diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/_VF.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/_VF.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a8d182222e50fc3488750a0476f60cfbc14b0d50 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/_VF.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/__config__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/__config__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2699969d12cc688320be02137cb244b5a8d75917 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/__config__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/__future__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/__future__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e3591bf9ab4ebe945f4f70623cedfd03b3bc4a51 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/__future__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/_appdirs.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/_appdirs.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cde662d3daa0a01dfea09f8edc496ebd775914c0 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/_appdirs.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/_classes.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/_classes.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..35557c4ca65d224441374ce4e0dbf0e2440c9fae Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/_classes.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/_compile.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/_compile.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1d3e2766b236e870138b9babe6c6c744e6f9b61e Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/_compile.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/_custom_ops.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/_custom_ops.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..12d1b491dc03309be0216e3a38427c044a5e799c Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/_custom_ops.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/_environment.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/_environment.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..62be5b1f18928015f4f8a662d868c20e6b410c48 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/_environment.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/_guards.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/_guards.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d3f4e3f25fbd547c19b49a07ec55360d542d7f4e Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/_guards.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/_jit_internal.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/_jit_internal.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e2c84ee5f7694229830edc021ac4c3a1d1256d03 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/_jit_internal.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/_linalg_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/_linalg_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a3e1fc0fa5b54e526b9c0bac928cbef9e8af5bb3 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/_linalg_utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/_lobpcg.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/_lobpcg.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eaa053b2701a5849b26f95234ab9130337e9b661 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/_lobpcg.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/_lowrank.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/_lowrank.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..07112d63e04da530cfbdf4eded5d1b99ec09e64b Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/_lowrank.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/_namedtensor_internals.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/_namedtensor_internals.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b0bb729ad680e778689cc00fa8b31095c2a16955 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/_namedtensor_internals.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/_ops.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/_ops.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e5c978a0a4c90440bab2d70ba6444888f1c7d190 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/_ops.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/_python_dispatcher.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/_python_dispatcher.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6fde9883cba9835cd57b11783fd288554e9892ae Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/_python_dispatcher.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/_size_docs.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/_size_docs.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b569ac3017281b7c834dba1a2f89511578093590 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/_size_docs.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/_sources.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/_sources.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..230aa1be342b040f130475fe82548198f6596e98 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/_sources.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/_storage_docs.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/_storage_docs.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a8091ba43cd059f98c5473adbd2e704afd3989f4 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/_storage_docs.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/_streambase.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/_streambase.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..af6de25cd9d87492bb0c8aba7b2f4e2ef5872f5f Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/_streambase.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/_tensor.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/_tensor.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7f10908d6a9412ee07962301ecaf2c2c9440240f Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/_tensor.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/_tensor_str.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/_tensor_str.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..41683e67fed2578b068a03adeb271b8b7f0eda5d Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/_tensor_str.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/_thread_safe_fork.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/_thread_safe_fork.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..90d67952214192bb80d894cd0e7cf1578aa0d1d5 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/_thread_safe_fork.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c0cd8b7dc23b32801abdb7cd341d693d7fd84f42 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/_utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/_utils_internal.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/_utils_internal.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..af2efa9a7a4b7dcff28c89304c59e7c7bb2fc1fa Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/_utils_internal.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/_vmap_internals.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/_vmap_internals.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f6cfb2175705085060eba08e1e7ca0f0902f7fae Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/_vmap_internals.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/_weights_only_unpickler.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/_weights_only_unpickler.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..01b4785ca4153dadb600347a255e8db02518266c Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/_weights_only_unpickler.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/functional.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/functional.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..85e8eb3285867ba99d51a1796d591467ffc266a5 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/functional.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/hub.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/hub.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9d2f6828b61202ebfea1119db5e353d70f37ebd1 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/hub.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/library.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/library.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b9db4a8a4bd5f1507ecd560369fed6f018228aee Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/library.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/quasirandom.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/quasirandom.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c1067b0e5f88686697039fbd280f570aac8cfcd4 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/quasirandom.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/random.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/random.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..780afff32ad35e98dda87cb93df13d4af19139e8 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/random.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/return_types.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/return_types.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a82c1c6656555961cf5e180b305e8afd37bc9842 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/return_types.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/serialization.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/serialization.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9591d092ecf66174b0d589d4a09aca5e09c3f4de Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/serialization.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/storage.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/storage.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..93c812e01701ae6f1c261b028708e483cbcb169f Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/storage.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/torch_version.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/torch_version.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e5cf75f3d5f34333ae3eedf85b7aabdafde18a99 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/torch_version.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/types.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/types.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..23a6e91aa1acad1644d9178c5a67da43347bb5f9 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/types.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/version.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/version.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..85bea04b1e08cfae38e882f759dfd58f604f6ab9 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/__pycache__/version.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_awaits/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_awaits/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b08067bdcf45a17dbcf3e032b4156315d9e2981b --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_awaits/__init__.py @@ -0,0 +1,53 @@ +from __future__ import annotations + +from typing import Generic, TypeVar + +import torch + +__all__ = ['Await'] + +W = TypeVar("W") + +class _PyAwaitMeta(type(torch._C._Await), type(Generic)): # type: ignore[misc, no-redef] + pass + +class _Await(torch._C._Await, Generic[W], metaclass=_PyAwaitMeta): + r""" + Wrapper around a ``torch._C.Await`` which encapsulates delayed execution + of a callable. All manipulations happen with functions ``torch.jit._awaitable``, + ``torch.jit._awaitable_wait``, ``torch.jit._awaitable_nowait``. + + Torch scriptable manipulations: + ``torch.jit._awaitable(func, *args)`` + Creates ``Await[W]`` object, where W is return type of func. + + Returns: + ``torch.jit._awaitable_wait(Await[W])`` + Returns the result of the function, specified at ``_awaitable``, with specified arguments. + + Returns: + The result of type ``W`` of the function call. The result is owned by ``Await[W]`` + and returned on all following ``_awaitable_wait`` calls. + + + ``torch.jit._awaitable_nowait(W)`` + Returns: + Trivial ``Await[W]`` with specified result. + + + Only in eager mode: + ``fn() -> Callable[Tuple[Any], W]`` + Returns: + Specified at ``_awaitable`` python function ``func``. + + ``args() -> Tuple[Any]`` + Returns: + Specified at ``_awaitable`` python args. + + ``is_nowait() -> _bool`` + Returns: + ``True`` if this object was created via ``_awaitable_nowait`` call (trivial `Await[W]`). + + In eager mode ``Await[W]`` can be used as ``W`` i.e. attributes of W can be called on ``Await[W]``, + ``_awaitable_wait()`` call will be transparently added. + """ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_decomp/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_decomp/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a321a49ac142e637d87eb09433659442e3b47004 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_decomp/__init__.py @@ -0,0 +1,549 @@ +# mypy: allow-untyped-defs +import inspect +from collections import defaultdict +from collections.abc import Callable, Sequence +from functools import lru_cache, partial, wraps +from itertools import chain +from typing import Optional, TYPE_CHECKING, TypeVar, Union +from typing_extensions import ParamSpec + + +if TYPE_CHECKING: + from torch.export.decomp_utils import CustomDecompTable + +import torch +import torch.library +from torch._ops import HigherOrderOperator, OperatorBase, OpOverload, OpOverloadPacket +from torch._prims_common import CustomOutParamAnnotation +from torch._subclasses.functional_tensor import FunctionalTensor +from torch.utils import _pytree as pytree + + +__all__ = [ + "decomposition_table", + "pre_autograd_decomposition_table", + "meta_table", + "register_decomposition", + "get_decompositions", + "core_aten_decompositions", + "_should_decompose_because_unsafe_op", +] + +_T = TypeVar("_T") +_P = ParamSpec("_P") + +# TODO: relax key type here; torch registrations should be possible to; but +# right now this type is accurate +global_decomposition_table: dict[str, dict[torch._ops.OperatorBase, Callable]] = ( + defaultdict(dict) +) + +decomposition_table = global_decomposition_table["post_autograd"] +pre_autograd_decomposition_table = global_decomposition_table["pre_autograd"] +meta_table = global_decomposition_table["meta"] + + +def _should_decompose_because_unsafe_op(op: torch._ops.OperatorBase) -> bool: + """ + Returns True if the op must always decompose in export/compile tracing system + + In export, we always decompose certain CIA ops that are tagged with + maybe_aliasing_or_mutating because we statically need to know if the op is + mutating or not. But these CIA ops could have different behaviour in runtime. + + native_batch_norm is a prim op which has a wrong schema and it needs to be replaced + with correct schema. But until then, we will force decompose it via this tag. + """ + if not isinstance(op, torch._ops.OpOverload): + return False + if torch.Tag.maybe_aliasing_or_mutating in op.tags: + return True + return op is torch.ops.aten.native_batch_norm.default + + +def _add_op_to_registry(registry, op, fn): + """ + This is an internal API for adding an op to the decomposition table. + + If op is OpOverload, it will be added to the registry directly. + If op is OpOverloadPacket, all the valid op_overloads in the packet will be added to the registry. + """ + overloads: list[Union[torch._ops.OperatorBase]] = [] + if isinstance(op, HigherOrderOperator): + # There's no concept of overloads for HigherOrderOperator + registry[op] = fn + return + elif isinstance(op, OpOverload): + overloads.append(op) + else: + assert isinstance(op, OpOverloadPacket) + for ol in op.overloads(): + overloads.append(getattr(op, ol)) + + for op_overload in overloads: + if op_overload in registry: + raise RuntimeError(f"duplicate registrations for {op_overload}") + # TorchScript dumps a bunch of extra nonsense overloads + # which don't have corresponding dispatcher entries, we need + # to filter those out, e.g aten.add.float_int + if torch._C._dispatch_has_kernel(op_overload.name()): + registry[op_overload] = fn + + +def _convert_out_params(f): + out_annotation = f.__annotations__.get("out") + + # If there are no out params, do not wrap the function. + if not out_annotation: + return f + + # Hack to detect when out is a Tuple. There seems to be no pretty way of doing this + if getattr(out_annotation, "__origin__", None) is tuple: + sig = inspect.signature(f) + out_names = sig.return_annotation._fields + # If out is a tuple, we need to register a function that unpacks all the out + # elements as this is what native_functions.yaml expects + + @wraps(f) + def _fn(*args, **kwargs): + out_kwargs = tuple(kwargs.pop(o, None) for o in out_names) + # Either all of the out kwargs are set or none of them + is_none = out_kwargs[0] is None + assert all((o is None) == is_none for o in out_kwargs) + return f(*args, **kwargs, out=None if is_none else out_kwargs) + + out_params = [ + inspect.Parameter( + o, + kind=inspect.Parameter.KEYWORD_ONLY, + default=None, + annotation=t, + ) + for o, t in zip(out_names, out_annotation.__args__) + ] + # Drop the out parameter and concatenate the new kwargs in the signature + params = chain((v for k, v in sig.parameters.items() if k != "out"), out_params) + _fn.__signature__ = inspect.Signature( # type: ignore[attr-defined] + parameters=params, # type: ignore[arg-type] + return_annotation=sig.return_annotation, + ) + # Drop the out parameter and concatenate the new kwargs in the annotations + _fn.__annotations__ = {k: v for k, v in f.__annotations__.items() if k != "out"} + for o in out_params: + _fn.__annotations__[o.name] = o.annotation + + # Propagate that this function is wrapped by `out_wrapper` + _fn._torch_decompositions_out_wrapper = f._torch_decompositions_out_wrapper # type: ignore[attr-defined] + + return _fn + + # Alternatively, there may be a single tensor out parameter with a name + # other than "out". This will need special treatment and is indicated by an + # annotation, which we will remove here so it is not exposed after wrapping. + custom_out_param_name = f.__annotations__.pop(CustomOutParamAnnotation, None) + if custom_out_param_name: + + @wraps(f) + def _fn(*args, **kwargs): + out_kwarg = kwargs.pop(custom_out_param_name, None) + return f(*args, **kwargs, out=out_kwarg) + + out_param = inspect.Parameter( + custom_out_param_name, + kind=inspect.Parameter.KEYWORD_ONLY, + default=None, + annotation=out_annotation, + ) + + # Drop the out parameter and concatenate the new kwarg in the signature + sig = inspect.signature(f) + params = chain( + (v for k, v in sig.parameters.items() if k != "out"), (out_param,) + ) + _fn.__signature__ = inspect.Signature( # type: ignore[attr-defined] + parameters=params, # type: ignore[arg-type] + return_annotation=sig.return_annotation, + ) + + # Drop the out parameter and concatenate the new kwargs in the annotations + _fn.__annotations__ = {k: v for k, v in f.__annotations__.items() if k != "out"} + _fn.__annotations__[out_param.name] = out_param.annotation + + return _fn + + return f + + +def register_decomposition( + aten_op, registry=None, *, type="post_autograd", unsafe=False +) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: + """ + A decorator to register a function as a decomposition to the Python + decomposition table. Use it like this:: + + @register_decomposition(torch.ops.aten.clamp_min) + def clamp_min(x): + return torch.clamp(self, min=min) + + If you are writing a new decomposition, consider contributing it + directly to PyTorch in torch._decomp.decompositions. + + This API is experimental; we are almost certainly going to extend + the API when we make decompositions eligible for use in transforms (e.g., + autograd) and not just backend tracing, where we then need to know if a + decomposition can be used to simulate a transform. + + By default, we also will register it to the Meta key of dispatcher, + and replace the c++ Meta implementation if there is already one. + + unsafe kwarg is for reuse of this function for registering non-function + things + """ + + assert type in {"post_autograd", "pre_autograd", "meta"} + + def decomposition_decorator(fn: Callable[_P, _T]) -> Callable[_P, _T]: + orig_fn = fn + if not unsafe: + fn = _convert_out_params(fn) + + nonlocal registry + if registry is None: + registry = global_decomposition_table[type] + + def register(op): + _add_op_to_registry(registry, op, fn) + + # To handle allowing multiple aten_ops at once + pytree.tree_map_(register, aten_op) + return orig_fn + + return decomposition_decorator + + +def get_decompositions( + aten_ops: Sequence[Union[torch._ops.OperatorBase, OpOverloadPacket]], + type: str = "post_autograd", +) -> dict[torch._ops.OperatorBase, Callable]: + """ + Retrieve a dictionary of decompositions corresponding to the list of + operator overloads and overload packets passed as input. Overload + packets will include all decomposed overloads in the packet. If there is + no decomposition for a requested operator, it is silently ignored. + + This API is experimental; we are almost certainly going to give an alternate, + more recommended formulation, where a user provides the set of operators + they know how to implement, and we provide decompositions for everything + not in this set. + """ + assert type in {"post_autograd", "pre_autograd", "meta"} + + registry = global_decomposition_table[type] + packets_to_overloads = defaultdict(list) + + for opo in registry: + if isinstance(opo, (OpOverload, OpOverloadPacket)): + packets_to_overloads[opo.overloadpacket].append(opo) + decompositions: dict[torch._ops.OperatorBase, Callable] = {} + for op in aten_ops: + if isinstance(op, OpOverloadPacket) and op in packets_to_overloads: + for op_overload in packets_to_overloads[op]: + decompositions[op_overload] = registry[op_overload] + elif isinstance(op, (torch._ops.OperatorBase)) and op in registry: + decompositions[op] = registry[op] + return decompositions + + +def remove_decompositions( + decompositions: dict[torch._ops.OperatorBase, Callable], + aten_ops: Sequence[Union[OpOverload, OpOverloadPacket]], +) -> None: + """ + Given a dictionary of decompositions obtained from get_decompositions(), removes + operators associated with a list of operator overloads and overload packets passed + as input. If the decomposition dictionary does not contain a decomposition that is + specified to be removed, it is silently ignored. + """ + for op in aten_ops: + if isinstance(op, OpOverloadPacket): + for overload_name in op.overloads(): + opo = getattr(op, overload_name) + decompositions.pop(opo, None) + elif isinstance(op, OpOverload): + decompositions.pop(op, None) + + +# populate the table +import torch._decomp.decompositions +import torch._refs + + +def core_aten_decompositions() -> "CustomDecompTable": + from torch.export.exported_program import default_decompositions + + return default_decompositions() + + +# See NOTE [Core ATen Ops] +# +# list was copied from torch/_inductor/decomposition.py +# excluding decompositions that results in prim ops +# Resulting opset of decomposition is core aten ops +def _core_aten_decompositions_post_autograd() -> dict[ + torch._ops.OperatorBase, Callable +]: + aten = torch.ops.aten + return get_decompositions( + [ + aten.addcdiv, + aten.addcdiv_, + aten.addcmul, + aten.addcmul_, + aten.addr, + aten.affine_grid_generator, + aten.alias_copy, + aten.all, + aten.aminmax, + aten.arange.default, + aten.arange.start, + aten.avg_pool2d_backward, + aten.baddbmm, + aten.binary_cross_entropy, + aten.binary_cross_entropy_backward, + aten.binary_cross_entropy_with_logits, + aten.block_diag, + aten.bernoulli.p, + aten.bernoulli.default, + aten.celu, + aten.celu_, + aten.channel_shuffle, + aten.clamp_max, + aten.clamp_min, + aten.col2im, + aten.count_nonzero, + aten.linalg_cross, + aten.cudnn_batch_norm, + aten.cudnn_batch_norm_backward, + aten.miopen_batch_norm_backward, + aten.deg2rad, + aten.deg2rad_, + aten.detach, + aten.diag_embed, + aten.diagonal_backward, + aten.diagonal_copy, + aten.dot, + aten.vdot, + aten.elu_, + aten.elu_backward, + aten._embedding_bag, + aten.embedding_dense_backward, + aten.empty_like, + aten._euclidean_dist.default, + aten.expand_as, + aten.expand_copy, + aten.eye, + aten.fill, + aten.fill_, + aten.floor_divide, + aten.frac, + aten.frac_, + aten._fused_moving_avg_obs_fq_helper, + aten.gelu_, + aten.gelu_backward, + aten.glu, + aten.glu_backward, + aten.hardshrink, + aten.hardsigmoid, + aten.hardsigmoid_, + aten.hardsigmoid_backward, + aten.hardswish, + aten.hardswish_, + aten.hardswish_backward, + aten.hardtanh_, + aten.hardtanh_backward, + aten.heaviside, + aten.heaviside_, + aten.huber_loss, + aten.huber_loss_backward, + aten.im2col, + aten.index_add.out, + aten.index_add.default, + aten.index_add_, + aten.index_copy.out, + aten.index_copy.default, + aten.index_copy_, + aten.index_fill.int_Scalar, + aten.index_fill.int_Tensor, + aten.index_fill.int_Scalar_out, + aten.index_fill.int_Tensor_out, + aten.index_fill_, + aten.isin, + aten.isneginf, + aten.isposinf, + aten.l1_loss, + aten._lazy_clone, + aten._test_parallel_materialize, + aten.leaky_relu_, + aten.leaky_relu_backward, + aten.lerp, + aten.lerp_, + aten.linspace, + aten.logaddexp, + aten.logaddexp2, + aten.logit, + aten.logit_, + aten.logit_backward, + aten.log_sigmoid_backward, + aten.log_sigmoid_forward, + aten._log_softmax_backward_data, + aten.logspace, + aten.logsumexp.default, + aten.masked_fill, + aten.masked_fill_, + aten.max_unpool2d, + aten.max_unpool3d, + aten.mish, + aten.mish_, + aten.mish_backward, + aten.mse_loss, + aten.mse_loss_backward, + aten.multi_margin_loss, + aten.multilabel_margin_loss_forward, + aten.mv, + aten.mvlgamma, + aten.mvlgamma_, + aten.nansum, + aten.nan_to_num, + aten.nan_to_num_, + aten.narrow, + aten.native_batch_norm_backward, + aten.native_dropout_backward, + aten.native_group_norm_backward, + aten.native_layer_norm_backward, + aten._fused_rms_norm, + aten._fused_rms_norm_backward, + aten.new_empty, + aten.new_full, + aten.new_ones, + aten.new_zeros, + aten.nll_loss2d_forward, + aten.nll_loss2d_backward, + aten.nll_loss_backward, + aten.nll_loss_forward, + aten.norm.ScalarOpt_dtype, + aten.norm.Scalar, + aten.norm.ScalarOpt_dim_dtype, + aten.norm.ScalarOpt_dim, + aten.norm.dtype_out, + aten.norm.out, + aten.norm.names_dtype_out, + aten.norm.names_out, + aten.norm.ScalarOpt_dtype_out, + aten.norm.Scalar_out, + aten.ones, + aten.ones_like, + aten.pixel_shuffle, + aten.pixel_unshuffle, + aten._prelu_kernel, + aten._prelu_kernel_backward, + aten._reshape_alias, + aten.rad2deg, + aten.rad2deg_, + aten.reflection_pad1d, + aten.reflection_pad1d_backward, + aten.reflection_pad2d, + aten.reflection_pad2d_backward, + aten.reflection_pad3d, + aten.reflection_pad3d_backward, + aten.replication_pad1d, + aten.replication_pad2d, + aten.replication_pad3d, + aten.renorm, + aten.renorm_, + aten.replication_pad2d, + aten.resize_as, + aten.roll, + aten.rot90, + aten.rrelu_with_noise, + aten.rrelu_with_noise_, + aten.rsub, + aten._safe_softmax, + aten._scaled_dot_product_flash_attention_for_cpu.default, + aten.select_backward, + aten.select_scatter, + aten.sgn, + aten.sgn_, + aten.sigmoid_backward, + aten.silu, + aten.silu_, + aten.silu_backward.grad_input, + aten.silu_backward, + aten.sinc, + aten.sinc_, + aten.slice_backward, + aten.smooth_l1_loss, + aten.smooth_l1_loss_backward, + aten.soft_margin_loss, + aten.soft_margin_loss_backward, + aten._softmax_backward_data, + aten.softplus, + aten.softplus_backward, + aten.softshrink, + aten.special_entr, + aten.special_log_ndtr, + aten.special_xlog1py, + aten.split.Tensor, + aten.split_with_sizes_copy, + aten.squeeze_copy, + aten.squeeze.default, + aten.squeeze.dim, + aten.std.correction, + aten.std.out, + aten.std.correction_out, + aten.std.names_out, + aten.std.correction_names_out, + aten.std_mean.correction, + aten.std_mean.correction_out, + aten.stack, + aten.sum.default, + aten.sum.out, + aten.t, + aten.t_copy, + aten.take, + aten.tanh_backward, + aten.threshold, + aten.threshold_, + aten.threshold_backward, + aten.trace, + aten.transpose.int, + aten.transpose_copy, + aten.tril, + aten.tril_, + aten.triu, + aten.triu_, + aten.unbind, + aten.unfold_backward, + aten.unfold_copy, + aten._unsafe_index, + aten._unsafe_index_put, + aten._unsafe_masked_index, + aten._unsafe_masked_index_put_accumulate, + aten.unsafe_split.Tensor, + aten.unsafe_split_with_sizes, + aten.unsqueeze_copy, + aten._unsafe_view, + aten.upsample_linear1d, + aten.upsample_bilinear2d.out, + aten.upsample_trilinear3d.out, + aten.upsample_nearest2d_backward, + aten.view_as_complex, + aten.xlogy, + aten.xlogy_, + aten.zero, + aten.zero_, + aten.zeros, + aten.zeros_like, + aten._chunk_cat, + aten._weight_norm_interface, + ] + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_decomp/decompositions.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_decomp/decompositions.py new file mode 100644 index 0000000000000000000000000000000000000000..4446ed5cdd3107f5177284ff08f3455663eeff8d --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_decomp/decompositions.py @@ -0,0 +1,5376 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +import functools +import itertools +import numbers +import operator +import sys +from collections.abc import Callable, Iterable +from contextlib import nullcontext +from enum import Enum +from functools import partial, reduce +from itertools import chain, product +from typing import Any, cast, Optional, Union + +import torch +import torch._meta_registrations +import torch._prims as prims +import torch._prims_common as utils +import torch.nn.functional as F +from torch import sym_float, sym_int, Tensor +from torch._decomp import register_decomposition +from torch._higher_order_ops.out_dtype import out_dtype +from torch._prims_common import ( + IntLike, + NumberType, + suggest_memory_format, + TensorLike, + TensorSequenceType, +) +from torch._prims_common.wrappers import ( + _maybe_convert_to_dtype, + _maybe_resize_out, + _safe_copy_out, + out_wrapper, +) +from torch.utils import _pytree as pytree +from torch.utils._pytree import tree_map + + +DispatchKey = torch._C.DispatchKey # type: ignore[attr-defined] + +# None of these functions are publicly accessible; get at them +# from torch._decomps +__all__: list[str] = [] + +aten = torch._ops.ops.aten + + +class Reduction(Enum): + NONE = 0 + MEAN = 1 + SUM = 2 + + +# This wraps a decomposition and performs various type promotion logic within it, depending on the strategy provided +# We're currently reusing ELEMENTWISE_TYPE_PROMOTION_KIND, although some of the usages are on non-elementwise ops +# Will need to validate the non-elementwise uses +def type_casts( + f: Callable, + type_promotion: utils.ELEMENTWISE_TYPE_PROMOTION_KIND, + compute_dtype_only: bool = False, + include_non_tensor_args: bool = False, +): + @functools.wraps(f) + def inner(*args, **kwargs): + allowed_types = ( + (Tensor, torch.types._Number) if include_non_tensor_args else (Tensor,) + ) # type: ignore[arg-type] + flat_args = [ + x + for x in pytree.arg_tree_leaves(*args, **kwargs) + if isinstance(x, allowed_types) + ] + computation_dtype, result_dtype = utils.elementwise_dtypes( + *flat_args, type_promotion_kind=type_promotion + ) + + # TODO: pretty sure this is not quite right + def increase_prec(x): + if isinstance(x, Tensor): + return x.to(computation_dtype) + else: + return x + + def decrease_prec(x): + if isinstance(x, Tensor): + return x.to(result_dtype) + else: + return x + + r = f(*tree_map(increase_prec, args), **tree_map(increase_prec, kwargs)) + if compute_dtype_only: + return r + else: + return tree_map(decrease_prec, r) + + return inner + + +compute_only_pw_cast_for_opmath = partial( + type_casts, + type_promotion=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + compute_dtype_only=True, +) +pw_cast_for_opmath = partial( + type_casts, type_promotion=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT +) +pw_cast_for_opmath_non_tensor_args = partial( + type_casts, + type_promotion=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + include_non_tensor_args=True, +) +pw_cast_for_int_to_real = partial( + type_casts, type_promotion=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT +) + + +# This expands x until x.dim() == dim. Might be useful as an operator +def _unsqueeze_to_dim(x: Tensor, dim: int) -> Tensor: + for _ in range(dim - x.dim()): + x = x.unsqueeze(-1) + return x + + +@register_decomposition(aten.tanh_backward) +@out_wrapper("grad_input") +@pw_cast_for_opmath +def tanh_backward(out_grad: Tensor, y: Tensor): + return out_grad * (1 - y * y).conj_physical() + + +@register_decomposition(aten.sigmoid_backward) +@out_wrapper("grad_input") +@pw_cast_for_opmath +def sigmoid_backward(out_grad: Tensor, y: Tensor): + return out_grad * (y * (1 - y)).conj_physical() + + +@register_decomposition(aten.softplus_backward) +@out_wrapper("grad_input") +@pw_cast_for_opmath +def softplus_backward(out_grad: Tensor, x: Tensor, beta: float, threshold: float): + z = (x * beta).exp() + return torch.where((x * beta) > threshold, out_grad, out_grad * z / (z + 1.0)) + + +@register_decomposition(aten.elu_backward) +@out_wrapper("grad_input") +@pw_cast_for_opmath +def elu_backward( + grad_output: Tensor, + alpha: float, + scale: float, + input_scale: float, + is_result: bool, + self_or_result: Tensor, +): + negcoef = alpha * scale + poscoef = scale + negiptcoef = input_scale + if is_result: + return torch.where( + self_or_result <= 0, + grad_output * negiptcoef * (self_or_result + negcoef), + grad_output * poscoef, + ) + else: + return torch.where( + self_or_result <= 0, + grad_output * negiptcoef * negcoef * torch.exp(self_or_result * negiptcoef), + grad_output * poscoef, + ) + + +@register_decomposition([aten.fill.Scalar]) +def fill_scalar(self, value): + return torch.full_like(self, value) + + +@register_decomposition([aten.fill.Tensor]) +def fill_tensor(self, value: Tensor): + torch._check( + value.dim() == 0, + lambda: f"fill only supports 0-dimension value tensor but got tensor with {value.dim()} dimensions", + ) + return aten.copy(self, value) + + +@register_decomposition(aten.hardsigmoid) +@out_wrapper() +@pw_cast_for_opmath +def hardsigmoid(self: Tensor) -> Tensor: + return torch.clamp(torch.clamp(self + 3, min=0), max=6) / 6 + + +@register_decomposition(aten.hardsigmoid_backward) +@out_wrapper("grad_input") +@pw_cast_for_opmath +def hardsigmoid_backward(grad_output: Tensor, self: Tensor): + return torch.where( + (self > -3.0) & (self < 3.0), + grad_output * (1.0 / 6.0), + 0.0, + ) + + +@register_decomposition(aten.hardtanh_backward) +@out_wrapper("grad_input") +def hardtanh_backward( + grad_output: Tensor, self: Tensor, min_val: float, max_val: float +): + return torch.where((self <= min_val) | (self >= max_val), 0.0, grad_output) + + +@register_decomposition(aten.hardswish) +@out_wrapper() +@pw_cast_for_opmath +def hardswish(self: Tensor) -> Tensor: + return self * torch.clamp(torch.clamp(self + 3, min=0), max=6) / 6 + + +@register_decomposition(aten.hardswish_backward) +@out_wrapper() +@pw_cast_for_opmath +def hardswish_backward(grad_output: Tensor, self: Tensor) -> Tensor: + return torch.where( + self <= -3, + 0.0, + torch.where(self < 3, grad_output * ((self / 3) + 0.5), grad_output), + ) + + +@register_decomposition(aten.threshold_backward) +@out_wrapper("grad_input") +def threshold_backward(grad_output: Tensor, self: Tensor, threshold: float): + return torch.where(self <= threshold, 0, grad_output) + + +@register_decomposition(aten.leaky_relu_backward) +@out_wrapper("grad_input") +@pw_cast_for_opmath +def leaky_relu_backward( + grad_output: Tensor, self: Tensor, negative_slope: float, self_is_result: bool +): + return torch.where(self > 0, grad_output, grad_output * negative_slope) + + +@register_decomposition(aten.gelu_backward) +@out_wrapper("grad_input") +@pw_cast_for_opmath +def gelu_backward(grad: Tensor, self: Tensor, approximate: str = "none"): + M_SQRT2 = 1.41421356237309504880 + M_SQRT1_2 = 0.70710678118654752440 + M_2_SQRTPI = 1.12837916709551257390 + if approximate == "tanh": + kBeta = M_SQRT2 * M_2_SQRTPI * 0.5 + kKappa = 0.044715 + x_sq = self * self + x_cube = x_sq * self + inner = kBeta * (self + kKappa * x_cube) + tanh_inner = torch.tanh(inner) + + left = 0.5 * self + right = 1 + tanh_inner + + left_derivative = 0.5 * right + + tanh_derivative = 1 - tanh_inner * tanh_inner + inner_derivative = kBeta * (1 + 3 * kKappa * x_sq) + right_derivative = left * tanh_derivative * inner_derivative + + return grad * (left_derivative + right_derivative) + else: + kAlpha = M_SQRT1_2 + kBeta = M_2_SQRTPI * M_SQRT1_2 * 0.5 + cdf = 0.5 * (1 + torch.erf(self * kAlpha)) + pdf = kBeta * torch.exp(self * self * -0.5) + return grad * (cdf + self * pdf) + + +@register_decomposition(aten.mish_backward) +@pw_cast_for_opmath +def mish_backward(grad_output: Tensor, input: Tensor): + input_tanh_softplus = torch.tanh(F.softplus(input)) + input_sigmoid = torch.sigmoid(input) + out = input * input_sigmoid * (1 - input_tanh_softplus * input_tanh_softplus) + return grad_output * (input_tanh_softplus + out) + + +@register_decomposition(aten.silu) +@out_wrapper() +@pw_cast_for_opmath +def silu(self: Tensor) -> Tensor: + return self * torch.sigmoid(self) + + +@register_decomposition(aten.silu_backward) +@out_wrapper("grad_input") +@pw_cast_for_opmath +def silu_backward(grad_output: Tensor, self: Tensor) -> Tensor: + sigmoid = 1 / (1 + torch.exp(-self)) + return grad_output * sigmoid * (1 + self * (1 - sigmoid)) + + +@register_decomposition(aten._prelu_kernel) +def _prelu_kernel(self: Tensor, weight: Tensor) -> Tensor: + return torch.where(self > 0, self, weight * self) + + +@register_decomposition(aten._prelu_kernel_backward) +def _prelu_kernel_backward( + grad_output: Tensor, + self: Tensor, + weight: Tensor, +) -> tuple[Tensor, Tensor]: + input_grad = torch.where(self > 0, grad_output, weight * grad_output) + weight_grad = torch.where(self > 0, 0.0, self * grad_output) + return (input_grad, weight_grad) + + +@register_decomposition(aten.rrelu_with_noise_backward) +@out_wrapper() +@pw_cast_for_opmath +def rrelu_with_noise_backward( + grad_output: Tensor, + self: Tensor, + noise: Tensor, + lower: float, + upper: float, + training: bool, + self_is_result: bool, +) -> Tensor: + if training and upper - lower > 1e-6: + return grad_output.mul(noise) + else: + negative_slope = (lower + upper) / 2 + return aten.leaky_relu_backward( + grad_output, self, negative_slope, self_is_result + ) + + +@register_decomposition(aten.log_sigmoid_backward) +@out_wrapper("grad_input") +@pw_cast_for_opmath +def log_sigmoid_backward(grad_output: Tensor, self: Tensor, buffer: Tensor) -> Tensor: + in_negative = self < 0 + max_deriv = torch.where(in_negative, 1, 0) + sign = torch.where(in_negative, 1, -1) + z = torch.exp(-torch.abs(self)) + return grad_output * (max_deriv - sign * (z / (1 + z))) + # CPU has a special formula that uses buffer, but disabled for convenience sake + # return (max_deriv - sign * (buffer / (1 + buffer))) * grad_output + + +def apply_loss_reduction(loss: Tensor, reduction: int): + if reduction == Reduction.MEAN.value: + return torch.mean(loss) + elif reduction == Reduction.SUM.value: + return torch.sum(loss) + else: + return loss + + +def to_real_dtype(dtype: torch.dtype): + if dtype == torch.complex32: + return torch.float16 + elif dtype == torch.complex64: + return torch.float32 + elif dtype == torch.complex128: + return torch.float64 + + +# TODO: None of these loss castings are quite correct, see +# https://github.com/pytorch/pytorch/issues/76870. Also, the ATen kernels +# perform the pointwise portion in opmath, but don't maintain it between the +# pointwise portion and the reduction + + +@register_decomposition(aten.mse_loss) +@out_wrapper() +@pw_cast_for_opmath +def mse_loss( + self: Tensor, target: Tensor, reduction: int = Reduction.MEAN.value +) -> Tensor: + # pyrefly: ignore [unsupported-operation] + loss = (self - target) ** 2 + return apply_loss_reduction(loss, reduction) + + +@register_decomposition(aten.mse_loss_backward) +@out_wrapper("grad_input") +@pw_cast_for_opmath +def mse_loss_backward( + grad_output: Tensor, input: Tensor, target: Tensor, reduction: int +): + norm = 2.0 / input.numel() if reduction == Reduction.MEAN.value else 2.0 + return norm * (input - target) * grad_output + + +@register_decomposition(aten._safe_softmax) +def safe_softmax(self, dim, dtype=None): + out = torch.softmax(self, dim=dim, dtype=dtype) + masked = self.eq(float("-inf")) + masked_rows = torch.all(masked, dim=dim, keepdim=True) + zeros = torch.zeros_like(out) + return torch.where(masked_rows, zeros, out) + + +@register_decomposition(aten.smooth_l1_loss) +@out_wrapper() +@pw_cast_for_opmath +def smooth_l1_loss( + self: Tensor, + target: Tensor, + reduction: int = Reduction.MEAN.value, + beta: float = 1.0, +): + loss = (self - target).abs() + # pyrefly: ignore [unsupported-operation] + loss = torch.where(loss < beta, 0.5 * loss**2 / beta, loss - 0.5 * beta) + return apply_loss_reduction(loss, reduction) + + +@register_decomposition(aten.smooth_l1_loss_backward.default) +@pw_cast_for_opmath +def smooth_l1_loss_backward( + grad_output: Tensor, self: Tensor, target: Tensor, reduction: int, beta: float +): + norm = 1.0 / self.numel() if reduction == Reduction.MEAN.value else 1.0 + x = self - target + abs_x = torch.abs(x) + norm_grad = norm * grad_output + return torch.where( + abs_x < beta, + norm_grad * x / beta, + norm_grad * torch.sign(x), + ) + + +@register_decomposition(aten.smooth_l1_loss_backward.grad_input) +@pw_cast_for_opmath +def smooth_l1_loss_backward_out( + grad_output: Tensor, + self: Tensor, + target: Tensor, + reduction: int, + beta: float, + grad_input: Tensor, +): + result = smooth_l1_loss_backward(grad_output, self, target, reduction, beta) + _maybe_resize_out(grad_input, result.shape) + return _safe_copy_out(copy_from=result, copy_to=grad_input, exact_dtype=True) + + +@register_decomposition(aten.huber_loss_backward.default) +@pw_cast_for_opmath +def huber_loss_backward( + grad_output: Tensor, self: Tensor, target: Tensor, reduction: int, delta: float +): + norm = 1.0 / self.numel() if reduction == Reduction.MEAN.value else 1.0 + x = self - target + return torch.where( + x < -delta, + -norm * grad_output * delta, + torch.where(x > delta, norm * grad_output * delta, norm * x * grad_output), + ) + + +# We cannot use @out_wrapper() here, because the output tensor is not named 'out', it's 'grad_input' +@register_decomposition(aten.huber_loss_backward.out) +@pw_cast_for_opmath +def huber_loss_backward_out( + grad_output: Tensor, + self: Tensor, + target: Tensor, + reduction: int, + delta: float, + grad_input: Tensor, +): + result = huber_loss_backward(grad_output, self, target, reduction, delta) + _maybe_resize_out(grad_input, result.shape) + return _safe_copy_out(copy_from=result, copy_to=grad_input, exact_dtype=True) + + +def _nll_loss_backward( + grad_output: Tensor, + self: Tensor, + target: Tensor, + weight: Optional[Tensor], + reduction: int, + ignore_index: int, + total_weight: Tensor, +) -> Tensor: + channel_dim = 0 if self.dim() < 2 else 1 + if reduction == Reduction.MEAN.value: + grad_output = grad_output / total_weight + + target = target.unsqueeze(channel_dim) + safe_target = torch.where(target != ignore_index, target, 0) + grad_input = torch.zeros_like(self) + grad_input = torch.scatter(grad_input, channel_dim, safe_target, -1.0) + + if grad_input.dim() > grad_output.dim() > 0: + grad_output = grad_output.unsqueeze(channel_dim) + + if weight is not None: + new_shape = [1 for _ in range(self.dim())] + new_shape[channel_dim] = weight.shape[0] + weight = weight.reshape(new_shape) + grad_output = grad_output * weight + + grad_output = torch.where(target != ignore_index, grad_output, 0) + + return grad_input * grad_output + + +@register_decomposition(aten.glu_backward) +@out_wrapper("grad_input") +@pw_cast_for_opmath +def glu_backward(grad_output: Tensor, self: Tensor, dim: int) -> Tensor: + assert self.dim() > 0, "glu does not support 0-dimensional tensors" + wrap_dim = utils.canonicalize_dim(self.dim(), dim) + nIn = self.size(wrap_dim) + assert nIn % 2 == 0, ( + f"Halving dimension must be even, but dimension {wrap_dim} is size {nIn}" + ) + inputSize = nIn // 2 + firstHalf = self.narrow(wrap_dim, 0, inputSize) + secondHalf = self.narrow(wrap_dim, inputSize, inputSize) + gradInputFirstHalf = torch.sigmoid(secondHalf) + gradInputSecondHalf = ( + (1.0 - gradInputFirstHalf) * gradInputFirstHalf * firstHalf * grad_output + ) + gradInputFirstHalf = gradInputFirstHalf * grad_output + return torch.cat([gradInputFirstHalf, gradInputSecondHalf], dim=wrap_dim) + + +@register_decomposition(aten.nll_loss_backward) +@out_wrapper("grad_input") +def nll_loss_backward( + grad_output: Tensor, + self: Tensor, + target: Tensor, + weight: Optional[Tensor], + reduction: int, + ignore_index: int, + total_weight: Tensor, +) -> Tensor: + assert 0 <= self.dim() <= 2, "input tensor should be 1D or 2D" + assert target.dim() <= 1, ( + "0D or 1D target tensor expected, multi-target not supported" + ) + + no_batch_dim = self.dim() == 1 and target.dim() == 0 + assert no_batch_dim or (self.shape[0] == target.shape[0]), ( + f"size mismatch (got input: {self.shape}, target: {target.shape})" + ) + assert total_weight.numel() == 1, ( + "expected total_weight to be a single element tensor, got: ", + f"{total_weight.shape} ({total_weight.numel()} elements)", + ) + + assert weight is None or weight.numel() == self.shape[-1], ( + "weight tensor should be defined either for all or no classes" + ) + + if reduction == Reduction.NONE.value and self.dim() == 2: + assert grad_output.dim() == 1 and grad_output.shape[0] == self.shape[0], ( + f"Expected a tensor of dimension 1 and tensor.size[0] == {self.shape[0]} but " + f"got: dimension {grad_output.dim()} and tensor.size[0] == {grad_output.shape[0]}" + ) + else: + assert grad_output.dim() <= 1 and grad_output.numel() == 1, ( + f"Expected a single element grad_output tensor, but got: {grad_output.shape}" + ) + + return _nll_loss_backward( + grad_output, self, target, weight, reduction, ignore_index, total_weight + ) + + +@register_decomposition(aten.nll_loss2d_backward) +@out_wrapper("grad_input") +def nll_loss2d_backward( + grad_output: Tensor, + self: Tensor, + target: Tensor, + weight: Optional[Tensor], + reduction: int, + ignore_index: int, + total_weight: Tensor, +) -> Tensor: + assert self.dim() == 4, ( + f"only batches of spatial inputs supported (4D tensors), but got input of dimension: {self.dim()}" + ) + + assert target.dim() == 3, ( + f"only batches of spatial targets supported (3D tensors) but got targets of dimension: {target.dim()}" + ) + + assert ( + self.shape[0] == target.shape[0] + and self.shape[2] == target.shape[1] + and self.shape[3] == target.shape[2] + ), f"size mismatch (got input: {self.shape}, target: {target.shape}" + + assert total_weight.numel() == 1, ( + "expected total_weight to be a single element tensor, " + f"got: {total_weight.shape} ( {total_weight.numel()}, elements)" + ) + + return _nll_loss_backward( + grad_output, self, target, weight, reduction, ignore_index, total_weight + ) + + +@register_decomposition(aten.binary_cross_entropy) +@out_wrapper() +@pw_cast_for_opmath +def binary_cross_entropy( + self: Tensor, + target: Tensor, + weight: Optional[Tensor] = None, + reduction: int = Reduction.MEAN.value, +) -> Tensor: + # We cannot currently model this without introducing data-dependent control flow + # TORCH_CHECK( + # (input_val >= 0) && (input_val <= 1), + # "all elements of input should be between 0 and 1" + # ) + loss = (target - 1) * torch.maximum( + torch.log1p(-self), self.new_full((), -100) + ) - target * torch.maximum(torch.log(self), self.new_full((), -100)) + if weight is not None: + loss = loss * weight + return apply_loss_reduction(loss, reduction) + + +@register_decomposition(aten.binary_cross_entropy_backward) +@out_wrapper("grad_input") +@pw_cast_for_opmath +def binary_cross_entropy_backward( + grad_output: Tensor, + self: Tensor, + target: Tensor, + weight: Optional[Tensor] = None, + reduction: int = Reduction.MEAN.value, +) -> Tensor: + EPSILON = 1e-12 + result = grad_output * (self - target) / torch.clamp(self * (1 - self), min=EPSILON) + if weight is not None: + result = result * weight + if reduction == Reduction.MEAN.value: + result = result / self.numel() + return result + + +@register_decomposition(aten.soft_margin_loss) +@out_wrapper() +@pw_cast_for_opmath +def soft_margin_loss( + input: Tensor, + target: Tensor, + reduction: int = Reduction.MEAN.value, +) -> Tensor: + loss = torch.log1p(torch.exp(-input * target)) + return apply_loss_reduction(loss, reduction) + + +@register_decomposition(aten.soft_margin_loss_backward) +@out_wrapper("grad_input") +@pw_cast_for_opmath +def soft_margin_loss_backward( + grad_output: Tensor, + self: Tensor, + target: Tensor, + reduction: int = Reduction.MEAN.value, +) -> Tensor: + grad_input = target * grad_output * (torch.sigmoid(target * self) - 1) + if reduction == Reduction.MEAN.value: + grad_input = grad_input / self.numel() + return grad_input + + +@register_decomposition(aten.dist) +@out_wrapper() +def dist(input: Tensor, other: Tensor, p: float = 2): + return aten.norm(input - other, p=p) + + +@register_decomposition(aten._euclidean_dist) +@out_wrapper() +def _euclidean_dist(x1: Tensor, x2: Tensor) -> Tensor: + x1_norm = x1.pow(2).sum(-1, True) + x1_pad = torch.ones_like(x1_norm, memory_format=torch.contiguous_format) + x2_norm = x2.pow(2).sum(-1, True) + x2_pad = torch.ones_like(x2_norm, memory_format=torch.contiguous_format) + x1_ = torch.cat([x1.mul(-2), x1_norm, x1_pad], -1) + x2_ = torch.cat([x2, x2_pad, x2_norm], -1) + result = x1_.matmul(x2_.mT) + return result.clamp_min(0).sqrt() + + +@register_decomposition(aten.slice_backward) +@out_wrapper() +def slice_backward( + grad_output: Tensor, + input_sizes: list[int], + dim: int, + start: int, + end: int, + step: int, +): + grad_input = grad_output.new_zeros(input_sizes) + return torch.slice_scatter(grad_input, grad_output, dim, start, end, step) + + +@register_decomposition(aten.slice.Tensor) +def slice_forward( + # Tensor(a) self, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1 + self: Tensor, + dim: int = 0, + start: Optional[int] = None, + end: Optional[int] = None, + step: int = 1, +): + from torch.fx.experimental.symbolic_shapes import statically_known_true + + ndim = self.dim() + if ndim == 0: + raise RuntimeError("slice() cannot be applied to a 0-dim tensor.") + dim = utils.canonicalize_dim(self.dim(), dim) + sizes = list(self.size()) + strides = list(self.stride()) + + if step <= 0: + raise RuntimeError("slice step must be positive") + + start_val = start if start is not None else 0 + end_val = end if end is not None else sys.maxsize # 2^63 - 1 + + if start_val < 0: + start_val += sizes[dim] + + if end_val < 0: + end_val += sizes[dim] + + if start_val < 0: + start_val = 0 + elif start_val > sizes[dim]: + start_val = sizes[dim] + + if statically_known_true(end_val == sys.maxsize): + end_val = sizes[dim] + elif end_val < start_val: + end_val = start_val + elif end_val > sizes[dim]: + end_val = sizes[dim] + + storage_offset = self.storage_offset() + start_val * strides[dim] + len = end_val - start_val + sizes[dim] = (len + step - 1) // step + strides[dim] *= step + + if self.is_quantized: + raise NotImplementedError( + "Slice decomposition for quantized tensors aren't implemented" + ) + else: + return self.as_strided(sizes, strides, storage_offset) + + +def _normalize_start_end( + x: Tensor, dim: int, start: Optional[int], end: Optional[int] +) -> tuple[int, int]: + """ + Normalize start and end such that both are in the range + [0, x.get_size()[dim]] and start <= end. + """ + dim_size = x.shape[dim] + + def clamp_wrap(val, lower, upper, default) -> int: + if val is None: + return default + if val < 0: + val = val + dim_size + return min(max(val, lower), upper) + + start = clamp_wrap(start, 0, dim_size, 0) + end = clamp_wrap(end, start, dim_size, dim_size) + return start, end + + +# This is not in torch._refs because aten.index used by +# aten._unsafe_masked_index does not have a decomposition. +@register_decomposition(aten.slice_scatter) +@out_wrapper() +def slice_scatter( + input: Tensor, + src: Tensor, + dim: int = 0, + start: Optional[int] = None, + end: Optional[int] = None, + step: int = 1, +): + dim = utils.canonicalize_dim(input.ndim, dim) + dim_size = input.shape[dim] + start, end = _normalize_start_end(input, dim, start, end) + + src_size = list(input.shape) + src_size[dim] = (end - start + (step - 1)) // step + src = src.expand(src_size) + + if start == 0 and end == dim_size and step == 1: + return src.clone() + + indices: list[Optional[Tensor]] = [None] * input.dim() + idx = torch.arange(dim_size, device=input.device) + indices[dim] = (idx - start) // step + + mask = torch.ones(dim_size, device=input.device, dtype=torch.bool) + if start != 0: + mask = torch.logical_and(mask, idx >= start) + + if end != dim_size: + mask = torch.logical_and(mask, idx < end) + + if step != 1: + mask = torch.logical_and(mask, (idx - start) % step == 0) + + mask_shape = [1] * input.dim() + mask_shape[dim] = -1 + mask = mask.view(mask_shape) + return aten.where(mask, aten._unsafe_masked_index(src, mask, indices, 0), input) + + +@register_decomposition(aten.select_backward) +@out_wrapper() +def select_backward(grad_output: Tensor, input_sizes: list[int], dim: int, index: int): + grad_input = grad_output.new_zeros(input_sizes) + return torch.select_scatter(grad_input, grad_output, dim, index) + + +@register_decomposition(aten.diagonal_backward) +@out_wrapper() +def diagonal_backward( + grad_output: Tensor, input_sizes: list[int], offset: int, dim1: int, dim2: int +): + grad_input = grad_output.new_zeros(input_sizes) + return torch.diagonal_scatter(grad_input, grad_output, offset, dim1, dim2) + + +def _cast_grad_to_input_dtype( + grad_output: Tensor, grad_input: Tensor, input_dtype: torch.dtype +): + if grad_output.dtype != input_dtype: + grad_input = grad_input.to(input_dtype) + return grad_input + + +@register_decomposition(aten._softmax_backward_data) +@out_wrapper("grad_input") +@compute_only_pw_cast_for_opmath +def _softmax_backward_data( + grad_output: Tensor, output: Tensor, dim: int, input_dtype: torch.dtype +): + new_grad_output = grad_output * output + grad_input = new_grad_output - output * torch.sum( + new_grad_output, dim=dim, keepdim=True + ) + + # CPU kernel doesn't respect input_dtype, but following check doesn't work for meta tensor + # if grad_output.device == torch.device("cpu"): + # return grad_input.contiguous() + + return _cast_grad_to_input_dtype(grad_output, grad_input, input_dtype).contiguous() + + +@register_decomposition(aten._log_softmax_backward_data) +@out_wrapper() +@compute_only_pw_cast_for_opmath +def _log_softmax_backward_data( + grad_output: Tensor, output: Tensor, dim: int, input_dtype: torch.dtype +): + grad_input = grad_output - torch.exp(output) * torch.sum( + grad_output, dim=dim, keepdim=True + ) + return _cast_grad_to_input_dtype(grad_output, grad_input, input_dtype) + + +def _im2col_col2im_indices_along_dim( + input_d, kernel_d, dilation_d, padding_d, stride_d, device +): + """Utility function to implement im2col and col2im""" + blocks_d = input_d + padding_d * 2 - dilation_d * (kernel_d - 1) + + arange_kw = partial(torch.arange, dtype=torch.int64, device=device) + + # Stride kernel over input and find starting indices along dim d + blocks_d_indices = arange_kw(0, blocks_d, stride_d).unsqueeze(0) + + # Apply dilation on kernel and find its indices along dim d + kernel_grid = arange_kw(0, kernel_d * dilation_d, dilation_d).unsqueeze(-1) + + # Broadcast and add kernel starting positions (indices) with + # kernel_grid along dim d, to get block indices along dim d + return blocks_d_indices + kernel_grid + + +@register_decomposition(aten.im2col) +@out_wrapper() +def im2col( + input: Tensor, + kernel_size: list[int], + dilation: list[int], + padding: list[int], + stride: list[int], +) -> Tensor: + torch._check(len(kernel_size) == 2, lambda: "im2col(): only 2D kernel supported") + torch._check(len(dilation) == 2, lambda: "im2col(): only 2D dilation supported") + torch._check(len(padding) == 2, lambda: "im2col(): only 2D padding supported") + torch._check(len(stride) == 2, lambda: "im2col(): only 2D stride supported") + + def check_positive(param, param_name, strict=True): + cond = all(p > 0 for p in param) if strict else all(p >= 0 for p in param) + torch._check( + cond, lambda: f"{param_name} should be greater than zero, but got {param}" + ) + + check_positive(kernel_size, "kernel_size") + check_positive(dilation, "dilation") + check_positive(dilation, "padding", strict=False) + check_positive(stride, "stride") + + shape = input.shape + ndim = len(shape) + torch._check( + ndim in (3, 4) and all(d != 0 for d in shape[-3:]), + lambda: "Expected 3D or 4D (batch mode) tensor for input with possible 0 batch size " + f"and non-zero dimensions, but got: {tuple(shape)}", + ) + output_size = tuple( + 1 + (out + 2 * pad - dil * (ker - 1) - 1) // st + for out, pad, dil, ker, st in zip( + shape[-2:], padding, dilation, kernel_size, stride + ) + ) + torch._check( + all(c > 0 for c in output_size), + lambda: f"Given an input with spatial size {tuple(shape[-2:])}, " + f"kernel_size={kernel_size}, dilation={dilation}, " + f"padding={padding}, stride={stride}, " + "the calculated shape of the array of sliding blocks " + f"is {output_size}, but its components must be at least one.", + ) + batched_input = ndim == 4 + if not batched_input: + input = input.unsqueeze(0) + + batch_dim, channel_dim, input_h, input_w = input.shape + + stride_h, stride_w = stride + padding_h, padding_w = padding + dilation_h, dilation_w = dilation + kernel_h, kernel_w = kernel_size + + blocks_row_indices = _im2col_col2im_indices_along_dim( + input_h, kernel_h, dilation_h, padding_h, stride_h, input.device + ) + blocks_col_indices = _im2col_col2im_indices_along_dim( + input_w, kernel_w, dilation_w, padding_w, stride_w, input.device + ) + + # Note that F.pad takes (padding_left, padding_right, padding_top, padding_bottom) + # ugh + padded_input = F.pad(input, (padding_w, padding_w, padding_h, padding_h)) + + blocks_row_indices = blocks_row_indices.unsqueeze(-1).unsqueeze(-1) + output = padded_input[:, :, blocks_row_indices, blocks_col_indices] + output = output.permute(0, 1, 2, 4, 3, 5) + num_blocks_row = blocks_row_indices.size(1) + num_blocks_col = blocks_col_indices.size(1) + output = output.reshape( + batch_dim, channel_dim * kernel_h * kernel_w, num_blocks_row * num_blocks_col + ) + + if not batched_input: + output = output.squeeze(0) + return output + + +@register_decomposition(aten.col2im) +@out_wrapper() +@pw_cast_for_opmath +def col2im( + input: Tensor, + output_size: list[int], + kernel_size: list[int], + dilation: list[int], + padding: list[int], + stride: list[int], +) -> Tensor: + torch._check(len(output_size) == 2, lambda: "only 2D output_size supported") + torch._check(len(kernel_size) == 2, lambda: "only 2D kernel supported") + torch._check(len(dilation) == 2, lambda: "only 2D dilation supported") + torch._check(len(padding) == 2, lambda: "only 2D padding supported") + torch._check(len(stride) == 2, lambda: "only 2D stride supported") + + def check_positive(param, param_name, strict=True): + cond = all(p > 0 for p in param) if strict else all(p >= 0 for p in param) + torch._check( + cond, lambda: f"{param_name} should be greater than zero, but got {param}" + ) + + check_positive(kernel_size, "kernel_size") + check_positive(dilation, "dilation") + check_positive(padding, "padding", strict=False) + check_positive(stride, "stride") + check_positive(output_size, "output_size") + + shape = input.shape + ndim = len(shape) + torch._check( + ndim in (2, 3) and all(d != 0 for d in shape[-2:]), + lambda: "Expected 2D or 3D (batch mode) tensor for input with possible 0 batch size " + f"and non-zero dimensions, but got: {tuple(shape)}", + ) + prod_kernel_size = kernel_size[0] * kernel_size[1] + torch._check( + shape[-2] % prod_kernel_size == 0, + lambda: "Expected size of input's first non-batch dimension to be divisible by the " + f"product of kernel_size, but got input.shape[-2] = {shape[-2]} and " + f"kernel_size={kernel_size}", + ) + col = [ + 1 + (out + 2 * pad - dil * (ker - 1) - 1) // st + for out, pad, dil, ker, st in zip( + output_size, padding, dilation, kernel_size, stride + ) + ] + L = col[0] * col[1] + torch._check( + shape[-1] == L, + lambda: f"Given output_size={output_size}, kernel_size={kernel_size}, " + f"dilation={dilation}, padding={padding}, stride={stride}, " + f"expected input.size(-1) to be {L} but got {shape[-1]}.", + ) + torch._check( + L > 0, + lambda: f"Given output_size={output_size}, kernel_size={kernel_size}, " + f"dilation={dilation}, padding={padding}, stride={stride}, " + f"expected input.size(-1) to be {L} but got {shape[-1]}.", + ) + batched_input = ndim == 3 + if not batched_input: + input = input.unsqueeze(0) + + shape = input.shape + + out_h, out_w = output_size + stride_h, stride_w = stride + padding_h, padding_w = padding + dilation_h, dilation_w = dilation + kernel_h, kernel_w = kernel_size + + # col2im is defined as the backwards of im2col, so we differentiate its decomposition by hand + input = input.reshape([shape[0], shape[1] // prod_kernel_size] + kernel_size + col) + input = input.permute(0, 1, 2, 4, 3, 5) + + indices_row = _im2col_col2im_indices_along_dim( + out_h, kernel_h, dilation_h, padding_h, stride_h, input.device + ) + indices_row = _unsqueeze_to_dim(indices_row, 4) + indices_col = _im2col_col2im_indices_along_dim( + out_w, kernel_w, dilation_w, padding_w, stride_w, input.device + ) + + output_padded_size = [o + 2 * p for o, p in zip(output_size, padding)] + output = input.new_zeros( + [shape[0], shape[1] // prod(kernel_size)] + output_padded_size + ) + idx = (None, None, indices_row, indices_col) + output = aten._unsafe_index_put(output, idx, input, accumulate=True) + output = F.pad(output, (-padding_w, -padding_w, -padding_h, -padding_h)) + + if not batched_input: + output = output.squeeze(0) + return output + + +@register_decomposition(aten.native_dropout_backward) +@out_wrapper() +def native_dropout_backward(grad_output: Tensor, mask: Tensor, scale: float): + # According to the CUDA kernel implementation we should have this test; + # but it seems to fail tests! + # torch._check(mask.dtype == torch.bool, lambda: f"Mask should be Bool Scalar Type {mask.dtype}") + + # Mimicking CUDA kernel's behavior for output stride: output follow input's memory format + # This different from TensorIterator's behavior + r = (grad_output * (mask.type_as(grad_output) * scale)).clone( + memory_format=utils.suggest_memory_format(grad_output) + ) + return r + + +@register_decomposition(aten.unfold_backward) +@out_wrapper() +def unfold_backward( + grad: Tensor, input_size: list[int], dimension: int, size: int, step: int +) -> Tensor: + if len(input_size) == 0: + return torch.squeeze_copy(grad, 0) + dim = utils.canonicalize_dim(len(input_size), dimension) + idx = torch.arange(input_size[dim], device=grad.device, dtype=torch.int32) + idx = idx.unfold(0, size, step).flatten() + grad = grad.movedim(-1, dim + 1).flatten(dim, dim + 1) + # nb. At the moment this generates two kernels in triton + # It could potentially be fused into one call to scatter_reduce, + # in the case step <= size provided scatter_reduce generates 1 kernel + grad_input = grad.new_zeros(input_size) + index = (None,) * dim + (idx,) + return aten._unsafe_index_put(grad_input, index, grad, accumulate=True).contiguous() + + +@register_decomposition(aten.logit_backward.default) +@pw_cast_for_opmath +def logit_backward( + grad_output: Tensor, self: Tensor, eps: Optional[float] = None +) -> Tensor: + if eps is not None: + lo = eps + hi = 1.0 - lo + return torch.where( + torch.logical_and(self >= lo, self <= hi), + grad_output / (self * (1.0 - self)), + 0.0, + ) + else: + return torch.where( + torch.logical_and(self >= 0.0, self <= 1.0), + grad_output / (self * (1.0 - self)), + self.new_full((), float("nan")), + ) + + +@register_decomposition(aten.dropout) +@aten.dropout.default.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten.dropout.default.py_impl(DispatchKey.Autograd) +def dropout(input: Tensor, p: float, train: Optional[bool]): + if train and p != 0: + return aten.native_dropout(input, p, train)[0] + else: + return input.clone() + + +@register_decomposition(aten.native_dropout) +@out_wrapper("out0", "out1") +def native_dropout(input: Tensor, p: float, train: Optional[bool]): + if train and p != 0: + if p == 1: + return (torch.zeros_like(input), torch.zeros_like(input, dtype=torch.bool)) + if not input.dtype.is_floating_point: + raise RuntimeError( + "result type Float can't be cast to the desired output type Long" + ) + bool_mask = torch.rand_like(input) > p + res = bool_mask * input * float(1.0 / (1.0 - p)) + return (res, bool_mask) + else: + return (input, torch.ones_like(input, dtype=torch.bool)) + + +@register_decomposition(aten._softmax) +@out_wrapper() +def _softmax(x: Tensor, dim: int, half_to_float: bool): + from torch.fx.experimental.symbolic_shapes import guard_or_false + + # eager softmax returns a contiguous tensor. Ensure that decomp also returns + # a contiguous tensor. + x = x.contiguous() + if half_to_float: + assert x.dtype == torch.half + computation_dtype, result_dtype = utils.elementwise_dtypes( + x, type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ) + x = x.to(computation_dtype) + if guard_or_false(x.numel() == 0): + unnormalized = torch.exp(x) + else: + x_max = torch.amax(x, dim, keepdim=True) + unnormalized = torch.exp(x - x_max) + result = unnormalized / torch.sum(unnormalized, dim, keepdim=True) + if not half_to_float: + result = result.to(result_dtype) + return result + + +@register_decomposition(aten._log_softmax) +@out_wrapper(exact_dtype=True) +def _log_softmax(x: Tensor, dim: int, half_to_float: bool): + from torch.fx.experimental.symbolic_shapes import guard_or_false + + # eager log_softmax returns a contiguous tensor. Ensure that decomp also + # returns a contiguous tensor. + x = x.contiguous() + if half_to_float: + assert x.dtype == torch.half + computation_dtype, result_dtype = utils.elementwise_dtypes( + x, type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ) + x = x.to(computation_dtype) + if guard_or_false(x.numel() == 0): + shifted = x + else: + x_max = torch.amax(x, dim, keepdim=True) + shifted = x - x_max + shifted_logsumexp = torch.log(torch.sum(torch.exp(shifted), dim, keepdim=True)) + result = shifted - shifted_logsumexp + if not half_to_float: + result = result.to(result_dtype) + return result + + +@register_decomposition(aten.embedding) +@out_wrapper() +def embedding( + weight: Tensor, + indices: Tensor, + padding_idx: int = -1, + scale_grad_by_freq: bool = False, + sparse: bool = False, +) -> Tensor: + assert weight.dim() == 2, "'weight' must be 2-D" + # Nb. scale_grad_by_freq is not used in the forward + if indices.ndim <= 1: + # We need this one as weight[indices] calls item() in these cases + out = weight.index_select(0, indices) + if indices.ndim == 0: + out = out.squeeze(0) + return out + else: + return weight[indices] + + +@register_decomposition(aten.embedding_dense_backward) +@out_wrapper() +def embedding_dense_backward( + grad_output: Tensor, + indices: Tensor, + num_weights: int, + padding_idx: int, + scale_grad_by_freq: bool, +): + computation_dtype, result_dtype = utils.elementwise_dtypes( + grad_output, type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ) + grad_output = grad_output.to(computation_dtype) + indices = _maybe_convert_to_dtype(indices, torch.long) # type: ignore[assignment] + if scale_grad_by_freq: + counts = indices.new_zeros((num_weights,)) + ones = torch.ones_like(indices) + counts = aten._unsafe_index_put(counts, [indices], ones, accumulate=True) + grad_weights_scale = counts[indices] + grad_output = grad_output / grad_weights_scale.unsqueeze(-1) + + mask = _unsqueeze_to_dim(indices == padding_idx, grad_output.ndim) + grad = grad_output.masked_fill(mask, 0) + grad_weight = grad_output.new_zeros( + (num_weights,) + grad_output.shape[indices.ndim :] + ) + return aten._unsafe_index_put(grad_weight, [indices], grad, accumulate=True).to( + result_dtype + ) + + +def prod(x: list[int]): + r = 1 + for i in x: + r *= i + return r + + +def _pad_chunk( + tensors: list[Tensor], + dim: int, + num_chunks: int, +) -> list[Tensor]: + padded_tensors = [] + for tensor in tensors: + tensor_size = tensor.size() + pad_along_dim = (tensor_size[dim] + num_chunks - 1) // num_chunks * num_chunks + if pad_along_dim != tensor_size[dim]: + # Use aten.constant_pad_nd instead of copy_ for functionalization + pad = [0] * 2 * (tensor.ndim - dim - 1) + [ + 0, + pad_along_dim - tensor_size[dim], + ] + tensor = aten.constant_pad_nd(tensor, pad, 0) + view_size = tensor_size[:dim] + torch.Size([num_chunks, -1]) + padded_tensors.append(tensor.reshape(view_size)) + return padded_tensors + + +def have_same_ndims(tensors: list[Tensor]): + ndim = tensors[0].ndim + for tensor in tensors: + if tensor.ndim != ndim: + return False + return True + + +def leading_dimension_matches(tensors: list[Tensor], dim: int): + leading_dim_sizes = tensors[0].size()[:dim] + for tensor in tensors: + torch._check( + tensor.size()[:dim] == leading_dim_sizes, + lambda: "_chunk_cat expects same sizes of 0,...,dim-1 dimensions for all tensors", + ) + + +def _preprocess_chunk_cat_inputs( + tensors: list[Tensor], + dim: int, + num_chunks: int, +): + torch._check(num_chunks >= 1, lambda: "_chunk_cat expects positive num_chunks") + torch._check( + len(tensors) > 0, lambda: "_chunk_cat expects a non-empty input tensor list" + ) + expected_dtype = tensors[0].dtype + expected_device = tensors[0].device + for tensor in tensors: + torch._check(tensor.numel() > 0, lambda: "_chunk_cat expects non-empty tensor") + torch._check( + tensor.dtype == expected_dtype, + lambda: "_chunk_cat expects all input tensors with the same dtype", + ) + torch._check( + tensor.device == expected_device, + lambda: "_chunk_cat expects all inputs tensors on the same device", + ) + if have_same_ndims(tensors): + dim = utils.canonicalize_dim(tensors[0].dim(), dim) + else: + torch._check( + dim >= 0, + lambda: "_chunk_cat expects non-negative dim when input tensors have different ndims", + ) + for tensor in tensors: + torch._check( + dim < tensor.ndim, + lambda: "_chunk_cat expects dim < ndim for all input tensors", + ) + leading_dimension_matches(tensors, dim) + return dim + + +@register_decomposition([aten._chunk_cat.default, aten._chunk_cat.out]) +def _chunk_cat( + tensors: list[Tensor], + dim: int, + num_chunks: int, + out: Optional[Tensor] = None, +) -> Tensor: + dim = _preprocess_chunk_cat_inputs(tensors, dim, num_chunks) + padded_tensors = _pad_chunk(tensors, dim, num_chunks) + if out is None: + return torch.cat(padded_tensors, dim + 1) + else: + torch.cat(padded_tensors, dim + 1, out=out) + return out + + +# out_wrapper currently does not allow optional outputs +@register_decomposition( + [aten.split_with_sizes_copy.default, aten.split_with_sizes_copy.out] +) +def split_with_sizes_copy( + self: Tensor, + split_sizes: list[int], + dim: int = 0, + out: Optional[list[Tensor]] = None, +) -> Optional[list[Tensor]]: + splits = aten.split_with_sizes(self, split_sizes, dim=dim) + if out is None: + return [s.clone(memory_format=torch.contiguous_format) for s in splits] + else: + for output, split in zip(out, splits): + _maybe_resize_out(output, split.shape) + _safe_copy_out(copy_from=split, copy_to=output, exact_dtype=True) + return None + + +@register_decomposition(aten.unsafe_split.Tensor) +def unsafe_split(input: Tensor, split_size: int, dim: int = 0) -> tuple[Tensor, ...]: + return aten.split.Tensor(input, split_size, dim) + + +@register_decomposition(aten.unsafe_split_with_sizes.default) +def unsafe_split_with_sizes( + input: Tensor, split_sizes: list[int], dim: int = 0 +) -> tuple[Tensor, ...]: + return aten.split_with_sizes.default(input, split_sizes, dim) + + +@register_decomposition(aten.split.Tensor) +def split(self: Tensor, split_size: int, dim: int = 0) -> tuple[Tensor, ...]: + input_sizes = self.shape + dim_size = input_sizes[dim] + if split_size == 0: + assert dim_size == 0 + return (self.detach(),) + chunks = (dim_size + split_size - 1) // split_size + + # Avoid importing sympy at a module level + from torch.fx.experimental.symbolic_shapes import guard_int + + chunks = guard_int(chunks) + split_sizes = [split_size for i in range(chunks)] + split_sizes[-1] = split_size - (split_size * chunks - dim_size) + return torch.split(self, split_sizes, dim) + + +@aten.tensor_split.tensor_indices_or_sections.py_impl( + DispatchKey.CompositeImplicitAutograd +) +def tensor_split_tensor_indices_or_sections_py_impl( + self: Tensor, + tensor_indices_or_sections: Tensor, + dim: int = 0, +) -> tuple[Tensor, ...]: + assert tensor_indices_or_sections.device.type == "cpu" + assert tensor_indices_or_sections.dtype == torch.int64 + split_dim = tensor_indices_or_sections.dim() + torch._check( + split_dim == 1 or split_dim == 0, + lambda: "tensor_split expected tensor_indices_or_sections to be a zero-dimensional " + f"or one-dimensional tensor, but got a tensor with {split_dim} dims", + ) + if split_dim == 0: + sections = tensor_indices_or_sections.item() + assert isinstance(sections, IntLike) + return self.tensor_split(sections, dim) + else: + ctx = nullcontext + if (fake_mode := torch._guards.detect_fake_mode()) and ( + shape_env := fake_mode.shape_env + ): + ctx = shape_env.ignore_fresh_unbacked_symbols # type: ignore[assignment] + # In fake tensor prop, we end up calling slice() with these unbacked indices. + # Because slice has flexible semantics, the unbacked handling generates new output sizes + # for each slice, effectively clobbering over these index symbols. + # To avoid PendingUnbackedSymbolNotFound errors, we tell the compiler it's fine to not bind these. + with ctx(): + indices = [i.item() for i in tensor_indices_or_sections] + # WARNING: Tempted to torch._check(x>0) on the indices here? You + # can't: tensor_split works with negative values in indices: + # + # >>> torch.tensor_split(torch.randn(10), torch.tensor([-5, 5])) + # (tensor([ 0.3540, 2.1074, -0.8507, 1.1639, 0.3055]), tensor([]), + # tensor([-0.4285, 1.0692, -0.1776, 0.9362, 1.6143])) + # + # Sorry, I don't make the rules. Explicitly do the item call in user + # code if you KNOW that they are non-negative. + return self.tensor_split(indices, dim) + + +# TODO: this doesn't appear to have enough precision in bfloat16 +@register_decomposition(aten.addmm) +@out_wrapper(exact_dtype=True) +@pw_cast_for_opmath +def addmm(self: Tensor, mat1: Tensor, mat2: Tensor, beta: int = 1, alpha: int = 1): + if not self.is_floating_point() and not self.is_complex(): + beta = int(beta) + alpha = int(alpha) + out = alpha * torch.mm(mat1, mat2) + if beta == 0: + return out + + # The output of aten.addmm is contiguous, we need to match this behavior in the decomposition. + # The original implementation 'beta * self + out' would return a strided tensor if `self` is strided. + # We thus use `out`, the output of torch.mm, which is always contiguous, as the first argument for addition. + # This is relying on TensorIterator's behavior that it takes higher precedence on the stride of first input. + # Alternative, we can write `(beta * self + out).contiguous()`, but it introduces another copy in some cases. + # This implementation is not ideal, and we should revisit this when we have a better solution. + return out + beta * self + + +@register_decomposition(aten._addmm_activation) +@out_wrapper() +@pw_cast_for_opmath +def _addmm_activation( + self: Tensor, + mat1: Tensor, + mat2: Tensor, + beta: int = 1, + alpha: int = 1, + use_gelu: bool = False, +): + out = addmm(self, mat1, mat2, beta, alpha) + if use_gelu: + if self.is_cuda: + return aten.gelu(out, approximate="tanh") + else: + return aten.gelu(out) + return aten.relu(out) + + +@register_decomposition(aten.addmv) +@out_wrapper(exact_dtype=True) +@pw_cast_for_opmath +def addmv(self: Tensor, mat1: Tensor, vec: Tensor, beta: int = 1, alpha: int = 1): + if not self.is_floating_point() and not self.is_complex(): + beta = int(beta) + alpha = int(alpha) + out = alpha * torch.mv(mat1, vec) + if beta == 0: + return out + if out.numel() == 0: # handle empty matrix + return beta * self + return out + beta * self + + +@register_decomposition(aten.native_group_norm_backward.default) +@pw_cast_for_opmath +def native_group_norm_backward( + grad_output: Tensor, + input: Tensor, + mean: Tensor, + rstd: Tensor, + gamma: Optional[Tensor], + N: int, + C: int, + HxW: int, + group: int, + output_mask: list[bool], +) -> tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]: + utils.check_same_device( + grad_output, input, mean, rstd, allow_cpu_scalar_tensors=False + ) + utils.check_same_shape(input, grad_output, allow_cpu_scalar_tensors=False) + utils.check_same_shape(mean, rstd, allow_cpu_scalar_tensors=False) + torch._check( + input.numel() == N * C * HxW, + lambda: f"Expect input to have {N * C * HxW} elements", + ) + torch._check( + mean.shape == (N, group), + lambda: f"Expect mean to have shape ({N}, {group}, but got {mean.shape}", + ) + torch._check( + gamma is None or gamma.numel() == C, + lambda: f"Expect gamma to have {C} elements but got {gamma.numel() if gamma is not None else -1}", + ) + + cpg = C // group + torch._check( + C == cpg * group, + lambda: f"Expect number of channels {C} to be evenly-divisible by number of groups {group}", + ) + + # Compute Internal gradients + ds = torch.mul(grad_output, input).view(N, C, HxW).sum(dim=[2]) + db = grad_output.view(N, C, HxW).sum(dim=[2]) + + d_input: Optional[Tensor] = None + d_gamma: Optional[Tensor] = None + d_bias: Optional[Tensor] = None + if output_mask[0]: + s = 1.0 / (HxW * cpg) + if gamma is not None: + ds_val = torch.mul(ds, gamma.unsqueeze(0)).reshape(N, group, cpg).sum(2) + db_val = torch.mul(db, gamma.unsqueeze(0)).reshape(N, group, cpg).sum(2) + c1 = torch.mul( + rstd.unsqueeze(-1), + gamma.reshape(1, group, cpg), + ) + else: + ds_val = ds.reshape(N, group, cpg).sum(2) + db_val = db.reshape(N, group, cpg).sum(2) + c1 = torch.mul( + rstd.unsqueeze(-1), + torch.ones((1, group, cpg), device=rstd.device), + ) + c2 = (db_val * mean - ds_val) * rstd * rstd * rstd * s + c3 = -c2 * mean - db_val * rstd * s + + c1 = c1.unsqueeze(-1) + c2 = _unsqueeze_to_dim(c2, 4) + c3 = _unsqueeze_to_dim(c3, 4) + d_input = ( + torch.mul(grad_output.reshape(N, group, cpg, HxW), c1) + + torch.mul(input.reshape(N, group, cpg, HxW), c2) + + c3 + ) + d_input = d_input.reshape(input.shape).to(input.dtype) + if output_mask[1]: + d_gamma = ( + ( + (ds.view(N, group, cpg) - db.view(N, group, cpg) * mean.unsqueeze(-1)) + * rstd.unsqueeze(-1) + ) + .sum(dim=[0]) + .reshape(C) + ) + if output_mask[2]: + d_bias = db.sum(dim=[0]) + + return (d_input, d_gamma, d_bias) + + +# out_wrapper currently does not allow optional outputs +@register_decomposition(aten.native_group_norm_backward.out) +def native_group_norm_backward_out( + grad_output: Tensor, + input: Tensor, + mean: Tensor, + rstd: Tensor, + gamma: Optional[Tensor], + N: int, + C: int, + HxW: int, + group: int, + output_mask: list[bool], + *, + out0: torch.Tensor, + out1: torch.Tensor, + out2: torch.Tensor, +) -> tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]: + result = native_group_norm_backward( + grad_output, input, mean, rstd, gamma, N, C, HxW, group, output_mask + ) + grad_input = (out0, out1, out2) + for i, r in enumerate(result): + if r is not None: + _maybe_resize_out(grad_input[i], r.shape) + _safe_copy_out(copy_from=r, copy_to=grad_input[i], exact_dtype=True) + + return grad_input + + +def _maybe_cast(x: Optional[Tensor], dtype) -> Optional[Tensor]: + if x is not None: + return x.to(dtype) + return x + + +# TODO: Take a closer look at the type promotion semantics +@register_decomposition(aten.native_layer_norm_backward.default) +def native_layer_norm_backward( + grad_out: Tensor, + input: Tensor, + normalized_shape: list[int], + mean: Tensor, + rstd: Tensor, + weight: Optional[Tensor], + bias: Optional[Tensor], + output_mask: list[bool], +) -> tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]: + input_shape = input.shape + input_ndim = input.dim() + computation_dtype = utils.get_computation_dtype(input.dtype) + grad_out_cast, input_cast, weight_cast, bias_cast = ( + x.to(computation_dtype, memory_format=torch.contiguous_format) + if x is not None + else x + for x in (grad_out, input, weight, bias) + ) + assert grad_out_cast is not None + + axis = input_ndim - len(normalized_shape) + inner_dims = input_shape[axis:] + outer_dims = input_shape[:axis] + inner_dim_indices: list[int] = [] + outer_dim_indices: list[int] = [] + for i in range(input_ndim): + if i >= axis: + inner_dim_indices.append(i) + else: + outer_dim_indices.append(i) + + N = prod(inner_dims) # type: ignore[arg-type] + M = prod(outer_dims) # type: ignore[arg-type] + from torch.fx.experimental.symbolic_shapes import statically_known_true + + if statically_known_true(M == 0) or statically_known_true(N == 0): + return ( + input.new_zeros(input_shape) if output_mask[0] else None, + input.new_zeros(input_shape[axis:]) if output_mask[1] else None, + input.new_zeros(input_shape[axis:]) if output_mask[2] else None, + ) + mean = _unsqueeze_to_dim(mean, input_cast.dim()) # type: ignore[union-attr] + rstd = _unsqueeze_to_dim(rstd, input_cast.dim()) # type: ignore[union-attr] + assert input_cast is not None + x_hat = (input_cast - mean) * rstd + if weight_cast is not None: + grad_x_hat = grad_out_cast * weight_cast + else: + grad_x_hat = grad_out_cast + a = grad_x_hat * N + b = torch.sum(grad_x_hat, inner_dim_indices, True) + c1 = torch.mul(grad_x_hat, x_hat) + c2 = torch.sum(c1, inner_dim_indices, True) + c3 = torch.mul(x_hat, c2) + + inner = a - b - c3 + d_input: Optional[Tensor] = None + d_weight: Optional[Tensor] = None + d_bias: Optional[Tensor] = None + if output_mask[0]: + d_input = (rstd / N) * inner + + if output_mask[1] and weight_cast is not None: + if len(outer_dim_indices) > 0: + d_weight = torch.sum(grad_out_cast * x_hat, outer_dim_indices, False) + else: + d_weight = grad_out_cast * x_hat + + if output_mask[2] and bias_cast is not None: + if len(outer_dim_indices) > 0: + d_bias = torch.sum(grad_out_cast, outer_dim_indices, False) + else: + d_bias = grad_out_cast.clone() + + return ( + _maybe_cast(d_input, input.dtype), + _maybe_cast(d_weight, weight.dtype if weight is not None else None), + _maybe_cast(d_bias, bias.dtype if bias is not None else None), + ) + + +# out_wrapper currently does not allow optional outputs +@register_decomposition(aten.native_layer_norm_backward.out) +def native_layer_norm_backward_out( + grad_out: Tensor, + input: Tensor, + normalized_shape: list[int], + mean: Tensor, + rstd: Tensor, + weight: Optional[Tensor], + bias: Optional[Tensor], + output_mask: list[bool], + *, + out0: torch.Tensor, + out1: torch.Tensor, + out2: torch.Tensor, +) -> tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]: + result = native_layer_norm_backward( + grad_out, input, normalized_shape, mean, rstd, weight, bias, output_mask + ) + grad_input = (out0, out1, out2) + for i, r in enumerate(result): + if r is not None: + _maybe_resize_out(grad_input[i], r.shape) + _safe_copy_out(copy_from=r, copy_to=grad_input[i], exact_dtype=True) + + return grad_input + + +@register_decomposition(aten._fused_rms_norm.default) +def _fused_rms_norm( + input: Tensor, + normalized_shape: list[int], + weight: Optional[Tensor], + eps: Optional[float], +) -> tuple[Tensor, Tensor]: + dims_to_reduce: list[int] = [] + for i in range(len(normalized_shape)): + dims_to_reduce.append(input.dim() - i - 1) + + # upcast is needed for fp16 and bf16 + computation_dtype = utils.get_computation_dtype(input.dtype) + upcasted_input = input.to(computation_dtype) + + # computation_dtype would be one of [Double, Float, ComplexFloat, ComplexDouble] + if eps is None: + if computation_dtype in (torch.float32, torch.complex64): + eps_val = torch.finfo(torch.float32).eps + else: + eps_val = torch.finfo(torch.float64).eps + else: + eps_val = eps + + rqrst_input = torch.rsqrt( + # NB: don't inplace here, will violate functional IR invariant + # NB: carefully use the Scalar overload of add to ensure compatibility with the C++ decomp + torch.ops.aten.add.Scalar( + torch.pow(upcasted_input, 2).mean(dim=dims_to_reduce, keepdim=True), eps_val + ) + ) + + upcasted_result = upcasted_input.mul(rqrst_input) + + if weight is not None: + upcasted_result = upcasted_result.mul(weight) + + # NB: nested should be dead here, just here for fidelity + is_nested = input.is_nested or (weight is not None and weight.is_nested) + memory_format = utils.suggest_memory_format(input) + is_channels_last = memory_format in ( + torch.channels_last, + torch.channels_last_3d, + ) + + if not is_nested and not is_channels_last: + upcasted_result = upcasted_result.contiguous() + rqrst_input = rqrst_input.contiguous() + + # Cast normalized result back to original input type + result = upcasted_result.type_as(input) + + return result, rqrst_input + + +@register_decomposition(aten._fused_rms_norm_backward.default) +def _fused_rms_norm_backward( + grad_out: Tensor, + input: Tensor, + normalized_shape: list[int], + rstd: Tensor, + weight: Optional[Tensor], + output_mask: list[bool], +) -> tuple[Optional[Tensor], Optional[Tensor]]: + input_shape = input.shape + input_ndim = input.dim() + computation_dtype = utils.get_computation_dtype(input.dtype) + + grad_out_cast = grad_out.to( + computation_dtype, memory_format=torch.contiguous_format + ) + input_cast = input.to(computation_dtype, memory_format=torch.contiguous_format) + weight_cast = ( + weight.to(computation_dtype, memory_format=torch.contiguous_format) + if weight is not None + else None + ) + assert grad_out_cast is not None + + axis = input_ndim - len(normalized_shape) + inner_dims = input_shape[axis:] + outer_dims = input_shape[:axis] + inner_dim_indices: list[int] = [] + outer_dim_indices: list[int] = [] + for i in range(input_ndim): + if i >= axis: + inner_dim_indices.append(i) + else: + outer_dim_indices.append(i) + + N = prod(inner_dims) # type: ignore[arg-type] + M = prod(outer_dims) # type: ignore[arg-type] + from torch.fx.experimental.symbolic_shapes import guard_or_false + + if guard_or_false(M == 0) or guard_or_false(N == 0): + return ( + input.new_zeros(input_shape) if output_mask[0] else None, + input.new_zeros(input_shape[axis:]) if output_mask[1] else None, + ) + + rstd = _unsqueeze_to_dim(rstd, input_cast.dim()) # type: ignore[union-attr] + if weight_cast is not None: + grad_x_hat = grad_out_cast * weight_cast + else: + grad_x_hat = grad_out_cast + + d_input: Optional[Tensor] = None + d_weight: Optional[Tensor] = None + + x_hat = input_cast * rstd + + if output_mask[0]: + sum_val = torch.sum(x_hat * grad_x_hat, dim=inner_dim_indices, keepdim=True) + d_input = (grad_x_hat - (x_hat / N) * sum_val) * rstd + + if output_mask[1] and weight_cast is not None: + d_weight_full_shape = grad_out_cast * x_hat + if len(outer_dim_indices) > 0: + d_weight = torch.sum( + d_weight_full_shape, dim=outer_dim_indices, keepdim=False + ) + else: + d_weight = d_weight_full_shape + + return ( + _maybe_cast(d_input, input.dtype), + _maybe_cast(d_weight, input.dtype), + ) + + +def native_batch_norm_helper( + input: Tensor, + weight: Optional[Tensor], + bias: Optional[Tensor], + running_mean: Optional[Tensor], + running_var: Optional[Tensor], + training: bool, + momentum: float, + eps: float, + functional: bool, +) -> tuple[Tensor, Tensor, Tensor, Optional[Tensor], Optional[Tensor]]: + reduction_dims = [0] + list(range(2, input.dim())) + computation_dtype = utils.get_computation_dtype(input.dtype) + new_running_mean = running_mean + new_running_var = running_var + if training: + computation_dtype = utils.get_computation_dtype(input.dtype) + input_acc = input.to(dtype=computation_dtype) + biased_var, mean = torch.var_mean( + input_acc, dim=reduction_dims, correction=0, keepdim=True + ) + rstd = torch.rsqrt(biased_var + eps) + + output = (input - mean) * rstd + + save_mean = torch.squeeze(mean, reduction_dims) + save_rstd = torch.squeeze(rstd, reduction_dims) + if running_mean is not None: + new_running_mean = momentum * save_mean + (1 - momentum) * running_mean + if not functional: + running_mean.copy_(new_running_mean) + if running_var is not None: + n = input.numel() / input.shape[1] + # This doesn't strictly match eager's numerics, which accumulates var sum and then directly applies the correction + # But... that would require re-implementing var here, for negligible numerics gain on a tensor whose + # numerics probably don't matter. + squeezed_var = torch.squeeze(biased_var, reduction_dims) + unbiased_var = squeezed_var * (n / (n - 1)) + new_running_var = momentum * unbiased_var + (1 - momentum) * running_var + if not functional: + running_var.copy_(new_running_var) + else: + assert running_mean is not None and running_var is not None + running_mean = running_mean.to(dtype=computation_dtype, copy=True) + new_running_mean = running_mean + running_var = running_var.to(dtype=computation_dtype, copy=True) + new_running_var = running_var + mean = running_mean + invstd = 1 / (torch.sqrt(running_var + eps)) + # Very annoying inconsistency where CPU and CUDA give different shapes + if input.device.type != "cpu": + save_mean = running_mean + save_rstd = invstd + else: + save_mean = input.new_zeros((0,)) + save_rstd = input.new_zeros((0,)) + mean = _unsqueeze_to_dim(mean, input.dim() - 1) + invstd = _unsqueeze_to_dim(invstd, input.dim() - 1) + output = (input - mean) * invstd + + if weight is not None: + weight = weight.flatten() + weight = _unsqueeze_to_dim(weight, input.dim() - 1) + output = output * weight + + if bias is not None: + bias = bias.flatten() + bias = _unsqueeze_to_dim(bias, input.dim() - 1) + output = output + bias + + if input.device.type == "cpu": + save_mean = save_mean.to(dtype=input.dtype) + save_rstd = save_rstd.to(dtype=input.dtype) + return ( + output.to(dtype=input.dtype), + save_mean, + save_rstd, + new_running_mean, + new_running_var, + ) + + +@register_decomposition(aten.native_batch_norm) +@out_wrapper("out", "save_mean", "save_invstd") +def native_batch_norm( + input: Tensor, + weight: Optional[Tensor], + bias: Optional[Tensor], + running_mean: Optional[Tensor], + running_var: Optional[Tensor], + training: bool, + momentum: float, + eps: float, +) -> tuple[Tensor, Tensor, Tensor]: + output, save_mean, save_rstd, _, _ = native_batch_norm_helper( + input, weight, bias, running_mean, running_var, training, momentum, eps, False + ) + return output, save_mean, save_rstd + + +# TODO: this decomposition is NOT here to stay. We would much prefer replacing native_batch_norm +# with our new correctly schema'd _native_batch_norm_legit and its variants, but +# we cannot do that immediately in the C++ because it would be forwards incompatible +# with some mobile use cases. +# +# Since this change is most impactful for aot autograd/functionalization, we simply +# register this decomposition on the Autograd key for the python dispatcher (which is +# currently only used by aot autograd/functionalization and no one else, really). +# In two weeks or so, we should remove this decomposition and phase out the current native_batch_norm +# to be _native_batch_norm_legit and have the right schema (stating that there are input mutations). +@aten.native_batch_norm.default.py_impl(DispatchKey.Autograd) +@aten.native_batch_norm.default.py_impl(DispatchKey.CompositeImplicitAutograd) +def native_batch_norm_decomposition( + input: Tensor, + weight: Optional[Tensor], + bias: Optional[Tensor], + running_mean: Optional[Tensor], + running_var: Optional[Tensor], + training: bool, + momentum: float, + eps: float, +) -> tuple[Tensor, Tensor, Tensor]: + if running_mean is None and running_var is None: + return aten._native_batch_norm_legit( + input, weight, bias, training, momentum, eps + ) + if running_mean is None: + raise RuntimeError( + "running_mean is None, but running_var is provided. " + "They should both be None or both be provided." + ) + if running_var is None: + raise RuntimeError( + "running_var is None, but running_mean is provided. " + "They should both be None or both be provided." + ) + if training: + # HACK: batch norm consolidation should clean this up so this op doesn't take in a training arg. + return aten._native_batch_norm_legit( + input, weight, bias, running_mean, running_var, training, momentum, eps + ) + else: + return aten._native_batch_norm_legit_no_training( + input, weight, bias, running_mean, running_var, momentum, eps + ) + + +@aten.unsafe_chunk.default.py_impl(DispatchKey.CompositeImplicitAutograd) +def unsafe_chunk_py_impl(tensor, chunks, dim=0) -> list[Tensor]: + dim_size = tensor.size(dim) + split_size = (dim_size + chunks - 1) // chunks + + if split_size == 0 and dim_size == 0: + split_sizes = [split_size for _ in chunks] + split_sizes[chunks - 1] = split_size - (split_size * chunks - dim_size) + return torch.ops.aten.unsafe_split_with_sizes.default(tensor, split_sizes, dim) + return torch.ops.aten.unsafe_split.Tensor(tensor, split_size, dim) + + +@register_decomposition(aten._native_batch_norm_legit_no_training.default) +def _native_batch_norm_legit_no_training( + input: Tensor, + weight: Optional[Tensor], + bias: Optional[Tensor], + running_mean: Tensor, + running_var: Tensor, + momentum: float, + eps: float, +) -> tuple[Tensor, Tensor, Tensor]: + return aten._native_batch_norm_legit.default( + input, + weight, + bias, + running_mean, + running_var, + False, # training + momentum, + eps, + ) + + +@register_decomposition(aten._native_batch_norm_legit.default) +def _native_batch_norm_legit( + input: Tensor, + weight: Optional[Tensor], + bias: Optional[Tensor], + running_mean: Tensor, + running_var: Tensor, + training: bool, + momentum: float, + eps: float, +) -> tuple[Tensor, Tensor, Tensor]: + output, save_mean, save_rstd, _, _ = native_batch_norm_helper( + input, weight, bias, running_mean, running_var, training, momentum, eps, False + ) + return output, save_mean, save_rstd + + +@register_decomposition(aten._native_batch_norm_legit.no_stats) +def _native_batch_norm_legit_no_stats( + input: Tensor, + weight: Optional[Tensor], + bias: Optional[Tensor], + training: bool, + momentum: float, + eps: float, +) -> tuple[Tensor, Tensor, Tensor]: + output, save_mean, save_rstd, _, _ = native_batch_norm_helper( + input, weight, bias, None, None, training, momentum, eps, False + ) + return output, save_mean, save_rstd + + +@register_decomposition(aten._native_batch_norm_legit_functional.default) +def _native_batch_norm_legit_functional( + input: Tensor, + weight: Optional[Tensor], + bias: Optional[Tensor], + running_mean: Tensor, + running_var: Tensor, + training: bool, + momentum: float, + eps: float, +) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: + ( + output, + save_mean, + save_rstd, + new_running_mean, + new_running_var, + ) = native_batch_norm_helper( + input, weight, bias, running_mean, running_var, training, momentum, eps, True + ) + assert new_running_mean is not None, "new_running_mean should not be None" + assert new_running_var is not None, "new_running_var should not be None" + return output, save_mean, save_rstd, new_running_mean, new_running_var + + +def _get_batch_norm_reserve_tensor( + input: Tensor, + weight: Optional[Tensor], + bias: Optional[Tensor], + running_mean: Tensor, + running_var: Tensor, + eps: float, + training: bool, +) -> Tensor: + """ + Return a reserve tensor for batch norm, used only by cudnn to pass forward state to the + backward pass. This is needed for `_batch_norm_with_update` and `_batch_norm_no_update`, + which support a variety of backends including cudnn. We create this tensor here to get + the correct shape in the traced graph if we detect that will call the cudnn kernel, + and rely on DCE to avoid materializing this tensor. + """ + backend = torch._C._select_batch_norm_backend( # type: ignore[attr-defined] + input, weight, bias, running_mean, running_var, True, eps + ) + reserve_size = 0 + if backend == torch._C._BatchNormBackend.Cudnn: # type: ignore[attr-defined] + reserve_size = torch._C._get_cudnn_batch_norm_reserve_space_size( # type: ignore[attr-defined] + input, training + ) + return torch.empty( + reserve_size, dtype=torch.uint8, layout=input.layout, device=input.device + ) + + +@register_decomposition(aten._batch_norm_with_update.default) +def _batch_norm_with_update( + input: Tensor, + weight: Optional[Tensor], + bias: Optional[Tensor], + running_mean: Tensor, + running_var: Tensor, + momentum: float, + eps: float, +) -> tuple[Tensor, Tensor, Tensor, Tensor]: + output, save_mean, save_rstd, _, _ = native_batch_norm_helper( + input, + weight, + bias, + running_mean, + running_var, + True, # training + momentum, + eps, + False, # functional + ) + reserve = _get_batch_norm_reserve_tensor( + input, weight, bias, running_mean, running_var, eps, training=True + ) + return output, save_mean, save_rstd, reserve + + +@register_decomposition(aten._batch_norm_with_update_functional.default) +def _batch_norm_with_update_functional( + input: Tensor, + weight: Optional[Tensor], + bias: Optional[Tensor], + running_mean: Tensor, + running_var: Tensor, + momentum: float, + eps: float, +) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: + ( + output, + save_mean, + save_rstd, + new_rm, + new_rv, + ) = native_batch_norm_helper( + input, weight, bias, running_mean, running_var, True, momentum, eps, True + ) + reserve = _get_batch_norm_reserve_tensor( + input, weight, bias, running_mean, running_var, eps, training=True + ) + assert new_rm is not None, "new_running_mean should not be None" + assert new_rv is not None, "new_running_var should not be None" + return (output, save_mean, save_rstd, reserve, new_rm, new_rv) + + +@register_decomposition(aten._batch_norm_no_update.default) +def _batch_norm_no_update( + input: Tensor, + weight: Optional[Tensor], + bias: Optional[Tensor], + running_mean: Tensor, + running_var: Tensor, + momentum: float, + eps: float, +) -> tuple[Tensor, Tensor, Tensor, Tensor]: + output, save_mean, save_rstd, _, _ = native_batch_norm_helper( + input, + weight, + bias, + running_mean, + running_var, + False, # training + momentum, + eps, + False, # functional + ) + reserve = _get_batch_norm_reserve_tensor( + input, weight, bias, running_mean, running_var, eps, training=False + ) + return output, save_mean, save_rstd, reserve + + +@register_decomposition(aten._fused_dropout) +@out_wrapper("out0", "out1") +@pw_cast_for_opmath +def _fused_dropout_decomposition(input, p, generator=None): + assert generator is None + mask = (torch.rand_like(input) < p).to(dtype=torch.uint8) + res = mask.type_as(input) * input * (1.0 / p) + return (res, mask) + + +@register_decomposition(aten._to_copy) +@out_wrapper() +def _to_copy( + x: Union[Tensor, NumberType], + *, + dtype: Optional[torch.dtype] = None, + layout=None, + device: Optional[torch.device] = None, + pin_memory: bool = False, + non_blocking: bool = False, + memory_format: Optional[torch.memory_format] = None, +): + assert not layout or layout == torch.strided, "TODO" + assert not pin_memory, "TODO" + assert isinstance(x, (torch.Tensor, int, float, bool, complex)) + if device is None and dtype is None and memory_format is None: + if isinstance(x, torch.Tensor): + return x.clone() + else: + return x + dtype_converted = False + + if isinstance(x, torch.Tensor): + x_tensor = x + else: + x_tensor = torch.scalar_tensor(x) + + if device is not None and device != x_tensor.device: + # avoid conversions on cpu + if dtype is not None and device.type == "cpu": + x_tensor = torch._prims.convert_element_type(x_tensor, dtype) + dtype_converted = True + x_tensor = torch._prims.device_put(x_tensor, device, non_blocking) + + if dtype is not None and not dtype_converted: + x_tensor = torch._prims.convert_element_type(x_tensor, dtype) + dtype_converted = True + + if memory_format is not None: # no ref/prim for memory format + return torch.clone(x_tensor, memory_format=memory_format) + return x_tensor + + +# Questionable decompositions +# This is only valid if we're running the graph without autograd, such as if the backward pass has been traced. +# Note that this decomposition causes issues with in-place ops +@register_decomposition([aten.detach, aten.lift, aten.lift_fresh]) +@out_wrapper() +def nop_decomposition(x): + return aten.alias(x) + + +# Also register to the Autograd dispatch key, so this decomp can run above autograd. +# native_batch_norm needs to decompose into other ops before autograd. +@aten.cudnn_batch_norm.default.py_impl(DispatchKey.Autograd) +@register_decomposition(aten.cudnn_batch_norm) +@out_wrapper("out0", "out1", "out2", "out3") +def cudnn_batch_norm( + input: Tensor, + weight: Tensor, + bias: Optional[Tensor], + running_mean: Optional[Tensor], + running_var: Optional[Tensor], + training: bool, + exponential_average_factor: float, + epsilon: float, +): + a, b, c = aten.native_batch_norm( + input, + weight, + bias, + running_mean, + running_var, + training, + exponential_average_factor, + epsilon, + ) + # Cudnn return running mean and variance when training is True + if training: + return (a, b, c, input.new_zeros((0,), dtype=torch.uint8)) + return ( + a, + weight.new_zeros((0,)), + weight.new_zeros((0,)), + input.new_zeros((0,), dtype=torch.uint8), + ) + + +def _broadcast_batch_norm_backward(x, broadcast_mask): + for axis, mask in enumerate(broadcast_mask): + if mask == 1 and not (axis < x.ndim and x.shape[axis] == mask): + x = x.unsqueeze(axis) + return x + + +@register_decomposition(aten.batch_norm_backward.default) +def batch_norm_backward( + grad_out: Tensor, + input: Tensor, + weight: Optional[Tensor], + running_mean: Optional[Tensor], + running_var: Optional[Tensor], + save_mean: Optional[Tensor], + save_invstd: Optional[Tensor], + train: bool, + eps: float, + output_mask: list[bool], + reserve: Tensor, +) -> tuple[Tensor, Optional[Tensor], Optional[Tensor]]: + return native_batch_norm_backward( + grad_out, + input, + weight, + running_mean, + running_var, + save_mean, + save_invstd, + train, + eps, + output_mask, + ) + + +@register_decomposition(aten.native_batch_norm_backward.default) +def native_batch_norm_backward( + grad_out: Tensor, + input: Tensor, + weight: Optional[Tensor], + running_mean: Optional[Tensor], + running_var: Optional[Tensor], + save_mean: Optional[Tensor], + save_invstd: Optional[Tensor], + train: bool, + eps: float, + output_mask: list[bool], +) -> tuple[Tensor, Optional[Tensor], Optional[Tensor]]: + input_dtype = input.dtype + if weight is not None: + weight_dtype = weight.dtype + else: + weight_dtype = input_dtype + computation_dtype = utils.get_computation_dtype(input.dtype) + ( + grad_out_cast, + input_cast, + weight_cast, + running_mean_cast, + running_var_cast, + save_mean_cast, + save_invstd_cast, + ) = ( + x.to(computation_dtype) if x is not None else x + for x in ( + grad_out, + input, + weight, + running_mean, + running_var, + save_mean, + save_invstd, + ) + ) + input_shape = input.shape + input_rank = input.dim() + assert input_rank >= 2, "rank of the input must be at least 2" + + axis = 1 + num_features = prod(list(input_shape)) / input_shape[axis] + mean = save_mean_cast + invstd = save_invstd_cast + if train: + assert mean is not None and invstd is not None + + else: + assert running_mean_cast is not None and running_var_cast is not None + mean = running_mean_cast + invstd = torch.rsqrt(running_var_cast + eps) + + broadcast_mask: list[int] = [1] * input_rank + broadcast_mask[axis] = input_shape[axis] + + reduction_axes: list[int] = [] + for i in range(input_rank): + if i != axis: + reduction_axes.append(i) + + mean = _broadcast_batch_norm_backward(mean, broadcast_mask) # type: ignore[arg-type] + norm = 1.0 / num_features + grad_output_sum = torch.sum(grad_out_cast, reduction_axes) # type: ignore[arg-type] + dot_p = torch.sum(grad_out_cast * (input_cast - mean), reduction_axes) # type: ignore[operator] + + grad_mean = _broadcast_batch_norm_backward(grad_output_sum * norm, broadcast_mask) + proj_scale = _broadcast_batch_norm_backward( + torch.mul(dot_p * norm, invstd * invstd), # type: ignore[operator] + broadcast_mask, + ) + + if weight_cast is None: + grad_scale = _broadcast_batch_norm_backward(invstd, broadcast_mask) * 1.0 # type: ignore[arg-type] + else: + grad_scale = _broadcast_batch_norm_backward( + invstd * weight_cast, broadcast_mask + ) + + if train: + proj = (input_cast - mean) * proj_scale # type: ignore[operator] + grad_input = ((grad_out_cast - proj) - grad_mean) * grad_scale + else: + grad_input = grad_out_cast * grad_scale + + if output_mask[1]: + grad_weight = dot_p * invstd + else: + grad_weight = None # "None" doesn't work with vjp, should use zeros for vjp + + if output_mask[2]: + grad_bias = grad_output_sum + else: + grad_bias = None # "None" doesn't work with vjp, should use zeros for vjp + + return ( + grad_input.to(input_dtype), + _maybe_cast(grad_weight, weight_dtype), + _maybe_cast(grad_bias, weight_dtype), + ) + + +# out_wrapper currently does not allow optional outputs +@register_decomposition(aten.native_batch_norm_backward.out) +def native_batch_norm_backward_out( + grad_out: Tensor, + input: Tensor, + weight: Optional[Tensor], + running_mean: Optional[Tensor], + running_var: Optional[Tensor], + save_mean: Optional[Tensor], + save_invstd: Optional[Tensor], + train: bool, + eps: float, + output_mask: list[bool], + *, + out0: torch.Tensor, + out1: torch.Tensor, + out2: torch.Tensor, +) -> tuple[Tensor, Optional[Tensor], Optional[Tensor]]: + result = native_batch_norm_backward( + grad_out, + input, + weight, + running_mean, + running_var, + save_mean, + save_invstd, + train, + eps, + output_mask, + ) + grad_input = (out0, out1, out2) + for i, r in enumerate(result): + if r is not None: + _maybe_resize_out(grad_input[i], r.shape) + _safe_copy_out(copy_from=r, copy_to=grad_input[i], exact_dtype=True) + + return grad_input + + +@register_decomposition(aten.miopen_batch_norm_backward) +@out_wrapper("out0", "out1", "out2") +def miopen_batch_norm_backward( + input: Tensor, + grad_output: Tensor, + weight: Tensor, + running_mean: Optional[Tensor], + running_var: Optional[Tensor], + save_mean: Optional[Tensor], + save_var: Optional[Tensor], + epsilon: float, +): + return aten.native_batch_norm_backward( + grad_output, + input, + weight, + running_mean, + running_var, + save_mean, + save_var, + True, + epsilon, + [True, True, True], + ) + + +@register_decomposition(aten.cudnn_batch_norm_backward) +@out_wrapper("out0", "out1", "out2") +def cudnn_batch_norm_backward( + input: Tensor, + grad_output: Tensor, + weight: Tensor, + running_mean: Optional[Tensor], + running_var: Optional[Tensor], + save_mean: Optional[Tensor], + save_var: Optional[Tensor], + epsilon: float, + reserveSpace: Tensor, +): + return aten.native_batch_norm_backward( + grad_output, + input, + weight, + running_mean, + running_var, + save_mean, + save_var, + True, + epsilon, + [True, True, True], + ) + + +@register_decomposition(aten._adaptive_avg_pool2d) +@out_wrapper() +@pw_cast_for_opmath +def adaptive_avg_pool2d(input: Tensor, output_size: tuple[int, int]): + # Preconditions + device = input.device + shape = input.shape + ndim = len(shape) + torch._check( + ndim in (3, 4), + lambda: f"adaptive_avg_pool2d(): Expected 3D or 4D tensor, but got {ndim}", + ) + for d in input.shape[-2:]: + torch._check( + d != 0, + lambda: "adaptive_avg_pool2d(): Expected input to have non-zero size for " + f"non-batch dimensions, but input has shape {tuple(shape)}.", + ) + + # Optimisation (we should also do this in the kernel implementation) + if shape[-2] % output_size[-2] == 0 and shape[-1] % output_size[-1] == 0: + stride = tuple(i // o for i, o in zip(shape[-2:], output_size)) + kernel = tuple( + i - (o - 1) * s for i, o, s in zip(shape[-2:], output_size, stride) + ) + return torch.nn.functional.avg_pool2d(input, kernel, stride) + + def start_index(a, b, c): + return torch.div(a * c, b, rounding_mode="trunc") + + def end_index(a, b, c): + return torch.div((a + 1) * c + b - 1, b, rounding_mode="trunc") + + def compute_idx(in_size, out_size): + orange = torch.arange(out_size, device=device, dtype=torch.int64) + i0 = start_index(orange, out_size, in_size) + # Let length = end_index - start_index, i.e. the length of the pooling kernels + # length.max() can be computed analytically as follows: + maxlength = in_size // out_size + 1 + in_size_mod = in_size % out_size + # adaptive = True iff there are kernels with different lengths + adaptive = not (in_size_mod == 0 or out_size % in_size_mod == 0) + if adaptive: + maxlength += 1 + elif in_size_mod == 0: + maxlength -= 1 + + range_max = torch.arange(maxlength, device=device, dtype=torch.int64) + idx = i0.unsqueeze(-1) + range_max + if adaptive: + # Need to clamp to avoid accessing out-of-bounds memory + # TODO make minimum accept scalars + maxval = torch.scalar_tensor( + in_size - 1, dtype=idx.dtype, device=idx.device + ) + idx = torch.minimum(idx, maxval) + + # Compute the length + i1 = end_index(orange, out_size, in_size) + length = i1 - i0 + else: + length = maxlength + return idx, length, range_max, adaptive + + # length is not None if it's constant, otherwise we'll need to compute it + idxh, length_h, range_max_h, adaptive_h = compute_idx(shape[-2], output_size[-2]) + idxw, length_w, range_max_w, adaptive_w = compute_idx(shape[-1], output_size[-1]) + + vals = input[..., _unsqueeze_to_dim(idxh, 4), idxw] + # Shortcut for the simpler case + if not adaptive_h and not adaptive_w: + return torch.mean(vals, dim=(-3, -1)) + + def maybe_mask(vals, length, range_max, adaptive, dim): + if isinstance(length, IntLike): + return vals, length + else: + # zero-out the things we didn't really want to select + assert dim < 0 + # hack + mask = range_max >= length.unsqueeze(-1) + if dim == -2: + mask = _unsqueeze_to_dim(mask, 4) + vals = torch.masked_fill(vals, mask, 0.0) + # Compute the length of each window + length = _unsqueeze_to_dim(length, -dim) + return vals, length + + vals, length_h = maybe_mask( + vals, length_h, range_max_h, adaptive=adaptive_h, dim=-2 + ) + vals, length_w = maybe_mask( + vals, length_w, range_max_w, adaptive=adaptive_w, dim=-1 + ) + + # We unroll the sum as we assume that the kernels are going to be small + ret = None + for i, j in product(range(vals.shape[-3]), range(vals.shape[-1])): + if ret is None: + ret = vals[..., i, :, j] + else: + ret = ret + vals[..., i, :, j] + return ret / (length_h * length_w) + + +def _max_unpoolnd( + self: TensorLike, indices: TensorLike, output_size: list[int], dim: int +): + # If the input tensors self and indices came from max_pool call as + # required by the documentation, this operation is deterministic + # because that ensures that if there are two entries in `indices` + # tensor that are equal, the corresponding values in `self` are also + # equal. If this condition is not satisfied, the operation is + # non-deterministic as one of the different values in `self` 'wins'. + utils.alert_not_deterministic(f"max_unpooling{dim}d_forward_out") + nc = reduce(operator.mul, self.shape[:-dim]) + hw = reduce(operator.mul, output_size) + indices_nc_shape = [1] * self.ndim + indices_nc_shape[:-dim] = self.shape[:-dim] + indices_flat = ( + indices + aten.arange(nc, device=self.device).view(indices_nc_shape) * hw + ).reshape(-1) + + output = self.new_zeros(list(self.shape[:-dim]) + list(output_size)) + return aten._unsafe_index_put( + output.reshape(-1), [indices_flat], self.reshape(-1), accumulate=False + ).view(output.shape) + + +@register_decomposition(aten.max_unpool2d) +@out_wrapper() +def max_unpool2d( + self: TensorLike, + indices: TensorLike, + output_size: list[int], +): + torch._check( + indices.dtype == torch.int64, + lambda: f"elements in indices should be type int64 but got: {indices.dtype}", + ) + torch._check( + len(output_size) == 2, + lambda: ( + f"There should be exactly two elements (height, width) in output_size, " + f"but got {len(output_size)} elements." + ), + ) + + torch._check( + self.ndim in (3, 4), + lambda: ( + f"Input to max_unpooling2d should be a 3d or 4d Tensor, " + f"but got a tensor with {self.ndim} dimensions." + ), + ) + torch._check( + self.shape == indices.shape, + lambda: ( + f"Expected shape of indices to be same as that of the input tensor ({self.shape}) " + f"but got indices tensor with shape: {indices.shape}" + ), + ) + + for i in range(1, self.ndim): + torch._check( + self.size(i) > 0, + lambda: ( + f"max_unpooling2d(): " + f"Expected input to have non-zero size for non-batch dimensions, " + f"but got {self.shape} with dimension {i} being empty." + ), + ) + + return _max_unpoolnd(self, indices, output_size, 2) + + +@register_decomposition(aten.max_unpool3d) +@out_wrapper() +def max_unpool3d( + input: TensorLike, + indices: TensorLike, + output_size: list[int], + stride: list[int], + padding: list[int], +): + torch._check( + indices.dtype == torch.int64, lambda: "elements in indices should be type int64" + ) + torch._check( + input.ndim in (4, 5), + lambda: f"Input to max_unpooling3d should be a 4d or 5d Tensor, but got a tensor with {input.ndim} dimensions.", + ) + torch._check( + len(output_size) == 3, + lambda: ( + f"There should be exactly three elements (depth, height, width) in output_size, " + f"but got {len(output_size)} elements." + ), + ) + torch._check( + len(stride) == 3, + lambda: f"There should be exactly three elements (depth, height, width) in stride, but got: {len(stride)} elements.", + ) + torch._check( + len(padding) == 3, + lambda: f"There should be exactly three elements (depth, height, width) in padding, but got: {len(padding)} elements.", + ) + torch._check( + input.shape == indices.shape, + lambda: ( + f"Expected shape of indices to be same as that of the input tensor ({input.shape}) " + f"but got indices tensor with shape: {indices.shape}" + ), + ) + + for i in range(1, input.ndim): + torch._check( + input.size(i) > 0, + lambda: ( + f"max_unpooling3d(): " + f"Expected input to have non-zero size for non-batch dimensions, " + f"but got {input.shape} with dimension {i} being empty." + ), + ) + + torch._check( + stride[0] > 0 and stride[1] > 0 and stride[2] > 0, + lambda: f"strides should be greater than zero, but got stride: {stride}", + ) + + return _max_unpoolnd(input, indices, output_size, 3) + + +@register_decomposition(aten.index_add_) +def index_add_( + x: TensorLike, + dim: int, + index: TensorLike, + tensor: TensorLike, + *, + alpha: NumberType = 1, +): + return _index_add(x, dim, index, tensor, inplace=True, alpha=alpha) + + +@register_decomposition(aten.index_add) +@out_wrapper() +def index_add( + x: TensorLike, + dim: int, + index: TensorLike, + tensor: TensorLike, + *, + alpha: NumberType = 1, +): + return _index_add(x, dim, index, tensor, inplace=False, alpha=alpha) + + +def _index_add( + x: TensorLike, + dim: int, + index: TensorLike, + tensor: TensorLike, + *, + inplace: bool, + alpha: NumberType = 1, +): + dim = utils.canonicalize_dims(x.ndim, dim) + torch._check( + index.ndim <= 1, + lambda: f"Index should have dimension 1 or 0 (got {index.ndim})", + ) + index_size = index.size(0) if index.ndim == 1 else 1 + tensor_size = tensor.size(dim) if tensor.ndim > 0 else 1 + torch._check( + tensor_size == index_size, + lambda: f"Number of indices ({index_size}) should be equal to tensor.size(dim) ({tensor_size}), for {dim=}", + ) + if alpha != 1: + python_type = utils.dtype_to_type(x.dtype) + torch._check( + python_type is bool + or utils.is_weakly_lesser_type(type(alpha), python_type), + lambda: f"alpha argument of type {type(alpha)} cannot be safely cast to type {python_type}!", + ) + tensor = tensor * alpha + # Treat scalars as elements of \R^1 + zero_dim = x.ndim == 0 + x1 = x.unsqueeze(0) if zero_dim else x + idx = (None,) * dim + (index,) + index_put = aten.index_put_ if inplace else aten.index_put + out = index_put(x1, idx, tensor, accumulate=True) + if inplace: + return x + else: + return out.squeeze(0) if zero_dim else out.contiguous() + + +@register_decomposition(aten.pad_sequence.default) +@aten.pad_sequence.default.py_impl(DispatchKey.CompositeImplicitAutograd) +def pad_sequence(sequences, batch_first=False, padding_value=0.0): + torch._check(len(sequences) > 0, lambda: "received an empty list of sequences") + sequences_size = len(sequences) + max_size = sequences[0].size() + trailing_dims = max_size[1:] + max_len = max(x.size(0) for x in sequences) + if batch_first: + out_dims = (sequences_size, max_len) + else: + out_dims = (max_len, sequences_size) + out_dims = out_dims + trailing_dims + out = sequences[0].new_full(out_dims, padding_value) + dim_paddings = (0, 0) * len(trailing_dims) + for i in range(sequences_size): + currseq = sequences[i] + row = aten.constant_pad_nd( + currseq, dim_paddings + (0, max_len - currseq.size(0)), padding_value + ) + if batch_first: + out = aten.select_scatter(out, row, dim=0, index=i) + else: + out = aten.select_scatter(out, row, dim=1, index=i) + return out + + +@register_decomposition(aten.index_copy_) +def index_copy_(x: TensorLike, dim: int, index: TensorLike, tensor: TensorLike): + return _index_copy(x, dim, index, tensor, inplace=True) + + +@register_decomposition(aten.index_copy) +@out_wrapper() +def index_copy(x: TensorLike, dim: int, index: TensorLike, tensor: TensorLike): + return _index_copy(x, dim, index, tensor, inplace=False) + + +def _index_copy( + x: TensorLike, dim: int, index: TensorLike, tensor: TensorLike, *, inplace: bool +): + dim = utils.canonicalize_dims(x.ndim, dim) + torch._check( + index.ndim <= 1, + lambda: f"Index should have dimension 1 or 0 (got {index.ndim})", + ) + # Treat scalars as elements of \R^1 + zero_dim = x.ndim == 0 + x1 = x.unsqueeze(0) if zero_dim else x + index = index.unsqueeze(0) if index.ndim == 0 else index + idx = (None,) * dim + (index,) + index_put = aten.index_put_ if inplace else aten.index_put + out = index_put(x1, idx, tensor) + if inplace: + return x + else: + return out.squeeze(0) if zero_dim else out.contiguous() + + +# nb: Should use acc_t, not op_math +@register_decomposition(aten.log_sigmoid_forward) +@out_wrapper("output", "buffer") +@pw_cast_for_opmath +def log_sigmoid_forward(self: Tensor) -> tuple[Tensor, Tensor]: + min = torch.minimum(self.new_zeros(()), self) + z = torch.exp(-torch.abs(self)) + if self.is_cuda or self.is_xpu: + buffer = self.new_zeros((0,)) + else: + buffer = z + return min - torch.log1p(z), buffer + + +@register_decomposition(aten.uniform) +@out_wrapper() +def uniform( + x: Tensor, + low: Union[bool, int, float] = 0.0, + high: Union[bool, int, float] = 1.0, + generator: Optional[torch.Generator] = None, +): + return prims._uniform_helper( + x.shape, + low=sym_float(low), + high=sym_float(high), + dtype=x.dtype, + device=x.device, + generator=generator, + ) + + +@register_decomposition(aten.uniform_) +def uniform_(self, low=0, high=1, generator=None): + return self.copy_(uniform(self, low, high, generator)) + + +# aten/src/ATen/native/UpSample.cpp compute_output_size +def upsample_compute_output_size(input_size, output_size, scale_factors): + spatial_dimensions = len(input_size) - 2 + if output_size is not None: + torch._check( + scale_factors is None, + lambda: "Must specify exactly one of output_size and scale_factors", + ) + torch._check(len(output_size) == spatial_dimensions, lambda: "") + return output_size + if scale_factors is not None: + # NB: this isn't necessary lol + torch._check( + output_size is None, + lambda: "Must specify exactly one of output_size and scale_factors", + ) + torch._check(len(scale_factors) == spatial_dimensions, lambda: "") + output_size = [] + for i, s in enumerate(scale_factors): + if int(s) == s: + output_size.append(input_size[i + 2] * int(s)) + else: + output_size.append(sym_int(input_size[i + 2] * s)) + return output_size + torch._check( + False, lambda: "Must specify exactly one of output_size and scale_factors" + ) + + +def get_scale_value(scales, idx): + if scales is None: + return None + return scales[idx] + + +@register_decomposition(aten.upsample_nearest1d.vec) +@register_decomposition(aten.upsample_nearest2d.vec) +@register_decomposition(aten.upsample_nearest3d.vec) +@aten.upsample_nearest1d.vec.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten.upsample_nearest1d.vec.py_impl(DispatchKey.Autograd) +@aten.upsample_nearest2d.vec.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten.upsample_nearest2d.vec.py_impl(DispatchKey.Autograd) +@aten.upsample_nearest3d.vec.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten.upsample_nearest3d.vec.py_impl(DispatchKey.Autograd) +def _upsample_nearest_vec( + input: Tensor, + output_size: Optional[list[int]], + scale_factors: Optional[list[float]], +) -> Tensor: + osize = upsample_compute_output_size(input.size(), output_size, scale_factors) + scales = ( + scale_factors if scale_factors else [None] * len(osize) # type: ignore[list-item] + ) + return _upsample_nearest(input, osize, scales) + + +@register_decomposition(aten._upsample_nearest_exact1d.vec) +@register_decomposition(aten._upsample_nearest_exact2d.vec) +@register_decomposition(aten._upsample_nearest_exact3d.vec) +@aten._upsample_nearest_exact1d.vec.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten._upsample_nearest_exact1d.vec.py_impl(DispatchKey.Autograd) +@aten._upsample_nearest_exact2d.vec.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten._upsample_nearest_exact2d.vec.py_impl(DispatchKey.Autograd) +@aten._upsample_nearest_exact3d.vec.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten._upsample_nearest_exact3d.vec.py_impl(DispatchKey.Autograd) +def _upsample_nearest_exact_vec( + input: Tensor, + output_size: Optional[list[int]], + scale_factors: Optional[list[float]], +) -> Tensor: + osize = upsample_compute_output_size(input.size(), output_size, scale_factors) + scales = ( + scale_factors if scale_factors else [None] * len(osize) # type: ignore[list-item] + ) + return _upsample_nearest(input, osize, scales, exact=True) + + +def _compute_upsample_nearest_indices(input, output_size, scales, exact=False): + # For each dim in output_size, compute the set of input indices used + # to produce the upsampled output. + indices = [] + num_spatial_dims = len(output_size) + offset = 0.5 if exact else 0.0 + + for d in range(num_spatial_dims): + # Math matches aten/src/ATen/native/cpu/UpSampleKernel.cpp + # + # Indices are computed as following: + # scale = isize / osize + # Case: exact=False + # input_index = floor(output_index * scale) + # Same as OpenCV INTER_NEAREST + # + # Case: exact=False + # index_f32 = (output_index + 0.5) * scale - 0.5 + # input_index = round(index_f32) + # Same as Pillow and Scikit-Image/Scipy ndi.zoom + osize = output_size[d] + isize = input.shape[-num_spatial_dims + d] + scale = isize / (isize * scales[d]) if scales[d] is not None else isize / osize + + output_indices = torch.arange(osize, dtype=torch.float32, device=input.device) + input_indices = ((output_indices + offset) * scale).to(torch.int64) + for _ in range(num_spatial_dims - 1 - d): + input_indices = input_indices.unsqueeze(-1) + indices.append(input_indices) + return indices + + +@register_decomposition([aten.upsample_nearest1d.default, aten.upsample_nearest1d.out]) +@aten.upsample_nearest1d.default.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten.upsample_nearest1d.default.py_impl(DispatchKey.Autograd) +@out_wrapper(preserve_memory_format=True, exact_dtype=True) +def upsample_nearest1d( + input: Tensor, + output_size: list[int], + scales: Optional[float] = None, +) -> Tensor: + return _upsample_nearest(input, output_size, [scales]) + + +@register_decomposition( + [aten._upsample_nearest_exact1d.default, aten._upsample_nearest_exact1d.out] +) +@aten._upsample_nearest_exact1d.default.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten._upsample_nearest_exact1d.default.py_impl(DispatchKey.Autograd) +@out_wrapper(preserve_memory_format=True, exact_dtype=True) +def upsample_nearest_exact1d( + input: Tensor, + output_size: list[int], + scales: Optional[float] = None, +) -> Tensor: + return _upsample_nearest(input, output_size, [scales], exact=True) + + +@register_decomposition([aten.upsample_nearest2d.default, aten.upsample_nearest2d.out]) +@aten.upsample_nearest2d.default.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten.upsample_nearest2d.default.py_impl(DispatchKey.Autograd) +@out_wrapper(preserve_memory_format=True, exact_dtype=True) +def upsample_nearest2d( + input: Tensor, + output_size: list[int], + scales_h: Optional[float] = None, + scales_w: Optional[float] = None, +) -> Tensor: + return _upsample_nearest(input, output_size, [scales_h, scales_w]) + + +@register_decomposition( + [aten._upsample_nearest_exact2d.default, aten._upsample_nearest_exact2d.out] +) +@aten._upsample_nearest_exact2d.default.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten._upsample_nearest_exact2d.default.py_impl(DispatchKey.Autograd) +@out_wrapper(preserve_memory_format=True, exact_dtype=True) +def _upsample_nearest_exact2d( + input: Tensor, + output_size: list[int], + scales_h: Optional[float] = None, + scales_w: Optional[float] = None, +) -> Tensor: + return _upsample_nearest(input, output_size, [scales_h, scales_w], exact=True) + + +@register_decomposition([aten.upsample_nearest3d.default, aten.upsample_nearest3d.out]) +@aten.upsample_nearest3d.default.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten.upsample_nearest3d.default.py_impl(DispatchKey.Autograd) +@out_wrapper(preserve_memory_format=True, exact_dtype=True) +def upsample_nearest3d( + input: Tensor, + output_size: list[int], + scales_d: Optional[float] = None, + scales_h: Optional[float] = None, + scales_w: Optional[float] = None, +) -> Tensor: + return _upsample_nearest(input, output_size, [scales_d, scales_h, scales_w]) + + +@register_decomposition( + [aten._upsample_nearest_exact3d.default, aten._upsample_nearest_exact3d.out] +) +@aten._upsample_nearest_exact3d.default.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten._upsample_nearest_exact3d.default.py_impl(DispatchKey.Autograd) +@out_wrapper(preserve_memory_format=True, exact_dtype=True) +def _upsample_nearest_exact3d( + input: Tensor, + output_size: list[int], + scales_d: Optional[float] = None, + scales_h: Optional[float] = None, + scales_w: Optional[float] = None, +) -> Tensor: + return _upsample_nearest( + input, output_size, [scales_d, scales_h, scales_w], exact=True + ) + + +@pw_cast_for_opmath +def _upsample_nearest( + input: Tensor, + output_size: list[int], + scales: list[Optional[float]], + exact: bool = False, +) -> Tensor: + spatial_indices = _compute_upsample_nearest_indices( + input, output_size, scales, exact=exact + ) + + indices = [None, None] + spatial_indices + result = aten._unsafe_index(input, indices) + + if result.ndim == 4: + # convert output to correct memory format, if necessary + memory_format = utils.suggest_memory_format(input) + + # following "heuristic: only use channels_last path when it's faster than the contiguous path" + n_channels = input.shape[1] + if input.device.type == "cuda" and n_channels < 4: + memory_format = torch.contiguous_format + + result = result.contiguous(memory_format=memory_format) + return result + + +def gather_params(params, has_biases, has_projections): + if has_biases and has_projections: + group_size = 5 + elif has_biases: + group_size = 4 + elif has_projections: + group_size = 3 + else: + group_size = 2 + + assert len(params) % group_size == 0, len(params) + return [ + tuple(params[i : i + group_size]) for i in range(0, len(params), group_size) + ] + + +def params_hiddens(params, hiddens, i, bidirectional): + if bidirectional: + cur_params, cur_hidden = params[2 * i], hiddens[2 * i] + bidir_params, bidir_hidden = params[2 * i + 1], hiddens[2 * i + 1] + else: + cur_params, cur_hidden = params[i], hiddens[i] + bidir_params, bidir_hidden = None, None + + return cur_params, cur_hidden, bidir_params, bidir_hidden + + +def update_hidden_for_packed(cur_hidden, last_batch_size, batch_size, hiddens): + assert last_batch_size > batch_size + hiddens.append(cur_hidden.narrow(0, batch_size, last_batch_size - batch_size)) + return cur_hidden.narrow(0, 0, batch_size) + + +def update_hidden_for_packed_reverse( + cur_hidden, last_batch_size, batch_size, inp_hidden +): + if last_batch_size == batch_size: + return cur_hidden + assert last_batch_size < batch_size + return torch.concat( + ( + cur_hidden, + inp_hidden.narrow(0, last_batch_size, batch_size - last_batch_size), + ) + ) + + +def one_layer_rnn_data( + inp, hidden, params, has_biases, hidden_fn, batch_sizes, reverse=False +): + ih_weight = params[0] + hh_weight = params[1] + ih_bias = params[2] if has_biases else None + hh_bias = params[3] if has_biases else None + + step_output = [] + hiddens: list[torch.Tensor] = [] + + last_batch_size = batch_sizes[-1] if reverse else batch_sizes[0] + cur_hidden = hidden.narrow(0, 0, last_batch_size) + split_inp = torch.split(inp, list(batch_sizes)) + if reverse: + split_inp = split_inp[::-1] + for inp in split_inp: + i = inp.shape[0] + + if last_batch_size == i: + pass # don't update cur_hidden + # this will only happen when reverse=False, since batch sizes are sorted largest -> smallest + elif reverse: + cur_hidden = update_hidden_for_packed_reverse( + cur_hidden, last_batch_size, i, hidden + ) + else: + cur_hidden = update_hidden_for_packed( + cur_hidden, last_batch_size, i, hiddens + ) + + cur_hidden = hidden_fn(inp, cur_hidden, ih_weight, ih_bias, hh_weight, hh_bias) + last_batch_size = i + step_output.append(cur_hidden) + + if reverse: + step_output.reverse() + else: + hiddens.append(cur_hidden) + hiddens.reverse() + + out = torch.cat(step_output, 0) + hidden_out = torch.cat(hiddens, 0) if not reverse else cur_hidden + return out, hidden_out + + +def rnn_cell(nonlinearity): + def inner(i, cur_hidden, ih_weight, ih_bias, hh_weight, hh_bias): + return nonlinearity(F.linear(cur_hidden, hh_weight, hh_bias) + i) + + return inner + + +def rnn_cell_data(nonlinearity): + def inner(i, cur_hidden, ih_weight, ih_bias, hh_weight, hh_bias): + i = F.linear(i, ih_weight, ih_bias) + return nonlinearity(F.linear(cur_hidden, hh_weight, hh_bias) + i) + + return inner + + +def one_layer_rnn(inp, hidden, params, has_biases, hidden_fn, reverse=False): + ih_weight = params[0] + hh_weight = params[1] + ih_bias = params[2] if has_biases else None + hh_bias = params[3] if has_biases else None + + precomputed_input = F.linear(inp, ih_weight, ih_bias) + precomputed_input = precomputed_input.flip(0) if reverse else precomputed_input + cur_hidden = hidden.unsqueeze(0) + step_output = [] + for i in precomputed_input: + cur_hidden = hidden_fn(i, cur_hidden, ih_weight, ih_bias, hh_weight, hh_bias) + step_output.append(cur_hidden) + + if reverse: + step_output.reverse() + + out = torch.cat(step_output, 0) + + return out, cur_hidden.squeeze(0) + + +def mkldnn_one_layer_lstm(inp, hidden, params, has_biases, reverse=False): + w0 = params[0] + w1 = params[1] + if has_biases: + w2 = params[2] + w3 = params[3] + else: + w2 = torch.zeros(w0.size()) + w3 = torch.zeros(w1.size()) + + hx = hidden[0].unsqueeze(0) + cx = hidden[1].unsqueeze(0) + + batch_sizes: list[int] = [] + mode = 2 # third_party/ideep/include/ideep/abstract_types.hpp: ideep::rnn_kind::LSTM = 2 + hidden_size = hx.size(2) + num_layers = 1 + + # _rnn_helper already handles bidirectional and batch_first so we hard-code them to False here + bidirectional = False + batch_first = False + + train = False + # If batch_first, inp has been permuted in _rnn_helper. Convert to contiguous here. + # Same as aten/src/ATen/native/mkldnn/RNN.cpp: mkldnn_rnn: input = input.contiguous(); + inp = inp.contiguous() + hx = hx.contiguous() + cx = cx.contiguous() + outputs = torch.ops.aten.mkldnn_rnn_layer.default( + inp, + w0, + w1, + w2, + w3, + hx, + cx, + reverse, + batch_sizes, + mode, + hidden_size, + num_layers, + has_biases, + bidirectional, + batch_first, + train, + ) + y, hy, cy = outputs[0], outputs[1], outputs[2] + return y, (hy.squeeze(0), cy.squeeze(0)) + + +def _rnn_helper( + input, + hidden, + params, + has_biases, + num_layers, + dropout, + train, + bidirectional, + batch_first, + layer_fn, +): + input = input.transpose(0, 1) if batch_first else input + final_hiddens = [] + + for i in range(num_layers): + cur_params, cur_hidden, bidir_params, bidir_hidden = params_hiddens( + params, hidden, i, bidirectional + ) + dropout = dropout if (train and num_layers < i - 1) else 0.0 + fwd_inp, fwd_hidden = layer_fn(input, cur_hidden, cur_params, has_biases) + final_hiddens.append(fwd_hidden) + + if bidirectional: + bwd_inp, bwd_hidden = layer_fn( + input, bidir_hidden, bidir_params, has_biases, reverse=True + ) + final_hiddens.append(bwd_hidden) + + if bidirectional: + input = torch.cat([fwd_inp, bwd_inp], fwd_inp.dim() - 1) # type: ignore[possibly-undefined] + else: + input = fwd_inp + + if dropout != 0 and train and i < num_layers - 1: + input = torch.dropout(input, dropout, train=True) + + input = input.transpose(0, 1) if batch_first else input + return input, final_hiddens + + +@register_decomposition(aten.rnn_tanh.input) +@aten.rnn_tanh.input.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten.rnn_tanh.input.py_impl(DispatchKey.Autograd) +def rnn_tanh_input( + input, + hx, + params, + has_biases, + num_layers, + dropout, + train, + bidirectional, + batch_first, +): + hidden = hx.unbind(0) + params = gather_params(params, has_biases, False) + out, final_hiddens = _rnn_helper( + input, + hidden, + params, + has_biases, + num_layers, + dropout, + train, + bidirectional, + batch_first, + partial(one_layer_rnn, hidden_fn=rnn_cell(torch.tanh)), + ) + return out, torch.stack(final_hiddens, 0) + + +@register_decomposition(aten.rnn_relu.input) +@aten.rnn_relu.input.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten.rnn_relu.input.py_impl(DispatchKey.Autograd) +def rnn_relu_input( + input, + hx, + params, + has_biases, + num_layers, + dropout, + train, + bidirectional, + batch_first, +): + hidden = hx.unbind(0) + params = gather_params(params, has_biases, False) + out, final_hiddens = _rnn_helper( + input, + hidden, + params, + has_biases, + num_layers, + dropout, + train, + bidirectional, + batch_first, + partial(one_layer_rnn, hidden_fn=rnn_cell(torch.relu)), + ) + return out, torch.stack(final_hiddens, 0) + + +@register_decomposition(aten.rnn_relu.data) +@aten.rnn_relu.data.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten.rnn_relu.data.py_impl(DispatchKey.Autograd) +def rnn_relu_data( + data, + batch_sizes, + hx, + params, + has_biases, + num_layers, + dropout, + train, + bidirectional, +): + hidden = hx.unbind(0) + params = gather_params(params, has_biases, False) + out, final_hiddens = _rnn_helper( + data, + hidden, + params, + has_biases, + num_layers, + dropout, + train, + bidirectional, + False, + partial( + one_layer_rnn_data, + batch_sizes=batch_sizes, + hidden_fn=rnn_cell_data(torch.relu), + ), + ) + return out, torch.stack(final_hiddens, 0) + + +@register_decomposition(aten.rnn_tanh.data) +@aten.rnn_tanh.data.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten.rnn_tanh.data.py_impl(DispatchKey.Autograd) +def rnn_tanh_data( + data, + batch_sizes, + hx, + params, + has_biases, + num_layers, + dropout, + train, + bidirectional, +): + hidden = hx.unbind(0) + params = gather_params(params, has_biases, False) + out, final_hiddens = _rnn_helper( + data, + hidden, + params, + has_biases, + num_layers, + dropout, + train, + bidirectional, + False, + partial( + one_layer_rnn_data, + batch_sizes=batch_sizes, + hidden_fn=rnn_cell_data(torch.tanh), + ), + ) + return out, torch.stack(final_hiddens, 0) + + +def lstm_cell(inp, hx, cx, hh_weight, hh_bias, hr_weight, chunk_dim): + gates = F.linear(hx, hh_weight, hh_bias) + inp + chunked_gates = gates.chunk(4, chunk_dim) + in_gate = chunked_gates[0].sigmoid() + forget_gate = chunked_gates[1].sigmoid() + cell_gate = chunked_gates[2].tanh() + out_gate = chunked_gates[3].sigmoid() + cy = forget_gate * cx + (in_gate * cell_gate) + hy = out_gate * cy.tanh() + hy = hy if hr_weight is None else F.linear(hy, hr_weight, None) + + return hy, cy + + +def one_layer_lstm(inp, hidden, params, has_biases, reverse=False): + ih_weight = params[0] + hh_weight = params[1] + ih_bias = params[2] if has_biases else None + hh_bias = params[3] if has_biases else None + hr_weight = ( + params[4] if len(params) == 5 else params[2] if len(params) == 3 else None + ) + + hx = hidden[0].unsqueeze(0) + cx = hidden[1].unsqueeze(0) + + precomputed_input = F.linear(inp, ih_weight, ih_bias) + precomputed_input = precomputed_input.flip(0) if reverse else precomputed_input + step_output = [] + for inp in precomputed_input: + hx, cx = lstm_cell(inp, hx, cx, hh_weight, hh_bias, hr_weight, chunk_dim=2) + step_output.append(hx) + + if reverse: + step_output.reverse() + + out = torch.cat(step_output, 0) + + return out, (hx.squeeze(1), cx.squeeze(1)) + + +def one_layer_lstm_data(inp, hidden, params, has_biases, batch_sizes, reverse=False): + ih_weight = params[0] + hh_weight = params[1] + ih_bias = params[2] if has_biases else None + hh_bias = params[3] if has_biases else None + hr_weight = ( + params[4] if len(params) == 5 else params[2] if len(params) == 3 else None + ) + + step_output = [] + hiddens = [] + + last_batch_size = batch_sizes[-1] if reverse else batch_sizes[0] + split_inp = torch.split(inp, list(batch_sizes)) + if reverse: + split_inp = split_inp[::-1] + + orig_hx = hidden[0] + orig_cx = hidden[1] + hx, cx = ( + orig_hx.narrow(0, 0, last_batch_size), + orig_cx.narrow(0, 0, last_batch_size), + ) + + for inp in split_inp: + i = inp.shape[0] + inp = F.linear(inp, ih_weight, ih_bias) + + # this will only happen when reverse=False, since batch sizes are sorted largest -> smallest + if i < last_batch_size: + hiddens.append( + ( + hx.narrow(0, i, last_batch_size - i), + cx.narrow(0, i, last_batch_size - i), + ) + ) + hx, cx = hx.narrow(0, 0, i), cx.narrow(0, 0, i) + + # this will only happen when reverse=True + if i > last_batch_size: + hx = torch.concat( + (hx, orig_hx.narrow(0, last_batch_size, i - last_batch_size)), 0 + ) + cx = torch.concat( + (cx, orig_cx.narrow(0, last_batch_size, i - last_batch_size)), 0 + ) + + hx, cx = lstm_cell(inp, hx, cx, hh_weight, hh_bias, hr_weight, chunk_dim=1) + last_batch_size = i + step_output.append(hx) + + if reverse: + step_output.reverse() + hidden_out = (hx, cx) + else: + hiddens.append((hx, cx)) + hiddens.reverse() + hidden0, hidden1 = zip(*hiddens) + hidden_out = torch.cat(hidden0, 0), torch.cat(hidden1, 0) + + out = torch.cat(step_output, 0) + return out, hidden_out + + +def select_one_layer_lstm_function(input, hx, params): + r"""Check whether we could use decompose lstm with mkldnn_rnn_layer. + All the below conditions need to be met: + * ``torch._C._get_mkldnn_enabled()`` returns ``True``. + * All the input args are on CPU. + * The dtypes of args are either torch.float or torch.bfloat16. + * Inference. + * ``has_projections`` returns ``False``. + + Args: + * input: the input sequence to LSTM + * hx: a tuple of the input hidden state and cell state ``(h_0, c_0)`` to LSTM + * params: the weight and bias tensors of LSTM + """ + + def use_mkldnn(input, hx, params): + if not torch._C._get_mkldnn_enabled(): + return False + + tensors = [input] + list(hx) + list(chain.from_iterable(params)) + devices = {t.device for t in tensors} + if len(devices) != 1: + return False + + device = devices.pop() + if device != torch.device("cpu"): + return False + # With autocast, possible to have mixed dtype here + dtypes = {t.dtype for t in tensors} + for dtype in dtypes: + if dtype not in [torch.float, torch.bfloat16]: + return False + + if input.requires_grad: + return False + + has_projections = hx[0].size(2) != hx[1].size(2) + if has_projections: + return False + + return True + + # mkldnn_one_layer_lstm does not depend on seq_len while one_layer_lstm + # will expand over the seq_len dim + if use_mkldnn(input, hx, params): + return mkldnn_one_layer_lstm + else: + return one_layer_lstm + + +@register_decomposition(aten.lstm.input) +@aten.lstm.input.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten.lstm.input.py_impl(DispatchKey.Autograd) +def lstm_impl( + input, + hx, + params, + has_biases, + num_layers, + dropout, + train, + bidirectional, + batch_first, +): + assert len(hx) == 2, "lstm expects two hidden states" + params = gather_params(params, has_biases, hx[0].size(2) != hx[1].size(2)) + hidden = list(zip(hx[0], hx[1])) + layer_fn = select_one_layer_lstm_function(input, hx, params) + out, final_hiddens = _rnn_helper( + input, + hidden, + params, + has_biases, + num_layers, + dropout, + train, + bidirectional, + batch_first, + layer_fn, + ) + final_hiddens = list(zip(*final_hiddens)) + return out, torch.stack(final_hiddens[0], 0), torch.stack(final_hiddens[1], 0) + + +@register_decomposition(aten.lstm.data) +@aten.lstm.data.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten.lstm.data.py_impl(DispatchKey.Autograd) +def lstm_data_impl( + data, + batch_sizes, + hx, + params, + has_biases, + num_layers, + dropout, + train, + bidirectional, +): + assert len(hx) == 2, "lstm expects two hidden states" + params = gather_params(params, has_biases, hx[0].size(2) != hx[1].size(2)) + hidden = list(zip(hx[0], hx[1])) + out, final_hiddens = _rnn_helper( + data, + hidden, + params, + has_biases, + num_layers, + dropout, + train, + bidirectional, + False, + partial(one_layer_lstm_data, batch_sizes=batch_sizes), + ) + final_hiddens = list(zip(*final_hiddens)) + return out, torch.stack(final_hiddens[0], 0), torch.stack(final_hiddens[1], 0) + + +def gru_cell(inp, cur_hidden, ih_weight, ih_bias, hh_weight, hh_bias): + chunked_igates = inp.chunk(3, 1) + chunked_hgates = F.linear(cur_hidden, hh_weight, hh_bias).chunk(3, 2) + reset_gate = (chunked_hgates[0] + chunked_igates[0]).sigmoid() + input_gate = (chunked_hgates[1] + chunked_igates[1]).sigmoid() + new_gate = (chunked_igates[2] + (chunked_hgates[2] * reset_gate)).tanh() + return (cur_hidden - new_gate) * input_gate + new_gate + + +def gru_cell_data(inp, cur_hidden, ih_weight, ih_bias, hh_weight, hh_bias): + chunked_igates = F.linear(inp, ih_weight, ih_bias).chunk(3, 1) + chunked_hgates = F.linear(cur_hidden, hh_weight, hh_bias).chunk(3, 1) + reset_gate = (chunked_hgates[0] + chunked_igates[0]).sigmoid() + input_gate = (chunked_hgates[1] + chunked_igates[1]).sigmoid() + new_gate = (chunked_igates[2] + (chunked_hgates[2] * reset_gate)).tanh() + return (cur_hidden - new_gate) * input_gate + new_gate + + +@register_decomposition(aten.gru.data) +@aten.gru.data.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten.gru.data.py_impl(DispatchKey.Autograd) +def gru_impl_data( + data, + batch_sizes, + hx, + params, + has_biases, + num_layers, + dropout, + train, + bidirectional, +): + params = gather_params(params, has_biases, False) + out, final_hiddens = _rnn_helper( + data, + hx.unbind(0), + params, + has_biases, + num_layers, + dropout, + train, + bidirectional, + False, + partial(one_layer_rnn_data, batch_sizes=batch_sizes, hidden_fn=gru_cell_data), + ) + return out, torch.stack(final_hiddens, 0) + + +@register_decomposition(aten.gru.input) +@aten.gru.input.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten.gru.input.py_impl(DispatchKey.Autograd) +def gru_impl( + input, + hx, + params, + has_biases, + num_layers, + dropout, + train, + bidirectional, + batch_first, +): + params = gather_params(params, has_biases, False) + out, final_hiddens = _rnn_helper( + input, + hx.unbind(0), + params, + has_biases, + num_layers, + dropout, + train, + bidirectional, + batch_first, + partial(one_layer_rnn, hidden_fn=gru_cell), + ) + return out, torch.stack(final_hiddens, 0) + + +@register_decomposition(aten._upsample_bilinear2d_aa.vec) +@aten._upsample_bilinear2d_aa.vec.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten._upsample_bilinear2d_aa.vec.py_impl(DispatchKey.Autograd) +def upsample_bilinear2d_aa_vec(input, output_size, align_corners, scale_factors): + osize = upsample_compute_output_size(input.size(), output_size, scale_factors) + scale_h = get_scale_value(scale_factors, 0) + scale_w = get_scale_value(scale_factors, 1) + return torch.ops.aten._upsample_bilinear2d_aa( + input, osize, align_corners, scale_h, scale_w + ) + + +@register_decomposition(aten._upsample_bicubic2d_aa.vec) +@aten._upsample_bicubic2d_aa.vec.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten._upsample_bicubic2d_aa.vec.py_impl(DispatchKey.Autograd) +def upsample_bicubic2d_aa_vec(input, output_size, align_corners, scale_factors): + osize = upsample_compute_output_size(input.size(), output_size, scale_factors) + scale_h = get_scale_value(scale_factors, 0) + scale_w = get_scale_value(scale_factors, 1) + return torch.ops.aten._upsample_bicubic2d_aa( + input, osize, align_corners, scale_h, scale_w + ) + + +@register_decomposition(aten.upsample_bilinear2d.vec) +@register_decomposition(aten.upsample_trilinear3d.vec) +@aten.upsample_linear1d.vec.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten.upsample_linear1d.vec.py_impl(DispatchKey.Autograd) +@aten.upsample_bilinear2d.vec.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten.upsample_bilinear2d.vec.py_impl(DispatchKey.Autograd) +@aten.upsample_trilinear3d.vec.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten.upsample_trilinear3d.vec.py_impl(DispatchKey.Autograd) +def _upsample_linear_vec(input, output_size, align_corners, scale_factors): + osize = upsample_compute_output_size(input.size(), output_size, scale_factors) + scales = scale_factors if scale_factors else [None] * len(osize) + return _upsample_linear(input, osize, align_corners, scales) + + +@register_decomposition([aten.upsample_linear1d.default, aten.upsample_linear1d.out]) +@out_wrapper() +def upsample_linear1d( + input: Tensor, + output_size: list[int], + align_corners: bool, + scales_w: Optional[float] = None, +) -> Tensor: + return _upsample_linear(input, output_size, align_corners, [scales_w]) + + +@register_decomposition( + [aten.upsample_bilinear2d.default, aten.upsample_bilinear2d.out] +) +@aten.upsample_bilinear2d.default.py_impl(DispatchKey.Autograd) +@out_wrapper() +def upsample_bilinear2d( + input: Tensor, + output_size: list[int], + align_corners: bool, + scales_h: Optional[float] = None, + scales_w: Optional[float] = None, +) -> Tensor: + return _upsample_linear(input, output_size, align_corners, [scales_h, scales_w]) + + +@register_decomposition( + [aten.upsample_trilinear3d.default, aten.upsample_trilinear3d.out] +) +@out_wrapper() +def upsample_trilinear3d( + input: Tensor, + output_size: list[int], + align_corners: bool, + scales_d: Optional[float] = None, + scales_h: Optional[float] = None, + scales_w: Optional[float] = None, +) -> Tensor: + return _upsample_linear( + input, output_size, align_corners, [scales_d, scales_h, scales_w] + ) + + +def _compute_scale(in_size, out_size, align_corners, scale=None): + if align_corners: + return (in_size - 1.0) / (out_size - 1.0) if out_size > 1 else 0 + else: + return 1.0 / scale if scale is not None and scale > 0 else in_size / out_size + + +def _compute_source_index(scale, dst_index, align_corners): + if align_corners: + return scale * dst_index + else: + return scale * (dst_index + 0.5) - 0.5 + + +def _sum_tensors_uint8( + src: Iterable[Tensor], weights: Iterable[Tensor], weights_precision: Tensor +) -> Tensor: + output = _sum_tensors( + s.to(torch.int32) * c.to(torch.int32) for s, c in zip(src, weights) + ) + (1 << (weights_precision - 1)) + output = output >> weights_precision + return torch.clamp(output, 0, 255).to(torch.uint8) + + +def _compute_weight_precision(weights: TensorSequenceType) -> Tensor: + max_weight = torch.stack(weights).max() + max_weight_precision = 22 + precisions = torch.arange(max_weight_precision, device=max_weight.device) + values = 0.5 + max_weight * (1 << (precisions + 1)) + mask = values >= (1 << 15) + return max_weight_precision - mask.sum() + + +@pw_cast_for_opmath +def _upsample_linear( + input: Tensor, + output_size: list[int], + align_corners: bool, + scales: list[Optional[float]], +) -> Tensor: + # get dimensions of original image + n_channels = input.shape[1] + inp_sizes = input.shape[2:] + n_dims = len(inp_sizes) + + _, dtype = utils.elementwise_dtypes( + input, + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + ) + + def get_values(inp_size, out_size, scales, nsqueeze): + # First Calculate scaling factor + scale_factor = _compute_scale(inp_size, out_size, align_corners, scales) + # We have to create arange with int64 dtype and use .to in order to avoid + # additional kernels creation in inductor and get a perf slowdown + i = torch.arange(out_size, device=input.device).to(dtype=dtype) + + x_f32 = _compute_source_index(scale_factor, i, align_corners).clamp(min=0.0) + x_f32 = x_f32.reshape(x_f32.shape[0], *[1] * (nsqueeze)) + x = x_f32.to(torch.int64) + xp1 = (x + 1).clamp(max=inp_size - 1) + return x_f32, x, xp1 + + values = [ + get_values(inp_size, out_size, scales, n_dims - 1 - i) + for i, (inp_size, out_size, scales) in enumerate( + zip(inp_sizes, output_size, scales) + ) + ] + xs_f32, xs, xp1s = list(zip(*values)) + + vs = [] + for a in product(*[[0, 1]] * n_dims): + idx = [None, None] + [xs[k] if a[k] == 0 else xp1s[k] for k in range(n_dims)] + v = aten._unsafe_index(input, idx) + v = _maybe_convert_to_dtype(v, dtype) + vs.append(v) + + for i in reversed(range(n_dims)): + xscale = (xs_f32[i] - xs[i]).clamp(0.0, 1.0).to(dtype) + vs = [ + # x1 * (1 - alpha) + x2 * alpha == x1 + (x2 - x1) * alpha + v1 + torch.mul(v2 - v1, xscale) + for v1, v2 in zip(vs[::2], vs[1::2]) + ] + + assert len(vs) == 1 + result = vs[0] + + # convert output to correct memory format, if necessary + memory_format = utils.suggest_memory_format(input) + + # following "heuristic: only use channels_last path when it's faster than the contiguous path" + if input.device.type == "cuda" and n_channels < 16: + memory_format = torch.contiguous_format + + assert isinstance(result, torch.Tensor) + + result = result.contiguous(memory_format=memory_format) + + if not input.is_floating_point(): + result = result.round() + + return result + + +# We should be applying decompositions after all transformations +@register_decomposition(aten.is_same_size.default) +def is_same_size(a: Tensor, b: Tensor) -> bool: + return a.shape == b.shape + + +@register_decomposition([aten._reshape_alias, aten._unsafe_view]) +@out_wrapper() +def _reshape_alias(x, shape, *args): + return aten.view(x, shape) + + +@register_decomposition([aten._unsafe_index]) +def _unsafe_index(x, indices): + return aten.index(x, indices) + + +@register_decomposition([aten._unsafe_index_put]) +def _unsafe_index_put(x, indices, value, accumulate=False): + return aten.index_put(x, indices, value, accumulate) + + +@register_decomposition([aten._unsafe_masked_index]) +def _unsafe_masked_index(x, mask, indices, fill): + for index in indices: + if index is not None: + torch._check( + index.dtype in [torch.long, torch.int], + lambda: "tensors used as indices must be long or int tensors", + ) + + torch._check( + mask.dtype == torch.bool, + lambda: "tensors used as masks must be bool tensors", + ) + + from torch.fx.experimental.symbolic_shapes import guard_or_false + + if guard_or_false(x.numel() == 0): + meta_result = torch._meta_registrations.meta_index_Tensor(x, indices) + return x.new_full(meta_result.shape, fill) + + for i in range(len(indices)): + index = indices[i] + if index is not None: + indices[i] = index.clamp(min=0, max=x.size(i) - 1) + + return aten._unsafe_index(x, indices).masked_fill(~mask, fill) + + +@register_decomposition([aten._unsafe_masked_index_put_accumulate]) +def _unsafe_masked_index_put_accumulate(x, mask, indices, values): + for index in indices: + if index is not None: + torch._check( + index.dtype in [torch.long, torch.int], + lambda: "tensors used as indices must be long or int tensors", + ) + + torch._check( + mask.dtype == torch.bool, + lambda: "tensors used as masks must be bool tensors", + ) + + if x.numel() == 0: + return x.clone() + + for i in range(len(indices)): + index = indices[i] + if index is not None: + indices[i] = index.clamp(min=-x.size(i), max=x.size(i) - 1) + + masked_value = values.masked_fill(~mask, 0) + return aten._unsafe_index_put(x, indices, masked_value, accumulate=True) + + +def _nll_loss_forward( + self: Tensor, + target: Tensor, + weight: Optional[Tensor], + reduction: int, + ignore_index: int, +) -> tuple[Tensor, Tensor]: + # self can be [N, C] or [C] + # target can be [N] or [] + + n_dims = self.dim() + channel_dim = 1 + if n_dims < 2: + channel_dim = 0 + + if weight is not None: + if n_dims > 1: + shape = [ + 1, + ] * n_dims + shape[channel_dim] = weight.shape[0] + w = weight.view(shape) + else: + w = weight + self = self * w + safe_target = torch.where(target != ignore_index, target, 0) + safe_target_ = safe_target.unsqueeze(channel_dim) + # target can be [N, 1] or [1] + + result = -torch.gather(self, channel_dim, safe_target_).squeeze(channel_dim) + + result = torch.where(target != ignore_index, result, 0) + + if reduction == Reduction.NONE.value and n_dims > 1: + total_weight = self.new_full((), 0.0) + return result, total_weight + + if weight is not None: + # pyrefly: ignore [unbound-name] + w = w.expand(self.shape) + wsum = torch.gather(w, channel_dim, safe_target_).squeeze(channel_dim) + wsum = torch.where(target != ignore_index, wsum, 0) + total_weight = wsum.sum() + else: + total_weight = (target != ignore_index).sum().to(self) + + if reduction == Reduction.SUM.value: + result = result.sum() + elif reduction == Reduction.MEAN.value: + result = result.sum() / total_weight + + return result, total_weight + + +@register_decomposition(aten.nll_loss_forward) +@out_wrapper("output", "total_weight") +def nll_loss_forward( + self: Tensor, + target: Tensor, + weight: Optional[Tensor], + reduction: int, + ignore_index: int, +) -> tuple[Tensor, Tensor]: + assert self.dim() > 0 and self.dim() <= 2, "input tensor should be 1D or 2D" + assert target.dim() <= 1, ( + "0D or 1D target tensor expected, multi-target not supported" + ) + + no_batch_dim = self.dim() == 1 and target.dim() == 0 + assert no_batch_dim or (self.shape[0] == target.shape[0]), ( + f"size mismatch (got input: {self.shape}, target: {target.shape})" + ) + + n_classes = self.shape[-1] + + assert weight is None or (weight.dim() == 1 and weight.numel() == n_classes), ( + f"weight tensor should be defined either for all {n_classes} classes or no classes " + f"but got weight tensor of shape: {weight.shape}" + ) + + return _nll_loss_forward(self, target, weight, reduction, ignore_index) + + +@register_decomposition(aten.nll_loss2d_forward) +@out_wrapper("output", "total_weight") +def nll_loss2d_forward( + self: Tensor, + target: Tensor, + weight: Optional[Tensor], + reduction: int, + ignore_index: int, +) -> tuple[Tensor, Tensor]: + return _nll_loss_forward(self, target, weight, reduction, ignore_index) + + +# These are adapted from aten/src/ATen/native/UpSample.h, which is based on +# https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm +def _upsample_cubic_convolution1(x: Tensor, A: float) -> Tensor: + return ((A + 2) * x - (A + 3)) * x * x + 1 + + +def _upsample_cubic_convolution2(x: Tensor, A: float) -> Tensor: + return ((A * x - 5 * A) * x + 8 * A) * x - 4 * A + + +def _upsample_get_cubic_coefficients(t: Tensor) -> TensorSequenceType: + A = -0.75 + + if t.device == torch.device("cpu"): + tt1 = torch.stack([t, 1.0 - t], dim=0) + tt2 = torch.stack([t + 1.0, 2.0 - t], dim=0) + w03 = _upsample_cubic_convolution2(tt2, A) + w12 = _upsample_cubic_convolution1(tt1, A) + w0, w3 = torch.unbind(w03, dim=0) + w1, w2 = torch.unbind(w12, dim=0) + return w0, w1, w2, w3 + else: + return ( + _upsample_cubic_convolution2(t + 1.0, A), + _upsample_cubic_convolution1(t, A), + _upsample_cubic_convolution1(1.0 - t, A), + _upsample_cubic_convolution2(2.0 - t, A), + ) + + +def _upsample_cubic_interp1d(coeffs: TensorSequenceType, ts: Tensor) -> Tensor: + coeffs2 = _upsample_get_cubic_coefficients(ts) + return _sum_tensors(c1 * c2 for (c1, c2) in zip(coeffs, coeffs2)) + + +# Need this instead of just sum() to keep mypy happy +def _sum_tensors(ts: Iterable[Tensor]) -> Tensor: + return reduce(torch.add, ts) + + +def _linspace_from_neg_one( + num_steps: int, align_corners: bool, dtype: torch.dtype, device: torch.device +): + if num_steps <= 1: + return torch.tensor(0, device=device, dtype=dtype) + + a = ((num_steps - 1) / num_steps) if not align_corners else 1 + return torch.linspace(-a, a, steps=num_steps, device=device, dtype=dtype) + + +def _make_base_grid_4d(theta: Tensor, h: int, w: int, align_corners: bool): + dtype = theta.dtype + device = theta.device + + # Using padding and summation generates a single kernel vs using torch.stack where 3 kernels generated + # corresponding to each individual tensor: grid_x, grid_y, grid_one + grid_x = _linspace_from_neg_one(w, align_corners, dtype, device).view(1, w, 1) + grid_y = _linspace_from_neg_one(h, align_corners, dtype, device).view(h, 1, 1) + grid_one = torch.ones((1, 1, 1), dtype=dtype, device=device) + + # this is just a temporary hack and we should use torch.stack here once #104480 is merged + grid_x = torch.nn.functional.pad(grid_x, pad=(0, 2), mode="constant", value=0) + grid_y = torch.nn.functional.pad(grid_y, pad=(1, 1), mode="constant", value=0) + grid_one = torch.nn.functional.pad(grid_one, pad=(2, 0), mode="constant", value=0) + return grid_x + grid_y + grid_one + + +def _make_base_grid_5d(theta: Tensor, d: int, h: int, w: int, align_corners: bool): + dtype = theta.dtype + device = theta.device + + grid_x = _linspace_from_neg_one(w, align_corners, dtype, device).view(1, 1, w, 1) + grid_y = _linspace_from_neg_one(h, align_corners, dtype, device).view(1, h, 1, 1) + grid_z = _linspace_from_neg_one(d, align_corners, dtype, device).view(d, 1, 1, 1) + grid_one = torch.ones((1, 1, 1, 1), dtype=dtype, device=device) + + # this is just a temporary hack and we should use torch.stack here once #104480 is merged + grid_x = torch.nn.functional.pad(grid_x, pad=(0, 3), mode="constant", value=0) + grid_y = torch.nn.functional.pad(grid_y, pad=(1, 2), mode="constant", value=0) + grid_z = torch.nn.functional.pad(grid_z, pad=(2, 1), mode="constant", value=0) + grid_one = torch.nn.functional.pad(grid_one, pad=(3, 0), mode="constant", value=0) + return grid_x + grid_y + grid_z + grid_one + + +def _affine_grid_generator_4d(theta: Tensor, size: list[int], align_corners: bool): + n, _, h, w = size + base_grid = _make_base_grid_4d(theta, h, w, align_corners=align_corners) + # base_grid shape is (h, w, 3) and theta shape is (n, 2, 3) + # We do manually a matrix multiplication which is faster than mm() + # (h * w, 3, 1) * (n, 1, 3, 2) -> (n, h * w, 2) + grid = (base_grid.view(-1, 3, 1) * theta.mT.unsqueeze(1)).sum(-2) + return grid.view(n, h, w, 2) + + +def _affine_grid_generator_5d(theta: Tensor, size: list[int], align_corners: bool): + n, _, d, h, w = size + base_grid = _make_base_grid_5d(theta, d, h, w, align_corners=align_corners) + # base_grid shape is (d, h, w, 4) and theta shape is (n, 3, 4) + # We do manually a matrix multiplication which is faster than mm() + # (d * h * w, 4, 1) * (n, 1, 4, 3) -> (n, h * w, 3) + grid = (base_grid.view(-1, 4, 1) * theta.mT.unsqueeze(1)).sum(-2) + return grid.view(n, d, h, w, 3) + + +@register_decomposition(aten.affine_grid_generator) +@out_wrapper() +@pw_cast_for_opmath +def affine_grid_generator(theta: Tensor, size: list[int], align_corners: bool): + torch._check( + len(size) in (4, 5), + lambda: "affine_grid_generator needs 4d (spatial) or 5d (volumetric) inputs.", + ) + if len(size) == 4: + return _affine_grid_generator_4d(theta, size, align_corners=align_corners) + else: + return _affine_grid_generator_5d(theta, size, align_corners=align_corners) + + +def _grid_sampler_2d( + a: Tensor, + grid: Tensor, + interpolation_mode: int = 0, + padding_mode: int = 0, + align_corners: bool = False, + _expand_grid: bool = True, +) -> Tensor: + # This method is a copy of grid_sampler_2d implementation and introduced with additional arg _expand_grid to + # optionally expand the input grid for performance reasons. + # Experimenting locally it was found that compiled CUDA code is accelerated by ~5x + # and CPU code by ~2x on bicubic mode, if we expand the grid from (N, H, W, 2) into (N, C, H, W, 2) + # However, this leads to a slowdown around ~0.8x on CPU bilinear mode, channels first. + # Thus we apply this hack to not expand the grid for this case. + + torch._check( + interpolation_mode in (0, 1, 2), + lambda: f"Invalid interpolation mode {interpolation_mode}", + ) + torch._check( + padding_mode in (0, 1, 2), lambda: f"Invalid padding mode {padding_mode}" + ) + + def unnormalize(coords: Tensor, size: int) -> Tensor: + # Rescale coordinates from [-1, 1] to: + # [0, size - 1] if align_corners is True + # [-.5, size -.5] if align_corners is False + mul = (size * 0.5 - 0.5) if align_corners else (size * 0.5) + ofs = size * 0.5 - 0.5 + return coords * mul + ofs + + # Reflects coordinates until they fall between low and high (inclusive). + # The bounds are passed as twice their value so that half-integer values + # can be represented as ints. + def reflect_coordinates(coords: Tensor, twice_low: int, twice_high: int) -> Tensor: + if twice_low == twice_high: + return torch.zeros_like(coords) + coords_min = twice_low / 2 + coords_span = (twice_high - twice_low) / 2 + coords2 = (coords - coords_min).abs() + extra = torch.fmod(coords2, coords_span) + flips = (coords2 / coords_span).floor().to(dtype=torch.int8) + return torch.where( + flips & 1 == 0, extra + coords_min, coords_span + coords_min - extra + ) + + def compute_coordinates(coords: Tensor, size: int) -> Tensor: + if padding_mode == 0: # Zero + return coords + elif padding_mode == 1: # Borders + return torch.clamp(coords, 0, size - 1) + else: # padding_mode == 2, Reflection + if align_corners: + coords_reflected = reflect_coordinates(coords, 0, 2 * (size - 1)) + else: + coords_reflected = reflect_coordinates(coords, -1, 2 * size - 1) + return torch.clamp(coords_reflected, 0, size - 1) + + def compute_source_index(coords: Tensor, size: int) -> Tensor: + coords_un = unnormalize(coords, size) + return compute_coordinates(coords_un, size) + + N, C, iH, iW = a.shape + _, oH, oW, two = grid.shape + assert two == 2 + + if _expand_grid: + # Let's expand grid to [N, C, oH, oW, 2] + # This allows to generate a single triton cuda kernel instead of two kernels. + # Two kernels are due source indices, weights have shape (N, 1, oH, oW), xnumel=N*oH*oW + # and output has shape (N, C, oH, oW), xnumel=N*C*oH*oW + # Expanding grid to (N, C, oH, oW, two) unifies xnumel to N*C*oH*oW + grid = grid.view(N, 1, oH, oW, two).expand(N, C, oH, oW, 2) + + def in_bounds_cond(xs: Tensor, ys: Tensor) -> Tensor: + return torch.logical_and( + 0 <= xs, torch.logical_and(xs < iW, torch.logical_and(0 <= ys, ys < iH)) + ) + + N_idx = torch.arange(N, device=a.device).view(N, 1, 1, 1) + C_idx = torch.arange(C, device=a.device).view(1, C, 1, 1) + + def clip(xs: Tensor, ys: Tensor, ws: Tensor) -> TensorSequenceType: + cond = in_bounds_cond(xs, ys) + # To clip to inside valid coordinates, we map the coordinates + # to (x, y) = (0, 0) and also set the weight to 0 + # We also change the shape of the tensor to the appropriate one for + # broadcasting with N_idx, C_idx for the purposes of advanced indexing + c = C if _expand_grid else 1 + return tuple( + torch.where(cond, t, 0).view(N, c, oH, oW) + for t in (xs.to(dtype=torch.int64), ys.to(dtype=torch.int64), ws) + ) + + def get_summand(ix: Tensor, iy: Tensor, w) -> Tensor: + # Perform clipping, index into input tensor and multiply by weight + idx_x, idx_y, w_ = clip(ix, iy, w) + return a[N_idx, C_idx, idx_y, idx_x] * w_ + + x = grid[..., 0] + y = grid[..., 1] + + if interpolation_mode == 0: # Bilinear + ix = compute_source_index(x, iW) + iy = compute_source_index(y, iH) + + ix_nw, iy_nw = ix.floor(), iy.floor() + ix_ne, iy_ne = ix_nw + 1, iy_nw + ix_sw, iy_sw = ix_nw, iy_nw + 1 + ix_se, iy_se = ix_ne, iy_sw + + w_nw = (ix_se - ix) * (iy_se - iy) + w_ne = (ix - ix_sw) * (iy_sw - iy) + w_sw = (ix_ne - ix) * (iy - iy_ne) + w_se = (ix - ix_nw) * (iy - iy_nw) + + return _sum_tensors( + get_summand(ix, iy, w) + for (ix, iy, w) in ( + (ix_nw, iy_nw, w_nw), + (ix_ne, iy_ne, w_ne), + (ix_sw, iy_sw, w_sw), + (ix_se, iy_se, w_se), + ) + ) + elif interpolation_mode == 1: # Nearest + ix = compute_source_index(x, iW) + iy = compute_source_index(y, iH) + + ix_nearest = ix.round() + iy_nearest = iy.round() + + return get_summand(ix_nearest, iy_nearest, 1) + else: # interpolation_mode == 2, Bicubic + ix = unnormalize(x, iW) + iy = unnormalize(y, iH) + + ix_nw = ix.floor() + iy_nw = iy.floor() + + tx = ix - ix_nw + ty = iy - iy_nw + + if not _expand_grid: + tx = tx.unsqueeze(1) + ty = ty.unsqueeze(1) + + def get_value_bounded(ix: Tensor, iy: Tensor) -> Tensor: + x = compute_coordinates(ix, iW) + y = compute_coordinates(iy, iH) + return get_summand(x, y, 1) + + def get_coeff(ofs: int) -> Tensor: + iy_ofs = iy_nw + (ofs - 1) + cs = ( + get_value_bounded(ix_nw - 1, iy_ofs), + get_value_bounded(ix_nw, iy_ofs), + get_value_bounded(ix_nw + 1, iy_ofs), + get_value_bounded(ix_nw + 2, iy_ofs), + ) + return _upsample_cubic_interp1d(cs, tx) + + coeffs = tuple(get_coeff(ofs) for ofs in range(4)) + return _upsample_cubic_interp1d(coeffs, ty) + + +@register_decomposition(aten.grid_sampler_2d) +@out_wrapper() +@pw_cast_for_opmath +def grid_sampler_2d( + a: Tensor, + grid: Tensor, + interpolation_mode: int = 0, + padding_mode: int = 0, + align_corners: bool = False, +) -> Tensor: + return _grid_sampler_2d( + a, + grid=grid, + interpolation_mode=interpolation_mode, + padding_mode=padding_mode, + align_corners=align_corners, + ) + + +@register_decomposition(aten.mv) +@out_wrapper(exact_dtype=True) +@pw_cast_for_opmath +def mv(self, vec): + torch._check( + self.dim() == 2 and vec.dim() == 1, + lambda: f"matrix @ vector expected, got {self.dim()}, {vec.dim()}", + ) + torch._check( + self.size(1) == vec.size(0), + lambda: f"size mismatch, got input ({self.size(0)}x{self.size(1)}), vec ({vec.size(0)})", + ) + return (self * vec).sum(dim=1) + + +@register_decomposition(aten.binary_cross_entropy_with_logits) +@out_wrapper() +def binary_cross_entropy_with_logits( + self, target, weight=None, pos_weight=None, reduction=Reduction.MEAN.value +): + if pos_weight is not None: + log_weight = (pos_weight - 1) * target + 1 + loss = (1 - target) * self - (log_weight * F.logsigmoid(self)) + else: + loss = (1 - target) * self - F.logsigmoid(self) + + if weight is not None: + loss = loss * weight + + return apply_loss_reduction(loss, reduction) + + +def should_fold(tensor1: torch.Tensor, tensor2: torch.Tensor, is_out: bool) -> bool: + # For comments of the logic of this function see eager in /native/LinearAlgebra.cpp + + t1, t2 = (tensor1, tensor2) if tensor1.ndim >= tensor2.ndim else (tensor2, tensor1) + + from torch.fx.experimental.symbolic_shapes import guard_or_false + + if not (t1.ndim >= 3 and t2.ndim <= 2): + return False + if t2.requires_grad and not is_out: + return True + if tensor1.ndim == 2: + return False + if guard_or_false(t1.numel() == 0): + return True + + t1_shape = t1.shape + t1_stride = t1.stride() + + # Check the contiguous, we can skip the dim with size of 1 + # as aten: https://github.com/pytorch/pytorch/blob/e201460f8aa1510b4c4686627d57b69756c4b916/aten/src/ATen/TensorGeometry.cpp#L17 + expected_stride = [1] + for size in reversed(t1_shape[1:]): + expected_stride.append(size * expected_stride[-1]) + return all( + guard_or_false(size == 1) or guard_or_false(left == right) + for left, right, size in zip( + t1_stride, list(reversed(expected_stride)), t1_shape + ) + ) + + +@aten.matmul.default.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten.matmul.out.py_impl(DispatchKey.CompositeImplicitAutograd) +@out_wrapper(pass_is_out=True) +def matmul(tensor1, tensor2, *, is_out=False): + from torch.fx.experimental.symbolic_shapes import guard_or_false, guard_or_true + + dim_tensor1 = tensor1.dim() + dim_tensor2 = tensor2.dim() + assert dim_tensor1 != 0 and dim_tensor2 != 0 + if dim_tensor1 == 1 and dim_tensor2 == 1: + return torch.dot(tensor1, tensor2) + elif dim_tensor1 == 2 and dim_tensor2 == 1: + return torch.mv(tensor1, tensor2) + elif dim_tensor1 == 1 and dim_tensor2 == 2: + return torch.squeeze(torch.mm(torch.unsqueeze(tensor1, 0), tensor2), 0) + elif dim_tensor1 == 2 and dim_tensor2 == 2: + return torch.mm(tensor1, tensor2) + elif should_fold(tensor1, tensor2, is_out): + # dim_tensor1 >=3 && (dim_tensor2 == 1 || dim_tensor2 == 2) || + # dim_tensor2 >=3 && (dim_tensor1 == 1 || dim_tensor1 == 2) + # and some condition on the strides is fulfilled + + # optimization: use mm instead of bmm by folding the batch of the larger tensor + # into its leading matrix dimension + transpose = dim_tensor2 > dim_tensor1 + t1 = tensor2.mT if transpose else tensor1 + t2 = ( + tensor2 if not transpose else (tensor1.t() if dim_tensor1 == 2 else tensor1) + ) + # Invariant: t1.dim() >= 3 && (t2.dim() == 1 || t2.dim() == 2) + # and t1 and t2 are matmul-compatible + + # Why not t1.view(-1, sizes_1[-1])? + # If the last dim is 0, then view(-1, 0) won't work because the -1 becomes ambiguous. + # This can happen in e.g. [3, 5, 0] @ [0, 0]. + sizes_1 = t1.shape + output_shape = list(sizes_1[:-1]) + folded_dim1 = reduce(operator.mul, output_shape) + + # Readjust output_shape if we are multiplying by a matrix + t2_is_matrix = t2.dim() == 2 + if t2_is_matrix: + output_shape.append(t2.shape[1]) + + # This will almost always be a view. + # It may not be a view if t2->requires_grad(). See should_fold in aten/ for an explanation + t1_folded = t1.reshape(folded_dim1, sizes_1[-1]) + if t2_is_matrix: + # This copies if we perform a 2D @ 3D and the first tensor requires_grad + # See should_fold native/LinearAlgebra.cpp for why. + output = torch.ops.aten._unsafe_view(t1_folded.mm(t2), output_shape) + return output.mT.contiguous() if transpose else output + else: + return torch.ops.aten._unsafe_view(t1_folded.mv(t2), output_shape) + + elif dim_tensor1 >= 1 and dim_tensor2 >= 1: + # We are multiplying b1 x n x m1 by x2 x m2 x p (where b1 can be a list); + # we track m1 vs m2 separately even though they must match for nicer error messages + n = tensor1.size(-2) if dim_tensor1 > 1 else 1 + m1 = tensor1.size(-1) + batch_tensor1 = tensor1.shape[:-2] + m2 = tensor2.size(-2) if dim_tensor2 > 1 else tensor2.size(-1) + p = tensor2.size(-1) if dim_tensor2 > 1 else 1 + + batch_tensor2: list[int] = [] + # TODO: handling of slice + for i in range(dim_tensor2 - 2): + batch_tensor2.append(tensor2.size(i)) + + # Same optimization for the gradients as that in should_fold + # If we're going to broadcast, we force it to go through the should_fold branch + if ( + dim_tensor1 == 3 + and dim_tensor2 == 3 + and guard_or_true(batch_tensor1[0] != batch_tensor2[0]) + ): + if guard_or_false(batch_tensor1[0] == 1) and tensor1.requires_grad: + return matmul(tensor1.squeeze(0), tensor2) + if guard_or_false(batch_tensor2[0] == 1) and tensor2.requires_grad: + return matmul(tensor1, tensor2.squeeze(0)) + + # expand the batch portion (i.e. cut off matrix dimensions and expand rest) + expand_batch_portion = list( + torch.broadcast_shapes(batch_tensor1, batch_tensor2) + ) + + tensor1_expand_size = expand_batch_portion + [n, m1] + + expand_batch_product = prod(expand_batch_portion) + + # HACK: We need reshape with symint support + tensor1_expanded = tensor1.expand(tensor1_expand_size).reshape( + expand_batch_product, n, m1 + ) + + vector_rhs = dim_tensor2 == 1 + if vector_rhs: + tensor2_expand_size = expand_batch_portion + [m2] + tensor2_expanded = ( + tensor2.expand(tensor2_expand_size) + .reshape(expand_batch_product, m2) + .unsqueeze(2) + ) + else: + tensor2_expand_size = expand_batch_portion + [m2, p] + tensor2_expanded = tensor2.expand(tensor2_expand_size).reshape( + expand_batch_product, m2, p + ) + + output_shape = expand_batch_portion + if dim_tensor1 > 1: + output_shape.append(n) + + if dim_tensor2 > 1: + output_shape.append(p) + + if vector_rhs: + return tensor1_expanded.bmm(tensor2_expanded).squeeze(-1).view(output_shape) + else: + return tensor1_expanded.bmm(tensor2_expanded).view(output_shape) + else: + torch._check(False, lambda: "both arguments to matmul need to be at least 1D") + + +@register_decomposition([aten.upsample_bicubic2d.default, aten.upsample_bicubic2d.out]) +@aten.upsample_bicubic2d.default.py_impl(DispatchKey.Autograd) +@out_wrapper() +@pw_cast_for_opmath +def upsample_bicubic2d_default( + input: Tensor, + output_size: tuple[int, int], + align_corners: bool, + scale_h: Optional[float] = None, + scale_w: Optional[float] = None, +) -> Tensor: + # get dimensions of original image + _, _, in_h, in_w = input.shape + + # Calculate horizontal and vertical scaling factor + h_scale_factor = _compute_scale(in_h, output_size[0], align_corners, scale_h) + w_scale_factor = _compute_scale(in_w, output_size[1], align_corners, scale_w) + + _, dtype = utils.elementwise_dtypes( + input, type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ) + + # We have to create arange with int64 dtype and use .to in order to avoid + # additional kernels creation in inductor and get a perf slowdown + i = torch.arange(output_size[0], device=input.device).to(dtype=dtype) + j = torch.arange(output_size[1], device=input.device).to(dtype=dtype) + + x_float = _compute_source_index(w_scale_factor, j, align_corners) + y_float = _compute_source_index(h_scale_factor, i, align_corners) + y_float = y_float.unsqueeze(-1) + + x = x_float.floor() + y = y_float.floor() + + # We should also clamp xscale/yscale + # See guard_index_and_lambda in UpSample.h + yscale = (y_float - y).clamp(0.0, 1.0) + xscale = (x_float - x).clamp(0.0, 1.0) + x = x.to(torch.int64) + y = y.to(torch.int64) + + iys_ofs = (y - 1, y, y + 1, y + 2) + ixs_ofs = (x - 1, x, x + 1, x + 2) + + weights_x = _upsample_get_cubic_coefficients(xscale) + weights_y = _upsample_get_cubic_coefficients(yscale) + + weights_precision_x, weights_precision_y = None, None + if input.dtype == torch.uint8: + weights_precision_x = _compute_weight_precision(weights_x) + weights_precision_y = _compute_weight_precision(weights_y) + + weights_x = [ + (w * (1 << weights_precision_x) + torch.sign(w) * 0.5).to(torch.int16) + for w in weights_x + ] + weights_y = [ + (w * (1 << weights_precision_y) + torch.sign(w) * 0.5).to(torch.int16) + for w in weights_y + ] + + def load_bounded(ys, xs): + y_idx = torch.clamp(ys, 0, in_h - 1) + x_idx = torch.clamp(xs, 0, in_w - 1) + v = aten._unsafe_index(input, [None, None, y_idx, x_idx]) + return v + + def get_x_interp(y): + src_x = tuple(load_bounded(y, x_ofs) for x_ofs in ixs_ofs) + if input.dtype == torch.uint8: + assert weights_precision_x is not None + return _sum_tensors_uint8(src_x, weights_x, weights_precision_x) + return _sum_tensors(c1 * c2 for (c1, c2) in zip(src_x, weights_x)) + + src_y = tuple(get_x_interp(y_ofs) for y_ofs in iys_ofs) + if input.dtype == torch.uint8: + assert weights_precision_y is not None + result = _sum_tensors_uint8(src_y, weights_y, weights_precision_y) + else: + result = _sum_tensors(c1 * c2 for (c1, c2) in zip(src_y, weights_y)) + + # convert output to correct memory format, if necessary + memory_format = utils.suggest_memory_format(input) + result = result.contiguous(memory_format=memory_format) + return result + + +@register_decomposition(aten.upsample_bicubic2d.vec) +@aten.upsample_bicubic2d.vec.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten.upsample_bicubic2d.vec.py_impl(DispatchKey.Autograd) +@out_wrapper() +@pw_cast_for_opmath +def upsample_bicubic2d_vec( + a: Tensor, + output_size: Optional[tuple[int, int]], + align_corners: bool, + scale_factors: Optional[tuple[float, float]] = None, +) -> Tensor: + torch._check( + bool(output_size) + bool(scale_factors) == 1, + lambda: "Must specify exactly one of output_size and scale_factors.", + ) + if output_size is None: + assert scale_factors is not None + output_size = cast( + tuple[int, int], + tuple( + sym_int(sym_float(w) * scale) + for w, scale in zip(a.shape[2:], scale_factors) + ), + ) + scale_h, scale_w = scale_factors if scale_factors else (None, None) + return upsample_bicubic2d_default(a, output_size, align_corners, scale_h, scale_w) + + +@register_decomposition(aten.reflection_pad1d) +@register_decomposition(aten.reflection_pad2d) +@register_decomposition(aten.reflection_pad3d) +@pw_cast_for_opmath +@out_wrapper() +def _reflection_pad(a: Tensor, padding: tuple[int, ...]) -> Tensor: + def idx(left, middle, right): + dim_idx = torch.arange(-left, middle + right, device=a.device) + return middle - 1 - (middle - 1 - dim_idx.abs()).abs() + + return _reflection_or_replication_pad( + a, + padding, + idx, + ) + + +@register_decomposition(aten.replication_pad1d) +@register_decomposition(aten.replication_pad2d) +@register_decomposition(aten.replication_pad3d) +@pw_cast_for_opmath +@out_wrapper() +def _replication_pad(a: Tensor, padding: tuple[int, ...]) -> Tensor: + def idx(left, middle, right): + dim_idx = torch.arange(-left, middle + right, device=a.device) + return torch.clamp(dim_idx, 0, middle - 1) + + return _reflection_or_replication_pad( + a, + padding, + idx, + ) + + +def _reflection_or_replication_pad( + a: Tensor, + padding: tuple[int, ...], + idx_fn: Callable[[int, int, int], Tensor], +) -> Tensor: + dim = len(padding) // 2 + torch._check( + a.dim() in (dim + 1, dim + 2), + lambda: f"reflection_pad{dim}d requires {dim + 1}D or {dim + 2}D input", + ) + inp_shape = a.shape[-dim:] + nc_dim = a.dim() - dim + + padding_left = [padding[2 * (dim - 1 - i)] for i in range(dim)] + padding_right = [padding[2 * (dim - 1 - i) + 1] for i in range(dim)] + + result = a + for i in range(dim): + idx: list[Any] = [None] * result.dim() + idx[i + nc_dim] = idx_fn(padding_left[i], inp_shape[i], padding_right[i]) + result = aten._unsafe_index(result, idx) + + # convert output to correct memory format, if necessary + memory_format = utils.suggest_memory_format(result) + result = result.contiguous(memory_format=memory_format) + return result + + +@register_decomposition(aten.reflection_pad1d_backward) +@register_decomposition(aten.reflection_pad2d_backward) +@register_decomposition(aten.reflection_pad3d_backward) +@out_wrapper("grad_input") +def _reflection_pad_backward(grad_output, x, padding): + dim = len(padding) // 2 + + dhw = [h - 1 for h in x.shape[-dim:]] + + padding_left = [padding[2 * (dim - 1 - i)] for i in range(dim)] + padding_right = [padding[2 * (dim - 1 - i) + 1] for i in range(dim)] + + indices = [] + for i in range(x.ndim): + view_shape = [1] * x.ndim + view_shape[i] = -1 + indices.append(torch.arange(x.shape[i], device=x.device).view(view_shape)) + + b = indices[:-dim] + xyz = indices[-dim:] + + def index_range_condition(index_range): + i, lb, ub = index_range + return torch.logical_and(i >= lb, i <= ub) + + # Areas after reflection: + # + # top-left | top | top-right + # ----------------------------------------- + # left | center | right + # ----------------------------------------- + # bottom-left | bottom | bottom-right + # + # The center area is the original matrix. Other areas are reflections. + + center = [xyz[i] + padding_left[i] for i in range(dim)] + left_reflect = [padding_left[i] - xyz[i] for i in range(dim)] + right_reflect = [2 * dhw[i] + padding_left[i] - xyz[i] for i in range(dim)] + + # Accumulate gradients from different areas + # If some of the padding is negative, center load is not always valid + range_c = [ + (center[i], 0, dhw[i] + padding_left[i] + padding_right[i]) for i in range(dim) + ] + cond = functools.reduce( + aten.logical_and, [index_range_condition(range_c[i]) for i in range(dim)] + ) + grad = aten._unsafe_masked_index(grad_output, cond, b + center, 0.0) + + def accumulate(grad, out, index_ranges): + # If the upper bound is less than the lower bound, we can get rid of one accumulation. + # This happens when the padding size is zero. + for i in range(dim): + upper_less_than_lower = index_ranges[i][2] < index_ranges[i][1] + if isinstance(upper_less_than_lower, bool) and upper_less_than_lower: + return grad + + cond = functools.reduce( + aten.logical_and, + [index_range_condition(index_range) for index_range in index_ranges], + ) + g = aten._unsafe_masked_index(grad_output, cond, b + out, 0.0) + return grad + g + + for area in itertools.product(*[[-1, 0, 1] for _ in range(dim)]): + if area == tuple([0] * dim): + # center, this is already done. + continue + + outs = [] + index_ranges = [] + + for i in range(dim): + if area[i] == 0: + out = center[i] + index_range = range_c[i] + elif area[i] == -1: + out = left_reflect[i] + index_range = (xyz[i], 1, padding_left[i]) + elif area[i] == 1: + out = right_reflect[i] + index_range = (xyz[i], dhw[i] - padding_right[i], dhw[i] - 1) + + outs.append(out) # type: ignore[possibly-undefined] + index_ranges.append(index_range) # type: ignore[possibly-undefined] + + grad = accumulate(grad, outs, index_ranges) + + return grad + + +@register_decomposition(aten.aminmax) +@out_wrapper("min", "max") +def aminmax(self, *, dim=None, keepdim=False): + # pyrefly: ignore [bad-argument-type] + amin = torch.amin(self, dim=dim, keepdim=keepdim) + # pyrefly: ignore [bad-argument-type] + amax = torch.amax(self, dim=dim, keepdim=keepdim) + return amin, amax + + +@register_decomposition(aten.nansum) +@out_wrapper() +def nansum(self, dim=None, keepdim=False, *, dtype=None): + return aten.sum(torch.where(torch.isnan(self), 0, self), dim, keepdim, dtype=dtype) + + +@register_decomposition([aten.arange.default, aten.arange.out]) +@out_wrapper() +def arange_default( + end: NumberType, + *, + dtype: Optional[torch.dtype] = None, + layout: torch.layout = torch.strided, + device: Optional[torch.device] = None, + pin_memory: bool = False, +): + return aten.arange.start_step( + 0, end, 1, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory + ) + + +@register_decomposition([aten.arange.start]) +def arange_start( + start: NumberType, + end: NumberType, + *, + dtype: Optional[torch.dtype] = None, + layout: torch.layout = torch.strided, + device: Optional[torch.device] = None, + pin_memory: bool = False, +): + return aten.arange.start_step( + start, end, 1, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory + ) + + +@register_decomposition(out_dtype) +def out_dtype_decomp(*args, **kwargs): + from torch._higher_order_ops.out_dtype import out_dtype_dense + + return out_dtype_dense(*args, **kwargs) + + +@register_decomposition(aten.multi_margin_loss) +@aten.multi_margin_loss.default.py_impl(DispatchKey.Autograd) +@out_wrapper() +def multi_margin_loss( + input: Tensor, + target: Tensor, + p: NumberType = 1, + margin: NumberType = 1, + weight: Optional[Tensor] = None, + reduction: int = Reduction.MEAN.value, +) -> Tensor: + input = torch.atleast_2d(input) + target = torch.atleast_1d(target) + nframe = input.shape[0] + dim = input.shape[1] + torch._check(p == 1 or p == 2, lambda: "only p == 1 and p == 2 supported") + torch._check( + input.ndim == 2 and dim != 0, + lambda: f"Expected non-empty vector or matrix with optional 0-dim batch size, but got: {input.shape}", + ) + torch._check( + target.ndim == 1 and target.numel() == nframe, + lambda: f"inconsistent target size, expected {nframe} but got {target.shape}", + ) + if weight is not None: + weight = torch.atleast_1d(weight) + torch._check( + weight.ndim == 1 and weight.numel() == dim, # type: ignore[union-attr] + lambda: f"inconsistent weight size, expected {dim} but got {weight.shape}", # type: ignore[union-attr] + ) + target = target.unsqueeze(1) + u = torch.gather(input, dim=1, index=target) + z = margin - u + input + z = z.clamp_min(0) + z = z if p == 1 else z * z + if weight is not None: + z = z * weight[target] + idx = torch.arange(dim, device=input.device) + z = torch.where(idx != target, z, 0) + if reduction == Reduction.MEAN.value: + return z.mean() + elif reduction == Reduction.SUM.value: + return z.sum() / z.shape[1] + else: + return z.mean(dim=1) + + +@register_decomposition(aten.multilabel_margin_loss_forward) +@aten.multilabel_margin_loss_forward.default.py_impl(DispatchKey.Autograd) +@out_wrapper("output", "is_target") +def multilabel_margin_loss_forward( + input: Tensor, + target: Tensor, + reduction: int, +) -> tuple[Tensor, Tensor]: + orig_input_shape = input.shape + orig_target_shape = target.shape + input = torch.atleast_2d(input) + target = torch.atleast_2d(target) + dim = input.shape[1] + torch._check( + len(orig_input_shape) <= 2 and dim != 0, + lambda: f"Expected non-empty vector or matrix with optional 0-dim batch size, but got: {orig_input_shape}", + ) + torch._check( + len(orig_target_shape) <= 2 and orig_target_shape == orig_input_shape, + lambda: f"inconsistent target size: {orig_target_shape} for input of size: {orig_input_shape}", + ) + # ignores labels after the first -1, detects when -1 is not present + idx = torch.arange(dim, device=target.device) + is_end = target == -1 + end_idx = torch.amin(torch.where(is_end, idx, dim), dim=-1, keepdim=True) + # target indices + target_mask = idx < end_idx + # masks target to be able to use gather, which doesn't allow -1 + tidx0 = torch.where(target_mask, target, 0) + u = torch.gather(input, dim=-1, index=tidx0) + # is_target + tidx1 = torch.where(target_mask, target, -1) + is_target = torch.any(idx == tidx1.unsqueeze(dim=-1), dim=1) + # loss + z = 1.0 - u.T.unsqueeze(dim=-1) + input + z = z.clamp_min(0) + z = z / dim + # masks loss + z = torch.where(is_target, 0, z) + # reduction + if reduction == Reduction.MEAN.value: + z = z.sum(dim=(0, -1)).mean() + elif reduction == Reduction.SUM.value: + z = z.sum() + else: + z = z.sum(dim=(0, -1)) + # result + is_target = is_target.to(input.dtype).reshape(orig_target_shape) + return z, is_target + + +# scaled_dot_product_attention used to be decomposed in pre-autograd, given that +# it calls _scaled_dot_product_attention_math and +# _scaled_dot_product_attention_math only has a CompositeImplicitAutograd +# kernel. As a result it's decomposed into ops with finer granularity. +# However recent PRs (#103826 #105131 #115913) added new logic in +# scaled_dot_product_attention and now it calls +# _scaled_dot_product_flash_attention_for_cpu in export path. This results +# in _scaled_dot_product_flash_attention_for_cpu showing up in export result. +# This decomposition ensures scaled_dot_product_attention is still decomposed +# the same way as before, i.e., going through +# _scaled_dot_product_attention_math. Notice that this decomp rule should be +# excluded by inductor. +@register_decomposition(aten._scaled_dot_product_flash_attention_for_cpu.default) +def scaled_dot_product_flash_attention_for_cpu( + query: Tensor, + key: Tensor, + value: Tensor, + dropout_p: float = 0.0, + is_causal: bool = False, + *, + attn_mask: Optional[Tensor] = None, + scale: Optional[float] = None, +) -> tuple[Tensor, Tensor]: + torch._check( + torch.is_floating_point(query), + lambda: f"query must be FP32, FP64, BF16, FP16 but got {query.dtype}", + ) + torch._check( + query.dim() == 4 and key.dim() == 4 and value.dim() == 4, + lambda: f"q, k, v must be a 4 dimensional tensor, got {query.dim()}, {key.dim()}, {value.dim()}", + ) + torch._check( + dropout_p == 0.0, lambda: f"dropout probability must be zero, got {dropout_p}" + ) + torch._check( + query.shape[3] == value.shape[3] and key.shape[3] == value.shape[3], + lambda: "q, k, v should have the same head size", + ) + + output, attn = aten._scaled_dot_product_attention_math.default( + query, + key, + value, + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=is_causal, + dropout_mask=None, + scale=scale, + enable_gqa=query.size(1) != key.size(1), + ) + # Why this change? + # In pre-dispatch export scaled_dot_product_attention is executed via + # * flash_attention. + # flash_attention allocates output tensor as (N, H, L, E) (see PR #134656) + # assume x: [N, H, L, E] is the output sdpa + # In MHA code, this output is then permuted via (2, 0, 1, 3) to get + # (L, N, H, E) dim tensor + # x = x.permute(2, 0, 1, 3).contiguous() and the viewed via + # x = x.view(L * N, H * E) + # During pre autograd dispatch call to contiguous is not traced because + # flash_attention output after the x.permute is already contiguous + # on which the view is valid + # However, during 2nd stage export, post-dispatch, we run _match variant + # instead of flash* to get the decomposition. _match variant returns + # x: [N, H, L, E] applying x.permute(2, 0, 1, 3) returns + # x: [L, N, H, E] and without converting this to contiguous tensor + # subsequent view is not valid and the export fails + # solution is to maintain the return tensor view from the decomp to be + # exactly same as *flash* variant. + + # Really the invariant you want to maintain is: + # pre-dispatch op-output and its decomposed representation must + # return tensor with same view and dims + output = ( + output.permute(2, 0, 1, 3) + .contiguous(memory_format=torch.contiguous_format) + .permute(1, 2, 0, 3) + ) + return output, attn + + +def register_inplace(aten_op, outplace_op): + @register_decomposition(aten_op) + def inplace_op(*args, **kwargs): + out = outplace_op(*args, **kwargs) + return args[0].copy_(out) + + return inplace_op + + +@register_decomposition([aten.baddbmm]) +@out_wrapper(exact_dtype=True) +@pw_cast_for_opmath +def baddbmm(self, batch1, batch2, beta=1, alpha=1): + if not self.is_floating_point() and not self.is_complex(): + beta = int(beta) + alpha = int(alpha) + result = torch.bmm(batch1, batch2) + if not isinstance(alpha, numbers.Number) or alpha != 1: + # pyrefly: ignore [unsupported-operation] + result = result * alpha + if beta == 0: + return result + if not isinstance(beta, numbers.Number) or beta != 1: + self = self * beta + return self + result + + +@register_decomposition(aten.floor_divide) +@out_wrapper() +def floor_divide(self, other): + return torch.div(self, other, rounding_mode="floor") + + +@register_decomposition(aten.sym_numel) +def sym_numel(t): + return functools.reduce(operator.mul, t.shape, 1) + + +@register_decomposition([aten.sum.default, aten.sum.out]) +def sum_default( + self: Tensor, + *, + dtype: Optional[torch.dtype] = None, + out: Optional[Tensor] = None, +) -> Tensor: + if out is None: + return aten.sum.dim_IntList(self, [], dtype=dtype) + else: + return aten.sum.IntList_out(self, [], dtype=dtype, out=out) + + +@register_decomposition([aten.squeeze.default, aten.squeeze.dim]) +def squeeze_default(self: Tensor, dim: Optional[int] = None): + # handle a scalar directly + if not isinstance(self, torch.Tensor): + return self + # perform squeeze + if dim is None: + return aten.squeeze.dims(self, list(range(self.dim()))) + else: + return aten.squeeze.dims(self, [dim]) + + +@register_decomposition(torch.ops.aten._weight_norm_interface) +def _weight_norm_interface(v, g, dim=0): + # https://github.com/pytorch/pytorch/blob/852f8526c52190125446adc9a6ecbcc28fb66182/aten/src/ATen/native/WeightNorm.cpp#L58 + keep_dim = tuple(i for i in range(len(v.shape)) if i != dim) + # align with cuda behavior, keep norm in 'float' when g is 'bfloat16' + norm_dtype = torch.float if g.dtype == torch.bfloat16 else None + norm = v.norm(2, keep_dim, keepdim=True, dtype=norm_dtype) + return v * (g / norm.to(g.dtype)), norm + + +@register_decomposition(aten.isin) +@out_wrapper() +def isin(elements, test_elements, *, assume_unique=False, invert=False): + # handle when either elements or test_elements are Scalars (they can't both be) + if not isinstance(elements, torch.Tensor): + elements = torch.tensor(elements, device=test_elements.device) + if not isinstance(test_elements, torch.Tensor): + if invert: + return torch.ne(elements, test_elements) + else: + return torch.eq(elements, test_elements) + + if test_elements.numel() < 10.0 * pow(elements.numel(), 0.145): + return isin_default(elements, test_elements, invert=invert) + else: + return isin_sorting( + elements, test_elements, assume_unique=assume_unique, invert=invert + ) + + +@register_decomposition(aten.bernoulli.default) +def bernoulli( + self: torch.Tensor, + *, + generator: Optional[torch.Generator] = None, +) -> torch.Tensor: + if generator is None: + raw_p = torch.rand(self.size(), dtype=torch.float32, device=self.device) + else: + raw_p = torch.rand( + self.size(), + generator=generator, + dtype=torch.float32, + device=self.device, + ) + p = (raw_p < self).to(self.dtype) + return p + + +def isin_default(elements, test_elements, *, invert=False): + if elements.numel() == 0: + return torch.empty_like(elements, dtype=torch.bool) + expanded_elem_shape = elements.shape + (1,) * test_elements.ndim + x = elements.view(expanded_elem_shape) + dim = tuple(range(-1, -test_elements.ndim - 1, -1)) + res = (x == test_elements).any(dim=dim) + return ~res if invert else res + + +def isin_sorting(elements, test_elements, *, assume_unique=False, invert=False): + elements_flat = elements.flatten() + test_elements_flat = test_elements.flatten() + if assume_unique: + # This is the same as the aten implementation. For + # assume_unique=False, we cannot use unique() here, so we use a + # version with searchsorted instead. + all_elements = torch.cat([elements_flat, test_elements_flat]) + sorted_elements, sorted_order = torch.sort(all_elements, stable=True) + + duplicate_mask = sorted_elements[1:] == sorted_elements[:-1] + duplicate_mask = torch.constant_pad_nd(duplicate_mask, [0, 1], False) + + if invert: + duplicate_mask = duplicate_mask.logical_not() + + mask = torch.empty_like(duplicate_mask) + mask = mask.index_copy(0, sorted_order, duplicate_mask) + + return mask[0 : elements.numel()] + else: + sorted_test_elements, _ = torch.sort(test_elements_flat) + idx = torch.searchsorted(sorted_test_elements, elements_flat) + test_idx = torch.where(idx < sorted_test_elements.numel(), idx, 0) + cmp = sorted_test_elements[test_idx] == elements_flat + cmp = cmp.logical_not() if invert else cmp + return cmp.reshape(elements.shape) + + +@register_decomposition(aten.take) +@out_wrapper() +def take(self, index): + flattened = self.reshape(-1) + return flattened[index] + + +@register_decomposition(aten.resize_as) +def resize_as(self, other, memory_format=None): + if memory_format is None: + memory_format = torch.contiguous_format + if memory_format == torch.preserve_format: + memory_format = suggest_memory_format(other) + return aten.resize(self, other.shape, memory_format=memory_format) + + +register_inplace(aten.addbmm_, aten.addbmm) +register_inplace(aten.addmm_, aten.addmm) +register_inplace(aten.addmv_, aten.addmv) +register_inplace(aten.baddbmm_, aten.baddbmm) +register_inplace(aten.fill_, aten.fill) +register_inplace(aten.gelu_, aten.gelu) +register_inplace(aten.hardswish_, aten.hardswish) +register_inplace(aten.hardtanh_, aten.hardtanh) +register_inplace(aten.hardsigmoid_, aten.hardsigmoid) +register_inplace(aten.__iand__, aten.__and__) +register_inplace(aten.__ilshift__, aten.__lshift__) +register_inplace(aten.index_put_, aten.index_put) +register_inplace(aten.index_reduce_, aten.index_reduce) +register_inplace(aten.__ior__, aten.__or__) +register_inplace(aten.__irshift__, aten.__rshift__) +register_inplace(aten.__ixor__, aten.__xor__) +register_inplace(aten.leaky_relu_, aten.leaky_relu) +register_inplace(aten.logit_, aten.logit) +register_inplace(aten.relu_, aten.relu) +register_inplace(aten.renorm_, aten.renorm) +register_inplace(aten.round_, aten.round) +register_inplace(aten.scatter_, aten.scatter) +register_inplace(aten.scatter_add_, aten.scatter_add) +register_inplace(aten.scatter_reduce_, aten.scatter_reduce) +register_inplace(aten.silu_, aten.silu) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_decomp/decompositions_for_jvp.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_decomp/decompositions_for_jvp.py new file mode 100644 index 0000000000000000000000000000000000000000..dd3b7e7d8899266501ad57381f190c47a2082739 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_decomp/decompositions_for_jvp.py @@ -0,0 +1,336 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +import inspect +from collections.abc import Callable +from typing import Optional + +import torch +import torch._decomp +from torch import Tensor +from torch._prims_common.wrappers import _maybe_remove_out_wrapper + + +decomposition_table = torch._decomp.decomposition_table +decomposition_table_for_jvp: dict[torch._ops.OperatorBase, Callable] = {} +register_decomposition = torch._decomp.register_decomposition +aten = torch.ops.aten + +# NOTE: [forward-mode AD decompositions mechanism] +# +# The mechanism is in VariableType, +# IF any inputs have forward grad +# AND there is no forward AD formula implemented +# AND the functions are actually differentiable +# run the decomposition +# See run_jit_decomposition_with_args_for_jvp +# We currently use python decompositions that we torchscript. +# +# Note that we would be building the backward graph at the decomposed level +# too, but that is OK, because we would've errored out otherwise anyway. +# +# TODO: The mechanism we are using to register decompositions doesn't +# seem to be exclusively used for jvp. So open question here is whether +# torch/csrc/jit/runtime/decomposition_registry.cpp is being used for other things. +# If that is the case, we may go down the decomposition path unexpectedly +# (and possibly produce an unintelligible error) vs erroring out earlier and +# printing that the forward AD formula is not implemented. +# +# The solution to this may be to have an explicitly white list control when +# to enable the decomposition. + + +def maybe_register_decomposition(op): + def decorator(f): + try: + return register_decomposition(op)(f) + except Exception: + return f + + return decorator + + +# Functions where we need a special decomposition for jvp but there's another version that +# should be used more generally (ex. for jvp we need to recompute the mean and variance for +# the backwards of a normalization function. Without jvp, it should use the saved value) +decomposition_table_for_jvp = {} + + +def register_decomposition_for_jvp(fn): + return register_decomposition(fn, registry=decomposition_table_for_jvp) + + +def _register_jit_decomposition_for_jvp(decomp, use_python=False): + if decomp in decomposition_table_for_jvp: + decomposition_table_used = decomposition_table_for_jvp + elif decomp in decomposition_table: + decomposition_table_used = decomposition_table + else: + raise RuntimeError(f"could not find decomposition for {decomp}") + decomp_fn = decomposition_table_used[decomp] + + # `out_wrapper` extends a decompositions signature with + # an `out` parameter. However jit will use the unwrapped function's + # signature instead so we need to unwrap here to prevent an error + decomp_fn = _maybe_remove_out_wrapper(decomp_fn) + + if use_python: + decomp_fn = torch.jit.ignore(decomp_fn) + sig = inspect.signature(decomp_fn) + + # Create a string wrapping the function from the signature + # example output: + # def wrapped_decomp(x: torch.Tensor, y: int, z: int): + # return decomp_fn(x, y, z) + # Thanks copilot! + def get_function_def(sig): + param_def = [f"{param_str}" for param_str in sig.parameters.values()] + param_use = [f"{param_str}" for param_str in sig.parameters] + + return f"def wrapped_decomp({', '.join(param_def)}):\n return decomp_fn({', '.join(param_use)})\n" + + f_str = get_function_def(sig) + graph = torch.jit.CompilationUnit(f_str).wrapped_decomp.graph + else: + graph = torch.jit.script(decomp_fn).graph + torch.jit._register_decomposition(decomp, graph) + + +# The only decompositions here are temporary or hacks for the purposes of jvp + + +# TODO: do these also belong here? +@maybe_register_decomposition(aten.trace.default) +def trace(self: Tensor) -> Tensor: + return torch.sum(torch.diag(self)) + + +@maybe_register_decomposition(aten.log_sigmoid_forward.default) +def log_sigmoid_forward(self: Tensor) -> tuple[Tensor, Tensor]: + min = torch.minimum(self.new_zeros(()), self) + z = torch.exp(-torch.abs(self)) + if self.is_cuda or self.is_xpu: + buffer = self.new_zeros((0,)) + else: + buffer = z + return min - torch.log1p(z), buffer + + +def recompute_mean_var( + input: Tensor, rstd: Tensor, inner_dim_indices: list[int], keepdim: bool +): + # for most norm decompositions, it will be the same as the core version except for here. + # We recompute the mean and variance so that they track gradients through input + + mean = torch.mean(input, dim=inner_dim_indices, keepdim=keepdim) + var = torch.var(input, dim=inner_dim_indices, unbiased=False, keepdim=keepdim) + eps = torch.pow(1 / rstd, 2) - var # this makes me so sad inside + eps = eps.detach() + rstd = 1 / torch.sqrt(var + eps) + return mean, rstd + + +@register_decomposition_for_jvp(aten.native_layer_norm_backward) +def native_layer_norm_backward( + grad_out: Tensor, + input: Tensor, + normalized_shape: list[int], + mean: Tensor, + rstd: Tensor, + weight: Optional[Tensor], + bias: Optional[Tensor], + output_mask: list[bool], +) -> tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]: + input_shape = input.shape + input_ndim = input.dim() + + axis = input_ndim - len(normalized_shape) + inner_dims = input_shape[axis:] + outer_dims = input_shape[:axis] + inner_dim_indices = list(range(axis, input_ndim)) + outer_dim_indices = list(range(axis)) + + N = 1 + for i in inner_dims: + N *= i + M = 1 + for i in outer_dims: + M *= i + if M <= 0 or N <= 0: + return ( + input.new_zeros(input_shape), + input.new_zeros(input_shape[axis:]), + input.new_zeros(input_shape[axis:]), + ) + + mean_, rstd_ = recompute_mean_var(input, rstd, inner_dim_indices, keepdim=True) + + x_hat = (input - mean_) * rstd_ + if weight is not None: + grad_x_hat = grad_out * weight + else: + grad_x_hat = grad_out + a = grad_x_hat * N + b = torch.sum(grad_x_hat, inner_dim_indices, True) + c1 = torch.mul(grad_x_hat, x_hat) + c2 = torch.sum(c1, inner_dim_indices, True) + c3 = torch.mul(x_hat, c2) + inner = a - b - c3 + + if output_mask[0]: + d_input: Optional[Tensor] = (rstd_ / N) * inner + else: + d_input = torch.zeros_like(input) # should be None but doesn't work with vjp + + if output_mask[1] and weight is not None: + if len(outer_dim_indices) > 0: + d_weight: Optional[Tensor] = torch.sum( + grad_out * x_hat, outer_dim_indices, False + ) + else: + d_weight = grad_out * x_hat + elif weight is not None: + d_weight = torch.zeros_like(weight) # should be None but doesn't work with vjp + else: + d_weight = torch.zeros(()) # should be None but doesn't work with vjp + + if output_mask[2] and bias is not None: + if len(outer_dim_indices) > 0: + d_bias: Optional[Tensor] = torch.sum(grad_out, outer_dim_indices, False) + else: + d_bias = grad_out.clone() + elif bias is not None: + d_bias = torch.zeros_like(bias) # should be None but doesn't work with vjp + else: + d_bias = torch.zeros(()) # should be None but doesn't work with vjp + + return (d_input, d_weight, d_bias) + + +def prod(x: list[int]): + r = 1 + for i in x: + r *= i + return r + + +@register_decomposition_for_jvp(aten.native_batch_norm_backward) +def native_batch_norm_backward( + grad_out: Tensor, + input: Tensor, + weight: Optional[Tensor], + running_mean: Optional[Tensor], + running_var: Optional[Tensor], + save_mean: Optional[Tensor], + save_invstd: Optional[Tensor], + train: bool, + eps: float, + output_mask: list[bool], +) -> tuple[Tensor, Optional[Tensor], Optional[Tensor]]: + input_shape = input.shape + input_rank = input.dim() + assert input_rank >= 2, "rank of the input must be at least 2" + + axis = 1 + num_features = prod(input_shape) / input_shape[axis] # type: ignore[arg-type] + mean = save_mean + invstd = save_invstd + if train: + assert save_mean is not None and save_invstd is not None, ( + "when train=True, save_mean and save_invstd are required" + ) + + reduciton_dims = [0] + list(range(2, input.dim())) + assert invstd is not None # for typing + mean, invstd = recompute_mean_var(input, invstd, reduciton_dims, keepdim=False) + else: + assert running_mean is not None and running_var is not None + mean = running_mean + invstd = torch.rsqrt(running_var + eps) + + assert invstd is not None and mean is not None + + broadcast_mask = [1] * input_rank + broadcast_mask[axis] = input_shape[axis] + + reduction_axes: list[int] = [] + for i in range(input_rank): + if i != axis: + reduction_axes.append(i) + + mean = torch.reshape(mean, broadcast_mask) + norm = 1.0 / num_features + grad_output_sum = torch.sum(grad_out, reduction_axes) + dot_p = torch.sum(grad_out * (input - mean), reduction_axes) + + grad_mean = torch.reshape(grad_output_sum * norm, broadcast_mask) + proj_scale = torch.reshape(torch.mul(dot_p * norm, invstd * invstd), broadcast_mask) + + if weight is None: + grad_scale = torch.reshape(invstd, broadcast_mask) * 1.0 + else: + grad_scale = torch.reshape(invstd * weight, broadcast_mask) + + if train: + proj = (input - mean) * proj_scale + grad_input = ((grad_out - proj) - grad_mean) * grad_scale + else: + grad_input = grad_out * grad_scale + + if output_mask[1]: + grad_weight = dot_p * invstd + elif weight is not None: + grad_weight = torch.zeros_like( + weight + ) # should be None but doesn't work with vjp + else: + grad_weight = torch.zeros(()) # should be None but doesn't work with vjp + + if output_mask[2]: + grad_bias = grad_output_sum + else: + grad_bias = torch.zeros_like( + grad_output_sum + ) # should be None but doesn't work with vjp + + return (grad_input, grad_weight, grad_bias) + + +@register_decomposition_for_jvp(aten.batch_norm_backward) +def batch_norm_backward( + grad_out: Tensor, + input: Tensor, + weight: Tensor, + running_mean: Optional[Tensor], + running_var: Optional[Tensor], + save_mean: Optional[Tensor], + save_var: Optional[Tensor], + update: bool, + eps: float, + output_mask: list[bool], + reserve: Tensor, +) -> tuple[Tensor, Optional[Tensor], Optional[Tensor]]: + return native_batch_norm_backward( + grad_out, + input, + weight, + running_mean, + running_var, + save_mean, + save_var, + update, + eps, + output_mask, + ) + + +_register_jit_decomposition_for_jvp(torch.ops.aten.trace.default, use_python=True) +_register_jit_decomposition_for_jvp(torch.ops.aten.nll_loss_backward.default) +_register_jit_decomposition_for_jvp(torch.ops.aten.nll_loss2d_backward.default) +_register_jit_decomposition_for_jvp(torch.ops.aten._log_softmax_backward_data.default) +_register_jit_decomposition_for_jvp(torch.ops.aten._softmax_backward_data.default) +_register_jit_decomposition_for_jvp(torch.ops.aten.log_sigmoid_forward.default) +_register_jit_decomposition_for_jvp(torch.ops.aten.native_layer_norm_backward.default) +_register_jit_decomposition_for_jvp(torch.ops.aten.native_batch_norm_backward.default) +_register_jit_decomposition_for_jvp(torch.ops.aten.cudnn_batch_norm_backward.default) +_register_jit_decomposition_for_jvp(torch.ops.aten.batch_norm_backward.default) +_register_jit_decomposition_for_jvp(torch.ops.aten.miopen_batch_norm_backward.default) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_decomp/decompositions_for_rng.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_decomp/decompositions_for_rng.py new file mode 100644 index 0000000000000000000000000000000000000000..455ef0cc994388a60785cf715c6ec529a0c0fec5 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_decomp/decompositions_for_rng.py @@ -0,0 +1,266 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +import functools +from collections import defaultdict +from collections.abc import Callable + +import torch +import torch._decomp as decomp +from torch._decomp import get_decompositions +from torch._ops import OpOverload + + +aten = torch.ops.aten + +rng_decompositions: dict[str, dict[OpOverload, Callable]] = defaultdict(dict) + + +def register_rng_decomposition(aten_op): + return decomp.register_decomposition(aten_op, rng_decompositions) + + +def throw_on_non_cuda(device): + raise RuntimeError( + f"You are trying to functionalize a {device.type} RNG operator but {device.type} does not " + f"use Philox/counter-based RNG. Therefore, functionalizing a {device.type} RNG operator is " + "not supported. We are discussing the possibility of a Philox-based RNG implementation for CPU." + ) + + +# TODO - We have to register many more distributions here, and also higher level +# ops like dropout which have fused implementation and can hide the rand inside. +@register_rng_decomposition(aten.rand) +def rand(shape, dtype=None, layout=torch.strided, device=None, pin_memory=False): + if device and device.type != "cuda": + throw_on_non_cuda(device) + seed, offset = PhiloxStateTracker.get_state_as_tuple() + dtype = dtype or torch.float32 + out, offset_jump = torch.ops.rngprims.philox_rand( + shape, seed, offset, None, device, dtype + ) + PhiloxStateTracker.advance_offset(offset_jump) + return out + + +@register_rng_decomposition(aten.rand_like) +def rand_like( + x: torch.Tensor, + dtype=None, + layout=None, + device=None, + pin_memory=False, + memory_format=torch.preserve_format, +): + device = device or x.device + if device.type != "cuda": + throw_on_non_cuda(device) + dtype = dtype or x.dtype + seed, offset = PhiloxStateTracker.get_state_as_tuple() + out, offset_jump = torch.ops.rngprims.philox_rand( + x.shape, seed, offset, None, device, dtype + ) + PhiloxStateTracker.advance_offset(offset_jump) + return out + + +class PhiloxState: + """ + Represents a PhiloxRngState - (seed, offset) where offset = base_offset + + relative_offset. seed and base_offset basically point to the rng state just + before tracing starts. relative offset tracks the totally consumed offset at + trace time. + """ + + def __init__(self) -> None: + self.reset() + + def reset(self): + self.seed = torch.tensor(()) + self.base_offset = torch.tensor(()) + self.relative_offset = 0 + self.offset_advanced_alteast_once = False + + def validate_state(self): + assert self.seed.numel() != 0 and self.base_offset.numel() != 0 + + def advance_offset(self, consumed_offset): + self.offset_advanced_alteast_once = True + self.relative_offset = self.relative_offset + consumed_offset + + def set_state(self, seed, base_offset, relative_offset=0): + self.seed = seed + self.base_offset = base_offset + self.relative_offset = relative_offset + + def get_state_as_tuple(self): + self.validate_state() + return (self.seed, self.base_offset + self.relative_offset) + + def get_state_as_tensor(self): + # Only needed because we override get_rng_state. + self.validate_state() + return torch.stack([self.seed, self.base_offset + self.relative_offset]) + + def set_state_from_tensor(self, state): + # Only needed because we override set_rng_state. + self.seed, self.base_offset = torch.unbind(state) + self.relative_offset = 0 + + +class PhiloxStateTracker: + """ + Singleton class to track the philox rng state during AOT Autograd tracing. + For each aot tracing instance, AOT Autograd resets this tracker and keeps + track of both forward and backward offsets. At runtime, we only care about + the total consumed forward and backward offsets. For dynamic shapes, these + offsets are a function of input shapes. Therefore, the AOT generated graphs + have additional outputs that compute total consumed forward and backward + offsets. + """ + + running_state: PhiloxState + fwd_state: PhiloxState + bwd_state: PhiloxState + + def __enter__(self): + PhiloxStateTracker.reset() + return self + + def __exit__(self, exc_type, exc_cal, exc_tb): + PhiloxStateTracker.reset() + + @classmethod + def reset(cls): + cls.running_state = PhiloxState() + cls.fwd_state = PhiloxState() + cls.bwd_state = PhiloxState() + + @classmethod + def mark_beginning_of_forward(cls): + # Tells the tracker to use fwd_state as the running state + cls.running_state = cls.fwd_state + + @classmethod + def mark_beginning_of_backward(cls): + # Tells the tracker to use bwd_state as the running state + cls.running_state = cls.bwd_state + + @classmethod + def record_state(cls, seed, offset, mode): + # Records the seed and offset tensors. These tensors are used to invoke + # the philox_rand functional primitives. + if mode == "forward": + cls.fwd_state.set_state(seed, offset) + cls.mark_beginning_of_forward() + else: + assert mode == "backward" + cls.bwd_state.set_state(seed, offset) + + @classmethod + def get_state_as_tensor(cls): + # The only reason this exists is because we override get_rng_state and + # set_rng_state during tracing. get_rng_state expects a tensor output, + # so return (seed, offset) tuple upset other parts of the program like + # ctx.saved_tensors. + + # A bad consequence is that if user saves and restores rng state, we + # have little bit of ugliness in the generated code, where we first + # concat the (seed, offset) to create a tensor for get_rng_state, and + # then split it back to get (seed, offset) tuple in set_rng_state. + + # TODO: Investigate if there is be a better way to wrap the tuple in a + # false Tensor object, and then desugar it later on. + return cls.running_state.get_state_as_tensor() + + @classmethod + def get_state_as_tuple(cls): + return cls.running_state.get_state_as_tuple() + + @classmethod + def set_state_from_tensor(cls, x): + # This is only needed because we override set_rng_state. Look at the + # comment in get_state_from_tensor method. + cls.running_state.set_state_from_tensor(x) + + @classmethod + def advance_offset(cls, consumed_offset): + cls.running_state.advance_offset(consumed_offset) + + @classmethod + def get_current_relative_offset(cls): + return cls.running_state.relative_offset + + @staticmethod + def multiple_of_4(offset): + # torch cuda rng state offset must be a multiple of 4. For inductor, as + # we sum up all the numel, the result might not be a multiple of 4. This + # method achieves that. + return (offset + 3) // 4 * 4 + + @classmethod + def get_updated_fwd_offset(cls): + # Short circuit if no rand ops were observed + if not cls.fwd_state.offset_advanced_alteast_once: + return cls.fwd_state.base_offset + return cls.multiple_of_4( + cls.fwd_state.base_offset + cls.fwd_state.relative_offset + ) + + @classmethod + def get_updated_bwd_offset(cls): + # Short circuit if no rand ops were observed + if not cls.bwd_state.offset_advanced_alteast_once: + return cls.bwd_state.base_offset + return cls.multiple_of_4( + cls.bwd_state.base_offset + cls.bwd_state.relative_offset + ) + + +# Adding more decompositions which eventually use rand_like inside decomps. +# Adding these in rng_decompositions ensures the functionalization of rand_like +# ops used in these decomps. The list is copied from inductor codebase, which +# uses it for similar purpose. +# +# Caution - These decomps do not have same accuracy as that of eager. However, +# we can't just disable them with a config flag like fallback_random, because +# for functionalization of rng ops, we have to decompose these ops. +extra_random_decomps = get_decompositions( + [ + aten.cauchy, + aten.cauchy_, + aten.exponential, + aten.exponential_, + aten.geometric, + aten.geometric_, + aten.native_dropout, + aten.normal, + aten.normal_, + aten.normal_functional, + aten.log_normal, + aten.log_normal_, + aten.rrelu_with_noise, + aten.rrelu_with_noise_, + aten.uniform_, + ] +) +register_extra_random_decomp = functools.partial( + decomp.register_decomposition, registry=extra_random_decomps +) + + +@register_extra_random_decomp([aten.bernoulli_]) +def bernoulli_(self, p=0.5): + if self.device == torch.device("cpu"): + return NotImplemented + return self.copy_(torch.rand_like(self, dtype=torch.float32) < p) + + +@register_extra_random_decomp([aten.bernoulli.p]) +def bernoulli_p(self, p=0.5, *, generator=None): + if self.device == torch.device("cpu"): + return NotImplemented + assert generator is None + return torch.rand_like(self, dtype=torch.float32) < p + + +rng_decompositions.update(extra_random_decomps) # type: ignore[arg-type] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dispatch/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dispatch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dispatch/python.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dispatch/python.py new file mode 100644 index 0000000000000000000000000000000000000000..98f6ccf78bb89e37631a43bfa557aef381222d1b --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_dispatch/python.py @@ -0,0 +1,192 @@ +# mypy: allow-untyped-defs +import itertools +import unittest.mock +from collections.abc import Callable, Iterator +from contextlib import contextmanager +from typing import TypeVar, Union +from typing_extensions import ParamSpec + +import torch +import torch._C +import torch._ops +import torch.utils._python_dispatch +import torch.utils._pytree as pytree +from torch._C import DispatchKey + + +__all__ = ["enable_python_dispatcher", "no_python_dispatcher", "enable_pre_dispatch"] + +no_python_dispatcher = torch._C._DisablePythonDispatcher +enable_python_dispatcher = torch._C._EnablePythonDispatcher +enable_pre_dispatch = torch._C._EnablePreDispatch + +CROSSREF_FUNCTIONALIZE = False + +_P = ParamSpec("_P") +_T = TypeVar("_T") + + +def all_py_loaded_overloads() -> Iterator[torch._ops.OpOverload]: + """ + Warning: the set of overloads this will report is very subtle. It is precisely + the set of torch.ops functions that have actually been accessed from Python + (e.g., we actually called torch.ops.aten.blah at some point. This is DIFFERENT + from the set of registered operators, which will in general be a larger set, + as this would include all operators which we ran C++ static initializers or + Python operator registration on. This does not eagerly populate the list on + torch.ops.aten; this list is lazy! + + In other words, this is good for traversing over everything that has an + OpOverload object allocated in Python. We use it for cache invalidation, but + don't rely on this list being complete. + + Note that even if we did report all C++ registered overloads, this isn't guaranteed + to be complete either, as a subsequent lazy load of a library which triggers more + registrations could add more things to the set. + """ + for ns in torch.ops: + packets = getattr(torch.ops, ns) + for op_name in packets: + packet = getattr(packets, op_name) + for overload in packet: + yield getattr(packet, overload) + + +@contextmanager +def suspend_functionalization(): + f_tls = torch._C._dispatch_tls_is_dispatch_key_included( + torch._C.DispatchKey.Functionalize + ) + f_rv = torch._C._functionalization_reapply_views_tls() + if f_tls: + torch._disable_functionalization() + try: + yield + finally: + if f_tls: + torch._enable_functionalization(reapply_views=f_rv) + + +def check_tensor_metadata_matches(nv, rv, desc): + assert callable(desc) + assert nv.size() == rv.size(), f"{desc()}: sizes {nv.size()} != {rv.size()}" + assert nv.dtype == rv.dtype, f"{desc()}: dtype {nv.dtype} != {rv.dtype}" + same_strides, idx = torch._prims_common.check_significant_strides( + nv, rv, only_cuda=False + ) + assert same_strides, ( + f"{desc()}: strides {nv.stride()} != {rv.stride()} (mismatch at index {idx})" + ) + + +def check_metadata_matches(n, r, desc): + assert callable(desc) + n_vals, _n_spec = pytree.tree_flatten(n) + r_vals, _r_spec = pytree.tree_flatten(r) + # TODO: test the specs match; empirically sometimes we have a tuple + # on one side and a list on the other + assert len(n_vals) == len(r_vals), f"{len(n_vals)} != {len(r_vals)}" + for i, nv, rv in zip(range(len(n_vals)), n_vals, r_vals): + if not isinstance(rv, torch.Tensor): + continue + check_tensor_metadata_matches(nv, rv, lambda: f"{desc()} output {i}") + + +class Lit: + def __init__(self, s): + self.s = s + + def __repr__(self): + return self.s + + +def _fmt(a: object) -> object: + if isinstance(a, torch.Tensor): + return Lit( + f"torch.empty_strided({tuple(a.size())}, {a.stride()}, dtype={a.dtype})" + ) + else: + return a + + +def make_crossref_functionalize( + op: torch._ops.OpOverload[_P, _T], final_key: DispatchKey +) -> Union[Callable[_P, _T], DispatchKey]: + from torch._subclasses.fake_tensor import FakeTensorMode + + # This case is pretty weird, suppress it for now + if op is torch.ops.aten.lift_fresh.default: + return final_key + + def handler(*args: _P.args, **kwargs: _P.kwargs) -> _T: + fake_mode = FakeTensorMode() + + def fakeify_defun(t): + if isinstance(t, torch.Tensor): + if torch._is_functional_tensor(t): + r = torch._from_functional_tensor(t) + # NB: This assumes that the inner tensor sizes/strides match + # the outer tensor sizes/strides. This doesn't necessarily have to + # be the case, see discussion at + # https://github.com/pytorch/pytorch/pull/87610/files/401ddeda1d769bedc88a12de332c7357b60e51a4#r1007264456 + assert t.size() == r.size() + assert t.stride() == r.stride() + else: + r = t + # TODO: suppress guards + return fake_mode.from_tensor(r) + return t + + def maybe_detach(t): + if isinstance(t, torch.Tensor): + return t.detach() + else: + return t + + # TODO: This probably does the wrong thing if you're running other + # substantive modes with the normal op outside here + with ( + torch.utils._python_dispatch._disable_current_modes(), + suspend_functionalization(), + ): + f_args, f_kwargs = pytree.tree_map(fakeify_defun, (args, kwargs)) + orig_f_args, orig_f_kwargs = pytree.tree_map( + maybe_detach, (f_args, f_kwargs) + ) + with fake_mode: + f_r = op(*f_args, **f_kwargs) # pyrefly: ignore [invalid-param-spec] + r = op._op_dk(final_key, *args, **kwargs) + + def desc(): + fmt_args = ", ".join( + itertools.chain( + (repr(pytree.tree_map(_fmt, a)) for a in orig_f_args), + ( + f"{k}={pytree.tree_map(_fmt, v)}" + for k, v in orig_f_kwargs.items() + ), + ) + ) + return f"{op}({fmt_args})" + + check_metadata_matches(f_r, r, desc) + return r + + return handler + + +# NB: enabling this is slow, don't do it in a hot loop. This is purely +# for debugging purposes. +@contextmanager +def enable_crossref_functionalize(): + for op in all_py_loaded_overloads(): + op._uncache_dispatch(torch._C.DispatchKey.Functionalize) + try: + with ( + enable_python_dispatcher(), + unittest.mock.patch("torch._dispatch.python.CROSSREF_FUNCTIONALIZE", True), + ): + yield + finally: + for op in all_py_loaded_overloads(): + op._uncache_dispatch(torch._C.DispatchKey.Functionalize) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/config.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/config.py new file mode 100644 index 0000000000000000000000000000000000000000..ec3963eaa34acfbbe4f21354b0a84ab54cccfd71 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/config.py @@ -0,0 +1,45 @@ +""" +Configuration module for torch.export.export. + +This module contains various configuration flags and settings that control torch.export's +behavior, including: +- Runtime behavior flags +- Debugging and development options +""" + +import sys +from typing import Any, TYPE_CHECKING + +from torch._environment import is_fbcode +from torch.utils._config_module import install_config_module + + +# this flag controls whether we use new functional tracer. It +# should be True in the long term. +use_new_tracer_experimental = True + +# this flag is used to control whether we want to instrument +# fake tensor creation to track potential leaks. It is off +# by default, but user can turn it on to debug leaks. +detect_non_strict_fake_tensor_leaks = False + +# error on potentially pre-dispatch/non-strict tracing limitation +# this type of error usually happens when we encounter an op +# that we don't know how to proxy, resulting in untracked fake tensors +error_on_lifted_constant_tensors = True + +# enable auto_functionalized_v2 in export +# We turn this off in fbcode due to downstream users not +# being ready to handle auto_functionalized_v2. +enable_auto_functionalized_v2_for_export = not is_fbcode() + +use_legacy_dynamo_graph_capture = True + + +if TYPE_CHECKING: + from torch.utils._config_typing import * # noqa: F401, F403 + + def _make_closure_patcher(**changes: Any) -> Any: ... + + +install_config_module(sys.modules[__name__]) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/error.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/error.py new file mode 100644 index 0000000000000000000000000000000000000000..03b7f52fb9de435b9e58fa4a0bb141cc191e84c5 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/error.py @@ -0,0 +1,56 @@ +from enum import Enum + + +class ExportErrorType(Enum): + # User providing invalid inputs to either tracer, or other public facing APIs + INVALID_INPUT_TYPE = 1 + + # User returning values from their models that we don't support. + INVALID_OUTPUT_TYPE = 2 + + # Generated IR does not conform to Export IR Specification. + VIOLATION_OF_SPEC = 3 + + # User's code contains types and functionalities we don't support. + NOT_SUPPORTED = 4 + + # User's code didn't provide necessary details for us to successfully trace and export. + # For example, we use a lot of decorators and ask users to annotate their model. + MISSING_PROPERTY = 5 + + # User is using an API without proper initialization step. + UNINITIALIZED = 6 + + +def internal_assert(pred: bool, assert_msg: str) -> None: + """ + This is exir's custom assert method. It internally just throws InternalError. + Note that the sole purpose is to throw our own error while maintaining similar syntax + as python assert. + """ + + if not pred: + raise InternalError(assert_msg) + + +class InternalError(Exception): + """ + Raised when an internal invariance is violated in EXIR stack. + Should hint users to report a bug to dev and expose the original + error message. + """ + + def __init__(self, message: str) -> None: + super().__init__(message) + + +class ExportError(Exception): + """ + This type of exception is raised for errors that are directly caused by the user + code. In general, user errors happen during model authoring, tracing, using our public + facing APIs, and writing graph passes. + """ + + def __init__(self, error_code: ExportErrorType, message: str) -> None: + prefix = f"[{error_code}]: " + super().__init__(prefix + message) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/verifier.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/verifier.py new file mode 100644 index 0000000000000000000000000000000000000000..8f8ab1be26483a911872aef476b4a7845daeceb1 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_export/verifier.py @@ -0,0 +1,531 @@ +# mypy: allow-untyped-defs +import inspect +import math +import operator +from collections.abc import Iterable +from typing import Any, final, TYPE_CHECKING + +import torch +from torch._library.opaque_object import is_opaque_type +from torch._ops import HigherOrderOperator, OpOverload +from torch._subclasses.fake_tensor import FakeTensor +from torch.export.graph_signature import ( + CustomObjArgument, + InputKind, + SymBoolArgument, + SymFloatArgument, + SymIntArgument, + TensorArgument, + TokenArgument, +) +from torch.fx import GraphModule + + +if TYPE_CHECKING: + from torch.export.exported_program import ExportedProgram + + +class SpecViolationError(Exception): + pass + + +def is_functional(op: OpOverload) -> bool: + return not op._schema.is_mutable + + +def _check_has_fake_tensor(node: torch.fx.Node) -> None: + # TODO(angelayi): remove this in favor of _check_val + return _check_val(node) + + +def _check_val(node: torch.fx.Node) -> None: + from torch.fx.experimental.symbolic_shapes import SymBool, SymFloat, SymInt + + def _check_correct_val(val): + if val is None: + return True + elif isinstance(val, (int, bool, str, float)): + return True + elif isinstance( + val, (torch.memory_format, torch.dtype, torch.device, torch.layout) + ): + return True + elif isinstance( + val, (FakeTensor, torch.Tensor) + ): # TODO(zhxchen17) Remove Tensor. + return True + elif isinstance(val, (SymInt, SymFloat, SymBool)): + return True + elif isinstance(val, CustomObjArgument): + return True + elif isinstance(val, Iterable): + return all(_check_correct_val(x) for x in val) + elif is_opaque_type(type(val)): + return True + return False + + def _no_returns(op): + if not isinstance(op, OpOverload): + return False + return len(op._schema.returns) == 0 + + if "val" not in node.meta: + if node.op == "call_function" and _no_returns(node.target): + return + raise SpecViolationError(f"Node.meta {node.name} is missing val field.") + + val = node.meta["val"] + if not _check_correct_val(val): + raise SpecViolationError(f"Node.meta {node.name} has invalid val field {val}") + + +def _check_torch_fn(node: torch.fx.Node) -> None: + torch_fn = node.meta.get("torch_fn") + if torch_fn is None: + raise SpecViolationError( + f"Unable to find torch_fn metadata for node {node.name}" + ) + if ( + not isinstance(torch_fn, tuple) + and isinstance(torch_fn[0], str) + and isinstance(torch_fn[1], str) + ): + raise SpecViolationError( + f"Node.meta {node.name} has invalid torch_fn field {torch_fn}" + ) + + +class _VerifierMeta(type): + _registry: dict[str, type["Verifier"]] = {} + + def __new__(metacls, name, bases, attrs): + if bases: + if "check" in attrs or "_check_graph_module" in attrs: + raise SyntaxError("Overriding method check is not allowed.") + assert "dialect" in attrs and attrs["dialect"] != "ATEN" + else: + assert "check" in attrs + assert "_check_graph_module" in attrs + assert attrs["dialect"] == "ATEN" + + assert isinstance(attrs["dialect"], str) + ret = type.__new__(metacls, name, bases, attrs) + metacls._registry[attrs["dialect"]] = ret # type: ignore[assignment] + return ret + + +def getattr_recursive(obj: Any, target: str) -> Any: + target_atoms = target.split(".") + attr_itr = obj + for i, atom in enumerate(target_atoms): + if not hasattr(attr_itr, atom): + raise RuntimeError( + f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}" + ) + attr_itr = getattr(attr_itr, atom) + return attr_itr + + +class Verifier(metaclass=_VerifierMeta): + dialect = "ATEN" + + def allowed_builtin_ops(self) -> list: + return [ + operator.getitem, + operator.add, + operator.mul, + operator.sub, + operator.truediv, + operator.ge, + operator.le, + operator.gt, + operator.lt, + operator.eq, + operator.ne, + operator.floordiv, + operator.mod, + operator.and_, + operator.or_, + operator.not_, + operator.pow, + operator.neg, + operator.abs, + operator.lshift, + operator.rshift, + math.ceil, + math.floor, + math.trunc, + round, + ] + + def allowed_op_types(self) -> tuple[type[Any], ...]: + return (OpOverload, HigherOrderOperator) + + def allowed_getattr_types(self) -> tuple[type[Any], ...]: + return (torch.fx.GraphModule, torch.utils._pytree.TreeSpec) + + def allowed_getattr_types_for_subgm(self) -> tuple[type[Any], ...]: + # subgm in HOP's argument could has have getattr(weight) nodes, thus stateful + return ( + torch.fx.GraphModule, + torch.nn.parameter.Parameter, + torch.Tensor, # for buffer and constant tensor + torch.utils._pytree.TreeSpec, + ) + + def check_valid_op(self, op): + pass + + def check_additional(self, gm: GraphModule) -> None: + """ + Additional checks that are specific to some dialects. + """ + + @final + def check(self, ep: "ExportedProgram") -> None: + self._check_graph_module(ep.graph_module) + _verify_exported_program_module_call_graph(ep) + _verify_exported_program_signature(ep) + + @final + def _check_graph_module(self, gm: torch.fx.GraphModule) -> None: + def _allowed_getattr_types(is_toplevel_gm) -> tuple[type[Any], ...]: + if is_toplevel_gm: + ret = self.allowed_getattr_types() + else: + ret = self.allowed_getattr_types_for_subgm() + assert not any(t is object for t in ret) + return ret + + def _check_valid_op(op) -> None: + def _allowed_builtin_ops() -> list: + ret = self.allowed_builtin_ops() + assert all(inspect.isbuiltin(op) for op in ret) + return ret + + def _allowed_op_types() -> tuple[type[Any], ...]: + ret = self.allowed_op_types() + assert not any(t is object for t in ret) + return ret + + # TODO Remove this allowlist. + _allowed_torch_functions = ( + torch.autograd.grad_mode.set_grad_enabled, + torch.sym_int, + torch.sym_float, + torch.sym_ite, + torch.sym_max, + torch.sym_min, + torch.sym_not, + torch.sym_sqrt, + torch.sym_sum, + torch.export.custom_ops._call_custom_autograd_function_in_pre_dispatch, + # TODO (tmanlaibaatar) + # Predispatch export is able to contain autograd ops. + # These will be modeled as HOO later + torch._C._set_grad_enabled, + torch.amp.autocast_mode._enter_autocast, + torch.amp.autocast_mode._exit_autocast, + torch.fx.experimental.symbolic_shapes.cast_symbool_to_symint_guardless, + torch._functorch.predispatch._add_batch_dim, + torch._functorch.predispatch._remove_batch_dim, + torch._functorch.predispatch._vmap_increment_nesting, + torch._functorch.predispatch._vmap_decrement_nesting, + torch._functorch.predispatch.lazy_load_decompositions, + ) + + if not isinstance(op, _allowed_op_types()): + if ( + op not in _allowed_builtin_ops() + and op not in _allowed_torch_functions + ): + raise SpecViolationError( + f"Operator '{op}' is not an allowed operator type: {_allowed_op_types()}\n" + f"Valid builtin ops: {_allowed_builtin_ops()}" + f"Valid torch functions: {_allowed_torch_functions}" + ) + + if isinstance(op, OpOverload): + # All ops functional + # TODO (tmanlaibaatar) more proper way is needed here + if self.dialect != "TRAINING" and not is_functional(op): + raise SpecViolationError(f"operator '{op}' is not functional") + self.check_valid_op(op) + + for mod in gm.modules(): + is_toplevel_gm = mod is gm + + if not isinstance(mod, torch.fx.GraphModule): + continue + + mod.graph.lint() + for node in mod.graph.nodes: + # TODO(T140410192): should have fake tensor for all dialects + if node.op in {"call_module", "call_method"}: + raise SpecViolationError( + f"call_module is not valid: got a class '{node.target}' ", + ) + + elif node.op == "call_function": + _check_val(node) + + _check_valid_op(node.target) + + elif node.op == "get_attr": + if not isinstance(node.target, str): + raise SpecViolationError( + f"Expected get_attr target to be string, but got {type(node.target)}" + ) + + attr = getattr_recursive(mod, node.target) + if isinstance(attr, torch.nn.Module): + + def _is_type(name, ty): + return isinstance(getattr(attr, name, None), ty) + + if type(attr).__name__ == "LoweredBackendModule": + if ( + _is_type("backend_id", str) + and hasattr(attr, "original_module") + and hasattr(attr, "module_name") + and getattr(attr, "backend_id", None) == "aoti" + ): + continue + if ( + _is_type("backend_id", str) + and _is_type("processed_bytes", bytes) + and _is_type("compile_specs", list) + and hasattr(attr, "original_module") + ): + continue + else: + backend_id = getattr(attr, "backend_id", None) + processed_bytes = getattr(attr, "processed_bytes", None) + compile_specs = getattr(attr, "compile_specs", None) + raise SpecViolationError( + f"Invalid get_attr type {type(attr)}. \n" + f"LoweredBackendModule fields: " + f"backend_id(str) : {type(backend_id)}, " + f"processed_bytes(bytes) : {type(processed_bytes)}, " + f"compile_specs(list) : {type(compile_specs)}" + ) + elif type(attr).__name__ == "AOTInductorEPModule": + continue + + elif type(attr).__name__ == "AOTInductorRunnerWrapper": + continue + + if not isinstance(attr, _allowed_getattr_types(is_toplevel_gm)): + raise SpecViolationError( + f"Invalid get_attr type {type(attr)} on target {node.target}. \n" + f"Valid get_attr types: {_allowed_getattr_types(is_toplevel_gm)}" + ) + + elif node.op == "placeholder": + _check_val(node) + # TODO(zhxchen17) + # elif node.op == "output": + # _check_flattened_outputs() + + self.check_additional(gm) + + +class TrainingIRVerifier(Verifier): + dialect = "TRAINING" + + +def _verify_exported_program_module_call_graph(exported_program) -> None: + module_call_graph = exported_program.module_call_graph + nodes = {node.name for node in exported_program.graph.nodes} + for entry in module_call_graph: + if entry.signature is not None: + for arg in entry.signature.inputs: + if arg.name and arg.name not in nodes: + raise SpecViolationError( + f"Input {arg.name} does not exist in the graph." + ) + for arg in entry.signature.outputs: + if arg.name and arg.name not in nodes: + raise SpecViolationError( + f"Output {arg.name} does not exist in the graph." + ) + + +def _verify_exported_program_signature(exported_program) -> None: + # Check ExportedProgram signature matches + gs = exported_program.graph_signature + + # Check every node in the signature exists in the graph + input_node_names = [ + node.name for node in exported_program.graph.nodes if node.op == "placeholder" + ] + + if len(input_node_names) != len(gs.input_specs): + raise SpecViolationError( + f"Number of graph inputs ({len(input_node_names)}) " + f"does not match number of inputs in the graph signature ({len(gs.input_specs)})" + ) + + for input_spec, node in zip(gs.input_specs, input_node_names): + if isinstance( + input_spec.arg, + (TensorArgument, SymIntArgument, SymFloatArgument, SymBoolArgument), + ): + if input_spec.arg.name != node: + raise SpecViolationError( + f"Input spec name {input_spec.arg.name} does not match node name {node}" + ) + + if input_spec.kind == InputKind.USER_INPUT: + continue + + elif input_spec.kind == InputKind.PARAMETER: + if not isinstance(input_spec.arg, TensorArgument): + raise SpecViolationError( + f"Parameter {input_spec.name} is not a tensor argument. Found {input_spec.arg} instead." + ) + if input_spec.target is None: + raise SpecViolationError( + f"InputSpec for {input_spec.name} has no target." + ) + + param = input_spec.target + if param not in exported_program.state_dict: + raise SpecViolationError(f"Parameter {param} is not in the state dict.") + + if not isinstance(exported_program.state_dict[param], torch.nn.Parameter): + raise SpecViolationError( + f"State dict entry for parameter {param} is not an instance of torch.nn.Parameter." + ) + + elif input_spec.kind == InputKind.BUFFER: + if not isinstance(input_spec.arg, TensorArgument): + raise SpecViolationError( + f"Buffer {input_spec.name} is not a tensor argument. Found {input_spec.arg} instead." + ) + if input_spec.target is None: + raise SpecViolationError( + f"InputSpec for {input_spec.name} has no target." + ) + + buffer = input_spec.target + if input_spec.persistent is None: + raise SpecViolationError( + f"Buffer {buffer} is missing a persistence flag" + ) + + if ( + input_spec.persistent is True + and buffer not in exported_program.state_dict + ): + raise SpecViolationError(f"Buffer {buffer} is not in the state dict.") + + if input_spec.persistent is False and buffer in exported_program.state_dict: + raise SpecViolationError( + f"Non-persistent buffer {buffer} is in the state dict, it should not be." + ) + elif input_spec.kind == InputKind.CONSTANT_TENSOR: + if not isinstance(input_spec.arg, TensorArgument): + raise SpecViolationError( + f"Constant tensor {input_spec.name} is not a tensor argument. Found {input_spec.arg} instead." + ) + if input_spec.target is None: + raise SpecViolationError( + f"InputSpec for {input_spec.name} has no target." + ) + + tensor_const = input_spec.target + if tensor_const not in exported_program.constants: + raise SpecViolationError( + f"Constant tensor {tensor_const} is not in the constants dictionary." + ) + elif input_spec.kind == InputKind.CUSTOM_OBJ: + if not isinstance(input_spec.arg, CustomObjArgument): + raise SpecViolationError( + f"Custom object {input_spec.name} is not a custom object argument. Found {input_spec.arg} instead." + ) + if input_spec.target is None: + raise SpecViolationError( + f"InputSpec for {input_spec.name} has no target." + ) + + custom_obj = input_spec.target + if custom_obj not in exported_program.constants: + raise SpecViolationError( + f"Custom object {custom_obj} is not in the constants dictionary." + ) + elif input_spec.kind == InputKind.TOKEN: + if not isinstance(input_spec.arg, TokenArgument): + raise SpecViolationError( + f"Constant tensor {input_spec.name} is not a tensor argument. Found {input_spec.arg} instead." + ) + else: + raise SpecViolationError(f"Unknown InputKind {input_spec.kind}.") + + # Check outputs + output_node = list(exported_program.graph.nodes)[-1] + assert output_node.op == "output" + output_nodes = [ + arg.name if isinstance(arg, torch.fx.Node) else arg + for arg in output_node.args[0] + ] + + if len(output_nodes) != len(gs.output_specs): + raise SpecViolationError( + f"Number of output nodes {len(output_nodes)} is different " + "Than the number of outputs specified by the graph signature: \n" + f"Number of mutated buffers: {len(gs.buffers_to_mutate)}. \n" + f"Number of user outputs: {len(gs.user_outputs)}. \n" + ) + + num_tokens = len(gs.output_tokens) + end = ( + len(gs.buffers_to_mutate) + + len(gs.parameters_to_mutate) + + len(gs.user_inputs_to_mutate) + + num_tokens + ) + mutate_nodes: list[str] = output_nodes[num_tokens:end] + user_output_nodes = output_nodes[end : end + len(gs.user_outputs)] + + for mutation_node in mutate_nodes: + if mutation_node in gs.buffers_to_mutate: + if gs.buffers_to_mutate[mutation_node] not in gs.buffers: + raise SpecViolationError( + f"Buffer output {mutation_node} does not point to a buffer that exists. \n" + f"Dict of buffers that are mutated, in order: {gs.buffers_to_mutate} \n" + f"Buffer nodes available: {gs.buffers} \n" + ) + elif mutation_node in gs.parameters_to_mutate: + if gs.parameters_to_mutate[mutation_node] not in gs.parameters: + raise SpecViolationError( + f"Parameter output {mutation_node} does not point to a parameter that exists. \n" + f"Dict of parameters that are mutated, in order: {gs.parameters_to_mutate} \n" + f"Parameter nodes available: {gs.parameters} \n" + ) + elif mutation_node in gs.user_inputs_to_mutate: + if gs.user_inputs_to_mutate[mutation_node] not in gs.user_inputs: + raise SpecViolationError( + f"User input output {mutation_node} does not point to a user input that exists. \n" + f"Dict of user inputs that are mutated, in order: {gs.user_inputs_to_mutate} \n" + f"User input nodes available: {gs.user_inputs} \n" + ) + else: + raise SpecViolationError( + f"Mutation node {mutation_node} is neither a buffer nor a user input. " + f"Buffers to mutate: {gs.buffers_to_mutate}, User inputs to mutate: {gs.user_inputs_to_mutate}" + ) + + for user_output_node, user_output_name in zip(user_output_nodes, gs.user_outputs): + if user_output_node != user_output_name: + raise SpecViolationError( + f"User output {user_output_node} is not in the correct " + "order or is not found in the " + f"exported program's user_output list: {gs.user_outputs}. " + ) + + +def load_verifier(dialect: str) -> type[Verifier]: + if dialect == "ATEN" or dialect == "": + return _VerifierMeta._registry.get(dialect, Verifier) + return _VerifierMeta._registry[dialect] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__autotune_main__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__autotune_main__.py new file mode 100644 index 0000000000000000000000000000000000000000..1eb5ca86e8c185e9c355e6dea152b53a3f181519 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__autotune_main__.py @@ -0,0 +1,33 @@ +import argparse +import logging +import os + +from torch._inductor.autotune_process import TuningProcess +from torch._inductor.compile_worker.utils import _async_compile_initializer + + +log = logging.getLogger(__name__) + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--parent", type=int) + parser.add_argument("--read-fd", type=int) + parser.add_argument("--write-fd", type=int) + args = parser.parse_args() + read_pipe = os.fdopen(args.read_fd, "rb") + write_pipe = os.fdopen(args.write_fd, "wb") + + try: + # Ensures the subprocess exits if the parent crashes: + _async_compile_initializer(args.parent) + TuningProcess.process_main(read_pipe, write_pipe) + except Exception: + log.exception("Uncaught exception in autotune subprocess") + finally: + read_pipe.close() + write_pipe.close() + + +if __name__ == "__main__": + main() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8e6fde9280c4aa3f5e28d47e32c35c581142c9c6 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/__init__.py @@ -0,0 +1,447 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import io +import logging +import os +from typing import Any, IO, Literal, Optional, TYPE_CHECKING, Union + +import torch.fx + +from .standalone_compile import CompiledArtifact # noqa: TC001 + + +if TYPE_CHECKING: + from torch._inductor.utils import InputType + from torch.export import ExportedProgram + from torch.export.pt2_archive._package import AOTICompiledModel + from torch.export.pt2_archive._package_weights import Weights + from torch.types import FileLike + +__all__ = [ + "compile", + "list_mode_options", + "list_options", + "cudagraph_mark_step_begin", + "standalone_compile", +] + + +log = logging.getLogger(__name__) + + +def compile( + gm: torch.fx.GraphModule, + example_inputs: list[InputType], + options: Optional[dict[str, Any]] = None, +): + """ + Compile a given FX graph with TorchInductor. This allows compiling + FX graphs captured without using TorchDynamo. + + Args: + gm: The FX graph to compile. + example_inputs: List of tensor inputs. + options: Optional dict of config options. See `torch._inductor.config`. + + Returns: + Callable with same behavior as gm but faster. + """ + from .compile_fx import compile_fx + + return compile_fx(gm, example_inputs, config_patches=options) + + +def aoti_compile_and_package( + exported_program: ExportedProgram, + _deprecated_unused_args=None, + _deprecated_unused_kwargs=None, + *, + package_path: Optional[FileLike] = None, + inductor_configs: Optional[dict[str, Any]] = None, +) -> str: + """ + Compiles the exported program with AOTInductor, and packages it into a .pt2 + artifact specified by the input package_path. To load the package, you can + call ``torch._inductor.aoti_load_package(package_path)``. + + An example usage is as follows: + + .. code-block:: python + + ep = torch.export.export(M(), ...) + aoti_file = torch._inductor.aoti_compile_and_package( + ep, package_path="my_package.pt2" + ) + compiled_model = torch._inductor.aoti_load_package("my_package.pt2") + + To compile and save multiple models into a single ``.pt2`` artifact, you can do + the following: + + .. code-block:: python + + ep1 = torch.export.export(M1(), ...) + aoti_file1 = torch._inductor.aot_compile( + ep1, ..., options={"aot_inductor.package": True} + ) + ep2 = torch.export.export(M2(), ...) + aoti_file2 = torch._inductor.aot_compile( + ep2, ..., options={"aot_inductor.package": True} + ) + + from torch._inductor.package import package_aoti, load_package + + package_aoti("my_package.pt2", {"model1": aoti_file1, "model2": aoti_file2}) + + compiled_model1 = load_package("my_package.pt2", "model1") + compiled_model2 = load_package("my_package.pt2", "model2") + + Args: + exported_program: An exported program created through a call from torch.export + package_path: Optional specified path to the generated .pt2 artifact. + inductor_configs: Optional dictionary of configs to control inductor. + + Returns: + Path to the generated artifact + """ + from torch.export import ExportedProgram + + from .debug import aot_inductor_minifier_wrapper + + if not isinstance(exported_program, ExportedProgram): + raise ValueError("Only ExportedProgram is supported") + + if exported_program.example_inputs is None: + raise RuntimeError( + "exported_program.example_inputs is required to be set in order " + "for AOTInductor compilation." + ) + + if _deprecated_unused_args is not None or _deprecated_unused_kwargs is not None: + log.warning( + "You no longer need to specify args/kwargs to aoti_compile_and_package " + "as we can get this information from exported_program.example_inputs." + ) + + assert ( + package_path is None + or ( + isinstance(package_path, (io.IOBase, IO)) + and package_path.writable() + and package_path.seekable() + ) + or ( + isinstance(package_path, (str, os.PathLike)) + and os.fspath(package_path).endswith(".pt2") + ) + ), ( + f"Expect package path to be a file ending in .pt2, is None, or is a buffer. Instead got {package_path}" + ) + + inductor_configs = inductor_configs or {} + inductor_configs["aot_inductor.package"] = True + + if inductor_configs.get("aot_inductor.output_path"): + raise RuntimeError( + "Please pass in a package path to aot_inductor_compile() instead " + "of setting the aot_inductor.output_path config." + ) + + # a wrapper around aoti_compile_and_package_inner. + return aot_inductor_minifier_wrapper( + _aoti_compile_and_package_inner, + exported_program, + # pyrefly: ignore [bad-argument-type] + package_path=package_path, + inductor_configs=inductor_configs, + ) + + +def _aoti_compile_and_package_inner( + gm: torch.nn.Module, + # flat_example_inputs: List[Any], + args: tuple[Any], + kwargs: Optional[dict[str, Any]] = None, + *, + load_and_run: bool = False, + check_accuracy: Optional[str] = None, + package_path: Optional[Union[str, io.BytesIO]] = None, + inductor_configs: Optional[dict[str, Any]] = None, +): + """ + See docstring for aoti_compile_and_package. + + If `load_and_run` is True, this function will load the compiled model and run it. + This is for the minifier to check the correctness of the compiled model. + + If `check_accuracy` is set, this function will check the accuracy of the compiled + model against gm. kwargs must be None if check_accuracy is set. + "strict_accuracy" means "we will minify any time we see anything that + diverges", whereas "accuracy" is more conservative, and will only minify if there + is a meaningful fp64 divergence + """ + + if check_accuracy: + assert kwargs is None or len(kwargs) == 0, ( + "when checking for accuracy, the inputs must have been flattened and kwargs is None" + ) + + from .package import package_aoti + + assert isinstance(gm, torch.fx.GraphModule) + + kwargs = kwargs or {} + + aoti_files = aot_compile(gm, args, kwargs, options=inductor_configs) + assert isinstance(aoti_files, list) + + if package_path is None: + path = [ + os.path.splitext(file)[0] + for file in aoti_files + if isinstance(file, str) and os.path.splitext(file)[1] == ".so" + ] + if len(path) == 0: + path = [ + os.path.splitext(file)[0] + for file in aoti_files + if isinstance(file, str) and os.path.splitext(file)[1] == ".cpp" + ] + package_path = path[0] + ".pt2" + + res = package_aoti(package_path, aoti_files) + assert res == package_path + + if load_and_run or check_accuracy: + compiled_model = aoti_load_package(package_path) + if check_accuracy: + from torch._dynamo.debug_utils import AccuracyError, same_two_models + + # This might look inverted but it's not. strict_accuracy means "we will + # minify any time we see anything that diverges", whereas accuracy is more + # conservative, and will only minify if there is a meaningful fp64 + # divergence + not_strict_accuracy = check_accuracy == "accuracy" + if not same_two_models( + gm, + compiled_model, # type: ignore[arg-type] + args, + only_fwd=True, + require_fp64=not_strict_accuracy, + ignore_non_fp=not_strict_accuracy, + ): + raise AccuracyError("Bad accuracy detected") + else: + compiled_model(*args, **kwargs) + + return package_path + + +def aoti_load_package( + path: FileLike, run_single_threaded: bool = False, device_index: int = -1 +) -> AOTICompiledModel: + """ + Loads the model from the PT2 package. + + If multiple models were packaged into the PT2, this will load the default + model. To load a specific model, you can directly call the load API + + .. code-block:: python + + from torch._inductor.package import load_package + + compiled_model1 = load_package("my_package.pt2", "model1") + compiled_model2 = load_package("my_package.pt2", "model2") + + Args: + path: Path to the .pt2 package + run_single_threaded (bool): Whether the model should be run without + thread synchronization logic. This is useful to avoid conflicts with + CUDAGraphs. + device_index (int): The index of the device to which the PT2 package is + to be loaded. By default, `device_index=-1` is used, which corresponds + to the device `cuda` when using CUDA. Passing `device_index=1` would + load the package to `cuda:1`, for example. + """ + from torch._inductor.package import load_package + + return load_package( + path, run_single_threaded=run_single_threaded, device_index=device_index + ) + + +def aot_compile( + gm: torch.fx.GraphModule, + args: tuple[Any, ...], + kwargs: Optional[dict[str, Any]] = None, + *, + options: Optional[dict[str, Any]] = None, +) -> Union[str, list[Union[str, Weights]], torch.fx.GraphModule]: + """ + Ahead-of-time compile a given FX graph with TorchInductor into a shared library. + + Args: + gm: The FX graph to compile. + args: Example arguments + kwargs: Example keyword arguments + options: Optional dict of config options. See `torch._inductor.config`. + + Returns: + Path to the generated shared library, or a list of files generated by + AOTI if aot_inductor.package=True. + TODO: make it return a list by default + """ + from .compile_fx import _aoti_flatten_inputs, compile_fx_aot + + if hasattr(gm, "_guards_fn"): + # Do not compile the guards function, since it may contain checks + # that are not currently supported by AOTI. In particular, non-Tensor + # arguments are converted to None and will fail specialization checks. + node = next(iter(gm.graph.find_nodes(op="call_module", target="_guards_fn"))) + gm.graph.erase_node(node) + delattr(gm, "_guards_fn") + gm.recompile() + + flat_example_inputs, options = _aoti_flatten_inputs( + gm, args, kwargs, options=options + ) + from torch._export.utils import _compiling_state_context + + with _compiling_state_context(): + return compile_fx_aot( + gm, + flat_example_inputs, # type: ignore[arg-type] + config_patches=options, + ) + + +lite_mode_options = { + # Fallback by default unless users explicitly annotated with + # regional inductor compile. + "fallback_by_default": True, + "selective_decompose": True, + # Disable reorder optimizations + "reorder_for_peak_memory": False, + "reorder_for_compute_comm_overlap": False, + "triton.reorder_for_reducing_graph_partitions": False, + # Disable pre-, joint-, post-grad passes + "use_pre_grad_passes": False, + "use_joint_graph_passes": False, + "use_post_grad_passes": False, + # Disable dead code elimination (dce) and buffer reuse + "use_dce": False, + "allow_buffer_reuse": False, +} + + +def list_mode_options( + mode: Optional[str] = None, dynamic: Optional[bool] = None +) -> dict[str, Any]: + r"""Returns a dictionary describing the optimizations that each of the available + modes passed to `torch.compile()` performs. + + Args: + mode (str, optional): The mode to return the optimizations for. + If None, returns optimizations for all modes + dynamic (bool, optional): Whether dynamic shape is enabled. + + Example:: + >>> torch._inductor.list_mode_options() + """ + + mode_options: dict[str, dict[str, bool]] = { + "default": {}, + # lite backend for opt-in optimizations + "lite": lite_mode_options, + # enable cudagraphs + "reduce-overhead": { + "triton.cudagraphs": True, + }, + # enable max-autotune + "max-autotune-no-cudagraphs": { + "max_autotune": True, + "coordinate_descent_tuning": True, + }, + # enable max-autotune + # enable cudagraphs + "max-autotune": { + "max_autotune": True, + "triton.cudagraphs": True, + "coordinate_descent_tuning": True, + }, + } + try: + return mode_options[mode] if mode else mode_options + except KeyError as e: + raise RuntimeError( + f"Unrecognized mode={mode}, should be one of: {', '.join(mode_options.keys())}" + ) from e + + +def list_options() -> list[str]: + r"""Returns a dictionary describing the optimizations and debug configurations + that are available to `torch.compile()`. + + The options are documented in `torch._inductor.config`. + + Example:: + + >>> torch._inductor.list_options() + """ + + from torch._inductor import config + + current_config: dict[str, Any] = config.get_config_copy() + + return list(current_config.keys()) + + +def cudagraph_mark_step_begin(): + "Indicates that a new iteration of inference or training is about to begin." + from .cudagraph_trees import mark_step_begin + + mark_step_begin() + + +def standalone_compile( + gm: torch.fx.GraphModule, + example_inputs: list[InputType], + *, + dynamic_shapes: Literal[ + "from_example_inputs", "from_tracing_context", "from_graph" + ] = "from_graph", + options: Optional[dict[str, Any]] = None, + aot: bool = False, # AOT mode, which uses BundledAOTAutogradCache +) -> CompiledArtifact: + """ + Precompilation API for inductor. + + .. code-block:: python + + compiled_artifact = torch._inductor.standalone_compile(gm, args) + compiled_artifact.save(path=path, format="binary") + + # Later on a new process + loaded = torch._inductor.CompiledArtifact.load(path=path, format="binary") + compiled_out = loaded(*args) + + Args: + gm: Graph Module + example_inputs: Inputs for the graph module + dynamic_shapes: If "from_graph" (default), we will use the dynamic + shapes in the passed-in graph module. + If "from_tracing_context", we use the dynamic shape info in the + ambient tracing context. + If "from_example_inputs", we will specialize the graph on the + example_inputs. + options: Inductor compilation options + + Returns: + CompiledArtifact that can be saved to disk or invoked directly. + """ + from .standalone_compile import standalone_compile + + options = options if options else {} + return standalone_compile( + gm, example_inputs, dynamic_shapes=dynamic_shapes, options=options, aot=aot + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/analyze_preserves_zero_mask.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/analyze_preserves_zero_mask.py new file mode 100644 index 0000000000000000000000000000000000000000..0674d1566c33b46ba439e821ddd3ca9784c84b31 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/analyze_preserves_zero_mask.py @@ -0,0 +1,166 @@ +import dataclasses +import itertools +from typing import Any, Optional, TYPE_CHECKING + +import sympy + +import torch +from torch._inductor import config +from torch._inductor.dtype_propagation import DtypePropagationOpsHandler +from torch._inductor.index_propagation import SymPyOps, TypedExpr + +from .ops_handler import DefaultHandler +from .virtualized import StoreMode, V + + +if TYPE_CHECKING: + from torch._inductor.scheduler import SchedulerNode + + +def construct_symbol(count: int, dtype: torch.dtype) -> sympy.Symbol: + return sympy.Symbol(f"unknown_{count}") + + +class PreservesZeros(SymPyOps, DefaultHandler): + """ + For prologue kernels where the loads are masked, does the final store of this kernel preserve + the zeros. + """ + + def __init__(self) -> None: + self.count = itertools.count(0) + self.store_preserves_zeros: Optional[bool] = None + self.dtype_prop = DtypePropagationOpsHandler() + + def load(self, name: str, index: sympy.Expr) -> TypedExpr: + # In prologue fusion, all loads get broadcasted + dtype = self.dtype_prop.load(name, index) + return TypedExpr( + sympy.Float(0) if dtype.is_floating_point else sympy.Integer(0), dtype + ) + + def store( + self, name: str, index: sympy.Expr, value: TypedExpr, mode: "StoreMode" = None + ) -> None: + assert isinstance(self, PreservesZeros) + # should only have a single store in prologue + assert self.store_preserves_zeros is None + self.store_preserves_zeros = value.is_constant() and value.expr == 0 + + def indirect_indexing(self, *args: Any, **kwargs: Any) -> sympy.Expr: + return construct_symbol(next(self.count), torch.int32) + + def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: + from torch._inductor.codegen.common import OpDecompositions + + if hasattr(OpDecompositions, name): + return getattr(OpDecompositions, name)(*args, **kwargs).value + + dtype = getattr(self.dtype_prop, name)(*args, **kwargs) + return TypedExpr(construct_symbol(next(self.count), dtype), dtype) + + +def prologue_preserves_zero_mask(prologue: "SchedulerNode") -> bool: + """ + Does this prologue preserve zero masks + """ + preserves_zeros = PreservesZeros() + with V.set_ops_handler(preserves_zeros): + prologue._body(*prologue.get_ranges()) + + store_preserves_zeros = preserves_zeros.store_preserves_zeros + assert isinstance(store_preserves_zeros, bool) + + return store_preserves_zeros + + +@dataclasses.dataclass +class DTypeContainer: + dtype: torch.dtype + is_scalar: bool = False + + +class RecordLowPrecisionOps(DefaultHandler): + def __init__(self, disallow_fp32_ops: bool = False) -> None: + self.disallow_fp32_ops = disallow_fp32_ops + self.low_precision_numeric_op = False + self.dtype_prop = DtypePropagationOpsHandler() + self.non_numeric_ops = ( + "to_dtype", + "constant", + "where", + ) + + def load(self, name: str, index: sympy.Expr) -> DTypeContainer: + return DTypeContainer(self.dtype_prop.load(name, index)) + + @staticmethod + def store( + name: str, index: sympy.Expr, value: TypedExpr, mode: "StoreMode" = None + ) -> None: + pass + + def check_bounds( + self, expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool + ) -> None: + pass + + @staticmethod + # pyrefly: ignore [bad-override] + def indirect_indexing(*args: Any, **kwargs: Any) -> sympy.Expr: + return sympy.S.Zero + + def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: + out_dtype = getattr(self.dtype_prop, name)(*args, **kwargs) + out = DTypeContainer(out_dtype, is_scalar=(name == "constant")) + if name == "constant": + return DTypeContainer(torch.float, is_scalar=True) + + uses_low_prec = any( + isinstance(dtype_cont, DTypeContainer) + and dtype_cont.dtype is not None + and low_prec_float(dtype_cont.dtype) + for dtype_cont in itertools.chain((out,), args, kwargs.values()) + ) + + if uses_low_prec and name not in self.non_numeric_ops: + self.low_precision_numeric_op = True + + if ( + self.disallow_fp32_ops + and out.dtype in (torch.float32, torch.float64) + and not out.is_scalar + ): + self.low_precision_numeric_op = True + + return out + + +def low_prec_float(dtype: torch.dtype) -> bool: + return dtype.is_floating_point and dtype.itemsize < 4 + + +def can_codegen_without_upcasts( + prologue: "SchedulerNode", + disallow_fp32_ops: bool = False, +) -> bool: + """ + Can this prologue be run without `upcast_to_fp32` while preserving numerics. + + This is only true if the node only contains dtype conversions, indexing, and other non-arithmetic operators. + + If disallow_fp32_ops is True, then we also disallow ops that are explicitly computed in fp32 or fp64. + """ + if prologue.get_operation_names() <= V.graph.low_precision_codegen_ops: + return True + + low_prec_analysis = RecordLowPrecisionOps(disallow_fp32_ops) + + # Need to turn off upcasting to do analysis of whether we can turn it off + with ( + config.patch("triton.codegen_upcast_to_fp32", False), + V.set_ops_handler(low_prec_analysis), + ): + prologue._body(*prologue.get_ranges()) + + return not low_prec_analysis.low_precision_numeric_op diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/aoti_eager.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/aoti_eager.py new file mode 100644 index 0000000000000000000000000000000000000000..991f1caaecbb9b0b6da39b41c96a34a7590deffa --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/aoti_eager.py @@ -0,0 +1,299 @@ +import json +import logging +import os +from collections.abc import Callable +from pathlib import Path +from typing import Any, Optional +from unittest import mock + +import torch +import torch._export +from torch._inductor.utils import is_cpu_device + +from .runtime.runtime_utils import cache_dir + + +log = logging.getLogger(__name__) + + +def aoti_eager_cache_dir(namespace: str, device: str) -> Path: + return Path(cache_dir()) / "aoti_eager" / namespace / device + + +def aoti_eager_op_conf_lock(op_func_name_with_overload: str) -> Any: + # Avoid circular import + from torch._inductor.codecache import get_lock_dir, LOCK_TIMEOUT + from torch.utils._filelock import FileLock + + op_conf_lock_file = f"{op_func_name_with_overload}.lock" + lock_dir = get_lock_dir() + return FileLock(os.path.join(lock_dir, op_conf_lock_file), timeout=LOCK_TIMEOUT) + + +def load_aoti_eager_cache( + ns: str, op_func_name_with_overload: str, device_type: str +) -> list[Optional[dict[str, Any]]]: + device_kernel_cache = aoti_eager_cache_dir(ns, device_type) + op_conf = device_kernel_cache / f"{op_func_name_with_overload}.json" + if not op_conf.exists(): + return [] + + try: + with aoti_eager_op_conf_lock(op_func_name_with_overload): + with open(op_conf) as f: + json_data = json.load(f) + for item in json_data: + # Get absolution path for kernel library + kernel_lib_abs_path = device_kernel_cache / item["kernel_path"] + item["kernel_path"] = kernel_lib_abs_path.as_posix() + + # Check if the kernel library exists + if not kernel_lib_abs_path.exists(): + return [] + + for metadata in item["meta_info"]: + if metadata.get("is_dynamic"): + raise NotImplementedError( + "Only support static shape for now" + ) + if ( + "device_type" in metadata + and metadata["device_type"] == "cpu" + ): + metadata["device_index"] = -1 + for dtype_key in ["dtype", "dtype_value"]: + if dtype_key in metadata: + metadata[dtype_key] = getattr( + torch, metadata[dtype_key].split(".")[-1] + ) + if "layout_value" in metadata: + metadata["layout_value"] = getattr( + torch, metadata["layout_value"].split(".")[-1] + ) + if "memory_format_value" in metadata: + metadata["memory_format_value"] = getattr( + torch, metadata["memory_format_value"].split(".")[-1] + ) + + return json_data + except Exception as e: + err_msg = f"Failed to load aoti eager cache: {e}" + log.exception(err_msg) + return [] + + +def supported_builtin_dtype_torch_dtype() -> dict[type, torch.dtype]: + return {int: torch.int32, float: torch.float, bool: torch.bool} + + +def supported_scalar_types() -> tuple[type, ...]: + type_to_torch_dtype = supported_builtin_dtype_torch_dtype() + return tuple(type_to_torch_dtype.keys()) + + +def extract_tensor_metadata(dynamic: bool, input: torch.Tensor) -> dict[str, Any]: + metadata: dict[str, Any] = {} + metadata["is_dynamic"] = dynamic + + assert isinstance(input, torch.Tensor) + metadata["device_type"] = f"{input.device.type}" + if is_cpu_device([input]): + metadata["device_index"] = -1 + else: + metadata["device_index"] = input.device.index + metadata["dtype"] = f"{input.dtype}" + metadata["sizes"] = list(input.size()) + metadata["strides"] = list(input.stride()) + metadata["requires_grad"] = input.requires_grad + metadata["dispatch_key_set"] = torch._C._dispatch_keys(input).raw_repr() + return metadata + + +def extract_tensor_list_metadata( + dynamic: bool, + input: list[torch.Tensor], +) -> dict[str, Any]: + metadata_list = [] + for item in input: + assert isinstance(item, torch.Tensor) + metadata_list.append(extract_tensor_metadata(dynamic, item)) + + metadata: dict[str, Any] = {} + metadata["tensor_list"] = metadata_list + return metadata + + +def extract_scalar_metadata(device_type: str, input: Any) -> dict[str, Any]: + assert isinstance(input, supported_scalar_types()) + metadata: dict[str, Any] = {} + metadata["is_dynamic"] = False + # Scalar tensor + metadata["device_type"] = device_type + metadata["device_index"] = -1 if device_type == "cpu" else 0 + type_to_torch_dtype = supported_builtin_dtype_torch_dtype() + metadata["dtype"] = f"{type_to_torch_dtype[type(input)]}" + metadata["scalar_value"] = input + return metadata + + +def extract_string_metadata(input: str) -> dict[str, Any]: + assert isinstance(input, str) + metadata: dict[str, Any] = {} + metadata["string_value"] = input + return metadata + + +def extract_dtype_metadata(input: torch.dtype) -> dict[str, Any]: + assert isinstance(input, torch.dtype) + metadata: dict[str, Any] = {} + metadata["dtype_value"] = f"{input}" + return metadata + + +def extract_device_metadata(input: torch.device) -> dict[str, Any]: + assert isinstance(input, torch.device) + metadata: dict[str, Any] = {} + metadata["device_type_value"] = f"{input.type}" + metadata["device_index_value"] = input.index + return metadata + + +def extract_layout_metadata(input: torch.layout) -> dict[str, Any]: + assert isinstance(input, torch.layout) + metadata: dict[str, Any] = {} + metadata["layout_value"] = f"{input}" + return metadata + + +def aoti_compile_with_persistent_cache( + ns: str, + op_func_name_with_overload: str, + device_type: str, + dynamic: bool, + f: Callable[..., Any], + args: tuple[Any], + kwargs: dict[str, Any], + *, + dynamic_shapes: Optional[dict[str, Any]] = None, + options: Optional[dict[str, Any]] = None, + remove_runtime_assertions: bool = False, + disable_constraint_solver: bool = False, +) -> str: + """ + Compile the given function with persistent cache for AOTI eager mode. + """ + assert not dynamic, "Only support static shape for now" + flattened_inputs = list(args) + list(kwargs.values()) + if not all( + isinstance( + input, + ( + supported_scalar_types(), + torch.Tensor, + list, + str, + torch.dtype, + torch.device, + torch.layout, + ), + ) + for input in flattened_inputs + ): + err_msg = f"Unsupported input types: {flattened_inputs}" + log.exception(err_msg) + raise NotImplementedError(err_msg) + + for input in flattened_inputs: + if isinstance(input, list) and not all( + isinstance(item, torch.Tensor) for item in input + ): + err_msg = f"_impl_with_aoti_compile encounters unsupported input types: {flattened_inputs}" + log.exception(err_msg) + raise NotImplementedError(err_msg) + + persistent_cache = aoti_eager_cache_dir(ns, device_type) + if not persistent_cache.exists(): + persistent_cache.mkdir(parents=True) + + persistent_cache_lib = persistent_cache / "lib" + if not persistent_cache_lib.exists(): + persistent_cache_lib.mkdir() + + with mock.patch.dict( + os.environ, + {"TORCHINDUCTOR_CACHE_DIR": persistent_cache_lib.absolute().as_posix()}, + ): + try: + kernel_lib_path = torch._export.aot_compile( + f, + args, + kwargs, + dynamic_shapes=dynamic_shapes, + remove_runtime_assertions=remove_runtime_assertions, + disable_constraint_solver=disable_constraint_solver, + # Some operations may have non-Tensor parameters like int, float, bool. These + # non-Tensor parameters will not be the input of the graph. Therefore, we do + # need to keep the same signature. + same_signature=False, + ) + assert isinstance(kernel_lib_path, str) + + kernel_metadata_items = [] + + for idx, input in enumerate(flattened_inputs): + if isinstance(input, torch.Tensor): + metadata = extract_tensor_metadata(dynamic, input) + elif isinstance(input, list): + assert all(isinstance(item, torch.Tensor) for item in input) + metadata = extract_tensor_list_metadata(dynamic, input) + elif isinstance(input, supported_scalar_types()): + metadata = extract_scalar_metadata(device_type, input) + elif isinstance(input, str): + metadata = extract_string_metadata(input) + elif isinstance(input, torch.dtype): + metadata = extract_dtype_metadata(input) + elif isinstance(input, torch.device): + metadata = extract_device_metadata(input) + elif isinstance(input, torch.layout): + metadata = extract_layout_metadata(input) + else: + raise NotImplementedError(f"Unsupported input type: {type(input)}") + + metadata["arg_order"] = idx + kernel_metadata_items.append(metadata) + + kernel_meta_info: dict[str, Any] = {} + kernel_meta_info["meta_info"] = kernel_metadata_items + kernel_meta_info["kernel_path"] = ( + Path(kernel_lib_path).relative_to(persistent_cache).as_posix() + ) + + json_data = [] + update_json = True + op_conf = persistent_cache / f"{op_func_name_with_overload}.json" + mode = "r" if op_conf.exists() else "w" + with aoti_eager_op_conf_lock(op_func_name_with_overload): + with open(op_conf, mode) as op_conf_file: + try: + json_data = json.load(op_conf_file) + except Exception: + json_data = [] + + assert isinstance(json_data, list) + for item in json_data: + assert isinstance(item, dict) + # Same kernel meta info already exists in the json file + if item["meta_info"] == kernel_metadata_items: + update_json = False + break + + if update_json: + json_data.append(kernel_meta_info) + with open(op_conf, "w") as op_conf_file: + json.dump(json_data, op_conf_file, indent=4) + + return kernel_lib_path + except Exception as e: + err_msg = f"Failed to compile {op_func_name_with_overload}: {e}" + log.exception(err_msg) + return "" diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/async_compile.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/async_compile.py new file mode 100644 index 0000000000000000000000000000000000000000..5ede0cd085010af4596335c103f5bdee4f0159bc --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/async_compile.py @@ -0,0 +1,705 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import atexit +import functools +import json +import logging +import multiprocessing +import os +import re +import sys +from concurrent.futures import Future, ThreadPoolExecutor +from concurrent.futures.process import BrokenProcessPool +from functools import partial +from time import time, time_ns +from typing import Any, Optional, TYPE_CHECKING + +import torch +from torch._dynamo.device_interface import get_registered_device_interfaces +from torch._dynamo.utils import ( + counters, + dynamo_timed, + get_metrics_context, + set_feature_use, +) +from torch._inductor import config +from torch._inductor.codecache import ( + _load_triton_kernel_from_source, + code_hash, + CodeCacheFuture, + CppCodeCache, + CppPythonBindingsCodeCache, + CUDACodeCache, + HalideCodeCache, + LambdaFuture, + ROCmCodeCache, + StaticAutotunerFuture, + torch_key, +) +from torch._inductor.compile_worker.subproc_pool import ( + AnyPool, + SubprocException, + SubprocPool, +) +from torch._inductor.compile_worker.tracked_process_pool import ( + TrackedProcessPoolExecutor, +) +from torch._inductor.compile_worker.utils import _async_compile_initializer +from torch._inductor.runtime.compile_tasks import ( + _set_triton_ptxas_path, + _worker_compile_triton, +) +from torch._inductor.utils import clear_on_fresh_cache +from torch._inductor.virtualized import V +from torch._utils_internal import log_triton_builds +from torch.hub import _Faketqdm, tqdm +from torch.utils._ordered_set import OrderedSet +from torch.utils._triton import has_triton_package + + +if TYPE_CHECKING: + from collections.abc import Callable + + from torch._inductor.runtime.hints import HalideMeta + from torch._inductor.runtime.triton_heuristics import CachingAutotuner + +# timing metrics for time spent in the compilation +_cumulative_compile_time = 0.0 +_t0: Optional[float] = None + +kernel_code_log = torch._logging.getArtifactLogger(__name__, "kernel_code") + +log = logging.getLogger(__name__) + +_triton_kernel_metrics: Optional[dict[str, dict[str, Any]]] = None + +size_hints_regex = re.compile( + r"size_hints=(\{.*?\})", +) + + +def pre_fork_setup(): + """ + Setup that must be done prior to forking with a process pool. + """ + # ensure properties have been calculated before processes + # are forked + caching_device_properties() + + # Computing the triton key can be slow. If we call it before fork, + # it will be cached for the forked subprocesses. + from torch._inductor.runtime.triton_compat import HAS_TRITON, triton_key + + if HAS_TRITON: + triton_key() + + +def caching_device_properties(): + for _, device_interface in get_registered_device_interfaces(): + if device_interface.is_available(): + device_interface.Worker.get_device_properties() + + +def _compile_start() -> None: + global _t0, _triton_kernel_metrics + if _t0 is None: + _t0 = time() + if _triton_kernel_metrics is None: + _triton_kernel_metrics = {} + + +def _compile_end() -> None: + global _cumulative_compile_time, _t0, _triton_kernel_metrics + if _t0 is not None: + t1 = time() + _cumulative_compile_time += t1 - _t0 + _t0 = None + # print("CUMULATIVE COMPILE TIME", _cumulative_compile_time) + if _triton_kernel_metrics: + # Log triton kernel info + sorted_info = dict(sorted(_triton_kernel_metrics.items())) + torch._logging.trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "triton_kernel_info", + "encoding": "json", + }, + payload_fn=lambda: json.dumps(sorted_info), + ) + _triton_kernel_metrics = None + + +def _add_triton_kernel_info(kernel_name: str, info: dict[str, Any]): + global _triton_kernel_metrics + # Must be called between _compile_start and _compile_end + if _triton_kernel_metrics is not None: + _triton_kernel_metrics[kernel_name] = info + + +_IS_WINDOWS = sys.platform == "win32" + +log = logging.getLogger(__name__) + +# Used to keep track of all process pools invoked so far. +_pool_set = OrderedSet[AnyPool]() + + +def shutdown_compile_workers() -> None: + """Shut down all outstanding compile-worker pools.""" + for pool in _pool_set: + pool.shutdown() + AsyncCompile._ready_future = None + after_fork() + + +def after_fork(): + """Reset pools to initial state without shutting them down""" + _pool_set.clear() + AsyncCompile.process_pool.cache_clear() + + +try: + os.register_at_fork(after_in_child=after_fork) +except AttributeError: + pass # register_at_fork does not exists on windows + + +def get_compile_threads() -> int: + """ + Temporary for internal rollout. Assign config.compile_threads lazily and return it. + TODO: remove after rollout. + """ + if config.compile_threads is None: + config.compile_threads = config.decide_compile_threads() + return config.compile_threads + + +@clear_on_fresh_cache +class CompiledTritonKernels: + """ + In memory cache for storing compiled triton kernels. + + Each triton kernel is keyed by the hash of its source code. Each value stored + in the cache is a return value of AsyncCompile.triton(). + + Currently, the cache stores Future objects, but it should be generalizable for any kernels. + """ + + _cache: dict[str, CodeCacheFuture] = {} + + @staticmethod + def key(kernel_src: str): + """ + Generates a cache key given a triton kernel's full source code. + This source includes the inductor meta, compilation metadata, the kernel itself, etc. + `kernel_src` should be the exact string passed to async_compile.triton()'s first argument. + """ + # Hashes the kernel source with torch_key into a single hash key + return code_hash(kernel_src, extra=torch_key()) + + @staticmethod + def save(kernel_src: str, future: CodeCacheFuture): + """ + Saves a compiled triton kernel to the cache. + TODO: We store a LambdaFuture as that's the callable returned by async_compile.triton, + but the real type we want to return here is actually an abstract triton kernel. + + TODO: Source code here is not just the kernel's source code, but also includes the inductor preamble, etc. + so it could be less strict. + """ + key = CompiledTritonKernels.key(kernel_src) + CompiledTritonKernels._cache[key] = future + + @staticmethod + def get(kernel_src: str) -> Optional[CodeCacheFuture]: + key = CompiledTritonKernels.key(kernel_src) + return CompiledTritonKernels._cache.get(key, None) + + @staticmethod + def cache_clear(): + CompiledTritonKernels._cache = {} + + @staticmethod + def remove_future(kernel_src: str) -> None: + key = CompiledTritonKernels.key(kernel_src) + + # Delete the LambdaFuture if there is one + if key in CompiledTritonKernels._cache: + del CompiledTritonKernels._cache[key] + + +class AsyncCompile: + """ + Utilities to compile in thread pools or subprocess pools (in the case of Triton). + """ + + _ready_future: Optional[Future[Any]] = None + + def __init__(self) -> None: + pass + + @staticmethod + @functools.lru_cache(1) + def pool() -> ThreadPoolExecutor: + assert get_compile_threads() > 1 + return ThreadPoolExecutor(get_compile_threads()) + + @staticmethod + def _get_ready(): + """No-op function to help mark when the subprocess pool is ready.""" + return "ready" + + @staticmethod + @functools.lru_cache(1) + def process_pool() -> AnyPool: + assert get_compile_threads() > 1 + AsyncCompile._ready_future = None + log.info( + "Creating '%s' pool with %d workers", + config.worker_start_method, + get_compile_threads(), + ) + + pool: AnyPool + if config.worker_start_method == "subprocess": + # Wrapper around ProcessPoolExecutor forks in a new process we control + pool = SubprocPool( + get_compile_threads(), quiesce=config.quiesce_async_compile_pool + ) + else: + if config.worker_start_method == "spawn": + # Avoid creating pools in the spawned subprocs themselves: + os.environ["TORCH_WARM_POOL"] = "0" + pre_fork_setup() + ctx = multiprocessing.get_context(config.worker_start_method) + pool = TrackedProcessPoolExecutor( + get_compile_threads(), + mp_context=ctx, + initializer=partial(_async_compile_initializer, os.getpid()), + ) + # when this pool is created in a subprocess object, the normal exit handler + # doesn't run, and we need to register our own handler. + # exitpriority has to be high, because another one of the finalizers will + # kill the worker thread that sends the shutdown message to the workers... + multiprocessing.util.Finalize(None, pool.shutdown, exitpriority=sys.maxsize) + + _pool_set.add(pool) + return pool + + @classmethod + def warm_pool(cls) -> None: + if get_compile_threads() <= 1: + return + _compile_start() + # Pool is created on first access. Note for a SubprocPool, the sidecar process starts, + # but its ProcessPoolExecutor does not initialize until a wakeup() call or the first + # job is submitted. + cls.process_pool() + _compile_end() + + @classmethod + def wait_pool_ready(cls, timeout=120) -> None: + cls.use_process_pool() + if cls._ready_future is not None: + cls._ready_future.result(timeout=timeout) + + @classmethod + def submit(cls, task: Callable[..., Any]) -> Any: + if get_compile_threads() <= 1: + return task() + return cls.pool().submit(task) + + @classmethod + def use_process_pool(cls): + if get_compile_threads() <= 1: + return False + + # Create a dummy job to check if the pool is ready. Submit it here instead of at + # pool creation so we don't launch the full pool of worker subprocesses until + # we're sure they're needed. + if not cls._ready_future: + cls._ready_future = cls.process_pool().submit(cls._get_ready) + return cls._ready_future.done() + + @classmethod + def wakeup(cls) -> None: + """ + If using a SubprocPool, signal the sidecar process to start up its + ProcessPoolExecutor. + """ + if not cls.use_process_pool(): + return + pool = cls.process_pool() + if isinstance(pool, SubprocPool): + pool.wakeup() + + def triton(self, kernel_name: str, source_code: str, device_str: str = "cuda"): + """ + Async_compile.triton is more complicated than the other backends because + we're trying to optimize compile time as much as possible for this hot callsite. + + First of all, the function is cached by CompiledTritonKernels; if there's a kernel + already compiled, we grab it directly from the cache and return. + + Otherwise, if we have multiple compile threads, we kick off triton compilations on each + worker process by giving it a kernel and source code to compile. The worker initializes + a CachingAutotuner, runs triton compilation, and pickles the kernel back to us. + We use TritonCompileResult to represent the objects being pickled back to us by each + worker. + + Some maybe not obvious things that are pickled back to us: + - Most of the time, we can avoid sending back CachingAutotuner.fn and other metadata + and do not have to pay the cost of loading the triton kernel on the parent. But certain + cases, like coordesc tuning and dynamic_scale_rblock, require us to reload the function + in the parent lazily when we require it. + - The AutotuneCache, if enabled, is constructed on each worker per triton config + and pickled by to us via `CachingAutotuner.save_cache_hook`. + """ + load_kernel = functools.partial( + _load_triton_kernel_from_source, kernel_name, source_code + ) + + def reload_kernel_in_parent(): + # Benchmark how often this happens + with dynamo_timed("reload_kernel_in_parent"): + return load_kernel() + + counters["inductor"]["async_compile_cache_miss"] += 1 + + kernel_code_log.info("Triton Kernel:\n%s", source_code) + _compile_start() + + if os.environ.get("TRITON_INTERPRET", "0") == "1": + return getattr( + torch._inductor.codecache.PyCodeCache.load(source_code), kernel_name + ) + + is_parallel = self.use_process_pool() + set_feature_use("parallel_compile_post_warmup", is_parallel) + + compile_id = torch._guards.CompileContext.current_compile_id() + is_backward = getattr(V.graph, "is_backward", False) + + if (future := CompiledTritonKernels.get(source_code)) is not None: + counters["inductor"]["async_compile_cache_hit"] += 1 + # Set reload_kernel_from_src properly based on source_code + if isinstance(future, StaticAutotunerFuture): + # Remove the future now that we've cache hit + CompiledTritonKernels.remove_future(source_code) + future.reload_kernel_from_src = reload_kernel_in_parent + if is_parallel: + return future + else: + return future.result() + + # Cache miss + if is_parallel: + # We want to support changing these env vars after (and while) the + # process pool is running, so pass them to the subprocess to reset. + env_vars = ["TORCHINDUCTOR_CACHE_DIR", "TRITON_CACHE_DIR"] + extra_env = {v: os.environ[v] for v in env_vars if v in os.environ} + extra_config = { + "use_static_cuda_launcher": torch._inductor.config.use_static_cuda_launcher + } + + if len(torch._inductor.config.autotune_lookup_table) > 0: + m = size_hints_regex.search(source_code) + if m: + size_hints_str = m.group(1) + else: + size_hints_str = str(None) + + triton_src = source_code.split("@triton.jit\n")[1] + from torch._inductor.runtime.triton_heuristics import ( + generate_lookup_hash_from_source_code, + ) + + fn_hash = generate_lookup_hash_from_source_code( + size_hints_str, triton_src + ) + + if fn_hash in torch._inductor.config.autotune_lookup_table: + extra_config["autotune_lookup_table"] = { # type: ignore[assignment] + fn_hash: torch._inductor.config.autotune_lookup_table[fn_hash] + } + + task = self.process_pool().submit( + _worker_compile_triton, + load_kernel, + extra_env, + extra_config, + ) + + def get_result() -> CachingAutotuner: + try: + kernel, elapsed_us = task.result() + except SubprocException as e: + raise e.with_name(kernel_name) from e + + # Now that we've compiled, we should clear the future + # so it can't be used again + kernel.set_compile_info(compile_id, is_backward) + CompiledTritonKernels.remove_future(source_code) + + kernel.restore_after_unpickle(old_values=None) + + kernel.precompile( + warm_cache_only=False, + reload_kernel=reload_kernel_in_parent, + static_triton_bundle_key=CompiledTritonKernels.key(source_code), + ) + info = kernel.autotune_cache_info or {} + info["compile_time_us"] = elapsed_us + _add_triton_kernel_info(kernel_name, info) + get_metrics_context().add_top_n( + "triton_kernel_compile_times_us", kernel_name, elapsed_us + ) + return kernel + + future = LambdaFuture(get_result, future=task) + CompiledTritonKernels.save(source_code, future) + return future + else: + with dynamo_timed( + "async_compile.precompile", + log_pt2_compile_event=True, + dynamo_compile_column_us="triton_compile_time_us", + log_waitcounter=True, + waitcounter_name_override="compile_triton", + ): + fail = None + try: + start_ns = time_ns() + _set_triton_ptxas_path() + kernel = load_kernel() + kernel.set_compile_info(compile_id, is_backward) + kernel.precompile( + warm_cache_only=False, + static_triton_bundle_key=CompiledTritonKernels.key(source_code), + ) + elapsed_us = (time_ns() - start_ns) // 1000 + get_metrics_context().add_top_n( + "triton_kernel_compile_times_us", kernel_name, elapsed_us + ) + info = kernel.autotune_cache_info or {} + info["compile_time_us"] = elapsed_us + _add_triton_kernel_info(kernel_name, info) + return kernel + except Exception as e: + fail = str(e) + raise + finally: + log_triton_builds(fail=fail) + + def multi_kernel(self, *args, **kwargs) -> Any: + from torch._inductor.codegen.multi_kernel import MultiKernelCall + + # no need to call this in parallel since the sub-kernels are already parallel tasks + return MultiKernelCall(*args, **kwargs) + + def size_hint_multi_kernel(self, *args, **kwargs) -> Any: + from torch._inductor.codegen.multi_kernel import SizeHintMultiKernelCall + + return SizeHintMultiKernelCall(*args, **kwargs) + + def cpp(self, source_code: str): + kernel_code_log.info("CPP Kernel:\n%s", source_code) + if get_compile_threads() <= 1: + return CppCodeCache.load(source_code).kernel + else: + get_result = CppCodeCache.load_async(source_code, submit_fn=self.submit) + return LambdaFuture(lambda: get_result().kernel) + + def cpp_pybinding(self, argtypes: list[str], source_code: str): + kernel_code_log.info("CPP+Bindings Kernel:\n%s", source_code) + if get_compile_threads() <= 1: + return CppPythonBindingsCodeCache.load_pybinding(argtypes, source_code) + else: + get_result = CppPythonBindingsCodeCache.load_pybinding_async( + argtypes, source_code, submit_fn=self.submit + ) + return LambdaFuture(get_result) + + def cuda(self, source_code, dst_file_ext, aot_compile=False): + kernel_code_log.info("CUDA Kernel:\n%s", source_code) + + def task(): + if aot_compile: + # We rely on JITInductor to compile the CUDA code, + # so that we can load it into AOTInductor. + output_path, *_ = CUDACodeCache.compile(source_code, "o") + CUDACodeCache.aot_kernels_o.append(output_path) + return CUDACodeCache.load(source_code, dst_file_ext)[0] + + return self.submit(task) + + def rocm( + self, + source_code, + dst_file_ext, + aot_compile=False, + ): + kernel_code_log.info("ROCm Kernel:\n%s", source_code) + + def task(): + if aot_compile: + output_path, *_ = ROCmCodeCache.compile(source_code, dst_file_ext="o") + ROCmCodeCache.aot_kernels_o.append(output_path) + if config.rocm.generate_test_runner: + _ = ROCmCodeCache.compile(source_code, dst_file_ext="exe") + return ROCmCodeCache.load(source_code, dst_file_ext)[0] + + return self.submit(task) + + def halide(self, meta: HalideMeta, source_code: str): + kernel_code_log.info("Halide Kernel:\n%r\n%s", meta, source_code) + if get_compile_threads() <= 1: + return HalideCodeCache.generate_halide(meta, source_code) + else: + get_result = HalideCodeCache.generate_halide_async( + meta, source_code, submit_fn=self.submit + ) + return LambdaFuture(get_result) + + def cutedsl(self, kernel_name: str, source_code: str): + """ + Compile CuteDSL (CUTLASS Python DSL) kernels. + + Args: + kernel_name: Name of the kernel to be defined + source_code: Source code of the CuteDSL kernel, as a string + + Note: + CuteDSL currently requires source files to do its compilation, there we + use the PyCodeCache to write the source code to a file and load it. + """ + from torch._inductor.codegen.cutedsl.cutedsl_kernel import ( + CuteDSLKernelWrapper, + MAIN_SUFFIX, + ) + + kernel_code_log.info("CuteDSL Kernel:\n%s", source_code) + + def task(): + key, path = torch._inductor.codecache.PyCodeCache.write(source_code) + mod = torch._inductor.codecache.PyCodeCache.load_by_key_path(key, path) + + # Find our special entry point named function + main_func_name = f"{kernel_name}_{MAIN_SUFFIX}" + if not hasattr(mod, main_func_name): + available = [name for name in dir(mod) if callable(getattr(mod, name))] + raise RuntimeError( + f"Could not find CuteDSL main kernel function '{main_func_name}'. Available callables: {available}" + ) + + return CuteDSLKernelWrapper(getattr(mod, main_func_name), kernel_path=path) + + if get_compile_threads() <= 1: + return task() + else: + future = self.submit(task) + return LambdaFuture(lambda: future.result()) + + def pallas(self, kernel_name: str, source_code: str): + """ + Compile Pallas (JAX experimental) kernels. + + Args: + kernel_name: Name of the kernel to be defined + source_code: Source code of the Pallas kernel, as a string + + Note: + Pallas kernels are Python code that uses JAX and Pallas APIs. + We use the PyCodeCache to write the source code to a file and load it. + """ + from torch._inductor.codegen.pallas import MAIN_SUFFIX, PallasKernelWrapper + + kernel_code_log.info("Pallas Kernel:\n%s", source_code) + + def task(): + key, path = torch._inductor.codecache.PyCodeCache.write(source_code) + mod = torch._inductor.codecache.PyCodeCache.load_by_key_path(key, path) + + # Find our special entry point named function + main_func_name = f"{kernel_name}_{MAIN_SUFFIX}" + if not hasattr(mod, main_func_name): + available = [name for name in dir(mod) if callable(getattr(mod, name))] + raise RuntimeError( + f"Could not find Pallas main kernel function '{main_func_name}'. Available callables: {available}" + ) + + return PallasKernelWrapper(getattr(mod, main_func_name), kernel_path=path) + + if get_compile_threads() <= 1: + return task() + else: + future = self.submit(task) + return LambdaFuture(lambda: future.result()) + + def wait(self, scope: dict[str, Any]) -> None: + if get_compile_threads() > 1: + with dynamo_timed( + "async_compile.wait", + log_pt2_compile_event=True, + dynamo_compile_column_us="triton_compile_time_us", + log_waitcounter=True, + waitcounter_name_override="compile_triton", + ): + self._wait_futures(scope) + + _compile_end() + + def _wait_futures(self, scope: dict[str, Any]) -> None: + kernels = { + key: value + for key, value in scope.items() + if isinstance(value, (Future, CodeCacheFuture)) + } + pbar = tqdm( + total=len(kernels), + desc="Inductor Compilation", + disable=config.disable_progress, + delay=0, + ) + for key, result in kernels.items(): + if config.verbose_progress and not isinstance(pbar, _Faketqdm): + pbar.set_postfix_str(key) + try: + kernel = result.result() + scope[key] = kernel + except BrokenProcessPool as e: + raise RuntimeError( + "A compilation subprocess exited unexpectedly. This " + "is likely due to a crash. To facilitate debugging, " + "you can re-run with TORCHINDUCTOR_COMPILE_THREADS=1 " + "to cause compilation to occur in the main process." + ) from e + pbar.update(1) + + +def maybe_warm_pool() -> None: + if ( + os.environ.get("TORCH_TNT_IN_USE", "0") == "1" + or os.environ.get("TORCH_WARM_POOL", "1") != "1" + # The subprocess pool is only used for the Triton backend + or not has_triton_package() + # Skip for fbcode. We have internal reports of usages inside multiprocessing + # pools that lead a multiplicative number of compile subprocesses. + or config.is_fbcode() + ): + return + + AsyncCompile.warm_pool() + # TODO: This starts the SubprocPool's internal process pool as early as possible at + # the expense of creating a bunch of worker processes that might not be needed. We + # could start them lazily if we're willing to lose a small amount of compile time. + AsyncCompile.wakeup() + + +# On exit give the workers a chance to clean themselves up. Without this the +# resource_tracker can complain about leaked semaphores coming from the +# ProcessPoolExecutor: +# UserWarning: resource_tracker: There appear to be 5 leaked semaphore objects +# to clean up at shutdown +atexit.register(shutdown_compile_workers) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/augmented_graph_helper.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/augmented_graph_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..5a70a34f7b64b72d8e8d8e86523905b959bd1b0e --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/augmented_graph_helper.py @@ -0,0 +1,181 @@ +from collections import defaultdict +from typing import Optional + +import torch +import torch.fx as fx +from torch.utils._ordered_set import OrderedSet + + +class AugmentedGraphHelper: + """ + Graph helper that augments the original graph with additional + dependencies and uses, plus tracks node equivalences for coalescing. + + TODO: if this becomes too large of compile time, consider binding + graphcycles.cc + """ + + def __init__( + self, + graph: fx.Graph, + node_ancestors: Optional[dict[fx.Node, OrderedSet[fx.Node]]] = None, + ): + # Each node starts in its own singleton set + self.graph = graph + self.merge_sets = {node: OrderedSet([node]) for node in graph.nodes} + + # Extra dependencies: node depends on dep (dep must come before node) + self.extra_deps: dict[fx.Node, OrderedSet[fx.Node]] = defaultdict(OrderedSet) + # Extra uses: reverse of extra_deps (node is used by user) + self.extra_uses: dict[fx.Node, OrderedSet[fx.Node]] = defaultdict(OrderedSet) + # Note: only reflect original ancestors, not maintained through additional deps + # or merge sets + self.node_ancestors = node_ancestors + + def add_extra_dep(self, *, n: fx.Node, dep: fx.Node) -> None: + """Add extra dependency: node depends on dep.""" + self.extra_deps[n].add(dep) + self.extra_uses[dep].add(n) + + def remove_extra_dep(self, *, n: fx.Node, dep: fx.Node) -> None: + if dep in self.extra_deps[n]: + self.extra_deps[n].discard(dep) + self.extra_uses[dep].discard(n) + + def merge_to_set(self, existing_node: fx.Node, new_node: fx.Node) -> None: + """ + Merge new_node into existing_node's set. The new node must be a singleton set. + """ + existing_set = self.merge_sets[existing_node] + new_set = self.merge_sets[new_node] + assert len(new_set) == 1 + + # Add all nodes from new_set to existing_set + existing_set.update(new_set) + + # Update all nodes from new_set to point to existing_set + for node in new_set: + self.merge_sets[node] = existing_set + + def unmerge_node(self, node: fx.Node) -> None: + """Remove a node from its merge set, making it singleton.""" + old_set = self.merge_sets[node] + + # If already singleton, nothing to do + if len(old_set) == 1: + return + + # Remove from old set + old_set.remove(node) + + # Make node singleton + self.merge_sets[node] = OrderedSet([node]) + + def get_merged_deps(self, node: fx.Node) -> OrderedSet[fx.Node]: + """ + Get all dependencies of a node considering merges and extra deps. + Combines: + 1. Direct deps (all_input_nodes) of node and its merge equivalents + 2. Extra deps of node and its merge equivalents + """ + deps: OrderedSet[fx.Node] = OrderedSet() + + # For each node in the merge set + for merged_node in self.merge_sets[node]: + # Add direct dependencies from all_input_nodes + deps.update(merged_node.all_input_nodes) + # Add extra dependencies + deps.update(self.extra_deps[merged_node]) + + return deps + + def has_cycle(self) -> bool: + merged_deps = {n: self.get_merged_deps(n) for n in self.graph.nodes} + return torch._dynamo.graph_deduplication._has_cycle(self.graph, merged_deps) + + def has_path(self, source: fx.Node, target: fx.Node) -> bool: + """Check if there's a path from source to target.""" + # we should not be checking path from node to itself + assert self.merge_sets[source] is not self.merge_sets[target] + + # search backwards from target to source + visited: OrderedSet[fx.Node] = OrderedSet() + queue = [target] + visited.add(target) + + while queue: + current = queue.pop() + + for dep in self.get_merged_deps(current): + # Check if we reached source or its equivalent + if dep in self.merge_sets[source]: + return True + + if dep in visited: + continue + + # We are searching from target, so this node is necessarily an ancestor + # of target. + # If dep is an ancestor of source, any path through dep to source would imply a cycle + if self.node_ancestors: + source_set = self.merge_sets[source] + is_ancestor_of_source = any( + dep in self.node_ancestors[s] for s in source_set + ) + # Add to visited to avoid recomputing this check if we see dep again + if is_ancestor_of_source: + visited.add(dep) + continue + + visited.add(dep) + queue.append(dep) + + return False + + def transfer_erased_node_deps(self, erased_to_new: dict[fx.Node, fx.Node]) -> None: + """ + Transfer all extra dependencies from erased nodes to their replacements, handling + cross-dependencies between erased nodes correctly. + """ + erased_merge_sets: dict[fx.Node, fx.Node] = {} + + for replaced, new in erased_to_new.items(): + for equiv in self.merge_sets[replaced]: + erased_merge_sets[equiv] = new + + # Transfer dependencies + for old_node, new_node in erased_merge_sets.items(): + # Transfer dependencies FROM old_node (what old_node depended on) + for extra_dep in self.extra_deps[old_node]: + # Redirect if dep is also being erased + updated_dep = erased_merge_sets.get(extra_dep, extra_dep) + self.extra_deps[new_node].add(updated_dep) + self.extra_uses[updated_dep].discard(old_node) + self.extra_uses[updated_dep].add(new_node) + + # Transfer dependencies TO old_node (what depended on old_node) + for extra_use in self.extra_uses[old_node]: + # Redirect if this user is also being erased + updated_use = erased_merge_sets.get(extra_use, extra_use) + + # Update the user's deps to point to new_node + self.extra_deps[updated_use].discard(old_node) + self.extra_deps[updated_use].add(new_node) + self.extra_uses[new_node].add(updated_use) + + # Clean up erased nodes + for old_node in erased_merge_sets: + self.extra_deps[old_node].clear() + self.extra_uses[old_node].clear() + del self.merge_sets[old_node] + + def get_all_extra_deps(self) -> dict[fx.Node, OrderedSet[fx.Node]]: + """ + Get all extra dependencies in a format suitable for topological sort. + Returns a copy to avoid external modifications. + """ + return { + node: OrderedSet(deps) + for node, deps in self.extra_deps.items() + if deps # Only include nodes with non-empty deps + } diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/autotune_process.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/autotune_process.py new file mode 100644 index 0000000000000000000000000000000000000000..3b869ce8271c8c837c19bb79d442500057826df3 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/autotune_process.py @@ -0,0 +1,1041 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import atexit +import ctypes +import dataclasses +import functools +import logging +import os +import pickle +import queue +import selectors +import subprocess +import sys +import time +import warnings +from collections.abc import Callable, Iterable, Sequence +from concurrent.futures import ThreadPoolExecutor +from ctypes import byref, c_size_t, c_void_p, CDLL +from typing import Any, IO, Optional, TYPE_CHECKING, Union + +import torch +import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools +from torch._dynamo.device_interface import get_interface_for_device +from torch._dynamo.testing import rand_strided +from torch._inductor import ir +from torch._inductor.codecache import ( + CppCodeCache, + CUDACodeCache, + DLLWrapper, + get_hash, + PyCodeCache, +) +from torch._inductor.utils import ( + do_bench_using_profiling, + get_gpu_type, + get_ld_library_path, + is_gpu, + python_subprocess_env, +) +from torch._logging import getArtifactLogger +from torch.utils._ordered_set import OrderedSet + + +if TYPE_CHECKING: + from types import ModuleType + + from torch._inductor.select_algorithm import PartialRender, TritonTemplateCaller + +from . import config +from .runtime.benchmarking import benchmarker +from .virtualized import V + + +CUDA_VISIBLE_DEVICES = "CUDA_VISIBLE_DEVICES" + +autotuning_log = getArtifactLogger(__name__, "autotuning") + + +class NonzeroWorkspaceNotSupportedError(Exception): + pass + + +class TuningProcess: + """ + Class to launch and interact with a benchmarking subprocess. + """ + + @staticmethod + def process_main(read_pipe: IO[bytes], write_pipe: IO[bytes]) -> None: + """ + Entry point for the child process. + """ + autotuning_log.debug( + "Started autotune subprocess %s. Visible devices: %s", + os.getpid(), + os.environ.get(CUDA_VISIBLE_DEVICES), + ) + + def workloop(): + while True: + job, extra_env = TuningProcess.recv(read_pipe) + if job is None: + # None is a sentinel for the child to shut down + break + try: + if extra_env: + os.environ.update(extra_env) + result = job() + except Exception as e: + result = e + TuningProcess.send(result, write_pipe) + + try: + workloop() + except EOFError: + # The parent closed the pipe + pass + + @staticmethod + def send( + obj: Any, write_pipe: IO[bytes], extra_env: dict[str, str] | None = None + ) -> None: + pickle.dump((obj, extra_env), write_pipe) + write_pipe.flush() + + @staticmethod + def recv(read_pipe: IO[bytes]) -> Any: + return pickle.load(read_pipe) + + def __init__(self, device: Optional[int]): + self.device = device + self.start() + + def start(self): + """ + Start the benchmarking subprocess. + """ + entry = os.path.join(os.path.dirname(__file__), "__autotune_main__.py") + + subproc_read_fd, write_fd = os.pipe() + read_fd, subproc_write_fd = os.pipe() + self.write_pipe = os.fdopen(write_fd, "wb") + self.read_pipe = os.fdopen(read_fd, "rb") + + self.selector = selectors.DefaultSelector() + self.selector.register(self.read_pipe, selectors.EVENT_READ) + + cmd = [ + sys.executable, + entry, + f"--parent={os.getpid()}", + f"--read-fd={str(subproc_read_fd)}", + f"--write-fd={str(subproc_write_fd)}", + ] + env = { + **python_subprocess_env(), + # We shouldn't be using the Triton async compile subprocess pool, + # but as a precaution set the env var that disables its creation. + "TORCH_WARM_POOL": "0", + # Some internal usages need a modified LD_LIBRARY_PATH. + "LD_LIBRARY_PATH": get_ld_library_path(), + # This will cause the subprocs to profile using the profiler. + "TORCHINDUCTOR_PROFILE_WITH_DO_BENCH_USING_PROFILING": "1" + if config.profile_bandwidth_with_do_bench_using_profiling + else "0", + } + if self.device is not None: + env[CUDA_VISIBLE_DEVICES] = str(self.device) + self.process = subprocess.Popen( + cmd, + env=env, + pass_fds=(subproc_read_fd, subproc_write_fd), + ) + os.close(subproc_read_fd) + os.close(subproc_write_fd) + + self.running = True + + def alive(self) -> bool: + """ + True if the subprocess is still running. + """ + return self.running and self.process.poll() is None + + def put(self, req: Any, extra_env: dict[str, str] | None = None) -> None: + """ + Push a work item to the child process. + """ + if not self.alive(): + self.start() + TuningProcess.send(req, self.write_pipe, extra_env=extra_env) + + def get(self, timeout: float = 120.0) -> Any: + """ + Get a response from the child process. Raises TimeoutError on timeout; + raises EOFError if the subprocess crashes. + """ + try: + if not self.selector.select(timeout): + raise TimeoutError(f"Timeout in autotune subprocess {self.process.pid}") + result, _ = TuningProcess.recv(self.read_pipe) + except TimeoutError: + self.kill() + raise + except EOFError: + # The subprocess crashed + self.close() + raise + except Exception: + autotuning_log.exception( + "Unexpected exception in autotune subprocess %s", self.process.pid + ) + self.kill() + raise + + if isinstance(result, Exception): + raise result + return result + + def shutdown(self, wait: bool = True) -> None: + """ + Signal the child process to shut down gracefully. + """ + if self.alive(): + TuningProcess.send(None, self.write_pipe) + if wait: + self.wait() + + def wait(self) -> None: + """ + Wait for the child process to exit. + """ + if self.alive(): + self.process.wait() + self.close() + + def close(self) -> None: + """ + Close resources. + """ + self.selector.close() + self.read_pipe.close() + self.write_pipe.close() + self.running = False + + def kill(self) -> None: + """ + Send a SIGKILL to the child process. + """ + if self.alive(): + autotuning_log.error( + "Sending SIGKILL to autotune subprocess %d", + self.process.pid, + ) + self.process.kill() + self.close() + + def restart(self) -> None: + """ + Gracefully restarts the child process. + """ + self.shutdown(wait=True) + self.start() + + +class TuningProcessPool: + """ + Maintains a pool of TuningProcesses to benchmark kernels in parallel + across devices. By default, we create one TuningProcess per device and + set the sub-process environment to make only that device visible. + """ + + def __init__(self) -> None: + """ + Start the child processes. + """ + devices = self.get_device_list() + autotuning_log.debug("Sub-process autotune device list: %s", devices) + + # Launch the child processes. + self.processes = [TuningProcess(device=device) for device in devices] + + self.process_queue: queue.Queue[TuningProcess] = queue.Queue() + for p in self.processes: + self.process_queue.put(p) + + # Use a thread pool to manage distributing work to the subprocesses. + # Threads block on an available process, so it makes sense to match + # the number of threads with the number of devices. + self.executor = ThreadPoolExecutor(max_workers=len(devices)) + + @staticmethod + def get_device_list() -> Sequence[Optional[int]]: + """ + Gather the list of devices to be used in the pool. + """ + if not config.autotune_multi_device: + # Don't use multiple devices + return [None] + + gpu_type = get_gpu_type() + device_interface = get_interface_for_device(gpu_type) + count = device_interface.device_count() + + # If the user specified the visible devices in the env, use those. + if CUDA_VISIBLE_DEVICES in os.environ: + devices = [int(d) for d in os.environ[CUDA_VISIBLE_DEVICES].split(",")] + assert len(devices) <= count + return devices + + return list(range(count)) + + def shutdown(self) -> None: + """ + Signal all child processes to exit. + """ + self.executor.shutdown() + + for p in self.processes: + p.shutdown(wait=False) + for p in self.processes: + p.wait() + + def target(self, choice: TritonTemplateCaller) -> float: + """ + Entry point for the thread-pool helper threads: Wait for an open TuningProcess, + remove it from the queue, execute the benchmark in that subprocess, and return + the TuningProcess to the queue. + """ + assert choice.bmreq is not None + + env_vars = ["TORCHINDUCTOR_CACHE_DIR", "TRITON_CACHE_DIR"] + extra_env = {v: os.environ[v] for v in env_vars if v in os.environ} + process = self.process_queue.get() + process.put(choice.bmreq.benchmark, extra_env=extra_env) + try: + return process.get( + config.max_autotune_subproc_result_timeout_seconds, + ) + except TimeoutError: + warnings.warn( + f"Timed out benchmarking choice '{choice}'. It will be ignored. " + "Please debug the root cause in case the choice can bring perf gains." + ) + # Set to INF so this choice will be ignored + return float("inf") + except Exception as process_exception: + warnings.warn( + f"Failed to benchmark choice '{choice}'. It will be ignored. " + "Please debug the root cause in case the choice can bring perf gains." + ) + # An unspecified launch failure (cudaErrorLaunchFailure) corrupts the + # CUDA context, making it unrecoverable. All subsequent CUDA calls will + # fail as well. The process must be restarted to restore CUDA functionality. + if "cudaErrorLaunchFailure" in str(process_exception): + process.restart() + # Set to INF so this choice will be ignored + return float("inf") + finally: + self.process_queue.put(process) + + def benchmark( + self, + choices: list[TritonTemplateCaller], + ) -> dict[TritonTemplateCaller, float]: + """ + Benchmark each choice in a separate process. + """ + + # Use a ThreadExecutorPool to spread the work across the subprocesses and + # to grab subprocesses as soon as they're free. + results = dict(zip(choices, self.executor.map(self.target, choices))) + + return results + + +LayoutOrBuffer = Union[ir.Layout, ir.Buffer] + + +@dataclasses.dataclass +class TensorMeta: + device: torch.device + dtype: torch.dtype + sizes: torch._prims_common.ShapeType + strides: torch._prims_common.StrideType + offset: int + name: Optional[str] = None + + @classmethod + def from_irnodes( + cls, irnodes: Union[LayoutOrBuffer, Sequence[LayoutOrBuffer]] + ) -> Union[TensorMeta, list[TensorMeta]]: + if isinstance(irnodes, Sequence): + result: list[Any] = [cls.from_irnodes(x) for x in irnodes] + assert all(isinstance(x, TensorMeta) for x in result) + return result + + node = irnodes + if isinstance(node, ir.Layout): + node = ir.Buffer(name="fake", layout=node) + + dtype = node.get_dtype() + assert dtype is not None + device = node.get_device() + assert device is not None + + return TensorMeta( + device=device, + dtype=dtype, + sizes=V.graph.sizevars.size_hints( + node.get_size(), + fallback=config.unbacked_symint_fallback, + ), + strides=V.graph.sizevars.size_hints( + node.get_stride(), + fallback=config.unbacked_symint_fallback, + ), + offset=V.graph.sizevars.size_hint( + node.get_layout().offset, + fallback=config.unbacked_symint_fallback, + ), + name=node.get_name(), + ) + + def to_tensor(self) -> torch.Tensor: + return rand_strided( + self.sizes, + self.strides, + device=self.device, + dtype=self.dtype, + extra_size=self.offset, + ) + + +@dataclasses.dataclass +class BenchmarkRequest: + """ + Only handle triton template benchmark for now. The extern kernel benchmark + can be done inside the same process since they usually don't cause crash. + + Important: Instances of this class and subclasses have to be serializable + across process boundaries. Do not put CUDA Tensors in here! + """ + + def __init__( + self, + kernel_name: str, + input_tensor_meta: Union[TensorMeta, list[TensorMeta]], + output_tensor_meta: Union[TensorMeta, list[TensorMeta]], + extra_args: Iterable[Any], + ) -> None: + # the kernel name defined in the module + self.kernel_name = kernel_name + + if isinstance(input_tensor_meta, TensorMeta): + self.input_tensor_meta: list[TensorMeta] = [input_tensor_meta] + else: + self.input_tensor_meta: list[TensorMeta] = input_tensor_meta + + if output_tensor_meta and isinstance(output_tensor_meta, (tuple, list)): + if len(output_tensor_meta) > 1: + # Each output with same meta for Grouped GEMM + assert all( + getattr(output_tensor_meta[0], attr) == getattr(x, attr) + for x in output_tensor_meta + for attr in ["device", "dtype", "sizes", "strides", "offset"] + ) + self.output_tensor_meta = output_tensor_meta[0] + else: + self.output_tensor_meta: TensorMeta = output_tensor_meta + + self.extra_args = extra_args + + def make_run_fn( + self, *input_tensors: torch.Tensor, out: torch.Tensor + ) -> Callable[[], None]: + raise NotImplementedError + + def cleanup_run_fn(self) -> None: + pass + + def do_bench( + self, + fn, + *input_tensors: torch.Tensor, + out: Optional[torch.Tensor] = None, + ) -> float: + raise NotImplementedError + + def benchmark( + self, + *input_tensors: torch.Tensor, + out: Optional[torch.Tensor] = None, + ) -> float: + debug = autotuning_log.isEnabledFor(logging.DEBUG) + if debug: + start_ts = time.time() + + # create args and out tensor + if out is None: + assert self.input_tensor_meta and self.output_tensor_meta, ( + "Input and output tensor meta must be populated when input_tensors is empty" + ) + assert len(input_tensors) == 0 + input_tensors = tuple(x.to_tensor() for x in self.input_tensor_meta) + out = self.output_tensor_meta.to_tensor() + + if debug: + create_tensor_elapse = time.time() - start_ts # type: ignore[possibly-undefined] + start_ts = time.time() + try: + fn = self.make_run_fn(*input_tensors, out=out) + except NonzeroWorkspaceNotSupportedError: + # Skipping all ops with nonzero workspace requirements + autotuning_log.info("Skipping op due to nonzero workspace requirement") + return float("inf") + + if debug: + load_elapse = time.time() - start_ts # type: ignore[possibly-undefined] + start_ts = time.time() + + res = self.do_bench(fn, *input_tensors, out) + + if debug: + bench_elapse = time.time() - start_ts # type: ignore[possibly-undefined] + autotuning_log.debug( + "InChildProcess %s: load %f, create tensor %f, bench %f", + str(self), + load_elapse, # type: ignore[possibly-undefined] + create_tensor_elapse, # type: ignore[possibly-undefined] + bench_elapse, + ) + self.cleanup_run_fn() + return res + + +class _TestBenchmarkRequest(BenchmarkRequest): + """ + Supports unit testing. Defined in this file instead of the test file so the + TuningProcess sub-process can unpickle these objects. + """ + + def __init__( + self, + result: float = 0.0, + device: Optional[int] = None, + sleep: Optional[float] = None, + exc: Optional[Exception] = None, + crash: bool = False, + ): + self.result = result + self.device = device + self.sleep = sleep + self.exc = exc + self.crash = crash + + def benchmark( + self, *input_tensors: torch.Tensor, out: Optional[torch.Tensor] = None + ) -> float: + if self.device is not None: + assert os.environ.get(CUDA_VISIBLE_DEVICES, None) == str(self.device) + if self.sleep: + time.sleep(self.sleep) + if self.exc: + raise self.exc + if self.crash: + sys.exit(1) + return self.result + + +class GPUDeviceBenchmarkMixin: + def do_bench( + self, + fn, + *input_tensors: torch.Tensor, + out: Optional[torch.Tensor] = None, + ) -> float: + device_idx_set = OrderedSet( + tensor.device.index + for tensor in [*input_tensors, out] + if isinstance(tensor, torch.Tensor) + and is_gpu(tensor.device.type) + and tensor.device.index is not None + ) + assert len(device_idx_set) <= 1, f"Can not mix devices {device_idx_set}" + device_type = next( + ( + tensor.device.type + for tensor in input_tensors + if is_gpu(tensor.device.type) + ), + "cuda", + ) + device_interface = get_interface_for_device(device_type) + if len(device_idx_set) == 1: + device_idx = next(iter(device_idx_set)) + else: + device_idx = device_interface.current_device() + with device_interface.device(device_idx): # type: ignore[attr-defined] + res = benchmarker.benchmark_gpu(fn) + device_interface.synchronize() # shake out any CUDA errors + + return res + + +class CPUDeviceBenchmarkMixin: + def do_bench( + self, + fn, + *input_tensors: torch.Tensor, + out: Optional[torch.Tensor] = None, + ) -> float: + return benchmarker.benchmark_cpu(fn) + + +class TritonBenchmarkRequest(BenchmarkRequest): + # Important: Instances of this class have to be serializable + # across process boundaries. Do not put CUDA Tensors in here! + def __init__( + self, + kernel_name: str, + input_tensor_meta: Union[TensorMeta, list[TensorMeta]], + output_tensor_meta: Union[TensorMeta, list[TensorMeta]], + extra_args: Iterable[Any], + module_path: str, # the path of the module defining the triton kernel + module_cache_key: str, + num_stages: int, + num_warps: int, + num_consumer_groups: int = 0, + num_buffers_warp_spec: int = 0, + matrix_instr_nonkdim: int = 0, # only used for hip to choose the shape of mfma instruction. + waves_per_eu: int = 0, # only used for hip to schedule waves per execution unit + kpack: int = 0, # ROCm specific gemm parameter + ) -> None: + super().__init__(kernel_name, input_tensor_meta, output_tensor_meta, extra_args) + self.module_path = module_path + self.module_cache_key = module_cache_key + self.num_stages = num_stages + self.num_warps = num_warps + self.num_consumer_groups = num_consumer_groups + self.num_buffers_warp_spec = num_buffers_warp_spec + self.matrix_instr_nonkdim = matrix_instr_nonkdim + self.waves_per_eu = waves_per_eu + self.kpack = kpack + + def make_run_fn( + self, *input_tensors: torch.Tensor, out: torch.Tensor + ) -> Callable[[], None]: + mod = PyCodeCache.load_by_key_path(self.module_cache_key, self.module_path) + autotuning_log.debug( + "benchmark module key: %s, path: %s", + self.module_cache_key, + self.module_path, + ) + + run_method = getattr(mod, self.kernel_name).run + extra_args = list(self.extra_args) + run_method.__self__.with_bandwidth_info = False + + # Newer version of triton add warmup argument to JITFunction.run. + # This code handles backward-compatibility. + warmup_arg = {} + import inspect + + if "warmup" in inspect.signature(run_method).parameters: + warmup_arg["warmup"] = False + + if out.device.type == "cpu": + stream = 0 + else: + device_type = out.device.type + device_interface = get_interface_for_device(device_type) + stream = device_interface.get_raw_stream( + self.output_tensor_meta.device.index + ) + + if isinstance( + getattr(mod, self.kernel_name), + torch._inductor.runtime.triton_heuristics.DebugAutotuner, + ): + return functools.partial( + run_method, + *input_tensors, + out, + *extra_args, + **warmup_arg, + stream=stream, + ) + else: + return functools.partial( + run_method, + *input_tensors, + out, + *extra_args, + **warmup_arg, + stream=stream, + benchmark_run=True, + ) + + def precompile(self): + mod = PyCodeCache.load_by_key_path(self.module_cache_key, self.module_path) + getattr(mod, self.kernel_name).precompile() + + def __str__(self) -> str: + return f"{self.kernel_name=}, {self.module_path=}, {self.module_cache_key=}" + + +class TritonGPUBenchmarkRequest(GPUDeviceBenchmarkMixin, TritonBenchmarkRequest): + pass + + +class TritonCPUBenchmarkRequest(CPUDeviceBenchmarkMixin, TritonBenchmarkRequest): + pass + + +class ExternKernelBenchmarkRequest(BenchmarkRequest): + """ + A class to handle extern kernel benchmark requests. This allows extern kernels + (like aten::mm) to be benchmarked in a subprocess, similar to Triton kernels. + + Important: Instances of this class have to be serializable across + process boundaries. Do not put CUDA Tensors in here! + """ + + def __init__( + self, + kernel_name: str, + input_tensor_meta: Union[TensorMeta, list[TensorMeta]], + output_tensor_meta: Union[TensorMeta, list[TensorMeta]], + extra_args: Iterable[Any], + callable_path: str, # Module path to the callable (e.g., "extern_kernels.mm") + kwargs: Optional[dict[str, Any]] = None, + has_out_variant: bool = True, + ) -> None: + super().__init__(kernel_name, input_tensor_meta, output_tensor_meta, extra_args) + self.callable_path = callable_path + self.kwargs = kwargs or {} + self.has_out_variant = has_out_variant + + def make_run_fn( + self, *input_tensors: torch.Tensor, out: torch.Tensor + ) -> Callable[[], None]: + fn = self.to_callable() + if self.has_out_variant: + # For out=variant, pass output as keyword arg + return functools.partial(fn, *input_tensors, out=out) + else: + # For non-out variant, just call with inputs + return functools.partial(fn, *input_tensors) + + def benchmark( + self, *input_tensors: torch.Tensor, out: Optional[torch.Tensor] = None + ): + if out is not None and out.numel() == 0: + # no need to run the kernel of do benchmarking + return 0.0 + if self.has_out_variant or len(input_tensors) == 0: + return super().benchmark(*input_tensors, out=out) + else: + algo = self.to_callable() + out_new = algo(*input_tensors) + if out is not None: + torch._C._dynamo.guards.assert_size_stride( + out_new, tuple(out.size()), tuple(out.stride()) + ) + out.copy_(out_new) # for correctness checking + if config.profile_bandwidth_with_do_bench_using_profiling: + return do_bench_using_profiling(lambda: algo(*input_tensors)) + return benchmarker.benchmark(algo, input_tensors, {}) + + def precompile(self) -> None: + # Extern kernels don't need precompilation - they're already compiled + pass + + def to_callable(self): + # While ExternKernelChoice also has a to_callable method, + # we avoid calling the ExternKernelChoice version here to make sure + # this is picklable + from torch._inductor.select_algorithm import extern_kernels + + fn = getattr(extern_kernels, self.kernel_name) + if self.kwargs: + return functools.partial(fn, **self.kwargs) + + return fn + + def __str__(self) -> str: + return f"ExternKernelBenchmarkRequest({self.callable_path})" + + +class ExternKernelGPUBenchmarkRequest( + GPUDeviceBenchmarkMixin, ExternKernelBenchmarkRequest +): + pass + + +class ExternKernelCPUBenchmarkRequest( + CPUDeviceBenchmarkMixin, ExternKernelBenchmarkRequest +): + pass + + +class CUDABenchmarkRequest(GPUDeviceBenchmarkMixin, BenchmarkRequest): + """ + A class to handle CUDA (CUTLASS) benchmark requests. This class is for + managing the lifecycle of a CUDA kernel benchmark, including compiling + the source code, managing workspace memory, and executing the kernel. + + Important: Instances of this class have to be serializable across + process boundaries. Do not put CUDA Tensors in here! + """ + + def __init__( + self, + kernel_name: str, + input_tensor_meta: Union[TensorMeta, list[TensorMeta]], + output_tensor_meta: Union[TensorMeta, list[TensorMeta]], + extra_args: Iterable[Any], + source_code: str, + ) -> None: + super().__init__(kernel_name, input_tensor_meta, output_tensor_meta, extra_args) + self.source_code = source_code + self.workspace_size: int = 0 + self.workspace: Optional[torch.Tensor] = None + self.DLL: Optional[DLLWrapper] = None + self._workspace_size_updated = False + self.hash_key: str = "" + self.source_file: str = "" + self.hash_key, self.source_file = CUDACodeCache.write(self.source_code, "so") + + def precompile(self): + """ + Precompile the CUDA source code to populate the CUDACodeCache. + This may happen in a separate thread pool. + """ + autotuning_log.debug("Precompiling %s", self) + CUDACodeCache.compile(self.source_code, "so") + autotuning_log.debug("Done precompiling %s", self) + + def make_run_fn( + self, *input_tensors: torch.Tensor, out: torch.Tensor + ) -> Callable[[], None]: + """ + Create a function to run the CUDA kernel with the given input and output tensors. + """ + + self.ensure_dll_loaded() + self.update_workspace_size() + args = [c_void_p(tensor.data_ptr()) for tensor in list(input_tensors) + [out]] + autotuning_log.debug( + "make_run_fn: self.kernel_name=%s, self.source_file=%s, self.hash_key=%s, self.DLL=%s, args=%s, self.extra_args=%s", + self.kernel_name, + self.source_file, + self.hash_key, + self.DLL, + args, + self.extra_args, + ) + stream_ptr = c_void_p(torch.cuda.current_stream().cuda_stream) + run_method = getattr(self.DLL, self.kernel_name) + workspace_ptr = c_void_p(0) + if self.workspace_size > 0: + self.workspace = torch.zeros( + (self.workspace_size + 7) // 8, + dtype=torch.float64, + device=out.device, + ) + workspace_ptr = c_void_p(self.workspace.data_ptr()) + + # Generate partial function. + ret = functools.partial( + run_method, + *args, + *self.extra_args, + None, # null workspace size ptr + workspace_ptr, # set workspace ptr, + stream_ptr, + ) + + # sanity check to make sure we cleanup run fn properly + try: + ret() + except RuntimeError as e: + err_msg = str(e) + + def raise_runtime_error(): + raise RuntimeError(err_msg) + + self.cleanup_run_fn() + return raise_runtime_error + + return ret + + def update_workspace_size(self) -> None: + if self._workspace_size_updated: + return + self.ensure_dll_loaded() + unique_input_count = len( + dict.fromkeys(meta.name for meta in self.input_tensor_meta) + ) + args = [c_void_p(None) for _ in range(unique_input_count + 1)] + stream_ptr = c_void_p(torch.cuda.current_stream().cuda_stream) + + run_method = getattr(self.DLL, self.kernel_name) + # Retrieve workspace_size and initialize workspace. + c_workspace_size = c_size_t() + run_method( + *args, # input ptrs and output ptrs + *self.extra_args, + byref( + c_workspace_size + ), # set workspace size ptr to retrieve workspace size + None, # null workspace ptr + stream_ptr, + ) + torch.cuda.synchronize() # shake out any CUDA errors + self.workspace_size = c_workspace_size.value + autotuning_log.debug( + "update_workspace_size called: new workspace size=%d, self.kernel_name=%s, self.source_file=%s, self.hash_key=%s, self.DLL=%s, args=%s, self.extra_args=%s", # noqa: B950 + self.workspace_size, + self.kernel_name, + self.source_file, + self.hash_key, + self.DLL, + args, + self.extra_args, + ) + self._workspace_size_updated = True + + def ensure_dll_loaded(self): + if self.DLL is None: + self.DLL, self.hash_key, self.source_file = CUDACodeCache.load( + self.source_code, "so" + ) + + def cleanup_run_fn(self) -> None: + if self.DLL is not None: + self.DLL.close() + self.DLL = None + self.workspace = None + + def __str__(self) -> str: + return f"{self.kernel_name=}, {self.source_file=}, {self.hash_key=}" + + +class CppBenchmarkRequest(CPUDeviceBenchmarkMixin, BenchmarkRequest): + # Important: Instances of this class have to be serializable + # across process boundaries. Do not put Tensors in here! + + def __init__( + self, + kernel_name: str, + input_tensor_meta: Union[TensorMeta, list[TensorMeta]], + output_tensor_meta: Union[TensorMeta, list[TensorMeta]], + extra_args: Iterable[Any], + source_code: str, + ) -> None: + super().__init__(kernel_name, input_tensor_meta, output_tensor_meta, extra_args) + self.source_code = source_code + self.hash_key = get_hash(source_code) + self.DLL: Optional[Union[CDLL, ModuleType]] = None + + def precompile(self): + # Prepopulate CppCodeCache + # may happen in separate Threadpool + autotuning_log.debug("Precompiling %s", self) + CppCodeCache.load(self.source_code, device_type="cpu") + autotuning_log.debug("Done precompiling %s", self) + + def make_run_fn( + self, *input_tensors: torch.Tensor, out: torch.Tensor + ) -> Callable[[], None]: + # TODO(jgong5): use CppPythonBindingsCodeCache for better binding perf + self.DLL = CppCodeCache.load(self.source_code, device_type="cpu") + args = [tensor.data_ptr() for tensor in list(input_tensors) + [out]] + autotuning_log.debug( + "make_run_fn: self.kernel_name=%s, self.DLL=%s, args=%s, self.extra_args=%s", + self.kernel_name, + self.DLL, + args, + self.extra_args, + ) + run_method = getattr(self.DLL, self.kernel_name) + # Assume only size with type ctypes.c_ulonglong in extra_args + assert all(isinstance(arg, ctypes.c_ulonglong) for arg in self.extra_args) + run_method.argtypes = [ctypes.c_ulonglong] * ( + len(args) + len(list(self.extra_args)) + ) + + # Generate partial function. + return functools.partial( + run_method, + *args, + *self.extra_args, + ) + + def __str__(self) -> str: + return f"{self.kernel_name=}" + + +class CuteDSLBenchmarkRequest(GPUDeviceBenchmarkMixin, BenchmarkRequest): + """Benchmark request for CuteDSL (CUTLASS Python DSL) kernels.""" + + def __init__( + self, + kernel_name: str, + input_tensor_meta: Union[TensorMeta, list[TensorMeta]], + output_tensor_meta: Union[TensorMeta, list[TensorMeta]], + extra_args: tuple[Any, ...], + source_code: PartialRender, + ) -> None: + super().__init__(kernel_name, input_tensor_meta, output_tensor_meta, extra_args) + + finalized_code = source_code.finalize_all() + self.module_cache_key, self.module_path = PyCodeCache.write(finalized_code) + + def make_run_fn( + self, *input_tensors: torch.Tensor, out: torch.Tensor + ) -> Callable[[], None]: + """ + Create a function to run the CuteDSL kernel with the given input and output tensors. + Similar to TritonBenchmarkRequest.make_run_fn but for CuteDSL kernels. + """ + mod = PyCodeCache.load_by_key_path(self.module_cache_key, self.module_path) + + # Logic replicated async_compile + from .codegen.cutedsl.cutedsl_kernel import MAIN_SUFFIX + + main_func_name = f"{self.kernel_name}_{MAIN_SUFFIX}" + + if not hasattr(mod, main_func_name): + available = [name for name in dir(mod) if callable(getattr(mod, name))] + raise RuntimeError( + f"Could not find CuteDSL main kernel function '{main_func_name}'. Available callables: {available}" + ) + + kernel_func = getattr(mod, main_func_name) + + def run_kernel(): + device_interface = get_interface_for_device("cuda") + stream = device_interface.get_raw_stream(out.device.index) + return kernel_func(*input_tensors, out, stream=stream) + + return run_kernel + + +@functools.cache +def get_tuning_process_pool() -> TuningProcessPool: + pool = TuningProcessPool() + atexit.register(pool.shutdown) + return pool + + +def benchmark_in_sub_process( + choices: list[TritonTemplateCaller], +) -> dict[TritonTemplateCaller, float]: + """ + Do benchmarking in a subprocess and return the perf number (latency). + """ + return get_tuning_process_pool().benchmark(choices) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/await_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/await_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2468b0039a18df324216c38e5797e9d2b805edd7 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/await_utils.py @@ -0,0 +1,178 @@ +import asyncio +import sys +import weakref +from asyncio import AbstractEventLoop, Future +from collections.abc import Awaitable, Callable, Coroutine, Generator, Iterator +from contextlib import contextmanager, ExitStack +from contextvars import Context +from typing import Any, Optional, Protocol, TypeVar + +from torch.utils._ordered_set import OrderedSet + + +T = TypeVar("T") +TCoro = Generator[Any, None, T] + +if sys.version_info >= (3, 11): + + class TaskFactory(Protocol): + def __call__( + self, + __loop: AbstractEventLoop, + __factory: Coroutine[None, None, object] | Generator[None, None, object], + __context: Context | None = None, + /, + ) -> asyncio.futures.Future[object]: ... + + TaskFactoryType = TaskFactory +else: + TaskFactoryType = Callable[[AbstractEventLoop, Generator[TCoro, None, T]], Future] # type: ignore[valid-type] + + +def await_sync(awaitable: Awaitable[T]) -> T: + with get_loop() as loop: + return loop.run_until_complete(awaitable) + + +@contextmanager +def get_loop( + always_create_new_loop: bool = False, +) -> Iterator[AbstractEventLoop]: + try: + loop = asyncio.get_event_loop() + except RuntimeError as re: + if "There is no current event loop in thread" in str(re): + with _new_loop() as loop: + yield loop + return + else: + raise + + @contextmanager + def _restore_loop( + loop: asyncio.AbstractEventLoop, + ) -> Iterator[None]: + try: + yield + finally: + asyncio.set_event_loop(loop) + + @contextmanager + def _restore_running_loop() -> Iterator[None]: + loop_from_events = asyncio.events._get_running_loop() + asyncio.events._set_running_loop(None) + try: + yield + finally: + asyncio.events._set_running_loop(loop_from_events) + + with ExitStack() as stack: + if loop.is_running(): + stack.enter_context(_restore_running_loop()) + stack.enter_context(_restore_loop(loop=loop)) + loop = stack.enter_context(_new_loop(loop.get_task_factory())) # type: ignore[arg-type] + elif loop.is_closed(): + loop = stack.enter_context(_new_loop()) # type: ignore[arg-type] + elif always_create_new_loop: + stack.enter_context(_restore_loop(loop=loop)) + loop = stack.enter_context(_new_loop()) # type: ignore[arg-type] + yield loop + + +@contextmanager +def _new_loop( + task_factory: Optional[TaskFactoryType] = None, +) -> Iterator[asyncio.AbstractEventLoop]: + loop = asyncio.new_event_loop() + tasks = _patch_loop(loop) + + if task_factory: + # pyre-ignore[6] + loop.set_task_factory(task_factory) # type: ignore[arg-type] + + asyncio.set_event_loop(loop) + try: + yield loop + finally: + try: + _cancel_all_tasks(loop, tasks) + finally: + asyncio.set_event_loop(None) + loop.close() + + +def _cancel_all_tasks( + loop: AbstractEventLoop, + tasks: OrderedSet[Future], # type: ignore[type-arg] +) -> None: + to_cancel = [task for task in tasks if not task.done()] + + if not to_cancel: + return + + # pyre-fixme[1001]: Awaitable assigned to `task` is never awaited. + for task in to_cancel: + task.cancel() + + # pyrefly: ignore [bad-argument-type] + loop.run_until_complete(asyncio.gather(*to_cancel, return_exceptions=True)) + + for task in to_cancel: + if task.cancelled(): + continue + if task.exception() is not None: + loop.call_exception_handler( + { + "message": "unhandled exception during asyncio.run() shutdown", + "exception": task.exception(), + "task": task, + } + ) + + +def _patch_loop(loop: AbstractEventLoop) -> OrderedSet[Future]: # type: ignore[type-arg] + tasks: weakref.WeakSet[Future] = weakref.WeakSet() # type: ignore[type-arg] + + task_factories: list[Optional[TaskFactoryType]] = [None] + + def _set_task_factory(factory: Optional[TaskFactoryType]) -> None: + task_factories[0] = factory + + def _get_task_factory() -> Optional[TaskFactoryType]: + return task_factories[0] + + def _safe_task_factory( + loop: AbstractEventLoop, + coro: TCoro, # type: ignore[type-arg] + *, + context: Context | None = None, + ) -> asyncio.Future: # type: ignore[valid-type, type-arg] + task_factory = task_factories[0] + if task_factory is None: + if sys.version_info >= (3, 11): + # pyrefly: ignore [bad-argument-type] + task = asyncio.Task(coro, loop=loop, context=context) + else: + task = asyncio.Task(coro, loop=loop) + # pyre-ignore[16]: `Task` has no attribute `_source_traceback`. + if task._source_traceback: # type: ignore[attr-defined] + del task._source_traceback[ # type: ignore[attr-defined] + -1 + ] # pragma: no cover # type: ignore[attr-defined] + else: + if sys.version_info >= (3, 11): + task = task_factory(loop, coro, context=context) # type: ignore[arg-type, call-arg, assignment] + else: + task = task_factory(loop, coro) # type: ignore[arg-type] + # `Union[Task[Any], Future[Any]]`. + tasks.add(task) + return task + + # pyre-ignore[6] + loop.set_task_factory(_safe_task_factory) # type: ignore[method-assign, arg-type] + # pyre-ignore[8] + loop.set_task_factory = _set_task_factory # type: ignore[method-assign, assignment] + # pyre-ignore[8] + loop.get_task_factory = _get_task_factory # type: ignore[method-assign, assignment] + + return tasks # type: ignore[return-value] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/bounds.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/bounds.py new file mode 100644 index 0000000000000000000000000000000000000000..bc8dba511925212e14f3230f2a1fb0539706b5c1 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/bounds.py @@ -0,0 +1,260 @@ +import logging +import operator +from collections.abc import Callable +from functools import partial +from typing import Any, Optional, Union + +import sympy +from sympy import Expr + +import torch +from torch.utils._sympy.value_ranges import ( + bound_sympy, + SymPyValueRangeAnalysis, + ValueRanges, +) + +from ..utils._sympy.functions import PowByNatural +from ..utils._sympy.numbers import int_oo +from .loop_body import InterpreterShim, LoopBody, LoopBodyBlock +from .ops_handler import DefaultHandler, ReductionType, StoreMode +from .utils import cache_on_self, dominated_nodes +from .virtualized import V + + +log = logging.getLogger(__name__) + + +class BoundVars: + """ + Performs Value Range Analysis on LoopBody's fx graph by calling BoundVars.run() + It exposes the ranges of the nodes in the `bounds` variable + + Note. A current limitation of this analysis is that it just works on a per-loop basis. + We should be able to propagate the bounds between across the whole graph. This may benefit + the case a bounded variable is returned by a kernel and fed into another. + """ + + def __init__(self, loop_body: LoopBody) -> None: + def upper_bound(v: Union[Expr, int]) -> int: + return bound_sympy(v).upper if isinstance(v, Expr) else v + + self.loop_body = loop_body + self.replacement_vals = { + k: ValueRanges[Expr](0, upper_bound(v) - 1) + for k, v in loop_body.var_ranges.items() + } + # avoid computing these values, pessimistically assume that they are unbounded + self.unbounded_vars = dominated_nodes( + node + for node in self.loop_body.get_nodes() + if node.target in ["load", "reduction", operator.getitem] + or "masked_subblock" in node.target + ) + # To access this variable call `get_bounds()` + self._bounds: dict[torch.fx.Node, ValueRanges[Expr]] = {} + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(" + f"loop_body={self.loop_body},\n " + f"replacement_vals={self.replacement_vals}, \n" + f"unbounded_vars={self.unbounded_vars}, \n" + f"_bounds={self._bounds})" + ) + + @cache_on_self + def get_bounds(self) -> dict[torch.fx.Node, ValueRanges[Expr]]: + submodules = self.swap_submodules(self.loop_body.submodules) + + # Initialize the environment with the unbounded variables + for node in self.unbounded_vars: + # we need to evaluate masked_subblock to recurse, and we need to set indirect values + if not isinstance(node.target, str) or ( + "masked_subblock" not in node.target + and "set_indirect" not in node.target + ): + self._bounds[node] = ValueRanges[Expr].unknown() + + with V.set_ops_handler(ValueRangeAnalysis()): + interpreter = InterpreterShim(self.loop_body.root_block.graph, submodules) + log.debug("get_bounds:\n%s", self.loop_body.root_block.graph) + interpreter.run(V.get_ops_handler(), initial_env=self._bounds) + return self._bounds + + def swap_submodules( + self, submodules: dict[str, Callable[..., Any]] + ) -> dict[str, Callable[..., ValueRanges[Expr]]]: + result: dict[str, Callable[..., ValueRanges[Expr]]] = {} + for key in submodules: + if key == "get_index": + result[key] = self.get_index + elif "masked_subblock" in key: + subblock = self.loop_body.subblocks[key] + # The result within the lambda will reference to the final + # set of modules at the end of the for-loop as it stores a reference to it + + # bind subblock in a function because python lambdas close over by reference + # moving the lambda out of make_fn would close over the reference to subblock, + # so all lambdas would have the same subblock reference that is the final + # subblock in the loop + def make_fn( + subblock: LoopBodyBlock, + ) -> Callable[[Any, Any], ValueRanges[Expr]]: + return lambda mask, value: self.masked_subblock( + subblock, self._bounds, mask, value, result + ) + + result[key] = make_fn(subblock) + elif "set_indirect" in key: + idx = int(key[len("set_indirect") :]) + var = self.loop_body.indirect_vars[idx] + indirect = partial(self.set_indirect, var) + result[key] = indirect + else: + assert "scan" in key + result[key] = submodules[key] + + return result + + def masked_subblock( + self, + subblock: LoopBodyBlock, + env: dict[torch.fx.Node, ValueRanges[Expr]], + mask: Any, + value: Any, + submodules: dict[str, Callable[..., Any]], + ) -> ValueRanges[Expr]: + interp = InterpreterShim(subblock.graph, submodules) + interp.run(V.get_ops_handler(), initial_env=env) + output = [node for node in subblock.graph.nodes if node.target == "output"] + assert len(output) == 1 + # dont bother unioning with value since the load from buffer will be + # pessimistically assumed to be inf anyway + return interp.env[output[0]] + + def set_indirect(self, old: Expr, new: ValueRanges[Expr]) -> ValueRanges[Expr]: + assert isinstance(new, ValueRanges) + self.replacement_vals[old] = new + return new + + def get_index(self, name: str) -> ValueRanges[Expr]: + expr = self.loop_body.indexing_exprs[name] + bound = self.replacement_vals.get(expr) + if bound is None: + bound = bound_sympy(expr, self.replacement_vals) + # The following assertion is true at the time of this writing + # We don't assert is as to not execute bound_sympy when bound is not None + # assert bound is None or bound == bound_sympy(expr, self.replacement_vals) + self.replacement_vals[name] = bound + return bound + + +class ValueRangeAnalysis(SymPyValueRangeAnalysis, DefaultHandler): + def __init__(self) -> None: + self.name = "ValueRangeAnalysis" + boolean_operators = ( + "xor", + "logical_and", + "logical_or", + "logical_not", + ) + for op in boolean_operators: + setattr(self, op, self.bool_handler) + + @staticmethod + def bool_handler(*args: Any, **kwargs: Any) -> ValueRanges[Any]: + # just assuming bools can have both values + return ValueRanges(sympy.false, sympy.true) # type: ignore[arg-type] + + def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: + # many ops are unlikely to show up in optimizable indexing compute, + # so we dont have full coverage + return ValueRanges.unknown() + + def load(self, name: str, index: sympy.Expr) -> ValueRanges[Any]: + return ValueRanges.unknown() + + def store( + self, name: str, index: sympy.Expr, value: Any, mode: StoreMode = None + ) -> None: + return + + def reduction( + self, + dtype: torch.dtype, + src_dtype: torch.dtype, + reduction_type: ReductionType, + value: Any, + ) -> ValueRanges[Any]: + return ValueRanges.unknown() + + @classmethod + def index_expr(cls, index: Any, dtype: torch.dtype) -> ValueRanges[Any]: + assert isinstance(index, ValueRanges) + return cls.to_dtype(index, dtype) + + @staticmethod + def to_dtype( + x: Any, + dtype: torch.dtype, + src_dtype: Optional[torch.dtype] = None, + use_compute_types: bool = True, + ) -> ValueRanges[Any]: + x = ValueRanges.wrap(x) + + if dtype == torch.bool: + if x.is_singleton(): + return ValueRanges.wrap(x.lower != 0) + elif x.is_bool: + return x + elif 0 not in x: + return ValueRanges.wrap(sympy.true) + else: + return ValueRanges(sympy.false, sympy.true) + + def cast(x: Any, dtype: torch.dtype) -> sympy.Expr: + # dtype is int or float + if dtype.is_floating_point: + return sympy.Float(x) + else: + if x in (int_oo, -int_oo): + return x + try: + return sympy.Integer(x) + except TypeError: + # inf cannot be cast to Integer + return x + + if x.is_bool: + if x.is_singleton(): + val = 1 if x.lower else 0 + return ValueRanges.wrap(cast(val, dtype)) + else: + return ValueRanges(cast(0, dtype), cast(1, dtype)) + else: + # int to float or float to int + return ValueRanges(cast(x.lower, dtype), cast(x.upper, dtype)) + + @staticmethod + def square(x: Any) -> ValueRanges[Any]: + return ValueRanges.convex_min_zero_map(x, lambda y: PowByNatural(y, 2)) + + @staticmethod + def neg(x: Any) -> ValueRanges[Any]: + return ValueRanges.decreasing_map(x, operator.neg) + + # TODO: this is slightly inaccurate because truncdiv operates at integer + # precision, but we're going through float truediv which means we can + # potentially lose precision on the bounds + @classmethod + def truncdiv(cls, a: Any, b: Any) -> ValueRanges[Any]: + x = cls.truediv(a, b) + if x == ValueRanges.unknown(): + return x + + return cls.trunc(x) + + @classmethod + def sub(cls, a: Any, b: Any) -> ValueRanges[Any]: + return cls.add(a, cls.neg(b)) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/cache.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/cache.py new file mode 100644 index 0000000000000000000000000000000000000000..118bbf2828799d8fd63a96427b5e88573d6adf19 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/cache.py @@ -0,0 +1,419 @@ +from __future__ import annotations + +import pickle +from abc import ABC, abstractmethod +from ast import literal_eval +from functools import cached_property +from hashlib import sha256 +from os import getenv +from pathlib import Path +from tempfile import gettempdir +from threading import Lock +from typing import Any, Generic, TYPE_CHECKING, TypeVar +from typing_extensions import assert_never, override, Self + +from torch.utils._filelock import FileLock + + +if TYPE_CHECKING: + from concurrent.futures import Future, ThreadPoolExecutor + + +# TypeVars can't be recursive, so generic types that fall within +# Key or Value can't be bound properly; for example, Key should +# only take tuples of other Key types: tuple[Key, ...]. this is +# a known shortcoming of torch's typing +Key = TypeVar("Key", str, int, tuple[Any, ...]) +Value = TypeVar("Value", str, int, tuple[Any, ...], bytes, dict[Any, Any], list[Any]) + + +class CacheError(ValueError): + """ + Exception raised for errors encountered during cache operations. + """ + + +class Cache(ABC, Generic[Key, Value]): + """ + Abstract base class for cache implementations. + Provides the interface for cache operations. + """ + + @abstractmethod + def get(self: Self, key: Key) -> Value | None: + """ + Retrieve a value from the cache. + Args: + key (Key): The key to look up. + Returns: + Value | None: The cached value if present, else None. + """ + + @abstractmethod + def insert(self: Self, key: Key, value: Value) -> bool: + """ + Insert a value into the cache. + Args: + key (Key): The key to insert. + value (Value): The value to associate with the key. + Returns: + bool: True if the value was inserted, False if the key already exists. + """ + + +class InMemoryCache(Cache[Key, Value]): + """ + In-memory cache implementation using a dictionary and thread lock. + """ + + def __init__(self: Self) -> None: + """ + Initialize an empty in-memory cache. + """ + self._cache: dict[Key, Value] = {} + self._lock: Lock = Lock() + + def get(self: Self, key: Key) -> Value | None: + """ + Retrieve a value from the cache. + Args: + key (Key): The key to look up. + Returns: + Value | None: The cached value if present, else None. + """ + with self._lock: + if (value := self._cache.get(key)) is not None: + return value + return None + + def insert(self: Self, key: Key, value: Value) -> bool: + """ + Insert a value into the cache. + Args: + key (Key): The key to insert. + value (Value): The value to associate with the key. + Returns: + bool: True if the value was inserted, False if the key already exists. + """ + with self._lock: + if key in self._cache: + # no overwrites for insert! + return False + self._cache[key] = value + return True + + @classmethod + def from_env_var(cls, env_var: str) -> Self: + """ + Create an in-memory cache from an environment variable. + Args: + env_var (str): Name of the environment variable containing cache data. + Returns: + InMemoryCache: An instance populated from the environment variable. + Raises: + CacheError: If the environment variable is malformed or contains invalid data. + """ + cache = cls() + + if (env_val := getenv(env_var)) is None: + # env_var doesn't exist = empty cache + return cache + + for kv_pair in env_val.split(";"): + # ignore whitespace prefix/suffix + kv_pair = kv_pair.strip() + + if not kv_pair: + # kv_pair could be '' if env_val is '' or has ; suffix + continue + + try: + # keys and values should be comma separated + key_bytes_repr, value_bytes_repr = kv_pair.split(",", 1) + except ValueError as err: + raise CacheError( + f"Malformed kv_pair {kv_pair!r} from env_var {env_var!r}, likely missing comma separator." + ) from err + + # ignore whitespace prefix/suffix, again + key_bytes_repr, value_bytes_repr = ( + key_bytes_repr.strip(), + value_bytes_repr.strip(), + ) + + try: + # check that key_bytes_str is an actual, legitimate encoding + key_bytes = literal_eval(key_bytes_repr) + except (ValueError, SyntaxError) as err: + raise CacheError( + f"Malformed key_bytes_repr {key_bytes_repr!r} in kv_pair {kv_pair!r}, encoding is invalid." + ) from err + try: + # check that value_bytes_str is an actual, legitimate encoding + value_bytes = literal_eval(value_bytes_repr) + except (ValueError, SyntaxError) as err: + raise CacheError( + f"Malformed value_bytes_repr {value_bytes_repr!r} in kv_pair {kv_pair!r}, encoding is invalid." + ) from err + + try: + key = pickle.loads(key_bytes) + except pickle.UnpicklingError as err: + raise CacheError( + f"Malformed key_bytes_repr {key_bytes_repr!r} in kv_pair {kv_pair!r}, not un-pickle-able." + ) from err + try: + value = pickle.loads(value_bytes) + except pickle.UnpicklingError as err: + raise CacheError( + f"Malformed value_bytes_repr {value_bytes_repr!r} in kv_pair {kv_pair!r}, not un-pickle-able." + ) from err + + # true duplicates, i.e. multiple occurrences of the same key => value + # mapping are ok and treated as a no-op; key duplicates with differing + # values, i.e. key => value_1 and key => value_2 where value_1 != value_2, + # are not okay since we don't allow overwriting cached values (it's bad regardless) + if (not cache.insert(key, value)) and (cache.get(key) != value): + raise CacheError( + f"Multiple values for key {key!r} found, got {cache.get(key)!r} and {value!r}." + ) + + return cache + + @classmethod + def from_file_path(cls, fpath: Path) -> Self: + """ + Create an in-memory cache from a file path. + Args: + fpath (Path): Path to the file containing pickled cache data. + Returns: + InMemoryCache: An instance populated from the file. + Raises: + CacheError: If the file is not a valid pickled dictionary. + """ + cache = cls() + + if not fpath.is_file(): + # fpath doesn't exit = empty cache + return cache + + try: + with open(fpath, "rb") as fp: + cache._cache = pickle.load(fp) + except pickle.UnpicklingError as err: + raise CacheError( + f"Failed to create cache from file path {fpath}, file contents are un-pickle-able." + ) from err + + if not isinstance(cache._cache, dict): + raise CacheError( + f"Failed to create cache from file path {fpath}, file contents not pickled dict[Key, Value]." + ) + + return cache + + +class AsyncCache(Cache[Key, Value]): + """ + Asynchronous cache implementation using ThreadPoolExecutor. + """ + + def get_async( + self: Self, key: Key, executor: ThreadPoolExecutor + ) -> Future[Value | None]: + """ + Retrieve a value from the cache asynchronously. + Args: + key (Key): The key to look up. + executor (ThreadPoolExecutor): Executor for async execution. + Returns: + Future[Value | None]: Future for the cached value or None. + """ + return executor.submit(self.get, key) + + def insert_async( + self: Self, key: Key, value: Value, executor: ThreadPoolExecutor + ) -> Future[bool]: + """ + Insert a value into the cache asynchronously. + Args: + key (Key): The key to insert. + value (Value): The value to associate with the key. + executor (ThreadPoolExecutor): Executor for async execution. + Returns: + Future[bool]: Future for the result of insertion. + """ + return executor.submit(self.insert, key, value) + + +class OnDiskCache(AsyncCache[Key, Value]): + """ + On-disk cache implementation using files and file locks. + Stores cache data in files on disk, with atomic operations and versioning. + Supports custom cache directory names. + Attributes: + version (int): The version used for cache versioning. + name (str): The name of the cache directory. + """ + + version: int = 0 + + def __init__(self: Self, name: str | None = None) -> None: + """ + Initialize an on-disk cache instance. + Args: + name (str | None, optional): The name of the cache directory. If None, + defaults to "on_disk_cache". + """ + self.name = name or "on_disk_cache" + + @cached_property + def base_dir(self: Self) -> Path: + """ + Get the base directory for the cache. + Returns: + Path: The base directory path for storing cache files. + """ + return Path(gettempdir()) / "cache" / self.name + + def _fpath_from_key(self: Self, key: Key) -> Path: + """ + Get the file path for a given key. + Args: + key (Key): The key to convert to a file path. + Returns: + Path: The file path for the key. + Raises: + CacheError: If the key is not pickle-able. + """ + try: + return self.base_dir / sha256(pickle.dumps(key)).hexdigest()[:32] + except (AttributeError, pickle.PicklingError) as err: + raise CacheError( + f"Failed to get fpath for key {key!r}, key is not pickle-able." + ) from err + # pyrefly: ignore [bad-argument-type] + assert_never(key) + + def _flock_from_fpath(self: Self, fpath: Path) -> FileLock: + """ + Get a file lock for a given file path. + Args: + fpath (Path): The file path. + Returns: + FileLock: The file lock for the path. + """ + # fpath.name is a hex digest, meaning there are 16^4 potential values + # for fpath.name[:4]; this is more than enough unique locks to not + # cause additional overhead from shared locks and it also saves our + # cache dir from becoming 50 percent locks + # pyrefly: ignore [bad-return] + return FileLock(str(fpath.parent / "locks" / fpath.name[:4]) + ".lock") + + @property + def version_prefix(self: Self) -> bytes: + """ + Get the version prefix for the cache. + Returns: + bytes: The version prefix as bytes, derived from the cache version string. + """ + return sha256(str(OnDiskCache.version).encode()).digest()[:4] + + @override + def get(self: Self, key: Key) -> Value | None: + """ + Retrieve a value from the cache. + Args: + key (Key): The key to look up. + Returns: + Value | None: The cached value if present and version matches, else None. + Raises: + CacheError: If the value is corrupted or cannot be unpickled. + Side Effects: + Removes stale cache files if the version prefix does not match. + """ + fpath = self._fpath_from_key(key) + flock = self._flock_from_fpath(fpath) + + with flock: + if not fpath.is_file(): + return None + + value_bytes = None + prefix_length = len(self.version_prefix) + with open(fpath, "rb") as fp: + if fp.read(prefix_length) == self.version_prefix: + value_bytes = fp.read() + + if value_bytes is None: + # version_prefix did not match, so we can't read the stale + # cached value; we should also remove the stale cached value, + # so that key can be re-cached by the newer version + fpath.unlink() + return None + + try: + value = pickle.loads(value_bytes) + except pickle.UnpicklingError as err: + raise CacheError( + f"Failed to get key {key!r}, value is potentially corrupted (value is not un-pickle-able)." + ) from err + + return value + + @override + def insert(self: Self, key: Key, value: Value) -> bool: + """ + Insert a value into the cache. + Args: + key (Key): The key to insert. + value (Value): The value to associate with the key. + Returns: + bool: True if the value was inserted, False if the key already exists. + Raises: + CacheError: If the value is not pickle-able. + Side Effects: + Creates the cache directory if it does not exist. + """ + fpath = self._fpath_from_key(key) + flock = self._flock_from_fpath(fpath) + fpath.parent.mkdir(parents=True, exist_ok=True) + try: + # "x" mode is exclusive creation, meaning the file will be created + # iff the file does not already exist (atomic w/o overwrite); use + # flock for added atomicity guarantee and to prevent partial writes + with flock as _, open(fpath, "xb") as fp: + fp.write(self.version_prefix) + pickle.dump(value, fp) + except pickle.PicklingError as err: + raise CacheError( + f"Failed to insert key {key!r} with value {value!r}, value is not pickle-able." + ) from err + except FileExistsError: + return False + return True + + +class InductorOnDiskCache(OnDiskCache[Key, Value]): + """ + Inductor-specific on-disk cache implementation. + Uses a custom base directory for Inductor cache files. + """ + + def __init__(self: Self) -> None: + """ + Initialize an inductor on-disk cache instance. + Sets the cache directory name to "inductor_on_disk_cache". + """ + super().__init__("inductor_on_disk_cache") + + @cached_property + def base_dir(self: Self) -> Path: + """ + Get the base directory for the Inductor cache. + Returns: + Path: The base directory path for Inductor cache files. + """ + from torch._inductor.runtime.runtime_utils import default_cache_dir + + return Path(default_cache_dir(), "cache", self.name) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/choices.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/choices.py new file mode 100644 index 0000000000000000000000000000000000000000..d2a89684a97302fb34933b933dc1b13ba1ac736d --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/choices.py @@ -0,0 +1,651 @@ +from __future__ import annotations + +import dataclasses +import typing +from typing import Any, Optional, TYPE_CHECKING, Union + +import sympy + +import torch +from torch._inductor.runtime.runtime_utils import next_power_of_2 +from torch._inductor.scheduler import MixOrderReduction +from torch.utils._sympy.value_ranges import bound_sympy + +from . import config +from .codecache import write_text +from .kernel_inputs import KernelInputs # noqa: TC001 +from .kernel_template_choice import make_ktc_generator +from .metrics import get_metric_table, is_metric_table_enabled +from .runtime.hints import DeviceProperties, ReductionHint +from .scheduler import BaseSchedulerNode, Scheduler, WhyNoFuse +from .select_algorithm import ExternKernelChoice +from .template_heuristics import get_template_heuristic +from .template_heuristics.triton import ( + BaseConfigHeuristic, + CPUConfigHeuristic, + CUDAConfigHeuristic, + MTIAConfigHeuristic, + ROCmConfigHeuristic, + XPUConfigHeuristic, +) +from .utils import _use_autotune_backend +from .virtualized import V + + +if TYPE_CHECKING: + from collections.abc import Generator + from functools import partial + + from triton import Config as TritonConfig + + from .codegen.common import KernelTemplate + from .codegen.simd_kernel_features import SIMDKernelFeatures + from .codegen.triton import TritonKernel + from .ir import ChoiceCaller + from .kernel_template_choice import KernelTemplateChoice + + from torch.utils._ordered_set import OrderedSet # isort: skip + + +class Sortable(typing.Protocol): + """Anything that can be used as a list.sort() key (int/tuple/etc)""" + + def __lt__(self, other: typing.Self) -> bool: ... + + +@dataclasses.dataclass +class FusionScore: + template_score: int + node_type_score: bool + memory_score: int + proximity_score: int + + def __lt__(self, other): + """ + node_type_score has higher priority than memory_score unless + the memory_score differs too much + """ + threshold = 16 + if self.template_score != other.template_score: + return self.template_score < other.template_score + + if ( + max(self.memory_score, other.memory_score) + > min(self.memory_score, other.memory_score) * threshold + ): + return self.memory_score < other.memory_score + + return (self.node_type_score, self.memory_score, self.proximity_score) < ( + other.node_type_score, + other.memory_score, + other.proximity_score, + ) + + +class InductorChoices: + """ + This class contains a collection of default heuristics that effect performance of our generated + code. We try to not put correctness requirements in this file. + + You can override the choices made here by doing: + + class MyHeuristics(InductorChoices): + ... + + torch._inductor.virtualized.V.set_choices_handler(MyHeuristics()) + """ + + def get_config_heuristics( + self, device_type: Optional[str] = "cuda" + ) -> BaseConfigHeuristic: + if device_type == "cuda": + if torch.version.hip is None: + return CUDAConfigHeuristic() + else: + return ROCmConfigHeuristic() + elif device_type == "xpu": + return XPUConfigHeuristic() + elif device_type == "cpu": + return CPUConfigHeuristic() + elif device_type == "mtia": + return MTIAConfigHeuristic() + else: + return BaseConfigHeuristic() + + # Conv configs + def get_conv_configs( + self, device_type: Optional[str] = "cuda" + ) -> partial[Generator[TritonConfig, None, None]]: + conv_heuristics = self.get_config_heuristics(device_type) + return conv_heuristics.get_conv_configs() + + # Flex attention configs + # TODO(coconutruben): break out flexattention/decode configs into the new retrieval mechanism + def get_flex_attention_fwd_configs( + self, head_dim: int, dtype: torch.dtype, device_type: Optional[str] = "cuda" + ) -> list[Any]: + flex_heuristics = self.get_config_heuristics(device_type) + return flex_heuristics.get_flex_attn_fwd_configs(head_dim, dtype) + + def get_flex_attention_bwd_configs( + self, head_dim: int, dtype: torch.dtype, device_type: Optional[str] = "cuda" + ) -> list[Any]: + flex_heuristics = self.get_config_heuristics(device_type) + return flex_heuristics.get_flex_attn_bwd_configs(head_dim, dtype) + + def get_flex_decode_configs( + self, head_dim: int, dtype: torch.dtype, device_type: Optional[str] = "cuda" + ) -> list[Any]: + flex_heuristics = self.get_config_heuristics(device_type) + return flex_heuristics.get_flex_decode_configs(head_dim, dtype) + + def _finalize_template_configs( + self, + template_choices: dict[str, Generator[KernelTemplateChoice, None, None]], + kernel_inputs: KernelInputs, + templates: list[Union[KernelTemplate, ExternKernelChoice]], + op_name: str, + kwarg_overrides: Optional[dict[str, dict[str, Any]]] = None, + ) -> list[KernelTemplateChoice]: + """ + This method can be subclassed to perform any override/modification of the choices. + The incoming parameters are cheap (generators), so you can do any overrides without + incurring too much cost. Override this method to customize the kernel template choices + before they are converted to ChoiceCaller objects, which is expensive on template codegen. + + The full list of arguments are here to facilitate any overrides you may want to do, + as they can be used to start from scratch for each template if so desired. + + Args: + template_choices: Dictionary mapping template UIDs to generators of KernelTemplateChoice objects + kernel_inputs: MMKernelInputs containing input tensor nodes and matrix indices + templates: List of template objects (KernelTemplate or ExternKernelChoice) in use + op_name: Operation name (e.g., "bmm", "baddbmm", "addmm") + kwarg_overrides: Optional dict of kwargs to override for each template heuristic + + Returns: + Flattened list of KernelTemplateChoice objects across all templates + """ + choices: list[KernelTemplateChoice] = [] + for choice_gen in template_choices.values(): + choices.extend(choice_gen) + return choices + + def get_ktc( + self, + kernel_inputs: KernelInputs, + template: Union[KernelTemplate, ExternKernelChoice], + op_name: str, + kwarg_overrides: Optional[dict[str, Any]] = None, + ) -> Generator[KernelTemplateChoice, None, None]: + """ + Utility to get the KernelTemplateChoice generator for a specific input. + + This is a per template/op call, whereas get_template_configs is an op wide call (all templates). + Consider when overriding/using at which level you need to make decisions + """ + # Extract device_type from kernel_inputs + device_type = kernel_inputs.device_type + assert device_type is not None, "get_ktc requires a valid device type" + # Extract template_name from the template object + template_name = template.uid + + # Get the appropriate template-specific heuristic + heuristic = get_template_heuristic(template_name, device_type, op_name) + cs = heuristic.get_template_configs( + kernel_inputs, + op_name, + ) + # adjust the kernel inputs to the template-specific heuristic, if needed + # default here is to just return the kernel_inputs as is + inputs_val = heuristic.adjust_kernel_inputs(kernel_inputs, op_name) + extra_kwargs = heuristic.get_extra_kwargs(kernel_inputs, op_name) + # Create KernelTemplateChoice generator using the moved function + overrides = kwarg_overrides or {} + return make_ktc_generator( + template=template, + cs=cs, + extra_kwargs=extra_kwargs, + overrides=overrides, + layout=kernel_inputs.output_layout(), + inputs=inputs_val, + ) + + def _need_to_fix_layout( + self, + adjusted_choices: list[KernelTemplateChoice], + op_name: str, + ) -> bool: + """ + Check if we need to fix the layout instead of keeping it flexible + + Args: + ktc: KernelTemplateChoice object + + Returns: + True if we need to fix the layout, False otherwise + """ + # TODO: debug and fix + # NOTE: on mps, we see issues with flexible layouts on baddmm. This check just makes sure + # that for mps, everything stays as it was before this optimization + if len(adjusted_choices) > 0: + if adjusted_choices[0].inputs.device_type == "mps" and op_name not in [ + "mm", + "addmm", + ]: + return True + + # Since the following backends are not using get_mm_configs yet through the singular call, + if not (config.max_autotune or config.max_autotune_gemm): + # no danger of using other backends than ATEN + if not config.max_autotune_allow_flexible_layouts and op_name not in [ + # The historical implementation for mm and addmm allowed had flexible layouts in the + # not max-autotune world + "mm", + "addmm", + ]: + # TODO: deprecate this by migrating users to the new behavior + return True + return False + + if not config.max_autotune_allow_flexible_layouts: + # we always need to fix the layout + return True + + # Since the following backends are not using get_template_configs yet through the singular call, + # we don't know if they are a valid choice or not. Instead, just skip the optimization + # defensively. + # TODO(coconutruben): remove this once CPP,CK,CUTLASS are supported + if _use_autotune_backend("CUTLASS"): + return True + if _use_autotune_backend("CK") or _use_autotune_backend("CKTILE"): + return True + if _use_autotune_backend("CPP"): + return True + return any( + not isinstance(ktc.template, ExternKernelChoice) for ktc in adjusted_choices + ) + + def get_template_configs( + self, + kernel_inputs: KernelInputs, + templates: list[Union[KernelTemplate, ExternKernelChoice]], + op_name: str, + kwarg_overrides: Optional[dict[str, dict[str, Any]]] = None, + ) -> list[ChoiceCaller]: + """ + Get list of ChoiceCallers for MM templates using template-specific heuristics. + + Args: + kernel_inputs: MMKernelInputs containing input tensor nodes and matrix indices + layout: Output layout + templates: List of template objects (KernelTemplate or ExternKernelChoice) + op_name: Operation name (e.g., "bmm", "baddbmm", "addmm", "mm_plus_mm") + kwarg_overrides: Optional dict of kwargs to override for each template heuristic, + indexed by template.uid. These only override the per config kwargs, not the extra kwargs + Returns: + List of ChoiceCaller objects from the templates + """ + if kwarg_overrides is None: + kwarg_overrides = {} + input_tensors = kernel_inputs.nodes() + if len(input_tensors) < 2: + raise ValueError(f"Need at least 2 input tensors, got {len(input_tensors)}") + layout = kernel_inputs.output_layout() + # First pass: Create dict of template.uid to generator of KernelTemplateChoice objects + template_choices = {} + for template in templates: + template_choices[template.uid] = self.get_ktc( + kernel_inputs, + template, + op_name, + kwarg_overrides.get(template.uid, {}), + ) + + # Second pass: Adjust the template choices + adjusted_choices = self._finalize_template_configs( + template_choices, + kernel_inputs, + templates, + op_name, + kwarg_overrides, + ) + # Layout optimization: if all choices are ExternKernelChoice and layout is FixedLayout, convert to FlexibleLayout + if self._need_to_fix_layout(adjusted_choices, op_name): + layout = kernel_inputs.output_layout(flexible=False) + for ktc in adjusted_choices: + ktc.layout = layout + # for good measure, delete the cached ChoiceCaller from the ktc if it existed. + # ExternKernelChoice are cheap to generate + if hasattr(ktc, "_choice"): + del ktc._choice + # Third pass: Convert to ChoiceCaller objects + return [ktc.choice for ktc in adjusted_choices if ktc.choice is not None] + + def triton_kernel_kwargs( + self, + kernel_cls: type[TritonKernel], + features: SIMDKernelFeatures, + groups: list[sympy.Expr], + kernel_kwargs: dict[str, Any], + ) -> dict[str, Any]: + """Hook to change the kwargs passed to TritonKernel, used to apply fixed configurations""" + return kernel_kwargs + + @staticmethod + def should_use_cooperative_reduction(features: SIMDKernelFeatures) -> bool: + """Heuristic to decide if a cooperative reduction should be used.""" + if config.triton.force_cooperative_reductions: + return True + if ( + not config.triton.cooperative_reductions + or V.graph.get_current_device_or_throw().type == "cpu" + ): + return False + + xhint = V.graph.sizevars.size_hint(features.numel, fallback=2) + if xhint <= 8: + threshold = 32768 * xhint + elif xhint <= 16: + threshold = 2097152 + else: + return False + # TODO(jansel): should this default on for dynamic shapes? + return V.graph.sizevars.statically_known_geq( + features.reduction_numel, threshold + ) + + @staticmethod + def should_use_persistent_reduction( + features: SIMDKernelFeatures, cooperative_reduction: bool + ) -> bool: + """ + Heuristic to decide if a persistent reduction should be used. + """ + if not config.triton.persistent_reductions: + return False + threshold = { + ReductionHint.INNER: 1024, + }.get(features.get_reduction_hint(), 64) + + if features.get_reduction_hint() not in ( + ReductionHint.INNER, + ReductionHint.OUTER_TINY, + ): + bounds = bound_sympy(features.reduction_numel) + lower = bounds.lower + upper = bounds.upper + + if not all( + ( + (isinstance(bound, int) or bound.is_constant()) + and bound != torch.utils._sympy.numbers.IntInfinity() + ) + for bound in (lower, upper) + ): + return False + + lower = next_power_of_2(int(lower)) + upper = next_power_of_2(int(upper)) + + # If we are are coalescing on xblock (not ReductionHint.INNER) and this is not a tiny kernel + # (not ReductionHint.OUTER_TINY), do not use persistent reduction if it induces tile + # quantization. Persistent reduction forces rblock == rnumel, if the bounds between lower + # and upper are large, for the lower values we will be masking off large % of read/writes, + # when we could expand the coalescing xblock instead. + if lower != upper: + return False + + if cooperative_reduction: + # The RSPLIT of cooperative reductions means each thread block is operating on fewer elements + try: + threshold *= 32 // min( + V.graph.sizevars.size_hint_or_throw(features.numel), 32 + ) + except ValueError: + pass # unbacked symint + + # If multi_kernel is enabled, we do more aggressive persistent reduction. + # This may result in some persistent reductions slower than the + # corresponding non-persistent reductions. MultiKernel will do benchmarking + # to pick the faster one. + if config.triton.multi_kernel: + threshold *= 16 + + return V.graph.sizevars.statically_known_leq( + features.reduction_numel, threshold + ) # type: ignore[arg-types] + + @staticmethod + def reduction_split_factor( + device: torch.device, + reduction_numel_hint: int, + numel_hint: int, + inner_reduction: bool, + ) -> int: + """Heuristic to decide the RSPLIT used for split reductions. + When a reduction has a small number of outputs there is not enough parallelism, + so we will do the reduction in two phases.""" + props = DeviceProperties.create(device) + num_sm = props.multi_processor_count + min_elements_per_thread = 32 + max_elements_per_thread = 512 + threads_per_sm = 2048 + min_elements_per_device = min_elements_per_thread * num_sm * threads_per_sm + max_elements_per_device = max_elements_per_thread * num_sm * threads_per_sm + num_warps = 8 + num_threads = 32 * num_warps + + if inner_reduction: + # do heuristics that's close to eager mode for split inner reduction + # we leak reduction autotune configs here, and will need to refactor to avoid this later + if numel_hint >= 2 * num_sm: # don't split if there are enough outputs + return 1 + if reduction_numel_hint <= 8192: + return 1 + if reduction_numel_hint * numel_hint <= min_elements_per_device: + split_size = min_elements_per_thread + elif reduction_numel_hint * numel_hint < max_elements_per_device: + target_blocks = num_sm * threads_per_sm // (2 * num_threads) + blocks_per_output = (target_blocks + numel_hint - 1) // numel_hint + tmp_split_size = ( + reduction_numel_hint + num_threads * blocks_per_output - 1 + ) // (num_threads * blocks_per_output) + divisors = sympy.divisors(reduction_numel_hint) + closest = min(divisors, key=lambda x: abs(x - tmp_split_size)) + if abs(closest - tmp_split_size) < 30: + # prefer even splits, but never smalle than min_elements_per_thread + split_size = max(closest, min_elements_per_thread) + else: + split_size = tmp_split_size + else: + divisors = sympy.divisors(reduction_numel_hint) + closest = min(divisors, key=lambda x: abs(x - max_elements_per_thread)) + if abs(closest - max_elements_per_thread) < 50: + # prefer even splits + split_size = closest + else: + split_size = max_elements_per_thread + return (reduction_numel_hint + split_size * num_threads - 1) // ( + split_size * num_threads + ) + else: + # TODO the best heuristic currently has XBLOCK (corresponding to numel_hint) 128 + # extend to even smaller number of outputs + rvals_per_thread = 4 # comes from heuristics, refactor to not leak here + xvals_per_block = 128 + xblocks = (numel_hint + xvals_per_block - 1) // xvals_per_block + if reduction_numel_hint * numel_hint < min_elements_per_device: + split_size = min_elements_per_thread + elif reduction_numel_hint * numel_hint < max_elements_per_device: + target_blocks = num_sm * threads_per_sm // (num_threads) + target_blocks = (target_blocks + xblocks - 1) // xblocks + tmp_split_size = ( + reduction_numel_hint + rvals_per_thread * target_blocks - 1 + ) // (rvals_per_thread * target_blocks) + divisors = sympy.divisors(reduction_numel_hint) + closest = min(divisors, key=lambda x: abs(x - tmp_split_size)) + if abs(tmp_split_size - closest) < 20: + split_size = max(closest, min_elements_per_thread) + else: + split_size = tmp_split_size + else: + divisors = sympy.divisors(reduction_numel_hint) + closest = min(divisors, key=lambda x: abs(x - max_elements_per_thread)) + if abs(closest - max_elements_per_thread) < 50: + # prefer even splits + split_size = closest + else: + split_size = max_elements_per_thread + + return (reduction_numel_hint + rvals_per_thread * split_size - 1) // ( + rvals_per_thread * split_size + ) + + @staticmethod + def can_fuse( + scheduler: Scheduler, + node1: BaseSchedulerNode, + node2: BaseSchedulerNode, + shared_data_score: int, + ) -> bool: + """ + Heuristics to prevent fusion applied to both horizontal and vertical fusions. Heuristics here should not + be needed for correctness and tweaking them may yield additional performance. + + See also some related heuristics that can be changed via config: + - config.triton.tiling_prevents_pointwise_fusion + - config.triton.tiling_prevents_reduction_fusion + - config.aggressive_fusion (will cause this function to be called more times) + """ + if shared_data_score == 0 and ( + not config.aggressive_fusion or node1.is_reduction() or node2.is_reduction() + ): + if is_metric_table_enabled("fusion_failure_due_to_indexing_mismatch"): + common_buf_names: OrderedSet[str] = ( + node1.read_writes.buffer_names() & node2.read_writes.buffer_names() + ) + if len(common_buf_names) > 0: + get_metric_table("fusion_failure_due_to_indexing_mismatch").add_row( + lambda: { + "pre_grad_graph_id": V.graph.graph_id, + "post_grad_graph_id": V.graph.post_grad_graph_id, + "node1_name": node1.get_name(), + "node2_name": node2.get_name(), + "node1_debug_str": write_text(node1.debug_str()), + "node2_debug_str": write_text(node2.debug_str()), + "common_buffer_names": list(common_buf_names), # type: ignore[dict-item] + "failure_reason": scheduler.decide_fusion_fail_reason( + node1, node2, common_buf_names + ), + } + ) + + WhyNoFuse(node1, node2)("no shared data due to indexing mismatch") + return False + WhyNoFuse(node1, node2)("no shared data") + return False # heuristic not needed for correctness + + if ( + not node1.is_foreach() + and not node2.is_foreach() + and len(node1.get_nodes()) + len(node2.get_nodes()) > config.max_fusion_size + ): + WhyNoFuse(node1, node2)("exceeds max fusion") + return False # heuristic not needed for correctness + + if scheduler.can_fusion_increase_peak_memory(node1, node2): + WhyNoFuse(node1, node2)("Fusion will increase peak memory") + return False + + if ( + config.max_fusion_unique_io_buffers is not None + and scheduler.fusion_prevent_too_many_reads_and_writes( + node1, + node2, + config.max_fusion_unique_io_buffers, + ) + ): + WhyNoFuse(node1, node2)("fusion_prevent_too_many_reads_and_writes") + return False + + return True + + @staticmethod + def can_fuse_vertical( + scheduler: Scheduler, + node1: BaseSchedulerNode, + node2: BaseSchedulerNode, + shared_data_score: int, + ) -> bool: + """Hook for heuristics to prevent vertical (producer/consumer) fusions""" + return True + + @staticmethod + def can_fuse_horizontal( + scheduler: Scheduler, + node1: BaseSchedulerNode, + node2: BaseSchedulerNode, + shared_data_score: int, + ) -> bool: + """Hook for heuristics to prevent horizontal (consumer/consumer) fusions""" + if MixOrderReduction.can_fuse(node1, node2): + # For mix order reduction, we disregard shared data or + # distance. + return True + if shared_data_score < config.score_fusion_memory_threshold: + WhyNoFuse(node1, node2)("score_fusion_memory_threshold") + return False + if scheduler.are_long_distant_nodes(node1, node2): + WhyNoFuse(node1, node2)( + "Nodes are too far away. Fusing them may increase peak memory." + ) + return False + return True + + @staticmethod + def score_fusion( + scheduler: Scheduler, + node1: BaseSchedulerNode, + node2: BaseSchedulerNode, + ) -> Sortable: + """ + Assign a score (higher comes first) to the fusion of node1 and node2. + When different fusions conflict with each other, this is the way we + decide what order to run them in. + + Our current score is based on: + - The type of fusion (template/reduction/etc) + - Estimate of the saved memory operations + - Fusions closer together in original graph order + """ + + memory_score, is_mix_order_reduction = typing.cast( + tuple[int, bool], + scheduler.score_fusion_memory( + node1, node2, return_is_mix_order_reduction=True + ), + ) + proximity_score = -max( + abs(node1.min_order - node2.max_order), + abs(node2.min_order - node1.max_order), + ) + + # prologue fusion always last + if node2.is_template(): + template_score = 0 + else: + template_score = 1 + ( + (node1.is_template() == config.epilogue_fusion_first) + and memory_score > 0 + ) + + type_score = node1.is_reduction() == node2.is_reduction() and memory_score > 0 + + # pyrefly: ignore [bad-return] + return FusionScore( + template_score, + type_score, + memory_score, + proximity_score, + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codecache.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codecache.py new file mode 100644 index 0000000000000000000000000000000000000000..aad56bca31d6c5436e13affbc2d90d29a03e10f6 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/codecache.py @@ -0,0 +1,4445 @@ +from __future__ import annotations + +import base64 +import copyreg +import dataclasses +import functools +import hashlib +import importlib +import importlib.resources +import io +import itertools +import json +import logging +import os +import pickle +import pkgutil +import platform +import re +import shlex +import shutil +import struct +import subprocess +import sys +import tempfile +import textwrap +import threading +import warnings +from bisect import bisect_right +from copy import copy +from ctypes import c_void_p, CDLL, cdll +from datetime import timedelta +from functools import lru_cache, partial +from pathlib import Path +from tempfile import _TemporaryFileWrapper +from time import time, time_ns +from types import ModuleType +from typing import Any, cast, Generic, NoReturn, TYPE_CHECKING, TypeVar, Union +from typing_extensions import override, Self + +import torch +import torch.distributed as dist +from torch import SymInt, Tensor +from torch._dynamo.device_interface import get_interface_for_device +from torch._dynamo.exc import SkipFrame +from torch._dynamo.utils import ( + CompileEventLogger, + counters, + dynamo_timed, + get_metrics_context, +) +from torch._inductor import config, exc, metrics +from torch._inductor.codegen.common import ( + custom_backend_codegen_configs, + custom_backend_passes, + init_backend_registration, +) +from torch._inductor.codegen.cuda import cuda_env +from torch._inductor.codegen.rocm.compile_command import ( + rocm_compile_command, + rocm_compiler, +) +from torch._inductor.compile_worker.utils import in_toplevel_process +from torch._inductor.cpp_builder import ( + _LINKER_SCRIPT, + _set_gpu_runtime_env, + _TORCH_PATH, + _transform_cuda_paths, + convert_cubin_to_obj, + CppBuilder, + CppOptions, + CppTorchDeviceOptions, + get_compiler_version_info, + get_ld_and_objcopy, + get_name_and_dir_from_output_file_path, + normalize_path_separator, + run_asm_build_object, +) +from torch._inductor.cpu_vec_isa import pick_vec_isa +from torch._inductor.custom_graph_pass import ( + CustomGraphModulePass, + CustomGraphPass, + CustomGraphPassType, + CustomPartitionerFn, + CustomPartitionerFnType, +) +from torch._inductor.freezing_utils import has_frozen_params, is_frozen_param +from torch._inductor.runtime.compile_tasks import _reload_python_module +from torch._inductor.runtime.runtime_utils import cache_dir, default_cache_dir +from torch._inductor.utils import ( + ALIGN_BYTES, + clear_on_fresh_cache, + determine_aoti_mmap_flags, + is_linux, + is_windows, +) +from torch._logging import trace_structured +from torch._subclasses.fake_tensor import ( + extract_tensor_metadata, + FakeTensor, + TensorMetadata, +) +from torch._utils_internal import log_cache_bypass +from torch.compiler import config as cconfig +from torch.compiler._cache import ( + CacheArtifact, + CacheArtifactFactory, + CacheArtifactManager, +) +from torch.export.pt2_archive._package_weights import TensorProperties, Weights +from torch.export.pt2_archive.constants import CUSTOM_OBJ_FILENAME_PREFIX +from torch.fx.experimental.symbolic_shapes import has_hint, hint_int, ShapeEnv +from torch.utils._ordered_set import OrderedSet + +from .output_code import CompiledFxGraph +from .remote_cache import create_cache +from .runtime import autotune_cache +from .runtime.autotune_cache import AutotuneCacheBundler +from .triton_bundler import TritonBundler +from .virtualized import V + + +if config.is_fbcode(): + from triton.fb.build import build_paths + + +T = TypeVar("T") + +if TYPE_CHECKING: + from collections.abc import Callable, Generator, KeysView, Sequence + from concurrent.futures import Future + + from .compile_fx import _CompileFxKwargs + from .cpp_builder import BuildOptionsBase + from .graph import GraphLowering + from .ir import ChoiceCaller + from .output_code import CompiledFxGraphConstants, OutputCode + from .remote_cache import JsonDataTy, RemoteCache + from .runtime.hints import HalideInputSpec, HalideMeta + from .runtime.triton_heuristics import CachingAutotuner + from .utils import InputType + + +_IS_WINDOWS = sys.platform == "win32" +LOCK_TIMEOUT = config.file_lock_timeout + +output_code_log = torch._logging.getArtifactLogger(__name__, "output_code") +autotuning_log = torch._logging.getArtifactLogger(__name__, "autotuning") +log = logging.getLogger(__name__) + + +def use_re_build() -> bool: + """ + Use for CUTLASS compilation only right now. + """ + if config.is_fbcode() and not cuda_env.nvcc_exist(_cuda_compiler()): + from triton.fb.re_build_helper import should_build_locally + + return not should_build_locally() + return False + + +def get_cpp_wrapper_cubin_path_name() -> str: + return "cubin_path" if torch.version.hip is None else "hsaco_path" + + +def get_kernel_bin_format(device: str) -> str: + if device == "cuda": + return "cubin" if torch.version.hip is None else "hsaco" + elif device == "xpu": + return "spv" + else: + return "" + + +def get_device_information(device_type: str) -> dict[str, str]: + """ + Gets all the current device information used to compile the .so. + """ + metadata: dict[str, str] = { + "AOTI_PLATFORM": sys.platform, + "AOTI_MACHINE": platform.machine(), + "AOTI_CPU_ISA": str(torch._inductor.cpu_vec_isa.pick_vec_isa()).upper(), + "AOTI_COMPUTE_CAPABILITY": str( + get_interface_for_device(device_type).get_compute_capability() + ), + } + return metadata + + +class CacheBase: + @staticmethod + @functools.cache + def get_system() -> dict[str, Any]: + from torch._inductor.runtime.triton_compat import HAS_TRITON, triton_key + + if HAS_TRITON: + # Use triton_key instead of triton.__version__ as the version + # is not updated with each code change + triton_version = triton_key() + else: + triton_version = None + + try: + system: dict[str, Any] = { + "device": {"name": None}, + "version": { + "triton": triton_version, + }, + } + device_properties = torch.cuda.get_device_properties( + torch.cuda.current_device() + ) + if torch.version.cuda is not None: + system["device"]["name"] = device_properties.name + system["version"]["cuda"] = torch.version.cuda + else: + system["device"]["name"] = device_properties.gcnArchName + system["version"]["hip"] = torch.version.hip + except (AssertionError, RuntimeError): + # If cuda is not installed, none of the above config is relevant. + system = {} + + system["hash"] = hashlib.sha256( + json.dumps(system, sort_keys=True).encode("utf-8") + ).hexdigest() + + return system + + @staticmethod + @clear_on_fresh_cache + @functools.cache + def get_local_cache_path() -> Path: + return Path(os.path.join(cache_dir(), "cache", CacheBase.get_system()["hash"])) + + def __init__(self) -> None: + self.system = CacheBase.get_system() + + def get_local_cache(self) -> dict[str, Any]: + local_cache_path = self.get_local_cache_path() + if not local_cache_path.is_file(): + return {} + with open(local_cache_path) as local_cache_fp: + local_cache = json.load(local_cache_fp) + return local_cache["cache"] + + def update_local_cache(self, local_cache: dict[str, Any]) -> None: + local_cache_path = self.get_local_cache_path() + write_atomic( + str(local_cache_path), + json.dumps({"system": self.system, "cache": local_cache}, indent=4), + make_dirs=True, + ) + + +class LocalCache(CacheBase): + def lookup(self, *keys: str) -> dict[str, Any] | None: + cache = self.get_local_cache() + + sub_cache = cache + for key in keys: + if key in cache: + sub_cache = cache[key] + else: + return None + + return sub_cache + + def set_value(self, *keys: str, value: Any) -> None: + cache = self.get_local_cache() + + sub_cache = cache + for key in keys[0:-1]: + sub_cache.setdefault(key, {}) + sub_cache = sub_cache[key] + sub_cache[keys[-1]] = value + + self.update_local_cache(cache) + + +class PersistentCache(CacheBase): + def lookup( + self, + choices: list[ChoiceCaller], + op: str, + inputs: str, + benchmark: Callable[[Any], dict[ChoiceCaller, float]] | None, + hint_override: int | None = None, + ) -> dict[ChoiceCaller, float]: + """ + Check to see if we have benchmarked the given choice callers. For each + choice caller: + + 1. Check local_cache[op][inputs][choice][precision], return benchmark if cached. + 2. If benchmark is not None: + a. `max_autotune_gemm=True`: benchmark the choice, update + local_cache[op][inputs][choice], and return the benchmark. + b. `max_autotune_gemm=False`: don't benchmark the choice, return nothing. + """ + precision = torch.get_float32_matmul_precision() + cache_key = f"{inputs}_{hint_override}" if hint_override is not None else inputs + + timings = {} + + def check_cache(cache: dict[str, Any]) -> bool: + """Check if `cache` contains data for all the choices""" + hit = True + for choice in choices: + choice_hash = choice.hash_key() + if choice_hash in cache.get(op, {}).get(cache_key, {}).get( + precision, {} + ): + # cache hit + timings[choice] = cache[op][cache_key][precision][choice_hash] + else: + # cache miss + hit = False + break + return hit + + local_cache = self.get_local_cache() if config.autotune_local_cache else {} + if (not check_cache(local_cache)) and (benchmark is not None): + # re-benchmark everything to try to get consistent numbers from the same machine + timings = benchmark(choices) + assert all(choice in timings for choice in choices) + local_cache.setdefault(op, {}) + local_cache[op].setdefault(cache_key, {}).setdefault(precision, {}) + for choice, timing in timings.items(): + local_cache[op][cache_key][precision][choice.hash_key()] = timing + + self.update_local_cache(local_cache) + + return timings + + +def get_lock_dir() -> str: + lock_dir = os.path.join(cache_dir(), "locks") + if not os.path.exists(lock_dir): + os.makedirs(lock_dir, exist_ok=True) + return lock_dir + + +def sha256_hash(data: bytes) -> str: + # [:51] to strip off the "Q====" suffix common to every hash value. + return base64.b32encode(hashlib.sha256(data).digest())[:51].decode("utf-8").lower() + + +def code_hash(code: str | bytes, extra: str | bytes = "") -> str: + hashing_str = code if isinstance(code, bytes) else code.encode("utf-8") + if extra: + extra_b = extra if isinstance(extra, bytes) else extra.encode("utf-8") + hashing_str = hashing_str + b"||" + extra_b + return "c" + sha256_hash(hashing_str) + + +def get_path( + basename: str, extension: str, specified_dir: str = "" +) -> tuple[str, str, str]: + if specified_dir: + if os.path.isabs(specified_dir): + subdir = specified_dir + else: + subdir = os.path.join(cache_dir(), specified_dir) + else: + subdir = os.path.join(cache_dir(), basename[1:3]) + path = os.path.join(subdir, f"{basename}.{extension}") + return basename, subdir, path + + +def get_hash(content: str | bytes, extra: str = "", hash_type: str = "code") -> str: + if hash_type in {"amdgcn", "code", "ptx", "spv"}: + return code_hash(content, extra) + if hash_type in {"cubin", "hsaco", "spv"}: + return code_hash(repr(content)) + raise AssertionError(f"Unknown hash type {hash_type}") + + +class WritableTempFile: + """ + Avoid "Permission denied error" on Windows: + with tempfile.NamedTemporaryFile("w", suffix=".gv") as temp_file: + # Not writable on Windows: + # https://docs.python.org/3/library/tempfile.html#tempfile.NamedTemporaryFile + + Example: + with WritableTempFile("w", suffix=".gv") as temp_file: + tree.to_dotfile(temp_file.name) + """ + + def __init__( + self, mode: str = "w", *, encoding: Any = None, suffix: Any = None + ) -> None: + self.mode = mode + self.encoding = encoding + self.suffix = suffix + + def __enter__(self) -> _TemporaryFileWrapper[Any]: + self.temp_file = tempfile.NamedTemporaryFile( + self.mode, encoding=self.encoding, suffix=self.suffix, delete=False + ) + return self.temp_file + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + self.temp_file.close() + try: + os.unlink(self.temp_file.name) + except OSError as e: + if _IS_WINDOWS: + # On Windows, some case temp file is opened and fail to unlink. Need to ignore it. + pass + else: + raise e + + +def write( + content: str | bytes, + extension: str, + extra: str = "", + hash_type: str = "code", + specified_dir: str = "", + key: str | None = None, +) -> tuple[str, str]: + if key is None: + # use striped content to compute hash so we don't end up with different + # hashes just because the content begins/ends with different number of + # spaces. + key = get_hash(content.strip(), extra, hash_type) + basename, _subdir, path = get_path(key, extension, specified_dir) + if not os.path.exists(path): + write_atomic(path, content, make_dirs=True) + return basename, path + + +def write_text(text: str) -> str: + """ + Write the `text` to a file and return the path computed based on the hash. + """ + return write(text, "txt")[1] + + +def write_atomic( + path_: str, + content: str | bytes, + make_dirs: bool = False, + encode_utf_8: bool = False, +) -> None: + # Write into temporary file first to avoid conflicts between threads + # Avoid using a named temporary file, as those have restricted permissions + assert isinstance(content, (str, bytes)), ( + "Only strings and byte arrays can be saved in the cache" + ) + path = Path(path_) + if make_dirs: + path.parent.mkdir(parents=True, exist_ok=True) + tmp_path = path.parent / f".{os.getpid()}.{threading.get_ident()}.tmp" + write_mode = "w" if isinstance(content, str) else "wb" + with tmp_path.open(write_mode, encoding="utf-8" if encode_utf_8 else None) as f: + f.write(content) + try: + tmp_path.rename(target=path) + except FileExistsError: + if not _IS_WINDOWS: + raise + # On Windows file exist is expected: https://docs.python.org/3/library/pathlib.html#pathlib.Path.rename + # Below two lines code is equal to `tmp_path.rename(path)` on non-Windows OS. + # 1. Copy tmp_file to Target(Dst) file. + shutil.copy2(src=tmp_path, dst=path) + # 2. Delete tmp_file. + os.remove(tmp_path) + + +@dataclasses.dataclass +class TensorMetadataAndValues: + """ + TensorMetadata plus the elements as a list of raw values. + Used for hashing inlined constants. + """ + + tensor_metadata: TensorMetadata + values: list[Any] + + +def _ident(x: T) -> T: + return x + + +def extract_tensor_metadata_for_cache_key(t: Tensor) -> TensorMetadata: + """ + Extracts the tensor metadata and removes fields of the TensorMetadata + that are not needed for caching + """ + meta = extract_tensor_metadata(t) + if not hasattr(t, "_is_inductor_static"): + meta = dataclasses.replace(meta, storage_offset=0, storage_bytes=None) + + return meta + + +class FxGraphCachePickler(pickle.Pickler): + """ + Custom pickler to customize the pickling of some objects (Tensors), only for the + purpose of computing a hash for keying into the FxGraphCache. Tensors contain + objects that don't pickle and/or vary between runs, and we want to capture the + data that allow us to compute a stable, but safe hash. + """ + + def __init__( + self, + gm: torch.fx.GraphModule, + has_user_defined_triton_kernels: bool = False, + ) -> None: + """ + Create an FX graph pickler. If include_non_inlined=True, then pickling will + include the _values_ for all Tensors. (Note that any tensors are constants + attached as attributes to the GraphModule). Otherwise, pickling will include + only the metadata for these tensors. + """ + self._stream = io.BytesIO() + super().__init__(self._stream) + + self.dispatch_table = copyreg.dispatch_table.copy() + self.dispatch_table.update( + { + FakeTensor: functools.partial(self._reduce_fake_tensor), + torch.Tensor: functools.partial(self._reduce_tensor), + torch.nn.parameter.Parameter: functools.partial(self._reduce_tensor), + torch.SymInt: functools.partial(self._reduce_symint), + torch.fx.experimental._backward_state.BackwardState: functools.partial( + self._reduce_unsupported + ), + } + ) + if has_user_defined_triton_kernels: + # Need to use runtime type as GraphModule generates a singleton in __new__ function + self.dispatch_table[gm.__class__] = functools.partial( + self._reduce_graph_module + ) + + # Run with pickler.fast so it doesn't intern strings, making the hash result more predictable + # TODO: pickler.fast is technically deprecated. Will this work on new python versions? + self.fast = True + + def _reduce_fake_tensor( + self, t: Tensor + ) -> tuple[Callable[[T], T], tuple[TensorMetadata]]: + """ + Custom reducer to pickle FakeTensors. + """ + metadata = extract_tensor_metadata_for_cache_key(t) + return (_ident, (metadata,)) + + def _reduce_tensor( + self, t: Tensor + ) -> tuple[Callable[[T], T], tuple[TensorMetadata | TensorMetadataAndValues]]: + """ + Custom reducer to pickle Tensors. If we see tensors, we know they're constants + stored as attributes on the GraphModule. + """ + from .graph import GraphLowering + + if t.is_mkldnn: + # TODO: These tensors don't currently pickle, so we can't cache a compiled + # graph containing them. Just fail now. If mkldnn tensors get pickling + # support, we can remove this. + raise BypassFxGraphCache("mkldnn tensors unpickleable") + + metadata = extract_tensor_metadata_for_cache_key(t) + + # If this is a non-inlined frozen parameter, we consider the metadata only. + if is_frozen_param(t) and not GraphLowering.can_inline_constant(t): + return (_ident, (metadata,)) + + # Very large tensors will be expensive to copy to cpu and hash. Let's at least + # report any slowness. + start = time() + values = t.tolist() + elapsed = time() - start + if elapsed > 1.0: + warnings.warn( + f"FX graph cache copying of a large constant took {elapsed:.1}s. " + "Please file an issue." + ) + + return (_ident, (TensorMetadataAndValues(metadata, values),)) + + def _reduce_symint(self, s: SymInt) -> tuple[Callable[[T], T], tuple[str]]: + """ + Custom reducer to pickle SymInts. + """ + # For hashing purposes, we only care about the name of the symbol and not the + # backed value. We evaluate guards stored with a cached graph to ensure a cached + # entity with SymInt args is safe to reuse. + return (_ident, (str(s),)) + + def _reduce_unsupported(self, s: Any) -> NoReturn: + """ + Custom reducer to handle any objects that we don't support and therefore + raise to bypass caching. + """ + raise BypassFxGraphCache("Reduce unsupported") + + def _reduce_graph_module( + self, gm: torch.fx.GraphModule + ) -> tuple[Any, tuple[dict[str, Any], str]]: + """ + Custom reducer for graph module to handle irrelevant data for user + defined triton kernels + Essentially what we are doing here is a huge hack where user defined + triton kernel contain a dynamo time side table and the arguments to the + call_function are indices into this side table. These arguments are not + for hashing purposes since we included the source code into the cache + key and the numbers are prone to give false negatives due to ordering. + """ + fn, (data, imports) = gm.__reduce__() + code = data["_code"] + code = re.sub(r"kernel_idx = \d+", "", code) + code = re.sub(r"constant_args_idx = \d+", "", code) + data["_code"] = code + return fn, (data, imports) + + def dumps(self, obj: Any) -> bytes: + """ + Pickle an object and return a byte string. + """ + try: + self.dump(obj) + return self._stream.getvalue() + except (TypeError, AttributeError, pickle.PicklingError) as e: + # Some configs options may not pickle. + log.warning("Failed to pickle cache key", exc_info=True) + raise BypassFxGraphCache("Failed to pickle cache key") from e + finally: + # Reset our stream for the next dump. + self._stream.seek(0) + self._stream.truncate(0) + + def get_hash(self, obj: Any) -> str: + """ + Serialize an object and return a hash of the bytes. + """ + serialized_data = self.dumps(obj) + return sha256_hash(serialized_data) + + def debug_lines(self, inp: FxGraphHashDetails) -> list[str]: + """ + Get a printable string describing in more detail all the attributes + comprising an object. Useful for debugging when one graph hashes + to a different value than another. + """ + + def get_str(obj: Any) -> str: + if isinstance(obj, torch.Tensor): + return str(extract_tensor_metadata_for_cache_key(obj)) + elif isinstance(obj, bytes): + val = obj.decode("utf-8", errors="replace") + return val if len(val) <= 1024 else val[:1024] + "..." + elif type(obj) in self.dispatch_table: + # Run the reducer on the object + return str(self.dispatch_table[type(obj)](obj)[1]) + else: + return str(obj) + + lines = [] + for attr, obj in vars(inp).items(): + if isinstance(obj, list): + for ii in range(len(obj)): + h = self.get_hash(obj[ii]) + lines.append(f"[{h}] {attr}[{ii}]: {get_str(obj[ii])}") + elif isinstance(obj, dict): + for k, v in obj.items(): + h = self.get_hash(v) + lines.append(f"[{h}] {attr}[{k}]: {get_str(v)}") + else: + h = self.get_hash(obj) + lines.append(f"[{h}] {attr}: {get_str(obj)}") + return lines + + +def build_code_hash( + roots: list[str] | None, prefix: str, hasher: hashlib._Hash +) -> None: + for lib in sorted(pkgutil.iter_modules(roots, prefix), key=lambda x: x.name): + spec = lib.module_finder.find_spec(lib.name, None) + assert spec is not None + module = spec.origin + assert module is not None + with open(module, "rb") as f: + hasher.update(spec.name.encode("utf-8")) + hasher.update(f.read()) + if lib.ispkg: + # need to also hash submodules + build_code_hash(spec.submodule_search_locations, f"{spec.name}.", hasher) + + +def torch_key_cache(func: Callable[[], bytes]) -> Callable[[], bytes]: + """ + This function is a reimplementation of functools.lru_cache with a + set function that allows prepopulating the cache. + """ + # Use list for reference semantics + _cache: list[bytes] = [] + + def wrapper() -> bytes: + if len(_cache) == 0: + _cache.append(func()) + return _cache[0] + + def set_val(val: bytes) -> None: + assert len(_cache) == 0 + _cache.append(val) + + def clear() -> None: + _cache.clear() + + wrapper.set = set_val # type: ignore[attr-defined] + wrapper.clear = clear # type: ignore[attr-defined] + return wrapper + + +@torch_key_cache +def torch_key() -> bytes: + """ + Compute a key that contains relevant information about torch source files + """ + with dynamo_timed("inductor_codecache_torch_key", log_pt2_compile_event=False): + if not config.is_fbcode(): + + def get_code_hash(root: str) -> bytes: + # This function isn't meant to be used outside of torch_key, just a + # helper for clarity. Instead, use torch_key() directly when you need + # a hash representing the state of the source code. + extra_files = ( + "codegen/aoti_runtime/interface.cpp", + "script.ld", + ) + inductor_root = os.path.dirname(__file__) + extra_files = [os.path.join(inductor_root, x) for x in extra_files] + hasher = hashlib.sha256() + hasher.update(torch.__version__.encode("utf-8")) + build_code_hash([root], "", hasher) + for path in extra_files: + if os.path.exists(path): + with open(path, "rb") as f: + hasher.update(f.read()) + return hasher.digest() + + return get_code_hash(_TORCH_PATH) + + from libfb.py import parutil + + return parutil.get_file_contents("torch/src_hash.txt").rstrip().encode("ascii") + + +def get_inductor_root() -> str: + return os.path.dirname(__file__) + + +@dataclasses.dataclass +class OrderedSetHolder: + """ + See FxGraphHashDetails. Holds a sorted list to support stable hashing + of set kwargs. + """ + + items: list[Any] + + +class BypassFxGraphCache(Exception): + """ + Exception to indicate that the FxGraphCache should be bypassed. + """ + + +class FxGraphHashDetails: + """ + Object to capture all the details for a compiled FX graph relevant to computing + a safe and stable cache key. + """ + + # Excluded kwargs param that are not stable between runs + EXCLUDED_KWARGS = ["graph_id"] + + def __init__( + self, + gm: torch.fx.GraphModule, + example_inputs: Sequence[InputType], + fx_kwargs: _CompileFxKwargs, + inputs_to_check: Sequence[int], + ) -> None: + self.gm = gm + self.example_inputs = example_inputs + self.cache_key_tag = cconfig.cache_key_tag + + # Order kwargs so hashing is stable to changes in kwarg order. Although + # it's technically a _CompileFxKwargs we don't actually need it typed as + # such since we're just using it to generate a hash. + self.fx_kwargs: dict[str, object] = {} + for k, v in sorted(fx_kwargs.items()): + if k not in self.EXCLUDED_KWARGS: + if type(v) in (set, OrderedSet): # noqa: set_linter + # Special case to handle set params. Python sets can't be + # ordered, so sort the elements and store them in a proxy. + self.fx_kwargs[k] = OrderedSetHolder(sorted(v)) # type: ignore[call-overload] + else: + self.fx_kwargs[k] = v + + from torch._higher_order_ops.triton_kernel_wrap import ( + kernel_side_table, + triton_kernel_wrapper_functional, + triton_kernel_wrapper_mutation, + ) + from torch._inductor.codegen.wrapper import ( + user_defined_triton_kernel_transitive_closure_source_code, + ) + + # Node meta will not be part of gm's reduce function, so lets remember + # the kernel source code separately + self.user_defined_triton_source: list[Any] = [] + if gm is not None: + for module in gm.modules(): + if not isinstance(module, torch.fx.GraphModule): + continue + for node in itertools.chain( + module.graph.find_nodes( + op="call_function", target=triton_kernel_wrapper_functional + ), + module.graph.find_nodes( + op="call_function", target=triton_kernel_wrapper_mutation + ), + ): + from triton.runtime.autotuner import Autotuner + + kernel = kernel_side_table.get_kernel(node.kwargs["kernel_idx"]) + configs = None + if isinstance(kernel, Autotuner): + if kernel.configs: + configs = str( + sorted( + sorted(str(kv) for kv in c.all_kwargs().items()) + for c in kernel.configs + ) + ) + kernel = kernel.fn + + kernel_source = ( + user_defined_triton_kernel_transitive_closure_source_code( + kernel + ) + ) + constant_args = kernel_side_table.get_constant_args( + node.kwargs["constant_args_idx"] + ) + self.user_defined_triton_source.append( + (kernel_source, constant_args, configs) + ) + + # Alignment checks + self.inputs_to_check = inputs_to_check + + no_tensor_inputs = not any(isinstance(x, torch.Tensor) for x in example_inputs) + # This device index is usually already encoded by the device of the inputs + # but fx graphs don't necessarily have tensor inputs. If there aren't any, + # we need to guard on the device index in case we allocate cuda tensors + if no_tensor_inputs and torch.accelerator.is_available(): + self.default_cuda_device_index = torch.accelerator.current_device_index() + + # 'Deterministic algorithms' can affect codegen via lowering to cuda kernels. + self.deterministic_algorithms_settings = ( + torch.are_deterministic_algorithms_enabled(), + torch.is_deterministic_algorithms_warn_only_enabled(), + torch.utils.deterministic.fill_uninitialized_memory, # type: ignore[attr-defined] + ) + + # Global settings affecting matmul codegen. + self.cuda_matmul_settings = ( + torch.backends.cuda.matmul.fp32_precision, + torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction, + torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction, + ) + + # Also hash on various system info (including the triton compiler version). + self.torch_version = torch_key() + self.system_info = CacheBase.get_system() + self.inductor_config = config.save_config_portable(ignore_private_configs=False) + # Custom post grad passes should provide an ID to hash. + self.post_grad_custom_pre_pass = self._get_custom_pass_detail( + config.post_grad_custom_pre_pass + ) + # TODO: change to more holistic config rather than bundled_autograd_cache + self.precompile_enabled = torch._functorch.config.bundled_autograd_cache + self.post_grad_custom_post_pass = self._get_custom_pass_detail( + config.post_grad_custom_post_pass + ) + self.joint_custom_pre_pass = self._get_custom_pass_detail( + config.joint_custom_pre_pass + ) + self.joint_custom_post_pass = self._get_custom_pass_detail( + config.joint_custom_post_pass + ) + self._pre_fusion_custom_pass = self._get_custom_pass_detail_unsafe( + config._pre_fusion_custom_pass + ) + self._fuse_ddp_communication_passes = self._get_custom_pass_detail_unsafe( + config._fuse_ddp_communication_passes + ) + + # Register indcutor backends and custom passes and get their UUIDs. + init_backend_registration() + self.custom_backend_passes = tuple( + map(self._get_custom_pass_detail, custom_backend_passes.values()) + ) + + # Save custom inductor codegen configs + self.custom_backend_codegen_configs = { + device: custom_config.save_config_portable(ignore_private_configs=False) + for device, custom_config in custom_backend_codegen_configs.items() + if custom_config is not None + } + + # Register the custom partitioner function + self._custom_partitioner_fn = self._get_custom_partitioner_fn_detail( + config.custom_partitioner_fn + ) + + # This is mainly added to handle these two inductor configs, which are (unfortunately) + # sometimes cache safe: + # - _pre_fusion_custom_pass + # - _fuse_ddp_communication_passes + # Their types can be found in `torch/_inductor/config.py`, but: + # - if they are string names, we can cache them safely (one is by default) + # - if any of them are set to custom callables, we will need to cache miss + # Future work is for someone to find any places where these functions are used + # and force them to be of type CustomGraphPass, so we can guarantee serialization. + def _get_custom_pass_detail_unsafe(self, custom_pass: Any) -> Any | None: + if not custom_pass: + return None + if isinstance(custom_pass, list): + return [self._get_custom_pass_detail_unsafe(x) for x in custom_pass] + if isinstance(custom_pass, str): + return custom_pass + if isinstance(custom_pass, CustomGraphPass): + return custom_pass.uuid() + if callable(custom_pass): + # Returning None is safe here because we raise an explicit bypass error + # later if we detect these passes are set to callables + return None + raise AssertionError(f"unknown config type: {str(type(custom_pass))}") + + def _get_custom_pass_detail( + self, custom_pass: CustomGraphPassType | CustomGraphModulePass + ) -> Any | None: + if not custom_pass: + return None + assert isinstance(custom_pass, (CustomGraphPass, CustomGraphModulePass)) + return custom_pass.uuid() + + def _get_custom_partitioner_fn_detail( + self, custom_partitioner_fn: CustomPartitionerFnType + ) -> Any | None: + if not custom_partitioner_fn: + return None + assert isinstance(custom_partitioner_fn, CustomPartitionerFn) + return custom_partitioner_fn.uuid() + + +def compiled_fx_graph_hash( + gm: torch.fx.GraphModule, + example_inputs: Sequence[InputType], + fx_kwargs: _CompileFxKwargs, + inputs_to_check: Sequence[int], +) -> tuple[str, list[str]]: + """ + Generate a unique hash of the FX graph for caching. + """ + details = FxGraphHashDetails(gm, example_inputs, fx_kwargs, inputs_to_check) + has_user_defined_triton_kernels = len(details.user_defined_triton_source) != 0 + pickler = FxGraphCachePickler(gm, has_user_defined_triton_kernels) + + # The prefix distinguishes among the other kinds of objects we + # cache in this module. + key = "f" + pickler.get_hash(details) + debug_lines = pickler.debug_lines(details) + debug_str = "\n".join(debug_lines) + log.debug(f"FX graph cache hash details for key {key}:\n{debug_str}") # noqa: G004 + return key, debug_lines + + +def add_ephemeral_timeout_increase_for_distributed(time_saved_ns: int) -> int: + """ + Ephemerally increases the NCCL timeout when compiling for a distributed job + Returns amount of seconds increased + """ + if not torch.distributed.is_available() or not torch.distributed.is_initialized(): + return 0 + + increased_timeout_sec = int(time_saved_ns // 1e9) # convert to seconds + + if config.is_fbcode(): + fudge_factor = torch._utils_internal.justknobs_getval_int( + "pytorch/remote_cache:ephemeral_timeout_fudge_factor_percentage" + ) + log.info( + "Ephemeral NCCL timeout increase fudge factor %d and original increase value %d", + fudge_factor, + increased_timeout_sec, + ) + increased_timeout_sec += int(increased_timeout_sec * fudge_factor / 100) + + log.info("Increasing NCCL timeout by %d", increased_timeout_sec) + dist.distributed_c10d._add_ephemeral_timeout_for_all_pgs( + timedelta(seconds=increased_timeout_sec) + ) + return increased_timeout_sec + + +class GuardedCache(Generic[T]): + """ + Mixin for caches that have guards associated with their entries. + """ + + @classmethod + def _get_tmp_dir_for_key(cls: type[GuardedCache[T]], _key: str) -> str: + raise NotImplementedError("Implement _get_tmp_dir_for_key on parent class") + + @classmethod + def iterate_over_candidates( + cls: type[GuardedCache[T]], + local: bool, + remote_cache: RemoteCache[JsonDataTy] | None, + key: str, + ) -> Generator[tuple[T, bytes], None, None]: + if local: + subdir = cls._get_tmp_dir_for_key(key) + if os.path.exists(subdir): + for path in sorted(os.listdir(subdir)): + try: + with open(os.path.join(subdir, path), "rb") as f: + content = f.read() + yield pickle.loads(content), content + except Exception: + log.warning( + "fx graph cache unable to load compiled graph", + exc_info=True, + ) + + if remote_cache: + try: + if (cache_data := remote_cache.get(key)) is not None: + assert isinstance(cache_data, dict) + data = cache_data["data"] + assert isinstance(data, (str, bytes)) + content = base64.b64decode(data) + yield pickle.loads(content), content + except Exception: + log.warning( + "%s unable to load compiled graph", cls.__name__, exc_info=True + ) + + @classmethod + def find_guarded_entry( + cls: type[GuardedCache[T]], + key: str, + local: bool, + remote_cache: RemoteCache[JsonDataTy] | None, + evaluate_guards: Callable[[str, list[int] | list[torch.SymInt]], bool], + hints: list[int], + ) -> tuple[T | None, bytes | None, dict[str, str]]: + """ + Find the first cache entry in iterate_over_candidates that passes `evaluate_guards`. + + Args: + key: The cache key to look up + local: Whether to check the local cache + remote_cache: The remote cache to check, if any + evaluate_guards: Function that evaluates whether a guard passes the check, + given a list of hint values and the guard expression. + hints: List of symint hints paired with evaluate_guards + + Returns: + A tuple of (graph, pickled_content) if found, or (None, None) if not found + """ + graph = None + pickled_content = None + result_status = "full_miss" + sample_guards_expr = None + + # Iterate over any entries in the subdir for this key and evaluate + # guards to determine whether there's a hit. + + for candidate, content in cls.iterate_over_candidates(local, remote_cache, key): + assert hasattr(candidate, "guards_expr") + if not candidate.guards_expr: # type: ignore[attr-defined] + # No guards to evaluate, so this is a hit. + graph = candidate + pickled_content = content + result_status = "hit" + break + + # Evaluate the guard expression in the current context. + # If there's not a cache hit, we don't want the evaluation to + # affect the current env, e.g., cause the creation of new guards, + # so we evaluate with the hints instead of the symbols. + hit = bool(evaluate_guards(candidate.guards_expr, hints)) # type: ignore[attr-defined] + if hit: + graph = candidate + pickled_content = content + result_status = "hit" + sample_guards_expr = candidate.guards_expr + break + else: + # At least one guard missed, log this + result_status = "guard_miss" + sample_guards_expr = candidate.guards_expr + + info = {"cache_status_detailed": result_status} + if sample_guards_expr is not None: + info["cache_status_guard_expr"] = sample_guards_expr + return graph, pickled_content, info + + @classmethod + def _filter_backed_symints( + cls: type[GuardedCache[T]], inputs: Sequence[InputType] + ) -> list[torch.SymInt]: + """ + Get the backed SymInt objects from the input list. Note that we can never + have guards that depend on unbacked symint. + """ + return [s for s in inputs if isinstance(s, torch.SymInt) and has_hint(s)] + + @classmethod + def _get_shape_env(cls: type[GuardedCache[T]]) -> ShapeEnv | None: + """ + Helper to get the shape env from the tracing context. + """ + ctx = torch._guards.TracingContext.try_get() + if not ctx or not ctx.fake_mode: + return None + return ctx.fake_mode.shape_env + + +@CacheArtifactFactory.register +class InductorCacheArtifact(CacheArtifact): + @override + def populate_cache(self) -> None: + FxGraphCache._write_to_local_cache(self.key, self.content) + + @override + @staticmethod + def type() -> str: + return "inductor" + + +class FxGraphCache(GuardedCache[CompiledFxGraph]): + """ + Supports caching and reusing compiled Fx graphs. + + The overall strategy is as follows: + - This cache stores entries on disk. When saving an entry, we can't + serialize callables (that could be C++, Triton, etc.), so we serialize + their own disk cache location. We then recreate the compiled artifact + after fetching from disk. + - For indexing the cache, we gather the fields relevant to identifying an + FxGraph (the graph module, graph inputs, system settings etc.) into an + FxGraphCacheDetails object, pickle it, and compute a hash for the key. + See FxGraphCachePickler. + - Among the metadata we store, we also include a guards expression that's + appropriate for validating any symbols for Tensor arguments that have + symbolic bounds. On cache lookup then, we evaluate those guards in the + current context to validate that a cached entry can be served. + - A given graph could have multiple compiled versions, corresponding to + different sets of guards. Therefore, we store cache entries in the form: + // + - On lookup, we compute the key from the graph details, iterate over all + leaf files in the corresponding subdirectory, deserialize the entry, and + evaluate its guards expression. If the evaluation succeeds, we have a + cache hit. If it fails, we compile the graph and store a new entry. + - Finally, on a cache hit, we need to make sure any guards that would + have been created during compilation are added to the current context. + """ + + # TODO(masnesral): Investigate whether it's beneficial to store compiled graphs + # in an in-memory cache after loading from disk. + @staticmethod + def _get_tmp_dir() -> str: + """ + Get the toplevel temporary directory for storing compiled graphs. + """ + return os.path.join(cache_dir(), "fxgraph") + + @classmethod + def _get_tmp_dir_for_key(cls: type[FxGraphCache], key: str) -> str: + """ + Return the disk location for a given cache key. + """ + return os.path.join(FxGraphCache._get_tmp_dir(), key[1:3], key) + + @staticmethod + def cache_hit_post_compile( + graph: CompiledFxGraph, + cache_info: dict[str, Any], + constants: CompiledFxGraphConstants, + ) -> tuple[CompiledFxGraph | None, dict[str, Any]]: + """ + Cache specific post compile steps that need to run if we find a graph in the cache + This includes putting bundled triton artifacts in the right place, + reloading the PyCodeCache artifact, etc. + + These don't always happen (i.e. on a cache miss, so they are in a separate function from + CompiledFxGraph.post_compile) + """ + if bundle := graph._triton_bundle: + triton_bundler_meta = TritonBundler.read_and_emit(bundle) + if (meta := triton_bundler_meta) is not None: + cache_info["triton_bundler_meta"] = str(meta) + CompileEventLogger.try_add_pt2_compile( + "inductor_compile", cached_kernel_names=meta.cached_kernel_names + ) + CompileEventLogger.try_add_pt2_compile( + "AOTAutogradCache.inductor_load", + cached_kernel_names=meta.cached_kernel_names, + ) + if len(meta.cached_kernel_names) > 0: + CompileEventLogger.try_( + CompileEventLogger.increment_toplevel, "num_triton_bundles" + ) + + try: + artifact_path = graph.after_deserialization(constants) + + from .graph import GraphLowering + + # This is used by tests to check the output for specific details. + if GraphLowering.save_output_code is not None: + GraphLowering.save_output_code(graph.source_code) + + except OSError: + # Not expected, but in case the PyCodeCache entry is removed from + # underneath us, treat it as a cache miss and recompile. + return None, cache_info + + inductor_meta = autotune_cache.inductor_meta_from_config() + code = graph.source_code + AutotuneCacheBundler.begin_compile(inductor_meta, code=code) + + # Increment the cached metrics/counters by the amounts recorded when the FX + # graph was compiled for this cache entry. Pretending these counters + # were incremented normally is useful for testing with the cache enabled. + metrics.CachedMetricsHelper.apply_deltas(graph.metrics_deltas) + counters["inductor"] += graph.counter_deltas + + output_code_log.debug("Output code: \n%s", code) + output_code_log.debug("Output code written to: %s", artifact_path) + # On cache hit, use artifact path as filename + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "fx_graph_runnable", + "encoding": "string", + }, + payload_fn=lambda: graph.runnable_graph_str, + ) + trace_structured( + "inductor_post_grad_graph", + payload_fn=lambda: graph.inductor_post_grad_graph_str, + ) + trace_structured( + "inductor_output_code", + lambda: { + "filename": artifact_path, + "file_path": os.path.abspath(artifact_path), + }, + payload_fn=lambda: code, + ) + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "inductor_provenance_tracking_node_mappings", + "encoding": "json", + }, + payload_fn=lambda: graph.inductor_provenance_mapping_str, + ) + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "inductor_provenance_tracking_kernel_stack_traces", + "encoding": "json", + }, + payload_fn=lambda: graph.inductor_provenance_stack_traces_str, + ) + if ( + get_metrics_context().in_progress() + and graph.inductor_provenance_stack_traces_str + ): + get_metrics_context().add_to_set( + "inductor_provenance", graph.inductor_provenance_stack_traces_str + ) + return graph, cache_info + + @staticmethod + def _lookup_graph( + key: str, + example_inputs: Sequence[InputType], + local: bool, + remote_cache: RemoteCache[JsonDataTy] | None, + constants: CompiledFxGraphConstants, + evaluate_guards: Callable[[str, list[int] | list[torch.SymInt]], bool] + | None = None, + ) -> tuple[CompiledFxGraph | None, dict[str, Any]]: + """ + Lookup a compiled graph in the cache by key. On a hit, return the + deserialized CompiledFxGraph object. On a miss, return None. + `constants` tracks a list of constants, or a way to obtain the list of constants + associated with a given cache entry + `evaluate_guards` allows AOTAutogradCache and other callers to customize + what constitutes a guard success. Normally, a guard hit happens if + `shape_env.evaluate_guards_expression` returns True. + """ + shape_env = FxGraphCache._get_shape_env() + assert shape_env is not None + + symints = FxGraphCache._filter_backed_symints(example_inputs) + hints = [hint_int(s) for s in symints] + + # If this config is turned on, everything is a guard hit and we check nothing + if config.unsafe_skip_cache_dynamic_shape_guards: + # This also makes it so we don't add anything to the dynamic + # shape environment + evaluate_guards = lambda x, y: True # noqa: E731 + + if evaluate_guards is None: + evaluate_guards = shape_env.evaluate_guards_expression + + cache_info: dict[str, Any] = dict() + + # Use the find_graph_for_key method to find a graph for the given key + graph, pickled_content, guard_info = FxGraphCache.find_guarded_entry( + key, local, remote_cache, evaluate_guards, hints + ) + cache_info.update(guard_info) + if graph is None: + return None, cache_info + + if pickled_content is not None: + CacheArtifactManager.record_artifact( + InductorCacheArtifact.type(), key, pickled_content + ) + + # Now re-evaluate with the symints to add any guards to the current env. + if graph.guards_expr: + check = bool(evaluate_guards(graph.guards_expr, symints)) + assert check is True + log.debug( + "fx graph cache key %s post-load guards: %s", key, shape_env.guards + ) + + return FxGraphCache.cache_hit_post_compile(graph, cache_info, constants) + + @staticmethod + def _write_to_local_cache(key: str, content: bytes) -> None: + subdir = FxGraphCache._get_tmp_dir_for_key(key) + if not os.path.exists(subdir): + os.makedirs(subdir, exist_ok=True) + + # Use a hash of the serialized CompiledFxGraph to get a unique file + # name. The specific name doesn't matter since a lookup involves + # iterating over all entries in the parent subdir. + path = os.path.join(subdir, sha256_hash(content)) + write_atomic(path, content, make_dirs=True) + + @staticmethod + def _save_graph( + key: str, + compiled_graph: OutputCode, + example_inputs: Sequence[InputType], + local: bool, + remote_cache: RemoteCache[JsonDataTy] | None, + ) -> None: + """ + Store a serialized CompiledFxGraph on disk. + """ + from .compile_fx import CompiledFxGraph + + assert isinstance(compiled_graph, CompiledFxGraph), ( + f"serialization for {type(compiled_graph)} NYI" + ) + + # Before serializing, compute the guard expression that will be used to + # ensure that a CompiledFxGraph is valid when loaded from the cache. It's + # sufficient to consider only the SymInt args to the fx graph since the + # Tensor shapes are already captured in the hash for the cache key. Any + # Tensor arg with a symbolic shape will have a SymInt arg for the graph. + shape_env = FxGraphCache._get_shape_env() + assert shape_env is not None + symints = FxGraphCache._filter_backed_symints(example_inputs) + guards = shape_env.get_pruned_guards(symints) + compiled_graph.guards_expr = shape_env.produce_guards_expression( + placeholders=symints, guards=guards + ) + disk_compiled_graph = copy(compiled_graph) + disk_compiled_graph.prepare_for_serialization() + + try: + content = pickle.dumps(disk_compiled_graph) + except Exception: + log.warning( + "fx graph cache unable to serialize compiled graph", exc_info=True + ) + counters["inductor"]["fxgraph_cache_pickle_error"] += 1 + return + + try: + CacheArtifactManager.record_artifact( + InductorCacheArtifact.type(), key, content + ) + if local: + FxGraphCache._write_to_local_cache(key, content) + + if remote_cache: + time_taken_ms = int((disk_compiled_graph._time_taken_ns or 0) // 1e6) + cache_data: JsonDataTy = { + "data": base64.b64encode(content).decode("ascii"), + "time_taken_ms": time_taken_ms, + } + remote_cache.put(key, cache_data) + except Exception: + log.warning("fx graph unable to write to cache", exc_info=True) + counters["inductor"]["fxgraph_cache_write_error"] += 1 + + @staticmethod + def _check_for_hop(gm: torch.fx.GraphModule) -> None: + for module in gm.modules(): + if not isinstance(module, torch.fx.GraphModule): + continue + for node in module.graph.nodes: + if ( + isinstance(node.target, torch._ops.HigherOrderOperator) + and not node.target.cacheable() + ): + raise BypassFxGraphCache( + f"Can't cache HigherOrderOperator: {node.target.name()}" + ) + if node.op == "getattr" and isinstance( + getattr(gm, node.target), torch._C.ScriptObject + ): + raise BypassFxGraphCache("Can't cache torchbind objects") + + @staticmethod + def _check_can_cache(gm: torch.fx.GraphModule) -> None: + """ + Check some conditions that would preclude caching and raise BypassFxGraphCache + to bypass in case caching is not possible. + """ + # Post grad custom passes must implement the CustomGraphPass or we don't + # know how to include them in the cache key calculation. + for p in (config.post_grad_custom_pre_pass, config.post_grad_custom_post_pass): + if p and (not isinstance(p, CustomGraphPass) or not p.uuid()): + raise BypassFxGraphCache("Unsupported post grad custom pass") + # Same with the joint custom passes + for p in (config.joint_custom_pre_pass, config.joint_custom_post_pass): + if p and (not isinstance(p, CustomGraphPass) or not p.uuid()): + raise BypassFxGraphCache("Unsupported joint custom pass") + # We should find any users of _pre_fusion_custom_pass and _fuse_ddp_communication_passes + # and ensure they are not passing us raw callables + if config._pre_fusion_custom_pass is not None: + if not isinstance(config._pre_fusion_custom_pass, CustomGraphPass): + raise BypassFxGraphCache("Unsupported _pre_fusion_custom_pass") + for p in config._fuse_ddp_communication_passes: + if callable(p) and not isinstance(p, CustomGraphPass): + raise BypassFxGraphCache("Unsupported _fuse_ddp_communication_pass") + + # Freezing can embed constants that wouldn't be static across runs. + if has_frozen_params(gm) and not torch._utils_internal.justknobs_check( + "pytorch/inductor:allow_freezing_with_caching" + ): + raise BypassFxGraphCache("Skipping graph with frozen constants") + + if config.aot_inductor.use_runtime_constant_folding: + raise BypassFxGraphCache( + "Runtime constant folding can introduce constants that aren't " + "static across runs" + ) + + from torch._inductor.compiler_bisector import CompilerBisector + + if CompilerBisector.bisection_enabled: + log.debug("dont cache graph when bisect enabled") + raise BypassFxGraphCache + + # The treatment of guards in the caching implementation requires that + # we have a shape env. + if FxGraphCache._get_shape_env() is None: + log.debug("fx graph cache no shape env") + raise BypassFxGraphCache("No shape env") + + # We skip caching if there are any HOPs or torchbind objects. + FxGraphCache._check_for_hop(gm) + + @staticmethod + def prepare_key( + gm: torch.fx.GraphModule, + example_inputs: Sequence[InputType], + fx_kwargs: _CompileFxKwargs, + inputs_to_check: Sequence[int], + remote: bool, + ) -> tuple[tuple[str, list[str]] | None, dict[str, Any]]: + """ + Checks that the inductor input is cacheable, then computes + and returns the cache key for the input. + Returns (key_info, cache_info) where: + - key_info is (hash_key, debug_lines), and + - cache_info will contain debug info in the event of BypassFxGraphCache. + + NB: It is possible to have this function return a union instead. But + I personally believe it is more annoying/difficult to read in that format. + """ + try: + FxGraphCache._check_can_cache(gm) + key, debug_lines = compiled_fx_graph_hash( + gm, example_inputs, fx_kwargs, inputs_to_check + ) + except BypassFxGraphCache as e: + counters["inductor"]["fxgraph_cache_bypass"] += 1 + log.info("Bypassing FX Graph Cache because '%s'", e) # noqa: G200 + if remote: + log_cache_bypass("bypass_fx_graph", str(e)) + cache_info = { + "cache_state": "bypass", + "cache_bypass_reason": str(e), + "cache_event_time": time_ns(), + } + return None, cache_info + # If key exists, then cache_info will come from load_with_key + return (key, debug_lines), {} + + @staticmethod + def get_remote_cache() -> RemoteCache[JsonDataTy] | None: + """ + Attempts to load the remote cache, returns None on error. + """ + cache_id = "fx-graph-v1" + return create_cache( + cache_id, + config.is_fbcode(), + "FbRemoteFxGraphCache", + "RemoteFxGraphCache", + ) + + @staticmethod + def load_with_key( + key: str, + debug_lines: list[str], + example_inputs: Sequence[InputType], + local: bool, + remote_cache: RemoteCache[JsonDataTy] | None, + is_backward: bool, + constants: CompiledFxGraphConstants, + evaluate_guards: Callable[[str, list[int] | list[torch.SymInt]], bool] + | None = None, + ) -> tuple[CompiledFxGraph | None, dict[str, Any]]: + """ + Lookup the graph with the given key, and return results and metadata. + Doesn't do any logging on its own, because AOTAutograd handles a cache miss + differently from FXGraphCache. + """ + compiled_graph, cache_info = FxGraphCache._lookup_graph( + key, example_inputs, local, remote_cache, constants, evaluate_guards + ) + cache_info = { + **cache_info, + "key": key, + "components": debug_lines, + "cache_event_time": time_ns(), + } + if compiled_graph is not None: + log.info("fx graph cache hit for key %s", key) + counters["inductor"]["fxgraph_cache_hit"] += 1 + cache_info["cache_state"] = "hit" + if remote_cache: + # Count remote cache hit stats + CompileEventLogger.try_( + CompileEventLogger.increment_toplevel, + "inductor_fx_remote_cache_hit_count", + ) + CompileEventLogger.try_( + CompileEventLogger.add_to_set_toplevel, + "inductor_fx_remote_cache_hit_keys", + key, + ) + + if (time_saved_ns := compiled_graph._time_taken_ns) is not None: + cache_info["time_saved_ns"] = time_saved_ns + CompileEventLogger.try_( + CompileEventLogger.increment_toplevel, + "distributed_ephemeral_timeout_us", + time_saved_ns // 1000, + ) + if ( + ephemeral_increase + := add_ephemeral_timeout_increase_for_distributed(time_saved_ns) + ) != 0: + cache_info["ephemeral_timeout_increase"] = ephemeral_increase + else: + if remote_cache: + # Count remote cache miss stats + CompileEventLogger.try_( + CompileEventLogger.increment_toplevel, + "inductor_fx_remote_cache_miss_count", + ) + CompileEventLogger.try_( + CompileEventLogger.add_to_set_toplevel, + "inductor_fx_remote_cache_miss_keys", + key, + ) + log.info("fx graph cache miss for key %s", key) + counters["inductor"]["fxgraph_cache_miss"] += 1 + cache_info["cache_state"] = "miss" + + return compiled_graph, cache_info + + @staticmethod + def clear() -> None: + """ + Clear out the on-disk cache. + """ + try: + shutil.rmtree(FxGraphCache._get_tmp_dir()) + except FileNotFoundError: + pass + + +@functools.cache +def split_aot_inductor_output_path(path: str) -> tuple[str, str]: + def get_module_ext_type() -> str: + if _IS_WINDOWS: + return ".pyd" + else: + return ".so" + + """Returns the path where the AOT Inductor compiled kernels are stored.""" + if path.endswith(get_module_ext_type()): + return os.path.split(path) + elif path.endswith(".pt2"): + return os.path.split(path) + else: + return path, "" + + +@clear_on_fresh_cache +class CudaKernelParamCache: + cache: dict[str, dict[str, Any]] = {} + cache_clear = staticmethod(cache.clear) + + @classmethod + def set( + cls, + key: str, + params: dict[str, str | None], + cubin: str, + bin_type: str, + asm: str | None = None, + asm_type: str | None = None, + ) -> None: + basename = None + if config.aot_inductor.package_cpp_only: + assert config.triton.unique_kernel_names, ( + "package_cpp_only requires triton kernel names to be unique" + ) + assert params["mangled_name"], "Missing kernel name" + basename = params["mangled_name"] + + _, bin_path = write( + cubin, + bin_type, + hash_type=bin_type, + specified_dir=split_aot_inductor_output_path( + config.aot_inductor.output_path + )[0], + key=basename, + ) + # Retrieve the basename again in case it is a generated hashcode + basename, _ = get_name_and_dir_from_output_file_path(bin_path) + + if config.aot_inductor.emit_multi_arch_kernel: + bin_type_to_ext = {"cubin": ".fatbin", "spv": ".spv", "hsaco": ".hsaco"} + assert bin_type in bin_type_to_ext, ( + "multi_arch_kernel_binary only supported in CUDA/XPU/ROCm" + ) + base_path, _ = os.path.splitext(bin_path) + bin_path = base_path + bin_type_to_ext[bin_type] + + asm_path: str = "" + + # Kernel assembly/IR requirements for AOT Inductor: + # - CUDA/XPU: Always require PTX/SPV + # - ROCm multi-arch: Require LLVM IR (.ll) for bundle compilation + if ( + config.aot_inductor.emit_multi_arch_kernel + or config.aot_inductor.package_cpp_only + ): + # Allow ROCm single-arch to skip (asm=None OK), require for everything else + if torch.version.hip is None or (asm and asm_type): + assert asm, "Missing kernel assembly code" + assert asm_type, "Missing kernel assembly type" + + # Cache directory mapping: asm_type → hash_type + # Problem: LLVM IR extension ".ll" isn't a recognized cache category + # Solution: Map to "code" (generic category for non-standard formats) + # Recognized categories: "ptx", "amdgcn", "spv", "code" + hash_kind = asm_type if asm_type in {"amdgcn", "ptx", "spv"} else "code" + + _, asm_path = write( + asm, + asm_type, + hash_type=hash_kind, + specified_dir=split_aot_inductor_output_path( + config.aot_inductor.output_path + )[0], + key=basename, + ) + + params[get_cpp_wrapper_cubin_path_name()] = bin_path + params["asm"] = asm_path + cls.cache[key] = params + + @classmethod + def get(cls, key: str) -> dict[str, Any] | None: + return cls.cache.get(key, None) + + @classmethod + def get_keys(cls) -> KeysView[str]: + return cls.cache.keys() + + +class AotCodeCompiler: + """ + Compile AOT Inductor generated code. + """ + + @classmethod + def compile( + cls, + graph: GraphLowering, + wrapper_code: str, + kernel_code: str, + serialized_extern_kernel_nodes: str | None, + *, + device_type: str, + additional_files: list[str], + ) -> list[Union[str, Weights]] | str: + """ + Returns the .so path, or returns a list of files that were generated if + config.aot_inductor.package=True. + """ + generated_files: list[str | Weights] = additional_files # type: ignore[assignment] + + _set_gpu_runtime_env() # cpp_extension consults the env + + picked_vec_isa = pick_vec_isa() + vec_isa_cmd_gen = CppBuilder( + name="o", + sources="i", + BuildOption=CppTorchDeviceOptions( + vec_isa=picked_vec_isa, + device_type=device_type, + aot_mode=graph.aot_mode, + ), + ) + # write function will calc source_code hash, the same source code with different + # ISA level should be generate different hash. + # So we need get a command_line which contains isa related parameter as a part of hash key. + # And then pass the command_line to below write function as extra parameter to + # guarantee the source code hash contains ISA difference. + cpp_command = repr(vec_isa_cmd_gen.get_command_line()) + + # Meta internal AOTInductor CPU + use_relative_path = ( + config.is_fbcode() and device_type == "cpu" and graph.aot_mode + ) + + ( + specified_output_path, + specified_artifact_name, + ) = split_aot_inductor_output_path(config.aot_inductor.output_path) + + # TODO (benjaminglass1): the CMake packaging path doesn't support linking files + # built with different flags. Until that's implemented, append the kernel code + # to the wrapper and build everything at max optimization. + if config.aot_inductor.package_cpp_only: + wrapper_code = "\n".join((wrapper_code, kernel_code)) + kernel_code = "" + + wrapper_key, wrapper_path = write( + wrapper_code, + "wrapper.cpp", + extra=cpp_command, + specified_dir=specified_output_path, + key=config.aot_inductor.model_name_for_generated_files, + ) + kernel_code = ( + f"// Triton kernels are embedded as comments in {wrapper_path}\n" + + kernel_code + ) + _, kernel_path = write( + kernel_code, + "kernel.cpp", + extra=cpp_command, + specified_dir=specified_output_path, + key=config.aot_inductor.model_name_for_generated_files, + ) + + header_code = "" + header_path = "" + if not config.aot_inductor.dynamic_linkage: + # to link statically, we also need a header file + with open( + os.path.join( + os.path.dirname(os.path.dirname(__file__)), + "csrc", + "inductor", + "aoti_runtime", + "model.h", + ) + ) as f: + # model_name_for_generated_files is guaranteed to be non-empty when compile_standalone + model_class_name = config.aot_inductor.model_name_for_generated_files + class_name = f"AOTInductorModel{model_class_name}" + header_code = f.read() + + # we replace like this to avoid replacing + # AOTInductorModelBase and AOTInductorModelKernelsBase + header_code = ( + header_code.replace("", f"<{class_name}>") + .replace("AOTInductorModel(", f"{class_name}(") + .replace("AOTInductorModel :", f"{class_name} :") + ) + _, header_path = write( + header_code, + "h", + specified_dir=specified_output_path, + key=model_class_name, + ) + + # Log the AOTInductor wrapper and kernel code, if needed. + with WritableTempFile("w+") as t: + """ + Avoid "Permission denied error" on Windows: + with tempfile.NamedTemporaryFile("w", suffix=".gv") as temp_file: + # Not writable on Windows: + # https://docs.python.org/3/library/tempfile.html#tempfile.NamedTemporaryFile + + Example: + with WritableTempFile("w", suffix=".gv") as temp_file: + tree.to_dotfile(temp_file.name) + """ + t.writelines((wrapper_code, "\n", kernel_code, "\n")) + t.flush() + V.debug.output_code(t.name, extension="cpp") + + if config.aot_inductor.package: + generated_files.append(wrapper_path) + if not config.aot_inductor.package_cpp_only: + generated_files.append(kernel_path) + if not config.aot_inductor.dynamic_linkage: + generated_files.append(header_path) + + output_code_log.info("Wrapper code written to: %s", wrapper_path) + output_code_log.info("Kernel code written to: %s", kernel_path) + trace_structured( + "graph_dump", + lambda: { + "name": "inductor_aot_wrapper_code", + "type": "cpp", + "filename": wrapper_path, + }, + payload_fn=lambda: wrapper_code, + ) + trace_structured( + "graph_dump", + lambda: { + "name": "inductor_aot_kernel_code", + "type": "cpp", + "filename": kernel_path, + }, + payload_fn=lambda: kernel_code, + ) + if not config.aot_inductor.dynamic_linkage: + output_code_log.info("Header code written to: %s", header_path) + trace_structured( + "graph_dump", + lambda: { + "name": "inductor_aot_header_code", + "type": "cpp", + "filename": header_path, + }, + payload_fn=lambda: header_code, + ) + + # We use a file lock below to protect FS operations. The lock file + # is scoped to the 'key', so make sure the consts_s is protected + # by the same lock: + wrapper_path_operator = Path(wrapper_path) + kernel_path_operator = Path(kernel_path) + specified_sub_dir = wrapper_path_operator.parent / wrapper_key + if not specified_sub_dir.exists(): + specified_sub_dir.mkdir(exist_ok=True) + cmake_path = str(Path(specified_sub_dir) / "CMakeLists.txt") + + def _compile_consts(consts: bytes, platform: str) -> str: + # Load from aot_inductor, and update the value on demand. + use_asm_build: bool = config.aot_inductor.use_consts_asm_build + + if platform == "linux": + if graph.mutated_buffers & OrderedSet(graph.constants.keys()): + # .data section is between .text and .bss. When the size of .data is large, + # during the linking, the relocation of .text against .bss may overflow. + # Rename it to .ldata so that it won't be in between the .text and .bss section + if len(consts) > 2_000_000_000: + raise ValueError( + "Models with buffer mutation included doesn't support constants greater than 2GB!" + ) + section_attr = '.ldata, "aw"' + else: + section_attr = '.lrodata, "a"' + symbol_prefix = "" + elif platform == "darwin": + section_attr = "__DATA,__data" + symbol_prefix = "_" + elif platform == "win32": + symbol_prefix = "" + # ASM build is not supported on Windows, force use CPP build. + use_asm_build = False + else: + raise RuntimeError(f"Unsupported platform: {platform}") + + # Intel compiler failed to compile this manually constructed assembly file. + # Switch XPU to use consts cpp build. + if device_type == "xpu": + use_asm_build = False + + is_large_consts = len(consts) > 1024 + is_zero_size_consts = len(consts) == 0 + + def format_consts_to_gnu_asm( + consts: bytes, + align_bytes: int, + symbol_prefix: str, + is_large_consts: bool, + ) -> tuple[str, str]: + consts_asm = f"\t.section\t{section_attr}\n" + consts_asm += f"\t.balign {align_bytes}\n" + consts_asm += f"\t.globl\t{symbol_prefix}_binary_constants_bin_start\n" + consts_asm += f"{symbol_prefix}_binary_constants_bin_start:\n" + if not is_large_consts: + for c in consts: + consts_asm += f"\t.byte {c}\n" + # Add one element even if constants are empty + # Otherwise assembler will not put them in data section + if not consts: + consts_asm += "\t.space 1\n" + else: + consts_asm += "\t.quad 0x1234567899abcdef\n" + consts_asm += f"\t.space {len(consts) - 8}\n" + consts_asm += f".globl\t{symbol_prefix}_binary_constants_bin_end\n" + consts_asm += f"{symbol_prefix}_binary_constants_bin_end:\n" + return consts_asm, "weights.S" + + # Use c++ to convert consts to object file can support more compilers, such as msvc and icx. + def format_consts_to_cpp( + consts: bytes, align_bytes: int, symbol_prefix: str + ) -> tuple[str, str]: + consts_size = len(consts) + asan_attr = """#if defined(__clang__) || defined (__GNUC__)\t\n\ +#define ATTRIBUTE_NO_SANITIZE_ADDRESS __attribute__((no_sanitize("address")))\t\n\ +#else\t\n\ +#define ATTRIBUTE_NO_SANITIZE_ADDRESS\t\n\ +#endif\t\n\ +\t\n\ +ATTRIBUTE_NO_SANITIZE_ADDRESS\t\n""" + const_cpp = asan_attr + const_cpp += f"alignas({align_bytes}) extern " + const_cpp += f"unsigned char {symbol_prefix}_binary_constants_bin_start[{consts_size}] = {{\t\n" + count_bytes = 0 + for c in consts: + const_cpp += f"{c}, " + count_bytes = count_bytes + 1 + if count_bytes % 16 == 0: + const_cpp += "\t\n" + const_cpp += "};\t\n" + const_cpp += f"alignas({align_bytes}) extern unsigned char * {symbol_prefix}_binary_constants_bin_end;\t\n" + return const_cpp, "weights.cpp" + + def get_zero_consts_asm_code( + align_bytes: int, + symbol_prefix: str, + ) -> tuple[str, str]: + """ + This function handles zero-sized constants because the C++ standard prohibits zero-length arrays: + https://stackoverflow.com/questions/9722632/what-happens-if-i-define-a-0-size-array-in-c-c + + On Windows (MSVC): + The compiler reports error C2466 for zero-sized arrays: + https://learn.microsoft.com/en-us/cpp/error-messages/compiler-errors-1/compiler-error-c2466 + Solution: Use assembly compilation to handle this case. + + Why not use Win32 assembly for all paths? + ml64 only supports alignment up to 16 bytes, which isn't optimal for performance. + + Cross-platform implementation: + Linux: Added '-pedantic' to disable zero-sized arrays in C++ compiler + Windows: MSVC naturally rejects zero-sized arrays by default + """ + if _IS_WINDOWS: + # Windows ml64 is max support align to 16, but it is no effect to zero size data. + asm_code = """ +option casemap:none +.data +?_binary_constants_bin_start@@3PAEA: +align 16 +?_binary_constants_bin_end@@3PAEA: +align 16 +public ?_binary_constants_bin_start@@3PAEA +public ?_binary_constants_bin_end@@3PAEA +end +""" + asm_ext = "asm" + else: + asm_code = f"\t.section\t{section_attr}\n" + asm_code += f"\t.balign {align_bytes}\n" + asm_code += ( + f"\t.globl\t{symbol_prefix}_binary_constants_bin_start\n" + ) + asm_code += f"{symbol_prefix}_binary_constants_bin_start:\n" + asm_code += f".globl\t{symbol_prefix}_binary_constants_bin_end\n" + asm_code += f"{symbol_prefix}_binary_constants_bin_end:\n" + asm_ext = "S" + return asm_code, asm_ext + + if use_asm_build: + consts_code, code_ext = format_consts_to_gnu_asm( + consts, ALIGN_BYTES, symbol_prefix, is_large_consts + ) + else: + if is_zero_size_consts: + consts_code, code_ext = get_zero_consts_asm_code( + ALIGN_BYTES, symbol_prefix + ) + else: + consts_code, code_ext = format_consts_to_cpp( + consts, ALIGN_BYTES, symbol_prefix + ) + + _, consts_s = write( + consts_code, + code_ext, + specified_dir=str(specified_sub_dir), + key=config.aot_inductor.model_name_for_generated_files, + ) + consts_s = Path(consts_s) + object_build_options = CppTorchDeviceOptions( + device_type=device_type, + aot_mode=graph.aot_mode, + compile_only=True, + use_relative_path=use_relative_path, + ) + object_builder = CppBuilder( + name=str(consts_s.stem), + sources=str(consts_s), + output_dir=str(consts_s.parent), + BuildOption=object_build_options, + ) + consts_o = object_builder.get_target_file_path() + if use_asm_build is False and is_zero_size_consts: + run_asm_build_object(str(consts_s), consts_o, str(consts_s.parent)) + else: + object_builder.build() + + if is_large_consts and use_asm_build: + with open(consts_o, "r+b") as f: + f.seek(0) + hdr = f.read(1024) + # Search for magic number and write the actual data over it + start_idx = ( + hdr.find(b"\xef\xcd\xab\x99\x78\x56\x34\x12") + if sys.byteorder == "little" + else hdr.find(b"\x12\x34\x56\x78\x99\xab\xcd\xef") + ) + assert start_idx != -1 + f.seek(start_idx) + pos = 0 + while pos < len(consts): + rc = f.write(consts[pos:]) + pos += rc + + # Remove the .S file to save space + os.remove(consts_s) + + return consts_o + + from torch.utils._filelock import FileLock + + lock_dir = get_lock_dir() + lock = FileLock( + os.path.join(lock_dir, wrapper_key + ".lock"), timeout=LOCK_TIMEOUT + ) + with lock: + if serialized_extern_kernel_nodes: + extern_kernel_nodes_json = str( + wrapper_path_operator.with_suffix(".json") + ) + with open(extern_kernel_nodes_json, "w") as f: + f.write(serialized_extern_kernel_nodes) + + if config.aot_inductor.package: + generated_files.append(extern_kernel_nodes_json) + + metadata = config.aot_inductor.metadata + metadata["AOTI_DEVICE_KEY"] = device_type + + # Add environment information to ensure .so compatibility + metadata.update(get_device_information(device_type)) + + # Save user provided metadata + meta_json = str( + wrapper_path_operator.with_name( + f"{wrapper_path_operator.stem}_metadata.json" + ) + ) + for k, v in config.aot_inductor.metadata.items(): + assert isinstance(k, str) and isinstance(v, (str)), ( + "Metadata must only contain strings" + ) + + with open(meta_json, "w") as f: + f.write(json.dumps(config.aot_inductor.metadata)) + + kernel_meta_json = str( + kernel_path_operator.with_name( + f"{kernel_path_operator.stem}_metadata.json" + ) + ) + shutil.copy(meta_json, kernel_meta_json) + + if config.aot_inductor.package: + generated_files.append(meta_json) + if not config.aot_inductor.package_cpp_only: + generated_files.append(kernel_meta_json) + + output_so = ( + config.aot_inductor.output_path + if specified_artifact_name + else str(wrapper_path_operator.with_suffix(".so")) + ) + all_cuda = all( + graph.get_original_value_of_constant(name).is_cuda + for name in graph.constants + if name not in graph.folded_constants + ) + + def _to_bytes(t: torch.Tensor, all_cuda: bool) -> bytes: + def _pad_to_alignment(raw_bytes: bytes) -> bytes: + padded_bytes = raw_bytes.ljust( + (len(raw_bytes) + ALIGN_BYTES - 1) // ALIGN_BYTES * ALIGN_BYTES, + b"\x00", + ) + return padded_bytes + + # This serializes the tensor's untyped_storage to bytes by accessing + # the raw data of the underlying structure. + import ctypes + + if t.numel() == 0: + return b"" + + if t.is_mkldnn: + data_ptr = torch.ops.mkldnn.data_ptr(t) + nbytes = torch.ops.mkldnn._nbytes(t) + else: + t_cpu = t.untyped_storage().cpu() + data_ptr = t_cpu.data_ptr() + nbytes = t_cpu.nbytes() + + raw_array = ctypes.cast( + data_ptr, + ctypes.POINTER(ctypes.c_ubyte * nbytes), + ) + # pyrefly: ignore [missing-attribute] + raw_bytes = bytes(raw_array.contents) + return raw_bytes if all_cuda else _pad_to_alignment(raw_bytes) + + if ( + config.aot_inductor.package_constants_in_so + or config.aot_inductor.package_constants_on_disk_format == "binary_blob" + ): + serialized_weights = b"".join( + _to_bytes(graph.get_original_value_of_constant(name), all_cuda) + for name in graph.constants + if name not in graph.folded_constants + ) + else: + serialized_weights = b"" + + if config.aot_inductor.package_constants_on_disk_format == "pickle_weights": + # We need to return a storage key here because the original value tensor might be a clone + weights_dict = Weights( + { + graph.allocated_constant_name[name]: ( + graph.get_original_value_of_constant(name), + TensorProperties(graph.constants[name]), + ) + for name in graph.constants + if name not in graph.folded_constants + } + ) + generated_files.append(weights_dict) + + consts_size = len(serialized_weights) + + use_external_weights, use_mmap_weights = determine_aoti_mmap_flags( + consts_size + ) + if use_external_weights and use_mmap_weights: + # Should never reach here, just a check for sanity + raise RuntimeError( + "use_external_weights and use_mmap_weights cannot both be True." + ) + + external_weights_path = None + if use_external_weights: + external_weights_filename = f"{wrapper_path_operator.stem}_weights.blob" + external_weights_path = str( + wrapper_path_operator.with_name(external_weights_filename) + ) + + compile_command: dict[str, Any] = { + "aot_mode": graph.aot_mode, + "device_type": device_type, + "use_mmap_weights": use_mmap_weights, + "use_mmap_weights_external": use_external_weights, + "use_relative_path": use_relative_path, + "vec_isa": picked_vec_isa, + } + # If we're packaging via CMake, we build the whole code at max optimization. + wrapper_build_options = CppTorchDeviceOptions( + compile_only=True, + min_optimize=not config.aot_inductor.package_cpp_only, + **compile_command, + ) + kernel_build_options = CppTorchDeviceOptions( + compile_only=True, + **compile_command, + ) + + # potentially, precompile the AOT header for this device + if config.aot_inductor.precompile_headers and not _IS_WINDOWS: + header_file = _get_cpp_wrapper_header( + device_type, aot_mode=graph.aot_mode + ) + wrapper_build_options.precompiled_header = _precompile_header( + header_file, + cpp_command, + min_optimize=not config.aot_inductor.package_cpp_only, + **compile_command, + ) + if cpp_prefix := _get_cpp_prefix_header(device_type): + kernel_build_options.precompiled_header = _precompile_header( + cpp_prefix, + cpp_command, + **compile_command, + ) + + wrapper_builder = CppBuilder( + name=str(wrapper_path_operator.stem), + sources=wrapper_path, + output_dir=str(wrapper_path_operator.parent), + BuildOption=wrapper_build_options, + ) + wrapper_compile_cmd = wrapper_builder.get_command_line() + wrapper_o = wrapper_builder.get_target_file_path() + + kernel_builder = CppBuilder( + name=str(kernel_path_operator.stem), + sources=kernel_path, + output_dir=str(wrapper_path_operator.parent), + BuildOption=kernel_build_options, + ) + kernel_compile_cmd = kernel_builder.get_command_line() + kernel_o = kernel_builder.get_target_file_path() + + log.debug("aot wrapper compilation command: %s", wrapper_compile_cmd) + log.debug("aot kernel compilation command: %s", kernel_compile_cmd) + if config.aot_inductor.package_cpp_only: + # Not doing the actual compilation here + compile_flags = str( + wrapper_path_operator.with_name( + f"{wrapper_path_operator.stem}_compile_flags.json" + ) + ) + wrapper_build_options.save_flags_to_json(compile_flags) + generated_files.append(compile_flags) + wrapper_builder.save_compile_cmd_to_cmake(cmake_path, device_type) + wrapper_builder.save_src_to_cmake(cmake_path, wrapper_path) + generated_files.append(cmake_path) + else: + try: + wrapper_builder.build() + except (exc.CppCompileError, SkipFrame) as e: + if " is too big to optimize" in str(e): + raise RuntimeError( + "Please use torch._inductor.config.aot_inductor.compile_wrapper_opt_level = 'O0' flag." + ) from e + raise e + kernel_builder.build() + + if not use_mmap_weights: + aot_constants = serialized_weights + magic_number = 0 + if use_external_weights: + aot_constants = struct.pack("q", consts_size) + assert external_weights_path is not None + # For external weights, write weights to separate file and embed minimal placeholder + with open(external_weights_path, "wb") as f_weights: + f_weights.write(serialized_weights) + generated_files.append(external_weights_path) + else: + # we'll append weights binary to the end of .so file and mmap it when loading + magic_number = cast( + int, torch.randint(0, torch.iinfo(torch.int64).max, (1,)).item() + ) + aot_constants = struct.pack("qq", consts_size + 8, magic_number) + + consts_o = _compile_consts(aot_constants, sys.platform) + custom_obj_idx = 0 + # Note that custom_objs_config.json file is different from the model_constants_config.json file produced + # in package_sigmoid(). The keys in custom_objs_config.json directly correspond to the arg name in extern + # nodes json. The key in model_constants_config.json produced by package_sigmoid is the attribute name in the + # user model code. + + qual_name_to_id = {} # Map from constant name to its name in constants folder + for custom_obj_idx, (name, constant) in enumerate( + graph.torchbind_constants.items() + ): + if isinstance( + constant, torch._library.fake_class_registry.FakeScriptObject + ): + constant = constant.real_obj + assert isinstance(constant, torch._C.ScriptObject) + custom_obj_name = f"{CUSTOM_OBJ_FILENAME_PREFIX}{custom_obj_idx}" + + log.debug("saving script object %s as %s", name, custom_obj_name) + + qual_name_to_id[name] = custom_obj_name + custom_obj_bytes = torch._C._pickle_save(constant) + custom_obj_path = os.path.join( + wrapper_path_operator.parent, custom_obj_name + ) + + write_atomic(custom_obj_path, custom_obj_bytes, True) + generated_files.append(custom_obj_path) + + if qual_name_to_id: + constants_config_json = os.path.join( + wrapper_path_operator.parent, "custom_objs_config.json" + ) + with open(constants_config_json, "w") as f: + f.write(json.dumps(qual_name_to_id)) + generated_files.append(constants_config_json) + + gpu_codecache: ROCmCodeCache | CUDACodeCache = ( + ROCmCodeCache() if torch.version.hip else CUDACodeCache() + ) + gpu_kernels_o = gpu_codecache.aot_kernels_o.copy() + # clear the list of aot kernels after each linking + gpu_codecache.aot_kernels_o.clear() + + if gpu_kernels_o: + assert not config.aot_inductor.emit_multi_arch_kernel, ( + "TODO: add emit_multi_arch_kernel support for cutlass kernels" + ) + + cubins_o = [] + asm_files = [] + if not _IS_WINDOWS: + ld, objcopy = get_ld_and_objcopy(use_relative_path) + kernels = getattr(V.graph.wrapper_code, "_kernel_name_to_body", {}) + for kernel_name, value in CudaKernelParamCache.cache.items(): + if kernel_name not in kernels: + # It is possible that CudaKernelParamCache contains more Triton kernels + # than what the current graph uses + continue + + if asm_file := value["asm"]: + asm_files.append(asm_file) + + cubin_file = value[get_cpp_wrapper_cubin_path_name()] + if ( + config.aot_inductor.emit_multi_arch_kernel + and device_type == "cuda" + ): + if torch.version.hip is None: + current_arch = _nvcc_arch_as_compile_option() + cmd = ( + # pyrefly: ignore [unbound-name] + f"{_cuda_compiler()} -fatbin {asm_file} -o {cubin_file} " + # Triton only allows generating PTX version as same as the current arch + f"-gencode arch=compute_{current_arch},code=compute_{current_arch} " + # Include SASS for the current specific arch + f"-gencode arch=compute_{current_arch},code=sm_{current_arch} " + ) + try: + subprocess.run( + cmd.split(), + capture_output=True, + text=True, + check=True, + ) + except subprocess.CalledProcessError as e: + print( + f"{cmd} failed with:\nstdout:\n{e.stdout}\nstderr:\n{e.stderr}", + file=sys.stderr, + ) + raise + + else: + # ROCm multi-arch: compile LLVM IR to multi-arch bundle + from torch._inductor.rocm_multiarch_utils import ( + compile_multiarch_bundle_from_llvm_ir, + ) + + if not os.path.exists(asm_file): + raise RuntimeError( + f"Multi-arch ROCm compilation requires LLVM IR file, " + f"but {asm_file} not found. " + f"Ensure asm_type='ll' is captured in triton_heuristics.py" + ) + + # Compile for multiple archs and bundle them + success = compile_multiarch_bundle_from_llvm_ir( + llvm_ir_path=asm_file, + output_bundle_path=cubin_file, + target_archs=None, + ) + + if not success: + raise RuntimeError( + f"Failed to compile multi-arch bundle for kernel {kernel_name}. " + f"Check that ROCm toolchain is available and LLVM IR is valid." + ) + + log.info("Created multi-arch bundle: %s", cubin_file) + + if config.aot_inductor.embed_kernel_binary: + # Embed cubin files into model.so using objcopy + cubins_o.append( + convert_cubin_to_obj(cubin_file, kernel_name, ld, objcopy) + ) + + output_name, output_dir = get_name_and_dir_from_output_file_path(output_so) + so_build_options = CppTorchDeviceOptions( + vec_isa=picked_vec_isa, + device_type=device_type, + aot_mode=graph.aot_mode, + use_relative_path=use_relative_path, + ) + + obj_srcs = [wrapper_o, kernel_o, consts_o, *gpu_kernels_o, *cubins_o] + so_builder = CppBuilder( + name=output_name, + sources=obj_srcs, + output_dir=output_dir, + BuildOption=so_build_options, + ) + link_cmd = so_builder.get_command_line() + output_so = so_builder.get_target_file_path() + + log.debug("aot linkage command: %s", link_cmd) + + # Append cmds to the end of codegen-ed wrapper file + with open(wrapper_path, "a") as f: + f.write("\n") + f.write(f"// Compile cmd\n// {wrapper_compile_cmd}\n") + f.write(f"// Link cmd\n// {link_cmd}\n") + + with open(kernel_path, "a") as f: + f.write("\n") + f.write(f"// Compile cmd\n// {kernel_compile_cmd}\n") + f.write(f"// Link cmd\n// {link_cmd}\n") + + if config.aot_inductor.package_cpp_only: + linker_flags = str( + wrapper_path_operator.with_name( + f"{wrapper_path_operator.stem}_linker_flags.json" + ) + ) + so_build_options.save_flags_to_json(linker_flags) + generated_files.append(linker_flags) + generated_files.append(_LINKER_SCRIPT) + + # If we only want to package the cpp, then we need to save the + # weights separately into a bin, and we also need to prevent compiling the so + if use_mmap_weights: + weight_file = str( + wrapper_path_operator.with_name( + f"{wrapper_path_operator.stem}_serialized_weights.bin" + ) + ) + with open(weight_file, "wb") as f_weights: + f_weights.write(serialized_weights) + f_weights.write(struct.pack("q", magic_number)) + + generated_files.append(weight_file) + else: + # TODO: unify to always use mmap_weights + generated_files.append(consts_o) + so_builder.save_src_to_cmake(cmake_path, consts_o) + + # Different CMake strategies for CUDA vs ROCm: + # - CUDA: Save asm for CMake to recompile (user has nvcc) + # - ROCm: Link pre-compiled bundle (user may lack dev tools) + if ( + config.aot_inductor.emit_multi_arch_kernel + and torch.version.hip is None + ): + so_builder.save_kernel_asm_to_cmake(cmake_path, asm_files) + generated_files.extend(asm_files) + else: + # ROCm multi-arch + all single-arch: Link pre-compiled objects + # Bundle already embedded in .o files - just link into .so + obj_srcs = [*gpu_kernels_o, *cubins_o] + generated_files.extend(obj_srcs) + for obj in obj_srcs: + so_builder.save_src_to_cmake(cmake_path, obj) + + so_builder.save_link_cmd_to_cmake(cmake_path) + else: + so_builder.build() + for o_file in obj_srcs: + if o_file in gpu_kernels_o: + continue + # Remove these as they are not needed anymore + os.remove(o_file) + + if use_mmap_weights: + if config.aot_inductor.cross_target_platform == "windows": + raise RuntimeError( + "when cross_target_platform is windows, use_mmap_weights should not be true." + ) + + def get_page_size() -> int: + # Don't use resource.getpagesize() on Windows, as it is a Unix specific package + # as seen in https://docs.python.org/2/library/resource.html + if _IS_WINDOWS: + from ctypes import ( # type: ignore[attr-defined] + byref, + Structure, + windll, + ) + from ctypes.wintypes import DWORD, LPVOID, WORD + + class SYSTEM_INFO(Structure): + _fields_ = [ + ("wProcessorArchitecture", WORD), + ("wReserved", WORD), + ("dwPageSize", DWORD), + ("lpMinimumApplicationAddress", LPVOID), + ("lpMaximumApplicationAddress", LPVOID), + ("dwActiveProcessorMask", DWORD), + ("dwNumberOfProcessors", DWORD), + ("dwProcessorType", DWORD), + ("dwAllocationGranularity", DWORD), + ("wProcessorLevel", WORD), + ("wProcessorRevision", WORD), + ] + + si = SYSTEM_INFO() + windll.kernel32.GetSystemInfo(byref(si)) + sys_page_size = si.dwPageSize + else: + import resource + + sys_page_size = resource.getpagesize() + + return sys_page_size + + page_size_ = get_page_size() + page_size = max(16384, page_size_) + + with open(output_so, "a+b") as f_so: + so_size = f_so.tell() + # Page align the weights + f_so.write(b" " * (page_size - so_size % page_size)) + f_so.write(serialized_weights) + f_so.write(struct.pack("q", magic_number)) + + if config.aot_inductor.package: + generated_files.append(output_so) + + if config.trace.provenance_tracking_level != 0: + kernel_info = torch._inductor.debug.create_kernel_information_json() + kernel_info_json = os.path.join( + wrapper_path_operator.parent, "kernel_information.json" + ) + with open(kernel_info_json, "w") as f: + f.write(json.dumps(kernel_info, indent=4)) + generated_files.append(kernel_info_json) + + if config.aot_inductor.package: + # We want to return the directory that contains all the AOTI + # generated files, not just the so + # return os.path.split(output_so)[0] + return generated_files + + return output_so + + +_libgomp: CDLL | None = None + + +def custom_op_wrapper(op: str, *args: Any) -> list[c_void_p] | c_void_p | None: + # This function will be called from generated cpp wrapper code in the JIT mode. + # Because tensors will be passed in as AtenTensorHandle, we need to explicitly convert them. + def convert_arg(arg: Any) -> Any: + if str(type(arg)) == "": + # No easy way to do isinstance check on PyCapsule + return torch._C._aoti.alloc_tensor_by_stealing_from_void_ptr(arg) + elif isinstance(arg, (list, tuple)): + return type(arg)(convert_arg(a) for a in arg) + else: + return arg + + converted_args = [convert_arg(arg) for arg in args] + + assert op.startswith("torch.ops."), ( + op + " can not be called through custom_op_wrapper" + ) + func = None + for i, s in enumerate(op.split(".")): + if i == 0: + func = importlib.import_module(s) + func = getattr(func, s) + + assert callable(func), op + " can not be loaded through custom_op_wrapper" + + # convert any kwarg-only arguments to kwargs + kwargs = dict() + # pyrefly: ignore [missing-attribute] + for func_arg, conv_arg in zip(func._schema.arguments, converted_args): + if func_arg.kwarg_only: + kwargs[func_arg.name] = conv_arg + if kwargs: + del converted_args[-len(kwargs) :] + + result = func(*converted_args, **kwargs) + if result is None: + return None + + if isinstance(result, (list, tuple)): + # unsafe_alloc_void_ptrs_from_tensors expects result contains tensor only + result = [torch.tensor([]) if r is None else r for r in result] + for r in result: + assert isinstance(r, torch.Tensor), op + " returns a list of non-tensors" + return torch._C._aoti.unsafe_alloc_void_ptrs_from_tensors(result) # type: ignore[arg-type] + + assert isinstance(result, torch.Tensor), op + " returns a non-tensor" + return torch._C._aoti.unsafe_alloc_void_ptr_from_tensor(result) + + +# Precompiled headers are persistent past program runtime, but associated with one +# specific compiler version and set of flags. We explicitly use default_cache_dir here +# because these headers need to be global, rather than ignored by fresh_cache. +_HEADER_DIR = os.path.join(default_cache_dir(), "precompiled_headers") +_HEADER_LOCK_DIR = os.path.join(_HEADER_DIR, "locks") + + +@functools.cache +def _precompile_header( + header: str, + hashable_cmd_line: str, + **compile_command: Any, +) -> str: + assert not _IS_WINDOWS, ( + "CppBuilder does not currently support precompiling on Windows!" + ) + + # Get the preprocessed output from the header file to be precompiled. This allows + # us to properly invalidate the file cache when any header dependency changes. This + # is thread-safe, as each thread will get its own temporary directory. + # + # N.B. we can't use NamedTemporaryFile here because Windows errors out on attempts + # to read from a file with an open write handle. + with tempfile.TemporaryDirectory() as preprocessing_dir: + preprocessing_header = Path(preprocessing_dir) / "header.hpp" + preprocessing_header.write_text(f"#include <{header}>\n") + preprocessor = CppBuilder( + name=str(preprocessing_header)[:-4], # strip off the .hpp extension + sources=str(preprocessing_header), + BuildOption=CppTorchDeviceOptions(**compile_command, preprocessing=True), + ) + preprocessor.build() + + def _get_file_checksum(filename: str) -> str: + """Reading the whole preprocessed header in for hashing is very expensive, + but calling a fast hashing utility in a subprocess is cheap.""" + # If Windows support needs to be added here, use certutil -hashfile. + cmd_output = subprocess.run( + ("openssl", "sha512", filename), capture_output=True, text=True + ) + return cmd_output.stdout.split()[-1] + + preprocessor_hash = _get_file_checksum(preprocessor.get_target_file_path()) + + header_build_option = CppTorchDeviceOptions(**compile_command, precompiling=True) + header_hash, header_full_path = write( + content=f"#include <{header}>\n", + extension="h", + extra=( + hashable_cmd_line + + preprocessor_hash + + get_compiler_version_info(header_build_option.get_compiler()) + ), + specified_dir=_HEADER_DIR, + ) + cpp_builder = CppBuilder( + name=header_full_path, + sources=header_full_path, + BuildOption=header_build_option, + ) + # _worker_compile_cpp will automatically ignore any compilation whose result already + # exists, so this is always safe. + os.makedirs(_HEADER_LOCK_DIR, exist_ok=True) + _worker_compile_cpp( + os.path.join(_HEADER_LOCK_DIR, f"{header_hash}.lock"), + (cpp_builder,), + ) + + return header_full_path + + +def _get_cpp_prefix_header(device: str) -> str | None: + if device.startswith("cpu"): + return "torch/csrc/inductor/cpp_prefix.h" + return None + + +def _get_cpp_wrapper_header(device: str, aot_mode: bool = False) -> str: + """Given a device type (and optionally whether we're in AOT Inductor mode), returns + the path to the cpp_wrapper header file to be precompiled.""" + base_device = device.split(":", maxsplit=1)[0] + is_array_ref = config.aot_inductor.allow_stack_allocation and base_device == "cpu" + return ( + "torch/csrc/inductor/" + f"{'aoti_include' if aot_mode else 'cpp_wrapper'}/" + f"{'array_ref' if is_array_ref else base_device}.h" + ) + + +@clear_on_fresh_cache +class CppCodeCache: + """Compiles and caches C++ libraries. Users of this class supply the source code to + be compiled, while compilation flags are set by CppBuilder.""" + + cache: dict[str, Callable[[], CDLL | ModuleType]] = {} + cache_clear = staticmethod(cache.clear) + cpp_compile_command_flags: dict[str, Any] = {} + + @staticmethod + def _load_library_inner(path: str, key: str) -> CDLL | ModuleType: + return cdll.LoadLibrary(path) + + @classmethod + def _load_library(cls, path: str, key: str) -> CDLL | ModuleType: + try: + result = cls._load_library_inner(path, key) + result.key = key # type: ignore[union-attr] + return result + except (ImportError, OSError) as e: + if "gomp" in str(e) and os.path.exists("/usr/lib64/libgomp.so.1"): + # hacky workaround for fbcode/buck + global _libgomp + _libgomp = cdll.LoadLibrary("/usr/lib64/libgomp.so.1") + result = cls._load_library_inner(path, key) + result.key = key # type: ignore[union-attr] + return result + if "failed to map segment from shared object" in str(e): + raise OSError( + f"{e}. The most common reason this may occur is if the {tempfile.gettempdir()} folder " + "is mounted with noexec (e.g., by default Docker mounts tmp file systems " + f"as noexec). Please remount {tempfile.gettempdir()} with exec enabled, or set another " + "temporary directory with TORCHINDUCTOR_CACHE_DIR environment variable." + ) from e + raise + + @classmethod + def _get_uncompiled_header(cls, device: str) -> str | None: + """ + Given a device type, returns the path to a CPP header file to be precompiled. + """ + return None + + @classmethod + def load_async( + cls, + main_code: str, + device_type: str = "cpu", + submit_fn: Any = None, + extra_flags: Sequence[str] = (), + optimized_code: str | None = None, + ) -> Any: + """Compile and load a C++ library. Returns a callable that returns the loaded + library.""" + compile_command = { + **cls.cpp_compile_command_flags, + "device_type": device_type, + "extra_flags": extra_flags, + "use_relative_path": config.is_fbcode(), + "vec_isa": pick_vec_isa(), + } + + _set_gpu_runtime_env() # cpp_extension consults the env + + # Note the distinction between the two booleans. We do minimal optimization if + # the optimized_code argument is present at all, since that's how the user of + # this function opts in, but we do compilation and linking in one step if the + # optimized_code argument is empty (as a micro-optimization). + main_build_option = CppTorchDeviceOptions( + compile_only=bool(optimized_code), + min_optimize=optimized_code is not None, + # pyrefly: ignore [bad-argument-type] + **compile_command, + ) + optimized_build_option = CppTorchDeviceOptions( + # pyrefly: ignore [bad-argument-type] + compile_only=True, + # pyrefly: ignore [bad-argument-type] + **compile_command, + ) + + def get_hashable_command_line(build_option: BuildOptionsBase) -> str: + """Writing the code to file will calculate a hash, which we need to vary if + the command line flags change. This implements a mostly-generic way of + validating that.""" + return CppBuilder( + name="o", sources="i", BuildOption=build_option + ).get_command_line() + + main_cmd_line = get_hashable_command_line(main_build_option) + optimized_cmd_line = get_hashable_command_line(optimized_build_option) + + key, main_path = write( + main_code, "main.cpp", extra=f"{optimized_code} {main_cmd_line}" + ) + + # Don't bother writing if the argument is empty. + if optimized_code: + _, optimized_path = write( + optimized_code, "optimized.cpp", extra=optimized_cmd_line + ) + else: + # Unused, but makes type checkers happy. + optimized_path = os.devnull + + if key not in cls.cache: + from torch.utils._filelock import FileLock + + lock_path = os.path.join(get_lock_dir(), key + ".lock") + future: Future[Any] | None = None + lib = None + + # if requested, pre-compile any headers + if config.cpp_cache_precompile_headers and not _IS_WINDOWS: + if header := cls._get_uncompiled_header(device_type): + main_build_option.precompiled_header = _precompile_header( + header, + main_cmd_line, + min_optimize=optimized_code is not None, + **compile_command, + ) + + # Currently, the optimized_code field is only used for cpp kernel code, + # so go ahead and precompile the relevant header here. Revisit this + # decision if that ever changes. + if optimized_code and (header := _get_cpp_prefix_header(device_type)): + optimized_build_option.precompiled_header = _precompile_header( + # pyrefly: ignore [unbound-name] + header, + optimized_cmd_line, + **compile_command, + ) + + main_name, output_dir = get_name_and_dir_from_output_file_path(main_path) + main_builder = CppBuilder( + name=main_name, + sources=main_path, + BuildOption=main_build_option, + output_dir=output_dir, + ) + + if optimized_code: + optimized_name, _ = get_name_and_dir_from_output_file_path( + optimized_path + ) + optimized_builder = CppBuilder( + name=optimized_name, + sources=optimized_path, + BuildOption=optimized_build_option, + output_dir=output_dir, + ) + + linker = CppBuilder( + name=main_name, + sources=[ + main_builder.get_target_file_path(), + optimized_builder.get_target_file_path(), + ], + # pyrefly: ignore [bad-argument-type] + BuildOption=CppTorchDeviceOptions(**compile_command), + output_dir=output_dir, + ) + + worker_fn = functools.partial( + _worker_compile_cpp, + lock_path, + (main_builder, optimized_builder, linker), + ) + binary_path = normalize_path_separator(linker.get_target_file_path()) + else: + worker_fn = functools.partial( + _worker_compile_cpp, lock_path, (main_builder,) + ) + binary_path = normalize_path_separator( + main_builder.get_target_file_path() + ) + + def load_fn() -> Any: + nonlocal lib + if lib is None: + if future is not None: + future.result() + result = worker_fn() + assert result is None + lib = cls._load_library(binary_path, key) + assert lib is not None + return lib + + if submit_fn is not None: + with FileLock(lock_path, timeout=LOCK_TIMEOUT): + if not os.path.exists(binary_path): + future = submit_fn(worker_fn) + + cls.cache[key] = load_fn + + return cls.cache[key] + + @classmethod + def load(cls, *args: Any, **kwargs: Any) -> Any: + return cls.load_async(*args, **kwargs)() + + +def _worker_compile_cpp( + lock_path: str, + cpp_builders: Sequence[CppBuilder], +) -> None: + from torch.utils._filelock import FileLock + + with FileLock(lock_path, timeout=LOCK_TIMEOUT): + for builder in cpp_builders: + if not os.path.exists(builder.get_target_file_path()): + builder.build() + + +# Customized Python binding for cpp kernels +@clear_on_fresh_cache +class CppPythonBindingsCodeCache(CppCodeCache): + cache: dict[str, Callable[[], CDLL | ModuleType]] = {} + cache_clear = staticmethod(cache.clear) + cpp_compile_command_flags = { + # kernels have no dependency on libtorch + "include_pytorch": False, + "shared": True, + } + entry_function = "kernel" + call_entry_function = "kernel({}); Py_RETURN_NONE;" + extra_parse_arg = "" + suffix_template = textwrap.dedent( + """ + // Python bindings to call {entry_func}(): + #define PY_SSIZE_T_CLEAN + #include + #include + #include + + #ifndef _MSC_VER + #if __cplusplus < 202002L + // C++20 (earlier) code + // https://en.cppreference.com/w/cpp/language/attributes/likely + #define likely(x) __builtin_expect(!!(x), 1) + #define unlikely(x) __builtin_expect(!!(x), 0) + #endif + #else + #define likely(x) (x) + #define unlikely(x) (x) + #endif + + // This is defined in guards.cpp so we don't need to import PyTorch headers that are slooow. + // We manually link it below to workaround issues with fbcode build. + static void* (*_torchinductor_pyobject_tensor_data_ptr)(PyObject* obj); + + template static inline T parse_arg(PyObject* args, size_t n) {{ + static_assert(std::is_pointer_v, "arg type must be pointer or long"); + return static_cast(_torchinductor_pyobject_tensor_data_ptr(PyTuple_GET_ITEM(args, n))); + }} + template <> inline int64_t parse_arg(PyObject* args, size_t n) {{ + auto result = PyLong_AsSsize_t(PyTuple_GET_ITEM(args, n)); + if(unlikely(result == -1 && PyErr_Occurred())) + throw std::runtime_error("expected int arg"); + return result; + }} + template <> inline uintptr_t parse_arg(PyObject* args, size_t n) {{ + auto result = PyLong_AsVoidPtr(PyTuple_GET_ITEM(args, n)); + if(unlikely(result == reinterpret_cast(-1) && PyErr_Occurred())) + throw std::runtime_error("expected int arg"); + return reinterpret_cast(result); + }} + template <> inline float parse_arg(PyObject* args, size_t n) {{ + auto result = PyFloat_AsDouble(PyTuple_GET_ITEM(args, n)); + if(unlikely(result == -1.0 && PyErr_Occurred())) + throw std::runtime_error("expected float arg"); + return static_cast(result); + }} + + {extra_parse_arg} + + static PyObject* {entry_func}_py(PyObject* self, PyObject* args) {{ + try {{ + if(unlikely(!PyTuple_CheckExact(args))) + throw std::runtime_error("tuple args required"); + if(unlikely(PyTuple_GET_SIZE(args) != {arg_len})) + throw std::runtime_error("requires {arg_len} args"); + {call_entry_func} + }} catch(std::exception const& e) {{ + PyErr_SetString(PyExc_RuntimeError, e.what()); + return nullptr; + }} catch(...) {{ + PyErr_SetString(PyExc_RuntimeError, "unhandled error"); + return nullptr; + }} + }} + + static PyMethodDef py_methods[] = {{ + {{"{entry_func}", {entry_func}_py, METH_VARARGS, ""}}, + {{NULL, NULL, 0, NULL}}}}; + + static struct PyModuleDef py_module = + {{PyModuleDef_HEAD_INIT, "{entry_func}", NULL, -1, py_methods}}; + + PyMODINIT_FUNC PyInit_{entry_func}(void) {{ + const char* str_addr = std::getenv("_TORCHINDUCTOR_PYOBJECT_TENSOR_DATA_PTR"); + if(!str_addr) {{ + PyErr_SetString(PyExc_RuntimeError, "_TORCHINDUCTOR_PYOBJECT_TENSOR_DATA_PTR must be set"); + return nullptr; + }} + std::istringstream iss(str_addr); + uintptr_t addr = 0; + iss >> addr; + _torchinductor_pyobject_tensor_data_ptr = + reinterpret_cast(addr); + PyObject* module = PyModule_Create(&py_module); + if (module == NULL) {{ + return NULL; + }} + #ifdef Py_GIL_DISABLED + PyUnstable_Module_SetGIL(module, Py_MOD_GIL_NOT_USED); + #endif + return module; + }} + """ + ) + + @classmethod + # pyrefly: ignore [bad-override] + def _load_library_inner(cls, path: str, key: str) -> ModuleType: + os.environ["_TORCHINDUCTOR_PYOBJECT_TENSOR_DATA_PTR"] = str( + torch._C._dynamo.guards._torchinductor_pyobject_tensor_data_ptr # type: ignore[attr-defined] + ) + module_name = f"{key}.{cls.entry_function}" + try: + return sys.modules[module_name] + except KeyError: + pass + spec = importlib.util.spec_from_file_location(module_name, path) + assert spec is not None + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + assert spec.loader is not None + spec.loader.exec_module(module) + return module + + @classmethod + def _get_uncompiled_header(cls, device: str) -> str | None: + return _get_cpp_prefix_header(device) + + @classmethod + def load_pybinding_async( + cls, + argtypes: Sequence[str], + main_code: str, + device_type: str = "cpu", + num_outputs: int = -1, + submit_fn: Any = None, + extra_flags: Sequence[str] = (), + kernel_code: str | None = None, + ) -> Any: + """ + Wrap a C++ function in fast Python bindings. + + Args: + argtypes: The types of args to ENTRY_FUNCTION(), e.g. ["float*", "long"] + main_code: C++ source code containing ENTRY_FUNCTION(). Will be built at + -O3 if kernel_code is None (to maximize performance in any kernels that + are present), or -O1 otherwise (to minimize compile time). + kernel_code: If present, C++ source code that will be built at -O3 and + linked to main_code. + + Returns: + A python version of ENTRY_FUNCTION() + """ + parseargs = ", ".join( + f"parse_arg<{argtype.replace('const ', '')}>(args, {n})" + for n, argtype in enumerate(argtypes) + ) + suffix = cls.suffix_template.format( + arg_len=len(argtypes), + call_entry_func=cls.call_entry_function.format(parseargs), + entry_func=cls.entry_function, + extra_parse_arg=cls.extra_parse_arg.format(array_len=num_outputs), + ) + get_result = cls.load_async( + main_code + suffix, + device_type, + submit_fn=submit_fn, + extra_flags=extra_flags, + optimized_code=kernel_code, + ) + result = None + + def future() -> Any: + nonlocal result + if result is None: + result = get_result() + assert isinstance(result, ModuleType) + return getattr(result, cls.entry_function) + + return future + + @classmethod + def load_pybinding(cls, *args: Any, **kwargs: Any) -> Any: + return cls.load_pybinding_async(*args, **kwargs)() + + +@clear_on_fresh_cache +class CppWrapperCodeCache(CppPythonBindingsCodeCache): + cache: dict[str, Callable[[], CDLL | ModuleType]] = {} + cache_clear = staticmethod(cache.clear) + cpp_compile_command_flags = { + "include_pytorch": True, + "shared": True, + } + entry_function = "inductor_entry_cpp" + call_entry_function = "return inductor_entry_cpp({});" + extra_parse_arg = textwrap.dedent( + """ + #include + + static inline std::vector unpack_tensor_handle_list(PyObject* pyvec) {{ + std::vector result; + size_t result_len = PyList_GET_SIZE(pyvec); + result.reserve(result_len); + for (size_t i = 0; i < result_len; i++) {{ + // AtenTensorHandle is essentially a pointer + void* elem = PyCapsule_GetPointer(PyList_GET_ITEM(pyvec, i), NULL); + result.push_back(reinterpret_cast(elem)); + }} + return result; + }} + + static inline PyObject* pack_tensor_handle_list(const std::array& arr) {{ + PyObject* result = PyList_New({array_len}); + for (size_t i = 0; i < {array_len}; i++) {{ + PyObject *elem = + arr[i] == nullptr + ? Py_None + // Store AtenTensorHandle as PyCapsulate + : PyCapsule_New(reinterpret_cast(arr[i]), NULL, NULL); + PyList_SET_ITEM(result, i, elem); + }} + return result; + }} + + template <> inline std::vector parse_arg>(PyObject* args, size_t n) {{ + return unpack_tensor_handle_list(PyTuple_GET_ITEM(args, n)); + }} + + PyObject* inductor_entry_cpp(std::vector&& input_handles) {{ + // For outputs, we only allocate an array to hold returned tensor handles, + // not the actual output tensor storage. + std::array output_handles{{}}; + try {{ + inductor_entry_impl(input_handles.data(), output_handles.data()); + if (PyErr_Occurred()) {{ + return nullptr; + }} + return pack_tensor_handle_list(output_handles); + }} catch(std::exception const& e) {{ + PyErr_SetString(PyExc_RuntimeError, e.what()); + return nullptr; + }} catch(...) {{ + PyErr_SetString(PyExc_RuntimeError, "unhandled error"); + return nullptr; + }} + }} + """ + ) + + @classmethod + def _get_uncompiled_header(cls, device: str) -> str | None: + return _get_cpp_wrapper_header(device) + + +@clear_on_fresh_cache +class HalideCodeCache(CppPythonBindingsCodeCache): + cache: dict[str, Callable[[], ModuleType | CDLL]] = {} + cache_clear = staticmethod(cache.clear) + _standalone_runtime_path: str | None = None + prefix = textwrap.dedent( + """ + #include "{halideruntime_h}" + #include "{headerfile}" + #include + #include + + namespace c10 {{ + inline long div_floor_integer(long a, long b) {{ + if ((a<0) != (b<0)) {{ + const auto quot = a / b; + const auto rem = a % b; + return rem ? quot - 1 : quot; + }} + return a / b; + }} + }} + """ + ) + glue_template_cpp = prefix + textwrap.dedent( + """ + void kernel({argdefs}) {{ + {buffers} + int err = halide_kernel({buffer_names}); + if(err != 0) throw std::runtime_error("halide_kernel failed"); + }} + """ + ) + glue_template_cuda = prefix + textwrap.dedent( + """ + #include + static const halide_device_interface_t* cuda_interface = halide_cuda_device_interface(); + + void kernel({argdefs}, uintptr_t stream) {{ + {buffers} + int err = halide_kernel(reinterpret_cast(stream), {buffer_names}); + if(err != 0) throw std::runtime_error("halide_kernel failed"); + }} + """ + ) + standalone_runtime_cuda_init = textwrap.dedent( + """ + #include "{}" + #include + + static int acquire_context(void* user_context, + void** cuda_context_out, + bool create) {{ + return cuCtxGetCurrent(reinterpret_cast(cuda_context_out)); + }} + + static int release_context(void* user_context) {{ + return 0; + }} + + static int get_stream(void* user_context, + void* cuda_context, + void** stream_out) {{ + *stream_out = user_context; + return 0; + }} + + static int register_halide_hooks() {{ + halide_set_cuda_acquire_context(&acquire_context); + halide_set_cuda_release_context(&release_context); + halide_set_cuda_get_stream(&get_stream); + return 0; + }} + + int inductor_register_halide_hooks_result = register_halide_hooks(); + """ + ) + + @classmethod + def _codegen_buffer(cls, name: str, arg: HalideInputSpec, cuda: bool) -> list[str]: + assert arg.shape is not None + assert arg.stride is not None and len(arg.shape) == len(arg.stride) + assert arg.offset is not None + data_ptr = f"{arg.alias_of or arg.name} + {arg.offset}" + if cuda: + device = f"reinterpret_cast({data_ptr})" + device_interface = "cuda_interface" + host = "nullptr" + flags = "halide_buffer_flag_device_dirty" + else: + device = "0" + device_interface = "nullptr" + host = f"reinterpret_cast({data_ptr})" + flags = "halide_buffer_flag_host_dirty" + + dims = [] + for size, stride in zip(arg.shape, arg.stride): + dims.append(f"halide_dimension_t(0, {size}, {stride})") + + return [ + f"halide_buffer_t {name};", + f"halide_dimension_t {name}_dims[] = {{{', '.join(dims)}}};" + if len(dims) > 0 + else f"halide_dimension_t * {name}_dims = nullptr;", + f"{name}.device = {device};", + f"{name}.device_interface = {device_interface};", + f"{name}.host = {host};", + f"{name}.flags = {flags};", + f"{name}.type = {arg.halide_type()};", + f"{name}.dimensions = {len(dims)};", + f"{name}.dim = {name}_dims;", + f"{name}.padding = nullptr;", + ] + + @classmethod + def _codegen_glue(cls, meta: HalideMeta, headerfile: object) -> str: + is_cuda = meta.is_cuda() + assert is_cuda is ("user_context" in meta.target) + assert "no_runtime" in meta.target + buffers = [] + buffer_names = [] + for i, arg in enumerate(meta.argtypes): + if arg.is_buffer(): + # pyrefly: ignore [bad-argument-type] + buffer_names.append(f"&hl_buf_{i}") + buffers.extend(cls._codegen_buffer(f"hl_buf_{i}", arg, is_cuda)) + else: + assert "*" not in arg.ctype + # pyrefly: ignore [bad-argument-type] + buffer_names.append(arg.name) + buffers = "\n".join([f" {line}" for line in buffers]).lstrip() + + glue_template = cls.glue_template_cuda if is_cuda else cls.glue_template_cpp + glue_code = glue_template.format( + halideruntime_h=cls.find_header( + "HalideRuntimeCuda.h" if is_cuda else "HalideRuntime.h" + ), + headerfile=headerfile, + argdefs=", ".join( + f"{a.bindings_type()} {a.name}" + for a in meta.argtypes + if a.alias_of is None + ), + buffers=buffers, + buffer_names=", ".join(buffer_names), + ) + return glue_code + + @classmethod + @functools.cache + def config_hash(cls) -> str: + command_gen = CppBuilder( + name="O", + sources="I", + BuildOption=CppOptions(), + ) + command_line = command_gen.get_command_line() + return sha256_hash( + "\n".join( + [ + cls.glue_template_cpp, + cls.glue_template_cuda, + cls.standalone_runtime_cuda_init, + command_line, + ] + ).encode("utf-8") + ) + + @staticmethod + def _search_for_file(suffix: str, errmsg: str) -> str: + spec = importlib.machinery.PathFinder.find_spec("halide") + if spec is None or not spec.submodule_search_locations: + raise RuntimeError("halide python bindings not installed") + try: + search = spec.submodule_search_locations[0] + for file in os.listdir(search): + if file.endswith(".so"): + try: + out = subprocess.check_output( + ["ldd", os.path.join(search, file)] + ) + except subprocess.SubprocessError: + continue + m = re.search(r"(/.*)/libHalide.so", out.decode("utf-8")) + if m: + path = os.path.join(os.path.abspath(m.group(1)), suffix) + if os.path.exists(path): + return os.path.abspath(path) + except Exception as e: + raise RuntimeError(errmsg) from e + raise RuntimeError(errmsg) + + @staticmethod + @functools.cache + def find_libautoschedule(name: str) -> str: + sofile = f"libautoschedule_{name.lower()}.so" + if "HALIDE_LIB" in os.environ: + path = os.path.join(os.environ["HALIDE_LIB"], sofile) + if os.path.exists(path): + return path + errmsg = ( + f"Can't find {sofile}, set env HALIDE_LIB to the directory containing it" + ) + return HalideCodeCache._search_for_file(sofile, errmsg) + + @staticmethod + @functools.cache + def find_header(name: str) -> str: + if "HALIDE_INCLUDE" in os.environ: + path = os.path.join(os.environ["HALIDE_INCLUDE"], name) + if os.path.exists(path): + return path + if "HALIDE_LIB" in os.environ: + path = os.path.abspath( + os.path.join(os.environ["HALIDE_LIB"], f"../include/{name}") + ) + if os.path.exists(path): + return path + errmsg = ( + f"Can't find {name}, set env HALIDE_INCLUDE to the directory containing it" + ) + return HalideCodeCache._search_for_file(f"../include/{name}", errmsg) + + @classmethod + def generate_halide_async( + cls, meta: HalideMeta, source_code: str, submit_fn: Any = None + ) -> Callable[[], Any]: + dirpath = Path( + get_path( + code_hash( + source_code, + extra=repr((cls.config_hash(), meta)), + ), + "halide", + )[2] + ) + os.makedirs(dirpath, exist_ok=True) + wait_for_compile = None + genfile = str(dirpath / "generate_kernel.py") + libfile = str(dirpath / "halide_kernel.a") + headerfile = str(dirpath / "halide_kernel.h") + donefile = str(dirpath / "done") + lockfile = str(dirpath / "lock") + need_compile = not os.path.exists(donefile) + jobs: list[Any] = [] + if need_compile: + write_atomic(genfile, source_code) + cmd = [ + sys.executable, + genfile, + "-g", + "kernel", + "-o", + f"{dirpath}", + "-f", + "halide_kernel", + "-e", + "static_library,h,schedule", + ] + if meta.scheduler: + cmd.extend(["-p", cls.find_libautoschedule(meta.scheduler)]) + cmd.extend(meta.args()) + jobs.append(functools.partial(subprocess.check_call, cmd)) + + binding_types = [ + arg.bindings_type() for arg in meta.argtypes if arg.alias_of is None + ] + if meta.is_cuda(): + binding_types.append("uintptr_t") # stream + bindings_future = cls.load_pybinding_async( + binding_types, + cls._codegen_glue(meta, headerfile), + extra_flags=(libfile, cls.build_standalone_runtime()), + submit_fn=jobs.append if need_compile else None, + device_type="cuda" if meta.is_cuda() else "cpu", + ) + + if need_compile: + jobs.append(functools.partial(touch, donefile)) + task = functools.partial(_worker_task_halide, lockfile, jobs) + if submit_fn: + wait_for_compile = submit_fn(task).result + else: + task() + + def load() -> Callable[[], Any]: + if wait_for_compile: + wait_for_compile() + return bindings_future() + + return load + + @classmethod + def generate_halide(cls, *args: Any, **kwargs: Any) -> Callable[[], Any]: + return cls.generate_halide_async(*args, **kwargs)() + + @classmethod + def build_standalone_runtime(cls) -> str: + if cls._standalone_runtime_path and os.path.exists( + cls._standalone_runtime_path + ): + return cls._standalone_runtime_path + device_type = "cuda" if torch.cuda.is_available() else "cpu" + libname = "libStandaloneHalideRuntime.so" + target = "host-cuda" if device_type == "cuda" else "host" + if cls._standalone_runtime_path: + assert not os.path.exists(cls._standalone_runtime_path) + # We hit this case in unittests when we run with fresh_cache() + # Generating a fresh runtime over and over causes errors because we initialize + # cuda hundreds of times in the same process and run out of file descriptors. + # Workaround by jail breaking the current fresh_cache(). + base = default_cache_dir() + else: + base = cache_dir() + dirpath = Path(base) / f"halide-runtime-{target}-{cls.config_hash()}" + os.makedirs(dirpath, exist_ok=True) + done_file = str(dirpath / "done") + lock_file = str(dirpath / "lock") + hook_file = str(dirpath / "hooks.cpp") + a_file = str(dirpath / "standalone_halide_runtime.a") + so_file = str(dirpath / libname) + if not os.path.exists(done_file): + import halide as hl # type: ignore[import-untyped,import-not-found] + + from torch.utils._filelock import FileLock + + with FileLock(lock_file, LOCK_TIMEOUT): + if not os.path.exists(done_file): + with open(hook_file, "w") as f: + if device_type == "cuda": + f.write( + cls.standalone_runtime_cuda_init.format( + cls.find_header("HalideRuntimeCuda.h") + ) + ) + hl.compile_standalone_runtime(a_file, hl.Target(target)) + + name, output_dir = get_name_and_dir_from_output_file_path(so_file) + halide_cmd_gen = CppBuilder( + name=name, + sources=[hook_file, a_file], + output_dir=output_dir, + BuildOption=CppTorchDeviceOptions( + device_type=device_type, + ), + ) + + subprocess.check_call( + shlex.split(halide_cmd_gen.get_command_line()) + ) + touch(done_file) + assert os.path.exists(so_file) + cls._standalone_runtime_path = so_file + return so_file + + @classmethod + def _get_uncompiled_header(cls, device: str) -> str | None: + """Header precompiling is currently disabled for halide.""" + return None + + +def _worker_task_halide(lockfile: str, jobs: list[partial[Any]]) -> None: + from torch.utils._filelock import FileLock + + try: + with FileLock(lockfile, LOCK_TIMEOUT): + for job in jobs: + job() + except subprocess.SubprocessError as e: + if os.environ.get("HALIDE_REPRO") == "1": + cmd: list[Any] + python, script, *cmd = getattr(e, "cmd", ("", "", "")) + if os.path.basename(python).startswith("python"): + code = Path(script).read_text() + main = " hl.main()" + assert code.count(main) == 1 + + class Out: + def __repr__(self) -> str: + return "out" + + ci = cmd.index("-o") + assert isinstance(ci, int) + # pyrefly: ignore [unsupported-operation] + cmd[ci + 1] = Out() + repl = textwrap.indent( + textwrap.dedent( + f"""\ + import sys, tempfile + with tempfile.TemporaryDirectory() as out: + sys.argv = {["repro.py", *cmd]!r} + hl.main() + """ + ), + " ", + ) + code = code.replace(main, repl) + with open("repro.py", "w") as fd: + fd.write(code.lstrip()) + raise RuntimeError(f"wrote repro.py: {e}") from e + raise + + +def touch(filename: str) -> None: + with open(filename, "a"): + pass + + +@clear_on_fresh_cache +class PyCodeCache: + # Track the loaded modules so we can remove the on-disk artifacts when + # clearing the cache. Note also that we may load the same path more + # than once, but attach different attributes, i.e., due to different + # constant values. + modules: list[ModuleType] = [] + + # Modules loaded without extra attributes are stored here, those do not + # need to be re-loaded. + modules_no_attr: dict[str, ModuleType] = {} + + linemaps: dict[str, list[tuple[Any, ...]]] = {} + + @classmethod + def write(cls, source_code: str, extra: str = "") -> tuple[str, str]: + return write(source_code, "py", extra=extra) + + @classmethod + def load(cls, source_code: str, extra: str = "") -> ModuleType: + key, path = write(source_code, "py", extra=extra) + return cls.load_by_key_path(key, path) + + @classmethod + def load_by_key_path( + cls, + key: str, + path: str, + linemap: list[tuple[int, str]] | None = None, + attrs: dict[str, Any] | None = None, + ) -> ModuleType: + if linemap is None: + linemap = [] + + # we only cache when attrs is None + if attrs is None and path in cls.modules_no_attr: + return cls.modules_no_attr[path] + + in_toplevel = in_toplevel_process() + mod = _reload_python_module(key, path, set_sys_modules=in_toplevel) + + # unzip into separate lines/nodes lists + if in_toplevel: + cls.linemaps[path] = list(zip(*linemap)) + + if attrs is not None: + for k, v in attrs.items(): + setattr(mod, k, v) + + if in_toplevel: + # we only cache when attrs is None + if attrs is None: + cls.modules_no_attr[path] = mod + + cls.modules.append(mod) + return mod + + @classmethod + def cache_clear(cls, purge: bool = False) -> None: + """ + Clear the in-memory module cache. If purge=True, also delete all the + corresponding on-disk source files. + """ + if purge: + for mod in cls.modules: + try: + assert mod.__file__ + os.remove(mod.__file__) + except FileNotFoundError: + pass + cls.modules.clear() + cls.modules_no_attr.clear() + + @classmethod + @functools.cache + def stack_frames_for_code( + cls, path: str, lineno: int + ) -> list[dict[str, Any]] | None: + if path not in cls.linemaps: + return None + if len(cls.linemaps[path]) == 0: + return None + # [(starting_line, ), ...] + lines, nodes = cls.linemaps[path] + p = bisect_right(lines, lineno) + if p == 0: + return None + entry = nodes[p - 1] + if not entry: + return None + + def parse_stack_trace(stack_trace: str) -> list[dict[str, Any]]: + # ideally fx stores stack traces as data rather than a string + # but this is not along a performance critical path + regex = r'File "(.+)", line (\d+), in (.+)\n' + matches = re.findall(regex, stack_trace) + return [ + {"filename": f, "line": int(l), "name": n} + for f, l, n in reversed(matches) + ] + + return parse_stack_trace(entry) + + +def _load_triton_kernel_from_source( + kernel_name: str, source_code: str +) -> CachingAutotuner: + return getattr(PyCodeCache.load(source_code), kernel_name) + + +def _cuda_compiler() -> str | None: + if cuda_env.nvcc_exist(config.cuda.cuda_cxx): + return config.cuda.cuda_cxx + if config.is_fbcode(): + return os.path.join(build_paths.sdk_home, "bin", "nvcc") + if cuda_env.nvcc_exist(os.getenv("CUDACXX")): + return os.getenv("CUDACXX", "") + if cuda_env.nvcc_exist(os.getenv("CUDA_HOME")): + return os.path.realpath(os.path.join(os.getenv("CUDA_HOME", ""), "bin/nvcc")) + return "nvcc" + + +def _cutlass_path() -> str: + if config.is_fbcode(): + from libfb.py import parutil + + return parutil.get_dir_path("cutlass-4-headers") + else: + return config.cuda.cutlass_dir + + +def _cutlass_paths() -> list[str]: + return [ + "include", + "tools/library/include", + "tools/library/src", + "tools/util/include", + ] + + +def _clone_cutlass_paths(build_root: str) -> list[str]: + paths = _cutlass_paths() + cutlass_root = _cutlass_path() + for path in _cutlass_paths(): + old_path = os.path.join(cutlass_root, path) + new_path = os.path.join(build_root, path) + shutil.copytree(old_path, new_path, dirs_exist_ok=True) + return paths + + +def _cutlass_include_paths() -> list[str]: + cutlass_path = _cutlass_path() + return [ + # Use realpath to get canonical absolute paths, in order not to mess up cache keys + os.path.realpath(os.path.join(cutlass_path, path)) + for path in _cutlass_paths() + ] + + +@torch_key_cache +def cutlass_key() -> bytes: + """ + Compute a key representing the state of the CUTLASS library. + + Note: OSS and fbcode will have different keys. + """ + if config.is_fbcode(): + with ( + importlib.resources.path( + "cutlass_library", "src_hash.txt" + ) as resource_path, + open(resource_path) as resource_file, + ): + return resource_file.read().encode() + + combined_hash = hashlib.sha256() + build_code_hash([config.cuda.cutlass_dir], "", combined_hash) + return combined_hash.digest() + + +def _cuda_lib_options() -> list[str]: + """ + Util function for CUTLASS backend to find the correct CUDA libraries. + """ + _set_gpu_runtime_env() # cpp_extension consults the env + from torch.utils import cpp_extension + + lpaths = cpp_extension.library_paths(device_type="cuda") + if use_re_build(): + lpaths += [ + build_paths.sdk_lib, + os.path.join(build_paths.sdk_lib, "stubs"), + ] + extra_ldflags: list[str] = [] + if is_linux(): + _transform_cuda_paths(lpaths) + for path in lpaths: + if "torch/lib" in path: + # don't want to depend on pytorch + continue + extra_ldflags.append(f"-L{path}") + # -rpath ensures the DLL can find its dependencies when loaded, even + # if the library path is non-standard. + # But do not add the stubs folder to rpath as the driver is expected to be found at runtime + if os.path.basename(path) != "stubs": + extra_ldflags.extend(["-Xlinker", f"-rpath={path}"]) + extra_ldflags.append("-lcuda") + extra_ldflags.append("-lcudart") + else: + raise NotImplementedError( + "Unsupported env, failed to find cuda libs! Currently only Linux is supported." + ) + return extra_ldflags + + +def _nvcc_host_compiler_options() -> list[str]: + return [ + "-fPIC", + "-fno-strict-aliasing", + "-fvisibility=hidden", + "-Wconversion", + ] + + +def _nvcc_arch_as_compile_option() -> str: + arch = cuda_env.get_cuda_arch() + if arch == "90": + # Required by cutlass compilation. + return "90a" + if arch == "100": + return "100a" + return arch + + +def _nvcc_compiler_options() -> list[str]: + arch = _nvcc_arch_as_compile_option() + code = [f"sm_{arch}", f"compute_{arch}"] + if config.cuda.enable_cuda_lto: + code += [f"lto_{arch}"] + options = [ + "-t=0", + "-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1", + "-DCUTLASS_ENABLE_SM90_EXTENDED_MMA_SHAPES=1", + "-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED", + "-w", + f"-gencode=arch=compute_{arch},code=[{','.join(code)}]", + config.cuda.compile_opt_level, + "-std=c++17", + "--expt-relaxed-constexpr", + "-DNDEBUG", + ] + if config.is_fbcode(): + options.extend(["-ccbin", os.path.dirname(build_paths.gcc)]) + if config.cuda.enable_debug_info: + options.extend(["-lineinfo", "-g", "-DCUTLASS_DEBUG_TRACE_LEVEL=1"]) + if config.cuda.enable_ptxas_info: + options.extend( + [ + "--keep", # Keep the intermediate files for debugging (including ptx, sass, cubin etc.) + "--ptxas-options=--warn-on-local-memory-usage", # warn us if local memory is used in CUDA Kernels + "--ptxas-options=--warn-on-spills", # warn us if register spilling happens in CUDA Kernels + "--resource-usage", # Report on CUDA resource usage (shared mem, registers etc.) + "--source-in-ptx", + ] + ) # Annotate the ptx file with source information + if config.cuda.use_fast_math: + options.extend( + [ + "--use_fast_math", + "-DCUTLASS_USE_TANH_FOR_SIGMOID=1", + ] + ) + return options + + +def cuda_compile_command( + src_files: list[str], + dst_file: str, + dst_file_ext: str, + extra_args: list[str] | None = None, +) -> str: + if extra_args is None: + extra_args = [] + if use_re_build(): + build_path = os.path.dirname(dst_file) + include_paths = _clone_cutlass_paths(build_path) + src_files = [os.path.basename(src_file) for src_file in src_files] + dst_file = os.path.basename(dst_file) + else: + include_paths = _cutlass_include_paths() + cuda_lib_options = _cuda_lib_options() + nvcc_host_compiler_options = _nvcc_host_compiler_options() + nvcc_compiler_options = _nvcc_compiler_options() + options = ( + nvcc_compiler_options + + extra_args + + [ + f"-Xcompiler {opt}" if "=" in opt else f"-Xcompiler={opt}" + for opt in nvcc_host_compiler_options + ] + + ["-I" + path for path in include_paths] + + cuda_lib_options + ) + src_file = " ".join(src_files) + res = "" + if dst_file_ext == "o": + res = f"{_cuda_compiler()} {' '.join(options)} -c -o {dst_file} {src_file}" + elif dst_file_ext == "so": + options.append("-shared") + res = f"{_cuda_compiler()} {' '.join(options)} -o {dst_file} {src_file}" + elif dst_file_ext == "exe": + res = f"{_cuda_compiler()} {' '.join(options)} -o {dst_file} {src_file}" + else: + raise NotImplementedError(f"Unsupported output file suffix {dst_file_ext}!") + if log.isEnabledFor(logging.DEBUG): + log.debug("CUDA command: %s", res) + else: + autotuning_log.debug("CUDA command: %s", res) + return res + + +class DLLWrapper: + """A wrapper for a dynamic library.""" + + def __init__( + self, + lib_path: str, + ) -> None: + self.lib_path = lib_path + self.is_open = False + self.DLL = cdll.LoadLibrary(lib_path) + self.is_open = True + + def close(self) -> None: + if self.is_open: + self._dlclose() + self.is_open = False + + def _dlclose(self) -> None: + f_dlclose = None + + if is_linux(): + syms = CDLL(None) + if not hasattr(syms, "dlclose"): + # Apline Linux + syms = CDLL("libc.so") + + if hasattr(syms, "dlclose"): + f_dlclose = syms.dlclose + elif is_windows(): + import ctypes + + kernel32 = ctypes.CDLL("kernel32", use_last_error=True) + + f_dlclose = kernel32.FreeLibrary + else: + raise NotImplementedError("Unsupported env, failed to do dlclose!") + + if f_dlclose is not None: + if is_linux(): + f_dlclose.argtypes = [c_void_p] + f_dlclose(self.DLL._handle) + elif is_windows(): + import ctypes + from ctypes import wintypes + + f_dlclose.argtypes = [wintypes.HMODULE] + f_dlclose(self.DLL._handle) + else: + log.warning( + "dll unloading function was not found, library may not be unloaded properly!" + ) + + def __getattr__(self, name: str) -> Callable[..., None]: + if not self.is_open: + raise RuntimeError(f"Cannot use closed DLL library: {self.lib_path}") + + method = getattr(self.DLL, name) + + def _wrapped_func(*args: Any) -> None: + err = method(*args) + if err: + raise RuntimeError(f"Error in function: {method.__name__}") + + return _wrapped_func + + def __enter__(self) -> Self: + return self + + def __exit__(self, *args: Any) -> None: + self.close() + + def __del__(self) -> None: + self.close() + + +@lru_cache +def binary_error_path(output_path: str) -> str: + """ + standard format for the error path + """ + return output_path + ".error" + + +@clear_on_fresh_cache +class CUDACodeCache: + """ + A cache for managing the compilation and loading of CUDA source code specifically for CUTLASS. + This class handles writing source code to files, compiling them into shared objects, and caching + the results to avoid redundant compilations. It also manages error handling and logging for the + compilation process. + """ + + @dataclasses.dataclass + class CacheEntry: + input_path: str + output_path: str + error_json: str | None = None + + cache: dict[str, CacheEntry] = {} + aot_kernels_o: list[str] = [] + _SOURCE_CODE_SUFFIX = "cu" + + @staticmethod + def cache_clear() -> None: + CUDACodeCache.cache.clear() + CUDACodeCache.aot_kernels_o.clear() + + @staticmethod + @lru_cache(maxsize=4) + def get_kernel_binary_remote_cache( + caching_enabled: bool, caching_available: bool + ) -> Any | None: + """ + Get or create the class instance of the CUTLASSKernelBinaryRemoteCache. + + Args: + caching_enabled: Whether binary remote caching is enabled + caching_available: Whether we're in fbcode environment + + Returns: + CUTLASSKernelBinaryRemoteCache: The class instance of the kernel binary remote cache + """ + if not caching_enabled: + log.debug("CUTLASSKernelBinaryRemoteCache not requested, skipping") + return None + if not caching_available: + return None + + try: + from torch._inductor.fb.kernel_binary_remote_cache import ( + CUTLASSKernelBinaryRemoteCache, + ) + + return CUTLASSKernelBinaryRemoteCache() + except ImportError: + log.debug( + "CUTLASSKernelBinaryRemoteCache not available, remote caching disabled" + ) + return None + + @classmethod + @lru_cache(None) + def write(cls, source_code: str, dst_file_ext: str) -> tuple[str, str]: + """ + Writes source code into a file with dst_file_ext as the file extension. + Returns the hash key of source code, and the path to the file. + """ + + if config.cuda.cutlass_hash_with_compile_cmd: + cuda_command = repr( + cuda_compile_command(["dummy_input"], "dummy_output", dst_file_ext) + ) + extra = cuda_command + else: + extra = repr( + [ + # nvcc and cuda hash + _cuda_compiler(), + # cutlass flags and gcc hash + _nvcc_compiler_options(), + # flags + _nvcc_host_compiler_options(), + # cutlass key + cutlass_key(), + # hack to deal with AOTI .o compilation + ] + ) + key, input_path = write(source_code, cls._SOURCE_CODE_SUFFIX, extra=extra) + return key, input_path + + @classmethod + def compile( + cls, source_code: str, dst_file_ext: str, extra_args: list[str] | None = None + ) -> tuple[str, str, str]: + """ + Compiles CUDA source_code into a file with dst_file_ext extension. + If dst_file_ext is "so", first compiles to ".o" and then links to ".so". + Returns a tuple of dst_file_path, hash_key, source_code_path + """ + if dst_file_ext == "so": + # Two-step compilation: first compile to .o, then link to .so + obj_path, _, _ = cls.compile(source_code, "o", extra_args) + key, input_path = cls.write(source_code, dst_file_ext) + src_files, operation_name = [obj_path], "Linking" + else: + # Regular compilation for non-.so files + key, input_path = cls.write(source_code, dst_file_ext) + src_files, operation_name = [input_path], "Compilation" + + key_with_ext = key + dst_file_ext + if key_with_ext not in cls.cache: + from torch.utils._filelock import FileLock + + lock_dir = get_lock_dir() + lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT) + with lock: + output_path = input_path[: -len(cls._SOURCE_CODE_SUFFIX)] + dst_file_ext + error_path = binary_error_path(output_path) + binary_remote_cache = cls.get_kernel_binary_remote_cache( + caching_enabled=config.cuda.use_binary_remote_cache + and not config.force_disable_caches, + caching_available=config.is_fbcode(), + ) + if binary_remote_cache is not None: + # The remote cache implementation will only download if the file does + # not already exist locally + binary_remote_cache.get(output_path, error_path) + + if os.path.exists(error_path): + with open(error_path, encoding="utf-8") as fh: + error_json = fh.read() + cmd_parts, error_output = json.loads(error_json) + if ( + binary_remote_cache is not None + and config.cuda.upload_to_binary_remote_cache + ): + # This ensures that a local error is uploaded to the remote cache, + # as we make no assumptions about the remote cache having the same + # information as the local cache + binary_remote_cache.put( + error_path, config.cuda.binary_remote_cache_force_write + ) + cls.cache[key_with_ext] = CUDACodeCache.CacheEntry( + input_path, output_path, error_json + ) + raise exc.CUDACompileError(cmd_parts, error_output) + if not os.path.exists(output_path): + cmd = cuda_compile_command( + src_files, output_path, dst_file_ext, extra_args + ) + with open(input_path, "a") as f: + f.write("\n") + f.write(f"// CUDA {operation_name} cmd\n// {cmd}\n") + start_time = time() + log.debug("CUDA %s: %s", operation_name, cmd) + cmd_parts = cmd.split(" ") + try: + if use_re_build(): + from triton.fb.re_build_helper import run_build_command + + run_build_command( + cmd_parts, + os.path.dirname(input_path), + os.path.basename(output_path), + ) + else: + subprocess.check_output( + cmd_parts, stderr=subprocess.STDOUT, env=os.environ + ) + except subprocess.CalledProcessError as error: + cls._record_cuda_compile_error( + error.output.decode("utf-8"), + key_with_ext, + cmd_parts, + input_path, + output_path, + binary_remote_cache, + ) + raise exc.CUDACompileError(cmd_parts, error.output) from error + except Exception as error: + if "COMPILE FAILED WITH" in str(error): + cls._record_cuda_compile_error( + str(error), + key_with_ext, + cmd_parts, + input_path, + output_path, + binary_remote_cache, + ) + raise exc.CUDACompileError(cmd_parts, str(error)) from error + raise error + end_time = time() + log_duration_msg = f"CUDA {operation_name} took {end_time - start_time} seconds. Command: {cmd}" + log.info(log_duration_msg) + + else: + log.debug( + "CUDA %s skipped: %s since output already exists", + operation_name, + output_path, + ) + # Upload to remote cache if enabled + if ( + binary_remote_cache is not None + and config.cuda.upload_to_binary_remote_cache + ): + # will log on errors, but not fail out + binary_remote_cache.put( + output_path, config.cuda.binary_remote_cache_force_write + ) + cls.cache[key_with_ext] = CUDACodeCache.CacheEntry( + input_path, output_path, None + ) + + cache_entry: CUDACodeCache.CacheEntry = cls.cache[key_with_ext] + if cache_entry.error_json is not None: + # Restore cached Exception and raise it as if we had compiled + cmd_parts, error_output = json.loads(cache_entry.error_json) + raise exc.CUDACompileError(cmd_parts, error_output.encode("utf-8")) + return (cls.cache[key_with_ext].output_path, key, input_path) + + @classmethod + def load(cls, source_code: str, dst_file_ext: str) -> tuple[DLLWrapper, str, str]: + """ + Compiles source code and loads the generated .so file. + Returns a tuple of DLLWrapper, hash_key, source_code_path + """ + + if dst_file_ext != "so": + raise RuntimeError( + f"Only support loading a .so file for now. " + f"Requested file extension: {dst_file_ext}. Source code: {source_code}" + ) + dst_file_path, hash_key, source_code_path = cls.compile( + source_code, dst_file_ext + ) + return (DLLWrapper(dst_file_path), hash_key, source_code_path) + + @classmethod + def _record_cuda_compile_error( + cls, + error_str: str, + key_with_ext: str, + cmd_parts: list[str], + input_path: str, + output_path: str, + # Any here, as the import and type will only work in fbcode + # TODO: Make the typing hint strong here + binary_remote_cache: Any = None, + ) -> None: + error_json = json.dumps([cmd_parts, error_str]) + cls.cache[key_with_ext] = CUDACodeCache.CacheEntry( + input_path, output_path, error_json + ) + error_path = binary_error_path(output_path) + with open(error_path, "w", encoding="utf-8") as fh: + fh.write(error_json) + + # Upload to remote cache directly from memory if enabled + if ( + binary_remote_cache is not None + and config.cuda.upload_to_binary_remote_cache + ): + binary_remote_cache.put( + error_path, config.cuda.binary_remote_cache_force_write + ) + + +@clear_on_fresh_cache +class ROCmCodeCache: + @dataclasses.dataclass + class CacheEntry: + input_path: str + output_path: str + + cache: dict[str, CacheEntry] = {} + aot_kernels_o: list[str] = [] + _SOURCE_CODE_SUFFIX = "cpp" + _logged_compiler_version = False + + @staticmethod + def cache_clear() -> None: + ROCmCodeCache.cache.clear() + ROCmCodeCache.aot_kernels_o.clear() + + @classmethod + def write(cls, source_code: str, dst_file_ext: str) -> tuple[str, str]: + """ + Writes source code into a file with dst_file_ext as the file extension. + Returns the hash key of source code, and the path to the file. + """ + + cuda_command = repr( + rocm_compile_command(["dummy_input"], "dummy_output", dst_file_ext) + ) + key, input_path = write( + source_code, cls._SOURCE_CODE_SUFFIX, extra=cuda_command + ) + return key, input_path + + @classmethod + def compile( + cls, source_code: str, dst_file_ext: str, extra_args: list[str] | None = None + ) -> tuple[str, str, str]: + """ + Compiles source_code into a file with dst_file_ext extension, + using the compile command specific for the ROCm platform. + Returns a tuple of dst_file_path, hash_key, source_code_path + """ + if not cls._logged_compiler_version: + cls._logged_compiler_version = True + log.debug(get_compiler_version_info(str(rocm_compiler()))) + + key, input_path = cls.write(source_code, dst_file_ext) + if key not in cls.cache: + from torch.utils._filelock import FileLock + + lock_dir = get_lock_dir() + lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT) + with lock: + output_path = input_path[: -len(cls._SOURCE_CODE_SUFFIX)] + dst_file_ext + if not os.path.exists(output_path): + cmd = rocm_compile_command( + [input_path], output_path, dst_file_ext, extra_args + ) + start_time = time() + cmd_parts = cmd.split(" ") + try: + output = subprocess.check_output( + cmd_parts, + stderr=subprocess.STDOUT, + text=True, + env=os.environ, + ) + log.debug("Compilation output: %s", output) + except subprocess.CalledProcessError as error: + raise exc.CUDACompileError(cmd_parts, error.output) from error + end_time = time() + log_duration_msg = f"Compilation took {end_time - start_time} seconds. Compile command: {cmd}" + log.info(log_duration_msg) + else: + log.debug( + "Skip compiling %s: output %s already exists", + input_path, + output_path, + ) + cls.cache[key] = ROCmCodeCache.CacheEntry(input_path, output_path) + + return (cls.cache[key].output_path, key, input_path) + + @classmethod + def load(cls, source_code: str, dst_file_ext: str) -> tuple[DLLWrapper, str, str]: + """ + Compiles source code and loads the generated .so file. + Returns a tuple of DLLWrapper, hash_key, source_code_path + """ + + if dst_file_ext != "so": + raise RuntimeError( + f"Only support loading a .so file for now. " + f"Requested file extension: {dst_file_ext}. Source code: {source_code}" + ) + dst_file_path, hash_key, source_code_path = cls.compile( + source_code, dst_file_ext + ) + return (DLLWrapper(dst_file_path), hash_key, source_code_path) + + +class CodeCacheFuture: + def result(self) -> Callable[..., Any]: + raise NotImplementedError + + +class LambdaFuture(CodeCacheFuture): + def __init__( + self, result_fn: Callable[..., Any], future: Future[Any] | None = None + ) -> None: + self.result_fn = result_fn + self.future = future + + def result(self) -> Callable[..., Any]: + return self.result_fn() + + +class StaticAutotunerFuture(CodeCacheFuture): + """ + A statically launchable CachingAutotuner, loaded from TritonBundler + """ + + def __init__(self, static_autotuner: CachingAutotuner) -> None: + # Pickled version of CachingAutotuner + self.static_autotuner = static_autotuner + # This needs to be set in AsyncCompile.triton, in case + # we need to reload the CachingAutotuner from its source code + # We don't store the source code on the CachingAutotuner itself + # since it can be very large. + self.reload_kernel_from_src: Callable[[], Any] | None = None + + def result(self) -> CachingAutotuner: + assert self.reload_kernel_from_src is not None + with dynamo_timed("StaticAutotunerFuture.warm_precompile"): + self.static_autotuner.recheck_autotune_cache( + reload_kernel_from_src=self.reload_kernel_from_src + ) + self.static_autotuner.precompile( # type: ignore[union-attr] + warm_cache_only=False, + reload_kernel=self.reload_kernel_from_src, + static_triton_bundle_key=None, # no need to save again + ) + return self.static_autotuner diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/comm_analysis.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/comm_analysis.py new file mode 100644 index 0000000000000000000000000000000000000000..5b174414a67b67f09146b099e3262364a0bff94f --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/comm_analysis.py @@ -0,0 +1,501 @@ +import functools +import logging +import math +import operator +from enum import IntEnum +from typing import Any, Optional + +import sympy + +import torch +import torch.utils._pytree as pytree +from torch.fx.experimental.symbolic_shapes import hint_int +from torch.fx.operator_schemas import normalize_function + +from . import ir +from .utils import get_dtype_size, snode_args_kwargs, sympy_product +from .virtualized import V + + +log = logging.getLogger(__name__) + + +class NCCL_COLL(IntEnum): + ALL_REDUCE = 0 + ALL_GATHER = 1 + REDUCE_SCATTER = 2 + ALL_TO_ALL = 3 + UNSUPPORTED = 4 + + +class NVIDIA_GPU_TYPE(IntEnum): + VOLTA = 0 + AMPERE = 1 + HOPPER = 2 + + +@functools.lru_cache +def get_gpu_type() -> NVIDIA_GPU_TYPE: + gpu_info = torch.utils.collect_env.get_gpu_info(torch.utils.collect_env.run) or "" + if "V100" in gpu_info: + return NVIDIA_GPU_TYPE.VOLTA + elif "A100" in gpu_info: + return NVIDIA_GPU_TYPE.AMPERE + elif "H100" in gpu_info: + return NVIDIA_GPU_TYPE.HOPPER + else: + # for other gpu types, assume Ampere + return NVIDIA_GPU_TYPE.AMPERE + + +def get_collective_type_from_kernel_name(kernel_name: str) -> NCCL_COLL: + assert kernel_name is not None + if "all_reduce" in kernel_name: + return NCCL_COLL.ALL_REDUCE + elif "all_gather" in kernel_name: + return NCCL_COLL.ALL_GATHER + elif "reduce_scatter" in kernel_name: + return NCCL_COLL.REDUCE_SCATTER + elif any(comm in kernel_name for comm in ("all_to_all", "alltoall")): + return NCCL_COLL.ALL_TO_ALL + else: + return NCCL_COLL.UNSUPPORTED + + +def get_collective_type(node: ir.IRNode) -> NCCL_COLL: + if not isinstance(node, ir._CollectiveKernel): + raise ValueError(f"node is not a collective kernel: {node}") + + name = node.python_kernel_name + assert name is not None + return get_collective_type_from_kernel_name(name) + + +def get_ir_node_size_numel(size: torch.Size, fallback: int = 4096 * 4096) -> int: + numel = sympy_product(size) + if isinstance(numel, sympy.Integer): + return int(numel) + return V.graph.sizevars.size_hint(numel, fallback=fallback) + + +def get_fx_node_size_numel(size: torch.Size, fallback: int = 4096 * 4096) -> int: + numel = functools.reduce(operator.mul, size, 1) + result = hint_int(numel, fallback=fallback) + return result + + +def get_collective_input_size_bytes(node: ir.IRNode) -> int: + sz_bytes = 0 + for inp in node.inputs: # type: ignore[attr-defined] + numel = get_ir_node_size_numel(inp.layout.size) + sz_bytes += numel * get_dtype_size(inp.layout.dtype) + return sz_bytes + + +def get_collective_group_size(node: ir.IRNode) -> int: + if isinstance(node, ir._CollectiveKernel) and not isinstance(node, ir._WaitKernel): + from torch.distributed.distributed_c10d import _get_group_size_by_name + + return _get_group_size_by_name(node.constant_args[-1]) + else: + raise TypeError(f"Unsupported collective type: {node}") + + +#################################################################################################################### +# The following code and constants are adapted from https://github.com/NVIDIA/nccl/blob/master/src/graph/tuning.cc # +#################################################################################################################### + + +class NCCL_HW(IntEnum): + NVLINK = 0 + PCI = 1 + NET = 2 + + +class NCCL_ALGO(IntEnum): + TREE = 0 + RING = 1 + + +class NCCL_PROTO(IntEnum): + # The ordering and enum values here matches original in + # https://github.com/NVIDIA/nccl/blob/0b083e52096c387bad7a5c5c65b26a9dca54de8c/src/include/devcomm.h#L28 + # For difference between these protocols, see https://github.com/NVIDIA/nccl/issues/281#issuecomment-571816990 + LL = 0 # Low-latency + # LL128 = 1 # Low-latency 128-byte + # SIMPLE = 2 + + +# Latencies in us +# len(NCCL_ALGO) x len(NCCL_PROTO) +# NOTE: use array instead of tensor to prevent incompatibility with fake mode +baseLat = [ + # Tree + [ + 6.8, # LL + ], + # Ring + [ + 6.6, # LL + ], +] + +# Latencies in us +# len(NCCL_HW) x len(NCCL_ALGO) x len(NCCL_PROTO) +hwLat = [ + # NVLINK + [ + [0.6], # Tree (LL) + [0.6], # Ring (LL) + ], + # PCI + [ + [1.0], # Tree (LL) + [1.0], # Ring (LL) + ], + # NET + [ + [5.0], # Tree (LL) + [2.7], # Ring (LL) + ], +] + + +# LL128 max BW per channel +llMaxBws = [ + # Volta-N1/Intel-N2/Intel-N4 + [ + 39.0, + 39.0, + 20.4, + ], + # Ampere-N1/AMD-N2/AMD-N4 + [ + 87.7, + 22.5, # avg of ring & tree + 19.0, + ], + # Hopper-N1/AMD-N2/AMD-N4 + [ + 87.7, + 22.5, # avg of ring & tree + 19.0, + ], +] + + +def estimate_nccl_collective_runtime_nccl_estimator(snode) -> Optional[float]: # type: ignore[no-untyped-def] + kernel = snode.node + assert kernel is not None + py_kernel_name = getattr(kernel, "python_kernel_name", "") + pg_name = kernel.constant_args[-1] # type: ignore[attr-defined] + from torch.distributed.distributed_c10d import _resolve_process_group + + pg = _resolve_process_group(pg_name) + rank: int = torch.distributed.get_rank(pg) + # TODO(ivankobzarev): Figure out how we can use time estimations, + # without cuda allocations. + device = torch.device(f"cuda:{rank}") + + fn = eval(py_kernel_name) + args, kwargs = snode_args_kwargs(snode) + + # TODO(ivankobzarev): fix out variants snode_args_kwargs + if "all_gather_into_tensor_out" in py_kernel_name: + args = args[1:] + args[0] + + with torch.distributed._time_estimator(group=pg, device=device) as time_estimator: + w = fn(*args, **kwargs) + torch.ops._c10d_functional.wait_tensor.default(w) + + est_time_us = time_estimator.estimated_time + # -1000 constant is NCCL return in case of error during estimations. + # Observed it for all_to_all estimations. + if est_time_us < 0: + return None + est_time_ms = est_time_us / 1e3 + return est_time_ms + + +def estimate_nccl_collective_runtime_impl( + tensor_storage_size_bytes: int, group_size: int, coll: NCCL_COLL +) -> float: + """ + Returns estimated NCCL collective runtime in milliseconds (ms). + + The following heuristics are copied from https://github.com/NVIDIA/nccl/blob/master/src/graph/tuning.cc. + We aim to estimate the runtime as accurately as possible. + + Assumptions: + - only ring algorithm (NCCL_ALGO_RING) is used + - only Low-Latency protocol (NCCL_PROTO_LL) is used, i.e. Simple or LL128 is not used + - 8 gpus per node # TODO: Need to find a way to get accurate "gpus per node" and "# nodes" info. + - collective is one of: allreduce, reducescatter, allgather + """ + # Convert bytes to GB + tensor_storage_size_GB = tensor_storage_size_bytes / 1024 / 1024 / 1024 + + # Currently assumes each node has 8 gpus. And when >1 node is used, assumes each node uses all 8 gpus. + # TODO: Need to find a way to get accurate "gpus per node" and "# nodes" info. + num_gpus_per_node = 8 + nNodes = math.ceil(group_size / num_gpus_per_node) + nRanks = group_size # this is total # of gpus globally that participate in this collective op + + if nRanks <= 1: + return 0 + + # Assumes ring algorithm + nccl_algo = NCCL_ALGO.RING + nccl_proto = NCCL_PROTO.LL + + # =============== bandwidth computation =============== + # First compute bandwidth in GB/s; then at the end, convert it to GB/ns + + bwIntra = torch._inductor.config.intra_node_bw + bwInter = torch._inductor.config.inter_node_bw + + compCapIndex = get_gpu_type() + index2 = nNodes - 1 if nNodes <= 2 else 2 + # LL: for single node, we look at GPU type; for multi-node, we look at CPU type + index1 = compCapIndex if nNodes == 1 else 0 + llMaxBw = llMaxBws[index1][index2] + + # NOTE: each step of ring algorithm is synchronized, + # and is bottlenecked by the slowest link which is the inter-node interconnect. + # hence when nNodes >= 2, bw is inter-node bandwidth. + # NOTE: the original code in https://github.com/NVIDIA/nccl/blob/master/src/graph/tuning.cc + # have this as `if nNodes <= 2` which seems wrong. Corrected it here. + bw = bwIntra if nNodes == 1 else bwInter + nChannels = 2 # Assume # channels is 2 + busBw = nChannels * bw + + # Various model refinements + busBw = min( + llMaxBw, + busBw + * (1.0 / 4.0 if (nNodes > 1 or coll == NCCL_COLL.ALL_REDUCE) else 1.0 / 3.0), + ) + + if coll == NCCL_COLL.ALL_REDUCE: + nsteps = 2 * (nRanks - 1) + elif coll == NCCL_COLL.ALL_TO_ALL: + nsteps = 2 * (nRanks - 1) + elif coll in (NCCL_COLL.REDUCE_SCATTER, NCCL_COLL.ALL_GATHER): + nsteps = nRanks - 1 + + # Convert bus BW to algorithm BW (tensor bytes / algoBW = actual execution time) + ratio = (1.0 * nRanks) / nsteps # type: ignore[possibly-undefined] + bandwidth = busBw * ratio + # Convert GB/s to GB/ns + bandwidth_GB_per_ns = bandwidth / 1e9 + + # =============== latency computation =============== + intraHw = NCCL_HW.NVLINK + + if coll == NCCL_COLL.ALL_REDUCE: + if nNodes > 1: + nInterSteps = 2 * nNodes + else: + nInterSteps = 0 + elif coll in (NCCL_COLL.REDUCE_SCATTER, NCCL_COLL.ALL_GATHER, NCCL_COLL.ALL_TO_ALL): + nInterSteps = nNodes - 1 + + # First compute latency in us; then at the end, convert it to ns + latency = baseLat[nccl_algo][nccl_proto] + intraLat = hwLat[intraHw][nccl_algo][nccl_proto] + interLat = hwLat[NCCL_HW.NET][nccl_algo][nccl_proto] + + # Inter-node rings still have to launch nsteps * net overhead. + netOverhead = 0.0 + if nNodes > 1: + netOverhead = 1.0 # getNetOverhead(comm); + intraLat = max(intraLat, netOverhead) + latency += (nsteps - nInterSteps) * intraLat + nInterSteps * interLat # type: ignore[possibly-undefined] + # Convert us to ns + latency_ns = latency * 1e3 + + # =============== final result =============== + transport_ns = tensor_storage_size_GB / bandwidth_GB_per_ns + ns = transport_ns + latency_ns + ms = ns / 1e6 + return ms + + +################################################################################################################ +# The above code and constants are adapted from https://github.com/NVIDIA/nccl/blob/master/src/graph/tuning.cc # +################################################################################################################ + + +def estimate_nccl_collective_runtime(node: ir.IRNode) -> float: + """ + Returns estimated NCCL collective runtime in nanoseconds (ms). + + The following heuristics are copied from https://github.com/NVIDIA/nccl/blob/master/src/graph/tuning.cc. + We aim to estimate the runtime as accurately as possible. + + Assumptions: + - only ring algorithm (NCCL_ALGO_RING) is used + - only Low-Latency protocol (NCCL_PROTO_LL) is used, i.e. Simple or LL128 is not used + - 8 gpus per node # TODO: Need to find a way to get accurate "gpus per node" and "# nodes" info. + - collective is one of: allreduce, reducescatter, allgather + """ + tensor_storage_size_bytes = get_collective_input_size_bytes(node) + group_size = get_collective_group_size(node) + coll = get_collective_type(node) + return estimate_nccl_collective_runtime_impl( + tensor_storage_size_bytes, group_size, coll + ) + + +def estimate_fx_collective_size(fx_node: torch.fx.Node) -> int: + """Estimate the size of a collective operation in bytes, including inputs and outputs.""" + input_bytes = None + + args, kwargs = fx_node.args, fx_node.kwargs + kwargs = dict(kwargs) + + # dont double count pre-allocated buffer passed in + kwargs.pop("out", None) + + def tensor_bytes(t: torch.Tensor) -> int: + return get_fx_node_size_numel(t.size()) * get_dtype_size(t.dtype) + + def add_inp_bytes(inp: torch.fx.Node): + inp_val = inp.meta.get("val", None) + if not isinstance(inp_val, torch.Tensor): + return + + nonlocal input_bytes + if input_bytes is None: + input_bytes = 0 + input_bytes += tensor_bytes(inp_val) + + pytree.tree_map_only( + torch.fx.Node, + add_inp_bytes, + (args, kwargs), + ) + + output_val = fx_node.meta.get("val", None) + + if input_bytes is None or not isinstance(output_val, torch.Tensor): + return 0 + + output_bytes = tensor_bytes(output_val) + + return input_bytes + output_bytes + + +def estimate_fx_collective_memory_footprint(fx_node: torch.fx.Node) -> int: + """Estimate the memory footprint of a collective operation in bytes. + + This returns the total bytes that need to be live concurrently in memory. + For all_reduce, we divide by 2 since it can be done in-place. + """ + from torch._inductor.fx_passes.bucketing import ( + is_all_reduce_tensor as is_all_reduce, + ) + + size = estimate_fx_collective_size(fx_node) + return size if not is_all_reduce(fx_node) else size // 2 + + +def estimate_nccl_collective_runtime_from_fx_node( + fx_node: torch.fx.Node, + override_size: Optional[int] = None, + use_nccl_estimator: bool = True, +) -> float: + """ + Returns estimated NCCL collective runtime in nanoseconds (ms). + + The following heuristics are copied from https://github.com/NVIDIA/nccl/blob/master/src/graph/tuning.cc. + We aim to estimate the runtime as accurately as possible. + + Assumptions: + - only ring algorithm (NCCL_ALGO_RING) is used + - only Low-Latency protocol (NCCL_PROTO_LL) is used, i.e. Simple or LL128 is not used + - 8 gpus per node # TODO: Need to find a way to get accurate "gpus per node" and "# nodes" info. + - collective is one of: allreduce, reducescatter, allgather + """ + from torch.distributed.distributed_c10d import _get_group_size_by_name + + if override_size is None: + tensor_storage_size_bytes = estimate_fx_collective_size(fx_node) + else: + tensor_storage_size_bytes = override_size + + assert not isinstance(fx_node.target, str) + opt_args_kwargs = normalize_function( + fx_node.target, + args=fx_node.args, + kwargs=fx_node.kwargs, + normalize_to_only_use_kwargs=True, + ) + assert opt_args_kwargs is not None + args, kwargs = opt_args_kwargs + + group_name = kwargs["group_name"] + group_size = _get_group_size_by_name(group_name) + assert isinstance(fx_node.target, torch._ops.OpOverload) + coll = get_collective_type_from_kernel_name(fx_node.target.name()) + + def _nccl_estimate() -> Optional[float]: + # TODO: Refactor with estimate_nccl_collective_runtime_nccl_estimator + from torch.distributed.distributed_c10d import ( + _get_pg_default_device, + _resolve_process_group, + ) + + pg = _resolve_process_group(group_name) + if torch.distributed.distributed_c10d.get_backend(pg) == "fake": + # nccl estimator requires real process group + return None + + device = _get_pg_default_device(pg) + backend = pg._get_backend(device) + if not backend.supports_time_estimate: + return None + + flat_args, flat_args_pytree_spec = pytree.tree_flatten((args, kwargs)) + + def _tensor(size, dtype, device) -> torch.Tensor: # type: ignore[no-untyped-def] + return torch.empty( + size if override_size is None else [override_size], + dtype=dtype, + device=device, + ) + + def try_size_hint(s: sympy.Expr) -> int: + return V.graph.sizevars.size_hint(s, fallback=0) + + def to_real_tensor(e: Any) -> Any: + if isinstance(e, torch.fx.Node): + return to_real_tensor(e.meta["val"]) + if isinstance(e, torch.Tensor): + return _tensor([get_fx_node_size_numel(e.size())], e.dtype, e.device) + return e + + flat_args = [to_real_tensor(a) for a in flat_args] + real_args, real_kwargs = pytree.tree_unflatten(flat_args, flat_args_pytree_spec) + + fn = fx_node.target + assert isinstance(fn, torch._ops.OpOverload) + with torch.distributed._time_estimator(group=pg) as time_estimator: + w = fn(*real_args, **real_kwargs) + torch.ops._c10d_functional.wait_tensor.default(w) + est_time_us = time_estimator.estimated_time + # -1000 constant is NCCL return in case of error during estimations. + # Observed it for all_to_all estimations. + if est_time_us < 0: + return None + est_time_ms = est_time_us / 1e3 + return est_time_ms + + if use_nccl_estimator: + est_time_ms = _nccl_estimate() + if est_time_ms is not None: + return est_time_ms + + return estimate_nccl_collective_runtime_impl( + tensor_storage_size_bytes, group_size, coll + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/comm_lowering.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/comm_lowering.py new file mode 100644 index 0000000000000000000000000000000000000000..c9a8460b3c048b6d9d4e51178079c1c4498a627b --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/comm_lowering.py @@ -0,0 +1,393 @@ +# mypy: allow-untyped-defs +import logging + +import torch +import torch.utils._pytree as pytree +from torch._inductor.utils import is_symbolic +from torch.utils._ordered_set import OrderedSet + +from . import config, ir +from .virtualized import V + + +log = logging.getLogger(__name__) + + +# NOTE [lowering-time collective optimization] +# +# In collective communication libraries such as NCCL, every rank maintains +# communication buffers that are remotely accessible by some peers. Depending +# on the underlying transport, remote accessibility may be established via +# mechanisms such as ib_reg_mr, CUDA P2P, or CUDA multicast. Typically, these +# buffers are private to the communication library by default, and +# communication ops copy user data in and out of these buffers. +# +# To prevent these copies, an optimization commonly known as "user buffer +# registration" can be employed. This allows direct establishment of remote +# accessibility on user buffers, eliminating the need for copying. However, +# this optimization introduces stringent usage requirements, which are +# typically hard to satisfy without being intrusive to the user code: +# +# - Establishing remote accessibility is expensive and often done ahead of +# time. In such implementations, all ranks must agree on the set of allocations +# used for every collective op. Failing to meet this requirement can +# lead to runtime errors or even silent correctness issues. +# - Even if the collective communication library supports gracefully falling +# back to "unregistered" implementations, the fallback mechanism would nullify +# the optimization. +# - Some communication mechanisms impose stricter requirements than others. For +# example, CUDA's multicast + multi-mem instructions require all ranks to agree +# not only on the allocations used for every collective but also on the offsets +# within these allocations. +# +# To support all different mechanisms with optimal results, we aim to satisfy +# the strictest requirement for this family of optimizations - we ensures that +# every collective op invocation is guaranteed to operate on the same +# allocation, at the same offset, in every iteration. +# +# For eligible collective ops, we identify communication buffers at lowering +# time and optionally choose to lower the op to a different kernel +# (communication libraries like NCCL handle both registered and non-registered +# buffers transparently within the same op, though some may require different +# ops for different cases). Later, the codegen will perform "persistent +# allocation" to satisfy the aforementioned constraints, and optionally, +# perform buffer planning to optimize overall memory usage. +def can_realize_as_comm_buffer( + x: ir.TensorBox, comm_buffer_type: ir.CommBufferType +) -> bool: + """ + Check if an input can be realized as a comm buffer of the specified + `comm_buffer_type`. + """ + data = _get_data(x) + + if isinstance(data, ir.Loops): + return True + + layout = data.get_output_spec() + if isinstance(layout, ir.CommBufferLayout): + return True + + if isinstance(layout, ir.FlexibleLayout) and not is_symbolic(data.get_numel()): + return True + + return False + + +def realize_as_comm_buffer( + x: ir.TensorBox, + comm_buffer_type: ir.CommBufferType, + group_name: "torch.distributed.distributed_c10d.GroupName", +) -> None: + """ + Realize an input as a comm buffer of the specified `comm_buffer_type`. + + Specifically, this realizes the underlying buffer if it's still unrealized + and changes the layout of the buffer to `ir.CommBufferLayout`. + """ + x.realize() + buffer = _get_data(x) + assert isinstance(buffer, ir.Buffer) + + layout = buffer.get_output_spec() + if isinstance(layout, ir.CommBufferLayout): + return + + if not isinstance(layout, ir.FlexibleLayout): + raise AssertionError( + "A buffer can only be realized as a comm buffer if it " + f"has `FlexibleLayout` (got {layout})." + ) + + if is_symbolic(buffer.get_numel()): + raise AssertionError( + "A buffer with symbolic shape cannot be converted to " + f"a comm buffer (got {layout})." + ) + + buffer.layout = ir.CommBufferLayout( + layout=layout, + comm_buffer_type=comm_buffer_type, + group_name=group_name, + ) + + +def _get_data(x: ir.TensorBox) -> ir.IRNode: + if isinstance(x.data, ir.BaseView): + # TensorBox -> *View -> StorageBox -> IRNode + node = x.data.unwrap_view() + assert isinstance(node, (ir.BaseView, ir.MutableBox)) + return node.data + elif isinstance(x.data, ir.StorageBox): + # TensorBox -> StorageBox -> IRNode + return x.data.data + else: + raise AssertionError( + "Expect the data attr of a `TensorBox` to be either " + f"an `ir.BaseView` or `ir.StorageBox` (got {x.data})." + ) + + +_bufs_to_skip_wait = OrderedSet[tuple[int, str]]() + + +def mark_as_skip_wait(x: ir.IRNode) -> None: + """ + If a non-blocking collective is lowered as a blocking collective, the wait + node in the original graph becomes useless and we can skip the lowering it. + """ + _bufs_to_skip_wait.add((id(V.graph), x.get_name())) + + +def should_skip_wait(x: ir.IRNode) -> bool: + return (id(V.graph), x.get_name()) in _bufs_to_skip_wait + + +def _should_lower_as_one_shot_all_reduce( + inp: ir.TensorBox, + reduce_op: str, + group_name: "torch.distributed.distributed_c10d.GroupName", +): + from torch.distributed._symmetric_memory import is_symm_mem_enabled_for_group + + inp_size = inp.get_numel() * inp.get_dtype().itemsize + return ( + config._collective.auto_select + and is_symm_mem_enabled_for_group(group_name) + and can_realize_as_comm_buffer(inp, ir.CommBufferType.SYMM_MEM) + and reduce_op == "sum" + and inp_size <= config._collective.one_shot_all_reduce_threshold_bytes + ) + + +def _one_shot_all_reduce(inp: ir.TensorBox, reduce_op, group_name): + realize_as_comm_buffer(inp, ir.CommBufferType.SYMM_MEM, group_name) + return pytree.tree_map( + ir.TensorBox.create, + ir.FallbackKernel.create( + torch.ops.symm_mem.one_shot_all_reduce.default, + inp, + reduce_op, + group_name, + ), + ) + + +def register_comm_lowerings(): + """ + Register lowerings for the comm subsystem. + """ + try: + torch.ops._c10d_functional.all_reduce + except AttributeError: + log.info( + "Inductor support for distributed collectives depends on building " + "torch.distributed" + ) + return + + from .lowering import ( + add_layout_constraint, + clone, + constrain_to_fx_strides, + copy_, + register_lowering, + ) + + def register_comm_lowering(fn): + add_layout_constraint(fn, constrain_to_fx_strides) + return register_lowering(fn) + + c10d = torch.ops._c10d_functional + + @register_comm_lowering(c10d.all_reduce) # type: ignore[misc] + def _all_reduce( + inp: ir.TensorBox, + reduce_op: str, + group_name: "torch.distributed.distributed_c10d.GroupName", + ) -> ir.TensorBox: + if _should_lower_as_one_shot_all_reduce(inp, reduce_op, group_name): + return _one_shot_all_reduce(inp, reduce_op, group_name) + + # Lower as c10d.all_reduce_ + inp = clone(inp) + if config.reorder_for_compute_comm_overlap: + # The horizontal fusion of this clone often severely delays the + # scheduling of the all_reduce_ node. Horizontally fusing this + # clone can almost never out-perform scheduling the all_reduce_ + # earlier. Also in most cases, this clone is eliminated via + # in-place reuse. Therefore, we tell the scheduler to not fuse it. + inp.realize() + V.graph.no_fuse_buffer_names.add(inp.get_name()) + # pyrefly: ignore [bad-assignment] + inp = ir.ExternKernel.require_contiguous(inp) + # Because we are lowering as inplace c10d.all_reduce_, we should generate + # _AllReduce_Kernel instead of _AllReduceKernel. + ir._AllReduce_Kernel.create_inplace( + c10d.all_reduce_.default, + inp, # type: ignore[arg-type] + reduce_op, + group_name, # type: ignore[arg-type] + ) + return inp # type: ignore[return-value] + + @register_comm_lowering(c10d.all_reduce_) # type: ignore[misc] + def _all_reduce_( + inp: ir.TensorBox, + reduce_op: str, + group_name: "torch.distributed.distributed_c10d.GroupName", + ) -> ir.TensorBox: + if _should_lower_as_one_shot_all_reduce(inp, reduce_op, group_name): + ret = copy_( + inp, + _one_shot_all_reduce(inp, reduce_op, group_name), + ) + mark_as_skip_wait(ret) + return inp + + # Lower as c10d.all_reduce_ + # pyrefly: ignore [bad-assignment] + inp = ir.ExternKernel.require_contiguous(inp) + ir._AllReduce_Kernel.create_inplace( + c10d.all_reduce_.default, + inp, # type: ignore[arg-type] + reduce_op, + group_name, # type: ignore[arg-type] + ) + return inp # type: ignore[return-value] + + @register_comm_lowering(c10d.all_reduce_coalesced) + def _all_reduce_coalesced(inputs, reduce_op, group_name): + inputs = [clone(inp) for inp in inputs] + ir._CollectiveKernel.create_inplace( + c10d.all_reduce_coalesced_.default, + inputs, + reduce_op, + group_name, + ) + return inputs + + @register_comm_lowering(c10d.all_reduce_coalesced_) + def _all_reduce_coalesced_(inputs, reduce_op, group_name): + ir._CollectiveKernel.create_inplace( + c10d.all_reduce_coalesced_.default, + inputs, + reduce_op, + group_name, + ) + return inputs + + def _create_out_of_place(kernel, inputs, *args) -> ir.IRNode: + node = ir._CollectiveKernel.create_out_of_place(kernel, inputs, *args) + assert isinstance(node, ir.IRNode) + return ir.TensorBox.create(node) + + @register_comm_lowering(c10d.all_gather_into_tensor) + def _all_gather_into_tensor(inp, group_size, group_name): + return _create_out_of_place( + c10d.all_gather_into_tensor.default, + inp, + group_size, + group_name, + ) + + @register_comm_lowering(c10d.all_gather_into_tensor_coalesced) + def _all_gather_into_tensor_coalesced(inputs, group_size, group_name): + return pytree.tree_map( + ir.TensorBox.create, + ir._CollectiveKernel.create_out_of_place( + c10d.all_gather_into_tensor_coalesced.default, + inputs, + group_size, + group_name, + ), + ) + + @register_comm_lowering(c10d.all_gather_into_tensor_out) + def _all_gather_into_tensor_out(inp, group_size, group_name, *, out): + ir._CollectiveKernel.create_inplace( + c10d.all_gather_into_tensor_out.default, + inp, + group_size, + group_name, + out=out, + ) + return out + + @register_comm_lowering(c10d.reduce_scatter_tensor) + def _reduce_scatter_tensor(inp, reduce_op, group_size, group_name): + return _create_out_of_place( + c10d.reduce_scatter_tensor.default, + inp, + reduce_op, + group_size, + group_name, + ) + + @register_comm_lowering(c10d.reduce_scatter_tensor_out) + def _reduce_scatter_tensor_out(inp, reduce_op, group_size, group_name, *, out): + ir._CollectiveKernel.create_inplace( + c10d.reduce_scatter_tensor_out.default, + inp, + reduce_op, + group_size, + group_name, + out=out, + ) + return out + + @register_comm_lowering(c10d.reduce_scatter_tensor_coalesced) + def _reduce_scatter_tensor_coalesced(inputs, reduce_op, group_size, group_name): + return pytree.tree_map( + ir.TensorBox.create, + ir._CollectiveKernel.create_out_of_place( + c10d.reduce_scatter_tensor_coalesced.default, + inputs, + reduce_op, + group_size, + group_name, + ), + ) + + @register_comm_lowering(c10d.all_to_all_single) + def _all_to_all_single(inp, output_split_sizes, input_split_sizes, group_name): + return _create_out_of_place( + c10d.all_to_all_single.default, + inp, + output_split_sizes, + input_split_sizes, + group_name, + ) + + @register_comm_lowering(c10d.broadcast) + def _broadcast(inp, src, group_name): + inp = clone(inp) + ir._CollectiveKernel.create_inplace( + c10d.broadcast_.default, inp, src, group_name + ) + return inp + + @register_comm_lowering(c10d.broadcast_) + def _broadcast_(inp, src, group_name): + ir._CollectiveKernel.create_inplace( + c10d.broadcast_.default, inp, src, group_name + ) + return inp + + @register_comm_lowering(torch.ops._dtensor.shard_dim_alltoall) + def _shard_dim_alltoall(inp, gather_dim, shard_dim, group_name): + return _create_out_of_place( + torch.ops._dtensor.shard_dim_alltoall.default, + inp, + gather_dim, + shard_dim, + group_name, + ) + + @register_comm_lowering(c10d.wait_tensor) + def _wait_tensor(inp): + if should_skip_wait(inp): + return inp + + ir._WaitKernel.create_wait(c10d.wait_tensor.default, inp) + return inp diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/comms.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/comms.py new file mode 100644 index 0000000000000000000000000000000000000000..ba2571f266244c75449e1869d709282053c95820 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/comms.py @@ -0,0 +1,2652 @@ +# mypy: allow-untyped-defs +# pyre-strict +from __future__ import annotations + +import heapq +import importlib +import itertools +import logging +import operator +import sys +import time +from collections import defaultdict +from dataclasses import dataclass +from typing import Any, Optional, TYPE_CHECKING, Union + +import torch +from torch._logging import trace_structured +from torch.multiprocessing.reductions import StorageWeakRef +from torch.utils._ordered_set import OrderedSet + +from . import config, config_comms, ir +from .dependencies import WeakDep + + +if TYPE_CHECKING: + from .ir import IRNode, Operation + +from .memory import ( + estimate_peak_memory, + estimate_peak_memory_allocfree, + FreeableInputBuffer, + get_freeable_input_buf, + SNodeMemory, +) +from .utils import ( + contains_collective, + contains_wait, + find_recursive_deps_of_node, + find_recursive_users_of_node, + is_collective, + is_fallback_op, + is_wait, +) +from .virtualized import V + + +log = logging.getLogger(__name__) +overlap_log = torch._logging.getArtifactLogger(__name__, "overlap") + +if TYPE_CHECKING: + from torch._inductor.scheduler import BaseSchedulerNode + + +def align_runtime_estimations_across_all_distributed_ranks( + snodes: list[BaseSchedulerNode], +): + runtime_estimations = {} + for snode in snodes: + runtime_estimations[snode] = snode.get_estimated_runtime() + import torch.distributed as dist + from torch.distributed.distributed_c10d import _get_default_group + + world_size = dist.get_world_size() + pg = _get_default_group() + gathered_runtime_estimations: list[list[float]] = [[] for _ in range(world_size)] + dist.all_gather_object( + gathered_runtime_estimations, list(runtime_estimations.values()), pg + ) + median_runtime_estimations = torch.median( + torch.tensor(gathered_runtime_estimations), dim=0 + ).values.tolist() + for i in range(len(snodes)): + snodes[i].override_estimated_runtime = median_runtime_estimations[i] + + +def sink_waits(snodes: list[BaseSchedulerNode]) -> list[BaseSchedulerNode]: + """ + Greedily schedules waits as late as possible. + """ + return _schedule_for_comm( + snodes, raise_comms=False, sink_waits=True, reorder_for_overlap=False + ) + + +def raise_comms(snodes: list[BaseSchedulerNode]) -> list[BaseSchedulerNode]: + """ + Greedily schedules comms as early as possible. + """ + return _schedule_for_comm( + snodes, raise_comms=True, sink_waits=False, reorder_for_overlap=False + ) + + +def reorder_compute_for_overlap( + snodes: list[BaseSchedulerNode], +) -> list[BaseSchedulerNode]: + """ + This achieves the following overall scheduling procedure: + Step 1: Given that we've currently scheduled comm N, we now schedule all compute nodes + that are required for comm N + 1 but do not depend on comm N, to run at the same time with comm N. + Step 2: If all those compute nodes are sufficient to overlap comm N, we're done. + Otherwise, we now need to look elsewhere to find compute that overlaps with comm N. + We prioritize compute nodes that are needed sooner. + Step 3: We schedule the compute nodes dependent on comm N and required for comm N + 1. + Step 4: We schedule comm N + 1. + Repeat this for subsequent comm nodes. + """ + return _schedule_for_comm( + snodes, raise_comms=True, sink_waits=True, reorder_for_overlap=True + ) + + +def reorder_communication_preserving_peak_memory( + snodes: list[BaseSchedulerNode], +) -> list[BaseSchedulerNode]: + """ + Reorders communication ops relative to computation ops to improve communication-compute overlapping and hide comm + latency. Stops moving a particular op if it reaches a point that would have increased the peak memory footprint. + + Currently, follows these heuristics (subject to change or tune): + - never reorders collectives relative to one another, for SPMD safety + - has an option for per-collective prefetch limit, but does not enable it by default + - limits the total number of reorder steps to some factor of the graph size to prevent worst-case quadratic + performance + + Prerequisite: sink_comms_and_waits - ensure comm and wait nodes are scheduled as late as possible, respecting data + dependencies. That allows reorder_communication_preserving_peak_memory to take a best case peak-memory snapshot, + and then monotonically improve latency by moving collectives backward in time. + + Peak memory impact is computed in an iterative fashion. First, memory use at each timestep is computed, and global + peak memory is computed as a max over timesteps. Then, when swapping any two adjacent nodes, only the curr-memory + for the earlier of the nodes after the swap is affected. This enables checking step by step whether a swap is + peak-memory-safe, and bailing out if not. Example: + + 0 n0 C0 + 1 n1 C0 + Allocs(n1) - Frees(n1) + 2 n2 C0 + Allocs(n1) - Frees(n1) + Allocs(n2) - Frees(n2) + + 0 n0 C0 + 1 n2 C0 + Allocs(n2) - Frees(n2) <-- After moving n2 to Time 1, only time1 memory changes + 2 n1 C0 + Allocs(n2) - Frees(n2) + Allocs(n1) - Frees(n1) + + """ + reordered_snodes, node_stats = ( + _reorder_communication_preserving_peak_memory_internal(snodes) + ) + + return reordered_snodes + + +@dataclass +class ReorderInfo: + """ + Debug info describing how an individual snode was reordered + """ + + limiting_factor: str = "None" + moves: int = 0 + grouped: int = 0 + grouped_info: str = "" + comm_time: float = -1.0 + comp_time: float = -1.0 + initial_exposed: float = -1.0 + final_exposed: float = -1.0 + overlap_info: str = "None" + + @property + def improvement(self): + return self.initial_exposed - self.final_exposed + + +def is_gemm_like(node: Optional[Union[IRNode, Operation]]) -> bool: + if node is None: + return False + + if is_fallback_op( + node, # type: ignore[arg-type] + torch.ops.aten._scaled_dot_product_flash_attention.default, + ): + return True + + if ( + python_kernel_name := getattr(node, "python_kernel_name", None) + ) and "extern_kernels" in python_kernel_name: + return True + return False + + +def contains_gemm_like(snode: BaseSchedulerNode) -> bool: + from torch._inductor.scheduler import GroupedSchedulerNode + + if isinstance(snode, GroupedSchedulerNode): + return any(contains_gemm_like(x) for x in snode.snodes) + else: + return is_gemm_like(snode.node) + + +def _temp_group_visit_leaves(snode: BaseSchedulerNode, fn): + from torch._inductor.scheduler import GroupedSchedulerNode + + if isinstance(snode, GroupedSchedulerNode) and snode.temp_grouping: + for _snode in snode.snodes: + fn(_snode) + else: + fn(snode) + + +def wait_exposed_communication_time( + snodes_to_wait: list[BaseSchedulerNode], runtimes: dict[BaseSchedulerNode, float] +) -> tuple[float, float, str]: + """ + Calculate exposed communication time for a wait operation by finding its corresponding + collective and accumulating overlapping compute time between them. + + The Wait node must be the last in snodes_to_wait. + Compute time between corresponding Collective and Wait is accumulated. + If there is another pair of Collective and Wait inside, + Only compute before first such Wait' is considered as overlapping. + + Multiple process groups are not modeled so far. + """ + wait_snode = snodes_to_wait[-1] + assert is_wait(wait_snode.node) + assert len(snodes_to_wait) > 1 + idx = len(snodes_to_wait) - 2 + comm_time = 0.0 + comp_time = 0.0 + overlap_info = "" + waits_found = [] + for i in range(idx, -1, -1): + c = snodes_to_wait[i] + if contains_wait(c): + waits_found.append(c) + if contains_collective(c): + if is_corresponding_collective_wait(c, wait_snode): + comm_time = runtimes[c] + overlap_info += f"->C[{c.get_name()}]" + break + + if not contains_async_collective(c): + # Sync Collective + comp_time = 0.0 + continue + else: + for w in waits_found: + if is_corresponding_collective_wait(c, w): + # Similar to Sync Collective + # If after our Collective exist another Collective-Wait, + # All compute after it will not be overlapping + comp_time = 0.0 + continue + + comp_time_before = comp_time + + def accumulate_time(_snode: BaseSchedulerNode) -> None: + nonlocal comp_time + comp_time += runtimes[_snode] + + _temp_group_visit_leaves(c, accumulate_time) + comp_time_after = comp_time + overlap_info += f"+{c.get_name()}[{comp_time_after - comp_time_before}]" + + return comm_time, comp_time, overlap_info + + +def coll_exposed_communication_time( + snodes: list[BaseSchedulerNode], + runtimes: dict[BaseSchedulerNode, float], +) -> tuple[float, float, str]: + """ + Calculate exposed communication time for a collective operation by finding its corresponding + wait and accumulating compute time that can overlap with communication. + + The Collective node must be the first in snodes. + Compute time between corresponding Collective and Wait is accumulated. + If there is another pair of Collective and Wait inside, + Only compute before first such Wait' is considered as overlapping. + + Multiple process groups are not modeled so far. + """ + collective_snode = snodes[0] + comm_time = runtimes[collective_snode] + comp_time = 0.0 + collective_outs: OrderedSet[str] = OrderedSet( + o.get_name() for o in collective_snode.get_outputs() + ) + overlap_info = "" + collectives_found: list[BaseSchedulerNode] = [] + for snode in snodes[1:]: + # We may have some ops without Wait, + # e.g. DTensor torch.ops._dtensor.shard_dim_alltoall + unmet_deps = OrderedSet( + d.name for d in snode.unmet_dependencies if not _is_fake_dep(d) + ) + + if unmet_deps & collective_outs: + overlap_info += f"->W[{snode.get_name()}]" + break + + if contains_collective(snode): + if not contains_async_collective(snode): + break + else: + collectives_found.append(snode) + continue + if contains_wait(snode): + has_wait_for_collectives_found = False + for _coll in collectives_found: + if is_corresponding_collective_wait(collective_snode, snode): + has_wait_for_collectives_found = True + break + if has_wait_for_collectives_found: + # Any compute after not overlapping original Collective + break + + comp_time_before = comp_time + + def accumulate_time(_snode: BaseSchedulerNode) -> None: + nonlocal comp_time + comp_time += runtimes[_snode] + + _temp_group_visit_leaves(snode, accumulate_time) + comp_time_after = comp_time + overlap_info += f"+{snode.get_name()}[{comp_time_after - comp_time_before}]" + return comm_time, comp_time, overlap_info + + +def _group_name(snode, with_bufs=False) -> str: + ret = "" + for n in snode.snodes: + if ret: + ret += "_" + ret += n.get_name() + if with_bufs: + ret += f"{list(snode.get_buffer_names())}" + return ret + + +def _is_fake_dep(d): + return isinstance(d, WeakDep) and d.is_fake + + +def _group_names(gns: list[BaseSchedulerNode]) -> str: + return "~".join([gn.get_name() for gn in gns]) + + +def _initialize_memory_tracking(snodes, graph_inputs, graph_outputs): + """Initialize memory tracking data structures""" + name_to_freeable_input_buf = get_freeable_input_buf(snodes, graph_inputs) + peak_memory, snodes_curr_memory, snodes_allocfree, buf_to_snode_last_use = ( + estimate_peak_memory_allocfree( + snodes, name_to_freeable_input_buf, graph_outputs + ) + ) + _curr_memory = dict(zip(snodes, snodes_curr_memory)) + _curr_memory[None] = (0, 0) + return ( + peak_memory, + _curr_memory, + snodes_allocfree, + buf_to_snode_last_use, + name_to_freeable_input_buf, + ) + + +def _initialize_double_linked_list( + snodes: list[BaseSchedulerNode], +) -> tuple[ + dict[BaseSchedulerNode, Optional[BaseSchedulerNode]], + dict[BaseSchedulerNode, Optional[BaseSchedulerNode]], + BaseSchedulerNode, +]: + """Create double-linked list structure from snodes""" + _prev = {} + _next = {} + for i, snode in enumerate(snodes): + _prev[snode] = snodes[i - 1] if i > 0 else None + _next[snode] = snodes[i + 1] if i < len(snodes) - 1 else None + _head = snodes[0] + return _prev, _next, _head + + +def is_corresponding_collective_wait( + collective_snode: BaseSchedulerNode, wait_snode: BaseSchedulerNode +) -> bool: + """ + Check if a wait node corresponds to a given collective node by verifying if the wait + depends on outputs from the collective. + """ + collective_outs = OrderedSet(o.get_name() for o in collective_snode.get_outputs()) + unmet_deps = OrderedSet(d.name for d in wait_snode.unmet_dependencies) + return bool(unmet_deps & collective_outs) + + +def _op_runtime_estimate_mult(snode): + # Apply multipliers for faster experimentation. + # TODO(ivankobzarev): Remove after confirmation that runtime estimations are correct. + if contains_collective(snode): + return config_comms.reorder_sink_runtime_estimations_comm_mult + + return config_comms.reorder_sink_runtime_estimations_non_comm_mult + + +def is_async_collective(snode): + """ + Filtering out ops that contain Collective and Wait inside and considered as Collectives. + See contains_collective function. + If the op contains Wait inside - consider as Synchronous compute. + """ + if python_kernel_name := getattr(snode.node, "python_kernel_name", None): + if "torch.ops._dtensor.shard_dim_alltoall.default" in python_kernel_name: + return False + + return True + + +def contains_async_collective(snode): + return contains_collective(snode, is_async_collective) + + +def _group_nodes_from_linked_list( + head: Optional[BaseSchedulerNode], + tail: Optional[BaseSchedulerNode], + next_dict: dict[BaseSchedulerNode, Optional[BaseSchedulerNode]], +) -> list[BaseSchedulerNode]: + """ + Traverse doubly-linked list from head to tail and return nodes as a list. + + Args: + head: Starting node of the segment + tail: Ending node of the segment (inclusive) + next_dict: Dictionary mapping each node to its next node + + Returns: + List of nodes from head to tail (inclusive) + """ + ret = [] + n = head + while True: + if n is not None: + ret.append(n) + if n == tail: + break + n = next_dict[n] # type: ignore[index] + return ret + + +def _perform_double_linked_list_swap( + candidate: BaseSchedulerNode, + group_head: BaseSchedulerNode, + group_tail: BaseSchedulerNode, + prev_dict: dict[BaseSchedulerNode, Optional[BaseSchedulerNode]], + next_dict: dict[BaseSchedulerNode, Optional[BaseSchedulerNode]], + head: BaseSchedulerNode, +) -> BaseSchedulerNode: + """ + Swap positions of candidate and group in doubly-linked list. + + Transforms: + candidate_prev -> candidate -> group_head...group_tail -> group_tail_next + Into: + candidate_prev -> group_head...group_tail -> candidate -> group_tail_next + + Args: + candidate: Node to swap with group + group_head: First node of group + group_tail: Last node of group + prev_dict: Dictionary mapping nodes to their previous nodes + next_dict: Dictionary mapping nodes to their next nodes + head: Current head of the linked list + + Returns: + New head of the linked list (may change if candidate was the head) + """ + # 0: Update candidate's previous node + candidate_prev = prev_dict[candidate] + if candidate_prev: + next_dict[candidate_prev] = group_head + prev_dict[group_head] = candidate_prev + + # 2: Update group_tail's next node + group_tail_next = next_dict[group_tail] + if group_tail_next: + prev_dict[group_tail_next] = candidate + next_dict[candidate] = group_tail_next + + # 1: Link group_tail to candidate + prev_dict[candidate] = group_tail + next_dict[group_tail] = candidate + + # Update head if candidate was the head + if head == candidate: + return group_head + return head + + +def _calculate_potential_peak_memory_reorder( + candidate: BaseSchedulerNode, + gns: list[BaseSchedulerNode], + group_tail: BaseSchedulerNode, + group_peak_memory: int, + candidate_delta_mem: int, + candidate_allocfree: SNodeMemory, + group_n_to_bufs_after_swap_dealloc_by_candidate: dict, + curr_memory: dict, +) -> tuple[int, dict[BaseSchedulerNode, int]]: + """ + Calculate potential peak memory after swapping candidate with group (reorder version). + + Computes new memory levels for all affected nodes and returns the potential + peak memory along with cached post-allocation memory values for each node. + + Args: + candidate: Node being moved + gns: Group nodes + group_tail: Last node of group + group_peak_memory: Current peak memory within the group + candidate_delta_mem: Net memory change from candidate (alloc - free) + candidate_allocfree: Candidate's allocation/free info + group_n_to_bufs_after_swap_dealloc_by_candidate: Buffers whose deallocation moves to candidate + curr_memory: Current memory state dict + + Returns: + Tuple of (potential_peak_memory, post_alloc_update_dict) + """ + # Caching calculations of memory for group nodes and candidate, + # to apply without recalculation after swap. + _post_alloc_update: dict[BaseSchedulerNode, int] = {} + potential_peak: int = 0 + if not group_n_to_bufs_after_swap_dealloc_by_candidate: + # Not accounting for buffers last use change + potential_peak = max( + group_peak_memory - candidate_delta_mem, + curr_memory[group_tail][1] + - candidate_delta_mem + + candidate_allocfree.size_alloc, + ) + return potential_peak, _post_alloc_update + + # If candidate will be after group, the starting memory level of group nodes + # changes to the -(candidate.size_alloc - candidate.size_free) + mem_after_reorder_delta: int = -candidate_delta_mem + for gn in gns: + gn_post_alloc_mem = curr_memory[gn][0] + mem_after_reorder_delta + _post_alloc_update[gn] = gn_post_alloc_mem + potential_peak = max(potential_peak, gn_post_alloc_mem) + + bufs = group_n_to_bufs_after_swap_dealloc_by_candidate.get(gn) + if bufs is not None: + for buf in bufs: + # Candidate will deallocate those buffers + mem_after_reorder_delta += buf.mpi_buffer.size_free + + candidate_mem_post_alloc = ( + curr_memory[group_tail][1] + + mem_after_reorder_delta + + candidate_allocfree.size_alloc + ) + _post_alloc_update[candidate] = candidate_mem_post_alloc + potential_peak = max(potential_peak, candidate_mem_post_alloc) + return potential_peak, _post_alloc_update + + +def _update_memory_tracking_after_swap_reorder( + candidate: BaseSchedulerNode, + gns: list[BaseSchedulerNode], + group_tail: BaseSchedulerNode, + candidate_delta_mem: int, + candidate_allocfree: SNodeMemory, + group_n_to_bufs_after_swap_dealloc_by_candidate: dict, + post_alloc_update: dict[BaseSchedulerNode, int], + curr_memory: dict, + buf_to_snode_last_use: dict, + snodes_allocfree: dict, +) -> None: + """ + Update memory tracking structures after swap (reorder version). + + Updates curr_memory, buf_to_snode_last_use, and snodes_allocfree dictionaries + to reflect the new memory state after swapping candidate with group. + + Args: + candidate: Node that was moved + gns: Group nodes + group_tail: Last node of group + candidate_delta_mem: Net memory change from candidate (alloc - free) + candidate_allocfree: Candidate's allocation/free info + group_n_to_bufs_after_swap_dealloc_by_candidate: Buffers whose deallocation moves to candidate + post_alloc_update: Cached post-allocation memory values + curr_memory: Current memory state dict (mutated) + buf_to_snode_last_use: Buffer to last-use node mapping (mutated) + snodes_allocfree: Node allocation/free info dict (mutated) + """ + if not group_n_to_bufs_after_swap_dealloc_by_candidate: + for gn in gns: + cm = curr_memory[gn] + curr_memory[gn] = ( + cm[0] - candidate_delta_mem, + cm[1] - candidate_delta_mem, + ) + _candidate_post_alloc_mem = ( + curr_memory[group_tail][1] + candidate_allocfree.size_alloc + ) + _candidate_post_free_mem = ( + _candidate_post_alloc_mem - candidate_allocfree.size_free + ) + curr_memory[candidate] = ( + _candidate_post_alloc_mem, + _candidate_post_free_mem, + ) + return + + # Candidate becomes last use of some bufs + for bufs in group_n_to_bufs_after_swap_dealloc_by_candidate.values(): + for buf in bufs: + buf_to_snode_last_use[buf] = candidate + + size_free_to_move_to_candidate_sum: int = 0 + for n in gns: + _gn_post_alloc_mem: int = post_alloc_update[n] + size_free_to_move_to_candidate: int = sum( + buf.mpi_buffer.size_free + for buf in group_n_to_bufs_after_swap_dealloc_by_candidate[n] + ) + size_free_to_move_to_candidate_sum += size_free_to_move_to_candidate + # group node does not deallocate this after swap + snodes_allocfree[n].size_free -= size_free_to_move_to_candidate + gn_post_free_mem: int = _gn_post_alloc_mem - snodes_allocfree[n].size_free + curr_memory[n] = (_gn_post_alloc_mem, gn_post_free_mem) + _candidate_post_alloc_mem = post_alloc_update[candidate] + snodes_allocfree[candidate].size_free += size_free_to_move_to_candidate_sum + candidate_post_free_mem = ( + _candidate_post_alloc_mem - snodes_allocfree[candidate].size_free + ) + curr_memory[candidate] = ( + _candidate_post_alloc_mem, + candidate_post_free_mem, + ) + + +def _find_buffers_with_changed_last_use( + candidate: BaseSchedulerNode, + gns: list[BaseSchedulerNode], + buf_to_snode_last_use: dict, +) -> dict[BaseSchedulerNode, list[Union[FreeableInputBuffer, Any]]]: + """ + Find buffers whose last use will change after swapping candidate with group. + + When we swap [candidate [group]] to [[group] candidate], some buffers that + were last used by a group node will now be last used by candidate instead. + This affects memory deallocation timing. + + Args: + candidate: The node being moved + gns: Group nodes being swapped with candidate + buf_to_snode_last_use: Mapping of buffers to their current last-use nodes + + Returns: + Dict mapping group nodes to buffers that will change their last-use node + """ + group_n_to_bufs_after_swap_dealloc_by_candidate: dict[ + BaseSchedulerNode, list[Union[FreeableInputBuffer, Any]] + ] = defaultdict(list) + for ( + buf, + snode_last_use, + ) in buf_to_snode_last_use.items(): + succ_nodes = buf.mpi_buffer.succ_nodes + if candidate not in succ_nodes: + continue + + if not any(gn == snode_last_use for gn in gns): + continue + + group_n_to_bufs_after_swap_dealloc_by_candidate[snode_last_use].append(buf) + + return group_n_to_bufs_after_swap_dealloc_by_candidate + + +def _is_node_groupable_for_reorder( + candidate: BaseSchedulerNode, +) -> tuple[bool, Optional[str]]: + """ + Check if a candidate node can be grouped with collective during reordering. + + This pass processes collectives left to right, so we avoid grouping with + already-processed collectives based on configuration. + + Args: + candidate: Node to check for groupability + + Returns: + Tuple of (is_groupable, reason_if_not_groupable) + """ + # This pass processes collectives left to right, + # Do not group with processed collectives. + # Leaving config for experimentation in 2D + if not config_comms.reorder_iterative_group_with_collectives: + if contains_async_collective(candidate): + return ( + False, + f"candidate contains_collective {candidate.get_name()}", + ) + if not config_comms.reorder_iterative_use_runtime_estimations: + if contains_gemm_like(candidate): + return False, "contains_gemm_like" + return True, None + + +def _format_and_log_reordering_stats( + stats: dict[BaseSchedulerNode, ReorderInfo], + head: BaseSchedulerNode, + next_dict: dict[BaseSchedulerNode, Optional[BaseSchedulerNode]], + original_snodes_num: int, + peak_memory: int, + name_to_freeable_input_buf: dict, + graph_outputs: OrderedSet[str], +) -> list[BaseSchedulerNode]: + """ + Format reordering statistics, log them, and return final node list. + + Computes improvement metrics, creates a formatted table (using tabulate if + available), validates the reordered node count, recalculates peak memory, + and logs all information. + + Args: + stats: Per-node reordering statistics + head: Head of the reordered linked list + next_dict: Linked list next pointers + original_snodes_num: Original number of nodes (for validation) + peak_memory: Initial peak memory before reordering + name_to_freeable_input_buf: Buffer memory tracking info + graph_outputs: Graph output names + + Returns: + Final reordered list of scheduler nodes + """ + node_stats = stats + improvement = {snode: node_stats[snode].improvement for snode in node_stats} + total_improvement = sum([improvement[snode] for snode in improvement]) + total_moves = sum([node_stats[snode].moves for snode in node_stats]) + + reorder_log_str = ( + f"reorder_communication_preserving_peak_memory improved overlap by {total_improvement} ns" + f" after {total_moves} reorders.\n" + ) + headers = [ + "Collective node", + "comm_time(us)", + "comp_time(us)", + "initial exposed(us)", + "final exposed(us)", + "improvement(us)", + "limiting factor", + "moves", + "grouped", + "grouped_info", + "overlap_info", + ] + rows = [ + [ + node_summary(snode), + node_info.comm_time / 1e3, + node_info.comp_time / 1e3, + node_info.initial_exposed / 1e3, + node_info.final_exposed / 1e3, + node_info.improvement / 1e3, + node_info.limiting_factor, + node_info.moves, + node_info.grouped, + node_info.grouped_info, + node_info.overlap_info, + ] + for snode, node_info in node_stats.items() + ] + if importlib.util.find_spec("tabulate"): + # pyrefly: ignore[import-error] + from tabulate import tabulate + + reorder_log_str += tabulate( + rows, + headers=headers, + ) + else: + reorder_log_str += ( + "Please `pip install tabulate` to nicely render overlap stats.\n" + ) + reorder_log_str += str(headers) + "\n" + reorder_log_str += "\n".join(map(str, rows)) + + new_snodes = _group_nodes_from_linked_list(head, None, next_dict) + assert len(new_snodes) == original_snodes_num + new_peak_memory, _, _, _ = estimate_peak_memory_allocfree( + new_snodes, name_to_freeable_input_buf, graph_outputs + ) + reorder_log_str += f"\n peak_memory_before:{peak_memory}" + reorder_log_str += f"\n peak_memory_after:{new_peak_memory}" + + overlap_log.info(reorder_log_str) + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "reorder_communication_preserving_peak_memory", + "encoding": "string", + }, + payload_fn=lambda: reorder_log_str, + ) + + return new_snodes + + +def _reorder_communication_preserving_peak_memory_internal( + snodes: list[BaseSchedulerNode], +) -> tuple[list[BaseSchedulerNode], dict[BaseSchedulerNode, ReorderInfo]]: + """ + Internal testing helper that also returns debug info. + Returns: + - reordered snodes list + - dict {snode: ReorderInfo} + """ + has_collectives = False + for snode in snodes: + if contains_collective(snode): + has_collectives = True + break + if not has_collectives: + return snodes, {} + + from torch._inductor.scheduler import GroupedSchedulerNode + + original_snodes_num = len(snodes) + # heuristic to avoid degenerating to quadratic time + graph_inputs: OrderedSet[str] = OrderedSet(V.graph.graph_inputs.keys()) + graph_outputs: OrderedSet[str] = OrderedSet(V.graph.get_output_names()) + ( + peak_memory, + _curr_memory, + snodes_allocfree, + buf_to_snode_last_use, + name_to_freeable_input_buf, + ) = _initialize_memory_tracking(snodes, graph_inputs, graph_outputs) + + runtimes: dict[BaseSchedulerNode, float] = { + snode: estimate_op_runtime(snode) * _op_runtime_estimate_mult(snode) + for snode in snodes + } + # debug stats + stats: dict[BaseSchedulerNode, ReorderInfo] = {} + + total_moves = 0 + + _prev, _next, _head = _initialize_double_linked_list(snodes) + + debug_num_collectives_to_reorder: Optional[int] = ( + config_comms.reorder_iterative_debug_limit_to_reorder + ) + + num_processed_collectives: int = 0 + curr: Optional[BaseSchedulerNode] = _head + debug_iterative_memory_recompute = ( + config_comms.reorder_iterative_debug_memory_recompute + ) + iterative_recompute_error = False + + while curr is not None and _next[curr] is not None: + _next_curr = _next[curr] + if iterative_recompute_error: + break + # pyrefly: ignore [bad-argument-type] + if not contains_async_collective(curr): + curr = _next_curr + continue + + if debug_num_collectives_to_reorder is not None and ( + num_processed_collectives >= debug_num_collectives_to_reorder + ): + break + num_processed_collectives += 1 + + info = stats[curr] = ReorderInfo() + comm_time, comp_time, overlap_info = coll_exposed_communication_time( + _group_nodes_from_linked_list(curr, None, _next), runtimes + ) + info.comm_time = comm_time + info.comp_time = comp_time + info.initial_exposed = info.final_exposed = comm_time - comp_time + info.overlap_info = overlap_info + + candidate = _prev[curr] + group_head = curr + group_tail = curr + group_waits = {} + group_runtime = 0.0 + group_peak_memory = _curr_memory[curr][0] # post_alloc memory + + while candidate is not None: + if config_comms.reorder_iterative_use_runtime_estimations and ( + info.final_exposed + < -config_comms.reorder_iterative_extra_comm_comp_overlap + * info.comm_time + ): + info.limiting_factor = "unexposed by runtime estimations" + break + + if ( + not config_comms.reorder_iterative_unsafe_collectives_reorder + and contains_collective(candidate) + ): + info.limiting_factor = "collective ordering" + break + + gns: list[BaseSchedulerNode] = _group_nodes_from_linked_list( + group_head, group_tail, _next + ) + group = GroupedSchedulerNode( + curr.scheduler, + gns, + temp_grouping=True, + ) + + # We can have multiple deps with the same name. + # As we ignore WeakDep(is_fake=True) => + # filter them out first to avoid overwriting of real dep. + data_deps = { + d.name: d for d in group.unmet_dependencies if not _is_fake_dep(d) + } + + candidate_outs = candidate.get_outputs() + data_dep = None + for o in candidate_outs: + if d := data_deps.get(o.get_name(), None): + data_dep = d + break + + if data_dep is not None: + is_groupable_result, grouping_reason = _is_node_groupable_for_reorder( + candidate + ) + if is_groupable_result: + group_head = candidate + # pyrefly: ignore[unbound-name] + if config_comms.reorder_iterative_use_runtime_estimations: + if contains_wait(candidate): + comm_time, comp_time, _ = wait_exposed_communication_time( + _group_nodes_from_linked_list(_head, candidate, _next), + runtimes, + ) + group_waits[candidate] = comm_time, comp_time + if not contains_async_collective(candidate): + group_runtime += runtimes[candidate] + + group_peak_memory = max( + group_peak_memory, _curr_memory[candidate][0] + ) + info.grouped += 1 + info.grouped_info = _group_names(gns) + candidate = _prev[candidate] + continue + else: + msg = ( + f"data dependency {data_dep}(dep_names:{list(data_deps.keys())})" + f"\n candidate:{candidate.get_name()}(outs:{[candidate.get_buffer_names()]})" + f"dep on {_group_names(gns)}" + f"\n non_group_reason:{grouping_reason}" + ) + info.limiting_factor = msg + break + + # pyrefly: ignore[unbound-name] + if config_comms.reorder_iterative_use_runtime_estimations: + # Check if candidate has sync runtime + if not contains_async_collective(candidate): + c_runtime = runtimes[candidate] + + if c_runtime > 0 and len(group_waits) > 0: + # pyrefly: ignore[no-matching-overload] + exposed_before = max(0, info.comm_time - info.comp_time) + # pyrefly: ignore[no-matching-overload] + exposed_after = max( + 0, info.comm_time - info.comp_time - c_runtime + ) + exposed_delta = exposed_after - exposed_before + for gw_comm_time, gw_comp_time in group_waits.values(): + gw_exposed_before = max(0, gw_comm_time - gw_comp_time) + gw_exposed_after = max( + 0, gw_comm_time - gw_comp_time + c_runtime + ) + + exposed_delta += gw_exposed_after - gw_exposed_before + + if exposed_delta > 0: + info.limiting_factor = ( + f"candidate has compute {c_runtime}," + f" group contains waits, total_exposed_delta {exposed_delta}" + ) + break + else: + # Update all group_colls comm_time, comp_time + for gw, ( + gw_comm_time, + gw_comp_time, + ) in group_waits.items(): + group_waits[gw] = ( + gw_comm_time, + gw_comp_time - c_runtime, + ) + else: + # Candidate is async_collective + + # Unsafe collectives reordering + # Cj -> [...group_runtime..., Ci] -> Wj + # Checking that we are not increasing exposed time of Cj + if group_runtime > 0: + comm_time, comp_time, _ = coll_exposed_communication_time( + _group_nodes_from_linked_list(candidate, None, _next), + runtimes, + ) + # pyrefly: ignore[no-matching-overload] + exposed_before = max(0, comm_time - comp_time) + # pyrefly: ignore[no-matching-overload] + exposed_after = max(0, comm_time - comp_time + group_runtime) + exposed_delta = exposed_after - exposed_before + if exposed_delta > 0: + info.limiting_factor = ( + f"candidate {candidate.get_name()} is collective," + f" group_runtime:{group_runtime}," + f" exposed_delta:{exposed_delta} c_comm_time:{comm_time} c_comp_time:{comp_time}" + ) + break + + candidate_allocfree: SNodeMemory = snodes_allocfree[candidate] + candidate_delta_mem: int = ( + candidate_allocfree.size_alloc - candidate_allocfree.size_free + ) + # candidate and one of group nodes are successors of the same buffer + # and last use of the buffer happen in group nodes. + # This last use deallocates it. + # If we swap [candidate [group]] to [[group] candidate], + # candidate becomes the last use + # and deallocated this buffer instead of group node. + # we need to update size_free accordingly to group_node and candidate, + # and recalculate post_alloc, post_free for them. + # + # Buf that changes its last use snode, + # after swap will be deallocated only by candidate, + # while before it was deallocated by group node. + group_n_to_bufs_after_swap_dealloc_by_candidate = ( + _find_buffers_with_changed_last_use( + candidate, gns, buf_to_snode_last_use + ) + ) + + potential_peak, _post_alloc_update = ( + _calculate_potential_peak_memory_reorder( + candidate, + gns, + group_tail, + group_peak_memory, + candidate_delta_mem, + candidate_allocfree, + group_n_to_bufs_after_swap_dealloc_by_candidate, + _curr_memory, + ) + ) + + if ( + potential_peak - peak_memory + # pyrefly: ignore[unbound-name] + > peak_memory * config_comms.reorder_iterative_peak_memory_budget + ): + info.limiting_factor = ( + f"peak memory new:{potential_peak} vs base:{peak_memory}" + ) + break + info.moves += 1 + total_moves += 1 + + _head = _perform_double_linked_list_swap( + candidate, group_head, group_tail, _prev, _next, _head + ) + + comm_time, comp_time, overlap_info = coll_exposed_communication_time( + _group_nodes_from_linked_list(curr, None, _next), runtimes + ) + info.comm_time = comm_time + info.comp_time = comp_time + info.overlap_info = overlap_info + info.final_exposed = comm_time - comp_time + + _update_memory_tracking_after_swap_reorder( + candidate, + gns, + group_tail, + candidate_delta_mem, + candidate_allocfree, + group_n_to_bufs_after_swap_dealloc_by_candidate, + _post_alloc_update, + _curr_memory, + buf_to_snode_last_use, + snodes_allocfree, + ) + + if debug_iterative_memory_recompute: + # Compare iteratively recomputed memory data + # with full run of estimate_peak_memory + + from .comms_debug import _debug_iterative_memory_recompute + + iterative_recompute_error = _debug_iterative_memory_recompute( + candidate, + gns, + _group_names(gns), + _group_nodes_from_linked_list(_head, None, _next), + name_to_freeable_input_buf, + graph_outputs, + peak_memory, + _curr_memory, + snodes_allocfree, + "reorder_communication_preserving_peak_memory", + group_n_to_bufs_after_swap_dealloc_by_candidate, + ) + if iterative_recompute_error: + break + candidate = _prev[group_head] + curr = _next_curr + + new_snodes = _format_and_log_reordering_stats( + stats, + _head, + _next, + original_snodes_num, + peak_memory, + name_to_freeable_input_buf, + graph_outputs, + ) + + return new_snodes, stats + + +def _schedule_for_comm( + snodes: list[BaseSchedulerNode], + raise_comms: bool, + sink_waits: bool, + reorder_for_overlap: bool, +) -> list[BaseSchedulerNode]: + """ + Schedule `snodes` for various comm optimization objectives. + + Args: + snodes: the nodes to be scheduled. + raise_comms: whether to greedily schedule collectives as early as possible + sink_wait: whether to greedily schedule waits as late as possible + reorder_compute_for_overlap: whether to reorder compute nodes to + optimize for compute/communication overlapping. + + Returns: + The new schedule order. + + Some notes on the synergy between different options: + - `raise_comms` provides more overlapping oppurtunies for `reorder_compute_for_overlap`. + - When both `raise_comms` and `sink_waits` is `True`, `raise_comms` is prioritized. + """ + # We assign each node a tuple of scores (score_0, score_1, score_2), + # decreasing in importance, with a lower value indicating a higher ranking: + # + # - score_0: the lowest comm_idx among the comm nodes that the node blocks. + # If a node doesn't block any comm nodes, its score_0 is set to + # sys.maxsize. This score ensures that comm nodes get scheduled as early as + # possible. + # - score_1: 1 if the node is a wait node, 0 otherwise. This score ensures + # that wait nodes are deferred as late as possible. + # - score_2: the index of the node in the original topological order. This + # score provides stability in case of ties. + # + # When only raise_comms is True, only score_0 and score_2 are considered. + # When only sink_waits is True, only score_1 and score_2 are considered. + # When neither is True, the original order is yielded. + buf_name_to_snode = {} + name_to_fused_node = {} + scores_0, scores_1, scores_2 = {}, {}, {} + for idx, snode in enumerate(snodes): + for buf_name in snode.get_buffer_names(): + buf_name_to_snode[buf_name] = snode + + for op_name in snode.get_operation_names(): + name_to_fused_node[op_name] = snode + name_to_fused_node[snode.get_name()] = snode + + node_name = snode.get_name() + scores_0[node_name] = sys.maxsize + scores_1[node_name] = 0 + scores_2[node_name] = idx + + comm_idx = 0 + for snode in snodes: + if raise_comms and contains_collective(snode): + scores_0[snode.get_name()] = comm_idx + for ancestor in snode.ancestors: + anc_fused_name = name_to_fused_node[ancestor].get_name() + scores_0[anc_fused_name] = min(scores_0[anc_fused_name], comm_idx) + comm_idx += 1 + elif sink_waits and contains_wait(snode): + scores_1[snode.get_name()] = 1 + + class Runnable: + def __init__(self, snode) -> None: + self.snode = snode + name = next(iter(snode.get_operation_names())) + fused_name = name_to_fused_node[name].get_name() + self.score = ( + scores_0[fused_name], + scores_1[fused_name], + scores_2[fused_name], + ) + + def __lt__(self, other): + return self.score < other.score + + unmet_deps: dict[BaseSchedulerNode, OrderedSet[str]] = { + snode: OrderedSet(dep.name for dep in snode.unmet_dependencies) + for snode in snodes + } + + ready: list[Runnable] = [] + buffer_users: dict[str, OrderedSet[BaseSchedulerNode]] = defaultdict(OrderedSet) + snode_to_cost = {snode: estimate_op_runtime(snode) for snode in snodes} + + for snode, deps in unmet_deps.items(): + if len(deps) == 0: + heapq.heappush(ready, Runnable(snode)) + for dep in deps: + buffer_users[dep].add(snode) + + scheduled = [] + + def schedule(snode): + """ + Schedules `snode` and put all unblocked nodes onto the ready queue. + """ + scheduled.append(snode) + for buf_name in snode.get_buffer_names(): + for snode in buffer_users[buf_name]: + unmet_deps[snode].remove(buf_name) + if len(unmet_deps[snode]) == 0: + heapq.heappush(ready, Runnable(snode)) + + def get_overlapping_candidate(): + """ + Return the next node in the ready queue that's neither a collective or + a wait. + """ + candidates = [ + x + for x in ready + if not contains_collective(x.snode) and not contains_wait(x.snode) + ] + if len(candidates) == 0: + return None + return min(candidates, key=lambda x: x.score) + + def schedule_collective_for_overlap(snode): + """ + Schedules collective node `snode`, along with one or more compute nodes + to overlap with it. The strategy is described in the comment of + `reorder_compute_for_overlap`. + """ + assert contains_collective(snode) + schedule(snode) + + collective_cost = snode_to_cost[snode] + while ( + collective_cost > 0 + and (candidate := get_overlapping_candidate()) is not None + ): + ready.remove(candidate) + + schedule(candidate.snode) + + collective_cost -= snode_to_cost[candidate.snode] + heapq.heapify(ready) + + while ready: + snode = heapq.heappop(ready).snode + if reorder_for_overlap and contains_collective(snode): + schedule_collective_for_overlap(snode) + else: + schedule(snode) + + for deps in unmet_deps.values(): + assert len(deps) == 0, ( + f"Detected unscheduled nodes. Nodes with unmet dependencies: {unmet_deps}" + ) + return scheduled + + +def decide_global_ordering_of_comms( + nodes: list[BaseSchedulerNode], name_to_buf, name_to_fused_node +) -> list[BaseSchedulerNode]: + """ + Decide global ordering of comms, by just enforcing the ordering that's in the input graph + (might not be the same ordering as the eager mode program). + TODO: Come up with a better approach + """ + if not torch.distributed.is_available(): + return nodes + + comm_nodes = [n for n in nodes if contains_collective(n)] + + for i in range(1, len(comm_nodes)): + # Enforce ordering by making previous comm a `WeakDep` dependency of the next comm + mutating_buf = next(iter(comm_nodes[i].get_buffer_names())) + for buf in comm_nodes[i - 1].get_buffer_names(): + comm_nodes[i].add_fake_dep( + WeakDep(buf, mutating_buf=mutating_buf, is_fake=True) + ) + + return nodes + + +@dataclass +class SinkWaitInfo: + grouped: int = 0 + grouped_info: str = "" + moves: int = 0 + moves_info: str = "" + limiting_factor: str = "None" + comm_time: float = -1.0 + comp_time: float = -1.0 + initial_exposed: float = -1.0 + final_exposed: float = -1.0 + overlap_info: str = "None" + + @property + def improvement(self): + return self.initial_exposed - self.final_exposed + + +def _is_node_groupable_for_sink_waits( + candidate: BaseSchedulerNode, +) -> tuple[bool, Optional[str]]: + """ + Check if a candidate node can be grouped during sink_waits pass. + + Sink Waits traverses waits right to left, so we don't group with + processed waits on the right or with async collectives. + + Args: + candidate: Node to check for groupability + + Returns: + Tuple of (is_groupable, reason_if_not_groupable) + """ + # Sink Waits traverse Waits right to left, + # => we do not group with processed Waits on the right. + if contains_wait(candidate): + return False, f"candidate contains wait {candidate.get_name()}" + if contains_async_collective(candidate): + return ( + False, + f"candidate contains_async_collective {candidate.get_name()}", + ) + + # pyrefly: ignore[unbound-name] + if not config_comms.sink_iterative_use_runtime_estimations: + # Heuristics pre-use_runtime_estimations: + # TODO(ivankobzarev): Remove them after confirming, + # that using runtime estimations always give better results. + # We do not want to group with collectives to not reorder them forward. + if contains_collective(candidate): + return ( + False, + f"candidate contains collective {candidate.get_name()}", + ) + if contains_gemm_like(candidate): + return ( + False, + f"candidate contains gemm_like {candidate.get_name()}", + ) + return True, None + + +def _update_memory_tracking_after_swap_sink_waits( + candidate: BaseSchedulerNode, + gns: list[BaseSchedulerNode], + candidate_delta_mem: int, + candidate_allocfree: SNodeMemory, + group_n_to_bufs_after_swap_dealloc_instead_of_candidate: dict, + post_alloc_update: dict[BaseSchedulerNode, int], + size_free_delta_update: dict[BaseSchedulerNode, int], + curr_memory: dict, + snodes_allocfree: dict, +) -> None: + """ + Update memory tracking structures after swap (sink_waits version). + + Updates curr_memory and snodes_allocfree dictionaries to reflect the new + memory state after swapping candidate with group. + + Args: + candidate: Node that was moved + gns: Group nodes + candidate_delta_mem: Net memory change from candidate (alloc - free) + candidate_allocfree: Candidate's allocation/free info + group_n_to_bufs_after_swap_dealloc_instead_of_candidate: Buffers whose deallocation moves from candidate to group + post_alloc_update: Cached post-allocation memory values + size_free_delta_update: Cached size-free delta values + curr_memory: Current memory state dict (mutated) + snodes_allocfree: Node allocation/free info dict (mutated) + """ + group_head = gns[0] + pre_group_mem = curr_memory[group_head][0] - snodes_allocfree[group_head].size_alloc + if not group_n_to_bufs_after_swap_dealloc_instead_of_candidate: + candidate_post_alloc = pre_group_mem + candidate_allocfree.size_alloc + curr_memory[candidate] = ( + candidate_post_alloc, + candidate_post_alloc - candidate_allocfree.size_free, + ) + for gn in gns: + cm = curr_memory[gn] + curr_memory[gn] = ( + cm[0] + candidate_delta_mem, + cm[1] + candidate_delta_mem, + ) + return + + for n in [candidate, *gns]: + post_alloc = post_alloc_update[n] + snodes_allocfree[n].size_free += size_free_delta_update.get(n, 0) + curr_memory[n] = ( + post_alloc, + post_alloc - snodes_allocfree[n].size_free, + ) + + +def _calculate_potential_peak_memory_sink_waits( + candidate: BaseSchedulerNode, + gns: list[BaseSchedulerNode], + group_head: BaseSchedulerNode, + group_peak_memory: int, + candidate_delta_mem: int, + candidate_allocfree: SNodeMemory, + group_n_to_bufs_after_swap_dealloc_instead_of_candidate: dict, + curr_memory: dict, + snodes_allocfree: dict, +) -> tuple[int, dict[BaseSchedulerNode, int], dict[BaseSchedulerNode, int]]: + """ + Calculate potential peak memory after swapping candidate with group (sink_waits version). + + Computes new memory levels for all affected nodes and returns the potential + peak memory along with cached post-allocation and size-free delta values. + + Args: + candidate: Node being moved + gns: Group nodes + group_head: First node of group + group_peak_memory: Current peak memory within the group + candidate_delta_mem: Net memory change from candidate (alloc - free) + candidate_allocfree: Candidate's allocation/free info + group_n_to_bufs_after_swap_dealloc_instead_of_candidate: Buffers whose deallocation moves from candidate to group + curr_memory: Current memory state dict + snodes_allocfree: Allocation/free info for all nodes + + Returns: + Tuple of (potential_peak_memory, post_alloc_update_dict, size_free_delta_update_dict) + """ + pre_group_mem = curr_memory[group_head][0] - snodes_allocfree[group_head].size_alloc + # Stash memory tracing updates to not recompute them after swap + _post_alloc_update: dict[BaseSchedulerNode, int] = {} + _size_free_delta_update: dict[BaseSchedulerNode, int] = {} + + potential_peak = 0 + if not group_n_to_bufs_after_swap_dealloc_instead_of_candidate: + # Not accounting for buffers liveliness change + potential_peak = max( + group_peak_memory + candidate_delta_mem, + pre_group_mem + candidate_allocfree.size_alloc, + ) + return potential_peak, _post_alloc_update, _size_free_delta_update + + candidate_post_alloc = pre_group_mem + candidate_allocfree.size_alloc + _post_alloc_update[candidate] = candidate_post_alloc + potential_peak = candidate_post_alloc + candidate_size_free_to_move = sum( + buf.mpi_buffer.size_free # type: ignore[attr-defined] + for buf in itertools.chain.from_iterable( + group_n_to_bufs_after_swap_dealloc_instead_of_candidate.values() + ) + ) + _size_free_delta_update[candidate] = -candidate_size_free_to_move + delta_mem = candidate_delta_mem + candidate_size_free_to_move + for gn in gns: + gn_post_alloc = curr_memory[gn][0] + delta_mem + _post_alloc_update[gn] = gn_post_alloc + potential_peak = max(potential_peak, gn_post_alloc) + gn_size_free_to_add = 0 + if gn in group_n_to_bufs_after_swap_dealloc_instead_of_candidate: + bufs = group_n_to_bufs_after_swap_dealloc_instead_of_candidate[gn] + for buf in bufs: + gn_size_free_to_add += buf.mpi_buffer.size_free + _size_free_delta_update[gn] = gn_size_free_to_add + delta_mem -= gn_size_free_to_add + return potential_peak, _post_alloc_update, _size_free_delta_update + + +def _perform_double_linked_list_swap_sink_waits( + candidate: BaseSchedulerNode, + group_head: BaseSchedulerNode, + group_tail: BaseSchedulerNode, + prev_dict: dict[BaseSchedulerNode, Optional[BaseSchedulerNode]], + next_dict: dict[BaseSchedulerNode, Optional[BaseSchedulerNode]], + head: BaseSchedulerNode, +) -> BaseSchedulerNode: + """ + Swap positions of candidate and group in doubly-linked list (sink_waits version). + + Transforms (moves candidate to the left): + group_head_prev -> group_head...group_tail -> candidate -> candidate_next + Into: + group_head_prev -> candidate -> group_head...group_tail -> candidate_next + + Args: + candidate: Node to swap with group + group_head: First node of group + group_tail: Last node of group + prev_dict: Dictionary mapping nodes to their previous nodes + next_dict: Dictionary mapping nodes to their next nodes + head: Current head of the linked list + + Returns: + New head of the linked list (may change if group_head was the head) + """ + # 0: Update group_head's previous node + group_head_prev = prev_dict[group_head] + if group_head_prev: + next_dict[group_head_prev] = candidate + prev_dict[candidate] = group_head_prev + + # 2: Update candidate's next node + candidate_next = next_dict[candidate] + if candidate_next: + prev_dict[candidate_next] = group_tail + next_dict[group_tail] = candidate_next + + # 1: Link candidate to group_head + prev_dict[group_head] = candidate + next_dict[candidate] = group_head + + # Update head if group_head was the head + if group_head == head: + return candidate + return head + + +def _format_and_log_sink_waits_stats( + stats: dict[BaseSchedulerNode, SinkWaitInfo], + head: BaseSchedulerNode, + next_dict: dict[BaseSchedulerNode, Optional[BaseSchedulerNode]], + original_snodes_num: int, + peak_memory: int, + name_to_freeable_input_buf: dict, + graph_outputs: OrderedSet[str], +) -> list[BaseSchedulerNode]: + """ + Format sink_waits statistics, log them, and return final node list. + + Computes improvement metrics, creates a formatted table (using tabulate if + available), validates the reordered node count, recalculates peak memory, + and logs all information. + + Args: + stats: Per-node sink_waits statistics + head: Head of the reordered linked list + next_dict: Linked list next pointers + original_snodes_num: Original number of nodes (for validation) + peak_memory: Initial peak memory before reordering + name_to_freeable_input_buf: Buffer memory tracking info + graph_outputs: Graph output names + + Returns: + Final reordered list of scheduler nodes + """ + headers = [ + "Wait node", + "comm_time(us)", + "comp_time(us)", + "initial exposed(us)", + "final exposed(us)", + "improvement(us)", + "limiting factor", + "grouped", + "grouped_info", + "moves", + "moves_info", + "overlap_info", + ] + rows = [ + [ + node_summary(snode), + info.comm_time / 1e3, + info.comp_time / 1e3, + info.initial_exposed / 1e3, + info.final_exposed / 1e3, + info.improvement / 1e3, + info.limiting_factor, + info.grouped, + info.grouped_info, + info.moves, + info.moves_info, + info.overlap_info, + ] + for snode, info in stats.items() + ] + log_str = "" + if importlib.util.find_spec("tabulate"): + # pyrefly: ignore[import-error] + from tabulate import tabulate + + log_str += tabulate( + rows, + headers=headers, + ) + else: + log_str += "Please `pip install tabulate` to nicely render overlap stats.\n" + log_str += str(headers) + "\n" + log_str += "\n".join(map(str, rows)) + overlap_log.info(log_str) + new_snodes = _group_nodes_from_linked_list(head, None, next_dict) + assert len(new_snodes) == original_snodes_num + new_peak_memory, _, _, _ = estimate_peak_memory_allocfree( + new_snodes, name_to_freeable_input_buf, graph_outputs + ) + log_str += f"\n sink_waits_iterative peak_memory_before:{peak_memory}" + log_str += f"\n sink_waits_iterative peak_memory_after:{new_peak_memory}" + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "sink_waits_iterative_info", + "encoding": "string", + }, + payload_fn=lambda: log_str, + ) + return new_snodes + + +def _find_buffers_with_changed_last_use_sink_waits( + candidate: BaseSchedulerNode, + gns: list[BaseSchedulerNode], + buf_to_snode_last_use: dict, +) -> dict[BaseSchedulerNode, list[Union[FreeableInputBuffer, Any]]]: + """ + Find buffers whose last use will change after swapping in sink_waits pass. + + When we swap [group] candidate to candidate [group], some buffers that + were last used by candidate will now be last used by a group node instead. + This is the opposite direction from the reorder version. + + Args: + candidate: The node being moved (currently last use) + gns: Group nodes being swapped with candidate + buf_to_snode_last_use: Mapping of buffers to their current last-use nodes + + Returns: + Dict mapping group nodes to buffers that will change their last-use node + """ + group_n_to_bufs_after_swap_dealloc_instead_of_candidate: dict[ + BaseSchedulerNode, list[Union[FreeableInputBuffer, Any]] + ] = defaultdict(list) + for ( + buf, + snode_last_use, + ) in buf_to_snode_last_use.items(): + succ_nodes = buf.mpi_buffer.succ_nodes + if snode_last_use != candidate: # noqa: E711 + continue + # candidate is last use of buf + last_succ_gn = None + for gn in gns: + if gn in succ_nodes: + last_succ_gn = gn + if last_succ_gn is None: + continue + + # gn has successors of buf that after potential swap will become + # last use of buf and start deallocating buf instead of candidate + group_n_to_bufs_after_swap_dealloc_instead_of_candidate[last_succ_gn].append( + buf + ) + + return group_n_to_bufs_after_swap_dealloc_instead_of_candidate + + +def _sink_waits_iterative_internal( + snodes: list[BaseSchedulerNode], +) -> tuple[list[BaseSchedulerNode], dict[BaseSchedulerNode, SinkWaitInfo]]: + from torch._inductor.scheduler import GroupedSchedulerNode + + original_snodes_num = len(snodes) + if original_snodes_num == 0: + return snodes, {} + graph_inputs: OrderedSet[str] = OrderedSet(V.graph.graph_inputs.keys()) + graph_outputs: OrderedSet[str] = OrderedSet(V.graph.get_output_names()) + ( + peak_memory, + _curr_memory, + snodes_allocfree, + buf_to_snode_last_use, + name_to_freeable_input_buf, + ) = _initialize_memory_tracking(snodes, graph_inputs, graph_outputs) + + _prev, _next, _head = _initialize_double_linked_list(snodes) + + stats: dict[BaseSchedulerNode, SinkWaitInfo] = {} + + runtimes: dict[BaseSchedulerNode, float] = { + snode: estimate_op_runtime(snode) * _op_runtime_estimate_mult(snode) + for snode in snodes + } + + curr: Optional[BaseSchedulerNode] = snodes[-1] + + processed_waits = OrderedSet() # type: ignore[var-annotated] + debug_iterative_memory_recompute = ( + config_comms.reorder_iterative_debug_memory_recompute + ) + debug_num_sink_waits_to_reorder: Optional[int] = ( + config_comms.sink_waits_iterative_debug_limit_to_sink + ) + + iterative_recompute_error = False + while curr is not None and _prev[curr] is not None: + _prev_curr = _prev[curr] + if iterative_recompute_error: + break + if ( + debug_num_sink_waits_to_reorder is not None + and len(processed_waits) >= debug_num_sink_waits_to_reorder + ): + break + + # pyrefly: ignore [bad-argument-type] + if not (contains_wait(curr) and curr not in processed_waits): + curr = _prev_curr + continue + + processed_waits.add(curr) + info = stats[curr] = SinkWaitInfo() + comm_time, comp_time, overlap_info = wait_exposed_communication_time( + _group_nodes_from_linked_list(_head, curr, _next), runtimes + ) + info.initial_exposed = info.final_exposed = comm_time - comp_time + info.comm_time = comm_time + info.comp_time = comp_time + info.overlap_info = overlap_info + + candidate = _next[curr] + wait_snode = curr + group_head = curr + group_tail = curr + group_colls = {} + group_runtime = 0.0 + group_peak_memory = _curr_memory[curr][0] + + while candidate is not None: + if config_comms.sink_iterative_use_runtime_estimations and ( + info.final_exposed + < -config_comms.sink_iterative_extra_comm_comp_overlap * info.comm_time + ): + info.limiting_factor = "unexposed by runtime estimations" + break + + gns: list[BaseSchedulerNode] = _group_nodes_from_linked_list( + group_head, group_tail, _next + ) + group = GroupedSchedulerNode( + wait_snode.scheduler, + gns, + temp_grouping=True, + ) + + # We can have multiple deps with the same name. + # As we ignore WeakDep(is_fake=True) => + # filter them out first to avoid overwriting of real dep. + data_deps = { + d.name: d for d in candidate.unmet_dependencies if not _is_fake_dep(d) + } + + group_outs = group.get_outputs() + data_dep = None + for o in group_outs: + if d := data_deps.get(o.get_name(), None): + data_dep = d + break + # Conservative sink wait, limiting by space before next collective. + # The global strategy is that bucketing should create space. + # For 2D we can experiment with allowing to sink Wait beyond non current group collective. + # pyrefly: ignore[unbound-name] + if not config_comms.sink_waits_iterative_swap_with_collectives: + if contains_async_collective(candidate): + info.limiting_factor = ( + f"candidate contains_async_collective {candidate.get_name()}" + ) + break + + # 1. If we have data_dep - we can not swap => trying to group + # 2. If swap candidate and current node both contain collectives => trying to group + if data_dep is not None or ( + both_contain_comms := ( + contains_collective(group) and contains_collective(candidate) + ) + ): + _is_groupable, groupable_reason = _is_node_groupable_for_sink_waits( + candidate + ) + if _is_groupable: + group_tail = candidate + if ( + # pyrefly: ignore[unbound-name] + config_comms.sink_iterative_use_runtime_estimations + and contains_collective(candidate) + ): + comm_time, comp_time, _ = coll_exposed_communication_time( + _group_nodes_from_linked_list(candidate, None, _next), + runtimes, + ) + group_colls[candidate] = (comm_time, comp_time) + if not contains_async_collective(candidate): + group_runtime += runtimes[candidate] + + group_peak_memory = max( + group_peak_memory, _curr_memory[candidate][0] + ) + info.grouped += 1 + info.grouped_info = _group_names(gns) + candidate = _next[candidate] + continue + elif data_dep is None: + if ( + # pyrefly: ignore[unbound-name] + not config_comms.sink_waits_iterative_unsafe_collectives_reorder + and both_contain_comms + ): + info.limiting_factor = ( + f"collective ordering {_group_names(gns)}" + f"\n with candidate:{candidate.get_name()}" + ) + break + else: + info.limiting_factor = ( + f"data dependency {data_dep}(dep_names:{list(data_deps.keys())})" + f"\n candidate:{candidate.get_name()}(os:{[candidate.get_buffer_names()]})" + f"\n dep on {_group_names(gns)}" + f"\n outs:{[o.get_name() for o in group_outs]}" + f"\n non_group_reason:{groupable_reason}" + ) + break + + # pyrefly: ignore[unbound-name] + if config_comms.sink_iterative_use_runtime_estimations: + if is_wait(candidate.node): + # Corresponding collective is before the group, + # Swap can increase exposed time of corresponding collective + comm_time, comp_time, _ = wait_exposed_communication_time( + _group_nodes_from_linked_list(_head, candidate, _next), runtimes + ) + # pyrefly: ignore[no-matching-overload] + exposed_before = max(0, comm_time - comp_time) + # pyrefly: ignore[no-matching-overload] + exposed_after = max(0, comm_time - comp_time + group_runtime) + # We do not know how much we can sink more after this swap, + # Just comparing advantage at the moment for now. + if exposed_after > exposed_before: + info.limiting_factor = ( + "candidate is wait," + f" exposed_before:{exposed_before} vs exposed_after:{exposed_after}" + ) + break + + # Check if candidate has sync runtime + if not contains_async_collective(candidate): + # If candidate has sync runtime, + # Waits of gorup_colls are on the right from group. + # Swap can increase their exposed time. + c_runtime = runtimes[candidate] + + if c_runtime > 0 and len(group_colls) > 0: + # Advantage for current Wait to do the Swap + # pyrefly: ignore[no-matching-overload] + exposed_delta = max( + 0, + info.comm_time - info.comp_time, + ) + # pyrefly: ignore[no-matching-overload] + -max(0, info.comm_time - info.comp_time - c_runtime) + for gc_comm_time, gc_comp_time in group_colls.values(): + exposed_delta += max(0, gc_comm_time - gc_comp_time) - max( + 0, gc_comm_time - gc_comp_time + c_runtime + ) + if exposed_delta > 0: + info.limiting_factor = ( + f"candidate has compute {c_runtime}, group contains collectives," + f" total_exposed_delta {exposed_delta}" + ) + break + else: + # Update all group_colls comm_time, comp_time + for gc, ( + gc_comm_time, + gc_comp_time, + ) in group_colls.items(): + group_colls[gc] = ( + gc_comm_time, + gc_comp_time - c_runtime, + ) + + candidate_allocfree: SNodeMemory = snodes_allocfree[candidate] + candidate_delta_mem = ( + candidate_allocfree.size_alloc - candidate_allocfree.size_free + ) + # [group] candidate -> candidate [group] + # Check for buffers with successors in group and candidate last successor + # + # Buf that changes its last use snode, + # It was deallocated by candidate, + # but after swap it will be deallocated by group node. + group_n_to_bufs_after_swap_dealloc_instead_of_candidate = ( + _find_buffers_with_changed_last_use_sink_waits( + candidate, gns, buf_to_snode_last_use + ) + ) + + potential_peak, _post_alloc_update, _size_free_delta_update = ( + _calculate_potential_peak_memory_sink_waits( + candidate, + gns, + group_head, + group_peak_memory, + candidate_delta_mem, + candidate_allocfree, + group_n_to_bufs_after_swap_dealloc_instead_of_candidate, + _curr_memory, + snodes_allocfree, + ) + ) + if ( + potential_peak - peak_memory + # pyrefly: ignore[unbound-name] + > peak_memory * config_comms.sink_iterative_peak_memory_budget + ): + info.limiting_factor = ( + f"peak memory new:{potential_peak} vs base:{peak_memory}" + ) + break + + info.moves += 1 + info.moves_info += f"+{candidate.get_name()}" + + _head = _perform_double_linked_list_swap_sink_waits( + candidate, group_head, group_tail, _prev, _next, _head + ) + + comm_time, comp_time, overlap_info = wait_exposed_communication_time( + _group_nodes_from_linked_list(_head, curr, _next), runtimes + ) + info.comm_time = comm_time + info.comp_time = comp_time + info.final_exposed = comm_time - comp_time + info.overlap_info = overlap_info + + _update_memory_tracking_after_swap_sink_waits( + candidate, + gns, + candidate_delta_mem, + candidate_allocfree, + group_n_to_bufs_after_swap_dealloc_instead_of_candidate, + _post_alloc_update, + _size_free_delta_update, + _curr_memory, + snodes_allocfree, + ) + + if debug_iterative_memory_recompute: + from .comms_debug import _debug_iterative_memory_recompute + + iterative_recompute_error = _debug_iterative_memory_recompute( + candidate, + gns, + _group_names(gns), + _group_nodes_from_linked_list(_head, None, _next), + name_to_freeable_input_buf, + graph_outputs, + peak_memory, + _curr_memory, + snodes_allocfree, + "sink_waits_iterative", + group_n_to_bufs_after_swap_dealloc_instead_of_candidate, + ) + if iterative_recompute_error: + break + + candidate = _next[group_tail] + curr = _prev_curr + + new_snodes = _format_and_log_sink_waits_stats( + stats, + _head, + _next, + original_snodes_num, + peak_memory, + name_to_freeable_input_buf, + graph_outputs, + ) + + return new_snodes, stats + + +def sink_waits_iterative(snodes: list[BaseSchedulerNode]) -> list[BaseSchedulerNode]: + """ + Similarly to reorder_communication_preserving_peak_memory this pass will try to iteratively + push Wait nodes later, recomputing estimated peak memory before each swap, + and preventing peak memory regressions. + + Pass will be applied to every Wait node. If there are immediate dependencies with next node, + pass will try to group them together and on the next step to swap the group with next candidate. + + If _inductor.config_comms.sink_iterative_use_runtime_estimations is set True, + pass will stop reordering of Wait once corresponding Collective is unexposed, + based on runtime estimations. + + inductor.config_comms.sink_iterative_peak_memory_budget allows to tune how much pass + can regress initial peak memory. + E.g.: + sink_iterative_peak_memory_budget == 0.0 - No regression of initial peak memory is allowed + sink_iterative_peak_memory_budget == 0.2 - Pass can improve comm-compute overlap, sacrificing + 20% of initial peak memory value. + + inductor.config_comms.sink_iterative_extra_comm_comp_overlap config allows to more aggressively + sink waits, stopping only when overlap_compute >= (1 + extra_comm_comp_overlap) * comm_time + """ + return _sink_waits_iterative_internal(snodes)[0] + + +def estimate_op_runtime(snode: BaseSchedulerNode) -> float: + """ + Returns estimated op runtime in milliseconds (ms) + """ + if config.estimate_op_runtime == "default": + runtime = snode.get_estimated_runtime() + else: + assert callable(config.estimate_op_runtime) + runtime = config.estimate_op_runtime(snode) + return runtime + + +def node_summary(snode): + snodes = snode.get_nodes() + if len(snodes) == 1: + detail = "" + if isinstance(snode.node, (ir.ExternKernelOut, ir._CollectiveKernel)): + outs_str = f"outs:{[o.get_name() for o in snode.get_outputs()]}" + ins_str = f"ins:{[d.name for d in snode.unmet_dependencies]}" + detail = f" {snode.get_name()} ({snode.node.python_kernel_name})\n {outs_str}({ins_str})" + layouts = [child.node.get_output_spec() for child in snode.get_nodes()] + out_tensor_info = ",".join( + [ + f" (size={layout.size}, stride={layout.stride})" + if isinstance(layout, ir.Layout) + else "" + for layout in layouts + ] + ) + try: + node_name = snode.node.maybe_get_name() + except AttributeError: + # TODO: node_summary was written without FusedSchedulerNode in mind, generally needs to be hardened + node_name = "" + return f"{snode.node.__class__.__name__}{detail}{out_tensor_info} ({node_name} ({snode.get_estimated_runtime():.0f} ns)" + + # Flatten the summaries for Fused/Foreach/Grouped nodes + summaries = [] + for child_snode in snodes: + summaries.append(node_summary(child_snode)) + return f"{snode.__class__.__name__}: {', '.join(summaries)}" + + +def visualize_overlap(order): + # TODO - this function probably doesn't do a very good job estimating the runtime because it doesn't carefully model + # streams and overlap. For now its mostly useful as a debug visualization. + + total_est_runtime: float = 0.0 + cur_comm_node = None + + def step_log(step, msg): + overlap_log.debug(f"{step:>6}: {msg}") # noqa: G004 + + for step, snode in enumerate(order): + if cur_comm_node is None: + if contains_collective(snode): + total_est_runtime += estimate_op_runtime(snode) + cur_comm_node = snode.node + elif is_wait(snode.node): + # raise AssertionError( + # "Wait is not expected when there is no collective running" + # ) + pass + else: # exposed compute op + total_est_runtime += estimate_op_runtime(snode) + step_log(step, f"{node_summary(snode)}") + else: # cur_comm_node is not None + if contains_collective(snode): + total_est_runtime += estimate_op_runtime(snode) + cur_comm_node = snode.node + step_log(step, f"{node_summary(snode)}") # noqa: G004 + elif is_wait(snode.node): # end of this comm op + step_log(step, f"{node_summary(snode)}") + cur_comm_node = None + else: # overlapped compute op + step_log(step, f"| {node_summary(snode)}") + overlap_log.debug( + f"Est. runtime (ms): {total_est_runtime / 1000 / 1000}" # noqa: G004 + ) + + +def reorder_compute_and_comm_for_overlap( + snodes: list[BaseSchedulerNode], +) -> list[BaseSchedulerNode]: + order = snodes + graph_inputs: OrderedSet[str] = OrderedSet(V.graph.graph_inputs.keys()) + graph_outputs: OrderedSet[str] = OrderedSet(V.graph.get_output_names()) + for p in config.reorder_for_compute_comm_overlap_passes: + if isinstance(p, str) and p in globals(): + p = globals()[p] # it is a builtin pass + assert callable(p), ( + f"Invalid reorder_compute_and_comm_for_overlap pass: {p} is not callable" + ) + peak_memory, _ = estimate_peak_memory( + snodes, get_freeable_input_buf(snodes, graph_inputs), graph_outputs + ) + if torch.distributed.get_rank() == 0: + overlap_log.debug( + f"==== Visualize overlap before reordering pass {p}, {peak_memory=} ====" # noqa: G004 + ) + try: + visualize_overlap(order) + except Exception as e: + overlap_log.debug("", exc_info=e) + t0 = time.time() + order = p(order) # type: ignore[operator] + t = time.time() - t0 + if torch.distributed.get_rank() == 0: + overlap_log.debug( + f"==== Visualize overlap after reordering pass {p} (ran in {t} sec)====" # noqa: G004 + ) + try: + visualize_overlap(order) + except Exception as e: + overlap_log.debug("", exc_info=e) + peak_memory, _ = estimate_peak_memory( + snodes, get_freeable_input_buf(snodes, graph_inputs), graph_outputs + ) + print(f"final {peak_memory=}") + # pyrefly: ignore [bad-return] + return order + + +def remove_fsdp2_unsharded_param_graph_input_usage(graph: torch.fx.Graph): + """ + This FX graph pass replaces uses of FSDP2 unsharded params with their corresponding + graph intermediates that were fsdp.copy_ into the unsharded params in the original graph. + + NOTE: Can only apply this pass to any of the FSDP2 unsharded params that have this pattern + (or repetition of): `resize_(full) -> copy_ -> resize_(0)`. Because of this, for partial-graph case + where `resize_(full) -> copy_` is in one graph and `resize_(0)` is in another graph, we can't + remove these resize and copy ops and thus we will have worse performance there. + + In other words, "do we try to remove all the resize_(full) -> copy_ -> resize_(0) nodes for this unsharded param" + is actually a per-unsharded-param decision, since for each unsharded param, we look at its resize sequence pattern + (in `check_resize_pattern()`) to determine if its set of resize and copy nodes can be removed. + """ + node_list = list(graph.nodes) + + # Find all graph inputs and their resize counts + graph_input_to_resized_to_full_node_idxes = defaultdict(list) + graph_input_to_resized_to_0_node_idxes = defaultdict(list) + for idx, node in enumerate(node_list): + if ( + node.op == "call_function" + and node.target is torch.ops.inductor.resize_storage_bytes_.default + ): + assert node.args[0].op == "placeholder", f"""\ +Resize can only operate on graph inputs, but got {node} which is resizing non-graph-input {node.args[0]} +""" + graph_input = node.args[0] + new_size = node.args[1] + if new_size > 0: + graph_input_to_resized_to_full_node_idxes[graph_input].append(idx) + else: + graph_input_to_resized_to_0_node_idxes[graph_input].append(idx) + + def check_resize_pattern(graph_input): + # Check the number of resize-to-full and resize-to-0 nodes are equal, + # and that for each (resize-to-full, resize-to-0) pair, the resize-to-full node + # always happens before the resize-to-0 node. + # This is the precondition for being able to remove all the resize and copy nodes + # for this specific unsharded param. + resized_to_full_idxes = graph_input_to_resized_to_full_node_idxes.get( + graph_input, [] + ) + resized_to_0_idxes = graph_input_to_resized_to_0_node_idxes.get(graph_input, []) + + if len(resized_to_full_idxes) != len(resized_to_0_idxes): + log.warning( + f""" +Unequal number of resize-to-full and resize-to-0 nodes for graph input {graph_input}: +{len(resized_to_full_idxes)} vs. {len(resized_to_0_idxes)}. +Skipping `remove_fsdp2_unsharded_param_graph_input_usage` FX graph pass. +""" # noqa: G004 + ) + return False + + # Check the sequence: (resize_to_full -> resize_to_0)+ + for resize_to_full_idx, resize_to_0_idx in zip( + resized_to_full_idxes, resized_to_0_idxes + ): + if resize_to_full_idx >= resize_to_0_idx: + log.warning( + f""" +For graph input {graph_input}: resize-to-full node {node_list[resize_to_full_idx]} at index {resize_to_full_idx} +happens after resize-to-0 node {node_list[resize_to_0_idx]} at index {resize_to_0_idx}. +Skipping `remove_fsdp2_unsharded_param_graph_input_usage` FX graph pass for that unsharded param. +""" # noqa: G004 + ) + return False + return True + + # Find all eligible unsharded params and their corresponding graph intermediates. + unsharded_param_to_fsdp_copy_node_idxes = defaultdict(list) + for idx, node in enumerate(node_list): + if node.op == "call_function" and node.target is torch.ops.fsdp.copy_.default: + fsdp_copy_node = node + unsharded_param = node.args[0] + assert unsharded_param.op == "placeholder", f""" +Assumed all FSDP2 `unsharded_param`s to be graph input, but it's not true! +Offending node: {unsharded_param}. Graph: {graph} +""" + if check_resize_pattern(unsharded_param): + unsharded_param_to_fsdp_copy_node_idxes[unsharded_param].append(idx) + + def is_allowed_mutation(node): + return ( + node.target is torch.ops.fsdp.copy_.default + or node.target is torch.ops.inductor.resize_storage_bytes_.default + ) + + def is_node_mutating_unsharded_param_or_its_alias(node, unsharded_params): + # Check whether the node is mutating any of the unsharded params or their aliases. + mutated_arg_idxes = ( + [ + i + for i, x in enumerate(node.target._schema.arguments) + if x.alias_info is not None and x.alias_info.is_write + ] + if isinstance(node.target, torch._ops.OpOverload) + else [] + ) + mutated_node_arg_storages = OrderedSet( + [ + StorageWeakRef(node.args[i].meta["val"].untyped_storage()) + for i in mutated_arg_idxes + ] + ) + storages_of_unsharded_params = OrderedSet( + [ + StorageWeakRef(unsharded_param.meta["val"].untyped_storage()) + for unsharded_param in unsharded_params + ] + ) + return len(mutated_node_arg_storages & storages_of_unsharded_params) > 0 + + # Check no user mutation on any unsharded_param + for node in node_list: + if ( + node.op == "call_function" + and isinstance(node.target, torch._ops.OpOverload) + and node.target._schema.is_mutable + and not is_allowed_mutation(node) + ): + assert not is_node_mutating_unsharded_param_or_its_alias( + node, unsharded_param_to_fsdp_copy_node_idxes.keys() + ), f"""\ +User mutation on FSDP2 unsharded param is not allowed when Traceable FSDP2 is used. Violating node: {node} +""" + + # For each `fsdp.copy_(unsharded_param, Y)`, replace downstream usage of `unsharded_param` with `Y`. + # + # NOTE: Because of "layer reuse" use case, there could be multiple `fsdp.copy_` to the same `unsharded_param` graph input. + # e.g. + # ``` + # fsdp_copy_1 = fsdp.copy_(unsharded_param_1, Y1) + # ... (use of unsharded_param_1) -> Subgraph 1 + # fsdp_copy_2 = fsdp.copy_(unsharded_param_1, Y2) + # ... (use of unsharded_param_1) -> Subgraph 2 + # fsdp_copy_3 = fsdp.copy_(unsharded_param_1, Y3) + # ... (use of unsharded_param_1) -> Subgraph 3 + # ``` + # We must do the replacement only within each subgraph. + for ( + unsharded_param, + fsdp_copy_node_idxes, + ) in unsharded_param_to_fsdp_copy_node_idxes.items(): + for i, fsdp_copy_node_idx in enumerate(fsdp_copy_node_idxes): + fsdp_copy_node = node_list[fsdp_copy_node_idx] + assert fsdp_copy_node.args[0] is unsharded_param + _, replacement = fsdp_copy_node.args + # subgraph_start_idx is exclusive + subgraph_start_idx = fsdp_copy_node_idx + 1 + # subgraph_end_idx is exclusive (also intentionally don't replace args in return op) + subgraph_end_idx = ( + fsdp_copy_node_idxes[i + 1] + if i < len(fsdp_copy_node_idxes) - 1 + else len(node_list) - 1 + ) + subgraph_nodes = node_list[subgraph_start_idx:subgraph_end_idx] + assert not any( + is_node_mutating_unsharded_param_or_its_alias(node, [unsharded_param]) + for node in subgraph_nodes + ), f"""\ +Assumed no ops mutating unsharded param {unsharded_param} in subgraph {subgraph_nodes}, but it's not true! +Graph: {graph} +""" + for node in subgraph_nodes: + if ( + node.op == "call_function" + and unsharded_param in node.args + and node.target != torch.ops.inductor.resize_storage_bytes_.default + ): # TODO(yf225): implement replacement in kwargs + new_args = tuple( + replacement if arg is unsharded_param else arg + for arg in node.args + ) + node.args = new_args + + # Delete `fsdp.copy_(unsharded_param, Y)` nodes + for fsdp_copy_node_idxes in unsharded_param_to_fsdp_copy_node_idxes.values(): + for fsdp_copy_node_idx in fsdp_copy_node_idxes: + fsdp_copy_node = node_list[fsdp_copy_node_idx] + graph.erase_node(fsdp_copy_node) + + # Delete `resize_(unsharded_param, ...)` nodes + for node in node_list: + if ( + node.op == "call_function" + and node.target is torch.ops.inductor.resize_storage_bytes_.default + and node.args[0] in unsharded_param_to_fsdp_copy_node_idxes + ): + graph.erase_node(node) + + +def reinplace_fsdp_all_gather(graph: torch.fx.Graph) -> None: + try: + import torch.distributed.fsdp._fully_shard._fsdp_collectives + + assert torch.distributed.is_available() + # Assert existence of these ops + assert ( + torch.ops._c10d_functional.all_gather_into_tensor + and torch.ops._c10d_functional.all_gather_into_tensor_out + ) + except (ImportError, AttributeError, AssertionError): + return + + from .pattern_matcher import ( + CallFunction, + KeywordArg, + Match, + PatternMatcherPass, + register_graph_pattern, + ) + + """ + all_gather_copy_in = torch.ops.fsdp.all_gather_copy_in.default(...); + getitem = all_gather_copy_in[0]; + (getitem_1 = all_gather_copy_in[1];) # optional + + all_gather_into_tensor = torch.ops._c10d_functional.all_gather_into_tensor.default(getitem, ...); + + -> + + all_gather_copy_in = torch.ops.fsdp.all_gather_copy_in.default(...); + getitem = all_gather_copy_in[0]; + getitem_1 = all_gather_copy_in[1]; + + all_gather_into_tensor = torch.ops._c10d_functional.all_gather_into_tensor_out.default(getitem, ..., out=getitem_1); + """ + + def remove_unused_getitem(g): + # Remove `getitem_X = all_gather_copy_in[1]` which is never used. + node_list = list(g.nodes) + for n in node_list: + if ( + n.target is operator.getitem + and n.args[0].target is torch.ops.fsdp.all_gather_copy_in.default + and n.args[1] == 1 + ): + g.erase_node(n) + + graph_pass = PatternMatcherPass() + + @register_graph_pattern( + CallFunction( + torch.ops._c10d_functional.all_gather_into_tensor.default, + CallFunction( + operator.getitem, + CallFunction( + torch.ops.fsdp.all_gather_copy_in.default, + KeywordArg("all_gather_inputs"), + KeywordArg("all_gather_output"), + KeywordArg("inp_split_sizes"), + KeywordArg("all_gather_input_numel"), + KeywordArg("rank"), + ), + KeywordArg("item_idx"), + ), + KeywordArg("group_size"), + KeywordArg("group_name"), + ), + # pyrefly: ignore [bad-argument-type] + pass_dict=graph_pass, + extra_check=lambda match: match.kwargs["item_idx"] == 0, + ) + def reinplace_all_gather(match: Match, *args, **kwargs): + def repl( + *args, + ): + copy_in_args = args[:-2] + group_size = args[-2] + group_name = args[-1] + all_gather_copy_in = torch.ops.fsdp.all_gather_copy_in.default( + *copy_in_args + ) + getitem = all_gather_copy_in[0] + getitem_1 = all_gather_copy_in[1] + all_gather_into_tensor = ( + torch.ops._c10d_functional.all_gather_into_tensor_out.default( + getitem, group_size, group_name, out=getitem_1 + ) + ) + return all_gather_into_tensor + + match.replace_by_example( + # pyrefly: ignore [bad-argument-type] + repl, + [ + kwargs["all_gather_inputs"], + kwargs["all_gather_output"], + kwargs["inp_split_sizes"], + kwargs["all_gather_input_numel"], + kwargs["rank"], + kwargs["group_size"], + kwargs["group_name"], + ], + ) + + remove_unused_getitem(graph) + graph_pass.apply(graph) # type: ignore[arg-type] + + +def get_op_idx(snode): + assert not isinstance( + snode, + ( + torch._inductor.scheduler.FusedSchedulerNode, + torch._inductor.scheduler.GroupedSchedulerNode, + ), + ) + return int(snode.get_name()[2:]) + + +def enforce_comm_ordering_for_fsdp( + snodes: list[torch._inductor.scheduler.BaseSchedulerNode], + name_to_buf: dict[str, torch._inductor.scheduler.SchedulerBuffer], + name_to_fused_node: dict[str, BaseSchedulerNode], +) -> list[torch._inductor.scheduler.BaseSchedulerNode]: + from . import scheduler + + new_order: list[BaseSchedulerNode] = [] + scheduled = OrderedSet[Any]() + ag_exists = False + rs_exists = False + ag_grouped_node_to_wait_grouped_node = {} + rs_grouped_node_to_wait_grouped_node = {} + snode_name_to_final_snode = {} + + def _create_group_node(snodes_to_group): + group_node = scheduler.GroupedSchedulerNode.create(snodes_to_group) + for snode in snodes_to_group: + snode_name_to_final_snode[snode.get_name()] = group_node + snode_name_to_final_snode[group_node.get_name()] = group_node + return group_node + + # Create grouped nodes for specific sets of ops + for snode in snodes: + # Case 1: Handle AllGather + if is_collective( + snode.node, op=torch.ops._c10d_functional.all_gather_into_tensor_out.default + ) and any( + is_fallback_op( + name_to_fused_node[x].node, torch.ops.fsdp.all_gather_copy_in.default + ) + for x in snode.ancestors + ): + ag_exists = True + ag_snode = snode + ag_related_snode_set: OrderedSet[scheduler.BaseSchedulerNode] = OrderedSet() + + # Find the "cast + copy_in + getitem + all_gather" code block + find_recursive_deps_of_node( + ag_snode, + ag_related_snode_set, + name_to_buf, + name_to_fused_node, + ) + + # Find the "all_gather + all_gather_wait_tensor + copy_out" code block + allowed_ops = OrderedSet( + [ + torch.ops._c10d_functional.all_gather_into_tensor_out.default, + torch.ops._c10d_functional.wait_tensor.default, + torch.ops.fsdp.split_with_sizes_copy.default, + ] + ) + find_recursive_users_of_node( + ag_snode, + ag_related_snode_set, + name_to_buf, + name_to_fused_node, + criteria_cb=lambda x: not ( + isinstance(x, scheduler.NopKernelSchedulerNode) + or ( + isinstance(x, scheduler.ExternKernelSchedulerNode) + and x.node.op_overload in allowed_ops # type: ignore[union-attr] + ) + ), + ) + + # sort nodes by original operation order + ag_related_snodes = sorted( + ag_related_snode_set, key=lambda x: get_op_idx(x) + ) + + # In the "reuse layer" case, some ops in the 2nd all-gather code block could also + # depend on ops in the 1st all-gather code block, and we don't want to group them together. + end_idx_of_current_ag_block = len(ag_related_snodes) + copy_out_count = 0 + for i in range(len(ag_related_snodes)): + cur_snode = ag_related_snodes[i] + if is_fallback_op( + cur_snode.node, torch.ops.fsdp.split_with_sizes_copy.default + ): + copy_out_count += 1 + if copy_out_count > 1: + end_idx_of_current_ag_block = i + break + + ag_related_snodes = ag_related_snodes[:end_idx_of_current_ag_block] + + # Group "cast + copy_in + getitem + all_gather" into one GroupedSchedulerNode + wait_node_idx = None + for i in range(len(ag_related_snodes) - 1): + if isinstance(ag_related_snodes[i + 1].node, ir._WaitKernel): + wait_node_idx = i + 1 + break + assert wait_node_idx is not None + ag_group_node = _create_group_node(ag_related_snodes[:wait_node_idx]) + + # Group "all_gather_wait_tensor + copy_out" into one GroupedSchedulerNode + ag_wait_group_node = _create_group_node(ag_related_snodes[wait_node_idx:]) + + ag_grouped_node_to_wait_grouped_node[ag_group_node] = ag_wait_group_node + + # Case 2: Handle ReduceScatter + elif is_fallback_op(snode.node, torch.ops.fsdp.chunk_cat.default): + rs_exists = True + rs_snode = snode + + # Find the "reduce_scatter copy-in + reduce_scatter comm + reduce_scatter wait" code block + rs_related_snode_set: OrderedSet[scheduler.BaseSchedulerNode] = OrderedSet() + find_recursive_users_of_node( + rs_snode, + rs_related_snode_set, + name_to_buf, + name_to_fused_node, + ) + + # sort nodes by original operation order + rs_related_snodes = sorted( + rs_related_snode_set, key=lambda x: get_op_idx(x) + ) + + # Group "reduce_scatter copy-in + reduce_scatter comm" into one GroupedSchedulerNode + wait_node_idx = None + for i in range(len(rs_related_snodes) - 1): + if isinstance(rs_related_snodes[i + 1].node, ir._WaitKernel): + wait_node_idx = i + 1 + break + assert wait_node_idx is not None + rs_group_node = _create_group_node(rs_related_snodes[:wait_node_idx]) + + # Group "reduce_scatter wait + related output nodes" into one GroupedSchedulerNode + rs_wait_group_node = _create_group_node(rs_related_snodes[wait_node_idx:]) + + rs_grouped_node_to_wait_grouped_node[rs_group_node] = rs_wait_group_node + + assert len(snode_name_to_final_snode) > 0 + if ag_exists: + assert len(ag_grouped_node_to_wait_grouped_node) > 0 + if rs_exists: + assert len(rs_grouped_node_to_wait_grouped_node) > 0 + + # Build the new node schedule, taking GroupedSchedulerNode into account + for snode in snodes: + if snode.get_name() in snode_name_to_final_snode: + snode = snode_name_to_final_snode[snode.get_name()] + if snode in scheduled: + continue + new_order.append(snode) + scheduled.add(snode) + + # Enforce AllGather ordering: previous AllGather's "wait then copy_out" group node must run + # before next AllGather's "copy_in then AG" group node + prev_ag_wait = None + for ag_group_node, wait_group_node in ag_grouped_node_to_wait_grouped_node.items(): + if prev_ag_wait is not None: + mutating_buf = next(iter(ag_group_node.get_buffer_names())) + for o in prev_ag_wait.get_outputs(): + ag_group_node.add_fake_dep( + WeakDep(o.get_name(), mutating_buf=mutating_buf, is_fake=True) + ) + prev_ag_wait = wait_group_node + + # Enforce ReduceScatter ordering: previous ReduceScatter's "wait" group node must run + # before next ReduceScatter's "copy_in then RS" group node + prev_rs_wait = None + for rs_group_node, wait_group_node in rs_grouped_node_to_wait_grouped_node.items(): + if prev_rs_wait is not None: + mutating_buf = next(iter(rs_group_node.get_buffer_names())) + for o in prev_rs_wait.get_outputs(): + rs_group_node.add_fake_dep( + WeakDep(o.get_name(), mutating_buf=mutating_buf, is_fake=True) + ) + prev_rs_wait = wait_group_node + + return new_order # type: ignore[return-value] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/comms_debug.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/comms_debug.py new file mode 100644 index 0000000000000000000000000000000000000000..20c9779a4ef3f0e75e35c602b99f64e3df285c60 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/comms_debug.py @@ -0,0 +1,112 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Union + +from torch._logging import trace_structured + +from .memory import estimate_peak_memory_allocfree + + +if TYPE_CHECKING: + from torch.utils._ordered_set import OrderedSet + + from .memory import FreeableInputBuffer, SNodeMemory + from .scheduler import BaseSchedulerNode, SchedulerBuffer + + +def _debug_iterative_memory_recompute( + candidate: BaseSchedulerNode, + gns: list[BaseSchedulerNode], + group_names: str, + snodes: list[BaseSchedulerNode], + name_to_freeable_input_buf: dict[str, FreeableInputBuffer], + graph_outputs: OrderedSet[str], + peak_memory: int, + iter_curr_memory: dict[BaseSchedulerNode, tuple[int, int]], + snodes_allocfree: dict[BaseSchedulerNode, SNodeMemory], + tlparse_name: str, + gn_to_bufs_last_use: dict[ + BaseSchedulerNode, list[Union[FreeableInputBuffer, SchedulerBuffer]] + ], +) -> bool: + iterative_recompute_error = False + candidate_allocfree = snodes_allocfree[candidate] + est_peak_memory, snodes_curr_memory, snodes_allocfree, _ = ( + estimate_peak_memory_allocfree( + snodes, name_to_freeable_input_buf, graph_outputs + ) + ) + est_curr_memory = dict(zip(snodes, snodes_curr_memory)) + iter_cm = iter_curr_memory[candidate] + new_cm = est_curr_memory[candidate] + log = "" + if est_peak_memory > peak_memory: + log = "ITERATIVE PEAK DOES NOT MATCH" + iterative_recompute_error = True + if iter_cm != new_cm: + log = "ITERATIVE CURR MEMORY CANDIDATE DOES NOT MATCH" + iterative_recompute_error = True + for gn in gns: + iter_gnm = iter_curr_memory[gn] + new_gnm = est_curr_memory[gn] + if iter_gnm != new_gnm: + log = f"ITERATIVE GN CURR MEMORY DOES NOT MATCH:{gn.get_name()}" + iterative_recompute_error = True + if iterative_recompute_error: + log += ( + f"\nCANDIDATE:{candidate.get_name()}" + f"\nGROUP:{group_names}" + f"\nPEAK_MEMORY_BEFORE:{peak_memory}" + f"\nPEAK_MEMORY_AFTER_SWAP:{est_peak_memory}" + f"\nCANDIDATE:{candidate.debug_str()}" + f"\nCANDIDATE_ITER_CURR_MEMORY:{iter_cm}" + f"\nCANDIDATE_NEW__CURR_MEMORY:{new_cm}" + f"\nCANDIDATE_ITER_ALLOCFREE:{candidate_allocfree}" + f"\nCANDIDATE_NEW_ALLOCFREE:{snodes_allocfree[candidate]}" + ) + peak_log = "" + for i, (pre, _post) in enumerate(snodes_curr_memory): + if est_peak_memory == pre: + n = snodes[i] + peak_log = ( + f"\nNEW_PEAK:{est_peak_memory}(BASE:{peak_memory})" + f" @ SNODE[{i}/{len(snodes)}]:{n.get_name()} {n.debug_str()}" + ) + break + group_log = "" + for i, gn in enumerate(gns): + iter_gnm = iter_curr_memory[gn] + new_gnm = est_curr_memory[gn] + group_log += ( + f"\nGROUP_NODE[{i}]:{gn.debug_str()}" + f"\nGROUP_NODE[{i}] ITER_GNM[{gn.get_name()}]:{iter_gnm}" + f"\nGROUP_NODE[{i}] ESTM_GNM[{gn.get_name()}]:{new_gnm}" + f"\nGROUP_NODE[{i}] ITER_allocfree:{snodes_allocfree[gn]}" + f"\nGROUP_NODE[{i}] ESTM_allocfree:{snodes_allocfree[gn]}" + ) + log += peak_log + log += group_log + log += f"\nGN_TO_BUFS_LAST_USE:{gn_to_bufs_last_use}" + log += "\n\n".join( + [ + ( + f"\nSNODE[{i}]\n{n.debug_str()}" + f"\nITER_cur_mem:{iter_curr_memory[n]}" + f"\nESTM_cur_mem:{est_curr_memory[n]}" + f"\nITER_allocfree:{snodes_allocfree[n]}" + f"\nESTM_allocfree:{snodes_allocfree[n]}" + ) + for i, n in enumerate(snodes) + ] + ) + tname = f"{tlparse_name}_ITERATIVE_RECOMPUTE_ERROR" + print(f"{tname}:\n{log}") + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": tname, + "encoding": "string", + }, + payload_fn=lambda: log, + ) + return iterative_recompute_error diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/compile_fx.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/compile_fx.py new file mode 100644 index 0000000000000000000000000000000000000000..ea740d1493dc7fb797ecb8db79d3001c5d0589e9 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/compile_fx.py @@ -0,0 +1,3006 @@ +from __future__ import annotations + +import contextlib +import copy +import enum +import functools +import io +import itertools +import json +import logging +import os +import sys +import time +import warnings +from abc import ABC, abstractmethod +from collections import defaultdict +from contextlib import AbstractContextManager +from dataclasses import dataclass +from inspect import currentframe +from itertools import count +from operator import attrgetter +from typing import Any, Optional, TYPE_CHECKING, TypeVar, Union +from typing_extensions import Never, override, ParamSpec, Protocol, TypedDict, Unpack +from unittest import mock + +import torch._inductor.async_compile +import torch.fx +import torch.utils._pytree as pytree +from functorch.compile import min_cut_rematerialization_partition +from torch import fx +from torch._dispatch.python import enable_python_dispatcher +from torch._dynamo import ( + compiled_autograd, + config as dynamo_config, + logging as dynamo_logging, + utils as dynamo_utils, +) +from torch._dynamo.device_interface import get_interface_for_device +from torch._dynamo.repro.after_aot import wrap_compiler_debug +from torch._dynamo.utils import ( + chromium_event_timed, + CompileEventLogger, + counters, + detect_fake_mode, + dynamo_timed, + flatten_graph_inputs, + get_metrics_context, + lazy_format_graph_code, + set_feature_use, +) +from torch._functorch import config as functorch_config +from torch._functorch._aot_autograd.subclass_parametrization import ( + unwrap_tensor_subclass_parameters, +) +from torch._functorch.aot_autograd import ( + aot_export_module, + GraphOutputName, + make_boxed_func, + SerializableAOTDispatchCompiler, +) +from torch._inductor.codecache import code_hash, FxGraphCache, output_code_log +from torch._inductor.cudagraph_utils import ( + BoxedDeviceIndex, + format_default_skip_message, + log_cudagraph_skip_and_bump_counter, + PlaceholderInfo, +) +from torch._inductor.custom_graph_pass import CustomPartitionerFn +from torch._inductor.debug import ( + create_mapping_pre_post_grad_nodes, + save_args_for_compile_fx_inner, +) +from torch._inductor.output_code import ( + CompiledAOTI, + CompiledFxGraph, + CompiledFxGraphConstantsWithGm, + get_expanded_dims, + index_expanded_dims, + OutputCode, +) +from torch._inductor.runtime.cache_dir_utils import cache_dir +from torch._inductor.utils import ( + BoxedBool, + count_tangents, + fresh_cache, + get_all_devices, + InputType, + is_gpu, + should_assume_input_aligned, + should_use_remote_fx_graph_cache, + tensor_is_aligned, +) +from torch._library.fake_class_registry import FakeScriptObject +from torch._library.opaque_object import is_opaque_type +from torch._logging import trace_structured +from torch._utils_internal import compile_time_strobelight_meta +from torch.fx import GraphModule +from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols, SymExprPrinter +from torch.fx.passes.fake_tensor_prop import FakeTensorProp +from torch.monitor import _WaitCounter +from torch.utils._ordered_set import OrderedSet + +from .._dynamo.backends.common import aot_autograd +from .._dynamo.exc import ShortenTraceback, SkipFrame +from ..fx._lazy_graph_module import _use_lazy_graph_module +from ..fx.graph import _PyTreeCodeGen +from ..utils._triton import has_triton +from . import config, distributed_autotune, metrics +from .codegen.common import get_wrapper_codegen_for_device, init_backend_registration +from .debug import DebugContext +from .decomposition import select_decomp_table +from .exc import InductorError +from .fx_passes.joint_graph import joint_graph_passes +from .fx_passes.post_grad import post_grad_passes, view_to_reshape +from .fx_passes.pre_grad import pre_grad_passes +from .graph import GraphLowering +from .ir import get_device_type, IRNode +from .output_code import complex_memory_overlap # noqa: F401 +from .triton_bundler import TritonBundler +from .utils import ( + align_inputs_from_check_idxs, + clone_preserve_strides, + copy_misaligned_inputs, + get_cloned_parameter_buffer_name, + get_first_incompatible_cudagraph_node, + maybe_get_suppress_shape_guards_ctx, + output_node, + remove_unaligned_input_idxs, + shape_env_from_inputs, +) +from .virtualized import V + + +if TYPE_CHECKING: + from collections.abc import Callable, Generator, Sequence + + from torch._inductor.output_code import _StrideExprStr + from torch._ops import OpOverload + from torch.export.pt2_archive._package_weights import Weights + + from .ir import ExternKernelNode + + +_P = ParamSpec("_P") +_T = TypeVar("_T") + +if TYPE_CHECKING or not config.is_fbcode(): + # no-op decorator + def time_and_log(attr: str) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: + return dynamo_utils.identity + + def log_optimus_to_scuba(*args: object, **kwargs: object) -> None: + pass + +else: + from torch._inductor.fb.utils import log_optimus_to_scuba, time_and_log + +if TYPE_CHECKING: + import types + + from torch._functorch._aot_autograd.schemas import ( + FQN, + GraphInputName, + GraphSignature, + ) + + CompileFxOutput = Union[ + Callable[[list[object]], Sequence[torch.Tensor]], + str, + list[str], + Weights, + ] + + +class FxCompileMode(enum.Enum): + NORMAL = 0 + # For testing - use the serde FxCompile scheme to debug serialization and + # deserialization of GraphMoule and CompiledFxGraph. + SERIALIZE = 1 + # Compile using a subprocess instead of in-process. + SUBPROCESS = 2 + + +@dataclass +class FxCompileConfig: + mode: FxCompileMode + use_async: bool + use_progressive: bool + + +def _fx_compile_mode_default() -> FxCompileConfig: + name = "TORCHINDUCTOR_FX_COMPILE_MODE" + value = os.environ.get(name) + if value is None: + return FxCompileConfig(FxCompileMode.NORMAL, False, False) + + use_async = False + use_progressive = False + + if value.lower().startswith("progressive+"): + use_progressive = True + value = value[12:] + if value.lower().startswith("async+"): + use_async = True + value = value[6:] + + try: + value = value.upper() + return FxCompileConfig(FxCompileMode[value], use_async, use_progressive) + except KeyError: + import logging + + log = logging.getLogger(__name__) + log.error( + "Invalid value of %s for %s. Expected one of %s. Using default.", + value, + name, + ", ".join(sorted(repr(x) for x in FxCompileMode.__members__)), + ) + # Remove from the environment so subprocesses don't ALSO complain. + os.environ.pop(name) + return FxCompileConfig(FxCompileMode.NORMAL, False, False) + + +def _get_progression_configs() -> list[dict[str, Any]]: + # TODO make this configurable + return [ + {"max_autotune": True}, + ] + + +_fx_compile_config = _fx_compile_mode_default() +fx_compile_mode = _fx_compile_config.mode +fx_compile_async = _fx_compile_config.use_async +fx_compile_progressive = _fx_compile_config.use_progressive + +log = logging.getLogger(__name__) +perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints") +pre_grad_graphs_log = torch._logging.getArtifactLogger(__name__, "pre_grad_graphs") +post_grad_graphs_log = torch._logging.getArtifactLogger(__name__, "post_grad_graphs") +static_inputs_log = torch._logging.getArtifactLogger( + __name__, "cudagraph_static_inputs" +) +inductor_metrics_log = torch._logging.getArtifactLogger(__name__, "inductor_metrics") + + +def get_static_input_idxs(num_fixed: int) -> list[int]: + # If we are inlining NNModules, we treat all torch.nn.Parameters as static for the purposes + # of cudagraphs. Rather than copying these into cudagraph-owned memory + # like we do for normal inputs on each run, we will re-record a cudagraph if these + # parameter locations change. + context = torch._guards.TracingContext.try_get() + fixed = list(range(num_fixed)) + if not context or not context.fw_metadata: + return fixed + + return context.fw_metadata.static_input_indices + + +def record_original_output_strides(gm: GraphModule) -> None: + output_node = gm.graph.find_nodes(op="output")[0] + output_strides = [] + + if not isinstance(output_node.args[0], torch.fx.Node): + output_node_args = output_node.args[0] + else: + output_node_args = output_node.args + + for output in output_node_args: + if ( + isinstance(output, torch.fx.Node) + and (val := output.meta.get("val")) is not None + and isinstance(val, torch.Tensor) + ): + output_strides.append(val.stride()) + else: + # pyrefly: ignore [bad-argument-type] + output_strides.append(None) + output_node.meta["original_output_strides"] = output_strides + + +def _recursive_record_original_output_strides(gm: GraphModule) -> None: + # invoke_subgraph HOP requires output strides to be respected + for node in gm.graph.find_nodes( + op="call_function", target=torch.ops.higher_order.invoke_subgraph + ): + subgraph = getattr(gm, node.args[0].target) + _recursive_record_original_output_strides(subgraph) + + record_original_output_strides(gm) + + +def _recursive_record_user_visible_output_idxs(gm: GraphModule) -> None: + # invoke_subgraph HOP requires output strides to be respected + for node in gm.graph.find_nodes( + op="call_function", target=torch.ops.higher_order.invoke_subgraph + ): + subgraph = getattr(gm, node.args[0].target) + + for node in subgraph.graph.find_nodes(op="output"): + node.meta["user_visible_output_idxs"] = [ + idx + for idx in range(len(node.args[0])) + if isinstance(node.args[0][idx], torch.fx.Node) + ] + _recursive_record_user_visible_output_idxs(subgraph) + + +@functools.lru_cache(None) +def _step_logger() -> Callable[..., None]: + return dynamo_logging.get_step_logger(log) + + +@functools.cache +def _warn_tf32_disabled() -> None: + if ( + torch.cuda.is_available() + and not torch.backends.cuda.matmul.allow_tf32 + and torch.cuda.get_device_capability() >= (8, 0) + ): + warnings.warn( + "TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. " + "Consider setting `torch.set_float32_matmul_precision('high')` for better performance." + ) + + +def _resolve_name_collision(mod: GraphModule, gm: GraphModule) -> None: + """ + In aot_export_module (make_fx), we create get_attr nodes with name prefix + "_tensor_constant" and "_torchbind_obj". See Tracer.create_arg() in + torch/fx/_symbolic_trace.py + + However, this might result in name collision if the original mod already + has a different buffer with the same name. + + We resolve this potential name collision here by changing the target name + with a new number post fix. + """ + + existing_keys = OrderedSet( + [name for name, val in mod.named_parameters(remove_duplicate=False)] + ) + existing_keys.update( + OrderedSet([name for name, val in mod.named_buffers(remove_duplicate=False)]) + ) + + def find_smallest_i(graph: fx.Graph, prefix: str) -> int: + i = 0 + for node in graph.nodes: + if node.op == "get_attr" and node.target.startswith(prefix): + if len(node.target) > len(prefix): + post_fix = node.target.split(prefix)[-1] + if post_fix.isdigit(): + i = max(i, int(post_fix)) + for key in existing_keys: + if key.startswith(prefix): + if len(key) > len(prefix): + post_fix = key.split(prefix)[-1] + if post_fix.isdigit(): + i = max(i, int(post_fix)) + return i + 1 + + for node in gm.graph.nodes: + if node.op == "get_attr": + target_name = node.target + if not target_name.startswith( + "_tensor_constant" + ) and not target_name.startswith("_torchbind_obj"): + continue + + if not hasattr(mod, target_name): + continue + gm_target = attrgetter(target_name)(gm) + model_target = attrgetter(target_name)(mod) + if isinstance(gm_target, FakeScriptObject): + if ( + isinstance(model_target, FakeScriptObject) + and gm_target.real_obj is model_target.real_obj + ): + continue + elif ( + gm_target.device == model_target.device + and gm_target.dtype == model_target.dtype + and torch.equal(gm_target, model_target) + ): + # If tensors with same name from gm and model are indeed the same, we don't need to rename + # Check device first, to avoid torch.equal(wrapper_CUDA__equal) raise when different device + continue + + prefix = ( + "_tensor_constant" + if target_name.startswith("_tensor_constant") + else "_torchbind_obj" + ) + new_id = find_smallest_i(gm.graph, prefix) + new_target_name = f"{prefix}{new_id}" + node.target = new_target_name + setattr(gm, new_target_name, gm_target) + existing_keys.add(new_target_name) + + +def _unlift_graph( + mod: GraphModule, gm: GraphModule, graph_signature: GraphSignature +) -> GraphModule: + from torch.export.unflatten import _assign_attr, _AttrKind + + _resolve_name_collision(mod, gm) + + state_dict: dict[str, Union[torch.nn.parameter.Parameter, torch.Tensor]] = {} + for name, param in mod.named_parameters(remove_duplicate=False): + state_dict[name] = param + _assign_attr( + param, + gm, + name, + attr_kind=_AttrKind.PARAMETER, + ) + for name, buffer in mod.named_buffers(remove_duplicate=False): + state_dict[name] = buffer + _assign_attr( + buffer, + gm, + name, + attr_kind=_AttrKind.BUFFER, + ) + + placeholder_nodes = gm.graph.find_nodes(op="placeholder") + lifted_inputs: list[Optional[FQN]] = [] + + # In AOTI, module parameters and buffers are not lifted as graph inputs. + # As a result, mutation to buffers has side effect which makes their initial + # values different from Eager. So we clone them here as a copy. + # We are not cloning for parameters, although it will be needed if we want to + # support training. + for node in placeholder_nodes: + node_name = node.name + if node_name in graph_signature.inputs_to_parameters: + parameter_name = graph_signature.inputs_to_parameters[node_name] + lifted_inputs.append(parameter_name) + elif node_name in graph_signature.inputs_to_buffers: + buffer_name = graph_signature.inputs_to_buffers[node_name] + lifted_inputs.append(buffer_name) + gm.meta[get_cloned_parameter_buffer_name(buffer_name)] = ( + clone_preserve_strides(state_dict[buffer_name]) + ) + else: + assert node_name in graph_signature.user_inputs + lifted_inputs.append(None) + + from torch.export._unlift import _unlift + + outputs: tuple[torch.fx.Node, ...] = tuple(gm.graph.output_node().args[0]) # type: ignore[arg-type] + mutated_outputs = [] + buffer_mutations = graph_signature.buffers_to_mutate + user_input_mutations = graph_signature.user_inputs_to_mutate + output_tokens = graph_signature.output_tokens + for idx, out in enumerate(outputs): + value: Optional[Union[FQN, GraphInputName]] = None + + if idx < len(buffer_mutations) + len(user_input_mutations) + len(output_tokens): + name = GraphOutputName(out.name) + if name in buffer_mutations: + value = buffer_mutations[name] + elif name in user_input_mutations: + value = user_input_mutations[name] + + mutated_outputs.append(value) + + unlifted_gm = _unlift( + gm, + lifted_inputs, + mutated_outputs, + pytree.treespec_leaf(), + None, + ) + # After unlifting, the buffer mutation information is lost. Pass the information + # so that Inductor can do optimizations correctly. + unlifted_gm.meta["mutated_named_buffers"] = OrderedSet(buffer_mutations.values()) + return unlifted_gm + + +def _get_subgraph_names( + gm: GraphModule, skip_invoke_subgraph: bool = False +) -> Generator[str, None, None]: + all_subgraph_names: OrderedSet[str] = OrderedSet( + x.target for x in gm.graph.find_nodes(op="get_attr") + ) + fx_subgraph_names: OrderedSet[str] = OrderedSet() + for child_name, child_module in gm.named_children(): + # Sometimes an owning_module can have unused children. Skip them + # by checking them from get_attr node targets. + if child_name in all_subgraph_names and isinstance( + child_module, torch.fx.GraphModule + ): + fx_subgraph_names.add(child_name) + + if skip_invoke_subgraph: + for node in gm.graph.find_nodes( + op="call_function", target=torch.ops.higher_order.invoke_subgraph + ): + fx_subgraph_names.discard(node.args[0].target) + + yield from fx_subgraph_names + + +def _recursive_pre_grad_passes( + gm: GraphModule, + example_inputs: Sequence[InputType], +) -> GraphModule: + with dynamo_timed( + "_recursive_pre_grad_passes", + log_pt2_compile_event=True, + dynamo_compile_column_us="pre_grad_pass_time_us", + ): + if not config.use_pre_grad_passes: + return gm + + add_passes = config.add_pre_grad_passes + remove_passes = config.remove_pre_grad_passes + for subgraph_name in _get_subgraph_names(gm): + subgraph = getattr(gm, subgraph_name) + # as we don't have recursive example inputs, passing empty set here + new_subgraph = _recursive_pre_grad_passes(subgraph, ()) + setattr(gm, subgraph_name, new_subgraph) + return pre_grad_passes(gm, example_inputs, add_passes, remove_passes) + + +def _recursive_joint_graph_passes( + gm: GraphModule, skip_invoke_subgraph: bool = False +) -> None: + with dynamo_timed( + "_recursive_joint_graph_passes", + log_pt2_compile_event=True, + dynamo_compile_column_us="joint_graph_pass_time_us", + ): + if not config.use_joint_graph_passes: + return + + # invoke_subgraph already runs the _recursive_joint_graph_passes. In + # AOTAutograd, `run_joint_graph_passes_on_hops` partitions the + # invoke_subgraph HOP before calling the partitioner on the outer graph. + # AOTAutograd has access to partition_fn, which internally calls the + # `_recursive_joint_graph_passes` for the subgraph. So, skip recursing + # skip_invoke_subgraph. + for subgraph_name in _get_subgraph_names(gm, skip_invoke_subgraph): + subgraph = getattr(gm, subgraph_name) + _recursive_joint_graph_passes(subgraph, skip_invoke_subgraph) + joint_graph_passes(gm) + + +def _recursive_post_grad_passes(gm: GraphModule, is_inference: bool = False) -> None: + with dynamo_timed( + "_recursive_post_grad_passes", + log_pt2_compile_event=True, + dynamo_compile_column_us="post_grad_pass_time_us", + ): + if not config.use_post_grad_passes: + return + + for subgraph_name in _get_subgraph_names(gm): + subgraph = getattr(gm, subgraph_name) + _recursive_post_grad_passes(subgraph, is_inference) + post_grad_passes(gm, is_inference) + + +def split_const_gm( + gm: GraphModule, + skip_constructor: bool = True, + lifted_constant_names: Optional[list[str]] = None, + skip_folding_node_fn: Optional[Callable[[torch.fx.Node], bool]] = None, +) -> tuple[GraphModule, dict[str, int]]: + """ + This function takes an GraphModule input "gm". + The gm will be split into 2 components, + 1) const_gm, which consists the subgraph of gm that can be constant folded. + 2) gm (being inplace modified,) which returns the graph after constant folding. + + If an additional "lifted_constants" argument is passed in, we will assume the gm has + been lifted and run the transformation accordingly. + + When a "skip_folding_node_fn" callback is passed, we will skip constant folding on + the nodes for which the callback returns True. + + const_output_index is a mapping of corresponding node name from gm to the + output index of const_gm. + Returns (const_gm, const_output_index) + """ + from torch._inductor.constant_folding import ( + CONST_MODULE_TAG, + META_TAG, + MODULE_TAG, + replace_node_with_constant, + run_and_get_constant_graph, + ) + + const_gm = run_and_get_constant_graph( + gm, skip_constructor, lifted_constant_names, skip_folding_node_fn + ) + const_result = const_gm() if lifted_constant_names is None else None + + const_outputs = { + x.name: idx for idx, x in enumerate(tuple(const_gm.graph.nodes)[-1].args[0]) + } + + to_erase_node = [] + to_replace_node = [] + const_output_index = {} + for node in gm.graph.nodes: + if node.name in const_outputs: + to_replace_node.append(node) + elif node.meta[META_TAG] == CONST_MODULE_TAG and node.op != "placeholder": + to_erase_node.append(node) + + for node in to_replace_node: + new_const_name = "_FOLDED_CONST_" + node.name + replace_node_with_constant( + gm, + node, + ( + const_result[const_outputs[node.name]] # type:ignore[index] + if lifted_constant_names is None + else None + ), + new_const_name, + ) + const_output_index[new_const_name] = const_outputs[node.name] + for node in to_erase_node[::-1]: + if node.users: + for n in node.users: + assert n.meta[META_TAG] == MODULE_TAG, f"node: {node} user not empty." + else: + gm.graph.erase_node(node) + gm.recompile() + + return const_gm, const_output_index + + +def is_tf32_warning_applicable(gm: GraphModule) -> bool: + aten = torch.ops.aten + tf32_ops = OrderedSet( + [ + aten.mm.default, + aten.addmm.default, + aten.bmm.default, + aten.baddbmm.default, + ] + ) + for target in tf32_ops: + for node in gm.graph.find_nodes(op="call_function", target=target): + if ( + isinstance(node.meta.get("val", None), torch.Tensor) + and node.meta["val"].dtype == torch.float32 + and node.meta["val"].device.type == "cuda" + ): + return True + return False + + +def maybe_disable_comprehensive_padding( + example_inputs: Sequence[InputType], +) -> AbstractContextManager[None, None]: + """ + For CPU backend, enable comprehensive padding causes some unit tests + fail due to changing number of generated kernels. Skip for now. + """ + has_gpu = any( + is_gpu(t.device.type) for t in example_inputs if isinstance(t, torch.Tensor) + ) + + if config.disable_padding_cpu and config.comprehensive_padding and not has_gpu: + perf_hint_log.info("Skip comprehensive padding on CPU") + return config.patch(comprehensive_padding=False) + elif config.aot_inductor.use_runtime_constant_folding: + perf_hint_log.info( + "Skip comprehensive padding for use_runtime_constant_folding" + ) + return config.patch(comprehensive_padding=False) + else: + return contextlib.nullcontext() + + +def maybe_disable_graph_partition( + cpp_wrapper: bool, aot_mode: bool +) -> AbstractContextManager[None, None]: + """ + graph partition does not support cpp_wrapper and aot_mode yet. + """ + if cpp_wrapper or aot_mode: + return config.patch(graph_partition=False) + else: + return contextlib.nullcontext() + + +def fake_tensor_prop( + gm: GraphModule, + example_inputs: Sequence[InputType], + force_allow_non_fake_inputs: bool = False, +) -> torch._subclasses.FakeTensorMode: + """ + If we can not detect fake mode from the context of inputs, create one. + + The created fake mode will be returned. + """ + # Ensure that decomps that support symbolic shapes are used + with enable_python_dispatcher(): + fake_mode = detect_fake_mode(example_inputs) + if not fake_mode: + fake_mode = torch._subclasses.FakeTensorMode(allow_non_fake_inputs=True) + FakeTensorProp(gm, mode=fake_mode).propagate(*example_inputs) + else: + ctx = ( + contextlib.nullcontext() + if not force_allow_non_fake_inputs + else mock.patch.object(fake_mode, "allow_non_fake_inputs", True) + ) + with ctx: # type: ignore[attr-defined] + FakeTensorProp(gm, mode=fake_mode).propagate_dont_convert_inputs( + *example_inputs + ) + + return fake_mode + + +# pass config dict back to user +def get_patched_config_dict( + config_patches: Optional[Union[str, dict[str, Any]]] = None, +) -> dict[str, Any]: + with config.patch(config_patches): + return config.get_config_copy() + + +@contextlib.contextmanager +def with_fresh_cache_if_config() -> Generator[None, None, None]: + if config.force_disable_caches: + # Don't delete the cache dir because it has to survive beyond the + # compile_fx call. Let's put the temp dirs under the default cache + # dir so they're easier to locate. + with fresh_cache(dir=cache_dir(), delete=False): + yield + else: + yield + + +class _CompileFxKwargs(TypedDict, total=False): + cudagraphs: Optional[BoxedBool] + static_input_idxs: Sequence[int] + is_backward: bool + graph_id: Optional[int] + cpp_wrapper: bool + aot_mode: bool + is_inference: bool + layout_opt: Optional[bool] + extern_node_serializer: Optional[Callable[[list[ExternKernelNode]], Any]] + boxed_forward_device_index: Optional[BoxedDeviceIndex] + fx_wrapper: bool + + +class _CompileFxCallable(Protocol): + def __call__( + self, + gm: GraphModule, + example_inputs: Sequence[InputType], + **kwargs: Unpack[_CompileFxKwargs], + ) -> OutputCode: ... + + +def compile_fx_inner( + gm: GraphModule, + example_inputs: Sequence[InputType], + **kwargs: Unpack[_CompileFxKwargs], +) -> OutputCode: + kwargs.setdefault("cudagraphs", None) + kwargs.setdefault("static_input_idxs", ()) + kwargs.setdefault("is_backward", False) + kwargs.setdefault("graph_id", None) + kwargs.setdefault("cpp_wrapper", False) + kwargs.setdefault("fx_wrapper", False) + kwargs.setdefault("is_inference", False) + kwargs.setdefault("boxed_forward_device_index", None) + kwargs.setdefault("layout_opt", None) + kwargs.setdefault("extern_node_serializer", None) + + # Need with_fresh_cache_if_config for compile_fx_inner even if we already have one for + # compile_fx. The reason is the compilation for backward graph may happen after + # compile_fx return and we may want to use the _LazyGraphModule for compiling + # the backward graph as well. + with contextlib.ExitStack() as stack: + stack.enter_context(torch.utils._python_dispatch._disable_current_modes()) + stack.enter_context(_use_lazy_graph_module(dynamo_config.use_lazy_graph_module)) + stack.enter_context( + dynamo_utils.dynamo_timed( + "compile_fx_inner", + phase_name="inductor_compile", + log_pt2_compile_event=True, + log_waitcounter=True, + waitcounter_name_override="compile_inductor", + dynamo_compile_column_us="inductor_cumulative_compile_time_us", + ) + ) + stack.enter_context(with_fresh_cache_if_config()) + stack.enter_context(DebugContext()) + CompileEventLogger.pt2_compile( + "inductor_compile", + is_backward=kwargs["is_backward"], + ) + return wrap_compiler_debug(_compile_fx_inner, compiler_name="inductor")( + gm, + example_inputs, + **kwargs, + ) + + +@time_and_log(attr="compilation time (in seconds)") +def _compile_fx_inner( + gm: GraphModule, + example_inputs: Sequence[InputType], + **graph_kwargs: Unpack[_CompileFxKwargs], +) -> OutputCode: + """ + Inductor API that compiles a single graph. + + If you change the argument list for this function, make sure you + also update the call to save_args_for_compile_fx_inner below accordingly. + """ + aot_mode: bool = V.aot_compilation + + # Clean up Compiled Triton Kernels per inductor compile, as the future objects + # may not be valid for use after they are run/autotuned + torch._inductor.async_compile.CompiledTritonKernels.cache_clear() + + if dynamo_utils.count_calls(gm.graph) == 0 and not aot_mode: + # trigger the real recompilation for _LazyGraphModule before returning + # the forward method. + from torch._dynamo.utils import CompileEventLogLevel + from torch.fx._lazy_graph_module import _LazyGraphModule + + _LazyGraphModule.force_recompile(gm) + compile_id = torch._guards.CompileContext.current_compile_id() + CompileEventLogger.log_instant_event( + "backward no-op", + metadata={"compile_id": compile_id}, + log_level=CompileEventLogLevel.PT2_COMPILE, + ) + + return make_boxed_func(gm.forward) + + static_input_idxs: Sequence[int] = graph_kwargs.setdefault("static_input_idxs", ()) + static_inputs_log.debug("static input idxs compile_fx_inner: %s", static_input_idxs) + inputs_to_check = get_input_idxs_to_check(example_inputs, static_input_idxs) + + assert isinstance(next(iter(reversed(gm.graph.nodes))).args[0], (tuple, list)), ( + f"inductor can only compile FX graphs which return a tuple/list, but got {gm.graph}" + ) + + if graph_kwargs.get("cudagraphs") is None: + graph_kwargs["cudagraphs"] = BoxedBool(config.triton.cudagraphs) + if config.save_args: + save_args_for_compile_fx_inner( + gm, + example_inputs, + **graph_kwargs, + ) + + start = time.time() + + fx_graph_remote_cache = should_use_remote_fx_graph_cache() + + # Check if the registered backend(s) support caching. + init_backend_registration() + backends_support_caching = all( + backend.supports_caching + for backend in ( + get_wrapper_codegen_for_device( + device.type, config.cpp_wrapper, config.fx_wrapper + ) + for device in get_all_devices(gm) + ) + if backend is not None + ) + + with dynamo_timed( + "fx_codegen_and_compile", log_pt2_compile_event=True, log_waitcounter=True + ): + use_cache = ( + not config.force_disable_caches + and (config.fx_graph_cache or fx_graph_remote_cache) + and not aot_mode + and backends_support_caching + and not torch._functorch.config.bundled_autograd_cache + ) + local = config.fx_graph_cache + remote = fx_graph_remote_cache + set_feature_use("fx_cache", use_cache) + + log.debug( + "FX cache status: use_cache=%s, local=%s, remote=%s, aot_mode=%s, force_disable_caches=%s", + use_cache, + local, + remote, + aot_mode, + config.force_disable_caches, + ) + + # TODO: This is a hack purely to get some info to extract_tensor_metadata_for_cache_key, + # figure out how to not have to modify example inputs + for i, input in enumerate(example_inputs): + if ( + isinstance(input, torch.Tensor) + and is_gpu(input.device.type) + and i in static_input_idxs + ): + input._is_inductor_static = True # type: ignore[attr-defined] + + mb_compiled_graph: Optional[OutputCode] = None + key_info = None + cache_info = None + remote_cache = None + constants = CompiledFxGraphConstantsWithGm(gm) + # TODO: this time will be slightly inconsistent with the one computed + # in prepare_key/load_with_key, dump those settings of "cache_event_time" + start_time = time.time_ns() + + if use_cache: + (key_info, cache_info) = FxGraphCache.prepare_key( + gm, example_inputs, graph_kwargs, inputs_to_check, remote + ) + + # Attempt a cache lookup + if key_info is not None: + key, debug_lines = key_info + log.debug("FX cache key generated: %s", key) + if remote: + remote_cache = FxGraphCache.get_remote_cache() + log.debug("Using remote FX cache") + mb_compiled_graph, cache_info = FxGraphCache.load_with_key( + key, + debug_lines, + example_inputs, + local, + remote_cache, + is_backward=graph_kwargs.get("is_backward", False), + constants=constants, + ) + else: + log.debug("Failed to generate FX cache key") + + if torch._functorch.config.bundled_autograd_cache: + assert mb_compiled_graph is None + assert cache_info is None + # When using bundled autograd cache, we still want + # to use the TritonBundler, but we don't want to save + # the results here. The results will get saved directly + # to AOTAutogradCache. + TritonBundler.begin_compile() + try: + mb_compiled_graph = fx_codegen_and_compile( + gm, example_inputs, inputs_to_check, **graph_kwargs + ) + assert mb_compiled_graph is not None + ( + triton_bundle, + triton_bundler_meta, + ) = TritonBundler.collect() + mb_compiled_graph.set_triton_bundle(triton_bundle) + except (ShortenTraceback, SkipFrame): + raise + except Exception as e: + raise InductorError(e, currentframe()).with_traceback( + e.__traceback__ + ) from None + finally: + TritonBundler.end_compile() + + # CACHE BYPASS: Compile the graph, don't save it to the cache + # (this can happen either because cache was disabled, or we + # determined the input is uncacheable) + elif cache_info is None or cache_info["cache_state"] == "bypass": + assert mb_compiled_graph is None + log.debug( + "FX cache bypass reason: %s", + ( + cache_info.get("cache_bypass_reason", "unknown") + if cache_info is not None + else "FX cache disabled or key generation failed" + ), + ) + try: + mb_compiled_graph = fx_codegen_and_compile( + gm, example_inputs, inputs_to_check, **graph_kwargs + ) + except Exception as e: + raise InductorError(e, currentframe()).with_traceback( + e.__traceback__ + ) from None + + # CACHE MISS: Compile the graph and save to cache + elif cache_info["cache_state"] == "miss": + assert mb_compiled_graph is None + assert key_info is not None + log.debug("FX cache miss, compiling and saving to cache") + TritonBundler.begin_compile() + try: + mb_compiled_graph = fx_codegen_and_compile( + gm, example_inputs, inputs_to_check, **graph_kwargs + ) + assert mb_compiled_graph is not None + mb_compiled_graph._time_taken_ns = time.time_ns() - start_time + cache_key, debug_lines = key_info + mb_compiled_graph._fx_graph_cache_key = cache_key + mb_compiled_graph._fx_graph_cache_debug_lines = debug_lines + ( + triton_bundle, + triton_bundler_meta, + ) = TritonBundler.collect() + mb_compiled_graph.set_triton_bundle(triton_bundle) + except (ShortenTraceback, SkipFrame): + raise + except Exception as e: + raise InductorError(e, currentframe()).with_traceback( + e.__traceback__ + ) from None + finally: + TritonBundler.end_compile() + if triton_bundler_meta is not None: + cache_info["triton_bundler_meta"] = str(triton_bundler_meta) + cache_info["time_taken_ns"] = mb_compiled_graph._time_taken_ns + log.debug("Saving compiled graph to FX cache with key: %s", cache_key) + FxGraphCache._save_graph( + cache_key, + mb_compiled_graph, + example_inputs, + local, + remote_cache, + ) + + # CACHE HIT: not much to really do, just make sure the cache key + # is recorded on the graph + else: + assert cache_info["cache_state"] == "hit" + assert mb_compiled_graph is not None + assert key_info is not None + (cache_key, debug_lines) = key_info + log.debug("FX cache hit with key: %s", cache_key) + mb_compiled_graph._fx_graph_cache_key = cache_key + mb_compiled_graph._fx_graph_cache_debug_lines = debug_lines + + assert mb_compiled_graph is not None + compiled_graph = mb_compiled_graph + + # Logging and observability: we log a single chromium event + # and a tlparse log for every cache action. + # In the event of a bypass, we also logged to the remote table earlier + # with log_cache_bypass. + cache_state = ( + cache_info["cache_state"] if cache_info is not None else "disabled" + ) + # Here for grepping: + # fx_graph_cache_hit + # fx_graph_cache_miss + # fx_graph_cache_bypass + # fx_graph_cache_disabled + CompileEventLogger.instant( + f"fx_graph_cache_{cache_state}", + metadata=cache_info or {}, + time_ns=start_time, + ) + # Add event data about cache hits/miss + # TODO: add remote cache get/put timings here too + CompileEventLogger.pt2_compile( + "inductor_compile", + cache_state=cache_state, + cache_event_time=start_time, + key=cache_info.get("key") if cache_info else None, + components=cache_info.get("components") if cache_info else None, + cache_bypass_reason=( + cache_info.get("cache_bypass_reason") + if cache_info + else "cache not enabled" + ), + remote_cache_enabled=remote, + local_cache_enabled=local, + ) + + # Don't clog up the main tlparse output with disabled cache + if cache_info is not None: + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": f"fx_graph_cache_{cache_state}", + "encoding": "json", + }, + payload_fn=lambda: json.dumps(cache_info), + ) + compiled_graph.post_compile(example_inputs, constants, graph_kwargs) + + log.debug("FX codegen and compilation took %.3fs", time.time() - start) + + # This message is for printing overview information of inductor mm counts, shapes,etc after lowering + if log.isEnabledFor(logging.INFO): + mm_table_data = [] + for key, value in counters["aten_mm_info"].items(): + parts = key.split("_") + if len(parts) < 3: + # Unexpected format, show as-is + mm_table_data.append([key, "-", "?", "?", "?", value]) + continue + + # Determine if this is a batched operation by checking the operation name + name = "_".join(parts[:-4]) if len(parts) >= 4 else "_".join(parts[:-3]) + is_batched = name.endswith(("bmm", "baddbmm")) + + if is_batched and len(parts) >= 4: + # Batched operation: last 4 parts are batch, m, n, k + batch, m, n, k = parts[-4:] + name = "_".join(parts[:-4]) + mm_table_data.append([name, batch, m, n, k, value]) + else: + # Non-batched operation: last 3 parts are m, n, k + m, n, k = parts[-3:] + name = "_".join(parts[:-3]) + mm_table_data.append([name, "-", m, n, k, value]) + + log.info("Overview info of inductor aten mms: ") + log.info( + "{:<30} | {:<20} | {:<20} | {:<20} | {:<20} | {:<20}".format( # noqa: G001 + "Name", "B", "M", "N", "K", "Count" + ) + ) + log.info("-" * 130) + for row in mm_table_data: + # pyrefly: ignore [not-iterable] + log.info("{:<30} | {:<20} | {:<20} | {:<20} | {:<20} | {:<20}".format(*row)) # noqa: G001 + log.info("-" * 130) + + # Not strictly necessary, but good to clean up straggling futures + # that are unused to reclaim memory. + torch._inductor.async_compile.CompiledTritonKernels.cache_clear() + + _step_logger()( + logging.INFO, + "torchinductor done compiling " + f"{'BACKWARDS' if graph_kwargs['is_backward'] else 'FORWARDS'} " + f"graph {graph_kwargs['graph_id']}", + ) + return compiled_graph + + +class _FxCompileStat: + # Count of successful compiles of this type + codegen_and_compile: int = 0 + + def __repr__(self) -> str: + return f"codegen_and_compile: {self.codegen_and_compile}" + + +class FxCompile(ABC): + """ + An FxCompile represents a mechanism that can turn a GraphModule into an + OutputCode. + """ + + # Some stats for logging/debugging + _compile_stats: dict[type[FxCompile], _FxCompileStat] = defaultdict(_FxCompileStat) + + # TODO: We should probably eventually add some kind of async version of this + # so we can kick off a compile and then go do other things - but we'll need + # to know what kind of API we want for that first. + @abstractmethod + def codegen_and_compile( + self, + gm: GraphModule, + example_inputs: Sequence[InputType], + inputs_to_check: Sequence[int], + graph_kwargs: _CompileFxKwargs, + ) -> OutputCode: ... + + @classmethod + def _reset_stats(cls) -> None: + cls._compile_stats.clear() + + +class _InProcessFxCompile(FxCompile): + @override + def codegen_and_compile( + self, + gm: GraphModule, + example_inputs: Sequence[InputType], + inputs_to_check: Sequence[int], + graph_kwargs: _CompileFxKwargs, + ) -> OutputCode: + """ + Generates the OutputCode from the GraphModule and example_inputs. + """ + # Sorry about the mess, we need graph_kwargs to continue to be able + # to propagate it further on + # TODO: _CompileFxKwargs actually has stronger types than in the + # signature, need to tighten it up + + assert "cudagraphs" in graph_kwargs and graph_kwargs["cudagraphs"] is not None + cudagraphs: BoxedBool = graph_kwargs["cudagraphs"] + static_input_idxs: Sequence[int] = graph_kwargs.get("static_input_idxs", ()) + is_backward: bool = graph_kwargs.get("is_backward", False) + graph_id: Optional[int] = graph_kwargs.get("graph_id", None) + cpp_wrapper: bool = graph_kwargs.get("cpp_wrapper", False) + fx_wrapper: bool = graph_kwargs.get("fx_wrapper", False) + aot_mode: bool = V.aot_compilation + is_inference: bool = graph_kwargs.get("is_inference", False) + extern_node_serializer: Optional[Callable[[list[ExternKernelNode]], Any]] = ( + graph_kwargs.get("extern_node_serializer", None) + ) + + with ( + _WaitCounter("pytorch.wait_counter.actual_codegen_and_compile").guard(), + dynamo_utils.preserve_rng_state(), + ): + if (sleep_sec := config.sleep_sec_TESTING_ONLY) is not None: + import time + + log.warning( + "Sleeping for %s since sleep_sec_TESTING_ONLY is set", sleep_sec + ) + time.sleep(sleep_sec) + + if is_tf32_warning_applicable(gm): + _warn_tf32_disabled() + + inductor_counters = counters["inductor"].copy() + + # lift the maximum depth of the Python interpreter stack + # to adapt large/deep models + sys.setrecursionlimit(max(sys.getrecursionlimit(), 2000)) + + _step_logger()( + logging.INFO, + "torchinductor compiling " + f"{'BACKWARDS' if is_backward else 'FORWARDS'} " + f"graph {graph_id}", + ) + + fd = io.StringIO() + torch._dynamo.repro.after_aot.save_graph_repro( + fd, gm, example_inputs, "inductor", save_dir=None + ) + runnable_graph_str = fd.getvalue() + + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "fx_graph_runnable", + "encoding": "string", + }, + payload_fn=lambda: runnable_graph_str, + ) + + V.debug.fx_graph(gm, example_inputs) + # TODO: Should we actually dump this? It should be redundant with the aot + # structured logs... + # trace_structured("inductor_input_graph", payload_fn=lambda: gm.print_readable(print_output=False)) + + shape_env = gm.shape_env + if shape_env is None: + shape_env = shape_env_from_inputs(example_inputs) + + # Convert view to reshape in the graph. This is necessary primarily for + # layout optimization. Do it unconditionally for uniformity. + # + # It's needed because when we do layout optimization, an contiguous tensor + # in eager mode may becomes a channels last tensor. A view op previously + # can be applied to the contiguous tensor may not be able to be applied + # on the channels tensor any more. An error like + # RuntimeError: view size is not compatible with input tensor's size and stride + # (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead. + # will be printed. + # + # Replace view op to reshape op in this case. + # As an example, timm_resnest/botnet26t_256/convnext_base etc. will fail if we don't do this. + # + # Also this has to be done before FakeTensorProp below to avoid the failed + # .view() call. + view_to_reshape(gm) + + with dynamo_timed( + "additional_fake_tensor_prop", log_pt2_compile_event=True + ): + # It is safe to run FakeTensorProp under no_grad because by the time + # we're in inductor, we assume that AOTAutograd has already "taken care" + # of autograd, so there should be no more autograd-related API's in the + # graph. + with torch.no_grad(): + fake_mode = fake_tensor_prop(gm, example_inputs) + + _recursive_record_original_output_strides(gm) + + # pattern matcher passes might not preserve striding information + # on node.meta["val"]. if in the future we rely on these being + # correct we will need to fix. + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "before_post_grad_graph", + "encoding": "string", + }, + payload_fn=lambda: gm.print_readable( + print_output=False, include_stride=True, include_device=True + ), + ) + with V.set_fake_mode(fake_mode): + # has some issues with memory in training + cuda_context = get_cuda_device_context(gm) + with cuda_context: + _recursive_post_grad_passes(gm, is_inference=is_inference) + V.debug.fx_graph_transformed(gm, example_inputs) + post_grad_graphs_log.debug( + "%s", + lazy_format_graph_code( + "AFTER POST GRAD", + gm, + include_stride=True, + include_device=True, + colored=True, + ), + ) + + # We're printing the graph to be used as a cache key - so a + # printer which is a little less readable but faster is + # appropriate. + inductor_post_grad_graph_str = gm.print_readable( + print_output=False, + include_stride=True, + include_device=True, + fast_sympy_print=True, + ) + # "inductor_post_grad_graph" is used in inductor provenance + # tracking highlighter front-end. + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "inductor_post_grad_graph", + "encoding": "string", + }, + payload_fn=lambda: inductor_post_grad_graph_str, + ) + if config.trace.provenance_tracking_level != 0: + provenance_tracking_json = ( + torch.fx.traceback.get_graph_provenance_json(gm.graph) + ) + torch._inductor.debug._inductor_post_to_pre_grad_nodes = ( + create_mapping_pre_post_grad_nodes( + torch._inductor.debug._pre_grad_graph_id, + provenance_tracking_json, + ) + ) + + metrics_context = get_metrics_context() + if metrics_context.in_progress(): + num_graph_breaks = counters["graph_break"].total() + CompileEventLogger.compilation_metric( + overwrite=True, num_graph_breaks=num_graph_breaks + ) + if config.is_fbcode(): + try: + log_optimus_to_scuba( + extra_logging={ + "pt2_configs": str(get_patched_config_dict()) + } + ) + except Exception: + # TODO(T216453900): need to work around for now to support vllm + # See details in vllm/compilation/pass_manager.py. + log.warning("failed to log pt2_configs") + + with ( + V.set_fake_mode(fake_mode), + maybe_disable_comprehensive_padding(example_inputs), + maybe_disable_graph_partition(cpp_wrapper, aot_mode), + ): + const_output_index = None + const_graph = None + const_wrapper_code = None + const_kernel_code = None + + if aot_mode and config.aot_inductor.use_runtime_constant_folding: + # torchbind objects have name that starts with _torchbind_obj + # See caffe2/torch/fx/_symbolic_trace.py?lines=406 + const_gm, const_output_index = split_const_gm( + gm, + skip_folding_node_fn=lambda node: node.op == "get_attr" + and isinstance(node.target, str) + and ( + node.target.startswith("_torchbind_obj") + or isinstance(node.meta.get("val", None), FakeScriptObject) + ), + ) + + const_graph = GraphLowering( + const_gm, + example_inputs=[], + shape_env=shape_env, + graph_id=graph_id, + cpp_wrapper=cpp_wrapper, + aot_mode=aot_mode, + extern_node_serializer=extern_node_serializer, + is_inference=is_inference, + is_backward=is_backward, + is_const_graph=True, + fx_wrapper=fx_wrapper, + ) + with ( + V.set_graph_handler(const_graph), + V.set_extern_kernel_nodes([]), + ): + assert cpp_wrapper, "AOT mode only supports C++ wrapper" + const_graph.run() + const_wrapper_code, const_kernel_code = ( + const_graph.codegen_with_cpp_wrapper() + ) + + graph = GraphLowering( + gm, + # example_inputs will be used by AOTInductor to dry-run the generated code for Triton kernel tuning. + # For the forward pass, we have the real inputs to be used as example_inputs. For the backward pass, + # we currently use fake tensors and defake them later. + example_inputs=example_inputs, + shape_env=shape_env, + graph_id=graph_id, + cpp_wrapper=cpp_wrapper, + aot_mode=aot_mode, + extern_node_serializer=extern_node_serializer, + is_inference=is_inference, + is_backward=is_backward, + const_output_index=const_output_index, + const_wrapper_code=( + const_wrapper_code.value if const_wrapper_code else None + ), + const_kernel_code=( + const_kernel_code.value if const_kernel_code else None + ), + const_module=const_graph, + inputs_to_check=inputs_to_check, + fx_wrapper=fx_wrapper, + ) + metrics_helper = metrics.CachedMetricsHelper() + + # We are going to start code generating runtime asserts, so make sure + # you don't start adding new ones in the lowering process + graph.freeze_runtime_asserts() + with ( + V.set_graph_handler(graph), + V.set_extern_kernel_nodes([]), + distributed_autotune.graph_context(), + ): + graph.run(*example_inputs) + output_strides: list[Optional[tuple[_StrideExprStr, ...]]] = [] + if graph.graph_outputs is not None: + # We'll put the output strides in the compiled graph so we + # can later return them to the caller via TracingContext + p = SymExprPrinter() + for out in graph.graph_outputs: + if ( + isinstance(out, IRNode) + and out.has_tensor_output() + and len(free_unbacked_symbols(out.get_stride())) == 0 + ): + # Convert to string for eval on the load path + output_strides.append( + tuple(p.doprint(s) for s in out.get_layout().stride) + ) + else: + output_strides.append(None) + + _check_triton_bf16_support(graph) + + # TODO: The switching between AOT mode and not here is a bit + # messy, but it's localized to the block of code below so I'm + # not going to touch it for now + + compiled_fn: Any + compiled_fn_runner = None + with dynamo_timed( + "GraphLowering.compile_to_fn", log_pt2_compile_event=True + ): + if graph.aot_mode and graph.fx_wrapper: + assert not graph.cpp_wrapper + compiled_fn = graph.codegen()[0].gm # type: ignore[attr-defined] + output_code_log.debug( + "Output graph module: \n%s", + compiled_fn.print_readable(print_output=False), + ) + + elif graph.aot_mode: + from .codecache import AotCodeCompiler + + assert graph.cpp_wrapper, ( + "AOT mode only supports C++ wrapper" + ) + wrapper_code, kernel_code = graph.codegen_with_cpp_wrapper() + output_code_log.debug( + "Output wrapper code: \n%s", wrapper_code.value + ) + if kernel_code.value: + output_code_log.debug( + "Output kernel code:\n%s", kernel_code.value + ) + + serialized_extern_kernel_nodes = None + if V.extern_kernel_nodes: + serialized_extern_kernel_nodes = ( + graph.extern_node_serializer(V.extern_kernel_nodes) + ) + output_code_log.debug( + "Serialized Extern Kernel Nodes: \n%s", + serialized_extern_kernel_nodes, + ) + + with dynamo_timed( + "AotCodeCompiler.compile", log_pt2_compile_event=True + ): + # Directly return the file path with the compiled code + compiled_fn = AotCodeCompiler.compile( + graph, + wrapper_code.value, + kernel_code.value, + serialized_extern_kernel_nodes, + device_type=graph.device_type, + additional_files=[ + *dict.fromkeys( + graph.wrapper_code.additional_files + + ( + const_graph.wrapper_code.additional_files + if const_graph + else [] + ) + ) + ], + ) + else: + compiled_module = graph.compile_to_module() + compiled_fn = compiled_module.call + compiled_fn_runner = getattr( + compiled_module, "runner", None + ) + + # Dump provenance artifacts for debugging trace + inductor_provenance_tracking_node_mappings = None + inductor_kernel_stack_trace_str = None + if config.trace.provenance_tracking_level != 0: + inductor_provenance_tracking_node_mappings = json.dumps( + torch._inductor.debug.dump_inductor_provenance_info() + ) + inductor_kernel_stack_trace_str = json.dumps( + torch._inductor.debug._inductor_kernel_stack_trace + ) + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "inductor_provenance_tracking_node_mappings", + "encoding": "json", + }, + payload_fn=lambda: inductor_provenance_tracking_node_mappings, + ) + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "inductor_provenance_tracking_kernel_stack_traces", + "encoding": "json", + }, + payload_fn=lambda: inductor_kernel_stack_trace_str, + ) + if inductor_kernel_stack_trace_str: + metrics_context = get_metrics_context() + if metrics_context.in_progress(): + metrics_context.add_to_set( + "inductor_provenance", + inductor_kernel_stack_trace_str, + ) + + node_runtimes = None + if inductor_metrics_log.isEnabledFor(logging.INFO): + num_bytes, nodes_num_elem, node_runtimes = graph.count_bytes() + # pyrefly: ignore [bad-assignment] + metrics.num_bytes_accessed += num_bytes + metrics.node_runtimes += node_runtimes + metrics.nodes_num_elem += nodes_num_elem + inductor_metrics_log.info( + "Graph Metrics:\n%s", + { + "num_bytes_accessed": num_bytes, + "nodes_num_elem": nodes_num_elem, + "node_runtimes": node_runtimes, + }, + ) + + # Collect and dump op runtimes and tensor metadata for TLParse + if config.log_tlparse: + _, _, node_runtimes = graph.count_bytes() + torch._inductor.debug.log_runtime_and_tensor_meta(node_runtimes) + + # Collect and dump collective-op schedule for external diagnostics + torch._inductor.debug.log_collective_schedule(graph.scheduler.nodes) + + # When graph_partition is enabled, skip this check - partitioning handles dynamic shapes + if ( + cudagraphs + and config.triton.cudagraph_skip_dynamic_graphs + and not config.graph_partition + and not V.graph.disable_cudagraphs_reason + and torch._inductor.utils.any_is_symbolic(*example_inputs) + ): + stack_trace = None + for node in gm.graph.nodes: + meta_val = node.meta.get("val", None) + if ( + node.op == "placeholder" + or not isinstance(meta_val, torch.Tensor) + or not torch._inductor.utils.any_is_symbolic(meta_val) + ): + continue + + if stack_trace := node.meta.get("stack_trace", None): + break + disable = "graph with symbolic shapes inputs and config.triton.cudagraph_skip_dynamic_graphs=True." + if stack_trace: + disable = f"{disable} Found from {stack_trace}\n" + else: + disable = f"{disable}\n" + # pyrefly: ignore [unbound-name] + V.graph.disable_cudagraphs_reason = disable + + # pyrefly: ignore [unbound-name] + # When graph_partition is enabled, skip this check - partitioning handles incompatible ops + if ( + cudagraphs + # pyrefly: ignore [unbound-name] + and not config.graph_partition + # pyrefly: ignore [unbound-name] + and not V.graph.disable_cudagraphs_reason + ): + maybe_incompat_node = get_first_incompatible_cudagraph_node(gm) + if maybe_incompat_node: + disable = f"disabling cudagraphs due to incompatible op {maybe_incompat_node.target}" + if stack_trace := maybe_incompat_node.meta.get( + "stack_trace", None + ): + disable = f"{disable} Found from {stack_trace}\n" + # pyrefly: ignore [unbound-name] + V.graph.disable_cudagraphs_reason = disable + + # pyrefly: ignore [unbound-name] + if V.aot_compilation: + assert isinstance( + compiled_fn, + # pyrefly: ignore [unbound-name] + (str, list, torch.fx.GraphModule), + ), type(compiled_fn) + return CompiledAOTI( + filename=compiled_fn, device_type=graph.device_type + ) + + # TODO: Hoist this above V.aot_compilation + # pyrefly: ignore [unbound-name] + if cudagraphs and not V.graph.disable_cudagraphs_reason: + from torch._inductor.cudagraph_utils import ( + check_lowering_disable_cudagraph, + ) + + # pyrefly: ignore [unbound-name] + V.graph.disable_cudagraphs_reason = ( + check_lowering_disable_cudagraph( + # pyrefly: ignore [unbound-name] + V.graph.device_node_mapping + ) + ) + + self._compile_stats[type(self)].codegen_and_compile += 1 + + if ( + # pyrefly: ignore [unbound-name] + torch._inductor.debug.RECORD_GRAPH_EXECUTION + # pyrefly: ignore [unbound-name] + and torch._inductor.debug.GRAPH_COMPILE_IDS is not None + ): + compile_id = str( + # pyrefly: ignore [unbound-name] + torch._guards.CompileContext.current_compile_id() + ) + graph_id = graph_kwargs.get("graph_id") + if graph_id is not None: + # pyrefly: ignore [unbound-name] + torch._inductor.debug.GRAPH_COMPILE_IDS[graph_id] = ( + compile_id + ) + + return CompiledFxGraph( + # pyrefly: ignore [bad-argument-type] + compiled_fn, + graph, + gm, + output_strides, + # pyrefly: ignore [unbound-name] + V.graph.disable_cudagraphs_reason, + metrics_helper.get_deltas(), + counters["inductor"] - inductor_counters, + cudagraphs, + example_inputs, + static_input_idxs, + graph_kwargs, + inputs_to_check, + runnable_graph_str, + inductor_post_grad_graph_str, + compiled_fn_runner, + inductor_provenance_tracking_node_mappings, + inductor_kernel_stack_trace_str, + ) + + +def fx_codegen_and_compile( + gm: GraphModule, + example_inputs: Sequence[InputType], + # This is derivable from the other inputs to this function, but we pass it + # in explicitly because it's nontrivial to compute + inputs_to_check: Sequence[int], + **graph_kwargs: Unpack[_CompileFxKwargs], +) -> OutputCode: + scheme: FxCompile + + if fx_compile_mode == FxCompileMode.NORMAL: + scheme = _InProcessFxCompile() + elif fx_compile_mode == FxCompileMode.SERIALIZE: + from .compile_fx_ext import _DebugSerdeFxCompile + + scheme = _DebugSerdeFxCompile() + elif fx_compile_mode == FxCompileMode.SUBPROCESS: + from .compile_fx_subproc import _SubprocessFxCompile + + scheme = _SubprocessFxCompile() + + if fx_compile_async: + from .compile_fx_async import _AsyncFxCompile + from .compile_fx_ext import _OutOfProcessFxCompile + + # pyrefly: ignore [unbound-name] + assert isinstance(scheme, _OutOfProcessFxCompile), ( + "async is only valid with an out-of-process compile mode" + ) + # pyrefly: ignore [unbound-name] + scheme = _AsyncFxCompile(scheme) + + if fx_compile_progressive: + from .compile_fx_async import _ProgressiveFxCompile + from .compile_fx_ext import _OutOfProcessFxCompile + + # pyrefly: ignore [unbound-name] + assert isinstance(scheme, _OutOfProcessFxCompile), ( + "progressive is only valid with an out-of-process compile mode" + ) + + progression_configs = _get_progression_configs() + + # Use in-process compile for the fast version + fast_scheme = _InProcessFxCompile() + + # pyrefly: ignore [unbound-name] + scheme = _ProgressiveFxCompile(fast_scheme, scheme, progression_configs) + + # pyrefly: ignore [unbound-name] + return scheme.codegen_and_compile(gm, example_inputs, inputs_to_check, graph_kwargs) + + +def get_input_idxs_to_check( + inputs: Sequence[InputType], + static_input_idxs: Sequence[int], +) -> Sequence[int]: + """ + This function runs at compile time, and generates a list of indices for which we + might need to do a copy to preserve alignment requirements. + """ + ids_to_check = [] + + for i, input in enumerate(inputs): + if not isinstance(input, torch.Tensor): + # non-tensors don't need alignment + continue + if not is_gpu(input.device.type): + # right now we only care for gpu tensors + continue + with maybe_get_suppress_shape_guards_ctx(): + # suppress guards so that tensor_is_aligned and should_assume_input_aligned + # do not add guards on input's storage offset + if i in static_input_idxs and tensor_is_aligned(input): + continue + if not should_assume_input_aligned(input): + continue + + # if we get here, then + # (a) our triton code assumes that the input is aligned + # (b) we can't be sure ahead of time that the input will actually be aligned. + # therefore, at runtime, we'll need to check that the input is aligned + # (and if not, clone it to make it aligned.) + ids_to_check.append(i) + + return ids_to_check + + +def cudagraphify( + model: Callable[..., Any], + static_input_idxs: Sequence[int] = (), + *, + device_index: int, + stack_traces: list[Optional[str]], + is_backward: bool, + is_inference: bool, + constants: tuple[torch.Tensor, ...] = (), + placeholders: Sequence[PlaceholderInfo] = (), + mutated_input_idxs: tuple[int, ...] = (), +) -> Callable[..., Any]: + from torch._inductor.cudagraph_trees import ( + cudagraphify_impl as new_cudagraphify_impl, + ) + + cudagraphify_fn: Callable[..., Any] + if config.triton.cudagraph_trees: + cudagraphify_fn = functools.partial( + new_cudagraphify_impl, + device_index=device_index, + stack_traces=stack_traces, + is_backward=is_backward, + is_inference=is_inference, + constants=constants, + placeholders=placeholders, + mutated_input_idxs=mutated_input_idxs, + compile_id=torch._guards.CompileContext.current_compile_id(), + ) + else: + cudagraphify_fn = cudagraphify_impl + + compiled_fn = None + + def run(new_inputs: Sequence[InputType]) -> Any: + nonlocal compiled_fn + if compiled_fn is None: + with dynamo_utils.preserve_rng_state(): + compiled_fn = cudagraphify_fn(model, new_inputs, static_input_idxs) # type: ignore[arg-type] + return compiled_fn(new_inputs) # type: ignore[arg-type] + + return run + + +def static_input(x: torch.Tensor) -> torch.Tensor: + """ + Copy and input while preserving strides + """ + return torch.empty_strided(x.size(), x.stride(), dtype=x.dtype, device=x.device) + + +def index_expanded_dims_and_copy_( + dst: torch.Tensor, + src: torch.Tensor, + expanded_dims: list[int], +) -> None: + "Index into expanded dimensions of both dst and src then copy_" + dst = index_expanded_dims(dst, expanded_dims) + src = index_expanded_dims(src, expanded_dims) + dst.copy_(src) + + +def cudagraphify_impl( + model: Callable[..., Any], + inputs: list[torch.Tensor], + static_input_idxs: Sequence[int] = (), +) -> Callable[[list[InputType]], Any]: + """ + Assumes inputs[static_input_idxs[i]] are always the same memory address + """ + check_input_idxs = get_input_idxs_to_check(inputs, static_input_idxs) # type: ignore[arg-type] + # pyrefly: ignore [annotation-mismatch] + static_input_idxs: OrderedSet[int] = OrderedSet( + remove_unaligned_input_idxs(inputs, static_input_idxs) # type: ignore[arg-type] + ) + copy_misaligned_inputs(inputs, check_input_idxs) # type: ignore[arg-type] + + assert isinstance(inputs, list) + + inps_expanded_dims = [ + get_expanded_dims(x) if idx not in static_input_idxs else [] + for idx, x in enumerate(inputs) + ] + + # allocate static tensor inputs + static_inputs = [ + ( + x + if not isinstance(x, torch.Tensor) + else static_input(x) + if idx not in static_input_idxs + else x.detach() + ) + for idx, x in enumerate(inputs) + ] + + # copy over input values for fresh allocations + for idx, (x, expanded_dims) in enumerate(zip(inputs, inps_expanded_dims)): + if isinstance(x, torch.Tensor) and idx not in static_input_idxs: + index_expanded_dims_and_copy_(static_inputs[idx], x, expanded_dims) + + # warmup + torch.cuda.synchronize() + stream = torch.cuda.Stream() + stream.wait_stream(torch.cuda.current_stream()) + # copy static_inputs because it will be cleared in model + with torch.cuda.stream(stream): + model(list(static_inputs)) + stream.synchronize() + torch.cuda.current_stream().wait_stream(stream) + torch.cuda.synchronize() + + # record + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph, stream=stream, capture_error_mode="thread_local"): + static_outputs = model(list(static_inputs)) + if not isinstance(static_outputs, (list, tuple)): + static_outputs = (static_outputs,) + + if config.size_asserts: + + def run(new_inputs: list[InputType]) -> Callable[[list[InputType]], Any]: + assert len(static_inputs) == len(new_inputs) + for idx, (dst, src, expanded_dims) in enumerate( + zip(static_inputs, new_inputs, inps_expanded_dims) + ): + if not isinstance(dst, torch.Tensor): + continue + assert isinstance(src, torch.Tensor) + if idx in static_input_idxs: + assert dst.data_ptr() == src.data_ptr() + else: + # TODO - could make one single op of multiple slices + # and avoid dispatch. + # Could also pre-index the `dst` tensors + index_expanded_dims_and_copy_(dst, src, expanded_dims) + new_inputs.clear() + graph.replay() + # pyrefly: ignore [bad-return] + return static_outputs + + else: + copy_indices = [ + idx for idx in range(len(static_inputs)) if idx not in static_input_idxs + ] + + def run(new_inputs: list[InputType]) -> Callable[[list[InputType]], Any]: + for idx in copy_indices: + expanded_dims = inps_expanded_dims[idx] + src = new_inputs[idx] + assert isinstance(src, torch.Tensor) + index_expanded_dims_and_copy_(static_inputs[idx], src, expanded_dims) + new_inputs.clear() + graph.replay() + # pyrefly: ignore [bad-return] + return static_outputs + + return align_inputs_from_check_idxs(run, check_input_idxs, OrderedSet()) + + +def compile_fx_aot( + model_: GraphModule, + example_inputs_: list[InputType], + inner_compile: _CompileFxCallable = compile_fx_inner, + config_patches: Optional[dict[str, Any]] = None, +) -> Union[list[Union[str, Weights]], str, GraphModule]: + assert isinstance(model_, GraphModule), model_ + + # [See NOTE] Unwrapping subclasses AOT + unwrap_tensor_subclass_parameters(model_) + + # pyrefly: ignore [annotation-mismatch] + config_patches: dict[str, Any] = copy.deepcopy(config_patches or {}) + + if not (config_patches.get("fx_wrapper", False) or config.fx_wrapper): + # If fx_wrapper is not set, then set cpp_wrapper + config_patches["cpp_wrapper"] = True + + output_path = config_patches.get( + "aot_inductor.output_path", config.aot_inductor.output_path + ) + + if output_path: + assert not output_path.endswith(".pt2"), ( + "The output path for aot_compile should not have an extension with .pt2 " + "this is for specifying the output path for the .so in AOTInductor. " + "If you would like to package the AOTInductor generated files " + "into a pt2, please call `torch._inductor.aoti_compile_and_package`." + ) + else: + config_patches = { + **config_patches, + "aot_inductor.output_path": code_hash(model_.code), + } + + from .utils import maybe_aoti_standalone_config + + config_patches = maybe_aoti_standalone_config(config_patches) + + extern_node_serializer = config_patches.pop("extern_node_serializer", None) + saved_compile_id = model_.meta.get("dynamo_compile_id", None) + saved_compile_context = torch._guards.CompileContext(saved_compile_id) + with ( + V.set_aot_compilation(True), + torch._guards.compile_context(saved_compile_context), + chromium_event_timed( + "compile_fx_aot", + log_pt2_compile_event=True, + reset_event_log_on_exit=True, + ), + get_metrics_context(), + ): + compiled_artifacts = compile_fx( + model_, + example_inputs_, + inner_compile=functools.partial( + inner_compile, + extern_node_serializer=extern_node_serializer, + ), + config_patches=config_patches, + ) + + assert isinstance(compiled_artifacts, CompiledAOTI) + + return compiled_artifacts.filename + + +_graph_counter = count(0) + + +def fw_compiler_freezing( + aot_autograd_model: GraphModule, + aot_example_inputs: Sequence[InputType], + dynamo_model: GraphModule, + num_example_inputs: int, + inner_compile: Callable[..., Any], + cudagraphs: BoxedBool, + graph_id: int, + forward_device: BoxedDeviceIndex, +) -> Callable[[list[object]], Sequence[torch.Tensor]]: + from torch._inductor.freezing import convert_conv_weights_to_channels_last, freeze + + # partition_fn won't be called + _recursive_joint_graph_passes(aot_autograd_model) + + layout_opt = GraphLowering.decide_layout_opt(aot_autograd_model, is_inference=True) + if layout_opt: + # make sure meta['val'] is properly setup + fake_tensor_prop(aot_autograd_model, aot_example_inputs, True) + convert_conv_weights_to_channels_last(aot_autograd_model) + + opt_model, preserved_arg_indices = freeze( + dynamo_model, + aot_autograd_model, + aot_example_inputs, # type: ignore[arg-type] + ) + + aot_example_inputs = [aot_example_inputs[ind] for ind in preserved_arg_indices] + + fake_mode = detect_fake_mode(aot_example_inputs) + + # for freezing, all graph outputs should be user visible + *_, model_outputs_node = opt_model.graph.nodes + model_outputs = model_outputs_node.args[0] + model_outputs_node.meta["user_visible_output_idxs"] = [ + idx for idx, n in enumerate(model_outputs) if isinstance(n, torch.fx.Node) + ] + + static_input_idxs: list[Any] = [] + # constant params will be real tensors, not fake + tracing_context = torch._guards.TracingContext.try_get() + unwrapped_args_offsets = [0] + max_offset_idx = 0 + if tracing_context is not None: + assert tracing_context.params_flat_unwrap_subclasses is not None + params_flat_unwrap = tracing_context.params_flat_unwrap_subclasses + max_offset_idx = max(0, len(params_flat_unwrap) - 1) + preserved_indices_params_flat = OrderedSet[int]() + unwrapped_idxs = tracing_context.params_unwrapped_to_flat_index + assert unwrapped_idxs is not None + current_offset = 0 + if len(params_flat_unwrap) > 0: + unwrapped_args_offsets = [] + + for i in range(len(params_flat_unwrap)): + if i not in preserved_arg_indices: + params_flat_unwrap[i] = None + if i > 0 and unwrapped_idxs[i] == unwrapped_idxs[i - 1]: + current_offset += 1 + else: + preserved_indices_params_flat.add(unwrapped_idxs[i]) + unwrapped_args_offsets.append(current_offset) + + # Deallocate wrapped params, if all subelements were deallocated + assert tracing_context.params_flat is not None + for i in range(len(tracing_context.params_flat)): + if i not in preserved_indices_params_flat: + tracing_context.params_flat[i] = None + + if tracing_context.fw_metadata: + static_input_idxs = tracing_context.fw_metadata.static_input_indices + + with mock.patch.object(fake_mode, "allow_non_fake_inputs", True): + optimized_function = inner_compile( + opt_model, + aot_example_inputs, + static_input_idxs=static_input_idxs, + cudagraphs=cudagraphs, + graph_id=graph_id, + is_inference=True, + boxed_forward_device_index=forward_device, + layout_opt=layout_opt, + ) + + # aot_inductor codegens a call that takes in just the inputs, so we don't return a wrapper + # that drops constant-ified params + if V.aot_compilation: + return optimized_function + + def wrapper(args: list[object]) -> Sequence[torch.Tensor]: + args_new = [ + args[i - unwrapped_args_offsets[min(i, max_offset_idx)]] + for i in preserved_arg_indices + ] + args.clear() + return optimized_function(args_new) + + wrapper._boxed_call = True # type: ignore[attr-defined] + + return wrapper + + +def get_cpp_wrapper_config() -> dict[str, object]: + if config.triton.cudagraphs: + log_cudagraph_skip_and_bump_counter( + format_default_skip_message("cpp wrapper enabled") + ) + + return { + # Set autotune_at_compile_time to True as default if the option is not explicitly set + "triton.autotune_at_compile_time": ( + config.triton.autotune_at_compile_time + if config.triton.autotune_at_compile_time is not None + else has_triton() + ), + "triton.autotune_cublasLt": False, + "triton.cudagraphs": False, # TODO: to be removed + "triton.store_cubin": True, + } + + +def get_cuda_device_context(gm: torch.fx.GraphModule) -> AbstractContextManager[None]: + """ + Returns a cuda device context manager if there is a single device in the graph + """ + if not torch.cuda.is_available(): + return contextlib.nullcontext() + + cuda_devices: OrderedSet[torch.device] = OrderedSet( + device for device in get_all_devices(gm) if device.type == "cuda" + ) + + return ( + torch.cuda.device(next(iter(cuda_devices))) # type: ignore[return-value] + if len(cuda_devices) == 1 + else contextlib.nullcontext() + ) + + +def partition_fn( + gm: GraphModule, + joint_inputs: Sequence[object], + **kwargs: object, +) -> tuple[GraphModule, GraphModule]: + cuda_context = get_cuda_device_context(gm) + with cuda_context: + # We can skip the invoke_subgraph because the + # entire_partition_fn is called recursively for invoke_subgraph + # in partitioning. + _recursive_joint_graph_passes(gm, skip_invoke_subgraph=True) + + static_lifetime_input_indices: Optional[list[int]] = kwargs.pop( # type: ignore[assignment] + "static_lifetime_input_indices", None + ) + + if config.custom_partitioner_fn is None: + with dynamo_utils.dynamo_timed( + "min_cut_rematerialization_partition", log_pt2_compile_event=True + ): + return min_cut_rematerialization_partition( + gm, + joint_inputs, + compiler="inductor", + static_lifetime_input_indices=static_lifetime_input_indices, + **kwargs, + ) + else: + assert isinstance(config.custom_partitioner_fn, CustomPartitionerFn) + with dynamo_utils.dynamo_timed( + config.custom_partitioner_fn.__class__.__name__, + log_pt2_compile_event=True, + ): + return config.custom_partitioner_fn( + gm, + joint_inputs, + compiler="inductor", + static_lifetime_input_indices=static_lifetime_input_indices, + **kwargs, + ) + + +def get_num_model_outputs(model: GraphModule) -> int: + model_outputs_node = output_node(model) + model_outputs = pytree.arg_tree_leaves(*model_outputs_node.args) + return len(model_outputs) + + +@dataclass(frozen=True) +class CompilerConfigExtra: + cudagraphs: BoxedBool + graph_id: int + forward_device: BoxedDeviceIndex + + +def create_compiler_config_extra(config: types.ModuleType) -> CompilerConfigExtra: + # Although cudagraphs may have been enabled via config, various + # conditions (which are tested within the bowels of Inductor) may + # force cudagraphs to be disabled. This mutable box lets us retrieve + # the final determination if cudagraphs actually can be used or not. + cudagraphs = BoxedBool(config.triton.cudagraphs) + + # TODO: The modern style is to use CompileId from TracingContext to + # identify Inductor compilation. However, this CompileId cannot + # uniquely identify multiple Inductor compilations that arise from + # DDPOptimizer + graph_id = next(_graph_counter) + + # See [Backward Generation Handling] + forward_device = BoxedDeviceIndex(None) + + return CompilerConfigExtra( + cudagraphs=cudagraphs, + graph_id=graph_id, + forward_device=forward_device, + ) + + +def compile_fx_forward( + gm: GraphModule, + example_inputs: Sequence[InputType], + num_orig_model_outputs: int, + num_example_inputs: int, + compiler_config_extra: CompilerConfigExtra, + inner_compile: Callable[..., OutputCode] = compile_fx_inner, + is_inference: bool = False, +) -> OutputCode: + """ + Compile the forward graph of the given graph module. + + Args: + gm: The graph module to compile. + example_inputs: The example inputs to use for compilation. + num_orig_model_outputs: The number of model outputs from the original dynamo graph. + num_example_inputs: The number of example inputs from the original dynamo graph. + compiler_config_extra: Extra configuration for the compiler. + inner_compile: The inner compile function to use. + is_inference: Whether this is an inference graph. + """ + + if is_inference: + # partition_fn won't be called + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "before_joint_graph", + "encoding": "string", + }, + payload_fn=lambda: gm.print_readable( + print_output=False, include_stride=True, include_device=True + ), + ) + + _recursive_joint_graph_passes(gm) + + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "after_joint_graph", + "encoding": "string", + }, + payload_fn=lambda: gm.print_readable( + print_output=False, include_stride=True, include_device=True + ), + ) + + fixed = torch._inductor.utils.num_fw_fixed_arguments( + num_example_inputs, len(example_inputs) + ) + + model_outputs_node = output_node(gm) + if config.keep_output_stride: + model_outputs = pytree.arg_tree_leaves(*model_outputs_node.args) + num_model_outputs = len(model_outputs) + + context = torch._guards.TracingContext.try_get() + # See Note [User Outputs in the inductor graph] + if context is not None and context.fw_metadata and not is_inference: + original_output_start_index = ( + context.fw_metadata.num_mutated_inp_runtime_indices + ) + else: + original_output_start_index = 0 + + assert num_orig_model_outputs <= num_model_outputs + + # Note [User Outputs in the inductor graph] + # We makes the following assumption + # For inference + # len(orig_model_outputs) == len(model_outputs) + # For training + # len(orig_model_outputs) <= len(model_outputs) + # During training, most of the time the model_outputs starts with + # original module's outputs followed by saved activations. + # But this can be not true if the model have inplace updated tensors. + # AOTAutograd will make those tensors being returned before the original + # module's output. + # To make things safe, we'll use original_output_start_index field + # set by AOTAutograd to decide where the original module outputs start. + orig_output_end_idx = original_output_start_index + num_orig_model_outputs + # Sanity check: we are about to splice out the "user" outputs from the full set + # of "graph" outputs. Make sure we're within bounds. + assert orig_output_end_idx <= num_model_outputs + + model_outputs_node.meta["user_visible_output_idxs"] = [ + idx + for idx in range(original_output_start_index, orig_output_end_idx) + if isinstance(model_outputs[idx], torch.fx.Node) + ] + else: + model_outputs_node.meta["user_visible_output_idxs"] = [] + + # We also mark the invoke_subgraph outputs as user_visible to + # force the outputs of invoke_subgraph subgraph to follow the + # original strides + _recursive_record_user_visible_output_idxs(gm) + + return inner_compile( + gm, + example_inputs, + static_input_idxs=get_static_input_idxs(fixed), + cudagraphs=compiler_config_extra.cudagraphs, + graph_id=compiler_config_extra.graph_id, + is_inference=is_inference, + boxed_forward_device_index=compiler_config_extra.forward_device, + ) + + +def compile_fx_backward( + gm: GraphModule, + example_inputs: Sequence[InputType], + compiler_config_extra: CompilerConfigExtra, + inner_compile: Callable[..., OutputCode] = compile_fx_inner, +) -> OutputCode: + """ + Compile the backward graph of the given graph module. + + Args: + gm: The graph module to compile. + example_inputs: The example inputs to use for compilation. + compiler_config_extra: Extra configuration for the compiler. + inner_compile: The inner compile function to use. + """ + from torch._dynamo.convert_frame import compile_lock + + with compile_lock: + model_outputs_node = output_node(gm) + if config.bw_outputs_user_visible: + model_outputs = pytree.arg_tree_leaves(*model_outputs_node.args) + model_outputs_node.meta["user_visible_output_idxs"] = [ + idx + for idx, n in enumerate(model_outputs) + if isinstance(n, torch.fx.Node) + ] + else: + model_outputs_node.meta["user_visible_output_idxs"] = [] + + fixed = count_tangents(gm) + with ( + config.patch(get_cpp_wrapper_config()) + if config.cpp_wrapper + else contextlib.nullcontext() + ): + return inner_compile( + gm, + example_inputs, + static_input_idxs=list(range(fixed)), + cudagraphs=compiler_config_extra.cudagraphs, + is_backward=True, + graph_id=compiler_config_extra.graph_id, + boxed_forward_device_index=compiler_config_extra.forward_device, + ) + + +def run_pre_grad_passes( + model_: GraphModule, example_inputs_: Sequence[InputType] +) -> GraphModule: + # "before_pre_grad_graph" is used in inductor provenance + # tracking highlighter front-end. + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "before_pre_grad_graph", + "encoding": "string", + }, + payload_fn=lambda: model_.print_readable( + print_output=False, include_stride=True, include_device=True + ) + + f"\n\n # graph id: {id(model_.graph)}", + ) + pre_grad_graphs_log.debug( + "%s", + lazy_format_graph_code( + "BEFORE PRE GRAD", + model_, + include_stride=True, + include_device=True, + colored=True, + ), + ) + torch._inductor.debug._pre_grad_graph_id = id(model_.graph) + + if config.trace.provenance_tracking_level == 1: + for node in model_.graph.nodes: + if node.stack_trace: + torch._inductor.debug._inductor_pre_grad_node_stack_trace[node.name] = ( + node.stack_trace + ) + + model_ = _recursive_pre_grad_passes(model_, example_inputs_) + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "after_pre_grad_graph", + "encoding": "string", + }, + payload_fn=lambda: model_.print_readable( + print_output=False, include_stride=True, include_device=True + ) + + f"\n\n # graph id: {id(model_.graph)}", + ) + return model_ + + +def compile_fx( + model_: GraphModule, + example_inputs_: Sequence[InputType], + inner_compile: Callable[..., OutputCode] = compile_fx_inner, + config_patches: Optional[dict[str, Any]] = None, + decompositions: Optional[dict[OpOverload, Callable[..., Any]]] = None, + ignore_shape_env: bool = False, +) -> CompileFxOutput: + """ + Main entry point for compiling given FX graph. Despite the fact that this + lives in :mod:`torch._inductor`, this function is responsible for calling + into AOT Autograd (and we will eventually get a callback to + ``inner_compile`` to perform actual compilation. In other words, this + function orchestrates end-to-end compilation for the inductor backend when + you use :func:`torch.compile`. + + NB: This function TAKES OWNERSHIP of the input ``model_`` and can potentially + mutate it! Make a copy if you need to preserve the original GraphModule. + """ + # Some arguments trigger a recursive call to compile_fx. Handle these + # short circuits first, before anything else + + from torch._inductor.compiler_bisector import CompilerBisector + + if CompilerBisector.disable_subsystem("inductor", "pre_grad_graph"): + return model_ + + if config_patches: + with config.patch(config_patches): + return compile_fx( + model_, + example_inputs_, + # need extra layer of patching as backwards is compiled out of scope + inner_compile=config.patch(config_patches)(inner_compile), + decompositions=decompositions, + ignore_shape_env=ignore_shape_env, + ) + + # Wake up the AsyncCompile subproc pool as early as possible (if there's cuda). + if any( + isinstance(e, torch.Tensor) and e.device.type in ("cuda", "xpu") + for e in example_inputs_ + ): + torch._inductor.async_compile.AsyncCompile.wakeup() + + if config.cpp_wrapper or config.fx_wrapper: + from torch._export.non_strict_utils import _fakify_script_objects + + cpp_wrapper_config = config.cpp_wrapper + fx_wrapper_config = config.fx_wrapper + + with ( + config.patch(get_cpp_wrapper_config()), + V.set_real_inputs(example_inputs_), + ): + inputs_: Sequence[InputType] = ( + _extract_inputs_from_exported_gm(model_, example_inputs_) + if isinstance(model_, GraphModule) + else example_inputs_ + ) + fake_mode = detect_fake_mode(inputs_) + with _fakify_script_objects(model_, inputs_, {}, fake_mode) as ( + patched_mod, + fake_args, + _, + _, + _, + ): + return _maybe_wrap_and_compile_fx_main( + patched_mod, + fake_args, + inner_compile=functools.partial( + inner_compile, + cpp_wrapper=cpp_wrapper_config, + fx_wrapper=fx_wrapper_config, + ), + decompositions=decompositions, + ignore_shape_env=ignore_shape_env, + ) + + return _maybe_wrap_and_compile_fx_main( + model_, + example_inputs_, + inner_compile, + decompositions, + ignore_shape_env, + ) + + +def _extract_inputs_from_exported_gm( + gm: GraphModule, example_inputs_: Sequence[InputType] +) -> Sequence[InputType]: + fake_inputs = [ + node.meta.get("val") for node in gm.graph.nodes if node.op == "placeholder" + ] + + if not config.fx_wrapper: + # Replace non-tensor inputs with Nones + # constant scalars embedded in the graph + # symbolic scalars (symint) are not supported in non-fx_wrapper mode + fake_inputs = [ + inp if isinstance(inp, torch.Tensor) else None for inp in fake_inputs + ] + + if any(v is not None for v in fake_inputs): + # Validate devices before switching to fake tensors. + for idx, fi, i in zip(count(), fake_inputs, example_inputs_): + if fi is not None and isinstance(fi, torch.Tensor): + assert isinstance(i, torch.Tensor) + if fi.device != i.device: + raise ValueError( + f"Device mismatch between fake input and example input at position #{idx}: " + f"{fi.device} vs {i.device}. If the model was exported via torch.export(), " + "make sure torch.export() and torch.aot_compile() run on the same device." + ) + return fake_inputs + + return example_inputs_ + + +def _maybe_wrap_and_compile_fx_main( + model_: GraphModule, + example_inputs_: Sequence[InputType], + inner_compile: Callable[..., OutputCode], + decompositions: Optional[dict[OpOverload, Callable[..., Any]]], + ignore_shape_env: bool, +) -> CompileFxOutput: + """ + Part of compile_fx, called after patching configs. + + Ultimately we want to call _compile_fx_main, where the actual work happens. + But under various conditions, various forms of wrapping might be needed + around _compile_fx_main. + """ + # Each wrapper below takes a self-contained compile_gm function which is + # called inside the wrapper. This just recursively calls this function. + compile_gm = functools.partial( + _maybe_wrap_and_compile_fx_main, + inner_compile=inner_compile, + decompositions=decompositions, + ignore_shape_env=ignore_shape_env, + ) + if not graph_returns_tuple(model_): + return make_graph_return_tuple(model_, example_inputs_, compile_gm) + + if isinstance(model_, GraphModule) and isinstance( + model_.graph._codegen, _PyTreeCodeGen + ): + # this graph is the result of dynamo.export() + return handle_dynamo_export_graph(model_, example_inputs_, compile_gm) + + if any(isinstance(x, (list, tuple, dict)) for x in example_inputs_): + # NB: this short circuit never occurs for Dynamo produced graphs + # (which are pre-flattened) + return flatten_graph_inputs(model_, example_inputs_, compile_gm) + + # Finally do the actual work! + return _compile_fx_main( + model_, + example_inputs_, + inner_compile, + decompositions, + ignore_shape_env, + ) + + +def _compile_fx_main( + model_: GraphModule, + example_inputs_: Sequence[InputType], + inner_compile: Callable[..., OutputCode], + decompositions: Optional[dict[OpOverload, Callable[..., Any]]], + ignore_shape_env: bool, +) -> CompileFxOutput: + """ + Main part of compile_fx, called after wrapping is done. + + Roughly speaking, here the steps will be: + (1) apply pre-grad passes + (2) create `fw_compiler` and `bw_compiler` functions out of `inner_compile` + (3) call aot_autograd, which: + - (3a) creates a joint graph with `decompositions`, + - (3b) partitions it with `partition_fn` into fw and bw graphs (applying joint-graph passes), + - (3c) calls `fw_compiler` and `bw_compiler` on those graphs (applying post-grad passes) + - (3d) finally, assembles the fw and bw compiled functions back together and returns. + """ + with ( + _use_lazy_graph_module(dynamo_config.use_lazy_graph_module), + enable_python_dispatcher(), + torch.fx.traceback.preserve_node_meta( + config.trace.provenance_tracking_level == 1 + ), + torch._inductor.debug.reset_provenance_globals(), + ): + # Pre-grad passes cannot be run if we weren't given a GraphModule. + # Dynamo will always produce a GraphModule, but this handles cases + # where a user directly passes a plain Module with the intention of + # having AOTAutograd trace it. + # TODO: Get rid of this? + if isinstance(model_, GraphModule): + model_ = run_pre_grad_passes(model_, example_inputs_) + + assert not config._raise_error_for_testing + + num_example_inputs = len(example_inputs_) + + compiler_config_extra = create_compiler_config_extra(config) + + decompositions = ( + decompositions if decompositions is not None else select_decomp_table() + ) + + def fw_compiler_base( + gm: GraphModule, + example_inputs: Sequence[InputType], + is_inference: bool, + ) -> OutputCode: + with dynamo_utils.dynamo_timed("compile_fx..fw_compiler_base"): + if isinstance(model_, GraphModule): + num_orig_model_outputs = get_num_model_outputs(model_) + else: + num_orig_model_outputs = get_num_model_outputs(gm) + return compile_fx_forward( + gm, + example_inputs, + num_orig_model_outputs=num_orig_model_outputs, + num_example_inputs=num_example_inputs, + compiler_config_extra=compiler_config_extra, + inner_compile=inner_compile, + is_inference=is_inference, + ) + + fw_compiler: Callable[[GraphModule, Sequence[InputType]], OutputCode] = ( + functools.partial(fw_compiler_base, is_inference=False) + ) + fw_compiler = SerializableAOTDispatchCompiler(OutputCode, fw_compiler) + + if config.freezing and not torch.is_grad_enabled(): + inference_compiler: Callable[..., Any] = functools.partial( + fw_compiler_freezing, + dynamo_model=model_, + num_example_inputs=num_example_inputs, + inner_compile=inner_compile, + cudagraphs=compiler_config_extra.cudagraphs, + graph_id=compiler_config_extra.graph_id, + forward_device=compiler_config_extra.forward_device, + ) + else: + inference_compiler = functools.partial(fw_compiler_base, is_inference=True) + inference_compiler = SerializableAOTDispatchCompiler( + OutputCode, inference_compiler + ) + + @compile_time_strobelight_meta(phase_name="backward") + def bw_compiler( + gm: GraphModule, example_inputs: Sequence[InputType] + ) -> OutputCode: + with ( + dynamo_utils.dynamo_timed("compile_fx..bw_compiler"), + ): + return compile_fx_backward( + gm, + example_inputs, + compiler_config_extra=compiler_config_extra, + inner_compile=inner_compile, + ) + + bw_compiler = SerializableAOTDispatchCompiler(OutputCode, bw_compiler) + + fake_mode = detect_fake_mode( + example_inputs_ + ) or torch._subclasses.FakeTensorMode(allow_non_fake_inputs=True) + tracing_context = ( + torch._guards.TracingContext.try_get() + or torch._guards.TracingContext(fake_mode) + ) + + if V.aot_compilation and not config.enable_autograd_for_aot: + from .utils import is_valid_aoti_model_name + + is_valid_aoti_model_name() + + with functorch_config.patch( + unlift_effect_tokens=True, + selective_decompose=config.selective_decompose, + ): + gm, graph_signature = aot_export_module( + model_, + example_inputs_, + trace_joint=False, + decompositions=decompositions, + ) + + from torch._export.utils import _detect_fake_mode_from_gm + + fake_mode = _detect_fake_mode_from_gm(gm) # type: ignore[assignment] + # aot_export_module doesn't account for constant tensor attributes + # so we end up having tensors that don't have fake vals attached. + # This can happen when upstream export is non-strict where we + # preserve the original module params/buffers. Once AOTI switches + # to ep.run_decompositions() flow to lower to post-autograd opset + # this will go away. + for node in gm.graph.nodes: + if node.op == "get_attr" and "val" not in node.meta: + target = attrgetter(node.target)(gm) + if isinstance(target, torch.Tensor): + assert fake_mode is not None + node.meta["val"] = fake_mode.from_tensor( + target, static_shapes=True + ) + elif isinstance(target, torch.ScriptObject) or is_opaque_type( + type(target) + ): + node.meta["val"] = ( + torch._library.fake_class_registry.maybe_to_fake_obj( + fake_mode, target + ) + ) + elif isinstance(target, FakeScriptObject): + node.meta["val"] = target + + unlifted_gm = _unlift_graph(model_, gm, graph_signature) + if "dynamo_flat_name_to_original_fqn" in model_.meta: + unlifted_gm.meta["dynamo_flat_name_to_original_fqn"] = model_.meta[ + "dynamo_flat_name_to_original_fqn" + ] + + if "dynamo_compile_id" in model_.meta: + unlifted_gm.meta["dynamo_compile_id"] = model_.meta["dynamo_compile_id"] + + # Disable amp as in aot_dispatch_autograd (https://github.com/pytorch/pytorch/pull/86515) + # In inference_compiler (fw_compiler_base), _recursive_joint_graph_passes will call into + # _sfdp_init() to register patterns. + # When fallback_random is set to True, the sdpa patterns will be traced during runtime. + # If amp is turned on, the traced FP32 patterns will have prims.convert_element_type which + # will be the same as the generated FP16 patterns. + disable_amp = torch._C._is_any_autocast_enabled() + context = ( + torch._C._DisableAutocast if disable_amp else contextlib.nullcontext + ) + with V.set_fake_mode(fake_mode), compiled_autograd._disable(), context(): + return inference_compiler(unlifted_gm, example_inputs_) + + with ( + V.set_fake_mode(fake_mode), + torch._guards.tracing(tracing_context), + compiled_autograd._disable(), + functorch_config.patch( + unlift_effect_tokens=True, + selective_decompose=config.selective_decompose, + ), + ): + try: + return aot_autograd( + fw_compiler=fw_compiler, + bw_compiler=bw_compiler, + inference_compiler=inference_compiler, + decompositions=decompositions, + partition_fn=partition_fn, + keep_inference_input_mutations=True, + cudagraphs=compiler_config_extra.cudagraphs, + boxed_forward_device_index=compiler_config_extra.forward_device, + ignore_shape_env=ignore_shape_env, + )(model_, example_inputs_) + except ShortenTraceback as e: + # We will also shorten the traceback inside dynamo. + # This is only useful if inductor is called directly with an FX graph. + raise e.remove_dynamo_frames() from None # see TORCHDYNAMO_VERBOSE=1 + + +def graph_returns_tuple(gm: GraphModule) -> bool: + """True if a FX graph returns a tuple""" + if not isinstance(gm, GraphModule): + return True # can't check this, assume true + (rv,) = output_node(gm).args + if isinstance(rv, (list, tuple)): + return True + if ( + isinstance(rv, torch.fx.node.Node) + and hasattr(rv.target, "_schema") + and len(rv.target._schema.returns) > 1 + and all(str(ret.type) == "Tensor" for ret in rv.target._schema.returns) + ): + # for graphs whose result is one node with multiple outputs + return True + return False + + +def make_graph_return_tuple( + gm: GraphModule, + inputs: Sequence[InputType], + compile_gm: Callable[..., Any], +) -> Callable[..., Any]: + """ + Mutate gm so it returns a tuple. This is only needed for graphs + not created by torchdynamo that return non-tuples. + """ + node = output_node(gm) + (rv,) = node.args + rv, spec = pytree.tree_flatten(rv) + with gm.graph.inserting_before(node): + gm.graph.output(rv) + gm.graph.erase_node(node) + assert graph_returns_tuple(gm) + + compiled_fn = compile_gm(gm, inputs) + + @functools.wraps(compiled_fn) + def wrapper(*args: Any, **kwargs: Any) -> Any: + return pytree.tree_unflatten(compiled_fn(*args, **kwargs), spec) + + return wrapper + + +def handle_dynamo_export_graph( + gm: GraphModule, + inputs: Sequence[InputType], + compile_gm: Callable[..., Any], +) -> Callable[..., Any]: + """ + `torch._dynamo.export` embeds pytrees in the FX graph codegen object, + convert that to a normal FX graph so inductor can compile it. + """ + codegen = gm.graph._codegen + gm.graph._codegen = torch.fx.graph.CodeGen() + gm.recompile() + + compiled_fn = compile_gm(gm, codegen.process_inputs(*inputs)) + + @functools.wraps(compiled_fn) # type: ignore[misc] + def wrapper(*args: Any) -> Any: + return codegen.process_outputs(compiled_fn(*codegen.process_inputs(*args))) + + return wrapper + + +def _check_triton_bf16_support(graph: GraphLowering) -> None: + def warn_and_skip(device: Optional[torch.device]) -> Never: + from torch._dynamo.exc import SkipFrame + + assert device is not None + + device_interface = get_interface_for_device(device.type) + device_props = device_interface.get_device_properties(device) + warnings.warn( + f"{device_props.name} does not support bfloat16 compilation natively, skipping" + ) + raise SkipFrame("BF16 is not supported") + + for node in itertools.chain(graph.graph_inputs.values(), graph.graph_outputs): + if not isinstance(node, IRNode): + continue + device_type = get_device_type(node) + if ( + not device_type + or not is_gpu(device_type) + or node.get_dtype() != torch.bfloat16 + ): + continue + # Print warning and skip frame if attempting to compile for bfloat16 + # on device without hardware support for dtype + device_interface = get_interface_for_device(device_type) + if device_interface.is_bf16_supported(including_emulation=False): + return + warn_and_skip(node.get_device()) + + +def _aoti_flatten_inputs( + gm: torch.fx.GraphModule, + args: Union[list[Any], tuple[Any, ...]], + kwargs: Optional[dict[str, Any]] = None, + *, + options: Optional[dict[str, Any]] = None, +) -> tuple[list[Any], dict[str, Any]]: + """ + Flatten the inputs to the graph module and return the flat inputs and options. + Add "aot_inductor.serialized_in_spec" and "aot_inductor.serialized_out_spec" to the options. + """ + # pyrefly: ignore [missing-module-attribute] + from .compile_fx import graph_returns_tuple + + assert graph_returns_tuple(gm), ( + "Graph output must be a tuple(). This is so that we can avoid " + "pytree processing of the outputs. Please change the module to " + "have tuple outputs." + ) + + # We will serialize the pytree info into the .so as constant strings + in_spec = None + out_spec = None + if isinstance(gm.graph._codegen, torch.fx.graph._PyTreeCodeGen): + codegen = gm.graph._codegen + gm.graph._codegen = torch.fx.graph.CodeGen() + gm.recompile() + + if codegen.pytree_info.in_spec is not None: + in_spec = codegen.pytree_info.in_spec + if codegen.pytree_info.out_spec is not None: + out_spec = codegen.pytree_info.out_spec + + else: + if hasattr(gm, "_in_spec"): + in_spec = gm._in_spec + if hasattr(gm, "_out_spec"): + out_spec = gm._out_spec + + serialized_in_spec = pytree.treespec_dumps(in_spec) if in_spec is not None else "" + serialized_out_spec = ( + pytree.treespec_dumps(out_spec) if out_spec is not None else "" + ) + + flat_args_with_path, received_spec = pytree.tree_flatten_with_path( + (args, kwargs or {}) + ) + + if any(isinstance(x[1], torch.ScriptObject) for x in flat_args_with_path): + from torch._dynamo.exc import UserError, UserErrorType + + raise UserError( + UserErrorType.INVALID_INPUT, + "TorchBind objects found in inputs. TorchBind object inputs are not supported in AOTInductor. " + "TorchBind objects can only be attributes.", + ) + + # Replace non-tensor (constant) inputs with Nones, since these are not being + # used anyways by the graph + flat_example_inputs = [ + x[1] if isinstance(x[1], torch.Tensor) else None for x in flat_args_with_path + ] + + if in_spec is not None and received_spec != in_spec: + raise ValueError( # noqa: B904 + "Trying to flatten user inputs with exported input tree spec: \n" + f"{in_spec}\n" + "but actually got inputs with tree spec of: \n" + f"{received_spec}" + ) + + options = ( + { + "aot_inductor.serialized_in_spec": serialized_in_spec, + "aot_inductor.serialized_out_spec": serialized_out_spec, + } + if options is None + else { + **options, + "aot_inductor.serialized_in_spec": serialized_in_spec, + "aot_inductor.serialized_out_spec": serialized_out_spec, + } + ) + return flat_example_inputs, options diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/compile_fx_async.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/compile_fx_async.py new file mode 100644 index 0000000000000000000000000000000000000000..95a0832349b1c0df53f5a2e429cb41d382557342 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/compile_fx_async.py @@ -0,0 +1,398 @@ +from __future__ import annotations + +from collections import deque +from dataclasses import dataclass +from typing import Any, Optional, TYPE_CHECKING +from typing_extensions import final, override + +import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools +from torch._inductor.output_code import CompiledFxGraphConstants, OutputCode + +from .compile_fx import _CompileFxKwargs, _InProcessFxCompile, FxCompile +from .output_code import complex_memory_overlap # noqa: F401 + + +# When async compile works with cache, remove the disabling below +BUG_CACHES_DONT_WORK_WITH_ASYNC = True + + +if TYPE_CHECKING: + from collections.abc import Callable, Sequence + from concurrent.futures import Future + + from torch._inductor.utils import InputType + from torch.fx import GraphModule + + from .compile_fx_ext import _OutOfProcessFxCompile, _WireProtocolPickledOutput + + +@dataclass +class _PostCompileData: + example_inputs: Sequence[InputType] + constants: CompiledFxGraphConstants + graph_kwargs: _CompileFxKwargs + + +@dataclass +class ProgressiveCompilationState: + progression_futures: deque[Future[_WireProtocolPickledOutput]] + callback: Callable[[_WireProtocolPickledOutput], OutputCode] + post_compile_data: Optional[_PostCompileData] + + def check_and_get_ready_stage(self) -> int: + """Check if any progression stage is ready and return its index, or -1 if none are ready.""" + if not self.progression_futures: + return -1 + + stage_index = -1 + if self.post_compile_data: + for i, future in enumerate(self.progression_futures): + if future.done(): + stage_index = i + + return stage_index + + def switch_to_progression_stage(self, stage_index: int) -> tuple[OutputCode, bool]: + """ + Switch to the specified progression stage and return the optimized output code. + Returns a tuple of (optimized_output_code, should_clear_compilation_state). + """ + future = self.progression_futures[stage_index] + assert future is not None + optimized_output_code = self.callback(future.result()) + + if pcd := self.post_compile_data: + optimized_output_code.post_compile( + pcd.example_inputs, pcd.constants, pcd.graph_kwargs + ) + + # Clear earlier progression futures to free memory + for _ in range(stage_index + 1): + self.progression_futures.popleft() + + # Return whether all compilation state should be cleared + should_clear_state = not self.progression_futures + return optimized_output_code, should_clear_state + + +# _AsyncOutputCode handles the actual management of waiting for an +# out-of-process compile to finish and then switching over to it. +@final +class _AsyncOutputCode(OutputCode): + _eager_fn: Optional[Callable[..., Any]] + _output_code: Optional[OutputCode] + _future: Optional[Future[_WireProtocolPickledOutput]] + _callback: Callable[[_WireProtocolPickledOutput], OutputCode] + _post_compile_data: Optional[_PostCompileData] = None + _boxed_call: bool # Copied from the forward/output_code + + def __init__( + self, + # eager_fn is run until the future is finished. + eager_fn: Callable[..., Any], + # this responds with the result of the out-of-process compile when it's + # ready. + future: Future[_WireProtocolPickledOutput], + # this callback gets called to turn the _WireProtocolPickledOutput into an OutputCode + callback: Callable[[_WireProtocolPickledOutput], OutputCode], + ) -> None: + self._eager_fn = eager_fn + self._boxed_call = getattr(eager_fn, "_boxed_call", False) + self._output_code = None + + self._future = future + self._callback = callback + + @override + def __call__(self, *args: Any) -> Any: + if self._future is not None and self._future.done(): + args = self._switch_to_compiled_fn(args) + + if eager_fn := self._eager_fn: + _AsyncFxCompile._stat_eager_runs += 1 + return eager_fn(*args) + + else: + _AsyncFxCompile._stat_compiled_runs += 1 + assert self._output_code is not None + return self._output_code.__call__(*args) + + # Takes and returns the args (converted to the "right" boxed mode) + def _switch_to_compiled_fn(self, args: tuple[Any, ...]) -> tuple[Any, ...]: + assert self._future is not None + + # TODO: If the future ended in an exception do we want to continue + # running eager or hit the exception now? + f, self._future = self._future, None + output_code = self._callback(f.result()) + + if pcd := self._post_compile_data: + self._post_compile_data = None + + output_code.post_compile( + pcd.example_inputs, pcd.constants, pcd.graph_kwargs + ) + + self._output_code = output_code + self._eager_fn = None + boxed_call = getattr(output_code, "_boxed_call", False) + + if self._boxed_call != boxed_call: + if self._boxed_call: + # Was boxed, now unboxed + args = args[0] if len(args) > 0 else () + else: + # Was unboxed, now boxed + args = (args,) + + self._boxed_call = boxed_call + return args + + @override + def post_compile( + self, + example_inputs: Sequence[InputType], + constants: CompiledFxGraphConstants, + graph_kwargs: _CompileFxKwargs, + ) -> None: + if self._eager_fn is not None: + self._post_compile_data = _PostCompileData( + example_inputs, constants, graph_kwargs + ) + else: + assert self._output_code is not None + self._output_code.post_compile(example_inputs, constants, graph_kwargs) + + +# Given an FxCompile for an out-of-process compile _AsyncFxCompile will run +# eager until the compiled artifact is ready then it will automatically switch +# over to using the compiled version. +@final +class _AsyncFxCompile(FxCompile): + _compile: _OutOfProcessFxCompile + + # Some debugging stats: + # Number of times we started a background compile. + _stat_bg_started: int = 0 + # Number of times we finished a background compile. + _stat_bg_finished: int = 0 + # Number of times we ran "eager" + _stat_eager_runs: int = 0 + # Number of times we ran our compiled (out-of-process) artifact + _stat_compiled_runs: int = 0 + + def __init__(self, compile: _OutOfProcessFxCompile) -> None: + self._compile = compile + + @classmethod + def _reset_stats(cls) -> None: + cls._stat_bg_started = 0 + cls._stat_bg_finished = 0 + cls._stat_eager_runs = 0 + cls._stat_compiled_runs = 0 + + @override + def codegen_and_compile( + self, + gm: GraphModule, + example_inputs: Sequence[InputType], + inputs_to_check: Sequence[int], + graph_kwargs: _CompileFxKwargs, + ) -> OutputCode: + eager_output_code = _InProcessFxCompile().codegen_and_compile( + gm, example_inputs, inputs_to_check, graph_kwargs + ) + + # This is similar to _SerializedFxCompile.codegen_and_compile() but + # handles the async routing. + + serialized = self._compile.serialize_compile( + gm, example_inputs, inputs_to_check, graph_kwargs + ) + if not serialized: + # We can't serialize - just return the eager OutputCode + return eager_output_code + + inputs, constants = serialized + + _AsyncFxCompile._stat_bg_started += 1 + f = self._compile._send_to_child_async(inputs) + + # This is called by _switch_to_compiled_fn() when f has a result... + def callback(pickled_output: _WireProtocolPickledOutput) -> OutputCode: + _AsyncFxCompile._stat_bg_finished += 1 + output = pickled_output.deserialize(constants) + self._compile._postprocess(output) + return output.graph + + return _AsyncOutputCode(eager_output_code, f, callback) + + +# _ProgressiveOutputCode handles running a fast compile first, then hot-swapping +# to a more optimized version when the expensive compile finishes. +@final +class _ProgressiveOutputCode(OutputCode): + _fast_output_code: Optional[OutputCode] + _optimized_output_code: Optional[OutputCode] + _compilation_state: Optional[ProgressiveCompilationState] + # _boxed_call state is effectively cached (we sometimes wrap unboxed w/ + # lambdas to box them) so we can't change it mid-way. Since _boxed_call=True + # is more common let's default to that and we'll convert if necessary. + _boxed_call: bool = True + + def __init__( + self, + # Fast compile that runs faster than the progressive compiles + fast_output_code: OutputCode, + # Futures for the progressive optimized compiles + progression_futures: Sequence[Future[_WireProtocolPickledOutput]], + # Callback to convert the optimized result to OutputCode + callback: Callable[[_WireProtocolPickledOutput], OutputCode], + ) -> None: + self._fast_output_code = fast_output_code + self._optimized_output_code = None + self._compilation_state = ProgressiveCompilationState( + progression_futures=deque(progression_futures), + callback=callback, + post_compile_data=None, + ) + + @override + def __call__(self, args: Sequence[Any]) -> Any: + # Check if any newer progression stage is ready and switch to it + self._check_and_switch_progression() + + if self._optimized_output_code is not None: + _ProgressiveFxCompile._stat_optimized_runs += 1 + output_code = self._optimized_output_code + else: + _ProgressiveFxCompile._stat_fast_runs += 1 + assert self._fast_output_code is not None + output_code = self._fast_output_code + + boxed_call = getattr(output_code, "_boxed_call", False) + if boxed_call: + res = output_code.__call__(args) + else: + res = output_code.__call__(*args) + return res + + def _check_and_switch_progression(self) -> None: + if not self._compilation_state: + return + + stage_index = self._compilation_state.check_and_get_ready_stage() + if stage_index == -1: + # no futures are ready + return + + self._switch_to_progression_stage(stage_index) + + def _switch_to_progression_stage(self, stage_index: int) -> None: + assert self._compilation_state is not None + optimized_output_code, should_clear_state = ( + self._compilation_state.switch_to_progression_stage(stage_index) + ) + + self._optimized_output_code = optimized_output_code + self._fast_output_code = None + + # Clear all compilation state if no more progression futures are left + if should_clear_state: + self._compilation_state = None + + @override + def post_compile( + self, + example_inputs: Sequence[InputType], + constants: CompiledFxGraphConstants, + graph_kwargs: _CompileFxKwargs, + ) -> None: + assert self._fast_output_code is not None + self._fast_output_code.post_compile(example_inputs, constants, graph_kwargs) + + assert self._compilation_state is not None + # Store for later when optimized version is ready + self._compilation_state.post_compile_data = _PostCompileData( + example_inputs, constants, graph_kwargs + ) + + +# _ProgressiveFxCompile runs a fast compile immediately, then kicks off +# progressive compiles in the background and hot-swaps when they're ready. +@final +class _ProgressiveFxCompile(FxCompile): + _fast_compile: FxCompile + _optimized_compile: _OutOfProcessFxCompile + _progression_configs: list[dict[str, Any]] + + # Debugging stats + _stat_bg_started: int = 0 + _stat_bg_finished: int = 0 + _stat_fast_runs: int = 0 + _stat_optimized_runs: int = 0 + + def __init__( + self, + fast_compile: FxCompile, + optimized_compile: _OutOfProcessFxCompile, + progression_configs: list[dict[str, Any]], + ) -> None: + self._fast_compile = fast_compile + self._optimized_compile = optimized_compile + self._progression_configs = progression_configs + + @classmethod + def _reset_stats(cls) -> None: + cls._stat_bg_started = 0 + cls._stat_bg_finished = 0 + cls._stat_fast_runs = 0 + cls._stat_optimized_runs = 0 + + @override + def codegen_and_compile( + self, + gm: GraphModule, + example_inputs: Sequence[InputType], + inputs_to_check: Sequence[int], + graph_kwargs: _CompileFxKwargs, + ) -> OutputCode: + import torch._inductor.config as inductor_config + + progression_futures: list[Future[_WireProtocolPickledOutput]] = [] + + for config in self._progression_configs: + with inductor_config.patch(config): + _ProgressiveFxCompile._stat_bg_started += 1 + + # Start the progressive compiles in the background + serialized = self._optimized_compile.serialize_compile( + gm, example_inputs, inputs_to_check, graph_kwargs + ) + + if not serialized: + continue + + inputs, constants = serialized + future = self._optimized_compile._send_to_child_async(inputs) + progression_futures.append(future) + + fast_output_code = self._fast_compile.codegen_and_compile( + gm, example_inputs, inputs_to_check, graph_kwargs + ) + + if not progression_futures: + # All async compile attempts failed - just return the fast version + return fast_output_code + + # Callback to handle the optimized result. + # This callback may be called multiple times, once for each progressive level completed, + # but may be skipped if a level either never completes or if a more optimal level + # completes before a less optimal one is switched to. + def callback(pickled_output: _WireProtocolPickledOutput) -> OutputCode: + _ProgressiveFxCompile._stat_bg_finished += 1 + output = pickled_output.deserialize(constants) + self._optimized_compile._postprocess(output) + return output.graph + + return _ProgressiveOutputCode(fast_output_code, progression_futures, callback) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/compile_fx_ext.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/compile_fx_ext.py new file mode 100644 index 0000000000000000000000000000000000000000..24048ccdda12ca0bc7b0173abc7cde4b057051fb --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/compile_fx_ext.py @@ -0,0 +1,683 @@ +from __future__ import annotations + +import contextlib +import dataclasses +import functools +import logging +import os +import queue +import sys +import warnings +from abc import abstractmethod +from dataclasses import dataclass +from typing import Any, Optional, TYPE_CHECKING, TypeGuard, Union +from typing_extensions import final, override, Self + +import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools +import torch.fx +from torch._inductor.codecache import BypassFxGraphCache, FxGraphCache +from torch._inductor.metrics import CachedMetricsDeltas, CachedMetricsHelper +from torch._inductor.output_code import ( + CompiledFxGraph, + CompiledFxGraphConstants, + CompiledFxGraphConstantsWithGm, + OutputCode, +) +from torch._subclasses import FakeTensorMode +from torch.utils._ordered_set import OrderedSet + +from . import config +from .compile_fx import _CompileFxKwargs, _InProcessFxCompile, FxCompile, log +from .debug import DebugContext +from .graph import GraphLowering +from .output_code import complex_memory_overlap # noqa: F401 +from .virtualized import V + + +if TYPE_CHECKING: + import types + from collections.abc import Generator, Mapping, Sequence + from concurrent.futures import Future + + from torch._inductor.utils import InputType + from torch.fx import GraphModule + + +@dataclass +class _VirtualizedSerializer: + """ + This handles the data for serializing Virtualized. + """ + + # The values here get serialized. We don't grab everything because some of + # the fields can't be serialized. + aot_compilation: Any = None + choices: Any = None + local_buffer_context: Any = None + ops: Any = None + kernel: Any = None + current_node: Any = None + + @classmethod + def serialize(cls) -> _VirtualizedSerializer: + """ + Turn the current state of torch._inductor.virtualized.V into a + serializable structure. + """ + kwargs = {} + for f in dataclasses.fields(cls): + kwargs[f.name] = getattr(V, f.name) + return _VirtualizedSerializer(**kwargs) + + def patch(self) -> _VirtualizedSerializerContextManager: + """ + Returns a context manager which patches the saved values into the + current environment. While patched, any value not listed above will be + poisoned so that reads will raise an error. + """ + return _VirtualizedSerializerContextManager(self) + + +class _VirtualizedSerializerContextManager(contextlib.ExitStack): + """ + Helper for _VirtualizedSerializer.patch() + """ + + def __init__(self, virtualized: _VirtualizedSerializer) -> None: + super().__init__() + self.virtualized = virtualized + + @override + def __enter__(self) -> Self: + super().__enter__() + + for set_name in dir(V): + if not set_name.startswith("set_"): + continue + name = set_name[4:] + name = name.removesuffix("_handler") + set_handler = getattr(V, set_name) + if hasattr(self.virtualized, name): + value = getattr(self.virtualized, name) + else: + # poison any values that we don't serialize so that any + # unset accesses are caught. + value = torch._inductor.virtualized._PoisonedVirtual + self.enter_context(set_handler(value)) + + return self + + +def _is_fallback_handler(op: object) -> bool: + try: + return op._is_fallback_handler # type: ignore[attr-defined] + except AttributeError: + return False + + +class _LoweringSerializer: + """ + This handles the data for serializing lowering.lowering + """ + + # A full implementation would make sure that all lowerings are copied over + # (or at least detected and raise a bypass when a non-standard lowering is + # used). For now we just handle tests by looking for lowerings that were + # overridden with a forced fallback. + fallbacks: OrderedSet[str] + + def __init__(self) -> None: + from . import lowering + + self.fallbacks = OrderedSet( + str(k) for k, v in lowering.lowerings.items() if _is_fallback_handler(v) + ) + + def patch(self) -> _LoweringSerializerContextManager: + return _LoweringSerializerContextManager(self) + + +class _LoweringSerializerContextManager(contextlib.ExitStack): + """ + Helper for _LoweringSerializer.patch() + """ + + def __init__(self, lowering: _LoweringSerializer) -> None: + super().__init__() + self.lowering = lowering + + @override + def __enter__(self) -> Self: + super().__enter__() + + from . import lowering + + for k, v in lowering.lowerings.items(): + name = str(k) + if name in self.lowering.fallbacks: + if not _is_fallback_handler(v): + self.enter_context(lowering.force_fallback(k)) # type: ignore[arg-type] + + return self + + +@dataclass +class _FakeTensorModeSerializer: + allow_non_fake_inputs: bool + + def __init__(self, fake_mode: FakeTensorMode) -> None: + self.allow_non_fake_inputs = fake_mode.allow_non_fake_inputs + self.shape_env = fake_mode.shape_env + + @contextlib.contextmanager + def patch(self, fake_mode: FakeTensorMode) -> Generator[None, None, None]: + saved_allow_non_fake_inputs = fake_mode.allow_non_fake_inputs + fake_mode.allow_non_fake_inputs = self.allow_non_fake_inputs + + yield + + fake_mode.allow_non_fake_inputs = saved_allow_non_fake_inputs + + +@dataclass +class _WireProtocolInput: + """ + For _SerializedFxCompile - encapsulates all the data being transferred + (sent) from the parent to the child. + """ + + gm: torch.fx.GraphModule + example_inputs: Sequence[InputType] + inputs_to_check: Sequence[int] + graph_kwargs: _CompileFxKwargs + tracing_context: Optional[torch._guards.TracingContext] + config: dict[str, object] + virtualized: _VirtualizedSerializer + deterministic_guard_for_testing: Optional[ # type: ignore[name-defined] # mypy bug + torch.testing._internal.common_utils.DeterministicGuard + ] + logger_state: _LoggerState + lowering: _LoweringSerializer + fake_tensor_mode: _FakeTensorModeSerializer + + def serialize(self) -> _WireProtocolPickledInput: + """ + Turns this object into a _WireProtocolPickledInput which can be + directly transferred across a stream. + """ + from torch.fx._graph_pickler import GraphPickler + + return _WireProtocolPickledInput(GraphPickler.dumps(self)) + + +def _current_fake_mode() -> FakeTensorMode: + fake_mode = None + if context := torch._guards.TracingContext.try_get(): + fake_mode = context.fake_mode + if fake_mode is not None: + return fake_mode + + shape_env = torch.fx.experimental.symbolic_shapes.ShapeEnv() + return FakeTensorMode(shape_env=shape_env) + + +@dataclass +class _WireProtocolPickledInput: + value: bytes + + def deserialize(self) -> _WireProtocolInput: + """ + Turn this streamable object back into a _WireProtocolInput. + """ + from torch.fx._graph_pickler import GraphPickler + + fake_mode = _current_fake_mode() + result = GraphPickler.loads(self.value, fake_mode) + assert isinstance(result, _WireProtocolInput) + return result + + +@dataclass +class _WireProtocolOutput: + """ + For _SerializedFxCompile - encapsulates all the data being transferred + (returned) back from the child to the parent. + """ + + graph: OutputCode + metrics: CachedMetricsDeltas + logs: list[logging.LogRecord] + warning_replay: Optional[list[warnings.WarningMessage]] + shape_env: Optional[torch.fx.experimental.symbolic_shapes.ShapeEnv] + + def serialize(self) -> _WireProtocolPickledOutput: + """ + Turns this object into a _WireProtocolPickledOutput which can be + directly transferred across a stream. + """ + from torch.fx._graph_pickler import GraphPickler + + if isinstance(self.graph, CompiledFxGraph): + self.graph.prepare_for_serialization() + return _WireProtocolPickledOutput(GraphPickler.dumps(self)) + + +@dataclass +class _WireProtocolPickledOutput: + value: bytes + + def deserialize(self, constants: CompiledFxGraphConstants) -> _WireProtocolOutput: + """ + Turn this streamable object back into a _WireProtocolOutput. + """ + from torch.fx._graph_pickler import GraphPickler + + fake_mode = _current_fake_mode() + result = GraphPickler.loads(self.value, fake_mode) + assert isinstance(result, _WireProtocolOutput) + if isinstance(result.graph, CompiledFxGraph): + result.graph.after_deserialization(constants) + return result + + +class _LoggerState: + """ + This class is for tracking logging that happens during an out-of-process + compile so we can "replay" those messages when the compile is done. Used as + a context manager which returns the captured logs (object). + """ + + loggers: dict[str, int] + # The actual log capturing mechanism - this should be None when we're not + # actively capturing logs. + captured_logs: Optional[_CapturedLogs] = None + + def __init__(self) -> None: + # Mapping from logger name to level. + self.loggers = {} + + def filter( + logger: Union[logging.Logger, logging.PlaceHolder], + ) -> TypeGuard[logging.Logger]: + if not isinstance(logger, logging.Logger): + # Assume that Placeholders propagate + return False + # We only want to track torch._inductor logging + if not logger.name.startswith("torch._inductor"): + return False + # If this logger propagates then assume we'll track its parent + if logger.propagate: + return False + return True + + root = logging.getLogger("torch._inductor") + if sys.version_info < (3, 12): + # logging.getChildren() doesn't exist until 3.12 + logging._acquireLock() # type: ignore[attr-defined] + try: + for logger in root.manager.loggerDict.values(): + if filter(logger): + self.loggers[logger.name] = logger.level + finally: + logging._releaseLock() # type: ignore[attr-defined] + else: + q = [root] + while q: + logger = q.pop() + if filter(logger): + self.loggers[logger.name] = logger.level + q.extend(logger.getChildren()) + + def __enter__(self) -> _CapturedLogs: + assert self.captured_logs is None + self.captured_logs = _CapturedLogs(self) + self.captured_logs.apply() + return self.captured_logs + + def __exit__( + self, + exc_type: Optional[type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[types.TracebackType], + ) -> None: + assert self.captured_logs is not None + self.captured_logs.remove() + + +class _CapturedLogs: + """ + Helper for _LoggerState - this class actually attaches to the logger in + the child process and grabs the log messages themselves. + """ + + state: _LoggerState + queue: queue.Queue[logging.LogRecord] + handlers: Optional[dict[str, logging.Handler]] + + def __init__(self, state: _LoggerState) -> None: + self.state = state + # A queue of the log entries + # TODO: For memory purposes should we log to a file and then respond with that? + self.queue = queue.Queue(-1) + # Mapping from name to handler (only valid when applied) + self.handlers = None + + def finish(self) -> list[logging.LogRecord]: + assert self.handlers is None + logs = [] + try: + while True: + logs.append(self.queue.get_nowait()) + except queue.Empty: + pass + return logs + + def remove(self) -> None: + assert self.handlers is not None + handlers, self.handlers = self.handlers, None + for name, handler in handlers.items(): + logger = logging.getLogger(name) + logger.removeHandler(handler) + + def apply(self) -> None: + from logging.handlers import QueueHandler + + assert self.handlers is None + self.handlers = {} + for name, level in self.state.loggers.items(): + logger = logging.getLogger(name) + handler = QueueHandler(self.queue) + self.handlers[name] = handler + logger.addHandler(handler) + if level != logging.NOTSET: + logger.setLevel(level) + + +class _SerializedFxCompile(FxCompile): + """ + This is used to represent an FxCompile which occurs across a serialized + boundary. + """ + + @override + def codegen_and_compile( + self, + gm: GraphModule, + example_inputs: Sequence[InputType], + inputs_to_check: Sequence[int], + graph_kwargs: _CompileFxKwargs, + ) -> OutputCode: + # If this code changes it's likely _AsyncFxCompile.codegen_and_compile() + # will also need to match. + + serialized = self.serialize_compile( + gm, example_inputs, inputs_to_check, graph_kwargs + ) + if not serialized: + return _InProcessFxCompile().codegen_and_compile( + gm, example_inputs, inputs_to_check, graph_kwargs + ) + + inputs, constants = serialized + output = self._send_to_child(inputs).deserialize(constants) + + self._postprocess(output) + self._compile_stats[type(self)].codegen_and_compile += 1 + + # TODO: Do we need to figure out what changed in TracingContext in the + # child and plumb that back up to the parent? + + return output.graph + + def serialize_compile( + self, + gm: GraphModule, + example_inputs: Sequence[InputType], + inputs_to_check: Sequence[int], + graph_kwargs: _CompileFxKwargs, + ) -> Optional[tuple[_WireProtocolPickledInput, CompiledFxGraphConstantsWithGm]]: + """ + Prepare a _WireProtocolInput to compile. If None is returned then it + wasn't possible to serialize and we should fallback to in-process. + """ + try: + # _check_for_hop raises BypassFxGraphCache when it detects something + # we can't cache (or serialize) + FxGraphCache._check_for_hop(gm) + except BypassFxGraphCache as e: + log.debug("Skipping %s compile: %s", type(self), e) # noqa: G200 + return None + + context = torch._guards.TracingContext.try_get() + constants = CompiledFxGraphConstantsWithGm(gm) + logger_state = _LoggerState() + lowering = _LoweringSerializer() + + # If we're running tests then grab the DeterministicGuard (don't want to + # import this if it isn't already imported because it has side-effects) + deterministic_guard_for_testing: Optional[ # type: ignore[name-defined] # mypy bug + torch.testing._internal.common_utils.DeterministicGuard + ] = None + try: + deterministic_guard_for_testing = ( + torch.testing._internal.common_utils.DeterministicGuard._current_state() # type: ignore[attr-defined] # mypy bug + ) + except AttributeError: + pass + + fake_mode = _current_fake_mode() + fake_tensor_mode = _FakeTensorModeSerializer(fake_mode) + + from pickle import PicklingError + + try: + input = _WireProtocolInput( + gm, + example_inputs, + inputs_to_check, + graph_kwargs, + context, + config.save_config_portable(), + _VirtualizedSerializer.serialize(), + deterministic_guard_for_testing, + logger_state, + lowering, + fake_tensor_mode, + ).serialize() + return (input, constants) + except (AttributeError, BypassFxGraphCache, PicklingError): + # For example: AttributeError: Can't pickle local object + # 'make_opaque_unary_fn..OpaqueUnaryFn' + + # TODO: scuba record about not being able to do this? + log.warning("Unable to pickle input graph or example inputs", exc_info=True) + + return None + + @abstractmethod + def _send_to_child( + self, pickled_input: _WireProtocolPickledInput + ) -> _WireProtocolPickledOutput: + # The implementation of this should transfer `input` to the child, call + # `_run_in_child(input)` and transfer the result back. + ... + + def _postprocess(self, output: _WireProtocolOutput) -> None: + pass + + @classmethod + def _run_in_child( + cls, + pickled_input: _WireProtocolPickledInput, + extra_env: Optional[Mapping[str, str]] = None, + ) -> _WireProtocolPickledOutput: + metrics = CachedMetricsHelper() + + with contextlib.ExitStack() as stack: + if extra_env is not None: + import unittest + + stack.enter_context(unittest.mock.patch.dict("os.environ", extra_env)) + + # Save warnings to "replay" in the parent + warning_replay = stack.enter_context(warnings.catch_warnings(record=True)) + + # TODO: Should we split the input into multiple sections where each + # section sets up state for the previous section? (i.e. a Config section + # which we decode and apply, followed by a FakeTensorMode section which + # we decode and apply, etc) + input = pickled_input.deserialize() + + stack.enter_context(input.virtualized.patch()) + stack.enter_context(input.lowering.patch()) + stack.enter_context(config.patch(input.config)) + captured_logs = stack.enter_context(input.logger_state) + if input.deterministic_guard_for_testing: + stack.enter_context(input.deterministic_guard_for_testing) + stack.enter_context(torch._guards.tracing(input.tracing_context)) + stack.enter_context(DebugContext()) + + fake_mode = _current_fake_mode() + stack.enter_context(input.fake_tensor_mode.patch(fake_mode)) + + output_graph = _InProcessFxCompile().codegen_and_compile( + input.gm, + input.example_inputs, + input.inputs_to_check, + input.graph_kwargs, + ) + + logs = captured_logs.finish() + + return _WireProtocolOutput( + output_graph, + metrics.get_deltas(), + logs, + warning_replay, + fake_mode.shape_env, + ).serialize() + + +# This is a debugging/testing implementation of FxCompile which serializes the +# input and output but still runs the FxCompile in-process. +@final +class _DebugSerdeFxCompile(_SerializedFxCompile): + @override + def _send_to_child( + self, pickled_input: _WireProtocolPickledInput + ) -> _WireProtocolPickledOutput: + # For debugging just serde the input and output but don't run in a + # subprocess. + return self._run_in_child(pickled_input) + + +class _OutOfProcessFxCompile(_SerializedFxCompile): + """ + Represents an FxCompile which is run outside the current process (in + either a subprocess or possibly even a separate machine). + """ + + @override + @final + def _send_to_child( + self, pickled_input: _WireProtocolPickledInput + ) -> _WireProtocolPickledOutput: + f = self._send_to_child_async(pickled_input) + + # For debugging: If we want to print status updates... + # last = time.time() + # while not f.done(): + # print("tick...") + # time.sleep(0.125) + # now = time.time() + # if now - last > 1: + # last = now + + return f.result() + + @abstractmethod + def _send_to_child_async( + self, pickled_input: _WireProtocolPickledInput + ) -> Future[_WireProtocolPickledOutput]: ... + + def _postprocess(self, output: _WireProtocolOutput) -> None: + # Since our metrics were gathered in a subprocess make sure to add them + # here. + CachedMetricsHelper.apply_deltas(output.metrics) + + # This is used by tests to check the output for specific details. For + # remote things (subproc and RE) we need to do the `save_output_code` + # here since it didn't happen earlier in-process. In the future if this + # doesn't have "source_code" (it's a CompiledAOTI, for example) and we + # need it we'll have to grab it and serialize it separately from the + # child. + if GraphLowering.save_output_code is not None: + GraphLowering.save_output_code(output.graph.source_code) # type: ignore[attr-defined] + + # And forward our collected logs. The cache is cleared when the outer + # function exits. + @functools.cache + def getLogger(name: str) -> logging.Logger: + return logging.getLogger(name) + + if output.warning_replay: + for w in output.warning_replay: + warnings.warn_explicit( + message=w.message, + category=w.category, + filename=w.filename, + lineno=w.lineno, + source=w.source, + ) + + for record in output.logs: + logger = getLogger(record.name) + logger.handle(record) + + +# For debugging - create a _FxCompile which writes the serialized data to a file +# and then exits. +# +# TODO: make this a FxCompileMode value? +# +# The "child runner" should look something like this: +# +# import torch +# from torch._inductor import compile_fx +# idx = 0 +# with open(f"/tmp/pytorch_compile_fx_tmp_input_{idx}.bin", "rb") as f: +# input = compile_fx._WireProtocolPickledInput(f.read()) +# result = compile_fx._SubprocessFxCompile._run_in_child(input) +# with open(f"/tmp/pytorch_compile_fx_tmp_output_{idx}.bin", "wb") as f: +# f.write(result.value) +# +@final +class _DebugFileFxCompile(_SerializedFxCompile): + file_index = 0 + + @override + def _send_to_child( + self, pickled_input: _WireProtocolPickledInput + ) -> _WireProtocolPickledOutput: + idx = _DebugFileFxCompile.file_index + _DebugFileFxCompile.file_index += 1 + + name = f"/tmp/aorenste/pytorch_compile_fx_tmp_input_{idx}.bin" + with open(name, "wb") as f: + f.write(pickled_input.value) + print(f"Wrote to {name}") + + if False: + name = f"/tmp/aorenste/pytorch_compile_fx_tmp_actual_{idx}.bin" + actual = self._run_in_child(pickled_input) + with open(name, "wb") as f: + f.write(actual.value) + return actual + elif False: + name = f"/tmp/aorenste/pytorch_compile_fx_tmp_output_{idx}.bin" + with open(name, "rb") as f: + result = _WireProtocolPickledOutput(f.read()) + print(f"Read from {name}") + return result + else: + os._exit(-1) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/compile_fx_subproc.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/compile_fx_subproc.py new file mode 100644 index 0000000000000000000000000000000000000000..58d5195046fd1500e18df1435c0db12c9cddfaec --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/compile_fx_subproc.py @@ -0,0 +1,93 @@ +from __future__ import annotations + +import atexit +import functools +import os +from typing import Optional, TYPE_CHECKING +from typing_extensions import final, override + +import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools +import torch.fx +from torch._inductor.compile_worker.subproc_pool import ( + AnyPool, + SubprocKind, + SubprocPool, +) +from torch._inductor.utils import clear_caches + +from .compile_fx_ext import ( + _OutOfProcessFxCompile, + _WireProtocolPickledInput, + _WireProtocolPickledOutput, +) +from .output_code import complex_memory_overlap # noqa: F401 + + +if TYPE_CHECKING: + from collections.abc import Mapping + from concurrent.futures import Future + + +@final +class _SubprocessFxCompile(_OutOfProcessFxCompile): + @override + def _send_to_child_async( + self, input: _WireProtocolPickledInput + ) -> Future[_WireProtocolPickledOutput]: + # TODO: Do we need to copy across some kind of logging IDs? (ChromiumEventLogger) + + pool = self.process_pool() + + # TODO: This is probably the wrong thing to do long-term - but for now + # let's share the cache so we can identify tests broken by this later. + env_vars = ["TORCHINDUCTOR_CACHE_DIR", "TRITON_CACHE_DIR"] + extra_env = {v: os.environ[v] for v in env_vars if v in os.environ} + + return pool.submit( + _SubprocessFxCompile._run_in_child_subprocess, input, extra_env + ) + + @staticmethod + @functools.cache + def process_pool() -> AnyPool: + pool = SubprocPool( + # TODO: Consider raising this limit if we start using async w/ + # subprocess and want to compile multiple graphs in parallel. + 1, + kind=SubprocKind.SPAWN, + ) + + atexit.register(pool.shutdown) + + return pool + + @classmethod + def _run_in_child_subprocess( + cls, + pickled_input: _WireProtocolPickledInput, + extra_env: Optional[Mapping[str, str]], + ) -> _WireProtocolPickledOutput: + # TODO: In subprocess mode we need to clear the inductor caches. + # The problem: + # 1. We compile in worker A which fills stuff in tmpdir + # 2. parent clears inductor caches which deletes tmpdirs and tells + # cpp_prefix_path() to clear its LRU cache + # 3. We compile a second time in subproc A - but since we never told + # cpp_prefix_path() in worker A to clear its LRU it thinks the + # tmpdir still exists and fails to compile. + # + # TODO: We probably should be using a separate tmpdir in the worker + # anyway... but we should probably still respect clear_caches() + # in the parent... maybe? + # + # TODO: We could be less aggressive by keeping a clock which gets + # incremented when we clear the cache, send the clock to the worker and + # only clear caches if the clock changed since last time. + # + clear_caches() + torch._inductor.metrics.reset() + + # TODO: turn off config.fx_graph_async_compile + + result = cls._run_in_child(pickled_input, extra_env) + return result diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/compiler_bisector.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/compiler_bisector.py new file mode 100644 index 0000000000000000000000000000000000000000..c27717bb54ec37ed4ae9951dd512d4c0607bba4d --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/compiler_bisector.py @@ -0,0 +1,644 @@ +import atexit +import collections +import dataclasses +import functools +import os +import shutil +import sys +import tempfile +from collections.abc import Callable +from dataclasses import dataclass, field +from typing import Optional + +from torch._inductor.runtime.cache_dir_utils import cache_dir + + +# Set the subdirectory name +SUBDIR_NAME = "bisect" + + +@dataclass +class Subsystem: + name: str + + +@dataclass +class BisectSubsystem(Subsystem): + pass + + +@dataclass +class BinarySubsystem(Subsystem): + pass + + +@dataclass +class ConfigChange(BinarySubsystem): + name: str = field(init=False) + config_name: str + config_field: str + config_value: object + + def __post_init__(self) -> None: + self.name = f"{self.config_name}_{self.config_field}" + + +# Dictionary of backend -> subsystems +BACKENDS: dict[str, list[Subsystem]] = { + # run dynamo without aot_autograd + "eager": [], + # run dynamo with aot_autograd, but no partitioner or decomps + "aot_eager": [], + # run dynamo with aot autograd, decompositions and partitioner + "aot_eager_decomp_partition": [ + ConfigChange("aot_eager_decomp_partition", "cse", False), + BisectSubsystem( + "decomposition" + ), # number of decompositions we apply in tracing + ], # TODO - add cse ? + # applies CrossRefFakeMode on invocation + "aot_eager_decomp_partition_crossref": [], + "inductor": [ + BisectSubsystem("pre_grad_passes"), # passes applied on pre-grad IR + BisectSubsystem("joint_graph_passes"), # passes applied on joint graph + BisectSubsystem( + "post_grad_passes" + ), # passes applied individually on forward, and backward in inductor + ConfigChange("inductor", "fallback_random", True), + ConfigChange("inductor", "emulate_precision_casts", True), + ConfigChange("inductor", "layout_optimization", False), + ConfigChange("inductor", "comprehensive_padding", False), + BisectSubsystem("lowerings"), # lowering aten operators to inductor + ], # TODO - add more - fusions ? +} + +subsystem_call_counter: dict[str, int] = collections.Counter() +call_counter_debug_info: dict[int, str] = {} + + +def reset_counters() -> None: + subsystem_call_counter.clear() + call_counter_debug_info.clear() + + +@functools.cache +def get_env_val(env_str: str) -> Optional[str]: + return os.environ.get(env_str, None) + + +@dataclasses.dataclass +class BisectionResult: + """ + backend: torch.compile backend responsible for failure + subsystem: optional, registered component identified for failure + bisect_number: optional, number of times the subsystem needed to be applied to trigger failure + debug_info: associated info of the triggering bisect application of subsystem + """ + + backend: str + subsystem: Optional[str] = None + bisect_number: Optional[int] = None + debug_info: Optional[str] = None + + +class CompilerBisector: + """ + This class iteratively runs torch.compile backends (eager, aot_eager, inductor) to find the + first backend that can repro an issue. + + Once it discovers the offending backend it will iteratively disable subsystems within the backend. + For subsystems which are applied repeatedly, such as the number of post grad passes or number + of lowering of nodes to inductor ir, it will bisect to find the offending application. + + The idiomatic way to run it is with `do_bisect`. You can also use it by setting the env flags + `TORCH_BISECT_BACKEND`, `TORCH_BISECT_SUBSYSTEM` and `TORCH_BISECT_MAX`. + + It also supports a CLI interface, although this is less well tested. + + You must run python compiler_bisector.py [start | good | bad | end] + """ + + bisection_enabled: bool = False + + in_process_cache: Optional[str] = None + + @classmethod + def get_dir(cls) -> str: + return f"{cache_dir() if not cls.in_process_cache else cls.in_process_cache}/{SUBDIR_NAME}" + + @classmethod + def write_lines_to_file(cls, file_path: str, lines: list[str]) -> None: + os.makedirs(os.path.dirname(file_path), exist_ok=True) + with open(file_path, "w") as file: + file.writelines(lines) + + @classmethod + def read_lines_from_file(cls, file_path: str) -> list[str]: + if os.path.exists(file_path): + with open(file_path) as file: + return file.readlines() + return [] + + @classmethod + def update_run_state( + cls, backend_name: str, subsystem: Subsystem, run_state: str + ) -> None: + file_path = os.path.join( + cls.get_dir(), backend_name, f"{subsystem.name}_run_state.txt" + ) + if isinstance(subsystem, ConfigChange): + assert run_state == "test_disable" + cls.set_config_values( + backend_name, + subsystem.name, + {subsystem.config_field: subsystem.config_value}, + ) + + cls.write_lines_to_file(file_path, [run_state]) + + @classmethod + def set_config_values( + cls, backend: str, subsystem: str, config_data: dict[str, object] + ) -> None: + file_path = os.path.join(cls.get_dir(), backend, f"{subsystem}_config.txt") + lines = [f"{k}={v}\n" for k, v in config_data.items()] + cls.write_lines_to_file(file_path, lines) + + @classmethod + def update_bisect_status(cls, backend_name: str, subsystem_name: str) -> None: + assert isinstance(subsystem_name, str) + file_path = os.path.join(cls.get_dir(), "bisect_status.txt") + lines = [f"backend={backend_name}\n", f"subsystem={subsystem_name}\n"] + cls.write_lines_to_file(file_path, lines) + + @classmethod + def update_bisect_range( + cls, backend_name: str, subsystem_name: str, low: int, high: int + ) -> None: + assert isinstance(subsystem_name, str) + file_path = os.path.join( + cls.get_dir(), backend_name, f"{subsystem_name}_bisect_range.txt" + ) + lines = [f"low={low}\n", f"high={high}\n"] + cls.write_lines_to_file(file_path, lines) + + @classmethod + def get_backend(cls) -> Optional[str]: + """ + Returns the active backend, if any + """ + if val := get_env_val("TORCH_BISECT_BACKEND"): + return val + + file_path = os.path.join(cls.get_dir(), "bisect_status.txt") + lines = cls.read_lines_from_file(file_path) + for line in lines: + if line.startswith("backend="): + return line.strip().split("=")[1] + return None + + @classmethod + def get_subsystem(cls) -> Optional[str]: + """ + Returns the active subsystem, if any + """ + + if val := get_env_val("TORCH_BISECT_SUBSYSTEM"): + return val + + file_path = os.path.join(cls.get_dir(), "bisect_status.txt") + lines = cls.read_lines_from_file(file_path) + for line in lines: + if line.startswith("subsystem="): + out = line.strip().split("=")[1] + return out if out else None + return None + + @classmethod + def get_subsystem_object(cls, backend_name: str, subsystem_name: str) -> Subsystem: + return next(obj for obj in BACKENDS[backend_name] if obj.name == subsystem_name) + + @classmethod + def get_run_state(cls, backend_name: str, subsystem_name: str) -> Optional[str]: + """ + Returns the current stage of bisecting, if Any + """ + + file_path = os.path.join( + cls.get_dir(), backend_name, f"{subsystem_name}_run_state.txt" + ) + lines = cls.read_lines_from_file(file_path) + if lines: + out = lines[0].strip() + assert out in ("test_disable", "find_max_bounds", "bisect") + return out + return None + + @classmethod + def get_bisect_range( + cls, backend_name: str, subsystem_name: str + ) -> tuple[int, int]: + file_path = os.path.join( + cls.get_dir(), backend_name, f"{subsystem_name}_bisect_range.txt" + ) + lines = cls.read_lines_from_file(file_path) + low = None + high = None + # pyrefly: ignore [bad-assignment] + for line in reversed(lines): + if line.startswith("low="): + low = int(line.strip().split("=")[1]) + elif line.startswith("high="): + high = int(line.strip().split("=")[1]) + + if low is not None and high is not None: + break + + if low is None or high is None: + raise RuntimeError( + f"Trying to get bisect range when it is not set: subsystem {subsystem_name}" + ) + + return low, high + + @classmethod + def update_config_change(cls, backend: str, subsystem: ConfigChange) -> None: + file_path = os.path.join(cls.get_dir(), backend, f"{subsystem.name}_config.txt") + lines = [ + f"config_name={subsystem.config_name}\n", + f"config_field={subsystem.config_field}\n", + f"config_value={subsystem.config_value}\n", + ] + cls.write_lines_to_file(file_path, lines) + + @classmethod + def get_config_change(cls, config_name: str) -> Optional[dict[str, object]]: + backend = cls.get_backend() + subsystem = cls.get_subsystem() + + if not backend or not subsystem: + return None + + file_path = os.path.join(cls.get_dir(), backend, f"{subsystem}_config.txt") + + if not os.path.exists(file_path): + return None + + lines = cls.read_lines_from_file(file_path) + config_data = {} + for line in lines: + key, value = line.strip().split("=", 1) + config_data[key] = eval(value) + + return config_data + + @classmethod + def delete_bisect_status(cls) -> None: + # in process_cache we have created if it exists, just the subdirectory of non created dir + dir_name = cls.in_process_cache if cls.in_process_cache else cls.get_dir() + if os.path.exists(dir_name): + shutil.rmtree(dir_name) + print("Bisection status deleted.") + else: + print("No bisection status found.") + + @classmethod + def get_system_counter(cls, name: str, increment: bool = True) -> int: + global subsystem_call_counter + curr = subsystem_call_counter[name] + if increment: + subsystem_call_counter[name] += 1 + return curr + + @classmethod + def disable_subsystem( + cls, + backend: str, + subsystem: str, + debug_info: Optional[Callable[[], str]] = None, + ) -> bool: + if not cls.bisection_enabled: + return False + + if cls.get_backend() != backend: + return False + + if cls.get_subsystem() != subsystem: + return False + + if val := get_env_val("TORCH_BISECT_MAX"): + counter = cls.get_system_counter(subsystem, increment=True) + return counter > int(val) + + run_state = cls.get_run_state(backend, subsystem) + if run_state == "test_disable": + # First run, disable completely + return True + elif run_state == "find_max_bounds": + # Second run, update bisection range and return True to enable the subsystem + cls.update_bisect_range( + backend, + subsystem, + 0, + cls.get_system_counter(subsystem, increment=True), + ) + return False + else: + assert run_state == "bisect" + # If the environment variable is not set, use the bisection range midpoint + low, high = cls.get_bisect_range(backend, subsystem) + # if high - low <= 2: + midpoint = (low + high) // 2 + call_counter = cls.get_system_counter(subsystem) + + if ( + call_counter >= low + and call_counter <= high + and (low - high) <= 2 + and debug_info is not None + ): + call_counter_debug_info[call_counter] = debug_info() + + return call_counter > midpoint + + @classmethod + def advance_subsystem( + cls, curr_backend: str, curr_subsystem: Subsystem + ) -> Optional[Subsystem]: + """ + Tries to move to the next subsystem within the current system. + """ + print(f"Disabling {curr_subsystem.name} did not fix the issue.") + + current_subsystems = BACKENDS[curr_backend] + current_subsystem_index = next( + i + for i, subsystem in enumerate(current_subsystems) + if subsystem.name == curr_subsystem.name + ) + + if current_subsystem_index < len(current_subsystems) - 1: + next_subsystem = current_subsystems[current_subsystem_index + 1] + cls.update_bisect_status(curr_backend, next_subsystem.name) + cls.update_run_state(curr_backend, next_subsystem, "test_disable") + print( + f"Moving to the next subsystem: {curr_backend} - {next_subsystem.name}" + ) + return next_subsystem + else: + print( + f"All subsystems in {curr_backend} have been checked. The issue is not in this system." + ) + return None + + @classmethod + def advance_backend(cls, curr_backend: str) -> Optional[str]: + """ + Tries Move to the next backend. + """ + current_system_index = list(BACKENDS.keys()).index(curr_backend) + + if current_system_index < len(BACKENDS) - 1: + curr_backend = list(BACKENDS.keys())[current_system_index + 1] + cls.update_bisect_status(curr_backend, "") + print(f"Moving to the next system: {curr_backend}") + return curr_backend + else: + return None + + @classmethod + def process_subsystem( + cls, + curr_backend: str, + curr_subsystem: Subsystem, + fn: Callable[[], bool], + cli_interface: bool = True, + ) -> bool: + """ + Process the current subsystem. Returns True if the issue is found, False otherwise. + """ + assert isinstance(curr_subsystem, Subsystem) + while True: + run_state = cls.get_run_state(curr_backend, curr_subsystem.name) + reset_counters() + if run_state == "test_disable": + if not fn(): + next_subsystem = cls.advance_subsystem(curr_backend, curr_subsystem) + if not next_subsystem: + return False + curr_subsystem = next_subsystem + else: + if isinstance(curr_subsystem, ConfigChange): + print( + f"Setting config {curr_subsystem.config_name} field {curr_subsystem.config_field} " + f"to {curr_subsystem.config_value} fixed the issue" + ) + else: + print(f"Disabling {curr_subsystem.name} fixed the issue.") + if isinstance(curr_subsystem, BinarySubsystem): + return True + print("Starting bisect by getting upper bound.") + cls.update_run_state( + curr_backend, curr_subsystem, "find_max_bounds" + ) + elif run_state == "find_max_bounds": + if fn(): + raise RuntimeError( + f"Function succeeded with 'find_max_bounds' status for {curr_backend} - {curr_subsystem.name}." + ) + else: + _, high = cls.get_bisect_range(curr_backend, curr_subsystem.name) + print(f"Upper bound of {high} found for {curr_backend}.") + cls.update_run_state(curr_backend, curr_subsystem, "bisect") + elif run_state == "bisect": + low, high = cls.get_bisect_range(curr_backend, curr_subsystem.name) + midpoint = (low + high) // 2 + print( + f"Bisecting {curr_backend} - {curr_subsystem.name} (Range: [{low}, {high}], Midpoint: {midpoint})" + ) + if fn(): + cls.update_bisect_range( + curr_backend, curr_subsystem.name, midpoint + 1, high + ) + else: + cls.update_bisect_range( + curr_backend, curr_subsystem.name, low, midpoint + ) + low, high = cls.get_bisect_range(curr_backend, curr_subsystem.name) + if low == high: + print( + f"Binary search completed for {curr_backend} - {curr_subsystem.name}. The bisect number is {low}. " + f"Debug info: {call_counter_debug_info.get(low, 'not found')}" + ) + return True + else: + raise RuntimeError(f"Unexpected run_state {run_state}") + + if cli_interface: + sys.exit(0) + + @classmethod + def initialize_system(cls) -> None: + curr_backend = next(iter(BACKENDS.keys())) + curr_subsystem = "" + cls.update_bisect_status(curr_backend, curr_subsystem) + print(f"Starting bisection process with system: {curr_backend}") + + @classmethod + def do_bisect( + cls, fn: Callable[[], bool], cli_interface: bool = False + ) -> Optional[BisectionResult]: + """ + Run fn repeatedly attempting to bisect torch.compile. fn should return True on success and False on failure. + """ + + # TODO graph bisecting is not well composed with lowering + # bisector so far. Use a config to opt-in + import torch._inductor.config as inductor_config + + if inductor_config.test_configs.bisect_pre_grad_graph: + BACKENDS["inductor"].insert(0, BisectSubsystem("pre_grad_graph")) + + if not cli_interface: + bisection_enabled_orig = cls.bisection_enabled + cls.delete_bisect_status() + cls.bisection_enabled = True + cls.in_process_cache = tempfile.mkdtemp() + + def cleanup() -> None: + cls.bisection_enabled = bisection_enabled_orig + cls.delete_bisect_status() + cls.in_process_cache = None + + if BACKENDS["inductor"][0].name == "pre_grad_graph": + del BACKENDS["inductor"][0] + + cleanup_handler = atexit.register(cleanup) + + class DisableBisect: + def __del__(self) -> None: + cleanup() + atexit.unregister(cleanup_handler) + + _cleanup = DisableBisect() + + curr_backend = cls.get_backend() + curr_subsystem_name = cls.get_subsystem() + + if not curr_backend: + cls.initialize_system() + curr_backend = cls.get_backend() + assert curr_backend is not None + curr_subsystem_name = cls.get_subsystem() + + curr_subsystem = ( + cls.get_subsystem_object(curr_backend, curr_subsystem_name) + if curr_subsystem_name is not None + else None + ) + while True: + assert curr_backend is not None + reset_counters() + if curr_subsystem: + result = cls.process_subsystem( + curr_backend, curr_subsystem, fn, cli_interface=cli_interface + ) + if result: + curr_subsystem = cls.get_subsystem_object( + curr_backend, + cls.get_subsystem(), # type: ignore[arg-type] + ) + + if isinstance(curr_subsystem, BinarySubsystem): + return BisectionResult( + curr_backend, + curr_subsystem.name, + 0, + curr_subsystem.name, + ) + + low, _ = cls.get_bisect_range(curr_backend, curr_subsystem.name) + return BisectionResult( + curr_backend, + curr_subsystem.name, + low, + call_counter_debug_info.get(low), + ) + + next_subsystem = cls.advance_subsystem(curr_backend, curr_subsystem) + if not next_subsystem: + print( + f"The issue is in the {curr_backend} system, but could not identify subsystem." + ) + assert curr_backend is not None + return BisectionResult(curr_backend) + + curr_subsystem = next_subsystem + else: + if fn(): + next_backend = cls.advance_backend(curr_backend) + if not next_backend: + print("All systems have been checked.") + return None + + curr_backend = next_backend + else: + current_subsystems = BACKENDS[curr_backend] + if current_subsystems: + curr_subsystem = current_subsystems[0] + cls.update_bisect_status(curr_backend, curr_subsystem.name) + cls.update_run_state( + curr_backend, curr_subsystem, "test_disable" + ) + print( + f"The issue is in the {curr_backend} system. Moving to the first subsystem: {curr_subsystem}" + ) + else: + print(f"The issue is in the {curr_backend} system.") + return BisectionResult(curr_backend) + + if cli_interface: + sys.exit(0) + + +def command_line_usage() -> None: + if len(sys.argv) < 2: + print("Usage: python bisect_update.py ") + sys.exit(1) + + bisection_manager = CompilerBisector() + command = sys.argv[1] + + if command == "end": + bisection_manager.delete_bisect_status() + sys.exit(0) + + if command == "start": + bisection_manager.delete_bisect_status() + bisection_manager.initialize_system() + sys.exit(0) + + if command not in ["good", "bad"]: + print("Invalid command. Must be 'good', 'bad', 'start', or 'end'.") + sys.exit(1) + + def test_function() -> bool: + return command == "good" + + if not bisection_manager.get_backend(): + raise ValueError("Must call start prior to good or bad") + + bisection_manager.do_bisect(test_function, cli_interface=True) + + +def get_is_bisection_enabled() -> bool: + return ( + CompilerBisector.get_subsystem() is not None + or CompilerBisector.get_backend() is not None + ) + + +CompilerBisector.bisection_enabled = get_is_bisection_enabled() + +if __name__ == "__main__": + command_line_usage() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/config.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/config.py new file mode 100644 index 0000000000000000000000000000000000000000..6d2ec9fffee0e891c4cf0132416f303c20e55879 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/config.py @@ -0,0 +1,2290 @@ +import os +import sys +from collections.abc import Callable +from typing import Any, Literal, Optional, TYPE_CHECKING, Union + +import torch +import torch._inductor.custom_graph_pass +from torch._environment import is_fbcode +from torch.utils._config_module import Config, get_tristate_env, install_config_module + + +if TYPE_CHECKING: + from torch._inductor.choices import InductorChoices + +inplace_padding = os.environ.get("TORCHINDUCTOR_INPLACE_PADDING", "1") == "1" +can_inplace_pad_graph_input = False # ease testing + + +def fx_graph_remote_cache_default() -> Optional[bool]: + return get_tristate_env("TORCHINDUCTOR_FX_GRAPH_REMOTE_CACHE") + + +def vec_isa_ok_default() -> Optional[bool]: + if os.environ.get("TORCHINDUCTOR_VEC_ISA_OK") == "1": + return True + if os.environ.get("TORCHINDUCTOR_VEC_ISA_OK") == "0": + return False + return None + + +def autotune_remote_cache_default() -> Optional[bool]: + return get_tristate_env("TORCHINDUCTOR_AUTOTUNE_REMOTE_CACHE") + + +def bundled_autotune_remote_cache_default() -> Optional[bool]: + return get_tristate_env("TORCHINDUCTOR_BUNDLED_AUTOTUNE_REMOTE_CACHE") + + +def bundle_triton_into_fx_graph_cache_default() -> Optional[bool]: + return get_tristate_env( + "TORCHINDUCTOR_BUNDLE_TRITON_INTO_FX_GRAPH_CACHE", + True if not is_fbcode() else None, + ) + + +def static_cuda_launcher_default() -> bool: + STATIC_CUDA_LAUNCHER_VERSION = 2 + + if "TORCHINDUCTOR_USE_STATIC_CUDA_LAUNCHER" in os.environ: + return os.environ.get("TORCHINDUCTOR_USE_STATIC_CUDA_LAUNCHER") == "1" + elif is_fbcode(): + version = torch._utils_internal.justknobs_getval_int( + "pytorch/inductor:static_cuda_launcher_version" + ) + return version <= STATIC_CUDA_LAUNCHER_VERSION + else: + # Default true in OSS + return True + + +def prologue_fusion_enabled() -> bool: + ENABLE_PROLOGUE_FUSION_VERSION = 0 + + if "TORCHINDUCTOR_PROLOGUE_FUSION" in os.environ: + return os.environ.get("TORCHINDUCTOR_PROLOGUE_FUSION") == "1" + elif is_fbcode(): + jk_name = "pytorch/inductor:prologue_fusion_version" + version = torch._utils_internal.justknobs_getval_int(jk_name) + return version <= ENABLE_PROLOGUE_FUSION_VERSION + else: + return True + + +# Enable auto_functionalized_v2 (enabled by default) +enable_auto_functionalized_v2 = ( + os.environ.get("TORCHDYNAMO_AUTO_FUNCTIONALIZED_V2", "1") == "1" +) + +# add some debug printouts +debug = False + +# Whether to disable a progress bar for autotuning +disable_progress = True + +# Whether to enable printing the source code for each future +verbose_progress = False + +# Configurable compile worker logging path for subproc_pool +worker_log_path = ( + "/logs/dedicated_log_torch_compile_worker_rank" if is_fbcode() else None +) + +# precompilation timeout +precompilation_timeout_seconds: int = 60 * 60 + +# use fx aot graph codegen cache +fx_graph_cache: bool = Config( + justknob="pytorch/remote_cache:enable_local_fx_graph_cache", + env_name_default="TORCHINDUCTOR_FX_GRAPH_CACHE_DEFAULT", + env_name_force="TORCHINDUCTOR_FX_GRAPH_CACHE", + default=True, +) + +remote_gemm_autotune_cache: bool = False + +# use remote fx aot graph codegen cache +# False: Disables the cache +# True: Enables the cache +# None: Not set -- Off for OSS, JustKnobs based for internal +fx_graph_remote_cache: Optional[bool] = fx_graph_remote_cache_default() + +# should we bundle triton caching into fx graph cache +bundle_triton_into_fx_graph_cache: Optional[bool] = ( + bundle_triton_into_fx_graph_cache_default() +) + +non_blocking_remote_cache_write: bool = Config( + justknob="pytorch/remote_cache:enable_non_blocking_remote_cache_write_v2", + env_name_force="TORCHINDUCTOR_NON_BLOCKING_REMOTE_CACHE_WRITE", + default=True, +) + +# Enable autotune local cache. +# +# See bundled_autotune_remote_cache for the effect this flag has on the bundled +# remote cache. +autotune_local_cache: bool = True + +# Enable autotune remote cache. +# +# Enables/disables the autotune remote cache regardless of the state of +# autotune_local_cache. If both local and remote are enabled then on write both +# are written and on read local is checked first and only on a cache miss is +# remote read. +# +# False: Disables the cache +# True: Enables the cache +# None: Not set -- Off for OSS, JustKnobs based for internal +autotune_remote_cache: Optional[bool] = autotune_remote_cache_default() + +# Enable bundled autotune cache. +# +# Enables/disables the bundled autotune cache regardless of the state of +# autotune_remote_cache. However it does depend on the local cache for local +# state management - as a result if the local cache is disabled this will also +# disable the bundled autotune cache. +# +# False: Disables the cache +# True: Enables the cache (requires autotune_local_cache) +# None: Not set -- Off for OSS, JustKnobs based for internal +bundled_autotune_remote_cache: Optional[bool] = bundled_autotune_remote_cache_default() + +# See torch.compiler.config.force_disable_caches +force_disable_caches: bool = Config(alias="torch.compiler.config.force_disable_caches") + +# Unsafe way to skip dynamic shape guards to get faster cache load +unsafe_skip_cache_dynamic_shape_guards: bool = False + +# Unsafe way to mark non torch functions as safe to cache +# dictionary is from function name -> cache key +# Any function name in the dictionary will be allowed to be cacheable +# by AOTAutogradCache and FxGraphCache. +# changing the cache key value will change the resulting +# FXGraphCache key. +# Example usage: +# torch._inductor.config.unsafe_marked_cacheable_functions = { +# 'torch.ops.my_function' : torch.__version__ +# } +# The above example causes the custom op torch.ops.my_function to be cacheable, +# and for cache keys to be keyed by the current torch version +unsafe_marked_cacheable_functions: dict[str, str] = {} + +# sleep in inductor for testing +sleep_sec_TESTING_ONLY: Optional[int] = None + +# The default layout constraint for user-defined triton kernels. +# See "The default layout constraint for custom operators" for options. +triton_kernel_default_layout_constraint: Literal[ + "needs_fixed_stride_order", "flexible_layout" +] = "needs_fixed_stride_order" + +# use cpp wrapper instead of python wrapper +# incompatible with disable_cpp_codegen +cpp_wrapper: bool = os.environ.get("TORCHINDUCTOR_CPP_WRAPPER", "0") == "1" + +# controls whether to compile entry and kernel separately for cpp_wrapper mode. +# turn on this option to compile entry and kernel separately and minimize compile time of the entry part. +# see https://github.com/pytorch/pytorch/pull/148773 +# Note: compiling entry and kernel separately may have a non-negligible impact on the performance. +# see https://github.com/pytorch/pytorch/issues/156037 +cpp_wrapper_build_separate: bool = ( + os.environ.get("TORCHINDUCTOR_CPP_WRAPPER_BUILD_SEPARATE", "0") == "1" +) + +fx_wrapper: bool = os.environ.get("TORCHINDUCTOR_FX_WRAPPER", "0") == "1" + +# Controls automatic precompiling of common include files for codecache.CppCodeCache +# (i.e. for cpp_wrapper mode and for cpp kernels on CPU). AOTI header precompiling is +# controlled by a separate flag. +cpp_cache_precompile_headers: bool = not is_fbcode() + +online_softmax = os.environ.get("TORCHINDUCTOR_ONLINE_SOFTMAX", "1") == "1" + +# dead code elimination +dce = False + +# assume weight tensors are fixed size +static_weight_shapes = True + +# put correctness assertions in generated code +size_asserts = os.environ.get("TORCHINDUCTOR_SIZE_ASSERTS", "1") == "1" +nan_asserts = os.environ.get("TORCHINDUCTOR_NAN_ASSERTS") == "1" +runtime_triton_nan_asserts = ( + os.environ.get("TORCHINDUCTOR_RUNTIME_TRITON_NAN_ASSERTS") == "1" +) +scalar_asserts = os.environ.get("TORCHINDUCTOR_SCALAR_ASSERTS", "1") == "1" + +# Disable by default in fbcode +alignment_asserts = ( + os.environ.get("TORCHINDUCTOR_ALIGNMENT_ASSERTS", "0" if is_fbcode() else "1") + == "1" +) + +# enable loop reordering based on input orders +pick_loop_orders = True + +# reuse a kernel input as the output +inplace_buffers = True + +# reuse a buffer for an unrelated purpose +allow_buffer_reuse = True + +# Enable pooled allocations for non-output tensors +memory_planning = os.environ.get("TORCHINDUCTOR_MEMORY_PLANNING", "0") == "1" + +# Enable to allow using ftz variant of exponenet instruction in triton codegen. +use_fast_math = os.environ.get("TORCHINDUCTOR_USE_FAST_MATH") == "1" + +# How to organize memory under memory_planning=True: +# - "none": do not try to pool storage, just reuse +# - "intermediates": all non-outputs share storage, outputs each get unique storage +# - "outputs": two pools, one for intermediates (freed on return) and one for outputs +# - "combined": a single pool for both intermediates and outputs +memory_pool: Literal["none", "intermediates", "outputs", "combined"] = os.environ.get( + "TORCHINDUCTOR_MEMORY_POOL", "intermediates" +) # type: ignore[assignment] + +# codegen benchmark harness +benchmark_harness = True + +# fuse pointwise into templates epilogues +epilogue_fusion = True + +# fuse pointwise into template prologues +prologue_fusion = prologue_fusion_enabled() + +# do epilogue fusions before other fusions +epilogue_fusion_first = False + +# enable pattern match+replace optimizations +pattern_matcher = True + +# set to True to enable the back-to-back GEMM pass +b2b_gemm_pass = False + +# register custom graph optimization pass hook. so far, pre/post passes are +# only applied before/after pattern_matcher in post_grad_passes. +# +# Implement CustomGraphPass to allow Inductor to graph compiled artifacts +# to which your custom passes have been applied: +post_grad_custom_pre_pass: torch._inductor.custom_graph_pass.CustomGraphPassType = None +post_grad_custom_post_pass: torch._inductor.custom_graph_pass.CustomGraphPassType = None + +# Allow users to pass in custom partition function +custom_partitioner_fn: torch._inductor.custom_graph_pass.CustomPartitionerFnType = None + +# Registers a custom joint graph pass. +joint_custom_pre_pass: torch._inductor.custom_graph_pass.CustomGraphPassType = None +joint_custom_post_pass: torch._inductor.custom_graph_pass.CustomGraphPassType = None + +# Registers a custom pregrad pass. Note that the pre-grad IR is 1. +# non-functional, 2. non-normalized, and 3. prone to change. Ideally we should +# use post-grad passes. +pre_grad_custom_pass: Optional[Callable[[torch.fx.graph.Graph], None]] = None + +# Registers a custom pass to be run right before fusion in Inductor scheduler. +# WARNING: Inductor scheduler IR is at prototype stage and subject to change, +# hence custom IR passes built on top of it might break in the future. +_pre_fusion_custom_pass: Optional[ + Callable[ + [list["torch._inductor.scheduler.BaseSchedulerNode"]], + list["torch._inductor.scheduler.BaseSchedulerNode"], + ] +] = None + +# Registers a custom pass to be run right after fusion in Inductor scheduler. +# WARNING: Inductor scheduler IR is at prototype stage and subject to change, +# hence custom IR passes built on top of it might break in the future. +_post_fusion_custom_pass: Optional[ + Callable[ + [list["torch._inductor.scheduler.BaseSchedulerNode"]], + list["torch._inductor.scheduler.BaseSchedulerNode"], + ] +] = None + +# Deprecated +split_cat_fx_passes = True + +# Optimize conv-batchnorm if batchnorm is in eval mode. Slightly reduces numerical stability. +efficient_conv_bn_eval_fx_passes = False + +# Enable predispatch aten IR for export +is_predispatch = False + +# Deprecated +group_fusion = False + +# Deprecated +batch_fusion = True + +# Pre grad fusion and options in order, set to empty dict to disable fusion. +# Call `torch._inductor.fx_passes.group_batch_fusion.list_group_batch_fusions()` to see available fusions. +# batch fusion options: +# batch_linear +# batch_linear_lhs +# batch_layernorm +# batch_tanh +# batch_relu +# batch_sigmoid + +# split cat fusion options: +# normalization_pass +# remove_split_with_size_one_pass +# merge_getitem_cat_pass +# merge_stack_tahn_unbind +# merge_splits_pass +# mutate_cat_pass +# split_cat_pass +pre_grad_fusion_options: dict[str, dict[str, Any]] = {} + +# Post grad fusion and options, set to empty dict to disable fusion. +# Call `torch._inductor.fx_passes.group_batch_fusion.list_group_batch_fusions(False)` to see available fusions. +post_grad_fusion_options: dict[str, dict[str, Any]] = {} + +# enable reordering pass for improving memory locality +reorder_for_locality = True + +# Scale down Rn_BLOCK for better occupancy +dynamic_scale_rblock = os.environ.get("TORCHINDUCTOR_DYNAMIC_SCALE_RBLOCK", "1") == "1" + +# this forces fusion for int_mm with mul. Needed when you want to avoid realizing the int32 +# but the mul gets fused with other pointwise ops instead. +force_fuse_int_mm_with_mul = False + +# DEPRECATED. This setting is ignored. +use_mixed_mm = True + +# enable runtime numeric check for pre/post grad fx passes +# floating point provides limited accuracy (about 7 decimal digits for single precision +# floating point numbers,about 16 decimal digits for double precision floating point numbers) +# according to PyTorch documentation. +# https://pytorch.org/docs/stable/notes/numerical_accuracy.html#batched-computations-or-slice-computations +fx_passes_numeric_check: dict[str, Any] = { + "pre_grad": False, + "precision": 1e-4, + "num_iterations": 1, + "requires_optimizer": True, +} + +# DEPRECATED. This setting is ignored. +mixed_mm_choice: Literal["default", "triton", "aten", "heuristic"] = "heuristic" + +# enable reordering pass for increasing overlap between compute and communication +reorder_for_compute_comm_overlap = False + +# passes (in execution order) for increasing overlap between compute and communication +# for built-in passes, use string name; for user-defined passes, pass in the function handle +# WARNING: Inductor scheduler IR is at prototype stage and subject to change, +# hence custom IR passes built on top of it might break in the future. +# +# See aten_distributed_optimizations, it is recommended way for distributed optimizations. +# +# Recommended configuration for reorder_for_compute_comm_overlap_passes: +# [ +# "reorder_communication_preserving_peak_memory", +# "sink_waits_iterative", +# "reorder_communication_preserving_peak_memory", +# ] +reorder_for_compute_comm_overlap_passes: list[ + Union[ + str, + Callable[ + [list["torch._inductor.scheduler.BaseSchedulerNode"]], + list["torch._inductor.scheduler.BaseSchedulerNode"], + ], + ] +] = [] + +# Maximum number of positions to advance a given collective, unlimited by default +reorder_prefetch_limit: Optional[int] = None + +# enable operator reordering for peak memory optimization +reorder_for_peak_memory = True +reorder_for_peak_memory_debug = False + +# In some cases, when all the nodes that can be scheduled are quite large, +# it is beneficial to switch the scheduling strategy. So instead of using +# size as the criterion, we choose a node that can unlock more nodes to +# become schedulable by analyzing their successor nodes. The default value +# is zero, which turns off this optimization. +size_threshold_for_succ_based_strategy: int = 0 + + +bucket_all_gathers_fx: Literal["none", "all", "only_fsdp"] = "none" +# By default torch._inductor.fx_passes.bucketing.bucket_size_determinator is used +bucket_all_gathers_fx_bucket_size_determinator: Optional[Callable[[int], int]] = None + +bucket_reduce_scatters_fx: Literal["none", "all"] = "none" +# By default torch._inductor.fx_passes.bucketing.bucket_size_determinator is used +bucket_reduce_scatters_fx_bucket_size_determinator: Optional[Callable[[int], int]] = ( + None +) + +bucket_all_reduces_fx: Literal["none", "all"] = "none" +# By default torch._inductor.fx_passes.bucketing.bucket_size_determinator is used +bucket_all_reduces_fx_bucket_size_determinator: Optional[Callable[[int], int]] = None + +# runtime estimation function for ops +# for built-in estimation function, pass in "default"; for user-defined estimation function, pass in the function handle +estimate_op_runtime = "default" + +runtime_estimations_mms_benchmark: bool = False + +# unit: GB/s, uni-directional P2P bandwidth per card +# default value is NVLink +intra_node_bw = 300 + +# unit: GB/s, uni-directional P2P bandwidth per node +# default value is InfiniBand +inter_node_bw = 25 + +# unit: GB/s, uni-directional CPU<>GPU bandwidth +# default value is PCIe; modify for your hardware or measured bandwidth +cpu_gpu_bw = 50.0 + +# use Inductor's experimental benchmarker (runtime/benchmarking.py) +# to benchmark kernels during autotuning, otherwise fall back to +# Triton's `do_bench`. the experimental benchmarker may produce +# results that are not consistent with `do_bench`'s results +use_experimental_benchmarker: bool = Config( + default=True, + env_name_force="TORCHINDUCTOR_USE_EXPERIMENTAL_BENCHMARKER", + justknob="pytorch/inductor:use_experimental_benchmarker", +) + +# Enable distributed autotuning. When this is enabled we will distribute the +# autotuning across distributed ranks in the same program group - so instead of +# each rank autotuning every kernel they only autotune 1/world size kernels and +# then share the results. +distributed_max_autotune_gemm = ( + os.environ.get("TORCHINDUCTOR_DISTRIBUTED_MAX_AUTOTUNE_GEMM") == "1" +) + +# enable slow autotuning passes to select algorithms +max_autotune = os.environ.get("TORCHINDUCTOR_MAX_AUTOTUNE") == "1" + +# enable slow autotuning passes to select pointwise/reductions algorithms +max_autotune_pointwise = os.environ.get("TORCHINDUCTOR_MAX_AUTOTUNE_POINTWISE") == "1" + +# enable slow autotuning passes to select gemm algorithms +max_autotune_gemm = os.environ.get("TORCHINDUCTOR_MAX_AUTOTUNE_GEMM") == "1" + +# Modifies the number of autotuning choices displayed, set to None for all +autotune_num_choices_displayed: Optional[int] = 10 + +# Report the autotune choices and their benchmark results. Default is True. +max_autotune_report_choices_stats = ( + os.environ.get("TORCHINDUCTOR_MAX_AUTOTUNE_REPORT_CHOICES_STATS", "1") == "1" +) + +# Prune configs that require more shared memory than the hardware limit +max_autotune_prune_choices_based_on_shared_mem = ( + os.environ.get("TORCHINDUCTOR_MAX_AUTOTUNE_PRUNE_CHOICES_BASED_ON_SHARED_MEM", "1") + == "1" +) + +# Disable triton from trying to initialize and detect devices on the host +triton_disable_device_detection = ( + os.environ.get("TORCHINDUCTOR_TRITON_DISABLE_DEVICE_DETECTION", "0") == "1" +) + +# enable inductor graph partition to allow multiple inductor graphs for the same dynamo graph +graph_partition: bool = ( + os.environ.get("TORCHINDUCTOR_GRAPH_PARTITION", "1" if not is_fbcode() else "0") + == "1" +) + +# register ops upon which inductor should partition the graph. name format should be +# "namespace::kernel_name" (e.g., aten::mm) for op overload packet, or +# "namespace::kernel_name.overload" (e.g., aten::mm.default). +custom_should_partition_ops: list[str] = [] + +# whether template autotuning should allow flexible layouts if possible (e.g. only extern choices) +max_autotune_allow_flexible_layouts: bool = False + +# force cublas and triton to use the same precision; cublas supports TF32 for matmul operations +# when m, n, k are multiples of 16, 16, 8, whereas triton supports TF32 for matmul operations +# for any combinations of m, n, k, regardless of their alignment. setting this flag will ensure +# that triton does not use TF32 wherever cublas would not use TF32 +# DEPRECATED. cuBLAS no longer has the above alignment requirements. will remove in the future. +force_same_precision: bool = Config( + justknob="pytorch/compiler:force_same_precision", + env_name_force="TORCHINDUCTOR_FORCE_SAME_PRECISION", + default=False, +) + +# Size hints for multi-kernel dispatch. +# A reasonable default value of this config would be [64, 256, 4096] +# TODO: @bobrenjc93 to roll this out to a few internal models to ensure this works +# as expected before turning it on for everyone. +multi_kernel_hints: list[int] = [] + +# Specify candidate backends for gemm autotune. +# Possible choices are combinations of: ATen, Triton, CUTLASS, CK, CKTILE, CPP. +# ATen: default Pytorch ATen kernels. +# Triton: Triton templates defined in torch inductor (AMD and NVidia GPUs). +# CUTLASS: Cutlass templates and kernels (NVidia GPUs only). +# CK: Composable Kernel templates and kernels (AMD Instinct GPUs only). +# CKTILE: Composable Kernel templates and kernels, new API (AMD Instinct GPUs only). +# CPP: CPP templates and kernels for CPU. +max_autotune_gemm_backends = os.environ.get( + "TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_BACKENDS", "ATEN,TRITON,CPP" +).upper() + + +# As above, specify candidate backends for conv autotune. +# NB: in some cases for 1x1 convs we emit as matmul, +# which will use the backends of `max_autotune_gemm_backends` +max_autotune_conv_backends = os.environ.get( + "TORCHINDUCTOR_MAX_AUTOTUNE_CONV_BACKENDS", "ATEN,TRITON" +).upper() + + +# Specify the size of the search space for GEMM autotuning. +# DEFAULT - balance between compile time overhead and performance +# EXHAUSTIVE - maximize performance +max_autotune_gemm_search_space: Literal["DEFAULT", "EXHAUSTIVE"] = os.environ.get( + "TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_SEARCH_SPACE", "DEFAULT" +).upper() # type: ignore[assignment] + +# Specify the size of the search space for flex attention autotuning. +# DEFAULT - balance between compile time overhead and performance +# EXHAUSTIVE - maximize performance +max_autotune_flex_search_space: Literal["DEFAULT", "EXHAUSTIVE"] = os.environ.get( + "TORCHINDUCTOR_MAX_AUTOTUNE_FLEX_SEARCH_SPACE", "DEFAULT" +).upper() # type: ignore[assignment] + + +# Fall back to ATen for all ops by default, except those nodes that users explicitly +# annotated with regional inductor compile. Please read torch.fx.passes.regional_inductor +# on to explicitly annotate. This is currently only used by inductor lite mode. +# Different from default inductor mode that fuses all nodes, this config enables an +# opt-in mode that only fuse for user-specified nodes. The motivation is to provide +# guaranteed numeric correctness and give full control to users. +fallback_by_default: bool = False + + +# This config allows selective decomposition of certain operators in the graph. +# Currently the only use case is to patch the same-name config in functorch, for +# inductor lite mode. See more details in [Note: Selective Decomposition] +selective_decompose: bool = False + + +# Use dead code elimination +use_dce: bool = True + + +# Use fx graph passes +use_pre_grad_passes: bool = True +use_joint_graph_passes: bool = True +use_post_grad_passes: bool = True + + +cutedsl_enable_autotuning: bool = ( + os.environ.get("CUTEDSL_ENABLE_AUTOTUNING", "0") == "1" +) + +# DEPRECATED. This setting is ignored. +autotune_fallback_to_aten = False + +# the value used as a fallback for the unbacked SymInts +# that can appear in the input shapes (e.g., in autotuning) +unbacked_symint_fallback = 8192 + +# DEPRECATED. This setting is ignored. +search_autotune_cache = False + +save_args = os.environ.get("TORCHINDUCTOR_SAVE_ARGS") == "1" + +# We will disable creating subprocess for autotuning if this is False +autotune_in_subproc = os.environ.get("TORCHINDUCTOR_AUTOTUNE_IN_SUBPROC") == "1" + +# The following three timeouts are applicable if autotune_in_subproc is True: + +# Max time that a valid benchmark result may take during autotuning +max_autotune_subproc_result_timeout_seconds = 60.0 +# DEPRECATED. This setting is ignored. +max_autotune_subproc_graceful_timeout_seconds = 0.0 +# DEPRECATED. This setting is ignored. +max_autotune_subproc_terminate_timeout_seconds = 0.0 + +# If autotuning in subprocess, whether to use multiple devices +autotune_multi_device = os.environ.get("TORCHINDUCTOR_AUTOTUNE_MULTI_DEVICE") == "1" + +# Number of benchmark runs for collective operations +collective_benchmark_nruns = int( + os.environ.get("TORCHINDUCTOR_COLLECTIVE_BENCHMARK_NRUNS", "50") +) + +# Timeout in seconds for collective benchmarking +collective_benchmark_timeout = float( + os.environ.get("TORCHINDUCTOR_COLLECTIVE_BENCHMARK_TIMEOUT", "30") +) + +coordinate_descent_tuning = ( + os.environ.get("TORCHINDUCTOR_COORDINATE_DESCENT_TUNING") == "1" +) +coordinate_descent_check_all_directions = ( + os.environ.get("TORCHINDUCTOR_COORDINATE_DESCENT_CHECK_ALL_DIRECTIONS") == "1" +) +coordinate_descent_search_radius = int( + os.environ.get("TORCHINDUCTOR_COORDINATE_DESCENT_RADIUS", "1") +) + +# AutoHeuristic is a framework that allows one to collect data from autotuning, use the data to learn a heuristic, and +# generate the learned heuristic to code which is shipped with the compiler +# Specify a list of comma separated optimizations to collect data for +autoheuristic_collect = os.environ.get("TORCHINDUCTOR_AUTOHEURISTIC_COLLECT", "") +# Specify a list of comma separated optimizations to use learned heuristics for +autoheuristic_use = os.environ.get("TORCHINDUCTOR_AUTOHEURISTIC_USE", "mixed_mm") + +# If set to 1, will run a JIT post compile hook if one is set. +run_jit_post_compile_hook = ( + os.environ.get("TORCHINDUCTOR_RUN_JIT_POST_COMPILE_HOOK", "0") == "1" +) + + +def run_autoheuristic(name: str) -> bool: + return collect_autoheuristic(name) or use_autoheuristic(name) + + +def collect_autoheuristic(name: str) -> bool: + return name in torch._inductor.config.autoheuristic_collect.split(",") + + +def use_autoheuristic(name: str) -> bool: + return name in torch._inductor.config.autoheuristic_use.split(",") + + +# If set to "DEFAULT", this will use the default log path specified in autoheuristic.py. +# If set to another path, autoheuristic will instead log results to the given path. +autoheuristic_log_path = os.environ.get( + "TORCHINDUCTOR_AUTOHEURISTIC_LOG_PATH", "DEFAULT" +) + +# Disabled by default on ROCm, opt-in if model utilises NHWC convolutions +layout_opt_default = "1" if not torch.version.hip else "0" +layout_optimization = ( + os.environ.get("TORCHINDUCTOR_LAYOUT_OPTIMIZATION", layout_opt_default) == "1" +) + +force_layout_optimization = os.environ.get("TORCHINDUCTOR_FORCE_LAYOUT_OPT", "0") == "1" + + +# Whether to keep the output strides the same as eager after layout optimization. +keep_output_stride = os.environ.get("TORCHINDUCTOR_KEEP_OUTPUT_STRIDE", "1") == "1" + +# Enabling this will let compiler print warning messages if a generated triton +# kernel has inputs with mixed layouts. This is helpful for perf debugging +# since kernel with mixed layout inputs may run much slower then one whose inputs +# have uniform layouts. +warn_mix_layout = os.environ.get("TORCHINDUCTOR_WARN_MIX_LAYOUT") == "1" + +# control store vs recompute heuristic +# For fanouts, rematerialization can lead to exponential blowup. So, have +# smaller threshold +realize_reads_threshold = 4 +realize_opcount_threshold = 30 + +# Threshold to prevent excessive accumulation of ops in one buffer during lowering +realize_acc_reads_threshold = 8 +realize_acc_reads_size_threshold: Optional[int] = ( + None # TODO(xuanzh): harden this to make it non optional +) + +# fallback to eager for random/dropout, this is slow but useful for debugging +fallback_random = False + +# fallback embedding_bag_byte_unpack to eager +fallback_embedding_bag_byte_unpack = False + +# automatically create fallbacks when encountering an unhandled op +implicit_fallbacks = True +assume_unaligned_fallback_output = ( + os.environ.get("TORCHINDUCTOR_ASSUME_UNALIGNED_FALLBACK_OUTPUT") == "1" +) + +# Custom InductorChoices callable to use (can be a class or functools.partial with kwargs) +inductor_choices_class: Optional[Callable[[], "InductorChoices"]] = None + +# fuse even in cases without common reads +aggressive_fusion = False + +# For each fused kernel in the wrapper, comment with the nodes that get fused. +# Useful for debugging fusion. +debug_fusion: bool = os.environ.get("TORCHINDUCTOR_DEBUG_FUSION") == "1" +benchmark_fusion: bool = os.environ.get("TORCHINDUCTOR_BENCHMARK_FUSION") == "1" +enabled_metric_tables = os.environ.get("TORCHINDUCTOR_ENABLED_METRIC_TABLES", "") +loop_ordering_after_fusion: bool = ( + os.environ.get( + "TORCHINDUCTOR_LOOP_ORDERING_AFTER_FUSION", "0" if is_fbcode() else "1" + ) + == "1" +) + + +# When trying to fuse two nodes, one with: +# a[contiguous_writes] = fn(...) +# and another node: +# b[contiguous_writes] = a[discontiguous_reads] +# If b is unary, and we can figure out an inverse formula for +# discontiguous writes, invert b as : +# b[inverse(discontiguous_writes)] = a[contiguous_reads] +# so that the nodes can fuse. for more details: https://gist.github.com/eellison/6f9f4a7ec10a860150b15b719f9285a9 +loop_index_inversion_in_fusion: bool = True + +# If fusing two nodes only save less then score_fusion_memory_threshold memory, +# we should not bother fusing the nodes. +# +# This is especially helpful to resolve https://github.com/pytorch/pytorch/issues/133242 +# Previously we fuse two nodes because of common read of a scalar tensor. +# If we skip it, the loop ordering after fusion mechanism kicks in and can +# brings more savings. +# +# For the cases loop ordering after fusion does not help, we don't lose much. +score_fusion_memory_threshold = 10 + +# For Triton Templates, select fastest of best template + epilogue vs best template + separate epilogue kernel +benchmark_epilogue_fusion = ( + os.environ.get("TORCHINDUCTOR_BENCHMARK_EPILOGUE_FUSION", "1") == "1" +) + +# Take how many of the top triton kernels to benchmark epilogue +max_epilogue_benchmarked_choices = 1 + +# how many nodes to allow into a single fusion +max_fusion_size = 64 + +# how many nodes to attempt pairwise fusion with in a buffer group +max_fusion_buffer_group_pairwise_attempts = 64 + +# maximum number of unique input/output buffers allowed in fused kernels. +# The check is disabled if set to None. +max_fusion_unique_io_buffers: Optional[int] = None + +# max number of inputs to generate cat as a pointwise op with masked loads +max_pointwise_cat_inputs = 8 + +# force concat to be generated as a pointwise op with masked loads +force_pointwise_cat = False + +# replace small reductions with pointwise, disable with `= 1` +unroll_reductions_threshold = 8 + +# Add extra comments to output code (causes compile cache misses) +comment_origin = False + +# Convert 1x1 convs into matmuls +conv_1x1_as_mm = False + +# For reductions with a small output size (usually 1, e.g. x.sum()) there is not enough +# parallelism to saturate the GPU. We have two ways of handling this, either `split_reductions` +# or `triton.cooperative_reductions` which are mutually exclusive. +# split_reductions: uses multiple kernels to gain more parallelism +# triton.cooperative_reductions: uses cross thread-block synchronization to gain more parallelism +# enabling both of these will implicitly disable split_reductions +split_reductions = os.getenv("TORCHINDUCTOR_SPLIT_REDUCTIONS", "1") == "1" + +# A deterministic mode that skips any on device benchmarking in Inductor +# if we know they affect numerics. WARNING: Expect perf hit in this mode. +deterministic = os.getenv("TORCHINDUCTOR_DETERMINISTIC") == "1" + +# When we do split reduction, this number control the minimum value for +# num_split. Too small num_split make the split reduction less efficient. +# It's a much bigger problem when we compile a dynamic shape kernel with +# non-representative inputs. +min_num_split = int(os.environ.get("TORCHINDUCTOR_MIN_NUM_SPLIT", 0)) + +benchmark_kernel = os.environ.get("TORCHINDUCTOR_BENCHMARK_KERNEL", "0") == "1" + +# Enable constant and index_expr folding +constant_and_index_propagation = True + +# we always add constants into graph.constants without +# performing any constant-inlining optimization +always_keep_tensor_constants = False + +# assert that indirect indexing does not read / write out of bounds +assert_indirect_indexing = True + +# compute CSE bounds on variables that do not appear in the FX graph +compute_all_bounds = False + +# enable the combo kernel that combines data-independent kernels (additional +# to foreach kernels) into a single one (Experimental) +combo_kernels = False +# benchmark combo kernels and only allow ones with perf gains +benchmark_combo_kernel = False +# combo_kernel autotuning options: 0 - disable, 1 - enable except for foreach, +# 2 - enable for all +combo_kernels_autotune = 1 +# Enable masking for combining kernels of mixed sizes: 0 - disable, 1 - enable +# for all except for foreach, 2 - enable for all +combo_kernel_allow_mixed_sizes = 1 +# Enable dynamic shapes for foreach kernels +combo_kernel_foreach_dynamic_shapes = True +# Maximum number of arguments (read/write buffers) allowed in a combo kernel +combo_kernel_max_num_args = 250 + +# constant folding on the joint graph +joint_graph_constant_folding = True + +# Enable indirect_indexing asserts for decompositions and lowerings +debug_index_asserts = False + +# Mode to emulate PyTorch eager numerics when doing lower precision compute +# (fp16, bf16). PyTorch eager computes bf16/fp16 by upcasting inputs to fp32 +# and downcasting after. When two low precision operators are fused together, +# Inductor will elide the downcast-upcast pairs (effectively a precision +# truncation) that would occur between these two operators. Typically, +# Inductor's behavior should be closer to fp64 ref numerics. However, with +# this knob you can ensure the downcast-upcast are preserved so that you can +# emulate the eager numerics. +emulate_precision_casts = ( + os.environ.get("TORCHINDUCTOR_EMULATE_PRECISION_CASTS", "0") == "1" +) + +# x / y in Triton is lowered to div.full which is approx +# PyTorch eager uses the equivalent of Triton's div_rn, which can +# come at a performance penalty +emulate_divison_rounding = ( + os.environ.get("TORCHINDUCTOR_EMULATE_DIVISION_ROUNDING", "0") == "1" +) + +# warnings intended for PyTorch developers, disable for point releases +is_nightly_or_source = "dev" in torch.__version__ or "git" in torch.__version__ +developer_warnings = is_fbcode() or is_nightly_or_source + +# This pattern matches a special usage of scatter +# 1. It's applied to a constant tensor +# 2. The index tensor has size 1 in the scatter dimension +# Such pattern generates a sparse matrix when the const tensor is all-zero. +# We can lower this pattern to a pointwise kernel for more fusion opportunities +# and saving memory footprint. +optimize_scatter_upon_const_tensor = ( + os.environ.get("TORCHINDUCTOR_OPTIMIZE_SCATTER_UPON_CONST_TENSOR", "1") == "1" +) + +# options in caffe2/torch/_inductor/fx_passes/pre_grad.py +add_pre_grad_passes: Optional[str] = None +remove_pre_grad_passes: Optional[str] = None + + +# The multiprocessing start method to use for inductor workers in the codecache. +def decide_worker_start_method() -> str: + if "TORCHINDUCTOR_WORKER_START" in os.environ: + start_method = os.environ["TORCHINDUCTOR_WORKER_START"] + else: + start_method = "subprocess" + assert start_method in ( + "subprocess", + "fork", + "spawn", + ), f"Invalid start method: {start_method}" + return start_method + + +worker_start_method: str = decide_worker_start_method() + +# Threshold to decide if a kernel has small memory access in bytes +# Default value is 16 MB which is arbitrarily selected. +small_memory_access_threshold: int = 16777216 + +# Whether to log from subprocess workers that are launched. +worker_suppress_logging: bool = Config( + justknob="pytorch/compiler:worker_suppress_logging", + env_name_force="TORCHINDUCTOR_WORKER_SUPPRESS_LOGGING", + default=True, +) + +# Log per-operation runtime estimates for TLParse analysis. +log_tlparse: bool = Config( + env_name_force="LOG_TLPARSE", + default=False, +) + +# Flags to turn on all_reduce fusion. These 2 flags should be automatically turned +# on by DDP and should not be set by the users. +_fuse_ddp_communication = False +_fuse_ddp_bucket_size = 25 + +# Flag to control which fusion passes to apply. Functions in the list will +# be applied in order. There are two different different fusion passes +# --"fuse_ddp_with_concat_op" and "fuse_ddp_with_coalesced_op". The default +# one is "fuse_ddp_with_concat_op". Users can also change this to a customized +# fusion function. +# +# The fusion currently does not support multiple DDP with different PG or +# data type. This feature will be added in the future PRs. +# +# "schedule_comm_wait" is used to delay the wait ops to maximize comm/comp +# overlapping. At this moment, this pass performs better than +# reorder_for_compute_comm_overlap_passes but we will add the logic of +# "schedule_comm_wait" in the future and remove the one here. +_fuse_ddp_communication_passes: list[Union[Callable[..., None], str]] = [ + "fuse_ddp_with_concat_op", + "schedule_comm_wait", +] + +_micro_pipeline_tp: bool = False + + +class _collective: + auto_select: bool = False + one_shot_all_reduce_threshold_bytes: int = 128 * 1024 + + +class aten_distributed_optimizations: + """Configuration for distributed optimization passes on ATen FX graphs.""" + + # Enable overlap scheduling pass + enable_overlap_scheduling: bool = False + + # Enable overlap-preserving collective bucketing + collective_bucketing: Optional[bool] = None + + # Insert ordering dependencies to preserve overlap relationships. This should only be used if + # compiling with inductor, or for subsequent passes before removing the ops prior to execution + insert_overlap_deps: Optional[bool] = None + + # Maximum compute node prefetch distance for overlap scheduling + max_compute_pre_fetch: Optional[int] = None + + compute_overlap_multipler: Optional[float] = None + + # Custom runtime estimation function for ops + # For user-defined estimation function, pass in the function handle + # None means use default estimations + # TODO - need estimated and profile based version + custom_runtime_estimation: Optional[Callable[[torch.fx.Node], Optional[float]]] = ( + None + ) + + # Method for estimating collective runtime + # "analytical": Use bandwidth formulas (default) + # "benchmark": Use CUDA events with power-of-2 rounding and interpolation + collective_estimator: Literal["analytical", "benchmark"] = "analytical" + + # Maximum memory increase above baseline for prefetch operations + # Uses minimum of absolute cap and ratio of baseline + max_memory_increase_gb: Optional[float] = None # Absolute cap in GB + max_memory_increase_ratio: Optional[float] = None # Ratio of baseline peak memory + + # Maximum GB of concurrent collective data in flight. Too much in flight memory + # can cause memory fragmentation within the CUDA Caching Allocator. + max_in_flight_gb: Optional[float] = None + + # Maximum prefetch or bucketing candidates. Mainly intended for compile time. + max_coll_distance: Optional[int] = None + + +def parallel_compile_enabled_internally() -> bool: + """ + TODO: Remove when parallel compiled is fully enabled internally. For rollout, use a + knob to enable / disable. The justknob should not be performed at import, however. + So for fbcode, we assign compile_threads to 'None' below and initialize lazily in + async_compile.py. + """ + ENABLE_PARALLEL_COMPILE_VERSION = 1 + + jk_name = "pytorch/inductor:enable_parallel_compile_version" + version = torch._utils_internal.justknobs_getval_int(jk_name) + return ENABLE_PARALLEL_COMPILE_VERSION >= version + + +def decide_compile_threads() -> int: + """ + Here are the precedence to decide compile_threads + 1. User can override it by TORCHINDUCTOR_COMPILE_THREADS. One may want to disable async compiling by + setting this to 1 to make pdb happy. + 2. Set to 1 if it's win32 platform + 3. decide by the number of CPU cores + """ + import logging + + # Defined locally so install_config_module doesn't try to parse + # as a config option. + log = logging.getLogger(__name__) + + if "TORCHINDUCTOR_COMPILE_THREADS" in os.environ: + compile_threads = int(os.environ["TORCHINDUCTOR_COMPILE_THREADS"]) + log.info("compile_threads set to %d via env", compile_threads) + elif sys.platform == "win32": + compile_threads = 1 + log.info("compile_threads set to 1 for win32") + elif is_fbcode() and not parallel_compile_enabled_internally(): + compile_threads = 1 + log.info("compile_threads set to 1 in fbcode") + else: + cpu_count = ( + len(os.sched_getaffinity(0)) + if hasattr(os, "sched_getaffinity") + else os.cpu_count() + ) + assert cpu_count + compile_threads = min(32, cpu_count) + log.info("compile_threads set to %d", compile_threads) + + return compile_threads + + +# TODO: Set directly after internal rollout. +compile_threads: Optional[int] = None if is_fbcode() else decide_compile_threads() + +# Whether to quiesce the Triton-compile subprocess pool at the end of each compilation. +quiesce_async_compile_pool: bool = Config( + justknob="pytorch/inductor:quiesce_async_compile_pool", + env_name_force="TORCHINDUCTOR_QUIESCE_ASYNC_COMPILE_POOL", + default=True, +) + +# Time in seconds to wait before quiescing +quiesce_async_compile_time: int = Config( + default=60, +) + +# Whether or not to enable statically launching CUDA kernels +# compiled by triton (instead of using triton's own launcher) +use_static_cuda_launcher: bool = static_cuda_launcher_default() + +# Attempt to statically launch user defined triton kernels +# Requires use_static_cuda_launcher +static_launch_user_defined_triton_kernels: bool = Config( + justknob="pytorch/inductor:static_launch_user_defined_triton_kernels", + env_name_force="TORCHINDUCTOR_STATIC_LAUNCH_USER_DEFINED_TRITON_KERNELS", + default=False, +) + +# Raise error if we bypass the launcher +strict_static_cuda_launcher: bool = ( + os.environ.get("TORCHINDUCTOR_STRICT_STATIC_CUDA_LAUNCHER", "0") == "1" +) + +# gemm autotuning global cache dir +global_cache_dir: Optional[str] +if is_fbcode(): + try: + from libfb.py import parutil + + if __package__: + global_cache_dir = parutil.get_dir_path( + os.path.join(__package__.replace(".", os.sep), "fb/cache") + ) + else: + global_cache_dir = parutil.get_dir_path("fb/cache") + except (ValueError, ImportError): + global_cache_dir = None + +else: + global_cache_dir = None + +# If kernel is fused, the name is generated from the origin node op names +# for larger kernels limit this +kernel_name_max_ops = 10 + +# Pad input tensors of matmul/bmm/addmm to leverage Tensor Cores in NVIDIA GPUs +shape_padding = os.environ.get("TORCHINDUCTOR_SHAPE_PADDING", "1") == "1" + +# Control if we will do padding for pointwise/reductions +comprehensive_padding = ( + os.environ.get("TORCHINDUCTOR_COMPREHENSIVE_PADDING", "1") == "1" +) +pad_channels_last = False + +# Control if we will do padding on dynamic shapes +pad_dynamic_shapes = False + +# Disable comprehensive padding on the CPU +disable_padding_cpu = True + +# Control if we will expand the dimension of pointwise nodes to fuse +expand_dimension_for_pointwise_nodes = False + +# The width of comprehensive padding, in bytes. +# CUDA max memory transaction size is 128 bytes for a warp. +padding_alignment_bytes = 128 + +# Threshold on the minimum stride that will be padded. +# +# Don't align a too small stride since that causes too much memory increase. +# Pad too small stride may also cause perf loss. We may result in many tiny data blocks +# with gaps in between. That causes less coalesced GPU memory access! +# +# Initially we pick 320 as the threshold since for alignment=16, +# that results in at most 5% memory cost. +# +# But later on we raise the threshold to 1024 to avoid interfere with persistent reduction. +# Let's say an inner reduction has a row size 513. Inductor will generate +# persistent reduction code. +# If we do padding, the strides are not contiguous any more. Inductor +# uses a much smaller threshold for persistent reduction in this case and +# generates potentially worse non-persistent reduction code. +# +# This change turns HF AllenaiLongformerBase amp training from a loss of 1.09x to a win of 1.05x. +# (baseline: 71.09ms, padding w/o this change: 77.38ms, padding with this change: 67.77ms) +padding_stride_threshold = 1024 + +# Enable padding outputs, even if they would not be padded in eager mode. +# By default, we use the same strides as eager mode. +pad_outputs = False + +# Whether to treat output of the backward graph as user visible. +# For user visible outputs, inductor will make sure the stride matches with eager. +bw_outputs_user_visible = True + +# Whether to always use shape padding if it is enabled and possible +force_shape_pad: bool = False + +# Fx-based linear/matmul/bmm + permute/transpose vertical fusion +permute_fusion = os.environ.get("TORCHINDUCTOR_PERMUTE_FUSION", "0") == "1" + +# Mark the wrapper call in PyTorch profiler +profiler_mark_wrapper_call = False + +# Generate hook calls to torch._inductor.hooks.run_intermediate_hooks for +# every intermediate for which we can correlate it with an intermediate +# from the original FX graph +generate_intermediate_hooks = False + +# Populate traceback field on IRNode; good for debugging why origin_node is +# not populated, or finding out where an IRNode was constructed +debug_ir_traceback = False + +# used for debugging to make sure config is properly set +_raise_error_for_testing = False + +_profile_var = os.environ.get("TORCHINDUCTOR_PROFILE", "") +profile_bandwidth = _profile_var != "" +profile_bandwidth_regex = "" if _profile_var == "1" else _profile_var +# Specify a file where we print out the profiling results. +# None means we do not dump results to a file. +profile_bandwidth_output: Optional[str] = os.environ.get( + "TORCHINDUCTOR_PROFILE_OUTPUT", None +) +# Switch to do_bench_using_profiling to exclude the CPU overheads +profile_bandwidth_with_do_bench_using_profiling = ( + os.environ.get("TORCHINDUCTOR_PROFILE_WITH_DO_BENCH_USING_PROFILING") == "1" +) + + +# TODO: remove later +# incompatible with cpp_wrapper +disable_cpp_codegen = False + + +# Freezing will attempt to inline weights as constants in optimization +# and run constant folding and other optimizations on them. After freezing, weights +# can no longer be updated. +freezing: bool = os.environ.get("TORCHINDUCTOR_FREEZING", "0") == "1" + +# Make freezing invalidate the eager Parameters of nn modules, to avoid memory overhead +# of potentially keeping multiple copies of weights. +freezing_discard_parameters: bool = False + +# decompose some memory bound matmul/bmm to mul +decompose_mem_bound_mm: bool = False + +# Wrap compiled regions in inductor_compiled_code HOP to make them visible to +# TorchDispatchModes like DebugMode and Selective Activation Checkpointing. +wrap_inductor_compiled_regions: bool = False + +# assume_aligned_inputs means that we assume that inputs will be aligned; we generate +# code using this assumption, and clone tensors before use if they aren't aligned. +# In the common case, most inputs will be aligned. +assume_aligned_inputs: bool = False + +# assume_32bit_indexing means that we assume 32-bit indexing is always safe; we always +# use 32-bit indices regardless of tensor sizes. If assume_32bit_indexing contradicts +# with example inputs we throw. This is useful when all dynamic shapes are unbacked and +# you know you only operate with 32-bit sizes. +assume_32bit_indexing: bool = False + +# For the user-written Triton kernels compiled with the model, ignore the unsupported +# arguments passed to the @triton.autotune in the user's code; this is unsafe, as +# ignoring the unsupported args may lead to unexpected autotuning behavior: don't +# set unless you know what you're doing. +unsafe_ignore_unsupported_triton_autotune_args: bool = False + +# When True, we will check in scheduler.py _codegen that there are no "loops" +# in the call stack; that is to say, the same frame multiple times. This +# ensures that a cProfile trace to this frame will be a straight line without +# any cycles. Incompatible with cpp_wrapper. +check_stack_no_cycles_TESTING_ONLY: bool = False + +# When True, complex_memory_overlap always reports True +always_complex_memory_overlap_TESTING_ONLY: bool = False + +# enable linear binary folding +enable_linear_binary_folding = ( + os.environ.get("TORCHINDUCTOR_ENABLE_LINEAR_BINARY_FOLDING", "0") == "1" +) + + +# Adds NVTX annotations around training phases +annotate_training: bool = os.environ.get("TORCHINDUCTOR_ANNOTATE_TRAINING", "0") == "1" + +# Enable caching codegen of triton templates. +enable_caching_generated_triton_templates: bool = True + +# Lookup table for overriding autotune configs based on hash of Triton source code +autotune_lookup_table: dict[str, dict[str, Any]] = {} + +file_lock_timeout: int = int(os.environ.get("TORCHINDUCTOR_FILE_LOCK_TIMEOUT", "600")) + +enable_autograd_for_aot: bool = False + +_debug_cpu_to_tpu_pallas: bool = Config( + env_name_force="PALLAS_TARGET_TPU", default=False +) +pallas_take_first_jax_device_only: bool = Config( + env_name_force="PALLAS_TAKE_FIRST_JAX_DEVICE_ONLY", default=True +) + + +def get_worker_log_path() -> Optional[str]: + log_loc = None + if is_fbcode(): + mast_job_name = os.environ.get("MAST_HPC_JOB_NAME", None) + global_rank = os.environ.get("ROLE_RANK", "0") + + if mast_job_name is not None: + log_loc = f"/logs/dedicated_log_torch_compile_worker_rank{global_rank}" + + return log_loc + + +torchinductor_worker_logpath: str = Config( + env_name_force="TORCHINDUCTOR_WORKER_LOGPATH", + default="", +) + + +# config specific to codegen/cpp.py +class cpp: + """ + Settings for cpp backend. + This class provides a centralized location for managing cpp backend settings. + """ + + # set to torch.get_num_threads() + threads = -1 + + # Do not generate loops when the condition doesn't hold, like: + # for(long i0=4096; i0<4096; i0+=1) + no_redundant_loops = ( + os.environ.get("TORCHINDUCTOR_CPP_NO_REDUNDANT_LOOPS", "1") == "1" + ) + + # Assume number of threads is dynamic, don't specialize thread number. + # Kernels don't recompile on thread number changes with this flag on. + # For single-threaded workload, turning it on would incur a slight + # performance degradation. + dynamic_threads = os.environ.get("TORCHINDUCTOR_CPP_DYNAMIC_THREADS", "0") == "1" + + simdlen: Optional[int] = None + min_chunk_size = int(os.environ.get("TORCHINDUCTOR_CPP_MIN_CHUNK_SIZE", "512")) + + cxx: tuple[None, str] = ( + None, # download gcc12 from conda-forge if conda is installed + os.environ.get("CXX", "clang++" if sys.platform == "darwin" else "g++"), + ) # type: ignore[assignment] + + # Allow kernel performance profiling via PyTorch profiler + enable_kernel_profile = ( + os.environ.get("TORCHINDUCTOR_CPP_ENABLE_KERNEL_PROFILE", "0") == "1" + ) + + # enable weight prepacking to get a better performance; may lead to large memory footprint + weight_prepack = os.environ.get("TORCHINDUCTOR_CPP_WEIGHT_PREPACK", "1") == "1" + + # Inject a bug into our relu implementation; useful for testing our repro + # extraction and minification functionality. + # Valid values: "compile_error", "runtime_error", "accuracy" + inject_relu_bug_TESTING_ONLY: Optional[str] = None + inject_log1p_bug_TESTING_ONLY: Optional[str] = None + + # If None, autodetect whether or not AVX512/AVX2 can be used. Otherwise, + # force usage as specified, without testing. Default None. + vec_isa_ok: Optional[bool] = get_tristate_env("TORCHINDUCTOR_VEC_ISA_OK") + + # similar to config.triton.descriptive_names + descriptive_names: Literal["torch", "original_aten", "inductor_node"] = ( + "original_aten" + ) + + # how many nodes to allow into a single horizontal fusion + max_horizontal_fusion_size = int( + os.environ.get("TORCHINDUCTOR_CPP_MAX_HORIZONTAL_FUSION_SIZE", "16") + ) + + # Make scatter_reduce fallback when reduce is sum to avoid performance regression + # using atomic_add. + fallback_scatter_reduce_sum = ( + os.environ.get("TORCHINDUCTOR_CPP_FALLBACK_SCATTER_REDUCE_SUM", "1") == "1" + ) + + # Use funsafe-math-optimizations when compiling + enable_unsafe_math_opt_flag = ( + os.environ.get("TORCHINDUCTOR_CPP_ENABLE_UNSAFE_MATH_OPT_FLAG", "0") == "1" + ) + + # Use ffp-contract when compiling + # Options: "off" (default), "on", "fast" + # Per https://godbolt.org/z/bf4bvfc9r , clang/gcc has different behavior for "fast" + enable_floating_point_contract_flag = os.environ.get( + "TORCHINDUCTOR_CPP_ENABLE_FLOATING_POINT_CONTRACT_FLAG", "off" + ) + + # Disable the tiling select heuristic + enable_tiling_heuristics = ( + os.environ.get("TORCHINDUCTOR_CPP_ENABLE_TILING_HEURISTIC", "1") == "1" + ) + + # Enable the Grouped GEMM Fusion + enable_grouped_gemm_template = False + + # Maximal allowed number of slices on K-dim for a GEMM kernel. This controls + # the maximal parallelism of K-slicing. Since K-slicing requires extra thread + # synchronization and buffers, the maximal number of slices is limited to + # mitigate the sync overhead and memory usage. + # When set to 0, the number of slices is unlimited. + gemm_max_k_slices = int(os.environ.get("TORCHINDUCTOR_CPP_GEMM_MAX_K_SLICES", "1")) + + # For perf tuning and debugging purpose, configure the pre-defined cache blocking for + # MxNxK dims respectively. The blockings are separated by comma and the unit is + # the number of register blocks. + # For example, "4,1,10" means 4 register blocks on M, 1 on N and 10 on K respectively. + gemm_cache_blocking = os.environ.get("TORCHINDUCTOR_CPP_GEMM_CACHE_BLOCKING", None) + + # For perf tuning and debugging purpose, configure the pre-defined thread blocking factors for + # MxNxK dims respectively. The factors are separated by comma and their product + # should be the same as the total number of threads. + # For example, if the total number of threads is 56, "7,4,2" means the work is + # decomposed into 7x4x2 thread blocks along MxNxK of a GEMM. + gemm_thread_factors = os.environ.get("TORCHINDUCTOR_CPP_GEMM_THREAD_FACTORS", None) + + # Whether to enable masked vectorization for the tail_loop. + enable_loop_tail_vec = True + + # Whether to enable concat linear for cpu device + # Currently concat linear on CPU not always have benefit, depends on linear'shape or + # computing resource. We set this default to False to avoid regressions. User and + # enable this feature by their need. + enable_concat_linear = False + + # Whether to use decomposed tanh for cpu device + # Disable by default due to https://github.com/pytorch/pytorch/issues/148241 + use_decompose_tanh = ( + os.environ.get("TORCHINDUCTOR_CPP_USE_DECOMPOSE_TANH", "0") == "1" + ) + + # Use a small dequant buffer for wgt of woq int4 size as: [q_group_size, Nr] + use_small_dequant_buffer = False + + force_inline_kernel = ( + os.environ.get("TORCHINDUCTOR_CPP_FORCE_INLINE_KERNEL", "0") == "1" + ) + + # Use static constexpr or static const for int array + use_constexpr_for_int_array = ( + os.environ.get("TORCHINDUCTOR_CPP_USE_CONSTEXPR_FOR_INT_ARRAY", "1") == "1" + ) + + +class triton: + """ + Config specific to codegen/triton.py + """ + + # Use cudagraphs on output code + cudagraphs = os.environ.get("TORCHINDUCTOR_CUDAGRAPHS") == "1" + + # Use cudagraph trees for memory pooling if `cudagraphs` is True + cudagraph_trees = True + + # Should we skip cudagraphing graphs with dynamic shape inputs + # If False, we will re-record a graph for each unique set of shape inputs + cudagraph_skip_dynamic_graphs = False + + # Specify dynamic shapes to capture cudagraphs and skip cudagraph for other shapes. + # Default to None, which means we capture cudagraphs for all shapes. + cudagraph_capture_sizes: Optional[tuple[Union[int, tuple[int, ...]]]] = None + + # assertions not on the fast path, steady state + slow_path_cudagraph_asserts = True + + # TODO - need to debug why this prevents cleanup + cudagraph_trees_history_recording = False + + # Enable cudagraph support for mutated inputs from prior cudagraph pool + cudagraph_support_input_mutation = not is_fbcode() + + # Maximal number of allowed cudagraph re-record for a function and + # a cudagraph node due to static input tensor address changes or + # cudagraph managed tensor data pointer changed. + # i.e., allow num_recording <= cudagraph_unexpected_rerecord_limit + # note: we are conservative here and choose a large limit. + cudagraph_unexpected_rerecord_limit = 128 + + # Warn loudly when the number of cudagraphs due to dynamic shape + # exceeds this limit + cudagraph_dynamic_shape_warn_limit: Optional[int] = 8 + + # synchronize after cudagraph invocation + force_cudagraph_sync = False + + # always run cudagraphs in the eager warmup stage + # instead of recording and executing cudagraphs + force_cudagraphs_warmup = False + + # If False (default), torch.compile skips cudagraph for a graph if it + # contains cudagraph-unsafe ops. If True, we require that all cuda ops + # be captured into cudagraph. If this is not possible, this will raise + # an error. + cudagraph_or_error: bool = Config( + env_name_force="TORCHINDUCTOR_CUDAGRAPH_OR_ERROR", + default=False, + ) + + # reorder nodes to minimize the number of graph partitions while + # not incurring large memory overhead + reorder_for_reducing_graph_partitions: bool = True + + # assertions on the fast path + fast_path_cudagraph_asserts = False + + # skip warmup for cudagraph trees + skip_cudagraph_warmup = False + + # Synchronize before and after every compiled graph. + debug_sync_graph = False + + # Synchronize after every kernel launch, to help pinpoint bugs + debug_sync_kernel = False + + # Always load full blocks (rather than broadcasting inside the block) + dense_indexing = False + + # TODO - enable by default + coalesce_tiling_analysis: bool = ( + os.environ.get( + "TORCHINDUCTOR_COALESCE_TILING_ANALYSIS", "1" if not is_fbcode() else "0" + ) + == "1" + ) + + # limit tiling dimensions + # - max_tiles=1 disables tiling + # - max_tiles=2 + # - max_tiles=3 is experimental and may have bugs + # higher values are unsupported + + # We use a max of 3 if coalesce_tiling_analysis is True, and 2 otherwise. + # Note - coalesce_tiling_analysis does not yet apply to dynamic shapes. + max_tiles: Optional[int] = None + + # Prefer higher dimensional tilings. This simplifies indexing expressions, making + # it easier to identify block pointers. + prefer_nd_tiling: bool = False + + # use triton.autotune for pointwise ops with complex layouts + # this should only be disabled for debugging/testing + autotune_pointwise = True + + # max autotune gemm with cublasLt + autotune_cublasLt = True + + # Tune the generated Triton kernels at compile time instead of first time they run + # Setting to None means uninitialized + autotune_at_compile_time: Optional[bool] = None + + # We use random tensors for autotune by default. Setting this as true will let us + # use inputs from sample inputs to autotune user defined triton kernels. + # Side effect for this option is increased memory footprint during first pass compilation. + autotune_with_sample_inputs: bool = False + + # Allows tiling reductions into multiple dimensions. + # For best results, this should be used with prefer_nd_tiling. + tile_reductions: bool = False + + # Codegen matmul natively with tl.dot without using a template. + # This option makes Inductor generate matrix multiplication from scratch, + # instead of calling predefined Triton templates (mm, bmm, mm_plus_mm). + # Compile time may be longer because native matmul benchmarks more Triton configs + # than regular pointwise or reduction kernels. + # Native matmul often aggressively fuses operations around the matrix multiply, + # which can make it faster or slower depending on your program. + # + # This option takes priority over other GEMM implementations. If Inductor determines + # that a matmul can be generated, it will always generate it with native_matmul. + # That means optimized kernels such as decompose_k or persistent_tma_matmul will + # not be called when this option is enabled. + # + # Note: Native matmul does not currently support block pointers or TMA matmul. + # If both native_matmul and (use_block_ptr or enable_persistent_tma_matmul) are enabled, + # an error will be thrown. + native_matmul: bool = False + + # should we stop a fusion to allow better tiling? + tiling_prevents_pointwise_fusion = True + tiling_prevents_reduction_fusion = True + + # should we give different names to kernels + # Note: This is orthogonal to descriptive_names - this is deciding whether + # our triton kernel names should all be `triton_` (to maximize caching) or + # whether they should be unique. + unique_kernel_names = ( + os.environ.get("TORCHINDUCTOR_UNIQUE_KERNEL_NAMES", "1") == "1" + ) + + # similar to the option above, but this is specific to user defined kernels, + # while unique_kernel_name is for kernels generated by inductor. + # We have this option because sometimes we reuse user's kernel code with different + # configs which would result in the same name. + # Note: This MODIFIES the user's kernel function name within inductor phase. + unique_user_kernel_names = ( + os.environ.get("TORCHINDUCTOR_UNIQUE_USER_KERNEL_NAMES", "0") == "1" + ) + + # should we put op names in kernel names + # "torch": Maps to the fx op in the Dynamo graph (module name, method name, etc.) + # "original_aten": Maps to the highest-level aten op (i.e. pre-decompositions) + # "inductor_node": Maps to the node name in the FX graph passed to Inductor + descriptive_names: Literal["torch", "original_aten", "inductor_node"] = ( + "original_aten" + ) + + # use alternate codegen for smaller reductions + persistent_reductions = ( + os.environ.get("TORCHINDUCTOR_PERSISTENT_REDUCTIONS", "1") == "1" + ) + + # For small output size reductions uses cross thread-block synchronization to gain more parallelism + cooperative_reductions = ( + os.environ.get("TORCHINDUCTOR_COOPERATIVE_REDUCTIONS", "0") == "1" + ) + + # used for debugging cooperative reduction codegen, always generate cooperative_reductions + force_cooperative_reductions = False + + # 0: disable + # 1/True: enable, use tuning to pick between different subkernels + # 2: enable, force using persistent reduction (for debugging) + # 3: enable, force using non-persistent reduction (for debugging) + multi_kernel: Literal[0, 1, 2, 3] = int( + os.environ.get("TORCHINDUCTOR_MULTI_KERNEL", "0") + ) # type: ignore[assignment] + + # hint to Triton when arguments are divisible by 16 + divisible_by_16 = os.environ.get("TORCHINDUCTOR_DIVISIBLE_BY_16", "1") == "1" + + # Minimum R0_BLOCK to be used for a TritonSplitScanKernel + # NOTE: This also indirectly controls the size of workspace buffer required + min_split_scan_rblock = 256 + + # Store the generated cubin files for cpp wrapper code to load + store_cubin = False + + # the max number of spills we allow for the configs we benchmark. + # Setting this to 0 means we skip a config if it spills even a single + # register. + # Setting it to a larger value allows a config spilling a small amount + # of registers being benchmarked. + # + # NOTE: triton will always report >0 register spills for kernels using sin/cos. + # (check this issue https://github.com/triton-lang/triton/issues/1756 ) + # So far we see a fixed 8 spilled registers for kernels using sin/cos. + # Raise the threshold to 16 to be safe. + # We should revisit this once we understand more of the source of register spills. + spill_threshold: int = 16 + + # Generate code containing the newer tl.make_block_ptr() API for loads/store + use_block_ptr = False + + # (Experimental) + # Generate code using the tl.make_tensor_descriptor() API for loads/store + # [Note: TMA API Restrictions] Currently the TMA API requires the following: + # - For Nvidia GPUs, the compute capability should be >= 9.0 + # - The innermost stride of a descriptor should be 1 + # - The size of the block shape in the innermost dimension should load / store + # at least 16 bytes. + # - Tensors are 16 byte aligned. Enabling this option therefore requires + # assume_aligned_inputs to also be enabled + # TMA descriptors are only going to be generated if the above conditions + # can be satisfied, along with any existing requirements for index expressions + use_tensor_descriptor = False + + # (Experimental) + # Whether to allow reordering tensor descriptor matches with descending + # strides, at the expense of transposing values after load / before store. + transpose_discontiguous_tensor_descriptor = True + + # Inject a bug into our relu implementation; useful for testing our repro + # extraction and minification functionality. + # Valid values: "compile_error", "runtime_error", "accuracy" + inject_relu_bug_TESTING_ONLY: Optional[str] = None + + # Whether to upcast float16 / bfloat16 to float32 in triton codegen (Experimental) + codegen_upcast_to_fp32 = True + + # Whether persistent matmul kernels should be enabled this flag only has effect when on h100 + # with a version of triton new enough to support TMA + enable_persistent_tma_matmul = ( + os.environ.get("ENABLE_PERSISTENT_TMA_MATMUL", "0") == "1" + ) + # Should TMA store be enable from templates. TODO: Remove once we + # can autotune over the result. + enable_template_tma_store = os.environ.get("ENABLE_TEMPLATE_TMA_STORE", "0") == "1" + # Use epilogue subtiling. We allow disabling it due to limited B200 testing. + enable_epilogue_subtiling = os.environ.get("ENABLE_EPILOGUE_SUBTILING", "1") == "1" + # Skip L1 cache for buffers that are used only once. Disabled by default + skip_l1_cache = os.environ.get("TORCHINDUCTOR_SKIP_L1", "0") == "1" + + # During autotuning, if one of the kernels/configs fails for some reason, + # Inductor will usually skip it (and assign its latency to inf). + # For testing it's helpful to be able to assert that none of the configs fail. + # Note: it may also need to be used with config.compile_threads = 1 + disallow_failing_autotune_kernels_TESTING_ONLY = False + + # specify number of splits to autotune on for decompose_k. 0 disables decompose_k + num_decompose_k_splits = int( + os.environ.get("TORCHINDUCTOR_NUM_DECOMPOSE_K_SPLITS", "10") + ) + + # specify minimum ratio of K to M AND N in order to autotune on decompose_k. 0 enables + # it as an autotuning choice for all matmuls + decompose_k_threshold = int( + os.environ.get("TORCHINDUCTOR_DECOMPOSE_K_THRESHOLD", "32") + ) + + # Programmatic Dependent Launch improves launch latency on Nvidia Hopper+ devices + # If set to true, will generate PDL code on devices that support it. + # If set to false, will never generate PDL code. + enable_pdl = False + + mix_order_reduction = ( + os.environ.get("TORCHINDUCTOR_MIX_ORDER_REDUCTION", "0" if is_fbcode() else "1") + == "1" + ) + mix_order_reduction_initial_xblock = 1 + + mix_order_reduction_split_size: Optional[int] = None + mix_order_reduction_autotune_split_size = ( + os.environ.get("TORCHINDUCTOR_MIX_ORDER_REDUCTION_AUTOTUNE_SPLIT_SIZE", "0") + == "1" + ) + + +class aot_inductor: + """ + Settings for Ahead-Of-Time Inductor Compilation + """ + + # AOTInductor output path + # If an absolute path is specified, the generated lib files will be stored under the directory; + # If a relative path is specified, it will be used as a subdirectory under the default caching path; + # If not specified, a temp directory will be created under the default caching path. + # If the specified path contains something like "model.so", the sub-string will be used + # to name the generated library. + output_path = "" + + debug_compile = os.environ.get("AOT_INDUCTOR_DEBUG_COMPILE", "0") == "1" + debug_symbols = os.environ.get("AOT_INDUCTOR_DEBUG_SYMBOLS", "0") == "1" + + # Annotate generated main wrapper function, i.e. AOTInductorModel::run_impl, + # to use which cpp compiler optimization level, default to O1 + compile_wrapper_opt_level = os.environ.get( + "AOT_INDUCTOR_COMPILE_WRAPPER_OPT_LEVEL", "O1" + ) + + # option for debug printing/saving for intermediate tensor values for aot inductor + # 0: disable debug dumping + # 1: enable saving intermediate tensor values + # 2: enable printing intermediate tensor values + # 3: enable printing kernel names only (useful for pinpointing troublesome kernels) + debug_intermediate_value_printer: Literal["0", "1", "2", "3"] = os.environ.get( + "AOT_INDUCTOR_DEBUG_INTERMEDIATE_VALUE_PRINTER", "0" + ) # type: ignore[assignment] + + # filtered nodes to be printed for debug values. Specify this option when debug_intermediate_value_printer is set to 2 + filtered_kernel_names = os.environ.get( + "AOT_INDUCTOR_FILTERED_KERNELS_TO_PRINT", None + ) + + # Serialized tree spec for flattening inputs + # TODO: Move this into metadata + serialized_in_spec = "" + + # Serialized tree spec for flattening outputs + # TODO: Move this into metadata + serialized_out_spec = "" + + # flag to decide whether to create a submodule for constant graph. + use_runtime_constant_folding: bool = False + + # flag to force weight to be appended to the shared library and mapped by the runtime + # rather than embedded into the data section. Needed to support 1B+ parameter models + force_mmap_weights: bool = False + + # Default value of use_consts_asm_build is True, it will build by assembly language. + # When the value is False, it will build by c++ language. + use_consts_asm_build = True + + package: bool = False + package_cpp_only: Optional[bool] = None + + # If package_cpp_only is True, whether cpp files will be compiled to a + # dynamically linked library or static linked library + dynamic_linkage: bool = True + + # Dictionary of metadata users might want to save to pass to the runtime. + # TODO: Move this somewhere else, since it's no longer really a config + metadata: dict[str, str] = {} + + # fbcode only. Whether to raise error if C++ codegen is too big to optimize + raise_error_on_ignored_optimization: bool = ( + os.environ.get("AOTINDUCTOR_RAISE_ERROR_ON_IGNORED_OPTIMIZATION", "1") == "1" + ) + + # Whether to check lowerbound constraints on dynamic shapes during runtime. + # When disabled, allows models with dynamic sizes of 0 or 1 to work with + # AOTI_RUNTIME_CHECK_INPUTS=1, avoiding errors from the [2+, ...] lowerbound + # restriction when backed_size_oblivious is off. + check_lowerbound: bool = True + + # dump an aoti minifier if program errors + dump_aoti_minifier: bool = os.environ.get("DUMP_AOTI_MINIFIER", "0") == "1" + + # Compiler compilation debug info + # 1: Dumps the original graph out to repro.py if compilation fails + # 2: Dumps a minifier_launcher.py if aoti fails. + # 3: Always dumps a minifier_launcher.py. Good for segfaults. + # 4: Dumps a minifier_launcher.py if the accuracy fails. + repro_level: int = int(os.environ.get("AOTINDUCTOR_REPRO_LEVEL", 2)) + + # Dictionary of presets that can be passed in + presets: dict[str, Any] = {} + + # Kill switch for allowing temporary tensors to be allocated as stack arrays. Tests + # should be run with this flag both on and off to make sure we have coverage. + allow_stack_allocation: bool = False + + # Enables an alternate DSO interface (the "minimal ArrayRef interface") intended + # to maximize performance for use cases that it can accommodate at the expense of + # generality. In brief: + # - inputs and outputs are ArrayRefTensor (note that strides are required, but the + # tensor must be contiguous) + # - constant handling is unchanged because it is not a per-inference-iteration bottleneck + # + # When the DSO is generated in this mode, the usual interface will also be supported, + # but performance for that interface may be degraded. + use_minimal_arrayref_interface: bool = False + + # Set to True if we want to use Pytorch's CUDACachingAllocator for weight management + weight_use_caching_allocator: bool = ( + os.environ.get("AOT_INDUCTOR_WEIGHT_USE_CACHING_ALLOCATOR", "0") == "1" + ) + + # Experimental. Flag to control whether to include weight in .so + # Not supported for cross_target_platform="windows". + package_constants_in_so: bool = True + + # Experimental. Flag to control whether to package weight separately on disk and which + # format to package it in. + # Options: + # None: + # Do not package weight separately on disk. + # "pickle_weights": + # Each weight is pickled and stored separately in data/weights. We also store the + # FQN names of each weight in a weights_config.json in each model's data/aot_inductor/model folder. + # Can only be load back from python using torch._inductor.aoti_load_package API now. + # "binary_blob": + # Stores all weights in a single binary blob in data/aot_inductor/model folder for each model. + # This option and config.aot_inductor.force_mmap_weights cannot both be True + package_constants_on_disk_format: Optional[str] = None + + # Experimental. Controls automatic precompiling of common AOTI include files. + precompile_headers: bool = not is_fbcode() + + # Embed generated kernel binary files into model.so + embed_kernel_binary: Optional[bool] = None + + # Generate kernel files that support multiple archs + # For CUDA, this means generating fatbin files for kernels, and the fatbin files + # contains PTX and SASS for the current architecture. + emit_multi_arch_kernel: Optional[bool] = None + + # If not None, the generated files with use this name in file stem. + # If None, we will use a hash to name files. + # + # If package_cpp_only, this name is also used for the target name in CMakelists.txt + # The default target name is "aoti_model" + # + # If compile_standalone, the aoti model class name is f"AOTInductorModel{name}" + # + # This name can only contain letters, numbers, and underscores. + model_name_for_generated_files: Optional[str] = None + + # Custom ops that have implemented C shim wrappers, defined as an op to C shim declaration dict + custom_ops_to_c_shims: dict[torch._ops.OpOverload, list[str]] = {} + # custom op libs that have implemented C shim wrappers + custom_op_libs: Optional[list[str]] = None + + # Whether to enable link-time-optimization + enable_lto = os.environ.get("AOT_INDUCTOR_ENABLE_LTO", "0") == "1" + + # Whether the compiled .so should link to libtorch + link_libtorch: bool = True + + # Currently the only valid option is "windows". + # We'll use x86_64-w64-mingw32-gcc to cross-compile a .dll file + # If using cuda, you also need to set WINDOWS_CUDA_HOME env var + # to point to windows CUDA toolkit. + # Example: WINDOWS_CUDA_HOME=cuda-windows-base/cuda_cudart/cudart/ + # The path should contain lib cuda and lib cudart + cross_target_platform: Optional[str] = None + + # If link_libtorch is False and cross_target_platform is windows, + # a library needs to be provided to provide the shim implementations. + aoti_shim_library: Optional[str | list[str]] = None + aoti_shim_library_path: Optional[str] = None + + +# a convenient class that automatically sets a group of the configs in aot_inductor +# it should only control the flags in aot_inductor. +# it should not do anything else. +class aot_inductor_mode: + # dynamic_linkage=False + # link_libtorch=False + # package_cpp_only=True + # embed_kernel_binary=True + # emit_multi_arch_kernel=True + compile_standalone: bool = False + + +class cuda: + """Settings for cuda backend, today this consists of cutlass""" + + # CUDA arch to use for CUDA template kernel compilation. + # e.g. "70", "75", "80", "90", etc. + # When arch is None, Inductor uses torch.cuda.get_device_capability(0). + arch: Optional[str] = None + + # CUDA version to use for CUDA template kernel compilation. + # e.g. "11.4", "12.1", etc. + # When version is None, Inductor uses torch.version.cuda. + version: Optional[str] = None + + # Optimization level for the host compiler. + compile_opt_level: Literal["-O0", "-O1", "-O2", "-O3", "-OS"] = "-O1" + + # Whether to enable device LTO (link-time-optimization). + enable_cuda_lto = False + + # Whether to keep intermediate files dring compilation. + enable_ptxas_info = False + + # Whether to enable debug info, e.g. line number, cutlass debug info. + enable_debug_info = False + + # Whether to use fast math. + use_fast_math = False + + # Path to the CUTLASS repo root directory. + # The default path only works under PyTorch local development environment. + cutlass_dir = os.path.realpath( + os.environ.get( + "TORCHINDUCTOR_CUTLASS_DIR", + os.path.join(os.path.dirname(torch.__file__), "../third_party/cutlass/"), + ) + ) + + # Configures the maximum number of CUTLASS configs to profile in max_autotune. + # By default it's None, so that all CUTLASS configs are tuned. + # This is mainly used to reduce test time in CI. + cutlass_max_profiling_configs: Optional[int] = None + + # The L2 swizzle values to consider when profiling CUTLASS configs in max_autotune. + cutlass_max_profiling_swizzle_options: list[int] = [1, 2, 4, 8] + + # Whether to use CUTLASS EVT for epilogue fusion + cutlass_epilogue_fusion_enabled = ( + os.environ.get("CUTLASS_EPILOGUE_FUSION", "0") == "1" + ) + + # Whether to only use TMA-compatible kernels in CUTLASS + cutlass_tma_only = False + + # Path to CUDA NVCC. + # NVCC search order: + # 1) cuda_cxx set in this config + # 2) CUDACXX environment variable + # 3) CUDA_HOME environment variable + # 4) default system search PATH. + cuda_cxx: Optional[str] = None + + # Minimum value of M*N*K to consider the CUTLASS backend for GEMM ops. + cutlass_backend_min_gemm_size: int = 1 + + # enable generation of inline standalone runner in CUDA CPP generated code + # which allows to compile the generated code into a standalone executable. + generate_test_runner: bool = ( + os.environ.get("INDUCTOR_CUDA_BACKEND_GENERATE_TEST_RUNNER_CODE", "0") == "1" + ) + + # Keep only Cutlass op configs which contain this regular expression pattern + # Set this to "warpspecialized_cooperative_epi_tma" to enable only SM90 TMA Cutlass Kernels for large GEMMs + cutlass_op_allowlist_regex: Optional[str] = os.environ.get( + "TORCHINDUCTOR_CUTLASS_ALLOWLIST" + ) + + # Note: Names of Cutlass ops names can be obtained by calling + # op.configuration_name() on a Cutlass op instance, for example those + # returned from cutlass_utils.gen_ops() or the op argument passed to + # CUTLASSGemmTemplate.render(...) + + # Filter Cutlass configs which contain this regular expression pattern + # Set this to "pingpong" to avoid numerical issues + # caused by the op ordering of the "pingpong" memory access + # pattern used by some Cutlass Kernels. + cutlass_op_denylist_regex: Optional[str] = os.environ.get( + "TORCHINDUCTOR_CUTLASS_DENYLIST" + ) + + # Non-negative integer which determines how many kernels are instantiated. + # 0 = 0000 generates the fewest kernels, 9999 generates all possible combinations. + # increasing first digit reduces schedule / mixed type pruning, + # increasing second digit generates more cluster sizes, + # increasing third digit generates more MMA multipliers, + # increasing fourth digit generates more instruction shapes. + cutlass_instantiation_level: str = os.environ.get( + "TORCHINDUCTOR_CUTLASS_INSTANTIATION_LEVEL", "0" + ) + + # use compile command to create kernel .cu and .so name + cutlass_hash_with_compile_cmd: bool = ( + os.environ.get("TORCHINDUCTOR_CUTLASS_HASH_WITH_COMPILE_CMD", "0") == "1" + ) + + # Experimental. Prescreen top x configs before tuning on swizzle. + cutlass_prescreening: bool = ( + os.environ.get("TORCHINDUCTOR_CUTLASS_PRESCREENING", "1") == "1" + ) + + # Specify which operations should use CUTLASS backend + # Comma-separated list like "mm,addmm,bmm", "all" for all operations, and "" for none. + # Acceptable operations: mm, int_mm, addmm, sparse_semi_structured_mm, bmm, scaled_mm + cutlass_enabled_ops: str = os.environ.get( + "TORCHINDUCTOR_CUTLASS_ENABLED_OPS", "all" + ) + + # Whether to consult the binary remote cache + use_binary_remote_cache: bool = True + + # Whether to upload compiled kernels to remote cache + upload_to_binary_remote_cache: bool = False + + # Whether to force upload if the key already exists + # Use this to overwrite and handle cache pollution + binary_remote_cache_force_write: bool = False + + # Enable caching codegen of cuda templates. + enable_caching_codegen: bool = True + + +class rocm: + # Offload arch list for device code compilation, e.g. ["gfx90a", "gfx942"]. + # If empty, the `native` arch is used + arch: list[str] = [] + + # Enable the CK backend for CDNA2 and CDNA3 only (for now) + # Processor name reference: https://llvm.org/docs/AMDGPUUsage.html#processors + ck_supported_arch: list[Literal["gfx90a", "gfx942", "gfx950"]] = [ + "gfx90a", + "gfx942", + "gfx950", + ] + + # Optimization level, use to balance compilation speed and runtime performance. + # The type will not necessarily be comprehensive and won't be enforced at runtime. + compile_opt_level: Literal[ + "-O0", "-O1", "-O2", "-O3", "-Os", "-Oz", "-Omin", "-Ofast", "-Omax" + ] = "-O2" + + # Flag to keep debug information in compiled objects + is_debug = False + + # Flag to keep intermediate files (assembly listings, preprocessed sources, etc.) + save_temps = False + + # Flag to add `-ffast-math`` to compile flags + use_fast_math = True + + # Flag to add `-fgpu-flush-denormals-to-zero` to compile flags + flush_denormals = True + + # Flag to print register and LDS usage during compilation + print_kernel_resource_usage = False + + # Path to ROCm installation, if None, use env variable ROCM_HOME. + # In fbcode see triton/fb/TARGETS for how ROCM_HOME gets set. + rocm_home: Optional[str] = None + + # Path to Composable Kernel library. + # Install with `pip install git+https://github.com/rocm/composable_kernel@develop`. + ck_dir = os.environ.get("TORCHINDUCTOR_CK_DIR") + + # generate standalone executables for instances generated with the CK backend + generate_test_runner: bool = ( + os.environ.get("INDUCTOR_CK_BACKEND_GENERATE_TEST_RUNNER_CODE", "0") == "1" + ) + + # Deprecated, use CK and/or CK-tile specific settings + n_max_profiling_configs: Optional[int] = None + + # Number of op instance choices to trade off between runtime perf and compilation time + # For CK Kernels + ck_max_profiling_configs: Optional[int] = None + + # Number of op instance choices to trade off between runtime perf and compilation time + # For CK-Tile Kernels + ck_tile_max_profiling_configs: Optional[int] = None + + # Flag to use a short list of CK instances which perform well across a variety of shapes. + # Currently RCR and F16 only + use_preselected_instances: bool = False + + # List to determine kBatch parameters to sweep over. By default, we calculate one in splitK + # scenarios, and run on kBatch=1 in non-splitK scenarios + kBatch_sweep: Optional[list[int]] = None + + # The threshold at which we trigger a splitK config - K // max(M,N) has to be greater than this + split_k_threshold: int = 16 + + # The threshold at which we trigger a contiguous subgraph transformation + contiguous_threshold: int = 16 + + +# Backend to use for CPU codegen either "cpp" or "triton" (experimental) or "halide" (experimental) or "pallas" (experimental) +cpu_backend: Literal["cpp", "triton", "halide", "pallas"] = "cpp" + +# Backend to use for CUDA codegen either +# "triton", "halide" (experimental) or "pallas" (experimental) +cuda_backend: Literal["triton", "halide", "pallas"] = "triton" + +# Backend to use for XPU codegen either "triton" +xpu_backend: Literal["triton"] = "triton" + + +class halide: + # Base halide target to use for CPU devices + cpu_target = "host" + + # Base halide target to use for CUDA devices + gpu_target = "host-cuda" + + # Halide autoscheduler to use, choices are: + # "Anderson2021" (gpu-only), "Li2018", "Adams2019" (cpu-only), or "Mullapudi2016" (cpu-only) + scheduler_cuda: Literal["Anderson2021", "Li2018", "Adams2019", "Mullapudi2016"] = ( + "Anderson2021" + ) + scheduler_cpu: Literal["Anderson2021", "Li2018", "Adams2019", "Mullapudi2016"] = ( + "Adams2019" + ) + + # Controls `no_asserts` flag passed to Halide target (warning: can false positive) + asserts = False + + # Controls `debug` flag passed to Halide target + debug = False + + # Enable (or fallback on) scan kernels such as cumsum + # Halide autoschedulers struggle with these kernels + scan_kernels = False + + +# create a directory containing lots of debug information +class trace: + # master switch for all debugging flags below + enabled = os.environ.get("TORCH_COMPILE_DEBUG", "0") == "1" + + # save real tensors + save_real_tensors = os.environ.get("TORCH_COMPILE_DEBUG_SAVE_REAL", "0") == "1" + + # Save debug information to a temporary directory + # If not specified, a temp directory will be created by system + debug_dir: Optional[str] = None + + # Save python logger call >=logging.DEBUG + debug_log = False + + # Save python logger call >=logging.INFO + info_log = False + + # Save input FX graph (post decomps, pre optimization) + fx_graph = True + + # Save FX graph after transformations + fx_graph_transformed = True + + # Save TorchInductor IR before fusion pass + ir_pre_fusion = True + + # Save TorchInductor IR after fusion pass + ir_post_fusion = True + + # Copy generated code to trace dir + output_code = True + + # SVG figure showing post-fusion graph + graph_diagram = os.environ.get("INDUCTOR_POST_FUSION_SVG", "0") == "1" + + # SVG figure showing fx with fusion + draw_orig_fx_graph = os.environ.get("INDUCTOR_ORIG_FX_SVG", "0") == "1" + + # We draw our fx graphs with the "record" shape attribute by default. + # Sometimes, when the graph is very complex, we may hit dot errors like below: + # "flat edge between adjacent nodes one of which has a record shape - + # replace records with HTML-like labels" + # and thus fail to generate a graph. So, let's give the user an option + # to specify the shape attribute for the dot graph. For example, passing + # INDUCTOR_DOT_GRAPH_SHAPE_SVG = "none" would let us generate HTML-like labels + # to workaround the above failure. + dot_graph_shape = os.environ.get("INDUCTOR_DOT_GRAPH_SHAPE_SVG", None) + + # If not None, this is the URL that saves the SVG files of the input/output + # graph of each pass that changed the graph + # The nodes that are being transformed in each pass will be colored in yellow + # URL only supports local directory for now + log_url_for_graph_xform = os.environ.get("INDUCTOR_LOG_URL_FOR_GRAPH_XFORM", None) + + # Store cProfile (see snakeviz to view) + compile_profile = False + + # Upload the .tar.gz file + # Needs to be overridden based on specific environment needs + upload_tar: Optional[Callable[[str], None]] = None + + log_autotuning_results = os.environ.get("LOG_AUTOTUNE_RESULTS", "0") == "1" + + # Save mapping info from inductor generated kernel to post_grad/pre_grad fx nodes + # Levels: + # 0 - disabled (default) + # 1 - normal + # 2 - basic + # Backward compatibility: + # If TORCH_COMPILE_DEBUG=1, level is set to at least 1. + # If INDUCTOR_PROVENANCE is set, use its integer value. + provenance_tracking_level: int = int( + os.environ.get( + "INDUCTOR_PROVENANCE", os.environ.get("TORCH_COMPILE_DEBUG", "0") + ) + ) + + +_save_config_ignore: list[str] = [ + # workaround: "Can't pickle " + "trace.upload_tar", + "joint_custom_pre_pass", + "joint_custom_post_pass", + "pre_grad_custom_pass", + "aot_inductor.repro_level", + "aot_inductor.dump_aoti_minifier", + "post_grad_custom_pre_pass", + "post_grad_custom_post_pass", + "_fuse_ddp_communication_passes", + "_pre_fusion_custom_pass", +] + +_cache_config_ignore_prefix: list[str] = [ + # trace functions are not relevant to config caching + "trace", + # uses absolute path + "cuda.cutlass_dir", + # not relevant + "worker_start_method", + "compile_threads", + # see CustomGraphPass; these are handled specially + "post_grad_custom_post_pass", + "post_grad_custom_pre_pass", + "joint_custom_pre_pass", + "joint_custom_post_pass", + "_fuse_ddp_communication_passes", + "_pre_fusion_custom_pass", + # tests assume that changes here don't invalidate cache + "always_complex_memory_overlap_TESTING_ONLY", + # cache related options are not relevant to cache results + "fx_graph_cache", + "fx_graph_remote_cache", + "autotune_local_cache", + "autotune_remote_cache", +] + +# External callable for matmul tuning candidates +external_matmul: list[Callable[[torch.Tensor, torch.Tensor, torch.Tensor], None]] = [] + +write_are_deterministic_algorithms_enabled = ( + os.getenv("TORCHINDUCTOR_WRITE_ARE_DETERMINISTIC_ALGORITHMS_ENABLED", "1") == "1" +) + + +class lookup_table: + # Lookup table for template config overrides + table: Optional[dict[str, list[dict[str, Any]]]] = None + + # Enable template src_hash checking in lookup table to prevent using stale configs. + # If True, configs with 'template_hash' field will be compared against the template's + # src_hash at runtime and filtered out if they don't match. If False, no + # hash checking is performed. + check_src_hash: bool = True + + +class test_configs: + force_extern_kernel_in_multi_template: bool = False + + max_mm_configs: Optional[int] = None + + runtime_triton_dtype_assert = False + runtime_triton_shape_assert = False + static_cpp_dtype_assert = False + + # regex to control the set of considered autotuning + # choices (aka configs) by name and / or description + autotune_choice_name_regex: Optional[str] = None + autotune_choice_desc_regex: Optional[str] = None + + graphsafe_rng_func_ignores_fallback_random = False + + track_memory_lifecycle: Optional[Literal["assert", "log"]] = None + + # If set to True, AOTI-generated CMakelists.txt will still use libtorch + # for unit testing + use_libtorch = False + + # Assume bucketing reduces latency (mostly for testing) + assume_bucketing_reduces_latency: bool = True + + # A test config to ease the test for perf of reduction config filtering + force_filter_reduction_configs = ( + os.getenv("TORCHINDUCTOR_FORCE_FILTER_REDUCTION_CONFIGS") == "1" + ) + + # a testing config to distort benchmarking result + # - empty string to disable + # - "inverse" to inverse the numbers + # - "random" return a random value + distort_benchmarking_result = os.getenv( + "TORCHINDUCTOR_DISTORT_BENCHMARKING_RESULT", "" + ) + + bisect_pre_grad_graph = False + bisect_keep_custom_backend_for_inductor = False + + +if TYPE_CHECKING: + from torch.utils._config_typing import * # noqa: F401, F403 + + +# adds patch, save_config, etc +install_config_module(sys.modules[__name__]) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/config_comms.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/config_comms.py new file mode 100644 index 0000000000000000000000000000000000000000..31f38b867dd5e80fb26f5c0b09144894e66e1dd2 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/config_comms.py @@ -0,0 +1,71 @@ +import os +import sys +from typing import Optional + +from torch.utils._config_module import install_config_module + + +# Whether to use c10d._time_estimator for collectives runtime estimations. +runtime_estimations_use_nccl_lib_estimations: bool = False + +# Config to enable sync of runtime estimations across distributed ranks, +# To prevent passes using this runtime estimations to make different +# decisions on different distributed ranks. +runtime_estimations_align_across_all_distributed_ranks: bool = False + +reorder_iterative_debug_memory_recompute: bool = False +reorder_iterative_debug_limit_to_reorder: Optional[int] = ( + None + # pyrefly: ignore[unbound-name] + if (env_str := os.getenv("PYTORCH_REORDER_COLLECTIVES_LIMIT")) is None + else int(env_str) +) +sink_waits_iterative_debug_limit_to_sink: Optional[int] = ( + # pyrefly: ignore[unbound-name] + None if (env_str := os.getenv("PYTORCH_SINK_WAITS_LIMIT")) is None else int(env_str) +) + + +# Should be used with config.runtime_estimations_mms_benchmark = True +reorder_iterative_use_runtime_estimations: bool = False +sink_iterative_use_runtime_estimations: bool = False + +# Broadcast runtime estimations doing real Collective operation between all ranks. +# If non-deterministic runtime estimations are used this must be used to make +# all ranks to do identical decisions and prevent global Collectives reordering, +# (that will result un NCCL hangs) +reorder_for_compute_comm_overlap_broadcast_runtime_estimations: bool = False + +# Block of Ratios to workaround imperfection of current runtime estimations +# for collectives and compute for different scenarios. +# Multiplier of collectives estimated durations +reorder_sink_runtime_estimations_comm_mult: float = 2.0 +# Multiplier of compute estimated durations +reorder_sink_runtime_estimations_non_comm_mult: float = 1.0 +# The reordering will stop to reorder +# when overlap_comp >= (1 + extra_overlap_ratio) * comm_time +# Allows to configure more aggressive overlap +reorder_iterative_extra_comm_comp_overlap: float = 0.5 +# The sink waits reordering will stop to reorder +# when overlap_comp >= (1 + extra_overlap_ratio) * comm_time +# Allows to configure more aggressive sink waits +sink_iterative_extra_comm_comp_overlap: float = 0.5 + +# Allow reorder iterative pass to increase peak memory +# up to peak_memory_before_pass * (1 + budget) +reorder_iterative_peak_memory_budget: float = 0.2 +# Allow sink waits iterative pass to increase peak memory +# up to peak_memory_before_pass * (1 + budget) +sink_iterative_peak_memory_budget: float = 0.2 + +# Experimental unsafe configuration that allows changing relative collectives order. +# Must be used with runtime_estimations_align_across_all_distributed_ranks = True +reorder_iterative_unsafe_collectives_reorder: bool = True +sink_waits_iterative_unsafe_collectives_reorder: bool = True + +# Allow group and move other collectives during reordering +reorder_iterative_group_with_collectives: bool = False +sink_waits_iterative_swap_with_collectives: bool = False + +# adds patch, save_config, etc +install_config_module(sys.modules[__name__]) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/constant_folding.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/constant_folding.py new file mode 100644 index 0000000000000000000000000000000000000000..1e473a7826ce09dc529e47274ca3daa4113fa1d9 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/constant_folding.py @@ -0,0 +1,416 @@ +import collections +from collections.abc import Callable +from typing import Any, Optional + +import torch +import torch.utils._pytree as pytree +from torch._inductor.freezing_utils import maybe_set_is_frozen_param +from torch.utils._ordered_set import OrderedSet + + +aten = torch.ops.aten + +# We would like to split modules into two subgraphs for runtime weight updates to work correctly. +# The use case and more information could be found at: +# https://docs.google.com/document/d/1inZC-8KarJ6gKB7G9egmYLx1V_dKX_apxon0w4zPC0Q/edit?usp=sharing +META_TAG = "MODULE_TYPE" +MODULE_TAG = "_MAIN_MODULE" +CONST_MODULE_TAG = "_CONST_MODULE" + +_dont_constant_fold: list[torch.fx.node.Target] = [] + + +def add_dont_constant_fold(op: torch.fx.node.Target) -> None: + global _dont_constant_fold + _dont_constant_fold.append(op) + + +def clear_dont_constant_fold() -> None: + global _dont_constant_fold + _dont_constant_fold.clear() + + +def replace_node_with_constant( + gm: torch.fx.GraphModule, + node: torch.fx.Node, + constant: Optional[torch.Tensor] = None, + name: Optional[str] = None, +) -> None: + g = gm.graph + + if name: + qualname = name + else: + if not hasattr(gm, "_frozen_param_count"): + gm._frozen_param_count = 0 # type: ignore[assignment] + i = gm._frozen_param_count + + while True: + qualname = f"_frozen_param{i}" + if not hasattr(gm, qualname): + break + i += 1 # type: ignore[assignment, operator] + + gm._frozen_param_count = i + 1 # type: ignore[assignment, operator] + + with g.inserting_before(node): + if constant is not None: + new_input_node = g.create_node("get_attr", qualname, (), {}) + else: + # this is the case for lifted constants + new_input_node = g.create_node("placeholder", qualname, (), {}) + node.replace_all_uses_with(new_input_node) + new_input_node.meta.update(node.meta) + g.erase_node(node) + new_input_node.name = node.name + + if constant is not None: + # needed to suppress `does not reference an nn.Module, nn.Parameter, or buffer` warning + gm.register_buffer(qualname, constant) + setattr(gm, qualname, constant) + # mark any constants created during freezing + maybe_set_is_frozen_param(constant) + + +def is_const_source( + node: torch.fx.Node, lifted_constant_names: Optional[list[str]] +) -> bool: + return node.op == "get_attr" or node.name in (lifted_constant_names or ()) + + +class ConstantFolder(torch.fx.Interpreter): + def __init__( + self, + gm: torch.fx.GraphModule, + skip_constructors: bool = False, + lifted_constant_names: Optional[list[str]] = None, + skip_folding_node_fn: Optional[Callable[[torch.fx.Node], bool]] = None, + ) -> None: + super().__init__(gm) + self.node_replacements: dict[torch.fx.Node, Any] = {} + self.replaced_uses: dict[torch.fx.Node, int] = collections.Counter() + self.unknown_value = object() + self.skip_constructors: bool = skip_constructors + + # overwrite this to deallocate env values if their only remaining use + # is the output + self.user_to_last_uses = self.node_to_last_non_output_use() + self.lifted_constant_names = lifted_constant_names + self.deferred_value = object() + self.skip_folding_node_fn = skip_folding_node_fn + + def _support_dynamic_shape(self) -> bool: + # ConstantFolder not support dynamic shape now + return False + + def _deduce_value(self, node: torch.fx.Node) -> Any: + if self.lifted_constant_names is None: + return super().run_node(node) + # if lifted_constant_names is passed in, no concrete value is available + # so we just check if all inputs have values + if self.skip_folding_node_fn is not None and self.skip_folding_node_fn(node): + return self.unknown_value + flattened_node_inps = pytree.arg_tree_leaves(*node.args, **node.kwargs) + for inp in flattened_node_inps: + if ( + isinstance(inp, torch.fx.Node) + and inp.name not in (self.lifted_constant_names or ()) + and self.env[inp] != self.deferred_value + ): + return self.unknown_value + return self.deferred_value + + def is_impure(self, node: torch.fx.node.Node) -> bool: + def is_woq_int8_pattern(node: torch.fx.node.Node) -> bool: + return ( + node.target is torch.ops.prims.convert_element_type.default # type: ignore[return-value] + and isinstance(node.args[0], torch.fx.Node) + and "val" in node.args[0].meta + and node.args[0].meta["val"].dtype == torch.int8 # type: ignore[union-attr] + and node.args[1] == torch.bfloat16 + ) + + if ( + is_woq_int8_pattern(node) + or ( + node.target is torch.ops.aten.permute.default + and len(node.users) == 1 + and is_woq_int8_pattern(next(iter(node.users))) + ) + ) and is_const_source( + node.args[0], # type: ignore[arg-type] + self.lifted_constant_names, + ): + # Case 1: int8_weight -> dq -> bf16_weight + # Case 2: int8_weight -> permute -> dq -> bf16_weight + return True + + quant_registered = ( + getattr(torch.ops.quantized_decomposed, "dequantize_per_channel", None) + is not None + ) + if quant_registered and node.target in [ + torch.ops.quantized_decomposed.dequantize_per_channel.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor, + torch.ops.quantized_decomposed.convert_element_type.no_fuse, + ]: + # For the pattern fp32_weight -> q -> dq + # We only folding fp32_weight -> q + # int8_weight and leave dq in graph to be fused + return True + + if node.target in _dont_constant_fold: + return True + return False + + def node_to_last_non_output_use(self) -> dict[torch.fx.Node, list[torch.fx.Node]]: + last_non_output_use = collections.defaultdict(list) + seen_uses = OrderedSet[torch.fx.Node]() + output_node = next(iter(reversed(self.module.graph.nodes))) # type: ignore[arg-type, union-attr] + + for node in reversed(self.module.graph.nodes): # type: ignore[arg-type, union-attr] + if node.target == "output": + continue + + def add_use(inp: torch.fx.Node) -> None: + if inp in seen_uses: + return + + seen_uses.add(inp) + last_non_output_use[node].append(inp) + + # In-place is fine since we don't mutate + pytree.tree_map_only_(torch.fx.Node, add_use, (node.args, node.kwargs)) + + # if this node is only used in output, we want to gc it right away + if len(node.users) == 1 and output_node in node.users: + last_non_output_use[node].append(node) + + return last_non_output_use + + def run_node(self, node: torch.fx.Node) -> Any: + if node.target == "output": + # because we remove nodes from env on last non output use, + # re-define them now or we'll get error in interpreter + def set_env(arg: torch.fx.Node) -> None: + self.env[arg] = self.unknown_value + + # In-place is fine since we don't mutate + pytree.tree_map_only_(torch.fx.Node, set_env, node.args) + return super().run_node(node) + + args, kwargs = self.fetch_args_kwargs_from_env(node) + flattened_inputs = pytree.arg_tree_leaves(*args, **kwargs) + + # We need to do this weird thing because in cases where flattened_inputs + # contains a ScriptObject, equality checking results in a type error if + # the types are different. + if any( + type(self.unknown_value) is type(input_) and self.unknown_value == input_ + for input_ in flattened_inputs + ): + return self.unknown_value + + # TODO - fix errors with this + if ( + node.op == "call_function" + and node.target is aten._efficientzerotensor.default + ): + return self.unknown_value + + # TODO - constant folding triton kernel returns the inputs -- fix this + if ( + node.op == "call_function" + and node.name == "triton_kernel_wrapper_functional_proxy" + ): + return self.unknown_value + + # skip constructors, since inductor generates optimal code for them already + # and turning into tensor would result in an additional global memory read + # TODO - more complicated strategy + if ( + self.skip_constructors + and not is_const_source(node, self.lifted_constant_names) + and not any(isinstance(e, torch.Tensor) for e in flattened_inputs) + ): + return self.unknown_value + + # All mutations should either be removed or on inputs which we did not make constant + if ( + isinstance(node.target, torch._ops.OpOverload) + and torch.Tag.nondeterministic_seeded in node.target.tags + ): + return self.unknown_value + + if node.op == "call_function" and isinstance( + node.target, torch._ops.HigherOrderOperator + ): + return self.unknown_value + + out = self._deduce_value(node) + + if isinstance(out, torch._C.ScriptObject): + return out + + if out == self.unknown_value: + return self.unknown_value + + if not is_const_source(node, self.lifted_constant_names) and ( + isinstance(out, torch.Tensor) or out == self.deferred_value + ): + if out != self.deferred_value and out.device.type == "meta": + return out + + if not self.insertable_tensor_check(out): + return out + + if self.is_impure(node): + return self.unknown_value + + self.add_node_replacement(node, out) + + flattened_node_inps = pytree.arg_tree_leaves(*node.args, **node.kwargs) + + for n in flattened_node_inps: + if not isinstance(n, torch.fx.Node): + continue + + self.replaced_uses[n] += 1 + + for to_delete in self.user_to_last_uses.get(node, []): + if self.replaced_uses[to_delete] == len(to_delete.users): + self.node_replacements.pop(to_delete, None) + + return out + + def insertable_tensor_check(self, tensor: torch.Tensor) -> bool: + return True + + def add_node_replacement(self, node: torch.fx.Node, tensor: torch.Tensor) -> None: + self.node_replacements[node] = tensor + + def run(self) -> Any: # type: ignore[override] + env: dict[torch.fx.Node, Any] = {} + self.insert_placerholder_values(env) + return super().run(initial_env=env) + + def insert_placerholder_values(self, env: dict[torch.fx.Node, Any]) -> None: + for n in self.module.graph.find_nodes(op="placeholder"): # type: ignore[operator, union-attr] + env[n] = self.unknown_value # type: ignore[assignment] + if self.lifted_constant_names is None: + return + for n in self.module.graph.nodes: # type: ignore[union-attr] + if n.name in (self.lifted_constant_names or ()): + env[n] = self.deferred_value + + +def constant_fold( + gm: torch.fx.GraphModule, + constraint_fn: Optional[Callable[[torch.fx.Node], bool]] = None, +) -> None: + with torch.utils._python_dispatch._disable_current_modes(): + cf = ConstantFolder(gm, skip_constructors=True) + cf.run() + + for node, constant in cf.node_replacements.items(): + if constraint_fn is not None and not constraint_fn(node): + continue + replace_node_with_constant(gm, node, constant) + + erased_params = [] + for node in gm.graph.find_nodes(op="get_attr"): + if len(node.users) == 0: + if hasattr(gm, node.target): + delattr(gm, node.target) + erased_params.append(node) + + for node in erased_params: + gm.graph.erase_node(node) + + gm.graph.eliminate_dead_code() + gm.graph.lint() + gm.recompile() + + +def constant_graph_tag( + gm: torch.fx.GraphModule, + skip_constructors: bool = True, + lifted_constant_names: Optional[list[str]] = None, + skip_folding_node_fn: Optional[Callable[[torch.fx.Node], bool]] = None, +) -> None: + with torch.utils._python_dispatch._disable_current_modes(): + cf = ConstantFolder( + gm, + skip_constructors=skip_constructors, + lifted_constant_names=lifted_constant_names, + skip_folding_node_fn=skip_folding_node_fn, + ) + cf.run() + + for node in gm.graph.nodes: + if skip_folding_node_fn is not None and skip_folding_node_fn(node): + node.meta[META_TAG] = MODULE_TAG + continue + if ( + is_const_source(node, lifted_constant_names) + or node in cf.node_replacements + or node in cf.replaced_uses + ): + node.meta[META_TAG] = CONST_MODULE_TAG + else: + node.meta[META_TAG] = MODULE_TAG + + +def run_and_get_constant_graph( + gm: torch.fx.GraphModule, + skip_constructors: bool = True, + lifted_constant_names: Optional[list[str]] = None, + skip_folding_node_fn: Optional[Callable[[torch.fx.Node], bool]] = None, +) -> torch.fx.GraphModule: + """ + Construct a GraphModule which corresponds to the part which could be + constant folded in provided gm. + """ + + constant_graph_tag( + gm, skip_constructors, lifted_constant_names, skip_folding_node_fn + ) + + def untag(node: torch.fx.Node) -> bool: + used_to_fold = False + for u in node.users: + if u.meta[META_TAG] == CONST_MODULE_TAG: + used_to_fold = True + break + if not used_to_fold: + node.meta[META_TAG] = MODULE_TAG + return used_to_fold + + # We rewrite the tags, if it's a constant being directly consumed, without + # any folding opportunity, we keep it in main gm. + for node in gm.graph.nodes: + if node.op == "get_attr" or (node.name in (lifted_constant_names or ())): + untag(node) + + new_graph = torch.fx.Graph() + + node_remapping: dict[torch.fx.Node, torch.fx.Node] = {} + output_nodes = [] + for node in gm.graph.nodes: + if node.meta[META_TAG] == MODULE_TAG: + continue + + new_node = new_graph.node_copy(node, lambda x: node_remapping[x]) + node_remapping[node] = new_node + + for user in node.users: + if user.meta[META_TAG] == MODULE_TAG: + output_nodes.append(new_node) + break + + new_graph.output(tuple(output_nodes)) + new_graph.lint() + new_gm = torch.fx.GraphModule(gm, new_graph) + + return new_gm diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/cpp_builder.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/cpp_builder.py new file mode 100644 index 0000000000000000000000000000000000000000..6a6b7d15ae3eabc0a378feb4ccd2f232d003a275 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/cpp_builder.py @@ -0,0 +1,2355 @@ +# This CPP builder is designed to support both Windows and Linux OS. +# The design document please check this RFC: https://github.com/pytorch/pytorch/issues/124245 + +import copy +import ctypes +import errno +import functools +import json +import locale +import logging +import os +import platform +import re +import shlex +import shutil +import subprocess +import sys +import sysconfig +import tempfile +import textwrap +import warnings +from collections.abc import Sequence +from ctypes import cdll, wintypes +from ctypes.util import find_library +from pathlib import Path +from typing import Any, Optional, Union + +import torch +from torch._dynamo.utils import dynamo_timed +from torch._inductor import config, exc +from torch._inductor.cpu_vec_isa import invalid_vec_isa, VecISA +from torch._inductor.runtime.runtime_utils import cache_dir +from torch.torch_version import TorchVersion + + +if config.is_fbcode(): + from triton.fb.build import _run_build_command, build_paths + + from torch._inductor.fb.utils import ( + log_global_cache_errors, + log_global_cache_stats, + log_global_cache_vals, + use_global_cache, + ) +else: + + def log_global_cache_errors(*args: Any, **kwargs: Any) -> None: # type: ignore[misc] + pass + + def log_global_cache_stats(*args: Any, **kwargs: Any) -> None: # type: ignore[misc] + pass + + def log_global_cache_vals(*args: Any, **kwargs: Any) -> None: # type: ignore[misc] + pass + + def use_global_cache() -> bool: # type: ignore[misc] + return False + + +# Windows need setup a temp dir to store .obj files. +_BUILD_TEMP_DIR = "CxxBuild" +_HERE = os.path.abspath(__file__) +_TORCH_PATH = os.path.dirname(os.path.dirname(_HERE)) +_LINKER_SCRIPT = os.path.join(_TORCH_PATH, "_inductor/script.ld") + +# initialize variables for compilation +_IS_LINUX = sys.platform.startswith("linux") +_IS_MACOS = sys.platform.startswith("darwin") +_IS_WINDOWS = sys.platform == "win32" + +MINGW_GXX = "x86_64-w64-mingw32-g++" + +SUBPROCESS_DECODE_ARGS = (locale.getpreferredencoding(),) if _IS_WINDOWS else () + +log = logging.getLogger(__name__) + + +# =============================== toolchain =============================== +@functools.lru_cache(1) +def cpp_compiler_search(search: str) -> str: + from torch._inductor.codecache import get_lock_dir, LOCK_TIMEOUT + + for cxx in search: + try: + if cxx is None: + # gxx package is only available for Linux + # according to https://anaconda.org/conda-forge/gxx/ + if sys.platform != "linux": + continue + # Do not install GXX by default + if not os.getenv("TORCH_INDUCTOR_INSTALL_GXX"): + continue + from torch.utils._filelock import FileLock + + lock_dir = get_lock_dir() + lock = FileLock( + os.path.join(lock_dir, "g++.lock"), timeout=LOCK_TIMEOUT + ) + with lock: + cxx = install_gcc_via_conda() + subprocess.check_output([cxx, "--version"]) + return cxx + except (subprocess.SubprocessError, FileNotFoundError, ImportError): + continue + raise exc.InvalidCxxCompiler + + +def install_gcc_via_conda() -> str: + """On older systems, this is a quick way to get a modern compiler""" + prefix = os.path.join(cache_dir(), "gcc") + cxx_path = os.path.join(prefix, "bin", "g++") + if not os.path.exists(cxx_path): + log.info("Downloading GCC via conda") + conda = os.environ.get("CONDA_EXE", "conda") + if conda is None: + conda = shutil.which("conda") + if conda is not None: + subprocess.check_call( + [ + conda, + "create", + f"--prefix={prefix}", + "--channel=conda-forge", + "--quiet", + "-y", + "python=3.8", + "gxx", + ], + stdout=subprocess.PIPE, + ) + return cxx_path + + +@functools.cache +def check_compiler_exist_windows(compiler: str) -> None: + """ + Check if compiler is ready, in case end user not activate MSVC environment. + """ + try: + subprocess.check_output([compiler, "/help"], stderr=subprocess.STDOUT) + except FileNotFoundError as exc: + raise RuntimeError(f"Compiler: {compiler} is not found.") from exc + except subprocess.SubprocessError: + # Expected that some compiler(clang, clang++) is exist, but they not support `/help` args. + pass + + +class WinPeFileVersionInfo: + def __init__(self, file_path: str) -> None: + self.file_path = file_path + self.version_dll = ctypes.WinDLL("version.dll") # type: ignore[attr-defined] + self._setup_functions() + self._get_version_info() + + def _setup_functions(self) -> None: + self.version_dll.GetFileVersionInfoSizeW.argtypes = [ + wintypes.LPCWSTR, + wintypes.LPDWORD, + ] + self.version_dll.GetFileVersionInfoSizeW.restype = wintypes.DWORD + + self.version_dll.GetFileVersionInfoW.argtypes = [ + wintypes.LPCWSTR, + wintypes.DWORD, + wintypes.DWORD, + wintypes.LPVOID, + ] + self.version_dll.GetFileVersionInfoW.restype = wintypes.BOOL + + self.version_dll.VerQueryValueW.argtypes = [ + wintypes.LPCVOID, + wintypes.LPCWSTR, + ctypes.POINTER(ctypes.c_void_p), + ctypes.POINTER(wintypes.UINT), + ] + self.version_dll.VerQueryValueW.restype = wintypes.BOOL + + def _get_version_info(self) -> None: + dummy = wintypes.DWORD() + size = self.version_dll.GetFileVersionInfoSizeW( + self.file_path, ctypes.byref(dummy) + ) + + if size == 0: + raise RuntimeError(f"Can't get version info size of {self.file_path}.") + + self.version_info = ctypes.create_string_buffer(size) + success = self.version_dll.GetFileVersionInfoW( + self.file_path, 0, size, self.version_info + ) + + if not success: + raise RuntimeError(f"Can't get version info of {self.file_path}.") + + def get_language_id(self) -> int: + lp_buffer = ctypes.c_void_p() + u_len = wintypes.UINT() + + success = self.version_dll.VerQueryValueW( + self.version_info, + r"\VarFileInfo\Translation", + ctypes.byref(lp_buffer), + ctypes.byref(u_len), + ) + + if not success or u_len.value == 0: + return 0 + + translations = [] + lang_id: int = 0 + if lp_buffer.value is not None: + for i in range(u_len.value // 4): + offset = i * 4 + data = ctypes.string_at(lp_buffer.value + offset, 4) + lang_id = int.from_bytes(data[:2], "little") + code_page = int.from_bytes(data[2:4], "little") + translations.append((lang_id, code_page)) + else: + # Handle the case where lp_buffer.value is None + print("Buffer is None") + + return lang_id + + +@functools.cache +def check_msvc_cl_language_id(compiler: str) -> None: + """ + Torch.compile() is only work on MSVC with English language pack well. + Check MSVC's language pack: https://github.com/pytorch/pytorch/issues/157673#issuecomment-3051682766 + """ + + def get_msvc_cl_path() -> tuple[bool, str]: + """ + Finds the path to cl.exe using vswhere.exe. + """ + vswhere_path = os.path.join( + os.environ.get("ProgramFiles(x86)", "C:\\Program Files (x86)"), + "Microsoft Visual Studio", + "Installer", + "vswhere.exe", + ) + if not os.path.exists(vswhere_path): + vswhere_path = os.path.join( + os.environ.get("ProgramFiles", "C:\\Program Files"), + "Microsoft Visual Studio", + "Installer", + "vswhere.exe", + ) + if not os.path.exists(vswhere_path): + return False, "" # vswhere.exe not found + + try: + # Get the Visual Studio installation path + cmd = [ + vswhere_path, + "-latest", + "-prerelease", + "-products", + "*", + "-requires", + "Microsoft.VisualStudio.Component.VC.Tools.x86.x64", + "-property", + "installationPath", + ] + vs_install_path = subprocess.check_output( + cmd, text=True, encoding="utf-8" + ).strip() + + if not vs_install_path: + return False, "" + + # Find the latest MSVC toolset version within the installation + msvc_tools_path = os.path.join(vs_install_path, "VC", "Tools", "MSVC") + if not os.path.exists(msvc_tools_path): + return False, "" + + # Get the latest toolset version directory + toolset_versions = [ + d + for d in os.listdir(msvc_tools_path) + if os.path.isdir(os.path.join(msvc_tools_path, d)) + ] + if not toolset_versions: + return False, "" + latest_toolset_version = sorted(toolset_versions, reverse=True)[0] + + # Construct the full cl.exe path + cl_path = os.path.join( + msvc_tools_path, + latest_toolset_version, + "bin", + "HostX64", + "x64", + "cl.exe", + ) + if os.path.exists(cl_path): + return True, cl_path + else: + # Fallback for older versions or different architectures if needed + cl_path = os.path.join( + msvc_tools_path, + latest_toolset_version, + "bin", + "HostX86", + "x86", + "cl.exe", + ) + if os.path.exists(cl_path): + return True, cl_path + + except (subprocess.CalledProcessError, FileNotFoundError): + return False, "" + + return False, "" + + if not _is_msvc_cl(compiler): + return + + if os.path.exists(compiler): + # Passed compiler with path. + cl_exe_path = compiler + else: + b_ret, cl_exe_path = get_msvc_cl_path() + if b_ret is False: + return + + version_info = WinPeFileVersionInfo(cl_exe_path) + lang_id = version_info.get_language_id() + if lang_id != 1033: + # MSVC English language id is 0x0409, and the DEC value is 1033. + raise RuntimeError( + "Torch.compile() is only support MSVC with English language pack," + "Please reinstall its language pack to English." + ) + + +@functools.cache +def check_mingw_win32_flavor(compiler: str) -> str: + """ + Check if MinGW `compiler` exists and return it's flavor (win32 or posix). + """ + try: + out = subprocess.check_output( + [compiler, "-v"], stderr=subprocess.STDOUT, text=True + ) + except FileNotFoundError as e: + raise RuntimeError(f"Compiler: {compiler} is not found.") from e + except Exception as e: + raise RuntimeError(f"Failed to run {compiler} -v") from e + + flavor: str | None = None + for line in out.splitlines(): + if "Thread model" in line: + flavor = line.split(":", 1)[-1].strip().lower() + + if flavor is None: + raise RuntimeError( + f"Cannot determine the flavor of {compiler} (win32 or posix). No Thread model found in {compiler} -v" + ) + + if flavor not in ("win32", "posix"): + raise RuntimeError( + f"Only win32 and pofix flavor of {compiler} is supported. The flavor is {flavor}" + ) + + return flavor + + +def get_cpp_compiler() -> str: + if ( + config.aot_inductor.cross_target_platform == "windows" + and sys.platform != "win32" + ): + # we're doing cross-compilation + compiler = MINGW_GXX + if not config.aot_inductor.package_cpp_only: + check_mingw_win32_flavor(compiler) + return compiler + + if _IS_WINDOWS: + compiler = os.environ.get("CXX", "cl") + compiler = normalize_path_separator(compiler) + check_compiler_exist_windows(compiler) + check_msvc_cl_language_id(compiler) + else: + if config.is_fbcode(): + return build_paths.cc + if isinstance(config.cpp.cxx, (list, tuple)): + search = tuple(config.cpp.cxx) + else: + search = (config.cpp.cxx,) + compiler = cpp_compiler_search(search) + return compiler + + +def get_ld_and_objcopy(use_relative_path: bool) -> tuple[str, str]: + if _IS_WINDOWS: + raise RuntimeError("Windows is not supported yet.") + else: + if config.is_fbcode(): + ld = build_paths.ld + objcopy = ( + build_paths.objcopy_fallback + if use_relative_path + else build_paths.objcopy + ) + else: + ld = "ld" + objcopy = "objcopy" + return ld, objcopy + + +def convert_cubin_to_obj( + cubin_file: str, + kernel_name: str, + ld: str, + objcopy: str, +) -> str: + obj_file = cubin_file + ".o" + # Convert .cubin to .o + cmd = f"{ld} -r -b binary -z noexecstack -o {obj_file} {cubin_file}" + subprocess.run(cmd.split(), capture_output=True, text=True, check=True) + # Rename .data to .rodata + cmd = f"{objcopy} --rename-section .data=.rodata,alloc,load,readonly,data,contents {obj_file}" + subprocess.run(cmd.split(), capture_output=True, text=True, check=True) + # By default objcopy will create *_start, *_size, *_end symbols using the full path + # Rename to use the unique kernel name + file_name = re.sub(r"[\W]", "_", cubin_file) + cmd = ( + objcopy + + f" --redefine-sym _binary_{file_name}_start=__{kernel_name}_start " + + f"--redefine-sym _binary_{file_name}_size=__{kernel_name}_size " + + f"--redefine-sym _binary_{file_name}_end=__{kernel_name}_end " + + obj_file + ) + subprocess.run(cmd.split(), capture_output=True, text=True, check=True) + return obj_file + + +@functools.cache +def _is_apple_clang(cpp_compiler: str) -> bool: + version_string = subprocess.check_output([cpp_compiler, "--version"]).decode("utf8") + return "Apple" in version_string.splitlines()[0] + + +@functools.cache +def _is_clang(cpp_compiler: str) -> bool: + # Mac OS apple clang maybe named as gcc, need check compiler info. + if sys.platform == "darwin": + return _is_apple_clang(cpp_compiler) + elif _IS_WINDOWS: + # clang suite have many compilers, and only clang-cl is supported. + if re.search(r"((clang$)|(clang\+\+$))", cpp_compiler): + raise RuntimeError( + "Please use clang-cl, due to torch.compile only support MSVC-like CLI (compiler flags syntax)." + ) + return bool(re.search(r"(clang-cl)", cpp_compiler)) + return bool(re.search(r"(clang|clang\+\+)", cpp_compiler)) + + +@functools.cache +def _is_gcc(cpp_compiler: str) -> bool: + # Since "clang++" ends with "g++", the regex match below would validate on it. + if _is_clang(cpp_compiler): + return False + return bool(re.search(r"(gcc|g\+\+|gnu-c\+\+)", cpp_compiler)) + + +@functools.cache +def _is_msvc_cl(cpp_compiler: str) -> bool: + if not _IS_WINDOWS: + return False + + try: + output_msg = ( + subprocess.check_output([cpp_compiler, "/help"], stderr=subprocess.STDOUT) + .strip() + .decode(*SUBPROCESS_DECODE_ARGS) + ) + return "Microsoft" in output_msg.splitlines()[0] + except FileNotFoundError: + return False + + return False + + +@functools.cache +def _is_intel_compiler(cpp_compiler: str) -> bool: + def _check_minimal_version(compiler_version: TorchVersion) -> None: + """ + On Windows: early version icx has `-print-file-name` issue, and can't preload correctly for inductor. + """ + min_version = "2024.2.1" if _IS_WINDOWS else "0.0.0" + if compiler_version < TorchVersion(min_version): + raise RuntimeError( + f"Intel Compiler error: less than minimal version {min_version}." + ) + + try: + output_msg = ( + subprocess.check_output( + [cpp_compiler, "--version"], stderr=subprocess.DEVNULL + ) + .strip() + .decode(*SUBPROCESS_DECODE_ARGS) + ) + is_intel_compiler = "Intel" in output_msg.splitlines()[0] + if is_intel_compiler: + if _IS_WINDOWS: + if re.search(r"((icx$)|(icx-cc$))", cpp_compiler): + raise RuntimeError( + "Please use icx-cl, due to torch.compile only support MSVC-like CLI (compiler flags syntax)." + ) + + # Version check + icx_ver_search = re.search(r"(\d+[.]\d+[.]\d+[.]\d+)", output_msg) + if icx_ver_search is not None: + icx_ver = icx_ver_search.group(1) + _check_minimal_version(TorchVersion(icx_ver)) + + return is_intel_compiler + except FileNotFoundError: + return False + except subprocess.SubprocessError: + # --version args not support. + return False + + return False + + +@functools.cache +def is_gcc() -> bool: + return _is_gcc(get_cpp_compiler()) + + +@functools.cache +def is_clang() -> bool: + return _is_clang(get_cpp_compiler()) + + +@functools.cache +def is_intel_compiler() -> bool: + return _is_intel_compiler(get_cpp_compiler()) + + +@functools.cache +def is_apple_clang() -> bool: + return _is_apple_clang(get_cpp_compiler()) + + +@functools.cache +def is_msvc_cl() -> bool: + return _is_msvc_cl(get_cpp_compiler()) + + +@functools.cache +def get_compiler_version_info(compiler: str) -> str: + env = os.environ.copy() + env["LC_ALL"] = "C" # Don't localize output + try: + version_string = subprocess.check_output( + [compiler, "-v"], stderr=subprocess.STDOUT, env=env + ).decode(*SUBPROCESS_DECODE_ARGS) + except Exception: + try: + version_string = subprocess.check_output( + [compiler, "--version"], stderr=subprocess.STDOUT, env=env + ).decode(*SUBPROCESS_DECODE_ARGS) + except Exception: + return "" + # Multiple lines to one line string. + version_string = version_string.replace("\r", "_") + version_string = version_string.replace("\n", "_") + return version_string + + +# =============================== cpp builder =============================== +def _append_list(dest_list: list[str], src_list: list[str]) -> None: + dest_list.extend(copy.deepcopy(item) for item in src_list) + + +def _remove_duplication_in_list(orig_list: list[str]) -> list[str]: + new_list: list[str] = [] + for item in orig_list: + if item not in new_list: + new_list.append(item) + return new_list + + +def _create_if_dir_not_exist(path_dir: str) -> None: + if not os.path.exists(path_dir): + try: + Path(path_dir).mkdir(parents=True, exist_ok=True) + except OSError as exc: # Guard against race condition + if exc.errno != errno.EEXIST: + raise RuntimeError(f"Fail to create path {path_dir}") from exc + + +def _remove_dir(path_dir: str) -> None: + if os.path.exists(path_dir): + for root, dirs, files in os.walk(path_dir, topdown=False): + for name in files: + file_path = os.path.join(root, name) + os.remove(file_path) + for name in dirs: + dir_path = os.path.join(root, name) + os.rmdir(dir_path) + os.rmdir(path_dir) + + +def _run_compile_cmd(cmd_line: str, cwd: str) -> None: + cmd = shlex.split(cmd_line) + try: + subprocess.run( + cmd, cwd=cwd, check=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT + ) + except subprocess.CalledProcessError as e: + output = e.stdout.decode(*SUBPROCESS_DECODE_ARGS) + openmp_problem = "'omp.h' file not found" in output or "libomp" in output + if openmp_problem and sys.platform == "darwin": + instruction = ( + "\n\nOpenMP support not found. Please try one of the following solutions:\n" + "(1) Set the `CXX` environment variable to a compiler other than Apple clang++/g++ " + "that has builtin OpenMP support;\n" + "(2) install OpenMP via conda: `conda install llvm-openmp`;\n" + "(3) install libomp via brew: `brew install libomp`;\n" + "(4) manually setup OpenMP and set the `OMP_PREFIX` environment variable to point to a path" + " with `include/omp.h` under it." + ) + output += instruction + raise exc.CppCompileError(cmd, output) from e + + +def run_compile_cmd(cmd_line: str, cwd: str) -> None: + with dynamo_timed("compile_file"): + _run_compile_cmd(cmd_line, cwd) + + +def normalize_path_separator(orig_path: str) -> str: + if _IS_WINDOWS: + return orig_path.replace(os.sep, "/") + return orig_path + + +class BuildOptionsBase: + """ + This is the Base class for store cxx build options, as a template. + Actually, to build a cxx shared library. We just need to select a compiler + and maintains the suitable args. + """ + + def __init__( + self, + compiler: str = "", + definitions: Optional[list[str]] = None, + include_dirs: Optional[list[str]] = None, + cflags: Optional[list[str]] = None, + ldflags: Optional[list[str]] = None, + libraries_dirs: Optional[list[str]] = None, + libraries: Optional[list[str]] = None, + passthrough_args: Optional[list[str]] = None, + aot_mode: bool = False, + use_relative_path: bool = False, + compile_only: bool = False, + precompiling: bool = False, + preprocessing: bool = False, + ) -> None: + self._compiler = compiler + self._definitions: list[str] = definitions or [] + self._include_dirs: list[str] = include_dirs or [] + self._cflags: list[str] = cflags or [] + self._ldflags: list[str] = ldflags or [] + self._libraries_dirs: list[str] = libraries_dirs or [] + self._libraries: list[str] = libraries or [] + # Some args are hard to abstract to OS compatible, passthrough directly. + self._passthrough_args: list[str] = passthrough_args or [] + + # Optionally, the path to a precompiled header which should be included on the + # build command line. + self.precompiled_header: Optional[str] = None + + self._aot_mode: bool = aot_mode + self._use_relative_path: bool = use_relative_path + self._compile_only: bool = compile_only + self._precompiling: bool = precompiling + self._preprocessing: bool = preprocessing + + def _process_compile_only_options(self) -> None: + if self._compile_only: + self._libraries_dirs = [] + self._libraries = [] + + def _remove_duplicate_options(self) -> None: + self._definitions = _remove_duplication_in_list(self._definitions) + self._include_dirs = _remove_duplication_in_list(self._include_dirs) + self._cflags = _remove_duplication_in_list(self._cflags) + self._ldflags = _remove_duplication_in_list(self._ldflags) + self._libraries_dirs = _remove_duplication_in_list(self._libraries_dirs) + self._libraries = _remove_duplication_in_list(self._libraries) + self._passthrough_args = _remove_duplication_in_list(self._passthrough_args) + + def _finalize_options(self) -> None: + self._process_compile_only_options() + self._remove_duplicate_options() + + def get_compiler(self) -> str: + return self._compiler + + def get_definitions(self) -> list[str]: + return self._definitions + + def get_include_dirs(self) -> list[str]: + return self._include_dirs + + def get_cflags(self) -> list[str]: + return self._cflags + + def get_ldflags(self) -> list[str]: + return self._ldflags + + def get_libraries_dirs(self) -> list[str]: + return self._libraries_dirs + + def get_libraries(self) -> list[str]: + return self._libraries + + def get_passthrough_args(self) -> list[str]: + return self._passthrough_args + + def get_aot_mode(self) -> bool: + return self._aot_mode + + def get_use_relative_path(self) -> bool: + return self._use_relative_path + + def get_compile_only(self) -> bool: + return self._compile_only + + def get_precompiling(self) -> bool: + return self._precompiling + + def get_preprocessing(self) -> bool: + return self._preprocessing + + def save_flags_to_json(self, file: str) -> None: + attrs = { + "compiler": self.get_compiler(), + "definitions": self.get_definitions(), + "include_dirs": self.get_include_dirs(), + "cflags": self.get_cflags(), + "ldflags": self.get_ldflags(), + "libraries_dirs": self.get_libraries_dirs(), + "libraries": self.get_libraries(), + "passthrough_args": self.get_passthrough_args(), + "aot_mode": self.get_aot_mode(), + "use_relative_path": self.get_use_relative_path(), + "compile_only": self.get_compile_only(), + } + + with open(file, "w") as f: + json.dump(attrs, f) + + +def _get_warning_all_cflag(warning_all: bool = True) -> list[str]: + if not _IS_WINDOWS: + return ["Wall"] if warning_all else [] + else: + return [] + + +def _get_cpp_std_cflag(std_num: str = "c++17") -> list[str]: + if _IS_WINDOWS: + """ + On Windows, only c++20 can support `std::enable_if_t`. + Ref: https://learn.microsoft.com/en-us/cpp/overview/cpp-conformance-improvements-2019?view=msvc-170#checking-for-abstract-class-types # noqa: B950 + Note: + Only setup c++20 for Windows inductor. I tried to upgrade all project to c++20, but it is failed: + https://github.com/pytorch/pytorch/pull/131504 + """ + std_num = "c++20" + return [f"std:{std_num}"] + else: + return [f"std={std_num}"] + + +def _get_os_related_cpp_cflags(cpp_compiler: str) -> list[str]: + if _IS_WINDOWS: + cflags = [ + "wd4819", + "wd4251", + "wd4244", + "wd4267", + "wd4275", + "wd4018", + "wd4190", + "wd4624", + "wd4067", + "wd4068", + "EHsc", + # For Intel oneAPI, ref: https://learn.microsoft.com/en-us/cpp/build/reference/zc-cplusplus?view=msvc-170 + "Zc:__cplusplus", + # Enable max compatible to msvc for oneAPI headers. + # ref: https://github.com/pytorch/pytorch/blob/db38c44ad639e7ada3e9df2ba026a2cb5e40feb0/cmake/public/utils.cmake#L352-L358 # noqa: B950 + "permissive-", + ] + else: + cflags = ["Wno-unused-variable", "Wno-unknown-pragmas"] + if _is_clang(cpp_compiler): + ignored_optimization_argument = ( + "Werror=ignored-optimization-argument" + if config.aot_inductor.raise_error_on_ignored_optimization + else "Wno-ignored-optimization-argument" + ) + cflags.append(ignored_optimization_argument) + if _is_gcc(cpp_compiler): + # Issue all the warnings demanded by strict ISO C and ISO C++. + # Ref: https://github.com/pytorch/pytorch/issues/153180#issuecomment-2986676878 + cflags.append("pedantic") + return cflags + + +def _get_os_related_cpp_definitions(cpp_compiler: str) -> list[str]: + os_definitions: list[str] = [] + if _IS_WINDOWS: + # On Windows, we need disable min/max macro to avoid C2589 error, as PyTorch CMake: + # https://github.com/pytorch/pytorch/blob/9a41570199155eee92ebd28452a556075e34e1b4/CMakeLists.txt#L1118-L1119 + os_definitions.append("NOMINMAX") + return os_definitions + + +def _get_ffast_math_flags() -> list[str]: + if _IS_WINDOWS: + flags = [] + else: + # ffast-math is equivalent to these flags as in + # https://github.com/gcc-mirror/gcc/blob/4700ad1c78ccd7767f846802fca148b2ea9a1852/gcc/opts.cc#L3458-L3468 + # however gcc<13 sets the FTZ/DAZ flags for runtime on x86 even if we have + # -ffast-math -fno-unsafe-math-optimizations because the flags for runtime + # are added by linking in crtfastmath.o. This is done by the spec file which + # only does globbing for -ffast-math. + flags = [ + "fno-trapping-math", + "funsafe-math-optimizations", + "ffinite-math-only", + "fno-signed-zeros", + "fno-math-errno", + ] + + flags.append("fno-finite-math-only") + if not config.cpp.enable_unsafe_math_opt_flag: + flags.append("fno-unsafe-math-optimizations") + flags.append(f"ffp-contract={config.cpp.enable_floating_point_contract_flag}") + + if is_gcc(): + flags.append("fexcess-precision=fast") + + return flags + + +def _get_inductor_debug_symbol_cflags() -> tuple[list[str], list[str]]: + """ + When we turn on generate debug symbol. + On Windows, it should create a [module_name].pdb file. It helps debug by WinDBG. + On Linux, it should create some debug sections in binary file. + """ + cflags: list[str] = [] + ldflags: list[str] = [] + + if _IS_WINDOWS: + cflags = ["ZI", "_DEBUG"] + ldflags = ["DEBUG", "ASSEMBLYDEBUG ", "OPT:REF", "OPT:ICF"] + else: + cflags.append("g") + + return cflags, ldflags + + +def _get_optimization_cflags( + cpp_compiler: str, min_optimize: bool = False +) -> tuple[list[str], list[str]]: + cflags: list[str] = [] + ldflags: list[str] = [] + + should_use_optimized_flags = not ( + config.aot_inductor.debug_compile + or os.environ.get("TORCHINDUCTOR_DEBUG_COMPILE", "0") == "1" + ) + should_add_debug_symbol_flags = ( + config.aot_inductor.debug_compile + or config.aot_inductor.debug_symbols + or os.environ.get("TORCHINDUCTOR_DEBUG_COMPILE", "0") == "1" + or os.environ.get("TORCHINDUCTOR_DEBUG_SYMBOL", "0") == "1" + ) + if should_use_optimized_flags: + if _IS_WINDOWS: + cflags += ["O1" if min_optimize else "O2"] + else: + cflags += [ + config.aot_inductor.compile_wrapper_opt_level if min_optimize else "O3", + "DNDEBUG", + ] + else: + if _IS_WINDOWS: + cflags += ["Od", "Ob0", "Oy-"] + else: + cflags += ["O0"] + + if should_add_debug_symbol_flags: + debug_cflags, debug_ldflags = _get_inductor_debug_symbol_cflags() + cflags += debug_cflags + ldflags += debug_ldflags + + cflags += _get_ffast_math_flags() + + if _IS_WINDOWS: + pass + else: + if sys.platform != "darwin": + # on macos, unknown argument: '-fno-tree-loop-vectorize' + if _is_gcc(cpp_compiler): + cflags.append("fno-tree-loop-vectorize") + # https://stackoverflow.com/questions/65966969/why-does-march-native-not-work-on-apple-m1 + # `-march=native` is unrecognized option on M1 + if not config.is_fbcode(): + if platform.machine() == "ppc64le": + cflags.append("mcpu=native") + elif platform.machine() == "riscv64": + cflags.append("march=rv64gc") + elif platform.machine() == "riscv32": + cflags.append("march=rv32gc") + else: + cflags.append("march=native") + + if config.aot_inductor.enable_lto and _is_clang(cpp_compiler): + cflags.append("flto=thin") + + return cflags, ldflags + + +def _get_shared_cflags(do_link: bool) -> list[str]: + if _IS_WINDOWS: + """ + MSVC `/MD` using python `ucrtbase.dll` lib as runtime. + https://learn.microsoft.com/en-us/cpp/c-runtime-library/crt-library-features?view=msvc-170 + """ + return ["DLL", "MD"] + if platform.system() == "Darwin" and "clang" in get_cpp_compiler(): + # This causes undefined symbols to behave the same as linux + return ["shared", "fPIC", "undefined dynamic_lookup"] + flags = [] + if do_link: + flags.append("shared") + + flags.append("fPIC") + return flags + + +def get_cpp_options( + cpp_compiler: str, + do_link: bool, + warning_all: bool = True, + extra_flags: Sequence[str] = (), + min_optimize: bool = False, +) -> tuple[list[str], list[str], list[str], list[str], list[str], list[str], list[str]]: + definitions: list[str] = [] + include_dirs: list[str] = [] + cflags: list[str] = [] + ldflags: list[str] = [] + libraries_dirs: list[str] = [] + libraries: list[str] = [] + passthrough_args: list[str] = [] + + opt_cflags, opt_ldflags = _get_optimization_cflags(cpp_compiler, min_optimize) + + cflags = ( + opt_cflags + + _get_shared_cflags(do_link) + + _get_warning_all_cflag(warning_all) + + _get_cpp_std_cflag() + + _get_os_related_cpp_cflags(cpp_compiler) + ) + + definitions += _get_os_related_cpp_definitions(cpp_compiler) + + if not _IS_WINDOWS and config.aot_inductor.enable_lto and _is_clang(cpp_compiler): + ldflags.append("fuse-ld=lld") + ldflags.append("flto=thin") + + passthrough_args.append(" ".join(extra_flags)) + + if config.aot_inductor.cross_target_platform == "windows": + passthrough_args.extend(["-static-libstdc++", "-static-libgcc"]) + if check_mingw_win32_flavor(MINGW_GXX) == "posix": + passthrough_args.append("-Wl,-Bstatic -lwinpthread -Wl,-Bdynamic") + + return ( + definitions, + include_dirs, + cflags, + ldflags + opt_ldflags, + libraries_dirs, + libraries, + passthrough_args, + ) + + +class CppOptions(BuildOptionsBase): + """ + This class is inherited from BuildOptionsBase, and as cxx build options. + This option need contains basic cxx build option, which contains: + 1. OS related args. + 2. Toolchains related args. + 3. Cxx standard related args. + Note: + 1. This Options is good for assist modules build, such as x86_isa_help. + """ + + def __init__( + self, + compile_only: bool = False, + warning_all: bool = True, + extra_flags: Sequence[str] = (), + use_relative_path: bool = False, + compiler: str = "", + min_optimize: bool = False, + precompiling: bool = False, + preprocessing: bool = False, + ) -> None: + super().__init__( + compile_only=compile_only, + use_relative_path=use_relative_path, + precompiling=precompiling, + preprocessing=preprocessing, + ) + self._compiler = compiler if compiler else get_cpp_compiler() + + ( + definitions, + include_dirs, + cflags, + ldflags, + libraries_dirs, + libraries, + passthrough_args, + ) = get_cpp_options( + cpp_compiler=self._compiler, + do_link=not (compile_only or precompiling or preprocessing), + extra_flags=extra_flags, + warning_all=warning_all, + min_optimize=min_optimize, + ) + + _append_list(self._definitions, definitions) + _append_list(self._include_dirs, include_dirs) + _append_list(self._cflags, cflags) + _append_list(self._ldflags, ldflags) + _append_list(self._libraries_dirs, libraries_dirs) + _append_list(self._libraries, libraries) + _append_list(self._passthrough_args, passthrough_args) + self._finalize_options() + + +def _get_torch_cpp_wrapper_definition() -> list[str]: + return ["TORCH_INDUCTOR_CPP_WRAPPER", "STANDALONE_TORCH_HEADER"] + + +def _use_custom_generated_macros() -> list[str]: + return [" C10_USING_CUSTOM_GENERATED_MACROS"] + + +def _use_fb_internal_macros() -> list[str]: + if not _IS_WINDOWS: + if config.is_fbcode(): + fb_internal_macros = [ + "C10_USE_GLOG", + "C10_USE_MINIMAL_GLOG", + "C10_DISABLE_TENSORIMPL_EXTENSIBILITY", + ] + return fb_internal_macros + else: + return [] + else: + return [] + + +def _setup_standard_sys_libs( + cpp_compiler: str, + aot_mode: bool, + use_relative_path: bool, +) -> tuple[list[str], list[str], list[str]]: + cflags: list[str] = [] + include_dirs: list[str] = [] + passthrough_args: list[str] = [] + if _IS_WINDOWS: + return cflags, include_dirs, passthrough_args + + if config.is_fbcode(): + # TODO(T203137008) Can we unify these flags with triton_cc_command? + cflags.append("nostdinc") + # Note that the order of include paths do matter, as a result + # we need to have several branches interleaved here + include_dirs.append(build_paths.sleef_include) + include_dirs.append(build_paths.openmp_include) + include_dirs.append(build_paths.python_include) + include_dirs.append(build_paths.cc_include) + include_dirs.append(build_paths.libgcc_include) + include_dirs.append(build_paths.libgcc_arch_include) + include_dirs.append(build_paths.libgcc_backward_include) + include_dirs.append(build_paths.glibc_include) + include_dirs.append(build_paths.linux_kernel_include) + include_dirs.append("include") + + if aot_mode and not use_relative_path: + linker_script = _LINKER_SCRIPT + else: + linker_script = os.path.basename(_LINKER_SCRIPT) + + if _is_clang(cpp_compiler): + passthrough_args.append(" --rtlib=compiler-rt") + passthrough_args.append(" -fuse-ld=lld") + passthrough_args.append(f" -Wl,--script={linker_script}") + passthrough_args.append(" -B" + build_paths.glibc_lib) + passthrough_args.append(" -L" + build_paths.glibc_lib) + + return cflags, include_dirs, passthrough_args + + +def _get_build_args_of_chosen_isa(vec_isa: VecISA) -> tuple[list[str], list[str]]: + macros: list[str] = [] + build_flags: list[str] = [] + if vec_isa != invalid_vec_isa: + # Add Windows support later. + macros.extend(copy.deepcopy(x) for x in vec_isa.build_macro()) + + build_flags = [vec_isa.build_arch_flags()] + + if config.is_fbcode(): + cap = str(vec_isa).upper() + macros = [ + f"CPU_CAPABILITY={cap}", + f"CPU_CAPABILITY_{cap}", + f"HAVE_{cap}_CPU_DEFINITION", + ] + + return macros, build_flags + + +def _get_torch_related_args( + include_pytorch: bool, aot_mode: bool +) -> tuple[list[str], list[str], list[str]]: + from torch.utils.cpp_extension import include_paths, TORCH_LIB_PATH + + libraries = [] + include_dirs = include_paths() + + if config.aot_inductor.link_libtorch: + libraries_dirs = [TORCH_LIB_PATH] + if sys.platform != "darwin" and not config.is_fbcode(): + libraries.extend(["torch", "torch_cpu"]) + if not aot_mode: + libraries.append("torch_python") + else: + libraries_dirs = [] + if config.aot_inductor.cross_target_platform == "windows": + aoti_shim_library = config.aot_inductor.aoti_shim_library + + assert aoti_shim_library, ( + "'config.aot_inductor.aoti_shim_library' must be set when 'cross_target_platform' is 'windows'." + ) + if isinstance(aoti_shim_library, str): + libraries.append(aoti_shim_library) + else: + assert isinstance(aoti_shim_library, list) + libraries.extend(aoti_shim_library) + + if config.aot_inductor.cross_target_platform == "windows": + assert config.aot_inductor.aoti_shim_library_path, ( + "'config.aot_inductor.aoti_shim_library_path' must be set to the path of the AOTI shim library", + " when 'cross_target_platform' is 'windows'.", + ) + libraries_dirs.append(config.aot_inductor.aoti_shim_library_path) + + if _IS_WINDOWS: + libraries.append("sleef") + + return include_dirs, libraries_dirs, libraries + + +def _get_python_include_dirs() -> list[str]: + include_dir = Path(sysconfig.get_path("include")) + # On Darwin Python executable from a framework can return + # non-existing /Library/Python/... include path, in which case + # one should use Headers folder from the framework + if not include_dir.exists() and platform.system() == "Darwin": + std_lib = Path(sysconfig.get_path("stdlib")) + include_dir = (std_lib.parent.parent / "Headers").absolute() + if not (include_dir / "Python.h").exists(): + warnings.warn(f"Can't find Python.h in {str(include_dir)}") + return [str(include_dir)] + + +def _get_python_related_args() -> tuple[list[str], list[str]]: + python_include_dirs = _get_python_include_dirs() + python_include_path = sysconfig.get_path( + "include", scheme="nt" if _IS_WINDOWS else "posix_prefix" + ) + if python_include_path is not None: + python_include_dirs.append(python_include_path) + + if _IS_WINDOWS: + python_lib_path = [ + str( + ( + Path(sysconfig.get_path("include", scheme="nt")).parent / "libs" + ).absolute() + ) + ] + else: + python_lib_path = [sysconfig.get_config_var("LIBDIR")] + + if config.is_fbcode(): + python_include_dirs.append(build_paths.python_include) + + return python_include_dirs, python_lib_path + + +@functools.cache +def is_conda_llvm_openmp_installed() -> bool: + try: + command = "conda list llvm-openmp --json" + output = subprocess.check_output(command.split()).decode("utf8") + return len(json.loads(output)) > 0 + except (subprocess.SubprocessError, FileNotFoundError): + return False + + +@functools.cache +def homebrew_libomp() -> tuple[bool, str]: + try: + # check if `brew` is installed + if shutil.which("brew") is None: + return False, "" + # get the location of `libomp` if it is installed + # this is the location that `libomp` **would** be installed + # see https://github.com/Homebrew/brew/issues/10261#issuecomment-756563567 for details + libomp_path = ( + subprocess.check_output(["brew", "--prefix", "libomp"]) + .decode("utf8") + .strip() + ) + # check if `libomp` is installed + omp_available = os.path.exists(libomp_path) + return omp_available, libomp_path + except subprocess.SubprocessError: + return False, "" + + +@functools.cache +def perload_clang_libomp_win(cpp_compiler: str, omp_name: str) -> None: + try: + output = subprocess.check_output([cpp_compiler, "-print-file-name=bin"]).decode( + "utf8" + ) + omp_path = os.path.join(output.rstrip(), omp_name) + if os.path.isfile(omp_path): + os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" + cdll.LoadLibrary(omp_path) + except subprocess.SubprocessError: + pass + + +@functools.cache +def perload_icx_libomp_win(cpp_compiler: str) -> None: + def _load_icx_built_in_lib_by_name(cpp_compiler: str, lib_name: str) -> bool: + try: + output = subprocess.check_output( + [cpp_compiler, f"-print-file-name={lib_name}"], + stderr=subprocess.DEVNULL, + ).decode(*SUBPROCESS_DECODE_ARGS) + omp_path = output.rstrip() + if os.path.isfile(omp_path): + os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" + cdll.LoadLibrary(omp_path) + return True + except subprocess.SubprocessError: + pass + return False + + """ + Intel Compiler implemented more math libraries than clang, for performance proposal. + We need preload them like openmp library. + """ + preload_list = [ + "libiomp5md.dll", # openmp + "svml_dispmd.dll", # svml library + "libmmd.dll", # libm + ] + + for lib_name in preload_list: + _load_icx_built_in_lib_by_name(cpp_compiler, lib_name) + + +def _get_openmp_args( + cpp_compiler: str, +) -> tuple[list[str], list[str], list[str], list[str], list[str], list[str]]: + cflags: list[str] = [] + ldflags: list[str] = [] + include_dir_paths: list[str] = [] + lib_dir_paths: list[str] = [] + libs: list[str] = [] + passthrough_args: list[str] = [] + + if config.aot_inductor.cross_target_platform == "windows": + return cflags, ldflags, include_dir_paths, lib_dir_paths, libs, passthrough_args + if _IS_MACOS: + # Per https://mac.r-project.org/openmp/ right way to pass `openmp` flags to MacOS is via `-Xclang` + cflags.append("Xclang") + cflags.append("fopenmp") + + # only Apple builtin compilers (Apple Clang++) require openmp + omp_available = not _is_apple_clang(cpp_compiler) + + # check the `OMP_PREFIX` environment first + omp_prefix = os.getenv("OMP_PREFIX") + if omp_prefix is not None: + header_path = os.path.join(omp_prefix, "include", "omp.h") + valid_env = os.path.exists(header_path) + if valid_env: + include_dir_paths.append(os.path.join(omp_prefix, "include")) + lib_dir_paths.append(os.path.join(omp_prefix, "lib")) + else: + warnings.warn("environment variable `OMP_PREFIX` is invalid.") + omp_available = omp_available or valid_env + + if not omp_available: + libs.append("omp") + + # prefer to use openmp from `conda install llvm-openmp` + conda_prefix = os.getenv("CONDA_PREFIX") + if not omp_available and conda_prefix is not None: + omp_available = is_conda_llvm_openmp_installed() + if omp_available: + conda_lib_path = os.path.join(conda_prefix, "lib") + include_dir_paths.append(os.path.join(conda_prefix, "include")) + lib_dir_paths.append(conda_lib_path) + # Prefer Intel OpenMP on x86 machine + if os.uname().machine == "x86_64" and os.path.exists( + os.path.join(conda_lib_path, "libiomp5.dylib") + ): + libs.append("iomp5") + + # next, try to use openmp from `brew install libomp` + if not omp_available: + omp_available, libomp_path = homebrew_libomp() + if omp_available: + include_dir_paths.append(os.path.join(libomp_path, "include")) + lib_dir_paths.append(os.path.join(libomp_path, "lib")) + + # if openmp is still not available, we let the compiler to have a try, + # and raise error together with instructions at compilation error later + elif _IS_WINDOWS: + """ + On Windows, `clang` and `icx` have their specific openmp implenmention. + And the openmp lib is in compiler's some sub-directory. + For dynamic library(DLL) load, the Windows native APIs are `LoadLibraryA` and `LoadLibraryExA`, and their search + dependencies have some rules: + https://learn.microsoft.com/en-us/windows/win32/api/libloaderapi/nf-libloaderapi-loadlibraryexa#searching-for-dlls-and-dependencies + In some case, the rules may not include compiler's sub-directories. + So, it can't search and load compiler's openmp library correctly. + And then, the whole application would be broken. + + To avoid the openmp load failed, we can automatic locate the openmp binary and preload it. + 1. For clang, the function is `perload_clang_libomp_win`. + 2. For icx, the function is `perload_icx_libomp_win`. + """ + if _is_clang(cpp_compiler): + cflags.append("openmp") + libs.append("libomp") + perload_clang_libomp_win(cpp_compiler, "libomp.dll") + elif _is_intel_compiler(cpp_compiler): + cflags.append("Qiopenmp") + libs.append("libiomp5md") + perload_icx_libomp_win(cpp_compiler) + else: + # /openmp, /openmp:llvm + # llvm on Windows, new openmp: https://devblogs.microsoft.com/cppblog/msvc-openmp-update/ + # msvc openmp: https://learn.microsoft.com/zh-cn/cpp/build/reference/openmp-enable-openmp-2-0-support?view=msvc-170 + cflags.append("openmp") + cflags.append("openmp:experimental") # MSVC CL + else: + if config.is_fbcode(): + include_dir_paths.append(build_paths.openmp_include) + + openmp_lib = build_paths.openmp_lib_so + fb_openmp_extra_flags = f"-Wp,-fopenmp {openmp_lib}" + passthrough_args.append(fb_openmp_extra_flags) + + libs.append("omp") + else: + if _is_clang(cpp_compiler): + # TODO: fix issue, can't find omp.h + cflags.append("fopenmp") + libs.append("gomp") + elif _is_intel_compiler(cpp_compiler): + cflags.append("fiopenmp") + else: + cflags.append("fopenmp") + libs.append("gomp") + + return cflags, ldflags, include_dir_paths, lib_dir_paths, libs, passthrough_args + + +def _get_libstdcxx_args() -> tuple[list[str], list[str]]: + """ + For fbcode cpu case, we should link stdc++ instead assuming the binary where dlopen is executed is built with dynamic stdc++. + """ + lib_dir_paths: list[str] = [] + libs: list[str] = [] + if config.is_fbcode(): + lib_dir_paths = [sysconfig.get_config_var("LIBDIR")] + libs.append("stdc++") + + return lib_dir_paths, libs + + +def get_mmap_self_macro( + use_mmap_weights: bool, use_mmap_weights_external: bool +) -> list[str]: + macros = [] + + if use_mmap_weights and use_mmap_weights_external: + raise RuntimeError( + "Only one of use_mmap_weights and use_mmap_weights_external should be true" + ) + if use_mmap_weights: + macros.append(" USE_MMAP_SELF") + elif use_mmap_weights_external: + macros.append(" USE_MMAP_EXTERNAL") + return macros + + +def get_caching_allocator_macro() -> list[str]: + from torch._inductor import config + + macros = [] + if config.aot_inductor.weight_use_caching_allocator: + macros.append(" AOT_INDUCTOR_USE_CACHING_ALLOCATOR") + return macros + + +def get_cpp_torch_options( + cpp_compiler: str, + vec_isa: VecISA, + include_pytorch: bool, + aot_mode: bool, + use_relative_path: bool, + use_mmap_weights: bool, + use_mmap_weights_external: bool, +) -> tuple[list[str], list[str], list[str], list[str], list[str], list[str], list[str]]: + """ + This function is used to get the build args of torch related build options. + 1. Torch include_directories, libraries, libraries_directories. + 2. Python include_directories, libraries, libraries_directories. + 3. OpenMP related. + 4. Torch MACROs. + 5. MISC + 6. Return the build args + """ + definitions: list[str] = [] + include_dirs: list[str] = [] + cflags: list[str] = [] + ldflags: list[str] = [] + libraries_dirs: list[str] = [] + libraries: list[str] = [] + passthrough_args: list[str] = [] + + torch_cpp_wrapper_definitions = _get_torch_cpp_wrapper_definition() + use_custom_generated_macros_definitions = _use_custom_generated_macros() + + ( + sys_libs_cflags, + sys_libs_include_dirs, + sys_libs_passthrough_args, + ) = _setup_standard_sys_libs(cpp_compiler, aot_mode, use_relative_path) + + isa_macros, isa_ps_args_build_flags = _get_build_args_of_chosen_isa(vec_isa) + + ( + torch_include_dirs, + torch_libraries_dirs, + torch_libraries, + ) = _get_torch_related_args(include_pytorch=include_pytorch, aot_mode=aot_mode) + + python_include_dirs, python_libraries_dirs = _get_python_related_args() + + ( + omp_cflags, + omp_ldflags, + omp_include_dir_paths, + omp_lib_dir_paths, + omp_lib, + omp_passthrough_args, + ) = _get_openmp_args(cpp_compiler) + + fb_macro_passthrough_args = _use_fb_internal_macros() + + mmap_self_macros = get_mmap_self_macro(use_mmap_weights, use_mmap_weights_external) + caching_allocator_macros = get_caching_allocator_macro() + + definitions = ( + torch_cpp_wrapper_definitions + + use_custom_generated_macros_definitions + + isa_macros + + fb_macro_passthrough_args + + mmap_self_macros + + caching_allocator_macros + ) + include_dirs = ( + sys_libs_include_dirs + + python_include_dirs + + torch_include_dirs + + omp_include_dir_paths + ) + cflags = sys_libs_cflags + omp_cflags + ldflags = omp_ldflags + libraries_dirs = python_libraries_dirs + torch_libraries_dirs + omp_lib_dir_paths + libraries = torch_libraries + omp_lib + passthrough_args = ( + sys_libs_passthrough_args + isa_ps_args_build_flags + omp_passthrough_args + ) + + return ( + definitions, + include_dirs, + cflags, + ldflags, + libraries_dirs, + libraries, + passthrough_args, + ) + + +class CppTorchOptions(CppOptions): + """ + This class is inherited from CppTorchOptions, which automatic contains + base cxx build options. And then it will maintains torch related build + args. + 1. Torch include_directories, libraries, libraries_directories. + 2. Python include_directories, libraries, libraries_directories. + 3. OpenMP related. + 4. Torch MACROs. + 5. MISC + """ + + def __init__( + self, + vec_isa: VecISA = invalid_vec_isa, + include_pytorch: bool = False, + warning_all: bool = True, + aot_mode: bool = False, + compile_only: bool = False, + use_relative_path: bool = False, + use_mmap_weights: bool = False, + use_mmap_weights_external: bool = False, + shared: bool = True, + extra_flags: Sequence[str] = (), + compiler: str = "", + min_optimize: bool = False, + precompiling: bool = False, + preprocessing: bool = False, + ) -> None: + super().__init__( + compile_only=compile_only, + warning_all=warning_all, + extra_flags=extra_flags, + use_relative_path=use_relative_path, + compiler=compiler, + min_optimize=min_optimize, + precompiling=precompiling, + preprocessing=preprocessing, + ) + + self._aot_mode = aot_mode + + ( + torch_definitions, + torch_include_dirs, + torch_cflags, + torch_ldflags, + torch_libraries_dirs, + torch_libraries, + torch_passthrough_args, + ) = get_cpp_torch_options( + cpp_compiler=self._compiler, + vec_isa=vec_isa, + include_pytorch=include_pytorch, + aot_mode=aot_mode, + use_relative_path=use_relative_path, + use_mmap_weights=use_mmap_weights, + use_mmap_weights_external=use_mmap_weights_external, + ) + + _append_list(self._definitions, torch_definitions) + _append_list(self._include_dirs, torch_include_dirs) + _append_list(self._cflags, torch_cflags) + _append_list(self._ldflags, torch_ldflags) + _append_list(self._libraries_dirs, torch_libraries_dirs) + _append_list(self._libraries, torch_libraries) + _append_list(self._passthrough_args, torch_passthrough_args) + self._finalize_options() + + +def _set_gpu_runtime_env() -> None: + if ( + config.is_fbcode() + and torch.version.hip is None + and "CUDA_HOME" not in os.environ + and "CUDA_PATH" not in os.environ + ): + os.environ["CUDA_HOME"] = build_paths.sdk_home + + +@functools.lru_cache(8) +def _find_libcudart_static(path: str) -> Optional[Path]: + lib_dirs = list(Path(path).rglob("libcudart_static.a")) + if lib_dirs: + return lib_dirs[0].resolve().parent + log_msg = f'"libcudart_static.a" not found under {path}' + log.info(log_msg) + return None + + +def _transform_cuda_paths(lpaths: list[str]) -> None: + # This handles two cases: + # 1. Cases where libs are in (e.g.) lib/cuda-12 and lib/cuda-12/stubs + # 2. Linux machines may have CUDA installed under either lib64/ or lib/ + for i, path in enumerate(lpaths): + if "CUDA_HOME" in os.environ and path.startswith(os.environ["CUDA_HOME"]): + lib_dir: Optional[Path] = _find_libcudart_static(path) + if lib_dir is None: + continue + lpaths[i] = str(lib_dir) + stub_dir = lib_dir / "stubs" + if stub_dir.exists(): + lpaths.append(str(stub_dir)) + + +def get_cpp_torch_device_options( + device_type: str, + aot_mode: bool = False, + compile_only: bool = False, +) -> tuple[list[str], list[str], list[str], list[str], list[str], list[str], list[str]]: + """ + This function is used to get the build args of device related build options. + 1. Device include_directories, libraries, libraries_directories. + 2. Device MACROs. + 3. MISC + 4. Return the build args + """ + definitions: list[str] = [] + include_dirs: list[str] = [] + cflags: list[str] = [] + ldflags: list[str] = [] + libraries_dirs: list[str] = [] + libraries: list[str] = [] + passthrough_args: list[str] = [] + if ( + config.is_fbcode() + and "CUDA_HOME" not in os.environ + and "CUDA_PATH" not in os.environ + ): + os.environ["CUDA_HOME"] = build_paths.sdk_home + + _set_gpu_runtime_env() + from torch.utils import cpp_extension + + include_dirs = cpp_extension.include_paths( + device_type, config.aot_inductor.link_libtorch is None + ) + link_libtorch = config.aot_inductor.link_libtorch + libraries_dirs = cpp_extension.library_paths( + device_type, + torch_include_dirs=link_libtorch, + cross_target_platform=config.aot_inductor.cross_target_platform, + ) + if device_type == "cuda": + definitions.append(" USE_ROCM" if torch.version.hip else " USE_CUDA") + + if torch.version.hip is not None: + if config.is_fbcode() or not link_libtorch: + libraries += ["amdhip64"] + else: + libraries += ["torch_hip"] + definitions.append(" __HIP_PLATFORM_AMD__") + else: + if config.is_fbcode() or not link_libtorch: + libraries += ["cuda"] + else: + libraries += ["cuda", "torch_cuda"] + if config.aot_inductor.cross_target_platform == "windows": + libraries += ["cudart"] + _transform_cuda_paths(libraries_dirs) + + if device_type == "xpu": + definitions.append(" USE_XPU") + xpu_error_string = ( + "Intel GPU driver is not properly installed, please follow the instruction " + "in https://github.com/pytorch/pytorch?tab=readme-ov-file#intel-gpu-support." + ) + if _IS_WINDOWS: + ze_root = os.getenv("LEVEL_ZERO_V1_SDK_PATH") + if ze_root is None: + raise OSError(xpu_error_string) + include_dirs += [os.path.join(ze_root, "include")] + libraries_dirs += [os.path.join(ze_root, "lib")] + else: + # Suppress multi-line comment warnings in sycl headers + cflags += ["Wno-comment"] + if not find_library("ze_loader"): + raise OSError(xpu_error_string) + + libraries += ["ze_loader", "sycl"] + if link_libtorch: + libraries += ["torch_xpu"] + + if device_type == "mps": + definitions.append(" USE_MPS") + + if config.is_fbcode(): + include_dirs.append(build_paths.sdk_include) + + if aot_mode and device_type == "cuda": + if torch.version.hip is None: + if not compile_only: + # Only add link args, when compile_only is false. + passthrough_args = ["-Wl,-Bstatic -lcudart_static -Wl,-Bdynamic"] + + if device_type == "cpu": + ( + stdcxx_lib_dir_paths, + stdcxx_libs, + ) = _get_libstdcxx_args() + libraries_dirs += stdcxx_lib_dir_paths + libraries += stdcxx_libs + + if config.aot_inductor.custom_op_libs: + libraries += config.aot_inductor.custom_op_libs + + return ( + definitions, + include_dirs, + cflags, + ldflags, + libraries_dirs, + libraries, + passthrough_args, + ) + + +class CppTorchDeviceOptions(CppTorchOptions): + """ + This class is inherited from CppTorchOptions, which automatic contains + base cxx build options and torch common build options. And then it will + maintains cuda/xpu device related build args. + """ + + def __init__( + self, + vec_isa: VecISA = invalid_vec_isa, + include_pytorch: bool = False, + device_type: str = "cuda", + aot_mode: bool = False, + compile_only: bool = False, + use_relative_path: bool = False, + use_mmap_weights: bool = False, + use_mmap_weights_external: bool = False, + shared: bool = True, + extra_flags: Sequence[str] = (), + min_optimize: bool = False, + precompiling: bool = False, + preprocessing: bool = False, + ) -> None: + super().__init__( + vec_isa=vec_isa, + include_pytorch=include_pytorch, + aot_mode=aot_mode, + compile_only=compile_only, + use_relative_path=use_relative_path, + use_mmap_weights=use_mmap_weights, + use_mmap_weights_external=use_mmap_weights_external, + extra_flags=extra_flags, + min_optimize=min_optimize, + precompiling=precompiling, + preprocessing=preprocessing, + ) + + device_definitions: list[str] = [] + device_include_dirs: list[str] = [] + device_cflags: list[str] = [] + device_ldflags: list[str] = [] + device_libraries_dirs: list[str] = [] + device_libraries: list[str] = [] + device_passthrough_args: list[str] = [] + + ( + device_definitions, + device_include_dirs, + device_cflags, + device_ldflags, + device_libraries_dirs, + device_libraries, + device_passthrough_args, + ) = get_cpp_torch_device_options( + device_type=device_type, + aot_mode=aot_mode, + compile_only=compile_only, + ) + _append_list(self._definitions, device_definitions) + _append_list(self._include_dirs, device_include_dirs) + _append_list(self._cflags, device_cflags) + _append_list(self._ldflags, device_ldflags) + _append_list(self._libraries_dirs, device_libraries_dirs) + _append_list(self._libraries, device_libraries) + _append_list(self._passthrough_args, device_passthrough_args) + self._finalize_options() + + def _finalize_options(self) -> None: + super()._finalize_options() + if config.is_fbcode(): + # Re-order library search paths in case there are lib conflicts + # that also live in the FBCode python lib dir. + _, python_lib_dirs = _get_python_related_args() + assert len(python_lib_dirs) == 1, f"Python lib dirs: {python_lib_dirs}" + if python_lib_dirs[0] in self._libraries_dirs: + self._libraries_dirs.remove(python_lib_dirs[0]) + self._libraries_dirs.append(python_lib_dirs[0]) + + +def get_name_and_dir_from_output_file_path( + file_path: str, +) -> tuple[str, str]: + """ + This function help prepare parameters to new cpp_builder. + Example: + input_code: /tmp/tmpof1n5g7t/5c/c5crkkcdvhdxpktrmjxbqkqyq5hmxpqsfza4pxcf3mwk42lphygc.cpp + name, dir = get_name_and_dir_from_output_file_path(input_code) + Run result: + name = c5crkkcdvhdxpktrmjxbqkqyq5hmxpqsfza4pxcf3mwk42lphygc + dir = /tmp/tmpof1n5g7t/5c/ + + put 'name' and 'dir' to CppBuilder's 'name' and 'output_dir'. + CppBuilder --> get_target_file_path will format output path according OS: + Linux: /tmp/tmppu87g3mm/zh/czhwiz4z7ca7ep3qkxenxerfjxy42kehw6h5cjk6ven4qu4hql4i.so + Windows: [Windows temp path]/tmppu87g3mm/zh/czhwiz4z7ca7ep3qkxenxerfjxy42kehw6h5cjk6ven4qu4hql4i.dll + """ + name_and_ext = os.path.basename(file_path) + name, _ext = os.path.splitext(name_and_ext) + dir = os.path.dirname(file_path) + + return name, dir + + +class CppBuilder: + """ + CppBuilder is a cpp jit builder, and it supports both Windows, Linux and MacOS. + Args: + name: + 1. Build target name, the final target file will append extension type automatically. + 2. Due to the CppBuilder is supports multiple OS, it will maintains ext for OS difference. + sources: + Source code file list to be built. + BuildOption: + Build options to the builder. + output_dir: + 1. The output_dir the target file will output to. + 2. The default value is empty string, and then the use current dir as output dir. + 3. Final target file: output_dir/name.ext + """ + + @staticmethod + def __get_python_module_flags() -> tuple[str, str]: + extension = ".pyd" if _IS_WINDOWS else ".so" + output_flags = "/Fe" if _IS_WINDOWS else "-o" + return extension, output_flags + + @staticmethod + def __get_object_flags() -> tuple[str, str]: + extension = ".obj" if _IS_WINDOWS else ".o" + output_flags = "/c /Fo" if _IS_WINDOWS else "-c -o" # codespell:ignore + return extension, output_flags + + @staticmethod + def __get_precompiled_header_flags() -> tuple[str, str]: + extension = ".pch" if _IS_WINDOWS or not is_gcc() else ".gch" + output_flags = "/Fp" if _IS_WINDOWS else "-o" + return extension, output_flags + + @staticmethod + def __get_preprocessor_output_flags() -> tuple[str, str]: + extension = ".i" + output_flags = "/EP /P" if _IS_WINDOWS else "-E -P -o" + return extension, output_flags + + def __init__( + self, + name: str, + sources: Union[str, list[str]], + BuildOption: BuildOptionsBase, + output_dir: str = "", + ) -> None: + self._compiler = "" + self._cflags_args = "" + self._definitions_args = "" + self._include_dirs_args = "" + self._ldflags_args = "" + self._libraries_dirs_args = "" + self._libraries_args = "" + self._passthrough_parameters_args = "" + + # When relative path is used, we need to maintain the source dir list. + self._orig_source_paths = [] + self._output_dir = "" + self._target_file = "" + + self._use_relative_path: bool = False + self._aot_mode: bool = False + + self._name = name + self._target_name = ( + config.aot_inductor.model_name_for_generated_files or "aoti_model" + ) + + # Code start here, initial self internal variables firstly. + self._build_option = BuildOption + self._compiler = BuildOption.get_compiler() + self._use_relative_path = BuildOption.get_use_relative_path() + self._aot_mode = BuildOption.get_aot_mode() + + self._output_dir = output_dir + + self._compile_only = BuildOption.get_compile_only() + self._precompiling = BuildOption.get_precompiling() + self._preprocessing = BuildOption.get_preprocessing() + # Only one of these options (if any) should be true at any given time. + assert sum((self._compile_only, self._precompiling, self._preprocessing)) <= 1 + self._do_link = not ( + self._compile_only or self._precompiling or self._preprocessing + ) + + # MSVC produces two files when precompiling: the actual .pch file, as well as an + # object file which must be linked into the final library. This class assumes + # only one output file of note, so for now we'll error out here. + assert not _IS_WINDOWS or not self._precompiling, ( + "Cannot currently precompile headers on Windows!" + ) + + if self._compile_only: + file_ext, output_flags = self.__get_object_flags() + elif self._precompiling: + file_ext, output_flags = self.__get_precompiled_header_flags() + elif self._preprocessing: + file_ext, output_flags = self.__get_preprocessor_output_flags() + else: + file_ext, output_flags = self.__get_python_module_flags() + self._target_file = os.path.join(self._output_dir, f"{self._name}{file_ext}") + + relative_target_file = ( + os.path.basename(self._target_file) + if self._use_relative_path + else self._target_file + ) + if _IS_WINDOWS: + if self._preprocessing: + # The target file name is automatically determined by MSVC. + self._output = output_flags + else: + self._output = f"{output_flags}{relative_target_file}" + else: + self._output = f"{output_flags} {relative_target_file}" + + if isinstance(sources, str): + sources = [sources] + + # Use relative paths only when requested (typically for remote builds) + if config.is_fbcode() and self._use_relative_path: + # Will create another temp directory for building, so do NOT use the + # absolute path. + self._orig_source_paths = list(sources) + sources = [os.path.basename(i) for i in sources] + + if self._precompiling: + assert len(sources) == 1 + # See above; we can currently assume this is not on MSVC. + self._sources_args = f"-x c++-header {sources[0]}" + else: + self._sources_args = " ".join(sources) + + for cflag in BuildOption.get_cflags(): + if _IS_WINDOWS: + self._cflags_args += f"/{cflag} " + else: + self._cflags_args += f"-{cflag} " + + for definition in BuildOption.get_definitions(): + if _IS_WINDOWS: + self._definitions_args += f"/D {definition} " + else: + self._definitions_args += f"-D {definition} " + + if precompiled_header := BuildOption.precompiled_header: + if _IS_WINDOWS: + log.warning( + "Precompiled header support for MSVC is currently unavailable; ignoring %s", + precompiled_header, + ) + else: + self._include_dirs_args = f"-include {precompiled_header} " + + for inc_dir in BuildOption.get_include_dirs(): + if _IS_WINDOWS: + self._include_dirs_args += f'/I "{inc_dir}" ' + else: + self._include_dirs_args += f"-I{shlex.quote(inc_dir)} " + + for ldflag in BuildOption.get_ldflags(): + if _IS_WINDOWS: + self._ldflags_args += f"/{ldflag} " + else: + self._ldflags_args += f"-{ldflag} " + + for lib_dir in BuildOption.get_libraries_dirs(): + if _IS_WINDOWS: + self._libraries_dirs_args += f'/LIBPATH:"{lib_dir}" ' + else: + self._libraries_dirs_args += f"-L{lib_dir} " + + for lib in BuildOption.get_libraries(): + if _IS_WINDOWS: + self._libraries_args += f'"{lib}.lib" ' + else: + self._libraries_args += f"-l{lib} " + + for passthrough_arg in BuildOption.get_passthrough_args(): + self._passthrough_parameters_args += f"{passthrough_arg} " + + def get_command_line(self) -> str: + def format_build_command( + compiler: str, + sources: str, + include_dirs_args: str, + definitions_args: str, + cflags_args: str, + ldflags_args: str, + libraries_args: str, + libraries_dirs_args: str, + passthrough_args: str, + output: str, + ) -> str: + if _IS_WINDOWS: + # https://learn.microsoft.com/en-us/cpp/build/walkthrough-compile-a-c-program-on-the-command-line?view=msvc-1704 + # https://stackoverflow.com/a/31566153 + cmd = ( + f"{compiler} {include_dirs_args} {definitions_args} {cflags_args} " + f"{sources} {passthrough_args} {output}" + ) + if self._do_link: + cmd += f" /LD /link {libraries_dirs_args} {libraries_args} {ldflags_args}" + cmd = normalize_path_separator(cmd) + else: + cmd = ( + f"{compiler} {sources} {definitions_args} {cflags_args} " + f"{include_dirs_args} {passthrough_args} {output}" + ) + if self._do_link: + cmd += f" {ldflags_args} {libraries_args} {libraries_dirs_args}" + return cmd + + command_line = format_build_command( + compiler=self._compiler, + sources=self._sources_args, + include_dirs_args=self._include_dirs_args, + definitions_args=self._definitions_args, + cflags_args=self._cflags_args, + ldflags_args=self._ldflags_args, + libraries_args=self._libraries_args, + libraries_dirs_args=self._libraries_dirs_args, + passthrough_args=self._passthrough_parameters_args, + output=self._output, + ) + return command_line + + def get_target_file_path(self) -> str: + return normalize_path_separator(self._target_file) + + def build_fbcode_re( + self, + ) -> None: + with dynamo_timed("compile_file"): + command = self.get_command_line().split() + try: + output_path = self._target_file + # When we build remotely, we need to make sure to carefully copy any files + # that are required during the compilation process into our build directly. + # This is where all of the ATen/c10/Torch includes come from. + torch_includes_path = os.path.join(_TORCH_PATH, "include") + with tempfile.TemporaryDirectory() as tmp_dir: + # Copy everything to tmp compilation folder + shutil.copy(_LINKER_SCRIPT, os.path.join(tmp_dir, "script.ld")) + for src in self._orig_source_paths: + shutil.copy(src, os.path.join(tmp_dir, os.path.basename(src))) + dest_include_path = os.path.join(tmp_dir, "include") + shutil.copytree(torch_includes_path, dest_include_path) + # Run the build + tmp_output_path = _run_build_command( + command, tmp_dir, os.path.basename(output_path) + ) + # Copy output from the build + if os.path.exists(output_path): + os.remove(output_path) + shutil.copy(tmp_output_path, output_path) + if output_path.endswith(".o"): + os.chmod(output_path, 0o644) + elif output_path.endswith(".so"): + os.chmod(output_path, 0o755) + except subprocess.CalledProcessError as e: + output = e.output.decode("utf-8") + raise exc.CppCompileError(command, output) from e + + def build(self) -> None: + """ + It is must need a temporary directory to store object files in Windows. + After build completed, delete the temporary directory to save disk space. + """ + if self._use_relative_path: + # remote build uses relative path + return self.build_fbcode_re() + _create_if_dir_not_exist(self._output_dir) + _build_tmp_dir = os.path.join( + self._output_dir, f"{self._name}_{_BUILD_TEMP_DIR}" + ) + _create_if_dir_not_exist(_build_tmp_dir) + + build_cmd = self.get_command_line() + run_compile_cmd(build_cmd, cwd=_build_tmp_dir) + _remove_dir(_build_tmp_dir) + + def save_compile_cmd_to_cmake( + self, + cmake_path: str, + device_type: str, + ) -> None: + """ + Save global cmake settings here, e.g. compiler options. + If targeting CUDA, also emit a custom function to embed CUDA kernels. + """ + + definitions = " ".join(self._build_option.get_definitions()) + target_library_type = ( + "STATIC" if not config.aot_inductor.dynamic_linkage else "SHARED" + ) + + contents = textwrap.dedent( + f""" + cmake_minimum_required(VERSION 3.27 FATAL_ERROR) + project({self._target_name} LANGUAGES CXX) + set(CMAKE_CXX_STANDARD 17) + + # Set a library target + add_library({self._target_name} {target_library_type}) + + """ + ) + + if config.aot_inductor.link_libtorch or config.test_configs.use_libtorch: + # When compile_standalone is True, the generated cpp project should + # not use Torch. But for unit testing purpose, we need to use Torch here. + contents += textwrap.dedent( + """ + # May need to point CMAKE_PREFIX_PATH to the right torch location + find_package(Torch REQUIRED) + + """ + ) + # flags and macros here are mostly CPU specific. Not emitting them for GPU models + # will make the generated CMake file more portable and won't really hurt performance. + # NOTE: standalone focuses on GPU now. For CPU, some of the flags and macros may + # be still needed. + contents += textwrap.dedent( + f""" + # Add macro definitions + target_compile_definitions({self._target_name} PRIVATE {definitions}) + + # Add compile flags + target_compile_options({self._target_name} PRIVATE {self._cflags_args}) + + # Backend-specific flags + target_compile_options({self._target_name} PRIVATE {self._passthrough_parameters_args} -c) + + """ + ) + else: + # When compile_standalone is True, use TorchStandalone instead of Torch + contents += textwrap.dedent( + f""" + find_package(TorchStandalone REQUIRED) + # Set up include directories to find headers at the correct paths + target_include_directories({self._target_name} PRIVATE ${{TorchStandalone_INCLUDE_DIRS}}) + target_include_directories({self._target_name} PRIVATE ${{TorchStandalone_INCLUDE_DIRS}}/standalone) + + """ + ) + + if device_type == "cuda" and torch.version.hip is None: + from torch._inductor.codecache import _nvcc_arch_as_compile_option + + current_arch = _nvcc_arch_as_compile_option() + contents += textwrap.dedent( + f""" + enable_language(CUDA) + set(CMAKE_CUDA_STANDARD 17) + find_package(CUDAToolkit REQUIRED) + target_include_directories({self._target_name} PRIVATE ${{CUDAToolkit_INCLUDE_DIRS}}) + target_compile_definitions({self._target_name} PRIVATE USE_CUDA) + target_link_libraries({self._target_name} PRIVATE cuda CUDA::cudart_static) + + find_program(OBJCOPY_EXECUTABLE objcopy) + if(NOT OBJCOPY_EXECUTABLE) + message(FATAL_ERROR "objcopy not found. Cannot embed fatbin as object file") + endif() + + set(KERNEL_TARGETS "") + set(KERNEL_OBJECT_FILES "") + # Function to embed a single kernel + function(embed_gpu_kernel KERNEL_NAME PTX_FILE) + set(FATBIN_BASENAME ${{KERNEL_NAME}}.fatbin) + set(FATBIN_FILE ${{CMAKE_CURRENT_BINARY_DIR}}/${{FATBIN_BASENAME}}) + set(OBJECT_BASENAME ${{KERNEL_NAME}}.fatbin.o) + set(OBJECT_FILE ${{CMAKE_CURRENT_BINARY_DIR}}/${{OBJECT_BASENAME}}) + + # --- Define UNIQUE C symbol names --- + set(SYMBOL_START __${{KERNEL_NAME}}_start) + set(SYMBOL_END __${{KERNEL_NAME}}_end) + set(SYMBOL_SIZE __${{KERNEL_NAME}}_size) + string(REGEX REPLACE "[^a-zA-Z0-9]" "_" MANGLED_BASENAME ${{FATBIN_FILE}}) + set(OBJCOPY_START_SYM _binary_${{MANGLED_BASENAME}}_start) + set(OBJCOPY_END_SYM _binary_${{MANGLED_BASENAME}}_end) + set(OBJCOPY_SIZE_SYM _binary_${{MANGLED_BASENAME}}_size) + + # --- PTX to FATBIN Command & Target --- + add_custom_command( + OUTPUT ${{FATBIN_FILE}} + COMMAND ${{CUDAToolkit_NVCC_EXECUTABLE}} --fatbin ${{PTX_FILE}} -o ${{FATBIN_FILE}} ${{NVCC_GENCODE_FLAGS}} + -gencode arch=compute_{current_arch},code=compute_{current_arch} + -gencode arch=compute_{current_arch},code=sm_{current_arch} + DEPENDS ${{PTX_FILE}} + ) + + # --- FATBIN to Object File (.o) Command --- + add_custom_command( + OUTPUT ${{OBJECT_FILE}} + COMMAND ${{CMAKE_LINKER}} -r -b binary -z noexecstack -o ${{OBJECT_FILE}} ${{FATBIN_FILE}} + COMMAND ${{OBJCOPY_EXECUTABLE}} --rename-section .data=.rodata,alloc,load,readonly,data,contents + ${{OBJECT_FILE}} + COMMAND ${{OBJCOPY_EXECUTABLE}} + --redefine-sym ${{OBJCOPY_START_SYM}}=${{SYMBOL_START}} + --redefine-sym ${{OBJCOPY_END_SYM}}=${{SYMBOL_END}} + --redefine-sym ${{OBJCOPY_SIZE_SYM}}=${{SYMBOL_SIZE}} + ${{OBJECT_FILE}} + DEPENDS ${{FATBIN_FILE}} + ) + add_custom_target(build_kernel_object_${{KERNEL_NAME}} DEPENDS ${{OBJECT_FILE}}) + + # --- Add to a list for linking later --- + set(KERNEL_TARGETS ${{KERNEL_TARGETS}} build_kernel_object_${{KERNEL_NAME}} PARENT_SCOPE) + set(KERNEL_OBJECT_FILES ${{KERNEL_OBJECT_FILES}} ${{OBJECT_FILE}} PARENT_SCOPE) + endfunction() + + """ + ) + + with open(cmake_path, "w") as f: + f.write(contents) + + def save_src_to_cmake(self, cmake_path: str, src_path: str) -> None: + # Remove the directory part of file_path + src_path = "${CMAKE_CURRENT_SOURCE_DIR}/" + Path(src_path).name + with open(cmake_path, "a") as f: + f.write(f"target_sources({self._target_name} PRIVATE {src_path})\n") + + def save_kernel_asm_to_cmake(self, cmake_path: str, asm_files: list[str]) -> None: + # TODO: make this work beyond CUDA + with open(cmake_path, "a") as f: + for asm_file in asm_files: + kernel_name = Path(asm_file).name.split(".")[0] + asm_file = f"${{CMAKE_CURRENT_SOURCE_DIR}}/{Path(asm_file).name}" + contents = textwrap.dedent( + f""" + embed_gpu_kernel({kernel_name} {asm_file}) + """ + ) + f.write(contents) + if asm_files: + f.write(f"add_dependencies({self._target_name} ${{KERNEL_TARGETS}})\n") + f.write( + f"target_link_libraries({self._target_name} PRIVATE ${{KERNEL_OBJECT_FILES}})\n" + ) + + def save_link_cmd_to_cmake(self, cmake_path: str) -> None: + lflags = " ".join(self._build_option.get_ldflags()) + libs = " ".join(self._build_option.get_libraries()) + contents = textwrap.dedent( + f""" + # Add linker flags + target_link_options({self._target_name} PRIVATE {lflags}) + + # Add libraries + target_link_libraries({self._target_name} PRIVATE {libs}) + """ + ) + + assert os.path.exists(cmake_path), ( + f"save_link_cmd_to_cmakefile expects {cmake_path} to already exist" + ) + with open(cmake_path, "a") as f: + f.write(contents) + + +def run_asm_build_object(src: str, target: str, cwd: str) -> None: + def get_asm_compiler() -> str: + if _IS_WINDOWS: + ASM_CC = "ml64" + else: + ASM_CC = get_cpp_compiler() + # Intel compiler is not support to compile asm, switch to gcc. + if _is_intel_compiler(ASM_CC): + ASM_CC = "gcc" + return ASM_CC + + def get_command_line(asm_cc: str, src: str, target: str) -> str: + if _IS_WINDOWS: + # Format reference: + # https://learn.microsoft.com/en-us/cpp/assembler/masm/ml-and-ml64-command-line-reference?view=msvc-170 + cmd = f"{asm_cc} {src} /c /Fo {target}" # codespell:ignore /Fo + else: + cmd = f"{asm_cc} -c {src} -o {target}" + + return cmd + + asm_cc = get_asm_compiler() + cmd = get_command_line( + asm_cc=asm_cc, + src=normalize_path_separator(src), + target=normalize_path_separator(target), + ) + run_compile_cmd(cmd, cwd=normalize_path_separator(cwd)) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/cpu_vec_isa.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/cpu_vec_isa.py new file mode 100644 index 0000000000000000000000000000000000000000..46fd76c529e204afb4898e1a144b59d149be9abc --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/cpu_vec_isa.py @@ -0,0 +1,566 @@ +# mypy: allow-untyped-defs +import dataclasses +import functools +import os +import platform +import re +import subprocess +import sys +import warnings +from collections.abc import Callable +from typing import Any, Union + +import torch +from torch._inductor import config +from torch._inductor.utils import python_subprocess_env + + +_IS_WINDOWS = sys.platform == "win32" + + +def _get_isa_dry_compile_fingerprint(isa_flags: str) -> str: + # ISA dry compile will cost about 1 sec time each startup time. + # Please check the issue: https://github.com/pytorch/pytorch/issues/100378 + # Actually, dry compile is checking compile capability for ISA. + # We just record the compiler version, isa options and pytorch version info, + # and generated them to output binary hash path. + # It would optimize and skip compile existing binary. + from torch._inductor.cpp_builder import get_compiler_version_info, get_cpp_compiler + + compiler_info = get_compiler_version_info(get_cpp_compiler()) + torch_version = torch.__version__ + fingerprint = f"{compiler_info}={isa_flags}={torch_version}" + return fingerprint + + +class VecISA: + _bit_width: int + _macro: list[str] + _arch_flags: str + _dtype_nelements: dict[torch.dtype, int] + + # Note [Checking for Vectorized Support in Inductor] + # TorchInductor CPU vectorization reuses PyTorch vectorization utility functions + # Hence, TorchInductor would depend on Sleef* to accelerate mathematical functions + # like exp, pow, sin, cos and etc. + # But PyTorch and TorchInductor might use different compilers to build code. If + # PyTorch uses gcc-7/g++-7 to build the release package, the libtorch_cpu.so + # will not expose the Sleef* AVX512 symbols since gcc-7/g++-7 cannot pass + # avx512 check in CMake - FindAVX.cmake. But TorchInductor install the latest + # gcc/g++ compiler by default while it could support the AVX512 compilation. + # Therefore, there would be a conflict sleef version between PyTorch and + # TorchInductor. Hence, we dry-compile the following code to check whether current + # HW platform and PyTorch both could support AVX512 or AVX2. And suppose ARM + # also needs the logic + # In fbcode however, we are using the same compiler for pytorch and for inductor codegen, + # making the runtime check unnecessary. + _avx_code = """ +#if defined(CPU_CAPABILITY_AVX512) || defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_ZVECTOR) || defined(CPU_CAPABILITY_NEON) || defined(CPU_CAPABILITY_VSX) || defined(CPU_CAPABILITY_SVE) +#include +#include +#endif + +alignas(64) float in_out_ptr0[16] = {0.0}; + +extern "C" void __avx_chk_kernel() { + auto tmp0 = at::vec::Vectorized(1); + auto tmp1 = tmp0.exp(); + tmp1.store(in_out_ptr0); +} +""" # noqa: B950 + + _avx_py_load = """ +import torch +from ctypes import cdll +cdll.LoadLibrary("__lib_path__") +""" + + def bit_width(self) -> int: + return self._bit_width + + def nelements(self, dtype: torch.dtype = torch.float) -> int: + return self._dtype_nelements[dtype] + + def build_macro(self) -> list[str]: + return self._macro + + def build_arch_flags(self) -> str: + return self._arch_flags + + def __hash__(self) -> int: + return hash(str(self)) + + def check_build(self, code: str) -> bool: + from torch._inductor.codecache import get_lock_dir, LOCK_TIMEOUT, write + from torch._inductor.cpp_builder import ( + CppBuilder, + CppTorchOptions, + normalize_path_separator, + ) + + key, input_path = write( + code, + "cpp", + extra=_get_isa_dry_compile_fingerprint(self._arch_flags), + ) + from torch.utils._filelock import FileLock + + lock_dir = get_lock_dir() + lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT) + with lock: + output_dir = os.path.dirname(input_path) + buid_options = CppTorchOptions(vec_isa=self, warning_all=False) + x86_isa_help_builder = CppBuilder( + key, + [input_path], + buid_options, + output_dir, + ) + try: + # Check if the output file exist, and compile when not. + output_path = normalize_path_separator( + x86_isa_help_builder.get_target_file_path() + ) + if not os.path.isfile(output_path): + x86_isa_help_builder.build() + + # Check build result + subprocess.check_call( + [ + sys.executable, + "-c", + VecISA._avx_py_load.replace("__lib_path__", output_path), + ], + cwd=output_dir, + stderr=subprocess.DEVNULL, + env=python_subprocess_env(), + ) + except Exception: + return False + + return True + + def __bool__(self) -> bool: + return self.__bool__impl(config.cpp.vec_isa_ok) + + @functools.cache # noqa: B019 + def __bool__impl(self, vec_isa_ok) -> bool: + if vec_isa_ok is not None: + return vec_isa_ok + + if config.is_fbcode(): + return True + + return self.check_build(VecISA._avx_code) + + +@dataclasses.dataclass +class VecNEON(VecISA): + _bit_width = 128 # This is required to leverage the compute implemented in aten/src/ATen/cpu/vec/vec128/vec128_float_neon.h + _macro = ["CPU_CAPABILITY_NEON", "AT_BUILD_ARM_VEC256_WITH_SLEEF"] + _arch_flags = "" # Unused + _dtype_nelements = {torch.float: 4, torch.bfloat16: 8, torch.float16: 8} + + def __str__(self) -> str: + if config.is_fbcode(): + return "neon" + return "asimd" # detects the presence of advanced SIMD on armv8-a kernels + + __hash__: Callable[[VecISA], Any] = VecISA.__hash__ # type: ignore[assignment] + + +@dataclasses.dataclass +class VecSVE256(VecISA): + # this function can be repurposed for SVE with variable vec length + _bit_width = 256 + _macro = [ + "CPU_CAPABILITY_SVE", + "CPU_CAPABILITY_SVE256", + "AT_BUILD_ARM_VEC256_WITH_SLEEF", + "__ARM_FEATURE_BF16", + ] + _arch_flags = "-march=armv8-a+sve+bf16 -msve-vector-bits=256" + + _dtype_nelements = {torch.float: 8, torch.bfloat16: 16, torch.float16: 16} + + def __str__(self) -> str: + if config.is_fbcode(): + return "neon" + return "asimd" + + __hash__: Callable[[VecISA], Any] = VecISA.__hash__ # type: ignore[assignment] + + +@dataclasses.dataclass +class VecAVX512(VecISA): + _bit_width = 512 + _macro = ["CPU_CAPABILITY_AVX512"] + _arch_flags = ( + "-mavx512f -mavx512dq -mavx512vl -mavx512bw -mfma" + if not _IS_WINDOWS + else "/arch:AVX512" + ) # TODO: use cflags + _dtype_nelements = {torch.float: 16, torch.bfloat16: 32, torch.float16: 32} + _is_avx512_bf16_supported = False + + def __str__(self) -> str: + return "avx512" + + __hash__: Callable[[VecISA], Any] = VecISA.__hash__ # type: ignore[assignment] + + _avx512_bf16_code = """ +#include +#include + +extern "C" __m512bh __avx512_bf16_chk_kernel(__m512 a, __m512 b) { + return _mm512_cvtne2ps_pbh(a, b); +} +""" + + @functools.cache # noqa: B019 + # pyrefly: ignore [bad-override] + def __bool__(self) -> bool: + if super().__bool__(): + if config.is_fbcode(): + return False + # check avx512_bf16 + if torch.cpu._is_avx512_bf16_supported() and not _IS_WINDOWS: + # save _arch_flags + base_flags = self._arch_flags + # temporarily change _arch_flags for avx512_bf16 check_build + self._arch_flags += " -mavx512bf16" + if self.check_build(self._avx512_bf16_code): + self._is_avx512_bf16_supported = True + # restore _arch_flags + self._arch_flags = base_flags + + return True + return False + + @functools.lru_cache(None) # noqa: B019 + def is_avx512_bf16_supported(self) -> bool: + return self._is_avx512_bf16_supported + + def build_arch_flags(self) -> str: + if self._is_avx512_bf16_supported: + return self._arch_flags + " -mavx512bf16" + else: + return self._arch_flags + + +@dataclasses.dataclass +class VecAVX512VNNI(VecAVX512): + _bit_width = 512 + _arch_flags = VecAVX512._arch_flags + " -mavx512vnni -mavx512vl" + _dtype_nelements = { + torch.float: 16, + torch.bfloat16: 32, + torch.float16: 32, + torch.int8: 64, + torch.uint8: 64, + } + + def __str__(self) -> str: + return super().__str__() + " avx512_vnni" + + __hash__: Callable[[VecISA], Any] = VecISA.__hash__ # type: ignore[assignment] + + _avx512_vnni_code = """ +#include +#include + +extern "C" __m256i __avx512_vnni_chk_kernel_1(__m256i src, __m256i a, __m256i b) { + return _mm256_dpbusd_epi32(src, a, b); +} + +extern "C" __m512i __avx512_vnni_chk_kernel_2(__m512i src, __m512i a, __m512i b) { + return _mm512_dpbusd_epi32(src, a, b); +} +""" + + @functools.cache # noqa: B019 + # pyrefly: ignore [bad-override] + def __bool__(self) -> bool: + if super().__bool__(): + if config.is_fbcode(): + return False + if ( + torch.cpu._is_vnni_supported() + and not _IS_WINDOWS + and self.check_build(self._avx512_vnni_code) + ): + return True + return False + + def build_arch_flags(self) -> str: + return self._arch_flags + + +@dataclasses.dataclass +class VecAMX(VecAVX512VNNI): + _arch_flags = VecAVX512VNNI._arch_flags + " -mamx-tile -mamx-bf16 -mamx-int8" + # check amx_fp16 separately since it is not always supported when amx is supported + # amx_fp16 intrinsic compilation need gcc >=13 on platforms which support amx_fp16 + _is_amx_fp16_supported = False + + def __str__(self) -> str: + return super().__str__() + " amx_tile" + + __hash__: Callable[[VecISA], Any] = VecISA.__hash__ + + _amx_code = """ +#include +#include + +struct amx_tilecfg { + uint8_t palette_id; + uint8_t start_row; + uint8_t reserved_0[14]; + uint16_t colsb[16]; + uint8_t rows[16]; +}; + +extern "C" void __amx_chk_kernel() { + amx_tilecfg cfg = {0}; + _tile_loadconfig(&cfg); + _tile_zero(0); + _tile_dpbf16ps(0, 1, 2); + _tile_dpbusd(0, 1, 2); +} +""" + + _amx_fp16_code = _amx_code.replace("_tile_dpbf16ps", "_tile_dpfp16ps") + + @functools.cache # noqa: B019 + def __bool__(self) -> bool: + if super().__bool__(): + if config.is_fbcode(): + return False + if self.check_build(VecAMX._amx_code) and torch.cpu._init_amx(): + # check amx-fp16 as well when check amx + if torch.cpu._is_amx_fp16_supported(): + # save _arch_flags + base_flags = self._arch_flags + # temporarily change _arch_flags for amx-fp16 check_build + self._arch_flags += " -mamx-fp16" + if self.check_build(VecAMX._amx_fp16_code): + self._is_amx_fp16_supported = True + # restore _arch_flags + self._arch_flags = base_flags + + return True + return False + + @functools.lru_cache(None) # noqa: B019 + def is_amx_fp16_supported(self) -> bool: + return self._is_amx_fp16_supported + + def build_arch_flags(self) -> str: + extra_flags = "" + if self._is_avx512_bf16_supported: + # avx512_bf16 is not among the base flags, so we need to check and add it here + # And we need this flag in the WOQ case for dequantization + extra_flags += " -mavx512bf16" + if self._is_amx_fp16_supported: + extra_flags += " -mamx-fp16" + return self._arch_flags + extra_flags + + +@dataclasses.dataclass +class VecAVX2(VecISA): + _bit_width = 256 + _macro = ["CPU_CAPABILITY_AVX2"] + _arch_flags = ( + "-mavx2 -mfma -mf16c" if not _IS_WINDOWS else "/arch:AVX2" + ) # TODO: use cflags + _dtype_nelements = {torch.float: 8, torch.bfloat16: 16, torch.float16: 16} + + def __str__(self) -> str: + return "avx2" + + __hash__: Callable[[VecISA], Any] = VecISA.__hash__ # type: ignore[assignment] + + +@dataclasses.dataclass +class VecZVECTOR(VecISA): + _bit_width = 256 + _macro = [ + "CPU_CAPABILITY_ZVECTOR", + "CPU_CAPABILITY=ZVECTOR", + "HAVE_ZVECTOR_CPU_DEFINITION", + ] + _arch_flags = "-mvx -mzvector" + _dtype_nelements = {torch.float: 8, torch.bfloat16: 16, torch.float16: 16} + + def __str__(self) -> str: + return "zvector" + + __hash__: Callable[[VecISA], Any] = VecISA.__hash__ # type: ignore[assignment] + + +@dataclasses.dataclass +class VecVSX(VecISA): + _bit_width = 256 # VSX simd supports 128 bit_width, but aten is emulating it as 256 + _macro = ["CPU_CAPABILITY_VSX"] + _arch_flags = "-mvsx" + _dtype_nelements = {torch.float: 8, torch.bfloat16: 16, torch.float16: 16} + + def __str__(self) -> str: + return "vsx" + + __hash__: Callable[[VecISA], Any] = VecISA.__hash__ # type: ignore[assignment] + + +class InvalidVecISA(VecISA): + _bit_width = 0 + _macro = [""] + _arch_flags = "" + _dtype_nelements = {} + + def __str__(self) -> str: + return "INVALID_VEC_ISA" + + def __bool__(self) -> bool: # type: ignore[override] + return False + + __hash__: Callable[[VecISA], Any] = VecISA.__hash__ # type: ignore[assignment] + + +def x86_isa_checker() -> list[str]: + supported_isa: list[str] = [] + + def _check_and_append_supported_isa( + dest: list[str], isa_supported: bool, isa_name: str + ) -> None: + if isa_supported: + dest.append(isa_name) + + Arch = platform.machine() + """ + Arch value is x86_64 on Linux, and the value is AMD64 on Windows. + """ + if Arch != "x86_64" and Arch != "AMD64": + return supported_isa + + avx2 = torch.cpu._is_avx2_supported() + avx512 = torch.cpu._is_avx512_supported() + avx512_vnni = avx512 and torch.cpu._is_vnni_supported() + amx_tile = torch.cpu._is_amx_tile_supported() + + _check_and_append_supported_isa(supported_isa, avx2, "avx2") + _check_and_append_supported_isa(supported_isa, avx512, "avx512") + _check_and_append_supported_isa(supported_isa, avx512_vnni, "avx512_vnni") + _check_and_append_supported_isa(supported_isa, amx_tile, "amx_tile") + + return supported_isa + + +invalid_vec_isa = InvalidVecISA() +supported_vec_isa_list = [ + VecAMX(), + VecAVX512VNNI(), + VecAVX512(), + VecAVX2(), + VecNEON(), + VecSVE256(), +] + + +def get_isa_from_cpu_capability( + capability: Union[str, None], + vec_isa_list: list[VecISA], + invalid_vec_isa: InvalidVecISA, +): + # AMX setting is not supported in eager + # VecAMX will be prioritized for selection when setting ATEN_CPU_CAPABILITY to avx512 + # TODO add sve256 support + capability_to_isa_str = { + "default": "INVALID_VEC_ISA", + "zvector": "zvector", + "vsx": "vsx", + "avx2": "avx2", + "avx512": "avx512", + } + if capability in capability_to_isa_str: + # pyrefly: ignore [index-error] + isa_str = capability_to_isa_str[capability] + if isa_str == "INVALID_VEC_ISA": + return invalid_vec_isa + for vec_isa in vec_isa_list: + if isa_str in str(vec_isa): + return vec_isa + + if capability: + warnings.warn(f"ignoring invalid value for ATEN_CPU_CAPABILITY {capability}") + + return vec_isa_list[0] + + +# Cache the cpuinfo to avoid I/O overhead. Meanwhile, the cpuinfo content +# might have too much redundant content that is useless for ISA check. Hence, +# we only cache some key isa information. +@functools.cache +def valid_vec_isa_list() -> list[VecISA]: + isa_list: list[VecISA] = [] + if sys.platform == "darwin" and platform.processor() == "arm": + isa_list.append(VecNEON()) + + if sys.platform not in ["linux", "win32"]: + return isa_list + + arch = platform.machine() + if arch == "s390x": + with open("/proc/cpuinfo") as _cpu_info: + while True: + line = _cpu_info.readline() + if not line: + break + # process line + featuresmatch = re.match(r"^features\s*:\s*(.*)$", line) + if featuresmatch: + for group in featuresmatch.groups(): + if re.search(r"[\^ ]+vxe[\$ ]+", group): + isa_list.append(VecZVECTOR()) + break + elif arch == "ppc64le": + isa_list.append(VecVSX()) + elif arch == "aarch64": + if torch.backends.cpu.get_cpu_capability() == "SVE256": + isa_list.append(VecSVE256()) + else: + isa_list.append(VecNEON()) + + elif arch in ["x86_64", "AMD64"]: + """ + arch value is x86_64 on Linux, and the value is AMD64 on Windows. + """ + _cpu_supported_x86_isa = x86_isa_checker() + isa_list.extend( + isa + for isa in supported_vec_isa_list + if all(flag in _cpu_supported_x86_isa for flag in str(isa).split()) and isa + ) + + return isa_list + + +def pick_vec_isa() -> VecISA: + if config.is_fbcode() and (platform.machine() in ["x86_64", "AMD64"]): + return VecAVX2() + + _valid_vec_isa_list: list[VecISA] = valid_vec_isa_list() + if not _valid_vec_isa_list: + return invalid_vec_isa + + # If the simdlen is None, set simdlen based on the environment ATEN_CPU_CAPABILITY + # to control CPU vec ISA + if config.cpp.simdlen is None: + return get_isa_from_cpu_capability( + os.getenv("ATEN_CPU_CAPABILITY"), _valid_vec_isa_list, invalid_vec_isa + ) + + for isa in _valid_vec_isa_list: + if config.cpp.simdlen == isa.bit_width(): + return isa + + return invalid_vec_isa diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/cudagraph_trees.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/cudagraph_trees.py new file mode 100644 index 0000000000000000000000000000000000000000..72d0bcc69e3d0e1a98af0e63bdf56634b1701e23 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/cudagraph_trees.py @@ -0,0 +1,2600 @@ +""" +CUDA graph trees are a safety abstraction over CUDAGraphs, similar to make_graph_callables, +which share the same memory pool. Sharing a memory pool is an extremely +important optimization when chaining multiple CUDA graphs together, as it +prevents you from needing to copy intermediate tensors from one graph to the +next, and reduces overall memory usage by allowing dead memory from the first +pool to be reused in the second. + +The standard graph/make_graph_callables support sharing memory pool, but +with a lot of caveats. CUDA graph trees remove these restrictions: + +* Previously, if you recorded graphs A, B, you had to replay A, B in that + order. With CUDA graph trees, after replaying A, you can change your + mind and record/replay a different graph B'; we will support efficient + execution of both A, B and A, B', using only max(mem(A, B), mem(A, B')). In + other words: we support arbitrary trees of CUDA graph operations, not just + sequences (this is why this feature is called CUDA graph trees.) + +* Previously, if you executed graph A, some non-CUDA graph code, and then + graph B, after executing graph B, it was not safe to retain any references + to intermediates produced by A. With CUDA graph trees, we track if any +outputs of graph A are still live by the time graph B is run, and make + sure graph B doesn't clobber there memory when reusing the CUDA graphs + pool. You'll get a separate recording of B depending on what tensors + stay live or dead. + +CUDA graph trees are flexible enough to be used in Dynamo across graph breaks, +which is their primary use case. + +The ability to switch from replay to record is fairly nontrivial: remember that +when you replay a CUDA graph, you only replay CUDA operations; no CPU side state +is updated. In particular, the CPU-side book-keeping for the allocator is not +reconstructed. However, to record a new child CUDA graph, we must restore this +book-keeping. This is what checkpoint pool state is used for. +""" + +from __future__ import annotations + +import contextlib +import dataclasses +import functools +import gc +import itertools +import operator +import sys +import threading +import traceback +import warnings +import weakref +from collections import defaultdict +from contextlib import AbstractContextManager +from enum import auto, Enum +from typing import Any, cast, Optional, TYPE_CHECKING, TypeVar, Union + +import torch.fx +from torch import Tensor +from torch._dynamo.callback import CallbackTrigger +from torch._dynamo.mutation_guard import GenerationTracker +from torch._dynamo.utils import counters, dynamo_timed, preserve_rng_state +from torch._inductor.compile_fx import ( + align_inputs_from_check_idxs, + copy_misaligned_inputs, + get_expanded_dims, + get_input_idxs_to_check, + index_expanded_dims, + remove_unaligned_input_idxs, + static_input, +) +from torch._inductor.cudagraph_utils import ( + check_for_mutation, + CheckInvariantStatus, + FunctionID, + log_cudagraph_skip_and_bump_counter, + log_data_ptr_mismatch, + maybe_warning_due_to_dynamic_shape, + ModelType, + OutputType, + PlaceholderInfo, + WrappedFunction, +) +from torch.multiprocessing.reductions import StorageWeakRef +from torch.storage import UntypedStorage +from torch.utils import _pytree as pytree +from torch.utils._ordered_set import OrderedSet +from torch.utils.weak import TensorWeakRef + + +if TYPE_CHECKING: + from collections.abc import Callable, Generator, Iterator, Sequence + + from torch._guards import CompileId + from torch._inductor.utils import InputType + from torch.cuda import _POOL_HANDLE + from torch.types import _bool + +StorageWeakRefPointer = int +StorageDataPtr = int +NBytes = int +S = TypeVar("S", bound="StorageWeakRefWrapper") + + +if torch.backends.cuda.is_built(): + from torch._C import ( + _cuda_CUDAAllocator_AllocatorState as AllocatorState, + _set_cached_tensors_enabled, + ) +else: + + class AllocatorState: # type: ignore[no-redef] + pass + + def _set_cached_tensors_enabled(enabled: _bool) -> None: + pass + + +log = torch._logging.getArtifactLogger(__name__, "cudagraphs") + + +from . import config + + +@dataclasses.dataclass(frozen=True) +class GraphID: + "Unique counter of a cuda graph recording" + + id: int + + +def clear_cublass_cache() -> None: + """ + Cublas keeps a persistent workspace allocation for running matmuls. This poses a problem for + doing warmup within a CUDAGraph private pool because we do not want persistent allocations from + one one run to the next. When we begin a new run of a cudagraphs path (generation), all tensors + from the previous generation are freed. This frees them the memory pool, but not elsewhere. + A tensor in the cublas workspace would continue to be in use the workspace but would also get allocated + in the next run. The memory would be in use in two places. + + To solve this, we clear cublas caches before and after warming up or recording. If a workspace is required + it will be allocated to the cudagraph private pool and accounted for in the allocator for the duration of the + program. There is no overhead to this on replay since cudagraphs removes allocation overhead. + """ + torch._C._cuda_clearCublasWorkspaces() + + +@contextlib.contextmanager +def clear_cublas_manager() -> Generator[None, None, None]: + "Context manager around clearing cublas caches that will clear on enter and exit" + clear_cublass_cache() + try: + yield + finally: + clear_cublass_cache() + + +@contextlib.contextmanager +def disable_conv_cache_emptying() -> Generator[None, None, None]: + prev = torch._C._cuda_get_conv_benchmark_empty_cache() + torch._C._cudnn_set_conv_benchmark_empty_cache(False) + try: + yield + finally: + torch._C._cudnn_set_conv_benchmark_empty_cache(prev) + + +@contextlib.contextmanager +def enable_history_recording() -> Generator[None, None, None]: + "Turns on history recording in the CUDA Caching Allocator" + enabled = torch._C._cuda_isHistoryEnabled() + try: + if not enabled: + torch.cuda.memory._record_memory_history() + yield + finally: + if not enabled: + torch.cuda.memory._record_memory_history(None) + + +def get_history_recording() -> AbstractContextManager[None]: + # TODO - remove, prevents cleanup + if not config.triton.cudagraph_trees_history_recording: + return contextlib.nullcontext() + return enable_history_recording() + + +class TreeManagerContainer: + """ + Manages the lifetime of the tree manager. Like `PrivatePool` in cuda caching allocator, + the tree and its corresponding memory pool should be kept alive as long as any outstanding + graph or tensor which is an output of a graph remains alive. + + There is a single tree manager container per device. + + The lifecycle of a tree_manager is: + - Is constructed, no graph, no fns, no tensors + - Tree manager is fetched, resulting in tree manager being allocated + - We generate a bunch of functions, calling add_strong_reference + - These functions die, calling finalize_reference + - When all the functions die, we finalize_tree_manager. + + TODO: in the future, we would like to do the following once storage weak refs land + - We look for all the live storages and add references to THOSE + - We count as storages die + - All the storages are dead, we deallocate the tree manager + """ + + def __init__(self, device_index: int) -> None: + # This class keeps a strong reference to tree_manager, + # but upon all other strong references to the tree_manager will reset it to None. + # We need a strong reference so that we can still access its attributes upon cleanup. + self.tree_manager: Optional[CUDAGraphTreeManager] = None + + # Number of outstanding references to the current tree manager + self.live_cudagraphify_fns = 0 + + self.device_index = device_index + + # Following two objects are only set in the case that Tensor outputs outlive + # the cudagraphify_fns. Reference to the Graph is needed to keep the private pool from + # deallocation. + self.live_storages_count = 0 + self.graph: Optional[torch.cuda.CUDAGraph] = None + + self.lock = threading.Lock() + + def _finalize_tensor(self) -> None: + with self.lock: + self.live_storages_count -= 1 + if self.live_storages_count == 0: + self.graph = None + + # manager was used again after existing cleanup, + # we shouldn't set it to None + if self.live_cudagraphify_fns == 0: + self.tree_manager = None + + def finalize_cudagraphify_fn(self) -> None: + with self.lock: + self.live_cudagraphify_fns -= 1 + if self.live_cudagraphify_fns == 0: + self._finalize_tree_manager() + + def _finalize_tree_manager(self) -> None: + assert self.lock.locked() + self.tree_manager = None + + # TODO - when issue #91395 is landed, we can set a weakref on + # storages and trigger a deallocation when all outputs of the + # cudagraph are dead. + + # live_storages = list( + # tree_manager.live_cudagraph_pool_storages_in_curr_execution() + # ) + + # # Maintain reference to graph to keep tensors alive + # assert len(tree_manager.roots) > 0, "expected at least one use" + # root = next(tree_manager.get_roots()) + # self.graph = root.graph + # seen_storages = set() + # for stor in live_storages: + # if stor in seen_storages: + # continue + # seen_storages.add(stor) + # self.live_storages_count += 1 + # . weakref.finalize(stor, self._finalize_tensor) + + def add_strong_reference(self, fn: Callable[..., Any]) -> None: + with self.lock: + self.live_cudagraphify_fns += 1 + + weakref.finalize(fn, self.finalize_cudagraphify_fn) + + def get_tree_manager(self) -> CUDAGraphTreeManager: + with self.lock: + if self.tree_manager is None: + self.tree_manager = CUDAGraphTreeManager(self.device_index) + return self.tree_manager + + +local = threading.local() + +# one tree manager per device +local.tree_manager_containers = {} +local.tree_manager_locks = defaultdict(threading.Lock) + + +# only incremented by user call of mark_step_begin +class MarkStepBox: + mark_step_counter = 0 + + +# We need to register this as an object that will be copied over as TLS when new +# threads are created in autograd +torch._C._stash_obj_in_tls("tree_manager_containers", local.tree_manager_containers) +torch._C._stash_obj_in_tls("tree_manager_locks", local.tree_manager_locks) + + +def mark_step_begin() -> None: + "Indicates that a new iteration of inference or training is about to begin." + + # iterate down to distinguish from GenerationTracking counter + MarkStepBox.mark_step_counter -= 1 + + +def reset_cudagraph_trees() -> None: + "Clear all cudagraph trees" + # see shutdown below for why this is necessary + container_dict = get_obj(local, "tree_manager_containers") + locks_dict = get_obj(local, "tree_manager_locks") + for device, lock in locks_dict.items(): + with lock: + container = container_dict.get(device) + if not container or not container.tree_manager: + continue + + container.tree_manager.shutdown() + + _set_cached_tensors_enabled(False) + container_dict.clear() + + MarkStepBox.mark_step_counter = 0 + + +def get_obj(local: Any, attr_name: str) -> Any: + if hasattr(local, attr_name): + return getattr(local, attr_name) + else: + assert torch._C._is_key_in_tls(attr_name) + return torch._C._get_obj_in_tls(attr_name) + + +def get_container(device_index: int) -> TreeManagerContainer: + container_dict = get_obj(local, "tree_manager_containers") + lock = get_obj(local, "tree_manager_locks")[device_index] + + with lock: + if device_index not in container_dict: + container_dict[device_index] = TreeManagerContainer(device_index) + + return container_dict[device_index] + + +def get_manager( + device_index: int, create_if_none_exists: bool = True +) -> Optional[CUDAGraphTreeManager]: + if create_if_none_exists: + return get_container(device_index).get_tree_manager() + return get_container(device_index).tree_manager + + +def is_cudagraph_capture_sizes(int_key: Union[int, tuple[int, ...]]) -> bool: + """ + Returns true if all dynamic shapes should be captured or the dynamic shape + int_key should be captured. + """ + return ( + config.triton.cudagraph_capture_sizes is None + or int_key in config.triton.cudagraph_capture_sizes + ) + + +def cudagraphify_impl( + model: ModelType, + inputs: list[InputType], + static_input_idxs: Sequence[int], + *args: Any, + **kwargs: Any, +) -> ModelType: + fn_cache: dict[tuple[int, ...], Callable[..., Any]] = {} + + # Detect int inputs: we need to index on these + int_key = [i for i, v in enumerate(inputs) if isinstance(v, int)] + get_ints: Any = operator.itemgetter(*int_key) if int_key else lambda _: None + + has_warn = False + + del inputs + + def deferred_cudagraphify(inputs: list[InputType]) -> OutputType: + nonlocal has_warn + + int_key = get_ints(inputs) + + if not is_cudagraph_capture_sizes(int_key): + return model(inputs) + + fn = fn_cache.get(int_key) + if fn is not None: + return fn(inputs) + + if int_key is None: + log.info("recording cudagraph tree for graph without symints") + else: + log.info("recording cudagraph tree for symint key %s", int_key) + + if not has_warn: + has_warn = maybe_warning_due_to_dynamic_shape(fn_cache, int_key) + + # first get indices we need to check to align, then update our static inputs, + # and finally copy + check_input_idxs = get_input_idxs_to_check(inputs, static_input_idxs) + new_static_input_idxs = remove_unaligned_input_idxs(inputs, static_input_idxs) + copy_misaligned_inputs(inputs, check_input_idxs) + + fn, out = cudagraphify(model, inputs, new_static_input_idxs, *args, **kwargs) + # cudagraph will already clones input locally, no need to copy back + mutated_input_idxs: OrderedSet[int] = OrderedSet() + fn = align_inputs_from_check_idxs( + fn, inputs_to_check=check_input_idxs, mutated_input_idxs=mutated_input_idxs + ) + # pyrefly: ignore [unsupported-operation] + fn_cache[int_key] = fn + + return out + + return deferred_cudagraphify + + +@contextlib.contextmanager +def dynamo_timed_cudagraph( + name: str, + compile_id: Optional[CompileId], + mode: Optional[CompilationMode], +) -> Generator[Any, None, None]: + """ + Makes usages of dynamo_timed in this file less verbose. NOTE: This CM sums + all durations into a single column in the dynamo_compile table. Use only if + you consider the timed region to be part of the runtime overhead associated + with the compiler. + """ + with dynamo_timed( + name, + log_pt2_compile_event=True, + compile_id=compile_id, + is_backward=mode == CompilationMode.BACKWARD, + dynamo_compile_column_us="runtime_cudagraphify_time_us", + ): + yield + + +def cudagraphify( + model: ModelType, + inputs: list[InputType], + static_input_idxs: Sequence[int] = (), + *, + device_index: int, + is_backward: bool, + is_inference: bool, + stack_traces: Optional[StackTraces] = None, + constants: tuple[torch.Tensor, ...] = (), + placeholders: tuple[PlaceholderInfo, ...] = (), + mutated_input_idxs: tuple[int, ...] = (), + compile_id: Optional[CompileId] = None, +) -> tuple[ModelType, OutputType]: + assert not (is_backward and is_inference) + mode = ( + CompilationMode.BACKWARD + if is_backward + else (CompilationMode.INFERENCE if is_inference else CompilationMode.FORWARD) + ) + + with dynamo_timed_cudagraph("cudagraphify.get_container", compile_id, mode): + manager = get_container(device_index).get_tree_manager() + + return manager.add_function( + model, + inputs, + static_input_idxs, + stack_traces, + mode, + constants, + placeholders, + mutated_input_idxs, + compile_id, + ) + + +class StorageWeakRefWrapper: + """ + Wrapper around a storage weak ref. Will deallocate it upon expiration if invoked. + """ + + __slots__ = ["ref", "_data_ptr", "extra_ref_check"] + + storage_ref: Optional[StorageWeakRef] + + def __init__( + self, + inp: Union[Tensor, UntypedStorage], + extra_ref_check: Optional[Callable[[], bool]] = None, + ) -> None: + """ + extra_ref_check is an additional check we need to run to check if the + weak ref has expired. in checking storage use count we assume extra_ref_check + will hold an additional reference to the storage. + """ + if isinstance(inp, Tensor): + stor = inp.untyped_storage() + else: + assert isinstance(inp, UntypedStorage) + stor = inp + self.ref = StorageWeakRef(stor) + self._data_ptr = stor.data_ptr() + self.extra_ref_check = extra_ref_check + + @classmethod + def from_weakref_and_data_ptr( + cls: type[StorageWeakRefWrapper], + cdata: Any, + data_ptr: int, + extra_ref_check: Optional[Callable[[], bool]] = None, + ) -> StorageWeakRefWrapper: + instance = cls.__new__(cls) + instance._data_ptr = data_ptr + instance.ref = StorageWeakRef.from_weakref(cdata) + instance.extra_ref_check = extra_ref_check + return instance + + def __call__(self) -> Optional[StorageWeakRefPointer]: + if self.expired(): + return None + + return self.ref.cdata + + def swap_weakref(self, cdata: Any) -> None: + self.ref.__del__() + self.ref.cdata = cdata + + def data_ptr(self) -> int: + "NB: returns the data ptr even if the storage has expired" + return self._data_ptr + + def remove_extra_reference(self) -> None: + self.extra_ref_check = None + + def expired(self) -> bool: + if self.extra_ref_check is not None and not self.extra_ref_check(): + return False + + stor_count = torch._C._storage_Use_Count(self.ref.cdata) + if self.extra_ref_check is not None: + # if extra_ref_check is not None we expect two additional references: + # - one from the Python storage object + # - one from the cached Tensor + stor_count -= 2 + assert stor_count >= 0 + return stor_count == 0 + + def __repr__(self) -> str: + if self.ref is None or self.ref.expired(): + return f"StorageWeakRefWrapper to {self.data_ptr()}; dead" + else: + return f"StorageWeakRefWrapper to {self.data_ptr()}; alive" + + +def is_live(weak_ref: Optional[StorageWeakRefWrapper]) -> bool: + return maybe_deref(weak_ref) is not None + + +def maybe_deref( + weak_ref: Optional[StorageWeakRefWrapper], +) -> Optional[tuple[StorageWeakRefPointer, int]]: + if weak_ref is None: + return None + r = weak_ref() + if r is None: + return None + # NB: r.data_ptr() does not necessarily equal weak_ref.data_ptr() + return r, weak_ref.data_ptr() + + +@contextlib.contextmanager +def _use_cuda_memory_pool_manager( + device: int, mem_pool: tuple[int, int], stream: torch.cuda.Stream +) -> Generator[None, None, None]: + """ + Context manager to use cuda graph pool for new allocations. If you use this manager + all cudagraph tensors in use should be reflected in the allocator or they will be overwritten. + existing_graph should already have been used in a capture, and the mem_pool must already exist, + because this manager will not preserve a reference to the pool which keeps it alive. + """ + torch.cuda.synchronize() + stream.wait_stream(torch.cuda.current_stream()) + + with torch.cuda.stream(stream), torch.device(device): + # Begin allocate to mem pool for all memory allocation on the current thread. + # This is thread safe since a thread can only warmup or record 1 cudagraph + # at the same time. + torch._C._cuda_beginAllocateCurrentThreadToPool(device, mem_pool) + try: + yield + finally: + torch._C._cuda_endAllocateToPool(device, mem_pool) + torch._C._cuda_releasePool(device, mem_pool) + + torch.cuda.current_stream().wait_stream(stream) + + +def map_to_ref(t: Optional[Tensor]) -> Optional[StorageWeakRefWrapper]: + if not isinstance(t, torch.Tensor): + assert t is None + return None + return StorageWeakRefWrapper(t) + + +# A path index of (depth, offset) indices into a graph that is `depth`` number of nodes from the root +# at graph output offset +PathOutputIndex = tuple[int, int] + +# For each node in the path, for each output, is the output alive +PathLiveness = list[list[bool]] + +StackTraces = list[Optional[str]] + + +class CUDAWarmupNode: + """ + Simplified Wrapper around A CUDA Model that wraps outputs in storage refs and exposes + apis to get the live storages in the current chain of warmup. + + A CUDAWarmupNode may have either CUDAGraphNode or CUDAWarmupNode as a parent, but may only have + CUDAWarmupNode as children, because we cannot record or execute with tensors which do not have stable + memory addresses. + + CUDAWarmupNode and CUDAGraphNode have a number of differences that make it easier to use separate classes. + - Much of the CUDAGraphNode logic & initialization is based on the tensor properties of first recording. In the + first instance of warmup, these are not finalized yet. + - All Inputs to the RecordedFunction must be copied over to the cuda graph memory pool, this is unnecessary in warmup. + - CUDAWarmup is only used once and so does not need to optimize as much bookkeeping. It is much simpler. + + NB: this class and CUDAGraphNode need to expose `path_live_weakrefs`, `all_outputs_are_dead`, and + `self.outputs_weakrefs`, `stack_traces`, and `tensor_weakrefs` for compatibility. + """ + + def __init__( + self, + wrapped_function: WrappedFunction, + parent: Optional[Union[CUDAGraphNode, CUDAWarmupNode]], + cuda_graphs_pool: tuple[int, int], + existing_cuda_graph: Optional[torch.cuda.CUDAGraph], + device_index: int, + stack_traces: Optional[StackTraces], + stream: torch.cuda.Stream, + already_warm: bool, + id: GraphID, + ) -> None: + self.wrapped_function = wrapped_function + self.parent: Optional[Union[CUDAGraphNode, CUDAWarmupNode]] = parent + self.cuda_graphs_pool = cuda_graphs_pool + self.outputs_weakrefs: list[Optional[StorageWeakRefWrapper]] = [] + self.tensor_weakrefs: list[Optional[TensorWeakRef]] = [] + self.existing_cuda_graph = existing_cuda_graph + self.has_run = False + self.device_index = device_index + self.stack_traces = stack_traces + self.stream = stream + self.already_warm = already_warm + self.id = id + + def run(self, new_inputs: Any) -> OutputType: + assert not self.has_run, "Wrapped function should never be run twice" + + # See: output_is_alias_of_persistent_static_inputs below. We should only be returning freshly created + # storages in path_live_weakrefs. + existing_path_data_ptrs = OrderedSet( + [t.data_ptr() for t in self.path_live_weakrefs() if t()] + ) + + def get_non_cudagraph_inps() -> list[weakref.ReferenceType[UntypedStorage]]: + non_cudagraph_inps = [ + weakref.ref(t.untyped_storage()) + for t in itertools.chain(new_inputs, self.wrapped_function.constants) + if isinstance(t, torch.Tensor) + and t.untyped_storage().data_ptr() not in existing_path_data_ptrs + ] + return non_cudagraph_inps + + non_cudagraph_inps_storages = get_non_cudagraph_inps() + + if config.triton.slow_path_cudagraph_asserts and not self.already_warm: + refs = list(self.path_live_weakrefs()) + check_memory_pool(self.device_index, self.cuda_graphs_pool, refs) + + with ( + torch.cuda.device(self.device_index), + disable_conv_cache_emptying(), + clear_cublas_manager(), + _use_cuda_memory_pool_manager( + self.device_index, self.cuda_graphs_pool, self.stream + ), + get_history_recording(), + ): + out = self.wrapped_function.model(new_inputs) + + # We need to know which outputs are allocated within the cudagraph pool + # so that we can deallocate them at the beginning of the next cudagraph step, + # and set their access to error. + # We use a weakref to the inputs storage, in case a block which was previously + # allocated to the general caching allocator pool gets reallocated to a private pool. + + non_cudagraph_inps_storage_ptrs = OrderedSet[Any]() + for storage in non_cudagraph_inps_storages: + s = storage() + if s is not None: + non_cudagraph_inps_storage_ptrs.add(s._cdata) + + assert len(new_inputs) == 0 + + # sdpa returns cpu tensors when not recording cuda graph + def add_ref(o: Any) -> bool: + return ( + isinstance(o, torch.Tensor) + and o.is_cuda + and o.untyped_storage()._cdata not in non_cudagraph_inps_storage_ptrs + and o.untyped_storage().data_ptr() != 0 + ) + + self.outputs_weakrefs.extend( + [map_to_ref(o) if add_ref(o) else None for o in out] + ) + self.tensor_weakrefs.extend( + [TensorWeakRef(o) if add_ref(o) else None for o in out] + ) + + if config.triton.slow_path_cudagraph_asserts and not self.already_warm: + out_refs = list(self.path_live_weakrefs()) + check_memory_pool(self.device_index, self.cuda_graphs_pool, out_refs) + + return out + + @property + def _path_from_root( + self, + ) -> Generator[Union[CUDAGraphNode, CUDAWarmupNode], None, None]: + nodes = [] + node: Union[CUDAGraphNode, CUDAWarmupNode] = self + while node: + nodes.append(node) + node = node.parent # type: ignore[assignment] + + yield from reversed(nodes) + + def path_live_weakrefs(self) -> Iterator[StorageWeakRefWrapper]: + "Returns all live storages weakrefs that created by nodes in this path" + for node in self._path_from_root: + for output in node.outputs_weakrefs: + if is_live(output): + yield output # type: ignore[misc] + + def all_outputs_are_dead(self) -> bool: + return not list(self.path_live_weakrefs()) + + def _is_cuda_graph_recorded_tensor(self, t: torch.Tensor) -> bool: + for storage_weak_ref in self.path_live_weakrefs(): + if t.untyped_storage().data_ptr() == storage_weak_ref.data_ptr(): + return True + return False + + +# Aliases for List that say what the indices denote +InputList = list # input indexes +OutputList = list # output indexes +LevelList = list # levels (distance from root of tree) + + +class OutputAliasInfo: + __slots__ = [] + + +class _UnaliasedStorage(OutputAliasInfo): + "Singleton to mark that the graph output constructs a new alias or is None" + + +UnaliasedStorage = _UnaliasedStorage() + + +class AliasesPriorGraphOutput(OutputAliasInfo): + "Marks that the graph output aliases an output of a prior graph" + + __slots__ = ["index"] + + index: PathOutputIndex + + def __init__(self, index: PathOutputIndex) -> None: + assert isinstance(index, tuple) + self.index = index + + +class AliasesNewOutput(OutputAliasInfo): + "Marks that the graph output aliases an index in the new, returned outputs" + + __slots__ = ["index"] + + index: int + + def __init__(self, index: int) -> None: + assert isinstance(index, int) + self.index = index + + +class CUDAGraphNode: + """ + A single recording of a function into a CUDA Graph. Recordings of CUDA Graphs share a single memory pool + and are structured into a tree, where there is a single recording that can precede it (parent) and multiple + subsequent recordings that may follow (children). A node will have no parent if it is the first recording + in a tree; i.e., when it is first recorded, there are no live tensors from a previous recording which + would force a dependency. + + On first recording, all of the live tensors in the current CUDA Graph Node path will be + reflected in the corresponding private pool. On subsequent executions, the caching allocator + is unaffected when the graph is replayed. + + In order to support recording a subsequent cuda graph recording after execution of this graph, + we checkpoint the state of the memory pool so that it may later be resumed. + + WrappedFunction should have already been warmed up prior to invocation. + + See [setCheckpointPoolState] for further explanation, as well as + https://user-images.githubusercontent.com/13564/222815509-374f3400-f83d-4f7d-8fa6-4a092b3250bb.png + """ + + def __init__( + self, + wrapped_function: WrappedFunction, + id: GraphID, + parent: Optional[CUDAGraphNode], + inputs: list[InputType], + cuda_graphs_pool: _POOL_HANDLE, + device_index: int, + stack_traces: Optional[StackTraces], + stream: torch.cuda.Stream, + mode: Optional[CompilationMode], + compile_id: Optional[CompileId], + ) -> None: + assert isinstance(inputs, (list, tuple)) + + self.wrapped_function = wrapped_function + self.id = id + self.device = device_index + self.stack_traces = stack_traces + self.stream = stream + + # Enable re-record a cudagraph when static tensor address changed. + # if not we should error when it changed. + self.rerecord_if_static_inputs_change = ( + torch._dynamo.config.inline_inbuilt_nn_modules + or torch._inductor.config.triton.cudagraph_support_input_mutation + ) + + # if this is a root parent will be None. use weakref to prevent reference cycle + self._parent = weakref.ref(parent) if parent is not None else None + # reference to the shared memory pool for the entire cuda graphs tree + self.cuda_graphs_pool = cuda_graphs_pool + + # A single wrapped function may be recorded multiple times if memory patterns or + # invariants change from one execution to the next + self.children: dict[FunctionID, list[CUDAGraphNode]] = defaultdict(list) + + # StorageWeakRef maintains whether the Storage C++ object remains allocated, + # not whether the corresponding memory has been deallocated. In order + # to use them to track memory deallocations we must maintain a single StorageWeakRef + # for all Storages that reference that memory (even if we are constructing Storages + # that do not have a deallocator function). We maintain one single storage_cache + # as we execute any tree path. When we retrieve a storage from the cache we + # check that it is still alive, and we hash based on observed recording data ptr + # and storage cdata. + + # we preserve a single reference to executed outputs that is then referenced + # in children to avoid children having to chase parent pointers in the hot path + # DO NOT reassign output_weakrefs, only call `clear()` + # Path is a series of nodes from root to the current node + self.outputs_weakrefs: OutputList[Optional[StorageWeakRefWrapper]] = [] + self.path_weakrefs: LevelList[OutputList[Optional[StorageWeakRefWrapper]]] = [ + node.outputs_weakrefs for node in self._path_from_root + ] + self.path_stacktraces: LevelList[Optional[StackTraces]] = [ + node.stack_traces for node in self._path_from_root + ] + self.tensor_weakrefs: OutputList[Optional[TensorWeakRef]] = [] + + # tensors which are outputs of previous graphs in the tree + self.cudagraph_managed_idxs: list[int] = [ + idx + for idx, t in enumerate(inputs) + if isinstance(t, torch.Tensor) and self._is_cuda_graph_recorded_tensor(t) + ] + + # (depth, offset) of live tensors which are alias of previous graph outputs + self.live_cudagraph_managed_path_refs: InputList[Optional[PathOutputIndex]] = [ + ( + self._is_alias_of_live_recorded_tensor(t) + if isinstance(t, torch.Tensor) + else None + ) + for t in inputs + ] + + # when replay, preserve the liveness of an input if it AliasesPriorGraphOutput + # and also aliases an output of the current CUDAGraphNode + self.preserved_aliased_inputs: InputList[bool] = [False] * len(inputs) + + self.static_input_idxs: list[int] = list( + OrderedSet(wrapped_function.static_input_idxs) + | OrderedSet(self.cudagraph_managed_idxs) + ) + + self.non_static_input_idx: LevelList[int] = [ + i for i in range(len(inputs)) if i not in self.static_input_idxs + ] + + counters["inductor"]["cudagraph_recorded_non_static_inputs"] += len( + self.non_static_input_idx + ) + + self.non_managed_static_input_idxs: LevelList[int] = [ + i + for i in wrapped_function.static_input_idxs + if i not in self.cudagraph_managed_idxs + ] + + def maybe_get_static_data_ptr( + idx: int, + inputs: list[InputType], + static_input_idxs: list[int], + ) -> Optional[int]: + inp = inputs[idx] + if isinstance(inp, torch.Tensor) and idx in static_input_idxs: + return inp.data_ptr() + return None + + self.static_input_data_ptrs: InputList[Optional[int]] = [ + # pyrefly: ignore [bad-argument-type] + maybe_get_static_data_ptr(i, inputs, self.static_input_idxs) + for i in range(len(inputs)) + ] + + # When we checkpoint, and free generations, we will be manually freeing the outputs + # of CUDAGraphNodes. We should not be freeing parameters, not do we need to account for + # their liveness (they are static), so we need to compute which outputs are aliases of + # parameters. Some static inputs are saved tensors from the forward that die in the backward. + # Their locations are static but lifetimes are not. We only include the persistent static + # data ptrs below because the non persistent data ptrs may be outputs of this record and + # fresh allocations. + + # precompute expanded dims to avoid computing in the hot path + self.expanded_dims: list[list[int]] = [ + get_expanded_dims(x) + if isinstance(x, torch.Tensor) and idx not in self.static_input_idxs + else [] + for idx, x in enumerate(inputs) + ] + + # For each node in path, which outputs were observed to be live + # before invoking graph recording, and after graph recording + self.recorded_liveness_before_graph: LevelList[OutputList[bool]] = [] + self.recorded_liveness_after_graph: LevelList[OutputList[bool]] = [] + + # List of tuples of (depth, output_index) that index into node at depth + # number of nodes from root and output_index of outputs. Will index into + # path_weakrefs. + self.expected_dead_indices_before_graph: list[PathOutputIndex] = [] + self.expected_dead_indices_after_graph: list[PathOutputIndex] = [] + + # all live indices after graph recording + self.live_indices_after_graph: list[PathOutputIndex] = [] + + if self.parent is not None: + previous_liveness = self.parent.recorded_liveness_after_graph + curr_liveness = self._get_liveness(self.path_weakrefs) + + different_indices = self._get_different_indices( + previous_liveness, curr_liveness + ) + + self.recorded_liveness_before_graph = curr_liveness + self.expected_dead_indices_before_graph = different_indices + + rng_states = [inp for inp in inputs if isinstance(inp, torch.Generator)] + # pyrefly: ignore [bad-argument-type] + recording_inputs = self._allocate_and_copy_recording_inputs(inputs) + # recording inputs will copy over memory, so we can free non recording inputs + # pyrefly: ignore [missing-attribute] + inputs.clear() + del inputs + + # graph used for recording model invocation + self.graph: Optional[torch.cuda.CUDAGraph] = torch.cuda.CUDAGraph() + + # TODO: register_generator_state should potentially take explicit device + with torch.cuda.device(self.device): + for rng_state in rng_states: + self.graph.register_generator_state(rng_state) + + # we allocate non-static inputs within the same memory pool as the CUDAGraph + # which we will record the model with. For memory efficiency, it is important + # to reclaim the input memory when the inputs are no longer live. To accomplish this, + # we reconstruct tensors at the correct data pointers of our inputs which are + # non owning and do not prevent deallocation. On subsequent executions, input values + # will be copied over to these tensors. + self.reconstructed_inputs: list[InputType] = [ + self._reconstruct_from_tensor_metadata(self._tensor_metadata(x)) + if isinstance(x, torch.Tensor) + else x + for x in recording_inputs + ] + + # DO THE RECORDING!!! + # We record the CUDA graph in the constructor of CUDAGraphNode, which + # gives you what the CPU side compute of the function would do. We + # don't throw the recording outputs away: their memory is + # correctly accounted for in the CUDAGraphs caching allocator. This + # means on the very FIRST run of the CUDA graph node, we can directly + # do more recording, because we have a valid caching allocator state. + # NB: This relies on run() being called immediately after the + # constructor, otherwise this optimization would not be valid. + + # initialized below in _record + + self.checkpointed_caching_state: Optional[AllocatorState] = None + + # Output Storage Alias information, can be: + # - A new, unaliased storage, or the output is None + # - An alias of an output of a prior graph + # - An alias of an output already created in the reconstructed outputs + # This is None if the output in question is an int + self.output_storage_alias: OutputList[Optional[OutputAliasInfo]] = [] + + # is the output Storage unaliased in subsequent outputs, of all subsequent paths + # if it is, we cached the output tensor and adjust storage liveness tracking to also + # check if the output tensor does not have an additional python reference. + # If a descendent node discovers it has an alias of a prior output, then the output + # will no longer be cached in the ancestor. + # The large majority of tensors are unaliased, and preserving aliased output tensors would add + # significant additional complexity with marginal gains + # The cached tensor outputs are added on the first execution, and cleared whenever we need + # to do subsequent recording + self.unaliased_in_all_paths: OutputList[bool] = [] + self.cached_tensor_outputs: OutputList[Optional[Tensor]] = [] + + # if an output aliases a static, persistent input then the corresponding Tensor will + # be set here. These are different than cached tensors, because they are tensors that + # are aliases of parameters that are always live. + self.static_output_tensors: OutputList[Optional[Tensor]] = [] + + # Cleared after recording + with dynamo_timed_cudagraph("CUDAGraphNode.record", compile_id, mode): + self.recording_outputs: Optional[OutputType] = self._record( + wrapped_function.model, recording_inputs + ) + self.outputs_metadata: OutputList[Union[dict[str, Any], int, None]] = [] + + # As with inputs, we do not want to keep the outputs permanently alive because that would prevent + # their memory being reclaimed in subsequent cuda graph recordings. We record the tensor metadata + # needed to reconstruct instead. + assert self.recording_outputs is not None + for out in self.recording_outputs: + if isinstance(out, torch.Tensor): + self.outputs_metadata.append( + self._tensor_metadata(out, ignore_storage_offset=False) + ) + else: + assert isinstance(out, (int, type(None))), type(out) + self.outputs_metadata.append(out) + + self.graph.replay() + + def _copy_inputs_and_remove_from_src( + self, dsts: list[InputType], srcs: list[InputType] + ) -> None: + dst_tensors = [] + src_tensors = [] + for idx in self.non_static_input_idx: + if not isinstance(srcs[idx], torch.Tensor): + continue + expanded_dims = self.expanded_dims[idx] + dst_tensors.append(index_expanded_dims(dsts[idx], expanded_dims)) # type: ignore[arg-type] + src_tensors.append(index_expanded_dims(srcs[idx], expanded_dims)) # type: ignore[arg-type] + srcs[idx] = None # type: ignore[call-overload] + # Fails on empty lists + if dst_tensors: + torch._foreach_copy_(dst_tensors, src_tensors) + + def check_static_inputs_are_stable(self, new_inputs: list[InputType]) -> None: + # avoid checking managed tensor static points since we already checked those in check_invariants + if ( + not self.rerecord_if_static_inputs_change + and not torch._C._tensors_data_ptrs_at_indices_equal( + new_inputs, # type: ignore[arg-type] + self.static_input_data_ptrs, + self.non_managed_static_input_idxs, + ) + ): + # this should error + error_msg = log_data_ptr_mismatch( + self.wrapped_function.placeholders, + new_inputs, + self.static_input_data_ptrs, + self.non_managed_static_input_idxs, + CheckInvariantStatus.StaticInputIdxMismatch, + ) + torch._check(False, lambda: error_msg) + + def run_first_inputs(self, new_inputs: list[InputType]) -> OutputType: + if config.triton.fast_path_cudagraph_asserts: + self.debug_check_invariants_before_invocation() + + # graph is already invoked in the __init__ + # inputs are copied over in _allocate_recording_inputs and subsequently cleared + assert len(new_inputs) == 0 + outputs = self.recording_outputs + self.recording_outputs = None + assert outputs is not None + return outputs + + def run(self, new_inputs: list[InputType]) -> OutputType: + self.check_static_inputs_are_stable(new_inputs) + + self._copy_inputs_and_remove_from_src(self.reconstructed_inputs, new_inputs) + + self.run_graph() + + outputs = self.reconstruct_outputs() + new_inputs.clear() + + if config.triton.fast_path_cudagraph_asserts: + self.debug_check_invariants_after_invocation() + + if config.triton.force_cudagraph_sync: + torch.cuda.synchronize() + + # Reset this to run the check in the future + self.static_inputs_stable = False + + return outputs + + def reconstruct_outputs(self) -> OutputType: + "Reconstruct output tensors according to their saved metadata and alias information" + + # Cached tensors will not yet be set on the first execution + # They are also cleared in checkpointing, so if we checkpoint this node + # and then execute it again we will need to repopulate cached tensors + if not self.cached_tensor_outputs: + self._initialize_cached_tensors() + + outputs: OutputType = [] + + for i, (storage_info, metadata) in enumerate( + zip(self.output_storage_alias, self.outputs_metadata) + ): + if not isinstance(metadata, dict): # tensor metadata + assert isinstance(metadata, (int, type(None))) + outputs.append(metadata) + continue + + cached_t = self.cached_tensor_outputs[i] + if cached_t is not None: + # this output represents a fresh allocated tensor. + # We return the same TensorImpl from run to run to avoid overhead. + # autograd.Function will reset the Autograd meta of output tensors + # as part of aot_autograd, but _backward_hooks are stored on tensors separately, + # so we need to manually reset hooks. + if cached_t._backward_hooks is not None: + cached_t._backward_hooks = None + + # No need to update weakrefs, already correctly initialized + outputs.append(cached_t) + continue + + static_t = self.static_output_tensors[i] + if static_t is not None: + assert self.outputs_weakrefs[i] is None + outputs.append(static_t) + continue + + storage = self.prepare_alias_info_for_tensor_construction( + storage_info, metadata + ) + + if isinstance(storage, UntypedStorage) or storage is None: + out = self._reconstruct_from_tensor_metadata(metadata, storage) + else: + assert isinstance(storage, int) + out = self._reconstruct_from_tensor_metadata( + metadata, cast(torch.Tensor, outputs[storage]).untyped_storage() + ) + + outputs.append(out) + w = self.outputs_weakrefs[i] + assert w is not None + w.swap_weakref(out.untyped_storage()._weak_ref()) + + return outputs + + def prepare_alias_info_for_tensor_construction( + self, + out_alias_info: Optional[OutputAliasInfo], + metadata: Union[dict[str, Any], int, None], + ) -> Union[UntypedStorage, None, int]: + if ( + isinstance(metadata, (int, type(None))) + or out_alias_info is UnaliasedStorage + ): + return None + + if isinstance(out_alias_info, AliasesPriorGraphOutput): + depth, existing_output_index = out_alias_info.index + ref = self.path_weakrefs[depth][existing_output_index] + assert ref is not None + return torch.UntypedStorage._new_with_weak_ptr(ref()) + + assert isinstance(out_alias_info, AliasesNewOutput) + return out_alias_info.index + + def prepare_storages_for_construction( + self, + ) -> list[Union[UntypedStorage, None, int]]: + output_storages = [] + for output_storage_alias, metadata in zip( + self.output_storage_alias, self.outputs_metadata + ): + output_storages.append( + self.prepare_alias_info_for_tensor_construction( + output_storage_alias, metadata + ) + ) + + return output_storages + + def run_graph(self) -> None: + assert self.graph is not None + self.graph.replay() + + def all_outputs_are_dead(self) -> bool: + "All outputs of the path from this node to its root are dead" + for depth, output_index in self.live_indices_after_graph: + if is_live(self.path_weakrefs[depth][output_index]): + return False + return True + + def _record(self, model: ModelType, inputs: list[InputType]) -> OutputType: + "Record the model" + assert self.graph is not None + + def static_input_iter() -> Generator[torch.Tensor, None, None]: + for i in self.wrapped_function.static_input_idxs: + _inp = inputs[i] + if isinstance( + _inp, torch.Tensor + ) and not self._is_cuda_graph_recorded_tensor(_inp): + yield _inp + + # see: output_is_alias_of_persistent_static_inputs above + static_input_persistent_storage_ptrs: dict[int, StorageWeakRefWrapper] = { + inp.untyped_storage().data_ptr(): StorageWeakRefWrapper(inp) + for inp in itertools.chain( + static_input_iter(), self.wrapped_function.constants + ) + } + + if config.triton.slow_path_cudagraph_asserts: + # need to use parent live weakrefs because live_indices isn't set yet + memory = ( + [] if self.parent is None else list(self.parent.path_live_weakrefs()) + ) + memory += [ + StorageWeakRefWrapper(elem) + for i, elem in enumerate(inputs) + if isinstance(elem, torch.Tensor) + and i not in self.wrapped_function.static_input_idxs + and elem.untyped_storage().data_ptr() != 0 + ] + check_memory_pool(self.device, self.cuda_graphs_pool, memory) + + with ( + preserve_rng_state(), + torch.cuda.device(self.device), + clear_cublas_manager(), + torch.cuda.graph( + self.graph, + stream=self.stream, + pool=self.cuda_graphs_pool, + capture_error_mode="thread_local", + ), + get_history_recording(), + ): + static_outputs = model(inputs) + + # running model should reclaim memory + assert len(inputs) == 0 + + if not isinstance(static_outputs, (list, tuple)): + static_outputs = (static_outputs,) + + # pyrefly: ignore [bad-argument-type] + self._add_first_outputs(static_outputs, static_input_persistent_storage_ptrs) + + # pyrefly: ignore [bad-return] + return static_outputs + + def _add_first_outputs( + self, + outputs: OutputType, + static_input_persistent_storage_ptrs: dict[int, StorageWeakRefWrapper], + ) -> None: + "Add the outputs from the first invocation of the node and set up metadata" + + # getting liveness before we have added the outputs to path, so the length + # of the two lists is equal + prev_liveness = self.recorded_liveness_before_graph + curr_liveness = self._get_liveness(self.path_weakrefs) + + delta = self._get_different_indices(prev_liveness, curr_liveness) + self.expected_dead_indices_after_graph = delta + + assert len(self.outputs_weakrefs) == 0 + # index from data pointer to index in outputs + output_new_storages_index: dict[StorageDataPtr, int] = {} + + self.unaliased_in_all_paths = [False for _ in range(len(outputs))] + self.static_output_tensors = [None for _ in range(len(outputs))] + + for i, o in enumerate(outputs): + if o is None or not isinstance(o, torch.Tensor): + self.output_storage_alias.append(UnaliasedStorage) + continue + + torch._check( + o.is_cuda or o.untyped_storage().data_ptr() == 0, + lambda: ( + "Expected all cuda outputs in cuda graph recording. Non cuda output " + f"from {self.stack_traces[i] if self.stack_traces else '(unknown)'}" + ), + ) + + ref = static_input_persistent_storage_ptrs.get( + o.untyped_storage().data_ptr(), None + ) + # also treat empty storages as static outputs because we do not need to manage their lifetime + # and they should not participate in checkpointing + is_empty_storage = o.untyped_storage().data_ptr() == 0 + if (ref and ref() is not None) or is_empty_storage: + self.output_storage_alias.append(None) + self.static_output_tensors[i] = o + continue + + path_ref = self._is_alias_of_live_recorded_tensor(o) + if path_ref is not None: + self._mark_prior_graph_output_as_aliased(path_ref) + + for idx, inp_path_ref in enumerate( + self.live_cudagraph_managed_path_refs + ): + if path_ref == inp_path_ref: + self.preserved_aliased_inputs[idx] = True + self.output_storage_alias.append(AliasesPriorGraphOutput(path_ref)) + continue + + if o.untyped_storage().data_ptr() in output_new_storages_index: + index = output_new_storages_index[o.untyped_storage().data_ptr()] + self.unaliased_in_all_paths[index] = False + self.output_storage_alias.append(AliasesNewOutput(index)) + continue + + output_new_storages_index[o.untyped_storage().data_ptr()] = i + self.output_storage_alias.append(UnaliasedStorage) + self.unaliased_in_all_paths[i] = True + + if self.stack_traces is None: + self.stack_traces = [None for _ in range(len(outputs))] + else: + assert len(self.stack_traces) == len(outputs), ( + "Wrong number of stack traces passed in" + ) + + assert not self.outputs_weakrefs + for out, static_output_tensor in zip(outputs, self.static_output_tensors): + if not isinstance(out, torch.Tensor) or static_output_tensor is not None: + self.outputs_weakrefs.append(None) + self.tensor_weakrefs.append(None) + else: + self.outputs_weakrefs.append(StorageWeakRefWrapper(out)) + self.tensor_weakrefs.append(TensorWeakRef(out)) + + self.recorded_liveness_after_graph = self._get_liveness(self.path_weakrefs) + self.checkpointed_caching_state = torch._C._cuda_getCheckpointState( + self.device, self.cuda_graphs_pool + ) + + # now, get liveness with outputs added + for depth in range(len(self.path_weakrefs)): + for output_index in range(len(self.path_weakrefs[depth])): + if is_live(self.path_weakrefs[depth][output_index]): + self.live_indices_after_graph.append((depth, output_index)) + + self.debug_check_invariants_after_invocation() + if config.triton.slow_path_cudagraph_asserts: + check_memory_pool( + self.device, self.cuda_graphs_pool, list(self.path_live_weakrefs()) + ) + + def _mark_prior_graph_output_as_aliased(self, index: PathOutputIndex) -> None: + "Remove a graph output from the unaliased, cached tensors in an ancestor node" + depth, output_index = index + node = list(self._path_from_root)[depth] + node.unaliased_in_all_paths[output_index] = False + x = self.path_weakrefs[depth][output_index] + assert x is not None + x.remove_extra_reference() + + def _initialize_cached_tensors(self) -> None: + # we should not be clearing output_weakrefs, and they should be set in the first + # record run + assert len(self.outputs_weakrefs) == len(self.outputs_metadata) + + for i, (storage_info, metadata, make_cached) in enumerate( + zip( + self.output_storage_alias, + self.outputs_metadata, + self.unaliased_in_all_paths, + ) + ): + if not make_cached: + self.cached_tensor_outputs.append(None) + continue + + assert storage_info is UnaliasedStorage + assert isinstance(metadata, dict) + s = self.create_storage(metadata) + out = self._reconstruct_from_tensor_metadata(metadata, storage=s) # type: ignore[arg-type] + + # XXX: let autograd know that there will be an additional reference to the tensor + # that can be ignored when deciding whether to do gradient buffer inplacing. + # Otherwise, inplacing could differ between tracing and subsequent execution. + # For some models we tested this led to inputs no longer being in cudagraph pools, + # leading to spurious re-recordings. + # It also tells AMP cache that even though the tensor impls cannot be cached + # in dtype conversions. + + torch._C._add_cached_tensor(out) + + self_ref = weakref.ref(self) + + # one reference in our array, and calling sys.getrefcount bumps the refcount by one + def check_refcount(i: int) -> bool: + self_loc = self_ref() + if self_loc is None: + return False + refcount = self_loc.get_output_refcount(i) + # pyrefly: ignore + if self_loc.cached_tensor_outputs[i]._use_count() > 1: + # c10::Tensor may also holds one reference count + assert refcount >= 3 + return refcount == 3 + else: + assert refcount >= 2 + return refcount == 2 + + check = functools.partial(check_refcount, i=i) + + self.outputs_weakrefs[i] = StorageWeakRefWrapper(out, extra_ref_check=check) + self.cached_tensor_outputs.append(out) + + def get_output_refcount(self, index: int) -> int: + return sys.getrefcount(self.cached_tensor_outputs[index]) + + @property + def parent(self) -> Optional[CUDAGraphNode]: + "unwraps the weakref to _parent" + return self._parent() if self._parent is not None else None + + @property + def _path_to_root(self) -> Generator[CUDAGraphNode, None, None]: + "Returns all nodes in the path starting at self and ending at root" + node = self + while node: + yield node + node = node.parent # type: ignore[assignment] + + @property + def _path_from_root(self) -> Generator[CUDAGraphNode, None, None]: + "Returns all nodes in the path starting at the root and ending at self" + nodes = reversed(list(self._path_to_root)) + yield from nodes + + def _is_cuda_graph_recorded_tensor(self, t: torch.Tensor) -> bool: + "Is this tensor an output of a node in this path" + for output_refs in self.path_weakrefs: + for storage_weak_ref in output_refs: + if storage_weak_ref is None: + continue + # don't need to check liveness of storage since the cuda graph managed + # memory is never released. + data_ptr = storage_weak_ref.data_ptr() + if t.untyped_storage().data_ptr() == data_ptr: + return True + + return False + + def _is_alias_of_live_recorded_tensor( + self, t: torch.Tensor + ) -> Optional[PathOutputIndex]: + for depth, output_refs in enumerate(self.path_weakrefs): + for output_index, storage_ref in enumerate(output_refs): + if (storage_and_ptr := maybe_deref(storage_ref)) is not None: + _storage, ptr = storage_and_ptr + if ptr == t.untyped_storage().data_ptr(): + return (depth, output_index) + + return None + + @staticmethod + def _check_liveness( + indices: list[PathOutputIndex], + output_refs: list[list[Optional[StorageWeakRefWrapper]]], + ) -> bool: + "Check that all of the indices specified are dead references" + for depth, output_index in indices: + w = output_refs[depth][output_index] + assert w is not None + if w() is not None: + return False + return True + + def add_child(self, function_id: FunctionID, node: CUDAGraphNode) -> None: + "Adds node as a a child of self" + self.children[function_id].append(node) + + @staticmethod + def _get_different_indices( + prev: list[list[bool]], curr: list[list[bool]] + ) -> list[PathOutputIndex]: + "Find indices where the two lists differ." + dead_indices = [] + assert len(prev) <= len(curr) + for i, (outputs1, outputs2) in enumerate(zip(prev, curr)): + assert len(outputs1) == len(outputs2) + for j, (output1, output2) in enumerate(zip(outputs1, outputs2)): + if output1 != output2: + dead_indices.append((i, j)) + + return dead_indices + + @staticmethod + def _get_liveness( + weakrefs: list[list[Optional[StorageWeakRefWrapper]]], + ) -> list[list[bool]]: + "Maps weakrefs to true if the reference is alive and false otherwise" + if len(weakrefs) == 0: + return [] + + return [pytree.tree_map(is_live, outputs) for outputs in weakrefs] + + def debug_assert_invariants( + self, expected_liveness: list[list[bool]], newly_dead: list[PathOutputIndex] + ) -> None: + if not config.triton.fast_path_cudagraph_asserts: + return + + for i, node in enumerate(self._path_from_root): + assert self.path_weakrefs[i] is node.outputs_weakrefs + + nodes = list(self._path_from_root) + + live_blocks = get_block_addrs(self.cuda_graphs_pool) + + live_storage_data_ptrs = OrderedSet[Any]() + live_storage_weak_ptrs = OrderedSet[Any]() + + for depth, outputs_liveness in enumerate(expected_liveness): + for output_idx, output_liveness in enumerate(outputs_liveness): + # tensor can die early, but it can't be alive when it should be dead + w = self.path_weakrefs[depth][output_idx] + if (stor_weak_ptr_and_data_ptr := maybe_deref(w)) is not None: + assert output_liveness + stor_weak_ptr, stor_data_ptr = stor_weak_ptr_and_data_ptr + assert (stor_data_ptr in live_storage_data_ptrs) == ( + stor_weak_ptr in live_storage_weak_ptrs + ) + live_storage_data_ptrs.add(stor_data_ptr) + live_storage_weak_ptrs.add(stor_weak_ptr) + + is_persistent_alias = ( + nodes[depth].static_output_tensors[output_idx] is not None + ) + + if is_persistent_alias: + assert stor_data_ptr not in live_blocks + + for depth, output_index in newly_dead: + assert not is_live(self.path_weakrefs[depth][output_index]) + + def debug_check_invariants_before_invocation(self) -> None: + self.debug_assert_invariants( + self.recorded_liveness_before_graph, self.expected_dead_indices_before_graph + ) + + def debug_check_invariants_after_invocation(self) -> None: + self.debug_assert_invariants( + self.recorded_liveness_before_graph, self.expected_dead_indices_after_graph + ) + + def data_ptrs_dead_since_invocation(self) -> list[int]: + """ + Since this node was invoked, return data ptrs of all tensor outputs that have died + in the current executing tree path. + """ + curr_liveness = self._get_liveness(self.path_weakrefs) + _get_different_indices = self._get_different_indices( + self.recorded_liveness_after_graph, curr_liveness + ) + + path = list(self._path_from_root) + ptrs_to_deallocate = [] + for depth, output_index in _get_different_indices: + ptrs_to_deallocate.append( + path[depth].outputs_metadata[output_index]["data_ptr"] # type: ignore[index] + ) + + return ptrs_to_deallocate + + def path_live_weakrefs(self) -> Iterator[StorageWeakRefWrapper]: + for i, j in self.live_indices_after_graph: + out = self.path_weakrefs[i][j] + if out is not None and is_live(out): + yield out + + def remove_node_cached_tensors(self) -> None: + for t in self.cached_tensor_outputs: + if t is not None: + torch._C._remove_cached_tensor(t) + self.cached_tensor_outputs.clear() + + for i, unaliased in enumerate(self.unaliased_in_all_paths): + if unaliased: + n = self.outputs_weakrefs[i] + assert n is not None + n.remove_extra_reference() + + def remove_path_cached_tensors(self) -> None: + for node in self._path_from_root: + node.remove_node_cached_tensors() + + def clear_path_state(self) -> None: + "Clear the path state in this current executing node" + # this doesn't actually do anything right now, leaving it as placeholder + + @staticmethod + def _tensor_metadata( + x: torch.Tensor, ignore_storage_offset: bool = True + ) -> dict[str, Any]: + assert isinstance(x, torch.Tensor) + # We ignore the storage offset for inputs, but not for outputs + # TODO: - should we make the storage resizable ? + return { + "nbytes": x.untyped_storage().nbytes(), + "data_ptr": x.untyped_storage().data_ptr(), + "size": x.shape, + "stride": x.stride(), + "dtype": x.dtype, + "device": x.device, + "storage_offset": x.storage_offset() if not ignore_storage_offset else 0, + } + + def _reconstruct_from_tensor_metadata( + self, metadata: dict[str, Any], storage: Optional[UntypedStorage] = None + ) -> Tensor: + s = self.create_storage(metadata) if storage is None else storage + return torch._C._construct_CUDA_Tensor_From_Storage_And_Metadata(metadata, s) # type: ignore[arg-type] + + def create_storage(self, metadata: dict[str, Any]) -> torch.types.Storage: + return torch._C._construct_storage_from_data_pointer( + metadata["data_ptr"], metadata["device"], metadata["nbytes"] + ) + + def _allocate_and_copy_recording_inputs( + self, inputs: list[InputType] + ) -> list[InputType]: + """ + Allocate inputs for non static, non cudagraph managed tensors in the memory pool + and copy over the tensor values. + """ + + torch.cuda.synchronize() + self.stream.wait_stream(torch.cuda.current_stream()) + recording_inputs: list[InputType] = [] + + with ( + warnings.catch_warnings(record=True), + torch.cuda.device(self.device), + _use_cuda_memory_pool_manager( + self.device, + mem_pool=self.cuda_graphs_pool, + stream=self.stream, + ), + ): + for i, inp in enumerate(inputs): + if not isinstance(inp, torch.Tensor): + assert isinstance(inp, (int, torch.Generator)) + # pyrefly: ignore [bad-argument-type] + recording_inputs.append(inp) + elif i not in self.static_input_idxs: + # static_input does an allocation! + recording_inputs.append(static_input(inp)) + else: + recording_inputs.append(inp) + + self._copy_inputs_and_remove_from_src(recording_inputs, inputs) + + return recording_inputs + + def check_invariants( + self, inputs: list[InputType] + ) -> tuple[CheckInvariantStatus, Callable[..., str]]: + """ + Checks if this node can be run. The same pattern of tensor liveness, static inputs, + and tensors managed in the cudagraph private pool must remain stable. + """ + + _logger = functools.partial( + log_data_ptr_mismatch, + self.wrapped_function.placeholders, + inputs, + self.static_input_data_ptrs, + ) + + # previously managed data pointers remain stable + # this is on the hot path so moved to C++. equivalent to: + # return all(t.data_ptr() == data_ptr for (t, data_ptr) in zip(tensors, data_ptrs)) + if not torch._C._tensors_data_ptrs_at_indices_equal( + inputs, # type: ignore[arg-type] + self.static_input_data_ptrs, + self.cudagraph_managed_idxs, + ): + status = CheckInvariantStatus.CudagraphManagedIdxMismatch + _logger = functools.partial( + _logger, + self.cudagraph_managed_idxs, + status, + ) + return status, _logger + + if not self._check_liveness( + self.expected_dead_indices_before_graph, self.path_weakrefs + ): + status = CheckInvariantStatus.ExpectedDeadIndicesBeforeGraphMismatch + return status, lambda: f"{status}" + + # static input data pointers should remain stable + # if we are inlining builtin nn modules we re-record in this case + # if we are not inlining builtin nn modules, we check this in check_static_inputs_are_stable + # and error if they are not stable + if ( + self.rerecord_if_static_inputs_change + and not torch._C._tensors_data_ptrs_at_indices_equal( + inputs, # type: ignore[arg-type] + self.static_input_data_ptrs, + self.static_input_idxs, + ) + ): + status = CheckInvariantStatus.StaticInputIdxMismatch + _logger = functools.partial( + _logger, + self.static_input_idxs, + status, + ) + return status, _logger + + # the cudagraph managed tensors which died upon recording must also die upon + # this invocation. it is too late to check after we've replayed the graph, + # because we would have already written over their memory. + for idx in self.cudagraph_managed_idxs: + if not self.preserved_aliased_inputs[idx]: + inputs[idx] = None # type: ignore[call-overload] + + torch._check( + self._check_liveness( + self.expected_dead_indices_after_graph, self.path_weakrefs + ), + lambda: "TODO: graph recording observed an input tensor deallocate during graph " + " recording that did not occur during replay. Please file an issue.", + ) + return CheckInvariantStatus.SUCCESS, lambda: f"{CheckInvariantStatus.SUCCESS}" + + def num_descendants(self) -> int: + "Total number of descendents of this node" + num_desc = 0 + for children in self.children.values(): + for child in children: + num_desc += 1 + num_desc += child.num_descendants() + return num_desc + + +def get_cudagraph_segments(pool_id: tuple[int, int]) -> Any: + segments = torch.cuda.memory_snapshot() + return [segment for segment in segments if segment["segment_pool_id"] == pool_id] + + +def get_block_addrs(pool_id: tuple[int, int], live_only: bool = True) -> list[int]: + blocks = [] + + for segment in get_cudagraph_segments(pool_id): + addr = segment["address"] + for block in segment["blocks"]: + if block["state"] == "active_allocated" or not live_only: + blocks.append(addr) + + addr += block["size"] + + return blocks + + +def format_tb(frames: list[Any]) -> str: + formatted_traceback = [ + traceback.FrameSummary(entry["filename"], entry["line"], entry["name"]) + for entry in frames + ] + + return "".join(traceback.format_list(formatted_traceback)) + + +def check_memory_pool( + device: int, + pool_id: tuple[int, int], + live_storages_ptrs: list[StorageWeakRefWrapper], +) -> None: + assert all(isinstance(elem, StorageWeakRefWrapper) for elem in live_storages_ptrs) # noqa: C419 + unique_storages = {stor.data_ptr() for stor in live_storages_ptrs if stor()} # noqa: set_linter + + # check if there is a divergence first, then do the expensive snapshot call after + # we know it will error + if torch._C._cuda_checkPoolLiveAllocations(device, pool_id, unique_storages): + return + + # at this point we are past the fast-path. we have seen rare cases where a dead tensor is dead, + # but hasn't been gc'd yet, and gives false positive for allocated_not_in_live_storages + gc.collect() + torch.cuda.synchronize() + + segments = get_cudagraph_segments(pool_id) + + allocated_not_in_live_storages = {} + + for segment in segments: + addr = segment["address"] + for block in segment["blocks"]: + if block["state"] == "active_allocated": + if addr not in unique_storages: + allocated_not_in_live_storages[addr] = block + else: + unique_storages.remove(addr) + + addr += block["size"] + + torch._check( + len(unique_storages) == 0, + lambda: f"These storage data ptrs are not allocated in pool {pool_id} but should be {unique_storages}", + ) + + if len(allocated_not_in_live_storages) != 0: + formatted = [] + for dp, block in allocated_not_in_live_storages.items(): + trace = format_tb(block.get("frames", [])) + # pyrefly: ignore [bad-argument-type] + formatted.append(f"Data Pointer: {dp}, history: \n{trace}") + formatted_s = "\n".join(formatted) + msg = ( + f"These live storage data ptrs are in the cudagraph pool but not " + f"accounted for as an output of cudagraph trees: \n\n{formatted_s}" + ) + raise RuntimeError(msg) + + +class ExecutionState(Enum): + """ + Represents the state of the CUDAGraph Tree. Will be None if there is no live current memory allocated + in the cuda graph pool. Otherwise will reflect the state of the most recently executed node. + """ + + NONE = auto() + WARMUP = auto() + RECORDING = auto() + EXECUTION = auto() + + +class CompilationMode(Enum): + FORWARD = auto() + BACKWARD = auto() + INFERENCE = auto() + + +class CUDAGraphTreeManager: + """ + Groups individual recordings or executions of cuda graphs into a tree of recordings, + and checks required invariants, and manages warmups of graphs. + + When graphs are recorded in the same tree, it enforces subsequent execution + to follow the same order and have the same output tensor livespans. To remove + unnecessary coupling of cuda graphs (and additional imposed invariants), + the tree manager will end a currently recording tree whenever it is valid - when + the memory pool no longer has any live allocations. + + We ignore outputs from a previous generation that correspond to prior model outputs. + Currently this is hardcoded `GenerationTracker.generation` tracked in torch dynamo. + # TODO: make generation increment configurable, warn on overwrite. + + We run graph warmups in the cudagraph memory pool and return the result on the first invocation + of a function. For many models it is important to reclaim activations as you run the backward. + If we were to warm up the model and keep an extra copy of the inputs around to subsequently + use for recording, we would incur a memory penalty. Additionally, if we are part way through training + your model and need to recompile, memory will be allocated to the cuda graph pool, so we run this + warmup run in the cuda graph memory pool. As for recording, warm up needs the state of live tensors + to be accurately reflected so we checkpoint the allocator state if we need to warm up following graph + replay. + """ + + def __init__(self, device_index: int) -> None: + # roots are functions which have no dependencies on an other node. I.e., + # when they are first invoked, none of their inputs are outputs are outputs + # of another node, nor are there any live outputs of another node whose + # liveness would create a dependency. + self.roots: dict[FunctionID, list[CUDAGraphNode]] = defaultdict(list) + + # mapping from function id to wrapped function + self.ids_to_funcs: dict[FunctionID, WrappedFunction] = {} + + self.ids_to_stack_traces: dict[FunctionID, Optional[StackTraces]] = {} + + self.warmed_up_functions: OrderedSet[FunctionID] = OrderedSet() + # if we fail to increment generation, and are stuck warming up, + # only warn on each function once + self.warned_functions: OrderedSet[FunctionID] = OrderedSet() + torch._C._set_cached_tensors_enabled(True) + + # warn only once if a function mutates inputs + self.warned_mutation: OrderedSet[FunctionID] = OrderedSet() + + # NB: cuda caching allocator will remember the stream a segment is allocated to + # and only allocate that segment to the same stream. we need to use a single stream + # for all allocations to the memory pool, otherwise the allocations to separate streams + # will not be reused; separate recordings would have use the same memory pool, but not + # the same memory. + + with torch.cuda.device(device_index): + torch.cuda.synchronize() + self.stream = torch.cuda.Stream() + self.stream.wait_stream(torch.cuda.current_stream()) + + # Keeps Memory Pool Alive + self.graph: Optional[torch.cuda.CUDAGraph] = torch.cuda.CUDAGraph() + self.cuda_graphs_thread_pool = torch.cuda.graph_pool_handle() + + with ( + warnings.catch_warnings(record=True), + torch.cuda.graph( + self.graph, + pool=self.cuda_graphs_thread_pool, + stream=self.stream, + capture_error_mode="thread_local", + ), + ): + pass + + self.graph_counter = itertools.count(0) + self.func_counter = itertools.count(0) + + # mapping from graph_id to (function id to mutation type hint) since we are + # specializing on a particular combination of Parent Node -> Function ID. + self.non_cudagraph_managed_mutation_hint: dict[ + Optional[GraphID], dict[FunctionID, bool] + ] = defaultdict(dict) + self.warmup_node_counter = itertools.count(start=-1, step=-1) + + # mapping from graph_id to (function id to re-record count). We fall back to + # eager function if a function is re-recorded frequently on a node. + self.num_rerecord: dict[Optional[GraphID], dict[FunctionID, int]] = defaultdict( + lambda: defaultdict(lambda: 0) + ) + + # whether we the current node is in a state of warmup, recording, execution. If + # there is no current node the state will be ExecutionState.None. + self.path_state = ExecutionState.NONE + self.device_index = device_index + + # the most recently invoked cudagraph wrapping of a function. Will be None + # when there is no output from a previous recording or execution whose memory + # we need to respect in the cuda caching allocation. If you incremented generation, + # this will also be none, as ignore those allocations. + self.current_node: Optional[Union[CUDAGraphNode, CUDAWarmupNode]] = None + + # current generation of cudagraph invocations. when torch.compile is run + # we increment the current generation. are willing to ignore live outputs + # of a previous generation in checking liveness. + self.current_gen: int = -1 + + # number of instances we are in execution and failed to match to an + # existing child + self.debug_fail_counter = 0 + # number of instances we had to checkpoint the function + self.debug_checkpointing_counter = 0 + + self.id_to_mode: dict[FunctionID, CompilationMode] = {} + self.id_to_compile_id: dict[FunctionID, Optional[CompileId]] = {} + + # Note: [Backward Generation Handling] + # We generally perform a sequence of forward executions followed by backward executions. + # If multiple torch.compile wrapped forwards are executed with their backwards pending, + # we should not disregard the outputs from a prior torch.compile since the entire training + # loop hasn't completed. Occasionally, a backward pass corresponding to a forward pass may + # not be executed, so we cannot wait for all pending forward pass backward completions, so + # we cannot wait for all backwards to have been invoked. Instead we wait for a single backward + # invocation. Triggering a backward pass typically doesn't lead to another torch.compile + # invocation, making it less likely for the generation to increase between multiple + # backward calls. The following use case is covered by this approach: + # mod1 = torch.compile(...) + # mod2 = torch.compile(...) + # mod2(mod1(x)).sum().backward() + + self.running_forwards_with_pending_backwards = False + self.mode: Optional[CompilationMode] = None + + self.disable_invalidate_aliases = ( + False + if not torch._environment.is_fbcode() + else torch._utils_internal.justknobs_check( + "pytorch/inductor:disable_cudagraph_alias_invalidation" + ) + ) + + def run(self, new_inputs: list[InputType], function_id: FunctionID) -> OutputType: + assert self.graph is not None, "Running CUDAGraph after shutdown" + self.mode = self.id_to_mode[function_id] + self.compile_id = self.id_to_compile_id[function_id] + out = self._run(new_inputs, function_id) + + # The forwards are only pending following invocation, not before + if self.mode == CompilationMode.FORWARD: + self.running_forwards_with_pending_backwards = True + elif self.mode == CompilationMode.BACKWARD: + self.running_forwards_with_pending_backwards = False + + return out + + def set_to_running_backward(self) -> None: + self.running_forwards_with_pending_backwards = False + self.mode = CompilationMode.BACKWARD + + def _get_cuda_graph_recorded_tensor_checker(self) -> Callable[[Tensor], bool]: + return ( + self.current_node._is_cuda_graph_recorded_tensor + if isinstance(self.current_node, (CUDAGraphNode, CUDAWarmupNode)) + else lambda _: False + ) + + def new_warmup_node_id(self) -> GraphID: + return GraphID(next(self.warmup_node_counter)) + + def _update_non_cudagraph_managed_mutation( + self, function_id: FunctionID, inputs: list[InputType] + ) -> None: + node_id = self._get_node_id() + if maybe_mutation_str := check_for_mutation( + self.ids_to_funcs[function_id], + inputs, + self._get_cuda_graph_recorded_tensor_checker(), + ): + self.non_cudagraph_managed_mutation_hint[node_id][function_id] = True + # warn once per function_id + if function_id in self.warned_mutation: + return + self.warned_mutation.add(function_id) + log_cudagraph_skip_and_bump_counter(maybe_mutation_str) + else: + self.non_cudagraph_managed_mutation_hint[node_id][function_id] = False + + def _get_node_id(self) -> Optional[GraphID]: + if self.current_node is None: + return None + elif isinstance(self.current_node, (CUDAGraphNode, CUDAWarmupNode)): + return self.current_node.id + else: + raise RuntimeError(f"Unknown node type {type(self.current_node)}") + + def exceed_rerecord_limit( + self, node_id: Optional[GraphID], function_id: FunctionID + ) -> bool: + if torch._dynamo.config.inline_inbuilt_nn_modules: + return False + + return ( + self.num_rerecord[node_id][function_id] + > torch._inductor.config.triton.cudagraph_unexpected_rerecord_limit + ) + + def _run(self, new_inputs: list[InputType], function_id: FunctionID) -> OutputType: + # we will try to end the current execution lazily, since + # we dont want to do unnecessary checking of the existing outputs + # on the hot path, but both recording and warmup only happen once + # so we check up front + if self.in_recording: + self.try_end_curr_recording(function_id) + + if self.in_warmup: + self.try_end_curr_warmup(function_id) + + node_id = self._get_node_id() + if function_id not in self.non_cudagraph_managed_mutation_hint[node_id]: + self._update_non_cudagraph_managed_mutation(function_id, new_inputs) + + # Early exit if the function mutates inputs which are neither parameters/buffers nor + # cudagraph recorded tensors. This check should happen after `try_end_curr_recording` + # and `try_end_curr_warmup` which may change self.current_node. + if self.non_cudagraph_managed_mutation_hint[node_id][ + function_id + ] or self.exceed_rerecord_limit(node_id, function_id): + return self.ids_to_funcs[function_id].model(new_inputs) + + # warming up a function and subsequentally recording may use different memory addresses + # because both depend on the state of the caching allocator. if we warm up graph A, + # then warm up graph B and make more allocations, the subsequent recording of A will not + # necessarily use the same addresses as in the warm up. Thus any warm up of a node can only + # be followed by warm up runs. + if ( + ( + not ( + function_id in self.warmed_up_functions + or config.triton.skip_cudagraph_warmup + ) + ) + or self.in_warmup + or config.triton.force_cudagraphs_warmup + ): + # If we are in the middle of executing cuda graphs, then we need to checkpoint memory state. + # Both Recording and Warmup will be reflected in the allocator and dont need changes + if self.path_state == ExecutionState.EXECUTION: + self.apply_checkpoint_execution_state_in_allocator() + + return self.run_eager(new_inputs, function_id) + + assert not isinstance(self.current_node, CUDAWarmupNode) + child_nodes = ( + self.roots if self.current_node is None else self.current_node.children + ) + + if not self.in_recording: + unexpected_rerecord, unexpected_rerecord_reason = False, lambda: "" + for child in child_nodes[function_id]: + # here we are checking memory consistency between recording and execution, + # as well as things like stability of tensor locations, etc + # and other + status, status_logger = child.check_invariants(new_inputs) + if status == CheckInvariantStatus.SUCCESS: + return self.execute_node(child, new_inputs) + + if ( + status == CheckInvariantStatus.StaticInputIdxMismatch + or status == CheckInvariantStatus.CudagraphManagedIdxMismatch + ): + unexpected_rerecord = True + unexpected_rerecord_reason = status_logger + + # now that we know the new function can't be run as a child of the + # current node, if it is a root, try to end the current execution. + # as noted above, we want to do this lazily to avoid having to + # check all existing outputs + if self.current_node is not None and function_id in self.roots: + self.try_end_curr_execution() + + # run again to hit the root matching case which must succeed + if self.current_node is None: + return self.run(new_inputs, function_id) + + if len(self.ids_to_funcs[function_id].mutated_input_idxs) > 0: + self._update_non_cudagraph_managed_mutation(function_id, new_inputs) + if self.non_cudagraph_managed_mutation_hint[self._get_node_id()][ + function_id + ]: + return self.ids_to_funcs[function_id].model(new_inputs) + + # nb: run before checkpointing because checkpointing is slow, and we will + # be using the eager caching allocator pool which does not require live + # accounting of tensors in cudagraph allocator + if unexpected_rerecord: + curr_node_id = self._get_node_id() + self.num_rerecord[curr_node_id][function_id] += 1 + if self.exceed_rerecord_limit(curr_node_id, function_id): + _id = curr_node_id.id if curr_node_id else None + log_cudagraph_skip_and_bump_counter( + f"skipping cudagraph due to function {function_id.id} exceeding max " + f"re-recording limit " + f"(={torch._inductor.config.triton.cudagraph_unexpected_rerecord_limit}) " + f"on cudagraph node {_id} due to {unexpected_rerecord_reason()}." + ) + return self.ids_to_funcs[function_id].model(new_inputs) + + # at this point, we necessarily will do a new recording + self.debug_fail_counter += 1 + + self.try_end_curr_execution() + if self.current_node is not None: + self.apply_checkpoint_execution_state_in_allocator() + + # now, we are in a recording state ! + return self.record_function(new_inputs, function_id) + + def shutdown(self) -> None: + """ + Remove all cached tensors in all nodes. Because cached tensors can hold gradients which in turn + might reference a backward which invokes a CUDA Graph Node, we have to manually clear them on shutdown + to avoid a reference cycle. + """ + nodes = [] + for roots in self.roots.values(): + nodes.extend(roots) + + while nodes: + node = nodes.pop() + for children in node.children.values(): + nodes.extend(children) + node.remove_node_cached_tensors() + node.graph = None + + self.graph = None + self.roots = None # type: ignore[assignment] + self.current_node = None + + def record_function( + self, new_inputs: list[InputType], function_id: FunctionID + ) -> OutputType: + assert not isinstance(self.current_node, CUDAWarmupNode) + with torch._dynamo.callback_handler.install_callbacks( + CallbackTrigger.CUDAGRAPH_RECORDING, str(self.compile_id) + ): + graph_id = self.new_graph_id() + log.debug( + "Recording function %d of graph recording id %d", + function_id.id, + graph_id.id, + ) + torch.cuda.synchronize() + node = CUDAGraphNode( + self.ids_to_funcs[function_id], + graph_id, + self.current_node, + new_inputs, + self.cuda_graphs_thread_pool, + self.device_index, + self.ids_to_stack_traces[function_id], + self.stream, + self.mode, + self.compile_id, + ) + if self.current_node is None: + self.roots[function_id].append(node) + else: + self.current_node.add_child(function_id, node) + self.current_node = node + self.path_state = ExecutionState.RECORDING + self.update_generation() + torch.cuda.synchronize() + return node.run_first_inputs(new_inputs) + + def execute_node( + self, node: CUDAGraphNode, new_inputs: list[InputType] + ) -> OutputType: + self.current_node = node + self.path_state = ExecutionState.EXECUTION + self.update_generation() + return node.run(new_inputs) + + def run_eager( + self, new_inputs: list[InputType], function_id: FunctionID + ) -> OutputType: + # this is only stored on current node, because when we start a new path, + # we will deallocate it + already_warm = function_id in self.warmed_up_functions + if not already_warm: + log.debug("Running warmup of function %d", function_id.id) + else: + log.debug( + "Running eager of function %d because ancestor needed to warm up", + function_id.id, + ) + self.warmed_up_functions.add(function_id) + node = CUDAWarmupNode( + self.ids_to_funcs[function_id], + self.current_node, + self.cuda_graphs_thread_pool, + self.graph, + self.device_index, + self.ids_to_stack_traces[function_id], + self.stream, + already_warm, + self.new_warmup_node_id(), + ) + self.current_node = node + self.path_state = ExecutionState.WARMUP + self.update_generation() + return node.run(new_inputs) + + def new_graph_id(self) -> GraphID: + return GraphID(next(self.graph_counter)) + + def new_func_id(self) -> FunctionID: + return FunctionID(next(self.func_counter)) + + def add_function( + self, + model: ModelType, + inputs: list[InputType], + static_input_idxs: Sequence[int], + stack_traces: Optional[StackTraces], + mode: CompilationMode, + constants: tuple[torch.Tensor, ...], + placeholders: tuple[PlaceholderInfo, ...], + mutated_input_idxs: tuple[int, ...], + compile_id: Optional[CompileId], + ) -> tuple[ + ModelType, + OutputType, + ]: + id = self.new_func_id() + self.ids_to_stack_traces[id] = stack_traces + self.ids_to_funcs[id] = WrappedFunction( + model, + list(static_input_idxs), + id, + tuple(t for t in constants if isinstance(t, torch.Tensor) and t.is_cuda), + placeholders, + mutated_input_idxs, + ) + self.id_to_mode[id] = mode + self.id_to_compile_id[id] = compile_id + fn = functools.partial(self.run, function_id=id) + + # container needs to set clean up when fn dies + get_container(self.device_index).add_strong_reference(fn) + return fn, fn(inputs) + + @property + def in_recording(self) -> bool: + return self.path_state == ExecutionState.RECORDING + + @property + def in_warmup(self) -> bool: + return self.path_state == ExecutionState.WARMUP + + def get_roots(self) -> Iterator[CUDAGraphNode]: + for nodes in self.roots.values(): + yield from nodes + + @property + def current_node(self) -> Optional[Union[CUDAGraphNode, CUDAWarmupNode]]: + return self._current_node + + @current_node.setter + def current_node( + self, value: Optional[Union[CUDAGraphNode, CUDAWarmupNode]] + ) -> None: + self._current_node = value + if value is None: + self.path_state = ExecutionState.NONE + + def update_generation(self) -> None: + self.current_gen = self.get_curr_generation() + + @staticmethod + def get_curr_generation() -> int: + if MarkStepBox.mark_step_counter != 0: + return MarkStepBox.mark_step_counter + + return GenerationTracker.generation + + @staticmethod + def user_invoked_mark_step() -> bool: + return MarkStepBox.mark_step_counter != 0 + + def can_start_new_generation(self) -> bool: + if not self.in_new_torch_compile_invocation(): + return False + + if self.user_invoked_mark_step(): + return True + + return not self.running_forwards_with_pending_backwards + + def in_new_torch_compile_invocation(self) -> bool: + return self.current_gen != self.get_curr_generation() + + def try_end_curr_recording(self, function_id: FunctionID) -> None: + """ + Check if the current recording can be terminated, either because all outputs of the + previously recorded node are dead or because it was executed in a different + generation. Will set current_node to None and in_recording to False if successful. + """ + assert self.in_recording + assert self.current_node is not None + + # multiple invocations, allow overwriting the previous generation + if self.can_start_new_generation(): + self.dealloc_current_path_weakrefs() + self.clear_current_path_state_and_set_to_none() + return + + if self.current_node.all_outputs_are_dead(): + self.clear_current_path_state_and_set_to_none() + return + + self.check_warn_on_unable_to_start_executing(function_id) + + def try_end_curr_execution(self) -> None: + """ + Check if the current executing node can be terminated, either because all outputs of the + previously executed node are dead or because it was executed in a different generation. + Will set current_node to None if successful. + """ + + assert not self.in_recording + if self.current_node is None: + return + + if self.can_start_new_generation(): + self.clear_current_path_state_and_set_to_none() + return + + if self.current_node.all_outputs_are_dead(): + self.clear_current_path_state_and_set_to_none() + + def try_end_curr_warmup(self, function_id: FunctionID) -> None: + if self.can_start_new_generation(): + self.dealloc_current_path_weakrefs() + self.current_node = None + return + + assert self.current_node is not None + if self.current_node.all_outputs_are_dead(): + self.current_node = None + return + + self.check_warn_on_unable_to_start_executing(function_id) + + def check_warn_on_unable_to_start_executing(self, function_id: FunctionID) -> None: + "Warn if we in a potential loop where we are unable to hit fast path" + if ( + function_id in self.warned_functions + or not self.in_new_torch_compile_invocation() + ): + return + + assert self.current_node is not None + existing_nodes = [ + node + for node in self.current_node._path_from_root + if node.wrapped_function.id == function_id + ] + + if len(existing_nodes) <= 1: + return + + # repeated same pattern + parents = OrderedSet( + [ + n.parent.wrapped_function.id + for n in itertools.chain(existing_nodes, (self.current_node,)) + if n.parent is not None + ] + ) + if len(parents) == len(existing_nodes): + return + + self.warned_functions.add(function_id) + warnings.warn( + "Unable to hit fast path of CUDAGraphs because of pending, uninvoked backwards. " + "Consider running with torch.no_grad() or using torch.compiler.cudagraph_mark_step_begin() " + "before each model invocation" + ) + + @staticmethod + def format_dealloc_msg(stack_trace: Optional[str]) -> str: + stack_trace = ( + stack_trace.strip() if stack_trace else "[Could not find stack trace]" + ) + return ( + "Error: accessing tensor output of CUDAGraphs that has been overwritten by a subsequent run. " + f"Stack trace: {stack_trace}. " + "To prevent overwriting, clone the tensor outside of torch.compile() " + "or call torch.compiler.cudagraph_mark_step_begin() before each model invocation." + ) + + def dealloc_current_path_weakrefs(self) -> None: + assert self.current_node is not None + # TODO: we could also allow the these weak refs to continue to be allocated, + # but that adds some complications. + + stor_stack_trace: dict[int, Optional[str]] = {} + for node in self.current_node._path_from_root: + assert node.stack_traces is not None + assert len(node.tensor_weakrefs) == len(node.stack_traces) + for t, stack_trace in zip(node.tensor_weakrefs, node.stack_traces): + ten = None if t is None else t() + if ten is None: + continue + + torch._C._set_storage_access_error_msg( + ten, self.format_dealloc_msg(stack_trace) + ) + + # we would to enable the following assertion, but an internal model failed with a command + # that does not repro. len(node.outputs_weakrefs) == len(node.stack_traces) + # so, pessimistically assume that they might differ by doing the debug info + # loop separately from the dealloc loop + if self.disable_invalidate_aliases: + continue + + for storage_ref, stack_trace in zip( + node.outputs_weakrefs, node.stack_traces + ): + if not storage_ref: + continue + + stor_stack_trace[storage_ref.data_ptr()] = stack_trace + + deleted = OrderedSet[Any]() + for storage_ref in self.current_node.path_live_weakrefs(): + _storage_deref = storage_ref() + if _storage_deref and storage_ref.data_ptr() not in deleted: + deleted.add(storage_ref.data_ptr()) + + msg = self.format_dealloc_msg( + stor_stack_trace.get(storage_ref.data_ptr()) + ) + torch._C._free_And_Remove_DeleterFn(_storage_deref) + + if self.disable_invalidate_aliases: + continue + + torch._C._set_storage_data_ptr_access_error_msg(_storage_deref, msg) + + def clear_current_path_state_and_set_to_none(self) -> None: + assert isinstance(self.current_node, CUDAGraphNode) + self.current_node.clear_path_state() + self.current_node = None + + def apply_checkpoint_execution_state_in_allocator(self) -> None: + """ + Checkpoint the current execution state in the caching allocator so that + additional cudagraph recordings can be made respecting existent live storages. + """ + assert isinstance(self.current_node, CUDAGraphNode) + self.debug_checkpointing_counter += 1 + log.debug( + "Checkpointing cuda caching allocator state. Number of checkpoints %d", + self.debug_checkpointing_counter, + ) + + state = self.current_node.checkpointed_caching_state + device = self.current_node.device + assert state is not None and device is not None + + # currently we deallocate on instead of allowing stale recordings + stale_storages: list[int] = [] + + # remove cached tensors, otherwise they would prevent memory from being + # reclaimed in subsequent recordings + self.current_node.remove_path_cached_tensors() + live_storages_wrappers = list(self.current_node.path_live_weakrefs()) + + # path_live_weakrefs guarantees that t() will not be None + live_storages_weak_refs: list[int] = [t() for t in live_storages_wrappers] # type: ignore[misc] + ptrs_to_deallocate = self.current_node.data_ptrs_dead_since_invocation() + torch._C._cuda_setCheckpointPoolState( + device, + # pyrefly: ignore [bad-argument-type] + state, + stale_storages, + live_storages_weak_refs, + ) + + # NB: deduplicate aliased outputs + for ptr in OrderedSet(ptrs_to_deallocate): + torch._C._cuda_cudaCachingAllocator_raw_delete(ptr) + + # Now the live blocks should be exactly equal to the live storages in private pool + if config.triton.slow_path_cudagraph_asserts: + check_memory_pool( + self.device_index, self.cuda_graphs_thread_pool, live_storages_wrappers + ) + for wrapper in live_storages_wrappers: + storage_ptr = wrapper() + assert storage_ptr is not None + assert torch._C._has_Standard_Deleter(storage_ptr) + assert wrapper.data_ptr() not in ptrs_to_deallocate + + def live_cudagraph_pool_storages_in_curr_execution( + self, + ) -> list[StorageWeakRefPointer]: + if self.current_node is None: + return [] + # explicitly ignoring previous recorded outputs from past path + # path_live_weakrefs() guarantees that t() will not be None + return [t() for t in self.current_node.path_live_weakrefs()] # type: ignore[misc] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/cudagraph_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/cudagraph_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..50d986d48e6c22f25e0e997e99edbef17cdf5bd3 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/cudagraph_utils.py @@ -0,0 +1,423 @@ +# mypy: disallow-untyped-defs +from __future__ import annotations + +import dataclasses +from collections.abc import Callable +from enum import Enum +from typing import Any, Optional, TYPE_CHECKING, Union + +import torch +from torch._dynamo.utils import counters, get_metrics_context +from torch._inductor.utils import GraphPartitionMap, InputType +from torch.utils._ordered_set import OrderedSet + +from .utils import is_using_cudagraph_partition + + +if TYPE_CHECKING: + from collections.abc import Sequence, Set as AbstractSet + + +perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints") +static_inputs_log = torch._logging.getArtifactLogger( + __name__, "cudagraph_static_inputs" +) + + +OutputType = list[Optional[Union[int, torch.Tensor]]] +ModelType = Callable[[list[InputType]], OutputType] + + +@dataclasses.dataclass(frozen=True) +class FunctionID: + "Unique counter of a function wrapped in cudagraphify_impl" + + id: int + + +@dataclasses.dataclass(frozen=True) +class PlaceholderInfo: + """ + A serializable version of torch.fx.Node that contains information + pertinent to placeholder stack traces. We use these in logging and error messages + related to cudagraphs, and will cache these results. + """ + + name: str + stack_trace: Optional[str] + # This field is recursive, but never cyclic (since a node never uses itself) + users: list[PlaceholderInfo] + mutating_use_stack_trace: Optional[str] + + +@dataclasses.dataclass(frozen=True) +class WrappedFunction: + """ + Represents a function that you want to record for CUDA graph replay, + with a little more metadata so we can identify if we have an applicable + CUDA graph in our CUDA graph tree for it. + """ + + model: Callable[..., Any] + static_input_idxs: Sequence[int] + id: FunctionID + constants: tuple[torch.Tensor, ...] + placeholders: Sequence[PlaceholderInfo] + mutated_input_idxs: Sequence[int] + + +def get_mutating_use_stack_trace_from_node( + placeholder_node: torch.fx.Node, +) -> Optional[str]: + # reinplaced uses might have a single, non-copy_ use + if len(placeholder_node.users) == 1: + return next(iter(placeholder_node.users)).meta.get("stack_trace", None) + + for use in placeholder_node.users: + if use.target is torch.ops.aten.copy_.default: + if stack_trace := use.meta.get("stack_trace", None): + return stack_trace + + return None + + +def get_mutating_use_stack_trace(placeholder_info: PlaceholderInfo) -> Optional[str]: + return placeholder_info.mutating_use_stack_trace + + +def to_placeholder_info(placeholder_node: torch.fx.Node) -> PlaceholderInfo: + name = placeholder_node.name + stack_trace = placeholder_node.meta.get("stack_trace", None) + users = [] + mutating_use_stack_trace = None + # Only recurse to users once, since we only care about user's stack traces + if placeholder_node.op == "placeholder": + users = [to_placeholder_info(i) for i in placeholder_node.users] + mutating_use_stack_trace = get_mutating_use_stack_trace_from_node( + placeholder_node + ) + + return PlaceholderInfo(name, stack_trace, users, mutating_use_stack_trace) + + +def get_placeholder_info(graph: torch.fx.Graph) -> list[PlaceholderInfo]: + return [ + to_placeholder_info(node) for node in graph.nodes if node.op == "placeholder" + ] + + +def format_default_skip_message(reason: str) -> str: + return f"skipping cudagraphs due to {reason}" + + +def get_mutation_stack_trace( + placeholders: Sequence[PlaceholderInfo], + mutation_indices: Union[AbstractSet[int], Sequence[int]], +) -> str: + stack_trace: Optional[str] = "" + + for idx in mutation_indices: + placeholder = placeholders[idx] + if stack_trace := get_mutating_use_stack_trace(placeholder): + break + + msg = format_default_skip_message( + f"mutated inputs ({len(mutation_indices)} instances)" + ) + if stack_trace: + return f"{msg}. Found from : \n {stack_trace}" + + return msg + + +def check_for_mutation( + func: WrappedFunction, + inputs: list[InputType], + is_cuda_graph_recorded_tensor: Callable[[torch.Tensor], bool], +) -> Optional[str]: + # doesn't work for non-trees because the warmup run would apply mutation twice + if torch._inductor.config.triton.cudagraph_trees: + # checking if mutation is only on parameters/static inputs + mutation_indices: Sequence[int] = [ + idx + for idx in func.mutated_input_idxs + if not ( + idx in func.static_input_idxs + or is_cuda_graph_recorded_tensor(inputs[idx]) # type: ignore[arg-type] + ) + ] + else: + mutation_indices = func.mutated_input_idxs + + static_inputs_log.debug( + "check mutation static input indices: %s", func.static_input_idxs + ) + static_inputs_log.debug("check mutation mutation indices: %s", mutation_indices) + + return ( + get_mutation_stack_trace(func.placeholders, mutation_indices) + if mutation_indices + else None + ) + + +def _get_use_stack_trace(node: torch.fx.Node) -> Optional[str]: + for use in node.users: + if stack_trace := use.meta.get("stack_trace", None): + return stack_trace + return None + + +def check_multiple_devices_or_any_cpu_nodes( + device_node_mapping: dict[torch.device, torch.fx.Node], +) -> Optional[str]: + # meta tensors are supported since there is no compute + device_node_mapping.pop(torch.device("meta"), None) + + # dynamo cudagraph does not support graph partition + if is_using_cudagraph_partition(): + # graph partition supports splitting on cpu op. So we can ignore cpu nodes. + device_node_mapping.pop(torch.device("cpu"), None) + + if cpu_node := device_node_mapping.get(torch.device("cpu")): + msg = f"cpu device ({cpu_node.name})" + if stack_trace := _get_use_stack_trace(cpu_node): + return format_default_skip_message(f"{msg}. Found from : \n {stack_trace}") + + return format_default_skip_message(msg) + + if ( + len(device_node_mapping) == 1 + and next(iter(device_node_mapping.keys())).type == "cuda" + ): + return None + + keys_repr = (repr(key) for key in device_node_mapping) + return format_default_skip_message(f"multiple devices: {', '.join(keys_repr)}") + + +def check_lowering_disable_cudagraph( + device_node_mapping: dict[torch.device, torch.fx.Node], +) -> Optional[str]: + return check_multiple_devices_or_any_cpu_nodes(device_node_mapping) + + +def log_cudagraph_skip_and_bump_counter(msg: str) -> None: + perf_hint_log.warning(msg) + counters["inductor"]["cudagraph_skips"] += 1 + + if torch._inductor.config.triton.cudagraph_or_error: + raise RuntimeError(msg) + + metrics_context = get_metrics_context() + if metrics_context.in_progress(): + metrics_context.set("cudagraph_skip_reason", msg, overwrite=True) + + +@dataclasses.dataclass +class BoxedDeviceIndex: + value: Optional[int] + + def set(self, device_idx: Optional[int]) -> None: + assert device_idx is None or isinstance(device_idx, int) + self.value = device_idx + + +def check_for_mutation_ignore_cuda_graph_managed_tensor( + gm: torch.fx.GraphModule, + mutated_inputs: OrderedSet[str], + mutated_input_idxs: OrderedSet[int], + static_input_idxs: Sequence[int], +) -> Optional[str]: + default_msg = format_default_skip_message("mutated inputs") + + # doesn't work for non-trees because the warmup run would apply mutation twice + if torch._inductor.config.triton.cudagraph_trees: + unique_idxs = OrderedSet(static_input_idxs) + # checking if mutation is only on parameters/static inputs + mutation_indices = [idx for idx in mutated_input_idxs if idx not in unique_idxs] + has_mutation = len(mutation_indices) != 0 + if not has_mutation: + return None + placeholders = get_placeholder_info(gm.graph) + return get_mutation_stack_trace(placeholders, mutation_indices) + + else: + has_mutation = len(mutated_inputs) != 0 + return None if not has_mutation else default_msg + + +def get_placeholder_stack_trace(placeholder: PlaceholderInfo) -> Optional[str]: + """ + Gets the first non-empty stack trace of a placeholder or its users. + """ + if placeholder.stack_trace: + return placeholder.stack_trace + + for user in placeholder.users: + if user.stack_trace: + return user.stack_trace + + return None + + +class CheckInvariantStatus(Enum): + # Check invariant succeeded + SUCCESS = 1 + + # Previously managed data pointers are not stable + CudagraphManagedIdxMismatch = 2 + + # Static tensor input addresses are not stable + StaticInputIdxMismatch = 3 + + # Expected dead indices before graph are live + ExpectedDeadIndicesBeforeGraphMismatch = 4 + + def __str__(self) -> str: + if self.name == "CudagraphManagedIdxMismatch": + return "cudagraph managed tensor data pointer changed" + elif self.name == "StaticInputIdxMismatch": + return "static input data pointer changed" + elif self.name == "ExpectedDeadIndicesBeforeGraphMismatch": + return "expected dead indices before graph are live" + else: + return f"{self.name}: {self.value}" + + +def log_data_ptr_mismatch( + placeholders: Sequence[PlaceholderInfo], + inputs: list[InputType], + recorded_data_ptr: Sequence[Optional[int]], + target_idxs: Sequence[int], + mismatch: CheckInvariantStatus, +) -> str: + """ + Logs the mismatch between input data pointers and recorded data pointers. + This checks only idxs in target_idxs. + """ + assert len(inputs) == len(recorded_data_ptr) and len(inputs) == len(placeholders), ( + "length mismatch between inputs, recorded_data_ptr, and placeholders" + ) + + t_tensors = [inputs[i] for i in target_idxs] + t_data_ptrs = [recorded_data_ptr[i] for i in target_idxs] + error_msg = f"{mismatch}.\n" + for i, (tensor, data_ptr) in enumerate(zip(t_tensors, t_data_ptrs)): + assert isinstance(tensor, torch.Tensor) + index = target_idxs[i] + if tensor.data_ptr() != data_ptr: + placeholder = placeholders[index] + error_msg = ( + f"{error_msg}input name: {placeholder.name}. " + f"data pointer changed from {data_ptr} to {tensor.data_ptr()}. " + f"input stack trace: {get_placeholder_stack_trace(placeholder)}\n" + ) + return error_msg + + +def maybe_warning_due_to_dynamic_shape( + fn_cache: dict[tuple[int, ...], Callable[..., Any]], + new_int_key: Any, +) -> bool: + num_cudagraphs = len(fn_cache.keys()) + 1 + + def warn_msg() -> str: + return ( + "CUDAGraph supports dynamic shapes by recording a new graph for each " + "distinct input size. Recording too many CUDAGraphs may lead to " + f"extra overhead. We have observed {num_cudagraphs} distinct sizes. " + "Please consider the following options for better performance: " + "a) padding inputs to a few fixed number of shapes; or b) set " + "torch._inductor.config.triton.cudagraph_skip_dynamic_graphs=True. " + "Set torch._inductor.config.triton.cudagraph_dynamic_shape_warn_limit=None " + "to silence this warning." + ) + + if ( + torch._inductor.config.triton.cudagraph_dynamic_shape_warn_limit + and num_cudagraphs + > torch._inductor.config.triton.cudagraph_dynamic_shape_warn_limit + ): + perf_hint_log.warning(warn_msg()) + return True + + return False + + +@dataclasses.dataclass(frozen=True) +class CudagraphCachedInfo: + """ + Info needed to realign inputs + """ + + placeholders: Sequence[PlaceholderInfo] + stack_traces: list[Optional[str]] + cudagraph_fail_reasons: list[str] + + +@dataclasses.dataclass(frozen=True) +class CudagraphMetadata: + """ + Metadata for recording a CUDA graph. + """ + + placeholders: Sequence[PlaceholderInfo] + static_input_idxs: OrderedSet[int] + mutated_input_idxs: OrderedSet[int] + stack_traces: list[Optional[str]] + constants: dict[str, torch.Tensor] + + +def get_partition_cudagraph_metadata( + partition_map: GraphPartitionMap, + metadata: CudagraphMetadata, +) -> CudagraphMetadata: + """ + Convert the cudagraph metadata at the graph level to the graph partition level, + given the graph partition info (i.e., mapping from partition input/output index + to graph input/output index). + """ + + partition_placeholders = [] + partition_static_input_idxs: OrderedSet[int] = OrderedSet() + partition_mutated_input_idxs: OrderedSet[int] = OrderedSet() + for partition_input_idx, graph_input_idx in enumerate( + partition_map.input_index_mapping + ): + if graph_input_idx in metadata.static_input_idxs: + partition_static_input_idxs.add(partition_input_idx) + + if graph_input_idx in metadata.mutated_input_idxs: + partition_mutated_input_idxs.add(partition_input_idx) + + if graph_input_idx is not None: + placeholder = metadata.placeholders[graph_input_idx] + else: + # create a dummy placeholder info since this partition input is not a graph input + placeholder = PlaceholderInfo( + name=f"partition_{partition_map.id}_placeholder_{partition_input_idx}", + stack_trace=None, + users=[], + mutating_use_stack_trace=None, + ) + partition_placeholders.append(placeholder) + + partition_stack_traces = [] + for graph_output_idx in partition_map.output_index_mapping: + if graph_output_idx is not None: + partition_stack_traces.append(metadata.stack_traces[graph_output_idx]) + else: + partition_stack_traces.append(None) + + partition_constants = { + name: metadata.constants[name] for name in partition_map.constant_names + } + + return CudagraphMetadata( + partition_placeholders, + partition_static_input_idxs, + partition_mutated_input_idxs, + partition_stack_traces, + partition_constants, + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/custom_graph_pass.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/custom_graph_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..53baac7bd9a8f95b28665f33fc7ffdacd09e04a8 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/custom_graph_pass.py @@ -0,0 +1,158 @@ +import hashlib +from abc import ABC, abstractmethod +from collections.abc import Callable, Sequence +from functools import lru_cache +from typing import Any, Optional, TypeAlias, Union + +import torch.fx.graph + + +class CustomGraphPass(ABC): + """ + Implement this interface for custom Graph passes: + + 1) The __call__() method contains the implementation of the custom pass. + + 2) The uuid() method enables inductor to cache compiled graphs when your custom + passes are applied. This method can return any identifier as long as it uniquely + identifies your implementation (and can be pickled). The caching logic includes this + identifier in its key calculation, i.e., any new value will effectively invalidate + existing entries. We expect custom passes would typically depend purely on the + textual representation of the implementation. In that case, we recommend using the + 'get_hash_for_files' helper below to compute a unique hash from the contents of a + static list of source files, i.e., the source(s) containing the custom pass + implementation. That approach ensures that any change to the implementation will + mean a new uuid. + + ** IMPORTANT ** If your custom pass's behavior depends on some external state, then + you'll need to implement something more complicated (or disable caching). + + EXAMPLE: + + class MyCustomGraphPass(CustomGraphPass): + def __call__(self, graph: torch.fx.graph.Graph) -> None: + # my custom graph optimization pass + # ... + + def uuid(self) -> Optional[Any]: + return get_hash_for_files((__file__,)) + + """ + + @abstractmethod + def __call__(self, graph: torch.fx.graph.Graph) -> None: + """ + Implementation of the custom pass. + """ + + @abstractmethod + def uuid(self) -> Optional[Any]: + """ + Return an ID to uniquely identify your custom pass implementation. Return None + to skip inductor code caching entirely. + """ + + +class CustomGraphModulePass(ABC): + """ + Implement this interface for custom Graph passes: + + 1) The __call__() method contains the implementation of the custom pass. + + 2) The uuid() method enables inductor to cache compiled graphs when your custom + passes are applied. This method can return any identifier as long as it uniquely + identifies your implementation (and can be pickled). The caching logic includes this + identifier in its key calculation, i.e., any new value will effectively invalidate + existing entries. We expect custom passes would typically depend purely on the + textual representation of the implementation. In that case, we recommend using the + 'get_hash_for_files' helper below to compute a unique hash from the contents of a + static list of source files, i.e., the source(s) containing the custom pass + implementation. That approach ensures that any change to the implementation will + mean a new uuid. + """ + + @abstractmethod + def __call__(self, gm: torch.fx.GraphModule) -> None: + """ + Implementation of the custom pass. + """ + + @abstractmethod + def uuid(self) -> Optional[Any]: + """ + Return an ID to uniquely identify your custom pass implementation. Return None + to skip inductor code caching entirely. + """ + + +CustomGraphPassType: TypeAlias = Optional[ + Union[CustomGraphPass, Callable[[torch.fx.graph.Graph], None]] +] + + +@lru_cache(1) +def get_hash_for_files(paths: tuple[str, ...], extra: str = "") -> bytes: + """ + Helper to compute a unique string by hashing the contents of a list of files. + """ + hasher = hashlib.sha256() + hasher.update(extra.encode("utf-8")) + for path in paths: + with open(path, "rb") as f: + hasher.update(f.read()) + return hasher.digest() + + +class CustomPartitionerFn(ABC): + """ + Implement this interface for custom partitioner: + + 1) The __call__() method contains the implementation of the custom partitioner. + + 2) The uuid() method enables inductor to cache compiled graphs when your custom + partitioner are applied. This method can return any identifier as long as it uniquely + identifies your implementation (and can be pickled). The caching logic includes this + identifier in its key calculation, i.e., any new value will effectively invalidate + existing entries. We expect custom partitioner would typically depend purely on the + textual representation of the implementation. In that case, we recommend using the + 'get_hash_for_files' helper below to compute a unique hash from the contents of a + static list of source files, i.e., the source(s) containing the custom partitioner + implementation. That approach ensures that any change to the implementation will + mean a new uuid. + + EXAMPLE: + + from torch._inductor.custom_graph_pass import get_hash_for_files + + class MyCustomPartitionerFn(CustomPartitionerFn): + def __call__( + self, + gm: torch.fx.GraphModule, + joint_inputs: Sequence[object], + **kwargs: Any + ) -> tuple[torch.fx.GraphModule, torch.fx.GraphModule]: + # my custom partitioner implementation + # ... + + def uuid(self) -> Optional[Any]: + return get_hash_for_files((__file__,)) + + """ + + @abstractmethod + def __call__( + self, gm: torch.fx.GraphModule, joint_inputs: Sequence[object], **kwargs: Any + ) -> tuple[torch.fx.GraphModule, torch.fx.GraphModule]: + """ + Implementation of the custom partitioner. + """ + + @abstractmethod + def uuid(self) -> Optional[Any]: + """ + Return an ID to uniquely identify your custom partitioner implementation. + Return None to skip inductor code caching entirely. + """ + + +CustomPartitionerFnType: TypeAlias = Optional[CustomPartitionerFn] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/debug.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/debug.py new file mode 100644 index 0000000000000000000000000000000000000000..39c90bdea94ffeed2a3b8bf97e86f473bd05049b --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/debug.py @@ -0,0 +1,1336 @@ +import collections +import contextlib +import copy +import dataclasses +import functools +import io +import itertools +import json +import logging +import os +import os.path +import pickle +import pstats +import shutil +import traceback +from collections.abc import Callable, Iterator, Sequence +from typing import Any, IO, Optional, Union +from unittest.mock import patch + +import torch +from functorch.compile import draw_graph, get_aot_graph_name, get_graph_being_compiled +from torch import fx +from torch._dynamo.repro.after_aot import save_graph_repro +from torch._dynamo.utils import get_debug_dir +from torch._inductor import utils +from torch._logging import getArtifactLogger +from torch._logging._internal import trace_structured +from torch._utils_internal import signpost_event +from torch.fx.graph_module import GraphModule +from torch.fx.passes.shape_prop import _extract_tensor_metadata, TensorMetadata +from torch.fx.passes.tools_common import legalize_graph +from torch.types import FileLike +from torch.utils._ordered_set import OrderedSet +from torch.utils._pytree import tree_map + +from . import config, ir # noqa: F811, this is needed +from .ir import ExternKernel +from .scheduler import ( + BaseSchedulerNode, + FusedSchedulerNode, + NopKernelSchedulerNode, + OutputNode, + SchedulerNode, +) +from .virtualized import V + + +log = logging.getLogger(__name__) + +# Graph execution tracking for debugging +GRAPH_EXECUTION_ORDER: Optional[list[dict[str, object]]] = None +RECORD_GRAPH_EXECUTION: bool = False +GRAPH_COMPILE_IDS: Optional[dict[int, Optional[str]]] = None + +ir_pre_fusion_log = getArtifactLogger(__name__, "ir_pre_fusion") +ir_post_fusion_log = getArtifactLogger(__name__, "ir_post_fusion") +SchedulerNodeList = list[Any] +BufMeta = collections.namedtuple("BufMeta", ["name", "n_origin"]) +GRAPHVIZ_COMMAND_SCALABLE = ["dot", "-Gnslimit=2", "-Gnslimit1=2", "-Gmaxiter=5000"] + + +@functools.cache +def has_dot() -> bool: + return shutil.which("dot") is not None + + +def draw_buffers( + nodes: list[BaseSchedulerNode], + print_graph: bool = False, + fname: Optional[str] = None, +) -> None: + """ + Draw a graph in fname.svg. + """ + if not has_dot(): + log.warning("draw_buffers() requires `graphviz` package") + return + + if fname is None: + fname = get_graph_being_compiled() + + graph = create_fx_from_snodes(nodes) + + for node in graph.nodes: + if "fusion_meta" not in node.meta: + continue + group = node.meta["fusion_meta"].group + if isinstance(group, tuple): + if isinstance(group[1], int): + group = (group[1],) + else: + group = group[1] + + # gather meta data + dtype = None + if isinstance(node, ir.ComputedBuffer): + dtype = node.data.dtype + + metadata = TensorMetadata(group, dtype, None, None, None, None, None) # type: ignore[arg-type] + # pyrefly: ignore [missing-attribute] + node.meta["tensor_meta"] = metadata + + if print_graph: + print(graph) + + gm = GraphModule({}, graph) + legalize_graph(gm) + gm.graph.lint() + draw_graph( + gm, fname, clear_meta=False, dot_graph_shape=config.trace.dot_graph_shape + ) + + +def create_fx_from_snodes(snodes: list[BaseSchedulerNode]) -> fx.Graph: + """ + Creates a FX Graph from a list of SchedulerNode objects. + """ + + def get_fake_func(name: str) -> Callable[..., int]: + def func1(*args: Any) -> int: + return 0 + + func1.__name__ = name + return func1 + + FusionMeta = collections.namedtuple("FusionMeta", ["group", "snode", "type"]) + + buf_to_fx_node = {} + node_to_fx_node = {} + graph = torch.fx.Graph() + first_node = None + + outputs = [] + group: Any = None + # create call_function node for each Buffer and Kernel + for snode in snodes: + if snode.is_extern(): + node_type = "extern" + group = node_type + elif snode.is_template(): + node_type = "template" + group = node_type + elif isinstance(snode, NopKernelSchedulerNode): + node_type = "nop" + group = node_type + elif isinstance(snode, SchedulerNode): + node_type = "compute" + group = snode.group + elif isinstance(snode, FusedSchedulerNode): + node_type = "fused" + group = snode.group + else: + raise RuntimeError("Unknown node type") + + fused_name = torch._inductor.utils.get_fused_kernel_name( + snode.get_nodes(), "original_aten" + ) + func_name = f"{node_type}: {fused_name}" + node_func = get_fake_func(func_name) + kwargs = {} + if hasattr(snode, "get_device"): + kwargs = {"device": snode.get_device()} + fx_node = graph.call_function(node_func, args=(), kwargs=kwargs) # type: ignore[arg-type] + + def in_output(snode: Union[BaseSchedulerNode, FusedSchedulerNode]) -> bool: + if isinstance(snode, FusedSchedulerNode): + return any(in_output(x) for x in snode.snodes) + return any( + isinstance(user.node, OutputNode) + for buf in snode.get_outputs() + for user in buf.users + ) + + if in_output(snode): + outputs.append(fx_node) + name = snode.get_name() + fx_node.name = name + + fx_node.meta["fusion_meta"] = FusionMeta(group, snode, node_type) + + node_to_fx_node[name] = fx_node + for buf in snode.get_outputs(): + buf_to_fx_node[buf.get_name()] = fx_node + + if first_node is None: + first_node = fx_node + + # create edges between nodes + for snode in snodes: + name = snode.get_name() + deps = snode.read_writes.reads + + fx_node = node_to_fx_node[name] + new_args = [] + for dep in deps: + if dep.name in buf_to_fx_node: + dep_node = buf_to_fx_node[dep.name] + else: + with graph.inserting_before(first_node): + dep_node = graph.placeholder(dep.name) + buf_to_fx_node[dep.name] = dep_node + if dep_node == fx_node: # to avoid cycles + continue + new_args.append(dep_node) + + fx_node.args = tuple(new_args) + + graph.output(outputs[0] if len(outputs) == 1 else tuple(outputs)) + return graph + + +def update_orig_fx_node_name_to_buf_name( + nodes: Optional[SchedulerNodeList], + node_name_to_buf_name: dict[str, str], + parent_buf_name: Optional[str] = None, + n_origins: int = 0, +) -> None: + if nodes is None: + return + for node in nodes: + # for FusedSchedulerNode, traverse recursively into get_nodes() + buf_name = node.get_name() + children_nodes = node.get_nodes() + if children_nodes is not None and len(children_nodes) > 1: + update_orig_fx_node_name_to_buf_name( + children_nodes, + node_name_to_buf_name, + buf_name if parent_buf_name is None else parent_buf_name, + ) + continue + else: + # pyrefly: ignore [bad-argument-type, unsupported-operation] + assert len(children_nodes) == 1 and children_nodes[0] == node + + ir_node = node.node + if ir_node is None or ir_node.origins is None: + continue + for origin in ir_node.origins: + node_name = origin.name + # when buf1 and buf2 both have origin=node1 + # we draw node1 according to buf1 + if node_name not in node_name_to_buf_name: + node_name_to_buf_name[node_name] = ( + buf_name if parent_buf_name is None else parent_buf_name + ) + + +def get_node_name_to_buf_meta( + node_name_to_buf_name: dict[str, str], +) -> dict[str, BufMeta]: + buf_name_to_n_node = {} + for node_name, buf_name in node_name_to_buf_name.items(): + if buf_name not in buf_name_to_n_node: + buf_name_to_n_node[buf_name] = OrderedSet([node_name]) + else: + # pyrefly: ignore [missing-attribute] + buf_name_to_n_node[buf_name].add(node_name) + + node_name_to_buf_meta = {} + for node_name, buf_name in node_name_to_buf_name.items(): + n_node = len(buf_name_to_n_node[buf_name]) + node_name_to_buf_meta[node_name] = BufMeta(buf_name, n_node) + return node_name_to_buf_meta + + +def annotate_orig_fx_with_snodes( + gm: torch.fx.GraphModule, + snodes: SchedulerNodeList, +) -> None: + """ + Creates a FX Graph from a list of SchedulerNode objects. + """ + node_name_to_buf_name: dict[str, str] = {} + update_orig_fx_node_name_to_buf_name(snodes, node_name_to_buf_name) + if node_name_to_buf_name is None: + return + node_name_to_buf_meta = get_node_name_to_buf_meta(node_name_to_buf_name) + for node in gm.graph.nodes: + if node.name in node_name_to_buf_meta: + node.meta["buf_meta"] = node_name_to_buf_meta.get(node.name) + + +@contextlib.contextmanager +def enable_aot_logging() -> Iterator[None]: + compile_debug = os.environ.get("TORCH_COMPILE_DEBUG", "0") == "1" + + import torch._functorch.aot_autograd + + log = logging.getLogger(torch._functorch.aot_autograd.__name__) + + stack = contextlib.ExitStack() + if not compile_debug: + try: + yield + finally: + stack.close() + return + + # Enable all graphs to be logged to a file by setting the flags to True + # and the log level of the file logger to DEBUG + stack.enter_context(patch("functorch.compile.config.debug_partitioner", True)) + + path = os.path.join(get_debug_dir(), "torchinductor") + os.makedirs(path, exist_ok=True) + + fh = logging.FileHandler( + os.path.join( + path, + f"aot_{get_aot_graph_name()}_debug.log", + ) + ) + fh.setLevel(logging.DEBUG) + fh.setFormatter( + logging.Formatter("[%(filename)s:%(lineno)d %(levelname)s] %(message)s") + ) + log.addHandler(fh) + try: + yield + finally: + log.removeHandler(fh) + stack.close() + + +# Used for provenance tracking +# They are not stored in DebugContext because they are not set in +# _inductor_triton_kernel_to_post_grad_node_info's Debug Context +_inductor_post_to_pre_grad_nodes: dict[str, dict[str, list[str]]] = {} +_inductor_triton_kernel_to_post_grad_node_info: dict[str, list[str]] = {} +_pre_grad_graph_id: Optional[int] = None +_inductor_pre_grad_node_stack_trace: dict[str, str] = {} +_inductor_kernel_stack_trace: dict[str, list[str]] = {} +_inductor_kernel_provenance_debug_handle: int = 0 + + +def reset_inductor_kernel_provenance_debug_handle() -> None: + global _inductor_kernel_provenance_debug_handle + _inductor_kernel_provenance_debug_handle = 0 + + +@contextlib.contextmanager +def reset_provenance_globals() -> Iterator[None]: + """Context manager that resets provenance tracking globals upon entering + and restores their original values when exiting.""" + global _pre_grad_graph_id + global _inductor_post_to_pre_grad_nodes + global _inductor_triton_kernel_to_post_grad_node_info + global _inductor_pre_grad_node_stack_trace + global _inductor_kernel_stack_trace + global _inductor_kernel_provenance_debug_handle + + # Store original values + original_pre_grad_graph_id = _pre_grad_graph_id + original_post_to_pre_grad_nodes = _inductor_post_to_pre_grad_nodes.copy() + original_triton_kernel_to_post_grad_node_info = ( + _inductor_triton_kernel_to_post_grad_node_info.copy() + ) + original_inductor_pre_grad_node_stack_trace = ( + _inductor_pre_grad_node_stack_trace.copy() + ) + original_inductor_kernel_stack_trace = _inductor_kernel_stack_trace.copy() + original_inductor_kernel_provenance_debug_handle = ( + _inductor_kernel_provenance_debug_handle + ) + + # Reset to default values + _pre_grad_graph_id = -1 + _inductor_post_to_pre_grad_nodes = {} + _inductor_triton_kernel_to_post_grad_node_info = {} + _inductor_pre_grad_node_stack_trace = {} + _inductor_kernel_stack_trace = {} + _inductor_kernel_provenance_debug_handle = 0 + + try: + yield + finally: + # Restore original values + _pre_grad_graph_id = original_pre_grad_graph_id + _inductor_post_to_pre_grad_nodes = original_post_to_pre_grad_nodes + _inductor_triton_kernel_to_post_grad_node_info = ( + original_triton_kernel_to_post_grad_node_info + ) + _inductor_kernel_stack_trace = original_inductor_kernel_stack_trace + _inductor_pre_grad_node_stack_trace = ( + original_inductor_pre_grad_node_stack_trace + ) + _inductor_kernel_provenance_debug_handle = ( + original_inductor_kernel_provenance_debug_handle + ) + + +class DebugContext: + _counter = itertools.count() + + @staticmethod + def create_debug_dir(folder_name: str) -> Optional[str]: + debug_dir = config.trace.debug_dir or get_debug_dir() + for n in DebugContext._counter: + dirname = os.path.join( + debug_dir, + "torchinductor", + f"{folder_name}.{n}", + ) + if not os.path.exists(dirname): + os.makedirs(dirname) + return dirname + return None + + def __init__(self) -> None: + self._prof = None + self._path = None + self._stack = contextlib.ExitStack() + + def copy(self, new_path: str) -> None: + if not self._path: + return + assert new_path.endswith(".debug"), new_path + from filelock import FileLock + + try: + with FileLock(f"{new_path}.lock"): + if os.path.exists(new_path): + shutil.rmtree(new_path) + shutil.copytree(self._path, new_path) + except OSError: + log.warning( + "Failed to copy debug files from %s to %s", self._path, new_path + ) + + def fopen( + self, + filename: str, + write_mode: str = "w", + *args: Any, + **kwargs: Any, + ) -> IO[Any]: + assert self._path + return open(os.path.join(self._path, filename), write_mode, *args, **kwargs) + + @contextlib.contextmanager + def fopen_context( + self, + filename: str, + write_mode: str = "w", + *args: Any, + **kwargs: Any, + ) -> Iterator[IO[Any]]: + assert self._path + with open(os.path.join(self._path, filename), write_mode, *args, **kwargs) as f: + yield f + + def filename(self, suffix: str) -> str: + assert self._path + return os.path.join(self._path, suffix) + + def upload_tar(self) -> None: + if config.trace.upload_tar is not None: + import tarfile + + assert self._path + tar_file = os.path.join( + self._path, f"{os.path.basename(self._path)}.tar.gz" + ) + with tarfile.open(tar_file, "w:gz") as tar: + tar.add(self._path, arcname=os.path.basename(self._path)) + config.trace.upload_tar(tar_file) + + def __enter__(self) -> None: + if config.debug: + log = logging.getLogger("torch._dynamo") + prev_level = log.level + log.setLevel(logging.DEBUG) + + def reset_log_level(level: Any) -> None: + log.setLevel(level) + + self._stack.callback(reset_log_level, prev_level) + + self._stack.enter_context(V.set_debug_handler(self)) + + if not config.trace.enabled: + return + + self._path = self.create_debug_dir(get_aot_graph_name()) # type: ignore[assignment] + + if config.trace.debug_log: + self._setup_log_capture("debug.log", logging.DEBUG) + if config.trace.info_log: + self._setup_log_capture("info.log", logging.INFO) + + def _setup_log_capture( + self, + filename: str, + level: int, + ) -> None: + log = logging.getLogger("torch._inductor") + fd = self._stack.enter_context(self.fopen(filename)) + ch = logging.StreamHandler(fd) + ch.setLevel(level) + ch.setFormatter( + logging.Formatter("[%(filename)s:%(lineno)d %(levelname)s] %(message)s") + ) + log.addHandler(ch) + log.setLevel(min(log.level, level)) + self._stack.callback(log.removeHandler, ch) + + def __exit__( + self, + exc_type: Optional[type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[Any], + ) -> None: + if self._prof: + self._prof.disable() + self._save_profile_data() + + if self._path: + self.upload_tar() + log.warning("%s debug trace: %s", get_graph_being_compiled(), self._path) + self._stack.close() + + def _save_profile_data(self) -> None: + assert self._prof + self._prof.dump_stats(self.filename("compile.prof")) + with self.fopen("compile.stats") as fd: + stats = pstats.Stats(self._prof, stream=fd) + stats.strip_dirs() + stats.sort_stats("cumtime") + stats.print_stats(100) + stats.sort_stats("tottime") + stats.print_stats(100) + + def __getattr__(self, name: str) -> Optional[Callable[..., None]]: + if config.trace.enabled and getattr(config.trace, name): + try: + return getattr(DebugFormatter(self), name) + except Exception: + log.warning("Ignoring exception in debug code", exc_info=True) + return None + else: + + def ignored(*args: Any, **kwargs: Any) -> None: + pass + + return ignored + + +class DebugFormatter: + def __init__(self, handler: DebugContext) -> None: + self.fopen = handler.fopen + self.fopen_context = handler.fopen_context + self.filename = handler.filename + self.handler = handler + + def fx_graph( + self, + gm: torch.fx.GraphModule, + inputs: list[torch.Tensor], + ) -> None: + with self.fopen("fx_graph_runnable.py") as fd: + save_dir = None + if torch._inductor.config.trace.save_real_tensors: + inputs = torch._subclasses.fake_utils.try_convert_fake_to_real(inputs) + save_dir = os.path.dirname(fd.name) + + # dont try to use stable hash torchinductor compilation if saving real tensors + # and avoid recursively trying to save real tensors inside of the inductor compilation + # regardless + stable_hash = torch._inductor.config.trace.save_real_tensors + with torch._inductor.config.patch( + {"trace.enabled": False, "trace.save_real_tensors": False} + ): + save_graph_repro( + fd, + gm, + inputs, + "inductor", + save_dir=save_dir, + stable_hash=stable_hash, + ) + + with self.fopen("fx_graph_readable.py") as fd: + fd.write(gm.print_readable(print_output=False)) + + def fx_graph_transformed( + self, + gm: torch.fx.GraphModule, + inputs: list[torch.Tensor], + ) -> None: + with self.fopen("fx_graph_transformed.py") as fd: + fd.write(gm.print_readable(print_output=False)) + + def ir_pre_fusion(self, nodes: SchedulerNodeList) -> None: + with self.fopen("ir_pre_fusion.txt") as fd: + fd.write(self._write_ir(nodes)) + + def ir_post_fusion(self, nodes: SchedulerNodeList) -> None: + with self.fopen("ir_post_fusion.txt") as fd: + fd.write(self._write_ir(nodes)) + + @staticmethod + def _write_ir(nodes: SchedulerNodeList) -> str: + buf = io.StringIO() + for node in nodes: + buf.write(node.debug_str()) + buf.write("\n\n\n") + return buf.getvalue() + + def graph_diagram(self, nodes: SchedulerNodeList) -> None: + draw_buffers(nodes, fname=self.filename("graph_diagram.svg")) + + def draw_orig_fx_graph( + self, + gm: torch.fx.GraphModule, + nodes: SchedulerNodeList, + ) -> None: + annotate_orig_fx_with_snodes(gm, nodes) + draw_graph( + gm, + fname=self.filename("orig_fx_graph_diagram.svg"), + clear_meta=False, + prog=GRAPHVIZ_COMMAND_SCALABLE, + parse_stack_trace=True, + dot_graph_shape=config.trace.dot_graph_shape, + ) + + def output_code(self, filename: str, extension: str = "py") -> None: + shutil.copy(filename, self.filename(f"output_code.{extension}")) + + def log_autotuning_results( + self, + name: str, + input_nodes: list[ir.IRNode], + timings: dict["ChoiceCaller", float], # type: ignore[name-defined] # noqa: F821 + elapse: float, + precompile_elapse: float, + prescreening_elapse: Optional[float], + ) -> None: + from .ir import FixedLayout + + def build_node_info(node: ir.IRNode) -> dict[str, str]: + if hasattr(node, "name"): + node_name = node.name + else: + node_name = "" + node_info = { + "name": node_name, + "type": type(node).__name__, + } + try: + layout = node.get_output_spec() + if isinstance(layout, FixedLayout): + offset = 0 + try: + offset = int(layout.offset) + except Exception: + try: + offset = V.graph.sizevars.size_hint( + layout.offset, fallback=0 + ) + except Exception: + pass + static_layout = FixedLayout( + layout.device, + dtype=layout.dtype, + size=[*V.graph.sizevars.size_hints(layout.size)], + stride=[*V.graph.sizevars.size_hints(layout.stride)], + offset=offset, + ) + node_info["layout"] = str(static_layout) + else: + node_info["layout"] = str(layout) + except Exception: + pass + try: + node_info["dtype"] = str(node.get_dtype()) + except Exception: + pass + try: + node_info["device"] = str(node.get_device()) + except Exception: + pass + try: + node_info["stride"] = str( + V.graph.sizevars.size_hints(node.get_stride()) + ) + except Exception: + pass + try: + node_info["size"] = str(V.graph.sizevars.size_hints(node.get_size())) # type: ignore[arg-type] + except Exception: + pass + try: + node_info["numel"] = str(V.graph.sizevars.size_hint(node.get_numel())) + except Exception: + pass + if hasattr(node, "data") and isinstance(node.data, ir.IRNode): + node_info["data"] = build_node_info(node.data) + return node_info + + general_properties = { + "op_name": name, + "cuda_device_name": torch.cuda.get_device_name(), + "cuda_device_count": torch.cuda.device_count(), + "input_nodes": [build_node_info(node) for node in input_nodes], + "autotuning_time": elapse, + "precompile_time": precompile_elapse, + "prescreening_time": prescreening_elapse, + } + with self.fopen_context( + "autotuning_result_json_list.txt", "at", encoding="utf-8" + ) as fd: + for caller, time in timings.items(): + info_dict = dict(caller.info_dict()) + info_dict.update(general_properties) + info_dict["benchmark_result"] = time + json.dump(info_dict, fd) + fd.write("\n") + + +def log_ir_pre_fusion(nodes: SchedulerNodeList) -> None: + if ir_pre_fusion_log.isEnabledFor(logging.INFO): + ir_pre_fusion_log.info("BEFORE FUSION\n%s", DebugFormatter._write_ir(nodes)) + + V.debug.ir_pre_fusion(nodes) + + +def log_ir_post_fusion(nodes: SchedulerNodeList) -> None: + if ir_post_fusion_log.isEnabledFor(logging.INFO): + ir_post_fusion_log.info("AFTER FUSION\n%s", DebugFormatter._write_ir(nodes)) + + V.debug.ir_post_fusion(nodes) + + +def _dump_collective_schedule(schedule: list[Union[str, None]]) -> None: + try: + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "inductor_collective_schedule", + "encoding": "json", + }, + payload_fn=lambda: schedule, + ) + except Exception: + log.debug( + "Failed to log inductor_collective_schedule via structured logging", + exc_info=True, + ) + + +def log_collective_schedule(nodes: Sequence[BaseSchedulerNode]) -> None: + schedule = [ + getattr(op, "python_kernel_name", None) + for node in nodes + if isinstance(op := getattr(node, "node", None), ir._CollectiveKernel) + ] + + # Only log when there is at least one collective op + if schedule: + _dump_collective_schedule(schedule) + + +def log_runtime_and_tensor_meta(node_runtimes: Sequence[tuple[Any, float]]) -> None: + """Log per-op runtime estimates and output tensor metadata for TLParse.""" + + try: + to_size_hints = V.graph.sizevars.size_hints + + def to_list(x: Optional[Sequence[Any]]) -> list[Any]: + return list(to_size_hints(x)) if x is not None else [] + + def dtype_to_str(dtype: Any) -> Optional[str]: + if dtype is None: + return None + s = str(dtype) + s = s.removeprefix("torch.") + return s + + ops: list[dict[str, Any]] = [] + for s, runtime_ns in node_runtimes: + name = getattr(s.node, "python_kernel_name", s.get_name()) + op_type = "collective" if utils.is_collective(s.node) else "compute" + + # Build outputs metadata if available + outputs: list[dict[str, Any]] = [] + try: + for buf in s.get_outputs(): + irnode = buf.node + shape = irnode.maybe_get_size() + stride = ( + irnode.get_stride() + if isinstance(irnode.layout, ir.Layout) + else None + ) + dtype = irnode.maybe_get_dtype() + outputs.append( + { + "shape": to_list(shape), + "stride": to_list(stride), + "dtype": dtype_to_str(dtype), + } + ) + except Exception: + pass + + ops.append( + { + "name": name, + "type": op_type, + "estimated_runtime_ns": runtime_ns, + "outputs": outputs, + } + ) + + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "inductor_runtime_and_tensor_meta", + "encoding": "json", + }, + payload_fn=lambda: {"ops": ops}, + ) + except Exception: + log.debug("Failed to log inductor_runtime_and_tensor_meta", exc_info=True) + + +def log_graph_execution() -> None: + """Emit a structured artifact with the graph execution order.""" + if not GRAPH_EXECUTION_ORDER: + return + try: + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "graph_execution", + "encoding": "json", + }, + payload_fn=lambda: {"graph_execution_order": GRAPH_EXECUTION_ORDER}, + ) + except Exception: + log.debug("Failed to log graph_execution", exc_info=True) + + +@contextlib.contextmanager +def record_and_log_graph_execution_order() -> Iterator[None]: + """Record graph execution order and log it once on exit.""" + global RECORD_GRAPH_EXECUTION, GRAPH_EXECUTION_ORDER, GRAPH_COMPILE_IDS + GRAPH_EXECUTION_ORDER = [] + GRAPH_COMPILE_IDS = {} + RECORD_GRAPH_EXECUTION = True + try: + yield + finally: + log_graph_execution() + RECORD_GRAPH_EXECUTION = False + GRAPH_EXECUTION_ORDER = None + GRAPH_COMPILE_IDS = None + + +@dataclasses.dataclass +class TensorMetadataHolder: + tensor_metadata: TensorMetadata + device: torch.device + + +save_args_cnt = itertools.count() + + +def create_mapping_pre_post_grad_nodes( + pre_grad_graph_id: Optional[int], + post_to_pre_grad_nodes_json: dict[str, Any], +) -> dict[str, dict[str, list[str]]]: + """ + Create bidirectional mappings between pre_grad graph nodes + and post_grad graph code nodes, and vice versa. + """ + # return a dummy dict if there's any error + empty_return: dict[str, dict[str, list[str]]] = { + "preToPost": {}, + "postToPre": {}, + } + + if not isinstance(post_to_pre_grad_nodes_json, dict): + log.error("Provenance tacking error: post_to_pre_grad_nodes_json is not a dict") + return empty_return + + if not isinstance(pre_grad_graph_id, int): + # pre_grad_graph_id may be empty if there's no pre_grad graph + # and there's only a backward graph from backward pass engine + return empty_return + + pre_to_post: dict[str, Any] = collections.defaultdict(OrderedSet) + post_to_pre: dict[str, Any] = collections.defaultdict(OrderedSet) + + try: + + def check_format(node: dict[str, Any]) -> bool: + if not isinstance(node, dict): + log.error( + "Provenance tacking error: node provenance in post_to_pre_grad_nodes_json is not a dict" + ) + return False + if "graph_id" not in node or "name" not in node or "from_node" not in node: + log.error( + "Provenance tacking error: node provenance in post_to_pre_grad_nodes_json has wrong format" + ) + return False + return True + + for outer_key, node_array in post_to_pre_grad_nodes_json.items(): + if not isinstance(node_array, list): + log.error( + "Provenance tacking error: post_to_pre_grad_nodes_json value is not a list" + ) + return empty_return + for node in node_array: + if not check_format(node): + return empty_return + # Check the current node first + if node.get("graph_id") == pre_grad_graph_id: + pre_to_post[node["name"]].add(outer_key) + post_to_pre[outer_key].add(node["name"]) + + # Check nested from_node array recursively, add node with the right graph_id to the map + stack = [(n, outer_key) for n in node.get("from_node", [])] + while stack: + current_node, parent_key = stack.pop() + if not check_format(current_node): + return empty_return + if current_node.get("graph_id") == pre_grad_graph_id: + pre_to_post[current_node["name"]].add(parent_key) + post_to_pre[parent_key].add(current_node["name"]) + stack.extend( + (n, parent_key) for n in current_node.get("from_node", []) + ) + + def convert_sets_to_lists(d: dict[str, Any]) -> None: + for key in d: + d[key] = list(d[key]) + d = dict(d) + + # convert to list because set is not JSON serializable + convert_sets_to_lists(pre_to_post) + convert_sets_to_lists(post_to_pre) + return { + "preToPost": pre_to_post, + "postToPre": post_to_pre, + } + except Exception as e: + # Since this is just logging code, it should never interfere with regular + # program execution, so we use this try-except to guard against any error + signpost_event( + "inductor", + "provenance_tracking_error", + { + "function": "create_mapping_pre_post_grad_nodes", + "error_msg": str(e), + "stack_trace": traceback.format_exc(), + }, + ) + log.error("post_to_pre_grad_nodes_json: %s", post_to_pre_grad_nodes_json) + log.error("pre_grad_graph_id: %s", pre_grad_graph_id) + return empty_return + + +def create_node_mapping_kernel_to_post_grad( + triton_kernel_to_post_grad_json: dict[str, Any], +) -> dict[str, dict[str, Any]]: + """Create bidirectional mappings between triton kernel name and post_grad + graph code nodes, and vice versa. + """ + + # return a dummy dict if there's any error + empty_return: dict[str, dict[str, Any]] = { + "cppCodeToPost": {}, + "postToCppCode": {}, + } + + if not isinstance(triton_kernel_to_post_grad_json, dict): + log.error( + "Provenance tacking error: triton_kernel_to_post_grad_json is not a dict" + ) + return empty_return + + post_to_cpp_code: dict[str, Any] = collections.defaultdict(OrderedSet) + + try: + for outer_key, node_array in triton_kernel_to_post_grad_json.items(): + if not isinstance(node_array, list): + log.error( + "Provenance tacking error: triton_kernel_to_post_grad_json value is not a list" + ) + return empty_return + for curr_node in node_array: + post_to_cpp_code[curr_node].add(outer_key) + + def convert_sets_to_lists(d: dict[str, Any]) -> None: + for key in d: + d[key] = list(d[key]) + d = dict(d) + + # convert to list because set is not JSON serializable + convert_sets_to_lists(post_to_cpp_code) + return { + "cppCodeToPost": triton_kernel_to_post_grad_json, + "postToCppCode": post_to_cpp_code, + } + except Exception as e: + # Since this is just logging code, it should never interfere with regular + # program execution, so we use this try-except to guard against any error + signpost_event( + "inductor", + "provenance_tracking_error", + { + "function": "create_mapping_kernel_to_post_grad", + "error_msg": str(e), + "stack_trace": traceback.format_exc(), + }, + ) + log.error( + "triton_kernel_to_post_grad_json: %s", triton_kernel_to_post_grad_json + ) + return empty_return + + +def dump_inductor_provenance_info() -> dict[str, Any]: + try: + global _pre_grad_graph_id + global _inductor_post_to_pre_grad_nodes + global _inductor_triton_kernel_to_post_grad_node_info + node_mapping: dict[str, Any] = {} + if _pre_grad_graph_id: + node_mapping_kernel = create_node_mapping_kernel_to_post_grad( + _inductor_triton_kernel_to_post_grad_node_info + ) + node_mapping = { + **_inductor_post_to_pre_grad_nodes, + **node_mapping_kernel, + } + if config.trace.enabled: + with V.debug.fopen( + "inductor_provenance_tracking_node_mappings.json", "w" + ) as fd: + json.dump(node_mapping, fd) + # we need to update the node mapping version when node mapping format changes + # so the tlparse tool knows which node mapping version it is looking at + node_mapping["version"] = 2.0 + return node_mapping + except Exception as e: + # Since this is just debugging, it should never interfere with regular + # program execution, so we use this try-except to guard against any error + signpost_event( + "inductor", + "provenance_tracking_error", + { + "function": "dump_inductor_provenance_info", + "error_msg": str(e), + "stack_trace": traceback.format_exc(), + }, + ) + return {} + + +def create_kernel_information_json() -> dict[str, dict[str, list[str]]]: + """Create kernel information JSON""" + try: + global _inductor_post_to_pre_grad_nodes + global _inductor_kernel_stack_trace + global _inductor_triton_kernel_to_post_grad_node_info + + post_to_pre = _inductor_post_to_pre_grad_nodes.get("postToPre", {}) + all_kernels = OrderedSet(_inductor_kernel_stack_trace.keys()) | OrderedSet( + _inductor_triton_kernel_to_post_grad_node_info.keys() + ) + + result = {} + for kernel_name in all_kernels: + post_grad_nodes = _inductor_triton_kernel_to_post_grad_node_info.get( + kernel_name, [] + ) + + pre_grad_nodes: OrderedSet[str] = OrderedSet() + for post_node in post_grad_nodes: + pre_grad_nodes.update(post_to_pre.get(post_node, [])) + + result[kernel_name] = { + "stack_traces": _inductor_kernel_stack_trace.get(kernel_name, []), + "post_grad_nodes": post_grad_nodes, + "pre_grad_nodes": list(pre_grad_nodes), + } + + return result + except Exception as e: + signpost_event( + "inductor", + "provenance_tracking_error", + { + "function": "create_kernel_information_json", + "error_msg": str(e), + "stack_trace": traceback.format_exc(), + }, + ) + return {} + + +def set_kernel_post_grad_provenance_tracing( + node_schedule: Union[Sequence[BaseSchedulerNode], ExternKernel], + kernel_name: str, + is_extern: bool = False, +) -> Optional[int]: + """ + Set the mapping between `kernel_name` and the post_grad nodes in `node_schedule`. + + Returns a unique int debug handler for each call to this function. + """ + + if config.trace.provenance_tracking_level == 0: + return None + + try: + from .codegen.simd_kernel_features import DisableReduction, EnableReduction + + global _inductor_triton_kernel_to_post_grad_node_info + global _inductor_kernel_stack_trace + global _inductor_kernel_provenance_debug_handle + + _inductor_kernel_provenance_debug_handle += 1 + stack_traces: list[str] = [] + kernel_name = f"{kernel_name}:{_inductor_kernel_provenance_debug_handle}" + if is_extern: + assert isinstance(node_schedule, ExternKernel) + curr_node_info = _inductor_triton_kernel_to_post_grad_node_info.setdefault( + kernel_name, [] + ) + # 'origins' on IR nodes gives what FX IR nodes contributed to any given fused kernel. + # "origin_node" is more precise and says that the contents of this node corresponds + # EXACTLY to the output of a particular FX node, but it's not always available + if node_schedule.origin_node: + origin_node_name = node_schedule.origin_node.name + if origin_node_name not in curr_node_info: + curr_node_info.append(origin_node_name) + else: + curr_node_info.extend( + origin.name + for origin in node_schedule.origins + if origin.name not in curr_node_info + ) + stack_traces = list(node_schedule.get_stack_traces()) + else: + assert isinstance(node_schedule, list) + stack_traces_set: OrderedSet[str] = OrderedSet() + for snode in node_schedule: + if snode not in (EnableReduction, DisableReduction): + if snode.node is not None: + curr_node_info = ( + _inductor_triton_kernel_to_post_grad_node_info.setdefault( + kernel_name, [] + ) + ) + # pyrefly: ignore [missing-attribute] + stack_traces_set.update(snode.node.get_stack_traces()) + curr_node_info.extend( + origin.name + # pyrefly: ignore [missing-attribute] + for origin in snode.node.origins + if origin.name not in curr_node_info + ) + stack_traces = list(stack_traces_set) + _inductor_kernel_stack_trace.setdefault(kernel_name, []).extend(stack_traces) + return _inductor_kernel_provenance_debug_handle + except Exception as e: + # Since this is just debugging, it should never interfere with regular + # program execution, so we use this try-except to guard against any error + signpost_event( + "inductor", + "provenance_tracking_error", + { + "function": "set_kernel_post_grad_provenance_tracing", + "error_msg": str(e), + "stack_trace": traceback.format_exc(), + }, + ) + return None + + +def save_args_for_compile_fx_inner(*args: Any, **kwargs: Any) -> None: + """ + This function is used to save arguments for a compile_fx_inner function call + to the file system. Later on one can replay the compile_fx_inner call + with the saved arguments using load_args_and_run_compile_fx_inner. + """ + + folder = "/tmp/inductor_saved_args" + if not os.path.exists(folder): + os.mkdir(folder) + + def handle_tensor(x: Any) -> Any: + """ + Pickle FakeTensor will result in error: + AttributeError: Can't pickle local object 'WeakValueDictionary.__init__..remove' + + Convert all Tensor to metadata. This may also makes pickle faster. + """ + if isinstance(x, torch.Tensor): + return TensorMetadataHolder(_extract_tensor_metadata(x), x.device) + else: + return x + + args_to_save, kwargs_to_save = tree_map(handle_tensor, (args, kwargs)) + + fn_name = "compile_fx_inner" + path = f"{folder}/{fn_name}_{next(save_args_cnt)}.pkl" + with open(path, "wb") as f: + pickle.dump((args_to_save, kwargs_to_save), f) + + if log.isEnabledFor(logging.DEBUG): + message = f""" +Arguments for a compile_fx_inner call is saved to {path}. To replay the call, +run the following: + +from torch._inductor.debug import load_args_and_run_compile_fx_inner +load_args_and_run_compile_fx_inner({path!r}) + """ + # call print rather than log.debug. log.debug will print message + # prefix for each line which makes the code snippet harder to be + # copied. + # Not a big deal since the code is already been guarded by checking + # the log level. + print(message) + + +def load_args_and_run_compile_fx_inner(path: str) -> Any: + from torch._inductor.compile_fx import compile_fx_inner + + with open(path, "rb") as f: + args, kwargs = pickle.load(f) + + def handle_tensor(x: Any) -> Any: + if isinstance(x, TensorMetadataHolder): + return torch._dynamo.testing.rand_strided( + x.tensor_metadata.shape, + x.tensor_metadata.stride, + x.tensor_metadata.dtype, + x.device, + ) + else: + return x + + fake_mode = torch._subclasses.FakeTensorMode(allow_non_fake_inputs=True) + with fake_mode, config.patch("save_args", False): + args, kwargs = tree_map(handle_tensor, (args, kwargs)) + return compile_fx_inner(*args, **kwargs) + + +def aot_inductor_minifier_wrapper( + func: Callable[..., str], + exported_program: torch.export.ExportedProgram, + *, + inductor_configs: dict[str, Any], + package_path: Optional[FileLike] = None, +) -> str: + from torch._dynamo.debug_utils import AccuracyError + from torch._dynamo.repro.aoti import dump_to_minify + from torch._inductor import config + from torch._inductor.compile_fx import _aoti_flatten_inputs + + use_minifier = config.aot_inductor.dump_aoti_minifier + + gm = exported_program.module(check_guards=False) + assert isinstance(gm, torch.fx.GraphModule) + + args, kwargs = exported_program.example_inputs + + try: + if use_minifier and config.aot_inductor.repro_level == 3: + # Always dump the original module in case we have segfaults + dump_to_minify( + exported_program, + "aot_inductor", + options=inductor_configs, + ) + if use_minifier and config.aot_inductor.repro_level == 4: + # Check for accuracy + # We will first flatten the inputs before compiling and checking for accuracy. + # This is ok because we will flatten the inputs in the minifier anyway. + gm_copy = copy.deepcopy(gm) + example_inputs_copy = copy.deepcopy(exported_program.example_inputs) + config_copy = copy.deepcopy(inductor_configs) + flat_example_inputs, config_copy = _aoti_flatten_inputs( + gm_copy, + example_inputs_copy[0], + example_inputs_copy[1], + options=config_copy, + ) + tuple_inputs = tuple(flat_example_inputs) + flattened_ep = torch.export.export(gm_copy, tuple_inputs, strict=False) + func( + flattened_ep.module(check_guards=False), + tuple_inputs, + inductor_configs=config_copy, + package_path=package_path, + load_and_run=True, + check_accuracy="accuracy", + ) + + return func( + gm, + args, + kwargs, + inductor_configs=inductor_configs, + package_path=package_path, + load_and_run=use_minifier, + ) + except AccuracyError as e: + dump_to_minify( + exported_program, + "aot_inductor_accuracy", + command="minify", + options=inductor_configs, + ) + log.warning("Accuracy failed") + raise e + except Exception as e: + if use_minifier: + command = "minify" + + if config.aot_inductor.repro_level == 1: + command = "run" + + dump_to_minify( + exported_program, + "aot_inductor", + command=command, + options=inductor_configs, + ) + raise e diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/decomposition.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/decomposition.py new file mode 100644 index 0000000000000000000000000000000000000000..25e0ea31649d4e5749719041cc27ac61d7e84e76 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/decomposition.py @@ -0,0 +1,1259 @@ +# mypy: allow-untyped-decorators +import functools +import logging +import math +import operator +import sys +import typing +from collections.abc import Callable +from typing import Any, Optional, TypeAlias, TypeVar, Union +from typing_extensions import ParamSpec + +import torch +import torch._decomp as decomp +import torch._prims_common as utils +import torch.ao.quantization.fx._decomposed +from torch._decomp import ( + core_aten_decompositions, + get_decompositions, + remove_decompositions, +) +from torch._decomp.decompositions import ( + _grid_sampler_2d as decomp_grid_sampler_2d, + _index_add, + embedding_dense_backward as decomp_embedding_dense_backward, + pw_cast_for_opmath, + pw_cast_for_opmath_non_tensor_args, +) +from torch._decomp.decompositions_for_rng import extra_random_decomps +from torch._dynamo.utils import counters +from torch._environment import is_fbcode +from torch._higher_order_ops.out_dtype import out_dtype +from torch._inductor.utils import pad_listlike +from torch._prims_common import ( + elementwise_dtypes, + ELEMENTWISE_TYPE_PROMOTION_KIND, + type_to_dtype, +) +from torch._refs import native_layer_norm as decomp_native_layer_norm +from torch.fx.experimental.symbolic_shapes import guard_or_false, statically_known_true + +from . import config, inductor_prims +from .utils import ( + is_gpu, + needs_fallback_due_to_atomic_add_limitations, + use_scatter_fallback, +) + + +_T = TypeVar("_T") +_P = ParamSpec("_P") + +_GenericOperator: TypeAlias = Union[ + torch._ops.OperatorBase, torch._ops.OpOverloadPacket +] + +log = logging.getLogger(__name__) +aten = torch.ops.aten +prims = torch.ops.prims +quantized = torch.ops.quantized +_quantized = torch.ops._quantized +quantized_decomposed = torch.ops.quantized_decomposed + +inductor_decompositions = get_decompositions( + [ + aten._adaptive_avg_pool2d_backward, + aten.index_select, + aten.addmv, + aten.arange, + aten.bitwise_and_, + aten.bitwise_or_, + aten.clamp_min_, + aten.dist, + aten.elu, + aten.empty_like, + aten.flip, + aten.gelu, + aten.hardtanh, + aten.lcm, + aten.leaky_relu, + aten.linalg_vector_norm, + aten._log_softmax, + aten.max_pool2d_with_indices_backward, + aten._native_batch_norm_legit, + aten._native_batch_norm_legit_functional, + aten._native_batch_norm_legit_no_training, + aten._batch_norm_with_update, + aten._batch_norm_with_update_functional, + aten._batch_norm_no_update, + aten.batch_norm_backward, + aten.native_batch_norm, + aten.native_group_norm, + aten.native_layer_norm, + aten.nll_loss2d_backward, + aten.permute_copy, + aten.rrelu_with_noise_backward, + aten._softmax, + aten.sin_, + aten.sqrt_, + out_dtype, + aten._to_copy, + aten.tril_indices, + aten.triu_indices, + aten.unbind_copy.int, + aten.upsample_bilinear2d.vec, + quantized.linear_dynamic_fp16_unpacked_weight, + _quantized.wrapped_quantized_linear, + ] +) +decompositions = {**core_aten_decompositions(), **inductor_decompositions} + +# Remove unwanted decompositions included via the core ATen decompositions from +# the Inductor decomp table. +decomps_to_exclude: list[Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket]] = [ + aten._unsafe_index, + aten._unsafe_masked_index, + aten._unsafe_masked_index_put_accumulate, + aten._scaled_dot_product_flash_attention_for_cpu.default, # See comments in torch/_decomp/decompositions.py + aten._softmax_backward_data, + aten.clamp_max, + aten.clamp_min, + aten.embedding_dense_backward, # we fall back on xpu + aten.native_layer_norm, # we fall back on mtia + aten.index_add, # we conditionally call this decomp + aten.glu, # inductor lowers this directly + aten.select_scatter, # need to be in the ATen graph in order for it to work with the re-inplacing pass + aten.slice_scatter, # need to be in the ATen graph in order for it to work with the re-inplacing pass + aten.split.Tensor, # inductor lowers this directly + aten.squeeze, # inductor lowers this directly + aten.sum, # inductor lowers this directly + aten.unbind, # inductor lowers this directly + aten.baddbmm, # upcasts to fp32, perf issue +] + +remove_decompositions(decompositions, decomps_to_exclude) + + +def register_decomposition( + ops: Union[_GenericOperator, list[_GenericOperator]], +) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: + for op in ops if isinstance(ops, list) else [ops]: + if op in decompositions: + log.warning("duplicate decomp: %s", ops) + return decomp.register_decomposition(ops, decompositions) + + +@register_decomposition([aten.embedding_dense_backward]) +def _embedding_dense_backward( + grad_output: torch.Tensor, + indices: torch.Tensor, + num_weights: int, + padding_idx: int, + scale_grad_by_freq: bool, +) -> torch.Tensor: + # TODO: check if XE4 still need this fallback + # check torch.xpu.get_device_properties(grad_output.device).architecture + if grad_output.is_xpu: + return NotImplemented + # We can write a util function to update decomp table if we have more ops to fallback. + return decomp_embedding_dense_backward( + grad_output, indices, num_weights, padding_idx, scale_grad_by_freq + ) + + +@register_decomposition(aten.native_layer_norm) +def _native_layer_norm( + input: torch.Tensor, + normalized_shape: utils.ShapeType, + weight: Optional[torch.Tensor], + bias: Optional[torch.Tensor], + eps: float, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if input.is_mtia: + return NotImplemented + # We can write a util function to update decomp table if we have more ops to fallback. + return decomp_native_layer_norm(input, normalized_shape, weight, bias, eps) + + +@register_decomposition([aten.sym_constrain_range_for_size.default]) +def sym_constrain_range_for_size( + symbol: torch.SymInt, + *, + min: Optional[torch.types.Number] = None, + max: Optional[torch.types.Number] = None, +) -> None: + return + + +@register_decomposition([aten.clamp]) +@pw_cast_for_opmath_non_tensor_args +def clamp( + x: torch.Tensor, + min: Optional[torch.types.Number] = None, + max: Optional[torch.types.Number] = None, +) -> torch.Tensor: + if min is not None: + x = x.clamp_min(min) + if max is not None: + x = x.clamp_max(max) + return x + + +@register_decomposition([aten.full]) +def full( + size: list[Union[int, torch.SymInt]], + fill_value: torch.types.Number, + **kwargs: Any, +) -> torch.Tensor: + dtype = kwargs.get("dtype") + if dtype is None: + kwargs["dtype"] = type_to_dtype(type(fill_value)) + return torch.full(size, fill_value, **kwargs) + return NotImplemented + + +@register_decomposition([aten.index_add]) +def index_add( + x: torch.Tensor, + dim: int, + index: torch.Tensor, + tensor: torch.Tensor, + *, + alpha: torch.types.Number = 1, +) -> torch.Tensor: + # If we are not in fbcode and dtype is bfloat16 + # fallback to index_add kernel + # see https://github.com/pytorch/pytorch/issues/137425 for details + if not is_fbcode() and x.dtype == torch.bfloat16: + return NotImplemented + else: + return _index_add(x, dim, index, tensor, inplace=False, alpha=alpha) + + +# Not really sure how to put this into the main library. PrimTorch wants +# empty_permuted to go to the prim, and typically users don't really want +# to decompose to empty_strided (but inductor is OK with it, because we are +# cool with strides and everything goes to empty_strided) +@register_decomposition([aten.empty_permuted.default]) +def empty_permuted( + size: list[Union[int, torch.SymInt]], + physical_layout: list[int], + **kwargs: Any, +) -> torch.Tensor: + is_identity = list(physical_layout) == list(range(len(physical_layout))) + + if is_identity: + return torch.empty(size, **kwargs) + else: + perm = [0] * len(size) + for p, l in enumerate(physical_layout): + perm[l] = p + return torch.empty([size[l] for l in physical_layout], **kwargs).permute(perm) + + +@register_decomposition([aten.convolution_backward]) +def convolution_backward( + grad_output: torch.Tensor, + input: torch.Tensor, + weight: torch.Tensor, + bias_sizes: list[int], + stride: Union[int, list[int]], + padding: Union[int, list[int]], + dilation: Union[int, list[int]], + transposed: bool, + output_padding: list[int], + groups: int, + output_mask: list[bool], +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if not output_mask[2] or not is_gpu(grad_output.device.type): + return NotImplemented + grad_bias = aten.sum(grad_output, [0] + list(range(2, grad_output.dim()))) + grad_inp, grad_weight, _ = aten.convolution_backward( + grad_output, + input, + weight, + bias_sizes, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + [output_mask[0], output_mask[1], False], + ) + return (grad_inp, grad_weight, grad_bias) + + +@register_decomposition([aten.round.decimals]) +def round_dec(x: torch.Tensor, decimals: int = 0) -> torch.Tensor: + ten_pow_decimals = 10.0**decimals + return aten.round(x * ten_pow_decimals) * (1.0 / ten_pow_decimals) + + +@register_decomposition([aten.bmm]) +@pw_cast_for_opmath +def bmm( + self: torch.Tensor, + batch2: torch.Tensor, + out_dtype: Optional[torch.dtype] = None, +) -> torch.Tensor: + # TODO: Re-enable for mps once our reductions are performant enough + # (https://github.com/pytorch/pytorch/issues/150121) + if config.coordinate_descent_tuning and self.device.type not in ["cpu", "mps"]: + if statically_known_true(self.shape[1] == 1) or statically_known_true( + batch2.shape[2] == 1 + ): + out = (self.unsqueeze(-1) * batch2.unsqueeze(1)).sum(dim=2) + return out + if self.device.type == "cpu": + if statically_known_true(self.size(1) == 1) and statically_known_true( + batch2.size(-1) == 1 + ): + counters["inductor"]["decompose_bmm"] += 1 + return torch.sum( + self.squeeze(1) * batch2.squeeze(-1), dim=1, keepdim=True + ).unsqueeze(1) + return NotImplemented + + +@register_decomposition([aten.addmm]) +@pw_cast_for_opmath +def addmm( + self: torch.Tensor, + mat1: torch.Tensor, + mat2: torch.Tensor, + out_dtype: Optional[torch.dtype] = None, + beta: torch.types.Number = 1, + alpha: torch.types.Number = 1, +) -> torch.Tensor: + if self.device.type == "cpu": + if statically_known_true(mat1.size(0) == 1) and statically_known_true( + mat2.size(-1) == 1 + ): + counters["inductor"]["decompose_addmm"] += 1 + out = torch.sum( + mat1.squeeze(0) * mat2.squeeze(-1), dim=0, keepdim=True + ).unsqueeze(0) + return alpha * out + beta * self + if ( + statically_known_true(mat1.size(0) == 1) + and guard_or_false(mat2.size(0) <= 16) + and guard_or_false(mat2.size(1) <= 16) + ): + counters["inductor"]["decompose_addmm"] += 1 + out = (mat1.T * mat2).sum(dim=0, keepdim=True) + return alpha * out + beta * self + return NotImplemented + + +@register_decomposition([aten.mm]) +@pw_cast_for_opmath +def mm( + self: torch.Tensor, + input2: torch.Tensor, + out_dtype: Optional[torch.dtype] = None, +) -> torch.Tensor: + # Our matrix vector multiplies only achieve peak bandwidth with coordinate descent tuning. + # todo: Look into why and fix it (hopefully) + + # TODO: Re-enable for mps once our reductions are performant enough + # (https://github.com/pytorch/pytorch/issues/150121) + if config.coordinate_descent_tuning and self.device.type not in ["cpu", "mps"]: + if statically_known_true(self.shape[0] == 1) or statically_known_true( + input2.shape[1] == 1 + ): + return (self.unsqueeze(2) * input2.unsqueeze(0)).sum(dim=1) + if self.device.type == "cpu": + if ( + statically_known_true(self.size(-1) == 1) + and statically_known_true(self.size(0) > 0) + and statically_known_true(input2.size(0) == 1) + and (self.dtype == input2.dtype) + and guard_or_false((torch.numel(self) + torch.numel(input2)) <= 32) + ): + counters["inductor"]["decompose_mm"] += 1 + return self * input2 + if statically_known_true(self.size(0) == 1) and statically_known_true( + input2.size(-1) == 1 + ): + counters["inductor"]["decompose_mm"] += 1 + return torch.sum( + self.squeeze(0) * input2.squeeze(-1), dim=0, keepdim=True + ).unsqueeze(0) + return NotImplemented + + +# This pass does two things: +# - Eliminate cat when there is only one tensor input +# - Normalize cat calls, so that legacy empty 1-D tensors are removed (NB: we +# don't remove ALL empty tensors, only the naughty ones) +@register_decomposition([aten.cat.default]) +def cat( + tensors: list[torch.Tensor], + dim: int = 0, +) -> torch.Tensor: + def non_empty_tensor(x: torch.Tensor) -> bool: + # For better or worse, this is a valid cat: + # + # torch.cat([torch.randn(2, 2, 4), torch.randn(0), torch.randn(3, 2, 4)]) + # + # We'd like to eliminate naughtiness like this for downstream passes + # like split_cat. The easiest way is to just drop such inputs + # (guarding that they are non-zero). + # + # Is it permissible for this filtering to be size-oblivious? A case + # where this could matter is cat([(2, 2), (u0,)], dim=0); if u0 + # happened to be zero, we would have liked to have filtered it out. + # But actually, the ONLY way this could have passed is if u0 == 0, + # so by the time we get here we have already installed a deferred + # runtime assert forcing u0 to be zero. So if this hasn't happened, + # we know that the unbacked SymInt has appropriate size and there are + # no problems. + if len(x.shape) == 1 and guard_or_false(x.shape[0] == 0): + return False + + if dim < len(x.shape) and guard_or_false(x.shape[dim] == 0): + return False + + return True + + filtered_tensors = list(filter(non_empty_tensor, tensors)) + + if len(filtered_tensors) == 1: + # check dtype promotion + promoted_dtype = elementwise_dtypes( + *tensors, + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + )[1] + filtered_t = filtered_tensors[0] + return ( + filtered_t.clone() + if promoted_dtype == filtered_t.dtype + else filtered_t.to(dtype=promoted_dtype) + ) + elif 1 < len(filtered_tensors) < len(tensors): + # on the first call, when we remove empty tensors, we redispatch recursively + return aten.cat.default(filtered_tensors, dim) + + # optimization, avoid concat for single, repeated input + if len(filtered_tensors) > 1 and all( + t is filtered_tensors[0] for t in filtered_tensors + ): + inp = filtered_tensors[0] + shape = list(inp.shape) + dim = dim + len(inp.shape) if dim < 0 else dim + shape.insert(dim, len(filtered_tensors)) + return inp.unsqueeze(dim).expand(*shape).flatten(dim, dim + 1).clone() + + # when no 'filtering' has occurred, we raise to prevent infinite recursion (no more decomposition needed) + return NotImplemented + + +@register_decomposition([aten.angle]) +def angle(x: torch.Tensor) -> torch.Tensor: + if x.is_complex(): + return torch.where( + torch.isnan(x.real), float("nan"), torch.atan2(x.imag, x.real) + ) + + # when x is real number + # if x >= 0, return 0 + # if x < 0, return pi + # if x is nan, return nan + _, dtype = elementwise_dtypes( + x, + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + ) + pi = torch.scalar_tensor(math.pi, dtype=dtype, device=x.device) + ret = torch.where(x < 0, pi, 0.0) + return torch.where(torch.isnan(x), float("nan"), ret) + + +@register_decomposition([aten.add]) +def add( + x: torch.Tensor, + y: torch.Tensor, + *, + alpha: Optional[torch.types.Number] = None, +) -> torch.Tensor: + # Require both x and y to be complex tensors. + x_is_complex_tensor = torch.is_tensor(x) and x.is_complex() + y_is_complex_tensor = torch.is_tensor(y) and y.is_complex() + if not x_is_complex_tensor or not y_is_complex_tensor: + return NotImplemented + + def _requires_fallback(tensor: torch.Tensor) -> bool: + if tensor.ndim == 0: + return False + # Viewing complex tensors as their real dtype requires the last stride to be 1. + return tensor.stride()[-1] != 1 + + output_size_zero = False + if x.ndim == 0 and y.ndim == 0: + output_size_zero = True + + if x.ndim == 0: + x = x.reshape(1) + if y.ndim == 0: + y = y.reshape(1) + + z = y + if alpha is not None: + z = alpha * y + complex_type = torch.promote_types(x.dtype, y.dtype) + + if _requires_fallback(x) or _requires_fallback(z): + return NotImplemented + + # For complex typed `x`, `x.view(x.real.dtype)` doubles the last dimension and can cause problem + # when broadcasting the add. + def reshape_tensor_complex(tensor: torch.Tensor) -> torch.Tensor: + """Reshape tensor from [*initial_dims, last_dim] to *initial_dims, last_dim/2, 2]""" + # Get the current shape of the tensor + *initial_dims, last_dim = tensor.shape + + # Check if the last dimension is even. We should never reach here since `x.view(x.real.dtype)` + # doubles the last dimension for complex numbers. + if last_dim % 2 != 0: + raise AssertionError( + "The size of the last dimension must be even to reshape it to [..., last_dim/2, 2]" + ) + + # Reshape the tensor + new_shape = (*initial_dims, last_dim // 2, 2) + reshaped_tensor = tensor.view(new_shape) + return reshaped_tensor + + # Manually resolve complex tensors, as .is_conj() is unreliable after cloning during compilation. + x = x + 0 + z = z + 0 + + x_reshaped = reshape_tensor_complex(x.view(x.real.dtype)) + z_reshaped = reshape_tensor_complex(z.view(y.real.dtype)) + result = torch.flatten(x_reshaped + z_reshaped, start_dim=-2).view(complex_type) + + if output_size_zero: + return result[0] + return result + + +@register_decomposition([aten.conj_physical]) +def conj_physical(self: torch.Tensor) -> torch.Tensor: + if self.is_complex(): + return NotImplemented + return self + + +@register_decomposition([aten.lift, aten.detach_]) +def lift(self: torch.Tensor) -> torch.Tensor: + return self + + +@register_decomposition([aten.fmin, prims.fmin]) +def fmin(self: torch.Tensor, other: torch.Tensor) -> torch.Tensor: + return torch.where(torch.isnan(other) | (other > self), self, other) + + +@register_decomposition([aten.fmax, prims.fmax]) +def fmax(self: torch.Tensor, other: torch.Tensor) -> torch.Tensor: + return torch.where(torch.isnan(other) | (other < self), self, other) + + +@register_decomposition(aten.amax) +def amax( + self: torch.Tensor, + dim: Optional[int] = None, + keepdim: bool = False, +) -> torch.Tensor: + if self.dtype == torch.bool: + return torch.any(self, dim=dim, keepdim=keepdim) + return NotImplemented + + +@register_decomposition(aten.amin) +def amin( + self: torch.Tensor, + dim: Optional[int] = None, + keepdim: bool = False, +) -> torch.Tensor: + if self.dtype == torch.bool: + return torch.all(self, dim=dim, keepdim=keepdim) + return NotImplemented + + +@register_decomposition([aten.narrow_copy]) +def narrow_copy( + self: torch.Tensor, + dim: int, + start: int, + length: int, +) -> torch.Tensor: + return torch.narrow(self, dim, start, length).clone() + + +@register_decomposition([aten.view_copy.default]) +def view_copy_default( + self: torch.Tensor, + size: list[Union[int, torch.SymInt]], +) -> torch.Tensor: + return aten.view(self, size).clone() + + +@register_decomposition([aten.view_copy.dtype]) +def view_copy_dtype( + self: torch.Tensor, + dtype: torch.dtype, +) -> torch.Tensor: + return self.to(dtype).clone() + + +def _get_shape_permutation_like( + self: torch.Tensor, +) -> tuple[utils.ShapeType, utils.StrideType]: + physical_layout, _ = utils.compute_elementwise_output_logical_to_physical_perm(self) + shape = [self.shape[l] for l in physical_layout] + + permutation = [0] * len(shape) + for p, l in enumerate(physical_layout): + permutation[l] = p + + return (shape, permutation) + + +@register_decomposition(aten.full_like) +def full_like( + self: torch.Tensor, + fill_value: Union[int, float], + *, + dtype: Optional[torch.dtype] = None, + layout: Optional[torch.layout] = None, + device: Optional[torch.device] = None, + pin_memory: bool = False, + requires_grad: bool = False, + memory_format: torch.memory_format = torch.preserve_format, +) -> torch.Tensor: + dtype = self.dtype if dtype is None else dtype + layout = self.layout if layout is None else layout + device = self.device if device is None else device + + if memory_format != torch.preserve_format: + result = torch.full( + self.shape, + fill_value, + dtype=dtype, + layout=layout, + device=device, + pin_memory=pin_memory, + requires_grad=requires_grad, + ) + return result.to(memory_format=memory_format) + + else: + assert layout == torch.strided + shape, permutation = _get_shape_permutation_like(self) + result = torch.full( + shape, + fill_value, + dtype=dtype, + layout=layout, + device=device, + pin_memory=pin_memory, + requires_grad=requires_grad, + ) + if permutation == list(range(len(permutation))): + return result + return result.permute(permutation).clone() + + +def _rand_like( + rand_fn: Callable[..., torch.Tensor], + self: torch.Tensor, + *, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + memory_format: torch.memory_format = torch.preserve_format, + **kwargs: Any, +) -> torch.Tensor: + dtype = self.dtype if dtype is None else dtype + device = self.device if device is None else device + + if memory_format != torch.preserve_format: + return rand_fn( + self.shape, + dtype=dtype, + device=device, + **kwargs, + ).to(memory_format=memory_format) + + shape, permutation = _get_shape_permutation_like(self) + result = rand_fn( + shape, + dtype=dtype, + device=device, + **kwargs, + ) + if permutation == list(range(len(permutation))): + return result + return result.permute(permutation).clone() + + +@register_decomposition(aten.rand_like) +def rand_like(self: torch.Tensor, **kwargs: Any) -> torch.Tensor: + return _rand_like(torch.rand, self, **kwargs) + + +@register_decomposition(aten.randn_like) +def randn_like(self: torch.Tensor, **kwargs: Any) -> torch.Tensor: + return _rand_like(torch.randn, self, **kwargs) + + +@register_decomposition(aten.randint_like.default) +def randint_like(self: torch.Tensor, high: int, **kwargs: Any) -> torch.Tensor: + return _rand_like(functools.partial(aten.randint.low, 0, high), self, **kwargs) + + +@register_decomposition(aten.randint_like.low_dtype) +def randint_like_low( + self: torch.Tensor, low: int, high: int, **kwargs: Any +) -> torch.Tensor: + return _rand_like(functools.partial(aten.randint.low, low, high), self, **kwargs) + + +@register_decomposition(aten.randint.default) +def randint( + high: int, + size: list[Union[int, torch.SymInt]], + **kwargs: Any, +) -> torch.Tensor: + return aten.randint.low(0, high, size, **kwargs) + + +@register_decomposition(quantized.linear_dynamic_fp16_unpacked_weight.default) +def linear_dynamic_fp16_unpacked_weight( + input: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: + packed_weight = torch.ops._quantized.wrapped_fbgemm_pack_gemm_matrix_fp16(weight) + return torch.ops._quantized.wrapped_fbgemm_linear_fp16_weight( + input, packed_weight, bias, weight.size()[0] + ) + + +@register_decomposition(_quantized.wrapped_quantized_linear.default) +def wrapped_quantized_linear( + input: torch.Tensor, + input_scale: torch.Tensor, + input_zero_point: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_zero_point: torch.Tensor, + bias: torch.Tensor, + out_scale: torch.Tensor, + out_zero_point: torch.Tensor, + out_channel: int, +) -> torch.Tensor: + packed_weight = torch.ops._quantized._wrapped_linear_prepack( + weight, weight_scale, weight_zero_point, bias + ) + return torch.ops._quantized._wrapped_quantized_linear_prepacked( + input, + input_scale, + input_zero_point, + packed_weight, + out_scale, + out_zero_point, + out_channel, + ) + + +@register_decomposition(torch.ops.quantized.embedding_bag_byte_unpack) +def q_embedding_bag_byte_unpack_decomp(packed: torch.Tensor) -> torch.Tensor: + def bitcast_u8_to_f32(u8: torch.Tensor) -> torch.Tensor: + x, y, z, w = (u8[..., n].to(torch.int32) for n in (0, 1, 2, 3)) + if sys.byteorder == "little": + return (x + (y << 8) + (z << 16) + (w << 24)).view(torch.float32)[..., None] + else: + return ((x << 24) + (y << 16) + (z << 8) + w).view(torch.float32)[..., None] + + scales = bitcast_u8_to_f32(packed[..., -8:-4]) + offsets = bitcast_u8_to_f32(packed[..., -4:]) + return packed[..., :-8].to(torch.float32) * scales + offsets + + +@register_decomposition([aten.grid_sampler_2d]) +@pw_cast_for_opmath +def grid_sampler_2d( + a: torch.Tensor, + grid: torch.Tensor, + interpolation_mode: int = 0, + padding_mode: int = 0, + align_corners: bool = False, +) -> torch.Tensor: + # We do not expand the grid (_expand_grid=False) on cpu for performance reasons + # Experimenting locally it was found that compiled CUDA code is accelerated by ~5x + # and CPU code by ~2x on bicubic mode, if we expand the grid from (N, H, W, 2) into (N, C, H, W, 2) + # However, this leads to a slowdown around ~0.8x on CPU bilinear mode, channels first. + # Thus we apply this hack to not expand the grid for this case. + _expand_grid = not ( + a.device == torch.device("cpu") + and interpolation_mode == 0 + and a.is_contiguous(memory_format=torch.contiguous_format) + ) + + output = decomp_grid_sampler_2d( + a, + grid=grid, + interpolation_mode=interpolation_mode, + padding_mode=padding_mode, + align_corners=align_corners, + _expand_grid=_expand_grid, + ) + return output + + +@register_decomposition(aten._foreach_addcmul.Scalar) +def _foreach_addcmul_scalar( + self: list[torch.Tensor], + left_tensors: list[torch.Tensor], + right_tensors: list[torch.Tensor], + scalar: float = 1, +) -> list[torch.Tensor]: + return aten._foreach_add.List( + self, aten._foreach_mul.List(left_tensors, right_tensors), alpha=scalar + ) + + +@register_decomposition(aten._foreach_addcdiv.Scalar) +def _foreach_addcdiv_scalar( + self: list[torch.Tensor], + left_tensors: list[torch.Tensor], + right_tensors: list[torch.Tensor], + scalar: float = 1, +) -> list[torch.Tensor]: + return aten._foreach_add.List( + self, aten._foreach_div.List(left_tensors, right_tensors), alpha=scalar + ) + + +@register_decomposition(aten._foreach_lerp.Scalar) +def _foreach_lerp_scalar( + start_tensors: list[torch.Tensor], + end_tensors: list[torch.Tensor], + weight: torch.types.Number, +) -> list[torch.Tensor]: + return aten._foreach_add.List( + start_tensors, + aten._foreach_mul.Scalar( + aten._foreach_sub.List(end_tensors, start_tensors), weight + ), + ) + + +@register_decomposition(aten._foreach_lerp.ScalarList) +def _foreach_lerp_scalarlist( + start_tensors: list[torch.Tensor], + end_tensors: list[torch.Tensor], + scalars: list[torch.types.Number], +) -> list[torch.Tensor]: + return aten._foreach_add.List( + start_tensors, + aten._foreach_mul.ScalarList( + aten._foreach_sub.List(end_tensors, start_tensors), scalars + ), + ) + + +@aten.miopen_batch_norm.default.py_impl(torch._C.DispatchKey.Autograd) +@register_decomposition(aten.miopen_batch_norm) +def miopen_batch_norm( + input: torch.Tensor, + weight: torch.Tensor, + bias: typing.Optional[torch.Tensor], + running_mean: typing.Optional[torch.Tensor], + running_var: typing.Optional[torch.Tensor], + training: bool, + exponential_average_factor: float, + epsilon: float, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + a, b, c = aten.native_batch_norm( + input, + weight, + bias, + running_mean, + running_var, + training, + exponential_average_factor, + epsilon, + ) + + if training: + return (a, b, c) + return ( + a, + weight.new_zeros((0,)), + weight.new_zeros((0,)), + ) + + +@functools.cache +def fast_random_decomps() -> dict[Any, Callable[..., Any]]: + return {**decompositions, **extra_random_decomps} + + +# TODO(aakhundov): replace this (and the above) Any by more +# specific type and fix all the cascading mypy errors +def select_decomp_table() -> dict[Any, Callable[..., Any]]: + """decomps can change based on config""" + if config.fallback_random: + return decompositions + if config.fallback_embedding_bag_byte_unpack: + # remove q_embedding_bag_byte_unpack_decomp from decompositions + decompositions.pop(torch.ops.quantized.embedding_bag_byte_unpack.default, None) + return decompositions + return fast_random_decomps() + + +@register_decomposition(aten.masked_scatter) +def masked_scatter( + self: torch.Tensor, + mask: torch.Tensor, + source: torch.Tensor, +) -> torch.Tensor: + from .codegen.common import BackendFeature, has_backend_feature + + if has_backend_feature(self.device, BackendFeature.MASKED_SCATTER_WITH_INDEX): + # This two-step algorithm is the same as eager CUDA, for eager CPU we + # use a 1-shot serial iteration. + self, mask = aten.broadcast_tensors([self, mask]) + source_idx = mask.reshape(-1).cumsum(0) - 1 + self_flat, mask_flat, source_flat = (x.flatten() for x in (self, mask, source)) + result = aten._unsafe_masked_index(source_flat, mask_flat, [source_idx], 0) + return torch.where(mask_flat, result, self_flat).view(self.shape) + return NotImplemented + + +@register_decomposition(quantized_decomposed.choose_qparams.tensor) +def choose_qparams_tensor( + input: torch.Tensor, + quant_min: int, + quant_max: int, + eps: float, + dtype: torch.dtype, +) -> tuple[torch.Tensor, torch.Tensor]: + min_val, max_val = torch.aminmax(input) + scale = (max_val - min_val) / float(quant_max - quant_min) + scale = torch.max(scale, torch.Tensor([eps])) + zero_point = quant_min - torch.round(min_val / scale).to(torch.int) + zero_point = torch.clamp(zero_point, quant_min, quant_max) + return scale.to(torch.float64), zero_point.to(torch.int64) + + +@register_decomposition(aten.put) +def put( + self: torch.Tensor, + index: torch.Tensor, + source: torch.Tensor, + accumulate: bool = False, +) -> torch.Tensor: + flattened = self.flatten() + flattened = torch.index_put( + flattened, [index], source.reshape(index.shape), accumulate + ) + return flattened.reshape(self.shape) + + +@register_decomposition(aten.put_) +def put_( + self: torch.Tensor, + index: torch.Tensor, + source: torch.Tensor, + accumulate: bool = False, +) -> torch.Tensor: + out = aten.put(self, index, source, accumulate=accumulate) + return self.copy_(out) + + +@register_decomposition(aten._softmax_backward_data.default) +@pw_cast_for_opmath +def _softmax_backward_data( + grad_output: torch.Tensor, + output: torch.Tensor, + dim: int, + input_dtype: torch.dtype, +) -> torch.Tensor: + new_grad_output = grad_output * output + sum_new_grad = torch.sum(new_grad_output, dim=dim, keepdim=True) + # grad_input = new_grad_output - output * sum_new_grad + grad_input = inductor_prims.fma(-output, sum_new_grad, new_grad_output) + + # CPU kernel doesn't respect input_dtype, but following check doesn't work for meta tensor + # if grad_output.device == torch.device("cpu"): + # return grad_input.contiguous() + + if grad_output.dtype != input_dtype: + grad_input = grad_input.to(input_dtype) + return grad_input.contiguous() + + +@register_decomposition(aten.index_reduce) +def index_reduce( + self: torch.Tensor, + dim: int, + index: torch.Tensor, + src: torch.Tensor, + reduction_type: str, + *, + include_self: bool = True, +) -> torch.Tensor: + if reduction_type == "mean" and not needs_fallback_due_to_atomic_add_limitations( + self.dtype + ): + true_division = self.dtype.is_floating_point or self.dtype.is_complex + ones = torch.ones_like(src) + if include_self: + out = self + counts = torch.ones_like(self).index_add(dim, index, ones) + else: + out = self.index_fill(dim, index, 0) + counts = torch.zeros_like(self).index_add(dim, index, ones) + counts = counts.masked_fill(counts < 1, 1) + out = out.index_add(dim, index, src) + return out / counts if true_division else out // counts + + if use_scatter_fallback( + aten.scatter_reduce_.two, + reduction_type, + self.dtype, + src.dtype, + src.device.type, + True, + ): + return NotImplemented + + repeats = self.shape[dim + 1 :].numel() * self.shape[:dim].numel() + index_shape = (index.numel(), *self.shape[dim + 1 :], *self.shape[:dim]) + perm = (*range(self.ndim - dim, self.ndim), 0, *range(1, self.ndim - dim)) + scatter_index = ( + index.to(torch.int64) + .repeat_interleave(repeats) + .reshape(index_shape) + .permute(perm) + ) + return self.scatter_reduce( + dim, + scatter_index, + src, + reduction_type, + include_self=include_self, + ) + + +def _max_pool_with_indices( + x: torch.Tensor, + kernel_size: list[int], + stride: Optional[Union[int, list[int]]], + padding: Union[int, list[int]], + dilation: Union[int, list[int]], + ceil_mode: bool, + dim: int, +) -> tuple[torch.Tensor, torch.Tensor]: + if dilation == 1: + dilation = [1] * dim + + if padding == 0: + padding = [0] * dim + + if not stride: + stride = kernel_size + + # pyrefly: ignore [bad-assignment] + kernel_size = pad_listlike(kernel_size, dim) + # pyrefly: ignore [bad-assignment] + dilation = pad_listlike(dilation, dim) + # pyrefly: ignore [bad-assignment] + padding = pad_listlike(padding, dim) + # pyrefly: ignore [bad-assignment] + stride = pad_listlike(stride, dim) + + window_size = functools.reduce(operator.mul, kernel_size) + # We fallback when using non-default dilation or when the window size is too large + if ( + torch._inductor.lowering.should_fallback_max_pool_with_indices( + kernel_size, n_dim=dim + ) + or window_size > torch.iinfo(torch.int8).max + ): + return NotImplemented + + vals, offsets = prims._low_memory_max_pool_with_offsets( + x, + kernel_size, + stride, + padding, + dilation, + ceil_mode, + ) + indices = prims._low_memory_max_pool_offsets_to_indices( + offsets, + kernel_size, + x.shape[-dim:], + stride, + padding, + dilation, + ) + return vals, indices + + +@register_decomposition(aten.max_pool2d_with_indices) +def max_pool2d_with_indices( + x: torch.Tensor, + kernel_size: list[int], + stride: Optional[Union[int, list[int]]] = None, + padding: Union[int, list[int]] = 0, + dilation: Union[int, list[int]] = 1, + ceil_mode: bool = False, +) -> tuple[torch.Tensor, torch.Tensor]: + return _max_pool_with_indices( + x, kernel_size, stride, padding, dilation, ceil_mode, dim=2 + ) + + +@register_decomposition(aten.max_pool3d_with_indices) +def max_pool3d_with_indices( + x: torch.Tensor, + kernel_size: list[int], + stride: Optional[Union[int, list[int]]] = None, + padding: Union[int, list[int]] = 0, + dilation: Union[int, list[int]] = 1, + ceil_mode: bool = False, +) -> tuple[torch.Tensor, torch.Tensor]: + return _max_pool_with_indices( + x, kernel_size, stride, padding, dilation, ceil_mode, dim=3 + ) + + +@register_decomposition(aten.adaptive_max_pool2d) +def adaptive_max_pool2d( + x: torch.Tensor, output_size: list[int] +) -> tuple[torch.Tensor, torch.Tensor]: + *batch, h_in, w_in = x.shape + h_out, w_out = output_size + + if h_out == 0 or w_out == 0: + o_size = [*batch, h_out, w_out] + return x.new_empty(o_size), x.new_empty(o_size, dtype=torch.int64) + + if h_in % h_out == 0 and w_in % w_out == 0: + kernel_size = [h_in // h_out, w_in // w_out] + return aten.max_pool2d_with_indices(x, kernel_size) + + return NotImplemented + + +@register_decomposition(aten.searchsorted.Scalar) +def searchsorted_scalar( + sorted_sequence: torch.Tensor, + self: torch.types.Number, + *, + out_int32: bool = False, + right: bool = False, + side: Optional[str] = None, + sorter: Optional[torch.Tensor] = None, +) -> torch.Tensor: + return aten.searchsorted( + sorted_sequence, + torch.tensor([self], device=sorted_sequence.device), + out_int32=out_int32, + right=right, + side=side, + sorter=sorter, + )[0] + + +@register_decomposition(aten.rrelu_with_noise_functional) +def rrelu_with_noise_functional( + self: torch.Tensor, + noise: torch.Tensor, + lower: float = 0.125, + upper: float = 0.3333333333333333, + training: bool = False, + generator: Optional[torch.Generator] = None, +) -> tuple[torch.Tensor, torch.Tensor]: + if training: + not_positive = self <= 0 + r = aten.uniform(self, lower, upper, generator=generator) + output = torch.where(not_positive, self * r, self) + noise_out = torch.where(not_positive, r, 1) + return output, noise_out + else: + negative_slope = (lower + upper) / 2 + return aten.leaky_relu(self, negative_slope), torch.Tensor() + + +@register_decomposition(aten.repeat_interleave.Tensor) +def repeat_interleave_Tensor( + repeat: torch.Tensor, + output_size: Optional[int] = None, +) -> torch.Tensor: + if config.triton.autotune_at_compile_time: + # We can't compile-time auto-tune this because + # it expects specific data in `repeat` + return NotImplemented + if output_size is None or type(output_size) is not int: + return NotImplemented + if repeat.device.type == "mps": + return NotImplemented + assert repeat.dtype in [torch.int32, torch.int64] + assert repeat.ndim == 1 + cumsum = repeat.cumsum(0) + pos = torch.arange(output_size, device=repeat.device) + indices = torch.searchsorted( + cumsum, pos, out_int32=(repeat.dtype == torch.int32), right=True + ) + return torch.clamp(indices, max=repeat.size(0) - 1) + + +# intentionally not regiestered +def conv1d_to_conv2d( + input: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + stride: tuple[int] = (1,), + padding: tuple[int] = (0,), + dilation: tuple[int] = (1,), + groups: int = 1, +) -> torch.Tensor: + # Shapes: + # input: (N, C_in, L_in) + # weight: (C_out, C_in // groups, K) + # bias: (C_out,) + assert input.dim() == 3 and weight.dim() == 3, ( + "Expect (N,C_in,L) and (C_out,C_in//groups,K)" + ) + + # pyrefly: ignore [bad-assignment] + stride = stride[0] + # pyrefly: ignore [bad-assignment] + padding = padding[0] + # pyrefly: ignore [bad-assignment] + dilation = dilation[0] + + # Unsqueeze to make input 2D: (N,C,L) -> (N,C,L,1) + input_2d = input.unsqueeze(-1) + # Unsqueeze kernel: (C_out,C_in/groups,K) -> (C_out,C_in/groups,K,1) + weight_2d = weight.unsqueeze(-1) + + # Call conv2d with adjusted args + out_2d = aten.conv2d.default( + input_2d, + weight_2d, + bias, + stride=(stride, 1), + padding=(padding, 0), + dilation=(dilation, 1), + groups=groups, + ) + + # Squeeze dummy dimension back out: (N,C_out,L_out,1) -> (N,C_out,L_out) + return out_2d.squeeze(-1) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/dependencies.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/dependencies.py new file mode 100644 index 0000000000000000000000000000000000000000..3495bc35d137c4a340cbaf40efac517886c6920e --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/dependencies.py @@ -0,0 +1,890 @@ +import abc +import dataclasses +import itertools +import logging +import re +from collections.abc import Callable, Iterable, Sequence +from typing import Any, Optional, TypeVar, Union +from typing_extensions import Self +from unittest.mock import patch + +import sympy + +import torch +from torch._inductor.utils import get_free_symbols +from torch.fx.experimental.symbolic_shapes import free_symbols, free_unbacked_symbols +from torch.utils._ordered_set import OrderedSet + +from ..utils._sympy.symbol import make_symbol, SymT +from .codegen.common import index_prevent_reordering +from .ops_handler import DefaultHandler +from .utils import ( + get_dtype_size, + reduction_num_outputs, + sympy_index_symbol, + sympy_subs, + VarRanges, +) +from .virtualized import ReductionType, V + + +T = TypeVar("T") + +log = logging.getLogger(__name__) +is_indirect = re.compile(r"indirect|tmp").search + + +class Dep(abc.ABC): + name: str + index: sympy.Expr + + @abc.abstractmethod + def get_free_symbol_uses( + self, unbacked_only: bool = False + ) -> OrderedSet[sympy.Symbol]: + pass + + @abc.abstractmethod + def rename(self, renames: dict[str, str]) -> Self: + pass + + @abc.abstractmethod + def get_numel(self) -> sympy.Expr: + pass + + @abc.abstractmethod + def numbytes_hint(self) -> int: + pass + + @abc.abstractmethod + def numel_hint(self) -> int: + pass + + @abc.abstractmethod + def has_unbacked_symbols(self) -> bool: + pass + + @abc.abstractmethod + def is_contiguous(self) -> bool: + pass + + def normalize_with_stride_order(self, prefix: str = "t") -> Self: + return self + + +@dataclasses.dataclass(frozen=True) +class MemoryDep(Dep): + # pyrefly: ignore [bad-override] + name: str + # pyrefly: ignore [bad-override] + index: sympy.Expr + var_names: tuple[sympy.Symbol, ...] + size: tuple[sympy.Expr, ...] + mode: Optional[str] = None + + def get_free_symbol_uses( + self, unbacked_only: bool = False + ) -> OrderedSet[sympy.Symbol]: + return ( + get_free_symbols(self.index, unbacked_only) + | get_free_symbols(self.size, unbacked_only) + | get_free_symbols(self.var_names, unbacked_only) + ) + + def __repr__(self) -> str: + maybe_mode = "" + if self.mode is not None: + maybe_mode = f", {self.mode}" + return f"MemoryDep({self.name!r}, {self.index}, {self.ranges}{maybe_mode})" + + @property + def num_vars(self) -> int: + return len(self.var_names) + + def decide_loop_order_to_match(self, other: "MemoryDep") -> Optional[list[int]]: + """ + Can return None if not able to decide loop orders. + """ + assert self.num_vars == other.num_vars + + # ignore broadcast for now since broadcast causes extra 0 strides + # which makes it hard to decide the correct loop orders. + if self.num_vars != len(self.index.free_symbols): + return None + if other.num_vars != len(other.index.free_symbols): + return None + + # bail out if any size is 0 or 1 + # For size == 0, it's an empty tensor, any strides for that dimension + # are equivalent. Skip for simplicity and it may not matter that much. + # + # For size == 1, it cause cause tie for strides of different dimensions. + # Also when we first time create LoopBody in ComputedBuffer.simplify_and_reorder + # we can dependencies.index_vars_squeeze which should already sqeeuze + # the size == 1 dimensions. + if any(s == 0 or s == 1 for s in itertools.chain(self.size, other.size)): + return None + + # Extract strides for both expression + self_strides = V.graph.sizevars.stride_hints(self.index, self.var_names) + other_strides = V.graph.sizevars.stride_hints(other.index, other.var_names) + + # Even if the shape contains no 0/1, some complex index expression may + # still have duplicate stride values. Here is an example: + # https://gist.github.com/shunting314/511a7e1ec88aa2e1a8ec85d8445ab129 + # We don't reorder the loop for these cases for now, but in theory + # we could improve the algorithm to detect the correct loop orders. + if len(OrderedSet(self_strides)) != len(self_strides) or len( + OrderedSet(other_strides) + ) != len(other_strides): + log.debug( + "unable to decide loop order. self_dep=%s v.s. other_dep=%s, self_strides=%s v.s. other_strides=%s", + self, + other, + self_strides, + other_strides, + ) + return None + + # May happen if self and other are as follows + # MemoryDep('addmm_6', 393216*d0 + 768*d1 + d2, {d0: 16, d1: 512, d2: 768}, None) + # MemoryDep('addmm_6', 98304*d0 + d1 + 768*d2, {d0: 64, d1: 768, d2: 128}, None) + if OrderedSet(self_strides) != OrderedSet(other_strides): + return None + + stride_to_index = {s: i for i, s in enumerate(self_strides)} + order = [stride_to_index[s] for s in other_strides] + + assert OrderedSet(order) == OrderedSet(range(self.num_vars)) + return order + + def get_offset(self) -> sympy.Expr: + """ + Return the offset by setting every variable to be 0. + """ + return sympy_subs(self.index, dict.fromkeys(self.var_names, 0)) + + def normalize(self) -> "MemoryDep": + """ + Normalize by merging loops. The different to normalize_with_stride_order is, + this method does not reorder loops while normalize_with_stride_order reorder + loops based on stride order. + """ + return MemoryDep( + self.name, + *_RecordLoadStoreInner._normalize(self.index, self.ranges), # type: ignore[arg-type] + self.mode, + ) + + def normalize_with_stride_order(self, prefix: str = "t") -> "MemoryDep": + r""" + Used to decide if two MemoryDep does not equal due to different loop orders. + More specifically, when dep1 and dep2 are not equal, we can normalize + both and check if they are equal after that. If yes, then the mismatch is + caused by different loop orders. + """ + # import here to avoid circular import + from torch._inductor import ir + + strides = V.graph.sizevars.stride_hints(self.index, self.var_names) + + # pick a loop order with stride ordered decreasingly + order = sorted(range(len(strides)), key=strides.__getitem__, reverse=True) + stride_reorder = ir.same_reorder(order) + sizes = self.size + var_names = self.var_names + + new_reordered_sizes = stride_reorder(sizes) + new_reordered_var_names = stride_reorder(var_names) + + new_simplified_sizes, reindex, _prune = V.graph.sizevars._simplify_loops( + new_reordered_var_names, + new_reordered_sizes, + index_prevent_reordering( + [self.index], new_reordered_var_names, new_reordered_sizes + ), + ) + + # now let's create new symbols with the passed in prefix + var_ranges, add_var = var_builder(prefix) + replacement = dict( + zip( + new_reordered_var_names, + reindex([add_var(x) for x in new_simplified_sizes]), + ) + ) + new_index = sympy_subs(sympy.expand(self.index), replacement) # type: ignore[arg-type] # next PR + + out = MemoryDep( + self.name, new_index, tuple(var_ranges.keys()), tuple(var_ranges.values()) + ) # type: ignore[arg-type] + return out + + @property + def ranges(self) -> dict[sympy.Symbol, sympy.Expr]: + """{c0: 128, c1: 512, ...}""" + return dict(zip(self.var_names, self.size)) + + def simplify_with_ranges(self) -> "MemoryDep": + return MemoryDep( + name=self.name, + index=V.graph.sizevars.simplify_with_ranges(self.index, self.ranges), + var_names=self.var_names, + size=self.size, + mode=self.mode, + ) + + def get_numel(self) -> sympy.Expr: + if self.is_indirect(): + numel = V.graph.get_numel(self.name) + else: + vars: OrderedSet[sympy.Basic] = OrderedSet(self.index.free_symbols) + numel = sympy.S.One + for var, size in zip(self.var_names, self.size): + if var in vars: + numel = numel * size + return numel # type: ignore[return-value] + + def rename(self, renames: dict[str, str]) -> "MemoryDep": + if self.name in renames: + return MemoryDep( + renames[self.name], + self.index, + var_names=self.var_names, + size=self.size, + mode=self.mode, + ) + return self + + def numbytes_hint(self) -> int: + try: + return V.graph.sizevars.size_hint(self.get_numel()) * get_dtype_size( + V.graph.get_dtype(self.name) + ) + except NotImplementedError: # NoneLayout + return 0 + + def numel_hint(self) -> int: + try: + return V.graph.sizevars.size_hint(self.get_numel(), fallback=0) + except NotImplementedError: # NoneLayout + return 0 + + def has_unbacked_symbols(self) -> bool: + return len(free_unbacked_symbols(self.get_numel())) > 0 + + def is_contiguous(self) -> bool: + if isinstance(self.index, sympy.Integer): + return True + return isinstance(self.index, sympy.Symbol) and self.index in self.var_names + + def stride1_for_last_dim(self, result_for_complex_expression: bool = True) -> bool: + """ + Whether the stride for the last dimension is 1. + """ + # python test/inductor/test_torchinductor_opinfo.py -k test_comprehensive_masked_scatter_cuda_float16 + # will exercise thru this corner case. + if len(self.var_names) == 0: + return True + + terms = self.index.args if isinstance(self.index, sympy.Add) else [self.index] + + last_sym = self.var_names[-1] + for term in terms: + if term == last_sym: + return True + + # Having a >1 stride for the last dimension is bad for perf + # return False. + if ( + isinstance(term, sympy.Mul) + and len(term.args) == 2 + and term.args[1] == last_sym + and isinstance(term.args[0], (int, sympy.Integer)) + and term.args[0] > 1 + ): + return False + + return result_for_complex_expression + + def is_scalar(self) -> bool: + if isinstance(self.index, sympy.Symbol): + return self.index not in self.var_names and not self.is_indirect() + return isinstance(self.index, (int, sympy.Integer)) + + def is_indirect(self) -> bool: + return any(is_indirect(v.name) for v in self.index.free_symbols) # type: ignore[attr-defined] + + +@dataclasses.dataclass(frozen=True) +class StarDep(Dep): + # pyrefly: ignore [bad-override] + name: str + mode: Optional[str] = None + + # depends on the entire buffer + @property + # pyrefly: ignore [bad-override] + def index(self) -> sympy.Expr: + raise NotImplementedError("StarDep does not have an index") + + def get_numel(self) -> sympy.Expr: + return V.graph.get_numel(self.name) # type: ignore[return-value] + + def rename(self, renames: dict[str, str]) -> "StarDep": + if self.name in renames: + return StarDep(renames[self.name], self.mode) + return self + + def get_free_symbol_uses( + self, unbacked_only: bool = False + ) -> OrderedSet[sympy.Symbol]: + return OrderedSet() + + def numbytes_hint(self) -> int: + try: + return V.graph.sizevars.size_hint(self.get_numel()) * get_dtype_size( + V.graph.get_dtype(self.name) + ) + except NotImplementedError: + return 0 # NoneLayout, MultiOutputLayout, etc + + def numel_hint(self) -> int: + try: + return V.graph.sizevars.size_hint(self.get_numel(), fallback=0) + except NotImplementedError: + return 0 # NoneLayout, MultiOutputLayout, etc + + def has_unbacked_symbols(self) -> bool: + return len(free_unbacked_symbols(self.get_numel())) > 0 + + def is_contiguous(self) -> bool: + return False + + def is_scalar(self) -> bool: + return False + + def is_indirect(self) -> bool: + return False + + +# Used for tracking mutation ordering +# if A reads a buffer and B mutates it +# B must be ordered after A +# +# This is useful for a variety of reasons. +# For example, if A's read is never actually used, we can eliminate it. +# Another case is if A's buffer ends up being fused away, we never need to +# materialize that buffer +@dataclasses.dataclass(frozen=True) +class WeakDep(Dep): + # Fake dependency on unused buffer + # pyrefly: ignore [bad-override] + name: str + # Buffer that is doing the mutation + mutating_buf: str + # WeakDep's are also used to add dependencies to prevent some specific reordering, + # E.g. collectives global ordering. + # But if other pass guarantees proper ordering by its logic, + # This additional "fake" deps will be holding optimizations. + # This flag is used to identify those additional deps. + is_fake: bool = False + + def get_free_symbol_uses( + self, unbacked_only: bool = False + ) -> OrderedSet[sympy.Symbol]: + return OrderedSet() + + @property + # pyrefly: ignore [bad-override] + def index(self) -> sympy.Expr: + raise NotImplementedError("WeakDep does not have an index") + + def get_numel(self) -> sympy.Expr: + return sympy.S.One + + def rename(self, renames: dict[str, str]) -> "WeakDep": + if self.name in renames: + return WeakDep(renames[self.name], self.mutating_buf, self.is_fake) + return self + + def numbytes_hint(self) -> int: + return 1 # Purely inserted for ordering, not an actual dep + + def numel_hint(self) -> int: + return 1 # Purely inserted for ordering, not an actual dep + + def has_unbacked_symbols(self) -> bool: + return False + + def is_contiguous(self) -> bool: + return False + + +@dataclasses.dataclass(frozen=True) +class IndexExprDep: + index: sympy.Expr # type: ignore[assignment] + var_names: tuple[sympy.Symbol, ...] + size: tuple[sympy.Expr, ...] + + +@dataclasses.dataclass +class ReadWrites: + reads: OrderedSet[Dep] + writes: OrderedSet[Dep] + index_exprs: OrderedSet[IndexExprDep] + range_vars: Optional[list[sympy.Expr]] = None + var_ranges: Optional[VarRanges] = None + + def rename(self, renames: dict[str, str]) -> "ReadWrites": + return ReadWrites( + OrderedSet(dep.rename(renames) for dep in self.reads), + OrderedSet(dep.rename(renames) for dep in self.writes), + self.index_exprs, + self.range_vars, + self.var_ranges, + ) + + def with_read(self, dep: Union[Dep, OrderedSet[Dep]]) -> "ReadWrites": + assert isinstance(dep, (WeakDep, StarDep, OrderedSet)) + if not isinstance(dep, OrderedSet): + dep = OrderedSet([dep]) + return ReadWrites( + OrderedSet.union(self.reads, dep), + self.writes, + self.index_exprs, + self.range_vars, + self.var_ranges, + ) + + def merge(self, other: "ReadWrites") -> "ReadWrites": + reads = OrderedSet.union(self.reads, other.reads) + writes = OrderedSet.union(self.writes, other.writes) + index_exprs = OrderedSet.union(self.index_exprs, other.index_exprs) + return ReadWrites(reads - writes, writes, index_exprs) + + @staticmethod + def merge_list(read_writes: list["ReadWrites"]) -> "ReadWrites": + all_writes = OrderedSet.union(*[rw.writes for rw in read_writes]) + all_reads = OrderedSet.union(*[rw.reads for rw in read_writes]) - all_writes + all_index_exprs = OrderedSet.union(*[rw.index_exprs for rw in read_writes]) + return ReadWrites(all_reads, all_writes, all_index_exprs) + + def remove_reads(self, rem_reads: OrderedSet[Dep]) -> "ReadWrites": + return ReadWrites( + self.reads - rem_reads, + self.writes, + self.index_exprs, + self.range_vars, + self.var_ranges, + ) + + def reads_and_writes(self) -> Iterable[Dep]: + return itertools.chain(self.reads, self.writes) + + def buffer_names(self, ignore_integer_index: bool = True) -> OrderedSet[str]: + """ + Integer index is used for load_seed. + """ + names: OrderedSet[str] = OrderedSet() + for dep in self.reads_and_writes(): + if not isinstance(dep, MemoryDep): + continue + if not ignore_integer_index or not isinstance( + dep.index, (int, sympy.Integer) + ): + names.add(dep.name) + return names + + def get_free_symbol_uses( + self, unbacked_only: bool = False + ) -> OrderedSet[sympy.Symbol]: + result: OrderedSet[sympy.Symbol] = OrderedSet() + + for dep in self.reads_and_writes(): + result |= dep.get_free_symbol_uses(unbacked_only) + return result + + +class _RecordLoadStoreInner(V.MockHandler): # type: ignore[name-defined] + def __init__(self, var_ranges: VarRanges, normalize: bool) -> None: + super().__init__() + self._reads: OrderedSet[Dep] = OrderedSet() + self._writes: OrderedSet[MemoryDep] = OrderedSet() + self._index_exprs: OrderedSet[IndexExprDep] = OrderedSet() + self._var_ranges: VarRanges = var_ranges + self._should_normalize: bool = normalize + + @staticmethod + def drop_unused_symbols( + index: Union[int, sympy.Expr], + var_names: list[sympy.Expr], + sizes: list[sympy.Expr], + ) -> None: + """ + Reduction has last (reduced) dim in its sizes, but + downstream users won't. Normalize this away. + """ + if not isinstance(index, sympy.Expr): + # index can be an int + return + free_symbols = index.free_symbols + while var_names and var_names[-1] not in free_symbols: + var_names.pop() + sizes.pop() + + @classmethod + def _normalize( + cls, index: sympy.Expr, var_ranges: VarRanges + ) -> tuple[sympy.Expr, tuple[sympy.Symbol, ...], tuple[sympy.Expr, ...]]: + # Try to further simplify the indexes even if simplify_loops didn't + # convert it to the simplest form because of the interference from + # different indexing formulas. + index_vars = [*var_ranges.keys()] + sizes = tuple(var_ranges.values()) # type: ignore[assignment] + new_sizes, reindex, _prune = V.graph.sizevars._simplify_loops( + index_vars, + sizes, + index_prevent_reordering([index], index_vars, sizes), + ) + + # assign new variables each dimension to deal with numbering mismatches + # d0, d1, d2 could become d0, d2 -- which won't match d0, d1 + new_vars, add_var = var_builder(canonicalization_prefix()) + replacement = dict(zip(index_vars, reindex([add_var(x) for x in new_sizes]))) + index = sympy_subs(sympy.expand(index), replacement) + + new_vars = [*new_vars.keys()] + new_sizes = [*new_sizes] + cls.drop_unused_symbols(index, new_vars, new_sizes) + return index, tuple(new_vars), tuple(new_sizes) # type: ignore[arg-type] + + def canonicalize( + self, index: sympy.Expr + ) -> tuple[sympy.Expr, tuple[sympy.Symbol, ...], tuple[sympy.Expr, ...]]: + if not self._should_normalize: + sizes = [V.graph.sizevars.simplify(x) for x in self._var_ranges.values()] + var_names = [k for k, v in zip(self._var_ranges.keys(), sizes) if v != 1] + sizes = [v for v in sizes if v != 1] + + self.drop_unused_symbols(index, var_names, sizes) + + return index, tuple(var_names), tuple(sizes) # type: ignore[return-value, arg-type] + var_ranges = { + k: V.graph.sizevars.simplify(v) + for k, v in self._var_ranges.items() + # TODO(jansel): explore this further normalization + # if k in free_symbols + } + return self._normalize(index, var_ranges) + + def load(self, name: str, index: sympy.Expr) -> None: + self._reads.add(MemoryDep(name, *self.canonicalize(index))) + + def load_seed(self, name: str, index: int) -> None: + assert isinstance(index, int) + self.load(name, sympy.Integer(index)) + + def store( + self, name: str, index: sympy.Expr, value: str, mode: Optional[str] = None + ) -> None: + self._writes.add(MemoryDep(name, *self.canonicalize(index), mode=mode)) + + def store_reduction(self, name: str, index: sympy.Expr, value: str) -> None: + self.store(name, index, f"store_reduction({value})") + + def index_expr(self, index: sympy.Expr, dtype: Optional[torch.dtype]) -> None: + self._index_exprs.add(IndexExprDep(*self.canonicalize(index))) + + def bucketize( + self, + values: T, + boundaries: tuple[str, sympy.Expr, sympy.Expr, sympy.Expr], + boundary_indices: T, + indexing_dtype: torch.dtype, + right: bool, + sorter: Optional[tuple[str, sympy.Expr]] = None, + sorter_indices: Optional[T] = None, + ) -> None: + """Records the names of the buffers that bucketize will read from.""" + self._reads.add(StarDep(boundaries[0])) + if sorter is not None: + self._reads.add(StarDep(sorter[0])) + + +class RecordLoadStore(V.KernelFormatterHandler): # type: ignore[name-defined] + def __init__(self, var_ranges: VarRanges, normalize: bool) -> None: + parent_handler = _RecordLoadStoreInner( + var_ranges=var_ranges, normalize=normalize + ) + super().__init__(parent_handler=parent_handler) + + +# TODO: check call sites +def var_builder(prefix: str) -> tuple[VarRanges, Callable[[sympy.Expr], sympy.Symbol]]: + cnt = itertools.count() + var_ranges: VarRanges = {} + + def add_var(length: sympy.Expr) -> sympy.Symbol: + v = sympy_index_symbol(f"{prefix}{next(cnt)}") + var_ranges[v] = length + return v + + return var_ranges, add_var + + +def index_vars_no_squeeze( + *argsizes: Sequence[sympy.Expr], prefix: str +) -> tuple[list[list[sympy.Symbol]], VarRanges]: + var_ranges, add_var = var_builder(prefix) + args: list[list[sympy.Symbol]] = [list(map(add_var, size)) for size in argsizes] + return args, var_ranges + + +def index_vars_squeeze( + *argsizes: Sequence[sympy.Expr], prefix: str = "d" +) -> tuple[list[Sequence[sympy.Expr]], VarRanges]: + from .ir import SqueezeView + + var_ranges, add_var = var_builder(prefix) + args: list[Sequence[sympy.Expr]] = [] + new_sizes: list[Sequence[sympy.Expr]] = [] + for size in argsizes: + new_size, reindex = SqueezeView.squeezer(size) + new_sizes.append(new_size) + args.append(reindex(list(map(add_var, new_size)))) + return args, var_ranges + + +def extract_read_writes( + fn: Callable[..., Any], + *argsizes: Sequence[sympy.Expr], + normalize: bool = False, + prefix: str = "d", + hidden_args: Sequence[list[sympy.Expr]] = (), +) -> ReadWrites: + args, var_ranges = index_vars_squeeze(*argsizes, prefix=prefix) + + from .loop_body import LoopBody + + if isinstance(fn, LoopBody): + inner = extract_loop_body_with_args( + fn, + [*args, *hidden_args], # type: ignore[list-item] + var_ranges, + normalize, + ) + else: + # Slow path tracing the function + rw = RecordLoadStore(var_ranges, normalize=normalize) + with V.set_ops_handler(rw): + fn(*args, *hidden_args) + inner = rw.parent_handler + + if normalize: + range_vars = [] # Number of vars could differ due to normalization + else: + range_vars = [*itertools.chain.from_iterable(args)] + + return ReadWrites( + # pyrefly: ignore [missing-attribute] + OrderedSet(inner._reads), + # pyrefly: ignore [missing-attribute] + OrderedSet(inner._writes), + # pyrefly: ignore [missing-attribute] + inner._index_exprs, + range_vars, + var_ranges, + ) + + +def extract_loop_body_with_args( + fn: Any, + args: list[list[sympy.Expr]], + var_ranges: VarRanges, + normalize: bool = False, +) -> _RecordLoadStoreInner: + from .loop_body import MemoryUsageType + + # Fast path to avoid tracing when we already have a LoopBody + inner = _RecordLoadStoreInner(var_ranges=var_ranges, normalize=normalize) + name_to_index = fn.indexing_from_args(args) + if fn.indirect_vars: + # mimic the `tmpX` naming tracing gives us + repl = {v: make_symbol(SymT.TMP, i) for i, v in enumerate(fn.indirect_vars)} + name_to_index = {k: sympy_subs(v, repl) for k, v in name_to_index.items()} # type: ignore[arg-type] + for entry in fn.memory_usage[MemoryUsageType.LOAD]: + inner.load(entry.buffer_name, name_to_index[entry.index_name]) # type: ignore[arg-type] + for entry in fn.memory_usage[MemoryUsageType.LOAD_SEED]: + inner.load_seed(entry.buffer_name, int(name_to_index[entry.index_name])) # type: ignore[arg-type] + for entry in fn.memory_usage[MemoryUsageType.STORE]: + inner.store( + entry.buffer_name, + name_to_index[entry.index_name], + None, # type: ignore[arg-type] + entry.mode, + ) + for entry in fn.memory_usage[MemoryUsageType.STORE_REDUCTION]: + inner.store_reduction( + entry.buffer_name, + name_to_index[entry.index_name], + None, # type: ignore[arg-type] + ) + for entry in fn.memory_usage[MemoryUsageType.INDEX_EXPR]: + inner.index_expr(name_to_index[entry.index_name], None) + for entry in fn.memory_usage[MemoryUsageType.BUCKETIZE]: + # All that matters is that we record the buffer name, so place it in the + # "boundaries" name position to ensure that it's recorded. + inner.bucketize( + None, + (entry.buffer_name, None, None, None), + None, + None, # type: ignore[arg-type] + None, # type: ignore[arg-type] + ) + # fn.memory_usage[MemoryUsageType.CHECK_BOUNDS] intentionally skipped + return inner + + +def extract_input_node_reduction_ranges( + input_node: "torch._inductor.ir.IRNode", +) -> tuple[Optional[list[sympy.Expr]], Optional[list[sympy.Expr]]]: + """ + Returns the size and reduction size of all inputs, if the sizes and reduction_sizes (if exist) are all the same. + It's possible that a node has multiple inputs, some are Reduction nodes and others are Pointwise nodes. + In this case, reduction_sizes of the Reduction nodes need to be the same. + Otherwise returns (None, None). + """ + + from .ir import ComputedBuffer, ExternKernel, Loops + + size: Optional[list[sympy.Expr]] + reduction_size: Optional[list[sympy.Expr]] + + if isinstance(input_node.get_defining_op(), ComputedBuffer): + # Input node has already been realized. Return its size and reduction_size. + size = [*input_node.get_size()] + reduction_size = [*input_node.get_reduction_size()] + if len(reduction_size) > 0: + return (size, reduction_size) + else: + return (None, None) + + if not isinstance(input_node.data.data, Loops): # type: ignore[attr-defined] + # Other IRNodes do not have reduction_ranges. + return (None, None) + + # There is one issue: what if there are views / permutations between the input node and its dependent realized nodes? + # The current method still uses reduction ranges from the dependent realized node, which is not ideal. + # Is there a way to check whether there are permutations in between? + reads = input_node.get_reads() + reduction_size: Optional[list[sympy.Expr]] = None + size: Optional[list[sympy.Expr]] = None + while reduction_size is None and len(reads) > 0: + seen: OrderedSet[str] = OrderedSet() + new_reads: list[Dep] = [] + for read in reads: + if not isinstance(read, MemoryDep): + continue + if read.name in seen: + continue + seen.add(read.name) + buffer = V.graph.try_get_buffer(read.name) + if buffer is None: + continue + op = buffer.get_defining_op() + if op is None or isinstance(op, ExternKernel): + continue + + if isinstance(op, ComputedBuffer) and len(op.get_reduction_size()) > 0: + if reduction_size is None: + reduction_size = [*op.get_reduction_size()] + size = [*op.get_size()] + elif reduction_size != [*op.get_reduction_size()] or size != [ + *op.get_size() + ]: + return (None, None) + else: + new_reads.extend(op.get_reads()) + if reads == new_reads: + return (size, reduction_size) + else: + reads = OrderedSet(new_reads) + return (size, reduction_size) + + +def canonicalization_prefix() -> str: + return "c" + + +# ops handler which computes all the free symbols for an IR +class FreeSymbolsOpsHandler(DefaultHandler): + symbols: OrderedSet[sympy.Symbol] + + def __init__(self, unbacked_only: bool = True) -> None: + self.symbols = OrderedSet() + self.get_symbols = free_unbacked_symbols if unbacked_only else free_symbols + + def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: + for a in itertools.chain(args, kwargs.values()): + if isinstance(a, (sympy.Expr, sympy.logic.boolalg.Boolean)): + self.symbols |= self.get_symbols(a) + + def indirect_indexing( + self, + index_var: Any, + size: Union[int, sympy.Expr], + check: bool = True, + wrap_neg: bool = True, + ) -> sympy.Symbol: + assert not isinstance(index_var, (sympy.Expr, sympy.logic.boolalg.Boolean)) + self.symbols |= self.get_symbols(size) + return sympy_index_symbol(f"({str(index_var)})") + + def frexp(self, x: Any) -> tuple[None, ...]: + return (None,) * 2 + + def scan( + self, dtypes: Any, combine_fn: Any, values: Sequence[Any] + ) -> tuple[None, ...]: + return (None,) * len(values) + + def sort( + self, dtypes: Any, values: Sequence[Any], stable: Any, descending: Any + ) -> tuple[None, ...]: + return (None,) * len(values) + + def reduction( + self, + dtype: torch.dtype, + src_dtype: torch.dtype, + reduction_type: ReductionType, + value: Union[None, tuple[None, ...]], + ) -> Union[None, tuple[None, ...]]: + num_values = reduction_num_outputs(reduction_type) + return (None,) * num_values if num_values > 1 else None + + def masked(self, mask: Any, body: Callable[..., Any], other: Any) -> None: + assert callable(body), "masked body must always be callable." + # The body can make additional calls, for e.g. ops.indirect_indexing + body() + + +def extract_free_symbols( + fn: Callable[..., Any], + index: Sequence[sympy.Expr], + rindex: Optional[Sequence[sympy.Expr]] = None, + unbacked_only: bool = True, +) -> OrderedSet[sympy.Symbol]: + from .ir import FlexibleLayout + + args = [index, rindex] if rindex is not None else [index] + handler = FreeSymbolsOpsHandler(unbacked_only) + # NB: I cargo culted the allow_indexing patch here, I don't understand why + # people do this all over + with ( + V.set_ops_handler(handler), + patch.object(FlexibleLayout, "allow_indexing", True), + ): + fn(*args) + return handler.symbols diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/distributed_autotune.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/distributed_autotune.py new file mode 100644 index 0000000000000000000000000000000000000000..ec53d25efcd5b5a2ac19adbdc8ac3a3e352caa0f --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/distributed_autotune.py @@ -0,0 +1,386 @@ +from __future__ import annotations + +import contextlib +import dataclasses +from typing import Any, TYPE_CHECKING, Union +from unittest.mock import patch + +import sympy + +import torch._logging +import torch.distributed as dist +import torch.fx +from torch.utils._ordered_set import OrderedSet + +from . import config, select_algorithm +from .ir import ( + Buffer, + ChoiceCaller, + Layout, + MultiTemplateBuffer, + OperationBuffer, + ShapeAsConstantBuffer, + StorageBox, + TensorBox, +) +from .kernel_inputs import KernelInputs, MMKernelInputs +from .scheduler import SchedulerNode +from .virtualized import NullHandler, V + + +if TYPE_CHECKING: + from collections.abc import Generator, Sequence + + +_DISTRIBUTED_AUTOTUNE_KEY = "distributed_autotune" + +_AUTOTUNE_PG: dist.ProcessGroup | None = None + + +@dataclasses.dataclass +class _DistributedAutotuneState: + """ + State used to track autotuning during a graph_context() + """ + + # This is the next operator index. Used to figure out which rank should do + # the autotuning. + autotuned_index: int = 0 + + # For debugging - used to make sure that we autotune the same number of + # local operators that we expected to. + autotuned_local_count: int = 0 + + +@dataclasses.dataclass +class _DistributedAutotuneInfo: + index: int + local: bool + + +def get_autotune_pg() -> dist.ProcessGroup | None: + if dist.is_available() and dist.is_initialized(): + global _AUTOTUNE_PG + if _AUTOTUNE_PG is None: + _AUTOTUNE_PG = dist.distributed_c10d._new_group_with_tag( + pg_tag="pt2_distributed_autotune_pg" + ) + return _AUTOTUNE_PG + + return None + + +def schedule(scheduler: torch._inductor.scheduler.Scheduler) -> None: + """ + Finish the distributed autotuning by propagating the autotuning results + between the ranks and then replacing the placeholder with the real Buffer. + """ + assert config.distributed_max_autotune_gemm + autotune_results = _autotune_local_nodes(scheduler) + choices_by_index = _sync(autotune_results) + _autotune_remote_nodes(scheduler, choices_by_index) + + +@contextlib.contextmanager +def graph_context() -> Generator[None, None, None]: + """ + Wrapped around processing a graph, sets up figuring out which ranks tune + which shapes. + """ + assert not isinstance( + V.get_distributed_autotune_state(check_poisoned=False), # type: ignore[call-arg] + _DistributedAutotuneState, + ) + V.set_distributed_autotune_state(_DistributedAutotuneState()) + try: + yield + finally: + V.set_distributed_autotune_state(NullHandler()) + + +def maybe_autotune_remote( + name: str, choices: list[ChoiceCaller], inputs: list[Buffer], layout: Layout +) -> TensorBox | ShapeAsConstantBuffer | None: + """ + Used by an op (like `mm`) to determine if the op should be autotuned + locally (returns None) or remotely (returns a placeholder Buffer). + """ + if not config.distributed_max_autotune_gemm: + return None + + if not (autotune_pg := get_autotune_pg()): + return None + + if len(choices) <= 1: + return None + + state = V.distributed_autotune_state + index = state.autotuned_index + state.autotuned_index += 1 + local = index % autotune_pg.size() == autotune_pg.rank() + + V.current_node.meta[_DISTRIBUTED_AUTOTUNE_KEY] = _DistributedAutotuneInfo( + index, local + ) + if local: + state.autotuned_local_count += 1 + return None + + return torch._inductor.ir.TensorBox.create( + _DistributedAutotuneBuffer(name, inputs, layout) + ) + + +class _DistributedAutotuneBuffer(MultiTemplateBuffer): + """ + A MultiTemplateBuffer which represents a kernel being autotuned on a + different rank. When `schedule` is called this will be replaced by the + "real" buffer. + """ + + # Name of the kernel being autotuned. + _kernel_name: str + + def __init__( + self, + kernel_name: str, + inputs: list[Buffer], + layout: Layout, + ) -> None: + super().__init__( + layout, + inputs, + choice_timings_fn=self._dummy_choice_timings, + unfiltered_choices=[], + allowed_prologue_inps=OrderedSet({}), + ) + + self._kernel_name = kernel_name + + def _dummy_choice_timings( + self, _hint_override: int | None + ) -> dict[ChoiceCaller, float]: + # This should never get called. It means that a remote autotune was + # scheduled but never filled in. + raise NotImplementedError + + def autotune(self, ser_choice: _SerializedChoice) -> TensorBox: + """ + Given a _SerializedChoice (autotune results from another rank) + compute the final TensorBox. + """ + + from .select_algorithm import autotune_select_algorithm + + with patch.object(V.graph, "scheduler", None): + kernel_inputs = MMKernelInputs([*self.original_inputs]) + assert isinstance(self.layout, Layout) + choice = ser_choice.get_choice(self.layout, kernel_inputs) + buffer = autotune_select_algorithm( + self._kernel_name, + [choice], + kernel_inputs.nodes(), + self.layout, + ) + assert isinstance(buffer, TensorBox) + return buffer + + +# Can we make this async? +def _sync(autotune_results: list[_SerializedChoice]) -> Sequence[_SerializedChoice]: + """ + Perform the all_gather to collect the autotune results from all the ranks. + """ + + autotune_pg = get_autotune_pg() + assert autotune_pg + + # Perform allgather + all_states: list[list[_SerializedChoice]] = [None] * autotune_pg.size() # type: ignore[list-item] + torch.distributed.all_gather_object(all_states, autotune_results, group=autotune_pg) + + node_count = sum(len(x) for x in all_states) + # It's faster to briefly lie about the type than to unzip the results and append. + choices_by_index: list[_SerializedChoice] = [None] * node_count # type: ignore[list-item] + + check_count = 0 + for other_results in all_states: + for choice in other_results: + assert isinstance(choice, _SerializedChoice) + assert choices_by_index[choice.index] is None + choices_by_index[choice.index] = choice + check_count += 1 + + assert node_count == check_count, f"count mismatch: {node_count} != {check_count}" + return choices_by_index + + +class _SerializedChoice: + """ + This is a serializer for the autotune choice. KernelTemplateChoice can't + be serialized directly (the template and inputs prevent this) so we need to + serialize it by parts and reconstruct later on. + """ + + def __init__(self, index: int, choice: ChoiceCaller) -> None: + self.index = index + self.template_uid = _SerializedChoice._template_uid_from_choice(choice) + self.kwargs = self._compute_kwargs(choice.description) + + def get_choice(self, layout: Layout, inputs: KernelInputs) -> ChoiceCaller | None: + """ + Deserialize the ChoiceCaller and return it. + """ + + template = self._template_from_uid() + + kwargs = {**self.kwargs} + if "BLOCK_K" in kwargs: + # TODO: Do we really need to externally compute this value? If it's + # needed I'm surprised it's not just part of the original template + # description. + # This needs the actual 'k' to figure out the value. + k = inputs.nodes()[0].get_size()[1] + kwargs["EVEN_K"] = sympy.gcd(k, kwargs["BLOCK_K"]) == kwargs["BLOCK_K"] + + extra_kwargs: dict[str, Any] = {} + from .kernel_template_choice import ( + DictKernelTemplateParams, + KernelTemplateChoice, + ) + + params = DictKernelTemplateParams(kwargs) + ktc = KernelTemplateChoice(template, params, extra_kwargs, layout, inputs) + return ktc.choice + + @staticmethod + def _compute_kwargs(description: str) -> dict[str, Union[int, str, bool]]: + """ + Given a template description turn it into input kwargs. + """ + if not description: + return {} + + # TODO: It seems like it would be better if the template could provide + # this directly instead of having to parse a string. + kwargs: dict[str, Union[int, str, bool]] = {} + for cfg in description.split(","): + key, val = cfg.split("=", 1) + key, val = key.strip(), val.strip() + if val == "True": + kwargs[key] = True + elif val == "False": + kwargs[key] = False + elif val.isdigit(): + kwargs[key] = int(val) + else: + assert val.startswith("'") and val.endswith("'") + kwargs[key] = val[1:-1] + return kwargs + + @staticmethod + def _template_uid_from_choice(choice: ChoiceCaller) -> str: + """ + Given a ChoiceCaller figure out which template represents it. This + is reversed by _template_from_uid(). + """ + + # We need a better way to do this - right now we need to add each + # supported template directly. + if isinstance(choice, select_algorithm.ExternKernelCaller): + if choice.choice.name == "mm": + return "torch._inductor.kernel.mm.aten_mm" + else: + raise RuntimeError(f"TODO: kernel {choice.choice.name!r}") + elif isinstance(choice, select_algorithm.TritonTemplateCaller): + return "torch._inductor.kernel.mm.mm_template" + else: + raise RuntimeError(f"TODO: {type(choice)}") + + def _template_from_uid(self) -> Any: + """ + See _template_uid_from_choice(). + """ + parts = self.template_uid.split(".") + obj = globals()[parts[0]] + for k in parts[1:]: + obj = getattr(obj, k) + return obj + + +def _autotune_local_nodes( + scheduler: torch._inductor.scheduler.Scheduler, +) -> list[_SerializedChoice]: + """ + Go through the nodes in the scheduler and autotune the kernels which + should be autotuned by this rank. + """ + + autotune_results: list[_SerializedChoice] = [] + + for node in scheduler.nodes: + if not isinstance(node, SchedulerNode): + continue + + if (inner_node := node.node) is None: + continue + + if isinstance(inner_node, _DistributedAutotuneBuffer): + # This is marked for remote autotuning. + continue + + if not isinstance(inner_node, MultiTemplateBuffer): + continue + + if (origin_node := inner_node.origin_node) is None: + continue + + if (meta := origin_node.meta) is None: + continue + + info = meta.get(_DISTRIBUTED_AUTOTUNE_KEY) + if info is None: + continue + + assert info.local + + # We force autotuning here + # Still takes advantage of async precompile + # We need all the configs before fusion + min_choice, _ = inner_node.get_min_choice() + + choice = _SerializedChoice(info.index, min_choice) + autotune_results.append(choice) + + state = V.distributed_autotune_state + assert len(autotune_results) == state.autotuned_local_count, ( + f"incorrect local autotuned nodes found ({len(autotune_results)} != {state.autotuned_local_count})" + ) + return autotune_results + + +def _autotune_remote_nodes( + scheduler: torch._inductor.scheduler.Scheduler, + choices_by_index: Sequence[_SerializedChoice], +) -> None: + """ + Go through the nodes in the scheduler and autotune the nodes that were + autotuned on remote ranks. + """ + + for i, node in enumerate(scheduler.nodes): + if isinstance(node, SchedulerNode) and isinstance( + (dist_node := node.node), _DistributedAutotuneBuffer + ): + assert dist_node.origin_node is not None + info = dist_node.origin_node.meta[_DISTRIBUTED_AUTOTUNE_KEY] + out_tensorbox = dist_node.autotune(choices_by_index[info.index]) + + out_storage = out_tensorbox.data + assert isinstance(out_storage, StorageBox) + out_buffer = out_storage.data + assert isinstance(out_buffer, OperationBuffer) + + assert out_buffer.layout == dist_node.layout + + scheduler._replace_node(out_buffer, dist_node, i, node) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/dtype_propagation.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/dtype_propagation.py new file mode 100644 index 0000000000000000000000000000000000000000..7e8583c9804e633485f06363cc363e53d39c259f --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/dtype_propagation.py @@ -0,0 +1,400 @@ +# mypy: allow-untyped-defs +import functools +from collections.abc import Callable, Sequence +from typing import Any, Optional, Protocol, TYPE_CHECKING, TypeVar, Union + +import sympy + +import torch +from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND, type_to_dtype +from torch.utils._ordered_set import OrderedSet + +from .ops_handler import OP_NAMES, OpsHandler +from .utils import upcast_compute_type +from .virtualized import OpsValue, V + + +T = TypeVar("T") + + +class DTypeVar(Protocol): + @property + def dtype(self) -> torch.dtype: ... + + +DTypeArg = Union[DTypeVar, torch.types.Number, str, OpsValue] + + +# Inputs need to be cacheable (e.g., not a CSEVar) in order for the cache to be effective +# So first decompose CSEVars -> tuple before calling this + + +@functools.cache +def get_promoted_dtype( + *args: Sequence[tuple[torch.dtype, bool]], + type_promotion_kind: Optional[ELEMENTWISE_TYPE_PROMOTION_KIND] = None, +): + def construct_input(inp): + if inp[1]: + return torch.empty([], dtype=inp[0]) + else: + return torch.empty([1], dtype=inp[0]) + + inps = [construct_input(arg) for arg in args] + _, dtype = torch._prims_common.elementwise_dtypes( + *inps, + type_promotion_kind=( + type_promotion_kind + if type_promotion_kind + else ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ), + ) + return dtype + + +def promote_types( + args: Sequence[DTypeArg], + type_promotion_kind: Optional[ELEMENTWISE_TYPE_PROMOTION_KIND] = None, +): + dtype_prop_candidates = [] + + # pyrefly: ignore [bad-assignment] + for arg in args: + assert not isinstance(arg, str) + if isinstance(arg, OpsValue): + arg = arg.value + assert isinstance(arg, torch._prims_common.Number) or hasattr(arg, "dtype") + + if isinstance(arg, torch._prims_common.Number): + dtype_prop_candidates.append((type_to_dtype(type(arg)), True)) + continue + + # pyrefly: ignore [missing-attribute] + dtype_prop_candidates.append((arg.dtype, getattr(arg, "is_scalar", False))) + + dtype = get_promoted_dtype( + *dtype_prop_candidates, + type_promotion_kind=type_promotion_kind, + ) + + return dtype + + +class DtypePropagationOpsHandler: + """ + Propagate dtype from args to output + """ + + # Singleton DtypePropagationOpsHandler, because we meta program over a number of op rules. + # Those are only defined after other inductor state has run. + + _instance: Optional["DtypePropagationOpsHandler"] = None + + def __new__(cls): + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + def __init__(self) -> None: + for op, rule in torch._inductor.utils.op_dtype_propagation_rules.items(): + fn = ( + functools.partial(self.return_dtype, dtype=rule.override_return_dtype) + if rule.override_return_dtype + else functools.partial( + self.op_dtype_rule, type_promotion_kind=rule.type_promotion_kind + ) + ) + setattr(self, op, fn) + + # Set pointwise operation rules + for op in torch._inductor.codegen.common.pointwise_overrides_data.values(): + if not hasattr(self, op.name): + setattr( + self, + op.name, + functools.partial( + self.op_dtype_rule, type_promotion_kind=op.type_promotion_kind + ), + ) + + # Set boolean operation rules + for op in torch._inductor.utils.boolean_ops(): + if not hasattr(self, op): + setattr( + self, op, functools.partial(self.return_dtype, dtype=torch.bool) + ) + + unimplemented_ops = OP_NAMES - OrderedSet(dir(self)) + torch._check( + len(unimplemented_ops) == 0, + lambda: f"Unimplemented dtype rule for ops: {unimplemented_ops}", + ) + + # metaprogrammed in __init__ + + @staticmethod + def op_dtype_rule( + *args: DTypeArg, type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND + ) -> torch.dtype: + return promote_types(args, type_promotion_kind=type_promotion_kind) + + @staticmethod + def return_dtype(*args: DTypeArg, dtype: torch.dtype) -> torch.dtype: + return dtype + + # op rules + + @staticmethod + def constant(value: torch.types.Number, dtype: torch.dtype) -> torch.dtype: + return upcast_compute_type(dtype) + + @staticmethod + def load_seed(name: str, offset: int) -> torch.dtype: + return upcast_compute_type(V.graph.get_dtype(name)) + + @staticmethod + def randint64(seed: int, offset: int, low: int, high: int) -> torch.dtype: + return torch.int64 + + @staticmethod + def masked( + mask: DTypeArg, body: Callable[[], DTypeArg], other: DTypeArg + ) -> torch.dtype: + from .loop_body import LoopBodyBlock + + assert isinstance(body, LoopBodyBlock), "body must be a LoopBodyBlock" + # TODO - we avoid calling this in codegen, needs work for non codegen use cases + loads = body.graph.find_nodes(op="call_method", target="load") + if len(loads) <= 1: + return promote_types([other]) + + return upcast_compute_type(V.graph.get_dtype(loads[-1].args[1])) + + @staticmethod + def where(a: DTypeArg, b: DTypeArg, c: DTypeArg) -> torch.dtype: + return promote_types([b, c]) + + @staticmethod + def index_expr(expr: sympy.Expr, dtype: torch.dtype) -> torch.dtype: + # TODO - TODO - rationalize index_expr. The dtype is not always used and we are inconsistent about int32 or int64 + # in lowerings. cpp just uses the dtype + if dtype not in (torch.int32, torch.int64) or not hasattr( + V.kernel, "index_dtype" + ): + return upcast_compute_type(dtype) + + return V.kernel.get_index_dtype_as_torch_dtype() + + @staticmethod + def to_dtype( + x: DTypeArg, + dtype: torch.dtype, + src_dtype: Optional[torch.dtype] = None, + use_compute_types=True, + ) -> torch.dtype: + return upcast_compute_type(dtype) if use_compute_types else dtype + + @staticmethod + def to_dtype_bitcast( + x: DTypeArg, dtype: torch.dtype, src_dtype: torch.dtype + ) -> torch.dtype: + return upcast_compute_type(dtype) + + @staticmethod + def gelu(x: DTypeArg) -> torch.dtype: + return promote_types([x]) + + @staticmethod + def mul(a: DTypeArg, b: DTypeArg) -> torch.dtype: + return promote_types([a, b]) + + @staticmethod + def truediv(a: DTypeArg, b: DTypeArg) -> torch.dtype: + return promote_types([a, b]) + + @staticmethod + def pow(a: DTypeArg, b: DTypeArg) -> torch.dtype: + return promote_types([a, b]) + + @staticmethod + def mod(a: DTypeArg, b: DTypeArg) -> torch.dtype: + return promote_types([a, b]) + + @staticmethod + def indirect_indexing( + x: DTypeArg, size: int, check: bool = True, wrap_neg: bool = True + ) -> torch.dtype: + return torch.int64 + + @staticmethod + def randn(seed: int, offset: int) -> torch.dtype: + return torch.float + + @staticmethod + def rand(seed: int, offset: int) -> torch.dtype: + return torch.float + + @staticmethod + def store_reduction(name: str, index, value: DTypeArg) -> None: + return None + + @staticmethod + def reduction( + dtype: torch.dtype, src_dtype: torch.dtype, reduction_type: str, value: DTypeArg + ) -> torch.dtype: + return dtype + + @staticmethod + def store(name: str, index, value: DTypeArg, mode: Optional[str] = None) -> None: + return None + + @staticmethod + def partial_accumulate( + name: str, + reduction_type: str, + value: DTypeArg, + extra_meta: dict[str, Any], + ) -> None: + return None + + @staticmethod + def load(name: str, index) -> torch.dtype: + return upcast_compute_type(V.graph.get_dtype(name)) + + @staticmethod + def floor(x: DTypeArg) -> torch.dtype: + return promote_types( + [x], type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ) + + @staticmethod + def ceil_to_int(x: DTypeArg, dtype: torch.dtype) -> torch.dtype: + return dtype + + @staticmethod + def int_truediv(x: DTypeArg, y: DTypeArg) -> torch.dtype: + return promote_types( + [x, y], type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ) + + @staticmethod + def scan( + dtypes: tuple[torch.dtype, ...], + combine_fn: Callable[[tuple[T, ...], tuple[T, ...]], tuple[T, ...]], + values: tuple[T, ...], + ) -> tuple[torch.dtype, ...]: + return dtypes + + @staticmethod + def fmod(x: DTypeArg, y: DTypeArg) -> torch.dtype: + return promote_types([x, y]) + + @staticmethod + def round_to_int(x: DTypeArg, dtype: torch.dtype) -> torch.dtype: + return dtype + + @staticmethod + def identity(x: DTypeArg) -> torch.dtype: + return promote_types([x]) + + @staticmethod + def frexp(x: DTypeArg) -> tuple[torch.dtype, torch.dtype]: + # TODO - need to handle multiple outputs + return (promote_types([x]), torch.int32) + + @staticmethod + def sort( + dtypes: tuple[torch.dtype, ...], + values: tuple[T, ...], + stable: bool, + descending: bool, + ) -> tuple[torch.dtype, ...]: + return dtypes + + @staticmethod + def trunc(x: DTypeArg) -> torch.dtype: + return promote_types([x]) + + @staticmethod + def bucketize( + values: DTypeArg, + boundaries: tuple[str, sympy.Expr, sympy.Expr, sympy.Expr], + boundary_indices: DTypeArg, + indexing_dtype: torch.dtype, + right: bool, + sorter: Optional[tuple[str, sympy.Expr]] = None, + sorter_indices: Optional[T] = None, + ) -> torch.dtype: + return indexing_dtype + + @staticmethod + def rshift(x: DTypeArg, y: DTypeArg) -> torch.dtype: + return promote_types([x]) + + @staticmethod + def round(x: DTypeArg) -> torch.dtype: + return promote_types( + [x], type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ) + + @staticmethod + def trunc_to_int(x: DTypeArg, dtype: torch.dtype) -> torch.dtype: + return dtype + + @staticmethod + def floor_to_int(x: DTypeArg, dtype: torch.dtype) -> torch.dtype: + return dtype + + @staticmethod + def truncdiv(x: DTypeArg, y: DTypeArg) -> torch.dtype: + return promote_types([x, y]) + + @staticmethod + def floordiv(x: DTypeArg, y: DTypeArg) -> torch.dtype: + return promote_types([x, y]) + + @staticmethod + def halide_clamp(value, size, check): + # TODO - way of registering dtype for op in backend + return torch.int32 + + @staticmethod + def dot(x: DTypeArg, y: DTypeArg) -> torch.dtype: + # triton tl.dot out_dtype is tl.float32 by default. + return torch.float32 + + @staticmethod + def inline_asm_elementwise( + *inputs, asm, constraints=None, dtype=torch.float32, is_pure=True, pack=1 + ): + return dtype + + @staticmethod + def lshift(x: DTypeArg, y: DTypeArg) -> torch.dtype: + return promote_types([x]) + + @staticmethod + def check_bounds( + expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool + ) -> None: + return None + + def output(self, *args: DTypeArg) -> None: + raise AssertionError( + f"{type(self).__name__}: ops.output should not appear here" + ) + + def placeholder(self, index: int) -> torch.dtype: + raise AssertionError( + f"{type(self).__name__}: ops.placeholder should not appear here" + ) + + @staticmethod + def device_assert_async(cond, msg: str) -> None: + return None + + +if TYPE_CHECKING: + + class _typecheck_DtypePropagation(DtypePropagationOpsHandler, OpsHandler[Any]): + pass # mypy will error if we got any of the signatures wrong diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/exc.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/exc.py new file mode 100644 index 0000000000000000000000000000000000000000..8c932c0369897b7be92e8e9dbcb797ce1ce88230 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/exc.py @@ -0,0 +1,160 @@ +from __future__ import annotations + +import os +import tempfile +import textwrap +from functools import lru_cache +from typing import Any, Optional, TYPE_CHECKING + +from torch._dynamo.exc import BackendCompilerFailed, ShortenTraceback + + +if TYPE_CHECKING: + import types + + from torch.cuda import _CudaDeviceProperties + +if os.environ.get("TORCHINDUCTOR_WRITE_MISSING_OPS") == "1": + + @lru_cache(None) + def _record_missing_op(target: Any) -> None: + with open(f"{tempfile.gettempdir()}/missing_ops.txt", "a") as fd: + fd.write(str(target) + "\n") + +else: + + def _record_missing_op(target: Any) -> None: # type: ignore[misc] + pass + + +class OperatorIssue(RuntimeError): + @staticmethod + def operator_str(target: Any, args: list[Any], kwargs: dict[str, Any]) -> str: + lines = [f"target: {target}"] + [ + f"args[{i}]: {arg}" for i, arg in enumerate(args) + ] + if kwargs: + lines.append(f"kwargs: {kwargs}") + return textwrap.indent("\n".join(lines), " ") + + +class MissingOperatorWithoutDecomp(OperatorIssue): + def __init__(self, target: Any, args: list[Any], kwargs: dict[str, Any]) -> None: + _record_missing_op(target) + super().__init__(f"missing lowering\n{self.operator_str(target, args, kwargs)}") + + +class MissingOperatorWithDecomp(OperatorIssue): + def __init__(self, target: Any, args: list[Any], kwargs: dict[str, Any]) -> None: + _record_missing_op(target) + super().__init__( + f"missing decomposition\n{self.operator_str(target, args, kwargs)}" + + textwrap.dedent( + f""" + + There is a decomposition available for {target} in + torch._decomp.get_decompositions(). Please add this operator to the + `decompositions` list in torch._inductor.decomposition + """ + ) + ) + + +class LoweringException(OperatorIssue): + def __init__( + self, exc: Exception, target: Any, args: list[Any], kwargs: dict[str, Any] + ) -> None: + super().__init__( + f"{type(exc).__name__}: {exc}\n{self.operator_str(target, args, kwargs)}" + ) + + +class SubgraphLoweringException(RuntimeError): + pass + + +class InvalidCxxCompiler(RuntimeError): + def __init__(self) -> None: + from . import config + + super().__init__( + f"No working C++ compiler found in {config.__name__}.cpp.cxx: {config.cpp.cxx}" + ) + + +class CppWrapperCodegenError(RuntimeError): + def __init__(self, msg: str) -> None: + super().__init__(f"C++ wrapper codegen error: {msg}") + + +class CppCompileError(RuntimeError): + def __init__(self, cmd: list[str], output: str) -> None: + if isinstance(output, bytes): + output = output.decode("utf-8") + + self.cmd = cmd + self.output = output + + super().__init__( + textwrap.dedent( + """ + C++ compile error + + Command: + {cmd} + + Output: + {output} + """ + ) + .strip() + .format(cmd=" ".join(cmd), output=output) + ) + + def __reduce__(self) -> tuple[type, tuple[list[str], str]]: + return (self.__class__, (self.cmd, self.output)) + + +class CUDACompileError(CppCompileError): + pass + + +class TritonMissing(ShortenTraceback): + def __init__(self, first_useful_frame: Optional[types.FrameType]) -> None: + super().__init__( + "Cannot find a working triton installation. " + "Either the package is not installed or it is too old. " + "More information on installing Triton can be found at: https://github.com/triton-lang/triton", + first_useful_frame=first_useful_frame, + ) + + +class GPUTooOldForTriton(ShortenTraceback): + def __init__( + self, + # pyrefly: ignore [not-a-type] + device_props: _CudaDeviceProperties, + first_useful_frame: Optional[types.FrameType], + ) -> None: + super().__init__( + f"Found {device_props.name} which is too old to be supported by the triton GPU compiler, " + "which is used as the backend. Triton only supports devices of CUDA Capability >= 7.0, " + f"but your device is of CUDA capability {device_props.major}.{device_props.minor}", + first_useful_frame=first_useful_frame, + ) + + +class InductorError(BackendCompilerFailed): + backend_name = "inductor" + + def __init__( + self, + inner_exception: Exception, + first_useful_frame: Optional[types.FrameType], + ) -> None: + self.inner_exception = inner_exception + ShortenTraceback.__init__( + self, + f"{type(inner_exception).__name__}: {inner_exception}", + first_useful_frame=first_useful_frame, + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/extern_node_serializer.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/extern_node_serializer.py new file mode 100644 index 0000000000000000000000000000000000000000..0e5f42e7309e85035a8db51e1f5acc782336ddb0 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/extern_node_serializer.py @@ -0,0 +1,24 @@ +import json + +from torch._export.serde.schema import ExternKernelNode, ExternKernelNodes, Node +from torch._export.serde.serialize import _dataclass_to_dict, EnumEncoder +from torch._inductor.ir import ExternKernelNode as inductor_ExternKernelNode + + +def serialize_extern_kernel_node( + extern_kernel_node: inductor_ExternKernelNode, +) -> ExternKernelNode: + assert isinstance(extern_kernel_node.node, Node) + return ExternKernelNode( + name=extern_kernel_node.name, + node=extern_kernel_node.node, + ) + + +def extern_node_json_serializer( + extern_kernel_nodes: list[inductor_ExternKernelNode], +) -> str: + serialized_nodes = ExternKernelNodes( + nodes=[serialize_extern_kernel_node(node) for node in extern_kernel_nodes] + ) + return json.dumps(_dataclass_to_dict(serialized_nodes), cls=EnumEncoder) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/freezing.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/freezing.py new file mode 100644 index 0000000000000000000000000000000000000000..70ebe6e9ead06394ad949e22a38ebadf55ab7e3a --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/freezing.py @@ -0,0 +1,292 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import itertools +import logging +import weakref +from typing import Any, Optional + +import torch +import torch.utils._pytree as pytree +from torch._dynamo.utils import dynamo_timed, lazy_format_graph_code +from torch._functorch.aot_autograd import MutationType +from torch._functorch.compile_utils import fx_graph_cse +from torch._inductor.constant_folding import constant_fold, replace_node_with_constant +from torch._inductor.freezing_utils import enter_freezing, record_has_frozen_params +from torch._inductor.fx_passes.freezing_patterns import freezing_passes +from torch._inductor.fx_passes.post_grad import view_to_reshape + +from . import config + + +aten = torch.ops.aten +prims = torch.ops.prims + +log = logging.getLogger(__name__) + + +def replace_params_with_constants( + gm: torch.fx.GraphModule, + flat_params: list[Any], + fw_metadata: torch._functorch.aot_autograd.ViewAndMutationMeta, +) -> list[int]: + """ + Replaces the parameters of a PyTorch GraphModule with constants wherever possible. + Returns a list of indices representing the input parameters that were not converted to constants. + """ + params = gm.graph.find_nodes(op="placeholder") + fake_inp_nodes = params[: len(params)] + preserved_arg_indices = [] + aliased_input_args = [ + out_info.base_idx + for out_info in fw_metadata.output_info + if out_info.base_idx is not None + ] + + # TODO (tmanlaibaatar) figure out why this is different + # from mutated_inp_runtime_indices + mutated_inps = [ + i + for i, m in enumerate(fw_metadata.input_info) + if m.mutation_type + in (MutationType.MUTATED_IN_GRAPH, MutationType.MUTATED_OUT_GRAPH) + ] + + static_indices_new = [] + static_indices_offset = 0 + for i, (real_input, node) in enumerate(zip(flat_params, fake_inp_nodes)): + if i in mutated_inps or i in aliased_input_args: + preserved_arg_indices.append(i) + if i in fw_metadata.static_input_indices: + new_static_index = i - static_indices_offset + static_indices_new.append(new_static_index) + else: + replace_node_with_constant(gm, node, real_input) + static_indices_offset += 1 + # add on non param inputs + preserved_arg_indices.extend(range(len(flat_params), len(params))) + # is this necessary ? + fw_metadata.static_input_indices = static_indices_new + gm.recompile() + return preserved_arg_indices + + +def freeze( + dynamo_gm: torch.fx.GraphModule, + aot_autograd_gm: torch.fx.GraphModule, + example_inputs: list[torch._subclasses.FakeTensor], +) -> tuple[torch.fx.GraphModule, list[int]]: + """ + Inlines parameters that are not mutated into constants and optimizes the graph through constant propagation + and other techniques. If enabled, the function also discards the original parameters of the module for memory efficiency. + + Assumes that this function is run in dynamo tracing post aot_autograd. + + Args: + dynamo_gm (torch.fx.GraphModule): The Dynamo constructed GraphModule. + aot_autograd_gm (torch.fx.GraphModule): The aot_autograd constructed GraphModule to be frozen. + example_inputs (List[torch.Tensor]): A list of example input tensors to be used in the freezing process. + + Returns: + Tuple[torch.fx.GraphModule, List[int]]: A tuple containing the frozen GraphModule and a list of indices + of the inputs that were preserved (not turned into constants). + """ + with enter_freezing(): + return _freeze(dynamo_gm, aot_autograd_gm, example_inputs) + + +def _freeze( + dynamo_gm: torch.fx.GraphModule, + aot_autograd_gm: torch.fx.GraphModule, + example_inputs: list[torch._subclasses.FakeTensor], +) -> tuple[torch.fx.GraphModule, list[int]]: + # We have convert conv's weight to channels last which may meet error for .view + # when doing fake_tensor_prop. So we need to convert view to reshape first. + # See the details in fx_codegen_and_compile of compile_fx.py. + view_to_reshape(aot_autograd_gm) + + if tracing_context := torch._guards.TracingContext.try_get(): + fw_metadata = tracing_context.fw_metadata + assert tracing_context.params_flat_unwrap_subclasses is not None + params_flat = tracing_context.params_flat_unwrap_subclasses + assert fw_metadata is not None and params_flat is not None + + preserved_arg_indices = replace_params_with_constants( + aot_autograd_gm, params_flat, fw_metadata + ) + else: + inputs = aot_autograd_gm.graph.find_nodes(op="placeholder") + preserved_arg_indices = list(range(len(inputs))) + + # TODO - further restrict cse ? right now needed to dedup aliasing ops + cse_graph = fx_graph_cse(aot_autograd_gm.graph) + aot_autograd_gm.graph = cse_graph + aot_autograd_gm.recompile() + + aot_example_inputs = [example_inputs[ind] for ind in preserved_arg_indices] + freezing_passes(aot_autograd_gm, aot_example_inputs) + + constant_fold(aot_autograd_gm) + # invalidate nn Modules + if config.freezing_discard_parameters: + invalidate_eager_modules() + discard_traced_gm_params(dynamo_gm) + + log.debug( + "%s", lazy_format_graph_code("FROZEN GRAPH", aot_autograd_gm, colored=True) + ) + + record_has_frozen_params(aot_autograd_gm) + return aot_autograd_gm, preserved_arg_indices + + +class ErasedTensor(torch.Tensor): + @staticmethod + def __new__(cls, elem, name, owning_mod): + return super().__new__(cls, elem.to(device="meta")) + + def __init__(self, elem, name: Optional[str], mod) -> None: + self.erased_name = name + self.owning_mod_ref = weakref.ref(mod) + + @classmethod + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): # type: ignore[override] + erased_tensors = [ + e + # pyrefly: ignore [bad-unpacking] + for e in pytree.arg_tree_leaves(*args, **kwargs) + if isinstance(e, ErasedTensor) + ] + assert len(erased_tensors) > 0 + e = erased_tensors[0] + + raise RuntimeError( + f"Trying to run Pytorch Eager Module after Dynamo Freezing. " + "The original parameters have been discarded for memory efficiency. " + f"Found in op {func} for erased parameter {e.erased_name} of {e.owning_mod_ref()}" + ) + + +def invalidate_eager_modules(): + with torch.utils._python_dispatch._disable_current_modes(): + for ( + mod + ) in torch._guards.TracingContext.get().module_context.nn_modules.values(): + if not isinstance(mod, torch.nn.Module): + continue + + for attr_name, tensor in list( + itertools.chain( + mod.named_parameters(recurse=False), + # pyrefly: ignore [bad-argument-type] + mod.named_buffers(recurse=False), + ) + ): + with torch._dispatch.python.no_python_dispatcher(): + e_t = ErasedTensor(tensor, attr_name, mod) + if isinstance(tensor, torch.nn.Parameter): + e_t.requires_grad_(True) + e_t._is_param = True + setattr(mod, attr_name, e_t) + + +def discard_traced_gm_params(mod: torch.fx.GraphModule): + with torch.utils._python_dispatch._disable_current_modes(): + for attr_name, tensor in list( + itertools.chain( + mod.named_parameters(recurse=False), + # pyrefly: ignore [bad-argument-type] + mod.named_buffers(recurse=False), + ) + ): + with torch._dispatch.python.no_python_dispatcher(): + e_t = ErasedTensor(tensor, attr_name, mod) + if isinstance(tensor, torch.nn.Parameter): + e_t.requires_grad_(True) + e_t._is_param = True + setattr(mod, attr_name, e_t) + + +def enforce_output_layout(gm: torch.fx.GraphModule): + """ + Make sure the output node's layout does not change due to compiler optimizations + by adding aten.as_strided nodes with the expected strides. + + Only used for inference so we can assume all graph outputs are model outputs. + """ + *_, output_node = gm.graph.nodes + out_list = output_node.args[0] + with gm.graph.inserting_before(output_node): + for n in out_list: + if not isinstance( + n.meta["val"], torch.Tensor + ) or not torch._prims_common.is_non_overlapping_and_dense(n.meta["val"]): + continue + + # add a node to enforce eager layout + ft = n.meta["val"] + new_node = gm.graph.call_function( + prims.inductor_force_stride_order.default, (n, ft.stride()) + ) + + # can not call + # n.replace_all_uses_with(new_node) + # since it will replace the usage of n in new_node itself. + output_node.replace_input_with(n, new_node) + + gm.graph.lint() + gm.recompile() + + +def enforce_as_strided_input_layout(gm: torch.fx.GraphModule): + """ + Make sure the as_strided node's input's layout does not change due to compiler + optimizations, because the as_strided strides info depends on input tensor stride info. + """ + + as_strided_ops = [ + torch.ops.aten.as_strided.default, + torch.ops.aten.as_strided_.default, + torch.ops.aten.as_strided_scatter.default, + ] + strided_nodes = [n for n in gm.graph.nodes if n.target in as_strided_ops] + for n in strided_nodes: + with gm.graph.inserting_before(n): + # add a node to enforce eager layout + ft = n.args[0].meta["val"] + new_node = gm.graph.call_function( + prims.inductor_force_stride_order.default, (n.args[0], ft.stride()) + ) + n.replace_input_with(n.args[0], new_node) + + gm.graph.lint() + gm.recompile() + + +def convert_conv_weights_to_channels_last(gm: torch.fx.GraphModule): + """ + Convert 4d convolution weight tensor to channels last format. + + This pass is performed before freezing so the added nodes can be constant + folded by freezing. + """ + with dynamo_timed("convert_conv_weights_to_channels_last"): + convs = [n for n in gm.graph.nodes if n.target is aten.convolution.default] + for conv in convs: + weight_node = conv.args[1] + if len(weight_node.meta["val"].size()) != 4 or weight_node.meta[ + "val" + ].is_contiguous(memory_format=torch.channels_last): + # not a 4d tensor or already channels last, skip + continue + + with gm.graph.inserting_before(conv): + new_node = gm.graph.call_function( + aten.clone.default, + (weight_node,), + {"memory_format": torch.channels_last}, + ) + conv.replace_input_with(weight_node, new_node) + + enforce_as_strided_input_layout(gm) + enforce_output_layout(gm) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/freezing_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/freezing_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8a14890aacbd76acd0e49726d9eba99c590e83c8 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/freezing_utils.py @@ -0,0 +1,55 @@ +import contextlib +import threading +from collections.abc import Generator +from typing import Any + +import torch + + +_TLS = threading.local() + + +def _freezing_active() -> bool: + return getattr(_TLS, "freezing_active", False) + + +@contextlib.contextmanager +def enter_freezing() -> Generator[Any, None, None]: + """ + Context manager to designate when freezing is active. + """ + prev = _freezing_active() + _TLS.freezing_active = True + try: + yield + finally: + _TLS.freezing_active = prev + + +def record_has_frozen_params(gm: torch.fx.GraphModule) -> None: + """ + Mark the gm as having frozen params. + """ + gm._has_frozen_params = True # type: ignore[assignment] + + +def has_frozen_params(gm: torch.fx.GraphModule) -> bool: + """ + Return True if the gm has frozen parameters. + """ + return getattr(gm, "_has_frozen_params", False) + + +def maybe_set_is_frozen_param(t: torch.Tensor) -> None: + """ + Mark the provided tensor as a frozen param if freezing is active. + """ + if _freezing_active(): + t._is_frozen_param = True # type: ignore[attr-defined] + + +def is_frozen_param(t: torch.Tensor) -> bool: + """ + Return True if the tensor is a frozen param. + """ + return getattr(t, "_is_frozen_param", False) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fuzzer.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fuzzer.py new file mode 100644 index 0000000000000000000000000000000000000000..152dce202676623960af408fe63577846691ccb3 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fuzzer.py @@ -0,0 +1,1008 @@ +import importlib +import itertools +import logging +import pickle +import random +import signal +import string +import traceback +from collections.abc import Callable, KeysView, Sequence +from enum import Enum +from functools import partial, wraps +from types import FrameType +from typing import Any, get_args, get_origin, Literal, Optional, TypeVar, Union + +import torch +from functorch.compile import min_cut_rematerialization_partition +from torch._inductor.custom_graph_pass import CustomGraphPass, CustomPartitionerFn +from torch._inductor.scheduler import BaseSchedulerNode +from torch.utils._config_module import _ConfigEntry, ConfigModule +from torch.utils._ordered_set import OrderedSet + + +log = logging.getLogger(__name__) + + +def is_type(type_hint, comp_type) -> bool: # type: ignore[no-untyped-def] + """ + Determines if type_hint is comp_type. There are some type annotations that this doesn't work for. + I think it's because some Type annotations are Type Objects and some are Special Forms, but not sure. + There's definite room for improvement to make this more general for someone who deeply understands + Python types. + """ + return type_hint is comp_type or get_origin(type_hint) is comp_type + + +def is_optional_type(type_hint) -> bool: # type: ignore[no-untyped-def] + """ + Special case of is_type. + """ + origin = get_origin(type_hint) + + if origin is Union: + args = get_args(type_hint) + return type(None) in args + + return False + + +def is_callable_type(type_hint) -> bool: # type: ignore[no-untyped-def] + """ + Special Case of is_type. + """ + return type_hint.__name__ == "Callable" + + +class DummyPass(CustomGraphPass): + """ + A Dummy pass to be used by ConfigFuzzer + """ + + def __call__(self, graph: torch.fx.graph.Graph) -> None: + return None + + def uuid(self) -> Optional[Any]: + return None + + +class DummyPartitionerFn(CustomPartitionerFn): + """ + A Dummy partitioner function to be used by ConfigFuzzer + """ + + def __call__( + self, gm: torch.fx.GraphModule, joint_inputs: Sequence[object], **kwargs: Any + ) -> tuple[torch.fx.GraphModule, torch.fx.GraphModule]: + return min_cut_rematerialization_partition(gm, joint_inputs, **kwargs) + + def uuid(self) -> Optional[Any]: + return None + + +T = TypeVar("T") + + +class TypeExemplars: + """ + This class returns examples of a Type, given its class name. + """ + + TYPE_EXEMPLARS: dict[str, Any] = { + CustomGraphPass.__name__: DummyPass(), + CustomPartitionerFn.__name__: DummyPartitionerFn(), + torch.fx.graph.Graph.__name__: torch.fx.graph.Graph(), + BaseSchedulerNode.__name__: BaseSchedulerNode(None), # type: ignore[arg-type] + } + + @staticmethod + def example(t: type[T]) -> Optional[T]: + """ + Return an example of a class. + """ + # pyrefly: ignore [bad-argument-type, bad-argument-count] + return TypeExemplars.TYPE_EXEMPLARS.get(t.__name__, None) + + @staticmethod + def contains(t: type[T]) -> bool: + # pyrefly: ignore [bad-argument-type, bad-argument-count] + return t.__name__ in TypeExemplars.TYPE_EXEMPLARS + + +def check_halide_import() -> bool: + """checks if we have halide available""" + try: + importlib.import_module("halide") + return True + except ModuleNotFoundError: + return False + + +if check_halide_import(): + CUDA_BACKEND = ["triton", "halide"] +else: + CUDA_BACKEND = ["triton"] + + +class Status(Enum): + """ + The Status return value enum for Config Fuzzer + """ + + # ConfigFuzzer skipped the test + SKIPPED = "skipped" + # ConfigFuzzer compiled and ran the test and function it passed. + PASSED = "passed" + # ConfigFuzzer failed to compile the test function + FAILED_COMPILE = "failed_compile" + # ConfigFuzzer compiled the test function and running it raised an exception + FAILED_RUN_COMPILE_EXCEPTION = "failed_run_compile_exception" + # ConfigFuzzer ran eager and it raised an exception + FAILED_RUN_EAGER_EXCEPTION = "failed_run_eager_exception" + # ConfigFuzzer compiled the test function, but the return value indicated that the compiled value didn't match the + # value from eager (or however else you set up the comparison in the test function) + FAILED_RUN_RETURN = "failed_run_return" + + def failing(self) -> bool: + """ + Convenience method to check whether these status represent failure. + """ + return ( + self == Status.FAILED_COMPILE + or self == Status.FAILED_RUN_EAGER_EXCEPTION + or self == Status.FAILED_RUN_COMPILE_EXCEPTION + or self == Status.FAILED_RUN_RETURN + ) + + +# Sometime the types of configs aren't expressive enough to be captured by python type system, so the options can be +# manually specified here: +# TODO this needs to be indexed to the module, like inductor or dynamo, for name collisions +TYPE_OVERRIDES: dict[str, list[Any]] = { + "cuda_backend": CUDA_BACKEND, + "post_grad_fusion_options": [ + { + "batch_linear_post_grad": { + "shape_broadcast_batch_linear": True, + "fuse_nodes_with_same_users": True, + }, + "batch_aten_mul": {"fuse_nodes_with_same_parent": False}, + "batch_aten_sigmoid": {"fuse_nodes_with_same_parent": True}, + "batch_aten_add": {"fuse_nodes_with_same_parent": True}, + "normalization_aten_pass": {}, + "unbind_stack_aten_pass": {}, + }, + { + "batch_aten_add": {}, + "batch_aten_mul": {}, + "batch_aten_sub": {}, + "batch_aten_div": {}, + "group_linear": {"require_fbgemm": True}, + }, + ], + "autoheuristic_collect": ["pad_mm", "mixed_mm"], + "autoheuristic_use": ["pad_mm", "mixed_mm"], + "traceable_tensor_subclasses": [OrderedSet()], + "nontraceable_tensor_subclasses": [OrderedSet()], +} +SamplingType = Callable[[str, type[Any], Any], Any] + + +class SamplingMethod(Enum): + """ + This class handles the process of assigning concrete values to type annotations. So a type annotation of + ```python + foo: Optional[int] = None + ``` + Will be assigned an int if the dispatch function gets TOGGLE, or a 50/50 split between an int and None if it gets + RANDOM. + """ + + TOGGLE = "TOGGLE" # toggle to the opposite value + RANDOM = "RANDOM" # randomly choose an option + + @staticmethod + def _generate_value_for_type( + random_sample: bool, field_name: str, type_hint: type[Any], default: Any + ) -> Any: + """ + Generates a value of a type based on the setting. + """ + # look for name in type overrides + if field_name in TYPE_OVERRIDES: + return random.choice(TYPE_OVERRIDES[field_name]) + + if type_hint is bool: + return random.choice([True, False]) if random_sample else not default + elif type_hint is int: + # NOTE initially tried to use negation of the value, but it doesn't work because most types are ints + # when they should be natural numbers + zero. Python types to cover these values aren't super convenient. + return random.randint(0, 1000) + elif type_hint is float: + return random.uniform(0, 1000) + elif type_hint is str: + characters = string.ascii_letters + string.digits + string.punctuation + return "".join( + random.choice(characters) for _ in range(random.randint(1, 20)) + ) + elif is_type(type_hint, list): + elem_type = getattr( + type_hint, + "__args__", + [type(default[0])] if default and len(default) else [type(None)], + )[0] + new_default = default[0] if default and len(default) > 0 else None + return [ + SamplingMethod._generate_value_for_type( + random_sample, field_name, elem_type, new_default + ) + for _ in range(random.randint(1, 3)) + ] + elif is_type(type_hint, set): # noqa: set_linter + indexable = list(default) + elem_type = getattr( + type_hint, + "__args__", + [type(indexable[0])] if default and len(default) else [type(None)], + )[0] + new_default = indexable[0] if default and len(default) > 0 else None + return { # noqa: set_linter + SamplingMethod._generate_value_for_type( + random_sample, field_name, elem_type, new_default + ) + for _ in range(random.randint(1, 3)) + } + elif is_type(type_hint, OrderedSet): + indexable = list(default) + elem_type = getattr( + type_hint, + "__args__", + [type(indexable[0])] if default and len(default) else [type(None)], + )[0] + new_default = indexable[0] if default and len(default) > 0 else None + return OrderedSet( + [ + SamplingMethod._generate_value_for_type( + random_sample, field_name, elem_type, new_default + ) + for _ in range(random.randint(1, 3)) + ] + ) + elif is_type(type_hint, dict): + key_type, value_type = getattr( + type_hint, + "__args__", + map(type, next(iter(default.items()))) + if (default is not None and len(default)) + else (type(None), type(None)), + ) + if default is not None and len(default.items()) > 0: + default_key, default_val = next(iter(default.items())) + else: + default_key, default_val = None, None + return { + SamplingMethod._generate_value_for_type( + random_sample, field_name, key_type, default_key + ): SamplingMethod._generate_value_for_type( + random_sample, field_name, value_type, default_val + ) + for _ in range(random.randint(0, 3)) + } + elif is_type(type_hint, Union): + # do whatever is not the type of default + try: + assert len(type_hint.__args__) > 1 + except AttributeError as err: + raise ValueError("Union type with no args") from err + if random_sample: + new_type = random.choice(type_hint.__args__) + else: + new_type = random.choice( + [t for t in type_hint.__args__ if t is not type(default)] + ) + try: + new_default = new_type() + except Exception: + # if default constructor doesn't work, try None + new_default = None + + return SamplingMethod._generate_value_for_type( + random_sample, field_name, new_type, new_default + ) + elif is_type(type_hint, tuple): + args = getattr( + type_hint, + "__args__", + tuple(map(type, default)), + ) + zipped = zip(args, default) + return tuple( + map( # noqa: C417 + lambda x: SamplingMethod._generate_value_for_type( + random_sample, field_name, x[0], x[1] + ), + zipped, + ) + ) + elif is_type(type_hint, Literal): + try: + if random_sample: + return random.choice(type_hint.__args__) + else: + choices = [t for t in type_hint.__args__ if t != default] + if choices: + return random.choice(choices) + else: + return default + except AttributeError as err: + raise ValueError("Literal type with no args") from err + elif is_optional_type(type_hint): + try: + elem_type = type_hint.__args__[0] + except AttributeError as err: + raise ValueError("Optional type with no args") from err + if random_sample: + return random.choice( + [ + None, + SamplingMethod._generate_value_for_type( + random_sample, field_name, elem_type, default + ), + ] + ) + else: + if default is None: + return SamplingMethod._generate_value_for_type( + random_sample, field_name, elem_type, None + ) + else: + return None + elif type_hint is type(None): + return None + elif is_callable_type(type_hint): + try: + return_type = list(type_hint.__args__)[-1] + except AttributeError as err: + raise ValueError("Callable type with no args") from err + + @wraps(lambda *args, **kwargs: None) + def dummy_function(*args, **kwargs): # type: ignore[no-untyped-def] + return SamplingMethod._generate_value_for_type( + random_sample, field_name, return_type, None + ) + + return dummy_function + elif type_hint == torch._ops.OpOverload: + return torch.ops.aten.add.default + elif TypeExemplars.contains(type_hint): + return TypeExemplars.example(type_hint) + elif type_hint == Any: + return 1 if default != 1 else 2 + else: + raise ValueError(f"Unable to process type {type_hint}. PRs welcome :)") + + @staticmethod + def dispatch(sm: "SamplingMethod") -> SamplingType: + """ + Returns a function that will generate values from a type, based on the SamplingMethod passed in. + """ + if sm == SamplingMethod.RANDOM: + return partial(SamplingMethod._generate_value_for_type, True) + elif sm == SamplingMethod.TOGGLE: + return partial(SamplingMethod._generate_value_for_type, False) + else: + raise ValueError(f"malformed sampling method: {sm}") + + +class Default: + """ + Singleton default object that will cause the ConfigFuzzer to always use the default value set in the config. + """ + + +DEFAULT = Default() + +# The combination of config settings being set (based on their strings) +ComboType = tuple[str, ...] + + +class ResultType: + """ + The mapping of the combo strings to the result status after running the config fuzzer. + """ + + _vals: dict[ComboType, Status] + + def __repr__(self) -> str: + return f"ResultType[{self._vals}]" + + def __init__(self) -> None: + self._vals = {} + + def __len__(self) -> int: + return len(self._vals) + + def num_ran(self) -> int: + """ + Returns how many combos actually ran (weren't skipped). + """ + ret = len(self._vals) + for status in self._vals.values(): + if status == Status.SKIPPED: + ret -= 1 + return ret + + def set(self, combo: ComboType, status: Status) -> None: + combo = tuple(sorted(combo)) + self._vals[combo] = status + + def lookup(self, combo: ComboType) -> Optional[Status]: + combo = tuple(sorted(combo)) + return self._vals.get(combo, None) + + def keys(self) -> KeysView[ComboType]: + return self._vals.keys() + + +# Type that maps config strings to their default value +ConfigType = dict[str, Any] +# Callable that returns a bool +FactoryOutputType = Callable[[], bool] +# input function factory +FactoryType = Callable[[], FactoryOutputType] + +# Why are some configs disabled by default? Because if we don't the fuzzer produces uninteresting results. +# It will always hone-in on these failures, even with the most basic model, making it useless for +# debugging more complex models. +# +# More explicit explanations are below: +# Out of Scope: We can't fuzz, say, the cuda version because that comes from the environment and will +# produce a failure if not aligned with env. +# Known Failure: Disabled due to known failure. Hopefully re-enable. Known failures are listed in the +# docstring of this file. +# Required: Required for the fuzzer to operate (removing caching, etc.) +# FSDP: Flag meant for FSDP that fails in non FSDP envs. Re-enable these if you're testing FSDP. +# Typing: disabled because the type annotation of the config isn't constrained enough to produce +# meaningful fuzz values. These could be improved. +# Timing: These take too long to compile, feel free to enable. +MODULE_DEFAULTS: dict[str, ConfigType] = { + "torch._inductor.config": { + "force_disable_caches": True, # Required + "cpp.cxx": DEFAULT, # Out of Scope + "TYPE_CHECKING": DEFAULT, # Not a config + "max_autotune_pointwise": DEFAULT, # Timing + "max_autotune_gemm": DEFAULT, # Timing, re-enable when autotune speed improvements merged. + "max_autotune_gemm_backends": DEFAULT, # Timing + "max_autotune_conv_backends": DEFAULT, # Timing + "max_autotune_gemm_search_space": DEFAULT, # Timing + "max_autotune_subproc_result_timeout_seconds": DEFAULT, # Timing + "max_autotune_subproc_graceful_timeout_seconds": DEFAULT, # Timing + "max_autotune_subproc_terminate_timeout_seconds": DEFAULT, # Timing + "aot_inductor.presets": DEFAULT, # Typing + "cuda.arch": DEFAULT, # Out of Scope + "cuda.version": DEFAULT, # Out of Scope + "cuda.cutlass_dir": DEFAULT, # Out of Scope + "cuda.cuda_cxx": DEFAULT, # Out of Scope + "rocm.arch": DEFAULT, # Out of Scope + "rocm.ck_supported_arch": DEFAULT, # Out of Scope + "rocm.ck_dir": DEFAULT, # Out of Scope + "rocm.rocm_home": DEFAULT, # Out of Scope + "check_stack_no_cycles_TESTING_ONLY": DEFAULT, # Testing + "sleep_sec_TESTING_ONLY": DEFAULT, # Testing + "triton.inject_relu_bug_TESTING_ONLY": DEFAULT, # Testing + "reorder_for_compute_comm_overlap": DEFAULT, # FSDP + "enabled_metric_tables": DEFAULT, # Typing + "triton.debug_sync_graph": DEFAULT, # Known Failure + "triton.debug_sync_kernel": DEFAULT, # Known Failure + "profile_bandwidth_regex": DEFAULT, # Known Failure + "disable_cpp_codegen": DEFAULT, # Known Failure + "trace.save_real_tensors": DEFAULT, # Known Failure + "pre_grad_fusion_options": DEFAULT, # Typing + "external_matmul": DEFAULT, # Typing, need to add this to type overrides or type exemplars. + "test_configs.autotune_choice_name_regex": DEFAULT, # Typing + "test_configs.autotune_choice_desc_regex": DEFAULT, # Typing + "cpp.enable_floating_point_contract_flag": DEFAULT, # Typing + "post_grad_custom_pre_pass": DEFAULT, # Typing + "post_grad_custom_post_pass": DEFAULT, # Typing + "reorder_for_compute_comm_overlap_passes": DEFAULT, # Typing + "joint_custom_post_pass": DEFAULT, # Typing + "joint_custom_pre_pass": DEFAULT, # Typing + "pre_grad_custom_pass": DEFAULT, # Typing + "custom_partitioner_fn": DEFAULT, # Typing + "inductor_choices_class": DEFAULT, # Typing + }, + "torch._dynamo.config": { + "traceable_tensor_subclasses": DEFAULT, # Typing + "nontraceable_tensor_subclasses": DEFAULT, # Typing + "compiled_autograd_kwargs_override": DEFAULT, # Typing + "fail_on_recompile_limit_hit": DEFAULT, # fails in combo with suppress_errors + "suppress_errors": DEFAULT, + "caching_precompile": False, # Required + }, +} + + +class ConfigFuzzer: + """ + This tool makes it easy to search through config state-space with a minimal reproduction or test, either for + debugging or just bug hunting. + It has two entry points: + - bisect, which randomly flips configs and tries to find the minimal reproduction upon failure. + - fuzz_n_tuple, which tries every combination of n configs. This grows quickly as a function of n, so beware. + bisect is recommended, but fuzz_n_tuple can give you peace of mind that a new config will compose with + every other config. + + The main interface is a function factory that will return Callables to be torch.compiled. This function factory + should return a test function when it's called. Said test function returns a boolean, which determines whether + the ConfigFuzzer considers it a successful run or not. Throwing an exception from within the function will be + considered a failure as well. + + # Example usage: + + ```python + import torch._inductor.config as cfg + + + def create_simple_test_model_gpu() -> FactoryOutputType: + batch_size = 32 + seq_length = 50 + hidden_size = 768 + + def test_fn() -> bool: + inp = torch.randn(batch_size, seq_length, hidden_size, device="cuda") + weight = torch.randn(hidden_size, hidden_size, device="cuda") + matmul_output = inp @ weight + final_output = torch.nn.LayerNorm(hidden_size, device="cuda")(matmul_output) + return True + + return test_fn + + + fuzzer = ConfigFuzzer(cfg, create_simple_test_model_gpu, seed=2) + + # Test every pair of configs: + results = fuzzer.fuzz_n_tuple(n, max_combinations=10000000) + + visualize_results(n, results) + + # Test random configs with bisection: + ret = fuzzer.bisect(num_attempts=10) + + # reproduce a failing config + fuzzer.reproduce( + [{"triton.autotune_pointwise": ..., "coordinate_descent_tuning": ...}] + ) + ``` + + The list of known failures on inductor config are: + cpp_wrapper, triton_debug_sync_graph + cpp_wrapper, triton_debug_sync_kernel + cpp_wrapper, disable_cpp_codegen + combo_kernels, benchmark_combo_kernel, profile_bandwidth, profile_bandwidth_regex + trace.enabled, trace.save_real_tensors + """ + + sample: SamplingType + default: ConfigType + + def __init__( + self, + config_module: ConfigModule, + test_model_fn_factory: FactoryType, + seed: int, + default: Optional[ConfigType] = None, + sm: SamplingMethod = SamplingMethod.TOGGLE, + test_timeout: int = 3600, + ): + """ + Args: + config_module: The module containing the configs to fuzz + test_model_fn_factory: Function that returns a test model, which runs and returns True if successful, or + the outputs if they should be compared with eager + seed: Randomness seed. + default: Default values for the config. Inductor has preset based on know failures. + sm: How type value samples are generated, default TOGGLE. + test_timeout: max time a test can take. + """ + self.seed = seed + self.test_timeout = test_timeout + self.detailed_results: dict[ComboType, dict[str, Any]] = {} + self.config_module = config_module + self.test_model_fn_factory = test_model_fn_factory + self.fields: dict[str, _ConfigEntry] = self.config_module._config + self.sample = SamplingMethod.dispatch(sm) + + if default is None: + if self.config_module.__name__ in MODULE_DEFAULTS: + self.default = MODULE_DEFAULTS[self.config_module.__name__] + else: + raise ValueError("No default passed to ConfigFuzzer.") + else: + self.default = default + + def __repr__(self) -> str: + return ( + f"ConfigFuzzer(config_module={self.config_module}, " + f"test_model_fn_factor={self.test_model_fn_factory}, seed={self.seed}, default={self.default})" + ) + + def _set_config(self, field_name: str, value: Any) -> None: + """Set a config value in the module.""" + setattr(self.config_module, field_name, value) + + def _reset_configs(self) -> None: + """Reset all configs to their default values.""" + for field_name, field_obj in self.fields.items(): + self._set_config(field_name, field_obj.default) + + def new_config(self) -> ConfigType: + """creates a new config from the default""" + ret = { + name: val if val != DEFAULT else self.fields[name].default + for name, val in self.default.items() + } + return ret + + def reproduce(self, configs: Sequence[ConfigType]) -> ResultType: + """entrypoint to reproduce any failure""" + results = ResultType() + for conf in configs: + self._reproduce_single_helper(conf, results) + return results + + def _reproduce_single_helper(self, conf: ConfigType, results: ResultType) -> None: + print(f"Starting repro of {conf}") + new_config = self.new_config() + new_config.update(conf) + self.test_config(results, new_config) + print(f"Status of {conf}:\n{results.lookup(tuple(conf.keys()))}") + + def reproduce_single(self, config: ConfigType) -> ResultType: + results = ResultType() + self._reproduce_single_helper(config, results) + return results + + def _fuzz_helper(self, results: ResultType, combo: ComboType) -> Status: + print(combo) + if st := results.lookup(combo): + # we already processed this config + return st + + config = self.new_config() + + skip = False + for field_name in combo: + if field_name in config: + # don't break here because we need to build the config dict + skip = True + if field_name.startswith("_"): + skip = True + field = self.fields[field_name] + value = self.sample(field_name, field.value_type, field.default) + config[field_name] = value + if skip: + results.set(combo, Status.SKIPPED) + return Status.SKIPPED + + return self.test_config(results, config) + + def fuzz_n_tuple(self, n: int, max_combinations: int = 1000) -> ResultType: + """ + Test every combination of n configs. + + returns a dict of this shape: {(config-1, config-2... config-n): status} + """ + results = ResultType() + print(f"Starting {n}-tuple testing with seed {self.seed}") + random.seed(self.seed) + + for combo in itertools.combinations(self.fields, n): + st = self._fuzz_helper(results, combo) + if st != Status.SKIPPED: + max_combinations -= 1 + if max_combinations <= 0: + print("Reached maximum combinations limit") + break + + return results + + def save_state(self, filename: str = "fuzzer_state.pkl") -> None: + """Save the current fuzzer state to a file""" + with open(filename, "wb") as f: + pickle.dump( + {"results": self.results, "detailed_results": self.detailed_results}, f + ) + + def load_state(self, filename: str = "fuzzer_state.pkl") -> None: + """Load fuzzer state from a file""" + with open(filename, "rb") as f: + state = pickle.load(f) + self.results = state["results"] + self.detailed_results = state.get("detailed_results", {}) + + def timeout_handler(self, signum: int, frame: Optional[FrameType]) -> None: + raise TimeoutError("Test execution timed out") + + def test_config(self, results: ResultType, config: ConfigType) -> Status: + """ + Tests a config by calling the function produced by the factory function. + """ + original_handler = signal.signal(signal.SIGALRM, self.timeout_handler) + signal.alarm(self.test_timeout) + print(f"Testing config {config}") + config_tuple = tuple(config.keys()) + if ret := results.lookup(config_tuple): + signal.signal(signal.SIGALRM, original_handler) + return ret + + def print_config() -> None: + for field, value in config.items(): + print(f"{field} = {value}") + + def get_error_info(exc: Exception) -> dict[str, Any]: + return { + "exception": str(exc), + "traceback": traceback.format_exc(), + "config": config.copy(), + } + + def handle_return( + message: str, + return_status: Status, + print_traceback: bool, + exc: Optional[Exception], + ) -> Status: + signal.signal(signal.SIGALRM, original_handler) + print(f"{message} with config combination:") + print_config() + if exc: + self.detailed_results[config_tuple] = get_error_info(exc) + if print_traceback: + traceback.print_exc() + results.set(config_tuple, return_status) + return return_status + + # reset config + torch._dynamo.reset() + self._reset_configs() + for name, value in config.items(): + self._set_config(name, value) + + # try running eager + test_model_fn = self.test_model_fn_factory() + try: + test_model_fn() + except Exception as exc: + return handle_return( + "Eager exception", Status.FAILED_RUN_EAGER_EXCEPTION, True, exc + ) + + # try compilation + try: + test_model_fn2 = self.test_model_fn_factory() + comp = torch.compile(test_model_fn2, backend="inductor") + except Exception as exc: + return handle_return( + "Exception compiling", Status.FAILED_COMPILE, True, exc + ) + + # try running compiled + try: + compile_result = comp() + except Exception as exc: + return handle_return( + "Exception running compiled", + Status.FAILED_RUN_COMPILE_EXCEPTION, + True, + exc, + ) + + # bool return value means don't compare with eager + if not compile_result: + return handle_return( + "Function returned False", Status.FAILED_RUN_RETURN, False, None + ) + else: + return handle_return("Function succeeded", Status.PASSED, False, None) + + def bisect(self, num_attempts: int = 100, p: float = 0.5) -> list[ConfigType]: + """ + Test configs and bisect to minimal failing configuration. + """ + print(f"Starting random testing with bisection, seed {self.seed}, and p {p}") + random.seed(self.seed) + self._reset_configs() + results = ResultType() + ret: list[ConfigType] = [] + + for attempt in range(num_attempts): + print(f"Random attempt {attempt + 1}/{num_attempts}") + + config = self.new_config() + + for field_name, config_entry in self.fields.items(): + if ( + field_name not in config + and not field_name.startswith("_") + and "TESTING_ONLY" not in field_name + and random.random() < p + ): + value = self.sample( + field_name, config_entry.value_type, config_entry.default + ) + config[field_name] = value + + status = self.test_config(results, config) + if status not in OrderedSet([Status.PASSED, Status.SKIPPED]): + if minimal_failing_config := self._bisect_failing_config( + results, config + ): + print(f"Minimum failing config: {minimal_failing_config}") + ret.append(minimal_failing_config) + + return ret + + def _bisect_failing_config( + self, results: ResultType, failing_config: ConfigType + ) -> Optional[ConfigType]: + return self._bisect_failing_config_helper(results, list(failing_config.items())) + + def _bisect_failing_config_helper( + self, results: ResultType, failing_config: list[tuple[str, Any]] + ) -> Optional[ConfigType]: + """ + Bisect a failing configuration to find minimal set of configs that cause failure. + + Splits it into halves, then fourths, then tries dropping configs one-by-one. + """ + print(f"bisecting config: {failing_config}") + + if not failing_config: + return None + + def test(x: list[tuple[str, Any]]) -> Status: + d = dict(x) + result = self.test_config(results, d) + return result + + if len(failing_config) <= 1: + return dict(failing_config) if test(failing_config).failing() else None + + random.shuffle(failing_config) + + mid = len(failing_config) // 2 + first_half = failing_config[:mid] + second_half = failing_config[mid:] + if test(first_half).failing(): + return self._bisect_failing_config_helper(results, first_half) + if test(second_half).failing(): + return self._bisect_failing_config_helper(results, second_half) + + if len(failing_config) >= 8: + low = len(failing_config) // 4 + high = mid + low + quart1 = failing_config[low:] + if test(quart1).failing(): + return self._bisect_failing_config_helper(results, quart1) + quart2 = failing_config[:low] + second_half + if test(quart2).failing(): + return self._bisect_failing_config_helper(results, quart2) + quart3 = first_half + failing_config[:high] + if test(quart3).failing(): + return self._bisect_failing_config_helper(results, quart3) + quart4 = failing_config[high:] + if test(quart4).failing(): + return self._bisect_failing_config_helper(results, quart4) + # try dropping one value at a time + for i in range(len(failing_config)): + new_list = [x for j, x in enumerate(failing_config) if j != i] + if test(new_list).failing(): + return self._bisect_failing_config_helper(results, new_list) + # we have the minimal set + return dict(failing_config) + + +def visualize_results( + n: int, results: ResultType, filename: str = "results.html" +) -> None: + """ + Creates an HTML document representing the results of running the fuzzer with fuzz_n_tuple, with n = 2. + """ + # TODO support more dimensions + assert n == 2 + assert len(results) > 0 + + input_set: OrderedSet[str] = OrderedSet({}) + for key in results.keys(): # noqa: SIM118 + input_set.add(key[0]) + input_set.add(key[1]) + input_list = sorted(input_set) + + # Start the HTML content + html_content = """ + + + + + + Fuzzer Visualization + + + +

Fuzzer Visualization

+ + + """ + + html_content += "" + for col_name in input_list: + col = "
".join(col_name) + html_content += f"" + html_content += "" + + # Add table rows + for row_name in input_list: + html_content += f"" + for col_name in input_list: + # Determine the status class for the cell + status_enum = results.lookup((row_name, col_name)) + status_class = "" + status_val = "" + if status_enum == Status.SKIPPED: + status_class = "skipped" + status_val = "-" + elif status_enum == Status.PASSED: + status_class = "passed" + status_val = "O" + elif status_enum == Status.FAILED_RUN_EAGER_EXCEPTION: + status_class = "failed" + status_val = "e" + elif status_enum == Status.FAILED_RUN_COMPILE_EXCEPTION: + status_class = "failed" + status_val = "E" + elif status_enum == Status.FAILED_RUN_RETURN: + status_class = "failed" + status_val = "R" + elif status_enum == Status.FAILED_COMPILE: + status_class = "failed" + status_val = "C" + else: + status_class = "skipped" + status_val = "-" + + html_content += f'' + html_content += "" + + html_content += """ + +
\\{col}
{row_name}{status_val}
+ + + """ + + with open(filename, "w") as file: + file.write(html_content) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..88cb9c7ea08bb8d61e46bd06d26b7ab57f7c83bb --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/fx_utils.py @@ -0,0 +1,346 @@ +# mypy: allow-untyped-defs +import contextlib +import operator +from collections import defaultdict +from collections.abc import Callable +from typing import Any, Optional + +import sympy + +import torch +import torch.fx +from torch._dispatch.python import enable_python_dispatcher +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.fx.experimental.symbolic_shapes import ( + compute_unbacked_bindings, + rebind_unbacked, + statically_known_true, + sym_eq, +) +from torch.utils import _pytree as pytree +from torch.utils._ordered_set import OrderedSet +from torch.utils._pytree import tree_map +from torch.utils.flop_counter import flop_registry + +from .virtualized import V + + +# Check the pattern: (nn.module, F.function/torch.Tensor.method) matched. +# Works for length 2 patterns with 1 module and 1 function/method. +def matches_module_function_pattern( + pattern: tuple[type[torch.nn.modules.Module], Callable[..., Any]], + node: torch.fx.node.Node, + modules: dict[str, torch.nn.modules.Module], +) -> bool: + if len(node.args) == 0: + return False + if not isinstance(node.args[0], torch.fx.Node) or not isinstance( + node, torch.fx.Node + ): + return False + # the first node is call_module + if node.args[0].op != "call_module": + return False + if not isinstance(node.args[0].target, str): + return False + if node.args[0].target not in modules: + return False + if type(modules[node.args[0].target]) is not pattern[0]: + return False + # the second node is call_function or call_method + if node.op != "call_function" and node.op != "call_method": + return False + if node.target != pattern[1]: + return False + # make sure node.args[0] output is only used by current node. + if len(node.args[0].users) > 1: + return False + return True + + +class FakeTensorUpdater: + """ + The main idea here is that it's difficult to maintain accurate fake + tensors (our primary form of metadata) for each node in our graph as we + transform it. + + The most reliable way to obtain this information is by rerunning + faketensor propagation. However, in general, faketensor propagation is + fairly expensive. So, instead we'd like to only rerun faketensor + propagation on nodes that have changed. + + In order to detect which nodes have changed, we first hash its node, + target, and argument lists (which are immutable in FX). + + Then, whenever we call incremental_update, we check which FX nodes have a + new hash, and recompute the faketensor metadata for that node. Then, we + continue to recursively compute the faketensors for all users until the + fake tensors stop changing. + """ + + def __init__(self, graph: torch.fx.Graph) -> None: + self.processed_hashes = OrderedSet[Any]() + self.graph = graph + + for node in self.graph.nodes: + self.processed_hashes.add(self.hash_node(node)) + + def hash_node(self, node: torch.fx.Node): + # todo(chilli): Not a great hash function + return (node, node.target, id(node.args), id(node.kwargs)) + + def incremental_update(self): + """Update FakeTensors on self.graph. We will try to do the minimum amount of work.""" + existing_storages: defaultdict[Optional[int], int] = defaultdict(int) + for node in self.graph.nodes: + existing_storages[get_node_storage(node)] += 1 + + def is_intlist_same(new, old): + return statically_known_true(sym_eq(new, old)) + + def is_fake_tensor_same(new, old, *, node): + if type(new) is not type(old): + return False + if isinstance(new, (list, tuple)): + if len(new) != len(old): + return False + return all( + is_fake_tensor_same(new_i, old_i, node=node) + for new_i, old_i in zip(new, old) + ) + if new is None: + return old is None + if not isinstance(new, torch.Tensor): + assert isinstance(new, (torch.SymInt, torch.SymBool, torch.SymFloat)), ( + f"Unknown type {type(new)} in {self.graph}" + ) + return ( + new.node.shape_env._maybe_evaluate_static( + sympy.Eq(new.node.expr, old.node.expr) + ) + == sympy.true + ) + if not is_intlist_same(new.shape, old.shape) or new.layout != old.layout: + return False + if new.layout == torch.strided and ( + not is_intlist_same(new.stride(), old.stride()) + or not statically_known_true( + new.storage_offset() == old.storage_offset() + ) + ): + return False + + if new.device != old.device: + return False + + if get_storage(new) == get_storage(old): + return True + + def any_user_may_alias(node): + if not isinstance(node.meta["val"], torch.Tensor): + # analysis too complicated on lists, can support in the future + return True + for user in node.users: + if not ( + isinstance( + user.target, + (torch._ops.OpOverload, torch._ops.HigherOrderOperator), + ) + or user.target + is torch._inductor.fx_passes.reinplace._generalized_scatter + ): + return True + if isinstance(user.target, torch._ops.HigherOrderOperator): + # HOPs that survive until inductor are all non-aliasing HOPs. + # We will likely never support HOPs that are aliasing. + continue + # Strategy: do a FakeTensor prop, see if the storage aliases. + # If Inductor ever gets tighter invariants on OpOverloads + # (that is, we ban things like torch.ops.aten.reshape calls in the graph), + # Then this could just be a fast schema lookup. + is_valid, args, kwargs = get_fake_args_kwargs(user) + if not is_valid: + return True + with ( + V.fake_mode, + enable_python_dispatcher(), + contextlib.ExitStack() as stack, + ): + # Ignore unbacked symbols (if they exist): we're making + # this FakeTensor and then throwing it away. + shape_env = V.fake_mode.shape_env + if shape_env is not None: + stack.enter_context( + shape_env.ignore_fresh_unbacked_symbols() + ) + new_fake_tensor = user.target(*args, **kwargs) + if not isinstance(new_fake_tensor, torch.Tensor): + # analysis too complicated on lists, can support in the future + return True + if get_storage(new_fake_tensor) == get_storage(node.meta["val"]): + return True + return False + + # This is the case where it returns a completely fresh storage that's used nowhere else. + # If the FakeTensor's storage is fresh and none of the node's users can alias it, then + # we don't need to update this node. + if ( + existing_storages[get_storage(old)] == 1 + and get_storage(new) not in existing_storages + and not any_user_may_alias(node) + ): + return True + + return False + + def should_process_node(node): + # node.target for nodes returning true from this function + # are called under fake mode and does not work for inductor + # lowerings. We check if the node.target is an aten operator + # or operator.getitem which is used when returning multiple + # tensors from an op. + return node.op == "call_function" and ( + isinstance(node.target, torch._ops.OpOverload) + or node.target is operator.getitem + or node.target + is torch._inductor.fx_passes.reinplace._generalized_scatter + ) + + to_process = OrderedSet[int]() + for node in self.graph.nodes: + # NB: Be very careful about skipping nodes (via continues) here + # and ask for a careful review when changing this code. The + # consequence for incorrect FakeTensor metadata is difficult-to-debug + # silent incorrectness. + if ( + self.hash_node(node) in self.processed_hashes + and id(node) not in to_process + ): + continue + + if not should_process_node(node): + continue + + is_valid, args, kwargs = get_fake_args_kwargs(node) + if not is_valid: + continue + with V.fake_mode, enable_python_dispatcher(): + new_fake_tensor = node.target(*args, **kwargs) + + if "val" in node.meta and is_fake_tensor_same( + new_fake_tensor, node.meta["val"], node=node + ): + continue + + rebind_unbacked(V.fake_mode.shape_env, node, new_fake_tensor) + + node.meta["val"] = new_fake_tensor + if (shape_env := V.fake_mode.shape_env) and ( + symbol_to_path := compute_unbacked_bindings(shape_env, new_fake_tensor) + ): + # Refresh the bindings to the new symbols + + node.meta["unbacked_bindings"] = symbol_to_path + + existing_storages[get_node_storage(node)] += 1 + + to_process.update([id(user) for user in node.users]) + + self.processed_hashes.add(self.hash_node(node)) + + +def get_storage(t: torch.Tensor) -> int: + return t.untyped_storage()._cdata + + +def get_node_storage(node: torch.fx.Node) -> Optional[int]: + if "val" not in node.meta: + return None + if not isinstance(node.meta["val"], torch.Tensor): + return None + if not torch._C._has_storage(node.meta["val"]): + return None + return get_storage(node.meta["val"]) + + +def get_fake(x): + if isinstance(x, torch.fx.Node): + if "val" not in x.meta: + return x + return x.meta["val"] + return x + + +def get_fake_args_kwargs(x: torch.fx.Node) -> tuple[bool, tuple[Any], dict[str, Any]]: + """ + First value returns a boolean if any of the input nodes don't have a faketensor. + """ + args, kwargs = tree_map(get_fake, (x.args, x.kwargs)) + if any( + isinstance(a, torch.fx.Node) for a in pytree.arg_tree_leaves(*args, **kwargs) + ): + return False, args, kwargs + return True, args, kwargs + + +def is_node_realized(node: torch.fx.Node) -> bool: + """Returns true if a node is always realized when lowered to inductor IR. + + NOTE: This may return some false negatives. e.g. it doesn't + handle buffers realized heuristically during lowering, or + buffers realized indirectly through view ops. + """ + from torch._inductor.lowering import fallbacks, needs_realized_inputs + + def is_buffer(node: torch.fx.Node) -> bool: + if node.op == "call_function" and node.target is operator.getitem: + # For nodes with multiple outputs, we get the fx graph: + # foo = torch.ops.aten.foo(...) + # getitem = foo[0] + # getitem_1 = foo[1] + # where we need to check if foo is a fallback kernel + return is_buffer(node.args[0]) # type: ignore[arg-type] + return node.op in ("placeholder", "output") or node.target in fallbacks + + if is_buffer(node): + return True + + def realizes_inputs(node: torch.fx.Node) -> bool: + return node.op == "output" or node.target in needs_realized_inputs + + if any(realizes_inputs(user) for user in node.users): + return True + + # Otherwise, assume node isn't realized + return False + + +def count_flops_fx(node: torch.fx.Node) -> Optional[int]: + if not countable_fx(node) or isinstance(node.target, str): + return None + with FakeTensorMode(allow_non_fake_inputs=True): + success, args, kwargs = get_fake_args_kwargs(node) + + if success: + with torch.utils.flop_counter.FlopCounterMode( + display=False + ) as flop_counter_mode: + node.target(*args, **kwargs) + + counted_flops = flop_counter_mode.get_total_flops() + return counted_flops + return None + + +def countable_fx(node: torch.fx.Node) -> bool: + """ + Whether or not we can count the flops of an FX node. + """ + assert isinstance(node, torch.fx.Node) + if not hasattr(node, "target"): + return False + target = node.target + if not hasattr(target, "overloadpacket"): + return target in flop_registry + packet = target.overloadpacket + return packet in flop_registry diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/graph.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/graph.py new file mode 100644 index 0000000000000000000000000000000000000000..68b2f05f2c414b2bba191ef72d80d3f04974e445 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/graph.py @@ -0,0 +1,2569 @@ +from __future__ import annotations + +import contextlib +import functools +import itertools +import logging +import operator +import os +import re +import sys +import time +from collections import defaultdict +from contextlib import contextmanager +from typing import Any, NoReturn, Optional, TYPE_CHECKING, Union + +import sympy +from sympy import Expr + +import torch +import torch._logging +import torch.fx +from torch import device, Tensor +from torch._decomp import get_decompositions +from torch._dynamo.utils import defake, dynamo_timed +from torch._library.fake_class_registry import FakeScriptObject +from torch._library.utils import get_layout_constraint_tag +from torch._logging import LazyString, trace_structured +from torch._prims_common import ( + compute_required_storage_length, + make_channels_last_strides_for, +) +from torch._subclasses.fake_tensor import FakeTensor +from torch._utils_internal import full_aoti_runtime_assert +from torch.fx.experimental._backward_state import BackwardState +from torch.fx.experimental.sym_node import magic_methods, method_to_operator +from torch.fx.experimental.symbolic_shapes import ( + _get_placeholder_expr, + free_unbacked_symbols, + has_free_symbols, + resolve_unbacked_bindings, + RuntimeAssert, + ShapeEnv, + SympyBoolean, + SymTypes, +) +from torch.fx.node import Node +from torch.fx.passes.reinplace import _is_view_op +from torch.utils._mode_utils import no_dispatch +from torch.utils._ordered_set import OrderedSet +from torch.utils._sympy.numbers import int_oo + +from . import config, ir, metrics +from .codegen.common import ( + BackendFeature, + DeviceOpOverrides, + FileBackedGraphModule, + get_backend_features, + get_device_op_overrides, + get_wrapper_codegen_for_device, + init_backend_registration, + WorkspaceArg, +) +from .exc import ( + CppWrapperCodegenError, + LoweringException, + MissingOperatorWithDecomp, + MissingOperatorWithoutDecomp, +) +from .fx_utils import count_flops_fx +from .ir import ( + assign_origin_node, + Constant, + DonatedBuffer, + FixedLayout, + get_device_type, + GraphPartitionSignature, + InputBuffer, + Pointwise, + Reduction, + ShapeAsConstantBuffer, + StorageBox, + TensorBox, + TorchBindObject, +) +from .lowering import ( + constrain_to_fake_tensors, + constrain_to_fx_strides, + FALLBACK_ALLOW_LIST, + fallback_handler, + fallback_node_due_to_unsupported_type, + lowerings, + make_fallback, + maybe_layout_constraints, + needs_realized_inputs, + require_contiguous, + tag_to_layout_constraint, + unsupported_output_tensor, +) +from .runtime import autotune_cache +from .runtime.autotune_cache import AutotuneCacheBundler +from .sizevars import SizeVarAllocator +from .utils import ( + convert_shape_to_inductor, + gather_origins, + get_cloned_parameter_buffer_name, + get_donated_idxs, + get_sympy_Expr_dtype, + GraphPartitionMap, + is_same_tensor, + maybe_get_suppress_shape_guards_ctx, + normalize_name, + should_assume_input_aligned, + should_fallback_by_default, + SUPPORTED_MKLDNN_DEVICES, + ValueWithLineMap, +) +from .virtualized import NullHandler, V + + +if TYPE_CHECKING: + from collections.abc import Callable, Iterable, Iterator, Sequence + from types import ModuleType + + from torch._higher_order_ops.effects import _EffectType + from torch.fx import GraphModule + from torch.fx.graph import Graph + + from .codegen.wrapper import PythonWrapperCodegen + from .dependencies import Dep + from .scheduler import BaseSchedulerNode + + CompiledModule = Union[ModuleType, FileBackedGraphModule] + +from torch._inductor.codecache import output_code_log + + +log = logging.getLogger(__name__) +perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints") + +aten = torch.ops.aten + +_post_grad_graph_counter = itertools.count() + +if config.is_fbcode(): + from torch._inductor.fb.utils import log_module_code +else: + + def log_module_code(*args: Any, **kwargs: Any) -> None: + pass + + +def may_get_constant_buffer_dtype(constant_buffer: sympy.Expr) -> Optional[torch.dtype]: + assert isinstance( + constant_buffer, (sympy.Symbol, sympy.Expr, sympy.core.numbers.Integer) + ), ( + "get_constant_buffer_dtype only supports input of sympy.Symbol, sympy.Expr or sympy.core.numbers.Integer" + ) + if isinstance(constant_buffer, sympy.core.numbers.Integer): + return torch.int64 + + if isinstance(constant_buffer, sympy.Expr): + return get_sympy_Expr_dtype(constant_buffer) + + if constant_buffer.is_integer: + return torch.int64 + elif constant_buffer.is_float: + return torch.float32 + else: + return None + + +def is_magic_method(op: Any) -> bool: + magic_ops = OrderedSet(method_to_operator(m) for m in magic_methods) + return op in magic_ops + + +def getattr_recursive( + obj: GraphModule, target: str +) -> Union[Tensor, torch._C.ScriptObject, GraphModule]: + target_atoms = target.split(".") + attr_itr = obj + for i, atom in enumerate(target_atoms): + if not hasattr(attr_itr, atom): + raise RuntimeError( + f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}" + ) + attr_itr = getattr(attr_itr, atom) + return attr_itr + + +def get_user_visible_output_strides(g: Graph) -> dict[Node, tuple[int, ...]]: + ret: dict[Node, tuple[int, ...]] = {} + output_node = g.find_nodes(op="output")[0] + + if "user_visible_output_idxs" not in output_node.meta: + return ret + + if not isinstance(output_node.args[0], torch.fx.Node): + output_node_args = output_node.args[0] + else: + output_node_args = output_node.args + + for idx, node in enumerate(output_node_args): + if idx in output_node.meta["user_visible_output_idxs"]: + ret[node] = output_node.meta["original_output_strides"][idx] + return ret + + +def extend_user_visible_output_strides( + user_visible_outputs: dict[Node, tuple[int, ...]], +) -> dict[Node, object]: + """ + Extend user_visible_output_strides to include view ops that lead to user-visible outputs. + """ + result: dict[Node, object] = {**user_visible_outputs} + queue = [*result.keys()] + visited = OrderedSet([*queue]) + while queue: + current = queue.pop() + if ( + _is_view_op(current.target) + and current.args + and isinstance(current.args[0], torch.fx.Node) + ): + base = current.args[0] + if base not in visited: + result.setdefault(base, None) + visited.add(base) + queue.append(base) + return result + + +def mark_nodes_dislike_padding( + g: Graph, user_visible_output_strides: dict[Node, tuple[int, ...]] +) -> None: + """ + Nodes like convolution/convolution_backward want its input to be dense. + If we pad their inputs, we result in extra calls to copy kernels! On the other hand, padding usually helps reduction. + + The pass finds nodes that dislike padding. These are nodes that can be reached + from a convolution/convolution_backward in the backward direction without + going thru a reduction. + """ + if not config.comprehensive_padding: + return + + extended_user_visible_nodes = extend_user_visible_output_strides( + user_visible_output_strides + ) + ops_dislike_padding = OrderedSet( + [ + aten.convolution, + aten.convolution_backward, + aten._scaled_mm, + ] + ) + # what's a better way to collect the reduction ops? + ops_like_padding = OrderedSet( + [ + aten.var_mean, + aten.sum, + aten.mean, + aten.prod, + aten.any, + aten.amin, + aten.amax, + aten.min, + aten.max, + aten.argmin, + aten.argmax, + aten.scatter_reduce, + ] + ) + + def _get_overload_packet( + node: torch.fx.Node, + ) -> Optional[torch._ops.OpOverloadPacket]: + return ( + node.target._overloadpacket + if node.op == "call_function" + # hasattr on OpOverloadPacket is slow, do isinstance first + and isinstance(node.target, torch._ops.OpOverload) + and hasattr(node.target, "_overloadpacket") + else None + ) + + for cur in reversed(g.nodes): + if isinstance( + cur.target, + torch._higher_order_ops.triton_kernel_wrap.TritonKernelWrapperMutation, + ): + cur.meta["dislike_padding"] = True + continue + + if ( + isinstance(cur.target, torch._ops.OpOverload) + and get_layout_constraint_tag(cur.target) + == torch._C.Tag.needs_exact_strides + ): + cur.meta["dislike_padding"] = True + continue + + op = _get_overload_packet(cur) + if not op: + continue + if op in ops_dislike_padding: + cur.meta["dislike_padding"] = True + + if cur.meta.get("dislike_padding", False): + # propagate + for prior in cur.all_input_nodes: + prior_op = _get_overload_packet(prior) + if not prior_op: + continue + if prior_op not in ops_like_padding: + prior.meta["dislike_padding"] = True + # We only want to mark output nodes. So, move it after the above prior nodes process. + if not config.pad_outputs and cur in extended_user_visible_nodes: + cur.meta["dislike_padding"] = True + + +class GraphLowering(torch.fx.Interpreter): + graph_outputs: list[ir.IRNode] + + def __init__( + self, + gm: torch.fx.GraphModule, + example_inputs: Optional[Sequence[object]] = None, + shape_env: Optional[ShapeEnv] = None, + graph_id: Optional[int] = None, + cpp_wrapper: bool = False, + aot_mode: bool = False, + layout_opt: Optional[bool] = None, + extern_node_serializer: Optional[ + Callable[[list[ir.ExternKernelNode]], Any] + ] = None, + is_inference: bool = False, + is_backward: bool = False, + is_const_graph: bool = False, + const_output_index: Optional[dict[str, int]] = None, + const_wrapper_code: Optional[str] = None, + const_kernel_code: Optional[str] = None, + const_module: Optional[GraphLowering] = None, + name: Optional[str] = None, + inputs_to_check: Optional[Sequence[int]] = None, + fx_wrapper: bool = False, + ) -> None: + super().__init__(gm) + self.example_inputs = example_inputs + self.layout_opt = ( + layout_opt + if layout_opt is not None + else self.decide_layout_opt(gm, is_inference=is_inference) + ) + self.num_channels_last_conv = 0 + self.is_inference = is_inference + self.is_backward = is_backward + self.is_const_graph = is_const_graph + self.const_wrapper_code = const_wrapper_code + self.const_kernel_code = const_kernel_code + self.const_module = const_module + self.inputs_to_check = inputs_to_check + + self.extra_traceback = False # we do our own error wrapping + if shape_env is None: + shape_env = ShapeEnv() + self.reuse_shape_env = False + else: + self.reuse_shape_env = True + self._shape_env = shape_env + # We're going to mutate ras_by_symbol as we finish generating them + self.ras_by_symbol: dict[Optional[sympy.Symbol], list[RuntimeAssert]] = ( + shape_env.deferred_runtime_asserts.copy() + ) + self.bound_unbacked_symbols = OrderedSet[sympy.Symbol]() + + self.sizevars = SizeVarAllocator(shape_env) + self.graph_input_names: list[str] = [] + self.graph_inputs: dict[str, Union[TensorBox, TorchBindObject, sympy.Expr]] = {} + self.graph_inputs_original: dict[str, InputBuffer] = {} + self.partition_maps: Optional[list[GraphPartitionMap]] = None + self.zero_dim_cpu_tensor_list: OrderedSet[str] = OrderedSet() + self.device_types: OrderedSet[str] = ( + const_module.device_types if const_module else OrderedSet() + ) + self.device_idxs: OrderedSet[int] = ( + const_module.device_idxs if const_module else OrderedSet() + ) + self.device_type = "cpu" + self.additional_buffer_deps: dict[str, OrderedSet[str]] = defaultdict( + OrderedSet + ) + self.additional_star_deps: dict[str, OrderedSet[str]] = defaultdict(OrderedSet) + + # Inplace padding may require Inductor to allocate slightly larger + # tensor for padding. + self.buffer_to_padded_size: dict[str, list[int]] = {} + + self.buffers: list[ir.Buffer] = [] + self.operations: list[ir.Operation] = [] + self.const_output_index: dict[str, int] = ( + const_output_index if const_output_index else {} + ) + self.folded_constants: OrderedSet[str] = ( + OrderedSet(const_output_index.keys()) + if const_output_index + else OrderedSet() + ) + self.constants: dict[str, torch.Tensor] = ( + const_module.constants if const_module else {} + ) + self.named_buffers: dict[str, torch.Tensor] = ( + const_module.named_buffers if const_module else {} + ) + self.mutated_named_buffers: OrderedSet[torch.Tensor] = gm.meta.get( + "mutated_named_buffers", OrderedSet() + ) + self.named_parameters: dict[str, torch.Tensor] = ( + const_module.named_parameters if const_module else {} + ) + self.torchbind_constants: dict[ + str, Union[torch._C.ScriptObject, FakeScriptObject] + ] = {} + self.opaque_value_type_classes: dict[str, type] = {} + self.seen_subgraphs: dict[str, ir.Subgraph] = {} + self.constant_reprs: dict[str, str] = {} + self.removed_operations: OrderedSet[str] = OrderedSet() + self.removed_buffers: OrderedSet[str] = OrderedSet() + self.removed_inplace_buffers: OrderedSet[str] = OrderedSet() + self.mutated_buffers: OrderedSet[str] = OrderedSet() + self.never_reuse_buffers: OrderedSet[str] = OrderedSet() + self.inplaced_to_remove: OrderedSet[str] = OrderedSet() + self.device_ops: DeviceOpOverrides = None # type: ignore[assignment] + self.wrapper_code: PythonWrapperCodegen = None # type: ignore[assignment] + + from torch._inductor.extern_node_serializer import extern_node_json_serializer + + self.extern_node_serializer: Callable[[list[ir.ExternKernelNode]], Any] = ( + extern_node_serializer + if config.is_fbcode() and extern_node_serializer + else extern_node_json_serializer + ) + + self.current_node: torch.fx.Node = None # type: ignore[assignment] + self.lists: dict[str, list[str]] = {} + self.mutated_inputs: OrderedSet[str] = OrderedSet() + self.mutated_input_idxs: list[int] = [] + self.name_to_buffer: dict[str, ir.Buffer] = {} + self.name_to_users: defaultdict[str, list[ir.IRNode]] = defaultdict(list) + self.name_to_op: dict[str, ir.Operation] = {} + self.creation_time = time.time() + self.name = name # type: ignore[assignment] + self.cpp_wrapper = cpp_wrapper + self.fx_wrapper = fx_wrapper + + # record multi_kernel choice for cpp_wrapper so the second pass knows + # which sub-kernel is picked. Copy cpp_wrapper to another variable + # since cpp_wrapper flag is OrderedSet to false for the first pass of codegen. + self.record_multi_kernel_choice = cpp_wrapper + self.multi_kernel_to_choice: dict[str, str] = {} + + self.aot_mode = aot_mode + self.graph_id = graph_id + self.post_grad_graph_id = next(_post_grad_graph_counter) + self.scheduler: torch._inductor.scheduler.Scheduler = None # type: ignore[assignment] + + # record intermediate results for input of UsedDefinedTritonKernels + # This will be used if autotuning is done in one pass. + self.autotuning_inputs: Optional[list[torch.Tensor]] = None + self.autotuning_mapping: Optional[dict[str, dict[str, int]]] = None + self.autotuning_grids: Optional[dict[str, Any]] = None + + # current_device is set only during codegen of a device-specific kernel + # a graph can have many devices + self.current_device: Optional[torch.device] = None + + self.nodes_prefer_channels_last = ( + self.find_nodes_prefer_channels_last() if self.layout_opt else OrderedSet() + ) + self._warned_fallback = OrderedSet(["aten.convolution_backward"]) + self.user_visible_output_strides = get_user_visible_output_strides(gm.graph) + mark_nodes_dislike_padding(gm.graph, self.user_visible_output_strides) + self.cache_key: str = "" # This is the cache key for the compiled artifact + self.cache_path: str = "" # This is the path in the filesystem where the compiled artifact is stored + self.cache_linemap: list[ + tuple[int, str] + ] = [] # This is the linemap used by the profiler to mark custom compiled kernels getting run + # Used if lowering encounters cases where cudagraphs are not supported + self.disable_cudagraphs_reason: Optional[str] = None + + # only keeping one node per device for stack trace purposes + self.device_node_mapping: dict[torch.device, torch.fx.Node] = {} + self.orig_gm: torch.fx.GraphModule = gm.__copy__() + for k, v in self.orig_gm.named_buffers(): + self.named_buffers[k] = v + for k, v in self.orig_gm.named_parameters(): + self.named_parameters[k] = v + self.dynamo_flat_name_to_original_fqn = self.module.meta.get( # type: ignore[operator, union-attr] + "dynamo_flat_name_to_original_fqn", {} + ) + self.allocated_constant_name: dict[str, str] = ( + const_module.allocated_constant_name if const_module is not None else {} + ) + init_backend_registration() + self.get_backend_features = functools.lru_cache(None)(get_backend_features) + + self.effectful_ops: dict[_EffectType, ir.Buffer] = {} + # Track the buffers that we know is unaligned + # This can either be a graph input or the output of fallback + # kernels. + self.unaligned_buffers: OrderedSet[str] = OrderedSet() + self.no_fuse_buffer_names: OrderedSet[str] = OrderedSet() + + self.low_precision_codegen_ops: OrderedSet[str] = OrderedSet() + # more aggressive prologue fusion + self.invoke_quant_ops: OrderedSet[str] = OrderedSet() + + # Below field is related to printing debug intermediate tensor values info for debugging + self.all_codegen_kernel_names: OrderedSet[str] = OrderedSet() + + # state used by for KernelArgs.workspace + self.workspace_id = itertools.count() + + # track the current placeholder index that we are processing + self.placeholder_idx = -1 + + self.bw_donated_idxs = get_donated_idxs() + + # Cache for dep size hints to avoid expensive recomputation + self.dep_size_hint_cache: dict[tuple[Dep, bool], int] = {} + + def freeze_runtime_asserts(self) -> None: + self._shape_env.freeze_runtime_asserts() + + def symbolic_sizes_strides( + self, ex: torch.Tensor + ) -> tuple[Sequence[Union[int, Expr]], Sequence[Union[int, Expr]]]: + """ + Support dynamic shapes and dynamic strides by assigning variables + to each dimension. We duck-shape tensors, so if two tensors + have the same size they get assigned the same symbolic variable. + """ + if self.reuse_shape_env: + return convert_shape_to_inductor(ex.size()), convert_shape_to_inductor( + ex.stride() + ) + else: + from torch._dynamo.source import ConstantSource + + # TODO: this should not be needed once #93059 lands + # https://github.com/pytorch/pytorch/pull/94031#discussion_r1096044816 + # TODO: make a dedicated UnknownSource for this? + # NB: This is using the legacy default behavior from + # create_symbolic_sizes_strides_storage_offset but we hope we can + # just delete this entirely + source = ConstantSource( + f"__inductor_unknown_tensor_{len(self._shape_env.var_to_val)}" + ) + ( + size, + stride, + _, + ) = self._shape_env.create_symbolic_sizes_strides_storage_offset( + ex, + source, + ) + + r_size = [i.node.expr if isinstance(i, torch.SymInt) else i for i in size] + r_stride = [i.node.expr if isinstance(i, torch.SymInt) else i for i in stride] + return r_size, r_stride + + def static_sizes_strides( + self, ex: torch.Tensor + ) -> tuple[list[sympy.Expr], list[sympy.Expr]]: + """ + Primarily used to weights + """ + size = [sympy.Integer(i) for i in ex.size()] + stride = [sympy.Integer(i) for i in ex.stride()] + return size, stride + + def get_allocation_size( + self, + node: Union[ + ir.TensorBox, ir.StorageBox, ir.Buffer, WorkspaceArg, ir.TorchBindObject + ], + ) -> Sequence[Expr]: + if isinstance(node, ir.TensorBox): + node = node.data # type: ignore[assignment] + if isinstance(node, ir.StorageBox): + node = node.data # type: ignore[assignment] + if ( + isinstance(node, ir.ComputedBuffer) + and node.name in self.buffer_to_padded_size + ): + # pyrefly: ignore [index-error] + return self.buffer_to_padded_size[node.name] + else: + return node.get_size() + + def get_allocation_storage_size( + self, node: Union[ir.Buffer, WorkspaceArg, ir.TorchBindObject] + ) -> Expr: + layout = node.get_layout() + size = self.get_allocation_size(node) # consider inplace padding + stride = layout.stride + offset = layout.offset + return compute_required_storage_length(size, stride, offset) # type: ignore[arg-type] + + def has_feature( + self, + device: Union[torch._inductor.ir.IRNode, device, None], + feature: BackendFeature, + ) -> bool: + assert isinstance(feature, BackendFeature), feature + return feature in self.get_backend_features(get_device_type(device)) + + def get_dep_size_hint(self, dep: Dep, count_bytes: bool = True) -> int: + """ + Get the size hint for a dependency with caching to avoid expensive recomputation. + """ + if (dep, count_bytes) not in self.dep_size_hint_cache: + res = 0 + try: + if not dep.has_unbacked_symbols(): + if count_bytes: + res = dep.numbytes_hint() + else: + res = dep.numel_hint() + except KeyError: + # In at least one test (test/inductor/test_torchbind.py) we + # create a StarDep that doesn't exist in the graph and calling + # `has_unbacked_symbols()` throws an error. + pass + self.dep_size_hint_cache[(dep, count_bytes)] = res + return self.dep_size_hint_cache[(dep, count_bytes)] + + def get_current_device_or_throw(self) -> torch.device: + if device := self.current_device: + return device + else: + raise RuntimeError("No current device") + + @contextlib.contextmanager + def set_current_device(self, device: torch.device) -> Iterator[None]: + prior = self.current_device + self.current_device = device + try: + yield + finally: + self.current_device = prior + + def get_training_phase(self) -> str: + if self.is_inference: + return "inference" + if self.is_backward: + return "backward" + return "forward" + + @staticmethod + def decide_layout_opt(gm: GraphModule, *, is_inference: bool) -> bool: + """ + Decide if we should enable layout optimization for this graph based on + heuristics. + """ + if not config.layout_optimization: + return False + + if config.force_layout_optimization: + return True + + conv_nodes = [ + n for n in gm.graph.nodes if n.target is torch.ops.aten.convolution.default + ] + nconv = len(conv_nodes) + + if nconv == 0: + return False + + # For cpu backend and mkldnn enabled, we always use channels_last for better performance. + if ( + torch.backends.mkldnn.enabled + and torch.backends.mkldnn.is_available() + and all( + n.args[idx].meta["val"].device.type in SUPPORTED_MKLDNN_DEVICES + for n in conv_nodes + for idx in [0, 1] + ) + ): + return True + + # Following models are skipped due to this: + # jx_nest_base + # volo_d1_224 + if len(list(gm.graph.nodes)) >= 300 * nconv: + log.debug("Skipped layout opt because only a few conv") + return False + + if any( + has_free_symbols(n.args[idx].meta["val"]) + for n in conv_nodes + for idx in [0, 1] + ): + log.debug( + "See perf regression with dynamic shape. Follow up in https://github.com/pytorch/pytorch/issues/102670" + ) + return False + + def is_grouped(n: Any) -> bool: + meta_val = n.args[1].meta["val"] # type: ignore[union-attr, operator] + assert isinstance(meta_val, torch.Tensor) + return n.args[-1] > 1 and meta_val.size(1) > 1 # type: ignore[union-attr, operator] + + def is_in_out_channel(n: torch.fx.Node) -> bool: + return ( + n.args[1].meta["val"].size(0) * 2 <= n.args[1].meta["val"].size(1) # type: ignore[union-attr, operator] + and n.args[1].meta["val"].size(2) > 1 # type: ignore[union-attr, operator] + ) + + def is_small_channel(n: torch.fx.Node) -> bool: + return ( + n.args[1].meta["val"].size(0) <= 64 # type: ignore[union-attr, operator] + and n.args[1].meta["val"].size(1) <= 64 # type: ignore[union-attr, operator] + ) + + # only grouped convolutions benchmarked as slower in conv samples for inference only + if is_inference: + flop_counts: dict[str, float] = defaultdict(float) + for node in conv_nodes: + counted_flops = count_flops_fx(node) + if counted_flops is None: + continue + + if is_grouped(node): + node_type = "grouped" + elif is_small_channel(node): + node_type = "small" + elif is_in_out_channel(node): + node_type = "in_out" + else: + node_type = "default" + + flop_counts[node_type] += counted_flops + else: + log.debug("Conv inputs meta not found") + + # average benchmarked channels last speedup / slowdown, < 1 is speedup. + # taken from the set of convolution inputs in benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/ + # To regenerate these numbers follow https://gist.github.com/eellison/55d7a6ed6f39829d68ac56f95f4df5bb + GROUPED_MULTIPLIER = 1.358 + DEFAULT_MULTIPLIER = 0.823 + IN_OUT_MULTIPLIER = 0.725 + SMALL_MULTIPLIER = 0.783 + + total_flops = sum(flop_counts.values()) + # TODO - get different values per hardware + weighted_flops = ( + flop_counts["grouped"] * GROUPED_MULTIPLIER + + flop_counts["small"] * SMALL_MULTIPLIER + + flop_counts["in_out"] * IN_OUT_MULTIPLIER + + flop_counts["default"] * DEFAULT_MULTIPLIER + ) + do_layout_opt = weighted_flops <= total_flops + if not do_layout_opt: + log.debug( + "Skipped layout opt in inference because weighted flops indicate slowdown, default: %d, channels last: %d", + total_flops, + weighted_flops, + ) + return do_layout_opt + + # Channels last layout can dramatically hurt grouped conv perf. E.g. + # Conv with arguments like + # {"input_shape": [32, 224, 112, 112], "weight_shape": [224, 112, 3, 3], + # "stride": [2, 2], "padding": [1, 1], "groups": 2} + # slows down 31x using channels last.. + + # But a lot of timm models use depthwise separable convolution which will + # result in grouped convolution with in-channel size == 1. + # For those grouped convolution, channels last still helps a lot. + # E.g. + # Conv with arguments + # {"input_shape": [128, 58, 56, 56], "weight_shape": [58, 1, 3, 3], + # "stride": [2, 2], "padding": [1, 1], "groups": 58} + # get 1.86x speedup with channels last layout. + # + # The following heuristics skip using channels-last if the model contains + # grouped convolution with in-channels > 1. + if any(map(is_grouped, conv_nodes)): + log.debug( + "Skip layout opt because found grouped convolution with >1 in_channels!" + ) + return False + + # For some models that contain convolution with larger in-channel than out-channel, applying + # channels last hurts performance. + # Following models are skipped due to this: + # - pytorch_unet + # - phlippe_densenet (slightly worse) + # - Background_Matting (1.22x -> 0.821x) + # - pytorch_CycleGAN_and_pix2pix (1.597x -> 1.294x) + if any(map(is_in_out_channel, conv_nodes)): + log.debug( + "Skip layout opt because some convolutions have smaller out_channel" + ) + return False + + # Following models are skipped due to this: + # - functorch_maml_omniglot + if all(map(is_small_channel, conv_nodes)): + log.debug("Skip layout opt because all convolution channels are too small") + return False + + return True + + def qualify_name(self, name: str) -> str: + """Prepend the given name with the graph name if any.""" + if self.name is not None: + return f"{self.name}_{name}" + return name + + def make_subgraph( + self, + gm: torch.fx.GraphModule, + example_inputs: list[torch.Tensor], + subgraph_name: str, + ) -> SubgraphLowering: + """ + Make a subgraph of the current graph with all inherited parts, except + the graph module (`gm`) and `example_inputs`. The subgraphs are lowered + separately and lifted into a separate function in the parent output + wrapper code. The subgraph name is qualified by the parent graph's + name. Note that the lifting of subgraph is supported for python wrapper + only. For cpp wrapper, we inline the subgraphs in the parent wrapper. + """ + return SubgraphLowering( + parent=self, + gm=gm, + example_inputs=example_inputs, + shape_env=self._shape_env, + cpp_wrapper=self.cpp_wrapper, + aot_mode=self.aot_mode, + extern_node_serializer=self.extern_node_serializer, + is_inference=self.is_inference, + is_backward=self.is_backward, + name=self.qualify_name(subgraph_name), + ) + + def find_nodes_prefer_channels_last(self) -> OrderedSet[Node]: + """ + The rule to decide if an node prefer channels last is simple. + 1. if it's input/output of a convolution + 2. if one of its user prefers channels last + + We have rule 1 because cudnn runs a faster convolution kernel for channels last inputs; + Rule 2 is also important. It makes sure that indirect inputs to convolution also prefers + channels last. + + Consider the scenario: conv -> batch-norm -> relu -> conv + Without rule 2, batch-norm output may use a contiguous layout. That will cause 2 extra copies: + 1. the output of batch-norm should be channels last initially since its input is a conv's output. + Forcing the batch-norm's output to be contiguous results in the first copy + 2. The second conv's input is initially contiguous. This layout is propagated from the batch-norm's output. + We need convert it to channels last layout which results in the second copy. + With rule 2, we makes sure all the tensors in the chain uses channels last layout. So both copies + can be saved. + """ + last_conv = None + nodes_cannot_propagate = [torch.ops.aten.bmm.default] + output_set = OrderedSet[Node]() + for n in reversed(self.module.graph.nodes): # type: ignore[arg-type, union-attr] + if n.target is torch.ops.aten.convolution.default: + output_set.add(n) + if last_conv is None: + last_conv = n + continue + if n.target in nodes_cannot_propagate: + continue + for user in n.users: + if user in output_set: + output_set.add(n) + break + + # need a second pass to add downstream nodes of those channel last nodes to the sets. + # This pass is especially needed to avoid mix-layout kernel inputs in backward pass. + # + # Let's say a conv-batchnorm 's output is passed to relu whose output is in turn returned + # from the fwd graph. Without this second pass, we will force relu's output to be contiguous. + # Then in the kernel in backward pass, the contiguous output of relu may be mix with other channels last + # tensors and passed to a kernel. + # + # This pass improve yolov3 training speedup from 1.116x (worse than disabling layout optimization speedup 1.196x) to 1.457x. + # It also improves dla102 training speedup from 1.240x (worse than disabling layout optimization speedup 1.523x) to 1.835x . + # This also helps the following models: + # - res2net101_26w_4s + # - res2net50_14w_8s + # - sebotnet33ts_256 + for n in self.module.graph.nodes: # type: ignore[union-attr] + # layout propagation ends at last conv node, which will benefit vison transformers. + if last_conv is not None and n == last_conv: + break + if n in output_set: + for user in n.users: + if user.target in nodes_cannot_propagate: + continue + output_set.add(user) + + return output_set + + def warn_fallback(self, name: str) -> None: + if name not in self._warned_fallback: + self._warned_fallback.add(name) + perf_hint_log.info("Using FallbackKernel: %s", name) + + def add_device_info(self, device: torch.device) -> None: + self.device_types.add(device.type) + if device.index is not None: + self.device_idxs.add(device.index) + if V.graph.current_node and device not in self.device_node_mapping: + self.device_node_mapping[device] = V.graph.current_node + + @property + def fake_mode(self) -> torch._subclasses.fake_tensor.FakeTensorMode: + return V.fake_mode + + def try_get_buffer( + self, buffer_name: str + ) -> Optional[Union[ir.TensorBox, ir.Buffer, ir.TorchBindObject]]: + if buffer_name in self.name_to_buffer: + return self.name_to_buffer[buffer_name] + if buffer_name in self.graph_inputs: + return self.graph_inputs[buffer_name] + if buffer_name in self.constants: + data = V.graph.constants[buffer_name] + return ir.ConstantBuffer( + name=buffer_name, + layout=ir.FixedLayout( + data.device, data.dtype, *V.graph.static_sizes_strides(data) + ), + ) + + return None + + def add_symbol_graph_input(self, symbol: sympy.Expr) -> None: + raise RuntimeError("Should not be called for the main graph") + + def get_buffer( + self, buffer_name: str + ) -> Union[ir.TensorBox, ir.Buffer, ir.TorchBindObject]: + buf = self.try_get_buffer(buffer_name) + if buf is not None: + return buf + raise RuntimeError(f"Failed to find buffer matching name {buffer_name}") + + def get_dtype(self, buffer_name: str) -> torch.dtype: + if buffer_name in self.constants: + return self.constants[buffer_name].dtype + # For a mutation op we should return the dtype of the buffer being mutated + if ( + hasattr(self.scheduler, "mutation_real_name") + and buffer_name in self.scheduler.mutation_real_name + ): + mutated_buf = self.scheduler.mutation_real_name[buffer_name] + if mutated_buf in self.name_to_buffer: + return self.name_to_buffer[mutated_buf].get_dtype() + if mutated_buf in self.graph_inputs: + return self.graph_inputs[mutated_buf].get_dtype() + if buffer_name in self.name_to_buffer: + return self.name_to_buffer[buffer_name].get_dtype() + if buffer_name in self.graph_inputs: + return self.graph_inputs[buffer_name].get_dtype() + m = re.match(r"(as_strided|reinterpret_tensor)\(([a-zA-Z0-9_]+),", buffer_name) + if m: + return self.get_dtype(m.group(1)) + raise KeyError(f"could not find {buffer_name}") + + def get_numel(self, buffer_name: str) -> Union[int, Expr]: + if buffer_name in self.constants: + return self.constants[buffer_name].numel() + if buffer_name in self.name_to_buffer: + buf = self.name_to_buffer[buffer_name] + if not buf.has_tensor_output(): + return 1 + return buf.get_numel() + if buffer_name in self.graph_inputs: + return self.graph_inputs[buffer_name].get_numel() + raise KeyError(f"could not find {buffer_name}") + + def run(self, *args: Any) -> Any: # type: ignore[override] + with dynamo_timed("GraphLowering.run"): + return super().run(*args) + + def register_operation(self, op: ir.Operation) -> str: + assert op.operation_name is None, f"Operation registered twice: {op}" + assert isinstance(op, ir.Operation) + name = self.qualify_name(f"op{len(self.operations)}") + self.operations.append(op) + self.name_to_op[name] = op + op.operation_name = name + return name + + def register_buffer(self, buffer: ir.Buffer, *, set_name: bool = False) -> str: + name = self.qualify_name(f"buf{len(self.buffers)}") + self.buffers.append(buffer) + self.name_to_buffer[name] = buffer + device = buffer.get_device() + if ( + # Skip empty CPU tensor so that CUDA graphs can succeed, see https://github.com/pytorch/pytorch/pull/114144 + device is not None + and not ( + isinstance(buffer, ir.ComputedBuffer) + and buffer.is_zero_elements() + and device == torch.device("cpu") + ) + ): + self.add_device_info(device) + + if set_name: + buffer.name = name + return name + + def register_operation_list(self, operation_names: list[str]) -> str: + name = self.qualify_name("list_" + "_".join(operation_names)) + self.lists[name] = operation_names + return name + + def register_users_of( + self, node_output: Union[Iterable[ir.IRNode], ir.IRNode] + ) -> None: + def register(value: Union[Iterable[ir.IRNode], ir.IRNode]) -> None: + if isinstance(value, (list, tuple)): + for x in value: + register(x) + if isinstance(value, ir.TensorBox): + for read_name in value.get_read_names(): + self.name_to_users[read_name].append(value) + + register(node_output) + + def mark_buffer_mutated(self, name: str) -> None: + """ + When a buffer is mutated we need to make sure all the reads to + the old version are realized before the mutation happens. + """ + assert isinstance(name, str) + self.mutated_buffers.add(name) + + if name not in self.name_to_users: + return + + for user in self.name_to_users[name]: + user.realize() + + def get_original_value_of_constant(self, name: str) -> torch.Tensor: + """ + In AOTI, module buffers may have been mutated during the tracing and compilation. + Thus we need to read from previously stored original buffers, to make sure the + generated model.so uses correct initial values. + """ + assert name in self.allocated_constant_name and name in self.constants, ( + "Can not find the original value for " + name + ) + orig_name = get_cloned_parameter_buffer_name(self.allocated_constant_name[name]) + return ( + self.module.meta[orig_name] # type: ignore[index] + if orig_name in self.module.meta # type: ignore[operator] + else self.constants[name] + ) + + def allocate_non_dup_const_name( + self, name: Optional[str], data: Union[Tensor] + ) -> str: + if not config.aot_inductor.use_runtime_constant_folding: + for constant_name, value in self.constants.items(): + if is_same_tensor(data, value): + return constant_name + + if name is None: + name = f"constant{len(self.constants)}" + orig_name = name + if name[0].isdigit(): + name = f"constant_{name}" + name = self.qualify_name(name) + # We may generate a var name for each constant in the codegen. + # Let's only keep sane characters. + prefix = normalize_name(name) + name = prefix + cnt = 0 + while name in self.constants: + name = f"{prefix}_{cnt}" + cnt += 1 + self.constants[name] = data + self.constant_reprs[name] = ( + f"{data.device!r} {data.dtype!r} " + f"{tuple(data.size())!r} {tuple(data.stride())!r} " + f"{hash(data):x}" + ) + self.allocated_constant_name[name] = orig_name # type: ignore[assignment] + return name + + def add_tensor_constant( + self, data: Tensor, name: Optional[str] = None + ) -> Union[TensorBox, ir.ShapeAsConstantBuffer]: + new_name = self.allocate_non_dup_const_name(name, data) + return TensorBox.create( + ir.ConstantBuffer( + name=new_name, + layout=FixedLayout( + data.device, data.dtype, *self.static_sizes_strides(data) + ), + ) + ) + + def constant_name(self, name: str, device_override: Optional[torch.device]) -> str: + """ + We AOT copy constants to the devices they are needed on. + If device_override doesn't match the constant's device, then + copy it and return a different name. + """ + if self.constants[name].device == device_override or device_override is None: + return name + with torch.utils._python_dispatch._disable_current_modes(): + # caller might have OrderedSet fake tensor mode which will create a fake tensor + # when calling .to, so unset modes here + non_dup_const_name = self.allocate_non_dup_const_name( + f"{name}_{device_override.type}{device_override.index or 0}", + self.constants[name].to(device_override), + ) + + assert non_dup_const_name in self.constants, ( + f"{non_dup_const_name} should be in V.graph.constants already" + ) + + # register device-copied buffers and parameters to graph as well + # to codegen correct torch::aot_inductor::ConstantType for them rather than `Unknown` + if any( + name == normalize_name(buffer_name) + for buffer_name in self.named_buffers + ): + self.named_buffers[non_dup_const_name] = self.constants[ + non_dup_const_name + ] + + if any( + name == normalize_name(param_name) + for param_name in self.named_parameters + ): + self.named_parameters[non_dup_const_name] = self.constants[ + non_dup_const_name + ] + + return non_dup_const_name + + # pyrefly: ignore [bad-override] + def placeholder( + self, + target: str, # type: ignore[override] + args: tuple[object], # type: ignore[override] + kwargs: dict[str, object], + ) -> Union[Expr, TensorBox, None]: + self.placeholder_idx += 1 + example = super().placeholder(target, args, kwargs) # type: ignore[arg-type] + target = self.qualify_name(target) + if isinstance(example, SymTypes): + # TODO fix partitioning issue and re-enable for backward + # https://github.com/pytorch/pytorch/issues/155468. + if not V.graph.is_backward: + expr = _get_placeholder_expr(example.node) + else: + expr = example.node.expr + self.graph_inputs[target] = expr + self.graph_input_names.append(target) + return expr + elif isinstance(example, (int, bool, float)): + expr = sympy.sympify(example) + self.graph_inputs[target] = expr + self.graph_input_names.append(target) + return expr + elif isinstance(example, FakeScriptObject): + obj = TorchBindObject(name=target, value=example) + self.graph_inputs[target] = obj + self.graph_input_names.append(target) + return obj + elif example is None: + self.graph_input_names.append(target) + return None + if isinstance(example, BackwardState): + # Ignored arg, must be unused + # Alternately we could filter this out in AotAutograd + self.graph_input_names.append(target) + return None + # See note: Note: [Generator arguments in AOTDispatcher] + elif isinstance(example, torch.Generator): + assert len(V.graph.current_node.users) == 1 and next( + iter(V.graph.current_node.users) + ).target in ( + torch._prims.rng_prims.graphsafe_run_with_rng_state, + torch.ops.higher_order.invoke_subgraph, + ) + gen = ir.GeneratorState(name=target, device=example.device) + self.graph_inputs[target] = gen # type: ignore[assignment] + self.graph_input_names.append(target) + return gen + + assert isinstance(example, torch.Tensor), example + # todo(chilli): We can remove the last check once we turn buffers into + # static shape tensors. That's a hack to workaround Inductor believing + # the buffer should be static but us passing in a fake tensor with + # symbolic shapes. + if not example._has_symbolic_sizes_strides: + # the first N inputs are weights + sizes, strides = self.static_sizes_strides(example) + else: + sizes, strides = self.symbolic_sizes_strides(example) # type: ignore[assignment] + + if ( + self.is_backward + and self.bw_donated_idxs + and self.placeholder_idx in self.bw_donated_idxs + ): + tensor = TensorBox.create( + DonatedBuffer( + name=target, + layout=FixedLayout(example.device, example.dtype, sizes, strides), + ) + ) + else: + # TODO(jansel): handle input aliasing + tensor = TensorBox.create( + InputBuffer( + name=target, + layout=FixedLayout(example.device, example.dtype, sizes, strides), + ) + ) + + self.graph_inputs[target] = tensor + self.graph_input_names.append(target) + self.graph_inputs_original[target] = tensor.data.data # type: ignore[union-attr] + if self.current_node.users: # cudagraphs should work with an unused CPU input + self.add_device_info(example.device) + + # Note: [Input Alignment handling in Inductor] + # Alignment matters for generating efficient code. Some operations, + # e.g. vectorized loads, can only be performed on aligned inputs. + # + # But if we codegen assuming aligned inputs and then get unaligned + # inputs at runtime, then we are forced to clone - which is bad for + # both perf and memory usage. + # + # One option would be to guard on storage_offset%ALIGNMENT, and then + # codegen based on this. But storage_offset guards turned out to be + # expensive and cause recompiles; Instead, we're generating code + # based on the alignment of the example input without guarding. + with maybe_get_suppress_shape_guards_ctx(): + if not should_assume_input_aligned(example): + self.unaligned_buffers.add(target) + return tensor + + def call_function(self, target: Callable, args: Any, kwargs: dict[str, Any]) -> Any: # type: ignore[type-arg, override] + if target is operator.getitem and isinstance(args[0], (list, tuple, dict)): + return super().call_function(target, args, kwargs) + + # hasattr on OpOverloadPacket is slow, check isinstance first + if not isinstance(target, torch._ops.OpOverloadPacket) and hasattr( + target, "_inductor_lowering_function" + ): + # passthrough lowerings from .pattern_matcher + return target(*args, **kwargs) + + if target not in lowerings: + assert isinstance(target, torch._ops.OpOverload), ( + f"{target} is not an OpOverload" + ) + base_name = target.name().split(".")[0] + if base_name in FALLBACK_ALLOW_LIST: + make_fallback(target, warn=False, override_decomp=True) + elif config.implicit_fallbacks: + error = ( + MissingOperatorWithDecomp + if get_decompositions([target]) + else MissingOperatorWithoutDecomp + ) + log.info( + "Creating implicit fallback for:\n%s", + error.operator_str(target, args, kwargs), + ) + + tag: Optional[torch._C.Tag] = get_layout_constraint_tag( + target, with_default=False + ) + if ( + tag is None + and torch._library.utils.is_builtin(target) + and self.is_backward + ): + # for implicit fallback ATen ops during backward, if there + # is no layout constraint tag, we conservatively require contiguous + # input since some eager kernels do not + # support non-contiguous inputs. Otherwise they may silently cause + # accuracy problems. Check https://github.com/pytorch/pytorch/issues/140452 + # We only do this For ATen ops and for backward. + # + # TODO: should really switch to "needs_fixed_stride" constraint on these + # and identify them one by one. + decided_constraint: Optional[Callable[..., tuple[Any, Any]]] = ( + require_contiguous + ) + else: + default_tag: torch._C.Tag = get_layout_constraint_tag( + target, with_default=True + ) + decided_constraint = tag_to_layout_constraint(default_tag) + + make_fallback(target, layout_constraint=decided_constraint) + + elif get_decompositions([target]): + # There isn't a good way to dynamically patch this in + # since AOT Autograd already ran. The error message tells + # the user how to fix it. + raise MissingOperatorWithDecomp(target, args, kwargs) + else: + raise MissingOperatorWithoutDecomp(target, args, kwargs) + + try: + log.debug(" via %s", lowerings[target]) # type: ignore[index] + + n = self.current_node + layout_constraints = maybe_layout_constraints(target) + if layout_constraints: + old_args, old_kwargs = args, kwargs + if layout_constraints is constrain_to_fake_tensors: + # only constrain_to_fake_tensor if this exists. + # otherwise, no constraints at all: the implication is + # that this operator was inserted by a custom pass + # so we'll give them the freedom. + if "eager_input_vals" in n.meta: + fake_args, fake_kwargs = n.meta["eager_input_vals"] + + # (fake_args, fake_kwargs) might not align with (args, kwargs). + # we need to normalize them based on the schema + assert isinstance(target, torch._ops.OpOverload) + + def normalize(args: Any, kwargs: Any) -> tuple[Any, Any]: + result = torch.fx.operator_schemas.normalize_function( + target, args, kwargs + ) + assert result is not None + return result[0], result[1] + + fake_args, fake_kwargs = normalize(fake_args, fake_kwargs) + args, kwargs = normalize(args, kwargs) + old_args, old_kwargs = normalize(old_args, old_kwargs) + + args, kwargs = constrain_to_fake_tensors( + args, kwargs, fake_args, fake_kwargs + ) + else: + args, kwargs = layout_constraints(n, *args, **kwargs) + + if "should_fallback" in n.meta: + out = fallback_handler(target, add_to_fallback_set=False)( + *args, **kwargs + ) + else: + out = lowerings[target](*args, **kwargs) # type: ignore[index] + + if layout_constraints: + # layout_constraints are allowed to make new copies of the inputs. + # if they do, and if the target is mutable, then we need to + # write the new values back into the original inputs. + self.propagate_mutation(n, old_args, old_kwargs, args, kwargs) # type: ignore[possibly-undefined] + + return out + except Exception as e: + raise LoweringException(e, target, args, kwargs).with_traceback( + e.__traceback__ + ) from None + + @staticmethod + def can_inline_constant(t: torch.Tensor) -> bool: + """ + True if this is a small constant attr that will be inlined. + """ + return len(t.shape) == 1 and t.shape[0] <= 8 + + # pyrefly: ignore [bad-override] + def get_attr( + self, + target: str, # type: ignore[override] + args: tuple[()], # type: ignore[override] + kwargs: dict[str, object], + ) -> Union[ + Constant, TensorBox, ShapeAsConstantBuffer, ir.Subgraph, TorchBindObject + ]: + # this is a constant + value = getattr_recursive(self.module, target) # type: ignore[arg-type] + + if isinstance(value, torch.fx.GraphModule): + # Reuse the existing subgraph if we have seen it before already. + if target in self.seen_subgraphs: + return self.seen_subgraphs[target] + + out = ir.Subgraph(name=target, graph_module=value) + self.seen_subgraphs[target] = out + return out + + if isinstance(value, torch._C.ScriptObject): + self.torchbind_constants[target] = value + self.constant_reprs[target] = "" + return TorchBindObject(name=target, value=value) + elif isinstance(value, FakeScriptObject): + self.torchbind_constants[target] = value + self.constant_reprs[target] = "" + return TorchBindObject(name=target, value=value) + + assert isinstance(value, torch.Tensor) + if ( + config.aot_inductor.use_runtime_constant_folding + or config.always_keep_tensor_constants + or unsupported_output_tensor(value) + or target in self.mutated_named_buffers + ): + return self.add_tensor_constant(value, target) + + with no_dispatch(): + if value.shape == (): + return Constant( + value=value.item(), dtype=value.dtype, device=value.device + ) + if self.can_inline_constant(value): + log.debug("Inlining constant: %s ", str(target)) + # tensor lowering has constant inlining logic + from .lowering import tensor + + return tensor(value.tolist(), dtype=value.dtype, device=value.device) + + return self.add_tensor_constant(value, target) + + def call_module(self, target: Any, args: Any, kwargs: Any) -> NoReturn: + raise AssertionError + + def call_method(self, target: Any, args: Any, kwargs: Any) -> NoReturn: + raise AssertionError + + # pyrefly: ignore [bad-override] + def output( + self, + target: str, # type: ignore[override] + args: tuple[object], # type: ignore[override] + kwargs: dict[str, object], + ) -> None: + result = super().output(target, args, kwargs) # type: ignore[arg-type] + if not isinstance(result, (tuple, list)): + # nested subgraphs can have singleton outputs + result = (result,) + assert isinstance(result, (tuple, list)), type(result) + assert all( + isinstance( + x, + ( + TensorBox, + ir.Constant, + type(None), + ir.ConstantBuffer, + sympy.Expr, + sympy.logic.boolalg.Boolean, + int, + ir.EffectfulKernel, + ir.ShapeAsConstantBuffer, + ), + ) + for x in result + ), result + + fx_node_args = V.graph.current_node.args[0] # type: ignore[arg-type] + if not isinstance(fx_node_args, (tuple, list)): + # nested subgraphs can have singleton outputs + fx_node_args = (fx_node_args,) + result = [ir.ExternKernel.realize_input(x) for x in result] + result_correct_strides = [] + + assert len(fx_node_args) == len(result) + for r, fx_node in zip(result, fx_node_args): + if not isinstance(r, (ir.TensorBox, ir.BaseView)): + result_correct_strides.append(r) + elif isinstance(r.get_output_spec(), ir.CommBufferLayout): + # Active references to persistent comm buffers are not allowed + # outside of graphs + result_correct_strides.append(ir.ExternKernel.copy_input(r)) + else: + # AOT Autograd tries to detect stride divergence of inductor from output metadata. + # Here, we try to avoid spurious divergence by matching insignificant strides such as + + # should have already been realized + assert torch._inductor.ir.is_storage_and_layout(r) + meta_strides = [ + s.node.expr if isinstance(s, torch.SymInt) else s + for s in fx_node.meta["val"].stride() + ] + result_correct_strides.append( + ir.try_match_insignificant_strides(r, meta_strides) + ) + + self.graph_outputs = result_correct_strides + value: ir.IRNode + for name, value in self.graph_inputs.items(): + if isinstance(value, TorchBindObject): + continue + assert isinstance( + value, (TensorBox, sympy.Expr, torch._inductor.ir.GeneratorState) + ), f"Unsupported inductor graph input type: {type(value)}" + if not isinstance(value, TensorBox): + continue + value.realize() + assert isinstance(value, TensorBox) + value = value.data + assert isinstance(value, ir.StorageBox) + value_storage_box = value + value = value.data + if not isinstance(value, InputBuffer) or value.get_name() != name: + # one of our inputs was mutated, need to turn that into a copy + ir.MutationLayoutSHOULDREMOVE.realize_into( + value, self.graph_inputs_original[name] + ) + # replace output with mutated input + try: + ind = self.graph_outputs.index(value_storage_box) + self.graph_outputs[ind] = self.graph_inputs_original[name] + except ValueError: + pass + + self.finalize() + log.debug( + "Force channels last inputs for %d conv for the current graph with id %d", + self.num_channels_last_conv, + self.graph_id if self.graph_id is not None else -1, + ) + + def finalize(self) -> None: + for buf in self.buffers: + buf.decide_layout() + + @contextmanager + def set_current_node(self, node: torch.fx.Node): # type: ignore[no-untyped-def] + old = self.current_node + try: + self.current_node = node + yield + finally: + self.current_node = old + + @contextmanager + def set_current_wrapper_code(self) -> Iterator[None]: + old = self.wrapper_code + try: + yield + finally: + self.wrapper_code = old + + def propagate_mutation( + self, + fx_node: torch.fx.Node, + old_args: tuple[Any], + old_kwargs: dict[str, Any], + new_args: tuple[Any], + new_kwargs: dict[str, Any], + ) -> None: + """Propagate mutations on new_args/new_kwargs back to old_args/old_kwargs. + + Assumes we may have cloned old_args/old_kwargs into new_args/new_kwargs + and then called fx_node(*new_args, **new_kwargs). + + If fx_node mutates any of new_args/new_kwargs, and they are different from + old_args/old_kwargs, then we need to update the original tensor. + """ + assert len(old_args) == len(new_args) + assert len(old_kwargs) == len(new_kwargs) + + if fx_node.target is torch.ops.higher_order.triton_kernel_wrapper_mutation: + kwargs = fx_node.kwargs["kwargs"] + assert isinstance(kwargs, dict) + mutated = torch._higher_order_ops.triton_kernel_wrap.get_mutated_tensors( + old_kwargs["kernel_idx"], + old_kwargs["constant_args_idx"], + { + k: v.meta["val"] if isinstance(v, torch.fx.Node) else v + for k, v in kwargs.items() + }, + old_kwargs["tma_descriptor_metadata"], + ) + for name in mutated: + old_arg = old_kwargs["kwargs"][name] + new_arg = new_kwargs["kwargs"][name] + if old_arg is new_arg: + continue + + self.call_function(torch.ops.aten.copy_.default, (old_arg, new_arg), {}) + return + + assert isinstance(fx_node.target, torch._ops.OpOverload) + + def maybe_propagate( + schema_arg: torch._C.Argument, old_arg: ir.IRNode, new_arg: ir.IRNode + ) -> None: + if old_arg is new_arg: + return + if schema_arg.alias_info is not None and schema_arg.alias_info.is_write: + # The lowering for copy_ is smart enough to "replace" old_arg with + # new_arg in all future uses so a copy_ kernel never gets emitted. + # old_arg, new_arg may be immutable_list + if isinstance(old_arg, ir.IRNode): + old_arg = (old_arg,) # type: ignore[assignment] + new_arg = (new_arg,) # type: ignore[assignment] + + for old_arg_item, new_arg_item in zip(old_arg, new_arg): # type: ignore[call-overload] + if old_arg_item is new_arg_item: + continue + self.call_function( + torch.ops.aten.copy_.default, (old_arg_item, new_arg_item), {} + ) + + schema = fx_node.target._schema + for idx, (old_arg, new_arg) in enumerate(zip(old_args, new_args)): + schema_arg = schema.arguments[idx] + maybe_propagate(schema_arg, old_arg, new_arg) + + schema_kwargs = {arg.name: arg for arg in schema.arguments} + + for key in old_kwargs: + old_arg = old_kwargs[key] + new_arg = new_kwargs[key] + schema_arg = schema_kwargs[key] + maybe_propagate(schema_arg, old_arg, new_arg) + + def run_node(self, n: torch.fx.Node) -> object: + def debug(msg: str) -> None: + log.debug("lowering %s %s", LazyString(n.format_node), msg) # type: ignore[arg-type] + + from torch._inductor.compiler_bisector import CompilerBisector + + buffer_watermark = len(self.buffers) + operation_watermark = len(self.operations) + + # origins: OrderedSet[Union[Node, ir.IRNode]] = OrderedSet([n]) + origins: OrderedSet[Any] = OrderedSet([n]) + is_call_function = n.op == "call_function" + if is_call_function: + args, kwargs = self.fetch_args_kwargs_from_env(n) + origins |= gather_origins(args, kwargs) + with ( + ir.IRNode.current_origins(origins), + self.set_current_node(n), + V.set_current_node(n), + ): + if ( + n.op == "call_function" + # this path only for built-in operators + and n.target + and isinstance(n.target, torch._ops.OpOverload) + and torch._library.utils.is_builtin(n.target) + and ( + fallback_node_due_to_unsupported_type(n) + or CompilerBisector.disable_subsystem( + "inductor", "lowerings", lambda: repr(n) + ) + ) + ): + debug("fallback_handler") + result = fallback_handler(n.target, add_to_fallback_set=False)( + *args, # type: ignore[possibly-undefined] + **kwargs, # type: ignore[possibly-undefined] + ) + elif ( + n.op == "call_function" + and isinstance( + n.target, (torch._ops.OpOverload, torch._ops.HigherOrderOperator) + ) + and should_fallback_by_default(n) + ): + # this path supports fallback due to inductor lite mode. It supports + # both OpOverload and HOPs (e.g., triton_kernel_wrapper_functional). + debug("fallback_handler") + result = fallback_handler(n.target, add_to_fallback_set=False)( + *args, # type: ignore[possibly-undefined] + **kwargs, # type: ignore[possibly-undefined] + ) + elif ( + n.op == "call_function" + and n.target is torch.ops.higher_order.triton_kernel_wrapper_mutation + and config.triton_kernel_default_layout_constraint != "flexible_layout" + ): + debug("user_defined_triton_kernel_layout_constraints") + if ( + config.triton_kernel_default_layout_constraint + == "needs_fixed_stride_order" + ): + old_args = args # type: ignore[possibly-undefined] + old_kwargs = kwargs # type: ignore[possibly-undefined] + + if eager_input_vals := n.meta.get("eager_input_vals"): + inp_args = eager_input_vals[0] + inp_kwargs = eager_input_vals[1] + args, kwargs = constrain_to_fake_tensors( + # pyrefly: ignore [unbound-name] + args, + # pyrefly: ignore [unbound-name] + kwargs, + inp_args, + inp_kwargs, + ) + else: + args, kwargs = constrain_to_fx_strides(n, *args, **kwargs) # type: ignore[index] + result = self.call_function(n.target, args, kwargs) # type: ignore[arg-type] + self.propagate_mutation(n, old_args, old_kwargs, args, kwargs) # type: ignore[possibly-undefined] + else: + raise RuntimeError( + f"Unknown triton_kernel_default_layout_constraint: {config.triton_kernel_default_layout_constraint}" + ) + elif is_magic_method(n.target): + # TODO: this is sus, it probably should be handled in the + # lowerings themselves similarly to sym_size/sym-stride + # https://github.com/pytorch/pytorch/issues/127789 + debug("is_magic_method") + if isinstance( + n.meta["val"], (torch.SymInt, torch.SymFloat, torch.SymBool) + ): + result = n.meta["val"].node.expr + else: + result = super().run_node(n) + else: + debug("") + result = super().run_node(n) + + # require the same stride order for dense outputs, + # 1. user-land view() will not throw because inductor + # output different strides than eager + # long term the solution is to make view() always succeed + # with infallible strides. + # 2: as_strided ops, we need make sure its input has same size/stride with + # eager model to align with eager behavior. + as_strided_ops = [ + torch.ops.aten.as_strided.default, + torch.ops.aten.as_strided_.default, + torch.ops.aten.as_strided_scatter.default, + torch.ops.aten.resize.default, + torch.ops.aten.resize_as.default, + ] + is_output = any(user.op == "output" for user in n.users) + is_user_visible = n in self.user_visible_output_strides + is_input_for_as_strided = any( + user.target in as_strided_ops for user in n.users + ) + + if n.meta.get("inductor_realize_to_strides", False) and isinstance( + result, TensorBox + ): + result.realize() + strides = n.meta["val"].stride() + sym_strides = torch._inductor.utils.any_is_symbolic(*strides) + if result.maybe_get_stride() != strides and not sym_strides: + stride_order = ir.get_stride_order(strides) + result = ir.ExternKernel.require_stride_order(result, stride_order) + if ( + is_output + and isinstance(result, TensorBox) + and isinstance(result.data, ir.BaseView) + ): + # Realize so that outputs are correctly aliased + result.realize() + + if (is_output or is_input_for_as_strided) and isinstance( + n.meta["val"], torch.Tensor + ): + if is_user_visible: + strides = self.user_visible_output_strides.get(n) + else: + strides = n.meta["val"].stride() + + if strides is not None and len(strides) > 0: + allow_padding = ( + config.pad_outputs or not is_user_visible + ) and not is_input_for_as_strided + dense = torch._prims_common.is_non_overlapping_and_dense( + n.meta["val"] + ) + unbacked_symbols_in_strides = ( + len(free_unbacked_symbols(strides)) > 0 + ) + if ( + not unbacked_symbols_in_strides + and dense + and len(result.get_size()) == 4 + and n in self.nodes_prefer_channels_last + and not is_user_visible + and not is_input_for_as_strided + ): + strides = ir.FlexibleLayout.stride_ordered_for_memory_format( + result.get_size(), torch.channels_last + ) + if not unbacked_symbols_in_strides and len(strides): + # To avoid converting possible view ops to a copy kernel, we use the previous + # require_exact_strides to handle views. But ultimately it's better to require + # the right strides at the tensor definition. + if n.meta["val"]._is_view() or isinstance( + # pyrefly: ignore [missing-attribute] + result.data, + ir.BaseView, + ): + result = ir.ExternKernel.require_stride_order( + result, + ir.get_stride_order(strides), + allow_padding=allow_padding, + ) + else: + # Fix for 0-d tensors: if result size is empty, + # strides should also be empty + if len(result.get_size()) == 0 and len(strides) > 0: + strides = [] + else: + strides = [ + s.node.expr if isinstance(s, torch.SymInt) else s + for s in strides + ] + result = ir.ExternKernel.require_exact_strides( + result, strides, allow_padding=allow_padding + ) + + # Realize if (1) any user need inputs realized, or (2) there is + # already too many reads and rematerializing can be bad. + num_users = len(OrderedSet(n.users)) + if num_users > 1 and isinstance(result, TensorBox): + for user in n.users: + if user.target in needs_realized_inputs: + result.realize_hint() + # This inclusion is somewhat controversial (from + # discussion between Horace, Natalia, and Elias). + # Currently, it's not very clear why this is helpful. + # The general idea here is that even though a node may + # have FlexibleLayout, we still often *treat* it as if + # it was contiguous. This appears to sometimes result in + # suboptimal behavior. + # + # When we do a better job selecting layout, we should + # revisit this. + need_fixed_layout = [ + torch.ops.aten.convolution_backward.default, + torch.ops.aten.mm.default, + torch.ops.aten._int_mm.default, + ] + need_fixed_channels_last_layout = [] + if not self.layout_opt: + need_fixed_layout.append(torch.ops.aten.convolution.default) + if torch._C._has_mkldnn: + need_fixed_layout += [ + torch.ops.mkldnn._linear_pointwise.default, + torch.ops.mkldnn._linear_pointwise.binary, + torch.ops.aten.mkldnn_rnn_layer.default, + torch.ops.onednn.qlinear_pointwise.default, + torch.ops.onednn.qlinear_pointwise.tensor, + torch.ops.onednn.qlinear_pointwise.binary, + torch.ops.onednn.qlinear_pointwise.binary_tensor, + ] + need_fixed_channels_last_layout += [ + torch.ops.mkldnn._convolution_pointwise.default, + torch.ops.mkldnn._convolution_pointwise.binary, + torch.ops.mkldnn._convolution_pointwise_.binary, + torch.ops.mkldnn._convolution_transpose_pointwise.default, + torch.ops.onednn.qconv_pointwise.default, + torch.ops.onednn.qconv2d_pointwise.binary, + ] + if torch._C.has_mkl: + need_fixed_layout += [torch.ops.mkl._mkl_linear.default] + if user.target in need_fixed_layout: + result = ir.ExternKernel.require_stride_order( + result, + ir.get_stride_order(n.meta["val"].stride()), + allow_padding=True, + ) + if ( + user.target in need_fixed_channels_last_layout + and n is user.args[0] + ): + result = ir.ExternKernel.require_stride_order( + result, + ir.get_stride_order( + make_channels_last_strides_for(n.meta["val"].shape) + ), + ) + if user.op == "output": + # pyrefly: ignore [missing-attribute] + if isinstance(result.data.data, (Pointwise, Reduction)): + result.realize() + + # TODO(jansel): introduce a store vs inline choice + result.mark_reuse(len(n.users)) + + # Realize if the IRNode already has accumulated lots of reads + if isinstance(result, TensorBox) and result.has_exceeded_max_reads(): + # Prevent excessive accumulation in a computed buffer, when + # there are multiple branches each with small number of memory + # reads, but they converge to a user. + result.realize_hint() + + # Realize if a Pointwise has too much stuff to be inlined. + # As this may cause RecursionError during Inductor's evaluation. + if isinstance(result, TensorBox) and isinstance(result.data, StorageBox): + curr = result.data.data + if isinstance(curr, Pointwise): + # Use inner fn as a rough proxy. Good enough. + if curr.has_large_inner_fn(threshold=100): + result.realize() + + assign_origin_node(result, n) + self.register_users_of(result) + + new_unbacked_defs = OrderedSet[sympy.Symbol]() + for buf in self.buffers[buffer_watermark:]: + new_unbacked_defs |= buf.get_unbacked_symbol_defs() + for op in self.operations[operation_watermark:]: + new_unbacked_defs |= op.get_unbacked_symbol_defs() + + shape_env = V.graph.sizevars.shape_env + + # An input can be unbacked symint i.e.: when mark_unbacked is used. + # in that case add it to new_unbacked_defs. + if ( + n.op == "placeholder" + and isinstance(result, sympy.Symbol) + and shape_env.is_unbacked_symint(result) + ): + new_unbacked_defs.add(result) + + def format_new_defs() -> str: + r = [ + f"unbacked_symbol_defs={buf.get_unbacked_symbol_defs()} in:\n{buf}\n" + for buf in self.buffers[buffer_watermark:] + ] + r.extend( + f"unbacked_symbol_defs={op.get_unbacked_symbol_defs()} in:\n{op}\n" + for op in self.operations[operation_watermark:] + ) + return "***\n".join(r) + + # We do not skip unbacked symints that are input for backward see the note below. + if V.graph.is_backward and n.op == "placeholder": + return result + + # Note [Backwards runtime asserts] + # Backwards poses an interesting problem for deferred runtime + # asserts. In the easy case, we may solely close over data + # dependent sized tensors, and there are no binding sites for + # unbacked SymInts. In this case, we can just drop all the + # runtime asserts on the floor: no non-placeholder bindings, no + # problem. + # + # However, it is *possible* for a fresh runtime assert to show up + # between forwards and backwards. Right now, the freezing process + # that happens when we lower forwards means that we will freeze + # runtime asserts, and then the moment the backwards lowering + # process attempts to add a new deferred runtime assert, we will + # fail. Let's say you remove that assert. Now when we get here, + # we need to make sure we actually emit these asserts (because we + # can't emit them in forwards, we already compiled it). So we + # have to do something here. But we don't want to reemit ALL + # deferred runtime asserts, we only want to emit the NEW ones. + # Therefore needing some sort of stratification in the ShapeEnv. + # This is all doable, it just hasn't been done yet. + + unbacked_bindings = resolve_unbacked_bindings( + V.graph.sizevars.shape_env, n.meta.get("unbacked_bindings", {}) + ) + assert unbacked_bindings is not None + # When we do lowering, it is possible we reallocate unbacked SymInts. + # So we need to line up the unbacked SymInts when performing the test + # here + # + # In principle, we could permit lowering to introduce MORE unbacked + # SymInts: as long as all the old unbacked ones are accounted for, + # it's fine for inductor to introduce extra calls to item()/unbacked() + # whatever. This actually happens in practice when an unbacked SymInt + # gets memoized away; naively, when Inductor reprocesses a kernel, it + # doesn't know that the memo still applies, and ends up allocating a + # new symbol. However, this is generally a bad thing: we may still + # end up needing to test equalities on the symbols, and a fresh + # symbol is likely to hit lots of GuardOnDataDependent errors that + # we already know facts for. + renamed_unbacked_bindings = OrderedSet( + V.fake_mode.shape_env.unbacked_renamings.get(s, s) + for s in unbacked_bindings + ) + + assert new_unbacked_defs >= renamed_unbacked_bindings, ( + f"failed {new_unbacked_defs} >= {renamed_unbacked_bindings} (inductor >= fx)\n" + f"fx node is: {n.format_node()}\n" + f"new operations are:\n\n{format_new_defs()}" + ) + self.create_deferred_runtime_asserts(n, new_unbacked_defs) + return result + + def create_deferred_runtime_asserts( + self, n: torch.fx.Node, new_unbacked_defs: OrderedSet[sympy.Symbol] + ) -> None: + # [NOTE] Codegen runtime asserts in Inductor + # + # We need to generate runtime asserts directly in Inductor instead + # of just reusing the asserts from input graphs because we reuse the + # same ShapeEnv as before. In particular, on subsequent graph passes, + # we would immediately turn all of these assertions into noops, + # because when we evaluated their expressions, we would see that + # because we had a deferred runtime assert in the ShapeEnv, we + # know "oh, of course this expression is True" already. + # One example is below: + # + # class Model(torch.nn.Module): + # def forward(self, a, b, c): + # nz = torch.nonzero(a) + # ones = a.new_ones([nz.size(0), b.size(0)]) + # torch._check(ones.size(0) >= 1) + # equals = torch.add(ones, c) + # return equals + # torch._dynamo.mark_dynamic(c, 0) + # When we reuse the ShapeEnv in Inductor lowering, the check that checks + # a and nonzero have the same shape would be evaluated to True after we resolve + # unbacked bindings using the ShapeEnv. + # See test_unbacked_equals_input_size_runtime_assertion in test_aot_inductor. + # + # + # In addition to the Inductor generated runtime asserts, we also + # need the runtime asserts from the input graph, because some derived + # runtime asserts on backed symints are not generated in Inductor. One example is + # this: `y = x.reshape(100, -1).clone()`. x.shape[0] needs to be a multiple of 100. + # See test_aoti_runtime_asserts_backed_symint in test_aot_inductor. + + def make_assert(expr: SympyBoolean, msg: str) -> None: + assert_op = ir.AssertScalar(expr, msg) + self.register_buffer(assert_op, set_name=True) + self.register_operation(assert_op) + + if ( + full_aoti_runtime_assert() + and n.target is torch.ops.aten._assert_scalar.default + and self.aot_mode + ): + node_args, _ = self.fetch_args_kwargs_from_env(n) + if node_args[0] != True: # noqa: E712 + make_assert(node_args[0], f"{node_args[0]} to be True") + else: + # bound_unbacked_symbols tracks the symbols that are created so far, + # we use it to make sure that runtime assertions are added after all + # symbols used in them are defined. + self.bound_unbacked_symbols |= new_unbacked_defs + + shape_env = V.graph.sizevars.shape_env + + # Emit code for runtime asserts that can be inserted at this point. + for i0 in new_unbacked_defs: + ras = self.ras_by_symbol.pop(i0, []) + # NB: size-like not needed, we won't retrace + vr = shape_env.var_to_range[i0] + if not shape_env._default_unspecified_value_range().issubset(vr): + + def is_convertible(s: Expr) -> bool: + if s in (int_oo, -int_oo): + return False + try: + int(s) + return True + except TypeError: + return False + + if is_convertible(vr.lower): + make_assert(i0 >= vr.lower, f"{i0} >= {vr.lower}") + if is_convertible(vr.upper): + make_assert(i0 <= vr.upper, f"{i0} <= {vr.upper}") + + for ra in ras: + fvs = free_unbacked_symbols(ra.expr) + missing = fvs - self.bound_unbacked_symbols + if missing: + i1 = min(missing, key=str) + self.ras_by_symbol.setdefault(i1, []).append(ra) + else: + make_assert(ra.expr, f"{ra.expr}") + + def validate_can_generate_cpp_wrapper(self) -> None: + if config.disable_cpp_codegen: + raise CppWrapperCodegenError("C++ codegen is disabled") + + if sys.platform not in ("linux", "darwin", "win32"): + raise CppWrapperCodegenError(f"Unsupported platform {sys.platform}") + + def init_wrapper_code( + self, + is_subgraph: bool = False, + subgraph_name: Optional[str] = None, + parent_wrapper_code: Optional[PythonWrapperCodegen] = None, + partition_signatures: Optional[GraphPartitionSignature] = None, + ) -> None: + device_types = self.device_types.copy() + device_types.discard("cpu") + device_types.discard("meta") + # TODO(Eikan): Only support mixing cpu and other device now. + assert len(device_types) <= 1, "Does not support mixing {}".format( + "+".join(device_types) + ) + only_cpu = len(device_types) == 0 + self.device_type = "cpu" if only_cpu else device_types.pop() + + if self.cpp_wrapper: + self.validate_can_generate_cpp_wrapper() + + self.device_ops = get_device_op_overrides(self.device_type) + wrapper_code_gen_cls = get_wrapper_codegen_for_device( + self.device_type, self.cpp_wrapper, self.fx_wrapper + ) + assert wrapper_code_gen_cls is not None, ( + f"Device {self.device_type} not supported" + ) + self.wrapper_code = wrapper_code_gen_cls.create( + is_subgraph, + subgraph_name, + parent_wrapper_code, + partition_signatures, + ) + + if self.const_module: + self.wrapper_code._names_iter = self.const_module.wrapper_code._names_iter + + def extract_autotune_inputs( + self, example_inputs: list[Union[int, float, torch.Tensor]] + ) -> None: + import copy + + cloned_gm = copy.deepcopy(self.orig_gm) + example_inputs = copy.deepcopy(example_inputs) + triton_nodes = [] + for node in cloned_gm.graph.nodes: + if ( + node.op == "call_function" + and node.target is torch.ops.higher_order.triton_kernel_wrapper_mutation + ): + triton_nodes.append(node) + + # Store grid related nodes + grid_inputs: list[torch.fx.Node] = [] + visited_grids: dict[torch.fx.Node, int] = {} + # Store kwargs related nodes + triton_inputs: dict[str, Any] = {} + kwargs_inputs: list[torch.fx.Node] = [] + visited_kwargs: dict[Any, int] = {} + for node in triton_nodes: + # first check whether we have fx node in grid settings. + for grid in node.kwargs["grid"]: + for val in grid: + if val in visited_grids: + continue + + if isinstance(val, torch.fx.Node): + visited_grids[val] = len(grid_inputs) + grid_inputs.append(val) + + kwargs = node.kwargs["kwargs"] + # identify which args might be mutated, those should be cloned. + mutated = torch._higher_order_ops.triton_kernel_wrap.get_mutated_tensors( + node.kwargs["kernel_idx"], + node.kwargs["constant_args_idx"], + { + k: v.meta["val"] if isinstance(v, torch.fx.Node) else v + for k, v in kwargs.items() + }, + node.kwargs["tma_descriptor_metadata"], + ) + + new_kwargs: dict[str, int] = {} + with cloned_gm.graph.inserting_before(node): + for k, v in kwargs.items(): + if k in mutated: + new_node = cloned_gm.graph.call_function(torch.clone, args=(v,)) + new_kwargs[k] = len(kwargs_inputs) + kwargs_inputs.append(new_node) + continue + + if v in visited_kwargs: + new_kwargs[k] = visited_kwargs[v] + continue + visited_kwargs[v] = len(kwargs_inputs) + kwargs_inputs.append(v) + new_kwargs[k] = visited_kwargs[v] + triton_inputs[node.name] = new_kwargs + + new_outputs = kwargs_inputs + grid_inputs + for node in cloned_gm.graph.nodes: + if node.op == "output": + node.args = (tuple(new_outputs),) + break + + cloned_gm.recompile() + runner = torch.fx.Interpreter(cloned_gm) + returned_outputs = runner.run(example_inputs) + # Extract and store the grid for autotuning + if len(grid_inputs) > 0: + grid_outputs = returned_outputs[len(kwargs_inputs) :] + self.autotuning_grids = {} + for node in triton_nodes: + dynamic_grid = False + new_grids: list[tuple[Any]] = [] + for grid in node.kwargs["grid"]: + new_grid = [] + for val in grid: + if not isinstance(val, torch.fx.Node): + new_grid.append(val) + continue + dynamic_grid = True + new_grid.append(grid_outputs[visited_grids[val]]) + # pyrefly: ignore [bad-argument-type] + new_grids.append(tuple(new_grid)) + + if dynamic_grid: + self.autotuning_grids[node.name] = new_grids + # Store the kwargs input for autotuning + self.autotuning_inputs = returned_outputs[: len(kwargs_inputs)] + self.autotuning_mapping = triton_inputs + + def codegen_with_cpp_wrapper( + self, + ) -> tuple[ValueWithLineMap, ValueWithLineMap]: + """ + For GPU, Triton kernels are autotuned and stored as cubin files + """ + if any(device in self.device_types for device in ["cuda", "xpu"]): + + def extract_real_inputs() -> list[Union[int, float, torch.Tensor]]: + def materialize( + x: Union[torch.SymInt, torch.SymFloat, torch.Tensor], + ) -> Union[int, float, torch.Tensor]: + if x is None: + # pyrefly: ignore [bad-return] + return None + elif isinstance(x, (torch.SymInt, torch.SymFloat)): + # Need concrete value to run dynamic shapes and tune the result + return x.node.hint + elif isinstance(x, FakeTensor): + return defake(x) + else: + assert isinstance(x, torch.Tensor), ( + "Unknown type when creating real inputs" + str(type(x)) + ) + return x + + tracing_context = torch._guards.TracingContext.try_get() + if tracing_context is not None and not isinstance( + V.real_inputs, NullHandler + ): + if tracing_context.output_strides: + tracing_context.output_strides.clear() + + params_flat = [ + param + for param in tracing_context.params_flat # type: ignore[union-attr] + if param is not None + ] + real_inputs = [ + materialize(x) + for x in itertools.chain(params_flat, V.real_inputs) + ] + else: + # In the backward pass, V.real_inputs is not OrderedSet. + # Generating random inputs based on self.example_inputs sometimes can be problematic, + # e.g. illegal memory access. A comprehensive fix is to autotune in a separate process. + real_inputs = [ + materialize(x) # type:ignore[arg-type] + for x in ( + self.example_inputs # type:ignore[union-attr] + if isinstance(V.real_inputs, NullHandler) + else V.real_inputs + ) + ] + + if self.mutated_inputs: + from .compile_fx import clone_preserve_strides + + mutated_input_idxs = [ + idx + for idx, name in enumerate(self.graph_inputs) + if name in self.mutated_inputs + and isinstance(real_inputs[idx], torch.Tensor) + ] + for idx in mutated_input_idxs: + # clone mutated Tensor inputs to avoid mutating them in + # the first pass of the CPP wrapper-based compilation, as + # this will lead to a side effect on the example inputs: + # e.g. if torch.compile(f)(x) if called on input-mutating + # f, the inputs x will be mutated twice in the process: + # once here, and again when running the compiled model; + # this will also lead to a numerically incorrect output + mutated_inp = real_inputs[idx] + assert isinstance(mutated_inp, torch.Tensor) + real_inputs[idx] = clone_preserve_strides(mutated_inp) + del mutated_inp + return real_inputs + + if config.triton.autotune_at_compile_time: + # If autotune_at_compile_time is True, we can do the codegen in one-pass + # We will construct the autotuning values if user defined kernel exists. + if config.triton.autotune_with_sample_inputs: + user_defined_kernels = False + for op in self.operations: + if isinstance(op, ir.UserDefinedTritonKernel): + user_defined_kernels = True + break + if user_defined_kernels: + real_inputs = extract_real_inputs() + self.extract_autotune_inputs(real_inputs) + return self.codegen() + else: + # first pass + self.cpp_wrapper = False + compiled = self.compile_to_module().call + + real_inputs = extract_real_inputs() + with torch.utils._python_dispatch._disable_current_modes(): + compiled(real_inputs) + del real_inputs + + # second pass + self.cpp_wrapper = True + self.removed_buffers.clear() + self.removed_operations.clear() + self.inplaced_to_remove.clear() + V.graph.sizevars.precomputed_replacements.clear() + V.graph.sizevars.inv_precomputed_replacements.clear() + metrics.reset() + with config.patch({"triton.autotune_at_compile_time": False}): + return self.codegen() + else: + # cpu + return self.codegen() + + def _update_scheduler(self) -> None: + """ + (Re)initializes the scheduler member. When initializing the scheduler, no CUBIN + files should be generated (to avoid biasing any benchmarks and pessimizing + fusion decisions). + """ + from .scheduler import Scheduler + + with config.patch("triton.store_cubin", False): + self.scheduler = Scheduler(self.operations) + + def codegen(self) -> tuple[ValueWithLineMap, ValueWithLineMap]: + with dynamo_timed("GraphLowering.codegen", log_pt2_compile_event=True): + self.init_wrapper_code() + + self._update_scheduler() + V.debug.draw_orig_fx_graph(self.orig_gm, self.scheduler.nodes) + + self.wrapper_code.push_codegened_graph(self) + self.scheduler.codegen() + + log.debug( + "Finished codegen for all nodes. The list of kernel names available: %s", + V.graph.all_codegen_kernel_names, + ) + + result = self.wrapper_code.generate(self.is_inference) + self.wrapper_code.pop_codegened_graph() + return result + + def codegen_subgraph(self, parent_graph: GraphLowering) -> None: + """ + This is a more compact version of the `codegen()` above + where we codegen this graph as a subgraph of some parent + graph. The parent graph is passed as an argument: the + intention is to inline codegening of the subgraph in + the parent graph's wrapper code (including the generated + kernels). The wrapper code is not finalized (via `.generate()` + call), as this will be done in the parent graph's `codegen()`. + """ + with dynamo_timed("GraphLowering.codegen_subgraph", log_pt2_compile_event=True): + self.wrapper_code = parent_graph.wrapper_code + self.device_ops = parent_graph.device_ops + self.cpp_wrapper = parent_graph.cpp_wrapper + self.device_types = parent_graph.device_types + self.device_idxs = parent_graph.device_idxs + self.device_type = parent_graph.device_type + + self._update_scheduler() + self.scheduler.codegen() + + def count_bytes( + self, + ) -> tuple[ + int, list[tuple[BaseSchedulerNode, int]], list[tuple[BaseSchedulerNode, float]] + ]: + total_bytes = 0 + node_counts = [] + node_runtimes = [] + for node in self.scheduler.nodes: + num_bytes = node.get_read_write_buffers_sizes() + total_bytes += num_bytes + node_counts.append((node, num_bytes // 4)) + node_runtimes.append((node, node.get_estimated_runtime())) + + return total_bytes, node_counts, node_runtimes + + # No-op to be patched for unit tests + save_output_code: Optional[Callable[[str], None]] = None + + def compile_to_module(self) -> CompiledModule: + with dynamo_timed( + "GraphLowering.compile_to_module", + phase_name="code_gen", + log_pt2_compile_event=True, + dynamo_compile_column_us="inductor_code_gen_cumulative_compile_time_us", + ): + return self._compile_to_module() + + def _compile_to_module(self) -> CompiledModule: + # If we're here, we don't have to worry about the kernel code, which is only + # returned separately in AOTInductor mode. + wrapper_code, _ = ( + self.codegen_with_cpp_wrapper() if self.cpp_wrapper else self.codegen() + ) + + if isinstance(wrapper_code, ValueWithLineMap): + mod = self._compile_to_module_lines(wrapper_code) + elif isinstance(wrapper_code, FileBackedGraphModule): + mod = wrapper_code + else: + raise NotImplementedError( + f"Unrecognized wrapper code type: {type(wrapper_code)}" + ) + + # Logged twice as per https://github.com/pytorch/pytorch/pull/99038#discussion_r1167826029 + # TODO. Revisit this once the logging API is more mature + assert mod.__file__ is not None + + log_module_code(mod.__file__) + log.debug("Output code written to: %s", mod.__file__) + output_code_log.info("Output code written to: %s", mod.__file__) + if config.benchmark_kernel: + print(f"Compiled module path: {mod.__file__}", file=sys.stderr) + if isinstance(wrapper_code, FileBackedGraphModule): + V.debug.output_code(mod.__file__) + V.debug.copy(os.path.splitext(mod.__file__)[0] + ".debug") + + return mod + + def _compile_to_module_lines( + self, wrapper_code: ValueWithLineMap + ) -> CompiledModule: + from .codecache import PyCodeCache + + if config.triton.autotune_at_compile_time: + # sanitize docstrings in kernel defs (#155006) + kernel_autotune_defs = self.wrapper_code.kernel_autotune_defs.getvalue() + kernel_autotune_defs = kernel_autotune_defs.replace('"""', '\\"\\"\\"') + + tuning_code = ( + 'r"""\n' + + "Compile-time auto-tuning block: \n" + + kernel_autotune_defs + + self.wrapper_code.kernel_autotune_calls.getvalue() + + '"""\n' + ) + wrapper_code.value = tuning_code + wrapper_code.value + if GraphLowering.save_output_code is not None: + GraphLowering.save_output_code(wrapper_code.value) + output_code_log.debug("Output code: \n%s", wrapper_code.value) + + inductor_meta = autotune_cache.inductor_meta_from_config() + AutotuneCacheBundler.begin_compile(inductor_meta, code=wrapper_code.value) + + try: + linemap = [ + (line_no, node.stack_trace) # type: ignore[attr-defined] + for line_no, node in wrapper_code.line_map + ] + key, path = PyCodeCache.write(wrapper_code.value) + output_code_log.debug("Output code written to: %s", path) + + V.debug.output_code(path) + V.debug.copy(os.path.splitext(path)[0] + ".debug") + except Exception: + trace_structured( + "inductor_output_code", + # Just omit the filename, I still want the code though! + payload_fn=lambda: wrapper_code.value, + ) + raise + else: + trace_structured( + "inductor_output_code", + lambda: { + "filename": path, + "file_path": os.path.abspath(path), + }, + payload_fn=lambda: wrapper_code.value, + ) + with dynamo_timed("PyCodeCache.load_by_key_path", log_pt2_compile_event=True): + mod = PyCodeCache.load_by_key_path( + key, + path, + linemap=linemap, # type: ignore[arg-type] + attrs={ + **self.constants, + **self.torchbind_constants, + **self.opaque_value_type_classes, + }, + ) + self.cache_key = key + self.cache_path = path + self.cache_linemap = linemap # type: ignore[assignment] + + if config.benchmark_harness and config.profile_bandwidth_output: + # run the inputs code gen to get the bandwidth info + mod.benchmark_compiled_module(times=1, repeat=1) + + return mod + + def _get_output_names(self, graph_outputs: list[ir.IRNode]) -> list[str]: + names = [] + shape_counter = itertools.count(0) + none_counter = itertools.count(0) + for node in graph_outputs: + if isinstance(node, ir.NoneAsConstantBuffer): + names.append(f"{self.name}_none{next(none_counter)}") + elif isinstance(node, ir.ShapeAsConstantBuffer): + names.append(f"{self.name}_shape{next(shape_counter)}") + else: + names.append(node.get_name()) + return names + + def get_output_names(self) -> list[str]: + return self._get_output_names(self.graph_outputs) + + def is_unspec_arg(self, name: str) -> bool: + # dynamo wraps unspec variable as 0d CPU tensor, + # need to convert to scalar during codegen (triton only) + return ( + name in self.graph_inputs + and self.graph_inputs[name].get_numel() == 1 + and len(self.graph_inputs[name].get_size()) == 0 + and get_device_type(self.graph_inputs[name]) == "cpu" + ) or name in self.zero_dim_cpu_tensor_list + + +class SubgraphLowering(GraphLowering): + """ + Mostly a helper class for the subgraph lowering. The main goal is to call + init_wrapper_code with the subgraph related arguments. + """ + + def __init__(self, parent: GraphLowering, *args: Any, **kwargs: Any) -> None: + self.parent = parent + super().__init__(*args, **kwargs) + + def init_wrapper_code( + self, + is_subgraph: bool = False, + subgraph_name: Optional[str] = None, + parent_wrapper_code: Optional[PythonWrapperCodegen] = None, + partition_signatures: Optional[GraphPartitionSignature] = None, + ) -> None: + super().init_wrapper_code( + is_subgraph=True, + subgraph_name=self.name, + parent_wrapper_code=self.parent.wrapper_code, + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/hooks.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/hooks.py new file mode 100644 index 0000000000000000000000000000000000000000..72a935fb5d272adc39e9bf5116f452d66addccdd --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/hooks.py @@ -0,0 +1,31 @@ +# mypy: allow-untyped-defs +import contextlib +from collections.abc import Callable +from typing import TYPE_CHECKING + + +if TYPE_CHECKING: + import torch + +# Executed in the order they're registered +INTERMEDIATE_HOOKS: list[Callable[[str, "torch.Tensor"], None]] = [] + + +@contextlib.contextmanager +def intermediate_hook(fn): + INTERMEDIATE_HOOKS.append(fn) + try: + yield + finally: + INTERMEDIATE_HOOKS.pop() + + +def run_intermediate_hooks(name, val): + global INTERMEDIATE_HOOKS + hooks = INTERMEDIATE_HOOKS + INTERMEDIATE_HOOKS = [] + try: + for hook in hooks: + hook(name, val) + finally: + INTERMEDIATE_HOOKS = hooks diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/index_propagation.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/index_propagation.py new file mode 100644 index 0000000000000000000000000000000000000000..3711266ae93b06d6ff5992712fa9e8e9cd8279cd --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/index_propagation.py @@ -0,0 +1,381 @@ +# mypy: allow-untyped-defs +"""This file implements the IndexPropagation ops handler, which wraps an +underlying handler to add a limited form of constant propagation, as well as +propagation of sympy expressions downstream of ops.index_expr calls. + +For example, say we have the IR: + + tmp0 = ops.index_expr(x, torch.int32) + tmp1 = ops.constant(2, torch.int32) + tmp2 = ops.mul(tmp0, tmp1) + tmp3 = ops.indirect_indexing(tmp2, x_size) + tmp4 = ops.load("buf0", tmp3) + +The underlying handler would just see: + + ops.load("buf0", x * 2) + +This is limited by the set of operators handled in the sympy expression +printers. So simple operations like minimum and maximum cannot be translated to +SymPy expressions yet, despite sympy.Min and sympy.Max existing. + +""" + +import itertools +from collections.abc import Sequence +from dataclasses import dataclass +from typing import Any, Literal, Optional, overload, TypeAlias, Union + +import sympy + +import torch +from torch._prims_common import dtype_to_type, is_integer_dtype +from torch.utils._sympy.functions import FloorDiv, ModularIndexing, Where +from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges + +from .ops_handler import DefaultHandler +from .sizevars import statically_known_true +from .utils import generate_assert +from .virtualized import V + + +_ExprType = Union[sympy.Expr, float, int, bool] + + +def _is_constant(val: _ExprType): + if isinstance(val, sympy.Basic): + return val.is_number + return isinstance(val, (int, float, bool)) + + +def upper_bound(val: _ExprType): + return bound_sympy(val).upper if isinstance(val, sympy.Expr) else val + + +@dataclass +class TypedExpr: + """A SymPy expression with associated type""" + + expr: _ExprType + dtype: torch.dtype + + def is_constant(self): + return _is_constant(self.expr) + + def __post_init__(self): + if _is_constant(self.expr): + expr = self.expr + if isinstance(expr, sympy.Expr): + expr = expr.expand(identity=True) + expr = dtype_to_type(self.dtype)(expr) + if is_integer_dtype(self.dtype): + bits = torch.iinfo(self.dtype).bits + if self.dtype.is_signed: + expr = expr + 2 ** (bits - 1) + expr = expr % 2**bits + if self.dtype.is_signed: + expr = expr - 2 ** (bits - 1) + self.expr = expr + + +class SymPyOps: + """An ops handler where all IR values are SymPy expressions + + When a value cannot be represented as a SymPy expression, the method is + either not defined, or returns NotImplemented + + """ + + @staticmethod + def identity(value: Any) -> Any: + return value + + @staticmethod + def constant(value: Union[int, float, bool], dtype: torch.dtype) -> TypedExpr: + return TypedExpr(value, dtype) + + @staticmethod + def index_expr(value: Union[sympy.Expr, int], dtype: torch.dtype) -> TypedExpr: + return TypedExpr(value, dtype) + + @staticmethod + def to_dtype( + value: TypedExpr, + dtype: torch.dtype, + src_dtype: Optional[torch.dtype] = None, + use_compute_types: bool = False, + ) -> TypedExpr: + return TypedExpr(value.expr, dtype) + + @staticmethod + def abs(x: TypedExpr) -> TypedExpr: + return TypedExpr(abs(x.expr), x.dtype) # type: ignore[arg-type] + + @staticmethod + def square(x: TypedExpr) -> TypedExpr: + return TypedExpr(x.expr * x.expr, x.dtype) + + @staticmethod + def add(x: TypedExpr, y: TypedExpr) -> TypedExpr: + result_type = torch.promote_types(x.dtype, y.dtype) + return TypedExpr(x.expr + y.expr, result_type) + + @staticmethod + def sub(x: TypedExpr, y: TypedExpr) -> TypedExpr: + result_type = torch.promote_types(x.dtype, y.dtype) + return TypedExpr(x.expr - y.expr, result_type) + + @staticmethod + def mul(x: TypedExpr, y: TypedExpr) -> TypedExpr: + result_type = torch.promote_types(x.dtype, y.dtype) + return TypedExpr(x.expr * y.expr, result_type) + + @staticmethod + def neg(x: TypedExpr) -> TypedExpr: + return TypedExpr(-x.expr, x.dtype) + + @staticmethod + def floordiv(x: TypedExpr, y: TypedExpr) -> TypedExpr: + result_type = torch.promote_types(x.dtype, y.dtype) + if not is_integer_dtype(result_type): + return NotImplemented + + return TypedExpr(FloorDiv(x.expr, y.expr), result_type) + + @staticmethod + def mod(x: TypedExpr, y: TypedExpr) -> Optional[TypedExpr]: + result_type = torch.promote_types(x.dtype, y.dtype) + if not is_integer_dtype(result_type): + return NotImplemented + + result_expr = ModularIndexing(x.expr, sympy.S.One, y.expr) + return TypedExpr(result_expr, result_type) + + @staticmethod + def remainder(x: TypedExpr, y: TypedExpr) -> Optional[TypedExpr]: + result_type = torch.promote_types(x.dtype, y.dtype) + if not is_integer_dtype(result_type): + return NotImplemented + + x_expr = sympy.sympify(x.expr) + y_expr = sympy.sympify(y.expr) + # In these cases, remainder in Python == remainder in C++, so this transformation + # is sound + if ( + x_expr.is_nonnegative is not None + and x_expr.is_nonnegative == y_expr.is_positive + ): + result_expr = ModularIndexing(x.expr, sympy.S.One, y.expr) + return TypedExpr(result_expr, result_type) + return NotImplemented + + @staticmethod + def minimum(x: TypedExpr, y: TypedExpr) -> TypedExpr: + result_type = torch.promote_types(x.dtype, y.dtype) + return TypedExpr(sympy.Min(x.expr, y.expr), result_type) + + @staticmethod + def maximum(x: TypedExpr, y: TypedExpr) -> TypedExpr: + result_type = torch.promote_types(x.dtype, y.dtype) + return TypedExpr(sympy.Max(x.expr, y.expr), result_type) + + +@dataclass +class IndexPropVar: + value: Any # Either an IR value, or TypedExpr if is_symbolic is true + is_symbolic: bool = False + + @staticmethod + def new_symbolic(expr: TypedExpr) -> "IndexPropVar": + return IndexPropVar(expr, is_symbolic=True) + + def __post_init__(self): + assert not self.is_symbolic or isinstance(self.value, TypedExpr), ( + "Symbolic IndexPropVar must contain a TypedExpr" + ) + + +IndexPropResult: TypeAlias = Union[IndexPropVar, tuple["IndexPropResult", ...]] + + +class IndexPropagation(DefaultHandler): + """Ops wrapper that tries to propagate constant and index_expr values through the computation. + + This aims to maximize the compile time simplification possible, and convert + indirect indexing from arange into normal static indexing. + + """ + + def __init__( + self, + inner: Any, + iter_ranges: dict[sympy.Symbol, sympy.Expr], + indirect_var_ranges: dict[sympy.Symbol, sympy.Expr], + ) -> None: + self._inner = inner + self.shape_env = V.graph.sizevars.shape_env + + var_to_range = { + k: ValueRanges(0, upper_bound(v) - 1) for k, v in iter_ranges.items() + } + self.var_to_range = tuple( + itertools.chain(self.shape_env.var_to_range.items(), var_to_range.items()) + ) + # NOTE: this is intentionally kept as a reference so the caller can + # update it in-place + self.indirect_var_ranges = indirect_var_ranges + + axioms = [] + for x, s in iter_ranges.items(): + axioms.append(0 <= x) + axioms.append(x < s) + self.axioms = tuple(axioms) + self.shape_env.get_axioms() + + def materialize_expr(self, expr: sympy.Expr, dtype: torch.dtype) -> Any: + # Construct a new constant/index_expr from the SymPy expression + if _is_constant(expr): + val = dtype_to_type(dtype)(expr) + return self._inner.constant(val, dtype) + return self._inner.index_expr(expr, dtype) + + def unwrap(self, a: Union[Any, IndexPropVar]) -> Any: + if isinstance(a, (list, tuple)): + return tuple(self.unwrap(v) for v in a) + + if not isinstance(a, IndexPropVar): + return a + + # Prefer the sympy representation if possible + if a.is_symbolic: + return self.materialize_expr(a.value.expr, a.value.dtype) + + return a.value + + def wrap(self, a) -> IndexPropResult: + if isinstance(a, (list, tuple)): + return tuple(self.wrap(v) for v in a) + return IndexPropVar(a) + + @overload + def fallback( + self, + name: Literal["indirect_indexing"], + args: Sequence[Any], + kwargs: dict[str, Any], + ) -> IndexPropVar: ... + + @overload + def fallback( + self, name: str, args: Sequence[Any], kwargs: dict[str, Any] + ) -> IndexPropResult: ... + + def fallback( + self, name: str, args: Sequence[Any], kwargs: dict[str, Any] + ) -> IndexPropResult: + # Fallback to the wrapped handler + new_args = [self.unwrap(a) for a in args] + new_kwargs = {k: self.unwrap(v) for k, v in kwargs.items()} + return self.wrap(getattr(self._inner, name)(*new_args, **new_kwargs)) + + def propagate_sympy( + self, name: str, args: Sequence[Any], kwargs: dict[str, Any] + ) -> IndexPropResult: + # Build a new SymPy expression from this ops call + def unwrap(a: Union[Any, IndexPropVar]) -> Any: + if not isinstance(a, IndexPropVar): + return a + return a.value + + new_args = [unwrap(a) for a in args] + new_kwargs = {k: unwrap(v) for k, v in kwargs.items()} + new_expr = getattr(SymPyOps, name)(*new_args, **new_kwargs) + is_valid_expr = new_expr is not NotImplemented and ( + # Inductor doesn't expect floating point in sympy expressions, but + # allow floating point constants to be propagated + new_expr.is_constant() or new_expr.expr.is_integer + ) + if not is_valid_expr: + return self.fallback(name, args, kwargs) + return IndexPropVar.new_symbolic(new_expr) + + def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: + if not hasattr(SymPyOps, name): + return self.fallback(name, args, kwargs) + + var_arguments = [ + a + for a in itertools.chain(args, kwargs.values()) + if isinstance(a, IndexPropVar) + ] + if not all(v.is_symbolic for v in var_arguments): + return self.fallback(name, args, kwargs) + + return self.propagate_sympy(name, args, kwargs) + + def statically_true(self, e): + """ + Given some iter_ranges, return a function that given an expression, returns whether + it is true or false using value ranges, guard knowledge and runtime_asserts. + + FIXME I think this may not be entirely right, as we may not be able to use all runtime_asserts + If this is an issue, just use guards in `self.axioms`. + + The proper way of handling this would be to have a global shape_env that adds + runtime_asserts as they happen in the code. Then, it should be used in SimplifyIndexing + to perform wrap_expr and in CSEProxy.check_bounds to elide upper / lower bounds also + for indirect_indexing + """ + var_to_range = ( + *self.var_to_range, + *( + (k, ValueRanges(0, upper_bound(v) - 1)) + for k, v in self.indirect_var_ranges.items() + ), + ) + # pyrefly: ignore [bad-argument-type] + return statically_known_true(self.shape_env, e, self.axioms, var_to_range) + + def indirect_indexing( + self, + index: Union[Any, IndexPropVar], + size: Any, + check: bool = True, + wrap_neg=True, + ) -> Any: + if isinstance(index, IndexPropVar) and index.is_symbolic: + # If we find something we can convert into a direct indexing we do so + # We still need to (perhaps) wrap the expression and add bound checks + # We want to do this "constant folding", as we don't allow to fuse + # kernels into indirect indexing + + expr = sympy.sympify(index.value.expr) + + # TODO Perhaps move this logic to the simplify indexing pass + def wrap_expr(expr): + # Positive, negative, mixed + if self.statically_true(0 <= expr): + return expr + elif self.statically_true(expr < 0): + return expr + size + else: + return Where(expr < 0, expr + size, expr) + + # Sometimes it's easier to prove 0 <= expr than the weaker -size <= expr + can_prove_lower = self.statically_true(0 <= expr) or self.statically_true( + -size <= expr + ) + can_prove_upper = self.statically_true(expr < size) + if wrap_neg: + expr = wrap_expr(expr) + if generate_assert(check): + self.fallback( + "check_bounds", + (expr, size), + dict(lower=not can_prove_lower, upper=not can_prove_upper), + ) + return expr + + indirect_var = self.fallback( + "indirect_indexing", (index, size, check, wrap_neg), {} + ).value + return indirect_var diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/inductor_prims.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/inductor_prims.py new file mode 100644 index 0000000000000000000000000000000000000000..458c881ef0e74a76247dc3e76d77494a568a2e8f --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/inductor_prims.py @@ -0,0 +1,226 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import functools +import logging +import operator +from typing import Optional, TYPE_CHECKING + +import torch +from torch import _prims, Tensor + + +if TYPE_CHECKING: + from collections.abc import Sequence + + +log = logging.getLogger(__name__) + + +def make_prim( + schema: str, + impl_aten, + return_type=_prims.RETURN_TYPE.NEW, + doc: str = "", + tags: Optional[Sequence[torch.Tag]] = None, +): + if isinstance(return_type, tuple): + + def meta(*args, **kwargs): + return tuple(_prims.TensorMeta(o) for o in impl_aten(*args, **kwargs)) + + else: + + def meta(*args, **kwargs): + return _prims.TensorMeta(impl_aten(*args, **kwargs)) + + return _prims._make_prim( + schema=schema, + return_type=return_type, + meta=meta, + impl_aten=impl_aten, + doc=doc, + tags=tags, + ) + + +def eager_force_stride(input_tensor: Tensor, stride) -> Tensor: + if input_tensor.stride() == stride: + return input_tensor + new_tensor = input_tensor.clone().as_strided( + input_tensor.shape, + stride, + ) + new_tensor.copy_(input_tensor) + return new_tensor + + +def eager_prepare_softmax(x: Tensor, dim: int) -> tuple[Tensor, Tensor]: + amax = torch.amax(x, dim, keepdim=True) + return amax, torch.sum(torch.exp(x - amax), dim, keepdim=True) + + +# Custom prims used for handling randomness +seed = make_prim( + "inductor_seed(Device device) -> Tensor", + lambda device: torch.randint(2**63 - 1, [], device=device), + doc="create a fresh seed (one per call) for use with inductor_rand", + tags=(torch.Tag.nondeterministic_seeded,), +) +seeds = make_prim( + "inductor_seeds(int count, Device device) -> Tensor", + lambda count, device: torch.randint(2**63 - 1, [count], device=device), + doc="Horizontal fusion of many inductor_seed() calls", + tags=(torch.Tag.nondeterministic_seeded,), +) +lookup_seed = make_prim( + # if inductor_lookup_seed changes, update partitioners.py + "inductor_lookup_seed(Tensor seeds, int index) -> Tensor", + lambda seeds, index: seeds[index].clone(), + doc="Extract a single seed from the result of inductor_seeds()", +) +# inductor_random() doesn't accept a dtype. +# instead, its lowering always burns in float32, and conversions to a different type +# are explicit in the graph. We therefore need this impl (used during tracing) to hardcoded +# the dtype, so it always faithfully produces a float32 tensor during tracing, +# even if the default dtype is set to something else. +random = make_prim( + "inductor_random(SymInt[] size, Tensor seed, str mode) -> Tensor", + lambda size, seed, mode: getattr(torch, mode)( + size, device=seed.device, dtype=torch.float32 + ), + doc="torch.rand()/torch.randn() using backend-specific RNG that can be fused", +) +randint = make_prim( + "inductor_randint(SymInt low, SymInt high, SymInt[] size, Tensor seed) -> Tensor", + lambda low, high, size, seed: torch.randint(low, high, size, device=seed.device), + doc="torch.randint() using backend-specific RNG that can be fused", +) +force_stride_order = make_prim( + "inductor_force_stride_order(Tensor input, SymInt[] stride) -> Tensor", + eager_force_stride, + doc="Force the stride order for input tensor. No-op if the input tensor already has the stride. Do a copy otherwise", +) +_unsafe_index_put_ = make_prim( + "_unsafe_index_put_(Tensor(a!) self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor(a!)", + lambda self, indices, values, accumulate=False: torch.ops.aten.index_put_( + self, indices, values, accumulate + ), + doc="Unsafe index_put_ (doesn't issue device asserts)", +) +fma = make_prim( + "fma(Tensor a, Tensor b, Tensor c) -> Tensor", + lambda a, b, c: (a * b) + c, + doc="Fused multiply add: fma(a, b, c) -> (a * b) + c without rounding after the multiplication", + tags=(torch.Tag.pointwise,), +) +prepare_softmax_online = make_prim( + "prepare_softmax_online(Tensor a, int dim) -> (Tensor, Tensor)", + eager_prepare_softmax, + return_type=(_prims.RETURN_TYPE.NEW, _prims.RETURN_TYPE.NEW), + doc="Prepare the softmax by computing the max and sum.", +) + + +def _flattened_index_to_nd(indices, width): + import sympy + + from torch.utils._sympy.functions import FloorDiv + + dim = len(width) + + if dim == 1: + return [indices] + elif dim >= 2: + m = functools.reduce(operator.mul, width[1:]) + if isinstance(indices, sympy.Expr) or isinstance(m, sympy.Expr): + ih = FloorDiv(indices, m) + else: + ih = indices // m + indices_new = indices - (ih * m) + return [ih, *_flattened_index_to_nd(indices_new, width[1:])] + else: + raise ValueError(f"Unknown dim: {dim}") + + +def _flatten_index(indices, width): + result = indices[0] + for d in range(1, len(indices)): + result = width[d] * result + indices[d] + return result + + +def _low_memory_max_pool_with_offsets_aten( + self, + kernel_size, + stride, + padding, + dilation, + ceil_mode, +): + dim = len(kernel_size) + if dim == 2: + vals, indices = torch.ops.aten.max_pool2d_with_indices( + self, kernel_size, stride, padding, dilation, ceil_mode + ) + else: + vals, indices = torch.ops.aten.max_pool3d_with_indices( + self, kernel_size, stride, padding, dilation, ceil_mode + ) + + idhw = _flattened_index_to_nd(indices, self.shape[-dim:]) + + dhw_inc = [] + + for d in range(dim): + bh_shape = [1] * self.ndim + bh_shape[-dim + d] = -1 + bh = torch.arange( + indices.shape[-dim + d], dtype=torch.int64, device=self.device + ).view(bh_shape) + hbase = bh * stride[d] - padding[d] + h_inc = (idhw[d] - hbase) // dilation[d] + dhw_inc.append(h_inc) + + offsets = _flatten_index(dhw_inc, kernel_size) + + return vals, offsets.to(torch.int8) + + +def _low_memory_max_pool_offsets_to_indices_aten( + offsets, + kernel_size, + input_size, + stride, + padding, + dilation, +): + dim = len(kernel_size) + offsets = offsets.to(torch.int64) + dhw_inc = _flattened_index_to_nd(offsets, kernel_size) + + idhw = [] + for d in range(dim): + bh_shape = [1] * offsets.ndim + bh_shape[-dim + d] = -1 + bh = torch.arange( + offsets.shape[-dim + d], dtype=torch.int64, device=offsets.device + ).view(bh_shape) + hbase = bh * stride[d] - padding[d] + idhw.append(hbase + dhw_inc[d] * dilation[d]) + + return _flatten_index(idhw, input_size) + + +_low_memory_max_pool_with_offsets = make_prim( + "_low_memory_max_pool_with_offsets(Tensor self, SymInt[] kernel_size, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool ceil_mode) -> (Tensor, Tensor)", # noqa: B950 + _low_memory_max_pool_with_offsets_aten, + return_type=(_prims.RETURN_TYPE.NEW, _prims.RETURN_TYPE.NEW), + doc="Instead of returning indices, returns indices offsets.", +) + +_low_memory_max_pool_offsets_to_indices = make_prim( + "_low_memory_max_pool_offsets_to_indices(Tensor self, SymInt[] kernel_size, SymInt[] input_size, SymInt[] stride, SymInt[] padding, SymInt[] dilation) -> Tensor", # noqa: B950 + _low_memory_max_pool_offsets_to_indices_aten, + doc="Convert small int offsets to regular indices.", +) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/invert_expr_analysis.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/invert_expr_analysis.py new file mode 100644 index 0000000000000000000000000000000000000000..816482dba020c80b732bf35e88b210417aa4b77e --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/invert_expr_analysis.py @@ -0,0 +1,208 @@ +from dataclasses import dataclass +from typing import Optional + +import sympy + +from torch._inductor.utils import _IntLike, argsort_sym +from torch.utils._sympy.functions import FloorDiv, ModularIndexing + +from .virtualized import V + + +def static_eq(a: _IntLike, b: _IntLike) -> bool: + return V.graph.sizevars.statically_known_equals(a, b) + + +@dataclass +class Term: + coefficient: _IntLike + range: Optional[_IntLike] # None for unbounded + original_expr: sympy.Expr + reconstruction_multiplier: _IntLike # The multiplier needed for reconstruction + + +def generate_inverse_formula( + expr: sympy.Expr, var: sympy.Symbol +) -> Optional[sympy.Expr]: + """ + Analyze an expression to see if it matches a specific invertible pattern that we + know how to reverse. + + We're looking for expressions that are sums of terms where each term extracts a + distinct bounded range from the input variable, like: + + y = c₀*a₀ + c₁*a₁ + c₂*a₂ + ... + cₙ*aₙ + + where each aᵢ must be one of these specific patterns: + - ModularIndexing(var, divisor, modulo) + - FloorDiv(ModularIndexing(var, 1, modulo), divisor) + - FloorDiv(var, divisor) + - var (the variable itself) + + The key pattern we need is: + - Coefficients are strictly decreasing: c₀ > c₁ > c₂ > ... > cₙ + - Each coefficient matches the product of ranges of later terms (mixed-radix property) + - Each term extracts a bounded range, creating non-overlapping "slots" + + If we find this pattern, we can generate the reconstruction transformation that + decomposes the variable and rebuilds it using the correct multipliers. + + EXAMPLE: + Input: 100*((p//100)) + 10*((p%100)//10) + (p%10) + + Returns the reconstruction expression: + remainder₀ = p + component₀ = remainder₀ // 100 # hundreds digit + remainder₁ = remainder₀ % 100 + component₁ = remainder₁ // 10 # tens digit + remainder₂ = remainder₁ % 10 + component₂ = remainder₂ # ones digit + result = component₀*100 + component₁*10 + component₂*1 + + This decomposes p into its components and rebuilds it using the original + multipliers, which should equal the input expression. + + Args: + expr: Expression to analyze (sum of terms with ModularIndexing, FloorDiv, etc.) + var: The variable being decomposed + + Returns: + None if not invertible, or the reconstruction expression + + References: + Mixed-radix systems: https://en.wikipedia.org/wiki/Mixed_radix + """ + # Step 1: Parse all terms + terms = parse_terms(expr, var) + if not terms: + return None + + # Step 2: Sort by coefficient (descending) + coeffs = [t.coefficient for t in terms] + idxs = reversed(argsort_sym(V.graph.sizevars.shape_env, coeffs)) + terms = [terms[i] for i in idxs] + + # Step 3: Check invertibility conditions + if not check_invertibility(terms): + return None + + return generate_reconstruction_expr(terms, var) + + +def parse_terms(expr: sympy.Expr, var: sympy.Symbol) -> Optional[list[Term]]: + """Parse expression into terms.""" + if not isinstance(expr, sympy.Add): + # Single term + term = parse_single_term(expr, var) + return [term] if term else [] + + terms = [] + for arg in expr.args: + term = parse_single_term(arg, var) + if term: + terms.append(term) + else: + return None # If any term fails to parse, fail completely + + return terms + + +def parse_single_term(term: sympy.Expr, var: sympy.Symbol) -> Optional[Term]: + """Parse a single term and extract coefficient, range, and reconstruction multiplier.""" + # Extract coefficient and expression parts + coefficient, expr_parts = term.as_coeff_mul() + + if len(expr_parts) == 0: + # Pure constant term + return Term( + coefficient=coefficient, + range=1, + original_expr=1, + reconstruction_multiplier=0, + ) + elif len(expr_parts) == 1: + expr = expr_parts[0] + else: + # Multiple non-constant factors, too complex + return None + + # Now determine the range and reconstruction multiplier + range_val, reconstruction_multiplier = analyze_expression_properties(expr, var) + if reconstruction_multiplier is None: + return None + + return Term( + coefficient=coefficient, + range=range_val, + original_expr=expr, + reconstruction_multiplier=reconstruction_multiplier, + ) + + +def analyze_expression_properties( + expr: sympy.Expr, var: sympy.Symbol +) -> tuple[Optional[_IntLike], Optional[_IntLike]]: + """Analyze an expression to determine its range and reconstruction multiplier.""" + # ModularIndexing(var, divisor, modulo) = (var // divisor) % modulo + if isinstance(expr, ModularIndexing): + x, div, mod = expr.args + if static_eq(x, var): + return mod, div # Range is mod, multiplier is div + + # FloorDiv cases + if isinstance(expr, FloorDiv): + base, divisor = expr.args + + # FloorDiv(ModularIndexing(var, 1, mod), div) = (var % mod) // div + if isinstance(base, ModularIndexing): + x, inner_div, mod = base.args + if static_eq(x, var) and static_eq(inner_div, 1): + range_val = FloorDiv(mod, divisor) + return range_val, divisor # Range is mod//div, multiplier is div + + # FloorDiv(var, divisor) = var // divisor (unbounded) + elif static_eq(base, var): + return None, divisor # Unbounded range, multiplier is div + + return None, None + + +def check_invertibility(terms: list[Term]) -> bool: + """Check if the terms represent an invertible transformation.""" + if not terms: + return False + + # Coefficients must be strictly decreasing + coeffs = [t.coefficient for t in terms] + if argsort_sym(V.graph.sizevars.shape_env, coeffs) != list( + reversed(range(len(coeffs))) + ): + return False + + # Check mixed-radix property: each coeff[i] = coeff[i+1] * range[i+1] + expected_coeff = 1 + for term in reversed(terms): + if not static_eq(term.coefficient, expected_coeff): + return False + if term.range is not None: + expected_coeff *= term.range + + return True + + +def generate_reconstruction_expr(terms: list[Term], var: sympy.Symbol) -> sympy.Expr: + y = var + reconstruction = sympy.S.Zero + remainder = y + + for i, term in enumerate(terms): + if i < len(terms) - 1: + component = FloorDiv(remainder, term.coefficient) + remainder = ModularIndexing(remainder, 1, term.coefficient) + else: + # Last term should also divide by its coefficient + component = FloorDiv(remainder, term.coefficient) + + reconstruction += component * term.reconstruction_multiplier + + return reconstruction diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/ir.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/ir.py new file mode 100644 index 0000000000000000000000000000000000000000..b091b95abdf14bef71b98fb43c92f06546c04486 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/ir.py @@ -0,0 +1,9716 @@ +from __future__ import annotations + +import contextlib +import dataclasses +import functools +import itertools +import logging +import operator +import os +import textwrap +import traceback +from collections.abc import Callable, Container, Generator, Iterable, Iterator, Sequence +from contextlib import AbstractContextManager, nullcontext +from enum import Enum +from functools import partial +from typing import ( + Any, + cast, + ClassVar, + Literal, + Optional, + overload, + SupportsFloat, + SupportsInt, + TYPE_CHECKING, + TypeAlias, + TypeVar, + Union, +) +from typing_extensions import assert_never, Never, override, ParamSpec, Self, TypeIs +from unittest.mock import patch + +import sympy +from sympy import Expr, Integer, Symbol + +import torch._export.serde.schema as export_schema +import torch._library.utils as library_utils +import torch._logging +import torch.fx +import torch.utils._pytree as pytree +from torch._dynamo.utils import identity +from torch._export.serde.serialize import GraphModuleSerializer +from torch._higher_order_ops.auto_functionalize import can_auto_functionalize +from torch._inductor import metrics +from torch._inductor.utils import get_free_symbols +from torch._prims_common import ( + compute_required_storage_length, + is_boolean_dtype, + is_float_dtype, + make_channels_last_strides_for, + StrideType, +) +from torch.fx.experimental.symbolic_shapes import ( + _remove_effect_token_unbacked_bindings, + compute_unbacked_bindings, + free_symbols, + free_unbacked_symbols, + IterateExprs, + rebind_unbacked, + resolve_unbacked_bindings, + ShapeEnv, + SymTypes, +) +from torch.fx.node import Node +from torch.utils._ordered_set import OrderedSet +from torch.utils._python_dispatch import _disable_current_modes +from torch.utils._sympy.functions import CleanDiv, FloorDiv, Mod, ModularIndexing +from torch.utils._sympy.symbol import SymT + +from . import config, dependencies +from .codegen.common import ( + BackendFeature, + CodegenSymbol, + get_scheduling_for_device, + index_prevent_reordering, + Kernel, +) +from .dependencies import ( + Dep, + extract_free_symbols, + extract_input_node_reduction_ranges, + extract_read_writes, + var_builder, +) +from .loop_body import LoopBody +from .ops_handler import OpCounterCSE, OpCountResult, ReductionType, StoreMode +from .runtime.benchmarking import benchmarker +from .runtime.hints import DeviceProperties, ReductionHint +from .utils import ( + argsort, + argsort_sym, + cache_on_self, + cache_on_self_and_args, + ceildiv, + convert_shape_to_inductor, + convert_shape_to_symint, + developer_warning, + do_bench_using_profiling, + dtype_from_size, + get_dtype_size, + get_kernel_metadata, + GPU_ALIGN_BYTES, + ir_dataclass, + is_dynamic, + is_gpu, + sympy_dot, + sympy_index_symbol, + sympy_index_symbol_with_prefix, + sympy_product, + sympy_subs, + tensor_is_aligned, +) +from .virtualized import ops, OpsValue, V + + +if TYPE_CHECKING: + from torch._library.fake_class_registry import FakeScriptObject + from torch.fx.experimental.symbolic_shapes import SympyBoolean + from torch.fx.node import Argument + + from .codegen.cuda.cuda_template import CUDATemplate + from .codegen.wrapper import PythonWrapperCodegen + from .graph import GraphLowering + from .utils import IndentedBuffer + +else: + CUDATemplate: TypeAlias = object + + +try: + import triton + + triton_version = triton.__version__ + has_triton = True +except ImportError: + triton_version = None + has_triton = False + + +_P = ParamSpec("_P") +_T = TypeVar("_T") +_U = TypeVar("_U") +_V = TypeVar("_V") + +_IntLike: TypeAlias = Union[int, Expr] +_NumLike: TypeAlias = Union[int, float, Expr] + +_OpOverloads: TypeAlias = Union[torch._ops.OpOverload, torch._ops.HigherOrderOperator] + +log = logging.getLogger(__name__) +indent = functools.partial(textwrap.indent, prefix=" ") +aten = torch.ops.aten + +autotune_warmup = int(os.getenv("TORCH_AUTOTUNE_WARMUP", 25)) +autotune_rep = int(os.getenv("TORCH_AUTOTUNE_REP", 100)) + +""" [Note: Inductor IR] + +Inductor's IR is produced by executing 'lowering' code (see lowering.py). Each +lowering is registered to a particular aten operator, and expects inputs that +correspond to the aten schema. However, in place of torch Tensor inputs, lowerings +expect Inductor TensorBox inputs. + +TensorBox IR represents torch tensors. Tensors are sometimes single objects owning +storage, and sometimes views of another Tensor's storage. Mutating tensor operations +(such as add_()) affect the underlying storage and any associated views. Other operations +(such as .t_()) update metadata about the current view but don't modify the underlying storage. + +To model this in Inductor, the IR distinguishes between TensorBox, View, StorageBox and Buffer. + +TensorBox is the top level IR construct that any lowering should produce and maps to a torch.Tensor +output from an operation. But just as torch.Tensors take different forms, TensorBox IR can +reference View IR or directly reference StorageBox IRs. + +Some Inductor lowerings produce new sets of 'Box'es, while others (such as .t() or other view ops) +may take an existing TensorBox and point it to a new underlying View IR. + +Tensors that directly own storage are represented as a chain of: +TensorBox -> StorageBox -> Buffer +where Buffer is a simple (1D) allocation, and StorageBox introduces the concept of a Layout. + +If you mutate the data of such a tensor, we swing the StorageBox pointer to point to a new buffer +(leaving the old buffer unmodified and functionalizing the operation). + +Tensors backed by views add one more indirection to the IR. +TensorBox -> View -> StorageBox -> Buffer +In these cases, the underlying StorageBox/Buffer will be shared with the pre-view TensorBox. + +Computation is represented by Operation nodes, with each operation producing 1 +or more output Buffers. In the case of mutations, these will be new Buffers that have the +mutated buffer listed in its get_mutation_names(). + +It is also possible to have an InputBuffer for which there is no corresponding Operation, +e.g. it may be a graph input or compile time constant. + +""" + + +_NodeOrNodes: TypeAlias = Union[ + int, + "TensorBox", + dict[str, "TensorBox"], + "Symbol", + "IRNode", + Sequence[ + Optional[Union[int, dict[str, "TensorBox"], "TensorBox", "Symbol", "IRNode"]] + ], +] + + +def _is_static(x: object) -> TypeIs[Union[int, Integer]]: + return isinstance(x, (int, Integer)) + + +@dataclasses.dataclass(frozen=True) +class GraphPartitionSignature: + # symbol inputs that are necessary for codegen + symbol_inputs: OrderedSet[sympy.Symbol] + + # mapping from partition input name to IRNode or Expr. Need the name str since + # we cannot get name from Expr. + input_nodes: dict[str, Union[IRNode, sympy.Expr, TorchBindObject]] + output_nodes: list[IRNode] + + # mapping from partition input name to a boolean for whether deallocating it + # in the partition function + input_deallocation: dict[str, bool] + skip_cudagraph: bool + + # name of constants read/written by the graph partition + constant_names: list[str] + + +def validate_ir(node_or_nodes: Optional[_NodeOrNodes]) -> None: + def _check_tensorbox(nodes: Optional[_NodeOrNodes]) -> None: + # Could expand this to check deeper properties + # (e.g. TensorBox points to View or StorageBox) + if nodes is None: + pass + elif isinstance(nodes, (list, tuple)): + for node in nodes: + _check_tensorbox(node) + elif isinstance(nodes, dict): + for node in nodes.values(): + _check_tensorbox(node) + else: + assert isinstance( + nodes, + ( + ExpandView, + DynamicScalar, + AssertScalar, + TensorBox, + sympy.logic.boolalg.Boolean, + Expr, + int, + EffectfulKernel, + ShapeAsConstantBuffer, + ), + ), ( + f"Found {type(nodes)}, which is not a supported top level IR node. See [Note: Inductor IR]" + ) + + # Be picky about the accepted data structure (don't use pytree here) + _check_tensorbox(node_or_nodes) + + +def ops_wrapper(name: str) -> Callable[..., OpsValue]: + assert isinstance(name, str), type(name) + + def fn(*args: object, **kwargs: object) -> OpsValue: + return getattr(ops, name)(*args, **kwargs) + + return fn + + +def inverse_reorder(order: Sequence[int]) -> Callable[[Sequence[_T]], Sequence[_T]]: + inv_order = dict(zip(order, range(len(order)))) + + def reindex(index: Sequence[_T]) -> Sequence[_T]: + assert len(index) == len(inv_order) + return [index[inv_order[i]] for i in range(len(index))] + + return reindex + + +def same_reorder(order: Sequence[int]) -> Callable[[Sequence[_T]], Sequence[_T]]: + def reindex(index: Sequence[_T]) -> Sequence[_T]: + assert len(index) == len(order) + return [index[order[i]] for i in range(len(index))] + + return reindex + + +def fuse_reindexing( + reindex1: Callable[[Sequence[_U]], Sequence[_V]], + reindex2: Callable[[Sequence[_T]], Sequence[_U]], +) -> Callable[[Sequence[_T]], Sequence[_V]]: + def reindex(index: Sequence[_T]) -> Sequence[_V]: + return reindex1(reindex2(index)) + + return reindex + + +NHWC_STRIDE_ORDER = [3, 0, 2, 1] +NHWDC_STRIDE_ORDER = [4, 0, 3, 2, 1] + + +def get_fill_order( + seq: Sequence[Union[int, torch.SymInt, Expr]], shape_env: Optional[ShapeEnv] = None +) -> Sequence[int]: + """ + Convert strides to fill order (argsort) + """ + if shape_env is None or all(isinstance(s, (int, sympy.Integer)) for s in seq): + sorted_idx: Sequence[int] = argsort(seq) + else: + # argsort_sym handles unbacked symints (with the help of the shape_env) + sorted_idx = argsort_sym(shape_env, seq) + return sorted_idx + + +def stride_order2fill_order(order: Sequence[Union[int, Integer]]) -> Sequence[int]: + """ + Convert stride order to fill order + For channel last format, + + stride order = [3, 0, 2, 1] and fill order = [1, 3, 2, 0] + """ + lookup = {pos: idx for idx, pos in enumerate(order)} + fill_order = [lookup[i] for i in range(len(order))] + return fill_order + + +def get_stride_order( + seq: Sequence[Union[int, torch.SymInt, Expr]], shape_env: Optional[ShapeEnv] = None +) -> Sequence[int]: + """ + Convert strides to stride order + """ + sorted_idx: Sequence[int] = get_fill_order(seq, shape_env) + out = [0 for _ in range(len(seq))] + for i, elem in enumerate(sorted_idx): + out[elem] = i + return out + + +@overload +def ir_node_to_tensor(x: None, guard_shape: bool = True) -> None: ... + + +@overload +def ir_node_to_tensor(x: IRNode, guard_shape: bool = True) -> torch.Tensor: ... + + +def ir_node_to_tensor( + x: Optional[IRNode], guard_shape: bool = True +) -> Optional[torch.Tensor]: + if x is None: + return None + + shape_fn: Callable[[Union[int, Expr]], Union[int, Expr]] + if not guard_shape: + shape_fn = V.graph.sizevars.size_hint + else: + shape_fn = identity + size = [shape_fn(s) for s in x.get_size()] + stride: StrideType + if is_storage_and_layout(x): + stride = [shape_fn(s) for s in x.get_layout().stride] + else: + stride = FlexibleLayout.contiguous_strides(size) + dtype = x.get_dtype() + device = x.get_device() + size = convert_shape_to_symint(size) + # pyrefly: ignore [bad-assignment] + stride = convert_shape_to_symint(stride) + with V.graph.sizevars.shape_env.suppress_guards(): + t = torch.empty_strided( + size=size, stride=stride, dtype=dtype, device=device + ).zero_() + return t + + +def may_convert_to_optional( + value: Optional[Sequence[_T]], +) -> Optional[Sequence[Optional[_T]]]: + if isinstance(value, list) and not value: + # [None] makes sure the cpp wrapper codegen will generate something like + # {std::nullopt} instead of {} + return [None] + return value + + +def get_device_type( + x: Union[IRNode, OutputSpec, torch.device, None, str], +) -> Optional[str]: + if isinstance(x, str) or x is None: + return x + elif isinstance(x, torch.device): + return x.type + elif isinstance(x, (IRNode, OutputSpec)): + return get_device_type(x.get_device()) + # pyrefly: ignore [bad-argument-type] + assert_never(f"get_device_type({x}: {type(x).__name__})") + + +def is_triton(x: Union[IRNode, torch.device, None, str]) -> bool: + device = get_device_type(x) + # Special case cpu and cuda as using the method below + # to determine if the scheduler is a triton scheduler subclass + # requires instantiating a scheduler for them + if device in ["cpu", "cuda"]: + if getattr(config, f"{device}_backend") == "triton": + return True + return False + if ( + device is None + or (device_scheduling := get_scheduling_for_device(device)) is None + ): + return False + from .codegen.triton import TritonScheduling + + assert isinstance(device_scheduling, type), type(device_scheduling) + return issubclass(device_scheduling, TritonScheduling) + + +def is_cpu(x: Union[IRNode, torch.device, None, str]) -> bool: + return get_device_type(x) == "cpu" + + +def is_aligned_realized_tensor(x: Union[Buffer, TensorBox], alignment: int) -> bool: + if ( + not isinstance(x, IRNode) + or x.maybe_get_stride() is None + or free_unbacked_symbols(x.get_stride()) + or free_unbacked_symbols(x.get_size()) + ): + return False + + aligned_strides = sympy.And( + *(sympy.Eq(Mod(s, alignment), 0) for s in x.get_stride()[:-1]) + ) + aligned_last_dim = sympy.Or( + sympy.Eq(x.get_stride()[-1], 1), sympy.Le(x.get_size()[-1], 1) + ) + is_aligned = sympy.And(aligned_strides, aligned_last_dim) + + # Make sure to guard to recompile when necessary. + return V.graph.sizevars.guard_or_false(is_aligned) + + +def significant_strides_equal( + strides1: Sequence[_IntLike], + strides2: Sequence[_IntLike], + shape: Sequence[_IntLike], +) -> bool: + """ + Returns true if the strides are equal, ignoring dimensions of size 1 . + """ + assert len(shape) == len(strides1) and len(strides1) == len(strides2) + for dim, s1, s2 in zip(shape, strides1, strides2): + if V.graph.sizevars.statically_known_leq(dim, 1): + continue + + if not V.graph.sizevars.statically_known_equals( + s1, s2 + ) and V.graph.sizevars.symbolic_hint(s1) != V.graph.sizevars.symbolic_hint(s2): + return False + + return True + + +def try_match_insignificant_strides( + tensor: IRNode, + strides: Sequence[Union[int, torch.SymInt]], +) -> IRNode: + """ + Tries to match the strides of the tensor to those in the meta_strides. Strides of insignificant + dimensions - size 0 or 1 - will be updated. + + If there are real stride differences (NHWC vs NCHW), or the tensor is not realized, then the input will be returned + """ + if not is_storage_and_layout(tensor): + return tensor + + if all( + V.graph.sizevars.statically_known_equals(s1, s2) + for s1, s2 in zip(strides, tensor.get_stride()) + ): + return tensor + + if not significant_strides_equal(strides, tensor.get_stride(), tensor.get_size()): + return tensor + + storage, old_layout = as_storage_and_layout(tensor) + new_stride = [*old_layout.stride] + for i, s in enumerate(tensor.get_size()): + if V.graph.sizevars.statically_known_leq(s, 1): + new_stride[i] = strides[i] + + new_layout = FixedLayout( + old_layout.device, + old_layout.dtype, + old_layout.size, + new_stride, + old_layout.offset, + old_layout.is_pinned, + ) + return TensorBox(ReinterpretView(data=storage, layout=new_layout)) + + +def gm_original_output_strides(gm: torch.fx.GraphModule) -> None: + output_node = gm.graph.find_nodes(op="output")[0] + output_node.meta["user_visible_output_idxs"] = [ + idx for idx, _ in enumerate(output_node.args) + ] + from torch._inductor.compile_fx import record_original_output_strides + + record_original_output_strides(gm) + + +def get_symbolic_inputs(inputs: Sequence[IRNode]) -> list[Expr]: + sym_vars: OrderedSet[Expr] = OrderedSet() + for inp in inputs: + sym_vars |= get_free_symbols(inp.get_size(), unbacked_only=False) + sym_vars |= get_free_symbols(inp.get_stride(), unbacked_only=False) + + return list(sym_vars) + + +def try_get_name(x): + if isinstance(x, TensorBox): + x = x.data + if isinstance(x, BaseView): + x = x.unwrap_view() + if isinstance(x, StorageBox): + x = x.data + return x.get_name() if isinstance(x, Buffer) else None + + +class IRNode: + """Base class for all intermediate representation (IR) nodes in TorchInductor. + + Note: + This is an abstract base class. Most methods raise NotImplementedError + and must be overridden by concrete subclasses. + """ + + _current_origins: ClassVar[OrderedSet[Any]] = OrderedSet() + + # NB: These are kinda weird, + origins: OrderedSet[Any] = dataclasses.field(init=False) + # traces back to where the IRNode is created in Inductor + traceback: Optional[list[str]] = dataclasses.field(init=False) + origin_node: Optional[torch.fx.Node] = dataclasses.field(init=False) + + @staticmethod + @contextlib.contextmanager + def current_origins(origins: OrderedSet[Node]) -> Generator[None, None, None]: + old = IRNode._current_origins + IRNode._current_origins = old | origins + try: + yield + finally: + IRNode._current_origins = old + + @staticmethod + def is_realized_node(node: IRNode) -> bool: + return isinstance( + node, + ( + ComputedBuffer, + InputsKernel, + InputBuffer, + ReinterpretView, + TemplateBuffer, + ), + ) + + def _post_init_setattr(self, attr: str, value: Any) -> None: + # Intended for use in __post_init__ for enforcing an invariant on a dataclass + # If you must, can also be used for setting provenance info + # We would like to try and minimize these usages though + object.__setattr__(self, attr, value) + + def __post_init__(self) -> None: + origins = OrderedSet(self._current_origins) + self._post_init_setattr("origins", origins) + self._post_init_setattr( + "traceback", traceback.format_stack() if config.debug_ir_traceback else None + ) + self._post_init_setattr("origin_node", None) + + def get_read_names(self) -> OrderedSet[str]: + return OrderedSet(dep.name for dep in self.get_reads()) + + def get_traceback(self) -> Optional[list[str]]: + return self.traceback + + def get_origin_node(self) -> Optional[torch.fx.Node]: + return self.origin_node + + def get_defining_op(self) -> Optional[Operation]: + return None + + def get_stack_traces(self) -> OrderedSet[str]: + # Return stack traces to user model code + # A single IRNode could correspond to multiple lines of code + stack_traces: OrderedSet[str] = OrderedSet() + origins = self.origins + if isinstance(self, ExternKernel): + origin_node = self.get_origin_node() + if self.origin_node: + origins = OrderedSet([origin_node]) + for node in origins: + if hasattr(node, "stack_trace") and node.stack_trace: + # nodes in the backward graph don't have mapping to pre_grad_graph + stack_traces.add(node.stack_trace) + else: + pre_grad_nodes = ( + torch._inductor.debug._inductor_post_to_pre_grad_nodes.get( + "postToPre", + {}, + # pyrefly: ignore [missing-attribute] + ).get(node.name, []) + ) + if not isinstance(pre_grad_nodes, list): + continue + for node_name in pre_grad_nodes: + stack_trace = ( + torch._inductor.debug._inductor_pre_grad_node_stack_trace.get( + node_name, None + ) + ) + if stack_trace: + stack_traces.add(stack_trace) + return stack_traces + + def common_repr(self, shorten: bool = True) -> Sequence[str]: + origins = f"origins={getattr(self, 'origins', '')}" + if shorten and len(origins) > 64: + # this can get *very* long + origins = f"{origins[:61]}..." + if not self.get_stack_traces(): + return [origins] + + stack_trace_str = [] + for stack_trace in self.get_stack_traces(): + stack_trace_str.append("stack_traces = {") + stack_trace_str += stack_trace.split("\n") + stack_trace_str.append("}") + return [origins] + stack_trace_str + + def str_helper( + self, lines: Sequence[object], shorten: bool = True, multiline: bool = True + ) -> str: + lines = list(lines) + list(self.common_repr(shorten)) + lines = list(map(str, lines)) + if multiline: + # pyrefly: ignore [no-matching-overload] + new_lines = indent(",\n".join(lines)) + return f"{type(self).__name__}(\n{new_lines}\n)" + else: + return f"{type(self).__name__}({lines})" + + def get_dtype(self) -> torch.dtype: + return self.dtype + + def maybe_get_dtype(self) -> Optional[torch.dtype]: + try: + return self.get_dtype() + except NotImplementedError: + return None + + def get_layout(self) -> Layout: + raise NotImplementedError(f"get_layout() is not implemented by {type(self)}!") + + def maybe_get_layout(self) -> Optional[Layout]: + try: + return self.get_layout() + except NotImplementedError: + return None + + def get_output_spec(self) -> OutputSpec: + return self.get_layout() + + def maybe_get_output_spec(self) -> Optional[OutputSpec]: + try: + return self.get_output_spec() + except NotImplementedError: + return None + + def has_tensor_output(self) -> bool: + """True for single tensor output (excludes MultiOutput)""" + return isinstance(self.maybe_get_output_spec(), Layout) + + def get_size(self) -> Sequence[Expr]: + raise NotImplementedError(f"get_size() is not implemented by {type(self)}!") + + def maybe_get_size(self) -> Optional[Sequence[_IntLike]]: + try: + return self.get_size() + except NotImplementedError: + return None + + @property + def shape(self) -> Union[_IntLike, sympy.Rel, Sequence[_IntLike]]: + return self.get_size() + + def get_numel(self) -> Expr: + return sympy_product(self.get_size()) + + def is_zero_elements(self) -> bool: + return V.graph.sizevars.statically_known_true(sympy.Eq(self.get_numel(), 0)) + + def realize(self) -> Optional[str]: + """ + If the IRNode refers to data which has not been materialized (e.g., + it is a Pointwise/Reduction that could potentially have more + compute fused into it), realize the IRNode into physical memory, + ending the possibility of fusing into it, but allowing, e.g., multiple + users to access the data without having to recompute. + + Check StorageBox.realize for a particularly notable implementation. + + TODO(ezyang): I think, in principle, every IRNode should have an + implementation of this, and most of the time no-op is OK, but you + really do have to audit each IRNode for this, so for now, raise + an error if it's not implemented. Note that some code in graph.py + will catch this thrown error and suppress it with a warning. + """ + raise NotImplementedError(f"realize NYI on {type(self)}") + + def codegen_reference(self, writer: Optional[IndentedBuffer] = None) -> str: + raise NotImplementedError(f"codegen_reference NYI on {type(self)}") + + def get_device(self) -> Optional[torch.device]: + return None + + def get_device_or_error(self) -> torch.device: + device = self.get_device() + assert device is not None + return device + + def has_exceeded_max_reads(self) -> bool: + return False + + def make_loader(self) -> Callable[[Sequence[Expr]], OpsValue]: + raise NotImplementedError(type(self).__name__) + + def make_indexer(self) -> Callable[[Sequence[Expr]], Expr]: + raise NotImplementedError(type(self).__name__) + + def get_stride(self) -> Sequence[_IntLike]: + raise NotImplementedError(type(self).__name__) + + def maybe_get_stride(self) -> Optional[Sequence[_IntLike]]: + try: + return self.get_stride() + except NotImplementedError: + return None + + def get_name(self) -> str: + raise NotImplementedError(type(self).__name__) + + def maybe_get_name(self) -> Optional[str]: + try: + return self.get_name() + except NotImplementedError: + return None + + def is_input_buffer(self) -> bool: + try: + return self.get_name() in V.graph.graph_inputs + except NotImplementedError: + return False + + def has_large_inner_fn(self, threshold: Optional[int] = None) -> bool: + return False + + def mark_reuse(self, users: int) -> None: + pass + + def realize_hint(self) -> None: + pass + + def unwrap_view(self) -> IRNode: + raise NotImplementedError(type(self).__name__) + + def freeze_layout(self) -> None: + raise NotImplementedError(type(self).__name__) + + def freeze_layout_with_stride_order( + self, order: Sequence[int], allow_padding: bool = False + ) -> None: + raise NotImplementedError(type(self).__name__) + + def freeze_layout_with_fill_order(self, order: Sequence[int]) -> None: + raise NotImplementedError(type(self).__name__) + + def freeze_layout_with_same_order(self, stride: Sequence[_IntLike]) -> None: + raise NotImplementedError(type(self).__name__) + + def freeze_layout_with_exact_strides( + self, exact_strides: Sequence[_IntLike], allow_padding: bool = False + ) -> None: + raise NotImplementedError(type(self).__name__) + + def get_read_writes(self) -> dependencies.ReadWrites: + raise NotImplementedError(type(self).__name__) + + def get_reads(self) -> OrderedSet[Dep]: + return self.get_read_writes().reads + + def num_reads(self) -> int: + return len(self.get_reads()) + + def get_storage_numel(self) -> _IntLike: + raise NotImplementedError(type(self).__name__) + + def get_free_symbol_uses( + self, unbacked_only: bool = False + ) -> OrderedSet[sympy.Symbol]: + raise NotImplementedError(type(self).__name__) + + def get_reduction_type(self) -> Optional[str]: + raise NotImplementedError(type(self).__name__) + + def get_reduction_size(self) -> Sequence[Expr]: + raise NotImplementedError(type(self).__name__) + + def is_extern(self) -> bool: + return False + + def is_no_op(self) -> bool: + return False + + def constant_to_device(self, device: torch.device) -> IRNode: + raise NotImplementedError(type(self).__name__) + + def get_mutation_names(self) -> Sequence[str]: + raise NotImplementedError(type(self).__name__) + + def get_operation_name(self) -> str: + raise NotImplementedError(type(self).__name__) + + def get_inputs_that_alias_output(self) -> Sequence[str]: + raise NotImplementedError(type(self).__name__) + + if TYPE_CHECKING: + + @property + def dtype(self) -> torch.dtype: ... + + +@ir_dataclass(frozen=False) +class Operation: + def __post_init__(self) -> None: + self.operation_name: Optional[str] = None + + def get_device(self) -> Optional[torch.device]: + raise NotImplementedError + + def get_origin_node(self) -> Optional[torch.fx.Node]: + assert hasattr(self, "origin_node") + return self.origin_node + + def get_origins(self) -> OrderedSet[Any]: + assert hasattr(self, "origins") + return self.origins + + def get_operation_name(self) -> str: + assert self.operation_name is not None + return self.operation_name + + def is_extern(self) -> bool: + return False + + def is_no_op(self) -> bool: + return False + + def get_read_writes(self) -> dependencies.ReadWrites: + raise NotImplementedError + + def is_user_of(self, name: str) -> bool: + return name in self.get_read_names() + + def get_read_names(self) -> OrderedSet[str]: + return OrderedSet(dep.name for dep in self.get_reads()) + + def get_reads(self) -> OrderedSet[Dep]: + return self.get_read_writes().reads + + def get_outputs(self) -> list[Buffer]: + raise NotImplementedError + + def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: + return OrderedSet() + + def get_free_symbol_uses( + self, unbacked_only: bool = False + ) -> OrderedSet[sympy.Symbol]: + """ + When unbacked_only=True: + Returns the unbacked symbols which are required to be in scope in + order to successfully perform codegen for this buffer. For example, + a buffer that corresponds to an extern kernel call that takes i0 as + an argument would return {i0} here. This is used to generate necessary + dependencies that ensure we actually bind i0 in codegen before you + try to use it. + + Note that this is NOT transitive; in particular, if this buffer takes + in as input another buffer with dynamic shape (e.g., (i0,)), we will + not report it here, because you will already have a dependency + on that buffer, which will eventually have a dependency on i0 if + necessary. + + When unbacked_only=False: + Similar to `unbacked_only=True` but including all free symbols + instead of only free unbacked symbols. + """ + return OrderedSet() + + def get_workspace_size(self) -> int: + """ + Gets extra global memory size needed by this buffer. + Some algorithms (e.g. group gemm) may require extra global memory in the generated code. + """ + return 0 + + +@ir_dataclass +class Loops(IRNode): + device: torch.device + dtype: torch.dtype + inner_fn: Callable[..., Any] + ranges: Sequence[_IntLike] + + @cache_on_self_and_args("Loops") + def get_free_symbol_uses( + self, unbacked_only: bool = False + ) -> OrderedSet[sympy.Symbol]: + return OrderedSet().union( + *(get_free_symbols(e, unbacked_only) for e in self.ranges), + self.inner_fn_free_symbols(unbacked_only), + ) + + def _to_str(self, names: Sequence[str]) -> str: + return self.str_helper( + [ + f"'{self.device.type}'", + str(self.dtype), + self.inner_fn_str(), + ] + + [f"{name}={getattr(self, name)}" for name in names] + + [f"origin_node={self.origin_node!r}"] + ) + + def __str__(self) -> str: + return self._to_str(("ranges",)) + + __repr__ = __str__ + + def get_device(self) -> Optional[torch.device]: + return self.device + + def get_origin_node(self) -> Optional[torch.fx.Node]: + return self.origin_node + + def get_size(self) -> Sequence[Expr]: + return self.ranges + + def get_pointwise_size(self) -> Sequence[Expr]: + return self.ranges + + @classmethod + def create( + cls, *args: Any, **kwargs: Any + ) -> Union[TensorBox, ShapeAsConstantBuffer]: + origin_node = kwargs.pop("origin_node", None) + tb = kwargs.pop("traceback", None) + r = cls(*args, **kwargs) + # Need to explicitly set origin_node here to propagate it down. + # todo(chilli): I think it would be better for IRNode to directly set + # origin_node + r._post_init_setattr("origin_node", origin_node) + r._post_init_setattr("traceback", tb or r.traceback) + return TensorBox.create(r) + + @staticmethod + def _index(ranges: Sequence[_IntLike], prefix: SymT = SymT.INDEX) -> Sequence[Expr]: + return [ + sympy.S.Zero if s == 1 else sympy_index_symbol_with_prefix(prefix, n) + for n, s in enumerate(ranges) + ] + + @cache_on_self + def inner_fn_opcount(self) -> OpCountResult: + opcounter = OpCounterCSE(V.MockHandler()) + with ( + V.set_ops_handler(opcounter), + patch.object(FlexibleLayout, "allow_indexing", True), + ): + self.inner_fn(*self.inner_fn_args()) + return opcounter.getvalue() + + def inner_fn_args(self) -> Sequence[Sequence[_IntLike]]: + return (self._index(self.ranges),) + + @cache_on_self + def inner_fn_str(self) -> str: + return V.KernelFormatterHandler.ir_to_string( + self.inner_fn, *self.inner_fn_args() + ) + + def has_large_inner_fn(self, threshold: Optional[int] = None) -> bool: + if threshold is None: + threshold = 0 + threshold = max(threshold, config.realize_opcount_threshold) + return self.inner_fn_opcount().num_ops > threshold + + def inner_fn_free_symbols(self, unbacked_only: bool = False) -> OrderedSet[Symbol]: + index = self._index(self.ranges) + return extract_free_symbols(self.inner_fn, index, unbacked_only=unbacked_only) + + def get_reads(self) -> OrderedSet[Dep]: + with patch.object(FlexibleLayout, "allow_indexing", True): + if self.get_reduction_type(): + return extract_read_writes( + self.make_loader(), + self.get_size(), + self.get_reduction_size(), + ).reads + else: + return extract_read_writes( + self.make_loader(), + self.get_size(), + ).reads + + def get_read_names(self) -> OrderedSet[str]: + return OrderedSet(self.inner_fn_opcount().read_buffers) + + def num_reads(self) -> int: + return len(self.inner_fn_opcount().read_buffers) + + def get_reduction_size(self) -> Sequence[Expr]: + raise NotImplementedError( + f"get_reduction_size() is not implemented by {type(self)}!" + ) + + def get_reduction_type(self) -> Optional[str]: + raise NotImplementedError( + f"get_reduction_type() is not implemented by {type(self)}!" + ) + + def constant_to_device(self, device: torch.device) -> IRNode: + raise NotImplementedError( + f"constant_to_device() is not implemented by {type(self)}!" + ) + + +def nop_loader_fn(idx: Union[Expr, Sequence[Expr]], *, dtype: torch.dtype) -> OpsValue: + if dtype.is_floating_point: + return ops.constant(float("nan"), dtype) + else: + return ops.constant(0, dtype) + + +@ir_dataclass +class Pointwise(Loops): + def make_loader(self) -> Callable[[Sequence[Expr]], OpsValue]: + # Make zero-element loops into a no-op + if self.is_zero_elements(): + return partial(nop_loader_fn, dtype=self.dtype) + + return self.inner_fn + + def __str__(self) -> str: + return self._to_str(("ranges",)) + + __repr__ = __str__ + + def get_reduction_size(self) -> Sequence[sympy.Expr]: + return [] + + def get_reduction_type(self) -> Optional[str]: + return None + + def store_output( + self, + output_name: Optional[str], + indexer: Callable[[Sequence[Expr]], Never], + vars: Sequence[Expr], + ) -> None: + loader = self.make_loader() + return ops.store(output_name or "unnamed", indexer(vars), loader(vars)) + + def constant_to_device(self, device: torch.device) -> IRNode: + """Move this to a given device. Requires that all reads are to constants.""" + loader = self.make_loader() + loader = patch.object(ConstantBuffer, "override_device", device)(loader) + return Pointwise( + device=device, + dtype=self.dtype, + inner_fn=loader, + ranges=self.ranges, + ) + + +@ir_dataclass +class Scatter(Pointwise): + output_indexer: Callable[[Sequence[Expr]], Expr] + scatter_mode: StoreMode = None + + def constant_to_device(self, device: torch.device) -> IRNode: + """Move this to a given device. Requires that all reads are to constants.""" + loader = self.make_loader() + loader = patch.object(ConstantBuffer, "override_device", device)(loader) + return Scatter( + device=device, + dtype=self.dtype, + inner_fn=loader, + ranges=self.ranges, + output_indexer=self.output_indexer, + scatter_mode=self.scatter_mode, + ) + + def store_output( + self, + output_name: Optional[str], + indexer: Callable[[Sequence[Expr]], Never], + vars: Sequence[Expr], + ) -> Any: + loader = self.make_loader() + if output_name is None: + output_name = "unnamed" + return ops.store( + output_name, + indexer(self.output_indexer(vars)), + loader(vars), + mode=self.scatter_mode, + ) + + +REDUCTION_COMBINE_FN: dict[str, Callable[..., OpsValue]] = { + "any": ops_wrapper("logical_or"), + "max": ops_wrapper("maximum"), + "min": ops_wrapper("minimum"), + "prod": ops_wrapper("mul"), + "sum": ops_wrapper("add"), + "dot": ops_wrapper("add"), + "xor_sum": ops_wrapper("bitwise_xor"), +} + + +def get_reduction_combine_fn( + reduction_type: str, dtype: torch.dtype, arg_break_ties_left: bool = True +) -> Callable[..., object]: + if reduction_type in REDUCTION_COMBINE_FN: + return REDUCTION_COMBINE_FN[reduction_type] + + elif reduction_type in ("argmax", "argmin"): + + def argmax_combine_fn( + a: tuple[object, object], b: tuple[object, object] + ) -> tuple[OpsValue, OpsValue]: + a_value, a_index = a + b_value, b_index = b + + if reduction_type == "argmin": + mask = ops.lt(a_value, b_value) + else: + mask = ops.gt(a_value, b_value) + + equal = ops.eq(a_value, b_value) + if is_float_dtype(dtype): + a_isnan = ops.ne(a_value, a_value) + b_isnan = ops.ne(b_value, b_value) + mask = ops.logical_or(mask, ops.gt(a_isnan, b_isnan)) + equal = ops.logical_or(equal, ops.logical_and(a_isnan, b_isnan)) + + tie = ( + ops.lt(a_index, b_index) + if arg_break_ties_left + else ops.gt(a_index, b_index) + ) + mask = ops.logical_or(mask, ops.logical_and(equal, tie)) + return ( + ops.where(mask, a_value, b_value), + ops.where(mask, a_index, b_index), + ) + + return argmax_combine_fn + + elif reduction_type == "welford_combine": + + def welford_combine_fn( + a: tuple[OpsValue, OpsValue, OpsValue], + b: tuple[OpsValue, OpsValue, OpsValue], + ) -> tuple[OpsValue, OpsValue, OpsValue]: + a_mean, a_m2, a_weight = a + b_mean, b_m2, b_weight = b + + delta = b_mean - a_mean + new_weight = a_weight + b_weight + w2_over_w = b_weight / new_weight + return ( + a_mean + delta * w2_over_w, + a_m2 + b_m2 + delta * delta * a_weight * w2_over_w, + new_weight, + ) + + return welford_combine_fn + + else: + raise NotImplementedError(f"unknown reduction_type={reduction_type}") + + +@ir_dataclass +class Reduction(Loops): + reduction_ranges: Sequence[_IntLike] + reduction_type: ReductionType + # self.dtype represents the dst dtype + src_dtype: torch.dtype + reduction_hint: ReductionHint + + def __str__(self) -> str: + return self._to_str(("ranges", "reduction_ranges", "reduction_type")) + + __repr__ = __str__ + + @cache_on_self_and_args("Reduction") + def get_free_symbol_uses(self, unbacked_only: bool = False) -> OrderedSet[Symbol]: + return super().get_free_symbol_uses(unbacked_only) | OrderedSet().union( + *(get_free_symbols(e, unbacked_only) for e in self.reduction_ranges) + ) + + def get_reduction_size(self) -> Sequence[Expr]: + return self.reduction_ranges + + def get_reduction_type(self) -> Optional[str]: + return self.reduction_type + + def store_reduction( + self, + output_name: Optional[str], + indexer: Callable[[Sequence[Expr]], Never], + vars: Sequence[Expr], + reduction_vars: Sequence[Symbol], + ) -> None: + value = ops.reduction( + self.dtype, + self.src_dtype, + self.reduction_type, + self.inner_fn(vars, reduction_vars), + ) + ops.store_reduction(output_name or "unnamed", indexer(vars), value) + + def index_length(self) -> int: + return len(self.ranges) + len(self.reduction_ranges) + + def inner_fn_args(self) -> Sequence[Sequence[Expr]]: + index = self._index(self.ranges) + rindex = self._index(self.reduction_ranges, SymT.R0_INDEX) + return (index, rindex) + + def inner_fn_free_symbols(self, unbacked_only: bool = False) -> OrderedSet[Symbol]: + index = self._index(self.ranges) + rindex = self._index(self.reduction_ranges, SymT.R0_INDEX) + return extract_free_symbols( + self.inner_fn, index, rindex, unbacked_only=unbacked_only + ) + + def constant_to_device(self, device: torch.device) -> IRNode: + """Move this to a given device. Requires that all reads are to constants.""" + loader = self.make_loader() + loader = patch.object(ConstantBuffer, "override_device", device)(loader) + return Reduction( + device=device, + dtype=self.dtype, + inner_fn=loader, + ranges=self.ranges, + reduction_ranges=self.reduction_ranges, + reduction_type=self.reduction_type, + src_dtype=self.src_dtype, + reduction_hint=ReductionHint.DEFAULT, + ) + + @staticmethod + def num_splits( + device: torch.device, + dst_dtype: torch.dtype, + src_dtype: torch.dtype, + inner_fn: Callable[_P, OpsValue], + ranges: Sequence[_IntLike], + reduction_ranges: Sequence[_IntLike], + reduction_type: Union[ReductionType, Literal["scan"]], + reduction_numel: Expr, + input_node: Optional[IRNode] = None, + ) -> tuple[ReductionHint, _IntLike]: + reduction_numel_hint = V.graph.sizevars.symbolic_hint(reduction_numel) + numel_hint = V.graph.sizevars.symbolic_hint(sympy_product(ranges)) + + should_split = reduction_type == "scan" or ( + not V.graph.has_feature(device, BackendFeature.REDUCE_TO_SINGLE_ELEMENT) + and reduction_type + not in ( + "argmax", + "argmin", + ) + and config.split_reductions + ) + + if not (_is_static(reduction_numel_hint) and _is_static(numel_hint)): + # We don't support unbacked symints + return ReductionHint.DEFAULT, 1 + + if reduction_type == "dot": + # Don't split when doing native matmul + return ReductionHint.DEFAULT, 1 + + props = DeviceProperties.create(device) + num_sm = props.multi_processor_count + min_elements_per_thread = 32 + if should_split: + inner_reduction_splits: Callable[[int, int], int] = functools.partial( + V.choices.reduction_split_factor, device, inner_reduction=True + ) + outer_reduction_splits: Callable[[int, int], int] = functools.partial( + V.choices.reduction_split_factor, device, inner_reduction=False + ) + else: + + def inner_reduction_splits( + reduction_numel_hint: int, + numel_hint: int, + ) -> int: + return 1 + + outer_reduction_splits = inner_reduction_splits + + # easy cases + if numel_hint == 1: + split = inner_reduction_splits(reduction_numel_hint, numel_hint) + if split == 1: + # No need to split. + return ReductionHint.INNER, split + if input_node is not None and isinstance(input_node, TensorBox): + with patch.object(FlexibleLayout, "allow_indexing", True): + ( + new_ranges, + new_reduction_ranges, + ) = extract_input_node_reduction_ranges(input_node) + if new_ranges is not None and new_reduction_ranges is not None: + extracted_numel_hint = V.graph.sizevars.symbolic_hint( + sympy_product(new_ranges + new_reduction_ranges) + ) + if reduction_numel_hint == extracted_numel_hint: + log.debug( + "Use previous IRNode's range and reduction_ranges instead of split. " + "current ranges: %s, current reduction ranges: %s, current split: %d, " + "new ranges: %s, new reduction ranges: %s", + ranges, + reduction_ranges, + split, + new_ranges, + new_reduction_ranges, + ) + # If the input_node or its dependent nodes are also Reduction nodes, + # use reduction_sizes of this node or its dependent nodes directly. + return ReductionHint.INNER, -1 + return ReductionHint.INNER, split + if ( + reduction_numel_hint <= min_elements_per_thread + or numel_hint >= num_sm * 2 * 32 + ): + return ReductionHint.DEFAULT, 1 + + r = Reduction( + device=device, + dtype=dst_dtype, + inner_fn=inner_fn, + ranges=ranges, + reduction_ranges=reduction_ranges, + reduction_type=reduction_type if reduction_type != "scan" else "sum", + src_dtype=src_dtype, + reduction_hint=ReductionHint.DEFAULT, + ) + + def get_read_indices(r: Reduction) -> tuple[Sequence[Expr], bool]: + device = r.get_device() + assert device is not None + cb = ComputedBuffer( + name=None, + layout=FlexibleLayout( + device=device, + dtype=r.get_dtype(), + size=r.get_size(), + ), + data=r, + ) + read_writes = cb.get_read_writes() + # try finding the full size producer + # TODO this will fail for something like ((1, N) * (N, 1)).sum() + # this would also possibly be wrong for producers with the different contiguity but we hope those cases are rare + assert read_writes.range_vars is not None + range_vars = [ + r + for r in read_writes.range_vars + if isinstance(r, Expr) and not isinstance(r, sympy.Number) + ] + indices = [] + changed = False + for md in sorted(read_writes.reads, key=lambda x: x.name): + if all(r in md.index.free_symbols for r in range_vars): + indices.append(md.index) + if md.name in V.graph.name_to_buffer: + buf = V.graph.name_to_buffer[md.name] + original_stride = getattr(buf.layout, "stride", None) + buf.decide_layout() + if getattr(buf.layout, "stride", None) != original_stride: + changed = True + return indices, changed + + indices, changed = get_read_indices(r) + if changed: + indices, _ = get_read_indices(r) + + if len(indices) == 0: + # TODO determine splits when all inputs are broadcast + return ReductionHint.DEFAULT, 1 + + (_, reduction_vars), ranges1 = dependencies.index_vars_squeeze( + r.get_size(), r.get_reduction_size() + ) + num_outer = 0 + num_inner = 0 + for i in indices: + j = V.graph.sizevars.simplify_with_ranges(i, ranges1) + strides = V.graph.sizevars.stride_hints( + j, reduction_vars, list(ranges1.keys()) + ) + # A 0 stride does not make a reduction contiguous. + # This can happen when the reduction ranges contains a 1. + outer = all(s == 0 or s > 1 for s in strides) + if outer: + num_outer += 1 + else: + num_inner += 1 + if num_inner > num_outer: + return ReductionHint.INNER, inner_reduction_splits( + reduction_numel_hint, numel_hint + ) + else: + return ReductionHint.OUTER, outer_reduction_splits( + reduction_numel_hint, numel_hint + ) + + @staticmethod + def _unroll_reduction_fn( + inner_fn: Callable[[Sequence[_IntLike], Sequence[_IntLike]], OpsValue], + reduction_ranges: Sequence[_IntLike], + reduction_type: str, + src_dtype: torch.dtype, + ) -> Callable[[Sequence[_IntLike]], OpsValue]: + """Convert inner_fn from a reduction to an pointwise""" + reduction_ranges = V.graph.sizevars.guard_int_seq(reduction_ranges) + + combine_fn = get_reduction_combine_fn(reduction_type, src_dtype) + + def fn(index: Sequence[_IntLike]) -> Any: + return functools.reduce( + combine_fn, + ( + value_fn(index, rindex) + for rindex in itertools.product( + *[range(x) for x in reduction_ranges] + ) + ), + ) + + value_fn: Callable[[Sequence[_IntLike], Sequence[_IntLike]], Any] + if reduction_type in ("argmin", "argmax"): + flatten_index = _fixed_indexer( + reduction_ranges, + FlexibleLayout.contiguous_strides(reduction_ranges), + ) + + def value_fn( + index: Sequence[_IntLike], rindex: Sequence[_IntLike] + ) -> tuple[OpsValue, OpsValue]: + rindex = [sympy.expand(i) for i in rindex] + return ( + inner_fn(index, rindex), + ops.index_expr(flatten_index(rindex), torch.int64), + ) + + return lambda index: fn(index)[1] + else: + value_fn = inner_fn + return fn + + @classmethod + # pyrefly: ignore [bad-override] + def create( + cls, + device: torch.device, + dst_dtype: torch.dtype, + src_dtype: torch.dtype, + inner_fn: Callable[..., Any], + ranges: Sequence[Expr], + reduction_ranges: Sequence[Expr], + reduction_type: ReductionType, + reduction_hint: ReductionHint = ReductionHint.DEFAULT, + input_node: Optional[IRNode] = None, + ) -> Union[TensorBox, ShapeAsConstantBuffer]: + """ + Create a reduction node. May split the reduction to multiple layers to expose + more parallelism. + """ + reduction_numel = V.graph.sizevars.simplify(sympy_product(reduction_ranges)) + + if reduction_numel == 0: + # N.B. This is a hack to generate the literal of the given type + # Ideally, we should be fixing `def constant` in triton.py + # but it breaks due to hardcoded dtypes in other places + def py_cnst(val: object) -> Union[bool, float, int]: + if dst_dtype == torch.bool: + return bool(val) + elif dst_dtype.is_floating_point: + assert isinstance(val, SupportsFloat), type(val) + return float(val) + else: + assert isinstance(val, SupportsInt), type(val) + return int(val) + + rtypes_to_inits = { + "sum": py_cnst(0), + "xor_sum": py_cnst(0), + "prod": py_cnst(1), + "any": py_cnst(0), + # "all" is desugared to `!any(!val)` + } + + assert reduction_type in rtypes_to_inits, ( + f"{reduction_type} not supported for zero-dimension tensors!" + ) + + def const_fn(index: int) -> OpsValue: + return ops.constant(rtypes_to_inits[reduction_type], dst_dtype) + + return Pointwise.create( + device=device, + dtype=src_dtype, + inner_fn=const_fn, + ranges=list(ranges), + ) + + if reduction_numel == 1: + # this reduction is actually a pointwise op + if reduction_type in ("argmin", "argmax"): + + def fn(index: int) -> OpsValue: + return ops.constant(0, dst_dtype) + + else: + + def fn(index: int) -> OpsValue: + reduction_index = [sympy.S.Zero for _ in reduction_ranges] + return inner_fn(index, reduction_index) + + return Pointwise.create( + device=device, dtype=dst_dtype, inner_fn=fn, ranges=ranges + ) + + if ( + isinstance(reduction_numel, Integer) + and V.graph.sizevars.size_hint_or_throw(reduction_numel) + < config.unroll_reductions_threshold + and (sympy_product(ranges) != 1 or is_gpu(device.type)) + and reduction_type != "dot" + ): + # When native matmul, don't unroll the dot reduction. + + # NB: This works around https://github.com/pytorch/pytorch/issues/140457 + # since turning reductions into pointwise ops can exacerbate this problem + return Pointwise.create( + device=device, + dtype=dst_dtype, + inner_fn=cls._unroll_reduction_fn( + inner_fn, reduction_ranges, reduction_type, src_dtype + ), + ranges=ranges, + ) + + # triton doesn't support reduce to single element well, so break it up + hint, split = cls.num_splits( + device, + dst_dtype, + src_dtype, + inner_fn, + ranges, + reduction_ranges, + reduction_type, + reduction_numel, + input_node, + ) + + def _maybe_increase_split(split: int) -> int: + # don't apply min_num_split constraint for static shape case. + if _is_static(reduction_numel): + return split + if split > 1: + return max(split, config.min_num_split) + else: + return split + + split = _maybe_increase_split(split) + + # intermediate reduction in split can contain complex indexing, + # and num_splits will fail to correctly set the hint + # reuse the passed hint if available + if reduction_hint == ReductionHint.DEFAULT: + reduction_hint = hint + if split == -1: + assert input_node is not None + with patch.object(FlexibleLayout, "allow_indexing", True): + new_ranges, new_reduction_ranges = extract_input_node_reduction_ranges( + input_node + ) + assert new_ranges is not None + assert new_reduction_ranges is not None + return cls.create_multilayer_existing_ranges( + device, + dst_dtype, + src_dtype, + inner_fn, + ranges, + reduction_ranges, + new_ranges, + new_reduction_ranges, + reduction_type, + reduction_hint, + ) + elif split > 1: + # triton doesn't support reduce to single element well, so break it up + out = cls.create_multilayer( + device, + dst_dtype, + src_dtype, + inner_fn, + ranges, + reduction_ranges, + reduction_type, + split, + reduction_hint, + input_node, + ) + + # Find the reduction that get split + split_reduction = None + if config.triton.mix_order_reduction and isinstance(out, TensorBox): + + def _find_split_reduction( + cur_node: TensorBox, + ) -> Optional[ComputedBuffer]: + read_names = cur_node.get_read_names() + if len(read_names) != 1: + return None + + bufname = next(iter(read_names)) + if bufname not in V.graph.name_to_buffer: + return None + buf = V.graph.name_to_buffer[bufname] + if not isinstance(buf, ComputedBuffer): + return None + + assert buf.data.get_reduction_type() is not None + + return buf + + split_reduction = _find_split_reduction(out) + + if split_reduction: + # If a reduction is split to more than 2 layers, + # say there are 3 layers, + # we always have the correct setting for layer1 (top layer). + # The setting on layer2 may be incorrect but it's fine + # since they are never get used. + # TODO: should we skip setting these fields for layer2 + assert isinstance(split_reduction.data, Reduction), ( + f"{type(split_reduction.data)}" + ) + split_reduction._split_size = split_reduction.data.reduction_ranges[0] + split_reduction._original_inner_fn = inner_fn + split_reduction._original_ranges = ranges + split_reduction._original_reduction_ranges = reduction_ranges + return out + + out = TensorBox.create( + Reduction( + device=device, + dtype=dst_dtype, + inner_fn=inner_fn, + ranges=ranges, + reduction_ranges=reduction_ranges, + reduction_type=reduction_type, + src_dtype=src_dtype, + reduction_hint=reduction_hint, + ) + ) + return out + + @staticmethod + def default_accumulator( + reduction_type: str, dtype: torch.dtype + ) -> Union[_NumLike, Sequence[_NumLike]]: + if reduction_type in ("max", "argmax"): + if is_float_dtype(dtype): + return float("-inf") + elif is_boolean_dtype(dtype): + return False + else: + return torch.iinfo(dtype).min + if reduction_type in ("min", "argmin"): + if is_float_dtype(dtype): + return float("inf") + elif is_boolean_dtype(dtype): + return True + else: + return torch.iinfo(dtype).max + + zero = False if is_boolean_dtype(dtype) else 0 + one = True if is_boolean_dtype(dtype) else 1 + return { + "sum": zero, + "prod": one, + "dot": zero, + "xor_sum": zero, + "any": zero, + "welford_reduce": (zero, zero, zero), + "welford_combine": (zero, zero, zero), + "online_softmax_reduce": (float("-inf"), zero), + }[reduction_type] + + @staticmethod + def default_value( + reduction_type: str, dtype: torch.dtype + ) -> Union[_NumLike, Sequence[_NumLike]]: + if reduction_type == "welford_reduce": + return 0 + return Reduction.default_accumulator(reduction_type, dtype) + + @staticmethod + def _multilayer_second_step_hint( + split: _IntLike, numel_hint: int, reduction_hint: ReductionHint + ) -> ReductionHint: + if split == -1: + return reduction_hint + if split <= 512 and numel_hint <= 512 and reduction_hint == ReductionHint.OUTER: + return ReductionHint.OUTER_TINY + if ( + split <= 1024 + and numel_hint <= 256 + and reduction_hint == ReductionHint.OUTER + ): + return ReductionHint.OUTER_TINY + + return reduction_hint + + @classmethod + def check_for_split_dense_dim_reindexing( + cls, reduction_numel: _IntLike, input_node: Optional[IRNode] + ) -> Optional[int]: + """ + If we are reducing over the full tensor, and it is non-dense in the last dimension, + reindex so we reduce over the dense dimension. initially just handle complete + reduction case + """ + if input_node is None: + return None + + if not V.graph.sizevars.statically_known_equals( + input_node.get_numel(), reduction_numel + ): + return None + + input_node.realize() + try: + # finalize layout + as_storage_and_layout(input_node) + except NotImplementedError: + return None + + strides = input_node.get_stride() + + for i, s in enumerate(strides[:-1]): + if V.graph.sizevars.statically_known_equals(s, 1): + return i + + return None + + @classmethod + def _multilayer_wrap_loader( + cls, + loader: Callable[..., OpsValue], + reduction_ranges: Sequence[_IntLike], + reduction_numel: _IntLike, + split: _IntLike, + block_size: _IntLike, + default: Union[_NumLike, Sequence[_NumLike]], + input_node: Optional[IRNode] = None, + ) -> Callable[..., object]: + dense_index = cls.check_for_split_dense_dim_reindexing( + reduction_numel, input_node + ) + reindex = View.dynamic_reshape_indexer( + reduction_ranges, [reduction_numel], dense_index + ) + need_mask = not V.graph.sizevars.statically_known_true( + sympy.Eq(reduction_numel % split, 0) + ) + + def wrapper_fn( + index: Sequence[Symbol], reduction_index: Sequence[Symbol] + ) -> OpsValue: + (reduction_index,) = reduction_index + *new_index, reduction_block = index + indices = block_size * reduction_block + reduction_index + + def body() -> OpsValue: + return loader(new_index, reindex([indices])) + + if need_mask: + index_dtype = dtype_from_size(reduction_numel) + mask = ops.lt( + ops.index_expr(indices, index_dtype), + ops.index_expr(reduction_numel, index_dtype), + ) + return ops.masked(mask, body, default) + else: + return body() + + return wrapper_fn + + @classmethod + def _multilayer_wrap_loader_existing_ranges( + cls, + loader: Callable[[Sequence[Expr], Sequence[Expr]], OpsValue], + original_ranges: Sequence[Expr], + original_reduction_ranges: Sequence[Expr], + new_ranges: Sequence[Integer], + new_reduction_ranges: Sequence[Integer], + ) -> Callable[[Sequence[sympy.Expr], Sequence[sympy.Expr]], OpsValue]: + assert all(r == 1 for r in original_ranges), ( + f"Only enabled for numel_hint == 1, found {original_ranges=}" + ) + reindex = View.dynamic_reshape_indexer( + original_reduction_ranges, tuple(new_ranges) + tuple(new_reduction_ranges) + ) + + def wrapper_fn( + merged_index: Sequence[Expr], + new_reduction_index: Sequence[Expr], + ) -> OpsValue: + original_idx = merged_index[: len(original_ranges)] + new_index = merged_index[len(original_ranges) :] + return loader( + original_idx, + reindex(tuple(new_index) + tuple(new_reduction_index)), + ) + + return wrapper_fn + + @classmethod + def create_multilayer_helper( + cls, + device: torch.device, + dst_dtype: torch.dtype, + src_dtype: torch.dtype, + wrapper_fn: Callable[..., Any], + original_ranges: Sequence[Expr], + original_reduction_ranges: Sequence[Expr], + new_ranges: list[Expr], + new_reduction_ranges: list[Integer], + reduction_type: ReductionType, + split: _IntLike, + reduction_hint: ReductionHint, + ) -> Union[TensorBox, ShapeAsConstantBuffer]: + """ + Break a large reduction up into multiple smaller reductions + recursively + """ + # triton will automatically compute reductions in fp32 if reducing over fp16/bf16 + # within the kernel. keep the intermediate in fp32 so as to keep the whole reduction + # in fp32 and not reduce precision by breaking up the kernel into multiple layers + intermediate_dtype = ( + dst_dtype + if dst_dtype not in (torch.float16, torch.bfloat16) + else torch.float + ) + intermediate = Reduction.create( + device, + intermediate_dtype, + src_dtype, + wrapper_fn, + new_ranges, + new_reduction_ranges, + reduction_type, + reduction_hint, + ) + intermediate.realize() + intermediate_loader = intermediate.make_loader() + + def intermediate_fn( + index: Sequence[_IntLike], reduction_index: Sequence[_IntLike] + ) -> OpsValue: + return intermediate_loader([*index, *reduction_index]) + + numel_hint = V.graph.sizevars.size_hint(sympy_product(original_ranges)) + reduction_hint = cls._multilayer_second_step_hint( + split, numel_hint, reduction_hint + ) + + assert original_ranges == new_ranges[: len(original_ranges)] + return TensorBox.create( + Reduction( + device=device, + dtype=dst_dtype, + inner_fn=intermediate_fn, + ranges=original_ranges, + reduction_ranges=new_ranges[len(original_ranges) :], + reduction_type=reduction_type, + src_dtype=src_dtype, + reduction_hint=reduction_hint, + ) + ) + + @classmethod + def create_multilayer( + cls, + device: torch.device, + dst_dtype: torch.dtype, + src_dtype: torch.dtype, + inner_fn: Callable[..., Any], + ranges: Sequence[Expr], + reduction_ranges: Sequence[Expr], + reduction_type: ReductionType, + split: _IntLike, + reduction_hint: ReductionHint, + input_node: Optional[IRNode] = None, + ) -> Union[TensorBox, ShapeAsConstantBuffer]: + """ + Break a large reduction up into multiple smaller reductions + recursively + """ + # TODO(jansel): realize the reduction so we can do dynamic indexing + reduction_numel = sympy_product(reduction_ranges) + block_size = FloorDiv(reduction_numel + (split - 1), split) + default = cls.default_value(reduction_type, dst_dtype) + wrapper_fn = cls._multilayer_wrap_loader( + inner_fn, + reduction_ranges, + reduction_numel, + split, + block_size, + default, + input_node, + ) + + return cls.create_multilayer_helper( + device, + dst_dtype, + src_dtype, + wrapper_fn, + ranges, + reduction_ranges, + [*ranges, split], + [block_size], + reduction_type, + split, + reduction_hint, + ) + + @classmethod + def create_multilayer_existing_ranges( + cls, + device: torch.device, + dst_dtype: torch.dtype, + src_dtype: torch.dtype, + inner_fn: Callable[..., Any], + original_ranges: Sequence[Expr], + original_reduction_ranges: Sequence[Expr], + new_ranges: list[Integer], + new_reduction_ranges: list[Integer], + reduction_type: ReductionType, + reduction_hint: ReductionHint, + ) -> Union[TensorBox, ShapeAsConstantBuffer]: + """ + Break a large reduction up into multiple smaller reductions + recursively + """ + wrapper_fn = cls._multilayer_wrap_loader_existing_ranges( + inner_fn, + original_ranges, + original_reduction_ranges, + new_ranges, + new_reduction_ranges, + ) + return cls.create_multilayer_helper( + device, + dst_dtype, + src_dtype, + wrapper_fn, + original_ranges, + original_reduction_ranges, + [*original_ranges, *new_ranges], + new_reduction_ranges, + reduction_type, + -1, + reduction_hint, + ) + + +def _fixed_indexer( + size: Sequence[int], + stride: Optional[Sequence[int]] = None, + offset: Expr = Integer(0), +) -> Callable[[Sequence[Expr]], Expr]: + """A closure containing math to read a given element""" + + def indexer(index: Sequence[int]) -> int: + assert stride is not None and len(index) == len(stride) + assert len(index) == len(size) + result = offset + for idx, st, sz in zip(index, stride, size): + if sz != 1: + result = result + idx * st + return result + + return indexer + + +INNER_FN_TY: TypeAlias = Callable[[Sequence[Expr], Sequence[Expr]], OpsValue] + + +class MultiOutputReduction(Reduction): + output_index: int + + def __init__( + self, + device: torch.device, + dst_dtype: torch.dtype, + inner_fns: Union[INNER_FN_TY, Sequence[INNER_FN_TY]], + ranges: Sequence[Integer], + reduction_ranges: Sequence[Integer], + reduction_type: ReductionType, + src_dtype: torch.dtype, + reduction_hint: ReductionHint, + output_index: int, + ): + if callable(inner_fns): + inner_fns = (inner_fns,) + + loader: Callable[[Sequence[Expr], Sequence[Expr]], Any] + if len(inner_fns) == 1: + loader = inner_fns[0] + else: + + def loader( + idx: Sequence[Expr], reduction_idx: Sequence[Expr] + ) -> tuple[OpsValue, ...]: + return tuple(fn(idx, reduction_idx) for fn in inner_fns) + + super().__init__( + device=device, + dtype=dst_dtype, + inner_fn=loader, + ranges=ranges, + reduction_ranges=reduction_ranges, + reduction_type=reduction_type, + src_dtype=src_dtype, + reduction_hint=reduction_hint, + ) + self.output_index = output_index + + def store_reduction( + self, + output_name: Optional[str], + indexer: Callable[[Sequence[Expr]], Never], + vars: Sequence[Expr], + reduction_vars: Sequence[Symbol], + ) -> Any: + values = ops.reduction( + self.dtype, + self.src_dtype, + self.reduction_type, + self.inner_fn(vars, reduction_vars), + ) + assert isinstance(values, (tuple, list)), type(values) + value = values[self.output_index] + return ops.store_reduction(output_name or "unnamed", indexer(vars), value) + + +class OnlineSoftmaxReduction(MultiOutputReduction): + @classmethod + def create( # type: ignore[override] + cls, + device: torch.device, + dst_dtype: torch.dtype, + src_dtype: torch.dtype, + inner_fn: Callable[..., Any], + ranges: Sequence[Expr], + reduction_ranges: Sequence[Expr], + num_output: int, + reduction_hint: ReductionHint = ReductionHint.DEFAULT, + input_node: Optional[IRNode] = None, + ) -> Sequence[Union[TensorBox, ShapeAsConstantBuffer]]: + """ + Create the reduction disregarding splitting. + """ + results = tuple( + TensorBox.create( + MultiOutputReduction( + device, + dst_dtype, + inner_fn, + ranges, + reduction_ranges, + "online_softmax_reduce", + src_dtype, + reduction_hint, + output_idx, + ) + ) + for output_idx in range(num_output) + ) + for t in results: + t.realize() + return results + + +class WelfordReduction(MultiOutputReduction): + @classmethod + def create( # type: ignore[override] + cls, + device: torch.device, + dtype: torch.dtype, + inner_fns: Sequence[Callable[..., Any]], + ranges: list[Integer], + reduction_ranges: list[Integer], + reduction_type: ReductionType, + reduction_hint: ReductionHint = ReductionHint.DEFAULT, + ) -> Sequence[Union[TensorBox, ShapeAsConstantBuffer]]: + assert reduction_type in ("welford_reduce", "welford_combine") + + reduction_numel = V.graph.sizevars.simplify(sympy_product(reduction_ranges)) + + def const(val: int) -> Union[TensorBox, ShapeAsConstantBuffer]: + def inner_fn(idx: Sequence[Expr]) -> OpsValue: + return ops.constant( + val, + dtype, + ) + + return Pointwise.create( + device=device, + dtype=dtype, + inner_fn=inner_fn, + ranges=list(ranges), + ) + + if reduction_numel == 0: + mean = const(0) + m2 = const(0) + weight = const(0) + return mean, m2, weight + + if reduction_numel == 1: + + def copy( + loader: Callable[[Sequence[Expr], Sequence[Expr]], OpsValue], + ) -> Union[TensorBox, ShapeAsConstantBuffer]: + def inner_fn(idx: Sequence[Expr]) -> OpsValue: + reduction_index = [sympy.S.Zero for _ in reduction_ranges] + return loader(idx, reduction_index) + + return Pointwise.create( + device=device, + dtype=dtype, + inner_fn=inner_fn, + ranges=list(ranges), + ) + + if reduction_type == "welford_reduce": + return copy(inner_fns[0]), const(0), const(1) + else: + return tuple(copy(fn) for fn in inner_fns) + + # TODO: Unrolled reduction + # if ( + # isinstance(reduction_numel, Integer) + # and V.graph.sizevars.size_hint(reduction_numel) + # < config.unroll_reductions_threshold + # and sympy_product(ranges) != 1 + # ): + # return Pointwise.create( + # device, + # dst_dtype, + # cls._unroll_reduction_fn( + # inner_fn, reduction_ranges, reduction_type, src_dtype, + # ), + # ranges, + # ) + + # triton doesn't support reduce to single element well, so break it up + hint, split = Reduction.num_splits( + device, + dtype, + dtype, + inner_fns[0], + ranges, + reduction_ranges, + reduction_type=reduction_type, + reduction_numel=reduction_numel, + ) + # intermediate reduction in split can contain complex indexing, + # and num_splits will fail to correctly set the hint + # reuse the passed hint if available + if reduction_hint == ReductionHint.DEFAULT: + reduction_hint = hint + if split > 1: + # triton doesn't support reduce to single element well, so break it up + return cls.create_multilayer( + device, + dtype, + inner_fns, + ranges, + reduction_ranges, + reduction_type, + split, + reduction_hint, + ) + + results = [ + TensorBox.create( + WelfordReduction( + device, + dtype, + inner_fns, + ranges, + reduction_ranges, + reduction_type, + dtype, + reduction_hint, + output_idx, + ) + ) + for output_idx in range(3) + ] + for t in results: + t.realize() + return results + + @staticmethod + def default_value( + reduction_type: str, dtype: torch.dtype + ) -> Union[_NumLike, Sequence[_NumLike]]: + return (0, 0, 0) + + @classmethod + def create_multilayer( # type: ignore[override] + cls, + device: torch.device, + dtype: torch.dtype, + inner_fns: Sequence[Callable[..., Any]], + ranges: list[Integer], + reduction_ranges: list[Integer], + reduction_type: ReductionType, + split: _IntLike, + reduction_hint: ReductionHint, + ) -> Sequence[Union[TensorBox, ShapeAsConstantBuffer]]: + """ + Break a large reduction up into multiple smaller reductions + recursively + """ + reduction_numel = sympy_product(reduction_ranges) + need_mask = not V.graph.sizevars.statically_known_true( + sympy.Eq(reduction_numel % split, 0) + ) + + if need_mask and reduction_type != "welford_combine": + # If we need mask, then "welford_reduce" doesn't work because + # masked inputs shouldn't count towards the welford weight + + def constant( + idx: Sequence[Expr], reduction_idx: Sequence[Expr], value: int + ) -> OpsValue: + return ops.constant(value, dtype) + + return cls.create_multilayer( + device=device, + dtype=dtype, + inner_fns=( + inner_fns[0], + partial(constant, value=0), + partial(constant, value=1), + ), + ranges=ranges, + reduction_ranges=reduction_ranges, + reduction_type="welford_combine", + split=split, + reduction_hint=reduction_hint, + ) + + block_size = FloorDiv(reduction_numel + (split - 1), split) + intermediates = WelfordReduction.create( + device, + dtype, + tuple( + cls._multilayer_wrap_loader( + loader, + reduction_ranges, + reduction_numel, + split, + block_size, + default=0, + ) + for loader in inner_fns + ), + [*ranges, split], + [block_size], + reduction_type, + reduction_hint, + ) + for i in intermediates: + i.realize() + + def intermediate_loader_fn( + index: Sequence[Expr], + reduction_index: Sequence[Expr], + loader: Callable[[Sequence[Expr]], OpsValue], + ) -> OpsValue: + return loader([*index, *reduction_index]) + + numel_hint = V.graph.sizevars.size_hint(sympy_product(ranges)) + reduction_hint = cls._multilayer_second_step_hint( + split, numel_hint, reduction_hint + ) + return WelfordReduction.create( + device, + dtype, + tuple( + partial(intermediate_loader_fn, loader=i.make_loader()) + for i in intermediates + ), + ranges, + [split], + # welford_reduce turns one input into three outputs, which are combined with welford_combine + "welford_combine", + reduction_hint, + ) + + +@ir_dataclass +class Scan(Loops): + scan_ranges: list[Integer] + size: list[Integer] + combine_fn: Callable[[tuple[Any, ...], tuple[Any, ...]], tuple[Any, ...]] + reindex: Callable[[Sequence[_IntLike], Sequence[_IntLike]], Sequence[_IntLike]] + reduction_hint: ReductionHint + output_index: int + # output_index indexes the following tuples + dtypes: tuple[torch.dtype, ...] + inner_fns: tuple[Callable[..., Any], ...] + + # HACK we mimic reduction + + @cache_on_self_and_args("Scan") + def get_free_symbol_uses(self, unbacked_only: bool = False) -> OrderedSet[Symbol]: + # TODO: Can combine_fn/reindex close over unbacked symbols? If so, we + # need to explicitly represent the closure so we can pull out unbacked + # symbols here + return ( + super().get_free_symbol_uses(unbacked_only) + | OrderedSet().union( + *(get_free_symbols(e, unbacked_only) for e in self.scan_ranges) + ) + | OrderedSet().union( + *(get_free_symbols(e, unbacked_only) for e in self.size) + ) + ) + + def __post_init__(self) -> None: + assert len(self.ranges) + len(self.scan_ranges) == len(self.size) + super().__post_init__() + + def store_reduction( + self, + output_name: Optional[str], + indexer: Callable[[Sequence[_IntLike]], Never], + vars: Sequence[Expr], + scan_vars: Sequence[Symbol], + ) -> Any: + idx = self.reindex(vars, scan_vars) + values = tuple(inner_fn(idx) for inner_fn in self.inner_fns) + result = ops.scan(self.dtypes, self.combine_fn, values) + return ops.store( + output_name or "unnamed", indexer(idx), result[self.output_index] + ) + + def get_reduction_type(self) -> Optional[str]: + # return self.scan_op + return "custom" + + def get_reduction_size(self) -> Sequence[Expr]: + return self.scan_ranges + + def get_size(self) -> Sequence[Expr]: + return self.size + + def get_pointwise_size(self) -> Sequence[Expr]: + return self.ranges + + def index_length(self) -> int: + return len(self.ranges) + len(self.scan_ranges) + + def inner_fn_args(self) -> Sequence[Sequence[_IntLike]]: + index = self._index(self.ranges) + rindex = self._index(self.scan_ranges, SymT.R0_INDEX) + idx = self.reindex(index, rindex) + return (idx,) + + def inner_fn_free_symbols(self, unbacked_only: bool = False) -> OrderedSet[Symbol]: + index = self._index(self.ranges) + rindex = self._index(self.scan_ranges, SymT.R0_INDEX) + idx = self.reindex(index, rindex) + return extract_free_symbols(self.inner_fn, idx, unbacked_only=unbacked_only) + + @classmethod + def create( # type: ignore[override] + cls, + device: torch.device, + dtypes: tuple[torch.dtype, ...], + inner_fns: tuple[Callable[[Sequence[Expr]], Any], ...], + size: list[Integer], + axis: int, + combine_fn: Callable[[tuple[Any, ...], tuple[Any, ...]], tuple[Any, ...]], + reduction_hint: ReductionHint = ReductionHint.DEFAULT, + *, + # Whether we have the option to fallback to aten + can_fallback_to_aten: bool = True, + **kwargs: Any, + ) -> Sequence[Optional[Union[TensorBox, ShapeAsConstantBuffer]]]: + pointwise_ranges = [*size[:axis], *size[axis + 1 :]] + scan_ranges = [size[axis]] + + if not V.graph.has_feature(device, BackendFeature.SCAN): + return [None] * len(dtypes) + + if len(dtypes) > 1 and not V.graph.has_feature( + device, BackendFeature.TUPLE_REDUCTION + ): + return [None] * len(dtypes) + + sizevars = V.graph.sizevars + scan_numel = sizevars.simplify(sympy_product(scan_ranges)) + + assert len(dtypes) == len(inner_fns) + + # Scan with a single element is just a copy + if sizevars.statically_known_true(sympy.Le(scan_numel, 1)): + return [ + Pointwise.create( + device=device, + dtype=dtypes[output_index], + inner_fn=inner_fns[output_index], + ranges=size, + ) + for output_index in range(len(dtypes)) + ] + + reduction_hint, num_splits = cls.num_splits( + device=device, + dtype=dtypes[0], + inner_fn=inner_fns[0], + axis=axis, + pointwise_ranges=pointwise_ranges, + scan_ranges=scan_ranges, + combine_fn=combine_fn, + scan_numel=scan_numel, + ) + scan_type = Scan + if num_splits > 1: + supports_split = ( + # pyrefly: ignore [unsupported-operation] + torch.version.hip is None or (has_triton and triton_version >= "3.3.0") + ) and (len(dtypes) == 1) + if not supports_split: + if can_fallback_to_aten: + # Fallback to ATen + return [None] * len(dtypes) + else: + num_splits = 1 + else: + scan_type = SplitScan + + def reindex(index: Sequence[Expr], scan_index: Sequence[Expr]) -> list[Expr]: + assert len(scan_index) == len(scan_ranges) + assert len(index) == len(pointwise_ranges) + return [*index[:axis], *scan_index, *index[axis:]] + + results = [ + TensorBox.create( + scan_type( + device=device, + dtype=dtypes[output_index], + dtypes=dtypes, + inner_fn=inner_fns[output_index], + inner_fns=inner_fns, + size=size, + ranges=pointwise_ranges, + scan_ranges=scan_ranges, + combine_fn=combine_fn, + reindex=reindex, + reduction_hint=reduction_hint, + output_index=output_index, + **kwargs, + ) + ) + for output_index in range(len(dtypes)) + ] + + for result in results: + result.realize() + + return results + + @classmethod + def num_splits( + cls, + device: torch.device, + dtype: torch.dtype, + inner_fn: Callable[[Sequence[Expr]], OpsValue], + axis: int, + pointwise_ranges: list[Integer], + scan_ranges: list[Integer], + combine_fn: Callable[[tuple[Any, ...], tuple[Any, ...]], tuple[Any, ...]], + scan_numel: Expr, + ) -> tuple[ReductionHint, _IntLike]: + # TODO: custom splitting heuristic for scan + def wrapper_fn(idx: Sequence[Expr], reduction_idx: Sequence[Expr]) -> OpsValue: + return inner_fn([*idx[:axis], *reduction_idx, *idx[axis:]]) + + return Reduction.num_splits( + device=device, + dst_dtype=dtype, + src_dtype=dtype, + inner_fn=wrapper_fn, + ranges=pointwise_ranges, + reduction_ranges=scan_ranges, + reduction_type="scan", + reduction_numel=scan_numel, + ) + + +# This signifies a scan op that should go through TritonSplitScanKernel codegen on CUDA. +@ir_dataclass +class SplitScan(Scan): + pass + + +@ir_dataclass +class Sort(Loops): + # Sorts a tuple of key, value pairs + sort_ranges: list[Integer] + size: list[Integer] + reindex: Callable[[Sequence[Expr], Sequence[Expr]], Sequence[Expr]] + reduction_hint: ReductionHint + output_index: int + # output_index indexes the following tuples + dtypes: tuple[torch.dtype, ...] + inner_fns: tuple[Callable[..., Any], ...] + + stable: bool + descending: bool + + # HACK we mimic reduction + + @cache_on_self_and_args("Sort") + def get_free_symbol_uses(self, unbacked_only: bool = False) -> OrderedSet[Symbol]: + return ( + super().get_free_symbol_uses(unbacked_only) + | OrderedSet().union( + *(get_free_symbols(e, unbacked_only) for e in self.sort_ranges) + ) + | OrderedSet().union( + *(get_free_symbols(e, unbacked_only) for e in self.size) + ) + ) + + def __post_init__(self) -> None: + assert len(self.ranges) + len(self.sort_ranges) == len(self.size) + super().__post_init__() + + def store_reduction( + self, + output_name: Optional[str], + indexer: Callable[[Sequence[Expr]], Expr], + vars: Sequence[Expr], + reduction_vars: Sequence[Expr], + ) -> Any: + idx = self.reindex(vars, reduction_vars) + values = tuple(inner_fn(idx) for inner_fn in self.inner_fns) + result = ops.sort(self.dtypes, values, self.stable, self.descending) + return ops.store( + output_name or "unnamed", indexer(idx), result[self.output_index] + ) + + def get_reduction_type(self) -> Optional[str]: + return "sort" + + def get_reduction_size(self) -> Sequence[Expr]: + return self.sort_ranges + + def get_size(self) -> Sequence[Expr]: + return self.size + + def get_pointwise_size(self) -> Sequence[Expr]: + return self.ranges + + def index_length(self) -> int: + return len(self.ranges) + len(self.sort_ranges) + + def inner_fn_args(self) -> Sequence[Sequence[Expr]]: + index = self._index(self.ranges) + rindex = self._index(self.sort_ranges, SymT.R0_INDEX) + idx = self.reindex(index, rindex) + return (idx,) + + def inner_fn_free_symbols(self, unbacked_only: bool = False) -> OrderedSet[Symbol]: + index = self._index(self.ranges) + rindex = self._index(self.sort_ranges, SymT.R0_INDEX) + idx = self.reindex(index, rindex) + return extract_free_symbols(self.inner_fn, idx, unbacked_only=unbacked_only) + + @classmethod + def create( # type: ignore[override] + cls, + device: torch.device, + dtypes: tuple[torch.dtype, ...], + inner_fns: tuple[Callable[[list[Expr]], Any], ...], + size: list[Integer], + axis: int, + stable: bool, + descending: bool, + reduction_hint: ReductionHint = ReductionHint.DEFAULT, + **kwargs: Any, + ) -> Sequence[Optional[Union[TensorBox, ShapeAsConstantBuffer]]]: + pointwise_ranges = [*size[:axis], *size[axis + 1 :]] + sort_ranges = [size[axis]] + + if not V.graph.has_feature(device, BackendFeature.SORT): + return [None] * len(dtypes) + + sizevars = V.graph.sizevars + sort_numel = sizevars.simplify(sympy_product(sort_ranges)) + + # Heuristic, smallest rblock where triton usually outperforms aten.sort + # It also isn't bandwidth bound so fusion is unlikely to help. + max_rblock = 512 + is_persistent_kernel = ( + config.triton.persistent_reductions + and sizevars.statically_known_true(sympy.Le(sort_numel, max_rblock)) + ) + if not is_persistent_kernel: + # We only support persistent triton kernels + return [None] * len(dtypes) + + assert len(dtypes) == len(inner_fns) + + # Sort with a single element is just a copy + if sizevars.statically_known_true(sympy.Le(sort_numel, 1)): + return [ + Pointwise.create( + device=device, + dtype=dtypes[output_index], + inner_fn=inner_fns[output_index], + ranges=size, + ) + for output_index in range(len(dtypes)) + ] + + def reindex(index: Sequence[Expr], sort_index: Sequence[Expr]) -> list[Expr]: + assert len(sort_index) == len(sort_ranges) + assert len(index) == len(pointwise_ranges) + return [*index[:axis], *sort_index, *index[axis:]] + + results = [ + TensorBox.create( + Sort( + device=device, + dtype=dtypes[output_index], + dtypes=dtypes, + inner_fn=inner_fns[output_index], + inner_fns=inner_fns, + size=size, + ranges=pointwise_ranges, + sort_ranges=sort_ranges, + reindex=reindex, + reduction_hint=reduction_hint, + output_index=output_index, + stable=stable, + descending=descending, + **kwargs, + ) + ) + for output_index in range(len(dtypes)) + ] + + for result in results: + result.realize() + + return results + + +def is_storage_and_layout(x: IRNode) -> bool: + try: + as_storage_and_layout(x, freeze=False) + return True + except NotImplementedError: + return False + + +def is_contiguous_storage_and_layout(x: IRNode) -> bool: + try: + _buffer, layout = as_storage_and_layout(x, freeze=False) + # pad the stride here so we will NOT claim an tensor as contiguous + # if a padding is gonna happen. + if layout.should_pad_strides(): + layout.pad_strides() + return layout.is_contiguous() + except NotImplementedError: + return False + + +def as_storage_and_layout( + x: IRNode, + freeze: bool = True, + want_contiguous: bool = False, + stride_order: Optional[Sequence[Union[int, Integer]]] = None, + allow_padding: bool = False, + exact_strides: Optional[Sequence[Union[int, Integer]]] = None, +) -> tuple[StorageBox, Layout]: + """ + Try to simplify x into a StorageBox and a Layout. + + allow_padding only affect how we apply stride_order. When allow_padding + is True, we have the freedom to add padding when applying the stride_order. + """ + if isinstance(x, TensorBox): + return as_storage_and_layout( + x.data, + freeze=freeze, + want_contiguous=want_contiguous, + stride_order=stride_order, + allow_padding=allow_padding, + exact_strides=exact_strides, + ) + if isinstance(x, StorageBox): + _, layout = as_storage_and_layout( + x.data, + freeze=freeze, + want_contiguous=want_contiguous, + stride_order=stride_order, + allow_padding=allow_padding, + exact_strides=exact_strides, + ) + return x, x.data.get_layout() + if isinstance(x, Buffer): + if freeze: + if want_contiguous: + x.freeze_layout() + assert x.get_layout().is_contiguous() + elif stride_order is not None: + x.freeze_layout_with_stride_order( + stride_order, allow_padding=allow_padding + ) + elif exact_strides is not None: + x.freeze_layout_with_exact_strides( + exact_strides, allow_padding=allow_padding + ) + else: + x.decide_layout() + return StorageBox(x), x.get_layout() + if isinstance(x, ReinterpretView): + # making the base of x contiguous or stride_ordered will not necessarily make + # the ReinterpretView either, so don't pass along those arguments + buffer, _ = as_storage_and_layout( + x.data, + freeze=freeze, + ) + return buffer, x.layout + raise NotImplementedError + + +def is_stride_order_storage_and_layout( + x: IRNode, stride_order: Sequence[Union[int, Integer]] +) -> bool: + try: + _buffer, layout = as_storage_and_layout(x, freeze=False) + return layout.is_stride_ordered(stride_order) + except NotImplementedError: + return False + + +def is_unaligned(node: IRNode) -> bool: + if isinstance(node, (TensorBox, StorageBox)): + return is_unaligned(node.data) + + if isinstance(node, ReinterpretView): + layout = node.layout + has_unaligned_layout = not V.graph.sizevars.statically_known_multiple_of( + layout.offset * get_dtype_size(layout.dtype), GPU_ALIGN_BYTES + ) + return is_unaligned(node.data) or has_unaligned_layout + + if isinstance(node, Buffer): + return node.get_name() in V.graph.unaligned_buffers + + # assume to be aligned otherwise + return False + + +@ir_dataclass +class BaseView(IRNode): + data: IRNode + + @cache_on_self_and_args("BaseView") + def get_free_symbol_uses(self, unbacked_only: bool = False) -> OrderedSet[Symbol]: + return self.data.get_free_symbol_uses(unbacked_only) + + def make_reindexer(self) -> Callable[[Sequence[Expr]], Sequence[Expr]]: + raise NotImplementedError(f"make_reindexer NYI on {self}") + + def make_indexer(self) -> Callable[[Sequence[Expr]], Expr]: + inner = self.data.make_indexer() + reindex = self.make_reindexer() + + def indexer(idx: Sequence[Expr]) -> Expr: + return inner(reindex(idx)) + + return indexer + + def make_loader(self) -> Callable[[Sequence[Expr]], OpsValue]: + inner = self.data.make_loader() + reindex = self.make_reindexer() + + def loader(idx: Sequence[Expr]) -> OpsValue: + return inner(reindex(idx)) + + return loader + + @property + def dtype(self) -> torch.dtype: + return self.data.get_dtype() + + def get_layout(self) -> Layout: + return self.data.get_layout() + + def get_device(self) -> Optional[torch.device]: + return self.data.get_device() + + def get_origin_node(self) -> Optional[torch.fx.Node]: + return None + + def get_name(self) -> str: + return self.data.get_name() + + def get_pointwise_size(self) -> Sequence[Expr]: + return self.get_size() + + def mark_reuse(self, users: int) -> None: + return self.data.mark_reuse(users) + + def has_exceeded_max_reads(self) -> bool: + return self.data.has_exceeded_max_reads() + + def realize(self) -> Optional[str]: + return self.data.realize() + + def realize_hint(self) -> None: + self.data.realize_hint() + + def get_storage_numel(self) -> _IntLike: + return self.data.get_storage_numel() + + def is_extern(self) -> bool: + return self.data.is_extern() + + def is_module_buffer(self) -> bool: + assert isinstance(self.data, BaseView), type(self.data) + return self.data.is_module_buffer() + + def get_read_names(self) -> OrderedSet[str]: + return self.data.get_read_names() + + def get_reads(self) -> OrderedSet[Dep]: + with patch.object(FlexibleLayout, "allow_indexing", True): + return extract_read_writes( + self.make_loader(), + self.get_size(), + ).reads + + def unwrap_view(self) -> IRNode: + x: IRNode = self + while isinstance(x, BaseView): + x = x.data + return x + + def constant_to_device(self, device: torch.device) -> IRNode: + """Move this to a given device. Requires that all reads are to constants.""" + loader = self.make_loader() + loader = patch.object(ConstantBuffer, "override_device", device)(loader) + return Pointwise( + device=device, + dtype=self.get_dtype(), + inner_fn=loader, + ranges=self.get_size(), + ) + + +@ir_dataclass +class ExpandView(BaseView): + size: Sequence[Expr] + + @staticmethod + def _normalize_size(x: IRNode, new_size: Sequence[_IntLike]) -> Sequence[_IntLike]: + """Replace `-1` with correct sizes""" + sizevars = V.graph.sizevars + new_size = [sympy.expand(s) for s in new_size] + old_size = x.get_size() + old_size = [None] * (len(new_size) - len(old_size)) + list(old_size) + assert len(new_size) == len(old_size) + for i in range(len(new_size)): + if new_size[i] == -1: + assert old_size[i] is not None + new_size[i] = old_size[i] + elif old_size[i] is None or V.graph.sizevars.is_size_one_or_false( + old_size[i] + ): + pass + else: + # Sanity check: Expect broadcast compatibility + # + # NB: new_size[i] == old_size[i] is expected to already be + # guarded because the meta formula was expected to have taught + # us this equality. + # pyrefly: ignore [unsupported-operation] + assert sizevars.size_hint(new_size[i] - old_size[i], fallback=0) == 0, ( + f"Broadcast failed in ExpandView({x.get_size()}, {new_size}) on dimension {i}" + ) + return new_size + + @classmethod + def create(cls, x: IRNode, new_size: Sequence[_IntLike]) -> BaseView: + new_size = cls._normalize_size(x, new_size) + + if is_storage_and_layout(x): + storage, old_layout = as_storage_and_layout(x) + skip = len(new_size) - len(old_layout.size) + assert skip >= 0 + new_stride = [sympy.S.Zero] * skip + for stride, size in zip(old_layout.stride, old_layout.size): + new_stride.append( + stride + if not V.graph.sizevars.is_size_one_or_false(size) + else sympy.S.Zero + ) + new_layout = FixedLayout( + old_layout.device, + old_layout.dtype, + list(new_size), + new_stride, + old_layout.offset, + old_layout.is_pinned, + ) + return ReinterpretView(data=storage, layout=new_layout) + + return ExpandView(data=x, size=new_size) + + def get_size(self) -> Sequence[Expr]: + return self.size + + def make_reindexer( + self, + ) -> Callable[[Sequence[Expr]], Sequence[Expr]]: + target = self.get_size() + actual = self.data.get_size() + skip = len(target) - len(actual) + + def reindex( + index: Sequence[Expr], + ) -> Sequence[Expr]: + index = list(index[skip:]) + assert len(index) == len(actual) + for i in range(len(actual)): + if actual[i] == 1: + # zero out broadcast dimension + index[i] = sympy.S.Zero + return index + + return reindex + + +@ir_dataclass +class PermuteView(BaseView): + dims: list[Expr] + + @classmethod + def create(cls, x: IRNode, dims: Sequence[int]) -> BaseView: + dims = cls._map_neg_dims(dims) + assert OrderedSet(dims) == OrderedSet(range(len(dims))) + + if is_storage_and_layout(x): + storage, old_layout = as_storage_and_layout(x) + new_layout = FixedLayout( + old_layout.device, + old_layout.dtype, + [old_layout.size[i] for i in dims], + [old_layout.stride[i] for i in dims], + old_layout.offset, + old_layout.is_pinned, + ) + return ReinterpretView(data=storage, layout=new_layout) + + return PermuteView(data=x, dims=dims) + + @classmethod + def _map_neg_dims(cls, dims: Sequence[int]) -> list[int]: + return [dim if dim >= 0 else len(dims) + dim for dim in dims] + + def get_size(self) -> Sequence[Expr]: + assert OrderedSet(self._map_neg_dims(self.dims)) == OrderedSet( + range(len(self.dims)) + ) + size = self.data.get_size() + return [size[i] for i in self.dims] + + def make_reindexer( + self, + ) -> Callable[[Sequence[Expr]], Sequence[Expr]]: + inv = {j: i for i, j in enumerate(self.dims)} + inv = [inv[i] for i in range(len(self.dims))] + assert OrderedSet(inv) == OrderedSet(range(len(self.dims))) + + def reindex( + index: Sequence[Expr], + ) -> Sequence[Expr]: + return [index[i] for i in inv] + + return reindex + + +@ir_dataclass +class SqueezeView(BaseView): + @classmethod + def create(cls, x: IRNode, *, dim: Optional[int] = None) -> IRNode: + if is_storage_and_layout(x): + storage, old_layout = as_storage_and_layout(x) + new_size = [] + new_stride = [] + if dim is not None: + assert isinstance(dim, int), type(dim) + assert 0 <= dim and dim < len(old_layout.size) + + for i, (size, stride) in enumerate(zip(old_layout.size, old_layout.stride)): + if dim is None: + # Only append if dim is not squeezed out + if not V.graph.sizevars.is_size_one_or_false(size): + new_size.append(size) + new_stride.append(stride) + else: + if i != dim: + new_size.append(size) + new_stride.append(stride) + else: + assert size == 1, "expected squeezed size to be 1" + + new_layout = FixedLayout( + old_layout.device, + old_layout.dtype, + new_size, + new_stride, + old_layout.offset, + old_layout.is_pinned, + ) + return ReinterpretView(data=storage, layout=new_layout) + + if dim is None: + return View.create( + x, + [ + s + for s in x.get_size() + if not V.graph.sizevars.is_size_one_or_false(s) + ], + ) + else: + assert x.get_size()[dim] == 1 + return View.create(x, [s for i, s in enumerate(x.get_size()) if i != dim]) + + @staticmethod + def squeezer( + size: Sequence[Expr], + ) -> tuple[list[int], Callable[[Sequence[Expr]], tuple[Expr, ...]]]: + new_size = [s for s in size if s != 1] + not_one = [i for i, s in enumerate(size) if s != 1] + length = len(size) + + def reindex(index: Sequence[Expr]) -> tuple[Expr, ...]: + assert len(index) == len(not_one), f"{index} {not_one}" + new_index: list[Expr] = [sympy.S.Zero] * length + for idx, s in zip(not_one, index): + new_index[idx] = s + return tuple(new_index) + + return new_size, reindex + + def __init__(self, data: Any) -> None: + raise AssertionError("use SqueezeView.create()") + + +@ir_dataclass +class GenericView(BaseView): + size: Sequence[Expr] + reindex: Callable[[Sequence[Expr]], Sequence[Expr]] + + def make_reindexer( + self, + ) -> Callable[[Sequence[Expr]], Sequence[Expr]]: + return self.reindex + + def reindex_str(self) -> str: + index_old = [ + sympy_index_symbol_with_prefix(SymT.INDEX, n) for n in range(len(self.size)) + ] + index_new = list(self.reindex(index_old)) + return f"lambda {', '.join(map(str, index_old))}: {index_new}" + + def __str__(self) -> str: + return self.str_helper( + [self.data, f"size={self.size}", f"reindex={self.reindex_str()}"] + ) + + __repr__ = __str__ + + @classmethod + def create( + cls, + x: IRNode, + new_size: Sequence[Expr], + reindex: Callable[[Sequence[Expr]], Sequence[Expr]], + ) -> BaseView: + return cls(data=x, size=list(new_size), reindex=reindex) + + def get_size(self) -> Sequence[Expr]: + return self.size + + +@ir_dataclass +class View(GenericView): + @staticmethod + def handle_negative_index(idx: Expr, size: Expr) -> Expr: + idx = sympy.expand(idx) + size = sympy.expand(size) + evaluate_expr = V.graph.sizevars.shape_env.evaluate_expr + if evaluate_expr(sympy.Lt(idx, 0)): + idx = idx + size + return idx + + @classmethod + def create(cls, x: IRNode, new_size: Sequence[Expr]) -> IRNode: # type: ignore[override] + assert isinstance(new_size, Sequence), type(new_size) + old_size, new_size = cls.resolve_negative_size(x.get_size(), new_size) + + # Skip pointless views + if V.graph.sizevars.statically_known_list_equals(old_size, new_size): + return x + + unbacked_symbols_in_sizes = False + if ( + len(free_unbacked_symbols(old_size)) > 0 + or len(free_unbacked_symbols(new_size)) > 0 + ): + unbacked_symbols_in_sizes = True + + if 0 in new_size: + + def fake_reindex(index: Any) -> tuple[int, ...]: + return tuple([0] * len(old_size)) + + return cls(data=x, size=list(new_size), reindex=fake_reindex) + # TODO: a new class for FixedTransferLayout that output layout is constrained by input layout + elif is_contiguous_storage_and_layout(x) or unbacked_symbols_in_sizes: + if unbacked_symbols_in_sizes and (not is_contiguous_storage_and_layout(x)): + # realize x; otherwise, the dynamic_reshape_indexer below will fail + # due to the size_hint's inability to process unbacked SymInts + # TODO: unbacked should not diverge from backed in determining striding + # Need to require contiguous here instead of realize, see: + # https://github.com/pytorch/pytorch/issues/145561 + x = ExternKernel.require_contiguous(x) + + storage, old_layout = as_storage_and_layout(x, want_contiguous=True) + new_layout = FixedLayout( + old_layout.device, + old_layout.dtype, + new_size, + FlexibleLayout.contiguous_strides(new_size), + old_layout.offset, + old_layout.is_pinned, + ) + return ReinterpretView(data=storage, layout=new_layout) + + reindex = cls.dynamic_reshape_indexer(old_size, new_size) + return cls(data=x, size=list(new_size), reindex=reindex) + + @staticmethod + def resolve_negative_size( + old_size: Sequence[Expr], new_size: Sequence[Expr] + ) -> tuple[list[Expr], list[Expr]]: + new_size = [V.graph.sizevars.simplify(x) for x in new_size] + old_size = [V.graph.sizevars.simplify(x) for x in old_size] + + new_size = list(new_size) + for i in range(len(new_size)): + if new_size[i] == -1: + new_size[i] = sympy.S.One + new_size[i] = CleanDiv(sympy_product(old_size), sympy_product(new_size)) + break + + V.graph.sizevars.check_equals(sympy_product(old_size), sympy_product(new_size)) + return old_size, new_size + + @classmethod + def dynamic_reshape_indexer( + cls, + old_size: Sequence[_IntLike], + new_size: Sequence[_IntLike], + dense_dim: Optional[int] = None, + ) -> Callable[[Sequence[_T]], Sequence[_V]]: + try: + reindex = cls._dynamic_reshape_indexer(old_size, new_size, dense_dim) + except (AssertionError, IndexError): + # optimistic algorithm failed, lets do a fallback + flat = [sympy_product(old_size)] + reindex1 = cls._dynamic_reshape_indexer(old_size, flat) + reindex2 = cls._dynamic_reshape_indexer(flat, new_size) + reindex = fuse_reindexing(reindex1, reindex2) + return reindex + + @staticmethod + def _dynamic_reshape_indexer( + old_size: Sequence[Expr], + new_size: Sequence[Expr], + dense_dim: Optional[int] = None, + ) -> Callable[[Sequence[Expr]], Sequence[Expr]]: + """ + Perform a reshape entirely by modifying indexing math + """ + size_hint = V.graph.sizevars.size_hint + # TODO: These symbols may not escape, if they don't assert so and + # treat them as temporary + vars = [ + sympy_index_symbol_with_prefix(SymT.VIEW, i) for i in range(len(new_size)) + ] + + stack_new = list(zip(vars, new_size)) + stack_old = list(old_size) + + # process the dense dim first + reordering_dense_dim = ( + dense_dim is not None + and dense_dim != len(stack_old) - 1 + and len(new_size) == 1 + ) + if reordering_dense_dim: + assert dense_dim is not None # mypy + old_dim = stack_old.pop(dense_dim) + stack_old.append(old_dim) + + view_expr = [] + while stack_new and stack_old: + size_old = stack_old.pop() + var, size_new = stack_new.pop() + if size_old == 1: + view_expr.append(sympy.S.Zero) + stack_new.append((var, size_new)) # re-add + elif size_new == 1: + stack_old.append(size_old) # re-add + elif size_hint(size_new) == size_hint(size_old): + view_expr.append(var) + V.graph.sizevars.check_equals(size_new, size_old) + elif size_hint(size_new) < size_hint(size_old): + while size_hint(size_new) < size_hint(size_old): + var2, size_new2 = stack_new.pop() + var = var2 * size_new + var + size_new = size_new * size_new2 + view_expr.append(var) + V.graph.sizevars.check_equals(size_new, size_old) + elif size_hint(size_new) > size_hint(size_old): + divisor = sympy.S.One + modulus = size_old + view_expr.append(ModularIndexing(var, divisor, modulus)) + divisor = divisor * modulus + while size_hint(size_new) > size_hint(size_old): + modulus = stack_old.pop() + view_expr.append(ModularIndexing(var, divisor, modulus)) + divisor = divisor * modulus + size_old = size_old * modulus + V.graph.sizevars.check_equals(size_new, size_old) + else: + raise AssertionError + + while stack_old: + size_old = stack_old.pop() + V.graph.sizevars.check_equals(size_old, 1) + view_expr.append(sympy.S.Zero) + + while stack_new: + var, size_new = stack_new.pop() + V.graph.sizevars.check_equals(size_new, 1) + + if dense_dim is not None and len(new_size) == 1: + view_expr.reverse() + # Move the last expression (dense dim) to its original position + dense_expr = view_expr.pop() + view_expr.insert(dense_dim, dense_expr) + else: + view_expr.reverse() + + assert len(view_expr) == len(old_size) + + def reindex( + index: Sequence[Expr], + ) -> Sequence[Expr]: + assert len(index) == len(vars), (len(index), len(vars)) + replacements = dict(zip(vars, index)) + return tuple(sympy_subs(x, replacements) for x in view_expr) + + return reindex + + +@ir_dataclass +class ReinterpretView(BaseView): + """Pretend our storage has a different layout""" + + layout: Layout + + def __post_init__(self) -> None: + super().__post_init__() + if isinstance(self.data, BaseView): + object.__setattr__(self, "data", self.data.unwrap_view()) + + def __str__(self) -> str: + return self.str_helper( + [ + self.data, + self.layout, + ] + ) + + __repr__ = __str__ + + def get_name(self) -> str: + return self.data.get_name() + + def get_device(self) -> Optional[torch.device]: + return self.layout.device + + def get_origin_node(self) -> Optional[torch.fx.Node]: + return None + + @property + def dtype(self) -> torch.dtype: + return self.layout.dtype + + def get_size(self) -> Sequence[Expr]: + return list(self.layout.size) + + def get_stride(self) -> Sequence[Expr]: + return list(self.layout.stride) + + def make_loader(self) -> Callable[[Sequence[Expr]], OpsValue]: + def loader(index: Sequence[Expr]) -> OpsValue: + indexer = self.layout.make_indexer() + tmp_loader = ops.load(self.get_name(), indexer(index)) + if self.layout.dtype != self.data.dtype: + return ops.to_dtype_bitcast(tmp_loader, self.dtype, self.data.dtype) + else: + return tmp_loader + + return loader + + def make_indexer(self) -> Callable[[Sequence[Expr]], Expr]: + return self.layout.make_indexer() + + def get_layout(self) -> Layout: + return self.layout + + def freeze_layout(self) -> None: + pass + + @cache_on_self_and_args("ReinterpretView") + def get_free_symbol_uses( + self, unbacked_only: bool = False + ) -> OrderedSet[sympy.Symbol]: + return ( + get_free_symbols(self.layout.size, unbacked_only) + | get_free_symbols(self.layout.stride, unbacked_only) + | get_free_symbols(self.layout.offset, unbacked_only) + ) + + def codegen_reference(self, writer: Optional[IndentedBuffer] = None) -> str: + # reinterpret_tensor is similar to as_strided except: + # - offset is added to the existing offset (rather than replacing it) + # - view tracking is disabled similar to unsafe_view + return V.graph.wrapper_code.codegen_reinterpret_view( + self.data, + self.layout.size, + self.layout.stride, + self.layout.offset, + writer.writeline if writer is not None else V.graph.wrapper_code.writeline, + dtype=self.layout.dtype, + ) + + def num_reads(self) -> int: + return 1 + + +@ir_dataclass +class DtypeView(BaseView): + """Pretend our storage has a different type""" + + target_dtype: torch.dtype + + @classmethod + def create(cls, x: IRNode, new_dtype: torch.dtype) -> BaseView: + if is_storage_and_layout(x): + storage, old_layout = as_storage_and_layout(x) + new_layout = FixedLayout( + old_layout.device, + new_dtype, + old_layout.size, + old_layout.stride, + old_layout.offset, + old_layout.is_pinned, + ) + return ReinterpretView(data=storage, layout=new_layout) + return DtypeView(data=x, target_dtype=new_dtype) + + def __str__(self) -> str: + return self.str_helper([self.data, self.target_dtype]) + + __repr__ = __str__ + + @property + def dtype(self) -> torch.dtype: + return self.target_dtype + + def get_size(self) -> Sequence[Expr]: + return self.data.get_size() + + def make_loader(self) -> Callable[[Sequence[Expr]], OpsValue]: + inner = self.data.make_loader() + + def loader(idx: Sequence[Expr]) -> OpsValue: + return ops.to_dtype_bitcast(inner(idx), self.target_dtype, self.data.dtype) + + return loader + + +class SliceView(View): + @classmethod + def normalize_start_end( + cls, x: IRNode, dim: int, start: int, end: int + ) -> tuple[int, int]: + """ + Normalize start and end such that both are in the range + [0, x.get_size()[dim]] and start <= end. + """ + sizevars = V.graph.sizevars + dim_size = x.get_size()[dim] + + if any(free_unbacked_symbols(x) for x in (start, end, dim_size)): + min_func = sympy.Min + max_func = sympy.Max + else: + min_func = sizevars.evaluate_min + max_func = sizevars.evaluate_max + + def clamp(x: Expr, lower: int, upper: int) -> Expr: + clamped_lower = ( + x if sizevars.statically_known_geq(x, lower) else max_func(x, lower) + ) + clamped_full = ( + clamped_lower + if sizevars.statically_known_leq(clamped_lower, upper) + else min_func(clamped_lower, upper) + ) + return clamped_full + + def clamp_wrap( + val: Union[int, None], lower: int, upper: int, default: Union[Expr, int] + ) -> Union[Expr, int]: + if val is None: + # TODO(rec): can this really happen? + return default + val = cls.handle_negative_index(val, dim_size) + return clamp(val, lower, upper) + + start = clamp_wrap(start, 0, dim_size, 0) + end = clamp_wrap(end, start, dim_size, dim_size) + return start, end + + @classmethod + def create( # type: ignore[override] + cls, + x: IRNode, + dim: int, + start: int, + end: int, + step: int = 1, + clamp: bool = True, + ) -> IRNode: + step = sympy.expand(step) + assert isinstance(step, Expr) or step > 0, step + try: + if start == 0 and end >= 2**63 - 1 and step == 1: + return x + except TypeError: + pass + + new_size = list(x.get_size()) + + # NB: Ordinarily we default to clamping. + # We only don't clamp for split_with_sizes. For split_with_sizes, sizes should be already valid + # failing in this situation is ok, since invalid sizes could trigger silent errors. + if clamp: + start, end = cls.normalize_start_end(x, dim, start, end) + + new_size[dim] = FloorDiv(end - start + (step - 1), step) + + if is_storage_and_layout(x): + # Fast path + storage, old_layout = as_storage_and_layout(x) + new_stride = list(old_layout.stride) + new_stride[dim] = new_stride[dim] * step + new_layout = FixedLayout( + old_layout.device, + old_layout.dtype, + new_size, + new_stride, + old_layout.offset + old_layout.stride[dim] * start, + old_layout.is_pinned, + ) + return ReinterpretView(data=storage, layout=new_layout) + + def reindex( + index: Sequence[Expr], + ) -> Sequence[Expr]: + assert len(index) == len(new_size), f"wrong ndim {index} {new_size}" + index = list(index) + index[dim] = index[dim] * step + start + return index + + # redirect to a generic view + return SliceView(data=x, size=new_size, reindex=reindex) + + +@ir_dataclass +class BaseConstant(IRNode): + dtype: torch.dtype + device: torch.device + + def get_size(self) -> Sequence[Expr]: + return () + + def get_device(self) -> Optional[torch.device]: + return self.device + + def get_origin_node(self) -> Optional[torch.fx.Node]: + return None + + def get_reads(self) -> OrderedSet[Dep]: + return OrderedSet() + + +@ir_dataclass +class Constant(BaseConstant): + value: Any + dtype: torch.dtype + device: torch.device + + def make_loader(self) -> Callable[[Sequence[Expr]], OpsValue]: + def loader(index: Sequence[Expr]) -> OpsValue: + return ops.constant(self.value, self.dtype) + + return loader + + def realize(self) -> Optional[str]: + pass + + def constant_to_device(self, device: torch.device) -> IRNode: + return Constant(value=self.value, dtype=self.dtype, device=device) + + +@ir_dataclass +class IndexingConstant(BaseConstant): + index: Any + dtype: torch.dtype + device: torch.device + + def make_loader(self) -> Callable[[Sequence[Expr]], OpsValue]: + def loader(index: Sequence[Expr]) -> OpsValue: + return ops.index_expr(self.index, self.dtype) + + return loader + + def constant_to_device(self, device: torch.device) -> IRNode: + return IndexingConstant(index=self.index, dtype=self.dtype, device=device) + + +def is_contiguous_strides_for_shape( + stride: Sequence[_IntLike], shape: Sequence[_IntLike] +) -> bool: + expected_stride = 1 + expected_stride_max = 1 + for x, y in reversed(tuple(zip(shape, stride))): + if x == 1: + continue + + if not V.graph.sizevars.statically_known_equals( + y, expected_stride + ) and not V.graph.sizevars.statically_known_equals(y, expected_stride_max): + return False + + expected_stride_max *= sympy.Max(1, x) + expected_stride *= x + + return True + + +def get_align_for_dtype(dtype: torch.dtype) -> int: + return config.padding_alignment_bytes // dtype.itemsize + + +class OutputSpec: + """Abstract base for Layout, MultiOutputLayout, NoneLayout. + Represents the memory layout of the output of an Operation.""" + + def get_device(self) -> Optional[torch.device]: + raise NotImplementedError(type(self).__name__) + + def storage_size(self) -> int: + raise NotImplementedError(type(self).__name__) + + def get_free_symbol_uses( + self, unbacked_only: bool = False + ) -> OrderedSet[sympy.Symbol]: + raise NotImplementedError(type(self).__name__) + + +@ir_dataclass +class Layout(OutputSpec): + """ + Layout base class + + Carries tensor meta-information including offset and + whether it is pinned. + """ + + def __init__( + self, + device: torch.device, + dtype: torch.dtype, + size: Sequence[Expr], + stride: Optional[Sequence[Expr]] = None, + offset: Expr = Integer(0), + is_pinned: bool = False, + ) -> None: + if stride is None: + stride = FlexibleLayout.contiguous_strides(size) + # pyrefly: ignore [read-only] + self.device = device + self.dtype = dtype + assert len(size) == len(stride), f"size={size}, stride={stride}" + assert all(isinstance(s, (Expr, int)) for s in size) + self._size = size + self._stride = stride + self._offset = offset + self.is_pinned = is_pinned + # is_pinned implies cpu + assert (not self.is_pinned) or (self.device.type == "cpu") + + @property + def size(self) -> Sequence[Expr]: + return self._size + + @size.setter + def size(self, value: Sequence[Expr]) -> None: + self._size = value + + @property + def stride(self) -> Sequence[Expr]: + return self._stride + + @stride.setter + def stride(self, value: Sequence[Expr]) -> None: + self._stride = value + + @property + def offset(self) -> Expr: + return self._offset + + @offset.setter + def offset(self, value: Expr) -> None: + self._offset = value + + def __str__(self) -> str: + offset = "" + if self.offset != 0: + offset = f", offset={self.offset}" + + device_index_str = "" if self.device.index is None else f":{self.device.index}" + is_pinned_str = "" + if self.is_pinned: + is_pinned_str = f", is_pinned={self.is_pinned}" + return ( + f"{type(self).__name__}('{self.device.type}{device_index_str}', {self.dtype}, " + f"size={self.size}, stride={self.stride}{offset}{is_pinned_str})" + ) + + __repr__ = __str__ + + def get_device(self) -> torch.device: + return self.device + + def get_example(self) -> torch.Tensor: + with V.fake_mode: + return torch.empty_strided( + convert_shape_to_symint(self.size), + convert_shape_to_symint(self.stride), + dtype=self.dtype, + device=self.device, + pin_memory=self.is_pinned, + ) + + def is_contiguous(self) -> bool: + return is_contiguous_strides_for_shape(self.stride, self.size) + + @staticmethod + def is_channels_last_contiguous( + shape: Sequence[_IntLike], strides: Sequence[_IntLike] + ) -> bool: + ndim = len(shape) + if ndim not in [4, 5] or shape[1] == 1: + return False + for left, right, size in zip( + strides, make_channels_last_strides_for(shape), shape + ): + if size != 1 and left != right: + return False + return True + + def is_transposed(self) -> bool: + for left, right, size in zip( + self.stride, + reversed(FlexibleLayout.contiguous_strides(list(reversed(self.size)))), + self.size, + ): + if size != 1 and left != right: + return False + return True + + def is_stride_ordered(self, order: Sequence[int]) -> bool: + assert len(self.stride) == len(order) + + # ignore dimensions of size 1, they dont affect layout + non_1_indices = [ + i + for i, dim in enumerate(self.size) + if V.graph.sizevars.size_hint(dim, fallback=2) != 1 + ] + + stride = [self.stride[i] for i in non_1_indices] + order: Sequence[int] = [order[i] for i in non_1_indices] + + def sorted_indices(arr: Sequence[int]) -> Sequence[int]: + sorted_arr = sorted(arr) + return [sorted_arr.index(element) for element in arr] + + # since we may have removed dimensions, need to re-sort & re-index order + order = sorted_indices(order) + + # reorder the stride given order + stride_ordered = [-1] * len(order) + for i in range(len(order)): + stride_ordered[order[i]] = stride[i] + # check if it is in ascending order + for i in range(len(order) - 1): + expr = stride_ordered[i] > stride_ordered[i + 1] + if not isinstance(expr, bool): + expr = V.graph._shape_env.evaluate_expr( + stride_ordered[i] > stride_ordered[i + 1], size_oblivious=True + ) + if expr: + return False + return True + + def is_channels_last_stride_ordered(self) -> bool: + # create channels_last order(NCHW, NCDHW, the C is the first order). + order = [0] + list(reversed(range(1, len(self.stride) - 1))) + order = [len(order)] + order + return self.is_stride_ordered(order) + + @staticmethod + def _pad_strides( + in_strides: Sequence[int], size: Sequence[Expr], dtype: torch.dtype + ) -> Sequence[int]: + """ + The padding does not change stride order but makes sure all strides larger + than the threshold are multiple of align. + """ + align = get_align_for_dtype(dtype) + if len(in_strides) == 0: + return in_strides + + if not config.pad_channels_last and Layout.is_channels_last_contiguous( + size, in_strides + ): + return in_strides + + current_fx_node = V.get_current_node() + if hasattr(current_fx_node, "meta") and current_fx_node.meta.get( + "dislike_padding", False + ): + return in_strides + + # Skip padding the strides for dynamic shapes based on config.pad_dynamic_shape + # Checking both shape and strides, as there are cases where only one is dynamic + is_dynamic = not all( + isinstance(s, (int, sympy.Integer)) + for s in itertools.chain(in_strides, size) + ) + if not config.pad_dynamic_shapes and is_dynamic: + return in_strides + + shape_env = V.graph._shape_env if hasattr(V.graph, "_shape_env") else None + + def contains_unbacked_symints(expr: sympy.Expr | int) -> bool: + if shape_env is None: + return False + if not isinstance(expr, sympy.Expr): + return False + return any(shape_env.is_unbacked_symint(s) for s in expr.free_symbols) + + # Skip padding the strides when it contains unbacked symints for now. + if shape_env and any(contains_unbacked_symints(s) for s in in_strides): + return in_strides + + stride_order = get_stride_order(in_strides, shape_env) + fill_order = stride_order2fill_order(stride_order) + + new_strides = [0 for _ in range(len(in_strides))] + # since we pad when the layout is flexible, we can decide the + # smallest stride to be 1. + new_strides[fill_order[0]] = 1 + + padded = False + for rank, idx in enumerate(fill_order[1:], start=1): + prev_idx = fill_order[rank - 1] + stride = new_strides[prev_idx] * size[prev_idx] + # Static stride and meets padding conditions OR + # Dynamic stride and config.pad_dynamic_shape=True + require_padding = ( + isinstance(stride, (int, sympy.Integer)) + and stride > config.padding_stride_threshold + and stride % align != 0 + ) or (isinstance(stride, sympy.Expr) and config.pad_dynamic_shapes) + new_strides[idx] = stride + if require_padding: + new_strides[idx] = ceildiv(stride, align) * align + padded = True + + if not padded: + # Consider a tensor with shape [256, 1, 5, 5] + # Avoid strides like [25, 5, 5, 1] being padded to equivalent strides + # [25, 25, 5, 1]. + return in_strides + + # pyrefly: ignore [bad-assignment] + metrics.num_comprehensive_padding += 1 + return new_strides + + def pad_strides(self) -> None: + assert isinstance(self, FlexibleLayout), type(self) + assert self.stride is not None + self.stride = self._pad_strides(self.stride, self.size, self.dtype) + + def should_pad_strides(self) -> bool: + return config.comprehensive_padding and isinstance(self, FlexibleLayout) + + def as_fixed(self) -> FixedLayout: + if isinstance(self, FixedLayout): + return self + + if self.should_pad_strides(): + self.pad_strides() + return FixedLayout( + self.device, + self.dtype, + self.size, + self.stride, + self.offset, + self.is_pinned, + ) + + def make_indexer(self) -> Callable[[Sequence[Expr]], Expr]: + assert FlexibleLayout.allow_indexing, ( + f"convert {type(self).__name__} to FixedLayout first" + ) + return self.as_fixed().make_indexer() + + def __eq__(self, other: object) -> bool: + return ( + isinstance(other, Layout) + and self.device == other.device + and self.dtype == other.dtype + and self.size == other.size + and self.stride == other.stride + and self.offset == other.offset + and self.is_pinned == other.is_pinned + ) + + def storage_size(self) -> Expr: + return compute_required_storage_length(self.size, self.stride, self.offset) # type: ignore[arg-type] + + @cache_on_self_and_args("Layout") + def get_free_symbol_uses( + self, unbacked_only: bool = False + ) -> OrderedSet[sympy.Symbol]: + return ( + get_free_symbols(self.size, unbacked_only) + | get_free_symbols(self.stride, unbacked_only) + | get_free_symbols(self.offset, unbacked_only) + ) + + +class FixedLayout(Layout): + """A Tensor layout we cannot change""" + + def make_indexer(self) -> Callable[[Sequence[Expr]], Expr]: + """A closure containing math to read a given element""" + return _fixed_indexer(self.size, self.stride, self.offset) + + +class FlexibleLayout(Layout): + """ + A Tensor layout that we are allowed to change + + Assumption: layout change should NOT add or remove free symbols + """ + + allow_indexing = False + + # WARNING! This doesn't handle zero size tensors correctly + @staticmethod + def contiguous_strides(sizes: Sequence[int]) -> list[Expr]: + if len(sizes) == 0: + return [] + reversed_strides = [sympy.S.One] + for size in reversed(sizes[1:]): + reversed_strides.append(size * reversed_strides[-1]) + return list(reversed(reversed_strides)) + + @staticmethod + def fill_ordered(sizes: Sequence[int], order: Sequence[int]) -> list[Expr]: + """ + Create a stride based on the order the dimensions should be filled in. + + In this format, channels last would be: + [1, 3, 2, 0] + """ + assert OrderedSet(range(len(sizes))) == OrderedSet(order), (sizes, order) + next_stride = sympy.S.One + strides = [None] * len(order) + + for i in order: + strides[i] = next_stride + next_stride = next_stride * sizes[i] + return strides + + @staticmethod + def stride_ordered(sizes: Sequence[int], order: Sequence[int]) -> Sequence[Expr]: + """ + Create a stride based on the sorted order of a permuted range. + + In this format, channels last would be: + [3, 0, 2, 1] + """ + assert OrderedSet(range(len(sizes))) == OrderedSet(order) + fill_order = stride_order2fill_order(order) + return FlexibleLayout.fill_ordered(sizes, fill_order) + + @staticmethod + def stride_ordered_for_memory_format( + sizes: Sequence[int], memory_format: torch.memory_format + ) -> Sequence[Expr]: + """ + Create a stride based on a memory format. + + Memory format is translasted into a stride order, + so channels_last is the same as: + FlexibleLayout.stride_ordered(sizes, [3, 0, 2, 1]) + + This interface does not support memory_format `torch.preserve_format` + which should be used to deduce a format from another source + """ + if memory_format == torch.channels_last: + return FlexibleLayout.stride_ordered(sizes, NHWC_STRIDE_ORDER) + elif memory_format == torch.channels_last_3d: + return FlexibleLayout.stride_ordered(sizes, NHWDC_STRIDE_ORDER) + elif memory_format == torch.contiguous_format: + return FlexibleLayout.contiguous_strides(sizes) + else: + log.debug( + "stride_ordered_for_memory_format, unsuppored memory_format: %s", + memory_format, + ) + raise NotImplementedError + + @staticmethod + def same_ordered( + sizes: Sequence[int], stride: Sequence[_IntLike] + ) -> Sequence[Expr]: + """ + Create a stride that has the same stride order as given stride + + For example, if given stride is [1000, 1, 100, 10], + the fill order should be [1, 3, 2, 0] + """ + assert len(sizes) == len(stride) + stride = [V.graph.sizevars.size_hint_or_throw(x) for x in stride] + fill_order = sorted(range(len(stride)), key=stride.__getitem__) + return FlexibleLayout.fill_ordered(sizes, fill_order) + + @property + def size(self) -> Sequence[Expr]: + return self._size + + @size.setter + def size(self, value: Sequence[Expr]) -> None: + self.assert_free_symbol_uses_unchanged("size", value) + self._size = value + + @property + def stride(self) -> Sequence[Expr]: + return self._stride + + @stride.setter + def stride(self, value: Sequence[Expr]) -> None: + self.assert_free_symbol_uses_unchanged("stride", value) + self._stride = value + + @property + def offset(self) -> Expr: + return self._offset + + @offset.setter + def offset(self, value: Expr) -> None: + self.assert_free_symbol_uses_unchanged("offset", value) + self._offset = value + + def as_stride_order( + self, order: Sequence[int], allow_padding: bool = False + ) -> FixedLayout: + new_stride = self.stride_ordered(self.size, order) + if self.should_pad_strides() and allow_padding: + new_stride = self._pad_strides(new_stride, self.size, self.dtype) + + return FixedLayout( + self.device, + self.dtype, + self.size, + new_stride, + self.offset, + self.is_pinned, + ) + + def as_exact_strides( + self, exact_strides: Sequence[_IntLike], allow_padding: bool = False + ) -> FixedLayout: + new_stride = exact_strides + if self.should_pad_strides() and allow_padding: + new_stride = self._pad_strides(new_stride, self.size, self.dtype) + + return FixedLayout( + self.device, + self.dtype, + self.size, + new_stride, + self.offset, + self.is_pinned, + ) + + def as_fill_order(self, order: Sequence[int]) -> FixedLayout: + new_stride: Sequence[int] = self.fill_ordered(self.size, order) + if self.should_pad_strides(): + new_stride = self._pad_strides(new_stride, self.size, self.dtype) + return FixedLayout( + self.device, + self.dtype, + self.size, + new_stride, + self.offset, + self.is_pinned, + ) + + def as_same_order(self, stride: Sequence[_IntLike]) -> FixedLayout: + new_stride = self.same_ordered(self.size, stride) + if self.should_pad_strides(): + new_stride = self._pad_strides(new_stride, self.size, self.dtype) + return FixedLayout( + self.device, + self.dtype, + self.size, + new_stride, + self.offset, + self.is_pinned, + ) + + def get_initial_free_symbol_uses(self) -> dict[tuple[str, bool], sympy.Symbol]: + initial_free_symbols = {} + for name in ["size", "stride", "offset"]: + for unbacked_only in [True, False]: + key = (name, unbacked_only) + initial_free_symbols[key] = OrderedSet( + get_free_symbols(getattr(self, name), unbacked_only) + ) + + return initial_free_symbols + + def assert_free_symbol_uses_unchanged(self, name: str, value: IterateExprs) -> None: + for unbacked_only in [True, False]: + old_free_symbols = self.initial_free_symbols[(name, unbacked_only)] + new_free_symbols = OrderedSet(get_free_symbols(value, unbacked_only)) + assert new_free_symbols == old_free_symbols, ( + f"Expected free symbols unchanged, but got {new_free_symbols} vs {old_free_symbols}" + ) + + def __init__( + self, + device: torch.device, + dtype: torch.dtype, + size: Sequence[Expr], + stride_order: Optional[Sequence[Union[int, Integer]]] = None, + is_pinned: bool = False, + ) -> None: + if stride_order: + strides = FlexibleLayout.fill_ordered(size, stride_order) + else: + strides = FlexibleLayout.contiguous_strides(size) + super().__init__(device, dtype, size, strides, is_pinned=is_pinned) + + # record the initial free symbols to check that we do not add new free symbols + # later when modifying sizes, strides, and offsets. + self.initial_free_symbols = self.get_initial_free_symbol_uses() + + +class NonOwningLayout(Layout): + """Is a view into the storage of another tensor""" + + def __init__(self, view: Union[BaseView, TensorBox]) -> None: + layout = view.get_layout() + super().__init__( + layout.device, + layout.dtype, + layout.size, + layout.stride, + ) + self.view = view + + def make_indexer(self) -> Callable[[Sequence[Expr]], Expr]: + return self.as_fixed().make_indexer() + + def maybe_guard_aligned(self) -> bool: + offset = self.view.get_layout().offset + if offset == 0: + return True + from .utils import ALIGNMENT + + return V.graph.sizevars.statically_known_multiple_of(offset, ALIGNMENT) + + @cache_on_self_and_args("NonOwningLayout") + def get_free_symbol_uses( + self, unbacked_only: bool = False + ) -> OrderedSet[sympy.Symbol]: + assert isinstance(self.view, ReinterpretView) + box = self.view.data + assert isinstance(box, StorageBox), type(box) + input_buffer = box.data + assert isinstance(input_buffer, Buffer), type(box) + return input_buffer.layout.get_free_symbol_uses(unbacked_only) + + +class CommBufferType(Enum): + SYMM_MEM = "symm_mem" + + +class CommBufferLayout(FixedLayout): + """ + A layout that signifies the buffer is a comm buffer. + In terms of striding, the layout is identical to `FixedLayout`. + + Buffers with this layout do not participate in in-place reuse - it can be + neither the source nor the target for in-place reuse. + + For detailed motivation and usage of this layout, see + NOTE [lowering-time collective optimization]. + """ + + comm_buffer_type: CommBufferType + group_name: str + + def __init__( + self, + layout: FlexibleLayout, + comm_buffer_type: CommBufferType, + group_name: str, + ): + if not isinstance(layout, FlexibleLayout): + raise AssertionError( + "A `CommBufferLayout` can only be initialized with " + f"a `FlexibleLayout` (got {layout})." + ) + + fixed = layout.as_fixed() + super().__init__( + device=fixed.device, + dtype=fixed.dtype, + size=fixed.size, + stride=fixed.stride, + offset=fixed.offset, + is_pinned=fixed.is_pinned, + ) + self.comm_buffer_type = comm_buffer_type + self.group_name = group_name + + +@ir_dataclass +class NoneLayout(OutputSpec): + # This is janky, I figured out what fields to populate by just running + # the model I was interested in and adding properties/methods as needed. + # This doesn't inherit from Layout because Layout assumes you have stuff + # like sizes, but I don't really have anything here. + # + # If you have an ir.Node with NoneLayout, you probably need to setup + # dependencies manually in scheduler + + device: Optional[torch.device] + size: list[int] = dataclasses.field(default_factory=lambda: [0]) + stride: list[int] = dataclasses.field(default_factory=lambda: [0]) + + def storage_size(self) -> int: + return 0 + + def as_fixed(self) -> OutputSpec: + return self + + def get_device(self) -> Optional[torch.device]: + return self.device + + +class MutationLayoutSHOULDREMOVE(Layout): + def __init__(self, target: IRNode) -> None: + super().__init__( + target.get_device_or_error(), + target.get_dtype(), + target.get_size(), + None, + ) + self.target = target + name = self.get_buffer().get_name() + V.graph.mark_buffer_mutated(name) + + @property + def stride(self) -> Sequence[Expr]: # type: ignore[override] + return self.real_layout().stride + + @stride.setter # type: ignore[override] + def stride(self, value: Never) -> None: + pass # ignore setting of stride + + def storage_size(self) -> Expr: + return self.real_layout().storage_size() + + def get_buffer(self) -> Buffer: + def unwrap_views(target: Any) -> Any: + if isinstance(target, MutationLayoutSHOULDREMOVE): + return unwrap_views(target.target) + if isinstance(target, BaseView): + return unwrap_views(target.unwrap_view()) + if isinstance(target, MutableBox): + return unwrap_views(target.data) + return target + + result = unwrap_views(self.target) + assert isinstance(result, Buffer), type(result) + return result + + def real_layout(self) -> Layout: + layout = self.get_buffer().layout + assert isinstance(layout, Layout) + return layout + + @classmethod + def realize_into( + cls, src: IRNode, dst: IRNode, unsafe_alias: bool = False + ) -> IRNode: + dst.realize() + # NOTE: We must realize users of `dst` before we realize `src`, since + # realization order determines scheduling order. Otherwise, src's + # mutation would be scheduled before the existing users of dst! + V.graph.mark_buffer_mutated(dst.get_name()) + + if isinstance(src, TensorBox): + src = src.data + + # We copy the contents of src into dst. In most cases this should + # be fused into a single kernel by the scheduler. + # NOTE: We cannot change src's layout to mutate dst directly as this + # would alias src to dst, which is not correct as further mutations to + # dst would effect users of src. However if there are no more users of + # dst, we can alias src to dst. + src.realize_hint() + + if not unsafe_alias: + node = Pointwise.create( + device=src.get_device(), + dtype=src.get_dtype(), + inner_fn=src.make_loader(), + ranges=[ + V.graph.sizevars.check_equals_and_simplify(a, b) + for a, b in zip(src.get_size(), dst.get_size()) + ], + ) + assert isinstance(node, (BaseView, MutableBox)) + src = node.data + + src.realize() + assert hasattr(src, "data"), src + assert isinstance(src.data.layout, FlexibleLayout), type(src.data.layout) + src.data.layout = MutationLayoutSHOULDREMOVE(dst) + return src.data + + def as_fixed(self) -> Self: # type: ignore[override] + return self + + def make_indexer(self) -> Callable[[Sequence[Expr]], Expr]: + return self.target.make_indexer() + + +@ir_dataclass(frozen=False) +class Buffer(IRNode, CodegenSymbol): + # Name is sometimes None; e.g., ForceInPlace, where there isn't + # a meaningful name + name: Optional[str] + layout: OutputSpec + + # Multi-output buffers will define 'outputs: List[Buffer]'. Confusingly, + # MultiOutput does NOT define this! + + def __post_init__(self) -> None: + super().__post_init__() + self._post_init_setattr("origin_node", None) + + def make_indexer(self) -> Callable[[Sequence[Expr]], Expr]: + return self.get_layout().make_indexer() + + def get_name(self) -> str: + assert self.name, self + return self.name + + def get_example(self) -> Union[torch.Tensor, torch.SymInt]: + if isinstance(self.layout, Layout): + return self.layout.get_example() + raise NotImplementedError(type(self.layout).__name__) + + def get_device(self) -> Optional[torch.device]: + return self.get_output_spec().get_device() + + def get_defining_op(self) -> Optional[Operation]: + return None + + @property + def dtype(self) -> torch.dtype: + return self.get_layout().dtype + + def get_size(self) -> Sequence[Expr]: + return [*self.get_layout().size] + + def get_stride(self) -> list[Expr]: + return [*self.get_layout().stride] + + def get_offset(self) -> Expr: + return self.get_layout().offset + + def get_layout(self) -> Layout: + if isinstance(self.layout, Layout): + return self.layout + raise NotImplementedError(type(self.layout).__name__) + + def get_output_spec(self) -> OutputSpec: + return self.layout + + def get_storage_numel(self) -> int: + return self.get_numel() + + def get_is_pinned(self) -> bool: + return self.get_layout().is_pinned + + def freeze_layout(self) -> None: + if isinstance(self.layout, Layout) and not isinstance( + self.layout, NonOwningLayout + ): + self.layout = self.layout.as_fixed() + + def freeze_layout_with_stride_order( + self, order: Sequence[int], allow_padding: bool = False + ) -> None: + assert isinstance(self.layout, FlexibleLayout), type(self.layout) + self.layout = self.layout.as_stride_order(order, allow_padding=allow_padding) + + def freeze_layout_with_fill_order(self, order: Sequence[int]) -> None: + assert isinstance(self.layout, FlexibleLayout), type(self.layout) + self.layout = self.layout.as_fill_order(order) + + def freeze_layout_with_same_order(self, stride: Sequence[int]) -> None: + assert isinstance(self.layout, FlexibleLayout), type(self.layout) + self.layout = self.layout.as_same_order(stride) + + def freeze_layout_with_exact_strides( + self, exact_strides: Sequence[int], allow_padding: bool = False + ) -> None: + assert isinstance(self.layout, FlexibleLayout), type(self.layout) + self.layout = self.layout.as_exact_strides( + exact_strides, allow_padding=allow_padding + ) + + def is_zero_elements(self) -> bool: + return V.graph.sizevars.statically_known_true(sympy.Eq(self.get_numel(), 0)) + + def make_loader(self) -> Callable[[Sequence[Expr]], OpsValue]: + # Loading from a zero-element buffer is a no-op + if self.is_zero_elements(): + return partial(nop_loader_fn, dtype=self.get_dtype()) + + def loader(index: Sequence[Expr]) -> OpsValue: + indexer = self.make_indexer() + return ops.load(self.name or "unnamed", indexer(index)) + + return loader + + def codegen_reference(self, writer: Optional[IndentedBuffer] = None) -> str: + return self.get_name() + + def decide_layout(self) -> None: + pass + + def get_inputs_that_alias_output(self) -> Sequence[str]: + if isinstance(self.layout, NonOwningLayout): + return [self.layout.view.get_name()] + return () + + def get_mutation_names(self) -> Sequence[str]: + if isinstance(self.layout, MutationLayoutSHOULDREMOVE): + return [self.layout.target.get_name()] + return () + + def get_read_names(self) -> OrderedSet[str]: + return OrderedSet([self.get_name()]) + + @cache_on_self_and_args("Buffer") + def get_free_symbol_uses( + self, unbacked_only: bool = False + ) -> OrderedSet[sympy.Symbol]: + return OrderedSet() + + def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: + return OrderedSet() + + def realize(self) -> Optional[str]: + pass + + def should_allocate(self) -> bool: + # Returns False by default. + return False + + +@ir_dataclass(frozen=False) +class OperationBuffer(Buffer, Operation): + # An operation that produces a single output buffer + def get_outputs(self) -> list[Buffer]: + return [self] + + def get_defining_op(self) -> Operation: + return self + + # Skip implementation in Buffer + get_operation_name = Operation.get_operation_name + + def __post_init__(self) -> None: + Buffer.__post_init__(self) + Operation.__post_init__(self) + + +class InputBuffer(Buffer): + def num_reads(self) -> int: + return 1 + + +class DonatedBuffer(InputBuffer): + """ + Represents a donated buffer which is a saved tensor that is not alias to any + fwd inputs, fwd user outputs, and bwd outputs. We generally cannot inplace + reuse the input tensor memory during backward since it might be used in another + function. However, donated buffer can be inplace reused during backward + to save memory. + """ + + +class ConstantBuffer(InputBuffer): + override_device: Optional[torch.device] = None + + def make_loader(self) -> Callable[[Sequence[Expr]], OpsValue]: + def loader(index: Sequence[Expr]) -> OpsValue: + indexer = self.get_layout().make_indexer() + return ops.load( + V.graph.constant_name(self.get_name(), self.override_device), + indexer(index), + ) + + return loader + + def constant_to_device(self, device: torch.device) -> IRNode: + return ConstantBuffer( + name=V.graph.constant_name(self.get_name(), device), layout=self.layout + ) + + +@ir_dataclass +class NoneAsConstantBuffer(IRNode): + def get_reads(self) -> OrderedSet[Dep]: + return OrderedSet() + + @cache_on_self_and_args("NoneAsConstantBuffer") + def get_free_symbol_uses( + self, unbacked_only: bool = False + ) -> OrderedSet[sympy.Symbol]: + return OrderedSet() + + def codegen_reference(self, writer: Optional[IndentedBuffer] = None) -> str: + return V.graph.wrapper_code.none_str + + def get_output_spec(self) -> OutputSpec: + return NoneLayout(device=None) + + def has_tensor_output(self) -> bool: + return False + + +@ir_dataclass +class ShapeAsConstantBuffer(IRNode): + expr: Expr + + @cache_on_self_and_args("ShapeAsConstantBuffer") + def get_free_symbol_uses( + self, unbacked_only: bool = False + ) -> OrderedSet[sympy.Symbol]: + return get_free_symbols(self.expr, unbacked_only) + + def codegen_reference(self, writer: Optional[IndentedBuffer] = None) -> str: + return V.graph.wrapper_code.codegen_sizevar(self.expr) + + def has_tensor_output(self) -> bool: + return False + + +@ir_dataclass(frozen=False) +class ComputedBuffer(OperationBuffer): + """ + Represents a buffer that is computed during kernel execution rather than being an input. + """ + + data: Loops + _force_realize: ClassVar[bool] = False + + # fields for split reduction + _split_size: Optional[int] = None + _original_inner_fn: Optional[Callable[..., Any]] = None + _original_ranges: Optional[Sequence[_IntLike]] = None + _original_reduction_ranges: Optional[Sequence[_IntLike]] = None + + @contextlib.contextmanager + def with_original_inner_fn(self) -> Iterator[None]: + assert self._split_size is not None + assert self._original_inner_fn is not None + assert self._original_ranges is not None + assert self._original_reduction_ranges is not None + + assert isinstance(self.data, Reduction), f"{type(self.data)}" + old_data = self.data + old_layout = self.layout + try: + new_data = Reduction( + device=old_data.device, + dtype=old_data.dtype, + inner_fn=self._original_inner_fn, + ranges=self._original_ranges, + reduction_ranges=self._original_reduction_ranges, + reduction_type=old_data.reduction_type, + src_dtype=old_data.src_dtype, + reduction_hint=old_data.reduction_hint, + ) + self.data = new_data + # this layout does not matter since we skip tl.store + # later + self.layout = FixedLayout( + old_data.device, + old_data.dtype, + self._original_ranges, + ) + self.get_default_sizes_body.clear_cache(self) + yield + finally: + self.data = old_data + self.layout = old_layout + + @staticmethod + @contextlib.contextmanager + def force_realize() -> Iterator[None]: + old_value = ComputedBuffer._force_realize + try: + ComputedBuffer._force_realize = True + yield + finally: + ComputedBuffer._force_realize = old_value + + def get_computed_buffer_name(self) -> Optional[str]: + """ + Returns self.name if it exists, otherwise returns the name of the data node if that exists. + If neither exist, returns None. + """ + if self.name is not None: + return self.name + if hasattr(self.data, "name"): + return self.data.name + return None + + def num_reads(self) -> int: + return self.data.num_reads() + + def get_reads(self) -> OrderedSet[Dep]: + return self.data.get_reads() + + def get_read_names(self) -> OrderedSet[str]: + return self.data.get_read_names() + + def get_read_writes(self) -> dependencies.ReadWrites: + if not isinstance(self.data, (Reduction, Scan, Sort, Pointwise)): + return dependencies.ReadWrites( + reads=OrderedSet(), + writes=OrderedSet(), + index_exprs=OrderedSet(), + ) + + with patch.object(FlexibleLayout, "allow_indexing", True): + if self.data.get_reduction_type(): + return extract_read_writes( + self.get_store_function(), + self.data.get_pointwise_size(), + self.data.get_reduction_size(), + ) + else: + return extract_read_writes( + self.get_store_function(), + self.data.get_size(), + ) + + @cache_on_self_and_args("ComputedBuffer") + def get_free_symbol_uses( + self, unbacked_only: bool = False + ) -> OrderedSet[sympy.Symbol]: + # Ordinarily, we'd like to just peek at the arguments list, + # but ComputedBuffers have no argument list. + # + # Morally, this logic needs to be synchronized with the + # KernelArgs.size calls, which are responsible for making symbols make + # there way as kernel arguments (and it is precisely passing in one of + # those symbols that establishes a dependency). However, we haven't + # started codegen yet so we can't directly reuse that logic. + # + # One thing you might wonder is if this is enough for a ComputedBuffer + # denoting a reduction over i0. Empirically, it is enough, but for an + # unusual reason: we only need accurate dependencies for item() call, + # but it's impossible to end up with a reduction over i0 from an + # item() call without a regular non-reduction buffer first. + result = self.layout.get_free_symbol_uses( + unbacked_only + ) | self.data.get_free_symbol_uses(unbacked_only) + + if self.has_store_function(): + result |= self.get_read_writes().get_free_symbol_uses(unbacked_only) + return result + + def make_loader(self) -> Callable[[Sequence[Expr]], OpsValue]: + if ( + not self.get_reduction_type() + and self.name not in V.graph.mutated_buffers + and self.num_reads() == 0 + and not self._force_realize + ): + # inline this op rather than generating ops.load() + return self.data.make_loader() + return super().make_loader() + + def has_store_function(self) -> bool: + return isinstance(self.data, (Reduction, Scan, Sort, Pointwise)) + + def get_store_function(self) -> Callable[..., None]: + indexer = self.get_layout().as_fixed().make_indexer() + if isinstance(self.data, (Reduction, Scan, Sort)): + return partial(self.data.store_reduction, self.name, indexer) + else: + assert isinstance(self.data, Pointwise), type(self.data) + return partial(self.data.store_output, self.name, indexer) + + def get_fill_order(self) -> Optional[list[int]]: + """ + If our layout is still flexible, try to determine the stride order based on stride orders of reads. + + TODO(jansel): A better algorithm here would look at downstream consumers of this + value and try to do global graph-level layout optimization. + This is also something just begging to be autotuned. + """ + if isinstance(self.layout, FlexibleLayout): + (index_vars, reduction_vars), _ = dependencies.index_vars_squeeze( + self.data.get_pointwise_size(), self.data.get_reduction_size() + ) + reads = self.get_read_writes().reads + # only consider reads to buffer of same size + # ignore StarDeps because they don't contribute stride information + assert all( + isinstance(r, (dependencies.StarDep, dependencies.MemoryDep)) + for r in reads + ) + reads = [ + sympy_subs(r.index, {v: sympy.S.Zero for v in reduction_vars if v != 0}) + for r in reads + if isinstance(r, dependencies.MemoryDep) + ] + + if reads: + if isinstance(self.data, (Scan, Sort)): + indices = self.data.reindex(index_vars, reduction_vars) + else: + indices = index_vars + stride_lengths = [ + V.graph.sizevars.stride_hints(expr, indices) for expr in reads + ] + from .scheduler import pick_loop_order + + return pick_loop_order(stride_lengths, self.get_size()) + + return None + + def decide_layout(self) -> None: + if isinstance(self.layout, FlexibleLayout): + order = self.get_fill_order() + if order: + self.freeze_layout_with_fill_order(order) + else: + self.freeze_layout() + + @cache_on_self + def get_default_sizes_body( + self, + ) -> tuple[ + tuple[list[Expr], list[Expr]], + LoopBody, + tuple[list[Expr], list[Expr]], + ]: + args, var_ranges = dependencies.index_vars_squeeze( + self.get_pointwise_size(), self.get_reduction_size(), prefix="q" + ) + with patch.object(ConstantBuffer, "override_device", self.get_device()): + body = LoopBody( + self.get_store_function(), + (args if self.get_reduction_type() else args[:1]), + var_ranges, + *args, + ) + index_vars = [] + reduce_vars: list[Any] = [] + index_size = [] + reduce_size = [] + for v, s in var_ranges.items(): + if v in args[0]: + assert not reduce_vars + index_vars.append(v) + index_size.append(s) + else: + assert v in args[1] + reduce_vars.append(v) + reduce_size.append(s) + return (index_size, reduce_size), body, (index_vars, reduce_vars) + + def simplify_and_reorder( + self, + extra_indexing_constraints: Optional[tuple[dict[Any, Any], list[Any]]] = None, + recompute_sizes_body_func: Optional[Callable[..., Any]] = None, + ) -> tuple[tuple[list[Expr], list[Expr]], Optional[LoopBody]]: + """ + This is a main place where we do loop transformations in a + backend-agnostic way. + + Here we: + 1) Remove any 1 dimensions + 2) Fuse contiguous dimensions together + 3) Reorder dimensions based on stride orders + + Optional argument extra_indexing_constraints can be used to append additional + indexing expressions to existing ones derived from buffer's body. This can be useful + to fuse scheduler nodes with compatible ranges, e.g. (s0*s1*...,) and (s0, s1, s2, ...) + on CPU by preventing indexing simplifications and obtaining index/reduce ranges for + the scheduler node compatible with other nodes. + Optional argument recompute_sizes_body_func can be used to recompute sizes and body + on the default body. This can be useful to append additional loop transformations. + """ + ( + (index_size, reduce_size), + body, + (index_vars, reduce_vars), + ) = self.get_default_sizes_body() + + if recompute_sizes_body_func: + ( + (index_size, reduce_size), + body, + (index_vars, reduce_vars), + ) = recompute_sizes_body_func( + (index_size, reduce_size), body, (index_vars, reduce_vars) + ) + + index_formulas = [*body.indexing_exprs.values()] + if extra_indexing_constraints is not None: + assert ( + isinstance(extra_indexing_constraints, tuple) + and len(extra_indexing_constraints) == 2 + ) + extra_indexing_ranges, extra_indexing_expr = extra_indexing_constraints + assert isinstance(extra_indexing_ranges, dict), type(extra_indexing_ranges) + assert isinstance(extra_indexing_expr, list), type(extra_indexing_expr) + assert all(isinstance(f, Expr) for f in extra_indexing_expr) + + expected_var_ranges = body.var_ranges + assert expected_var_ranges == extra_indexing_ranges, ( + expected_var_ranges, + extra_indexing_ranges, + ) + # remove already existing expressions + extra_indexing_expr = [ + e for e in extra_indexing_expr if e not in index_formulas + ] + index_formulas += extra_indexing_expr + + memory_addrs = [*body.get_write_exprs()] + if not V.graph.has_feature(self, BackendFeature.PREFER_STORE_LOOP_ORDER): + memory_addrs.extend(body.get_read_exprs()) + + def simplify_and_reorder( + x_vars: Sequence[sympy.Symbol], + support_vars: Sequence[sympy.Symbol], + sizes: Sequence[int], + simplify_loops: bool, + ) -> tuple[ + list[int], + Callable[[Sequence[int]], Sequence[int]], + Callable[[Sequence[int]], Sequence[int]], + ]: + newsizes, reindex0, reindex1 = self._apply_loop_reordering( + x_vars, support_vars, sizes, memory_addrs + ) + + # When using native matmul, the codegen assumes the following loop order, + # regardless of the stride of A and B: + # + # for z -> y -> x -> r: C[z, y, x] += A[z, y, r] * B[z, r, x] + # or + # for z -> x -> y -> r: C[z, y, x] += A[z, y, r] * B[z, r, x] + # + # The critical point is the position of the "z" (batch) axis in bmm. + # It is fine to swap the y and x axes (e.g., (z, y, x, r) or (z, x, y, r)), + # but reordering the z axis (e.g., (y, x, z, r)) breaks codegen. + # + # Therefore, if loop reordering changes the "z" location in bmm, + # it should be reverted to the default. + # This may not always produce the optimal loop order when strides + # do not align with the default assumption. + # + # TODO: Consider extending tl.dot codegen to support arbitrary loop orders. + if self.get_reduction_type() == "dot" and len(sizes) == 3: + order = list(range(len(sizes))) # default order + + # if z axis is not the outermost, use the default reorder. + if reindex0(order)[0] != 0: + newsizes = [sizes[i] for i in order] + reindex0 = same_reorder(order) + reindex1 = inverse_reorder(order) + + # for NHWC: reindex0([0,1,2,3]) = [0,2,3,1], reindex1([0,1,2,3]) = [0,3,2,1] + x_vars = reindex0(x_vars) + + if simplify_loops: + newsizes, reindex2, _prune = V.graph.sizevars._simplify_loops( + x_vars, + newsizes, + index_prevent_reordering(index_formulas, x_vars, newsizes), + ) + reindex = fuse_reindexing(reindex1, reindex2) + else: + reindex = reindex1 + return newsizes, reindex, reindex1 + + support_vars = index_vars + reduce_vars + should_merge_loops = ( + not is_gpu(get_device_type(self)) or not config.loop_ordering_after_fusion + ) + iter_ranges, iter_reindex, _ = simplify_and_reorder( + index_vars, + support_vars, + index_size, + should_merge_loops, + ) + + # Like iteration dimensions, we may also want to delay merging reduction dimensions. + # E.g., if we reduce a tensor [M, N, K] for its M and N dimensions followed by a pointwise + # kernel, merging M and N dimension too early makes it hard to decide what loop order + # we should pick for the piontwise kernel so that it is fusible with the reduction. + reduce_ranges, reduce_reindex, _ = simplify_and_reorder( + reduce_vars, support_vars, reduce_size, should_merge_loops + ) + + # retrace the loop body with simplification and reordering applied + (iter_vars, reduce_vars), var_ranges = dependencies.index_vars_no_squeeze( + iter_ranges, + reduce_ranges, + prefix="p", + ) + body = LoopBody( + body, + [iter_reindex(iter_vars), reduce_reindex(reduce_vars)], + var_ranges, + iter_vars, + reduce_vars, + ) + return (iter_ranges, reduce_ranges), body + + @staticmethod + def _apply_loop_reordering( + index_vars: Sequence[sympy.Symbol], + support_vars: Sequence[sympy.Symbol], + sizes: Sequence[int], + memory_addrs: list[sympy.Expr], + priority_idx: Optional[list[int]] = None, + ) -> tuple[ + list[int], + Callable[[Sequence[int]], Sequence[int]], + Callable[[Sequence[int]], Sequence[int]], + ]: + """ + Shuffle the order of loops around to hopefully improve performance. + """ + from .scheduler import pick_loop_order + + if priority_idx is None: + priority_idx = [] + + try: + strides = [ + V.graph.sizevars.stride_hints(expr, index_vars, support_vars) + for expr in memory_addrs + ] + assert len(strides) == len(memory_addrs) and len(strides[0]) == len( + index_vars + ) + order = list(reversed(pick_loop_order(strides, sizes, priority_idx))) + except Exception: + if config.debug: + log.warning( + "Did not simplify complex index:\n%s\n%s", + dict(zip(index_vars, sizes)), + memory_addrs, + ) + order = list(range(len(sizes))) + sizes = [sizes[i] for i in order] + return sizes, same_reorder(order), inverse_reorder(order) + + def get_pointwise_size(self) -> Sequence[Expr]: + return self.data.get_pointwise_size() + + def get_reduction_size(self) -> Sequence[Expr]: + return self.data.get_reduction_size() + + def get_reduction_type(self) -> Optional[str]: + return self.data.get_reduction_type() + + def is_no_op(self) -> bool: + return self.data.is_zero_elements() + + def should_allocate(self) -> bool: + return True + + def constant_to_device(self, device: torch.device) -> IRNode: + """Move this to a given device. Requires that all reads are to constants.""" + return self.data.constant_to_device(device) + + +class TemplateBuffer(OperationBuffer): + """ + Represents a Triton (in the future other type) of template operator + that we can fuse an epilogue onto. + """ + + def __init__( + self, + layout: OutputSpec, + inputs: Sequence[IRNode], + make_kernel_render: Optional[Callable[..., Any]], + ) -> None: + super().__init__(name=None, layout=layout) + self.inputs = InputsKernel.unwrap_storage(inputs) + self.make_kernel_render = make_kernel_render + self.name = V.graph.register_buffer(self) + V.graph.register_operation(self) + + def get_read_writes(self) -> dependencies.ReadWrites: + return self.extract_read_writes(normalize=True) + + def extract_read_writes(self, normalize: bool = False) -> dependencies.ReadWrites: + name = self.get_name() + indexer = self.get_layout().make_indexer() + + def dummy(index: Sequence[Any], rindex: Sequence[Any]) -> Any: + assert len(rindex) == 0 + return ops.store(name, indexer(index), "fake") + + deps = dependencies.extract_read_writes( + dummy, self.get_size(), (), normalize=normalize + ) + + for inp in self.inputs: + assert isinstance(inp, (ReinterpretView, Buffer)), type(inp) + assert isinstance(inp.layout, Layout), type(inp.layout) + + indexer = inp.layout.make_indexer() + + def dummy(index: Sequence[Any], rindex: Sequence[Any]) -> Any: + assert len(rindex) == 0 + # pyrefly: ignore [missing-attribute] + return ops.load(inp.get_name(), indexer(index)) + + deps.reads |= dependencies.extract_read_writes( + dummy, inp.get_size(), (), normalize=normalize + ).reads + + return deps + + def get_reduction_size(self) -> Sequence[Expr]: + return sympy.S.One + + def get_reduction_type(self) -> Optional[str]: + return None + + def should_allocate(self) -> bool: + return True + + def simplify_and_reorder( + self, + extra_indexing_constraints: Optional[tuple[dict[Any, Any], list[Any]]] = None, + recompute_sizes_body_func: Optional[Callable[..., Any]] = None, + ) -> tuple[tuple[Sequence[Expr], list[Expr]], Optional[LoopBody]]: + return ( + ( + self.get_size(), + [], + ), + None, + ) + + +class TritonTemplateBuffer(TemplateBuffer): + def __init__( + self, + layout: Layout, + inputs: Sequence[IRNode], + make_kernel_render: Optional[Callable[_P, _T]], + mutated_inputs: Optional[Iterable[IRNode]] = None, + allowed_prologue_inps: Optional[OrderedSet[str]] = None, + ) -> None: + """ + NOTE:[TritonTemplates with multiple outputs] + We want the ability for TritonTemplates to output multiple tensors. Triton + kernels have no notion of outputs and this is done by creating tensors that + are then mutated by the kernel. Currently our STORE_OUTPUT codegen doesn't + support creating multinode outputs for triton templates. + We work around this by creating an extra input buffer during the lowering + and we mark them as mutated inputs. + """ + super().__init__(layout, inputs, make_kernel_render) + self.mutated_inputs = mutated_inputs + self.outputs: list[Buffer] = [self] + if mutated_inputs is not None: + # Ensure that the mutated inputs are only allowed for certain nodes + allowed_set = ( + torch.ops.higher_order.flex_attention, + torch.ops.higher_order.flex_attention_backward, + ) + current_node = V.graph.current_node.target + assert current_node in allowed_set, ( + f"Mutated inputs are only allowed for {allowed_set} but got {current_node}" + ) + assert isinstance(self.inputs[0], IRNode), type(self.inputs[0]) + device = self.inputs[0].get_device() + self.outputs += [ + MutationOutput(NoneLayout(device=device), buf, self) + for buf in mutated_inputs + ] + + self.allowed_prologue_inps = ( + allowed_prologue_inps if allowed_prologue_inps else OrderedSet() + ) + + self.subgraph_inps: Optional[list[Optional[Union[IRNode, sympy.Expr]]]] = None + self.subgraph_outs: Optional[list[Optional[IRNode]]] = None + + @cache_on_self_and_args("TritonTemplateBuffer") + def get_free_symbol_uses( + self, unbacked_only: bool = False + ) -> OrderedSet[sympy.Symbol]: + res = super().get_free_symbol_uses(unbacked_only) + subgraph_outs = self.subgraph_outs if self.subgraph_outs else [] + subgraph_inps = self.subgraph_inps if self.subgraph_inps else [] + + for inp in subgraph_inps: + if isinstance(inp, sympy.Expr): + res.update(get_free_symbols(inp, unbacked_only)) + elif isinstance(inp, IRNode): + res.update(inp.get_free_symbol_uses(unbacked_only)) + else: + assert inp is None + + for out in subgraph_outs: + if isinstance(out, IRNode): + res.update(out.get_free_symbol_uses(unbacked_only)) + else: + assert out is None + + return res + + def get_outputs(self) -> list[Buffer]: + return self.outputs + + def get_allowed_prologue_inps(self) -> OrderedSet[str]: + return self.allowed_prologue_inps + + def __str__(self) -> str: + out = f"TritonTemplateBuffer(layout={self.layout})" + return out + + +PrimitiveInfoType = Union[int, float, bool, str, list[Union[int, str, float, bool]]] + + +class ChoiceCaller: + """ + Represents a possible choice used in autotune_process.py. + During autotuning, self.benchmark() is first called to get benchmark result, + and if this choice is selected, self.output_node() is called to get the output_node. + + Children classes: TritonTemplateCaller, CUDATemplateCaller. + """ + + def __init__( + self, + name: str, + input_nodes: list[Buffer], + layout: Layout, + description: str, + ) -> None: + super().__init__() + self.name = name + self.layout = layout + self.input_nodes = input_nodes + # An additional description used to describe the choice (useful for + # knowing what autotuning is choosing) + self.description = description + self.failed: bool = False + # A place to store annotations that can be read post benchmarking + # Use this to shuttle information between ChoieCaller generation + # and the end of benchmarking + self.annotations: dict[Any, Any] = {} + + def benchmark(self, *args: Any, out: torch.Tensor) -> float: + algo = self.to_callable() + benchmark_configs = { + "warmup": autotune_warmup, + "rep": autotune_rep, + } + if config.profile_bandwidth_with_do_bench_using_profiling: + return do_bench_using_profiling(lambda: algo(*args), **benchmark_configs) # type: ignore[arg-type] + return benchmarker.benchmark( + algo, args, {"out": out}, device=None, **benchmark_configs + ) + + def call_name(self) -> str: + raise NotImplementedError + + def to_callable(self) -> Callable[..., Any]: + raise NotImplementedError + + def kernel_hash_key(self) -> str: + """ + Hash key for the underlying kernel. By default, we assume there are no + runtime params, so kernel hash key defaults to choice caller's hash key. + """ + return self.hash_key() + + def hash_key(self) -> str: + raise NotImplementedError + + def output_node(self) -> Union[TensorBox, ShapeAsConstantBuffer]: + raise NotImplementedError + + def info_dict(self) -> dict[str, Union[PrimitiveInfoType, list[PrimitiveInfoType]]]: + """Information returned here is logged to the autotune log file when that is enabled.""" + return {} + + def autoheuristic_id(self) -> str: + return "unsupported_choice" + + def mark_failed(self) -> None: + """ + Mark the choice as failed so that it can be + removed later. Useful for when we decouple + compilation and tuning. + """ + self.failed = True + + +class TritonTemplateCallerBase(ChoiceCaller): + def get_make_kernel_render(self) -> Any: + raise NotImplementedError + + +class MultiTemplateBuffer(TritonTemplateBuffer): + """ + Represents a Buffer with multiple backing implementation choices. + + Choices can be TritonTemplates or ExternKernels. During scheduling if there is a potential + epilogue we will benchmark each of the choices with the epilogue to determine an implementation. + Otherwise, the fastest base choice will be chosen. + """ + + def __init__( + self, + layout: Layout, + inputs: Sequence[IRNode], + choice_timings_fn: Callable[[Optional[int]], dict[ChoiceCaller, float]], + unfiltered_choices: list[ChoiceCaller], + allowed_prologue_inps: OrderedSet[str], + ) -> None: + super().__init__( + layout=layout, + inputs=inputs, + make_kernel_render=None, + allowed_prologue_inps=allowed_prologue_inps, + ) + self._choice_timings_fn = choice_timings_fn + self._choice_timings: dict[Optional[int], dict[ChoiceCaller, float]] = {} + self.original_inputs = inputs + self._output_plannable = all( + isinstance(choice, TritonTemplateCallerBase) + or ( + isinstance(choice, torch._inductor.select_algorithm.ExternKernelCaller) + and choice.has_out_variant + ) + for choice in unfiltered_choices + ) + self._make_kernel_renders: dict[Optional[int], Any] = {} + + @property + def output_plannable(self) -> bool: + """ + Are all possible choices TritonTemplates or Extern Kernels with out variants + """ + return self._output_plannable + + def choice_timings( + self, hint_override: Optional[int] = None + ) -> dict[ChoiceCaller, float]: + if hint_override not in self._choice_timings: + self._choice_timings[hint_override] = self._choice_timings_fn(hint_override) + return self._choice_timings[hint_override] + + @contextlib.contextmanager + def swap_as_triton_caller(self, caller: TritonTemplateCallerBase) -> Iterator[None]: + assert isinstance( + caller, torch._inductor.select_algorithm.TritonTemplateCaller + ), type(caller) + assert self.layout == caller.layout + + render = self.make_kernel_render + self.make_kernel_render = caller.get_make_kernel_render() + try: + yield + finally: + self.make_kernel_render = render + + def finalize_as_triton_caller(self, caller: TritonTemplateCallerBase) -> None: + assert isinstance( + caller, torch._inductor.select_algorithm.TritonTemplateCaller + ), type(caller) + assert self.get_size() == caller.layout.size + assert self.get_stride() == caller.layout.stride + self.make_kernel_render = caller.get_make_kernel_render() + + def get_min_choice( + self, hint_override: Optional[int] = None + ) -> tuple[ChoiceCaller, float]: + timings = self.choice_timings(hint_override=hint_override) + min_choice = min(timings, key=timings.get) # type: ignore[arg-type] + return (min_choice, timings[min_choice]) + + def finalize_as_triton_callers( + self, callers: dict[Optional[int], TritonTemplateCallerBase] + ) -> None: + """Finalize with multiple callers for different hint overrides""" + for hint_override, caller in callers.items(): + self._make_kernel_renders[hint_override] = caller.get_make_kernel_render() + + # Set the default to be the one without hint override + self.make_kernel_render = self._make_kernel_renders[None] + + +class CUDATemplateBuffer(TemplateBuffer): + def __init__( + self, + layout: Layout, + inputs: Sequence[IRNode], + make_kernel_render: Callable[_P, _T], + workspace_size: int, + template: CUDATemplate, + supports_epilogue_fusion: bool, + ) -> None: + super().__init__(layout, inputs, make_kernel_render) + # Global memory (in bytes) needed for this template. + self.workspace_size = workspace_size + self.template = template + self.supports_epilogue_fusion = supports_epilogue_fusion + + def get_workspace_size(self) -> int: + return self.workspace_size if self.workspace_size is not None else 0 + + def emulate_store_fn(self) -> None: + for output in self.get_outputs(): + ops.store(output.get_name(), None, None) + + +class CppTemplateBuffer(TemplateBuffer): + def __init__( + self, + layout: Layout, + inputs: Sequence[IRNode], + make_kernel_render: Callable[_P, _T], + template: CUDATemplate, + choice: Any, + ) -> None: + super().__init__(layout, inputs, make_kernel_render) + self.template = template + self.choice = choice + self.outputs: Optional[list[Buffer]] = None + + def get_layout(self) -> Layout: + if isinstance(self.layout, MultiOutputLayout): + assert isinstance(self.outputs, Iterable), type(self.outputs) + # pyrefly: ignore [index-error] + first_output = self.outputs[0] + assert isinstance(first_output, Buffer), type(first_output) + layout = first_output.layout + assert isinstance(layout, Layout), type(layout) + return layout + else: + return super().get_layout() + + +class CuteDSLTemplateBuffer(TemplateBuffer): + """ + Buffer for CuteDSL (CUTLASS Python DSL) template kernels. + Similar to other template buffers but specialized for CuteDSL operations. + """ + + def __init__( + self, + layout: Layout, + inputs: Sequence[IRNode], + make_kernel_render: Callable[_P, _T], + template: Any, + mutated_inputs: Optional[Iterable[IRNode]] = None, + ) -> None: + super().__init__(layout, inputs, make_kernel_render) + self.template = template + self.mutated_inputs = mutated_inputs + self.outputs: list[Buffer] = [self] + + if mutated_inputs is not None: + assert isinstance(self.inputs[0], IRNode), type(self.inputs[0]) + device = self.inputs[0].get_device() + self.outputs += [ + MutationOutput(NoneLayout(device=device), buf, self) + for buf in mutated_inputs + ] + + def get_outputs(self) -> list[Buffer]: + return self.outputs + + +def is_node_sequence( + nodes: Sequence[Union[IRNode, Sequence[IRNode]]], +) -> TypeIs[Sequence[IRNode]]: + return all(isinstance(n, IRNode) for n in nodes) + + +@ir_dataclass(frozen=False) +class InputsKernel(OperationBuffer): + inputs: Sequence[Union[IRNode, Sequence[IRNode]]] + + def input_name(self, i: int) -> str: + input = self.inputs[i] + assert isinstance(input, IRNode) + return input.get_name() + + def get_read_writes(self) -> dependencies.ReadWrites: + reads = OrderedSet[dependencies.Dep]() + StarDep = dependencies.StarDep + for input in self.inputs: + if isinstance(input, Sequence): + reads.update(StarDep(x.get_name()) for x in input) + elif isinstance(input, ShapeAsConstantBuffer): + # Skip creating dependency for symbolics as they're visible globally + continue + else: + reads.add(StarDep(input.get_name())) + + writes = OrderedSet[dependencies.Dep]( + StarDep(buf.get_name()) for buf in self.get_outputs() + ) + + return dependencies.ReadWrites( + reads=reads, + writes=writes, + index_exprs=OrderedSet(), + ) + + def get_reads(self) -> OrderedSet[Dep]: + return self.get_read_writes().reads + + @classmethod + def unwrap_storage_for_input(cls, x: IRNode) -> IRNode: + if isinstance(x, TensorBox): + x = x.data + if isinstance(x, StorageBox): + x = x.data + if isinstance(x, BaseView) and not isinstance(x, ReinterpretView): + x = ExternKernel.realize_input(x) + if isinstance(x, TensorBox): + # when converting to ReinterpretView fails in the + # realize_input call above, the result will be wrapped + # into TensorBox / StorageBox pair as a result of the + # cls.copy_input call; so we should unwrap recursively + return cls.unwrap_storage_for_input(x) + if isinstance(x, TorchBindObject): + return x + assert isinstance(x, (Buffer, ReinterpretView)), type(x) + return x + + @staticmethod + def unwrap_storage( + inputs: Sequence[Union[IRNode, Sequence[IRNode]]], + ) -> list[Union[IRNode, Sequence[IRNode]]]: + inputs_new: list[Union[IRNode, Sequence[IRNode]]] = [] + for x in inputs: + if isinstance(x, Sequence): + x = [InputsKernel.unwrap_storage_for_input(i) for i in x] + else: + x = InputsKernel.unwrap_storage_for_input(x) + inputs_new.append(x) + return inputs_new + + def is_extern(self) -> bool: + return True + + def num_reads(self) -> int: + return 1 + + @cache_on_self_and_args("InputsKernel") + def get_free_symbol_uses( + self, unbacked_only: bool = False + ) -> OrderedSet[sympy.Symbol]: + r = OrderedSet[sympy.Symbol]() + for inp in self.inputs: + if isinstance(inp, IRNode): + r |= inp.get_free_symbol_uses(unbacked_only) + else: + for inner_inp in inp: + r |= inner_inp.get_free_symbol_uses(unbacked_only) + return r + + +class NopKernel(InputsKernel): + def is_no_op(self) -> bool: + return True + + def get_reads(self) -> OrderedSet[Dep]: + return OrderedSet() + + +class ConcatKernel(NopKernel): + """ + There isn't actually a real kernel for concat, we just change the + storage for the upstream data. + """ + + @classmethod + def create(cls, inputs: Sequence[IRNode], dim: int) -> StorageBox: + """ + Create the concat kernel from inputs + """ + device = inputs[0].get_device() + dtype = inputs[0].get_dtype() + new_size = list(inputs[0].get_size()) + offsets_start = [0] + offsets_end = [new_size[dim]] + assert 0 <= dim < len(new_size) + for i in range(1, len(inputs)): + input_size = inputs[i].get_size() + offsets_start.append(new_size[dim]) + assert len(input_size) == len(new_size) + assert inputs[i].get_dtype() == dtype + assert inputs[i].get_device() == device + for j in range(len(new_size)): + if j == dim: + new_size[j] = new_size[j] + input_size[j] + else: + new_size[j] = V.graph.sizevars.check_equals_and_simplify( + new_size[j], input_size[j] + ) + offsets_end.append(new_size[dim]) + + output_stride: Sequence[int] = FlexibleLayout.contiguous_strides(new_size) + if config.comprehensive_padding: + # Ensure the output stride matches the alignment requirements + output_stride = Layout._pad_strides( + output_stride, new_size, inputs[0].dtype + ) + + # If any of the inputs is in CL format, use CL format for the output + for i in range(len(inputs)): + x = inputs[i] + if is_storage_and_layout(x): + layout = x.get_layout() + if isinstance( + layout, FixedLayout + ) and Layout.is_channels_last_contiguous(layout.size, layout.stride): + # use CL stride for the output + output_stride = make_channels_last_strides_for(new_size) + break + any_input_is_storage_and_layout = any(is_storage_and_layout(x) for x in inputs) + fx_node_args = V.graph.current_node.args[0] + assert isinstance(fx_node_args, list), type(fx_node_args) + # If any of the inputs has meta tensor and the meta tensor is in CL format, use CL format for the output + if any_input_is_storage_and_layout is False and any( + "val" in arg.meta + and ( + arg.meta["val"].is_contiguous(memory_format=torch.channels_last) + or arg.meta["val"].is_contiguous(memory_format=torch.channels_last_3d) + ) + for arg in fx_node_args + ): + output_stride = make_channels_last_strides_for(new_size) + + is_pinned = all( + is_storage_and_layout(x) and x.get_layout().is_pinned for x in inputs + ) + + assert device is not None + concat_kernel = ConcatKernel( + name=None, + layout=FixedLayout( + device=device, + dtype=dtype, + size=new_size, + stride=output_stride, + is_pinned=is_pinned, + ), + inputs=[], + ) + kernel = StorageBox(concat_kernel) + op_names = [] + for i, inp in enumerate(inputs): + assert isinstance(inp, (BaseView, MutableBox)), type(inp) + input_buffer = cls.realize_into( + inp, + SliceView.create( + kernel, dim, offsets_start[i], offsets_end[i], clamp=False + ), + ) + assert isinstance(input_buffer, Buffer), type(input_buffer) + assert isinstance(concat_kernel.inputs, list), type(concat_kernel.inputs) + concat_kernel.inputs.append(input_buffer) + + if isinstance(inp.data, BaseView): + input_unwrapped = inp.data.unwrap_view() + else: + input_unwrapped = inp.data + + if ( + isinstance(input_unwrapped, StorageBox) + and input_unwrapped.is_input_buffer() + and (dev := inp.get_device()) is not None + and is_gpu(dev.type) + and not is_dynamic(input_buffer) + ): + op_names.append(input_buffer.get_operation_name()) + + if len(op_names) > 1 and V.graph.has_feature(device, BackendFeature.FOREACH): + V.graph.register_operation_list(op_names) + + concat_kernel.name = V.graph.register_buffer(concat_kernel) + concat_kernel.inputs = cls.unwrap_storage(concat_kernel.inputs) + V.graph.register_operation(concat_kernel) + + return kernel + + @classmethod + def can_realize_into_without_copy( + cls, src: IRNode, dst: Optional[IRNode] = None + ) -> bool: + if isinstance(src, TensorBox): + # unwrap a TensorBox + return cls.can_realize_into_without_copy(src.data, dst) + + assert isinstance(src, (BaseView, StorageBox)), type(src) + if isinstance(src.data, MultiTemplateBuffer): + if ( + not isinstance(src.data.layout, FixedLayout) + or not src.data.output_plannable + ): + return False + + # we call can_realize_into_without_copy in cat lowering before we've decided + # on output format, optimistically assume layout matches + if dst is None: + return True + + # otherwise, check equality of layouts + if len(src.get_stride()) != len(dst.get_stride()): + return False + + return all( + V.graph.sizevars.statically_known_equals(s1, s2) + for s1, s2 in zip(src.get_stride(), dst.get_stride()) + ) + + return ( + hasattr(src.data, "layout") + and isinstance(src.data.layout, FlexibleLayout) + and not isinstance(src.data, ExternKernelAlloc) + ) + + @cache_on_self_and_args("ConcatKernel") + def get_free_symbol_uses( + self, unbacked_only: bool = False + ) -> OrderedSet[sympy.Symbol]: + return NopKernel.get_free_symbol_uses(self, unbacked_only) + + @classmethod + def realize_into(cls, src: IRNode, dst: IRNode) -> IRNode: + # Attempt to turn this into a ReinterpretView rather than assert. + # This has concessions around layout, as as_storage_and_layout + # can cause us to go from flexible to fixed layout. + if not isinstance(dst, ReinterpretView): + if is_storage_and_layout(dst): + storage, layout = as_storage_and_layout(dst) + dst = ReinterpretView(data=storage, layout=layout) + assert isinstance(dst, ReinterpretView), type(dst) + if isinstance(src, TensorBox): + # unwrap a TensorBox + return cls.realize_into(src.data, dst) + + if isinstance(src, StorageBox): + src.realize() + # ExternKernelAlloc has specific requirements for output layout, should create a copy + assert hasattr(src.data, "layout") + if cls.can_realize_into_without_copy(src, dst): + # pyrefly: ignore [missing-attribute] + src.data.layout = NonOwningLayout(dst) + return src.data + # introduce a copy + pw = Pointwise.create( + device=src.get_device(), + dtype=src.get_dtype(), + inner_fn=src.make_loader(), + ranges=[ + V.graph.sizevars.check_equals_and_simplify(a, b) + for a, b in zip(src.get_size(), dst.get_size()) + ], + ) + return cls.realize_into(pw, dst) + + def should_allocate(self) -> bool: + return True + + +@ir_dataclass(frozen=False) +class ExternKernel(InputsKernel): + """ + A class that represents Kernels which are not directly lowered to Inductor + Loop Level IR, such as custom operators, or aten operators which we fallback to. + """ + + constant_args: Sequence[Any] = () + kwargs: dict[str, Any] = dataclasses.field(default_factory=dict) + output_view: Optional[ReinterpretView] = None + python_kernel_name: Optional[str] = None + cpp_kernel_name: Optional[str] = None + # FIXME: in some cases we sill need to explicitly pass in ordered_kwargs_for_cpp_kernel + # We shouldn't need to do this since the information can be retrieved from op_overload._schema. + ordered_kwargs_for_cpp_kernel: Iterable[str] = dataclasses.field( + default_factory=list + ) + op_overload: Optional[_OpOverloads] = None + arg_properties: Optional[list[dict[str, Any]]] = None + allarg_properties: dict[str, dict[str, Any]] = dataclasses.field( + default_factory=dict + ) + kwarg_properties: Optional[dict[str, dict[str, Any]]] = None + unbacked_bindings: dict[sympy.Symbol, pytree.KeyPath] = dataclasses.field( + default_factory=dict + ) + mutation_outputs: list[MutationOutput] = dataclasses.field(default_factory=list) + + def __init__( + self, + name: Optional[str], + layout: OutputSpec, + inputs: Sequence[Union[IRNode, Sequence[IRNode]]], + constant_args: Sequence[Any] = (), + kwargs: Optional[dict[str, Any]] = None, + output_view: Optional[ReinterpretView] = None, + python_kernel_name: Optional[str] = None, + cpp_kernel_name: Optional[str] = None, + ordered_kwargs_for_cpp_kernel: Iterable[str] = (), + op_overload: Optional[_OpOverloads] = None, + ) -> None: + super().__init__( + name=name, + layout=layout, + inputs=inputs, + ) + self.constant_args = constant_args + self.kwargs = kwargs if kwargs else {} + self.output_view = output_view + self.op_overload = op_overload + self.set_cpp_kernel_name(cpp_kernel_name) + self.set_python_kernel_name(python_kernel_name) + self.ordered_kwargs_for_cpp_kernel = ordered_kwargs_for_cpp_kernel + self.collect_arg_kwarg_properties() + self.unbacked_bindings = {} + self.mutation_outputs = [] + self.fx_node = V.graph.current_node + + def get_outputs(self) -> list[Buffer]: + return [self, *self.mutation_outputs] + + def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: + return OrderedSet() + + def collect_arg_kwarg_properties(self) -> None: + # if self.op_overload is torch._ops.OpOverload, we can use its schema to collect additional + # information for args and kwargs, e.g. type and default value, to help with the cpp wrapper codegen + self.arg_properties = ( + [ + { + "name": x.name, + "type": x.real_type, + "default_value": x.default_value, + } + for x in self.op_overload._schema.arguments + if not x.kwarg_only + ] + if isinstance(self.op_overload, torch._ops.OpOverload) + else [{} for i in range(len(self.inputs))] + ) + self.allarg_properties = ( + { + x.name: {"type": x.real_type, "default_value": x.default_value} + for x in self.op_overload._schema.arguments + } + if isinstance(self.op_overload, torch._ops.OpOverload) + else {} + ) + # FIXME: self.kwargs does not always match kwargs defined in schema, so sometimes + # ordered_kwargs_for_cpp_kernel is explicitly passed in. + if isinstance(self.op_overload, torch._ops.OpOverload): + if not self.ordered_kwargs_for_cpp_kernel: + self.ordered_kwargs_for_cpp_kernel = [ + x.name for x in self.op_overload._schema.arguments if x.kwarg_only + ] + self.schema_kwargs = [ + x for x in self.op_overload._schema.arguments if x.kwarg_only + ] + else: + self.schema_kwargs = [] + + def decide_layout(self) -> None: + if isinstance(self.layout, FlexibleLayout): + self.apply_constraint() + self.freeze_layout() + + def codegen_comment( + self, wrapper: PythonWrapperCodegen, kernel_name: Optional[str] = None + ) -> None: + origin_str, _detailed_origin_str = get_kernel_metadata(self, wrapper) + if origin_str: + wrapper.make_comment(origin_str) + + if not kernel_name: + kernel_name = self.try_get_kernel_name() + if kernel_name: + from .debug import set_kernel_post_grad_provenance_tracing + + debug_handle = set_kernel_post_grad_provenance_tracing( + self, kernel_name, is_extern=True + ) + wrapper.write_provenance_debug_handle(kernel_name, debug_handle) + + def codegen(self, wrapper: PythonWrapperCodegen) -> None: + raise NotImplementedError + + def set_cpp_kernel_name(self, cpp_kernel_name: Optional[str] = None) -> None: + self.cpp_kernel_name = cpp_kernel_name + if not V.graph.cpp_wrapper or not isinstance( + self.op_overload, torch._ops.OpOverload + ): + return + + kernel = self.op_overload + if self.cpp_kernel_name is None: + # Try to construct cpp_kernel_name from op_overload + if kernel.namespace == "aten": + # Calling with the default kernel name can lead to ambiguous behavior like the following example. + # repeat_interleave(const at::Tensor & repeats, std::optional output_size=std::nullopt) + # repeat_interleave(const at::Tensor & self, int64_t repeats, + # std::optional dim=std::nullopt, std::optional output_size=std::nullopt) + opname = ( + kernel.__name__.split(".")[0] + if kernel._overloadname == "default" + else kernel.__name__.replace(".", "_") + ) + self.cpp_kernel_name = f"at::_ops::{opname}::call" + else: + self.cpp_kernel_name = kernel._schema.name + + def set_python_kernel_name(self, python_kernel_name: Optional[str]) -> None: + self.python_kernel_name = python_kernel_name + if python_kernel_name is not None: + return + + kernel = self.op_overload + if kernel is None: + pass + elif isinstance(kernel, torch._ops.HigherOrderOperator): + self.python_kernel_name = f"torch.ops.higher_order.{kernel.__name__}" + else: + self.python_kernel_name = ( + f"{kernel.__module__.replace('._ops.', '.ops.')}.{kernel.__name__}" + ) + + def try_get_kernel_name(self) -> Optional[str]: + from .codegen.cpp_wrapper_cpu import CppWrapperCpu + + device = d.type if (d := self.get_device()) else V.graph.device_type + if V.graph.fx_wrapper: + return self.python_kernel_name + elif V.graph.cpp_wrapper: + assert isinstance(V.graph.wrapper_code, CppWrapperCpu), type( + V.graph.wrapper_code + ) + if self.cpp_kernel_name is None: + return None + return V.graph.wrapper_code.get_c_shim_func_name( + self.cpp_kernel_name, device + ) + else: + return self.python_kernel_name + + def get_kernel_name(self) -> str: + name = self.try_get_kernel_name() + assert name is not None + return name + + @staticmethod + def copy_input(x: IRNode) -> Union[TensorBox, ShapeAsConstantBuffer]: + pw = Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=x.make_loader(), + ranges=x.get_size(), + origin_node=x.get_origin_node(), + traceback=x.get_traceback(), + ) + pw.realize() + return pw + + @classmethod + def process_kernel( + cls, kernel: _OpOverloads, *args: Any, **kwargs: Any + ) -> tuple[ + Any, + list[Any], + list[Any], + Callable[[Any, Any], Any], + Optional[dict[sympy.Symbol, pytree.KeyPath]], + ]: + binded_args = {"args": args, "kwargs": kwargs} + + args_flat, args_spec = pytree.tree_flatten(binded_args) + + is_arg_tensor = [] + # tensor_args can be either tensor or torchbind objects + tensor_args = [] + non_tensor_args: list[Any] = [] + for arg in args_flat: + is_arg_tensor.append( + isinstance(arg, IRNode) and not isinstance(arg, GeneratorState) + ) + if is_arg_tensor[-1]: + tensor_args.append(arg) + else: + if isinstance(arg, Expr): + arg = V.graph.sizevars.shape_env.create_symintnode(arg, hint=None) + non_tensor_args.append(arg) + + def unflatten_args( + new_tensor_args: Sequence[_T], new_non_tensor_args: Sequence[_T] + ) -> tuple[list[_T], dict[str, _T]]: + result = [] + it_tensors = iter(new_tensor_args) + it_non_tensors = iter(new_non_tensor_args) + for is_tensor in is_arg_tensor: + if is_tensor: + result.append(next(it_tensors)) + else: + result.append(next(it_non_tensors)) + r = pytree.tree_unflatten(result, args_spec) + return r.get("args", []), r.get("kwargs", {}) + + tensor_args = [cls.realize_input(x) for x in tensor_args] + + # freeze layout otherwise our output stride calculation might + # become incorrect + for x in tensor_args: + if is_storage_and_layout(x): + as_storage_and_layout(x, freeze=True) + + # Rerun fake tensor propagation, because Inductor may have changed the + # strides of inputs and we need to determine accurately what the + # output stride will be. + example_args: list[ + Union[ + torch.Tensor, torch._C.ScriptObject, FakeScriptObject, torch.Generator + ] + ] = [] + + # We need to retain the constant values of fake tensors that we originally + # propagated the graph with, because for some operators running without a + # constant would trigger an error / DataDependentException + for x in tensor_args: + # if x is a view of a constant, we need to realize the view + # (we can't pass the constant into the kernel directly) + if not isinstance(x, BaseView) and x.get_name() in V.graph.constants: + example_args.append(V.graph.constants[x.get_name()]) + elif ( + not isinstance(x, BaseView) + and x.get_name() in V.graph.torchbind_constants + ): + example_args.append(V.graph.torchbind_constants[x.get_name()]) + elif isinstance(x, TorchBindObject): + example_args.append(x.get_value()) + elif isinstance(x, torch._inductor.ir.GeneratorState): + device_index = x.device.index + assert x.device.type == "cuda" and device_index is not None + example_args.append( + torch.cuda.default_generators[device_index].clone_state() + ) + else: + example_args.append(ir_node_to_tensor(x, guard_shape=True)) + + new_args, new_kwargs = unflatten_args(example_args, non_tensor_args) + example_output = kernel(*new_args, **new_kwargs) + + unbacked_bindings: Optional[dict[sympy.Symbol, pytree.KeyPath]] = None + if shape_env := V.fake_mode.shape_env: + node_meta_val = V.current_node.meta.get("val") + ctx: AbstractContextManager[None] = nullcontext() + if V.current_node.target is torch._higher_order_ops.effects.with_effects: + # remove the first effect token in meta["val"] and meta["unbacked_bindings"] + node_meta_val = node_meta_val[1] + ctx = _remove_effect_token_unbacked_bindings(V.current_node) + + with ctx: + rebind_unbacked(shape_env, V.current_node, example_output) + unbacked_bindings = compute_unbacked_bindings( + shape_env, example_output, node_meta_val + ) + + example_out_li = ( + [example_output] + if not isinstance(example_output, (list, tuple)) + else example_output + ) + # When graph_partition is enabled, skip - partitioning handles sparse outputs + for t in example_out_li: + if ( + isinstance(t, torch.Tensor) + and t.is_sparse + and not config.graph_partition + ): + msg = "sparsity not handled. Please file issue for sparse inference weights." + if stack_trace := V.graph.current_node.meta.get("stack_trace", None): + msg = f"{msg} Found from : \n {stack_trace}" + V.graph.disable_cudagraphs_reason = msg + + return ( + example_output, + tensor_args, + non_tensor_args, + unflatten_args, + unbacked_bindings, + ) + + @classmethod + def convert_to_reinterpret_view(cls, x: IRNode) -> ReinterpretView: + """ + In order to pass this to an extern kernel we need a + ReinterpretView not a View. This allows us to avoid some + unneeded copies. + """ + assert isinstance(x, BaseView), type(x) + if isinstance(x, ReinterpretView): + return x + + # NOTE: Don't use extract_read_writes here as it fails when + # make_loader() inlines the computation + x_unwrap_view = x.unwrap_view() + buf = V.graph.get_buffer(x_unwrap_view.get_name()) + assert buf is not None + x_unwrap_view_fx_node = buf.get_origin_node() + # Prefer channels last format according to how the format is set from eager. + if ( + x_unwrap_view_fx_node is not None + and "val" in x_unwrap_view_fx_node.meta + and isinstance(x_unwrap_view, (ReinterpretView, Buffer, MutableBox)) + and isinstance(x_unwrap_view.layout, FlexibleLayout) + and ( + x_unwrap_view_fx_node.meta["val"].is_contiguous( + memory_format=torch.channels_last + ) + or x_unwrap_view_fx_node.meta["val"].is_contiguous( + memory_format=torch.channels_last_3d + ) + ) + ): + x_unwrap_view.freeze_layout_with_same_order( + make_channels_last_strides_for(x_unwrap_view.get_size()) + ) + else: + x_unwrap_view.freeze_layout() + + index_args, var_ranges = dependencies.index_vars_squeeze( + x.get_size(), prefix="r" + ) + range_vars = index_args[0] + index = x.make_indexer()(range_vars) + + index = V.graph.sizevars.simplify_with_ranges(index, var_ranges) + strides = V.graph.sizevars.stride_vars(index, range_vars) + offset = V.graph.sizevars.offset_var(index, range_vars) + expected = sympy_dot(range_vars, strides) + offset + + if index != expected: + log.debug( + "convert_to_reinterpret_view failed: stride=%s offset=%s index=%s", + strides, + offset, + index, + ) + raise NotImplementedError + + return ReinterpretView( + data=x.data, + layout=FixedLayout( + device=x.get_device_or_error(), + dtype=x.get_dtype(), + size=x.get_size(), + stride=strides, + offset=offset, + is_pinned=False, + ), + ) + + @classmethod + def realize_input(cls, x: IRNode) -> IRNode: + if x is None: + return NoneAsConstantBuffer() + if isinstance(x, (Expr, sympy.logic.boolalg.Boolean, int)): + return ShapeAsConstantBuffer(expr=x) + if isinstance(x, Constant): + # We need to unset fake mode, or else the torch.tensor() call will + # turn into a FakeTensor + with _disable_current_modes(): + return V.graph.add_tensor_constant( + torch.tensor(x.value, dtype=x.get_dtype(), device=x.get_device()) + ) + if isinstance(x, ConstantBuffer): + return x + if isinstance(x, TensorBox): + return cls.realize_input(x.data) + if isinstance(x, ReinterpretView): + return ReinterpretView( + data=cls.realize_input(x.data), layout=x.get_layout() + ) + if isinstance(x, BaseView): + x.realize() + if is_storage_and_layout(x.unwrap_view()): + try: + return cls.convert_to_reinterpret_view(x) + except NotImplementedError: + pass + if isinstance(x, StorageBox): + # TODO(jansel): impose layout preference on realized buffer + x.realize() + return x + if isinstance(x, (NonTensorObj, ShapeAsConstantBuffer)): + return x + return cls.copy_input(x) + + @classmethod + def require_stride1(cls, x: IRNode) -> IRNode: + if is_storage_and_layout(x): + if len(x.get_stride()) == 0: + return x + for stride in x.get_stride(): + if stride == 1: + return x + return cls.copy_input(x) + + @classmethod + def require_strides( + cls, + x: IRNode, + order: Optional[Sequence[int]] = None, + exact_strides: Optional[Sequence[_IntLike]] = None, + allow_padding: bool = False, + ) -> IRNode: + assert order is not None or exact_strides is not None + # Layout generally doesn't matter, but some consuming external ops might have requirements + if x.get_numel() in (0, 1) and not exact_strides: + return x + + # require x to have the layout + if is_storage_and_layout(x): + if isinstance(x.get_layout(), FlexibleLayout): + if order: + # If the FlexibleLayout already has the size and stride in the required order, + # freeze it to a FixedLayout by using its current size and stride. + # The behavior of using its current size and stride or the given order can be different + # if the size and stride has ambiguilty, for example for a 4D input where the iC = 1: + # size=[s0, 1, 28, 28], stride=[784, 784, 28, 1]. If the required order is [3, 0, 2, 1] (channels last), + # the current size and stride already satisfies this order. + # However by freezing it to the required order, the layout will be changed to: + # size=[s0, 1, 28, 28], stride=[784, 1, 28, 1]), which is not actually necessary. + use_current_stride_order = is_stride_order_storage_and_layout( + x, order + ) and not free_unbacked_symbols(x.get_layout().stride) + # fix flexiblelayout to be FixedLayout with stride_order + as_storage_and_layout( + x, + freeze=True, + want_contiguous=False, + stride_order=( + get_stride_order( + V.graph.sizevars.size_hints_or_throw( + x.get_layout().stride + ) + ) + if use_current_stride_order + else order + ), + allow_padding=allow_padding, + ) + return x + else: + # If the exact_strides is given, freeze the FlexibleLayout to a FixedLayout with the exact_strides. + as_storage_and_layout( + x, + freeze=True, + want_contiguous=False, + stride_order=None, + allow_padding=allow_padding, + exact_strides=exact_strides, + ) + return x + elif isinstance(x.get_layout(), (FixedLayout, NonOwningLayout)) and ( + (order and x.get_layout().is_stride_ordered(order)) + or ( + exact_strides + and significant_strides_equal( + exact_strides, x.get_layout().stride, x.get_size() + ) + ) + ): + return ( + try_match_insignificant_strides(x, exact_strides) + if exact_strides is not None + else x + ) + elif isinstance( + (mutation_layout := x.get_layout()), MutationLayoutSHOULDREMOVE + ): + if isinstance( + (real_layout := mutation_layout.real_layout()), FlexibleLayout + ): + raise AssertionError( + "the MutationLayoutSHOULDREMOVE's real layout shouldn't be FlexibleLayout" + ) + elif isinstance(real_layout, FixedLayout) and ( + (order and real_layout.is_stride_ordered(order)) + or ( + exact_strides + and significant_strides_equal( + exact_strides, real_layout.stride, x.get_size() + ) + ) + ): + return x + + # TODO - Storage to InputBuffer + if isinstance(x, InputBuffer) and ( + (order and x.get_layout().is_stride_ordered(order)) + or ( + exact_strides + and significant_strides_equal( + exact_strides, x.get_layout().stride, x.get_size() + ) + ) + ): + return x + if ( + isinstance(x, TensorBox) + and isinstance(x.data, BaseView) + and not isinstance(x.data, ReinterpretView) + and is_storage_and_layout(unwrap_view := x.unwrap_view()) + and hasattr(unwrap_view, "data") + and not isinstance(unwrap_view.data, ExternKernelAlloc) + ): + try: + x.data = cls.convert_to_reinterpret_view(x.data) + if order: + return cls.require_stride_order( + x, order, allow_padding=allow_padding + ) + elif exact_strides: + return cls.require_exact_strides( + x, exact_strides, allow_padding=allow_padding + ) + except NotImplementedError: + pass + + # Preserve ExpandView representation that would be lost during copy_input + # Without representation of the expand in inductor IR, in codegen we end up + # launching a grid for the full size tensor and doing redundant computation + # across expanded dims. + # TODO: could also be good to have a codegen fix to recognize overlapping elements + + expanded_dims: Optional[list[int]] = None + orig_size = x.get_size() + if exact_strides is not None: + sizevars = V.graph.sizevars + expanded_dims = [ + i + for i in range(len(x.get_size())) + if sizevars.statically_known_equals(exact_strides[i], 0) + and sizevars.statically_known_geq(x.get_size()[i], 2) + ] + + for dim in expanded_dims: + x = torch._inductor.lowering.slice_(x, dim, 0, 1) + + # Although this is a clone, inductor is good about fusing clones into previous + # operations if they weren't realized and their layouts were flexible. + x = cls.copy_input(x) + + as_storage_and_layout( + x, + freeze=True, + want_contiguous=False, + stride_order=order, + allow_padding=allow_padding, + exact_strides=exact_strides, + ) + if order: + assert is_stride_order_storage_and_layout(x, order) + elif expanded_dims: + assert orig_size is not None and exact_strides is not None + x = torch._inductor.lowering.expand(x, orig_size) + # the expand will sometimes may change insignificant strides, so match them back + return try_match_insignificant_strides(x, exact_strides) + + return x + + @classmethod + def require_exact_strides( + cls, x: IRNode, exact_strides: Sequence[_IntLike], allow_padding: bool = False + ) -> IRNode: + return cls.require_strides( + x, exact_strides=exact_strides, allow_padding=allow_padding + ) + + @classmethod + def require_stride_order( + cls, x: IRNode, order: Sequence[int], allow_padding: bool = False + ) -> IRNode: + return cls.require_strides(x, order=order, allow_padding=allow_padding) + + @classmethod + def require_channels_last(cls, x: IRNode) -> IRNode: + return cls.require_stride_order(x, NHWC_STRIDE_ORDER) + + @classmethod + def require_channels_last_3d(cls, x: IRNode) -> IRNode: + return cls.require_stride_order(x, NHWDC_STRIDE_ORDER) + + @classmethod + def require_contiguous(cls, x: IRNode) -> IRNode: + def is_mkldnn_tensor(x: IRNode) -> bool: + try: + name = x.get_name() + except (AttributeError, NotImplementedError): + return False + + return name in V.graph.constants and V.graph.constants[name].is_mkldnn + + # TODO move this to the more proper places + if is_mkldnn_tensor(x): + return x + else: + return cls.require_exact_strides( + x, FlexibleLayout.contiguous_strides(x.get_size()) + ) + + @classmethod + def require_contiguous_strides(cls, x: IRNode) -> IRNode: + # TODO: combine this with require_contiguous after + # https://github.com/pytorch/pytorch/pull/148235 lands. + return cls.require_exact_strides( + x, FlexibleLayout.contiguous_strides(x.get_size()) + ) + + def apply_constraint(self) -> None: + pass + + def fill_non_provided_args( + self, args: Sequence[Any], kwargs: dict[str, Any] + ) -> Sequence[Any]: + # Previously, we want to maintain forward-compatibility by skipping + # default args in the serialized artifacts in fbcode. However, + # some of our shim interfaces require default values being OrderedSet. + # Discussed with Sherlock offline and we decided to allow serializing + # default args into the C++ wrapper code for now. We will refine this + # part if we see real FC requirement. More details related to FC + # can be found at: + # https://docs.google.com/document/d/1FzWm-sHYwmRi3x_g036kOxd99KaYquUsA-L5JwOn8ys/edit?usp=sharing + assert isinstance(args, Sequence), type(args) + if not isinstance(args, list): + args = list(args) + assert self.arg_properties, "ExternKernel.arg_properties should not be empty" + + n_args = len(args) + n_pos_args = len(self.arg_properties) + # For cpp wrapper, if some positional args are not provided, we need to check + # if they're in the kwargs or use their default value + if n_args < n_pos_args: + log.debug( + "%s has %d unprovided positional arguments. " + "Will check if they are in the keyword arguments or will use default values.", + self.op_overload, + n_pos_args - n_args, + ) + for i in range(n_args, n_pos_args): + arg_name = self.arg_properties[i]["name"] + args.append( + kwargs[arg_name] + if arg_name in kwargs + else self.arg_properties[i]["default_value"] + ) + return args + + def codegen_const_args(self, names: Optional[list[str]] = None) -> list[str]: + if V.graph.cpp_wrapper: + result = [] + # Aten ops follow the convention that tensor args are before non-tensor args, + # in which case the following 'len(self.inputs) + i' logic works. But this + # may not be true for other ops, and if that is the case, caller needs to + # pass in a list of const arg names for arg_properties lookup. + name_to_arg_properties = None + if names and self.arg_properties: + assert len(self.constant_args) == len(names), ( + "names passed to codegen_const_args does not match self.constant_args" + ) + name_to_arg_properties = { + arg.get("name"): arg for arg in self.arg_properties + } + + for i, x in enumerate(self.constant_args): + if name_to_arg_properties is not None: + assert names is not None + prop = name_to_arg_properties.get(names[i]) + type_ = prop.get("type") if prop else None + else: + idx = len(self.inputs) + i + type_ = ( + self.arg_properties[idx].get("type") + if self.arg_properties and idx < len(self.arg_properties) + else None + ) + result.append(V.graph.wrapper_code.val_to_arg_str(x, type_)) + return result + else: + return [V.graph.wrapper_code.val_to_arg_str(a) for a in self.constant_args] + + def codegen_args(self) -> list[str]: + if V.graph.cpp_wrapper and self.op_overload is not None: + # cpp wrapper needs special logic to fill in missing args with default values + inputs = self.fill_non_provided_args( + [*self.inputs, *self.constant_args], self.kwargs + ) + # fill_non_provided_args has handled constant args, so no need to codegen for that later + need_codegen_constant_args = False + else: + inputs = self.inputs + need_codegen_constant_args = True + + args = [] + for i, x in enumerate(inputs): + if V.graph.cpp_wrapper: + assert self.arg_properties and i < len(self.arg_properties), ( + "Invalid access to ExternKernel.arg_properties" + ) + type_ = self.arg_properties[i].get("type") + args.append(V.graph.wrapper_code.val_to_arg_str(x, type_)) + else: + args.append(V.graph.wrapper_code.val_to_arg_str(x)) + if need_codegen_constant_args: + args.extend(self.codegen_const_args()) + return args + + def get_kwargs_value(self, arg_name: str, **kwargs: Any) -> Any: + """Given an argument name, queries for values in (in order): + 1. any provided kwargs for this function. + 2. the class self.kwargs member. + 3. any available default arguments in self.allarg_properties.""" + if arg_name in kwargs: + return kwargs.get(arg_name) + if arg_name in self.kwargs: + return self.kwargs.get(arg_name) + if (arg := self.allarg_properties.get(arg_name)) is not None: + return arg.get("default_value") + raise AssertionError(f"{arg_name} not in self.allarg_properties") + + def codegen_kwargs(self, skip_out: bool = False) -> list[str]: + if V.graph.cpp_wrapper: + if self.op_overload is not None and len(self.schema_kwargs) == 0: + # All the args should have been generated by fill_non_provided_args in codegen_args + return [] + + kwargs = [] + for arg_name in self.ordered_kwargs_for_cpp_kernel: + if skip_out and arg_name == "out": + # ExternKernelOut has its own logic for inserting the out parameter + continue + + v = self.get_kwargs_value(arg_name) + if isinstance(v, Expr): + kwargs.append(v) + else: + assert self.allarg_properties is not None + type_ = self.allarg_properties.get(arg_name, {}).get("type") + kwargs.append(V.graph.wrapper_code.val_to_arg_str(v, type_)) + else: + kwargs = [ + f"{k}={V.graph.wrapper_code.val_to_arg_str(v)}" + for k, v in self.kwargs.items() + ] + return kwargs + + def get_op_name(self) -> str: + if self.fx_node is not None: + target = self.fx_node.target + op_namespace = getattr(target, "__module__", "unknown_namespace") + op_namespace = op_namespace.replace("._ops.", ".ops.") + op_namespace = op_namespace.rsplit(".", 1)[0] + op_name = f"{op_namespace}.{target}" + else: + op_name = "unknown_op" + return op_name + + def codegen_size_asserts(self, wrapper: PythonWrapperCodegen) -> None: + if config.size_asserts and not V.graph.cpp_wrapper: + # comparing strides for 0 size tensor is tricky. Ignore them for now. + if sympy_product(self.get_size()) == 0: + return + size = V.graph.wrapper_code.codegen_shape_tuple(self.get_size()) + stride = V.graph.wrapper_code.codegen_shape_tuple(self.get_stride()) + op_name = self.get_op_name() + wrapper.writeline( + f"assert_size_stride({self.get_name()}, {size}, {stride}, {op_name!r})" + ) + + def codegen_alignment_asserts(self, wrapper: PythonWrapperCodegen) -> None: + if config.alignment_asserts and not V.graph.cpp_wrapper: + name = self.get_name() + aligned = name not in V.graph.unaligned_buffers + op_name = self.get_op_name() + if aligned: + wrapper.writeline( + f"assert_alignment({name}, {GPU_ALIGN_BYTES}, {op_name!r})" + ) + else: + wrapper.writeline( + f"# buffer {name} (op: {op_name}) is assumed to be not aligned" + ) + + def codegen_memory_tracking(self, wrapper: PythonWrapperCodegen) -> None: + """ + Track outputs of fallback operators if config.test_configs.track_memory_lifecycle + """ + if not config.test_configs.track_memory_lifecycle or V.graph.cpp_wrapper: + return + + wrapper.write_memory_track_allocation_once() + name = self.get_name() + wrapper.writeline(f"track_tensor({name}, '{name}')") + + def get_group_stride(self) -> tuple[list[Sequence[Expr]], list[Expr]]: + """ + get output sizes and strides, for template_codegen + """ + _size = self.get_size() + _stride = self.get_stride() + # iter_ranges = _size of output tensor, reduce_range = [] because no reduction + return [_size, []], _stride + + def canonicalize(self) -> tuple[Expr, Sequence[Expr]]: + """ + Manually get canonicalization of the output index + """ + # manually generate index formula for conv + sizevars = V.graph.sizevars + sizes = self.get_size() + strides = self.get_stride() + strides = [sizevars.size_hint(x) for x in strides] + # TODO: I can't tell if the symbols here are temporary + index_vars = [sympy_index_symbol(f"d{i}") for i in range(len(sizes))] + # reorder index vars according to stride + index_order = sorted(range(len(strides)), key=strides.__getitem__, reverse=True) + lookup = {pos: idx for idx, pos in enumerate(index_order)} + order = [lookup[i] for i in range(len(lookup))] + index_vars = [index_vars[i] for i in order] + indexer = self.make_indexer() + index = indexer(index_vars) + + new_sizes, reindex, _prune = V.graph.sizevars._simplify_loops( + index_vars, sizes, [index] + ) + + # assign new variables each dimension to deal with numbering mismatches + # d0, d1, d2 could become d0, d2 -- which won't match d0, d1 + _, add_var = var_builder("c") + replacement = dict(zip(index_vars, reindex([add_var(x) for x in new_sizes]))) + + index = sympy_subs(sympy.expand(index), replacement) + return index, tuple(new_sizes) + + @cache_on_self_and_args("ExternKernel") + def get_free_symbol_uses( + self, unbacked_only: bool = False + ) -> OrderedSet[sympy.Symbol]: + # NB: It's not necessary to check regular inputs as we automatically + # have dependencies on them + maybe_get_symbols = ( + maybe_free_unbacked_symbols if unbacked_only else maybe_free_symbols + ) + r = InputsKernel.get_free_symbol_uses(self, unbacked_only) + for arg in self.constant_args: + r |= maybe_get_symbols(arg) + for arg in self.kwargs.values(): + r |= maybe_get_symbols(arg) + return r + + def __str__(self) -> str: + kernel_name = getattr(self, "python_kernel_name", None) + lines = [ + f"python_kernel_name={kernel_name!r}", + ] + lines += [ + f"{field.name}={getattr(self, field.name)}" + for field in dataclasses.fields(self) + ] + lines.append(f"origin_node={self.origin_node!r}") + return self.str_helper(lines) + + __repr__ = __str__ + + +@ir_dataclass(frozen=False) +class ExternKernelOut(ExternKernel): + def codegen(self, wrapper: PythonWrapperCodegen) -> None: + wrapper.generate_extern_kernel_out(self) + + def __init__( + self, + layout: Layout, + inputs: Sequence[IRNode], + constant_args: Sequence[Any] = (), + kwargs: Optional[dict[str, Any]] = None, + output_view: Optional[ReinterpretView] = None, + python_kernel_name: Optional[str] = None, + cpp_kernel_name: Optional[str] = None, + ordered_kwargs_for_cpp_kernel: Sequence[Any] = (), + op_overload: Optional[_OpOverloads] = None, + ) -> None: + unwrapped_inputs = self.unwrap_storage(inputs) + assert isinstance(unwrapped_inputs, Sequence), type(unwrapped_inputs) + super().__init__( + None, + layout, + unwrapped_inputs, + constant_args, + kwargs or {}, + None, + python_kernel_name, + cpp_kernel_name, + ordered_kwargs_for_cpp_kernel, + op_overload, + ) + self.name = V.graph.register_buffer(self) + V.graph.register_operation(self) + + def should_allocate(self) -> bool: + return True + + +class RandomSeeds(ExternKernelOut): + def __init__(self, count: int, device: torch.device) -> None: + limits = torch.iinfo(torch.int64) + super().__init__( + layout=FixedLayout( + device=device, + dtype=torch.int64, + size=[count], + ), + inputs=[], + constant_args=[limits.min, limits.max, [count]], + python_kernel_name="aten.randint.low_out", + # FIXME: Ideally we should only use at::_ops::randint_low_out::call here, + # but the signature is different from is at::randint_out. Again, + # we can simplify the code when only keeping an ABI-compatible version. + cpp_kernel_name="at::_ops::randint_low_out::call", + op_overload=aten.randint.low_out, + ) + + +class ExternKernelAlloc(ExternKernel): + def codegen(self, wrapper: PythonWrapperCodegen) -> None: + wrapper.generate_extern_kernel_alloc(self) + + def __init__( + self, + layout: OutputSpec, + inputs: Sequence[IRNode], + constant_args: Sequence[Any] = (), + kwargs: Optional[dict[str, Any]] = None, + python_kernel_name: Optional[str] = None, + cpp_kernel_name: Optional[str] = None, + ordered_kwargs_for_cpp_kernel: Sequence[Any] = (), + op_overload: Optional[_OpOverloads] = None, + ) -> None: + unwrapped_inputs = self.unwrap_storage(inputs) + assert all(isinstance(i, IRNode) for i in unwrapped_inputs) + super().__init__( + None, + layout, + cast(Sequence[IRNode], unwrapped_inputs), + constant_args, + kwargs or {}, + None, + python_kernel_name, + cpp_kernel_name, + ordered_kwargs_for_cpp_kernel, + op_overload, + ) + # We need output buffers for generating kernel arguments in the + # abi-compatible mode, where we retrieve outputs by pass each individual + # output through the abi-compatible interface. + self.outputs: Sequence[Any] = [] + self.name = V.graph.register_buffer(self) + V.graph.register_operation(self) + + def should_allocate(self) -> bool: + return False + + def apply_constraint(self) -> None: + raise NotImplementedError + + +class MutationOutput(Buffer): + """ + An output buffer that represents the mutation of a pre-existing buffer + """ + + def __init__( + self, layout: OutputSpec, mutated_node: IRNode, mutating_node: Operation + ) -> None: + super().__init__(name=None, layout=layout) + mutated_node_name = mutated_node.get_name() + V.graph.mark_buffer_mutated(mutated_node_name) + self.mutation_names = [mutated_node_name] + self.mutating_node: Operation = mutating_node + self.name = V.graph.register_buffer(self) + + def get_defining_op(self) -> Operation: + return self.mutating_node + + def get_mutation_names(self) -> Sequence[str]: + return self.mutation_names + + def should_allocate(self) -> bool: + return False + + def get_mutation_buffers(self) -> Sequence[IRNode]: + mutation_names = self.get_mutation_names() + return [ + buf + for buf in (V.graph.try_get_buffer(name) for name in mutation_names) + if buf is not None + ] + + +class TMADescriptor(ExternKernel): + """ + An IR node representing a generic host-side TMA descriptor in the Triton API + Mostly useful for user-defined Triton kernels relying on host-side TMA; + but can, in principle, be used for Inductor's Triton templates, too. + + See TMADescriptorExperimental and TMADescriptorStable for the two implementations + (the old API and the new API) + """ + + # as TMA descriptors are immutable, + # we can dedup them by the input args + _CACHE: dict[Any, TMADescriptor] = {} + + @classmethod + def _create_impl( + cls, tensor: IRNode, tma_meta: tuple[str, tuple[Any, ...]] + ) -> TMADescriptor: + assert len(tma_meta) == 2 + if tma_meta[0] == "experimental": + return TMADescriptorExperimental(tensor, *tma_meta[1]) + else: + assert tma_meta[0] == "stable" + return TMADescriptorStable(tensor, *tma_meta[1]) + + @classmethod + def create( + cls, tensor: IRNode, tma_meta: tuple[str, tuple[Any, ...]] + ) -> TMADescriptor: + key = (id(tensor), tma_meta) + if key not in cls._CACHE: + cls._CACHE[key] = cls._create_impl(tensor, tma_meta) + return cls._CACHE[key] + + def __init__( + self, tensor: IRNode, inputs: Sequence[Any], constant_args: Sequence[Any] + ) -> None: + super().__init__( + None, + # link back to the underlying tensor in terms of ownership + # to avoid getting the underlying tensor deleted *before* + # the TMADescriptor node can be deleted. + NonOwningLayout( + ReinterpretView( + data=tensor, + layout=tensor.get_layout(), + ) + ), + cast(Sequence[Buffer], inputs), + tuple(constant_args), + None, + ) + + self.tensor = tensor + self.name = V.graph.register_buffer(self) + V.graph.register_operation(self) + + def codegen(self, wrapper: PythonWrapperCodegen) -> None: + wrapper.generate_tma_descriptor(self) + + def get_tensor(self) -> IRNode: + return self.tensor + + +class TMADescriptorExperimental(TMADescriptor): + """ + the new host-side TMA Descriptor API: + (the ones obtained via create_{1d,2d}_tma_descriptor calls). + + See also TMADescriptorStable for the new API. + """ + + def __init__( + self, + tensor: IRNode, + dims: list[Union[int, torch.SymInt]], + block_dims: list[Union[int, torch.SymInt]], + element_size: Optional[int] = None, + ) -> None: + assert len(dims) in (1, 2) + assert len(dims) == len(block_dims) + + if element_size is None: + element_size = tensor.get_dtype().itemsize + + self.dims = dims + self.block_dims = block_dims + self.element_size = element_size + self.rank = len(self.dims) + + inputs = [tensor] + constant_args = [ + *self.dims, + *self.block_dims, + self.element_size, + ] + + super().__init__( + tensor=tensor, + inputs=inputs, + constant_args=constant_args, + ) + + +class TMADescriptorStable(TMADescriptor): + """ + the new host-side TMA descriptor API + (the ones obtained via TensorDescriptor.from_tensor). + + See also TMADescriptorExperimental for the old API. + """ + + def __init__(self, tensor: IRNode, block_shape: list[Union[int, torch.SymInt]]): + self.block_shape = block_shape + + super().__init__( + tensor=tensor, + inputs=[tensor], + constant_args=block_shape, + ) + + +class SubgraphBuffer(ExternKernel): + def __init__( + self, + layout: Layout, + input_nodes: list[Buffer], + gm: torch.fx.GraphModule, + example_inputs: list[Any], + subgraph_name: str, + ): + super().__init__(None, layout, input_nodes) + self.gm = gm + self.example_inputs = example_inputs + self.name = V.graph.register_buffer(self) + V.graph.register_operation(self) + + self.subgraph = V.graph.make_subgraph(self.gm, example_inputs, subgraph_name) + + assert is_node_sequence(self.inputs) + sym_inputs = get_symbolic_inputs(self.inputs) + + for sym_inp in sym_inputs: + self.subgraph.graph_inputs[sym_inp.name] = sym_inp + self.subgraph.graph_input_names.append(sym_inp.name) + + self.sym_inputs = [sym_var.name for sym_var in sym_inputs] + + import torch._inductor.config as inductor_config + + with V.set_graph_handler(self.subgraph): + # Don't bother autotuning on Triton here + with inductor_config.patch( + max_autotune=False, + max_autotune_gemm=False, + max_autotune_gemm_backends="ATEN", + ): + self.subgraph.run(*self.example_inputs) + + def codegen(self, wrapper: PythonWrapperCodegen) -> None: + class CodegenGraph: + def __init__(self, graph: GraphLowering): + self.graph = graph + self.name = graph.name + + assert is_node_sequence(self.inputs) + outer_inputs = [t.codegen_reference() for t in self.inputs] + wrapper.codegen_subgraph_with_flattened_outputs( + CodegenGraph(self.subgraph), + [*self.sym_inputs, *outer_inputs], + [self.name], + ) + + +class UserDefinedTritonKernel(ExternKernel): + def get_kernel_and_metadata(self) -> tuple[Kernel, Any, list[str], list[str]]: + from triton.runtime.autotuner import Autotuner + + from torch._higher_order_ops.triton_kernel_wrap import kernel_side_table + + kernel = kernel_side_table.get_kernel(self.kernel_idx) + configs = [] + restore_value_args: list[str] = [] + reset_to_zero_args: list[str] = [] + if isinstance(kernel, Autotuner): + # https://github.com/triton-lang/triton/pull/5083 + # changes kernel.restore_idx to kernel.restore_value + if hasattr(kernel, "restore_idx"): + restore_value_args.extend( + kernel.fn.arg_names[i] for i in kernel.restore_idx + ) + else: + assert hasattr(kernel, "restore_value") + restore_value_args.extend(kernel.restore_value) + + if hasattr(kernel, "reset_idx"): + for i in kernel.reset_idx: + reset_to_zero_args.append(kernel.fn.arg_names[i]) + else: + assert hasattr(kernel, "reset_to_zero") + reset_to_zero_args.extend(kernel.reset_to_zero) + + configs = kernel.configs + kernel = kernel.fn + # pyrefly: ignore # bad-return + return kernel, configs, restore_value_args, reset_to_zero_args + + @override + def codegen(self, wrapper: PythonWrapperCodegen) -> None: + """Overrides the parent member. + See https://github.com/pytorch/pytorch/issues/151692""" + + from torch._inductor.utils import triton_version_uses_attrs_dict + + ( + kernel, + configs, + restore_value_args, + reset_to_zero_args, + ) = self.get_kernel_and_metadata() + + # Definition of kernel + ( + new_name, + triton_meta, + extra_launch_args, + ) = wrapper.define_user_defined_triton_kernel( + kernel, + configs, + self.kwargs, + restore_value_args, + reset_to_zero_args, + self.grid, + ) + named_args = { + k: self.get_kwargs_value(k) for k in self.ordered_kwargs_for_cpp_kernel + } + arg_names = [p.name for p in kernel.params] # type: ignore[attr-defined] + constexprs = [p.num for p in kernel.params if p.is_constexpr] # type: ignore[attr-defined] + constexpr_names = OrderedSet(arg_names[i] for i in constexprs) + + args: list[Any] = [] + arg_types: list[Any] = [] + raw_keys_filtered: list[Any] = [] + raw_args_filtered: list[Any] = [] + for name, arg in itertools.chain( + named_args.items(), zip(itertools.repeat(""), extra_launch_args) + ): + if name in constexpr_names and triton_version_uses_attrs_dict(): + # see #160000 - we don't pass in constexpr args to speed up runtime. + continue + raw_keys_filtered.append(name) + raw_args_filtered.append(arg) + if isinstance(arg, IRNode): + args.append(arg.codegen_reference()) + arg_types.append(arg.get_dtype()) + elif isinstance(arg, (int, float, bool, sympy.Expr)): + args.append(arg) + arg_types.append(type(arg)) + elif name in constexpr_names: + # insert a dummy value for constexpr args of unsupported type + # constexprs will end up getting baked into the kernel at compile time + args.append(-1) + arg_types.append(int) + elif arg is None: + """ + Filter out None args. + + see https://github.com/pytorch/pytorch/issues/115344 + + Two cases for a None arg: + 1. The arg is already tl.constexpr, so leave it in + 2. The arg is not tl.constexpr so we have to remove it + """ + if triton_version_uses_attrs_dict(): + args.append(-1) + arg_types.append(int) + else: + raw_keys_filtered.pop() + raw_args_filtered.pop() + else: + raise NotImplementedError(f"Unsupported arg type: {type(arg)}: {arg}") + + self.codegen_comment(wrapper, new_name) + wrapper.generate_kernel_call( + new_name, + args, + arg_types=arg_types, + raw_args=raw_args_filtered, + raw_keys=raw_keys_filtered, + triton_meta=triton_meta, + triton=True, + device=self.get_device(), + original_fxnode_name=self.fx_node.name, + ) + + @cache_on_self_and_args("UserDefinedTritonKernel") + def get_free_symbol_uses( + self, unbacked_only: bool = False + ) -> OrderedSet[sympy.Symbol]: + # add unbacked symbols used in the grid to the ones used + # in the kwargs (the latter is generated by ExternKernel) + return super().get_free_symbol_uses(unbacked_only) | get_free_symbols( + self.grid, unbacked_only + ) + + def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: + return OrderedSet() + + def __init__( + self, + *, + kernel_idx: int, + grid: Any, + tma_descriptor_metadata: dict[str, Any], + kernel_args: dict[str, Any], + ) -> None: + inputs: list[IRNode] = [] + kwargs: dict[str, IRNode] = {} + constant_args: list[IRNode] = [] + + for k, v in kernel_args.items(): + if isinstance(v, TensorBox): + t = InputsKernel.unwrap_storage_for_input(self.realize_input(v)) + if k in tma_descriptor_metadata: + t = TMADescriptor.create(t, tma_descriptor_metadata[k]) + inputs.append(t) + kwargs[k] = t + else: + constant_args.append(v) + kwargs[k] = v + + assert len(inputs) != 0 + self.device = inputs[0].get_device() + + assert isinstance(inputs, Sequence), type(inputs) + super().__init__( + None, + NoneLayout(device=self.device), + inputs, + tuple(constant_args), + kwargs, + ) + self.kernel_idx = kernel_idx + self.grid = grid + + kernel, configs, _, _ = self.get_kernel_and_metadata() + + # If we are autotuning, not all arguments will be passed + assert hasattr(kernel, "arg_names") + self.ordered_kwargs_for_cpp_kernel = [ + arg for arg in kernel.arg_names if arg in kernel_args + ] + + from torch._higher_order_ops.triton_kernel_wrap import identify_mutated_tensors + + autotuned_kwargs = configs[0].kwargs if len(configs) > 0 else {} + self.mutable_args = [ + kernel_args[key] + for key in identify_mutated_tensors( + # pyrefly: ignore # bad-argument-type + kernel, + {**kernel_args, **autotuned_kwargs}, + tma_descriptor_metadata, + ) + ] + + self.mutation_outputs = [ + MutationOutput(NoneLayout(device=self.device), buf, self) + for buf in self.mutable_args + ] + V.graph.register_operation(self) + + def get_outputs(self) -> list[Buffer]: + return list(self.mutation_outputs) + + def get_device(self) -> Optional[torch.device]: + return self.device + + +class InplaceBernoulliFallback(ExternKernel): + """ + This needs to be a custom class to handle mutation properly + """ + + def codegen(self, wrapper: PythonWrapperCodegen) -> None: + assert all(isinstance(t, IRNode) for t in self.inputs) + (x,) = (cast(IRNode, t).codegen_reference() for t in self.inputs) + + if V.graph.cpp_wrapper: + # Inductor doesn't really support aten Generator, so the Generator kwarg is always NULL here, + # which needs to be explicitly generated for cpp wrapper + wrapper.writeline( + f"{self.get_kernel_name()}({x}, {', '.join(map(repr, self.constant_args))}, NULL){wrapper.ending}" + ) + else: + wrapper.writeline( + f"{self.get_kernel_name()}({x}, {', '.join(map(repr, self.constant_args))}){wrapper.ending}" + ) + + def should_allocate(self) -> bool: + return False + + def get_mutation_names(self) -> Sequence[str]: + return [self.input_name(0)] + + def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: + return OrderedSet() + + def __init__( + self, op_overload: _OpOverloads, x: IRNode, *constant_args: Any + ) -> None: + super().__init__( + None, + NoneLayout(device=x.get_device()), + self.unwrap_storage([x]), + constant_args, + op_overload=op_overload, + ) + V.graph.mark_buffer_mutated(x.get_name()) + self.name = V.graph.register_buffer(self) + V.graph.register_operation(self) + + +# Used to deal with torch.complex types +class InplaceCopyFallback(ExternKernel): + """ + This needs to be a custom class to handle mutation properly + """ + + def codegen(self, wrapper: PythonWrapperCodegen) -> None: + (dst, src, non_blocking) = self.codegen_args() + wrapper.codegen_device_copy(src, dst, non_blocking) + + def should_allocate(self) -> bool: + return False + + def get_mutation_names(self) -> Sequence[str]: + return [self.input_name(0)] + + def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: + return OrderedSet() + + def __init__( + self, + layout: OutputSpec, + inputs: Sequence[IRNode], + constant_args: Sequence[Any], + ) -> None: + super().__init__( + None, + layout, + inputs, + constant_args, + python_kernel_name="aten.copy_", + cpp_kernel_name="aoti_torch_copy_", + ) + V.graph.mark_buffer_mutated(inputs[0].get_name()) + self.name = V.graph.register_buffer(self) + V.graph.register_operation(self) + + @classmethod + def create( + cls, dst: IRNode, src: IRNode, non_blocking: bool = False + ) -> InplaceCopyFallback: + inputs = [cls.realize_input(t) for t in [dst, src]] + constant_args = (non_blocking,) + result = InplaceCopyFallback( + NoneLayout(device=dst.get_device()), + inputs, + constant_args, + ) + return result + + +class MutatingFirstArgExternKernel(ExternKernel): + """ + This needs to be a custom class to handle mutation properly + """ + + def codegen(self, wrapper: PythonWrapperCodegen) -> None: + assert is_node_sequence(self.inputs) + argrefs = [ + *(t.codegen_reference() for t in self.inputs), + *map(repr, self.constant_args), + ] + wrapper.writeline( + f"{self.get_kernel_name()}({', '.join(argrefs)}){wrapper.ending}" + ) + + def should_allocate(self) -> bool: + return False + + def get_mutation_names(self) -> Sequence[str]: + return [self.input_name(0)] + + def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: + return OrderedSet() + + def has_side_effects(self) -> bool: + return True + + +class ResizeStorageBytes(MutatingFirstArgExternKernel): + def __init__(self, variable: IRNode, new_size: int) -> None: + assert isinstance(new_size, int), "TODO: dynamic shapes" + super().__init__( + None, + NoneLayout(device=variable.get_device()), + self.unwrap_storage([variable]), + constant_args=(new_size,), + ) + V.graph.mark_buffer_mutated(variable.get_name()) + self.name = V.graph.register_buffer(self) + V.graph.register_operation(self) + self.python_kernel_name = "inductor_ops.resize_storage_bytes_" + self.cpp_kernel_name = "torch::inductor::resize_storage_bytes_" + assert isinstance(variable, (BaseView, StorageBox, TensorBox)), type(variable) + V.graph.never_reuse_buffers.add(variable.data.get_name()) + + +class SetSourceTensorKernel(ExternKernelAlloc): + def __init__(self, self_tensor: IRNode, storage_tensor: IRNode) -> None: + storage_tensor.freeze_layout() + super().__init__( + storage_tensor.get_layout(), + [self_tensor, storage_tensor], + python_kernel_name="torch.ops.aten.set_.source_Tensor", + op_overload=torch.ops.aten.set_.source_Tensor, + ) + assert isinstance(self_tensor, (BaseView, StorageBox, TensorBox)), type( + self_tensor + ) + V.graph.never_reuse_buffers.add(self_tensor.data.get_name()) + V.graph.never_reuse_buffers.add(storage_tensor.get_name()) + V.graph.never_reuse_buffers.add(self.get_name()) + device = storage_tensor.get_device() + self.mutation_outputs = [ + MutationOutput(NoneLayout(device=device), self_tensor, self), + MutationOutput(NoneLayout(device=device), storage_tensor, self), + ] + + def get_inputs_that_alias_output(self) -> Sequence[str]: + return [self.input_name(0), self.input_name(1)] + + +class ScatterFallback(ExternKernel): + """ + This needs to be a custom class to handle mutation properly. + This class handles both aten.scatter_ and aten.scatter_reduce_. + It also handle the case `src` being a scalar properly. + """ + + def codegen(self, wrapper: PythonWrapperCodegen) -> None: + wrapper.generate_scatter_fallback(self) + + def should_allocate(self) -> bool: + return False + + def get_mutation_names(self) -> list[str]: + inp = self.inputs[0] + assert isinstance(inp, IRNode) + return [inp.get_name()] + + def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: + return OrderedSet() + + def __init__( + self, + op_overload: _OpOverloads, + x: IRNode, + dim: int, + index: IRNode, + src: IRNode, + *, + reduce: Optional[str] = None, + include_self: bool = True, + ) -> None: + self.src_is_tensor = isinstance(src, TensorBox) + + constant_args: tuple[Any, ...] + if self.src_is_tensor: + tensors = [self.realize_input(t) for t in [x, index, src]] + constant_args = (dim,) + else: + tensors = [self.realize_input(t) for t in [x, index]] + constant_args = (dim, src) + + super().__init__( + None, + NoneLayout(device=x.get_device()), + self.unwrap_storage(tensors), + constant_args, + {"reduce": reduce, "include_self": include_self}, + python_kernel_name=str(op_overload), + ordered_kwargs_for_cpp_kernel=["reduce", "include_self"], + op_overload=op_overload, + ) + V.graph.mark_buffer_mutated(x.get_name()) + self.name = V.graph.register_buffer(self) + V.graph.register_operation(self) + + +class IndexPutFallback(ExternKernel): + """ + This needs to be a custom class to handle mutation and indices properly + """ + + def codegen(self, wrapper: PythonWrapperCodegen) -> None: + wrapper.generate_index_put_fallback(self) + + def should_allocate(self) -> bool: + return False + + def get_mutation_names(self) -> Sequence[str]: + return [self.input_name(0)] + + def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: + return OrderedSet() + + def __init__( + self, + op_overload: torch._ops.OpOverload, + x: IRNode, + indices: list[Any], + values: Sequence[Any], + accumulate: Any, + ) -> None: + self.indices = indices + valid_indices = [i for i in indices if i is not None] + # pyrefly: ignore [bad-argument-type] + tensors = [self.realize_input(x) for x in [x, values, *valid_indices]] + cpp_kernel_name = "aoti_torch_index_put_out" + super().__init__( + None, + NoneLayout(device=x.get_device()), + self.unwrap_storage(tensors), + (accumulate,), + python_kernel_name="aten.index_put_", + cpp_kernel_name=cpp_kernel_name, + op_overload=op_overload, + ) + V.graph.mark_buffer_mutated(self.input_name(0)) + self.name = V.graph.register_buffer(self) + V.graph.register_operation(self) + + +class DeviceCopy(ExternKernelOut): + @classmethod + def create(cls, x: IRNode, device: torch.device, non_blocking: bool) -> IRNode: + if ( + not x.is_extern() + # Can not apply this optimization if x has been mutated + and try_get_name(x) not in V.graph.mutated_buffers + and all(r in V.graph.constants for r in x.get_read_names()) + and not config.aot_inductor.use_runtime_constant_folding + ): + return x.constant_to_device(device) + + V.graph.add_device_info(device) + x_device = x.get_device() + assert x_device is not None + V.graph.add_device_info(x_device) + + developer_warning("DeviceCopy in input program") + constant_args = (non_blocking,) + # Device Copy should keep the same layout as input + x = ExternKernel.require_contiguous(x) + stride = None + if x.get_size(): + # x.get_stride() may be unimplemented if x's size is empty + stride = x.get_stride() + is_destination_pinned = ( + is_gpu(x_device.type) and device.type == "cpu" and non_blocking + ) + is_source_pinned = ( + x_device.type == "cpu" and is_gpu(device.type) and non_blocking + ) + if is_source_pinned and is_storage_and_layout(x): + x.get_layout().is_pinned = True + return DeviceCopy( + FixedLayout( + device, + x.get_dtype(), + x.get_size(), + stride, + is_pinned=is_destination_pinned, + ), + [cls.realize_input(x)], + constant_args, + ) + + def codegen(self, wrapper: PythonWrapperCodegen) -> None: + args = self.codegen_args() + assert len(args) == 2 + if self.output_view: + wrapper.codegen_device_copy( + args[0], self.output_view.codegen_reference(), args[1] + ) + else: + wrapper.codegen_device_copy(args[0], self.codegen_reference(), args[1]) + + +class DynamicSelectStorageOffset(ExternKernel): + """ + The result of computing a dynamic selection index is determined as follows: when the index in the + select operation is unbacked, the actual index calculation is ambiguous for negative indices + (index + size) versus non-negative indices (just index). To resolve this, we allocate an unbacked + SymInt to represent the storage offset and decompose the select operation into a call to as_strided, + computing the storage offset at runtime with this node. + """ + + def get_reads(self) -> OrderedSet[Dep]: + return OrderedSet() + + def should_allocate(self) -> bool: + return False + + def __init__( + self, + unbacked_offset_symbol: sympy.Symbol, + index: sympy.Symbol, + base_offset: Union[sympy.Symbol, int], + base_dim_stride: Union[sympy.Symbol, int], + size: Union[sympy.Symbol, int], + clamp: bool, + ) -> None: + super().__init__(None, NoneLayout(device=torch.device("cpu")), []) + # This node codegen the following: + # unbacked_offset_symbol = base_offset + base_dim_stride * (index if index >=0 else index + size) + self.unbacked_offset_symbol = unbacked_offset_symbol + self.index = index + self.base_offset = base_offset + self.base_dim_stride = base_dim_stride + self.size = size + self.clamp = clamp + + def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: + return OrderedSet([self.unbacked_offset_symbol]) + + @cache_on_self_and_args("DynamicSelectStorageOffset") + def get_free_symbol_uses( + self, unbacked_only: bool = False + ) -> OrderedSet[sympy.Symbol]: + return get_free_symbols(self.index, unbacked_only) + + def codegen(self, wrapper: PythonWrapperCodegen) -> None: + wrapper.codegen_dynamic_select_index(self, clamp=self.clamp) + + +class DynamicSliceSize(ExternKernel): + """ + Computes the output size of a slice call, handling the correct semantics in codegen. + We do this for flexible handling for unbacked indices (to not data-dependent error). + + Slicing has 4 semantics for indices, i.e. x[start:] could be: + 1) start < -x.size(0) -> x[0:] # negative out-of-bounds + 2) start in [-x.size(0), 0) -> x[x.size(0) + start:] # negative slicing + 3) start in [0, x.size(0)) -> x[start:] # standard slicing + 4) start >= x.size(0) -> empty slice # positive out-of-bounds + + If the appropriate semantics are known beforehand, the output size is computed based on + the start & end indices. If not (with unbacked indices), a new unbacked symbol is created + to represent the output size, and codegen handles computing the correct case. + """ + + def get_reads(self) -> OrderedSet[Dep]: + return OrderedSet() + + def should_allocate(self) -> bool: + return False + + def __init__( + self, + unbacked_size_symbol: sympy.Symbol, + start: Union[sympy.Symbol, int], + end: Union[sympy.Symbol, int], + step: Union[sympy.Symbol, int], + size: Union[sympy.Symbol, int], + ): + super().__init__(None, NoneLayout(device=torch.device("cpu")), []) + # This node codegen + self.unbacked_size_symbol = unbacked_size_symbol + self.start = start + self.end = end + self.step = step + self.size = size + + def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: + return OrderedSet([self.unbacked_size_symbol]) + + @cache_on_self_and_args("DynamicSliceSize") + def get_free_symbol_uses( + self, unbacked_only: bool = False + ) -> OrderedSet[sympy.Symbol]: + return get_free_symbols(self.start, unbacked_only).union( + get_free_symbols(self.end, unbacked_only) + ) + + def codegen(self, wrapper: PythonWrapperCodegen) -> None: + wrapper.codegen_dynamic_slice_size(self) + + +class DynamicScalar(ExternKernel): + """ + The result of a call to aten._local_scalar_dense. + """ + + def get_reads(self) -> OrderedSet[Dep]: + return OrderedSet() + + def should_allocate(self) -> bool: + return False + + def __init__( + self, sym: sympy.Symbol, keypath: pytree.KeyPath, data: IRNode + ) -> None: + data.realize() + super().__init__( + None, NoneLayout(device=torch.device("cpu")), self.unwrap_storage([data]) + ) + self.sym = sym + self.keypath = keypath + + def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: + return OrderedSet([self.sym]) + + def codegen(self, wrapper: PythonWrapperCodegen) -> None: + wrapper.codegen_dynamic_scalar(self) + + +class AssertScalar(ExternKernel): + """ + The result of a call to aten._assert_scalar + """ + + def get_reads(self) -> OrderedSet[Dep]: + return OrderedSet() + + def should_allocate(self) -> bool: + return False + + def __init__(self, scalar: SympyBoolean, msg: str) -> None: + super().__init__( + # Buffer(name, layotu) + None, + NoneLayout(device=torch.device("cpu")), + # InputsKernel(inputs) + [], + ) + self.scalar = scalar + self.msg = msg + + def has_side_effects(self) -> bool: + return True + + @cache_on_self_and_args("AssertScalar") + def get_free_symbol_uses( + self, unbacked_only: bool = False + ) -> OrderedSet[sympy.Symbol]: + return get_free_symbols(self.scalar, unbacked_only) + + def codegen(self, wrapper: PythonWrapperCodegen) -> None: + if not config.scalar_asserts: + return + # NB: It is EXTREMELY important not to simplify the scalar under assertion here, + # because simplify is done with respect to runtime asserts. So if you have + # "u0 == 0" in the runtime asserts, if you subsequently try to + # simplify(u0 == 0), you will get True (because we've already runtime assert'ed + # that it's true). But we're code generating the actual runtime assert here!! + symbol = next(iter(self.get_free_symbol_uses(unbacked_only=False))) + if V.graph.fx_wrapper: + # TODO fix + pass + elif V.graph.cpp_wrapper: + symbol_str = f"std::to_string({symbol})" + sizevar = V.graph.wrapper_code.codegen_cpp_sizevar( + self.scalar, simplify=False + ) + # TODO: when we start compiling in C++20, annotate with [[unlikely]]. + wrapper.writeline( + f'if (!({sizevar})) {{ throw std::runtime_error("Expected {self.msg} but received " + {symbol_str}); }}' + ) + else: + sizevar = V.graph.wrapper_code.codegen_python_sizevar( + self.scalar, simplify=False + ) + wrapper.writeline(f"if not ({sizevar}):") + wrapper.writeline(f" raise RuntimeError({repr(self.msg)})") + # No one should ever use this buffer, but for uniformity + # define the variable and assign it None + wrapper.writeline(f"{self.get_name()} = None") + + +@ir_dataclass(frozen=False) +class ExternKernelNode: + name: str + node: export_schema.Node + + +class FallbackKernel(ExternKernelAlloc): + """ + A class that represents a fallback kernel for handling operators that are not + directly support by inductor. It currently supports functional ops, view ops, + inplace aten ops, and mutating ops that are auto-functionalizable. + """ + + def __init__( + self, + layout: OutputSpec, + kernel: _OpOverloads, + tensor_args: Sequence[IRNode], + nontensor_args: Sequence[Any], + unflatten_args: Callable[..., Any], + kwargs: Optional[dict[str, Any]] = None, + *, + unbacked_bindings: Optional[dict[sympy.Symbol, pytree.KeyPath]] = None, + ) -> None: + super().__init__( + layout, + tuple(tensor_args), + tuple(nontensor_args), + op_overload=kernel, + ) + + self.use_runtime_dispatch = False + self.unbacked_bindings = unbacked_bindings or {} + + assert isinstance( + kernel, (torch._ops.OpOverload, torch._ops.HigherOrderOperator) + ), f"Fails to create FallbackKernel for {kernel}: {type(kernel)} not supported" + self.op_overload = kernel + self.unflatten_args = unflatten_args + self.kwargs = {} if kwargs is None else kwargs + assert self.python_kernel_name is not None + V.graph.warn_fallback(self.python_kernel_name) + + # args that are aliased + self.alias_names: list[str] = [] + # args that are mutated AND returned from the op + self.mutation_names: list[str] = [] + + if isinstance(self.op_overload, torch._ops.HigherOrderOperator): + # We assume here that HOPs with FallbackKernel are functional. + # This may not always be true! HOPs must individually opt-in to + # FallbackKernel, so please check this if you opt-in. + return + + if "_c10d_functional" in self.op_overload.name(): + # _c10d_functional kernels are lowered into _CollectiveKernel which + # derives from FallbackKernel for the cpp codegen. The kernels + # don't pass the can_auto_functionalize check, but their mutation + # is handled properly by _CollectiveKernel. + return + + schema = self.op_overload._schema + + # NOTE: [FallbackKernel supported operators] + # We only support three types of operators: + # - functional ops + # - view ops + # - inplace aten ops + # - mutating ops that are auto-functionalizable. That is, + # the operator may mutate any number of inputs, but its outputs + # may not alias any of the inputs. + # + # The unsupported cases usually do not show up here (because + # AOTAutograd functionalized them away); the only way for an in-place + # op to show up here is if a lowering or pass introduced it. + if torch._library.utils.mutates_and_returns_first_arg(self.op_overload): + self.mutation_names.append(tensor_args[0].get_name()) + return + + if schema.is_mutable and not can_auto_functionalize(kernel): + raise NotImplementedError( + f"NYI: Can't generate FallbackKernel for {kernel}" + ) + + args, kwargs = self.unflatten_args(self.inputs, self.constant_args) + + def handle_aliasing_and_mutation(info: torch._C.Argument, arg: Any) -> None: + # Assertions to make sure we didn't mismatch args + if isinstance(info.type, torch.ListType): + assert isinstance(arg, (list, tuple)), type(arg) + if library_utils.is_tensor_like_type(info.type): + # PyTorch also accepts None and scalar types for args marked as "Tensor". + # We're not going to check all of them here. + assert not isinstance(arg, (tuple, list)) + + if arg is None: + return + if info.alias_info is None: + return + + def add_alias(t: IRNode) -> None: + self.alias_names.append(t.get_name()) + assert info.alias_info is not None + if info.alias_info.is_write: + self.mutation_outputs.append( + MutationOutput(NoneLayout(device=t.get_device()), t, self) + ) + + if library_utils.is_tensorlist_like_type(info.type): + if arg is not None: + for optional_tensor_arg in arg: + add_alias(optional_tensor_arg) + else: + assert library_utils.is_tensor_like_type(info.type) + # pyrefly: ignore [bad-argument-type] + add_alias(arg) + + for info, arg in torch._library.utils.zip_schema(schema, args, kwargs): + handle_aliasing_and_mutation(info, arg) + + def get_read_writes(self) -> dependencies.ReadWrites: + read_writes = super().get_read_writes() + + if self.op_overload is torch._prims.rng_prims.graphsafe_run_with_rng_state: + for arg in self.constant_args: + if isinstance(arg, GeneratorState): + read_writes = read_writes.with_read( + dependencies.StarDep(arg.get_name()) + ) + + return read_writes + + def codegen_unbacked_symbol_defs(self, wrapper: PythonWrapperCodegen) -> None: + return wrapper.codegen_unbacked_symbol_defs_for_outputs( + self.get_name(), self.outputs, getattr(self, "unbacked_bindings", None) + ) + + def get_unbacked_symbol_defs(self) -> Container[sympy.Symbol]: # type: ignore[override] + if unbacked_bindings := getattr(self, "unbacked_bindings", None): + resolved = resolve_unbacked_bindings( + V.graph.sizevars.shape_env, unbacked_bindings + ) + assert resolved is not None + return resolved.keys() + else: + return OrderedSet() + + def codegen_args(self) -> list[str]: + @dataclasses.dataclass + class Shim: + ref: Any + + def __repr__(self) -> str: + return self.ref + + assert is_node_sequence(self.inputs) + tensor_args = [Shim(x.codegen_reference()) for x in self.inputs] + args, kwargs = self.unflatten_args(tensor_args, self.constant_args) + if V.graph.cpp_wrapper and isinstance(self.op_overload, torch._ops.OpOverload): + args = self.fill_non_provided_args(args, kwargs) + args = [ + V.graph.wrapper_code.val_to_arg_str(x, param.real_type) + for param, x in zip(self.op_overload._schema.arguments, args) + ] + else: + args = [V.graph.wrapper_code.val_to_arg_str(x) for x in args] + + # let self.codegen_kwargs handle kwargs + self.kwargs.update(kwargs) + return args + + @staticmethod + def find_device( + tensor_args: Optional[Sequence[torch.Tensor]], example_output: Sequence[Any] + ) -> Any: + non_torch_bind_tensor_args = ( + [t for t in tensor_args if not isinstance(t, TorchBindObject)] + if tensor_args + else None + ) + if non_torch_bind_tensor_args: + assert tensor_args + devices = [arg.get_device() for arg in tensor_args if arg.get_device()] + return devices[0] + if isinstance(example_output, torch.Tensor): + return example_output.device + if isinstance(example_output, (list, tuple)): + device_set = OrderedSet( + FallbackKernel.find_device(None, x) for x in example_output + ) + # Remove None + devices = [device for device in device_set if device] + if len(devices) == 1: + return devices[0] + for device in devices: + assert isinstance(device, torch.device) + if is_gpu(device.type): + return device + return devices[0] + return None + + def has_side_effects(self) -> bool: + from torch._library.utils import is_impure + + # Note: We don't pass args/kwargs here because they're IRNodes, not actual values + # The check is done on the op_overload itself + return is_impure(self.op_overload) # pyrefly: ignore[bad-argument-type] + + def get_inputs_that_alias_output(self) -> Sequence[str]: + assert isinstance( + self.op_overload, (torch._ops.OpOverload, torch._ops.HigherOrderOperator) + ), ( + f"Fails to create FallbackKernel for {self.op_overload}: " + f"{type(self.op_overload)} not supported" + ) + + # See [Note: FallbackKernel supported operators]: for a mutating + # op that is auto-functionalizable, its outputs does NOT + # alias any of the inputs. + if ( + not isinstance(self.op_overload, torch._ops.HigherOrderOperator) + and "_c10d_functional" not in self.op_overload.name() + and self.op_overload._schema.is_mutable + and can_auto_functionalize(self.op_overload) + ): + return [] + else: + return self.alias_names + + def get_mutation_names(self) -> Sequence[str]: + assert len(self.mutation_names) <= 1 + return self.mutation_names + + def export_extern_kernel_node(self): # type: ignore[no-untyped-def] + """ + ProxyExecutor Design Note + We export the ExternFallbackNodes (for custom ops) into a serialized file + and run it with a host side proxy executor to address the ABI problem + This is currently only implemented for fbcode. Eventually, we will also make this work for OSS. + Detailed design doc can be found at + https://docs.google.com/document/d/1wC4DOZFaYym2t1Esz0X5yxlLI3RDnSiyRbUus3bkJ64/edit?usp=sharing + """ + log.debug( + "Extern kernel node added for node %s with target %s.", + self.get_name(), + self.op_overload, + ) + + assert isinstance(self, FallbackKernel), type(self) + args, kwargs = self.unflatten_args(self.inputs, self.constant_args) + args = self.fill_non_provided_args(args, kwargs) + ordered_kwargs = [ + self.get_kwargs_value(key, **kwargs) + for key in self.ordered_kwargs_for_cpp_kernel + ] + target = self.op_overload + + if not V.graph.aot_mode: + # No need to serialize in the cpp wrapper JIT mode + return [*args, *ordered_kwargs] + + serializer = GraphModuleSerializer(None, []) # type: ignore[arg-type] + named_arguments = serializer.serialize_inputs(target, args, kwargs) + + # serialize_outputs + def handle_single_output( + return_type: Union[torch.TensorType, torch.ListType, torch.JitType], + output: Union[IRNode, Sequence[IRNode]], + ) -> export_schema.Argument: + if isinstance(return_type, (torch.TensorType, torch.NoneType)): + # For single Tensor or None + out = output + if isinstance(output, (list, tuple)): + assert len(output) == 1 + out = output[0] + if isinstance(return_type, torch.TensorType): + assert isinstance(out, IRNode) + return export_schema.Argument.create( + as_tensor=export_schema.TensorArgument(name=out.get_name()) + ) + else: # NoneType + assert out is None + return export_schema.Argument.create(as_none=True) + elif isinstance(return_type, torch.ListType) and isinstance( + return_type.getElementType(), torch.TensorType + ): + assert isinstance(output, Sequence), type(output) + # For single TensorList + return export_schema.Argument.create( + as_tensors=[ + export_schema.TensorArgument(name=out.get_name()) + for out in output + ] + ) + elif isinstance(return_type, torch.OptionalType) and isinstance( + return_type.getElementType(), torch.TensorType + ): + # For OptionalTensor + if output is None: + return export_schema.Argument.create( + as_optional_tensor=export_schema.OptionalTensorArgument.create( + as_none=True + ) + ) + else: + assert isinstance(output, IRNode) + return export_schema.Argument.create( + as_optional_tensor=export_schema.OptionalTensorArgument.create( + as_tensor=export_schema.TensorArgument( + name=output.get_name() + ) + ) + ) + elif isinstance(return_type, torch.IntType): + return export_schema.Argument.create(as_int=output) + else: + raise RuntimeError(f"Unsupported return type {type(return_type)}") + + if isinstance(target, torch._higher_order_ops.torchbind.CallTorchBind): + returns = target.schema(args[0], args[1]).returns + else: + returns = target._schema.returns # type: ignore[union-attr] + if len(returns) == 1: + # NOTE: [special handling of all_reduce_coalesced_'s return value] + # all_reduce_coalesced_ return a list of tensors via self.mutation_outputs + outputs = self.outputs if self.outputs else self.mutation_outputs + return_type = returns[0].real_type + output_arguments = [handle_single_output(return_type, outputs)] + else: + # For tuple returns, e.g "-> (Tensor, Tensor)" or "-> (Tesnor, Tensor[])" + # Not generating output args for self.mutation_outputs + output_arguments = [ + handle_single_output( + return_schema.real_type, # type: ignore[attr-defined] + output, + ) + for return_schema, output in zip(returns, self.outputs) + ] + + assert self.op_overload is not None + node = ExternKernelNode( + name=self.get_name(), + node=export_schema.Node( + target=self.op_overload.name(), + inputs=named_arguments, + outputs=output_arguments, + metadata={}, + ), + ) + + V.extern_kernel_nodes.append(node) + + return [*args, *ordered_kwargs] + + @override + def codegen(self, wrapper: PythonWrapperCodegen) -> None: + """Overrides the parent member. + See https://github.com/pytorch/pytorch/issues/151692""" + kernel = self.op_overload + assert kernel is not None + if kernel.namespace == "aten": + # Aten Fallback Ops + assert isinstance(kernel, torch._ops.OpOverload), type(kernel) + if V.graph.cpp_wrapper: + from torchgen.aoti.fallback_ops import inductor_fallback_ops + + if str(kernel) not in inductor_fallback_ops: + # C shim v2 is torchgen-ed, which should cover all aten ops. + # If you do hit a missed op, please update fallback_ops.py. + log.warning( + "%s is missing a c-shim implementation, using proxy executor as fallback", + kernel, + ) + self.use_runtime_dispatch = True + elif kernel.namespace == "_quantized": + # Internal Quantized Fallback Ops + assert isinstance(kernel, torch._ops.OpOverload), type(kernel) + elif V.graph.cpp_wrapper: + # For non-aten OpOverload, i.e. custom ops + # If the op is in custom_ops_to_c_shims, generate direct function call + self.use_runtime_dispatch = ( + kernel not in config.aot_inductor.custom_ops_to_c_shims + ) + + # Handle the special case where a complex number is input to a C-shim kernel for + # a scalar input. The torchgen'ed shim API will use type "double", which is + # incompatible with complex numbers, forcing a fallback to runtime dispatch. + if ( + V.graph.cpp_wrapper + and isinstance(kernel, torch._ops.OpOverload) + and not self.use_runtime_dispatch + ): + + def is_number(t: torch.JitType) -> bool: + if isinstance(t, torch.OptionalType): + return is_number(t.getElementType()) + return isinstance(t, torch.NumberType) + + # Using unflatten_args is a bit of a hack, but all the complex arguments we + # care about are in self.constant_args, and calling unflatten_args puts them + # in the correct order without triggering codegen. + args, kwargs = self.unflatten_args(self.inputs, self.constant_args) + # Append kwarg values to args. ordered_kwargs_for_cpp_kernel is guaranteed + # to be set, since this is an OpOverload kernel. + args_iter = itertools.chain( + args, + ( + self.get_kwargs_value(k, **kwargs) + for k in self.ordered_kwargs_for_cpp_kernel + ), + ) + self.use_runtime_dispatch = any( + isinstance(v, complex) and is_number(a.real_type) + for v, a in zip(args_iter, kernel._schema.arguments) + ) + + self.codegen_comment(wrapper) + if self.use_runtime_dispatch: + exported_args = self.export_extern_kernel_node() + assert self.python_kernel_name is not None + assert self.op_overload is not None + + wrapper.generate_fallback_kernel_with_runtime_lookup( + self.get_name(), + self.python_kernel_name, + lambda: [*self.codegen_args(), *self.codegen_kwargs()], + self.op_overload, + exported_args, + # NOTE: [special handling of all_reduce_coalesced_'s return value] + self.outputs if self.outputs else self.mutation_outputs, + ) + else: + wrapper.generate_fallback_kernel(self) + if isinstance(self.layout, Layout): + self.codegen_size_asserts(wrapper) + self.codegen_alignment_asserts(wrapper) + self.codegen_memory_tracking(wrapper) + + self.codegen_unbacked_symbol_defs(wrapper) + + @staticmethod + def tensor_to_layout(output: torch.Tensor) -> FixedLayout: + is_pinned = False + try: + is_pinned = output.is_pinned() + except RuntimeError: + # dispatch not implemented + pass + return FixedLayout( + output.device, + output.dtype, + convert_shape_to_inductor(output.size()), + convert_shape_to_inductor(output.stride()), + is_pinned=is_pinned, + ) + + @classmethod + def create(cls, kernel: _OpOverloads, *args: Any, **kwargs: Any) -> FallbackKernel: + """Create an instance of FallbackKernel from an _OpOverloads""" + fake_incorrect_kernels = (aten._fused_moving_avg_obs_fq_helper_functional,) + if kernel not in fake_incorrect_kernels: + context = cast(AbstractContextManager[None], V.graph.fake_mode) + else: + context = nullcontext() + + with context: + ( + example_output, + tensor_args, + non_tensor_args, + unflatten_args, + unbacked_bindings, + ) = cls.process_kernel(kernel, *args, **kwargs) + + # We need this extra check for input alignment since the example + # inputs we created are always aligned. + has_unaligned_input = any(is_unaligned(arg) for arg in tensor_args) + + device = cls.find_device(tensor_args, example_output) + + if not device and isinstance( + kernel, torch._higher_order_ops.torchbind.CallTorchBind + ): + # use CPU device for torchbind methods that don't take in or output any tensor, e.g. size() + device = torch.device("cpu") + + if example_output is None: + packed = cls( + NoneLayout(device=device), + kernel, + tensor_args, + non_tensor_args, + unflatten_args, + unbacked_bindings=unbacked_bindings, + ) + + else: + assert device, "Not sure where to find device info" + packed = cls( + MultiOutputLayout(device=device), + kernel, + tensor_args, + non_tensor_args, + unflatten_args, + unbacked_bindings=unbacked_bindings, + ) + + def generate_output(output: Any, indices: list[tuple[Any, int]]) -> Any: + if isinstance(output, (list, tuple)): + return type(output)( + generate_output(output[i], indices + [(type(output), i)]) + for i in range(len(output)) + ) + elif isinstance(output, dict): + return { + key: generate_output(val, indices + [(type(output), key)]) + for key, val in output.items() + } + elif isinstance(output, torch.Tensor): + buf = MultiOutput( + cls.tensor_to_layout(output), + packed, + indices, + ) + if ( + config.assume_unaligned_fallback_output + or has_unaligned_input + or not tensor_is_aligned(output) + ): + V.graph.unaligned_buffers.add(buf.name) # type: ignore[arg-type] + return buf + elif isinstance(output, int): + return output + elif isinstance(output, torch.SymInt): + return output.node.expr + else: + assert output is None, ( + f"FallbackKernel output type {type(output)} is not supported" + ) + return None + + outputs = generate_output(example_output, []) + if isinstance(outputs, (list, tuple)): + packed.outputs = outputs + elif isinstance(outputs, dict): + packed.outputs = tuple(outputs) + else: + packed.outputs = [outputs] + # pyrefly: ignore [bad-return] + return outputs + + +@ir_dataclass(frozen=False) +class ComplexView(FallbackKernel): + """View a complex number as two dtyped numbers or vice versa""" + + def should_allocate(self) -> bool: + return False + + def get_inputs_that_alias_output(self) -> Sequence[str]: + # Signal to codegen that our output buffer isn't safe to reuse + return [self.input_name(0)] + + def __init__( + self, + layout: OutputSpec, + kernel: _OpOverloads, + tensor_args: Sequence[IRNode], + nontensor_args: Sequence[Any], + unflatten_args: Callable[..., Any], + *, + unbacked_bindings: Optional[dict[sympy.Symbol, pytree.KeyPath]] = None, + ) -> None: + super().__init__( + layout, + kernel, + tensor_args, + nontensor_args, + unflatten_args, + unbacked_bindings=unbacked_bindings, + ) + + +class MemoryCheckKernel(FallbackKernel): + """ + Custom kernel for memory checking that generates direct function calls + + TODO - the custom op was erroring with str inputs. should be able to custom op directly. + """ + + def codegen(self, wrapper: PythonWrapperCodegen) -> None: + """Override codegen to write direct function call""" + # Extract our arguments from nontensor_args + wrapper.write_memory_track_allocation_once() + alive_list, dead_list, is_final_step = self.constant_args + + alive_repr = repr(alive_list) + dead_repr = repr(dead_list) + if is_final_step: + wrapper.writeline( + "# note: dont currently distinguish between buffers returned and dealloc'd in last step" + ) + call = f"check_memory_step(allocated={alive_repr}, freed={dead_repr}, is_final_step={is_final_step})" + else: + call = f"check_memory_step(allocated={alive_repr}, freed={dead_repr})" + wrapper.writeline(call) + + +@ir_dataclass +class MultiOutputLayout(OutputSpec): + device: torch.device + + def get_device(self) -> Optional[torch.device]: + return self.device + + +class MultiOutput(ExternKernel): + def codegen(self, wrapper: PythonWrapperCodegen) -> None: + wrapper.codegen_multi_output(self) + if not self.skip_size_stride_alignment_checks: + self.codegen_size_asserts(wrapper) + self.codegen_alignment_asserts(wrapper) + + def __init__( + self, + layout: OutputSpec, + input: IRNode, + indices: list[tuple[Any, ...]], + skip_size_stride_alignment_checks: bool = False, + ) -> None: + super().__init__(None, layout, [input], ()) + self.name = V.graph.register_buffer(self) + V.graph.register_operation(self) + self.indices = indices + self.skip_size_stride_alignment_checks = skip_size_stride_alignment_checks + + @cache_on_self_and_args("MultiOutput") + def get_free_symbol_uses( + self, unbacked_only: bool = False + ) -> OrderedSet[sympy.Symbol]: + input_node = self.inputs[0] + assert isinstance(input_node, IRNode), input_node + return input_node.get_free_symbol_uses(unbacked_only) + + def should_allocate(self) -> bool: + return len(self.inputs) == 1 and ( + isinstance(self.inputs[0], CppTemplateBuffer) # Grouped GEMM + ) + + def get_inputs_that_alias_output(self) -> Sequence[str]: + return [ + inp.get_name() + for inp in self.inputs + if isinstance(inp, FallbackKernel) + and len(inp.get_inputs_that_alias_output()) > 0 + ] + + +# We just use a normal dataclass for MutableBox/TensorBox/StorageBox since +# they're mainly lowering-time constructs that we expect to mutate and such. +@dataclasses.dataclass +class MutableBox(IRNode): + """ + TensorBox / StorageBox allow in-place mutation of Tensors + """ + + data: IRNode + + def has_exceeded_max_reads(self) -> bool: + return self.data.has_exceeded_max_reads() + + def get_device(self) -> Optional[torch.device]: + return self.data.get_device() + + def make_loader(self) -> Callable[[Sequence[Expr]], OpsValue]: + return self.data.make_loader() + + def make_indexer(self) -> Callable[[Sequence[Expr]], Expr]: + return self.data.make_indexer() + + def get_stride(self) -> Sequence[_IntLike]: + return self.data.get_stride() + + def get_name(self) -> str: + return self.data.get_name() + + def has_large_inner_fn(self, threshold: Optional[int] = None) -> bool: + return self.data.has_large_inner_fn(threshold) + + def mark_reuse(self, users: int) -> None: + return self.data.mark_reuse(users) + + def realize_hint(self) -> None: + return self.data.realize_hint() + + def unwrap_view(self) -> IRNode: + return self.data.unwrap_view() + + def is_input_buffer(self) -> bool: + return self.data.is_input_buffer() + + def freeze_layout(self) -> None: + return self.data.freeze_layout() + + def freeze_layout_with_stride_order( + self, order: Sequence[int], allow_padding: bool = False + ) -> None: + return self.data.freeze_layout_with_stride_order(order, allow_padding) + + def freeze_layout_with_fill_order(self, order: Sequence[int]) -> None: + return self.data.freeze_layout_with_fill_order(order) + + def freeze_layout_with_same_order(self, stride: Sequence[_IntLike]) -> None: + return self.data.freeze_layout_with_same_order(stride) + + def freeze_layout_with_exact_strides( + self, exact_strides: Sequence[_IntLike], allow_padding: bool = False + ) -> None: + return self.data.freeze_layout_with_exact_strides(exact_strides, allow_padding) + + def get_read_writes(self) -> dependencies.ReadWrites: + return self.data.get_read_writes() + + def get_reads(self) -> OrderedSet[Dep]: + return self.data.get_reads() + + def num_reads(self) -> int: + return self.data.num_reads() + + def get_storage_numel(self) -> _IntLike: + return self.data.get_storage_numel() + + def get_reduction_type(self) -> Optional[str]: + return self.data.get_reduction_type() + + def get_reduction_size(self) -> Sequence[Expr]: + return self.data.get_reduction_size() + + def is_extern(self) -> bool: + return self.data.is_extern() + + def is_no_op(self) -> bool: + return self.data.is_no_op() + + def constant_to_device(self, device: torch.device) -> IRNode: + return self.data.constant_to_device(device) + + def get_mutation_names(self) -> Sequence[str]: + return self.data.get_mutation_names() + + def get_operation_name(self) -> str: + return self.data.get_operation_name() + + def get_inputs_that_alias_output(self) -> Sequence[str]: + return self.data.get_inputs_that_alias_output() + + def realize(self) -> Optional[str]: + return self.data.realize() + + @cache_on_self_and_args("MutableBox") + def get_free_symbol_uses( + self, unbacked_only: bool = False + ) -> OrderedSet[sympy.Symbol]: + return self.data.get_free_symbol_uses(unbacked_only) + + def get_read_names(self) -> OrderedSet[str]: + return self.data.get_read_names() + + def get_defining_op(self) -> Optional[Operation]: + return self.data.get_defining_op() + + def codegen_reference(self, writer: Optional[IndentedBuffer] = None) -> str: + return self.data.codegen_reference(writer) + + @property + def layout(self) -> OutputSpec: + # we intentionally call get_output_spec (rather than get_layout) since Buffer.layout is an OutputSpec + return self.data.get_output_spec() + + def get_layout(self) -> Layout: + return self.data.get_layout() + + def get_output_spec(self) -> OutputSpec: + return self.data.get_output_spec() + + def get_size(self) -> Sequence[Expr]: + return self.data.get_size() + + @property + def dtype(self) -> torch.dtype: + return self.data.dtype + + def __str__(self) -> str: + if isinstance(self.data, MutableBox): + line0 = f"{type(self).__name__}({type(self.data).__name__}(" + endl = "))" + inner = self.data.data + else: + line0 = f"{type(self).__name__}(" + inner = self.data + endl = ")" + + lines = [ + line0, + indent(str(inner)), + endl, + ] + return "\n".join(lines) + + __repr__ = __str__ + + +class TensorBox(MutableBox): + @staticmethod + def create(data: IRNode) -> Union[TensorBox, ShapeAsConstantBuffer]: + if isinstance(data, ShapeAsConstantBuffer): + return data + return TensorBox(StorageBox(data)) + + +class StorageBox(MutableBox): + """ + StorageBox allow in-place mutation of Tensors + """ + + def is_input_buffer(self) -> bool: + if isinstance(self.data, (InputBuffer, ReinterpretView)): + return self.data.get_name() in V.graph.graph_inputs + return False + + def is_module_buffer(self) -> bool: + return ( + isinstance(self.data, (ConstantBuffer)) + and self.data.get_name() in V.graph.constants + ) + + def realize(self) -> Optional[str]: + if IRNode.is_realized_node(self.data): + return self.data.get_name() + + assert isinstance(self.data, (Pointwise, Reduction, Scan, Sort)), type( + self.data + ) + origin_node = self.data.get_origin_node() + traceback = self.data.get_traceback() + device = self.data.get_device() + assert device is not None + + self.data = ComputedBuffer( + name=None, + layout=FlexibleLayout( + device=device, + dtype=self.data.get_dtype(), + size=self.data.get_size(), + is_pinned=False, + ), + data=self.data, + ) + self.data.name = V.graph.register_buffer(self.data) + V.graph.register_operation(self.data) + self.data.origins = self.origins + self.data.origin_node = origin_node + self.data.traceback = traceback + return self.data.name + + def realize_hint(self) -> None: + """ + Called on buffers we expect to be forced to realize later. + """ + if ( + isinstance(self.data, (Pointwise, Reduction)) + and self.data.inner_fn_opcount().nontrivial_read_count > 1 + ): + self.realize() + + def has_accumulated_enough_reads_by_size(self, threshold: int) -> bool: + from torch._inductor.utils import is_nonfreeable_buffers + + size_of_reads = [ + V.graph.get_dep_size_hint(dep) + for dep in self.get_reads() + if not is_nonfreeable_buffers(dep) + ] + if not size_of_reads: + return False + total_size = sum(size_of_reads) + max_size = max(size_of_reads) + min_size = min(size_of_reads) + return ( + total_size >= threshold + and total_size / max_size >= 2 + and max_size == min_size + ) + + def has_exceeded_max_reads(self) -> bool: + return isinstance(self.data, Pointwise) and ( + self.num_reads() > config.realize_acc_reads_threshold + or self.has_large_inner_fn() + or ( + config.realize_acc_reads_size_threshold is not None + and self.has_accumulated_enough_reads_by_size( + config.realize_acc_reads_size_threshold + ) + ) + ) + + def should_realize_on_reuse(self, users: int) -> bool: + """ + A heuristic to decide if we should realize a tensor + that is used multiple times. + """ + if users > 1 and isinstance(self.data, (Pointwise, Reduction)): + if is_cpu(self.data): + # Heuristic for realizing reused result of heavy ops on cpu + opcount = self.data.inner_fn_opcount() + heavy_ops = ["exp", "sigmoid"] # a list of heavy ops + if any(x in opcount.used_ops for x in heavy_ops): + return True + return ( + self.num_reads() > config.realize_reads_threshold + or self.has_large_inner_fn() + ) + return False + + def mark_reuse(self, users: int) -> None: + if self.should_realize_on_reuse(users): + self.realize() + + def num_reads(self) -> int: + return self.data.num_reads() + + +@ir_dataclass(frozen=False) +class Subgraph(IRNode): + name: str + graph_module: torch.fx.GraphModule + graph: Optional[GraphLowering] = None + + +def _has_aliased_buffers(buffers: Sequence[IRNode]) -> bool: + buffers = [ + buffer.unwrap_view() if isinstance(buffer, ReinterpretView) else buffer + for buffer in buffers + ] + # assuming the same buffer is represented by the same IRNode object + return len(OrderedSet(id(buffer) for buffer in buffers)) < len(buffers) + + +@ir_dataclass(frozen=False) +class InvokeSubgraph(ExternKernel): + """ + Ir node for the invoke_subgraph HOP. + """ + + subgraph: Optional[Subgraph] = None + operands: Optional[Sequence[IRNode]] = None + outputs: Optional[Sequence[IRNode]] = None + + def __init__( + self, subgraph: Subgraph, operands: Sequence[IRNode], layout: MultiOutputLayout + ) -> None: + super().__init__( + name=None, + layout=layout, + inputs=operands, + ) + self.subgraph = subgraph + self.name = V.graph.register_buffer(self) + V.graph.register_operation(self) + + @classmethod + def create( + cls, subgraph: Subgraph, *operands: IRNode + ) -> list[Union[ShapeAsConstantBuffer, NoneAsConstantBuffer, MultiOutput]]: + """For each operand, get a realized input, force it to have the same + strides as the subgraph inputs, then use an InvokeSubgraph""" + from .lowering import constrain_to_fake_tensor + + # TODO(anijain2305) - Support sym expr as operands in future. + current_node = V.graph.current_node + + fake_operands = None + if eager_input_vals := current_node.meta.get("eager_input_vals"): + # eager_input_vals is (args_values, kwargs_values). We need args for invoke_subgraph + offset = 2 + if current_node.target is torch.ops.higher_order.with_effects: + # Aruguments eagerly are (token, subgraph, identifier, *operands) + assert current_node.args[1] is torch.ops.higher_order.invoke_subgraph + offset = 3 + fake_operands = eager_input_vals[0][offset:] + else: + offset = 2 + if current_node.target is torch.ops.higher_order.with_effects: + # with_effects args: (token, invoke_subgraph, subgraph, identifier, *operands) + assert current_node.args[1] is torch.ops.higher_order.invoke_subgraph + offset = 4 + + # For the partitioned backward graph, we do not have + # eager_input_vals. Here, we rely on the recorded example values. + fx_operands = current_node.args[offset:] + fake_operands = [x.meta["val"] for x in fx_operands] # type: ignore[union-attr] + + # Realize the inputs. Also intermediates can have different strides than + # the inputs of the subgraph. So, force the intermediates to have same + # strides as that of subgraph inputs. + # pyrefly: ignore [annotation-mismatch] + operands: list[IRNode] = [cls.realize_input(x) for x in operands] + new_operands: list[IRNode] = [] + + for idx, operand in enumerate(operands): + if isinstance(operand, (ShapeAsConstantBuffer, GeneratorState)): + new_operands.append(operand) + else: + new_operands.append( + constrain_to_fake_tensor(operand, fake_operands[idx]) + ) + + # pyrefly: ignore [bad-assignment] + operands = new_operands + + if subgraph.graph is None: + # create and lower subgraphs + subgraph.graph = V.graph.make_subgraph( + gm=subgraph.graph_module, + example_inputs=fake_operands, + subgraph_name=subgraph.name, + ) + with V.set_graph_handler(subgraph.graph): + subgraph.graph.run(*fake_operands) + + outputs = subgraph.graph.graph_outputs + + # Find the device - operands could be integers from shapes, so we can't + # use operands[0] + device = None + for operand in operands: + if not isinstance(operand, ShapeAsConstantBuffer): + device = operand.get_device() + break + assert device is not None + invoke_subgraph = InvokeSubgraph( + subgraph=subgraph, + operands=operands, + layout=MultiOutputLayout(device=device), + ) + + def create_output( + output: IRNode, ind: int + ) -> Union[ShapeAsConstantBuffer, NoneAsConstantBuffer, MultiOutput]: + if isinstance(output, (ShapeAsConstantBuffer, NoneAsConstantBuffer)): + return output + else: + device = output.get_device() + assert device is not None + + return MultiOutput( + FixedLayout( + device=device, + dtype=output.get_dtype(), + size=output.get_size(), + stride=output.get_stride(), + offset=output.get_layout().offset, + is_pinned=output.get_layout().is_pinned, + ), + invoke_subgraph, # type: ignore[has-type] + [(list, ind)], + skip_size_stride_alignment_checks=True, + ) + + outs = [create_output(output, i) for i, output in enumerate(outputs)] + invoke_subgraph.outputs = outs # type: ignore[assignment] + return outs + + def codegen(self, wrapper: PythonWrapperCodegen) -> None: + wrapper.codegen_invoke_subgraph(self) + + +@ir_dataclass(frozen=False) +class Conditional(ExternKernel): + predicate: Optional[IRNode] = None + operands: Optional[Sequence[IRNode]] = None + true_subgraph: Optional[Subgraph] = None + false_subgraph: Optional[Subgraph] = None + outputs: Optional[Sequence[MultiOutput]] = None + + def __init__( + self, + predicate: IRNode, + operands: Sequence[IRNode], + true_subgraph: Subgraph, + false_subgraph: Subgraph, + layout: MultiOutputLayout, + unbacked_bindings: Optional[dict[sympy.Symbol, pytree.KeyPath]], + ) -> None: + self.predicate = predicate + self.operands = operands + self.true_subgraph = true_subgraph + self.false_subgraph = false_subgraph + + sym_args, tensor_args = _split_by_sym_type([predicate, *operands]) + + super().__init__( + name=None, + layout=layout, + inputs=tensor_args, + constant_args=sym_args, + ) + if unbacked_bindings is not None: + self.unbacked_bindings = unbacked_bindings + + self.name = V.graph.register_buffer(self) + V.graph.register_operation(self) + + @staticmethod + def _maybe_expr(s: Union[int, torch.SymInt]) -> Union[int, sympy.Expr]: + if isinstance(s, int): + return s + return s.node.expr + + @classmethod + def create( + cls, + predicate: TensorBox, + true_fn: Subgraph, + false_fn: Subgraph, + operands: list[Union[TensorBox, ShapeAsConstantBuffer]], + ) -> Sequence[IRNode]: + """Create a Sequence of IRNodes from a conditional statement (see .lowering.cond)""" + # pyrefly: ignore [bad-assignment] + predicate = cls.realize_input(predicate) + # pyrefly: ignore [bad-assignment] + operands = [cls.realize_input(x) for x in operands] + fx_operands: Argument = V.graph.current_node.args[-1] + + assert isinstance(fx_operands, Sequence), type(fx_operands) + assert all(isinstance(n, Node) for n in fx_operands) + fake_operands = [cast(Node, x).meta["val"] for x in fx_operands] + + for subgraph in (true_fn, false_fn): + if subgraph.graph is None: + # create and lower subgraphs + subgraph.graph = V.graph.make_subgraph( + gm=subgraph.graph_module, + example_inputs=fake_operands, + subgraph_name=subgraph.name, + ) + with V.set_graph_handler(subgraph.graph): + subgraph.graph.run(*fake_operands) + + assert true_fn.graph is not None + assert false_fn.graph is not None + true_outputs = true_fn.graph.graph_outputs + false_outputs = false_fn.graph.graph_outputs + + for name, outputs in (("true_fn", true_outputs), ("false_fn", false_outputs)): + if _has_aliased_buffers(true_outputs): + raise AssertionError( + "Output aliasing is currently not supported in compiled torch.cond. " + f"The outputs of the {name} subgraph of torch.cond are aliased: {outputs}" + ) + + # make sure true and false outputs are structurally equivalent + assert len(true_outputs) == len(false_outputs), (true_outputs, false_outputs) + for i, (t_o, f_o) in enumerate(zip(true_outputs, false_outputs)): + assert t_o.get_device() == f_o.get_device(), (i, t_o, f_o) + assert t_o.get_dtype() == f_o.get_dtype(), (i, t_o, f_o) + assert t_o.get_layout().offset == f_o.get_layout().offset, (i, t_o, f_o) + + # Determine device from operands and predicate + # The predicate can be on a different device (e.g., CPU for control flow) + # while the data operands and outputs should be on the compute device, so + # using predicate device as a fallback. + device = next( + o.get_device() + for o in operands + [predicate] + if not isinstance(o, ShapeAsConstantBuffer) + ) + unbacked_bindings = resolve_unbacked_bindings( + V.graph.sizevars.shape_env, + V.graph.current_node.meta.get("unbacked_bindings", None), + ) + assert device is not None, "cannot determine device" + conditional = Conditional( + predicate=predicate, + operands=operands, + true_subgraph=true_fn, + false_subgraph=false_fn, + layout=MultiOutputLayout(device=device), + unbacked_bindings=unbacked_bindings, + ) + + outputs = [ + MultiOutput( + FixedLayout( + device=output.get_device() + if output.get_device() is not None + else device, # type: ignore[arg-type] + dtype=output.get_dtype(), + size=[Conditional._maybe_expr(sz) for sz in merged_output.size()], + stride=[ + Conditional._maybe_expr(sz) for sz in merged_output.stride() + ], + offset=output.get_layout().offset, + is_pinned=output.get_layout().is_pinned, + ), + conditional, + [(list, i)], + ) + # as the true and false outputs are equivalent, + # we can use either of them here as a "template" + for i, (output, merged_output) in enumerate( + zip(true_outputs, V.graph.current_node.meta["val"]) + ) + ] + + conditional.outputs = outputs # type: ignore[assignment] + return outputs + + def codegen(self, wrapper: PythonWrapperCodegen) -> None: + wrapper.codegen_conditional(self) + wrapper.codegen_unbacked_symbol_defs_for_outputs( + self.get_name(), self.outputs, getattr(self, "unbacked_bindings", {}) + ) + + def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: + if unbacked_bindings := getattr(self, "unbacked_bindings", None): + resolved = resolve_unbacked_bindings( + V.graph.sizevars.shape_env, unbacked_bindings + ) + assert resolved is not None + return OrderedSet(resolved.keys()) + else: + return OrderedSet() + + +def _split_by_sym_type( + args: list[Any], +) -> tuple[list[ShapeAsConstantBuffer], list[Any]]: + non_sym_args = [] + sym_args = [] + for arg in args: + if isinstance(arg, ShapeAsConstantBuffer): + sym_args.append(arg.expr) + else: + non_sym_args.append(arg) + + return sym_args, non_sym_args + + +@ir_dataclass(frozen=False) +class WhileLoop(ExternKernel): + """The IR node for while_loop and while_loop_stack_output. It supports input mutation.""" + + carried_inputs: Optional[Sequence[IRNode]] = None + additional_inputs: Optional[Sequence[IRNode]] = None + cond_subgraph: Optional[Subgraph] = None + body_subgraph: Optional[Subgraph] = None + outputs: Optional[Sequence[MultiOutput]] = None + + def __init__( + self, + carried_inputs: Sequence[IRNode], + additional_inputs: Sequence[IRNode], + cond_subgraph: Subgraph, + body_subgraph: Subgraph, + layout: MultiOutputLayout, + unbacked_bindings: Optional[dict[sympy.Symbol, pytree.KeyPath]], + stack_output: bool, + ) -> None: + self.carried_inputs = carried_inputs + self.additional_inputs = additional_inputs + self.cond_subgraph = cond_subgraph + self.body_subgraph = body_subgraph + + sym_args, tensor_args = _split_by_sym_type( + [*carried_inputs, *additional_inputs] + ) + super().__init__( + name=None, + layout=layout, + inputs=tensor_args, + constant_args=sym_args, + ) + if unbacked_bindings is not None: + self.unbacked_bindings = unbacked_bindings + self.stack_output = stack_output + + self.name = V.graph.register_buffer(self) + V.graph.register_operation(self) + + # Accidental aliasing can be created due to cse, where the empty buffers we + # allocated for backward to use gets csed into the same buffer in function fx_graph_cse. + # See test_scan_multiple_layers_gradient for a concrete example. + @staticmethod + def _clone_aliased_inputs(carried_inputs: Sequence[IRNode]) -> Sequence[IRNode]: + if not _has_aliased_buffers(carried_inputs): + return carried_inputs + + # Import clone from lowering module + + # Unwrap views to get the underlying buffers for comparison + unwrapped_buffers = [ + buffer.unwrap_view() if isinstance(buffer, ReinterpretView) else buffer + for buffer in carried_inputs + ] + + # Track which buffers we've seen and their indices + seen_buffers: OrderedSet[int] = OrderedSet() + result: list[Union[IRNode, TensorBox, ShapeAsConstantBuffer]] = [] + + for original_input, unwrapped_buffer in zip(carried_inputs, unwrapped_buffers): + if id(unwrapped_buffer) in seen_buffers: + result.append(ExternKernel.copy_input(original_input)) + else: + seen_buffers.add(id(unwrapped_buffer)) + result.append(original_input) + + return result + + @staticmethod + def _maybe_wrap_as_tensor_box(out: IRNode) -> IRNode: + if isinstance(out, TensorBox): + return out + elif isinstance(out, (StorageBox, ReinterpretView)): + return TensorBox(out) + elif isinstance(out, MultiOutput): + return TensorBox.create(out) + else: + raise RuntimeError(f"NYI unsupported output type: {type(out)}") + + @classmethod + def create( + cls, + cond_fn: Subgraph, + body_fn: Subgraph, + carried_inputs: Sequence[IRNode], + additional_inputs: Sequence[IRNode], + stack_output: bool, + ) -> Union[IRNode, Sequence[IRNode]]: + """create the while_loop IR node. stack_output controls whether it stack + each iterations' output, which is necessary for training. + """ + from torch._higher_order_ops.utils import check_input_alias_and_mutation + + def _require_exact_strides( + tensor_boxes: Sequence[IRNode], + fake_tensors: list[Union[int, torch.SymInt, torch.Tensor]], + ) -> list[IRNode]: + assert len(tensor_boxes) == len(fake_tensors) + ret = [] + for tb, fk in zip(tensor_boxes, fake_tensors): + if isinstance(fk, torch.Tensor): + # Subgraph lowering always return StorageBox as graph_outputs because + # it realizes the outputs. + # + # However, require_exact_strides is expecting TensorBox + # e.g. in require_exact_strides when an expand happens, + # the fake tensor's stride is (0, 0, 0) but the storage + # box might have a different stride so lowering.slice_ + # is used to make the stride consistent and it expects input to + # be TensorBox. + # + # So we wrap the inputs as tensor boxes if they're not yet. + new_tb = WhileLoop._maybe_wrap_as_tensor_box(tb) + ret.append( + ExternKernel.require_exact_strides( + new_tb, fk.stride(), allow_padding=False + ) + ) + else: + ret.append(tb) + return ret + + fx_carried_inputs = V.graph.current_node.args[-2] + fx_additional_inputs = V.graph.current_node.args[-1] + fx_all_inputs = fx_carried_inputs + fx_additional_inputs # type: ignore[operator] + fake_all_inputs = [x.meta["val"] for x in fx_all_inputs] # type: ignore[union-attr] + fake_carried_inputs = [x.meta["val"] for x in fx_carried_inputs] # type: ignore[union-attr] + fake_additional_inputs = [x.meta["val"] for x in fx_additional_inputs] # type: ignore[union-attr] + + carried_inputs_ = [cls.realize_input(x) for x in carried_inputs] + carried_inputs_ = WhileLoop._clone_aliased_inputs(carried_inputs_) + carried_inputs_ = _require_exact_strides(carried_inputs_, fake_carried_inputs) + additional_inputs_ = [cls.realize_input(x) for x in additional_inputs] + additional_inputs_ = _require_exact_strides( + additional_inputs_, fake_additional_inputs + ) + all_inputs = carried_inputs_ + additional_inputs_ + + for subgraph in (cond_fn, body_fn): + if subgraph.graph is None: + # create and lower subgraphs + assert isinstance(fx_all_inputs, Sequence), type(fx_all_inputs) + subgraph.graph = V.graph.make_subgraph( + gm=subgraph.graph_module, + example_inputs=fx_all_inputs, # type: ignore[arg-type] + subgraph_name=subgraph.name, + ) + with V.set_graph_handler(subgraph.graph): + subgraph.graph.run(*fake_all_inputs) + # For body_fn, we require its output to have the exact same stride + # as inputs because the previous output is the input of next iteration. + # + # This cannot be automatically done in graph lowering because body_fn's graph outputs + # are not user-facing so the special handling for strides of user-facing output in graph + # lowering is not applicable. + if subgraph is body_fn: + assert len(subgraph.graph.graph_outputs) == len( + fake_carried_inputs + ) + subgraph.graph.graph_outputs = _require_exact_strides( # type: ignore[assignment] + subgraph.graph.graph_outputs, + fake_carried_inputs, + ) + + assert cond_fn.graph and body_fn.graph + cond_outputs = cond_fn.graph.graph_outputs + body_outputs = body_fn.graph.graph_outputs + + if _has_aliased_buffers(body_outputs): + raise AssertionError( + "Output aliasing is currently not supported in compiled torch.while_loop. " + f"The outputs of the body_fn subgraph of torch.while_loop are aliased: {body_outputs}" + ) + + # make sure cond_fn returns a boolean scalar Tensor + assert len(cond_outputs) == 1, cond_outputs + p = cond_outputs[0] + if not isinstance(p, ShapeAsConstantBuffer): + assert p.get_dtype() == torch.bool, p + assert len(p.get_size()) == 0, p + + assert len(all_inputs) > 0, ( + "torch.while_loop is assumed to have at least one operand." + ) + + device = all_inputs[0].get_device() + + assert device is not None # to make linter happy + # make sure carried_inputs_ and body outputs are structurally equivalent + assert len(carried_inputs_) == len(body_outputs), ( + carried_inputs_, + body_outputs, + ) + for i, (op, bo) in enumerate(zip(carried_inputs_, body_outputs)): + + def _guard_list_equals( + lhs_exprs: Sequence[Union[int, sympy.Expr]], + rhs_exprs: Sequence[Union[int, sympy.Expr]], + ) -> None: + assert len(lhs_exprs) == len(rhs_exprs) + for lhs, rhs in zip(lhs_exprs, rhs_exprs): + V.graph.sizevars.check_equals(lhs, rhs) + + _guard_list_equals(op.get_size(), bo.get_size()) + _guard_list_equals(op.get_stride(), bo.get_stride()) + # assume all carried_inputs_ and outputs are on the same device + # as the MultiOutputLayout below requires single device + assert op.get_device() == bo.get_device(), (i, op, bo, device) + assert op.get_dtype() == bo.get_dtype(), (i, op, bo) + + assert device is not None + + unbacked_bindings = resolve_unbacked_bindings( + V.graph.sizevars.shape_env, + V.graph.current_node.meta.get("unbacked_bindings", None), + ) + + while_loop = WhileLoop( + carried_inputs=carried_inputs_, + additional_inputs=additional_inputs_, + cond_subgraph=cond_fn, + body_subgraph=body_fn, + # asserted above that there is at least one operand + layout=MultiOutputLayout(device=device), + unbacked_bindings=unbacked_bindings, + stack_output=stack_output, + ) + + assert body_fn.graph is not None and isinstance( + body_fn.graph.module, torch.fx.GraphModule + ) # to make linter happy + + # Handling input mutations + mutated_idxs = check_input_alias_and_mutation( + body_fn.graph.module, fake_all_inputs + )[3] + mutated_idx_set = OrderedSet(mutated_idxs) + mutated_inputs = [all_inputs[idx] for idx in mutated_idx_set] + + # Create all outputs first + mutated_inputs_iter = iter(mutated_inputs) + all_outputs: list[IRNode] = [] + while_loop.outputs = [] + while_loop.mutation_outputs = [] + if stack_output: + assert len(mutated_idx_set) == 0, ( + "NYI: while_loop_stack_output input mutations." + ) + for idx, output in enumerate(V.graph.current_node.meta["val"]): + # Create MultiOutput for regular outputs + multi_out = MultiOutput( + FixedLayout( + device=output.device, # type: ignore[arg-type] + dtype=output.dtype, + size=[Conditional._maybe_expr(sz) for sz in output.size()], + stride=[Conditional._maybe_expr(st) for st in output.stride()], + ), + while_loop, + [(list, idx)], + ) + while_loop.outputs.append(multi_out) + all_outputs.append(multi_out) + else: + for idx, output in enumerate(body_outputs): + if idx in mutated_idx_set: + assert idx < len(carried_inputs), "only carries can be mutated." + # Create MutationOutput for mutated inputs + mutated_input = next(mutated_inputs_iter) + while_loop.mutation_outputs.append( + MutationOutput(mutated_input.layout, mutated_input, while_loop) # type: ignore[attr-defined, union-attr] + ) + all_outputs.append(mutated_input) + else: + multi_out = MultiOutput( + FixedLayout( + device=output.get_device(), # type: ignore[arg-type] + dtype=output.get_dtype(), + size=output.get_size(), + stride=output.get_stride(), + offset=output.get_layout().offset, + ), + while_loop, + [(list, idx)], + ) + while_loop.outputs.append(multi_out) + all_outputs.append(multi_out) + + for inp, out in zip(carried_inputs, all_outputs): + if inp.get_name() in V.graph.graph_inputs: + # if a carried input of the while_loop is a graph input, + # it can be returned as is when the number of iterations + # is zero. due to this, we can't (generally) reuse the + # output buffers corresponding to the graph inputs, as + # the inputs may end up being mutated. + V.graph.never_reuse_buffers.add(out.get_name()) + return all_outputs + + def codegen(self, wrapper: PythonWrapperCodegen) -> None: + wrapper.codegen_while_loop(self, self.stack_output) + wrapper.codegen_unbacked_symbol_defs_for_outputs( + self.get_name(), self.outputs, getattr(self, "unbacked_bindings", {}) + ) + + def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: + if unbacked_bindings := getattr(self, "unbacked_bindings", None): + resolved = resolve_unbacked_bindings( + V.graph.sizevars.shape_env, unbacked_bindings + ) + assert resolved is not None + return OrderedSet(resolved.keys()) + else: + return OrderedSet() + + +class EffectfulKernel(FallbackKernel): + def __init__( + self, + layout: OutputSpec, + kernel: _OpOverloads, + tensor_args: Sequence[IRNode], + nontensor_args: Sequence[Any], + unflatten_args: Callable[..., Any], + kwargs: Optional[dict[str, Any]] = None, + *, + unbacked_bindings: Optional[dict[sympy.Symbol, pytree.KeyPath]] = None, + ) -> None: + super().__init__( + layout, + kernel, + tensor_args, + nontensor_args, + unflatten_args, + kwargs=None, + unbacked_bindings=unbacked_bindings, + ) + + from torch._higher_order_ops.effects import _get_effect + + effect_type = _get_effect(kernel) + assert effect_type is not None + self.effect_type = effect_type + self.prev_effect_buffer = V.graph.effectful_ops.get(effect_type, None) + V.graph.effectful_ops[effect_type] = self + + def get_read_writes(self) -> dependencies.ReadWrites: + read_writes = super().get_read_writes() + + if self.prev_effect_buffer is not None: + read_writes.reads.add( + dependencies.StarDep(self.prev_effect_buffer.get_name()) + ) + + return read_writes + + def has_side_effects(self) -> bool: + return True + + +class NonTensorObj(IRNode): + @cache_on_self_and_args("NonTensorObj") + def get_free_symbol_uses( + self, unbacked_only: bool = False + ) -> OrderedSet[sympy.Symbol]: + return OrderedSet() + + +@ir_dataclass +class TorchBindObject(NonTensorObj): + name: str + value: Union[FakeScriptObject, torch.ScriptObject] + + def get_name(self) -> str: + return self.name + + def codegen_reference(self, writer: Optional[IndentedBuffer] = None) -> str: + return self.name + + def get_value(self) -> Union[FakeScriptObject, torch.ScriptObject]: + return self.value + + def get_real_obj(self) -> torch.ScriptObject: + if isinstance(self.value, torch.ScriptObject): + return self.value + else: + return self.value.real_obj + + def get_buf_bytes(self) -> int: + # Returns the sum of all tensors in the flattened object + real_script_obj = self.get_real_obj() + + if real_script_obj is None: + return 0 + + assert hasattr(real_script_obj, "__obj_flatten__") + flat_dict = dict(real_script_obj.__obj_flatten__()) + flat_elems = pytree.tree_flatten(flat_dict)[0] + flat_sizes = [ + x.element_size() * x.numel() + for x in flat_elems + if isinstance(x, torch.Tensor) + ] + return functools.reduce(operator.add, flat_sizes, 0) + + +@ir_dataclass +class GeneratorState(NonTensorObj): + name: str + device: torch.device + + def get_name(self) -> str: + return self.name + + def codegen_reference(self, writer: Optional[IndentedBuffer] = None) -> str: + return self.name + + +class _CollectiveKernel(FallbackKernel): + def should_allocate(self) -> bool: + return False + + def has_side_effects(self) -> bool: + return True + + # This is identical to FallbackKernel.set_cpp_kernel(), minus the + # part that checks against input aliasing and mutation. + def set_cpp_kernel_name(self, cpp_kernel_name: Optional[str] = None) -> None: + assert type(self.op_overload) is torch._ops.OpOverload, ( + "Setting cpp kernel needs a valid op_overload" + ) + kernel = self.op_overload + if cpp_kernel_name is not None: + self.cpp_kernel_name = cpp_kernel_name + else: + self.cpp_kernel_name = kernel._schema.name + + self.ordered_kwargs_for_cpp_kernel = [ + x.name for x in kernel._schema.arguments if x.kwarg_only + ] + + # NOTE: [In-Place Collective Safety] + # Between the initiation and completion of an in-place collective, the + # input buffers are subject to both volatile reads and volatile writes. + # They must not be read, written to or reused by another kernel. To ensure + # the constraints, we model collective -> wait_tensor as as two-step + # mutation of the input buffers. + @classmethod + def create_inplace( + cls, + kernel: _OpOverloads, + inputs: Union[IRNode, list[IRNode]], + *args: Any, + **kwargs: Any, + ) -> None: + with V.graph.fake_mode: + ( + _example_output, + tensor_args, + non_tensor_args, + unflatten_args, + unbacked_bindings, + ) = cls.process_kernel(kernel, inputs, *args, **kwargs) + assert not unbacked_bindings, f"{kernel} {unbacked_bindings}" + for tensor_arg in tensor_args: + tensor_arg.realize() + V.graph.mark_buffer_mutated(tensor_arg.get_name()) + + device = tensor_args[0].get_device() + packed = cls( + NoneLayout(device=device), + kernel, + tensor_args, + non_tensor_args, + unflatten_args, + ) + + inps = pytree.tree_leaves(inputs) + packed.mutation_outputs.extend( + [MutationOutput(NoneLayout(device=device), buf, packed) for buf in inps] + ) + + # For inplace collective ops, the input is guaranteed to be alias of the returned value of op. + packed.alias_names.extend([inp.get_name() for inp in inps]) + if "out" in kwargs: + packed.mutation_outputs.append( + MutationOutput(NoneLayout(device=device), kwargs["out"], packed) + ) + # For out-variant collective ops, the `out=` arg is guaranteed to be alias of the returned value of op. + packed.alias_names.append(kwargs["out"].get_name()) + + # NOTE: [Out-of-Place Collective Safety] + # Between the initiation and completion of an out-of-place collective: + # + # Input buffers: + # - Are subject to volatile reads + # - Can be read by another kernel + # - Must not be written to or reused by another kernel + # + # Output buffers: + # - Are subject to volatile writes + # - Must not be read, written to or reused by another kernel + # + # To ensure the safety of input buffers without sacrificing read + # availability, we add input buffers as read deps of wait_tensor kernels. + # + # To ensure the safety of output buffers, we model wait_tensor as a + # mutation to the output buffer. Note we also assumes the user program being + # correct and the output buffer is not consumed by kernels other than + # wait_tensor. + # + # TODO(yifu): add a pre-grad pass to validate the correctness of collective + # usage in the user program. + @classmethod + def create_out_of_place( + cls, + kernel: _OpOverloads, + inputs: Union[TensorBox, list[TensorBox]], + *args: Any, + **kwargs: Any, + ) -> Union[list[MultiOutput], _CollectiveKernel]: + with V.graph.fake_mode: + ( + example_output, + tensor_args, + non_tensor_args, + unflatten_args, + unbacked_bindings, + ) = cls.process_kernel(kernel, inputs, *args, **kwargs) + assert not unbacked_bindings, f"{kernel}, {unbacked_bindings}" + for tensor_arg in tensor_args: + tensor_arg.realize() + + if isinstance(example_output, list): + device = cls.find_device(tensor_args, example_output) + assert device is not None + packed = cls( + MultiOutputLayout(device=device), + kernel, + tensor_args, + non_tensor_args, + unflatten_args, + ) + packed.outputs = [ + MultiOutput( + cls.tensor_to_layout(tensor), + packed, + [(list, i)], + ) + for i, tensor in enumerate(example_output) + ] + for buf, tensor in zip(packed.outputs, example_output): + if config.assume_unaligned_fallback_output or not tensor_is_aligned( + tensor + ): + V.graph.unaligned_buffers.add(buf.name) # type: ignore[arg-type] + return packed.outputs + else: + packed = cls( + cls.tensor_to_layout(example_output), + kernel, + tensor_args, + non_tensor_args, + unflatten_args, + ) + if config.assume_unaligned_fallback_output or not tensor_is_aligned( + example_output + ): + V.graph.unaligned_buffers.add(packed.name) # type: ignore[arg-type] + packed.outputs = [packed] + return packed + + +class _AllReduce_Kernel(_CollectiveKernel): + def __init__( + self, + layout: OutputSpec, + kernel: _OpOverloads, + tensor_args: Sequence[IRNode], + nontensor_args: Sequence[Any], + unflatten_args: Callable[..., Any], + kwargs: Optional[dict[str, Any]] = None, + *, + unbacked_bindings: Optional[dict[sympy.Symbol, pytree.KeyPath]] = None, + ) -> None: + super().__init__( + layout, + kernel, + tensor_args, + nontensor_args, + unflatten_args, + kwargs=None, + unbacked_bindings=unbacked_bindings, + ) + self.set_cpp_kernel_name("aoti_torch_cpu__c10d_functional_all_reduce_") + + def codegen(self, wrapper: PythonWrapperCodegen) -> None: + wrapper.include_extra_header("torch/csrc/inductor/aoti_torch/c/shim_cpu.h") + wrapper.generate_extern_kernel_alloc(self) + + if isinstance(self.layout, Layout): + self.codegen_size_asserts(wrapper) + + +class _AllReduceKernel(_CollectiveKernel): + def __init__( + self, + layout: OutputSpec, + kernel: _OpOverloads, + tensor_args: Sequence[IRNode], + nontensor_args: Sequence[Any], + unflatten_args: Callable[..., Any], + kwargs: Optional[dict[str, Any]] = None, + *, + unbacked_bindings: Optional[dict[sympy.Symbol, pytree.KeyPath]] = None, + ) -> None: + super().__init__( + layout, + kernel, + tensor_args, + nontensor_args, + unflatten_args, + kwargs=None, + unbacked_bindings=unbacked_bindings, + ) + self.set_cpp_kernel_name("aoti_torch_cpu__c10d_functional_all_reduce") + + def codegen(self, wrapper: PythonWrapperCodegen) -> None: + wrapper.include_extra_header("torch/csrc/inductor/aoti_torch/c/shim_cpu.h") + wrapper.generate_extern_kernel_alloc(self) + + if isinstance(self.layout, Layout): + self.codegen_size_asserts(wrapper) + + +class _WaitKernel(_CollectiveKernel): + def __init__( + self, + layout: OutputSpec, + kernel: _OpOverloads, + tensor_args: Sequence[IRNode], + nontensor_args: Sequence[Any], + unflatten_args: Callable[..., Any], + kwargs: Optional[dict[str, Any]] = None, + *, + unbacked_bindings: Optional[dict[sympy.Symbol, pytree.KeyPath]] = None, + ) -> None: + super().__init__( + layout, + kernel, + tensor_args, + nontensor_args, + unflatten_args, + kwargs=None, + unbacked_bindings=unbacked_bindings, + ) + self.set_cpp_kernel_name("aoti_torch_cpu__c10d_functional_wait_tensor") + + def codegen(self, wrapper: PythonWrapperCodegen) -> None: + wrapper.include_extra_header("torch/csrc/inductor/aoti_torch/c/shim_cpu.h") + wrapper.generate_extern_kernel_alloc(self) + + if isinstance(self.layout, Layout): + self.codegen_size_asserts(wrapper) + + def get_volatile_reads(self) -> Sequence[IRNode]: + inp = self.inputs[0] + assert isinstance(inp, IRNode) + if isinstance(inp, _CollectiveKernel): + # Out-of-place single-output + i = inp.inputs[0] + assert isinstance(i, IRNode), type(i) + return [i] + elif isinstance(inp, MultiOutput): + # This can be two things: + # 1. Out-of-place multi-output coll + # 2. In-place coll with inputs coming from another MultiOutput + coll = inp.inputs[0] + # Case 1 + if isinstance(coll, _CollectiveKernel): + _, idx = inp.indices[0] + # pyrefly: ignore [bad-return] + return [coll.inputs[idx]] + # Case 2 + return [] + else: + # In-place requires no additional deps handling for volatile + # reads since the inputs are mutated. + return [] + + @classmethod + def create_wait(cls, kernel: _OpOverloads, inp: TensorBox) -> None: + with V.graph.fake_mode: + ( + _example_output, + tensor_args, + non_tensor_args, + unflatten_args, + unbacked_bindings, + ) = cls.process_kernel(kernel, inp) + assert not unbacked_bindings, f"{kernel} {unbacked_bindings}" + packed = cls( + NoneLayout(device=inp.get_device()), + kernel, + tensor_args, + non_tensor_args, + unflatten_args, + ) + packed.mutation_outputs.append( + MutationOutput(NoneLayout(device=inp.get_device()), inp, packed) + ) + + def get_read_writes(self) -> dependencies.ReadWrites: + read_writes = super().get_read_writes() + # See [Out-of-Place Collective Safety]. + volatile_reads = self.get_volatile_reads() + for vr in volatile_reads: + read_writes.reads.add(dependencies.StarDep(vr.get_name())) + return read_writes + + +# NB: recursive structure here reflects val_to_arg_str, avoid +# calling free_unbacked_symbols on "exotic" types that don't get pexpr +# treatment +def maybe_free_unbacked_symbols(s: object) -> OrderedSet[Symbol]: + if isinstance(s, (SymTypes, Expr)): + # This branch should be impossible in return position + return free_unbacked_symbols(s) + elif isinstance(s, (tuple, list)): + r = OrderedSet[sympy.Symbol]() + for t in s: + r |= maybe_free_unbacked_symbols(t) + return r + elif isinstance(s, torch.Tensor): + # This branch is impossible in constant-args position + return free_unbacked_symbols(s) + else: + return OrderedSet() + + +def maybe_free_symbols(s: object) -> OrderedSet[Symbol]: + if isinstance(s, (SymTypes, Expr)): + # This branch should be impossible in return position + return free_symbols(s) + elif isinstance(s, (tuple, list)): + r = OrderedSet[sympy.Symbol]() + for t in s: + r |= maybe_free_symbols(t) + return r + elif isinstance(s, torch.Tensor): + # This branch is impossible in constant-args position + return free_symbols(s) + else: + return OrderedSet() + + +def assign_origin_node(result: Any, n: torch.fx.Node) -> None: + # This is not complete, but it doesn't have to be: origin_node + # tracking is best effort. The logic here critically relies on direct + # TensorBox -> StorageBox denoting a non-view; we don't bother trying + # to get views to work. Feel free to add any extra cases as needed. + # + # Note: we can't YOLO tree_map over this result, because if there are + # buffers or a view involved, we might not be able to validly assign + # the origin_node here. + if isinstance(result, TensorBox) and isinstance(result.data, StorageBox): + if isinstance(result.data.data, Loops): + result.data.data._post_init_setattr("origin_node", n) + elif isinstance(result.data.data, Buffer): + result.data.data._post_init_setattr("origin_node", n) + if isinstance(result.data.data, ComputedBuffer) and isinstance( + result.data.data.data, Loops + ): + result.data.data.data._post_init_setattr("origin_node", n) + # Not really multi-output, can straightforwardly recurse in + elif ( + isinstance(result.data.data, MultiOutput) + and not result.data.data.indices + ): + if isinstance(result.data.data.inputs[0], Buffer): + result.data.data.inputs[0]._post_init_setattr("origin_node", n) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/jagged_lowerings.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/jagged_lowerings.py new file mode 100644 index 0000000000000000000000000000000000000000..86cdc42ee88eb0b4615368ab6a0b04b90bb6208f --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/jagged_lowerings.py @@ -0,0 +1,270 @@ +# mypy: allow-untyped-defs +from typing import Optional, Union + +import sympy + +import torch + +from .ir import Pointwise, ShapeAsConstantBuffer, TensorBox +from .virtualized import ops + + +# pyre-ignore[2,3] +def dense_idx_to_jagged_idx(batch_idx, seq_idx, offsets_loader, jagged_len): + # jagged_len + 1 is used as the upper bound, + # because the last sequence length may be zero + begin_idx = ops.indirect_indexing( + offsets_loader([batch_idx]), + jagged_len + 1, + ) + end_idx = offsets_loader([batch_idx + 1]) + jagged_idx = begin_idx + seq_idx + return jagged_idx, end_idx + + +def get_inverse_offsets( + offsets: TensorBox, + jagged_len: Union[int, sympy.Expr], + realize: bool = True, +) -> Union[TensorBox, ShapeAsConstantBuffer]: + """ + Returns "inverse_offsets" - the inverse of the offsets array. + offsets maps batch index (dense) to jagged index (i.e. offset into jagged tensor). + inverse_offsets maps jagged index to batch index. + + e.g. for offsets [0, 3, 4, 9, 10] this will return + inverse_offsets = [0, 0, 0, 1, 2, 2, 2, 2, 2, 3] + + For the given offsets, the computed inverse_offsets are cached + on the first call and reused in the further calls. + """ + + if hasattr(offsets, "inverse_offsets"): + # inverse_offsets are already computed + # for these offsets: can reuse + return offsets.inverse_offsets + + # ops.bucketize takes offsets.get_name() which doesn't exist on Pointwise + # kernels, i.e. we need to realize it before using. In other words, we need + # offsets to be in global memory so that we can binary search over the + # entire tensor + offsets.realize() + device: torch.device = offsets.get_device_or_error() + dtype: torch.dtype = offsets.get_dtype() + + # pyre-ignore[2,3] + def inner_fn(index): + idx = index[0] + bucket = ops.bucketize( + values=ops.index_expr(idx, dtype), + boundaries=( + offsets.get_name(), + offsets.get_size()[-1], + offsets.get_size()[0] * offsets.get_stride()[0], + offsets.get_stride()[-1], + ), + boundary_indices=0, + indexing_dtype=dtype, + right=True, + ) + # ops.bucketize above returns 1-based bucket indices, + # but we need 0-based, hence we subtract 1 from batch + return bucket - 1 + + inverse_offsets = Pointwise.create( + device=device, + dtype=dtype, + inner_fn=inner_fn, + ranges=[jagged_len], + ) + + if realize: + # "freeze" the node so that it doesn't get inlined downstream. + inverse_offsets.realize() + + # cache inverse_offsets for further reuse + offsets.inverse_offsets = inverse_offsets # type: ignore[attr-defined] + + return inverse_offsets + + +def jagged_idx_to_dense_idx( + jagged_idx, # pyre-ignore[2] + inverse_offsets_loader, # pyre-ignore[2] + offsets_loader, # pyre-ignore[2] + batch_size: Union[int, sympy.Expr], + max_seq_len: Union[int, sympy.Expr], + offsets_dtype: torch.dtype, +) -> tuple[sympy.Expr, sympy.Expr]: + batch_idx = ops.indirect_indexing( + inverse_offsets_loader([jagged_idx]), + batch_size + 1, + ) + batch_start = offsets_loader([batch_idx]) + seq = ops.index_expr(jagged_idx, offsets_dtype) - batch_start + # check=False because there may be sequences longer than max_seq_len + seq_idx = ops.indirect_indexing(seq, max_seq_len, check=False) + return batch_idx, seq_idx + + +def register_jagged_ops(): + # Avoid circular import by importing here + from .lowering import fallback_handler, is_integer_type, register_lowering + + # pyre-ignore[56] + @register_lowering(torch.ops.aten._jagged_to_padded_dense_forward.default) + def _jagged_to_padded_dense_forward( + jagged_values: TensorBox, + jagged_offsets: list[TensorBox], + max_lengths: list[int], # list of ints/SymInts + padding_value: float = 0.0, + ) -> Union[TensorBox, ShapeAsConstantBuffer]: + device = jagged_values.get_device_or_error() + dtype = jagged_values.get_dtype() + + jagged_values_size = jagged_values.get_size() + + # only handle the common case of a single jagged dimension + if ( + len(jagged_offsets) != 1 + or device.type != "cuda" + or device != jagged_offsets[0].get_device() + or len(jagged_values_size) != 2 + or len(jagged_offsets[0].get_size()) != 1 + or len(max_lengths) != len(jagged_offsets) + or not is_integer_type(jagged_offsets[0]) + ): + return fallback_handler( + torch.ops.aten._jagged_to_padded_dense_forward.default, + add_to_fallback_set=False, + )( + jagged_values, + jagged_offsets, + max_lengths, + padding_value, + ) + + offsets: TensorBox = jagged_offsets[0] # type: ignore[assignment] + offsets_len = offsets.get_size()[0] + offsets_dtype = offsets.get_dtype() + batch_size = offsets_len - 1 + max_seq_len = max_lengths[0] + embedding_len = jagged_values_size[1] + jagged_len = jagged_values_size[0] + + output_size = [batch_size, max_seq_len, embedding_len] + + values_loader = jagged_values.make_loader() + offsets_loader = offsets.make_loader() + + # pyre-ignore[2,3,53] + def inner_fn(index): + # dense tensor size: [B, N, D] + batch_idx, seq_idx, emb_idx = index + jagged_idx, end_idx = dense_idx_to_jagged_idx( + batch_idx=batch_idx, + seq_idx=seq_idx, + offsets_loader=offsets_loader, + jagged_len=jagged_len, + ) + return ops.masked( + ops.lt( + ops.index_expr(jagged_idx, offsets_dtype), + end_idx, + ), + lambda: values_loader([jagged_idx, emb_idx]), + padding_value, + ) + + return Pointwise.create( + device=device, + dtype=dtype, + inner_fn=inner_fn, + ranges=output_size, + ) + + def _dense_to_jagged_forward_impl( + fallback_op, # pyre-ignore[2] + dense: TensorBox, + jagged_offsets: list[TensorBox], + jagged_len: Optional[int] = None, + ) -> Union[TensorBox, ShapeAsConstantBuffer]: + device = dense.get_device_or_error() + dtype = dense.get_dtype() + + dense_size = dense.get_size() + + # only handle the common case of a single jagged dimension + if ( + len(jagged_offsets) != 1 + or device.type != "cuda" + or device != jagged_offsets[0].get_device() + or len(jagged_offsets[0].get_size()) != 1 + or len(dense_size) != 3 + or jagged_len is None + or not is_integer_type(jagged_offsets[0]) + ): + return fallback_handler(fallback_op, add_to_fallback_set=False)( + dense, + jagged_offsets, + jagged_len, + ) + + offsets: TensorBox = jagged_offsets[0] # type: ignore[assignment] + offsets_dtype = offsets.get_dtype() + batch_size = dense_size[0] + max_seq_len = dense_size[1] + embedding_len = dense_size[-1] + + output_size = [jagged_len, embedding_len] + + dense_loader = dense.make_loader() + offsets_loader = offsets.make_loader() + + inverse_offsets = get_inverse_offsets( + offsets=offsets, + jagged_len=jagged_len, + ) + inverse_offsets_loader = inverse_offsets.make_loader() + + # pyre-ignore[2,3,53] + def inner_fn(index): + # jagged tensor size: [sum_B(N_B), D] + jagged_idx, emb_idx = index + batch_idx, seq_idx = jagged_idx_to_dense_idx( + jagged_idx=jagged_idx, + offsets_loader=offsets_loader, + inverse_offsets_loader=inverse_offsets_loader, + batch_size=batch_size, + max_seq_len=max_seq_len, + offsets_dtype=offsets_dtype, + ) + return ops.masked( + ops.lt( + ops.index_expr(seq_idx, offsets_dtype), + ops.index_expr(max_seq_len, offsets_dtype), + ), + lambda: dense_loader([batch_idx, seq_idx, emb_idx]), + 0.0, # jagged sequence longer than max_seq_len + ) + + return Pointwise.create( + device=device, + dtype=dtype, + inner_fn=inner_fn, + ranges=output_size, + ) + + # pyre-ignore[56] + @register_lowering(torch.ops.aten._padded_dense_to_jagged_forward) + def _dense_to_jagged_forward( + dense: TensorBox, + jagged_offsets: list[TensorBox], + jagged_len: Optional[int] = None, + ) -> Union[TensorBox, ShapeAsConstantBuffer]: + return _dense_to_jagged_forward_impl( + fallback_op=torch.ops.aten._padded_dense_to_jagged_forward.default, + dense=dense, + jagged_offsets=jagged_offsets, + jagged_len=jagged_len, + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel_inputs.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel_inputs.py new file mode 100644 index 0000000000000000000000000000000000000000..c579cf756577282a3fb498c342e9385079eb8947 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel_inputs.py @@ -0,0 +1,338 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any, Optional, TYPE_CHECKING, Union + +import torch +import torch._inductor.config +from torch._inductor import ir +from torch._inductor.virtualized import V + +from .ir import FixedLayout, FlexibleLayout, Layout + + +if TYPE_CHECKING: + from collections.abc import Sequence + + import sympy + + +class KernelInputs(ABC): + """ + Class to store and provide access to input nodes for kernels. + This class takes in a tuple of input nodes and provides methods to access + information about these nodes, such as their device type and device. + """ + + def __init__( + self, + input_nodes: list[Any], + scalars: Optional[dict[str, Union[float, int]]] = None, + out_dtype: Optional[torch.dtype] = None, + ): + """ + Initialize with a tuple of input nodes. + + Args: + input_nodes: A tuple of input nodes to store + out_dtype: Optional output dtype to store + """ + self._input_nodes = input_nodes + self._device_name: Optional[str] = None + self._scalars = scalars if scalars is not None else {} + self._out_dtype = out_dtype + assert len(input_nodes) > 0, "Expected at least one input node" + + def nodes(self, reorder: Optional[Sequence[int]] = None) -> list[Any]: + """ + Return the stored input nodes, optionally reordered. + + Args: + reorder: Optional sequence of indices to reorder the nodes. + For example, (2, 0, 1) would return nodes in that order. + + Returns: + The tuple of input nodes, optionally reordered + """ + if reorder is None: + return self._input_nodes + assert len(self._input_nodes) == len(reorder), ( + f"Reorder length mismatch: {len(self._input_nodes)} vs {len(reorder)}" + ) + return [self._input_nodes[i] for i in reorder] + + @property + def count(self) -> int: + """ + Get the number of input nodes. + + Returns: + The number of input nodes + """ + return len(self._input_nodes) + + @property + def device_type(self) -> Optional[str]: + """ + Get the device type of the first node. + + Returns: + The device type (e.g., 'cuda', 'cpu') + """ + + return ir.get_device_type(self._input_nodes[0]) + + def device(self) -> torch.device: + """ + Get the device of the first node. + + Returns: + The device of the first node + """ + return self._input_nodes[0].get_device() + + def device_name(self) -> Optional[str]: + """ + Get the device name information. + + Returns: + A tuple of (gpu_name, vendor, model) + """ + if self._device_name is None: + device = self.device() + if self.device_type == "cuda": + device_properties = torch.cuda.get_device_properties(device) + self._device_name = device_properties.gcnArchName + return self._device_name + + def shapes_symbolic(self) -> tuple[tuple[Any, ...], ...]: + """ + Get the symbolic shapes of all input nodes. + + Returns: + A tuple of shape tuples for each input node + """ + return tuple(node.get_size() for node in self._input_nodes) + + def shapes_hinted(self) -> tuple[tuple[int, ...], ...]: + """ + Get the size hints for shapes of all input nodes. + + Returns: + A tuple of shape tuples with integer hints for each input node + """ + return tuple( + V.graph.sizevars.size_hints( + node.get_size(), + fallback=torch._inductor.config.unbacked_symint_fallback, + ) + for node in self._input_nodes + ) + + def strides_symbolic(self) -> tuple[tuple[sympy.Integer, ...], ...]: + """ + Get the symbolic strides of all input nodes. + + Returns: + A tuple of stride tuples for each input node + """ + return tuple(node.get_stride() for node in self._input_nodes) + + def strides_hinted(self) -> tuple[tuple[int, ...], ...]: + """ + Get the size hints for strides of all input nodes. + + Returns: + A tuple of stride tuples with integer hints for each input node + """ + return tuple( + V.graph.sizevars.size_hints( + node.get_stride(), + fallback=torch._inductor.config.unbacked_symint_fallback, + ) + for node in self._input_nodes + ) + + def dtypes(self) -> tuple[torch.dtype, ...]: + """ + Get the dtypes of all input nodes. + + Returns: + A tuple of dtypes for each input node + """ + return tuple(node.get_dtype() for node in self._input_nodes) + + def dtype(self, idx: int = 0) -> torch.dtype: + """ + Get the dtype of a specific input node. + + Args: + idx: Index of the node to get the dtype from (default: 0) + + Returns: + The dtype of the specified input node + """ + return self._input_nodes[idx].get_dtype() + + @abstractmethod + def out_dtype(self) -> torch.dtype: + """ + Get the output dtype, whether passed in or inferred from the nodes + + Returns: + The output dtype + """ + + def get_scalar(self, name: str) -> Union[float, int]: + """ + Get the scalar value for a given name. + + Args: + name: Name of the scalar to get + + Returns: + The scalar value + """ + assert name in self._scalars, f"Scalar {name} not found, but required" + return self._scalars[name] + + @abstractmethod + def output_layout(self, flexible: bool = True) -> Layout: + """ + Abstract method to handle output layout generation. + + Args: + out_dtype: Optional output dtype. If not provided, infer from inputs + flexible: If True, return FlexibleLayout, otherwise FixedLayout + """ + + +class MMKernelInputs(KernelInputs): + """ + Specialized KernelInputs for matrix multiplication operations. + Provides additional methods to access M, N, K dimensions. + """ + + def __init__( + self, + input_nodes: list[Any], + scalars: Optional[dict[str, Union[float, int]]] = None, + out_dtype: Optional[torch.dtype] = None, + mat1_idx: int = -2, + mat2_idx: int = -1, + ): + """ + Initialize with a tuple of input nodes. + + By default, we assume the last 2 input nodes are mat1 and mat2, but + the caller can adjust when necessary + """ + super().__init__(input_nodes, scalars, out_dtype) + # for mm, we need at least 2 nodes, and we need to know which nodes + # are the main matrixes e.g. addmm is (bias, mat1, mat2) whereas others + # might be (mat1, mat2, scale), etc. + assert len(self._input_nodes) >= 2, "Expected at least 2 input nodes" + + # Adjust assertions to handle negative indices + m1_idx, m2_idx = mat1_idx, mat2_idx + if mat1_idx < 0: + m1_idx += len(input_nodes) + if mat2_idx < 0: + m2_idx += len(input_nodes) + + assert 0 <= m1_idx < len(input_nodes), f"Invalid mat1_idx: {mat1_idx}" + assert 0 <= m1_idx < len(input_nodes), f"Invalid mat2_idx: {mat2_idx}" + + self._mat1_idx = mat1_idx + self._mat2_idx = mat2_idx + + def mnk_symbolic( + self, + ) -> tuple[sympy.Integer, sympy.Integer, sympy.Integer]: + """ + Get the symbolic M, N, K dimensions for matrix multiplication. + Handles both 2D (MM) and 3D (BMM) tensors. + + M is extracted from the second-to-last dimension of the first operand (mat1). + N is extracted from the last dimension of the second operand (mat2). + K is extracted from the last dimension of the first operand (mat1). + + Returns: + A tuple of (M, N, K) dimensions + """ + mat1 = self.nodes()[self._mat1_idx] + mat2 = self.nodes()[self._mat2_idx] + + m = mat1.get_size()[-2] # M from second-to-last dimension of mat1 + k = mat1.get_size()[-1] # K from last dimension of mat1 + n = mat2.get_size()[-1] # N from last dimension of mat2 + + # Ensure K dimensions match between operands + k0 = mat2.get_size()[-2] # K from second-to-last dimension of mat2 + V.graph.sizevars.check_equals(k, k0) + return (m, n, k) + + def out_dtype(self) -> torch.dtype: + """ + Get the output dtype, whether passed in or inferred from the nodes + + Returns: + The output dtype + """ + if self._out_dtype is not None: + return self._out_dtype + return self.mat1mat2()[0].get_dtype() + + def output_layout(self, flexible: bool = True) -> Layout: + """ + Handle output layout generation for matrix multiplication. + + Args: + out_dtype: Optional output dtype. If not provided, infer from inputs + flexible: If True, return FlexibleLayout, otherwise FixedLayout + """ + mat1, mat2 = self.mat1mat2() + out_dtype = self.out_dtype() + # NOTE: taken from mm_common.mm_args + *b1, m, k1 = mat1.get_size() + *b2, k2, n = mat2.get_size() + b = [V.graph.sizevars.check_equals_and_simplify(a, b) for a, b in zip(b1, b2)] + size = [*b, m, n] + if flexible: + return FlexibleLayout(self.device(), out_dtype, size) + else: + return FixedLayout(self.device(), out_dtype, size) + + def mat1mat2(self) -> tuple[Any, Any]: + """ + Get the mat1 and mat2 nodes. + + Returns: + A tuple of (mat1, mat2) nodes + """ + nodes = self.nodes() + return nodes[self._mat1_idx], nodes[self._mat2_idx] + + def mnk_hinted(self) -> tuple[int, int, int]: + """ + Get the hinted M, N, K dimensions for matrix multiplication. + Handles both 2D (MM) and 3D (BMM) tensors. + + Uses shapes_hinted from the base class to get integer hints for dimensions. + + Returns: + A tuple of (M, N, K) dimensions as integers + """ + hinted_shapes = self.shapes_hinted() + mat1_shape = hinted_shapes[self._mat1_idx] + mat2_shape = hinted_shapes[self._mat2_idx] + + m = mat1_shape[-2] # M from second-to-last dimension of mat1 + k = mat1_shape[-1] # K from last dimension of mat1 + n = mat2_shape[-1] # N from last dimension of mat2 + + # Ensure K dimensions match between operands + k_check = mat2_shape[-2] # K from second-to-last dimension of mat2 + assert k == k_check, f"K dimensions don't match: {k} vs {k_check}" + + return (m, n, k) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel_template_choice.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel_template_choice.py new file mode 100644 index 0000000000000000000000000000000000000000..8f90157c6c1a0de9ef21dd044cdc40f5c8f82e4e --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/kernel_template_choice.py @@ -0,0 +1,99 @@ +from __future__ import annotations + +from typing import Any, Optional, TYPE_CHECKING, Union + +from .template_heuristics.params import DictKernelTemplateParams + + +if TYPE_CHECKING: + from collections.abc import Generator + + from .codegen.common import KernelTemplate + from .ir import ChoiceCaller, Layout + from .kernel_inputs import KernelInputs + from .select_algorithm import ExternKernelChoice + from .template_heuristics.params import KernelTemplateParams + + +class KernelTemplateChoice: + """ + A class that encapsulates all the components needed to create a ChoiceCaller from a template. + + This class implements lazy evaluation for the choice property - the actual ChoiceCaller + is only created when first accessed via the choice property. + """ + + def __init__( + self, + template: Union[KernelTemplate, ExternKernelChoice], + params: KernelTemplateParams, + extra_kwargs: dict[str, Any], + layout: Layout, + inputs: KernelInputs, + ): + self.template = template + self.params = params + self.extra_kwargs = extra_kwargs + self.layout = layout + self.inputs = inputs + self.annotations: dict[str, Any] = {"ktc": self} + + @property + def choice(self) -> Optional[ChoiceCaller]: + """ + Lazily evaluate and return the ChoiceCaller for this template choice. + + On first access, calls template.choice_or_none() with the stored parameters. + If successful, caches and returns the ChoiceCaller. If it fails, caches + and returns None. Subsequent accesses return the cached value. + + Returns: + ChoiceCaller if the template choice succeeds, None otherwise + """ + if not hasattr(self, "_choice"): + # First time accessing choice - try to generate it + kwargs = self.params.to_kwargs() + self._choice = self.template.choice_or_none( + **kwargs, + **self.extra_kwargs, + layout=self.layout, + input_nodes=self.inputs.nodes(), + ) + if self._choice is not None: + self._choice.annotations = self.annotations + return self._choice + + +def make_ktc_generator( + template: Union[KernelTemplate, ExternKernelChoice], + cs: Generator[KernelTemplateParams, None, None], + extra_kwargs: dict[str, Any], + overrides: dict[str, Any], + layout: Layout, + inputs: KernelInputs, +) -> Generator[KernelTemplateChoice, None, None]: + """ + Create a generator of KernelTemplateChoice objects for a given template. + + Args: + template: The template object (KernelTemplate or ExternKernelChoice) + cs: Generator of KernelTemplateParams from template heuristic + overrides: Override kwargs for the template + layout: Layout value for the template + inputs: KernelInputs for the op + + Yields: + KernelTemplateChoice objects + """ + for params in cs: + # Apply overrides to params + base_kwargs = params.to_kwargs() + final_kwargs = {**base_kwargs, **overrides} + final_params = DictKernelTemplateParams(final_kwargs) + yield KernelTemplateChoice( + template=template, + params=final_params, + extra_kwargs=extra_kwargs, + layout=layout, + inputs=inputs, + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/loop_body.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/loop_body.py new file mode 100644 index 0000000000000000000000000000000000000000..3921aa955a8360e9f6e53d121ad4dfcc35632e5c --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/loop_body.py @@ -0,0 +1,789 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import collections +import functools +import itertools +import re +from enum import auto, Enum +from typing import Any, NamedTuple, Optional, TYPE_CHECKING, TypeVar + +import sympy + +import torch.fx +from torch._dynamo.utils import identity +from torch.fx.proxy import Scope, TracerBase +from torch.utils._sympy.symbol import SymT + +from . import config, dependencies +from .codegen.common import index_prevent_reordering +from .ops_handler import DefaultHandler, OpsHandler, WrapperHandler +from .utils import ( + cache_on_self, + reduction_num_outputs, + sympy_index_symbol_with_prefix, + sympy_subs, +) +from .virtualized import ops, V + + +if TYPE_CHECKING: + from collections.abc import Callable, Sequence + + +T = TypeVar("T") + + +class InterpreterShim(torch.fx.Interpreter): + @staticmethod + @functools.cache + def _dummy_gm(): + return torch.fx.symbolic_trace(identity) + + def __init__(self, graph, submodules): + # call super() with a placeholder to avoid constructing a + # GraphModule which is very expensive (it does codegen). + super().__init__(self._dummy_gm(), garbage_collect_values=False) + self.module = self # type: ignore[assignment] + self.graph = graph + self.submodules = submodules + self.extra_traceback = False + self.fetch_attr = submodules.__getitem__ # type: ignore[method-assign] + self.current_node = None + + def run_node(self, n: torch.fx.Node) -> Any: + # pyrefly: ignore [bad-assignment] + self.current_node = n + return super().run_node(n) + + def run(self, *args, **kwargs): + with V.set_interpreter_handler(self): + return super().run(*args, **kwargs) + + +# We don't need the nn.Module and constant handling in Tracer +class LightTracer(TracerBase): + def __init__(self): + super().__init__() + self.graph = torch.fx.Graph(tracer_cls=self.__class__) # type: ignore[arg-type] + self.scope = Scope("", None) + self.module_stack = {} # type: ignore[assignment] + self.node_name_to_scope = {} + + +class MemoryEntry(NamedTuple): + index_name: str # LoopBody.indexing_exprs[index_name] + buffer_name: Optional[str] + mode: Optional[str] # V.ops.store(..., mode=mode) + + +class MemoryUsageType(Enum): + # These are 1:1 with the opcode generating the usage + LOAD = auto() + LOAD_SEED = auto() + STORE = auto() + STORE_REDUCTION = auto() + INDEX_EXPR = auto() + CHECK_BOUNDS = auto() + BUCKETIZE = auto() + + +class LoopBody: + """ + Captures the body of a Loops subclass into an FX graph. Persists any + indexing simplifications and makes it easier to analyze loop bodies. + """ + + indexing_exprs: dict[str, sympy.Expr] + submodules: dict[str, Any] + subblocks: dict[str, LoopBodyBlock] + indirect_vars: list[sympy.Symbol] + indirect_var_ranges: dict[sympy.Symbol, sympy.Expr] + root_block: LoopBodyBlock + memory_usage: dict[MemoryUsageType, list[MemoryEntry]] + op_counts: collections.Counter[str] + + # defined only temporarily + indexing_exprs_name: dict[sympy.Expr, str] + + def __init__( + self, + fn, + args, + var_ranges, + iter_vars, + reduce_vars, + allow_same_symbol_in_index=False, + ): + super().__init__() + + _flat_sizes = tuple(var_ranges.values()) + self.sizes = ( + _flat_sizes[: len(iter_vars)], + _flat_sizes[len(iter_vars) :], + ) + + self.iter_vars = iter_vars + self.reduce_vars = reduce_vars + self.var_ranges = var_ranges + + if isinstance(fn, LoopBody): + self._init_with_copy(fn, args, allow_same_symbol_in_index) + else: + self._init_with_tracing(fn, args) + + self.indexing = None + + def get_original_num_rdims(self) -> int: + assert self.has_partial_accumulate + node = self.root_block.graph.find_nodes( + op="call_method", target="partial_accumulate" + )[0] + meta = node.args[-1] + return meta["num_reduction_dims"] + + def extract_pw_from_reduction(self): + self.root_block = self.root_block.extract_pw_from_reduction() + self.has_partial_accumulate = True + self.iter_vars = self.iter_vars + self.reduce_vars + self.reduce_vars = [] + self.sizes = (self.sizes[0] + self.sizes[1], tuple()) + return self + + def _init_with_tracing(self, fn, args): + """Do an FX trace of an arbitrary callable to construct self""" + self.indexing_exprs = {} + self.indexing_exprs_name = {} + self.submodules = {"get_index": self.get_index} + self.subblocks = {} + self.indirect_vars = [] + self.indirect_var_ranges: dict[sympy.Symbol, sympy.Expr] = {} + self.memory_usage = {t: [] for t in MemoryUsageType} + self.op_counts = collections.Counter() + self.root_block = LoopBodyBlock(self, fn, args) # traces + self.has_partial_accumulate = self.root_block.graph.find_nodes( + op="call_method", target="partial_accumulate" + ) + del self.indexing_exprs_name # not used after _init_with_tracing + + def _init_with_copy(self, other: LoopBody, args, allow_same_symbol_in_index): + """ + _init_with_tracing() is slow, so this is a fast path in the case + where we are just reordering/merging/splitting the args of an + existing LoopBody. + """ + indexing_exprs = other.indexing_from_args(args, allow_same_symbol_in_index) + self.indexing_exprs = { + name: V.graph.sizevars.simplify_with_ranges(expr, self.var_ranges) + for name, expr in indexing_exprs.items() + } + self.subblocks = {k: v.clone(self) for k, v in other.subblocks.items()} + self.indirect_vars = other.indirect_vars + self.indirect_var_ranges = other.indirect_var_ranges + self.memory_usage = other.memory_usage + self.op_counts = other.op_counts + self.root_block = other.root_block.clone(self) + self.has_partial_accumulate = other.has_partial_accumulate + + submodules = {**other.submodules} + submodules.pop("get_index") + self.submodules = { + "get_index": self.get_index, + **{k: v.clone(self) for k, v in submodules.items()}, # type: ignore[attr-defined] + } + + def has_op(self, name: str): + return self.op_counts.get(name, 0) > 0 + + def merge_loops(self) -> LoopBody: + """ + Merge both iteration and reduction loops and return a new LoopBody. + """ + old_body = self + old_sizes = self.sizes + old_iter_vars, old_reduce_vars = old_body.vars + old_iter_sizes, old_reduce_sizes = old_sizes + + index_exprs = [*old_body.indexing_exprs.values()] + + iter_sizes, iter_reindex, _ = V.graph.sizevars._simplify_loops( + old_iter_vars, + old_iter_sizes, + index_prevent_reordering(index_exprs, old_iter_vars, old_iter_sizes), + ) + + reduce_sizes, reduce_reindex, _ = V.graph.sizevars._simplify_loops( + old_reduce_vars, + old_reduce_sizes, + index_prevent_reordering(index_exprs, old_reduce_vars, old_reduce_sizes), + ) + + if iter_sizes == old_iter_sizes and reduce_sizes == old_reduce_sizes: + return old_body + + ( + ( + iter_vars, + reduce_vars, + ), + var_ranges, + ) = dependencies.index_vars_no_squeeze(iter_sizes, reduce_sizes, prefix="p") + new_body = LoopBody( + old_body, + [iter_reindex(iter_vars), reduce_reindex(reduce_vars)], + var_ranges, + iter_vars, + reduce_vars, + allow_same_symbol_in_index=True, + ) + + return new_body + + def expand_dimension_for_pointwise_node( + self, dimension: int, new_range: int + ) -> LoopBody: + """ + Expand node on `dimension` to `new_range` and rely on index modular to avoid + out-of-boundary access. + """ + + old_body = self + old_sizes = self.sizes + + iter_size, reduce_size = old_sizes + original_range = iter_size[dimension] + new_iter_size = list(iter_size) + new_iter_size[dimension] = new_range + new_sizes = (new_iter_size, reduce_size) + + (iter_vars, reduce_vars), var_ranges = dependencies.index_vars_no_squeeze( + *new_sizes, + prefix="t", # type: ignore[arg-type] + ) + + def new_body(*indices: Sequence[sympy.Expr]) -> Any: + index = [*itertools.chain.from_iterable(indices)] + assert len(index) == len(iter_size) + len(reduce_size) + iter_idx = index[: len(iter_size)] + reduce_idx = index[len(iter_size) :] + + new_iter_idx = list(iter_idx) + new_iter_idx[dimension] = iter_idx[dimension] % original_range + + return old_body(new_iter_idx, reduce_idx) + + loop_body = LoopBody( + new_body, (iter_vars, reduce_vars), var_ranges, iter_vars, reduce_vars + ) + + # use the original symbol prefix so we can do multiple round of reordering + (iter_vars2, reduce_vars2), var_ranges2 = dependencies.index_vars_no_squeeze( + *new_sizes, + prefix="p", # type: ignore[arg-type] + ) + new_body = LoopBody( + loop_body, (iter_vars2, reduce_vars2), var_ranges2, iter_vars2, reduce_vars2 + ) + return new_body + + def reorder_iter_loops(self, new_order) -> LoopBody: + """ + Reorder iteration loops and return a new LoopBody. + """ + from .ir import same_reorder + + old_body = self + old_sizes = self.sizes + assert len(old_sizes[0]) == len(new_order) + reorder_fn = same_reorder(new_order) + + iter_size, reduce_size = old_sizes + new_iter_size = reorder_fn(iter_size) + + new_sizes = (new_iter_size, reduce_size) + + (iter_vars, reduce_vars), var_ranges = dependencies.index_vars_no_squeeze( + *new_sizes, + prefix="p", # type: ignore[arg-type] + ) + + inverse_order = {b: a for a, b in enumerate(new_order)} + inverse_order = [inverse_order[i] for i in range(len(new_order))] + + def new_body(*indices: Sequence[sympy.Expr]) -> Any: + index = [*itertools.chain.from_iterable(indices)] + assert len(index) == len(iter_size) + len(reduce_size) + iter_idx = index[: len(iter_size)] + reduce_idx = index[len(iter_size) :] + iter_idx = [iter_idx[i] for i in inverse_order] + return old_body(iter_idx, reduce_idx, allow_same_symbol_in_index=True) + + return LoopBody( + new_body, + (iter_vars, reduce_vars), + var_ranges, + iter_vars, + reduce_vars, + ) + + @property + def vars(self): + assert self.iter_vars is not None + assert self.reduce_vars is not None + return self.iter_vars, self.reduce_vars + + @cache_on_self + def get_nodes(self): + all_graphs = itertools.chain( + (self.root_block.graph,), + (block.graph for block in self.subblocks.values()), + ) + return [node for graph in all_graphs for node in graph.nodes] + + @cache_on_self + def bounds(self): + # Doing a local import to avoid dumping all the code here + from .bounds import BoundVars + + return BoundVars(self) + + def get_read_expr(self, buffer_name): + # reversed to match old behavior + for entry in reversed(self.memory_usage[MemoryUsageType.LOAD]): + if entry.buffer_name == buffer_name: + return self.indexing_exprs[entry.index_name] + raise KeyError(buffer_name) + + def get_write_expr(self, buffer_name): + for entry in itertools.chain( + self.memory_usage[MemoryUsageType.STORE], + self.memory_usage[MemoryUsageType.STORE_REDUCTION], + ): + if entry.buffer_name == buffer_name: + return self.indexing_exprs[entry.index_name] + raise KeyError(buffer_name) + + def get_read_exprs(self): + return [ + self.indexing_exprs[entry.index_name] + for entry in self.memory_usage[MemoryUsageType.LOAD] + ] + + def get_all_read_expr(self, buffer_name): + # reversed to match old behavior + out = [] + for entry in reversed(self.memory_usage[MemoryUsageType.LOAD]): + if entry.buffer_name == buffer_name: + out.append(self.indexing_exprs[entry.index_name]) + return out + + def get_write_exprs(self): + return [ + self.indexing_exprs[entry.index_name] + for entry in itertools.chain( + self.memory_usage[MemoryUsageType.STORE], + self.memory_usage[MemoryUsageType.STORE_REDUCTION], + ) + ] + + def get_all_write_expr(self, buffer_name): + out = [] + for entry in itertools.chain( + self.memory_usage[MemoryUsageType.STORE], + self.memory_usage[MemoryUsageType.STORE_REDUCTION], + ): + if entry.buffer_name == buffer_name: + out.append(self.indexing_exprs[entry.index_name]) + return out + + def debug_str(self): + lines = [f"var_ranges = {dict(self.var_ranges)}"] + lines.extend([f"{name} = {val}" for name, val in self.indexing_exprs.items()]) + lines.extend( + [ + block.debug_str(name) + for name, block in itertools.chain( + [("body", self.root_block)], self.subblocks.items() + ) + ] + ) + return "\n".join(lines) + + def is_memory_copy(self) -> bool: + """ + True of this contains only a single loads and store. + Note, this could involve a layout change. + """ + return ( + len(self.memory_usage[MemoryUsageType.LOAD]) == 1 + and len(self.memory_usage[MemoryUsageType.STORE]) == 1 + and len(self.submodules) == 1 # get_index + and self.root_block.contains_only_ops(("load", "store")) + ) + + __repr__ = debug_str + + def add_index_expr( + self, + expr: sympy.Expr, + mtype: MemoryUsageType, + buffer_name: Optional[str] = None, + mode: Optional[str] = None, + ): + name = self.indexing_exprs_name.get(expr) + if not name: + name = f"index{len(self.indexing_exprs)}" + self.indexing_exprs_name[expr] = name + self.indexing_exprs[name] = expr + self.memory_usage[mtype].append(MemoryEntry(name, buffer_name, mode)) + return name + + def add_submodule(self, block, prefix): + """Not actually for nn.Modules, but subblocks in generated code are mapped to FX call_module opcodes""" + if prefix[-1].isnumeric() and prefix not in self.submodules: + name = prefix + else: + name = f"{prefix}{len(self.submodules)}" + self.submodules[name] = block + return name + + def add_indirect(self, size): + var = sympy_index_symbol_with_prefix(SymT.INDIRECT, len(self.indirect_vars)) + assert var not in self.indirect_var_ranges + self.indirect_vars.append(var) + self.indirect_var_ranges[var] = size + return var + + def replace_indirect(self, old, new): + """Swap in a variable used in indirect indexing""" + if str(old) == str(new): + return + assert self.indexing is not None + # pyrefly: ignore [bad-assignment] + self.indexing = {k: sympy_subs(v, {old: new}) for k, v in self.indexing.items()} + + def get_index(self, name): + assert self.indexing is not None + return self.indexing[name] + + def indexing_from_args(self, indices, allow_same_symbol_in_index=False): + index = [*itertools.chain.from_iterable(indices)] + assert len(index) == len(self.var_ranges), (index, self.var_ranges) + assert allow_same_symbol_in_index or all( + v not in self.var_ranges for v in index + ), f"{self.var_ranges=}, {indices=}" + + replacements = dict(zip(self.var_ranges.keys(), index)) + return { + name: sympy_subs(expr, replacements) + for name, expr in self.indexing_exprs.items() + } + + def __call__(self, *indices, allow_same_symbol_in_index=False): + self.indexing = self.indexing_from_args(indices, allow_same_symbol_in_index) + result = self.root_block() + self.indexing = None + return result + + def bind_set_indirect_shim(self, var, size, check, wrap_neg): + def set_indirect(new_var): + self.replace_indirect( + var, V.ops.indirect_indexing(new_var, size, check, wrap_neg) + ) + + set_indirect.clone = functools.partial( # type: ignore[attr-defined] + LoopBody.bind_set_indirect_shim, + var=var, + size=size, + check=check, + wrap_neg=wrap_neg, + ) + return set_indirect + + def bind_scan_shim(self, combine_fn): + def shim(dtypes, values): + return V.ops.scan(dtypes, combine_fn, values) + + shim.clone = functools.partial(LoopBody.bind_scan_shim, combine_fn=combine_fn) # type: ignore[attr-defined] + return shim + + def bind_masked_shim(self, name): + def shim(mask, other): + return V.ops.masked(mask, self.subblocks[name], other) + + shim.clone = functools.partial(LoopBody.bind_masked_shim, name=name) # type: ignore[attr-defined] + return shim + + +class LoopBodyBlock: + """ + Captures the body of a Loops subclass into an FX graph. + In normal cases there will be a 1:1 mapping between LoopBody and + LoopBodyBlock, however in the case of ops.masked() the masked out + operations will manifest as an extra LoopBodyBlock. + """ + + def __init__(self, body: LoopBody, fn: Callable[..., Any], args: list[Any]): + self.body = body + + tracer = LightTracer() + proxy_ops = tracer.create_proxy("placeholder", "ops", (), {}) + + from .index_propagation import IndexPropagation + + handler: Any = CountOps( + CaptureIndexing(proxy_ops, body, tracer), + body.op_counts, + ) + if config.constant_and_index_propagation: + handler = IndexPropagation( + handler, self.body.var_ranges, self.body.indirect_var_ranges + ) + + with V.set_ops_handler(handler): + # This indirection is just a cute way to get IndexPropagation to + # unwrap the return value. + ops.output(fn(*args)) + self.graph = tracer.graph + + def extract_pw_from_reduction(self): + red = None + store = None + for node in self.graph.nodes: + if node.target == "reduction": + assert not red + red = node + if node.target == "store_reduction": + assert not store + store = node + assert red + assert store + reduction_type = red.args[-2] + red_arg = red.args[-1] + buf = store.args[1] + ops = store.args[0] + + extra_meta = { + "num_reduction_dims": len(self.body.reduce_vars), + } + with self.graph.inserting_after(store): + self.graph.call_method( + "partial_accumulate", (ops, buf, reduction_type, red_arg, extra_meta) + ) + self.graph.erase_node(store) + self.graph.erase_node(red) + return self + + def __call__(self): + graph = self.graph + submodules = self.body.submodules + + return InterpreterShim(graph, submodules).run(V.get_ops_handler()) + + def debug_str(self, name="block"): + code = torch.fx.GraphModule(self.body.submodules, self.graph).code + return re.sub( + # strip `; del var0` suffixes to make output prettier + r";[^\n]*", + "", + code.strip().replace("def forward(", f"def {name}("), + ) + + def contains_only_ops(self, allowed_ops) -> bool: + return all( + node.target in allowed_ops + for node in self.graph.find_nodes(op="call_method") + ) + + def clone(self, body: LoopBody): + """Shallow copy with a new parent LoopBody""" + copy = LoopBodyBlock.__new__(LoopBodyBlock) + copy.__dict__.update({**self.__dict__, "body": body}) + return copy + + +class CountOps(DefaultHandler): + def __init__(self, inner: OpsHandler[Any], counts: collections.Counter[str]): + self._inner = inner + self._counts = counts + + def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: + self._counts[name] += 1 + return getattr(self._inner, name)(*args, **kwargs) + + +class CaptureIndexing(WrapperHandler): + name = "CaptureIndexing" + + def __init__( + self, + inner: OpsHandler[Any], + body: LoopBody, + tracer: LightTracer, + ): + super().__init__(inner) + self.body = body + self.tracer = tracer + + def _add_index(self, expr: sympy.Expr, mtype: MemoryUsageType, **kwargs: Any): + return self.tracer.create_proxy( + "call_module", + "get_index", + (self.body.add_index_expr(expr, mtype, **kwargs),), + {}, + ) + + def _simplify(self, expr: sympy.Expr) -> sympy.Expr: + return V.graph.sizevars.simplify_with_ranges(expr, self.body.var_ranges) + + def load(self, name: str, index: sympy.Expr): + index = self._simplify(index) + index = self._add_index(index, MemoryUsageType.LOAD, buffer_name=name) + return self._inner.load(name, index) + + def load_seed(self, name: str, index: int): + assert isinstance(index, int) + self.body.add_index_expr( + sympy.Integer(index), MemoryUsageType.LOAD_SEED, buffer_name=name + ) + return self._inner.load_seed(name, index) + + def store(self, name, index, value, mode=None): + index = self._simplify(index) + index = self._add_index( + index, MemoryUsageType.STORE, buffer_name=name, mode=mode + ) + return self._inner.store(name, index, value, mode) + + def store_reduction(self, name, index, value): + index = self._simplify(index) + index = self._add_index( + index, MemoryUsageType.STORE_REDUCTION, buffer_name=name + ) + return self._inner.store_reduction(name, index, value) + + def reduction(self, dtype, src_dtype, reduction_type, value): + result = self._inner.reduction(dtype, src_dtype, reduction_type, value) + num_outputs = reduction_num_outputs(reduction_type) + if num_outputs > 1: + return tuple(result[i] for i in range(num_outputs)) + return result + + def index_expr(self, index, dtype): + index = self._simplify(index) + if isinstance(index, (int, sympy.Integer)): + return self._inner.constant(int(index), dtype) + index = self._add_index(index, MemoryUsageType.INDEX_EXPR) + return self._inner.index_expr(index, dtype) + + def check_bounds(self, index, size, lower, upper): + index = self._simplify(index) + index = self._add_index(index, MemoryUsageType.CHECK_BOUNDS) + size = self._add_index(size, MemoryUsageType.CHECK_BOUNDS) + return self._inner.check_bounds(index, size, lower, upper) + + def bucketize( + self, + values: T, + boundaries: tuple[str, sympy.Expr, sympy.Expr, sympy.Expr], + boundary_indices: T, + indexing_dtype: torch.dtype, + right: bool, + sorter: Optional[tuple[str, sympy.Expr]] = None, + sorter_indices: Optional[T] = None, + ) -> T: + """ + See [Note: Inductor bucketize op] + """ + boundaries = ( + boundaries[0], + self._add_index( + boundaries[1], + MemoryUsageType.BUCKETIZE, + buffer_name=boundaries[0], + ), + self._add_index( + boundaries[2], + MemoryUsageType.BUCKETIZE, + buffer_name=boundaries[0], + ), + self._add_index( + boundaries[3], + MemoryUsageType.BUCKETIZE, + buffer_name=boundaries[0], + ), + ) + if sorter is not None: + sorter = ( + sorter[0], + self._add_index( + sorter[1], MemoryUsageType.BUCKETIZE, buffer_name=sorter[0] + ), + ) + + return self._inner.bucketize( + values, + boundaries, + boundary_indices, + indexing_dtype, + right, + sorter, + sorter_indices, + ) + + def masked(self, mask_proxy, masked_body: Callable[..., Any], other_proxy): + """ + Recursively capture the masked out body in another LoopBodyBlock + """ + name = self.body.add_submodule(None, "masked_subblock") + self.body.submodules[name] = self.body.bind_masked_shim(name) + self.body.subblocks[name] = LoopBodyBlock(self.body, masked_body, []) + return self.tracer.create_proxy( + "call_module", name, (mask_proxy, other_proxy), {} + ) + + def scan( + self, + dtype_proxy, + combine_fn: Callable[[tuple[Any, ...], tuple[Any, ...]], tuple[Any, ...]], + value_proxy, + ): + shim = self.body.bind_scan_shim(combine_fn) + name = self.body.add_submodule(shim, "scan") + result = self.tracer.create_proxy( + "call_module", + name, + (dtype_proxy, value_proxy), + {}, + ) + # Proxies are iterable, but some methods expect tuples/lists + return tuple(result[i] for i in range(len(value_proxy))) + + def sort(self, dtypes, values, stable, descending): + result = self._inner.sort(dtypes, values, stable, descending) + # Proxies are iterable, but some methods expect tuples/lists + return tuple(result[i] for i in range(len(values))) + + def frexp(self, value_proxy): + result = self._inner.frexp(value_proxy) + # Proxies are iterable, but some methods expect tuples/lists + return (result[0], result[1]) + + def indirect_indexing(self, index_proxy, size, check=True, wrap_neg=True): + """ + Flow data from tensors into indexing formulas. + Introduce a call_module to update the indexing. + """ + + var = self.body.add_indirect(size) + set_indirect = self.body.bind_set_indirect_shim(var, size, check, wrap_neg) + self.tracer.create_proxy( + "call_module", + self.body.add_submodule(set_indirect, f"set_{var}"), + (index_proxy,), + {}, + ) + return var + + def output(self, *result): + self.tracer.create_proxy("output", "output", result, {}) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/lowering.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/lowering.py new file mode 100644 index 0000000000000000000000000000000000000000..8d5c8ce444acaa48622a9ad99f4d1ae4ff1bf618 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/lowering.py @@ -0,0 +1,7683 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import contextlib +import dataclasses +import functools +import itertools +import logging +import math +import operator +import os +import textwrap +import warnings +from collections import defaultdict +from collections.abc import Callable, Collection, Iterable, Sequence +from typing import Any, cast, Optional, TYPE_CHECKING, TypeGuard, TypeVar, Union +from typing_extensions import ParamSpec +from unittest.mock import patch + +import sympy + +import torch +import torch.ao.quantization.fx._decomposed +import torch.fx +import torch.utils._pytree as pytree +from torch._dynamo.utils import counters +from torch._higher_order_ops.associative_scan import associative_scan_op +from torch._higher_order_ops.triton_kernel_wrap import triton_kernel_wrapper_mutation +from torch._library.fake_class_registry import FakeScriptObject +from torch._library.utils import get_layout_constraint_tag +from torch._prims_common import ( # pyrefly: ignore # deprecated; pyrefly: ignore [deprecated] + canonicalize_dim, + canonicalize_dims, + check, + dtype_to_type, + elementwise_dtypes, + ELEMENTWISE_TYPE_PROMOTION_KIND, + get_computation_dtype, + is_boolean_dtype, + is_float_dtype, + is_integer_dtype, + Number, +) +from torch.fx.experimental.sym_node import magic_methods, method_to_operator +from torch.fx.experimental.symbolic_shapes import ( + free_unbacked_symbols, + has_free_unbacked_symbols, + resolve_unbacked_bindings, +) +from torch.utils._ordered_set import OrderedSet +from torch.utils._sympy.functions import ( + CeilDiv, + FloorDiv, + Identity, + Mod, + ModularIndexing, +) + +from .._dynamo.utils import import_submodule +from . import config, inductor_prims, ir, test_operators # NOQA: F401 +from .decomposition import decompositions, get_decompositions +from .ir import ( + BaseView, + DtypeView, + ExpandView, + IndexingConstant, + IRNode, + is_triton, + MutableBox, + OnlineSoftmaxReduction, + ops_wrapper, + PermuteView, + Pointwise, + Reduction, + ShapeAsConstantBuffer, + SqueezeView, + TensorBox, + validate_ir, + View, +) +from .utils import ( + ceildiv, + decode_device, + is_dynamic, + is_gpu, + is_pointwise_use, + is_view, + needs_fallback_due_to_atomic_add_limitations, + pad_listlike, + register_op_dtype_propagation_rules, + register_op_requires_libdevice_fp64, + sympy_product, + use_scatter_fallback, +) +from .virtualized import ops, V + + +if TYPE_CHECKING: + from .ops_handler import ReductionType + + +_T = TypeVar("_T") +_P = ParamSpec("_P") + +# TODO(jansel): we should implement decomps or lowerings for these +# https://github.com/pytorch/torchdynamo/issues/327 +FALLBACK_ALLOW_LIST = OrderedSet( + [ + "torchvision::roi_align", + "aten::index_add", + ] +) + +log = logging.getLogger(__name__) +lowerings: dict[Union[Callable[..., Any], str], Callable[..., Any]] = {} +# Use maybe_layout_constraints to access this dict, we lazily register tag-based layout constraints +_maybe_layout_constraints: dict[ + torch._ops.OpOverload, Optional[Callable[..., Any]] +] = {} +fallbacks = OrderedSet[torch._ops.OpOverload]() +aten = torch.ops.aten +tr_c10d = torch.ops.tr_c10d +prims = torch.ops.prims +needs_realized_inputs = OrderedSet[torch._ops.OpOverload]() +foreach_ops = OrderedSet[torch._ops.OpOverload]( + [torch._higher_order_ops._foreach_map] # type: ignore[list-item] +) +# TODO(rec): torch._higher_order_ops._foreach_map is not an OpOverload +# so why is it in foreach_ops? +inplace_foreach_ops = OrderedSet[torch._ops.OpOverload]() +inplaceable_foreach_ops: dict[torch._ops.OpOverload, torch._ops.OpOverload] = {} +quantized_decomposed = torch.ops.quantized_decomposed + + +def cur_node_has_non_foreach_users() -> bool: + for node in V.graph.current_node.users: + for user in node.users: + if not (user.op == "call_function" and (user.target in foreach_ops)): + return True + + return False + + +# group by device, whether any of the inputs are dynamic +# note arg_pairs may or may not be a pair +# foreach_map for example just passes output buffers here +def group_foreach_args( + arg_pairs: Iterable[Any], +) -> defaultdict[tuple[Any, bool], list[tuple[int, Any]]]: + out = defaultdict(list) + unpack_args = False + for i, args in enumerate(arg_pairs): + if not isinstance(args, Iterable): + unpack_args = True + args = (args,) + use_foreach = ( + not is_dynamic(*args) or config.combo_kernel_foreach_dynamic_shapes + ) + device = None + for t in args: + if isinstance(t, TensorBox): + device = t.data.get_device() + break + assert device is not None, "foreach op should have at least one tensor arg" + if unpack_args: + # pyrefly: ignore [bad-unpacking] + (args,) = args + out[(device, use_foreach)].append((i, args)) + return out + + +def maybe_layout_constraints(fn: Callable[..., Any]) -> Optional[Callable[..., Any]]: + """Get layout constraints. Returns None if there are no layout constraints.""" + if not isinstance(fn, torch._ops.OpOverload): + # Only OpOverloads have layout constraints. + return None + + if maybe_layout_tag := get_layout_constraint_tag(fn, with_default=False): + return tag_to_layout_constraint(maybe_layout_tag) + + if fn in _maybe_layout_constraints: + return _maybe_layout_constraints[fn] + return None + + +def tag_to_layout_constraint( + tag: torch._C.Tag, +) -> Optional[Callable[..., tuple[Any, Any]]]: + if tag == torch._C.Tag.needs_exact_strides: + return constrain_to_fake_tensors + if tag == torch._C.Tag.needs_contiguous_strides: # type: ignore[attr-defined] + return require_contiguous_strides + if tag == torch._C.Tag.needs_fixed_stride_order: + return constrain_to_fx_strides + if tag == torch._C.Tag.flexible_layout: + return None + raise AssertionError(f"Unknown layout constraint tag: {tag}") + + +def assert_nyi(cond: bool, msg: str) -> None: + if not cond: + raise NotImplementedError(f"inductor does not support {msg}") + + +def add_needs_realized_inputs( + fn: Union[ + Collection[Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket]], + torch._ops.OpOverload, + torch._ops.OpOverloadPacket, + ], +) -> Optional[list[Any]]: + if isinstance(fn, (list, set, tuple, OrderedSet)): # noqa: set_linter + return [add_needs_realized_inputs(x) for x in fn] + if isinstance(fn, torch._ops.OpOverload): + needs_realized_inputs.add(fn) + elif isinstance(fn, torch._ops.OpOverloadPacket): + needs_realized_inputs.update( + getattr(fn, overload) for overload in fn.overloads() + ) + return None + + +def add_layout_constraint( + fn: Union[torch._ops.OpOverloadPacket, torch._ops.OpOverload], + constraint: Callable[..., tuple[Any, Any]], +) -> None: + if isinstance(fn, torch._ops.OpOverloadPacket): + for overload in fn.overloads(): + _maybe_layout_constraints[getattr(fn, overload)] = constraint + else: + _maybe_layout_constraints[fn] = constraint + + +add_needs_realized_inputs( + [ + aten.as_strided, + aten.as_strided_copy, + aten.avg_pool2d, + aten.avg_pool2d_backward, + aten.bmm, + aten.convolution, + aten.convolution_backward, + aten.max_pool2d_with_indices, + aten.max_pool3d_with_indices, + aten.max_pool2d_with_indices_backward, + aten.mm, + aten.upsample_nearest2d, + aten._upsample_nearest_exact2d, + aten._int_mm, + ] +) + +# TODO(jansel): ezyang says we won't need this in the future, try removing it +# based on https://github.com/pytorch/pytorch/blob/9e3eb329df8f701/c10/core/ScalarType.h#L28 +DTYPE_ID_LOOKUP = { + 0: torch.uint8, + 1: torch.int8, + 2: torch.int16, + 3: torch.int32, + 4: torch.int64, + 5: torch.float16, + 6: torch.float32, + 7: torch.float64, + 8: torch.complex32, + 9: torch.complex64, + 10: torch.complex32, + 11: torch.bool, + 15: torch.bfloat16, + # TODO(jansel): add quantized types? + # _(c10::qint8, QInt8) /* 12 */ + # _(c10::quint8, QUInt8) /* 13 */ + # _(c10::qint32, QInt32) /* 14 */ + # _(c10::quint4x2, QUInt4x2) /* 16 */ + # _(c10::quint2x4, QUInt2x4) /* 17 */ +} + + +def decode_dtype(dtype: Union[int, torch.dtype]) -> torch.dtype: + if not isinstance(dtype, int): + return dtype + assert dtype in DTYPE_ID_LOOKUP, f"id {dtype} missing from DTYPE_ID_LOOKUP" + # pyrefly: ignore [bad-assignment] + dtype = DTYPE_ID_LOOKUP[dtype] + return dtype + + +def is_integer_type(x: Any) -> TypeGuard[Union[TensorBox, sympy.Expr, int]]: + if isinstance(x, TensorBox): + return is_integer_dtype(x.get_dtype()) or is_boolean_dtype(x.get_dtype()) + elif isinstance(x, sympy.Expr): + return x.is_integer is True # type: ignore[attr-defined] + else: + return isinstance(x, int) + + +def is_boolean_type(x: Any) -> TypeGuard[Union[TensorBox, bool]]: + if isinstance(x, TensorBox): + return is_boolean_dtype(x.get_dtype()) + else: + return isinstance(x, bool) + + +def get_promoted_dtype( + *args: Any, type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND +) -> torch.dtype: + def construct_input(inp: Any) -> Any: + if isinstance(inp, (Number, sympy.Basic)): + return inp + else: + dim = len(inp.get_size()) + # construct a tmp tensor to feed into torch.result_type + return torch.zeros([1] * dim, dtype=inp.get_dtype()) + + inps = [construct_input(arg) for arg in args] + _, dtype = elementwise_dtypes(*inps, type_promotion_kind=type_promotion_kind) + return dtype + + +def get_overloads(aten_fn): + if not isinstance(aten_fn, (list, tuple)): + aten_fn = [aten_fn] + else: + aten_fn = list(aten_fn) + + for fn in list(aten_fn): + if isinstance(fn, torch._ops.OpOverloadPacket): + for overload in fn.overloads(): + other_fn = getattr(fn, overload) + if other_fn not in lowerings: + aten_fn.append(other_fn) + + return aten_fn + + +def in_namespace( + op: Union[Any, torch._ops.OpOverloadPacket, torch._ops.OpOverload], namespace: str +) -> bool: + if isinstance(op, torch._ops.OpOverloadPacket): + return namespace in op._qualified_op_name + elif isinstance(op, torch._ops.OpOverload): + return namespace in op.name() + return False + + +def maybe_copy_cpu_scalar(x: TensorBox, device: torch.device) -> TensorBox: + """ + Copy cpu scalar if doesn't not match with given `device` + """ + if not isinstance(x.data, ir.ReinterpretView) or has_free_unbacked_symbols( + x.get_size() + ): + return x + size = [V.graph.sizevars.size_hint_or_throw(s) for s in x.get_size()] + cur_device = x.get_device() + if ( + cur_device is not None + and cur_device.type == "cpu" + and cur_device != device + and (len(size) == 0 or (len(size) == 1 and size[0] == 1)) + ): + return TensorBox(ir.StorageBox(ir.DeviceCopy.create(x, cur_device, False))) + return x + + +def transform_args( + args: list[Any], + kwargs: dict[str, Any], + broadcast: bool, + type_promotion_kind: Optional[ELEMENTWISE_TYPE_PROMOTION_KIND], + convert_input_to_bool: bool, +) -> tuple[list[Any], dict[str, Any]]: + """ + Transforms arguments for broadcasting and type promotion + """ + + args_indices = [i for i, x in enumerate(args) if isinstance(x, TensorBox)] + kwargs_indices = [k for k, v in kwargs.items() if isinstance(v, TensorBox)] + # check that there's something to transform + if not args_indices and not kwargs_indices: + return args, kwargs + + if type_promotion_kind or convert_input_to_bool: + if convert_input_to_bool: + dtype = torch.bool + else: + # FIXME this is a crude approximation for promoting args + promoting_args = [ + a + for a in args + if isinstance(a, (Number, sympy.Basic)) or hasattr(a, "dtype") + ] + # only consider tensor kwargs for promotion, for now + promoting_args.extend(a for a in kwargs.values() if hasattr(a, "dtype")) + dtype = get_promoted_dtype( + *promoting_args, + type_promotion_kind=type_promotion_kind, # type: ignore[arg-type] + ) + + device = ( + args[args_indices[0]] if args_indices else kwargs[kwargs_indices[0]] + ).get_device() + + for i in args_indices: + args[i] = maybe_copy_cpu_scalar(args[i], device) + + for k in kwargs_indices: + kwargs[k] = maybe_copy_cpu_scalar(kwargs[k], device) + + # sometimes args are an immutable list so we can't mutate them + def promote(arg: Any) -> Any: + if isinstance(arg, TensorBox): + return to_dtype(arg, dtype) + elif isinstance(arg, ir.Constant): + return ir.Constant(value=arg.value, dtype=dtype, device=device) + else: + return arg + + args = [promote(a) for a in args] + kwargs = {k: promote(v) for k, v in kwargs.items()} + + if broadcast: + broadcasted = broadcast_tensors( + *list( + itertools.chain( + (args[i] for i in args_indices), + (kwargs[k] for k in kwargs_indices), + ) + ) + ) + size = list(broadcasted[0].get_size()) + + for i, x in zip(args_indices, broadcasted[: len(args_indices)]): + args[i] = x + for k, x in zip(kwargs_indices, broadcasted[len(args_indices) :]): + kwargs[k] = x + + for i in range(len(args)): + if isinstance(args[i], ir.Constant): + args[i] = ExpandView.create(args[i], size) + for k in kwargs: + if isinstance(kwargs[k], ir.Constant): + kwargs[k] = ExpandView.create(kwargs[k], size) + + return args, kwargs + + +def _register_foreach_lowering( + aten_fn: torch._ops.OpOverload, decomp_fn: Callable[..., Any] +) -> Callable[..., Any]: + """ + Add a foreach lowering to lowerings dict. + + Arguments: + aten_fn: torch.ops.aten.* fn we are lowering + decomp_fn: alternate implementation on our IR + broadcast: True to apply broadcasting to tensor inputs + type_promotion_kind: kind of type promotion applied to tensor inputs, `None` means no type promotion + convert_input_to_bool: some logical ops require inputs are converted to bool + """ + + @functools.wraps(decomp_fn) + def wrapped(*args: Any, **kwargs: Any) -> Any: + assert len(args) <= 2 + out = decomp_fn(*args, **kwargs) + validate_ir(out) + return out + + aten_fns = get_overloads(aten_fn) + foreach_ops.update(aten_fns) + lowerings.update(dict.fromkeys(aten_fns, wrapped)) + return wrapped + + +def _register_lowering( + aten_fn, + decomp_fn: Callable[..., Any], + broadcast: bool, + type_promotion_kind: Optional[ELEMENTWISE_TYPE_PROMOTION_KIND], + convert_input_to_bool: bool, + lowering_dict: dict[Union[Callable[..., Any], str], Callable[..., Any]], +): + """ + Add a lowering to lowerings dict + + Arguments: + aten_fn: torch.ops.aten.* fn we are lowering + decomp_fn: alternate implementation on our IR + broadcast: True to apply broadcasting to tensor inputs + type_promotion_kind: kind of type promotion applied to tensor inputs, `None` means no type promotion + convert_input_to_bool: some logical ops require inputs are converted to bool + """ + + @functools.wraps(decomp_fn) + def wrapped(*args, **kwargs): + args: list[Any] = list(args) + kwargs: dict[str, Any] = dict(kwargs) + unpacked = False + # TODO maybe we need to use pytrees here + if len(args) == 1 and isinstance(args[0], (list, tuple)): + unpacked = True + args = list(args[0]) + + if not all( + (fn in fallbacks or in_namespace(fn, "_c10d_functional")) for fn in aten_fn + ): + # explicitly assert for "out=" ops for better error messages + assert not any(x == "out" for x in kwargs), "out= ops aren't yet supported" + + args, kwargs = transform_args( + args, kwargs, broadcast, type_promotion_kind, convert_input_to_bool + ) + + if unpacked: + args = [args] + + out = decomp_fn(*args, **kwargs) + validate_ir(out) + + return out + + aten_fn = get_overloads(aten_fn) + + lowering_dict.update(dict.fromkeys(aten_fn, wrapped)) + return wrapped + + +def register_lowering( + aten_fn, + broadcast=False, + type_promotion_kind: Optional[ + ELEMENTWISE_TYPE_PROMOTION_KIND + ] = ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + convert_input_to_bool=False, + lowering_dict=lowerings, +) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: + """ + Shim to support decorator syntax. + """ + return functools.partial( + _register_lowering, + aten_fn, + broadcast=broadcast, + type_promotion_kind=type_promotion_kind, + convert_input_to_bool=convert_input_to_bool, + lowering_dict=lowering_dict, + ) + + +def broadcast_symbolic_shapes(a, b): + """ + Broadcasting logic based on symbolic shapes. + + We give the shapes 0 and 1 concrete values, while all other shapes + are symbolic sympy formulas. + """ + output = [] + for x, y in itertools.zip_longest(reversed(a), reversed(b), fillvalue=sympy.S.One): + if V.graph.sizevars.is_size_one_or_false(y): + output.append(x) + elif V.graph.sizevars.is_size_one_or_false(x): + output.append(y) + else: + V.graph.sizevars.check_equals(x, y) + if len(sympy.expand(y).free_symbols) < len(sympy.expand(x).free_symbols): + output.append(y) # prefer shorter formula + else: + output.append(x) + return tuple(reversed(output)) + + +def promote_constants(inputs, override_return_dtype=None, type_promotion_kind=None): + assert override_return_dtype is None or type_promotion_kind is None, ( + "only one of override_return_dtype or type_promotion_kind may be given" + ) + + if override_return_dtype is None and type_promotion_kind is None: + type_promotion_kind = ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + + if not any(isinstance(x, (sympy.Basic, int, float)) for x in inputs): + return inputs + if all(isinstance(x, (int, float, sympy.Basic)) for x in inputs): + dtype = override_return_dtype or get_promoted_dtype( + *inputs, + # pyrefly: ignore [bad-argument-type] + type_promotion_kind=type_promotion_kind, + ) + + def const_func(x): + if isinstance(x, sympy.Basic): + return ir.IndexingConstant( + index=x, dtype=dtype, device=decode_device(None) + ) + else: + return ir.Constant(value=x, dtype=dtype, device=decode_device(None)) + + return [const_func(x) for x in inputs] + ex = next(x for x in inputs if isinstance(x, (TensorBox, ExpandView, ir.Constant))) + out = [] + for x in inputs: + if isinstance(x, (int, float)): + out.append( + ExpandView.create( + ir.Constant( + value=x, dtype=ex.get_dtype(), device=ex.get_device_or_error() + ), + list(ex.get_size()), + ) + ) + elif isinstance(x, sympy.Basic): + out.append( + ExpandView.create( + IndexingConstant( + index=x, dtype=ex.get_dtype(), device=ex.get_device_or_error() + ), + list(ex.get_size()), + ) + ) + else: + out.append(x) + + return out + + +def make_pointwise( + fn, + override_return_dtype=None, + override_device=None, + override_fn_when_input_bool=None, + allow_alpha=False, + triton_fallback=None, +): + def inner(*inputs: TensorBox, alpha=None): + if triton_fallback is not None and any( + isinstance(inp, IRNode) and is_triton(inp) for inp in inputs + ): + assert not allow_alpha # not implemented + return triton_fallback(*inputs) + + inputs = promote_constants(inputs, override_return_dtype) + if allow_alpha: + if alpha is not None and alpha != 1: + # pyrefly: ignore [bad-assignment] + inputs = list(inputs) + # pyrefly: ignore [unsupported-operation] + inputs[-1] = mul(inputs[-1], alpha) + else: + assert alpha is None + loaders = [x.make_loader() for x in inputs] + ranges = inputs[0].get_size() + dtype = override_return_dtype or inputs[0].get_dtype() + + for other in inputs[1:]: + assert isinstance(other, ir.BaseConstant) or len(ranges) == len( + other.get_size() + ), f"ndim mismatch {fn} {ranges} {other.get_size()}" + + # in tracing, we will annotate pointwise nodes that correspond to the output of + # a pointwise node that would have been run in eager. intermediary pointwise nodes + # during decompositions are not annotated. + low_pr_fp = (torch.bfloat16, torch.float16) + emulate_precision_casts = ( + V.graph is not None + and getattr(V.graph, "current_node", None) is not None + and V.graph.current_node.meta is not None + and V.graph.current_node.meta.get("low_precision_pointwise_barrier", False) + ) + emulate_output_cast = emulate_precision_casts and dtype in low_pr_fp + + def inner_fn(index): + assert len(index) == len(ranges), f"wrong ndim {index} {ranges}" + if dtype == torch.bool and override_fn_when_input_bool is not None: + return override_fn_when_input_bool(*[load(index) for load in loaders]) + else: + inputs_loaded = [] + for inp_index, load in enumerate(loaders): + out = load(index) + inp_dtype = inputs[inp_index].get_dtype() + if emulate_precision_casts and inp_dtype in low_pr_fp: + downcast = ops.to_dtype(out, inp_dtype, use_compute_types=False) + out = ops.to_dtype(downcast, inp_dtype) + inputs_loaded.append(out) + + out = fn(*inputs_loaded) + if emulate_output_cast: + # fp16/bf16 kernels are computed in fp32. Casting down to fp16/bf16 here, + # then upcasting again, to emulate casts that eager would do. + downcast = ops.to_dtype(out, dtype, use_compute_types=False) + return ops.to_dtype(downcast, dtype) + return out + + if not override_device: + device = None + for i in inputs: + # pyrefly: ignore [missing-attribute] + if is_gpu(i.get_device().type): + device = i.get_device() + break + if not device: + device = inputs[0].get_device() + + # pyrefly: ignore [unbound-name] + device = override_device or device + + return Pointwise.create( + device=device, # type: ignore[arg-type] + dtype=dtype, + inner_fn=inner_fn, + ranges=ranges, + ) + + return inner + + +def make_foreach_pointwise(pw_fn, allow_alpha=False): + def inner(*inputs: list[list[TensorBox]], alpha=1): + realize_outputs = ( + len(V.graph.current_node.users) == 0 + or V.graph.current_node.target in inplace_foreach_ops + or cur_node_has_non_foreach_users() + ) + + a_list_input = None + for input in inputs: + if isinstance(input, (list, tuple)): + a_list_input = input + break + assert a_list_input is not None, ( + "at least one input must be a list to a foreach op" + ) + + # broadcast scalar inputs to match length of list inputs + broadcast_inputs = [] + for input in inputs: + if not isinstance(input, (list, tuple)): + broadcast_inputs.append([input] * len(a_list_input)) + else: + broadcast_inputs.append(input) + + groups = group_foreach_args(zip(*broadcast_inputs)) + + outputs = [None] * len(a_list_input) + for (device, use_foreach), group in groups.items(): + operation_list: list[str] = [] + for ( + output_ind, + args, + ) in group: + if allow_alpha: + output = pw_fn(*args, alpha=alpha) + else: + output = pw_fn(*args) + + outputs[output_ind] = output + + if ( + # pyrefly: ignore [unbound-name] + V.graph.has_feature(device, BackendFeature.FOREACH) + and use_foreach + and realize_outputs + ): + output.realize() + operation_list.append(output.get_operation_name()) + + if operation_list: + # pyrefly: ignore [unbound-name] + V.graph.register_operation_list(operation_list) + + assert all(x is not None for x in outputs) + return outputs + + return inner + + +def to_dtype( + x: Union[TensorBox, ShapeAsConstantBuffer], dtype: torch.dtype, copy: bool = False +): + src_dtype = x.get_dtype() + if src_dtype == dtype: + return clone(x) if copy else x + + def _to_dtype(x): + return ops.to_dtype(x, dtype, src_dtype=src_dtype) + + return make_pointwise(_to_dtype, override_return_dtype=dtype)(x) + + +@register_lowering(torch._higher_order_ops._foreach_map, type_promotion_kind=None) +def _foreach_map(subgraph, *args, **kwargs): + """ + This lowers an invocation of foreach_map + The way this works is that an arbitrary N-arg func is provided by the user, looped over by the + polyfill with the same semantics as a foreach op (a loop applying an n-ary function to n args) + and then traced into a subgraph by dynamo. + This code allows us to inline the subgraph into the main graph lowering using the PontwiseSubgraphLowering. + The graph outputs represent the vertically fused sequence of ops, and then register_operation_list + below registers the buffers as horizontally fuseable in the scheduler. + """ + from .subgraph_lowering import PointwiseSubgraphLowering + + inputs = args + + gm = subgraph.graph_module + pw_subgraph = PointwiseSubgraphLowering(gm, root_graph_lowering=V.graph) + with V.set_graph_handler(pw_subgraph): # type: ignore[arg-type] + pw_subgraph.run(*inputs) + + sub_outputs = pw_subgraph.graph_outputs + # group outputs by device and register as foreach + assert sub_outputs # mypy lol + groups = group_foreach_args(sub_outputs) + + outputs = [None] * len(sub_outputs) + for (device, use_foreach), group in groups.items(): + operation_list: list[str] = [] + for ( + output_ind, + output, + ) in group: + outputs[output_ind] = output + + if V.graph.has_feature(device, BackendFeature.FOREACH) and use_foreach: + output.realize() + operation_list.append(output.get_operation_name()) + + if operation_list: + V.graph.register_operation_list(operation_list) + + assert all(x is not None for x in outputs) + return outputs + + +@register_lowering(prims.convert_element_type, type_promotion_kind=None) +def _convert_element_type(x: TensorBox, dtype: torch.dtype): + if dtype.is_complex or x.get_dtype().is_complex: + if x.get_size(): + # Decompose since aa aten fallback is more friendly for c++ codegen. + # This decomposition doesn't work for empty tensor, which needs more investigation. + dst = empty_like(x, dtype=dtype) + ir.InplaceCopyFallback.create(dst, x) + return dst + else: + return fallback_handler( + prims.convert_element_type.default, add_to_fallback_set=False + )(x, dtype) + return to_dtype(x, dtype, copy=True) + + +def to_dtype_bitcast(x: TensorBox, dtype: torch.dtype, *, copy=False): + x_dtype = x.get_dtype() + if x_dtype == dtype: + return clone(x) if copy else x + + def _get_primitive_bitwidth(dtype): + if dtype.is_floating_point: + return torch.finfo(dtype).bits + else: + return torch.iinfo(dtype).bits + + src_bits = _get_primitive_bitwidth(x_dtype) + dst_bits = _get_primitive_bitwidth(dtype) + if src_bits != dst_bits: + # fallback to aten eager implementation for differing bitwidths + return fallback_handler(aten.view.dtype)(x, dtype) + else: + return TensorBox(DtypeView.create(x, dtype)) + + +@register_lowering(aten.view.dtype, type_promotion_kind=None) +def _view_dtype(x: TensorBox, dtype: torch.dtype): + if dtype.is_complex or x.get_dtype().is_complex: + return TensorBox.create( + ir.ComplexView.create(torch.ops.aten.view.dtype, x, dtype) + ) + return to_dtype_bitcast(x, dtype) + + +def to_device(x: TensorBox, device: torch.device, *, copy=False, non_blocking=False): + device = decode_device(device) + if x.get_device() == device: + return clone(x) if copy else x + return TensorBox.create(ir.DeviceCopy.create(x, device, non_blocking)) + + +@register_lowering(prims.device_put, type_promotion_kind=None) +def _device_put(x: TensorBox, device: torch.device, non_blocking=False): + return to_device(x, device, copy=True, non_blocking=non_blocking) + + +def register_pointwise( + aten_fn, + name=None, + broadcast=True, + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + convert_input_to_bool=False, + override_return_dtype=None, + override_fn_when_input_bool=None, + allow_alpha=False, + triton_fallback=None, +): + """A pointwise function that maps ops.{name} to inputs""" + name = name or aten_fn.__name__ + fn = ops_wrapper(name) + + register_op_dtype_propagation_rules( + name, type_promotion_kind, override_return_dtype + ) + + if override_fn_when_input_bool is not None: + override_fn_when_input_bool = ops_wrapper(override_fn_when_input_bool) + + fn = make_pointwise( + fn, + override_return_dtype=override_return_dtype, + override_fn_when_input_bool=override_fn_when_input_bool, + allow_alpha=allow_alpha, + triton_fallback=triton_fallback, + ) + fn = register_lowering( + aten_fn, + broadcast=broadcast, + type_promotion_kind=type_promotion_kind, + convert_input_to_bool=convert_input_to_bool, + )(fn) + + if hasattr(prims, name): + register_lowering( + getattr(prims, name), + type_promotion_kind=None, + convert_input_to_bool=convert_input_to_bool, + )(fn) + return fn + + +def register_frexp(): + """A pointwise function that maps ops.frexp to inputs""" + name = "frexp" + frexp = ops_wrapper("frexp") + + def frexp0(*args, **kwargs): + return frexp(*args, **kwargs)[0] # type: ignore[index] + + def frexp1(*args, **kwargs): + return frexp(*args, **kwargs)[1] # type: ignore[index] + + pw_fns = [ + make_pointwise(frexp0), + make_pointwise(frexp1, override_return_dtype=torch.int32), + ] + + def fn(*args, **kwargs): + return pw_fns[0](*args, **kwargs), pw_fns[1](*args, **kwargs) + + fn = register_lowering( + aten.frexp, + )(fn) + + if hasattr(prims, name): + register_lowering( + getattr(prims, name), + type_promotion_kind=None, + )(fn) + return fn + + +register_frexp() + + +def register_foreach_pointwise( + aten_fn, + pointwise_lowering_fn, + allow_alpha=False, +): + fn = make_foreach_pointwise(pointwise_lowering_fn, allow_alpha=allow_alpha) + fn = _register_foreach_lowering(aten_fn, fn) + return fn + + +@register_lowering(aten.where, broadcast=False, type_promotion_kind=None) +def where(cond, a, b): + def fn(*args): + return ops.where(*args) + + if isinstance(a, (float, int)): + a = constant_like(a)(b) + if isinstance(b, (float, int)): + b = constant_like(b)(a) + + args = [cond, a, b] + dtype = get_promoted_dtype( + args[1], args[2], type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ) + indices = [i for i, x in enumerate(args) if isinstance(x, TensorBox)] + for i, x in zip(indices, broadcast_tensors(*[args[i] for i in indices])): + args[i] = x + for i in range(len(args)): + if isinstance(args[i], ir.Constant): + args[i] = ExpandView.create(args[i], list(args[indices[0]].get_size())) + return make_pointwise(fn, override_return_dtype=dtype)( + args[0], to_dtype(args[1], dtype), to_dtype(args[2], dtype) + ) + + +@register_lowering(aten.broadcast_tensors, broadcast=False, type_promotion_kind=None) +def broadcast_tensors(*inputs): + if len(inputs) == 1 and isinstance(inputs[0], (list, tuple)): + return broadcast_tensors(*inputs[0]) + target: list[sympy.Expr] = functools.reduce( + broadcast_symbolic_shapes, [x.get_size() for x in inputs], [] + ) + outputs = [] + for x in inputs: + sizes = x.get_size() + + if len(sizes) != len(target) or any( + V.graph.sizevars.is_size_one_or_false(a) + != V.graph.sizevars.is_size_one_or_false(b) + for a, b in zip(sizes, target) + ): + x = expand(x, target) + outputs.append(x) + return outputs + + +@register_lowering([aten.alias, aten.detach, aten.detach_, aten.lift, prims.view_of]) +def nop(x): + return x # AOT autograd handles this for us + + +if hasattr(aten, "lift_fresh"): + register_lowering(aten.lift_fresh)(nop) + + +@register_lowering(aten.squeeze, type_promotion_kind=None) +def squeeze(x, dim=None): + assert isinstance(x, TensorBox) + if dim is None: + return TensorBox(SqueezeView.create(x.data)) + + dim = ( + V.graph.sizevars.guard_int(dim) + if isinstance(dim, (int, sympy.Expr)) + else tuple(V.graph.sizevars.guard_int(d) for d in dim) + ) + dim = canonicalize_dims(len(x.get_size()), dim) # type: ignore[call-overload] + dims = OrderedSet((dim,) if not isinstance(dim, tuple) else dim) + + new_shape = [] + for d, s in enumerate(x.get_size()): + if not (d in dims and V.graph.sizevars.guard_or_false(sympy.Eq(s, 1))): + new_shape.append(s) + + # squeeze does nothing if the size isn't 1 + return view(x, new_shape) if new_shape != x.get_size() else x + + +@register_lowering(aten.squeeze_copy, type_promotion_kind=None) +def squeeze_copy(x, dim=None): + return clone(squeeze(x, dim)) + + +@register_lowering([aten.squeeze_]) +def squeeze_(x, dim=None): + val = squeeze(x, dim) + assert isinstance(x, TensorBox) + assert isinstance(val, TensorBox) + x.data = val.data + return x + + +@register_lowering(aten.isinf) +def isinf(x): + if is_integer_type(x): + return full_like(x, False, dtype=torch.bool) + fn = ops_wrapper("isinf") + return make_pointwise(fn, override_return_dtype=torch.bool)(x) + + +@register_lowering(aten.isnan) +def isnan(x): + if is_integer_type(x): + return full_like(x, False, dtype=torch.bool) + fn = ops_wrapper("isnan") + return make_pointwise(fn, override_return_dtype=torch.bool)(x) + + +@register_lowering(aten.ceil) +def ceil(x): + if is_integer_type(x): + return clone(x) + fn = ops_wrapper("ceil") + return make_pointwise(fn)(x) + + +@register_lowering(aten.floor) +def floor(x): + if is_integer_type(x): + return clone(x) + fn = ops_wrapper("floor") + return make_pointwise(fn)(x) + + +@register_lowering(aten.round.default) +def round(x): + if is_integer_type(x): + return clone(x) + else: + fn = ops_wrapper("round") + return make_pointwise(fn)(x) + + +@register_lowering(aten.trunc) +def trunc(x): + if is_integer_type(x): + return clone(x) + fn = ops_wrapper("trunc") + return make_pointwise(fn)(x) + + +@register_lowering(aten.expand, type_promotion_kind=None) +def expand(x, sizes): + (x,) = promote_constants([x]) + if isinstance(x, ir.BaseConstant): + return ExpandView.create(x, tuple(sizes)) + assert isinstance(x, TensorBox) + assert isinstance(sizes, (list, tuple)) + if tuple(x.get_size()) == tuple(sizes): + return x + + if not free_unbacked_symbols(x.get_size()): + x_size_product = V.graph.sizevars.size_hint_or_throw( + sympy_product(x.get_size()) + ) + # TODO: It would be better to realize the input if any of its sizes + # are unbacked, because typically the size will be non-zero. However, + # this cannot be done directly as below as we'll choke on the size_hint + # here + if x_size_product > 0 and not free_unbacked_symbols(sizes): + # maybe realize input before broadcasting it + x.mark_reuse( + V.graph.sizevars.size_hint_or_throw(sympy_product(sizes)) + // x_size_product + ) + return TensorBox(ExpandView.create(x.data, tuple(sizes))) + + +@register_lowering(prims.broadcast_in_dim, type_promotion_kind=None) +def broadcast_in_dim(a, shape, broadcast_dimensions): + s = list(shape) + for broadcast_dimension in broadcast_dimensions: + s[broadcast_dimension] = -1 + + v = a + for idx, x in enumerate(s): + if x != -1: + v = unsqueeze(v, idx) + + return expand(v, shape) + + +@register_lowering(aten.expand_as, type_promotion_kind=None) +def expand_as(x, y): + return expand(x, y.get_size()) + + +@register_lowering(aten.repeat) +def repeat(x, repeats): + old_size = list(x.get_size()) + if len(repeats) > len(old_size): + old_size = [sympy.S.One] * (len(repeats) - len(old_size)) + old_size + x = view(x, list(old_size)) + assert len(repeats) == len(x.get_size()) + + new_size = list(x.get_size()) + + zero_tensor = False + for i in range(len(repeats)): + if repeats[i] == 0: + zero_tensor = True + new_size[i] = new_size[i] * repeats[i] + + if zero_tensor: + return empty(new_size, dtype=x.get_dtype(), device=x.get_device()) + if all((a == 1 or b == 1) for a, b in zip(repeats, old_size)): + return clone(expand(x, new_size)) + + x_loader: Callable[[Any], Any] + + def inner_fn(index): + assert len(index) == len(repeats) + index = list(index) + for i in range(len(repeats)): + if repeats[i] != 1: + if old_size[i] == 1: + index[i] = sympy.S.Zero + else: + index[i] = ModularIndexing(index[i], 1, old_size[i]) + return x_loader(index) + + if not free_unbacked_symbols(old_size) and not free_unbacked_symbols(new_size): + old_size_product = V.graph.sizevars.size_hint_or_throw(sympy_product(old_size)) + if old_size_product > 0: + # maybe realize the input but skip for unbacked symints since it'll + # choke on the size hint. + x.mark_reuse( + V.graph.sizevars.size_hint_or_throw(sympy_product(new_size)) + // old_size_product + ) + + x_loader = x.make_loader() + return Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=inner_fn, + ranges=list(new_size), + ) + + +@register_lowering(aten._unsafe_view, type_promotion_kind=None) +@register_lowering(aten.view, type_promotion_kind=None) +@register_lowering(aten.reshape, type_promotion_kind=None) +def view(x: TensorBox, sizes: Sequence[sympy.Expr]) -> TensorBox: + return TensorBox(View.create(x.data, sizes)) + + +@register_lowering(aten.permute, type_promotion_kind=None) +def permute(x, dims): + assert isinstance(x, TensorBox) + assert isinstance(dims, (list, tuple)) + return TensorBox(PermuteView.create(x.data, tuple(dims))) + + +@register_lowering(aten.slice, type_promotion_kind=None) +def slice_(x, dim=0, start=0, end=2**63, step=1, clamp=True): + """ + Lowers a slice call, creating ExternKernels for the output size & storage offset symbols, + if the indices are unbacked and appropriate semantics aren't known. + If they are known (indices are static/backed/unbacked with info), a SliceView is created. + """ + + from torch.fx.experimental.symbolic_shapes import ( + CallMethodKey, + resolve_unbacked_bindings, + ) + + assert isinstance(x, TensorBox) + dim = _validate_dim(x, dim, 0) + size = x.get_size()[dim] + step = sympy.expand(step) + assert isinstance(step, sympy.Expr) or step > 0, step + + # maybe apply slice optimization + try: + if ( + start == 0 + and V.graph.sizevars.statically_known_leq(size, end) + and step == 1 + ): + return x + except TypeError: + pass + + # try to avoid dynamic (unbacked) slice + def compute_slice_index(index, size, default=None): + if index is None: + return default + + fn = lambda x: V.graph.sizevars.guard_or_false(x) # noqa: E731 + index = sympy.expand(index) + size = sympy.expand(size) + if fn(sympy.Ge(index, 0)) and fn(sympy.Le(index, size)): + return index + elif fn(sympy.Lt(index, 0)) and fn(sympy.Ge(index, -size)): + return index + size + elif fn(sympy.Gt(index, size)): + return size + elif fn(sympy.Lt(index, -size)): + return 0 + return None + + start_index, end_index = None, None + ambiguous_slice = clamp + if ambiguous_slice: + start_index = compute_slice_index(start, size, 0) + end_index = compute_slice_index(end, size, size) + if start_index is not None and end_index is not None: + start, end = start_index, end_index + ambiguous_slice = False + + # ambiguous_slice=False means we know what semantics this slice call follows, + # and don't need to generate an extern kernel to represent the output size. + # This is assumed True for clamp=False + # (meant to follow standard indexing semantics: 0 <= index < size) + if not ambiguous_slice: + return TensorBox( + ir.SliceView.create(x.data, dim, start, end, step, clamp=clamp) + ) # go to SliceView/ReinterpretView + + # unbacked territory: create DynamicSlice ExternKernel + # clamp is True, unbacked start / end + assert clamp + unbacked_bindings = resolve_unbacked_bindings( + V.graph.sizevars.shape_env, V.graph.current_node.meta["unbacked_bindings"] + ) + assert unbacked_bindings is not None + assert len(unbacked_bindings) <= 2, unbacked_bindings + sym_size, sym_storage = None, None + for sym, keypath in unbacked_bindings.items(): + if keypath == (CallMethodKey("size"), pytree.SequenceKey(dim)): + sym_size = sym + elif keypath == (CallMethodKey("storage_offset"),): + sym_storage = sym + + assert start_index is None or end_index is None + b_size = ir.DynamicSliceSize( + sym_size, + start, + end, + step, + x.get_size()[dim], + ) + b_size.name = V.graph.register_buffer(b_size) + V.graph.register_operation(b_size) + new_size = sym_size + + if x.maybe_get_layout() is None: + # realize tensor before accessing layout + x.realize() + + if start_index is not None: + # we shouldn't have allocated storage offset symbol if start index was determinable + assert sym_storage is None + new_storage_offset = x.get_layout().offset + start_index * x.get_stride()[dim] + else: + b_storage = ir.DynamicSelectStorageOffset( + sym_storage, + start, + x.get_layout().offset, + x.get_stride()[dim], + x.get_size()[dim], + clamp=True, + ) + b_storage.name = V.graph.register_buffer(b_storage) + V.graph.register_operation(b_storage) + new_storage_offset = sym_storage + + new_sizes = list(x.get_size()) + new_strides = list(x.get_stride()) + new_sizes[dim] = new_size + new_strides[dim] *= step + return as_strided(x, new_sizes, new_strides, new_storage_offset) + + +@register_lowering(aten.as_strided, type_promotion_kind=None) +def as_strided(x, size, stride, storage_offset=None): + new_device = None + new_dtype = None + if isinstance(x, TensorBox) and isinstance(x.data, ir.BaseView): + # Note: Merging views + # When we use as_strided, we can rewrite the size/stride/offset + # of the incoming buffer x. If x is a view, we would overwrite + # its metadata. Except for dtype, which we need to propagate. + + # Technically device is not needed because it is not possible + # to have a cross-device view today. + new_device = x.get_device() + new_dtype = x.dtype + x = x.data.unwrap_view() + x.realize() + if not ir.is_storage_and_layout(x): + raise NotImplementedError(f"unrealized as_strided({x}, ...)") + storage, old_layout = ir.as_storage_and_layout(x) + new_layout = ir.FixedLayout( + new_device if new_device else old_layout.device, + new_dtype if new_dtype else old_layout.dtype, + [sympy.expand(s) for s in size], + [sympy.expand(s) for s in stride], + sympy.expand(storage_offset or 0), + ) + return TensorBox(ir.ReinterpretView(data=storage, layout=new_layout)) + + +@register_lowering(aten.as_strided_, type_promotion_kind=None) +def as_strided_(x, size, stride, storage_offset=None): + assert isinstance(x, TensorBox) + x.data = as_strided(x, size, stride, storage_offset).data + return x + + +@register_lowering(aten.as_strided_copy, type_promotion_kind=None) +def as_strided_copy(x, size, stride, storage_offset=None): + result = as_strided(x, size, stride, storage_offset) + return clone(result) + + +def pointwise_cat(inputs, dim=0): + # (inclusive, exclusive) + inputs_ranges: list[tuple[sympy.Expr, sympy.Expr]] = [] + prev_end = 0 + for inp in inputs: + inputs_ranges.append((prev_end, prev_end + inp.get_size()[dim])) # type: ignore[arg-type] + prev_end = inputs_ranges[-1][-1] # type: ignore[assignment] + + inputs_loaders = [inp.make_loader() for inp in inputs] + + def inner_fn(idx): + idx_dim = ops.index_expr(idx[dim], torch.int64) + + masks = [] + masked_loads = [] + for i in range(len(inputs)): + start = ( + ops.constant(0, torch.int64) + if i == 0 + else ops.index_expr(inputs_ranges[i][0], torch.int64) + ) + end = ops.index_expr(inputs_ranges[i][1], torch.int64) + + start_cond = ops.ge(idx_dim, start) + end_cond = ops.lt(idx_dim, end) + if i == 0: + mask = end_cond + elif i == len(inputs) - 1: + mask = start_cond + else: + mask = ops.and_(start_cond, end_cond) + + masks.append(mask) + idx_load = list(idx) + + # if we're concatting [4], [2] + # when we index the second tensor for 5 we want to index 5 - 4 + # Use Identity to prevent expansion of index * stride to keep expression + # in same int bitwidth as shape + idx_load[dim] = Identity(idx_load[dim] - inputs_ranges[i][0]) + + masked_loads.append( + ops.masked( + mask, + lambda: inputs_loaders[i](idx_load), + 0.0, # this value should be unused + ), + ) + + next_val = masked_loads[-1] + for i in range((len(inputs)) - 2, -1, -1): + next_val = ops.where( + masks[i], + masked_loads[i], + next_val, + ) + return next_val + + new_size = list(inputs[0].get_size()) + new_size[dim] = inputs_ranges[-1][-1] + + return Pointwise.create( + device=inputs[0].get_device(), + dtype=inputs[0].get_dtype(), + inner_fn=inner_fn, + ranges=new_size, + ) + + +@register_lowering(quantized_decomposed.quantize_per_channel, type_promotion_kind=None) +def quantized_decomposed_quantize_per_channel( + input: TensorBox, + scales: TensorBox, + zero_points: TensorBox, + axis: int, + quant_min: int, + quant_max: int, + dtype: torch.dtype, +) -> Union[TensorBox, ShapeAsConstantBuffer]: + assert len(scales.get_size()) == 1, "expect scales 1 dim" + assert len(zero_points.get_size()) == 1, "expect zero_points 1 dim" + + if input.get_dtype() == torch.bfloat16: + input = to_dtype(input, torch.float32) + assert input.get_dtype() == torch.float32, ( + f"Expecting input to have dtype torch.float32, but got dtype: {input.get_dtype()}" + ) + assert axis < len(input.get_size()), ( + f"Expecting axis to be < {len(input.get_size())}" + ) + + input_loader = input.make_loader() + scales_loader = scales.make_loader() + zero_points_loader = zero_points.make_loader() + + def inner_fn(idx): + channel_idx = (idx[axis],) + + input = input_loader(idx) + scale = scales_loader(channel_idx) + zero_point = zero_points_loader(channel_idx) + qmin, qmax = _create_constants(quant_min, quant_max, dtype=torch.float32) + + if scales.dtype != torch.float32: + scale = ops.to_dtype(scale, torch.float32) + if zero_points.dtype != torch.int32: + zero_point = ops.to_dtype(zero_point, torch.int32) + inv_scale = ops.reciprocal(scale) + val = ops.round(input * inv_scale) + zero_point + clamped = ops.maximum(qmin, ops.minimum(qmax, val)) + return ops.to_dtype(clamped, dtype) + + return Pointwise.create( + device=input.get_device(), + dtype=dtype, + inner_fn=inner_fn, + ranges=input.get_size(), + ) + + +def _assert_async(cond, msg): + cond.realize() + cond = to_dtype(cond, torch.bool) + + def inner_fn(index): + with ir.ComputedBuffer.force_realize(): + return ops.device_assert_async(cond.make_loader()(index), msg) + + assertion_op = Pointwise.create( + device=cond.get_device(), + dtype=cond.get_dtype(), + inner_fn=inner_fn, + ranges=list(cond.get_size()), + ) + assertion_op.realize() + return assertion_op + + +@register_lowering(aten._assert_async.msg) +def lower_assert_async(cond, msg): + return _assert_async(cond, msg) + + +@register_lowering(aten._functional_assert_async.msg) +def lower_assert_functional_async(cond, msg): + return _assert_async(cond, msg) + + +@register_lowering( + quantized_decomposed.dequantize_per_channel, type_promotion_kind=None +) +def quantized_decomposed_dequantize_per_channel( + input: TensorBox, + scales: TensorBox, + zero_points: TensorBox, + axis: int, + quant_min: int, + quant_max: int, + dtype: torch.dtype, + *, + out_dtype: Optional[torch.dtype] = None, +) -> Union[TensorBox, ShapeAsConstantBuffer]: + assert len(scales.get_size()) == 1, "expect scales 1 dim" + assert len(zero_points.get_size()) == 1, "expect zero_points 1 dim" + assert input.get_dtype() == dtype, ( + f"Expecting input to have dtype {dtype}, but got dtype: {input.get_dtype()}" + ) + assert axis < len(input.get_size()), ( + f"Expecting axis to be < {len(input.get_size())}" + ) + + if out_dtype is None: + out_dtype = torch.float32 + + input_loader = input.make_loader() + scales_loader = scales.make_loader() + zero_points_loader = zero_points.make_loader() + + def inner_fn(idx): + channel_idx = (idx[axis],) + + input = input_loader(idx) + scale = scales_loader(channel_idx) + zero_point = zero_points_loader(channel_idx) + + if scales.dtype != torch.float32: + scale = ops.to_dtype(scale, torch.float32) + if zero_points.dtype != torch.float32: + zero_point = ops.to_dtype(zero_point, torch.float32) + val = ops.sub(ops.to_dtype(input, torch.float32), zero_point) * scale + val = ops.to_dtype(val, out_dtype) + return val + + return Pointwise.create( + device=input.get_device(), + dtype=out_dtype, + inner_fn=inner_fn, + ranges=input.get_size(), + ) + + +@register_lowering( + quantized_decomposed.quantize_per_tensor.default, type_promotion_kind=None +) +def quantized_decomposed_quantize_per_tensor_default( + input: TensorBox, + scale: float, + zero_point: int, + quant_min: int, + quant_max: int, + dtype: torch.dtype, +) -> Union[TensorBox, ShapeAsConstantBuffer]: + if input.get_dtype() == torch.bfloat16: + input = to_dtype(input, torch.float32) + assert input.get_dtype() == torch.float32, ( + f"Expecting input to have dtype torch.float32, but got dtype: {input.get_dtype()}" + ) + + input_loader = input.make_loader() + + def inner_fn(idx, scale, zero_point): + input = input_loader(idx) + inv_scale, zero_point = _create_constants( + 1.0 / scale, zero_point, dtype=torch.float32 + ) + val = ops.round(input * inv_scale) + zero_point + qmin, qmax = _create_constants(quant_min, quant_max, dtype=torch.float32) + clamped = ops.minimum(ops.maximum(val, qmin), qmax) + return ops.to_dtype(clamped, dtype) + + return Pointwise.create( + device=input.get_device(), + dtype=dtype, + inner_fn=functools.partial( + inner_fn, scale=float(scale), zero_point=int(zero_point) + ), + ranges=input.get_size(), + ) + + +@register_lowering( + quantized_decomposed.dequantize_per_tensor.default, type_promotion_kind=None +) +def quantized_decomposed_dequantize_per_tensor_default( + input: TensorBox, + scale: float, + zero_point: int, + quant_min: int, + quant_max: int, + dtype: torch.dtype, + *, + out_dtype: Optional[torch.dtype] = None, +) -> Union[TensorBox, ShapeAsConstantBuffer]: + assert input.get_dtype() == dtype, ( + f"Expecting input to have dtype {dtype}, but got dtype: {input.get_dtype()}" + ) + + if out_dtype is None: + out_dtype = torch.float32 + + input_loader = input.make_loader() + + def inner_fn(idx, scale, zero_point): + input = input_loader(idx) + scale, zero_point = _create_constants(scale, zero_point, dtype=torch.float32) + val = ops.sub(ops.to_dtype(input, torch.float32), zero_point) * scale + val = ops.to_dtype(val, out_dtype) + return val + + return Pointwise.create( + device=input.get_device(), + dtype=out_dtype, + inner_fn=functools.partial( + inner_fn, scale=float(scale), zero_point=int(zero_point) + ), + ranges=input.get_size(), + ) + + +@register_lowering( + quantized_decomposed.quantize_per_tensor.tensor, type_promotion_kind=None +) +def quantized_decomposed_quantize_per_tensor_tensor( + input: TensorBox, + scale: TensorBox, + zero_point: TensorBox, + quant_min: int, + quant_max: int, + dtype: torch.dtype, +) -> Union[TensorBox, ShapeAsConstantBuffer]: + if input.get_dtype() == torch.bfloat16: + input = to_dtype(input, torch.float32) + assert input.get_dtype() == torch.float32, ( + f"Expecting input to have dtype torch.float32, but got dtype: {input.get_dtype()}" + ) + assert len(scale.get_size()) == 0 or ( + len(scale.get_size()) == 1 and scale.get_size()[0] == 1 + ), "expect scale as scalar tensor" + assert len(zero_point.get_size()) == 0 or ( + len(zero_point.get_size()) == 1 and zero_point.get_size()[0] == 1 + ), "expect zero_point as scalar tensor" + + input_loader = input.make_loader() + scale_loader = scale.make_loader() + zero_point_loader = zero_point.make_loader() + + def inner_fn(idx): + input = input_loader(idx) + _scale = scale_loader((0,) if len(scale.get_size()) == 1 else ()) + _zero_point = zero_point_loader((0,) if len(scale.get_size()) == 1 else ()) + if scale.dtype != torch.float32: + _scale = ops.to_dtype(_scale, torch.float32) + if zero_point.dtype != torch.float32: + _zero_point = ops.to_dtype(_zero_point, torch.float32) + val = ops.round(input * ops.reciprocal(_scale)) + _zero_point + qmin, qmax = _create_constants(quant_min, quant_max, dtype=torch.float32) + clamped = ops.minimum(ops.maximum(val, qmin), qmax) + return ops.to_dtype(clamped, dtype) + + return Pointwise.create( + device=input.get_device(), + dtype=dtype, + inner_fn=inner_fn, + ranges=input.get_size(), + ) + + +@register_lowering( + quantized_decomposed.dequantize_per_tensor.tensor, type_promotion_kind=None +) +def quantized_decomposed_dequantize_per_tensor_tensor( + input: TensorBox, + scale: TensorBox, + zero_point: TensorBox, + quant_min: int, + quant_max: int, + dtype: torch.dtype, + *, + out_dtype: Optional[torch.dtype] = None, +) -> Union[TensorBox, ShapeAsConstantBuffer]: + assert len(scale.get_size()) == 0 or ( + len(scale.get_size()) == 1 and scale.get_size()[0] == 1 + ), "expect scale as scalar tensor" + assert len(zero_point.get_size()) == 0 or ( + len(zero_point.get_size()) == 1 and zero_point.get_size()[0] == 1 + ), "expect zero_point as scalar tensor" + assert input.get_dtype() == dtype, ( + f"Expecting input to have dtype {dtype}, but got dtype: {input.get_dtype()}" + ) + + if out_dtype is None: + out_dtype = torch.float32 + + input_loader = input.make_loader() + scale_loader = scale.make_loader() + zero_point_loader = zero_point.make_loader() + + def inner_fn(idx): + input = input_loader(idx) + _scale = scale_loader((0,) if len(scale.get_size()) == 1 else ()) + _zero_point = zero_point_loader((0,) if len(scale.get_size()) == 1 else ()) + if scale.dtype != torch.float32: + _scale = ops.to_dtype(_scale, torch.float32) + if zero_point.dtype != torch.float32: + _zero_point = ops.to_dtype(_zero_point, torch.float32) + val = ops.sub(ops.to_dtype(input, torch.float32), _zero_point) * _scale + val = ops.to_dtype(val, out_dtype) + return val + + return Pointwise.create( + device=input.get_device(), + dtype=out_dtype, + inner_fn=inner_fn, + ranges=input.get_size(), + ) + + +@register_lowering(aten.cat) +def cat(inputs, dim=0): + cpu_device = inputs[0].get_device().type == "cpu" + if cpu_device and all( + input.get_dtype() in [torch.int8, torch.uint8] for input in inputs + ): + # TODO Remove this fallback when we support vectorization + # code gen with uint8 data type directly. + for input in inputs: + input.realize() + if all(len(input.get_size()) == 4 for input in inputs): + inputs, _ = require_channels_last(aten.cat, *inputs) + return fallback_handler(aten.cat.default)(inputs, dim) + + if len(inputs) == 1: + return clone(inputs[0]) + + dim = _validate_dim(inputs[0], dim, 0) + dtype = get_promoted_dtype( + *inputs, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ) + inputs = [to_dtype(inp, dtype) for inp in inputs] + + def unwrap_tensor(x: Union[TensorBox, ir.StorageBox]) -> ir.IRNode: + if isinstance(x, TensorBox): + if isinstance(x.data, ir.BaseView): + return x.data.unwrap_view() + else: + return x.data + + if isinstance(x, ir.StorageBox): + return x.data + + return x + + def is_reduction(t): + return isinstance(t, ir.ComputedBuffer) and isinstance(t.data, ir.Reduction) + + def can_fuse_reduction(t): + if isinstance(t, (TensorBox, ir.StorageBox)): + return can_fuse_reduction(unwrap_tensor(t)) + return ( + is_reduction(t) + or isinstance(t, ir.Pointwise) + and any( + can_fuse_reduction(V.graph.get_buffer(read)) + for read in t.get_read_names() + ) + ) + + # fusing reducutions into computed concat buffer can cause regressions. + fusable_reduction = any(can_fuse_reduction(t) for t in inputs) + + def should_lower_cat_input(x) -> bool: + # Unrealized inputs will not be storage and layouts, and we dont want to realize + # them in case we want to fuse + if ir.is_storage_and_layout(x): + storage, _ = ir.as_storage_and_layout(x, freeze=False) + return not ir.ConcatKernel.can_realize_into_without_copy(storage) + + if isinstance(x, (TensorBox, ir.StorageBox)): + return should_lower_cat_input(unwrap_tensor(x)) + + if isinstance(x, ir.Pointwise): + return True + + return False + + if config.force_pointwise_cat: + return pointwise_cat(inputs, dim) + + # TODO: We observed negative performance impact of pointwise_cat optimization on CPU so disabled it. + # We will revisit this later after enabling vectorization on index_expr. + if cpu_device: + return TensorBox(ir.ConcatKernel.create(inputs, dim)) + + def op_count(x): + if isinstance(x, (TensorBox, ir.StorageBox)): + return op_count(unwrap_tensor(x)) + + # this will correspond to a direct memory read + if not isinstance(x, ir.Pointwise): + return 0 + + count = x.inner_fn_opcount().num_ops + for read in x.get_read_names(): + count += op_count(V.graph.get_buffer(read)) + + return count + + # as of inputs increase, possibility for register spilling also increases + # past a certain threshold of inputs we only fuse if the if the input kernels + # are simple + # not sure if we want to expose to users via config since logic may change in future + MAX_COMPLEX_POINTWISE_CAT = 8 + MAX_SIMPLE_OP_COUNT = 2 + + def additional_pointwise_ops(op: torch._ops.OpOverload): + return op in (aten.cat.default, aten.constant_pad_nd.default) + + if len(inputs) <= MAX_COMPLEX_POINTWISE_CAT or ( + (len(inputs) <= config.max_pointwise_cat_inputs) + and all(op_count(t) <= MAX_SIMPLE_OP_COUNT for t in inputs) + ): + pointwise_uses = all( + is_pointwise_use(use, additional_pointwise_ops) + for use in V.current_node.users + ) + # fuse in case we will be used in a pointwise node, and there are any inputs we + # we can prevent materialization of. + fuse_pointwise_use = ( + any(should_lower_cat_input(inp) for inp in inputs) and pointwise_uses + ) + + # horizontal fuse in case all inputs will require a copy kernel anyway. + # only horizontally fuse pointwise kernels + horizontal_fuse_cat = all( + should_lower_cat_input(inp) for inp in inputs + ) and not any(can_fuse_reduction(t) for t in inputs) + if fuse_pointwise_use or (horizontal_fuse_cat and not fusable_reduction): + return pointwise_cat(inputs, dim) + + return TensorBox(ir.ConcatKernel.create(inputs, dim)) + + +@register_lowering(aten.diagonal, type_promotion_kind=None) +def diagonal(input, offset: int = 0, dim1: int = 0, dim2: int = 1): + original_shape = input.get_size() + num_dims = len(original_shape) + dim1 = canonicalize_dim(idx=dim1, rank=num_dims) + dim2 = canonicalize_dim(idx=dim2, rank=num_dims) + + check( + dim1 != dim2, lambda: f"diagonal dimensions cannot be identical {dim1}, {dim2}" + ) + + offset_negative = V.graph.sizevars.evaluate_expr(sympy.Lt(offset, 0)) + if offset_negative: + diag_size = V.graph.sizevars.evaluate_max( + V.graph.sizevars.evaluate_min( + original_shape[dim1] + offset, original_shape[dim2] + ), + 0, # type: ignore[arg-type] + ) + else: + diag_size = V.graph.sizevars.evaluate_max( + V.graph.sizevars.evaluate_min( + original_shape[dim1], original_shape[dim2] - offset + ), + 0, # type: ignore[arg-type] + ) + + base_idx = (0, 0) + if offset_negative: + base_idx = (-offset, 0) + else: + base_idx = (0, offset) + + sizes = [s for i, s in enumerate(original_shape) if i not in (dim1, dim2)] + sizes.append(diag_size) + + def reindexer(idx): + diag_idx = idx[-1] + original_idx = [0] * len(original_shape) + cur_dim = 0 + for d in range(num_dims): + if d == dim1: + original_idx[d] = diag_idx + base_idx[0] + elif d == dim2: + original_idx[d] = diag_idx + base_idx[1] + else: + original_idx[d] = idx[cur_dim] + cur_dim += 1 + + assert cur_dim == len(original_shape) - 2 + return original_idx + + return TensorBox(ir.GenericView.create(input, sizes, reindexer)) + + +@register_lowering(aten.diagonal_copy, type_promotion_kind=None) +def diagonal_copy(input, offset: int = 0, dim1: int = 0, dim2: int = 1): + return clone(diagonal(input, offset, dim1, dim2)) + + +@register_lowering(aten.diagonal_scatter, type_promotion_kind=None) +def diagonal_scatter(input, src, offset: int = 0, dim1: int = 0, dim2: int = 1): + output = clone(input) + target = diagonal(output, offset, dim1, dim2) + mutate_to(target, src) + return output + + +@register_lowering(aten.select, type_promotion_kind=None) +def select(x, dim, idx): + idx = sympy.expand(idx) + size = sympy.expand(x.get_size()[dim]) + actual_index = None + + if V.graph.sizevars.guard_or_false(sympy.Lt(idx, 0)): + actual_index = idx + size + elif V.graph.sizevars.guard_or_false(sympy.Ge(idx, 0)): + actual_index = idx + + if actual_index is not None: + if has_free_unbacked_symbols(idx): + # Inductor could generate incorrect views for tensors with unbacked symbols here; + # Squeeze operations are translated to views, resulting in incorrect strides. + # Additionally, we want to avoid accidental unbacked unsqueeze semantics. To resolve this, + # we use as_strided instead. + # Removing this branch will cause test_unbacked_select_index_with_check to fail. + + # before accessing size, stride, and offset we need to realize. + x.realize() + new_size = x.get_size() + new_stride = x.get_stride() + new_storage_offset = x.get_layout().offset + new_stride[dim] * actual_index + + del new_size[dim] + del new_stride[dim] + return as_strided(x, new_size, new_stride, new_storage_offset) + else: + # no need to clamp, this function handles negative indexing itself + slice_result = slice_(x, dim, actual_index, actual_index + 1, clamp=False) + return squeeze(slice_result, dim) + + # Unbacked Semantics: + # When the index idx is unbacked (e.g., u0), we compute the index dynamically + # during the lowering of the select operation using DynamicSelectStorageOffset. + + unbacked_bindings = resolve_unbacked_bindings( + V.graph.sizevars.shape_env, V.graph.current_node.meta["unbacked_bindings"] + ) + assert unbacked_bindings is not None + assert len(unbacked_bindings) == 1, unbacked_bindings + unbacked_offset_sym, _ = next(iter(unbacked_bindings.items())) + + # before accessing size, stride, and offset we need to realize. + x.realize() + new_size = x.get_size() + new_stride = x.get_stride() + new_storage_offset = unbacked_offset_sym + buffer = ir.DynamicSelectStorageOffset( + unbacked_offset_sym, + idx, + x.get_layout().offset, + new_stride[dim], + x.get_size()[dim], + clamp=False, + ) + buffer.name = V.graph.register_buffer(buffer) + V.graph.register_operation(buffer) + + del new_size[dim] + del new_stride[dim] + return as_strided(x, new_size, new_stride, new_storage_offset) + + +@register_lowering(aten.split, type_promotion_kind=None) +def split(x, sizes, dim=0): + dim = _validate_dim(x, dim, 0) + sizes_ = sizes + + # If sizes is an integer (or a SymInt), we turn it into a list of sizes + # by computing what the actual size of each chunk should be. + if not isinstance(sizes, (list, tuple)): + x_size = x.get_size()[dim] + chunks = V.graph.sizevars.guard_int(FloorDiv(x_size + sizes - 1, sizes)) + sizes_ = [sizes] * chunks + # The last chunk might have a smaller size than the rest. + sizes_[-1] = x_size - (chunks - 1) * sizes + + # From this point, we assume that the sum of the sizes of all chunks + # equals the size of the base tensor. + result = [] + start = 0 + for size in sizes_: + end = start + size + # No need for clamping here, since we compute the exact + # start and end values. + result.append(slice_(x, dim, start, end, clamp=False)) + start = end + return result + + +@register_lowering(aten.split_with_sizes, type_promotion_kind=None) +def split_with_sizes(x, sizes, dim=0): + return split(x, sizes, dim) + + +@register_lowering(aten.unbind, type_promotion_kind=None) +def unbind(x, dim=0): + dim = _validate_dim(x, dim, 0) + x_size = V.graph.sizevars.guard_int(x.get_size()[dim]) + result = [select(x, dim, i) for i in range(x_size)] + return result + + +@register_lowering(aten.unfold, type_promotion_kind=None) +def unfold(x, dimension, size, step): + sizes = x.get_size() + ndim = len(sizes) + dim = canonicalize_dim(ndim, dimension) + + if ndim == 0: + return slice_(unsqueeze(x, 0), end=size, clamp=False) + + dim_size = sizes[dim] + sizevars = V.graph.sizevars + sizevars.check_leq(size, dim_size) + sizevars.check_lt(0, step) # type: ignore[arg-type] + + new_dim_size = FloorDiv(dim_size - size, step) + 1 + if sizevars.size_hint_or_throw(dim_size) > 0: + x.mark_reuse( + sizevars.size_hint_or_throw(CeilDiv(new_dim_size * size, dim_size)) + ) + + out_size = [*sizes[:dim], new_dim_size, *sizes[dim + 1 :], size] + + def reindexer(idx): + dim_idx = idx[-1] + idx[dim] * step + return (*idx[:dim], dim_idx, *idx[dim + 1 : -1]) + + return TensorBox(ir.GenericView.create(x, out_size, reindexer)) + + +@register_lowering(aten.unsqueeze, type_promotion_kind=None) +def unsqueeze(x, dim): + dim = _validate_dim(x, dim, 1) + new_shape = list(x.get_size()) + new_shape.insert(dim, sympy.S.One) + return view(x, new_shape) + + +@register_lowering(aten.unsqueeze_, type_promotion_kind=None) +def unsqueeze_(x, dim): + val = unsqueeze(x, dim) + assert isinstance(x, TensorBox) + assert isinstance(val, TensorBox) + x.data = val.data + return x + + +def _validate_dim(x, dim, offset=0): + dim = V.graph.sizevars.shape_env.evaluate_expr(sympy.sympify(dim)) + ndim = len(x.get_size()) + if dim < 0: + dim += ndim + offset + assert 0 <= dim < ndim + offset + return dim + + +@register_lowering(aten.glu) +def glu(x, dim=-1): + dim = _validate_dim(x, dim, 0) + # TODO: don't guard on static shape here + new_len = V.graph.sizevars.guard_int(x.get_size()[dim]) // 2 + # no need to clamp, index is int based on input size + a = slice_(x, dim, 0, new_len, clamp=False) + b = slice_(x, dim, new_len, new_len * 2, clamp=False) + return mul(a, sigmoid(b)) + + +def fallback_handler(kernel, add_to_fallback_set=True): + if add_to_fallback_set: + fallbacks.add(kernel) + + def handler(*args, **kwargs): + def wrap_tensors(x): + return TensorBox.create(x) if isinstance(x, ir.IRNode) else x + + return pytree.tree_map( + wrap_tensors, ir.FallbackKernel.create(kernel, *args, **kwargs) + ) + + # This lets us detect that a lowering is a fallback handler. + handler._is_fallback_handler = True # type: ignore[attr-defined] + + return handler + + +@functools.cache +def _warn_complex_not_supported(): + warnings.warn( + "Torchinductor does not support code generation for complex operators. Performance may be worse than eager." + ) + + +# There are some types (CPU) which we accept as input but not as +# output. +def unsupported_input_tensor(t: torch.Tensor, node=None): + "Do not support reading or writing to this tensor" + if t.is_complex(): + # Complex views are supported with IR ComplexView + _warn_complex_not_supported() + return True + + if t.is_meta: + return True + + if t.is_sparse: + return True + + if t.dtype == torch.float8_e8m0fnu: + if not node: + return True + + # allow bitcast, views, memory movement, but not arithmetic + # TODO: delete once triton adds native support + return not ( + isinstance(node.target, torch._ops.OpOverload) + and node.target + in ( + aten.view.dtype, + aten.cat.default, + aten.clone.default, + aten._scaled_mm.default, + ) + or (isinstance(node.target, torch._ops.OpOverload) and is_view(node.target)) + ) + + return False + + +def unsupported_output_tensor(t: torch.Tensor, node=None): + "Do not support writing tensor but can read from it" + supported_complex_views = ( + aten.view.dtype, + torch.ops.prims.convert_element_type.default, + ) + if node is not None and node.target in supported_complex_views and t.is_complex(): + return False + if unsupported_input_tensor(t, node): + return True + return t.is_cpu and config.disable_cpp_codegen + + +def fallback_node_due_to_unsupported_type(node: torch.fx.Node, allow_cpu_inputs=True): + # Custom fallback lowering + if node.target is aten.view_as_complex.default: + return False + + if node.op == "placeholder": + return False + + # We should be able to remove this special case once `disable_cpp_codegen` is killed. + if node.target is aten.lift_fresh_copy.default: + return False + + def check_skip_condition(inp_out_node, is_output): + if not isinstance(inp_out_node, torch.fx.Node): + return False + + if "val" not in inp_out_node.meta: + return False + + for meta in pytree.tree_leaves(inp_out_node.meta["val"]): + if not isinstance(meta, torch._subclasses.FakeTensor): + continue + + if is_output: + if unsupported_output_tensor(meta, node): + return True + else: + if unsupported_input_tensor(meta, node): + return True + + return False + + # only skip codegen if there is a cpu output, not input + for arg in pytree.arg_tree_leaves(*node.args, **node.kwargs): + if check_skip_condition(arg, is_output=False): + return True + + return check_skip_condition(node, is_output=True) + + +def make_fallback(op, layout_constraint=None, warn=True, override_decomp=False): + assert op not in decompositions or override_decomp, ( + f"both a fallback and a decomp for same op: {op}" + ) + if ( + warn + and bool(os.getenv("CI")) + and get_decompositions([op]) + # if fallback_random, we allow not decomposing random + and not ( + config.fallback_random + and op in torch._decomp.decompositions_for_rng.extra_random_decomps + ) + and not override_decomp + ): + # Note: 'warn' is holdover from when this was a warning, but for ops that previously + # set warn=False we do not want a CI error. + # Ignore the 'suppress errors' configs in CI, as this particular warning happens on startup anyway and is not + # likely to be triggered preferentially on one CI config over another. + if torch._dynamo.config.suppress_errors: + torch._dynamo.config.suppress_errors = False + log.warning( + "A make_fallback error occurred in suppress_errors config," + " and suppress_errors is being disabled to surface it." + ) + raise AssertionError( + f"make_fallback({op}): a decomposition exists, we should switch to it." + " To fix this error, either add a decomposition to core_aten_decompositions (preferred)" + " or inductor_decompositions, and delete the corresponding `make_fallback` line." + " Get help from the inductor team if unsure, don't pick arbitrarily to unblock yourself.", + ) + + def register_fallback(op_overload): + add_needs_realized_inputs(op_overload) + if layout_constraint is not None: + add_layout_constraint(op_overload, layout_constraint) + return register_lowering(op_overload, type_promotion_kind=None)( + fallback_handler(op_overload) + ) + + if isinstance(op, torch._ops.OpOverloadPacket): + for ol in op.overloads(): + op_overload = getattr(op, ol) + register_fallback(op_overload) + elif isinstance(op, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)): + register_fallback(op) + else: + raise RuntimeError(f"Unsupported fallback {op} with type {type(op)}") + + +def philox_rand_offset(shape): + """ + TorchInductor offset calculation differs from PyTorch eager offset + calculation for random ops (tl.rand vs torch.rand). In future, we should + strive for same impl for tl.rand and torch.rand. + """ + numel = 1 + for s in shape: + numel = numel * s + return tensor(numel, dtype=torch.int64) + + +@register_lowering(torch.ops.rngprims.philox_rand, type_promotion_kind=None) +def philox_rand(size, seed, offset, stride, device, dtype): + # stride arg is optional and will be used in future for distributed random + # ops. Currently, its unused. + random_pos = ir.FixedLayout( + device, + dtype, + size, + ir.FlexibleLayout.contiguous_strides(size), + ).make_indexer() + seed_loader = seed.make_loader() + offset_loader = offset.make_loader() + + def inner_fn(index): + # Both seed and offset in the philox_rand op are tensors. + # torch seed and offsets are of type int64, but tl.rand accepts int32 + seed_index_expr = ops.to_dtype(seed_loader([]), torch.int32) + offset_index_expr = ops.to_dtype(offset_loader([]), torch.int32) + # Get the offset'd position + rand_index_expr = ops.add( + ops.index_expr(random_pos(index), torch.int32), offset_index_expr + ) + result = ops.rand( + seed_index_expr, + rand_index_expr, + ) + return ops.to_dtype(result, dtype) + + random_values_node = Pointwise.create( + device=device, + dtype=dtype, + inner_fn=inner_fn, + ranges=list(size), + ) + + offset_node = philox_rand_offset(size) + return random_values_node, offset_node + + +@register_lowering(aten.native_dropout, type_promotion_kind=None) +def native_dropout(x, p, train): + if config.fallback_random: + return pytree.tree_map( + TensorBox.create, + ir.FallbackKernel.create(aten.native_dropout.default, x, p, train), + ) + else: + raise AssertionError("should be handled in replace_random.py") + + +@register_lowering(aten.bernoulli_, type_promotion_kind=None) +def bernoulli_(x, *args): + assert config.fallback_random or x.get_device() == torch.device("cpu"), ( + "this should be handled in decomps unless config.fallback_random or the device is CPU" + ) + x.realize() + op_overload = ( + aten.bernoulli_.float + if len(args) == 0 or isinstance(args[0], float) + else aten.bernoulli_.Tensor + ) + ir.InplaceBernoulliFallback(op_overload, x, *args) + return x + + +@register_lowering(aten.bernoulli.p, type_promotion_kind=None) +def bernoulli_p(x, *args): + assert config.fallback_random or x.get_device() == torch.device("cpu"), ( + "this should be handled in decomps unless config.fallback_random or the device is CPU" + ) + return bernoulli_(clone(x), *args) + + +# This shouldn't be called in general +@register_lowering(aten._foobar) +def _foobar(_): + raise AssertionError + + +@functools.lru_cache(1) +def _warn_triton_random(salt): + log.info("using triton random, expect difference from eager") + + +def warn_triton_random(): + # only warn once per graph + _warn_triton_random(V.graph.creation_time) + + +fallback_rand_default = fallback_handler(aten.rand.default) +fallback_rand_generator = fallback_handler(aten.rand.generator) +fallback_randn_default = fallback_handler(aten.randn.default) +fallback_randn_generator = fallback_handler(aten.randn.generator) +make_fallback(aten.randint) + +# TODO: mlazos reevaluate if we want to codegen something different +make_fallback(torch.ops.streams.record_event.default) +make_fallback(torch.ops.streams.wait_event.default) + + +@register_lowering(aten.rand) +def rand(*args, **kwargs): + if kwargs.get("generator") is not None: + return fallback_rand_generator(*args, **kwargs) + elif config.fallback_random: + kwargs.pop("generator", None) + return fallback_rand_default(*args, **kwargs) + raise AssertionError("should have been handled in replace_random.py") + + +@register_lowering(aten.randn) +def randn(*args, **kwargs): + if kwargs.get("generator") is not None: + return fallback_randn_generator(*args, **kwargs) + elif config.fallback_random: + kwargs.pop("generator", None) + return fallback_randn_default(*args, **kwargs) + raise AssertionError("should have been handled in replace_random.py") + + +@register_lowering(inductor_prims.force_stride_order, type_promotion_kind=None) +def inductor_force_stride_order(input_tensor, stride): + stride_order = ir.get_stride_order(stride) + return ir.ExternKernel.require_stride_order(input_tensor, stride_order) + + +@register_lowering(inductor_prims.seed, type_promotion_kind=None) +def inductor_seed(device: torch.device): + raise AssertionError("should be handled in fuse_seed_creation_pass()") + + +@register_lowering(inductor_prims.seeds, type_promotion_kind=None) +def inductor_seeds(count, device): + warn_triton_random() + return TensorBox.create(ir.RandomSeeds(count, decode_device(device))) + + +@register_lowering(inductor_prims.lookup_seed, type_promotion_kind=None) +def inductor_lookup_seed(seeds, index): + def inner_fn(_): + return ops.load_seed(seeds.get_name(), index) + + return Pointwise.create( + device=seeds.get_device(), + dtype=seeds.get_dtype(), + inner_fn=inner_fn, + ranges=[], + ) + + +@register_lowering(inductor_prims.random, type_promotion_kind=None) +def inductor_random(size: list[int], seed: TensorBox, mode: str, *, offset: int = 0): + assert not config.fallback_random + assert mode in ("rand", "randn") + size = [*size] + dtype = torch.float32 + device = seed.get_device_or_error() + random_pos = ir.FixedLayout( + device, dtype, size, ir.FlexibleLayout.contiguous_strides(size), offset=offset + ).make_indexer() + seed_loader = seed.make_loader() + + def inner_fn(index): + return getattr(ops, mode)( + seed_loader([]), + ops.index_expr(random_pos(index), torch.int32), + ) + + result = Pointwise.create( + device=device, + dtype=dtype, + inner_fn=inner_fn, + ranges=[*size], + ) + result.realize() + return result + + +@register_lowering(inductor_prims.randint, type_promotion_kind=None) +def inductor_randint( + low: int, high: int, size: list[int], seed: TensorBox, *, offset: int = 0 +): + assert not config.fallback_random + size = [*size] + dtype = torch.int64 + device = seed.get_device_or_error() + random_pos = ir.FixedLayout( + device, dtype, size, ir.FlexibleLayout.contiguous_strides(size), offset=offset + ).make_indexer() + seed_loader = seed.make_loader() + + def inner_fn(index): + return ops.randint64( + seed_loader([]), + ops.index_expr(random_pos(index), torch.int32), + ops.index_expr(low, torch.int64), + ops.index_expr(high, torch.int64), + ) + + return Pointwise.create( + device=device, + dtype=dtype, + inner_fn=inner_fn, + ranges=[*size], + ) + + +def _boundaries_helper(tb: TensorBox) -> tuple[str, sympy.Expr, sympy.Expr, sympy.Expr]: + # Calculate the maximum offset for the boundaries tensor + # For a strided tensor, this is sum((size[i] - 1) * stride[i]) + stride[-1] + # This ensures the mask check in bucketize_binary_search works correctly + # for both contiguous and non-contiguous tensors. + size = tb.get_size() + stride = tb.get_stride() + max_offset = sum((s - 1) * st for s, st in zip(size, stride)) + stride[-1] + return ( + tb.get_name(), + size[-1], + max_offset, + stride[-1], + ) + + +def _sorter_helper(tb: TensorBox) -> tuple[str, sympy.Expr]: + return tb.get_name(), tb.get_stride()[-1] + + +@register_lowering(aten.searchsorted.Tensor, type_promotion_kind=None) +def searchsorted( + sorted_sequence: TensorBox, + self: TensorBox, + *, + out_int32: bool = False, + right: bool = False, + side: Optional[str] = None, + sorter: Optional[TensorBox] = None, +) -> Union[TensorBox, ShapeAsConstantBuffer]: + validate_bucketize = lambda tb: V.graph.has_feature( # noqa: E731 + tb, BackendFeature.BUCKETIZE + ) + if ( + not validate_bucketize(sorted_sequence) + or not validate_bucketize(self) + or (sorter is not None and not validate_bucketize(sorter)) + ): + return fallback_handler(aten.searchsorted.Tensor, add_to_fallback_set=False)( + sorted_sequence, + self, + out_int32=out_int32, + right=right, + side=side, + sorter=sorter, + ) + + # If side is present, override the value of right if needed. This assumes that + # validation of the two options being non-contradictory is already done by the + # searchsorted meta-function. + if side is not None and side == "right": + right = True + + index_dtype = torch.int32 if out_int32 else torch.int64 + values_loader = self.make_loader() + + # The entire sorted_sequence tensor needs to be used by ops.bucketize, so we need to + # realize it into global memory; or in other words, we can't guarantee that + # sorted_sequence.get_name() (used below) will exist unless we call + # sorted_sequence.realize(). + sorted_sequence.realize() + + if sorter is not None: + sorter.realize() + + if len(sorted_sequence.get_size()) == 1: + + def inner_fn(idx): + val = values_loader(idx) + return ops.bucketize( + val, + _boundaries_helper(sorted_sequence), + 0, + index_dtype, + right, + sorter=None if sorter is None else _sorter_helper(sorter), + sorter_indices=None if sorter is None else 0, + ) + + else: + + def inner_fn(idx): + val = values_loader(idx) + + # Get index to the beginning of the sorted sequence within a flattened + # version of the array. + def get_flattened_index(tb: TensorBox): + strides = tb.get_stride() + return ops.index_expr( + functools.reduce( + operator.add, (s * i for s, i in zip(strides[:-1], idx[:-1])) + ), + index_dtype, + ) + + return ops.bucketize( + val, + _boundaries_helper(sorted_sequence), + get_flattened_index(sorted_sequence), + index_dtype, + right, + sorter=None if sorter is None else _sorter_helper(sorter), + sorter_indices=None if sorter is None else get_flattened_index(sorter), + ) + + device = self.get_device() + result = Pointwise.create( + device=device, + dtype=index_dtype, + inner_fn=inner_fn, + ranges=self.shape, + ) + # see [NOTE: inductor bucketize realize] + result.realize() + + return result + + +@register_lowering( + aten.bucketize, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH +) +def bucketize( + input: TensorBox, + boundaries: TensorBox, + *, + out_int32: bool = False, + right: bool = False, +): + assert len(boundaries.get_size()) == 1 + + if not ( + V.graph.has_feature(input, BackendFeature.BUCKETIZE) + and V.graph.has_feature(boundaries, BackendFeature.BUCKETIZE) + ): + return fallback_handler(aten.bucketize.Tensor, add_to_fallback_set=False)( + input, boundaries, out_int32=out_int32, right=right + ) + + # The entire boundaries tensor needs to be used by ops.bucketize, so we + # need to realize it into global memory; or in other words, we can't + # guarantee that boundaries.get_name() (used below) will exist unless + # we call boundaries.realize(). + boundaries.realize() + device = input.get_device() + input_loader = input.make_loader() + + index_dtype = torch.int32 if out_int32 else torch.int64 + + def inner_fn(index): + val = input_loader(index) + indices = ops.bucketize( + val, + _boundaries_helper(boundaries), + 0, + index_dtype, + right, + ) + + return indices + + result = Pointwise.create( + device=device, + dtype=index_dtype, + inner_fn=inner_fn, + ranges=input.get_size(), + ) + + # [NOTE: inductor bucketize realize] + # bucketize_binary_search is relatively expensive, so we don't want to re-compute + # it unnecessarily. If we run bucketize() and then broadcast the result, we don't + # want this to be fused into a large number of duplicate bucketize() computations + # for each of the elements in the result. + # + # If no broadcasting occurs, fusions can still occur in scheduler.py + result.realize() + + return result + + +def require_dense(_, *args, **kwargs): + args, kwargs = pytree.tree_map_only( + ir.IRNode, ir.ExternKernel.require_stride1, (args, kwargs) + ) + return args, kwargs + + +def require_contiguous(_, *args, **kwargs): + args, kwargs = pytree.tree_map_only( + ir.IRNode, ir.ExternKernel.require_contiguous, (args, kwargs) + ) + return args, kwargs + + +def require_contiguous_strides(_, *args, **kwargs): + # TODO: combine this with require_contiguous after + # https://github.com/pytorch/pytorch/pull/148235 lands. + args, kwargs = pytree.tree_map_only( + ir.IRNode, ir.ExternKernel.require_contiguous_strides, (args, kwargs) + ) + return args, kwargs + + +def require_channels_last(_, *args, **kwargs): + args, kwargs = pytree.tree_map_only( + ir.IRNode, ir.ExternKernel.require_channels_last, (args, kwargs) + ) + return args, kwargs + + +def constrain_to_fake_tensor(arg, fake_arg): + if fake_arg is None: + return arg + if isinstance(fake_arg, FakeScriptObject): + return arg + if isinstance(arg, ir.IRNode): + meta_stride_expr = [ + s.node.expr if isinstance(s, torch.SymInt) else s for s in fake_arg.stride() + ] + return ir.ExternKernel.require_exact_strides(arg, meta_stride_expr) + if isinstance(arg, dict): + return {key: constrain_to_fake_tensor(arg[key], fake_arg[key]) for key in arg} + elif isinstance(arg, (tuple, list)): + return type(arg)( + constrain_to_fake_tensor(a, f_a) for (a, f_a) in zip(arg, fake_arg) + ) + return arg + + +def constrain_to_fake_tensors(args, kwargs, fake_args, fake_kwargs): + args = tuple( + constrain_to_fake_tensor(arg, fake_arg) + for arg, fake_arg in zip(args, fake_args) + ) + kwargs = {k: constrain_to_fake_tensor(v, fake_kwargs[k]) for k, v in kwargs.items()} + return args, kwargs + + +def constrain_to_fx_strides(fx_node, *args, **kwargs): + def apply_constraint(arg, fx_arg): + if isinstance(arg, ir.IRNode): + stride_order = ir.get_stride_order( + fx_arg.meta["val"].stride(), V.graph.sizevars.shape_env + ) + return ir.ExternKernel.require_stride_order(arg, stride_order) + if isinstance(arg, dict): + return {key: apply_constraint(arg[key], fx_arg[key]) for key in arg} + return arg + + args = tuple( + apply_constraint(arg, fx_arg) for arg, fx_arg in zip(args, fx_node.args) + ) + kwargs = {k: apply_constraint(v, fx_node.kwargs[k]) for k, v in kwargs.items()} + return args, kwargs + + +def sdpa_constraint(fx_node, *args, **kwargs): + # sdpa requires dense last dimension] + + def apply_constraint(idx, arg, fx_arg): + if not isinstance(arg, ir.IRNode): + return arg + + meta_val = fx_arg.meta["val"] + meta_stride_expr = [ + s.node.expr if isinstance(s, torch.SymInt) else s for s in meta_val.stride() + ] + shape_env = V.graph.sizevars.shape_env + stride_order = ir.get_stride_order(meta_val.stride(), shape_env) + + if stride_order and stride_order[-1] != 0: + # contiguous stride order + stride_order = list(reversed(range(len(arg.get_size())))) + + if ( + fx_node.target + == aten._scaled_dot_product_efficient_attention_backward.default + and idx in (0, 5) + ): + assert len(stride_order) == 4 + # The 0 and 5th arguments for aten._scaled_dot_product_efficient_attention_backward.default + # are for out and gradient_out. They have to be in + # (3, 1, 2, 0) stride order. Otherwise the kernel will crash. + # Check https://github.com/pytorch/pytorch/issues/138772 + stride_order = (3, 1, 2, 0) + + if not meta_val.is_cuda: + return ir.ExternKernel.require_stride_order(arg, stride_order) + + # This is the minimum alignment required by SDPA kernels for attention_bias. + # This value can be found in pytorch/aten/src/ATen/native/transformers/attention.cpp preprocess_mask + ALIGNMENT = 8 + + # effn_attn_fwd does requires dense last dim, not just alignment + effn_attn_fwd_bias = ( + fx_node.target + == torch.ops.aten._scaled_dot_product_efficient_attention.default + and idx == 3 + ) + + assert isinstance(arg, TensorBox) + if len(arg.get_size()) not in (3, 4): + return arg + + is_aligned_tensor = ir.is_aligned_realized_tensor(arg, ALIGNMENT) + if is_aligned_tensor: + return ir.try_match_insignificant_strides( + ir.ExternKernel.realize_input(arg), meta_stride_expr + ) + + if ( + isinstance(arg, IRNode) + and arg.maybe_get_stride() is not None + and is_aligned_tensor + ): + return ir.try_match_insignificant_strides( + ir.ExternKernel.realize_input(arg), meta_stride_expr + ) + + if effn_attn_fwd_bias: + out_size = list(arg.get_size()) + + expanded_dims = [] + # We require a dense last dimension, but the other strides + # can be expanded, which results in a smaller tensor + maybe_stride = arg.maybe_get_stride() + for i in range(len(arg.get_size()) - 1): + if V.graph.sizevars.statically_known_equals(meta_stride_expr[i], 0) or ( + maybe_stride is not None + and V.graph.sizevars.statically_known_equals(maybe_stride[i], 0) + ): + expanded_dims.append(i) + + # Now, pad strides to alignment + out_strides = [-1] * len(out_size) + out_strides[-1] = 1 + stride = 1 + for i in range(len(out_size) - 2, -1, -1): + if out_strides[i + 1] != 0: + stride = stride * out_size[i + 1] + + # the expanded dims still need to be aligned, if they are, + # we can make them expanded by setting the stride equal to 0 + if i in expanded_dims: + if V.graph.sizevars.statically_known_equals( + out_strides[i + 1] % ALIGNMENT, 0 + ): + out_strides[i] = 0 + continue + + if not V.graph.sizevars.statically_known_equals(stride % ALIGNMENT, 0): + stride = ceildiv(stride, ALIGNMENT) * ALIGNMENT + + out_strides[i] = stride + + return ir.ExternKernel.require_exact_strides(arg, out_strides) + + if is_aligned_tensor: + return ir.try_match_insignificant_strides( + ir.ExternKernel.realize_input(arg), meta_stride_expr + ) + + if ( + isinstance(arg, IRNode) + and arg.maybe_get_stride() is not None + and is_aligned_tensor + ): + return ir.try_match_insignificant_strides( + ir.ExternKernel.realize_input(arg), meta_stride_expr + ) + + def is_aligned(x): + return V.graph.sizevars.guard_or_false( + sympy.Eq(Mod(x.get_size()[-1], ALIGNMENT), 0) + ) + + if isinstance(arg.data, ir.BaseView): + if not is_aligned(arg): + if is_aligned(arg.unwrap_view()): + return ir.try_match_insignificant_strides( + ir.ExternKernel.realize_input(arg), meta_stride_expr + ) + + return ir.ExternKernel.require_stride_order(arg, stride_order) + + args = tuple( + apply_constraint(idx, arg, fx_arg) + for idx, (arg, fx_arg) in enumerate(zip(args, fx_node.args)) + ) + kwargs = {k: apply_constraint(-1, v, fx_node.kwargs[k]) for k, v in kwargs.items()} + return args, kwargs + + +# WIP +make_fallback(aten._adaptive_avg_pool3d) # @isuruf +make_fallback(aten.adaptive_max_pool3d) # @isuruf +make_fallback(aten._scaled_dot_product_attention_math_for_mps) # @malfet + + +# 1) Easy +make_fallback(aten.uniform, warn=False) +make_fallback(aten.exponential.default, warn=False) # (fails accuracy on test_torch.py) +make_fallback(aten._pdist_forward) # Has decomp. Needs benchmarks +make_fallback(aten.soft_margin_loss_backward, warn=False) # py_impl? +make_fallback(aten._fused_rms_norm, warn=False) # (MPS-only and faster than decomp) +if torch.xpu.is_available(): + make_fallback( + aten.embedding_dense_backward, warn=False + ) # (XPU-only and faster than decomp) + +if torch.mtia._is_compiled(): + make_fallback( + aten.native_layer_norm, warn=False + ) # (MTIA-only and faster than decomp) + +# 1.5) Easy or Impossible +make_fallback(aten._cdist_forward) # p=2 should be feasible +make_fallback(aten._cdist_backward) + +# 2) Medium +make_fallback(aten._trilinear) + + +# 3) Difficult +# Scans +# See the discussion at +# https://dev-discuss.pytorch.org/t/pytorch-sparse-gnn-compiler-rfc/1644/19 +make_fallback(aten.segment_reduce.default) +make_fallback(aten._segment_reduce_backward.default) + +# Histogram (need to implement Histogram IR) +make_fallback(aten.histc) +make_fallback(aten.histogram.bin_ct) +make_fallback(aten._histogramdd_bin_edges.default) +make_fallback(aten._histogramdd_from_bin_cts.default) + +# Need templated kernel +make_fallback(aten.addbmm) +make_fallback(aten._addmm_activation, warn=False) + +make_fallback(aten._grouped_mm, require_dense) + +# Need templated kernel. Probably impossible to write efficiently +make_fallback(aten.convolution_backward, constrain_to_fx_strides) +make_fallback(aten._cudnn_rnn, require_dense) +make_fallback(aten._cudnn_rnn_backward, require_contiguous) + +# Haven't checked but sound difficult / impossible +make_fallback(aten._embedding_bag, require_contiguous) +make_fallback(aten._embedding_bag_forward_only, require_contiguous) +make_fallback(aten._embedding_bag_backward) +make_fallback(aten._embedding_bag_per_sample_weights_backward) +make_fallback(aten._embedding_bag_per_sample_weights_backward) +make_fallback(aten._fused_moving_avg_obs_fq_helper) +make_fallback(aten._fused_moving_avg_obs_fq_helper_functional) + + +# 4) Backwards (try py_impl'ing them) when fwd is written as a decomp +make_fallback(aten.max_pool3d_with_indices_backward) +make_fallback(aten._adaptive_avg_pool2d_backward, require_dense) +make_fallback(aten._adaptive_avg_pool3d_backward) +make_fallback(aten.adaptive_max_pool2d_backward) +make_fallback(aten.adaptive_max_pool3d_backward) +make_fallback(aten.fractional_max_pool2d_backward) +make_fallback(aten.fractional_max_pool3d_backward) +make_fallback(aten.replication_pad1d_backward) +make_fallback(aten.replication_pad2d_backward) +make_fallback(aten.upsample_linear1d_backward) +make_fallback(aten.upsample_bicubic2d_backward, require_contiguous) +make_fallback(aten.upsample_trilinear3d_backward) +make_fallback(aten.grid_sampler_2d_backward) +make_fallback(aten._pdist_backward) + + +# 5) Impossible (missing triton/CPU features) + +# Sorting / Sorting-like +make_fallback(aten.sort) +make_fallback(aten.sort.stable) +make_fallback(aten.kthvalue) +make_fallback(aten.topk) +make_fallback(aten.mode) +make_fallback(aten.median) +make_fallback(aten.nanmedian) +make_fallback(aten.randperm) +# see: https://github.com/pytorch/pytorch/pull/121354 +make_fallback(aten.resize_) +make_fallback(aten.resize_as_) + +# Linalg +make_fallback(aten._linalg_det) +make_fallback(aten.linalg_householder_product) +make_fallback(aten.linalg_inv_ex) +make_fallback(aten.linalg_ldl_factor_ex) +make_fallback(aten.linalg_ldl_solve) +make_fallback(aten.linalg_lu) +make_fallback(aten.linalg_lu_factor_ex) +make_fallback(aten.linalg_lu_solve) +make_fallback(aten.linalg_matrix_exp) +make_fallback(aten.linalg_qr) +make_fallback(aten._linalg_slogdet) +make_fallback(aten._linalg_solve_ex) +make_fallback(aten.linalg_solve_triangular) +make_fallback(aten._linalg_svd) +make_fallback(aten.lu_unpack) +make_fallback(aten.ormqr) +make_fallback(aten._linalg_check_errors) +make_fallback(aten.linalg_pinv.atol_rtol_tensor) +make_fallback(aten._linalg_eigh) +make_fallback(aten.triangular_solve) +make_fallback(aten.linalg_cholesky_ex) +make_fallback(aten.cholesky_inverse) +make_fallback(aten.cholesky_solve) +make_fallback(aten.geqrf) +make_fallback(aten._fft_r2c) # needs complex as well + +# Data dependent (are these necessary?) +make_fallback(aten.nonzero.default) + +# Misc +make_fallback(aten.gcd.default, warn=False) +make_fallback(aten._thnn_fused_lstm_cell, require_dense) +make_fallback(torch._prims.rng_prims.run_and_save_rng_state) +make_fallback(torch._prims.rng_prims.run_with_rng_state) +make_fallback(torch._prims.rng_prims.graphsafe_run_with_rng_state) + + +# Implemented / Half implemented +# Scans. Implemented for CUDA, missing CPU +make_fallback(aten.masked_scatter) +make_fallback(aten.masked_scatter_backward) + +# Complex number support +make_fallback(aten.view_as_complex, require_contiguous) +make_fallback(aten.angle) # needs complex + +# Needs efficentzerotensor +make_fallback(aten._efficientzerotensor) + +# Needs Sparse +make_fallback(aten._sparse_coo_tensor_with_dims_and_tensors) +make_fallback(aten.to_sparse) +make_fallback(aten._to_sparse) + +# Needs dimname support +make_fallback(aten.zeros.names) + +# 6) Pattern-matched +make_fallback( + aten._scaled_dot_product_efficient_attention.default, + sdpa_constraint, + warn=False, +) +make_fallback( + aten._scaled_dot_product_efficient_attention_backward.default, + sdpa_constraint, + warn=False, +) +make_fallback( + aten._scaled_dot_product_flash_attention.default, + sdpa_constraint, + warn=False, +) +make_fallback( + aten._scaled_dot_product_flash_attention_backward.default, + sdpa_constraint, + warn=False, +) +make_fallback( + aten._scaled_dot_product_cudnn_attention.default, + sdpa_constraint, + warn=False, +) +make_fallback( + aten._scaled_dot_product_cudnn_attention_backward.default, + sdpa_constraint, + warn=False, +) +make_fallback( + aten._scaled_dot_product_flash_attention_for_cpu.default, + sdpa_constraint, + warn=False, +) +make_fallback( + aten._scaled_dot_product_flash_attention_for_cpu_backward.default, + sdpa_constraint, + warn=False, +) +make_fallback( + aten._scaled_dot_product_fused_attention_overrideable.default, + sdpa_constraint, + warn=False, +) +make_fallback( + aten._scaled_dot_product_fused_attention_overrideable_backward.default, + sdpa_constraint, + warn=False, +) +make_fallback(aten._flash_attention_forward.default, sdpa_constraint) +make_fallback(aten._flash_attention_backward.default, sdpa_constraint) +make_fallback(aten._efficient_attention_forward.default, sdpa_constraint) +make_fallback(aten._efficient_attention_backward.default, sdpa_constraint) + +# index_reduce requires fallback when use_scatter_fallback(...) returns True +make_fallback(aten.index_reduce) +make_fallback(aten.repeat_interleave.Tensor, override_decomp=True) + +make_fallback(aten._weight_norm_interface_backward.default, require_contiguous) + + +# Register with type_promotion_kind None. +# For example, fp16.copy_(fp32) should **not** promote the first input's dtype. +@register_lowering(aten.copy, type_promotion_kind=None) +def copy(self, src, non_blocking=False): + if not isinstance(src, ir.IRNode): + src = tensor(src, dtype=self.get_dtype(), device=self.get_device()) + x = src + if self.get_device() != src.get_device(): + # pyrefly: ignore [bad-argument-type] + x = to_device(x, self.get_device()) + if self.get_dtype() != src.get_dtype(): + # pyrefly: ignore [bad-argument-type] + x = to_dtype(x, self.get_dtype()) + + if self.get_size() != src.get_size(): + out = expand(x, self.get_size()) + return clone(out) + return clone(x) + + +@register_lowering(aten.clone) +def clone(x, *, memory_format=None): + # TODO(jansel): memory format + return Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=x.make_loader(), + ranges=list(x.get_size()), + ) + + +def clone_preserve_reinterpret_view(x): + reinterpret_view_layouts = [] + if isinstance(x, TensorBox) and isinstance(x.data, ir.ReinterpretView): + x = x.data # unwrap TensorBox + # pyrefly: ignore [bad-assignment] + while isinstance(x, ir.ReinterpretView): + reinterpret_view_layouts.append(x.get_layout()) + x = x.data + x = TensorBox(x) + + x = clone(x) + + if reinterpret_view_layouts: + x = x.data # unwrap TensorBox + for layout in reinterpret_view_layouts[::-1]: + x = ir.ReinterpretView(data=x, layout=layout) + x = TensorBox(x) + + return x + + +if hasattr(aten, "lift_fresh_copy"): + register_lowering(aten.lift_fresh_copy)(clone) + + +@register_lowering(prims.iota) +def iota( + length, + *, + start, + step, + dtype, + device, + requires_grad, +): + def fn(index): + return ops.index_expr(step * index[0] + start, dtype=dtype) + + return Pointwise.create( + device=decode_device(device), + dtype=dtype, + inner_fn=fn, + ranges=[length], + ) + + +@register_lowering(aten.select_scatter, type_promotion_kind=None) +def select_scatter(x, src, dim: int, index: int): + assert x.get_dtype() == src.get_dtype() + x_loader = x.make_loader() + dim = _validate_dim(x, dim, 0) + if V.graph.sizevars.guard_or_false(sympy.Lt(index, 0)): + index = index + x.get_size()[dim] + elif V.graph.sizevars.guard_or_false(sympy.Ge(index, 0)): + pass + else: + # unbacked index + return fallback_handler(aten.select_scatter.default)(x, src, dim, index) + + V.graph.sizevars.check_leq(0, index) # type: ignore[arg-type] + V.graph.sizevars.check_lt(index, x.get_size()[dim]) # type: ignore[arg-type] + src = expand(unsqueeze(src, dim), x.get_size()) + src_loader = src.make_loader() + + def inner_fn(idx): + return ops.where( + ops.eq( + ops.index_expr(idx[dim], torch.int32), + ops.index_expr(index, torch.int32), + ), + src_loader(idx), + x_loader(idx), + ) + + return Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=inner_fn, + ranges=list(x.get_size()), + ) + + +@register_lowering(aten.slice_scatter, type_promotion_kind=None) +def slice_scatter(x, src, dim=0, start=None, end=None, step=1): + src = to_dtype(src, x.get_dtype()) + x_loader = x.make_loader() + dim = _validate_dim(x, dim, 0) + dim_size = x.get_size()[dim] + + # pyrefly: ignore [bad-argument-type] + start, end = ir.SliceView.normalize_start_end(x, dim, start, end) + + src_size = list(x.get_size()) + src_size[dim] = FloorDiv(end - start + (step - 1), step) + src = expand(src, src_size) + src_loader = src.make_loader() + + def inner_fn(idx): + if start == 0 and end == dim_size and step == 1: + # selecting every element is the same as just src.clone() + return src_loader(idx) + + idx_dim = ops.index_expr(idx[dim], torch.int64) + src_idx = list(idx) + src_idx[dim] = FloorDiv(idx[dim] - start, step) + + mask = [] + if start != 0: + mask.append( + ops.ge( + idx_dim, + ops.index_expr(sympy.expand(start), torch.int64), + ) + ) + if end != dim_size: + mask.append( + ops.lt( + idx_dim, + ops.index_expr(sympy.expand(end), torch.int64), + ) + ) + if step != 1: + mask.append( + ops.eq( + ops.index_expr( + ModularIndexing(idx[dim] - start, 1, step), torch.int64 + ), + ops.constant(0, torch.int64), + ) + ) + assert mask + mask = functools.reduce(ops.and_, mask) + src_val = ops.masked( + mask, + lambda: src_loader(src_idx), + 0 if is_integer_type(x) else 0.0, + ) + return ops.where( + mask, + src_val, + x_loader(idx), + ) + + return Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=inner_fn, + ranges=list(x.get_size()), + ) + + +def _unwrap(x): + if isinstance(x, (list, tuple)) and len(x) > 0: + return _unwrap(x[0]) + return x + + +@register_lowering([torch.tensor, aten.scalar_tensor]) +def tensor(data, *, dtype=None, device=None, layout=None, pin_memory=False): + assert_nyi(layout in (None, torch.strided), f"layout={layout}") + assert_nyi(not pin_memory, "pin_memory") + if isinstance(_unwrap(data), int): + dtype = dtype or torch.int64 + else: + dtype = dtype or torch.get_default_dtype() + + ranges: list[sympy.Expr] = [] + + if isinstance(data, sympy.Basic): + + def inner_fn(index): + return ops.index_expr(data, dtype) + + elif isinstance(data, (float, int)): + + def inner_fn(index): + return ops.constant(data, dtype) + + elif len(data) == 0 or isinstance(data[0], (float, int)) and len(data) <= 8: + # inline small tensors + ranges.append(sympy.Integer(len(data))) + + def inner_fn(index): + def binary_search(start, end): + assert start < end + if end - start == 1: + return ops.constant(data[start], dtype) + mid = (end - start) // 2 + start + return ops.where( + ops.lt( + ops.index_expr(index[0], torch.int64), + ops.constant(mid, torch.int64), + ), + binary_search(start, mid), + binary_search(mid, end), + ) + + if len(data) == 0: + return ops.constant(0, dtype) + return binary_search(0, len(data)) + + else: + return V.graph.add_tensor_constant( + torch.tensor(data, dtype=dtype, device=device) + ) + + return Pointwise.create( + device=decode_device(device), + dtype=dtype, + inner_fn=inner_fn, + ranges=ranges, + ) + + +@register_lowering(torch.as_tensor) +def as_tensor(data, dtype=None, device=None): + if isinstance(data, TensorBox): + if dtype is not None: + data = to_dtype(data, dtype) + if device is not None: + data = to_device(data, device) + return data + return tensor(data, dtype=dtype, device=device) + + +@register_lowering(torch.LongTensor) +def long_tensor(data): + return tensor(data, dtype=torch.int64) + + +@register_lowering(aten._local_scalar_dense) +def _local_scalar_dense(data): + # This is interesting! Most lowerings return tensors, so you can just + # return the buffer you allocated and it will get used (or not used, if + # it's dead.) But _local_scalar_dense (aka item) returns an int, + # not a Tensor, so you would have a type mismatch if you return a buffer; + # we are obligated to return a sympy expression instead. However, + # we need to actually codegen the .item() call somehow. We do this + # by registering a faux buffer for the DynamicScalar IR node, which is + # solely responsible for generating this .item(). The buffer is + # not used for anything (notice we discard it); at codegen time, + # the "buffer" just gets assigned None. + unbacked_bindings = resolve_unbacked_bindings( + V.graph.sizevars.shape_env, V.graph.current_node.meta["unbacked_bindings"] + ) + assert unbacked_bindings is not None + assert len(unbacked_bindings) == 1, unbacked_bindings + # NB: Have to be very careful here. V.graph.current_node.meta["val"] + # seemingly also contains a symbol which you want to do binding for, + # but it actually isn't. In particular, if we have later performed + # a deferred runtime assert saying that u0 == s0, you will actually + # see s0 from expr! This is bad because we need to actually generate + # the assert that says u0 == s0, so we need to know where to get u0 + # from (this call). In particular, we must use unbacked_bindings, which + # is guaranteed to have the original, unreplaced symbol in question. + # + # NB2: Another thing we have to be very careful about are symbol bindings + # that require nontrivial refinement, e.g., when you have a binding site + # x: Sym(u0 * 4) = y.item(). Here, the code generation must do a division + # in order to appropriately bind u0. This is communicated via the keypath + # in unbacked_bindings, and we need to hold onto it in order to generate + # code appropriately for this case. + binding_sym, keypath = next(iter(unbacked_bindings.items())) + buffer = ir.DynamicScalar(binding_sym, keypath, data) + buffer.name = V.graph.register_buffer(buffer) + V.graph.register_operation(buffer) + # NB: the replaced expr is OK to use directly downstream, we want + # simplifications in this case! + val = V.graph.current_node.meta["val"] + if isinstance(val, (torch.SymInt, torch.SymFloat, torch.SymBool)): + return val.node.expr + else: + return sympy.sympify(val) + + +@register_lowering(aten._assert_scalar) +def _assert_scalar(data, msg): + # NB: These will be handled at codegen time + # Not sure if we are guaranteed to be able to serve out truth from the + # deferred_runtime_asserts, TODO: try this assert out + # See [NOTE] Codegen runtime asserts in Inductor + # assert bool(data.scalar), data + return None + + +@register_lowering(aten._assert_tensor_metadata) +def _assert_tensor_metadata( + a, size=None, stride=None, dtype=None, *, device=None, layout=None +): + return None + + +def _full(fill_value, device, dtype, size): + value = fill_value + if not isinstance(fill_value, (int, float)) and hasattr(value, "value"): + value = value.value + + if isinstance(value, (int, float)): + + def inner_fn(index): + return ops.constant(value, dtype) + + elif isinstance(value, sympy.Basic): + + def inner_fn(index): + return ops.index_expr(value, dtype) + + else: + assert len(value.get_size()) == 0 + value_loader = value.make_loader() + + def inner_fn(index): + return value_loader([]) + + return Pointwise.create( + device=device, + dtype=dtype, + inner_fn=inner_fn, + ranges=list(size), + ) + + +def full_like(x, fill_value, **kwargs): + return create_tensor_like(tensor_constructor(fill_value))(x, **kwargs) + + +def tensor_constructor(fill_value): + # torch.zeros, torch.ones, etc + def inner( + *size, + names=None, + dtype=None, + device=None, + layout=None, + pin_memory=False, + memory_format=None, + ): + assert_nyi(names is None, "named tensors") + assert_nyi(layout in (None, torch.strided), f"layout={layout}") + assert_nyi(not pin_memory, "pin_memory") + device = decode_device(device) + dtype = dtype or torch.get_default_dtype() + if len(size) == 1 and isinstance(size[0], (list, tuple, torch.Size)): + size = tuple(size[0]) + # See https://github.com/pytorch/pytorch/issues/118102 + # All sizes at lowering time should be sympy.Symbol, not SymInt! + for s in size: + assert not isinstance(s, torch.SymInt) + size = [sympy.expand(s) for s in size] + return _full(fill_value, device, dtype, size) + + return inner + + +@register_lowering([torch.empty, aten.empty]) +def empty( + *size, + names=None, + dtype=None, + layout=None, + device=None, + pin_memory=None, + memory_format=None, +): + assert_nyi(names is None, "named tensors") + device = decode_device(device) + if len(size) == 1 and isinstance(size[0], (list, tuple, torch.Size)): + size = tuple(size[0]) + return empty_strided( + size, None, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory + ) + + +def create_tensor_like(creation_fn): + """ + Shim to convert X_like(...) into X(...). For example zeros_like() into zeros(). + """ + + def _constant_like( + x, *, dtype=None, device=None, layout=None, pin_memory=False, memory_format=None + ): + assert_nyi(not pin_memory, "pin_memory") + assert_nyi(layout in (None, torch.strided), f"layout={layout}") + if dtype is None: + dtype = x.get_dtype() + else: + dtype = decode_dtype(dtype) + device = device or x.get_device() + size = list(x.get_size()) + return creation_fn( + size, dtype=dtype, device=device, layout=layout, pin_memory=pin_memory + ) + + return _constant_like + + +def constant_like(fill_value): + return create_tensor_like(tensor_constructor(fill_value)) + + +empty_like = register_lowering(aten.empty_like)(create_tensor_like(empty)) +ones_like = create_tensor_like(tensor_constructor(1)) +zeros_like = create_tensor_like(tensor_constructor(0)) + + +def new_constant(fill_value): + def _new_constant( + x, size, *, dtype=None, layout=None, device=None, pin_memory=None + ): + assert isinstance(size, (list, tuple)) + assert_nyi(not pin_memory, "pin_memory") + assert_nyi(layout in (None, torch.strided), f"layout={layout}") + # pyrefly: ignore [bad-argument-type] + dtype = decode_dtype(dtype) or x.get_dtype() + device = device or x.get_device() + size = [sympy.Integer(s) for s in size] + return _full(fill_value, decode_device(device), dtype, size) + + return _new_constant + + +@register_lowering(aten.new_empty) +def new_empty(x, size, *, dtype=None, layout=None, device=None, pin_memory=None): + if dtype is None: + dtype = x.get_dtype() + if device is None: + device = x.get_device() + return empty_strided( + size, + None, + dtype=dtype, + layout=layout, + device=decode_device(device), + pin_memory=pin_memory, + ) + + +@register_lowering(aten.empty_strided) +def empty_strided( + size, stride, *, dtype=None, layout=None, device=None, pin_memory=None +): + assert isinstance(size, (list, tuple)) + assert isinstance(stride, (list, tuple, type(None))) + assert_nyi(not pin_memory, "pin_memory") + assert_nyi(layout in (None, torch.strided), f"layout={layout}") + # pyrefly: ignore [bad-argument-type] + dtype = decode_dtype(dtype) or torch.get_default_dtype() + device = device or torch.tensor(0.0).device + device = decode_device(device) + pointwise = _full(fill_value=0, device=device, dtype=dtype, size=size) + pointwise.realize() + buffer = pointwise.data.data + # explicitly set ranges to zeros in order to make a NopKernelSchedulerNode + buffer.data = dataclasses.replace(buffer.data, ranges=[0] * len(size)) + assert isinstance(buffer, ir.ComputedBuffer) + size = [sympy.expand(s) for s in size] + stride = ( + [sympy.expand(s) for s in stride] + if stride + else ir.FlexibleLayout.contiguous_strides(size) + ) + buffer.layout = ir.FixedLayout( + device=device, + dtype=dtype, + size=size, + stride=stride, + ) + return pointwise + + +@register_lowering(aten.new_empty_strided) +def new_empty_strided( + x, size, stride, *, dtype=None, layout=None, device=None, pin_memory=None +): + if dtype is None: + dtype = x.get_dtype() + if device is None: + device = x.get_device() + return empty_strided( + size, + stride, + dtype=dtype, + layout=layout, + device=decode_device(device), + pin_memory=pin_memory, + ) + + +@register_lowering(prims.copy_strided.default) +def copy_strided(x, stride): + stride = [V.graph.sizevars.size_hint_or_throw(s) for s in stride] + stride_order = sorted(range(len(stride)), key=stride.__getitem__) + return ir.ExternKernel.require_stride_order(x, stride_order) + + +@register_lowering([torch.full, aten.full]) +def full(size, fill_value, **kwargs): + assert kwargs.get("dtype") is not None, "dtype should be handled by decomposition" + return tensor_constructor(fill_value)(size, **kwargs) + + +@register_lowering(aten.gather, type_promotion_kind=None) +def gather(x, dim, index, sparse_grad=False): + # sparse_grad doesn't affect forward computation, + # and backward tracing is taken care of by AOT Autograd + assert isinstance(x, TensorBox) + if index.get_numel() == 0: + # Empty index case. Return an empty array with the same shape + return new_empty(x, index.get_size()) + + size = x.get_size() + offset = len(size) == 0 + dim = _validate_dim(x, dim, offset) + + if offset: + x = expand(x, [1]) + size = [1] + + x_loader = x.make_loader() + index_loader = index.make_loader() + + def fn(idx): + idx = list(idx) + gather_idx = ops.indirect_indexing(index_loader(idx), size[dim]) + if len(idx) == 0: + idx = [gather_idx] + else: + idx[dim] = gather_idx + return x_loader(idx) + + return Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=fn, + ranges=index.get_size(), + ) + + +@register_lowering(aten.embedding, type_promotion_kind=None) +def embedding(weight, indices, padding_idx=-1, scale_grad_by_freq=False, sparse=False): + if sparse: + return fallback_handler(aten.embedding.default)( + weight, indices, padding_idx, scale_grad_by_freq, sparse + ) + + assert not sparse + assert isinstance(weight, TensorBox) + assert isinstance(indices, TensorBox) + assert "int" in str(indices.get_dtype()) + + weight_loader = weight.make_loader() + indices_loader = indices.make_loader() + indices_ndim = len(indices.get_size()) + weight_size = weight.get_size() + new_size = [*indices.get_size(), *weight_size[1:]] + + def fn(idx): + assert len(idx) == len(new_size), f"{idx} != {new_size}" + var_index = indices_loader(idx[:indices_ndim]) + weight_idx = [ops.indirect_indexing(var_index, weight_size[0])] + [ + *idx[indices_ndim:] + ] + return weight_loader(weight_idx) + + return Pointwise.create( + device=weight.get_device(), + dtype=weight.get_dtype(), + inner_fn=fn, + ranges=new_size, + ) + + +def check_and_broadcast_indices(indices, device): + assert all( + i.get_dtype() in (torch.int64, torch.int32, torch.bool, torch.uint8) + for i in indices + if i is not None + ), ( + f"indices must be int64, byte or bool. Got {[i.get_dtype() for i in indices if i is not None]}" + ) + if any( + i.get_dtype() in (torch.bool, torch.uint8) for i in indices if i is not None + ): + raise NotImplementedError("Fallback for bool indices") + + valid_idxs = [i for i, x in enumerate(indices) if isinstance(x, TensorBox)] + assert len(valid_idxs) > 0, "requires at least 1 non-None index" + new_indices = [None] * len(indices) + for i, x in zip(valid_idxs, broadcast_tensors(*[indices[i] for i in valid_idxs])): + # Eager allows indices to be CPU tensor when running on CUDA + # FIXME: Calling to_device(x, device) should work but + # test_advancedindex_mixed_cpu_devices still fails + if x.get_device() != device: + raise NotImplementedError("Fallback when indices is on a different device") + new_indices[i] = x + return new_indices, valid_idxs + + +def index_output_size_and_inner_fn( + x_size, + indices, + tensor_indices, + tensor_size, + indices_loaders, + indexed_size, + x_loader, + check, + wrap_neg=True, +): + # Note that behavior of indexing differs when there are non consecutive + # tensors. In this case, the tensor index is pulled to the beginning. + # + # Suppose a = torch.arange(3 * 4 * 5 * 6 * 7).view(3, 4, 5, 6, 7) + # x = torch.tensor[1,2] + # Then, a[:,x,:,x,:] will have shape 2,3,5,7 as due to x,:,x then 2 will + # be pulled to the front. + non_consecutive_tensors = False + for previous, current in itertools.pairwise(tensor_indices): + if current - previous != 1: + non_consecutive_tensors = True + + output_size = [x_size[i] for i, val in enumerate(indices) if val is None] + output_size = [*output_size, *x_size[len(output_size) + len(tensor_indices) :]] + + first_tensor_index = tensor_indices[0] + if non_consecutive_tensors: + output_size = tensor_size + output_size + else: + output_size = ( + output_size[:first_tensor_index] + + tensor_size + + output_size[first_tensor_index:] + ) + + def fn(idx): + assert len(idx) == len(output_size) + assert len(indices_loaders) == len(indexed_size) + + rank = len(tensor_size) + new_index = [] + first_tensor_index = tensor_indices[0] + start_offset = 0 if non_consecutive_tensors else first_tensor_index + next_idx = 0 + for i in range(tensor_indices[-1] + 1): + if i == start_offset: + next_idx += rank + if indices[i] is None: + assert next_idx < len(idx) + new_index.append(idx[next_idx]) + next_idx += 1 + else: + loader = indices_loaders[i] + assert loader is not None + size = indexed_size[i] + new_index.append( + ops.indirect_indexing( + loader(idx[start_offset : start_offset + rank]), + size, + check=check, + wrap_neg=wrap_neg, + ) + ) + new_index = [ + *new_index, + *idx[next_idx:], + ] + return new_index if x_loader is None else x_loader(new_index) + + return output_size, fn + + +def index_impl(x, indices, check): + output_size, inner_fn, _ = index_impl_helper(x, indices, check) + + return Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=inner_fn, + ranges=output_size, + ) + + +def index_impl_helper(x, indices, check, wrap_neg=True): + assert isinstance(indices, (list, tuple)) + x_loader = x.make_loader() + indices, tensor_indices = check_and_broadcast_indices(indices, x.get_device()) + assert len(tensor_indices) > 0, "Must have at least one valid idx" + + indices_loaders = [i.make_loader() if i is not None else None for i in indices] + # no guards on output size, all the guards are set in broadcast_tensors + + # We can use the first one since they are all required to be the same size + tensor_size = list(indices[tensor_indices[0]].get_size()) + + x_size = x.get_size() + + indexed_size = [x_size[i] for i in range(len(indices)) if indices[i] is not None] + if check and 0 in indexed_size and 0 not in tensor_size: + raise IndexError("index is out of bounds for dimension with size 0") + + indexed_size = [x_size[i] for i in range(len(indices))] + output_size, index_inner_fn = index_output_size_and_inner_fn( + x_size, + indices, + tensor_indices, + tensor_size, + indices_loaders, + indexed_size, + None, + check=check, + wrap_neg=wrap_neg, + ) + + def inner_fn(idx): + return x_loader(index_inner_fn(idx)) + + return output_size, inner_fn, index_inner_fn + + +@register_lowering(aten.index, type_promotion_kind=None) +def index(x, indices): + try: + return index_impl(x, indices, check=True) + except NotImplementedError: + # Fallback to ATen for boolean indexing + x.realize() + return fallback_handler(aten.index.Tensor, add_to_fallback_set=False)( + x, indices + ) + + +@register_lowering(aten._unsafe_index, type_promotion_kind=None) +def _unsafe_index(x, indices): + return index_impl(x, indices, check=False) + + +# All the indexing decompositions are written in terms of index, index_put, and index_put_ +# We cannot have this lowering as a decomposition as it introduces +# mutation in the graph, which is bad for Aot Autograd. Aot Autograd runs dead +# code elimination and common subexpression elimination optimizations, which +# assume graphs to be side-effect free. More details at +# https://github.com/pytorch/torchdynamo/issues/1235 +# and +# https://github.com/pytorch/torchdynamo/issues/1863 +@register_lowering(aten.index_put, type_promotion_kind=None) +def index_put(x, indices, values, accumulate=False): + return index_put_impl_( + clone(x), indices, values, accumulate, check=True, may_realize=False + ) + + +@register_lowering(aten._unsafe_index_put) +def _unsafe_index_put(x, indices, values, accumulate=False): + return index_put_impl_( + clone(x), indices, values, accumulate, check=False, may_realize=False + ) + + +def index_put_as_masked_fill(self, indices, value, accumulate): + if value.get_device() != self.get_device(): + value = to_device(value, self.get_device()) + if accumulate: + value = add(self, value) + return mutate_to(self, where(indices[0], value, self)) + + +def index_put_fallback(self, indices, values, accumulate): + from .utils import _fx_node_is_input_dependent_cudagraph_unsafe + + op_overload = getattr(aten.index_put_, V.graph.current_node.target._overloadname) # type: ignore[union-attr] + + # Check if any index is a boolean tensor - if so, mark as cudagraph-unsafe + # because boolean indices trigger .nonzero() during CUDA graph capture + # When graph_partition is enabled, skip - partitioning handles this + fx_node = V.graph.current_node + if ( + not config.graph_partition + and fx_node is not None + and _fx_node_is_input_dependent_cudagraph_unsafe(fx_node) + ): + msg = "index_put_ fallback with boolean indexing is not compatible with CUDA graphs" + if stack_trace := fx_node.meta.get("stack_trace", None): + msg = f"{msg} Found from : \n {stack_trace}" + V.graph.disable_cudagraphs_reason = msg + + ir.IndexPutFallback(op_overload, self, indices, values, accumulate) + return self + + +@register_lowering(aten.index_put_, type_promotion_kind=None) +def index_put_(self, indices, values, accumulate=False): + return index_put_impl_( + self, indices, values, accumulate, check=True, may_realize=True + ) + + +@register_lowering(inductor_prims._unsafe_index_put_, type_promotion_kind=None) +def _unsafe_index_put_(self, indices, values, accumulate=False): + return index_put_impl_( + self, indices, values, accumulate, check=False, may_realize=True + ) + + +def index_put_impl_(self, indices, values, accumulate, check, may_realize=False): + if may_realize: + + def indice_slice_from_randperm(indice): + # Refer to: https://github.com/pytorch/pytorch/pull/139366#discussion_r1825424660 + # For this specific pattern, indices is unique as coming from torch.randperm. + # However, as the content of the indices is unknown, we have to check this specific pattern. + if isinstance(indice, TensorBox) and isinstance(indice.data, ir.BaseView): + indice = indice.data.unwrap_view() + return ( + isinstance(indice, ir.StorageBox) + and isinstance(indice.data, ir.ExternKernel) + and getattr(indice.data, "fx_node", None) + and indice.data.fx_node.target is torch.ops.aten.randperm.default + ) + return False + + if ir.try_get_name(self) in values.get_read_names() and not all( + indice_slice_from_randperm(indice) for indice in indices + ): + # Fix issue: https://github.com/pytorch/pytorch/issues/138908 + # When self and values have memory overlapping, indices may + # contain duplicate values, potentially causing incorrect results since + # the load of `values` might contain modified value from the store of `self`. + # To address this, store values in a temporary buffer in such cases. + values.realize() + + # Dispatch to masked fill for single boolean index with single value + if ( + values.get_numel() == 1 + and len(indices) == 1 + and indices[0].get_dtype() in (torch.bool, torch.uint8) + ): + mask = indices[0] + for _ in range(len(mask.get_size()), len(self.get_size())): + mask = unsqueeze(mask, -1) + return index_put_as_masked_fill(self, [mask], values, accumulate) + + # Fallback in torch deterministic mode + if torch.are_deterministic_algorithms_enabled(): + return index_put_fallback(self, indices, values, accumulate) + + # Fallback if there is a boolean index + for index in indices: + if index is not None and index.get_dtype() in (torch.bool, torch.uint8): + return index_put_fallback(self, indices, values, accumulate) + + x_size = self.get_size() + x_ndim = len(x_size) + + if accumulate and needs_fallback_due_to_atomic_add_limitations(self.get_dtype()): + # self is an scalar Tensor + if x_ndim == 0: + self = view(self, [1]) + self = index_put_fallback(self, indices, values, accumulate) + if x_ndim == 0: + self = view(self, []) + return self + + values = to_dtype(values, self.get_dtype()) + + try: + # Note that code will only get here when dtype is uint32 + indices, tensor_indices = check_and_broadcast_indices( + indices, self.get_device() + ) + except NotImplementedError: + return index_put_fallback(self, indices, values, accumulate) + + indices_loaders = [i.make_loader() if i is not None else None for i in indices] + + assert isinstance(self, TensorBox) + self.realize() + + # self is an scalar Tensor + if x_ndim == 0: + self = view(self, [1]) + + # We can use the first one since they are all required to be the same size + tensor_size = list(indices[tensor_indices[0]].get_size()) + indexed_size = [x_size[i] for i in range(len(indices))] + + expected_vals_size, inner_fn = index_output_size_and_inner_fn( + x_size, + indices, + tensor_indices, + tensor_size, + indices_loaders, + indexed_size, + None, + check=check, + ) + + values = expand(values, expected_vals_size) + # all guards are set above during broadcast_tensors and expand + + device = self.get_device() + assert device is not None + scatter = ir.Scatter( + device=device, + dtype=self.get_dtype(), + inner_fn=values.make_loader(), + ranges=expected_vals_size, # iter_ranges, + output_indexer=inner_fn, + scatter_mode="atomic_add" if accumulate else None, + ) + buffer = ir.ComputedBuffer( + name=None, + layout=ir.MutationLayoutSHOULDREMOVE(self), + data=scatter, + ) + buffer.name = V.graph.register_buffer(buffer) + V.graph.register_operation(buffer) + + if x_ndim == 0: + self = view(self, []) + return self + + +fallback__unsafe_masked_index = fallback_handler( + aten._unsafe_masked_index.default, add_to_fallback_set=False +) + +fallback__unsafe_masked_index_put_accumulate = fallback_handler( + aten._unsafe_masked_index_put_accumulate.default, add_to_fallback_set=False +) + + +@register_lowering(aten._unsafe_masked_index, type_promotion_kind=None) +def _unsafe_masked_index(self, mask, indices, fill): + ranges, _, _unsafe_index_fn = index_impl_helper( + self, indices, check=False, wrap_neg=False + ) + mask_loader = mask.make_loader() + self_loader = self.make_loader() + + def inner_fn(idx): + if mask.dtype != torch.bool: + mask_val = ops.to_dtype(mask_loader(idx), torch.bool) + else: + mask_val = mask_loader(idx) + return ops.masked(mask_val, lambda: self_loader(_unsafe_index_fn(idx)), fill) + + return Pointwise.create( + device=self.get_device(), + dtype=self.get_dtype(), + inner_fn=inner_fn, + ranges=ranges, + ) + + +@register_lowering(aten._unsafe_masked_index_put_accumulate, type_promotion_kind=None) +def _unsafe_masked_index_put_accumulate(x, mask, indices, values): + masked_value = where(mask, values, 0) + shape = x.get_size() + clamped_indices = [ + clamp(indices[i], -shape[i], shape[i] - 1) if indices[i] else None + for i in range(len(indices)) + ] + # TODO: use a masked store for this. currently only triton + # supports masked stores and cpp backend does not. + return _unsafe_index_put(x, clamped_indices, masked_value, accumulate=True) + + +@make_pointwise +def clamp(a, min, max): + return ops.maximum(min, ops.minimum(max, a)) + + +@register_lowering(aten.as_strided_scatter, type_promotion_kind=None) +def as_strided_scatter(self, src, size, stride, storage_offset=None): + output = clone(self) + output_view = as_strided(output, size, stride, storage_offset) + copy_(output_view, src) + return output + + +@register_lowering(aten.scatter, type_promotion_kind=None) +def scatter(x, dim: int, index, src, **kwargs): + return scatter_(clone(x), dim, index, src, **kwargs) + + +def scatter_fallback( + op_overload: torch._ops.OpOverload, + self, + dim: int, + index, + src, + *, + reduce: Optional[str] = None, + include_self: bool = True, +): + src_is_tensor = isinstance(src, TensorBox) + if use_scatter_fallback( + op_overload, + reduce, + self.get_dtype(), + cast(torch.dtype, src.get_dtype() if src_is_tensor else type(src)), + src.get_device().type if src_is_tensor else "not impl", + src_is_tensor, + ): + ir.ScatterFallback( + op_overload, + self, + dim, + index, + src, + reduce=reduce, + include_self=include_self, + ) + return self + + return None + + +@register_lowering(aten.scatter_, type_promotion_kind=None) +def scatter_(self, dim: int, index, src, *, reduce: Optional[str] = None): + assert reduce in (None, "add", "multiply") + if reduce is None: + op_overload = getattr(aten.scatter_, V.graph.current_node.target._overloadname) # type: ignore[union-attr] + fallback_result = scatter_fallback( + op_overload, self, dim, index, src, reduce=reduce + ) + if fallback_result is not None: + return fallback_result + + if reduce == "add": + reduce = "sum" + elif reduce == "multiply": + reduce = "prod" + return scatter_reduce_(self, dim, index, src, reduce) + + +@register_lowering(aten.scatter_add, type_promotion_kind=None) +def scatter_add(x, dim: int, index, src): + return scatter_add_(clone(x), dim, index, src) + + +@register_lowering(aten.scatter_add_, type_promotion_kind=None) +def scatter_add_(x, dim: int, index, src): + return scatter_reduce_(x, dim, index, src, "sum") + + +@register_lowering(aten.scatter_reduce, type_promotion_kind=None) +def scatter_reduce(x, dim: int, index, src, reduction_type, **kwargs): + return scatter_reduce_(clone(x), dim, index, src, reduction_type, **kwargs) + + +@register_lowering(aten.scatter_reduce_, type_promotion_kind=None) +def scatter_reduce_(self, dim: int, index, src, reduce, *, include_self: bool = True): + assert reduce in (None, "sum", "prod", "mean", "amax", "amin") + assert ( + len(aten.scatter_reduce_.overloads()) == 1 + and "two" in aten.scatter_reduce_.overloads() + ), "aten.scatter_reduce_.two is not the unique overload of aten.scatter_reduce_" + + if isinstance(src, Number): + src = full_like(self, src) + + fallback_result = scatter_fallback( + aten.scatter_reduce_.two, + self, + dim, + index, + src, + reduce=reduce, + include_self=include_self, + ) + + if fallback_result: + return fallback_result + + assert isinstance(self, TensorBox) + assert "int" in str(index.get_dtype()) + + ndim = len(self.get_size()) + if ndim == 0: + self = view(self, [1]) + + if isinstance(src, TensorBox) and len(src.get_size()) == 0: + src = view(src, [1]) + + if isinstance(index, TensorBox) and len(index.get_size()) == 0: + index = view(index, [1]) + + if index.get_numel() == 0: + return self + + dim = _validate_dim(self, dim) + + self.realize() + index_loader = index.make_loader() + src_loader = src.make_loader() if isinstance(src, TensorBox) else None + + def output_indexer(idx): + # self is captured from the end of the function, so it may have 0 dim + shape = self.get_size() + ndim = len(shape) + indirect_idx = list(idx) + indirect_idx[dim] = ops.indirect_indexing( + index_loader(idx), 1 if ndim == 0 else shape[dim], wrap_neg=False + ) + return indirect_idx + + def fn(idx): + if src_loader: + return src_loader(idx) + else: + # src is a scalar + # pyrefly: ignore [bad-argument-type] + return ops.constant(src, self.get_dtype()) + + def backend_reduce_str(reduce): + if reduce == "sum": + return "atomic_add" + else: + # TODO: Need to support more reduction type + assert reduce is None + return None + + device = self.get_device() + assert device is not None + + if not include_self: + # zero out the corresponding elements first + zero_out = ir.Scatter( + device=device, + dtype=self.get_dtype(), + inner_fn=lambda index: ops.constant(0, self.get_dtype()), + ranges=index.get_size(), + output_indexer=output_indexer, + scatter_mode=None, + ) + buffer = ir.ComputedBuffer( + name=None, + layout=ir.MutationLayoutSHOULDREMOVE(self), + data=zero_out, + ) + buffer.name = V.graph.register_buffer(buffer) + V.graph.register_operation(buffer) + + # self[index[i][j][k]][j][k] += src[i][j][k] # if dim == 0 + # self[i][index[i][j][k]][k] += src[i][j][k] # if dim == 1 + # self[i][j][index[i][j][k]] += src[i][j][k] # if dim == 2 + scatter = ir.Scatter( + device=device, + dtype=self.get_dtype(), + inner_fn=fn, + ranges=index.get_size(), + output_indexer=output_indexer, + scatter_mode=backend_reduce_str(reduce), + ) + buffer = ir.ComputedBuffer( + name=None, + layout=ir.MutationLayoutSHOULDREMOVE(self), + data=scatter, + ) + buffer.name = V.graph.register_buffer(buffer) + V.graph.register_operation(buffer) + + if ndim == 0: + self = view(self, []) + return self + + +def upsample_nearestnd( + x, + output_size, + scales_x: tuple[Optional[float], ...], + n: int = 2, + exact: bool = False, +): + x.realize_hint() # elements are reused + x_loader = x.make_loader() + i_sizes = x.get_size()[-n:] + batch = x.get_size()[:-n] + i_sizes = [V.graph.sizevars.guard_int(i) for i in i_sizes] + + assert len(scales_x) == n + o_sizes = output_size + + inv_scales = [i / o for i, o in zip(i_sizes, o_sizes)] + for i, scale in enumerate(scales_x): + if scale is not None: + inv_scales[i] = 1.0 / scale + + def scale_fn(x, scale, size): + # Nearest Exact: input_index = round(scale * (output_index + 0.5) - 0.5) + # = floor(scale * (output_index + 0.5)) + # Nearest: input_index = floor(scale * output_index) + x = ops.index_expr(x, torch.float32) + if exact: + x = ops.add(x, ops.constant(0.5, torch.float32)) + x = ops.mul(x, ops.constant(scale, torch.float32)) + x = ops.to_dtype(x, torch.int32) + return ops.indirect_indexing(x, size, check=False) + + def fn(idx): + x = idx[-n:] + b = idx[:-n] + return x_loader( + [*b, *[scale_fn(i, s, size) for i, s, size in zip(x, inv_scales, i_sizes)]] + ) + + return Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=fn, + ranges=[*batch, *o_sizes], + ) + + +@register_lowering(aten.upsample_nearest1d.default) +def upsample_nearest1d(x, output_size, scales: Optional[float] = None): + return upsample_nearestnd(x, output_size, (scales,), n=1) + + +@register_lowering(aten._upsample_nearest_exact1d.default) +def _upsample_nearest_exact1d(x, output_size, scales: Optional[float] = None): + return upsample_nearestnd(x, output_size, (scales,), n=1, exact=True) + + +@register_lowering(aten.upsample_nearest2d.default) +def upsample_nearest2d( + x, output_size, scales_h: Optional[float] = None, scales_w: Optional[float] = None +): + return upsample_nearestnd(x, output_size, (scales_h, scales_w), n=2) + + +@register_lowering(aten._upsample_nearest_exact2d.default) +def _upsample_nearest_exact2d( + x, output_size, scales_h: Optional[float] = None, scales_w: Optional[float] = None +): + return upsample_nearestnd(x, output_size, (scales_h, scales_w), n=2, exact=True) + + +@register_lowering(aten.upsample_nearest3d.default) +def upsample_nearest3d( + x, + output_size, + scales_d: Optional[float] = None, + scales_h: Optional[float] = None, + scales_w: Optional[float] = None, +): + return upsample_nearestnd(x, output_size, (scales_d, scales_h, scales_w), n=3) + + +@register_lowering(aten._upsample_nearest_exact3d.default) +def _upsample_nearest_exact3d( + x, + output_size, + scales_d: Optional[float] = None, + scales_h: Optional[float] = None, + scales_w: Optional[float] = None, +): + return upsample_nearestnd( + x, output_size, (scales_d, scales_h, scales_w), n=3, exact=True + ) + + +def _create_constants(*args, dtype): + return tuple(ops.constant(a, dtype) for a in args) + + +@register_lowering(prims.rev.default) +def rev(x, dims): + # note - dims pre-canonicalized + x_loader = x.make_loader() + sizes = x.get_size() + + def loader(idx): + idx = list(idx) + assert len(idx) == len(sizes) + for dim in dims: + idx[dim] = (sizes[dim] - 1) - idx[dim] + + return x_loader(idx) + + return Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=loader, + ranges=sizes, + ) + + +def inplace_constant_pad_nd( + x: TensorBox, padding: Sequence[int], fill_value: float +) -> Optional[TensorBox]: + """ + This optimization changes the semantics of padding from 'clone' + style to 'view' style. + + Thanks to functionalization, this change can still maintain numerical + correctness. + """ + + def _padding_can_be_fused(): + """ + Conservatively check if padding can be fused with downstream op. + 1. if the downstream op is a sum, then there is little benefit to + do inplace padding + 2. if the downstream op is a matmul, doing inplace padding can + save membw. + """ + current_node = V.graph.current_node + if current_node is None: + return True # be conservative + users = tuple(current_node.users) + if len(users) == 1 and users[0].target in ( + aten.mm.default, + aten.addmm.default, + ): + return False + + return True # be conservative + + if _padding_can_be_fused(): + return None + + # Only handle 2D case for now + if len(padding) != 4 or len(x.get_size()) != 2: + return None + + # No harm to realize since we already know that + # the op can not be fused into the single user. + # It need to be realized later anyways. + x.realize() + + # If x is a view (e.g. a SliceView), realizing it just realizing the + # underlying storage. x itself is still a view. + if ( + not isinstance(x, ir.TensorBox) + or not isinstance(x.data, ir.StorageBox) + or not ( + isinstance(x.data.data, ir.ComputedBuffer) + or ( + config.can_inplace_pad_graph_input + and isinstance(x.data.data, ir.InputBuffer) + ) + ) + or not x.data.data.name + ): + return None + x.freeze_layout() + + _, layout = ir.as_storage_and_layout(x) + strides = layout.stride + if strides[1] != 1: + return None + + if padding[0] != 0 or padding[2] != 0 or padding[3] != 0: + return None + + npad = padding[1] + if npad == 0: + return None + + stride0 = strides[0] + rowsize = layout.size[1] + + if stride0 < rowsize + npad: + return None + + bufname = x.data.data.name + padded_size = [layout.size[0], layout.size[1] + npad] + V.graph.buffer_to_padded_size[bufname] = padded_size + resized_x = as_strided( + x, + padded_size, + layout.stride, + layout.offset, + ) + + sliced_x = slice_(resized_x, dim=1, start=rowsize, end=rowsize + npad, clamp=False) + fill_(sliced_x, fill_value) + + counters["inductor"]["inplace_padding"] += 1 + return resized_x + + +@register_lowering(aten.constant_pad_nd, type_promotion_kind=None) +def constant_pad_nd(x, padding, fill_value=0): + assert (len(padding) % 2) == 0 + if all(p == 0 for p in padding): + return clone(x) + + if config.inplace_padding: + out = inplace_constant_pad_nd(x, padding, fill_value) + if out: + return out + # fall through if can not inplace the padding + + sizes = x.get_size() + + bounds = list(reversed(list(zip(padding[::2], padding[1::2])))) + n = len(sizes) - len(bounds) + + # if padding is a complicated expression, hoist it + bounds_precomp: list[tuple[sympy.Symbol, Any]] = [] + for l, h in bounds: + bounds_precomp.append((V.graph.sizevars.lookup_precomputed_size(l), h)) # type: ignore[arg-type] + + output_size = list(sizes[:n]) + mask_sizes = [] + for (low, high), size in zip(bounds, sizes[n:]): + mask_sizes.append(size) + output_size.append(sympy.expand(size + low + high)) + assert len(output_size) == len(sizes) + fill_value = dtype_to_type(x.get_dtype())(fill_value) + + def mask(index): + mask = [] + for idx, (low, high), length in zip(index[n:], bounds, mask_sizes): + if low != 0: + mask.append(range_mask_low(idx, 0)) + if high != 0: + mask.append(range_mask_high(idx, length)) + mask = functools.reduce(ops.and_, mask) + return ops.masked(mask, lambda: x_loader(index), fill_value) + + def offset_fn(index): + new_index = list(index[:n]) + for idx, (low, _high) in zip(index[n:], bounds_precomp): + new_index.append(idx - low) + assert len(new_index) == len(index) + return mask(new_index) + + x_loader = x.make_loader() + return Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=offset_fn, + ranges=output_size, + ) + + +def range_mask_low(i: sympy.Expr, low: Union[sympy.Expr, int]): + return ops.ge( + ops.index_expr(i, torch.int64), + ops.index_expr(sympy.Integer(low), torch.int64), + ) + + +def range_mask_high(i: sympy.Expr, high: sympy.Expr): + return ops.lt( + ops.index_expr(i, torch.int64), + ops.index_expr(high, torch.int64), + ) + + +def range_mask(i: sympy.Expr, high: sympy.Expr, low: sympy.Expr): + return ops.and_( + range_mask_low(i, low), + range_mask_high(i, high), + ) + + +def constant_boundary_condition( + x, fill_value, padding=None, pad_fill_value=1.0, dim=None +): + h = x.get_size()[-dim:] + x_loader = x.make_loader() + # pyrefly: ignore [unsupported-operation] + padding_h = padding or [0] * dim + + def load(index): + prefix = index[:-dim] + ih = index[-dim:] + + mask = functools.reduce( + ops.and_, + # pyrefly: ignore [no-matching-overload] + [range_mask(ih[i], h[i] + padding_h[i], -padding_h[i]) for i in range(dim)], + ) + return ( + ops.masked( + mask, + lambda: constant_boundary_condition(x, pad_fill_value, dim=dim)( + [*prefix, *ih] + ), + fill_value, + ) + if padding + else ops.masked(mask, lambda: x_loader([*prefix, *ih]), fill_value) + ) + + return load + + +def pooling_size(x, i, kernel_size, stride, padding, ceil_mode, *, dilation=None): + if dilation is None: + dilation = [1] * len(padding) + + x_out = FloorDiv( + x + 2 * padding[i] - dilation[i] * (kernel_size[i] - 1) + (stride[i] - 1), + stride[i], + ) + + if ceil_mode: + x_alt = FloorDiv( + x + + 2 * padding[i] + - dilation[i] * (kernel_size[i] - 1) + + 2 * (stride[i] - 1), + stride[i], + ) + if V.graph.sizevars.size_hint((x_alt - 1) * stride[i] - x - padding[i]) >= 0: + # Sliding windows must start within the input or left padding + x_alt -= 1 # type: ignore[assignment] + V.graph.sizevars.check_leq(0, x_alt * stride[i] - x - padding[i]) # type: ignore[arg-type] + if V.graph.sizevars.size_hint(x_out - x_alt) == 0: + # ceil mode is actually a no-op, lets guard on that + V.graph.sizevars.check_equals(x_out, x_alt) + ceil_mode = False + else: + x_out = x_alt + return x_out, ceil_mode + + +def should_fallback_max_pool_with_indices(kernel_size, *, n_dim): + kernel_size = pad_listlike(kernel_size, n_dim) + window_size = functools.reduce(operator.mul, kernel_size) + return window_size > 25 + + +def max_pool_checks( + x, kernel_size, stride, padding, dilation, n_dim, *, assert_fallback=None +): + if padding == 0: + padding = [0] * n_dim + if dilation == 1: + dilation = [1] * n_dim + if not stride: + stride = kernel_size + + kernel_size = pad_listlike(kernel_size, n_dim) + stride = pad_listlike(stride, n_dim) + padding = pad_listlike(padding, n_dim) + dilation = pad_listlike(dilation, n_dim) + + assert isinstance(x, TensorBox) + assert len(kernel_size) == n_dim + assert len(stride) == n_dim + assert len(padding) == n_dim + assert len(dilation) == n_dim + assert len(x.get_size()) in (n_dim + 1, n_dim + 2) + + use_fallback = should_fallback_max_pool_with_indices(kernel_size, n_dim=n_dim) + if assert_fallback is not None: + assert use_fallback == assert_fallback + + return kernel_size, stride, padding, dilation, use_fallback + + +def _max_pool_with_offsets( + x, + kernel_size, + stride, + padding, + dilation, + ceil_mode, + *, + n_dim, +): + x.realize_hint() + batch = x.shape[:-n_dim] + dhw = x.shape[-n_dim:] + + dhw_out, ceil_mode = zip( + *[ + pooling_size( + dhw[d], d, kernel_size, stride, padding, ceil_mode, dilation=dilation + ) + for d in range(n_dim) + ] + ) + + dtype = x.dtype + min_value = ( + False + if dtype is torch.bool + else (float("-inf") if dtype.is_floating_point else torch.iinfo(dtype).min) + ) + + new_size = list(batch) + list(dhw_out) + if any(padding) or any(ceil_mode) or any(d > 1 for d in dilation): + x_loader = constant_boundary_condition(x, min_value, dim=n_dim) + else: + x_loader = x.make_loader() + + def fn_inner(idx, reduction_idx): + prefix = idx[:-n_dim] + bh = idx[-n_dim:] + ih = [ + (bh[i] * stride[i]) + (reduction_idx[i] * dilation[i]) - padding[i] + for i in range(n_dim) + ] + return x_loader([*prefix, *ih]) + + result = Reduction.create( + reduction_type="max", + input_node=x, + device=x.get_device(), + dst_dtype=dtype, + src_dtype=dtype, + inner_fn=fn_inner, + ranges=new_size, + reduction_ranges=kernel_size, + ) + offsets = Reduction.create( + reduction_type="argmax", + input_node=x, + device=x.get_device(), + dst_dtype=torch.int64, + src_dtype=dtype, + inner_fn=fn_inner, + ranges=new_size, + reduction_ranges=kernel_size, + ) + if isinstance(result.data.data, Reduction): # type: ignore[attr-defined, union-attr] + # Only realize if reduction isn't unrolled + result.realize() + if isinstance(offsets.data.data, Reduction): # type: ignore[attr-defined, union-attr] + # Only realize if reduction isn't unrolled + offsets.realize() + + return result, offsets + + +@register_lowering(prims._low_memory_max_pool_with_offsets, type_promotion_kind=None) +def _low_memory_max_pool_with_offsets( + x, + kernel_size, + stride, + padding, + dilation, + ceil_mode=False, +): + n_dim = len(kernel_size) + + # assert we are not on a fallback path, the inductor decomp should have guaranteed this + kernel_size, stride, padding, dilation, _ = max_pool_checks( + x, + kernel_size, + stride, + padding, + dilation, + n_dim, + assert_fallback=False, + ) + + with config.patch(unroll_reductions_threshold=25): + result, offsets = _max_pool_with_offsets( + x, + kernel_size, + stride, + padding, + dilation, + ceil_mode, + n_dim=n_dim, + ) + return result, to_dtype(offsets, torch.int8) + + +def _pool_offsets_to_indices( + offsets: TensorBox, + kernel_size: Sequence[Union[int, torch.SymInt]], + input_size: Sequence[Union[int, torch.SymInt]], + increments_to_index: Callable[ + [Sequence[Union[int, torch.SymInt]], Sequence[Union[int, torch.SymInt]]], + torch._inductor.virtualized.OpsValue, + ], +) -> Union[TensorBox, ShapeAsConstantBuffer]: + n_dim = len(kernel_size) + offsets_loader = offsets.make_loader() + window_size = sympy.sympify(functools.reduce(operator.mul, kernel_size)) + + def offsets_to_indices(idx): + offset = offsets_loader(idx) + offset_sympy = ops.indirect_indexing(offset, window_size) + reduction_idx = inductor_prims._flattened_index_to_nd(offset_sympy, kernel_size) + idhw = increments_to_index(idx, reduction_idx) + return ops.index_expr( + inductor_prims._flatten_index(idhw, input_size[-n_dim:]), torch.int64 + ) + + indices = Pointwise.create( + device=offsets.get_device(), + dtype=torch.int64, + inner_fn=offsets_to_indices, + ranges=offsets.get_size(), + ) + return indices + + +@register_lowering( + prims._low_memory_max_pool_offsets_to_indices, type_promotion_kind=None +) +def _low_memory_max_pool_offsets_to_indices( + offsets, kernel_size, input_size, stride, padding, dilation +): + # TODO: Generalize to other max pooling flavors + n_dim = len(kernel_size) + + def increments_to_index(idx, reduction_idx): + bh = idx[-n_dim:] + return [ + (bh[i] * stride[i]) + (reduction_idx[i] * dilation[i]) - padding[i] + for i in range(n_dim) + ] + + return _pool_offsets_to_indices( + offsets, kernel_size, input_size, increments_to_index + ) + + +def _max_pool_with_indices( + x, + kernel_size, + stride, + padding, + dilation, + ceil_mode, + n_dim, +): + kernel_size, stride, padding, dilation, _ = max_pool_checks( + x, kernel_size, stride, padding, dilation, n_dim=n_dim + ) + + out, offsets = _max_pool_with_offsets( + x, kernel_size, stride, padding, dilation, ceil_mode, n_dim=n_dim + ) + + indices = _low_memory_max_pool_offsets_to_indices( + offsets, + kernel_size, + x.shape[-n_dim:], + stride, + padding, + dilation, + ) + + return out, indices + + +# Fallback when we do not decompose to the low-memory path. +@register_lowering(aten.max_pool2d_with_indices, type_promotion_kind=None) +def max_pool2d_with_indices( + x, + kernel_size, + stride=None, + padding=0, + dilation=1, + ceil_mode=False, +): + return _max_pool_with_indices( + x, kernel_size, stride, padding, dilation, ceil_mode, n_dim=2 + ) + + +# Fallback when we do not decompose to the low-memory path. +@register_lowering(aten.max_pool3d_with_indices, type_promotion_kind=None) +def max_pool3d_with_indices( + x, + kernel_size, + stride=None, + padding=0, + dilation=1, + ceil_mode=False, +): + return _max_pool_with_indices( + x, kernel_size, stride, padding, dilation, ceil_mode, n_dim=3 + ) + + +fallback_max_pool2d_with_indices_backward = fallback_handler( + aten.max_pool2d_with_indices_backward.default, + add_to_fallback_set=False, +) + + +@register_lowering(aten.max_pool2d_with_indices_backward, type_promotion_kind=None) +def max_pool2d_with_indices_backward( + grad_output, x, kernel_size, stride, padding, dilation, ceil_mode, indices +): + if padding == 0: + padding = [0, 0] + if dilation == 1: + dilation = [1, 1] + if not stride: + stride = kernel_size + + assert isinstance(x, TensorBox) + assert len(kernel_size) == 2 + assert len(stride) == 2 + assert len(padding) == 2 + assert len(dilation) == 2 + assert len(x.get_size()) in (3, 4) + + # we will read this many times, so make sure it is computed + grad_output.realize_hint() + gO_stride = grad_output.maybe_get_stride() + x_stride: Optional[Sequence[Any]] + if isinstance(x, TensorBox) and isinstance(x.data.data, Pointwise): # type: ignore[attr-defined] + data = x.data.data # type: ignore[attr-defined] + device = data.get_device() + assert device is not None + x_buffer = ir.ComputedBuffer( + name=None, + layout=ir.FlexibleLayout( + device=device, + dtype=data.get_dtype(), + size=data.get_size(), + ), + data=data, + ) + x_buffer.decide_layout() + x_stride = x_buffer.get_stride() + else: + x_stride = x.maybe_get_stride() + + is_channels_last = (x_stride is not None and x_stride[1] == 1) or ( + gO_stride is not None and gO_stride[1] == 1 + ) + if any(d != 1 for d in dilation): + # dilation NYI + return fallback_max_pool2d_with_indices_backward( + grad_output, x, kernel_size, stride, padding, dilation, ceil_mode, indices + ) + + *_batch, _height, width = x.get_size() + *_, pooled_height, pooled_width = grad_output.get_size() + + indices_loader = indices.make_loader() + grad_loader = grad_output.make_loader() + new_size = list(x.get_size()) + + h_window_size = max( + max(FloorDiv(h, stride[0]) - max(0, FloorDiv(h - kernel_size[0], stride[0])), 1) + for h in range(kernel_size[0] * 2) + ) + w_window_size = max( + max(FloorDiv(w, stride[1]) - max(0, FloorDiv(w - kernel_size[1], stride[1])), 1) + for w in range(kernel_size[1] * 2) + ) + + window_size = h_window_size * w_window_size + + if window_size > 25: + # Kernel size too big. Results in hard-to-optimize Triton code. Use fallback. + return fallback_max_pool2d_with_indices_backward( + grad_output, x, kernel_size, stride, padding, dilation, ceil_mode, indices + ) + + indices_size = indices.get_size() + + def fn(idx): + *prefix, h, w = idx + index_test = ops.index_expr(h * width + w, torch.int32) + h = h + padding[0] + w = w + padding[1] + phstart = ops.index_expr( + FloorDiv(h - kernel_size[0] + stride[0], stride[0]), torch.int32 + ) + pwstart = ops.index_expr( + FloorDiv(w - kernel_size[1] + stride[1], stride[1]), torch.int32 + ) + phend = ops.index_expr(FloorDiv(h, stride[0]) + 1, torch.int32) + pwend = ops.index_expr(FloorDiv(w, stride[1]) + 1, torch.int32) + + phstart = ops.maximum(phstart, ops.constant(0, torch.int32)) + pwstart = ops.maximum(pwstart, ops.constant(0, torch.int32)) + phend = ops.minimum(phend, ops.index_expr(pooled_height, torch.int32)) + pwend = ops.minimum(pwend, ops.index_expr(pooled_width, torch.int32)) + + gradient = None + for ph_ in range(h_window_size): + for pw_ in range(w_window_size): + ph = ops.add(phstart, ops.constant(ph_, torch.int32)) + pw = ops.add(pwstart, ops.constant(pw_, torch.int32)) + grad_index = [ + *prefix, + ops.indirect_indexing( + ops.minimum(ph, ops.sub(phend, ops.constant(1, torch.int32))), + indices_size[-2], + check=False, + ), + ops.indirect_indexing( + ops.minimum(pw, ops.sub(pwend, ops.constant(1, torch.int32))), + indices_size[-1], + check=False, + ), + ] + + index_actual = indices_loader(grad_index) + grad_part = grad_loader(grad_index) + check = ops.eq(index_actual, index_test) + + if gradient is None: + # don't need mask for 0, 0 + gradient = ops.where( + check, grad_part, ops.constant(0.0, torch.float32) + ) + else: + mask = ops.and_( + ops.and_( + ops.lt(ph, phend), + ops.lt(pw, pwend), + ), + check, + ) + gradient = ops.where(mask, ops.add(gradient, grad_part), gradient) + assert gradient is not None + return gradient + + out = Pointwise.create( + device=grad_output.get_device(), + dtype=grad_output.get_dtype(), + inner_fn=fn, + ranges=new_size, + ) + if is_channels_last: + return ir.ExternKernel.require_channels_last(out) + else: + return out + + +def pad_adaptive_loader(x, pad_val=0.0): + x_loader = x.make_loader() + + def load(prefix, increments, start_indices, end_indices): + ih, iw = increments + h_start_index, w_start_index = start_indices + h_end_index, w_end_index = end_indices + + mask = ops.and_( + ops.lt( + ops.index_expr(h_start_index + ih, torch.int64), + ops.index_expr(h_end_index, torch.int64), + ), + ops.lt( + ops.index_expr(w_start_index + iw, torch.int64), + ops.index_expr(w_end_index, torch.int64), + ), + ) + + return ops.masked( + mask, + lambda: x_loader([*prefix, h_start_index + ih, w_start_index + iw]), + pad_val, + ) + + return load + + +def compute_indices_adaptive_pooling(start_index, end_index, h_in, w_in, h_out, w_out): + h_start_index = functools.partial(start_index, out_dim=h_out, inp_dim=h_in) + h_end_index = functools.partial(end_index, out_dim=h_out, inp_dim=h_in) + + w_start_index = functools.partial(start_index, out_dim=w_out, inp_dim=w_in) + w_end_index = functools.partial(end_index, out_dim=w_out, inp_dim=w_in) + + return h_start_index, h_end_index, w_start_index, w_end_index + + +def _adaptive_pooling_fn( + start_index, end_index, kernel_maxes, in_sizes, out_sizes, pooling_fn +): + h_in, w_in = in_sizes + h_out, w_out = out_sizes + + ( + h_start_index_fn, + h_end_index_fn, + w_start_index_fn, + w_end_index_fn, + ) = compute_indices_adaptive_pooling( + start_index, end_index, h_in, w_in, h_out, w_out + ) + + def fn(idx, loader): + *prefix, bh, bw = idx + + h_start_index = h_start_index_fn(bh) + h_end_index = h_end_index_fn(bh) + + w_start_index = w_start_index_fn(bw) + w_end_index = w_end_index_fn(bw) + + result = None + for ih, iw in itertools.product(range(kernel_maxes[0]), range(kernel_maxes[1])): + val = loader( + prefix, + [ih, iw], + [h_start_index, w_start_index], + [h_end_index, w_end_index], + ) + if result is None: + result = val + else: + result = pooling_fn(val, result) + return result + + return fn + + +def _adaptive_pooling_fn_with_idx( + start_index, end_index, kernel_maxes, in_sizes, out_sizes, pooling_fn +): + h_in, w_in = in_sizes + h_out, w_out = out_sizes + + ( + h_start_index_fn, + h_end_index_fn, + w_start_index_fn, + w_end_index_fn, + ) = compute_indices_adaptive_pooling( + start_index, end_index, h_in, w_in, h_out, w_out + ) + + def fn(idx, loader): + *prefix, bh, bw = idx + + h_start_index = h_start_index_fn(bh) + h_end_index = h_end_index_fn(bh) + + w_start_index = w_start_index_fn(bw) + w_end_index = w_end_index_fn(bw) + + maxval = None + maxindex = None + for ih, iw in itertools.product(range(kernel_maxes[0]), range(kernel_maxes[1])): + val = loader( + prefix, + [ih, iw], + [h_start_index, w_start_index], + [h_end_index, w_end_index], + ) + + index = ops.index_expr( + (h_start_index + ih) * w_in + w_start_index + iw, torch.int64 + ) + + if maxindex is None: + maxindex = index + else: + maxindex = ops.where(ops.gt(val, maxval), index, maxindex) + + if maxval is None: + maxval = val + else: + maxval = pooling_fn(val, maxval) + + return maxindex + + return fn + + +fallback_adaptive_avg_pool2d = fallback_handler( + aten._adaptive_avg_pool2d.default, add_to_fallback_set=False +) + + +@register_lowering(aten._adaptive_avg_pool2d) +def _adaptive_avg_pool2d(x, output_size): + if x.get_dtype() == torch.int64: + # not supported in eager + raise RuntimeError("'adaptive_avg_pool2d' not implemented for 'Long'") + assert isinstance(x, TensorBox) + assert len(output_size) == 2 + x.realize_hint() + + *batch, h_in, w_in = x.get_size() + + h_in = V.graph.sizevars.guard_int(h_in) + w_in = V.graph.sizevars.guard_int(w_in) + + h_out, w_out = output_size + + # no-op if the same input and output + if h_in == h_out and w_in == w_out: + return clone(x) + + if h_out == 0 or w_out == 0: + o_size = [*batch, h_out, w_out] + return empty(o_size, dtype=x.get_dtype(), device=x.get_device()) + if h_in % h_out == 0 and w_in % w_out == 0: + kernel_size = [FloorDiv(h_in, h_out), FloorDiv(w_in, w_out)] + return avg_pool2d(x, kernel_size) + + h_kernel_max = ceildiv((h_in + h_out - 1), h_out) + w_kernel_max = ceildiv((w_in + w_out - 1), w_out) + + new_size = list(batch) + [h_out, w_out] + dtype = x.get_dtype() + + window_size = h_kernel_max * w_kernel_max + if window_size > 25: + # Kernel size too big. Results in hard-to-optimize Triton code. Use fallback. + return fallback_adaptive_avg_pool2d(x, output_size) + + def start_index(index, out_dim, inp_dim): + return FloorDiv((index * inp_dim), out_dim) + + def end_index(index, out_dim, inp_dim): + return FloorDiv((index + 1) * inp_dim + out_dim - 1, out_dim) + + fn_sum = _adaptive_pooling_fn( + start_index=start_index, + end_index=end_index, + kernel_maxes=[h_kernel_max, w_kernel_max], + in_sizes=[h_in, w_in], + out_sizes=[h_out, w_out], + pooling_fn=ops.add, + ) + + ones_loader = pad_adaptive_loader(ones_like(x)) + + def fn(idx): + return ops.truediv( + fn_sum(idx, pad_adaptive_loader(x)), fn_sum(idx, ones_loader) + ) + + rv = Pointwise.create( + device=x.get_device(), + dtype=dtype, + inner_fn=fn, + ranges=new_size, + ) + # TODO: should we force these to be realized? + return rv + + +fallback_adaptive_max_pool2d = fallback_handler( + aten.adaptive_max_pool2d.default, add_to_fallback_set=False +) + + +@register_lowering(aten.adaptive_max_pool2d) +def adaptive_max_pool2d(x, output_size): + if x.get_dtype() == torch.int64: + # not supported in eager + raise RuntimeError("adaptive_max_pool2d not implemented for Long") + assert isinstance(x, TensorBox) + assert len(output_size) == 2 + x.realize_hint() + + *batch, h_in, w_in = x.get_size() + + h_in = V.graph.sizevars.guard_int(h_in) + w_in = V.graph.sizevars.guard_int(w_in) + + h_out, w_out = output_size + + if h_out == 0 or w_out == 0: + o_size = [*batch, h_out, w_out] + return empty(o_size, dtype=x.get_dtype(), device=x.get_device()), empty( + o_size, dtype=torch.int64, device=x.get_device() + ) + + if h_in % h_out == 0 and w_in % w_out == 0: + # This is handled by a decomposition + raise ValueError + + h_kernel_max = ceildiv((h_in + h_out - 1), h_out) + w_kernel_max = ceildiv((w_in + w_out - 1), w_out) + + new_size = list(batch) + [h_out, w_out] + dtype = x.get_dtype() + + window_size = h_kernel_max * w_kernel_max + if window_size > 25: + # Kernel size too big. Results in hard-to-optimize Triton code. Use fallback. + return fallback_adaptive_max_pool2d(x, output_size) + + def start_index(index, out_dim, inp_dim): + return FloorDiv((index * inp_dim), out_dim) + + def end_index(index, out_dim, inp_dim): + return FloorDiv((index + 1) * inp_dim + out_dim - 1, out_dim) + + inner_func_max_val = _adaptive_pooling_fn( + start_index=start_index, + end_index=end_index, + kernel_maxes=[h_kernel_max, w_kernel_max], + in_sizes=[h_in, w_in], + out_sizes=[h_out, w_out], + pooling_fn=ops.maximum, + ) + + inner_func_max_idx = _adaptive_pooling_fn_with_idx( + start_index=start_index, + end_index=end_index, + kernel_maxes=[h_kernel_max, w_kernel_max], + in_sizes=[h_in, w_in], + out_sizes=[h_out, w_out], + pooling_fn=ops.maximum, + ) + + def inner_fn_max_val(idx): + return inner_func_max_val(idx, pad_adaptive_loader(x, float("-inf"))) + + def inner_fn_max_idx(idx): + return inner_func_max_idx(idx, pad_adaptive_loader(x, float("-inf"))) + + rv = Pointwise.create( + device=x.get_device(), + dtype=dtype, + inner_fn=inner_fn_max_val, + ranges=new_size, + ) + ri = Pointwise.create( + device=x.get_device(), + dtype=torch.int64, + inner_fn=inner_fn_max_idx, + ranges=new_size, + ) + return rv, ri + + +def _fractional_pooling_offsets(samples, in_sz, out_sz, kernel_sz, dim, ndims): + out_sz = out_sz[dim] + in_sz = in_sz[dim] + kernel_sz = kernel_sz[dim] + samples_loader = samples.make_loader() + + def load(prefix, i): + # Handle indexing for samples tensor correctly for different input dimensions + # samples tensor always has shape (N, C, 2) for fractional_max_pool2d where: + # - N=1 for 3D inputs (C,H,W), N=batch_size for 4D inputs (N,C,H,W) + # - C=num_channels + # - 2 for the two spatial dimensions (height, width) + samples_shape = samples.get_size() + + if len(samples_shape) == 3: # Expected: (N, C, 2) + if len(prefix) == 1: + # 3D input case: prefix=(channel,), samples=(1, C, 2) + # Access: samples[0, channel, dim] + sample = samples_loader([0, prefix[0], ndims - 1 - dim]) + elif len(prefix) >= 2: + # 4D+ input case: prefix=(batch, channel, ...), samples=(batch, C, 2) + # Access: samples[batch, channel, dim] + sample = samples_loader([prefix[0], prefix[1], ndims - 1 - dim]) + else: + # Edge case - shouldn't happen for valid fractional pooling + sample = samples_loader([0, 0, ndims - 1 - dim]) + else: + # Fallback for unexpected tensor shapes + sample = samples_loader([*prefix, ndims - 1 - dim]) + i_expr = ops.index_expr(i, samples.get_dtype()) + diff = ops.index_expr(in_sz - kernel_sz, torch.int64) + out_sz_expr = ops.index_expr(out_sz - 1, torch.int64) + alpha = ops.truediv( + ops.to_dtype(diff, torch.float64), ops.to_dtype(out_sz_expr, torch.float64) + ) + alpha = ops.where(ops.eq(out_sz_expr, 0), 0, alpha) + seq_i = ops.trunc((i_expr + sample) * alpha) - ops.trunc(sample * alpha) + seq_i = ops.to_dtype(seq_i, torch.int64) + mask = ops.lt(i_expr, out_sz_expr) + return ops.indirect_indexing(ops.where(mask, seq_i, diff), sympy.sympify(in_sz)) + + return load + + +@register_lowering(aten.fractional_max_pool2d) +def fractional_max_pool2d(x, kernel_size, output_size, random_samples): + return _fractional_max_pool(x, kernel_size, output_size, random_samples, n_dim=2) + + +@register_lowering(aten.fractional_max_pool3d) +def fractional_max_pool3d(x, kernel_size, output_size, random_samples): + return _fractional_max_pool(x, kernel_size, output_size, random_samples, n_dim=3) + + +def _fractional_max_pool(x, kernel_size, output_size, random_samples, n_dim): + x.realize_hint() + batch, inp_dhw = x.shape[:-n_dim], x.shape[-n_dim:] + + with config.patch(unroll_reductions_threshold=25): + dhw_index_fn = [ + _fractional_pooling_offsets( + samples=random_samples, + in_sz=inp_dhw, + out_sz=output_size, + kernel_sz=kernel_size, + ndims=n_dim, + dim=d, + ) + for d in range(n_dim) + ] + + x_loader = x.make_loader() + + def fn_inner(idx, reduction_idx): + prefix = idx[:-n_dim] + return x_loader([*prefix, *increments_to_index(idx, reduction_idx)]) + + def increments_to_index(idx, reduction_idx): + prefix = idx[:-n_dim] + bdhw = idx[-n_dim:] + return [ + dhw_index_fn[d](prefix, bdhw[d]) + reduction_idx[d] + for d in range(n_dim) + ] + + new_size = list(batch) + list(output_size) + dtype = x.get_dtype() + result = Reduction.create( + reduction_type="max", + input_node=x, + device=x.get_device(), + dst_dtype=dtype, + src_dtype=dtype, + inner_fn=fn_inner, + ranges=new_size, + reduction_ranges=kernel_size, + ) + offsets = Reduction.create( + reduction_type="argmax", + input_node=x, + device=x.get_device(), + dst_dtype=torch.int64, + src_dtype=dtype, + inner_fn=fn_inner, + ranges=new_size, + reduction_ranges=kernel_size, + ) + assert isinstance(result, TensorBox), result + if isinstance(result.data.data, Reduction): # type: ignore[attr-defined] + # Only realize if reduction isn't unrolled + result.realize() + assert isinstance(offsets, TensorBox), offsets + if isinstance(offsets.data.data, Reduction): # type: ignore[attr-defined] + # Only realize if reduction isn't unrolled + offsets.realize() + + indices = _pool_offsets_to_indices( + offsets, kernel_size, x.shape, increments_to_index + ) + return result, indices + + +@register_lowering(aten.upsample_nearest2d_backward.default) +def upsample_nearest2d_backward( + x, output_size=None, input_size=None, scales_h=None, scales_w=None +): + x.realize_hint() + + *_batch, inp_h, inp_w = x.get_size() + inp_h = V.graph.sizevars.guard_int(inp_h) + inp_w = V.graph.sizevars.guard_int(inp_w) + + # pyrefly: ignore [not-iterable] + *_batch, out_h, out_w = input_size + + if inp_h % out_h == 0 and inp_w % out_w == 0: + return avg_pool2d( + x, [FloorDiv(inp_h, out_h), FloorDiv(inp_w, out_w)], divisor_override=1 + ) + + h_kernel_max = ceildiv(inp_h, out_h) + w_kernel_max = ceildiv(inp_w, out_w) + + def start_index(index, out_dim, inp_dim): + return CeilDiv(index * inp_dim, sympy.sympify(out_dim)) + + def end_index(index, out_dim, inp_dim): + return start_index((index + 1), out_dim, inp_dim) + + fn_sum = _adaptive_pooling_fn( + start_index=start_index, + end_index=end_index, + kernel_maxes=[h_kernel_max, w_kernel_max], + in_sizes=[inp_h, inp_w], + out_sizes=[out_h, out_w], + pooling_fn=ops.add, + ) + + def fn(idx): + return fn_sum(idx, pad_adaptive_loader(x)) + + rv = Pointwise.create( + device=x.get_device(), + dtype=x.get_dtype(), + inner_fn=fn, + # pyrefly: ignore [no-matching-overload] + ranges=list(input_size), + ) + + return rv + + +fallback_avg_pool2d = fallback_handler( + aten.avg_pool2d.default, add_to_fallback_set=False +) +fallback_avg_pool3d = fallback_handler( + aten.avg_pool3d.default, add_to_fallback_set=False +) + + +@register_lowering(aten.avg_pool2d, type_promotion_kind=None) +def avg_pool2d( + x, + kernel_size, + stride=(), + padding=0, + ceil_mode=False, + count_include_pad=True, + divisor_override=None, +): + return _avg_poolnd( + x, + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override, + dim=2, + ) + + +@register_lowering(aten.avg_pool3d, type_promotion_kind=None) +def avg_pool3d( + x, + kernel_size, + stride=(), + padding=0, + ceil_mode=False, + count_include_pad=True, + divisor_override=None, +): + return _avg_poolnd( + x, + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override, + dim=3, + ) + + +def _avg_poolnd( + x, + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override, + dim, +): + if not stride: + stride = kernel_size + if not padding: + padding = [0] * dim + kernel_size = pad_listlike(kernel_size, dim) + stride = pad_listlike(stride, dim) + padding = pad_listlike(padding, dim) + + assert isinstance(x, TensorBox) + assert len(kernel_size) == dim + assert len(stride) == dim + assert len(padding) == dim + assert len(x.get_size()) in (dim + 1, dim + 2) + + x.realize_hint() + batch = x.get_size()[:-dim] + h = x.get_size()[-dim:] + + h_out, ceil_modes = zip( + *[ + pooling_size(h[i], i, kernel_size, stride, padding, ceil_mode) + for i in range(dim) + ] + ) + + if any(padding) or any(ceil_modes): + x_loader = constant_boundary_condition(x, 0.0, dim=dim) + had_padding = True + else: + x_loader = x.make_loader() + had_padding = False + + new_size = list(batch) + list(h_out) + dtype = x.get_dtype() + + window_size = functools.reduce(operator.mul, kernel_size) + if window_size > 25: + # Kernel size too big. Results in hard-to-optimize Triton code. Use fallback. + if dim == 2: + fallback = fallback_avg_pool2d + elif dim == 3: + fallback = fallback_avg_pool3d + else: + raise ValueError(f"Unknown dim: {dim}") + + return fallback( + x, + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override, + ) + + def fn_sum(idx, loader): + prefix = idx[:-dim] + b = idx[-dim:] + total = None + for ih in itertools.product(*[range(kernel_size[i]) for i in range(dim)]): + inp = [b[i] * stride[i] + ih[i] - padding[i] for i in range(dim)] + val = loader([*prefix, *inp]) + if total is None: + total = val + else: + total = ops.add(val, total) + return total + + if not had_padding or divisor_override: + divisor = divisor_override if divisor_override else window_size + if dtype.is_floating_point: + scale = 1 / divisor + + def fn(idx): + return ops.mul(fn_sum(idx, x_loader), ops.constant(scale, dtype)) + + else: + + def fn(idx): + # C style integer division as done in native/cpu/AvgPoolKernel.cpp + return ops.truncdiv(fn_sum(idx, x_loader), ops.constant(divisor, dtype)) + + else: + + def fn(idx): + bh = idx[-dim:] + + divide_factors = [] + for i in range(dim): + hstart = bh[i] * stride[i] - padding[i] + hend = sympy.Min(hstart + kernel_size[i], h[i] + padding[i]) + if not count_include_pad: + hstart = sympy.Max(hstart, 0) + hend = sympy.Min(hend, h[i]) + factor = ops.index_expr(hend - hstart, torch.int32) + divide_factors.append(factor) + divide_factor = functools.reduce(ops.mul, divide_factors) + if dtype.is_floating_point: + return ops.truediv(fn_sum(idx, x_loader), divide_factor) + # C style integer division as done in native/cpu/AvgPoolKernel.cpp + return ops.truncdiv(fn_sum(idx, x_loader), divide_factor) + + rv = Pointwise.create( + device=x.get_device(), + dtype=dtype, + inner_fn=fn, + ranges=new_size, + ) + # TODO(jansel): should we force these to be realized? + return rv + + +fallback_avg_pool2d_backward = fallback_handler( + aten.avg_pool2d_backward.default, add_to_fallback_set=False +) + + +@register_lowering(aten.avg_pool2d_backward, type_promotion_kind=None) +def avg_pool2d_backward( + grad_output, + x, + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override=None, +): + assert divisor_override is None or divisor_override != 0, "divisor must be not zero" + if not stride: + stride = kernel_size + if not padding: + padding = [0, 0] + + assert isinstance(grad_output, TensorBox) + assert isinstance(x, TensorBox) + assert len(kernel_size) == 2 + assert len(stride) == 2 + assert len(padding) == 2 + assert len(x.get_size()) in (3, 4) + + grad_output.realize_hint() # we will read this many times, so make sure it is computed + + *_, height, width = x.get_size() + + _h_out, ceil_mode1 = pooling_size( + height, 0, kernel_size, stride, padding, ceil_mode + ) + _w_out, ceil_mode2 = pooling_size(width, 1, kernel_size, stride, padding, ceil_mode) + + grad_loader = grad_output.make_loader() + + had_padding = padding[0] or padding[1] or ceil_mode1 or ceil_mode2 + + *_, pooled_height, pooled_width = grad_output.get_size() + new_size = list(x.get_size()) + dtype = x.get_dtype() + + h_window_size = max( + max(FloorDiv(h, stride[0]) - max(0, FloorDiv(h - kernel_size[0], stride[0])), 1) + for h in range(kernel_size[0] * 2) + ) + w_window_size = max( + max(FloorDiv(w, stride[1]) - max(0, FloorDiv(w - kernel_size[1], stride[1])), 1) + for w in range(kernel_size[1] * 2) + ) + + window_size = h_window_size * w_window_size + if window_size > 25: + # Kernel size too big. Results in hard-to-optimize Triton code. Use fallback. + return fallback_avg_pool2d_backward( + grad_output, + x, + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override, + ) + + def compute_pool_size_without_padding(ph, pw): + """ + This computes the scaling factor that we will divide an element + by when `count_include_pad=False` + """ + stride_h = ops.constant(stride[0], torch.int32) + stride_w = ops.constant(stride[1], torch.int32) + pad_h = ops.constant(padding[0], torch.int32) + pad_w = ops.constant(padding[1], torch.int32) + kernel_h = ops.constant(kernel_size[0], torch.int32) + kernel_w = ops.constant(kernel_size[1], torch.int32) + hstart = ops.sub(ops.mul(ph, stride_h), pad_h) + wstart = ops.sub(ops.mul(pw, stride_w), pad_w) + hend = ops.minimum( + ops.add(hstart, kernel_h), + ops.add(ops.index_expr(height, torch.int32), pad_h), + ) + wend = ops.minimum( + ops.add(wstart, kernel_w), + ops.add(ops.index_expr(width, torch.int32), pad_w), + ) + hstart = ops.maximum(hstart, ops.constant(0, torch.int32)) + wstart = ops.maximum(wstart, ops.constant(0, torch.int32)) + hend = ops.minimum(hend, ops.index_expr(height, torch.int32)) + wend = ops.minimum(wend, ops.index_expr(width, torch.int32)) + divide_factor = ops.mul(ops.sub(hend, hstart), ops.sub(wend, wstart)) + return divide_factor + + def fn(idx): + *prefix, h, w = idx + h = h + padding[0] + w = w + padding[1] + phstart = ops.index_expr( + FloorDiv(h - kernel_size[0] + stride[0], stride[0]), torch.int32 + ) + pwstart = ops.index_expr( + FloorDiv(w - kernel_size[1] + stride[1], stride[1]), torch.int32 + ) + phend = ops.index_expr(FloorDiv(h, stride[0]) + 1, torch.int32) + pwend = ops.index_expr(FloorDiv(w, stride[1]) + 1, torch.int32) + + phstart = ops.maximum(phstart, ops.constant(0, torch.int32)) + pwstart = ops.maximum(pwstart, ops.constant(0, torch.int32)) + phend = ops.minimum(phend, ops.index_expr(pooled_height, torch.int32)) + pwend = ops.minimum(pwend, ops.index_expr(pooled_width, torch.int32)) + + gradient = None + for ph_ in range(h_window_size): + for pw_ in range(w_window_size): + ph = ops.add(phstart, ops.constant(ph_, torch.int32)) + pw = ops.add(pwstart, ops.constant(pw_, torch.int32)) + + if divisor_override is not None: + scale = divisor_override + elif count_include_pad or not had_padding: + scale = kernel_size[0] * kernel_size[1] + else: + scale = compute_pool_size_without_padding(ph, pw) + + part = ops.truediv( + grad_loader( + [ + *prefix, + ops.indirect_indexing( + ops.minimum( + ph, ops.sub(phend, ops.constant(1, torch.int32)) + ), + pooled_height, + check=False, + ), + ops.indirect_indexing( + ops.minimum( + pw, ops.sub(pwend, ops.constant(1, torch.int32)) + ), + pooled_width, + check=False, + ), + ] + ), + scale, + ) + + mask = ops.and_( + ops.lt(ph, phend), + ops.lt(pw, pwend), + ) + if gradient is None: + gradient = ops.where(mask, part, ops.constant(0.0, torch.float32)) + else: + gradient = ops.where(mask, ops.add(gradient, part), gradient) + assert gradient is not None + return gradient + + rv = Pointwise.create( + device=grad_output.get_device(), + dtype=dtype, + inner_fn=fn, + ranges=new_size, + ) + return rv + + +fallback_avg_pool3d_backward = fallback_handler( + aten.avg_pool3d_backward.default, add_to_fallback_set=False +) + + +@register_lowering(aten.avg_pool3d_backward, type_promotion_kind=None) +def avg_pool3d_backward( + grad_output, + x, + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override=None, +): + assert divisor_override is None or divisor_override != 0, "divisor must be not zero" + if not stride: + stride = kernel_size + if not padding: + padding = [0, 0, 0] + + assert isinstance(grad_output, TensorBox) + assert isinstance(x, TensorBox) + assert len(kernel_size) == 3 + assert len(stride) == 3 + assert len(padding) == 3 + assert len(x.get_size()) in (4, 5) + + grad_output.realize_hint() + + *_batch, depth, height, width = x.get_size() + + _d_out, ceil_mode_d = pooling_size( + depth, 0, kernel_size, stride, padding, ceil_mode + ) + _h_out, ceil_mode_h = pooling_size( + height, 1, kernel_size, stride, padding, ceil_mode + ) + _w_out, ceil_mode_w = pooling_size( + width, 2, kernel_size, stride, padding, ceil_mode + ) + + grad_loader = grad_output.make_loader() + had_padding = any(padding) or ceil_mode_d or ceil_mode_h or ceil_mode_w + + *_, pooled_depth, pooled_height, pooled_width = grad_output.get_size() + new_size = list(x.get_size()) + dtype = x.get_dtype() + + d_window_size, h_window_size, w_window_size = ( + max( + max(d // stride[i] - max(0, (d - kernel_size[i]) // stride[i]), 1) + for d in range(kernel_size[i] * 2) + ) + for i in range(3) + ) + + window_size = d_window_size * h_window_size * w_window_size + if window_size > 125: + # Kernel size too big. Results in hard-to-optimize Triton code. + return fallback_avg_pool3d_backward( + grad_output, + x, + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override, + ) + + def compute_pool_size_without_padding(pd, ph, pw): + stride_d, stride_h, stride_w = (ops.constant(s, torch.int32) for s in stride) + pad_d, pad_h, pad_w = (ops.constant(p, torch.int32) for p in padding) + kernel_d, kernel_h, kernel_w = ( + ops.constant(k, torch.int32) for k in kernel_size + ) + + dstart, hstart, wstart = ( + ops.sub(ops.mul(p, s), pad) + for p, s, pad in zip( + [pd, ph, pw], [stride_d, stride_h, stride_w], [pad_d, pad_h, pad_w] + ) + ) + dend, hend, wend = ( + ops.minimum( + ops.add(start, k), ops.add(ops.index_expr(dim, torch.int32), pad) + ) + for start, k, dim, pad in zip( + [dstart, hstart, wstart], + [kernel_d, kernel_h, kernel_w], + [depth, height, width], + [pad_d, pad_h, pad_w], + ) + ) + dstart, hstart, wstart = ( + ops.maximum(start, ops.constant(0, torch.int32)) + for start in [dstart, hstart, wstart] + ) + dend, hend, wend = ( + ops.minimum(end, ops.index_expr(dim, torch.int32)) + for end, dim in zip([dend, hend, wend], [depth, height, width]) + ) + divide_factor = ops.mul( + ops.mul(ops.sub(dend, dstart), ops.sub(hend, hstart)), ops.sub(wend, wstart) + ) + return divide_factor + + def fn(idx): + *prefix, d, h, w = idx + d, h, w = (v + pad for v, pad in zip([d, h, w], padding)) + + pdstart, phstart, pwstart = ( + ops.index_expr(FloorDiv(v - k + s, s), torch.int32) + for v, k, s in zip([d, h, w], kernel_size, stride) + ) + + pdend, phend, pwend = ( + ops.index_expr(FloorDiv(v, s) + 1, torch.int32) + for v, s in zip([d, h, w], stride) + ) + + pdstart, phstart, pwstart = ( + ops.maximum(pstart, ops.constant(0, torch.int32)) + for pstart in [pdstart, phstart, pwstart] + ) + pdend, phend, pwend = ( + ops.minimum(pend, ops.index_expr(pooled_dim, torch.int32)) + for pend, pooled_dim in zip( + [pdend, phend, pwend], [pooled_depth, pooled_height, pooled_width] + ) + ) + + gradient = None + # Iterate over the 3D region to accumulate gradients + for pd_ in range(d_window_size): + for ph_ in range(h_window_size): + for pw_ in range(w_window_size): + pd, ph, pw = ( + ops.add(pstart, ops.constant(p_, torch.int32)) + for pstart, p_ in zip( + [pdstart, phstart, pwstart], [pd_, ph_, pw_] + ) + ) + + if divisor_override is not None: + scale = divisor_override + elif count_include_pad or not had_padding: + scale = kernel_size[0] * kernel_size[1] * kernel_size[2] + else: + scale = compute_pool_size_without_padding(pd, ph, pw) + + part = ops.truediv( + grad_loader( + [ + *prefix, + ops.indirect_indexing( + ops.minimum( + pd, ops.sub(pdend, ops.constant(1, torch.int32)) + ), + pooled_depth, + check=False, + ), + ops.indirect_indexing( + ops.minimum( + ph, ops.sub(phend, ops.constant(1, torch.int32)) + ), + pooled_height, + check=False, + ), + ops.indirect_indexing( + ops.minimum( + pw, ops.sub(pwend, ops.constant(1, torch.int32)) + ), + pooled_width, + check=False, + ), + ] + ), + scale, + ) + + mask = ops.and_( + ops.and_(ops.lt(pd, pdend), ops.lt(ph, phend)), + ops.lt(pw, pwend), + ) + if gradient is None: + gradient = ops.where( + mask, part, ops.constant(0.0, torch.float32) + ) + else: + gradient = ops.where(mask, ops.add(gradient, part), gradient) + assert gradient is not None + return gradient + + rv = Pointwise.create( + device=grad_output.get_device(), + dtype=dtype, + inner_fn=fn, + ranges=new_size, + ) + return rv + + +def _validate_reduction_axis(x, axis): + size = x.get_size() + if isinstance(axis, int): + axis = [axis] + elif not axis: + axis = range(len(size)) + if len(size) == 0: + assert tuple(axis) in [(), (0,), (-1,)], f"invalid axis: {axis}" + return [] + axis = list(axis) + for i in range(len(axis)): + if axis[i] < 0: + axis[i] += len(size) if len(size) else 1 + assert 0 <= axis[i] < len(size) or (len(size) == 0 and axis[i] == 0) + assert len(OrderedSet(axis)) == len(axis), "reduction axis not unique" + return axis + + +def _make_reduction_inner( + x, *, axis, keepdims, dtype, override_return_dtype, reduction_type=None +): + if dtype is not None: + x = to_dtype(x, dtype) + size = x.get_size() + axis = OrderedSet[int](_validate_reduction_axis(x, axis)) + + kept_sizes = [] + kept_idx = [] + reduced_sizes = [] + reduced_idx = [] + for i in range(len(size)): + if i in axis: + reduced_idx.append(i) + reduced_sizes.append(size[i]) + else: + kept_idx.append(i) + kept_sizes.append(size[i]) + + # For argmax/argmin compute logical indices when the tensor has non-contiguous layout. + should_compute_logical_index = False + if ( + reduction_type in ("argmax", "argmin") + and len(reduced_sizes) > 1 + and is_triton(x) + ): + if isinstance(x.data, PermuteView): + should_compute_logical_index = True + elif isinstance(x.data, ir.ReinterpretView) or ( + isinstance(x.data, ir.StorageBox) and isinstance(x.data.data, ir.Buffer) + ): + layout = x.get_layout() + should_compute_logical_index = ( + layout.is_transposed() or not layout.is_contiguous() + ) + + def loader(index, reduction_index): + assert len(reduction_index) == len(reduced_idx) + if keepdims: + assert len(index) == len(size) + index = [index[i] for i in kept_idx] + assert len(index) == len(kept_idx) + new_index = [None] * (len(index) + len(reduction_index)) + for idx, var in itertools.chain( + zip(kept_idx, index), zip(reduced_idx, reduction_index) + ): + new_index[idx] = var + value = inner_loader(new_index) + + # For argmax/argmin, return tuple with logical linear index if needed + if should_compute_logical_index: + rindex = [sympy.expand(i) for i in reduction_index] + + # Compute linear index in row-major order + # For reduction_ranges = [4, 6]: linear_index = r0 * 6 + r1 + linear_idx = rindex[0] + for i in range(1, len(rindex)): + linear_idx = linear_idx * reduced_sizes[i] + rindex[i] + + return (value, ops.index_expr(linear_idx, torch.int64)) + + return value + + if keepdims: + new_size = list(size) + for i in reduced_idx: + new_size[i] = sympy.S.One + else: + new_size = kept_sizes + + inner_loader = x.make_loader() + return dict( + device=x.get_device(), + dst_dtype=override_return_dtype or x.get_dtype(), + src_dtype=x.get_dtype(), + inner_fn=loader, + ranges=new_size, + reduction_ranges=reduced_sizes, + ) + + +def make_reduction(reduction_type: ReductionType, override_return_dtype=None): + def inner(x, axis=None, keepdims=False, *, dtype=None): + kwargs = _make_reduction_inner( + x, + axis=axis, + keepdims=keepdims, + dtype=dtype, + override_return_dtype=override_return_dtype, + reduction_type=reduction_type, + ) + result = Reduction.create(reduction_type=reduction_type, input_node=x, **kwargs) + if isinstance( + result.data.data, # type: ignore[attr-defined, attr-type, union-attr] + Reduction, + ): # Only realize if reduction isn't unrolled + result.realize() + return result + + return inner + + +def _make_scan_inner(x, *, axis, dtype): + if dtype is not None: + x = to_dtype(x, dtype) + axis = _validate_dim(x, axis) + + return dict( + device=x.get_device(), + dtypes=(x.get_dtype(),), + inner_fns=(x.make_loader(),), + size=x.get_size(), + axis=axis, + ) + + +@register_lowering(aten.mean) +def mean(x, axis=None, keepdim=False, *, dtype=None): + if dtype is not None: + x = to_dtype(x, dtype) + size = x.get_size() + axis = _validate_reduction_axis(x, axis) + # compute in higher-precision until end of mean lowering + output_dtype = x.get_dtype() + if output_dtype in (torch.float16, torch.bfloat16): + x = to_dtype(x, torch.float) + sum_result = sum_(x, axis, keepdim) + denom = sympy_product(size[i] for i in axis) + denom = ir.IndexingConstant(index=denom, dtype=x.get_dtype(), device=x.get_device()) + denom = ExpandView.create(denom, list(sum_result.get_size())) + return to_dtype(div(sum_result, denom), output_dtype) + + +def var_mean_sum_(x, axis, correction, keepdim, return_mean): + if correction is None: + correction = 1 + + size = x.get_size() + axis = _validate_reduction_axis(x, axis) + x_mean = mean(x, axis, keepdim=True) + if return_mean: + x_mean.realize() + + diffs = square(sub(x, x_mean)) + sum_result = sum_(diffs, axis, keepdim) + + denom = sympy_product(size[i] for i in axis) + if correction: + denom = sympy.Max(denom - correction, 0) + denom = ir.IndexingConstant(index=denom, dtype=x.get_dtype(), device=x.get_device()) + denom = ExpandView.create(denom, list(sum_result.get_size())) + x_var = div(sum_result, denom) + if not return_mean: + return (x_var,) + + x_mean = x_mean if keepdim else squeeze(x_mean, axis) + return x_var, x_mean + + +def use_two_step_variance(x, axis, keepdim): + # Instead of unrolling welford, just unroll the simpler two-step var + axis = _validate_reduction_axis(x, axis) + kwargs = _make_reduction_inner( + x, axis=axis, keepdims=keepdim, dtype=None, override_return_dtype=None + ) + + ranges = kwargs["ranges"] + reduction_numel = sympy_product(kwargs["reduction_ranges"]) + return ( + isinstance(reduction_numel, sympy.Integer) + and int(reduction_numel) < config.unroll_reductions_threshold + and sympy_product(ranges) != 1 + ) + + +def var_mean_welford_(x, axis, *, correction, keepdim, return_mean): + if correction is None: + correction = 1 + + kwargs = _make_reduction_inner( + x, axis=axis, keepdims=keepdim, dtype=None, override_return_dtype=None + ) + loader = kwargs.pop("inner_fn") + kwargs.pop("dst_dtype") + kwargs.pop("src_dtype") + + mean, m2, _ = ir.WelfordReduction.create( + inner_fns=(loader,), + reduction_type="welford_reduce", + dtype=x.get_dtype(), + **kwargs, + ) + m2.realize() + + dtype = x.get_dtype() + size = x.get_size() + axis = _validate_reduction_axis(x, axis) + rnumel = sympy_product(size[i] for i in axis) + + def get_constant_or_index_expr(x, dtype): + if isinstance(x, sympy.Expr) and not x.is_number: + return ops.to_dtype(ops.index_expr(x, torch.int64), dtype) + return ops.constant(x, dtype) + + def scale_fn(data): + c = get_constant_or_index_expr(correction, dtype) + N = get_constant_or_index_expr(rnumel, dtype) + zero = ops.constant(0, dtype) + return data / ops.maximum(zero, N - c) + + var = make_pointwise(scale_fn)(m2) + + if return_mean: + mean.realize() + return var, mean + return (var,) + + +def var_mean_helper_(x, *, axis, correction, keepdim, return_mean): + out_dtype = x.get_dtype() + compute_dtype = get_computation_dtype(out_dtype) + x = to_dtype(x, compute_dtype, copy=False) + kwargs = dict( + x=x, + axis=axis, + correction=correction, + keepdim=keepdim, + return_mean=return_mean, + ) + output = ( + var_mean_sum_(**kwargs) + if use_two_step_variance(x, axis=axis, keepdim=keepdim) + else var_mean_welford_(**kwargs) + ) + output = tuple(to_dtype(x, out_dtype, copy=False) for x in output) + return output[0] if not return_mean else output + + +@register_lowering([aten.var, prims.var]) +def var_(x, axis=None, *, correction=None, keepdim=False): + return var_mean_helper_( + x, axis=axis, correction=correction, keepdim=keepdim, return_mean=False + ) + + +@register_lowering(aten.var_mean) +def var_mean(x, axis=None, *, correction=None, keepdim=False): + return var_mean_helper_( + x, axis=axis, correction=correction, keepdim=keepdim, return_mean=True + ) + + +def pow_recursive(x, y, dtype): + if y < 0: + return pow_recursive(ops.reciprocal(x), -y, dtype) + if y == 0: + return ops.constant(1, dtype) + if y == 1: + return x + + result = pow_recursive(x, y // 2, dtype) + result = ops.mul(result, result) + if (y % 2) == 1: + result = ops.mul(result, x) + return result + + +@make_pointwise +def pow_native(a, b): + return ops.pow(a, b) + + +fallback_pow_tensor_tensor = fallback_handler( + aten.pow.Tensor_Tensor, add_to_fallback_set=False +) +fallback_pow_scalar = fallback_handler(aten.pow.Scalar, add_to_fallback_set=False) +fallback_pow_tensor_scalar = fallback_handler( + aten.pow.Tensor_Scalar, add_to_fallback_set=False +) + + +@register_lowering(aten.pow, broadcast=True) +def pow(a, b): + if isinstance(b, float) and b == int(b): + return pow(a, int(b)) + elif isinstance(b, float) and b == 0.5: + return sqrt(a) + elif isinstance(b, int) and b == 1: + return clone(a) + + # Type promotion ensures all tensor arguments have the same type + dtype = next(x.get_dtype() for x in (a, b) if isinstance(x, ir.TensorBox)) + is_integer_pow = is_integer_dtype(dtype) + + # Optimize away small fixed powers, or for integers avoid falling back to ATen + embed_exponent = isinstance(b, int) and ( + -32 < b < 32 or (is_integer_pow and b >= 0) + ) + if embed_exponent: + loader = a.make_loader() + + def fn(idx): + return pow_recursive(loader(idx), b, a.get_dtype()) + + return Pointwise.create( + device=a.get_device(), + dtype=a.get_dtype(), + inner_fn=fn, + ranges=a.get_size(), + ) + + if isinstance(a, Number): + if a == 1: + return full_like(b, 1) + # pyrefly: ignore [missing-attribute] + if a == 2 and is_float_dtype(b.get_dtype()): + return exp2(b) + + if is_integer_pow: + # ops.pow doesn't work for integers + if isinstance(a, Number): + return fallback_pow_scalar(a, b) + elif isinstance(b, Number): + return fallback_pow_tensor_scalar(a, b) + else: + return fallback_pow_tensor_tensor(a, b) + + return pow_native(a, b) + + +def mutate_to(changed, val, unsafe_alias=False): + if isinstance(changed, TensorBox): + changed_data = changed.data + else: + changed_data = changed + if isinstance(val, TensorBox): + val = val.data + + if not isinstance(val, ir.StorageBox): + # introduce a copy to handle views + node = Pointwise.create( + device=changed.get_device(), + dtype=changed.get_dtype(), + inner_fn=val.make_loader(), + ranges=changed.get_size(), + ) + assert isinstance(node, (BaseView, MutableBox)) + val = node.data + assert isinstance(val, ir.StorageBox) + + if isinstance(changed_data, ir.StorageBox) and not ( + changed_data.is_input_buffer() + # In AOTI, module parameters and buffers are not lifted as graph inputs + or changed_data.is_module_buffer() + or isinstance(changed_data.data, ir.NopKernel) + ): + # Fast path, just swing the data pointer + val.realize() + changed_data.data = val.data + return changed + + ir.MutationLayoutSHOULDREMOVE.realize_into( + val, changed_data, unsafe_alias=unsafe_alias + ) + return changed + + +@register_lowering(aten.fill_) +def fill_(x, fill_value): + return mutate_to(x, full_like(x, fill_value)) + + +@register_lowering(aten.copy_, type_promotion_kind=None) +def copy_(dst, src, non_blocking=False): + if dst is src: + # dst.copy_(dst) can happen from the reinplacing pass + return dst + src = to_device(src, dst.get_device()) + src = to_dtype(src, dst.get_dtype()) + src = expand(src, dst.get_size()) + return mutate_to(dst, src) + + +@make_pointwise +def floordiv(a, b): + return ops.floordiv(a, b) + + +@make_pointwise +def truncdiv(a, b): + return ops.truncdiv(a, b) + + +@register_lowering(aten.div, broadcast=True) +def div_mode(a, b, rounding_mode=None): + both_integer = is_integer_type(a) and is_integer_type(b) + both_boolean = is_boolean_type(a) and is_boolean_type(b) + + # floordiv and truncdiv need special handling for integer tensors on Triton, + # see the discussion at https://github.com/triton-lang/triton/issues/605 + if rounding_mode == "floor": + assert not both_boolean, "floordiv operands can not be boolean at the same time" + return floordiv(a, b) if both_integer else floor(div(a, b)) + if rounding_mode == "trunc": + assert not both_boolean, "truncdiv operands can not be boolean at the same time" + return truncdiv(a, b) if both_integer else trunc(div(a, b)) + return div(a, b) + + +@register_lowering([aten.mul], broadcast=True) +def mul(a, b): + both_bool = is_boolean_type(a) and is_boolean_type(b) + if both_bool: + return logical_and(a, b) + else: + fn = ops_wrapper(aten.mul.__name__) + return make_pointwise(fn)(a, b) + + +def get_constant_value(x: ir.IRNode) -> Optional[ir.Constant]: + """Try convert an arbitrary IR node into an ir.Constant value""" + + # First try unwrapping the IRNode to see if it is already an ir.Constant + # Optional step, but avoids unnecessary inner_fn evaluation. + if isinstance(x, ir.MutableBox): + return get_constant_value(x.data) + if isinstance(x, ir.BaseView): + return get_constant_value(x.unwrap_view()) + if isinstance(x, ir.Constant): + return x + + # If the unwrapped node is not an ir.Constant, try evaluating inner_fn + # to see if the returned value is from an `ops.constant` call + if not isinstance(x, ir.Loops): + return None + + handler = torch._inductor.ops_handler.ExtractConstantsHandler(x.get_device()) + with ( + V.set_ops_handler(handler), + patch.object(ir.FlexibleLayout, "allow_indexing", True), + ): + out = x.inner_fn(*x.inner_fn_args()) + + assert isinstance(out, torch._inductor.virtualized.OpsValue) + if isinstance(out.value, ir.Constant): + return out.value + return None + + +# NOTE: prims.div maps to a / b in C, so performs truncation division on +# integer inputs and true division for floating and complex inputs. +@register_lowering([prims.div], broadcast=True) +def div_prim(a, b): + is_integral = all(is_boolean_type(x) or is_integer_type(x) for x in [a, b]) + + if is_integral: + return truncdiv(a, b) + + # Disable CPU optimization to avoid precision issues. + # see https://github.com/pytorch/pytorch/issues/157959 + if (divisor := get_constant_value(b)) is not None and a.get_device().type != "cpu": + # Replace divide by constant with multiply by reciprocal + + if divisor.value == 0: + reciprocal = math.copysign(float("inf"), divisor.value) + else: + reciprocal = 1.0 / divisor.value + return mul(a, reciprocal) + + def fn(*args): + return ops.truediv(*args) + + return make_pointwise(fn)(a, b) + + +@register_lowering( + [aten.true_divide, aten.div.Tensor], + broadcast=True, + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def div(a, b): + a, b = promote_constants( + (a, b), type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ) + return div_prim(a, b) + + +@register_lowering([aten.fmod, prims.fmod], broadcast=True) +def fmod(a, b): + is_integral = is_boolean_type(a) or is_integer_type(a) + + if is_integral: + + def fn(a, b): + return ops.mod(a, b) + + else: + + def fn(a, b): + return ops.fmod(a, b) + + return make_pointwise(fn)(a, b) + + +@register_lowering([aten.sum, prims.sum]) +def sum_(x, axis=None, keepdims=False, *, dtype=None): + if ( + is_integer_dtype(x.get_dtype()) or is_boolean_dtype(x.get_dtype()) + ) and dtype is None: + dtype = torch.int64 + + fn = make_reduction("sum", override_return_dtype=dtype) + return fn(x, axis, keepdims, dtype=dtype) + + +fallback_cumsum = fallback_handler(aten.cumsum.default) +fallback_cumprod = fallback_handler(aten.cumprod.default) +fallback_logcumsumexp = fallback_handler(aten.logcumsumexp.default) +fallback_cummax = fallback_handler(aten.cummax.default) +fallback_cummin = fallback_handler(aten.cummin.default) + + +@register_lowering(aten.cumsum) +def cumsum(x, axis=None, dtype=None): + if ( + is_integer_dtype(x.get_dtype()) or is_boolean_dtype(x.get_dtype()) + ) and dtype is None: + dtype = torch.int64 + + if len(x.get_size()) == 0: + assert axis in [0, -1] + dtype = dtype or x.get_dtype() + return to_dtype(x, dtype, copy=True) + + def combine_fn(a_tuple, b_tuple): + (a,) = a_tuple + (b,) = b_tuple + return (ops.add(a, b),) + + kwargs = _make_scan_inner(x, axis=axis, dtype=dtype) + (result,) = ir.Scan.create(**kwargs, combine_fn=combine_fn) + if result is None: + return fallback_cumsum(x, dim=axis, dtype=dtype) + return result + + +@register_lowering(aten.cumprod) +def cumprod(x, axis=None, dtype=None): + if ( + is_integer_dtype(x.get_dtype()) or is_boolean_dtype(x.get_dtype()) + ) and dtype is None: + dtype = torch.int64 + + if len(x.get_size()) == 0: + assert axis in [0, -1] + dtype = dtype or x.get_dtype() + return to_dtype(x, dtype, copy=True) + + def combine_fn(a_tuple, b_tuple): + (a,) = a_tuple + (b,) = b_tuple + return (ops.mul(a, b),) + + kwargs = _make_scan_inner(x, axis=axis, dtype=dtype) + (result,) = ir.Scan.create(**kwargs, combine_fn=combine_fn) + if result is None: + return fallback_cumprod(x, dim=axis, dtype=dtype) + return result + + +@register_lowering(aten.logcumsumexp) +def logcumsumexp(x, dim): + def log_add_exp_helper(a_tuple, b_tuple): + (a,) = a_tuple + (b,) = b_tuple + min_v = ops.minimum(a, b) + max_v = ops.maximum(a, b) + mask = (min_v != max_v) | (~ops.isinf(min_v)) + return (ops.where(mask, ops.log1p(ops.exp(min_v - max_v)) + max_v, a),) + + dtype = x.get_dtype() + if len(x.get_size()) == 0: + assert dim in [0, -1] + return clone(x) + + kwargs = _make_scan_inner(x, axis=dim, dtype=dtype) + (result,) = ir.Scan.create(**kwargs, combine_fn=log_add_exp_helper) + if result is None: + return fallback_logcumsumexp(x, dim=dim) + return result + + +@register_lowering(aten.cummax, type_promotion_kind=None) +def cummax(x, axis=None): + if len(x.get_size()) == 0: + assert axis in [0, -1] + return clone(x), empty_like(x, dtype=torch.int64) + + dtype = x.get_dtype() + combine_fn = ir.get_reduction_combine_fn( + "argmax", dtype=dtype, arg_break_ties_left=False + ) + + kwargs = _make_scan_inner(x, axis=axis, dtype=dtype) + kwargs["dtypes"] = (dtype, torch.int64) + kwargs["inner_fns"] = ( + x.make_loader(), + lambda idx: ops.index_expr(idx[axis], torch.int64), + ) + values, indices = ir.Scan.create(**kwargs, combine_fn=combine_fn) # type: ignore[arg-type] + if values is None: + return fallback_cummax(x, dim=axis) + return values, indices + + +@register_lowering(aten.cummin, type_promotion_kind=None) +def cummin(x, axis=None): + if len(x.get_size()) == 0: + assert axis in [0, -1] + return clone(x), empty_like(x, dtype=torch.int64) + + dtype = x.get_dtype() + combine_fn = ir.get_reduction_combine_fn( + "argmin", dtype=dtype, arg_break_ties_left=False + ) + + kwargs = _make_scan_inner(x, axis=axis, dtype=dtype) + kwargs["dtypes"] = (dtype, torch.int64) + kwargs["inner_fns"] = ( + x.make_loader(), + lambda idx: ops.index_expr(idx[axis], torch.int64), + ) + values, indices = ir.Scan.create(**kwargs, combine_fn=combine_fn) # type: ignore[arg-type] + if values is None: + return fallback_cummin(x, dim=axis) + return values, indices + + +@register_lowering(aten.prod) +def prod(x, axis=None, keepdims=False, *, dtype=None): + if ( + is_integer_dtype(x.get_dtype()) or is_boolean_dtype(x.get_dtype()) + ) and dtype is None: + dtype = torch.int64 + + fn = make_reduction("prod", override_return_dtype=dtype) + return fn(x, axis, keepdims, dtype=dtype) + + +@register_lowering(aten.any) +def reduce_any(x, dim=None, keepdim=False): + x = to_dtype(x, torch.bool) + return make_reduction("any")(x, axis=dim, keepdims=keepdim) + + +@register_lowering(aten.max, type_promotion_kind=None) +def reduce_max(x, dim=None, keepdim=False): + if dim is not None: + return ( + reduce_amax(x, axis=dim, keepdims=keepdim), + reduce_argmax(x, axis=dim, keepdims=keepdim), + ) + + return reduce_amax(x, axis=None, keepdims=keepdim) + + +@register_lowering(aten.min, type_promotion_kind=None) +def reduce_min(x, dim=None, keepdim=False): + if dim is not None: + return ( + reduce_amin(x, axis=dim, keepdims=keepdim), + reduce_argmin(x, axis=dim, keepdims=keepdim), + ) + + return reduce_amin(x, axis=None, keepdims=keepdim) + + +register_lowering(prims.xor_sum)(make_reduction("xor_sum")) +reduce_amax = register_lowering(aten.amax)(make_reduction("max")) +reduce_amin = register_lowering(aten.amin)(make_reduction("min")) +reduce_argmax = register_lowering(aten.argmax)( + make_reduction("argmax", override_return_dtype=torch.int64) +) +reduce_argmin = register_lowering(aten.argmin)( + make_reduction("argmin", override_return_dtype=torch.int64) +) + +add = register_pointwise( + aten.add, allow_alpha=True, override_fn_when_input_bool="logical_or" +) + +sort_fallback = fallback_handler(aten.sort.stable, add_to_fallback_set=False) + + +@register_lowering(aten.sort.stable, type_promotion_kind=None) +def sort_stable(x, *, stable=None, dim=-1, descending=False): + if stable is None: + stable = False + + shape = x.get_size() + device = x.get_device() + dim = canonicalize_dim(len(shape), dim) + if len(shape) == 0: + return clone(x), _full(0, device, torch.int64, shape) + + dim_size = shape[dim] if len(shape) else 1 + if not V.graph.sizevars.statically_known_lt(dim_size, torch.iinfo(torch.int16).max): + return sort_fallback(x, stable=stable, dim=dim, descending=descending) + + indices = iota( + dim_size, start=0, step=1, dtype=torch.int16, device=device, requires_grad=False + ) + view_shape = [1] * len(shape) + if len(shape): + view_shape[dim] = dim_size + indices = view(indices, view_shape) + indices = expand(indices, shape) + + values, indices = ir.Sort.create( + device=device, + dtypes=(x.dtype, indices.dtype), + inner_fns=(x.make_loader(), indices.make_loader()), + size=shape, + axis=dim, + stable=stable, + descending=descending, + ) + if values is None: + return sort_fallback(x, stable=stable, dim=dim, descending=descending) + + assert indices is not None + return values, to_dtype(indices, torch.int64) + + +@register_lowering(aten.sort.default, type_promotion_kind=None) +def sort(x, dim=-1, descending=False): + return sort_stable(x, stable=False, dim=dim, descending=descending) + + +def register_pointwise_numeric(op, name=None, triton_fallback=None): + return register_pointwise( + op, + name=name, + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + triton_fallback=triton_fallback, + ) + + +def register_pointwise_numeric_ldf64(op: torch._ops.OpOverloadPacket): + register_op_requires_libdevice_fp64(op.__name__) + return register_pointwise( + op, + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + ) + + +rsqrt = register_pointwise_numeric(aten.rsqrt) +exp = register_pointwise_numeric_ldf64(aten.exp) +exp2 = register_pointwise_numeric(aten.exp2) +expm1 = register_pointwise_numeric(aten.expm1) +relu = register_pointwise(aten.relu) +sigmoid = register_pointwise_numeric_ldf64(aten.sigmoid) +sqrt = register_pointwise_numeric_ldf64(aten.sqrt) +square = register_pointwise(aten.square) +sub = register_pointwise(aten.sub, allow_alpha=True) +register_pointwise_numeric_ldf64(aten.cos) +register_pointwise_numeric_ldf64(aten.sin) +abs = register_pointwise(aten.abs) +bitwise_and = register_pointwise(aten.bitwise_and) +bitwise_left_shift = register_pointwise(aten.bitwise_left_shift) +bitwise_not = register_pointwise( + aten.bitwise_not, override_fn_when_input_bool="logical_not" +) +bitwise_or = register_pointwise(aten.bitwise_or) +bitwise_right_shift = register_pointwise(aten.bitwise_right_shift) +bitwise_xor = register_pointwise(aten.bitwise_xor) +register_pointwise_numeric(aten.lgamma) +erf = register_pointwise_numeric(aten.erf) +register_lowering( + aten.special_erf, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT +)(erf) + +register_pointwise_numeric(aten.log1p) +register_pointwise_numeric(aten.tan) +register_pointwise_numeric(aten.tanh) +register_pointwise_numeric_ldf64(aten.log) +logical_and = register_pointwise( + aten.logical_and, + type_promotion_kind=None, + convert_input_to_bool=True, + override_return_dtype=torch.bool, +) +logical_not = register_pointwise( + aten.logical_not, + type_promotion_kind=None, + convert_input_to_bool=True, + override_return_dtype=torch.bool, +) +logical_or = register_pointwise( + aten.logical_or, + type_promotion_kind=None, + convert_input_to_bool=True, + override_return_dtype=torch.bool, +) +logical_xor = register_pointwise( + aten.logical_xor, + type_promotion_kind=None, + convert_input_to_bool=True, + override_return_dtype=torch.bool, +) +maximum = register_pointwise(aten.maximum) +minimum = register_pointwise(aten.minimum) +register_lowering(aten.clamp_min)(maximum) +register_lowering(aten.clamp_max)(minimum) +neg = register_pointwise(aten.neg) +abs = register_pointwise(aten.abs) +reciprocal = register_pointwise_numeric(aten.reciprocal) +register_pointwise(aten.remainder) +sign = register_pointwise(aten.sign, override_fn_when_input_bool="identity") +register_pointwise(aten.ceil) +register_pointwise(aten.signbit, override_return_dtype=torch.bool) + +register_lowering(aten._neg_view)(neg) + +register_pointwise(aten.le, override_return_dtype=torch.bool) +register_pointwise(aten.lt, override_return_dtype=torch.bool) +register_pointwise(aten.ge, override_return_dtype=torch.bool) +gt = register_pointwise(aten.gt, override_return_dtype=torch.bool) +register_pointwise(aten.eq, override_return_dtype=torch.bool) +register_pointwise(aten.ne, override_return_dtype=torch.bool) + +register_pointwise_numeric(aten.cosh) +register_pointwise_numeric(aten.sinh) +register_pointwise_numeric(aten.acos) +register_pointwise_numeric(aten.acosh) +register_pointwise_numeric(aten.asin) +register_pointwise_numeric(aten.asinh) +register_pointwise_numeric(aten.atan2) +register_pointwise_numeric(aten.atan) +register_pointwise_numeric(aten.atanh) +register_pointwise_numeric(aten.copysign) +register_pointwise_numeric(aten.erfc) +register_pointwise_numeric(aten.erfinv) +register_pointwise_numeric(aten.hypot) +register_pointwise_numeric(aten.log10) +register_pointwise_numeric(aten.log2) +register_pointwise_numeric(aten.nextafter) + +from .codegen.common import BackendFeature, pointwise_overrides_data + + +def _get_pointwise_overrides(ns, name): + data = pointwise_overrides_data[name] + op = getattr(ns, data.name, None) + if op is None: + return + + def make_triton_fallback(op): + if data.triton is None: + return fallback_handler(op) + + if isinstance(op, torch._ops.OpOverloadPacket): + for olname in op.overloads(): + ol = getattr(op, olname) + yield ol, data.type_promotion_kind, make_triton_fallback(ol) + else: + yield op, data.type_promotion_kind, make_triton_fallback(op) + + +for name in pointwise_overrides_data: + for op, type_promotion_kind, triton_fallback in _get_pointwise_overrides( + aten, name + ): + register_pointwise( + op, + name=name, + type_promotion_kind=type_promotion_kind, + triton_fallback=triton_fallback, + ) + + for op, type_promotion_kind, triton_fallback in _get_pointwise_overrides( + prims, name + ): + register_pointwise( + op, + name=name, + type_promotion_kind=type_promotion_kind, + triton_fallback=triton_fallback, + ) + + +foreach_add_list = register_foreach_pointwise( + aten._foreach_add.List, add, allow_alpha=True +) +foreach_add_scalar = register_foreach_pointwise( + aten._foreach_add.Scalar, add, allow_alpha=True +) +register_foreach_pointwise(aten._foreach_add.Tensor, add, allow_alpha=True) +foreach_mul_list = register_foreach_pointwise(aten._foreach_mul.List, mul) +register_foreach_pointwise(aten._foreach_mul.Tensor, mul) +foreach_mul_scalar = register_foreach_pointwise(aten._foreach_mul.Scalar, mul) +register_foreach_pointwise(aten._foreach_sub.List, sub) +register_foreach_pointwise(aten._foreach_sub.Scalar, sub) +register_foreach_pointwise(aten._foreach_neg.default, neg) +register_foreach_pointwise(aten._foreach_abs.default, abs) +register_foreach_pointwise(aten._foreach_pow.Scalar, pow) +register_foreach_pointwise(aten._foreach_pow.List, pow) +register_foreach_pointwise(aten._foreach_pow.ScalarAndTensor, pow) +foreach_div_list = register_foreach_pointwise(aten._foreach_div.List, div) +register_foreach_pointwise(aten._foreach_div.Tensor, div) +foreach_div_scalar = register_foreach_pointwise(aten._foreach_div.Scalar, div) +register_foreach_pointwise(aten._foreach_sqrt, sqrt) +register_foreach_pointwise(aten._foreach_rsqrt, rsqrt) +register_foreach_pointwise(aten._foreach_maximum.List, maximum) +register_foreach_pointwise(aten._foreach_maximum.Scalar, maximum) +register_foreach_pointwise(aten._foreach_minimum.List, minimum) +register_foreach_pointwise(aten._foreach_minimum.Scalar, minimum) +register_foreach_pointwise(aten._foreach_clamp_min.List, maximum) +register_foreach_pointwise(aten._foreach_clamp_min.Scalar, maximum) +register_foreach_pointwise(aten._foreach_clamp_max.List, minimum) +register_foreach_pointwise(aten._foreach_clamp_max.Scalar, minimum) +register_foreach_pointwise(aten._foreach_reciprocal, reciprocal) +register_foreach_pointwise(aten._foreach_sign, sign) +foreach_copy = register_foreach_pointwise(aten._foreach_copy, copy) + + +# these are only encountered as outputs of the graph +# reinplacing epilogue copies improves compile time +# by removing extra buffers sent to the scheduler. +def register_foreach_inplace(aten_op, outplace_aten_op, outplace_op): + inplaceable_foreach_ops[outplace_aten_op] = aten_op + inplace_foreach_ops.add(aten_op) + + def fn(*args, **kwargs): + results = outplace_op(*args, **kwargs) + mut_results = [] + for arg, result in zip(args[0], results): + mut_results.append(mutate_to(arg, result, unsafe_alias=True)) + + return mut_results + + _register_foreach_lowering(aten_op, fn) + + +register_foreach_inplace( + aten._foreach_add_.List, aten._foreach_add.List, foreach_add_list +) +register_foreach_inplace( + aten._foreach_add_.Scalar, aten._foreach_add.Scalar, foreach_add_scalar +) +register_foreach_inplace( + aten._foreach_mul_.List, aten._foreach_mul.List, foreach_mul_list +) +register_foreach_inplace( + aten._foreach_mul_.Scalar, aten._foreach_mul.Scalar, foreach_mul_scalar +) +register_foreach_inplace( + aten._foreach_div_.List, aten._foreach_div.List, foreach_div_list +) +register_foreach_inplace( + aten._foreach_div_.Scalar, aten._foreach_div.Scalar, foreach_div_scalar +) +register_foreach_inplace( + aten._foreach_copy_.default, aten._foreach_copy.default, foreach_copy +) + + +def register_inplace(aten_op, outplace_op): + @register_lowering(aten_op, type_promotion_kind=None) + def fn(*args, **kwargs): + result = outplace_op(*args, **kwargs) + result = to_dtype(result, args[0].get_dtype()) + return mutate_to(args[0], result) + + return fn + + +register_inplace(aten.add_, add) +register_inplace(aten.bitwise_and_, bitwise_and) +register_inplace(aten.bitwise_left_shift_, bitwise_left_shift) +register_inplace(aten.bitwise_not_, bitwise_not) +register_inplace(aten.bitwise_or_, bitwise_or) +register_inplace(aten.bitwise_right_shift_, bitwise_right_shift) +register_inplace(aten.bitwise_xor_, bitwise_xor) +register_inplace(aten.mul_, mul) +register_inplace(aten.div_.Tensor, div) +register_inplace(aten.div_.Tensor_mode, div_mode) +register_inplace(aten.logical_and_, logical_and) +register_inplace(aten.logical_not_, logical_not) +register_inplace(aten.logical_or_, logical_or) +register_inplace(aten.logical_xor_, logical_xor) +register_inplace(aten.sub_, sub) +register_inplace(aten.relu_, relu) +register_inplace(aten.sigmoid_, sigmoid) + + +register_lowering(aten.__and__)(bitwise_and) +register_lowering(aten.__lshift__)(bitwise_left_shift) +register_lowering(aten.__or__)(bitwise_or) +register_lowering(aten.__rshift__)(bitwise_right_shift) +register_lowering(aten.__xor__)(bitwise_xor) + +register_inplace(aten.__iand__, aten.__and__) +register_inplace(aten.__ilshift__, aten.__lshift__) +register_inplace(aten.__ior__, aten.__or__) +register_inplace(aten.__irshift__, aten.__rshift__) +register_inplace(aten.__ixor__, aten.__xor__) + + +@register_lowering(aten.sym_constrain_range) +def sym_constrain_range(a, min=None, max=None): + return None + + +@register_lowering(aten.sym_size.int) +def sym_size(a, dim): + val = V.graph.current_node.meta["val"] + if isinstance(val, torch.SymInt): + return val.node.expr + else: + return int(val) + + +@register_lowering(aten.sym_stride.int) +def sym_stride(a, dim): + val = V.graph.current_node.meta["val"] + if isinstance(val, torch.SymInt): + return val.node.expr + else: + return int(val) + + +@register_lowering(aten.sym_numel) +def sym_numel(a): + return a.get_numel() + + +for method, func in magic_methods.items(): + register_lowering(method_to_operator(method))(func) # type: ignore[arg-type] + + +@register_lowering(torch.sym_sum) +def sym_sum(args): + return sympy.Add(*args) + + +@register_lowering(aten._foobar) +def foobar(self, *args, **kwargs): + raise NotImplementedError("Helpful for debugging") + + +@register_lowering(torch.ops._inductor_test.realize) +def _realize(x): + x.realize() + return clone(x) + + +@register_lowering(torch.ops.inductor.resize_storage_bytes_) +def resize_storage_bytes_(variable, new_size): + variable.realize() + ir.ResizeStorageBytes(variable, new_size) + return variable + + +@register_lowering(torch.ops.aten.set_.source_Tensor) +def set__source_tensor(self, source_tensor): + self.realize() + source_tensor.realize() + return TensorBox.create(ir.SetSourceTensorKernel(self, source_tensor)) + + +if hasattr(torch.ops.fsdp, "copy_"): + + @register_lowering(torch.ops.fsdp.copy_.default) + def fsdp_copy_(dst, src): + if dst is src: + # dst.copy_(dst) can happen from the reinplacing pass + return dst + src = to_device(src, dst.get_device()) + src = to_dtype(src, dst.get_dtype()) + src = expand(src, dst.get_size()) + return mutate_to(dst, src) + + +@register_lowering(torch.ops.aten.resize) +def resize(x, size, *, memory_format=None): + assert isinstance(x, TensorBox) + assert isinstance(size, (list, tuple)) + + if memory_format is None: + memory_format = torch.contiguous_format + if memory_format == torch.preserve_format: + raise RuntimeError(f"unsupported memory format: {memory_format}") + + if memory_format == torch.channels_last: + assert len(size) == 4 + if memory_format == torch.channels_last_3d: + assert len(size) == 5 + + old_numel = x.get_numel() + dtype = x.get_dtype() + device = x.get_device_or_error() + + if isinstance(x.data, ir.BaseView): + x.data = x.data.unwrap_view() + + if ( + torch.are_deterministic_algorithms_enabled() + and torch.utils.deterministic.fill_uninitialized_memory # type: ignore[attr-defined] + ): + if is_float_dtype(dtype): + uninitialized_val = float("nan") + elif is_integer_dtype(dtype): + uninitialized_val = torch.iinfo(dtype).max + else: + uninitialized_val = True + else: + # using zero as that is what empty does + uninitialized_val = 0.0 + + if V.graph.sizevars.statically_known_equals(old_numel, 0): # type: ignore[arg-type] + return full(size, uninitialized_val, dtype=dtype, device=device) + + x_flat = as_strided( + x, + [ + old_numel, + ], + [ + 1, + ], + ) + flat_loader = x_flat.make_loader() + out_stride = ir.FlexibleLayout.stride_ordered_for_memory_format(size, memory_format) + out_indexer = ir.FixedLayout(device, dtype, size, out_stride).make_indexer() + + def inner_fn(idx): + flat_index = out_indexer(idx) + flat_index_expr = ops.index_expr(flat_index, torch.int64) + limit = ops.index_expr(old_numel, torch.int64) + mask = ops.lt(flat_index_expr, limit) + return ops.masked(mask, lambda: flat_loader([flat_index]), uninitialized_val) + + out = Pointwise.create( + device=device, dtype=dtype, inner_fn=inner_fn, ranges=list(size) + ) + return out + + +from torch._higher_order_ops.auto_functionalize import auto_functionalized + + +make_fallback(auto_functionalized) + + +@register_lowering(triton_kernel_wrapper_mutation) +def triton_kernel_wrap_( + *, + kernel_idx, + constant_args_idx, + grid, + tma_descriptor_metadata, + kwargs, +): + from torch._higher_order_ops.triton_kernel_wrap import kernel_side_table + + constant_args = kernel_side_table.get_constant_args(constant_args_idx) + ir.UserDefinedTritonKernel( + kernel_idx=kernel_idx, + grid=grid, + tma_descriptor_metadata=tma_descriptor_metadata, + kernel_args={**kwargs, **constant_args}, + ) + return {key: val for key, val in kwargs.items() if isinstance(val, TensorBox)} + + +@register_lowering(torch.ops.higher_order.cond, type_promotion_kind=None) +def cond( + pred, true_fn, false_fn, operands +) -> list[Union[ir.TensorBox, ir.ShapeAsConstantBuffer]]: + # TODO: when graph_partition is enabled, skip - partitioning handles control flow + # we run into memory cleanup issue + if any(isinstance(x, IRNode) and is_triton(x) for x in [pred, *operands]): + msg = "control flow operator: torch.cond." + if stack_trace := V.graph.current_node.meta.get("stack_trace", None): + msg = f"{msg} Found from : \n {stack_trace}" + V.graph.disable_cudagraphs_reason = msg + + result = ir.Conditional.create(pred, true_fn, false_fn, operands) + return list(map(TensorBox.create, result)) + + +@register_lowering(torch.ops.higher_order.while_loop, type_promotion_kind=None) +def while_loop(cond_fn, body_fn, carried_inputs, additional_inputs, stack_output=False): + # TODO: when graph_partition is enabled, skip - partitioning handles control flow + # we run into memory cleanup issue + if not config.graph_partition and any( + isinstance(x, IRNode) and is_triton(x) + for x in carried_inputs + additional_inputs + ): + msg = "control flow operator: torch.while_loop." + if stack_trace := V.graph.current_node.meta.get("stack_trace", None): + msg = f"{msg} Found from : \n {stack_trace}" + V.graph.disable_cudagraphs_reason = msg + + result = ir.WhileLoop.create( + cond_fn, body_fn, carried_inputs, additional_inputs, stack_output + ) + assert isinstance(result, Sequence) + return list(map(ir.WhileLoop._maybe_wrap_as_tensor_box, result)) + + +register_lowering( + torch.ops.higher_order.while_loop_stack_output, type_promotion_kind=None +)(functools.partial(while_loop, stack_output=True)) + + +@register_lowering(torch.ops.higher_order.invoke_subgraph, type_promotion_kind=None) +def invoke_subgraph(subgraph_fn: ir.Subgraph, identifier: str, *operands): + result = ir.InvokeSubgraph.create(subgraph_fn, *operands) + return list(map(TensorBox.create, result)) # type: ignore[call-overload] + + +def process_subgraph_nodes(graph_module: torch.fx.GraphModule, args: list[Any]): + """Process nodes from a FX graph by executing them through V.graph. + + This is a common pattern for executing a subgraph's nodes: + - Placeholder nodes are mapped to the provided args + - Output nodes return their result + - Other nodes are executed via V.graph.run_node + + """ + output = None + + for i, node in enumerate(graph_module.graph.nodes): + if node.op == "placeholder": + assert node not in V.graph.env + V.graph.env[node] = args[i] + continue + elif node.op == "output": + output_args, kwargs = V.graph.fetch_args_kwargs_from_env(node) + output = torch.fx.Interpreter.output(V.graph, node, output_args, kwargs) + else: + assert node not in V.graph.env + V.graph.env[node] = V.graph.run_node(node) + + if output is None: + raise RuntimeError("No output node found in graph") + + return output + + +# Import the control_deps_op HOP for lowering +from torch._inductor.fx_passes.control_dependencies import control_deps + + +@register_lowering(control_deps, type_promotion_kind=None) +def control_deps_op_lowering(additional_deps, subgraph_fn, *args): + """ + Lower control_deps_op by ensuring dependencies are realized and tracking them. + + The control_deps_op HOP makes dependencies explicit in the graph. During lowering: + 1. Realize all additional dependencies to ensure they're computed + 2. Execute the target operation normally + 3. Track the dependencies for the scheduler + """ + # Realize all additional dependencies + dep_names = [] + for dep in additional_deps: + if not isinstance(dep, IRNode): + continue + + dep.realize() + dep_names.append(dep.get_name()) + + original_args = V.graph.current_node.args + arg_offset = 2 # first two args (additional_deps, subgraph) + assert len(args) + arg_offset == len(original_args) + + operation_len = len(V.graph.operations) + assert len(subgraph_fn.graph_module.graph.find_nodes(op="placeholder")) == len(args) + + # Process subgraph nodes using the shared helper + output = process_subgraph_nodes(subgraph_fn.graph_module, list(args)) + + assert output is not None and additional_deps + + # some operators, like wait_tensor, just return their input, + # so its more robust to add dep to the operation itself, + # otherwise you can have a cycle of + # a = coll + # b = control_deps(a, mm, ...) + # c = control_deps(b, wait, ...) + # if c == a, then you have a cycle. + for op in V.graph.operations[operation_len:]: + for dep_name in dep_names: + op_name = op.operation_name + assert op_name is not None + V.graph.additional_buffer_deps[op_name].add(dep_name) + + return output + + +@register_lowering(torch._higher_order_ops.invoke_quant, type_promotion_kind=None) +def invoke_quant_tracer(subgraph_fn: ir.Subgraph, *operands, scheme=None): + output = None + quant_options = V.graph.current_node.meta.get("quant_options", None) + assert quant_options is not None + + for i, node in enumerate(subgraph_fn.graph_module.graph.nodes): + if node.op == "placeholder": + V.graph.env[node] = operands[i] + continue + # todo getattr + elif node.op == "output": + args, kwargs = V.graph.fetch_args_kwargs_from_env(node) + + for v in itertools.chain(args, kwargs.values()): + v.realize() + + if quant_options.codegen_low_precision: + V.graph.low_precision_codegen_ops.add(v.get_operation_name()) + + V.graph.invoke_quant_ops.add(v.get_operation_name()) + + output = torch.fx.Interpreter.output(V.graph, node, args, kwargs) + else: + V.graph.env[node] = V.graph.run_node(node) + + return output + + +@register_lowering(associative_scan_op, type_promotion_kind=None) +def associative_scan( + combine_fn: ir.Subgraph, xs, additional_inputs: tuple[torch.Tensor] +): + from .subgraph_lowering import InputDescriptor, lower_pointwise_subgraph + + if len(additional_inputs) > 0: + raise RuntimeError( + "Unable to generate code for associative_scan op, because there are lifted arguments" + ) + + subgraph_inputs = [ + InputDescriptor(dtype=x.get_dtype(), device=x.get_device()) + for x in itertools.chain(xs, xs) + ] + lowered_combine_fn = lower_pointwise_subgraph(combine_fn, subgraph_inputs) # type: ignore[var-annotated] + + def wrapped_combine_fn(lhs, rhs): + return lowered_combine_fn( + *pytree.tree_leaves(lhs), + *pytree.tree_leaves(rhs), + ) + + kwargs = _make_scan_inner(xs[0], axis=0, dtype=None) + kwargs["dtypes"] = tuple(x.get_dtype() for x in xs) + kwargs["inner_fns"] = tuple(x.make_loader() for x in xs) + result = ir.Scan.create( + combine_fn=wrapped_combine_fn, + can_fallback_to_aten=False, + **kwargs, + ) + if result[0] is None: + raise RuntimeError("Unable to generate code for associative_scan op") + return result + + +@register_lowering(torch.ops.prims._sink_tokens.default) +def _sink_tokens(tokens): + return None + + +@register_lowering(torch.ops.prims._make_token.default) +def _make_token(): + return None + + +@register_lowering(torch.ops.higher_order.with_effects, type_promotion_kind=None) +def with_effects(token, op, *args, **kwargs): + """ + We lower the operator directly, and then we add StarDep dependencies to all + the newly created nodes in the graph. + """ + from torch._higher_order_ops.effects import _get_effect, _get_schema + + # Get effect type + effect_type = _get_effect(op) + if effect_type is None and op is torch.ops.higher_order.invoke_subgraph: + from torch._guards import InvokeSubgraphCache, TracingContext + + tracing_ctx = TracingContext.try_get() + if tracing_ctx: + invoke_subgraph_cache = tracing_ctx.hop_dispatch_set_cache.get_cache( + torch.ops.higher_order.invoke_subgraph + ) + if invoke_subgraph_cache: + assert isinstance(invoke_subgraph_cache, InvokeSubgraphCache) + # args[1] is identifier + effects = invoke_subgraph_cache.get_effects(args[1]) + if effects: + assert len(effects) == 1, "Multiple effects NYI" + effect_type = next(iter(effects)) + + # Track operations before + operation_len = len(V.graph.operations) + + # Lower the op + if op in lowerings: + result = lowerings[op](*args, **kwargs) + # Realize so that we can get the ops to show up in V.graph.operations + pytree.tree_map_only(TensorBox, lambda a: a.realize(), result) + else: + + def wrap_tensors(x): + return TensorBox.create(x) if isinstance(x, ir.IRNode) else x + + result = pytree.tree_map( + wrap_tensors, ir.FallbackKernel.create(op, *args, **kwargs) + ) + + # Get all the operations created during the lowering above, and add StarDeps + # to the previous node with the same effect + assert len(V.graph.operations[operation_len:]) > 0, ( + f"No operation nodes were generated when lowering effectful operator {op}." + ) + if effect_type: + prev_effect_buffer = V.graph.effectful_ops.get(effect_type) + for new_op in V.graph.operations[operation_len:]: + # Patch has_side_effects to return True + new_op.has_side_effects = lambda: True # pyrefly: ignore[missing-attribute] + if prev_effect_buffer: + op_name = new_op.get_name() # pyrefly: ignore[missing-attribute] + V.graph.additional_star_deps[op_name].add(prev_effect_buffer.get_name()) + # Update the effectful ops chain to point to the latest operation + V.graph.effectful_ops[effect_type] = ( # pyrefly: ignore[missing-attribute] + new_op # pyrefly: ignore[unsupported-operation] + ) + + try: + args, kwargs = pytree.tree_map_only( + ir.TorchBindObject, lambda a: a.get_value(), (args, kwargs) + ) + schema = _get_schema(op, args, kwargs) + except RuntimeError as e: + error_msg = str(e) + log.warning( + "Failed to get schema for %s: %s. Assuming list output", op, error_msg + ) + return (token, *result) + + if len(schema.returns) == 0: + return (token, result) + elif len(schema.returns) == 1: + return (token, result) + else: + return (token, *result) + + +from .comm_lowering import register_comm_lowerings + + +register_comm_lowerings() + + +@register_lowering(inductor_prims.prepare_softmax_online, type_promotion_kind=None) +def prepare_softmax_online(x, dim): + """ + Lowering inductor_prims.prepare_softmax_online to compute max/sum in one pass if no split is needed. + """ + kwargs = _make_reduction_inner( + x, axis=dim, keepdims=True, dtype=None, override_return_dtype=None + ) + + reduction_ranges = kwargs["reduction_ranges"] + rnumel = V.graph.sizevars.simplify(sympy_product(reduction_ranges)) + hint, num_split = ir.Reduction.num_splits( + **kwargs, + reduction_type="online_softmax_reduce", # type: ignore[arg-type] + reduction_numel=rnumel, + ) + + if num_split == 1 and V.graph.sizevars.statically_known_geq( + rnumel, config.unroll_reductions_threshold + ): + max_tensor, sum_tensor = OnlineSoftmaxReduction.create( + input_node=x, num_output=2, reduction_hint=hint, **kwargs + ) + return max_tensor, sum_tensor + else: + # Note: [Split online_softmax_reduce] + # We don't split reduction for online_softmax_reduce for now. + # On one hand, supporting split reduction makes things complex since + # the split out reuctions requires 2 inputs rather than one. + # On the other hand, during training the online_softmax_reduce should + # usually don't requires a split due to large batch size + # (more specifically batch size times sequence length). + # We should support split reduction if we find legit use cases to + # motivate the work. + # + # TODO: does inference need split online_softmax_reduce? + + warnings.warn( + textwrap.dedent( + """ + Online softmax is disabled on the fly since Inductor decides to + split the reduction. Cut an issue to PyTorch if this is an + important use case and you want to speed it up with online + softmax. + """ + ) + ) + amax = reduce_amax(x, dim, keepdims=True) + exp = lowerings[aten.exp](sub(x, amax)) + xsum = sum_(exp, dim, keepdims=True) + return amax, xsum + + +# populate lowerings defined in kernel/* +from . import kernel + + +import_submodule(kernel) + +from . import quantized_lowerings + + +quantized_lowerings.register_quantized_ops() +quantized_lowerings.register_woq_mm_ops() + +from . import mkldnn_lowerings + + +mkldnn_lowerings.register_onednn_fusion_ops() + +from . import jagged_lowerings + + +jagged_lowerings.register_jagged_ops() + + +@contextlib.contextmanager +def force_fallback(op: torch._ops.OpOverload): + """ + A context manager to force fallback an op. Used in unit test + for FallbackKernel. + """ + assert isinstance(op, torch._ops.OpOverload), ( + "Only OpOverload to make the clean up easier" + ) + old_handler = lowerings.get(op) + try: + register_lowering(op)(fallback_handler(op)) + yield + finally: + if old_handler: + lowerings[op] = old_handler + else: + lowerings.pop(op) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/memory.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/memory.py new file mode 100644 index 0000000000000000000000000000000000000000..4f587b23cda0c8f5993384d2b4b9a1f8521c7c6a --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/memory.py @@ -0,0 +1,1108 @@ +from __future__ import annotations + +import collections +import dataclasses +import heapq +import logging +from typing import Optional, TYPE_CHECKING, TypedDict, Union + +import torch +from torch._environment import is_fbcode +from torch._utils_internal import signpost_event +from torch.utils._ordered_set import OrderedSet + +from . import config +from .ir import MultiOutputLayout, NoneLayout +from .utils import get_dtype_size, is_nonfreeable_buffers +from .virtualized import V + + +if TYPE_CHECKING: + from collections.abc import Callable + + from .dependencies import Dep + from .scheduler import BaseSchedulerNode, SchedulerBuffer + +from .dependencies import WeakDep + + +torch_log = logging.getLogger(__name__) + + +@dataclasses.dataclass +class PeakMemoryResult: + order: list[BaseSchedulerNode] + peak_memory: int + method: str + + +@dataclasses.dataclass +class MemoryPlanningInfoForBuffer: + size_alloc: int = 0 + size_free: int = 0 + # succ_nodes used for buffer lifetime/freeing (excludes is_fake WeakDeps) + succ_nodes: OrderedSet[BaseSchedulerNode] = dataclasses.field( + default_factory=OrderedSet + ) + # succ_nodes used for node ordering (includes is_fake WeakDeps) + succ_nodes_for_ordering: OrderedSet[BaseSchedulerNode] = dataclasses.field( + default_factory=OrderedSet + ) + + def __post_init__(self) -> None: + torch._check( + len(self.succ_nodes) <= len(self.succ_nodes_for_ordering), + lambda: f"succ_nodes must be a subset of succ_nodes_for_ordering. " + f"len(succ_nodes)={len(self.succ_nodes)}, len(succ_nodes_for_ordering)={len(self.succ_nodes_for_ordering)}", + ) + + +@dataclasses.dataclass +class MemoryPlanningInfoForNode: + index: int = 0 + size: int = 0 + pred_buffers: OrderedSet[Union[SchedulerBuffer, FreeableInputBuffer]] = ( + dataclasses.field(default_factory=OrderedSet) + ) + pred_nodes: OrderedSet[BaseSchedulerNode] = dataclasses.field( + default_factory=OrderedSet + ) + succ_nodes: OrderedSet[BaseSchedulerNode] = dataclasses.field( + default_factory=OrderedSet + ) + + +@dataclasses.dataclass +class FreeableInputBuffer: + name: str + mpi_buffer: MemoryPlanningInfoForBuffer = dataclasses.field( + default_factory=MemoryPlanningInfoForBuffer + ) + + def get_name(self) -> str: + return self.name + + def __hash__(self) -> int: + return hash(self.name) + + +def get_freeable_input_buf( + nodes: list[BaseSchedulerNode], + graph_inputs: OrderedSet[str], +) -> dict[str, FreeableInputBuffer]: + """ + Create and keep track of all input buffers that can be freed during the program + + Returns: + A dictionary containing all freeable input buffers, keyed by their names. + """ + + def _dep_size_hint(dep: Dep) -> int: + return V.graph.get_dep_size_hint(dep) + + # get freeable input buffers' successor nodes for memory lifetime (excludes is_fake WeakDeps) + # and for ordering (includes all deps) + dep_name_to_succ_nodes: dict[str, OrderedSet[BaseSchedulerNode]] = ( + collections.defaultdict(OrderedSet) + ) + dep_name_to_succ_nodes_for_ordering: dict[str, OrderedSet[BaseSchedulerNode]] = ( + collections.defaultdict(OrderedSet) + ) + dep_name_to_size: dict[str, int] = dict() + + for node in nodes: + for dep in node.read_writes.reads: + if dep.name in graph_inputs: + if not is_nonfreeable_buffers(dep): + # All deps contribute to ordering, but fake weak deps do not contribute to + # memory liveness + dep_name_to_succ_nodes_for_ordering[dep.name].add(node) + dep_name_to_size[dep.name] = _dep_size_hint(dep) + if not (isinstance(dep, WeakDep) and dep.is_fake): + dep_name_to_succ_nodes[dep.name].add(node) + + # create FreeableInputBuffer objects and add them to the returned dictionary + name_to_freeable_input_buf: dict[str, FreeableInputBuffer] = dict() + for dep_name in dep_name_to_succ_nodes_for_ordering: + name_to_freeable_input_buf[dep_name] = FreeableInputBuffer( + dep_name, + MemoryPlanningInfoForBuffer( + size_free=dep_name_to_size[dep_name], + succ_nodes=dep_name_to_succ_nodes[dep_name], + succ_nodes_for_ordering=dep_name_to_succ_nodes_for_ordering[dep_name], + ), + ) + return name_to_freeable_input_buf + + +def compute_size_for_scheduler_buffer( + name_to_buf: dict[str, SchedulerBuffer], +) -> dict[str, tuple[int, int]]: + """ + Compute the size of each scheduler buffer, including (1) memory allocated when + it is created and (2) memory deallocated when it is freed. + + We specially handle the case of MultiOutputLayout. + Consider the following case: + buf0 = some_ops_with_multi_outputs(...) + buf1 = buf0[0] # assume 10 bytes + buf2 = buf0[1] # assume 20 bytes + In such cases, + buf0: at creation, 30 bytes allocated, when deleted, 0 bytes freed + buf1: at creation, 0 bytes allocated, when deleted, 10 bytes freed + buf2: at creation, 0 bytes allocated, when deleted, 20 bytes freed + + When an operation mutates a buffer in-place, the scheduler creates a new buffer name + to track the "before" and "after" states, even though they share the same memory. + + The mutated buffer represents a rename with zero allocation and deallocation cost. + During dependency tracking, we transfer dependencies from the mutated name back to + the original buffer, ensuring the original memory is only freed when all aliases + are done. + + This handles cases where a buffer has multiple non-overlapping aliases - rather than + trying to assign free costs to individual aliases, we forward all alias dependencies + to the original buffer. + + Consider: + buf0 = op0() + buf1 = mutation_op_(buf0) + del buf0 + ... + op(buf1) + del buf1 + + The only memory events are the creation prior to op0, and the deletion following buf1. + + Returns: + A dictionary mapping a scheduler buffer to a tuple of (size_alloc, size_free). + """ + from .ir import MultiOutput + from .scheduler import OutputNode + + sched_buf_to_size: dict[str, tuple[int, int]] = dict() + + def _compute_and_update_buf_size( + sched_buf: SchedulerBuffer, user_of_MultiOutputLayout: bool = False + ) -> int: + if sched_buf.get_name() in V.graph.scheduler.mutation_real_name: + sched_buf_to_size[sched_buf.get_name()] = (0, 0) + return 0 + elif isinstance(sched_buf.node.layout, NoneLayout): + sched_buf_to_size[sched_buf.get_name()] = (0, 0) + return 0 + elif isinstance(sched_buf.node.layout, MultiOutputLayout): + size_alloc = 0 + for user in sched_buf.users: + if isinstance(user.node, OutputNode): + continue + for buf in user.node.get_outputs(): + if isinstance(buf.node, MultiOutput): + size_alloc += _compute_and_update_buf_size(buf, True) + sched_buf_to_size[sched_buf.get_name()] = ( + 0 if user_of_MultiOutputLayout else size_alloc, + 0, + ) + return size_alloc + else: + buf_size = V.graph.sizevars.size_hint( + sched_buf.node.get_numel(), fallback=0 + ) * get_dtype_size(sched_buf.node.get_dtype()) + sched_buf_to_size[sched_buf.get_name()] = ( + 0 if user_of_MultiOutputLayout else buf_size, + buf_size, + ) + return buf_size + + for sched_buf in name_to_buf.values(): + # skip if sched_buf is already processed as an user of another SchedulerBuffer + # whose layout is of the type MultiOutputLayout + if sched_buf.get_name() not in sched_buf_to_size: + _compute_and_update_buf_size(sched_buf) + + return sched_buf_to_size + + +def assign_memory_planning_info_for_scheduler_buffers( + nodes: list[BaseSchedulerNode], + name_to_buf: dict[str, SchedulerBuffer], +) -> None: + """ + For each SchedulerBuffer, assign its size info and successor nodes. + A buffer's successor nodes determines when a buffer can be freed. + """ + # get buffer sizes + sched_buf_to_size = compute_size_for_scheduler_buffer(name_to_buf) + + # get buffer's successor nodes for memory lifetime (excludes is_fake WeakDeps) + # and for ordering (includes all deps) + dep_name_to_succ_nodes: dict[str, OrderedSet[BaseSchedulerNode]] = ( + collections.defaultdict(OrderedSet) + ) + dep_name_to_succ_nodes_for_ordering: dict[str, OrderedSet[BaseSchedulerNode]] = ( + collections.defaultdict(OrderedSet) + ) + for node in nodes: + for dep in node.unmet_dependencies: + # All deps contribute to ordering, but fake weak deps do not contribute to + # memory liveness + dep_name_to_succ_nodes_for_ordering[dep.name].add(node) + if not (isinstance(dep, WeakDep) and dep.is_fake): + dep_name_to_succ_nodes[dep.name].add(node) + + # iterate in reverse, so dependencies are picked up transitively. + for mutating_buf_name, real_buf_name in reversed( + V.graph.scheduler.mutation_real_name.items() + ): + dep_name_to_succ_nodes[real_buf_name] |= dep_name_to_succ_nodes[ + mutating_buf_name + ] + dep_name_to_succ_nodes_for_ordering[real_buf_name] |= ( + dep_name_to_succ_nodes_for_ordering[mutating_buf_name] + ) + + # populate the MemoryPlanningInfoForBuffer attribute to each scheduler buffer + # note: there are scheduler buffers not in dep_name_to_succ_nodes (e.g., graph outputs) + for buf_name in name_to_buf: + name_to_buf[buf_name].mpi_buffer = MemoryPlanningInfoForBuffer( + size_alloc=sched_buf_to_size[buf_name][0], + size_free=sched_buf_to_size[buf_name][1], + succ_nodes=dep_name_to_succ_nodes[buf_name], + succ_nodes_for_ordering=dep_name_to_succ_nodes_for_ordering[buf_name], + ) + + +def assign_memory_planning_info_for_scheduler_nodes( + nodes: list[BaseSchedulerNode], + name_to_fused_node: dict[str, BaseSchedulerNode], + name_to_buf: dict[str, SchedulerBuffer], + name_to_freeable_input_buf: dict[str, FreeableInputBuffer], +) -> None: + """ + Assign to each scheduler node its predecessor and successor nodes. + """ + + node_to_pred_nodes: dict[BaseSchedulerNode, OrderedSet[BaseSchedulerNode]] = ( + collections.defaultdict(OrderedSet) + ) + node_to_succ_nodes: dict[BaseSchedulerNode, OrderedSet[BaseSchedulerNode]] = {} + node_to_pred_buffers: dict[ + BaseSchedulerNode, OrderedSet[SchedulerBuffer | FreeableInputBuffer] + ] = collections.defaultdict(OrderedSet) + + # collect all predecessors using existing successor mappings + for node in nodes: + succ_nodes = OrderedSet( + succ_node + for buffer in node.get_outputs() + for succ_node in buffer.mpi_buffer.succ_nodes_for_ordering + ) + node_to_succ_nodes[node] = succ_nodes + + # For each successor, add current node as its predecessor + for succ_node in succ_nodes: + node_to_pred_nodes[succ_node].add(node) + + # For each output buffer, add it as predecessor to its successor nodes + # Use succ_nodes (not succ_nodes_for_ordering) since pred_buffers is used + # for memory lifetime tracking, not ordering + for buffer in node.get_outputs(): + for succ_node in buffer.mpi_buffer.succ_nodes: + node_to_pred_buffers[succ_node].add(buffer) + + for freeable_buffer in name_to_freeable_input_buf.values(): + for succ_node in freeable_buffer.mpi_buffer.succ_nodes: + node_to_pred_buffers[succ_node].add(freeable_buffer) + + # Second pass: assign memory planning info using completed predecessor mappings + for index, node in enumerate(nodes): + size_alloc = sum(buffer.mpi_buffer.size_alloc for buffer in node.get_outputs()) + succ_nodes = node_to_succ_nodes[node] + pred_nodes = node_to_pred_nodes[node] + + # make sure we do not make node a successor or predecessor of itself + succ_nodes.discard(node) + pred_nodes.discard(node) + + node.mpi_node = MemoryPlanningInfoForNode( + index=index, + size=size_alloc, + pred_buffers=node_to_pred_buffers[node], + pred_nodes=node_to_pred_nodes[node], + succ_nodes=succ_nodes, + ) + + +# map each scheduler buffer to its size, start step, and end step +@dataclasses.dataclass +class BufferInfo: + buffer: Union[SchedulerBuffer, FreeableInputBuffer] + size_alloc: int + size_free: int + start_step: int + end_step: int + + +def compute_memory_timeline( + nodes: list[BaseSchedulerNode], + name_to_freeable_input_buf: dict[str, FreeableInputBuffer], + graph_outputs: OrderedSet[str], +) -> tuple[ + list[BufferInfo], + dict[BaseSchedulerNode, int], + dict[Union[FreeableInputBuffer, SchedulerBuffer], BaseSchedulerNode], +]: + """ + Compute buffer allocation and deallocation sizes and map their + lifetime to the node schedule + """ + + # get the execution step of each node, this will be used to determine + # the end_step of buffers + node_to_step: dict[BaseSchedulerNode, int] = { + node: step for step, node in enumerate(nodes) + } + + # get buffers' size and liveliness information + buf_info_list: list[BufferInfo] = [] + buf_to_snode_last_use: dict[ + Union[FreeableInputBuffer, SchedulerBuffer], BaseSchedulerNode + ] = {} + + def _get_end_step_and_snode( + buf: Union[FreeableInputBuffer, SchedulerBuffer], + ) -> tuple[int, Optional[BaseSchedulerNode]]: + max_step: int = -1 + max_step_snode: Optional[BaseSchedulerNode] = None + succ_nodes = buf.mpi_buffer.succ_nodes + if succ_nodes: + for succ_node in succ_nodes: + step = node_to_step[succ_node] + if step > max_step: + max_step = step + max_step_snode = succ_node + assert max_step_snode is not None + return max_step, max_step_snode + + # 1. for freeable input buffers + for buf_name, input_buf in name_to_freeable_input_buf.items(): + end_step = -1 + if buf_name not in graph_outputs: + end_step, end_step_snode = _get_end_step_and_snode(input_buf) + assert end_step_snode is not None + buf_to_snode_last_use[input_buf] = end_step_snode + + buf_info_list.append( + BufferInfo( + input_buf, + input_buf.mpi_buffer.size_free, + input_buf.mpi_buffer.size_free, + 0, + end_step, + ) + ) + + # 2. for scheduler buffers + for step, node in enumerate(nodes): + for sched_buf in node.get_outputs(): + # note: it is possible for a non-graph-output sched_buf to have no succ_nodes and + # to be only used by its defining op (e.g., due to fusion when all consumers of + # the buffer are fused with its defining op). In such cases, end_step is step. + buf_name = sched_buf.get_name() + end_step = -1 + if buf_name not in graph_outputs: + end_step, end_step_snode = _get_end_step_and_snode(sched_buf) + if end_step == -1: + end_step = step + buf_to_snode_last_use[sched_buf] = node + else: + assert end_step_snode is not None + buf_to_snode_last_use[sched_buf] = end_step_snode + + buf_info_list.append( + BufferInfo( + sched_buf, + sched_buf.mpi_buffer.size_alloc, + sched_buf.mpi_buffer.size_free, + step, + end_step, + ) + ) + + return buf_info_list, node_to_step, buf_to_snode_last_use + + +def estimate_peak_memory( + nodes: list[BaseSchedulerNode], + name_to_freeable_input_buf: dict[str, FreeableInputBuffer], + graph_outputs: OrderedSet[str], +) -> tuple[int, list[int]]: + """ + Given a list of nodes in their execution order, estimate the peak memory, by + keeping track of the liveliness of SchedulerBuffers and FreeableInputBuffers. + + Returns: + int: peak memory + List[int]: memory usage at each node (or each step). + """ + + buf_info_list, _, _ = compute_memory_timeline( + nodes, name_to_freeable_input_buf, graph_outputs + ) + + # incremental memory changes at each step + memory = [0 for _ in range(len(nodes) + 1)] + + # for each buffer, update memory when created and when freed + for buf_info in buf_info_list: + memory[buf_info.start_step] += buf_info.size_alloc + memory[buf_info.end_step + 1] -= buf_info.size_free + + # get peak memory by compute the cumulative memories + max_memory = 0 + cur_memory = 0 + memories_at_nodes = [] + for t in range(len(nodes) + 1): + cur_memory += memory[t] + memories_at_nodes.append(cur_memory) + max_memory = max(max_memory, cur_memory) + + return (max_memory, memories_at_nodes) + + +@dataclasses.dataclass +class SNodeMemory: + size_alloc: int + size_free: int + + +def estimate_peak_memory_allocfree( + nodes: list[BaseSchedulerNode], + name_to_freeable_input_buf: dict[str, FreeableInputBuffer], + graph_outputs: OrderedSet[str], +) -> tuple[ + int, + list[tuple[int, int]], + dict[BaseSchedulerNode, SNodeMemory], + dict[Union[FreeableInputBuffer, SchedulerBuffer], BaseSchedulerNode], +]: + """ + Alternative version of estimate_peak_memory, that respects the fact, + that every SchedulerNode has multiple phases: + 1. alloc ( outputs ) + 2. run_kernel + 3. dealloc last_use buffers + estimate_peak_memory collapses memory into one value: size_alloc - size_free + While peak memory happens after alloc. + + Duplicating the code to not migrate all callsites at once, + In future usages of estimate_peak_memory will migrate to this version. + """ + + buf_info_list, _, buf_to_snode_last_use = compute_memory_timeline( + nodes, name_to_freeable_input_buf, graph_outputs + ) + + # incremental memory changes at each step + step_idx_allocfree = [SNodeMemory(0, 0) for _ in range(len(nodes))] + + # for each buffer, update memory when created and when freed + for buf_info in buf_info_list: + step_idx_allocfree[buf_info.start_step].size_alloc += buf_info.size_alloc + if buf_info.end_step != -1: + step_idx_allocfree[buf_info.end_step].size_free += buf_info.size_free + + snodes_allocfree = {} + for i, node in enumerate(nodes): + snodes_allocfree[node] = step_idx_allocfree[i] + + max_memory = 0 + cur_memory = 0 + snodes_curr_memory = [] + for t in range(len(nodes)): + alloc = step_idx_allocfree[t].size_alloc + free = step_idx_allocfree[t].size_free + cur_memory += alloc + post_alloc = cur_memory + max_memory = max(max_memory, cur_memory) + cur_memory -= free + post_free = cur_memory + snodes_curr_memory.append((post_alloc, post_free)) + + return ( + max_memory, + snodes_curr_memory, + snodes_allocfree, + buf_to_snode_last_use, + ) + + +def topological_sort_lpmf( + nodes: list[BaseSchedulerNode], + name_to_freeable_input_buf: dict[str, FreeableInputBuffer], + name_to_buf: dict[str, SchedulerBuffer], + graph_outputs: OrderedSet[str], +) -> list[BaseSchedulerNode]: + """ + A bfs-based greedy topological order. LPMF stands for "Least Peak Memory First". + + The idea is from this paper: + Buffer memory optimization for video codec application modeled in Simulink + https://www.cs.york.ac.uk/rts/docs/DAC-1964-2006/PAPERS/2006/DAC06/PDFFILES/P0689.PDF + + The algorithm maintains the max memory so far. + At every iteration, for each scheduleable node, it computes: + - how much memory needs to be allocated for the output buffers of this node; + - how much memory can be freed as a result of executing this node. + This gives us two values for each node: + (1) mem1: memory during the execution of the node; + (2) mem2: memory after executing the node, after some input buffers are freed. + The greedy approach select as follows: + (i) if there are nodes whose mem1 values are below the max memory so far, + then pick the node with the lowest mem2 value; + (ii) otherwise, pick the one with the lowest mem1 value. + """ + + class NodeInfo(TypedDict): + indegree: int + memory_to_free: int + + class BufferInfo(TypedDict): + outdegree: int + + node_info: dict[BaseSchedulerNode, NodeInfo] = dict() + buf_info: dict[Union[SchedulerBuffer, FreeableInputBuffer], BufferInfo] = dict() + + # compute nodes' number of unmet dependencies (for schedulability) + # initialize the list of nodes ready to be scheduled + nodes_to_schedule: OrderedSet[BaseSchedulerNode] = OrderedSet() + for node in nodes: + node_info[node] = { + "indegree": len(node.mpi_node.pred_nodes), + "memory_to_free": 0, + } + if node_info[node]["indegree"] == 0: + nodes_to_schedule.add(node) + + # compute buffers' number of unmet successors (used to decide when to free) + for buf in list(name_to_buf.values()) + list(name_to_freeable_input_buf.values()): + buf_info[buf] = { + "outdegree": len(buf.mpi_buffer.succ_nodes) + + (1 if buf.get_name() in graph_outputs else 0) + } + + # initialize memory estimations + live_memory = sum( + input_buf.mpi_buffer.size_free + for input_buf in name_to_freeable_input_buf.values() + ) + + # this is the total output memory, which is a lower bound for peak memory + # we do not include the memory of non freeable input buffers + output_memory = 0 + for buf_name in graph_outputs: + if buf_name in name_to_buf: + output_memory += name_to_buf[buf_name].mpi_buffer.size_free + elif buf_name in name_to_freeable_input_buf: + output_memory += name_to_freeable_input_buf[buf_name].mpi_buffer.size_free + max_memory = max(live_memory, output_memory) + memory_gap = max_memory - live_memory + + # compute the amount of memory that is allocated when a node is scheduled + # and the amount of memory that can be freed when a node is scheduled + for node in nodes: + # 1. if a buffer read by this node is last used by this node + for buf in node.mpi_node.pred_buffers: + if buf_info[buf]["outdegree"] == 1: + node_info[node]["memory_to_free"] += buf.mpi_buffer.size_free + # 2. if a buffer written by this node is used internally and not used later + for buf in node.get_outputs(): + if buf_info[buf]["outdegree"] == 0: + node_info[node]["memory_to_free"] += buf.mpi_buffer.size_free + + # schedule nodes one at a time + schedule: list[BaseSchedulerNode] = [] + size_threshold = config.size_threshold_for_succ_based_strategy + num_iters: int = 0 + while num_iters < len(nodes) and nodes_to_schedule: + # select a node to schedule: + if ( + size_threshold > 0 + and min(node.mpi_node.size for node in nodes_to_schedule) > size_threshold + ): + selected_node = min( + nodes_to_schedule, + key=lambda node: min( + ( + succ_node.mpi_node.index + for succ_node in node.mpi_node.succ_nodes + ), + default=len(nodes), + ), + ) + else: + selected_node = min( + nodes_to_schedule, + key=lambda node: ( + node.mpi_node.size if node.mpi_node.size > memory_gap else 0, + node.mpi_node.size - node_info[node]["memory_to_free"], + node.mpi_node.index, + ), + ) + nodes_to_schedule.remove(selected_node) + schedule.append(selected_node) + num_iters += 1 + + # update memory usage + live_memory += selected_node.mpi_node.size + max_memory = max(max_memory, live_memory) + live_memory -= node_info[selected_node]["memory_to_free"] + memory_gap = max_memory - live_memory + + # update successor nodes and nodes_to_schedule + for succ_node in selected_node.mpi_node.succ_nodes: + assert node_info[succ_node]["indegree"] > 0 + node_info[succ_node]["indegree"] -= 1 + if node_info[succ_node]["indegree"] == 0: + nodes_to_schedule.add(succ_node) + + # update predecessor nodes + for buf in selected_node.mpi_node.pred_buffers: + assert buf_info[buf]["outdegree"] > 0 + buf_info[buf]["outdegree"] -= 1 + if buf_info[buf]["outdegree"] == 1: + for succ_node in buf.mpi_buffer.succ_nodes: + node_info[succ_node]["memory_to_free"] += buf.mpi_buffer.size_free + + if num_iters > len(nodes): + raise RuntimeError("Failed to schedule, while loop ran too long for lpmf") + + return schedule + + +def topological_sort_bfs(nodes: list[BaseSchedulerNode]) -> list[BaseSchedulerNode]: + """ + A BFS topological sort that selects nodes whose dependencies are executed the + earliest. This follows a FIFO idea. Specifically, at every iteration, for each node + that is schedulable, we gather the order in which its predecessor nodes are executed, + and this sorted list of execution orders of predecessor nodes defines the priority. + We select the node whose predecessors nodes are executed the earliest. The FIFO + idea aims to reduce the liveness duration of buffers created. + """ + + class NodeInfo(TypedDict): + indegree: int + order: int + + node_info: dict[BaseSchedulerNode, NodeInfo] = dict() + + @dataclasses.dataclass + class NodeWithPriority: + priority: list[int] + node: BaseSchedulerNode + + def __lt__(self, other: NodeWithPriority) -> bool: + if self.priority == other.priority: + return self.node.mpi_node.index < other.node.mpi_node.index + return self.priority < other.priority + + def _node_priority(node: BaseSchedulerNode) -> list[int]: + # priority is the order in which predecessor nodes are executed + assert node_info[node]["indegree"] == 0 + exec_orders = sorted( + OrderedSet( + node_info[pred_node]["order"] for pred_node in node.mpi_node.pred_nodes + ) + ) + return exec_orders + + # compute nodes' number of unmet dependencies (for schedulability) + # initialize the list of nodes ready to be scheduled + nodes_to_schedule: list[NodeWithPriority] = [] + for node in nodes: + node_info[node] = {"indegree": len(node.mpi_node.pred_nodes), "order": -1} + if node_info[node]["indegree"] == 0: + heapq.heappush( + nodes_to_schedule, NodeWithPriority(_node_priority(node), node) + ) + + # schedule nodes one at a time + schedule: list[BaseSchedulerNode] = [] + num_iters: int = 0 + while num_iters < len(nodes) and nodes_to_schedule: + # select a node to schedule + selected_node = heapq.heappop(nodes_to_schedule).node + node_info[selected_node]["order"] = len(schedule) + schedule.append(selected_node) + num_iters += 1 + + # update successor nodes and nodes_to_schedule + for succ_node in selected_node.mpi_node.succ_nodes: + assert node_info[succ_node]["indegree"] > 0 + node_info[succ_node]["indegree"] -= 1 + if node_info[succ_node]["indegree"] == 0: + heapq.heappush( + nodes_to_schedule, + NodeWithPriority(_node_priority(succ_node), succ_node), + ) + + if num_iters > len(nodes): + raise RuntimeError("Failed to schedule, while loop ran too long for bfs") + + return schedule + + +def topological_sort_dfs(nodes: list[BaseSchedulerNode]) -> list[BaseSchedulerNode]: + """ + This is a DFS topological sort. The setup is similar to `topological_sort_schedule` + in scheduler.py. The difference is the order nodes are visited in the outer loop. + In `topological_sort_schedule`, nodes are visited in their original order. + In this function, nodes are visited based on their priority -- for each node, we + compute the total memory of all buffers it reads from or writes to, and we visit + the nodes in ascending order of this priority. + """ + seen: OrderedSet[BaseSchedulerNode] = OrderedSet() + name_to_node: dict[str, BaseSchedulerNode] = dict() + result: list[BaseSchedulerNode] = [] + size_with_reads: dict[BaseSchedulerNode, int] = dict() + + def visit(n: BaseSchedulerNode) -> None: + if n not in seen: + seen.add(n) + dep_nodes = [ + name_to_node[dep.name] + for dep in n.unmet_dependencies + if dep.name in name_to_node + ] + for node in sorted( + dep_nodes, key=lambda n: (size_with_reads[n], n.mpi_node.index) + ): + visit(node) + result.append(n) + + for node in nodes: + for name in node.get_buffer_names(): + name_to_node[name] = node + + for node in nodes: + size_with_reads[node] = node.mpi_node.size + sum( + pred_buf.mpi_buffer.size_free for pred_buf in node.mpi_node.pred_buffers + ) + for node in sorted(nodes, key=lambda n: (size_with_reads[n], n.mpi_node.index)): + visit(node) + + return result + + +def validate_graph_acyclic(nodes: list[BaseSchedulerNode]) -> None: + """ + Validate that the graph is acyclic by checking predecessor relationships. + + Raises: + RuntimeError: If a cycle is detected in the graph + """ + # DFS coloring scheme for cycle detection: + # WHITE (0): Node has not been visited yet + # GRAY (1): Node is currently being processed (in the recursion stack) + # BLACK (2): Node has been completely processed (finished exploring all its predecessors) + # A back edge (cycle) is detected when we encounter a GRAY node during DFS traversal + WHITE, GRAY, BLACK = 0, 1, 2 + color = dict.fromkeys(nodes, WHITE) + path: list[BaseSchedulerNode] = [] # Track current DFS path + + def dfs_visit(node: BaseSchedulerNode) -> None: + if color[node] == BLACK: + return + + if color[node] == GRAY: + path.append(node) + path_info = " -> ".join([node.get_name() for node in path]) + + raise RuntimeError( + f"Cycle detected in memory planning graph" + f"Path containing cycle (i -> j: j is a dependency of i): {path_info} " + f"This indicates invalid dependency relationships in the scheduler graph" + ) + + color[node] = GRAY + path.append(node) + + for pred_node in node.mpi_node.pred_nodes: + assert pred_node != node + dfs_visit(pred_node) + + path.pop() + color[node] = BLACK + + # Start DFS from all unvisited nodes + for node in nodes: + if color[node] == WHITE: + dfs_visit(node) + + +def validate_unique_buffer_names( + nodes: list[BaseSchedulerNode], + name_to_buf: dict[str, SchedulerBuffer], + name_to_freeable_input_buf: dict[str, FreeableInputBuffer], +) -> None: + """ + Validate that for each node's output buffer, the name_to_buf mapping is correct. + For each output buffer buf, we should have name_to_buf[buf.get_name()] == buf. + Also validate that no buffer names overlap with freeable input buffer names. + + Raises: + RuntimeError: If buffer name mapping is incorrect or names overlap + """ + for node in nodes: + for buf in node.get_outputs(): + buf_name = buf.get_name() + + # Check if buffer name exists in the mapping + if buf_name not in name_to_buf: + raise RuntimeError( + f"{buf_name} from {node.get_name()} is not found in name_to_buf mapping." + f" This indicates a missing buffer mapping." + ) + + # Check if the mapping points to the correct buffer object + if name_to_buf[buf_name] != buf: + raise RuntimeError( + f"Buffer name mapping is incorrect for '{buf_name}'." + f"Expected name_to_buf['{buf_name}'] to be {buf.debug_str()}" + f"but got {name_to_buf[buf_name].debug_str()}" + f"This indicates some buffers share the same name" + ) + + # Check if buffer name conflicts with freeable input buffer names + if buf_name in name_to_freeable_input_buf: + raise RuntimeError( + f"Buffer name conflict detected: '{buf_name}' from node {node.get_name()} " + f"is also used as a freeable input buffer name. " + ) + + +def prepare_planning_info( + nodes: list[BaseSchedulerNode], + name_to_buf: dict[str, SchedulerBuffer], + name_to_fused_node: dict[str, BaseSchedulerNode], + graph_inputs: OrderedSet[str], + graph_outputs: OrderedSet[str], +) -> tuple[int, dict[str, FreeableInputBuffer]]: + """ + Prepare planning info. As nodes are scheduled one at a time, these help + keep track of when a buffer can be freed, and when a node can be scheduled + + Returns: + int: peak memory estimation + dict[str, FreeableInputBuffer]: name to freeable input buffer + """ + name_to_freeable_input_buf = get_freeable_input_buf(nodes, graph_inputs) + assign_memory_planning_info_for_scheduler_buffers(nodes, name_to_buf) + assign_memory_planning_info_for_scheduler_nodes( + nodes, name_to_fused_node, name_to_buf, name_to_freeable_input_buf + ) + + # the default + estimated_peak_memory, _ = estimate_peak_memory( + nodes, name_to_freeable_input_buf, graph_outputs + ) + + return estimated_peak_memory, name_to_freeable_input_buf + + +def reorder_for_peak_memory( + nodes: list[BaseSchedulerNode], + name_to_buf: dict[str, SchedulerBuffer], + name_to_fused_node: dict[str, BaseSchedulerNode], + graph_inputs: OrderedSet[str], + graph_outputs: OrderedSet[str], + methods: list[Callable[..., list[BaseSchedulerNode]]] = [ # noqa: B006 + topological_sort_lpmf, + topological_sort_bfs, + topological_sort_dfs, + ], +) -> list[BaseSchedulerNode]: + """ + Try a few heuristics based topological sort algorithms, and pick the one whose + resulting topological order has the lowest peak memory estimation. + """ + + torch_log.info("Reordering for peak memory -- %d nodes", len(nodes)) + + estimated_peak_memory, name_to_freeable_input_buf = prepare_planning_info( + nodes, + name_to_buf, + name_to_fused_node, + graph_inputs, + graph_outputs, + ) + + # export graph for simulator if needed + if config.reorder_for_peak_memory_debug: + export_graph_for_simulator( + nodes, + name_to_freeable_input_buf, + name_to_fused_node, + graph_inputs, + graph_outputs, + ) + + # Validate planning info before proceeding with reordering + try: + validate_graph_acyclic(nodes) + validate_unique_buffer_names(nodes, name_to_buf, name_to_freeable_input_buf) + except RuntimeError: + torch_log.exception("Memory planning validation failed") + if not is_fbcode(): # TODO: remove after ensuring OSS side is safe + raise + + # keep track of the peak memory estimates of different methods + peak_memory_diff_methods: list[PeakMemoryResult] = [] + peak_memory_diff_methods.append( + PeakMemoryResult(nodes, estimated_peak_memory, "baseline") + ) + torch_log.info("Baseline peak memory: %d", estimated_peak_memory) + + # other methods + for method in methods: + try: + if method is topological_sort_lpmf: + order = method( + nodes, name_to_freeable_input_buf, name_to_buf, graph_outputs + ) + else: + order = method(nodes) + assert len(order) == len(nodes) + peak_memory, _ = estimate_peak_memory( + order, name_to_freeable_input_buf, graph_outputs + ) + peak_memory_diff_methods.append( + PeakMemoryResult(order, peak_memory, method.__name__) + ) + torch_log.info("%s peak memory: %d", method.__name__, peak_memory) + except Exception: + torch_log.exception("Failed to reorder for %s", method.__name__) + if not is_fbcode(): # TODO: remove after ensuring OSS side is safe + raise + + signpost_event( + category="inductor", + name="memory", + parameters={ + "orm": {elem.method: elem.peak_memory for elem in peak_memory_diff_methods}, + }, + ) + + # get the optimal one + best_result = min(peak_memory_diff_methods, key=lambda x: x.peak_memory) + + return best_result.order + + +def export_graph_for_simulator( + nodes: list[BaseSchedulerNode], + name_to_freeable_input_buf: dict[str, FreeableInputBuffer], + name_to_fused_node: dict[str, BaseSchedulerNode], + graph_inputs: OrderedSet[str], + graph_outputs: OrderedSet[str], +) -> None: + """ + This is for debugging purposes. It will dump a json file that records graph information. + The graph can then be used in a simulator: https://fburl.com/code/3l3d3qi4 + """ + + class ORMBuffer(TypedDict): + name: str + size_alloc: int + size_free: int + size: int # for backward compatibility + is_input: bool + is_output: bool + deps: list[str] + unmet_deps: list[str] + + class ORMNode(TypedDict): + name: str + buffer_names: list[str] + + class ORMGraph(TypedDict): + nodes: list[ORMNode] + buffers: list[ORMBuffer] + + orm_buffers: list[ORMBuffer] = [] + orm_nodes: list[ORMNode] = [] + + # get orm buffers for freeable input buffers + for buf_name, input_buf in name_to_freeable_input_buf.items(): + orm_buf_input_buffer: ORMBuffer = { + "name": buf_name, + "size_alloc": input_buf.mpi_buffer.size_free, + "size_free": input_buf.mpi_buffer.size_free, + "size": input_buf.mpi_buffer.size_free, + "is_input": True, + "is_output": buf_name in graph_outputs, + "deps": [], + "unmet_deps": [], + } + orm_buffers.append(orm_buf_input_buffer) + + # get orm buffers for scheduler buffers + name_to_buf: dict[str, SchedulerBuffer] = { + buf.get_name(): buf for node in nodes for buf in node.get_outputs() + } # need to reassign due to probably node pruning + for buf_name, sched_buf in name_to_buf.items(): + if sched_buf.defining_op is None: + continue + deps = [ + pred_buf.get_name() + for pred_buf in name_to_fused_node[ + sched_buf.defining_op.get_name() + ].mpi_node.pred_buffers + ] + orm_buf_scheduler_buffer: ORMBuffer = { + "name": buf_name, + "size_alloc": sched_buf.mpi_buffer.size_alloc, + "size_free": sched_buf.mpi_buffer.size_free, + "size": sched_buf.mpi_buffer.size_free, + "is_input": False, + "is_output": buf_name in graph_outputs, + "deps": deps, + "unmet_deps": [ + buf_name for buf_name in deps if buf_name not in graph_inputs + ], + } + orm_buffers.append(orm_buf_scheduler_buffer) + + # get orm nodes + for node in nodes: + orm_node: ORMNode = { + "name": node.get_name(), + "buffer_names": list(node.get_buffer_names()), + } + orm_nodes.append(orm_node) + + # create the graph object + g: ORMGraph = { + "nodes": orm_nodes, + "buffers": orm_buffers, + } + + # dump the graph + import json + import os + + import torch + from functorch.compile import get_graph_being_compiled + + name = os.path.splitext(get_graph_being_compiled())[0] + "_fused" + + g_str = json.dumps(g, indent=2) + + torch._logging.trace_structured( + "artifact", + metadata_fn=lambda: { + "name": name, + "encoding": "string", + }, + payload_fn=lambda: g_str, + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/metrics.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..36f83dc4ba3f22cac11d69048373b3f64b8ee4a4 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/metrics.py @@ -0,0 +1,485 @@ +from __future__ import annotations + +import csv +import dataclasses +import inspect +import os +import re +from dataclasses import dataclass +from functools import lru_cache +from typing import Optional, TYPE_CHECKING, Union + +from torch._inductor import config +from torch._inductor.utils import get_benchmark_name +from torch.utils._ordered_set import OrderedSet + + +# Prevent circular import +if TYPE_CHECKING: + from collections.abc import Callable + + from torch._inductor.runtime.triton_compat import Config + from torch._inductor.scheduler import BaseSchedulerNode + +# counter for tracking how many kernels have been generated +generated_kernel_count = 0 +generated_cpp_vec_kernel_count = 0 +num_bytes_accessed = 0 +nodes_num_elem: list[ + tuple[ + BaseSchedulerNode, + int, + ] +] = [] +node_runtimes: list[tuple[BaseSchedulerNode, float]] = [] + +# counters for tracking fusions +ir_nodes_pre_fusion = 0 + +# counters for tracking to_dtype inserted +cpp_to_dtype_count = 0 + + +@dataclasses.dataclass +class CppOuterLoopFusedCount: + inner_kernel_number: int + local_buffer_number: int = 0 + + +# The length counts the number of outer loop fusions. +cpp_outer_loop_fused_inner_counts: list[CppOuterLoopFusedCount] = [] + +num_comprehensive_padding = 0 +num_matches_for_scatter_upon_const_tensor = 0 + +num_loop_reordering = 0 + +# counter for parallel reduction. +parallel_reduction_count = 0 + +codegen_mix_order_reduction = 0 + + +# reset all counters +def reset() -> None: + global generated_kernel_count + global generated_cpp_vec_kernel_count + global num_bytes_accessed, nodes_num_elem + global ir_nodes_pre_fusion + global cpp_to_dtype_count + global cpp_outer_loop_fused_inner_counts + global num_comprehensive_padding + global num_matches_for_scatter_upon_const_tensor + global num_loop_reordering + global parallel_reduction_count + global codegen_mix_order_reduction + + generated_kernel_count = 0 + generated_cpp_vec_kernel_count = 0 + num_bytes_accessed = 0 + nodes_num_elem.clear() + node_runtimes.clear() + ir_nodes_pre_fusion = 0 + cpp_to_dtype_count = 0 + cpp_outer_loop_fused_inner_counts.clear() + num_comprehensive_padding = 0 + num_matches_for_scatter_upon_const_tensor = 0 + num_loop_reordering = 0 + parallel_reduction_count = 0 + codegen_mix_order_reduction = 0 + + +@dataclass +class CachedMetricsDeltas: + """ + The subset of metrics we want update across cache hits, e.g., the + FxGraphCache. + """ + + generated_kernel_count: int + generated_cpp_vec_kernel_count: int + ir_nodes_pre_fusion: int + cpp_to_dtype_count: int + num_bytes_accessed: int + num_matches_for_scatter_upon_const_tensor: int + + +def get_metric_fields() -> list[str]: + return [field.name for field in dataclasses.fields(CachedMetricsDeltas)] + + +class CachedMetricsHelper: + """ + A helper class to help calculate and apply counter deltas for those + metrics we want to save with cache entries (e.g., FxGraphCache) and + apply on a cache hit. + """ + + def __init__(self) -> None: + self.cached_metrics = {} + for metric in get_metric_fields(): + self.cached_metrics[metric] = globals()[metric] + + def get_deltas(self) -> CachedMetricsDeltas: + delta_metrics = {} + for metric in get_metric_fields(): + delta_metrics[metric] = globals()[metric] - self.cached_metrics[metric] + + return CachedMetricsDeltas(**delta_metrics) + + @staticmethod + def apply_deltas(delta: CachedMetricsDeltas) -> None: + for metric in get_metric_fields(): + globals()[metric] += getattr(delta, metric) + + +REGISTERED_METRIC_TABLES: dict[str, MetricTable] = {} + + +@dataclass +class MetricTable: + table_name: str + column_names: list[str] + + num_rows_added: int = 0 + + def add_row( + self, row_fn: Callable[[], dict[str, Optional[Union[str, float]]]] + ) -> None: + if self.table_name not in enabled_metric_tables(): + return + + row_dict = row_fn() + assert len(self.column_names) == len(row_dict), ( + f"{len(self.column_names)} v.s. {len(row_dict)}" + ) + assert OrderedSet(self.column_names) == OrderedSet(row_dict.keys()), ( + f"{OrderedSet(self.column_names)} v.s. {OrderedSet(row_dict.keys())}" + ) + + bn = get_benchmark_name() + # assert bn is not None + row = [bn] + [row_dict[column_name] for column_name in self.column_names] + assert all(isinstance(i, (str, float, type(None))) for i in row) + self._write_row(row) + + def output_filename(self) -> str: + return f"metric_table_{self.table_name}.csv" + + def write_header(self) -> None: + filename = self.output_filename() + with open(filename, "w") as fd: + writer = csv.writer(fd, lineterminator="\n") + writer.writerow(["model_name"] + self.column_names) + + def _write_row(self, row: list[str | float | None]) -> None: + filename = self.output_filename() + if self.num_rows_added == 0 and not os.path.exists(filename): + self.write_header() + + self.num_rows_added += 1 + + for idx, orig_val in enumerate(row): + if isinstance(orig_val, float): + new_val = f"{orig_val:.6f}" + elif orig_val is None: + new_val = "" + else: + new_val = orig_val + row[idx] = new_val + + with open(filename, "a") as fd: + writer = csv.writer(fd, lineterminator="\n") + writer.writerow(row) + + @staticmethod + def register_table(name: str, column_names: list[str]) -> None: + table = MetricTable(name, column_names) + REGISTERED_METRIC_TABLES[name] = table + + +MetricTable.register_table( + "slow_fusion", + [ + "kernel1_path", + "kernel1_latency", + "kernel2_path", + "kernel2_latency", + "fused_kernel_path", + "fused_kernel_latency", + "slow_down_ratio", + ], +) + +# track the fusion statistics for each graph +MetricTable.register_table( + "graph_stats", + [ + "graph_id", + "num_nodes_before_fusion", + "num_nodes_after_fusion", + ], +) + +# track the perf difference between persistent reduction and non-persistent +# reductions +MetricTable.register_table( + "persistent_red_perf", + [ + "kernel0_path", + "kernel1_path", + "kernel2_path", + "kernel3_path", + "kernel0_latency", + "kernel1_latency", + "kernel2_latency", + "kernel3_latency", + "size_hints", + "reduction_hint", + ], +) + +# Log the fusion failures due to indexing mismatch +MetricTable.register_table( + "fusion_failure_due_to_indexing_mismatch", + [ + "pre_grad_graph_id", + "post_grad_graph_id", + "node1_name", + "node2_name", + "node1_debug_str", + "node2_debug_str", + "common_buffer_names", + "failure_reason", + ], +) + +# Log metadata for pointwise/reduction kernels. E.g., model name, kernel path, numel, rnumel, reduction hint +MetricTable.register_table( + "kernel_metadata", + [ + "kernel_name", + "kernel_path", + "kernel_category", # pointwise/reduction/foreach etc. + "size_hints", + "reduction_hint", + "line_of_code", + "num_load", + "num_store", + "num_for_loop", + "num_atomic_add", + "num_args", + # xyz numel can be different to size_hints since size_hints are rounded + # up to the nearest power of 2. + # Inductor kernel will burn in the xyz numel in kernel code for static + # shape kernels. + # Logging them will be helpful to find unaligned shape for reduction + "xnumel", + "ynumel", + "rnumel", + "kernel_args_num_gb", + ], +) + + +def _parse_kernel_fn_code(kernel_module_code: str) -> str: + """ + The kernel_module_code is the python module that contains kernel function code. + kernel function is the proper triton kernel function annotated with + @triton.jit + """ + from .codecache import PyCodeCache + from .wrapper_benchmark import get_triton_kernel + + mod = PyCodeCache.load(kernel_module_code) + kernel = get_triton_kernel(mod) + # kernel is a CachingAutotune; kernel.fn is the JITFunction; + # kernel.fn.fn is the function being decorate by triton.jit + return inspect.getsource(kernel.fn.fn) + + +def _parse_kernel_line_of_code(proper_kernel_fn_code: str) -> int: + """ + Return the line of code for the kernel excluding the decorators. + """ + return len(proper_kernel_fn_code.splitlines()) + + +def _parse_size_hints(kernel_module_code: str, kernel_category: str) -> Optional[str]: + if kernel_category == "foreach": + # foreach kernel does not have size_hints + return None + m = re.search(r"size_hints=(\[[0-9, ]*\]),", kernel_module_code) + assert m, "size_hints missing!" + return m.group(1) + + +def _parse_reduction_hint( + kernel_category: str, kernel_module_code: str +) -> Optional[str]: + if kernel_category not in ("reduction", "persistent_reduction"): + return None + m = re.search(r"reduction_hint=ReductionHint\.(\w*),", kernel_module_code) + assert m, "reduction_hint not found in kernel source code!" + return m.group(1) + + +def _count_pattern(proper_kernel_fn_code: str, pattern: str) -> int: + return proper_kernel_fn_code.count(pattern) + + +def _count_args(proper_kernel_fn_code: str) -> int: + def_line = proper_kernel_fn_code.splitlines()[0] + assert def_line.startswith("def ") + start_idx = def_line.index("(") + end_idx = def_line.index("):") + decl_csv = def_line[start_idx + 1 : end_idx] + comps = decl_csv.split(",") + return len(comps) + + +def _parse_proper_kernel_fn_code(kernel_fn_code: str) -> str: + """ + Skip decorators. + """ + start_pos = kernel_fn_code.index("def ") + return kernel_fn_code[start_pos:] + + +def _parse_numel(proper_kernel_fn_code: str, numel_arg_name: str) -> Optional[int]: + m = re.search(f"{numel_arg_name} = ([\\d]+)", proper_kernel_fn_code) + if m: + return int(m.group(1)) + else: + return None + + +def _parse_kernel_args_num_gb( + kernel_fn_code: str, kernel_category: str +) -> Optional[float]: + """ + inductor meta looks like: + inductor_meta={... 'mutated_arg_names': [], 'no_x_dim': False, 'kernel_num_gb': 2.0}, + """ + m = re.search(r".kernel_num_gb.:\s*([0-9.]+)", kernel_fn_code) + if m: + return float(m.group(1)) + else: + """ + There are a few cases that kernel_num_gdb field can be missing: + 1. the field will be missing if config.benchmark_kernel and + config.profile_bandwidth are false + 2. even if config.benchmark_kernel or config.profile_bandwidth is true. + foreach kernel does not have kernel_num_gb field in the metadata + """ + return None + + +def log_kernel_metadata( + kernel_name: str, kernel_path: str, kernel_module_code: str +) -> None: + """ + An utility to log kernel metadata. We may parse metadata from kernel source code here. + + It's fine to parse the generated kernel code here since the logging is + disabled by default. It would hurt compilation time. + """ + from .wrapper_benchmark import get_kernel_category_by_source_code + + kernel_category = get_kernel_category_by_source_code(kernel_module_code) + reduction_hint = _parse_reduction_hint(kernel_category, kernel_module_code) + size_hints = _parse_size_hints(kernel_module_code, kernel_category) + kernel_fn_code = _parse_kernel_fn_code(kernel_module_code) + + proper_kernel_fn_code = _parse_proper_kernel_fn_code(kernel_fn_code) + + # the line of code excluding the decortors + kernel_line_of_code = _parse_kernel_line_of_code(proper_kernel_fn_code) + + get_metric_table("kernel_metadata").add_row( + lambda: { + "kernel_name": kernel_name, + "kernel_path": kernel_path, + "kernel_category": kernel_category, + "size_hints": size_hints, + "reduction_hint": reduction_hint, + "line_of_code": kernel_line_of_code, + "num_load": _count_pattern(proper_kernel_fn_code, "tl.load"), + "num_store": _count_pattern(proper_kernel_fn_code, "tl.store"), + "num_for_loop": _count_pattern(proper_kernel_fn_code, "for "), + "num_atomic_add": _count_pattern(proper_kernel_fn_code, "tl.atomic_add"), + "num_args": _count_args(proper_kernel_fn_code), + "xnumel": _parse_numel(proper_kernel_fn_code, "xnumel"), + "ynumel": _parse_numel(proper_kernel_fn_code, "ynumel"), + "rnumel": _parse_numel(proper_kernel_fn_code, "rnumel"), + "kernel_args_num_gb": _parse_kernel_args_num_gb( + kernel_fn_code, kernel_category + ), + } + ) + + +def purge_old_log_files() -> None: + """ + Purge the old log file at the beginning when the benchmark script runs. + Should do it in the parent process rather than the child processes running + each individual model. + """ + for name, table in REGISTERED_METRIC_TABLES.items(): + if name in enabled_metric_tables(): + filename = table.output_filename() + if os.path.exists(filename): + os.unlink(filename) + + table.write_header() + + +def enabled_metric_tables() -> OrderedSet[str]: + return enabled_metric_tables_impl(config.enabled_metric_tables) + + +@lru_cache +def enabled_metric_tables_impl(config_str: str) -> OrderedSet[str]: + enabled: OrderedSet[str] = OrderedSet() + for name in config_str.split(","): + name = name.strip() + if not name: + continue + assert name in REGISTERED_METRIC_TABLES, ( + f"Metric table name {name} is not registered" + ) + enabled.add(name) + return enabled + + +def is_metric_table_enabled(name: str) -> bool: + return name in enabled_metric_tables() + + +def get_metric_table(name: str) -> MetricTable: + assert name in REGISTERED_METRIC_TABLES, f"Metric table {name} is not defined" + return REGISTERED_METRIC_TABLES[name] + + +MetricTable.register_table( + "kernel_autotune", + [ + "kernel_path", + "kernel_name", + "triton_config", + "latency_ms", + ], +) + + +def log_kernel_autotune_result( + kernel_path: str, kernel_name: str, config: Config, latency: float +) -> None: + get_metric_table("kernel_autotune").add_row( + lambda: { + "kernel_path": kernel_path, + "kernel_name": kernel_name, + "triton_config": str(config), + "latency_ms": latency, + } + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/mkldnn_ir.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/mkldnn_ir.py new file mode 100644 index 0000000000000000000000000000000000000000..0040d77a00afd9d156bc5824bcdffd9b8c0b7c02 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/mkldnn_ir.py @@ -0,0 +1,1362 @@ +# mypy: allow-untyped-defs +from collections.abc import Sequence +from typing import Any, Optional, Union + +import sympy + +import torch +from torch._prims_common import make_channels_last_strides_for, StrideType +from torch.utils._ordered_set import OrderedSet + +from .ir import ( + ExternKernelAlloc, + FixedLayout, + FlexibleLayout, + get_device_type, + ir_node_to_tensor, + IRNode, + is_contiguous_storage_and_layout, + Layout, + may_convert_to_optional, + MultiOutput, + MultiOutputLayout, + MutationOutput, + NoneLayout, + ShapeAsConstantBuffer, + TensorBox, +) +from .utils import convert_shape_to_inductor, pad_listlike, SUPPORTED_MKLDNN_DEVICES +from .virtualized import V + + +def _prepare_convolution_fusion_create( + cls, + x: "TensorBox", + weight: "TensorBox", + bias: "TensorBox", + padding: Sequence[int], + stride: Sequence[int], + dilation: Sequence[int], + groups: int, + transposed: bool = False, + output_padding: Optional[Sequence[int]] = None, + quantize_args: Optional[list["TensorBox"]] = None, + other: Optional["TensorBox"] = None, +): + """ + This function is a helper function to prepare inputs, layout and constant args + for convolution post-op fusion's create function, including deciding the output + layout (channels first or channels last), realizing inputs and make them etc. The + function only supports the CPU/XPU device since conv post-op fusion kernel is only + supported on CPU/XPU right now. + """ + + # Port from aten/src/ATen/native/ConvUtils.h: _conv_input_size + def _conv_input_size( + output_size, weight_size, padding, output_padding, stride, dilation, groups + ): + assert len(output_size) == len(weight_size), "Expect input dim == weight dim" + dim = len(output_size) + assert dim > 2, "Expect input dim > 2" + + BATCH_DIM = 0 + WEIGHT_INPUT_CHANNELS_DIM = 1 + input_size = [] + input_size.append(output_size[BATCH_DIM]) + input_size.append(weight_size[WEIGHT_INPUT_CHANNELS_DIM] * groups) + for d in range(2, dim): + kernel = (weight_size[d] - 1) * dilation[d - 2] + 1 + input_size_d = ( + (output_size[d] - 1) * stride[d - 2] + - (padding[d - 2] * 2) + + kernel + + output_padding[d - 2] + ) + input_size.append(input_size_d) + return list(map(int, input_size)) + + # Port from aten/src/ATen/native/ConvUtils.h: _conv_output_size + def _conv_output_size(input_size, weight_size, padding, stride, dilation=None): + has_dilation = dilation is not None + dim = len(input_size) + output_size = [] + output_size.append(input_size[0]) + output_size.append(weight_size[0]) + for d in range(2, dim): + # pyrefly: ignore [unsupported-operation] + dilation_ = dilation[d - 2] if has_dilation else 1 + kernel = dilation_ * (weight_size[d] - 1) + 1 + output_size_d = (input_size[d] + (2 * padding[d - 2]) - kernel) // stride[ + d - 2 + ] + 1 + output_size.append(output_size_d) + return output_size + + # The size of prepacked_weight is the prepacked weight size of deconv: + # Groups > 1: [g*o, i/g, ...] + # Groups == 1: [o, i, ...] + # Returns original weight size in [i, o, ...] + def _original_deconv_weight_size( + prepacked_weight, + groups, + ): + prepacked_weight_size = prepacked_weight.size() + dim = len(prepacked_weight_size) + assert dim > 2, "Expect weight dim > 2" + if groups > 1: + weight_size = [] + weight_size.append(prepacked_weight_size[1] * groups) + weight_size.append(prepacked_weight_size[0] / groups) + weight_size.extend(prepacked_weight_size[d] for d in range(2, dim)) + else: + weight_size = prepacked_weight.transpose(0, 1).size() + return weight_size + + x.realize() + weight.realize() + if bias is not None: + bias.realize() + with V.graph.fake_mode: + # TODO cleaned up the fake_tensor trace as Linear implementation + x_fake = ir_node_to_tensor(x, guard_shape=True) + weight_fake = ir_node_to_tensor(weight, guard_shape=True) + dims = len(x_fake.size()) - 2 + assert 0 < len(padding) <= dims + assert 0 < len(dilation) <= dims + assert 0 < len(stride) <= dims + padding = pad_listlike(padding, dims) + dilation = pad_listlike(dilation, dims) + stride = pad_listlike(stride, dims) + if output_padding is None: + output_padding = pad_listlike([0], dims) + else: + assert 0 < len(output_padding) <= dims + output_padding = pad_listlike(output_padding, dims) + assert isinstance(groups, (int, sympy.core.numbers.Integer)) + if transposed: + # When transposed, the size of the prepacked oneDNN weight is different + # from the PyTorch weight. We're not able to run aten conv with such + # size. We infer the output size from the input params here: + weight_size = _original_deconv_weight_size(weight_fake, groups) + input_size = x_fake.size() + output_size = _conv_input_size( + input_size, + weight_size, + padding, + output_padding, + stride, + dilation, + groups, + ) + else: + x_shape = list(x_fake.shape) + weight_shape = list(weight_fake.shape) + if len(x_shape) != len(weight_shape): + assert len(x_shape) == 3 and len(weight_shape) == 4 + weight_shape.pop(2) + output_size = _conv_output_size( + x_shape, + weight_shape, + padding, + stride, + dilation, + ) + + req_stride_order = [0] + list(reversed(range(1, len(stride) + 1))) + req_stride_order = [len(req_stride_order)] + req_stride_order + + x = cls.require_stride_order(x, req_stride_order) + + # We won't do weight prepack for Conv if dynamic_shapes or if is xpu. + # In static shape cases, since weight is prepacked, we'll always force output to be channels last in the Conv kernel. + # In dynamic shape cases, for input with channels = 1, like tensor of size (s0, 1, 28, 28) and stride (784, 784, 28, 1), + # x = cls.require_stride_order(x, req_stride_order) where req_stride_order is in the channels last order + # won't change the stride of this tensor since stride for dimensions of size 1 is ignored. While in Conv kernel, + # this tensor is considered as channels first and the output will be in contiguous format. + # To align the behavior of the Conv kernel, we set the output_stride in such case to be contiguous instead of channels last. + dynamic_shapes = not all(isinstance(i, int) for i in (output_size)) + if ( + dynamic_shapes or get_device_type(x) == "xpu" + ) and is_contiguous_storage_and_layout(x): + output_stride: StrideType = FlexibleLayout.contiguous_strides(output_size) + # Currently we don't support channel last for the situation that stride of input's batch dim is 0, + # eg. input_size = (1, 1280, 64, 64), but input_stride=(0, 1, 81920, 1280). + # So we use NCHW hear instead. + # Different with cpu, cpu conv always use channels_last for convolution when weight is prepacked, + # but xpu does not do the prepack, so the problem exposed here is only for xpu. + # TODO support channels_last for such zero stride input. + elif get_device_type(x) == "xpu" and x.get_stride()[0] == 0: + output_stride = FlexibleLayout.contiguous_strides(output_size) + else: + output_stride = make_channels_last_strides_for(output_size) + + assert get_device_type(x) == get_device_type(weight) + assert get_device_type(x) in SUPPORTED_MKLDNN_DEVICES + inputs = [x] + + if quantize_args is not None: + x_scale, x_zero_point, w_scale, w_zero_point = quantize_args + x_scale.realize() + x_zero_point.realize() + w_scale.realize() + w_zero_point.realize() + inputs = inputs + [x_scale, x_zero_point] + [weight] + [w_scale, w_zero_point] + else: + inputs += [weight] + + if other is not None: + other = cls.require_stride_order(other, req_stride_order) + assert isinstance(other, TensorBox) + inputs += [other] + + kernel_layout = FixedLayout( + x.get_device_or_error(), + x.get_dtype(), + convert_shape_to_inductor(output_size), + convert_shape_to_inductor(output_stride), + ) + constant_args = [padding, stride, dilation, groups] + if transposed: + constant_args.insert(1, output_padding) + + if bias is not None: + inputs.append(bias) + else: + constant_args.insert(0, bias) + return inputs, constant_args, kernel_layout, req_stride_order, other + + +def _prepare_linear_fusion_create( + cls, + x: "TensorBox", + weight: "TensorBox", + bias: "TensorBox", + quantize_args: Optional[list["TensorBox"]] = None, + other: Optional["TensorBox"] = None, + binary_sum: bool = False, +): + """ + This function is a helper function to prepare inputs, layout and constant args + for linear post-op fusion's create function. The function only supports the CPU device + since linear post-op fusion kernel is only supported on CPU right now. + """ + x.realize() + weight.realize() + if bias is not None: + bias.realize() + + *m, _ = x.get_size() + # The weight has been transposed during the qlinear weight prepack process. + # https://github.com/pytorch/pytorch/blob/4979f9c0d72490970e2019bb1d2284f83d93f76b/ + # aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp#L291 + _, oc = weight.get_size() + output_size = list(m) + [oc] + req_stride_order = list(reversed(range(len(x.get_size())))) + + x = cls.require_stride_order(x, req_stride_order) + assert get_device_type(x) == get_device_type(weight) + assert get_device_type(x) in SUPPORTED_MKLDNN_DEVICES + inputs = [x] + + if quantize_args is not None: + x_scale, x_zero_point, w_scale, w_zero_point = quantize_args + x_scale.realize() + x_zero_point.realize() + w_scale.realize() + w_zero_point.realize() + inputs = inputs + [x_scale, x_zero_point] + [weight] + [w_scale, w_zero_point] + else: + inputs += [weight] + + if other is not None: + if binary_sum: + other = cls.require_stride_order(other, req_stride_order) + inputs = inputs + [other] + + output_stride = FlexibleLayout.contiguous_strides(output_size) + kernel_layout = FixedLayout( + x.get_device(), + x.get_dtype(), + output_size, + output_stride, + ) + constant_args: list[Any] = [] + + if bias is not None: + inputs.append(bias) + else: + constant_args.insert(0, bias) + return inputs, constant_args, kernel_layout, req_stride_order, other + + +def _create_output_node(packed): + output_ir = MultiOutput( + packed.get_layout(), + packed, + [], + ) + packed.layout = MultiOutputLayout(device=packed.get_device()) + packed.outputs = [output_ir] + return output_ir + + +class ConvolutionUnary(ExternKernelAlloc): + def __init__( + self, + layout, + inputs, + constant_args=(), + ) -> None: + self.device_type = get_device_type(inputs[0]) + super().__init__( + layout, + inputs, + constant_args, + None, + op_overload=torch.ops.mkldnn._convolution_pointwise.default, + cpp_kernel_name=f"aoti_torch_{self.device_type}_mkldnn__convolution_pointwise", + ) + + def codegen(self, wrapper): + wrapper.include_extra_header( + f"torch/csrc/inductor/aoti_torch/c/shim_{self.device_type}.h" + ) + super().codegen(wrapper) + + @classmethod + def create( + cls, + x: "TensorBox", + weight: "TensorBox", + bias: "TensorBox", + padding_: list[int], + stride_: list[int], + dilation_: list[int], + groups: int, + attr, + scalars: Optional[list[Any]], + algorithm, + ): + ( + inputs, + constant_args, + kernel_layout, + _, + _, + ) = _prepare_convolution_fusion_create( + cls, x, weight, bias, padding_, stride_, dilation_, groups + ) + constant_args = constant_args + [ + attr, + may_convert_to_optional(scalars), + algorithm, + ] + packed = ConvolutionUnary( + layout=kernel_layout, + inputs=inputs, + constant_args=constant_args, + ) + return _create_output_node(packed) + + +class ConvolutionBinary(ExternKernelAlloc): + def __init__( + self, + layout, + inputs, + constant_args=(), + cpp_constant_args=(), + ) -> None: + self.device_type = get_device_type(inputs[0]) + super().__init__( + layout, + inputs, + constant_args, + None, + op_overload=torch.ops.mkldnn._convolution_pointwise.binary, + cpp_kernel_name=f"aoti_torch_{self.device_type}_mkldnn__convolution_pointwise_binary", + ) + self.cpp_constant_args = cpp_constant_args + + def codegen(self, wrapper): + wrapper.include_extra_header( + f"torch/csrc/inductor/aoti_torch/c/shim_{self.device_type}.h" + ) + super().codegen(wrapper) + + @classmethod + def create( + cls, + x: "TensorBox", + other: "TensorBox", + weight: "TensorBox", + bias: "TensorBox", + padding_: list[int], + stride_: list[int], + dilation_: list[int], + groups: int, + binary_attr: str, + binary_alpha: Optional[float], + unary_attr: Optional[str], + unary_scalars: Optional[list[Any]], + unary_algorithm: Optional[str], + ): + ( + inputs, + constant_args, + kernel_layout, + req_stride_order, + _, + ) = _prepare_convolution_fusion_create( + cls, x, weight, bias, padding_, stride_, dilation_, groups + ) + # pyrefly: ignore [bad-assignment] + other = cls.require_stride_order(other, req_stride_order) + inputs.insert(1, other) + constant_args = constant_args + [ + binary_attr, + binary_alpha, + unary_attr, + may_convert_to_optional(unary_scalars), + unary_algorithm, + ] + packed = ConvolutionBinary( + layout=kernel_layout, + inputs=inputs, + constant_args=constant_args, + ) + return _create_output_node(packed) + + +class ConvolutionBinaryInplace(ExternKernelAlloc): + def __init__( + self, + kernel_layout, + inputs, + constant_args=(), + ) -> None: + # Due to constrain of op.call, other (Tensor&) should be at input[0] + self.device_type = get_device_type(inputs[0]) + reordered_inputs = [inputs[1], inputs[0]] + inputs[2:] + + super().__init__( + kernel_layout, + reordered_inputs, + constant_args, + None, + op_overload=torch.ops.mkldnn._convolution_pointwise_.binary, + cpp_kernel_name=f"aoti_torch_{self.device_type}_mkldnn__convolution_pointwise_binary_", + ) + + self.mutation_outputs = [ + MutationOutput(NoneLayout(device=inputs[0].get_device()), inputs[0], self), + MutationOutput(NoneLayout(device=inputs[1].get_device()), inputs[1], self), + ] + + def codegen(self, wrapper): + wrapper.include_extra_header( + f"torch/csrc/inductor/aoti_torch/c/shim_{self.device_type}.h" + ) + super().codegen(wrapper) + + def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: + return OrderedSet() + + @classmethod + def create( + cls, + x: "TensorBox", + other: "TensorBox", + weight: "TensorBox", + bias: "TensorBox", + padding_: list[int], + stride_: list[int], + dilation_: list[int], + groups: int, + binary_attr: str, + binary_alpha: Optional[float], + unary_attr: Optional[str], + unary_scalars: Optional[list[Any]], + unary_algorithm: Optional[str], + ): + ( + inputs, + constant_args, + _, + req_stride_order, + _, + ) = _prepare_convolution_fusion_create( + cls, x, weight, bias, padding_, stride_, dilation_, groups + ) + # pyrefly: ignore [bad-assignment] + other = cls.require_stride_order(other, req_stride_order) + inputs.insert(1, other) + constant_args = constant_args + [ + binary_attr, + binary_alpha, + unary_attr, + may_convert_to_optional(unary_scalars), + unary_algorithm, + ] + packed = ConvolutionBinaryInplace( + kernel_layout=NoneLayout(device=inputs[1].get_device()), # type: ignore[arg-type] + inputs=inputs, + constant_args=constant_args, + ) + # This op mutates in place which means that the result is not the + # target but rather the input that is being mutated + # init reorders the inputs, so inputs[1] becomes packed.inputs[0] + return packed.inputs[0] + + +class ConvolutionTransposeUnary(ExternKernelAlloc): + def __init__( + self, + layout, + inputs, + constant_args=(), + ) -> None: + self.device_type = get_device_type(inputs[0]) + super().__init__( + layout, + inputs, + constant_args, + None, + op_overload=torch.ops.mkldnn._convolution_transpose_pointwise.default, + cpp_kernel_name=f"aoti_torch_{self.device_type}_mkldnn__convolution_transpose_pointwise", + ) + + def codegen(self, wrapper): + wrapper.include_extra_header( + f"torch/csrc/inductor/aoti_torch/c/shim_{self.device_type}.h" + ) + super().codegen(wrapper) + + @classmethod + def create( + cls, + x: "TensorBox", + weight: "TensorBox", + bias: "TensorBox", + padding_: list[int], + output_padding_: list[int], + stride_: list[int], + dilation_: list[int], + groups_: int, + attr, + scalars: Optional[list[Any]], + algorithm, + ): + transposed = True + ( + inputs, + constant_args, + kernel_layout, + _, + _, + ) = _prepare_convolution_fusion_create( + cls, + x, + weight, + bias, + padding_, + stride_, + dilation_, + groups_, + transposed, + output_padding_, + ) + constant_args = constant_args + [ + attr, + may_convert_to_optional(scalars), + algorithm, + ] + packed = ConvolutionTransposeUnary( + layout=kernel_layout, + inputs=inputs, + constant_args=constant_args, + ) + return _create_output_node(packed) + + +class QConvPointWisePT2E(ExternKernelAlloc): + def __init__( + self, + layout, + inputs, + constant_args=(), + ) -> None: + """ + if bias is not None + - inputs = [x, w, b, weight_scale, weight_zp] + - const_args is: [stride, padding, dilation, groups, x_scale, x_zp, o_scale, o_zp, + fp32_output, unary_attr, unary_scalars, unary_algorithm] + else + - inputs = [x, w, weight_scale, weight_zp] + - const_args is: [bias, stride, padding, dilation, groups, x_scale, x_zp, o_scale, o_zp, + fp32_output, unary_attr, unary_scalars, unary_algorithm] + """ + self.device_type = get_device_type(inputs[0]) + self.has_bias = len(inputs) == 5 + super().__init__( + layout, + inputs, + constant_args, + None, + op_overload=torch.ops.onednn.qconv_pointwise.tensor, + cpp_kernel_name=f"aoti_torch_{self.device_type}__qconv_pointwise_tensor", + ) + + def codegen(self, wrapper): + wrapper.include_extra_header( + f"torch/csrc/inductor/aoti_torch/c/shim_{self.device_type}.h" + ) + super().codegen(wrapper) + if isinstance(self.layout, Layout): + self.codegen_size_asserts(wrapper) + + @classmethod + def create( + cls, + qx: "TensorBox", + x_scale: Union["ShapeAsConstantBuffer", "TensorBox"], + x_zero_point: Union["ShapeAsConstantBuffer", "TensorBox"], + qw: "TensorBox", # qw + w_scale: "TensorBox", + w_zero_point, + bias: "TensorBox", + stride: list[int], + padding: list[int], + dilation: list[int], + groups: int, + output_scale: float, + output_zero_point: int, + output_dtype, + attr, + scalars, + algorithm, + ): + transposed = False + output_padding = None + ( + inputs, + constant_args, + kernel_layout, + _, + _, + ) = _prepare_convolution_fusion_create( + cls, + qx, + qw, + bias, + padding, + stride, + dilation, + groups, + transposed, + output_padding, + [x_scale, x_zero_point, w_scale, w_zero_point], # type: ignore[list-item] + ) + # swap padding and stride to align with functional conv arg order + if bias is None: + constant_args[1], constant_args[2] = constant_args[2], constant_args[1] + else: + constant_args[0], constant_args[1] = constant_args[1], constant_args[0] + + constant_args = constant_args + [ + output_scale, + output_zero_point, + output_dtype, + attr, + may_convert_to_optional(scalars), + algorithm, + ] + + assert output_dtype is not None + if output_dtype in [torch.float32, torch.bfloat16]: + # in _prepare_convolution_fusion_create, we use x.dtype (uint8) to create kernel_layout + # if we set output_dtype is not None, the output buf should be output_dtype instead of uint8. + kernel_layout.dtype = output_dtype + + return QConvPointWisePT2E( + layout=kernel_layout, + inputs=inputs, + constant_args=constant_args, + ) + + +class QConvPointWiseBinaryPT2E(ExternKernelAlloc): + def __init__( + self, + layout, + inputs, + constant_args=(), + ) -> None: + """ + Needs input/weight/output qparams + if bias is not None + - inputs = [x, x_scale, x_zp, w, w_scale, w_zp, accum, b] + - const_args = [stride, padding, dilation, groups, o_scale, o_zp, + output_dtype, accum_scale, accum_zp, binary_attr, alpha, unary_attr, unary_scalars, unary_algorithm] + else + - inputs = [x, x_scale, x_zp, w, w_scale, w_zp, accum] + - const_args [b, stride, padding, dilation, groups, o_scale, o_zp, + output_dtype, accum_scale, accum_zp, binary_attr, alpha, unary_attr, unary_scalars, unary_algorithm] + """ + self.device_type = get_device_type(inputs[0]) + self.has_bias = len(inputs) == 8 + self.idx_for_inplace_sum = 6 + super().__init__( + layout, + inputs, + constant_args, + None, + op_overload=torch.ops.onednn.qconv2d_pointwise.binary_tensor, + cpp_kernel_name=( + f"aoti_torch_{self.device_type}__qconv2d_pointwise_binary_tensor" + ), + ) + + def codegen(self, wrapper): + wrapper.include_extra_header( + f"torch/csrc/inductor/aoti_torch/c/shim_{self.device_type}.h" + ) + super().codegen(wrapper) + if isinstance(self.layout, Layout): + self.codegen_size_asserts(wrapper) + + def get_mutation_names(self) -> Sequence[str]: + return [self.input_name(self.idx_for_inplace_sum)] + + def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: + return OrderedSet() + + @classmethod + def create( + cls, + qx: "TensorBox", + x_scale: "TensorBox", + x_zero_point: "TensorBox", + qw: "TensorBox", # packed_weight + w_scale, + w_zero_point, + qaccum: "TensorBox", + bias: "TensorBox", + stride: list[int], + padding: list[int], + dilation: list[int], + groups: int, + output_scale: "TensorBox", + output_zero_point: "TensorBox", + output_dtype, + accum_scale, + accum_zero_point, + binary_attr, + alpha, + unary_attr, + unary_scalars, + unary_algorithm, + ): + transposed = False + output_padding = None + ( + inputs, + constant_args, + _kernel_layout, + req_stride_order, + qaccum, + ) = _prepare_convolution_fusion_create( + cls, + qx, + qw, + bias, + padding, + stride, + dilation, + groups, + transposed, + output_padding, + [x_scale, x_zero_point, w_scale, w_zero_point], + qaccum, + ) + + # swap padding and stride to align with functional conv arg order + if bias is None: + constant_args[1], constant_args[2] = constant_args[2], constant_args[1] + else: + constant_args[0], constant_args[1] = constant_args[1], constant_args[0] + + constant_args = constant_args + [ + output_scale, + output_zero_point, + output_dtype, + accum_scale, + accum_zero_point, + binary_attr, + alpha, + unary_attr, + may_convert_to_optional(unary_scalars), + unary_algorithm, + ] + + assert binary_attr == "sum", ( + "For now, only post op sum is supported in QConvPointWiseBinaryPT2E." + ) + + V.graph.mark_buffer_mutated(qaccum.get_name()) + packed = QConvPointWiseBinaryPT2E( + layout=NoneLayout(device=qaccum.get_device()), + inputs=inputs, + constant_args=constant_args, + ) + + # Return accum since it has been inplace changed. + return packed.inputs[packed.idx_for_inplace_sum] + + +class MKLPackedLinear(ExternKernelAlloc): + def __init__( + self, + layout, + inputs, + constant_args=(), + ) -> None: + super().__init__( + layout, + inputs, + constant_args, + None, + op_overload=torch.ops.mkl._mkl_linear.default, + ) + + def codegen(self, wrapper): + wrapper.include_extra_header("torch/csrc/inductor/aoti_torch/c/shim_cpu.h") + super().codegen(wrapper) + + @classmethod + def create(cls, x, packed_w, orig_w, B, batch_size): + x = cls.require_stride1(cls.realize_input(x)) + orig_w = cls.require_stride1(cls.realize_input(orig_w)) + *m, _ = x.get_size() + oc, _ = orig_w.get_size() + output_size = list(m) + [oc] + output_stride = FlexibleLayout.contiguous_strides(output_size) + inputs = [x, packed_w, orig_w] + constant_args = [batch_size] + if B is not None: + inputs += [B] + else: + constant_args.insert(0, None) + + device = x.get_device() + assert device is not None + return MKLPackedLinear( + layout=FixedLayout(device, x.get_dtype(), output_size, output_stride), + inputs=inputs, + constant_args=constant_args, + ) + + +class LinearUnary(ExternKernelAlloc): + def __init__( + self, + layout, + inputs, + constant_args=(), + ) -> None: + self.device_type = get_device_type(inputs[0]) + super().__init__( + layout, + inputs, + constant_args, + None, + op_overload=torch.ops.mkldnn._linear_pointwise.default, + cpp_kernel_name=f"aoti_torch_{self.device_type}__linear_pointwise", + ) + + def codegen(self, wrapper): + wrapper.include_extra_header( + f"torch/csrc/inductor/aoti_torch/c/shim_{self.device_type}.h" + ) + super().codegen(wrapper) + + @classmethod + def create(cls, x, w, B, attr, scalars, algorithm): + x = cls.require_contiguous(cls.realize_input(x)) + w = cls.require_contiguous(cls.realize_input(w)) + + *m, _ic = x.get_size() + oc, _ic = w.get_size() + output_size = list(m) + [oc] + inputs = [x, w] + constant_args = [attr, scalars if scalars else [-1], algorithm] + if B is not None: + B = cls.require_contiguous(cls.realize_input(B)) + inputs.append(B) + else: + constant_args.insert(0, None) + + device = x.get_device() + assert device is not None + + packed = LinearUnary( + layout=FixedLayout( + device=device, + dtype=x.get_dtype(), + size=output_size, + ), + inputs=inputs, + constant_args=constant_args, + ) + return _create_output_node(packed) + + def apply_constraint(self): + pass + + +class LinearBinary(ExternKernelAlloc): + kernel = "torch.ops.mkldnn._linear_pointwise.binary" + + def __init__( + self, + layout, + inputs, + constant_args=(), + ) -> None: + self.device_type = get_device_type(inputs[0]) + super().__init__( + layout, + inputs, + constant_args, + None, + op_overload=torch.ops.mkldnn._linear_pointwise.binary, + cpp_kernel_name=f"aoti_torch_{self.device_type}__linear_pointwise_binary", + ) + + def codegen(self, wrapper): + wrapper.include_extra_header( + f"torch/csrc/inductor/aoti_torch/c/shim_{self.device_type}.h" + ) + super().codegen(wrapper) + + @classmethod + def create(cls, x, y, w, B, attr): + x = cls.require_contiguous(cls.realize_input(x)) + y = cls.require_contiguous(cls.realize_input(y)) + w = cls.require_contiguous(cls.realize_input(w)) + + *m, _ic = x.get_size() + oc, _ic = w.get_size() + output_size = list(m) + [oc] + inputs = [x, y, w] + constant_args = [attr] + if B is not None: + B = cls.require_contiguous(cls.realize_input(B)) + inputs.append(B) + else: + constant_args.insert(0, B) + + device = x.get_device() + assert device is not None + packed = LinearBinary( + layout=FixedLayout( + device=device, + dtype=x.get_dtype(), + size=output_size, + ), + inputs=inputs, + constant_args=constant_args, + ) + return _create_output_node(packed) + + def apply_constraint(self): + pass + + +class QLinearPointwisePT2E(ExternKernelAlloc): + def __init__( + self, + layout, + inputs, + constant_args=(), + has_bias=True, + ) -> None: + """ + if bias is not None + - inputs = [x, w, b, weight_scale, weight_zp] + - const_args is: [x_scale, x_zp, o_scale, o_zp, + fp32_output, unary_attr, unary_scalars, unary_algorithm] + else + - inputs = [x, w, weight_scale, weight_zp] + - const_args is: [bias, x_scale, x_zp, o_scale, o_zp, + fp32_output, unary_attr, unary_scalars, unary_algorithm] + """ + self.device_type = get_device_type(inputs[0]) + self.has_bias = has_bias + super().__init__( + layout, + inputs, + constant_args, + None, + op_overload=(torch.ops.onednn.qlinear_pointwise.tensor), + cpp_kernel_name=( + f"aoti_torch_{self.device_type}__qlinear_pointwise_tensor" + ), + ) + + def codegen(self, wrapper): + wrapper.include_extra_header( + f"torch/csrc/inductor/aoti_torch/c/shim_{self.device_type}.h" + ) + super().codegen(wrapper) + + if isinstance(self.layout, Layout): + self.codegen_size_asserts(wrapper) + + @classmethod + def create( + cls, + qx: "TensorBox", + x_scale: "TensorBox", + x_zero_point: "TensorBox", + qw: "TensorBox", # packed_weight + w_scale: "TensorBox", + w_zero_point: "TensorBox", + bias: "TensorBox", + output_scale: float, + output_zero_point: int, + output_dtype, + post_op_name, + post_op_args, + post_op_algorithm, + ): + (inputs, constant_args, kernel_layout, _, _) = _prepare_linear_fusion_create( + cls, + qx, + qw, + bias, + [x_scale, x_zero_point, w_scale, w_zero_point], + ) + + constant_args = constant_args + [ + output_scale, + output_zero_point, + output_dtype, + post_op_name, + may_convert_to_optional(post_op_args), + post_op_algorithm, + ] + + assert output_dtype is not None + if output_dtype in [torch.float32, torch.bfloat16]: + # in _prepare_linear_fusion_create, we use x.dtype (uint8) to create kernel_layout + # if we set fp32_output, the output buf should be dtype float32 instead of uint8. + kernel_layout.dtype = output_dtype + + return QLinearPointwisePT2E( + layout=kernel_layout, + inputs=inputs, + constant_args=constant_args, + has_bias=(bias is not None), + ) + + +class QLinearPointwiseBinaryPT2E(ExternKernelAlloc): + def __init__( + self, + layout, + inputs, + constant_args=(), + has_bias=True, + ) -> None: + """ + if bias is not None + - inputs = [x, w, x_scale, x_zp, weight_scale, weight_zp, x2, bias] + - const_args is: [o_scale, o_zp, + fp32_output, binary_attr, alpha, unary_attr, unary_scalars, unary_algorithm] + else + - inputs = [x, w, x_scale, x_zp, weight_scale, weight_zp, x2] + - const_args is: [bias, o_scale, o_zp, + fp32_output, binary_attr, alpha, unary_attr, unary_scalars, unary_algorithm] + """ + self.device_type = get_device_type(inputs[0]) + self.has_bias = has_bias + self.idx_for_inplace_sum = 6 + super().__init__( + layout, + inputs, + constant_args, + None, + op_overload=(torch.ops.onednn.qlinear_pointwise.binary_tensor), + cpp_kernel_name=f"aoti_torch_{self.device_type}__qlinear_pointwise_binary_tensor", + ) + + def codegen(self, wrapper): + wrapper.include_extra_header( + f"torch/csrc/inductor/aoti_torch/c/shim_{self.device_type}.h" + ) + super().codegen(wrapper) + if isinstance(self.layout, Layout): + self.codegen_size_asserts(wrapper) + + def get_mutation_names(self) -> Sequence[str]: + binary_post_op = self.constant_args[-5] + if binary_post_op == "sum": + input = self.inputs[self.idx_for_inplace_sum] + assert isinstance(input, IRNode) + return [input.get_name()] + else: + return [] + + @classmethod + def create( + cls, + qx: "TensorBox", + x_scale: "TensorBox", + x_zero_point: "TensorBox", + qw: "TensorBox", # packed_weight + w_scale: "TensorBox", + w_zero_point: "TensorBox", + other: "TensorBox", + bias: "TensorBox", + output_scale: float, + output_zero_point: int, + output_dtype, + other_scale, + other_zp, + binary_post_op, + binary_alpha, + unary_post_op, + unary_post_op_args, + unary_post_op_algorithm, + ): + ( + inputs, + constant_args, + kernel_layout, + req_stride_order, + other, + ) = _prepare_linear_fusion_create( + cls, + qx, + qw, + bias, + [x_scale, x_zero_point, w_scale, w_zero_point], + other, + binary_post_op == "sum", + ) + + constant_args = constant_args + [ + output_scale, + output_zero_point, + output_dtype, + other_scale, + other_zp, + binary_post_op, + binary_alpha, + unary_post_op, + may_convert_to_optional(unary_post_op_args), + unary_post_op_algorithm, + ] + + if binary_post_op == "sum": + V.graph.mark_buffer_mutated(other.get_name()) + packed = QLinearPointwiseBinaryPT2E( + layout=NoneLayout(device=other.get_device()), + inputs=inputs, + constant_args=constant_args, + has_bias=(bias is not None), + ) + # Return other since it has been inplace changed. + return packed.inputs[packed.idx_for_inplace_sum] + + assert output_dtype is not None + if output_dtype in [torch.float32, torch.bfloat16]: + # in _prepare_linear_fusion_create, we use x.dtype (uint8) to create kernel_layout + # if we set fp32_output, the output buf should be dtype float32 instead of uint8. + kernel_layout.dtype = output_dtype + + return QLinearPointwiseBinaryPT2E( + layout=kernel_layout, + inputs=inputs, + constant_args=constant_args, + has_bias=(bias is not None), + ) + + +class MkldnnRnnLayer(ExternKernelAlloc): + def __init__( + self, + layout, + inputs, + constant_args=(), + ) -> None: + super().__init__( + layout, + inputs, + constant_args, + None, + op_overload=torch.ops.aten.mkldnn_rnn_layer.default, + ) + + @classmethod + def create( + cls, + x: "TensorBox", + w0: "TensorBox", + w1: "TensorBox", + w2: "TensorBox", + w3: "TensorBox", + hx: "TensorBox", + cx: "TensorBox", + reverse: bool, + batch_sizes: list[int], + mode: int, + hidden_size: int, + num_layers: int, + has_biases: bool, + bidirectional: bool, + batch_first: bool, + train: bool, + ): + # pyrefly: ignore [bad-assignment] + x = cls.require_stride1(cls.realize_input(x)) + # If batch_first, x has been permuted in lstm before entering the mkldnn_rnn_layer. + # Make sure x is contiguous in batch_first case. + x.freeze_layout() + # pyrefly: ignore [bad-assignment] + w0 = cls.require_stride1(cls.realize_input(w0)) + # pyrefly: ignore [bad-assignment] + w1 = cls.require_stride1(cls.realize_input(w1)) + # pyrefly: ignore [bad-assignment] + w2 = cls.require_stride1(cls.realize_input(w2)) + # pyrefly: ignore [bad-assignment] + w3 = cls.require_stride1(cls.realize_input(w3)) + # pyrefly: ignore [bad-assignment] + hx = cls.require_stride1(cls.realize_input(hx)) + hx.freeze_layout() + # pyrefly: ignore [bad-assignment] + cx = cls.require_stride1(cls.realize_input(cx)) + cx.freeze_layout() + + input_size = x.get_size() + assert len(input_size) == 3, "Expect lstm input to be 3D" + # batch_first is handled in the lstm OP. When entering + # rnn_layer here, we'll always have batch_first = False + seq_length, mini_batch, input_size = input_size + output_shape = [seq_length, mini_batch, hidden_size] + + hy_shape = hx.get_size() + cy_shape = cx.get_size() + + inputs = [x, w0, w1, w2, w3, hx, cx] + constant_args = [ + reverse, + batch_sizes, + mode, + hidden_size, + num_layers, + has_biases, + bidirectional, + batch_first, + train, + ] + + device = x.get_device() + assert device is not None + packed = MkldnnRnnLayer( + MultiOutputLayout(device=device), + inputs=inputs, + constant_args=constant_args, + ) + + def get_strides_of_lstm_output(output_shape, batch_first): + assert len(output_shape) == 3, "Expect output_shape to be 3D" + return FlexibleLayout.contiguous_strides(output_shape) + + # C shim call requires all the outputs to be passed in, and thus the last + # dummy return value is added. + output_sizes = [output_shape, hy_shape, cy_shape, [1]] + output_strides = [ + get_strides_of_lstm_output(output_shape, batch_first), + FlexibleLayout.contiguous_strides(hy_shape), + FlexibleLayout.contiguous_strides(cy_shape), + [1], + ] + output_ir = [ + MultiOutput( + FixedLayout( + x.get_device(), # type: ignore[arg-type] + x.get_dtype(), + output_size, + output_stride, + ), + packed, + [(tuple, i)], + ) + for i, (output_size, output_stride) in enumerate( + zip(output_sizes, output_strides) + ) + ] + packed.outputs = output_ir + + return output_ir + + def codegen(self, wrapper): + wrapper.include_extra_header("torch/csrc/inductor/aoti_torch/c/shim_cpu.h") + return super().codegen(wrapper) + + +# Add this IR so that we can include shim_cpu.h for cpp_wrapper +class WeightInt4PackMatmul(ExternKernelAlloc): + def __init__( + self, + layout, + inputs, + constant_args=(), + ) -> None: + """ + inputs = [x, w, qGroupSize, qScalesAndZeros] + constant_args = () + """ + assert len(inputs) == 4 + assert len(constant_args) == 0 + super().__init__( + layout, + inputs, + constant_args, + None, + op_overload=(torch.ops.quantized.int4mm_packed_weight_cpu.default), + cpp_kernel_name=("aoti_torch_cpu__weight_int4pack_mm_cpu_tensor"), + ) + + def codegen(self, wrapper): + wrapper.include_extra_header("torch/csrc/inductor/aoti_torch/c/shim_cpu.h") + super().codegen(wrapper) + + if isinstance(self.layout, Layout): + self.codegen_size_asserts(wrapper) + + @classmethod + def create( + cls, + x: "TensorBox", + w: "TensorBox", + qGroupSize: "TensorBox", + qScalesAndZeros: "TensorBox", + ): + inputs = [x, w, qGroupSize, qScalesAndZeros] + *m, _ = x.get_size() + n, _ = w.get_size() + output_size = list(m) + [n] + output_stride = FlexibleLayout.contiguous_strides(output_size) + kernel_layout = FixedLayout( + x.get_device(), # type: ignore[arg-type] + x.get_dtype(), + output_size, + output_stride, + ) + return WeightInt4PackMatmul( + layout=kernel_layout, + inputs=inputs, + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/mkldnn_lowerings.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/mkldnn_lowerings.py new file mode 100644 index 0000000000000000000000000000000000000000..b171de34ae02d70c639129c4b4233e6eaff1cc68 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/mkldnn_lowerings.py @@ -0,0 +1,1404 @@ +# mypy: allow-untyped-defs +import functools +from typing import Optional, Union + +import torch +import torch.utils._pytree as pytree +from torch._inductor.kernel.mm_common import mm_args + +from . import config, ir +from .codegen.cpp_gemm_template import CppGemmTemplate +from .codegen.cpp_grouped_gemm_template import CppGroupedGemmTemplate +from .codegen.cpp_utils import create_epilogue_with_attr +from .ir import TensorBox +from .lowering import ( + add, + add_needs_realized_inputs, + aten, + permute, + register_lowering, + to_dtype, + view, +) +from .select_algorithm import ( + autotune_select_algorithm, + ChoiceCaller, + ExternKernelChoice, +) +from .utils import use_aten_gemm_kernels, use_cpp_gemm_template +from .virtualized import ops, OpsValue, V + + +def create_int8_compensation( + W_tensor: torch.Tensor, + packed_weight: ir.TensorBox, + x_scale: ir.TensorBox, + x_zp: ir.TensorBox, + w_scale: ir.TensorBox, +) -> tuple[ + bool, + Union[ir.TensorBox, ir.ShapeAsConstantBuffer], + Optional[Union[ir.TensorBox, ir.ShapeAsConstantBuffer]], +]: + x_w_scale: Optional[Union[ir.TensorBox, ir.ShapeAsConstantBuffer]] = None + use_int8_fast_compensation_path = all( + isinstance(item, ir.TensorBox) + and item.get_name() in V.graph.constants + and hasattr(item.data, "data") + and isinstance(item.data.data, ir.ConstantBuffer) + for item in [x_scale, x_zp, w_scale] + ) + if use_int8_fast_compensation_path: + x_w_scale_tensor = ( + V.graph.constants[x_scale.get_name()] + * V.graph.constants[w_scale.get_name()] + ) + x_w_scale = V.graph.add_tensor_constant( + x_w_scale_tensor, + name=packed_weight.get_name() + "_x_w_compens", + ) + weight_compens_tensor = torch.sum(W_tensor.to(torch.float), dim=0) + x_zp_tensor = V.graph.constants[x_zp.get_name()] + weight_compens_tensor = weight_compens_tensor * x_w_scale_tensor * x_zp_tensor + weight_compens = V.graph.add_tensor_constant( + weight_compens_tensor, + name=packed_weight.get_name() + "_BMatrixCompens", + ) + else: + weight_compens_tensor = torch.sum(W_tensor.to(torch.float), dim=0) + weight_compens = V.graph.add_tensor_constant( + weight_compens_tensor, + name=packed_weight.get_name() + "_BMatrixCompens", + ) + return ( # type: ignore[return-type] + use_int8_fast_compensation_path, + weight_compens, + x_w_scale, + ) + + +def codegen_int8_gemm_template_compensation( + use_int8_fast_compensation_path: bool, + input: OpsValue, + _weight_compo: OpsValue, + _x_scale: Optional[OpsValue], + _x_zp: Optional[OpsValue], + _w_scale: Optional[OpsValue], + _x_w_scale: Optional[OpsValue], +) -> OpsValue: + if use_int8_fast_compensation_path: + temp = ops.sub( + ops.mul( + input, + _x_w_scale, + ), + _weight_compo, + ) + else: + temp = ops.mul( + ops.mul( + input, + _x_scale, + ), + _w_scale, + ) + # NOTE: We will apply compensation even if the x_zp is 0 for int8 quantization. + # That's because when torch.compile is invoked for dynamic quantization, + # x might coincidentally have such values that x_zp might be zero despite + # asymmetric quantization. + # Besides, if x_zp is dummy for int8 x, or if x is statically quantized, + # we'd still perform that redundant compute to avoid making the code messy + # because we discovered that redundant computation of compensation did not + # lead to performance degradation with the input shapes tested. + temp = ops.sub( + temp, + ops.mul( + ops.mul( + ops.mul( + _x_scale, + _w_scale, + ), + _x_zp, + ), + _weight_compo, + ), + ) + return temp + + +def grouped_gemm_lowering( + x: TensorBox, + w: list[TensorBox], + b: list[TensorBox], + attr=None, + scalars=None, + algorithm=None, + layout=None, +): + x_size = x.get_size() + if len(x_size) > 2: + # GEMM template needs 2D input, normalize input shape here + x = view(x, [-1, x_size[-1]]) + num_gemm = len(w) + + assert config.max_autotune or config.max_autotune_gemm + # pyrefly: ignore [bad-assignment] + b = [bias if bias is None else ir.ExternKernel.realize_input(bias) for bias in b] + + choices: list[ChoiceCaller] = [] + *_, layout, x, _ = mm_args(x, permute(w[0], [1, 0]), layout=layout) + + kwargs = { + "has_bias": [bias is not None for bias in b], + "trans_w": True, + "epilogue_creator": None, + "act_mapping": dict.fromkeys(range(num_gemm), x), + } + + input_nodes = [x, *w] + input_nodes.extend([bias for bias in b if bias is not None]) + + CppGroupedGemmTemplate.add_choices( + choices, + layout, + input_nodes, + **kwargs, # type: ignore[arg-type] + ) + + assert len(choices) != 0 + result = autotune_select_algorithm( + "grouped_gemm", + choices, + input_nodes, + layout, + ) + template_buf = result.data.data + return_bufs = [ + ir.MultiOutput(layout, template_buf, [(list, gemm_idx)]) + for gemm_idx in range(num_gemm) + ] + # pyrefly: ignore [bad-argument-type] + template_buf.layout = ir.MultiOutputLayout(device=input_nodes[0].get_device()) + template_buf.outputs = return_bufs + return_tensors = [ + ir.TensorBox.create(return_bufs[gemm_idx]) for gemm_idx in range(num_gemm) + ] + if len(x_size) > 2: + for gemm_idx in range(num_gemm): + return_tensors[gemm_idx] = view( + return_tensors[gemm_idx], # type: ignore[arg-type] + (*x_size[:-1], return_tensors[gemm_idx].get_size()[-1]), + ) + return return_tensors + + +grouped_gemm_lowering._inductor_lowering_function = True # type: ignore[attr-defined] + + +def register_onednn_fusion_ops(): + if torch._C._has_mkldnn: + from . import mkldnn_ir + + aten_mkldnn_linear_unary = ExternKernelChoice( + torch.ops.mkldnn._linear_pointwise, + "mkldnn::_linear_pointwise", + has_out_variant=False, + kernel_creator=mkldnn_ir.LinearUnary.create, + ) + aten_mkldnn_linear_binary = ExternKernelChoice( + torch.ops.mkldnn._linear_pointwise.binary, + "mkldnn::_linear_pointwise", + has_out_variant=False, + kernel_creator=mkldnn_ir.LinearBinary.create, + ) + aten_mkldnn_qlinear_unary = ExternKernelChoice( + torch.ops.onednn.qlinear_pointwise, + "onednn::qlinear_pointwise", + has_out_variant=False, + kernel_creator=mkldnn_ir.QLinearPointwisePT2E.create, + ) + aten_mkldnn_qlinear_binary = ExternKernelChoice( + torch.ops.onednn.qlinear_pointwise.binary, + "onednn::qlinear_pointwise", + has_out_variant=False, + kernel_creator=mkldnn_ir.QLinearPointwiseBinaryPT2E.create, + ) + cpu_needs_realized_inputs: list[ + Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket] + ] = [ + torch.ops.mkldnn._convolution_pointwise, + torch.ops.mkldnn._convolution_pointwise_, + torch.ops.mkldnn._convolution_transpose_pointwise, + torch.ops.mkldnn._linear_pointwise, + aten.mkldnn_rnn_layer.default, + torch.ops.onednn.qconv_pointwise, + ] + + @register_lowering(torch.ops.mkldnn._convolution_pointwise) + def convolution_unary( + x: TensorBox, + weight: TensorBox, + bias: TensorBox, + padding, + stride, + dilation, + groups, + attr, + scalars, + algorithm, + ): + return TensorBox.create( + mkldnn_ir.ConvolutionUnary.create( + x, + weight, + bias, + padding, + stride, + dilation, + groups, + attr, + scalars, + algorithm, + ) + ) + + @register_lowering(torch.ops.mkldnn._convolution_pointwise.binary) + def convolution_binary( + x: TensorBox, + other: TensorBox, + weight: TensorBox, + bias: TensorBox, + padding, + stride, + dilation, + groups, + binary_attr, + binary_alpha, + unary_attr, + unary_scalars, + unary_algorithm, + ): + return TensorBox.create( + mkldnn_ir.ConvolutionBinary.create( + x, + other, + weight, + bias, + padding, + stride, + dilation, + groups, + binary_attr, + binary_alpha, + unary_attr, + unary_scalars, + unary_algorithm, + ) + ) + + @register_lowering(torch.ops.mkldnn._convolution_pointwise_.binary) + def convolution_binary_inplace( + x: TensorBox, + other: TensorBox, + weight: TensorBox, + bias: TensorBox, + padding, + stride, + dilation, + groups, + binary_attr, + binary_alpha, + unary_attr, + unary_scalars, + unary_algorithm, + ): + return TensorBox.create( + mkldnn_ir.ConvolutionBinaryInplace.create( + x, + other, + weight, + bias, + padding, + stride, + dilation, + groups, + binary_attr, + binary_alpha, + unary_attr, + unary_scalars, + unary_algorithm, + ) + ) + + @register_lowering(torch.ops.mkldnn._linear_pointwise) + def linear_unary( + x: TensorBox, + w: TensorBox, + b: TensorBox, + attr, + scalars, + algorithm, + layout=None, + ): + x_size = x.get_size() + if len(x_size) > 2: + # GEMM template needs 2D input, normalize input shape here + x = view(x, [-1, x_size[-1]]) + if b is not None: + b = ir.ExternKernel.realize_input(b) # type: ignore[assignment] + choices: list[ChoiceCaller] = [] + if config.max_autotune or config.max_autotune_gemm: + transposed_w = permute(w, [1, 0]) + *_, layout, x, transposed_w = mm_args(x, transposed_w, layout=layout) + if use_cpp_gemm_template(layout, x, transposed_w): + + def epilogue_creator(buf): + return create_epilogue_with_attr( + buf, attr, scalars=scalars, algorithm=algorithm + ) + + kwargs = { + "has_bias": b is not None, + "trans_w": True, + "epilogue_creator": ( + None if attr == "none" else epilogue_creator + ), + } + if b is not None: + kwargs["input_indices"] = [2, 0, 1] # type: ignore[assignment] + CppGemmTemplate.add_choices( + choices, + layout, + [x, w] if b is None else [x, w, b], + **kwargs, # type: ignore[arg-type] + ) + if len(choices) == 0 or use_aten_gemm_kernels(): + kwargs = dict(attr=attr, scalars=scalars, algorithm=algorithm) + if b is None: + kwargs["B"] = None + choices.append( + aten_mkldnn_linear_unary.bind( + [x, w] if b is None else [x, w, b], + layout, + **kwargs, + ) + ) + assert w.get_name() in V.graph.constants + input_gen_fns = { + 1: lambda x: V.graph.constants[x.get_name()], + } + result = autotune_select_algorithm( + "linear_unary", + choices, + [x, w] if b is None else [x, w, b], + layout, + input_gen_fns=input_gen_fns, + ) + if len(x_size) > 2: + result = view(result, (*x_size[:-1], result.get_size()[-1])) + return result + + @register_lowering(torch.ops.mkldnn._linear_pointwise.binary) + def linear_binary( + x: TensorBox, y: TensorBox, w: TensorBox, b: TensorBox, attr, layout=None + ): + x_size = x.get_size() + if len(x_size) > 2: + # GEMM template needs 2D input, normalize input shape here + x = view(x, [-1, x_size[-1]]) + y_size = y.get_size() + if len(y_size) > 2: + y = view(y, [-1, y_size[-1]]) + if b is not None: + b = ir.ExternKernel.realize_input(b) # type: ignore[assignment] + choices: list[ChoiceCaller] = [] + if config.max_autotune or config.max_autotune_gemm: + transposed_w = permute(w, [1, 0]) + *_, layout, x, transposed_w, y = mm_args( + x, transposed_w, y, layout=layout + ) + if use_cpp_gemm_template(layout, x, transposed_w): + + def epilogue_creator(buf): + return create_epilogue_with_attr(buf, attr, other=y) + + kwargs = { + "has_bias": b is not None, + "trans_w": True, + "epilogue_creator": epilogue_creator, + } + + # pyrefly: ignore [unsupported-operation] + kwargs["input_indices"] = [0, 2, 1] if b is None else [3, 0, 2, 1] + CppGemmTemplate.add_choices( + choices, + layout, + [x, y, w] if b is None else [x, y, w, b], + **kwargs, # type: ignore[arg-type] + ) + if len(choices) == 0 or use_aten_gemm_kernels(): + kwargs = dict(attr=attr) + if b is None: + kwargs["B"] = None + choices.append( + aten_mkldnn_linear_binary.bind( + [x, y, w] if b is None else [x, y, w, b], + layout, + **kwargs, + ) + ) + assert w.get_name() in V.graph.constants + input_gen_fns = { + 2: lambda x: V.graph.constants[x.get_name()], + } + result = autotune_select_algorithm( + "linear_binary", + choices, + [x, y, w] if b is None else [x, y, w, b], + layout, + input_gen_fns=input_gen_fns, + ) + if len(x_size) > 2: + result = view(result, (*x_size[:-1], result.get_size()[-1])) + return result + + @register_lowering(torch.ops.mkldnn._convolution_transpose_pointwise) + def convolution_transpose_unary( + x: TensorBox, + weight: TensorBox, + bias: TensorBox, + padding, + output_padding, + stride, + dilation, + groups, + attr, + scalars, + algorithm, + ): + return TensorBox.create( + mkldnn_ir.ConvolutionTransposeUnary.create( + x, + weight, + bias, + padding, + output_padding, + stride, + dilation, + groups, + attr, + scalars, + algorithm, + ) + ) + + @register_lowering(aten.mkldnn_rnn_layer.default) + def mkldnn_rnn_layer( + x: TensorBox, + w0: TensorBox, + w1: TensorBox, + w2: TensorBox, + w3: TensorBox, + hx: TensorBox, + cx: TensorBox, + reverse: bool, + batch_sizes: list[int], + mode: int, + hidden_size: int, + num_layers: int, + has_biases: bool, + bidirectional: bool, + batch_first: bool, + train: bool, + ): + return pytree.tree_map( + TensorBox.create, + mkldnn_ir.MkldnnRnnLayer.create( + x, + w0, + w1, + w2, + w3, + hx, + cx, + reverse, + batch_sizes, + mode, + hidden_size, + num_layers, + has_biases, + bidirectional, + batch_first, + train, + ), + ) + + @register_lowering(torch.ops.onednn.qconv_pointwise, type_promotion_kind=None) + def qconvolution_unary( + x: TensorBox, + x_scale, + x_zp, + packed_weight: TensorBox, + w_scale: TensorBox, + w_zp, + bias: TensorBox, + stride, + padding, + dilation, + groups, + o_inv_scale, + o_zero_point, + output_dtype, + attr, + scalars, + algorithm, + ): + if not isinstance(x_scale, ir.TensorBox): + assert type(x_scale) is float + x_scale = V.graph.add_tensor_constant( + torch.tensor(x_scale, dtype=torch.float32), name="x_scale" + ) + + if x_zp is None: + x_zp = V.graph.add_tensor_constant( + torch.tensor(0, dtype=torch.int32), name="x_zp" + ) + if not isinstance(x_zp, ir.TensorBox): + assert type(x_zp) is int + x_zp = V.graph.add_tensor_constant( + torch.tensor(x_zp, dtype=torch.int32), name="x_zp" + ) + + if w_zp is None: + w_zp = V.graph.add_tensor_constant( + torch.tensor(0, dtype=torch.int32), name="w_zp" + ) + + return TensorBox.create( + mkldnn_ir.QConvPointWisePT2E.create( + x, + x_scale, + x_zp, + packed_weight, + w_scale, + w_zp, + bias, + stride, + padding, + dilation, + groups, + o_inv_scale, + o_zero_point, + output_dtype, + attr, + scalars, + algorithm, + ) + ) + + @register_lowering( + torch.ops.onednn.qconv2d_pointwise.binary, type_promotion_kind=None + ) + @register_lowering( + torch.ops.onednn.qconv2d_pointwise.binary_tensor, type_promotion_kind=None + ) + def qconvolution_binary( + x: TensorBox, + x_scale, + x_zp, + packed_weight: TensorBox, + w_scale: TensorBox, + w_zp, + accum: TensorBox, + bias: TensorBox, + stride, + padding, + dilation, + groups, + o_inv_scale, + o_zero_point, + output_dtype, + accum_scale, + accum_zp, + binary_attr, + alpha, + unary_attr, + unary_scalars, + unary_algorithmm, + ): + if not isinstance(x_scale, ir.TensorBox): + assert type(x_scale) is float + x_scale = V.graph.add_tensor_constant( + torch.tensor(x_scale, dtype=torch.float32), name="x_scale" + ) + + if x_zp is None: + x_zp = V.graph.add_tensor_constant( + torch.tensor(0, dtype=torch.int32), name="x_zp" + ) + if not isinstance(x_zp, ir.TensorBox): + assert type(x_zp) is int + x_zp = V.graph.add_tensor_constant( + torch.tensor(x_zp, dtype=torch.int32), name="x_zp" + ) + + if w_zp is None: + w_zp = V.graph.add_tensor_constant( + torch.tensor(0, dtype=torch.int32), name="w_zp" + ) + + if ( + binary_attr == "sum" + and output_dtype in [torch.float32, torch.bfloat16] + and accum.get_dtype() in [torch.float32, torch.bfloat16] + and accum.get_dtype() != output_dtype + ): + # For int8-mixed-bf16 quantization and inplace add, + # there is case when accum dtype is float32 but output dtype is bfloat16. + # Since the accum will be inplaced changed with post op sum, + # we will do accum dtype conversion here. + accum = to_dtype(accum, output_dtype) + return TensorBox.create( + mkldnn_ir.QConvPointWiseBinaryPT2E.create( + x, + x_scale, # type: ignore[arg-type] + x_zp, # type: ignore[arg-type] + packed_weight, + w_scale, + w_zp, + accum, + bias, + stride, + padding, + dilation, + groups, + o_inv_scale, + o_zero_point, + output_dtype, + accum_scale, + accum_zp, + binary_attr, + alpha, + unary_attr, + unary_scalars, + unary_algorithmm, + ) + ) + + @register_lowering(torch.ops.onednn.qlinear_pointwise, type_promotion_kind=None) + def qlinear_unary( + x: TensorBox, + x_scale, + x_zp, + packed_weight: TensorBox, + w_scale: TensorBox, + w_zp: TensorBox, + bias: TensorBox, + o_scale, + o_zero_point, + output_dtype, + attr, + scalars, + algorithm, + layout=None, + ): + assert packed_weight.get_dtype() in [torch.int8, torch.float8_e4m3fn], ( + "Only int8 and e4m3fn weights are supported by oneDNN qlinear." + ) + x_size = x.get_size() + if len(x_size) > 2: + # GEMM template needs 2D input, normalize input shape here + x = view(x, [-1, x_size[-1]]) + if not isinstance(x_scale, ir.TensorBox): + assert type(x_scale) is float + x_scale = V.graph.add_tensor_constant( + torch.tensor(x_scale, dtype=torch.float32), name="x_scale" + ) + else: + x_scale.realize() + if all(dim == 1 for dim in x_scale.get_size()): + # Corner-case discovered with LLaMA series. + # If all outer dims of x_scale are 1, make it a 0D tensor. + # Otherwise, epilogue creator will run into indexing issues. + x_scale = view(x_scale, []) + assert len(x_scale.get_size()) in [0, 1], "x_scale must be 0D or 1D" + + if x_zp is None: + # If x_zp is None, x is int8 quantized per-tensor and its scale is not reshaped, + # then the codegened code would segfault if we don't create a tensor for x_zp. + # It's safe to do so since x is a symmetrically quantized int8 tensor. + # Moreover, oneDNN qlinear API doesn't accept None value for zp + x_zp = V.graph.add_tensor_constant( + torch.tensor(0, dtype=torch.int32), name="x_zp" + ) + if not isinstance(x_zp, ir.TensorBox): + assert type(x_zp) is int + x_zp = V.graph.add_tensor_constant( + torch.tensor(x_zp, dtype=torch.int32), name="x_zp" + ) + else: + x_zp.realize() + + assert x_zp.get_numel() == 1, "x_zp is incompatible with oneDNN qlinear" + + # When channels less than 8, w_scale/w_zp is Pointwise instead of ConstantBuffer + # Refer to + # https://github.com/pytorch/pytorch/blob/f353d17755ed23b02924c962a86ff99a3405fe10/torch/_inductor/graph.py#L570-L577 # noqa: B950 + if w_zp is None: + # If w_zp is None, then it's a dummy tensor created to denote the + # absence of a zero point, and thus w is int8 symmetrically quantized. + # Moreover, oneDNN qlinear API doesn't accept None value for zp + # pyrefly: ignore [bad-assignment] + w_zp = V.graph.add_tensor_constant( + torch.tensor(0, dtype=torch.int32), name="w_zp" + ) + w_scale.realize() + w_zp.realize() + if w_zp.get_dtype() != torch.int32 and isinstance( + ir.InputsKernel.unwrap_storage_for_input(w_zp), + ir.ConstantBuffer, + ): + # W_zp might be a ConstantBuffer with int64, convert it to int32 + w_zp_tensor = V.graph.constants[w_zp.get_name()].to(torch.int32) + w_zp = V.graph.add_tensor_constant( # type: ignore[assignment] + torch.tensor(w_zp_tensor, dtype=torch.int32), name=w_zp.get_name() + ) + + bias_dtype = None if bias is None else bias.get_dtype() + choices: list[ChoiceCaller] = [] + + if config.max_autotune or config.max_autotune_gemm: + *_, layout, x, packed_weight = mm_args( + x, packed_weight, layout=layout, out_dtype=output_dtype + ) + + if ( + # GEMM template currently only supports symmetrically quantized weights + isinstance( + ir.InputsKernel.unwrap_storage_for_input(w_zp), + ir.ConstantBuffer, + ) + and torch.equal( + torch.zeros_like(V.graph.constants[w_zp.get_name()]), + V.graph.constants[w_zp.get_name()], + ) + ) and use_cpp_gemm_template(layout, x, packed_weight): + W_tensor = V.graph.constants[packed_weight.get_name()].to_dense() + + ( + use_int8_fast_compensation_path, + weight_compens, + x_w_scale, + ) = create_int8_compensation( + W_tensor, + packed_weight, + # pyrefly: ignore [bad-argument-type] + x_scale, + # pyrefly: ignore [bad-argument-type] + x_zp, + w_scale, + ) + + def epilogue_creator(input_buffer): + # Epilogue to convert from s32 to f32 for u8s8f32 + assert output_dtype in [ + torch.float32, + torch.bfloat16, + torch.uint8, + torch.int8, + ] + input_loader = input_buffer.make_loader() + weight_compens_loader = weight_compens.make_loader() + x_w_scale_loader = None + if use_int8_fast_compensation_path: + assert x_w_scale is not None + x_w_scale_loader = x_w_scale.make_loader() + x_scale_loader = x_scale.make_loader() + w_scale_loader = w_scale.make_loader() + x_zp_loader = x_zp.make_loader() + nonlocal bias + bias_loader = None + if bias is not None: + bias_loader = bias.make_loader() + + def inner_fn(index): + nonlocal bias + input = input_loader(index) + # MicroKernel Output is with int32 + # cvt to FP32 before doing compensation + input = ops.to_dtype(input, torch.float32) + weight_compens_index = (index[-1],) + + _x_scale = None + _x_zp = None + _w_scale = None + if not use_int8_fast_compensation_path: + _x_scale = x_scale_loader(()) + _x_zp = x_zp_loader(()) + _w_scale = w_scale_loader(weight_compens_index) + _weight_compo = weight_compens_loader(weight_compens_index) + _x_w_scale = None + if use_int8_fast_compensation_path: + assert x_w_scale_loader is not None + _x_w_scale = x_w_scale_loader(weight_compens_index) + # Step 1: Compute s8s8->s32 or u8s8->s32 GEMM & then apply compensation + temp = codegen_int8_gemm_template_compensation( + use_int8_fast_compensation_path, + input, + _weight_compo, + _x_scale, + _x_zp, + _w_scale, + _x_w_scale, + ) + # Step 2: add Bias if applicable + if bias is not None: + # pyrefly: ignore [not-callable] + _bias = bias_loader(weight_compens_index) + nonlocal bias_dtype + assert bias_dtype in [torch.float32, torch.bfloat16] + if bias_dtype == torch.bfloat16: + _bias = ops.to_dtype(_bias, torch.float32) + temp = ops.add(temp, _bias) + + return temp + + output_buf = ir.Pointwise( + device=input_buffer.get_device(), + dtype=torch.float32, # Hardcode to FP32 for u8s8f32 & s8s8f32 + inner_fn=inner_fn, + ranges=input_buffer.get_size(), + ) + + # Step 3: Doing the unary post op fusion + if attr != "none": + output_buf = create_epilogue_with_attr( + output_buf, attr, scalars=scalars, algorithm=algorithm + ) + + # Step 4: Cast output to Target Dtype + if output_dtype == torch.bfloat16: + output_cast_loader = output_buf.make_loader() + + def inner_fn_cast_output_to_bf16(index): + input = output_cast_loader(index) + return ops.to_dtype(input, output_dtype) + + output_buf = ir.Pointwise( + device=output_buf.get_device_or_error(), + dtype=output_dtype, + inner_fn=inner_fn_cast_output_to_bf16, + ranges=output_buf.get_size(), + ) + elif output_dtype in [torch.uint8, torch.int8]: + from .lowering import _create_constants + + requant_input_loader = output_buf.make_loader() + + def inner_fn_requant(index, scale, zero_point): + input = requant_input_loader(index) + inv_scale, zero_point = _create_constants( + 1.0 / scale, zero_point, dtype=torch.float32 + ) + val = ops.round(input * inv_scale) + zero_point + if output_dtype == torch.uint8: + qmin, qmax = _create_constants( + 0, 255, dtype=torch.float32 + ) + else: + qmin, qmax = _create_constants( + -128, 127, dtype=torch.float32 + ) + clamped = ops.minimum(ops.maximum(val, qmin), qmax) + return ops.to_dtype(clamped, output_dtype) + + output_buf = ir.Pointwise( + device=output_buf.get_device_or_error(), + dtype=output_dtype, + inner_fn=functools.partial( + inner_fn_requant, + scale=float(o_scale), + zero_point=int(o_zero_point), + ), + ranges=output_buf.get_size(), + ) + + return output_buf + + assert x.get_dtype() in [torch.uint8, torch.int8] + CppGemmTemplate.add_choices( + choices, + layout, + [x, x_scale, x_zp, packed_weight, w_scale, w_zp] + if bias is None + else [x, x_scale, x_zp, packed_weight, w_scale, w_zp, bias], + has_bias=bias is not None, + epilogue_creator=epilogue_creator, + input_indices=[0, 3, 1, 2, 4, 5] + if bias is None + else [6, 0, 3, 1, 2, 4, 5], + ) + if len(choices) == 0 or use_aten_gemm_kernels(): + kwargs = dict( + output_scale=o_scale, + output_zero_point=o_zero_point, + output_dtype=output_dtype, + post_op_name=attr, + post_op_args=scalars, + post_op_algorithm=algorithm, + ) + if bias is None: + kwargs["bias"] = None + choices.append( + aten_mkldnn_qlinear_unary.bind( + (x, x_scale, x_zp, packed_weight, w_scale, w_zp) + if bias is None + else (x, x_scale, x_zp, packed_weight, w_scale, w_zp, bias), + layout, + **kwargs, + ) + ) + assert packed_weight.get_name() in V.graph.constants + input_gen_fns = { + 3: lambda x: V.graph.constants[x.get_name()], # packed weight + 4: lambda x: V.graph.constants[x.get_name()], # weight scale + 5: lambda x: V.graph.constants[x.get_name()], # weight zp + 6: lambda x: V.graph.constants[x.get_name()], # bias + } + if isinstance( + ir.InputsKernel.unwrap_storage_for_input(x_scale), + ir.ConstantBuffer, + ): + # x is statically quantized + input_gen_fns[1] = lambda x: V.graph.constants[x.get_name()] + if isinstance( + ir.InputsKernel.unwrap_storage_for_input(x_zp), + ir.ConstantBuffer, + ): + input_gen_fns[2] = lambda x: V.graph.constants[x.get_name()] + + result = autotune_select_algorithm( + "qlinear_unary", + choices, + [x, x_scale, x_zp, packed_weight, w_scale, w_zp] + if bias is None + else [x, x_scale, x_zp, packed_weight, w_scale, w_zp, bias], + layout, + input_gen_fns=input_gen_fns, + ) + if len(x_size) > 2: + result = view(result, (*x_size[:-1], result.get_size()[-1])) + return result + + @register_lowering( + torch.ops.onednn.qlinear_pointwise.binary, type_promotion_kind=None + ) + @register_lowering( + torch.ops.onednn.qlinear_pointwise.binary_tensor, type_promotion_kind=None + ) + def qlinear_binary( + x: TensorBox, + x_scale, + x_zp, + packed_weight: TensorBox, + w_scale: TensorBox, + w_zp: TensorBox, + x2: TensorBox, + bias: TensorBox, + o_scale, + o_zero_point, + output_dtype, + x2_scale, + x2_zp, + binary_attr, + alpha, + unary_attr, + unary_scalars, + unary_algorithmm, + layout=None, + ): + x_size = x.get_size() + x2_size = x2.get_size() + assert len(x_size) == len(x2_size) + if len(x_size) > 2 and binary_attr in ["add", "sum"]: + # GEMM template needs 2D input, normalize input shape here + x = view(x, [-1, x_size[-1]]) + x2 = view(x2, [-1, x2_size[-1]]) + if not isinstance(x_scale, ir.TensorBox): + assert type(x_scale) is float + x_scale = V.graph.add_tensor_constant( + torch.tensor(x_scale, dtype=torch.float32), name="x_scale" + ) + else: + x_scale.realize() + if all(dim == 1 for dim in x_scale.get_size()): + # Corner-case discovered with LLaMA series. + # If all outer dims of x_scale are 1, make it a 0D tensor. + # Otherwise, epilogue creator will run into indexing issues. + x_scale = view(x_scale, []) + assert len(x_scale.get_size()) in [0, 1], "x_scale must be 0D or 1D" + + if x_zp is None: + x_zp = V.graph.add_tensor_constant( + torch.tensor(0, dtype=torch.int32), name="x_zp" + ) + + if w_zp is None: + # pyrefly: ignore [bad-assignment] + w_zp = V.graph.add_tensor_constant( + torch.tensor(0, dtype=torch.int32), name="w_zp" + ) + + if not isinstance(x_zp, ir.TensorBox): + assert type(x_zp) is int + x_zp = V.graph.add_tensor_constant( + torch.tensor(x_zp, dtype=torch.int32), name="x_zp" + ) + else: + x_zp.realize() + + # When channels less than 8, w_scale/w_zp is Pointwise instead of ConstantBuffer + # Refer to + # https://github.com/pytorch/pytorch/blob/f353d17755ed23b02924c962a86ff99a3405fe10/torch/_inductor/graph.py#L570-L577 # noqa: B950 + w_scale.realize() + w_zp.realize() + if w_zp.get_dtype() != torch.int32 and isinstance( + ir.InputsKernel.unwrap_storage_for_input(w_zp), + ir.ConstantBuffer, + ): + w_zp_tensor = V.graph.constants[w_zp.get_name()].to(torch.int32) + w_zp = V.graph.add_tensor_constant( # type: ignore[assignment] + torch.tensor(w_zp_tensor, dtype=torch.int32), name=w_zp.get_name() + ) + if binary_attr == "sum": + if output_dtype in [ + torch.float32, + torch.bfloat16, + ] and x2.get_dtype() in [torch.float32, torch.bfloat16]: + if x2.get_dtype() != output_dtype: + # For int8-mixed-bf16 quantization and inplace add, + # there is case when accum dtype is float32 but output dtype is bfloat16. + # Since the accum will be inplaced changed with post op sum, + # we will do accum dtype conversion here. + x2 = to_dtype(x2, output_dtype) + else: + assert x2.get_dtype() == output_dtype, ( + "dtype of accum for qlinear post op sum should be the same as output" + ) + x2_dtype = x2.get_dtype() + bias_dtype = bias.get_dtype() if bias is not None else None + choices: list[ChoiceCaller] = [] + if (config.max_autotune or config.max_autotune_gemm) and binary_attr in [ + "add", + "sum", + ]: + *_, layout, x, packed_weight, x2 = mm_args( + x, packed_weight, x2, layout=layout, out_dtype=output_dtype + ) + if ( + isinstance( + ir.InputsKernel.unwrap_storage_for_input(x_zp), + ir.ConstantBuffer, + ) + and len(x_zp.get_layout().size) == 0 # Per tensor quant of act + and isinstance( + ir.InputsKernel.unwrap_storage_for_input(w_zp), + ir.ConstantBuffer, + ) + and torch.equal( + torch.zeros_like(V.graph.constants[w_zp.get_name()]), + V.graph.constants[w_zp.get_name()], + ) # We only compensate MatrixB and assume B_zp is 0 to avoid the compensation of MatrixA + and use_cpp_gemm_template(layout, x, packed_weight) + ): + W_tensor = V.graph.constants[packed_weight.get_name()] + W_tensor = W_tensor.to_dense() + ( + use_int8_fast_compensation_path, + weight_compens, + x_w_scale, + ) = create_int8_compensation( + W_tensor, + packed_weight, + # pyrefly: ignore [bad-argument-type] + x_scale, + # pyrefly: ignore [bad-argument-type] + x_zp, + w_scale, + ) + + def epilogue_creator(input_buffer): + # Epilogue to convert from s32 to f32 for u8s8f32 + assert output_dtype in [ + torch.float32, + torch.bfloat16, + torch.uint8, + torch.int8, + ] + + input_loader = input_buffer.make_loader() + x2_loader = x2.make_loader() + weight_compens_loader = weight_compens.make_loader() + x_w_scale_loader = None + if use_int8_fast_compensation_path: + assert x_w_scale is not None + x_w_scale_loader = x_w_scale.make_loader() + x_scale_loader = x_scale.make_loader() + w_scale_loader = w_scale.make_loader() + x_zp_loader = x_zp.make_loader() + nonlocal bias + bias_loader = None + if bias is not None: + bias_loader = bias.make_loader() + + def inner_fn(index): + nonlocal bias + input = input_loader(index) + _x2 = x2_loader(index) + _x_scale = None + _x_zp = None + _w_scale = None + weight_compens_index = (index[-1],) + if not use_int8_fast_compensation_path: + _x_scale = x_scale_loader(()) + _x_zp = x_zp_loader(()) + _w_scale = w_scale_loader(weight_compens_index) + # MicroKernel Output is with int32: cvt to FP32 before doing compensation + input = ops.to_dtype(input, torch.float32) + _weight_compo = weight_compens_loader(weight_compens_index) + _x_w_scale = None + if use_int8_fast_compensation_path: + assert x_w_scale_loader is not None + _x_w_scale = x_w_scale_loader(weight_compens_index) + # Step 1: Doing compensation to cvt fp32 + temp = codegen_int8_gemm_template_compensation( + use_int8_fast_compensation_path, + input, + _weight_compo, + _x_scale, + _x_zp, + _w_scale, + _x_w_scale, + ) + # Step 2: add Bias if applicable + if bias is not None: + # pyrefly: ignore [not-callable] + _bias = bias_loader(weight_compens_index) + nonlocal bias_dtype + assert bias_dtype in [torch.float32, torch.bfloat16] + if bias_dtype == torch.bfloat16: + _bias = ops.to_dtype(_bias, torch.float32) + temp = ops.add(temp, _bias) + + # Step 3: Binary add + nonlocal x2_dtype + assert x2_dtype in [torch.float32, torch.bfloat16] + if x2_dtype == torch.bfloat16: + _x2 = ops.to_dtype(_x2, torch.float32) + temp = ops.add(temp, _x2) + + return temp + + output_buf = ir.Pointwise( + device=input_buffer.get_device(), + dtype=torch.float32, # Hardcode to FP32 for u8s8f32 + inner_fn=inner_fn, + ranges=input_buffer.get_size(), + ) + + # Step 4: Unary post op if has + if unary_attr != "none": + output_buf = create_epilogue_with_attr( + output_buf, + unary_attr, + scalars=unary_scalars, + algorithm=unary_algorithmm, + ) + + # Step 5: Cast output to Target Dtype + if output_dtype == torch.bfloat16: + output_cast_loader = output_buf.make_loader() + + def inner_fn_cast_output_to_bf16(index): + input = output_cast_loader(index) + return ops.to_dtype(input, output_dtype) + + output_buf = ir.Pointwise( + device=output_buf.get_device_or_error(), + dtype=output_dtype, + inner_fn=inner_fn_cast_output_to_bf16, + ranges=output_buf.get_size(), + ) + elif output_dtype in [torch.uint8, torch.int8]: + from .lowering import _create_constants + + requant_input_loader = output_buf.make_loader() + + def inner_fn_requant(index, scale, zero_point): + input = requant_input_loader(index) + inv_scale, zero_point = _create_constants( + 1.0 / scale, zero_point, dtype=torch.float32 + ) + val = ops.round(input * inv_scale) + zero_point + if output_dtype == torch.uint8: + qmin, qmax = _create_constants( + 0, 255, dtype=torch.float32 + ) + else: + qmin, qmax = _create_constants( + -128, 127, dtype=torch.float32 + ) + clamped = ops.minimum(ops.maximum(val, qmin), qmax) + return ops.to_dtype(clamped, torch.uint8) + + output_buf = ir.Pointwise( + device=output_buf.get_device_or_error(), + dtype=torch.uint8, + inner_fn=functools.partial( + inner_fn_requant, + scale=float(o_scale), + zero_point=int(o_zero_point), + ), + ranges=output_buf.get_size(), + ) + + return output_buf + + CppGemmTemplate.add_choices( + choices, + layout, + [x, x_scale, x_zp, packed_weight, w_scale, w_zp, x2] + if bias is None + else [x, x_scale, x_zp, packed_weight, w_scale, w_zp, x2, bias], + has_bias=bias is not None, + epilogue_creator=epilogue_creator, + # Reorder bias and x2 + input_indices=[0, 3, 1, 2, 4, 5, 6] + if bias is None + else [7, 0, 3, 1, 2, 4, 5, 6], + ) + + if len(choices) == 0 or use_aten_gemm_kernels(): + kwargs = dict( + output_scale=o_scale, + output_zero_point=o_zero_point, + output_dtype=output_dtype, + other_scale=x2_scale, + other_zp=x2_zp, + binary_post_op=binary_attr, + binary_alpha=alpha, + unary_post_op=unary_attr, + unary_post_op_args=unary_scalars, + unary_post_op_algorithm=unary_algorithmm, + ) + if bias is None: + kwargs["bias"] = None + choices.append( + aten_mkldnn_qlinear_binary.bind( + (x, x_scale, x_zp, packed_weight, w_scale, w_zp, x2) + if bias is None + else (x, x_scale, x_zp, packed_weight, w_scale, w_zp, x2, bias), + layout, + **kwargs, + ) + ) + assert packed_weight.get_name() in V.graph.constants + input_gen_fns = { + 3: lambda x: V.graph.constants[x.get_name()], + 4: lambda x: V.graph.constants[x.get_name()], + 5: lambda x: V.graph.constants[x.get_name()], + } + if bias is not None: + input_gen_fns[7] = lambda x: V.graph.constants[x.get_name()] # For bias + result = autotune_select_algorithm( + "qlinear_binary", + choices, + [x, x_scale, x_zp, packed_weight, w_scale, w_zp, x2] + if bias is None + else [x, x_scale, x_zp, packed_weight, w_scale, w_zp, x2, bias], + layout, + input_gen_fns=input_gen_fns, + ) + if ( + isinstance(result.data.data, ir.CppTemplateBuffer) + and binary_attr == "sum" + and result.data.data.layout == x2.get_layout() + ): + # In this case, since x2 is inplace updated when binary_attr is "sum" + # we update the layout of result to view of x2 + result = ir.TensorBox.create( + ir.CppTemplateBuffer( + layout=ir.NonOwningLayout( + ir.ReinterpretView(data=x2, layout=x2.get_layout()) + ), + inputs=result.data.data.inputs, # type: ignore[arg-type] + make_kernel_render=result.data.data.make_kernel_render, # type: ignore[arg-type] + template=result.data.data.template, + choice=result.data.data.choice, + ) + ) + if len(x_size) > 2 and binary_attr in ["add", "sum"]: + result = view(result, (*x_size[:-1], result.get_size()[-1])) # type: ignore[arg-type] + return result + + if torch._C.has_mkl: + aten_mkl_linear = ExternKernelChoice( + torch.ops.mkl._mkl_linear, + "mkl::_mkl_linear", + has_out_variant=False, + kernel_creator=mkldnn_ir.MKLPackedLinear.create, + ) + cpu_needs_realized_inputs.append(torch.ops.mkl._mkl_linear) + + @register_lowering(torch.ops.mkl._mkl_linear) + def mkl_packed_linear( + x: TensorBox, + packed_w: TensorBox, + orig_w: TensorBox, + b: Optional[TensorBox], + batch_size, + *, + layout=None, + ): + choices: list[ChoiceCaller] = [] + if config.max_autotune or config.max_autotune_gemm: + transposed_w = permute(orig_w, [1, 0]) + *_, layout, x, transposed_w = mm_args( + x, transposed_w, layout=layout + ) + if use_cpp_gemm_template(layout, x, transposed_w): + CppGemmTemplate.add_choices( + choices, + layout, + [x, packed_w, orig_w], + trans_w=True, + input_indices=[0, 2], + ) + + if len(choices) == 0 or use_aten_gemm_kernels(): + choices.append( + aten_mkl_linear.bind( + (x, packed_w, orig_w), layout, B=None, batch_size=batch_size + ) + ) + + assert packed_w.get_name() in V.graph.constants + assert orig_w.get_name() in V.graph.constants + # packed_w is a mkldnn tensor which we can't generate directly + # so we use the weights from the original tensor in autotune. + input_gen_fns = { + 1: lambda x: V.graph.constants[x.get_name()], + 2: lambda x: V.graph.constants[x.get_name()], + } + result: TensorBox = autotune_select_algorithm( + "packed_linear", + choices, + [x, packed_w, orig_w], + layout, + input_gen_fns=input_gen_fns, + ) + if b is not None: + result = add(result, b) + return result + + add_needs_realized_inputs(cpu_needs_realized_inputs) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/mock_cache.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/mock_cache.py new file mode 100644 index 0000000000000000000000000000000000000000..3d9c58f1db8bd9b4ed523a22d4ea6dfc7208a756 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/mock_cache.py @@ -0,0 +1,274 @@ +# mypy: ignore-errors + +from __future__ import annotations + +import contextlib +import dataclasses +import sys +import threading +from typing import Any, Optional, TYPE_CHECKING +from typing_extensions import override, Self +from unittest.mock import patch + +from torch._inductor import config +from torch._inductor.remote_cache import RemoteCacheBackend + + +if TYPE_CHECKING: + from collections.abc import Callable + from types import TracebackType + + +@dataclasses.dataclass +class Stats: + num_put: int = 0 + num_get_hit: int = 0 + num_get_miss: int = 0 + + def __iadd__(self, other: Stats) -> Self: + self.num_put += other.num_put + self.num_get_hit += other.num_get_hit + self.num_get_miss += other.num_get_miss + return self + + def reset(self) -> None: + self.num_put = 0 + self.num_get_hit = 0 + self.num_get_miss = 0 + + def __str__(self) -> str: + return "".join( + ( + f"puts: {self.num_put}, ", + f"misses: {self.num_get_miss}, ", + f"hits: {self.num_get_hit}, ", + ) + ) + + def __eq__(self, other: object) -> bool: + # Dataclass's default __eq__ checks that the types are the same so can't + # be used with _GlobalItemStats. + return ( + isinstance(other, (Stats, _GlobalItemStats)) + and self.num_put == other.num_put + and self.num_get_hit == other.num_get_hit + and self.num_get_miss == other.num_get_miss + ) + + +class _GlobalItemStats(Stats): + cache: dict[str, object] + + def __init__(self) -> None: + super().__init__() + self.cache = {} + + def reset(self) -> None: + super().reset() + self.cache = {} + + +# The cache states are thread-local so if we're running multiple tests at once +# they won't cross contaminate. However - it needs to be "global" because we +# allow code to create new cache clients which refer to the same cache (because +# it's a remote cache). + + +class _GlobalStats(threading.local): + def __init__(self) -> None: + self.autotune_local = _GlobalItemStats() + self.autotune_remote = _GlobalItemStats() + self.bundled_autotune = _GlobalItemStats() + self.fx_graph = _GlobalItemStats() + self.triton = _GlobalItemStats() + self.aot_autograd = _GlobalItemStats() + self.dynamo_pgo = _GlobalItemStats() + + def reset(self) -> None: + self.autotune_local.reset() + self.autotune_remote.reset() + self.bundled_autotune.reset() + self.fx_graph.reset() + self.triton.reset() + self.aot_autograd.reset() + self.dynamo_pgo.reset() + + def get_stat(self, name: str) -> _GlobalItemStats: + return getattr(self, name) + + def report(self): + subs = ( + ("autotune_local", self.autotune_local), + ("autotune_remote", self.autotune_remote), + ("bundled_autotune", self.bundled_autotune), + ("fx_graph", self.fx_graph), + ("triton", self.triton), + ("aot_autograd", self.aot_autograd), + ("dynamo_pgo", self.dynamo_pgo), + ) + + print("Cache Stats:", file=sys.stderr) + for name, sub in subs: + print(f" {name}: {sub}", file=sys.stderr) + + print("Cache Entries:", file=sys.stderr) + for name, sub in subs: + if sub.cache: + print(f" {name}:", file=sys.stderr) + for k, v in sorted(sub.cache.items()): + v = repr(v) + if len(v) > 100: + v = v[:100] + "..." + print(f" {k!r}: {v}", file=sys.stderr) + + +global_stats = _GlobalStats() + + +class MockBackend(RemoteCacheBackend[Any]): + def __init__(self, name: str) -> None: + self._name = name + + @staticmethod + def with_name(name: str) -> Callable[[], MockBackend]: + def wrapper() -> MockBackend: + return MockBackend(name) + + return wrapper + + @override + def _get(self, key: str) -> Optional[Any]: + stat = global_stats.get_stat(self._name) + if key in stat.cache: + stat += Stats(num_get_hit=1) + return stat.cache.get(key) + else: + stat += Stats(num_get_miss=1) + return None + + @override + def _put(self, key: str, data: Any) -> None: + stat = global_stats.get_stat(self._name) + stat += Stats(num_put=1) + stat.cache[key] = data + + +# List of configs for each cache +_CACHE_CONFIG_EN = ( + "fx_graph_cache", + "fx_graph_remote_cache", + "autotune_local_cache", + "autotune_remote_cache", + "bundled_autotune_remote_cache", +) + + +class PatchCaches(contextlib.AbstractContextManager): + @classmethod + def setUp(cls): + # If this test is using PatchCaches then disable all the caches by + # default, letting the tests turn them on explicitly. This is because + # tests using PatchCaches will often want to check stats explicitly. + cls._savedCacheState = {} + for name in _CACHE_CONFIG_EN: + if hasattr(config, name): + cls._savedCacheState[name] = getattr(config, name) + setattr(config, name, False) + + @classmethod + def tearDown(cls): + # Restore cache defaults + for name in _CACHE_CONFIG_EN: + delattr(config, name) + if name in cls._savedCacheState: + setattr(config, name, cls._savedCacheState[name]) + + def __init__(self) -> None: + self._stack = contextlib.ExitStack() + + def __enter__(self) -> Self: + global_stats.reset() + self._stack.__enter__() + + ctx = patch( + "torch._inductor.runtime.autotune_cache.LocalAutotuneCache.backend_override_cls", + MockBackend.with_name("autotune_local"), + ) + self._stack.enter_context(ctx) + + ctx = patch( + "torch._inductor.remote_cache.RemoteAutotuneCache.backend_override_cls", + MockBackend.with_name("autotune_remote"), + ) + self._stack.enter_context(ctx) + + ctx = patch( + "torch._inductor.remote_cache.RemoteBundledAutotuneCache.backend_override_cls", + MockBackend.with_name("bundled_autotune"), + ) + self._stack.enter_context(ctx) + + ctx = patch( + "torch._inductor.remote_cache.RemoteFxGraphCache.backend_override_cls", + MockBackend.with_name("fx_graph"), + ) + self._stack.enter_context(ctx) + + ctx = patch( + "torch._inductor.remote_cache.RemoteAOTAutogradCache.backend_override_cls", + MockBackend.with_name("aot_autograd"), + ) + self._stack.enter_context(ctx) + + ctx = patch( + "torch._inductor.remote_cache.RemoteDynamoPGOCache.backend_override_cls", + MockBackend.with_name("dynamo_pgo"), + ) + self._stack.enter_context(ctx) + + if config.is_fbcode(): + ctx = patch( + "torch._inductor.fb.remote_cache.FbRemoteAutotuneCache.backend_override_cls", + MockBackend.with_name("autotune_remote"), + ) + self._stack.enter_context(ctx) + + ctx = patch( + "torch._inductor.fb.remote_cache.FbRemoteBundledAutotuneCache.backend_override_cls", + MockBackend.with_name("bundled_autotune"), + ) + self._stack.enter_context(ctx) + + ctx = patch( + "torch._inductor.fb.remote_cache.FbRemoteFxGraphCache.backend_override_cls", + MockBackend.with_name("fx_graph"), + ) + self._stack.enter_context(ctx) + + ctx = patch( + "triton.fb.fb_memcache.FbMemcacheRemoteKernelCache.backend_override_cls", + MockBackend.with_name("triton"), + ) + self._stack.enter_context(ctx) + + ctx = patch( + "torch._inductor.fb.remote_cache.FbRemoteAOTAutogradCache.backend_override_cls", + MockBackend.with_name("aot_autograd"), + ) + self._stack.enter_context(ctx) + + ctx = patch( + "torch._inductor.fb.remote_cache.FbRemoteDynamoPGOCache.backend_override_cls", + MockBackend.with_name("dynamo_pgo"), + ) + self._stack.enter_context(ctx) + + return self + + def __exit__( + self, + exc_type: Optional[type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> None: + self._stack.__exit__(exc_type, exc_value, traceback) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/ops_handler.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/ops_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..725abe260598d63aa5e053e1bcdb36bdbe74d975 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/ops_handler.py @@ -0,0 +1,1183 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import inspect +import itertools +import re +import warnings +from io import StringIO +from typing import ( + Any, + Generic, + Literal, + NamedTuple, + Optional, + TYPE_CHECKING, + TypeVar, + Union, +) +from unittest.mock import patch + +import sympy + +import torch +import torch.utils._pytree as pytree + +from ..utils._ordered_set import OrderedSet +from .utils import IndentedBuffer, reduction_num_outputs, sympy_index_symbol, sympy_str + + +if TYPE_CHECKING: + from collections.abc import Callable + + +T = TypeVar("T") +StoreMode = Optional[Literal["atomic_add", "tma"]] +ReductionType = Literal[ + "argmax", + "argmin", + "welford_reduce", + "welford_combine", + "any", + "max", + "min", + "prod", + "sum", + "dot", + "xor_sum", + "online_softmax_reduce", +] + + +def _arg_str(a: object) -> str: + if isinstance(a, sympy.Expr): + return sympy_str(a) + return str(a) + + +# See OpDecompositions for superclass that desugars operations like reciprocal/square. +class OpsHandler(Generic[T]): + """ + Protocol describing the set of valid operations on ``torch._inductor.virtualized.ops``, + as well as the contract for op handlers. The type T signifies the domain + of the abstract analysis AKA what all the functions return / take as arguments + anywhere compute occurs. + + While these operators are typically dtype polymorphic (e.g., you can use mul + on both integers and floats), they do NOT do promotion and usually return the + same dtype as the input. You are expected to have handled type promotion + during ATen decompositions. Most operators correspond exactly to pointwise + operations as defined by torch, so when in doubt about semantics, check the + corresponding torch documentation. These are all scalar operations (so they + are defined to operate on a single element at a time.) + + For convenience, many operators take a src_dtype which indicates what the dtype + of the input argument is. Although in principle this can be derived by an + analysis, providing this for ops where it is useful helps avoid having to repeatedly + recompute dtype in code generation. + + Note that this often describes a class of static methods, for stateless + ops handlers. + + Handlers are often defined using metaprogramming (e.g. _initialize_pointwise_overrides), + which means you will not get type errors for those methods. We have tests in + test/inductor/test_op_completeness.py which check that all operators are implemented after + all the metaprogramming has run. + """ + + def constant(self, value: Union[bool, float, int], dtype: torch.dtype) -> T: + """Produces a scalar constant of type dtype.""" + raise NotImplementedError + + def load_seed(self, name: str, offset: T) -> T: + """Computes inductor_prims.lookup_seed.""" + raise NotImplementedError + + def rand(self, seed: T, offset: T) -> T: + """Computes inductor_prims.random with mode="rand". offset has dtype int32.""" + raise NotImplementedError + + def randn(self, seed: T, offset: T) -> T: + """Computes inductor_prims.random with mode="randn". offset has dtype int32.""" + raise NotImplementedError + + def randint64(self, seed: T, offset: T, low: T, high: T) -> T: + """Computes inductor_prims.randint. offset has dtype int32.""" + raise NotImplementedError + + def masked(self, mask: T, body: Callable[[], T], other: T) -> T: + """ + Computes body, but only perform loads/stores if the boolean mask + evaluates to true. For example, you would use this if you needed to + perform an indirect load that may not be valid on some elements; + without masking, invalid accesses can cause IMAs. When mask is true, + the result is the result of body; otherwise it is other. Here, `other` + needs to be a constant. + + Contrast this with ops.where, which can multiplex between two values + that have been unconditionally computed. + """ + raise NotImplementedError + + def where(self, condition: T, input: T, other: T) -> T: + """ + Computes torch.where: when condition is true, return input; otherwise return other. + """ + raise NotImplementedError + + def index_expr(self, expr: sympy.Expr, dtype: torch.dtype) -> T: + """ + Converts a sympy expression into a scalar of type dtype. expr is typically + an indexing expression, thus the name; however, it can also be used in + non-indexing situations. + """ + raise NotImplementedError + + def to_dtype( + self, + x: T, + dtype: torch.dtype, + src_dtype: Optional[torch.dtype] = None, + use_compute_types: bool = True, + ) -> T: + """ + Convert x to dtype. src_dtype can be optionally set to specify what the original + dtype of x was, which can improve code generation (used by torch to(dtype=dtype)). + """ + raise NotImplementedError + + def trunc_to_int(self, x: T, dtype: torch.dtype) -> T: + """ + Convert x to dtype with truncation semantics (similar to how the int + constructor works in Python). In Inductor codegen, this just decays + to trunc and then to_dtype, but this composite operation helps + roundtrips for Sympy evaluation. + + dtype is taken as an explicit parameter because the desired output + dtype is typically the index dtype, which may vary between int32 and + int64 depending on if we've shown that all the indexing operations can + be done in int32. + """ + raise NotImplementedError + + def ceil_to_int(self, x: T, dtype: torch.dtype) -> T: + """ + Convert x to dtype with ceiling semantics. See also trunc_to_int. + """ + raise NotImplementedError + + def floor_to_int(self, x: T, dtype: torch.dtype) -> T: + """ + Convert x to dtype with ceiling semantics. See also trunc_to_int. + """ + raise NotImplementedError + + def round_to_int(self, x: T, dtype: torch.dtype) -> T: + """ + Convert x to dtype with round-to-even semantics. See also trunc_to_int. + """ + raise NotImplementedError + + def to_dtype_bitcast(self, x: T, dtype: torch.dtype, src_dtype: torch.dtype) -> T: + """ + Reinterpret cast x to dtype (reinterpreting the bits in memory as another dtype.) + src_dtype must be the original type of x. + """ + raise NotImplementedError + + def identity(self, x: T) -> T: + """ + Returns x as is. This is used to trigger CSE. + """ + raise NotImplementedError + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # These operations are only available in a "kernel" context. Check + # torch._inductor.codegen.common.CSEProxy for their typical implementation + # in op handler (routing to their respective implementations in the kernel + # handler) + # + # Importantly, inside a kernel, indexing and mask variables are available + # in scope, which are typically used by sympy.Expr indexing. + + def indirect_indexing( + self, x: T, size: sympy.Expr, check: bool = True, wrap_neg=True + ) -> sympy.Expr: + """ + Convert an integral x into a sympy.Expr that can be subsequently used in + indexing computation. 'size' represents an upper bound on what valid + indexes can be; when 'check' is True, we check that the x is in bounds. + + NB: This is typically mandatory to implement for any analysis, because you + MUST return a valid sympy.Expr of some sort (even if it's a meaningless symbol). + """ + raise NotImplementedError + + def load(self, name: str, index: sympy.Expr) -> T: + """ + Load from the memory location 'name', offset by some indexing expression 'index'. + """ + raise NotImplementedError + + def store( + self, + name: str, + index: sympy.Expr, + value: T, + mode: StoreMode = None, + ) -> None: + """ + Store 'value' to the memory location 'name' offset by 'expr'. If + specified, 'mode' can require the store to be an atomic addition. + """ + raise NotImplementedError + + # TODO: Better explain how the "collective" semantics of these ops; + # remember that the input value is a scalar, you can't reduce on it in the + # traditional sense! + def reduction( + self, + dtype: torch.dtype, + src_dtype: torch.dtype, + reduction_type: ReductionType, + value: T, + ) -> Union[T, tuple[T, ...]]: + """ + Perform a 'reduction_type' reduction on 'value' of dtype 'src_dtype', + using 'dtype' as the accumulation dtype for the reduction. The result + is an intermediate computation which should be stored to the final + location using 'ops.store_reduction'. + + Valid reduction types are . For Welford reduction types, this + function returns multiple outputs; consult reduction_num_outputs to + determine the amount in metaprogramming applications. + """ + raise NotImplementedError + + # TODO: in practice, this seems to actually return None, but not returning + # a T makes common __getattr__ idioms not type correctly. Figure out if + # this should be returning something. + def store_reduction(self, name: str, index: sympy.Expr, value: T) -> None: + """ + Store the fully accumulated result of 'reduction' to the memory + location 'name' offset by 'expr'. + """ + raise NotImplementedError + + def scan( + self, + dtypes: tuple[torch.dtype, ...], + combine_fn: Callable[[tuple[T, ...], tuple[T, ...]], tuple[T, ...]], + values: tuple[T, ...], + ) -> tuple[T, ...]: + """ + Perform an associative scan on 'value'. + """ + # TODO: Improve the description with some pseudocode + raise NotImplementedError + + def sort( + self, + dtypes: tuple[torch.dtype, ...], + values: tuple[T, ...], + stable: bool, + descending: bool, + ) -> tuple[T, ...]: + """ + Sort values along the reduction dimension. + """ + raise NotImplementedError + + def bucketize( + self, + values: T, + boundaries: tuple[str, sympy.Expr, sympy.Expr, sympy.Expr], + boundary_indices: T, + indexing_dtype: torch.dtype, + right: bool, + sorter: Optional[tuple[str, sympy.Expr]] = None, + sorter_indices: Optional[T] = None, + ) -> T: + # See [Note: Inductor bucketize op] + raise NotImplementedError + + def partial_accumulate( + self, + name: str, + reduction_type: ReductionType, + value: T, + extra_meta: dict[str, Any], + ) -> None: + raise NotImplementedError + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # The following ops have semantics that correspond exactly to the torch + # operation with the same corresponding name. + + def abs(self, x0: T) -> T: + raise NotImplementedError + + def exp(self, x0: T) -> T: + raise NotImplementedError + + def exp2(self, x0: T) -> T: + raise NotImplementedError + + def expm1(self, x0: T) -> T: + raise NotImplementedError + + def sqrt(self, x0: T) -> T: + raise NotImplementedError + + def relu(self, x0: T) -> T: + raise NotImplementedError + + def minimum(self, x0: T, x1: T) -> T: + raise NotImplementedError + + def maximum(self, x0: T, x1: T) -> T: + raise NotImplementedError + + def cos(self, x0: T) -> T: + raise NotImplementedError + + def sin(self, x0: T) -> T: + raise NotImplementedError + + def lgamma(self, x0: T) -> T: + raise NotImplementedError + + def erf(self, x0: T) -> T: + raise NotImplementedError + + def cosh(self, x0: T) -> T: + raise NotImplementedError + + def sinh(self, x0: T) -> T: + raise NotImplementedError + + def acos(self, x0: T) -> T: + raise NotImplementedError + + def acosh(self, x0: T) -> T: + raise NotImplementedError + + def asin(self, x0: T) -> T: + raise NotImplementedError + + def asinh(self, x0: T) -> T: + raise NotImplementedError + + def atan2(self, x0: T, x1: T) -> T: + raise NotImplementedError + + def atan(self, x0: T) -> T: + raise NotImplementedError + + def atanh(self, x0: T) -> T: + raise NotImplementedError + + def copysign(self, x0: T, x1: T) -> T: + raise NotImplementedError + + def erfc(self, x0: T) -> T: + raise NotImplementedError + + def erfinv(self, x0: T) -> T: + raise NotImplementedError + + def frexp(self, x0: T): + raise NotImplementedError + + def hypot(self, x0: T, x1: T) -> T: + raise NotImplementedError + + def log10(self, x0: T) -> T: + raise NotImplementedError + + def log2(self, x0: T) -> T: + raise NotImplementedError + + def nextafter(self, x0: T, x1: T) -> T: + raise NotImplementedError + + def logical_and(self, x0: T, x1: T) -> T: + raise NotImplementedError + + def logical_not(self, x0: T) -> T: + raise NotImplementedError + + def logical_or(self, x0: T, x1: T) -> T: + raise NotImplementedError + + def logical_xor(self, x0: T, x1: T) -> T: + raise NotImplementedError + + def bitwise_and(self, x0: T, x1: T) -> T: + raise NotImplementedError + + def bitwise_not(self, x0: T) -> T: + raise NotImplementedError + + def bitwise_or(self, x0: T, x1: T) -> T: + raise NotImplementedError + + def bitwise_xor(self, x0: T, x1: T) -> T: + raise NotImplementedError + + def bitwise_left_shift(self, x0: T, x1: T) -> T: + raise NotImplementedError + + def bitwise_right_shift(self, x0: T, x1: T) -> T: + raise NotImplementedError + + def rsqrt(self, x0: T) -> T: + raise NotImplementedError + + def log1p(self, x0: T) -> T: + raise NotImplementedError + + def tan(self, x0: T) -> T: + raise NotImplementedError + + def tanh(self, x0: T) -> T: + raise NotImplementedError + + def sigmoid(self, x0: T) -> T: + raise NotImplementedError + + def signbit(self, x0: T) -> T: + raise NotImplementedError + + def fmod(self, x0: T, x1: T) -> T: + raise NotImplementedError + + def log(self, x0: T) -> T: + raise NotImplementedError + + def isinf(self, x0: T) -> T: + raise NotImplementedError + + def isnan(self, x0: T) -> T: + raise NotImplementedError + + # NB: this returns a float, like the torch operation + # This rounds half to even to break ties + def round(self, x0: T) -> T: + raise NotImplementedError + + # NB: this returns a float, like the torch operation + def floor(self, x0: T) -> T: + raise NotImplementedError + + def sign(self, x0: T) -> T: + raise NotImplementedError + + # NB: this returns a float, like the torch operation + def trunc(self, x0: T) -> T: + raise NotImplementedError + + # NB: this returns a float, like the torch operation + def ceil(self, x0: T) -> T: + raise NotImplementedError + + def neg(self, x0: T) -> T: + raise NotImplementedError + + def reciprocal(self, x0: T) -> T: + raise NotImplementedError + + def eq(self, x0: T, x1: T) -> T: + raise NotImplementedError + + def ne(self, x0: T, x1: T) -> T: + raise NotImplementedError + + def lt(self, x0: T, x1: T) -> T: + raise NotImplementedError + + def gt(self, x0: T, x1: T) -> T: + raise NotImplementedError + + def le(self, x0: T, x1: T) -> T: + raise NotImplementedError + + def ge(self, x0: T, x1: T) -> T: + raise NotImplementedError + + def add(self, x0: T, x1: T) -> T: + raise NotImplementedError + + def sub(self, x0: T, x1: T) -> T: + raise NotImplementedError + + def mul(self, x0: T, x1: T) -> T: + raise NotImplementedError + + # NB: this returns a float, like the torch operation + def pow(self, x0: T, x1: T) -> T: + raise NotImplementedError + + def and_(self, x0: T, x1: T) -> T: + raise NotImplementedError + + def or_(self, x0: T, x1: T) -> T: + raise NotImplementedError + + def xor(self, x0: T, x1: T) -> T: + raise NotImplementedError + + # These are metaprogrammed by MockHandler._init_cls + def lshift(self, x0: T, x1: T) -> T: + raise NotImplementedError + + def rshift(self, x0: T, x1: T) -> T: + raise NotImplementedError + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # These are "special" operators. These only exist if the target + # language actually supports the operator. Keep this in sync with + # pointwise_overrides_data. + + def airy_ai(self, x: T) -> T: + raise NotImplementedError + + def bessel_j0(self, x: T) -> T: + raise NotImplementedError + + def bessel_j1(self, x: T) -> T: + raise NotImplementedError + + def bessel_y0(self, x: T) -> T: + raise NotImplementedError + + def bessel_y1(self, x: T) -> T: + raise NotImplementedError + + def digamma(self, x: T) -> T: + raise NotImplementedError + + def erfcx(self, x: T) -> T: + raise NotImplementedError + + def fma(self, x: T, y: T, z: T) -> T: + raise NotImplementedError + + def igamma(self, x: T, y: T) -> T: + raise NotImplementedError + + def igammac(self, x: T, y: T) -> T: + raise NotImplementedError + + def gammainc(self, x: T, y: T) -> T: + raise NotImplementedError + + def gammaincc(self, x: T, y: T) -> T: + raise NotImplementedError + + def i0(self, x: T) -> T: + raise NotImplementedError + + def i0e(self, x: T) -> T: + raise NotImplementedError + + def i1(self, x: T) -> T: + raise NotImplementedError + + def i1e(self, x: T) -> T: + raise NotImplementedError + + def log_ndtr(self, x: T) -> T: + raise NotImplementedError + + def modified_bessel_i0(self, x: T) -> T: + raise NotImplementedError + + def modified_bessel_i1(self, x: T) -> T: + raise NotImplementedError + + def modified_bessel_k0(self, x: T) -> T: + raise NotImplementedError + + def modified_bessel_k1(self, x: T) -> T: + raise NotImplementedError + + def ndtr(self, x: T) -> T: + raise NotImplementedError + + def ndtri(self, x: T) -> T: + raise NotImplementedError + + def polygamma(self, x: T, y: T) -> T: + raise NotImplementedError + + def scaled_modified_bessel_k0(self, x: T) -> T: + raise NotImplementedError + + def scaled_modified_bessel_k1(self, x: T) -> T: + raise NotImplementedError + + def spherical_bessel_j0(self, x: T) -> T: + raise NotImplementedError + + def zeta(self, x: T, y: T) -> T: + raise NotImplementedError + + def chebyshev_polynomial_t(self, x: T, y: T) -> T: + raise NotImplementedError + + def chebyshev_polynomial_u(self, x: T, y: T) -> T: + raise NotImplementedError + + def chebyshev_polynomial_v(self, x: T, y: T) -> T: + raise NotImplementedError + + def chebyshev_polynomial_w(self, x: T, y: T) -> T: + raise NotImplementedError + + def legendre_polynomial_p(self, x: T, y: T) -> T: + raise NotImplementedError + + def shifted_chebyshev_polynomial_t(self, x: T, y: T) -> T: + raise NotImplementedError + + def shifted_chebyshev_polynomial_u(self, x: T, y: T) -> T: + raise NotImplementedError + + def shifted_chebyshev_polynomial_v(self, x: T, y: T) -> T: + raise NotImplementedError + + def shifted_chebyshev_polynomial_w(self, x: T, y: T) -> T: + raise NotImplementedError + + def hermite_polynomial_h(self, x: T, y: T) -> T: + raise NotImplementedError + + def hermite_polynomial_he(self, x: T, y: T) -> T: + raise NotImplementedError + + def laguerre_polynomial_l(self, x: T, y: T) -> T: + raise NotImplementedError + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # These operators are a bit special, because they are conventionally + # natively supported in both Python and C, but the semantics differ so + # care must be taken + + def truncdiv(self, x0: T, x1: T) -> T: + """C-style trunc division between integers only. Computes the true + division of two numbers and rounds the result to zero. + """ + raise NotImplementedError + + def floordiv(self, x0: T, x1: T) -> T: + """Python-style floor division between integers only. Computes the + true division of two numbers and floors the result. If you want + floor division for floats, do regular truediv and floor the result. + """ + raise NotImplementedError + + def truediv(self, x0: T, x1: T) -> T: + """True division between floats. Integer inputs are NOT valid. To + do Python-style (int, int) -> float division, use int_truediv""" + raise NotImplementedError + + def int_truediv(self, x0: T, x1: T) -> T: + """True division between integers. This is NOT the same as promoting + to float and doing integer division, there is a bespoke algorithm for + doing the division in higher precision than the above. + """ + raise NotImplementedError + + def mod(self, x0: T, x1: T) -> T: + """C-style modulus, take sign from LHS (x0).""" + raise NotImplementedError + + def remainder(self, x0: T, x1: T) -> T: + """Python-style modulus, take sign from RHS (x1).""" + raise NotImplementedError + + def square(self, x0: T) -> T: + raise NotImplementedError + + def check_bounds( + self, expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool + ) -> None: + raise NotImplementedError + + # halide-only + def halide_clamp(self, value: T, size: sympy.Expr, check: bool) -> T: + raise NotImplementedError + + # triton-only + def dot(self, x: T, y: T) -> T: + raise NotImplementedError + + # triton-only + def inline_asm_elementwise( + self, + *inputs: T, + asm: str, + constraints: Optional[str] = None, + dtype: torch.dtype = torch.float32, + is_pure: bool = True, + pack: int = 1, + ) -> T: + raise NotImplementedError + + def output(self, *args: T) -> None: + """This is a fake op used in analysis but not codegen""" + raise NotImplementedError + + def placeholder(self, index: int) -> T: + """This is a fake op used in analysis but not codegen""" + raise NotImplementedError + + def device_assert_async(self, cond: T, msg: str) -> T: + raise NotImplementedError + + +_ignore_op_re = re.compile(r"_.*|paren").fullmatch + + +def list_ops(cls: type[Any]): + return OrderedSet([x for x in dir(cls) if not _ignore_op_re(x)]) + + +OP_NAMES = list_ops(OpsHandler) + + +class DefaultHandler(OpsHandler[Any]): + def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: + """ + Default implementation for all ops. Override in a subclass to + provide generic op behavior. + + Args: + name: name of the op, see OpHandler.{name} + args: positional args passed to the op + kwargs: keyword args passed to the op + + Returns: + return value of the op + + """ + raise NotImplementedError + + def __getattr__(self, name: str) -> Any: + def fallback(*args: Any, **kwargs: Any) -> Any: + return self._default(name, args, kwargs) + + # would like to remove this function entirely, but it's used in MTIA backend + warnings.warn(f"undefined OpHandler.{name}, please add missing op schema") + return fallback + + @staticmethod + def _call_default(target: str): + def call_default(self, *args, **kwargs): + return self._default(target, args, kwargs) + + call_default.__name__ = target + return call_default + + @classmethod + def _init_cls(cls): + """ + Here we codegen many functions of the form: + + def add(self, a, b): + return self._default('add', (a, b), {}) + + and install them in cls. This is the same as _call_default above, + but is about 1.2x faster since CPython varargs parsing is slow. + """ + code = StringIO() + for target in OP_NAMES: + sig = inspect.signature(getattr(OpsHandler, target)) + if all( + p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD + and p.default is inspect.Parameter.empty + for p in sig.parameters.values() + ): + self_arg, *args = sig.parameters.keys() + assert self_arg == "self" + code.write( + f""" + def {target}(self, {", ".join(args)}): + return self._default({target!r}, ({", ".join(args)}, ), {{}}) + """.strip() + ) + code.write("\n\n") + else: + # slower fallback for ops with default or variadic arguments + setattr(cls, target, cls._call_default(target)) + + ctx: dict[str, Any] = {} + exec(code.getvalue(), ctx) + for target, impl in ctx.items(): + if target in OP_NAMES: + setattr(cls, target, impl) + + +DefaultHandler._init_cls() + + +class NoopHandler(DefaultHandler): + name = "NoopHandler" + + def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: + return None + + @staticmethod + def masked(mask, body, other) -> None: + return None + + @staticmethod + def frexp(x) -> tuple[None, None]: + return (None, None) + + @staticmethod + def scan(dtypes, combine_fn, values) -> tuple[None, ...]: + return (None,) * len(values) + + @staticmethod + def sort(dtypes, values, stable, descending) -> tuple[None, ...]: + return (None,) * len(values) + + @staticmethod + def indirect_indexing(index_var, size, check=True, wrap_neg=True) -> sympy.Symbol: + return sympy.S.Zero + + +class BasicMathOpsMixin: + @staticmethod + def add(a, b): + return f"{a} + {b}" + + @staticmethod + def sub(a, b): + return f"{a} - {b}" + + @staticmethod + def mul(a, b): + return f"{a} * {b}" + + @staticmethod + def floordiv(a, b): + return f"{a} // {b}" + + @staticmethod + def truediv(a, b): + return f"{a} / {b}" + + @staticmethod + def mod(a, b): + # careful, depending on target semantics varies + return f"{a} % {b}" + + @staticmethod + def pow(a, b): + return f"{a} ** {b}" + + @staticmethod + def lshift(a, b): + return f"{a} << {b}" + + @staticmethod + def rshift(a, b): + return f"{a} >> {b}" + + @staticmethod + def and_(a, b): + return f"{a} & {b}" + + @staticmethod + def or_(a, b): + return f"{a} | {b}" + + @staticmethod + def xor(a, b): + return f"{a} ^ {b}" + + @staticmethod + def eq(a, b): + return f"{a} == {b}" + + @staticmethod + def ne(a, b): + return f"{a} != {b}" + + @staticmethod + def lt(a, b): + return f"{a} < {b}" + + @staticmethod + def gt(a, b): + return f"{a} > {b}" + + @staticmethod + def le(a, b): + return f"{a} <= {b}" + + @staticmethod + def ge(a, b): + return f"{a} >= {b}" + + @staticmethod + def neg(a): + return f"-{a}" + + +class MockHandler(BasicMathOpsMixin, DefaultHandler): + name = "MockHandler" + + def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: + fargs = [*map(_arg_str, args)] + for k, v in kwargs.items(): + fargs.append(f"{k}={_arg_str(v)}") + return f"ops.{name}({', '.join(fargs)})" + + @staticmethod + def masked(mask, body, other) -> str: + return f"ops.masked({mask}, {body()}, {other})" + + @staticmethod + def frexp(x): + return (f"ops.frexp({x})[0]", f"ops.frexp({x})[1]") + + @staticmethod + def scan(dtypes, combine_fn, values): + return tuple( + f"ops.scan({dtypes}, {combine_fn}, {values})[{i}]" + for i in range(len(values)) + ) + + @staticmethod + def sort(dtypes, values, stable, descending): + return tuple( + f"ops.sort({dtypes}, {values}, stable={stable}, descending={descending})[{i}]" + for i in range(len(values)) + ) + + @staticmethod + def indirect_indexing(index_var, size, check=True, wrap_neg=True) -> sympy.Symbol: + return sympy_index_symbol(str(index_var)) + + +class KernelFormatterHandler(DefaultHandler): + def __init__(self, parent_handler: OpsHandler[Any]): + self.parent_handler = parent_handler + self._output = IndentedBuffer(1) + self.var_counter = itertools.count() + + @staticmethod + def ir_to_string(ir_fn, index, rindex=None) -> str: + from .ir import FlexibleLayout + from .virtualized import V + + args = [index, rindex] if rindex is not None else [index] + names = ["index", "rindex"] if rindex is not None else ["index"] + formatter = KernelFormatterHandler(MockHandler()) + + with formatter._output.indent(-1): + formatter._output.writeline(f"def inner_fn({', '.join(names)}):") + for name, arg in zip(names, args): + if arg: + lhs = ", ".join( + [ + str("_" if isinstance(v, (int, sympy.Integer)) else v) + for v in arg + ] + ) + formatter._output.writeline(f"{lhs} = {name}") + + with ( + V.set_ops_handler(formatter), + patch.object(FlexibleLayout, "allow_indexing", True), + ): + result = ir_fn(*args) + return formatter.getvalue(result) + + def indirect_indexing(self, *args, **kwargs) -> sympy.Symbol: + return self.parent_handler.indirect_indexing(*args, **kwargs) + + def _write(self, line): + # replace line with a new variable name + varname = f"tmp{next(self.var_counter)}" + self._output.writeline(f"{varname} = {line}") + return varname + + def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: + return pytree.tree_map( + self._write, getattr(self.parent_handler, name)(*args, **kwargs) + ) + + def reduction( + self, + dtype: torch.dtype, + src_dtype: torch.dtype, + reduction_type: ReductionType, + value: Union[str, tuple[str, ...]], + ) -> Union[str, tuple[str, ...]]: + line = self.parent_handler.reduction(dtype, src_dtype, reduction_type, value) + num_values = reduction_num_outputs(reduction_type) + varnames = [f"tmp{next(self.var_counter)}" for _ in range(num_values)] + self._output.writeline(f"{','.join(varnames)} = {line}") + return tuple(varnames) if num_values > 1 else varnames[0] + + def getvalue(self, result): + self._output.writeline(f"return {result}") + return self._output.getvalue() + + +class WrapperHandler(DefaultHandler): + def __init__(self, inner: OpsHandler[Any]): + self._inner = inner + + def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: + return getattr(self._inner, name)(*args, **kwargs) + + +class AddParenHandler(WrapperHandler): + def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: + val = getattr(self._inner, name)(*args, **kwargs) + if not val or isinstance(val, (sympy.Expr, tuple, list)): + return val + return f"({val})" + + +class OpCountResult(NamedTuple): + num_ops: int + used_ops: OrderedSet[str] + read_buffers: list[str] + nontrivial_read_count: int + + +class OpCounterCSE(DefaultHandler): + """Shim to count how many ops are used""" + + def __init__(self, inner: OpsHandler[Any]): + super().__init__() + self.parent_handler = inner + self.op_count = 0 + self.var_names: dict[str, str] = {} + self._used_ops: OrderedSet[str] = OrderedSet() + self._read_names: list[str] = [] + self._nontrivial_read_count = 0 + + def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: + self._used_ops.add(name) + return pytree.tree_map( + self._update_count, getattr(self.parent_handler, name)(*args, **kwargs) + ) + + def _update_count(self, val): + varname = self.var_names.get(val) + if not varname: + varname = f"tmp{self.op_count}" + self.op_count += 1 + self.var_names[val] = varname + return varname + + def indirect_indexing(self, *args, **kwargs): + self._used_ops.add("indirect_indexing") + return self.parent_handler.indirect_indexing(*args, **kwargs) + + def load(self, name: str, index: sympy.Expr) -> str: + val = self.parent_handler.load(name, index) + if val not in self.var_names: + self._used_ops.add("load") + self._read_names.append(name) + if not isinstance(index, (sympy.Integer, int)): + self._nontrivial_read_count += 1 + return self._update_count(val) + + def load_seed(self, name: str, offset: T): + val = self.parent_handler.load_seed(name, offset) + if val not in self.var_names: + self._used_ops.add("load_seed") + self._read_names.append(name) + return self._update_count(val) + + def bucketize( + self, + values: T, + boundaries: tuple[str, sympy.Expr, sympy.Expr, sympy.Expr], + boundary_indices: T, + indexing_dtype: torch.dtype, + right: bool, + sorter: Optional[tuple[str, sympy.Expr]] = None, + sorter_indices: Optional[T] = None, + ) -> T: + """ + See [Note: Inductor bucketize op] + """ + val = self.parent_handler.bucketize( + values, + boundaries, + boundary_indices, + indexing_dtype, + right, + sorter, + sorter_indices, + ) + if val not in self.var_names: + self._used_ops.add("bucketize") + self._read_names.append(boundaries[0]) + if sorter is not None: + self._read_names.append(sorter[0]) + return self._update_count(val) + + def getvalue(self): + return OpCountResult( + self.op_count, self._used_ops, self._read_names, self._nontrivial_read_count + ) + + +class ExtractConstantsHandler(NoopHandler): + def __init__(self, device: Optional[torch.device]): + self.device = device + + def constant(self, value: Any, dtype: torch.dtype) -> torch._inductor.ir.Constant: + from torch._inductor import ir + + return ir.Constant( + value=value, dtype=dtype, device=self.device or torch.get_default_device() + ) + + +class SimpleCSEHandler(WrapperHandler): + """Wraps the underlying handler with a CSE pass + + NOTE: Compared to codegen level CSE this is simplified as it + doesn't support stores which require load cache invalidation. + """ + + def __init__(self, inner: Any): + super().__init__(inner) + self.cse_cache: dict[str, Union[Any, tuple[Any, ...]]] = {} + self.mock = MockHandler() + + def indirect_indexing(self, *args, **kwargs) -> sympy.Expr: + return super().indirect_indexing(*args, **kwargs) # type: ignore[misc] + + def store(self, *args, **kwargs) -> None: + raise NotImplementedError("store not implemented") + + def store_reduction(self, *args, **kwargs) -> None: + raise NotImplementedError("store not implemented") + + def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: + key = getattr(self.mock, name)(*args, **kwargs) + val = self.cse_cache.get(key) + if val is not None: + return val + + val = getattr(self._inner, name)(*args, **kwargs) + self.cse_cache[key] = val + return val + + def device_assert_async(self, *args, **kwargs) -> None: + raise NotImplementedError( + f"{type(self).__name__}: device_assert_async should be handled by CSEProxy" + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/optimize_indexing.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/optimize_indexing.py new file mode 100644 index 0000000000000000000000000000000000000000..67c2a74e886afb4b4c3f0f96079633e5bf97e6f5 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/optimize_indexing.py @@ -0,0 +1,126 @@ +import math +from typing import Any + +import sympy + +import torch +from torch.utils._sympy.value_ranges import ValueRanges + +from .loop_body import LoopBody +from .utils import dominated_nodes + + +def val_expressable_in_32_bits(val: Any) -> bool: + if getattr(val, "is_Boolean", False): + return True + + if isinstance(val, sympy.Expr): + assert val.is_number + if val.is_Integer or val.is_Boolean: + val = int(val) + else: + val = float(val) + + # bound within mantissa + if isinstance(val, float): + return val <= (2**24) and val >= -(2**24) + + if isinstance(val, int): + iinfo = torch.iinfo(torch.int32) + return val <= iinfo.max and val >= iinfo.min + + raise TypeError(f"Unexpected value {val}") + + +def range_expressable_in_32_bits(range: ValueRanges[sympy.Expr]) -> bool: + return val_expressable_in_32_bits(range.lower) and val_expressable_in_32_bits( + range.upper + ) + + +def try_to_reduce_precision( + node: Any, + bounds: dict[Any, Any], + indirect_vars: list[Any], + indices: dict[Any, sympy.Expr], + replacement_vals: dict[Any, ValueRanges[sympy.Expr]], +) -> None: + # if a downstream use of a node explicitly converts to int32, or float16/float32/float64, + # then it's precision is set for that chain of uses, and we don't need to consider those + # dominated values + def skip_filter(node: Any) -> bool: + return node.target == "to_dtype" and node.args[2] in ( + torch.int32, + torch.float32, + torch.float64, + ) + + # TODO - there are dominated uses whose dtype does not depend on whether + # we reduce the precision here, e.g. add(int64, int64) one of the args can be reduced to + # int32 without changing the output precision of the node. this case hasn't shown up + for dominated in dominated_nodes([node], skip_filter): + if dominated.target in ["store", "output"]: + continue + + if isinstance(dominated.target, str) and "set_indirect" in dominated.target: + idx = int(dominated.target[len("set_indirect") :]) + indirect_var = indirect_vars[idx] + + # We check that we can compute all the indices it's involved in with int32 + for index, expr in indices.items(): + if indirect_var in expr.free_symbols: + index_val = replacement_vals[index] + + if math.isinf(index_val.lower) or math.isinf(index_val.upper): + return + + # all indices are integers, so make sure that we + # use the bounds of integers instead of floats. + # TODO - not sure if we should be doing int/float casts while tracing, + # might interfere with sympy. + + index_val_int = ValueRanges[sympy.Expr]( + int(index_val.lower), int(index_val.upper) + ) + if not range_expressable_in_32_bits(index_val_int): + return + + if not range_expressable_in_32_bits(bounds[dominated]): + return + + args = list(node.args) + args[2] = torch.int32 + node.args = tuple(args) + + +def indexing_dtype_strength_reduction(loop_body: LoopBody) -> None: + """ + Performs Value Range Analysis on LoopBody's fx graph to reduce precision of + intermediaries from int64 to int32 + """ + bv = loop_body.bounds() + + int64_dtype_nodes = [ + node + for node in loop_body.get_nodes() + if ( + node.target == "to_dtype" + and node.args[2] == torch.int64 + and node not in bv.unbounded_vars + ) + ] + if not int64_dtype_nodes: + return + + bounds = bv.get_bounds() + + # TODO - if dominated node of one to_dtype is not expressible in int32, + # we should short circuit another to_dtype node if that node also dominates + for node in int64_dtype_nodes: + try_to_reduce_precision( + node, + bounds, + loop_body.indirect_vars, + loop_body.indexing_exprs, + bv.replacement_vals, + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/output_code.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/output_code.py new file mode 100644 index 0000000000000000000000000000000000000000..4d9bc98fc2220ab617547f66c1357cdf7c7f016b --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/output_code.py @@ -0,0 +1,1014 @@ +""" +This provides an abstract class which parametrizes over an "output code" concept +for Inductor. Intuitively, this represents the compiled callable which Inductor +produces which you can call to get optimized code. However, this callable +has some other capabilities: + +- It is serializable, so you can save/load this product from disk without + having to do compilation again. + +- (When using remote cache) it is addressable, so you can save just a key + which you can use to load this product from remote cache later. + +This class is abstract because we have several different implementations of +serialized format: + +- Python wrapper (the default) + +- AOTInductor (this produces ABI stable binaries which work across PyTorch + versions) + +""" + +from __future__ import annotations + +import dataclasses +import logging +import os +from functools import partial +from typing import Any, Optional, TYPE_CHECKING, TypeAlias, Union + +import torch +from torch._dynamo.utils import counters, get_runtime_metrics_context +from torch._higher_order_ops.wrap import inductor_compiled_code +from torch._inductor.cudagraph_utils import ( + BoxedDeviceIndex, + CudagraphCachedInfo, + CudagraphMetadata, + get_partition_cudagraph_metadata, + get_placeholder_info, + log_cudagraph_skip_and_bump_counter, +) +from torch._inductor.freezing_utils import has_frozen_params, is_frozen_param +from torch._inductor.utils import ( + _unstable_customized_partition_wrapper, + align_inputs_from_check_idxs, + BoxedBool, + CUDAGraphWrapperMetadata, + GraphPartitionMap, + InputType, + output_node, + set_tracing_context_output_strides, +) +from torch.autograd.profiler import record_function +from torch.utils._ordered_set import OrderedSet +from torch.utils._python_dispatch import is_in_torch_dispatch_mode + +from . import config +from .runtime.autotune_cache import AutotuneCacheBundler + + +if TYPE_CHECKING: + from collections import Counter + from collections.abc import Callable, Sequence + + from torch._inductor import metrics + from torch._inductor.graph import GraphLowering + from torch._library.fake_class_registry import FakeScriptObject + from torch.export.pt2_archive._package_weights import Weights + + from .compile_fx import _CompileFxKwargs + from .triton_bundler import TritonBundle + +log = logging.getLogger(__name__) + + +@dataclasses.dataclass +class OutputCode: + # TODO: Remove underscores here + + # None if the output is not remote cacheable + _fx_graph_cache_key: Optional[str] = dataclasses.field(default=None, init=False) + _fx_graph_cache_debug_lines: Optional[list[str]] = dataclasses.field( + default=None, init=False + ) + + # How long it took to compile this OutputCode, end to end + _time_taken_ns: Optional[int] = dataclasses.field(default=None, init=False) + + def __call__(self, inputs: Sequence[Any]) -> Any: + raise NotImplementedError(type(self)) + + def prepare_for_serialization(self) -> None: + raise NotImplementedError(type(self)) + + def post_compile( + self, + example_inputs: Sequence[InputType], + constants: CompiledFxGraphConstants, + graph_kwargs: _CompileFxKwargs, + ) -> None: + raise NotImplementedError(type(self)) + + # TODO: Get rid of this + def set_triton_bundle(self, triton_bundle: Any) -> None: + raise NotImplementedError(type(self)) + + +_StrideExprStr: TypeAlias = str + + +# copy_ fails when trying to write to tensors with memory overlap, +# for expanded dimensions (a dimension which used to have size 1 -> ?) +# we can select one element from that dimension and write to it +# to achieve writing to all values of that dimension of the input tensor +def get_expanded_dims(t: torch.Tensor) -> list[int]: + if not isinstance(t, torch.Tensor): + # pyrefly: ignore [bad-return] + return None + return [i for i in range(t.ndim) if t.stride(i) == 0 and t.size(i) != 1] + + +def index_expanded_dims(t: torch.Tensor, expanded_dims: list[int]) -> torch.Tensor: + for expanded_dim in expanded_dims: + t = torch.ops.aten.slice(t, expanded_dim, 0, 1) + return t + + +def complex_memory_overlap(t: torch.Tensor) -> bool: + if config.always_complex_memory_overlap_TESTING_ONLY: + return True + + # if torch._debug_has_internal_overlap thinks this tensor potentially has + # memory overlap internally, let's dig deeper to find out whether it's true. + # + # Call squeeze() so that dimension with size 1 does not cause false positive. + t = index_expanded_dims(t, get_expanded_dims(t)).squeeze() + if torch._debug_has_internal_overlap(t) != 0: + strides = t.stride() + sizes = t.shape + indices = list(range(len(strides))) + indices = [x for _, x in sorted(zip(strides, indices))] + for i in range(len(strides)): + prev_stride = 1 if i == 0 else strides[indices[i - 1]] + prev_size = 1 if i == 0 else sizes[indices[i - 1]] + if strides[indices[i]] < prev_stride * prev_size: + return True + return False + + +def maybe_handle_backward_generation( + compiled_graph: CompiledFxGraph, + boxed_forward_device_index: Optional[BoxedDeviceIndex], +) -> None: + assert compiled_graph.current_callable is not None + is_backward = compiled_graph.fx_kwargs["is_backward"] + + # See [Backward Generation Handling] + # if cudagraph'd the forward and set the device, we need to let the cudagraph manager + # know we are we running the backward even if we will not run it in cudagraphs + if is_backward and config.triton.cudagraph_trees: + assert boxed_forward_device_index is not None + assert boxed_forward_device_index.value is not None + compiled_graph_callable = compiled_graph.current_callable + + manager = torch._inductor.cudagraph_trees.get_manager( + boxed_forward_device_index.value, create_if_none_exists=False + ) + # should already exist from forward + assert manager is not None + + def compiled_artifact(new_inputs: list[Any]) -> Callable[..., Any]: + manager.set_to_running_backward() # type: ignore[union-attr] + return compiled_graph_callable(new_inputs) + + compiled_graph.current_callable = compiled_artifact + + +def prepare_cudagraph_post_compile( + compiled_graph: CompiledFxGraph, + example_inputs: Sequence[InputType], + boxed_forward_device_index: Optional[BoxedDeviceIndex], +) -> None: + if not config.triton.cudagraph_trees: + # Force specialize all inputs so that CUDA graphs will work + for t in example_inputs: + if isinstance(t, torch.SymInt): + int(t) # guard + + is_inference = compiled_graph.fx_kwargs["is_inference"] + is_backward = compiled_graph.fx_kwargs["is_backward"] + if boxed_forward_device_index is not None and not is_inference and not is_backward: + boxed_forward_device_index.set(next(iter(compiled_graph.device_idxs))) + + +def cudagraph_post_compile( + example_inputs: Sequence[InputType], + compiled_graph: CompiledFxGraph, + cudagraphs: BoxedBool, + constants: dict[str, Union[torch.Tensor, type]], + boxed_forward_device_index: Optional[BoxedDeviceIndex], +) -> None: + """ + Checks for any reasons not to run cudagraphs and then + runs it on compiled_graph. + Mutates the `compiled_graph.current_callable` and `cudagraphs` + """ + assert compiled_graph.current_callable is not None + assert compiled_graph.cudagraph_info is not None + cached_info = compiled_graph.cudagraph_info + cudagraph_fail_reasons = cached_info.cudagraph_fail_reasons + is_inference = compiled_graph.fx_kwargs["is_inference"] + is_backward = compiled_graph.fx_kwargs["is_backward"] + + if not cudagraph_fail_reasons: + fx_kwargs = compiled_graph.fx_kwargs + static_input_idxs = fx_kwargs["static_input_idxs"] + + placeholders = cached_info.placeholders + stack_traces = cached_info.stack_traces + + prepare_cudagraph_post_compile( + compiled_graph, example_inputs, boxed_forward_device_index + ) + + from .compile_fx import cudagraphify + + current_callable = compiled_graph.current_callable + assert current_callable is not None + # Filter to only tensor constants (exclude opaque value type classes) + tensor_constants = { + k: v for k, v in constants.items() if isinstance(v, torch.Tensor) + } + compiled_graph.current_callable = cudagraphify( + current_callable, + static_input_idxs=static_input_idxs or (), + device_index=next(iter(compiled_graph.device_idxs)), + stack_traces=stack_traces, + is_backward=is_backward, + is_inference=is_inference, + constants=tuple(tensor_constants.values()), + placeholders=placeholders, + mutated_input_idxs=tuple(compiled_graph.mutated_input_idxs), + ) + + else: + BoxedBool.disable(cudagraphs) + maybe_handle_backward_generation(compiled_graph, boxed_forward_device_index) + + if "cuda" in compiled_graph.device_types: + # prefer better disable_cudagraphs_reason bc stack trace + # TODO: migrate all disable reasons to stack trace, refactor + if compiled_graph.disabled_cudagraphs_reason: + log_cudagraph_skip_and_bump_counter( + compiled_graph.disabled_cudagraphs_reason + ) + else: + log_cudagraph_skip_and_bump_counter( + f"skipping cudagraphs due to {cudagraph_fail_reasons}" + ) + + +def cudagraph_partition_post_compile( + example_inputs: Sequence[InputType], + compiled_graph: CompiledFxGraph, + cudagraphs: BoxedBool, + constants: dict[str, Union[torch.Tensor, type]], + boxed_forward_device_index: Optional[BoxedDeviceIndex], +) -> None: + """ + Cudagraphify each partition functions, which first prepares the necessary + metadata and then applies the cudagraphify function to each partition. + + Assuming all partition functions are cudagraphified and share the same order + as `compiled_graph.partition_maps`. See [Note: Graph Partition Map for CUDAGraph]. + """ + assert compiled_graph.cudagraph_info is not None + cudagraph_fail_reasons = compiled_graph.cudagraph_info.cudagraph_fail_reasons + + if ( + cudagraph_fail_reasons + or compiled_graph.partition_maps is None + or len(compiled_graph.partition_maps) == 0 + ): + # cudagraphify is not called if there are no partitions + BoxedBool.disable(cudagraphs) + maybe_handle_backward_generation(compiled_graph, boxed_forward_device_index) + return + + from .compile_fx import cudagraphify + + assert compiled_graph.current_callable is not None + assert compiled_graph.recursively_apply_fns is not None + is_inference = compiled_graph.fx_kwargs["is_inference"] + is_backward = compiled_graph.fx_kwargs["is_backward"] + static_input_idxs = OrderedSet(compiled_graph.fx_kwargs["static_input_idxs"] or ()) + mutated_input_idxs = compiled_graph.mutated_input_idxs + device_index = next(iter(compiled_graph.device_idxs)) + + # Filter to only tensor constants (exclude opaque value type classes) + tensor_constants = { + k: v for k, v in constants.items() if isinstance(v, torch.Tensor) + } + + graph_metadata = CudagraphMetadata( + compiled_graph.cudagraph_info.placeholders, + static_input_idxs, + mutated_input_idxs, + compiled_graph.cudagraph_info.stack_traces, + tensor_constants, + ) + + prepare_cudagraph_post_compile( + compiled_graph, example_inputs, boxed_forward_device_index + ) + + # cudagraphify each partition function, assuming every graph partition function + # is cudagraphable. Non-cudagraphable ops (e.g., cpu ops) are inlined into + # `call` function and not included in partition functions. + cudagraphify_fns = [] + for partition_map in compiled_graph.partition_maps: + partition_metadata = get_partition_cudagraph_metadata( + partition_map, + graph_metadata, + ) + + cudagraphify_fn = partial( + cudagraphify, + static_input_idxs=tuple(partition_metadata.static_input_idxs), + device_index=device_index, + stack_traces=partition_metadata.stack_traces, + is_backward=is_backward, + is_inference=is_inference, + constants=tuple(partition_metadata.constants.values()), + placeholders=partition_metadata.placeholders, + mutated_input_idxs=tuple(partition_metadata.mutated_input_idxs), + ) + cudagraphify_fns.append(cudagraphify_fn) + + compiled_graph.recursively_apply_fns(cudagraphify_fns) + + +def maybe_realign_inputs( + ran_cudagraphs: BoxedBool, + compiled_graph: CompiledFxGraph, + inputs_to_check: Sequence[int], + mutated_inputs_idxs: OrderedSet[int], +) -> None: + """ + Realigns input strides from inputs_to_check if + we didn't end up running cudagraphs. Mutates + `compiled_graph.current_callable` if cudagraphs + was run. Otherwise, does nothing. + """ + if not ran_cudagraphs: + assert compiled_graph.current_callable is not None + new_callable = align_inputs_from_check_idxs( + compiled_graph.current_callable, inputs_to_check, mutated_inputs_idxs + ) + if new_callable is not compiled_graph.current_callable: + compiled_graph.current_callable = new_callable + + +class CompiledFxGraphConstants: + """Wrapper class that unwraps constants from a compiled fx graph. This + version of the class only supports directly grabbing the saved constants off of + a CompiledFxGraph. + + With freezing, FxGraphCache doesn't store the constants of the input + GraphModule it gets from AOTAutograd. Instead, it saves just the **names** + of those constants, and grabs the constant values directly from the graph module + passed in at runtime. + + Thing is, we don't always *have* the graph module available at runtime, hence + the existence of this class and its CompiledFxGraphConstantsWithGm counterpart. + + To support freezing, FXGraphCache gets passed a CompiledFxGraphConstantsWithGm during + post compile. Otherwise, CompiledFxGraphConstants supports the basic case of loading + the value of constants directly off of the original saved object. + """ + + def unwrap(self, g: CompiledFxGraph) -> dict[str, Union[torch.Tensor, type]]: + assert g.constants is not None + return {**g.constants, **g.opaque_value_type_classes} + + +class CompiledFxGraphConstantsWithGm(CompiledFxGraphConstants): + """ + This version of CompiledFxGraphConstants, instead of grabbing constants + directly saved on CompiledFxGraphs, will just grab their names. Then, it takes + a second GraphModule to grab the corresponding constant values out of. + + This is necessary for supporting freezing in FxGraphCache. + """ + + def __init__(self, gm: torch.fx.GraphModule) -> None: + self.gm = gm + + def unwrap(self, g: CompiledFxGraph) -> dict[str, Union[torch.Tensor, type]]: + frozen_params = { + name: getattr(self.gm, orig_name) + for name, orig_name in g.frozen_param_names.items() + } + constants = g.constants or {} + return {**constants, **frozen_params, **g.opaque_value_type_classes} + + +@dataclasses.dataclass +class CompiledFxGraph(OutputCode): + """ + Class holding a compiled FX graph. This is the object serialized on disk + to support FxGraph caching. + """ + + current_callable: Optional[Callable[..., Any]] + recursively_apply_fns: Optional[Callable[..., Any]] + compiled_fn_runner: Optional[Any] + cache_key: str + source_code: str = dataclasses.field(repr=False) # Do not display source_code + runnable_graph_str: str = dataclasses.field(repr=False) # Do not display graph + inductor_post_grad_graph_str: str = dataclasses.field( + repr=False + ) # Do not display graph + cache_linemap: Optional[list[tuple[int, str]]] + device_types: OrderedSet[str] + device_idxs: OrderedSet[int] + mutated_inputs: OrderedSet[str] + mutated_input_idxs: OrderedSet[int] + constants: Optional[dict[str, torch.Tensor]] + frozen_param_names: dict[str, str] + torchbind_constants: dict[str, torch._C.ScriptObject | FakeScriptObject] + opaque_value_type_classes: dict[str, type] + output_strides: Optional[list[Optional[tuple[_StrideExprStr, ...]]]] + disabled_cudagraphs_reason: Optional[str] + metrics_deltas: metrics.CachedMetricsDeltas + counter_deltas: Counter[str] + # This is a string representation of an expression we serialize + # with the object so the guards can be evaluated in a different + # context in order to verify the validity of serving a cached + # fx graph. The expression must be generated by: + # ShapeEnv.produce_guards_expression() + guards_expr: Optional[str] + inductor_provenance_mapping_str: Optional[str] + inductor_provenance_stack_traces_str: Optional[str] + + cudagraph_info: Optional[CudagraphCachedInfo] + partition_maps: Optional[list[GraphPartitionMap]] + fx_kwargs: _CompileFxKwargs + inputs_to_check: Sequence[int] + + _boxed_call: Optional[bool] = None + _triton_bundle: Optional[TritonBundle] = None + _wrap_compiled_regions: bool = False + + def __init__( + self, + current_callable: Optional[Callable[..., Any]], + graph: GraphLowering, + gm: torch.fx.GraphModule, + output_strides: list[Optional[tuple[_StrideExprStr, ...]]], + disabled_cudagraphs_reason: Optional[str], + metrics_deltas: metrics.CachedMetricsDeltas, + counter_deltas: Counter[str], + cudagraphs: BoxedBool, + example_inputs: Sequence[InputType], + static_input_idxs: Sequence[int], + fx_kwargs: _CompileFxKwargs, + inputs_to_check: Sequence[int], + runnable_graph_str: str, + inductor_post_grad_graph_str: str, + compiled_fn_runner: Optional[Any] = None, + inductor_provenance_mapping_str: Optional[str] = None, + inductor_provenance_stack_traces_str: Optional[str] = None, + ) -> None: + self.current_callable = current_callable + self.compiled_fn_runner = compiled_fn_runner + self.recursively_apply_fns = ( + compiled_fn_runner.recursively_apply_fns + if compiled_fn_runner is not None + else None + ) + self.cache_key = graph.cache_key + if graph.cache_path: + with open(graph.cache_path) as f: + self.source_code = f.read() + self.runnable_graph_str = runnable_graph_str + self.inductor_post_grad_graph_str = inductor_post_grad_graph_str + self.inductor_provenance_mapping_str = inductor_provenance_mapping_str + self.inductor_provenance_stack_traces_str = inductor_provenance_stack_traces_str + self.cache_linemap = graph.cache_linemap + # TODO - ordered set + self.device_types = OrderedSet(graph.device_types) + self.device_idxs = OrderedSet(graph.device_idxs) + self.mutated_inputs = OrderedSet(graph.mutated_inputs) + self.mutated_input_idxs = OrderedSet(graph.mutated_input_idxs) + + # We store the constant attributes in the cache entry and re-attach them + # to the module created in PyCodeCache.load_by_key_path. In the case that + # the graph has frozen parameters, we save the mapping from the attribute + # names in the GraphLowering to the original name of the attribute in the + # GraphModule. When we create the module from the cache entry, we then + # look up the constants from the current GraphModule. This scheme allows + # us to support caching with freezing. + if not has_frozen_params(gm): + self.constants = graph.constants + self.frozen_param_names = {} + else: + self.constants = {} + self.frozen_param_names = {} + for k, v in graph.constants.items(): + if is_frozen_param(v): + self.frozen_param_names[k] = graph.allocated_constant_name[k] + else: + self.constants[k] = v + + self.torchbind_constants = graph.torchbind_constants + self.opaque_value_type_classes = graph.opaque_value_type_classes + self.output_strides = output_strides + self.disabled_cudagraphs_reason = disabled_cudagraphs_reason + self.metrics_deltas = metrics_deltas + self.counter_deltas = counter_deltas + self.guards_expr = None + self.cudagraph_info = None + self.partition_maps = graph.partition_maps + self.fx_kwargs = {} + self.inputs_to_check = () + + cudagraph_info = None + if cudagraphs: + # check cudagraph disabling reasons from inductor lowering + if self.disabled_cudagraphs_reason: + if "cuda" in self.device_types: + log_cudagraph_skip_and_bump_counter( + f"skipping cudagraphs due to {self.disabled_cudagraphs_reason}" + ) + else: + counters["inductor"]["cudagraph_skips"] += 1 + BoxedBool.disable(cudagraphs) + else: + complex_memory_overlap_inputs = any( + complex_memory_overlap(t) + for t in example_inputs + if isinstance(t, torch.Tensor) + ) + + if not config.triton.cudagraph_support_input_mutation: + # Skip supports for cudagraph-managed tensors + from torch._inductor.cudagraph_utils import ( + check_for_mutation_ignore_cuda_graph_managed_tensor, + ) + + has_mutation_str = ( + check_for_mutation_ignore_cuda_graph_managed_tensor( + gm, + self.mutated_inputs, + self.mutated_input_idxs, + static_input_idxs, + ) + ) + has_mutation = has_mutation_str is not None + + if has_mutation: + self.disabled_cudagraphs_reason = has_mutation_str + else: + # Check mutation later to support cudagraph-managed tensors + has_mutation = None + + cudagraph_tests = [ + (not has_mutation, "mutated inputs"), + (not complex_memory_overlap_inputs, "complex memory overlap"), + ( + all( + isinstance(t, (torch.Tensor, torch.SymInt, torch.Generator)) + for t in example_inputs + ), + "non-Tensor inputs", + ), + ] + output = output_node(gm) + # output args are tuple of first argument + assert len(output.args) == 1 + stack_traces = [ + (arg.stack_trace if isinstance(arg, torch.fx.node.Node) else None) + for arg in output.args[0] # type: ignore[union-attr] + ] + cudagraph_fail_reasons = [s for b, s in cudagraph_tests if not b] + placeholders = tuple(get_placeholder_info(gm.graph)) + cudagraph_info = CudagraphCachedInfo( + placeholders, stack_traces, cudagraph_fail_reasons + ) + + self.cudagraph_info = cudagraph_info + self.inputs_to_check = inputs_to_check + self.fx_kwargs = fx_kwargs + + # aot autograd needs to know to pass in inputs as a list + self._boxed_call = True + + # Store whether to wrap compiled regions in inductor_compiled_code HOP + # This is set at compile time to avoid runtime overhead + self._wrap_compiled_regions = config.wrap_inductor_compiled_regions + + def __del__(self) -> None: + if self.compiled_fn_runner is not None: + # For torch._inductor.config.graph_partition = True, + # self.compiled_fn_runner.partitions hold cudagraphified functions + # which prevents deallocation. When CompiledFxGraph is deleted, + # self.compiled_fn_runner will not be called in the future so we + # should also delete these partitions. + del self.compiled_fn_runner.partitions + + def __call__(self, inputs: Sequence[Any]) -> Any: + assert self.current_callable is not None + + if ( + torch._inductor.debug.RECORD_GRAPH_EXECUTION + and torch._inductor.debug.GRAPH_EXECUTION_ORDER is not None + ): + graph_id = self.fx_kwargs.get("graph_id") + compile_id = ( + torch._inductor.debug.GRAPH_COMPILE_IDS.get(graph_id) + if graph_id is not None + and torch._inductor.debug.GRAPH_COMPILE_IDS is not None + else None + ) + torch._inductor.debug.GRAPH_EXECUTION_ORDER.append( + { + "compile_id": compile_id, + } + ) + try: + # Checking the profiler directly is faster than nullcontext + if torch.autograd.profiler._is_profiler_enabled: + with record_function( + f"## Call CompiledFxGraph {self._fx_graph_cache_key} ##" + ): + return self.current_callable(inputs) + else: + return self.current_callable(inputs) + finally: + get_runtime_metrics_context().finish() + AutotuneCacheBundler.end_compile() + + def post_compile( + self, + example_inputs: Sequence[InputType], + constants: CompiledFxGraphConstants, + graph_kwargs: _CompileFxKwargs, + ) -> None: + """ + Run a set of post processing steps after loading from the cache. These involve: + - Setting the tracing context output strides + - Running cudagraphs if enabled + - Realigning inputs + + This runs whether or not we have a cache hit, and always runs directly after we get a CompiledFxGraph. + The results of this function are *not* saved in the cache itself. + """ + if config.graph_partition and _unstable_customized_partition_wrapper.wrapper: + # Mechanically apply user-specified cudagraph wrappers without modification + assert self.recursively_apply_fns is not None + assert self.compiled_fn_runner is not None + num_partitions = len(self.compiled_fn_runner.partitions) + wrapper_metadatas = [ + CUDAGraphWrapperMetadata(num_partitions, i) + for i in range(num_partitions) + ] + customized_wrapper = _unstable_customized_partition_wrapper.wrapper + customized_wrappers_with_metadata = [ + lambda f, m=metadata: customized_wrapper(f, m) + for metadata in wrapper_metadatas + ] + self.recursively_apply_fns(customized_wrappers_with_metadata) + return + + set_tracing_context_output_strides(example_inputs, self) + assert graph_kwargs["cudagraphs"] is not None + assert graph_kwargs["is_backward"] is not None + is_backward = graph_kwargs["is_backward"] + cudagraphs: BoxedBool = graph_kwargs["cudagraphs"] + if cudagraphs: + # It's possible that cudagraphs is enabled, but was disabled + # during a previous compilation we're loading from the cache. + # If so, we need to disable it on this new process too. + if self.disabled_cudagraphs_reason: + if "cuda" in self.device_types: + log_cudagraph_skip_and_bump_counter( + f"skipping cudagraphs due to {self.disabled_cudagraphs_reason}" + ) + else: + counters["inductor"]["cudagraph_skips"] += 1 + BoxedBool.disable(cudagraphs) + else: + if is_backward: + assert "boxed_forward_device_index" in graph_kwargs + boxed_forward_device_index = graph_kwargs[ + "boxed_forward_device_index" + ] + else: + # On the forward we don't know whether or not + # boxed_forward_device_index is set yet + boxed_forward_device_index = graph_kwargs.get( + "boxed_forward_device_index", None + ) + + if config.graph_partition: + # with graph_partition=True, we skip some cudagraph checks if it's supported + # with partition. So we have to use cudagraph_partition_post_compile. + cudagraph_partition_post_compile( + example_inputs, + self, + cudagraphs, + constants.unwrap(self), + boxed_forward_device_index, + ) + else: + cudagraph_post_compile( + example_inputs, + self, + cudagraphs, + constants.unwrap(self), + boxed_forward_device_index, + ) + inputs_to_check = self.inputs_to_check + # cudagraphs could have been disabled from the earlier conditions + # so we still need to realign inputs if that happens + maybe_realign_inputs( + cudagraphs, + self, + inputs_to_check, + self.mutated_input_idxs, + ) + + # Apply inductor_compiled_code HOP wrapper if configured + # This is done in post_compile to ensure it works with cached artifacts + if self._wrap_compiled_regions and self.current_callable is not None: + original_callable = self.current_callable + + def wrapped_callable(inputs): + if is_in_torch_dispatch_mode(): + return inductor_compiled_code(original_callable, inputs) + else: + return original_callable(inputs) + + self.current_callable = wrapped_callable + + def set_triton_bundle(self, triton_bundle: Any) -> None: + self._triton_bundle = triton_bundle + + def prepare_for_serialization(self) -> None: + # We can't really serialize callables that may be C++/Triton/etc., + # so we serialize their PyCodeCache disk cache location instead. + # TODO: This could be better if we're ever able to serialize compiled + # models to disk. + self.current_callable = None + self.recursively_apply_fns = None + self.compiled_fn_runner = None + + def write_to_disk(self) -> str: + from torch._dynamo.utils import counters + from torch._inductor.codecache import get_path, write_atomic + + # See _save_graph(); we don't store the callable in the cache entry so + # recreate it here from the PyCodeCache disk cache. + artifact_path = get_path(self.cache_key, "py")[2] + code = self.source_code + if not os.path.exists(artifact_path): + counters["inductor"]["fxgraph_lookup_write_file"] += 1 + write_atomic(artifact_path, code, make_dirs=True) + return artifact_path + + def after_deserialization(self, constants: CompiledFxGraphConstants) -> str: + from torch._dynamo.utils import dynamo_timed + from torch._inductor.codecache import PyCodeCache + + artifact_path = self.write_to_disk() + + try: + with dynamo_timed( + "PyCodeCache.load_by_key_path", + log_pt2_compile_event=True, + ): + code_cache = PyCodeCache.load_by_key_path( + self.cache_key, + artifact_path, + self.cache_linemap, + constants.unwrap(self), + ) + self.current_callable = code_cache.call + self.recursively_apply_fns = getattr( + code_cache, "recursively_apply_fns", None + ) + self.compiled_fn_runner = getattr(code_cache, "runner", None) + except OSError: + log.error("Failed to load artifact: %s", artifact_path) + raise + + return artifact_path + + +@dataclasses.dataclass +class CompiledAOTI(OutputCode): + """ + Class holding an AOTInductor compiled so. + """ + + filename: Union[str, list[Union[str, Weights]], torch.fx.GraphModule] + device_type: str + current_callable: Optional[Callable[..., Any]] = None + _cached_files: dict[str, bytes] = dataclasses.field(default_factory=dict) + + def __post_init__(self): + if not config.aot_inductor.link_libtorch: + return + + if ( + torch._inductor.cpp_builder._IS_MACOS + or torch._inductor.cpp_builder._IS_WINDOWS + ): + return + + if config.aot_inductor.cross_target_platform == "windows": + return + + if config.aot_inductor.package_cpp_only: + return + + if not config.enable_autograd_for_aot: + return + + if isinstance(self.filename, list): + current_callable = next( + fn for fn in self.filename if isinstance(fn, str) and fn.endswith(".so") + ) + else: + current_callable = self.filename + + if isinstance(current_callable, torch.fx.GraphModule): + self.current_callable = current_callable + return + + if self.device_type.startswith("cuda"): + current_callable = ( + torch._C._aoti.AOTIModelContainerRunnerCuda( # type: ignore[call-arg] + current_callable, + 1, + self.device_type, + "", + True, + ).run # type: ignore[attr-defined] + ) # type: ignore[attr-defined] + elif self.device_type == "cpu": + current_callable = ( + torch._C._aoti.AOTIModelContainerRunnerCpu( # type: ignore[call-arg] + current_callable, 1 + ).run # type: ignore[attr-defined] + ) # type: ignore[attr-defined] + else: + raise RuntimeError(f"unsupported device type {self.device_type}") + self.current_callable = current_callable + self._boxed_call = True + for file in self._cached_files: + if not os.path.exists(file): + with open(file, "wb") as f: + f.write(self._cached_files[file]) + + def __call__(self, inputs: Sequence[Any]) -> Any: + if self.current_callable is None: + raise RuntimeError("AOTInductor compiled so is not loaded") + return self.current_callable(inputs) + + def prepare_for_serialization(self) -> None: + self.current_callable = None + self._cached_files = {} + filenames: list[str] = [] + if isinstance(self.filename, list): + filenames = self.filename # type: ignore[assignment] + elif isinstance(self.filename, str): + filenames = [self.filename] + for name in filenames: + with open(name, "rb") as f: + self._cached_files[name] = f.read() + + def __getstate__(self): + state = self.__dict__.copy() + state["current_callable"] = None + return state + + def post_compile( + self, + example_inputs: Sequence[InputType], + constants: CompiledFxGraphConstants, + graph_kwargs: _CompileFxKwargs, + ) -> None: + if self.current_callable is None: + self.__post_init__() + + def set_triton_bundle(self, triton_bundle: Any) -> None: + pass + + +@dataclasses.dataclass +class MockFXGraphCacheOutput(OutputCode): + gm: Any = None + + def __post_init__(self) -> None: + self._boxed_call = True + + def post_compile( + self, + example_inputs: Sequence[InputType], + constants: CompiledFxGraphConstants, + graph_kwargs: _CompileFxKwargs, + ) -> None: + pass + + def __call__(self, inputs: Sequence[Any]) -> Any: + return self.gm(inputs) + + def set_triton_bundle(self, triton_bundle: Any) -> None: + pass + + +@dataclasses.dataclass +class RegionalOutputCode(OutputCode): + """ + OutputCode for regional inductor compilation results. + + Regional inductor returns a torch.fx.GraphModule that contains both + compiled regions (via standalone_compile) and eager regions. This needs + special serialization using GraphPickler instead of standard pickle. + + The serialization strategy stores the GraphModule as bytes using + GraphPickler.dumps(), which handles FakeTensors, AOTCompiledArtifacts, + and other special objects that standard pickle cannot handle. + """ + + # The serialized graph module as bytes (using GraphPickler) + _serialized_graph_module: Optional[bytes] = dataclasses.field( + default=None, init=False + ) + + # The actual graph module (cleared during serialization) + _graph_module: Optional[torch.fx.GraphModule] = dataclasses.field( + default=None, init=False + ) + + def __init__(self, graph_module: torch.fx.GraphModule): + """ + Args: + graph_module: The torch.fx.GraphModule returned by regional_inductor + """ + super().__init__() + self._graph_module = graph_module + self._serialized_graph_module = None + + def __call__(self, inputs: Sequence[Any]) -> Any: + """Execute the regional compiled graph.""" + if self._graph_module is None: + raise RuntimeError( + "RegionalOutputCode has no graph module loaded. " + "Did you forget to call post_compile()?" + ) + return self._graph_module(*inputs) + + def post_compile( + self, + example_inputs: Sequence[InputType], + constants: CompiledFxGraphConstants, + graph_kwargs: _CompileFxKwargs, + ) -> None: + """ + Post-compile processing for regional inductor. + + This deserializes the GraphModule from bytes using GraphPickler, + extracting the fake_mode from example_inputs. + """ + if self._graph_module is not None: + return + assert self._serialized_graph_module is not None + # Get fake mode from example inputs + from torch._guards import detect_fake_mode + + fake_mode = detect_fake_mode(example_inputs) + if fake_mode is None: + raise RuntimeError( + "Could not detect fake mode from example inputs. " + "Regional inductor requires fake mode for deserialization." + ) + + # Deserialize the graph module + from torch.fx._graph_pickler import GraphPickler + + gm = GraphPickler.loads(self._serialized_graph_module, fake_mode) + assert isinstance(gm, torch.fx.GraphModule) + gm.recompile() + self._graph_module = gm + + def set_triton_bundle(self, triton_bundle: Any) -> None: + """Regional inductor doesn't use triton bundles directly.""" + + def prepare_for_serialization(self) -> None: + """ + Prepare for serialization by converting the GraphModule to bytes. + + This uses GraphPickler to serialize the graph module since it contains + special objects like FakeTensors and AOTCompiledArtifacts that need + custom pickling. + """ + if self._graph_module is not None: + from torch.fx._graph_pickler import GraphPickler + + self._serialized_graph_module = GraphPickler.dumps(self._graph_module) + # Clear the graph module to avoid pickling it with standard pickle + self._graph_module = None diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/pattern_matcher.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/pattern_matcher.py new file mode 100644 index 0000000000000000000000000000000000000000..6c2c98a5609d16022a372b445311e29b54a9a425 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/pattern_matcher.py @@ -0,0 +1,2368 @@ +""" +# Inductor Pattern Matcher + +The pattern matcher enables search/replace within an FX graph. + +The main entrypoint to the pattern matcher is register_replacement(). Given a +search function and a replacement function this will register a replacement with +a pass (such as torch._inductor.fx_passes.joint_graph.patterns). + +Internally the pattern matcher represents patterns as a graph (a DAG). Creating +new patterns manually as a graph is cumbersome and error-prone so the standard +way to create patterns (using register_replacement()) is to provide a search +function and a replacement function which is traced and converted into a graph. + +Because the search functions are built somewhat generic (they tend to ignore +tensor sizes, for example) register_replacement() allows you to specify an +`extra_check` function which performs additional checks to verify that the +matched pattern fully matches before returning it. + +## Precompiled Patterns + +New patterns are added using register_replacement(). Patterns added in this way +can have a compile-time overhead because they need to be traced before +use. Patterns can be precompiled and added using gen_register_replacement() +instead. To do this you call gen_register_replacement() instead of +register_replacement(). The arguments are the same except for an additional +unique name which is used as a lookup key. + +## Internals + +The match DAG is represented by a graph of `PatternExpr` nodes. Each PatternExpr +implements a `_match` method which returns either a `Match` object for a +successful match or a `FailedMatch` object for a failure to match. +""" + +from __future__ import annotations + +import contextlib +import dataclasses +import functools +import importlib +import inspect +import itertools +import logging +import operator +import os +import re +import textwrap +import typing +from abc import ABC, abstractmethod +from collections import defaultdict +from collections.abc import Callable, Collection, Generator, Iterable, Mapping, Sequence +from pathlib import Path +from typing import Any, NoReturn, Optional, Protocol, TypeVar, Union +from typing_extensions import Self, TypeIs + +import torch +import torch._guards +import torch.fx +import torch.utils._pytree as pytree +from torch._dispatch.python import enable_python_dispatcher +from torch._dynamo.utils import counters +from torch._prims_common import is_integer_dtype +from torch._subclasses.fake_tensor import unset_fake_temporarily +from torch.fx.experimental.proxy_tensor import make_fx +from torch.fx.experimental.symbolic_shapes import guard_or_false, statically_known_true +from torch.fx.graph_module import _get_attr +from torch.fx.immutable_collections import immutable_dict, immutable_list +from torch.fx.passes.graph_transform_observer import GraphTransformObserver +from torch.fx.traceback import preserve_node_meta +from torch.utils._ordered_set import OrderedSet + +from .._functorch import config as functorch_config +from .._functorch.aot_autograd import aot_function, make_boxed_func +from .._functorch.partitioners import default_partition +from .._subclasses import FakeTensor, FakeTensorMode +from ..fx import Transformer +from . import config +from .decomposition import select_decomp_table +from .lowering import fallback_node_due_to_unsupported_type + + +log = logging.getLogger(__name__) +aten = torch.ops.aten +prims = torch.ops.prims + +Constant = Any +NodeOrConstant = Union[Constant, torch.fx.Node] + +backend = os.environ.get("TORCHINDUCTOR_PATTERN_MATCH_BACKEND", "inductor") + + +class SearchFn(Protocol): + __name__: str + + def __call__(self, *args: Any, **kwargs: Any) -> Any: ... + + +class ReplaceFn(Protocol): + def __call__(self, *args: Any, **kwargs: Any) -> Any: ... + + +class TraceFn(Protocol): + def __call__( + self, fn: Union[SearchFn, ReplaceFn], *args: Any, **kwargs: Any + ) -> torch.fx.GraphModule: ... + + +T = TypeVar("T") + +# What's a better name for this? +FnsType = Union[torch.fx.node.Target, str] + + +class Multiple: + def __init__(self) -> None: + # Ensure we're really a singleton. + assert "MULTIPLE" not in globals() or self is MULTIPLE + + +# Sentinel indicating multiple quantities can be matched +MULTIPLE = Multiple() + + +def _transfer_meta( + new_meta: dict[str, Any], old_node: torch.fx.Node, pass_name: str = "" +) -> None: + from torch.fx.traceback import NodeSource, NodeSourceAction + + # transfer metadata after pattern matching occurs. + # skip "val" and "tensor_meta" because this info is too specific; it's unlikely + # to remain accurate after pattern matching has occurred. + if config.trace.provenance_tracking_level == 1: + # We handle "from_node" field of the node meta specially to record that the new node comes from the old_node. + new_from_node = new_meta.get("from_node", []).copy() + new_from_node.append(NodeSource(old_node, pass_name, NodeSourceAction.REPLACE)) + new_meta.update( + (k, v) + for k, v in old_node.meta.items() + if k in torch.fx.proxy._COPY_META_FIELDS + ) + new_meta["from_node"] = new_from_node + else: + new_meta.update( + (k, v) + for k, v in old_node.meta.items() + if k in torch.fx.proxy._COPY_META_FIELDS + ) + if "stack_trace" in old_node.meta: + new_meta["stack_trace"] = old_node.meta["stack_trace"] + + +class Match: + """ + Represents a successfully matched pattern. + + The `Match` object is returned to represent a successfully matched + pattern. Included in the Match are the pattern that was matched, the graph + nodes matched, and any args that were used during the matching. + + The args and kwargs are specific to the type of pattern that was matched and + provide hints about what was matched. + """ + + pattern: PatternExpr + args: list[Any] + kwargs: dict[str, Any] + nodes: list[torch.fx.Node] + targets: dict[_TargetExpr, torch.fx.node.Target] + ctx: MatchContext + replacement_graph: Optional[torch.fx.GraphModule] + + def __init__( + self, + ctx: MatchContext, + pattern: PatternExpr, + args: Optional[Sequence[Any]] = None, + kwargs: Optional[dict[str, Any]] = None, + ) -> None: + super().__init__() + self.pattern = pattern + # The input nodes that must be passed in to the result + self.args = list(args or []) + self.kwargs = kwargs or {} + # The nodes matched in this expression + self.nodes = [] + # Mapping CallFunction to the node.target + self.targets = {} + self.ctx = ctx + self.replacement_graph = None + + @property + def graph(self) -> torch.fx.Graph: + return self.ctx.graph + + def extend(self, other: Match) -> None: + if self.kwargs: + for key in OrderedSet(self.kwargs.keys()) & OrderedSet(other.kwargs.keys()): + if self.kwargs[key] != other.kwargs[key]: + raise FailedMatch("kwarg mismatch: {}", key) + self.args.extend(other.args) + self.nodes.extend(other.nodes) + self.kwargs.update(other.kwargs) + self.targets.update(other.targets) + + def bundle(self) -> Match: + # Wrap args in an extra list + self.args = [tuple(self.args)] if self.args else [] + return self + + def __repr__(self) -> str: + return f"Match(..., {self.args}, {self.kwargs})" + + def erase_nodes(self) -> None: + graph = self.graph + for n in reversed(self.nodes): + if not n._erased and not n.users: + graph.erase_node(n) + + def output_nodes(self) -> list[Optional[torch.fx.Node]]: + return [ + (self.ctx.pattern_to_node[p] if p is not None else None) + for p in self.ctx.outputs + ] + + def output_node(self) -> torch.fx.Node: + return next(p for p in self.output_nodes() if p) + + def replace_with_graph( + self, replacement_graph: torch.fx.Graph, args: Sequence[Any] + ) -> None: + ReplacementPatternEntry.replace_with_graph( + self, self.ctx.graph, replacement_graph, args + ) + + def replace_by_example( + self, + replacement_fn: ReplaceFn, + args: Sequence[Any], + trace_fn: Optional[TraceFn] = None, + run_functional_passes: bool = True, + ) -> None: + """Replace with a graph generated by tracing the replacement_fn. + + Args: + run_functional_passes (bool). If we should run passes that + assume functional IR (like DCE, remove_noop_ops), on the + replacement graph. + + """ + from torch._inductor.virtualized import NullHandler, V + + context = ( + V.fake_mode + if (not isinstance(V.fake_mode, NullHandler) or (V.fake_mode is None)) + else contextlib.nullcontext() + ) + + def should_propagate_eager_input_vals(nodes: list[torch.fx.Node]) -> bool: + if len(nodes) != 1: + return False + node = nodes[0] + if "eager_input_vals" not in node.meta: + return False + return node.target in OrderedSet( + [ + torch.ops.higher_order.triton_kernel_wrapper_functional, + torch.ops.higher_order.auto_functionalized, + torch.ops.higher_order.auto_functionalized_v2, + ] + ) + + # pyrefly: ignore [bad-context-manager] + with context: + if trace_fn is None: + trace_fn = functools.partial( + fwd_only, run_functional_passes=run_functional_passes + ) + + if should_propagate_eager_input_vals(self.nodes): + # Our strategy is: + # 1) trace out the graph with eager_input_vals (which have accurate eager-mode metadata) + # 2) trace out the graph with vals (which have the accurate Inductor metadata) + # 3) Propagate the eager_input_vals from the first graph to the second. + # 4) Use the second graph as the replacement graph. + + # Construct a map of node -> FakeTensor val in eager_input_vals + node_to_val = {} + + fake_args, fake_kwargs = self.nodes[0].meta["eager_input_vals"] + fake_kwargs = {**fake_kwargs} + match_args, match_kwargs = tuple(self.args), self.kwargs + + def record(node: torch.fx.Node, val: Any) -> None: + if isinstance(node, torch.fx.Node): + node_to_val[node] = val + + torch.utils._pytree.tree_map( + record, (match_args, match_kwargs), (fake_args, fake_kwargs) + ) + # map args to their FakeTensor val in eager_input_vals + example_vals = torch.fx.map_arg(args, lambda arg: node_to_val[arg]) + + # first graph + graph_with_eager_vals = trace_fn(replacement_fn, example_vals) + + # second graph + example_vals = torch.fx.map_arg(args, lambda arg: arg.meta["val"]) + replacement = trace_fn(graph_with_eager_vals, example_vals) + + # propagate metadata from first graph to second + # NB: This assertion might not be true in general, but it is true for + # the two use cases we have + # (triton_kernel_wrapper_functional, auto_functionalized) + assert len(graph_with_eager_vals.graph.nodes) == len( + replacement.graph.nodes + ) + for old_node, new_node in zip( + graph_with_eager_vals.graph.nodes, replacement.graph.nodes + ): + if "eager_input_vals" in old_node.meta: + new_node.meta["eager_input_vals"] = old_node.meta[ + "eager_input_vals" + ] + + else: + example_vals = torch.fx.map_arg( + args, + lambda arg: arg.meta["val"] + if "val" in arg.meta + else arg.meta["example_value"], + ) + replacement = trace_fn(replacement_fn, example_vals) + if len(self.nodes) == 1: + for n in replacement.graph.nodes: + _transfer_meta( + new_meta=n.meta, + old_node=self.nodes[0], + pass_name="replace_by_example", + ) + + ReplacementPatternEntry.replace_with_graph( + self, + self.ctx.graph, + replacement, + args, + ) + + +class FailedMatch(RuntimeError): + """ + Represents a unsuccessful match. + + The `FailedMatch` object is returned to represent a failure to match a + pattern. + """ + + format_string: str + + def __init__(self, format_string: str, *args: Any, **kwargs: Any) -> None: + self.format_string = format_string + # We want to construct error messages lazily instead of eagerly, as + # constructing them eagerly can significantly worsen compile times. + if len(format_string) > 200: + raise RuntimeError( + f"Format string too long - use lazy construction of strings instead. Format string is\n {format_string}" + ) + self.args = args + self.kwargs = kwargs + + def __str__(self) -> str: + return self.format_string.format(*self.args, **self.kwargs) + + def __bool__(self) -> bool: + return False + + +MatchResult = Union[Match, FailedMatch] + + +def is_match(m: MatchResult) -> TypeIs[Match]: + """ + TypeIs cannot act on `self`. Thus this function exists to let mypy + recognize FailedMatch.__bool__ as a TypeIs. + """ + return bool(m) + + +class MatchContext: + """ + Internal state needed while running PatternExpr._match(). + """ + + outputs: list[Optional[PatternExpr]] + pattern_to_node: dict[PatternExpr, Optional[torch.fx.Node]] + graph: torch.fx.Graph + exclusive_node_set: list[NodeOrConstant] + + def __init__( + self, + outputs: list[Optional[PatternExpr]], + pattern_to_node: Optional[dict[PatternExpr, torch.fx.Node]] = None, + *, + graph: torch.fx.Graph, + ) -> None: + self.outputs = outputs + self.pattern_to_node = {} if pattern_to_node is None else dict(pattern_to_node) + self.graph = graph + self.exclusive_node_set = [] + + def match(self, pattern: PatternExpr, node: NodeOrConstant) -> MatchResult: + """wrapper to check reused nodes in patterns""" + if pattern in self.pattern_to_node: + if self.pattern_to_node[pattern] == node: + return Match(self, pattern) # already checked this node + else: + return FailedMatch("repeated pattern differs") + m = pattern._match(node, self) + assert pattern not in self.pattern_to_node + self.pattern_to_node[pattern] = node if m else None + return m + + def filter_multi_user_patterns(self) -> dict[PatternExpr, torch.fx.Node]: + return { + pattern: node + for pattern, node in self.pattern_to_node.items() + if pattern.has_multiple_users() and node is not None + } + + +class PatternExpr(ABC): + """ + Base class for types of patterns. + """ + + @abstractmethod + def _match(self, node: torch.fx.Node, ctx: MatchContext) -> MatchResult: ... + + def match(self, node: torch.fx.Node) -> MatchResult: + try: + return MatchContext([self], graph=node.graph).match(self, node) + except FailedMatch as e: + return e + + def has_multiple_users(self) -> bool: + return False + + def __repr__(self) -> str: + return self.__class__.__name__ + "()" + + def find_anchor_nodes( + self, ctx: MatchContext, searched: OrderedSet[torch.fx.Node] + ) -> Generator[Optional[torch.fx.Node], None, None]: + if self in ctx.pattern_to_node: + yield ctx.pattern_to_node[self] + + def pattern_eq(self, other: Any) -> bool: + """ + Compare two `PatternExpr`s and return true if they are the + same. Note this is NOT matching a pattern - it is comparing the pattern + structures (for debugging). + """ + return isinstance(other, self.__class__) + + +class Arg(PatternExpr): + """ + Capture an arg which will become an input to the handler. Args are + passed in depth first order. + """ + + def _match(self, node: NodeOrConstant, ctx: MatchContext) -> MatchResult: + return Match(ctx, self, args=[node]) # matches anything + + +class Ignored(PatternExpr): + """ + Match an arg, but don't pass it to handler + """ + + def _match(self, node: NodeOrConstant, ctx: MatchContext) -> MatchResult: + return Match(ctx, self) # matches anything + + def __repr__(self) -> str: + return "*" + + def pretty_print(self, pp: PatternPrettyPrinter) -> str: + return "Ignored()" + + +class KeywordArg(PatternExpr): + """ + Capture a kwarg which will become an input to the handler. + """ + + def __init__(self, name: str) -> None: + super().__init__() + self.name = name + + def __repr__(self) -> str: + return f"KeywordArg({self.name!r})" + + def _match(self, node: NodeOrConstant, ctx: MatchContext) -> MatchResult: + return Match(ctx, self, kwargs={self.name: node}) # matches anything + + def pattern_eq(self, other: Any) -> bool: + other = typing.cast(Self, other) # super makes sure this is true + return super().pattern_eq(other) and self.name == other.name + + +class ExclusiveKeywordArg(PatternExpr): + """ + Capture a kwarg which will become an input to the handler. + """ + + name: str + + def __init__(self, name: str) -> None: + super().__init__() + self.name = name + + def __repr__(self) -> str: + return f"ExclusiveKeywordArg({self.name!r})" + + def _match(self, node: NodeOrConstant, ctx: MatchContext) -> MatchResult: + if node in ctx.exclusive_node_set: + return FailedMatch("exclusive arg appears twice") + + ctx.exclusive_node_set.append(node) + return Match(ctx, self, kwargs={self.name: node}) # matches anything + + def pattern_eq(self, other: Any) -> bool: + other = typing.cast(Self, other) # super makes sure this is true + return super().pattern_eq(other) and self.name == other.name + + +class _TargetExpr(PatternExpr): + """ + Base class for filtering match by node.target + """ + + fns: list[FnsType] + fns_set: OrderedSet[FnsType] + + def __init__( + self, fns: Union[FnsType, Sequence[FnsType]], users: Union[Multiple, int] = 1 + ) -> None: + super().__init__() + fns = [fns] if callable(fns) or isinstance(fns, str) else list(fns) + for fn in fns: + if isinstance(fn, torch._ops.OpOverloadPacket): + fns.extend(getattr(fn, overload) for overload in fn.overloads()) # noqa: B909 + + self.fns = fns + self.fns_set = OrderedSet(fns) + self.users = users + + @property + @abstractmethod + def op(self) -> str: ... + + def fns_repr(self) -> str: + first_repr = self.fns[0] + if not isinstance(first_repr, str): + first_repr = first_repr.__name__ + + if len(self.fns) > 1: + return f"[{first_repr}, ...]" + elif self.fns[0] is getattr(torch, first_repr, None): + return f"torch.{first_repr}" + elif self.fns[0] is getattr(operator, first_repr, None): + return f"operator.{first_repr}" + elif isinstance(self.fns[0], torch._ops.OpOverload): + return str(self.fns[0]) + else: + return first_repr + + def __repr__(self) -> str: + if self.users is MULTIPLE: + comma_users = ", MULTIPLE" + elif self.users != 1: + comma_users = f", {self.users})" + else: + comma_users = "" + return f"{self.__class__.__name__}({self.fns_repr()}{comma_users})" + + def has_multiple_users(self) -> bool: + return isinstance(self.users, Multiple) or self.users > 1 + + def find_anchor_nodes( + self, ctx: MatchContext, searched: OrderedSet[torch.fx.Node] + ) -> Generator[Optional[torch.fx.Node], None, None]: + raise NotImplementedError + + def _match_fns(self, node: torch.fx.Node) -> bool: + return ( + isinstance(node, torch.fx.Node) + and node.op == self.op + and extract_target(node) in self.fns_set + ) + + def _match_users(self, node: torch.fx.Node, ctx: MatchContext) -> bool: + return ( + self in ctx.outputs + or self.users is MULTIPLE + or len(node.users) == self.users + ) + + def pattern_eq(self, other: Any) -> bool: + other = typing.cast(Self, other) # super makes sure this is true + return ( + super().pattern_eq(other) + and self.op == other.op + and self.fns == other.fns + and self.users == other.users + ) + + +_SimpleSpec = tuple[Any, ...] + + +class _TargetArgsExpr(_TargetExpr): + """ + Base class for filtering match by node.{target,args,kwargs} + """ + + def __init__( + self, + fns: Union[torch.fx.node.Target, str, Sequence[Any]], + *args: Any, + _users: Union[int, Multiple] = 1, + **kwargs: Any, + ) -> None: + super().__init__(fns, _users) + self.args = tuple(args) + self.kwargs = dict(kwargs) + if any( + isinstance(x, (dict, list, tuple)) + for x in itertools.chain(args, kwargs.values()) + ): + self.flatten = self.pytree_flatten + else: + self.flatten = self.simple_flatten + self.flat_args_kwargs = self.flatten(self.args, self.kwargs) + + @staticmethod + def simple_flatten( + args: Sequence[Any], kwargs: Mapping[Any, Any] + ) -> tuple[Sequence[Any], Union[_SimpleSpec, pytree.TreeSpec]]: + values = (*args, *kwargs.values()) + spec = (len(args), *kwargs.keys()) + return values, spec + + @staticmethod + def pytree_flatten( + args: Sequence[Any], kwargs: Mapping[Any, Any] + ) -> tuple[Sequence[Any], Union[_SimpleSpec, pytree.TreeSpec]]: + type_mapping: dict[type, type] = { + immutable_list: tuple, + list: tuple, + immutable_dict: dict, + } + + def convert_type(x: Any) -> Any: + cls = type(x) + convert_fn = type_mapping.get(cls) + if convert_fn is not None: + return pytree.tree_map( + convert_type, + convert_fn(x), + is_leaf=lambda x: type(x) in type_mapping, + ) + return x + + normalized_args_tree = pytree.tree_map( + convert_type, + (args, kwargs), + is_leaf=lambda x: type(x) in type_mapping, + ) + flat, spec = pytree.tree_flatten(normalized_args_tree) + return flat, spec + + def __repr__(self) -> str: + args = [ + self.fns_repr(), + *map(repr, self.args), + *[f"{k}={v}" for k, v in self.kwargs.items()], + ] + if self.users is MULTIPLE: + args.append("_users=MULTIPLE") + elif self.users != 1: + args.append(f"_users={self.users}") + return f"{self.__class__.__name__}({', '.join(args)})" + + def pretty_print(self, pp: PatternPrettyPrinter) -> str: + args = [ + self.fns_repr(), + *(pp.pretty_print(x) for x in self.args), + *[f"{k}={pp.pretty_print(v)}" for k, v in self.kwargs.items()], + ] + if self.users is MULTIPLE: + args.append("_users=MULTIPLE") + elif self.users != 1: + args.append(f"_users={self.users}") + + joiner_str = ", " + return f"{self.__class__.__name__}({joiner_str.join(args)})" + + def _match(self, node: torch.fx.Node, ctx: MatchContext) -> MatchResult: + if not self._match_fns(node) or len(node.args) != len(self.args): + return FailedMatch("function_mismatch: node={}, pattern={}", node, self) + + if not self._match_users(node, ctx): + return FailedMatch("multiple_users {}", self) + + _args = node.args + _kwargs = node.kwargs + if len(_kwargs) < len(self.kwargs): + from torch.fx.operator_schemas import normalize_function + + assert callable(node.target) + normalized_args_and_kwargs = normalize_function( + node.target, node.args, node.kwargs + ) + + if normalized_args_and_kwargs is None: + return FailedMatch("function_mismatch: node={}, pattern={}", node, self) + else: + _args, _kwargs = normalized_args_and_kwargs + if len(_args) == len(self.args) and len(_kwargs) >= len(self.kwargs): + _kwargs = {i: _kwargs[i] for i in _kwargs if i in self.kwargs} + else: + return FailedMatch( + "function_mismatch: node={}, pattern={}", node, self + ) + else: + _kwargs = {i: _kwargs[i] for i in _kwargs if i in self.kwargs} + + node_items, node_spec = self.flatten(_args, _kwargs) + self_items, self_spec = self.flat_args_kwargs + if node_spec != self_spec: + return FailedMatch("args_structure {} {}", node_spec, self_spec) + assert len(node_items) == len(self_items) + + m = Match(ctx, self) + for pattern, child_node in zip(self_items, node_items): + if isinstance(pattern, PatternExpr): + child_match = ctx.match(pattern, child_node) + if not is_match(child_match): + return child_match + m.extend(child_match) + elif isinstance(child_node, torch.fx.Node) or child_node != pattern: + return FailedMatch( + "constant_args: {} {!r}!={pattern!r}", + node, + child_node, + pattern=pattern, + ) + m.nodes.append(node) + m.targets[self] = node.target + return m + + def find_anchor_nodes( + self, ctx: MatchContext, searched: OrderedSet[torch.fx.Node] + ) -> Generator[Optional[torch.fx.Node], None, None]: + """ + This is used when we are matching a pattern with multiple outputs. + There is a partial match (stored in ctx) and we want to walk + this pattern to find a connection to an already-matched node. + + Yields candidate nodes that `self._match` might like. + """ + if self in ctx.pattern_to_node: + yield ctx.pattern_to_node[self] + return + + for pattern in self.flat_args_kwargs[0]: + if isinstance(pattern, PatternExpr): + for other_node in pattern.find_anchor_nodes(ctx, searched): + if not isinstance(other_node, torch.fx.Node): + continue + for node in other_node.users: + if node not in searched: + if self._match_fns(node): + yield node + searched.add(node) + + def pattern_eq(self, other: Any) -> bool: + other = typing.cast(Self, other) # super makes sure this is true + return ( + super().pattern_eq(other) + and self.flat_args_kwargs[1] == other.flat_args_kwargs[1] + and all( + a.pattern_eq(b) if isinstance(a, PatternExpr) else a == b + for a, b in zip(self.flat_args_kwargs[0], other.flat_args_kwargs[0]) + ) + ) + + +class CallFunction(_TargetArgsExpr): + """ + Matches a call_function node in the FX graphs: `fns[i](*args, **kwargs)` + """ + + op = "call_function" + + +class CallMethod(_TargetArgsExpr): + """ + Matches a call_method node in the FX graphs: `fns[i].method(*args, **kwargs)` + """ + + op = "call_method" + + +class CallModule(_TargetArgsExpr): + """ + Matches a call_module node in the FX graphs: `module(*args, **kwargs)` + """ + + op = "call_module" + + +class _TargetExprVarArgs(_TargetExpr): + """ + Matches a call_function node with any arguments which are passed into the pattern + """ + + def _match(self, node: torch.fx.Node, ctx: MatchContext) -> MatchResult: + if not self._match_fns(node): + return FailedMatch("function_mismatch") + + if not self._match_users(node, ctx): + return FailedMatch("multiple_users") + + m = Match(ctx, self) + m.nodes.append(node) + m.targets[self] = node.target + m.args.extend(node.args) + m.kwargs.update(node.kwargs) + return m + + +class CallFunctionVarArgs(_TargetExprVarArgs): + op = "call_function" + + +class CallMethodVarArgs(_TargetExprVarArgs): + op = "call_method" + + +class CallModuleVarArgs(_TargetExprVarArgs): + op = "call_module" + + +class ListOf(PatternExpr): + """ + Matches a repeated pattern + """ + + def __init__(self, pattern: PatternExpr, partial: bool = False) -> None: + super().__init__() + assert isinstance(pattern, PatternExpr) + self.pattern = pattern + self.partial = partial + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.pattern})" + + def _match(self, node: list[torch.fx.Node], ctx: MatchContext) -> MatchResult: # type: ignore[override] + if not isinstance(node, (list, tuple)) or len(node) == 0: + return FailedMatch("non_list") + m = Match(ctx, self) + # Propagating patterns with multiple users will ensure we don't revisit + # the same nodes + pattern_to_node = ctx.filter_multi_user_patterns() + matched = False + for i, child_node in enumerate(node): + child_ctx = MatchContext( + ctx.outputs, pattern_to_node, graph=child_node.graph + ) + child_match = child_ctx.match(self.pattern, child_node) + pattern_to_node = child_ctx.filter_multi_user_patterns() + if not is_match(child_match): + if not self.partial: + return FailedMatch("list[{}]: {}", i, child_match) + continue + matched = True + m.extend(child_match.bundle()) + if not matched: + return FailedMatch("list: no_match") + return m.bundle() + + def pattern_eq(self, other: Any) -> bool: + other = typing.cast(Self, other) # super makes sure this is true + return ( + super().pattern_eq(other) + and self.pattern.pattern_eq(other.pattern) + and self.partial == other.partial + ) + + +class MultiOutputPattern(PatternExpr): + outputs: list[Optional[PatternExpr]] + + def __init__(self, outputs: Sequence[Optional[PatternExpr]]) -> None: + super().__init__() + assert isinstance(outputs[0], _TargetExpr) + assert all(x is None or isinstance(x, PatternExpr) for x in outputs), outputs + self.outputs = list(outputs) + self.op = outputs[0].op + + @property + def fns(self) -> Union[Callable[..., Any], str, Sequence[Any]]: + # This cast is checked above in __init__() + output = typing.cast(_TargetExpr, self.outputs[0]) + return output.fns + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.outputs})" + + def pretty_print(self, pp: PatternPrettyPrinter) -> str: + args = [pp.pretty_print(x) for x in self.outputs] + joiner_str = f",\n{' '}" + str_out = f"{self.__class__.__name__}([{joiner_str.join(args)}" + str_out = f"{str_out}\n])" + return str_out + + def _match(self, node: torch.fx.Node, ctx: MatchContext) -> MatchResult: + output = typing.cast(_TargetExpr, self.outputs[0]) + m = ctx.match(output, node) + if not is_match(m): + return m + + for pattern in self.outputs[1:]: + if pattern is None: + continue + child_match = self._match_from_anchors(pattern, ctx) + if not is_match(child_match): + return child_match + m.extend(child_match) + + return m + + def _match_from_anchors( + self, pattern: PatternExpr, ctx: MatchContext + ) -> MatchResult: + prior = dict(ctx.pattern_to_node) + m: MatchResult = FailedMatch("no anchor found") + for node in pattern.find_anchor_nodes(ctx, OrderedSet()): + m = ctx.match(pattern, node) + if is_match(m): + return m + # revert any partial matches + ctx.pattern_to_node = dict(prior) + return m + + def match(self, node: torch.fx.Node) -> MatchResult: + try: + return MatchContext(self.outputs, graph=node.graph).match(self, node) + except FailedMatch as e: + return e + + def pattern_eq(self, other: Any) -> bool: + other = typing.cast(Self, other) # super makes sure this is true + return ( + super().pattern_eq(other) + and len(self.outputs) == len(other.outputs) + and all( + a.pattern_eq(b) if isinstance(a, PatternExpr) else a == b + for a, b in zip(self.outputs, other.outputs) + ) + ) + + +class RepeatedExpr(PatternExpr): + """ + Checks for a repeated pattern. Useful for repeated operations after a node such as `split` or `unbind` + """ + + def __init__(self, inner_pattern: _TargetExpr) -> None: + super().__init__() + self.inner_pattern = inner_pattern + self.op = inner_pattern.op + + @property + def fns(self) -> Sequence[FnsType]: + return self.inner_pattern.fns + + def _match(self, node: torch.fx.Node, ctx: MatchContext) -> MatchResult: + m = ctx.match(self.inner_pattern, node) + if not is_match(m): + return m + ctx.pattern_to_node.pop( + self.inner_pattern, + ) + # Check all anchor nodes match the pattern + for anchor_node in self.inner_pattern.find_anchor_nodes(ctx, OrderedSet()): + anchor_m = MatchContext([self], graph=node.graph).match( + self.inner_pattern, anchor_node + ) + if not is_match(anchor_m): + return anchor_m + m.extend(anchor_m) + return m + + def pattern_eq(self, other: Any) -> bool: + other = typing.cast(Self, other) # super makes sure this is true + return super().pattern_eq(other) and self.inner_pattern.pattern_eq( + other.inner_pattern + ) + + +class PatternPrettyPrinter: + """ + Serializes Patterns to executable python. + XXX: currently only used and tested for fuse attention patterns. May not cover + all patterns. + """ + + def __init__(self) -> None: + self.namespace = torch.fx.graph._Namespace() + self.memoized_objs_names: dict[PatternExpr, str] = {} + self.memoized_objs_pp: dict[PatternExpr, str] = {} + + @staticmethod + @functools.cache + def run(obj: PatternExpr, output_name: str = "output") -> str: + """ + Serializes obj to python code with obj written out to `output_name` + """ + + pp = PatternPrettyPrinter() + assert hasattr(obj, "pretty_print") + out_str = obj.pretty_print(pp=pp) + + output = [ + f"{pp.memoized_objs_names[key]} = {pp.memoized_objs_pp[key]}" + for key in pp.memoized_objs_names + ] + + output.append(f"{output_name} = {out_str}") + + return "\n".join(output) + + def pretty_print(self, obj: Any) -> str: + if isinstance(obj, _TargetArgsExpr): + if memoized_name := self.memoized_objs_names.get(obj): + return memoized_name + else: + return self.memoize(obj) + if hasattr(obj, "pretty_print"): + return obj.pretty_print(self) + + return repr(obj) + + def memoize(self, obj: _TargetArgsExpr) -> str: + obj_str = obj.pretty_print(self) + obj_name = obj.fns_repr() + for prefix in ("aten.", "torch.", "prims."): + obj_name = obj_name.replace(prefix, "") + + tmp_name = self.namespace.create_name(obj_name, None) + self.memoized_objs_names[obj] = tmp_name + self.memoized_objs_pp[obj] = obj_str + return tmp_name + + +class _PassDictsType(Protocol): + def __getitem__( + self, k: tuple[str, torch.fx.node.Target] + ) -> list[PatternEntry]: ... + + +@dataclasses.dataclass +class PatternEntry: + pattern: PatternExpr + extra_check: Callable[[Match], bool] + + def apply(self, match: Match, graph: torch.fx.Graph, node: torch.fx.Node) -> None: + raise NotImplementedError + + def register( + self, + pass_dicts: Union[_PassDictsType, Sequence[_PassDictsType]], + target: Union[torch.fx.node.Target, None] = None, + prepend: bool = False, + ) -> None: + if target is None: + assert hasattr(self.pattern, "fns") + for fn in self.pattern.fns: + self.register(pass_dicts, fn, prepend=prepend) + elif isinstance(pass_dicts, (dict, PatternMatcherPass)): + assert hasattr(self.pattern, "op") + if prepend: + pass_dicts[(self.pattern.op, target)].insert(0, self) + else: + pass_dicts[(self.pattern.op, target)].append(self) + else: + pass_dicts = typing.cast(Sequence[_PassDictsType], pass_dicts) + for x in pass_dicts: + self.register(x, target, prepend=prepend) + + +@dataclasses.dataclass +class LoweringPatternEntry(PatternEntry): + handler: Callable[..., Any] + + def apply(self, match: Match, graph: torch.fx.Graph, node: torch.fx.Node) -> None: + handler = functools.wraps(self.handler)(functools.partial(self.handler, match)) + with graph.inserting_before(node): + replacement = graph.call_function(handler, tuple(match.args), match.kwargs) + replacement.meta.update(node.meta) + node.replace_all_uses_with(replacement) + assert match.nodes[-1] is node + match.erase_nodes() + + +@dataclasses.dataclass +class GraphPatternEntry(PatternEntry): + """ + A pattern that runs a function on the FX graph + """ + + handler: Callable[..., Any] + + def apply(self, match: Match, graph: torch.fx.Graph, node: torch.fx.Node) -> None: + with graph.inserting_before(node): + self.handler(match, *match.args, **match.kwargs) + + +@dataclasses.dataclass +class ReplacementPatternEntry(PatternEntry): + """ + The replacement pattern for the graph + """ + + normalize_args: Callable[..., list[Any]] + + @staticmethod + def replace_with_graph( + match: Match, + graph: torch.fx.Graph, + replacement_graph: Union[torch.fx.Graph, torch.fx.GraphModule], + args: Sequence[torch.fx.Node], + ) -> None: + """ + Inserts the replacement graph into the toplevel graph at the match + """ + + added_replacement_nodes: list[torch.fx.Node] = [] + + class Replacer(torch.fx.Interpreter): + call_method = None # type: ignore[assignment] + call_module = None # type: ignore[assignment] + get_attr = None # type: ignore[assignment] + + def run_node(self, node: torch.fx.Node) -> Any: + if node.op in ("placeholder", "output"): + return super().run_node(node) + target = node.target + args, kwargs = self.fetch_args_kwargs_from_env(node) + if node.op == "call_function": + assert callable(target) + result = graph.call_function(target, args, kwargs) + added_replacement_nodes.append(result) + _transfer_meta( + new_meta=result.meta, + old_node=node, + pass_name="Interpreter_Replacer", + ) + # This function copy-pastes the replacement graph into + # the graph. If the replacement graph had any eager_input_vals, + # or val/tensor_meta, we propagate those over. + if "eager_input_vals" in node.meta: + result.meta["eager_input_vals"] = node.meta["eager_input_vals"] + if "val" in node.meta and "val" not in result.meta: + result.meta["val"] = node.meta["val"] + if isinstance(node.meta["val"], torch.Tensor): + assert "tensor_meta" in node.meta + result.meta["tensor_meta"] = node.meta["tensor_meta"] + return result + if node.op == "get_attr": + # If the replacement graph contains a HOP, the subgraphs of the HOP are "get_attr" nodes. + # We need to fetch the subgraph of the HOP then register the subgraph to the replaced graph's root. + from torch._higher_order_ops.utils import ( + unique_graph_name_with_root, + ) + + sub_gm = super().get_attr(target, args, kwargs) + if not isinstance(sub_gm, torch.fx.GraphModule): + raise NotImplementedError( + f"NYI: replacement_graph.{target} is not a graph module. Got {sub_gm}." + ) + assert graph.owning_module is not None + graph_name = None + for n, mod in graph.owning_module.named_modules(): + if sub_gm is mod: + graph_name = n + break + if graph_name is None: + assert isinstance(target, str) + _, graph_name = unique_graph_name_with_root( + # pyrefly: ignore [unbound-name] + graph.owning_module, + target, + ) + # pyrefly: ignore [unbound-name] + graph.owning_module.register_module(graph_name, sub_gm) + # pyrefly: ignore [unbound-name] + getattr_node = graph.get_attr(graph_name) + added_replacement_nodes.append(getattr_node) + return getattr_node + + raise NotImplementedError(f"unhandled {node}") + + output_nodes = match.output_nodes() + + if len(output_nodes) == 1: + last_node = output_nodes[0] + else: + assert output_nodes[0] + nodes = list(output_nodes[0].graph.nodes) + indices = [ + (nodes.index(n), n) + for n in output_nodes + if isinstance(n, torch.fx.Node) + ] + last_node = min(indices, key=operator.itemgetter(0))[1] + + def percolate_tags( + node: torch.fx.Node, + tag_name: str, + tag_value: str, + input_stops: OrderedSet[torch.fx.Node], + ) -> None: + queue = [node] + visited = OrderedSet[torch.fx.Node]() + + while queue: + arg = queue.pop() + if ( + arg not in visited + and arg not in input_stops + and hasattr(arg, "meta") + ): + visited.add(arg) + arg.meta[tag_name] = tag_value + queue.extend(arg.all_input_nodes) + + with graph.inserting_before(last_node): + assert isinstance(replacement_graph, torch.fx.GraphModule) + replacement = Replacer(replacement_graph).run(*args) + if isinstance(replacement, torch.fx.Node): + replacement = [replacement] + + def maybe_getitem(node: torch.fx.Node) -> Any: + if node.op != "call_function": + return None + if node.target != operator.getitem: + return None + assert len(node.args) == 2 + return node.args[1] + + def replace( + old: Union[torch.fx.Node, None], + new: Union[torch.fx.Node, Sequence[torch.fx.Node], None], + ) -> None: + def filter_nodes_in_newly_added_nodes(node: torch.fx.Node) -> bool: + # Do not replace the use of a node if it is being used by + # nodes in the replaced graph + return node not in added_replacement_nodes + + if old is None: + assert new is None + return + assert isinstance(old, torch.fx.Node) + if new is None: + old.replace_all_uses_with( + None, # type: ignore[arg-type] + delete_user_cb=filter_nodes_in_newly_added_nodes, + ) + if len(old.users) == 0: + graph.erase_node(old) + return + if isinstance(new, torch.fx.Node): + if "val" not in new.meta: + new.meta.update(old.meta) + + # Preserve the recompute tags in the replacement graph. We + # look at the recompute tags of the original output node to + # propagate the tag from the output all the way to the input + # args (named as args in the replace_with_graph). + # Note that this is best effort. Since patterns are from + # many to many, there is no easy way to correctly map the + # recomputable tags. It is possible in some scenarios that we + # incorrectly tag some nodes as recomputables. + for tag_name in ["recompute", "ac_graph_id"]: + if tag_name in old.meta: + percolate_tags( + new, tag_name, old.meta[tag_name], OrderedSet(args) + ) + + old.replace_all_uses_with( + new, delete_user_cb=filter_nodes_in_newly_added_nodes + ) + if len(old.users) == 0: + graph.erase_node(old) + return + + # `new` is not a node: it's a list of nodes. + # + # This happens when we want to replace a node that has a single + # packed return with multiple unpacked returns. We need to do + # some graph surgery here. + # + # Example: + # def original_graph(x): + # a = op(x) + # b = a[0] + # c = a[1] + # ... + # + # Assume that we want to replace op(x) with the graph + # def new_op(x): + # w = x + 1 + # z = x + 2 + # return (w, z) + # + # We need to replace `op` with the contents of `new_op`, + # and then rewrite a[0] to be w and a[1] to be z, as so: + # def new_graph(x): + # w = x + 1 + # z = x + 2 + # b = w + # c = z + # ... + old_uses = list(old.users.keys()) + for user in old_uses: + idx = maybe_getitem(user) + if idx is None: + raise AssertionError( + "Deleted index from getitem, did you erase the index and not properly replace it?" + ) + replace(user, new[idx]) + graph.erase_node(old) + + if len(output_nodes) == len(replacement): + for old, new in zip(output_nodes, replacement): + replace(old, new) + else: + assert len(output_nodes) == 1 + replace(output_nodes[0], replacement) + + match.erase_nodes() + + def apply(self, match: Match, graph: torch.fx.Graph, node: torch.fx.Node) -> None: + assert match.replacement_graph is not None + self.replace_with_graph( + match, + graph, + match.replacement_graph, + self.normalize_args(*match.args, **match.kwargs), + ) + + +def _return_true(match: Match) -> bool: + return True + + +def log_trace_failure(search_fn: Callable[..., Any], e: RuntimeError) -> None: + log.info( + "Replacement pattern %s failed to apply due to shape mismatch: %s", + search_fn.__name__, + e, + ) + + +def check_and_add_duplicate_pattern( + pattern: PatternExpr, + graph: Optional[torch.fx.Graph], + seen_patterns: dict[str, list[Optional[str]]], + skip_duplicates: bool = False, +) -> bool: + """ + Check if a pattern is a duplicate. Because we ignore certain types in searching, but not + in matching, use the graph to distinguish equivalent search patterns. + + Returns True if a duplicate is found and `skip_duplicates=True` is passed in. Errors if + `skip_duplicates` is False and a duplicate is found. + """ + + pattern_repr = PatternPrettyPrinter.run(pattern) + equiv_pattern_reprs = seen_patterns.get(pattern_repr) + if not equiv_pattern_reprs: + seen_patterns[pattern_repr].append(str(graph) if graph else None) + return False + + if graph is None: + if skip_duplicates: + return True + torch._check( + False, + lambda: f"Duplicate pattern: {pattern_repr} with no graph", + ) + + new_graph_str = str(graph) + for graph_str in equiv_pattern_reprs: + if new_graph_str != graph_str: + continue + if skip_duplicates: + return True + torch._check( + False, + lambda: f"Duplicate pattern: {pattern_repr} with duplicated match graph {graph_str} ", + ) + equiv_pattern_reprs.append(new_graph_str) + return False + + +def register_replacement( + search_fn: SearchFn, + replace_fn: ReplaceFn, + example_inputs: Iterable[Any], + trace_fn: TraceFn, + pass_dicts: Union[_PassDictsType, Sequence[_PassDictsType]], + extra_check: Callable[[Match], bool] = _return_true, + scalar_workaround: Union[dict[str, Union[float, int]], None] = None, + exclusive_arg_names: Sequence[str] = (), + search_fn_pattern: Union[PatternExpr, None] = None, + skip_duplicates: bool = False, +) -> bool: + """ + Create a replacement rule based on example functions that get traced + to create patterns. This supports both training and inference when + run on a joint forward+backward graph. + + Args: + search_fn: traced to give original pattern + replace_fn: traced to give replacement graph + example_inputs: example inputs for initial trace + trace_fn: fwd_only or joint_fwd_bwd + pass_dict: dict of passes to register to + extra_check: additional check to run on match(using real shapes) + """ + argnames_static = [*inspect.signature(search_fn).parameters.keys()] + + if inspect.ismethod(search_fn): + search_fn = _wrap_bound_method(search_fn, argnames_static) + + if inspect.ismethod(replace_fn): + replace_argnames = [*inspect.signature(replace_fn).parameters.keys()] + replace_fn = _wrap_bound_method(replace_fn, replace_argnames) + + def check_fn(match: Match) -> bool: + """ + Often shapes get burned into the pattern, so our initial match ran with + `ignore_types=(int, ...)`. + + Recheck the match with the correct shapes. + """ + argnames = list(argnames_static) + for name in argnames: + if name not in match.kwargs: + raise RuntimeError( + f"Not all inputs to pattern found in match.kwargs. Perhaps one " + f"of the inputs is unused? argnames={argnames}, match.kwargs={match.kwargs}" + ) + + args = list( + torch.fx.map_arg( + [match.kwargs[name] for name in argnames], lambda n: n.meta["val"] + ) + ) + + sym_args: list[torch.SymInt] = [] + fake_mode = torch._dynamo.utils.detect_fake_mode(args) + assert fake_mode is not None + with fake_mode: + for i, grad in enumerate(requires_grad): + if isinstance(args[i], torch.Tensor): + if grad and is_integer_dtype(args[i].dtype): + return False + + args[i] = torch.empty_strided( + args[i].size(), + args[i].stride(), + dtype=args[i].dtype, + device=args[i].device, + requires_grad=grad, + ) + for v in itertools.chain(args[i].shape, args[i].stride()): + if isinstance(v, torch.SymInt) and all( + statically_known_true(v != a) for a in sym_args + ): + sym_args.append(v) + + # If we were given a pre-traced pattern then use that instead of + # retracing. Note that this means the pattern has to be independent + # of its args. + specific_pattern = search_fn_pattern + + if not specific_pattern: + if sym_args: + # AOT Autograd and make fx will dedupe symbolic shape size + # accesses of sym ints that appear as inputs + # We don't want the sym_size uses to interfere with pattern matching + # so we provide them as inputs. + # Later, when we actually do the replacement, the symbolic shape + # sizes will get re-traced and added to the graph. + + def search_fn_new(*args_new: Any) -> Any: + return search_fn(*args_new[len(args_new) - len(args) :]) + + try: + # pyrefly: ignore [bad-argument-type] + specific_graph = trace_fn(search_fn_new, sym_args + args) + except RuntimeError as e: + log_trace_failure(search_fn, e) + return False + + # correct argnames in the graph + sym_arg_names = [] + for i, placeholder in zip( + range(len(sym_args) + len(args)), + specific_graph.graph.nodes, + ): + if i < len(sym_args): + sym_arg_names.append(placeholder.target) + continue + + with specific_graph.graph.inserting_after(placeholder): + new_node = specific_graph.graph.placeholder( + argnames[i - len(sym_args)] + ) + new_node.target = new_node.name + placeholder.replace_all_uses_with(new_node) + specific_graph.graph.erase_node(placeholder) + + argnames = sym_arg_names + argnames + else: + try: + specific_graph = trace_fn(search_fn, args) + except RuntimeError as e: + log_trace_failure(search_fn, e) + return False + + specific_pattern = fx_to_pattern( + specific_graph, + argnames=argnames, + exclusive_arg_names=exclusive_arg_names, + scalar_workaround=scalar_workaround, + ) + + node = match.output_nodes()[0] + assert node is not None + specific_pattern_match = specific_pattern.match(node) + + if os.environ.get("TORCHINDUCTOR_PATTERN_MATCH_DEBUG") == node.name: + log.warning( + "Specific pattern match: %s%s %s %s", + node, + node.args, + specific_pattern_match, + specific_pattern, + ) + + if is_match(specific_pattern_match) and extra_check(specific_pattern_match): + # trace the pattern using the shapes from the user program + match.replacement_graph = trace_fn(replace_fn, args) + if len(match.nodes) == 1: + for n in match.replacement_graph.graph.nodes: + _transfer_meta( + new_meta=n.meta, + old_node=match.nodes[0], + pass_name="replacement", + ) + return True + return False + + def normalize_args(**kwargs: Any) -> list[Any]: + args = [kwargs.pop(name) for name in argnames_static] + for i in range(1, len(kwargs) + 1): + if f"tangents_{i}" not in kwargs: + break + args.append(kwargs.pop(f"tangents_{i}")) + assert not kwargs, f"leftover kwargs: {kwargs!r}" + return args + + if trace_fn is joint_fwd_bwd: + # If inference mode is enabled during compilation, assume that we don't + # want to match on any training graph patterns + if torch.is_inference_mode_enabled(): + return False + + # TODO: Revisit the functionalize_rng_ops for lowmem dropout + with functorch_config.patch(functionalize_rng_ops=False): + requires_grad: list[bool] = [ + isinstance(x, torch.Tensor) and x.requires_grad for x in example_inputs + ] + if search_fn_pattern is None: + pattern, gm = gen_pattern_and_search_gm( + search_fn, + example_inputs, + trace_fn, + scalar_workaround, + exclusive_arg_names, + ) + else: + pattern = search_fn_pattern + gm = None + + for pattern_matcher_pass in ( + pass_dicts if isinstance(pass_dicts, Sequence) else [pass_dicts] + ): + if isinstance(pattern_matcher_pass, PatternMatcherPass): + if check_and_add_duplicate_pattern( + pattern, + gm.graph if gm else None, + pattern_matcher_pass.seen_patterns, + skip_duplicates=skip_duplicates, + ): + return False + + pattern = ReplacementPatternEntry( + pattern=pattern, + extra_check=check_fn, + normalize_args=normalize_args, + ) + pattern.register(pass_dicts) + return pattern.pattern # type: ignore[return-value] + + +_serialized_patterns: OrderedSet[str] = OrderedSet() + + +def _serialize_pattern( + unique_name: str, + search_fn: SearchFn, + example_inputs: Sequence[Any], + trace_fn: TraceFn, + scalar_workaround: Union[dict[str, Union[float, int]], None], +) -> PatternExpr: + def get_file_template() -> str: + auto_generated_msg = textwrap.dedent( + """\ + # This is an auto-generated file. Please do not modify it by hand. + # To re-generate, run: + # cd ~/pytorch && python torchgen/fuse/gen_patterns.py + """ + ) + + file_template = textwrap.dedent( + """\ + # mypy: ignore-errors + + # noqa: F401, E501 + {msg} + import torch + import torch._inductor + import operator + + aten = torch.ops.aten + prims = torch.ops.prims + + """ + ).format(msg=auto_generated_msg) + + pattern_matcher_imports = [] + for name in dir(torch._inductor.pattern_matcher): + attr = getattr(torch._inductor.pattern_matcher, name) + try: + if isinstance(attr, type) and issubclass( + attr, (PatternExpr, _TargetExpr) + ): + # pyrefly: ignore [bad-argument-type] + pattern_matcher_imports.append(name) + except TypeError: + pass + + formatted_imports = ",\n ".join(pattern_matcher_imports) + formatted_imports = f"from torch._inductor.pattern_matcher import (\n {formatted_imports},\n)\n" + return f"{file_template}{formatted_imports}" + + if not SERIALIZED_PATTERN_PATH.is_dir(): + raise RuntimeError( + f"Could not find serialized patterns directory at {SERIALIZED_PATTERN_PATH}" + ) + + pattern_name = search_fn.__name__ + + from torch._functorch import config as functorch_config + + with functorch_config.patch(functionalize_rng_ops=False): + pattern = gen_pattern(search_fn, example_inputs, trace_fn, scalar_workaround) + + serialized_pattern = PatternPrettyPrinter.run(pattern, output_name=unique_name) + if pattern_name not in _serialized_patterns: + write_mode = "w" + _serialized_patterns.add(pattern_name) + else: + write_mode = "a" + + file_template = get_file_template() + + with open(SERIALIZED_PATTERN_PATH / f"{pattern_name}.py", write_mode) as f: + if write_mode == "w": + f.write(file_template) + else: + f.write("\n\n") + f.write(serialized_pattern) + f.write("\n") + + return pattern + + +SERIALIZED_PATTERN_PATH = Path(__file__).parent / "fx_passes" / "serialized_patterns" + +# This is the set of serialized patterns that we've registered. Used by +# test_serialized_patterns_up_to_date() to ensure the patterns are up +# to date. +_known_precompiled_patterns: list[ + tuple[ + Any, + Iterable[Any], + Callable[[Callable[..., Any], Iterable[Any]], torch.fx.GraphModule], + Any, + PatternExpr, + ] +] = [] + + +def gen_register_replacement( + unique_name: str, + search_fn: SearchFn, + replace_fn: ReplaceFn, + example_inputs: Iterable[Any], + trace_fn: TraceFn, + pass_dicts: Union[_PassDictsType, Sequence[_PassDictsType]], + extra_check: Callable[[Match], bool] = _return_true, + scalar_workaround: Union[dict[str, Union[float, int]], None] = None, + exclusive_arg_names: Sequence[str] = (), + skip_duplicates: bool = False, +) -> None: + # Make sure the example_inputs is materialized. + example_inputs = tuple(example_inputs) + + if "PYTORCH_GEN_PATTERNS" in os.environ: + pat = _serialize_pattern( + unique_name, search_fn, example_inputs, trace_fn, scalar_workaround + ) + else: + pattern_name = search_fn.__name__ + m = importlib.import_module( + f"torch._inductor.fx_passes.serialized_patterns.{pattern_name}" + ) + if not m or not hasattr(m, unique_name): + log.warning( + "Precompiled pattern %r not found. Run torchgen/fuse/gen_patterns.py.", + unique_name, + ) + pat = getattr(m, unique_name) + + for arg in pytree.tree_iter(example_inputs): + if isinstance(arg, FakeTensor) and arg.constant is not None: + # This can be a problem - small fake tensors (e.g. `tensor(2)`) will + # hold onto their original constant value - and by stashing it here + # will cause a memory leak if the constant value is on GPU. + # Since this is just an optimization we can clear it out. + arg.constant = None + + _known_precompiled_patterns.append( + (search_fn, example_inputs, trace_fn, scalar_workaround, pat) + ) + register_replacement( + search_fn, + replace_fn, + example_inputs, + trace_fn, + pass_dicts, + extra_check, + scalar_workaround, + exclusive_arg_names, + search_fn_pattern=pat, + skip_duplicates=skip_duplicates, + ) + + +@functorch_config.patch(functionalize_rng_ops=False) # type: ignore[misc] +def gen_pattern_and_search_gm( + search_fn: SearchFn, + example_inputs: Sequence[Any], + trace_fn: TraceFn, + scalar_workaround: Union[dict[str, Union[float, int]], None] = None, + exclusive_arg_names: Sequence[str] = (), +) -> tuple[PatternExpr, torch.fx.GraphModule]: + argnames = [*inspect.signature(search_fn).parameters.keys()] + + if scalar_workaround is None: + scalar_workaround = {} + flat_inputs = [] + input_idx = 0 # Positional arguments index + + for argname in argnames: + if argname in scalar_workaround: + flat_inputs.append(scalar_workaround[argname]) + else: + flat_inputs.append(example_inputs[input_idx]) + input_idx += 1 + + search_gm = trace_fn(search_fn, flat_inputs) + return ( + fx_to_pattern( + search_gm, + ignore_types=(int, float, list, torch.device, torch.dtype), + argnames=argnames, + scalar_workaround=scalar_workaround, + exclusive_arg_names=exclusive_arg_names, + ), + search_gm, + ) + + +def gen_pattern( + search_fn: SearchFn, + example_inputs: Sequence[Any], + trace_fn: TraceFn, + scalar_workaround: Union[dict[str, Union[float, int]], None] = None, + exclusive_arg_names: Sequence[str] = (), +) -> PatternExpr: + return gen_pattern_and_search_gm( + search_fn, example_inputs, trace_fn, scalar_workaround, exclusive_arg_names + )[0] + + +def register_lowering_pattern( + pattern: PatternExpr, + extra_check: Callable[[Match], bool] = _return_true, + *, + pass_dict: _PassDictsType, + prepend: bool = False, +) -> Callable[[Callable[..., Any]], Callable[..., Any]]: + """ + Register an aten to inductor IR replacement pattern. The decorated + function is saved and then called a lowering time allowing direct + pattern to inductor IR conversion. + """ + + def decorator(handler: Callable[..., Any]) -> Callable[..., Any]: + assert callable(handler) + LoweringPatternEntry( + pattern=pattern, extra_check=extra_check, handler=handler + ).register(pass_dict, prepend=prepend) + handler._inductor_lowering_function = True # type: ignore[attr-defined] + return handler + + return decorator + + +def register_graph_pattern( + pattern: PatternExpr, + extra_check: Callable[[Match], bool] = _return_true, + *, + pass_dict: _PassDictsType, + prepend: bool = False, +) -> Callable[[Callable[..., Any]], Callable[..., Any]]: + """ + Register a pattern that runs a function on the FX graph, allowing + custom transformation code. + """ + + def decorator(handler: Callable[..., Any]) -> Callable[..., Any]: + assert callable(handler) + GraphPatternEntry( + pattern=pattern, extra_check=extra_check, handler=handler + ).register(pass_dict, prepend=prepend) + return handler + + return decorator + + +def is_start_of_fx_graph(graph: torch.fx.Graph, node: torch.fx.Node) -> bool: + # first node in the graph + return node is next(iter(graph.nodes)) + + +# match: copy_, relu_, _set_grad_enabled, manual_seed, _enter_autocast, etc +# doesn't match: __rshift__, etc +_mutation_op_re = re.compile(r"(? bool: + if op.namespace != "inductor": + return False + + # TODO - fix schema + # Dont add any more ! + return op in ( + torch.ops.inductor.accumulate_grad_.default, + torch.ops.inductor.resize_storage_bytes_.default, + ) + + +def is_mutation_op(node: torch.fx.Node) -> bool: + if isinstance( + node.target, torch._ops.OpOverload + ) and not fixme_incorrect_inductor_schema_op(node.target): + return node.target._schema.is_mutable + elif isinstance( + node.target, torch._higher_order_ops.auto_functionalize.AutoFunctionalized + ): + return False + if node.op == "call_function": + assert callable(node.target) + if _mutation_op_re.search(node.target.__name__): + return True + elif node.op == "call_method": + assert isinstance(node.target, str) + if _mutation_op_re.search(node.target): + return True + return node.kwargs.get("out") is not None + + +def same_mutation_regions(a: torch.fx.Node, b: torch.fx.Node) -> bool: + assert "mutation_region_id" in a.meta + assert "mutation_region_id" in b.meta + return a.meta["mutation_region_id"] == b.meta["mutation_region_id"] + + +def get_mutation_region_id(graph: torch.fx.Graph, node: torch.fx.Node) -> int: + n = node + while "mutation_region_id" not in n.meta and not is_start_of_fx_graph(graph, n): + n = n.prev + mutation_region_id = n.meta.get("mutation_region_id", 0) + while n is not node: + n = n.next + if is_mutation_op(n): + mutation_region_id += 1 + n.meta["mutation_region_id"] = mutation_region_id + return mutation_region_id + + +def should_compute_mutation_region_ids(graph: torch.fx.Graph) -> bool: + return "mutation_region_id" not in next(iter(graph.nodes)).meta + + +def compute_mutation_region_ids(graph: torch.fx.Graph) -> None: + mutation_region_id = 0 + for nd in graph.nodes: + if is_mutation_op(nd): + mutation_region_id += 1 + nd.meta["mutation_region_id"] = mutation_region_id + + +def _wrap_bound_method(fn: Any, argnames: list[str]) -> Any: + """ + Wrap a bound method to remove 'self' from its signature for FX tracing. + """ + + def wrapper(*args: Any, **kwargs: Any) -> Any: + return fn(*args, **kwargs) + + params = [ + inspect.Parameter(name, inspect.Parameter.POSITIONAL_OR_KEYWORD) + for name in argnames + ] + wrapper.__signature__ = inspect.Signature(params) # type: ignore[attr-defined] + return wrapper + + +class PatternMatcherPass: + def __init__( + self, + pass_name: Optional[str] = None, + subsystem: Optional[str] = None, + ) -> None: + super().__init__() + self.patterns: defaultdict[ + tuple[str, torch.fx.node.Target], list[PatternEntry] + ] = defaultdict(list) + self.pass_name = pass_name + self.subsystem = subsystem + + # For a particular generated pattern repr, store all of the str representations + # of the graph used to generate them. Because we ignore certain patterns + # in searching, but not in matching, use the graph to distinguish if two equivalent + # searches are actually different. + self.seen_patterns: dict[str, list[Optional[str]]] = defaultdict(list) + + def __getitem__(self, item: tuple[str, torch.fx.node.Target]) -> list[PatternEntry]: + return self.patterns[item] + + def apply(self, gm: Union[torch.fx.GraphModule, torch.fx.Graph]) -> int: + if not self.patterns: + return 0 + if isinstance(gm, torch.fx.GraphModule): + graph = gm.graph + elif isinstance(gm, torch.fx.Graph): + graph = gm + gm = graph.owning_module + else: + raise RuntimeError( + f"The input to PatternMatcherPass must be a GraphModule or a Graph, but got {type(gm)}" + ) + if should_compute_mutation_region_ids(graph): + compute_mutation_region_ids(graph) + get_mutation_region_id_partial = functools.partial( + get_mutation_region_id, graph + ) + count = 0 + nodes = [] + has_call_module = False + for op, target in self.patterns: + if op == "call_module": + has_call_module = True + else: + nodes.append(graph.find_nodes(op=op, target=target, sort=False)) + if has_call_module: + nodes.append(graph.find_nodes(op="call_module", sort=False)) + pass_name = self.pass_name if self.pass_name is not None else "pattern_matcher" + assert isinstance(gm, torch.fx.GraphModule) + with GraphTransformObserver(gm, pass_name, self.subsystem): + for node in sorted(itertools.chain.from_iterable(nodes), reverse=True): + target = extract_target(node) + if node.op == "call_module": + if (node.op, target) not in self.patterns: + continue + + # conservatively not applying pattern for cpu input, + # since some of the patterns induce codegen and split nodes. + # Note: we will only skip cpu compute if disable_cpp_codegen=True + if fallback_node_due_to_unsupported_type(node, allow_cpu_inputs=False): + continue + + for entry in self.patterns[(node.op, target)]: + if node._erased: + break + m = entry.pattern.match(node) + # pattern match crosses mutation barrier - discard + if ( + is_match(m) + and len( + OrderedSet(map(get_mutation_region_id_partial, m.nodes)) + ) + != 1 + ): + continue + if os.environ.get("TORCHINDUCTOR_PATTERN_MATCH_DEBUG") == node.name: + log.warning("%s%s %s %s", node, node.args, m, entry.pattern) + + if is_match(m) and guard_or_false(entry.extra_check(m)): + count += 1 + entry.apply(m, graph, node) + counters[backend]["pattern_matcher_count"] += 1 + counters[backend]["pattern_matcher_nodes"] += len(m.nodes) + return count + + def clear(self) -> None: + self.patterns.clear() + + +def _not_implemented(*args: Any, **kwargs: Any) -> NoReturn: + raise NotImplementedError + + +def fx_to_pattern( + gm: Union[torch.fx.GraphModule, torch.fx.Graph], + ignore_types: Sequence[type[Any]] = (), + argnames: Sequence[str] = (), + scalar_workaround: Union[dict[str, Union[float, int]], None] = None, + exclusive_arg_names: Sequence[str] = (), +) -> PatternExpr: + """ + Convert an FX graph into a PatternExpr. This is useful for simple + patterns that can only match single functions and fixed-length lists. + """ + # scalar_workaround is a hack to capture dropout_p + # see https://github.com/pytorch/pytorch/issues/97894 + scalar_workaround = scalar_workaround or {} + inv_scalar_workaround = {v: k for k, v in scalar_workaround.items()} + assert len(inv_scalar_workaround) == len(scalar_workaround) + + def process_arg( + x: T, ignore_types_override: Optional[Sequence[type[Any]]] = None + ) -> Union[T, KeywordArg, Ignored]: + current_ignore_types = ( + ignore_types_override if ignore_types_override is not None else ignore_types + ) + if isinstance(x, (float, int)) and x in inv_scalar_workaround: + return KeywordArg(inv_scalar_workaround[x]) + if type(x) in current_ignore_types: + return Ignored() + if isinstance(x, list) and all(isinstance(y, Ignored) for y in x) and x: + return Ignored() + return x + + argnum = itertools.count() + + class Converter(torch.fx.Interpreter): + # pyrefly: ignore [bad-override] + call_method = _not_implemented + # pyrefly: ignore [bad-override] + call_module = _not_implemented + # pyrefly: ignore [bad-override] + get_attr = _not_implemented + + # pyrefly: ignore [bad-override] + def placeholder( + self, + target: str, # type: ignore[override] + args: Sequence[Any], + kwargs: Mapping[str, Any], + ) -> Union[ExclusiveKeywordArg, KeywordArg]: + n = next(argnum) + if n < len(argnames): + name = argnames[n] + elif argnames: + assert target.startswith("tangent") + name = target + else: + target = re.sub(r"_\d+$", "", target) # de-mangle arg name + name = target + if name in exclusive_arg_names: + return ExclusiveKeywordArg(name) + else: + return KeywordArg(name) + + # pyrefly: ignore [bad-override] + def call_function( + self, + target: str, # type: ignore[override] + args: Sequence[Any], + kwargs: Mapping[str, Any], + ) -> PatternExpr: + process_arg_fn = process_arg + # Indexing is critical for matching getitem nodes, so we can't ignore int args here + if target is operator.getitem: + + def process_arg_fn_impl( + x: T, + ignore_types_override: Optional[Sequence[type[Any]]] = tuple( + t for t in ignore_types if t is not int + ), + ) -> Union[T, KeywordArg, Ignored]: + return process_arg(x, ignore_types_override) + + process_arg_fn = process_arg_fn_impl + + args, kwargs = pytree.tree_map(process_arg_fn, (args, kwargs)) + if list in ignore_types: + # Handle a burned in tensor size which are now [Ignored(), Ignored(), ...] + args = [process_arg_fn(a) for a in args] + kwargs = {k: process_arg_fn(a) for k, a in kwargs.items()} + return CallFunction(target, *args, **kwargs) + + def run_node(self, n: torch.fx.Node) -> Any: + rv = super().run_node(n) + if n.op == "output" and isinstance(rv, tuple): + args = n.args[0] + assert isinstance(args, Collection) + assert len(rv) == len(args) + for r, arg in zip(rv, args): + # pyrefly: ignore [missing-attribute] + r.users = len(arg.users) + else: + rv.users = len(n.users) + return rv + + assert isinstance(gm, torch.fx.GraphModule) + pattern = Converter(gm).run() + if not isinstance(pattern, PatternExpr): + return MultiOutputPattern(pytree.tree_leaves(pattern)) + return pattern + + +@torch.no_grad() +def fwd_only( + fn: Callable[..., Any], + args: Sequence[Any], + *, + run_functional_passes: bool = True, + get_decomp_fn: Optional[Callable[..., Any]] = None, +) -> torch.fx.GraphModule: + """Build a normalized inference graph, for use with fx_to_pattern""" + # TODO - look into using aot autograd, asserting no mutating ops here + with enable_python_dispatcher(), preserve_node_meta(): + decompositions = ( + get_decomp_fn() if get_decomp_fn is not None else select_decomp_table() + ) + gm = make_fx(fn, decompositions, tracing_mode="real")(*args) + + from .fx_passes.post_grad import remove_noop_ops + + if run_functional_passes: + remove_noop_ops(gm.graph) + gm.graph.eliminate_dead_code() + + gm.recompile() + return gm + + +@torch.enable_grad() +def joint_fwd_bwd(fn: Callable[..., Any], args: Sequence[Any]) -> torch.fx.GraphModule: + """Build a normalized training graph, for use with fx_to_pattern""" + gm: Optional[torch.fx.GraphModule] = None + + def record_joint_graph( + joint_graph: torch.fx.GraphModule, inputs: Sequence[Any], **kwargs: Any + ) -> tuple[torch.fx.GraphModule, torch.fx.GraphModule]: + nonlocal gm + assert not gm + gm = clone_graph(joint_graph) + return default_partition(joint_graph, inputs, **kwargs) + + with torch._guards.tracing(None): + aot_function( + fn, + lambda g, i: make_boxed_func(g), + partition_fn=record_joint_graph, + decompositions=select_decomp_table(), + keep_inference_input_mutations=True, + enable_log=False, + )(*args) + assert gm + + from .fx_passes.post_grad import remove_noop_ops + + remove_noop_ops(gm.graph) + + from .fx_passes.joint_graph import pointless_view + + matcher_pass = PatternMatcherPass() + + pattern = CallFunction( + torch.ops.aten.view.default, KeywordArg("arg"), KeywordArg("size") + ) + GraphPatternEntry( + pattern=pattern, + handler=pointless_view, + extra_check=_return_true, + # pyrefly: ignore [bad-argument-type] + ).register(matcher_pass.patterns) + matcher_pass.apply(gm.graph) + + # remove in/out specs + gm.graph._codegen = torch.fx.graph.CodeGen() + gm.graph.eliminate_dead_code() + gm.recompile() + return gm + + +def _args(n: torch.fx.Node) -> list[torch.fx.node.Argument]: + args: list[torch.fx.node.Argument] = [] + torch.fx.map_arg((n.args, n.kwargs), args.append) + return args + + +def stable_topological_sort(graph: torch.fx.Graph) -> None: + # Nodes are in exactly one of these three collections: + + # - Nodes in `pending` are waiting to be processed (in reverse order): + pending = list(reversed(graph.nodes)) + + # - Nodes in `ready` have been processed and are already in the correct + # order. + ready = OrderedSet[torch.fx.Node]() + + # - `waiting` is a mapping from a dependency to nodes which depend on that + # dependency. + waiting = defaultdict(list) + + # The cursor indicates the last processed node so we can add new nodes + # after it. + cursor = None + while pending: + node = pending.pop() + waiting_for = [x for x in _args(node) if x not in ready] + if waiting_for: + # We have unprocessed input nodes. Might as well wait for the last + # arg so an already sorted list will only recheck this node once. + waiting[waiting_for[-1]].append(node) + else: + ready.add(node) + if cursor and cursor.next is not node: + cursor.append(node) + cursor = node + # Mark the nodes that have been waiting for this node to finish as + # ready to check again. + pending.extend(reversed(waiting.pop(node, ()))) + + assert not waiting and len(ready) == len(graph.nodes) + + +def init_once_fakemode(fn: Callable[..., Any]) -> Callable[[], Any]: + """Wrapper around lazy init functions in fx_passes/""" + + @functools.cache + @functools.wraps(fn) + def lazy_init() -> Any: + counters_ref = counters[backend].copy() + + with torch._guards.tracing(None), unset_fake_temporarily(), FakeTensorMode(): + result = fn() + + # clear view matches encountered during tracing + counters[backend] = counters_ref + + return result + + return lazy_init + + +def config_flag(name: str) -> Callable[[Match], Any]: + """Function for extra_check to put pass behind a flag""" + + def flag_check(match: Match) -> Any: + return getattr(config, name) + + return flag_check + + +def clone_graph(input_graph: torch.fx.GraphModule) -> torch.fx.GraphModule: + class CopyGraph(Transformer): + def run_node(self, old_node: torch.fx.Node) -> torch.fx.Node: + new_node = super().run_node(old_node) + if isinstance(new_node, torch.fx.Proxy): + new_node.node.meta.update(old_node.meta) + new_node.node.name = self.new_graph._graph_namespace.create_name( + old_node.name, None + ) + # pyrefly: ignore [bad-return] + return new_node + + return CopyGraph(input_graph).transform() + + +# TODO: remove in follow up diff, used internally +_seen_patterns: OrderedSet[str] = OrderedSet() + + +def get_arg_value( + node: torch.fx.Node, arg_number: int, kwarg_name: Optional[str] = None +) -> Any: + if len(node.args) > arg_number: + return node.args[arg_number] + elif kwarg_name is None: + return None + else: + return node.kwargs.get(kwarg_name) + + +def filter_nodes(nodes: Iterable[torch.fx.Node], fn: Any) -> list[torch.fx.Node]: + fns = [fn] + if isinstance(fn, torch._ops.OpOverloadPacket): + fns.extend([getattr(fn, overload) for overload in fn.overloads()]) + + return [node for node in nodes if node.target in fns] + + +def extract_target(node: torch.fx.Node) -> torch.fx.node.Target: + """For call_function and call_method, we directly use the target function; + For call_module, the target is string, and we treat the module class + as a function. + """ + if node.op == "call_module": + assert isinstance(node.target, str) + return _get_attr(node.graph.owning_module, node.target).__class__ + return node.target diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/quantized_lowerings.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/quantized_lowerings.py new file mode 100644 index 0000000000000000000000000000000000000000..5b6f8c12309b81202fc92a5def2d7f191f6641f8 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/quantized_lowerings.py @@ -0,0 +1,169 @@ +import logging +from typing import Any + +import torch +from torch._inductor.kernel.mm_common import mm_args + +from . import config, lowering +from .codegen.cpp_gemm_template import CppGemmTemplate, CppWoqInt4GemmTemplate +from .codegen.cpp_utils import create_epilogue_with_attr +from .lowering import expand, register_lowering +from .mkldnn_ir import WeightInt4PackMatmul +from .select_algorithm import ( + autotune_select_algorithm, + ExternKernelChoice, + realize_inputs, +) +from .utils import use_aten_gemm_kernels, use_cpp_gemm_template +from .virtualized import V + + +log = logging.getLogger(__name__) + +aten__weight_int8pack_mm = ExternKernelChoice( + torch._weight_int8pack_mm, "at::_weight_int8pack_mm", has_out_variant=False +) + +aten__weight_int4pack_mm_cpu = ExternKernelChoice( + torch.ops.quantized.int4mm_packed_weight_cpu, + "at::native::_weight_int4pack_mm_cpu_tensor", + has_out_variant=False, + kernel_creator=WeightInt4PackMatmul.create, +) + +quantized = torch.ops.quantized +_quantized = torch.ops._quantized +aten = torch.ops.aten + + +def register_quantized_ops() -> None: + lowering.add_needs_realized_inputs( + [ + quantized.max_pool2d, + _quantized.wrapped_fbgemm_pack_gemm_matrix_fp16, + _quantized.wrapped_fbgemm_linear_fp16_weight, + ] + ) + lowering.make_fallback(quantized.max_pool2d) + lowering.make_fallback(_quantized.wrapped_fbgemm_pack_gemm_matrix_fp16) + lowering.make_fallback(_quantized.wrapped_fbgemm_linear_fp16_weight) + + +def register_woq_mm_ops() -> None: + @register_lowering(aten._weight_int8pack_mm, type_promotion_kind=None) # type: ignore[misc] + def int8pack_mm( + input: torch.Tensor, + weight: torch.Tensor, + scale: torch.Tensor, + *, + layout: Any = None, + ) -> Any: + _, _, _, layout, mat1, mat2 = mm_args( + input, weight, layout=layout, mat2_transposed=True + ) + assert ( + mat1.get_dtype() in [torch.bfloat16, torch.float16, torch.float] + and mat2.get_dtype() == torch.int8 + ) + aten_layout = layout + + # options to tune from + choices = ( + [aten__weight_int8pack_mm.bind((mat1, mat2, scale), aten_layout)] + if use_aten_gemm_kernels() + else [] + ) + + # scale is applied as an epilogue, and the scale tensor is expanded (with a view op) + # for broadcasting, as it's 1D. + def _mul_epilogue(buf: torch.Tensor) -> Any: + return create_epilogue_with_attr( + buf, "mul", other=realize_inputs(expand(scale, layout.size)) + ) + + if use_cpp_gemm_template(aten_layout, mat1, mat2, mat2_transposed=True): + CppGemmTemplate.add_choices( + choices, + aten_layout, + [mat1, mat2, scale], + trans_w=True, + epilogue_creator=_mul_epilogue, # type: ignore[arg-type] + ) + + return autotune_select_algorithm( + "_weight_int8pack_mm", choices, [mat1, mat2, scale], aten_layout + ) + + @register_lowering(aten._weight_int4pack_mm_for_cpu, type_promotion_kind=None) # type: ignore[misc] + def int4pack_mm_cpu( + input: torch.Tensor, + weight: torch.Tensor, + qGroupSize: int, + qScaleAndZeros: torch.Tensor, + *, + layout: Any = None, + ) -> Any: + _, _, _, layout, mat1, mat2 = mm_args( + input, weight, layout=layout, use_4x2_dim=True, mat2_transposed=True + ) + assert ( + mat1.get_dtype() in [torch.bfloat16, torch.float16, torch.float] + and mat2.get_dtype() == torch.uint8 + ) + group_size = V.graph.add_tensor_constant( + torch.tensor(qGroupSize, dtype=torch.int64), name=None + ) + aten_layout = layout + + # options to tune from + choices = ( + [ + aten__weight_int4pack_mm_cpu.bind( + (mat1, mat2, group_size, qScaleAndZeros), aten_layout + ) + ] + if use_aten_gemm_kernels() + else [] + ) + if ( + (config.max_autotune or config.max_autotune_gemm) + and use_cpp_gemm_template( + aten_layout, + mat1, + mat2, + mat2_transposed=True, + is_woq_int4=True, + q_group_size=qGroupSize, + ) + and mat2.get_layout().is_contiguous() + ): + # pyrefly: ignore [bad-specialization, missing-attribute, not-a-type] + CppWoqInt4GemmTemplate[qGroupSize].add_choices( + choices, + aten_layout, + [mat1, mat2, group_size, qScaleAndZeros], + ) + + # define functions to generate example inputs for weight and group size + # otherwise, autotuner generates example inputs of all zeros for them + def get_example_weight(x: torch._inductor.ir.IRNode) -> torch.Tensor: + assert x.get_layout().is_contiguous() + shape = x.get_size() + device = x.get_device() + return torch.randint(0, 255, shape, dtype=torch.uint8, device=device) + + input_gen_fns = { + 1: get_example_weight, # packed weight + 2: lambda x: V.graph.constants[x.get_name()], # group size + } + + return autotune_select_algorithm( + "_weight_int4pack_mm_for_cpu", + choices, + [mat1, mat2, group_size, qScaleAndZeros], + aten_layout, + input_gen_fns=input_gen_fns, + ) + + lowering.make_fallback(aten._dyn_quant_matmul_4bit) + lowering.make_fallback(aten._dyn_quant_pack_4bit_weight) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/remote_cache.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/remote_cache.py new file mode 100644 index 0000000000000000000000000000000000000000..d9a2d4af9d1be060d7dd7ab3654aa06be709c40f --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/remote_cache.py @@ -0,0 +1,432 @@ +from __future__ import annotations + +import atexit +import collections +import dataclasses +import functools +import json +import logging +import os +import sys +import typing +from abc import abstractmethod +from typing import Any, Generic, Optional, TypeAlias, TypeVar, Union +from typing_extensions import override + +from torch._dynamo.utils import dynamo_timed +from torch._inductor import config +from torch.monitor import _WaitCounter + + +if typing.TYPE_CHECKING: + from collections.abc import Callable + + +try: + import redis +except ImportError: + redis = None # type: ignore[assignment] + + +log = logging.getLogger(__name__) + + +if config.is_fbcode(): + from rfe.scubadata.scubadata_py3 import ( # type: ignore[import-not-found] + Sample as Sample_, + ) + + Sample: TypeAlias = Sample_ +else: + Sample: TypeAlias = type[object] # type: ignore[misc,no-redef] + + +_T = TypeVar("_T") +_U = TypeVar("_U") + + +remote_fx_cache_get_timed = functools.partial( + dynamo_timed, + "FbRemoteFxGraphCache.get", + phase_name="remote_fx_graph_cache_get", + log_pt2_compile_event=False, + dynamo_compile_column_us="remote_fx_graph_cache_get_time_us", + log_waitcounter=True, +) +remote_fx_cache_put_timed = functools.partial( + dynamo_timed, + "FbRemoteFxGraphCache.put", + phase_name="remote_fx_graph_cache_put", + log_pt2_compile_event=False, + dynamo_compile_column_us="remote_fx_graph_cache_put_time_us", + log_waitcounter=True, +) + + +class RemoteCacheBackend(Generic[_T]): + """ + A backend implementation for accessing a remote/distributed cache. Only + works with bytes in/out. For structured data use a RemoteCache. + """ + + def __init__(self) -> None: + self._name = f"backend:{type(self).__name__}" + + @abstractmethod + def _get(self, key: str) -> Optional[_T]: + pass + + @abstractmethod + def _put(self, key: str, data: _T) -> None: + pass + + def get(self, key: str) -> Optional[_T]: + try: + value = self._get(key) + cache_stats.get(self._name, value) + except Exception: + cache_stats.exception(self._name) + raise + return value + + def put(self, key: str, data: _T) -> None: + try: + self._put(key, data) + cache_stats.put(self._name) + except Exception: + cache_stats.exception(self._name) + raise + + +# Serde that encodes from _T to _U and decodes from _U to _T. +class RemoteCacheSerde(Generic[_T, _U]): + @abstractmethod + def encode(self, data: _T) -> _U: + pass + + @abstractmethod + def decode(self, data: _U) -> _T: + pass + + +JsonDataTy = Optional[ + Union[int, float, str, bool, dict[str, "JsonDataTy"], list["JsonDataTy"]] +] + + +class RemoteCacheJsonSerde(RemoteCacheSerde[JsonDataTy, bytes]): + def encode(self, data: JsonDataTy) -> bytes: + return bytes(json.dumps(data), "ascii") + + def decode(self, data: bytes) -> JsonDataTy: + return json.loads(data) + + +class RemoteCachePassthroughSerde(RemoteCacheSerde[_T, _T]): + def encode(self, data: _T) -> _T: + return data + + def decode(self, data: _T) -> _T: + return data + + +# This class is the top of a RemoteCache. A RemoteCache is fundamentally made of +# three parts: +# +# 1. The controller (this class). +# 2. A serializer/deserializer (instance of RemoteCacheSerde). +# 3. A backend (instance of RemoteCacheBackend). +# +# To write (`put`), the RemoteCache takes data, uses the RemoteCacheSerde to +# convert it for the backend and passes it to the backend. +# +# Conversely when reading (`get`), the RemoteCache takes data from the backend, +# uses the RemoteCacheSerde to convert it and returns it. +# +# The RemoteCacheBackend is generic on _U - which is the type of data the +# backend can directly cache (usually `bytes`). +# +# The RemoteCacheSerde is responsible for converting between _T (the type of +# data the RemoteCache accepts in `put` and returns in `get`) and _U. +# +# When instantiating a RemoteCache you should override, not directly create a +# RemoteCache. The reason is that when logging cache use (`TORCH_LOGS=cache`) we +# use the concrete type of the RemoteCache as the reported cache. See +# RemoteFxGraphCache below as an example. +class RemoteCache(Generic[_T]): + backend_override_cls: Optional[Callable[[], RemoteCacheBackend[Any]]] = None + + def __init__( + self, backend: RemoteCacheBackend[_U], serde: RemoteCacheSerde[_T, _U] + ) -> None: + # Support for testing to mock out the backend on a class-by-class basis. + if (override_cls := self.__class__.backend_override_cls) is not None: + self.backend = override_cls() + else: + self.backend = backend + # pyrefly: ignore [invalid-type-var] + self.serde = serde + + # See if the cache contains `key`. Returns `None` if the value is not + # present in the cache. + def get(self, key: str) -> Optional[_T]: + with _WaitCounter("pytorch.remote_cache.get").guard(): + sample = self._create_sample() + try: + result = self._get(key, sample) + cache_stats.get(type(self).__name__, result) + except Exception as e: + cache_stats.exception(type(self).__name__) + if sample: + sample.fail_reason = str(e) + raise + finally: + self._log_sample(sample) + return result + + # Add `value` to the cache with the key `key`. Note that `None` is not a + # valid value even if _T supports it (because you can't tell the difference + # between `None` and a missing cache entry). + def put(self, key: str, value: _T) -> None: + with _WaitCounter("pytorch.remote_cache.put").guard(): + assert value is not None + sample = self._create_sample() + try: + self._put(key, value, sample) + cache_stats.put(type(self).__name__) + except Exception as e: + cache_stats.exception(type(self).__name__) + if sample: + sample.fail_reason = str(e) + raise + finally: + self._log_sample(sample) + + # Used to convert data from the cache into structured data. + def _decode(self, data: _U, sample: Optional[Sample]) -> _T: # type: ignore[override] + return self.serde.decode(data) # type: ignore[arg-type] + + # Used to convert structured data into data for the cache. + def _encode(self, value: _T, sample: Optional[Sample]) -> object: # returns _U + return self.serde.encode(value) + + # Get structured data from the cache. + # Separate from `get` so that it can be overridden. + def _get(self, key: str, sample: Optional[Sample]) -> Optional[_T]: + if data := self._backend_get(key): + return self._decode(data, sample) + return None + + # Get unstructured data from the cache. + # Separate from `get` so that it can be overridden. + # Returns _U - but we aren't actually generic on _U + def _backend_get(self, key: str) -> object: + return self.backend.get(key) + + # Put structured data into the cache. + # Separate from `put` so that it can be overridden. + def _put(self, key: str, value: _T, sample: Optional[Sample]) -> None: + data = self._encode(value, sample) + self._backend_put(key, data) + + # Put unstructured data into the cache. + # Separate from `put` so that it can be overridden. + # Takes data: _U - but we aren't actually generic on _U + def _backend_put(self, key: str, data: object) -> None: + self.backend.put(key, data) + + # Create a logging Sample - used with internal loggers to monitor cache + # effectiveness. + def _create_sample(self) -> Optional[Sample]: + return None + + # Write the logging Sample to the logger. + def _log_sample(self, sample: Optional[Sample]) -> None: + pass + + +class RedisRemoteCacheBackend(RemoteCacheBackend[bytes]): + """ + A Redis implementation of a remote/distributed cache. + """ + + # pyrefly: ignore [missing-attribute] + _redis: Optional[redis.Redis] = None + + def __init__(self, cache_id: str) -> None: + super().__init__() + if not redis: + raise RuntimeError("redis not available but required for remote cache") + + if "TORCHINDUCTOR_REDIS_URL" in os.environ: + self._redis = redis.Redis.from_url(os.environ["TORCHINDUCTOR_REDIS_URL"]) + else: + self._redis = redis.Redis( + host=os.environ.get("TORCHINDUCTOR_REDIS_HOST", "localhost"), + port=int(os.environ.get("TORCHINDUCTOR_REDIS_PORT", 6379)), + ) + + @override + def _get(self, key: str) -> Optional[bytes]: + if not self._redis: + # Either redis wasn't found or we already had some trouble... + return None + + try: + # pyrefly: ignore [missing-attribute] + value = self._redis.get(key) + # pyrefly: ignore [missing-attribute] + except redis.exceptions.ConnectionError: + # Redis is lazy and doesn't actually attempt to connect until the + # first use. Mark is as unavailable now. + self._redis = None + return None + + # In theory redis.get() can return an Awaitable as well... + assert value is None or isinstance(value, bytes) + return value + + @override + def _put(self, key: str, data: bytes) -> None: + if not self._redis: + # Either redis wasn't found or we already had some trouble... + return + + try: + # pyrefly: ignore [missing-attribute] + self._redis.set(key, data) + # pyrefly: ignore [missing-attribute] + except redis.exceptions.ConnectionError: + # Redis is lazy and doesn't actually attempt to connect until the + # first use. Mark is as unavailable now. + self._redis = None + + +class RedisRemoteCache(RemoteCache[JsonDataTy]): + def __init__(self, cache_id: str) -> None: + # Special test handling: If we're just going to override the backend + # anyway don't require redis + if self.__class__.backend_override_cls: + # This is totally bogus but it works for now... + backend = typing.cast(RemoteCacheBackend[bytes], None) + else: + backend = RedisRemoteCacheBackend(cache_id) + serde = RemoteCacheJsonSerde() + super().__init__(backend, serde) + version = 1 # consistency between various types of keys + self._key_fmt = f"pt2:{cache_id}::{{key}}:c{version}" + + def _get_key(self, key: str) -> str: + return self._key_fmt.format(key=key) + + @override + def _get(self, key: str, sample: Optional[Sample]) -> Optional[JsonDataTy]: + key = self._get_key(key) + return super()._get(key, sample) + + @override + def _put(self, key: str, value: JsonDataTy, sample: Optional[Sample]) -> None: + key = self._get_key(key) + super()._put(key, value, sample) + + +class RemoteAutotuneCache(RedisRemoteCache): + pass + + +class RemoteBundledAutotuneCache(RedisRemoteCache): + pass + + +class RemoteFxGraphCache(RedisRemoteCache): + pass + + +class RemoteAOTAutogradCache(RedisRemoteCache): + pass + + +class RemoteDynamoPGOCache(RedisRemoteCache): + pass + + +def create_cache( + key: str, + is_fbcode: bool, + fb_cache_cls: str, + oss_cache_cls: str, +) -> Optional[RemoteCache[JsonDataTy]]: + try: + if is_fbcode: + import torch._inductor.fb.remote_cache + + cache_cls = getattr(torch._inductor.fb.remote_cache, fb_cache_cls) + return cache_cls(key) + else: + this_module = sys.modules[__name__] + + cache_cls = getattr(this_module, oss_cache_cls) + return cache_cls(key) + + except Exception: + log.warning("Unable to create a remote cache", exc_info=True) + return None + + +# Some simple stat capture +@dataclasses.dataclass +class _CacheStat: + miss: int = 0 + hit: int = 0 + put: int = 0 + exception: int = 0 + + def __str__(self) -> str: + return f"{{hit: {self.hit}, miss: {self.miss}, put: {self.put}, exception: {self.exception}}}" + + +class _CacheStats: + _stats: dict[str, _CacheStat] + + def __init__(self) -> None: + self._stats = collections.defaultdict(_CacheStat) + + def miss(self, name: str, count: int = 1) -> None: + self._stats[name].miss += count + + def hit(self, name: str, count: int = 1) -> None: + self._stats[name].hit += count + + def get(self, name: str, value: Optional[object]) -> None: + if value is None: + self.miss(name) + else: + self.hit(name) + + def put(self, name: str, count: int = 1) -> None: + self._stats[name].put += count + + def exception(self, name: str, count: int = 1) -> None: + self._stats[name].exception += count + + +cache_stats = _CacheStats() + + +@atexit.register +def dump_cache_stats() -> None: + if not log.isEnabledFor(logging.INFO): + return + + import io + + out = io.StringIO() + + if not cache_stats._stats: + print(" None", file=out) + else: + print(file=out) + for k, v in sorted(cache_stats._stats.items()): + print(f" {k}: {v}", file=out) + + log.info("Cache Metrics:%s", out.getvalue()) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/remote_gemm_autotune_cache.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/remote_gemm_autotune_cache.py new file mode 100644 index 0000000000000000000000000000000000000000..0ef026269b10c86d58f72e53e998af4ba59b13bf --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/remote_gemm_autotune_cache.py @@ -0,0 +1,20 @@ +import asyncio +from typing import TypeVar + +import torch._inductor.config as config +from torch._inductor import ir + + +_T = TypeVar("_T") + + +def gen_best_config(mat1: ir.StorageBox, mat2: ir.StorageBox) -> asyncio.Task[_T]: + """ + Generate the best GEMM autotune config for the given matrices. + """ + if config.is_fbcode(): + from torch._inductor.fb.remote_gemm_autotune_cache import gen_best_config + + return gen_best_config(mat1, mat2) + else: + raise NotImplementedError("Function gen_best_config is not yet implemented") diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/rocm_multiarch_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/rocm_multiarch_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a1a6103e1091511121cc7612d5fd5d0a99993056 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/rocm_multiarch_utils.py @@ -0,0 +1,264 @@ +""" +ROCm Multi-Architecture Support Utilities +Compile LLVM IR to multi-arch bundles that HIP can load automatically. +""" + +import os +import subprocess +from typing import Optional + +import torch +from torch.utils.cpp_extension import _join_rocm_home, ROCM_HOME + + +def get_rocm_compiler() -> str: + """ + Get path to ROCm's clang compiler. + Uses PyTorch's ROCM_HOME detection. + + Returns: + Path to clang compiler + + Raises: + RuntimeError: If ROCm is not found + """ + if ROCM_HOME is None: + raise RuntimeError( + "ROCm installation not found. " + "PyTorch was not built with ROCm support or ROCM_HOME is not set." + ) + + # ROCm's clang is at /llvm/bin/clang + clang_path = _join_rocm_home("llvm", "bin", "clang") + + if not os.path.exists(clang_path): + raise RuntimeError( + f"ROCm clang not found at {clang_path}. ROCM_HOME is set to {ROCM_HOME}" + ) + + return clang_path + + +def get_rocm_bundler() -> str: + """ + Get path to clang-offload-bundler. + Uses PyTorch's ROCM_HOME detection. + + Returns: + Path to bundler + + Raises: + RuntimeError: If bundler is not found + """ + if ROCM_HOME is None: + raise RuntimeError( + "ROCm installation not found. " + "PyTorch was not built with ROCm support or ROCM_HOME is not set." + ) + + # Bundler is at /llvm/bin/clang-offload-bundler + bundler_path = _join_rocm_home("llvm", "bin", "clang-offload-bundler") + + if not os.path.exists(bundler_path): + raise RuntimeError( + f"clang-offload-bundler not found at {bundler_path}. " + f"ROCM_HOME is set to {ROCM_HOME}" + ) + + return bundler_path + + +def get_rocm_target_archs() -> list[str]: + """ + Get target architectures from environment or config. + Returns: List of architecture strings (e.g., ['gfx90a', 'gfx942']) + """ + # Check PYTORCH_ROCM_ARCH environment variable + env_archs = os.environ.get("PYTORCH_ROCM_ARCH", "").strip() + if env_archs: + archs = [arch.strip() for arch in env_archs.replace(";", ",").split(",")] + archs = [arch for arch in archs if arch] + if archs: + return archs + + # Try to get from inductor config + try: + from torch._inductor import config + + if hasattr(config, "rocm") and hasattr(config.rocm, "target_archs"): + archs = config.rocm.target_archs + if archs: + return archs + + except Exception: + pass + + return torch.cuda.get_arch_list() + + +def compile_llvm_ir_to_code_object( + llvm_ir_path: str, output_path: str, target_arch: str +) -> bool: + """ + Compile unbundled LLVM IR to a single-arch code object. + + Args: + llvm_ir_path: Path to .ll file + output_path: Where to write .hsaco file + target_arch: Target architecture (e.g., 'gfx90a') + + Returns: + True if successful + """ + if not os.path.exists(llvm_ir_path): + return False + + os.makedirs(os.path.dirname(output_path), exist_ok=True) + + try: + clang = get_rocm_compiler() + except RuntimeError: + return False + + # Using clang and not hipcc since we are not compiling source code + # Instead we use the LLVM IR (.ll) provided by triton + cmd = [ + clang, + "-target", + "amdgcn-amd-amdhsa", + f"-mcpu={target_arch}", + llvm_ir_path, + "-o", + output_path, + ] + + try: + subprocess.run(cmd, capture_output=True, text=True, check=True) + + if not os.path.exists(output_path): + return False + + return True + + except subprocess.CalledProcessError: + return False + + +def create_multiarch_bundle(code_objects: dict, output_bundle_path: str) -> bool: + """ + Bundle multiple architecture code objects into a single multi-arch bundle. + + Uses clang-offload-bundler to create a fat binary that HIP runtime can load. + The runtime automatically selects the correct architecture at load time. + + Args: + code_objects: Dict mapping architecture to code object path + output_bundle_path: Path for output bundle + + Returns: + True if successful + """ + if not code_objects: + return False + + os.makedirs(os.path.dirname(output_bundle_path), exist_ok=True) + + try: + bundler = get_rocm_bundler() + except RuntimeError: + return False + + # Build targets and inputs lists for clang-offload-bundler + targets = ["host-x86_64-unknown-linux-gnu"] + + # We include a dummy host entry to satisfy the bundler format + inputs = ["/dev/null"] + + for arch, path in sorted(code_objects.items()): + if not os.path.exists(path): + continue + # hipv4 = HIP version 4 code object format + # amdgcn-amd-amdhsa = target triple for ROCm/HSA runtime + # arch = specific GPU (gfx90a, gfx942, etc.) + targets.append(f"hipv4-amdgcn-amd-amdhsa--{arch}") + inputs.append(path) + + if len(inputs) == 1: # Only host, no device code + return False + + cmd = [ + bundler, + "--type=o", + # CRITICAL: HIP runtime expects 4096-byte alignment for loading bundles + # Without this, hipModuleLoadData gives segmentation fault + "-bundle-align=4096", # CRITICAL: Required by HIP runtime! + f"--targets={','.join(targets)}", + ] + + for input_file in inputs: + cmd.append(f"--input={input_file}") + + cmd.append(f"--output={output_bundle_path}") + + try: + subprocess.run(cmd, capture_output=True, text=True, check=True) + + if not os.path.exists(output_bundle_path): + return False + + return True + + except subprocess.CalledProcessError: + return False + + +def compile_multiarch_bundle_from_llvm_ir( + llvm_ir_path: str, output_bundle_path: str, target_archs: Optional[list[str]] = None +) -> bool: + """ + Complete workflow: LLVM IR → multiple code objects → bundle. + + This is the main entry point for multi-arch compilation. + + Args: + llvm_ir_path: Path to .ll file + output_bundle_path: Where to write bundle + target_archs: Optional list of architectures + + Returns: + True if successful + """ + if target_archs is None: + # Get architectures from environment variable or config + target_archs = get_rocm_target_archs() + + # Step 1: Compile LLVM IR to code object for each architecture + code_objects = {} + temp_dir = os.path.dirname(output_bundle_path) + kernel_name = os.path.splitext(os.path.basename(llvm_ir_path))[0] + + for arch in target_archs: + # Create temporary single-architecture code object + # Format: kernel_name_gfx90a.co, kernel_name_gfx942.co, etc. + co_path = os.path.join(temp_dir, f"{kernel_name}_{arch}.co") + + # Compile with clang backend: LLVM IR → GPU machine code + if compile_llvm_ir_to_code_object(llvm_ir_path, co_path, arch): + code_objects[arch] = co_path + + if not code_objects: + return False + + # Step 2: Bundle all code objects together + # Uses clang-offload-bundler to create fat binary + success = create_multiarch_bundle(code_objects, output_bundle_path) + + # Step 3: Clean up temporary single-arch code objects + # The bundle contains all the code, so intermediates are no longer needed + for co_path in code_objects.values(): + try: + os.remove(co_path) + except Exception: + pass + + return success diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/scheduler.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..47323242901e9f681dade60625700f35ddf86953 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/scheduler.py @@ -0,0 +1,6576 @@ +from __future__ import annotations + +import collections +import contextlib +import dataclasses +import functools +import inspect +import itertools +import logging +import math +import operator +import os +import pprint +import textwrap +import traceback +import typing +from collections import Counter, defaultdict +from typing import Any, Generic, Optional, TYPE_CHECKING, TypeAlias, TypeVar, Union +from typing_extensions import ParamSpec + +from torch.utils._ordered_set import OrderedSet + +from .ir import ComputedBuffer + + +if TYPE_CHECKING: + from collections.abc import Callable, Iterator, Sequence + from types import ModuleType + +import sympy + +import torch +import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools +import torch.utils._pytree as pytree +from torch._dynamo.utils import counters, dynamo_timed +from torch._inductor.codecache import LambdaFuture, PyCodeCache +from torch._inductor.ir import TritonTemplateCallerBase +from torch._inductor.metrics import get_metric_table, is_metric_table_enabled +from torch.fx.experimental.symbolic_shapes import free_symbols +from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT +from torch.utils._triton import has_triton + +from . import comms, config, config_comms, dependencies, ir, metrics +from .analyze_preserves_zero_mask import can_codegen_without_upcasts +from .codegen.common import BackendFeature, get_scheduling_for_device, Kernel +from .comm_analysis import ( + estimate_nccl_collective_runtime, + estimate_nccl_collective_runtime_nccl_estimator, +) +from .dependencies import Dep, MemoryDep, StarDep, WeakDep +from .exc import GPUTooOldForTriton, TritonMissing +from .fx_utils import count_flops_fx +from .ir import ( + assign_origin_node, + get_device_type, + GraphPartitionSignature, + MultiOutput, + MultiOutputLayout, + NoneLayout, +) +from .loop_body import LoopBody +from .memory import MemoryPlanningInfoForBuffer, MemoryPlanningInfoForNode +from .runtime.hints import ReductionHint +from .runtime.runtime_utils import green_text, red_text +from .sizevars import SimplifyIndexing +from .utils import ( + _unstable_customized_partition_wrapper, + cache_on_self, + cmp, + device_need_guard, + get_current_backend, + get_device_tflops, + get_dtype_size, + get_gpu_dram_gbps, + GraphPartitionMap, + IndentedBuffer, + is_collective, + is_cudagraph_unsafe_op, + is_gpu, + is_multi_outputs_template, + is_output_of_multi_outputs_template, + is_wait, + maybe_log_cudagraph_partition, + sympy_product, +) +from .virtualized import V + + +log = logging.getLogger(__name__) +fusion_log = torch._logging.getArtifactLogger(__name__, "fusion") +loop_ordering_log = torch._logging.getArtifactLogger(__name__, "loop_ordering") +compute_dependencies_log = torch._logging.getArtifactLogger( + __name__, "compute_dependencies" +) + +PartitionType: TypeAlias = list["BaseSchedulerNode"] +_T = TypeVar("_T") +_P = ParamSpec("_P") + + +class MixOrderReduction: + """ + This class contains utility functions to decide if we should fuse reductions + reducing across different dimensions of the same input tensor. + """ + + @staticmethod + def is_split_reduction(node: BaseSchedulerNode) -> bool: + return node.is_reduction() and all( + subnode.node._split_size is not None + for subnode in node.get_nodes() + if isinstance(subnode, SchedulerNode) + and subnode.is_reduction() + and isinstance(subnode.node, ComputedBuffer) + ) + + @classmethod + def get_numel_rnumel(cls, node: BaseSchedulerNode) -> tuple[sympy.Expr, sympy.Expr]: + if cls.is_split_reduction(node): + xnumel = None + rnumel = None + for subnode in node.get_nodes(): + if not ( + isinstance(subnode, SchedulerNode) + and subnode.is_reduction() + and isinstance(subnode.node, ComputedBuffer) + ): + continue + + assert subnode.node._original_ranges is not None + curxnumel = V.graph.sizevars.simplify( + sympy_product(subnode.node._original_ranges) + ) + assert subnode.node._original_reduction_ranges is not None + currnumel = V.graph.sizevars.simplify( + sympy_product(subnode.node._original_reduction_ranges) + ) + + if xnumel is None: + xnumel = curxnumel + rnumel = currnumel + else: + assert V.graph.sizevars.statically_known_equals( + xnumel, curxnumel + ), f"{xnumel} v.s. {curxnumel}" + assert V.graph.sizevars.statically_known_equals( + rnumel, currnumel + ), f"{rnumel} v.s. {currnumel}" + + assert xnumel is not None + return (xnumel, rnumel) + else: + return node.group[1] # type: ignore[return-value] + + @classmethod + def has_mix_reduction_orders( + cls, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> bool: + g1 = cls.get_numel_rnumel(node1) + g2 = cls.get_numel_rnumel(node2) + + if len(g1) != 2 or len(g2) != 2 or g1 == g2: + return False + + return tuple(g1) == tuple(reversed(g2)) + + @classmethod + def _is_full_access(cls, buf: str, node: BaseSchedulerNode) -> bool: + """ + The access to 'buf' is not a broadcast access. + """ + found_dep = None + for dep in node.read_writes.reads: + if isinstance(dep, MemoryDep) and dep.name == buf: + found_dep = dep + break + + if not found_dep: + return False + + index = found_dep.index + var_ranges = node.read_writes.var_ranges + + if not var_ranges: + assert isinstance(node, FusedSchedulerNode), f"{type(node)}" + var_ranges = node.snodes[0].read_writes.var_ranges + + assert var_ranges + if not (OrderedSet(var_ranges) - OrderedSet(index.free_symbols)): + return True + + # cases that happen after merging loops: + # MemoryDep('arg0_1', c0, {c0: 25165824})]) + # var_ranges={d0: 32768, d1: 768} + if V.graph.sizevars.statically_known_equals( + sympy_product(found_dep.size), sympy_product(var_ranges.values()) + ): + return True + return False + + @classmethod + def get_common_read( + cls, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> list[str]: + out = [] + common_reads = node1.used_buffer_names() & node2.used_buffer_names() + for buf in common_reads: + if cls._is_full_access(buf, node1) and cls._is_full_access(buf, node2): + out.append(buf) + + return out + + @classmethod + def has_common_read( + cls, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> bool: + return len(cls.get_common_read(node1, node2)) > 0 + + @classmethod + def get_numel(cls, node: BaseSchedulerNode) -> int: + g1 = cls.get_numel_rnumel(node) + return V.graph.sizevars.size_hint(g1[0] * g1[1], fallback=0) + + @classmethod + def get_fusion_score( + cls, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> int: + # node2 is ignored for now + return cls.get_numel(node1) + + # TODO add a cache + @classmethod + def can_fuse(cls, node1: BaseSchedulerNode, node2: BaseSchedulerNode) -> bool: + """ + Check whether we can fuse two reductions with mix loop orders. + """ + if not config.triton.mix_order_reduction: + return False + + # TODO: Mix order reduction is not supported with cpp_wrapper yet + if V.graph.cpp_wrapper: + return False + + if not node1.is_gpu() or not node2.is_gpu(): + return False + device_type = node1.get_device().type # type: ignore[union-attr] + if ( + device_type not in ("cuda", "xpu") + or get_current_backend(device_type) != "triton" + ): + return False + if not node1.is_reduction() or not node2.is_reduction(): + return False + + if (node1.ancestors & node2.get_operation_names()) or ( + node2.ancestors & node1.get_operation_names() + ): + # the two reductions have no producer/consumer relationship + return False + + # check for mix reduction orders + if not cls.has_mix_reduction_orders(node1, node2): + return False + + # check common buffer accesses + common_reads = MixOrderReduction.get_common_read(node1, node2) + if len(common_reads) == 0: + return False + + g1 = cls.get_numel_rnumel(node1) + nrow = sympy.Max(g1[0], g1[1]) + ncol = sympy.Min(g1[0], g1[1]) + + # the fused version has worse perf than non-fused version for + # small workload. When a workload is small enough, data can be + # fully cached by L2 + size_thres = 5 * 2**20 + + # Call evaluate_expr rather than statically_known_geq since nrow can + # have dynamic shape in real models. + # Don't use hint directly since hint can be non-representative. + if not V.graph.sizevars.evaluate_expr(sympy.Ge(nrow * ncol, size_thres)): + return False + + # We require more more row than columns since + # 1, we prefer doing persistent reduction for each row + # 2, we will split the reduction across the rows + if not V.graph.sizevars.evaluate_expr(sympy.Ge(nrow, ncol * 2)): + return False + + # When nrow is small, ncol should also be small (due to the check + # above). Thus the entire tensor should be well cached in L2. + # Mix order reduction is less beneficial. + if not V.graph.sizevars.evaluate_expr(sympy.Ge(nrow, 4096)): + return False + + contiguous_node, other_node = ( + (node1, node2) + if V.graph.sizevars.evaluate_expr(sympy.Eq(g1[1], ncol)) + else (node2, node1) + ) + + # We previously only check the contiguous_node has contiguous + # access to common_reads. But that turns out to be not enough. + # The contiguous node may access a buffer that's node use by + # other_ndoe. If that ascess is non-contiugous, generating + # mix-order reduction can be inefficient especially when we + # force XBLOCK to be 1 + # if not all( + # cls.is_contiguous_load(buf, contiguous_node) for buf in common_reads + # ): + # return False + if not all( + cls.is_contiguous_load(dep.name, contiguous_node) + for dep in contiguous_node.read_writes.reads + ): + return False + + # Make sure a persistent reduction will be generated + if any( + subnode.node.data.reduction_hint # type: ignore[union-attr] + not in ( + ReductionHint.INNER, + ReductionHint.DEFAULT, + ) + for subnode in contiguous_node.get_nodes() + if subnode.is_reduction() + ): + return False + + # rnumel so large that we will not generated persistent reduction + # We don't see real use cases with dynamic ncol. But if we do, + # we should call evaluete_expr here which adds guards. + if not V.graph.sizevars.statically_known_leq(ncol, 1024 * 16): + return False + + # Other reduction types like max/min is not supported yet. + # There are no real use case as well. + out = all( + subnode.node.get_reduction_type() # type: ignore[union-attr] + in { + "sum", + "prod", + } + for subnode in other_node.get_nodes() + if subnode.is_reduction() + ) + return out + + @classmethod + def are_mix_order_reductions( + cls, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> bool: + return cls.can_fuse(node1, node2) + + @classmethod + def is_contiguous_load(cls, buf: str, parent_node: BaseSchedulerNode) -> bool: + from torch._inductor.loop_body import MemoryUsageType + + for node in parent_node.get_nodes(): + assert isinstance(node, SchedulerNode) + loop_body = node._body + entries = loop_body.memory_usage[MemoryUsageType.LOAD] + index_names = [e.index_name for e in entries if e.buffer_name == buf] + + if len(index_names) == 0: + continue + + # there can be multiple index_names some times + for index_name in index_names: + index_expr = loop_body.indexing_exprs[index_name] + var_ranges = loop_body.var_ranges + + # assumes the final symbol is for reduction + var_symbols = list(var_ranges.keys()) + stride_vars = V.graph.sizevars.stride_vars( + index_expr, + var_symbols, + var_symbols, + ) + + # stride==0 means a broadcast + if not (stride_vars[-1] == 0 or stride_vars[-1] == 1): + return False + return True + + +@dataclasses.dataclass +class SchedulerBuffer: + scheduler: Scheduler + node: ir.Buffer + defining_op: Optional[BaseSchedulerNode] + users: list[NodeUser] = dataclasses.field(default_factory=list) + mpi_buffer: MemoryPlanningInfoForBuffer = dataclasses.field( + default_factory=MemoryPlanningInfoForBuffer + ) + + def defining_op_name(self) -> str: + op = self.defining_op + assert op is not None + return op.get_name() + + def __hash__(self) -> int: + return hash(self.node.name) + + def debug_str(self) -> str: + result = IndentedBuffer() + name = self.get_name() + result.writeline(f"{name}: {type(self.node).__name__}") + result.writeline(f"{name}.layout = {self.node.layout}") + if self.get_aliases(): + result.writeline(f"{name}.aliases = {pformat(self.get_aliases())}") + if self.get_mutations(): + result.writeline(f"{name}.mutations = {pformat(self.get_mutations())}") + + if len(self.users) <= 1: + result.writeline(f"{name}.users = {self.users}") + else: + result.writeline(f"{name}.users = [") + with result.indent(1): + for user in self.users: + result.writeline(f"{user},") + result.writeline("]") + return result.getrawvalue() + + def get_name(self) -> str: + return self.node.get_name() + + def allocate(self) -> None: + assert self.node is not None + if not self.node.should_allocate(): + return + + if ( + self.node.get_inputs_that_alias_output() + or self.node.get_mutation_names() + or isinstance(self.node.get_output_spec(), ir.CommBufferLayout) + ): + V.graph.wrapper_code.codegen_allocation(self.node) + return + + # hacky check for if V.kernel is a real kernel or NullHandler + if ( + hasattr(V.kernel, "args") + and self.get_name() in V.kernel.inplace_update_buffers + ): + input_buffer: Union[ir.DonatedBuffer, ir.Buffer] + input_buffer_name = V.kernel.inplace_update_buffers[self.get_name()] + if input_buffer_name in self.scheduler.name_to_donated_buffer: + input_buffer = self.scheduler.name_to_donated_buffer[ + input_buffer_name + ].node + else: + input_buffer = self.scheduler.name_to_buf[input_buffer_name].node + V.graph.wrapper_code.codegen_inplace_reuse( + input_buffer, + self.node, + ) + else: + V.graph.wrapper_code.codegen_allocation(self.node) + + def can_free(self) -> bool: + # There's no real allocated buffer, no need to free it + assert self.node is not None + if isinstance(self.node.layout, ir.NoneLayout) or is_multi_outputs_template( + self.node + ): + return False + for use in self.users: + if isinstance(use.node, OutputNode): + return False + return True + + def set_users(self, users: list[NodeUser]) -> None: + # deduplicate + result: dict[int, NodeUser] = {} + for use in users: + if id(use.node) in result: + result[id(use.node)] = use.merge(result[id(use.node)]) + else: + result[id(use.node)] = use + self.users = list(result.values()) + + def get_aliases(self) -> Sequence[str]: + assert self.node is not None + return self.node.get_inputs_that_alias_output() + + def get_mutations(self) -> Sequence[str]: + assert self.node is not None + return self.node.get_mutation_names() + + def get_device(self) -> Optional[torch.device]: + return self.node.get_output_spec().get_device() + + +@dataclasses.dataclass +class SchedulerDonatedBuffer(SchedulerBuffer): + defining_op: Optional[BaseSchedulerNode] = None + + +class BaseSchedulerNode: + ancestors: OrderedSet[str] + group: tuple[torch.device, tuple[tuple[sympy.Expr, ...], ...]] + last_usage: OrderedSet[str] + # .min_order and .max_order are only relevant for "grouped" nodes such as FusedSchedulerNode. + # e.g. if the FusedSchedulerNode includes nodes (op_1, op_2, op_3), and op_X is X-th node + # in `self.scheduler.nodes`, then for this FusedSchedulerNode, .min_order is 1 and .max_order is 3. + # For non-"grouped" nodes (i.e. regular SchedulerNode), + # .min_order = .max_order = X if this node is X-th node in `self.scheduler.nodes`. + min_order: int + max_order: int + mpi_node: MemoryPlanningInfoForNode + mutation_renames: dict[str, str] + node: Optional[ir.Operation] = None + outputs: list[SchedulerBuffer] + outputs_by_name: dict[str, SchedulerBuffer] + override_estimated_runtime: Optional[float] = None + read_writes: dependencies.ReadWrites + unmet_dependencies: OrderedSet[Dep] + written: bool = False + + def __init__(self, scheduler: Scheduler) -> None: + self.scheduler: Scheduler = scheduler + self.debug_device_str: Callable[[BaseSchedulerNode], list[str]] = ( + lambda *args, **kwargs: [] + ) + + def _init_from_node(self, node: ir.Operation) -> None: + self.node = node + self.ancestors = OrderedSet() + self.last_usage = OrderedSet[ + str + ]() # buffers that won't be used after this kernel + self.written = False + self.outputs = [ + SchedulerBuffer( + scheduler=self.scheduler, + node=output, + defining_op=self, + ) + for output in node.get_outputs() + ] + self.outputs_by_name = {buf.get_name(): buf for buf in self.outputs} + + # mutation_renames for the current node. Due to potential + # more mutations happening later, this can be different + # to Scheduler.mutation_renames. Also this dict should be small + # since only mutation information relevant to the deps for this + # node is stored here. + self.mutation_renames = {} + + def __repr__(self) -> str: + return f"{type(self).__name__}(name={self.get_name()!r})" + + def debug_str(self) -> str: + """Longer form printout for trace logs""" + name = self.get_name() + buf = IndentedBuffer() + buf.splice( + f"""\ +{name}: {type(self).__name__}({type(getattr(self, "node", None)).__name__}) +{name}.writes = {pformat(self.read_writes.writes)} +{name}.unmet_dependencies = {pformat(self.unmet_dependencies)} +{name}.met_dependencies = {pformat(self.read_writes.reads - self.unmet_dependencies)} +{name}.outputs = [ + """ + ) + with buf.indent(): + for out in self.get_outputs(): + buf.splice(out.debug_str()) + buf.writeline("]") + + try: + buf.splice(self.debug_str_extra()) + except Exception: + log.warning("Ignoring error in debug_str()", exc_info=True) + + return buf.getrawvalue().rstrip() + + def debug_str_extra(self) -> str: + return "" + + def _debug_str_for_device(self) -> list[str]: + return self.debug_device_str(self) + + def debug_str_short(self) -> str: + maybe_data = getattr(self.node, "data", None) + data_str = "" + if isinstance(maybe_data, torch._inductor.ir.Pointwise): + data_str = ", " + maybe_data.str_helper( + [maybe_data.get_size()], shorten=False, multiline=False + ) + elif isinstance(maybe_data, torch._inductor.ir.Reduction): + data_str = ", " + maybe_data.str_helper( + [maybe_data.get_reduction_size(), maybe_data.get_reduction_type()], + shorten=False, + multiline=False, + ) + return f"{self}{data_str}" + + def log_details(self) -> None: + log.info( + "%s: unmet_dependencies = %s, writes = %s", + self, + self.unmet_dependencies, + self.read_writes.writes, + ) + + def reorder_loops_by_dep_pair( + self, self_dep: MemoryDep, other_dep: MemoryDep + ) -> bool: + return False + + def update_mutated_names(self, renames: dict[str, str]) -> None: + self.mutation_renames = { + name: renames[name] + for name in (dep.name for dep in self.read_writes.reads_and_writes()) + if name in renames + } + self.set_read_writes(self.read_writes.rename(self.mutation_renames)) + + def add_fake_dep(self, dep: Dep) -> None: + self.set_read_writes(self.read_writes.with_read(dep)) + + def has_aliasing_or_mutation(self) -> bool: + return any( + buf.get_aliases() or buf.get_mutations() for buf in self.get_outputs() + ) + + def set_read_writes(self, rw: dependencies.ReadWrites) -> None: + self.read_writes = rw + self.unmet_dependencies = self.read_writes.reads + self.prune_deps() + + def set_last_usage( + self, future_used_buffers: OrderedSet[str], mutation_real_name: dict[str, str] + ) -> None: + used_buffers = self.used_or_aliased_buffer_names() + used_buffers = OrderedSet(mutation_real_name.get(k, k) for k in used_buffers) + self.last_usage = used_buffers - future_used_buffers + + def mark_run(self) -> None: + for buf in self.outputs: + buf.allocate() + + def used_buffer_names(self) -> OrderedSet[str]: + return OrderedSet( + dep.name + for dep in itertools.chain(self.read_writes.reads, self.read_writes.writes) + ) + + def used_or_aliased_buffer_names(self) -> OrderedSet[str]: + """ + Returns buffer names used by this node, including aliases. + + Note: is_fake WeakDeps are excluded since they are purely for ordering + and should not affect buffer lifetime. + """ + used_names: OrderedSet[str] = OrderedSet() + + deps = [ + dep.name + for dep in itertools.chain(self.read_writes.reads, self.read_writes.writes) + if not (isinstance(dep, WeakDep) and dep.is_fake) + ] + while len(deps) > 0: + dep = deps.pop() + used_names.add(dep) + if V.graph.name_to_buffer.get(dep): + deps.extend( + alias + for alias in V.graph.name_to_buffer[ + dep + ].get_inputs_that_alias_output() + if alias not in used_names + ) + return used_names + + def prune_deps(self) -> None: + self.unmet_dependencies = OrderedSet( + dep + for dep in self.unmet_dependencies + if dep.name not in self.scheduler.available_buffer_names + ) + + def prune_weak_deps(self) -> None: + # Prune weak dependencies on operations that have been removed + def should_prune(dep: Dep) -> bool: + if not isinstance(dep, WeakDep): + return False + op_name = self.scheduler.name_to_buf[dep.name].defining_op_name() + return op_name in V.graph.removed_operations + + to_remove = OrderedSet( + dep for dep in self.read_writes.reads if should_prune(dep) + ) + self.set_read_writes(self.read_writes.remove_reads(to_remove)) + + def prune_redundant_deps( + self, name_to_fused_node: dict[str, BaseSchedulerNode] + ) -> None: + _prune_redundant_deps(self, name_to_fused_node, self.scheduler.name_to_buf) + + def get_name(self) -> str: + assert self.node is not None + return self.node.get_operation_name() + + def get_first_name(self) -> str: + return self.get_name() + + @cache_on_self + def get_operation_names(self) -> OrderedSet[str]: + return OrderedSet(node.get_name() for node in self.get_nodes()) + + @cache_on_self + def get_buffer_names(self) -> OrderedSet[str]: + return OrderedSet(out.get_name() for out in self.outputs) + + @cache_on_self + def can_codegen_in_low_precision(self) -> bool: + return all( + isinstance(n, SchedulerNode) + and can_codegen_without_upcasts(n, disallow_fp32_ops=True) + for n in self.get_nodes() + ) + + @cache_on_self + def can_codegen_without_upcasts(self) -> bool: + return all( + isinstance(n, SchedulerNode) and can_codegen_without_upcasts(n) + for n in self.get_nodes() + ) + + def get_nodes(self) -> Sequence[BaseSchedulerNode]: + return [self] + + def get_outputs(self) -> Sequence[SchedulerBuffer]: + return self.outputs + + def get_output(self, buf_name: str) -> SchedulerBuffer: + return self.outputs_by_name[buf_name] + + def get_device(self) -> Optional[torch.device]: + assert self.node is not None + return self.node.get_device() + + def is_cpu(self) -> bool: + device = self.get_device() + return device is not None and device.type == "cpu" + + def is_gpu(self) -> bool: + device = self.get_device() + return device is not None and is_gpu(device.type) + + def is_reduction(self) -> bool: + return False + + def is_native_matmul(self) -> bool: + return False + + def is_split_scan(self) -> bool: + return False + + def is_template(self) -> bool: + return False + + def is_extern(self) -> bool: + return False + + def is_foreach(self) -> bool: + return False + + def can_inplace(self, read_dep: dependencies.Dep) -> bool: + return False + + def has_side_effects(self) -> bool: + return False + + def decide_inplace_update(self) -> None: + """ + Decide if there should be inplace updates for the node + and record the decision in the active kernel. + """ + from .codegen.wrapper import can_match_buffer_size + + if not ( + isinstance(self, SchedulerNode) + and config.inplace_buffers + and V.graph.has_feature(self.get_device(), BackendFeature.INPLACE_BUFFERS) + and ( + not isinstance(V.kernel, torch._inductor.codegen.simd.SIMDKernel) + or getattr(V.kernel, "mutations", None) is not None + ) + # hacky check for if V.kernel is a real kernel or NullHandler + and hasattr(V.kernel, "args") + ): + return + + # NOTE remove V.graph.removed_operations once deps issue is fixed + inconsequential_nodes = ( + self.ancestors + | V.graph.removed_operations + | self.scheduler.completed_operations + ) + + def single_index_in_fused_node(buf_to_be_inplaced: SchedulerBuffer) -> bool: + # Inside of NodeUser, we track that the read and write are equivalent + # before deciding if the use can be inplace. + # But if that use is fused into a larger kernel, we need to check equivalence + # of other accesses in fused scheduler node as well. + fused_node = buf_to_be_inplaced.scheduler.get_fused_node(self) + buf_name = buf_to_be_inplaced.get_name() + # Dedup read/writes with equivalent indices + # TODO - would be nice if we could just cache accesses on ReadWrites, + # and enforce variant that this class & members are functional.. + deps: OrderedSet[Dep] = OrderedSet() + for user in buf_to_be_inplaced.users: + user_node = user.node + if not isinstance(user_node, BaseSchedulerNode): + continue + + if ( + user_node.get_first_name() + not in buf_to_be_inplaced.scheduler.name_to_fused_node + or buf_to_be_inplaced.scheduler.get_fused_node(user_node) + is not fused_node + ): + continue + + deps |= ( + o + for o in user_node.read_writes.reads_and_writes() + if o.name == buf_name + ) + if len(deps) > 1: + return False + + return True + + for buf in self.get_outputs(): + buf_node = buf.node + assert buf_node is not None + if ( + not buf_node.should_allocate() + or buf_node.get_inputs_that_alias_output() + or buf_node.get_mutation_names() + or buf.get_name() in V.graph.removed_buffers + ): + continue + + for read in self.read_writes.reads: + input_buf: Optional[Union[SchedulerBuffer, SchedulerDonatedBuffer]] + if read.name in self.scheduler.name_to_donated_buffer: + input_buf = self.scheduler.name_to_donated_buffer[read.name] + else: + input_buf = self.scheduler.name_to_buf.get(read.name) + + if ( + input_buf + and V.graph.wrapper_code.can_reuse(input_buf, self) + and not isinstance(input_buf.defining_op, NopKernelSchedulerNode) + ): + assert input_buf.users is not None + remaining_uses = [ + x + for x in input_buf.users + if x.node.get_name() not in inconsequential_nodes + ] + if ( + len(remaining_uses) == 1 + and remaining_uses[0].can_inplace + and remaining_uses[0].node is self + and input_buf.node is not None + and not isinstance( + input_buf.node.get_output_spec(), + ( + ir.NoneLayout, + ir.MultiOutputLayout, + ir.MutationLayoutSHOULDREMOVE, + ), + ) + and not ( + input_buf.defining_op + and isinstance( + input_buf.defining_op.node, + (ir.FallbackKernel, ir.MultiOutput), + ) + and len(input_buf.node.get_inputs_that_alias_output()) > 0 + ) + and can_match_buffer_size(input_buf.node, buf.node) + and single_index_in_fused_node(input_buf) + ): + # if there isn't a triton kernel, then we don't need to call triton-specific things. + # but TODO this might be a convenient place to signal to the Collective kernels to inplace + # (and, can we make "kernel" less generic of a name?) + V.kernel.args.make_inplace(input_buf.get_name(), buf.get_name()) + # mutations not tracked in cpp kernels + if isinstance( + V.kernel, torch._inductor.codegen.simd.SIMDKernel + ): + V.kernel.mutations.add(input_buf.get_name()) + V.kernel.mutations.add(buf.get_name()) + + V.kernel.inplace_update_buffers[buf.get_name()] = ( + input_buf.get_name() + ) + break + + def codegen_originating_info( + self, buffer: IndentedBuffer, only_once: bool = True + ) -> None: + if not config.comment_origin: + return + + if only_once and self.written: + return + assert self.node is not None + origins = self.node.get_origins() + out_lines = [] + + for o in origins: + if o.op == "output": + # These are boring and samey + continue + + out_lines.append("") + # TODO(voz): Should the pragma be constant somewhere? + out_lines.append("#pragma CMT ORIGIN:") + op_info_str = f"#pragma CMT {o.op} {o.target}" + if "seq_nr" in o.meta: + op_info_str = op_info_str + f" seq_nr:{o.meta['seq_nr']}" + out_lines.append(op_info_str) + if "stack_trace" in o.meta: + stack_trace = f"{o.meta['stack_trace']}" + stack_trace_last_line = stack_trace.rsplit("|", maxsplit=1)[-1] + out_lines.append( + "#pragma CMT " + + stack_trace_last_line.replace("{", "{{") + .replace("}", "}}") + .replace("\n", "\\") + .replace( + "\\", "\\\\" + ) # For windows safe path, avoid for example \x, \U. + ) + out_lines.append("#pragma CMT END ORIGIN") + out_lines.append("") + + if len(out_lines) == 0: + return + + # TODO(voz): Ostensibly, we should not need this. But there are cases where C++ codegen does + # not use BracesBuffer, so we have no good indicator of a C++ buffer atm. + buffer.writelines(out_lines) + self.written = True + + @cache_on_self + def get_read_write_buffers_sizes(self) -> int: + return self.get_read_write_buffers_sizes_impl( + include_reads=True, include_writes=True + ) + + @cache_on_self + def get_read_buffer_sizes(self) -> int: + return self.get_read_write_buffers_sizes_impl( + include_reads=True, include_writes=False + ) + + @cache_on_self + def get_write_buffer_sizes(self) -> int: + return self.get_read_write_buffers_sizes_impl( + include_reads=False, include_writes=True + ) + + def get_read_write_buffers_sizes_impl( + self, include_reads: bool, include_writes: bool + ) -> int: + return sum( + self.get_read_write_buffer_accesses( + include_reads=include_reads, include_writes=include_writes + ).values(), + start=0, + ) + + def get_read_write_buffer_accesses( + self, include_reads: bool, include_writes: bool + ) -> dict[str, int]: + """ + Counting the number of bytes accessed for a kernel is + surprisingly tricky. In particular, there is a differentiation + between 'theoretical' memory accesses and practical memory + accesses. For example, a layernorm kernel may actually access an + input 3 times, but in theory, it only needs to access its input + once (and may be optimized to do so through say, persistent + reductions) + + Another example is that even though a buffer is passed in, we may + not access the entire buffer. This may occur if we are accessing + a slice of the buffer. Another tricky case is for indirect + indexing, where the amount of bytes accessed depends on the + values of the input. + + What this function aims to compute is the memory accesses for + worst-case inputs, best-case optimization. What this means is + that for each buffer we compute the amount of potential accesses in two ways and take the minimum. + + 1. Numel in ranges multiplied by number of deps the buffer has + 2. The buffer size + + Returns memory accesses per buffer. + """ + if isinstance(self, NopKernelSchedulerNode): + return {} + if isinstance(self, ExternKernelSchedulerNode) and isinstance( + self.node, MultiOutput + ): + # todo: Calculate this - it's kinda annoying. + return {} + if ( + isinstance(self, ExternKernelSchedulerNode) + and isinstance(self.node, ir.FallbackKernel) + and self.node.op_overload + is torch._prims.rng_prims.graphsafe_run_with_rng_state + ): + return {} + + def try_size_hint(s: sympy.Expr) -> int: + return V.graph.sizevars.size_hint(s, fallback=0) + + if isinstance(self, SchedulerNode): + node_numel = try_size_hint( + sympy_product(self.get_ranges()[0]) + * sympy_product(self.get_ranges()[1]), + ) + else: + node_numel = int(1e9) + buf_accesses = collections.defaultdict(list) + + if include_reads: + for dep in self.read_writes.reads: + buf_accesses[dep.name].append(dep) + + if include_writes: + for dep in self.read_writes.writes: + buf_accesses[dep.name].append(dep) + + reads = ( + OrderedSet(dep.name for dep in self.read_writes.reads) + if include_reads + else OrderedSet() + ) + writes = ( + OrderedSet(dep.name for dep in self.read_writes.writes) + if include_writes + else OrderedSet() + ) + + def is_materialized(buf: str, snodes: Sequence[BaseSchedulerNode]) -> bool: + users = self.scheduler.name_to_buf[buf].users + buf_uses = OrderedSet(user.node for user in users) + return len(buf_uses - OrderedSet(snodes)) > 0 + + if isinstance(self, FusedSchedulerNode): + removed_buffers = OrderedSet( + dep for dep in writes if not is_materialized(dep, self.snodes) + ) + writes = writes - removed_buffers + reads = reads - removed_buffers + + buf_byte_accesses: dict[str, int] = {} + + for buf_name in reads | writes: + buf_accessed_elems = sum(node_numel for dep in buf_accesses[buf_name]) + buf: Union[ir.Buffer, ir.TensorBox, ir.TorchBindObject] + if buf_name in V.graph.name_to_buffer: + buf = V.graph.name_to_buffer[buf_name] + elif buf_name in V.graph.graph_inputs: + buf = V.graph.graph_inputs[buf_name] + else: + continue + + def get_buf_bytes( + buf: Optional[Union[ir.Buffer, ir.TensorBox, ir.TorchBindObject]], + ) -> int: + if not buf: + return 0 + + if isinstance(buf, ir.TorchBindObject): + return buf.get_buf_bytes() + elif isinstance(buf.layout, MultiOutputLayout): + # Kind of a lazy way to get the MultiOutput nodes corresponding to + # a MultiOutputLayout + users = self.scheduler.name_to_buf[buf.get_name()].users + tot = 0 + for user in users: + assert isinstance(user.node, BaseSchedulerNode) + if isinstance(user.node.node, MultiOutput): + for sched_buf in user.node.get_outputs(): + tot += get_buf_bytes(sched_buf.node) + else: + # Buf is a MultiOutputLayout but not all of its + # users are MultiOutputs... + # TODO: Figure out what's going on + return 0 + return tot + elif isinstance(buf.layout, ir.NoneLayout): + return sum( + get_buf_bytes(V.graph.get_buffer(mut_name)) + for mut_name in buf.get_mutation_names() + ) + else: + buf_elems = try_size_hint(sympy_product(buf.get_size())) + return get_dtype_size(buf.get_dtype()) * min( + buf_accessed_elems, buf_elems + ) + + buf_bytes = get_buf_bytes(buf) + if buf_name not in buf_byte_accesses: + buf_byte_accesses[buf_name] = buf_bytes + else: + buf_byte_accesses[buf_name] += buf_bytes + + return buf_byte_accesses + + @cache_on_self + def estimate_flops(self) -> int | None: + if self.node is None: + return None + fx_node = self.node.get_origin_node() + if fx_node is None: + return None + + flops = count_flops_fx(fx_node) + if flops is None: + return None + + resolved_flops = V.graph.sizevars.size_hint(flops, fallback=0) + counters["inductor"]["flop_count"] += resolved_flops + return resolved_flops + + def get_estimated_runtime(self) -> float: + if self.override_estimated_runtime is not None: + return self.override_estimated_runtime + + return self._get_estimated_runtime() + + @cache_on_self + def _get_estimated_runtime(self) -> float: + """ + Returns estimated op runtime in milliseconds (ms) + """ + buf = self.get_nodes()[0].get_outputs()[0] + layout = buf.node.get_output_spec() + if not is_gpu(get_device_type(layout)): + # default to no reordering based on runtime + return 0 + + # Collective kernels + if is_collective(self.node): + assert isinstance(self.node, ir.IRNode) + try: + if config_comms.runtime_estimations_use_nccl_lib_estimations: + cache_key = get_estimate_runtime_cache_key_from_snode(self) + cache = get_estimate_runtime_cache() + cache_val = cache.lookup(cache_key) + if cache_val is not None: + assert isinstance(cache_val, float) + return cache_val + + ms = estimate_nccl_collective_runtime_nccl_estimator(self) + if ms is None: + # NCCL estimations fail: fallback to in-tree algorithmic estimation. + ms = estimate_nccl_collective_runtime(self.node) + + cache.set_value(cache_key, value=ms) + return ms + return estimate_nccl_collective_runtime(self.node) + except ValueError as e: + # We don't know how to estimate runtime for this collective, + # falling back to 0 + log.info(e) # noqa: G200 + return 0 + except TypeError as e: + # this happens when the collective is not of type ir._CollectiveKernel + log.info(e) # noqa: G200 + return 0 + + elif is_wait(self.node): + # ir.Wait is only used for collective ops. + # The time needed for the collective op is already estimated and considered + # when we are processing the collective op IR node, so ir.Wait takes 0 time + # since it doesn't take extra time to get the result after the collective is completed. + return 0 + + ret = maybe_estimate_runtime_benchmark(self) + if ret is not None: + return ret + + dtype = buf.node.maybe_get_dtype() + try: + gpu_memory_bandwidth = get_gpu_dram_gbps() + gpu_flops = get_device_tflops(dtype) * 10**12 + # If cudaGetDeviceProperties returns 0 for gpu_memory_bandwidth or gpu_flops + # there is a chance to continue execution successfully. Otherwise, it would fail with + # ZeroDivisionError below. + if gpu_memory_bandwidth <= 0: + raise AssertionError( + f"gpu_memory_bandwidth cannot be <= 0, but got {gpu_memory_bandwidth}" + ) + if gpu_flops <= 0: + raise AssertionError(f"gpu_flops cannot be <= 0, but got {gpu_flops}") + except Exception: + return 0 + + flops_est = self.estimate_flops() + + if flops_est == 0 or flops_est is None: + # no flops estimate, so fall back to memory estimate + ns = self.get_read_write_buffers_sizes() / gpu_memory_bandwidth + ms = ns / 1e6 + return ms + + # TODO(xmfan): find a better heuristic to model FLOPS/latency relationship + factor = 1.0 + counted_bytes = self.get_read_write_buffers_sizes() + counted_bytes = 0 if counted_bytes is None else counted_bytes + compute_time = (factor * flops_est / gpu_flops) * 1e9 + transfer_time = counted_bytes / gpu_memory_bandwidth + + # Return estimated runtime in milliseconds + ns = max(compute_time, transfer_time) + ms = ns / 1e6 + return ms + + def get_template_node(self) -> Optional[ir.TemplateBuffer]: + return None + + def get_template_node_or_throw(self) -> ir.TemplateBuffer: + template = self.get_template_node() + assert template is not None + return template + + @staticmethod + def get_prologue_template_epilogue( + nodes: list[BaseSchedulerNode], + ) -> tuple[list[BaseSchedulerNode], BaseSchedulerNode, list[BaseSchedulerNode]]: + """ + For the list of nodes, get the prologue, template, and epilogue + """ + template_index = next(i for i, n in enumerate(nodes) if n.is_template()) + + prologue = nodes[:template_index] + template_node = nodes[template_index] + epilogue = nodes[template_index + 1 :] + return prologue, template_node, epilogue + + +@functools.cache +def get_estimate_runtime_cache() -> torch._inductor.codecache.LocalCache: + return torch._inductor.codecache.LocalCache() + + +def get_estimate_runtime_cache_key_from_snode(snode: BaseSchedulerNode) -> str: + python_kernel_name = getattr(snode.node, "python_kernel_name", "") + args = snode.node.inputs # type: ignore[union-attr] + args = snode.node.fill_non_provided_args( # type: ignore[union-attr] + [*args, *snode.node.constant_args], # type: ignore[union-attr] + snode.node.kwargs, # type: ignore[union-attr] + ) + kwargs = snode.node.kwargs # type: ignore[union-attr] + flat_args, flat_args_pytree_spec = pytree.tree_flatten((args, kwargs)) + + def _is_tensor_ir(x) -> bool: # type: ignore[no-untyped-def] + return isinstance(x, ir.IRNode) and not isinstance(x, ir.GeneratorState) + + cache_key = str( + (python_kernel_name,) + + tuple(tuple(a.get_size()) if _is_tensor_ir(a) else None for a in flat_args) + ) + return cache_key + + +def _get_mm_like_fn(snode: BaseSchedulerNode) -> Optional[Callable[[Any], Any]]: + if not isinstance(snode, ExternKernelSchedulerNode): + return None + mms_fns = { + "extern_kernels.mm": torch.ops.aten.mm, + "extern_kernels.bmm": torch.ops.aten.bmm, + "extern_kernels.addmm": torch.ops.aten.addmm, + } + python_kernel_name = getattr(snode.node, "python_kernel_name", "") + if python_kernel_name not in mms_fns: + return None + if not isinstance(snode.node, ir.ExternKernel): + return None + return mms_fns[python_kernel_name] + + +def maybe_estimate_runtime_benchmark(snode: BaseSchedulerNode) -> Optional[float]: + bench_fn = None + args_kwargs_fn = None + if config.runtime_estimations_mms_benchmark: + mm_fn = _get_mm_like_fn(snode) + if mm_fn is None: + return None + bench_fn = mm_fn + # pyrefly: ignore [unbound-name] + args_kwargs_fn = lambda: snode_args_kwargs(snode) # noqa: E731 + else: + return None + + cache_key = get_estimate_runtime_cache_key_from_snode(snode) + cache = get_estimate_runtime_cache() + cache_val = cache.lookup(cache_key) + if cache_val is not None: + assert isinstance(cache_val, float) + return cache_val + + from .utils import snode_args_kwargs + + args, kwargs = args_kwargs_fn() + from torch._inductor.runtime.benchmarking import benchmarker + + ms = benchmarker.benchmark(bench_fn, args, kwargs) # type: ignore[arg-type] + + cache.set_value(cache_key, value=ms) + return ms + + +@dataclasses.dataclass(slots=True) +class WhyNoFuse: + name1: str + name2: str + reason: str + args: tuple[Any, ...] + + def __init__(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode) -> None: + self.name1 = node1.get_name() + self.name2 = node2.get_name() + + def __call__(self, reason: str, *args: Any) -> None: + self.reason = reason + self.args = args + fusion_log.debug(self) + + def __str__(self) -> str: + return f"cannot fuse {self.name1} with {self.name2}: " + ( + self.reason % self.args + ) + + +def pformat(obj: Any) -> str: + if isinstance(obj, (OrderedSet, set)): # noqa: set_linter + # pformat has trouble with sets of sympy exprs + obj = sorted(obj, key=str) + result = pprint.pformat(obj, indent=4) + if "\n" in result: + return f"\n{textwrap.indent(result, ' ' * 4)}" + return result + + +class OutputNode: + def __init__(self, dep: StarDep) -> None: + self.unmet_dependencies = OrderedSet([dep]) + + def is_reduction(self) -> bool: + return False + + def get_inputs_that_alias_output(self) -> Sequence[str]: + return () + + def get_name(self) -> str: + return "OUTPUT" + + __repr__ = get_name + + +def _prune_redundant_deps( + node: BaseSchedulerNode, + name_to_fused_node: dict[str, BaseSchedulerNode], + name_to_buf: dict[str, SchedulerBuffer], +) -> None: + """ + Prunes weakdeps intended for mutation ordering + on an upstream fused node if after fusion there is another dependency + on the fused upstream node, making the weakdep redundant + + In essence this enforces an ordering on fusions. As fusions occur, weakdeps will + be incrementally removed, enabling other fusions, ensuring they are fused in order. + """ + name_to_dep_count: Counter[str] = collections.Counter() + + for dep in node.unmet_dependencies: + if not isinstance(dep, WeakDep): + op_name = name_to_buf[dep.name].defining_op_name() + name_to_dep_count[name_to_fused_node[op_name].get_name()] += 1 + + def should_prune(dep: Dep) -> bool: + if isinstance(dep, WeakDep): + op_name = name_to_buf[dep.name].defining_op_name() + is_redundant = name_to_dep_count[ + name_to_fused_node[op_name].get_name() + ] > 0 and node.scheduler.fusable_weak_dep( + dep, name_to_fused_node[op_name], node + ) + # These can occur because fused nodes always gather deps from their snodes + # If B has a weakdep on A + # B gets fused with C, then any time BC is fused, the weakdep will reappear + is_self_dep = name_to_fused_node[op_name] == node + return is_redundant or is_self_dep + else: + return False + + deps_to_prune = OrderedSet( + dep for dep in node.unmet_dependencies if should_prune(dep) + ) + + if deps_to_prune: + node.unmet_dependencies = node.unmet_dependencies - deps_to_prune + node.set_read_writes(node.read_writes.remove_reads(deps_to_prune)) + + +class ExternKernelSchedulerNode(BaseSchedulerNode): + def __init__(self, scheduler: Scheduler, node: ir.Operation) -> None: + super().__init__(scheduler) + self._init_from_node(node) + self.set_read_writes(node.get_read_writes()) + + def debug_str_extra(self) -> str: + return f"{self.get_name()}.node.kernel = {getattr(self.node, 'python_kernel_name', None)}" + + def is_extern(self) -> bool: + return True + + def has_side_effects(self) -> bool: + assert self.node is not None + return hasattr(self.node, "has_side_effects") and self.node.has_side_effects() + + +class NopKernelSchedulerNode(BaseSchedulerNode): + def __init__(self, scheduler: Scheduler, node: ir.Operation) -> None: + super().__init__(scheduler) + self._init_from_node(node) + self.set_read_writes(node.get_read_writes()) + + +class SchedulerNode(BaseSchedulerNode): + """ + A SchedulerNode is a node for scheduling that encapsulates either + a ComputedBuffer or a TemplateBuffer. + """ + + _sizes: tuple[Sequence[sympy.Expr], ...] + _body: LoopBody + + def __init__( + self, + scheduler: Scheduler, + node: Union[ir.ComputedBuffer, ir.TemplateBuffer], + ) -> None: + super().__init__(scheduler) + self._init_from_node(node) + self._compute_attrs() + + def _compute_attrs( + self, + extra_indexing_constraints: Optional[tuple[dict[Any, Any], list[Any]]] = None, + recompute_sizes_body_func: Optional[Callable[_P, _T]] = None, + ) -> None: + assert isinstance(self.node, (ir.ComputedBuffer, ir.TemplateBuffer)) + self._sizes, body = self.node.simplify_and_reorder( + extra_indexing_constraints=extra_indexing_constraints, + recompute_sizes_body_func=recompute_sizes_body_func, + ) + self._body = body # type: ignore[assignment] + + device = self.node.get_device_or_error() + group_fn = self.scheduler.get_backend(device).group_fn + self.group = (device, group_fn(self._sizes)) + + # Don't normalize since normalization will merge loops which + # makes it hard to decide new loop orders. + should_normalize = not config.loop_ordering_after_fusion or not is_gpu( + device.type + ) + + if isinstance(self.node, ir.TemplateBuffer): + self.set_read_writes( + self.node.extract_read_writes(normalize=should_normalize) + ) + else: + self.set_read_writes( + dependencies.extract_read_writes( + self._body, *self._sizes, normalize=should_normalize + ) + ) + + def recompute_size_and_body( + self, + extra_indexing_constraints: Optional[tuple[dict[Any, Any], list[Any]]] = None, + recompute_sizes_body_func: Optional[Callable[..., Any]] = None, + ) -> None: + self._compute_attrs( + extra_indexing_constraints=extra_indexing_constraints, + recompute_sizes_body_func=recompute_sizes_body_func, + ) + + def refresh_dependencies( + self, normalize: bool, need_clear_tiling_cache: bool + ) -> None: + # Fake dependencies are added manually. They can not be analyzed from + # extract_read_writes. Find them out and apply manually. + fake_deps: OrderedSet[Dep] = OrderedSet( + dep for dep in self.read_writes.reads if isinstance(dep, (WeakDep, StarDep)) + ) + + # don't normalize since the loop order may need to be further changed + # later + self.set_read_writes( + dependencies.extract_read_writes( + self._body, *self._sizes, normalize=normalize + ) + .with_read(fake_deps) + .rename(self.mutation_renames) + ) + + self.pointwise_read_writes.clear_cache(self) + + if need_clear_tiling_cache: + from .codegen.simd import SIMDScheduling + + # TODO(shunting) if this cause compilation time increase when + # enabling LOAF by default, try just clearing the specific cache + # entry by using a customized cache implementation rather than + # lru_cache. + SIMDScheduling.candidate_tilings.cache_clear() + + def apply_new_loop_order(self, new_order: Sequence[int]) -> None: + self._body = self._body.reorder_iter_loops( + new_order, + ) + self._sizes = self._body.sizes + + self.refresh_dependencies(normalize=False, need_clear_tiling_cache=True) + + def swap_pw_red_dimension(self) -> None: + num_rdims = self._body.get_original_num_rdims() + num_pwdims = len(self._body.iter_vars) - num_rdims + pwdims = tuple(range(num_pwdims)) + rdims = tuple(range(num_pwdims, num_pwdims + num_rdims)) + + self.apply_new_loop_order(rdims + pwdims) + assert len(self.group[1]) == 2 + self.group = self.group[0], (self.group[1][1], self.group[1][0]) + + def extract_pw_from_reduction(self) -> BaseSchedulerNode: + self._body = self._body.extract_pw_from_reduction() + return self + + def cancel_reduction_split(self) -> None: + if not MixOrderReduction.is_split_reduction(self): + return + assert isinstance(self.node, ir.ComputedBuffer) + with self.node.with_original_inner_fn(): + self._compute_attrs() + + def expand_dimension_for_pointwise_node( + self, dimension: int, new_range: int + ) -> None: + assert isinstance(self.node, (ir.ComputedBuffer, ir.TemplateBuffer)) + + self._body = self._body.expand_dimension_for_pointwise_node( + dimension, new_range + ) + self._sizes = self._body.sizes + + device = self.node.get_device_or_error() + group_fn = self.scheduler.get_backend(device).group_fn + self.group = (device, group_fn(self._sizes)) + + # Need normalize the prefix name to facilitate finding common dependencies + self.refresh_dependencies(normalize=True, need_clear_tiling_cache=True) + + def merge_loops(self) -> None: + self._body = self._body.merge_loops() + self._sizes = self._body.sizes + + # merge_loops is called after loop reordering. + # We still need retain fake dependencies since codegen the + # estimated amount of memory access rely on them. + # + # Merge loops does not affect the tiling decision. So we + # don't need clear the tiling cache. + self.refresh_dependencies(normalize=True, need_clear_tiling_cache=False) + + def reorder_loops_by_dep_pair( + self, self_dep: MemoryDep, other_dep: MemoryDep + ) -> bool: + new_order = None + self_sizes = self._sizes[0] + if len(self_sizes) == self_dep.num_vars == other_dep.num_vars: + new_order = self_dep.decide_loop_order_to_match(other_dep) + + if new_order: + # pyrefly: ignore [bad-assignment] + metrics.num_loop_reordering += 1 + loop_ordering_log.debug( + "Reorder loops for %s with order %s", self.get_name(), new_order + ) + self.apply_new_loop_order(new_order) + return True + else: + loop_ordering_log.debug( + "Don't reordering %s because we can not decide the suitable loop order", + self.get_name(), + ) + return False + + def debug_str_extra(self) -> str: + name = self.get_name() + lines = [ + f"{name}.group.device = {self.group[0]}", + f"{name}.group.iteration = {self.group[1]}", + f"{name}.sizes = {self._sizes}", + ] + for dep in self.read_writes.reads_and_writes(): + if not isinstance(dep, WeakDep): + buf_name = dep.name + buf = V.graph.get_buffer(buf_name) + if not isinstance(buf, ir.TorchBindObject): + lines.append(f"{buf_name}_layout = {pformat(buf.layout)}") + if isinstance(self._body, LoopBody): + lines.append(f"class {name}_loop_body:") + lines.append(textwrap.indent(self._body.debug_str(), " ")) + + assert self.node is not None + lines.extend(self._debug_str_for_device()) + + return "\n".join(lines) + + def get_ranges(self) -> Sequence[Sequence[sympy.Expr]]: + return self._sizes + + def is_reduction(self) -> bool: + assert isinstance(self.node, (ir.ComputedBuffer, ir.TemplateBuffer)), ( + f"{type(self.node)=}" + ) + + # self._body containing partial accumulate means the reduction is + # converted to a pointwise node. Need this extra check since + # we change self._body but didn't change self.node (IRNode) + # when converting a reduction to a pointwise + return bool(self.node.get_reduction_type()) and ( + self._body is None or not self._body.has_partial_accumulate + ) + + def is_native_matmul(self) -> bool: + assert isinstance(self.node, ir.ComputedBuffer), f"{type(self.node)=}" + return self.node.get_reduction_type() == "dot" + + def is_split_scan(self) -> bool: + assert isinstance(self.node, (ir.ComputedBuffer, ir.TemplateBuffer)), ( + f"{type(self.node)=}" + ) + return isinstance(self.node, ir.ComputedBuffer) and isinstance( + self.node.data, ir.SplitScan + ) + + def is_template(self) -> bool: + return isinstance(self.node, ir.TemplateBuffer) + + def get_template_node(self) -> Optional[ir.TemplateBuffer]: + return self.node if isinstance(self.node, ir.TemplateBuffer) else None + + def run(self, *index_vars: Sequence[sympy.Expr]) -> None: + self.decide_inplace_update() + self.mark_run() + self.codegen(index_vars) + + def ranges_from_index_vars( + self, index_vars: Sequence[Sequence[sympy.Expr]] + ) -> dict[sympy.Expr, sympy.Expr]: + sizes = self._sizes + assert sum(map(len, sizes)) == sum(map(len, index_vars)) + var_ranges = dict( + zip( + itertools.chain.from_iterable(index_vars), + itertools.chain.from_iterable(sizes), + ) + ) + return var_ranges + + def codegen(self, index_vars: Sequence[Sequence[sympy.Expr]]) -> None: + """ + Generate code for this node using the provided index variables. + + This method sets up the appropriate context for code generation, including + simplifying indexing expressions based on the variable ranges, and then + calls the node's body function with the index variables. + + Args: + index_vars: A sequence of sequences of sympy expressions representing + the index variables for each dimension of the computation. + """ + var_ranges = self.ranges_from_index_vars(index_vars) + try: + with ( + V.set_ops_handler(SimplifyIndexing(V.get_ops_handler(), var_ranges)), + V.kernel.set_current_node(self), + ): + self._body(*index_vars) + except Exception: + log.fatal("Error in codegen for %s", self.node) + raise + + def pointwise_or_reduction_read_writes( + self, pointwise: bool = True + ) -> dependencies.ReadWrites: + """ + Get the memory dependencies in either the pointwise or the reduction axes. + """ + keep_sizes, ignore_sizes = self._sizes if pointwise else reversed(self._sizes) + return dependencies.extract_read_writes( + self._body, keep_sizes, hidden_args=[[sympy.S.Zero] * len(ignore_sizes)] + ) + + @cache_on_self + def pointwise_read_writes(self) -> dependencies.ReadWrites: + """ + Get the memory dependencies in the non-reduction axes. + """ + return self.pointwise_or_reduction_read_writes(pointwise=True) + + @cache_on_self + def reduction_read_writes(self) -> dependencies.ReadWrites: + """ + Get the memory dependencies in the reduction axes. + """ + return self.pointwise_or_reduction_read_writes(pointwise=False) + + def can_inplace(self, read_dep: dependencies.Dep) -> bool: + if self.is_template(): + return False + if any(out.get_aliases() for out in self.get_outputs()): + return False + if len(self.read_writes.writes) == 1 and isinstance( + read_dep, dependencies.MemoryDep + ): + write_dep = next(iter(self.read_writes.writes)) + assert isinstance(write_dep, dependencies.MemoryDep), f"{type(write_dep)=}" + return read_dep.index == write_dep.index and read_dep.size == write_dep.size + return False + + @cache_on_self + def _get_atomic_add_buffers(self) -> OrderedSet[str]: + buffers_store_as_atomic_add: OrderedSet[str] = OrderedSet() + if isinstance(self._body, LoopBody): + for node in self._body.get_nodes(): + if ( + node.op == "call_method" + and node.target == "store" + and ( + ("mode" in node.kwargs and node.kwargs["mode"] == "atomic_add") + or (len(node.args) == 5 and node.args[4] == "atomic_add") + ) + ): + buffers_store_as_atomic_add.add( + node.kwargs["name"] + if "name" in node.kwargs + else (node.args[1] if len(node.args) >= 2 else "") + ) + return buffers_store_as_atomic_add + + @cache_on_self + def has_side_effects(self) -> bool: + # self._body is None sometimes that's why this check was added + if self._body is not None and self._body.has_op("device_assert_async"): + return True + return super().has_side_effects() + + +def refresh_group_node_dependencies( + group_snode: Union[FusedSchedulerNode, GroupedSchedulerNode], +) -> None: + snodes = group_snode.snodes + group_snode.set_read_writes( + dependencies.ReadWrites.merge_list([x.read_writes for x in snodes]) + ) + + group_snode.unmet_dependencies = ( + OrderedSet( + dep + for dep in OrderedSet.union(*[x.unmet_dependencies for x in snodes]) + if dep.name not in group_snode.get_buffer_names() + ) + - group_snode.read_writes.writes + ) + + +def init_group_node( + group_snode: Union[FusedSchedulerNode, GroupedSchedulerNode], + scheduler: Scheduler, + snodes: list[BaseSchedulerNode], +) -> None: + assert isinstance(group_snode, (FusedSchedulerNode, GroupedSchedulerNode)) + group_snode.snodes = snodes + group_snode.scheduler = scheduler + group_snode.node = None + group_snode.ancestors = OrderedSet.union( + *[x.ancestors for x in snodes if x.ancestors is not None] + ) + + refresh_group_node_dependencies(group_snode) + + group_snode.min_order = min(x.min_order for x in group_snode.snodes) + group_snode.max_order = max(x.max_order for x in group_snode.snodes) + group_snode.outputs_by_name = { + buf.get_name(): buf for buf in group_snode.get_outputs() + } + + +class FusedSchedulerNode(BaseSchedulerNode): + """ + This is a "fake" scheduler node that represents a group of scheduler nodes + that are meant to be fused together. The way it does this is by maintaining + its unmet dependencies as the union of its constituent nodes. + """ + + snodes: list[BaseSchedulerNode] + + @classmethod + def fuse( + cls, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> FusedSchedulerNode: + assert node1.scheduler is node2.scheduler + assert isinstance(node1, (SchedulerNode, FusedSchedulerNode)) + if node1.is_template() and isinstance(node2, ExternKernelSchedulerNode): + # Fuse multi outputs template and its outputs + # * Node1 has memorydep of MultiOutput in reads + # * Node2 has StarDep of MultiOutput in writes + # Rewrite the Node2' StarDep to MemoryDep, because calculate score_fusion_memory + # of the template node and its epilogue requires the same type of dependencies + assert isinstance(node2.node, MultiOutput) + assert len(node2.read_writes.writes) == 1 + assert isinstance(next(iter(node2.read_writes.writes)), StarDep) + name = next(iter(node2.read_writes.writes)).name + template_nodes = [node for node in node1.get_nodes() if node.is_template()] + assert len(template_nodes) == 1 + template_node = template_nodes[0] + assert len(template_node.read_writes.writes) == 1 + write = next(iter(template_node.read_writes.writes)) + assert isinstance(write, MemoryDep) + node2.read_writes.writes = OrderedSet( + [ + MemoryDep( + name, write.index, write.var_names, write.size, write.mode + ), + ] + ) + else: + assert isinstance(node2, (SchedulerNode, FusedSchedulerNode)) + nodes = list(itertools.chain(node1.get_nodes(), node2.get_nodes())) + return cls(node1.scheduler, nodes) + + def extract_pw_from_reduction(self) -> BaseSchedulerNode: + for subnode in self.snodes: + assert isinstance(subnode, SchedulerNode) + assert subnode.is_reduction() + subnode.extract_pw_from_reduction() + return self + + def swap_pw_red_dimension(self) -> None: + for subnode in self.snodes: + assert isinstance(subnode, SchedulerNode) + subnode.swap_pw_red_dimension() + + @cache_on_self + def estimate_flops(self) -> int | None: + # don't increment counters in fused methods so we don't double count + fps = list( + filter( + None, + ( + node.estimate_flops() + for node in self.get_nodes() + if node.is_template() or node.is_extern() + ), + ) + ) + if len(fps) == 0: + return None + ret = sum(fps) + return ret + + def reorder_loops_by_dep_pair( + self, self_dep: MemoryDep, other_dep: MemoryDep + ) -> bool: + """ + Return true if a loop reordering is performed. + """ + if self.is_template(): + # We can not really reorder loops for a triton template + return False + self_sizes = None + for snode in self.snodes: + assert isinstance(snode, SchedulerNode) + if self_sizes is not None and tuple(self_sizes) != tuple(snode._sizes[0]): + loop_ordering_log.debug( + "Can not reorder fused node due to different sizes" + ) + return False + self_sizes = snode._sizes[0] + new_order = None + + assert self_sizes is not None + if len(self_sizes) == self_dep.num_vars == other_dep.num_vars: + new_order = self_dep.decide_loop_order_to_match(other_dep) + + if not new_order: + loop_ordering_log.debug( + "Dont reordering fused node %s because we can not decide the suitable loop order", + self.get_name(), + ) + return False + # pyrefly: ignore [bad-assignment] + metrics.num_loop_reordering += 1 + loop_ordering_log.debug( + "Reorder loops for fused node %s with order %s", self.get_name(), new_order + ) + for snode in self.snodes: + assert isinstance(snode, SchedulerNode) + snode.apply_new_loop_order(new_order) + + refresh_group_node_dependencies(self) + return True + + def __init__(self, scheduler: Scheduler, snodes: list[BaseSchedulerNode]) -> None: + super().__init__(scheduler) + init_group_node(self, scheduler, snodes) + self.users: list[NodeUser] = [] + self.group = max(snodes, key=lambda x: int(x.is_reduction())).group + + @cache_on_self + def get_name(self) -> str: + return "_".join([x.get_name() for x in self.snodes]) + + def get_first_name(self) -> str: + return self.snodes[0].get_name() + + @cache_on_self + def get_buffer_names(self) -> OrderedSet[str]: + return OrderedSet.union(*[x.get_buffer_names() for x in self.snodes]) + + def get_outputs(self) -> list[SchedulerBuffer]: + result: list[SchedulerBuffer] = [] + for node in self.snodes: + result.extend(node.get_outputs()) + return result + + def debug_str_extra(self) -> str: + lines = [ + f"{self.get_name()}.snodes[{i}] =\n{node.debug_str()}" + for i, node in enumerate(self.snodes) + ] + node = self.snodes[0].node + if node is not None: + lines.extend(self._debug_str_for_device()) + + return textwrap.indent("\n".join(lines).rstrip(), " ") + + def debug_str_short(self) -> str: + snodes_str = [node.debug_str_short() for node in self.snodes] + return f"{self}, snodes: {snodes_str}" + + def set_last_usage( + self, future_used_buffers: OrderedSet[str], mutation_real_name: dict[str, str] + ) -> None: + # Set self.last_usage using the global information + # This will be used for inter-kernel optimisations + super().set_last_usage(future_used_buffers, mutation_real_name) + # Set self.last_usage on the snodes + # This will be used for optimisations within the kernel + future_used_buffers: OrderedSet[str] = OrderedSet() + for node in reversed(self.snodes): + node.set_last_usage(future_used_buffers, mutation_real_name) + future_used_buffers.update(node.last_usage) + + @cache_on_self + def used_buffer_names(self) -> OrderedSet[str]: + return OrderedSet.union(*[x.used_buffer_names() for x in self.snodes]) + + @cache_on_self + def used_or_aliased_buffer_names(self) -> OrderedSet[str]: + return OrderedSet.union( + *[x.used_or_aliased_buffer_names() for x in self.snodes] + ) + + def get_nodes(self) -> Sequence[BaseSchedulerNode]: + return self.snodes + + def __repr__(self) -> str: + return f"{type(self).__name__}(nodes={self.get_name()})" + + @cache_on_self + def is_reduction(self) -> bool: + return any(x.is_reduction() for x in self.snodes) + + @cache_on_self + def is_native_matmul(self) -> bool: + return any(x.is_native_matmul() for x in self.snodes) + + @cache_on_self + def is_split_scan(self) -> bool: + return any(x.is_split_scan() for x in self.snodes) + + @cache_on_self + def is_template(self) -> bool: + return any(x.is_template() for x in self.snodes) + + @cache_on_self + def get_template_node(self) -> Optional[ir.TemplateBuffer]: + for node in self.snodes: + if node.is_template(): + return node.get_template_node() + return None + + def get_device(self) -> torch.device: + return self.group[0] + + @cache_on_self + def has_aliasing_or_mutation(self) -> bool: + return any(x.has_aliasing_or_mutation() for x in self.snodes) + + # None of these need to be implemented, as a FusedSchedulerNode is just an + # abstraction for scheduling purposes + def update_mutated_names(self, renames: dict[str, str]) -> None: + raise NotImplementedError + + def add_fake_dep(self, name: Dep) -> None: + raise NotImplementedError + + def can_inplace(self, read_dep: dependencies.Dep) -> bool: + raise NotImplementedError + + def debug_str(self) -> str: + """Longer form printout for trace logs""" + name = self.get_name() + node_typestr = ",".join(type(n).__name__ for n in self.snodes) + buf = IndentedBuffer() + buf.splice( + f"""\ +{name}: {type(self).__name__}({node_typestr}) +{name}.writes = {pformat(self.read_writes.writes)} +{name}.unmet_dependencies = {pformat(self.unmet_dependencies)} +{name}.met_dependencies = {pformat(self.read_writes.reads - self.unmet_dependencies)} +{name}.outputs = [ + """ + ) + with buf.indent(): + for out in self.get_outputs(): + buf.splice(out.debug_str()) + buf.writeline("]") + + try: + buf.splice(self.debug_str_extra()) + except Exception: + log.warning("Ignoring error in debug_str()", exc_info=True) + + return buf.getrawvalue().rstrip() + + @cache_on_self + def has_side_effects(self) -> bool: + if self.snodes is not None: + return any(node.has_side_effects() for node in self.snodes) + return super().has_side_effects() + + +class FusedMixOrderReductions(FusedSchedulerNode): + def __init__(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode) -> None: + self.node1 = node1 + self.node2 = node2 + super().__init__( + node1.scheduler, list(node1.get_nodes()) + list(node2.get_nodes()) + ) + self.numel = MixOrderReduction.get_numel(self.node1) + + def sub_node_can_fuse( + self, + node1: BaseSchedulerNode, + node2: BaseSchedulerNode, + other_nodes: tuple[BaseSchedulerNode, ...], + ): + """ + node1 is from the current mix order reduction; node2 is another node we want to fuse in. + + other_nodes are passed in to check if fusion will introduce producer/consumer relationship + between the inner and outer reduction. If yes, we don't fuse. + """ + assert not isinstance(node1, FusedMixOrderReductions) + assert not isinstance(node2, FusedMixOrderReductions) + + # When we fuse extra nodes into a FusedMixOrderReductions node, + # we should not allow recursive mix-order reduction being + # created. + if not self.scheduler.can_fuse(node1, node2, allow_mix_order_reduction=False): + return False + + def _get_ancestors(nodes: tuple[BaseSchedulerNode, ...]) -> OrderedSet[str]: + out = OrderedSet() + return out.union(*(n.ancestors for n in nodes)) + + def _get_operation_names( + nodes: tuple[BaseSchedulerNode, ...], + ) -> OrderedSet[str]: + out = OrderedSet() + return out.union(*(n.get_operation_names() for n in nodes)) + + if other_nodes: + if (_get_ancestors((node1, node2)) & _get_operation_names(other_nodes)) or ( + _get_ancestors(other_nodes) & _get_operation_names((node1, node2)) + ): + return False + + return ( + not node2.is_reduction() + or typing.cast( + int, self.scheduler.score_fusion_memory(node1, node2, count_bytes=False) + ) + >= self.numel + ) + + def can_fuse_with(self, other: BaseSchedulerNode): + if not isinstance(other, FusedMixOrderReductions): + return self.sub_node_can_fuse( + self.node1, other, (self.node2,) + ) or self.sub_node_can_fuse(self.node2, other, (self.node1,)) + else: + # pass empty tuple for the second since the producer/consumer relationship has + # already been checked in the first call + return self.sub_node_can_fuse( + self.node1, other.node1, (self.node2, other.node2) + ) and self.sub_node_can_fuse(self.node2, other.node2, tuple()) + + def fuse_with(self, other: BaseSchedulerNode): + device = self.node1.get_device() + backend = self.scheduler.get_backend(device) + + if isinstance(other, FusedMixOrderReductions): + fused_node1 = backend.fuse(self.node1, other.node1) + fused_node2 = backend.fuse(self.node2, other.node2) + return FusedMixOrderReductions(fused_node1, fused_node2) + else: + if self.sub_node_can_fuse(self.node1, other, (self.node2,)): + fused_node = backend.fuse(self.node1, other) + return FusedMixOrderReductions(fused_node, self.node2) + else: + fused_node = backend.fuse(self.node2, other) + return FusedMixOrderReductions(self.node1, fused_node) + + +class ForeachKernelSchedulerNode(FusedSchedulerNode): + """ + This is a schedular node that consists of a set of scheduler nodes that + has no data dependencies among them and can be executed in parallel. + """ + + def get_consumer_subnode_for( + self, producer: BaseSchedulerNode + ) -> Optional[BaseSchedulerNode]: + for buf in producer.get_outputs(): + if buf.get_name() in self.read_to_node: + return self.read_to_node[buf.get_name()] + + return None + + def get_producer_subnode_for( + self, consumer: BaseSchedulerNode + ) -> Optional[BaseSchedulerNode]: + producers = OrderedSet[BaseSchedulerNode]() + for rd in consumer.read_writes.reads: + if rd.name not in self.scheduler.name_to_buf: + continue + + node_name = self.scheduler.name_to_buf[rd.name].defining_op_name() + if node_name in self.name_to_node: + producers.add(self.name_to_node[node_name]) + + # Don't permit fusion if there are multiple subnodes + # that this consumer reads from + if len(producers) == 1: + return next(iter(producers)) + else: + return None + + @classmethod + def can_fuse(cls, producer: BaseSchedulerNode, consumer: BaseSchedulerNode) -> bool: + why = WhyNoFuse(producer, consumer) + if producer.is_foreach() and consumer.is_foreach(): + producer = typing.cast(ForeachKernelSchedulerNode, producer) + consumer = typing.cast(ForeachKernelSchedulerNode, consumer) + foreach_match = len(producer.snodes) == len(consumer.snodes) + if not foreach_match: + why("foreach do not have same length") + return foreach_match and all( + producer.scheduler.can_fuse(l, r) + for l, r in zip(producer.snodes, consumer.snodes) + ) + elif consumer.is_foreach(): + if producer.is_reduction(): + why( + "candidate producer is a reduction, foreach ops cannot be fused with reductions currently" + ) + return False + + consumer = typing.cast(ForeachKernelSchedulerNode, consumer) + consumer_subnode = consumer.get_consumer_subnode_for(producer) + if consumer_subnode is not None: + return consumer.scheduler.can_fuse(producer, consumer_subnode) + + why("candidate producer is not dep of any foreach consumer") + return False + + elif producer.is_foreach(): + if consumer.is_reduction(): + why( + "candidate consumer is a reduction, foreach ops cannot be fused with reductions currently" + ) + return False + + producer = typing.cast(ForeachKernelSchedulerNode, producer) + producer_subnode = producer.get_producer_subnode_for(consumer) + if producer_subnode is not None: + return producer.scheduler.can_fuse(producer_subnode, consumer) + + why("candidate consumer has no dep in any foreach producer") + return False + + raise AssertionError( + "At least one node passed to ForeachKernelSchedulerNode.can_fuse should be a foreach node" + ) + + @classmethod + def fuse( + cls, producer: BaseSchedulerNode, consumer: BaseSchedulerNode + ) -> ForeachKernelSchedulerNode: + assert producer.is_foreach() or consumer.is_foreach() + if producer.is_foreach(): + producer = typing.cast(ForeachKernelSchedulerNode, producer) + use_custom_partition_algo = producer.use_custom_partition_algo + enable_autotune = producer.enable_autotune + else: + consumer = typing.cast(ForeachKernelSchedulerNode, consumer) + use_custom_partition_algo = consumer.use_custom_partition_algo + enable_autotune = consumer.enable_autotune + prev_node_1 = None + prev_node_2 = None + fused_nodes: list[BaseSchedulerNode] + if producer.is_foreach() and consumer.is_foreach(): + producer = typing.cast(ForeachKernelSchedulerNode, producer) + consumer = typing.cast(ForeachKernelSchedulerNode, consumer) + fused_nodes = [ + FusedSchedulerNode.fuse(l, r) + for l, r in zip(producer.snodes, consumer.snodes) + ] + elif producer.is_foreach(): + producer = typing.cast(ForeachKernelSchedulerNode, producer) + producer_subnode = producer.get_producer_subnode_for(consumer) + fused_nodes = [] + prev_node_1 = producer + prev_node_2 = None + for node in producer.snodes: + if node is producer_subnode: + new_node = FusedSchedulerNode.fuse(node, consumer) + prev_node_2 = new_node + fused_nodes.append(new_node) + else: + fused_nodes.append(node) + + elif consumer.is_foreach(): + consumer = typing.cast(ForeachKernelSchedulerNode, consumer) + consumer_subnode = consumer.get_consumer_subnode_for(producer) + fused_nodes = [] + prev_node_1 = consumer + prev_node_2 = None + + for node in consumer.snodes: + if node is consumer_subnode: + new_node = FusedSchedulerNode.fuse(producer, node) + prev_node_2 = new_node + fused_nodes.append(new_node) + else: + fused_nodes.append(node) + else: + raise AssertionError( + "At least one node passed to ForeachKernelSchedulerNode.fuse should be a foreach node" + ) + + return cls( + producer.scheduler, + fused_nodes, + use_custom_partition_algo=use_custom_partition_algo, + prev_node_1=prev_node_1, + prev_node_2=prev_node_2, + enable_autotune=enable_autotune, + ) + + def __init__( + self, + scheduler: Scheduler, + snodes: list[BaseSchedulerNode], + use_custom_partition_algo: bool, + prev_node_1: Optional[BaseSchedulerNode] = None, + prev_node_2: Optional[BaseSchedulerNode] = None, + enable_autotune: bool = False, + ) -> None: + self.read_to_node = {} + self.name_to_node = {} + + if prev_node_1 is None or prev_node_2 is None: + super().__init__(scheduler, snodes) + + for node in snodes: + for read in node.read_writes.reads: + self.read_to_node[read.name] = node + + for name in node.get_operation_names(): + self.name_to_node[name] = node + else: + self.scheduler = scheduler + self.snodes = snodes + self.node = None + self.users: list[NodeUser] = [] + + self.set_read_writes( + dependencies.ReadWrites.merge_list( + [prev_node_1.read_writes, prev_node_2.read_writes] + ) + ) + + self.unmet_dependencies = ( + OrderedSet( + dep + for dep in OrderedSet.union( + prev_node_1.unmet_dependencies, prev_node_2.unmet_dependencies + ) + if dep.name not in self.get_buffer_names() + ) + - self.read_writes.writes + ) + + self.min_order = min([prev_node_1.min_order, prev_node_2.min_order]) + self.max_order = max([prev_node_1.max_order, prev_node_2.max_order]) + + if prev_node_1.is_foreach(): + assert isinstance(prev_node_1, ForeachKernelSchedulerNode) + foreach_node, other_node = prev_node_1, prev_node_2 + else: + assert isinstance(prev_node_2, ForeachKernelSchedulerNode) + foreach_node, other_node = prev_node_2, prev_node_1 + + self.ancestors = foreach_node.ancestors + self.ancestors.update(other_node.ancestors) + + self.name_to_node = foreach_node.name_to_node + for name in other_node.get_operation_names(): + self.name_to_node[name] = other_node + + self.outputs_by_name: dict[str, SchedulerBuffer] = { + k: v for snode in self.snodes for k, v in snode.outputs_by_name.items() + } + + self.use_custom_partition_algo = use_custom_partition_algo + device = snodes[0].get_device() + assert device + self.group = (device, ((sympy.Expr("combo_kernel"),),)) + self.origins = OrderedSet[torch.fx.Node]() + self.enable_autotune = enable_autotune + + @classmethod + def combinable_nodes( + cls, nodes: list[BaseSchedulerNode] + ) -> list[BaseSchedulerNode]: + extern = [x for x in nodes if isinstance(x, ExternKernelSchedulerNode)] + if extern: + log.debug( + "ComboKernels: %d external nodes are filtered %s", + len(extern), + [node.node.get_origins() for node in extern if node.node is not None], + ) + grouped = [x for x in nodes if isinstance(x, GroupedSchedulerNode)] + if grouped: + log.debug( + "ComboKernels: %d grouped nodes are filtered", + len(grouped), + ) + filtered_nodes = [ + x + for x in nodes + if not isinstance( + x, + ( + NopKernelSchedulerNode, + ExternKernelSchedulerNode, + GroupedSchedulerNode, + ), + ) + ] + foreach_nodes = [ + x for x in filtered_nodes if isinstance(x, ForeachKernelSchedulerNode) + ] + if foreach_nodes: + log.debug("ComboKernels: %d foreach nodes are filtered", len(foreach_nodes)) + filtered_nodes = [ + x for x in filtered_nodes if not isinstance(x, ForeachKernelSchedulerNode) + ] + template_nodes = [x for x in filtered_nodes if x.is_template()] + if template_nodes: + log.debug( + "ComboKernels: %d template nodes are filtered: %s", + len(template_nodes), + template_nodes, + ) + filtered_nodes = [x for x in filtered_nodes if x not in template_nodes] + return filtered_nodes + + @staticmethod + def _default_group_nodes_for_combo_kernels( + scheduler: Scheduler, + ) -> list[list[BaseSchedulerNode]]: + """ + Returns a list of lists of nodes that are to be grouped together. + """ + sorted_nodes = scheduler._topological_sort_nodes() + grouped_nodes = [] + max_num_nodes = 8 + for nodes in sorted_nodes: + # Group nodes by device first to avoid mixed-device fusion + device_groups: dict[Optional[torch.device], list[BaseSchedulerNode]] = ( + defaultdict(list) + ) + for node in nodes: + device = node.get_device() + if device and (device.type == "mps" or device.type == "cpu"): + continue + device_groups[device].append(node) + + # Chunk each device group separately + for device_nodes in device_groups.values(): + grouped_nodes.extend( + [ + device_nodes[i : i + max_num_nodes] + for i in range(0, len(device_nodes), max_num_nodes) + ] + ) + + return grouped_nodes + + group_algorithm_for_combo_kernels: Callable[ + [Scheduler], list[list[BaseSchedulerNode]] + ] = _default_group_nodes_for_combo_kernels + + @staticmethod + def set_group_algorithm_for_combo_kernels( + custom_group_algorithm: Callable[[Scheduler], list[list[BaseSchedulerNode]]], + ) -> None: + ForeachKernelSchedulerNode.group_algorithm_for_combo_kernels = ( + custom_group_algorithm + ) + + @staticmethod + def group_nodes_for_combo_kernels( + scheduler: Scheduler, + ) -> list[list[BaseSchedulerNode]]: + return ForeachKernelSchedulerNode.group_algorithm_for_combo_kernels(scheduler) + + def mark_run(self) -> None: + raise NotImplementedError + + def codegen(self) -> None: + raise NotImplementedError + + def is_foreach(self) -> bool: + return True + + def get_subkernel_nodes(self) -> list[BaseSchedulerNode]: + """Returns a list of nodes which comprise the combo kernel. + These nodes may be vertically fused.""" + return list(self.snodes) + + def get_nodes(self) -> Sequence[BaseSchedulerNode]: + """Returns all nodes contained in this kernel, unpacking fused nodes + into their constituent scheduler nodes.""" + return list(itertools.chain.from_iterable(x.get_nodes() for x in self.snodes)) + + def get_first_name(self) -> str: + return self.snodes[0].get_first_name() + + def prune_redundant_deps( + self, name_to_fused_node: dict[str, BaseSchedulerNode] + ) -> None: + _prune_redundant_deps(self, name_to_fused_node, self.scheduler.name_to_buf) + + for node in self.snodes: + node.prune_redundant_deps(name_to_fused_node) + + +class GroupedSchedulerNode(BaseSchedulerNode): + """ + This is a "fake" scheduler node that represents a group of scheduler nodes + that are meant to be *grouped* together (it does not allow another node to be scheduled + in between its constituent nodes, nor does it allow another node to fuse into any of its constituent nodes). + The way it does this is by maintaining its unmet dependencies as the union of its constituent nodes. + Fusion will still happen among the nodes within each GroupedSchedulerNode. + At codegen time, this scheduler node will be unpacked and codegen is called on each constituent node. + """ + + snodes: list[BaseSchedulerNode] + + @classmethod + def create(cls, snodes: list[BaseSchedulerNode]) -> GroupedSchedulerNode: + scheduler = snodes[0].scheduler + assert all(node.scheduler is scheduler for node in snodes) + grouped_snode = cls(scheduler, snodes) + for snode in snodes: + scheduler.name_to_fused_node[snode.get_name()] = grouped_snode + scheduler.name_to_fused_node[grouped_snode.get_name()] = grouped_snode + return grouped_snode + + def __init__( + self, + scheduler: Scheduler, + snodes: list[BaseSchedulerNode], + temp_grouping: bool = False, + ) -> None: + super().__init__(scheduler) + init_group_node(self, scheduler, snodes) + # This flag is introduced for "temporary" grouping during some passes, + # Where nodes are grouped and moved together. + # After the pass those nodes are flattened. + # Reusing calculation of grouped unmed_dependencies etc. + # No fusion logic in this case. + self.temp_grouping = temp_grouping + + def unpack(self) -> list[BaseSchedulerNode]: + """ + Do fusion among nodes within this GroupedSchedulerNode, + and then unpack this GroupedSchedulerNode into regular nodes. + """ + if self.temp_grouping: + return self.snodes + + for snode in self.snodes: + self.scheduler.name_to_fused_node[snode.get_name()] = snode + del self.scheduler.name_to_fused_node[self.get_name()] + return self.scheduler.fuse_nodes(self.snodes) + + def add_fake_dep(self, fake_dep: Dep) -> None: + self.set_read_writes(self.read_writes.with_read(fake_dep)) + self.unmet_dependencies.add(fake_dep) + + @cache_on_self + def get_name(self) -> str: + return "_".join([x.get_name() for x in self.snodes]) + + def get_first_name(self) -> str: + return self.snodes[0].get_name() + + @cache_on_self + def get_buffer_names(self) -> OrderedSet[str]: + return OrderedSet.union(*[x.get_buffer_names() for x in self.snodes]) + + def get_outputs(self) -> list[SchedulerBuffer]: + result: list[SchedulerBuffer] = [] + for node in self.snodes: + result.extend(node.get_outputs()) + return result + + @cache_on_self + def estimate_flops(self) -> int | None: + # don't increment counters in fused methods so we don't double count + fps = list( + filter( + None, + ( + node.estimate_flops() + for node in self.get_nodes() + if node.is_template() or node.is_extern() + ), + ) + ) + if len(fps) == 0: + return None + ret = sum(fps) + return ret + + def get_nodes(self) -> Sequence[BaseSchedulerNode]: + return self.snodes + + def get_device(self) -> Optional[torch.device]: + return self.snodes[0].get_device() if self.snodes else None + + @classmethod + def can_fuse(cls, producer: BaseSchedulerNode, consumer: BaseSchedulerNode) -> bool: + # GroupedSchedulerNode cannot be fused with another node + return False + + +def pick_loop_order( + stride_lengths: list[list[int]], + sizes: Sequence[sympy.Expr], + priority_idx: Sequence[int] = (), +) -> list[int]: + """ + A heuristic to decide loop iteration orders. This has not been well + tuned and may be something we should autotune. + """ + + @functools.cmp_to_key + def index_cmp(a: int, b: int) -> int: + if sizes[a] == 1 or sizes[b] == 1: + # 1-sizes don't matter, just move them to the end + return cmp(sizes[a] == 1, sizes[b] == 1) + + # Take abs, otherwise flipped dimensions are treated as smaller + # strides than contiguous dims + stride_len_a = [abs(sl[a]) for sl in stride_lengths] + stride_len_b = [abs(sl[b]) for sl in stride_lengths] + + # equivalent to + # np.logical_or(stride_lengths[:, b] == 0, stride_lengths[:, a] < stride_lengths[:, b]).all() + a_first = sum( + sl_b == 0 or sl_a < sl_b for sl_a, sl_b in zip(stride_len_a, stride_len_b) + ) + b_first = sum( + sl_a == 0 or sl_b < sl_a for sl_a, sl_b in zip(stride_len_a, stride_len_b) + ) + if a_first > b_first: + return -1 + if b_first > a_first: + return 1 + + # otherwise contiguous + return cmp(b, a) + + order = list(reversed(range(len(stride_lengths[0])))) + if len(priority_idx) > 0: + # if we have priority node, only use that node's order + stride_lengths = [stride_lengths[pi] for pi in priority_idx] + if config.pick_loop_orders: + order.sort(key=index_cmp) + return order + + +def _replace_operation_buffer( + orig_node: ir.MultiTemplateBuffer, new_node: ir.OperationBuffer +) -> None: + replaced_buf_name = new_node.get_name() + orig_buf_name = orig_node.get_name() + assert isinstance(orig_buf_name, str) and isinstance(replaced_buf_name, str) + + replaced_op_name = new_node.get_operation_name() + orig_op_name = orig_node.get_operation_name() + assert isinstance(orig_op_name, str) and isinstance(replaced_op_name, str) + + del V.graph.name_to_buffer[replaced_buf_name] + new_node.name = orig_buf_name + + del V.graph.name_to_op[replaced_op_name] + new_node.operation_name = orig_op_name + + orig = V.graph.buffers.index(orig_node) + V.graph.buffers.remove(new_node) + V.graph.buffers[orig] = new_node + V.graph.name_to_buffer[orig_buf_name] = new_node + + orig = V.graph.operations.index(orig_node) + V.graph.operations.remove(new_node) + V.graph.operations[orig] = new_node + V.graph.name_to_op[orig_op_name] = new_node + + +@dataclasses.dataclass +class NodeUser: + node: Union[BaseSchedulerNode, OutputNode] + can_inplace: bool = False + + # A weak user must be scheduled after a given node, but doesn't actually + # use the result + is_weak: bool = False + + def __hash__(self) -> int: + return hash((self.node.get_name(), self.can_inplace, self.is_weak)) + + def __eq__(self, other: object) -> bool: + return ( + isinstance(other, NodeUser) + and self.get_name() == other.get_name() + and self.can_inplace == other.can_inplace + and self.is_weak == other.is_weak + ) + + def get_name(self) -> str: + return self.node.get_name() + + def merge(self, other: NodeUser) -> NodeUser: + assert self.node is other.node + return NodeUser( + self.node, + self.can_inplace and other.can_inplace, + self.is_weak and other.is_weak, + ) + + +_post_grad_graph_counter = itertools.count() + + +def used_non_deterministic_runtime_estimations() -> bool: + return config.runtime_estimations_mms_benchmark + + +def get_layout_symints(node: ir.IRNode) -> OrderedSet[sympy.Symbol]: + """Get free symbols from a node's layout (size, stride, offset).""" + free_symbol_uses: OrderedSet[sympy.Symbol] = OrderedSet() + layout = node.maybe_get_layout() + if isinstance(layout, ir.Layout): + free_symbol_uses.update( + free_symbols(layout.size) + | free_symbols(layout.stride) + | free_symbols(layout.offset) + ) + if isinstance(layout, ir.MutationLayoutSHOULDREMOVE): + # symint may be used as index in layout.target + free_symbol_uses.update(get_layout_symints(layout.target)) + else: + assert layout is None, f"Expect layout to be None but found layout={layout}" + return free_symbol_uses + + +def get_scheduler_node_symbol_uses( + node: BaseSchedulerNode, +) -> OrderedSet[sympy.Symbol]: + """ + Gets symbols used in a scheduler node, including free symbols from + the node's operations and layout symints from outputs. + """ + if isinstance(node, FusedSchedulerNode): + return OrderedSet().union( + *(get_scheduler_node_symbol_uses(snode) for snode in node.snodes) + ) + assert node.node is not None + free_symbol_uses = node.node.get_free_symbol_uses() + free_symbol_uses.update( + *(get_layout_symints(ir_node) for ir_node in node.node.get_outputs()) + ) + return free_symbol_uses + + +class Scheduler: + """ + A Scheduler is a graph of BaseSchedulerNodes. It is responsible for + optimizations such as fusion, reorder, and graph partition. + """ + + def __init__(self, nodes: list[ir.Operation]) -> None: + with dynamo_timed("Scheduler.__init__"): + self._init(nodes) + + def _init(self, nodes: list[ir.Operation]) -> None: + super().__init__() + V.graph.scheduler = self + self.backends: dict[torch.device, BaseScheduling] = {} + self.post_grad_graph_id = next(_post_grad_graph_counter) + self._graph_partition_counter = itertools.count() + + self.completed_operations: OrderedSet[str] = OrderedSet() + self.available_buffer_names = OrderedSet( + [ + *V.graph.graph_inputs.keys(), + *V.graph.constants.keys(), + *V.graph.torchbind_constants.keys(), + ] + ) + self.nodes = [self.create_scheduler_node(n) for n in nodes] + self.current_node: Optional[BaseSchedulerNode] = None + self.update_zero_dim_cpu_tensor() + # some new constants could have been created above + self.available_buffer_names.update(V.graph.constants.keys()) + for node in self.nodes: + node.prune_deps() + + # See [Note: Graph Partition Device Contexts] + self.default_device_context: Optional[torch.device] = None + + self.name_to_donated_buffer: dict[str, SchedulerDonatedBuffer] = ( + self.get_donated_buffers() + ) + self.name_to_node: dict[str, BaseSchedulerNode] = { + n.get_name(): n for n in self.nodes + } + self.name_to_buf: dict[str, SchedulerBuffer] = { + buf.get_name(): buf for node in self.nodes for buf in node.get_outputs() + } + self.name_to_fused_node: dict[str, BaseSchedulerNode] = self.name_to_node.copy() + + # mutation_real_name: Maps back to the original name for codegen + # Example: + # If you mutate buf0 inside of buf1's kernel, then: + # mutation_real_name = {"buf0" : "buf1"} + # all subsequent uses of buf0 become buf1's usage in dependency graph + self.mutation_real_name: dict[str, str] = {} + + # We handle mutation by renaming modified versions of the same + # buffer in the dependency graph to prevent cycles. + # mutation_renames: tracks the current name for a given buffer + # (changed once per mutation) + # Example: + # If you mutate buf0 inside of buf1's kernel, then: + # mutation_renames = {"buf1" : "buf0"} + # in codegen we only use buf0, never buf1 + self.mutation_renames: dict[str, str] = {} + + # Must run first to correctly set dependencies, before all other passes that rely on + # reading from .read_writes.reads or .unmet_dependencies + self.nodes = comms.decide_global_ordering_of_comms( + self.nodes, + self.name_to_buf, + self.name_to_fused_node, + ) + + self.compute_dependencies() + self.nodes = self.topological_sort_schedule(self.nodes) + self.dead_node_elimination() + self.name_to_fused_node = {n.get_name(): n for n in self.nodes} + self.compute_ancestors() + + # pyrefly: ignore [bad-assignment] + metrics.ir_nodes_pre_fusion += len(self.nodes) + from torch._inductor.debug import log_ir_post_fusion, log_ir_pre_fusion + + log_ir_pre_fusion(self.nodes) + self.num_orig_nodes = len(self.nodes) + self.create_foreach_nodes() + self.nodes = self.topological_sort_schedule(self.nodes) + self.logged_slow_fusion = OrderedSet[tuple[str, str]]() + if config._pre_fusion_custom_pass is not None: + self.nodes = config._pre_fusion_custom_pass(self.nodes) + + if config.distributed_max_autotune_gemm: + from . import distributed_autotune + + distributed_autotune.schedule(self) + self.compute_ancestors() + + self.nodes = self.fuse_nodes(self.nodes) + if config._post_fusion_custom_pass is not None: + self.nodes = config._post_fusion_custom_pass(self.nodes) + + self.merge_loops() + self.finalize_multi_template_buffers() + if config.combo_kernels: + with dynamo_timed( + "Scheduler.create_combo_kernel_nodes", + log_pt2_compile_event=True, + log_waitcounter=True, + ): + self.create_combo_kernel_nodes(num_ck_nodes=None) + + # Peak memory pass and overlap pass must run last, otherwise + # other reordering passes could undo their effects. + if config.reorder_for_peak_memory: + from .memory import reorder_for_peak_memory + + self.nodes = reorder_for_peak_memory( + self.nodes, + self.name_to_buf, + self.name_to_fused_node, + OrderedSet(V.graph.graph_inputs.keys()), + OrderedSet(V.graph.get_output_names()), + ) + + # reorder_for_compute_comm_overlap may do benchmarking to estimate + # op runtime. Disable it for now in deterministic mode. + if not config.deterministic and config.reorder_for_compute_comm_overlap: + if not config.reorder_for_peak_memory: + from .memory import assign_memory_planning_info_for_scheduler_buffers + + assign_memory_planning_info_for_scheduler_buffers( + self.nodes, self.name_to_buf + ) + + if ( + used_non_deterministic_runtime_estimations() + and config_comms.runtime_estimations_align_across_all_distributed_ranks + and ( + config.runtime_estimations_mms_benchmark + or config_comms.runtime_estimations_use_nccl_lib_estimations + ) + ): + has_collectives = False + for node in self.nodes: + if is_collective(node.node): + has_collectives = True + break + if has_collectives: + from .comms import ( + align_runtime_estimations_across_all_distributed_ranks, + ) + + align_runtime_estimations_across_all_distributed_ranks(self.nodes) + + from torch._logging import trace_structured + + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "scheduler_nodes_before_comm_overlap", + "encoding": "string", + }, + payload_fn=lambda: "\n\n".join( + [ + f"snode[{i}]" + + n.debug_str() + + f" buffer_names:{n.get_buffer_names()}" + for i, n in enumerate(self.nodes) + ] + ), + ) + self.nodes = comms.reorder_compute_and_comm_for_overlap(self.nodes) + self.process_grouped_nodes() + + if ( + # pyrefly: ignore[unbound-name] + config.graph_partition + # pyrefly: ignore[unbound-name] + and config.triton.cudagraphs + # pyrefly: ignore[unbound-name] + and config.triton.reorder_for_reducing_graph_partitions + ): + self.nodes = self.maybe_reorder_for_minimizing_partition(self.nodes) + self.nodes = self.reorder_for_partition_with_simple_dependency(self.nodes) + + self.compute_last_usage() + + if torch._inductor.config.test_configs.track_memory_lifecycle: + self.insert_memory_check_nodes() + + log_ir_post_fusion(self.nodes) + # pyrefly: ignore[unbound-name] + V.debug.graph_diagram(self.nodes) + self.debug_draw_graph() + + # used during codegen: + self.buffer_names_to_free: OrderedSet[str] = OrderedSet() + + # fx graph node to the position it appears in the graph + # for debug attribution + self.origin_to_index: dict[torch.fx.Node, int] = {} + + get_metric_table("graph_stats").add_row( + lambda: { + "graph_id": self.post_grad_graph_id, + "num_nodes_before_fusion": self.num_orig_nodes, + "num_nodes_after_fusion": len(self.nodes), + } + ) + + # Unlike V.graph.removed_buffers, the op recorded here is removed but + # we still need the buffer (generated in alternative ways) + self.removed_ops: OrderedSet[str] = OrderedSet() + + def get_donated_buffers(self) -> dict[str, SchedulerDonatedBuffer]: + name_to_donated_buf = {} + for name in V.graph.graph_inputs_original: + if isinstance(V.graph.graph_inputs_original[name], ir.DonatedBuffer): + name_to_donated_buf[name] = SchedulerDonatedBuffer( + self, + V.graph.graph_inputs_original[name], + defining_op=None, + ) + return name_to_donated_buf + + @property + def current_device(self) -> Optional[torch.device]: + return V.graph.current_device + + @current_device.setter + def current_device(self, device: Optional[torch.device]) -> None: + V.graph.current_device = device + + def debug_draw_graph(self) -> None: + """Generate an image of the graph for debugging""" + if os.environ.get("INDUCTOR_WRITE_SCHEDULER_GRAPH", None) == "1": + from .debug import draw_buffers + + draw_buffers(self.nodes, print_graph=True) + + def debug_print_nodes(self, label: str) -> None: + if log.isEnabledFor(logging.INFO): + log.info("%s:", label) + for node in self.nodes: + node.log_details() + + def create_scheduler_node(self, node: ir.Operation) -> BaseSchedulerNode: + assert node.get_origins() is not None, ( + "All nodes passed to scheduling must have an origin" + ) + if node.is_no_op(): + return NopKernelSchedulerNode(self, node) + elif isinstance(node, (ir.ComputedBuffer, ir.TemplateBuffer)): + return SchedulerNode(self, node) + elif isinstance(node, ir.ExternKernel): + return ExternKernelSchedulerNode(self, node) + else: + raise NotImplementedError(node) + + def create_foreach_nodes(self) -> None: + removed_node_names: OrderedSet[str] = OrderedSet() + fe_nodes = [] + kept_node_names = self.name_to_fused_node.keys() + + for names in V.graph.lists.values(): + names = [ + name + for name in names + if name in kept_node_names + and not isinstance(self.name_to_node[name], NopKernelSchedulerNode) + ] + if not names: + # All nodes eliminated + continue + + removed_node_names.update(names) + snodes = [self.name_to_node[name] for name in names] + + enable_autotune = config.combo_kernels_autotune > 1 + fe_node = ForeachKernelSchedulerNode( + self, + snodes, + use_custom_partition_algo=False, + enable_autotune=enable_autotune, + ) + + fe_nodes.append(fe_node) + + for name in names: + self.name_to_fused_node[name] = fe_node + + self.nodes = [ + node for node in self.nodes if node.get_name() not in removed_node_names + ] + list(fe_nodes) + + def compute_dependencies(self) -> None: + """ + Create dependency edges between nodes, handling aliasing and + mutation properly. + """ + + class DedupList(Generic[_T]): + """ + This data structure behaves like a list except it makes sure the + elements remain unique. + Normally one could use a OrderedSet/dict for this purpose however + the list in question gets elements appended as it is being + iterated over which means that we need to keep the list + semantics. + """ + + def __init__( + self, + items: Optional[list[_T]] = None, + membership: Optional[OrderedSet[_T]] = None, + ) -> None: + self.items = items or [] + self.membership = membership or OrderedSet() + + def append(self, node_user: _T) -> None: + if node_user in self.membership: + return + self.items.append(node_user) + self.membership.add(node_user) + + def __add__(self, other: DedupList[_T]) -> DedupList[_T]: + new_membership = OrderedSet.union(self.membership, other.membership) + new_items = self.items + [ + x for x in other.items if x not in self.membership + ] + return DedupList(new_items, new_membership) + + # pyrefly: ignore [not-a-type] + name_to_users: defaultdict[str, DedupList[NodeUser]] = collections.defaultdict( + DedupList + ) + + # handle aliasing by using python aliasing in name_to_users + # if foo aliases bar then we will make name_to_users["foo"] point + # to the same python list as name_to_users["bar"] + for node in self.nodes: + for buf1 in node.get_outputs(): + buf1_name = buf1.get_name() + # This is for handling auto functionized ops which return None + # and mutate more than 1 inputs, we shouldn't let them all + # point to the same user list since buffers in the aliases + # list might not be alias to each other. + if ( + isinstance(buf1.node.layout, ir.NoneLayout) + and len(buf1.get_aliases()) > 1 + ): + continue + for buf2_name in buf1.get_aliases(): + if buf1_name in name_to_users and buf2_name in name_to_users: + # merge the two + list1 = name_to_users[buf1_name] + list2 = name_to_users[buf2_name] + combined = list1 + list2 + for key in name_to_users: + if ( + name_to_users[key] is list1 + or name_to_users[key] is list2 + ): + name_to_users[key] = combined + elif buf1_name in name_to_users: + name_to_users[buf2_name] = name_to_users[buf1_name] + else: + name_to_users[buf1_name] = name_to_users[buf2_name] + + # pyrefly: ignore [not-a-type] + def rename(n: str) -> str: + if n in self.mutation_renames: + return rename(self.mutation_renames[n]) + return n + + def add_user( + # pyrefly: ignore [not-a-type] + used_by_name: str, + user_node: Union[BaseSchedulerNode, OutputNode], + can_inplace: bool = False, + is_weak: bool = False, + ) -> None: + name_to_users[rename(used_by_name)].append( + NodeUser(user_node, can_inplace, is_weak) + ) + + # pyrefly: ignore [not-a-type] + unbacked_symbol_to_origin_node: dict[sympy.Symbol, Optional[str]] = {} + + # NB: None means that the dependency is on an input. Don't actually + # generate a dependency because if we do, Inductor will start trying + # to free the unbacked int but that's pointless + for val in V.graph.graph_inputs.values(): + if isinstance(val, sympy.Expr): + for fs in val.free_symbols: + unbacked_symbol_to_origin_node[fs] = None + elif isinstance(val, ir.TensorBox): + # We also need to add symbols from input size as well because + # AOTI doesn't lift the unbacked symints to inputs + sym_size = [s for s in val.get_size() if isinstance(s, sympy.Expr)] + for s in sym_size: + for fs in s.free_symbols: + unbacked_symbol_to_origin_node[fs] = None + + has_non_input_unbacked_defs = False + for node in self.nodes: + assert node.node is not None + # unbacked symbols don't follow ordinary buffer dependencies, so + # we track their def/uses separately + unbacked_symbol_defs = sorted( + node.node.get_unbacked_symbol_defs(), key=lambda x: x.name + ) + for s in unbacked_symbol_defs: + assert isinstance(s, sympy.Symbol) + # Pick the first definer as canonical. There may be multiple + # because if a MultiOutputLayout buffer propagates an unbacked + # symint to multiple outputs, they will all claim to def it. + has_non_input_unbacked_defs = True + if s not in unbacked_symbol_to_origin_node: + unbacked_symbol_to_origin_node[s] = node.get_name() + + for node in self.nodes: + log.debug("scheduling %s", node.node) + + if has_non_input_unbacked_defs: + assert node.node is not None + + unbacked_symbol_uses = sorted( + node.node.get_free_symbol_uses(unbacked_only=True), + key=lambda x: x.name, + ) + # if a kernel takes unbacked symints, register dependencies + for s in unbacked_symbol_uses: + assert s in unbacked_symbol_to_origin_node, ( + f"{s} not in {unbacked_symbol_to_origin_node}" + ) + if (r := unbacked_symbol_to_origin_node[s]) is not None: + for buf in self.name_to_node[r].get_outputs(): + node.add_fake_dep(StarDep(buf.get_name())) + + if ( + len(node.read_writes.writes) == 1 + and (dep := next(iter(node.read_writes.writes))) + and isinstance(dep, MemoryDep) + ): + node_mode = dep.mode + else: + node_mode = None + + # Handle output mutations + for buf in node.get_outputs(): + # a node will mutate either 0 or 1 buffers + assert len(buf.get_mutations()) <= 1 + for alt_name in buf.get_mutations(): + alt_name = rename(alt_name) + # this node must run after the prior writer + add_user(alt_name, node) + node.add_fake_dep(StarDep(alt_name, mode=node_mode)) + for user in name_to_users[alt_name].items: + if user.get_name() == node.get_name(): + continue + + assert isinstance(user.node, BaseSchedulerNode) + for other_name in user.node.get_buffer_names(): + # this node must run after all prior readers + other_name = rename(other_name) + node.add_fake_dep( + WeakDep(other_name, mutating_buf=buf.get_name()) + ) + add_user(other_name, node, is_weak=True) + + for add_dep in V.graph.additional_buffer_deps[node.get_name()]: + add_user(add_dep, node, is_weak=True) + # is_fake=True because these are control dependencies for ordering only, + # they should not extend buffer lifetimes + node.add_fake_dep(WeakDep(add_dep, node.get_name(), is_fake=True)) + + for add_dep in V.graph.additional_star_deps[node.get_name()]: + add_user(add_dep, node, is_weak=False) # Strong dependency + node.add_fake_dep(StarDep(add_dep)) + + # add normal non-mutation dependencies + for read in node.read_writes.reads: + if not isinstance(read, WeakDep): + add_user(read.name, node, node.can_inplace(read)) + + node.update_mutated_names(self.mutation_renames) + + # update our renaming scheme for the next iteration + for buf in node.get_outputs(): + for alt_name in buf.get_mutations(): + self.mutation_renames[rename(alt_name)] = buf.get_name() + self.mutation_renames[alt_name] = buf.get_name() + self.mutation_real_name[buf.get_name()] = ( + self.mutation_real_name.get(alt_name, alt_name) + ) + + # make sure outputs aren't dead-code-eliminated + for buf_name in V.graph.get_output_names(): + log.debug("scheduling output %s", buf_name) + add_user(buf_name, OutputNode(StarDep(buf_name))) + + # make sure unbacked symints aren't dead-code-eliminated + if has_non_input_unbacked_defs: + for out in V.graph.graph_outputs: + for s in out.get_free_symbol_uses(unbacked_only=True): + assert s in unbacked_symbol_to_origin_node, ( + f"{s} not in {unbacked_symbol_to_origin_node.keys()}" + ) + if r := unbacked_symbol_to_origin_node[s]: + for buf_name in self.name_to_node[r].get_buffer_names(): + log.debug( + "scheduling output %s for unbacked symint %s", + buf_name, + s, + ) + add_user(buf_name, OutputNode(StarDep(buf_name))) + + # make sure input mutation isn't dead-code-eliminated + for name in self.mutation_renames: + if name in V.graph.graph_inputs: + add_user(name, OutputNode(StarDep(name))) + V.graph.mutated_inputs.add(name) + elif name in V.graph.constants: + # In AOTI, module parameters and buffers are not lifted as graph inputs + add_user(name, OutputNode(StarDep(name))) + + inp_names = { + name: index for index, name in enumerate(V.graph.graph_inputs.keys()) + } + V.graph.mutated_input_idxs = [ + inp_names[name] for name in V.graph.mutated_inputs + ] + + # copy users information onto the nodes + for node in self.nodes: + for buf in node.get_outputs(): + buf.set_users(name_to_users[buf.get_name()].items) + + for name in self.name_to_donated_buffer: + self.name_to_donated_buffer[name].set_users(name_to_users[name].items) + + # For debug logging + logbuf = IndentedBuffer() + logbuf.splice("{") + for key, value in name_to_users.items(): + with logbuf.indent(): + users = [v.get_name() for v in value.items] + logbuf.splice(f"'{key}': {users},") + logbuf.splice("}") + str = logbuf.getrawvalue().rstrip() + compute_dependencies_log.debug("BUFFER USER LIST\n") + compute_dependencies_log.debug("===== AFTER SCHEDULING =====\n%s", str) + + def insert_memory_check_nodes(self) -> None: + from .memory import ( + assign_memory_planning_info_for_scheduler_buffers, + compute_memory_timeline, + FreeableInputBuffer, + get_freeable_input_buf, + ) + + graph_inputs: OrderedSet[str] = OrderedSet(V.graph.graph_inputs.keys()) + name_to_freeable_input_buf: dict[str, FreeableInputBuffer] = ( + get_freeable_input_buf(self.nodes, graph_inputs) + ) + + if not torch._inductor.config.reorder_for_peak_memory: + assign_memory_planning_info_for_scheduler_buffers( + self.nodes, self.name_to_buf + ) + + graph_outputs: OrderedSet[str] = OrderedSet(V.graph.get_output_names()) + buf_info_list, _, _ = compute_memory_timeline( + self.nodes, + name_to_freeable_input_buf, + graph_outputs, + ) + + step_allocs_deallocs: list[tuple[list[str], list[str]]] = [ + ([], []) for _ in range(len(self.nodes)) + ] + for buf_info in buf_info_list: + # Skip zero-size buffers + if buf_info.size_alloc == 0 and buf_info.size_free == 0: + continue + + buf_name = buf_info.buffer.get_name() + + step_allocs_deallocs[buf_info.start_step][0].append(buf_name) + step_allocs_deallocs[buf_info.end_step][1].append(buf_name) + + from torch._inductor.runtime.debug_utils import register_check_mem_op + + register_check_mem_op() + + def construct_mem_check_node( + step_idx: int, is_final_step: bool + ) -> ExternKernelSchedulerNode: + expected_newly_alive = step_allocs_deallocs[step_idx][0] + expected_newly_dead = step_allocs_deallocs[step_idx][1] + + nontensor_args = [expected_newly_alive, expected_newly_dead, is_final_step] + + node = ir.MemoryCheckKernel( + layout=NoneLayout(device=torch.device("cpu")), + kernel=torch.ops._inductor_debug.check_memory_step.default, + tensor_args=[], + nontensor_args=nontensor_args, + unflatten_args=lambda tensor_args, constant_args: ( + tensor_args, + { + "alive": constant_args[0], + "dead": constant_args[1], + "is_final_step": constant_args[2], + }, + ), + ) + node.operation_name = f"mem_check_{self.nodes[step_idx].get_name()}" + return ExternKernelSchedulerNode(self, node) + + new_nodes = [] + + for i, node in enumerate(self.nodes): + new_nodes.append(node) + new_nodes.append( + construct_mem_check_node(i, is_final_step=(i == len(self.nodes) - 1)) + ) + + self.nodes = new_nodes + + def dead_node_elimination(self) -> None: + """ + Remove any nodes without users + """ + if not config.use_dce: + return + + # self.nodes is in topological order, so by iterating in reverse order + # we have visited (and potentially removed) all users before visiting a + # given node. + updated_nodes = [] + for node in reversed(self.nodes): + + def can_eliminate_user(user: NodeUser) -> bool: + return user.is_weak or user.get_name() in V.graph.removed_operations + + active_buffers = False + for buf in node.get_outputs(): + can_eliminate = all(can_eliminate_user(u) for u in buf.users) + if can_eliminate: + log.debug("removed dead buffer: %s", buf.get_name()) + V.graph.removed_buffers.add(buf.get_name()) + else: + active_buffers = True + + can_eliminate = not node.has_side_effects() and not active_buffers + + if not can_eliminate: + updated_nodes.append(node) + else: + # dead code + log.debug("removed dead operation: %s", node.get_name()) + V.graph.removed_operations.add(node.get_name()) + for read in node.read_writes.reads: + if read.name in self.name_to_buf: + users = self.name_to_buf[read.name].users + self.name_to_buf[read.name].users = [ + u for u in users if u.node.get_name() != node.get_name() + ] + self.nodes = list(reversed(updated_nodes)) + + # Prune any WeakDeps no longer needed + for node in self.nodes: + node.prune_weak_deps() + + def topological_sort_schedule( + self, nodes: list[BaseSchedulerNode] + ) -> list[BaseSchedulerNode]: + """ + Ensure nodes is in topologically sorted order + """ + seen = OrderedSet[BaseSchedulerNode]() + name_to_node: dict[str, BaseSchedulerNode] = dict() + result: list[BaseSchedulerNode] = [] + + def visit(n: BaseSchedulerNode) -> None: + if n not in seen: + seen.add(n) + for dep in sorted(n.unmet_dependencies, key=lambda d: d.name): + # We only care about doing toposort within `nodes` + if dep.name not in name_to_node: + continue + visit(name_to_node[dep.name]) + result.append(n) + + for node in nodes: + for name in node.get_buffer_names(): + name_to_node[name] = node + for node in nodes: + visit(node) + return result + + def _get_unmet_dep_nodes(self, snode: BaseSchedulerNode) -> list[BaseSchedulerNode]: + unmet_deps: OrderedSet[str] = OrderedSet() + if isinstance( + snode, + ( + SchedulerNode, + ExternKernelSchedulerNode, + NopKernelSchedulerNode, + FusedSchedulerNode, + GroupedSchedulerNode, + ), + ): + for dep in snode.unmet_dependencies: + unmet_deps.add(dep.name) + else: + raise RuntimeError( + f"get_unmet_dep_nodes is not implemented for {type(snode)}." + ) + unmet_dep_ops = (self.name_to_buf[dep].defining_op_name() for dep in unmet_deps) + return list(OrderedSet(self.name_to_fused_node[n] for n in unmet_dep_ops)) + + def _topological_sort_nodes(self) -> list[list[BaseSchedulerNode]]: + """ + Sort nodes by their topological order, return a list of node lists. + """ + order = [] + nodes = dict.fromkeys(self.nodes, 0) + children: dict[Any, Any] = {} + for node in self.nodes: + deps = self._get_unmet_dep_nodes(node) + nodes[node] = len(deps) + for dep in deps: + c = children.get(dep, []) + c.append(node) + children[dep] = c + + zero_deg_nodes = [n for n, v in nodes.items() if v == 0] + while zero_deg_nodes: + order.append(zero_deg_nodes) + for n in zero_deg_nodes: + for user in children.get(n, []): + nodes[user] -= 1 + nodes.pop(n) + zero_deg_nodes = [n for n, v in nodes.items() if v == 0] + assert not nodes, "Topological sort failed!" + return order + + def compute_ancestors(self) -> None: + """ + Populate each node.ancestors + """ + # note self.nodes is topologically sorted + name_to_ancestors: dict[str, OrderedSet[str]] = {} + for node in self.nodes: + ancestors: OrderedSet[str] = OrderedSet() + for dep in node.unmet_dependencies: + dep_node_name = self.name_to_buf[dep.name].defining_op_name() + ancestors.add(dep_node_name) + ancestors |= name_to_ancestors[dep_node_name] + name_to_ancestors[node.get_name()] = ancestors + node.ancestors = ancestors + + for order, node in enumerate(self.nodes): + node.min_order = order + node.max_order = order + + def merge_loops(self) -> None: + if not config.loop_ordering_after_fusion: + return + + for node in self.nodes: + # Even for CPU, if we are using the halide backend, we still need + # the merge loops steps below + if not isinstance(node, (SchedulerNode, FusedSchedulerNode)) or ( + not node.is_gpu() and config.cpu_backend != "halide" + ): + continue + for snode in node.get_nodes(): + # merge loops for the scheduler node + if not isinstance(snode, SchedulerNode) or snode.is_template(): + continue + + snode.merge_loops() + + # Note that for CPU backend, merging loops will change + # snode.group. It's fine for Triton backend. + # But if we simplify update snode.group like this: + # group_fn = self.get_backend(snode.node.get_device()).group_fn + # snode.group = (snode.node.get_device(), group_fn(snode._sizes)) + # There is still an issue due to different snode in a + # FusedSchedulerNode having different merged loops. + # Skip CPU backend for now. + + def fuse_nodes(self, nodes: list[BaseSchedulerNode]) -> list[BaseSchedulerNode]: + """ + Combine eligible nodes into FusedSchedulerNodes. + """ + with dynamo_timed( + "Scheduler.fused_nodes", log_pt2_compile_event=True, log_waitcounter=True + ): + for i in range(10): + old_len = len(nodes) + fusion_log.debug( + "===== attempting fusion (%d/10): %d nodes =====", + i + 1, + old_len, + ) + nodes = self.fuse_nodes_once(nodes, is_reorder_round=False) + new_len = len(nodes) + fusion_log.debug( + "completed fusion round (%d/10): fused %d nodes into %d nodes\n", + i + 1, + old_len, + new_len, + ) + if new_len == old_len or new_len == 1: + fusion_log.debug( + "===== fusion complete (%d iterations) =====", i + 1 + ) + break + + if ( + config.loop_ordering_after_fusion + or config.loop_index_inversion_in_fusion + ): + nodes = self.fuse_nodes_once(nodes, is_reorder_round=True) + return nodes + + def process_grouped_nodes(self) -> None: + """ + Unpack GroupedSchedulerNode into regular nodes. + """ + new_nodes: list[BaseSchedulerNode] = [] + for node in self.nodes: + new_nodes.extend( + node.unpack() if isinstance(node, GroupedSchedulerNode) else [node] + ) + self.nodes = new_nodes + + def benchmark_fused_nodes( + self, nodes: Sequence[BaseSchedulerNode] + ) -> tuple[float, str]: + """ + Benchmark fused list of nodes and return the execution time + in milliseconds on randomly generated inputs. + """ + assert len(nodes) > 0 + device = nodes[0].get_device() + self.current_device = device + backend = self.get_backend(device) + with dynamo_timed( + "benchmark_fused_nodes", + log_pt2_compile_event=True, + dynamo_compile_column_us="compile_time_autotune_time_us", + ): + return backend.benchmark_fused_nodes(nodes) + + def generate_kernel_code_from_nodes( + self, + nodes: Sequence[BaseSchedulerNode], + benchmark_kernel: bool, + hint_override: Optional[int] = None, + ) -> str: + """ + Benchmark fused list of nodes and return the execution time + in milliseconds on randomly generated inputs. + """ + assert len(nodes) > 0 + device = nodes[0].get_device() + self.current_device = device + backend = self.get_backend(device) + with dynamo_timed("benchmark_fused_nodes"): + return backend.generate_kernel_code_from_nodes( + nodes, benchmark_kernel, hint_override=hint_override + ) + + def benchmark_codegened_module( + self, module: ModuleType, device: torch.device + ) -> tuple[float, str]: + """ + Benchmark fused list of nodes and return the execution time + in milliseconds on randomly generated inputs. + """ + self.current_device = device + backend = self.get_backend(device) + with dynamo_timed("benchmark_fused_nodes"): + return backend.benchmark_codegened_module(module) + + def finalize_multi_template_buffers(self) -> None: + """ + Finalize a backing choice for MultiTemplateBuffers which did not already have a + choice finalized through fusion. In the case of an extern choice, this will result + in replacing the SchedulerNode. + + If a MultiTemplateBuffer did not have any fusion opportunities, finalizing a choice + will force completion of compilation and benchmarking. + """ + + for i, node in enumerate(self.nodes): + if isinstance(node, SchedulerNode) and isinstance( + node.node, ir.MultiTemplateBuffer + ): + multi_node = node.node + if not config.test_configs.force_extern_kernel_in_multi_template: + min_node_unfused, _ = multi_node.get_min_choice() + else: + min_node_unfused = next( + ( + timing + for timing in multi_node.choice_timings() + if isinstance( + timing, + torch._inductor.select_algorithm.ExternKernelCaller, + ) + ), + ) + + if isinstance( + min_node_unfused, + torch._inductor.ir.TritonTemplateCallerBase, + ): + if config.multi_kernel_hints: + callers: dict[Optional[int], TritonTemplateCallerBase] = {} + callers[None] = min_node_unfused + + for hint in config.multi_kernel_hints: + timings = multi_node.choice_timings(hint_override=hint) + triton_timings = { + k: v + for k, v in timings.items() + if isinstance(k, TritonTemplateCallerBase) + } + choice = min(triton_timings.items(), key=lambda x: x[1])[0] + callers[hint] = choice + + node.node.finalize_as_triton_callers(callers) + else: + node.node.finalize_as_triton_caller(min_node_unfused) + continue + + with ir.IRNode.current_origins(multi_node.origins): + out_tensorbox = min_node_unfused.output_node() + out_storage = out_tensorbox.data # type: ignore[union-attr] + assert isinstance(out_storage, ir.StorageBox) + out_buffer = out_storage.data + assert isinstance(out_buffer, ir.OperationBuffer) + + if multi_node.origin_node: + assign_origin_node(out_tensorbox, multi_node.origin_node) + + out_buffer.layout = multi_node.layout + self._replace_node(out_buffer, multi_node, i, node) + + def _replace_node( + self, + out_buffer: ir.OperationBuffer, + multi_node: ir.MultiTemplateBuffer, + i: int, + node: SchedulerNode, + ) -> None: + _replace_operation_buffer(multi_node, out_buffer) + new_scheduler_node = self.create_scheduler_node(out_buffer) + + self.nodes[i] = new_scheduler_node + self.name_to_node[node.get_name()] = new_scheduler_node + self.name_to_fused_node[node.get_name()] = new_scheduler_node + + # We need to reflect the mutation renames that were recorded in the original node + mutation_renames = {} + for dep in itertools.chain(node.read_writes.reads, node.unmet_dependencies): + if real_name := self.mutation_real_name.get(dep.name, None): + mutation_renames[real_name] = dep.name + + def rename_deps(deps: OrderedSet[Dep]) -> OrderedSet[Dep]: + return OrderedSet(dep.rename(mutation_renames) for dep in deps) + + new_scheduler_node.unmet_dependencies = rename_deps( + new_scheduler_node.unmet_dependencies + ) + new_scheduler_node.read_writes.reads = rename_deps( + new_scheduler_node.read_writes.reads + ) + + for new_out, old_out in zip( + new_scheduler_node.get_outputs(), node.get_outputs() + ): + self.name_to_buf[old_out.get_name()] = new_out + new_out.users = old_out.users + + new_scheduler_node.min_order = node.min_order + new_scheduler_node.max_order = node.max_order + new_scheduler_node.ancestors = node.ancestors + new_scheduler_node.last_usage = node.last_usage + + def _any_atomic_add(self, node_list: Sequence[BaseSchedulerNode]) -> bool: + return any( + hasattr(n.node, "data") + and n.node is not None + and hasattr(n.node.data, "scatter_mode") + and n.node.data.scatter_mode == "atomic_add" + for n in node_list + ) + + def speedup_by_fusion( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> Union[bool, Callable[[], bool]]: + """ + If config.benchmark_fusion is False, always return True. + Otherwise, return True if fusion can brings speedup. + """ + + is_multi_template = any( + n.is_template() + and isinstance(n.get_template_node(), ir.MultiTemplateBuffer) + for n in (node1, node2) + ) + if not config.benchmark_fusion and not is_multi_template: + return True + + if ( + node1.is_template() + and not isinstance(node1.get_template_node(), ir.TritonTemplateBuffer) + or node1.is_foreach() + or node2.is_foreach() + ): + # TODO support benchmarking epilogue fusion + return True + + node_list_1 = node1.get_nodes() + device = node_list_1[0].get_device() + assert device + + # don't support benchmark fusion for CPU C++ backend right now. + if device.type == "cpu" and config.cpu_backend != "triton": + return True + + node_list_2 = node2.get_nodes() + node_list_fused = list(itertools.chain(node_list_1, node_list_2)) + + # We can not accurately benchmark kernel using atomic_add + # due to how we generate random integer inputs. + # Skip benchmarking them by allowing fusion. + if self._any_atomic_add(node_list_fused): + return True + + from triton.compiler.errors import CompilationError + + why = WhyNoFuse(node1, node2) + + device = node_list_fused[0].get_device() + assert device is not None + + def log_fusion(ms_fused: float, ms1: float, ms2: float) -> None: + if fusion_log.isEnabledFor(logging.DEBUG): + if ms_fused < ms1 + ms2: + fusion_log.debug( + "can fuse (benchmark): fusing %s with %s cause %sx speedup", + node1.get_buffer_names(), + node2.get_buffer_names(), + green_text(f"{(ms1 + ms2) / ms_fused:.3f}"), + ) + else: + fusion_log.debug( + "cannot fuse (benchmark): fusing %s with %s cause %sx slowdown", + node1.get_buffer_names(), + node2.get_buffer_names(), + red_text(f"{ms_fused / (ms1 + ms2):.3f}"), + ) + + async_compile = torch._inductor.async_compile.AsyncCompile() + + def compile_kernel( + nodes: Sequence[BaseSchedulerNode], hint_override: Optional[int] = None + ) -> tuple[Optional[LambdaFuture], ModuleType]: + src_code = self.generate_kernel_code_from_nodes( + nodes, benchmark_kernel=True, hint_override=hint_override + ) + mod = PyCodeCache.load(src_code) + if not async_compile.use_process_pool(): + fut = None + else: + fut = async_compile.triton(kernel_name="triton_", source_code=src_code) + assert isinstance(fut, LambdaFuture) + + return (fut, mod) + + if is_multi_template and any( + n.get_template_node() is not None for n in (node1, node2) + ): + epilogue_fusion = node1.get_template_node() is not None + multi_node = ( + node1.get_template_node() + if epilogue_fusion + else node2.get_template_node() + ) + assert isinstance(multi_node, ir.MultiTemplateBuffer) + + hint_override_best_fusion_choice: dict[ + Optional[int], TritonTemplateCallerBase + ] = {} + future_choices: list[tuple[Any, Optional[LambdaFuture], ModuleType]] = [] + for hint_override in config.multi_kernel_hints: + choice_timings = multi_node.choice_timings(hint_override) + for choice, _ in sorted(choice_timings.items(), key=lambda x: x[1]): + if not isinstance( + choice, torch._inductor.select_algorithm.TritonTemplateCaller + ): + continue + with multi_node.swap_as_triton_caller(choice): + future_choices.append( + ( + choice, + *compile_kernel( + node_list_fused, hint_override=choice.hint_override + ), + ) + ) + + min_ms_fused = float("inf") + ms_fused_choice: Optional[TritonTemplateCallerBase] = None + new_timings = {} + for choice, future, mod_fused in future_choices: + try: + if future is not None: + future.result() + except Exception as e: + if fusion_log.isEnabledFor(logging.DEBUG): + fusion_log.debug( # noqa: G200 + "Exception in compiling %s: %s", + "prologue" if not epilogue_fusion else "epilogue", + str(e), + ) + continue + with multi_node.swap_as_triton_caller(choice): + ms_fused, path = self.benchmark_codegened_module( + mod_fused, device + ) + new_timings[choice] = ms_fused + if ms_fused < min_ms_fused: + min_ms_fused = ms_fused + ms_fused_choice = choice + multi_node._choice_timings[hint_override] = new_timings + assert isinstance(ms_fused_choice, TritonTemplateCallerBase) + hint_override_best_fusion_choice[hint_override] = ms_fused_choice + + # Eagerly compile and benchmark non-template nodes + choice_timings = multi_node.choice_timings() + _, ms1 = multi_node.get_min_choice() + ms2, path2 = ( + self.benchmark_fused_nodes(node_list_2) + if epilogue_fusion + else self.benchmark_fused_nodes(node_list_1) + ) + + # Start compiling choices in parallel + future_choices: list[tuple[Any, Optional[LambdaFuture], ModuleType]] = [] + triton_choices = 0 + for choice, unfused_time in sorted( + choice_timings.items(), key=operator.itemgetter(1) + ): + if not isinstance(choice, torch._inductor.ir.TritonTemplateCallerBase): + continue + + # For prologue fusion we check if the underlying template of the choice + # supports all allowed prologue inputs. If not, we skip this choice in + # the fusion benchmark. + # TODO: Remove this check after all Triton templates support prologue fusion. + # Currently, persistent+TMA Triton template does not due to the TMA-based loads. + if ( + not epilogue_fusion + and hasattr(choice, "allowed_prologue_inps") + and choice.allowed_prologue_inps != multi_node.allowed_prologue_inps + ): + continue + + if unfused_time >= ms1 + ms2: + break + + triton_choices += 1 + if triton_choices > config.max_epilogue_benchmarked_choices: + break + + with multi_node.swap_as_triton_caller(choice): + future_choices.append((choice, *compile_kernel(node_list_fused))) + + if len(future_choices) == 0: + return False + + def benchmark_when_ready() -> bool: + min_ms_fused = float("inf") + ms_fused_choice = None + + new_timings = {} + # Benchmark each choice after compilation completes + for choice, future, mod_fused in future_choices: + try: + if future is not None: + future.result() + + # Ideally we would more narrowly catch Exceptions here but + # triton will unpredictably error with valid prologue fusions + except Exception as e: + if fusion_log.isEnabledFor(logging.DEBUG): + fusion_log.debug( # noqa: G200 + "Exception in compiling %s: %s", + "prologue" if not epilogue_fusion else "epilogue", + str(e), + ) + continue + # pyrefly: ignore [missing-attribute] + with multi_node.swap_as_triton_caller(choice): + ms_fused, path = self.benchmark_codegened_module( + mod_fused, + # pyrefly: ignore [bad-argument-type] + device, + ) + new_timings[choice] = ms_fused + if ms_fused < min_ms_fused: + min_ms_fused = ms_fused + ms_fused_choice = choice + + log_fusion(min_ms_fused, ms1, ms2) + + if min_ms_fused < (ms1 + ms2) and ms_fused_choice is not None: + if config.multi_kernel_hints: + hint_override_best_fusion_choice[None] = ms_fused_choice + # pyrefly: ignore [missing-attribute] + multi_node.finalize_as_triton_callers( + hint_override_best_fusion_choice + ) + else: + # pyrefly: ignore [missing-attribute] + multi_node.finalize_as_triton_caller(ms_fused_choice) + + # pyrefly: ignore [missing-attribute] + multi_node._choice_timings[None] = new_timings + return True + else: + return False + + return benchmark_when_ready + + else: + # Start parallel compilation for all three kernels + future_and_mod_l1 = compile_kernel(node_list_1) + future_and_mod_l2 = compile_kernel(node_list_2) + future_and_mod_l1_fused = compile_kernel(node_list_fused) + + def benchmark_when_ready() -> bool: + from torch._inductor.runtime.triton_heuristics import ( + NoTritonConfigsError, + ) + + try: + # Wait for all compilations to complete + for fut in ( + future_and_mod_l1[0], + future_and_mod_l2[0], + future_and_mod_l1_fused[0], + ): + if fut is not None: + fut.result() + + ms1, path1 = self.benchmark_codegened_module( + future_and_mod_l1[1], + # pyrefly: ignore [bad-argument-type] + device, + ) + if math.isinf(ms1): + why("register spilling of the first kernel") + return False + + ms2, path2 = self.benchmark_codegened_module( + future_and_mod_l2[1], + # pyrefly: ignore [bad-argument-type] + device, + ) + if math.isinf(ms2): + why("register spilling of the second kernel") + return False + + ms_fused, path_fused = self.benchmark_codegened_module( + future_and_mod_l1_fused[1], + # pyrefly: ignore [bad-argument-type] + device, + ) + if math.isinf(ms_fused): + why("register spilling of the fused kernel") + return False + + log_fusion(ms_fused, ms1, ms2) + + if ( + is_metric_table_enabled("slow_fusion") + and ms_fused >= ms1 + ms2 + and (path1, path2) not in self.logged_slow_fusion + ): + self.logged_slow_fusion.add((path1, path2)) + get_metric_table("slow_fusion").add_row( + lambda: { + "kernel1_path": path1, + "kernel1_latency": ms1, + "kernel2_path": path2, + "kernel2_latency": ms2, + "fused_kernel_path": path_fused, + "fused_kernel_latency": ms_fused, + "slow_down_ratio": ms_fused / (ms1 + ms2), + } + ) + + return ms_fused < ms1 + ms2 + + except NoTritonConfigsError: + return False + + except CompilationError as e: + if "Loop-carried variable" in str(e): + return True + raise + + return benchmark_when_ready + + def get_fused_node(self, node: BaseSchedulerNode) -> BaseSchedulerNode: + "Look up the node in Scheduler name_to_fused_node" + return self.name_to_fused_node[node.get_first_name()] + + def fuse_nodes_once( + self, + nodes: list[BaseSchedulerNode], + is_reorder_round: bool, + ) -> list[BaseSchedulerNode]: + """ + Combine eligible nodes into FusedSchedulerNodes. + + This relies on two key functions to control the logic: + - self.can_fuse(): checks if a fusion is legal + - self.score_fusion(): assigns priority to a given fusion + """ + self.prune_redundant_deps(nodes) + fused_nodes = OrderedSet(nodes) + if fusion_log.isEnabledFor(logging.DEBUG): + fusion_log.debug("fuse_nodes_once, candidates:") + for node in fused_nodes: + fusion_log.debug(" %s", node.debug_str_short()) + + # These are potential fusions which we are async compiling, + # and which we will benchmark profitability of. + pending_fusions: dict[ + BaseSchedulerNode, + tuple[Callable[[], bool], BaseSchedulerNode, BaseSchedulerNode], + ] = {} + + def fuse_two_nodes( + node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> BaseSchedulerNode: + fusion_log.debug("fusing %s with %s", node1.get_name(), node2.get_name()) + + device = node1.get_device() + assert node2.get_device() == device + node3 = self.get_backend(device).fuse(node1, node2) + fused_nodes.remove(node1) + fused_nodes.remove(node2) + fused_nodes.add(node3) + self.name_to_fused_node.update( + {n.get_name(): node3 for n in node3.get_nodes()} + ) + return node3 + + def resolve_pending_fusions( + node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> None: + while ( + self.get_fused_node(node1) in pending_fusions + or self.get_fused_node(node2) in pending_fusions + ): + pending_fusion = pending_fusions.get( + self.get_fused_node(node1), + pending_fusions.get(self.get_fused_node(node2), None), + ) + assert pending_fusion is not None + + is_speedup, node_key1, node_key2 = pending_fusion + pending_fusions.pop(node_key1, None) + pending_fusions.pop(node_key2, None) + + assert self.get_fused_node(node_key1) is node_key1 + assert self.get_fused_node(node_key2) is node_key2 + + if not is_speedup() or self.will_fusion_create_cycle(node1, node2): + continue + + fuse_two_nodes(node_key1, node_key2) + + for node1, node2 in self.get_possible_fusions(nodes, is_reorder_round): + # if either node is in a pending fusion, resolve it. + # since we iterate on potential fusions based on profitability + # the first potential fusion should take precedence. + resolve_pending_fusions(node1, node2) + node1 = self.get_fused_node(node1) + node2 = self.get_fused_node(node2) + + if self.can_fuse( + node1, node2, is_reorder_round + ) and not self.will_fusion_create_cycle(node1, node2): + speedup = self.speedup_by_fusion(node1, node2) + if callable(speedup): + pending_fusions[node1] = (speedup, node1, node2) + pending_fusions[node2] = (speedup, node1, node2) + continue + + if not speedup: + continue + + fuse_two_nodes(node1, node2) + + seen_pair_speedup_fn: OrderedSet[Callable[[], bool]] = OrderedSet() + for is_speedup_fn, node_key1, node_key2 in pending_fusions.values(): + if is_speedup_fn in seen_pair_speedup_fn: + continue + + seen_pair_speedup_fn.add(is_speedup_fn) + + assert self.get_fused_node(node_key1) is node_key1 + assert self.get_fused_node(node_key2) is node_key2 + + if is_speedup_fn() and not self.will_fusion_create_cycle( + node_key1, node_key2 + ): + fuse_two_nodes(node_key1, node_key2) + + nodes = sorted(fused_nodes, key=lambda x: x.min_order) + nodes = self.topological_sort_schedule(nodes) + return nodes + + def create_combo_kernel_nodes(self, num_ck_nodes: Optional[int] = None) -> None: + """ + Groups parallel nodes + """ + fused_nodes = OrderedSet(self.nodes) + count = 0 + num_nodes_orig = len(self.nodes) + log.debug("ComboKernels: Generating with num_ck_nodes = %s...", num_ck_nodes) + for num, node_list in enumerate( + ForeachKernelSchedulerNode.group_nodes_for_combo_kernels(self) + ): + node_list = ForeachKernelSchedulerNode.combinable_nodes(node_list) + if len(node_list) < 2: + continue + if num_ck_nodes is not None and count > num_ck_nodes: + break + if not self.speedup_by_combo_kernel(node_list): + log.debug("ComboKernels: Not speeding up %d-th group", num) + continue + count += 1 + enable_autotune = config.combo_kernels_autotune > 0 + group_snode = ForeachKernelSchedulerNode( + node_list[0].scheduler, + node_list, + use_custom_partition_algo=True, + enable_autotune=enable_autotune, + ) + log.info( + "ComboKernels: Combining %d nodes for %d-th group", + len(node_list), + num, + ) + for node in node_list: + fused_nodes.remove(node) + fused_nodes.add(group_snode) + self.name_to_fused_node.update( + {n.get_name(): group_snode for n in group_snode.get_nodes()} + ) + self.nodes = sorted(fused_nodes, key=lambda x: x.min_order) + self.nodes = self.topological_sort_schedule(self.nodes) + log.info( + "Generated ComboKernel nodes: %d ComboKernels, totally %d -> %d nodes", + count, + num_nodes_orig, + len(self.nodes), + ) + self.prune_redundant_deps(self.nodes) + + def prune_redundant_deps(self, nodes: list[BaseSchedulerNode]) -> None: + for node in nodes: + node.prune_redundant_deps(self.name_to_fused_node) + + def get_possible_fusions( + self, + nodes: list[BaseSchedulerNode], + is_reorder_round: bool, + ) -> list[tuple[BaseSchedulerNode, BaseSchedulerNode]]: + """ + Helper to find all legal fusion opportunities, sorted by self.score_fusion() + """ + possible_fusions = [] + seen = OrderedSet[tuple[BaseSchedulerNode, BaseSchedulerNode]]() + + def check_all_pairs(nodes: list[BaseSchedulerNode]) -> None: + for node1_index, node1 in enumerate(nodes): + for node2 in nodes[ + node1_index + 1 : node1_index + + 1 + + config.max_fusion_buffer_group_pairwise_attempts + ]: + key = (node1, node2) + if key in seen: + continue + seen.add(key) + + if self.can_fuse(node1, node2, is_reorder_round): + possible_fusions.append(key) + elif (node2.is_template() or node2.is_foreach()) and self.can_fuse( + node2, node1, is_reorder_round + ): + # foreach fusions and epilogue fusions are order dependent + possible_fusions.append((node2, node1)) + + buffer_names_grouping = collections.defaultdict(list) + for node in nodes: + if self.unfusable_node(node): + continue + for buf in node.used_buffer_names(): + buffer_names_grouping[buf].append(node) + for node_grouping in buffer_names_grouping.values(): + check_all_pairs(node_grouping) + + if config.aggressive_fusion: + group_grouping = collections.defaultdict(list) + for node in nodes: + group = getattr(node, "group", None) + if group: + group_grouping[group].append(node) + for node_grouping in group_grouping.values(): + check_all_pairs(node_grouping) + + possible_fusions = self.get_possible_fusions_with_highest_priority( + possible_fusions + ) + possible_fusions.sort(key=self.score_fusion_key, reverse=True) + fusion_log.debug("found %d possible fusions", len(possible_fusions)) + return possible_fusions + + def will_fusion_create_cycle( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> bool: + """ + Finds whether there's a path from node1 to node2 (or vice-versa) + caused indirectly by other fusions. + """ + # since we are just returning boolean here, use slightly faster, unordered set + visited = OrderedSet[FusedSchedulerNode]() + + def found_path(node: BaseSchedulerNode) -> bool: + # only fused nodes can introduce new ancestors. + if isinstance(node, FusedSchedulerNode) and node not in visited: + visited.add(node) + if node.get_operation_names().issubset(combined_ancestors): + # All fusion outputs are in ancestors of node1 and node2, thus + # cannot introduce new path: + # + # 1. if output is neither descendent of node1 or node2, the + # output cannot introduce a path + # 2. due to [can_fuse]: if WLOG output is descendent of node1, it cannot be + # on path(node1->node2), hence it cannot be ancestor of node2 + # 3. due to [acyclic]: if WLOG output is descendent of node1, it cannot be + # ancestor of node1 + return False + else: + # continue DFS of new ancestors introduced by the fusion + return bool(combined_names & node.ancestors) or any( + found_path(self.name_to_fused_node[n]) + for n in node.ancestors - combined_ancestors + ) + return False + + # as above - use slightly faster, unordered set + combined_names = ( + node1.get_operation_names()._dict.keys() + | node2.get_operation_names()._dict.keys() + ) + combined_ancestors = ( + node1.ancestors._dict.keys() | node2.ancestors._dict.keys() + ) - combined_names + cycle = any(found_path(self.name_to_fused_node[n]) for n in combined_ancestors) + if cycle: + WhyNoFuse(node1, node2)("will create cycle") + return cycle + + def can_fusion_increase_peak_memory( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> bool: + """ + Return true if fusing the two nodes can potentially increasing peak memory. + + The implementation is more like a heuristic since we don't really know if we are at peak + or not when trying to fuse these two nodes. The order of nodes may change later which makes the + peak memory estimation hard. + + Here is how we decide the LOWER BOUND of extra memory allocation if we fuse these 2 nodes: + 1. find all buffers read by each node with a single user. These buffers are supposed to + be reused if we don't fuses these 2 nodes + 2. find the intersection of these buffers for the two node and sum the total buffer size. + If we don't fuse these two nodes, we can at lease avoid this much memory allocation. + Note that the extra memory allocation is not necessarily causing peak memory increase. + This is just a heuristic. + + We return true only if the saving for fusion can not trade off the extra memory allocation. + """ + + from .codegen.wrapper import buffer_reuse_key + + def _find_single_user_inputs( + node: BaseSchedulerNode, + ) -> list[ir.Buffer]: + output = [] + for rd in node.read_writes.reads: + buf = self.name_to_buf.get(rd.name) + if buf and len(buf.users) == 1 and buf.node.has_tensor_output(): + output.append(buf.node) + return output + + # Check inputs that can be potentially reused + lhs_dep_nodes = _find_single_user_inputs(node1) + rhs_dep_nodes = _find_single_user_inputs(node2) + + lhs_reuse_keys = OrderedSet(buffer_reuse_key(buf) for buf in lhs_dep_nodes) + rhs_reuse_keys = OrderedSet(buffer_reuse_key(buf) for buf in rhs_dep_nodes) + + common_reuse_keys = lhs_reuse_keys.intersection(rhs_reuse_keys) + + memory_overhead = 0 + for key in common_reuse_keys: + try: + memory_overhead += int(key[2]) + except ValueError: + # not an integer. Fallback is to fuse + return False + + bw_saving = self.score_fusion_memory(node1, node2) + + # The factor 32 here is quite arbitrary. + if V.graph.sizevars.statically_known_gt(memory_overhead, 32 * bw_saving): + return True + return False + + def fusion_prevent_too_many_reads_and_writes( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode, threshold: int + ) -> bool: + # After fusion, we need to calculate the unique I/O buffers + # accounting for buffers that become internal (removed through fusion) + + # Get all nodes that will be in the fused node + fused_node_names = OrderedSet( + [node.get_name() for node in node1.get_nodes()] + + [node.get_name() for node in node2.get_nodes()] + ) + + # Calculate node2 reads that can be removed through fusion, + # i.e. node2 reads that are outputs of node1 + node1_write_names = OrderedSet(dep.name for dep in node1.read_writes.writes) + node2_read_names = OrderedSet(dep.name for dep in node2.read_writes.reads) + reads_removed_through_fusion = node2_read_names & node1_write_names + + # Calculate node1 writes that can be removed through fusion, + # i.e. node1 writes that are only read by node2 + writes_removed_through_fusion: OrderedSet[str] = OrderedSet() + for write_dep in node1.read_writes.writes: + if self.can_buffer_be_removed_through_fusion( + write_dep.name, fused_node_names + ): + writes_removed_through_fusion.add(write_dep.name) + + # Get all unique reads (union of both nodes' reads) + all_read_names = OrderedSet( + dep.name for dep in node1.read_writes.reads + ) | OrderedSet(dep.name for dep in node2.read_writes.reads) + + # Get all unique writes (union of both nodes' writes) + all_write_names = OrderedSet( + dep.name for dep in node1.read_writes.writes + ) | OrderedSet(dep.name for dep in node2.read_writes.writes) + + # Remove reads that become internal + unique_reads = all_read_names - reads_removed_through_fusion + + # Remove writes that become internal + unique_writes = all_write_names - writes_removed_through_fusion + + # Get all unique buffer names (reads and writes combined, but no double counting) + unique_io_buffers = unique_reads | unique_writes + + return len(unique_io_buffers) > threshold + + def are_long_distant_nodes( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> bool: + """ + This function prevents fusion for nodes that can increase memory + footprint. This problem is more common in horizontal fusion, where nodes + that are far apart in the original order get fused, lengthening the live + intervals of tensors. This is very evident in models with activation + checkpointing, where the recomputed nodes from different checkpointed + regions get fused and significantly increase the memory footprint. + + The current attempt is a quick, possibly hacky, heuristic to prevent the + fusion of nodes that are far away in the original order. + + A better but difficult to implement heurisitic would be to use live + intervals of the buffers, find region of peak pressure in the original + program and prevent fusion that crosses that peak region. We might need + special care or good approximation in this implementation, as fusion of + node changes live intervals, and re-computing live intervals and peak + memory after each fusion can introduce large compilation overhead. + """ + proximity_score = max( + abs(node1.min_order - node2.max_order), + abs(node2.min_order - node1.max_order), + ) + return proximity_score > 64 + + def decide_fusion_fail_reason( + self, + node1: BaseSchedulerNode, + node2: BaseSchedulerNode, + common_buf_names: Union[tuple[str, ...], OrderedSet[str]], + ) -> str: + """ + Try to decide reasons why fusion fail due to no shared memory even though + there are common buffers. + """ + reasons = {} + node1_name2dep = {dep.name: dep for dep in node1.read_writes.reads_and_writes()} + node2_name2dep = {dep.name: dep for dep in node2.read_writes.reads_and_writes()} + + for buf_name in common_buf_names: + buf = V.graph.get_buffer(buf_name) + lhs_dep = node1_name2dep[buf_name] + rhs_dep = node2_name2dep[buf_name] + + if not isinstance(lhs_dep, MemoryDep) or not isinstance(rhs_dep, MemoryDep): + reasons[buf_name] = ( + f"not MemoryDep: {type(lhs_dep)} v.s. {type(rhs_dep)}" + ) + continue + + if lhs_dep.get_numel() != rhs_dep.get_numel(): + reasons[buf_name] = ( + f"different numel: {lhs_dep.get_numel()} v.s. {rhs_dep.get_numel()}" + ) + continue + + # same numel but different MemoryDep.size. Should be broadcasting + if sympy_product(lhs_dep.size) != sympy_product(rhs_dep.size): + reasons[buf_name] = "broadcast" + continue + + lhs_off = lhs_dep.get_offset() + rhs_off = rhs_dep.get_offset() + if lhs_off != rhs_off: + # One example is in transformer, we use a concatenated linear layer + # to project Q/K/V and then split the result. The 3 splits will + # point to the same buffer with different offsets. + reasons[buf_name] = f"different offset: {lhs_off} v.s. {rhs_off}" + continue + + if ( + lhs_dep.normalize_with_stride_order() + == rhs_dep.normalize_with_stride_order() + ): + reasons[buf_name] = f"Mismatch loop orders: {lhs_dep} v.s. {rhs_dep}" + continue + + # Add more rules here + layout_str = "" + if not isinstance(buf, ir.TorchBindObject): + layout_str = f"Layout: {buf.layout}" + reasons[buf_name] = ( + f"Unknown reason: {lhs_dep} v.s. {rhs_dep}. {layout_str}" + ) + + return str(reasons) + + def shared_data_after_inverting_indexing( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> int: + """ + Attempts to enable fusion between two nodes by inverting indexing patterns. + + This optimization targets cases where node1 has a contiguous write and + node2 has a contiguous write but discontiguous read. By inverting the + indexing in node2's read and write operations, we can make them compatible + with node1 for potential fusion. + + Args: + node1: First scheduler node (source) + node2: Second scheduler node (target for inversion) + + Returns: + int: Fusion score if successful, 0 if optimization not applicable + """ + + if not config.loop_index_inversion_in_fusion: + return -1 + + if any(n.is_cpu() for n in [node1, node2]): + return -1 + + # Check for shared buffers between nodes + node1_buffer_names = node1.read_writes.buffer_names() + node2_buffer_names = node2.read_writes.buffer_names() + common_buffer_names = node1_buffer_names & node2_buffer_names + + if not common_buffer_names: + return -1 + + # only invert if node1 is single unmet dep + node2_unmet_dependencies = OrderedSet( + dep.name for dep in node2.unmet_dependencies + ) + if node2_unmet_dependencies - node1_buffer_names: + return -1 + + if len(node2_unmet_dependencies) > 1: + return -1 + + # Currently only handle single read/write operations + if len(node2.read_writes.reads) > 1 or len(node2.read_writes.writes) > 1: + return -1 + + node2_read = next(iter(node2.read_writes.reads)) + node2_write = next(iter(node2.read_writes.writes)) + + if not isinstance(node2_read, MemoryDep) or not isinstance( + node2_write, MemoryDep + ): + return -1 + + node1_writes = {dep.name: dep for dep in node1.read_writes.writes} + if node2_read.name not in node1_writes: + return -1 + + node1_write = node1_writes[node2_read.name] + + if not isinstance(node1_write, MemoryDep): + return -1 + + # We are checking for compatibility with the normalized node1 write + # then modifying node2 reads/writes. since the node1 write will be just used + # for compatibility, while node2 will be used in actual modification, just + # normalize node1 not node2. + node1_write = node1_write.normalize() + + if ( + node1_write.index != node2_write.index + and node1_write.size != node2_write.size + ): + return -1 + + if node2_read.size != node2_write.size or len(node2_read.var_names) != 1: + return -1 + + # Verify we have exactly two indexing expressions (one read, one write) + if len(node2._body.indexing_exprs) != 2: # type: ignore[attr-defined] + return -1 + + # No subblocks allowed for this optimization + if node2._body.subblocks: # type: ignore[attr-defined] + return -1 + + assert ( + "index0" in node2._body.indexing_exprs # type: ignore[attr-defined] + and "index1" in node2._body.indexing_exprs # type: ignore[attr-defined] + ) + + # Extract and verify single read expression + node2_read_exprs = OrderedSet(expr for expr in node2._body.get_read_exprs()) # type: ignore[attr-defined] + if len(node2_read_exprs) != 1: + return -1 + + read_expr = next(iter(node2_read_exprs)) + + # Determine which index is for reading vs writing + if read_expr == node2._body.indexing_exprs["index0"]: # type: ignore[attr-defined] + read_expr_index = "index0" + write_expr_index = "index1" + else: + assert read_expr == node2._body.indexing_exprs["index1"] # type: ignore[attr-defined] + read_expr_index = "index1" + write_expr_index = "index0" + + from torch._inductor.invert_expr_analysis import generate_inverse_formula + + index_vars = node2._body.vars[0] # type: ignore[attr-defined] + if len(index_vars) != 1: + return -1 + + simplified_terms = [] + for term in sympy.Add.make_args(read_expr): + simplified_terms.append( + V.graph.sizevars.combine_modular_indexing_pairs(term) + ) + simplified_read_expr = sum(simplified_terms) + + inverse_formula = generate_inverse_formula(simplified_read_expr, index_vars[0]) + + # formula is not invertible + if inverse_formula is None: + return -1 + + # === Apply Inversion === + + # Swap the indexing expressions using the inverse formula + node2._body.indexing_exprs[read_expr_index] = node2._body.indexing_exprs[ # type: ignore[attr-defined] + write_expr_index + ] + node2._body.indexing_exprs[write_expr_index] = inverse_formula # type: ignore[attr-defined] + + # Refresh dependencies and calculate fusion score + node2.refresh_dependencies(True, False) # type: ignore[attr-defined] + score = self.score_fusion_memory(node1, node2) + assert isinstance(score, int) + + fusion_log.info("Shared memory after inversion: %d", score) + return score + + def shared_data_after_reordering_loop( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> int: + """ + Right now just greedily reorder the loop of node1 to be compatible with node2, + but ideally we should have some heuristics to reorder the loop for node2 + to be compatible with node1 if that's more efficient. + + Return the amount of shared data re-computed in this method. + If no such recomputation happens, return -1 (not return 0 since 0 is a valid + amount of shared data). + + """ + + # TODO Don't do loop reordering for CPU for now. + # Should debug more why it does not work for CPU codegen + if not config.loop_ordering_after_fusion or any( + n.is_cpu() for n in [node1, node2] + ): + return -1 + + # in some rare case, a template can be passed in. + # Check test_interaction_with_multi_template in test_loop_ordering.py + # and https://github.com/pytorch/pytorch/issues/165579 + if node1.is_template() or node2.is_template(): + return -1 + + node1_buffer_names = node1.read_writes.buffer_names() + node2_buffer_names = node2.read_writes.buffer_names() + # Fast path: no common buffers. + common_buffer_names = node1_buffer_names & node2_buffer_names + if not common_buffer_names: + return -1 + + node1_name2dep = {dep.name: dep for dep in node1.read_writes.reads_and_writes()} + node2_name2dep = {dep.name: dep for dep in node2.read_writes.reads_and_writes()} + + # Find the commons buffers that has different loop orders + candidates = [] + for buffer_name in common_buffer_names: + lhs_dep = node1_name2dep[buffer_name] + rhs_dep = node2_name2dep[buffer_name] + if ( + lhs_dep.normalize_with_stride_order() + == rhs_dep.normalize_with_stride_order() + ): + candidates.append( + ( + V.graph.sizevars.size_hint(lhs_dep.get_numel(), fallback=0), + lhs_dep, + rhs_dep, + ) + ) + + if len(candidates) == 0: + return -1 + + # Pick the largest buffer to guide the loop reordering + _numel, lhs_dep, rhs_dep = max(candidates, key=operator.itemgetter(0)) + + if not isinstance(lhs_dep, MemoryDep) or not isinstance(rhs_dep, MemoryDep): + return -1 + + if lhs_dep.num_vars != rhs_dep.num_vars: + # this can happen due to we don't merge loops. + # We can not do loop reordering in this case right now + # Simply returning true if the two Deps are the same after + # normalization (merging loops) + if lhs_dep.normalize() == rhs_dep.normalize(): + return self.dep_size_hint(lhs_dep) + return -1 + + reordered = False + # Only reorder loops for pointwise for now + if not node1.is_reduction(): + reordered = node1.reorder_loops_by_dep_pair(lhs_dep, rhs_dep) + elif not node2.is_reduction(): + reordered = node2.reorder_loops_by_dep_pair(rhs_dep, lhs_dep) + else: + loop_ordering_log.debug( + "Don't reorder loops since both nodes are reductions: %s v.s. %s", + node1.get_name(), + node2.get_name(), + ) + + return ( + typing.cast(int, self.score_fusion_memory(node1, node2)) + if reordered + else -1 + ) + + def unfusable_node(self, node: BaseSchedulerNode) -> bool: + """ + Is this node unfusable under any conditions. + """ + return ( + isinstance(node, (ExternKernelSchedulerNode, NopKernelSchedulerNode)) + and not node.is_template() + and not is_output_of_multi_outputs_template(node.node) + ) + + def check_prologue_fusion_heuristics_fusable( + self, + prologue_node: BaseSchedulerNode, + template_node: BaseSchedulerNode, + why: WhyNoFuse, + ) -> bool: + """ + Heuristics to avoid benchmarking predictably slow prologue fusions + """ + # user opt into more aggressive prologue fusion, dont use heuristics + if prologue_node.get_operation_names() <= V.graph.invoke_quant_ops: + return True + + read_bytes = prologue_node.get_read_buffer_sizes() + write_bytes = prologue_node.get_write_buffer_sizes() + + # Initially, only do fusions which will result in fewer memory accesses inside of the template to avoid + # potential bad cache behavior and shared memory use. + # we also want to avoid benchmarking reliably unprofitable fusions like downcasts from fp32 -> fp16 inside kernel. + # allowing gathers by allowing increasing write_bytes by small factor + # TODO - make configurable per input, for instance, bias can fuse fp32 -> fp16 profitably + + BYTES_THRESHOLD_MULTIPLIER = 1.1 + if read_bytes > (write_bytes * BYTES_THRESHOLD_MULTIPLIER): + why("prologue fusion will not increase amount of bytes read in kernel") + return False + + # we want to avoid attempting to fuse predictably unprofitable prologues + # such as increasing the unaligned reads or writes. + # TODO - would be nice to generalize this, however, we would need more explicit + # knowledge of memory access patterns in the TritonTemplate in order to know + # the stride order to check alignment. + origins = tuple( + e.target + for n in prologue_node.get_nodes() + if n.node is not None + for e in n.node.get_origins() + if e.op == "call_function" + ) + if origins == (torch.ops.aten.constant_pad_nd.default,): + why( + "prologue fusion will not increase attempt to fuse in padding bc it increases unaligned reads" + ) + return False + + def low_prec_fp(dtype: torch.dtype) -> bool: + return dtype.itemsize <= 2 and dtype.is_floating_point + + if ( + low_prec_fp(template_node.get_template_node_or_throw().dtype) + and not prologue_node.can_codegen_in_low_precision() + ): + why( + "prologue fusion that must be upcast to fp32 not profitable for low precision templates" + ) + return False + + return True + + def get_expand_dim_for_pointwise_nodes( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> Optional[tuple[int, SchedulerNode, sympy.Expr]]: + """ + Fusing two small pointwise nodes significantly reduces kernel overhead + and launch overhead. However, slightly different sizes would prevent fusion. + Here, we decide if expanding sizes of one node is profitible by allowing + fusion, and returns the dimension to expand, node with smaller sizes, + and new size after expand. + """ + # only support scheduler node + if not isinstance(node1, SchedulerNode) or not isinstance(node2, SchedulerNode): + return None + + # only support computued buffer + if not ( + isinstance(node1.node, ir.ComputedBuffer) + and isinstance(node2.node, ir.ComputedBuffer) + ): + return None + + # does not support mutation yet since relying on index mod to handle + # out-of-boundary access. + if node1.has_aliasing_or_mutation() or node2.has_aliasing_or_mutation(): + return None + + # skip halide which does not support mod for index + if config.cpu_backend == "halide": + return None + + # only support pointwise nodes with the same reduction size + n1_sizes, n2_sizes = node1._sizes, node2._sizes + n1_iter_sizes, n1_reduce_sizes = n1_sizes + n2_iter_sizes, n2_reduce_sizes = n2_sizes + if ( + node1.is_reduction() + or node2.is_reduction() + or n1_reduce_sizes != n2_reduce_sizes + or len(n1_iter_sizes) != len(n2_iter_sizes) + ): + return None + + # only support nodes with 1 write for simplification + if len(node1.read_writes.writes) > 1 or len(node2.read_writes.writes) > 1: + return None + + # When memory access is small, reducing gpu kernel overhead is profitable over + # slightly larger memory access. + node1_write_memory = self.dep_size_hint(next(iter(node1.read_writes.writes))) + node2_write_memory = self.dep_size_hint(next(iter(node1.read_writes.writes))) + if ( + max(node1_write_memory, node2_write_memory) + > config.small_memory_access_threshold + ): + return None + + # does not support reinplace since `index % boundary` may lead to + # race condition + def has_reusable_buffer(node: BaseSchedulerNode) -> bool: + for read in node.read_writes.reads: + input_buf: Optional[Union[SchedulerBuffer, SchedulerDonatedBuffer]] + if read.name in self.name_to_donated_buffer: + input_buf = self.name_to_donated_buffer[read.name] + else: + input_buf = self.name_to_buf.get(read.name) + + if ( + input_buf + and V.graph.wrapper_code.can_reuse(input_buf, node) + and not isinstance(input_buf.defining_op, NopKernelSchedulerNode) + ): + return True + return False + + if has_reusable_buffer(node1) or has_reusable_buffer(node2): + return None + + # only support nodes with 1 mismatch dimension + mismatch_dimensions = [] + for idx, (n1_size, n2_size) in enumerate(zip(n1_iter_sizes, n2_iter_sizes)): + if n1_size != n2_size: + mismatch_dimensions.append(idx) + + if len(mismatch_dimensions) != 1: + return None + + mismatch_dim = mismatch_dimensions[0] + mismatch_size1, mismatch_size2 = ( + n1_iter_sizes[mismatch_dim], + n2_iter_sizes[mismatch_dim], + ) + if V.graph.sizevars.statically_known_lt(mismatch_size1, mismatch_size2): + return mismatch_dim, node1, mismatch_size2 + elif V.graph.sizevars.statically_known_lt(mismatch_size2, mismatch_size1): + return mismatch_dim, node2, mismatch_size1 + else: + return None + + def can_fuse( + self, + node1: BaseSchedulerNode, + node2: BaseSchedulerNode, + can_reorder: bool = False, + allow_mix_order_reduction: bool = True, + ) -> bool: + """ + Determine if it is possible to combine node1 and node2 into a + single fused node. + """ + if node1 is node2: + return False + + if isinstance(node1, FusedMixOrderReductions): + return node1.can_fuse_with(node2) + if isinstance(node2, FusedMixOrderReductions): + # We don't fuse something before a FusedMixOrderReductions + # right now + return False + + why = WhyNoFuse(node1, node2) + + if node1.is_template() and self.get_backend( + node1.get_device() + ).can_fuse_multi_outputs_template(node1, node2): + return True + + if isinstance(node1, GroupedSchedulerNode) or isinstance( + node2, GroupedSchedulerNode + ): + why("grouped node must not be fused with other nodes") + return False + if ( + isinstance(node1, (ExternKernelSchedulerNode, NopKernelSchedulerNode)) + and not node1.is_template() + ): + why("node1 is extern or nop") + return False + if ( + isinstance(node2, (ExternKernelSchedulerNode, NopKernelSchedulerNode)) + and not node2.is_template() + ): + why("node2 is extern or nop") + return False + + if node2.get_operation_names() & node1.ancestors: + why("node1 must go before node2") + return False + + if node2.is_template(): + if not config.prologue_fusion: + why("prologue fusion turned off") + return False + + if node1.is_reduction() or node1.is_template(): + why("prologue fusion only supported for pointwise nodes") + return False + + template = node2.get_template_node_or_throw() + if not isinstance(template, ir.TritonTemplateBuffer): + why("prologue fusion only supported for TritonTemplates") + return False + + allowed_prologue_inps = template.get_allowed_prologue_inps() + + unsupported_prologue_args = ( + OrderedSet(inp.get_name() for inp in template.inputs) # type: ignore[union-attr] + - allowed_prologue_inps + ) + + if node1.get_buffer_names() & unsupported_prologue_args: + why("prologue fusion not implemented for kernel for these inputs") + return False + + if node1.has_aliasing_or_mutation() or node1.has_aliasing_or_mutation(): + why("template prologue can only fuse functional pointwise nodes") + return False + + prologue_nodes = node1.get_nodes() + for node in prologue_nodes[:-1]: + node_outs = node.get_outputs() + for out in node_outs: + if not all(user.node in prologue_nodes for user in out.users): + why("template prologue can only fuse nodes with a single use") + return False + + template_snodes = ( + [node2] + if not isinstance(node2, FusedSchedulerNode) + else [n for n in node2.snodes if n.is_template()] + ) + assert len(template_snodes) == 1 + template_snode = template_snodes[0] + + if not ( + len(prologue_nodes[-1].outputs) == 1 + and len(prologue_nodes[-1].outputs[0].users) == 1 + and prologue_nodes[-1].outputs[0].users[0].node is template_snode + ): + why( + "template prologue can only fuse nodes with a single use into template" + ) + return False + + if not self.check_prologue_fusion_heuristics_fusable(node1, node2, why): + return False + + if node1.is_template() and ( + node2.has_aliasing_or_mutation() + or node2.is_reduction() + or not config.epilogue_fusion + ): + why("template epilogue not satisfied") + return False + + if (node1.get_buffer_names() & V.graph.no_fuse_buffer_names) or ( + node2.get_buffer_names() & V.graph.no_fuse_buffer_names + ): + why("fusion for buffer explicit disabled") + return False + device = node1.get_device() + device2 = node2.get_device() + if device != device2: + why("device mismatch (%s vs %s)", device, device2) + return False + del device2 + + shared_data_score = self.score_fusion_memory( + node1, node2, allow_mix_order_reduction=allow_mix_order_reduction + ) + assert isinstance(shared_data_score, int) + + if ( + can_reorder + and shared_data_score < config.score_fusion_memory_threshold + and config.loop_ordering_after_fusion + ): + new_shared_data_score = self.shared_data_after_reordering_loop(node1, node2) + if new_shared_data_score >= 0: + shared_data_score = new_shared_data_score + + if config.expand_dimension_for_pointwise_nodes and ( + expand_analysis := self.get_expand_dim_for_pointwise_nodes(node1, node2) + ): + (expand_dim, smaller_node, expand_size) = expand_analysis + smaller_node.expand_dimension_for_pointwise_node(expand_dim, expand_size) + shared_data_score = self.score_fusion_memory(node1, node2) + assert isinstance(shared_data_score, int) + + if ( + config.loop_index_inversion_in_fusion + and shared_data_score < config.score_fusion_memory_threshold + ): + new_shared_data_score = self.shared_data_after_inverting_indexing( + node1, node2 + ) + if new_shared_data_score >= 0: + shared_data_score = new_shared_data_score + + if loop_ordering_log.isEnabledFor(logging.DEBUG): + loop_ordering_log.debug( + "%s and %s has %s shared data", + node1.get_name(), + node2.get_name(), + shared_data_score, + ) + + if not V.choices.can_fuse(self, node1, node2, shared_data_score): + return False + + if node1.get_operation_names() & node2.ancestors: + # node2 depends on node1 outputs + return ( + self.can_fuse_vertical(node1, node2) + and V.choices.can_fuse_vertical(self, node1, node2, shared_data_score) + and self.get_backend(device).can_fuse_vertical(node1, node2) + ) + else: # nodes don't depend on each other, but may have common reads + return V.choices.can_fuse_horizontal( + self, node1, node2, shared_data_score + ) and self.get_backend(device).can_fuse_horizontal(node1, node2) + + def can_fuse_vertical( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> bool: + """ + Check if it is legal to fuse a consumer (node2) into a producer (node1). + + We can fuse them if all the reads of node2 either match + corresponding writes in node1, or are written by nodes that can + be scheduled before the fusion of node1 and node2. + """ + node1_buf_names = node1.get_buffer_names() + why = WhyNoFuse(node1, node2) + remaining_deps_by_name: dict[str, list[Dep]] = defaultdict(list) + + for dep in node2.unmet_dependencies: + name = self.mutation_renames.get(dep.name, dep.name) + if isinstance(dep, WeakDep) and self.fusable_weak_dep(dep, node1, node2): + continue + remaining_deps_by_name[name].append(dep) + + for cd in node1.read_writes.writes: + if not isinstance(cd, MemoryDep): + continue + remaining = remaining_deps_by_name.get( + self.mutation_renames.get(cd.name, cd.name) + ) + if remaining: + for rd in remaining: + if self.fusable_read_and_write(rd, cd): + remaining.remove(rd) # noqa: B909 + + remaining_deps = OrderedSet( + dep.name + for dep in itertools.chain.from_iterable(remaining_deps_by_name.values()) + ) + + if remaining_deps & node1_buf_names: + # MemoryDeps didn't match and read different locations of the same buffer. + # Examples here include: + # - MemoryDep("foo", x) != MemoryDep("foo", x + 1) + # - MemoryDep("foo", x) != StarDep("foo") + why("memory deps did not match") + return False + + node1_op_names = node1.get_operation_names() + for name in remaining_deps: + op_name = self.name_to_buf[name].defining_op_name() + if node1_op_names & self.name_to_fused_node[op_name].ancestors: + why("intermediate nodes between node1 & node2") + return False + + return True + + def fusable_weak_dep( + self, weak_dep: WeakDep, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> bool: + if weak_dep.name not in node1.get_buffer_names(): + return False + + # A weak dep can be fused if and only if the fused operation acts inplace + # on the buffer being mutated. i.e. the same index is being read then mutated + mutating_writes = [ + write + for write in node2.read_writes.writes + if write.name == weak_dep.mutating_buf + ] + if len(mutating_writes) != 1: + return False + write = mutating_writes[0] + if isinstance(write, StarDep): + return False + assert isinstance(write, MemoryDep) + + if free_symbol_is_type(write.index, SymT.TMP): + return False + + real_name = self.mutation_real_name[weak_dep.mutating_buf] + relevant_reading_nodes = [node1] + if isinstance(node1, ForeachKernelSchedulerNode): + relevant_reading_nodes = node1.snodes + num_concurrent_reads = 0 + for reading_node in relevant_reading_nodes: + relevant_reads = [ + read + for read in reading_node.read_writes.reads + if read.name == real_name + ] + if not relevant_reads: + continue + num_concurrent_reads += 1 + if not all( + isinstance(read, MemoryDep) + and not free_symbol_is_type(read.index, SymT.TMP) + and read.index == write.index + and read.size == write.size + for read in relevant_reads + ): + return False + return num_concurrent_reads <= 1 + + # StarDep doesn't match MemoryDep, different indices don't match + # However, broadcasting sometimes strips dimensions, and if that's the case + # we still can match unmet dep + # if there's indirect indexing, don't match it + def fusable_read_and_write(self, read: Dep, write: MemoryDep) -> bool: + if isinstance(read, MemoryDep): + read_name = self.mutation_renames.get(read.name, read.name) + + if ( + read_name != write.name + or free_symbol_is_type(read.index, SymT.TMP) + or free_symbol_is_type(write.index, SymT.TMP) + ): + return False + + if config.loop_ordering_after_fusion and read.num_vars != write.num_vars: + # Need merge loops if we do loop ordering after fusion since + # we have not merged the loops yet when creating the scheduler + # nodes. + read = read.normalize() + write = write.normalize() + + return ( + read.index == write.index + and len(read.size) >= len(write.size) + and read.size[: len(write.size)] == write.size + ) + elif isinstance(read, StarDep): + read_name = self.mutation_renames.get(read.name, read.name) + write_name = self.mutation_renames.get(write.name, write.name) + if ( + read.mode == write.mode + and write.mode is not None + and read_name == write_name + ): + return True + return False + + def dep_size_hint(self, dep: Dep, count_bytes: bool = True) -> int: + return V.graph.get_dep_size_hint(dep, count_bytes) + + def score_fusion_memory( + self, + node1: BaseSchedulerNode, + node2: BaseSchedulerNode, + count_bytes: bool = True, + return_is_mix_order_reduction: bool = False, + allow_mix_order_reduction: bool = True, + ) -> int | tuple[int, bool]: + """ + The first term in our fusion score that estimates number of saved + memory operations. + """ + + def _construct_return_value(score, is_mix_order_reduction): + return ( + (score, is_mix_order_reduction) + if return_is_mix_order_reduction + else score + ) + + if allow_mix_order_reduction and MixOrderReduction.can_fuse(node1, node2): + # The fusion score for mix order reduction only count + # numel so far. It's actually fine. This makes other fusions + # sharing the same amount of numels go first; but make + # fusions only share weight/bias go later. + score = MixOrderReduction.get_fusion_score(node1, node2) + return _construct_return_value(score, True) + + node1_dep_len = len(node1.read_writes.reads) + len(node1.read_writes.writes) + node2_dep_len = len(node2.read_writes.reads) + len(node2.read_writes.writes) + + # optimization: iter over smaller set + if min(node1_dep_len, node2_dep_len) * 4 < max(node1_dep_len, node2_dep_len): + if node1_dep_len > node2_dep_len: + node1, node2 = node2, node1 + + deps = [ + dep + for dep in node1.read_writes.reads | node1.read_writes.writes + if dep in node2.read_writes.reads or dep in node2.read_writes.writes + ] + + return _construct_return_value( + sum(self.dep_size_hint(dep, count_bytes) for dep in deps), False + ) + + common_memory_deps = (node1.read_writes.reads | node1.read_writes.writes) & ( + node2.read_writes.reads | node2.read_writes.writes + ) + return _construct_return_value( + sum(self.dep_size_hint(dep) for dep in common_memory_deps), False + ) + + def get_possible_fusions_with_highest_priority( + self, possible_fusions: list[tuple[BaseSchedulerNode, BaseSchedulerNode]] + ) -> list[tuple[BaseSchedulerNode, BaseSchedulerNode]]: + # Group the possible fusions based on their priority from the backend. + # Only return the group of possible fusions with highest priority. + if len(possible_fusions) == 0: + return possible_fusions + possible_fusions_group_by_priority: dict[ + int, list[tuple[BaseSchedulerNode, BaseSchedulerNode]] + ] = {} + + for node1, node2 in possible_fusions: + assert node1.get_device() == node2.get_device() + device = node1.get_device() + fusion_pair_priority = int( + self.get_backend(device).get_fusion_pair_priority(node1, node2) + ) + if fusion_pair_priority not in possible_fusions_group_by_priority: + possible_fusions_group_by_priority[fusion_pair_priority] = [ + (node1, node2), + ] + else: + possible_fusions_group_by_priority[fusion_pair_priority].append( + (node1, node2) + ) + # return the possible fusions with highest priority + possible_fusions_with_highest_priority = min( + possible_fusions_group_by_priority.items(), key=operator.itemgetter(0) + )[1] + assert len(possible_fusions_with_highest_priority) > 0 + return possible_fusions_with_highest_priority + + def score_fusion_key( + self, nodes: tuple[BaseSchedulerNode, BaseSchedulerNode] + ) -> Any: + """ + Shim for list.sort(key=...) + """ + return V.choices.score_fusion(self, *nodes) + + def compute_last_usage(self) -> None: + """ + Populate node.last_usage recursively (also for the nodes within a FusedSchedulerNode) + """ + + future_used_buffers = OrderedSet(V.graph.get_output_names()) + + for node in reversed(self.nodes): + node.set_last_usage(future_used_buffers, self.mutation_real_name) + future_used_buffers.update(node.last_usage) + + def free_buffers(self) -> None: + """Free any buffers that are no longer needed""" + for name in sorted( + self.buffer_names_to_free + - V.graph.removed_buffers + - V.graph.wrapper_code.freed # type: ignore[has-type] + ): + if name in self.name_to_buf: + buf = self.name_to_buf[name] + if buf.can_free(): + V.graph.wrapper_code.codegen_free(buf.node) + elif name in V.graph.graph_inputs: + inp = V.graph.graph_inputs[name] + if isinstance(inp, ir.TorchBindObject): + V.graph.wrapper_code.codegen_free(inp) + elif isinstance(inp, ir.GeneratorState): + continue + else: + storage = inp.data + assert ( + isinstance(storage, ir.StorageBox) and storage.is_input_buffer() + ) + V.graph.wrapper_code.codegen_free(storage.data) + + self.buffer_names_to_free.clear() + + def flush(self) -> None: + for backend in self.backends.values(): + backend.flush() + self.free_buffers() + + def codegen_extern_call(self, scheduler_node: ExternKernelSchedulerNode) -> None: + assert isinstance(scheduler_node, ExternKernelSchedulerNode) + # 'decide_inplace_update' stores the inplace update decisions in + # the current kernel from where 'allocate' retrieve those decisions. + # We have to make sure there is a non-NULL kernel handler to store + # those inplace update decisions. + counters["inductor"]["extern_calls"] += 1 + with V.set_kernel_handler(Kernel(increase_kernel_count=False)): + scheduler_node.decide_inplace_update() + scheduler_node.mark_run() + node = scheduler_node.node + assert isinstance(node, ir.ExternKernel), f"{type(node)=}" + node.codegen(V.graph.wrapper_code) + self.free_buffers() + + def create_backend(self, device: torch.device) -> BaseScheduling: + assert not is_gpu(device.type) or device.index is not None, ( + f"{device} should have been normalized in lowering" + ) + V.graph.add_device_info(device) + + device_scheduling = get_scheduling_for_device(device.type) + if device_scheduling is None: + raise RuntimeError(f"Unsupported device type: {device.type}") + + if not has_triton(): + if ( + device.type == "cuda" + and (device_props := torch.cuda.get_device_properties(device)).major < 7 + ): + raise GPUTooOldForTriton(device_props, inspect.currentframe()) + elif is_gpu(device.type) and not device.type == "mps": + raise TritonMissing(inspect.currentframe()) + + return device_scheduling(self) + + def get_backend(self, device: Optional[torch.device]) -> BaseScheduling: + assert device is not None + if device not in self.backends: + self.backends[device] = self.create_backend(device) + return self.backends[device] + + def enter_context(self, node: BaseSchedulerNode) -> None: + def get_order(n: torch.fx.Node) -> int: + if n not in self.origin_to_index: + self.origin_to_index.update({n: i for i, n in enumerate(n.graph.nodes)}) + return self.origin_to_index[n] + + # Use a dict to have ordering + origins = { + (get_order(e), e): None + for n in node.get_nodes() + if n.node is not None + for e in n.node.get_origins() + } + origins = list(origins.keys()) + if origins: + _, last = max(origins, key=operator.itemgetter(0)) + V.graph.wrapper_code.enter_context(last) + + def can_buffer_be_removed_through_fusion( + self, name: str, fused_node_names: OrderedSet[str] + ) -> bool: + try: + users = self.name_to_buf[name].users + except KeyError: + return False + return ( + all(user.is_weak or user.get_name() in fused_node_names for user in users) + and name not in self.mutation_renames + and name not in self.mutation_real_name + ) + + def should_partition( + self, node: BaseSchedulerNode, should_log: bool = False + ) -> bool: + """Return True if we should partition the inductor graph on this node""" + + # Allow users to manually specify if a node should be partitioned + # Can only do this for FallbackKernels + ir_node = node.node + if isinstance(ir_node, torch._inductor.ir.FallbackKernel) and ( + op := ir_node.op_overload + ): + op_overload_packet_name = op.name() + op_overload_name = ( + f"{op_overload_packet_name}.{op._overloadname}" + if isinstance(op, torch._ops.OpOverload) + else op_overload_packet_name + ) + if ( + op_overload_packet_name in config.custom_should_partition_ops + or op_overload_name in config.custom_should_partition_ops + ): + assert isinstance(op, torch._ops.OpOverload) + return True + + # When not using cudagraphs, keep all kernels in the `call` function + # instead of graph partition functions, since graph partition only brings + # benefit to cudagraph + if ( + not torch._inductor.config.triton.cudagraphs + and _unstable_customized_partition_wrapper.wrapper is None + ): + return True + + # avoid duplicating logs when should_partition is called multiple times + # on the same node + def noop_log(msg: str, node: Optional[BaseSchedulerNode]) -> None: + return + + # Don't log partition reasons for CPU-only graphs since cudagraph + # partitioning is not relevant when there are no GPU devices + has_gpu_device = any(is_gpu(device) for device in V.graph.device_types) + log_partition_reason = ( + maybe_log_cudagraph_partition if should_log and has_gpu_device else noop_log + ) + + if isinstance(node, FusedSchedulerNode): + return any(self.should_partition(snode) for snode in node.snodes) + + assert node.node is not None + + if not node.is_gpu(): + log_partition_reason("non gpu ops", node=node) + + return True + + if isinstance(node.node, ir.DeviceCopy): + log_partition_reason("DeviceCopy ops", node=node) + return True + + if isinstance(node.node, ir.Conditional): + log_partition_reason("Conditional ops", node=node) + return True + + if getattr(node.node, "unbacked_bindings", None): + log_partition_reason("unbacked binding ops", node=node) + return True + + if is_cudagraph_unsafe_op(node.node): + log_partition_reason("CUDAGraph-unsafe custom ops", node=node) + return True + + # Partition around nodes with dynamic shapes when cudagraph_skip_dynamic_graphs is enabled + if config.triton.cudagraph_skip_dynamic_graphs: + if get_scheduler_node_symbol_uses(node): + log_partition_reason("dynamic shape ops", node=node) + return True + + return False + + def get_name_to_nodes( + self, + ) -> dict[str, Union[ir.IRNode, ir.TorchBindObject, sympy.Expr]]: + """ + Return a mapping from name strings to the corresponding graph inputs or + base scheduler node outputs. + """ + name_to_node: dict[str, Union[ir.IRNode, ir.TorchBindObject, sympy.Expr]] = {} + name_to_node.update(V.graph.graph_inputs) + + for node in self.nodes: + for name, scheduler_buffer in node.outputs_by_name.items(): + name_to_node[name] = scheduler_buffer.node + + return name_to_node + + def compute_graph_partition_maps( + self, + signatures: list[GraphPartitionSignature], + ) -> None: + """ + computes a mapping from partition input/output indices to graph input/output + indices for each partition. + """ + name_to_graph_input_index = { + name: idx for idx, name in enumerate(V.graph.graph_inputs) + } + name_to_graph_output_index = { + name: idx for idx, name in enumerate(V.graph.get_output_names()) + } + + V.graph.partition_maps = [] + for partition_id, signature in enumerate(signatures): + if signature.skip_cudagraph: + # Note: [Graph Partition Map for CUDAGraph] + # number of partition map should be the same as the number of generated + # partition functions. This assumption will be used when cudagraphify + # each partition function. + continue + + input_mapping = [] + for name in signature.input_nodes: + input_mapping.append(name_to_graph_input_index.get(name)) + + output_mapping = [] + for node in signature.output_nodes: + output_mapping.append(name_to_graph_output_index.get(node.get_name())) + + V.graph.partition_maps.append( + GraphPartitionMap( + partition_id, + input_mapping, + output_mapping, + signature.constant_names, + ) + ) + + def get_graph_partition_symbol_inputs( + self, + partition: PartitionType, + input_nodes: dict[str, Union[ir.IRNode, ir.TorchBindObject, sympy.Expr]], + ) -> OrderedSet[sympy.Symbol]: + """ + Returns all symbol inputs which are required to be in scope to successfully + perform codegen for this graph partition, including: + - free symbols used in partition nodes + - free symbols in partition input/node shapes, strides, and offsets. This is needed + for recording cudagraphs for tensors with dynamic shapes. + """ + + def get_input_node_symbols( + node: Union[ir.IRNode, sympy.Expr, ir.TorchBindObject], + ) -> OrderedSet[sympy.Symbol]: + """ + Gets symbols used in input node shapes, strides, and offsets. + """ + if isinstance(node, ir.TorchBindObject): + # TorchBindObject does not involve dynamic shapes yet + return OrderedSet() + elif isinstance(node, ir.IRNode): + return get_layout_symints(node) + else: + # node cannot be sympy.Expr since node comes from read_writes and + # read_writes does not contain sympy.Expr + raise NotImplementedError(f"Unsupported input node type: {type(node)}") + + def filter_symbols( + symbols: OrderedSet[sympy.Symbol], + ) -> OrderedSet[sympy.Symbol]: + """ + Filters a set of symbols that are required for codegen. Skip symbols + that are always internal to kernels, such as SymT.TMP, SymT.INDEX, + and SymT.R0_INDEX. + """ + return OrderedSet( + s + for s in symbols + if symbol_is_type( + s, + ( + SymT.SIZE, + SymT.FLOAT, + SymT.UNBACKED_INT, + SymT.UNBACKED_FLOAT, + ), + ) + ) + + candidate_symbols: OrderedSet[sympy.Symbol] = OrderedSet().union( + *(get_scheduler_node_symbol_uses(node) for node in partition) + ) + candidate_symbols.union( + *(get_input_node_symbols(node) for _, node in input_nodes.items()) + ) + + candidate_symbols = filter_symbols(candidate_symbols) + + res: OrderedSet[sympy.Symbol] = OrderedSet() + for s in candidate_symbols: + symplified_s = V.graph.sizevars.simplify(s) + # use free_symbols only when s is simplified to an Integer or expr + res.update(symplified_s.free_symbols) + + return OrderedSet(sorted(res, key=operator.attrgetter("name"))) + + def get_graph_partition_signature( + self, partitions: list[PartitionType], skip_cudagraphs: list[bool] + ) -> list[GraphPartitionSignature]: + """ + Gets signature for each graph partition, including input nodes, output nodes, and + whether deallocating an input within graph partition. + """ + signatures = [] + + unmet_output_names = OrderedSet(V.graph.get_output_names()) + name_to_node = self.get_name_to_nodes() + + def is_unallocated_buffer(buf_name: str) -> bool: + """ + Checks if buf_name resolves to a NoneLayout buffer (following mutation_real_name). + Buffers with NoneLayout are not allocated so graph partition should not + take them as inputs or outputs. + """ + buf = self.name_to_buf.get(buf_name, None) + + if buf is None: + return False + + if isinstance(buf.node.layout, NoneLayout): + # If there's a mutation real name, check the underlying buffer + # This handles both MutationOutput and other mutation ops like + # IndexPutFallback that have NoneLayout but mutate real buffers + if real_name := self.mutation_real_name.get(buf_name, None): + return is_unallocated_buffer(real_name) + + return True + + return False + + for partition, skip_cudagraph in zip( + reversed(partitions), reversed(skip_cudagraphs) + ): + output_names: OrderedSet[str] = OrderedSet() + + for node in partition: + output_names.update(node.outputs_by_name.keys()) + + returned_output_names = output_names.intersection(unmet_output_names) + + # all reads/writes are partition inputs except those generated + # within the partition and tensor constants + read_writes = dependencies.ReadWrites.merge_list( + [node.read_writes for node in partition] + ) + + # WeakDep is fake dependency on unused buffer. It should not appear + # in partition_input_names for inputs that are actually read or written. + partition_input_names = ( + OrderedSet( + [ + x.name + for x in read_writes.reads | read_writes.writes + if not isinstance(x, WeakDep) + ] + ) + - output_names + ) + + partition_input_names = OrderedSet( + self.mutation_real_name.get(name, name) + for name in partition_input_names + ) + + buffer_names_to_free: OrderedSet[str] = OrderedSet() + for node in partition: + buffer_names_to_free.update(node.last_usage) + + # buffer_names_to_free may contain buffers allocated in previous + # graph partitions. These buffers should also be a partition + # input. + extra_input_names = [ + name + for name in (buffer_names_to_free - output_names) + if name in name_to_node + ] + partition_input_names.update(extra_input_names) + + input_nodes = { + name: name_to_node[name] + for name in partition_input_names + if name in name_to_node + } + input_deallocation = { + name: name in buffer_names_to_free + for name in partition_input_names + if name in name_to_node + } + + # if an input tensor is not freed in the partition function, it should + # also be returned as an output. This brings benefits to cudagraph + # since the returned output tensor is a cudagraph managed tensor with + # a static tensor address. + extra_output_names = [ + name + for name in partition_input_names + if name in name_to_node and name not in buffer_names_to_free + ] + + returned_output_names.update(extra_output_names) + + returned_output_names = OrderedSet( + self.mutation_real_name.get(name, name) + for name in returned_output_names + ) + + output_nodes = [ + name_to_node[name] + for name in returned_output_names + if not is_unallocated_buffer(name) + ] + + constant_names = [ + name for name in partition_input_names if name in V.graph.constants + ] + + symbol_inputs = self.get_graph_partition_symbol_inputs( + partition, input_nodes + ) + + partition_signature = GraphPartitionSignature( + symbol_inputs, + input_nodes, + output_nodes, + input_deallocation, + skip_cudagraph, + constant_names, + ) + + signatures.append(partition_signature) + + unmet_output_names = partition_input_names.union( + # pyrefly: ignore [unsupported-operation] + unmet_output_names - returned_output_names + ) + + return signatures[::-1] + + def clean_removed_buffer_from_partition_signatures( + self, signature: GraphPartitionSignature + ) -> GraphPartitionSignature: + """ + Updates the partition signature by removing buffers specified in + V.graph.removed_buffers. See [Note: Removed Graph Partition Arguments] + """ + input_nodes = { + name: buffer + for name, buffer in signature.input_nodes.items() + if name not in V.graph.removed_buffers + } + input_deallocation = { + name: val + for name, val in signature.input_deallocation.items() + if name not in V.graph.removed_buffers + } + output_nodes = [ + node + for node in signature.output_nodes + if node.maybe_get_name() not in V.graph.removed_buffers + ] + constant_names = [ + name + for name in signature.constant_names + if name not in V.graph.removed_buffers + ] + return GraphPartitionSignature( + signature.symbol_inputs, + input_nodes, + output_nodes, + input_deallocation, + signature.skip_cudagraph, + constant_names, + ) + + def reorder_for_minimizing_partition( + self, + nodes: list[BaseSchedulerNode], + ) -> list[BaseSchedulerNode]: + """ + Reorder nodes to minimize the number of partitions via a bfs + topological sort. This is the optimal reordering such that the + number of partitions cannot be reduced further. This may be + sub-optimal for other metrics such as peak memory. This does not + change relative orders of two cudagraphable nodes, nor the + relative order of two non_cudagraphable nodes. + """ + import heapq + + node_to_indegree: dict[BaseSchedulerNode, int] = dict() + cudagraphable_nodes: list[tuple[int, BaseSchedulerNode]] = [] + non_cudagraphable_nodes: list[tuple[int, BaseSchedulerNode]] = [] + node_to_index = {node: idx for idx, node in enumerate(nodes)} + + def insert_pending_nodes(node: BaseSchedulerNode) -> None: + node_with_index = (node_to_index[node], node) + if self.should_partition(node): + heapq.heappush(non_cudagraphable_nodes, node_with_index) + else: + heapq.heappush(cudagraphable_nodes, node_with_index) + + def update_indegree(node: BaseSchedulerNode) -> None: + for succ_node in node.mpi_node.succ_nodes: + assert node_to_indegree[succ_node] > 0 + node_to_indegree[succ_node] -= 1 + if node_to_indegree[succ_node] == 0: + insert_pending_nodes(succ_node) + + for node in nodes: + node_to_indegree[node] = len(node.mpi_node.pred_nodes) + if node_to_indegree[node] == 0: + insert_pending_nodes(node) + + schedule: list[BaseSchedulerNode] = [] + num_iters: int = 0 + while num_iters < len(nodes) and ( + non_cudagraphable_nodes or cudagraphable_nodes + ): + while non_cudagraphable_nodes: + _, node = heapq.heappop(non_cudagraphable_nodes) + schedule.append(node) + update_indegree(node) + + while cudagraphable_nodes: + _, node = heapq.heappop(cudagraphable_nodes) + schedule.append(node) + update_indegree(node) + + num_iters += 1 + + if num_iters > len(nodes): + raise RuntimeError( + """ + Failed to schedule, while loop ran too long when + reordering for minimizing the num of partitions + """ + ) + + return schedule + + def maybe_reorder_for_minimizing_partition( + self, + nodes: list[BaseSchedulerNode], + ) -> list[BaseSchedulerNode]: + """ + Reorder nodes to minimize the number of partitions if this only slightly + increase peak memory. + """ + from .memory import estimate_peak_memory, prepare_planning_info + + graph_outputs = OrderedSet(V.graph.get_output_names()) + + default_peak_memory, name_to_freeable_input_buf = prepare_planning_info( + nodes, + self.name_to_buf, + self.name_to_fused_node, + OrderedSet(V.graph.graph_inputs.keys()), + graph_outputs, + ) + + reordered_nodes = self.reorder_for_minimizing_partition(nodes) + reorder_peak_memory, _ = estimate_peak_memory( + reordered_nodes, name_to_freeable_input_buf, graph_outputs + ) + + # 1.1 here means 10% extra peak memory budget which is quite arbitrary + if reorder_peak_memory < default_peak_memory * 1.1: + return reordered_nodes + + return nodes + + def reorder_for_partition_with_simple_dependency( + self, nodes: list[BaseSchedulerNode] + ) -> list[BaseSchedulerNode]: + """ + Reorder a node if it should be partitioned and has simple dependency: + 1. move a partitioned node to the front if it has no dependency + 2. move a partitioned node to the back if it is only used by OutputNode + 3. otherwise do not reorder + """ + + front: list[BaseSchedulerNode] = [] + middle: list[BaseSchedulerNode] = [] + back: list[BaseSchedulerNode] = [] + + def only_output_user(node: BaseSchedulerNode) -> bool: + for buf in node.get_outputs(): + for use in buf.users: + if not isinstance(use.node, OutputNode): + return False + return True + + for node in nodes: + should_partition = self.should_partition(node) + if should_partition and len(node.unmet_dependencies) == 0: + front.append(node) + elif should_partition and only_output_user(node): + back.append(node) + else: + middle.append(node) + + return front + middle + back + + def graph_partition( + self, + ) -> tuple[list[PartitionType], list[GraphPartitionSignature]]: + """ + Given a list of BaseSchedulerNodes, split into a list of + graph partitions and compute partition input/output signatures. + """ + partitions: list[PartitionType] = [] + skip_cudagraph = True + cur_partition: PartitionType = [] + skip_cudagraphs = [] + for node in self.nodes: + should_partition = self.should_partition(node, should_log=True) + if cur_partition and skip_cudagraph != should_partition: + partitions.append(cur_partition) + skip_cudagraphs.append(skip_cudagraph) + cur_partition = [] + + skip_cudagraph = should_partition + cur_partition.append(node) + + if cur_partition: + partitions.append(cur_partition) + skip_cudagraphs.append(skip_cudagraph) + + signatures = self.get_graph_partition_signature( + partitions=partitions, skip_cudagraphs=skip_cudagraphs + ) + self.compute_graph_partition_maps(signatures) + + return partitions, signatures + + def codegen(self) -> None: + with dynamo_timed("Scheduler.codegen"): + return ( + self._codegen_partitions() + if torch._inductor.config.graph_partition + else self._codegen(self.nodes) + ) + + def _codegen_partition_wrapper( + self, + partition: PartitionType, + signature: GraphPartitionSignature, + ) -> None: + """Codegen a partition given its inputs/outputs""" + from .codegen.wrapper import SubgraphPythonWrapperCodegen + + parent_wrapper_code = V.graph.wrapper_code + graph_partition_id = next(self._graph_partition_counter) + + with V.graph.set_current_wrapper_code(): + V.graph.init_wrapper_code( + is_subgraph=True, + subgraph_name=f"partition_{graph_partition_id}", + parent_wrapper_code=parent_wrapper_code, + partition_signatures=signature, + ) + self._codegen(partition) + + # Note: [Removed Graph Partition Arguments] + # Graph partition relies on node.read_writes to analyze the partition + # inputs and outputs. However, during codegen, we may decide some buffers + # are internal to a kernel (e.g., triton kernel) such that these buffers + # are never actually defined. This information is collected during codegen + # and recorded in V.graph.removed_buffers. So we cleanup signature and write + # prefix (i.e., generating call function and return outputs) after we have + # codegen the partition. + assert isinstance(V.graph.wrapper_code, SubgraphPythonWrapperCodegen) + signature = self.clean_removed_buffer_from_partition_signatures(signature) + V.graph.wrapper_code.partition_signatures = signature + V.graph.wrapper_code.write_prefix() + + graph_name = V.graph.name + partition_code, _ = V.graph.wrapper_code.generate(V.graph.is_inference) + + V.graph.wrapper_code.define_subgraph_launcher_fn(graph_name, partition_code) + + V.graph.wrapper_code.codegen_partition_call(graph_partition_id, signature) + V.graph.wrapper_code.allocated.update( # type: ignore[has-type] + [node.get_name() for node in signature.output_nodes] + ) + + def use_default_device_context( + self, partitions: list[PartitionType], signatures: list[GraphPartitionSignature] + ) -> contextlib.AbstractContextManager[None]: + @contextlib.contextmanager + def ctx() -> Iterator[None]: + self.update_graph_partition_default_device(partitions, signatures) + if self.default_device_context and device_need_guard( + self.default_device_context.type + ): + assert self.default_device_context.index is not None, ( + "device should have an index" + ) + V.graph.wrapper_code.codegen_device_guard_enter( + self.default_device_context.index + ) + + try: + yield + finally: + if self.default_device_context and device_need_guard( + self.default_device_context.type + ): + V.graph.wrapper_code.codegen_device_guard_exit() + self.default_device_context = None + + return ctx() + + def update_graph_partition_default_device( + self, partitions: list[PartitionType], signatures: list[GraphPartitionSignature] + ) -> None: + # Note: [Graph Partition Device Contexts] + # Entering a device context takes 60 microseconds and exiting a device + # context takes 20 microseconds. If all graph partitions and + # cudagraph-unsafe ops happen on the same device, we can share the + # device context. + + if len(partitions) == 1 and not signatures[0].skip_cudagraph: + # If there is only 1 cudagraph partition, the device context + # should happen within the cudagraph partition, which + # would be removed by cudagraph. + return + + def get_cudagraph_partition_device(partition: PartitionType) -> torch.device: + partition_device = partition[0].get_device() + assert partition_device is not None + return partition_device + + def all_on_target_device( + partition: PartitionType, target_device: torch.device + ) -> bool: + for node in partition: + device = node.get_device() + if device != target_device: + return False + return True + + cudagraph_partition_device = None + for partition, signature in zip(partitions, signatures): + if not signature.skip_cudagraph: + cudagraph_partition_device = get_cudagraph_partition_device(partition) + break + + # all partitions skip cudagraph + if cudagraph_partition_device is None: + return + + for partition, signature in zip(partitions, signatures): + if signature.skip_cudagraph and not all_on_target_device( + partition, cudagraph_partition_device + ): + return + + self.default_device_context = cudagraph_partition_device + + def _codegen_partitions(self) -> None: + """ + Split nodes into partitions and codegen each partition into separate functions. + This allows further applying different optimizations (e.g., cudagraph) to + each function. + """ + partitions, signatures = self.graph_partition() + + if len(partitions) > 1: + msg = f"cudagraph partition into {len(partitions)} partitions" + maybe_log_cudagraph_partition(msg=msg, prefix="") + counters["inductor"]["cudagraph_partitions"] += len(partitions) + + with self.use_default_device_context(partitions, signatures): + for partition, signature in zip(partitions, signatures): + assert len(partition) >= 1, ( + f"Each partition must have at least one node but found {len(partition)}" + ) + + if signature.skip_cudagraph: + self._codegen(partition) + else: + self._codegen_partition_wrapper(partition, signature) + + num_partitions = next(self._graph_partition_counter) + V.graph.wrapper_code.set_all_partition_names(num_partitions) + + # See [Note: Graph Partition Map for CUDAGraph] + if num_partitions > 0: + assert V.graph.partition_maps is not None + assert num_partitions == len(V.graph.partition_maps), ( + f"Expect {num_partitions} partition maps but got {len(V.graph.partition_maps)}" + ) + + def _codegen(self, nodes: list[BaseSchedulerNode]) -> None: + if config.check_stack_no_cycles_TESTING_ONLY: + import torch._dynamo.convert_frame + + stack = traceback.extract_stack() + seen: OrderedSet[tuple[str, int | None]] = OrderedSet() + for frame in reversed(stack): + # This is where maybe_cprofile is + if ( + frame.name == "_compile_inner" + and frame.filename == torch._dynamo.convert_frame.__file__ + ): + break + key = (frame.filename, frame.lineno) + assert key not in seen, ( + f"Duplicate stack frame {frame.filename}:{frame.lineno}; " + "did you add a decorator to one of the functions in this stack " + "trace? If so, try using a context manager instead." + ) + seen.add(key) + + self.current_device = self.default_device_context + + # pyrefly: ignore [unbound-name] + if self.default_device_context and config.triton.autotune_at_compile_time: + V.graph.wrapper_code.write_get_raw_stream_header() + + for node in nodes: + if log.isEnabledFor(logging.DEBUG): + try: + log.debug( + "Generating code for node %s with estimated runtime %f", + node.get_name(), + node.get_estimated_runtime(), + ) + except Exception: + log.debug( + "Generating code for node %s with estimated runtime 0.0", + node.get_name(), + ) + + self.enter_context(node) + + if device := node.get_device(): + if ( + device != self.current_device + or node.is_extern() + or node.is_template() + ): + self.flush() + if device != self.current_device: + if self.current_device and device_need_guard( + self.current_device.type + ): + V.graph.wrapper_code.codegen_device_guard_exit() + self.current_device = device + if device_need_guard(device.type): + assert device.index is not None, "device should have an index" + V.graph.wrapper_code.codegen_device_guard_enter(device.index) + + self.current_node = node + self.buffer_names_to_free.update(node.last_usage) + + if node.is_template(): + prologue, template_node, epilogue = node.get_prologue_template_epilogue( + list(node.get_nodes()) + ) + # pyrefly: ignore [unbound-name] + self.get_backend(device).codegen_template( + template_node, epilogue, prologue + ) + elif node.is_extern(): + node = typing.cast(ExternKernelSchedulerNode, node) + self.codegen_extern_call(node) + elif node.is_foreach(): + node = typing.cast(ForeachKernelSchedulerNode, node) + # pyrefly: ignore [unbound-name] + backend_ = self.get_backend(device) + from .codegen.cuda_combined_scheduling import CUDACombinedScheduling + from .codegen.simd import SIMDScheduling + + if isinstance(backend_, (SIMDScheduling, CUDACombinedScheduling)): + backend = backend_ + else: + raise AssertionError(f"{type(self)=}") + backend.codegen_combo_kernel(node) + elif isinstance(node, FusedMixOrderReductions): + # pyrefly: ignore [unbound-name] + self.get_backend(device).codegen_mix_order_reduction(node) + elif isinstance(node, (FusedSchedulerNode, SchedulerNode)): + # pyrefly: ignore [unbound-name] + self.get_backend(device).codegen_node(node) + else: + assert isinstance(node, NopKernelSchedulerNode) + node.mark_run() + + # pyrefly: ignore [unbound-name] + if config.triton.debug_sync_kernel: + # pyrefly: ignore [unbound-name] + self.get_backend(device).codegen_sync() + + self.available_buffer_names.update(node.get_buffer_names()) + self.completed_operations.update(node.get_operation_names()) + + if not isinstance(node, NopKernelSchedulerNode): + device = node.get_device() + if ( + device is not None + and device.type != "meta" + and self.get_backend(device).ready_to_flush() + ): + self.flush() + + if self.current_device != self.default_device_context: + # when default_device_context is not None, we are codegen + # for graph partitions and all nodes must be on + # the same default device. + assert self.current_device is not None + if device_need_guard(self.current_device.type): + # exit the outermost CUDA device guard. this is + # important for nested indentation codegen-ing. + V.graph.wrapper_code.codegen_device_guard_exit() + + self.flush() + + def benchmark_combo_kernel( + self, node_list: Sequence[BaseSchedulerNode] + ) -> tuple[float, float, list[Optional[str]]]: + """ + Benchmark fused list of nodes and return the execution time + in milliseconds on randomly generated inputs. + """ + device = node_list[0].get_device() + V.graph.scheduler = self + self.current_device = device + assert device is not None + backend = self.get_backend(device) + return backend.benchmark_combo_kernel(node_list) + + def speedup_by_combo_kernel(self, nodes: list[BaseSchedulerNode]) -> bool: + """ + If config.benchmark_fusion is False, always return True. + Otherwise, return True if fusion can brings speedup. + """ + + subkernel_nodes = nodes + device = subkernel_nodes[0].get_device() + + assert all(node.get_device() == device for node in subkernel_nodes), ( + "All nodes in a combo kernel group must be on the same device" + ) + + if not config.benchmark_combo_kernel: + return True + + from triton.compiler.errors import CompilationError + + ms1, path1_list = 0.0, [] + for i, snode in enumerate(subkernel_nodes): + node_list = snode.get_nodes() + # We can not accurately benchmark kernel using atomic_add + # due to how we generate random integer inputs. + if self._any_atomic_add(node_list): + fusion_log.debug( + "ComboKernel: benchmarking may not accurate due to atomic_add" + ) + + try: + ms, path = self.benchmark_fused_nodes(node_list) + if math.isinf(ms): + fusion_log.debug( + "ComboKernel benchmark: register spilling of %d-th subkernel", + i, + ) + return False + except CompilationError as e: + # workaround triton issue: https://github.com/triton-lang/triton/issues/2151 + if "Loop-carried variable" in str(e): + fusion_log.debug( + "ComboKernel benchmark: return True because of loop-carried variable" + ) + return True # allow fusion + else: + raise + ms1 += ms + path1_list.append(path) + + try: + ms2, ms2_clone, _path2_list = self.benchmark_combo_kernel(subkernel_nodes) + except CompilationError as e: + # workaround triton issue: https://github.com/triton-lang/triton/issues/2151 + if "Loop-carried variable" in str(e): + fusion_log.debug( + "ComboKernel benchmark: return True because of loop-carried variable" + ) + return True # allow fusion + else: + raise + + # small kernels are very likely to have speedup but hard to benchmark. So we skip benchmarking. + small_kernel = ms2 - ms2_clone < 0.3 or ms1 < 0.3 + if fusion_log.isEnabledFor(logging.DEBUG): + if ms1 > ms2 or small_kernel: + fusion_log.debug( + "can fuse (benchmark): fusing causes %sx speedup", + green_text(f"{ms1 / ms2:.3f}"), + ) + else: + fusion_log.debug( + "cannot fuse (benchmark): fusing causes %sx slowdown", + red_text(f"{ms1 / ms2:.3f}"), + ) + # ms1 returned by benchmark_fused_nodes discounted clone time + return ms2 - ms2_clone < ms1 or small_kernel + + def get_buffer_layout(self, buf_name: str) -> ir.Layout: + buf = self.name_to_buf[buf_name] + assert buf.node is not None + return buf.node.get_layout() + + def update_zero_dim_cpu_tensor(self) -> None: + for node in self.nodes: + if node.is_gpu(): + for read in node.read_writes.reads: + buffer = V.graph.name_to_buffer.get(read.name) + if ( + buffer + and get_device_type(buffer) == "cpu" + and not isinstance( + buffer.layout, (NoneLayout, MultiOutputLayout) + ) + and buffer.get_size() == [] + ): + V.graph.zero_dim_cpu_tensor_list.add(read.name) + + +class BaseScheduling: # noqa: docstring_linter + def __init__(self, scheduler: Optional[Scheduler]): + super().__init__() + self.scheduler = scheduler + + def free_buffers_in_scheduler(self) -> None: + if self.scheduler: + self.scheduler.free_buffers() + + def get_backend_features(self, device: torch.device) -> OrderedSet[BackendFeature]: + """Return a set of .codegen.common.BackendFeature()""" + return OrderedSet() + + def can_fuse_vertical( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> bool: + """ + Check whether node1 and node2 can be vertically fused or not. + """ + raise NotImplementedError + + def can_fuse_horizontal( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> bool: + """ + Check whether node1 and node2 can be horizontally fused or not. + """ + raise NotImplementedError + + def can_fuse_multi_outputs_template( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> bool: + """ + A Multi-Output Template (referenced in #144012) is a template node + with MultiOutputLayout, and its output buffers are instances of MultiOutput. + In this context, we verify whether node1 represents the Multi-Output Template + and node2 corresponds to one of its outputs. If so, we further check if + backend supports this fusion. + """ + return False + + def fuse( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> FusedSchedulerNode: + """ + Fuse two nodes + """ + if node1.is_foreach() or node2.is_foreach(): + return ForeachKernelSchedulerNode.fuse(node1, node2) + elif MixOrderReduction.are_mix_order_reductions(node1, node2): + return FusedMixOrderReductions(node1, node2) + elif isinstance(node1, FusedMixOrderReductions): + return node1.fuse_with(node2) + else: + return FusedSchedulerNode.fuse(node1, node2) + + def group_fn( + self, sizes: Sequence[Sequence[sympy.Expr]] + ) -> tuple[tuple[sympy.Expr, ...], ...]: + """ + Process the iteration sizes in case a transformation needs to be applied. + """ + raise NotImplementedError + + def codegen_template( + self, + template_node: BaseSchedulerNode, + epilogue_nodes: Sequence[BaseSchedulerNode], + prologue_nodes: Sequence[BaseSchedulerNode], + ) -> Optional[str]: + """ + Given a template node, generate a kernel. + + This function is only available for triton now. If the third-party backend behaves as a sub-class + of TritonScheduling, it can override it or reuse it. + """ + raise NotImplementedError + + def generate_kernel_code_from_nodes( + self, + nodes: Sequence[BaseSchedulerNode], + benchmark_kernel: bool, + hint_override: Optional[int] = None, + ) -> str: + """ + Generate a kernel given a list of pre-fused nodes. + """ + raise NotImplementedError + + def codegen_node(self, node: Union[FusedSchedulerNode, SchedulerNode]) -> None: + """ + Generate a kernel given a list of pre-fused nodes. + """ + raise NotImplementedError + + def codegen_mix_order_reduction(self, node: FusedMixOrderReductions) -> None: + raise NotImplementedError + + def codegen_sync(self) -> None: + """ + Generate synchronization code for the kernel. This method depends on the hardware characteristics. + """ + raise NotImplementedError + + def ready_to_flush(self) -> bool: + """ + Check whether the backend is requesting the scheduler to flush the generated kernel. + If not supported, please return False. + """ + return False + + def flush(self) -> None: + """ + Flush the generated kernel and python wrapper code to the source code file. + """ + raise NotImplementedError + + def benchmark_fused_nodes( + self, nodes: Sequence[BaseSchedulerNode] + ) -> tuple[float, str]: + """ + Benchmark fused list of nodes and return the execution time + in milliseconds on randomly generated inputs. + """ + raise NotImplementedError + + def benchmark_codegened_module(self, module: ModuleType) -> tuple[float, str]: + """ + Benchmark a compiled module and return the execution time + in milliseconds on randomly generated inputs. + """ + raise NotImplementedError + + def get_fusion_pair_priority( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> int: + """ + Return an unsigned integer which represents the priority of this fusion pair. + The smaller is with higher priority. + """ + return 0 + + def benchmark_combo_kernel( + self, node_list: Sequence[BaseSchedulerNode] + ) -> tuple[float, float, list[Optional[str]]]: + """ + Benchmark the list of nodes to combine and return the execution time + and memory copy time in milliseconds on randomly generated inputs. + """ + raise NotImplementedError + + def codegen_comment( + self, + node_schedule: Sequence[BaseSchedulerNode], + kernel_name: Optional[str] = None, + ) -> None: + if kernel_name: + from torch._inductor.debug import set_kernel_post_grad_provenance_tracing + + debug_handle = set_kernel_post_grad_provenance_tracing( + node_schedule, # type: ignore[arg-type] + kernel_name, + ) + V.graph.wrapper_code.write_provenance_debug_handle( + kernel_name, debug_handle + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/script.ld b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/script.ld new file mode 100644 index 0000000000000000000000000000000000000000..5a052e984fcd720526201aa93d6d13b0aba2107a --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/script.ld @@ -0,0 +1,8 @@ +SECTIONS { + /* By default, in LLD 16, .lrodata is placed immediately after .rodata. + * However, .lrodata can be very large in our compiled models, which leads to + * relocation out-of-range errors for relative relocations. So we place it + * after other the sections that are referenced from .text using relative + * relocations. This is the default behavior in GNU ld. */ + .lrodata : { *(.lrodata) } + } INSERT AFTER .bss; diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/select_algorithm.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/select_algorithm.py new file mode 100644 index 0000000000000000000000000000000000000000..a54fb2263ec8387d3c9c9cb8cdf58bc82e89184c --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/select_algorithm.py @@ -0,0 +1,4557 @@ +# mypy: allow-untyped-defs +import contextlib +import dataclasses +import functools +import hashlib +import inspect +import itertools +import json +import logging +import math +import operator +import os +import re +import sys +import textwrap +import time +from collections.abc import Callable, Sequence +from concurrent.futures import as_completed, ThreadPoolExecutor +from io import StringIO +from pathlib import Path +from types import ModuleType +from typing import Any, NamedTuple, Optional, TYPE_CHECKING, Union +from typing_extensions import Self +from unittest.mock import patch + +import sympy + +import torch +import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools +from torch._dynamo.device_interface import get_interface_for_device +from torch._dynamo.testing import rand_strided +from torch._dynamo.utils import ( + counters, + dynamo_timed, + get_chromium_event_logger, + identity, + preserve_rng_state, +) +from torch._inductor.await_utils import await_sync +from torch._inductor.utils import clear_on_fresh_cache +from torch.utils._filelock import FileLock +from torch.utils._ordered_set import OrderedSet + +from ..utils._sympy.functions import CeilDiv +from . import config, ir +from .autotune_process import ( + TensorMeta, + TritonBenchmarkRequest, + TritonCPUBenchmarkRequest, + TritonGPUBenchmarkRequest, +) +from .codecache import code_hash, PersistentCache, PyCodeCache +from .codegen.common import ( + CSEVariable, + IndentedBuffer, + KernelTemplate, + OpOverrides, + WorkspaceArg, + WorkspaceZeroMode, +) +from .codegen.simd_kernel_features import SIMDKernelFeatures +from .codegen.subgraph import SubgraphChoiceCaller +from .codegen.triton import ( + gen_common_triton_imports, + texpr, + TMACompatibilityChecker, + TritonKernel, + TritonScheduling, +) +from .codegen.triton_utils import config_of, equal_1_arg_indices, signature_to_meta +from .codegen.wrapper import pexpr +from .exc import CUDACompileError +from .fx_utils import count_flops_fx +from .ir import ChoiceCaller, PrimitiveInfoType +from .ops_handler import StoreMode +from .runtime.hints import DeviceProperties +from .runtime.triton_compat import HAS_WARP_SPEC +from .runtime.triton_heuristics import FixedGrid +from .utils import ( + ceildiv, + do_bench_using_profiling, + FakeIndentedBuffer, + get_dtype_size, + is_gpu, + Placeholder, + restore_stdout_stderr, + sympy_dot, + sympy_index_symbol, + sympy_product, + triton_type, + triton_type_to_torch, + unique, +) +from .virtualized import V + + +log = logging.getLogger(__name__) + +# correctness checks struggle with fp16/tf32 +VERIFY: dict[str, Any] = {} +PRINT_AUTOTUNE = True +DEBUG = False + + +if TYPE_CHECKING: + import concurrent + + from torch._inductor.autotune_process import BenchmarkRequest + from torch._inductor.codegen.simd import IterationRangesEntry, IterationRangesRoot + + from .codegen.common import CSE + + +class KernelNamespace: + pass + + +# these objects are imported from the generated wrapper code +extern_kernels = KernelNamespace() + + +@dataclasses.dataclass +class BenchmarkTensors: + """Represents a set of inputs and outputs for autotuning with a template""" + + input_tensors: list[torch.Tensor] + output_tensor: Optional[torch.Tensor] + + def unpack(self): + return self.input_tensors, self.output_tensor + + +@dataclasses.dataclass +class AutotuneArgs: + """During autotuning, we need to pass the same inputs to all choices. + Note: + Since we typically have a mix of external choices and triton choices, we create + two lists of inputs for the same underlying buffers: + - External inputs (for aten kernels): Include offset for sliced tensors + - Triton inputs: Use base pointer for sliced tensors, without offset + """ + + triton: BenchmarkTensors + extern: BenchmarkTensors + expected: Optional[torch.Tensor] = None + + def get_benchmark_tensors(self, extern=False) -> BenchmarkTensors: + """Returns the inputs and output tensors for a given choice.""" + bench_tensors = self.extern if extern else self.triton + return bench_tensors + + @classmethod + def from_choice_args( + cls, + example_inputs: list[torch.Tensor], + example_inputs_extern: list[torch.Tensor], + out: torch.Tensor, + out_extern: torch.Tensor, + expected: Optional[torch.Tensor] = None, + ) -> Self: + """Factory method to create AutotuneInputs from separate inputs/outputs""" + return cls( + triton=BenchmarkTensors(example_inputs, out), + extern=BenchmarkTensors(example_inputs_extern, out_extern), + expected=expected, + ) + + def verify(self, **kwargs): + """Verify the correctness of the benchmarking results""" + + torch.testing.assert_close(self.extern.output_tensor, self.expected, **kwargs) + + +class PartialRender: + """ + Some parts of a template need to be generated at the end, but + inserted into the template at the start. This allows doing a bunch + of replacements after the initial render. + """ + + HookFn = Callable[[], str] + + def __init__( + self, code: str, replacement_hooks: dict[str, Optional[HookFn]] + ) -> None: + super().__init__() + self._code: str = code + self.replacement_hooks: dict[str, Optional[PartialRender.HookFn]] = ( + replacement_hooks + ) + + @property + def code(self) -> str: + """ + The fully rendered code. Will **error** if any hooks have yet to be + finalized. + """ + remaining_active_hooks = [ + key for key, fn in self.replacement_hooks.items() if fn is not None + ] + assert len(remaining_active_hooks) == 0, ( + f"The following hooks have not yet been finalized:\n {remaining_active_hooks=}" + ) + return self._code + + def finalize_hook(self, hook_key: str, strict: bool = True) -> None: + """ + Finalize a hook by name. + + :param strict: If ``True``, raise an error if the hook wasn't found. + + NOTE: Will **error** if the hook has already been finalized. + """ + if hook_key not in self.replacement_hooks: + if strict: + raise RuntimeError( + f"{hook_key} not registered in self.replacement_hooks" + ) + else: + return + + hook = self.replacement_hooks[hook_key] + assert hook is not None, f"Hook key {hook_key} can only be called once" + self._code = self._code.replace(hook_key, hook()) + + self.replacement_hooks[hook_key] = None + + def finalize_remaining(self) -> str: + """ + Finalize the remaining active hooks. This function can be used in cases + where the caller uses `finalize_hook` rather than `finalize_all`. + Note: `finalize_all` errors if a hook that has already been finalized + is attempted to be called again. This function only attempts to + finalize active hooks. + """ + for key, fn in self.replacement_hooks.items(): + if fn is not None: + self.finalize_hook(key) + return self.code + + def finalize_all(self) -> str: + """ + Finalize all active hooks. + + NOTE: unlike ``finalize_remaining``, this method will **error** if any + hook has already been finalized. + """ + for key in self.replacement_hooks: + self.finalize_hook(key) + return self.code + + +# This is used to store info needed for lowering each subgraph in triton +# templates + + +@dataclasses.dataclass() +class SubgraphInfo: + body: IndentedBuffer + template_mask: Optional[str] = None + template_out_shape: Optional[Union[str, tuple[str]]] = None + compute: IndentedBuffer = dataclasses.field(default_factory=IndentedBuffer) + indexing_code: IndentedBuffer = dataclasses.field(default_factory=IndentedBuffer) + loads: IndentedBuffer = dataclasses.field(default_factory=IndentedBuffer) + stores: IndentedBuffer = dataclasses.field(default_factory=IndentedBuffer) + ops_handler: Optional[V.WrapperHandler] = None # type: ignore[name-defined] + cse: Optional["CSE[Any]"] = None + + # only copied over if not None + range_trees: Optional[list["IterationRangesRoot"]] = None + range_tree_nodes: Optional[dict[sympy.Symbol, "IterationRangesEntry"]] = None + numels: Optional[dict[str, sympy.Expr]] = None + + def __post_init__(self): + self.only_copy_if_non_none_fields = ( + "range_trees", + "range_tree_nodes", + "numels", + "cse", + ) + + def to_dict(self): + return { + field.name: getattr(self, field.name) for field in dataclasses.fields(self) + } + + +class ModificationWrapper(V.WrapperHandler): # type: ignore[name-defined] + """Handles placeholder substitutions during subgraph processing.""" + + def __init__( + self, + kernel, + subgraph_number: int, + fixed_inputs: dict[str, Any], + mask: Optional[str], + ): + super().__init__(V.ops) + self.name = f"PlaceholderSubstitution_{subgraph_number}" + self.kernel = kernel + self.fixed_inputs = fixed_inputs + self.mask = mask + + def load(self, name: str, index: sympy.Expr): + """Handle loading from tensor or fixed input.""" + if name not in self.fixed_inputs: + index_str = self._process_indexing(index) + var = self._add_kernel_input(name) + buffer = V.graph.get_buffer(name) + var_dtype = buffer.dtype + line = f"tl.load({var} + {index_str})" + + if ( + var_dtype in (torch.float16, torch.bfloat16) + and config.triton.codegen_upcast_to_fp32 + ): + line += ".to(tl.float32)" + var_dtype = torch.float32 + + out = self.kernel.cse.generate( + self.kernel.compute, line, dtype=var_dtype, shape=() + ) + return out + + return self.kernel.cse.generate( + self.kernel.compute, + f"({self.fixed_inputs[name]})", + dtype=torch.float32, + shape=(), + ) + + def indirect_indexing(self, index_var: str, size, check, wrap_neg=True): + """Convert index variable to symbolic form.""" + return sympy_index_symbol(str(index_var)) + + # pyrefly: ignore [bad-override] + def store( + self, name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None + ) -> str: + """Currently only supports stores for atomic adds coming from scatter nodes + This is used by flex_attention's backwards grad for captured buffers, see + zeros_and_scatter lowering + """ + assert self.mask is not None, ( + "Mask is required for inner stores in modifications" + ) + assert mode == "atomic_add", "Only atomic_add is supported for inner stores" + + buf_name = self._add_kernel_input(name) + index_str = self._process_indexing(index) + index_str = f"tl.broadcast_to({index_str}, {value}.shape)" + store = f"tl.atomic_add({buf_name} + {index_str}, {value}, {self.mask}, sem='relaxed')" + return store + + def _add_kernel_input(self, name: str): + """Add name as input to kernel and return input ref.""" + return self.kernel.args.input(name) + + def _process_indexing(self, index): + """Process and rename indexing, adding symbols as kernel inputs.""" + return self.kernel.kexpr(self.kernel.rename_indexing(index)) + + +# Function name, followed by args and kwargs. +RecordedEventsType = list[tuple[str, list[Any], dict[str, Any]]] + + +class TritonTemplateKernel(TritonKernel): + """ + A specialized kernel class for Triton templates that handles code generation + for templated Triton kernels. + + This class extends TritonKernel to provide additional functionality for + template-based kernel generation, including support for subgraphs, workspace + arguments, and prologue/epilogue fusion. + """ + + def __init__( + self, + kernel_name, + input_nodes: tuple[ir.IRNode, ...], + output_node, + defines, + num_stages, + num_warps, + grid_fn, + meta, + call_sizes, + num_consumer_groups=0, + num_buffers_warp_spec=0, + use_jit=False, + tma_store=False, + transpose_discontiguous_tensor_descriptors_override=None, + prefix_args=0, + suffix_args=0, + epilogue_fn=identity, + subgraphs: Optional[list[ir.ComputedBuffer]] = None, + workspace_arg: Optional[WorkspaceArg] = None, + prologue_loads_all_inputs=False, + hint_override: Optional[int] = None, + ) -> None: + if tma_store: + pass + numel = sympy_product(output_node.get_size()) + if tma_store: + assert len(output_node.get_size()) == 2, ( + "TMA store only supported for 2D with templates" + ) + tiling = { + "x": output_node.get_size()[0], + "y": output_node.get_size()[1], + "r0_": sympy.S.One, + } + else: + tiling = { + "x": numel, + "r0_": sympy.S.One, + } + super().__init__( + tiling, + features=SIMDKernelFeatures([], numel), + hint_override=hint_override, + ) + if tma_store: + # By default `construct_range_trees` will return the range_trees in the order + # ["z", "y", "x", "r0_", "r1_"] (see simd.py:all_prefixes) + # and this order defines what the kernel block shape will be. So if the template + # input / output has requested e.g. ["x", "y"], `construct_range_trees` will still return the + # trees in the order ["y", "x"]. This would mean that the template would need to transpose + # the loaded value. + # The below sorts the range trees according to that required by the caller + prefix_to_range_tree = {rt.prefix: rt for rt in self.range_trees} + pw_sorted_range_trees = [] + reduction_idx = None + for i, prefix in enumerate(tiling): + rt = prefix_to_range_tree[prefix] + # pyrefly: ignore # missing-argument + if rt.is_reduction: + reduction_idx = i + break + rt.index = i + rt.grid_dim = i + rt.tensor_dim = i + pw_sorted_range_trees.append(rt) + self.range_trees = pw_sorted_range_trees + self.range_trees[reduction_idx:] + + self.input_nodes = input_nodes + self.output_node = output_node + self.named_input_nodes = {} # type: ignore[var-annotated] + self.defines = defines + self.kernel_name = kernel_name + self.use_jit = use_jit + self.tma_store = tma_store + self.transpose_discontiguous_tensor_descriptors_override = ( + transpose_discontiguous_tensor_descriptors_override + ) + self.num_stages = num_stages + self.num_warps = num_warps + self.num_consumer_groups = num_consumer_groups + self.num_buffers_warp_spec = num_buffers_warp_spec + self.grid_fn = grid_fn + self.meta = meta + self.call_sizes = call_sizes + # for templates with fixed epilogues + self.prefix_args = prefix_args + self.suffix_args = suffix_args + # pyrefly: ignore [invalid-type-var] + self.epilogue_fn = epilogue_fn + self.render_hooks = {} # type: ignore[var-annotated] + self.triton_meta: Optional[dict[str, object]] = None + # For Templated Attention this can be a list of ir.Subgraph + self.subgraphs: Optional[list[ir.ComputedBuffer]] = subgraphs + + # Some templates use extra global memory as a workspace + self.workspace_arg = workspace_arg + if workspace_arg is not None: + self.args.workspace_args.append(workspace_arg) + + # The following attributes (body, template_mask, output_val) are all + # used for triton kernel codegen. + # They are swapped onto the TritonTemplateKernel object by + # `set_subgraph_body` + self.subgraph_bodies: dict[str, SubgraphInfo] = {} + + # input buffers which we are allowed to prologue fuse into + self.prologue_supported_inputs: OrderedSet[str] = OrderedSet() + + # input buffers which we are fusing into + self.prologue_fused_inputs: OrderedSet[str] = OrderedSet() + # input buffers which we are fusing into, which preserve a zero mask + self.prologue_fused_inputs_preserve_zero: OrderedSet[str] = OrderedSet() + + # The following attributes are all used for triton kernel codegen. + # They are swapped onto the TritonTemplateKernel object by + # `set_subgraph_body` + # NB: the names here must match the fields in SubgraphInfo + self.body: IndentedBuffer = FakeIndentedBuffer() + self.compute: IndentedBuffer = FakeIndentedBuffer() + self.indexing_code: IndentedBuffer = FakeIndentedBuffer() + self.loads: IndentedBuffer = FakeIndentedBuffer() + self.stores: IndentedBuffer = FakeIndentedBuffer() + self.template_mask: Optional[str] = None + self.template_out_shape: Optional[Union[str, tuple[str]]] = None + self.ops_handler: Optional[V.WrapperHandler] = None # type: ignore[name-defined] + + # When caching is enabled, the generated code is not dependent on the input nodes names, or + # symbolic sizes names. + # However, some of the variables returned by generate_and_load that are computed during the + # triton template expansions (code generation) are dependent on those. + # In order to cache the code generation and avoid redoing it for similar inputs that varies only by + # input names or symbol names, we do a record and replay method. + # During template expansions we record all function calls that change input_dependent_preserved_state + # and replay them on a cache hit to regenerate them. + self.cached_replay_events: Optional[RecordedEventsType] = None + + # Update each time an input is marked frozen, used to replay the freezing of inputs on a cache hit. + self.frozen_layouts_cnt = 0 + + # When prologue_loads_all_inputs is true, prologue_supported_inputs is populated during def_kernel + # by adding all inputs. + self.prologue_loads_all_inputs = prologue_loads_all_inputs + + # Extra functions to be exposed during partial template rendering. + self.extra_template_env_fns: list[Callable[..., Any]] = [] + + # Tracking for intermediate variables + self.tmp_var_ctr = itertools.count() + + def _gen_tmp_var(self) -> str: + return f"_tmp_var{next(self.tmp_var_ctr)}" + + def input_dependent_preserved_state(self) -> str: + # Not adding self.args.output_buffers on purpose. But we do not need to reproduce it on a cache hit. + # (never accessed). + return repr( + [ + self.args.input_buffers, + self.args.sizevars, + self.args.workspace_args, + self.prologue_supported_inputs, + self.frozen_layouts_cnt, + ] + ) + + def record_input_dependent_tracked_event(self) -> Callable[..., Any]: + def decorator(fn) -> Callable[..., Any]: + def wrapper(*args, **kwargs) -> Any: + pre_state = self.input_dependent_preserved_state() + result = fn(*args, **kwargs) + post_state = self.input_dependent_preserved_state() + if pre_state != post_state: + assert self.cached_replay_events is not None + self.cached_replay_events.append((fn.__name__, [*args], {**kwargs})) + return result + + return wrapper + + return decorator + + def replay_cached_events(self, events: RecordedEventsType) -> None: + for f, args, kwargs in events: + getattr(self, f)(*args, **kwargs) + + @contextlib.contextmanager + def set_subgraph_body(self, body_name: str): + assert all( + hasattr(self, field.name) for field in dataclasses.fields(SubgraphInfo) + ) + old_state = { + key.name: getattr(self, key.name) + for key in dataclasses.fields(SubgraphInfo) + } + + assert body_name in self.subgraph_bodies, body_name + + subgraph = self.subgraph_bodies[body_name] + for key, value in subgraph.to_dict().items(): + if value is None and key in subgraph.only_copy_if_non_none_fields: + continue + setattr(self, key, value) + + context = ( + contextlib.nullcontext + if not self.ops_handler + # pyrefly: ignore [not-callable] + else lambda: V.set_ops_handler(self.ops_handler(V.get_ops_handler())) + ) + with context(): # type: ignore[operator] + yield + self.subgraph_bodies[body_name] = SubgraphInfo( + **{ + key.name: getattr(self, key.name) + for key in dataclasses.fields(SubgraphInfo) + } + ) + for key, value in old_state.items(): + setattr(self, key, value) + + @contextlib.contextmanager + def create_subgraph_body(self, body_name: str, clear_cse: bool = False): + assert body_name not in self.subgraph_bodies + self.subgraph_bodies[body_name] = SubgraphInfo( + IndentedBuffer(), None, None, cse=self.cse.clone() if clear_cse else None + ) + with self.set_subgraph_body(body_name): + yield + + def need_numel_args(self): + return False + + def estimate_kernel_num_bytes(self): + """ + Estimate the total number of bytes this kernel takes. + For in/out nodes, sizes are counted twice: once for reading and + once for writing. + """ + ninplace_args = len(unique(self.args.inplace_buffers.values())) + num_bytes = [] + for i, inp in enumerate(itertools.chain(self.input_nodes, (self.output_node,))): + size = V.graph.sizevars.size_hints(inp.get_size(), fallback=0) + numel = functools.reduce(operator.mul, size, 1) + dtype_size = get_dtype_size(inp.get_dtype()) + num_bytes.append(numel * dtype_size * (1 + int(i < ninplace_args))) + return sum(num_bytes) + + def estimate_flops(self) -> int: + for node in self.input_nodes: + for fx_node in node._current_origins: + f = count_flops_fx(fx_node) + if f is not None: + return V.graph.sizevars.size_hint(f, fallback=0) + return 0 + + def jit_lines(self): + if self.use_jit: + return "@triton.jit" + + argdefs, _, signature, _ = self.args.python_argdefs() + triton_meta: dict[str, Any] = { + "signature": signature_to_meta( + signature, + size_dtype=self.index_dtype, + argdefs=argdefs, + is_template=True, + ), + "device": DeviceProperties.create(self.output_node.get_device()), + "constants": {}, + } + triton_meta["configs"] = [config_of(signature)] + for arg_num in equal_1_arg_indices(signature): # type: ignore[index] + triton_meta["constants"][signature[arg_num].name] = 1 # type: ignore[index,union-attr] + matrix_instr_nonkdim = self.meta.get("matrix_instr_nonkdim", None) + waves_per_eu = self.meta.get("waves_per_eu", None) + kpack = self.meta.get("kpack", None) + if matrix_instr_nonkdim: + triton_meta["matrix_instr_nonkdim"] = matrix_instr_nonkdim + if waves_per_eu: + triton_meta["waves_per_eu"] = waves_per_eu + if kpack: + triton_meta["kpack"] = kpack + + self.triton_meta = triton_meta + + inductor_meta = { + "kernel_name": str(Placeholder.DESCRIPTIVE_NAME), + **self.inductor_meta_common(), + **FixedGrid.setup_grid_as_args(), + } + if config.profile_bandwidth or config.benchmark_kernel: + num_gb = self.estimate_kernel_num_bytes() / 1e9 + inductor_meta["kernel_num_gb"] = num_gb + if config.benchmark_kernel: + flops = self.estimate_flops() + inductor_meta["kernel_flop"] = flops + + inductor_meta["config_args"] = self.meta + + template_args = f""" + num_stages={self.num_stages}, + num_warps={self.num_warps}, + triton_meta={triton_meta!r}, + inductor_meta={inductor_meta!r}, + """ + + if HAS_WARP_SPEC: + template_args += f""" + num_consumer_groups={self.num_consumer_groups}, + num_buffers_warp_spec={self.num_buffers_warp_spec}, + """ + + return f""" + @triton_heuristics.template( + {template_args} + ) + @triton.jit + """ + + def gen_argdefs(self): + def hook(): + # python_argdefs() cannot be run until after the rest of the template lazily adds more args + arg_defs, *_ = self.args.python_argdefs() + return f"{', '.join(x.full_name() for x in arg_defs)}" + + return self._register_hook("", hook, allow_overwriting=True) + + def gen_defines(self): + return self.defines + + def def_kernel(self, *argnames): + """ + Hook called from template code to generate function def and + needed args. + """ + assert all(isinstance(x, str) for x in argnames) + renames = IndentedBuffer(initial_indent=1) + + named_args = self.input_nodes[ + self.prefix_args : len(self.input_nodes) - self.suffix_args + ] + + assert len(argnames) == len(named_args), ( + len(argnames), + len(named_args), + self.prefix_args, + len(self.input_nodes), + ) + + for input_node in self.input_nodes[: self.prefix_args]: + # get args in correct order + self.args.input(input_node.get_name()) + + for name, input_node in zip(argnames, named_args): + arg_name = f"arg_{name}" + self.named_input_nodes[name] = input_node + if input_node.get_name() in V.graph.removed_buffers: + continue + if input_node.get_name() in self.prologue_fused_inputs: + continue + + self.args.input_buffers[input_node.get_name()] = arg_name + + # The args may be duplicated, so renaming must be after args are de-duplicated. + for name in argnames: + input_node = self.named_input_nodes[name] + if self.prologue_loads_all_inputs: + self.prologue_supported_inputs.add(input_node.get_name()) + if input_node.get_name() in V.graph.removed_buffers: + continue + if input_node.get_name() in self.prologue_fused_inputs: + continue + + arg_name = self.args.input_buffers[input_node.get_name()] + if input_node.get_layout().offset == 0: + renames.writeline(f"{name} = {arg_name}") + else: + offset = texpr(self.rename_indexing(input_node.get_layout().offset)) + renames.writeline(f"{name} = {arg_name} + {offset}") + + for input_node in self.input_nodes[len(self.input_nodes) - self.suffix_args :]: + # get args in correct order + if input_node.get_name() in V.graph.removed_buffers: + continue + if input_node.get_name() in self.prologue_fused_inputs: + continue + + self.args.input(input_node.get_name()) + + def hook(): + # python_argdefs() cannot be run until after the rest of the template lazily adds more args + arg_defs, *_ = self.args.python_argdefs() + code = IndentedBuffer() + code.splice(gen_common_triton_imports()) + code.splice(self.jit_lines()) + code.writeline( + f"def {self.kernel_name}({', '.join(x.full_name() for x in arg_defs)}):" + ) + with code.indent(): + code.splice(self.defines) + code.splice(renames.getvalue()) + self.codegen_prologue(code) + return code.getvalue() + + return self._register_hook("", hook) + + def size(self, name: Optional[str], index: int): + """ + Hook called from template code to get the size of an arg. + Will add needed args to pass it in if it is dynamic. + """ + assert isinstance(index, int) + if name is None: + val = self.output_node.get_size()[index] + else: + assert isinstance(name, str) + val = self.named_input_nodes[name].get_size()[index] + return texpr(self.rename_indexing(val)) + + def stride(self, name, index=None): + """ + Hook called from template code to get the stride of an arg. + Will add needed args to pass it in if it is dynamic. + """ + if name is None: + val = self.output_node.get_stride() + else: + assert isinstance(name, str) + val = self.get_stride_and_maybe_freeze_layout(self.named_input_nodes[name]) + + if isinstance(index, int): + return texpr(self.rename_indexing(val[index])) + return ", ".join([texpr(self.rename_indexing(i)) for i in val]) + + def _get_subgraph(self, subgraph_number: int): + assert isinstance(subgraph_number, int) + assert isinstance(self.subgraphs, list) + assert subgraph_number < len(self.subgraphs), ( + f"Invalid subgraph number provided to create_modification, {subgraph_number} must be < {len(self.subgraphs)}" + ) + assert self.body.getvalue() == "", ( + "Body should be clear before adding a modification" + ) + return self.subgraphs[subgraph_number] + + def _handle_scatter_graph(self, scatter_graph): + """Handle processing for a single scatter graph. + + Args: + scatter_graph: The scatter graph to process + """ + assert isinstance(scatter_graph, ir.ComputedBuffer), ( + f"scatter_graph must be an instance of ComputeBuffer but got {type(scatter_graph)}" + ) + + def contiguous_strides(x): + # We always create a fresh contiguous grad for scattering into + return sum( + x_i * stride for x_i, stride in zip(x, scatter_graph.get_stride()) + ) + + return scatter_graph.data.store_output( # type: ignore[attr-defined] + scatter_graph.name, contiguous_strides, [] + ) + + def modification( + self, + subgraph_number: int, + output_name: Optional[str], + mask: Optional[str] = None, + **fixed_inputs, + ) -> str: + """This creates a modification function for a subgraph. + To use this inside a template, the first argument should specify which subgraph to codegen for + + Args: + subgraph_number (int): The index of the subgraph in self.subgraphs + output_name (Optional[str]): The name of the output variable to store the result in + mask (Optional[str]): An optional mask to use for the store operation. If provided, this mask + will be applied to the store. + """ + num = 0 + out = None + scatters = [] + while f"mod_{subgraph_number}_{num}" in self.subgraph_bodies: + num += 1 + with self.create_subgraph_body(f"mod_{subgraph_number}_{num}"): + subgraph = self._get_subgraph(subgraph_number) + modification_handler = ModificationWrapper( + self, subgraph_number, fixed_inputs, mask + ) + with V.set_ops_handler(modification_handler): + assert isinstance(subgraph, (ir.ComputedBuffer, list)), ( + f"Expected the subgraph to be a ComputedBuffer or a List[ComputedBuffer], got {type(subgraph)}" + ) + # Handle scatter stores + if isinstance(subgraph, list): + for scatter_graph in subgraph: + scatters.append(self._handle_scatter_graph(scatter_graph)) + elif isinstance(subgraph.data, ir.InputBuffer): + out = subgraph.data.make_loader()(()) + else: + out = subgraph.data.inner_fn(()) + + self.codegen_body() + if output_name is not None: + assert isinstance(output_name, str) + assert out is not None + self.body.writeline(f"{output_name} = {out.value}") + else: + assert out is None + for scatter in scatters: + self.body.writeline(str(scatter)) + + body_val = self.body.getvalue() + self.cse.invalidate(OrderedSet()) + return body_val + + def load_input( + self, + input_name: str, + output_name: str, + indices: Union[list[Any], tuple[Any]], + mask: Optional[str] = None, + other: Optional[Union[float, int]] = 0.0, + indent_width: int = 4, + index_shape: Optional[tuple[str]] = None, + ): + """Loads an input and applies any necessary preprocessing or masking. + + Args: + input_name (str): The name of the input to load. + indices (Union[List, Tuple]): The index for each dimension of the input. + val (str): The name of the variable to store the loaded value. + mask (Optional[str]): An optional mask to use for the load operation. + other (Optional[Union[float, int]]): The value to use for masked elements. Default is 0.0. + indent_width (int): The number of spaces to use for indentation. + """ + + input_node = self.named_input_nodes[input_name] + if not self.prologue_loads_all_inputs: + self.prologue_supported_inputs.add(input_node.get_name()) + + tilings = (sympy_product(input_node.get_size()), sympy.Integer(1)) + groups = { + "x": tilings[0], + "r0_": tilings[1], + } + + range_trees = self.construct_range_trees( + pid_cache=None, + inside_reduction=False, + is_reduction=False, + numels=groups, + no_x_dim=False, + ) + load_code = None + + with self.create_subgraph_body(f""): + assert isinstance(indices, (list, tuple)) + assert isinstance(output_name, str) + assert isinstance(mask, (str, type(None))) + self.range_trees = range_trees + self.numels = {k: V.graph.sizevars.simplify(v) for k, v in groups.items()} + indices = list(map(OpOverrides.paren, indices)) + index_symbols = [sympy.Symbol(x, integer=True) for x in indices] + + lengths = [V.graph.sizevars.simplify(s) for s in input_node.get_size()] + assert len(indices) == len(lengths) + + index_symbols = [sympy.Symbol(x, integer=True) for x in indices] + assert len(indices) == len(lengths) + + # glue to make generated code use same indexing from template + + # TODO (from reviewers as well) + # in codegen_template, + # prologue_node.codegen(kernel.split_and_set_ranges(prologue_node.get_ranges())) + # the ranges need to reflect the group of the prologue input or it will error + # not sure if there is any difference between original range_tree_entry in + # and new one from correct lengths/groups... both actually seem to work + for name, range_tree_entry in zip( + indices, self.range_trees[0].construct_entries(lengths) + ): + range_tree_entry.set_name(name) + contiguous_index = sympy_dot( + ir.FlexibleLayout.contiguous_strides(lengths), index_symbols + ) + contiguous_index = self.rename_indexing(contiguous_index) + self.body.writeline("xindex = " + texpr(contiguous_index)) + + xindex_range_root = self.range_trees[0].lookup( + sympy.Integer(1), sympy_product(lengths) + ) + xindex_range_root.set_name("xindex") + + # Note - ["None" override_mask] + # MM Templates work by taking out of bounds index values and wrapping them around to 0 + # so that no mask is required on the load: offs_a_m = `rm % M` + # We should to override the mask to be "None" instead of inheriting the mask that would + # have been loaded otherwise. + # We are using "None" for clarity in output code, but + # we could alternatively emit `xmask = tl.full([xindex.shape], True, tl.int1)` + self.template_mask = mask if mask is not None else "None" + self.template_out_shape = index_shape if index_shape else "xindex" + self.template_indices = indices + self.cse.invalidate(OrderedSet()) + + template_mask = self.template_mask + + class StoreOutputSubstitution(V.WrapperHandler): # type: ignore[name-defined] + name = "StoreOutputSubstitution" + + def store( + self, + name: str, + index: sympy.Expr, + value: "CSEVariable", + mode: "StoreMode" = None, + ): + V.kernel.store_buffer_names.add(name) + V.kernel.cse.store_cache[name] = value + if name in V.kernel.prologue_fused_inputs: + # We load masked out values with 0, then apply a prologue. + # The masked out values may not necessariliy be 0 any more + # so we need to reapply the mask. + value_dtype = value.dtype + value_str = str(value) + if template_mask != "None" and ( + name not in V.kernel.prologue_fused_inputs_preserve_zero + or other != 0 + ): + value_str = ( + f"tl.where({template_mask}, {value_str}, {other})" + ) + + if value_dtype != V.graph.get_buffer(name).dtype: + value_str = f"{value_str}.to({triton_type(V.graph.get_buffer(name).dtype)})" + + # TODO: we should have intermediary var shapes + V.kernel.compute.writeline( + f"{output_name} = {value_str}.broadcast_to(xindex.shape)" + ) + + # pyrefly: ignore [bad-assignment] + self.ops_handler = StoreOutputSubstitution + + input_node = self.named_input_nodes[input_name] + output_index = input_node.make_indexer()(index_symbols) + + # in def_kernel above we define the inputs with the storage offset adjusted + # creating the load in input_node.make_indexer() will also adjust by storage offset + # so subtract here to not double increment + if not V.graph.sizevars.statically_known_equals( + input_node.layout.offset, 0 + ): + output_index = output_index - self.rename_indexing( + input_node.get_layout().offset + ) + + output_index = self.rename_indexing(output_index) + + if output_index == contiguous_index: + output_index_str = "xindex" + else: + out_indexing = self.indexing( + output_index, + copy_shape=self.template_out_shape, + override_mask=self.template_mask, + ) + from .codegen.triton import IndexingOptions + + assert isinstance(out_indexing, IndexingOptions) + output_index_str = ( + f"({out_indexing.index_str}).broadcast_to(xindex.shape)" + ) + + # Generate load code + load_code = f"{output_name} = tl.load({input_name} + ({output_index_str})" + + if mask: + load_code += f", mask={mask}, other={other})" + else: + load_code += ")" + + hook_key = f"" + + def hook(): + with self.set_subgraph_body(hook_key): + self.cse.invalidate(OrderedSet()) + self.codegen_body() + self.cse.invalidate(OrderedSet()) + if input_node.get_name() not in self.prologue_fused_inputs: + assert load_code is not None + self.body.writeline(load_code) + + return textwrap.indent(self.body.getvalue(), " " * indent_width).strip() + + return self._register_hook(hook_key, hook) + + def _generate_index_from_tma_index( + self, + output_name: str, + offset_name: str, + tma_index: sympy.Symbol, + block_size: str, + dim: int, + num_dims: int, + block_name: Optional[str] = None, + ) -> list[str]: + """ + Generate the logic to compute the regular tl.load index from the provided + tma index. This is used to ensure variables can support fusions. + + Args: + output_name (str): The output variable name. + offset_name (str): The name used for the intermediate offset. + tma_index (sympy.Symbol): The symbol used for the original TMA index. + block_size (str): The block size of the index. + dim (int): Which dimension to project the index in. + num_dims (int): The total number of dimensions in the output. + block_name (Optional[str]): The name of the block variable. If not passed + in then we aren't reusing standard symbol names. + + Returns: + list[str]: The lines used to generate the index. + + """ + if block_name: + # Generate the expected names for the structure: + # XBLOCK/YBLOCK and xoffset/yoffset. We append XBLOCK/YBLOCK + # to the top of the kernel so we can safely extract the tensor + # descriptor construction to the top of the kernel. + if block_name in self.prologue_cache: + assert self.prologue_cache[block_name] == block_size, ( + f"Constant {block_name} must be used for all stores" + ) + else: + self.prologue_cache[block_name] = block_size + self.prologue.writeline(f"{block_name}: tl.constexpr = {block_size}") + else: + block_name = block_size + line0 = f"{offset_name} = {texpr(tma_index)}" + expr = f"({offset_name} + tl.arange(0, {block_name}))" + prefix_none = "".join(["None, "] * dim) + suffix_none = ", ".join(["None"] * (num_dims - (dim + 1))) + line1 = f"{output_name} = {expr}[{prefix_none}:, {suffix_none}]" + return [line0, line1] + + def _generated_mask_for_tma( + self, + index_name: str, + shape_val: str, + output_name: str, + ) -> str: + """ + Generate the mask logic to feed to fusions for mask. The expectation + is that if we have X/Y there will be a variable named xmask and ymask. + + Args: + index_name (str): The index used in the mask. Should be one of + xindex or yindex. + shape_val (str): The expression for the upper bound shape. + output_name (str): The expression used for the output. + + Returns: + str: The mask generation line. + """ + return f"{output_name} = {index_name} < {shape_val}" + + def store_output( + self, + indices: Union[list[Any], tuple[Any]], + val: str, + mask: Optional[str] = None, + indent_width: int = 4, + val_shape: Optional[tuple[str]] = None, + block_indexing: bool = False, + ): + """Stores the final output and appends any epilogue fusions if the buffer hasn't been optimized away. + + Args: + indices (Union[List, Tuple]): The index for each dimension of the output. The dot product of + these indices and output strides must match `val`. + val (str): The value to store. + mask (Optional[str]): An optional mask to use for the store operation. If provided, this mask + will be applied to the store. + indent_width (int): The number of spaces to use for indentation. This is used when the call to + store_output is indented in the kernel definition. + block_indexing (bool): Are the input indices presented as offsets for creating the block (e.g. + inputs to TMA) or are they tensors that should be passed in directly. + """ + subgraph_name = self._get_store_output_subgraph_name( + next(self.store_output_ctr) + ) + with self.create_subgraph_body(subgraph_name, clear_cse=True): + assert isinstance(indices, (list, tuple)) + assert isinstance(val, str) + assert isinstance(mask, (str, type(None))) + assert isinstance(val_shape, (tuple, type(None))) + assert isinstance(block_indexing, bool) + assert self.template_mask is None + indices = list(map(OpOverrides.paren, indices)) + index_symbols = [sympy.Symbol(x, integer=True) for x in indices] + lengths = [ + V.graph.sizevars.simplify(s) for s in self.output_node.get_size() + ] + assert len(indices) == len(lengths) + + output_layout = self.output_node.get_layout() + self.template_out = val + if block_indexing: + assert val_shape, "Blocking indexing requires passing in val_shape" + assert len(val_shape) == 2, ( + "Blocking indexing only supports 2D data at this time" + ) + assert not mask, "Mask is not supported with blocking indexing" + intermediate_lines: list[str] = [] + epilogue_index_symbols: list[sympy.Symbol] = [] + if self.tma_store: + val_shape_copy = list(val_shape) + for i, range_tree in enumerate(self.range_trees[:-1]): + name = range_tree.name + symbol = range_tree.symbol() + epilogue_index_symbols.append(symbol) + lookup_output = range_tree.lookup(sympy.S.One, lengths[i]) + old_name = lookup_output.symbol() + lookup_output.set_name(name) + # Update var_list and var_range + range_tree.var_list[range_tree.var_list.index(old_name)] = ( + symbol + ) + range_val = range_tree.var_ranges[old_name] + del range_tree.var_ranges[old_name] + range_tree.var_ranges[symbol] = range_val + intermediate_lines.extend( + self._generate_index_from_tma_index( + name, + "xoffset" if name == "xindex" else "yoffset", + index_symbols[i], + val_shape[i], + i, + len(val_shape), + # pyrefly: ignore [missing-argument] + block_name=range_tree.symt.name, + ) + ) + # Generate the xmask and ymask + intermediate_lines.append( + self._generated_mask_for_tma( + name, + self.size(None, i), + "xmask" if name == "xindex" else "ymask", + ) + ) + # Update the val_shape information to use consistent naming + # after the remapping. + # pyrefly: ignore [missing-argument] + val_shape_copy[i] = range_tree.symt.name + val_shape = tuple(val_shape_copy) + else: + mask_vars: list[str] = [] + for i, (index, shape) in enumerate(zip(index_symbols, val_shape)): + index_name = self._gen_tmp_var() + offset_name = self._gen_tmp_var() + intermediate_lines.extend( + self._generate_index_from_tma_index( + index_name, + offset_name, + index, + shape, + i, + len(index_symbols), + ) + ) + epilogue_index_symbols.append( + sympy.Symbol(index_name, integer=True) + ) + mask_name = self._gen_tmp_var() + intermediate_lines.append( + self._generated_mask_for_tma( + index_name, + self.size(None, i), + mask_name, + ) + ) + mask_vars.append(mask_name) + final_mask_var = self._gen_tmp_var() + final_mask_rhs = " & ".join( + f"{mask_name}" for mask_name in mask_vars + ) + intermediate_lines.append(f"{final_mask_var} = {final_mask_rhs}") + self.template_mask = final_mask_var + index_symbols = epilogue_index_symbols + contiguous_index = sympy_dot(output_layout.stride, index_symbols) + if not self.tma_store: + # Convert to just use xindex. + contiguous_index = self.rename_indexing(contiguous_index) + intermediate_lines.append(f"xindex = {texpr(contiguous_index)}") + self.range_trees[0].lookup( + sympy.S.One, sympy_product(lengths) + ).set_name("xindex") + index_symbols = epilogue_index_symbols + output_index = contiguous_index + # Write out the intermediate lines + for line in intermediate_lines: + self.body.writeline(line) + else: + assert not self.tma_store, "TMA store requires block indexing" + # glue to make generated code use same indexing from template + for name, range_tree_entry in zip( + indices, self.range_trees[0].construct_entries(lengths) + ): + range_tree_entry.set_name(name) + contiguous_index = sympy_dot( + ir.FlexibleLayout.contiguous_strides(lengths), index_symbols + ) + contiguous_index = self.rename_indexing(contiguous_index) + self.body.writeline("xindex = " + texpr(contiguous_index)) + self.range_trees[0].lookup( + sympy.S.One, sympy_product(lengths) + ).set_name("xindex") + self.template_mask = mask + self.template_indices = indices + output_index = self.output_node.get_layout().make_indexer()( + index_symbols + ) + output_index = self.rename_indexing(output_index) + if output_index == contiguous_index: + output_index = sympy.Symbol("xindex", integer=True) + + # pyrefly: ignore [bad-assignment] + self.template_out_shape = val_shape if val_shape else val + acc_dtype = ( + triton_type_to_torch(self.meta["ACC_TYPE"]) + if "ACC_TYPE" in self.meta + else torch.float32 + ) + epilogue_args = [ + V.kernel.cse.namedvar(val, dtype=acc_dtype, shape=val_shape) + ] + for input_node in itertools.chain( + self.input_nodes[: self.prefix_args], + self.input_nodes[len(self.input_nodes) - self.suffix_args :], + ): + input_node.freeze_layout() + epilogue_arg = V.kernel.cse.generate( + self.compute, + input_node.make_loader()(index_symbols), + dtype=acc_dtype, + shape=input_node.get_size(), + ) + epilogue_args.append(epilogue_arg) + # We update frozen_layouts_cnt in order to replay this function on a cache hit. + self.frozen_layouts_cnt += 1 + + V.ops.store( + self.output_node.get_name(), + output_index, + self.epilogue_fn(*epilogue_args), + mode="tma" if self.tma_store else None, + ) + self.codegen_body() + + def hook(): + with self.set_subgraph_body(subgraph_name): + # more stuff might have been added since the codegen_body above + self.codegen_body() + self.cse.invalidate(OrderedSet()) + + return textwrap.indent(self.body.getvalue(), " " * indent_width).strip() + + return self._register_hook(subgraph_name, hook) + + def _register_hook( + self, + hook_name: str, + hook_fn: PartialRender.HookFn, + *, + allow_overwriting: bool = False, + ) -> str: + """ + Register a hook function with a name. + + ``hook_name`` should match the string that will be replaced via + ``hook_fn``, and should not already be in use for a hook. + + If ``allow_overwriting`` is ``False``, will assert that there isn't + currently a registered hook of the same name before registering the new + one. + """ + + if not allow_overwriting: + assert hook_name not in self.render_hooks, ( + f"Tried to register the hook {hook_name} multiple times. If " + "desired, pass allow_overwriting=True to _register_hook" + ) + self.render_hooks[hook_name] = hook_fn + return hook_name + + def _register_extra_template_env_fns(self, *fns: Callable[..., Any]): + """ + Register some extra functions to expose when performing the initial + template render, so that they're in scope to by used by jinja + expressions. + + These can be used to, for example, implement extra replacement hooks, + if the given function: + + * Returns the name of their hook, which should also be the string to + replace via the hook function. The convention is to use the format + . + * Assigns the corresponding entry in ``self.render_hooks`` to a hook + function. + """ + self.extra_template_env_fns.extend(fns) + + def render(self, template, kwargs, record_input_dependent_tracked_event=False): + if record_input_dependent_tracked_event: + self.cached_replay_events = [] + + template_env = { + fn.__name__: ( + self.record_input_dependent_tracked_event()(fn) + if record_input_dependent_tracked_event + else fn + ) + for fn in [ + self.def_kernel, + self.size, + self.stride, + self.store_output, + self.load_input, + self.make_load, + self.modification, + self.gen_argdefs, + self.gen_defines, + *self.extra_template_env_fns, + ] + } + return PartialRender( + template.render(**template_env, **kwargs), + self.render_hooks, + ) + + def make_load(self, name, indices, mask): + """ + Optional helper called from template code to generate the code + needed to load from an tensor. + """ + assert isinstance(indices, (list, tuple)) + assert isinstance(name, str) + assert isinstance(mask, str) + stride = self.get_stride_and_maybe_freeze_layout(self.named_input_nodes[name]) + indices = list(map(OpOverrides.paren, indices)) + assert len(indices) == len(stride) + index = " + ".join( + f"{texpr(self.rename_indexing(s))} * {i}" for s, i in zip(stride, indices) + ) + return f"tl.load({name} + ({index}), {mask}, other=0.0)" + + def indexing( + self, + index: sympy.Expr, + *, + dense_indexing=False, + copy_shape=None, + override_mask=None, + block_ptr=False, + tma_compatibility_checker: Optional[TMACompatibilityChecker] = None, + ): + """ + Override the default indexing to use our custom mask and force + dense indexing. + """ + return super().indexing( + index, + dense_indexing=False, + # We pass template_out as the shape to broadcast the indexing to as + # the mask might be broadcast to the output shape + copy_shape=self.template_out_shape, + override_mask=self.template_mask, + block_ptr=block_ptr, + tma_compatibility_checker=tma_compatibility_checker, + ) + + def codegen_range_tree(self): + pass # ignore default codegen + + def additional_call_args_and_types(self): + if isinstance(self.grid_fn, SymbolicGridFn): + grid_args = self.grid_fn.sympy_call(*self.call_sizes, self.meta) + assert len(grid_args) in (0, 3), "grid_fn should return 3 values" + return (grid_args, map(type, grid_args)) + elif all(isinstance(x, (int, sympy.Integer)) for x in self.call_sizes): + grid_args = self.grid_fn(*map(int, self.call_sizes), self.meta) + assert len(grid_args) in (0, 3), "grid_fn should return 3 values" + return (grid_args, map(type, grid_args)) + return ((), ()) + + def call_kernel( + self, name: str, node: Optional[ir.IRNode] = None, deallocate_ws: bool = True + ): + wrapper = V.graph.wrapper_code + _, call_args, _, arg_types = self.args.python_argdefs() + + additional_call_args, additional_arg_types = ( + self.additional_call_args_and_types() + ) + + if not additional_call_args: + assert not V.graph.cpp_wrapper, "cpp_wrapper requires SymbolicGridFn" + wrapper.add_import_once(f"import {self.grid_fn.__module__}") + meta = wrapper.add_meta_once(self.meta) + fn_name = f"{self.grid_fn.__module__}.{self.grid_fn.__name__}" + call_args.append( + f"*{fn_name}({', '.join(map(pexpr, self.call_sizes))}, {meta})" + ) + arg_types.append(None) + + call_args.extend(additional_call_args) + arg_types.extend(additional_arg_types) + + if self.workspace_arg is not None: + wrapper.generate_workspace_allocation(self.workspace_arg) + wrapper.generate_kernel_call( + name, + call_args, + arg_types=arg_types, + triton_meta=self.triton_meta, + triton=True, + ) + if self.workspace_arg is not None: + wrapper.generate_workspace_deallocation(self.workspace_arg) + + def kernel_benchmark_extra_args(self) -> list[str]: + return [ + str(x) + for x in self.grid_fn( + *V.graph.sizevars.size_hints(self.call_sizes), self.meta + ) + ] + + def get_stride_and_maybe_freeze_layout(self, node) -> list[int]: + node.data.freeze_layout() + return node.get_stride() + + +@functools.cache +def _jinja2_env(): + try: + import jinja2 + + return jinja2.Environment( + undefined=jinja2.StrictUndefined, + ) + except ImportError: + return None + + +class GenerateAndLoadResult(NamedTuple): + """ + Return type of TritonTemplate.generate_and_load. + """ + + mod: ModuleType + extra: str + input_call_args: tuple[str, ...] + prologue_supported_inputs: OrderedSet[str] + kernel_args_sizevars_keys: tuple[sympy.Expr, ...] + kernel_options: dict[str, Any] + + +class GeneratedCodeCacheEntry(NamedTuple): + code: str + extra: str + events: list[Any] + + +class GeneratedCodeCache: + """ + Cache for generated code. The cache key is a string representation of the input nodes, + number of stages, number of warps, and call sizes. The cache value is a tuple of the + generated code, extra code, and events. + """ + + def __init__(self, *args, **kwargs): + self._cache: dict[str, GeneratedCodeCacheEntry] = {} + + def cache_clear(self) -> None: + self._cache.clear() + + def __repr__(self): + return repr(self._cache) + + def make_key( + self, + input_nodes: tuple[ir.IRNode, ...], + num_stages: int, + num_warps: int, + call_sizes: Sequence[sympy.core.symbol.Symbol], + prefix_args: int, + suffix_args: int, + epilogue_fn: Optional[Callable[..., Any]], + epilogue_fn_hash: Optional[str], + tma_store: bool, + transpose_discontiguous_tensor_descriptors_override: Optional[bool], + subgraphs: Optional[list[ir.Buffer]], # has to be none to cache + workspace_arg: Optional[WorkspaceArg], # has to be none to cache + layout: ir.Layout, + num_consumer_groups: int, + num_buffers_warp_spec: int, + kwargs: dict[str, Any], + hint_override: Optional[int] = None, + ) -> Optional[str]: + def layout_key(layout: ir.Layout) -> str: + assert not isinstance(layout, ir.FlexibleLayout) + return repr( + [ + layout.size, + layout.stride, + layout.dtype, + layout.device, + layout.offset, + ] + ) + + def has_flexible_layout() -> bool: + if isinstance(layout, ir.FlexibleLayout): + return True + + for input in input_nodes: + if isinstance(input.get_layout(), ir.FlexibleLayout): + return True + return False + + if epilogue_fn is identity: + assert epilogue_fn_hash is None + epilogue_fn_hash = "identity" + + # we do not cache under those conditions right now. + if ( + has_flexible_layout() + or subgraphs is not None + or workspace_arg is not None + or epilogue_fn_hash is None + ): + return None + + return repr( + { + "input_nodes": [ + layout_key(input.get_layout()) for input in input_nodes + ], + "num_stages": num_stages, + "num_warps": num_warps, + "prefix_args": prefix_args, + "suffix_args": suffix_args, + "call_sizes": call_sizes, + "layout": layout_key(layout), + "num_consumer_groups": num_consumer_groups, + "num_buffers_warp_spec": num_buffers_warp_spec, + "epilogue_fn_hash": epilogue_fn_hash, + "tma_store": tma_store, + "transpose_discontiguous_tensor_descriptors_override": transpose_discontiguous_tensor_descriptors_override, + "kwargs": kwargs, + "hint_override": hint_override, + } + ) + + def get_entry(self, cache_key: Optional[str]) -> Optional[GeneratedCodeCacheEntry]: + if cache_key is None: + return None + + entry = self._cache.get(cache_key, None) + if entry is None: + torch._dynamo.utils.counters["inductor"]["generated_module_cache_miss"] += 1 + else: + torch._dynamo.utils.counters["inductor"]["generated_module_cache_hit"] += 1 + return entry + + def put_entry( + self, + cache_key: Optional[str], + code: str, + extra: str, + events: list[Any], + ) -> None: + if cache_key is None: + return + entry = GeneratedCodeCacheEntry(code, extra, events) + self._cache.update({cache_key: entry}) + + +class TritonTemplate(KernelTemplate): + """ + A Triton template is a template that can be used to generate a Triton kernel. + """ + + # Allow subclasses to override the kernel type + kernel_type: type[Any] = TritonTemplateKernel + index_counter = itertools.count() + all_templates: dict[str, "TritonTemplate"] = {} + + def __init__( + self, + name: str, + grid: Any, + source: str, + debug=False, + cache_codegen_enabled_for_template=False, + prologue_loads_all_inputs=False, + ) -> None: + super().__init__(name, hash=hashlib.sha256(source.encode("utf-8")).hexdigest()) + self.grid = grid + self.template = self._template_from_string(source) + assert name not in self.all_templates, "duplicate template name" + TritonTemplate.all_templates[name] = self + self.debug = debug + self._cache_codegen_enabled_for_template = cache_codegen_enabled_for_template + self._generated_code_cache: GeneratedCodeCache = GeneratedCodeCache() + clear_on_fresh_cache(self._generated_code_cache) + # When prologue_loads_all_inputs is true, prologue_supported_inputs is populated during def_kernel + # by adding all inputs. + self.prologue_loads_all_inputs = prologue_loads_all_inputs + + # When this flag is on, we ensure that the cached results and the generated result if cache + # was not used are the same. + test_cache = False + + @property + def uid(self) -> str: + # unique by prefixing with triton + return f"triton::{self.name}" + + def maybe_append_choice( + self, choices: list[Any], **kwargs: Any + ) -> Optional[NotImplementedError]: + """ + Maybe generates a new ChoiceCaller and appends it into existing choices. + Returns None if success, otherwise returns the error. + + choices: A list of ChoiceCallers. + kwargs: Additional kwargs to be passed to self.generate() to generate a new ChoiceCaller. + """ + + try: + choice = self.generate(generate_with_caching=True, **kwargs) + if choice is not None: + choices.append(choice) + return None + except NotImplementedError as e: + log.info( # noqa: G200 + "Cannot Append Choice: %s. KernelTemplate type is %s", + e, + type(self), + stack_info=log.getEffectiveLevel() < logging.INFO, + ) + return e + + # NOTE: MAKE SURE THAT ANY ARGUMENT ADDED TO THIS FUNCTION IS PROPERLY HANDLED IN _generated_code_cache.make_key. + def generate_and_load( + self, + input_nodes: tuple[ir.IRNode, ...], + num_stages: int, + num_warps: int, + call_sizes: Sequence[sympy.core.symbol.Symbol], + prefix_args: int, + suffix_args: int, + epilogue_fn: Optional[Callable[..., Any]], + epilogue_fn_hash: Optional[str], + subgraphs: Optional[list[ir.Buffer]], + workspace_arg: Optional[WorkspaceArg], + num_consumer_groups: int, + num_buffers_warp_spec: int, + layout: ir.Layout, + kwargs: dict[str, Any], + generate_with_caching, + hint_override: Optional[int] = None, + tma_store: bool = False, + transpose_discontiguous_tensor_descriptors_override: Optional[bool] = None, + ) -> Optional[GenerateAndLoadResult]: + """Generate the python code and load it into the current process""" + caching_enabled = ( + generate_with_caching + and torch._inductor.config.enable_caching_generated_triton_templates + ) + + cache_key = None + if caching_enabled: + cache_key = self._generated_code_cache.make_key( + input_nodes, + num_stages, + num_warps, + call_sizes, + prefix_args, + suffix_args, + epilogue_fn, + epilogue_fn_hash, + tma_store, + transpose_discontiguous_tensor_descriptors_override, + subgraphs, + workspace_arg, + layout, + num_consumer_groups, + num_buffers_warp_spec, + kwargs, + ) + + assert self.template, "requires jinja2" + defines = StringIO() + + for name, val in kwargs.items(): + defines.write(f"{name} : tl.constexpr = {val}\n") + + fake_out = ir.Buffer(name="buf_out", layout=layout) + kernel_name = f"triton_{self.name}" + + numel = sympy_product(layout.size) + buffers = itertools.chain(input_nodes, (fake_out,)) + + if TritonScheduling.can_use_32bit_indexing(numel, buffers): + index_dtype = "tl.int32" + else: + index_dtype = "tl.int64" + + # Add index dtype to defines so it's available in the template + defines.write(f"INDEX_DTYPE : tl.constexpr = {index_dtype}\n") + defines = defines.getvalue() + + kernel_options = { + "input_nodes": input_nodes, + "defines": defines, + "num_stages": num_stages, + "num_warps": num_warps, + "grid_fn": self.grid, + "meta": kwargs, + "call_sizes": call_sizes, + "prefix_args": prefix_args, + "suffix_args": suffix_args, + "epilogue_fn": epilogue_fn, + "subgraphs": subgraphs, + "prologue_loads_all_inputs": self.prologue_loads_all_inputs, + } + + if HAS_WARP_SPEC: + kernel_options.update( + { + "num_consumer_groups": num_consumer_groups, + "num_buffers_warp_spec": num_buffers_warp_spec, + } + ) + + def make_kernel(): + return self.kernel_type( + kernel_name=kernel_name, + output_node=fake_out, + workspace_arg=workspace_arg, + use_jit=False, + hint_override=hint_override, + tma_store=tma_store, + transpose_discontiguous_tensor_descriptors_override=transpose_discontiguous_tensor_descriptors_override, + **kernel_options, + ) + + def generate_code(kernel) -> Optional[tuple[str, str]]: + def make_extra() -> str: + extra_parts = [ + f"{kwarg}={repr(kwargs[kwarg])}" for kwarg in sorted(kwargs.keys()) + ] + + extra_parts.extend( + [ + f"num_stages={num_stages}", + f"num_warps={num_warps}", + ] + ) + if HAS_WARP_SPEC: + extra_parts.extend( + [ + f"num_consumer_groups={num_consumer_groups}", + f"num_buffers_warp_spec={num_buffers_warp_spec}", + ] + ) + extra = "-".join(extra_parts) + "-" + return extra + + try: + template = kernel.render(self.template, kwargs, caching_enabled) + code = template.finalize_all() + except ZeroDivisionError: + # TODO(nmacchioni): fix sympy division by zero + return None + if self.debug: + print("Generated Code:\n", code) + + extra = make_extra() + return code, extra + + def maybe_test_cache(code: str, extra: str, kernel): + if self.test_cache or self.debug: + with ( + patch.object(V.graph, "get_dtype", self._fake_get_dtype(fake_out)), + V.graph.set_current_device(layout.device), + make_kernel() as kernel_test, + ): + result2 = generate_code(kernel_test) + assert result2 is not None + code_test, extra_test = result2 + assert ( + code == code_test + and extra == extra_test + and kernel.args.input_buffers == kernel_test.args.input_buffers + and kernel.prologue_supported_inputs + == kernel_test.prologue_supported_inputs + and kernel.args.sizevars == kernel_test.args.sizevars + ), "Generated code cache results in wrong output" + + # Generate code, extra. + code: Optional[str] = None + extra: Optional[str] = None + with ( + patch.object(V.graph, "get_dtype", self._fake_get_dtype(fake_out)), + V.graph.set_current_device(layout.device), + make_kernel() as kernel, + ): + cache_entry = self._generated_code_cache.get_entry(cache_key) + cache_hit = False + + if cache_entry is not None: + code, extra, events = cache_entry + kernel.replay_cached_events(events) + cache_hit = True + + else: + result = generate_code(kernel) + if result is None: # happens at ZeroDivisionError: + return None + code, extra = result + self._generated_code_cache.put_entry( + cache_key, code, extra, kernel.cached_replay_events + ) + + assert code is not None and extra is not None + + mod = PyCodeCache.load(code, extra) + + input_call_args = tuple(kernel.args.input_buffers.keys()) + prologue_supported_inputs = kernel.prologue_supported_inputs.copy() + kernel_args_sizevars_keys = tuple(kernel.args.sizevars.keys()) + + if cache_hit: + maybe_test_cache(code, extra, kernel) + + return GenerateAndLoadResult( + mod, + extra, + input_call_args, + prologue_supported_inputs, + kernel_args_sizevars_keys, + kernel_options, + ) + + def generate( # type: ignore[override] + self, + input_nodes: tuple[ir.IRNode, ...], + layout: ir.Layout, + num_stages: int, + num_warps: int, + num_consumer_groups: int = 0, + num_buffers_warp_spec: int = 0, + prefix_args: int = 0, + suffix_args: int = 0, + epilogue_fn: Optional[Callable[..., Any]] = identity, + epilogue_fn_hash: Optional[str] = None, + subgraphs: Optional[list[ir.Buffer]] = None, + mutated_inputs: Optional[list[ir.IRNode]] = None, + call_sizes: Optional[Sequence[sympy.core.symbol.Symbol]] = None, + workspace_arg: Optional[WorkspaceArg] = None, + generate_with_caching=False, + hint_override: Optional[int] = None, + tma_store: bool = False, + transpose_discontiguous_tensor_descriptors_override: Optional[bool] = None, + **kwargs, + ): + """This function generates a TritonTemplateCaller + + Args: + input_nodes: List of input nodes + layout: Output layout + num_stages: Number of stages for triton launch + num_warps: Number of warps for triton launch + prefix_args: Number of input nodes to be passed as arguments + suffix_args: Number of input nodes to be passed as arguments + epilogue_fn: Optional epilogue function to be called on the output + subgraphs: Optional subgraphs to be passed as arguments, these will be inlined + into the triton template string + mutated_inputs: Optional list of input nodes that are mutated by the kernel, this is helpful + if you need to return multiple outputs. You can pass them as inputs and mark them as + being mutated by the kernel. + """ + # HACK: Triton currently breaks if TF32 floats are requested, but the CUDA + # capability doesn't support them. This is a bug in Triton, but for now we'll + # patch around it here. See https://github.com/triton-lang/triton/issues/3011 + # for one example issue with this problem. + if torch.cuda.is_available() and not torch.cuda.is_tf32_supported(): + kwargs["ALLOW_TF32"] = "False" + + if call_sizes is None: + call_sizes = layout.size + + result = self.generate_and_load( + input_nodes, + num_stages, + num_warps, + call_sizes, + prefix_args, + suffix_args, + epilogue_fn, + epilogue_fn_hash, + subgraphs, + workspace_arg, + num_consumer_groups, + num_buffers_warp_spec, + layout, + kwargs, + generate_with_caching and self._cache_codegen_enabled_for_template, + hint_override=hint_override, + tma_store=tma_store, + transpose_discontiguous_tensor_descriptors_override=transpose_discontiguous_tensor_descriptors_override, + ) + + # May happen as result of dev by 0. + if result is None: + return None + + # We expect the input_buffer order to be [*input_nodes, *captured_buffers] + expected_input_args = tuple(unique(x.get_name() for x in input_nodes)) + assert ( + result.input_call_args[: len(expected_input_args)] == expected_input_args + ), ( + result.input_call_args, + expected_input_args, + ) + + # `kernel_input_nodes` are the actual inputs that will be passed to the kernel, + # so e.g. views of the same input are not included. `codegen_input_nodes` + # includes views of inputs to preserve the kernel semantics. The shape and + # strides of `codegen_input_nodes` will be used to infer read/writes in + # TemplateBuffer.extract_read_writes + kernel_input_nodes = tuple( + [V.graph.get_buffer(k) for k in result.input_call_args] + ) + # Here we have (*input_nodes, *captured_buffers) + codegen_input_nodes = ( + tuple(input_nodes) + kernel_input_nodes[len(expected_input_args) :] + ) + extra_args = V.graph.sizevars.size_hints( + map(sympy.expand, result.kernel_args_sizevars_keys), + fallback=config.unbacked_symint_fallback, + hint_override=hint_override, + ) + + kernel_hash_name = f"triton_{self.name}_{next(self.index_counter)}" + + workspace_args = [] + if workspace_arg is not None: + # Create workspace tensor + workspace_size = workspace_arg.count + workspace_tensor = torch.empty_strided( + (workspace_size,), + (1,), + dtype=torch.uint8, + device=layout.device.type, + ) + + # Handle zero initialization if needed + if workspace_arg.zero_mode != WorkspaceZeroMode.UNINITIALIZED: + workspace_tensor.zero_() + + workspace_args.append(workspace_tensor) + + options = result.kernel_options + + def make_kernel_render(out_node, hint_override: Optional[int] = None): + assert result is not None + kernel = self.kernel_type( + kernel_name=str(Placeholder.KERNEL_NAME), + output_node=out_node, + workspace_arg=workspace_arg, + use_jit=False, + hint_override=hint_override, + tma_store=tma_store, + transpose_discontiguous_tensor_descriptors_override=transpose_discontiguous_tensor_descriptors_override, + **options, + ) + render = functools.partial( + kernel.render, + self.template, + kwargs, + ) + return kernel, render + + # create the BenchmarkRequest + assert result.mod.__file__ is not None + grid = self.grid( + *V.graph.sizevars.size_hints( + call_sizes, + fallback=config.unbacked_symint_fallback, + hint_override=hint_override, + ), + kwargs, + ) + bmreq_cls: type[TritonBenchmarkRequest] + if layout.device.type == "cpu": + bmreq_cls = TritonCPUBenchmarkRequest + else: + bmreq_cls = TritonGPUBenchmarkRequest + bmreq = bmreq_cls( + module_path=result.mod.__file__, + module_cache_key=result.mod.key, + kernel_name=f"triton_{self.name}", + extra_args=[*extra_args, *workspace_args, *grid], + num_stages=num_stages, + num_warps=num_warps, + num_consumer_groups=num_consumer_groups, + num_buffers_warp_spec=num_buffers_warp_spec, + matrix_instr_nonkdim=kwargs.get("matrix_instr_nonkdim", 0), + waves_per_eu=kwargs.get("waves_per_eu", 0), + kpack=kwargs.get("kpack", 2), + input_tensor_meta=TensorMeta.from_irnodes(kernel_input_nodes), # type: ignore[arg-type] + output_tensor_meta=TensorMeta.from_irnodes(layout), + ) + + return TritonTemplateCaller( + kernel_hash_name, + codegen_input_nodes, + layout, + make_kernel_render, + result.extra.strip("-").replace("-", ", "), + bmreq, + log_info={ + "tile_shape": str( + ( + kwargs.get("BLOCK_M", -1), + kwargs.get("BLOCK_K", -1), + kwargs.get("BLOCK_N", -1), + ) + ), + "num_stages": num_stages, + "num_warps": num_warps, + "GROUP_M": kwargs.get("GROUP_M", -1), + "allow_tf32": str(kwargs.get("ALLOW_TF32")), + "acc_type": str(kwargs.get("ACC_TYPE")), + "matrix_instr_nonkdim": kwargs.get("matrix_instr_nonkdim", 0), + "waves_per_eu": kwargs.get("waves_per_eu", 0), + "kpack": kwargs.get("kpack", 2), + **{ + k: kwargs[k] + for k in AlgorithmSelectorCache.FLEX_ATTENTION_TUNABLE_KEYS + if k in kwargs + }, + }, + mutated_inputs=mutated_inputs, + workspace_arg=workspace_arg, + allowed_prologue_inps=result.prologue_supported_inputs, + hint_override=hint_override, + ) + + +class ExternKernelChoice: + def __init__( + self, + kernel, + cpp_kernel=None, + *, + name=None, + has_out_variant=True, + op_overload=None, + use_fallback_kernel=False, + kernel_creator=None, + ) -> None: + super().__init__() + name = name or kernel.__name__ + assert callable(kernel) + assert not hasattr(extern_kernels, name), f"duplicate extern kernel: {name}" + self.name = name + self.cpp_kernel_name = cpp_kernel + self.has_out_variant = has_out_variant + setattr(extern_kernels, name, kernel) + self.op_overload = op_overload + self.use_fallback_kernel = use_fallback_kernel + self.kernel_creator = kernel_creator + # match the API for KernelTemplate as they can be treated the same + # There is no src hash for ExternKernelChoice in the traditional sense + # so we indicate this by returning None + self.src_hash = None + # By default GraphModule is None for extern kernels if not set + self.gm = None + + def to_callable(self): + return getattr(extern_kernels, self.name) + + def call_name(self): + return f"extern_kernels.{self.name}" + + @functools.cache # noqa: B019 + def hash_key(self): + fn = self.to_callable() + parts = [ + self.name, + getattr(fn, "__name__", ""), + getattr(fn, "__module__", ""), + ] + try: + parts.append(inspect.getsource(fn)) + except Exception: + pass + return code_hash("-".join(parts)) + + def bind( + self, + input_nodes, + layout, + ordered_kwargs_for_cpp_kernel=(), + **kwargs, + ): + self.ordered_kwargs_for_cpp_kernel = ordered_kwargs_for_cpp_kernel + return ExternKernelCaller( + self, input_nodes, layout, kwargs, has_out_variant=self.has_out_variant + ) + + @property + def uid(self) -> str: + # unique by prefixing with aten + return f"aten::{self.name}" + + def choice_or_none(self, **kwargs: Any) -> Optional[ChoiceCaller]: + """ + Maybe generates a new ChoiceCaller and returns it, or None if generation fails. + + kwargs: Additional kwargs to be passed to generate a new ChoiceCaller. + """ + temp_choices: list[Any] = [] + result = self.maybe_append_choice(temp_choices, **kwargs) + if result is None and len(temp_choices) == 1: + return temp_choices[0] + return None + + def maybe_append_choice( + self, choices: list[Any], **kwargs: Any + ) -> Optional[NotImplementedError]: + # convenience function to match the Template interface, so that + # templates and ExternKernelChoice can be treated the same when + # generating choice callers + assert "input_nodes" in kwargs, "input_nodes argument required" + assert "layout" in kwargs, "layout argument required" + input_nodes = kwargs.pop("input_nodes") + layout = kwargs.pop("layout") + choices.append(self.bind(input_nodes=input_nodes, layout=layout, **kwargs)) + return None + + +class TritonTemplateCaller(ir.TritonTemplateCallerBase): + def __init__( + self, + name, + input_nodes, + layout, + make_kernel_render, + description, + bmreq, + log_info: Optional[ + dict[str, Union[PrimitiveInfoType, list[PrimitiveInfoType]]] + ] = None, + mutated_inputs=None, + workspace_arg: Optional[WorkspaceArg] = None, + allowed_prologue_inps: Optional[OrderedSet[str]] = None, + hint_override: Optional[int] = None, + ) -> None: + super().__init__(name, input_nodes, layout, description) + self.make_kernel_render = make_kernel_render + self.bmreq: TritonBenchmarkRequest = bmreq + if log_info is None: + log_info = {} + self.log_info: dict[str, Any] = log_info + self.log_info.update( + { + "backend": "Triton", + "num_stages": self.bmreq.num_stages, + "num_warps": self.bmreq.num_warps, + } + ) + self.mutated_inputs = mutated_inputs + self.workspace_arg = workspace_arg + self.allowed_prologue_inps = ( + allowed_prologue_inps if allowed_prologue_inps is not None else OrderedSet() + ) + self.hint_override = hint_override + + def benchmark(self, *args, out): + assert self.bmreq is not None + if config.profile_bandwidth_with_do_bench_using_profiling: + algo = self.bmreq.make_run_fn(*args, out=out) + return do_bench_using_profiling(algo) + return self.bmreq.benchmark(*args, out=out) + + def precompile(self): + assert self.bmreq is not None + self.bmreq.precompile() + + def __str__(self) -> str: + return f"TritonTemplateCaller({self.bmreq.module_path}, {self.description})" + + def call_name(self): + return f"template_kernels.{self.name}" + + def hash_key(self): + return "-".join( + [ + self.name.rsplit("_", 1)[0], + self.bmreq.module_cache_key, + ] + ) + + def output_node(self): + return ir.TensorBox.create( + ir.TritonTemplateBuffer( + layout=self.layout, + inputs=self.input_nodes, + make_kernel_render=self.make_kernel_render, + mutated_inputs=self.mutated_inputs, + allowed_prologue_inps=self.allowed_prologue_inps, + ) + ) + + def info_dict(self) -> dict[str, Union[PrimitiveInfoType, list[PrimitiveInfoType]]]: + """Information returned here is logged to the autotune log file when that is enabled.""" + return self.log_info + + def get_make_kernel_render(self): + return self.make_kernel_render + + def autoheuristic_id(self): + type_name = "triton" + info = self.info_dict() + # TODO(AlnisM): Does tile_shape always exist? + tile = info["tile_shape"] + tile_vals = eval(tile) # type: ignore[arg-type] + BLOCK_M = tile_vals[0] + BLOCK_K = tile_vals[1] + BLOCK_N = tile_vals[2] + num_stages = info["num_stages"] + num_warps = info["num_warps"] + return f"type={type_name}_BLOCK-M={BLOCK_M}_BLOCK-K={BLOCK_K}_BLOCK-N={BLOCK_N}_numstages={num_stages}_numwarps={num_warps}" + + +class ExternKernelCaller(ChoiceCaller): + """ + Caller for external kernel implementations + """ + + def __init__( + self, + choice: ExternKernelChoice, + input_nodes, + layout, + kwargs=None, + *, + has_out_variant=True, + ) -> None: + super().__init__(choice.name, input_nodes, layout, description="") + self.choice = choice + self.kwargs = kwargs or {} + self.has_out_variant = has_out_variant + self.gm = choice.gm + self.bmreq: Optional[BenchmarkRequest] = None + + from torch._inductor.autotune_process import ( + ExternKernelBenchmarkRequest, + ExternKernelCPUBenchmarkRequest, + ExternKernelGPUBenchmarkRequest, + ) + + # Determine if this is a GPU or CPU kernel + if self.layout: + device = self.layout.device + else: + device = None + for inp_node in self.input_nodes: + dev = inp_node.get_device() + if dev and dev.type != "cpu": + device = dev + break + + if not device: + device = torch.device("cpu") + + self.input_tensor_meta: Union[list[TensorMeta], TensorMeta] + self.output_tensor_meta: Union[list[TensorMeta], TensorMeta] + self.input_tensor_meta, self.output_tensor_meta = [], [] + if device.type == "cpu": + benchmark_cls = ExternKernelCPUBenchmarkRequest + else: + try: + self.input_tensor_meta = TensorMeta.from_irnodes(self.input_nodes) + self.output_tensor_meta = TensorMeta.from_irnodes(self.layout) + except Exception: + log.warning( + "Constructing input/output tensor meta failed for Extern Choice" + ) + + benchmark_cls = ExternKernelGPUBenchmarkRequest + + self.bmreq: ExternKernelBenchmarkRequest = benchmark_cls( + kernel_name=self.choice.name, + input_tensor_meta=self.input_tensor_meta, + output_tensor_meta=self.output_tensor_meta, + extra_args=(), + callable_path=self.choice.call_name(), + kwargs=self.kwargs, + has_out_variant=self.has_out_variant, + ) + + def __str__(self) -> str: + return f"ExternKernelCaller({self.choice.call_name()})" + + def benchmark(self, *args, out): + # pyrefly: ignore [missing-attribute] + return self.bmreq.benchmark(*args, out=out) + + def benchmark_collective(self, *args, out): + """ + Called by benchmark_collective_choice, only run once, timing handled externally with barrier sync. + """ + if out.numel() == 0: + return + + algo = self.to_callable() + if self.has_out_variant: + algo(*args, out=out) + else: + algo(*args) + + def to_callable(self): + # pyrefly: ignore [missing-attribute] + return self.bmreq.to_callable() + + def hash_key(self): + return "-".join( + [ + self.choice.name, + *[ + f"{kwarg}={repr(self.kwargs[kwarg])}" + for kwarg in sorted(self.kwargs.keys()) + ], + self.choice.hash_key(), + ] + ) + + def output_node(self): + if self.choice.use_fallback_kernel: + assert self.choice.op_overload is not None, ( + "Please provide an op_overload to use ir.FallbackKernel" + ) + inner: ir.IRNode = ir.FallbackKernel.create( + self.choice.op_overload, *self.input_nodes, **self.kwargs + ) + elif self.choice.kernel_creator is not None: + inner = self.choice.kernel_creator(*self.input_nodes, **self.kwargs) + else: + cls = ir.ExternKernelOut if self.has_out_variant else ir.ExternKernelAlloc + inner = cls( + layout=self.layout, + inputs=self.input_nodes, + python_kernel_name=self.choice.call_name(), + cpp_kernel_name=self.choice.cpp_kernel_name, + ordered_kwargs_for_cpp_kernel=self.choice.ordered_kwargs_for_cpp_kernel, + op_overload=self.choice.op_overload, + kwargs=self.kwargs, + ) + + return ir.TensorBox.create(inner) + + def info_dict(self) -> dict[str, Union[PrimitiveInfoType, list[PrimitiveInfoType]]]: + """Information returned here is logged to the autotune log file when that is enabled.""" + return { + "backend": "extern", + "kernel_call_name": self.choice.call_name(), + } + + def autoheuristic_id(self): + return f"extern_{self.choice.name}" + + +@functools.cache +def get_mm_log_filename() -> Optional[str]: + mm_file_name = os.environ.get("TORCHINDUCTOR_MM_LOGGING_FILE", None) + if not mm_file_name: + return None + + if "json" not in mm_file_name: + mm_file_name = f"{mm_file_name}.json" + + return mm_file_name + + +@functools.cache +def get_flex_attention_log_filename() -> Optional[str]: + flex_attention_file_name = os.environ.get( + "TORCHINDUCTOR_FLEX_ATTENTION_LOGGING_FILE", None + ) + if not flex_attention_file_name: + return None + + return str(Path(flex_attention_file_name).with_suffix(".json")) + + +def append_to_log(filename, data): + lock_file = filename.replace(".json", ".lock") + lock = FileLock(lock_file) + with lock: + try: + with open(filename) as f: + log_data = json.load(f) + except (FileNotFoundError, json.JSONDecodeError): + log_data = [] + + log_data.append(data) + + with open(filename, "w") as f: + json.dump(log_data, f, indent=4) + + +class DataProcessorChoiceCallerWrapper: + def __init__(self, wrapped, preprocessor, postprocessor) -> None: + self._wrapped = wrapped + if preprocessor is not None: + self._preprocessor = preprocessor + else: + self._preprocessor = lambda x, y: (x, y) + if postprocessor is not None: + self._postprocessor = postprocessor + else: + self._postprocessor = lambda x: x + + def __getattr__(self, name): + return getattr(self._wrapped, name) + + def benchmark(self, *args, out) -> float: + new_args, new_out = self._preprocessor(args, out) + result = self._wrapped.benchmark(*new_args, out=new_out) + new_out = self._postprocessor(new_out) + if out is not new_out: + out.copy_(new_out) + return result + + def output_node(self) -> ir.TensorBox: + result = self._wrapped.output_node() + return self._postprocessor(result) + + def __repr__(self) -> str: + return f"DataProcessorChoiceCallerWrapper({self._wrapped})" + + +class DataProcessorTemplateWrapper: + """ + A wrapper class for a kernel template. + + This class together with `DataProcessorChoiceCallerWrapper` provides a convenient way to + preprocess and postprocess data before and after using the wrapped template. A typical + usage is to reorder or filter the input nodes in order to match the expected input of other + kernel choices like a ATen kernel. A more complicated usage is to prepack the weights. + See the example from :mod:`cpp_gemm_template` for more details. + """ + + def __init__( + self, + wrapped_template_cls, + preprocessor, + postprocessor, + **kwargs, + ) -> None: + if preprocessor is not None: + self._preprocessor = preprocessor + else: + self._preprocessor = lambda x, y: (x, y) + if postprocessor is not None: + self._postprocessor = postprocessor + else: + self._postprocessor = lambda x: x + assert "input_nodes" in kwargs + assert "layout" in kwargs + # pyrefly: ignore [not-callable] + kwargs["input_nodes"], kwargs["layout"] = preprocessor( + kwargs["input_nodes"], kwargs["layout"] + ) + self._wrapped = wrapped_template_cls(**kwargs) + + def __getattr__(self, name): + return getattr(self._wrapped, name) + + def maybe_append_choice(self, choices, **kwargs): + return type(self._wrapped).maybe_append_choice(self, choices, **kwargs) + + def generate(self, **kwargs): + choice_caller = self._wrapped.generate(**kwargs) + return DataProcessorChoiceCallerWrapper( + choice_caller, self._preprocessor, self._postprocessor + ) + + def __repr__(self) -> str: + return f"DataProcessorTemplateWrapper({self._wrapped})" + + +class ErrorFromChoice(RuntimeError): + def __init__(self, msg, choice: ChoiceCaller, inputs_str) -> None: + msg += f"\nFrom choice {choice}\n{inputs_str}" + super().__init__(msg) + self.choice = choice + + +class NoValidChoicesError(RuntimeError): + pass + + +@functools.cache +def get_num_workers() -> int: + if "TORCHINDUCTOR_COMPILE_THREADS" in os.environ: + return int(os.environ["TORCHINDUCTOR_COMPILE_THREADS"]) + + cpu_count = ( + len(os.sched_getaffinity(0)) + if hasattr(os, "sched_getaffinity") + else os.cpu_count() + ) + assert cpu_count + + # Divide the number of CPUs by the number of GPUs for distributed workloads + if ( + config.is_fbcode() + and torch.cuda.is_available() + and torch.cuda.device_count() > 0 + ): + cpu_count = cpu_count // torch.cuda.device_count() + + return cpu_count + + +def create_inputs_key(input_nodes) -> str: + return repr([AlgorithmSelectorCache.key_of(x) for x in input_nodes]) + + +def create_precompile_key( + name: str, inputs_key: str, choices: list[ChoiceCaller] +) -> str: + return ":".join( + [ + name, + inputs_key, + torch.get_float32_matmul_precision(), + ] + + [choice.kernel_hash_key() for choice in choices] + ) + + +# Args to FeedbackFunctions +# timings: mapping from choices to the benchmark time +# name: name of the op +# input_nodes: list of input ir.py Nodes +# choices: list of choices +# profiled time: Callable that returns a dict mapping from choices to the profiled time +FeedbackFunction = Callable[ + [ + dict[ChoiceCaller, float], + str, + list[Any], + list[ChoiceCaller], + Callable[[], dict[ChoiceCaller, float]], + ], + None, +] + +# Args to PreprocessingFunctions +# choices: list of ChoiceCaller objects to preprocess +# Returns: modified list of ChoiceCaller objects +PreprocessingFunction = Callable[[list[ChoiceCaller]], list[ChoiceCaller]] + + +def filter_choices_by_name_regex(choices: list[ChoiceCaller]) -> list[ChoiceCaller]: + """Filter choices based on autotune_choice_name_regex config.""" + if config.test_configs.autotune_choice_name_regex is not None: + return [ + c + for c in choices + if re.search( + config.test_configs.autotune_choice_name_regex, + c.name, + ) + ] + return choices + + +def filter_choices_by_desc_regex(choices: list[ChoiceCaller]) -> list[ChoiceCaller]: + """Filter choices based on autotune_choice_desc_regex config.""" + if config.test_configs.autotune_choice_desc_regex is not None: + return [ + c + for c in choices + if re.search( + config.test_configs.autotune_choice_desc_regex, + c.description, + ) + ] + return choices + + +class AlgorithmSelectorCache(PersistentCache): + """ + A persistent cache for algorithm selection results used in autotuning of GEMMs + and convolutions. + + This classes includes precompilation and benchmarking of the kernels. + + The cache is keyed by input characteristics (sizes, strides, dtypes, etc.) but + doesn't depend on the output layout. + """ + + FLEX_ATTENTION_TUNABLE_KEYS = tuple( + dict.fromkeys( + [ + "num_warps", + "num_stages", + "BLOCK_M", + "BLOCK_N", + "BLOCK_M1", + "BLOCK_N1", + "BLOCK_M2", + "BLOCK_N2", + "USE_TMA", + "kpack", + "matrix_instr_nonkdim", + "waves_per_eu", + ] + ) + ) + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + # the autotuning will get occur in the scheduler, so there is + # no guarantee that the first lowering for a given key will also be the + # first to benchmark it. share a single precompilation function for all lowerings + # of a particular key + self.precompile_cache: dict[str, Callable[[], None]] = {} + # cache for prescreening results to ensure deterministic candidate selection + self.prescreening_cache: dict[str, OrderedSet[str]] = {} + # list of callbacks that are called after benchmarking + self.feedback_saver_fns: list[FeedbackFunction] = [] + # list of callbacks that are called to preprocess choices + self.preprocessing_fns: list[PreprocessingFunction] = [] + + self._register_default_preprocessing_fns() + + # registers `self.cache_clear(...)` to be called when a fresh Inductor cache is requested + clear_on_fresh_cache(self) + + def _register_default_preprocessing_fns(self): + """Register default preprocessing functions.""" + # Note: broken out into its own function so that we can avoid clearing + # them (i.e. so we can restore them after clearing user provided ones) + self.add_preprocessing_fn(filter_choices_by_name_regex) + self.add_preprocessing_fn(filter_choices_by_desc_regex) + + def cache_clear(self) -> None: + self.precompile_cache.clear() + self.prescreening_cache.clear() + + def pick_deterministic_choice(self, choices: list[ChoiceCaller]) -> ChoiceCaller: + assert len(choices) >= 2 + externs = [ + choice for choice in choices if isinstance(choice, ExternKernelChoice) + ] + if len(externs) > 0: + # pyrefly: ignore [bad-return] + return externs[0] + else: + return choices[0] + + def __call__( + self, + name, + choices: list[ChoiceCaller], + input_nodes, + layout, + # optional dict mapping arg indices to the functions + # generating a torch.Tensor for that input from the + # corresponding ir.Buffer. if passed for a given + # arg, the function will be called instead of + # generating a random torch.Tensor for benchmarking. + input_gen_fns: Optional[dict[int, Callable[[ir.Buffer], torch.Tensor]]] = None, + precompilation_timeout_seconds: int = 60 * 60, + return_multi_template=False, + best_config_future=None, + return_choice=False, # TODO: return_choice is temporary and will be refactored soon + is_collective=False, + ): + from .codegen.cuda.cuda_kernel import CUDATemplateCaller + + # Run preprocessing functions on choices + for preprocessing_fn in self.preprocessing_fns: + choices = preprocessing_fn(choices) + + # Templates selected with input_gen_fns require specific input data to avoid IMA + # Passing custom input gen fns to benchmark_fusion NYI, so skip deferred template selection + # TODO(jgong5): support multi-template on CPU C++ backend + if input_gen_fns is not None or ( + layout.device.type == "cpu" and config.cpu_backend != "triton" + ): + return_multi_template = False + + # TODO - assert that we have not mutating kernels here + + if mm_file_name := get_mm_log_filename(): + M, K = input_nodes[-2].get_size()[:2] + N = input_nodes[-1].get_size()[-1] + append_to_log(mm_file_name, {"invoke": str((M, K, N))}) + + if len(choices) == 0: + raise self.create_no_valid_choices(name, "No choices exist for backend.") + log.debug("Max autotune selects from %s choices.", str(len(choices))) + + if len(choices) == 1: + if not isinstance(choices[0], CUDATemplateCaller): + # CUDATemplateCaller still needs to go through autotuning process to retrieve workspace size. + return choices[0].output_node() + + if config.deterministic: + return self.pick_deterministic_choice(choices).output_node() + + inputs_key = create_inputs_key(input_nodes) + + if config.autotune_in_subproc: + # Initialize the suprocess pool so it will warmup early. + torch._inductor.autotune_process.get_tuning_process_pool() + + precompile_fn = self.make_precompile_fn( + choices, + name, + inputs_key, + precompilation_timeout_seconds=precompilation_timeout_seconds, + ) + + if return_multi_template and (config.max_autotune or config.max_autotune_gemm): + + def get_timings(hint_override: Optional[int] = None): + filtered_choices = [ + c + for c in choices + if not hasattr(c, "hint_override") + or c.hint_override == hint_override + ] + timings = self.do_autotuning( + name, + input_nodes, + layout, + input_gen_fns, + inputs_key, + filtered_choices, + precompile_fn, + hint_override=hint_override, + best_config_future=best_config_future, + ) + min_extern_choice = float("inf") + for choice, timing in timings.items(): + if isinstance(choice, ExternKernelCaller): + min_extern_choice = min(min_extern_choice, timing) + + timings = { + choice: time + for choice, time in timings.items() + if ( + time <= min_extern_choice + or not isinstance(choice, ExternKernelCaller) + ) + } + + return timings + + # We take the union of allowed prologue inputs from all choices, + # and, within benchmark fusion, don't allow prologue fusion for + # choices which don't support the whole union. + allowed_prologue_inps: OrderedSet[str] = OrderedSet() + for c in choices: + if isinstance(c, TritonTemplateCaller): + allowed_prologue_inps |= c.allowed_prologue_inps + + return torch._inductor.ir.TensorBox.create( + torch._inductor.ir.MultiTemplateBuffer( + layout, + input_nodes, + get_timings, + choices, + allowed_prologue_inps, + ) + ) + + timings = self.do_autotuning( + name, + input_nodes, + layout, + input_gen_fns, + inputs_key, + choices, + precompile_fn, + best_config_future=best_config_future, + is_collective=is_collective, + ) + # if timings is empty, we really have no choice but to return a semi-random + # choice. returning the first `ExternKernelCaller` is probably the safest bet + # in this case, since it will generally be the ATen kernel. if there are no + # `ExternKernelCaller`s to return, then returning the 0th kernel is our next + # best option (ideally we'd fail whenever there is no ATen kernel to fallback + # to, but that's not trivial to figure out) + if timings == {}: + for choice in choices: + if isinstance(choice, ExternKernelCaller): + node = choice.output_node() + log.debug( + "Autotuning returned empty timings, falling back to first `ExternKernelCaller`: %s", + node, + ) + if return_choice: + return node, choice + return node + node = choices[0].output_node() + choice = choices[0] + log.debug( + "Autotuning returned empty timings, falling back to first choice: %s", + node, + ) + if return_choice: + return node, choice + return node + + # if we got any timings at all, pick the best of those + choice = min(timings, key=timings.__getitem__) + node = choice.output_node() + + log.debug("Autotuning selected choice: %s", node) + if return_choice: + return node, choice + return node + + def benchmark( + self, + choices, + input_nodes, + layout, + input_gen_fns, + hint_override: Optional[int] = None, + is_collective=False, + ): + counters["inductor"]["select_algorithm_autotune"] += 1 + # TODO(nmacchioni): remove this layer of abstraction + # construct `benchmark_fn` which should pick between in-process and sub-process autotuning + benchmark_fn = self.make_benchmark_fn( + choices, + input_nodes, + layout, + input_gen_fns, + hint_override=hint_override, + is_collective=is_collective, + ) + # `benchmark_fn(choices)` will execute each choice, and return a dict[choice, timing] which + # maps each choice to its runtime, calculated by the specified benchmarker, in milliseconds + return benchmark_fn(choices) + + def autotune( + self, + name, + input_nodes, + layout, + input_gen_fns, + choices, + hint_override: Optional[int] = None, + is_collective=False, + ): + log.debug("Starting autotuning") + + with dynamo_timed( + f"{name}_template_autotuning", + log_pt2_compile_event=True, + dynamo_compile_column_us="compile_time_autotune_time_us", + metadata=_autotune_metadata(input_nodes), + ): + benchmark_results = self.benchmark( + choices, + input_nodes, + layout, + input_gen_fns, + hint_override=hint_override, + is_collective=is_collective, + ) + if config.max_autotune_report_choices_stats: + _log_autotune_choices_stats( + f"{name}_template_autotuning", benchmark_results + ) + return benchmark_results + + def do_autotuning( + self, + name, + input_nodes, + layout, + input_gen_fns, + inputs_key, + choices, + precompile_fn, + hint_override: Optional[int] = None, + best_config_future=None, + is_collective=False, + ): + """Execute the autotuning process for kernel algorithm selection. + + This method orchestrates the complete autotuning pipeline including precompilation, + prescreening, benchmarking, and feedback collection to select the optimal kernel + implementation for given inputs. + + Args: + name: Name identifier for the operation being autotuned (e.g., 'mm', 'convolution'). + input_nodes: List of input IR nodes used for benchmarking. + layout: Layout information specifying device and memory format for the operation. + input_gen_fns: Optional dict mapping argument indices to functions that generate + torch.Tensor inputs from ir.Buffer for benchmarking. If provided, these are + used instead of random tensors. + inputs_key: Cache key representing the input characteristics (sizes, strides, dtypes). + choices: List of ChoiceCaller objects representing candidate kernel implementations. + precompile_fn: Callable that precompiles all kernel choices before benchmarking. + hint_override: Optional index to override which choice is selected, used for testing + or forced selection. + best_config_future: Optional future containing pre-determined best configuration to + filter choices by specific config parameters. + + Returns: + dict: Mapping from ChoiceCaller to benchmark timing in seconds. Choices with + non-finite timings (inf/nan) indicate failures. + + Raises: + NoValidChoicesError: When all choices fail to compile or benchmark, or when all + timing results are non-finite. + """ + if log.isEnabledFor(logging.DEBUG): + # Log shape information for debugging timeout issues + sizevars = V.graph.sizevars + shapes = [ + "x".join( + map( + str, + sizevars.size_hints( + node.get_size(), + fallback=config.unbacked_symint_fallback, + hint_override=hint_override, + ), + ) + ) + for node in input_nodes + ] + log.debug( + "[BENCHMARK DEBUG] Starting autotuning for '%s' with shapes: %s, device: %s", + name, + shapes, + layout.device.type if layout else "unknown", + ) + + precompile_start_ts = time.time() + with dynamo_timed( + f"{name}_template_precompiling", + log_pt2_compile_event=True, + dynamo_compile_column_us="compile_time_autotune_time_us", + ): + precompile_fn() + precompile_elapse = time.time() - precompile_start_ts + log.debug("Precompilation elapsed time: %.02fs", precompile_elapse) + # Prune anything that failed to compile + choices = [c for c in choices if not c.failed] + if len(choices) == 0: + raise self.create_no_valid_choices( + name, "All choices failed to compile for backend." + ) + + candidates = self.prescreen_choices( + choices, name, inputs_key, self.prescreening_cache + ) + prescreening_elapse: Optional[float] = None + if candidates: + prescreening_start_ts = time.time() + timings = self.lookup( + candidates, + name, + inputs_key, + lambda choices: self.autotune( + name, + input_nodes, + layout, + input_gen_fns, + choices, + hint_override=hint_override, + ), + hint_override=hint_override, + ) + choices = self.prune_choices_postscreen( + choices, timings, name, inputs_key, self.prescreening_cache + ) + prescreening_elapse = time.time() - prescreening_start_ts + log.debug("Prescreening elapsed time: %.02fs", prescreening_elapse) + + autotune_start_ts = time.time() + + if best_config_future is not None: + best_config = await_sync(best_config_future) + + important_keys = [ + "ACC_TYPE", + "ALLOW_TF32", + "BLOCK_K", + "BLOCK_M", + "BLOCK_N", + "EVEN_K", + "GROUP_M", + "USE_FAST_ACCUM", + "num_stages", + "num_warps", + "num_consumer_groups", + "num_buffers_warp_spec", + ] + choices = [ + choice + for choice in choices + if all( + f"{k}={best_config[k]}" in choice.description + for k in important_keys + ) + for k in important_keys + ] + log.info("Filtered to %d choices based on best_config", len(choices)) + + has_autotuned: bool = False + + def track_has_autotuned(choices): + nonlocal has_autotuned + has_autotuned = True + return self.autotune( + name, + input_nodes, + layout, + input_gen_fns, + choices, + hint_override=hint_override, + is_collective=is_collective, + ) + + timings = self.lookup( + choices, + name, + inputs_key, + track_has_autotuned, + hint_override=hint_override, + ) + + autotune_elapse = time.time() - autotune_start_ts + log.debug("Autotuning elapsed time: %.02fs", autotune_elapse) + + # For collective: if any choice returned inf (timeout or failure), fallback to default + if is_collective and timings: + has_inf = any(not math.isfinite(timing) for timing in timings.values()) + if has_inf: + log.warning( + "At least one choice failed or timed out during collective benchmarking. " + "Falling back to default implementation." + ) + return {} + + # For regular: if all choices returned inf, raise error + if timings and all(not math.isfinite(timing) for timing in timings.values()): + raise NoValidChoicesError + + if ( + has_autotuned + or log.getEffectiveLevel() == logging.DEBUG + or config.trace.log_autotuning_results + ): + self.log_results( + name, + input_nodes, + timings, + autotune_elapse, + precompile_elapse, + prescreening_elapse, + hint_override=hint_override, + is_collective=is_collective, + ) + + def profiler_bench_function(): + # we're not running through the normal caching autotuner method here because we want to avoid returning + # the cached value. + # Avoid benchmarking in a separate process because it's not easy to signal to the TuningProcess that we + # should use the profiler. + with config.patch( + profile_bandwidth_with_do_bench_using_profiling=True, + autotune_in_subproc=False, + ): + return self.benchmark(choices, input_nodes, layout, input_gen_fns) + + for feedback_fn in self.feedback_saver_fns: + # re-benchmarking the same choices with profiler is a bit expensive, so pass it in as a thunk. + feedback_fn( + timings, + name, + input_nodes, + choices, + profiler_bench_function, + ) + + return timings + + def create_no_valid_choices(self, name: str, reason: str) -> NoValidChoicesError: + backend_config = ( + "max_autotune_gemm_backends" + if name != "convolution" + else "max_autotune_conv_backends" + ) + return NoValidChoicesError( + f"No choices to select. Provided reason: {reason} " + f"please consider adding ATEN into {backend_config} " + "config (defined in torch/_inductor/config.py) to allow at least one choice. " + ) + + def make_precompile_fn( + self, + choices, + name: str, + inputs_key: str, + precompilation_timeout_seconds: Optional[int] = 60 * 60, + ) -> Callable[[], None]: + """ + Returns a function that precompiles the given choices. + """ + log.debug("Starting precompilation") + + def no_op(*args, **kwargs): + return + + if ( + precompilation_timeout_seconds is None + or precompilation_timeout_seconds <= 0 + ): + log.debug("Precompilation timeout is None or <= 0, returning no_op") + return no_op + + num_workers = min(get_num_workers(), len(choices)) + + if num_workers <= 0: + return no_op + + # https://github.com/python/cpython/issues/106905 + if ( + sys.version_info.major == 3 + and sys.version_info.minor == 11 + and sys.version_info.micro <= 8 + ): + return no_op + + # check local and global cache before precompiling + timings = self.lookup( + choices, + name, + inputs_key, + benchmark=None, + ) + + if timings and len(timings) == len(choices): + # compilation in precompile stage is much cheaper than that in + # autotuning stage + log.debug("Found all %d timings in cache, returning no_op", len(timings)) + return no_op + + precompile_key = create_precompile_key(name, inputs_key, choices) + if precompile_func := self.precompile_cache.get(precompile_key): + log.debug("Precompile function found in cache, returning it") + return precompile_func + + log.info( + "Multithreaded precompilation for %d choices using %d worker threads", + len(choices), + num_workers, + ) + + # In rare circumstances, because python threads inherit global state, + # thread pool executor can race and leave stdout/stderr in a state + # different than the original values. we explicitly restore the state + # here to avoid this issue. + + def precompile_with_captured_stdout(choice) -> tuple[None, int]: + log.debug("Precompiling choice with captured stdout: %s", choice) + start_ns = time.time_ns() + with restore_stdout_stderr(): + choice.precompile() + elapsed_ns = time.time_ns() - start_ns + # Return tuple as triton async compile (_worker_compile_triton) + # returns tuple[CachingAutotuner, int] + return None, elapsed_ns // 1000 + + def on_complete(future): + if not future.exception(): + _, precompile_elapsed_us = future.result() + elapsed_seconds = precompile_elapsed_us / 1e6 + elapsed_times[future] = elapsed_seconds + log.debug( + "Precompilation complete for future: %s, elapsed time: %.02fs", + future, + elapsed_seconds, + ) + + executor = ThreadPoolExecutor(max_workers=num_workers) + async_compile = torch._inductor.async_compile.AsyncCompile() + + futures: dict[concurrent.futures.Future[Any], ChoiceCaller] = {} + elapsed_times: dict[concurrent.futures.Future[Any], float] = {} + + # Some choices only differ in runtime arguments, so we + # skip a choice if it has the same hash as a previously seen choice + seen_choices: OrderedSet[str] = OrderedSet() + for c in choices: + # Skip choices which we have already issued a precompile + if c.kernel_hash_key() in seen_choices: + log.debug("Skipping already seen choice: %s", c) + continue + else: + seen_choices.add(c.kernel_hash_key()) + + if hasattr(c, "precompile"): + triton_cuda_choice = isinstance(c, TritonTemplateCaller) and isinstance( + c.bmreq, TritonGPUBenchmarkRequest + ) + if triton_cuda_choice and async_compile.use_process_pool(): + with open(c.bmreq.module_path) as file: + source_code = file.read() + future = async_compile.triton( + kernel_name=c.bmreq.kernel_name, source_code=source_code + ).future + log.debug("Submitted triton async compile for choice: %s", c) + else: + future = executor.submit(precompile_with_captured_stdout, c) + log.debug("Submitted precompile for choice: %s", c) + + future.add_done_callback(on_complete) + futures[future] = c + + @functools.cache + @restore_stdout_stderr() + def wait_on_futures(): + log.debug("Waiting on futures") + counters["inductor"]["select_algorithm_precompile"] += 1 + exceptions: list[tuple[ChoiceCaller, BaseException]] = [] + try: + for future in as_completed( + futures, + timeout=precompilation_timeout_seconds, + ): + if e := future.exception(): + counters["inductor"][ + "select_algorithm_num_precompilation_exceptions" + ] += 1 + exceptions.append((futures[future], e)) + log.exception( # noqa: G202 + "Exception %s for benchmark choice %s", + e, + futures[future], + exc_info=e, + ) + futures[future].mark_failed() + else: + counters["inductor"]["select_algorithm_num_precompiles"] += 1 + log.info( + "Precompiling benchmark choice %s took %.02fs", + futures.get(future), + elapsed_times.get(future), + ) + except TimeoutError: + # Don't force the entire process to crash due to a timeout + # in compilation. Just mark those futures as failed. + completed_futures = OrderedSet([f for f in futures if f.done()]) + remaining_futures = OrderedSet(futures.keys()) - completed_futures + + log.warning( + "Precompilation timeout after %ds: %d of %d futures did not complete", + precompilation_timeout_seconds, + len(remaining_futures), + len(futures), + ) + + # Mark remaining futures as failed and log them + for future in remaining_futures: + choice = futures[future] + log.warning( + "Marking choice as failed due to timeout: %s", + choice, + ) + choice.mark_failed() + # Add timeout exception to the exceptions list + timeout_exc = TimeoutError( + f"Precompilation timed out after {precompilation_timeout_seconds}s" + ) + exceptions.append((choice, timeout_exc)) + if exceptions: + _log_autotune_exceptions(exceptions) + + executor.shutdown(wait=True) + + self.precompile_cache[precompile_key] = wait_on_futures + + return wait_on_futures + + @classmethod + def get_inputs( + cls, + choices: Sequence[ChoiceCaller], + input_nodes: list[ir.IRNode], + layout: ir.Layout, + input_gen_fns: Optional[dict[int, Callable[[ir.Buffer], torch.Tensor]]], + hint_override: Optional[int] = None, + ) -> AutotuneArgs: + """ + Factory method to create AutotuneArgs from a list of ChoiceCallers. + """ + if input_gen_fns is None: + input_gen_fns = {} + + # de-duplicate args + unique_example_inputs = { + x.get_name(): input_gen_fns.get( + i, + lambda x: cls.benchmark_example_value(x, hint_override=hint_override), + # pyrefly: ignore [bad-argument-type] + )(x) + for i, x in enumerate(input_nodes) + } + example_inputs = list(unique_example_inputs.values()) + example_inputs_extern = [] + for input_node in input_nodes: + if unique_example_inputs[input_node.get_name()].is_mkldnn: + example_inputs_extern.append( + unique_example_inputs[input_node.get_name()] + ) + else: + base = unique_example_inputs[input_node.get_name()] + base = base if base._base is None else base._base + sizes = tuple( + V.graph.sizevars.atomically_apply_size_hint( + size, + fallback=config.unbacked_symint_fallback, + hint_override=hint_override, + ) + for size in input_node.get_size() + ) + strides = tuple( + V.graph.sizevars.atomically_apply_size_hint( + stride, + fallback=config.unbacked_symint_fallback, + hint_override=hint_override, + ) + for stride in input_node.get_stride() + ) + storage_offset = V.graph.sizevars.atomically_apply_size_hint( + input_node.get_layout().offset, + fallback=config.unbacked_symint_fallback, + hint_override=hint_override, + ) + + # Check if the required storage size exceeds the current storage + # to avoid illegal memory access + needed_size = torch._prims_common.compute_required_storage_length( + sizes, strides, storage_offset + ) + current_size = base.storage().size() + + if needed_size > current_size: + # Create a new base tensor with sufficient storage + new_base = torch.randn( + needed_size, + dtype=base.dtype, + device=base.device, + requires_grad=base.requires_grad, + ) + base = new_base.as_strided( + base.size(), base.stride(), base.storage_offset() + ) + + example_inputs_extern.append( + torch.as_strided(base, sizes, strides, storage_offset) + ) + out = cls.benchmark_example_value(layout, hint_override=hint_override) + + # Also check the output tensor for storage size + out_base = out if out._base is None else out._base + out_offset = V.graph.sizevars.size_hint(layout.offset) + needed_out_size = torch._prims_common.compute_required_storage_length( + out.size(), out.stride(), out_offset + ) + current_out_size = out_base.storage().size() + + if needed_out_size > current_out_size: + # Create a new base tensor with sufficient storage + new_out_base = torch.randn( + needed_out_size, + dtype=out_base.dtype, + device=out_base.device, + requires_grad=out_base.requires_grad, + ) + out_base = new_out_base.as_strided( + out_base.size(), out_base.stride(), out_base.storage_offset() + ) + + out_extern = torch.as_strided(out_base, out.size(), out.stride(), out_offset) + expected = None + if VERIFY: + choices[0].benchmark(*example_inputs_extern, out=out_extern) + expected = out_extern.clone() + + return AutotuneArgs.from_choice_args( + example_inputs, + example_inputs_extern, + out, + out_extern, + expected, + ) + + @staticmethod + def _is_extern(choice: ChoiceCaller) -> bool: + return isinstance(choice, (ExternKernelCaller, SubgraphChoiceCaller)) + + @classmethod + def benchmark_choice( + cls, choice: ChoiceCaller, autotune_args: AutotuneArgs + ) -> float: + benchmark_tensors = autotune_args.get_benchmark_tensors(cls._is_extern(choice)) + inputs, output = benchmark_tensors.unpack() + output.zero_() + result = choice.benchmark(*inputs, out=output) + device_type = next( + (tensor.device.type for tensor in inputs if is_gpu(tensor.device.type)), + "cuda", + ) + device_interface = get_interface_for_device(device_type) + if device_interface.is_available(): + device_interface.synchronize() # shake out any CUDA errors + + if VERIFY and autotune_args.expected is not None: + autotune_args.verify(**VERIFY) + return result + + @classmethod + def _run_collective_benchmark( + cls, + choice: ChoiceCaller, + inputs: tuple, + output: torch.Tensor, + nruns: int, + process_group, + timeout, + ) -> float: + """ + Single function for benchmarking collective operations. + Used for both warmup and actual benchmarking. + + Returns total time in milliseconds, or raises TimeoutError if any collective times out. + """ + import torch.distributed as dist + + work = dist.barrier(group=process_group, async_op=True) + if not work.wait(timeout): + raise TimeoutError("Barrier timeout before benchmarking") + + torch.cuda.synchronize() + + total_time = 0.0 + + for i in range(nruns): + torch.cuda.synchronize() + + start_evt = torch.cuda.Event(enable_timing=True) + end_evt = torch.cuda.Event(enable_timing=True) + + start_evt.record() + choice.benchmark_collective(*inputs, out=output) # type: ignore[attr-defined] + end_evt.record() + end_evt.synchronize() + + total_time += start_evt.elapsed_time(end_evt) + + return total_time + + @classmethod + def benchmark_collective_choice( + cls, + choice: ChoiceCaller, + autotune_args: AutotuneArgs, + ) -> float: + """ + Benchmark a choice for collective operations with cross-rank synchronization. + This method ensures all ranks synchronize before benchmarking + to get accurate measurements for distributed collective operations. + + Timeout/Error handling: If ANY rank times out or encounters an error during + the collective operations, ALL ranks will naturally time out (since the collective + won't complete), allowing the autotuner to fall back to the default implementation. + """ + from datetime import timedelta + + import torch.distributed as dist + + timeout_seconds = config.collective_benchmark_timeout + + nruns = config.collective_benchmark_nruns + nwarmup = ir.autotune_warmup + + # Use default process group (None = all ranks) + process_group = None + rank = dist.get_rank(process_group) + + benchmark_tensors: BenchmarkTensors = autotune_args.get_benchmark_tensors( + cls._is_extern(choice) + ) + inputs, output = benchmark_tensors.unpack() + output.zero_() + + timeout = timedelta(seconds=timeout_seconds) + + try: + # Do n warmups + cls._run_collective_benchmark( + choice, inputs, output, nwarmup, process_group, timeout + ) + + # Do n actual benchmarking runs + total_time = cls._run_collective_benchmark( + choice, inputs, output, nruns, process_group, timeout + ) + + avg_time = total_time / nruns + + # All-reduce to get avg time across ranks + time_tensor = torch.tensor( + [avg_time], dtype=torch.float32, device=f"cuda:{rank}" + ) + work = dist.all_reduce( + time_tensor, + op=dist.ReduceOp.AVG, + group=process_group, + async_op=True, + ) + if not work.wait(timeout): + raise TimeoutError( + "All-reduce timeout when collecting benchmark results" + ) + + timing = time_tensor.item() + + log.info( + "Collective benchmark for %s: %.6f ms", + choice.name, + timing, + ) + + return timing + + except Exception: + log.warning( + "Collective benchmark exception for choice %s. Skipping this choice.", + getattr(choice, "name", ""), + exc_info=True, + ) + return float("inf") + + @classmethod + def benchmark_choices( + cls, + choices: Sequence[ChoiceCaller], + autotune_args: AutotuneArgs, + is_collective: bool = False, + ) -> dict[ChoiceCaller, float]: + """ + Benchmark a list of choices and return timing dict. + """ + if is_collective: + import torch.distributed as dist + + if not dist.is_initialized(): + log.warning( + "Collective op detected but distributed not initialized. " + "Falling back to regular benchmarking." + ) + is_collective = False + else: + rank = dist.get_rank(None) # Use default process group + log.debug( + "Using collective benchmarking for %d choices on rank %d", + len(choices), + rank, + ) + timings = {} + for choice in choices: + try: + if is_collective: + timing = cls.benchmark_collective_choice(choice, autotune_args) + else: + timing = cls.benchmark_choice(choice, autotune_args) + except CUDACompileError: + from torch._inductor.codegen.cuda.cuda_kernel import CUDATemplateCaller + + if not isinstance(choice, CUDATemplateCaller): + log.exception( + "CUDA compilation error during autotuning: \n%s. \nIgnoring this choice." + ) + timing = float("inf") + except NotImplementedError: + log.warning("Not yet implemented", exc_info=True) + timing = float("inf") + except RuntimeError as e: + from torch._inductor.codegen.cuda.cuda_kernel import CUDATemplateCaller + + msg = str(e) + if "invalid argument" in msg: + msg += "\n\nThis may mean this GPU is too small for max_autotune mode.\n\n" + elif "illegal memory access" in msg: + msg += "\n\nEither error in template or triton bug.\n" + elif "unspecified launch failure" in msg: + msg += "\n\nAn unrecoverable unspecified launch failure was caught during autotuning." + msg += "\nPlease try re-running with TORCHINDUCTOR_AUTOTUNE_IN_SUBPROC=1.\n\n" + + if isinstance(choice, CUDATemplateCaller): + log.debug( + "Runtime error during autotuning: \n%s. \nIgnoring this choice.", + msg, + exc_info=True, + ) + else: + log.error( + "Runtime error during autotuning: \n%s. \nIgnoring this choice.", + msg, + ) + timing = float("inf") + except AssertionError as e: + raise AssertionError( # noqa: B904 + f"Incorrect result from choice {choice}\n\n{e}" + ) + except Exception as e: + try: + from triton.runtime.autotuner import OutOfResources + + if isinstance(e, OutOfResources): + log.warning(e) # noqa: G200 + timing = float("inf") + else: + raise e + except ImportError: + raise e from None + + timings[choice] = timing + + # If a collective choice failed or timed out, skip the rest of the choices + if is_collective and not math.isfinite(timing): + log.warning( + "Choice %s failed or timed out during collective benchmarking. " + "Stopping further benchmarking to avoid NCCL corruption.", + getattr(choice, "name", ""), + ) + timings.update({c: float("inf") for c in choices if c not in timings}) + break + + return timings + + @classmethod + def benchmark_in_current_process( + cls, + choices: Sequence[ChoiceCaller], + input_nodes: list[ir.IRNode], + layout: ir.Layout, + input_gen_fns: Optional[dict[int, Callable[[ir.Buffer], torch.Tensor]]], + hint_override: Optional[int] = None, + is_collective=False, + ) -> dict[ChoiceCaller, float]: + inputs = cls.get_inputs( + choices, input_nodes, layout, input_gen_fns, hint_override=hint_override + ) + return cls.benchmark_choices( + choices, + inputs, + is_collective=is_collective, + ) + + @classmethod + def benchmark_in_sub_process( + cls, + choices: Sequence[ChoiceCaller], + input_nodes: list[ir.IRNode], + layout: ir.Layout, + input_gen_fns: Optional[dict[int, Callable[[ir.Buffer], torch.Tensor]]], + hint_override: Optional[int] = None, + ): + from . import autotune_process + + # only benchmark triton kernel in sub process for now. + # ATen/Extern kernel are still benchmarked in the current process. + extern = [c for c in choices if cls._is_extern(c)] + triton = [c for c in choices if not cls._is_extern(c)] + + timings = cls.benchmark_in_current_process( + extern, input_nodes, layout, input_gen_fns, hint_override=hint_override + ) + timings.update(autotune_process.benchmark_in_sub_process(triton)) # type: ignore[arg-type] + return timings + + @classmethod + def make_benchmark_fn( + cls, + choices: Sequence[ChoiceCaller], + input_nodes: list[ir.IRNode], + layout: ir.Layout, + input_gen_fns: Optional[dict[int, Callable[[ir.Buffer], torch.Tensor]]], + hint_override: Optional[int] = None, + is_collective=False, + ): + if DEBUG: + print(f"{len(choices)} tuning requests:") + + # Collective ops must use current process + if is_collective or not config.autotune_in_subproc: + return functools.partial( + cls.benchmark_in_current_process, + input_nodes=input_nodes, + layout=layout, + input_gen_fns=input_gen_fns, + hint_override=hint_override, + is_collective=is_collective, + ) + else: + return functools.partial( + cls.benchmark_in_sub_process, + input_nodes=input_nodes, + layout=layout, + input_gen_fns=input_gen_fns, + hint_override=hint_override, + ) + + @staticmethod + def prescreen_choices( + choices: list[ChoiceCaller], + name: str, + inputs_key: str, + prescreen_cache: dict[str, OrderedSet[str]], + ) -> list[ChoiceCaller]: + """ + Figure out what choices need to be prescreened before autotuning with runtime + params. + + Prescreening is a process of reducing the number of autotuning for choices with + runtime params via a two stage autotuning process. First, we fix a set of runtime + params (here we use swizzle=2) and run autotuning to get a set of candidates. + Then, we run autotuning again with the candidates and the full set of runtime + params. + + Since have the concept of runtime params, we need to differentiate between + choice's hash_key and choice's kernel_hash_key. The former includes information + like runtime params, while the latter does not. prescreen_cache, if exists, stores + the set of hash_key that should win the prescreening. + + Right now, only CUTLASS choices have runtime params. + """ + # Create a cache key for prescreening results + prescreen_key = f"{name}:{inputs_key}" + + # Check if we have cached prescreening results (prescreen_winners) + if prescreen_key in prescreen_cache: + prescreen_winners = [ + choice + for choice in choices + if choice.hash_key() in prescreen_cache[prescreen_key] + ] + return prescreen_winners + + # prescreen cutlass + from .codegen.cuda.cuda_kernel import CUDATemplateCaller + + candidates = [] + if ( + config.cuda.cutlass_prescreening + and len(config.cuda.cutlass_max_profiling_swizzle_options) > 1 + ): + candidates.extend( + [ + c + for c in choices + if isinstance(c, CUDATemplateCaller) + # hardcoded to only look at swizzle=2 + if c.info_dict().get("swizzle") == "2" + ] + ) + + # skip prescreening if the number of candidates is too small + if len(candidates) < 10: + return [] + + return candidates # type: ignore[return-value] + + @staticmethod + def prune_choices_postscreen( + choices: list[ChoiceCaller], + candidate_timings: dict[ChoiceCaller, float], + name: str, + inputs_key: str, + prescreen_cache: dict[str, OrderedSet[str]], + ) -> list[ChoiceCaller]: + """ + Prune the choices after prescreening. + """ + from .codegen.cuda.cuda_kernel import CUDATemplateCaller + + prescreen_key = f"{name}:{inputs_key}" + + # Check if we have cached postscreen results + if prescreen_key in prescreen_cache: + # candidate_timings are from choices that have won prescreening already + winner_kernel_hashes = [ + candidate.kernel_hash_key() for candidate in candidate_timings + ] + + pruned_choices = [ + choice + for choice in choices + if not isinstance(choice, CUDATemplateCaller) + or choice.kernel_hash_key() in winner_kernel_hashes + ] + return pruned_choices + + log.debug("Before pruning using prescreening timings, %d choices", len(choices)) + sorted_candidates = sorted( + candidate_timings.keys(), key=lambda choice: candidate_timings[choice] + ) + + # Print prescreening timings + if ( + candidate_timings + and PRINT_AUTOTUNE + and config.autotune_num_choices_displayed != 0 + ): + n = config.autotune_num_choices_displayed + top_k = sorted_candidates[:n] + best = top_k[0] + best_time = candidate_timings[best] + + lines = ["PRESCREENING CANDIDATE TIMINGS"] + for choice in top_k: + result = candidate_timings[choice] + if result: + lines.append( + f" {choice.name} {result:.4f} ms {best_time / result:.1%} {choice.description}" + ) + else: + lines.append( + f" {choice.name} {result:.4f} ms " + ) + + log.info("\n".join(lines)) + num_to_keep = max(int(math.sqrt(len(choices)) / 4), 8) + + # prune choices based on prescreening timings + candidates_to_prune = OrderedSet( + candidate.kernel_hash_key() for candidate in sorted_candidates[num_to_keep:] + ) + winner_hashes: OrderedSet[str] = OrderedSet() + for candidate in sorted_candidates[:num_to_keep]: + if candidate_timings[candidate] == float("inf"): + candidates_to_prune.add(candidate.kernel_hash_key()) + else: + winner_hashes.add(candidate.hash_key()) + if isinstance(candidate, CUDATemplateCaller): + candidate.bmreq.ensure_dll_loaded() + + pruned_choices = [ + choice + for choice in choices + if choice.kernel_hash_key() not in candidates_to_prune # type: ignore[attr-defined] + ] + + # Cache the hash_key of winners of prescreening + prescreen_cache[prescreen_key] = winner_hashes + + log.debug( + "After pruning using prescreening timings, %d choices", len(pruned_choices) + ) + return pruned_choices + + @staticmethod + def get_flex_attention_choice_info( + choice: ChoiceCaller, timings: dict[ChoiceCaller, float] + ) -> dict[str, Any]: + if isinstance(choice, torch._inductor.select_algorithm.ExternKernelCaller): + return {"type": "extern", "time": timings[choice]} + + assert isinstance(choice, torch._inductor.select_algorithm.TritonTemplateCaller) + + info = choice.info_dict() + result = { + "type": "triton", + "time": timings[choice], + } + + for key in AlgorithmSelectorCache.FLEX_ATTENTION_TUNABLE_KEYS: + if key in info: + # pyrefly: ignore [unsupported-operation] + result[key] = info[key] + + return result + + @staticmethod + def maybe_log_flex_attention_results( + name: str, input_nodes: list[ir.IRNode], timings: dict[ChoiceCaller, float] + ) -> None: + flex_attention_filename = get_flex_attention_log_filename() + if not flex_attention_filename or "flex_attention" not in name: + return + + if len(input_nodes) < 3: + return + + query_size = input_nodes[0].get_size() + key_size = input_nodes[1].get_size() + value_size = input_nodes[2].get_size() + + B = query_size[0] + Hq = query_size[1] + seq_len_q = query_size[2] + qk_head_dim = query_size[3] + Hkv = key_size[1] + seq_len_kv = key_size[2] + v_head_dim = value_size[3] + + kernel_type = "backward" if "backward" in name else "forward" + dims_key = str( + ( + kernel_type, + B, + Hq, + Hkv, + seq_len_q, + seq_len_kv, + qk_head_dim, + v_head_dim, + ) + ) + + sorted_choices = sorted(timings, key=timings.__getitem__) + out_dict = { + dims_key: [ + AlgorithmSelectorCache.get_flex_attention_choice_info(choice, timings) + for choice in sorted_choices + ] + } + append_to_log(flex_attention_filename, out_dict) + + @staticmethod + def log_results( + name: str, + input_nodes: list[ir.IRNode], + timings: dict[ChoiceCaller, float], + elapse: float, + precompile_elapse: float, + prescreening_elapse: Optional[float] = None, + hint_override: Optional[int] = None, + is_collective: bool = False, + ): + """Log the autotuning results, currently only handles mm and flex. Log Collective op autotuning result""" + if is_collective and timings: + import torch.distributed as dist + + # Only rank 0 logs to avoid duplicate logs + rank = dist.get_rank() if dist.is_initialized() else 0 + if rank == 0: + best_choice = min(timings, key=timings.__getitem__) + log.warning("[COLLECTIVE AUTOTUNING] All timings:") + for c, t in sorted(timings.items(), key=lambda x: x[1]): + choice_name = getattr(c, "name", str(c)) + log.warning( + " - %s: %.6f ms %s", + choice_name, + t if math.isfinite(t) else float("inf"), + "← SELECTED" if c == best_choice else "", + ) + + V.debug.log_autotuning_results( + name, input_nodes, timings, elapse, precompile_elapse + ) + if not (config.max_autotune or config.max_autotune_gemm) or not PRINT_AUTOTUNE: + return + sizes = ", ".join( + [ + "x".join( + map( + str, + V.graph.sizevars.size_hints( + n.get_size(), + fallback=config.unbacked_symint_fallback, # type: ignore[arg-type] + hint_override=hint_override, + ), + ) + ) + for n in input_nodes + ] + ) + + strides = ", ".join([str(n.get_stride()) for n in input_nodes]) + dtypes = ", ".join([str(n.get_dtype()) for n in input_nodes]) + if config.autotune_num_choices_displayed == 0: + return + + # when autotune_num_choices_displayed is None, [:None] means all + n = config.autotune_num_choices_displayed + top_k = sorted(timings, key=timings.__getitem__)[:n] + + best = top_k[0] + + def get_choice_info(choice): + if isinstance(choice, torch._inductor.select_algorithm.ExternKernelCaller): + return {"type": "cublas", "time": timings[choice]} + + assert isinstance( + choice, torch._inductor.select_algorithm.TritonTemplateCaller + ) + + info = choice.info_dict() + tile = info["tile_shape"] + + tile_vals = eval(tile) # type: ignore[arg-type] + BLOCK_M = tile_vals[0] + BLOCK_K = tile_vals[1] + BLOCK_N = tile_vals[2] + + return { + "type": "triton", + "time": timings[choice], + "BLOCK_M": BLOCK_M, + "BLOCK_K": BLOCK_K, + "BLOCK_N": BLOCK_N, + "num_stages": info["num_stages"], + "num_warps": info["num_warps"], + } + + mm_filename = get_mm_log_filename() + if mm_filename and "mm" in name: + M, K = input_nodes[-2].get_size()[:2] + N = input_nodes[-1].get_size()[-1] + + out_dict = {str((M, K, N)): [get_choice_info(choice) for choice in timings]} + + append_to_log(mm_filename, out_dict) + + AlgorithmSelectorCache.maybe_log_flex_attention_results( + name, input_nodes, timings + ) + + best_time = timings[best] + sys.stderr.write(f"AUTOTUNE {name}({sizes})\n") + sys.stderr.write(f"strides: {strides}\n") + sys.stderr.write(f"dtypes: {dtypes}\n") + + for choice in top_k: + result = timings[choice] + if result: + kernel_description = choice.description + sys.stderr.write( + f" {choice.name} {result:.4f} ms {best_time / result:.1%} {kernel_description}\n" + ) + else: + sys.stderr.write( + f" {choice.name} {result:.4f} ms \n" + ) + + autotune_type_str = ( + "SubProcess" if config.autotune_in_subproc else "SingleProcess" + ) + prescreening_msg = ( + f" and {prescreening_elapse:.4f} seconds prescreening" + if prescreening_elapse is not None + else "" + ) + sys.stderr.write( + f"{autotune_type_str} AUTOTUNE benchmarking takes {elapse:.4f} seconds and {precompile_elapse:.4f}" + f" seconds precompiling for {len(timings)} choices" + + prescreening_msg + + "\n" + ) + + @staticmethod + def benchmark_example_value(node, hint_override: Optional[int] = None): + """ + Convert an ir.Buffer into a concrete torch.Tensor we can use for + benchmarking. + """ + if isinstance(node, ir.Layout): + node = ir.Buffer(name="fake", layout=node) + # triton templates want the base tensor. + if isinstance(node, ir.BaseView): + node = node.unwrap_view() + + # Inplace padding may reinterpret a tensor to a larger tensor if the + # stride is large enough. The V.graph.get_allocation_size takes this into account. + # So we need call as_strided in the end to 'view' the tensor with the correct + # sizes/strides + return AlgorithmSelectorCache.generate_example_value( + tuple( + V.graph.sizevars.atomically_apply_size_hint( + size, + fallback=config.unbacked_symint_fallback, + hint_override=hint_override, + ) + for size in node.get_size() + ), + tuple( + V.graph.sizevars.atomically_apply_size_hint( + stride, + fallback=config.unbacked_symint_fallback, + hint_override=hint_override, + ) + for stride in node.get_stride() + ), + node.get_device(), + node.get_dtype(), + V.graph.sizevars.atomically_apply_size_hint( + # pyrefly: ignore [missing-attribute] + node.layout.offset, + fallback=config.unbacked_symint_fallback, + hint_override=hint_override, + ), + tuple( + V.graph.sizevars.atomically_apply_size_hint( + size, + fallback=config.unbacked_symint_fallback, + hint_override=hint_override, + ) + # pyrefly: ignore [bad-argument-type] + for size in V.graph.get_allocation_size(node) + ), + ) + + @staticmethod + def generate_example_value( + size, stride, device, dtype, extra_size, allocation_size=None + ): + # preserve rng states to avoid the rand_strided call below changes + # the rng states for the real model code. + with preserve_rng_state(): + if allocation_size is None or allocation_size == size: + return rand_strided( + size, + stride, + device=device, + dtype=dtype, + extra_size=extra_size, + ) + else: + return rand_strided( + allocation_size, + stride, + device=device, + dtype=dtype, + extra_size=extra_size, + ).as_strided(size, stride) + + @staticmethod + def key_of(node): + """ + Extract the pieces of an ir.Buffer that we should invalidate cached + autotuning results on. + """ + sizevars = V.graph.sizevars + return ( + node.get_device().type, + str(node.get_dtype()), + *sizevars.size_hints( + node.get_size(), + fallback=config.unbacked_symint_fallback, + ), + *tuple( + V.graph.sizevars.atomically_apply_size_hint( + stride, + fallback=config.unbacked_symint_fallback, + ) + for stride in node.get_stride() + ), + sizevars.size_hint( + node.get_layout().offset, + fallback=config.unbacked_symint_fallback, + ), + ) + + def add_feedback_saver(self, fn: FeedbackFunction): + self.feedback_saver_fns.append(fn) + + def clear_feedback_savers(self): + self.feedback_saver_fns = [] + + def add_preprocessing_fn(self, fn: PreprocessingFunction): + self.preprocessing_fns.append(fn) + + def clear_preprocessing_fns(self, clear_defaults: bool = False): + """Clear preprocessing functions. + + Args: + clear_defaults: If True, clears all functions including defaults. + If False, clears only user-added functions and re-registers defaults. + """ + self.preprocessing_fns.clear() + if not clear_defaults: + self._register_default_preprocessing_fns() + + +_ALGORITHM_SELECTOR_CACHE: Optional[AlgorithmSelectorCache] = None + + +def get_algorithm_selector_cache() -> AlgorithmSelectorCache: + """Get the global algorithm selector cache, creating it if it doesn't exist.""" + global _ALGORITHM_SELECTOR_CACHE + if _ALGORITHM_SELECTOR_CACHE is None: + _ALGORITHM_SELECTOR_CACHE = AlgorithmSelectorCache() + return _ALGORITHM_SELECTOR_CACHE + + +def autotune_select_algorithm(*args, **kwargs): + cache = get_algorithm_selector_cache() + + if "return_multi_template" not in kwargs: + kwargs["return_multi_template"] = ( + torch._inductor.config.benchmark_epilogue_fusion + ) + + if "precompilation_timeout_seconds" not in kwargs: + kwargs["precompilation_timeout_seconds"] = config.precompilation_timeout_seconds + + return cache(*args, **kwargs) + + +def add_feedback_saver( + fn: FeedbackFunction, +): + cache = get_algorithm_selector_cache() + cache.add_feedback_saver(fn) + + +def clear_feedback_savers(): + """Clear all feedback saver functions.""" + cache = get_algorithm_selector_cache() + cache.clear_feedback_savers() + + +def add_preprocessing_fn( + fn: PreprocessingFunction, +): + """Add a preprocessing function to be applied to choices before autotuning. + + Preprocessing functions are called sequentially in the order they were registered, + with each function receiving the output of the previous one. They can filter, + reorder, transform, or modify the list of choices in any way. + + Args: + fn: A function that takes a list of ChoiceCaller objects and returns + a modified list of ChoiceCaller objects. + + Example: + def my_filter(choices): + # Filter out choices with certain names + return [c for c in choices if 'slow' not in c.name.lower()] + + add_preprocessing_fn(my_filter) + """ + cache = get_algorithm_selector_cache() + cache.add_preprocessing_fn(fn) + + +def clear_preprocessing_fns(clear_defaults: bool = False): + """Clear preprocessing functions at module level. + + Args: + clear_defaults: If True, clears all functions including defaults. + If False, clears only user-added functions and re-registers defaults. + """ + cache = get_algorithm_selector_cache() + cache.clear_preprocessing_fns(clear_defaults) + + +def realize_inputs(*args): + if len(args) == 1: + return ir.ExternKernel.require_stride1(ir.ExternKernel.realize_input(args[0])) + return [realize_inputs(x) for x in args] + + +class SymbolicGridFn: + """ + Wrapper around a grid function that allows either int or sympy inputs. + + @SymbolicGridFn + def grid(x, meta, *, cdiv): + return cdiv(x, meta["BLOCK_X"]) + """ + + def __init__(self, fn: Callable[..., tuple[Any, Any, Any]]): + self.fn = fn + self.kwargs_int = {} + self.kwargs_sym = {} + params = inspect.signature(fn).parameters + for name, fn_sym, fn_int in [ + ("cdiv", CeilDiv, ceildiv), + ("min", sympy.Min, min), + ("max", sympy.Max, max), + ]: + if name in params: + self.kwargs_int[name] = fn_int + self.kwargs_sym[name] = fn_sym + + def __call__(self, *args, **kwargs) -> tuple[int, int, int]: + return self.fn(*args, **kwargs, **self.kwargs_int) + + def sympy_call(self, *args, **kwargs): + return self.fn(*args, **kwargs, **self.kwargs_sym) + + +def _autotune_metadata(input_nodes): + """Helper function to extract autotune metadata from input nodes.""" + return { + "autotune_strides": ", ".join([str(n.get_stride()) for n in input_nodes]), + "autotune_dtypes": ", ".join([str(n.get_dtype()) for n in input_nodes]), + "autotune_shape": ", ".join( + ["x".join(map(str, n.get_size())) for n in input_nodes] + ), + "autotune_offset": ", ".join([str(n.get_layout().offset) for n in input_nodes]), + # TODO(coconutruben): replace this with taking KernelInputs as the + # argument, and extracting those out there directly + "autotune_strides_hinted": ", ".join( + [ + str( + V.graph.sizevars.size_hints( + n.get_stride(), + fallback=config.unbacked_symint_fallback, + ) + ) + for n in input_nodes + ] + ), + "autotune_shape_hinted": ", ".join( + [ + "x".join( + map( + str, + V.graph.sizevars.size_hints( + n.get_size(), + fallback=config.unbacked_symint_fallback, + ), + ) + ) + for n in input_nodes + ] + ), + } + + +def _log_autotune_choices_stats( + event_name: str, timings: dict[ChoiceCaller, float] +) -> None: + """Helper function to extract autotune metadata from benchmark results.""" + if not timings: + return None + + metadata: dict[str, Union[int, float, str]] = { + "num_choices": len(timings), + "num_triton_choices": len( + [c for c in timings if isinstance(c, TritonTemplateCaller)] + ), + } + + sorted_choices = sorted(timings, key=timings.__getitem__) + best_choice = sorted_choices[0] + metadata["best_kernel"] = best_choice.name + if best_choice.description: + metadata["best_kernel_desc"] = best_choice.description + metadata["best_time"] = timings[best_choice] + + best_triton_pos = next( + ( + i + for i, choice in enumerate(sorted_choices) + if isinstance(choice, TritonTemplateCaller) + ), + None, + ) + if best_triton_pos is not None: + metadata["best_triton_pos"] = best_triton_pos + best_triton_kernel = sorted_choices[best_triton_pos] + if best_triton_pos != 0: + metadata["best_triton_time"] = timings[best_triton_kernel] + metadata["best_triton_kernel"] = best_triton_kernel.name + if best_triton_kernel.description: + metadata["best_triton_kernel_desc"] = best_triton_kernel.description + + payload = json.dumps(metadata) + get_chromium_event_logger().add_event_data( + event_name, autotune_choices_stats=payload + ) + sys.stderr.write(f"Autotune Choices Stats:\n{payload}\n") + + +def _log_autotune_exceptions( + exceptions: list[tuple[ChoiceCaller, BaseException]], +) -> None: + """Log autotune exceptions to chromium event logger.""" + if not exceptions: + return + + try: + pt2_compile_substack = get_chromium_event_logger().get_pt2_compile_substack() + if not pt2_compile_substack: + return + + current_event = pt2_compile_substack[-1] + if not current_event.endswith("_template_precompiling"): + return + + exception_details = [] + for choice, exc in exceptions: + try: + choice_type = ( + "triton" if isinstance(choice, TritonTemplateCaller) else "other" + ) + data = { + "choice_type": choice_type, + "choice": choice.description, + "exception_message": str(exc), + } + + exc_type_match = re.search(r"(\w+):", str(exc)) + if exc_type_match: + data["exception"] = exc_type_match.group(1) + + if "OutOfMemoryError" in str(exc): + required_match = re.search(r"Required: (\d+)", str(exc)) + if required_match: + data["required_memory"] = required_match.group(1) + + limit_match = re.search(r"Hardware limit:\s*(\d+)", str(exc)) + if limit_match: + data["hardware_limit"] = limit_match.group(1) + + exception_details.append(data) + except Exception: + # Don't let logging errors break the main flow + continue + + if exception_details: + metadata = json.dumps({"exceptions": exception_details}) + get_chromium_event_logger().try_add_event_data( + current_event, metadata=metadata + ) + except Exception: + # Silently ignore logging errors to avoid breaking autotune + pass + + +# ensure lowering is imported so that `extern_kernels.*` is populated +from . import lowering # noqa: F401 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/shape_propagation.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/shape_propagation.py new file mode 100644 index 0000000000000000000000000000000000000000..23a771a024efa4924c5894636769ff05ca9095db --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/shape_propagation.py @@ -0,0 +1,154 @@ +import functools +from collections.abc import Callable, Sequence +from typing import Optional, Protocol, Union + +import sympy + +import torch + +from .virtualized import OpsValue, V + + +BlockShapeType = Optional[Sequence[Union[int, str]]] + + +class ShapeVar(Protocol): + @property + def shape(self) -> BlockShapeType: ... + + +ShapeArg = Union[ShapeVar, torch.types.Number, str, OpsValue, torch.dtype] + +# Inputs need to be cacheable (e.g., not a CSEVar) in order for the cache to be effective +# So first decompose CSEVars -> tuple before calling this + + +@functools.lru_cache(None) +def get_broadcasted_shape(a: BlockShapeType, b: BlockShapeType) -> BlockShapeType: + assert isinstance(a, Sequence) + assert isinstance(b, Sequence) + if len(a) > len(b): + return get_broadcasted_shape(a, (*[1] * (len(a) - len(b)), *b)) + elif len(a) < len(b): + b, a = a, b + return get_broadcasted_shape(a, (*[1] * (len(a) - len(b)), *b)) + else: + + def _get_broadcasted_dim( + d1: Union[int, str], d2: Union[int, str] + ) -> Union[int, str]: + if str(d1) == "1": + return d2 + elif str(d2) == "1": + return d1 + assert str(d1) == str(d2) + return d1 + + return tuple(_get_broadcasted_dim(d1, d2) for d1, d2 in zip(a, b)) + + +def broadcast_shapes_for_args(args: Sequence[ShapeArg]) -> BlockShapeType: + result_shape: BlockShapeType = None + + for arg in args: + if hasattr(arg, "shape"): + shape = arg.shape + if shape is None: + return None + elif result_shape is None: + result_shape = tuple(shape) + else: + result_shape = get_broadcasted_shape(result_shape, tuple(shape)) + elif isinstance(arg, (int, float)): + if result_shape is None: + result_shape = () + elif isinstance(arg, torch.dtype): + continue + else: + from torch._inductor.loop_body import LoopBody, LoopBodyBlock + + if isinstance(arg, (LoopBodyBlock, LoopBody, OpsValue)): + # TODO: fix me + return None + raise TypeError(f"Unknown type: {type(arg)}") + + return result_shape + + +class ShapePropagationOpsHandler: + """ + Propagate shape from args to output + """ + + @staticmethod + def constant(value: torch.types.Number, dtype: torch.dtype) -> BlockShapeType: + # See implementation of constant for triton for the reason + from torch._inductor.codegen.triton import triton_compute_type, TritonKernel + + triton_type = triton_compute_type(dtype) + + if isinstance(V.kernel, TritonKernel) and triton_type != "tl.float32": + ndim = V.kernel.triton_tensor_ndim() + return tuple([1] * ndim) + else: + return () + + @staticmethod + def store_reduction(name: str, index: int, value: ShapeArg) -> None: + return None + + @staticmethod + def reduction( + dtype: torch.dtype, + src_dtype: torch.dtype, + reduction_type: str, + value: Union[ShapeArg, tuple[ShapeArg, ...]], + ) -> Union[BlockShapeType, tuple[BlockShapeType, ...]]: + raise NotImplementedError + + @staticmethod + def store( + name: str, index: int, value: ShapeArg, mode: Optional[str] = None + ) -> None: + return None + + @staticmethod + def to_dtype( + value: ShapeVar, + dtype: torch.dtype, + src_dtype: Optional[torch.dtype] = None, + use_compute_types: bool = True, + ) -> BlockShapeType: + return value.shape + + @staticmethod + def dot(a: sympy.Expr, b: sympy.Expr) -> BlockShapeType: + from torch._inductor.codegen.triton import TritonKernel + + assert isinstance(V.kernel, TritonKernel), "dot supports Triton only" + return ("YBLOCK", "XBLOCK") + + @staticmethod + def index_expr(expr: sympy.Expr, dtype: torch.dtype) -> BlockShapeType: + # shape is implicitly embedded in expr. + return None + + @staticmethod + def load_seed(name: str, offset: int) -> BlockShapeType: + return () + + @staticmethod + def indirect_indexing( + var: ShapeArg, + size: Union[sympy.Expr, int], + check: bool = True, + wrap_neg: bool = True, + ) -> None: + return None + + def __getattr__(self, name: str) -> Callable[..., BlockShapeType]: + return lambda *args, **kwargs: broadcast_shapes_for_args(args) + + @staticmethod + def device_assert_async(cond: ShapeArg, msg: str) -> None: + return None diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/sizevars.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/sizevars.py new file mode 100644 index 0000000000000000000000000000000000000000..77526a38aeb37f3919612f1ce698787f4b0bc3fd --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/sizevars.py @@ -0,0 +1,1205 @@ +# mypy: allow-untyped-defs +import functools +import itertools +import logging +from collections import defaultdict +from collections.abc import Callable, Iterable, Sequence +from typing import Any, cast, Optional, Union + +import sympy +from sympy import Expr + +from torch.fx.experimental.symbolic_shapes import ( + free_symbols, + has_free_unbacked_symbols, + ShapeEnv, +) +from torch.utils._ordered_set import OrderedSet +from torch.utils._sympy.functions import FloorDiv, ModularIndexing +from torch.utils._sympy.symbol import symbol_is_type, SymT +from torch.utils._sympy.value_ranges import bound_sympy, IntInfinity, ValueRanges + +from .runtime.runtime_utils import is_power_of_2 +from .utils import ( + has_free_symbols, + sympy_index_symbol, + sympy_index_symbol_with_prefix, + sympy_subs, + VarRanges, +) +from .virtualized import V + + +log = logging.getLogger(__name__) + + +def statically_known_true( + shape_env: ShapeEnv, + expr: Union[sympy.Basic, bool], + axioms: Optional[tuple[sympy.Expr]] = None, + var_to_range: Optional[tuple[tuple[sympy.Symbol, ValueRanges[Any]]]] = None, +) -> bool: + if expr in (True, False): + return bool(expr) + + try: + simplified = shape_env._maybe_evaluate_static( + expr, + axioms=axioms, + var_to_range=var_to_range, + ) + if simplified is not None: + return bool(simplified) + except Exception: + log.debug("Could not simplify %s", expr, exc_info=True) + + return False + + +# This class is a little awkward, because ShapeEnv is doing most of the heavy +# lifting and in some cases we should be directly passing through to ShapeEnv, +# but there is some extra inductor logic that needs to be handled here +class SizeVarAllocator: + """ + A class that manages symbolic size variables and their relationships. + + This class works with the ShapeEnv to handle symbolic shape expressions, + simplify them, and provide utilities for guarding, checking, and evaluating + symbolic expressions. It also manages precomputed replacements and stride + calculations for tensor operations. + """ + + def __init__(self, shape_env=None) -> None: + super().__init__() + # Note: this can lead to bugs. Reasoning APIs depends on existing information in + # in the shape_env. For example! var_to_ranges can't be empty! + if shape_env is None: + shape_env = ShapeEnv() + self.shape_env = shape_env + self.var_to_val = self.shape_env.var_to_val + self.var_to_hint_override = self.shape_env.var_to_hint_override + self.replacements: dict[sympy.Symbol, Expr] = self.shape_env.replacements + self.unbacked_replacements: Optional[dict[Expr, Expr]] = None + # Maps of dynamic sizes that have to be precomputed on the host to the kernel args. + # The basic idea is if we have some complicated sympy expression + # f(s0), we may choose to precompute it on the host and then replace + # all occurrences of that sympy expression with ps0, so that when we + # codegen we simply reference ps0 directly without repeating + # f(s0). Unlike regular size variables, ps variables cannot be + # guarded upon; so if we are asked to guard on a Sympy expression + # which potentially could have already had a precomputed replacement + # on it, we are obligated to invert the precomputed replacements + # (inv_precomputed_replacements). + self.precomputed_replacements: dict[Expr, sympy.Symbol] = {} + self.inv_precomputed_replacements: dict[sympy.Symbol, Expr] = {} + self.stride_vars = self.make_stride_vars_cache() + self.simplify_with_ranges = self.make_simplify_with_ranges_cache() + self._simplify_loops = self.make_simplify_loops_cache() + + def simplify(self, expr: Expr): + return sympy.expand(expr).xreplace(self.replacements) + + def make_simplify_with_ranges_cache(self) -> Callable[[Expr, VarRanges], Expr]: + """ + self._simplify_with_ranges() can be expensive, cache its results + """ + cache: dict[tuple[Any, ...], Expr] = {} + replacement_count = len(self.replacements) + + def simplify_with_ranges(expr: Expr, var_ranges: VarRanges) -> Expr: + nonlocal replacement_count + if replacement_count != len(self.replacements): + # new replacements invalidates cached results + cache.clear() + replacement_count = len(self.replacements) + key = (expr, *var_ranges.items()) + result = cache.get(key) + if result is None: + result = self._simplify_with_ranges(expr, var_ranges) + cache[key] = result + if result != expr: + cache[(result, *var_ranges.items())] = result + return result + + return simplify_with_ranges + + def make_simplify_loops_cache(self): + """ + self._simplify_with_ranges() can be expensive, cache its results + """ + cache: dict[tuple[Any, ...], Any] = {} + replacement_count = len(self.replacements) + + def simplify_loops(index_vars, sizes, index_formulas): + nonlocal replacement_count + if replacement_count != len(self.replacements): + # new replacements invalidates cached results + cache.clear() + replacement_count = len(self.replacements) + key = (*index_vars, *sizes, *index_formulas) + result = cache.get(key) + if result is None: + result = self._simplify_loops_impl(index_vars, sizes, index_formulas) + cache[key] = result + return result + + return simplify_loops + + def _simplify_with_ranges(self, expr: Expr, var_ranges: VarRanges) -> Expr: + """ + Simplify indexing expression with knowledge of the ranges of + iteration variables. + """ + + expr = join_dimensions(self.simplify(expr)) + original_expr = expr + + var_to_range = dict(self.shape_env.var_to_range) + var_to_range.update( + { + k: ValueRanges( + 0, max(0, v - 1) if not has_free_symbols([v]) else IntInfinity() + ) + for k, v in var_ranges.items() + } + ) + for var in expr.free_symbols: + if var not in var_to_range: + var_to_range[var] = ValueRanges(0, IntInfinity()) + + var_to_range_tuple = cast( + tuple[tuple[sympy.Symbol, ValueRanges[sympy.Expr]]], + tuple(var_to_range.items()), + ) + + axioms = [] + for var, upper_bound in var_ranges.items(): + axioms.append(0 <= var) + axioms.append(var < upper_bound) + axioms = tuple(axioms) + self.shape_env.get_axioms() + + def statically_known(expr): + evaluated = self.shape_env._maybe_evaluate_static( + expr, + # pyrefly: ignore [bad-argument-type] + axioms=axioms, + var_to_range=var_to_range_tuple, + ) + return bool(evaluated) + + def remove_zero_terms(base, divisor): + """Symbols smaller than the divisor are zero""" + if not statically_known(base >= 0): + return base + + for v in base.free_symbols: + if v in var_ranges: + # var smaller than divisor can be removed + # if the rest is guaranteed to be multiple of divisor + rest = sympy.Wild("_rest", exclude=[v]) + m = base.match(v + rest) + if m and v not in m[rest].free_symbols: + gcd = sympy.gcd(m[rest], divisor) + if gcd == divisor: + if statically_known(v < divisor): + base = m[rest] + return base + + def visit_indexing_div(base, divisor): + return FloorDiv(remove_zero_terms(base, divisor), divisor) + + def visit_modular_indexing(base, divisor, modulus): + base = remove_zero_terms(base, divisor) + + can_remove_mod = statically_known(base >= 0) and statically_known( + base < modulus * divisor + ) + + if can_remove_mod: + return FloorDiv(base, divisor) + return ModularIndexing(base, divisor, modulus) + + if expr.has(ModularIndexing): + expr = expr.replace( + ModularIndexing( + sympy.Wild("base", integer=True), + sympy.Wild("divisor", integer=True), + sympy.Wild("modulus", integer=True), + ), + visit_modular_indexing, + ) + + if expr.has(FloorDiv): + expr = expr.replace( + FloorDiv( + sympy.Wild("base", integer=True), + sympy.Wild("divisor", integer=True), + ), + visit_indexing_div, + ) + + if expr != original_expr: + return self._simplify_with_ranges(expr, var_ranges) + return expr + + def _simplify_loops_impl( + self, index_vars: list[sympy.Symbol], sizes, index_formulas + ): + """ + Try to remove as many axis from loop iterations as possible, by: + 1) removing size==1 dimensions + 2) fuse contiguous dimensions into a single loop + If channel_last = True, we will prevent the last dim fused with other dims + """ + sizes = list(map(self.simplify, sizes)) + + strides = [ + # index_formulas may contain boolean expressions (e.g. s0 < 10), + # for which "strides" don't make sense so we ignore them here. + # NOTE: These expressions may still block merging dims in the sound + # substitution test performed in can_merge_dims. + ( + self.stride_vars(x, index_vars) + if isinstance(x, sympy.Expr) + else [0] * len(index_vars) + ) + for x in index_formulas + ] + assert len(sizes) == len(strides[0]), (len(sizes), len(strides[0])) + + for i in range(len(sizes)): + if sizes[i] == 1: + # remove dim + sizes[i] = None + + def can_merge_dims(a, b): + for k in range(len(strides)): + if self.simplify(strides[k][a] * sizes[a]) == self.simplify( + strides[k][b] + ): + # approximate test passed, try sound version + va = index_vars[a] + vb = index_vars[b] + m1 = sympy_index_symbol("_merge_tester1") + m2 = sympy_index_symbol("_merge_tester2") + # NOTE: can't sub vb=0 here in case va * vb appears in the expression, + # in which case both expr1 and expr2 would be zero! + expr1 = sympy_subs(index_formulas[k], {va: m1 * sizes[a], vb: m2}) + expr2 = sympy_subs(index_formulas[k], {va: 0, vb: (m1 + m2)}) + if self.simplify(expr1) == self.simplify(expr2): + continue + return False + return True + + changed = True + while changed: + changed = False + for i, j in itertools.product( + reversed(range(len(sizes))), reversed(range(len(sizes))) + ): + if i == j or sizes[i] is None or sizes[j] is None: + continue + if can_merge_dims(i, j): + changed = True + sizes[i] = sizes[i] * sizes[j] + sizes[j] = None + + def reindex(index): + it = list(reversed(index)) + new_index = [] + for size in sizes: + if size is None: + new_index.append(sympy.S.Zero) + else: + new_index.append(it.pop()) + assert not it + return new_index + + def prune(index): + assert len(index) == len(sizes) + return [i for i, s in zip(index, sizes) if s is not None] + + return [x for x in sizes if x is not None], reindex, prune + + # Note - [On Statically Known] + # The statically_known_* family of functions below NEVER guard, they could return True if the + # asked questions can be answered without guarding otherwise they return False. + # Those are similar to statically_known_true in symbolic_shapes.py but operate on sympy + # expressions instead of symnodes. + def statically_known_true(self, expr: Union[sympy.Basic, bool]) -> bool: + """ + Returns true if an expression is always true (symbolically or via guards), + false otherwise. Never add guards, or throw data dependent errors. + """ + return statically_known_true(self.shape_env, expr) + + def statically_known_equals( + self, left: Union[Expr, int], right: Union[Expr, int] + ) -> bool: + """ + Returns a bool indicating if it is sound to optimize as if left and right are equal. + """ + return self.statically_known_true(sympy.Eq(left, right)) # type: ignore[arg-type] + + def statically_known_list_equals( + self, left: Sequence[Expr], right: Sequence[Expr] + ) -> bool: + """ + Returns a bool indicating if it is sound to optimize as if left and right lists are equal. + """ + return len(left) == len(right) and all( + self.statically_known_equals(l, r) for l, r in zip(left, right) + ) + + def statically_known_leq(self, left: Expr, right: Union[Expr, int]) -> bool: + """ + Returns a bool indicating if it is sound to optimize as if left is less than or equal to right. + """ + expr = left <= right + return self.statically_known_true(expr) + + def statically_known_geq(self, left: Expr, right: Union[Expr, int]) -> bool: + """ + Returns a bool indicating if it is sound to optimize as if left is greater than or equal to right. + """ + expr = left >= right + return self.statically_known_true(expr) + + def statically_known_lt(self, left: Expr, right: Union[Expr, int]) -> bool: + """ + Returns a bool indicating if it is sound to optimize as if left is less than right. + """ + expr = left < right + return self.statically_known_true(expr) + + def statically_known_gt(self, left: Expr, right: Union[Expr, int]) -> bool: + """ + Returns a bool indicating if it is sound to optimize as if left is greater than right. + """ + expr = left > right + return self.statically_known_true(expr) + + def statically_known_multiple_of( + self, numerator: Expr, denominator: Union[Expr, int] + ) -> bool: + """ + Return a bool indicating if it is sound to optimize for the numerator being a multiple of the denominator. + """ + # The reason we skip compute here is to avoid the cost of trying to eval this symbolically. + # see https://github.com/sympy/sympy/issues/28200 + if has_free_unbacked_symbols(numerator) or has_free_unbacked_symbols( + denominator + ): + return False + + if len(free_symbols(numerator)) > 20: + return False + + expr = sympy.Eq(numerator % denominator, 0) + return self.statically_known_true(expr) # type: ignore[arg-type] + + def statically_known_power_of_2(self, expr: Expr) -> bool: + """ + Returns a bool indicating if x is known to be a power of 2. + """ + return isinstance(expr, sympy.Integer) and is_power_of_2(int(expr)) + + # The expect/check functions require you to ALREADY KNOW that a particular + # condition holds. They are similar to expect_true in symbolic_shapes.py and + # torch.check but operates on sympy expressions instead of symnodes. + def expect_true(self, expr: Expr) -> bool: + """ + Use it when you already know that expr is true or should be true and want to + ensure that guards/runtime assertions are in place to ensure this in compiled + function. Unlike check, this WON'T raise an error if expr isn't actually true. + check Note [expect_true]. + """ + if not self.statically_known_true(expr): + return self.shape_env.guard_or_defer_runtime_assert( + expr, "sizevars.expect_true" + ) + return True + + def check(self, expr: Expr) -> None: + """ + Use it when you already know that expr is true or should be true and want to + ensure that guards/runtime assertions are in place to ensure this in compiled + function. Unlike expect_true, this WILL raise an error if expr isn't actually true. + check Note [expect_true]. + """ + expr = sympy_subs(expr, self.inv_precomputed_replacements) + assert self.expect_true(expr) + + def check_equals(self, left: Expr, right: Expr) -> None: + """ + check(sympy.Eq(left, right)). + + """ + self.check(sympy.Eq(left, right)) + return left + + def check_equals_and_simplify(self, left: Expr, right: Expr) -> Expr: + """ + check(sympy.Eq(left, right)) and returns left after applying + inv_precomputed_replacements. + """ + self.check(sympy.Eq(left, right)) + return sympy_subs(left, self.inv_precomputed_replacements) + + def check_leq(self, left: Expr, right: Expr) -> None: + self.check(sympy.Le(left, right)) + + def check_lt(self, left: Expr, right: Expr) -> None: + self.check(sympy.Lt(left, right)) + + # Similar to the functions guard_or_false/guard_or_true in symbolic_shapes.py + # but operates on sympy expressions instead of symnodes. see Note [guard_or_]. + def guard_or_false(self, left): + import torch.fx.experimental._config as exp_config + + if exp_config.backed_size_oblivious: + static_val = self.shape_env._maybe_evaluate_static(left) + if static_val is not None: + return static_val + return False + return self.evaluate_expr(left, fallback_value=False) + + def guard_or_true(self, left): + import torch.fx.experimental._config as exp_config + + if exp_config.backed_size_oblivious: + static_val = self.shape_env._maybe_evaluate_static(left) + if static_val is not None: + return static_val + return True + return self.evaluate_expr(left, fallback_value=True) + + # The evaluate functions evaluate some symbolic sympy expression + # (NB: not necessarily an Expr) and return what the concrete result + # is, guarding on the expression being that result + + # NB: write evaluate_expr(sympy.Lt(a, b)) rather than evaluate_expr(a < b) + # as this will ensure that you actually have a sympy'ified expression, + # and will prevent you from incorrectly writing evaluate_expr(a == b) + # which does the wrong thing if a or b is a sympy expression + def evaluate_expr( + self, + left: Union[Expr, sympy.logic.boolalg.Boolean], + size_oblivious: bool = False, + fallback_value: Optional[bool] = None, + ) -> bool: + assert isinstance(left, (Expr, sympy.logic.boolalg.Boolean)), type(left) + return self.shape_env.evaluate_expr( + sympy.sympify(left), + size_oblivious=size_oblivious, + fallback_value=fallback_value, + ) + + def is_size_one_or_false(self, size: Expr) -> bool: + """Return True if size equals 1. + + Unbacked symbolic sizes return False without introducing a guard. + """ + return self.guard_or_false(sympy.Eq(size, 1)) + + def evaluate_min(self, left: Expr, right: Expr) -> Expr: + """return the smaller of left and right, and guard on that choice""" + if isinstance(left, Expr): + left = sympy_subs(left, self.inv_precomputed_replacements) # type: ignore[arg-type] + if isinstance(right, Expr): + right = sympy_subs(right, self.inv_precomputed_replacements) # type: ignore[arg-type] + try: + lv = self.size_hint_or_throw(left) + rv = self.size_hint_or_throw(right) + except TypeError: # unbacked symints + if left == right or self.statically_known_leq(left, right): + return left + if self.statically_known_leq(right, left): + return right + gcd = sympy.gcd(left, right) + if left == gcd: # handle `min(10*u0, u0)` etc + return left + if right == gcd: + return right + raise TypeError( + f"evaluate_min({left}, {right}) with unbacked symints" + ) from None + if lv <= rv: + self.check_leq(left, right) + return left + else: + self.check_leq(right, left) + return right + + def evaluate_max(self, left: Expr, right: Expr) -> Expr: + """return the larger of left and right, and guard on that choice""" + # Always choose the opposite of eval min for consistency + # This means min(a, b) and max(a, b) produce the same guards + min_val = self.evaluate_min(left, right) + return right if min_val is left else left + + def guard_int(self, expr: Union[Expr, int]) -> int: + """ + Similar to guard_int in symbolic_shapes.py, except this function works with SymPy + expressions instead of SymNodes. It extracts the value represented by expr from shapeEnv + and specialize the compiled graph on it. Raises an error if the result cannot be + determined due to unhinted or unbacked symbols. + """ + if isinstance(expr, int): + return expr + val = self.size_hint_or_throw(expr) + self.check_equals(expr, sympy.Integer(val)) + return int(val) + + def guard_int_seq(self, left: Sequence[Union[Expr, int]]) -> list[int]: + """ + Apply guard_int on a sequence of inputs. + """ + return [self.guard_int(x) for x in left] + + def remove_precomputed_replacements(self, expr: Expr) -> Expr: + if any(symbol_is_type(s, SymT.PRECOMPUTED_SIZE) for s in expr.free_symbols): # type: ignore[attr-defined] + return sympy_subs(expr, self.inv_precomputed_replacements) # type: ignore[arg-type] + return expr + + def symbolic_hint( + self, + expr: Union[Expr, int], + hint_override: Optional[int] = None, + # Only flip this flag if you don't plan on guarding/adding runtime + # asserts based on this value and promise to only use this value + # in a heuristic nature. + use_user_provided_hint_override: bool = False, + ) -> Union[Expr, int]: + if isinstance(expr, int): + return expr + # Substitute all hints into expr, but leave unbacked symints alone + expr = self.simplify(expr) + if not isinstance(expr, Expr): + assert isinstance(expr, int) + return expr + free_symbols = expr.free_symbols + if not free_symbols: + try: + return int(expr) # type: ignore[return-value] + except TypeError: + return expr # inf/nan/I + + if hint_override: + return hint_override + + expr = self.remove_precomputed_replacements(expr) + + if use_user_provided_hint_override: + expr = sympy_subs(expr, self.var_to_hint_override) + + return sympy_subs(expr, self.var_to_val) + + def size_hint( + self, + expr: Union[Expr, int], + *, + fallback: Optional[int] = None, + hint_override: Optional[int] = None, + ) -> int: + out = self.symbolic_hint( + expr, + hint_override=hint_override, + use_user_provided_hint_override=fallback is not None, + ) + if not isinstance(out, (int, sympy.Integer)) and fallback is not None: + # Use the provided heuristic fallback hint + unbacked_sym_vrs = { + s: self.shape_env.var_to_range.get(s, None) for s in out.free_symbols + } + if all(vr is not None for vr in unbacked_sym_vrs.values()): + hint_vr = bound_sympy(out, unbacked_sym_vrs) # type: ignore[arg-type] + if isinstance(hint_vr.lower, (int, sympy.Integer)): + fallback = max(fallback, int(hint_vr.lower)) + if isinstance(hint_vr.upper, (int, sympy.Integer)): + fallback = min(fallback, int(hint_vr.upper)) + return fallback + + try: + return int(out) + except Exception: + log.debug("failed on: %s", out) + raise + + def size_hint_or_throw(self, expr: Union[Expr, int]) -> int: + # Like size_hint but there's no fallback for unbacked symints, so it throws. + out = self.symbolic_hint(expr) + try: + return int(out) + except Exception: + log.debug("failed on: %s", out, exc_info=True) + raise + + def size_hints( + self, + exprs: Iterable[Union[Expr, int]], + *, + fallback: Optional[int] = None, + hint_override: Optional[int] = None, + ) -> tuple[int, ...]: + return tuple( + self.size_hint( + x, + fallback=fallback, + hint_override=hint_override, + ) + for x in exprs + ) + + def size_hints_or_throw( + self, + exprs: Iterable[Union[Expr, int]], + ) -> tuple[int, ...]: + # Like size_hints but there's no fallback for unbacked symints, so it throws. + return tuple(self.size_hint_or_throw(x) for x in exprs) + + def _lru_cache(self, fn, maxsize=None): + """ + Wrapper around functools.lru_cache that clears when replacements + has been invalidated. + """ + fn_cache = functools.lru_cache(maxsize)(fn) + prior_len = len(self.replacements) + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + nonlocal prior_len + if prior_len != len(self.replacements): + prior_len = len(self.replacements) + fn_cache.cache_clear() + return fn_cache(*args, **kwargs) + + return wrapper + + def make_stride_vars_cache(self): + cache = self._lru_cache(self._stride_vars) + + def stride_vars( + index: Expr, + vars: Sequence[sympy.Symbol], + support_vars: Optional[Sequence[sympy.Symbol]] = None, + ) -> list[Expr]: + if not support_vars: + support_vars = vars + return cache(index, tuple(vars), tuple(support_vars)) + + return stride_vars + + def _stride_vars( + self, + index: Expr, + vars: Sequence[sympy.Symbol], + support_vars: Sequence[sympy.Symbol], + ) -> list[Expr]: + """Convert an indexing expression back into strides + + NOTE: This is only valid if the index is a standard strided offset + calculation. e.g. 10 * ModularIndexing(i0 + 1, 1, 2) would give a + stride of -10 because the index wraps around after the first element + + """ + strides = [] + index = self.simplify(index) + # remove any offset + index = index - sympy_subs( + index, {v: sympy.S.Zero for v in support_vars if v != 0} + ) + for i in range(len(vars)): + # drop all the other dims + index_dim = sympy_subs( + index, + { + support_vars[j]: sympy.S.Zero + for j in range(len(support_vars)) + if vars[i] != support_vars[j] and support_vars[j] != 0 + }, + ) + v = vars[i] + if v == 0: + strides.append(sympy.S.Zero) + else: + # TODO(jansel): should we use sympy.diff here? + strides.append( + sympy_subs(index_dim, {v: sympy.S.One}) + - sympy_subs(index_dim, {v: sympy.S.Zero}) + ) + return strides + + def _get_unbacked_replacements(self) -> dict[Expr, Expr]: + if self.unbacked_replacements is not None: + return self.unbacked_replacements + + class CanonicalExprFinder: + """ + Purpose: + A disjoint-set/union-find data structure that can return the + "canonical" expression for a group of equivalent expressions. + - The canonical expression must come from the input eq_graph. + - The heuristics used to choose a leader determines which + expression becomes the canonical expression. + + Problem: + Given any unbacked expression, we should be able to find a size_hint + for the unbacked expression, that adheres to the ShapeEnv's deferred + runtime assertions. Otherwise, we may generate conflicting size hints. + In other words, even though we know u0 + s0 == u2, we may generate + size hints, such that, size_hint(u0 + s0) != size_hint(u2). + NOTE: At this time, only deferred runtime asserts that are equalities + (i.e. Eq(lhs, rhs)) are considered in this data structure. + + Examples: + - u0 + u1 == 9000, then find_expr(u0 + u1) == find_expr(9000) + - u0 + u1 == s9, then find_expr(u0 + u1) == find_expr(s9) + - u0 + s0 == u10, then find_expr(u0 + s0) == find_expr(u10) + + Inputs: + - equality_graph: An adjacency set of expressions where the edge + connects two expressions that are found equal to each other. The + edges are sourced from ShapeEnv's deferred_runtime_asserts. + + Usage: + - Call union_expr(a, b) to merge a & b into a single set which + shares the same canonical expression. + - Call find_expr(x) to find the canonical expression for x. + """ + + def __init__(self, eq_graph: dict[Expr, OrderedSet[Expr]]): + self.eq_graph = eq_graph + self.expressions = list(eq_graph.keys()) + self.reverse_expressions = { + expr: i for i, expr in enumerate(self.expressions) + } + # Each node is its own leader/parent initially + self.leader = list(range(len(self.expressions))) + # Track rank for union-by-rank + self.rank = [1] * len(self.expressions) + + # Takes each edge from the undirected graph and starts merging them. + self._build_canonical_expr_mapping() + + def _build_canonical_expr_mapping(self): + for expr, edges in self.eq_graph.items(): + for adj in edges: + self.union_expr(expr, adj) + + def union_expr(self, a: Expr, b: Expr): + return self.union( + self.reverse_expressions[a], self.reverse_expressions[b] + ) + + def union(self, a: int, b: int): + rootA = self.find(a) + rootB = self.find(b) + if rootA == rootB: + return False # already connected + leader, other = self.choose_leader(rootA, rootB) + self.leader[other] = leader + self.rank[leader] += self.rank[other] + return True + + def find_expr(self, expr: Expr): + parent = self.find(self.reverse_expressions[expr]) + return self.expressions[parent] + + def find(self, x: int): + # Path compression + if self.leader[x] != x: + self.leader[x] = self.find(self.leader[x]) + return self.leader[x] + + def choose_leader(self, a: int, b: int): + """ + The leader will become the canonical expression. + + Here are the heuristics used for choosing a leader: + 1. Backed expression or constants preferred over unbacked expr + 2. Simpler sub-expr when one contains the other + 3. Higher frequency across equalities from deferred runtime assertions + 4. Rank/size of the set + 5. Fallback to sympy.Basic.compare + """ + + def _choose(x: int, y: int) -> bool: + lhs, rhs = self.expressions[x], self.expressions[y] + + # Prefer replacing unbacked exprs with backed expressions/constants. + # Examples: + # u0 + s3 ==> s0 + s1, then leader is s0 + s1 + # u2 ==> 300, then leader is 300 + any_unbacked_lhs = has_free_unbacked_symbols(lhs) + any_unbacked_rhs = has_free_unbacked_symbols(rhs) + if any_unbacked_lhs != any_unbacked_rhs: + return bool(any_unbacked_rhs) + + # Handles cases where LHS contains the RHS. In other words, + # RHS is a sub-expression of LHS. For example: + # s1 * Max(2, u0) ==> Max(2, u0), then leader is Max(2, u0) + if lhs.has(rhs): + return False + elif rhs.has(lhs): + return True + + # Prefer expressions that come up more often. + degrees_lhs = len(self.eq_graph[lhs]) + degrees_rhs = len(self.eq_graph[rhs]) + if degrees_lhs != degrees_rhs: + return degrees_lhs > degrees_rhs + + # Try to apply union-by-rank optimization to flatten the + # leader trees. + if self.rank[x] != self.rank[y]: + return self.rank[x] > self.rank[y] + + # Fallback to sympy.Basic.compare for a deterministic ordering. + return lhs.compare(rhs) == -1 + + if _choose(a, b): + return a, b + return b, a + + # Build an undirected graph using ShapeEnv's deferred runtime assertions. + self.equality_graph: dict[Expr, OrderedSet[Expr]] = defaultdict(OrderedSet) + for assertions in self.shape_env.deferred_runtime_asserts.values(): + for assertion in assertions: + if not isinstance(assertion.expr, sympy.Equality): + # We're ignoring other relationals for now. If you need to + # account for relationals, then you may need a solver solution. + continue + lhs = sympy.sympify(assertion.expr.lhs) # sympify helps with ints + rhs = sympy.sympify(assertion.expr.rhs) + self.equality_graph[lhs].add(rhs) + self.equality_graph[rhs].add(lhs) + + # Use the undirected graph to create a DSU data structure, so we can + # query for a "canonical" expression. + uf = CanonicalExprFinder(self.equality_graph) + + # Start building the unbacked replacements mapping using CanonicalExprFinder + # The mapping is from Expr to its "canonical" Expr. + self.unbacked_replacements = {} + for expr in self.equality_graph: + canonical_expr = uf.find_expr(expr) + if expr != canonical_expr: + self.unbacked_replacements[expr] = canonical_expr + + return self.unbacked_replacements + + @functools.lru_cache # noqa: B019 + def _sub_unbacked_exprs(self, expr: Expr) -> Expr: + # it's fine to cache this fn since self is a singleton + replacements = self._get_unbacked_replacements() + + # consider making this threshold configurable + sub_cnt_limit = 30 + sub_cnt = 0 + while sub_cnt < sub_cnt_limit: + new_expr = expr.subs(replacements) + if new_expr == expr: + return new_expr + expr = sympy.factor(new_expr) + sub_cnt += 1 + + log.warning("Substitution limit (%d) reached w/ %s", sub_cnt_limit, expr) + return expr + + def atomically_apply_size_hint( + self, + expr: Union[Expr, int], + *, + fallback: Optional[int] = None, + hint_override: Optional[int] = None, + ) -> Union[Expr, int]: + if isinstance(expr, (int, sympy.Integer)): + return int(expr) + + if has_free_unbacked_symbols(expr): + # Make sure to substitute with the factored version + # e.g. 10*(s0 + u0) instead of 10*s0 + 10*u0 + expr = self._sub_unbacked_exprs(sympy.factor(expr)) + + # For multiple expressions that depend on an unbacked symint, + # we want to compute them consistently for a size hint we have chosen. + # So, recursively compute expressions via size hints of contained symbols. + # For example: u1 * u2 - 10 ==> fallback * fallback - 10 + assert isinstance(expr, Expr), type(expr) + free_symbols = expr.free_symbols + size_dict = { + symbol: V.graph.sizevars.size_hint( + symbol, fallback=fallback, hint_override=hint_override + ) + for symbol in free_symbols + } + return expr.subs(size_dict) + + def offset_var(self, index: Expr, vars: Sequence[sympy.Symbol]) -> Expr: + """Extract offset part of an indexing expression""" + index = self.simplify(index) + return sympy_subs(index, {v: sympy.S.Zero for v in vars if v != 0}) + + def stride_hints( + self, + index: Expr, + vars: Sequence[sympy.Symbol], + support_vars: Optional[Sequence[sympy.Symbol]] = None, + ) -> list[int]: + for v in index.free_symbols: + if symbol_is_type(v, SymT.INDIRECT): # type: ignore[attr-defined] + index = sympy_subs(index, {v: 0}) # type: ignore[dict-item] + result = [] + for s in self.stride_vars(index, vars, support_vars): + try: + result.append(self.size_hint_or_throw(s)) + except TypeError: + result.append(0) + return result + + def stride_order(self, index: Expr, vars: list[sympy.Symbol]) -> list[int]: + strides = tuple(map(abs, self.stride_hints(index, vars))) + order = list(range(len(strides))) + order.sort(key=lambda x: (strides[x] == 0, strides[x])) + return order + + def lookup_precomputed_size(self, expr: Expr) -> Expr: + if ( + isinstance(expr, (int, sympy.Symbol, sympy.Number)) + or expr.is_number + or expr.is_symbol + ): + return expr + expr = self.remove_precomputed_replacements(expr) + if expr not in self.precomputed_replacements: + sym = sympy_index_symbol_with_prefix( + SymT.PRECOMPUTED_SIZE, len(self.precomputed_replacements) + ) + self.precomputed_replacements[expr] = sym + self.inv_precomputed_replacements[sym] = expr + return self.precomputed_replacements[expr] + + def free_symbols(self) -> OrderedSet[sympy.Symbol]: + return OrderedSet(self.var_to_val.keys()) - OrderedSet(self.replacements.keys()) + + def combine_modular_indexing_pairs(self, index: sympy.Expr) -> sympy.Expr: + """ + A pair of special ModularIndexing can be combined. + + E.g. ModularIndexing(ModularIndexing(x, 1, a), 1, b) + We can simplify this to ModuleIndexing(x, 1, b), if + 1. x is non negative integer + 2. a and b are positive integers + 3. a is a multiple of b. + """ + + def _check_args(x, div, mod, is_first): + if not isinstance(div, sympy.Integer) or not isinstance(mod, sympy.Integer): + return False + if div != 1: + return False + if mod <= 0: + return False + + if is_first: + # first ModularIndexing should contains a nested ModularIndex + if not isinstance(x, ModularIndexing): + return False + else: + # second ModularIndexing should contains a non-negative + # symbol + if not isinstance(x, sympy.Symbol) or not self.statically_known_geq( + x, 0 + ): + return False + return True + + if isinstance(index, ModularIndexing): + x, div, mod = index.args + + if not _check_args(x, div, mod, True): + return index + + x2, div2, mod2 = x.args + + if not _check_args(x2, div2, mod2, False): + return index + + if mod2 % mod != 0: + return index + + return ModularIndexing(x2, 1, mod) + + return index + + def expand_floor_div( + self, index: sympy.Expr + ) -> Union[bool, tuple[sympy.Expr, sympy.Expr]]: + """ + Expand the FloorDiv to the entire expression so that the expression may + be simplified. + + E.g., for a 2D contiguous tensor with shape [a, 2 * b], and index variables + x1, x2, index expression 'x1 * 2b + x2' can be easily combined. + But index expression 'x1 * b + x2 // 2' can not. + By expanding the FloorDiv to the entire expression, we get + '(x1 * 2b + x2) // 2'. This transformation allows us to merge loops + for the numerator! + + Return false if this optimization can be applied; + Return the new expression and the denominator otherwise. + The original expression will be equivalent to 'new_expression // denominator' + """ + if not isinstance(index, sympy.Add): + return False + terms = index.args + + if len(terms) < 2: + return False + floor_div_index = -1 + varlist = [] + factorlist = [] + for idx, term in enumerate(terms): + if isinstance(term, sympy.Mul): + # For dynamic shape, term like '2*s1*x1' has 3 child nodes. + # - A integer for 2 + # - A symbol for s1 + # - A symbol for x1 + # Skip for now. + if len(term.args) != 2: + return False + factor, var = term.args + varlist.append(var) + factorlist.append(factor) + if not isinstance(factor, sympy.Integer) or not isinstance( + var, sympy.Symbol + ): + return False + # It's easier to reason about the correceness of the transformation + # for non-negative integers. + if not self.statically_known_geq(var, 0): + return False + elif isinstance(term, FloorDiv): + var, factor = term.args + if not isinstance(factor, sympy.Integer) or not isinstance( + var, sympy.Symbol + ): + return False + if not self.statically_known_geq(var, 0): + return False + if floor_div_index >= 0: + # can not handle multi FloorDiv yet + return False + + floor_div_index = idx + varlist.append(var) + # this factor is denominator + factorlist.append(factor) + else: + return False + + if floor_div_index < 0: + return False + + # Construct the new expression and remember the denominator + denominator = factorlist[floor_div_index] + new_index = sympy.S.Zero + + for var, factor, idx in zip(varlist, factorlist, itertools.count()): + if idx == floor_div_index: + new_index += var + else: + new_index += (factor * denominator) * var + + return new_index, denominator + + +def join_dimensions(expr: Expr) -> Expr: + if not isinstance(expr, sympy.Add) or not expr.has(ModularIndexing): + return expr # fast exit path + return _join_dimensions_cached(expr) + + +@functools.lru_cache(256) +def _join_dimensions_cached(expr: Expr) -> Expr: + """ + ModularIndexing(i0, 1, 32) + 32 * ModularIndexing(i0, 32, 4) + becomes + ModularIndexing(i0, 1, 128) + ModularIndexing(i0, 1, 32) + 32 * FloorDiv(i0, 32) + becomes i0 + + + This type of pattern can come from view operations + """ + assert isinstance(expr, sympy.Add) + + scale = sympy.Wild("scale", exclude=[0], integer=True) + base = sympy.Wild("base", integer=True) + divisor = sympy.Wild("divisor", integer=True) + mod1 = sympy.Wild("modulus", integer=True) + mod2 = sympy.Wild("modulus2", integer=True) + for term1 in expr.args: + m1 = term1.match(scale * ModularIndexing(base, divisor, mod1)) + if m1: + for term2 in expr.args: + m2 = term2.match( + m1[scale] + * m1[mod1] + * ModularIndexing(m1[base], m1[divisor] * m1[mod1], mod2) + ) + if m2 and term1 != term2: + expr = join_dimensions( + expr + - term1 + - term2 + + m1[scale] + * ModularIndexing(m1[base], m1[divisor], m1[mod1] * m2[mod2]) + ) + return expr + for term1 in expr.args: + m1 = term1.match(scale * ModularIndexing(base, divisor, mod1)) + if m1: + for term2 in expr.args: + m2 = term2.match( + m1[scale] * m1[mod1] * FloorDiv(m1[base], m1[divisor] * m1[mod1]) + ) + if m2 is not None: # in case of success we get an empty dict here + expr = join_dimensions( + expr + - term1 + - term2 + + m1[scale] * FloorDiv(m1[base], m1[divisor]) + ) + return expr + return expr + + +class SimplifyIndexing(V.WrapperHandler): # type: ignore[name-defined] + """ + A wrapper around .virtualize.ops that uses var range information to + simplify ModularIndexing/FloorDiv. + """ + + def __init__(self, inner, var_ranges: VarRanges) -> None: + super().__init__(inner) + self.name = "SimplifyIndexing" + self._simplify: Callable[[Expr], Expr] = ( + lambda index: V.graph.sizevars.simplify_with_ranges(index, var_ranges) + ) + + def load(self, name: str, index: sympy.Expr): + return self._inner.load(name, self._simplify(index)) + + def store(self, name, index, value, mode=None): + return self._inner.store(name, self._simplify(index), value, mode=mode) + + def store_reduction(self, name, index, value): + return self._inner.store_reduction(name, self._simplify(index), value) + + def index_expr(self, index, dtype): + return self._inner.index_expr(self._simplify(index), dtype) + + def check_bounds(self, index, size, lower, upper): + return self._inner.check_bounds(self._simplify(index), size, lower, upper) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/standalone_compile.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/standalone_compile.py new file mode 100644 index 0000000000000000000000000000000000000000..d07c5e704321355c7dd9bc5dec2d192b60dc96d9 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/standalone_compile.py @@ -0,0 +1,440 @@ +from __future__ import annotations + +import copy +import logging +import os +import pickle +import shutil +from abc import ABC, abstractmethod +from contextlib import AbstractContextManager, nullcontext +from typing import Any, Literal, Optional, TYPE_CHECKING + +import torch.fx +from torch._dynamo.aot_compile_types import BundledAOTAutogradSerializableCallable +from torch._dynamo.utils import dynamo_timed +from torch._inductor.cpp_builder import normalize_path_separator +from torch._inductor.cudagraph_utils import BoxedDeviceIndex +from torch._inductor.runtime.cache_dir_utils import temporary_cache_dir +from torch._inductor.utils import BoxedBool, InputType +from torch._subclasses import FakeTensorMode +from torch.fx.experimental.symbolic_shapes import ShapeEnv + +from . import config + + +if TYPE_CHECKING: + from collections.abc import Callable, Sequence + + from torch.compiler._cache import CacheInfo + from torch.fx import GraphModule + + +log = logging.getLogger(__name__) + + +class CompiledArtifact(ABC): + """ + CompiledArtifact class represents the inductor cache artifacts that + can be invoked in order to avoid repeated compilation. + + CompiledArtifact can be obtained by calling standalone_compile(gm, example_inputs) + to create a fresh CompiledArtifact from a GraphModule and example inputs. + + Later this CompiledArtifact can be saved to disk, either as a binary or unpacked + into the provided folder via the CompiledArtifact.save function. + + CompiledArtifact.load provides a way to create a CompiledArtifact from the + binary or unpacked data. + + Finally, the CompiledArtifact can be invoked via the __call__ method + to execute the cached artifact. + """ + + def __init__( + self, + compiled_fn: Callable[..., Any], + artifacts: Optional[tuple[bytes, CacheInfo]], + ): + self._compiled_fn = compiled_fn + self._artifacts = artifacts + + @abstractmethod + def __call__(self, *args: Any) -> Any: ... + + @abstractmethod + def save( + self, *, path: str, format: Literal["binary", "unpacked"] = "binary" + ) -> None: ... + + @staticmethod + def load( + *, path: str, format: Literal["binary", "unpacked"] = "binary" + ) -> CompiledArtifact: + if format == "unpacked": + # If format is unpacked, it must be a CacheCompiledArtifact + return CacheCompiledArtifact.load(path=path, format=format) + + assert format == "binary" + with open(path, "rb") as file: + from torch.utils._appending_byte_serializer import BytesReader + + from .codecache import torch_key + + result_bytes = file.read() + reader = BytesReader(result_bytes) + header = reader.read_bytes() + if header == AOTCompiledArtifact.AOT_HEADER: + assert reader.read_bytes() == torch_key() + artifact = reader.read_bytes() + assert reader.is_finished() + return AOTCompiledArtifact.deserialize(artifact) + # Otherwise, it's in the CacheCompiledArtifact format + elif header == CacheCompiledArtifact.CACHE_HEADER: + assert reader.read_bytes() == torch_key() + key = reader.read_str() + artifact_bytes = reader.read_bytes() + assert reader.is_finished() + torch.compiler.load_cache_artifacts(artifact_bytes) + return CacheCompiledArtifact._load_impl(nullcontext(), key) + else: + raise RuntimeError( + "Invalid header, expected CacheCompiledArtifact or AOTCompiledArtifact, got: " + + header.decode("utf-8") + ) + + +class CacheCompiledArtifact(CompiledArtifact): + """ + CompiledArtifact that depends on torch.compiler.save_cache_artifacts + """ + + CACHE_HEADER = bytes("CacheCompiledArtifact", "utf-8") + + def __init__( + self, + compiled_fn: Callable[..., Any], + artifacts: Optional[tuple[bytes, CacheInfo]], + ): + self._compiled_fn = compiled_fn + self._artifacts = artifacts + + def __call__(self, *args: Any) -> Any: + return self._compiled_fn(*args) + + def save( + self, *, path: str, format: Literal["binary", "unpacked"] = "binary" + ) -> None: + with dynamo_timed("CompiledArtifact.save"): + if self._artifacts is None: + raise RuntimeError( + "CompiledArtifact.save failed to save since there's no artifact to save" + ) + artifact_bytes, cache_info = self._artifacts + assert len(cache_info.aot_autograd_artifacts) == 1, cache_info + key = cache_info.aot_autograd_artifacts[0] + + if format == "binary": + # can't assert that it is a file since it might not exist yet + assert not os.path.isdir(path) + + from torch.utils._appending_byte_serializer import BytesWriter + + from .codecache import torch_key + + writer = BytesWriter() + writer.write_bytes(CacheCompiledArtifact.CACHE_HEADER) + writer.write_bytes(torch_key()) + writer.write_str(key) + writer.write_bytes(artifact_bytes) + + from torch._inductor.codecache import write_atomic + + write_atomic(path, writer.to_bytes()) + else: + assert format == "unpacked" + if os.path.exists(path): + assert os.path.isdir(path) + shutil.rmtree(path, ignore_errors=True) + + from .codecache import FxGraphCache + + with temporary_cache_dir(path): + # This function unpacks the cache artifacts to disk + loaded_cache_info = torch.compiler.load_cache_artifacts( + artifact_bytes + ) + assert loaded_cache_info is not None + # Now write all the output_code artifacts to disk so that + # they can be inspected and modified + for key in loaded_cache_info.inductor_artifacts: + subdir = FxGraphCache._get_tmp_dir_for_key(key) + assert os.path.exists(subdir) + for path in sorted(os.listdir(subdir)): + with open(os.path.join(subdir, path), "rb") as f: + graph = pickle.load(f) + output_file = graph.write_to_disk() + log.info("Output code written to: %s", output_file) + + @staticmethod + def _load_impl( + cache_dir_ctx: AbstractContextManager[Any], key: str + ) -> CompiledArtifact: + with ( + cache_dir_ctx, + config.patch(unsafe_skip_cache_dynamic_shape_guards=True), + ): + with torch._functorch.config.patch(strict_autograd_cache=True): + from torch._functorch._aot_autograd.autograd_cache import ( + AOTAutogradCache, + ) + + result = AOTAutogradCache._lookup( + key, + local=True, + remote=False, + args=[], + cache_info={}, + aot_config=None, + ) + + assert result is not None + (entry, _) = result + + from .compile_fx import _CompileFxKwargs + + fx_config = _CompileFxKwargs( + cudagraphs=BoxedBool(False), + boxed_forward_device_index=BoxedDeviceIndex(0), + ) + + context = torch._guards.TracingContext(FakeTensorMode(shape_env=ShapeEnv())) + with torch._guards.tracing(context): + compiled_fn = entry.wrap_post_compile( + [], entry.sanitized_aot_config, fx_config + ) + return CacheCompiledArtifact(lambda *args: compiled_fn(list(args)), None) + + @staticmethod + def _prepare_load( + *, path: str, format: Literal["binary", "unpacked"] = "binary" + ) -> tuple[str, AbstractContextManager[Any]]: + """ + Do format specific prep and loads, return a context manager and key + """ + path = normalize_path_separator(path) + with dynamo_timed("CompiledArtifact.load"): + if format == "binary": + # can't assert that it is a file since it might not exist yet + assert not os.path.isdir(path) + with open(path, "rb") as file: + artifacts = file.read() + from torch.utils._appending_byte_serializer import BytesReader + + from .codecache import torch_key + + reader = BytesReader(artifacts) + assert reader.read_bytes() == torch_key() + key = reader.read_str() + artifact_bytes = reader.read_bytes() + assert reader.is_finished() + + torch.compiler.load_cache_artifacts(artifact_bytes) + return key, nullcontext() + else: + assert format == "unpacked" + assert os.path.isdir(path) + autograd_cache_dir = os.path.join(path, "aotautograd") + assert os.path.isdir(autograd_cache_dir) + files = list(os.listdir(autograd_cache_dir)) + assert len(files) == 1 + key = files[0] + cache_dir_ctx = temporary_cache_dir(path) + return key, cache_dir_ctx + + @staticmethod + def load( + *, path: str, format: Literal["binary", "unpacked"] = "binary" + ) -> CompiledArtifact: + key, cache_dir_ctx = CacheCompiledArtifact._prepare_load( + path=path, format=format + ) + return CacheCompiledArtifact._load_impl(cache_dir_ctx, key) + + +class AOTCompiledArtifact(CompiledArtifact): + """ + Similar to CompiledArtifact, but the object is a single, bundled precompiled function. + This object is always a serializable callable function. + + This object is essentially a wrapper for BundledAOTAutogradSerializableCallable, which + is used by torch._dynamo.aot_compile for AOT Precompilation. + """ + + AOT_HEADER = bytes("AOTCompiledArtifact", "utf-8") + + def __init__( + self, + compiled_fn: Callable[..., Any], + ): + self.inner_fn = BundledAOTAutogradSerializableCallable(compiled_fn) + self._artifacts = ( + None # We don't need artifacts, the inner object handles everything + ) + + @staticmethod + def from_bundled_callable( + bundled_fn: BundledAOTAutogradSerializableCallable, + ) -> AOTCompiledArtifact: + return AOTCompiledArtifact(bundled_fn.compiled_fn) + + def __call__(self, *args: Any) -> Any: + return self.inner_fn(*args) + + def save( + self, *, path: str, format: Literal["binary", "unpacked"] = "binary" + ) -> None: + if format == "unpacked": + raise RuntimeError( + "AOTCompiledArtifact does not support unpacked format yet" + ) + result_bytes = self.serialize() + from torch.utils._appending_byte_serializer import BytesWriter + + from .codecache import torch_key + + writer = BytesWriter() + writer.write_bytes(AOTCompiledArtifact.AOT_HEADER) + writer.write_bytes(torch_key()) + writer.write_bytes(result_bytes) + + from torch._inductor.codecache import write_atomic + + # Save a sentinel file to indicate that this is AOT + write_atomic(path, writer.to_bytes()) + + def serialize(self) -> bytes: + return BundledAOTAutogradSerializableCallable.serialize_compile_artifacts( + self.inner_fn + ) + + @staticmethod + def deserialize(result_bytes: bytes) -> AOTCompiledArtifact: + deserialized = ( + BundledAOTAutogradSerializableCallable.deserialize_compile_artifacts( + result_bytes + ) + ) + assert isinstance(deserialized, BundledAOTAutogradSerializableCallable) + return AOTCompiledArtifact.from_bundled_callable(deserialized) + + @staticmethod + def load( + *, path: str, format: Literal["binary", "unpacked"] = "binary" + ) -> CompiledArtifact: + if format == "unpacked": + raise RuntimeError( + "AOTCompiledArtifact does not support unpacked format yet" + ) + with open(path, "rb") as file: + from torch.utils._appending_byte_serializer import BytesReader + + from .codecache import torch_key + + result_bytes = file.read() + reader = BytesReader(result_bytes) + header = reader.read_bytes() + assert header == AOTCompiledArtifact.AOT_HEADER + assert reader.read_bytes() == torch_key() + artifact = reader.read_bytes() + assert reader.is_finished() + return AOTCompiledArtifact.deserialize(artifact) + + +def standalone_compile( + gm: GraphModule, + example_inputs: Sequence[InputType], + *, + dynamic_shapes: Any, + options: Any, + aot: bool = False, # AOT mode, which uses BundledAOTAutogradCache +) -> CompiledArtifact: + """ + Implementation of torch.inductor.standalone_compile + """ + from torch.compiler._cache import CacheArtifactManager + + from .compile_fx import compile_fx + + ignore_shape_env = False + if dynamic_shapes == "from_example_inputs": + fake_mode = FakeTensorMode(shape_env=ShapeEnv()) + # tells compile_fx to ignore the shape_envs on the ambient context + # and the graph_module. + ignore_shape_env = True + elif dynamic_shapes == "from_tracing_context": + # Reuse fake_mode from the TracingContext. + # NB: The TracingContext only exists if we're currently in a torch.compile backend. + context = torch._guards.TracingContext.get() + assert context.fake_mode is not None + fake_mode = context.fake_mode + elif dynamic_shapes == "from_graph": + fake_mode = FakeTensorMode(shape_env=ShapeEnv()) + # Strategy: find a FakeTensor in the graph output, grab its FakeTensorMode. + # The graph passed to standalone_compile must be an Inductor-approved graph, + # which means that there is at least one Tensor output and the output node + # contains a flat list of Tensors. + last_node = next(iter(reversed(gm.graph.nodes))) + assert last_node.op == "output" + assert len(last_node.args) == 1 + + def handle_node(node: torch.fx.Node) -> None: + nonlocal fake_mode + if "example_value" in node.meta: + maybe_tensor = node.meta["example_value"] + if isinstance(maybe_tensor, torch._subclasses.fake_tensor.FakeTensor): + fake_mode = maybe_tensor.fake_mode + + # If gm came from Dynamo, then last_node.args[0] is always a list, + # even in single-Tensor returns. + # + # It's possible to get into a situation where last_node.args[0] + # is a Node (and not a list!). This happens if you call split_module + # on the graph. We allow for this case since it is common. + if isinstance(last_node.args[0], torch.fx.Node): + handle_node(last_node.args[0]) + else: + for node in last_node.args[0]: + handle_node(node) + + else: + raise ValueError( + f"standalone_compile got unsupported `dynamic_shapes` value: dynamic_shapes={dynamic_shapes}." + ) + + context = torch._guards.TracingContext(fake_mode) + with ( + torch._guards.tracing(context), + CacheArtifactManager.with_fresh_cache(), + config.patch("triton.autotune_at_compile_time", True), + torch._functorch.config.patch("bundled_autograd_cache", aot), + ): + # compile_fx can mutate gm + gm = copy.deepcopy(gm) + compiled_fn = compile_fx( + gm, example_inputs, ignore_shape_env=ignore_shape_env, **options + ) + assert callable(compiled_fn) + if aot: + if not hasattr(compiled_fn, "serialize"): + raise RuntimeError( + "Compiled function should have serialize method when aot=True" + ) + return AOTCompiledArtifact(compiled_fn) + artifacts = torch.compiler.save_cache_artifacts() + if artifacts is None: + log.warning( + "standalone_compile artifact generation failed, cannot save. " + "Run with TORCH_LOGS=+torch._inductor.codecache to identify the problem" + ) + + return CacheCompiledArtifact(compiled_fn, artifacts) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/subgraph_lowering.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/subgraph_lowering.py new file mode 100644 index 0000000000000000000000000000000000000000..aa1b4d2db025da0ea8344e3185d4c65d7fda2aab --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/subgraph_lowering.py @@ -0,0 +1,208 @@ +"""Utilities for lowering subgraphs used by higher order operators""" + +import functools +import operator +from collections.abc import Callable, Generator +from contextlib import contextmanager +from dataclasses import dataclass +from typing import Any, Optional, TypeVar, Union +from typing_extensions import ParamSpec + +import torch +from torch.utils._ordered_set import OrderedSet + +from . import ir +from .exc import SubgraphLoweringException +from .graph import GraphLowering +from .ops_handler import SimpleCSEHandler +from .virtualized import ops, V, WrapperHandler + + +T = TypeVar("T") +_P = ParamSpec("_P") + +OpOverload = torch._ops.OpOverload +LoweringDict = dict[Union[OpOverload, str], Callable[..., Any]] +TargetType = Union[Callable[..., Any], str] + + +class PointwiseSubgraphLowering(torch.fx.Interpreter): + """ + Lowers a pointwise subgraph to a single set of buffers with a separate + lowering object. Errors if buffers are created unexpectedly + """ + + graph_outputs: Optional[list[ir.IRNode]] + root_graph: GraphLowering + _current_op: Optional[TargetType] + # For backwards of buffer_grads with scatters we allow mutations + allowed_mutations: Optional[OrderedSet[OpOverload]] + additional_lowerings: Optional[LoweringDict] + buffers: list[ir.Buffer] + mutated_buffers: OrderedSet[str] + + def __init__( + self, + gm: torch.fx.GraphModule, + root_graph_lowering: GraphLowering, + allowed_mutations: Optional[OrderedSet[OpOverload]] = None, + additional_lowerings: Optional[LoweringDict] = None, + ) -> None: + super().__init__(gm) + self.graph_outputs = None + self.root_graph = root_graph_lowering + self.allowed_mutations = allowed_mutations + self.additional_lowerings = additional_lowerings + self._current_op = None + + # Used to track buffers created during lowering + self.mutated_buffers = OrderedSet() + self.buffers = [] + + @contextmanager + def _op_context(self, op: TargetType) -> Generator[None, None, None]: + """Set which op is being processed in call function to know if we can mutate buffers""" + previous = self._current_op + self._current_op = op + try: + yield + finally: + self._current_op = previous + + def _approved_mutator(self) -> bool: + return ( + self.allowed_mutations is not None + and self._current_op in self.allowed_mutations + ) + + def mark_buffer_mutated(self, name: str) -> None: + if self._approved_mutator(): + self.mutated_buffers.add(name) + else: + raise SubgraphLoweringException( + f"Buffer mutation detected during lowering of {self._current_op}. " + "Buffer mutations are only allowed in approved mutation ops. " + "This is an error in the lowering of the subgraph, please file a bug report." + ) + + def register_buffer(self, buffer: ir.Buffer, *, set_name: bool = False) -> str: + if self._approved_mutator(): + name = self.root_graph.register_buffer(buffer, set_name=set_name) + return name + else: + raise SubgraphLoweringException( + "Buffers cannot be created while lowering a pointwise subgraph. " + "This could be for a good reason (e.g. you're calling an op we can't codegen as a pointwise op), " + "but it could also be a bug. Please file a bug report if you think this should be supportable." + ) + + def __getattr__(self, name: str) -> Any: + return getattr(self.root_graph, name) + + def call_function( + self, + target: TargetType, + args: Any, + kwargs: dict[str, Any], + ) -> Any: + from .lowering import lowerings + + with self._op_context(target): + if target is operator.getitem and isinstance(args[0], (list, tuple, dict)): + return super().call_function(target, args, kwargs) + + # These takes precedence over the main lowerings + if self.additional_lowerings is not None: + if target in self.additional_lowerings: + assert isinstance(target, OpOverload) + return self.additional_lowerings[target](*args, **kwargs) + + if target not in lowerings: + raise SubgraphLoweringException( + f"{target} not supported in subgraph, (missing lowering)" + ) + return lowerings[target](*args, **kwargs) + + def output(self, target: str, args: tuple[Any], kwargs: dict[str, Any]) -> None: # type: ignore[override] + assert len(args) == 1 + self.graph_outputs = args[0] + + +@dataclass +class InputDescriptor: + dtype: torch.dtype + device: torch.device + + +class TracingOpsHandler(WrapperHandler): + def __init__(self, tracer: torch.fx.Tracer, num_inputs: int) -> None: + parent = tracer.create_proxy("placeholder", "ops", (), {}) + super().__init__(parent) + self.tracer = tracer + + self.placeholders = [ + self.tracer.create_proxy("placeholder", f"input{i}", (), {}) + for i in range(num_inputs) + ] + + def placeholder(self, idx: int) -> torch.fx.Proxy: + return self.placeholders[idx] + + def output(self, *args: tuple[object]) -> None: + self.tracer.create_node( + "output", "output", (tuple(self.tracer.create_arg(a) for a in args),), {} + ) + + +def lower_pointwise_subgraph( + subgraph: ir.Subgraph, inputs: list[InputDescriptor] +) -> Callable[_P, Any]: + # Lower subgraph to ir.Pointwise nodes + def fake_inner_fn( + loop_idx: int, input_idx: int + ) -> Union[ir.Expr, ir.TensorBox, None]: + return ops.placeholder(input_idx) + + graph_inputs = [ + ir.Pointwise.create( + device=desc.device, + dtype=desc.dtype, + inner_fn=functools.partial(fake_inner_fn, input_idx=i), + ranges=[], + ) + for i, desc in enumerate(inputs) + ] + gm = subgraph.graph_module + pw_subgraph = PointwiseSubgraphLowering(gm, root_graph_lowering=V.graph) + with V.set_graph_handler(pw_subgraph): # type: ignore[arg-type] + pw_subgraph.run(*graph_inputs) + + # Combine multiple pointwise computations into a single graph module + # Do this by tracing through each individually and doing CSE + tracer = torch.fx.Tracer() + tracer.graph = torch.fx.Graph(tracer_cls=tracer.__class__) + trace_ops = SimpleCSEHandler(TracingOpsHandler(tracer, len(inputs))) + assert pw_subgraph.graph_outputs is not None + + with V.set_ops_handler(trace_ops): + output_irs = [] + + for out_var in pw_subgraph.graph_outputs: + assert isinstance(out_var, ir.TensorBox), type(out_var) + assert out_var.get_size() == [] + assert isinstance(out_var.data, ir.StorageBox) + assert isinstance(out_var.data.data, ir.Pointwise) + + idx = () + ir_out = out_var.data.data.inner_fn(idx) + + output_irs.append(ir_out) + + ops.output(*output_irs) + + lowered_gm = torch.fx.GraphModule({}, tracer.graph) + + def inner_fn(*args: _P.args, **kwargs: _P.kwargs) -> Any: + return lowered_gm(V.get_ops_handler(), *args, **kwargs) + + return inner_fn diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/test_case.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/test_case.py new file mode 100644 index 0000000000000000000000000000000000000000..efdef48884cefebdad8c3a5dda07848fff1b9675 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/test_case.py @@ -0,0 +1,51 @@ +import contextlib +import os +from typing import Union + +from torch._dynamo.test_case import ( + run_tests as dynamo_run_tests, + TestCase as DynamoTestCase, +) +from torch._functorch import config as functorch_config +from torch._inductor import config +from torch._inductor.utils import fresh_cache + + +def run_tests(needs: Union[str, tuple[str, ...]] = ()) -> None: + dynamo_run_tests(needs) + + +class TestCase(DynamoTestCase): + """ + A base TestCase for inductor tests. Enables FX graph caching and isolates + the cache directory for each test. + """ + + def setUp(self) -> None: + super().setUp() + self._inductor_test_stack = contextlib.ExitStack() + self._inductor_test_stack.enter_context( + functorch_config.patch( + { + "enable_autograd_cache": True, + } + ) + ) + + if ( + "TORCHINDUCTOR_FX_GRAPH_CACHE" not in os.environ + and "TORCHINDUCTOR_FX_GRAPH_CACHE_DEFAULT" not in os.environ + ): + self._inductor_test_stack.enter_context( + config.patch({"fx_graph_cache": True}) + ) + + if ( + os.environ.get("INDUCTOR_TEST_DISABLE_FRESH_CACHE") != "1" + and os.environ.get("TORCH_COMPILE_DEBUG") != "1" + ): + self._inductor_test_stack.enter_context(fresh_cache()) + + def tearDown(self) -> None: + super().tearDown() + self._inductor_test_stack.close() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/test_operators.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/test_operators.py new file mode 100644 index 0000000000000000000000000000000000000000..bbdcf89d0ef866bec01114953bce4dfb379628e9 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/test_operators.py @@ -0,0 +1,29 @@ +from typing import Any + +import torch.library +from torch import Tensor +from torch.autograd import Function + + +_test_lib_def = torch.library.Library("_inductor_test", "DEF") +_test_lib_def.define("realize(Tensor self) -> Tensor", tags=torch.Tag.pt2_compliant_tag) + +_test_lib_impl = torch.library.Library("_inductor_test", "IMPL") +for dispatch_key in ("CPU", "CUDA", "MPS", "Meta"): + _test_lib_impl.impl("realize", lambda x: x.clone(), dispatch_key) + + +class Realize(Function): + @staticmethod + # pyrefly: ignore [bad-override] + def forward(ctx: object, x: Tensor) -> Tensor: + return torch.ops._inductor_test.realize(x) + + @staticmethod + # types need to stay consistent with _SingleLevelFunction + def backward(ctx: Any, *grad_output: Any) -> Any: + return grad_output[0] + + +def realize(x: Tensor) -> Tensor: + return Realize.apply(x) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/tiling_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/tiling_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..89ad329abd70b0dae8b3e13f21bf8869e1580482 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/tiling_utils.py @@ -0,0 +1,814 @@ +import dataclasses +import itertools +from collections import Counter, defaultdict +from collections.abc import Callable +from typing import Literal, Optional, overload, TYPE_CHECKING, TypeVar, Union + +import sympy + +import torch +from torch._inductor import config +from torch._inductor.dependencies import index_vars_no_squeeze +from torch._inductor.utils import sympy_product, sympy_subs +from torch.utils._ordered_set import OrderedSet +from torch.utils._sympy.functions import Identity +from torch.utils._sympy.solve import try_solve +from torch.utils._sympy.symbol import symbol_is_type, SymT + +from .virtualized import V + + +T = TypeVar("T") +U = TypeVar("U") + + +Split = tuple[sympy.Expr, ...] +VarsAndRanges = tuple[list[sympy.Symbol], list[sympy.Expr]] + + +loop_tiling_log = torch._logging.getArtifactLogger(__name__, "loop_tiling") +from torch.utils._sympy.functions import FloorDiv, ModularIndexing + + +if TYPE_CHECKING: + from torch._inductor.scheduler import FusedSchedulerNode, SchedulerNode + + +def solve_for_zero(expr: sympy.Expr) -> Optional[sympy.Expr]: + """ + Given an expr with a single free symbol, solve for a constant relation that would make + this expression 0. + """ + if expr.is_constant(): + return None + elif isinstance(expr, FloorDiv): + return None + + assert len(expr.free_symbols) == 1 + free_symbol = next(iter(expr.free_symbols)) + if isinstance(expr, ModularIndexing): + out = try_solve(sympy.Eq(expr.args[0], expr.args[2]), free_symbol) + else: + out = try_solve(sympy.Eq(expr, 0), free_symbol) + if not out or not out[1].is_constant(): + return None + return out[1] + + +def solve_for_tiling(expr: sympy.Expr) -> Optional[sympy.Expr]: + """ + Giving an expr with a single free symbol, try to find a tiling that would + make the expression coalesced with respect to that symbol. + + Tiling an expression `x` by `y` means that the expression will now be indexed + by both the original (x) and by (x * y). So we are looking for a + multiplicative factor that will make ((x + 1) * y) - (x * y) == 1. + + To simplify things for sympy, we'll try just x * y == 1, check x(1) and x(0). + """ + + if len(expr.free_symbols) == 0: + return None + + free_symbol = next(iter(expr.free_symbols)) + + def _solve_simple_expr(expr: sympy.Expr) -> Optional[sympy.Expr]: + assert not expr.has(ModularIndexing) and not expr.has(FloorDiv) + if len(expr.free_symbols) != 1: + return None + + out = try_solve(sympy.Eq(expr, 1), free_symbol) + if not out or not out[1].is_constant(): + return None + return out[1] + + # Sympy solving is very limited with ModularIndexing and FloorDiv, + # but good otherwise. + if not expr.has(ModularIndexing) and not expr.has(FloorDiv): + return _solve_simple_expr(expr) + + required_values = [] + eq_1_expressions = [] + + # very piecemeal solution if ModularIndexing or FloorDiv involved. + # Look for terms we'll try to make 0, and then other terms we'll try to make 1. + # Expand as needed. + for arg in sympy.Add.make_args(expr): + # Try to make mul terms 0 + if isinstance(arg, sympy.Mul): + seen = False + # TODO - only need one of these to be solvable to zero + # + for mul_arg in arg.args: + out = solve_for_zero(mul_arg) + if out is None: + continue + + assert out.is_constant() + seen = True + required_values.append(out) + + if not seen: + return None + else: + eq_1_expressions.append(arg) + + if not eq_1_expressions: + return None + + eq_1_expr = sum(eq_1_expressions) + + def indexing_div_rep( + x: sympy.Expr, + y: sympy.Expr, + z: Optional[sympy.Expr] = None, + ) -> sympy.Expr: + return x / y + + # For the purposes of tiling/coalesced access, approximate ModularIndexing and FloorDiv + # then check later + # pyrefly: ignore [missing-attribute] + eq_1_expr_simplified = eq_1_expr.replace(ModularIndexing, indexing_div_rep).replace( + FloorDiv, indexing_div_rep + ) + + out = _solve_simple_expr(eq_1_expr_simplified) + # since we approximated FloorDiv/ModularIndexing, double check here + if not out or sympy_subs(eq_1_expr, {free_symbol: out}) != 1: + return None + + required_values.append(out) + + if len(OrderedSet(required_values)) == 1: + return required_values[0] + + return None + + +def find_broadcast_var( + index: sympy.Expr, var_ranges: dict[sympy.Expr, int] +) -> Optional[sympy.Expr]: + """ + Try to find the variable that this index is broadcast over. + A broadcast pattern is one where consecutive values of a variable + access the same memory location (e.g., x // 10). + """ + # Approximate analysis by evaluating at 1 and 0 + variables: dict[sympy.Symbol, int] = {} + for v in index.free_symbols: + if v in var_ranges: + variables[v] = 0 + else: + variables[v] = get_hint(v) + + zero_index = sympy_subs(index, variables) + for v in var_ranges: + if v not in index.free_symbols: + continue + + variables[v] = 1 + try: + new_val = sympy_subs(index, variables) + except ZeroDivisionError: + loop_tiling_log.info("zero division error %s %s", index, variables) + continue + # Broadcast means the value doesn't change when the variable increments + if new_val == zero_index: + return v + variables[v] = 0 + + return None + + +def find_coalesced_var( + index: sympy.Expr, var_ranges: dict[sympy.Expr, int] +) -> Optional[sympy.Expr]: + """ + Try to find the symbol which coalesces this index + """ + top_level_terms = sympy.Add.make_args(index) + for v in var_ranges: + if v in top_level_terms: + return v + + # Approximate analysis by evaluating at 1 and 0 + variables: dict[sympy.Symbol, int] = {} + for v in index.free_symbols: + if v in var_ranges: + variables[v] = 0 + else: + variables[v] = get_hint(v) + + zero_index = sympy_subs(index, variables) + for v in var_ranges: + variables[v] = 1 + try: + new_val = sympy_subs(index, variables) + except ZeroDivisionError: + loop_tiling_log.info("zero division error %s %s", index, variables) + continue + if new_val - zero_index == 1: + variables[v] = 2 + # in some more complex expressions, 0->1 will be coalesced, + # but not 1->2 + if (sympy_subs(index, variables) - new_val) == 1: + return v + variables[v] = 0 + + return None + + +@dataclasses.dataclass(frozen=True) +class FusedNormalizedReadsWrites: + """ + Normalized reads and writes for nodes in the same FusedSchedulerNode. + """ + + index_vars: OrderedSet[sympy.Symbol] + reduce_vars: OrderedSet[sympy.Symbol] + reads: dict[sympy.Expr, OrderedSet[str]] + writes: dict[sympy.Expr, OrderedSet[str]] + var_ranges: dict[sympy.Symbol, int] + + +@overload +def get_pw_red_splits( + n: "SchedulerNode", + pointwise_numel: sympy.Expr, + red_numel: sympy.Expr, + none_if_not_divisible: Literal[True], +) -> Optional[tuple[VarsAndRanges, VarsAndRanges]]: ... + + +@overload +def get_pw_red_splits( + n: "SchedulerNode", + pointwise_numel: sympy.Expr, + red_numel: sympy.Expr, + none_if_not_divisible: Literal[False] = False, +) -> tuple[VarsAndRanges, VarsAndRanges]: ... + + +def get_pw_red_splits( + n: "SchedulerNode", + pointwise_numel: sympy.Expr, + red_numel: sympy.Expr, + none_if_not_divisible: bool = False, +) -> Optional[tuple[VarsAndRanges, VarsAndRanges]]: + if n.is_reduction() or sympy_product(n._body.sizes[0]) == pointwise_numel: + return ( + (n._body.iter_vars, n._body.sizes[0]), + (n._body.reduce_vars, n._body.sizes[1]), + ) # type: ignore[return-value] + + assert sympy_product(n._body.sizes[0]) == pointwise_numel * red_numel # type: ignore[operator] + i = len(n._body.sizes[0]) - 1 + prod = 1 + while i >= 0: + prod *= n._body.sizes[0][i] + if prod == red_numel: + break + i -= 1 + + if i >= 0: + pw_splits = n._body.sizes[0][0:i] + iter_vars = n._body.iter_vars[0:i] + + red_splits = n._body.sizes[0][i:] + red_vars = n._body.iter_vars[i:] + return (iter_vars, pw_splits), (red_vars, red_splits) # type: ignore[return-value] + + if none_if_not_divisible: + return None + else: + return ( + (n._body.iter_vars, n._body.sizes[0]), + (n._body.reduce_vars, n._body.sizes[1]), + ) # type: ignore[return-value] + + +class NodeSplitGetter: + """ + Finds a Pointwise, Reduction Split that compatible with all nodes in a SchedulerNode. + """ + + def __init__( + self, + node: Union["FusedSchedulerNode", "SchedulerNode"], + ): + self.node = node + self.pointwise_numel: sympy.Expr = node.group[1][0] + self.red_numel: sympy.Expr = node.group[1][1] + + self.pw_split_options: dict[int, OrderedSet[Split]] = defaultdict(OrderedSet) + + self.reduction_split: Split = () + self.all_node_sizes: OrderedSet[tuple[Split, Split]] = OrderedSet() + + fused_group = node.group[1] + for n in reversed(node.get_nodes()): + if not isinstance(n, torch._inductor.scheduler.SchedulerNode): + continue + + # if we can't split the pw ranges into a (pw, red) split, + # dont add as a split option, but do make sure we check that this size + # is splittable + maybe_splits = get_pw_red_splits( + n, self.pointwise_numel, self.red_numel, none_if_not_divisible=True + ) + if maybe_splits is None: + self.all_node_sizes.add(n._body.sizes) + continue + + (_, n_pw_splits), (_, n_red_splits) = maybe_splits + + # fill in reduction size + n_pw_splits, n_red_splits = ( + torch._inductor.codegen.simd.SIMDKernel.prepare_split_iteration_lengths( + fused_group, (n_pw_splits, n_red_splits), self.red_numel + ) + ) + + self.pw_split_options[len(n_pw_splits)].add(tuple(n_pw_splits)) + + # initially, we are just going to do a single reduction split since + # reduction tiling is off by default. even if we miss a reduction split, + # we can recover it in the split var analysis. + # TODO: an earlier version for this code tried to iteratively try the maximum number + # of split vars, by iterating over both pointwise and reduction. but not worth + # the complexity yet. + + if n_red_splits != (): + self.reduction_split = (sympy_product(n_red_splits),) + + n_size = (tuple(n_pw_splits), tuple(n_red_splits)) + self.all_node_sizes.add(n_size) + + self.seen_pw_splits: OrderedSet[Split] = OrderedSet() + + def get_node_splits(self) -> tuple[Split, Split]: + """ + Get a compatible pointwise, reduction split of the node + """ + + if len(self.all_node_sizes) == 1: + return next(iter(self.all_node_sizes)) + + max_pw_split = max(self.pw_split_options.keys()) + for pw_split_len in range(max_pw_split, 0, -1): + for pw_split in self.pw_split_options[pw_split_len]: + if out := self.try_split(pw_split, self.reduction_split): + return out + + # combine dims for next round + for pw_split in self.pw_split_options[pw_split_len]: + for i in range(len(pw_split) - 1): + new_split = tuple( + pw_split[0:i] + + (sympy_product(pw_split[i : i + 2]),) + + pw_split[i + 2 :] + ) + self.pw_split_options[len(new_split)].add(new_split) + + # if for whatever reason we couldn't split above, return default split + return ((self.pointwise_numel,), (self.red_numel,)) + + def try_split(self, pw: Split, red: Split) -> Optional[tuple[Split, Split]]: + """ + See if this split is compatible, and potentially returning a longer split + than the input. + """ + + from torch._inductor.codegen.simd import CantSplit, SIMDKernel + + if pw in self.seen_pw_splits: + return None + self.seen_pw_splits.add(pw) + + for n_pw, n_red in self.all_node_sizes: + try: + groups = pw + red + lengths = (n_pw, n_red) + splits, getters = SIMDKernel._split_iteration_ranges(groups, lengths) + except CantSplit: + return None + + assert len(getters) == 2 + pw_group_splits = splits[: len(pw)] + # if we had to divide a variable into two to do this split, + # then lets try the larger, induced split. + # e.g. splitting (12, 2) into (2, 12) will split the first var into: + # (2, 6) and produce an overall split of (2, 6, 2) + flattened_pw_splits = tuple(itertools.chain.from_iterable(pw_group_splits)) + if flattened_pw_splits != pw: + if out := self.try_split(flattened_pw_splits, red): + return out + + return pw, red + + +def apply_var_mapping( + iter_vars: list[sympy.Symbol], + red_vars: list[sympy.Symbol], + norm_pw_vars: list[sympy.Symbol], + norm_red_vars: list[sympy.Symbol], + new_ranges: list[list[sympy.Expr]], + return_getters_groups: list[list[Callable[[list[sympy.Expr]], sympy.Expr]]], +) -> dict[sympy.Symbol, sympy.Expr]: + """Maps original variables to expressions using normalized variables.""" + + # the output of split_iteration_range is a new_ranges, return_getters_groups + # new_ranges is a flattened list of ranges corresponding to the new pw and red vars + # for example, taking in pw vars of range (6, 6) to normalized range [36], + # new_ranges would be [[6, 6]] + # There is a return_getter callable for each input iter_var and red_vars. + # if you flatten out all of the ranges, and create a variable for each index, + # then applying the flattening vars to the callables in return_getters_groups + # gives you the mapping from input vars -> flattened vars. + # From there, we can compute the output, normalized variables. + # For instance [6, 6] corresponding to flat vars v0, v1 will be + # v0 + 6 * v1 + + # Create flattened iteration variables + num_vars = sum(len(s) for s in new_ranges) + flat_vars = sympy.symbols(f"v_0:{num_vars}") + count = 0 + + if len(iter_vars) == 0 and len(red_vars) == 0: + return {} + + assert len(new_ranges) == len(norm_pw_vars + norm_red_vars) + apply_groups = [] + for group in return_getters_groups: + apply_groups.append([g(flat_vars) for g in group]) + + iter_vars_to_flat_vars = {} + for i, (group, var_group) in enumerate( + zip(apply_groups, (iter_vars, red_vars), strict=True) + ): + # if the node has sizes (p0, 1) and the fused node is (p0, r0) + # the reduction var gets filled in for split_iteration_range + if len(group) != len(var_group): + assert i == 1 + assert len(var_group) == 0 + continue + + iter_vars_to_flat_vars.update({v: g for g, v in zip(group, var_group)}) + + count = 0 + flat_vars_to_new_vars = {} + for new_range, new_var in zip( + new_ranges, norm_pw_vars + norm_red_vars, strict=True + ): + range_vars = [] + for _ in range(len(new_range)): + range_vars.append(flat_vars[count]) + count += 1 + + prod = 1 + for i in range(len(new_range) - 1, -1, -1): + flat_vars_to_new_vars[range_vars[i]] = new_var * prod + prod = new_range[i] * prod + + return { + k: sympy_subs(v, flat_vars_to_new_vars) + for k, v in iter_vars_to_flat_vars.items() + } + + +def extract_normalized_read_writes( + node: Union["FusedSchedulerNode", "SchedulerNode"], +) -> Optional[FusedNormalizedReadsWrites]: + """Extracts index variables, reduce variables, read/write expressions, and variable ranges from a fused node.""" + reads: dict[sympy.Expr, OrderedSet[str]] = defaultdict(OrderedSet) + writes: dict[sympy.Expr, OrderedSet[str]] = defaultdict(OrderedSet) + + all_output_names = node.get_buffer_names() + op_names = node.get_operation_names() + outputs: OrderedSet[str] = OrderedSet() + removed_buffers: OrderedSet[str] = OrderedSet() + for buf_name in all_output_names: + if V.graph.scheduler.can_buffer_be_removed_through_fusion(buf_name, op_names): + removed_buffers.add(buf_name) + else: + outputs.add(buf_name) + + inputs = OrderedSet( + dep.name for dep in node.read_writes.reads if dep.name not in removed_buffers + ) + + pointwise_numel: sympy.Expr = node.group[1][0] + red_numel: sympy.Expr = node.group[1][1] + + # TODO - a few dynamic shapes issues to resolve + if any( + (isinstance(var, sympy.Expr) and not var.is_constant()) + for var in (pointwise_numel, red_numel) + ): + return None + + pw_splits, red_splits = NodeSplitGetter(node).get_node_splits() + + # lets use different prefix (`n`) to distinguish + (norm_pw_vars, norm_red_vars), ranges = index_vars_no_squeeze( + pw_splits, red_splits, prefix="n" + ) + + for n in list(node.get_nodes()): + if not isinstance(n, torch._inductor.scheduler.SchedulerNode): + continue + + body = n._body + + # TODO - not handled well. indirect loads will not be coalesced, + # need to account for that in analysis. + if body.indirect_vars: + return None + + n_reads: dict[sympy.Expr, OrderedSet[str]] = defaultdict(OrderedSet) + n_writes: dict[sympy.Expr, OrderedSet[str]] = defaultdict(OrderedSet) + + # TODO - will the names for all the inputs/outputs accurately + # reflect mutation, or do I need to remap with mutation_real_name + for inp in inputs: + for expr in body.get_all_read_expr(inp): + n_reads[expr].add(inp) + + for out in outputs: + for expr in body.get_all_write_expr(out): + n_writes[expr].add(out) + + if not n_reads and not n_writes: + continue + + (iter_vars, n_pw_splits), (red_vars, n_red_splits) = get_pw_red_splits( + n, pointwise_numel, red_numel + ) + + groups = pw_splits + red_splits + lengths = (n_pw_splits, (n_red_splits)) + lengths = ( + torch._inductor.codegen.simd.SIMDKernel.prepare_split_iteration_lengths( + groups, lengths, red_numel + ) + ) + new_ranges, return_getters_groups = ( + torch._inductor.codegen.simd.SIMDKernel._split_iteration_ranges( + groups, lengths + ) + ) + var_map = apply_var_mapping( + iter_vars, + red_vars, + norm_pw_vars, + norm_red_vars, + new_ranges, + return_getters_groups, + ) + + # We create Identity sympy.Functions to prevent expansion to int64, + # unwrap for tiling analysis. + def remove_identity(expr: sympy.Expr) -> sympy.Expr: + return expr.replace(Identity, lambda x: x) + + n_reads_new = { + sympy_subs(remove_identity(read), var_map): v for read, v in n_reads.items() + } + n_writes_new = { + sympy_subs(remove_identity(write), var_map): v + for write, v in n_writes.items() + } + + for expr, buf_names in n_reads_new.items(): + reads[expr] |= buf_names + + for expr, buf_names in n_writes_new.items(): + writes[expr] |= buf_names + + reads = { + V.graph.sizevars.simplify_with_ranges(r, ranges): v for r, v in reads.items() + } + writes = { + V.graph.sizevars.simplify_with_ranges(w, ranges): v for w, v in writes.items() + } + + fused_out = FusedNormalizedReadsWrites( + norm_pw_vars, # type: ignore[arg-type] + norm_red_vars, # type: ignore[arg-type] + reads, + writes, + ranges, + ) + loop_tiling_log.info("Normalized Fused reads: %s", fused_out) + return fused_out + + +def get_score( + addr: sympy.Expr, var_ranges: dict[sympy.Symbol, int], buf_names: OrderedSet[str] +) -> int: + """ + Score addr according to its approximate size. + """ + # TODO - deduplicate with candidate_tilings + var_sizes = [] + for v in addr.free_symbols: + v_size = var_ranges.get(v) + # TODO - reason about indirect vars + if not symbol_is_type(v, SymT.INDIRECT) and v_size is not None: + var_sizes.append(v_size) + from .virtualized import V + + return V.graph.sizevars.atomically_apply_size_hint( + sympy_product(var_sizes), fallback=config.unbacked_symint_fallback + ) + + +def try_get_buf_size(buf_name: str) -> Optional[int]: + buf = V.graph.try_get_buffer(buf_name) + if not buf: + return None + return V.graph.sizevars.atomically_apply_size_hint( + sympy_product(buf.get_size()), fallback=config.unbacked_symint_fallback + ) + + +def get_hint(v: Union[sympy.Expr, int]) -> int: + if isinstance(v, int): + return v + else: + return V.graph.sizevars.size_hint(v, fallback=config.unbacked_symint_fallback) + + +@dataclasses.dataclass(frozen=True) +class VarTiling: + """ + Tiling of a var by `tiling_factor` that yields additional coalesced mem accesses by `benefit_score` + """ + + var: sympy.Symbol + tiling_factor: int + score: int + + +@dataclasses.dataclass(frozen=True) +class CoalesceVarAnalysis: + # Var -> Memory Score - not strictly the amount of memory + # because we multiply writes x2 + # TODO: separate into dataclass that olds mem, dtype, is_write + coalesced_by_var: dict[sympy.Expr, int] + + uncoalesced_addrs: dict[sympy.Expr, int] + + norm_read_writes: FusedNormalizedReadsWrites + + suggested_split: Optional[VarTiling] = None + + +def analyze_memory_coalescing( + fused_node: Union["FusedSchedulerNode", "SchedulerNode"], +) -> Optional[CoalesceVarAnalysis]: + """ + Find variables that coalesce the reads and writes and score the total size. + + If uncoalesced memory expressions are found, look for additionally tiling of variables + which will coalesce memory accesses. + + For instance - for the following expression: + + (32*p0) // 2048 + + Tiling p0 by 64 will make this expression coalesced. + """ + + norm_read_writes = extract_normalized_read_writes(fused_node) + + if norm_read_writes is None: + return None + + reads = norm_read_writes.reads + writes = norm_read_writes.writes + var_ranges = norm_read_writes.var_ranges + + coalesced_by_var: dict[sympy.Symbol, int] = Counter() + uncoalesced_addrs: dict[sympy.Expr, int] = Counter() + + for is_read, (memory_expr, buf_names) in itertools.chain( + ((True, item) for item in reads.items()), + ((False, item) for item in writes.items()), + ): + # skip memory deps with indirect vars - todo: better handling + indirect_expr = bool( + memory_expr.free_symbols - norm_read_writes.var_ranges.keys() + ) + + if indirect_expr: + continue + + size = get_score(memory_expr, var_ranges, buf_names) + + if size == 0: + continue + + maybe_coalesced_var = find_coalesced_var(memory_expr, var_ranges) + # while broadcasting vars are not technically coalesced, + # accesses at least stay in cache, so they provide most of the benefit. + # treat the same for now. + if maybe_coalesced_var is None: + maybe_coalesced_var = find_broadcast_var(memory_expr, var_ranges) + + total_score = 0 + for buf_name in buf_names: + if (buf := V.graph.try_get_buffer(buf_name)) and ( + buf_size := try_get_buf_size(buf_name) + ): + # constrain by buf size since we'll read at most that many elements + # score could be more through either masking or by broadcasting (e.g. x // 16) + total_score += min(buf_size, size) * buf.dtype.itemsize + + # coalesced writes more important + total_score *= 1 if is_read else 2 + + if maybe_coalesced_var: + coalesced_by_var[maybe_coalesced_var] += total_score + else: + uncoalesced_addrs[memory_expr] += total_score + + if not uncoalesced_addrs: + return CoalesceVarAnalysis( + coalesced_by_var=coalesced_by_var, + uncoalesced_addrs=uncoalesced_addrs, + norm_read_writes=norm_read_writes, + ) + + # map from var -> tiling -> total_score + tiling_scores: dict[sympy.Expr, dict[int, int]] = defaultdict(Counter) + + for uncoalesced_expr, addr_score in uncoalesced_addrs.items(): + expr_subs = dict.fromkeys(uncoalesced_expr.free_symbols, 0) + for v in uncoalesced_expr.free_symbols: + # skip non iter/reduce var variables + if v not in var_ranges: + continue + # skip small addrs + if addr_score == 0: + continue + del expr_subs[v] + single_var_expr = sympy_subs(uncoalesced_expr, expr_subs) + expr_subs[v] = 0 + tiling_factor = solve_for_tiling(single_var_expr) + if ( + tiling_factor is None + or not tiling_factor.is_constant() + or not tiling_factor.is_integer + ): + continue + + tiling_factor = int(tiling_factor) + if not V.graph.sizevars.statically_known_lt(tiling_factor, var_ranges[v]): + continue + + # TODO - if a var is in the middle, such as [n0, n1, n2] + # n1 can can be split beyond range + + MIN_TILING_BLOCK = 8 + if not all( + V.graph.sizevars.statically_known_lt(MIN_TILING_BLOCK, block) + for block in (tiling_factor, var_ranges[v] // tiling_factor) + ): + continue + + tiling_scores[v][tiling_factor] += addr_score + + if len(tiling_scores) == 0: + return CoalesceVarAnalysis( + coalesced_by_var=coalesced_by_var, + uncoalesced_addrs=uncoalesced_addrs, + norm_read_writes=norm_read_writes, + ) + + best_tiling: Optional[tuple[sympy.Expr, int]] = None + best_tiling_score = 0 + + for var, tiling_counter in tiling_scores.items(): + for tile, tile_score in tiling_counter.items(): + if tile_score > best_tiling_score: + best_tiling = (var, tile) + best_tiling_score = tile_score + + if best_tiling is None: + return CoalesceVarAnalysis( + coalesced_by_var=coalesced_by_var, + uncoalesced_addrs=uncoalesced_addrs, + norm_read_writes=norm_read_writes, + ) + + # TODO - for strictly pointwise fusions, + # we can consider just swizzling the var if the var we are going to tile + # does not coalesce a significant portion of global reads + # TODO - could also prefer index var splits to reduction, better tested + return CoalesceVarAnalysis( + coalesced_by_var=coalesced_by_var, + uncoalesced_addrs=uncoalesced_addrs, + norm_read_writes=norm_read_writes, + suggested_split=VarTiling(best_tiling[0], best_tiling[1], best_tiling_score), + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/triton_bundler.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/triton_bundler.py new file mode 100644 index 0000000000000000000000000000000000000000..5bf5210a2cf467240327c6fa78ead967f3d89156 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/triton_bundler.py @@ -0,0 +1,404 @@ +import copy +import dataclasses +import logging +import os +import shutil +import uuid +from pathlib import Path +from typing import Optional + +from torch._dynamo.utils import counters, dynamo_timed, set_feature_use +from torch._utils_internal import justknobs_check +from torch.utils._filelock import FileLock + +from .runtime.runtime_utils import triton_cache_dir +from .utils import _IS_WINDOWS, GPU_KERNEL_BIN_EXTS + + +log = logging.getLogger(__name__) + + +@dataclasses.dataclass(frozen=True) +class TritonBundleEntry: + """ + When we have compiled a triton kernel, we take note of that kernel by + its triton generated hash, its device, and where this kernel is located. + This is the minimum information we can use to later retrieve this kernel + from file system. + """ + + kernel_hash: str + device: int + directory: str + + +@dataclasses.dataclass(frozen=True) +class TritonKernelArtifact: + """ + Artifact for an individual kernel converted to bytes. + Bytes could be a cubin, json, ttir, or ttgir. + """ + + filename: str + payload: bytes = dataclasses.field(repr=False) # Do not display binary + + +@dataclasses.dataclass(frozen=True) +class StaticallyLaunchedAutotuner: + """ + Represents a statically compiled CachingAutotuner object that we can + save directly in the cache. A CachingAutotuner is made up of a list of + StaticTritonCompileResults, each of which uses the cubin from a TritonKernelArtifact. + + Statically saved here have their cubin files saved by a corresponding TritonBundleEntry. + """ + + cache_key: str + kernel_name: str + kernel: "CachingAutotuner" # type: ignore[name-defined] # noqa: F821 + + +@dataclasses.dataclass(frozen=True) +class TritonKernelArtifacts: + """ + Collection of artifacts for a particular kernel. + """ + + kernel_hash: str + device: int + artifacts: list[TritonKernelArtifact] + + +@dataclasses.dataclass(frozen=True) +class TritonBundlerMetadata: + """ + Metadata used for instrumentation + """ + + cached_kernel_names: list[str] + statically_launched_kernel_names: list[str] + + +@dataclasses.dataclass(frozen=True) +class TritonBundle: + """ + Serializable bundle to save into FXGraphCache + """ + + kernel_artifacts: list[TritonKernelArtifacts] + static_autotuners: list[StaticallyLaunchedAutotuner] + + +class TritonBundler: + """ + Lightweight Triton Kernel bundler that notes each time we compile a triton + kernel. When collect is called, converts all the previously noted kernels and + their artifacts into a structured bytes blob, and later when write is called + it writes this structured blob back to file system. + + Intended Life cycle: + - TritonBundler.begin_compile is called when we start compiling in Inductor + - TritonBundler.put is called each time a Triton Kernel is compiled + - TritonBundler.collect is called when a cache entry is being generated + - TritonBundler.end_compile is called to indicate bundling is completed, + collect will execute this function as well. + - TritonBundler.read_and_emit is called when a cache entry is read + """ + + _entries: Optional[list[TritonBundleEntry]] = None + _static_autotuners: Optional[list[StaticallyLaunchedAutotuner]] = None + + # __grp__kernel_name.json contains metadata with source code paths + # we use this as sentinel value for search and replace + _REPLACE_BYTES: bytes = b"[REPLACE]" + + @staticmethod + def is_enabled() -> bool: + from torch._inductor import config + + if config.force_disable_caches: + return False + + if (b := config.bundle_triton_into_fx_graph_cache) is not None: + return b + + if not config.is_fbcode(): + return False + + return justknobs_check( + "pytorch/remote_cache:bundle_triton_into_fx_graph_cache_v2" + ) + + @classmethod + def begin_compile(cls) -> None: + """ + Initializes the TritonBundler. + The current TritonBundler bundle is finalized by TritonBundler.collect. + """ + if not TritonBundler.is_enabled(): + return + log.debug("TritonBundler.begin_compile is called") + assert cls._entries is None + cls._entries = [] + cls._static_autotuners = [] + + @classmethod + def end_compile(cls) -> None: + """ + Finalizes the TritonBundler. If collect is not yet called, it + discards the current bundle. + """ + log.debug("TritonBundler.end_compile is called") + cls._entries = None + cls._static_autotuners = None + + @classmethod + def put(cls, kernel_hash: str, device: int) -> None: + """ + Lazily observes that we have seen a Triton kernel compilation. Remembers + it for when collect is later called. + """ + if (entries := cls._entries) is not None: + entries.append( + TritonBundleEntry(kernel_hash, device, triton_cache_dir(device)) + ) + + @classmethod + def put_static_autotuner(cls, key: str, kernel: "CachingAutotuner") -> None: # type: ignore[name-defined] # noqa: F821 + from torch._inductor import config + + assert config.use_static_cuda_launcher + if (entries := cls._static_autotuners) is not None: + # Clear a bunch of unpicklable values and make a copy to save + # for FXGraphCache + old_values = kernel.prepare_for_pickle() + new_kernel = copy.deepcopy(kernel) + new_kernel.prepare_for_caching() + new_kernel._reload_kernel = None + + entries.append( + StaticallyLaunchedAutotuner( + key, + new_kernel.inductor_meta.get("kernel_name", "unknown_kernel"), + new_kernel, + ) + ) + + # Put the values back since we need it to use now + kernel.restore_after_unpickle(old_values) + + @classmethod + def collect_static_autotuners( + cls, + ) -> tuple[list[StaticallyLaunchedAutotuner], list[str]]: + if not cls._static_autotuners: + return [], [] + else: + log.info( + "Saving %d statically launchable CachingAutotuners", + len(cls._static_autotuners), + ) + static_autotuner_names = [i.kernel_name for i in cls._static_autotuners] + counters["inductor"]["triton_bundler_save_static_autotuner"] += 1 + return cls._static_autotuners, static_autotuner_names + + @classmethod + def load_autotuners( + cls, static_autotuners: Optional[list[StaticallyLaunchedAutotuner]] + ) -> list[str]: + """ + Load statically launchable CachingAutotuners into async_compile.CompiledTritonKernels + cache. + """ + if not static_autotuners: + return [] + + from torch._inductor.async_compile import CompiledTritonKernels + from torch._inductor.codecache import StaticAutotunerFuture + + log.info("Loading %d statically launchable autotuners", len(static_autotuners)) + kernel_names = [] + with dynamo_timed("TritonBundler.load_cached_static_autotuners"): + for result in static_autotuners: + try: + # Make sure the cubin path exists and is valid + for compile_result in result.kernel.compile_results: + compile_result.reload_cubin_path() + except RuntimeError: + log.warning( + "Failed to reload cubin file statically launchable autotuner %s", + result.kernel_name, + exc_info=True, + ) + continue + # We make a future instead of returning the kernel here so that + # kernels that are not statically launchable (i.e. cache miss) + # can launch a worker without waiting on the blocking step of + # StaticAutotunerFuture.result(). + CompiledTritonKernels._cache[result.cache_key] = StaticAutotunerFuture( + result.kernel + ) + counters["inductor"]["triton_bundler_load_static_autotuner"] += 1 + kernel_names.append(result.kernel_name) + return kernel_names + + @classmethod + def collect( + cls, + ) -> tuple[TritonBundle, Optional[TritonBundlerMetadata]]: + """ + This is the main function called when a cache write happens. This function + converts all the previously remembered kernels into bundled format so that + it can be written into a cache entry. + This function also finalizes the current bundle. + """ + from torch._inductor import config + + if not TritonBundler.is_enabled(): + cls.end_compile() + set_feature_use("triton_bundling", False) + return TritonBundle([], []), None + set_feature_use("triton_bundling", True) + + with dynamo_timed(key="TritonBundler.collect", log_pt2_compile_event=True): + entries = cls._entries + if entries is not None: + result: list[TritonKernelArtifacts] = [] + kernel_names: list[str] = [] + for entry in entries: + artifacts: list[TritonKernelArtifact] = [] + path = os.path.join(entry.directory, entry.kernel_hash) + if not os.path.exists(path): + continue + for filename in os.listdir(path): + filepath = os.path.join(path, filename) + try: + assert os.path.isfile(filepath) + with open(filepath, "rb") as file: + payload = file.read() + if filepath.endswith(".json"): + # Make sure there's no sentinel value + if TritonBundler._REPLACE_BYTES in payload: + log.warning( + "Bundle contains illegal %s, payload: %s", + TritonBundler._REPLACE_BYTES, + payload, + ) + raise AssertionError( + "Bundle contains illegal bytes" + ) + # Remove the path from payload + payload = payload.replace( + str.encode(path), TritonBundler._REPLACE_BYTES + ) + artifacts.append( + TritonKernelArtifact(filename, payload) + ) + counters["inductor"]["triton_bundler_save_kernel"] += 1 + except Exception: + log.debug("failed to collect triton kernel", exc_info=True) + extension = os.path.splitext(filename)[1] + if extension in GPU_KERNEL_BIN_EXTS.values(): + # Each kernel has bunch of files like .cubin(for cuda), .spv(for xpu), .json, .ttir + # Just append one of them without the extension + kernel_names.append(Path(filename).stem) + if artifacts: + result.append( + TritonKernelArtifacts( + entry.kernel_hash, + entry.device, + artifacts, + ) + ) + if config.use_static_cuda_launcher: + static_autotuners, static_kernel_names = ( + cls.collect_static_autotuners() + ) + else: + static_autotuners = [] + static_kernel_names = [] + cls.end_compile() + return TritonBundle(result, static_autotuners), TritonBundlerMetadata( + kernel_names, static_kernel_names + ) + return TritonBundle([], []), None + + @staticmethod + def read_and_emit(bundle: TritonBundle) -> Optional[TritonBundlerMetadata]: + """ + This is the main function called when a cache read happens. This function + converts the bundled format back into individual files and writes them + to the filesystem. + + NOTE: When we are writing to the filesystem, we assume exclusive access + to the target directory. + This means that if the target folder already exists and is non-empty, + we bail out. + Exclusive access means that no other process should be writing to + or reading from the target directory. + """ + from torch._inductor import config + + if not TritonBundler.is_enabled(): + return None + + with dynamo_timed( + key="TritonBundler.read_and_emit", log_pt2_compile_event=True + ): + kernel_names: list[str] = [] + + for artifacts in bundle.kernel_artifacts: + basedir = triton_cache_dir(artifacts.device) + directory = os.path.join(basedir, artifacts.kernel_hash) + + if os.path.exists(directory) and len(os.listdir(directory)) != 0: + # If directory already exists, we bail out and leave + # local disk to take care of caching + log.debug( + "Bailing out TritonBundler.read_and_emit, %s is non empty", + directory, + ) + continue + + Path(basedir).mkdir(parents=True, exist_ok=True) + + # Random ID to avoid any collisions + rnd_id = str(uuid.uuid4()) + tmp_dir = os.path.join(basedir, f"tmp.{rnd_id}") + os.makedirs(tmp_dir) + + for artifact in artifacts.artifacts: + filepath = os.path.join(tmp_dir, artifact.filename) + with open(filepath, "wb") as file: + payload = artifact.payload + if artifact.filename.endswith(".json"): + payload = payload.replace( + TritonBundler._REPLACE_BYTES, str.encode(directory) + ) + file.write(payload) + counters["inductor"]["triton_bundler_read_and_emit_kernel"] += 1 + extension = os.path.splitext(artifact.filename)[1] + if extension in GPU_KERNEL_BIN_EXTS.values(): + # Each kernel has bunch of files like .cubin(for cuda), spv(for xpu), .json, .ttir + # Just append one of them without the extension + kernel_names.append(Path(artifact.filename).stem) + + if _IS_WINDOWS: + with FileLock(directory + ".lock"): + if os.path.exists(directory): + shutil.rmtree(directory) + os.replace(tmp_dir, directory) + else: + # Atomic on POSIX systems + try: + os.replace(tmp_dir, directory) + except OSError: + log.warning("Directory %s is not empty - skipping!", tmp_dir) + + if config.use_static_cuda_launcher: + static_kernel_names = TritonBundler.load_autotuners( + bundle.static_autotuners + ) + else: + static_kernel_names = [] + return TritonBundlerMetadata(kernel_names, static_kernel_names) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..85a1d03a04f71a0ed1608cc943774bb1496fcd05 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/utils.py @@ -0,0 +1,4202 @@ +from __future__ import annotations + +import collections +import contextlib +import dataclasses +import enum +import functools +import importlib +import inspect +import io +import itertools +import logging +import math +import operator +import os +import platform +import re +import shutil +import statistics +import sys +import sysconfig +import tempfile +import textwrap +import time +import unittest +from collections.abc import ( + Callable, + Collection, + Generator, + Iterator, + Mapping, + MutableMapping, + MutableSet, +) +from datetime import datetime +from io import StringIO +from typing import ( + Any, + cast, + Concatenate, + Generic, + Literal, + NamedTuple, + Optional, + Protocol, + TYPE_CHECKING, + TypeAlias, + TypeGuard, + TypeVar, + Union, +) +from typing_extensions import dataclass_transform, ParamSpec, Self +from unittest import mock + +import sympy + +import torch +import torch.utils._pytree as pytree +from torch._inductor.analysis.device_info import datasheet_tops +from torch._inductor.runtime.hints import DeviceProperties +from torch.fx.passes.regional_inductor import _needs_inductor_compile +from torch.utils._dtype_abbrs import dtype_abbrs +from torch.utils._ordered_set import OrderedSet +from torch.utils._pytree import tree_flatten, tree_map_only + + +if TYPE_CHECKING: + from pathlib import Path + +OPTIMUS_EXCLUDE_POST_GRAD = [ + "activation_quantization_aten_pass", + "inductor_autotune_lookup_table", +] + +from torch.fx.experimental.symbolic_shapes import ( + free_symbols, + free_unbacked_symbols, + IterateExprs, + ShapeEnv, +) + + +if TYPE_CHECKING: + from collections.abc import Iterable, Sequence, ValuesView + + from torch import SymBool, SymFloat, SymInt + from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND + from torch.fx import GraphModule + from torch.fx.node import Node + + from .codegen.common import WorkspaceArg + from .codegen.wrapper import PythonWrapperCodegen + from .dependencies import Dep + from .graph import GraphLowering + from .ir import Buffer, ExternKernel, IRNode, Layout, Operation, ReinterpretView + from .output_code import CompiledFxGraph + from .scheduler import BaseSchedulerNode, SchedulerBuffer + + +GPU_TYPES = ["cuda", "mps", "xpu", "mtia"] +T = TypeVar("T") + + +# defines here before import torch._dynamo is for avoiding circular import +# when get_gpu_type is imported from dynamo +@functools.cache +def get_gpu_type() -> str: + avail_gpus = [x for x in GPU_TYPES if getattr(torch, x).is_available()] + assert len(avail_gpus) <= 1 + gpu_type = "cuda" if len(avail_gpus) == 0 else avail_gpus.pop() + return gpu_type + + +from torch._dynamo.device_interface import get_interface_for_device +from torch._dynamo.utils import detect_fake_mode +from torch.autograd import DeviceType +from torch.autograd.profiler_util import EventList +from torch.fx.passes.graph_transform_observer import GraphTransformObserver +from torch.fx.passes.shape_prop import ShapeProp +from torch.utils._sympy.functions import ( + CeilDiv, + CleanDiv, + FloorDiv, + Identity, + ModularIndexing, +) +from torch.utils._sympy.symbol import make_symbol, SymT +from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges + +from . import config +from .runtime.runtime_utils import ceildiv as runtime_ceildiv + + +_IS_WINDOWS = sys.platform == "win32" + +log = logging.getLogger(__name__) +perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints") + + +_T = TypeVar("_T") +VarRanges = dict[sympy.Expr, sympy.Expr] +InputType = Optional[Union[torch.Tensor, int, torch.SymInt]] + +GPU_KERNEL_BIN_EXTS = {"cuda": ".cubin", "xpu": ".spv"} + +GPU_ALIGN_BYTES = 16 +ALIGNMENT = 16 + +TMA_ALIGNMENT = 16 +TMA_DESCRIPTOR_SIZE = 128 + +ALIGN_BYTES = 64 +assert (ALIGN_BYTES & (ALIGN_BYTES - 1)) == 0 and ALIGN_BYTES >= 8, "must be power of 2" + + +def _align(nbytes: int) -> int: + """Round up to the nearest multiple of ALIGN_BYTES""" + return (nbytes + ALIGN_BYTES - 1) & -ALIGN_BYTES + + +def _is_aligned(v: sympy.Expr) -> bool: + """v can be statically proven to be a multiple of ALIGN_BYTES""" + if isinstance(v, (sympy.Add, sympy.Max)): + return all(map(_is_aligned, v.args)) + return isinstance(v, align) or sympy.gcd(v, ALIGN_BYTES) == ALIGN_BYTES + + +class align(sympy.Function): + """Symbolically round up to the nearest multiple of ALIGN_BYTES""" + + nargs = (1,) + is_integer = True + + @classmethod + def eval(cls, value: sympy.Expr) -> Optional[sympy.Expr]: + if isinstance(value, (int, sympy.Integer)): + return _align(int(value)) + if _is_aligned(value): + return value + + +@dataclasses.dataclass(frozen=True) +class GraphPartitionMap: + """ + Mapping from the partition info (e.g., input/output) to the graph info + """ + + # a unique id of graph partition + id: int + + # map partition input/output indices to graph input/output indices. None indicates + # a partition input/output is not a graph input/output. + input_index_mapping: list[Optional[int]] + output_index_mapping: list[Optional[int]] + + # name of constants read/written by the graph partition + constant_names: list[str] + + +def fp8_bench(fn: Callable[[], Any], warmup: int = 25, rep: int = 100) -> float: + """ + Returns benchmark results by examining torch profiler events. + This could be more accurate as it doesn't count CPU side overhead. + However, this also requires manually excluding irrelevant event, e.g. + vectorized_elementwise_kernel which is used to fill L2 cache, + various CUDA events, etc, so could also be fragile. + """ + + fn() + torch.cuda.synchronize() + cache = torch.empty(int(256e6 // 4), dtype=torch.float16, device="cuda") + + # Estimate the runtime of the function + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + for _ in range(5): + cache.zero_() + fn() + end_event.record() + torch.cuda.synchronize() + estimate_ms = start_event.elapsed_time(end_event) / 5 + + # compute number of warmup and repeat + n_warmup = max(1, int(warmup / estimate_ms)) + n_repeat = max(1, int(rep / estimate_ms)) + + # Warm-up + for _ in range(n_warmup): + fn() + + start_event = [torch.cuda.Event(enable_timing=True) for _ in range(n_repeat)] + end_event = [torch.cuda.Event(enable_timing=True) for _ in range(n_repeat)] + with torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CUDA, + ] + ) as p: + torch.cuda.synchronize() + for i in range(n_repeat): + cache.zero_() + start_event[i].record() + with torch.cuda.nvtx.range("RunCudaModule"): + fn() + end_event[i].record() + torch.cuda.synchronize() + times = torch.tensor( + [s.elapsed_time(e) for s, e in zip(start_event, end_event)] + ) + + res = torch.mean(times).item() + log.debug("raw events") + log.debug(p.key_averages().table(sort_by="self_device_time_total", row_limit=-1)) + filtered_events = EventList( + [ + event + for event in p.events() + if ( + event.device_type == DeviceType.CUDA + and re.match(r"fused_abs_max_\d", event.name) is not None + ) + ] + ) + if filtered_events: + res -= ( + statistics.mean(event.device_time_total for event in filtered_events) + / 1000.0 + ) + + log.debug("profiling results: %s ms", res) + return res + + +def do_bench_using_profiling( + fn: Callable[[], Any], + warmup: int = 25, + rep: int = 100, + is_vetted_benchmarking: bool = False, +) -> float: + # We did't use decorator may_distort_benchmarking_result directly since that + # requires us to import torch._inductor.runtime.benchmarking into global scope. + # Importing torch._inductor.runtime.benchmarking will cause cuda initialization + # (because of calling torch.cuda.available in global scope) + # which cause failure in vllm when it create child processes. Check log: + # https://gist.github.com/shunting314/c194e147bf981e58df095c14874dd65a + # + # Another way to solve the issue is to just move do_bench_using_profiling + # to torch._inductor.runtime.benchmarking and change all the call site. + # But that's not trivial due to so many call sites in and out of pytorch. + + from torch._inductor.runtime.benchmarking import may_distort_benchmarking_result + + return may_distort_benchmarking_result(_do_bench_using_profiling)( + fn, warmup, rep, is_vetted_benchmarking + ) + + +def _do_bench_using_profiling( + fn: Callable[[], Any], + warmup: int = 25, + rep: int = 100, + is_vetted_benchmarking: bool = False, +) -> float: + """ + Returns benchmark results by examining torch profiler events. + This could be more accurate as it doesn't count CPU side overhead. + However, this also requires manually excluding irrelevant event, e.g. + vectorized_elementwise_kernel which is used to fill L2 cache, + various CUDA events, etc, so could also be fragile. + """ + + if not is_vetted_benchmarking: + from torch._inductor.runtime.benchmarking import may_ban_benchmarking + + may_ban_benchmarking() + + fn() + torch.cuda.synchronize() + cache = torch.empty(int(256e6 // 4), dtype=torch.int, device="cuda") + + # Estimate the runtime of the function + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + for _ in range(5): + cache.zero_() + fn() + end_event.record() + torch.cuda.synchronize() + estimate_ms = start_event.elapsed_time(end_event) / 5 + + # compute number of warmup and repeat + n_warmup = max(1, int(warmup / estimate_ms)) + n_repeat = max(1, int(rep / estimate_ms)) + + # Warm-up + for _ in range(n_warmup): + fn() + + torch.cuda.synchronize() + + with torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CUDA, + ] + ) as p: + # Benchmark + for _ in range(n_repeat): + # we clear the L2 cache before each run + cache.zero_() + # record time of `fn` + fn() + # Record clocks + torch.cuda.synchronize() + + log.debug("raw events") + log.debug(p.key_averages().table(sort_by="self_device_time_total", row_limit=-1)) + + filtered_events = EventList( + [ + event + for event in p.events() + if event.device_type == DeviceType.CUDA and event.name != "Context Sync" + ] + ) + if len(filtered_events) % n_repeat != 0: + raise RuntimeError( + "Failed to divide all profiling events into #repeat groups. " + "#CUDA events: %d, #repeats: %s", + len(filtered_events), + n_repeat, + ) + num_event_per_group = len(filtered_events) / n_repeat + actual_events = EventList( + [ + event + for i, event in enumerate(filtered_events) + if i % num_event_per_group != 0 + ] + ) + actual_events._build_tree() + actual_events = actual_events.key_averages() + + log.debug("profiling time breakdown") + log.debug(actual_events.table(row_limit=-1)) + + res = sum(event.device_time_total for event in actual_events) / 1000.0 / n_repeat + log.debug("profiling results: %s ms", res) + return res + + +@functools.cache +def has_torchvision_roi_align() -> bool: + try: + from torchvision.ops import roi_align # noqa: F401 + + torch._C._dispatch_has_kernel_for_dispatch_key("torchvision::nms", "Meta") + return roi_align is not None and hasattr( + getattr(torch.ops, "torchvision", None), "roi_align" + ) + except ImportError: + return False + except RuntimeError as e: + assert "torchvision::nms does not exist" in str(e) + return False + + +def decode_device(device: Union[Optional[torch.device], str]) -> torch.device: + if device is None: + return torch.tensor(0.0).device # default device + if isinstance(device, str): + device = torch.device(device) + if device.type not in ("cpu", "meta") and device.index is None: + device_interface = get_interface_for_device(device.type) + return torch.device(device.type, index=device_interface.Worker.current_device()) + return device + + +def sympy_product(it: Iterable[sympy.Expr]) -> sympy.Expr: + return functools.reduce(operator.mul, it, sympy.S.One) + + +def sympy_dot(seq1: Sequence[sympy.Expr], seq2: Sequence[sympy.Expr]) -> sympy.Expr: + assert len(seq1) == len(seq2) + return sympy.expand(sum(a * b for a, b in zip(seq1, seq2))) + + +def unique(it: Iterable[_T]) -> ValuesView[_T]: + return {id(x): x for x in it}.values() + + +def ceildiv( + number: Union[int, sympy.Expr], denom: Union[int, sympy.Expr] +) -> Union[int, sympy.Expr]: + if isinstance(number, sympy.Expr) or isinstance(denom, sympy.Expr): + return CeilDiv(sympy.sympify(number), sympy.sympify(denom)) + # TODO: There is a bug in a call to this function, to repro: + # python benchmarks/dynamo/huggingface.py --inductor -d cuda --accuracy + # --amp --only YituTechConvBert --dynamic-shapes + assert isinstance(number, int) and isinstance(denom, int), ( + f"{number}: {type(number)}, {denom}: {type(denom)}" + ) + return runtime_ceildiv(number, denom) + + +def _type_of(key: Optional[torch.dtype]) -> str: + # Use the function here to get rid of dependencies on the Triton during the codegen. + # Refer to Triton implementation here: + # https://github.com/triton-lang/triton/blob/98b5945d2aef679e00ebca8e07c35c3658ec76de/python/triton/runtime/jit.py#L238 + # `None` is nullptr. Implicitly convert to *i8. + if key is None: + return "*i8" + dtype_str = str(key).split(".")[-1] + tys = { + "bool": "i1", + "float8e4nv": "fp8e4nv", + "float8e5": "fp8e5", + "float8e4b15": "fp8e4b15", + "float8e4b15x4": "fp8e4b15x4", + "float8_e4m3fn": "fp8e4nv", + "float8_e5m2": "fp8e5", + # TODO: remove when support is added in triton + # https://github.com/triton-lang/triton/issues/6054 + "float8_e8m0fnu": "u8", + "float4_e2m1fn_x2": "u8", + "float16": "fp16", + "bfloat16": "bf16", + "float32": "fp32", + "float64": "fp64", + "int8": "i8", + "int16": "i16", + "int32": "i32", + "int64": "i64", + "uint8": "u8", + "uint16": "u16", + "uint32": "u32", + "uint64": "u64", + } + # reinterpret can create triton type + tys.update({v: v for v in list(tys.values())}) + return key if isinstance(key, str) else f"*{tys[dtype_str]}" + + +def convert_shape_to_inductor( + lst: Iterable[Union[int, torch.SymInt]], +) -> list[sympy.Expr]: + """ + Gets the shape and stride of a tensor. For non-symbolic tensors, this is + trivial. But for symbolic tensors, we need to map from SymIntNode into + sympy.Expr. + """ + return [sympy.sympify(i) for i in lst] + + +def convert_to_symint(i: Union[int, sympy.Expr]) -> Union[int, torch.SymInt]: + """ + Like convert_shape_to_symint, but operates on a single expression. + """ + from .virtualized import V + + return ( + i + if isinstance(i, int) + else ( + int(i) + if isinstance(i, sympy.Integer) + else V.graph.sizevars.shape_env.create_symintnode(i, hint=None) + ) + ) + + +def convert_shape_to_symint( + lst: Iterable[Union[int, sympy.Expr]], +) -> list[Union[int, torch.SymInt]]: + """ + Takes a list of shapes from Inductor and converts them into symints (or just + ints if all shapes are static). + """ + return [convert_to_symint(i) for i in lst] + + +def is_view(op: torch._ops.OpOverload) -> bool: + """ + Does this op overload have aliasing + """ + return any(a.alias_info is not None for a in op._schema.arguments) + + +def is_pointwise_use( + use: Node, + is_pointwise_fn: Callable[[torch._ops.OpOverload], bool] = lambda _: False, +) -> bool: + """ + Do all uses of this op have torch.Tag.pointwise or return True for optional `is_pointwise_fn` + + Uses in views ops will follow the views uses + """ + + if use.op != "call_function": + return False + if not ( + isinstance(use.target, torch._ops.OpOverload) or use.target is operator.getitem + ): + return False + + target = cast(torch._ops.OpOverload, use.target) + if target is operator.getitem or is_view(target): + return all(is_pointwise_use(u, is_pointwise_fn) for u in use.users) + + return torch.Tag.pointwise in target.tags or is_pointwise_fn(target) + + +def gen_gm_and_inputs( + target: Any, args: list[Any], kwargs: dict[str, Any] +) -> tuple[GraphModule, list[torch.Tensor]]: + g = torch.fx.Graph() + graph_args: list[torch.Tensor] = [] + + def add_tensor_arg(arg: torch.Tensor) -> Node: + graph_args.append(arg) + return g.placeholder(f"arg{len(graph_args)}") + + node = g.call_function( + target, *tree_map_only(torch.Tensor, add_tensor_arg, (args, kwargs)) + ) + if ( + len(target._schema.returns) == 1 + and str(target._schema.returns[0].type) == "Tensor" + ): + node = (node,) # type: ignore[assignment] + g.output(node) + + gm = torch.fx.GraphModule({}, g) + return gm, graph_args + + +def synchronize(device: str = "cuda") -> None: + if device == "cpu": + return + device_interface = get_interface_for_device(device) + if device_interface.is_available(): + device_interface.synchronize() + + +def timed( + model: Callable[..., Any], + example_inputs: Sequence[Any], + times: int = 1, + device: str = "cuda", +) -> float: + synchronize(device) + torch.manual_seed(1337) + t0 = time.perf_counter() + for _ in range(times): + result = model(*example_inputs) + synchronize(device) + t1 = time.perf_counter() + # GC the result after timing + assert result is not None # type: ignore[possibly-undefined] + return t1 - t0 + + +def print_performance( + model: Callable[..., Any], + example_inputs: Sequence[Any] = (), + times: int = 10, + repeat: int = 10, + baseline: float = 1.0, + device: str = "cuda", +) -> float: + timings = torch.tensor( + [timed(model, example_inputs, times, device) for _ in range(repeat)] + ) + took = torch.median(timings) / times + print(f"{took / baseline:.6f}") + return took.item() + + +def precompute_method(obj: Any, method: str) -> None: + """Replace obj.method() with a new method that returns a precomputed constant.""" + result = getattr(obj, method)() + setattr(obj, method, lambda: result) + + +def precompute_methods(obj: Any, methods: list[str]) -> None: + """Replace methods with new methods that returns a precomputed constants.""" + for method in methods: + precompute_method(obj, method) + + +def cmp(a: int, b: int) -> int: + return int(a > b) - int(a < b) + + +def pad_listlike(x: Union[int, Sequence[int]], size: int) -> Sequence[int]: + if isinstance(x, int): + return [x] * size + if len(x) == 1: + return type(x)([x[0]]) * size # type: ignore[call-arg, operator, return-value] + return x + + +# Used to ensure that iterating over a set is deterministic +def tuple_sorted(x: tuple[_T, ...]) -> list[_T]: + if len(x) == 0: + return [] + + def sort_func(elem: _T) -> str: + if isinstance(elem, str): + return elem + + from .scheduler import BaseSchedulerNode + + assert isinstance(elem, BaseSchedulerNode) + return elem.get_name() + + return sorted(x, key=sort_func) + + +P = ParamSpec("P") +RV = TypeVar("RV", covariant=True) +FN_TYPE = Callable[Concatenate[Any, P], RV] + + +class CachedMethod(Protocol, Generic[P, RV]): + @staticmethod + def clear_cache(cache: Any) -> None: ... + + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> RV: ... + + +# See https://github.com/python/mypy/issues/13222#issuecomment-1193073470 to understand the type signature +def cache_on_self(fn: Callable[Concatenate[Any, P], RV]) -> CachedMethod[P, RV]: + name = fn.__name__ + key = f"__{name}_cache" + + # wrapper is likely on the hot path, compile a specialized version of it + ctx = {"fn": fn} + exec( + f"""\ + def {name}_cache_on_self(self): + try: + return self.{key} + except AttributeError: + pass + rv = fn(self) + object.__setattr__(self, "{key}", rv) + return rv + """.lstrip(), + ctx, + ) + wrapper = functools.wraps(fn)(ctx[f"{name}_cache_on_self"]) + + def clear_cache(self: Any) -> None: + if hasattr(self, key): + delattr(self, key) + + wrapper.clear_cache = clear_cache # type: ignore[attr-defined] + return wrapper # type: ignore[return-value] + + +def cache_property_on_self(fn: Callable[P, RV]) -> CachedMethod[P, RV]: + """ + Variant of cache_on_self for properties. The only difference is the type signature. + """ + # pyrefly: ignore [bad-argument-type] + return cache_on_self(fn) + + +def cache_on_self_and_args( + class_name: str, +) -> Callable[[FN_TYPE[P, RV]], FN_TYPE[P, RV]]: + # include both class_name and fn_name in the key to support `super().fn(self, **args, **kwargs)` calls. + + def wrapper( + fn: FN_TYPE[P, RV], + ) -> FN_TYPE[P, RV]: + key = f"__{class_name}_{fn.__name__}_cache" + + # wrapper is likely on the hot path, compile a specialized version of it + ctx = {"fn": fn} + exec( + f"""\ + def inner(self: Any, *args: P.args, **kwargs: P.kwargs) -> RV: + args_kwargs = (args, tuple(sorted(kwargs.items()))) + + if not hasattr(self, "{key}"): + object.__setattr__(self, "{key}", {{}}) + + cache = self.{key} + + try: + return cache[args_kwargs] + except KeyError: + pass + + rv = fn(self, *args, **kwargs) + + cache[args_kwargs] = rv + return rv + """.lstrip(), + ctx, + ) + inner = functools.wraps(fn)(ctx["inner"]) + + def clear_cache(self: Any) -> None: + if hasattr(self, key): + delattr(self, key) + + inner.clear_cache = clear_cache # type: ignore[attr-defined] + return inner + + return wrapper + + +def aggregate_origins( + node_schedule: Union[Sequence[BaseSchedulerNode], ExternKernel], +) -> OrderedSet[Node]: + from . import ir + + if isinstance(node_schedule, list): + return functools.reduce( + operator.or_, + [ + # pyrefly: ignore [missing-attribute] + node.node.origins + for node in node_schedule + if hasattr(node, "node") and node.node + ], + OrderedSet(), + ) + elif isinstance(node_schedule, ir.ExternKernel): + return node_schedule.origins + else: + return OrderedSet() + + +def get_fused_kernel_name( + node_schedule: Sequence[BaseSchedulerNode], + descriptive_names: Literal[True, "torch", "original_aten", "inductor_node"], +) -> str: + all_origins = aggregate_origins(node_schedule) + if descriptive_names == "original_aten": + + def get_origin_meta_str(origin): + original_aten = origin.meta["original_aten"] + key = "" + if isinstance(original_aten, torch._ops.OpOverload): + key = original_aten._overloadpacket.__name__ + elif isinstance(original_aten, torch._ops.HigherOrderOperator): + key = str(original_aten.name()) + return key + + # Bases the kernel name off of the top-level aten operator (i.e. pre-decompositions) + sources = [ + get_origin_meta_str(origin) + for origin in all_origins + if origin.op == "call_function" + and "original_aten" in origin.meta + and origin.meta["original_aten"] is not None + ] + sources = sorted(OrderedSet(sources)) + elif descriptive_names == "torch": + # Bases the kernel name off of the top-level "torch" operator (i.e. post-dynamo graph) + sources = [] + for origin in all_origins: + if origin.op == "call_function": + source_fn = None + suffix = "" + if "source_fn_stack" in origin.meta: + source_fn = origin.meta["source_fn_stack"][-1] + elif "fwd_source_fn_stack" in origin.meta: + # backward nodes have "fwd_source_fn_stack" instead + source_fn = origin.meta["fwd_source_fn_stack"][-1] + suffix = "backward" + if not source_fn: + continue + if isinstance(source_fn[1], str): + sources.append(source_fn[1] + suffix) + else: + sources.append(source_fn[1].__name__ + suffix) + + sources = sorted(OrderedSet(sources)) + elif descriptive_names == "inductor_node": + sources = [ + origin.name for origin in all_origins if origin.op == "call_function" + ] + else: + raise NotImplementedError + return "_".join(["fused"] + sources) + + +def get_kernel_metadata( + node_schedule: Union[Sequence[BaseSchedulerNode], ExternKernel], + wrapper: PythonWrapperCodegen, +) -> tuple[str, str]: + """ + Retrieves metadata information for a kernel. + Args: + node_schedule (Union[Sequence[BaseSchedulerNode], ExternKernel]): + Either a sequence of BaseSchedulerNode objects or an ExternKernel instance. + wrapper (PythonWrapperCodegen): + An instance of PythonWrapperCodegen, used to define the code comment format. + Returns: + tuple[str, str]: + A tuple containing two strings: + - The first string represents the kernel's metadata. + - The second string represent the kernel's detailed metadata. + """ + + all_origins = aggregate_origins(node_schedule) + inductor_nodes = [origin for origin in all_origins if origin.op == "call_function"] + + from_node_dict = collections.defaultdict(list) + original_aten_dict = collections.defaultdict(list) + + # Attempt to sort `inductor_nodes` topologically. Note that the case + # where `inductor_nodes` contains nodes from multiple graph instances + # is not supported. An example of this is conditional statements. + single_graph = None + if inductor_nodes: + unique_graphs = OrderedSet(n.graph for n in inductor_nodes) + if len(unique_graphs) == 1: + single_graph = inductor_nodes[0].graph + # create a map of idx -> node and cache it + if not hasattr(single_graph, "_inductor_kernel_metadata_node_to_idx_map"): + node_to_idx_map = {n: idx for idx, n in enumerate(single_graph.nodes)} + single_graph._inductor_kernel_metadata_node_to_idx_map = node_to_idx_map # type: ignore[attr-defined] + inductor_nodes.sort( + key=lambda n: single_graph._inductor_kernel_metadata_node_to_idx_map[n] # type: ignore[attr-defined] + ) + + for node in inductor_nodes: + if "original_aten" in node.meta and node.meta["original_aten"] is not None: + original_aten = node.meta["original_aten"] + key = None + if isinstance(original_aten, torch._ops.OpOverload): + key = str(original_aten._overloadpacket) + elif isinstance(original_aten, torch._ops.HigherOrderOperator): + key = str(original_aten.name()) + if key: + original_aten_dict[key].append(node.name) + if "from_node" in node.meta: + key = node.meta["from_node"][0].name + from_node_dict[key].append(node.name) + elif node.meta.get("partitioner_tag") == "is_backward": + # backward nodes currently don't have a "from node" + from_node_dict[node.name].append(node.name) + sort_str = "Topologically Sorted" if single_graph is not None else "Unsorted" + metadata = ( + f"{wrapper.comment} {sort_str} Source Nodes: [{', '.join(from_node_dict.keys())}], " + f"Original ATen: [{', '.join(original_aten_dict.keys())}]" + ) + + # trace back to original node here + detailed_metadata = [f"{wrapper.comment} Source node to ATen node mapping:"] + for original_node, nodes in sorted(from_node_dict.items()): + detailed_metadata.append( + f"{wrapper.comment} {original_node} => {', '.join(sorted(nodes))}" + ) + + # print the aot_autograd graph fragment + if single_graph is not None: + from . import ir + + detailed_metadata.append(f"{wrapper.comment} Graph fragment:") + all_reads: OrderedSet[str] = OrderedSet() + all_writes: list[str] = [] + if not isinstance(node_schedule, ir.ExternKernel): + from .virtualized import V + + def get_buffer_info( + buffer: Union[ir.TensorBox, ir.Buffer, ir.TorchBindObject], rw_name: str + ) -> tuple[str, ir.Layout | None]: + if isinstance(buffer, ir.TensorBox) and isinstance( + buffer.data, ir.StorageBox + ): + origin_node = buffer.data.data.origin_node + else: + origin_node = buffer.origin_node + if origin_node is None: + # use the read/write name if no origin node is found + name = rw_name + else: + name = origin_node.name + try: + layout = buffer.get_layout() + except NotImplementedError: + layout = None + return name, layout + + def stringify_shape(shape: Iterable[int]) -> str: + return f"[{', '.join([str(x) for x in shape])}]" + + def stringfy_layout(layout: ir.Layout | None) -> str: + if layout is None: + return "" + shape_annotation = f"{stringify_shape(layout.size)}" + stride_annotation = f"{stringify_shape(layout.stride)}" + device_annotation = f"{layout.device}" + + return ( + f'"{dtype_abbrs[layout.dtype]}{shape_annotation}' + f'{stride_annotation}{device_annotation}"' + ) + + for n in node_schedule: + if not hasattr(n, "read_writes") or n.read_writes is None: + continue + if hasattr(n.read_writes, "reads") and n.read_writes.reads is not None: + for r in n.read_writes.reads: + # Remove the dupricated inputs + if r.name in all_reads: + continue + all_reads.add(r.name) + buffer = V.graph.try_get_buffer(r.name) + if buffer is None: + continue + input_name, layout = get_buffer_info(buffer, r.name) + detailed_metadata.append( + f"{wrapper.comment} %{input_name} : Tensor " + f"{stringfy_layout(layout)} = PlaceHolder[target={input_name}]" + ) + + if ( + hasattr(n.read_writes, "writes") + and n.read_writes.writes is not None + ): + for w in n.read_writes.writes: + buffer = V.graph.try_get_buffer(w.name) + if buffer is None: + continue + output_name, _ = get_buffer_info(buffer, w.name) + + all_writes.append("%" + output_name) + + for node in inductor_nodes: + detailed_metadata.append( + f"{wrapper.comment} {node.format_node(include_tensor_metadata=True)}" + ) + + detailed_metadata.append(f"{wrapper.comment} return {','.join(all_writes)}") + + return metadata, "\n".join(detailed_metadata) + + +def dominated_nodes( + initial_queue: Iterable[torch.fx.Node], + skip_filter: Optional[Callable[[Any], bool]] = None, +) -> OrderedSet[torch.fx.Node]: + """Returns the set of nodes whose values depend on those within initial_queue""" + initial_queue = list(initial_queue) + dominated_set = OrderedSet(initial_queue) + + while initial_queue: + node = initial_queue.pop() + for user in node.users: + if skip_filter and skip_filter(user): + continue + if user not in dominated_set: + dominated_set.add(user) + initial_queue.append(user) + + return dominated_set + + +def gather_origins( + args: Sequence[IRNode], kwargs: dict[str, IRNode] +) -> OrderedSet[torch.fx.Node]: + from . import ir + + def is_unrealized_node(n: IRNode) -> bool: + if isinstance(n, ir.TensorBox): + return is_unrealized_node(n.data) + if isinstance(n, ir.StorageBox): + return is_unrealized_node(n.data) + return isinstance(n, ir.IRNode) and not isinstance( + n, + ( + ir.ComputedBuffer, + ir.InputsKernel, + ir.InputBuffer, + ir.TemplateBuffer, + ), + ) + + # kwargs and args may include a container of node, for example torch.cat([t1, t2]) + # flatten them before search the unrealized nodes + kwargs_flatten, _ = tree_flatten(kwargs) + kwargs_origins = [val.origins for val in kwargs_flatten if is_unrealized_node(val)] + args_flatten, _ = tree_flatten(args) + args_origins = [val.origins for val in args_flatten if is_unrealized_node(val)] + return OrderedSet(itertools.chain(*args_origins, *kwargs_origins)) + + +def sympy_str(expr: sympy.Expr) -> str: + """ + Normal sympy str is very slow, this is a lot faster. The result are + somewhat worse, as it doesn't do as much simplification. So don't + use this for final codegen. + """ + + def is_neg_lead(expr: sympy.Expr) -> bool: + return ( + isinstance(expr, sympy.Mul) and len(expr.args) == 2 and expr.args[0] == -1 + ) + + def sympy_str_add(expr: sympy.Expr) -> str: + if isinstance(expr, sympy.Add): + # Special case 'a - b'. Note that 'a - b - c' will still appear as + # 'a + -1 * b + -1 * c'. + if len(expr.args) == 2 and is_neg_lead(expr.args[1]): + return f"{sympy_str_mul(expr.args[0])} - {sympy_str_mul(expr.args[1].args[1])}" + else: + return " + ".join(map(sympy_str_mul, expr.args)) + else: + return sympy_str_mul(expr) + + def sympy_str_mul(expr: sympy.Expr) -> str: + if isinstance(expr, sympy.Mul): + if is_neg_lead(expr): + # Special case '-a'. Note that 'a * -b' will still appear as + # '-1 * a * b'. + return f"-{sympy_str_atom(expr.args[1])}" + else: + return " * ".join(map(sympy_str_atom, expr.args)) + else: + return sympy_str_atom(expr) + + def sympy_str_atom(expr: sympy.Expr) -> str: + if isinstance(expr, sympy.Symbol): + return expr.name + elif isinstance(expr, (sympy.Add, sympy.Mul)): + return f"({sympy_str_add(expr)})" + elif isinstance(expr, (ModularIndexing, CleanDiv, FloorDiv, Identity)): + return f"{expr.func.__name__}({', '.join(map(sympy_str, expr.args))})" + else: + return str(expr) + + return sympy_str_add(expr) + + +def get_bounds_index_expr(index: sympy.Expr) -> ValueRanges[Any]: + from .virtualized import V + + # If this expression does not come from an FX node, we compute its bounds + if ( + config.compute_all_bounds + and (fx_node := getattr(V.interpreter, "current_node", None)) + and fx_node.target != "index_expr" + ): + return bound_sympy(index) + else: + return ValueRanges.unknown() + + +def prefix_is_reduction(prefix: str) -> bool: + return prefix[0] == "r" + + +def sympy_index_symbol_with_prefix(prefix: SymT, idx: int) -> sympy.Symbol: + """ + Used to generate an integer-nonnegative symbol. + """ + # This should never be used for creating shape/stride symbols, as those + # should all be allocated before Inductor. + assert prefix != SymT.SIZE + # NOTE: shape symbols are positive (> 0), but index variables are only + # non-negative (>= 0). + return make_symbol(prefix, idx, integer=True, nonnegative=True) + + +def generate_assert(check: bool) -> bool: + return (check or config.debug_index_asserts) and config.assert_indirect_indexing + + +def sympy_index_symbol(name: str) -> sympy.Symbol: + """ + Used to generate an integer-nonnegative symbol. + """ + # This should never be used for creating shape/stride symbols, as those + # should all be allocated before Inductor. + assert name[0] != "s" + # NOTE: shape symbols are positive (> 0), but index variables are only + # non-negative (>= 0). + return sympy.Symbol(name, integer=True, nonnegative=True) + + +def sympy_subs(expr: sympy.Expr, replacements: dict[sympy.Expr, Any]) -> sympy.Expr: + """ + When the passed replacement symbol v is a string, it is converted to a symbol with name v that + have the same replaced expression integer and nonnegative properties. + """ + + def to_symbol( + replaced: sympy.Expr, replacement: Union[sympy.Expr, str] + ) -> sympy.Symbol: + assert isinstance(replaced, sympy.Expr) + if isinstance(replacement, str): + return sympy.Symbol( + replacement, + integer=replaced.is_integer, # type: ignore[attr-defined] + nonnegative=replaced.is_nonnegative, # type: ignore[attr-defined] + ) + else: + return replacement + + # xreplace is faster than subs, but is way more picky + return sympy.sympify(expr).xreplace( + {k: to_symbol(k, v) for k, v in replacements.items()} + ) + + +def is_symbolic(a: Any) -> TypeGuard[Union[torch.SymInt, torch.Tensor]]: + return isinstance(a, torch.SymInt) or ( + isinstance(a, torch.Tensor) + and any(is_symbolic(x) for x in itertools.chain(a.size(), a.stride())) + ) + + +def any_is_symbolic(*args: Any) -> bool: + return any(is_symbolic(a) for a in args) + + +# Ops that are fundamentally incompatible with CUDA graph capture +# (e.g., CPU synchronization, dynamic memory allocation, etc.) +FORBIDDEN_CUDAGRAPH_OPS = frozenset( + [ + "aten._fused_moving_avg_obs_fq_helper.default", + "aten._fused_moving_avg_obs_fq_helper_functional.default", + "fbgemm.dense_to_jagged.default", + "fbgemm.jagged_to_padded_dense.default", + "run_and_save_rng_state", + "run_with_rng_state", + "aten._local_scalar_dense", + # Technically, it's not necessary to ban this, because an + # assert_scalar with constant arguments can be validly run + # with CUDA graphs, but the operator is also pointless with + # constant arguments, so might as well ban + "aten._assert_scalar", + ] +) + + +def get_first_incompatible_cudagraph_node( + gm: torch.fx.GraphModule, +) -> Optional[torch.fx.Node]: + from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols + + for node in gm.graph.nodes: + if is_cudagraph_unsafe_fx_node(node): + return node + + if (val := node.meta.get("val")) is not None and free_unbacked_symbols(val): + return node + + return None + + +def output_node(gm: torch.fx.GraphModule) -> Node: + """Get the output node from an FX graph""" + last_node = next(iter(reversed(gm.graph.nodes))) + assert last_node.op == "output" + return last_node + + +def get_all_devices(gm: torch.fx.GraphModule) -> OrderedSet[torch.device]: + placeholder_nodes = gm.graph.find_nodes(op="placeholder") + input_devices: OrderedSet[torch.device] = OrderedSet( + node.meta["val"].device + for node in placeholder_nodes + if isinstance(node.meta.get("val"), torch.Tensor) + ) + + out_arg = output_node(gm).args[0] # type: ignore[union-attr] + out_args = out_arg if isinstance(out_arg, tuple) else (out_arg,) + out_devices: OrderedSet[torch.device] = OrderedSet( + arg.meta["val"].device + for arg in out_args + if isinstance(arg, torch.fx.Node) + and isinstance(arg.meta.get("val"), torch.Tensor) + ) + return input_devices | out_devices + + +import gc + + +def unload_xpu_triton_pyds() -> None: + # unload __triton_launcher.pyd + for module_name in list(sys.modules.keys()): + if not module_name.startswith("torch._inductor.runtime.compile_tasks."): + continue + m = sys.modules[module_name] + for attr_name in m.__dict__: + if attr_name.startswith("triton_"): + kernel = getattr(m, attr_name) + if isinstance( + kernel, torch._inductor.runtime.triton_heuristics.CachingAutotuner + ): + for result in kernel.compile_results: + if isinstance( + result, + torch._inductor.runtime.triton_heuristics.TritonCompileResult, + ): + # pyrefly: ignore [missing-attribute] + result.kernel.run.mod.__del__() + del sys.modules[module_name] + + # unload spirv_utils.pyd + if "triton.runtime.driver" in sys.modules: + mod = sys.modules["triton.runtime.driver"] + del type(mod.driver.active.utils).instance + del mod.driver.active.utils + + gc.collect() + + +_registered_caches: list[Any] = [] + + +def clear_on_fresh_cache(obj: Any) -> Any: + """ + Use this decorator to register any caches that should be cache_clear'd + with fresh_cache(). + """ + if not hasattr(obj, "cache_clear") or not callable(obj.cache_clear): + raise AttributeError(f"{obj} does not have a cache_clear method") + + _registered_caches.append(obj) + return obj + + +def clear_caches() -> None: + """ + Clear all registered caches. + """ + for obj in _registered_caches: + obj.cache_clear() + + +@contextlib.contextmanager +def fresh_cache( + cache_entries: Optional[dict[str, Any]] = None, + dir: Optional[str] = None, + delete: bool = True, +) -> Iterator[None]: + """ + Contextmanager that provides a clean tmp cachedir for pt2 caches. + + Optionally, pass a dict as 'cache_entries' to get a list of filenames and sizes + generated with this cache instance. + """ + clear_caches() + + from torch._inductor.cpp_builder import normalize_path_separator + + inductor_cache_dir = normalize_path_separator(tempfile.mkdtemp(dir=dir)) + try: + with mock.patch.dict( + os.environ, {"TORCHINDUCTOR_CACHE_DIR": inductor_cache_dir} + ): + log.debug("Using inductor cache dir %s", inductor_cache_dir) + triton_cache_dir = normalize_path_separator( + os.path.join(inductor_cache_dir, "triton") + ) + with mock.patch.dict(os.environ, {"TRITON_CACHE_DIR": triton_cache_dir}): + yield + if isinstance(cache_entries, dict): + assert len(cache_entries) == 0, "expected empty cache_entries dict" + if os.path.exists(triton_cache_dir): + files = os.listdir(triton_cache_dir) + cache_entries.update( + { + f: os.path.getsize(os.path.join(triton_cache_dir, f)) + for f in files + if ".lock" not in f + } + ) + if delete: + if is_windows() and torch.xpu.is_available(): + unload_xpu_triton_pyds() + + shutil.rmtree( + inductor_cache_dir, + # Let's not fail if we can't clean up the temp dir. Also note that for + # Windows, we can't delete the loaded modules because the module binaries + # are open. + ignore_errors=is_windows(), + onerror=lambda func, path, exc_info: log.warning( + "Failed to remove temporary cache dir at %s", + inductor_cache_dir, + exc_info=exc_info, + ), + ) + except Exception: + log.warning("on error, temporary cache dir kept at %s", inductor_cache_dir) + raise + finally: + clear_caches() + + +# Deprecated functions -- only keeping them for BC reasons +clear_on_fresh_inductor_cache = clear_on_fresh_cache +clear_inductor_caches = clear_caches +fresh_inductor_cache = fresh_cache + + +def argsort(seq: Sequence[Any], *, reverse: bool = False) -> list[int]: + getter = seq.__getitem__ + a_r = range(len(seq)) + # preserve original order for equal strides + # e.g. if strides are [32, 8, 8, 1] + # argsort -> [3, 2, 1, 0], rather than + # [3, 1, 2, 0] + # i.e. for equal strides in ascending order (reverse=False) an + # inner dimension should come before an outer dimension, and vice versa + # for descending + sort_idx = list(sorted(a_r, key=getter, reverse=True)) # noqa: C413 + if not reverse: + return list(reversed(sort_idx)) + return sort_idx + + +def argsort_sym( + shape_env: ShapeEnv, + seq: Sequence[Union[int, torch.SymInt, sympy.Expr]], + *, + reverse: bool = False, +) -> list[int]: + def cmp(a: tuple[int, sympy.Expr], b: tuple[int, sympy.Expr]) -> int: + a_idx, a_val = a + b_idx, b_val = b + + def evaluate(expr: Union[bool, torch.SymInt, sympy.Expr]) -> bool: + if isinstance(expr, bool): + return expr + return shape_env.evaluate_expr(expr, size_oblivious=True) + + if evaluate(a_val < b_val): + return -1 + if evaluate(a_val > b_val): + return 1 + # If strides are the same, prefer the original order. + # (this matches argsort's algorithm). + # For strides = [2048, 2048, 16, 1], this is + # [3, 2, 1, 0]. + if a_idx < b_idx: + return 1 + if a_idx > b_idx: + return -1 + return 0 + + # Strategy: convert all symints to sympy.Expr, then use a custom comparator + exprs = [ + (idx, s.node.expr if isinstance(s, torch.SymInt) else s) + for idx, s in enumerate(seq) + ] + exprs = sorted(exprs, key=functools.cmp_to_key(cmp), reverse=reverse) + result = [idx for idx, _ in exprs] + return result + + +@functools.lru_cache(8) +def get_dtype_size(dtype: torch.dtype) -> int: + # TODO: Investigate why uint64 tensor creation causes overflow error: + # Workaround for RuntimeError in memory size calculation, but underlying cause unclear + if dtype == torch.uint64: + return 8 + return torch.empty((), dtype=dtype).element_size() + + +class LineContext(NamedTuple): + context: Any + + +@dataclasses.dataclass +class ValueWithLineMap: + value: str + line_map: list[tuple[int, LineContext]] + + +class IndentedBuffer: + tabwidth = 4 + + def __init__(self, initial_indent: int = 0) -> None: + self._lines: list[Union[DeferredLineBase, LineContext, str]] = [] + self._indent = initial_indent + + @contextlib.contextmanager + def set_tabwidth(self, tabwidth: int) -> Iterator[None]: + prev = self.tabwidth + try: + self.tabwidth = tabwidth + yield + finally: + self.tabwidth = prev + + def getvaluewithlinemap(self) -> ValueWithLineMap: + buf = StringIO() + p = 1 + linemap: list[tuple[int, LineContext]] = [] + for li in self._lines: + if isinstance(li, DeferredLineBase): + line = li() + if line is None: + continue + elif isinstance(li, LineContext): + linemap.append((p, li.context)) + continue + else: + line = li + assert isinstance(line, str) + buf.write(line) + buf.write("\n") + p += 1 + line.count("\n") + return ValueWithLineMap(buf.getvalue(), linemap) + + def getvalue(self) -> str: + return self.getvaluewithlinemap().value + + def getrawvalue(self) -> str: + buf = StringIO() + for li in self._lines: + if isinstance(li, DeferredLineBase): + line = li() + if line is None: + continue + elif isinstance(li, LineContext): + continue + else: + line = li + assert isinstance(line, str) + # backslash implies line continuation + if line.endswith("\\"): + buf.write(line[:-1]) + else: + buf.write(line) + buf.write("\n") + return buf.getvalue() + + def clear(self) -> None: + self._lines.clear() + + def __bool__(self) -> bool: + return bool(self._lines) + + def prefix(self) -> str: + return " " * (self._indent * self.tabwidth) + + def newline(self) -> None: + self.writeline("\n") + + def writeline(self, line: Union[LineContext, DeferredLineBase, str]) -> None: + if isinstance(line, LineContext): + self._lines.append(line) + elif isinstance(line, DeferredLineBase): + self._lines.append(line.with_prefix(self.prefix())) + elif line.strip(): + self._lines.append(f"{self.prefix()}{line}") + else: + self._lines.append("") + + def writelines( + self, lines: Sequence[Union[LineContext, DeferredLineBase, str]] + ) -> None: + for line in lines: + self.writeline(line) + + def indent(self, offset: int = 1) -> contextlib.AbstractContextManager[None]: + @contextlib.contextmanager + def ctx() -> Iterator[None]: + self._indent += offset + try: + yield + finally: + self._indent -= offset + + return ctx() + + def do_indent(self, offset: int = 1) -> None: + self._indent += offset + + def do_unindent(self, offset: int = 1) -> None: + self._indent -= offset + + def splice( + self, other_code: Union[IndentedBuffer, str], strip: bool = False + ) -> None: + if isinstance(other_code, IndentedBuffer): + dedent = float("inf") + # pyrefly: ignore [bad-assignment] + for line in other_code._lines: + if not isinstance(line, LineContext) and line: + dedent = min(dedent, len(line) - len(line.lstrip())) + if math.isinf(dedent): + dedent = 0 + for line in other_code._lines: + if isinstance(line, LineContext): + self._lines.append(line) + else: + IndentedBuffer.writeline(self, line[int(dedent) :]) + else: + other_code = textwrap.dedent(other_code) + if strip: + other_code = other_code.lstrip() + if not other_code: + return + other_code = other_code.rstrip() + for s in other_code.split("\n"): + self.writeline(s) + + def map(self, func: Callable[[Any], Any]) -> IndentedBuffer: + res = IndentedBuffer(initial_indent=self._indent) + res._lines = [func(line) for line in self._lines] + return res + + def __repr__(self) -> str: + return f"{type(self)}({self.getvalue()})" + + def __add__(self, other: Self) -> IndentedBuffer: + assert self._indent == other._indent + res = IndentedBuffer(initial_indent=self._indent) + # TODO(rec): or should this be self.__class__(initial_indent=self._indent)? + res.writelines(self._lines) + res.writelines(other._lines) + return res + + def contains(self, new_line: Union[DeferredLineBase, LineContext, str]) -> bool: + return new_line in self._lines + + +class FakeIndentedBuffer(IndentedBuffer): + def __init__(self) -> None: + super().__init__() + + def __getattribute__(self, name: str) -> Any: + if name == "__class__": # Allow access to the class attribute + return object.__getattribute__(self, name) + raise RuntimeError( + f"Tried to call self.{name} on FakeIndentedBuffer. This buffer" + "is currently used on TritonTemplateKernel to prevent actual" + "writes to the body without explicitly specifying the body with" + "`TritonTemplateKernel.set_subgraph_body(name)`" + ) + + +@contextlib.contextmanager +def restore_stdout_stderr() -> Iterator[None]: + initial_stdout, initial_stderr = sys.stdout, sys.stderr + try: + yield + finally: + sys.stdout, sys.stderr = initial_stdout, initial_stderr + + +class DeferredLineBase: + """A line that can be 'unwritten' at a later time""" + + def __init__(self, line: str): + if not line.strip(): + line = "" + self.line = line + + def __call__(self) -> Union[str, None]: + """Returns either self.line or None to indicate the line has been 'unwritten'""" + raise NotImplementedError + + def _new_line(self, line: str) -> Self: + """Returns a new deferred line with the same condition""" + raise NotImplementedError + + def with_prefix(self, prefix: str) -> Self: + return self._new_line(f"{prefix}{self.line}") + + def lstrip(self) -> Self: + return self._new_line(self.line.lstrip()) + + def __getitem__(self, index: Union[int, slice]) -> Self: + return self._new_line(self.line[index]) + + def __bool__(self) -> bool: + return bool(self.line) + + def __len__(self) -> int: + return len(self.line) + + +class DelayReplaceLine(DeferredLineBase): + """At end of codegen call `line.replace(key, value_fn())`""" + + def __init__(self, key: str, value_fn: Callable[[], str], line: str): + super().__init__(line) + self.key = key + self.value_fn = value_fn + + def __call__(self) -> str: + return self.line.replace(self.key, self.value_fn()) + + def _new_line(self, line: str) -> DelayReplaceLine: + return DelayReplaceLine(self.key, self.value_fn, line) + + +class DelayMaybeLine(DeferredLineBase): + """At end of codegen return `line if `pred_fn() else None`""" + + def __init__(self, pred_fn: Callable[[], bool], line: str): + super().__init__(line) + self.pred_fn = pred_fn + + def __call__(self) -> str | None: + return self.line if self.pred_fn() else None + + def _new_line(self, line: str) -> DelayMaybeLine: + return DelayMaybeLine(self.pred_fn, line) + + +@functools.cache +def is_big_gpu(index_or_device: Union[int, torch.device] = 0) -> bool: + if isinstance(index_or_device, torch.device): + device = index_or_device + else: + device = torch.device(get_gpu_type(), index_or_device) + + prop = DeviceProperties.create(device) + + # SM logic is not relevant to ROCm gpus + # Arbitrarily skipping the older models + if torch.version.hip: + assert prop.major is not None + if prop.major < 9 or prop.major == 10: + log.warning("GPU arch does not support max_autotune_gemm mode usage") + return False + return True + + min_sms = 16 if device.type == "xpu" else 68 # 3080 + avail_sms = prop.multi_processor_count + if avail_sms < min_sms: + log.warning( + "Not enough SMs to use max_autotune_gemm mode", + extra={"min_sms": min_sms, "avail_sms": avail_sms}, + ) + return False + return True + + +@functools.lru_cache +def get_max_num_sms() -> int: + if torch.xpu.is_available(): + return torch.xpu.get_device_properties().gpu_subslice_count + return torch.cuda.get_device_properties("cuda").multi_processor_count + + +@functools.lru_cache +def using_b200() -> bool: + """Returns true if the device is a NVIDIA B200, otherwise returns false.""" + if not torch.cuda.is_available(): + return False + # compute capability 10.0 or 10.0a is NVIDIA B200 + device_properties = torch.cuda.get_device_properties(torch.cuda.current_device()) + return device_properties.major == 10 + + +def get_num_sms() -> int: + """Handle experimental carveout if set otherwise return hardware SM count""" + # TODO we need to properly guard on this global + if torch.xpu.is_available(): + return get_max_num_sms() + carveout = torch._C._get_sm_carveout_experimental() + return get_max_num_sms() - (carveout if carveout is not None else 0) + + +def get_tma_workspace_arg( + num_tma_descriptors: int, + device: torch.device, + num_programs: Optional[int] = None, +) -> WorkspaceArg: + """Builds and returns a WorkspaceArg for the device side TMA workspace buffer.""" + from .codegen.common import WorkspaceArg, WorkspaceZeroMode + + if num_programs is None: + num_programs = get_num_sms() + zero_mode = WorkspaceZeroMode.from_bool(False) + size = num_programs * num_tma_descriptors * TMA_DESCRIPTOR_SIZE + return WorkspaceArg( + count=size, + zero_mode=zero_mode, + device=device, + outer_name=WorkspaceArg.unique_name(), + ) + + +def _use_template_for_gpu( + layout: Layout, allowed_layout_dtypes: list[torch.dtype] +) -> bool: + if layout.dtype not in allowed_layout_dtypes: + log.debug( + "Not using template since dtype %s is not in allowed layout dtypes %s", + layout.dtype, + allowed_layout_dtypes, + ) + return ( + is_gpu(layout.device.type) + and layout.dtype in allowed_layout_dtypes + and is_big_gpu(layout.device) + ) + + +def _use_autotune_backend(backend: str) -> bool: + return backend.upper() in [ + x.strip() for x in config.max_autotune_gemm_backends.upper().split(",") + ] + + +def _use_conv_autotune_backend(backend: str) -> bool: + return backend.upper() in [ + x.strip() for x in config.max_autotune_conv_backends.upper().split(",") + ] + + +def use_triton_template( + layout: Layout, + *, + enable_int32: bool = False, + enable_float8: bool = False, + check_max_autotune: bool = True, +) -> bool: + from .codegen.common import BackendFeature, has_backend_feature + + layout_dtypes = [torch.float16, torch.bfloat16, torch.float32] + if enable_int32: + layout_dtypes = [torch.float16, torch.bfloat16, torch.float32, torch.int32] + if enable_float8: + layout_dtypes.extend([torch.float8_e4m3fn, torch.float8_e5m2]) + return ( + ( + ( + is_gpu(layout.device.type) + and _use_template_for_gpu(layout, layout_dtypes) + ) + or (layout.device.type == "cpu" and layout.dtype in layout_dtypes) + ) + # some callers handle max-autotune checking externally + and (config.max_autotune or config.max_autotune_gemm or not check_max_autotune) + and _use_autotune_backend("TRITON") + and has_backend_feature(layout.device, BackendFeature.TRITON_TEMPLATES) + ) + + +def can_use_tma( + *matrices: IRNode, output_layout: Optional[Layout] = None, add_guards: bool = False +) -> bool: + """ + Return True iff *all* supplied tensors satisfy the CUDA-12.9 TMA constraints + that Triton relies on today. + * https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html + + A tensor is accepted when: + * 2 ≤ rank ≤ 5 + * dtype ∈ {FP16, BF16, FP8-E4M3FN} + * Every logical size ≥ 2 + * Base pointer 16-byte aligned + * All "outer" dims have 16-byte aligned strides + * The “inner” dim has stride 1 (contiguous) + * For FP8 tensors, inner dim ≥ 32 + """ + from torch.utils._triton import has_triton_tma_device + + from .virtualized import V + + def _aligned(expr_bytes: Union[int, sympy.Expr]) -> bool: + return V.graph.sizevars.statically_known_multiple_of(expr_bytes, TMA_ALIGNMENT) + + def _is_tma_compatible_layout(layout: Optional[Layout]) -> bool: + if layout is None: + return True + sizes = layout.size + strides = layout.stride + dtype = layout.dtype + + # Verify the output is 16-byte aligned + if not _aligned(layout.offset): + return False + + return _is_tma_compatible(sizes, strides, dtype, allow_float32=True) + + def _is_tma_compatible_matrix(m: IRNode) -> bool: + sizes = m.get_size() + strides = m.get_stride() + dtype = m.get_dtype() + + # Base pointer 16-byte aligned + if m.get_name() in V.graph.unaligned_buffers: + return False + + return _is_tma_compatible(sizes, strides, dtype, allow_float32=False) + + def _is_tma_compatible( + sizes: Sequence[sympy.Expr], + strides: Sequence[_IntLike], + dtype: torch.dtype, + allow_float32: bool, + ) -> bool: + rank = len(sizes) + itemsize = dtype.itemsize + + # 2 ≤ rank ≤ 5 + if rank < 2 or rank > 5: + return False + + # dtype ∈ {FP16, BF16, FP8-E4M3FN} + if dtype not in (torch.float16, torch.bfloat16, torch.float8_e4m3fn) and ( + not allow_float32 or dtype != torch.float32 + ): + return False + + if add_guards: + sizes_i = V.graph.sizevars.guard_int_seq(sizes) + strides_i = V.graph.sizevars.guard_int_seq(strides) + else: + sizes_i = [V.graph.sizevars.symbolic_hint(s) for s in sizes] + strides_i = [V.graph.sizevars.symbolic_hint(st) for st in strides] + + # Every logical size ≥ 2 + if any(not V.graph.sizevars.statically_known_geq(s, 2) for s in sizes_i): + return False + + # Find the single contiguous (“inner”) dim + inner = [ + i + for i, st in enumerate(strides_i) + if V.graph.sizevars.statically_known_equals(st, 1) + ] + if len(inner) != 1: + return False + inner_idx = inner[0] + + # All "outer" dims must have 16-byte aligned strides + for i, st in enumerate(strides_i): + if i == inner_idx: + continue + if not _aligned(st * itemsize): + return False + + # Inner dim byte width must still be a multiple of 16 B + inner_dim = sizes_i[inner_idx] + if not _aligned(inner_dim * itemsize): + return False + + # FP8 special case: inner ≥ 32 + if dtype == torch.float8_e4m3fn and not V.graph.sizevars.statically_known_geq( + inner_dim, 32 + ): + return False + + return True + + return ( + has_triton_tma_device() + and all(_is_tma_compatible_matrix(m) for m in matrices) + and _is_tma_compatible_layout(output_layout) + ) + + +def use_triton_tma_template( + *matrices: IRNode, output_layout: Layout, add_guards: bool = False +) -> bool: + layout = output_layout if config.triton.enable_template_tma_store else None + return ( + all(len(m.get_size()) == 2 for m in matrices) + and can_use_tma(*matrices, output_layout=layout, add_guards=add_guards) + and config.triton.enable_persistent_tma_matmul + ) + + +def use_triton_blackwell_tma_template( + *matrices: IRNode, output_layout: Layout, add_guards: bool = False +) -> bool: + if not use_triton_tma_template( + *matrices, output_layout=output_layout, add_guards=add_guards + ): + return False + + from torch.utils._triton import has_triton_tensor_descriptor_host_tma + + from .codegen.cuda.cuda_env import is_datacenter_blackwell_arch + + # Blackwell template require the tensor descriptor API, not the experimental API. + return has_triton_tensor_descriptor_host_tma() and is_datacenter_blackwell_arch() + + +@functools.lru_cache(maxsize=1) +def ensure_cute_available() -> bool: + """Check if CuTeDSL is importable; cache the result for reuse. + + Call ensure_cute_available.cache_clear() after installing CuTeDSL + in the same interpreter to retry the import. + """ + try: + return importlib.util.find_spec("cutlass.cute") is not None + except ImportError: + return False + + +def use_blackwell_cutedsl_grouped_mm( + mat_a: Any, + mat_b: Any, + layout: Layout, + a_is_2d: bool, + b_is_2d: bool, + offs: Optional[Any], + bias: Optional[Any], + scale_result: Optional[Any], +) -> bool: + """ + Returns True if we can use the blackwell kernel for grouped mm. + Required conditions: + 1. CuTeDSL backend is enabled + 2. CuTeDSL is available + 3. We are on a blackwell arch + 4. The dtype is bf16 + 5. Max autotune or max autotune gemm is enabled + 6. A, B, and the output are 16B aligned + 7. We are not using dynamic shapes + 8. A is 2d + 9. B is 3d + 10. Offsets are provided + 11. Bias and Scale are not provided + """ + if not ensure_cute_available(): + return False + + if not _use_autotune_backend("CUTEDSL"): + return False + + from .codegen.cuda.cuda_env import is_datacenter_blackwell_arch + + if not is_gpu(layout.device.type): + return False + + if not is_datacenter_blackwell_arch(): + return False + + layout_dtypes = [torch.bfloat16] + if not _use_template_for_gpu(layout, layout_dtypes): + return False + + if not (config.max_autotune or config.max_autotune_gemm): + return False + + # Checks for 16B ptr and stride alignment + if not can_use_tma(mat_a, mat_b, output_layout=layout): + return False + + if any(is_dynamic(x) for x in [mat_a, mat_b]): + return False + + if not a_is_2d or b_is_2d: + return False + + if offs is None: + return False + + if bias is not None or scale_result is not None: + return False + + return True + + +def use_cutlass_template(layout: Layout, m: int, n: int, k: int) -> bool: + from .virtualized import V + + gemm_size = V.graph.sizevars.size_hint(m * n * k, fallback=-1) + if gemm_size <= 0 or gemm_size < config.cuda.cutlass_backend_min_gemm_size: + return False + from .codegen.cuda.cutlass_utils import try_import_cutlass + + # Do not use cutlass template on ROCm + if torch.version.hip: + return False + + # output dtype + # FP32 not supported: https://github.com/pytorch/pytorch/issues/145952 + layout_dtypes = [torch.float16, torch.bfloat16, torch.int32] + res = ( + _use_template_for_gpu(layout, layout_dtypes) + and (config.max_autotune or config.max_autotune_gemm) + and _use_autotune_backend("CUTLASS") + ) + + if res: + if not try_import_cutlass(): + log.warning( + "Failed to import CUTLASS lib. Please check whether " + "_inductor.config.cuda.cutlass_dir %s is set correctly. " + "Skipping CUTLASS backend for now.", + config.cuda.cutlass_dir, + ) + return False + return res + + +def _use_cutlass_for_op(op_name: str) -> bool: + """Check if CUTLASS should be used for the given operation.""" + enabled_ops = config.cuda.cutlass_enabled_ops.upper() + if enabled_ops == "ALL": + return True + return op_name.upper() in [x.strip() for x in enabled_ops.split(",")] + + +_IntLike: TypeAlias = Union[int, sympy.Expr] + + +@functools.cache +def use_decompose_k_choice( + m: _IntLike, n: _IntLike, k: _IntLike, threshold_multiple: int = 1 +) -> bool: + from torch._inductor.virtualized import V + + decompose_k_threshold = config.triton.decompose_k_threshold * threshold_multiple + + return ( + not torch.version.hip + and V.graph.sizevars.statically_known_true( + sympy.And( + sympy.Ge(k, decompose_k_threshold * m), + sympy.Ge(k, decompose_k_threshold * n), + ) + ) + and not V.graph.aot_mode # TODO: Support AOTI for decomposeK + and not V.graph.cpp_wrapper + and config.triton.num_decompose_k_splits > 0 + ) + + +@functools.cache +def use_contiguous(m: _IntLike, n: _IntLike, k: _IntLike) -> bool: + """ + Check if we should use the contiguous subgraph transform. + This transform makes the second matrix contiguous before the matmul. + """ + contiguous_threshold = config.rocm.contiguous_threshold + + # Similar conditions to decompose_k but for contiguous transform + from torch._inductor.virtualized import V + + return ( + bool(torch.version.hip) # Only relevant on AMD + and V.graph.sizevars.statically_known_true( + sympy.And( + sympy.Ge(k, contiguous_threshold * m), + sympy.Ge(k, contiguous_threshold * n), + ) + ) + and not V.graph.aot_mode + and not V.graph.cpp_wrapper + ) + + +@functools.cache +def get_k_splits(m: _IntLike, n: _IntLike, k: _IntLike) -> list[int]: + # To limit compile time + k_splits_limit = config.triton.num_decompose_k_splits + + # Hand-tuned + default_k_splits = [16, 32, 64, 128, 256] + # If k is a sympy expression, we can't do any splitting + if isinstance(k, sympy.Expr) and not k.is_number: + return default_k_splits + elif k_splits_limit == 0: + return [] + + if (isinstance(m, sympy.Expr) and not m.is_number) or ( + isinstance(n, sympy.Expr) and not n.is_number + ): + max_k_split = 256 + else: + max_k_split = min(k // m, k // n) + + min_k_split = 2 + # Get all divisors of k, k has to be divisible by kPart + divisors = sympy.divisors(k) + + divisors = [ + divisor + for divisor in divisors + if divisor <= max_k_split and divisor >= min_k_split + ] + + pow_of_2_divisors, mul_of_32_divisors, rest_of_splits = [], [], [] + + for d in divisors: + kPart = k // d + + # Smaller than 128 might not even fit in a single tile, BLOCK_K can be 128 + if kPart < 128: + continue + + # Power of 2 divisors are best performing, conform to hardware + if (kPart & kPart - 1) == 0 and kPart >= 128: + pow_of_2_divisors.append(d) + # Else check if creates a multiple of 32 + elif kPart % 32 == 0: + mul_of_32_divisors.append(d) + # otherwise, take the smallest values + else: + rest_of_splits.append(d) + + if config.max_autotune_gemm_search_space == "EXHAUSTIVE": + return pow_of_2_divisors + mul_of_32_divisors + rest_of_splits + + best_splits = pow_of_2_divisors + mul_of_32_divisors + rest_of_splits + # Otherwise, conform results to k_splits_limit + return best_splits[:k_splits_limit] + + +@functools.cache +def _rocm_native_device_arch_name(device: str) -> str: + return torch.cuda.get_device_properties(device).gcnArchName + + +@functools.cache +def try_import_ck_lib() -> tuple[ + Optional[str], Callable[[], list[Any]], Callable[[], list[Any]], type[Any] +]: + try: + import ck4inductor # type: ignore[import] + from ck4inductor.universal_gemm.gen_instances import ( # type: ignore[import] + gen_ops_library, + gen_ops_preselected, + ) + from ck4inductor.universal_gemm.op import ( # type: ignore[import] + CKGemmOperation, + ) + + package_dirname = os.path.dirname(ck4inductor.__file__) + except ImportError: + + def gen_ops_library() -> list[Any]: + return [] + + def gen_ops_preselected() -> list[Any]: + return [] + + class CKGemmOperation: # type: ignore[no-redef] + pass + + package_dirname = None + return package_dirname, gen_ops_library, gen_ops_preselected, CKGemmOperation + + +def use_ck_template(layout: Layout) -> bool: + # config knobs check 1 + if not (config.max_autotune or config.max_autotune_gemm): + return False + # platform check + if not torch.version.hip: + return False + # tensors must be on GPU + if layout.device.type != "cuda": + return False + # hardware check + # if config arch list is not specified, get the native arch from the device properties + native_arch = _rocm_native_device_arch_name(layout.device) + requested_archs = {k.split(":")[0]: k for k in config.rocm.arch} or { + native_arch.split(":")[0]: native_arch + } + requested_supported_archs = [ + requested_archs[k] + for k in requested_archs.keys() & config.rocm.ck_supported_arch + ] + if not requested_supported_archs: + return False + # supported input dtypes + if layout.dtype not in [torch.float16, torch.bfloat16, torch.float32]: + return False + + ck_package_dirname, _, _, _ = try_import_ck_lib() + + if not ck_package_dirname: + log.warning("Please pip install Composable Kernel package") + return False + + config.rocm.ck_dir = ck_package_dirname + + return True + + +def use_ck_gemm_template(layout: Layout, m: int, n: int, k: int) -> bool: + from .virtualized import V + + return ( + _use_autotune_backend("CK") + and use_ck_template(layout) + and V.graph.sizevars.size_hint(m * n * k, fallback=-1) > 0 + ) + + +def use_ck_tile_gemm_template(layout: Layout, m: int, n: int, k: int) -> bool: + from .virtualized import V + + return ( + _use_autotune_backend("CKTILE") + and use_ck_template(layout) + and V.graph.sizevars.size_hint(m * n * k, fallback=-1) > 0 + ) + + +def use_ck_conv_template(layout: Layout) -> bool: + return _use_conv_autotune_backend("CK") and use_ck_template(layout) + + +def _use_template_for_cpu(layout: Layout) -> bool: + return ( + config.max_autotune or config.max_autotune_gemm + ) and layout.device.type == "cpu" + + +def use_cpp_bmm_template( + layout: Layout, mat1: Union[ReinterpretView, Buffer], mat2: IRNode +) -> bool: + from .ir import Layout + + assert isinstance(mat1.layout, Layout) + + # In certain scenarios, such as when the first stride is 0, the entire tensor may not be contiguous. + # But the 2D matrix within each batch can still be contiguous, allowing us to apply max autotune. + # So here we specifically check for contiguity within the 2D matrix of each batch. + mat1_size = mat1.layout.size + mat1_stride = mat1.layout.stride + mat1_each_batch_is_contiguous = ( + _use_template_for_cpu(layout) + and mat1.get_dtype() == torch.float32 + and (len(mat1_size) == 3) + and (len(mat1_stride) == 3) + and (mat1_stride[1] == mat1_size[2]) + and (mat1_stride[2] == 1) + ) + return use_cpp_gemm_template(layout, mat1, mat2, require_constant_mat2=False) and ( + mat1.layout.is_contiguous() or mat1_each_batch_is_contiguous + ) + + +def use_cpp_gemm_template( + layout: Layout, + mat1: IRNode, + mat2: IRNode, + mat2_transposed: bool = False, + require_constant_mat2: bool = True, + is_woq_int4: bool = False, + q_group_size: Optional[int] = None, +) -> bool: + from . import ir + from .codegen.cpp_micro_gemm import create_micro_gemm + from .codegen.cpp_utils import get_gemm_template_output_and_compute_dtype + from .kernel.mm_common import mm_args + + if not _use_template_for_cpu(layout) or not _use_autotune_backend("CPP"): + return False + + if not config.cpp.weight_prepack: + return False + + int8_gemm = mat1.get_dtype() in [torch.uint8, torch.int8] + layout_dtypes = [torch.float32, torch.bfloat16, torch.half, torch.uint8] + m, n, k, layout, mat1, mat2 = mm_args( + mat1, + mat2, + out_dtype=layout.dtype if int8_gemm else None, + mat2_transposed=mat2_transposed, + use_4x2_dim=is_woq_int4, + ) + + # TODO(jgong5): support dynamic shapes for n or k + if has_free_symbols((n, k)): + return False + + if isinstance(mat2, ir.BaseView): + mat2 = mat2.unwrap_view() + + output_dtype, _ = get_gemm_template_output_and_compute_dtype(mat1.get_dtype()) + micro_gemm = create_micro_gemm( + "micro_gemm", + m, + n, + k, + input_dtype=mat1.get_dtype(), + input2_dtype=mat2.get_dtype(), + output_dtype=output_dtype, + num_threads=parallel_num_threads(), + use_ref=not is_woq_int4, + q_group_size=q_group_size, + ) + + def is_last_dim_stride1(x: IRNode) -> bool: + x.freeze_layout() + return x.get_stride()[-1] == 1 + + return ( + layout.dtype in layout_dtypes + and micro_gemm is not None + and is_last_dim_stride1(mat1) # TODO(jgong5): support transposed input + and isinstance(mat2, ir.StorageBox) + and (mat2.is_module_buffer() or not require_constant_mat2) + ) + + +def use_aten_gemm_kernels() -> bool: + return not ( + config.max_autotune or config.max_autotune_gemm + ) or _use_autotune_backend("ATEN") + + +class DebugDirManager: + counter = itertools.count(0) + prev_debug_name: str + + def __init__(self) -> None: + self.id = next(DebugDirManager.counter) + + def __enter__(self) -> None: + self.prev_debug_name = torch._dynamo.config.debug_dir_root + self.new_name = f"{self.prev_debug_name}_tmp_{self.id}" + torch._dynamo.config.debug_dir_root = self.new_name + + def __exit__(self, *args: Any) -> None: + shutil.rmtree(self.new_name) + torch._dynamo.config.debug_dir_root = self.prev_debug_name + + +def run_and_get_code( + fn: Callable[P, _T], + *args: P.args, + **kwargs: P.kwargs, +) -> tuple[_T, list[str]]: + from .graph import GraphLowering + + source_codes: OrderedSet[str] = OrderedSet() + + def save_output_code(code: str) -> None: + source_codes.add(code) + + with mock.patch.object(GraphLowering, "save_output_code", save_output_code): + torch._dynamo.reset() + result = fn(*args, **kwargs) + return result, list(source_codes) + + +def run_and_get_kernels( + fn: Callable[P, _T], *args: P.args, **kwargs: P.kwargs +) -> tuple[_T, list[str]]: + # pyrefly: ignore [bad-argument-type] + result, source_codes = run_and_get_code(fn, *args, **kwargs) + kernels = [] + for code in source_codes: + kernels.extend(re.findall(r"'''.*?'''", code, re.DOTALL)) + return result, kernels + + +def run_fw_bw_and_get_code(fn: Callable[..., Any]) -> tuple[Any, list[str]]: + def run_with_backward() -> Any: + result = fn() + result.sum().backward() + return result + + return run_and_get_code(run_with_backward) + + +def get_code(fn: Callable[P, _T], *args: P.args, **kwargs: P.kwargs) -> list[str]: + """Get the inductor-generated code, but skip any actual compilation or running.""" + from .graph import GraphLowering + + source_codes: list[str] = [] + + def save_output_code(code: str) -> None: + source_codes.append(code) + + def patched_compile_to_module(self: GraphLowering) -> Any: + class DummyModule: + """This is empty to replace the generated triton module""" + + def __init__(self) -> None: + pass + + def call(self, *args: Any, **kwargs: Any) -> None: + # Don't do anything when called + pass + + wrapper_code, kernel_code = ( + self.codegen_with_cpp_wrapper() if self.cpp_wrapper else self.codegen() + ) + # Skip all the actual compiling. + save_output_code(wrapper_code.value) + if kernel_code: + save_output_code(kernel_code.value) + + return DummyModule() + + with ( + mock.patch.object( + GraphLowering, "compile_to_module", patched_compile_to_module + ), + mock.patch.object(GraphLowering, "save_output_code", save_output_code), + ): + torch._dynamo.reset() + # Note the return here is None + _ = fn(*args, **kwargs) + + return source_codes + + +def get_triton_code(fn: Callable[P, _T], *args: P.args, **kwargs: P.kwargs) -> str: + # pyrefly: ignore [bad-argument-type] + source_codes = get_code(fn, *args, **kwargs) + # Can have two outputs if backwards was eagerly compiled + assert 1 <= len(source_codes) <= 2, ( + f"expected one or two code outputs got {len(source_codes)}" + ) + return source_codes[0] + + +def run_and_get_triton_code( + fn: Callable[P, _T], *args: P.args, **kwargs: P.kwargs +) -> str: + # pyrefly: ignore [bad-argument-type] + _, source_codes = run_and_get_code(fn, *args, **kwargs) + # Can have two outputs if backwards was eagerly compiled + assert 1 <= len(source_codes) <= 2, ( + f"expected one or two code outputs got {len(source_codes)}" + ) + return source_codes[0] + + +def run_and_get_graph_lowering( + fn: Callable[P, _T], *args: P.args, **kwargs: P.kwargs +) -> tuple[Any, list[GraphLowering]]: + from torch._inductor.graph import GraphLowering + from torch._inductor.output_code import CompiledFxGraph + + real_init = CompiledFxGraph.__init__ + graph_lowerings = [] + + def fake_init(*args: Any, **kwargs: Any) -> None: + real_init(*args, **kwargs) + graph = args[2] + assert isinstance(graph, GraphLowering) + graph_lowerings.append(graph) + + with mock.patch.object(CompiledFxGraph, "__init__", fake_init): + result = fn(*args, **kwargs) + + return result, graph_lowerings + + +@contextlib.contextmanager +def override_lowering( + aten_op: Callable[..., Any], override_fn: Callable[..., Any] +) -> Iterator[None]: + """ + Override the lowering of aten_op with override_fn. + The first argument of override_fn is the original lowering fn. + """ + from torch._inductor import lowering + + orig_fn = lowering.lowerings[aten_op] + try: + lowering.lowerings[aten_op] = functools.partial(override_fn, orig_fn) + yield + finally: + lowering.lowerings[aten_op] = orig_fn + + +def add_scheduler_init_hook( + pre_fn: Callable[..., Any], post_fn: Optional[Callable[..., Any]] = None +) -> Any: + """ + Add hook functions to be called at the beginning and end of Scheduler.__init__. + Used for unit tests. + """ + from torch._inductor.scheduler import Scheduler + + orig_fn = Scheduler.__init__ + + def wrapper(scheduler: Any, nodes: Any) -> Any: + pre_fn(scheduler, nodes) + out = orig_fn(scheduler, nodes) + if post_fn: + post_fn(scheduler, nodes) + return out + + return unittest.mock.patch.object(Scheduler, "__init__", wrapper) + + +def developer_warning(msg: str) -> None: + """ + Warnings that will be actionable for PyTorch developers, but not + end users. Allows us to easily disable them in stable releases but + keep them on for nightly builds. + """ + if config.developer_warnings: + log.warning(msg) + else: + log.info(msg) + + +def get_benchmark_name() -> Optional[str]: + """ + An experimental API used only when config.benchmark_kernel is true. + + The benchmark name is only available at codegen time. So we can not + directly call it in benchmark_all_kernels which is run after codegen. + + The function assumes the argument after --only is the benchmark name. + It works for torchbench.py/hugginface.py/timm_models.py. But for ad-hoc + scripts, this function may return None. + + There are 2 flavors of --only argument we need handle: + 1. --only model_name + 2. --only=model_name + """ + try: + idx = sys.argv.index("--only") + if ( + idx + 1 < len(sys.argv) + and len(sys.argv[idx + 1]) > 0 + and sys.argv[idx + 1][0] != "-" + ): + return sys.argv[idx + 1] + except ValueError: + pass + + for arg in sys.argv: + if arg.startswith("--only="): + return arg[len("--only=") :] + + return None + + +def is_ones(items: Sequence[Any]) -> bool: + return all(x == 1 for x in items) + + +def is_zeros(items: Sequence[Any]) -> bool: + return all(x == 0 for x in items) + + +def is_cpu_device(inputs: Sequence[torch.Tensor]) -> bool: + return all( + item.device == torch.device("cpu") + for item in inputs + if isinstance(item, torch.Tensor) + ) + + +def get_sympy_Expr_dtype(val: sympy.Expr) -> torch.dtype: + assert isinstance(val, sympy.Expr), ( + "only support sympy.Expr as input to get_sympy_Expr_dtype" + ) + if val.is_integer: # type: ignore[attr-defined] + return torch.int64 + else: + return torch.float64 + + +@contextlib.contextmanager +def maybe_profile(should_profile: bool, *args: Any, **kwargs: Any) -> Iterator[Any]: + if should_profile: + with torch.profiler.profile(*args, **kwargs) as p: + yield p + else: + yield + + +def parallel_num_threads() -> int: + threads = config.cpp.threads + if threads < 1: + threads = torch.get_num_threads() + return threads + + +@functools.cache +def get_backend_num_stages() -> int: + from .runtime.triton_helpers import get_backend_options + + options = get_backend_options() + return options.get("num_stages", 2 if torch.version.hip else 3) + + +@functools.cache +def get_device_tflops(dtype: torch.dtype) -> float: + """ + We don't want to throw errors in this function. First check to see if the device is in device_info.py, + then fall back to the inaccurate triton estimation. + """ + ds_tops = datasheet_tops(dtype, is_tf32=torch.backends.cuda.matmul.allow_tf32) + if ds_tops is not None: + return ds_tops + + from triton.testing import get_max_simd_tflops, get_max_tensorcore_tflops + + SM80OrLater = torch.cuda.is_available() and torch.cuda.get_device_capability() >= ( + 8, + 0, + ) + + assert dtype in (torch.float16, torch.bfloat16, torch.float32) + + if inspect.signature(get_max_simd_tflops).parameters.get("clock_rate"): + # Triton API change in https://github.com/triton-lang/triton/pull/2293 + from torch._utils_internal import max_clock_rate + + sm_clock = max_clock_rate() + if dtype in (torch.float16, torch.bfloat16) and SM80OrLater: + return get_max_tensorcore_tflops(dtype, sm_clock) + + if torch.backends.cuda.matmul.allow_tf32: + return get_max_tensorcore_tflops(torch.float32, sm_clock) + else: + return get_max_simd_tflops(torch.float32, sm_clock) + else: + if dtype in (torch.float16, torch.bfloat16) and SM80OrLater: + # pyrefly: ignore # missing-argument + return get_max_tensorcore_tflops(dtype) + + if torch.backends.cuda.matmul.allow_tf32: + # pyrefly: ignore # missing-argument + return get_max_tensorcore_tflops(torch.float32) + else: + # pyrefly: ignore # missing-argument + return get_max_simd_tflops(torch.float32) + + +@functools.cache +def get_gpu_dram_gbps() -> int: + from triton.testing import get_dram_gbps + + return get_dram_gbps() + + +def get_gpu_shared_memory() -> int: + from triton.runtime import driver + + # pyrefly: ignore # missing-attribute + return driver.active.utils.get_device_properties(0).get("max_shared_mem", 0) + + +def get_max_numwarps() -> int: + if torch.cuda.is_available(): + warp_size = torch.cuda.get_device_properties().warp_size + max_threads_per_block = torch.cuda.get_device_properties().max_threads_per_block + else: + # Defaults + warp_size = 32 + max_threads_per_block = 1024 + return max_threads_per_block // warp_size + + +def is_welford_reduction(reduction_type: str) -> bool: + return reduction_type.startswith("welford") + + +def reduction_num_outputs(reduction_type: str) -> int: + if is_welford_reduction(reduction_type): + return 3 + elif reduction_type == "online_softmax_reduce": + return 2 + else: + return 1 + + +def is_linux() -> bool: + return platform.system() == "Linux" + + +def is_windows() -> bool: + return sys.platform == "win32" + + +def has_free_symbols(itr: Iterable[Any]) -> bool: + return any(isinstance(x, sympy.Expr) and not x.is_number for x in itr) + + +def is_dynamic(*args: Any) -> bool: + from . import ir + + for t in args: + if isinstance( + t, (ir.TensorBox, ir.StorageBox, ir.BaseView, ir.ComputedBuffer, ir.Buffer) + ): + if has_free_symbols(t.maybe_get_size() or ()) or has_free_symbols( + t.maybe_get_stride() or () + ): + return True + elif not isinstance(t, ir.IRNode): + continue + else: + raise TypeError(f"unexpected type for is_dynamic {type(t)}") + + return False + + +# Placeholder strings used in triton codegen. +class Placeholder(enum.Enum): + # The placeholder for the actual name of a triton kernel. + # e.g. for "def triton_" it would be "triton_" + KERNEL_NAME = "KERNEL_NAME" + + # The descriptive name of the triton kernel; when unique_kernel_names = False, this + # placeholder will be replaced with a string with more information. + DESCRIPTIVE_NAME = "DESCRIPTIVE_NAME" + + +def pass_execution_and_save( + func: Callable[..., Any], gm: GraphModule, inp: Sequence[Any], msg: str +) -> None: + from .pattern_matcher import stable_topological_sort + + with tempfile.NamedTemporaryFile( + mode="w", + encoding="utf-8", + ) as f: + before_io = io.StringIO() + after_io = io.StringIO() + ShapeProp(gm=gm, fake_mode=detect_fake_mode(inp)).propagate(*inp) + print(f"Before:\n{gm.graph}", file=f) + print(gm.graph, file=before_io) + start_time = datetime.now() + with GraphTransformObserver(gm, msg): + func(gm.graph) + time_elapsed = datetime.now() - start_time + # recompile graph + stable_topological_sort(gm.graph) + gm.graph.lint() + gm.recompile() + + print(f"After:\n{gm.graph}", file=f) + print(gm.graph, file=after_io) + t = before_io.getvalue() == after_io.getvalue() + log.info( + "%s, save before/after graph to %s, graph before/after are the same = %s, time elapsed = %s", + msg, + f.name, + t, + time_elapsed, + ) + + +def is_multi_outputs_template(input_buf: Optional[Union[Buffer, Operation]]) -> bool: + """ + Check if input buffer is a multi-outputs template buffer + """ + from . import ir + + return isinstance(input_buf, ir.CppTemplateBuffer) and isinstance( + input_buf.layout, ir.MultiOutputLayout + ) + + +def is_output_of_multi_outputs_template( + input_buf: Optional[Union[Buffer, Operation]], +) -> bool: + """ + Check if input buffer is a output of multi-outputs template buffer + """ + from . import ir + + return ( + isinstance(input_buf, ir.MultiOutput) + and len(input_buf.inputs) == 1 + and is_multi_outputs_template(input_buf.inputs[0]) # type: ignore[arg-type] + ) + + +def is_collective( + node: Optional[Union[Node, Operation]], + op: Optional[torch._ops.OperatorBase] = None, +) -> bool: + if node is None: + return False + + from . import ir + + return ( + isinstance(node, ir._CollectiveKernel) + and not isinstance(node, ir._WaitKernel) + and (op is None or node.op_overload is op) + ) or ( + # TODO: this is a temporary solution to ensure that we can identify torchrec's + # communication ops. But in order to allow better communication and computation + # overlap, torchrec's communication ops should be not used. + type(node) is ir.FallbackKernel + and ( + # NOTE: the `hasattr()` check is to bypass errors such as the following: + # AttributeError: '_OpNamespace' 'torchrec' object has no attribute 'all_to_all_single' + ( + hasattr(torch.ops.torchrec, "all_to_all_single") + and node.op_overload == torch.ops.torchrec.all_to_all_single.default + ) + or ( + hasattr(torch.ops.torchrec, "all_gather_into_tensor") + and node.op_overload + == torch.ops.torchrec.all_gather_into_tensor.default + ) + or ( + hasattr(torch.ops.torchrec, "reduce_scatter_tensor") + and node.op_overload == torch.ops.torchrec.reduce_scatter_tensor.default + ) + ) + ) + + +def is_wait(node: Optional[Union[IRNode, Operation]]) -> bool: + from . import ir + + return type(node) is ir._WaitKernel + + +def contains_collective( + snode: BaseSchedulerNode, + filter_fn: Optional[Callable[[BaseSchedulerNode], bool]] = None, +) -> bool: + from torch._inductor.scheduler import GroupedSchedulerNode + + if isinstance(snode, GroupedSchedulerNode): + return any(contains_collective(x) for x in snode.snodes) + + return is_collective(snode.node) and (filter_fn is None or filter_fn(snode)) + + +def contains_wait(snode: BaseSchedulerNode) -> bool: + from torch._inductor.scheduler import GroupedSchedulerNode + + if isinstance(snode, GroupedSchedulerNode): + return any(contains_wait(x) for x in snode.snodes) + else: + return is_wait(snode.node) + + +def is_fallback_op( + node: Optional[Operation], + op: Union[torch._ops.OpOverload, Collection[torch._ops.OpOverload]], +) -> bool: + from . import ir + + if isinstance(op, torch._ops.OpOverload): + op = [op] + return isinstance(node, ir.FallbackKernel) and node.op_overload in op + + +def buf_name_to_fused_snode( + buf_name: str, name_to_buf: dict[str, Any], name_to_fused_node: dict[str, Any] +) -> Any: + return name_to_fused_node[name_to_buf[buf_name].defining_op.get_name()] + + +def find_recursive_deps_of_node( + snode: BaseSchedulerNode, + collected_node_set: MutableSet[BaseSchedulerNode], + name_to_buf: dict[str, SchedulerBuffer], + name_to_fused_node: dict[str, BaseSchedulerNode], + criteria_cb: Callable[[Any], bool] = lambda snode: False, +) -> None: + if criteria_cb(snode): + return + collected_node_set.add(snode) + for dep in snode.unmet_dependencies: + defining_op_for_dep = buf_name_to_fused_snode( + dep.name, name_to_buf, name_to_fused_node + ) + if defining_op_for_dep in collected_node_set: + continue + find_recursive_deps_of_node( + defining_op_for_dep, + collected_node_set, + name_to_buf, + name_to_fused_node, + criteria_cb=criteria_cb, + ) + + +def find_recursive_users_of_node( + snode: BaseSchedulerNode, + collected_node_set: MutableSet[BaseSchedulerNode], + name_to_buf: dict[str, SchedulerBuffer], + name_to_fused_node: dict[str, BaseSchedulerNode], + criteria_cb: Callable[[Any], bool] = lambda snode: False, +) -> None: + if criteria_cb(snode): + return + collected_node_set.add(snode) + for o in snode.get_outputs(): + for user in o.users: + assert user.node is not None + if user.node.get_name() == "OUTPUT": + continue + if user.node.get_name() not in name_to_fused_node: + continue + user_op = name_to_fused_node[user.node.get_name()] + if user_op in collected_node_set: + continue + find_recursive_users_of_node( + user_op, + collected_node_set, + name_to_buf, + name_to_fused_node, + criteria_cb=criteria_cb, + ) + + +def num_fw_fixed_arguments(dynamo_gm_num_inputs: int, aot_fw_gm_num_inputs: int) -> int: + "Computes the number of inputs to the aot fw graph which have fixed addresses (params and buffers)" + num_rng_seed_offset_inputs = ( + 2 if torch._functorch.config.functionalize_rng_ops else 0 + ) + # AOT won't lift any parameters if we're inlining NN Modules + # however desugaring subclasses will still add arguments + # resulted in extra fixed inputs https://github.com/pytorch/pytorch/issues/130502 + return aot_fw_gm_num_inputs - dynamo_gm_num_inputs - num_rng_seed_offset_inputs + + +def count_tangents(fx_g: torch.fx.GraphModule) -> int: + """ + Infers which inputs are static for a backwards graph + """ + + def is_saved_tensor(x: Node) -> bool: + return ( + "tangents" not in x.name + and "bwd_seed" not in x.name + and "bwd_base_offset" not in x.name + and "bwd_rng_state" not in x.name + ) + + arg_count = 0 + static_arg_idxs = [] + for n in fx_g.graph.nodes: + if n.op == "placeholder": + if is_saved_tensor(n): + static_arg_idxs.append(arg_count) + arg_count += 1 + + assert static_arg_idxs == list(range(len(static_arg_idxs))) + return len(static_arg_idxs) + + +@dataclasses.dataclass +class BoxedBool: + value: bool + + def __bool__(self) -> bool: + return self.value + + @staticmethod + def disable(obj: Any) -> Union[BoxedBool, bool]: + if isinstance(obj, BoxedBool): + obj.value = False + return obj + return False + + +@contextlib.contextmanager +def collect_defined_kernels(kernel_list: list[str]) -> Iterator[None]: + from .codegen.wrapper import PythonWrapperCodegen + + orig_define_kernel = PythonWrapperCodegen.define_kernel + + def define_kernel( + self: PythonWrapperCodegen, + kernel_name: str, + kernel_code: str, + metadata: Optional[str] = None, + gpu: bool = True, + cpp_definition: Optional[str] = None, + ) -> Any: + kernel_list.append(kernel_code) + return orig_define_kernel( + self, kernel_name, kernel_code, metadata, gpu, cpp_definition + ) + + with mock.patch.object(PythonWrapperCodegen, "define_kernel", define_kernel): + yield + + +def get_cloned_parameter_buffer_name(name: str) -> str: + return name + "__original__" + + +def is_gpu(device: Optional[str]) -> bool: + return device in GPU_TYPES + + +def device_need_guard(device: str) -> bool: + return device != "mps" and is_gpu(device) # TODO: MPS does not expose streams now + + +def needs_fallback_due_to_atomic_add_limitations(dtype: torch.dtype) -> bool: + if dtype == torch.bfloat16 and torch.cuda.is_available(): + return torch.cuda.get_device_capability() < (9, 0) + elif dtype == torch.bfloat16 and torch.xpu.is_available(): + return True + else: + return dtype in (torch.int64, torch.bool) + + +def use_scatter_fallback( + op_overload: torch._ops.OpOverload, + reduction_type: Optional[str], + self_dtype: torch.dtype, + src_dtype: torch.dtype, + src_device_type: str, + src_is_tensor: bool, +) -> bool: + if ( + op_overload.overloadpacket + in (torch.ops.aten.scatter_reduce_, torch.ops.aten.scatter_reduce) + and reduction_type is None + ): + return False + + reduce_ty = ( + "add" if op_overload.overloadpacket == torch.ops.aten.scatter_ else "sum" + ) + + return ( + reduction_type not in (None, reduce_ty) + or ( + src_is_tensor + and is_gpu(src_device_type) + and needs_fallback_due_to_atomic_add_limitations(src_dtype) + ) + or ( + op_overload.overloadpacket == torch.ops.aten.scatter_reduce_ + and reduction_type == "sum" + and src_is_tensor + and src_device_type == "cpu" + and config.cpp.fallback_scatter_reduce_sum + and (config.cpp.dynamic_threads or parallel_num_threads() != 1) + ) + or (reduction_type == reduce_ty and self_dtype in (torch.bool, torch.int64)) + or torch.are_deterministic_algorithms_enabled() + ) + + +def dump_node_schedule(node_schedule: Sequence[BaseSchedulerNode]) -> None: + """ + An API that can be used in pdb to dump a node_schedule. + Right mainly dump the read/write dependencies but can add more as needed. + """ + from torch._inductor.codegen.simd import DisableReduction, EnableReduction + from torch._inductor.scheduler import SchedulerNode + + print(f"Node schedule with {len(node_schedule)} nodes") + for idx, node in enumerate(node_schedule): + print(f" {idx:3}:") + if node is EnableReduction: + print("enable reduction") + elif node is DisableReduction: + print("disable reduction") + elif isinstance(node, SchedulerNode): + is_red = node.is_reduction() + print(f"{'red' if is_red else 'pw'} scheduler node") + if is_red: + assert node.node is not None + print(f"original reduction hint {node.node.data.reduction_hint}") # type: ignore[attr-defined] + print("ReadDep:") + for dep in node.read_writes.reads: + print(dep) + print("WriteDep:") + for dep in node.read_writes.writes: + print(dep) + else: + raise RuntimeError(f"Unrecognized node type: {type(node)}") + + +def tensor_is_aligned(tensor: torch.Tensor) -> bool: + # See Note: [Input Alignment handling in Inductor] + # Right now, we don't try to guard on the alignment of the storage offset. + # When this comment was written, non-symbolic storage_offsets are not guarded on + # but symbolic storage_offsets are. For consistency, we suppress guard creation + # upon performing this check: that ensures that we don't add recompiles when we + # add this logic. + from torch.fx.experimental.symbolic_shapes import statically_known_true + + return statically_known_true( + (tensor.storage_offset() * get_dtype_size(tensor.dtype)) % GPU_ALIGN_BYTES == 0 + ) + + +def should_assume_input_aligned(example_input: torch.Tensor) -> bool: + # See Note: [Input Alignment handling in Inductor] + + # right now, we only care about alignment for cuda tensors. + if not is_gpu(example_input.device.type): + return False + return config.assume_aligned_inputs or tensor_is_aligned(example_input) + + +def maybe_get_suppress_shape_guards_ctx() -> contextlib.AbstractContextManager[None]: + # Try to get TracingContext.try_get().fake_mode.shape_env.suppress_guards() + # If it's not available, return a nullcontext. + + # If we're dealing with cudagraphs, we might not have a tracing_context + tracing_context = torch._guards.TracingContext.try_get() + if not tracing_context: + return contextlib.nullcontext() + + # In standalone inductor compile mode, we might not have a shape_env attached to the fake mode + if not tracing_context.fake_mode or not tracing_context.fake_mode.shape_env: + return contextlib.nullcontext() + shape_env = tracing_context.fake_mode.shape_env + return shape_env.suppress_guards() + + +def run_and_get_cpp_code( + fn: Callable[P, _T], *args: P.args, **kwargs: P.kwargs +) -> tuple[_T, str]: + # We use the patch context manager instead of using it as a decorator. + # In this way, we can ensure that the attribute is patched and unpatched correctly + # even if this run_and_get_cpp_code function is called multiple times. + with unittest.mock.patch.object(config, "debug", True): + torch._dynamo.reset() + import io + import logging + + log_capture_string = io.StringIO() + ch = logging.StreamHandler(log_capture_string) + from torch._inductor.codecache import output_code_log + + output_code_log.addHandler(ch) + prev_level = output_code_log.level + output_code_log.setLevel(logging.DEBUG) + result = fn(*args, **kwargs) + s = log_capture_string.getvalue() + output_code_log.setLevel(prev_level) + output_code_log.removeHandler(ch) + return result, s + + +def shape_env_from_inputs(inputs: Sequence[InputType]) -> Optional[ShapeEnv]: + fake_mode = detect_fake_mode(inputs) + + # TODO(voz): It would be nice to enable this assert, but there are lots of tests that + # pass in real inputs for now. + # if len(inputs) > 0: + # assert fake_mode is not None, breakpoint() + + if fake_mode is not None: + return fake_mode.shape_env + + # When there are no tensor inputs, get shape_env from the first SymInt. + for input in inputs: + if isinstance(input, torch.SymInt): + return input.node.shape_env + + # Check tensor sizes and strides for SymInt values + if isinstance(input, torch.Tensor): + for size in input.size(): + if isinstance(size, torch.SymInt): + return size.node.shape_env + for stride in input.stride(): + if isinstance(stride, torch.SymInt): + return stride.node.shape_env + + # TODO(voz): Should we always have one anyway? + return None + + +def align_inputs_from_check_idxs( + model: Callable[[list[InputType]], _T], + inputs_to_check: Sequence[int], + mutated_input_idxs: OrderedSet[int], +) -> Callable[[list[InputType]], _T]: + if len(inputs_to_check) == 0: + return model + + def run(new_inputs: list[InputType]) -> Any: + old_tensors, new_tensors = copy_misaligned_inputs( + new_inputs, inputs_to_check, mutated_input_idxs + ) + out = model(new_inputs) + + # If a mutated tensor was cloned to be aligned, we need to reflect back the mutation to the + # original tensor. + if len(old_tensors): + torch._foreach_copy_(old_tensors, new_tensors) + + return out + + return run + + +def clone_preserve_strides(x: torch.Tensor) -> torch.Tensor: + if 0 in x.size(): + # Short-circuits if the shape has no elements + needed_size = 0 + else: + needed_size = ( + sum((shape - 1) * stride for shape, stride in zip(x.size(), x.stride())) + 1 + ) + buffer = torch.as_strided(x, (needed_size,), (1,)).clone() + return torch.as_strided(buffer, x.size(), x.stride()) + + +def copy_misaligned_inputs( + new_inputs: list[InputType], + check_inputs_idxs: Sequence[int], + return_pair_idxs: Optional[OrderedSet[int]] = None, +) -> tuple[list[torch.Tensor], list[torch.Tensor]]: + """ + Clones misaligned tensors which we inferred were aligned. Returns a tuple of [old_tensors], [new_tensors] for every + cloned tensor which is in `return_pair_idxs`. + """ + + old_tensors: list[torch.Tensor] = [] + new_tensors: list[torch.Tensor] = [] + + # hoist above loop because this is on the hot path + ret_pair_defined = return_pair_idxs is not None + for i in check_inputs_idxs: + _inp = new_inputs[i] + assert isinstance(_inp, torch.Tensor), ( + f"Expected tensors only, but got: {type(_inp)}" + ) + if _inp.data_ptr() % ALIGNMENT: + new_inputs[i] = clone_preserve_strides(_inp) + + if ret_pair_defined and i in return_pair_idxs: # type: ignore[operator] + old_tensors.append(_inp) + new_tensors.append(new_inputs[i]) # type: ignore[arg-type] + + return old_tensors, new_tensors + + +def remove_unaligned_input_idxs( + inputs: Sequence[InputType], + static_input_idxs: Sequence[int], +) -> Sequence[int]: + """ + We require all inputs to be aligned, so introduce a copy for any + that aren't. + """ + aligned_static_input_idxs = [] + for idx in static_input_idxs: + input = inputs[idx] + if isinstance(input, torch.Tensor) and (input.data_ptr() % ALIGNMENT) == 0: + aligned_static_input_idxs.append(idx) + if len(aligned_static_input_idxs) != len(static_input_idxs): + return aligned_static_input_idxs + return static_input_idxs + + +def expr_fits_within_32bit(e: sympy.Expr) -> bool: + from .virtualized import V + + int_max = torch.iinfo(torch.int32).max + size_hint = V.graph.sizevars.size_hint + has_hint = V.graph.sizevars.shape_env.has_hint + + if config.assume_32bit_indexing: + V.graph.sizevars.check_leq(e, int_max) # type: ignore[arg-type] + return True + + # Allow for unhinted e as long as we can still statically prove + # (e.g., via ValueRanges) that it is still in bounds + if V.graph.sizevars.statically_known_true(e <= int_max): + return True + + # AOTI doesn't guard on < 2**32, so checking hints isn't a viable option, + # in case the hinted value is < 2**32, but the allowed range is larger. + # However, to prevent possible perf regressions on pre-existing AOTI models + # which don't set an upper bound on the valid range, we'll skip the check. + # To recap: + # - If using AOTI: + # - If allowed range has no upper bound, then check the hint to determine + # whether this fits in int32 + # - If allowed range does have an upper bound, then obey the upper bound + # (check whether upper bound < int32_max) without checking the hint. + + if V.aot_compilation: + # check whether value has an upper bound (1e20 is > INT64_MAX, assume + # there is no upper bound if it can be larger than 1e20) + if V.graph.sizevars.statically_known_true(e < 1e20): + # if so, then assume int_max < upper bound < inf + # so this could potentially have int64 values + return False + + # Otherwise, the hint MUST exist and be in range + return has_hint(e) and size_hint(e) <= int_max + + +def set_tracing_context_output_strides( + example_inputs: Sequence[Any], compiled_graph: CompiledFxGraph +) -> None: + # Return the output strides to the caller via TracingContext + context = torch._guards.TracingContext.try_get() + if context is not None and context.output_strides is not None: + assert len(context.output_strides) == 0 + shape_env = shape_env_from_inputs(example_inputs) + assert compiled_graph.output_strides is not None + for exprs in compiled_graph.output_strides: + if exprs is None: + context.output_strides.append(None) + else: + fakify_first_call = False + if ctx := torch._guards.TracingContext.try_get(): + fakify_first_call = ctx.fakify_first_call + + def map_expr(e: Any) -> Union[float, int, SymInt, SymFloat, SymBool]: + if shape_env is None: + return int(e) + if fakify_first_call: + return shape_env.deserialize_symexpr(e) + return shape_env.evaluate_symexpr(e) + + context.output_strides.append( + tuple(map_expr(e) for e in exprs) # type: ignore[misc] + ) + + +def should_use_remote_fx_graph_cache() -> bool: + if config.fx_graph_remote_cache is not None: + return config.fx_graph_remote_cache + if not config.is_fbcode(): + return False + + if torch._utils_internal.is_fb_unit_test(): + return False + + try: + from torch._inductor.fb.remote_cache import REMOTE_CACHE_VERSION + except ModuleNotFoundError: + return False + + return REMOTE_CACHE_VERSION >= torch._utils_internal.justknobs_getval_int( + "pytorch/remote_cache:fx_graph_memcache_version" + ) + + +def normalize_name(name: str) -> str: + return re.sub(r"[^a-zA-Z0-9_]", "_", name) + + +# correct cases where Triton types names don't match PyTorch +_triton_type_mapping = { + "tl.bool": "tl.int1", + "tl.float8_e4m3fn": "tl.float8e4nv", + "tl.float8_e5m2": "tl.float8e5", + "tl.float8_e4m3fnuz": "tl.float8e4b8", + "tl.float8_e5m2fnuz": "tl.float8e5b16", + # TODO: remove when support is added in triton + # https://github.com/triton-lang/triton/issues/6054 + "tl.float8_e8m0fnu": "tl.uint8", + "tl.float4_e2m1fn_x2": "tl.uint8", +} +_torch_triton_mapping = {v: k for k, v in _triton_type_mapping.items()} + + +_triton_type_re = re.compile(r"^.*[.]") + + +def triton_type(dtype: torch.dtype) -> str: + """Convert torch.dtype to triton type""" + triton_type_name = _triton_type_re.sub("tl.", str(dtype)) + return _triton_type_mapping.get(triton_type_name, triton_type_name) + + +def triton_type_to_torch(dtype: str) -> torch.dtype: + adjusted_type = _torch_triton_mapping.get(dtype, dtype) + type_name = adjusted_type.replace("tl.", "") + out_dtype = getattr(torch, type_name) + assert isinstance(out_dtype, torch.dtype) + return out_dtype + + +def is_same_tensor(data: torch.Tensor, value: torch.Tensor) -> bool: + return ( + not data.is_mkldnn + and data.size() == value.size() + and data.stride() == value.stride() + and data.dtype == value.dtype + and data.device == value.device + and data.untyped_storage().data_ptr() == value.untyped_storage().data_ptr() + and data.storage_offset() == value.storage_offset() + ) + + +def is_same_mkldnn_tensor(data: torch.Tensor, value: torch.Tensor) -> bool: + return ( + data.is_mkldnn + and data.size() == value.size() + and data.dtype == value.dtype + and data.device == value.device + and torch.ops.mkldnn.data_ptr(data) == torch.ops.mkldnn.data_ptr(value) + ) + + +@functools.cache +def boolean_ops() -> tuple[str, ...]: + return ( + "isinf", + "isnan", + "logical_not", + "logical_and", + "signbit", + "and_", + "le", + "lt", + "ge", + "gt", + "eq", + "ne", + "or_", # TODO should remove this op + "xor", + ) + + +@dataclasses.dataclass +class OpDtypeRule: + type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND + override_return_dtype: Optional[torch.dtype] + + +op_dtype_propagation_rules: dict[str, OpDtypeRule] = {} + + +def register_op_dtype_propagation_rules( + name: str, + type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND, + override_return_dtype: Optional[torch.dtype], +) -> None: + op_dtype_propagation_rules[name] = OpDtypeRule( + type_promotion_kind, override_return_dtype + ) + + +op_requires_libdevice_fp64: OrderedSet[str] = OrderedSet() + + +def register_op_requires_libdevice_fp64(name: str) -> None: + op_requires_libdevice_fp64.add(name) + + +def get_current_backend(device_type: Optional[str] = None) -> str: + from torch._inductor.virtualized import V + + if not device_type: + device_type = V.graph.get_current_device_or_throw().type + if device_type == "cpu": + return config.cpu_backend + elif device_type == "mps": + return "mps" + elif device_type == "xpu": + return config.xpu_backend + else: + return config.cuda_backend + + +def upcast_compute_type(dtype: torch.dtype) -> torch.dtype: + """Maybe upcast [b]float16 to float32""" + if ( + dtype in (torch.float16, torch.bfloat16) + and config.triton.codegen_upcast_to_fp32 + and get_current_backend() == "triton" + ): + return torch.float32 + return dtype + + +KeyType = TypeVar("KeyType") +ValType = TypeVar("ValType") + + +class ScopedDict(MutableMapping[KeyType, ValType]): + """ + A dictionary-like object that allows for scoped updates. It maintains + an original dictionary and a set of new items that can override + the original items within the scope. The original dictionary is + unmodified. + """ + + def __init__(self, original_dict: Mapping[KeyType, ValType]): + self.original_dict = original_dict + self.new_items: dict[KeyType, ValType] = {} + + def __getitem__(self, key: KeyType) -> ValType: + if key in self.new_items: + return self.new_items[key] + return self.original_dict[key] + + def __setitem__(self, key: KeyType, value: ValType) -> None: + self.new_items[key] = value + + def __contains__(self, key: object) -> bool: + return key in self.new_items or key in self.original_dict + + def get(self, key: KeyType, default: Optional[ValType] = None) -> Optional[ValType]: # type: ignore[override] + if key in self.new_items: + return self.new_items[key] + return self.original_dict.get(key, default) + + def __len__(self) -> int: + n = len(self.original_dict) + for k in self.new_items: + if k not in self.original_dict: + n += 1 + return n + + def __iter__(self) -> Iterator[KeyType]: + yield from self.original_dict + for k in self.new_items: + if k not in self.original_dict: + yield k + + def __bool__(self) -> bool: + return bool(self.original_dict or self.new_items) + + def __delitem__(self, key: KeyType) -> None: + raise NotImplementedError + + +@dataclass_transform(frozen_default=True) +def ir_dataclass(cls: Optional[type[Any]] = None, /, *, frozen: bool = True) -> Any: + def wrap(cls: _T) -> _T: + return dataclasses.dataclass(cls, kw_only=True, frozen=frozen) # type: ignore[call-overload] + + if cls is None: + return wrap + return wrap(cls) + + +def get_donated_idxs() -> Optional[list[int]]: + tracing_context = torch._guards.TracingContext.try_get() + if tracing_context is not None and tracing_context.fw_metadata: + return tracing_context.fw_metadata.bw_donated_idxs + return None + + +class TritonAttrsDescriptorVersion(enum.Enum): + V0_NO_TRITON = 0 + V1_COMPILER = 1 # triton.compiler.compiler.AttrsDescriptor + V2_BACKENDS = 2 # triton.backends.compiler.AttrsDescriptor + V3_BACKENDS_TUPLE = ( + 3 # triton.backends.compiler.AttrsDescriptor, but with tuple support + ) + V4_DICT = 4 # a raw dict + + +@functools.cache +def get_triton_attrs_descriptor_version() -> TritonAttrsDescriptorVersion: + if importlib.util.find_spec("triton") is None: + return TritonAttrsDescriptorVersion.V0_NO_TRITON + + import triton.backends.compiler + import triton.compiler.compiler + + if hasattr(triton.backends.compiler, "AttrsDescriptor"): + # Triton 3.2.0 + # AttrsDescriptor was moved from triton.compiler.compiler to triton.backends.compiler. + # AttrsDescriptor and its serialization format were also changed. + + # TODO: implement V3_BACKENDS_TUPLE + # On Dec 9, 2024, tuple support (triton #5220) was implemented and breaks handling. + # We don't have a way to detect this (and haven't implemented this version) + return TritonAttrsDescriptorVersion.V2_BACKENDS + elif hasattr(triton.compiler.compiler, "AttrsDescriptor"): + # Triton 3.0.0 + return TritonAttrsDescriptorVersion.V1_COMPILER + else: + # After Jan 1, 2025 + # AttrsDescriptor was removed and replaced with a raw dict. + return TritonAttrsDescriptorVersion.V4_DICT + + +def triton_version_uses_attrs_dict() -> bool: + return get_triton_attrs_descriptor_version() == TritonAttrsDescriptorVersion.V4_DICT + + +def _fx_node_is_input_dependent_cudagraph_unsafe(fx_node: torch.fx.Node) -> bool: + """ + Check if an FX node is cudagraph-unsafe based on its input arguments. + + Some ops are only cudagraph-unsafe depending on their inputs (e.g., index_put + with boolean indices triggers .nonzero() during capture, but integer indices + are safe). + """ + from torch.fx.operator_schemas import normalize_function + + target = fx_node.target + if not isinstance(target, torch._ops.OpOverload): + return False + + # index_put with boolean indices triggers .nonzero() during capture + if target in ( + torch.ops.aten.index_put.default, + torch.ops.aten.index_put_.default, + torch.ops.aten._unsafe_index_put.default, + ): + normalized = normalize_function( + target, fx_node.args, fx_node.kwargs, normalize_to_only_use_kwargs=True + ) + if normalized is not None: + _, kwargs = normalized + indices = kwargs["indices"] + for idx in indices: + if idx is not None and idx.meta["val"].dtype in ( + torch.bool, + torch.uint8, + ): + return True + + return False + + +def is_cudagraph_unsafe_fx_node(fx_node: torch.fx.Node) -> bool: + """ + Check if an FX node is cudagraph-unsafe. + + This includes: + - Ops in FORBIDDEN_CUDAGRAPH_OPS (CPU sync, dynamic alloc, etc.) + - Ops with the cudagraph_unsafe tag + - Input-dependent unsafe ops (e.g., index_put with boolean indices) + - Ops with sparse tensor outputs + """ + target = fx_node.target + + # Check against the forbidden ops set + if str(target) in FORBIDDEN_CUDAGRAPH_OPS: + return True + + # Check for cudagraph_unsafe tag + if ( + isinstance(target, torch._ops.OpOverload) + and torch._C.Tag.cudagraph_unsafe in target.tags # type: ignore[attr-defined] + ): + return True + + # Check for input-dependent unsafety + if _fx_node_is_input_dependent_cudagraph_unsafe(fx_node): + return True + + # Check for sparse tensor outputs + if (val := fx_node.meta.get("val")) is not None: + vals = [val] if not isinstance(val, (list, tuple)) else val + for v in vals: + if isinstance(v, torch.Tensor) and v.is_sparse: + return True + + return False + + +def is_cudagraph_unsafe_op(node: Operation) -> bool: + """ + Returns True if the node is an op that is not cudagraphable. + This includes: + - Ops in FORBIDDEN_CUDAGRAPH_OPS (CPU sync, dynamic alloc, etc.) + - Ops with the cudagraph_unsafe tag + - index_put_ with boolean indices (triggers .nonzero() during capture) + - Control flow nodes (Conditional, WhileLoop) + - Ops with sparse tensor outputs + """ + from . import ir + + # Control flow nodes are cudagraph-unsafe + if isinstance(node, (ir.Conditional, ir.WhileLoop)): + return True + + if not isinstance(node, (ir.FallbackKernel, ir.ExternKernel)): + return False + + fx_node = getattr(node, "fx_node", None) + if fx_node is not None and is_cudagraph_unsafe_fx_node(fx_node): + return True + + return False + + +def get_ld_library_path() -> str: + path = os.environ.get("LD_LIBRARY_PATH", "") + if config.is_fbcode(): + from libfb.py.parutil import get_runtime_path + + runtime_path = get_runtime_path() + if runtime_path: + lib_path = os.path.join(runtime_path, "runtime", "lib") + path = os.pathsep.join([lib_path, path]) if path else lib_path + + return path + + +def is_codegen_graph_partition_subgraph(wrapper: PythonWrapperCodegen) -> bool: + from torch._inductor.codegen.wrapper import SubgraphPythonWrapperCodegen + + return ( + isinstance(wrapper, SubgraphPythonWrapperCodegen) + and wrapper.partition_signatures is not None + ) + + +def is_using_cudagraph_partition() -> bool: + return ( + torch._inductor.config.triton.cudagraphs + or _unstable_customized_partition_wrapper.wrapper is not None + ) and torch._inductor.config.graph_partition + + +def dtype_from_size(size: int) -> torch.dtype: + from .virtualized import V + + if V.graph.sizevars.statically_known_lt( + size, 2**31 + ) and V.graph.sizevars.statically_known_geq(size, -(2**31)): + return torch.int32 + else: + return torch.int64 + + +SUPPORTED_MKLDNN_DEVICES = ("cpu", "xpu") + + +def is_mkldnn_bf16_supported(device_type: str) -> bool: + """ + Returns True if the device supports MKL-DNN BF16. + """ + if device_type == "cpu": + return torch.ops.mkldnn._is_mkldnn_bf16_supported() + elif "xpu" in device_type: + # match "xpu", "xpu:0", "xpu:1", etc. + return True + return False + + +def is_mkldnn_fp16_supported(device_type: str) -> bool: + """ + Returns True if the device supports MKL-DNN FP16. + """ + if device_type == "cpu": + return torch.ops.mkldnn._is_mkldnn_fp16_supported() + elif "xpu" in device_type: + # match "xpu", "xpu:0", "xpu:1", etc. + return True + return False + + +def tabulate_2d(elements: Sequence[Sequence[T]], headers: Sequence[T]) -> str: + widths = [len(str(e)) for e in headers] + for row in elements: + assert len(row) == len(headers) + for i, e in enumerate(row): + widths[i] = max(widths[i], len(str(e))) + lines = [] + lines.append("|".join(f" {h:{w}} " for h, w in zip(headers, widths))) + # widths whitespace horizontal separators + total_width = sum(widths) + (len(widths) * 2) + (len(widths) - 1) + lines.append("-" * total_width) + for row in elements: + lines.append("|".join(f" {e:{w}} " for e, w in zip(row, widths))) + return "\n".join(lines) + + +def zip_dicts( + dict1: Mapping[KeyType, ValType], + dict2: Mapping[KeyType, ValType], + d1_default: ValType | None = None, + d2_default: ValType | None = None, +) -> Generator[tuple[KeyType, ValType | None, ValType | None], None, None]: + """ + Zip two dictionaries together, replacing missing keys with default values. + + Args: + dict1 (dict): The first dictionary. + dict2 (dict): The second dictionary. + d1_default (Any): the default value for the first dictionary + d2_default (Any): the default value for the second dictionary + + Yields: + tuple: A tuple containing the key, the value from dict1 (or d1_default if missing), + and the value from dict2 (or d2_default if missing). + """ + # Find the union of all keys + all_keys = OrderedSet(dict1.keys()) | OrderedSet(dict2.keys()) + + # Iterate over all keys + for key in all_keys: + # Get the values from both dictionaries, or default if missing + value1 = dict1.get(key) + value2 = dict2.get(key) + + yield ( + key, + value1 if value1 is not None else d1_default, + value2 if value2 is not None else d2_default, + ) + + +def maybe_aoti_standalone_config(config_patches: dict[str, Any]) -> dict[str, Any]: + """ + Ensures the configuration is internally consistent for standalone AOTInductor. + + If `aot_inductor_mode.compile_standalone` is set to True in the provided + `config_patches` (or falls back to the global config), this function ensures + that the following configs are also enabled: + - `aot_inductor.package_cpp_only` + + Args: + config_patches (dict[str, Any]): A dictionary of user-provided config + overrides for AOTInductor compilation. + + Returns: + dict[str, Any]: The possibly-updated `config_patches` dictionary. + """ + + def patch_config( + config_patches: dict[str, Any], config_name: str, config_value: Any + ) -> None: + value = config_patches.get(config_name, getattr(config, config_name)) + if value is None: + config_patches[config_name] = config_value + elif not value and value != config_value: + raise RuntimeError( + f"Invalid config: {config_name}={config_value} when aot_inductor_mode.compile_standalone is True." + ) + + def force_patch_config( + config_patches: dict[str, Any], config_name: str, config_value: Any + ) -> None: + value = config_patches.get(config_name, getattr(config, config_name)) + if value != config_value: + log.warning( + "Overriding: %s=%s when aot_inductor_mode.compile_standalone is True.", + config_name, + config_value, + ) + config_patches[config_name] = config_value + + compile_standalone = config_patches.get( + "aot_inductor_mode.compile_standalone", + config.aot_inductor_mode.compile_standalone, + ) + # Make a copy of the config_patches to avoid modifying the original dictionary, needed for testing + config_patches = config_patches.copy() + if compile_standalone: + # Standlaone AOTInductor means only generate cpp project for building a standalone binary + patch_config(config_patches, "aot_inductor.package_cpp_only", True) + # Standlaone AOTInductor needs to embed the kernel code in the binary + patch_config(config_patches, "aot_inductor.embed_kernel_binary", True) + # Default to use multi-arch kernel codegen for non-rocm GPU + patch_config( + config_patches, "aot_inductor.emit_multi_arch_kernel", not torch.version.hip + ) + patch_config( + config_patches, "aot_inductor.model_name_for_generated_files", "aoti_model" + ) + # TODO: change these two configs to default to None and use patch_config + force_patch_config( + config_patches, + "aot_inductor.link_libtorch", + config.test_configs.use_libtorch, + ) + force_patch_config(config_patches, "aot_inductor.dynamic_linkage", False) + + cross_target_platform = config_patches.get( + "aot_inductor.cross_target_platform", + config.aot_inductor.cross_target_platform, + ) + + package_constants_in_so = config_patches.get( + "aot_inductor.package_constants_in_so", + config.aot_inductor.package_constants_in_so, + ) + + if cross_target_platform == "windows" and package_constants_in_so: + raise RuntimeError( + "config.aot_inductor.package_constants_in_so is not supported for windows cross-compilation. " + "Please use config.aot_inductor.package_constants_on_disk_format = binary_blob." + ) + + return config_patches + + +def determine_aoti_mmap_flags(consts_size: int) -> tuple[bool, bool]: + """ + Decide whether we should mmap weights, and whether to store the weights with .so. + + If force_mmap_weights or package_constants_on_disk_format == "binary_blob" configs are set, respect the config. + + Returns tuple (use_external_weights, use_mmap_weights). + """ + + if ( + config.aot_inductor.force_mmap_weights + and config.aot_inductor.package_constants_on_disk_format == "binary_blob" + ): + raise RuntimeError( + "config.aot_inductor.package_constants_on_disk_format = binary_blob and " + "config.aot_inductor.force_mmap_weights cannot both be True." + ) + + if config.aot_inductor.force_mmap_weights: + if config.aot_inductor.cross_target_platform == "windows": + raise RuntimeError( + "when cross_target_platform is windows, use_mmap_weights should not be true." + ) + use_mmap_weights = True + use_external_weights = False + return use_external_weights, use_mmap_weights + + if config.aot_inductor.package_constants_on_disk_format == "binary_blob": + use_external_weights = True + use_mmap_weights = False + return use_external_weights, use_mmap_weights + + if consts_size <= 2_000_000_000: + return False, False + + use_external_weights = False + use_mmap_weights = not config.is_fbcode() + + return use_external_weights, use_mmap_weights + + +def is_valid_aoti_model_name() -> bool: + """ + Validates if a model name is suitable for use in code generation. + + """ + from torch._inductor import config + + model_name = config.aot_inductor.model_name_for_generated_files + + if model_name is None: + return True + + if not isinstance(model_name, str): + raise ValueError("Invalid AOTI model name: Model name must be a string") + + if model_name == "": + return True + + # Can only contain alphanumeric characters and underscores + if not re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", model_name): + raise ValueError( + "Invalid AOTI model name: Model name can only contain letters, numbers, and underscores" + ) + + return True + + +def get_free_symbols(x: IterateExprs, unbacked_only: bool) -> OrderedSet[sympy.Symbol]: + if unbacked_only: + return free_unbacked_symbols(x) + else: + return free_symbols(x) + + +def maybe_log_cudagraph_partition( + msg: str, + prefix: Optional[str] = "cudagraph partition due to ", + node: Optional[BaseSchedulerNode] = None, +) -> None: + """ + Cudagraph partition may lead to extra memory overhead so we + log partition reasons to help users understand the overhead. + """ + if not config.triton.cudagraphs: + return + + warning_msg = f"{prefix}{msg}" + + if ( + node + and (ir_node := node.node) + and (fx_node := ir_node.get_origin_node()) + and (stack_trace := fx_node.meta.get("stack_trace", None)) + ): + warning_msg = f"{warning_msg}. Found from : \n {stack_trace}" + + perf_hint_log.warning(warning_msg) + + +def python_subprocess_env() -> dict[str, str]: + """ + Get a base environment for running Python subprocesses. + """ + + env = { + # Inherit the environment of the current process. + **os.environ, + # Set the PYTHONPATH so the subprocess can find torch. + "PYTHONPATH": os.environ.get( + "TORCH_CUSTOM_PYTHONPATH", os.pathsep.join(sys.path) + ), + } + + # Set PYTHONHOME for internal builds, to account for builds that bundle the + # runtime. Otherwise they will use the libraries and headers from the + # platform runtime instead. + # + # This can't be done for external builds. The process can be run from a + # venv and that won't include Python headers. The process needs to be able + # to search for and find the platform runtime. + if config.is_fbcode(): + env["PYTHONHOME"] = sysconfig.get_path("data") + + return env + + +@dataclasses.dataclass(frozen=True) +class CUDAGraphWrapperMetadata: + """ + Metadata for Customized CUDAGraphWrapper. + + Currently assumes there is 1 dynamo graph and will extend to + multiple graphs in the future. + """ + + # The number of partitions that are cudagraphable. + num_partitions: int + + # Index of the current partition. + partition_index: int + + +PartitionFnType = Callable[..., Any] +CUDAGraphWrapperType = Callable[ + [PartitionFnType, CUDAGraphWrapperMetadata], PartitionFnType +] + + +# only incremented by user call of mark_step_begin +class CUDAGraphWrapper: + wrapper: Optional[CUDAGraphWrapperType] = None + + +# A customized partition wrappers from users. Interface should be: +# +# def wrapper(fn: PartitionFnType, metadata: CUDAGraphWrapperMetadata) -> PartitionFnType +# +# Inductor generates N wrapper functions for N partition functions, and mechanically wrap +# each partition fn with the generated wrapper function. Users need to handle all details +# such as static inputs, dynamic shapes, etc. +# Users could customize the wrapper based on the metadata. One example is to have special +# handle for the first and last wrapper function. +# +# Warning: This API is unstable and may change in the future. +_unstable_customized_partition_wrapper = CUDAGraphWrapper() + + +def set_customized_partition_wrappers(wrapper: CUDAGraphWrapperType) -> None: + _unstable_customized_partition_wrapper.wrapper = wrapper + + +def snode_args_kwargs(snode: BaseSchedulerNode) -> tuple[list[Any], dict[str, Any]]: + args = snode.node.inputs # type: ignore[union-attr] + args = snode.node.fill_non_provided_args( # type: ignore[union-attr] + [*args, *snode.node.constant_args], # type: ignore[union-attr] + snode.node.kwargs, # type: ignore[union-attr] + ) + kwargs = snode.node.kwargs # type: ignore[union-attr] + flat_args, flat_args_pytree_spec = pytree.tree_flatten((args, kwargs)) + + def _is_tensor_ir(x) -> bool: # type: ignore[no-untyped-def] + return isinstance(x, torch._inductor.ir.IRNode) and not isinstance( + x, torch._inductor.ir.GeneratorState + ) + + flat_args = [ + torch._inductor.ir.ir_node_to_tensor(a, guard_shape=False) + if _is_tensor_ir(a) + else a + for a in flat_args + ] + + def _tensor(size, dtype, device) -> torch.Tensor: # type: ignore[no-untyped-def] + return torch.empty(size, dtype=dtype, device=device) + + def to_real_tensor(e: Any) -> Any: + if not isinstance(e, torch.Tensor): + return e + out = _tensor(e.size(), e.dtype, e.device) + return out + + flat_args = [to_real_tensor(a) for a in flat_args] + args, kwargs = pytree.tree_unflatten(flat_args, flat_args_pytree_spec) + return args, kwargs + + +def is_nonfreeable_buffers(dep: Dep) -> bool: + from .virtualized import V + + dep_name = dep.name + # Subgraphs have a prefix for the name, cleanup the prefix + # before checking for known strings. + if V.graph.name: + dep_name = dep_name.removeprefix(V.graph.name + "_") + return dep_name.startswith( + ("primals_", "arg", "fwd_rng_state", "bwd_rng_state", "tangents") + ) + + +# Make sure to also include your jinja templates within torch_package_data in setup.py, or this function won't be able to find them +def load_template(name: str, template_dir: Path) -> str: + """Load a template file and return its content.""" + with open(template_dir / f"{name}.py.jinja") as f: + return f.read() + + +def should_fallback_by_default(node: torch.fx.Node) -> bool: + """Decide whether fallback for a node. This is only used in inductor lite mode.""" + target = node.target + + assert isinstance( + target, (torch._ops.OpOverload, torch._ops.HigherOrderOperator) + ), f"Expected OpOverload or HigherOrderOperator, but found {type(target)}" + + if not config.fallback_by_default: + return False + + # some ops need special handle due to dynamic shapes. we can avoid + # fallback if they do not impact numerics. + skip_fallback_due_to_dynamic_shape = OrderedSet( + [ + torch.ops.aten._assert_scalar.default, + torch.ops.aten.lift_fresh_copy.default, + ] + ) + + if target in skip_fallback_due_to_dynamic_shape: + return False + + # Most hops have registered lowering. We should follow the lowering and not fallback. + # However, in rare cases, hops may not register lowering, such as + # torch.ops.higher_order.triton_kernel_wrapper_functional. We should fallback for + # these hops. + fallback_hops = OrderedSet( + [torch.ops.higher_order.triton_kernel_wrapper_functional] + ) + + if isinstance(target, torch._ops.HigherOrderOperator): + return target in fallback_hops + + return not _needs_inductor_compile(node) + + +# Collective operation names for specialized benchmarking +COLLECTIVE_OPS = OrderedSet( + [ + "torch.ops._c10d_functional.all_reduce.default", + "torch.ops._c10d_functional.all_reduce_.default", + "torch.ops._c10d_functional.all_gather_into_tensor.default", + "torch.ops._c10d_functional.reduce_scatter_tensor.default", + "torch.ops._c10d_functional.all_to_all_single.default", + "torch.ops._c10d_functional_autograd.all_reduce.default", + "torch.ops._c10d_functional_autograd.all_gather_into_tensor.default", + "torch.ops._c10d_functional_autograd.reduce_scatter_tensor.default", + "torch.ops._c10d_functional_autograd.all_to_all_single.default", + ] +) + + +def is_collective_op(op_name: str) -> bool: + """Check if an operation is a collective operation.""" + return op_name in COLLECTIVE_OPS diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/virtualized.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/virtualized.py new file mode 100644 index 0000000000000000000000000000000000000000..f45e372e2b3a3d9adfeba23d5fb80a26b20cbe7c --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/virtualized.py @@ -0,0 +1,448 @@ +# mypy: allow-untyped-defs +""" +This file provides a number of "global" variables/handlers that are actually +thread local and dynamically scoped, with Inductor patching them to various +implementations depending on the situation. + +These handlers are interacted with in a fairly stylized way. Typically, +we will import V from this module:: + + from .virtualized import V + +Various handlers are accessible as attributes on this module; for example, +you might access ``V.graph.sizevars.size_hint`` to resolve a size hint associated with +a number. + +There are a few distinct usage patterns for virtualized global variables: + +1. Implicit argument passing. Examples: ``V.current_node``, ``V.aot_compilation``. + Use ``V.set_current_node`` to change what the current node is while we're + executing some region of code, so code inside that region can query ``V.current_node`` + to find out what it is. This is often more convenient than manually threading + the current node as an argument through all call stacks. + +2. Per-compilation global state. Examples: ``V.fake_mode``, ``V.graph``. For a + given ``compile_fx`` invocation, these typically don't change, but they are + associated with some internal state so they cannot just be global functions. + We install these objects at the beginning of compilation and then you can + conveniently access them without having to pass them around. + +3. Alternate define-by-run interpretations. Examples: ``V.ops``, ``V.kernel``. + A commonly used IR in Inductor is define-by-run: instead of maintaining + explicit syntax data structures, we instead represent loop bodies as + callable functions, which internally invoke operations defined on + ``V.ops``. To perform semantic analysis, print or code generate these + operations, we dynamically patch ``V.ops`` with an alternate handler with + the intended semantics and then run the callable function. For example, to + extract out a traditional (FX) graph representation of the define-by-run + IR, simply install a handler that records each ``ops`` call to a graph. + + TODO: Define a parent class / protocol that defines all of the operations + V.ops is expected to support. + +It is typically an error to access a virtualized global without having installed +an appropriate handler (you will get a NullHandler), although in some cases we +provide a default implementation. + +One last thing: although most virtualized globals are accessed via ``V``, ``ops`` is +ubiquitous enough to have its own top level variable, so you will typically see +``ops.constant(...)`` rather than ``V.ops.constant(...)``. In fact, these are not +equivalent; the former interface supports arithmetic overloads like ``x + y`` +instead of forcing ``ops.add(x, y)``, so it should be preferred. + +Some operators are seemingly unused, but they are implicitly used by ops_wrapper. +In particular, we typically have an operator for every basic pointwise PyTorch operation +supported. +""" + +from __future__ import annotations + +from contextlib import AbstractContextManager, contextmanager +from threading import local +from typing import Any, cast, Generic, TYPE_CHECKING, TypeVar, Union + +from torch.utils._ordered_set import OrderedSet + +from .ops_handler import ( # noqa: F401 + DefaultHandler, + KernelFormatterHandler, + MockHandler, + OpsHandler, + ReductionType, + StoreMode, + WrapperHandler, +) + + +if TYPE_CHECKING: + from collections.abc import Callable + + import torch + from torch._inductor.choices import InductorChoices + from torch._inductor.codegen.cpp_utils import LocalBufferContext + from torch._inductor.debug import DebugContext + from torch._inductor.graph import GraphLowering + from torch._inductor.ir import ExternKernelNode + from torch._inductor.loop_body import InterpreterShim + from torch._subclasses import FakeTensorMode + + from .distributed_autotune import _DistributedAutotuneState + +threadlocal = local() + +T = TypeVar("T") + + +class NullHandler: + """ + Sentinel indicating that a global variable is unset ala None. Typically, + attempting to access the global variable before it's set is an error, but with + NullHandler it won't fail until you try to access an attribute on it. + """ + + +# If a virtualized value is set to _PoisonedVirtual then any attempt to get the +# value will result an an exception being raised. This is useful if we want to +# trap uninitialized reads of virtualized globals - for example when compiling +# in a subprocess we don't want the child reading globals that weren't copied +# from the parent. +_PoisonedVirtual = object() + + +class Virtualized(Generic[T]): + """ + Implements a global variable that redirects via thread local variable + (NB: construct this class to create the global variable; this is not + a singleton class!) + + This allows us to swap in different op implementations in codegen. + + NB: Despite the fact that we typically call these "handlers" (e.g., NullHandler is + the default value of the variable), we sometimes use these variables to + store other things, like booleans. + """ + + def __init__(self, vname: str, default: Union[Callable[[], T], type[NullHandler]]): + self._vname = vname + self._key: str = f"__torchinductor_{vname}" + self._default = default + + def _set_handler(self, value: T) -> AbstractContextManager[None]: + prior = self._get_handler(False) + setattr(threadlocal, self._key, value) + + @contextmanager + def ctx(): + try: + yield + finally: + self._set_handler(prior) + + return ctx() + + def _get_handler(self, check_poisoned: bool = True) -> T: + try: + value = getattr(threadlocal, self._key) + if check_poisoned and value is _PoisonedVirtual: + raise RuntimeError( + f"Attempt to use poisoned virtualized value '{self._vname}'." + ) + return value + except AttributeError: + # TODO: To be honest, I feel we probably should just error in this + # case, instead of making a null handler that will probably error + # when you getattr on it + return self._default() # type: ignore[return-value] + + def __getattr__(self, name: str) -> Any: + return getattr(self._get_handler(), name) + + +class NullKernelHandler(NullHandler): + """ + We need access `V.kernel.removed_buffers` in DeferredLine class when there + is no kernel in the context. This happens when codegening the wrapper. + Initialize `removed_buffers` and `inplaced_to_remove` explicitly so we don't + need call 'getattr' with default value which is error prone to typo in + attribute name. + """ + + def __init__(self): + super().__init__() + self.removed_buffers = OrderedSet[Any]() + self.inplaced_to_remove = OrderedSet[Any]() + self.index_dtype = "tl.int64" + + def get_index_dtype_as_torch_dtype(self): + import torch + + if self.index_dtype == "tl.int64": + return torch.int64 + elif self.index_dtype == "tl.int32": + return torch.int32 + else: + raise ValueError(f"Unknown dtype: {self.index_dtype}") + + +_ops: Virtualized[OpsHandler[Any]] = Virtualized( + "ops", cast(type[OpsHandler[Any]], MockHandler) +) +_graph: Virtualized[GraphLowering] = Virtualized("graph", NullHandler) +_extern_kernel_nodes: Virtualized[list[ExternKernelNode]] = Virtualized( + "extern_kernel_nodes", NullHandler +) +_real_inputs: Virtualized[list[torch.Tensor]] = Virtualized("real_inputs", NullHandler) +_fake_mode: Virtualized[FakeTensorMode] = Virtualized("fake_mode", NullHandler) +_kernel: Virtualized[NullKernelHandler] = Virtualized( + "kernel", NullKernelHandler +) # TODO: improve type +_debug: Virtualized[DebugContext] = Virtualized("debug", NullHandler) +_interpreter: Virtualized[InterpreterShim] = Virtualized("interpreter", NullHandler) +_aot_compilation: Virtualized[bool] = Virtualized("aot_compilation", NullHandler) +_current_node: Virtualized[torch.fx.Node] = Virtualized("current_node", NullHandler) +_local_buffer_context: Virtualized[LocalBufferContext] = Virtualized( + "local_buffer_context", NullHandler +) +_distributed_autotune_state: Virtualized[_DistributedAutotuneState] = Virtualized( + "distributed_autotune_state", NullHandler +) + + +def _choices_default(): + """ + Lazy init the global choices handler + + We virtualize InductorChoices to allow changing inductor heuristics from out of tree. + """ + from torch._inductor import config + from torch._inductor.choices import InductorChoices + + if config.inductor_choices_class is not None: + rv = config.inductor_choices_class() + else: + rv = InductorChoices() + setattr(threadlocal, _choices._key, rv) + return rv + + +_choices: Virtualized[InductorChoices] = Virtualized("choices", _choices_default) + + +class OpsValue: + """The return type of most ops calls. + + This exists so we can overload magic methods, and write mathematical + expressions much more fluently. So instead of + + ops.add(ops.mul(ops.mul(ops.sub(ops.mul(_Ap2, x), _Ap3), x), x), _1) + + we can write + + (_Ap2 * x - _Ap3) * x * x + _1 + + """ + + value: Any + + def __init__(self, value): + self.value = value + + def __str__(self): + return str(self.value) + + def __repr__(self): + return f"OpsValue({self.value!r})" + + def __add__(self, other): + return ops.add(self, other) + + def __mul__(self, other): + return ops.mul(self, other) + + def __sub__(self, other): + return ops.sub(self, other) + + def __neg__(self): + return ops.neg(self) + + def __truediv__(self, other): + return ops.truediv(self, other) + + def __floordiv__(self, other): + return ops.floordiv(self, other) + + def __mod__(self, other): + return ops.mod(self, other) + + def __pow__(self, other): + return ops.pow(self, other) + + def __lt__(self, other): + return ops.lt(self, other) + + def __le__(self, other): + return ops.le(self, other) + + def __eq__(self, other): + return ops.eq(self, other) + + def __ne__(self, other): + return ops.ne(self, other) + + def __gt__(self, other): + return ops.gt(self, other) + + def __ge__(self, other): + return ops.ge(self, other) + + def __and__(self, other): + return ops.bitwise_and(self, other) + + def __or__(self, other): + return ops.bitwise_or(self, other) + + def __xor__(self, other): + return ops.bitwise_xor(self, other) + + def __invert__(self): + return ops.bitwise_not(self) + + def __rshfit__(self, n): + return ops.bitwise_right_shift(self, n) + + def __lshift__(self, n): + return ops.bitwise_left_shift(self, n) + + +class OpsWrapper(DefaultHandler): + """This wraps any returned IR values into an `OpsValue` instance, so that we + can overload the magic methods for writing mathematical expressions fluently. + """ + + def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: + new_args = [OpsWrapper._unwrap(a) for a in args] + new_kwargs = {k: OpsWrapper._unwrap(v) for k, v in kwargs.items()} + return OpsWrapper._wrap(getattr(_ops, name)(*new_args, **new_kwargs)) + + @staticmethod + def _unwrap(x): + if isinstance(x, (list, tuple)): + return tuple(OpsWrapper._unwrap(v) for v in x) + if isinstance(x, OpsValue): + return x.value + return x + + @staticmethod + def _wrap(x): + if isinstance(x, (list, tuple)): + return tuple(OpsValue(v) for v in x) + return OpsValue(x) + + @staticmethod + def indirect_indexing(index, size, check=True, wrap_neg=True): + # Returns a sympy value, not IR value + index = OpsWrapper._unwrap(index) + return _ops.indirect_indexing(index, size, check, wrap_neg) + + +ops: OpsHandler[Any] = OpsWrapper() + + +class _V: + MockHandler = MockHandler + KernelFormatterHandler = KernelFormatterHandler + WrapperHandler = WrapperHandler + + set_ops_handler: Callable[[OpsHandler[Any]], AbstractContextManager[None]] = ( + _ops._set_handler + ) + get_ops_handler: Callable[[], OpsHandler[Any]] = _ops._get_handler + set_graph_handler: Callable[[GraphLowering], Any] = _graph._set_handler + set_extern_kernel_nodes: Callable[[list[ExternKernelNode]], Any] = ( + _extern_kernel_nodes._set_handler + ) + set_real_inputs: Callable[[Any], Any] = _real_inputs._set_handler + get_real_inputs: Callable[[], Any] = _real_inputs._get_handler + set_fake_mode: Callable[[Any], Any] = _fake_mode._set_handler + get_fake_mode: Callable[[], Any] = _fake_mode._get_handler + set_kernel_handler: Callable[[Any], Any] = _kernel._set_handler + set_debug_handler: Callable[[Any], Any] = _debug._set_handler + set_interpreter_handler: Callable[[Any], Any] = _interpreter._set_handler + set_aot_compilation: Callable[[bool], Any] = _aot_compilation._set_handler + get_aot_compilation: Callable[[], Any] = _aot_compilation._get_handler + set_current_node: Callable[[Any], Any] = _current_node._set_handler + get_current_node: Callable[[], Any] = _current_node._get_handler + set_local_buffer_context: Callable[[Any], Any] = _local_buffer_context._set_handler + get_local_buffer_context: Callable[[], Any] = _local_buffer_context._get_handler + set_choices_handler: Callable[[Any], Any] = _choices._set_handler + set_distributed_autotune_state: Callable[[Any], Any] = ( + _distributed_autotune_state._set_handler + ) + get_distributed_autotune_state: Callable[[], Any] = ( + _distributed_autotune_state._get_handler + ) + + @property + def ops(self) -> OpsHandler[Any]: + """The operator handler specific to the current codegen task""" + return _ops._get_handler() + + @property + def graph(self) -> GraphLowering: + """The graph currently being generated""" + return _graph._get_handler() + + @property + def extern_kernel_nodes(self) -> list[ExternKernelNode]: + """ + The extern_kernel_nodes needed for the entire graph, including the + subgraphs. + See `ProxyExecutor Design Note` in ir.py for more details + """ + return _extern_kernel_nodes._get_handler() + + @property + def real_inputs(self): + """non-fake example inputs""" + return _real_inputs._get_handler() + + @property + def fake_mode(self): + """The graph currently being generated""" + return _fake_mode._get_handler() + + @property + def kernel(self): + """The kernel currently being generated""" + return _kernel._get_handler() + + @property + def debug(self): + return _debug._get_handler() + + @property + def interpreter(self): + return _interpreter._get_handler() + + @property + def aot_compilation(self): + return _aot_compilation._get_handler() is True + + @property + def current_node(self): + return _current_node._get_handler() + + @property + def local_buffer_context(self): + return _local_buffer_context._get_handler() + + @property + def choices(self) -> InductorChoices: + return _choices._get_handler() + + @property + def distributed_autotune_state(self): + return _distributed_autotune_state._get_handler() + + +V = _V() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/wrapper_benchmark.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/wrapper_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..56adde809079f7083e49e5c1fbe32fb2895ac73d --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_inductor/wrapper_benchmark.py @@ -0,0 +1,521 @@ +import argparse +import datetime +import tempfile +from collections import defaultdict +from dataclasses import dataclass +from types import ModuleType +from typing import Any, Optional, Protocol + +import torch +from torch.autograd import DeviceType +from torch.utils._ordered_set import OrderedSet + +from .runtime.benchmarking import benchmarker +from .runtime.runtime_utils import create_bandwidth_info_str, get_num_bytes + + +class BenchmarkCallableType(Protocol): + def __call__(self, times: int, repeat: int) -> float: ... + + +_kernel_category_choices = [ + "foreach", + "persistent_reduction", + "pointwise", + "reduction", + "split_scan", + "template", +] + + +def get_kernel_category_by_source_code(src_code: str) -> str: + """ + Similar to get_kernel_category but use the source code. Call this API + if we have not compile the src_code to module yet. + """ + choices = [ + ch for ch in _kernel_category_choices if f"@triton_heuristics.{ch}" in src_code + ] + if len(choices) == 1: + return choices[0] + else: + return "unknown" + + +def get_kernel_category(kernel_mod: ModuleType) -> str: + """ + Given the module defining a triton kernel, return the category of the kernel. + Category can be one of: + - pointwise + - reduction + - persistent_reduction + + Currently we simply decide the category depending on what decorator is imported + by the kernel. + """ + choices = [ch for ch in _kernel_category_choices if ch in kernel_mod.__dict__] + if len(choices) == 1: + return choices[0] + else: + return "unknown" + + +def get_triton_kernel(mod: ModuleType): # type: ignore[no-untyped-def] + from torch._inductor.runtime.triton_heuristics import CachingAutotuner + + cand_list = [ + v + for k, v in mod.__dict__.items() + if k.startswith("triton_") and isinstance(v, CachingAutotuner) + ] + assert len(cand_list) == 1 + return cand_list[0] + + +def benchmark_all_kernels( + benchmark_name: str, benchmark_all_configs: Optional[dict[Any, Any]] +) -> None: + """ + An experimental API used only when config.benchmark_kernel is true. + + Run the kernel benchmarks for all the kernels cached in PyCodeCache. + Used in the compiled modules. + + Put this method here rather than codegen it for convenience since its implementation + does not change based on different graph modules being compiled. + """ + from torch._inductor.codecache import PyCodeCache + + nfound = 0 + for kernel_mod in PyCodeCache.modules: + kernel_key = kernel_mod.key + if not hasattr(kernel_mod, "get_args") or not hasattr(kernel_mod, "call"): + continue + + triton_kernel = get_triton_kernel(kernel_mod) + device_type = triton_kernel.device_props.type + kernel_category = get_kernel_category(kernel_mod) + args = kernel_mod.get_args() + num_in_out_ptrs = len( + [ + arg_name + for arg_name in triton_kernel.fn.arg_names + if arg_name.startswith("in_out_ptr") + ] + ) + num_gb = triton_kernel.inductor_meta.get("kernel_num_gb", None) + if num_gb is None: + num_gb = get_num_bytes(*args, num_in_out_args=num_in_out_ptrs) / 1e9 + + def get_info_str( + ms: float, + n_regs: Optional[Any], + n_spills: Optional[Any], + shared: Optional[Any], + prefix: str = "", + ) -> str: + if not any(x is None for x in [n_regs, n_spills, shared]): + kernel_detail_str = ( + f" {n_regs:3} regs {n_spills:3} spills {shared:8} shared mem" + ) + else: + kernel_detail_str = "" + + gb_per_s = num_gb / (ms / 1e3) + return create_bandwidth_info_str( + ms, num_gb, gb_per_s, prefix=prefix, suffix=kernel_detail_str + ) + + kernel_desc = ( + f"{benchmark_name:20} {kernel_category[:3].upper()} {kernel_key[:10]}" + ) + if benchmark_all_configs: + assert hasattr(kernel_mod, "benchmark_all_configs") + bench_result = kernel_mod.benchmark_all_configs(args) + print(kernel_desc) + for launcher, ms in bench_result.items(): + print( + f" {get_info_str(ms, launcher.n_regs, launcher.n_spills, launcher.shared)} @ {launcher.config}" + ) + else: + ms = benchmarker.benchmark( + lambda: kernel_mod.call(args), + device=device_type, + rep=40, + ) + assert len(triton_kernel.launchers) == 1, ( + "Autotuner should have selected the best config" + ) + launcher = triton_kernel.launchers[0] + print( + get_info_str( + ms, + launcher.n_regs, + launcher.n_spills, + launcher.shared, + prefix=f"{kernel_desc} ", + ) + ) + + nfound += 1 + if nfound == 0: + print( + "No kernel with benchmark functionality found. Make sure you run inductor with config.benchmark_kernel being True" + ) + + +@dataclass +class ProfileEvent: + category: str + key: str + self_device_time_ms: float + # the benchmark is run multiple times and we average the count across all the + # runs. It should be an integer but define a float just in case. + count: float + + +def parse_profile_event_list( + benchmark_name: str, + event_list: torch.autograd.profiler_util.EventList, + wall_time_ms: float, + nruns: int, + device_name: str, +) -> None: + """ + Parse and generate a report for an event_list. + """ + + def get_self_device_time( + ev: torch.autograd.profiler_util.EventList, + ) -> float: + """ + ev.self_device_time_total is in microsecond. Convert to millisecond. + """ + return ev.self_device_time_total / 1000 / nruns # type: ignore[attr-defined] + + all_events: dict[str, list[ProfileEvent]] = defaultdict(list) + + def add_event( + ev: torch.autograd.profiler_util.EventList, + category: str, + ) -> None: + profile_ev = ProfileEvent( + category=category, + key=ev.key, # type: ignore[attr-defined] + self_device_time_ms=get_self_device_time(ev), + count=ev.count / nruns, # type: ignore[operator] # average across all runs + ) + all_events[category].append(profile_ev) + + for ev in event_list: + assert not ev.is_legacy, "Don't support the legacy profiler" + if ev.device_type == DeviceType.CPU: + # ignore the event on CPU side + continue + + category = "unknown" + if ev.key.startswith("triton_"): + if ev.key.startswith("triton_poi"): + category = "triton_pointwise" + elif ev.key.startswith("triton_red"): + category = "triton_reduction" + elif ev.key.startswith("triton_per"): + category = "triton_persistent_reduction" + else: + category = "triton_unknown" + + add_event(ev, category) + + def report_category(category: str, profile_events: list[ProfileEvent]) -> float: + if not device_name: + return 0.0 + + from tabulate import tabulate + + profile_events.sort(key=lambda ev: ev.self_device_time_ms, reverse=True) + + rows = [] + total_time = 0.0 + print(f"\n == {category} category kernels == ") + for ev in profile_events: + total_time += ev.self_device_time_ms + percent = f"{ev.self_device_time_ms / wall_time_ms * 100:.2f}%" + rows.append([ev.key[:120], ev.self_device_time_ms, ev.count, percent]) + rows.append( + ["Total", total_time, "", f"{total_time / wall_time_ms * 100:.2f}%"] + ) + print( + tabulate( + rows, + headers=[ + "Kernel", + f"Self {device_name.upper()} TIME (ms)", + "Count", + "Percent", + ], + ) + ) + return total_time + + def report() -> None: + category_list = [ + "triton_pointwise", + "triton_reduction", + "triton_persistent_reduction", + "triton_unknown", + "unknown", + ] + assert OrderedSet(all_events.keys()).issubset(OrderedSet(category_list)), ( + f"{list(all_events.keys())}" + ) + + per_category_wall_time = {} + total_device_ms = 0.0 + for category in category_list: + if category in all_events: + _time = report_category(category, all_events[category]) + per_category_wall_time[category] = _time + total_device_ms += _time + + device_busy_percent = f"{total_device_ms / wall_time_ms * 100:.2f}%" + if device_name: + print( + f"\nPercent of time when {device_name.upper()} is busy: {device_busy_percent}" + ) + else: + print("No device detected") + + print(f"Total wall time {wall_time_ms:.3f} ms") + + # output such a line so we can gather such line from all compiled modules from all + # benchmarks and tabulate it! + # Columns: benchmark_name, pointwise_percent, reduction_percent, persistent_reduction_percent, + # unknown_category_percent, device_busy_percent, wall_time_ms + tabulate_line = f"Output for tabulate: {benchmark_name}" + for category in category_list: + percent = ( + f"{per_category_wall_time.get(category, 0.0) / wall_time_ms * 100:.2f}%" + ) + tabulate_line += f", {percent}" + tabulate_line += f", {device_busy_percent}, {wall_time_ms:.3f}ms" + + print(tabulate_line) + + report() + + +PROFILE_DIR = tempfile.gettempdir() +PROFILE_PATH = f"{PROFILE_DIR}/compiled_module_profile.json" + + +def perf_profile( + wall_time_ms: float, + times: int, + repeat: int, + benchmark_name: str, + benchmark_compiled_module_fn: BenchmarkCallableType, +) -> None: + with torch.profiler.profile(record_shapes=True) as p: + benchmark_compiled_module_fn(times=times, repeat=repeat) + + path = PROFILE_PATH + p.export_chrome_trace(path) + print(f"Profiling result for a compiled module of benchmark {benchmark_name}:") + print(f"Chrome trace for the profile is written to {path}") + event_list = p.key_averages(group_by_input_shape=True) + print(event_list.table(sort_by="self_device_time_total", row_limit=10)) + parse_profile_event_list( + benchmark_name, event_list, wall_time_ms, times * repeat, p.use_device or "" + ) + + +def ncu_analyzer( + benchmark_name: str, + benchmark_compiled_module_fn: BenchmarkCallableType, + args: argparse.Namespace, +) -> None: + import inspect + import os + import subprocess + + kernel_regex = args.ncu_kernel_regex + metrics = args.ncu_metrics + + module_file = inspect.getfile(benchmark_compiled_module_fn) + module_dir = os.path.dirname(module_file) + module_name = os.path.splitext(os.path.basename(module_file))[0] + + ncu_dir = tempfile.gettempdir() + timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + ncu_output = os.path.join(ncu_dir, f"ncu_output_{timestamp}.ncu-rep") + python_cmd = ( + f"""import sys; sys.path.insert(0, '{module_dir}'); """ + f"""from {module_name} import benchmark_compiled_module; """ + """benchmark_compiled_module(times=1, repeat=1)""" + ) + + ncu_cmd = [ + "ncu", + "--target-processes", + "all", + "--replay-mode", + "kernel", + "--kernel-name-base", + "function", + "--print-units", + "base", + "--import-source", + "yes", + "--force-overwrite", + "--export", + ncu_output, + ] + + if kernel_regex: + ncu_cmd.extend(["--kernel-name", f"regex:{kernel_regex}"]) + + if metrics: + ncu_cmd.extend(["--metrics", metrics]) + else: + ncu_cmd.extend(["--set", "full"]) + + ncu_cmd.extend( + [ + "python", + "-c", + python_cmd, + ] + ) + + try: + subprocess.run(ncu_cmd, check=True) + print(f"\nNCU profiling results for benchmark {benchmark_name}:") + print(f"NCU report has been written to {ncu_output}") + + except subprocess.CalledProcessError as e: + print(f"NCU profiling failed with error: {e}") + return + + +def collect_memory_snapshot( + benchmark_compiled_module_fn: BenchmarkCallableType, +) -> None: + assert torch.cuda.is_available() + + torch.cuda.memory._record_memory_history(max_entries=100000) + benchmark_compiled_module_fn(times=10, repeat=1) # run 10 times + snapshot_path = f"{tempfile.gettempdir()}/memory_snapshot.pickle" + torch.cuda.memory._dump_snapshot(snapshot_path) + torch.cuda.memory._record_memory_history(enabled=None) + print(f"The collect memory snapshot has been written to {snapshot_path}") + + +# With AOTAutograd cache, we directly call the compiled module. So prevent +# Dynamo from reentering +@torch.compiler.disable # type: ignore[misc] +def compiled_module_main( + benchmark_name: str, benchmark_compiled_module_fn: BenchmarkCallableType +) -> None: + """ + This is the function called in __main__ block of a compiled module. + """ + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "--benchmark-kernels", + "-k", + action="store_true", + help="Whether to benchmark each individual kernels", + ) + parser.add_argument( + "--benchmark-all-configs", + "-c", + action="store_true", + help="Whether to benchmark each individual config for a kernel", + ) + parser.add_argument( + "--profile", + "-p", + action="store_true", + help="Whether to profile the compiled module", + ) + parser.add_argument( + "--cuda-memory-snapshot", + action="store_true", + help=""" + Whether to collect CUDA memory snapshot. Refer to + "https://pytorch.org/blog/understanding-gpu-memory-1/ + for details about how to visualize the collected snapshot + """, + ) + parser.add_argument( + "--ncu", + action="store_true", + help="Whether to run ncu analysis", + ) + parser.add_argument( + "--ncu-kernel-regex", + type=str, + default=None, + help=( + "Filter kernels profiled by NCU using a regex (e.g., '^triton_.*'). " + "Maps to '--kernel-name regex:'. " + "If None, NCU will profile all kernels." + ), + ) + parser.add_argument( + "--ncu-metrics", + type=str, + default=None, + help=( + "Comma-separated list of NCU metrics to collect (e.g., 'dram__bytes.sum.per_second'). " + "If None, NCU will use '--set full'." + ), + ) + parser.add_argument( + "--times", + type=int, + default=10, + help="Number of times to run each benchmark iteration", + ) + parser.add_argument( + "--repeat", + type=int, + default=10, + help="Number of repetitions of each benchmark run", + ) + + args = parser.parse_args() + + if args.benchmark_kernels: + benchmark_all_kernels(benchmark_name, args.benchmark_all_configs) + else: + times = args.times + repeat = args.repeat + + if torch.cuda.is_available(): + torch.cuda.reset_peak_memory_stats() + wall_time_ms = benchmark_compiled_module_fn(times=times, repeat=repeat) * 1000 + + if torch.cuda.is_available(): + peak_mem = torch.cuda.max_memory_allocated() + print(f"Peak GPU memory usage {peak_mem / 1e6:.3f} MB") + + if torch.cuda.is_available() and args.cuda_memory_snapshot: + collect_memory_snapshot(benchmark_compiled_module_fn) + + if args.profile: + perf_profile( + wall_time_ms, + times, + repeat, + benchmark_name, + benchmark_compiled_module_fn, + ) + if args.ncu: + ncu_analyzer( + benchmark_name, + benchmark_compiled_module_fn, + args=args, + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_library/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_library/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6f50d46dde0527d1e31e96d617920e26d9c69acb --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_library/__init__.py @@ -0,0 +1,6 @@ +import torch._library.autograd +import torch._library.fake_impl +import torch._library.simple_registry +import torch._library.utils +from torch._library.fake_class_registry import register_fake_class +from torch._library.triton import capture_triton, triton_op, wrap_triton diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_library/autograd.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_library/autograd.py new file mode 100644 index 0000000000000000000000000000000000000000..c8da8a692648e0a5ca4d0bb6cb5892cf66ea71f5 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_library/autograd.py @@ -0,0 +1,235 @@ +# mypy: allow-untyped-defs +import dataclasses +from collections.abc import Callable +from dataclasses import dataclass +from typing import Any, Optional, Protocol + +from torch import _C, _ops, autograd, Tensor +from torch.utils import _pytree + +from . import utils + + +class InfoProtocol(Protocol): + _backward_fn: Optional[Callable] + _setup_context_fn: Optional[Callable] + + +@dataclasses.dataclass +class Info: + _backward_fn: Optional[Callable] + _setup_context_fn: Optional[Callable] + + +def make_autograd_impl(op: _ops.OpOverload, info: InfoProtocol) -> Callable: + name: str = f"GeneratedBackwardFor_{op._namespace}_{op._opname}_{op._overloadname}" + + has_kwarg_only_args = utils.has_kwarg_only_args(op._schema) + + @dataclass + class Metadata: + keyset: _C.DispatchKeySet + keyword_only_args: dict[str, Any] + + def forward_no_grad(*args): + metadata = args[-1] + args = args[:-1] + + with _C._AutoDispatchBelowAutograd(): + keyset = metadata.keyset + kwargs = metadata.keyword_only_args + result = op.redispatch(keyset & _C._after_autograd_keyset, *args, **kwargs) + return result + + def forward(ctx, *args): + metadata = args[-1] + args = args[:-1] + + with _C._AutoDispatchBelowAutograd(): + keyset = metadata.keyset + kwargs = metadata.keyword_only_args + result = op.redispatch(keyset & _C._after_autograd_keyset, *args, **kwargs) + if info._setup_context_fn: + # The Dispatcher will remove args that are equal to their default + # values from (args, kwargs). We're going to add it back so that + # the user can access them. + # + # This is OK to do: The Dispatcher removed the args for serialization + # FC/BC reasons (that is, a graph will not store args that are equal + # to their default values), but that doesn't matter here. If the user + # adds a new default arg, then they must update + # their setup_context (along with the rest of their operator + # registrations) + args, kwargs = utils.fill_defaults(op._schema, args, kwargs) + + if has_kwarg_only_args: + info._setup_context_fn( + ctx=ctx, inputs=args, keyword_only_inputs=kwargs, output=result + ) + else: + info._setup_context_fn(ctx=ctx, inputs=args, output=result) + return result + + def backward(ctx, *grads): + if info._backward_fn: + try: + prev_needs_input_grad = ctx.needs_input_grad + ctx.needs_input_grad = ctx.needs_input_grad[:-1] + result = info._backward_fn(ctx, *grads) + finally: + ctx.needs_input_grad = prev_needs_input_grad + if isinstance(result, tuple): + return (*result, None) + return result, None + raise RuntimeError( + f"Trying to backward through {op} but no autograd " + f"formula was registered. " + f"Please use register_autograd to add one." + ) + + Generated = type( + name, + (autograd.Function,), + { + "forward": staticmethod(forward), + "backward": staticmethod(backward), + }, + ) + + schema = op._schema + if any( + utils.is_tensorlist_like_type(a.type) + for a in (*schema.arguments, *schema.returns) + ): + Generated = supports_tensorlist(Generated) + + # The dispatcher passes any keyword-only-args as kwargs and the + # rest of the args (even if specified as kwargs) as args. + def autograd_impl(keyset, *args, **keyword_only_args): + if _C.is_grad_enabled() and _C._any_requires_grad(*args): + result = Generated.apply(*args, Metadata(keyset, keyword_only_args)) # type: ignore[attr-defined] + else: + result = forward_no_grad(*args, Metadata(keyset, keyword_only_args)) + return result + + return autograd_impl + + +def supports_tensorlist(cls: Any) -> Any: + """Allows a given autograd.Function class to support List[Tensor] inputs/outputs. + + Regular autograd.Function has a constraint that it only directly supports autograd for + Tensors. Applying @supports_tensorlist enables an autograd.Function to support + autograd for List[Tensor] inputs and outputs. + """ + orig_forward = cls.forward + orig_backward = cls.backward + orig_apply = cls.apply + + @dataclass + class Metadata: + input_spec: _pytree.TreeSpec + output_spec: Optional[_pytree.TreeSpec] = None + result_is_tuple: Optional[bool] = None + + def new_forward(ctx, *args): + metadata = args[-1] + args = args[:-1] + if not isinstance(metadata, Metadata): + raise NotImplementedError( + "NYI: calling supports_tensorlist autograd.Function.forward directly. " + "You should probably be calling .apply instead. " + "Please file an issue if not." + ) + args = _pytree.tree_unflatten(list(args), metadata.input_spec) + result = orig_forward(ctx, *args) + metadata.result_is_tuple = isinstance(result, tuple) + if not metadata.result_is_tuple: + result = (result,) + flat_result, output_spec = _pytree.tree_flatten(result, not_list_of_tensor) + metadata.output_spec = output_spec + + if hasattr(ctx, "_pt_metadata"): + raise RuntimeError( + "Please don't set ctx._pt_metadata; PyTorch uses it to store info" + ) + ctx._pt_metadata = metadata + + return tuple(flat_result) + + def new_backward(ctx, *grads): + if not hasattr(ctx, "_pt_metadata"): + raise NotImplementedError( + "NYI: calling supports_tensorlist autograd.Function.backward directly. " + "This will automatically get called by PyTorch autograd. " + "Please file an issue if you need this." + ) + + metadata = ctx._pt_metadata + grads = _pytree.tree_unflatten(list(grads), metadata.output_spec) + + # If the user's input is ([x, y, z], w), + # then needs_input_grad is (bool, bool, bool, bool, bool). + # We need to + # 1. get rid of the additional bool (which comes from the extra + # `metadata input`) + # 2. _pytree.tree_unflatten to get the right structure. + prev_needs_input_grad = ctx.needs_input_grad + try: + ctx.needs_input_grad = _pytree.tree_unflatten( + list(ctx.needs_input_grad[:-1]), metadata.input_spec + ) + grad_inputs = orig_backward(ctx, *grads) + finally: + ctx.needs_input_grad = prev_needs_input_grad + + if not isinstance(grad_inputs, tuple): + grad_inputs = (grad_inputs,) + # Assume that any Nones in the backward are Tensors. + # If the forward has an arg that is [1, 2, 3], the backward should + # return None as the grad. + # If the forward has an arg that is [tensor, tensor], the backward + # may return [None, None], [grad, None], [None, grad], or [grad, grad]. + flat_grad_inputs, grad_inputs_spec = _pytree.tree_flatten( + grad_inputs, not_list_of_optional_tensor + ) + if grad_inputs_spec != metadata.input_spec: + raise RuntimeError( + f"Expected the return from backward to be of the same structure " + f"as the inputs. Got: {grad_inputs_spec} (return from backward), " + f"{metadata.input_spec} (inputs)" + ) + return tuple(flat_grad_inputs + [None]) + + def new_apply(*args): + flat_args, input_spec = _pytree.tree_flatten(args, is_leaf=not_list_of_tensor) + metadata = Metadata(input_spec) + result = orig_apply(*flat_args, metadata) # type: ignore[misc] + assert metadata.output_spec is not None + result = _pytree.tree_unflatten(list(result), metadata.output_spec) + if not metadata.result_is_tuple: + assert isinstance(result, tuple) + assert len(result) == 1 + return result[0] + return result + + cls.forward = new_forward + cls.backward = new_backward + cls.apply = new_apply + return cls + + +def not_list_of_tensor(tree): + if isinstance(tree, tuple): + return False + if isinstance(tree, list): + return any(not isinstance(l, Tensor) for l in tree) + return True + + +def not_list_of_optional_tensor(tree): + if isinstance(tree, tuple): + return False + if isinstance(tree, list): + return any(l is not None and not isinstance(l, Tensor) for l in tree) + return True diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_library/custom_ops.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_library/custom_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..a317297efba8842fe918db9dfcc4880133de6626 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_library/custom_ops.py @@ -0,0 +1,948 @@ +# mypy: allow-untyped-defs +import collections +import inspect +import logging +import warnings +import weakref +from collections.abc import Callable, Iterable, Sequence +from contextlib import contextmanager +from typing import Any, Optional, overload, Union + +import torch +from torch import _C, _ops, Tensor +from torch.types import _dtype +from torch.utils._exposed_in import exposed_in + +from . import autograd, utils +from .effects import EffectType + + +device_types_t = Optional[Union[str, Sequence[str]]] +log = logging.getLogger(__name__) + + +@overload +def custom_op( + name: str, + fn: None = None, + /, + *, + mutates_args: Union[str, Iterable[str]], + device_types: device_types_t = None, + schema: Optional[str] = None, + tags: Optional[Sequence[_C.Tag]] = None, +) -> Callable[[Callable[..., object]], "CustomOpDef"]: ... + + +@overload +def custom_op( + name: str, + fn: Callable[..., object], + /, + *, + mutates_args: Union[str, Iterable[str]], + device_types: device_types_t = None, + schema: Optional[str] = None, + tags: Optional[Sequence[_C.Tag]] = None, +) -> "CustomOpDef": ... + + +@exposed_in("torch.library") +def custom_op( + name: str, + fn: Optional[Callable] = None, + /, + *, + mutates_args: Union[str, Iterable[str]], + device_types: device_types_t = None, + schema: Optional[str] = None, + tags: Optional[Sequence[_C.Tag]] = None, +) -> Union[Callable[[Callable[..., object]], "CustomOpDef"], "CustomOpDef"]: + """Wraps a function into custom operator. + + Reasons why you may want to create a custom op include: + - Wrapping a third-party library or custom kernel to work with PyTorch + subsystems like Autograd. + - Preventing torch.compile/export/FX tracing from peeking inside your function. + + This API is used as a decorator around a function (please see examples). + The provided function must have type hints; these are needed to interface + with PyTorch's various subsystems. + + Args: + name (str): A name for the custom op that looks like "{namespace}::{name}", + e.g. "mylib::my_linear". The name is used as the op's stable identifier + in PyTorch subsystems (e.g. torch.export, FX graphs). + To avoid name collisions, please use your project name as the namespace; + e.g. all custom ops in pytorch/fbgemm use "fbgemm" as the namespace. + mutates_args (Iterable[str] or "unknown"): The names of args that the function mutates. + This MUST be accurate, otherwise, the behavior is undefined. If "unknown", + it pessimistically assumes that all inputs to the operator are being mutated. + device_types (None | str | Sequence[str]): The device type(s) the function + is valid for. If no device type is provided, then the function + is used as the default implementation for all device types. + Examples: "cpu", "cuda". + When registering a device-specific implementation for an operator that accepts no Tensors, + we require the operator to have a "device: torch.device argument". + schema (None | str): A schema string for the operator. If None + (recommended) we'll infer a schema for the operator from its type + annotations. We recommend letting us infer a schema unless you + have a specific reason not to. + Example: "(Tensor x, int y) -> (Tensor, Tensor)". + + .. note:: + We recommend not passing in a ``schema`` arg and instead letting us infer + it from the type annotations. It is error-prone to write your own schema. + You may wish to provide your own schema if our interpretation of + the type annotation is not what you want. + For more info on how to write a schema string, see + `here `_ + + Examples:: + >>> import torch + >>> from torch import Tensor + >>> from torch.library import custom_op + >>> import numpy as np + >>> + >>> @custom_op("mylib::numpy_sin", mutates_args=()) + >>> def numpy_sin(x: Tensor) -> Tensor: + >>> x_np = x.cpu().numpy() + >>> y_np = np.sin(x_np) + >>> return torch.from_numpy(y_np).to(device=x.device) + >>> + >>> x = torch.randn(3) + >>> y = numpy_sin(x) + >>> assert torch.allclose(y, x.sin()) + >>> + >>> # Example of a custom op that only works for one device type. + >>> @custom_op("mylib::numpy_sin_cpu", mutates_args=(), device_types="cpu") + >>> def numpy_sin_cpu(x: Tensor) -> Tensor: + >>> x_np = x.numpy() + >>> y_np = np.sin(x_np) + >>> return torch.from_numpy(y_np) + >>> + >>> x = torch.randn(3) + >>> y = numpy_sin_cpu(x) + >>> assert torch.allclose(y, x.sin()) + >>> + >>> # Example of a custom op that mutates an input + >>> @custom_op("mylib::numpy_sin_inplace", mutates_args={"x"}, device_types="cpu") + >>> def numpy_sin_inplace(x: Tensor) -> None: + >>> x_np = x.numpy() + >>> np.sin(x_np, out=x_np) + >>> + >>> x = torch.randn(3) + >>> expected = x.sin() + >>> numpy_sin_inplace(x) + >>> assert torch.allclose(x, expected) + >>> + >>> # Example of a factory function + >>> @torch.library.custom_op("mylib::bar", mutates_args={}, device_types="cpu") + >>> def bar(device: torch.device) -> Tensor: + >>> return torch.ones(3) + >>> + >>> bar("cpu") + + """ + + def inner(fn: Callable[..., object]) -> CustomOpDef: + import torch + + if schema is None: + schema_str = torch.library.infer_schema(fn, mutates_args=mutates_args) + else: + schema_str = schema + + namespace, opname = name.split("::") + result = CustomOpDef(namespace, opname, schema_str, fn, tags) + if schema is not None: + # Check that schema's alias annotations match those of `mutates_args`. + expected = set() + for arg in result._opoverload._schema.arguments: + if arg.alias_info is not None and arg.alias_info.is_write: + expected.add(arg.name) + if expected != set(mutates_args): + raise ValueError( + f"Attempted to create a custom op with `mutates_args={mutates_args}` " + f"and `schema={schema}. The schema suggests that the op mutates {expected}" + f"which is different from what was provided to us in `mutates_args`. " + f"Please make these consistent." + ) + result.register_kernel(device_types)(fn) + return result + + if fn is None: + return inner + return inner(fn) + + +class CustomOpDef: + """CustomOpDef is a wrapper around a function that turns it into a custom op. + + It has various methods for registering additional behavior for this + custom op. + + You should not instantiate CustomOpDef directly; instead, use the + :func:`torch.library.custom_op` API. + """ + + def __init__( + self, + namespace: str, + name: str, + schema: str, + fn: Callable, + tags: Optional[Sequence[_C.Tag]] = None, + ) -> None: + # Fields used to interface with the PyTorch dispatcher + self._namespace = namespace + self._name = name + self._schema = schema + self._tags = tags if tags is not None else [] + + self._init_fn = fn + + self._backend_fns: dict[Union[str, None], Callable] = {} + self._abstract_fn: Optional[Callable] = None + self._setup_context_fn: Optional[Callable] = None + self._backward_fn: Optional[Callable] = None + self._torch_dispatch_fns: dict[type, Callable] = {} + self._vmap_fn: Optional[Callable] = None + self._autocast_cuda_dtype: Optional[_dtype] = None + self._autocast_cpu_dtype: Optional[_dtype] = None + + self._lib = get_library_allowing_overwrite(self._namespace, self._name) + self._register_to_dispatcher(self._tags) + self._disabled_kernel: set = set() + self._used_triton_kernels: list[Any] = list() + OPDEFS[self._qualname] = self + + @property + def _qualname(self) -> str: + return f"{self._namespace}::{self._name}" + + def __repr__(self) -> str: + return f"" + + @contextmanager + def set_kernel_enabled(self, device_type: str, enabled: bool = True): + """ + Disable or re-enable an already registered kernel for this custom operator. + + If the kernel is already disabled/enabled, this is a no-op. + + Note: + If a kernel is first disabled and then registered, it is disabled until enabled again. + + Args: + device_type (str): The device type to disable/enable the kernel for. + disable (bool): Whether to disable or enable the kernel. + + Example: + >>> inp = torch.randn(1) + >>> + >>> # define custom op `f`. + >>> @custom_op("mylib::f", mutates_args=()) + >>> def f(x: Tensor) -> Tensor: + >>> return torch.zeros(1) + >>> + >>> print(f(inp)) # tensor([0.]), default kernel + >>> + >>> @f.register_kernel("cpu") + >>> def _(x): + >>> return torch.ones(1) + >>> + >>> print(f(inp)) # tensor([1.]), CPU kernel + >>> + >>> # temporarily disable the CPU kernel + >>> with f.set_kernel_enabled("cpu", enabled = False): + >>> print(f(inp)) # tensor([0.]) with CPU kernel disabled + + """ + action = "enable" if enabled else "disable" + originally_disabled = device_type in self._disabled_kernel + if device_type not in self._backend_fns: + log.warning( + "Attempted to %s kernel for %s but no kernel was registered for this device type.", + action, + device_type, + ) + + if not enabled: + if originally_disabled: + log.warning( + "Attempted to disable kernel for %s but it was already disabled.", + device_type, + ) + else: + self._disabled_kernel.add(device_type) + else: # enable the kernel + if not originally_disabled: + log.warning( + "Attempted to enable kernel for %s but it was already enabled.", + device_type, + ) + else: + self._disabled_kernel.remove(device_type) + + try: + yield + finally: + # restore original state + if originally_disabled: + self._disabled_kernel.add(device_type) + else: + self._disabled_kernel.discard(device_type) + + def register_kernel( + self, device_types: device_types_t, fn: Optional[Callable] = None, / + ) -> Callable: + """Register an implementation for a device type for this operator. + + Some valid device_types are: "cpu", "cuda", "xla", "mps", "ipu", "xpu". + This API may be used as a decorator. + + Args: + fn (Callable): The function to register as the implementation for + the given device types. + device_types (str | Sequence[str]): The device device_types to register an impl to. + + Examples:: + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) + >>> import torch + >>> from torch import Tensor + >>> from torch.library import custom_op + >>> import numpy as np + >>> + >>> # Create a custom op that works on cpu + >>> @custom_op("mylib::numpy_sin", mutates_args=(), device_types="cpu") + >>> def numpy_sin(x: Tensor) -> Tensor: + >>> x_np = x.numpy() + >>> y_np = np.sin(x_np) + >>> return torch.from_numpy(y_np) + >>> + >>> # Add implementations for the cuda device + >>> @numpy_sin.register_kernel("cuda") + >>> def _(x): + >>> x_np = x.cpu().numpy() + >>> y_np = np.sin(x_np) + >>> return torch.from_numpy(y_np).to(device=x.device) + >>> + >>> x_cpu = torch.randn(3) + >>> x_cuda = x_cpu.cuda() + >>> assert torch.allclose(numpy_sin(x_cpu), x_cpu.sin()) + >>> assert torch.allclose(numpy_sin(x_cuda), x_cuda.sin()) + + """ + + def inner(fn): + if device_types is None or isinstance(device_types, str): + dtypes: list[Union[str, None]] = [device_types] + else: + dtypes = list(device_types) + for device_type in dtypes: + if device_type not in self._backend_fns: + + def backend_impl(*args, **kwargs): + result = self._backend_fns[device_type](*args, **kwargs) + + def get_module(): + fn = self._backend_fns[device_type] + return inspect.getmodule(fn) + + schema = self._opoverload._schema + if not schema._is_view_op(): + utils._c_check_aliasing_constraint( + self._name, + args, + kwargs, + result, + get_module, + ) + return result + + if device_type is None: + self._lib.impl( + self._name, backend_impl, "CompositeExplicitAutograd" + ) + else: + self._lib.impl( + self._name, + backend_impl, + _C._dispatch_key_for_device(device_type), + ) + + # Wrap function to choose between the default implementation or the device-specific + # implementation depending on if the kernel is disabled. + @torch._disable_dynamo + def wrapped_fn(*args, **kwargs): + if device_type in self._disabled_kernel: + return self._init_fn(*args, **kwargs) + else: + return fn(*args, **kwargs) + + self._backend_fns[device_type] = wrapped_fn + return fn + + if device_types is not None and not utils.has_tensor_arg( + self._opoverload._schema + ): + device_arg_index = utils.get_device_arg_index(self._opoverload._schema) + if device_arg_index is None: + raise ValueError( + "Functions without tensor inputs are required to have a `device: torch.device` argument" + ) + self._register_backend_select_dispatcher(device_arg_index) + + # See NOTE: [Supporting decorator and non-decorator usage] + if fn is None: + return inner + return inner(fn) + + def register_fake(self, fn: Callable, /) -> Callable: + r"""Register a FakeTensor implementation for this custom op. + + This is necessary to get the operator to work efficiently with torch.compile. + + The Fake impl (sometimes also known as a meta kernel or abstract impl) + specifies the behavior of this operator on Tensors that carry no data. + Given some input Tensors with certain properties + (sizes/strides/storage_offset/device), it specifies what the properties of + the output Tensors are. + + Please see :func:`torch.library.register_fake` for more details. + + Args: + fn (Callable): The function to register as the FakeTensor + implementation. + + Examples: + >>> import torch + >>> import numpy as np + >>> from torch import Tensor + >>> + >>> # Example 1: an operator without data-dependent output shape + >>> @torch.library.custom_op("mylib::linear", mutates_args=()) + >>> def linear(x: Tensor, weight: Tensor, bias: Tensor) -> Tensor: + >>> return (x @ weight.t()) + bias + >>> + >>> @linear.register_fake + >>> def _(x, weight, bias): + >>> assert x.dim() == 2 + >>> assert weight.dim() == 2 + >>> assert bias.dim() == 1 + >>> assert x.shape[1] == weight.shape[1] + >>> assert weight.shape[0] == bias.shape[0] + >>> assert x.device == weight.device + >>> return x.new_empty(x.size(0), weight.size(0)) + >>> + >>> x = torch.randn(2, 2) + >>> weight = torch.randn(2, 2) + >>> bias = torch.randn(2) + >>> # xdoctest: +SKIP("Requires Python <= 3.11") + >>> out = torch.compile(linear, fullgraph=True)(x, weight, bias) + >>> # xdoctest: +SKIP("Requires Python <= 3.11") + >>> assert torch.allclose(out, torch.nn.functional.linear(x, weight, bias)) + >>> + >>> # Example 2: an operator with data-dependent output shape + >>> @torch.library.custom_op("mylib::nonzero", mutates_args=()) + >>> def nonzero(x: Tensor) -> Tensor: + >>> x_np = x.cpu().numpy() + >>> res = np.stack(np.nonzero(x_np), axis=1) + >>> return torch.tensor(res, device=x.device) + >>> + >>> @nonzero.register_fake + >>> def _(x): + >>> # Number of nonzero-elements is data-dependent. + >>> # Since we cannot peek at the data in an abstract impl, + >>> # we use the ctx object to construct a new symint that + >>> # represents the data-dependent size. + >>> ctx = torch.library.get_ctx() + >>> nnz = ctx.new_dynamic_size() + >>> shape = [nnz, x.dim()] + >>> result = x.new_empty(shape, dtype=torch.int64) + >>> return result + >>> + >>> x = torch.tensor([0, 1, 2, 0, 0, 1]) + >>> # xdoctest: +SKIP("Requires Python <= 3.11") + >>> out = torch.compile(nonzero, fullgraph=True)(x) + >>> # xdoctest: +SKIP("Requires Python <= 3.11") + >>> assert torch.allclose(out, x.nonzero()) + + """ + self._abstract_fn = fn + return fn + + def register_effect(self, effect: Optional[EffectType]) -> None: + self._lib._register_effectful_op(self._qualname, effect) + + def register_torch_dispatch( + self, torch_dispatch_class: Any, fn: Optional[Callable] = None, / + ) -> Callable: + r"""Registers a torch_dispatch rule for the given operator and ``torch_dispatch_class``. + + This allows for open registration to specify the behavior between the operator + and the ``torch_dispatch_class`` without needing to modify the ``torch_dispatch_class`` + or the operator directly. + + Please see :func:`torch.library.register_torch_dispatch` for examples and more details. + """ + + def register(fn): + if torch_dispatch_class not in self._torch_dispatch_fns: + + def inner(*args, **kwargs): + return self._torch_dispatch_fns[torch_dispatch_class]( + *args, **kwargs + ) + + self._lib._register_torch_dispatch_rule( + self._name, torch_dispatch_class, inner + ) + self._torch_dispatch_fns[torch_dispatch_class] = fn + return fn + + if fn is None: + return register + else: + return register(fn) + + def register_autograd( + self, + backward: Callable, + /, + *, + setup_context: Optional[Callable] = None, + ) -> None: + r"""Register a backward formula for this custom op. + + In order for an operator to work with autograd, you need to register + a backward formula: + 1. You must tell us how to compute gradients during the backward pass + by providing us a "backward" function. + 2. If you need any values from the forward to compute gradients, you can + use `setup_context` to save values for backward. + + ``backward_fn`` runs during the backward pass. It accepts ``(ctx, *grads)``: + - ``grads`` is one or more gradients. The number of gradients matches + the number of outputs of the operator. + The ``ctx`` object is `the same ctx object `_ used by + :class:`torch.autograd.Function`. The semantics of ``backward_fn`` are the + same as :meth:`torch.autograd.Function.backward`. + + ``setup_context(ctx, inputs, output)`` runs during the forward pass. + Please save quantities needed for backward onto the ``ctx`` object via + either :meth:`torch.autograd.function.FunctionCtx.save_for_backward` + or assigning them as attributes of ``ctx``. If your custom op has + kwarg-only arguments, we expect the signature of ``setup_context`` + to be ``setup_context(ctx, inputs, keyword_only_inputs, output)``. + + Both ``setup_context_fn`` and ``backward_fn`` must be traceable. That is, + they may not directly access :meth:`torch.Tensor.data_ptr` and they must + not depend on or mutate global state. If you need a non-traceable backward, + you can make it a separate custom_op that you call inside ``backward_fn``. + + If you need different autograd behavior on different devices, then we + recommend creating two different custom operators, one for each device + that needs different behavior, and switching between them at runtime. + + Examples: + >>> import torch + >>> import numpy as np + >>> from torch import Tensor + >>> + >>> @torch.library.custom_op("mylib::numpy_sin", mutates_args=()) + >>> def numpy_sin(x: Tensor) -> Tensor: + >>> x_np = x.cpu().numpy() + >>> y_np = np.sin(x_np) + >>> return torch.from_numpy(y_np).to(device=x.device) + >>> + >>> def setup_context(ctx, inputs, output) -> Tensor: + >>> x, = inputs + >>> ctx.save_for_backward(x) + >>> + >>> def backward(ctx, grad): + >>> x, = ctx.saved_tensors + >>> return grad * x.cos() + >>> + >>> numpy_sin.register_autograd(backward, setup_context=setup_context) + >>> + >>> x = torch.randn(3, requires_grad=True) + >>> y = numpy_sin(x) + >>> (grad_x,) = torch.autograd.grad(y, x, torch.ones_like(y)) + >>> assert torch.allclose(grad_x, x.cos()) + >>> + >>> # Example with a keyword-only arg + >>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=()) + >>> def numpy_mul(x: Tensor, *, val: float) -> Tensor: + >>> x_np = x.cpu().numpy() + >>> y_np = x_np * val + >>> return torch.from_numpy(y_np).to(device=x.device) + >>> + >>> def setup_context(ctx, inputs, keyword_only_inputs, output) -> Tensor: + >>> ctx.val = keyword_only_inputs["val"] + >>> + >>> def backward(ctx, grad): + >>> return grad * ctx.val + >>> + >>> numpy_mul.register_autograd(backward, setup_context=setup_context) + >>> + >>> x = torch.randn(3, requires_grad=True) + >>> y = numpy_mul(x, val=3.14) + >>> (grad_x,) = torch.autograd.grad(y, x, torch.ones_like(y)) + >>> assert torch.allclose(grad_x, torch.full_like(x, 3.14)) + + """ + schema = self._opoverload._schema + if not utils.is_functional_schema(schema, allow_valid_view=True): + raise RuntimeError( + f"Cannot register autograd formula for non-functional operator " + f"{self} with schema {schema}. Please create " + f"a functional operator and register an autograd formula for that." + ) + + self._backward_fn = backward + self._setup_context_fn = setup_context + + def _register_to_dispatcher(self, tags: Sequence[_C.Tag]) -> None: + lib = self._lib + schema_str = self._name + self._schema + cpp_schema = _C.parse_schema(schema_str) + if utils.has_kwarg_only_tensors(cpp_schema): + # If you want to support this, the progression is: + # - supporting kwarg-only Tensors that are non-differentiable + # - supporting kwarg-only Tensors (regardless of differentiability) + raise NotImplementedError( + f"custom_op with kwarg-only Tensor args. Please make your " + f"tensors not kwarg-only. Got: {schema_str}" + ) + + lib.define( + schema_str, + tags=[_C.Tag.pt2_compliant_tag, *tags], + ) + self._opoverload = utils.lookup_op(self._qualname) + + def fake_impl(*args, **kwargs): + if self._abstract_fn is None: + if utils.can_generate_trivial_fake_impl(self._opoverload): + return None + raise RuntimeError( + f"There was no fake impl registered for {self}. " + f"This is necessary for torch.compile/export/fx tracing to work. " + f"Please use `{self._init_fn.__name__}.register_fake` to add an " + f"fake impl." + ) + return self._abstract_fn(*args, **kwargs) + + lib._register_fake(self._name, fake_impl, _stacklevel=4) + + autograd_impl = autograd.make_autograd_impl(self._opoverload, self) + lib.impl(self._name, autograd_impl, "Autograd", with_keyset=True) + schema = self._opoverload._schema + + if schema._is_view_op() or schema.is_mutable: + lib.m.register_ad_inplace_or_view_fallback(self._name) # type: ignore[union-attr] + + if schema.is_mutable: + mutated_idxs, mutated_keys = utils.mutated_args_kwargs(schema) + + original_kernel = torch._C._dispatch_get_computed_kernel_for_dispatch_key( + f"{lib.ns}::{self._name}", "ADInplaceOrView" + ) + + def adinplaceorview_impl(keyset, *args, **kwargs): + # Handle the mutated idx the user gave us explicitly + + for idx in mutated_idxs: + increment_version(args[idx]) + for key in mutated_keys: + increment_version(kwargs[key]) + # Handle view + mutation that are in the schema + return original_kernel.call_boxed(keyset, *args, **kwargs) + + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + message="Warning only once for all operators", + category=UserWarning, + ) + lib.impl( + self._name, + adinplaceorview_impl, + "ADInplaceOrView", + with_keyset=True, + ) + + def _register_backend_select_dispatcher(self, device_arg_index: int): + """ + Switch on the device argument to select the correct backend to dispatch to. + """ + + def backend_select(keyset, *args, **kwargs): + device = args[device_arg_index].type + if device not in self._backend_fns: + raise RuntimeError( + f"{self._name} does not have a kernel registered for {device}. " + "Please use register_kernel to do so." + ) + dispatch_key = _C._dispatch_key_for_device(device) + dispatch_key = getattr(_C.DispatchKey, dispatch_key) + return self._opoverload.redispatch( + _C.DispatchKeySet(dispatch_key), *args, **kwargs + ) + + self._lib.impl(self._name, backend_select, "BackendSelect", with_keyset=True) + + def __call__(self, *args, **kwargs): + return self._opoverload(*args, **kwargs) + + def register_vmap( + self, + func: Optional[Callable] = None, + ): + r"""Register a vmap implementation to support :func:`torch.vmap` for this custom op. + + This API may be used as a decorator. + + In order for an operator to work with :func:`torch.vmap`, you may need to register a + vmap implementation in the following signature: + + ``vmap_func(info, in_dims: Tuple[Optional[int]], *args, **kwargs)``, + + where ``*args`` and ``**kwargs`` are the arguments and kwargs for ``op``. + + It specifies how do we compute the batched version of ``op`` given inputs with an additional + dimension (specified by ``in_dims``). + + For each arg in ``args``, ``in_dims`` has a corresponding ``Optional[int]``. It is ``None`` + if the arg is not a Tensor or if the arg is not being vmapped over, otherwise, it is an integer + specifying what dimension of the Tensor is being vmapped over. + + ``info`` is a collection of additional metadata that may be helpful: + ``info.batch_size`` specifies the size of the dimension being vmapped over, while + ``info.randomness`` is the ``randomness`` option that was passed to :func:`torch.vmap`. + + The return of the function ``func`` is a tuple of ``(output, out_dims)``. Similar to ``in_dims``, + ``out_dims`` should be of the same structure as ``output`` and contain one ``out_dim`` + per output that specifies if the output has the vmapped dimension and what index it is in. + + Examples: + >>> import torch + >>> import numpy as np + >>> from torch import Tensor + >>> from typing import Tuple + >>> + >>> def to_numpy(tensor): + >>> return tensor.cpu().numpy() + >>> + >>> lib = torch.library.Library("mylib", "FRAGMENT") + >>> @torch.library.custom_op("mylib::numpy_cube", mutates_args=()) + >>> def numpy_cube(x: Tensor) -> Tuple[Tensor, Tensor]: + >>> x_np = to_numpy(x) + >>> dx = torch.tensor(3 * x_np ** 2, device=x.device) + >>> return torch.tensor(x_np ** 3, device=x.device), dx + >>> + >>> def numpy_cube_vmap(info, in_dims, x): + >>> result = numpy_cube(x) + >>> return result, (in_dims[0], in_dims[0]) + >>> + >>> numpy_cube.register_vmap(numpy_cube_vmap) + >>> + >>> x = torch.randn(3) + >>> torch.vmap(numpy_cube)(x) + >>> + >>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=()) + >>> def numpy_mul(x: Tensor, y: Tensor) -> Tensor: + >>> return torch.tensor(to_numpy(x) * to_numpy(y), device=x.device) + >>> + >>> @numpy_mul.register_vmap + >>> def numpy_mul_vmap(info, in_dims, x, y): + >>> x_bdim, y_bdim = in_dims + >>> x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1) + >>> y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1) + >>> result = x * y + >>> result = result.movedim(-1, 0) + >>> return result, 0 + >>> + >>> + >>> x = torch.randn(3) + >>> y = torch.randn(3) + >>> torch.vmap(numpy_mul)(x, y) + """ + from torch._functorch.autograd_function import custom_function_call_vmap_helper + from torch._functorch.pyfunctorch import retrieve_current_functorch_interpreter + + def register(func): + need_register = self._vmap_fn is None + self._vmap_fn = func + + if need_register: + + def wrapped_func(keyset, *args, **kwargs): + interpreter = retrieve_current_functorch_interpreter() + return custom_function_call_vmap_helper( + interpreter, self._vmap_fn, self._opoverload, *args, **kwargs + ) + + self._lib.impl( + self._name, wrapped_func, "FuncTorchBatched", with_keyset=True + ) + + if func is None: + return register + else: + return register(func) + + def register_autocast( + self, + device_type: str, + cast_inputs: _dtype, + ): + r"""Register an autocast dispatch rule for this custom op. + + Valid `device_type` include: "cpu" and "cuda". + + Args: + op (str | OpOverload): The operator to register an autocast dispatch rule to. + device_type(str): Device type to use. 'cuda' or 'cpu'. + The type is the same as the `type` attribute of a :class:`torch.device`. + Thus, you may obtain the device type of a tensor using `Tensor.device.type`. + cast_inputs (:class:`torch.dtype`): When custom op runs in an autocast-enabled region, + casts incoming floating-point Tensors to the target dtype (non-floating-point Tensors + are not affected), then executes custom op with autocast disabled. + lib (Optional[Library]): If provided, the lifetime of this registration + + Examples:: + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) + >>> import torch + >>> from torch import Tensor + >>> from torch.library import custom_op + >>> + >>> # Create a custom op that works on cuda + >>> @torch.library.custom_op("mylib::my_sin", mutates_args=()) + >>> def my_sin(x: Tensor) -> Tensor: + >>> return torch.sin(x) + >>> + >>> # Register autocast dispatch rule for the cuda device + >>> torch.library.register_autocast("mylib::my_sin", "cuda", torch.float16) + >>> + >>> x = torch.randn(3, dtype=torch.float32, device="cuda") + >>> with torch.autocast("cuda", dtype=torch.float16): + >>> y = torch.ops.mylib.my_sin(x) + >>> assert y.dtype == torch.float16 + + """ + if not isinstance(device_type, str): + raise ValueError( + f"Expected `device_type` of type `str`, got: `{type(device_type)}`" + ) + if device_type not in ["cpu", "cuda"]: + raise ValueError(f"Unknown device type: {device_type}") + + need_register_cuda = self._autocast_cuda_dtype is None + need_register_cpu = self._autocast_cpu_dtype is None + if device_type == "cuda": + self._autocast_cuda_dtype = cast_inputs + else: + self._autocast_cpu_dtype = cast_inputs + + def kernel(_, *args, **kwargs): + assert len(kwargs) == 0, "Custom ops do not support kwargs yet." + autocast_keyset = torch._C.DispatchKeySet( + torch._C.DispatchKey.AutocastCPU + ) | torch._C.DispatchKeySet(torch._C.DispatchKey.AutocastCUDA) + with torch._C._ExcludeDispatchKeyGuard(autocast_keyset): + return self._opoverload(*_cast(args, device_type, cast_inputs)) + + if need_register_cuda and self._autocast_cuda_dtype: + self._lib.impl(self._name, kernel, "AutocastCUDA", with_keyset=True) + elif need_register_cpu and self._autocast_cpu_dtype: + self._lib.impl(self._name, kernel, "AutocastCPU", with_keyset=True) + + return kernel + + +# TODO: Merge this function with torch.amp.autocast_mode._cast, and refactor it +# into a utility function once custom ops support arbitrary input types. +def _cast(value, device_type: str, dtype: _dtype): + if isinstance(value, torch.Tensor): + is_eligible = ( + value.is_floating_point() + and value.device.type == device_type + and (value.dtype is not torch.float64) + ) + return value.to(dtype) if is_eligible else value + elif isinstance(value, (str, bytes)): + return value + elif isinstance(value, collections.abc.Iterable): + iterable = (_cast(v, device_type, dtype) for v in value) + if isinstance(value, (list, tuple)): + return type(value)(iterable) + else: + return iterable + else: + return value + + +def increment_version(val: Any) -> None: + if isinstance(val, Tensor): + torch.autograd.graph.increment_version(val) + elif isinstance(val, (tuple, list)): + for v in val: + if isinstance(v, Tensor): + torch.autograd.graph.increment_version(v) + + +# NOTE: [Supporting decorator and non-decorator usage] +# +# Some APIs may be both used as a decorator and not as a decorator. +# For example: +# +# >>> def fn(x): +# >>> return x.sin() +# >>> +# >>> # Usage 1: not as a decorator +# >>> numpy_sin.register_kernel("cuda", fn) +# >>> +# >>> # Usage 2: as a decorator +# >>> @numpy_sin.register_kernel("cuda") +# >>> def fn2(x): +# >>> return x.sin +# +# The way we support this is that `register_kernel` accepts an optional `fn`. +# If `fn` is provided (Usage 1), then we know that the user is using it not +# as a decorator. +# If `fn` is not provided (Usage 2), then `register_kernel` needs to return a +# decorator. + + +OPDEF_TO_LIB: dict[str, "torch.library.Library"] = {} +OPDEFS: weakref.WeakValueDictionary = weakref.WeakValueDictionary() + + +def get_library_allowing_overwrite( + namespace: str, name: str +) -> "torch.library.Library": + qualname = f"{namespace}::{name}" + + if qualname in OPDEF_TO_LIB: + OPDEF_TO_LIB[qualname]._destroy() + del OPDEF_TO_LIB[qualname] + + lib = torch.library.Library(namespace, "FRAGMENT") # noqa: TOR901 + OPDEF_TO_LIB[qualname] = lib + return lib + + +def _maybe_get_opdef( + op: Union[CustomOpDef, _ops.OpOverload, str], +) -> Optional[CustomOpDef]: + if isinstance(op, CustomOpDef): + return op + if isinstance(op, _ops.OpOverload): + op = op._name + assert isinstance(op, str) + if op in OPDEFS: + return OPDEFS[op] + return None diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_library/effects.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_library/effects.py new file mode 100644 index 0000000000000000000000000000000000000000..e69c361789b5da9c22f22629ac334f274488dfc6 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_library/effects.py @@ -0,0 +1,84 @@ +from enum import Enum +from typing import Optional + +import torch + + +class EffectType(Enum): + ORDERED = "Ordered" + + +from torch._library.utils import RegistrationHandle + + +# These classes do not have side effects as they just store quantization +# params, so we dont need to mark them as ordered +skip_classes = ( + "__torch__.torch.classes.quantized.Conv2dPackedParamsBase", + "__torch__.torch.classes.quantized.Conv3dPackedParamsBase", + "__torch__.torch.classes.quantized.EmbeddingPackedParamsBase", + "__torch__.torch.classes.quantized.LinearPackedParamsBase", + "__torch__.torch.classes.xnnpack.Conv2dOpContext", + "__torch__.torch.classes.xnnpack.LinearOpContext", + "__torch__.torch.classes.xnnpack.TransposeConv2dOpContext", +) + + +class EffectHolder: + """A holder where one can register an effect impl to.""" + + def __init__(self, qualname: str): + self.qualname: str = qualname + self._set_default_effect() + + def _set_default_effect(self) -> None: + self._effect: Optional[EffectType] = None + + # If the op contains a ScriptObject input, we want to mark it as having effects + namespace, opname = torch._library.utils.parse_namespace(self.qualname) + split = opname.split(".") + if len(split) > 1: + assert len(split) == 2, ( + f"Tried to split {opname} based on '.' but found more than 1 '.'" + ) + opname, overload = split + else: + overload = "" + + if namespace == "higher_order": + return + + opname = f"{namespace}::{opname}" + if torch._C._get_operation_overload(opname, overload) is not None: + # Since we call this when destroying the library, sometimes the + # schema will be gone already at that time. + schema = torch._C._get_schema(opname, overload) + for arg in schema.arguments: + if isinstance(arg.type, torch.ClassType): + type_str = arg.type.str() # pyrefly: ignore[missing-attribute] + if type_str in skip_classes: + continue + self._effect = EffectType.ORDERED + return + + @property + def effect(self) -> Optional[EffectType]: + return self._effect + + @effect.setter + def effect(self, _): + raise RuntimeError("Unable to directly set kernel.") + + def register(self, effect: Optional[EffectType]) -> RegistrationHandle: + """Register an effect + + Returns a RegistrationHandle that one can use to de-register this + effect. + """ + self._effect = effect + + def deregister_effect(): + self._set_default_effect() + + handle = RegistrationHandle(deregister_effect) + return handle diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_library/fake_class_registry.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_library/fake_class_registry.py new file mode 100644 index 0000000000000000000000000000000000000000..57342a752a84b0bb10320a3eb97267d05a078ed4 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_library/fake_class_registry.py @@ -0,0 +1,417 @@ +# mypy: allow-untyped-defs +import copy +import logging +from typing import Any, Optional, Protocol, Union + +import torch +from torch._library.utils import parse_namespace +from torch.utils._python_dispatch import _disable_current_modes + + +log = logging.getLogger(__name__) + + +class FakeScriptObject: + def __init__( + self, wrapped_obj: Any, script_class_name: str, x: Optional[torch.ScriptObject] + ): + # Use object.__setattr__ to bypass our custom __setattr__ during initialization + object.__setattr__(self, "wrapped_obj", wrapped_obj) + object.__setattr__(self, "script_class_name", script_class_name) + try: + with _disable_current_modes(): + real_obj = copy.deepcopy(x) + except RuntimeError as e: + log.warning( # noqa: G200 + "Unable to deepcopy the custom object %s due to %s. " + "Defaulting to the user given object. This might be " + "dangerous as side effects may be directly applied " + "to the object.", + script_class_name, + str(e), + ) + real_obj = x + object.__setattr__(self, "real_obj", real_obj) + + def __getattribute__(self, name): + try: + return super().__getattribute__(name) + except AttributeError as e: + raise AttributeError( + f"Tried to call __getattr__ with attr '{name}' on a FakeScriptObject, " + "implying that you are calling this inside of a fake kernel. " + "The fake kernel should not depend on the contents of the " + "OpaqueObject at all, so we're erroring out. If you need this" + "functionality, consider creating a custom TorchBind Object instead" + "(but note that this is more difficult)." + ) from e + + def __setattr__(self, name, value): + raise AttributeError( + f"Tried to call __setattr__ with attr '{name}' on a FakeScriptObject, " + "implying that you are calling this inside of a fake kernel. " + "The fake kernel should not depend on the contents of the " + "OpaqueObject at all, so we're erroring out. If you need this" + "functionality, consider creating a custom TorchBind Object instead" + "(but note that this is more difficult)." + ) + + +class FakeScriptMethod: + def __init__( + self, + self_fake_obj: FakeScriptObject, + method_name: str, + schema: Optional[torch.FunctionSchema], + ): + self.self_fake_obj = self_fake_obj + self.method_name = method_name + self.schema = schema + + def __call__(self, *args, **kwargs): + from torch._higher_order_ops.torchbind import call_torchbind + + return call_torchbind(self.self_fake_obj, self.method_name, *args, **kwargs) + + +class HasStaticMethodFromReal(Protocol): + @classmethod + def from_real(cls, real_obj: torch.ScriptObject): + pass + + +class FakeClassRegistry: + def __init__(self) -> None: + self._registered_class: dict[str, Any] = {} + + def has_impl(self, full_qualname: str) -> bool: + return full_qualname in self._registered_class + + def get_impl(self, full_qualname: str) -> Any: + self._check_registered(full_qualname) + return self._registered_class[full_qualname] + + def register(self, full_qualname: str, fake_class=None) -> None: + if self.has_impl(full_qualname): + log.warning( + "%s is already registered. Previous fake class is overridden with %s.", + full_qualname, + fake_class, + ) + self._registered_class[full_qualname] = fake_class + + def deregister(self, full_qualname: str) -> Any: + if not self.has_impl(full_qualname): + log.warning( + "Cannot deregister %s. Please use register_fake_class to register it first." + " Or do you dereigster it twice?", + full_qualname, + ) + else: + return self._registered_class.pop(full_qualname) + + def clear(self) -> None: + self._registered_class.clear() + + def _check_registered(self, full_qualname: str) -> None: + if full_qualname not in self._registered_class: + raise RuntimeError( + f"{full_qualname} is not registered. Please use register_fake_class to register it first." + ) + + +global_fake_class_registry = FakeClassRegistry() + + +# TODO: add this check at compile time for __obj_flatten__. +def _check_valid_flat_script_obj(flat_x): + if not isinstance(flat_x, tuple): + raise RuntimeError("Expect flat x to be a tuple.") + + for tp in flat_x: + if not isinstance(tp, tuple): + raise RuntimeError("Expect flat x to be a tuple of tuples.") + + if not len(tp) == 2 or not isinstance(tp[0], str): + raise RuntimeError( + "Expect element of flat x to be a tuple of two elements with first element being a string" + ) + + +def tracing_with_real(x: torch.ScriptObject) -> bool: + if not hasattr(x, "tracing_mode"): + return False + + assert x.tracing_mode() in [ + "real", + "fake", + ], f"tracing_mode can be either real or fake but got {x.tracing_mode()}" + return x.tracing_mode() == "real" + + +def maybe_to_fake_obj( + fake_mode, + x: Any, +) -> Union[FakeScriptObject, torch.ScriptObject]: + import torch.utils._pytree as pytree + from torch.utils._python_dispatch import _disable_current_modes + + # When tracing with real mode, people should implement meta kernels that can + # handle the case of real script object + fake tensor inputs. + if tracing_with_real(x): + return x + + from torch._library.opaque_object import ( + FakeOpaqueObject, + get_opaque_type_name, + is_opaque_type, + OpaqueTypeStr, + ) + + if x is None or is_opaque_type(type(x)): + # In order to make OpaqueObjects truly opaque, the fake kernel should + # not depend on the contents of the OpaqueObject at all. + type_name = OpaqueTypeStr if x is None else get_opaque_type_name(type(x)) + fake_x_wrapped = FakeScriptObject(FakeOpaqueObject(), type_name, None) + return fake_x_wrapped + else: + # x.__obj_flatten__() could be calling some tensor operations inside but we don't + # want to call these ops in surrounding dispatch modes when executing it. + # Otherwise, for example, the fake tensor modes will error out when the tensors inside + # script object execute some operations like clone if allow_non_fake_input flag is set. + with _disable_current_modes(): + flat_x = x.__obj_flatten__() # type: ignore[attr-defined] + + _check_valid_flat_script_obj(flat_x) + + with fake_mode: + from torch._higher_order_ops.utils import _tensor_storage + + storage_map = { + _tensor_storage(inp): i + for i, inp in enumerate(flat_x) + if isinstance(inp, torch.Tensor) + } + alias_map = { + i: storage_map[_tensor_storage(inp)] + for i, inp in enumerate(flat_x) + if isinstance(inp, torch.Tensor) + and storage_map[_tensor_storage(inp)] != i + } + if len(alias_map) > 0: + log.warning( + "Detected script object %s has aliasing relationship among its tensors. " + "Flattened obj: %s. Aliasing tensor indices: %s. " + "This is not supported and may cause unexpected behavior.", + x, + flat_x, + alias_map, + ) + + # This breaks the aliasing relationship among the tensors inside the torchbind object + # This is bad but since we don't need to preserve the aliasing relationship anyway and + # we state clearly that aliasing relationship is not preserved in the doc so this might be OK. + fake_flattened = pytree.tree_map_only( + torch.Tensor, + lambda t: torch.empty_strided( + t.size(), + t.stride(), + device=t.device, + dtype=t.dtype, + requires_grad=t.requires_grad, + layout=t.layout, + ), + flat_x, + ) + + fake_x = _find_fake_class_for_script_object(x).__obj_unflatten__(fake_flattened) + + fake_x_wrapped = FakeScriptObject(fake_x, x._type().qualified_name(), x) # type: ignore[attr-defined] + + for name in x._method_names(): # type: ignore[attr-defined] + attr = getattr(fake_x, name, None) + if attr is not None: + if not callable(attr): + raise RuntimeError(f"Expect {name} to be a callable but got {attr}.") + + real_attr = getattr(x, name) # type: ignore[attr-defined] + + # real attr sometimes is not torch.ScriptMethod thus doesn't have schema e.g. __init___ or __eq__ + method_schema: Optional[torch.FunctionSchema] = None + if isinstance(real_attr, torch.ScriptMethod): + method_schema = real_attr.schema # type: ignore[attr-defined] + + # Bypasses our custom setattr function + object.__setattr__( + fake_x_wrapped, + name, + FakeScriptMethod(fake_x_wrapped, name, method_schema), + ) + else: + override_skip_list = {"__obj_flatten__", "__getstate__", "__setstate__"} + if name not in override_skip_list: + log.warning("fake object of %s doesn't implement method %s.", x, name) + return fake_x_wrapped + + +def register_fake_class(qualname, fake_class: Optional[HasStaticMethodFromReal] = None): + r"""Register a fake implementation for this class. + + It's in the same spirit of registering a fake implementation for + an operator but with the difference that it + associates a fake class with the original torch bind class (registered + with torch::class_). In this way, torch.compile can handle them properly + in components such as Dynamo and AOTAutograd. + + This API may be used as a decorator (see example). For the fake class, users + are required to provide a from_real classmethod that takes a real object and + returns an instance of the fake class. All tensors in the fake object should also + be properly fakified with to_fake_tensor() in from_real. + + + Examples: + # For a custom class Foo defined in test_custom_class_registration.cpp: + + TORCH_LIBRARY(_TorchScriptTesting, m) { + m.class_("_TensorQueue") + .def(torch::init()) + .def("push", &TensorQueue::push) + .def("pop", &TensorQueue::pop) + .def("top", &TensorQueue::top) + .def("size", &TensorQueue::size) + .def("clone_queue", &TensorQueue::clone_queue) + .def("__obj_flatten__", &TensorQueue::__obj_flatten__) + .def_pickle( + // __getstate__ + [](const c10::intrusive_ptr& self) + -> c10::Dict { + return self->serialize(); + }, + // __setstate__ + [](c10::Dict data) + -> c10::intrusive_ptr { + return c10::make_intrusive(std::move(data)); + }); + }; + # We could register a fake class FakeTensorQueue in Python as follows: + import torch + + @torch._library.register_fake_class("_TorchScriptTesting::_TensorQueue") + class FakeTensorQueue: + def __init__(self, queue): + self.queue = queue + + @classmethod + def __obj_unflatten__(cls, flattened_ctx): + return cls(**dict(ctx)) + + def push(self, x): + self.queue.append(x) + + def pop(self): + return self.queue.pop(0) + + def size(self): + return len(self.queue) + + In this example, the original TensorQeue need to add a __obj_flatten__ method + to the class TensorQueue and the flattened result is passed into FakeTensorQueue's + __obj_unflatten__ as inputs to create a fake class. This protocol allows pytorch to look + at the contents of the script object and properly handle them in the subsystems + like dynamo, aot_aotugrad or more. + """ + + def inner(fake_class: HasStaticMethodFromReal): + ns, name = parse_namespace(qualname) + + # This also checks whether the referred torch::class_ exists. + torch._C._get_custom_class_python_wrapper(ns, name) + + from_method = getattr(fake_class, _CONVERT_FROM_REAL_NAME, None) + if not from_method: + raise RuntimeError( + f"{fake_class} doesn't define a classmethod {_CONVERT_FROM_REAL_NAME}." + ) + + if not isinstance(fake_class.__dict__[_CONVERT_FROM_REAL_NAME], classmethod): + raise RuntimeError( + f"{_CONVERT_FROM_REAL_NAME} method is not a classmethod." + ) + + global_fake_class_registry.register(_full_qual_class_name(qualname), fake_class) + return fake_class + + if fake_class is None: + return inner + return inner(fake_class) + + +def deregister_fake_class(qualname): + return global_fake_class_registry.deregister(_full_qual_class_name(qualname)) + + +def has_fake_class(full_qualname) -> bool: + return global_fake_class_registry.has_impl(full_qualname) + + +def find_fake_class(full_qualname) -> Optional[Any]: + if not has_fake_class(full_qualname): + return None + return global_fake_class_registry.get_impl(full_qualname) + + +def _full_qual_class_name(qualname: str) -> str: + ns, name = parse_namespace(qualname) + return "__torch__.torch.classes." + ns + "." + name + + +def _is_script_object(obj: Any) -> bool: + return isinstance( + obj, torch.ScriptObject + ) and obj._type().qualified_name().startswith( # type: ignore[attr-defined] + "__torch__.torch.classes" + ) + + +# Return the namespace and class name from fully qualified name. +def _ns_and_class_name(full_qualname: str) -> tuple[str, str]: + splits = full_qualname.split(".") + assert len(splits) == 5, f"Could not split {full_qualname=}" + _torch, _torch_ns, _classes, ns, class_name = splits + return ns, class_name + + +def _find_fake_class_for_script_object(x: torch.ScriptObject) -> Any: + full_qualname = x._type().qualified_name() # type: ignore[attr-defined] + ns, class_name = _ns_and_class_name(full_qualname) + fake_class = find_fake_class(full_qualname) + if fake_class is None: + raise RuntimeError( + f" ScriptObject's {full_qualname} haven't registered a fake class." + f" Please use register_fake_class({ns}::{class_name}) to annotate a fake class for the script obj." + f" Specifically, create a python class that implements a fake version for all the methods" + f" that're used in the program and put annotated class in the program e.g. after loading the library." + f" The fake methods can be written in the same way as a meta kernel for an operator but need to additionally" + f" simulate the object's states. Be sure to add a {_CONVERT_FROM_REAL_NAME} classmethod" + f" to enable creating a fake obj from a real one." + ) + return fake_class + + +_CONVERT_FROM_REAL_NAME = "__obj_unflatten__" + + +def _fake_obj_from_real(fake_mode, x) -> Any: + fake_class = _find_fake_class_for_script_object(x) + + from_real_method = getattr(fake_class, _CONVERT_FROM_REAL_NAME, None) + if not from_real_method: + raise RuntimeError( + f"{fake_class} must define a classmethod {_CONVERT_FROM_REAL_NAME}" + f" that converts the real object to the fake object." + ) + + # from_real defined by user need the ctx to fakify the tensor states. + ctx = torch._library.fake_impl.FakeImplCtx(fake_mode, None) + with torch._library.fake_impl.set_ctx_getter(lambda: ctx): + return fake_class.from_real(x) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_library/fake_impl.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_library/fake_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..877ebb0c59122c161ffb789a44484551a0f54cba --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_library/fake_impl.py @@ -0,0 +1,227 @@ +# mypy: allow-untyped-defs +import contextlib +import functools +from collections.abc import Callable +from typing_extensions import deprecated + +import torch +from torch._library.utils import Kernel, RegistrationHandle + + +class FakeImplHolder: + """A holder where one can register an fake impl to.""" + + def __init__(self, qualname: str): + self.qualname: str = qualname + # kernels stores all registered fake kernels, ordered by registration + # time ascendingly (newer registration after older registration). If an + # operator library gets loaded that overrides an existing fake kernel, + # both kernels will be in the list, but the newest one will be the one + # that is run. If the library is unloaded, we will remove the kernel + # from this list. + self.kernels: list[Kernel] = [] + + @property + def kernel(self): + if len(self.kernels) == 0: + return None + return self.kernels[-1] + + @kernel.setter + def kernel(self, value): + raise RuntimeError("Unable to directly set kernel.") + + def register( + self, func: Callable, source: str, lib, *, allow_override=False + ) -> RegistrationHandle: + """Register an fake impl. + + Returns a RegistrationHandle that one can use to de-register this + fake impl. + """ + + if not allow_override: + if self.kernel is not None: + raise RuntimeError( + f"register_fake(...): the operator {self.qualname} " + f"already has an fake impl registered at " + f"{self.kernel.source}." + ) + if torch._C._dispatch_has_kernel_for_dispatch_key(self.qualname, "Meta"): + raise RuntimeError( + f"register_fake(...): the operator {self.qualname} " + f"already has an DispatchKey::Meta implementation via a " + f"pre-existing torch.library or TORCH_LIBRARY registration. " + f"Please either remove that registration or don't call " + f"register_fake." + ) + + if torch._C._dispatch_has_kernel_for_dispatch_key( + self.qualname, "CompositeImplicitAutograd" + ): + raise RuntimeError( + f"register_fake(...): the operator {self.qualname} " + f"already has an implementation for this device type via a " + f"pre-existing registration to " + f"DispatchKey::CompositeImplicitAutograd." + f"CompositeImplicitAutograd operators do not need an fake " + f"impl; " + f"instead, the operator will decompose into its constituents " + f"and those " + f"can have fake impls defined on them." + ) + + # Store the kernel in this holder + kernel = Kernel(func, source) + self.kernels.append(kernel) + + def deregister_fake_kernel(): + self.kernels.remove(kernel) + + meta_kernel = construct_meta_kernel(self.qualname, self) + lib.impl(self.qualname, meta_kernel, "Meta", allow_override=allow_override) + + handle = RegistrationHandle(deregister_fake_kernel) + return handle + + +def construct_meta_kernel(qualname: str, fake_impl_holder: FakeImplHolder) -> Callable: + assert fake_impl_holder.kernel is not None + + @functools.wraps(fake_impl_holder.kernel.func) + def meta_kernel(*args, **kwargs): + assert fake_impl_holder.kernel is not None + source = fake_impl_holder.kernel.source + + def error_on_ctx(): + raise RuntimeError( + f"{qualname} ({source}): You're trying to run this operator " + f"with meta Tensors (as opposed to FakeTensors), but this " + f"operator may return an output Tensor with data-dependent shape. Meta " + f"Tensors don't support operators with outputs that have data-dependent shapes " + f"but FakeTensors do. " + f"If your operator does not return an output with data-dependent shape, " + f"make sure the FakeTensor and/or meta kernel does not call " + f"torch.library.get_ctx(). Otherwise, please use FakeTensors." + ) + + with set_ctx_getter(error_on_ctx): + return fake_impl_holder.kernel(*args, **kwargs) + + return meta_kernel + + +def get_none(): + return None + + +global_ctx_getter: Callable = get_none + + +@contextlib.contextmanager +def set_ctx_getter(ctx_getter): + global global_ctx_getter + prev = global_ctx_getter + try: + global_ctx_getter = ctx_getter + yield + finally: + global_ctx_getter = prev + + +class FakeImplCtx: + """ + Context object for writing fake implementations for custom operators. + """ + + def __init__(self, _fake_mode, _op): + self._fake_mode = _fake_mode + self._shape_env = _fake_mode.shape_env + self._op = _op + + @deprecated( + "`create_unbacked_symint` is deprecated, please use `new_dynamic_size` instead", + category=FutureWarning, + ) + def create_unbacked_symint(self, *, min=2, max=None) -> torch.SymInt: + return self.new_dynamic_size(min=min, max=max) + + def new_dynamic_size(self, *, min=0, max=None) -> torch.SymInt: + """Constructs a new symint (symbolic int) representing a data-dependent value. + + This is useful for writing the fake implementation (which is necessary + for torch.compile) for a CustomOp where an output Tensor has a size + that depends on the data of the input Tensors. + + Args: + min (int): A statically known inclusive lower bound for this symint. Default: 0 + max (Optional[int]): A statically known inclusive upper bound for this + symint. Default: None + + .. warning: + + It is important that the ``min`` and ``max`` (if not None) values are set + correctly, otherwise, there will be undefined behavior under + torch.compile. The default value of ``min`` is 2 due to torch.compile + specializing on 0/1 sizes. + + You must also verify that your implementation on concrete Tensors + (e.g. CPU/CUDA) only returns Tensors where the size that corresponds + to the symint also has respects these constraint. + The easiest way to do this is to add an assertion in the CPU/CUDA/etc + implementation that the size follows these bounds. + + Example:: + + >>> # An operator with data-dependent output shape + >>> lib = torch.library.Library("mymodule", "FRAGMENT") + >>> lib.define("mymodule::custom_nonzero(Tensor x) -> Tensor") + >>> + >>> @torch.library.register_fake("mymodule::custom_nonzero") + >>> def _(x): + >>> # Number of nonzero-elements is data-dependent. + >>> # Since we cannot peek at the data in an fake impl, + >>> # we use the ctx object to construct a new symint that + >>> # represents the data-dependent size. + >>> ctx = torch.library.get_ctx() + >>> nnz = ctx.new_dynamic_size() + >>> shape = [nnz, x.dim()] + >>> result = x.new_empty(shape, dtype=torch.int64) + >>> return result + >>> + >>> @torch.library.impl(lib, "custom_nonzero", "CPU") + >>> def _(x): + >>> x_np = x.numpy() + >>> res = np.stack(np.nonzero(x_np), axis=1) + >>> return torch.tensor(res, device=x.device) + + """ + if ( + self._shape_env is None + or not self._shape_env.allow_dynamic_output_shape_ops + ): + raise torch._subclasses.fake_tensor.DynamicOutputShapeException(self._op) + + if isinstance(min, torch.SymInt) or isinstance(max, torch.SymInt): + raise ValueError( + f"ctx.new_dynamic_size(min={min}, max={max}): expected " + f"min and max to be statically known ints but got SymInt. " + f"This is not supported." + ) + + if min < 0: + raise ValueError( + f"ctx.new_dynamic_size(min={min}, ...): expected min to be " + f"greater than or equal to 0: this API can only create " + f"non-negative sizes." + ) + + return allocate_size(self._shape_env, min, max) + + +def allocate_size(shape_env, min_val=0, max_val=None): + result = shape_env.create_unbacked_symint() + torch.fx.experimental.symbolic_shapes._constrain_range_for_size( + result, min=min_val, max=max_val + ) + return result diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_library/fake_profile.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_library/fake_profile.py new file mode 100644 index 0000000000000000000000000000000000000000..984a996b90dc118f4d5fa1314b752f401d1af2c1 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_library/fake_profile.py @@ -0,0 +1,325 @@ +import contextlib +import io +import logging +import os +from collections.abc import Callable, Generator +from dataclasses import dataclass +from typing import Any, Optional, Union + +import torch +from torch._library.custom_ops import _maybe_get_opdef +from torch.types import FileLike + + +log = logging.getLogger(__name__) + + +class MissingOpProfile(RuntimeError): + """ + This is raised when we don't have an operator profile available for the + given inputs. + """ + + +@dataclass(frozen=True) +class TensorMetadata: + rank: int + dtype: torch.dtype + device: torch.device + layout: torch.layout + + @staticmethod + def maybe_from_tensor(t: Any) -> Optional["TensorMetadata"]: + if not isinstance(t, torch.Tensor): + return None + return TensorMetadata(t.dim(), t.dtype, t.device, t.layout) + + +@dataclass(frozen=True) +class OpProfile: + args_profile: tuple[Optional[TensorMetadata]] + out_profile: Union[TensorMetadata, tuple[TensorMetadata]] + + +def _generate_fake_kernel(op_name: str, op_profile: set[OpProfile]) -> Callable: + def _match_args(args_profile: tuple[Optional[TensorMetadata]], args: Any) -> bool: + return all( + TensorMetadata.maybe_from_tensor(arg) == args_profile[i] + for i, arg in enumerate(args) + ) + + def _generate_res( + out_profile: Union[TensorMetadata, tuple[TensorMetadata]], + ) -> Union[torch.Tensor, list[torch.Tensor]]: + ctx = torch.library.get_ctx() + + def _generate_tensor_out(t: TensorMetadata) -> torch.Tensor: + fake_shape = [ctx.new_dynamic_size() for _ in range(t.rank)] + fake_strides = [-1] * t.rank + expected = 1 + fake_stride = expected + for i in range(t.rank): + fake_strides[i] = fake_stride # type: ignore[assignment] + fake_stride = fake_stride * fake_shape[i] # type: ignore[assignment] + + return torch.empty_strided( + fake_shape, + fake_strides, + device=t.device, + dtype=t.dtype, + layout=t.layout, + ) + + if isinstance(out_profile, TensorMetadata): + return _generate_tensor_out(out_profile) + else: + return [_generate_tensor_out(t) for t in out_profile] + + def _fake_kernel(*args, **kwargs): # type: ignore[no-untyped-def] + for profile in op_profile: + if _match_args(profile.args_profile, (*args, *kwargs.values())): + return _generate_res(profile.out_profile) + + raise MissingOpProfile( + f"No fake kernel was found for {op_name}, and although we have " + "previously registered some profiles to generate a fake kernel, " + f"no profiles match the given inputs: {args, kwargs}." + ) + + return _fake_kernel + + +@contextlib.contextmanager +def unsafe_generate_fake_kernels(op_profiles: dict[str, set[OpProfile]]) -> Generator: + """ + Registers a fake kernel based on the given operator profiles. This fake + kernel registration will override any existing fake kernel registrations. + + The input is a dictionary mapping operator names to a set of operator + profiles, which we will use to generate fake kernels. The operator profiles + are a record of the input and output tensor metadata. Based on this + information we will match a given input to the recorded profile, and return + an output with the same metadata as in the recorded profile. If a profile + doesn't exist then an exception will be thrown. + + The fake kernel generation is considered unsafe because it relies on the + rigid, pre-defined operator profiles that do not account for potential + variations in output behavior. Specifically, the generated kernels assume a + fixed relationship between input and output ranks. However, in reality, it's + possible that data-dependent operations may produce outputs of different + ranks even when given inputs of the same rank. The generated fake kernels + are inflexible and unable to accommodate these nuances, making them + potentially unsafe. + + Args: + op_profiles (dict[str, set[OpProfile]]): A dictionary mapping operator + name to a set of operator profiles from which we will generate fake + kernels. + + Examples: + + >>> # Example: Registering an op-profile from draft-export + >>> import torch + >>> from torch.export._draft_export import draft_export + >>> + >>> @torch.library.custom_op("mylib::foo", mutates_args=()) + >>> def foo(x: Tensor, y: Tensor) -> Tensor: + >>> return x + y + >>> + >>> class M(torch.nn.Module): + >>> def forward(self, a, b): + >>> res = torch.ops.mylib.foo(a, b) # no fake impl + >>> return res + >>> + >>> ep = draft_export(M(), (torch.ones(3, 4), torch.ones(3, 4)) + >>> + >>> with torch._library.fake_profile.unsafe_generate_fake_kernels(ep._report.op_profiles): + >>> decomp = ep.run_decompositions() + + """ + + libs: list[torch.library.Library] = [] + # Stores old fake impls from custom ops declared through @custom_op + old_fake_impls: dict[str, Callable] = {} + for op_name, profiles in op_profiles.items(): + log.warning( + "Registering fake profile for %s. This will override any existing " + "fake kernel registration.", + op_name, + ) + + op_name_split = op_name.split(".") + namespace, op_name_str = op_name_split[0], op_name_split[1] + op_str = f"{namespace}::{op_name_str}" + + fake_kernel = _generate_fake_kernel(op_str, profiles) + + if opdef := _maybe_get_opdef(op_str): + # If the op is a CustomOpDef, save the existing abstract_fn so that + # we can restore it after this contextmanager + if opdef._abstract_fn is not None: + old_fake_impls[op_str] = opdef._abstract_fn + opdef.register_fake(fake_kernel) + + else: + # Create a new library so that we can register a new fake impl. + # These libraries will then be destroyed after the contextmanager, + # which will automatically restore the previously registered fake + # impls. + newlib = torch.library.Library(namespace, "FRAGMENT") # noqa: TOR901 + torch.library.register_fake( + op_str, fake_kernel, lib=newlib, allow_override=True + ) + libs.append(newlib) + + try: + yield libs + finally: + # Destroying the libraries will automatically restore the previously + # registered fake impls + for lib in libs: + lib._destroy() + + # Restore abstract_fns for CustomOpDefs + for op_str, old_fake in old_fake_impls.items(): + opdef = _maybe_get_opdef(op_str) + assert opdef is not None + opdef.register_fake(old_fake) + + +def get_torch_version() -> str: + version = torch.__version__.split(".") + return f"{int(version[0])}.{int(version[1])}" + + +def generate_yaml_from_profiles(op_profiles: dict[str, set[OpProfile]]) -> str: + """ + Generates a yaml string from the given operator profiles which can be saved + to a file. The yaml string can be loaded back into an operator profile + structure using `read_profiles_from_yaml`. + """ + + import yaml + + from torch._export.serde.serialize import ( + _TORCH_TO_SERIALIZE_DTYPE, + _TORCH_TO_SERIALIZE_LAYOUT, + ) + + def serialize_tensor_metadata(t: TensorMetadata) -> dict: + return { + "rank": t.rank, + "dtype": _TORCH_TO_SERIALIZE_DTYPE[t.dtype].value, + "device": str(t.device), + "layout": _TORCH_TO_SERIALIZE_LAYOUT[t.layout].value, + } + + def serialize_op_profile(op: OpProfile) -> dict: + return { + "args_profile": [ + serialize_tensor_metadata(arg) + for arg in op.args_profile + if arg is not None + ], + "out_profile": ( + serialize_tensor_metadata(op.out_profile) + if isinstance(op.out_profile, TensorMetadata) + else [serialize_tensor_metadata(out) for out in op.out_profile] + ), + } + + serialized_data = { + operator: [serialize_op_profile(profile) for profile in profiles] + for operator, profiles in op_profiles.items() + } + return yaml.dump( + {"torch_version": get_torch_version(), "operators": serialized_data}, + sort_keys=False, + ) + + +def save_op_profiles(op_profiles: dict[str, set[OpProfile]], f: FileLike) -> None: + """ + Serializes the given operator profiles into a yaml format and saves it to + the given file. The operator profile can be loaded back using `load_op_profiles`. + """ + yaml_str = generate_yaml_from_profiles(op_profiles) + + if isinstance(f, (str, os.PathLike)): + f = os.fspath(f) + + with open(f, "w") as file: + file.write(yaml_str) + + elif isinstance(f, io.BytesIO): + f.write(yaml_str.encode("utf-8")) + + else: + raise ValueError(f"Invalid type of file {f}") + + +def read_profiles_from_yaml(yaml_str: str) -> dict[str, set[OpProfile]]: + """ + Reads the yaml saved by `save_op_profiles` and returns the operator profiles. + """ + + import yaml + + from torch._export.serde.serialize import ( + _SERIALIZE_TO_TORCH_DTYPE, + _SERIALIZE_TO_TORCH_LAYOUT, + ) + + def deserialize_tensor_metadata(data: dict) -> TensorMetadata: + return TensorMetadata( + rank=data["rank"], + dtype=_SERIALIZE_TO_TORCH_DTYPE[data["dtype"]], + device=torch.device(data["device"]), + layout=_SERIALIZE_TO_TORCH_LAYOUT[data["layout"]], + ) + + def deserialize_op_profile(data: dict) -> OpProfile: + args_profile = tuple( + deserialize_tensor_metadata(arg) for arg in data["args_profile"] + ) + out_profile_data = data["out_profile"] + out_profile: Union[tuple[TensorMetadata], TensorMetadata] = ( + tuple(deserialize_tensor_metadata(out) for out in out_profile_data) # type: ignore[assignment] + if isinstance(out_profile_data, list) + else deserialize_tensor_metadata(out_profile_data) + ) + return OpProfile(args_profile=args_profile, out_profile=out_profile) # type: ignore[arg-type] + + loaded_data = yaml.safe_load(yaml_str) + loaded_torch_version = loaded_data["torch_version"] + + if loaded_torch_version != get_torch_version(): + raise RuntimeError( + "Unable to load outdated profile. It was saved with torch version: " + f"{loaded_torch_version} but the current torch version is: {get_torch_version()}" + ) + + operators_data = loaded_data["operators"] + return { + operator: {deserialize_op_profile(profile) for profile in profiles} + for operator, profiles in operators_data.items() + } + + +def load_op_profiles(f: FileLike) -> dict[str, set[OpProfile]]: + """ + Loads the saved operator profiles from `save_op_profiles`. + """ + if isinstance(f, (str, os.PathLike)): + f = os.fspath(f) + + with open(f) as file: + yaml_str = file.read() + + elif isinstance(f, io.BytesIO): + yaml_str = f.read().decode("utf-8") + + else: + raise ValueError(f"Invalid type of file {f}") + + return read_profiles_from_yaml(yaml_str) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_library/infer_schema.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_library/infer_schema.py new file mode 100644 index 0000000000000000000000000000000000000000..5e81f8e6fd4e71e6c794ec3068185801daa4abfa --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_library/infer_schema.py @@ -0,0 +1,353 @@ +# mypy: allow-untyped-defs +import collections +import inspect +import typing +from types import GenericAlias +from typing import Optional, Union + +import torch +from torch import device, dtype, Tensor, types +from torch.utils._exposed_in import exposed_in + +from .opaque_object import _OPAQUE_TYPES, is_opaque_type + + +# This is used as a negative test for +# test_custom_ops.py::TestTypeConversion::test_type_eval. +_TestTensor = torch.Tensor + + +@exposed_in("torch.library") +def infer_schema( + prototype_function: typing.Callable, + /, + *, + mutates_args, + op_name: Optional[str] = None, +) -> str: + r"""Parses the schema of a given function with type hints. The schema is inferred from the + function's type hints, and can be used to define a new operator. + + We make the following assumptions: + + * None of the outputs alias any of the inputs or each other. + * | String type annotations "device, dtype, Tensor, types" without library specification are + | assumed to be torch.*. Similarly, string type annotations "Optional, List, Sequence, Union" + | without library specification are assumed to be typing.*. + * | Only the args listed in ``mutates_args`` are being mutated. If ``mutates_args`` is "unknown", + | it assumes that all inputs to the operator are being mutates. + + Callers (e.g. the custom ops API) are responsible for checking these assumptions. + + Args: + prototype_function: The function from which to infer a schema for from its type annotations. + op_name (Optional[str]): The name of the operator in the schema. If ``name`` is None, then the + name is not included in the inferred schema. Note that the input schema to + ``torch.library.Library.define`` requires a operator name. + mutates_args ("unknown" | Iterable[str]): The arguments that are mutated in the function. + + Returns: + The inferred schema. + + Example: + >>> def foo_impl(x: torch.Tensor) -> torch.Tensor: + >>> return x.sin() + >>> + >>> infer_schema(foo_impl, op_name="foo", mutates_args={}) + foo(Tensor x) -> Tensor + >>> + >>> infer_schema(foo_impl, mutates_args={}) + (Tensor x) -> Tensor + """ + UNKNOWN_MUTATES = "unknown" + pf_globals = prototype_function.__globals__ + pf_locals = None + # TODO: Once our minimum version is py3.10+ pass `eval_str=True` to + # inspect.signature() and we no longer need to deal with stringified + # annotations below. + sig = inspect.signature(prototype_function) + + def error_fn(what): + raise ValueError(f"infer_schema(func): {what} Got func with signature {sig})") + + def convert_type_string(annotation_type: str): + try: + return eval(annotation_type, pf_globals, pf_locals) + except Exception: + error_fn( + f"Unsupported type annotation {annotation_type}. It is not a type." + ) + + def unstringify_types( + tys: tuple[Union[type[object], str], ...], + ) -> tuple[tuple[typing.Any, ...], bool]: + res = [] + changed = False + for ty in tys: + ty, ty_changed = unstringify_type(ty) + res.append(ty) + changed |= ty_changed + if changed: + return tuple(res), True + else: + return tys, False # type: ignore[return-value] + + def unstringify_type(ty: Union[type[object], str]) -> tuple[typing.Any, bool]: + # Dig through a generic type and if it contains a stringified type + # convert that to a real type. The second return value indicates if the + # type contained a string or not. + if isinstance(ty, str): + return convert_type_string(ty), True + elif origin := typing.get_origin(ty): + args, args_changed = unstringify_types(typing.get_args(ty)) + if args_changed: + return GenericAlias(origin, args), True + + return ty, False + + params = [] + seen_args = set() + saw_kwarg_only_arg = False + for idx, (name, param) in enumerate(sig.parameters.items()): + if not supported_param(param): + error_fn("We do not support positional-only args, varargs, or varkwargs.") + + if param.kind == inspect.Parameter.KEYWORD_ONLY: + # The first time we see a kwarg-only arg, add "*" to the schema. + if not saw_kwarg_only_arg: + params.append("*") + saw_kwarg_only_arg = True + + if param.annotation is inspect.Parameter.empty: + error_fn(f"Parameter {name} must have a type annotation.") + + # The annotation might be converted to a string by annotation, + # we convert it to the actual type. + annotation_type, _ = unstringify_type(param.annotation) + + schema_type = None + if annotation_type not in SUPPORTED_PARAM_TYPES: + if is_opaque_type(annotation_type): + schema_type = _OPAQUE_TYPES[annotation_type].class_name + elif annotation_type == torch._C.ScriptObject: + error_fn( + f"Parameter {name}'s type cannot be inferred from the schema " + "as it is a ScriptObject. Please manually specify the schema " + "using the `schema=` kwarg with the actual type of the ScriptObject." + ) + elif ( + hasattr(annotation_type, "__origin__") + and annotation_type.__origin__ is tuple + ): + list_type = tuple_to_list(annotation_type) + example_type_str = "\n\n" + # Only suggest the list type if this type is supported. + if list_type in SUPPORTED_PARAM_TYPES: + example_type_str = f"For example, {list_type}.\n\n" + error_fn( + f"Parameter {name} has unsupported type {param.annotation}. " + f"We do not support Tuple inputs in schema. As a workaround, please try to use List instead. " + f"{example_type_str}" + f"The valid types are: {SUPPORTED_PARAM_TYPES.keys()}." + ) + else: + error_fn( + f"Parameter {name} has unsupported type {param.annotation}. " + f"The valid types are: {SUPPORTED_PARAM_TYPES.keys()}." + ) + else: + schema_type = SUPPORTED_PARAM_TYPES[annotation_type] + + assert schema_type is not None + + if type(mutates_args) is str: + if mutates_args != UNKNOWN_MUTATES: + raise ValueError( + "mutates_args must either be a sequence of the names of " + "the arguments that are mutated or the string 'unknown'. " + ) + if schema_type.startswith("Tensor"): + schema_type = f"Tensor(a{idx}!){schema_type[len('Tensor') :]}" + elif name in mutates_args: + if not schema_type.startswith("Tensor"): + error_fn( + f"Parameter {name} is in mutable_args but only Tensors or collections of Tensors can be mutated" + ) + schema_type = f"Tensor(a{idx}!){schema_type[len('Tensor') :]}" + seen_args.add(name) + if param.default is inspect.Parameter.empty: + # pyrefly: ignore [bad-argument-type] + params.append(f"{schema_type} {name}") + else: + default_repr = None + if param.default is None or isinstance(param.default, (int, float, bool)): + default_repr = str(param.default) + elif isinstance(param.default, (str, torch.device)): + default_repr = f'"{param.default}"' + elif isinstance(param.default, torch.dtype): + dtype_repr = str(param.default) + torch_dot = "torch." + assert dtype_repr.startswith(torch_dot) + default_repr = dtype_repr[len(torch_dot) :] + else: + error_fn( + f"Parameter {name} has an unsupported default value type {type(param.default)}. " + f"Please file an issue on GitHub so we can prioritize this." + ) + # pyrefly: ignore [bad-argument-type] + params.append(f"{schema_type} {name}={default_repr}") + if mutates_args != UNKNOWN_MUTATES: + mutates_args_not_seen = set(mutates_args) - seen_args + if len(mutates_args_not_seen) > 0: + error_fn( + f"{mutates_args_not_seen} in mutates_args were not found in " + f"the custom op's signature. " + f"mutates_args should contain the names of all args that the " + f"custom op mutates, or just the string 'unknown' if you don't know." + ) + return_annotation, _ = unstringify_type(sig.return_annotation) + ret = parse_return(return_annotation, error_fn) + if op_name is not None: + return f"{op_name}({', '.join(params)}) -> {ret}" + return f"({', '.join(params)}) -> {ret}" + + +def derived_types( + base_type: Union[type, typing._SpecialForm], + cpp_type: str, + list_base: bool, + optional_base_list: bool, + optional_list_base: bool, +): + result: list[tuple[Union[type, typing._SpecialForm, GenericAlias], str]] = [ + (base_type, cpp_type), + # pyrefly: ignore [not-a-type] + (typing.Optional[base_type], f"{cpp_type}?"), + ] + + def derived_seq_types(typ: Union[type, typing._SpecialForm]): + return ( + typing.Sequence[typ], # type: ignore[valid-type] # noqa: UP006 + typing.List[typ], # type: ignore[valid-type] # noqa: UP006 + GenericAlias(collections.abc.Sequence, (typ,)), + GenericAlias(list, (typ,)), + ) + + if list_base: + result.extend( + (seq_typ, f"{cpp_type}[]") for seq_typ in derived_seq_types(base_type) + ) + if optional_base_list: + result.extend( + (seq_typ, f"{cpp_type}?[]") + # pyrefly: ignore [not-a-type] + for seq_typ in derived_seq_types(typing.Optional[base_type]) + ) + if optional_list_base: + result.extend( + (typing.Optional[seq_typ], f"{cpp_type}[]?") + for seq_typ in derived_seq_types(base_type) + ) + return result + + +def get_supported_param_types(): + # pyrefly: ignore [bad-assignment] + data: list[tuple[Union[type, typing._SpecialForm], str, bool, bool, bool]] = [ + # (python type, schema type, type[] variant, type?[] variant, type[]? variant + (Tensor, "Tensor", True, True, False), + (int, "SymInt", True, False, True), + (float, "float", True, False, True), + (bool, "bool", True, False, True), + (str, "str", False, False, False), + (types.Number, "Scalar", True, False, False), + (dtype, "ScalarType", False, False, False), + (device, "Device", False, False, False), + ] + + if torch.distributed.is_available(): + from torch.distributed.distributed_c10d import GroupName + + data.append((typing.cast(type, GroupName), "str", False, False, False)) + + result = [] + for line in data: + result.extend(derived_types(*line)) + return dict(result) + + +SUPPORTED_RETURN_TYPES = { + Tensor: "Tensor", + typing.List[Tensor]: "Tensor[]", # noqa: UP006 + list[Tensor]: "Tensor[]", + int: "SymInt", + float: "float", + bool: "bool", + types.Number: "Scalar", +} + + +def parse_return(annotation, error_fn): + if annotation is None: + return "()" + + if annotation is inspect.Parameter.empty: + error_fn("No return type annotation was provided. Please add one.") + + origin = typing.get_origin(annotation) + if origin is not tuple: + if annotation not in SUPPORTED_RETURN_TYPES: + error_fn( + f"Return has unsupported type {annotation}. " + f"The valid types are: {SUPPORTED_RETURN_TYPES}." + ) + # pyrefly: ignore [index-error] + return SUPPORTED_RETURN_TYPES[annotation] + + args = typing.get_args(annotation) + for arg in args: + if arg not in SUPPORTED_RETURN_TYPES: + error_fn( + f"Return has unsupported type {annotation}. " + f"The valid types are: {SUPPORTED_RETURN_TYPES}." + ) + output_ty = ", ".join([SUPPORTED_RETURN_TYPES[arg] for arg in args]) + + # use (()) to represent tuple with single element + if len(args) == 1: + output_ty = "(" + output_ty + ")" + return "(" + output_ty + ")" + + +SUPPORTED_PARAM_TYPES = get_supported_param_types() + + +def supported_param(param: inspect.Parameter) -> bool: + return param.kind in ( + inspect.Parameter.POSITIONAL_OR_KEYWORD, + inspect.Parameter.KEYWORD_ONLY, + ) + + +def tuple_to_list(tuple_type: type[tuple]) -> type[list]: + """ + Convert `tuple_type` into a list type with the same type arguments. Assumes that `tuple_type` is typing.Tuple type. + """ + type_args = getattr(tuple_type, "__args__", None) + # Account for different python versions, e.g. python 3.8 would give () + # but python 3.12 would give None. + if ( + tuple_type is typing.Tuple # noqa: UP006 + or tuple_type is tuple + or type_args == () + or type_args is None + ): + # Handle the case of an empty tuple type + return list + elif len(type_args) == 1: + # General case: create a List with the same type arguments + return list[type_args[0]] # type: ignore[valid-type] + elif len(type_args) == 2 and type_args[1] is Ellipsis: + return list[type_args[0]] # type: ignore[valid-type] + else: + return list[typing.Union[tuple(type_args)]] # type: ignore[misc, return-value] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_library/opaque_object.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_library/opaque_object.py new file mode 100644 index 0000000000000000000000000000000000000000..847955b500244cbbcbed1179c6044aa84a64ae9e --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_library/opaque_object.py @@ -0,0 +1,204 @@ +""" +Note [Opaque Objects] + +Opaque objects are the way we allow custom operators to accept a user-defined +"black box" object as an input. + +There are two kinds of opaque types: VALUE type and REFERENCE type. +The distinction determines how torch.compile handles the object. + +REFERENCE TYPES (default): + +Reference-typed opaque objects represent mutable stateful objects and are +treated as black boxes. In torch.compile, since torch.compile cannot optimize +the anything (including tensors) within the object, the object must be an +input to the graph. + +You can register a custom class as being a reference-based opaque object class +through `register_opaque_type(MyClass, typ="reference")`. + +VALUE TYPES: + +Value-typed opaque objects represent constant values. +In torch.compile, the graph specializes on the object like how other constants +are. Therefore there are a couple of methods on the class that must be +implemented before registering it as a value-typed opaque object class: + - __eq__: torch.compile will create guards based on the equality of this + object, meaning that a recompilation will happen if __eq__ returns False. + - __hash__: This must be implemented for Fake Tensor caching + - __repr__: This must be implemented as it will be used in the FX graph's + codegen to reconstruct the object. The string representation must be able to + construct the object again through its __init__ method. + +You can register a custom class as being a reference-based opaque object class +through `register_opaque_type(MyClass, typ="value")`. +""" + +from dataclasses import dataclass +from typing import Any, Literal, NewType +from weakref import WeakKeyDictionary + +import torch + +from .fake_class_registry import register_fake_class + + +@register_fake_class("aten::OpaqueObject") +class FakeOpaqueObject: + def __init__(self) -> None: + pass + + @classmethod + def __obj_unflatten__(cls, flattened_ctx: dict[str, Any]) -> None: + raise RuntimeError( + "FakeOpaqueObject should not be created through __obj_unflatten__ " + "and should be special handled. Please file an issue to Github." + ) + + +OpaqueTypeStr = "__torch__.torch.classes.aten.OpaqueObject" + +OpaqueType = NewType("OpaqueType", torch._C.ScriptObject) + + +@dataclass +class _OpaqueTypeInfo: + class_name: str + opaque_typ: Literal["reference", "value"] + + +# Mapping of type -> (string name, reference/value type) +_OPAQUE_TYPES: WeakKeyDictionary[Any, _OpaqueTypeInfo] = WeakKeyDictionary() +# Mapping of class_name -> (type, reference/value type) +_OPAQUE_TYPES_BY_NAME: dict[str, _OpaqueTypeInfo] = {} + + +def get_opaque_type_name(cls: Any) -> str: + """ + Gets the registered opaque type name for a given class. + + Args: + cls (type): The class to get the type name for. + + Returns: + str: The registered type name for the class. + + Raises: + ValueError: If the class is not registered as an opaque type. + """ + if cls not in _OPAQUE_TYPES: + raise ValueError( + f"Class {cls} is not registered as an opaque type. " + f"Call register_opaque_type({cls.__name__}) first." + ) + return _OPAQUE_TYPES[cls].class_name + + +def register_opaque_type(cls: Any, *, typ: str) -> None: + """ + Registers the given type as an opaque type which allows this to be consumed + by a custom operator. + + The type name will be automatically generated from the class's fully + qualified name (ex. my_module.MyClass). + + Args: + cls (type): The class to register as an opaque type. + typ (str): Either "reference" or "value". See Note [Opaque Objects] for + more details. + """ + import torch.utils._pytree as pytree + + # Prevent registration of built-in types (int, str, list, dict, etc.) and torch.Tensor + if cls.__module__ == "builtins" or cls is torch.Tensor: + raise ValueError( + f"Unable to register built-in type {cls} as an opaque type. " + "Please wrap it in a custom class and register the custom class as opaque." + ) + + if cls in pytree.SUPPORTED_NODES: + raise ValueError( + f"{cls} cannot be registered as an opaque object as it has been " + "registered as a pytree. Opaque objects must be pytree leaves." + ) + + assert typ in ["reference", "value"], ( + "Opaque type must be either 'reference' or 'value'" + ) + + if typ == "value": + if cls.__eq__ is object.__eq__: # type: ignore[comparison-overlap] + raise TypeError( + f"Value-type opaque object of type {cls} is " + "expected to have a non-default `__eq__` " + "implementation as we will use this in torch.compile " + "to guard on the equality of objects." + ) + + # Class with a custom `__eq__` without `__hash__` won't inherit the default + # `__hash__` from object; see https://stackoverflow.com/a/1608907. + if cls.__hash__ is None: # type: ignore[comparison-overlap] + raise TypeError( + f"Value-type opaque object of type {cls} is " + "expected to have a non-default `__hash__` " + "implementation as we will use this in torch.compile " + "for FakeTensor caching." + ) + + if cls.__repr__ is object.__repr__: # type: ignore[comparison-overlap] + raise TypeError( + f"Value-type opaque object of type {cls} is " + "expected to have a non-default `__repr__` " + "implementation as we will use this to reconstruct " + "the object in the FX codegen." + ) + + # Generate a fully qualified name by combining module and qualname + name = f"{cls.__module__}.{cls.__qualname__}" + + type_info = _OpaqueTypeInfo(name, typ) + _OPAQUE_TYPES[cls] = type_info + _OPAQUE_TYPES_BY_NAME[name] = type_info + + torch._C._register_opaque_type(name) + + +def is_opaque_type(cls: Any) -> bool: + """ + Checks if the given type is an opaque type. + """ + if isinstance(cls, str): + return torch._C._is_opaque_type_registered(cls) + + if cls not in _OPAQUE_TYPES: + return False + + return torch._C._is_opaque_type_registered(_OPAQUE_TYPES[cls].class_name) + + +def is_opaque_value_type(cls: Any) -> bool: + """ + Checks if the given type is an opaque **value** type. + See Note [Opaque Objects] for more information. + """ + if not is_opaque_type(cls): + return False + + if isinstance(cls, str): + return _OPAQUE_TYPES_BY_NAME[cls].opaque_typ == "value" + + return _OPAQUE_TYPES[cls].opaque_typ == "value" + + +def is_opaque_reference_type(cls: Any) -> bool: + """ + Checks if the given type is an opaque **reference** type. + See Note [Opaque Objects] for more information. + """ + if not is_opaque_type(cls): + return False + + if isinstance(cls, str): + return _OPAQUE_TYPES_BY_NAME[cls].opaque_typ == "reference" + + return _OPAQUE_TYPES[cls].opaque_typ == "reference" diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_library/simple_registry.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_library/simple_registry.py new file mode 100644 index 0000000000000000000000000000000000000000..466f6cc68e52b2b7fa4d72aaeb630ae21f3101c5 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_library/simple_registry.py @@ -0,0 +1,91 @@ +from collections.abc import Callable +from typing import Any, Optional + +from .effects import EffectHolder +from .fake_impl import FakeImplHolder +from .utils import RegistrationHandle + + +__all__ = ["SimpleLibraryRegistry", "SimpleOperatorEntry", "singleton"] + + +class SimpleLibraryRegistry: + """Registry for the "simple" torch.library APIs + + The "simple" torch.library APIs are a higher-level API on top of the + raw PyTorch DispatchKey registration APIs that includes: + - fake impl + + Registrations for these APIs do not go into the PyTorch dispatcher's + table because they may not directly involve a DispatchKey. For example, + the fake impl is a Python function that gets invoked by FakeTensor. + Instead, we manage them here. + + SimpleLibraryRegistry is a mapping from a fully qualified operator name + (including the overload) to SimpleOperatorEntry. + """ + + def __init__(self) -> None: + self._data: dict[str, SimpleOperatorEntry] = {} + + def find(self, qualname: str) -> "SimpleOperatorEntry": + res = self._data.get(qualname, None) + if res is None: + self._data[qualname] = res = SimpleOperatorEntry(qualname) + return res + + +singleton: SimpleLibraryRegistry = SimpleLibraryRegistry() + + +class SimpleOperatorEntry: + """This is 1:1 to an operator overload. + + The fields of SimpleOperatorEntry are Holders where kernels can be + registered to. + """ + + def __init__(self, qualname: str) -> None: + self.qualname: str = qualname + self.fake_impl: FakeImplHolder = FakeImplHolder(qualname) + self.torch_dispatch_rules: GenericTorchDispatchRuleHolder = ( + GenericTorchDispatchRuleHolder(qualname) + ) + + self.effect: EffectHolder = EffectHolder(qualname) + + # For compatibility reasons. We can delete this soon. + @property + def abstract_impl(self) -> FakeImplHolder: + return self.fake_impl + + +class GenericTorchDispatchRuleHolder: + def __init__(self, qualname: str) -> None: + self._data: dict[type, Callable[..., Any]] = {} + self.qualname: str = qualname + + def register( + self, torch_dispatch_class: type, func: Callable[..., Any] + ) -> RegistrationHandle: + if self.find(torch_dispatch_class): + raise RuntimeError( + f"{torch_dispatch_class} already has a `__torch_dispatch__` rule registered for {self.qualname}" + ) + self._data[torch_dispatch_class] = func + + def deregister() -> None: + del self._data[torch_dispatch_class] + + return RegistrationHandle(deregister) + + def find(self, torch_dispatch_class: type) -> Optional[Callable[..., Any]]: + return self._data.get(torch_dispatch_class, None) + + +def find_torch_dispatch_rule( + op: Any, torch_dispatch_class: type +) -> Optional[Callable[..., Any]]: + return singleton.find(op.__qualname__).torch_dispatch_rules.find( + torch_dispatch_class + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_library/triton.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_library/triton.py new file mode 100644 index 0000000000000000000000000000000000000000..dc55cb9b34944c22db08f5d9aca9758a04eeab2d --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_library/triton.py @@ -0,0 +1,368 @@ +import ast +import contextlib +import inspect +import threading +from collections.abc import Callable, Generator, Iterable +from typing import Any, Optional, Union + +from torch.utils._exposed_in import exposed_in + +from .custom_ops import custom_op, CustomOpDef +from .infer_schema import infer_schema + + +triton_ops_to_kernels: dict[str, list[object]] = {} + + +def get_triton_kernels_for_op(name: str) -> list[object]: + return triton_ops_to_kernels.get(name, []) + + +def get_inner_triton_kernels(fn: Callable[..., Any]) -> list[object]: + """ + Inspect the source of an arbitrary callable passed to torch._library.triton_op, + and grab all of the triton kernels that are wrapped inside of it. + + TODO: This check is best effort. It does *not* handle the case where the triton + kernel is hidden behind recursive function calls. + """ + + def find_triton_kernels(fn: Callable[..., Any]) -> list[object]: + try: + source = inspect.getsource(fn) + except (OSError, TypeError): + return [] # Source code not available + + from torch._inductor.utils import IndentedBuffer + + buffer = IndentedBuffer() + buffer.splice(source, strip=True) + tree = ast.parse(buffer.getrawvalue()) + + # Visitor to collect function calls and triton kernels + class Visitor(ast.NodeVisitor): + def __init__(self) -> None: + self.triton_kernels: list[Any] = [] + + def visit_Call(self, node: ast.Call) -> None: + triton_func_names = ("capture_triton", "wrap_triton") + if isinstance(node.func, ast.Attribute): + attr = node.func + if ( + isinstance(attr.value, ast.Attribute) + and isinstance(attr.value.value, ast.Name) + and attr.value.value.id == "torch" + and attr.value.attr == "_library" + and attr.attr in triton_func_names + ): + if node.args and isinstance(node.args[0], ast.Name): + self.triton_kernels.append(node.args[0].id) + + # Catch capture_triton, wrap_triton that's been + # imported directly + elif isinstance(node.func, ast.Name): + if node.func.id in triton_func_names: + if node.args and isinstance(node.args[0], ast.Name): + self.triton_kernels.append(node.args[0].id) + + self.generic_visit(node) + + collector = Visitor() + collector.visit(tree) + closure_vars = inspect.getclosurevars(fn) + resolved = [] + # First, resolve triton kernel names + for name in collector.triton_kernels: + if name in closure_vars.nonlocals: + resolved.append(closure_vars.nonlocals[name]) + elif name in closure_vars.globals: + resolved.append(closure_vars.globals[name]) + elif name in closure_vars.builtins: + resolved.append(closure_vars.builtins[name]) + return resolved + + return find_triton_kernels(fn) + + +@exposed_in("torch.library") +def triton_op( + name: str, + fn: Optional[Callable] = None, + /, + *, + mutates_args: Union[str, Iterable[str]], + schema: Optional[str] = None, +) -> Callable: + """Create a custom operator whose implementation is backed by 1+ triton kernels. + + This is a more structured way of using triton kernels with PyTorch. + Prefer using triton kernels with no ``torch.library`` custom operator wrappers + (like :func:`torch.library.custom_op`, :func:`torch.library.triton_op`) because + that is simpler; + only use :func:`torch.library.custom_op`/:func:`torch.library.triton_op` if you + want to create an operator that behaves like PyTorch built-in operators. + For example, you may use a ``torch.library`` wrapper API to define the + behavior of the triton kernel when passed a tensor subclass or under + a TorchDispatchMode. + + Use :func:`torch.library.triton_op` instead of :func:`torch.library.custom_op` + when the implementation + consists of 1+ triton kernels. :func:`torch.library.custom_op` treats + custom operators as opaque (:func:`torch.compile` and + :func:`torch.export.export` will never trace into them), but ``triton_op`` + makes the implementation visible to these subsystems, allowing them + to optimize the triton kernel(s). + + Note that ``fn`` must only consist of calls to PyTorch-understood + operators and triton kernels. Any triton kernels called inside ``fn`` + must be wrapped in a call to :func:`torch.library.wrap_triton`. + + Args: + name (str): A name for the custom op that looks like "{namespace}::{name}", + e.g. "mylib::my_linear". The name is used as the op's stable identifier + in PyTorch subsystems (e.g. torch.export, FX graphs). + To avoid name collisions, please use your project name as the namespace; + e.g. all custom ops in pytorch/fbgemm use "fbgemm" as the namespace. + mutates_args (Iterable[str] or "unknown"): The names of args that the function mutates. + This MUST be accurate, otherwise, the behavior is undefined. If "unknown", + it pessimistically assumes that all inputs to the operator are being mutated. + schema (None | str): A schema string for the operator. If None + (recommended) we'll infer a schema for the operator from its type + annotations. We recommend letting us infer a schema unless you + have a specific reason not to. + Example: "(Tensor x, int y) -> (Tensor, Tensor)". + + Example:: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) + >>> import torch + >>> from torch.library import triton_op, wrap_triton + >>> + >>> import triton + >>> from triton import language as tl + >>> + >>> @triton.jit + >>> def add_kernel( + >>> in_ptr0, + >>> in_ptr1, + >>> out_ptr, + >>> n_elements, + >>> BLOCK_SIZE: "tl.constexpr", + >>> ): + >>> pid = tl.program_id(axis=0) + >>> block_start = pid * BLOCK_SIZE + >>> offsets = block_start + tl.arange(0, BLOCK_SIZE) + >>> mask = offsets < n_elements + >>> x = tl.load(in_ptr0 + offsets, mask=mask) + >>> y = tl.load(in_ptr1 + offsets, mask=mask) + >>> output = x + y + >>> tl.store(out_ptr + offsets, output, mask=mask) + >>> + >>> @triton_op("mylib::add", mutates_args={}) + >>> def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + >>> output = torch.empty_like(x) + >>> n_elements = output.numel() + >>> + >>> def grid(meta): + >>> return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) + >>> + >>> # NB: we need to wrap the triton kernel in a call to wrap_triton + >>> wrap_triton(add_kernel)[grid](x, y, output, n_elements, 16) + >>> return output + >>> + >>> @torch.compile + >>> def f(x, y): + >>> return add(x, y) + >>> + >>> x = torch.randn(3, device="cuda") + >>> y = torch.randn(3, device="cuda") + >>> + >>> z = f(x, y) + >>> assert torch.allclose(z, x + y) + + """ + + def dec(fn: Callable[..., object]) -> CustomOpDef: + def backend_fn(*args, **kwargs): # type: ignore[no-untyped-def] + # Optimization: we're passing regular Tensors into the triton kernel, so + # no need to go through HOP dispatch + with set_wrap_triton_enabled(False): + return fn(*args, **kwargs) + + result = custom_op( + name, + backend_fn, + mutates_args=mutates_args, + schema=infer_schema(fn, mutates_args=mutates_args), + ) + from .._subclasses.functional_tensor import FunctionalTensorMode + + # We require that the user pass us a function that is make_fx traceable, + # so we can just register it as the Fake/meta kernel. + result.register_fake(fn) + + # We decompose the operator when FunctionalTensorMode is active. + # The goal is to decompose the operator in AOTDispatcher. + # - With torch.compile, this means that the backend (usually Inductor) + # can see a call to the triton kernel(s) and so it can directly optimize + # them by inlining them into the lowering process. + def functional_decomp( # type: ignore[no-untyped-def] + mode, op, types, args, kwargs + ): + # NOTE [Export custom triton op] + # For torch.export (strict and non-strict), we don't do functional decomposition. + # Instead, we preserve the custom triton ops as custom ops. This is because we want + # the exported program to be high-level and serializable. If we decompose + # the custom op to a functional hop and make it a node in exported program, + # we need to figure out ways of serializing the hop and its arguments, which can be triton.jited + # functions and triton dtypes. This is undesirable because: + # - it can be tedious to maintain a layer that serializes the jited function (e.g. with a string) and dtypes. + # - exported program will contain the implementation detail (e.g. triton source code) for a specific + # backend (GPU), which is probably at a wrong level of abstraction. + # - changes to triton or the serialization logic for triton arguments can be BC breaking + # + # In the short term, we expect users to have a separate aot_compile stage that compiles the exported program + # into a Cubin file on the same machine that users call export, which does autotuning and removes triton + # dependency and serve the model with Cubin. This guarantees that triton changes won't break BC. + # In the long term, we may export multiple cubins for the triton op directly + from torch.export._trace import custom_triton_ops_decomposition_disabled + + if custom_triton_ops_decomposition_disabled(): + return mode.__torch_dispatch__(op, types, args, kwargs) + else: + # TODO: https://github.com/pytorch/pytorch/issues/160333 + # We should deduplicate the unrecognized_types logic. + import torch._subclasses + + unrecognized_types = [ + t + for t in types + if not issubclass(t, torch._subclasses.FakeTensor) + and t + not in [ + torch.Tensor, + torch._subclasses.functional_tensor.FunctionalTensor, + ] + ] + + if unrecognized_types: + return NotImplemented + with mode: + return fn(*args, **kwargs) + + triton_kernels = get_inner_triton_kernels(fn) + triton_ops_to_kernels[name] = triton_kernels + result.register_torch_dispatch(FunctionalTensorMode, functional_decomp) + return result + + if fn is None: + return dec + else: + return dec(fn) + + +wrap_triton_enabled = threading.local() +wrap_triton_enabled_default = True + + +@contextlib.contextmanager +def set_wrap_triton_enabled(enabled: bool) -> Generator[None, None, None]: + """If triton kernels annotated with @wrap_triton should dispatch via HOP + or go straight to the triton kernel execution. + + We have this switch because eager-mode performance of HOP dispatch is slow + enough to matter (~1ms) and we know that wrap_triton isn't necessary in + some situations (eager-mode with regular Tensors) + """ + try: + prev = is_wrap_triton_enabled() + wrap_triton_enabled.value = enabled + yield + finally: + wrap_triton_enabled.value = prev + + +def is_wrap_triton_enabled() -> bool: + return getattr(wrap_triton_enabled, "value", wrap_triton_enabled_default) + + +def capture_triton(triton_kernel: Callable, /) -> Any: + """This API has been renamed to wrap_triton""" + return wrap_triton(triton_kernel) + + +@exposed_in("torch.library") +def wrap_triton(triton_kernel: Callable, /) -> Any: + """Allows capture of a triton kernel into a graph via make_fx or + non-strict ``torch.export``. + + These technologies perform Dispatcher-based tracing (via + ``__torch_dispatch__``) and cannot see calls to raw triton kernels. + The ``wrap_triton`` API wraps a triton kernel into a callable that + can actually be traced into a graph. + + Please use this API together with :func:`torch.library.triton_op`. + + Examples: + + >>> # xdoctest: +SKIP + >>> import torch + >>> import triton + >>> from triton import language as tl + >>> from torch.fx.experimental.proxy_tensor import make_fx + >>> from torch.library import wrap_triton + >>> + >>> @triton.jit + >>> def add_kernel( + >>> in_ptr0, + >>> in_ptr1, + >>> out_ptr, + >>> n_elements, + >>> BLOCK_SIZE: "tl.constexpr", + >>> ): + >>> pid = tl.program_id(axis=0) + >>> block_start = pid * BLOCK_SIZE + >>> offsets = block_start + tl.arange(0, BLOCK_SIZE) + >>> mask = offsets < n_elements + >>> x = tl.load(in_ptr0 + offsets, mask=mask) + >>> y = tl.load(in_ptr1 + offsets, mask=mask) + >>> output = x + y + >>> tl.store(out_ptr + offsets, output, mask=mask) + >>> + >>> def add(x, y): + >>> output = torch.empty_like(x) + >>> n_elements = output.numel() + >>> + >>> def grid_fn(meta): + >>> return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) + >>> + >>> wrap_triton(add_kernel)[grid_fn](x, y, output, n_elements, 16) + >>> return output + >>> + >>> x = torch.randn(3, device="cuda") + >>> y = torch.randn(3, device="cuda") + >>> gm = make_fx(add)(x, y) + >>> print(gm.code) + >>> # def forward(self, x_1, y_1): + >>> # empty_like = torch.ops.aten.empty_like.default(x_1, pin_memory = False) + >>> # triton_kernel_wrapper_mutation_proxy = triton_kernel_wrapper_mutation( + >>> # kernel_idx = 0, constant_args_idx = 0, + >>> # grid = [(1, 1, 1)], kwargs = { + >>> # 'in_ptr0': x_1, 'in_ptr1': y_1, 'out_ptr': empty_like, + >>> # 'n_elements': 3, 'BLOCK_SIZE': 16 + >>> # }) + >>> # return empty_like + + """ + from triton.runtime.autotuner import Autotuner + from triton.runtime.jit import JITFunction + + from torch._higher_order_ops.triton_kernel_wrap import TraceableTritonKernelWrapper + + if not isinstance(triton_kernel, (JITFunction, Autotuner)): + raise RuntimeError( + "wrap_triton only works on functions annotated with triton.jit or triton.autotune" + ) + if not is_wrap_triton_enabled(): + return triton_kernel + return TraceableTritonKernelWrapper(triton_kernel, None, None) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_library/utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_library/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d50c7664d60b6a6085b13a9f92a4b288be7b06eb --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_library/utils.py @@ -0,0 +1,647 @@ +# mypy: allow-untyped-defs +import dataclasses +import inspect +import sys +from collections.abc import Callable, Iterable, Iterator +from typing import Any, Literal, Optional, overload, Union + +import torch +import torch.utils._pytree as pytree +import torchgen +from torch import _C, _utils_internal +from torch._ops import OpOverload + + +@dataclasses.dataclass +class Kernel: + """Models a (function, source location)""" + + func: Callable + source: str + + def __call__(self, *args, **kwargs): + return self.func(*args, **kwargs) + + +class RegistrationHandle: + """Does something when someone calls .destroy() on it""" + + def __init__(self, on_destroy: Callable): + self._on_destroy = on_destroy + + def destroy(self) -> None: + self._on_destroy() + + +def get_source(stacklevel: int) -> str: + """Get a string that represents the caller. + + Example: "/path/to/foo.py:42" + + Use stacklevel=1 to get the caller's source + Use stacklevel=2 to get the caller's caller's source + etc. + """ + frame = inspect.getframeinfo(sys._getframe(stacklevel)) + source = f"{frame.filename}:{frame.lineno}" + return source + + +def parse_namespace(qualname: str) -> tuple[str, str]: + splits = qualname.split("::") + if len(splits) != 2: + raise ValueError( + f"Expected `qualname` to be of the form " + f'"namespace::name", but got {qualname}. ' + f"The qualname passed to the torch.library APIs must consist " + f"of a namespace and a name, e.g. aten::sin" + ) + return splits[0], splits[1] + + +def lookup_op(qualname: str) -> OpOverload: + namespace, name = parse_namespace(qualname) + if "." in name: + name, overload = name.split(".") + else: + overload = "default" + ns = getattr(torch.ops, namespace) + packet = getattr(ns, name) + return getattr(packet, overload) + + +def is_builtin(op: OpOverload) -> bool: + assert isinstance(op, OpOverload) + return op.namespace in {"aten", "prim", "prims"} + + +def is_functional_schema(schema: Any, *, allow_valid_view: bool = False) -> bool: + """Check if the schema is functional. + + An operator is functional if: + - it does not mutate any of its inputs + - If no view are allowed + - it does not return a view on any of its inputs + - If valid views are allowed + - it is not a view or a view with a single input Tensor and single output Tensor + - it has at least one return + """ + + def is_functional(schema): + if schema.is_mutable: + return False + rets = schema.returns + is_non_mutating_view = len(rets) > 0 and any( + r.alias_info is not None and not r.alias_info.is_write for r in rets + ) + num_tensor_inputs = 0 + num_tensor_outputs = 0 + + if isinstance(schema, torch.FunctionSchema): + for arg in schema.arguments: + if isinstance(arg.type, torch.TensorType): + num_tensor_inputs += 1 + + for ret in schema.returns: + if isinstance(ret.type, torch.TensorType): + num_tensor_outputs += 1 + + elif isinstance(schema, torchgen.model.FunctionSchema): + for argument in schema.arguments.flat_non_out: + if argument.type.is_tensor_like(): + num_tensor_inputs += 1 + + for ret_arg in schema.returns: + if ret_arg.type.is_tensor_like(): + num_tensor_outputs += 1 + + if is_non_mutating_view: + return allow_valid_view and ( + num_tensor_inputs == 1 and num_tensor_outputs == 1 + ) + if not schema.returns: + return False + return True + + if isinstance(schema, torch._C.FunctionSchema): + return is_functional(schema) + + # Lazy import because not all PyTorch builds have torchgen + from torchgen.model import FunctionSchema + + if isinstance(schema, str): + schema = FunctionSchema.parse(schema) + assert isinstance(schema, FunctionSchema) + return is_functional(schema) + + +# should be torch._C.JitType but that annotation is busted +def is_tensorlist_like_type(typ: Any) -> bool: + return ( + typ == _C.ListType(_C.TensorType.get()) + or typ == _C.ListType(_C.OptionalType(_C.TensorType.get())) + or typ == _C.OptionalType(_C.ListType(_C.TensorType.get())) + or typ == _C.OptionalType(_C.ListType(_C.OptionalType(_C.TensorType.get()))) + ) + + +# should be torch._C.JitType but that annotation is busted +def is_tensor_like_type(typ: Any) -> bool: + return typ == _C.TensorType.get() or typ == _C.OptionalType(_C.TensorType.get()) + + +def mutates_and_returns_first_arg(op: OpOverload): + """Check if an op is an inplace aten op, i.e. it mutates and returns the first arg. + + TODO: torchgen/model.py's FunctionSchema.parse is the source of truth for this, + but not all PyTorch builds have torchgen (due to the yaml dependency being weird). + Figure this out. + + Example: add_(Tensor(a!) x, Tensor y) -> Tensor(a) + """ + if op.namespace != "aten": + return False + schema = op._schema + if len(schema.returns) != 1: + return False + if schema.returns[0].alias_info is None: + return False + alias_set = schema.returns[0].alias_info.after_set + if len(alias_set) != 1: + return False + loc = next(iter(alias_set)) + if len(schema.arguments) < 1: + return False + first_arg = schema.arguments[0] + if first_arg.alias_info is None: + return False + if not first_arg.alias_info.is_write: + return False + alias_set = first_arg.alias_info.after_set + if len(alias_set) != 1: + return False + if loc != next(iter(alias_set)): + return False + for arg in schema.arguments[1:]: + if arg.alias_info is not None: + return False + return True + + +def fill_defaults(schema, args, kwargs): + new_args = [] + new_kwargs = {} + for i in range(len(schema.arguments)): + info = schema.arguments[i] + if info.kwarg_only: + if info.name in kwargs: + new_kwargs[info.name] = kwargs[info.name] + else: + new_kwargs[info.name] = info.default_value + else: + if i < len(args): + new_args.append(args[i]) + else: + new_args.append(info.default_value) + return tuple(new_args), new_kwargs + + +def zip_schema( + schema: _C.FunctionSchema, args: tuple[Any, ...], kwargs: dict[str, Any] +) -> Iterable[tuple[_C.Argument, Any]]: + """zips schema.arguments and (args, kwargs) together. + + Assumes that (args, kwargs) were the inputs to some torch._ops.OpOverload: + that is, (args, kwargs) must be bindable to the schema (args, kwargs). + """ + assert len(schema.arguments) >= len(args) + len(kwargs) + for i in range(len(schema.arguments)): + info = schema.arguments[i] + if info.kwarg_only: + if info.name in kwargs: + yield info, kwargs[info.name] + continue + if i >= len(args): + if not info.kwarg_only and info.name in kwargs: + yield info, kwargs[info.name] + # args that are equal to their default values are not populated + # if they are followed by args that are equal to their defaults. + # Skip these. + continue + yield info, args[i] + return + + +def hop_schema_from_fx_node(node): + from torchgen.gen_schema_utils import FunctionSchemaGen + + hop = node.target + if not isinstance(hop, torch._ops.HigherOrderOperator): + raise RuntimeError("fx_node's target must be a hop.") + + def _collect_example_val(node): + meta_val = node.meta.get("val", None) + if meta_val is None: + assert node.op == "get_attr" + meta_val = getattr(node.graph.owning_module, node.target) + return meta_val + + example_inputs = [] + for arg in node.args: + if isinstance(arg, (torch.fx.Node, torch.fx.node.Node)): + example_inputs.append(_collect_example_val(arg)) + elif isinstance( + arg, (torch.fx.immutable_collections.immutable_list, list, tuple) + ): + example_inputs.append([_collect_example_val(x) for x in arg]) + else: + raise RuntimeError(f"Unsupported arg type {type(arg)}") + + # Bound the arguments to make sure number of inputs are correct + bound_args: inspect.BoundArguments = inspect.signature(hop.__call__).bind( + *example_inputs + ) + + # We treat example_output as a single value in return. This is to differentiate 1. return a single val + # vs 2. return a tuple with one element. + example_output = _collect_example_val(node) + return FunctionSchemaGen.from_example( + hop._name, tuple(bound_args.arguments.items()), (list(example_output),) + ) + + +def can_generate_trivial_fake_impl(op: OpOverload) -> bool: + assert isinstance(op, OpOverload) + if is_builtin(op): + # We control the built-ins. These may (in rare cases) + # do input metadata mutation (which we have banned on custom ops) + return False + schema = op._schema + # It's suspicious if the op is not mutable but returns nothing, so we return False out of an abundance of caution + if not schema.is_mutable: + return False + if len(schema.returns) > 0: + return False + # If the op returns nothing, then it has a trivial fake impl. + return True + + +def requires_set_python_module() -> bool: + """If an op was defined in C++ and extended from Python using the + torch.library APIs, returns if we require that there have been a + m.set_python_module("mylib.ops") call from C++ that associates + the C++ op with a python module. + """ + return getattr(_utils_internal, "REQUIRES_SET_PYTHON_MODULE", True) + + +def handle_dispatch_mode(curr_mode, op_overload, *args, **kwargs): + assert isinstance(curr_mode, torch.utils._python_dispatch.TorchDispatchMode) + args_flattened, _ = torch.utils._pytree.tree_flatten((args, kwargs.values())) + # TODO: need to double check the semantics of the "types" argument to torch_dispatch. + # It's generated in PyInterpreter.cpp, but seems to be generated in two places, + # where in one case we only include tensors with the python key, and in another + # we include **all** tensors. + overload_types = [ + type(a) + for a in args_flattened + if isinstance(a, torch.Tensor) + and torch._C._dispatch_keys(a).has(torch._C.DispatchKey.Python) + ] + # TODO: check that I got these args correct (in C++, we pass in "0000"??) + + return curr_mode.__torch_dispatch__(op_overload, overload_types, args, kwargs) + + +def has_kwarg_only_args(schema: _C.FunctionSchema): + return any(a.kwarg_only for a in schema.arguments) + + +def has_kwarg_only_tensors(schema: _C.FunctionSchema): + for a in schema.arguments: + if not (is_tensor_like_type(a.type) or is_tensorlist_like_type(a.type)): + continue + if not a.kwarg_only: + continue + return True + return False + + +def has_tensor_arg(schema: _C.FunctionSchema) -> bool: + """ + Given a schema, returns True if the schema has a Tensor arg. + A Tensor arg is any arg with a type annotation that might involve Tensor. + """ + return any( + (is_tensor_like_type(a.type) or is_tensorlist_like_type(a.type)) + for a in schema.arguments + ) + + +def get_device_arg_index(schema: _C.FunctionSchema) -> Union[int, None]: + """ + Given a schema, returns the id of the `device: torch.device` argument. + If it does not exist, returns None. + """ + for index, arg in enumerate(schema.arguments): + if arg.type is _C.DeviceObjType.get() and arg.name == "device": + return index + return None + + +def iter_tensors( + args: tuple[Any], kwargs: dict[str, Any], allowed_nesting: int = 1 +) -> Iterator[torch.Tensor]: + def check(arg): + if isinstance(arg, torch.Tensor): + yield arg + elif allowed_nesting > 0 and isinstance(arg, (tuple, list)): + yield from iter_tensors(tuple(arg), {}, allowed_nesting - 1) + + for arg in args: + yield from check(arg) + for kwarg in kwargs.values(): + yield from check(kwarg) + + +def check_aliasing_constraint(name, prev, result, get_module=lambda: "???"): + """ + custom operators' outputs must not alias any inputs or other outputs. + """ + storages = {t.untyped_storage()._cdata for t in prev if isinstance(t, torch.Tensor)} + tuple_result = result + if not isinstance(result, tuple): + tuple_result = (result,) + for tensor in iter_tensors(tuple_result, {}): + key = tensor.untyped_storage()._cdata + if tensor.untyped_storage()._cdata in storages: + raise RuntimeError( + f"{name} (with implementation in {get_module()}): " + f"The output of this custom operator (1) must not " + f"also be an input to this custom operator and " + f"(2) may not alias any inputs to this custom operator " + f"or other returns. " + f"The most common way to trigger this error is if " + f"we have y = custom_op(x) and y and x are the same Tensor. " + f"Please instead return a clone of the offending output " + f"tensor(s) (e.g. return x.clone()) or refactor the custom " + f"operator to not return y." + ) + storages.add(key) + + +def _c_check_aliasing_constraint(name, args, kwargs, result, get_module=lambda: "???"): + """ + custom operators' outputs must not have any aliases + This version uses C++ implementation for perf. + Only List container is supported. + Tensors in Lists with not only Tensors are checked. + """ + tuple_result = result + if not isinstance(result, tuple): + tuple_result = (result,) + if _C._any_output_is_alias_to_input_or_output(args, kwargs, tuple_result): + raise RuntimeError( + f"{name} (with implementation in {get_module()}): " + f"The output of this custom operator (1) must not " + f"also be an input to this custom operator and " + f"(2) may not alias any inputs to this custom operator " + f"or other returns. " + f"The most common way to trigger this error is if " + f"we have y = custom_op(x) and y and x are the same Tensor. " + f"Please instead return a clone of the offending output " + f"tensor(s) (e.g. return x.clone()) or refactor the custom " + f"operator to not return y." + ) + + +class MutationChecker: + """ + Check if an operator mutated its arguments. + Usage: + + checker = MutationChecker(op, flat_args, args_spec) + op(*args, **kwargs) + checker.check() + """ + + def __init__(self, op, flat_args, args_spec): + self.op = op + self.args_spec = args_spec + self.flat_args = flat_args + self.real_pre_hashes = [ + hash_tensor(a) if isinstance(a, torch.Tensor) else None for a in flat_args + ] + + def check(self): + real_post_hashes = [ + hash_tensor(a) if isinstance(a, torch.Tensor) else None + for a in self.flat_args + ] + was_mutated = [ + not torch.equal(pre, post) + and not (pre.isnan().all() and post.isnan().all()) + if isinstance(pre, torch.Tensor) and isinstance(post, torch.Tensor) + else None + for pre, post in zip(self.real_pre_hashes, real_post_hashes) + ] + was_mutated_args, was_mutated_kwargs = pytree.tree_unflatten( + was_mutated, self.args_spec + ) + for info, was_mutated in zip_schema( + self.op._schema, was_mutated_args, was_mutated_kwargs + ): + + def check_one(info, was_mutated): + if info.is_write == was_mutated: + return + raise RuntimeError( + f"{self.op._name}: for argument '{info.name}': the operator's schema " + f"{self.op._schema} specified that " + f"the operator {'mutates' if info.is_write else 'does not mutate'} " + f"the argument, but this seems to be empirically wrong. " + f"Please make the schema and operator behavior consistent. " + f"You can specify that an operator mutates a Tensor by " + f"e.g. changing its schema type from 'Tensor name' to 'Tensor(a!) name'" + f"(use different identifiers (a, b, c, ...) for different Tensors)" + ) + + if is_tensor_like_type(info.type): + check_one(info, was_mutated) + elif is_tensorlist_like_type(info.type): + was_any_mutated = False if was_mutated is None else any(was_mutated) + check_one(info, was_any_mutated) + + +def hash_tensor(t: torch.Tensor) -> torch.Tensor: + """Some inexpensive hash. Used as a quick and dirty indicator for tensor mutation""" + return t.detach().float().mean() + + +def has_fake_kernel(op: torch._ops.OpOverload) -> bool: + """If an operator (that stays alive until FakeTensorMode) has a Fake kernel. + Don't use this if the operator decomposes before FakeTensorMode. + """ + if can_generate_trivial_fake_impl(op): + return True + name = op._name + if torch._C._dispatch_has_kernel_for_dispatch_key( + name, "CompositeImplicitAutograd" + ): + return True + opdef = torch._library.custom_ops._maybe_get_opdef(name) + if opdef is None: + # the non-torch.library.custom_op path + if torch._C._dispatch_has_kernel_for_dispatch_key( + name, "CompositeExplicitAutograd" + ): + return True + entry = torch._library.simple_registry.singleton.find(name) + if entry.fake_impl.kernel is not None: + return True + if torch._C._dispatch_has_kernel_for_dispatch_key(name, "Meta"): + return True + else: + # the torch.library.custom_op path + if opdef._abstract_fn is not None: + return True + return False + + +def mutated_args_kwargs(schema: _C.FunctionSchema) -> tuple[list[int], list[str]]: + idxs = [] + keys = [] + for i, info in enumerate(schema.arguments): + if info.alias_info is not None and info.alias_info.is_write: + if info.kwarg_only: + keys.append(info.name) + else: + idxs.append(i) + return idxs, keys + + +tags_by_priority = [ + _C.Tag.needs_exact_strides, + _C.Tag.needs_contiguous_strides, + _C.Tag.needs_fixed_stride_order, + _C.Tag.flexible_layout, +] + + +# Case 1: with_default=True (or omitted). Return type is guaranteed to be a Tag. +@overload +def get_layout_constraint_tag( + fn: Any, *, with_default: Literal[True] = True +) -> _C.Tag: ... + + +# Case 2: with_default=False. Return type can be a Tag or None. +@overload +def get_layout_constraint_tag( + fn: Any, *, with_default: Literal[False] +) -> Optional[_C.Tag]: ... + + +def get_layout_constraint_tag(fn, *, with_default=True): + for tag in tags_by_priority: + if tag in fn.tags: + return tag + if with_default: + if is_builtin(fn): + return _C.Tag.flexible_layout + import torch._functorch + from torch._functorch import config + + return getattr(torch._C.Tag, config.custom_op_default_layout_constraint) + return None + + +# List of random functions that should be treated as impure +_RANDOM_FUNCTIONS = { + torch.rand, + torch.randn, + torch.randint, + torch.randperm, + torch.rand_like, + torch.randn_like, + torch.randint_like, + torch.normal, + torch.poisson, + torch.bernoulli, + torch.multinomial, +} + + +def is_impure( + op: Callable, + *, + args: Optional[tuple[Any, ...]] = None, + kwargs: Optional[dict[str, Any]] = None, + impure_random: bool = True, +) -> bool: + """ + An operator is impure if it: + - Mutates its inputs (has a mutable schema) + - Has nondeterministic/random behavior that mutates RNG state + - Is explicitly marked as effectful via torch.library._register_effectful_op + + Args: + op: The operator to check (function, OpOverload, HigherOrderOperator, etc.) + args: Optional arguments that would be passed to the callable + kwargs: Optional keyword arguments that would be passed to the callable + impure_random: Whether to treat random operations as impure (default: True) + + Returns: + bool: True if the callable has side effects, False otherwise + """ + # Import here to avoid circular dependencies + from torch._higher_order_ops.effects import _get_effect + from torch.fx.node import _side_effectful_functions + + if isinstance(op, torch._ops.OpOverload): + schema = getattr(op, "_schema", None) + if schema is not None and schema.is_mutable: + return True + + if op in _side_effectful_functions: + return True + + if _get_effect(op) is not None: + return True + + if isinstance(op, torch._ops.HigherOrderOperator): + if op in ( + torch.ops.higher_order.auto_functionalized, + torch.ops.higher_order.auto_functionalized_v2, + ): + # Check if the auto-functionalized operator (the first argument) is + # side-effectful + if args and len(args) > 0: + return args[0] in _side_effectful_functions + + if _get_effect(op) is not None: + return True + + return False + + # Impure since it mutates RNG state + if impure_random and getattr(op, "_nondeterministic_seeded", False): + return True + + # Handle Python random functions that don't have _nondeterministic_seeded + # but still affect global RNG state (issue #151524) + # These should be impure regardless of impure_random setting to maintain + # consistency between eager and compiled execution + # All random operations are impure to ensure consistent behavior + # between eager and compiled execution, regardless of generator usage + if op in _RANDOM_FUNCTIONS: + return True + + schema = getattr(op, "_schema", None) + if schema is not None and schema.is_mutable: + return True + + if op in _side_effectful_functions: + return True + + return False diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_refs/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_refs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..daf7c24c8d28c6a8bffa51dcdcbe11e001185bcd --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_refs/__init__.py @@ -0,0 +1,6876 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +import builtins +import collections +import inspect +import itertools +import math +import operator +import warnings +from collections.abc import Callable, Iterable, Sequence +from enum import Enum +from functools import partial, reduce, singledispatch, wraps +from typing import Any, cast, Optional, overload, Union + +import torch +import torch._prims as prims +import torch._prims_common as utils +import torch.utils._pytree as pytree +from torch import sym_float, sym_int +from torch._prims_common import ( + BoolLike, + DeviceLikeType, + Dim, + DimsSequenceType, + DimsType, + dtype_to_type, + ELEMENTWISE_TYPE_PROMOTION_KIND, + FloatLike, + FloatWithoutSymFloat, + IntLike, + is_contiguous_for_memory_format_or_false, + is_contiguous_or_false, + is_weakly_lesser_type, + Number, + NumberType, + RealNumberType, + REDUCTION_OUTPUT_TYPE_KIND, + ShapeType, + StrideType, + TensorLike, + TensorLikeType, + TensorOrNumberLikeType, + TensorSequenceType, +) +from torch._prims_common.wrappers import ( + _maybe_convert_to_dtype, + _maybe_resize_out, + _safe_copy_out, + elementwise_type_promotion_wrapper, + elementwise_unary_scalar_wrapper, + out_wrapper, +) + + +# Experimental module containing prototype Python references for existing +# PyTorch operations. + +__all__ = [ + # + # Elementwise Unary References + # + "abs", + "acos", + "acosh", + "asinh", + "asin", + "atan", + "atanh", + "bitwise_not", + # "cbrt", # No corresponding torch operation + "ceil", + "conj_physical", + "cos", + "cosh", + "count_nonzero", + "deg2rad", + "digamma", + "erf", + "erfinv", + "erfc", + "exp", + "expm1", + "exponential", + "exp2", + "fill", + "fill_", + "floor", + "frac", + "geometric", + "index_add", + "index_copy", + "index_copy_", + "index_select", + "index_fill", + "index_fill_", + "isfinite", + "isinf", + "isposinf", + "isneginf", + "isnan", + "isreal", + "i0", + "lerp", + "lgamma", + "log", + "log1p", + "log2", + "log10", + "log_normal", + "log_softmax", + "mvlgamma", + "norm", + "normal", + "nan_to_num", + "neg", + "positive", + "rad2deg", + "reciprocal", + "round", # TODO: model kwargs + "sigmoid", + "sgn", + "sign", + "signbit", + "sin", + "sinc", + "sinh", + "softmax", + "sqrt", + "square", + "tan", + "tanh", + "trace", + "trunc", + # + # Elementwise Binary References + # + "add", + "atan2", + "bitwise_and", + "bitwise_left_shift", + "bitwise_or", + "bitwise_right_shift", + "bitwise_xor", + "clamp_min", + "clamp_max", + "copysign", + "div", + "eq", + "float_power", + "floor_divide", + "fmax", + "fmin", + "fmod", + "gcd", + "ge", + "gt", + "heaviside", + "hypot", + "igamma", + "igammac", + "imag", + "isclose", + "lcm", + # 'ldexp', + "le", + "logaddexp", + "logaddexp2", + "logical_and", + "logical_not", + "logical_or", + "logical_xor", + "logsumexp", + "lt", + # 'max', # implement with reductions + "maximum", + # 'min', # implement with reductions + "minimum", + "mul", + "ne", + "nextafter", + # 'polar', # abs, cos, sin + "pow", + "real", + "rpow", + "remainder", + "rsub", + "rtruediv", + "rfloordiv", + "sub", + "true_divide", + "trunc_divide", + "xlogy", + # + # Elementwise Ternary References + # + "addcdiv", + "addcmul", + "clamp", + # + # Conditional references + # + "masked_fill", + "masked_fill_", + "where", + # + # Data conversion and movement references + # + "clone", + "copy_to", # TODO: add OpInfo (or implement .to) + "item", + "to", + # + # Reduction ops + # + "all", + "amax", + "amin", + "any", + "cumsum", + "cumprod", + "mean", + "dot", + "vdot", + "std", + "std_mean", + "sum", + "sum_to_size", + "prod", + "var", + "var_mean", + # + # Linear algebra ops + # + "addr", + # + # View & Shape Ops + # + "alias", + "alias_copy", + "atleast_1d", + "atleast_2d", + "atleast_3d", + "as_strided", + "as_strided_copy", + "as_strided_scatter", + "block_diag", + "broadcast_shapes", + "broadcast_tensors", + "broadcast_to", + "cat", + "chunk", + "column_stack", + "conj", + "constant_pad_nd", + "contiguous", + "diag_embed", + "diag", + "diagonal", + "diagonal_copy", + "diagonal_scatter", + "dsplit", + "dstack", + "expand", + "expand_as", + "expand_copy", + "flatten", + "flip", + "fliplr", + "flipud", + "hsplit", + "hstack", + "meshgrid", + "movedim", + "narrow", + "narrow_copy", + "native_group_norm", + "native_layer_norm", + "permute", + "permute_copy", + "ravel", + "repeat", + "reshape", + "reshape_as", + "roll", + "rot90", + "rsqrt", + "split_with_sizes", + "stack", + "swap_axes", # alias for transpose + "squeeze", + "squeeze_copy", + "t", + "t_copy", + "T", + "take_along_dim", + "tensor_split", + "transpose", + "transpose_copy", + "unbind_copy", + "unfold", + "unfold_copy", + "unsqueeze", + "unsqueeze_copy", + "view", + "view_as", + "view_copy", + "vsplit", + "vstack", + "view_as_complex", + "unflatten", + "unbind", + "triu", + "tril", + "triu_indices", + "tril_indices", + # + # Tensor Creation + # + "arange", + "cauchy", + "empty", + "empty_like", + "empty_permuted", + "empty_strided", + "eye", + "full", + "full_like", + "linspace", + "logspace", + "new_empty", + "new_empty_strided", + "new_full", + "new_ones", + "new_zeros", + "ones", + "ones_like", + "randn", + "scalar_tensor", + "zero", + "zeros", + "zeros_like", + # + # Test-related functions + # + "allclose", + "equal", + # + # Statistical operations + # + "bucketize", + # + # Misc + # + "is_complex", + "renorm", + "stft", + "istft", +] + +Tensor = torch.Tensor +DispatchKey = torch._C.DispatchKey # type: ignore[attr-defined] +aten = torch._ops.ops.aten + +# Note that the docstrings for the public methods from this file are in +# torch/_torch_docs.py + + +def is_noncontiguous_supported(device): + return device is None or device.type != "hpu" + + +def handle_noncontiguous_outputs(input_tlist, output): + device = None + from torch._subclasses.fake_tensor import FakeTensor + + for t in input_tlist: + if isinstance(t, FakeTensor): + device = t.fake_device + break + + if not is_noncontiguous_supported(device): + output = output.contiguous() + + return output + + +def _broadcast_shapes(*_shapes): + from torch.fx.experimental.symbolic_shapes import ( + guard_or_false, + is_nested_int, + size_hint, + ) + + backed_so = torch.fx.experimental._config.backed_size_oblivious + + shapes = tuple( + (x,) if isinstance(x, IntLike) else x + for x in filter(lambda x: x is not None, _shapes) + ) + + # Short-circuits on no input + if len(shapes) == 0: + return None + + for shape in shapes: + if not isinstance(shape, Sequence): + raise RuntimeError( + "Input shapes should be of type ints, a tuple of ints, or a list of ints, got ", + shape, + ) + + # Computes common shape + common_shape: list[Union[int, torch.SymInt]] = [ + 1, + ] * reduce(max, (len(shape) for shape in shapes)) + for arg_idx, shape in enumerate(shapes): + for idx in range(-1, -1 - len(shape), -1): + # NB: handle nested ints specially to avoid invalid guarding on Ne(j0, 1). + if is_nested_int(shape[idx]): + # Broadcasting is allowed for (j0, 1) or (j0, j0); + # not (j0, j1), (j0, 5), etc. + if is_nested_int(common_shape[idx]) and guard_or_false( + shape[idx] == common_shape[idx] + ): + continue + else: + # When backed size oblivious is used, we specialize for broadcasting + # if its the only way to compile the example input. + # i.e: s0:1, s1:1 ==> + # assert s0==s1, no specialization on ==1 or !=1. + # The non-broadcast path is picked + # s0:1, s1:4 ==> + # specialize(s0) to be 1. + # s0:4, s1:1 ==> + # specialize(s1) to be 1. + if backed_so: + a = size_hint(shape[idx], allow_none=True) + b = size_hint(common_shape[idx], allow_none=True) + if a == 1 and b != 1: + torch._check(shape[idx] == 1) + if b == 1 and a != 1: + torch._check(common_shape[idx] == 1) + if guard_or_false(shape[idx] == common_shape[idx]): + continue + + if guard_or_false(common_shape[idx] == 1): + if shape[idx] < 0: + raise ValueError( + "Attempting to broadcast a dimension with negative length!" + ) + common_shape[idx] = shape[idx] + + if not is_nested_int(shape[idx]) and guard_or_false(shape[idx] == 1): + # broadcast case . + continue + else: + # If broadcasting is undecided we pick non-broadcast path and add runtime assertion. + torch._check( + common_shape[idx] == shape[idx], + lambda: f"Attempting to broadcast a dimension of length {shape[idx]} at {idx}! " + f"Mismatching argument at index {arg_idx} had {shape}; but expected shape " + f"should be broadcastable to {common_shape}", + ) + + return common_shape + + +def _maybe_broadcast(*args, preserve_cpu_scalar_tensors=True): + # Computes common shape + common_shape = _broadcast_shapes( + *(t.shape if isinstance(t, TensorLike) else None for t in args) + ) + + def should_expand(a: ShapeType, b: ShapeType) -> bool: + from torch.fx.experimental.symbolic_shapes import ( + guard_or_false, + sym_and, + sym_or, + ) + + if len(a) != len(b): + return True + + for x, y in zip(a, b): + if guard_or_false(x != y): + # We know they are not the same. + return True + + # They are the same or we do not know if they are the same or not. + # 1==1 no-broadcast + # u0==1 and 1==u0 cases. We broadcast! + if guard_or_false(sym_and(x == 1, y == 1)): + pass + elif guard_or_false(sym_or(x == 1, y == 1)): + # assume broadcasting. + return True + + # u0==u1 assume the same, no broadcasting! + torch._check( + x == y, + lambda: "sizes assumed to be the same due to unbacked broadcasting semantics", + ) + + return False + + def __maybe_broadcast(x, shape): + if x is None: + return None + elif isinstance(x, Number): + return x + elif isinstance(x, TensorLike): + if preserve_cpu_scalar_tensors and utils.is_cpu_scalar_tensor(x): + return x + + if should_expand(x.shape, common_shape): + return x.expand(common_shape) + + return x + else: + raise RuntimeError( + "Unexpected type when broadcasting: " + str(type(x)) + "!" + ) + + return tuple(__maybe_broadcast(x, common_shape) for x in args) + + +# Utilities should come BEFORE this import +from torch._decomp import register_decomposition + + +# +# Elementwise unary references +# + +infer_aten_op = object() + + +# TODO: add type promotion support +def _make_elementwise_unary_reference( + type_promotion_kind, + *, + aten_op=infer_aten_op, + extra_meta=None, + exact_dtype=False, +) -> Callable: + def inner(prim: Callable): + nonlocal aten_op + + @wraps(prim) + @out_wrapper(exact_dtype=exact_dtype) + @elementwise_unary_scalar_wrapper + @elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=type_promotion_kind, + ) + def _ref(a: TensorLikeType) -> TensorLikeType: + if extra_meta is not None: + extra_meta(a) + + output = prim(a) + return handle_noncontiguous_outputs([a], output) + + if aten_op is infer_aten_op: + aten_op = utils.get_aten_op(prim, prim.__name__) + if aten_op is not None: + register_decomposition(aten_op)(_ref) + + return _ref + + return inner + + +def _make_alias(fn, name): + """ + This function defines an alias of another function and sets its __name__ argument. + It also sets its __module__ argument to the module of the caller. + Note that when naively doing `alias = fn`, we have that `alias.__name__ == "fn"`, and + `alias.__module__ == fn.__module__`. + """ + + def _fn(*args, **kwargs): + return fn(*args, **kwargs) + + _fn.__name__ = name + _fn.__module__ = inspect.currentframe().f_back.f_globals["__name__"] # type: ignore[union-attr] + return _fn + + +def _make_inplace(fn): + """ + Given a function with out variant (i.e. using `out_wrapper()), it returns its in-place variant + See https://github.com/pytorch/pytorch/wiki/Developer-FAQ#how-do-in-place-operations-work-in-pytorch + """ + + # nb. We use the name of the first argument used in the unary references + @wraps(fn) + def _fn(a, *args, **kwargs): + return fn(a, *args, out=a, **kwargs) + + inplace_name = f"{fn.__name__}_" + _fn.__name__ = inplace_name + _fn = register_decomposition(getattr(aten, inplace_name))(_fn) # type: ignore[assignment] + + # We access the __all__ attribute of the module where fn is defined + # There may be a cleaner way of doing this... + from inspect import getmodule + + _all = getmodule(fn).__all__ # type: ignore[union-attr] + if inplace_name not in _all: + _all.append(inplace_name) + return _fn + + +@_make_elementwise_unary_reference( + ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT, + exact_dtype=True, +) +def abs(a): + return prims.abs(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def acos(a): + return prims.acos(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def acosh(a): + return prims.acosh(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def asin(a): + return prims.asin(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def asinh(a): + return prims.asinh(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def atan(a): + return prims.atan(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def atanh(a): + return prims.atanh(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT) +def bitwise_not(a): + return prims.bitwise_not(a) + + +@_make_elementwise_unary_reference( + ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + exact_dtype=True, +) +def ceil(a): + return prims.ceil(a) + + +@register_decomposition(aten.is_complex) +def is_complex(input: TensorLikeType): + return utils.is_complex_dtype(input.dtype) + + +@register_decomposition(aten.conj_physical) +@out_wrapper() +def conj_physical(input: TensorLikeType): + if not utils.is_complex_dtype(input.dtype): + return input + return prims.conj_physical(input) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def cos(a): + return prims.cos(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def cosh(a): + return prims.cosh(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def digamma(a): + return prims.digamma(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def erf(a): + return prims.erf(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def erfinv(a): + return prims.erf_inv(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def erfc(a): + return prims.erfc(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def exp(a): + return prims.exp(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def expm1(a): + return prims.expm1(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def exp2(a): + return prims.exp2(a) + + +# Fill has its own implementation because it has a value parameter +# CompositeImplicitAutograd - don't register decomp +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH, +) +def fill(a: TensorLikeType, value: NumberType) -> TensorLikeType: + assert isinstance(a, TensorLike) + assert isinstance(value, Number) + + python_type = utils.dtype_to_type(a.dtype) + if not utils.is_weakly_lesser_type(type(value), python_type): + msg = f"value argument of type {type(value)} cannot be safely cast to type {python_type}!" + raise ValueError(msg) + + return prims.fill(a, value) + + +def fill_(a: TensorLikeType, value: NumberType) -> TensorLikeType: + r = prims.fill(a, value) + prims.copy_to(a, r) + return a + + +@register_decomposition(aten.zero) +@out_wrapper() +def zero(input: TensorLikeType) -> TensorLikeType: + return torch.zeros_like(input) + + +@_make_elementwise_unary_reference( + ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + exact_dtype=True, +) +def floor(a): + return prims.floor(a) + + +@_make_elementwise_unary_reference( + ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + exact_dtype=True, +) +def frac(x: TensorLikeType) -> TensorLikeType: + trunc_x = torch.mul(torch.floor(torch.abs(x)), torch.sign(x)) + return torch.sub(x, trunc_x) + + +# imag does not use _make_elementwise_unary_reference because it does not support out +def imag(a: TensorLikeType) -> TensorLikeType: + assert isinstance(a, TensorLike) + torch._check( + utils.is_complex_dtype(a.dtype), lambda: "imag only supports complex tensors." + ) + return prims.imag(a) + + +@_make_elementwise_unary_reference( + ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, + aten_op=None, # CompositeImplicitAutograd +) +def isfinite(a: TensorLikeType) -> TensorLikeType: + if utils.is_float_dtype(a.dtype) or utils.is_complex_dtype(a.dtype): + return prims.isfinite(a) + + return ones_like(a, dtype=torch.bool) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL) +def isinf(a: TensorLikeType) -> TensorLikeType: + if utils.is_complex_dtype(a.dtype): + return torch.logical_or(isinf(torch.real(a)), isinf(torch.imag(a))) + if utils.is_float_dtype(a.dtype): + return torch.abs(a) == float("inf") + return torch.zeros_like(a, dtype=torch.bool) + + +@_make_elementwise_unary_reference( + ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, + exact_dtype=True, +) +def isposinf(a: TensorLikeType) -> TensorLikeType: + torch._check( + not utils.is_complex_dtype(a.dtype), + lambda: f"Complex dtype is not supported for isposinf, got dtype {a.dtype}", + ) + if utils.is_float_dtype(a.dtype): + return a == float("inf") + return torch.zeros_like(a, dtype=torch.bool) + + +@_make_elementwise_unary_reference( + ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, + exact_dtype=True, +) +def isneginf(a: TensorLikeType) -> TensorLikeType: + torch._check( + not utils.is_complex_dtype(a.dtype), + lambda: f"Complex dtype is not supported for isneginf, got dtype {a.dtype}", + ) + if utils.is_float_dtype(a.dtype): + return a == float("-inf") + return torch.zeros_like(a, dtype=torch.bool) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL) +def isnan(a: TensorLikeType) -> TensorLikeType: + return prims.ne(a, a) + + +# alias +mvlgamma = _make_alias(torch.special.multigammaln, "mvlgamma") # type: ignore[has-type] + + +@_make_elementwise_unary_reference( + ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, + aten_op=None, # CompositeImplicitAutograd +) +def isreal(a: TensorLikeType) -> TensorLikeType: + if utils.is_complex_dtype(a.dtype): + return torch.imag(a) == 0 + return torch.ones_like(a, dtype=torch.bool) + + +# TODO: if this is special maybe it should be defined there and imported here? +@_make_elementwise_unary_reference( + ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, aten_op=aten.i0 +) +def i0(a): + return prims.bessel_i0(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def lgamma(a): + return prims.lgamma(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def log(a): + return prims.log(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def log1p(a): + return prims.log1p(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def log2(a): + return prims.log2(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def log10(a): + return prims.log10(a) + + +# CompositeImplicitAutograd - don't register decomp +@out_wrapper() +def log_softmax( + a: TensorLikeType, + dim: int, + dtype: Optional[torch.dtype] = None, +) -> TensorLikeType: + result_dtype = dtype or a.dtype + computation_dtype = utils.get_computation_dtype(result_dtype) + a_ = _maybe_convert_to_dtype(a, computation_dtype) + return _maybe_convert_to_dtype(a_ - logsumexp(a_, dim, keepdim=True), result_dtype) # type: ignore[return-value] + + +@register_decomposition(aten.logsumexp) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("self",), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def logsumexp( + self: TensorLikeType, dim: DimsType, keepdim: bool = False +) -> TensorLikeType: + if not isinstance(dim, Iterable): + dim = (dim,) + if self.numel() == 0: + # pyrefly: ignore [no-matching-overload] + return torch.sum(torch.exp(self), dim, keepdim).log() + # pyrefly: ignore [bad-argument-type] + maxes = torch.amax(torch.real(self), dim, keepdim=True) + maxes = torch.masked_fill(maxes, maxes.abs() == float("inf"), 0) + # pyrefly: ignore [no-matching-overload] + maxes_squeezed = maxes if keepdim else torch.squeeze(maxes, dim) + # pyrefly: ignore [no-matching-overload] + result = torch.sum(torch.exp(self - maxes), dim, keepdim) + return result.log().add(maxes_squeezed) + + +@register_decomposition(aten.nan_to_num) +@out_wrapper() +def nan_to_num( + a: TensorLikeType, + nan: Optional[NumberType] = 0.0, + posinf: Optional[NumberType] = None, + neginf: Optional[NumberType] = None, +) -> TensorLikeType: + assert isinstance(a, TensorLike) + + if utils.is_boolean_dtype(a.dtype) or utils.is_integer_dtype(a.dtype): + return a.clone() + + if nan is None: + nan = 0.0 + + if posinf is None: + posinf = torch.finfo(a.dtype).max + + if neginf is None: + neginf = torch.finfo(a.dtype).min + + result = torch.where(torch.isnan(a), nan, a) # type: ignore[call-overload] + result = torch.where(torch.isneginf(a), neginf, result) # type: ignore[call-overload] + result = torch.where(torch.isposinf(a), posinf, result) # type: ignore[call-overload] + return result + + +def _neg_meta(a: TensorLikeType): + torch._check( + a.dtype is not torch.bool, + lambda: ( + "Negation, the `-` operator, on a bool tensor is not supported. " + "If you are trying to invert a mask, use the `~` or `logical_not()` " + "operator instead." + ), + ) + + +@_make_elementwise_unary_reference( + ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, extra_meta=_neg_meta +) +def neg(a): + return prims.neg(a) + + +# positive does not use _make_elementwise_unary_reference because it does not support out +# CompositeImplicitAutograd - don't register decomp +def positive(a: TensorLikeType) -> TensorLikeType: + assert isinstance(a, TensorLike) + if a.dtype is torch.bool: + msg = "positive does not support bool tensors." + raise RuntimeError(msg) + return a + + +# real does not use _make_elementwise_unary_reference because it does not support out +def real(a: TensorLikeType) -> TensorLikeType: + assert isinstance(a, TensorLike) + if utils.is_complex_dtype(a.dtype): + return prims.real(a) + return a + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def reciprocal(a): + return prims.reciprocal(a) + + +@register_decomposition(aten.round) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def round(a: TensorLikeType, *, decimals: int = 0) -> TensorLikeType: + if decimals == 0: + return prims.round(a) + else: + ten_pow = 10**decimals + ten_neg_pow = 10 ** (-decimals) + return prims.mul(prims.round(prims.mul(a, ten_pow)), ten_neg_pow) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def rsqrt(a): + return prims.rsqrt(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def sigmoid(a: TensorLikeType) -> TensorLikeType: + return true_divide(1, add(1, exp(neg(a)))) + + +@_make_elementwise_unary_reference( + ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + exact_dtype=True, +) +def sgn(a): + if utils.is_complex_dtype(a.dtype): + a_abs = a.abs() + return torch.where(a_abs == 0, 0, a / a_abs) + else: + return a.sign() + + +@_make_elementwise_unary_reference( + ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + exact_dtype=True, +) +def sign(a): + return prims.sign(a) + + +@_make_elementwise_unary_reference( + ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, + exact_dtype=True, +) +def signbit(a): + return prims.signbit(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def sin(a): + return prims.sin(a) + + +# Autograd note: This will give the right first derivative at zero (by chance), +# but not the right second derivative +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def sinc(a): + a = math.pi * a + return torch.where(a == 0, 1, torch.sin(a) / a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def sinh(a): + return prims.sinh(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def sqrt(a): + return prims.sqrt(a) + + +@_make_elementwise_unary_reference( + ELEMENTWISE_TYPE_PROMOTION_KIND.BOOL_TO_LONG, + aten_op=None, # CompositeImplicitAutograd, +) +def square(a: TensorLikeType) -> TensorLikeType: + return mul(a, a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def tan(a): + return prims.tan(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def tanh(a): + return prims.tanh(a) + + +@_make_elementwise_unary_reference( + ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + exact_dtype=True, +) +def trunc(a): + return prims.trunc(a) + + +# TODO: register this as a real ref/decomposition once TorchInductor supports complex! +def view_as_complex(self: TensorLikeType) -> TensorLikeType: + input_dtype = self.dtype + torch._check( + utils.is_float_dtype(input_dtype), + lambda: f"view_as_complex is only supported for floating point" + f"tensors, but got a tensor of scalar type: {input_dtype}", + ) + sizes = self.size() + torch._check( + len(sizes) != 0, + lambda: "Input tensor must have one or more dimensions", + ) + torch._check( + sizes[-1] == 2, + lambda: "Tensor must have a last dimension of size 2", + ) + + old_strides = self.stride() + torch._check( + old_strides[-1] == 1, + lambda: "Tensor must have a last dimension with stride 1", + ) + dims = old_strides[:-1] + torch._check( + builtins.all(stride % 2 == 0 for stride in dims), + lambda: "Tensor must have a stride divisible by 2 for all but last dimension", + ) + torch._check( + self.storage_offset() % 2 == 0, + lambda: "Tensor must have a storage_offset divisible by 2", + ) + return prims.view_element_type( + self, utils.corresponding_complex_dtype(input_dtype) + ).squeeze(-1) + + +def _make_elementwise_binary_reference( + type_promotion_kind, + aten_op=infer_aten_op, + name=None, + has_out=True, + supports_lhs_python_scalar=True, + supports_rhs_python_scalar=True, + supports_two_python_scalars=False, + should_register_decomposition=True, +) -> Callable: + def inner(prim: Callable): + nonlocal aten_op, name + if name is None: + name = prim.__name__ + + @wraps(prim) + @elementwise_type_promotion_wrapper( + type_promoting_args=("a", "b"), + type_promotion_kind=type_promotion_kind, + ) + def _ref( + a: Union[Tensor, NumberType], + b: Union[Tensor, NumberType], + ) -> Tensor: + torch._check_value( + supports_lhs_python_scalar or not isinstance(a, Number), + lambda: f"{name}: Received a lhs Python scalar to an elementwise binary " + "operation that does not accept lhs scalars!", + ) + torch._check_value( + supports_rhs_python_scalar or not isinstance(b, Number), + lambda: f"{name}: Received a rhs Python scalar to an elementwise binary " + "operation that does not accept rhs scalars!", + ) + torch._check_value( + supports_two_python_scalars + or not (isinstance(a, Number) and isinstance(b, Number)), + lambda: f"{name}: Receive two Number inputs to an elementwise binary operation!", + ) + a, b = _maybe_broadcast(a, b) + output = prim(a, b) + return handle_noncontiguous_outputs([a, b], output) + + if has_out: + _ref = out_wrapper()(_ref) # type: ignore[assignment] + + _ref.__name__ = name + if aten_op is infer_aten_op: + aten_op = utils.get_aten_op(prim, name) + if aten_op is not None and should_register_decomposition: + register_decomposition(aten_op)(_ref) + + return _ref + + return inner + + +# Add has its own implementation because it has an alpha argument +@register_decomposition(aten.add) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a", "b"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def add( + a: Union[TensorLikeType, NumberType], + b: Union[TensorLikeType, NumberType], + *, + alpha: Optional[NumberType] = None, +): + """ + Reference implementation of torch.add + """ + + a, b = _maybe_broadcast(a, b) + + if alpha is not None: + dtype = a.dtype if isinstance(a, TensorLike) else b.dtype # type: ignore[union-attr] + python_type = utils.dtype_to_type(dtype) + if python_type is not bool and not utils.is_weakly_lesser_type( + type(alpha), python_type + ): + msg = f"alpha argument of type {type(alpha)} cannot be safely cast to type {python_type}!" + raise ValueError(msg) + if isinstance(b, TensorLike): + b = prims.mul(b, alpha) + else: + b = b * alpha + + output = prims.add(a, b) + return handle_noncontiguous_outputs([a, b], output) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + supports_lhs_python_scalar=False, + supports_rhs_python_scalar=False, +) +def atan2(a, b): + return prims.atan2(a, b) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def bitwise_and(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.bitwise_and(a, b) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def bitwise_left_shift(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.shift_left(a, b) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def bitwise_or(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.bitwise_or(a, b) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def bitwise_right_shift(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.shift_right_arithmetic(a, b) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def bitwise_xor(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.bitwise_xor(a, b) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + supports_lhs_python_scalar=False, +) +def copysign( + a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType] +): + if isinstance(b, Number) and isinstance(a, Tensor): + # pyrefly: ignore [bad-argument-type] + b = scalar_tensor(b, dtype=a.dtype, device=a.device) + elif isinstance(a, Tensor) and isinstance(b, Tensor) and a.device != b.device: + msg = f"Expected divisor (b) to be on the same device ({a.device}) as dividend (a), but it is found on {b.device}!" + raise RuntimeError(msg) + # pyrefly: ignore [bad-argument-type] + return where(signbit(b), neg(abs(a)), abs(a)) + + +# complex = _make_elementwise_binary_reference(prims.complex, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT) + + +@register_decomposition(aten.div) +@out_wrapper() +def div( + a: Union[TensorLikeType, NumberType], + b: Union[TensorLikeType, NumberType], + *, + rounding_mode: Optional[str] = None, +): + """ + Reference implementation of torch.div + """ + if rounding_mode is None: + return true_divide(a, b) + elif rounding_mode == "trunc": + return trunc_divide(a, b) + elif rounding_mode == "floor": + return floor_divide(a, b) + else: + msg = f"div expected rounding_mode to be one of None, 'trunc', or 'floor' but found {rounding_mode}." + raise ValueError(msg) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, + supports_lhs_python_scalar=False, +) +def eq(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.eq(a, b) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.BOOL_TO_LONG, +) +def pow( + a: Union[TensorLikeType, NumberType], + b: Union[TensorLikeType, NumberType], +) -> TensorLikeType: + assert isinstance(a, TensorLikeType) or isinstance(b, TensorLikeType) + + if isinstance(b, Number): + if b == 1.0: + return a.clone() # type: ignore[return-value,union-attr] + elif b == 2.0: + return a * a # type: ignore[return-value] + elif b == 0.5: + return torch.sqrt(a) # type: ignore[arg-type] + elif isinstance(a, Number): + if a == 1.0: + return torch.fill(b, True) + if a == 2.0 and ( + utils.is_float_dtype(b.dtype) or utils.is_complex_dtype(b.dtype) + ): + return torch.exp2(b) + + return prims.pow(a, b) + + +# Float power has its own implementation because it has unique type promotion. +# CompositeImplicitAutograd - don't register decomp +@out_wrapper() +def float_power( + a: Union[TensorLikeType, NumberType], + b: Union[TensorLikeType, NumberType], +) -> Tensor: + if isinstance(a, Number) and isinstance(b, Number): + raise ValueError( + "Receive two Number inputs to an elementwise binary operation!" + ) + + # Handles type promotion + dtype = utils.get_higher_dtype(a, b) + assert dtype is not None + if utils.is_complex_dtype(dtype): + dtype = torch.complex128 + else: + dtype = torch.float64 + + # Float power has the following contiguous cast behavior to be + # consistent with its C++ impl + + a = _maybe_convert_to_dtype(a, dtype) + + b = _maybe_convert_to_dtype(b, dtype) + + a, b = _maybe_broadcast(a, b) + # pyrefly: ignore [bad-return] + return pow(a, b) + + +# >>> a = torch.tensor(-0.2500, dtype=torch.float64) +# tensor(-0.250000000000000, dtype=torch.float64) +# +# >>> b = torch.tensor(-0.0010, dtype=torch.float64) +# tensor(-0.001000000000000, dtype=torch.float64) +# +# Note: In this case, casting float to double will expand the float mantissa with zeros, +# while creating a double generates a distinct mantissa. +# >>> torch.tensor(-0.001).to(dtype=torch.float64) +# tensor(-0.001000000047497, dtype=torch.float64) +# +# Floor Division +# The difference is caused because torch.remainder(a, b) = -0.001. +# +# >>> torch.floor(torch.true_divide(a, b)) +# tensor(250., dtype=torch.float64) +# +# >>> torch.div(a, b, rounding_mode='floor') +# tensor(249., dtype=torch.float64) +# +# Definition: a // b = (a - remainder(a, b)) / b +# >>> torch.true_divide(torch.sub(a, torch.remainder(a, b)), b) +# tensor(249., dtype=torch.float64) +# +# For reference, see CPython's implementation: +# https://github.com/python/cpython/blob/ace008c531dd685a30c1dd68f9b5ba35f20171cf/Objects/floatobject.c#L636 + + +@_make_elementwise_binary_reference( + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + supports_two_python_scalars=True, + should_register_decomposition=False, +) +def floor_divide( + a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType] +): + # Wrap scalars because some references only accept tensor arguments. + if isinstance(a, Number) and isinstance(b, Number): + # pyrefly: ignore [bad-argument-type] + a = scalar_tensor(a) + # pyrefly: ignore [bad-argument-type] + b = scalar_tensor(b) + elif isinstance(b, Number) and isinstance(a, Tensor): + # pyrefly: ignore [bad-argument-type] + b = scalar_tensor(b, dtype=a.dtype, device=a.device) + elif isinstance(a, Number) and isinstance(b, Tensor): + # pyrefly: ignore [bad-argument-type] + a = scalar_tensor(a, dtype=b.dtype, device=b.device) + elif isinstance(a, Tensor) and isinstance(b, Tensor) and a.device != b.device: + if a.device == torch.device("cpu"): + msg = f"Expected divisor (b) to be on the same device ({a.device}) as dividend (a), but it is found on {b.device}!" + raise RuntimeError(msg) + else: + b = prims.device_put(b, device=a.device) + + assert isinstance(a, Tensor) and isinstance(b, Tensor) + dtype = a.dtype + if utils.is_float_dtype(dtype): + return _floor_divide_float(a, b) + elif utils.is_integer_dtype(dtype): + return _floor_divide_integer(a, b) + else: + torch._check(False, lambda: f"{dtype} not supported for floor_divide") + + +def _floor_divide_integer(a: Tensor, b: Tensor) -> Tensor: + a, b = _maybe_broadcast(a, b) + + if not a.dtype.is_signed: + return prims.div(a, b) + + # Convert truncation to flooring: + offset = (torch.signbit(a) != torch.signbit(b)).logical_and(torch.fmod(a, b) != 0) + return prims.div(a, b) - _maybe_convert_to_dtype(offset, a.dtype) + + +def _floor_divide_float(a: Tensor, b: Tensor) -> Tensor: + mod = fmod(a, b) + div = true_divide(sub(a, mod), b) + + # Ensure that the remainder has the same sign as denominator + different_signed_inputs = bitwise_xor(lt(a, 0), lt(b, 0)) + non_zero_remainder = ne(mod, 0) + mask = bitwise_and(non_zero_remainder, different_signed_inputs) + div = where(mask, sub(div, 1), div) + + # Map quotient to nearest integer value + floor_div = floor(div) + mask = gt(sub(div, floor_div), 0.5) + floor_div = where(mask, add(floor_div, 1), floor_div) + + basic_div = true_divide(a, b) + zero_tensor = scalar_tensor(0, dtype=basic_div.dtype, device=basic_div.device) + + # If quotient is zero, copy signbit from true_divide quotient + floor_div = where(ne(div, 0), floor_div, copysign(zero_tensor, basic_div)) + + # If denominator is zero, then follow true_divide behavior + return where(ne(b, 0), floor_div, basic_div) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + supports_lhs_python_scalar=False, + supports_rhs_python_scalar=False, +) +def fmax(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.fmax(a, b) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + supports_lhs_python_scalar=False, + supports_rhs_python_scalar=False, +) +def fmin(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.fmin(a, b) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + supports_lhs_python_scalar=False, + supports_rhs_python_scalar=True, +) +def fmod(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.fmod(a, b) + + +@register_decomposition(aten.frexp) +@out_wrapper("mantissa", "exponent") +def frexp(self: TensorLikeType) -> tuple[TensorLikeType, TensorLikeType]: + return torch.return_types.frexp(prims.frexp(self)) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + supports_lhs_python_scalar=False, + supports_rhs_python_scalar=False, +) +def gcd(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.gcd(a, b) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, + supports_lhs_python_scalar=False, +) +def ge(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.ge(a, b) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, + supports_lhs_python_scalar=False, +) +def gt(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.gt(a, b) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + supports_lhs_python_scalar=False, + supports_rhs_python_scalar=False, +) +def heaviside(input: TensorLikeType, values: TensorLikeType) -> TensorLikeType: + input_eq_zero = torch.eq(input, 0) + input_lt_zero = torch.logical_or(torch.lt(input, 0), torch.isnan(input)) + zeros_and_ones = torch.where(input_lt_zero, 0, 1) + output = torch.where(input_eq_zero, values, zeros_and_ones) + return output + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + supports_lhs_python_scalar=False, + supports_rhs_python_scalar=False, +) +def hypot(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.hypot(a, b) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + supports_lhs_python_scalar=False, + supports_rhs_python_scalar=False, +) +def igamma(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.igamma(a, b) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + supports_lhs_python_scalar=False, + supports_rhs_python_scalar=False, +) +def igammac(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.igammac(a, b) + + +def _check_close_args( + name: str, + a: TensorLikeType, + b: TensorLikeType, + rtol: float, + atol: float, +) -> None: + torch._check_value( + a.dtype == b.dtype, + lambda: f"{name}: Attempting to compare tensors of different dtypes {a.dtype} and {b.dtype}!", + ) + torch._check( + rtol >= 0, + lambda: f"{name}: rtol must be greater than or equal to zero, but got {rtol}!", + ) + torch._check( + atol >= 0, + lambda: f"{name}: atol must be greater than or equal to zero, but got {atol}!", + ) + + +# CompositeImplicitAutograd - don't register decomp +def isclose( + a: TensorLikeType, + b: TensorLikeType, + rtol: float = 1e-05, + atol: float = 1e-08, + equal_nan: bool = False, +) -> TensorLikeType: + _check_close_args(name="torch.isclose", a=a, b=b, rtol=rtol, atol=atol) + + close = eq(a, b) + if equal_nan and (utils.is_float_dtype(a.dtype) or utils.is_complex_dtype(a.dtype)): + close = logical_or(close, logical_and(isnan(a), isnan(b))) + + # Note: In case of zero tolerances the closeness inequality degenerates to an equality check. + # In this case, the short-circuit prevents false positives as detailed in the paragraph below. + if atol == 0 and rtol == 0: + return close + + # Note [closeness error computation] + # atol and rtol are provided as doubles, so the computation + # rtol * other will produce a float or complex tensor. + # When the difference (self - other) is compared to it then the + # tensor representing the difference will also be cast to float or complex. + # However, since (self - other) in uint8 is very likely to produce a + # negative value, this moves the cast forward so the difference is + # always computed in a float or complex type. + # If the values of the integer tensors cannot be exactly represented + # by the default scalar type then this may cause an incorrect result. + if not utils.is_float_dtype(a.dtype) and not utils.is_complex_dtype(a.dtype): + a = prims.convert_element_type(a, torch.get_default_dtype()) + b = prims.convert_element_type(b, torch.get_default_dtype()) + + allowed_error = add(atol, abs(mul(b, rtol))) + actual_error = abs(sub(a, b)) + + # Computes finite closeness + result = logical_or( + close, logical_and(isfinite(actual_error), le(actual_error, allowed_error)) + ) + + return result + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + supports_lhs_python_scalar=False, + supports_rhs_python_scalar=False, +) +def lcm(a: TensorLikeType, b: TensorLikeType): + dtype = a.dtype + # promoting to int32 to maintain 100% consistency with C++ and to + # prevent overflow in case of int8 and int16 + promote_to_int = dtype in (torch.int8, torch.int16) + if promote_to_int: + a = prims.convert_element_type(a, torch.int32) + b = prims.convert_element_type(b, torch.int32) + + g = torch.gcd(a, b) + # Avoid division by zero in case gcd(0, 0) == 0 + g = torch.where(g == 0, 1, g) + res = torch.abs(prims.div(a, g) * b) + return res if not promote_to_int else prims.convert_element_type(res, dtype) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, + supports_lhs_python_scalar=False, +) +def le(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.le(a, b) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + supports_lhs_python_scalar=False, + supports_rhs_python_scalar=False, +) +def logaddexp(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + # Nb. this implementation does not distribute the gradients evenly when a == b + mask = torch.real(a) >= torch.real(b) + max_ = torch.where(mask, a, b) + min_ = torch.where(mask, b, a) + inf_mask = torch.logical_and( + torch.logical_not(torch.isfinite(torch.real(a))), torch.real(a) == torch.real(b) + ) + if utils.is_complex_dtype(a.dtype) or utils.is_complex_dtype(b.dtype): + # are you wondering what this bunch of codes are for? edge cases! + neg_min_mask = torch.real(min_) < 0 + inf_vals = torch.where( + neg_min_mask, min_, torch.log(torch.exp(min_) + torch.exp(max_)) + ) + non_nan_vals = torch.where( + inf_mask, inf_vals, max_ + torch.log1p(torch.exp(min_ - max_)) + ) + # the type for full_like does not include tensor yet + nan_mask = torch.isnan(min_) + return torch.where(nan_mask, complex(float("nan"), float("nan")), non_nan_vals) # type: ignore[call-overload] + else: + return torch.where(inf_mask, a, max_ + torch.log1p(torch.exp(min_ - max_))) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + supports_lhs_python_scalar=False, + supports_rhs_python_scalar=False, +) +def logaddexp2(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + torch._check( + not (utils.is_complex_dtype(a.dtype) or utils.is_complex_dtype(b.dtype)), + lambda: "logaddexp2 doesn't support complex dtypes", + ) + # Nb. this implementation does not distribute the gradients evenly when a == b + mask = a >= b + max_ = torch.where(mask, a, b) + min_ = torch.where(mask, b, a) + inf_mask = torch.logical_and(torch.isinf(a), a == b) + inv_log_2 = 1.0 / math.log(2) + result = max_ + torch.log1p(torch.exp2(min_ - max_)) * inv_log_2 + return torch.where(inf_mask, a, result) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, +) +def logical_and(a: TensorLikeType, b: TensorLikeType): + if not utils.is_boolean_dtype(a.dtype): + a = a != 0 + if not utils.is_boolean_dtype(b.dtype): + b = b != 0 + return a & b + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL) +def logical_not(a: TensorLikeType): + if not utils.is_boolean_dtype(a.dtype): + return a == 0 + return ~a + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, +) +def logical_or(a: TensorLikeType, b: TensorLikeType): + if not utils.is_boolean_dtype(a.dtype): + a = a != 0 + if not utils.is_boolean_dtype(b.dtype): + b = b != 0 + return bitwise_or(a, b) + + +# TODO: skip unnecessary conversion of long to float +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, +) +def logical_xor(a: TensorLikeType, b: TensorLikeType): + if not utils.is_boolean_dtype(a.dtype): + a = a != 0 + if not utils.is_boolean_dtype(b.dtype): + b = b != 0 + return a ^ b + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, + supports_lhs_python_scalar=False, +) +def lt(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.lt(a, b) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def maximum(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.maximum(a, b) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def minimum(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.minimum(a, b) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + supports_two_python_scalars=True, +) +def mul(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.mul(a, b) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, + supports_lhs_python_scalar=False, +) +def ne(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.ne(a, b) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH, + supports_lhs_python_scalar=False, + supports_rhs_python_scalar=False, +) +def nextafter(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.nextafter(a, b) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def remainder(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.remainder(a, b) + + +# reverse sub +@register_decomposition(aten.rsub) +@out_wrapper() +def rsub( + a: Union[TensorLikeType, NumberType], + b: Union[TensorLikeType, NumberType], + alpha: NumberType = 1, +): + if isinstance(a, Number): + msg = "Received a Number for the first argument, but expected a Tensor" + raise ValueError(msg) + + return torch.sub(b, a, alpha=alpha) + + +# TODO: consider refactoring this with add impl +# sub has its own implementation because it has an alpha argument +@register_decomposition(aten.sub) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a", "b"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def sub( + a: Union[TensorLikeType, NumberType], + b: Union[TensorLikeType, NumberType], + *, + alpha: NumberType = 1, +): + """ + Reference implementation of torch.sub + """ + + a, b = _maybe_broadcast(a, b) + + if isinstance(a, TensorLike) and isinstance(b, TensorLike): + torch._check( + not utils.is_boolean_dtype(a.dtype) and not utils.is_boolean_dtype(b.dtype), + lambda: ( + "Subtraction, the `-` operator, with two bool tensors is not supported. " + "Use the `^` or `logical_xor()` operator instead." + ), + ) + + if alpha != 1: + dtype = a.dtype if isinstance(a, TensorLike) else b.dtype # type: ignore[union-attr] + python_type = utils.dtype_to_type(dtype) + if not utils.is_weakly_lesser_type(type(alpha), python_type): + msg = f"alpha argument of type {type(alpha)} cannot be safely cast to type {python_type}!" + raise ValueError(msg) + if isinstance(b, torch.Tensor): + b = prims.mul(b, alpha) + else: + # Carefully not to use prims.mul if b is a scalar / symint. + # prims.mul always returns a tensor, + # which will mess with type promotion. + b = b * alpha + + output = prims.sub(a, b) + return handle_noncontiguous_outputs([a, b], output) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + name="true_divide", + aten_op=None, # CompositeImplicitAutograd + supports_two_python_scalars=True, +) +def true_divide(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.div(a, b) + + +@register_decomposition(aten.xlogy) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a", "b"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def xlogy(a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType]): + torch._check( + isinstance(a, TensorLike) or isinstance(b, TensorLike), + lambda: 'Expected either argument a or b to be a Tensor"', + ) + + # Operations like eq and log do not handle scalar values, so we convert them to scalar_tensors. + if isinstance(b, TensorLike) and isinstance(a, Number): + # pyrefly: ignore [bad-argument-type] + a = scalar_tensor(a, dtype=b.dtype, device=b.device) + elif isinstance(a, TensorLike) and isinstance(b, Number): + # pyrefly: ignore [bad-argument-type] + b = scalar_tensor(b, dtype=a.dtype, device=a.device) + + # mypy: expected "Tensor" + assert isinstance(a, TensorLike) + assert isinstance(b, TensorLike) + rhs = torch.where(torch.eq(a, 0), 0, torch.mul(a, torch.log(b))) + return torch.where(torch.isnan(b), float("nan"), rhs) + + +@_make_elementwise_binary_reference( + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + aten_op=None, # CompositeImplicitAutograd + supports_two_python_scalars=True, +) +def trunc_divide( + a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType] +): + dtype = utils.get_dtype(a) + if utils.is_integer_dtype(dtype): + return prims.div(a, b) + + return trunc(prims.div(a, b)) + + +# +# Elementwise Ternary References +# + + +@register_decomposition(aten.addcdiv) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("self", "tensor1", "tensor2"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def addcdiv( + self: TensorLikeType, + tensor1: TensorLikeType, + tensor2: TensorLikeType, + *, + value: NumberType = 1, +) -> TensorLikeType: + """ + Reference implementation of torch.addcdiv + """ + if value is not None: + dtype = self.dtype # no scalars allowed, see add + python_type = utils.dtype_to_type(dtype) + torch._check_value( + utils.is_weakly_lesser_type(type(value), python_type), + lambda: f"value argument of type {type(value)} cannot be safely cast to type {python_type}!", + ) + + return self + value * tensor1 / tensor2 + + +@register_decomposition(aten.addcmul) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("self", "tensor1", "tensor2"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def addcmul( + self: TensorLikeType, + tensor1: TensorLikeType, + tensor2: TensorLikeType, + *, + value: NumberType = 1, +) -> TensorLikeType: + """ + Reference implementation of torch.addcmul + """ + if value is not None: + dtype = self.dtype # no scalars allowed, see add + python_type = utils.dtype_to_type(dtype) + torch._check_value( + utils.is_weakly_lesser_type(type(value), python_type), + lambda: f"value argument of type {type(value)} cannot be safely cast to type {python_type}!", + ) + + return self + value * tensor1 * tensor2 + + +@register_decomposition(aten.clamp) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a", "min", "max"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def clamp( + a: TensorLikeType, + min: Optional[TensorOrNumberLikeType] = None, + max: Optional[TensorOrNumberLikeType] = None, +) -> TensorLikeType: + # NOTE: grad behavior with implementation `where` is not consistent on `nan` + if min is None and max is None: + msg = "clamp called but both min and max are none!" + raise ValueError(msg) + if min is not None: + a_isnan = torch.isnan(a) + condition = torch.bitwise_or(torch.ge(a, min), a_isnan) # type: ignore[arg-type] + # we should also propagate `nan` coming from boundaries. However, that's + # not necessary since `ge` would already `False` when either operands has + # a `nan`. So this line below is redundant + # `condition = bitwise_and(condition, bitwise_not(isnan(min)))` + a = torch.where(condition, a, min) # type: ignore[arg-type] + if max is not None: + a_isnan = torch.isnan(a) + # same as above, no need to adjust `nan` from `max` + condition = torch.bitwise_or(torch.le(a, max), a_isnan) # type: ignore[arg-type] + a = torch.where(condition, a, max) # type: ignore[arg-type] + + return a + + +@register_decomposition(aten.clamp_min) +@out_wrapper() +def clamp_min( + self: TensorLikeType, + min: Optional[TensorOrNumberLikeType] = None, +) -> TensorLikeType: + return torch.clamp(self, min=min) # type: ignore[arg-type] + + +@register_decomposition(aten.clamp_max) +@out_wrapper() +def clamp_max( + self: TensorLikeType, + max: Optional[TensorOrNumberLikeType] = None, +) -> TensorLikeType: + return torch.clamp(self, max=max) # type: ignore[arg-type] + + +# +# Conditional references +# + + +# https://pytorch.org/docs/stable/generated/torch.where.html +# TODO: implement where.default +@register_decomposition(aten.where.self) +@register_decomposition(aten.where.ScalarSelf) +@register_decomposition(aten.where.ScalarOther) +@register_decomposition(aten.where.Scalar) +@register_decomposition(aten.where.self_out) +@out_wrapper(exact_dtype=True) +@elementwise_type_promotion_wrapper( + type_promoting_args=("a", "b"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH, +) +def where( + pred: Tensor, + a: Optional[TensorOrNumberLikeType] = None, + b: Optional[TensorOrNumberLikeType] = None, +): + """ """ + + if a is None or b is None: + raise NotImplementedError + + utils.check_same_device(pred, a, b, allow_cpu_scalar_tensors=True) + torch._check( + pred.dtype is torch.bool, + lambda: f"expected predicate to be bool, got {pred.dtype}", + ) + + pred, a, b = _maybe_broadcast(pred, a, b) + return prims.where(pred, a, b) + + +# +# Data Movement References +# +@register_decomposition(aten.clone) +@out_wrapper() +def clone( + a: TensorLikeType, *, memory_format: torch.memory_format = torch.preserve_format +) -> TensorLikeType: + result = prims.clone(a, memory_format=memory_format) + return result + + +def copy_to(a: Tensor, b: Tensor, *, allow_cross_device=True): + if not allow_cross_device and a.device != b.device: + msg = f"Attempting to copy from device {b.device} to device {a.device}, but cross-device copies are not allowed!" + raise RuntimeError(msg) + + return prims.copy_to(a, b) + + +@register_decomposition(aten.item) +def item(a: TensorLikeType) -> NumberType: + if a.numel() != 1: + msg = f"Can't convert a tensor with {a.numel()} elements to a number!" + raise ValueError(msg) + + # NOTE: explicit conversion is necessary for bool! + # See https://github.com/pytorch/pytorch/issues/78071 + number_type = utils.dtype_to_type(a.dtype) + return number_type(prims.item(a)) + + +# fast path when `to` returns an alias to input. This mimics the same function in aten +def _to_will_alias( + a: TensorLikeType, + device: Optional[DeviceLikeType] = None, + dtype: Optional[torch.dtype] = None, + copy: Optional[bool] = None, + layout: Optional[torch.layout] = None, + memory_format: Optional[torch.memory_format] = None, + pin_memory: Optional[bool] = False, + non_blocking: bool = False, # not using non_blocking +) -> bool: + return ( + not copy + and (device is None or a.device == device) + and (dtype is None or a.dtype == dtype) + and (layout is None or a.layout == layout) + # is_pinned issue #84925 + # and (pin_memory is None or pin_memory == a.is_pinned()) + and ( + memory_format is None + or memory_format == torch.preserve_format + or utils.is_contiguous_for_memory_format(a, memory_format=memory_format) + ) + ) + + +@singledispatch +def _to_dispatch(*args, **kwargs): + raise NotImplementedError + + +@_to_dispatch.register +def _to_device( + device: torch.device, + dtype: torch.dtype, + non_blocking: bool = False, + copy: bool = False, + memory_format: Optional[torch.memory_format] = None, +) -> dict[str, Any]: + kwargs = { + "device": device, + "dtype": dtype, + "non_blocking": non_blocking, + "copy": copy, + "memory_format": memory_format, + } + return kwargs + + +@_to_dispatch.register +def _to_device_str( + device: str, + dtype: torch.dtype, + non_blocking: bool = False, + copy: bool = False, + memory_format: Optional[torch.memory_format] = None, +) -> dict[str, Any]: + kwargs = { + "device": torch.device(device), + "dtype": dtype, + "non_blocking": non_blocking, + "copy": copy, + "memory_format": memory_format, + } + return kwargs + + +@_to_dispatch.register +def _to_dtype( + dtype: torch.dtype, + non_blocking: bool = False, + copy: bool = False, + memory_format: Optional[torch.memory_format] = None, +) -> dict[str, Any]: + kwargs = { + "dtype": dtype, + "non_blocking": non_blocking, + "copy": copy, + "memory_format": memory_format, + } + return kwargs + + +@_to_dispatch.register +def _to_other( + other: Tensor, + non_blocking: bool = False, + copy: bool = False, + memory_format: Optional[torch.memory_format] = None, +) -> dict[str, Any]: + device = other.device + dtype = other.dtype + layout = other.layout + # is_pinned issue #84925 + # pin_memory = other.is_pinned() + kwargs = { + "device": device, + "dtype": dtype, + "layout": layout, + "non_blocking": non_blocking, + "copy": copy, + "memory_format": memory_format, + } + return kwargs + + +# remove to_kwargs that is already present in `a` +def _canonicalize_to_arguments(a: Tensor, to_kwargs: dict): + options_to_check = ["dtype", "device", "layout", "memory_format"] + # "device" option could be passed a str instead torch.device + if "device" in to_kwargs and isinstance(to_kwargs["device"], str): + to_kwargs["device"] = torch.device(to_kwargs["device"]) + + for kw in options_to_check: + if kw in to_kwargs: + if ( + (kw == "memory_format" and to_kwargs[kw] is torch.preserve_format) + or ( + kw == "device" + and to_kwargs[kw].type == a.device.type + and ( + not to_kwargs[kw].index or to_kwargs[kw].index == a.device.index + ) + ) + or ( + getattr(a, kw, None) == to_kwargs[kw] + ) # this also handles {"memory_format": None} + ): + to_kwargs.pop(kw) + + +def to(a: TensorLikeType, *args, **kwargs) -> TensorLikeType: + # handled dispatch via positional arguments + if len(args) != 0: + kwargs = _to_dispatch(*args, **kwargs) + + # TODO: is_pinned is not currently supported in refs or fake_tensor + # https://github.com/pytorch/pytorch/issues/84925 + assert "pin_memory" not in kwargs + _canonicalize_to_arguments(a, kwargs) + + if _to_will_alias(a, **kwargs): + return a + + copy = kwargs.pop("copy") if "copy" in kwargs else False + non_blocking = kwargs.pop("non_blocking") if "non_blocking" in kwargs else False + + # short-circuit to `prims.convert_element_type` when `to` is just a dtype change + if ( + (copy or (kwargs.get("dtype", a.dtype) != a.dtype)) + and (not non_blocking) + and ("memory_format" not in kwargs) + and ("device" not in kwargs) + and ("layout" not in kwargs) + # is_pinned issue #84925 + # and ("pin_memory" not in kwargs) + ): + return prims.convert_element_type(a, kwargs.get("dtype", a.dtype)) + + result = torch.empty_like(a, **kwargs) + # TODO: non_blocking should be handled by `copy_to` + copy_to(result, a) + return result + + +# +# Reduction references +# + + +def _reduction( + a: TensorLikeType, + prim: Callable, + *, + has_identity: bool = True, + accepts_dim_tuple: bool = True, # to handle min/argmin that accept single dim only + dims: Optional[DimsType] = None, + keepdims: bool = False, + dtype: Optional[torch.dtype] = None, # should be specified for ops that support it + out: Optional[Tensor] = None, + output_dtype_kind: REDUCTION_OUTPUT_TYPE_KIND, +) -> TensorLikeType: # it is usually SAME, but I want + # ref writers to actually think about what to put here + assert isinstance(a, TensorLike) + if a.ndim > 64: + raise RuntimeError( + f"Received a tensor with {a.ndim} dimensions, but only tensors with up to 64 dims are supported!" + ) + + if out is not None: + assert isinstance(out, TensorLike) + if dtype is not None: + # TODO - this is true for eager mode currently, but it's wrong behavior for complex norms + if dtype != out.dtype: + raise RuntimeError( + "dtype argument and out dtype must match in reduction" + ) + if not accepts_dim_tuple: + assert dims is None or isinstance(dims, Dim) + if isinstance(dims, Dim): + dims = (dims,) # type: ignore[assignment] + dims = utils.reduction_dims(a.shape, dims) + if not has_identity: + from torch.fx.experimental.symbolic_shapes import sym_and + + valid_shape = a.ndim == 0 or sym_and(*(a.shape[i] > 0 for i in dims)) + torch._check( + valid_shape, + lambda: "reducing over zero-size dimension for reduction operation without identity", + ) + + computation_dtype, result_dtype = utils.reduction_dtypes( + a, output_dtype_kind, dtype + ) + a = _maybe_convert_to_dtype(a, computation_dtype) # type: ignore[method-assign] + result = prim(a, dims) + if keepdims: + output_shape = [a.shape[i] if i not in dims else 1 for i in range(a.ndim)] + broadcast_dims = [i for i in range(a.ndim) if i not in dims] + result = prims.broadcast_in_dim(result, output_shape, broadcast_dims) + + if out is not None: + assert result_dtype is not None + if dtype is not None and result_dtype != out.dtype: + raise RuntimeError( + "Expected the dtype of reduction result and out to match" + ) + out = _maybe_resize_out(out, result.shape) + return _safe_copy_out(copy_from=result, copy_to=out) # type: ignore[arg-type] + + if result.dtype != result_dtype and result_dtype is not None: + result = prims.convert_element_type(result, result_dtype) + + return result + + +def _make_copy_from_view(fn, return_none_on_out_variant=False): + """ + Given a view function (e.g. torch.diagonal) generates its copy variant (e.g. torch.diagonal_copy) + """ + aten_fn = getattr(aten, fn.__name__) + annotations = getattr(fn, "__annotations__", {}) + # view ops should not change dtypes, this ensures that the decomp path has + # the same error checks as eager. + fn = out_wrapper(exact_dtype=True)(aten_fn) + + @wraps(fn) + def _fn(*args, out=None, **kwargs): + result = fn(*args, out=out, **kwargs) + if return_none_on_out_variant and out is not None: + return None + if out is not None: + return result + + return pytree.tree_map( + lambda x: x.clone(memory_format=torch.contiguous_format), + result, + ) + + copy_name = f"{fn.__name__}_copy" + _fn.__name__ = copy_name + _fn.__annotations__.update(annotations) + register_decomposition(getattr(aten, copy_name))(_fn) + return _fn + + +@register_decomposition(aten.all) +@out_wrapper() +def all( + a: TensorLikeType, + dim: Optional[DimsType] = None, + keepdim: bool = False, +) -> TensorLikeType: + result = torch.logical_not(torch.any(torch.logical_not(a), dim, keepdim=keepdim)) + + if a.dtype == torch.uint8: + result = result.to(dtype=torch.uint8) + + return result + + +@register_decomposition(aten.any) +@out_wrapper() +def any( + a: TensorLikeType, + dim: Optional[DimsType] = None, + keepdim: bool = False, +) -> TensorLikeType: + a_ = _maybe_convert_to_dtype(a, torch.bool) + if isinstance(dim, (list, tuple)) and len(dim) == 0: + result = a_.clone() + else: + result = a_.sum(dim=dim, keepdim=keepdim).ne(False) + + # Preserves uint8 -- probably a legacy mask thing + if a.dtype is torch.uint8: + return prims.convert_element_type(result, torch.uint8) + + return result + + +@register_decomposition([aten.sum.dim_IntList, aten.sum.IntList_out]) +def sum( + a: TensorLikeType, + dim: Union[Optional[int], Optional[list[int]]] = None, + keepdim: bool = False, + *, + dtype: Optional[torch.dtype] = None, + out: Optional[Tensor] = None, +) -> TensorLikeType: + if dtype is None: + if out is not None: + dtype = out.dtype + elif utils.is_boolean_dtype(a.dtype) or utils.is_integer_dtype(a.dtype): + dtype = torch.int64 + else: + dtype = a.dtype + # reduces over all dimensions if dim=() is passed + if dim == () or dim == []: + dim = None + return _reduction( + a, + prims.sum, + dims=dim, + keepdims=keepdim, + dtype=dtype, + out=out, + output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.SAME, + ) + + +def sum_to_size( + a: Tensor, + *shape, +) -> Tensor: + shape = utils.extract_shape_from_varargs(shape, validate=False) + torch._check( + utils.is_expandable_to(shape, a.shape), + lambda: f'sum_to_size: size "{shape}" is not expandable to size "{a.shape}"', + ) + # In ATen scalar tensors are sent through sum and the result is returned as + # type promoted + if utils.is_same_shape(shape, a.shape) and len(shape) > 0: + return prims.view_of(a) + leading_dims = a.ndim - len(shape) + reduce_dims = tuple(range(leading_dims)) + tuple( + i + for i in range(leading_dims, len(shape)) + if shape[i - leading_dims] == 1 and a.shape[i] != 1 + ) + return torch.sum(a, dim=reduce_dims, keepdim=True, dtype=None) + + +@register_decomposition(aten.prod) +def prod( + a: TensorLikeType, + dim: Union[Optional[int], Optional[list[int]]] = None, + keepdim: bool = False, + *, + dtype=None, + out: Optional[Tensor] = None, +) -> TensorLikeType: + if dtype is None: + if out is not None: + dtype = out.dtype + elif utils.is_boolean_dtype(a.dtype) or utils.is_integer_dtype(a.dtype): + dtype = torch.int64 + else: + dtype = a.dtype + # reduces over all dimensions if dim=() is passed + if dim == () or dim == []: + dim = None + return _reduction( + a, + prims.prod, + dims=dim, + keepdims=keepdim, + dtype=dtype, + out=out, + output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.SAME, + ) + + +@register_decomposition(aten.amin) +def amin( + a: TensorLikeType, + dim: Optional[DimsType] = None, + keepdim: bool = False, + *, + out: Optional[Tensor] = None, +) -> TensorLikeType: + # reduces over all dimensions if dim=() is passed + if dim == () or dim == []: + dim = None + + return _reduction( + a, + prims.amin, + dims=dim, + keepdims=keepdim, + dtype=None, + out=out, + has_identity=False, + output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.SAME, + ) + + +@register_decomposition(aten.amax) +def amax( + a: TensorLikeType, + dim: Optional[DimsType] = None, + keepdim: bool = False, + *, + out: Optional[Tensor] = None, +) -> TensorLikeType: + # reduces over all dimensions if dim=() is passed + if dim == () or dim == []: + dim = None + + return _reduction( + a, + prims.amax, + dims=dim, + keepdims=keepdim, + dtype=None, + out=out, + has_identity=False, + output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.SAME, + ) + + +def _dim_var_dispatch(dim=None, unbiased=None): + # There's the following overload of torch.var: + # var(Tensor self, bool unbiased=True) -> (Tensor, Tensor) + # We need to explicitly convert bool dims to unbiased arg + if unbiased is None and isinstance(dim, bool): + unbiased = dim + dim = None + return dim, unbiased + + +@register_decomposition(aten.var) +@out_wrapper() +def var( + a: TensorLikeType, + dim: Optional[DimsType] = None, + unbiased: Optional[bool] = None, + keepdim: bool = False, + *, + correction: Optional[NumberType] = None, +) -> TensorLikeType: + dim, unbiased = _dim_var_dispatch(dim, unbiased) + correction = utils.set_correction(unbiased, correction) + # reduces over all dimensions if dim=() is passed + if dim == () or dim == []: + dim = None + + result = _reduction( + a, + partial(prims.var, correction=correction), + dims=dim, + keepdims=keepdim, + dtype=None, + out=None, + has_identity=True, + output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT, + ) + return result + + +@register_decomposition(aten.std) +@out_wrapper() +def std( + a: TensorLikeType, + dim: Union[Optional[int], Optional[list[int]]] = None, + unbiased: Optional[bool] = None, + keepdim: bool = False, + *, + correction: Optional[NumberType] = None, +) -> TensorLikeType: + dim, unbiased = _dim_var_dispatch(dim, unbiased) + correction = utils.set_correction(unbiased, correction) + + opmath_dtype, dtype = utils.reduction_dtypes( + a, REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT + ) + a = _maybe_convert_to_dtype(a, opmath_dtype) + a_var = torch.var(a, dim, correction=correction, keepdim=keepdim) + a_std = torch.sqrt(a_var) + assert dtype is not None + return _maybe_convert_to_dtype(a_std, dtype) + + +@register_decomposition(aten.mean) +def mean( + a: TensorLikeType, + dim: Optional[DimsType] = None, + keepdim: bool = False, + *, + dtype=None, + out=None, +) -> TensorLikeType: + # reduces over all dimensions if dim=() is passed + if dim == () or dim == []: + dim = None + orig_dtype = dtype + if dtype is None: + dtype = a.dtype + result = _reduction( + a, + prims.sum, + dims=dim, + keepdims=keepdim, + dtype=dtype, + out=None, + output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.KEEP_PROMOTED_TYPE, + ) + torch._check( + utils.is_float_dtype(dtype) or utils.is_complex_dtype(dtype), + lambda: ( + f"mean(): could not infer output dtype. " + f"{'Input' if orig_dtype is None else 'Optional'} dtype must be either " + f"a floating point or complex dtype. Got: {dtype}" + ), + ) + if isinstance(dim, Dim): + dim = (dim,) # type: ignore[assignment] + dims = utils.reduction_dims(a.shape, dim) # type: ignore[arg-type] + nelem = 1 if a.ndim == 0 else reduce(operator.mul, (a.shape[i] for i in dims), 1) + result = true_divide(result, nelem) + result_dtype = a.dtype if dtype is None else dtype + result = _maybe_convert_to_dtype(result, result_dtype) # type: ignore[method-assign] + if out is not None: + assert isinstance(out, TensorLike) + out = _maybe_resize_out(out, result.shape) + return _safe_copy_out(copy_from=result, copy_to=out) # type: ignore[arg-type] + return result + + +@register_decomposition(aten.std_mean) +@out_wrapper("out0", "out1") +def std_mean( + a: TensorLikeType, + dim: Optional[DimsType] = None, + *, + unbiased: Optional[bool] = None, + keepdim: bool = False, + correction: Optional[NumberType] = None, +): + dim, unbiased = _dim_var_dispatch(dim, unbiased) + correction = utils.set_correction(unbiased, correction) + opmath_dtype, dtype = utils.reduction_dtypes( + a, REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT + ) + original_dtype = a.dtype + a = _maybe_convert_to_dtype(a, opmath_dtype) + a_var, a_mean = torch.var_mean(a, dim, correction=correction, keepdim=keepdim) + a_std = torch.sqrt(a_var) + assert dtype is not None + return ( + _maybe_convert_to_dtype(a_std, dtype), + _maybe_convert_to_dtype(a_mean, original_dtype), + ) + + +@register_decomposition(aten.var_mean) +@out_wrapper("out0", "out1") +def var_mean( + a: TensorLikeType, + dim: Optional[DimsType] = None, + unbiased: Optional[bool] = None, + keepdim: bool = False, + *, + correction: Optional[NumberType] = None, +): + dim, unbiased = _dim_var_dispatch(dim, unbiased) + v = var(a, dim, unbiased, keepdim, correction=correction) + m = mean(a, dim, keepdim) + return v, m + + +@register_decomposition(aten.addr) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("self", "vec1", "vec2"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def addr( + self: TensorLikeType, + vec1: TensorLikeType, + vec2: TensorLikeType, + *, + beta: NumberType = 1, + alpha: NumberType = 1, +) -> TensorLikeType: + torch._check( + vec1.ndim == 1, + lambda: f"addr: Expected 1-D argument vec1, but got {vec1.ndim}-D", + ) + torch._check( + vec2.ndim == 1, + lambda: f"addr: Expected 1-D argument vec2, but got {vec2.ndim}-D", + ) + for arg, arg_name in ((alpha, "alpha"), (beta, "beta")): + if isinstance(arg, bool): + torch._check( + utils.is_boolean_dtype(self.dtype) + and utils.is_boolean_dtype(vec1.dtype) + and utils.is_boolean_dtype(vec2.dtype), + lambda: f"Boolean {arg_name} only supported for Boolean results.", + ) + self = self.expand(vec1.shape[0], vec2.shape[0]) + if utils.is_boolean_dtype(self.dtype): + # Integers are accepted for booleans + torch._check( + is_weakly_lesser_type(type(beta), int), + lambda: f"expected bool/int beta but got {type(beta)}", + ) + torch._check( + is_weakly_lesser_type(type(alpha), int), + lambda: f"expected bool/int alpha but got {type(beta)}", + ) + if not beta: + return torch.outer(vec1, vec2) if alpha else torch.full_like(self, False) + else: + return torch.logical_or( + self, + torch.outer(vec1, vec2) if alpha else torch.full_like(self, False), + ) + else: + torch._check( + is_weakly_lesser_type(type(beta), dtype_to_type(self.dtype)), + lambda: f"cannot safely convert {type(beta)} to {self.dtype}", + ) + torch._check( + is_weakly_lesser_type(type(alpha), dtype_to_type(self.dtype)), + lambda: f"cannot safely convert {type(alpha)} to {self.dtype}", + ) + if beta == 0: + # This means NaNs from self are dropped if beta is zero + return alpha * torch.outer(vec1, vec2) + else: + return beta * self + alpha * torch.outer(vec1, vec2) + + +# CompositeImplicitAutograd - don't register decomp +def atleast_1d( + arg: Union[TensorLikeType, Sequence[TensorLikeType]], *args: TensorLikeType +) -> Union[TensorLikeType, tuple[TensorLikeType, ...]]: + """Reference implementation of :func:`torch.atleast_1d`.""" + if not args and isinstance(arg, collections.abc.Sequence): + args_ = arg + else: + assert not isinstance(arg, collections.abc.Sequence) + args_ = (arg,) + args + res = tuple(a if a.ndim >= 1 else unsqueeze(a, 0) for a in args_) + return res if len(res) > 1 else res[0] + + +# Helper function with assert to avoid MyPy error +# of incompatible type passed to unsqueeze +def _unsqueeze_atleast( + at_least_fn: Callable, dim: int, arg: TensorLikeType +) -> TensorLikeType: + arg_ = at_least_fn(arg) + assert isinstance(arg_, TensorLike) + return unsqueeze(arg_, dim) + + +# CompositeImplicitAutograd - don't register decomp +def atleast_2d( + arg: Union[TensorLikeType, Sequence[TensorLikeType]], *args: TensorLikeType +) -> Union[TensorLikeType, tuple[TensorLikeType, ...]]: + """Reference implementation of :func:`torch.atleast_2d`.""" + if not args and isinstance(arg, collections.abc.Sequence): + args_ = arg + else: + assert not isinstance(arg, collections.abc.Sequence) + args_ = (arg,) + args + unsqueeze_atleast_1d = partial(_unsqueeze_atleast, atleast_1d, 0) + res = tuple(a if a.ndim >= 2 else unsqueeze_atleast_1d(a) for a in args_) + return res if len(res) > 1 else res[0] + + +# CompositeImplicitAutograd - don't register decomp +def atleast_3d( + arg: Union[TensorLikeType, Sequence[TensorLikeType]], *args: TensorLikeType +) -> Union[TensorLikeType, tuple[TensorLikeType, ...]]: + """Reference implementation of :func:`torch.atleast_3d`.""" + if not args and isinstance(arg, collections.abc.Sequence): + args_ = arg + else: + assert not isinstance(arg, collections.abc.Sequence) + args_ = (arg,) + args + unsqueeze_atleast_2d = partial(_unsqueeze_atleast, atleast_2d, -1) + res = tuple(a if a.ndim >= 3 else unsqueeze_atleast_2d(a) for a in args_) + return res if len(res) > 1 else res[0] + + +def as_strided( + a: TensorLikeType, + size: ShapeType, + stride: StrideType, + storage_offset: Optional[int] = None, +) -> TensorLikeType: + storage_offset_int = ( + storage_offset if storage_offset is not None else a.storage_offset() + ) + return prims.as_strided(a, size, stride, storage_offset_int) + + +@register_decomposition(aten.as_strided_scatter) +@out_wrapper() +def as_strided_scatter( + input: TensorLikeType, + src: TensorLikeType, + size: ShapeType, + stride: StrideType, + storage_offset: Optional[int] = None, +) -> TensorLikeType: + storage_offset_int = 0 if storage_offset is None else storage_offset + return prims.as_strided_scatter(input, src, size, stride, storage_offset_int) + + +def broadcast_shapes(*shapes) -> ShapeType: + return torch.Size(_broadcast_shapes(*shapes)) + + +@aten.broadcast_tensors.default.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten.broadcast_tensors.default.py_impl(DispatchKey.Meta) +def broadcast_tensors(*tensors) -> list[TensorLikeType]: + if len(tensors) == 1 and not isinstance(tensors[0], Tensor): + tensors = tensors[0] + return list(_maybe_broadcast(*tensors, preserve_cpu_scalar_tensors=False)) + + +# CompositeImplicitAutograd - don't register decomp +def broadcast_to(a: TensorLikeType, size: ShapeType) -> TensorLikeType: + start = len(size) - len(a.shape) + dims = tuple(range(start, len(a.shape) + start)) + return prims.broadcast_in_dim(a, size, dims) + + +@register_decomposition(aten.cat) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("tensors",), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH, +) +def cat(tensors: TensorSequenceType, dim: int = 0) -> TensorLikeType: + def cat_compute_output_memory_format(inputs): + format = None + for t in inputs: + f = utils.suggest_memory_format(t) + if f == torch.contiguous_format: + return f + if format is not None and format != f: + return torch.contiguous_format + format = f + assert format is not None + return format + + if len(tensors) == 0: + msg = "cat expects at least one tensor, but received zero!" + raise ValueError(msg) + + for tensor in tensors: + assert isinstance(tensor, TensorLike) + + utils.check_same_device(*tensors, allow_cpu_scalar_tensors=False) + + from torch.fx.experimental.symbolic_shapes import guard_or_false + + # This is a bit tricky. Naively, you would expect to just pick one + # arbitrary tensor and check that all tensors match this tensor. However, + # there is legacy behavior which says that if you have a 1-D empty tensor + # (0,), this is permissible. So you can't assume that all the tensors + # have same dimensionality, and you can't assume that the first tensor is + # the correct stencil. + # + # We'll implement this in a few passes. First, we will try to infer the + # ndim of the cat output. If this ndim != 1, then we know that all ndim = + # 1 inputs must be empty, or are errors. If this ndim == 1, then life + # is easy (the legacy special case coincides with regular handling). + # + # NB: The regular implementation of cat just filters out empty inputs, + # but we do it slightly different here for better handling for unbacked + # SymInts + + example = None + # pyrefly: ignore [bad-assignment] + for i, t in enumerate(tensors): + if example is None: + if t.ndim != 1: + example = t + else: + if t.ndim != 1: + torch._check( + t.ndim == example.ndim, + lambda: "Number of dimensions of tensors must match. " + f"Expected {example.ndim}-D tensors, but got {t.ndim}-D for " + f"tensor number {i} in the list", + ) + + if example is None: + # example is None if everything is 1-D. If so, just arbitrarily pick + # the first one + example = tensors[0] + + shape = example.shape + filtered = [] + for tensor_idx, tensor in enumerate(tensors): + if len(shape) != len(tensor.shape): + assert tensor.ndim == 1 # we've already checked this above + # Don't suggest the legacy behavior in the error message + torch._check( + # NB: it is not enough to simply assert that tensor.shape[0] == 0; + # this MUST be true even under guard size oblivious. + # Effectively, we must actually know that the shape is zero, + # passing an unbacked SymInt which we will defer a runtime + # assert on won't cut it. This is a policy decision (size + # oblivious semantics say that u0 tensors never are inferred + # to be zero size, even if they must be that for the cat to go + # through), and is load bearing for our Inductor lowerings + # (which assume that size oblivious tests are OK to determine + # if a shape is permissibly zero.) + guard_or_false(tensor.shape[0] == 0), + lambda: f"Number of dimensions of tensors must match. " + f"Expected {example.ndim}-D tensors, but got 1-D for " + f"tensor number {tensor_idx} in the list", + ) + else: + # Remove inputs that are 1-D, zero size + if tensor.ndim == 1 and guard_or_false(tensor.shape[0] == 0): + continue + # Don't bother checking size match, prims.cat will handle it + filtered.append(tensor) + + memory_format = cat_compute_output_memory_format(tensors) + + if len(filtered) == 0: + t = tensors[0] + + # TODO: fix this to work with meta tensors + try: + # BUG? This looks like it wants to call builtins.any() but is + # actually calling .any() (in this file). Changing to builtins.any() + # causes tests to fail: + # PYTORCH_OPINFO_SAMPLE_INPUT_INDEX=4 python test/test_ops.py -k \ + # TestFakeTensorCUDA.test_fake_crossref_backward_amp_cat_cuda_float32 + requires_grad = bool(any(x.requires_grad for x in tensors)) # type: ignore[arg-type] + except Exception: + requires_grad = False # type: ignore[assignment] + + return empty( + (0,), + dtype=t.dtype, + device=t.device, + requires_grad=requires_grad, + memory_format=memory_format, + ) + + dim = utils.canonicalize_dim(filtered[0].ndim, dim) + utils.validate_idx(filtered[0].ndim, dim) + + return prims.cat(filtered, dim).clone(memory_format=memory_format) + + +# CompositeImplicitAutograd - don't register decomp +@out_wrapper() +def column_stack(tensors: TensorSequenceType) -> TensorLikeType: + aligned_tensors = tuple( + x if x.ndim > 1 else x.reshape((x.numel(), 1)) for x in tensors + ) + return cat(aligned_tensors, 1) + + +def conj(input: TensorLikeType) -> TensorLikeType: + if not utils.is_complex_dtype(input.dtype): + return input + if input.is_sparse: + return torch.conj_physical(input) + return prims.conj(input) + + +# This replicates at::constant_pad_nd, defined in ATen/native/PadNd.cpp +@register_decomposition(aten.constant_pad_nd) +@out_wrapper() +def constant_pad_nd( + input: TensorLikeType, pad: list[int], value: NumberType = 0 +) -> TensorLikeType: + torch._check( + len(pad) % 2 == 0, + lambda: f"Length of pad must be even but instead it equals {len(pad)}", + ) + + input_sizes = input.shape + l_inp = len(input_sizes) + + l_pad = len(pad) // 2 + l_diff = l_inp - l_pad + + torch._check( + l_inp >= l_pad, + lambda: "Length of pad should be no more than twice the number of " + f"dimensions of the input. Pad length is {len(pad)} while the input has " + f"{l_inp} dimensions.", + ) + + c_input = input + for i in range(l_diff, l_inp): + pad_idx = 2 * (l_inp - i - 1) + if pad[pad_idx] < 0: + c_input = c_input.narrow(i, -pad[pad_idx], c_input.shape[i] + pad[pad_idx]) + + if pad[pad_idx + 1] < 0: + c_input = c_input.narrow(i, 0, c_input.shape[i] + pad[pad_idx + 1]) + + # If all the pads are negative we can return the result. + # Avoid early exiting if all pads = 0 to prevent specialization on export. + # During export, raw if statements are specialized on the input, meaning + # that we lose a branch depending on the example input used to export. + # Here, this is either the case where all pads = 0, or the case where at + # least one pad > 0 and the rest are >= 0. + # Avoiding the early exit when all pads = 0 ensures we can export + # constant_pad_nd for cases when all pads >= 0. + # Note: if any pads are negative, this code specializes due to the if statements above. + if builtins.all(p < 0 for p in pad): + return c_input.clone() + + new_shape = list(input_sizes[:l_diff]) + + for i in range(l_pad): + pad_idx = len(pad) - ((i + 1) * 2) + new_dim = input_sizes[l_diff + i] + pad[pad_idx] + pad[pad_idx + 1] + torch._check( + new_dim >= 0, + lambda: f"The input size {input_sizes[l_diff + i]}, plus negative padding " + f"{pad[pad_idx]} and {pad[pad_idx + 1]} resulted in a negative output size, " + f"which is invalid. Check dimension {l_diff + i} of your input.", + ) + new_shape.append(new_dim) + + memory_format = utils.suggest_memory_format(input) + output = torch.empty( + new_shape, + dtype=input.dtype, + device=input.device, + requires_grad=input.requires_grad, + memory_format=memory_format, + ) + + if value == 0 and input.dtype == torch.bool: + value = False + # torch.fill isn't typed to allow complex values + output = torch.fill(output, value) # type: ignore[arg-type] + + c_output = output + for i in range(l_diff, l_inp): + pad_idx = 2 * (l_inp - i - 1) + if pad[pad_idx] >= 0: + c_output = c_output.narrow( + i, pad[pad_idx], c_output.shape[i] - pad[pad_idx] + ) + if pad[pad_idx + 1] >= 0: + c_output = c_output.narrow(i, 0, c_output.shape[i] - pad[pad_idx + 1]) + + prims.copy_to(c_output, c_input) + return output + + +def contiguous( + a: Tensor, *, memory_format: torch.memory_format = torch.contiguous_format +) -> Tensor: + torch._check( + memory_format != torch.preserve_format, + lambda: "preserve memory format is unsupported by the contiguous operator", + ) + + # TODO: make logic consistent with aten contiguous + if is_contiguous_for_memory_format_or_false(a, memory_format=memory_format): + return a + + return torch.clone(a, memory_format=memory_format) + + +@out_wrapper() +def dstack(tensors: TensorSequenceType) -> TensorLikeType: + torch._check(len(tensors) > 0, lambda: "dstack expects a non-empty TensorList") + aligned_tensors = atleast_3d(*tensors) + return cat(aligned_tensors, 2) + + +@register_decomposition(aten.expand) +def expand(a: Tensor, *shape, implicit: bool = False) -> Tensor: + from torch.fx.experimental.symbolic_shapes import guard_or_false, size_hint, sym_or + + backed_so = torch.fx.experimental._config.backed_size_oblivious + + # NOTE: cannot use utils.extract_shape_from_varargs here + # because that also validates the shape, but the shape + # given to expand may be "invalid" + if len(shape) == 1 and isinstance(shape[0], Sequence): + shape = tuple(shape[0]) + + torch._check( + len(shape) >= len(a.shape), + lambda: "expand: the requested shape has too few dimensions!", + ) + + offset = len(shape) - len(a.shape) + shape_ = list(shape) + for idx, x in enumerate(a.shape): + offset_idx = idx + offset + requested_length = shape[offset_idx] + + # expand(in -> out) has 3 different semantics: + # 1) out == -1 -> size = in, stride unchanged + # 2) in == 1 -> size = out, stride = 0 + # 3) in == out -> size = in, stride unchanged + # + # the code below is written for unbacked semantics s.t. we assume unbacked symbols don't + # represent -1 unless explicitly specified, and the user is opting for case 2) or 3). + # the sym_or allows either case, but in the decomposition's current state, broadcast_in_dim() + # will either assume case 3) (via validate_shape() marking the expanded shape size-like), or will + # raise a data-dependent error trying to figure out if the stride is 0, requiring the user to manually + # select between the semantics of cases 2) and 3). + if guard_or_false(requested_length == -1): + shape_[offset_idx] = x + else: + # When backed size oblivious is used, we specialize for broadcasting + # if its the only way to compile the example input. + # i.e: x:1, requested_length:1 ==> + # assert x==requested_length, no specialization on ==1 or !=1. + # The non-broadcast path is picked + # x:1, requested_length:4 ==> + # specialize(x) to be 1. + if backed_so: + x_hint = size_hint(x, allow_none=True) + requested_hint = size_hint(requested_length, allow_none=True) + if x_hint == 1 and requested_hint != 1: + torch._check(x == 1) + + torch._check( + sym_or(x == 1, requested_length == x), + lambda: f"expand: attempting to expand a dimension of length {x} -> {requested_length}!", + ) + torch._check(requested_length >= 0) + shape_[offset_idx] = requested_length + + # At this point shape must be valid + utils.validate_shape(shape_) + + return prims.broadcast_in_dim( + a, shape_, tuple(range(offset, len(a.shape) + offset)) + ) + + +# CompositeImplicitAutograd - don't register decomp +def expand_as(a: Tensor, b: Tensor) -> Tensor: + return a.expand(b.shape) + + +def chunk(a: TensorLikeType, chunks: int, dim: int = 0) -> tuple[TensorLikeType, ...]: + if chunks <= 0: + msg = f"Expected at least one chunk, but got {chunks}!" + raise ValueError(msg) + + dim = utils.canonicalize_dim(a.ndim, dim) + length = a.shape[dim] + chunk_size = math.ceil(length / chunks) + full_chunks = math.floor(length / chunk_size) + tail_chunk_size = length % chunk_size + + result = [narrow(a, dim, i * chunk_size, chunk_size) for i in range(full_chunks)] + + if tail_chunk_size != 0: + result.append(narrow(a, dim, full_chunks * chunk_size, tail_chunk_size)) + + return tuple(result) + + +# Note: flatten, unlike other shape operators, returns the input tensor on a no-op (unless +# a 0D tensor is flattened, in which case it's returned in 1D) +# CompositeImplicitAutograd - don't register decomp +def flatten(a: TensorLikeType, start_dim: int = 0, end_dim: int = -1) -> TensorLikeType: + start_dim = utils.canonicalize_dim(a.ndim, start_dim) + end_dim = utils.canonicalize_dim(a.ndim, end_dim) + + # Short-circuits on no-op + if start_dim == end_dim and a.ndim != 0: + return a + + # Tries to take a view + # TODO: we could look at directing collapse_view to skip its meta function here (unsafe_collapse_view) + # Unbacked semantics: if validity of in-place flattening is undecided we copy. + new_shape, _new_strides = prims._collapse_view_helper( + a, start_dim, end_dim, must_be_valid=None + ) + if new_shape is not None: + return prims.collapse_view(a, start_dim, end_dim) + + # Makes a copy if it can't make a view + return prims.collapse(a, start_dim, end_dim) + + +@register_decomposition(aten.flip) +@out_wrapper() +def flip(a: TensorLikeType, dims: DimsSequenceType) -> TensorLikeType: + if not isinstance(dims, tuple) and not isinstance(dims, list): + raise ValueError("dims has to be a sequence of ints") + dims = utils.canonicalize_dims(a.ndim, dims) # type: ignore[assignment] + utils.validate_no_repeating_dims(dims) + return prims.rev(a, dims) + + +# CompositeImplicitAutograd - don't register decomp +def fliplr(a: TensorLikeType) -> TensorLikeType: + if a.ndim < 2: + raise RuntimeError("Input must be >= 2-d.") + + return flip(a, (1,)) + + +# CompositeImplicitAutograd - don't register decomp +def flipud(a: TensorLikeType) -> TensorLikeType: + if a.ndim < 1: + raise RuntimeError("Input must be >= 1-d.") + + return flip(a, (0,)) + + +# CompositeImplicitAutograd - don't register decomp +def narrow( + a: TensorLikeType, dim: int, start: Union[int, TensorLikeType], length: int +) -> TensorLikeType: + # Supports Tensor overload that was added for XLA: + # https://github.com/pytorch/pytorch/issues/31558 + if isinstance(start, TensorLike): + torch._check( + start.dim() == 0 and utils.is_integer_dtype(start.dtype), + lambda: "start must be an 0-dim integral Tensor.", + ) + start = start.item() # type: ignore[assignment] + start = cast(int, start) + torch._check(a.dim() > 0, lambda: "narrow() cannot be applied to a 0-dim tensor.") + torch._check(length >= 0, lambda: "narrow(): length must be non-negative.") + dim = utils.canonicalize_dim(a.ndim, dim) + dim_length = a.size(dim) + torch._check_with( + IndexError, + -dim_length <= start and start <= dim_length, + lambda: f"start out of range (expected to be in range of [{-dim_length}, {dim_length}], but got {start})", + ) + if start < 0: + start = start + dim_length + torch._check( + start <= dim_length - length, + lambda: f"start ({start}) + length ({length}) exceeds dimension size ({dim_length}).", + ) + new_shape = list(a.shape) + new_shape[dim] = length + return a.as_strided( + new_shape, a.stride(), a.storage_offset() + a.stride(dim) * start + ) + + +def _normalize( + a: Tensor, norm_dims: DimsType, eps: float +) -> tuple[Tensor, Tensor, Tensor]: + """Computes mean and 1/std of a tensor along norm_dims. + + Used as a helper function for normalization layers. + + Args: + a (Tensor): input tensor + norm_dims (DimsType): dimensions to normalize over + eps (float): epsilon for numerical stability + + Returns: + out (Tensor): normalized tensor. + mean (Tensor): mean of the tensor along norm_dims. + rstd (Tensor): 1/std of the tensor along norm_dims. + """ + + norm_dims = utils.canonicalize_dims(a.ndim, norm_dims) + computation_dtype = utils.get_computation_dtype(a.dtype) + a_acc = _maybe_convert_to_dtype(a, computation_dtype) + assert isinstance(a_acc, TensorLike) # to avoid mypy error for var_mean + biased_var, mean = torch.var_mean( + a_acc, dim=norm_dims, unbiased=False, keepdim=True + ) + rstd = torch.rsqrt(biased_var + eps) + out = (a_acc - mean) * rstd + return out, mean, rstd + + +# add all specified dimensions +def _unsqueeze_multiple(x: TensorLikeType, dimensions: list[int]) -> TensorLikeType: + for dim in sorted(dimensions): + x = torch.unsqueeze(x, dim) + return x + + +@register_decomposition(aten.native_group_norm.default) +def native_group_norm( + input: Tensor, + weight: Optional[Tensor], + bias: Optional[Tensor], + batch_size: int, + num_channels: int, + flattened_inner_size: int, + num_groups: int, + eps: float, +) -> tuple[Tensor, Tensor, Tensor]: + torch._check( + input.ndim >= 2, + lambda: f"Expected at least 2 dimensions for input tensor but received {input.ndim}", + ) + torch._check( + num_channels % num_groups == 0, + lambda: "Expected number of channels in input to be divisible by num_groups, " + + f"but got input of shape {input.shape} and num_groups = {num_groups}", + ) + + computation_dtype = utils.get_computation_dtype(input.dtype) + input_acc = _maybe_convert_to_dtype(input, computation_dtype) + # num_channels / num_groups and flattened inner dimension are the reduction axes + reduction_dims = [2, 3] + input_reshaped = torch.reshape( + input_acc, + [batch_size, num_groups, num_channels // num_groups, flattened_inner_size], + ) + reduction_dims = utils.canonicalize_dims(input_reshaped.ndim, reduction_dims) + biased_var, mean = torch.var_mean( + input_reshaped, dim=reduction_dims, unbiased=False, keepdim=True + ) + rstd = torch.rsqrt(biased_var + eps) + if input.device.type == "cpu" and weight is not None: + weight_reshaped = torch.reshape( + weight, [1, num_groups, num_channels // num_groups, 1] + ) + w = rstd * weight_reshaped + b = -mean * w + if bias is not None: + bias_reshaped = torch.reshape( + bias, [1, num_groups, num_channels // num_groups, 1] + ) + b = b + bias_reshaped + w = w.contiguous().as_strided([batch_size, num_channels], [num_channels, 1]) + b = b.contiguous().as_strided([batch_size, num_channels], [num_channels, 1]) + broadcast_dims = list(range(2, input.ndim)) + unsqueeze_w = _unsqueeze_multiple(w, broadcast_dims) + unsqueeze_b = _unsqueeze_multiple(b, broadcast_dims) + out = input_acc * unsqueeze_w + unsqueeze_b + else: + out = (input_reshaped - mean) * rstd + out = out.view(input.shape) + broadcast_dims = [0] + list(range(2, input.ndim)) + if weight is not None: + unsqueeze_weight = _unsqueeze_multiple(weight, broadcast_dims) + out = out * unsqueeze_weight + if bias is not None: + unsqueeze_bias = _unsqueeze_multiple(bias, broadcast_dims) + out = out + unsqueeze_bias + + out = _maybe_convert_to_dtype(out, input.dtype) # type: ignore[assignment] + mean = _maybe_convert_to_dtype(mean, input.dtype) # type: ignore[assignment] + rstd = _maybe_convert_to_dtype(rstd, input.dtype) # type: ignore[assignment] + + # remove broadcast dimensions from mean and rstd + mean = torch.squeeze(mean, reduction_dims) + rstd = torch.squeeze(rstd, reduction_dims) + return (out, mean, rstd) + + +@register_decomposition(aten.native_layer_norm) +@out_wrapper("out0", "out1", "out2") +def native_layer_norm( + input: Tensor, + normalized_shape: ShapeType, + weight: Optional[Tensor], + bias: Optional[Tensor], + eps: float, +) -> tuple[Tensor, Tensor, Tensor]: + from torch.fx.experimental.symbolic_shapes import sym_eq + + normalized_ndim = len(normalized_shape) + torch._check( + normalized_ndim >= 1, + lambda: "Expected normalized_shape to be at least 1-dimensional, i.e., " + + "containing at least one element, but got normalized_shape = " + + str(normalized_shape), + ) + # torch.Size([1, 2, 3]) == [1, 2, 3] evaluates to False + # while torch.Size([1, 2, 3]) == (1, 2, 3) is True + # therefore we use tuple(normalized_shape) + torch._check( + # pyrefly: ignore [bad-argument-type] + weight is None or sym_eq(weight.shape, tuple(normalized_shape)), + lambda: "Expected weight to be of same shape as normalized_shape, but got " + + "weight of shape " + + str(weight.shape) # type: ignore[union-attr] + + " and normalized_shape = " + + str(normalized_shape), + ) + torch._check( + # pyrefly: ignore [bad-argument-type] + bias is None or sym_eq(bias.shape, tuple(normalized_shape)), + lambda: "Expected bias to be of same shape as normalized_shape, but got " + + "bias of shape " + + str(bias.shape) # type: ignore[union-attr] + + " and normalized_shape = " + + str(normalized_shape), + ) + torch._check( + input.ndim >= normalized_ndim + and sym_eq( + input.shape[(input.ndim - normalized_ndim) :], + # pyrefly: ignore [bad-argument-type] + tuple(normalized_shape), + ), + lambda: "Given normalized_shape=" + + str(normalized_shape) + + ", expected input with shape " + + str(normalized_shape) + + ", but got input of size " + + str(input.shape), + ) + + input = contiguous(input) + if weight is not None: + weight = contiguous(weight) + if bias is not None: + bias = contiguous(bias) + + axis = input.ndim - normalized_ndim + reduction_dims = list(range(axis, input.ndim)) + out, mean, rstd = _normalize(input, reduction_dims, eps) + + if weight is None and bias is not None: + out = out + bias + elif weight is not None and bias is None: + out = out * weight + elif weight is not None and bias is not None: + out = out * weight + bias + + out = _maybe_convert_to_dtype(out, input.dtype) # type: ignore[assignment] + if input.device.type in ["cpu", "mtia"]: + mean = _maybe_convert_to_dtype(mean, input.dtype) # type: ignore[assignment] + rstd = _maybe_convert_to_dtype(rstd, input.dtype) # type: ignore[assignment] + return (out, mean, rstd) + + +@torch._subclasses.fake_impls.register_op_impl(aten.native_layer_norm.default) +def native_layer_norm_fake(fake_mode, func, *args, **kwargs): + return native_layer_norm(*args) + + +# TODO: Adding this as a meta function causes functorch tests to fail when compiled with debug mode. +# test/test_eager_transforms.py::TestFunctionalizeCPU::test_functionalize_fx_transpose_simple_cpu +@register_decomposition(aten.permute) +def permute(a: TensorLikeType, *dims) -> TensorLikeType: + _permutation = utils.canonicalize_dims( + a.ndim, utils.extract_dims_from_varargs(dims) + ) + return prims.transpose(a, _permutation) + + +@register_decomposition(aten.renorm) +@out_wrapper() +def renorm( + input: TensorLikeType, p: RealNumberType, dim: int, maxnorm: RealNumberType +) -> TensorLikeType: + torch._check(not isinstance(p, complex), lambda: "renorm: p must be real-valued") + torch._check(p > 0, lambda: "renorm: non-positive norm not supported") + torch._check( + not isinstance(maxnorm, complex), lambda: "renorm: maxnorm must be real-valued" + ) + torch._check( + maxnorm >= 0, lambda: f"renorm: expected maxnorm to be >= 0 but got {maxnorm}" + ) + ndim = input.ndim + torch._check( + ndim > 1, + lambda: f"renorm: input needs at least 2 dimensions, got {ndim} dimensions", + ) + + dim = utils.canonicalize_dim(ndim, dim) + reduce_dims = list(range(ndim)) + del reduce_dims[dim] + + # For half and bfloat16, calculate norm in float precision then cast + # normalization factor to half + acc_type = utils.get_computation_dtype(input.dtype) + if acc_type != input.dtype: + norm = torch.linalg.vector_norm( + input, p, reduce_dims, keepdim=True, dtype=acc_type + ) + else: + norm = torch.linalg.vector_norm(input, p, reduce_dims, keepdim=True) + + eps = 1e-7 + norm_factor = torch.where(norm > maxnorm, maxnorm / (norm + eps), 1.0) + if acc_type != input.dtype: + norm_factor = prims.convert_element_type(norm_factor, input.dtype) + return (input * norm_factor).contiguous() + + +# CompositeImplicitAutograd - don't register decomp +@aten.stft.center.py_impl(DispatchKey.CompositeImplicitAutograd) +def stft( + input: Tensor, + n_fft: int, + hop_length: Optional[int] = None, + win_length: Optional[int] = None, + window: Optional[Tensor] = None, + center: bool = True, + pad_mode: str = "reflect", + normalized: bool = False, + onesided: Optional[bool] = None, + return_complex: Optional[bool] = None, + align_to_window: Optional[bool] = None, +) -> Tensor: + torch._check( + window is None or window.device == input.device, + lambda: ( + f"stft input and window must be on the same device but got self on {input.device}" + + f" and window on {window.device}" # type: ignore[union-attr] + ), + ) + torch._check( + not center or align_to_window is None, + lambda: "stft only supports align_to_window for center = False.", + ) + + hop_length_ = hop_length if hop_length is not None else n_fft // 4 + win_length_ = win_length if win_length is not None else n_fft + + if return_complex is None: + return_complex_ = input.is_complex() or ( + window is not None and utils.is_complex_dtype(window.dtype) + ) + torch._check( + return_complex_, + lambda: ( + "stft requires the return_complex parameter be given for real inputs, " + + "and will further require that return_complex=True in a future PyTorch release." + ), + ) + else: + return_complex_ = return_complex + + torch._check( + utils.is_float_dtype(input.dtype) or utils.is_complex_dtype(input.dtype), + lambda: "stft expected a tensor of floating point or complex values", + ) + torch._check(1 <= input.ndim <= 2, lambda: "stft expected a 1D or 2D tensor") + + original_ndim = input.ndim + if original_ndim == 1: + input = input.unsqueeze(0) + + if center: + extra_dims = 3 - input.ndim + pad_amount = n_fft // 2 + extended_shape = [*itertools.repeat(1, extra_dims), *input.shape] + input = aten.pad(input.view(extended_shape), [pad_amount, pad_amount], pad_mode) + input = input.view(input.size()[extra_dims:]) + + length = input.size(1) + torch._check( + 0 < n_fft <= length, + lambda: f"stft expected 0 < n_fft <= {length}, but got n_fft={n_fft}", + ) + torch._check( + hop_length_ > 0, + lambda: f"stft expected hop_length > 0 but got hop_length={hop_length_}", + ) + torch._check( + 0 < win_length_ <= n_fft, + lambda: f"stft expected 0 < win_length <= n_fft but got win_length={win_length_}", + ) + torch._check( + window is None or window.shape == (win_length_,), + lambda: ( + f"expected a 1D window tensor of size equal to win_length={win_length_}, " + + f"but got window with size {window.shape}" # type: ignore[union-attr] + ), + ) + + if win_length_ < n_fft: + if window is None: + window = torch.ones(win_length_, dtype=input.dtype, device=input.device) + left = (n_fft - win_length_) // 2 + window = aten.constant_pad_nd(window, [left, n_fft - win_length_ - left]) + + if not center and align_to_window: + input_pad_amount = (n_fft - win_length_) // 2 + input = aten.pad(input, [input_pad_amount, input_pad_amount], pad_mode) + + input = input.unfold(dimension=-1, size=n_fft, step=hop_length_) + + if window is not None: + input = input * window + + complex_fft = utils.is_complex_dtype(input.dtype) + onesided = onesided if onesided is not None else not complex_fft + norm = "ortho" if normalized else None + if onesided: + torch._check( + not complex_fft, + lambda: "Cannot have onesided output if window or input is complex", + ) + out = torch.fft.rfft(input, dim=-1, norm=norm) + else: + out = torch.fft.fft(input, dim=-1, norm=norm) + + out.transpose_(1, 2) + + if original_ndim == 1: + out = out.squeeze_(0) + + return out if return_complex_ else torch.view_as_real(out) + + +# CompositeImplicitAutograd - don't register decomp +@aten.istft.default.py_impl(DispatchKey.CompositeImplicitAutograd) +def istft( + input: Tensor, + n_fft: int, + hop_length: Optional[int] = None, + win_length: Optional[int] = None, + window: Optional[Tensor] = None, + center: bool = True, + normalized: bool = False, + onesided: Optional[bool] = None, + length: Optional[int] = None, + return_complex=False, +) -> Tensor: + torch._check( + window is None or window.device == input.device, + lambda: ( + f"istft input and window must be on the same device but got self on {input.device}" + + f" and window on {window.device}" # type: ignore[union-attr] + ), + ) + + hop_length_ = hop_length if hop_length is not None else n_fft // 4 + win_length_ = win_length if win_length is not None else n_fft + + torch._check( + utils.is_complex_dtype(input.dtype), + lambda: ( + "istft input and window must be on the same device but got self on " + + f"{input.device} and window on {window.device}" # type: ignore[union-attr] + ), + ) + n_frames = input.size(-1) + fft_size = input.size(-2) + + expected_output_signal_len = n_fft + hop_length_ * (n_frames - 1) + torch._check(input.numel() > 0, lambda: "istft input tensor cannot be empty") + torch._check( + 2 <= input.ndim <= 3, + lambda: f"istft expected a tensor with 2 or 3 dimensions, but got {input.ndim}", + ) + onesided_ = onesided if onesided is not None else fft_size != n_fft + + if onesided_: + torch._check( + n_fft // 2 + 1 == fft_size, + lambda: ( + "istft expected the frequency dimension (3rd to the last) of the input tensor " + + f"to match n_fft / 2 + 1 when onesided=True, but got {fft_size}" + ), + ) + else: + torch._check( + n_fft == fft_size, + lambda: ( + "istft expected the frequency dimension (3rd to the last) of the input tensor " + + f"to match n_fft when onesided=False, but got {fft_size}", + ), + ) + + torch._check( + 0 < hop_length_ <= win_length_, + lambda: "istft expected 0 < hop_length <= win_length", + ) + torch._check( + 0 < win_length_ <= n_fft, lambda: "istft expected 0 < win_length <= n_fft" + ) + torch._check( + window is None or window.shape == (win_length_,), + lambda: "Invalid window shape. window has to be 1D and length of `win_length`", + ) + + if window is None: + real_dtype = utils.corresponding_real_dtype(input.dtype) + window_ = torch.ones(win_length_, dtype=real_dtype, device=input.device) + else: + window_ = window + + if win_length_ != n_fft: + left = (n_fft - win_length_) // 2 + window_ = aten.constant_pad_nd(window_, (left, n_fft - win_length_ - left), 0) + + original_ndim = input.ndim + if input.ndim == 2: + input = input.unsqueeze(0) + + input = input.transpose(1, 2) + norm = "ortho" if normalized else None + if return_complex: + torch._check( + not onesided_, + lambda: "cannot have onesided output if window or input is complex", + ) + input = torch.fft.ifft(input, dim=-1, norm=norm) + else: + torch._check( + window is None or not utils.is_complex_dtype(window.dtype), + lambda: "Complex windows are incompatible with return_complex=False", + ) + if not onesided_: + input = input.narrow(dim=-1, start=0, length=n_fft // 2 + 1) + input = torch.fft.irfft(input, dim=-1, norm=norm) + + assert input.size(2) == n_fft + + y_tmp = input * window_.view([1, 1, n_fft]) + y = aten.unfold_backward( + y_tmp, + input_sizes=(y_tmp.size(0), expected_output_signal_len), + dim=1, + size=n_fft, + step=hop_length_, + ) + window_envelop = aten.unfold_backward( + window_.pow(2).expand((1, n_frames, n_fft)), + input_sizes=(y_tmp.size(0), expected_output_signal_len), + dim=1, + size=n_fft, + step=hop_length_, + ) + + assert expected_output_signal_len == y.size(1) + assert expected_output_signal_len == window_envelop.size(1) + + start = n_fft // 2 if center else 0 + if length is not None: + end = start + length + elif center: + end = expected_output_signal_len - n_fft // 2 + else: + end = expected_output_signal_len + + length = max(0, end - start) + y = y.narrow(dim=1, start=start, length=length) + window_envelop = window_envelop.narrow(dim=1, start=start, length=length) + + y = y / window_envelop + if original_ndim == 2: + y = y.squeeze(0) + + if end > expected_output_signal_len: + warnings.warn( + "The length of signal is shorter than the length parameter. Result is being " + + "padded with zeros in the tail. Please check your center and hop_length settings", + stacklevel=2, + ) + y = aten.constant_pad_nd(y, (0, end - expected_output_signal_len), 0) + return y + + +# Get the new shape and stride after applying unfold to an input tensor +def _get_unfold_shape_stride( + a_shape: ShapeType, a_stride: StrideType, dimension: int, size: int, step: int +): + a_ndim = len(a_shape) + dim = utils.canonicalize_dim(a_ndim, dimension, wrap_scalar=True) + max_size = 1 if a_ndim == 0 else a_shape[dim] + last_stride = 1 if a_ndim == 0 else a_stride[dim] + + torch._check( + size <= max_size, + lambda: f"Maximum size for tensor at dimension {dim} is {max_size} but size is {size}", + ) + + torch._check( + step > 0, + lambda: f"Step is {step} but must be > 0", + ) + + shape = list(a_shape) + strides = list(a_stride) + shape.append(size) + strides.append(last_stride) + if dim < a_ndim: + shape[dim] = (shape[dim] - size) // step + 1 + strides[dim] *= step + return shape, strides + + +@register_decomposition(aten.repeat) +@out_wrapper() +def repeat(a: Tensor, *repeat_shape) -> Tensor: + repeat_shape = utils.extract_shape_from_varargs(repeat_shape, validate=False) + torch._check( + len(repeat_shape) >= len(a.shape), + lambda: "repeat: Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor", + ) + + if len(repeat_shape) == 0: + return torch.clone(a) + + num_new_dimensions = len(repeat_shape) - a.ndim + padded_shape = [1] * num_new_dimensions + for dim_size in a.shape: + padded_shape.append(dim_size) + + target_shape = tuple( + padded_size * repeat_size + for padded_size, repeat_size in zip(padded_shape, repeat_shape) + ) + + # return an empty tensor if one of the repeat_shape dimensions is zero + if 0 in repeat_shape: + return torch.empty( + target_shape, + dtype=a.dtype, + device=a.device, + requires_grad=a.requires_grad, + memory_format=utils.suggest_memory_format(a), + ) + + urtensor_shape = target_shape + urtensor_stride = utils.make_contiguous_strides_for(target_shape) + for dim, dim_size in enumerate(padded_shape): + # repeat each dimension by using unfold_copy operation + urtensor_shape, urtensor_stride = _get_unfold_shape_stride( + urtensor_shape, urtensor_stride, dim, dim_size, max(dim_size, 1) + ) + + # derive permute order by sorting urtensor strides + enumerated_stride = list(enumerate(urtensor_stride)) + enumerated_stride.sort(key=operator.itemgetter(1), reverse=True) + permute_order, _sorted_stride = zip(*enumerated_stride) + + # add new and expand dimensions according to urtensor + repeat_xtensor = a.expand(urtensor_shape) + + # clone tensor to concretize expanded dimensions + cloned_result = torch.clone(repeat_xtensor) + + # transpose axis so strides are in sorted order + permuted_result = cloned_result.permute(permute_order) + + # reshape to get contiguous tensor with correct target shape + return permuted_result.reshape(target_shape) + + +def _reshape_view_helper_core_alg( + a: TensorLikeType, shape, allow_copy: bool +) -> TensorLikeType: + # NOTE [Reshape Algorithm] + # This algorithm works by attempting to greedily construct the desired dimensions in + # the output shape, left to right. It does this by, conceptually, accumulating + # dimensions of the original tensor, also left to right, until the dimension + # can be constructed using prims.split_dim. + # The algorithm also has special handling for tail squeezes/unsqueezes, like + # if a reshape from (5, 5) to (5, 5, 1) or vice versa. + # + # This algorithm does not flatten the original tensor and then split dims as appropriate + # because that would create copies more often than this algorithm. flatten is the only + # operation below which can create a view or a copy, and while it prefers creating + # views it may sometimes create a copy if the tensor's strides do not permit a view. + # As a result, this algorithm tries to minimize flattening. + # + # Note that a better version of this algorithm may exist. Regions which could be + # flattened without creating a copy can be identified in advance, and that might + # allow fewer flatten calls or faster short-circuiting to make a copy. + idx = 0 + a_ = a + for length in shape: + # Handles tail unsqueezes + if idx >= a_.ndim: + assert length == 1 + last_dim = a_.ndim - 1 + # NOTE: using split_dim instead of unsqueeze may seem silly here, + # but it's necessary to get the strides correct + a_ = prims.split_dim(a_, last_dim, a_.shape[last_dim]) + idx = idx + 1 + continue + + # Skips dimensions that are already the correct length + if length == a_.shape[idx]: + idx = idx + 1 + continue + + accum = a_.shape[idx] + end = idx + while accum % length != 0: + end += 1 + accum *= a_.shape[end] + if end != idx: + # NOTE: in this case multiple dimensions must be flatten to create the desired dimension + # This flattening is why reshape sometimes creates a copy -- because flattening + # may return a view of a copy + + # Checks if collapse can be a view and short-circuits to copying reshape if it can't + new_shape, _new_strides = prims._collapse_view_helper( + a_, idx, end, must_be_valid=None + ) + if new_shape is None: + if allow_copy: + return prims.reshape(a, shape) + + msg = f"Cannot view a tensor with shape {a.shape} and strides {a.stride()} as a tensor with shape {shape}!" + raise ValueError(msg) + + a_ = flatten(a_, idx, end) + + # Splits the (possibly flattened) dimension to create the desired dim length. + # guard_or_true is safe due to the tail unsqueeze routine. + if accum != length: + a_ = prims.split_dim(a_, idx, length) + + idx = idx + 1 + + # Squeezes tail + while idx < a_.ndim: + torch._check( + a_.shape[idx] == 1, + lambda: f"a.size({idx}) expected to be 1 but got {a_.shape[idx]}", + ) + a_ = squeeze(a_, idx) + + if a_ is a: + return prims.view_of(a) + else: + return a_ + + +def _reshape_view_helper(a: TensorLikeType, *shape, allow_copy: bool) -> TensorLikeType: + # Creates a valid shape + shape = utils.extract_shape_from_varargs(shape, validate=False) + # Reshape may be given a shape with a -1 length + # This indicates that the dimension's length should be inferred + shape = utils.infer_size(shape, a.numel()) + + # Special-cases tensors with no elements + if a.numel() == 0: + return as_strided(a, shape, utils.make_contiguous_strides_for(shape)) + + # Special-cases reshaping zero dim tensors + if a.ndim == 0: + _a = a + for length in shape: + assert length == 1 + _a = unsqueeze(_a, -1) + if _a is a: + return prims.view_of(a) + else: + return _a + + # Special-cases reshaping to zero dim tensors + if len(shape) == 0: + _a = a + for length in a.shape: + assert length == 1 + _a = squeeze(_a, -1) + if _a is a: + return prims.view_of(a) + else: + return _a + + if is_contiguous_or_false(a): + # Special-cases for nd_to_1d + if len(shape) == 1 and a.ndim > 1: + return torch.as_strided(a, [a.numel()], [1]) + # Special-cases for 1d_to_2d + if len(shape) == 2 and a.ndim == 1: + dim0 = shape[0] + dim1 = shape[1] + return torch.as_strided(a, [dim0, dim1], [dim1, 1]) + + shape_numel = reduce(operator.mul, shape, 1) + torch._check( + a.numel() == shape_numel, + lambda: f"Could not reshape a tensor with shape {a.shape} as a tensor with shape {shape}!", + ) + + # Handles general case: a 1+D tensor reshaped into a distinct 1+D shape + return _reshape_view_helper_core_alg(a, shape, allow_copy) + + +# CompositeImplicitAutograd - don't register decomp +# NOTE: shape is a vararg because Tensor.reshape can be called with as +# Tensor.reshape(a, b, c) or Tensor.reshape((a, b, c)) Function call +# torch.reshape doesn't support unpacked shapes +def reshape(a: TensorLikeType, *shape: ShapeType) -> TensorLikeType: + return _reshape_view_helper(a, *shape, allow_copy=True) + + +# CompositeImplicitAutograd - don't register decomp +def reshape_as(self: TensorLikeType, other: TensorLikeType) -> TensorLikeType: + return self.reshape(other.size()) + + +@register_decomposition(aten.roll) +@out_wrapper() +def roll(a: TensorLikeType, shifts: DimsType, dims: DimsType = ()) -> TensorLikeType: + """Reference implementation of :func:`torch.roll`.""" + + dims = utils.canonicalize_dims(a.ndim, dims) + # ATen specifies int[1] type for shifts and dims which expands integers to tuples of length 1 + if not isinstance(shifts, Iterable): + shifts = (shifts,) + if not isinstance(dims, Iterable): + dims = (dims,) + + # Avoid modulo by zero + if a.numel() == 0: + # Keeping this as ref for now as FakeTensor runs into some issues with complex tensors + return a.clone() + + # pyrefly: ignore [bad-argument-type] + if a.dim() == 0 and len(dims) > 0: + raise IndexError( + # pyrefly: ignore [index-error] + f"Dimension specified as {dims[0]} but tensor has no dimensions" + ) + + # pyrefly: ignore [bad-argument-type] + len_shifts = len(shifts) + # pyrefly: ignore [bad-argument-type] + len_dims = len(dims) + if len_shifts != 1 or len_dims != 1: + if len_shifts == 0: + raise RuntimeError("`shifts` required") + # Takes care of the case when dims is not specified (default) + # By default, the tensor is flattened before shifting, after which the original shape is restored + if len_dims == 0 and len_shifts == 1: + # pyrefly: ignore [bad-argument-type] + return torch.roll(torch.flatten(a), shifts, 0).view(a.shape) + if len_shifts != len_dims: + raise RuntimeError( + f"shifts and dimensions must align. shifts: {len_shifts}, dims: {len_dims}" + ) + assert len_dims > 1 + # pyrefly: ignore [index-error] + tail_shifts = shifts[1:] + # pyrefly: ignore [index-error] + tail_dims = dims[1:] + # pyrefly: ignore [index-error] + first_dim_rolled = torch.roll(a, (shifts[0],), dims[0]) + return torch.roll(first_dim_rolled, tail_shifts, tail_dims) + + # This path is taken when only one dimension is rolled + # For example to get `first_dim_rolled` above + # pyrefly: ignore [index-error] + dim = dims[0] + size = a.shape[dim] + # pyrefly: ignore [index-error] + start = (size - shifts[0]) % size + idx = torch.arange(size, device=a.device) + return a.index_select(dim, torch.fmod(start + idx, size)) + + +@register_decomposition(aten.rot90) +@out_wrapper() +def rot90( + a: TensorLikeType, k: int = 1, dims: DimsSequenceType = (0, 1) +) -> TensorLikeType: + """Reference implementation of :func:`torch.rot90`.""" + if len(dims) != 2: + raise RuntimeError( + f"expected total rotation dims == 2, but got dims = {len(dims)}" + ) + if a.ndim < 2: + raise RuntimeError(f"expected total dims >= 2, but got total dims = {a.ndim}") + + # Do this after the initial checks to be compatible with the behavior in + # core. + dims = utils.canonicalize_dims(a.ndim, dims) + + if dims[0] == dims[1]: + raise RuntimeError( + f"expected rotation dims to be different, but got dim0 = {dims[0]} and dim1 = {dims[1]}" + ) + k = k % 4 # Rotation direction is from the second towards the first axis for k < 0 + if k == 1: + return torch.transpose(torch.flip(a, (dims[1],)), dims[0], dims[1]) + elif k == 2: + return torch.flip(a, dims) + elif k == 3: + return torch.transpose(torch.flip(a, (dims[0],)), dims[0], dims[1]) + else: + return a.clone(memory_format=torch.contiguous_format) + + +def _check_stack_inputs(tensors: TensorSequenceType) -> None: + from torch.fx.experimental.symbolic_shapes import sym_eq + + entry_shape = tensors[0].shape + for i in range(1, len(tensors)): + torch._check( + sym_eq(tensors[i].shape, entry_shape), + lambda: f"stack expects each tensor to be equal size, but got {entry_shape} at entry 0 ", + ) + + +@register_decomposition(aten.stack) +@out_wrapper() +def stack(tensors: TensorSequenceType, dim: int = 0) -> TensorLikeType: + assert len(tensors) > 0, "stack expects a non-empty TensorList" + wrapped_dim = utils.canonicalize_dim(tensors[0].ndim + 1, dim) + # Refs need sparse support to check other condition + if wrapped_dim < tensors[0].ndim: # and not tensors[0].is_sparse: + _check_stack_inputs(tensors) + result_sizes = list(tensors[0].shape) + result_sizes.insert(wrapped_dim, len(tensors)) + out = torch.cat(tensors, wrapped_dim) + return out.view(result_sizes) + + # If dim == tensors[0].ndim, view cannot efficiently handle it + return torch.cat([t.unsqueeze(wrapped_dim) for t in tensors], dim) + + +# CompositeImplicitAutograd - don't register decomp +@out_wrapper() +def softmax( + a: TensorLikeType, + dim: int, + dtype: Optional[torch.dtype] = None, +) -> TensorLikeType: + result_dtype = dtype or a.dtype + computation_dtype = utils.get_computation_dtype(result_dtype) + a_ = _maybe_convert_to_dtype(a, computation_dtype) + if a.numel() == 0: + a_exp = exp(a_) + else: + a_max = amax(a_, dim, keepdim=True) + a_exp = exp(a_ - a_max) + return _maybe_convert_to_dtype( + # pyrefly: ignore [no-matching-overload] + true_divide(a_exp, sum(a_exp, dim, keepdim=True)), + result_dtype, + ) # type: ignore[return-value] + + +# CompositeImplicitAutograd - don't register decomp +@out_wrapper() +def hstack(tensors: TensorSequenceType) -> TensorLikeType: + torch._check(len(tensors) > 0, lambda: "hstack expects a non-empty TensorList") + aligned_tensors = atleast_1d(*tensors) + if aligned_tensors[0].ndim == 1: + return cat(aligned_tensors, 0) + return cat(aligned_tensors, 1) + + +# CompositeImplicitAutograd - don't register decomp +@out_wrapper() +def vstack(tensors: TensorSequenceType) -> TensorLikeType: + torch._check(len(tensors) > 0, lambda: "vstack expects a non-empty TensorList") + aligned_tensors = atleast_2d(*tensors) + return cat(aligned_tensors, 0) + + +# CompositeImplicitAutograd - don't register decomp +def unflatten(a: TensorLikeType, dim: int, sizes: ShapeType) -> TensorLikeType: + dim = utils.canonicalize_dim(a.ndim, dim) + torch._check(len(sizes) != 0, lambda: "unflatten: sizes must be non-empty") + return a.view(tuple(a.shape[:dim]) + tuple(sizes) + tuple(a.shape[dim + 1 :])) + + +@register_decomposition(aten.unbind) +def unbind(t: TensorLikeType, dim: int = 0) -> TensorSequenceType: + dim = utils.canonicalize_dim(t.ndim, dim) + torch._check_index( + len(t.shape) > 0, + lambda: "Dimension specified as 0 but tensor has no dimensions", + ) + + # Note: t.shape[dim] can't be dynamic or unbacked, even if we use guard_or_false here we will fail + # later in the split since t.shape[dim] control the number of output tensors. + if t.shape[dim] == 0: + return () + else: + return tuple( + torch.squeeze(s, dim) for s in torch.tensor_split(t, t.shape[dim], dim) + ) + + +@out_wrapper() +def index_copy(x: TensorLike, dim: int, index: TensorLike, tensor: TensorLike): + return x.clone(memory_format=torch.contiguous_format).index_copy_( + dim, index, tensor + ) + + +def index_copy_(x: TensorLike, dim: int, index: TensorLike, tensor: TensorLike): + dim = utils.canonicalize_dims(x.ndim, dim) + torch._check( + index.ndim <= 1, + lambda: f"Index should have dimension 1 or 0 (got {index.ndim})", + ) + # Treat scalars as elements of \R^1 + y = x.unsqueeze(0) if x.ndim == 0 else x + idx = (slice(None),) * dim + (index,) + y[idx] = tensor + return x + + +@register_decomposition(aten.index_fill) +@out_wrapper() +def index_fill( + x: TensorLike, dim: int, index: TensorLike, value: Union[NumberType, TensorLike] +): + return _index_fill(x, dim, index, value, inplace=False) + + +@register_decomposition(aten.index_fill_) +def index_fill_( + x: TensorLike, dim: int, index: TensorLike, value: Union[NumberType, TensorLike] +): + return _index_fill(x, dim, index, value, inplace=True) + + +def _index_fill( + x: TensorLike, + dim: int, + index: TensorLike, + value: Union[NumberType, TensorLike], + *, + inplace: bool, +): + torch._check( + index.ndim <= 1, + lambda: f"Index should have dimension 1 or 0 (got {index.ndim})", + ) + if isinstance(value, TensorLike): + torch._check( + value.ndim == 0, + lambda: "Only supports 0-dimensional value tensor. " # type: ignore[union-attr] + f"Got a tensor with {value.ndim} dimensions.", + ) # type: ignore[arg-type] + else: + value = torch.scalar_tensor( + value, + dtype=x.dtype, + layout=x.layout, + device=x.device, # type: ignore[arg-type] + ) + + # index_copy has some unnecessary preconditions when x is a scalar. We do this to work through them + zero_dim = x.ndim == 0 + y = x.unsqueeze(0) if zero_dim else x + # index_copy does not broadcast on value so we have to do it manually + shape = list(y.shape) + shape[dim] = index.numel() + value = value.expand(shape) + index_copy = Tensor.index_copy_ if inplace else torch.index_copy + out = index_copy(y, dim, index, value) # type: ignore[operator] + if inplace: + return x + else: + if zero_dim: + # The clone is necessary so that it returns a fresh tensor rather than a view + out = out.squeeze(0).clone() + # index_fill preserves the strides. index_copy always returns contiguous tensors + if out.stride() != x.stride(): + new_out = torch.empty_like(x) + new_out.copy_(out) + out = new_out + return out + + +@out_wrapper() +def index_add( + x: TensorLike, + dim: int, + index: TensorLike, + tensor: TensorLike, + *, + alpha: NumberType = 1, +): + # index_add always returns a new contiguous tensor + return x.clone(memory_format=torch.contiguous_format).index_add_( + dim, + index, + tensor, + alpha=alpha, # type: ignore[arg-type] + ) + + +@register_decomposition(aten.index_select) +@out_wrapper() +def index_select(x: TensorLike, dim: int, index: TensorLike): + dim = utils.canonicalize_dims(x.ndim, dim) + torch._check( + index.ndim <= 1, + lambda: f"Index should have dimension 1 or 0 (got {index.ndim})", + ) + if index.ndim == 0: + index = index.unsqueeze(0) + if x.ndim == 0: + # Treat scalars as elements of \R^1 + # We cannot use x[idx] here as it accesses item() (??), hence this awkward construction + return torch.empty_like(x).index_copy(0, index, x.expand_as(index)) + + idx = (slice(None),) * dim + (index,) + return x[idx].contiguous() + + +@register_decomposition(aten.squeeze.dims) +def squeeze(a: TensorLikeType, dim: Optional[DimsType] = None) -> TensorLikeType: + from torch.fx.experimental.symbolic_shapes import guard_or_false + + if dim is None: + dims = tuple(idx for idx, size in enumerate(a.shape) if size == 1) + return prims.squeeze(a, dims) if dims else prims.view_of(a) + + ndim = a.ndim + + dim = utils.canonicalize_dims(ndim, dim) + dims = (dim,) if isinstance(dim, Dim) else dim + # Short-circuits if the tensor has no dimensions + if ndim == 0: + assert len(dims) == 0 or dims == (0,) + return prims.view_of(a) + + # Note: squeeze does not modify tensors when the given dim is not a dimension of length 1 + # would it be better if we just not allow 1 for unbacked at runtiume? + dims = tuple(d for d in dims if guard_or_false(a.shape[d] == 1)) + if len(dims) == 0: + return prims.view_of(a) + if len(dims) == 1: + return prims.squeeze(a, dims) + dims_list = list(dims) + dims_list = sorted(dims_list, reverse=True) + for i in dims_list: + a = squeeze(a, i) + return a + + +@register_decomposition(aten.split_with_sizes) +def split_with_sizes( + self: Tensor, split_sizes: list[int], dim: int = 0 +) -> list[Tensor]: + # NB: Perform the check_is_size tests first so that the + # sum test does not try to do a replacement + for i in range(len(split_sizes)): + torch._check( + split_sizes[i] >= 0, + lambda: "split_with_sizes expects split_sizes have only non-negative entries", + ) + torch._check_with( + ValueError, + builtins.sum(split_sizes) == self.shape[dim], + lambda: f"Split sizes add up to {builtins.sum(split_sizes)} but got the tensor's size of {self.shape[dim]}", + ) + + splits = [] + offset = self.storage_offset() + + for split_size in split_sizes: + new_shape = list(self.shape) + new_shape[dim] = split_size + # We reimplement narrow here to avoid a lot of checks in the + # decomposition of narrow which calls slice_in_dim and slice + splits.append(self.as_strided(new_shape, self.stride(), offset)) + offset = offset + self.stride()[dim] * split_size + return splits + + +# Note: does not work with TensorMetas because of data-dependent control-flow +# CompositeImplicitAutograd - don't register decomp +def tensor_split( + a: TensorLikeType, + indices_or_sections: Union[Tensor, DimsType], + dim: int = 0, +) -> tuple[TensorLikeType, ...]: + _dim = utils.canonicalize_dim(a.ndim, dim) + if a.ndim == 0: + msg = "tensor_split: received a rank zero tensor, but expected a tensor of rank one or greater!" + raise ValueError(msg) + + # If indices_or_sections is a tensor, it must be a CPU Long tensor + if isinstance(indices_or_sections, TensorLike): + if indices_or_sections.device.type != "cpu": + msg = ( + f"tensor_split: if indices_or_sections is a tensor it must be on the CPU, " + f"but received one on {indices_or_sections.device}" + ) + raise ValueError(msg) + if indices_or_sections.dtype != torch.long: + msg = ( + "tensor_split: if indices_or_sections is a tensor it must have long dtype, " + f" but received one with dtype {indices_or_sections.dtype}" + ) + raise ValueError(msg) + + # Case 0 -- indices_or_sections is an integer or a scalar tensor n and a is split along dim into n parts of equal-ish length + if isinstance(indices_or_sections, IntLike) or ( + isinstance(indices_or_sections, TensorLike) and indices_or_sections.ndim == 0 + ): + sections: int = ( + indices_or_sections # type: ignore[assignment] + if isinstance(indices_or_sections, Number) + else indices_or_sections.item() + ) + + if sections <= 0: + msg = f"tensor_split: number of sections must be greater than 0, but was {sections}" + raise ValueError(msg) + + dim_size = a.shape[_dim] + min_split_size = math.floor(dim_size / sections) + num_splits_one_extra = dim_size % sections + + split_sizes = [] + for split_idx in range(sections): + split_size = ( + min_split_size + 1 + if (split_idx < num_splits_one_extra) + else min_split_size + ) + split_sizes.append(split_size) + + return tuple(aten.split_with_sizes(a, split_sizes, dim=_dim)) + # Case 1 -- indices_or_sections is a sequence of integers or a 1D tensor describing the splits + else: + indices = indices_or_sections + if isinstance(indices_or_sections, TensorLike): + if indices_or_sections.ndim != 1: + msg = ( + "tensor_split: non-scalar indices_or_sections tensors must have only one dimension, " + f"but received a tensor with {indices_or_sections.ndim} dimensions" + ) + raise ValueError(msg) + + indices = indices_or_sections.tolist() + + indices = [0] + list(indices) + [a.shape[_dim]] + split_sizes = [indices[i + 1] - indices[i] for i in range(len(indices) - 1)] + return tuple(aten.split_with_sizes(a, split_sizes, dim=_dim)) + + +# CompositeImplicitAutograd - don't register decomp +def hsplit( + a: TensorLikeType, indices_or_sections: DimsType +) -> tuple[TensorLikeType, ...]: + torch._check( + a.ndim >= 1, + lambda: ( + "torch.hsplit requires a tensor with at least 1 dimension, but got a tensor with " + + str(a.ndim) + + " dimensions!" + ), + ) + dim = 0 if a.ndim == 1 else 1 + if isinstance(indices_or_sections, IntLike): + split_size = indices_or_sections + torch._check( + # pyrefly: ignore [unsupported-operation] + (split_size != 0 and a.shape[dim] % split_size == 0), + lambda: ( + "torch.hsplit attempted to split along dimension " + + str(dim) + + ", but the size of the dimension " + + str(a.shape[dim]) + + " is not divisible by the split_size " + + str(split_size) + + "!" + ), + ) + # pyrefly: ignore [bad-argument-type] + return tensor_split(a, split_size, dim) + + torch._check_type( + isinstance(indices_or_sections, (list, tuple)), + lambda: ( + "hsplit(): received an invalid combination of arguments. " + "Expected indices_or_sections to be of type int, list of ints or tuple of ints " + f"but got type {type(indices_or_sections)}" + ), + ) + + split_sizes = indices_or_sections + return tensor_split(a, split_sizes, dim) + + +# CompositeImplicitAutograd - don't register decomp +def vsplit( + a: TensorLikeType, indices_or_sections: DimsType +) -> tuple[TensorLikeType, ...]: + torch._check( + a.ndim >= 2, + lambda: ( + "torch.vsplit requires a tensor with at least 2 dimension, but got a tensor with " + + str(a.ndim) + + " dimensions!" + ), + ) + if isinstance(indices_or_sections, IntLike): + split_size = indices_or_sections + torch._check( + # pyrefly: ignore [unsupported-operation] + (split_size != 0 and a.shape[0] % split_size == 0), + lambda: ( + f"torch.vsplit attempted to split along dimension 0" + f", but the size of the dimension " + f"{a.shape[0]}" + f" is not divisible by the split_size " + f"{split_size}" + f"!" + ), + ) + # pyrefly: ignore [bad-argument-type] + return tensor_split(a, split_size, 0) + + torch._check_type( + isinstance(indices_or_sections, (list, tuple)), + lambda: ( + "vsplit(): received an invalid combination of arguments. " + "Expected indices_or_sections to be of type int, list of ints or tuple of ints " + f"but got type {type(indices_or_sections)}" + ), + ) + + split_sizes = indices_or_sections + return tensor_split(a, split_sizes, 0) + + +@register_decomposition(aten.diag.out) +@out_wrapper() +def diag( + self: TensorLikeType, + offset: int = 0, +) -> TensorLikeType: + ndim = self.dim() + torch._check( + ndim in (1, 2), lambda: f"diag(): Supports 1D or 2D tensors. Got {ndim}D" + ) + if ndim == 1: + return torch.diag_embed(self, offset) + else: + return torch.diagonal_copy(self, offset) + + +@register_decomposition(aten.diagonal_scatter) +@out_wrapper() +def diagonal_scatter( + input: TensorLikeType, + src: TensorLikeType, + offset: int = 0, + dim1: int = 0, + dim2: int = 1, +) -> TensorLikeType: + out = utils.clone_preserve_strides(input) + diag = out.diagonal(offset, dim1, dim2) + torch._check( + diag.shape == src.shape, + lambda: "expected src to have a size equal to the diagonal of the input." + f"Got {src.shape} for a diagonal of shape {diag.shape}", + ) + copy_to(diag, src) + return out + + +@register_decomposition(aten.diagonal) +def diagonal( + self: TensorLikeType, + offset: int = 0, + dim1: int = 0, + dim2: int = 1, +) -> TensorLikeType: + """ + Reference implementation of torch.diagonal + """ + num_dims = self.dim() + dim1 = utils.canonicalize_dim(idx=dim1, rank=num_dims) + dim2 = utils.canonicalize_dim(idx=dim2, rank=num_dims) + + torch._check( + dim1 != dim2, lambda: f"diagonal dimensions cannot be identical {dim1}, {dim2}" + ) + + storage_offset = self.storage_offset() + + if offset >= 0: + diag_size = max(min(self.size()[dim1], self.size()[dim2] - offset), 0) + else: + diag_size = max(min(self.size()[dim1] + offset, self.size()[dim2]), 0) + + if diag_size > 0: + if offset >= 0: + storage_offset += offset * self.stride()[dim2] + else: + storage_offset -= offset * self.stride()[dim1] + + sizes = [s for i, s in enumerate(self.size()) if i not in (dim1, dim2)] + sizes.append(diag_size) + + strides = [s for i, s in enumerate(self.stride()) if i not in (dim1, dim2)] + strides.append(self.stride()[dim1] + self.stride()[dim2]) + + result = self.as_strided(size=sizes, stride=strides, storage_offset=storage_offset) + + return result + + +@register_decomposition(aten.diag_embed) +@out_wrapper() +def diag_embed( + t: TensorLikeType, + offset: int = 0, + dim1: int = -2, + dim2: int = -1, +) -> TensorLikeType: + """ + Reference implementation of torch.diag_embed + """ + # convert from negative dims + rank = t.ndim + 1 + dim1 = utils.canonicalize_dim(rank=rank, idx=dim1) + dim2 = utils.canonicalize_dim(rank=rank, idx=dim2) + + # as per the docs, exchanging dims is equivalent to changing the sign of + # offset + if dim1 > dim2: + dim1, dim2 = dim2, dim1 + offset = -offset + + torch._check( + dim1 != dim2, lambda: f"diagonal dimensions cannot be identical {dim1}, {dim2}" + ) + + # as per the docs, the size of last dim is placed at dim1 and dim2 + last_dim = t.size(-1) + + if offset != 0: + # add padding to match the new size + t_shape = list(t.shape) + t_shape[-1] = builtins.abs(offset) + z = torch.zeros(t_shape, dtype=t.dtype, device=t.device, requires_grad=False) + pair = (z, t) if offset > 0 else (t, z) + t = torch.cat(pair, dim=-1) + # make sure the diagonal always has the same size + last_dim += builtins.abs(offset) + + # preserve original data, but place 1 at dim1 and move last dim to dim2 + t = t.unsqueeze(dim1).movedim(-1, dim2) + + # generate ranges shifting indices based on offset + a_range = torch.arange(last_dim, device=t.device, dtype=torch.int64) + b_range = torch.arange( + offset, last_dim + offset, device=t.device, dtype=torch.int64 + ) + + # broadcast + cond = a_range == b_range.unsqueeze(-1) + cond_shape = [last_dim if i in (dim1, dim2) else 1 for i in range(len(t.shape))] + cond = cond.reshape(cond_shape) + + # aten.diag_embed always returns a new contiguous tensor + # contiguous() is needed to correctly model the output stride + return utils.mask_tensor(cond, t).contiguous() + + +@register_decomposition(aten.block_diag) +@out_wrapper() +def _block_diag_iterable(tensors: list[TensorLikeType]) -> TensorLikeType: + """ + Reference implementation of torch.block_diag + """ + tensors_2d = [ + tensor.view(1, -1) if tensor.dim() <= 1 else tensor for tensor in tensors + ] + + ncols = builtins.sum(tensor.shape[1] for tensor in tensors_2d) + device = tensors_2d[0].device + + result = [] + + col_start = 0 + for i, tensor in enumerate(tensors_2d): + torch._check( + tensor.dim() == 2, + lambda: "Input tensors must have 2 or fewer dimensions. " + f"Input {i} has {tensor.dim()} dimensions", + ) + torch._check( + tensor.device == device, + lambda: "Input tensors must all be on the same device. " + f"Input 0 is on device {device} and input {i} is on device {tensor.device}.", + ) + row, col = tensor.shape + left = torch.zeros((row, col_start), device=device, dtype=tensor.dtype) + right = torch.zeros( + (row, ncols - col_start - col), device=device, dtype=tensor.dtype + ) + result += [torch.cat((left, tensor, right), dim=1)] + col_start += col + + return torch.cat(result, dim=0) + + +def block_diag(*tensors: list[TensorLikeType]) -> TensorLikeType: + """ + This is used as an input to PythonRefInfo. `torch.block_diag` + expects arguments splatted, but `aten.block_diag` expects only + one argument that is a list of Tensors. + """ + return _block_diag_iterable(tensors) # type: ignore[arg-type] + + +# CompositeImplicitAutograd - don't register decomp +def dsplit(a: TensorLikeType, sections: DimsType) -> TensorSequenceType: + if a.ndim < 3: + raise RuntimeError( + f"torch.dsplit requires a tensor with at least 3 dimension, but got a tensor with {a.ndim} dimensions!" + ) + # pyrefly: ignore [unsupported-operation] + if isinstance(sections, IntLike) and (sections == 0 or a.shape[2] % sections != 0): + raise RuntimeError( + "torch.dsplit attempted to split along dimension 2, " + + f"but the size of the dimension {a.shape[2]} is not divisible by the split_size {sections}!" + ) + return tensor_split(a, sections, 2) + + +@register_decomposition(aten.t.default) +def t(a: TensorLikeType): + # TODO: Add sparse support + # if a.is_sparse: + # sparse_dim = a.sparse_dim() + # dense_dim = a.dense_dim() + # if not (sparse_dim <= 2 and dense_dim == 0): + # raise RuntimeError( + # f"t() expects a tensor with <= 2 sparse and 0 dense dimensions, but got {sparse_dim} sparse and" + # f"{dense_dim} dense dimensions" + # ) + if a.ndim > 2: + raise RuntimeError( + f"t() expects a tensor with <= 2 dimensions, but self is {a.ndim}D" + ) + return torch.transpose(a, 0, 0 if a.ndim < 2 else 1) + + +# CompositeImplicitAutograd - don't register decomp +def T(a: TensorLikeType) -> TensorLikeType: + # n != 2 && n != 0 is deprecated in regular PyTorch. + torch._check( + a.ndim in (0, 2), + lambda: ( + "The use of `x.T` on tensors of dimension other than 0 or 2 " + "to reverse their shape is not supported." + ), + ) + return a.t() + + +@register_decomposition(aten.alias) +def alias(a: TensorLikeType) -> TensorLikeType: + return prims.view_of(a) + + +@register_decomposition(aten.transpose) +def transpose(a: TensorLikeType, dim0: int, dim1: int) -> TensorLikeType: + _dim0, _dim1 = utils.canonicalize_dims(a.ndim, (dim0, dim1)) # type: ignore[misc] + + if a.ndim <= 1 or dim0 == dim1: + return aten.alias.default(a) + + _permutation = list(range(a.ndim)) + _permutation[_dim0] = _dim1 + _permutation[_dim1] = _dim0 + return torch.permute(a, _permutation) + + +# Aliases for transpose +swap_axes = transpose + + +@register_decomposition(aten.unfold) +def unfold( + self: TensorLikeType, dimension: int, size: int, step: int +) -> TensorLikeType: + shape, strides = _get_unfold_shape_stride( + self.shape, self.stride(), dimension, size, step + ) + return self.as_strided(shape, strides) + + +@register_decomposition(aten.unfold_copy) +@out_wrapper() +def unfold_copy(self: TensorLikeType, dimension: int, size: int, step: int): + return self.unfold(dimension, size, step).clone( + memory_format=torch.contiguous_format + ) + + +def _cumsumprod_common( + func, + init, + a: TensorLikeType, + dim: int, + *, + dtype: Optional[torch.dtype] = None, + out: Optional[Tensor] = None, +) -> TensorLikeType: + # We implement all the kwargs of a reduction. ATen just handles dtype + # nb. This decomposition may not be as efficient as a backend-specific implementation + ndim = a.ndim + dim = utils.canonicalize_dim(ndim, dim) + if ndim == 0: + return func(a.unsqueeze(0), dim=0, dtype=dtype, out=out) + a = a.unsqueeze(dim + 1) + rg = torch.arange(a.shape[dim], device=a.device) + mask = rg.unsqueeze(1) <= rg + for _ in range(ndim - dim - 1): + mask = mask.unsqueeze(-1) + masked_a = torch.where(mask, a, init) + return func(masked_a, dim=dim, dtype=dtype, out=out) + + +@register_decomposition(aten.cumsum) +def cumsum( + a: TensorLikeType, + dim: int, + *, + dtype: Optional[torch.dtype] = None, + out: Optional[Tensor] = None, +) -> TensorLikeType: + return _cumsumprod_common(func=sum, init=0, a=a, dim=dim, dtype=dtype, out=out) + + +@register_decomposition(aten.cumprod) +def cumprod( + a: TensorLikeType, + dim: int, + *, + dtype: Optional[torch.dtype] = None, + out: Optional[Tensor] = None, +) -> TensorLikeType: + return _cumsumprod_common(func=prod, init=1, a=a, dim=dim, dtype=dtype, out=out) + + +# Note: although squeeze is documented as having the out= kwarg it doesn't +@register_decomposition(aten.unsqueeze) +def unsqueeze(a: TensorLikeType, dim: int) -> TensorLikeType: + # Note that unsqueeze canonicalizes with rank + 1 because it allows + # a new innermost dimension to be specified + ndim = a.ndim + 1 + dim = utils.canonicalize_dim(ndim, dim) + return prims.expand_dims(a, (dim,), ndim=ndim) + + +# NOTE: shape is a vararg because Tensor.reshape can be called with as +# Tensor.view(a, b, c) or Tensor.view((a, b, c)) Function call torch.view +# doesn't support unpacked shapes +# TODO: Turn this into a decomposition (currently fails on reshape meta tests) +@register_decomposition(aten.view.default) +def view(a: TensorLikeType, *shape: ShapeType) -> TensorLikeType: + return _reshape_view_helper(a, *shape, allow_copy=False) + + +# CompositeImplicitAutograd - don't register decomp +def view_as(self: TensorLikeType, other: TensorLikeType) -> TensorLikeType: + return self.view(other.size()) + + +# CompositeImplicitAutograd - don't register decomp +def ravel(a: TensorLikeType) -> TensorLikeType: + return reshape(a, (-1,)) + + +# CompositeImplicitAutograd - don't register decomp +# missing ref impl. for aten.gather +@out_wrapper() +def take_along_dim( + a: torch.Tensor, indices: torch.Tensor, dim: Optional[int] = None +) -> torch.Tensor: + torch._check( + a.ndim == indices.ndim, + lambda: ( + "torch.take_along_dim(): input and indices should have the same " + f"number of dimensions, but got {a.ndim} dimensions for input, and " + f"{indices.ndim} dimensions for indices" + ), + ) + + torch._check( + utils.is_integer_dtype(indices.dtype), + lambda: ( + "torch.take_along_dim(): dtype of indices should be int but got " + f"{indices.dtype} instead" + ), + ) + + if dim is None: + return torch.gather(a.view(-1), 0, indices.view(-1)) + else: + self_sizes = list(a.shape) + self_sizes[dim] = indices.size(dim) + broadcast_shape = utils.infer_size_shapes(self_sizes, indices.size()) + indices_broadcast = broadcast_to(indices, broadcast_shape) + + indices_sizes = list(indices.shape) + indices_sizes[dim] = a.size(dim) + broadcast_shape = utils.infer_size_shapes(indices_sizes, a.size()) + self_broadcast = broadcast_to(a, broadcast_shape) + + # wrap negative indices + dim_size = self_broadcast.size(dim) + indices_broadcast = indices_broadcast % dim_size + + return torch.gather(self_broadcast, dim, indices_broadcast) + + +@out_wrapper() +def empty( + *shape, + dtype: Optional[torch.dtype] = None, + layout: torch.layout = torch.strided, + device: Optional[DeviceLikeType] = None, + requires_grad: bool = False, + pin_memory: bool = False, + memory_format: torch.memory_format = torch.contiguous_format, +) -> TensorLikeType: + torch._check( + memory_format != torch.preserve_format, + lambda: "torch.empty: the Preserve memory format is not supported", + ) + + shape = utils.extract_shape_from_varargs(shape) + + if memory_format == torch.contiguous_format: + strides = utils.make_contiguous_strides_for(shape) + elif memory_format == torch.channels_last_3d: + strides = utils.make_channels_last_3d_strides_for(shape) + else: # memory_format == torch.channels_last + torch._check( + memory_format == torch.channels_last, + lambda: f"torch.empty: received an unknown memory format {memory_format}!", + ) + strides = utils.make_channels_last_2d_strides_for(shape) + + return torch.empty_strided( + shape, + strides, + dtype=dtype, + layout=layout, + device=device, + pin_memory=pin_memory, + requires_grad=requires_grad, + ) + + +@out_wrapper() +def empty_permuted( + shape, + physical_layout, + dtype: Optional[torch.dtype] = None, + layout: torch.layout = torch.strided, + device: Optional[DeviceLikeType] = None, + requires_grad: bool = False, + pin_memory: bool = False, +) -> TensorLikeType: + return prims.empty_permuted( + shape, + physical_layout, + dtype=dtype, + device=device, + requires_grad=requires_grad, + ) + + +@register_decomposition(aten.new_empty) +@out_wrapper() +def new_empty( + a: TensorLikeType, + size: ShapeType, + *, + dtype: Optional[torch.dtype] = None, + layout: Optional[torch.layout] = None, + device: Optional[DeviceLikeType] = None, + pin_memory: bool = False, +) -> TensorLikeType: + dtype = a.dtype if dtype is None else dtype + layout = a.layout if layout is None else layout + device = a.device if device is None else device + + return torch.empty( + size, + dtype=dtype, + device=device, + pin_memory=pin_memory, + layout=layout, + ) + + +@register_decomposition(aten.new_empty_strided) +@out_wrapper() +def new_empty_strided( + a: TensorLikeType, + size: ShapeType, + stride: StrideType, + *, + dtype: Optional[torch.dtype] = None, + layout: Optional[torch.layout] = None, + device: Optional[DeviceLikeType] = None, + pin_memory: bool = False, +) -> TensorLikeType: + """ + Reference implementation of torch.Tensor.new_empty_strided + """ + + dtype = a.dtype if dtype is None else dtype + layout = a.layout if layout is None else layout + device = a.device if device is None else device + + return torch.empty_strided( + size, + stride, + dtype=dtype, + device=device, + pin_memory=pin_memory, + layout=layout, + ) + + +@register_decomposition(aten.zeros.default) +@out_wrapper() +def zeros( + *size, + dtype: Optional[torch.dtype] = None, + layout: torch.layout = torch.strided, + device: Optional[DeviceLikeType] = None, + pin_memory: bool = False, + requires_grad: bool = False, +) -> TensorLikeType: + size = utils.extract_shape_from_varargs(size) + + if dtype is None: + dtype = torch.get_default_dtype() + + return torch.full( + size, + False if dtype == torch.bool else 0, + dtype=dtype, + layout=layout, + device=device, + pin_memory=pin_memory, + requires_grad=requires_grad, + ) + + +@register_decomposition(aten.new_zeros) +@out_wrapper() +def new_zeros( + a: TensorLikeType, + size: ShapeType, + *, + dtype: Optional[torch.dtype] = None, + layout: Optional[torch.layout] = None, + device: Optional[DeviceLikeType] = None, + pin_memory: bool = False, + requires_grad: bool = False, +) -> TensorLikeType: + dtype = a.dtype if dtype is None else dtype + layout = a.layout if layout is None else layout + device = a.device if device is None else device + + return torch.full( + size, + False if (dtype or a.dtype) == torch.bool else 0, + dtype=dtype, + layout=layout, + device=device, + pin_memory=pin_memory, + requires_grad=requires_grad, + ) + + +@register_decomposition(aten.ones.default) +@out_wrapper() +def ones( + *size, + dtype: Optional[torch.dtype] = None, + layout: torch.layout = torch.strided, + device: Optional[DeviceLikeType] = None, + pin_memory: bool = False, + requires_grad: bool = False, +) -> TensorLikeType: + size = utils.extract_shape_from_varargs(size) + + if dtype is None: + dtype = torch.get_default_dtype() + + return torch.full( + size, + True if dtype == torch.bool else 1, + dtype=dtype, + layout=layout, + device=device, + pin_memory=pin_memory, + requires_grad=requires_grad, + ) + + +@register_decomposition(aten.new_ones) +@out_wrapper() +def new_ones( + a: TensorLikeType, + size: ShapeType, + *, + dtype: Optional[torch.dtype] = None, + layout: Optional[torch.layout] = None, + device: Optional[DeviceLikeType] = None, + pin_memory: bool = False, + requires_grad: bool = False, +) -> TensorLikeType: + dtype = a.dtype if dtype is None else dtype + layout = a.layout if layout is None else layout + device = a.device if device is None else device + + return torch.full( + size, + True if (dtype or a.dtype) == torch.bool else 1, + dtype=dtype, + layout=layout, + device=device, + pin_memory=pin_memory, + requires_grad=requires_grad, + ) + + +@register_decomposition(aten.new_full) +@out_wrapper() +def new_full( + a: TensorLikeType, + size: ShapeType, + fill_value: NumberType, + *, + dtype: Optional[torch.dtype] = None, + layout: Optional[torch.layout] = None, + device: Optional[DeviceLikeType] = None, + pin_memory: bool = False, +) -> TensorLikeType: + dtype = a.dtype if dtype is None else dtype + layout = a.layout if layout is None else layout + device = a.device if device is None else device + + return torch.full( + size, + fill_value, + dtype=dtype, + layout=layout, + device=device, + pin_memory=pin_memory, + ) + + +@aten.empty.out.py_impl(DispatchKey.CompositeImplicitAutograd) +def empty_out( + size: TensorLikeType, + out: TensorLikeType, + memory_format: Optional[torch.memory_format] = None, +) -> TensorLikeType: + return out + + +@register_decomposition(aten.empty_like) +@out_wrapper() +def empty_like( + a: TensorLikeType, + *, + dtype: Optional[torch.dtype] = None, + device: Optional[DeviceLikeType] = None, + layout: Optional[torch.layout] = None, + pin_memory: bool = False, + requires_grad: bool = False, + memory_format: torch.memory_format = torch.preserve_format, +) -> TensorLikeType: + dtype = a.dtype if dtype is None else dtype + layout = a.layout if layout is None else layout + device = a.device if device is None else device + + if memory_format != torch.preserve_format: + return torch.empty( + a.shape, + dtype=dtype, + layout=layout, + device=device, + requires_grad=requires_grad, + pin_memory=pin_memory, + memory_format=memory_format, + ) + + # memory_format == torch.preserve_format + logical_to_physical_perm, _ = ( + utils.compute_elementwise_output_logical_to_physical_perm(a) + ) + # identity perm is [2, 1, 0] + return torch.empty_permuted( + a.shape, + logical_to_physical_perm, + dtype=dtype, + layout=layout, + device=device, + pin_memory=pin_memory, + requires_grad=requires_grad, + ) + + +@register_decomposition([aten.arange.start_step, aten.arange.start_out]) +@out_wrapper() +def arange( + start: NumberType = 0, + end: Optional[NumberType] = None, + step: NumberType = 1, + *, + dtype: Optional[torch.dtype] = None, + layout: torch.layout = torch.strided, + device: Optional[DeviceLikeType] = None, + pin_memory: bool = False, + requires_grad: bool = False, +) -> TensorLikeType: + utils.check_layout(layout) + utils.check_pin_memory(pin_memory) + device = torch.device(utils.device_or_default(device)) + + assert not isinstance(start, complex) + assert not isinstance(end, complex) + assert not isinstance(step, complex) + + # Case: torch.arange(5) + if end is None: + end = start + start = 0 + torch._check(step != 0, lambda: "step must be nonzero") + if step > 0: + torch._check( + end >= start, + lambda: "upper bound and lower bound inconsistent with step sign", + ) + elif step < 0: + torch._check( + end <= start, + lambda: "upper bound and lower bound inconsistent with step sign", + ) + + def is_finite(x): + return not isinstance(x, FloatWithoutSymFloat) or math.isfinite(x) + + torch._check( + is_finite(start) and is_finite(end), + lambda: f"unsupported range: {start} -> {end}", + ) + torch._check( + is_finite(step), + lambda: f"step must be finite but got {step}", + ) + + args = (start, end, step) + integer_args = builtins.all(isinstance(arg, IntLike) for arg in args) + + if dtype is None: + dtype = torch.int64 if integer_args else torch.get_default_dtype() + + is_integer = utils.is_integer_dtype(dtype) + if is_integer or integer_args: + xstart = sym_int(start) + xend = sym_int(end) + xstep = sym_int(step) + + # For int64 we truncate arguments to int before calculating length, but + # other integral dtypes we don't. Weird... but needed to match ATen shapes. + if dtype == torch.int64 or integer_args: + # Uses floordiv to avoid ceil in inductor. + sgn = bool(xstep > 0) - bool(xstep < 0) # type: ignore[possibly-undefined] + length = (xend - xstart + xstep - sgn) // xstep # type: ignore[possibly-undefined] + else: + length = math.ceil((end - start) / step) + + if is_integer: + return prims.iota( + length, + start=xstart, # type: ignore[possibly-undefined] + step=xstep, # type: ignore[possibly-undefined] + dtype=dtype, + device=device, + requires_grad=requires_grad, + ) + + index = prims.iota( + length, + start=0, + step=1, + dtype=torch.int64, + device=device, + requires_grad=False, + ) + + computation_dtype = ( + torch.long if integer_args else utils.get_acc_type(dtype, device) + ) + index = _maybe_convert_to_dtype(index, computation_dtype) + result = start + step * index + result = _maybe_convert_to_dtype(result, dtype) + + if requires_grad: + result.requires_grad_(True) + return result + + +@register_decomposition(aten.lerp) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("start", "end", "weight"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def lerp(start: Tensor, end: Tensor, weight: Union[Tensor, NumberType]): + inputs = [start, end] + if isinstance(weight, Number): + weight = start.new_full((), weight) # type: ignore[arg-type] + else: + inputs.append(weight) + assert isinstance(weight, Tensor) # mypy + # We implement it this way for numerical stability. We assume (in the stability optimisation) + # that 0 <= weight <= 1. We take the abs to deal with complex numbers + # We want to perform operations near zero, which is where floating points are most precise + # thus, we perform the following optimisation: + # If weight.abs() >= 0.5: + # return (1 - weight) * (start - end) + end + mask = weight.abs() >= 0.5 + coeff = torch.where(mask, weight - 1, weight) + base = torch.where(mask, end, start) + output = coeff * (end - start) + base + # make sure the decomposition output's stride is same as non-decomposition path. + stride = utils.compute_elementwise_output_strides(*_maybe_broadcast(*inputs)) + if output.stride() != stride: + output = prims.copy_strided(output, stride) + + return handle_noncontiguous_outputs(inputs, output) + + +@register_decomposition(aten.linspace) +@out_wrapper() +def linspace( + start: Union[NumberType, TensorLikeType], + end: Union[NumberType, TensorLikeType], + steps: NumberType, + *, + dtype: Optional[torch.dtype] = None, + device: Optional[DeviceLikeType] = None, + layout: torch.layout = torch.strided, + pin_memory: bool = False, + requires_grad: bool = False, +) -> TensorLikeType: + if isinstance(start, TensorLikeType): + torch._check( + start.dim() == 0, + lambda: "linspace only supports 0-dimensional start and end tensors", + ) + start = _maybe_convert_to_dtype(start, torch.float64) + if isinstance(end, TensorLikeType): + torch._check( + end.dim() == 0, + lambda: "linspace only supports 0-dimensional start and end tensors", + ) + end = _maybe_convert_to_dtype(end, torch.float64) + + if builtins.any(isinstance(arg, complex) for arg in (start, end, steps)): + default_complex_dtype = utils.corresponding_complex_dtype( + torch.get_default_dtype() + ) + if dtype is None: + dtype = default_complex_dtype + else: + torch._check( + utils.is_complex_dtype(dtype), + lambda: f"linspace(): inferred dtype {default_complex_dtype} can't be safely cast to passed dtype {dtype}", + ) + else: + dtype = dtype or torch.get_default_dtype() + assert isinstance(dtype, torch.dtype) + + # steps does not participate in the computation of the dtype + torch._check_type( + isinstance(steps, IntLike), + lambda: f"received an invalid combination of arguments - got \ +({type(start).__name__}, {type(end).__name__}, {type(steps).__name__})", + ) + assert isinstance(steps, IntLike) # for mypy + torch._check(steps >= 0, lambda: "number of steps must be non-negative") + + factory_kwargs = { + "layout": layout, + "device": device, + "pin_memory": pin_memory, + "requires_grad": requires_grad, + } + if steps == 0: + return torch.full((0,), 0, dtype=dtype, **factory_kwargs) # type: ignore[arg-type] + if steps == 1: + if isinstance(start, TensorLikeType): + empty_tensor = torch.empty((steps,), dtype=dtype, **factory_kwargs) # type: ignore[arg-type] + return torch.ops.aten.copy.default(empty_tensor, start) + else: + return torch.full((steps,), start, dtype=dtype, **factory_kwargs) # type: ignore[arg-type] + + # Perform in arange in int because some backends like ATen or Triton do not support all the dtypes + rg = torch.arange(0, steps, **factory_kwargs) # type: ignore[arg-type] + + # Small types need to be computed in higher precision as this is, at heart, an associative scan + dtype_red = ( + torch.int64 + if (utils.is_boolean_dtype(dtype) or utils.is_integer_dtype(dtype)) + else dtype + ) + computation_dtype, _ = utils.reduction_dtypes( + rg, REDUCTION_OUTPUT_TYPE_KIND.SAME, dtype_red + ) + cast_rg = partial(_maybe_convert_to_dtype, dtype=computation_dtype) + + # We implement torch.lerp without performing rg / (steps - 1) explicitly + # With this we get out[0] == start, out[-1] == end + step = (end - start) / (steps - 1) + out = torch.where( + rg < steps / 2, + start + step * cast_rg(rg), # type: ignore[arg-type,operator] + end - step * cast_rg((steps - 1) - rg), # type: ignore[arg-type,operator] + ) + return _maybe_convert_to_dtype(out, dtype) # type: ignore[return-value] + + +@register_decomposition(aten.logspace) +@out_wrapper() +def logspace( + start: Union[NumberType, TensorLikeType], + end: Union[NumberType, TensorLikeType], + steps: NumberType, + base: NumberType = 10, + *, + dtype: Optional[torch.dtype] = None, + device: Optional[DeviceLikeType] = None, + layout: torch.layout = torch.strided, + pin_memory: bool = False, + requires_grad: bool = False, +) -> TensorLikeType: + if dtype is None: + dtype = torch.get_default_dtype() + + # NB: NumPy doesn't have this cast + if prims.utils.is_integer_dtype(dtype): + if isinstance(start, FloatLike): + start = sym_int(start) + elif isinstance(start, TensorLikeType): + torch._check( + start.dim() == 0, + lambda: "logspace only supports 0-dimensional start and end tensors", + ) + start = _maybe_convert_to_dtype(start, dtype) + if isinstance(end, FloatLike): + end = sym_int(end) + elif isinstance(end, TensorLikeType): + torch._check( + end.dim() == 0, + lambda: "logspace only supports 0-dimensional start and end tensors", + ) + end = _maybe_convert_to_dtype(end, dtype) + + if builtins.any(isinstance(arg, complex) for arg in (start, end, steps)): + default_complex_dtype = utils.corresponding_complex_dtype( + torch.get_default_dtype() + ) + dtype = default_complex_dtype + _dtype = None # torch.linspace will update the correct dtype + else: + _dtype = torch.float64 + + assert not isinstance(base, complex) # for mypy + if base < 0: + raise NotImplementedError + ret = torch.linspace( # type: ignore[misc] + start, # type: ignore[arg-type] + end, # type: ignore[arg-type] + steps, # type: ignore[arg-type] + dtype=_dtype, + layout=layout, + device=device, + pin_memory=pin_memory, + requires_grad=requires_grad, + ) + return _maybe_convert_to_dtype(torch.pow(base, ret), dtype) # type: ignore[arg-type,return-value] + + +@overload +# pyrefly: ignore [inconsistent-overload] +def meshgrid(tensors: Sequence[TensorLikeType], indexing: str): + pass + + +@overload +def meshgrid(*tensors: TensorLikeType, indexing: str): + pass + + +@register_decomposition(aten.meshgrid) # type: ignore[misc] +def meshgrid( + *tensors: Union[TensorLikeType, list[TensorLikeType], tuple[TensorLikeType]], + indexing: str, +) -> list[TensorLikeType]: + # This ref simultaneously handles two overloads (see stubs above) + # The `indexing` argument is currently optional for torch.meshgrid, but we + # plan to make the argument required: https://github.com/pytorch/pytorch/issues/50276 + if isinstance(tensors[0], (list, tuple)): + assert len(tensors) == 1 + tensors = tuple(tensors[0]) + + torch._check( + builtins.all(isinstance(a, TensorLike) for a in tensors), + lambda: "meshgrid expects its inputs to be tensors", + ) + + torch._check(len(tensors) > 0, lambda: "meshgrid expects a non-empty TensorList") + + for i in range(len(tensors) - 1): + torch._check( + tensors[i].dtype == tensors[i + 1].dtype, # type: ignore[union-attr] + lambda: "meshgrid expects all tensors to have the same dtype", + ) + torch._check( + tensors[i].device == tensors[i + 1].device, # type: ignore[union-attr] + lambda: "meshgrid expects all tensors to have the same device", + ) + + swap_first_and_second_tensors = False + if indexing == "xy": + swap_first_and_second_tensors = len(tensors) >= 2 + if swap_first_and_second_tensors: + tensors = (tensors[1], tensors[0], *tensors[2:]) + else: + torch._check( + indexing == "ij", + lambda: ( + 'torch.meshgrid: indexing must be one of "xy" or "ij", ' + f"but received: {indexing}" + ), + ) + + result_shape: list[int] = [] + for t in tensors: + assert isinstance(t, TensorLike) # mypy + torch._check( + t.ndim == 0 or t.ndim == 1, + lambda: f"torch.meshgrid: Expected 0D or 1D tensor in the tensor list but got: {t}", + ) + result_shape.append(t.numel()) + + grids: list[TensorLikeType] = [] + for i, t in enumerate(tensors): + assert isinstance(t, TensorLike) # mypy + if t.ndim == 0: + t = t.view((1,)) + grids.append(prims.broadcast_in_dim(t, result_shape, (i,))) + + if swap_first_and_second_tensors: + # Swap outputs if we originally swapped at the beginning + grids[0], grids[1] = grids[1], grids[0] + + return grids + + +# CompositeImplicitAutograd - don't register decomp +def movedim( + input: TensorLikeType, + source: Union[int, DimsSequenceType], + destination: Union[int, DimsSequenceType], +) -> TensorLikeType: + """ + Reference implementation of torch.movedim + """ + if type(source) is int: + source = (source,) + if type(destination) is int: + destination = (destination,) + + # Converts to list to produce a compatible error message with core PyTorch, + # which prints sequences in square brackets. + torch._check( + len(source) == len(destination), # type: ignore[arg-type] + lambda: ( + "movedim: Invalid source or destination dims: source " # type: ignore[arg-type] + f"({list(source)} dims) should contain the same number " # type: ignore[arg-type] + f"of dims as destination ({list(destination)} dims)" # type: ignore[arg-type] + ), + ) + + rank = input.ndim + ss = tuple(utils.canonicalize_dims(rank=rank, indices=source)) # type: ignore[arg-type] + ds = tuple(utils.canonicalize_dims(rank=rank, indices=destination)) # type: ignore[arg-type] + + sss = set(ss) + dss = set(ds) + + # See above on why this converts to list in error messages. + torch._check( + len(ss) == len(sss), + lambda: f"movedim: repeated dim in `source` ({list(source)})", # type: ignore[arg-type] + ) + torch._check( + len(ds) == len(dss), + lambda: f"movedim: repeated dim in `destination` ({list(destination)})", # type: ignore[arg-type] + ) + + m = dict(zip(ds, ss)) + dims = [] + si = 0 # source index + for di in range(rank): + # check if the destination index is in the mapping + s = m.get(di) + if s is not None: + # insert source index if found + dims.append(s) + else: + # insert source index sequentially, skipping indices from the mapping + while si in sss: + si += 1 + dims.append(si) + si += 1 + + result = torch.permute(input, tuple(dims)) + + return result + + +# NOTE: for convenience, shape can be a tuple of ints or a tuple containing a tuple of ints +@register_decomposition(aten.empty_strided) +@out_wrapper() +def empty_strided( + shape: Union[ShapeType, tuple[ShapeType]], + strides: StrideType, + *, + dtype: Optional[torch.dtype] = None, + device: Optional[DeviceLikeType] = None, + layout: torch.layout = torch.strided, + requires_grad: bool = False, + pin_memory: bool = False, +) -> TensorLikeType: + # Layout == strided, pin_memory is False + utils.check_layout(layout) + utils.check_pin_memory(pin_memory) + + shape = utils.extract_shape_from_varargs(shape) + dtype = torch.get_default_dtype() if dtype is None else dtype + device = torch.device("cpu") if device is None else device + + return prims.empty_strided( + shape, + strides, + dtype=dtype, + device=device, + requires_grad=requires_grad, + ) + + +def _strength_reduce_integer(val: int) -> torch.dtype: + for possible_dtype in (torch.uint8, torch.uint16, torch.int32): + if val <= torch.iinfo(possible_dtype).max: + return possible_dtype + return torch.int64 + + +@register_decomposition(aten.eye) +@out_wrapper() +def eye( + n: int, + m: Optional[int] = None, + *, + dtype: Optional[torch.dtype] = None, + layout: torch.layout = torch.strided, + device: Optional[DeviceLikeType] = None, + pin_memory: bool = False, + requires_grad: bool = False, # TODO: unused +) -> TensorLikeType: + """ + Reference implementation of torch.eye + """ + if m is None: + m = n + + torch._check(n >= 0, lambda: f"n must be greater or equal to 0, got {n}") + torch._check(m >= 0, lambda: f"m must be greater or equal to 0, got {m}") + + range_dtype = torch.int64 + if isinstance(n, utils.IntWithoutSymInt) and isinstance(m, utils.IntWithoutSymInt): + range_dtype = _strength_reduce_integer(max(n, m)) + range_n = torch.arange(n, dtype=range_dtype, device=device, requires_grad=False) + range_m = torch.arange(m, dtype=range_dtype, device=device, requires_grad=False) + + cond = range_n.unsqueeze(-1) == range_m + if layout in (torch.strided, None) and not pin_memory: + return cond.to(dtype or torch.get_default_dtype()) + else: + one = torch.ones( + (1,), + dtype=dtype, + layout=layout, + device=device, + pin_memory=pin_memory, + requires_grad=False, + ) + return torch.where(cond, one, 0) + # TODO: Use requires_grad. All refs taking the requires_grad kwarg must + # return a leaf tensor. + # result.requires_grad_(requires_grad) + + +@register_decomposition([aten.full.default, aten.full.out]) +@out_wrapper() +def full( + shape: ShapeType, + fill_value: NumberType, + *, + dtype: Optional[torch.dtype] = None, + layout: torch.layout = torch.strided, + device: Optional[DeviceLikeType] = None, + pin_memory: bool = False, + requires_grad: bool = False, +) -> TensorLikeType: + utils.check_layout(layout) + utils.check_pin_memory(pin_memory) + + dtype = dtype if dtype is not None else utils.type_to_dtype(type(fill_value)) + device = device if device is not None else torch.device("cpu") + + e = empty( + shape, + dtype=dtype, + layout=layout, + device=device, + pin_memory=pin_memory, + requires_grad=requires_grad, + ) + return torch.fill(e, fill_value) # type: ignore[arg-type] + + +def full_like( + a: TensorLikeType, + fill_value: NumberType, + *, + dtype: Optional[torch.dtype] = None, + layout: Optional[torch.layout] = None, + device: Optional[DeviceLikeType] = None, + pin_memory: bool = False, + requires_grad: bool = False, + memory_format: torch.memory_format = torch.preserve_format, +) -> TensorLikeType: + e = torch.empty_like( + a, + dtype=dtype, + layout=layout, + device=device, + pin_memory=pin_memory, + requires_grad=requires_grad, + memory_format=memory_format, + ) + return fill(e, fill_value) + + +@register_decomposition(aten.zeros_like) +@out_wrapper() +def zeros_like( + a: TensorLikeType, + *, + dtype: Optional[torch.dtype] = None, + layout: Optional[torch.layout] = None, + device: Optional[DeviceLikeType] = None, + pin_memory: bool = False, + requires_grad: bool = False, + memory_format: torch.memory_format = torch.preserve_format, +) -> TensorLikeType: + return torch.full_like( + a, + False if (dtype or a.dtype) == torch.bool else 0, + dtype=dtype, + layout=layout, + device=device, + pin_memory=pin_memory, + requires_grad=requires_grad, + memory_format=memory_format, + ) + + +@register_decomposition(aten.ones_like) +@out_wrapper() +def ones_like( + a: TensorLikeType, + *, + dtype: Optional[torch.dtype] = None, + layout: Optional[torch.layout] = None, + device: Optional[DeviceLikeType] = None, + pin_memory: bool = False, + requires_grad: bool = False, + memory_format: torch.memory_format = torch.preserve_format, +) -> TensorLikeType: + return torch.full_like( + a, + True if (dtype or a.dtype) == torch.bool else 1, + dtype=dtype, + layout=layout, + device=device, + pin_memory=pin_memory, + requires_grad=requires_grad, + memory_format=memory_format, + ) + + +@register_decomposition(aten.randn.default) +@out_wrapper() +def randn( + *shape, + dtype: Optional[torch.dtype] = None, + device: Optional[DeviceLikeType] = None, + layout: Optional[torch.layout] = None, + requires_grad: bool = False, + pin_memory: bool = False, +) -> TensorLikeType: + utils.check_pin_memory(pin_memory) + + shape_ = utils.extract_shape_from_varargs(shape) + + dtype = utils.dtype_or_default(dtype) + device = utils.device_or_default(device) + + return prims.normal( + shape_, + mean=0.0, + std=1.0, + dtype=dtype, + device=device, + requires_grad=requires_grad, + ) + + +def scalar_tensor( + a: NumberType, + *, + dtype: Optional[torch.dtype] = None, + layout: torch.layout = torch.strided, + device: Optional[DeviceLikeType] = None, + pin_memory: bool = False, +) -> TensorLikeType: + utils.check_layout(layout) + utils.check_pin_memory(pin_memory) + dtype = dtype if dtype is not None else utils.type_to_dtype(type(a)) + device = device if device is not None else torch.device("cpu") + return prims.scalar_tensor(a, dtype=dtype, device=device) + + +# +# Randomness References +# + + +def _uniform_helper( + shape: ShapeType, + low: Union[bool, int, float] = 0.0, + high: Union[bool, int, float] = 1.0, + *, + dtype: torch.dtype, + device: DeviceLikeType, +) -> TensorLikeType: + utils.validate_shape(shape) + + assert isinstance(low, Number) + assert isinstance(high, Number) + low = sym_float(low) + high = sym_float(high) + + assert isinstance(dtype, torch.dtype) + device = utils.canonicalize_device(device) + + return prims._uniform_helper(shape, low=low, high=high, dtype=dtype, device=device) + + +@register_decomposition(aten.masked_fill) +@out_wrapper() +def masked_fill(a: TensorLikeType, mask: TensorLikeType, value: TensorOrNumberLikeType): + python_type = utils.dtype_to_type(a.dtype) + if isinstance(value, Number): + value_type = type(value) + else: + # NOTE: Could not use value = item(value) as it resulted in + # RuntimeError: Cannot cast FakeTensor(cpu) to number + value_ndim = value.ndim + torch._check( + value_ndim == 0, + lambda: f"only supports a 0-dimensional value tensor, but got tensor with {value_ndim} dimension", + ) + # `masked_fill` allows cpu scalar to be moved to cuda, xpu and hpu but not otherwise. + is_cpu_scalar = ( + a.device.type + in ["cuda", "xpu", "mps", torch._C._get_privateuse1_backend_name(), "hpu"] + and value.device.type == "cpu" + ) + torch._check( + is_cpu_scalar or value.device == a.device, + lambda: "Expected `value` to be on same device as `a`", + ) + value_type = utils.dtype_to_type(value.dtype) + + if value_type is complex: + # only downcasting from complex to lower type is not allowed. + # We allow casting `value` to lower type for other case + # Eg. float -> int. + # Ref: https://github.com/pytorch/pytorch/issues/79195 + torch._check( + utils.is_weakly_lesser_type(value_type, python_type), + lambda: f"could not convert to type {python_type} without overflow", + ) + + # Since `where` allows type-promotion, + # cast value to correct type before passing to `where` + # pyrefly: ignore [no-matching-overload] + value = _maybe_convert_to_dtype(value, a.dtype) + r = torch.where(mask, value, a) # type: ignore[arg-type] + + # aten.mask_fill always return a new contiguous tensor + # contiguous() is needed to correctly model the output stride + return r.contiguous() + + +@register_decomposition(aten.masked_fill_) +def masked_fill_( + a: TensorLikeType, mask: TensorLikeType, value: TensorOrNumberLikeType +) -> TensorLikeType: + b = torch.masked_fill(a, mask, value) # type: ignore[arg-type] + a.copy_(b) + return a + + +# CompositeImplicitAutograd - don't register decomp +def allclose( + a: TensorLikeType, + b: TensorLikeType, + rtol: float = 1e-05, + atol: float = 1e-08, + equal_nan: bool = False, +) -> bool: + """ + Reference implementation of torch.allclose + """ + _check_close_args(name="torch.allclose", a=a, b=b, rtol=rtol, atol=atol) + + return bool( + torch.all(torch.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)).item() + ) + + +def equal(a: TensorLikeType, b: TensorLikeType) -> bool: + utils.check_same_device(a, b, allow_cpu_scalar_tensors=False) + utils.check_same_dtype(a, b) + + # Shape check + if a.ndim != b.ndim: + return False + + for x, y in zip(a.shape, b.shape): + if x != y: + return False + + # Short-circuits if there are no elements to validate + if a.numel() == 0: + return True + + return item(all(eq(a, b))) # type: ignore[return-value] + + +@register_decomposition(aten.norm) +@out_wrapper(exact_dtype=True) +def norm( + input: TensorLikeType, + p: Optional[Union[float, str]] = "fro", + dim: Optional[DimsType] = None, + keepdim: bool = False, + *, + dtype: Optional[torch.dtype] = None, +) -> TensorLikeType: + # In these cases we compute the "Frobenius norm" + if ( + p == "fro" and (dim is None or isinstance(dim, Dim) or len(dim) <= 2) + ) or p is None: + p = 2 + if isinstance(dim, Dim): + dim = [dim] + if isinstance(p, str): + # Here we either call the nuclear norm, or we call matrix_norm with some arguments + # that will throw an error + if dim is None: + dim = tuple(range(input.ndim)) + return torch.linalg.matrix_norm(input, p, dim, keepdim, dtype=dtype) + else: + return torch.linalg.vector_norm(input, p, dim, keepdim, dtype=dtype) + + +@register_decomposition(aten.trace) +@out_wrapper() +def trace(self: TensorLikeType) -> TensorLikeType: + torch._check( + self.ndim == 2, + lambda: f"expected a matrix, but got tensor with dim {self.ndim}", + ) + return torch.sum(torch.diag(self, 0)) + + +def _make_r_binary_op(base_op): + def rop( + a: Union[TensorLikeType, NumberType], + b: Union[TensorLikeType, NumberType], + ) -> TensorLikeType: + return base_op(b, a) + + return rop + + +rtruediv = _make_r_binary_op(true_divide) +rfloordiv = _make_r_binary_op(floor_divide) +rpow = _make_r_binary_op(pow) + + +@register_decomposition(aten.triu) +@out_wrapper() +def triu(a: TensorLikeType, diagonal: int = 0) -> TensorLikeType: + torch._check( + a.ndim >= 2, lambda: "triu: input tensor must have at least 2 dimensions" + ) + h, w = a.shape[-2:] + mask = ( + torch.arange(w, device=a.device).unsqueeze(-2) + - torch.arange(h, device=a.device).unsqueeze(-1) + ) >= diagonal + + # aten.triu always returns a new contiguous tensor + # contiguous() is needed to correctly model the output stride + return utils.mask_tensor(mask, a).contiguous() + + +@register_decomposition(aten.tril) +@out_wrapper() +def tril(a: TensorLikeType, diagonal: int = 0) -> TensorLikeType: + torch._check( + a.ndim >= 2, lambda: "tril: input tensor must have at least 2 dimensions" + ) + h, w = a.shape[-2:] + mask = ( + torch.arange(w, device=a.device).unsqueeze(-2) + - torch.arange(h, device=a.device).unsqueeze(-1) + ) <= diagonal + + # aten.tril always returns a new contiguous tensor + # contiguous() is needed to correctly model the output stride + return utils.mask_tensor(mask, a).contiguous() + + +# This is based on get_tril_size in aten/src/ATen/native/TensorFactories.h +# The components of the matrix that belong to the lower triangle with offset +# form a pentagon that can be broken down into a top trapezoid and a bottom +# rectangle. For the implementation of tril_indices, we need the sizes of +# both of these, as well as the length of the top side of the trapezoid. +def _get_tril_sizes(row: int, col: int, offset: int) -> tuple[int, int, int]: + if row == 0 or col == 0: + return 0, 0, 0 + + m_first_row = min(col, 1 + offset) if offset > 0 else int(row + offset > 0) + m_last_row = max(0, min(col, row + offset)) + n_row_all = max(0, min(row, row + offset)) + n_row_trapezoid = m_last_row - m_first_row + 1 + + # Number of elements in top trapezoid + trapezoid_size = (m_first_row + m_last_row) * n_row_trapezoid // 2 + # Number of elements in bottom rectangle + diff_row = n_row_all - n_row_trapezoid + rectangle_size = max(0, diff_row * col) + + return trapezoid_size, rectangle_size, m_first_row + + +def _trilu_checks( + name: str, + row: int, + col: int, + dtype: torch.dtype, + layout: torch.layout, + pin_memory: bool, +): + torch._check(row >= 0, lambda: f"row must be non-negative, got {row}") + torch._check(col >= 0, lambda: f"col must be non-negative, got {col}") + torch._check( + dtype in (torch.int32, torch.int64), + lambda: f"\"{name}\" not implemented for '{dtype}'", + ) + + +# This is based on tril_indices_cuda in aten/src/ATen/native/cuda/TensorFactories.cu +@register_decomposition(aten.tril_indices) +@out_wrapper() +def tril_indices( + row: int, + col: int, + offset: int = 0, + *, + dtype: torch.dtype = torch.long, + layout: torch.layout = torch.strided, + device: DeviceLikeType = "cpu", + pin_memory: bool = False, +) -> TensorLikeType: + _trilu_checks("tril_indices", row, col, dtype, layout, pin_memory) + + trapezoid_size, rectangle_size, m_first_row = _get_tril_sizes(row, col, offset) + row_offset = max(0, -offset) + + arange_kw = partial( + torch.arange, layout=layout, device=device, pin_memory=pin_memory + ) + + # first we do the indices for top trapezoid + xs1 = arange_kw(0, trapezoid_size, dtype=torch.float64) + b = m_first_row - 0.5 + row_inds1 = torch.floor(-b + torch.sqrt(b * b + 2 * xs1)) + col_inds1 = torch.floor(xs1 - (2 * m_first_row - 1 + row_inds1) * row_inds1 * 0.5) + row_inds1 = _maybe_convert_to_dtype(row_inds1 + row_offset, dtype) + col_inds1 = _maybe_convert_to_dtype(col_inds1, dtype) + + # then bottom rectangle + xs2 = arange_kw(0, rectangle_size, dtype=dtype) + row_inds2 = xs2 // col + (col - m_first_row + 1 + row_offset) + col_inds2 = xs2 % col + + return torch.stack( + (torch.cat((row_inds1, row_inds2)), torch.cat((col_inds1, col_inds2))) + ) + + +# Similar to _get_tril_sizes above, but here there is a top trapezoid and +# a bottom rectangle instead. Note that you can't reduce this to +# _get_tril_sizes(col, row, -offset) because that would correspond to +# decomposing into a left trapezoid and right rectangle. +def _get_triu_sizes(row: int, col: int, offset: int) -> tuple[int, int, int]: + if row == 0 or col == 0: + return 0, 0, 0 + + m_first_row = max(0, col - offset) if offset > 0 else col + + # Number of elements in top rectangle + rectangle_size = max(0, min(row, -offset) * col) + + # Number of elements in bottom trapezoid + trapezoid_size_tril, rectangle_size_tril, _ = _get_tril_sizes(row, col, offset - 1) + triu_size = row * col - (trapezoid_size_tril + rectangle_size_tril) + trapezoid_size = triu_size - rectangle_size + + return trapezoid_size, rectangle_size, m_first_row + + +@register_decomposition(aten.triu_indices) +@out_wrapper() +def triu_indices( + row: int, + col: int, + offset: int = 0, + *, + dtype: torch.dtype = torch.long, + layout: torch.layout = torch.strided, + device: DeviceLikeType = "cpu", + pin_memory: bool = False, +) -> TensorLikeType: + _trilu_checks("triu_indices", row, col, dtype, layout, pin_memory) + + trapezoid_size, rectangle_size, m_first_row = _get_triu_sizes(row, col, offset) + col_offset = max(0, offset) + + arange_kw = partial( + torch.arange, layout=layout, device=device, pin_memory=pin_memory + ) + + # indices for top rectangle + xs2 = arange_kw(0, rectangle_size, dtype=dtype) + row_inds2 = xs2 // col + col_inds2 = xs2 % col + + # bottom trapezoid + xs1 = arange_kw(0, trapezoid_size, dtype=torch.float64) + b = -0.5 - m_first_row + row_inds1 = torch.floor(-b - torch.sqrt(b * b - 2 * xs1)) + col_inds1 = torch.floor(xs1 - ((2 * m_first_row - 1 - row_inds1) * row_inds1) * 0.5) + row_inds1 = _maybe_convert_to_dtype(row_inds1, dtype) + col_inds1 = _maybe_convert_to_dtype(col_inds1, dtype) + + if col: + row_inds1 = row_inds1 + (rectangle_size // col) + col_inds1 = col_inds1 + col_offset + + return torch.stack( + (torch.cat((row_inds2, row_inds1)), torch.cat((col_inds2, col_inds1))) + ) + + +@register_decomposition(aten.bucketize) +@out_wrapper(exact_dtype=True) +def bucketize( + a: TensorOrNumberLikeType, + boundaries: TensorLikeType, + *, + out_int32: bool = False, + right: bool = False, +): + torch._check( + boundaries.dim() == 1, + lambda: f"boundaries tensor must be 1 dimension but got dim({boundaries.dim()})", + ) + + a = a if isinstance(a, torch.Tensor) else torch.tensor(a) + out_dtype = torch.int32 if out_int32 else torch.int64 + n_boundaries = boundaries.shape[-1] + if n_boundaries == 0: + return torch.zeros_like(a) + # We are trying to find the bucket (defined by pairs of consecutive elements of `boundaries`) + # each element of `a` belongs to. We use binary search to achieve logarithmic complexity, + # but each step of the search is done "in parallel" over all elements of `a` + # can't use int32 as indexes, so we have to do all computations with int64 and convert at the end + start = torch.zeros(a.shape, device=a.device, dtype=torch.int64) + end = start + n_boundaries + # Max depth of the binary search + # Since we can't break out of the loop at different points for different elements of a, + # we just do the max amount of iterations that binary search requires and add condition + # tensor (cond_update below) to stop updating once the search terminates + + # For first iteration through loop we can skip some checks, we have separate implementation + mid = start + (end - start) // 2 + mid_val = boundaries[mid] + if right: + cond_mid = mid_val > a + else: + cond_mid = mid_val >= a + start = torch.where(cond_mid, start, mid + 1) + + if n_boundaries > 1: + cond_update = torch.ones_like(a, dtype=torch.bool) + niters = int(math.log2(n_boundaries)) + for _ in range(niters): + end = torch.where(cond_mid & cond_update, mid, end) + cond_update = start < end + # start might end up pointing to 1 past the end, we guard against that + mid = torch.where(cond_update, start + (end - start) // 2, 0) + mid_val = boundaries[mid] + # If right is true, the buckets are closed on the *left* + # (i.e., we are doing the equivalent of std::upper_bound in C++) + # Otherwise they are closed on the right (std::lower_bound) + if right: + cond_mid = mid_val > a + else: + cond_mid = mid_val >= a + start = torch.where((~cond_mid) & cond_update, mid + 1, start) + + return start.to(dtype=out_dtype) + + +@register_decomposition(aten.cauchy) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("self",), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def cauchy(self, median=0, sigma=1, generator=None): + assert generator is None + torch._check( + not utils.is_complex_dtype(self.dtype) + and not utils.is_integer_dtype(self.dtype) + and not utils.is_boolean_dtype(self.dtype), + lambda: f"Cauchy distribution is a continuous probability distribution. \ + dtype must be a floating point but you specified {self.dtype}", + ) + torch._check( + sigma > 0.0, + lambda: f"cauchy_ expects sigma > 0.0, but found sigma={sigma}", + ) + return median + sigma * torch.tan(math.pi * (torch.rand_like(self) - 0.5)) + + +@register_decomposition(aten.exponential) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("self",), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def exponential(self, rate=1, generator=None): + assert generator is None + torch._check( + not utils.is_complex_dtype(self.dtype) + and not utils.is_integer_dtype(self.dtype) + and not utils.is_boolean_dtype(self.dtype), + lambda: f"Exponential distribution is a continuous probability distribution. \ + dtype must be a floating point but you specified {self.dtype}", + ) + torch._check( + rate > 0.0, + lambda: f"exponential_ expects lambda > 0.0, but found lambda={rate}", + ) + + uniform_val = torch.rand_like(self) + + # copying numerics of transformation::exponential see comment: + # curand_uniform has (0,1] bounds. log(1) is 0 and exponential excludes 0. + # we need log to be not 0, and not underflow when converted to half + # fast __logf approximation can underflow, so set log to -epsilon/2 for 1 or close to 1 args + epsilon = torch.finfo(uniform_val.dtype).eps / 2 + condition = uniform_val >= 1.0 - epsilon + log_uniform = torch.where(condition, -epsilon, torch.log(uniform_val)) + + return -1 / rate * log_uniform + + +@register_decomposition(aten.geometric) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("self",), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def geometric(self, p, generator=None): + assert generator is None + # TODO: fix inductor rand_like for integer, bool dtypes + torch._check( + not utils.is_complex_dtype(self.dtype) + and not utils.is_boolean_dtype(self.dtype), + lambda: f"geometric not implemented for {self.dtype}", + ) + torch._check( + 0 < p and p < 1, + lambda: f"geometric_ expects p to be in (0, 1), but got p={p}", + ) + return torch.floor(torch.log1p(-torch.rand_like(self)) / math.log1p(-p)) + 1 + + +@register_decomposition(aten.log_normal) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("self",), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def log_normal(self, mean=1, std=2, generator=None): + assert generator is None + torch._check( + not utils.is_complex_dtype(self.dtype) + and not utils.is_integer_dtype(self.dtype) + and not utils.is_boolean_dtype(self.dtype), + lambda: f"log_normal not implemented for {self.dtype}", + ) + torch._check( + 0 < std, + lambda: f"log_normal_ expects std > 0.0, but found std={std}", + ) + return torch.exp(std * torch.randn_like(self) + mean) + + +# TODO: add support for functionalization aten.normal_functional +# NOTE: the device and dtype will be ignored when shape is None +@register_decomposition(aten.normal) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=( + "mean", + "std", + ), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def normal( + mean=0, + std=1, + size=None, + *, + generator=None, + dtype=None, + layout=None, + device=None, + pin_memory=None, +): + assert layout is None or layout == torch.strided + + if not isinstance(std, TensorLike): + torch._check( + std >= 0, lambda: f"normal expects std >= 0.0, but found std {std}" + ) + + if size is None: + tensors = tuple(t for t in (mean, std) if isinstance(t, TensorLike)) + torch._check( + len(tensors) > 0, + lambda: "normal expects that either mean or std is a tensor, or size is defined", + ) + torch._check( + layout is None and pin_memory is None, + lambda: "Cannot pass layout, or pin_memory without size", + ) + + size = _broadcast_shapes(*(t.shape for t in tensors)) + dtype = tensors[0].dtype + device = tensors[0].device + else: + torch._check( + not isinstance(mean, TensorLike) and not isinstance(std, TensorLike), + lambda: "normal expects mean and std to be scalars when size is defined", + ) + dtype = torch.get_default_dtype() if dtype is None else dtype + device = torch.device("cpu") if device is None else device + + normal_samples = prims.normal( + size, + mean=0.0, + std=1.0, + dtype=dtype, + device=device, + requires_grad=False, + generator=generator, + ) + return std * normal_samples + mean + + +@register_decomposition(aten.normal_) +def normal_(self, mean=0, std=1, *, generator=None): + return normal(mean, std, self.shape, out=self, generator=generator) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def rad2deg(self: TensorLikeType): + torch._check( + not utils.is_complex_dtype(self.dtype), + lambda: "rad2deg is not supported for complex tensors.", + ) + M_180_PI = 57.295779513082320876798154814105170332405472466564 + return self * M_180_PI + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def deg2rad(self: TensorLikeType): + torch._check( + not utils.is_complex_dtype(self.dtype), + lambda: "deg2rad is not supported for complex tensors.", + ) + M_PI_180 = 0.017453292519943295769236907684886127134428718885417 + return self * M_PI_180 + + +@register_decomposition(aten.count_nonzero) +@out_wrapper() +def count_nonzero(self, dim: Optional[DimsType] = None): + return (self != 0).sum(dim) + + +def _dot_check(self, other): + torch._check( + self.dim() == 1 and other.dim() == 1, + lambda: f"1D tensors expected, but got {self.dim()}D and {other.dim()}D tensors", + ) + + torch._check( + self.dtype == other.dtype, + lambda: "dot : expected both vectors to have same dtype, but found " + f"{self.dtype} and {other.dtype}", + ) + + def numel_error(): + return ( + f"inconsistent tensor size, expected tensor [{self.numel()}] and src [{other.numel()}] to have the" + f"same number of elements, but got {self.numel()} and {other.numel()} elements respectively" + ) + + torch._check(self.numel() == other.numel(), numel_error) + + +def _dot_check_wrapper(fn): + @wraps(fn) + def wrapper(self, other): + _dot_check(self, other) + return fn(self, other) + + return wrapper + + +@register_decomposition(aten.dot) +@out_wrapper(exact_dtype=True) +@_dot_check_wrapper +@elementwise_type_promotion_wrapper( + type_promoting_args=("self", "other"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def dot(self, other): + if self.is_complex(): + if self.is_conj(): + if other.is_conj(): + return torch.dot(self.conj(), other.conj()).conj() + else: + return torch.vdot(self.conj(), other) + elif other.is_conj(): + return torch.vdot(other.conj(), self) + + return (self * other).sum() + + +@register_decomposition(aten.vdot) +@out_wrapper(exact_dtype=True) +@_dot_check_wrapper +@elementwise_type_promotion_wrapper( + type_promoting_args=("self", "other"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def vdot(self, other): + if not self.is_complex(): + return torch.dot(self, other) + + if self.is_conj(): + if other.is_conj(): + return torch.vdot(other.conj(), self.conj()) + else: + return torch.dot(self.conj(), other) + elif other.is_conj(): + return torch.dot(self, other.conj()).conj() + + # The decomposition fails if you do self.conj()... not sure why + return (self.conj_physical() * other).sum() + + +@register_decomposition(aten.select_scatter) +@out_wrapper() +def select_scatter(x: TensorLikeType, src: TensorLikeType, dim: int, index: int): + dim = utils.canonicalize_dim(x.ndim, dim) + mask_shape = [1] * x.ndim + mask_shape[dim] = -1 + if index < 0: + index = index + x.shape[dim] + mask = torch.arange(x.shape[dim], device=x.device).view(mask_shape) == index + src = torch.unsqueeze(src, dim).expand(x.shape) + return torch.where(mask, src, x) + + +# inplace +abs_ = _make_inplace(abs) +acos_ = _make_inplace(acos) +acosh_ = _make_inplace(acosh) +add_ = _make_inplace(add) +addcmul_ = _make_inplace(addcmul) +addcdiv_ = _make_inplace(addcdiv) +asin_ = _make_inplace(asin) +asinh_ = _make_inplace(asinh) +atan_ = _make_inplace(atan) +atanh_ = _make_inplace(atanh) +atan2_ = _make_inplace(atan2) +bitwise_and_ = _make_inplace(bitwise_and) +bitwise_left_shift_ = _make_inplace(bitwise_left_shift) +bitwise_not_ = _make_inplace(bitwise_not) +bitwise_or_ = _make_inplace(bitwise_or) +bitwise_right_shift_ = _make_inplace(bitwise_right_shift) +bitwise_xor_ = _make_inplace(bitwise_xor) +ceil_ = _make_inplace(ceil) +clamp_ = _make_inplace(clamp) +clamp_min_ = _make_inplace(clamp_min) +clamp_max_ = _make_inplace(clamp_max) +conj_physical_ = _make_inplace(conj_physical) +copysign_ = _make_inplace(copysign) +cos_ = _make_inplace(cos) +cosh_ = _make_inplace(cosh) +cumsum_ = _make_inplace(cumsum) +cumprod_ = _make_inplace(cumprod) +deg2rad_ = _make_inplace(deg2rad) +digamma_ = _make_inplace(digamma) +div_ = _make_inplace(div) +eq_ = _make_inplace(eq) +erf_ = _make_inplace(erf) +erfc_ = _make_inplace(erfc) +erfinv_ = _make_inplace(erfinv) +exp_ = _make_inplace(exp) +exp2_ = _make_inplace(exp2) +expm1_ = _make_inplace(expm1) +float_power_ = _make_inplace(float_power) +floor_ = _make_inplace(floor) +floor_divide_ = _make_inplace(floor_divide) +fmod_ = _make_inplace(fmod) +frac_ = _make_inplace(frac) +gcd_ = _make_inplace(gcd) +ge_ = _make_inplace(ge) +gt_ = _make_inplace(gt) +heaviside_ = _make_inplace(heaviside) +hypot_ = _make_inplace(hypot) +igamma_ = _make_inplace(igamma) +igammac_ = _make_inplace(igammac) +i0_ = _make_inplace(i0) +lcm_ = _make_inplace(lcm) +le_ = _make_inplace(le) +lerp_ = _make_inplace(lerp) +lgamma_ = _make_inplace(lgamma) +log10_ = _make_inplace(log10) +log1p_ = _make_inplace(log1p) +log2_ = _make_inplace(log2) +log_ = _make_inplace(log) +logical_and_ = _make_inplace(logical_and) +logical_not_ = _make_inplace(logical_not) +logical_or_ = _make_inplace(logical_or) +logical_xor_ = _make_inplace(logical_xor) +lt_ = _make_inplace(lt) +mul_ = _make_inplace(mul) +mvlgamma_ = _make_inplace(mvlgamma) +nan_to_num_ = _make_inplace(nan_to_num) +ne_ = _make_inplace(ne) +neg_ = _make_inplace(neg) +nextafter_ = _make_inplace(nextafter) +pow_ = _make_inplace(pow) +rad2deg_ = _make_inplace(rad2deg) +reciprocal_ = _make_inplace(reciprocal) +remainder_ = _make_inplace(remainder) +rsqrt_ = _make_inplace(rsqrt) +sgn_ = _make_inplace(sgn) +sigmoid_ = _make_inplace(sigmoid) +sign_ = _make_inplace(sign) +sin_ = _make_inplace(sin) +sinc_ = _make_inplace(sinc) +sinh_ = _make_inplace(sinh) +sqrt_ = _make_inplace(sqrt) +square_ = _make_inplace(square) +sub_ = _make_inplace(sub) +tan_ = _make_inplace(tan) +tanh_ = _make_inplace(tanh) +tril_ = _make_inplace(tril) +triu_ = _make_inplace(triu) +true_divide_ = _make_inplace(true_divide) +trunc_ = _make_inplace(trunc) +xlogy_ = _make_inplace(xlogy) +cauchy_ = _make_inplace(cauchy) +exponential_ = _make_inplace(exponential) +geometric_ = _make_inplace(geometric) +log_normal_ = _make_inplace(log_normal) +zero_ = _make_inplace(zero) + +alias_copy = _make_copy_from_view(aten.alias) +as_strided_copy = _make_copy_from_view(aten.as_strided) +diagonal_copy = _make_copy_from_view(aten.diagonal) +expand_copy = _make_copy_from_view(aten.expand) +# TODO: This must return a sparse tensor if the input is sparse, but refs have +# no sparse support. See narrow_copy_sparse in core. +narrow_copy = _make_copy_from_view(aten.narrow) +squeeze_copy = _make_copy_from_view(aten.squeeze) +permute_copy = _make_copy_from_view(aten.permute) +t_copy = _make_copy_from_view(aten.t) +transpose_copy = _make_copy_from_view(aten.transpose) +unbind_copy = _make_copy_from_view(aten.unbind, return_none_on_out_variant=True) +unsqueeze_copy = _make_copy_from_view(aten.unsqueeze) +view_copy = _make_copy_from_view(aten.view) + + +# xref: isStorage in torch/csrc/DynamicTypes.cpp +def _isStorage(obj): + return isinstance(obj, (torch.TypedStorage, torch.UntypedStorage)) + + +# xref: compute_sizes in torch/csrc/utils/tensor_new.cpp +def _compute_sizes(seq, scalar_type): + MAX_DIMS = 128 + is_storage = _isStorage(seq) + sizes = [] + # TODO: this is inaccurate, we actually test PySequence_Check + while isinstance(seq, (list, tuple)): + length = len(seq) + if is_storage: + length //= scalar_type.itemsize + sizes.append(length) + if len(sizes) > MAX_DIMS: + raise ValueError(f"too many dimensions '{type(seq).__name__}'") + if length == 0: + break + try: + handle = seq[0] + except Exception: + raise ValueError( # noqa: B904 + f"could not determine the shape of object type '{type(seq).__name__}'" + ) + seq = handle + + return sizes + + +# xref: infer_scalar_type in torch/csrc/utils/tensor_new.cpp +def _infer_scalar_type(obj): + if isinstance(obj, FloatLike): + return torch.get_default_dtype() + if isinstance(obj, IntLike) and not isinstance(obj, bool): # careful! + return torch.int64 + if isinstance(obj, BoolLike): + return torch.bool + if isinstance(obj, complex): + default_dtype = torch.get_default_dtype() + if default_dtype is torch.float: + return torch.cfloat + elif default_dtype is torch.double: + return torch.cdouble + elif default_dtype is torch.half: + return torch.chalf + else: + raise RuntimeError("invalid default scalar type for complex") + if isinstance(obj, torch.Tensor): + return obj.dtype + if isinstance(obj, str): + raise TypeError(f"new(): invalid data type '{type(obj).__name__}'") + # TODO: this is inaccurate, we actually test PySequence_Check + if isinstance(obj, (list, tuple)): + scalarType = None + length = len(obj) + # match NumPy semantics, except use default tensor type instead of + # double. + if length == 0: + return torch.get_default_dtype() + + for i in range(length): + cur_item = obj[i] + # TODO: test this + """ + if cur_item is obj: + raise TypeError("new(): self-referential lists are incompatible") + """ + item_scalarType = _infer_scalar_type(cur_item) # recurse! + if scalarType is not None: + scalarType = torch.promote_types(scalarType, item_scalarType) + else: + scalarType = item_scalarType + if scalarType is torch.cdouble: + # this won't change (unless we hit undefined, but that will + # fail later) + return scalarType + return scalarType + raise RuntimeError(f"Could not infer dtype of {type(obj).__name__}") + + +# Analogous to recursive_store +# xref: recursive_store in torch/csrc/utils/tensor_new.cpp +def _recursive_build( + scalarType: torch.dtype, obj: Union[TensorOrNumberLikeType, TensorSequenceType] +): + if isinstance(obj, Tensor) and obj.numel() == 1: + return obj.detach().to(dtype=scalarType, device="cpu", copy=True).view(()) + elif isinstance(obj, Tensor): + # It is invalid to call ".tensor([...])" with a non-scalar tensor in eager mode + # >>> torch.tensor([torch.randn(2)]) + # ValueError: only one element tensors can be converted to Python scalars + # + # But it is possible with a NumPy array + # >>> torch.tensor([np.random.uniform(size=(2,))]).shape + # torch.Size([1, 2]) + return obj.detach().to(dtype=scalarType, device="cpu", copy=True) + elif isinstance(obj, Number): + # pyrefly: ignore [bad-argument-type] + return torch.scalar_tensor(obj, dtype=scalarType) + + # seq can be a list of tensors + seq = obj + return ( + torch.empty(0) + if not seq + else torch.stack([_recursive_build(scalarType, item) for item in seq]) + ) + + +# xref: internal_new_from_data in torch/csrc/utils/tensor_new.cpp +def _internal_new_from_data( + options, + scalar_type, + device_opt, + data, + copy_variables, + copy_numpy, + type_inference, + pin_memory=False, +): + if isinstance(data, torch.Tensor): + torch._check( + not pin_memory, lambda: "Can't pin tensor constructed from a variable" + ) + var = data + if copy_variables: + var = var.detach() + inferred_scalar_type = var.dtype if type_inference else scalar_type + device = device_opt if device_opt is not None else var.device + return var.to( + device=device, + dtype=inferred_scalar_type, + non_blocking=False, + copy=copy_variables, + ) + + # TODO + if hasattr(data, "__cuda_array_interface__"): + return NotImplemented + + # TODO: test for numpy input with PyArray_Check + + device = device_opt if device_opt is not None else options["device"] + inferred_scalar_type = _infer_scalar_type(data) if type_inference else scalar_type + + # NB: Don't need to avoid tracing, as we aren't going to do any manual + # pointer filling tricks + if _isStorage(data): + return NotImplemented + else: + if torch.device(device).type == "meta": + return NotImplemented + + # In the C implementation, we would directly start poking the memory + # of a freshly allocated CPU tensor. Here, we're going to do an + # alternate, heinously slow implementation: turn each individual + # scalar into a tensor, and then repeatedly cat them together + tensor = _recursive_build(inferred_scalar_type, data) + + tensor = tensor.to(device, inferred_scalar_type, non_blocking=False, copy=False) + + # NB: lift_fresh is not needed, because we built the tensor from scalars + # guaranteeing a fresh tensor in this case + return tensor + + +# xref: tensor_ctor in torch/csrc/utils/tensor_new.cpp +def tensor(data, *, dtype=None, device=None, pin_memory=False, requires_grad=False): + # TODO (or not): support names kwarg + if isinstance(data, torch.Tensor): + warnings.warn( + "To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() " + "or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor)", + UserWarning, + stacklevel=2, + ) + type_inference = dtype is None + new_tensor = _internal_new_from_data( + # device="cpu" because that's what you get with torch.tensor(2) no + # device by default + {"device": "cpu"}, # TODO: use torch.get_default_tensor_type + dtype if dtype is not None else torch.get_default_dtype(), + device, + data, + copy_variables=True, + copy_numpy=True, + type_inference=type_inference, + pin_memory=pin_memory, + ) + new_tensor.detach_() + if requires_grad: + new_tensor.requires_grad_(requires_grad) + return new_tensor + + +# Views +# We can't model these as above, as the pattern of doing `op(a, out=a)` does not work for a view function +# given that it does not reshape the input (it just copies the result into it) + +# squeeze_ = _make_inplace(squeeze) +# t_ = _make_inplace(t) +# transpose_ = _make_inplace(transpose) +# unsqueeze_ = _make_inplace(unsqueeze) + + +import torch._refs._conversions +import torch._refs.fft +import torch._refs.linalg +import torch._refs.nn.functional +import torch._refs.special diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_refs/_conversions.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_refs/_conversions.py new file mode 100644 index 0000000000000000000000000000000000000000..8092469741981efce3c53f424e3b2fb83a38e8eb --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_refs/_conversions.py @@ -0,0 +1,119 @@ +# mypy: allow-untyped-defs +import torch +import torch._prims_common as utils + +# Utilities should come BEFORE this import +from torch._decomp import register_decomposition +from torch._prims_common import TensorLikeType +from torch._prims_common.wrappers import out_wrapper +from torch._refs import _broadcast_shapes + + +# Data conversion references. +# +# Note: this module breaks the usual _refs to torch naming scheme where +# _refs.foo.bar is a ref for torch.foo.bar. The following definitions are not +# part of _refs/__init__.py to avoid name clashes with Python builtin types +# (like int). + +__all__ = [ + # dtypes + "bfloat16", + "bool", + "byte", + "cdouble", + "cfloat", + "chalf", + "char", + "double", + "float", + "half", + "int", + "long", + "short", + # misc + "complex", + "polar", +] + + +def _make_conversion_method(name: str, dtype: torch.dtype): + def fn( + self: TensorLikeType, memory_format: torch.memory_format = torch.preserve_format + ) -> TensorLikeType: + return self.to(dtype, memory_format=memory_format) # type: ignore[call-overload] + + fn.__name__ = name + return fn + + +bfloat16 = _make_conversion_method("bfloat16", torch.bfloat16) + +bool = _make_conversion_method("bool", torch.bool) + +byte = _make_conversion_method("byte", torch.uint8) + +cdouble = _make_conversion_method("cdouble", torch.cdouble) + +cfloat = _make_conversion_method("cfloat", torch.cfloat) + +chalf = _make_conversion_method("chalf", torch.complex32) + +char = _make_conversion_method("char", torch.int8) + +double = _make_conversion_method("double", torch.double) + +float = _make_conversion_method("float", torch.float) + +half = _make_conversion_method("half", torch.half) + +int = _make_conversion_method("int", torch.int) + +long = _make_conversion_method("long", torch.long) + +short = _make_conversion_method("short", torch.short) + + +@register_decomposition(torch._ops.ops.aten.complex) +# Note: complex has type promotion tests disabled due to different semantics. +# exact_dtype is for compat with complex_check_dtype from core. +@out_wrapper(exact_dtype=True) +def complex(real: TensorLikeType, imag: TensorLikeType) -> TensorLikeType: + allowed_dtypes = (torch.float32, torch.float64, torch.float16) + torch._check( + real.dtype in allowed_dtypes and imag.dtype in allowed_dtypes, + lambda: ( + f"Expected both inputs to be Half, Float or Double tensors but got " + f"{real.dtype} and {imag.dtype}" + ), + ) + torch._check( + real.dtype == imag.dtype, + lambda: ( + f"Expected object of scalar type {real.dtype} but got " + f"scalar type {imag.dtype} for second argument" + ), + ) + result_dtype = utils.corresponding_complex_dtype(real.dtype) # type: ignore[arg-type] + common_shape = _broadcast_shapes(real.shape, imag.shape) + result = real.new_empty( + common_shape, + dtype=result_dtype, + layout=real.layout, + device=real.device, + # pin_memory=real.is_pinned(), # NYI + ) + result.real = real + result.imag = imag + return result + + +@register_decomposition(torch._ops.ops.aten.polar) +# Note: polar has type promotion tests disabled due to different semantics. +# exact_dtype is for compat with complex_check_dtype from core. +@out_wrapper(exact_dtype=True) +def polar(abs: TensorLikeType, angle: TensorLikeType) -> TensorLikeType: + result = torch.complex(abs, angle) + result.real = abs * torch.cos(angle) + result.imag = abs * torch.sin(angle) + return result diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_refs/fft.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_refs/fft.py new file mode 100644 index 0000000000000000000000000000000000000000..e4e300bee62aa5c5eeb8130853882f2dc674d935 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_refs/fft.py @@ -0,0 +1,593 @@ +import math +from collections.abc import Iterable, Sequence +from typing import Literal, NamedTuple, Optional, Union + +import torch +import torch._prims as prims +import torch._prims_common as utils +from torch._decomp import register_decomposition +from torch._prims_common import DimsType, ShapeType, TensorLikeType +from torch._prims_common.wrappers import _maybe_convert_to_dtype, out_wrapper + + +__all__ = [ + # Transforms + "fft", + "fft2", + "fftn", + "hfft", + "hfft2", + "hfftn", + "rfft", + "rfft2", + "rfftn", + "ifft", + "ifft2", + "ifftn", + "ihfft", + "ihfft2", + "ihfftn", + "irfft", + "irfft2", + "irfftn", + # Helpers + "fftshift", + "ifftshift", +] + +NormType = Union[None, Literal["forward", "backward", "ortho"]] +_NORM_VALUES = {None, "forward", "backward", "ortho"} +aten = torch._ops.ops.aten + + +def _apply_norm( + x: TensorLikeType, norm: NormType, signal_numel: int, forward: bool +) -> TensorLikeType: + """Apply normalization to the un-normalized FFT result""" + torch._check(norm in _NORM_VALUES, lambda: f"Invalid normalization mode: {norm}") + + if norm == "ortho": + return x * (1 / math.sqrt(signal_numel)) + + normalize = (not forward and (norm is None or norm == "backward")) or ( + forward and norm == "forward" + ) + return x * (1 / signal_numel) if normalize else x + + +def _promote_type_fft( + dtype: torch.dtype, require_complex: bool, device: torch.device +) -> torch.dtype: + """Helper to promote a dtype to one supported by the FFT primitives""" + if dtype.is_complex: + return dtype + + # Promote integral to default float type + if not dtype.is_floating_point: + dtype = torch.get_default_dtype() + + allowed_types = [torch.float32, torch.float64] + maybe_support_half = device.type in ["cuda", "meta"] + + if maybe_support_half: + allowed_types.append(torch.float16) + torch._check(dtype in allowed_types, lambda: f"Unsupported dtype {dtype}") + + if require_complex: + dtype = utils.corresponding_complex_dtype(dtype) + + return dtype + + +def _maybe_promote_tensor_fft( + t: TensorLikeType, require_complex: bool = False +) -> TensorLikeType: + """Helper to promote a tensor to a dtype supported by the FFT primitives""" + cur_type = t.dtype + new_type = _promote_type_fft(cur_type, require_complex, t.device) + return _maybe_convert_to_dtype(t, new_type) # type: ignore[return-value] + + +def _resize_fft_input( + x: TensorLikeType, dims: tuple[int, ...], sizes: tuple[int, ...] +) -> TensorLikeType: + """ + Fixes the shape of x such that x.size(dims[i]) == sizes[i], + either by zero-padding, or by slicing x starting from 0. + """ + assert len(dims) == len(sizes) + must_copy = False + x_sizes = x.shape + pad_amount = [0] * len(x_sizes) * 2 + for i in range(len(dims)): + if sizes[i] == -1: + continue + + if x_sizes[dims[i]] < sizes[i]: + must_copy = True + pad_idx = len(pad_amount) - 2 * dims[i] - 1 + + pad_amount[pad_idx] = sizes[i] - x_sizes[dims[i]] + + if x_sizes[dims[i]] > sizes[i]: + x = x.narrow(dims[i], 0, sizes[i]) + + return torch.constant_pad_nd(x, pad_amount) if must_copy else x + + +def _fft_c2r( + func_name: str, + input: TensorLikeType, + n: Optional[int], + dim: int, + norm: NormType, + forward: bool, +) -> TensorLikeType: + """Common code for performing any complex to real FFT (irfft or hfft)""" + input = _maybe_promote_tensor_fft(input, require_complex=True) + dims = (utils.canonicalize_dim(input.ndim, dim, wrap_scalar=False),) + last_dim_size = n if n is not None else 2 * (input.shape[dim] - 1) + torch._check( + last_dim_size >= 1, + lambda: f"Invalid number of data points ({last_dim_size}) specified", + ) + + if n is not None: + input = _resize_fft_input(input, dims=dims, sizes=(last_dim_size // 2 + 1,)) + + if forward: + input = torch.conj(input) + + output = prims.fft_c2r(input, dim=dims, last_dim_size=last_dim_size) + return _apply_norm(output, norm=norm, signal_numel=last_dim_size, forward=forward) + + +def _fft_r2c( + func_name: str, + input: TensorLikeType, + n: Optional[int], + dim: int, + norm: NormType, + forward: bool, + onesided: bool, +) -> TensorLikeType: + """Common code for performing any real to complex FFT (rfft or ihfft)""" + torch._check( + not input.dtype.is_complex, + lambda: f"{func_name} expects a floating point input tensor, but got {input.dtype}", + ) + input = _maybe_promote_tensor_fft(input) + dims = (utils.canonicalize_dim(input.ndim, dim, wrap_scalar=False),) + dim_size = n if n is not None else input.shape[dim] + torch._check( + dim_size >= 1, lambda: f"Invalid number of data points ({dim_size}) specified" + ) + + if n is not None: + input = _resize_fft_input(input, dims, (n,)) + + ret = prims.fft_r2c(input, dim=dims, onesided=onesided) + ret = _apply_norm(ret, norm, dim_size, forward) + return ret if forward else torch.conj(ret) + + +def _fft_c2c( + func_name: str, + input: TensorLikeType, + n: Optional[int], + dim: int, + norm: NormType, + forward: bool, +) -> TensorLikeType: + """Common code for performing any complex to complex FFT (fft or ifft)""" + torch._check( + input.dtype.is_complex, + lambda: f"{func_name} expects a complex input tensor, but got {input.dtype}", + ) + dims = (utils.canonicalize_dim(input.ndim, dim, wrap_scalar=False),) + dim_size = n if n is not None else input.shape[dim] + torch._check( + dim_size >= 1, lambda: f"Invalid number of data points ({dim_size}) specified" + ) + + if n is not None: + input = _resize_fft_input(input, dims, (n,)) + + ret = prims.fft_c2c(input, dim=dims, forward=forward) + return _apply_norm(ret, norm, dim_size, forward) + + +@register_decomposition(aten.fft_fft) +@out_wrapper() +def fft( + input: TensorLikeType, + n: Optional[int] = None, + dim: int = -1, + norm: NormType = None, +) -> TensorLikeType: + if input.dtype.is_complex: + return _fft_c2c("fft", input, n, dim, norm, forward=True) + else: + return _fft_r2c("fft", input, n, dim, norm, forward=True, onesided=False) + + +@register_decomposition(aten.fft_ifft) +@out_wrapper() +def ifft( + input: TensorLikeType, + n: Optional[int] = None, + dim: int = -1, + norm: NormType = None, +) -> TensorLikeType: + if input.dtype.is_complex: + return _fft_c2c("ifft", input, n, dim, norm, forward=False) + else: + return _fft_r2c("ifft", input, n, dim, norm, forward=False, onesided=False) + + +@register_decomposition(aten.fft_rfft) +@out_wrapper() +def rfft( + input: TensorLikeType, + n: Optional[int] = None, + dim: int = -1, + norm: NormType = None, +) -> TensorLikeType: + return _fft_r2c("rfft", input, n, dim, norm, forward=True, onesided=True) + + +@register_decomposition(aten.fft_irfft) +@out_wrapper() +def irfft( + input: TensorLikeType, + n: Optional[int] = None, + dim: int = -1, + norm: NormType = None, +) -> TensorLikeType: + return _fft_c2r("irfft", input, n, dim, norm, forward=False) + + +@register_decomposition(aten.fft_hfft) +@out_wrapper() +def hfft( + input: TensorLikeType, + n: Optional[int] = None, + dim: int = -1, + norm: NormType = None, +) -> TensorLikeType: + return _fft_c2r("hfft", input, n, dim, norm, forward=True) + + +@register_decomposition(aten.fft_ihfft) +@out_wrapper() +def ihfft( + input: TensorLikeType, + n: Optional[int] = None, + dim: int = -1, + norm: NormType = None, +) -> TensorLikeType: + return _fft_r2c("ihfft", input, n, dim, norm, forward=False, onesided=True) + + +class _ShapeAndDims(NamedTuple): + shape: tuple[int, ...] + dims: tuple[int, ...] + + +def _canonicalize_fft_shape_and_dim_args( + input: TensorLikeType, shape: Optional[ShapeType], dim: Optional[DimsType] +) -> _ShapeAndDims: + """Convert the shape and dim arguments into a canonical form where neither are optional""" + input_dim = input.ndim + input_sizes = input.shape + + if dim is not None: + if not isinstance(dim, Sequence): + dim = (dim,) + ret_dims = utils.canonicalize_dims(input_dim, dim, wrap_scalar=False) + + # Check dims are unique + torch._check( + len(set(ret_dims)) == len(ret_dims), lambda: "FFT dims must be unique" + ) + + if shape is not None: + if not isinstance(shape, Sequence): + shape = (shape,) + + # Has shape, might have dim + torch._check( + dim is None or len(dim) == len(shape), + lambda: "When given, dim and shape arguments must have the same length", + ) + transform_ndim = len(shape) + + torch._check( + transform_ndim <= input_dim, + lambda: f"Got shape with {transform_ndim} values but input tensor " + f"only has {input_dim} dimensions.", + ) + + # If shape is given, dims defaults to the last len(shape) dimensions + if dim is None: + ret_dims = tuple(range(input_dim - transform_ndim, input_dim)) + + # Translate any -1 values in shape to the default length + ret_shape = tuple( + s if s != -1 else input_sizes[d] + for (s, d) in zip(shape, ret_dims) # type: ignore[possibly-undefined] + ) + elif dim is None: + # No shape, no dim + ret_dims = tuple(range(input_dim)) + ret_shape = tuple(input_sizes) + else: + # No shape, has dim + ret_shape = tuple(input_sizes[d] for d in ret_dims) # type: ignore[possibly-undefined] + + for n in ret_shape: + torch._check(n > 0, lambda: f"Invalid number of data points ({n}) specified") + + return _ShapeAndDims(shape=ret_shape, dims=ret_dims) # type: ignore[possibly-undefined] + + +def _prod(xs: Iterable[int]) -> int: + """Compute product of a list""" + prod = 1 + for x in xs: + prod *= x + return prod + + +def _fftn_c2c( + function_name: str, + input: TensorLikeType, + shape: tuple[int, ...], + dim: tuple[int, ...], + norm: NormType, + forward: bool, +) -> TensorLikeType: + """Common code for n-dimensional complex to complex FFTs (fftn or ifftn)""" + torch._check( + input.dtype.is_complex, + lambda: f"{function_name} expects a complex input tensor, " + f"but got {input.dtype}", + ) + x = _resize_fft_input(input, dim, shape) + output = prims.fft_c2c(x, dim=dim, forward=forward) + return _apply_norm(output, norm=norm, signal_numel=_prod(shape), forward=forward) + + +@register_decomposition(aten.fft_fftn) +@out_wrapper() +def fftn( + input: TensorLikeType, + s: Optional[ShapeType] = None, + dim: Optional[DimsType] = None, + norm: NormType = None, +) -> TensorLikeType: + (shape, dim) = _canonicalize_fft_shape_and_dim_args(input, s, dim) + x = _maybe_promote_tensor_fft(input, require_complex=True) + return _fftn_c2c("fftn", x, shape, dim, norm, forward=True) + + +@register_decomposition(aten.fft_ifftn) +@out_wrapper() +def ifftn( + input: TensorLikeType, + s: Optional[ShapeType] = None, + dim: Optional[DimsType] = None, + norm: NormType = None, +) -> TensorLikeType: + (shape, dim) = _canonicalize_fft_shape_and_dim_args(input, s, dim) + x = _maybe_promote_tensor_fft(input, require_complex=True) + return _fftn_c2c("ifftn", x, shape, dim, norm, forward=False) + + +@register_decomposition(aten.fft_rfftn) +@out_wrapper() +def rfftn( + input: TensorLikeType, + s: Optional[ShapeType] = None, + dim: Optional[DimsType] = None, + norm: NormType = None, +) -> TensorLikeType: + torch._check( + not input.dtype.is_complex, + lambda: f"rfftn expects a real-valued input tensor, but got {input.dtype}", + ) + shape, dim = _canonicalize_fft_shape_and_dim_args(input, s, dim) + input = _maybe_promote_tensor_fft(input, require_complex=False) + input = _resize_fft_input(input, dim, shape) + out = prims.fft_r2c(input, dim=dim, onesided=True) + return _apply_norm(out, norm=norm, signal_numel=_prod(shape), forward=True) + + +@register_decomposition(aten.fft_ihfftn) +@out_wrapper() +def ihfftn( + input: TensorLikeType, + s: Optional[ShapeType] = None, + dim: Optional[DimsType] = None, + norm: NormType = None, +) -> TensorLikeType: + torch._check( + not input.dtype.is_complex, + lambda: f"ihfftn expects a real-valued input tensor, but got {input.dtype}", + ) + shape, dim = _canonicalize_fft_shape_and_dim_args(input, s, dim) + torch._check(len(shape) > 0, lambda: "ihfftn must transform at least one axis") + input = _maybe_promote_tensor_fft(input, require_complex=False) + input = _resize_fft_input(input, dim, shape) + + tmp = prims.fft_r2c(input, dim=dim[-1:], onesided=True) + + if len(dim) == 1: + tmp = _apply_norm(tmp, norm=norm, signal_numel=shape[0], forward=False) + return prims.conj(tmp) + + tmp = prims.conj_physical(tmp) + tmp = prims.fft_c2c(tmp, dim=dim[:-1], forward=False) + return _apply_norm(tmp, norm=norm, signal_numel=_prod(shape), forward=False) + + +class _CanonicalizeC2rReturn(NamedTuple): + shape: tuple[int, ...] + dim: tuple[int, ...] + last_dim_size: int + + +def _canonicalize_fft_c2r_shape_and_dim_args( + fname: str, + input: TensorLikeType, + s: Optional[ShapeType], + dim: Optional[DimsType], +) -> _CanonicalizeC2rReturn: + """Canonicalize shape and dim arguments for n-dimensional c2r transforms, + as well as calculating the last_dim_size which is shape[dim[-1]] for the output""" + (shape, dim) = _canonicalize_fft_shape_and_dim_args(input, s, dim) + torch._check(len(shape) > 0, lambda: f"{fname} must transform at least one axis") + + if s is None or s[-1] == -1: + last_dim_size = 2 * (input.shape[dim[-1]] - 1) + else: + last_dim_size = shape[-1] + + torch._check( + last_dim_size >= 1, + lambda: f"Invalid number of data points ({last_dim_size}) specified", + ) + + shape_list = list(shape) + shape_list[-1] = last_dim_size // 2 + 1 + return _CanonicalizeC2rReturn( + shape=tuple(shape_list), dim=dim, last_dim_size=last_dim_size + ) + + +@register_decomposition(aten.fft_irfftn) +@out_wrapper() +def irfftn( + input: TensorLikeType, + s: Optional[ShapeType] = None, + dim: Optional[DimsType] = None, + norm: NormType = None, +) -> TensorLikeType: + shape, dim, last_dim_size = _canonicalize_fft_c2r_shape_and_dim_args( + "irfftn", input, s, dim + ) + input = _maybe_promote_tensor_fft(input, require_complex=True) + input = _resize_fft_input(input, dim, shape) + out = prims.fft_c2r(input, dim=dim, last_dim_size=last_dim_size) + return _apply_norm(out, norm, _prod(out.shape[d] for d in dim), forward=False) + + +@register_decomposition(aten.fft_hfftn) +@out_wrapper() +def hfftn( + input: TensorLikeType, + s: Optional[ShapeType] = None, + dim: Optional[DimsType] = None, + norm: NormType = None, +) -> TensorLikeType: + shape, dim, last_dim_size = _canonicalize_fft_c2r_shape_and_dim_args( + "hfftn", input, s, dim + ) + input = _maybe_promote_tensor_fft(input, require_complex=True) + input = _resize_fft_input(input, dim, shape) + + tmp = prims.fft_c2c(input, dim=dim[:-1], forward=True) if len(dim) > 1 else input + tmp = _apply_norm(tmp, norm, _prod(shape[:-1]), forward=True) + tmp = prims.conj_physical(tmp) + out = prims.fft_c2r(tmp, dim=dim[-1:], last_dim_size=last_dim_size) + return _apply_norm(out, norm, last_dim_size, forward=True) + + +@register_decomposition(aten.fft_fft2) +@out_wrapper() +def fft2( + input: TensorLikeType, + s: Optional[ShapeType] = None, + dim: Optional[DimsType] = (-2, -1), + norm: NormType = None, +) -> TensorLikeType: + return torch.fft.fftn(input, s=s, dim=dim, norm=norm) + + +@register_decomposition(aten.fft_ifft2) +@out_wrapper() +def ifft2( + input: TensorLikeType, + s: Optional[ShapeType] = None, + dim: Optional[DimsType] = (-2, -1), + norm: NormType = None, +) -> TensorLikeType: + return torch.fft.ifftn(input, s=s, dim=dim, norm=norm) + + +@register_decomposition(aten.fft_rfft2) +@out_wrapper() +def rfft2( + input: TensorLikeType, + s: Optional[ShapeType] = None, + dim: Optional[DimsType] = (-2, -1), + norm: NormType = None, +) -> TensorLikeType: + return torch.fft.rfftn(input, s=s, dim=dim, norm=norm) + + +@register_decomposition(aten.fft_irfft2) +@out_wrapper() +def irfft2( + input: TensorLikeType, + s: Optional[ShapeType] = None, + dim: Optional[DimsType] = (-2, -1), + norm: NormType = None, +) -> TensorLikeType: + return torch.fft.irfftn(input, s=s, dim=dim, norm=norm) + + +@register_decomposition(aten.fft_hfft2) +@out_wrapper() +def hfft2( + input: TensorLikeType, + s: Optional[ShapeType] = None, + dim: Optional[DimsType] = (-2, -1), + norm: NormType = None, +) -> TensorLikeType: + return torch.fft.hfftn(input, s=s, dim=dim, norm=norm) + + +@register_decomposition(aten.fft_ihfft2) +@out_wrapper() +def ihfft2( + input: TensorLikeType, + s: Optional[ShapeType] = None, + dim: Optional[DimsType] = (-2, -1), + norm: NormType = None, +) -> TensorLikeType: + return torch.fft.ihfftn(input, s=s, dim=dim, norm=norm) + + +def _default_alldims(dim: Optional[DimsType], x: TensorLikeType) -> list[int]: + """Convert Optional[DimsType] to a simple list, defaulting to all dimensions""" + if dim is None: + return list(range(x.ndim)) + elif not isinstance(dim, Sequence): + return [dim] + else: + return list(dim) + + +@register_decomposition(aten.fft_fftshift) +def fftshift(input: TensorLikeType, dim: Optional[DimsType] = None) -> TensorLikeType: + dims = _default_alldims(dim, input) + shift = [input.shape[d] // 2 for d in dims] + return torch.roll(input, shift, dims) + + +@register_decomposition(aten.fft_ifftshift) +def ifftshift(input: TensorLikeType, dim: Optional[DimsType] = None) -> TensorLikeType: + dims = _default_alldims(dim, input) + shift = [(input.shape[d] + 1) // 2 for d in dims] + return torch.roll(input, shift, dims) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_strobelight/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_strobelight/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_strobelight/cli_function_profiler.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_strobelight/cli_function_profiler.py new file mode 100644 index 0000000000000000000000000000000000000000..a63a49c3938a1d4a3cda3f1a6b9b029a90a8a77e --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_strobelight/cli_function_profiler.py @@ -0,0 +1,322 @@ +# mypy: disallow-untyped-defs + +import functools +import logging +import os +import re +import subprocess +import time +from collections.abc import Callable, Sequence +from threading import Lock +from timeit import default_timer as timer +from typing import Any, Optional, TypeVar +from typing_extensions import ParamSpec + + +logger = logging.getLogger("strobelight_function_profiler") + +console_handler = logging.StreamHandler() +formatter = logging.Formatter( + "%(name)s, line %(lineno)d, %(asctime)s, %(levelname)s: %(message)s" +) +console_handler.setFormatter(formatter) + +logger.addHandler(console_handler) +logger.setLevel(logging.INFO) +logger.propagate = False + +_P = ParamSpec("_P") +_R = TypeVar("_R") + + +class StrobelightCLIProfilerError(Exception): + """ + Raised when an error happens during strobelight profiling + """ + + +def _pid_namespace_link(pid: Optional[int] = None) -> str: + """Returns the link to the process's namespace, example: pid:[4026531836]""" + PID_NAMESPACE_PATH = "/proc/{}/ns/pid" + pid = pid or os.getpid() + return os.readlink(PID_NAMESPACE_PATH.format(pid)) + + +def _pid_namespace(pid: Optional[int] = None) -> int: + """Returns the process's namespace id""" + pid = pid or os.getpid() + link = _pid_namespace_link(pid) + return int(link[link.find("[") + 1 : -1]) + + +def _command_to_string(command: Sequence[str]) -> str: + return " ".join(command) + + +class StrobelightCLIFunctionProfiler: + """ + Note: this is a Meta only tool. + + StrobelightCLIFunctionProfiler can be used to profile a python function and + generate a strobelight link with the results. It works on meta servers but + does not requires an fbcode target. + When stop_at_error is false(default), error during profiling does not prevent + the work function from running. + + Check function_profiler_example.py for an example. + """ + + # This lock is used to make sure only one thread is running the profiler at any point. + _lock = Lock() + + def __init__( + self, + *, + stop_at_error: bool = False, + max_profile_duration_sec: int = 60 * 10, + sample_each: float = 1e7, # sample each sample_each cycles. + run_user_name: str = "pytorch-strobelight-ondemand", + timeout_wait_for_running_sec: int = 60, + timeout_wait_for_finished_sec: int = 60, + recorded_env_variables: Optional[list[str]] = None, + sample_tags: Optional[list[str]] = None, + stack_max_len: int = 127, + async_stack_max_len: int = 127, + ): + self.stop_at_error = stop_at_error + self.max_profile_duration_sec = max_profile_duration_sec + self.sample_each = sample_each + self.run_user_name = run_user_name + self.timeout_wait_for_running_sec = timeout_wait_for_running_sec + self.timeout_wait_for_finished_sec = timeout_wait_for_finished_sec + # Results of the most recent run. + # Tracks the strobelight run id of the most recent run + self.current_run_id: Optional[int] = None + self.profile_result: Optional[list[str]] = None + self.sample_tags = sample_tags + + def _run_async(self) -> None: + processId = os.getpid() + namespace = _pid_namespace(processId) + command = [ + "strobeclient", + "run", + "--profiler", + "pyperf", + "--event", + "cycles", + "--async", + "--sample-interval", + f"{int(self.sample_each)}", + "--duration-ms", + f"{int(self.max_profile_duration_sec * 1000)}", + "--pid", + f"{namespace}:{processId}", + ] + + if self.sample_tags: + command.append("--sample-tags") + command.append(",".join(self.sample_tags)) + + logger.debug("running command: %s", _command_to_string(command)) + result = subprocess.run(command, capture_output=True) + output = result.stderr.decode("utf-8") + logger.debug("output:\n{%s}", output) + + if result.returncode != 0: + raise StrobelightCLIProfilerError( + f"failed to start strobelight profiling, error in run_async:{output}" + ) + + if match := re.search(r"INFO Run Id: (-?\d+)", output): + self.current_run_id = int(match.group(1)) + return + + raise StrobelightCLIProfilerError( + f"failed to start strobelight profiling, unexpected result {output}" + ) + + def _wait_for_running(self, counter: int = 0) -> None: + if counter > 20: + raise StrobelightCLIProfilerError( + "wait_for_running called more than 20 times" + ) + + command = ["strobeclient", "getRunStatus", "--run-id", f"{self.current_run_id}"] + logger.debug("running command: %s", _command_to_string(command)) + result = subprocess.run(command, capture_output=True) + output = result.stderr.decode("utf-8") + logger.debug("output:\n{%s}", output) + + if result.returncode != 0: + raise StrobelightCLIProfilerError( + f"failed to start strobelight profiling, error in wait_for_running:{output}" + ) + + if match := re.search("Profile run status: (.*)", output): + current_status = match.group(1) + if current_status == "RUNNING": + return + elif current_status == "PREPARING": + time.sleep(10) + self._wait_for_running(counter + 1) + return + else: + raise StrobelightCLIProfilerError(f"unexpected {current_status} phase") + + raise StrobelightCLIProfilerError(f"unexpected output\n: {output} ") + + def _stop_run(self) -> None: + command = ["strobeclient", "stopRun", "--run-id", str(self.current_run_id)] + logger.debug("running command: %s", _command_to_string(command)) + result = subprocess.run(command, capture_output=True) + output = result.stderr.decode("utf-8") + logger.debug("output:\n{%s}", output) + + if result.returncode != 0: + raise StrobelightCLIProfilerError( + f"failed to stop strobelight profiling, return code is not 0 :{output}" + ) + + if match := re.search("INFO ::1:(.*)", output): + current_status = match.group(1) + if current_status.__contains__("Success!"): + return + else: + raise StrobelightCLIProfilerError( + f"failed to stop strobelight profiling, got {current_status} result" + ) + + raise StrobelightCLIProfilerError(f"unexpected output\n: {output} ") + + def _get_results(self) -> None: + command = ["strobeclient", "getRunStatus", "--run-id", str(self.current_run_id)] + logger.debug("running command: %s", _command_to_string(command)) + result = subprocess.run(command, capture_output=True) + output = result.stderr.decode("utf-8") + logger.debug("output:\n{%s}", output) + + if result.returncode != 0: + raise StrobelightCLIProfilerError( + f"failed to extract profiling results, return code is not 0 : {output}" + ) + + if match := re.search("INFO ::1:(.*)", output): + current_status = match.group(1) + if current_status.__contains__("Profile run status: PROCESSING"): + time.sleep(10) + self._get_results() + return + elif not current_status.__contains__("Profile run finished with SUCCESS"): + raise StrobelightCLIProfilerError( + f"failed to extract profiling results, unexpected response {output}" + ) + + self.profile_result = [] + for item in re.findall( + r"(Total samples(.*)|GraphProfiler(.*)|Icicle view \(python stack\)(.*))", + output, + ): + self.profile_result += item[0] + logger.info(item[0]) + + def _stop_strobelight_no_throw( + self, + collect_results: bool, + ) -> None: + try: + # call stop run + self._stop_run() + logger.info("strobelight profiling stopped") + + logger.debug("collection stopped") + + if not collect_results: + return + + self._get_results() + except Exception: + logger.warning("error during stop_strobelight", exc_info=True) + + # Return true if strobelight started and is running. Never throw. + def _start_strobelight(self) -> bool: + strobelight_started = False + try: + self._run_async() + strobelight_started = True + logger.info("strobelight run id is: %s", self.current_run_id) + self._wait_for_running() + logger.info("strobelight profiling running") + return True + + except Exception: + logger.warning("error during start_strobelight:", exc_info=True) + if strobelight_started: + self._stop_strobelight_no_throw(collect_results=False) + return False + + def profile( + self, work_function: Callable[_P, _R], *args: _P.args, **kwargs: _P.kwargs + ) -> Optional[_R]: + self.current_run_id = None + self.profile_result = None + + if locked := StrobelightCLIFunctionProfiler._lock.acquire(False): + if not locked: + if self.stop_at_error: + raise StrobelightCLIProfilerError("concurrent runs not supported") + + logger.warning("concurrent runs not supported") + return work_function(*args, **kwargs) + + started = self._start_strobelight() + if not started: + if self.stop_at_error: + StrobelightCLIFunctionProfiler._lock.release() + raise StrobelightCLIProfilerError( + "failed to start strobelight profiling" + ) + result = work_function(*args, **kwargs) + StrobelightCLIFunctionProfiler._lock.release() + return result + + try: + logger.debug("collection started") + start = timer() + result = work_function(*args, **kwargs) + end = timer() + total_time = end - start # Time in seconds, e.g. 5.38091952400282 + logger.info("work function took %s seconds", total_time) + self._stop_strobelight_no_throw(collect_results=True) + StrobelightCLIFunctionProfiler._lock.release() + return result + except Exception as error: + logger.warning("work function throw exception", exc_info=True) + self._stop_strobelight_no_throw(collect_results=False) + StrobelightCLIFunctionProfiler._lock.release() + raise error + return None + + +# A function decorator that wraps profile, if no profiler is provided one with +# default args is created. A function can be annotated as: +# @strobelight() +# @strobelight(profiler = StrobelightFunctionProfiler(stop_at_error=True,..)) +# @strobelight(stop_at_error=True,...) +def strobelight( + profiler: Optional[StrobelightCLIFunctionProfiler] = None, **kwargs: Any +) -> Callable[[Callable[_P, _R]], Callable[_P, Optional[_R]]]: + if not profiler: + profiler = StrobelightCLIFunctionProfiler(**kwargs) + + def strobelight_inner( + work_function: Callable[_P, _R], + ) -> Callable[_P, Optional[_R]]: + @functools.wraps(work_function) + def wrapper_function(*args: _P.args, **kwargs: _P.kwargs) -> Optional[_R]: + # pyrefly: ignore [bad-argument-type] + return profiler.profile(work_function, *args, **kwargs) + + return wrapper_function + + return strobelight_inner diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_strobelight/compile_time_profiler.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_strobelight/compile_time_profiler.py new file mode 100644 index 0000000000000000000000000000000000000000..89b44632e27872704b6f6c3e6ea21c9c19416610 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_strobelight/compile_time_profiler.py @@ -0,0 +1,224 @@ +# mypy: disallow-untyped-defs + +import json +import logging +import os +import re +import subprocess +from datetime import datetime +from socket import gethostname +from typing import Any, Optional + +from torch._strobelight.cli_function_profiler import StrobelightCLIFunctionProfiler + + +logger = logging.getLogger("strobelight_compile_time_profiler") + +console_handler = logging.StreamHandler() +formatter = logging.Formatter( + "%(name)s, line %(lineno)d, %(asctime)s, %(levelname)s: %(message)s" +) +console_handler.setFormatter(formatter) + +logger.addHandler(console_handler) +logger.setLevel(logging.INFO) +logger.propagate = False + + +def get_fburl(url: str) -> str: + short_url = url + # Attempt to shorten the URL + try: + result = subprocess.run( + ["fburl", url], capture_output=True, stdin=subprocess.DEVNULL + ) + if result.returncode == 0: + short_url = result.stdout.decode("utf-8") + except Exception as e: + logger.warning("URL shortening failed: %s, using long URL", repr(e)) + return short_url + + +def get_strobelight_url(identifier: str) -> str: + scuba_json = { + "aggregateList": [], + "aggregation_field": "async_stack_complete", + "b_constraints": [[]], + "c_constraints": [[]], + "cols": ["namespace_id", "namespace_process_id"], + "compare": "none", + "constraints": [ + [{"column": "sample_tags", "op": "all", "value": [f'["{identifier}"]']}] + ], + "derivedCols": [], + "end": "now", + "enumCols": [], + "filterMode": "DEFAULT", + "hideEmptyColumns": "false", + "ignoreGroupByInComparison": "false", + "is_timeseries": "false", + "mappedCols": [], + "metric": "count", + "modifiers": [], + "order": "weight", + "order_desc": "true", + "param_dimensions": [ + {"dim": "py_async_stack", "op": "edge", "param": "0", "anchor": "0"} + ], + "purposes": [], + "return_remainder": "false", + "samplingRatio": "1", + "should_pivot": "false", + "start": "-30 days", + "timezone": "America/Los_Angeles", + "top": 10000, + } + scuba_url_prefix = "https://www.internalfb.com/intern/scuba/query/?dataset=pyperf_experimental/on_demand&drillstate=" + scuba_url_suff = "&view=GraphProfilerView&&normalized=1726332703&pool=uber" + long_url = scuba_url_prefix + json.dumps(scuba_json) + scuba_url_suff + return get_fburl(long_url) + + +class StrobelightCompileTimeProfiler: + success_profile_count: int = 0 + failed_profile_count: int = 0 + ignored_profile_runs: int = 0 + inside_profile_compile_time: bool = False + enabled: bool = False + + # A regex that can be used to filter out what frames to profile. ex: "1/.*" + frame_id_filter: Optional[str] = os.environ.get("COMPILE_STROBELIGHT_FRAME_FILTER") + + # A unique identifier that is used as the run_user_name in the strobelight profile to + # associate all compile time profiles together. + identifier: Optional[str] = None + + current_phase: Optional[str] = None + + profiler: Optional[Any] = None + + max_stack_length: int = int( + os.environ.get("COMPILE_STROBELIGHT_MAX_STACK_LENGTH", 500) + ) + max_profile_time: int = int( + os.environ.get("COMPILE_STROBELIGHT_MAX_PROFILE_TIME", 60 * 30) + ) + # Collect sample each x cycles. + sample_each: int = int( + float(os.environ.get("COMPILE_STROBELIGHT_SAMPLE_RATE", 1e7)) + ) + + @classmethod + def get_frame(cls) -> str: + from torch._guards import CompileContext + + return (str)(CompileContext.current_trace_id()) + + @classmethod + def enable(cls, profiler_class: Any = StrobelightCLIFunctionProfiler) -> None: + if cls.enabled: + logger.info("compile time strobelight profiling already enabled") + return + + logger.info("compile time strobelight profiling enabled") + + if profiler_class is StrobelightCLIFunctionProfiler: + import shutil + + if not shutil.which("strobeclient"): + logger.info( + "strobeclient not found, can't enable compile time strobelight profiling, seems" + "like you are not on a FB machine." + ) + return + + cls.enabled = True + cls._cls_init() + # profiler_class should have public API similar to that of StrobelightCLIFunctionProfiler. + # we have pass different functionProfilerClass for meta-internal fbcode targets. + # NB: the actual implementation in Meta is at + # fbcode/caffe2/fb/strobelight/function_profiler.py + cls.profiler = profiler_class( + sample_each=cls.sample_each, + max_profile_duration_sec=cls.max_profile_time, + stack_max_len=cls.max_stack_length, + async_stack_max_len=cls.max_stack_length, + run_user_name="pt2-profiler/" + + os.environ.get("USER", os.environ.get("USERNAME", "")), + sample_tags={cls.identifier}, # pyrefly: ignore # bad-argument-type + ) + + @classmethod + def _cls_init(cls) -> None: + cls.identifier = "{date}{pid}{hostname}".format( + date=datetime.now().strftime("%Y-%m-%d-%H:%M:%S"), + pid=os.getpid(), + hostname=gethostname(), + ) + + logger.info("Unique sample tag for this run is: %s", cls.identifier) + logger.info( + "URL to access the strobelight profile at the end of the run: %s", + get_strobelight_url(cls.identifier), + ) + + @classmethod + def _log_stats(cls) -> None: + logger.info( + "%s strobelight success runs out of %s non-recursive compilation events.", + cls.success_profile_count, + cls.success_profile_count + cls.failed_profile_count, + ) + + # TODO use threadlevel meta data to tags to record phases. + @classmethod + def profile_compile_time( + cls, func: Any, phase_name: str, *args: Any, **kwargs: Any + ) -> Any: + def skip() -> Any: + return func(*args, **kwargs) + + if not cls.enabled: + return skip() + + if cls.profiler is None: + logger.error("profiler is not set") + return + + frame_id = cls.get_frame() + + if cls.inside_profile_compile_time: + cls.ignored_profile_runs += 1 + logger.info( + "profile_compile_time is requested for phase: %s, frame %s, while already in running phase: %s," + "frame %s, recursive call ignored", + phase_name, + frame_id, + cls.current_phase, + frame_id, + ) + return skip() + + if cls.frame_id_filter is not None: + should_run = re.match(cls.frame_id_filter, frame_id) is not None + if not should_run: + logger.info( + "profiling frame %s is skipped due to frame_id_filter %s", + frame_id, + cls.frame_id_filter, + ) + return skip() + + cls.inside_profile_compile_time = True + cls.current_phase = phase_name + logger.info("profiling frame %s", frame_id) + work_result = cls.profiler.profile(func, *args, **kwargs) + + if cls.profiler.profile_result is not None: + cls.success_profile_count += 1 + else: + cls.failed_profile_count += 1 + + cls._log_stats() + cls.inside_profile_compile_time = False + return work_result diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_subclasses/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_subclasses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cdc42f39cbddaf5bdc919cef88d5f049fdba2634 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_subclasses/__init__.py @@ -0,0 +1,17 @@ +import torch +from torch._subclasses.fake_tensor import ( + DynamicOutputShapeException, + FakeTensor, + FakeTensorMode, + UnsupportedFakeTensorException, +) +from torch._subclasses.fake_utils import CrossRefFakeMode + + +__all__ = [ + "FakeTensor", + "FakeTensorMode", + "UnsupportedFakeTensorException", + "DynamicOutputShapeException", + "CrossRefFakeMode", +] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_subclasses/_fake_tensor_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_subclasses/_fake_tensor_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..cffa4a2216532deac32fa51a3a770236f4040f68 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_subclasses/_fake_tensor_utils.py @@ -0,0 +1,263 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Optional, TYPE_CHECKING, Union + +import torch +from torch import SymInt +from torch.fx.experimental.sym_node import SymNode +from torch.types import py_sym_types, PySymType + + +if TYPE_CHECKING: + import sympy + + from torch.fx.experimental.symbolic_shapes import ShapeEnv + + from .fake_tensor import _DispatchCacheKey, _MetadataIntLike + + +@dataclass(frozen=True, slots=True) +class _DeconstructedSymNode: + """ + Represents a SymNode without the associated ShapeEnv + """ + + # n.b. keep the same protocol as SymNode + _expr: sympy.Expr + pytype: type + _hint: Optional[Union[int, float, bool]] + constant: Optional[Union[int, float, bool]] + fx_node: torch.fx.Node + + @staticmethod + def from_node(node: SymNode) -> _DeconstructedSymNode: + return _DeconstructedSymNode( + node._expr, + node.pytype, + node._hint, + node.constant, + # pyrefly: ignore [bad-argument-type] + node.fx_node, + ) + + def extract(self, shape_env: ShapeEnv) -> SymNode: + return SymNode( + self._expr, shape_env, self.pytype, self._hint, self.constant, self.fx_node + ) + + def __str__(self) -> str: + return str(self._expr) + + def __repr__(self) -> str: + return f"_DeconstructedSymNode{{{self._expr!r}, {self.pytype!r}, {self._hint!r}, {self.constant!r}, {self.fx_node!r}}}" + + def __eq__(self, other: object) -> bool: + raise NotImplementedError + + def __hash__(self) -> int: + raise NotImplementedError + + # _value_eq to match SymNode + def _value_eq(self, other: object) -> bool: + if isinstance(other, (SymNode, _DeconstructedSymNode)): + return ( + self._expr == other._expr + and self.pytype == other.pytype + and self._hint == other._hint + and self.constant == other.constant + and self.fx_node == other.fx_node + ) + else: + return False + + # _value_hash to match SymNode + def _value_hash(self) -> int: + return hash((self._expr, self.pytype, self._hint, self.constant, self.fx_node)) + + +@dataclass(frozen=True, slots=True) +class _DeconstructedSymType: + """ + Represents a SymInt, SymFloat, SymBool without the associated ShapeEnv + """ + + ty: type[PySymType] + node: _DeconstructedSymNode + + @staticmethod + def from_sym_type(value: PySymType) -> _DeconstructedSymType: + return _DeconstructedSymType(type(value), value.node) + + def extract(self, shape_env: ShapeEnv) -> PySymType: + return self.ty(self.node.extract(shape_env)) + + def __str__(self) -> str: + return f"{self.ty}({self.node})" + + def __repr__(self) -> str: + return f"_DeconstructedSymType({self.ty}, {self.node!r})" + + def __eq__(self, other: object) -> bool: + return NotImplemented + + def __hash__(self) -> int: + return NotImplemented + + +@dataclass(frozen=True, slots=True) +class _InputBackref: + value: int + + +@dataclass(slots=True) +class _PySymInputStub: + """ + Represents a SymInt in the cached key. Needed because SymInt doesn't + support __eq__ or __hash__ directly. + """ + + # value can be: + # PySymType: This is the 'normal' SymInt value, wrapped so we can use + # hash/eq as value hash/eq (normally SymInt does object + # hash/eq). + # _DeconstructedSymType: This is used when storing the _PySymInputStub in + # the cache to avoid cyclic ShapeEnv references. + # _InputBackref: This is a back-reference to a previous _PySymInputStub in + # the key. + value: Union[PySymType, _DeconstructedSymType, _InputBackref] + + def __init__( + self, value: Union[PySymType, _DeconstructedSymType, _InputBackref] + ) -> None: + # For inputs (values in the `key`) we need to keep the PySymType intact + # - this way if we need to reuse it as an output we can properly copy + # the original value. + self.value = value + + def strip_shape_env(self) -> None: + if isinstance(self.value, py_sym_types): + self.value = _DeconstructedSymType.from_sym_type(self.value) + + def extract(self, shape_env: ShapeEnv) -> PySymType: + if isinstance(self.value, _DeconstructedSymType): + return self.value.extract(shape_env) + else: + # We should never see an _InputBackref here - anyone extracting a + # value should be pulling from the original entry (the one this + # backref points at). + assert not isinstance(self.value, _InputBackref) + return self.value + + def __str__(self) -> str: + return str(self.value) + + def __repr__(self) -> str: + return f"_PySymInputStub({self.value!r})" + + def __eq__(self, other: object) -> bool: + if not isinstance(other, _PySymInputStub): + return False + elif isinstance(self.value, _InputBackref) or isinstance( + other.value, _InputBackref + ): + return self.value == other.value + else: + return self.value.node._value_eq(other.value.node) + + def __hash__(self) -> int: + if isinstance(self.value, _InputBackref): + return hash(self.value) + else: + return self.value.node._value_hash() + + +@dataclass(slots=True) +class _SymIntOutputStub: + """ + Represents a SymInt in the cached output. + """ + + # This is either an `int` which represents the index in the key to copy the + # SymNode from or it's the deconstructed SymNode itself. + value: Union[int, _DeconstructedSymNode] + + def __init__(self, value: SymInt, key_path: Optional[int]) -> None: + if key_path is None: + self.value = _DeconstructedSymNode.from_node(value.node) + else: + self.value = key_path + + def extract(self, key: _DispatchCacheKey, shape_env: ShapeEnv) -> SymInt: + if isinstance(self.value, _DeconstructedSymNode): + return SymInt(self.value.extract(shape_env)) + else: + src = key.key[self.value] + assert isinstance(src, _PySymInputStub) and isinstance(src.value, SymInt) + return src.value + + def __repr__(self) -> str: + return f"_SymIntOutputStub({self.value!r})" + + def __eq__(self, other: object) -> bool: + raise NotImplementedError + + def __hash__(self) -> int: + raise NotImplementedError + + +@dataclass(slots=True) +class _CacheKeyState: + """ + State used while building our cache key. + """ + + # We track the SymNodes so when we get the output we can see if it exactly + # matches one of the inputs so we can uncache it properly. + sym_node_lookup: dict[int, int] # id(SymNode) -> index + + # This is a list of all seen input sympy.Symbols. We use it when building + # the cache entry to see if the output value has any symbols that we didn't + # see on input. See _has_unrepresented_symbols(). + known_symbols: set[sympy.Symbol] + + # There are cases where we're asked to perform an op when we have no + # ShapeEnv on the FakeTensorMode - but for SymNodes we MUST have a + # ShapeEnv. So as we scan if we see a SymNode (with a ShapeEnv) we record it + # here. + shape_env: Optional[ShapeEnv] + + def __init__(self, shape_env: Optional[ShapeEnv] = None) -> None: + self.sym_node_lookup = {} + self.known_symbols = set() + self.shape_env = shape_env + + def cache_on_shape_env(self) -> bool: + """ + Returns true if the CacheKey needs to be cached on the ShapeEnv + rather than the global cache. + + If our inputs contain a SymNode then we can't cache this operation on + the global cache because the cached output will implicitly depend on + guard values which might not be true on some other ShapeEnv. So unless + we're also going to cache the guards we need to cache this operation on + the ShapeEnv instead of globally. + """ + return bool(self.sym_node_lookup) + + def convert_sym_int(self, result: list[object], arg: SymInt) -> None: + node_id = id(arg.node) + if node_id in self.sym_node_lookup: + result.append(_InputBackref(self.sym_node_lookup[node_id])) + else: + self.sym_node_lookup[node_id] = len(result) + self.known_symbols.update(arg.node.expr.free_symbols) + if self.shape_env is None: + self.shape_env = arg.node.shape_env + result.append(_PySymInputStub(arg)) + + def convert_output(self, arg: _MetadataIntLike) -> _MetadataIntLike: + if isinstance(arg, SymInt): + return _SymIntOutputStub(arg, self.sym_node_lookup.get(id(arg.node), None)) + else: + return arg diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_subclasses/fake_impls.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_subclasses/fake_impls.py new file mode 100644 index 0000000000000000000000000000000000000000..ff309af8a29e0e2afa4f94cae67fef775bf874bc --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_subclasses/fake_impls.py @@ -0,0 +1,1465 @@ +# mypy: ignore-errors + +import functools +import itertools +import math +import operator +import sys +from collections.abc import Callable +from functools import reduce +from typing import Optional, Union + +import torch +import torch._custom_op +import torch._logging +import torch._prims_common as utils +from torch._dispatch.python import no_python_dispatcher +from torch._ops import OpOverload +from torch._prims_common import ( + canonicalize_dim, + elementwise_dtypes, + ELEMENTWISE_TYPE_PROMOTION_KIND, + is_boolean_dtype, + is_contiguous, + is_contiguous_for_memory_format_or_false, + is_contiguous_or_false, + is_float_dtype, + is_integer_dtype, + make_contiguous_strides_for, +) +from torch._subclasses.fake_tensor import ( + DataDependentOutputException, + DynamicOutputShapeException, + FakeTensor, + in_kernel_invocation_manager, + run_fallback_kernel, + UnsupportedOperatorException, +) +from torch.fx.operator_schemas import normalize_function +from torch.utils._stats import count_label + + +pytree = torch.utils._pytree + +__all__ = [ + "op_implementations_checks", + "get_fast_op_impls", + "stride_incorrect_op", + "has_meta", +] + +op_implementations_dict = {} +op_implementations_checks = [] + + +aten = torch._ops.ops.aten + + +def ordered_set(*items): + return dict.fromkeys(items, True) + + +# This function indicates if the backend device +# supports non-contiguous tensors +def is_noncontiguous_supported(device): + return device.type != "hpu" + + +_like_tensor_constructors = ordered_set( + aten.empty_like.default, + aten.empty_like.out, + aten.full_like.default, + aten.full_like.out, + aten.ones_like.default, + aten.ones_like.out, + aten.rand_like.default, + aten.rand_like.generator, + aten.rand_like.out, + aten.rand_like.generator_out, + aten.randn_like.default, + aten.randn_like.generator, + aten.randn_like.out, + aten.randn_like.generator_out, + aten.randint_like.default, + aten.randint_like.generator, + aten.randint_like.Tensor, + aten.randint_like.Tensor_generator, + aten.randint_like.Tensor_out, + aten.randint_like.Tensor_generator_out, + aten.randint_like.out, + aten.randint_like.generator_out, + aten.randint_like.low_dtype, + aten.randint_like.low_generator_dtype, + aten.randint_like.low_dtype_out, + aten.randint_like.low_generator_dtype_out, + aten.zeros_like.default, + aten.zeros_like.out, + aten.new_empty.default, + aten.new_empty.out, + aten.new_empty_strided.default, + aten.new_empty_strided.out, + aten.new_full.default, + aten.new_full.out, + aten.new_zeros.default, + aten.new_zeros.out, + aten.new_ones.default, + aten.new_ones.out, +) + + +_device_not_kwarg_ops = ordered_set( + aten._resize_output_.default, + aten._nested_tensor_from_tensor_list.default, + aten._nested_tensor_from_tensor_list.out, + aten.pin_memory.default, + aten.to.device, + aten.to.prim_Device, + aten.is_pinned.default, + aten._pin_memory.default, + aten._pin_memory.out, + aten._resize_output.default, + aten._resize_output.out, +) + +# this op is never actually used +_non_kwarg_device_constructors = (aten._list_to_tensor,) + + +def contains_tensor_types(type): + tensor_type = torch._C.TensorType.get() + return type.isSubtypeOf(tensor_type) or any( + contains_tensor_types(e) for e in type.containedTypes() + ) + + +@functools.cache +def _is_tensor_constructor(func: OpOverload): + assert isinstance(func, OpOverload) + schema = func._schema + if any(contains_tensor_types(arg.type) for arg in schema.arguments): + return False + # TODO: no real reason to restrict multiple outputs + return ( + len(schema.returns) == 1 and schema.returns[0].type is torch._C.TensorType.get() + ) + + +def register_op_impl(run_impl_check: Union[Callable[[OpOverload], bool], OpOverload]): + def impl_decorator(op_impl): + if isinstance(run_impl_check, OpOverload): + assert run_impl_check not in op_implementations_dict, ( + f"duplicate registration: {run_impl_check}" + ) + op_implementations_dict[run_impl_check] = op_impl + elif isinstance(run_impl_check, (list, tuple)): + for op in run_impl_check: + register_op_impl(op)(op_impl) + else: + assert callable(run_impl_check) + op_implementations_checks.append((run_impl_check, op_impl)) + + return op_impl + + return impl_decorator + + +def _is_op_registered_to_fake_rule(op): + return op in op_implementations_dict + + +def _deregister_op_impl(op): + op_implementations_dict.pop(op, None) + for check, impl in op_implementations_checks: + if check is op: + op_implementations_checks.remove((check, impl)) + break + + +@register_op_impl(op_implementations_dict.__contains__) +def dispatch_to_op_implementations_dict(fake_mode, func, *args, **kwargs): + return op_implementations_dict[func](fake_mode, func, *args, **kwargs) + + +@register_op_impl(_is_tensor_constructor) +@register_op_impl([*_like_tensor_constructors]) +def constructors(fake_mode, func, *args, **kwargs): + assert func not in _non_kwarg_device_constructors + _, new_kwargs = normalize_function( + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + if "names" in kwargs: + raise UnsupportedOperatorException( + "torch.compile doesn't support named tensors" + ) + + if func in _like_tensor_constructors: + default_device = new_kwargs["input"].device + # TODO: file issue + args = (new_kwargs.pop("input"),) + else: + # cpu is default device if none is specified + default_device = torch.device("cpu") + args = () + out_device = new_kwargs.pop("device", None) + out_device = out_device if out_device is not None else default_device + new_kwargs["device"] = torch.device("meta") + # _like constructors have fake tensor inputs (maybe this causes the non-like + # to fail? hmmm) + with in_kernel_invocation_manager(fake_mode): + r = func(*args, **new_kwargs) + return FakeTensor(fake_mode, r, out_device) + + +@register_op_impl(aten.is_pinned.default) +def non_kwarg_is_pinned(fake_mode, func, *args, **kwargs): + _, new_kwargs = normalize_function( + func, args, kwargs, normalize_to_only_use_kwargs=True + ) + inp = new_kwargs.pop("input") + # we'll ignore device argument because it is deprecated and not + # actually used by is_pinned. + with in_kernel_invocation_manager(fake_mode): + r = func(inp) + return r + + +@register_op_impl(aten.to.prim_Device) +@register_op_impl(aten.to.device) +def non_kwarg_to(fake_mode, func, *args, **kwargs): + _, new_kwargs = normalize_function( + func, args, kwargs, normalize_to_only_use_kwargs=True + ) + input_device = new_kwargs["device"] + out_device = input_device if input_device else new_kwargs["input"].device + new_kwargs["device"] = torch.device("meta") + inp = new_kwargs.pop("input") + with in_kernel_invocation_manager(fake_mode): + r = func(inp, **new_kwargs) + # TODO: I think this does the wrong thing if r is inp + return fake_mode.fake_tensor_converter.from_meta_and_device( + fake_mode, r, out_device + ) + + +def stride_incorrect_op(op): + return False + + +# These operators have meta implementations with incorrect strides +@register_op_impl(stride_incorrect_op) +def wordaround_stride_incorrect_op(fake_mode, func, *args, **kwargs): + # This is a workaround for meta implementations with incorrect strides + + def is_symbolic(x): + if isinstance(x, FakeTensor): + return x._has_symbolic_sizes_strides + if isinstance(x, (torch.SymInt, torch.SymFloat, torch.SymBool)): + return True + return False + + # For static shapes, we can fall back to eager for the real strides + if fake_mode.allow_fallback_kernels: + require_dynamic = any( + is_symbolic(x) for x in itertools.chain(args, kwargs.values()) + ) + if not require_dynamic: + flat_args, args_spec = pytree.tree_flatten((args, kwargs)) + return run_fallback_kernel(fake_mode, func, flat_args, args_spec, None) + + raise UnsupportedOperatorException(func) + + +# Dont default to default device handling, +# since the device of `the_template` is ignored +@register_op_impl(aten.resize_as_.default) +def resize_as_(fake_mode, func, *args, **kwargs): + with in_kernel_invocation_manager(fake_mode): + return func(*args, **kwargs) + + +@register_op_impl(aten._sparse_coo_tensor_with_dims_and_tensors.default) +def _sparse_coo_tensor_with_dims_and_tensors(fake_mode, func, *args, **kwargs): + # TODO: remove me + return constructors(fake_mode, func, *args, **kwargs) + + +# index.Tensor data-dependent in only some conditions +@register_op_impl( + lambda func: torch.Tag.dynamic_output_shape in func.tags + and func + not in [aten.index.Tensor, aten.nonzero.default, aten.repeat_interleave.Tensor] +) +def dyn_shape(fake_mode, func, *args, **kwargs): + raise DynamicOutputShapeException(func) + + +def _unique( + fake_mode, + func, + arg, + dim, + sorted=True, + return_inverse=False, + return_counts=False, + *, + unique_consecutive=False, +): + if ( + fake_mode.shape_env is None + or not fake_mode.shape_env.allow_dynamic_output_shape_ops + ): + # Without symints/symfloats, cannot handle this + raise DynamicOutputShapeException(func) + + nnz = arg.unique_consecutive_memo if unique_consecutive else arg.unique_memo + + # Do not use a memo for unique_dim + if dim is not None or nnz is None: + # Avoid importing sympy at a module level + from torch.fx.experimental.symbolic_shapes import ( + _constrain_range_for_size, + has_free_symbols, + ) + + if not has_free_symbols(arg.numel()) and arg.numel() == 0: + # If numel is zero, then the output size must be zero. + # In this case, we must not allocate an unbacked SymInt, + # because if we do, it will immediately get refined to + # zero, but this will be inconsistent with size oblivious + # tests (which will continue to claim that the unbacked + # symint cannot equal zero). We could also unconditionally + # allocate an unbacked SymInt and not refine its range, + # but this seems more precise. + nnz = 0 + else: + nnz = fake_mode.shape_env.create_unbacked_symint() + + maxval = sys.maxsize - 1 + + numel = arg.numel() if dim is None else arg.size(dim) + if not has_free_symbols(numel): + maxval = int(numel) + + _constrain_range_for_size(nnz, max=maxval) + + if dim is None: + if unique_consecutive: + arg.unique_consecutive_memo = nnz + else: + arg.unique_memo = nnz + + if dim is None: + ret = [arg.new_empty((nnz,))] + else: + ret = [arg.new_empty(*arg.shape[:dim], nnz, *arg.shape[dim + 1 :])] + + return_if_dim_and_cpu = dim is not None and arg.fake_device == torch.device("cpu") + if return_inverse or return_if_dim_and_cpu: + inverse = arg.new_empty(arg.shape if dim is None else (arg.shape[dim],)) + else: + inverse = arg.new_empty(0) + ret.append(inverse) + + if return_counts or return_if_dim_and_cpu: + counts = arg.new_empty(ret[0].shape if dim is None else (ret[0].shape[dim],)) + else: + counts = arg.new_empty(0) + ret.append(counts) + + return tuple(ret) + + +@register_op_impl(aten._unique2.default) +def unique2( + fake_mode, func, arg, sorted=True, return_inverse=False, return_counts=False +): + return _unique(fake_mode, func, arg, None, sorted, return_inverse, return_counts) + + +@register_op_impl(aten.select.int) +def meta_select(fake_mode, func, self, dim, index): + from torch.fx.experimental.symbolic_shapes import guard_or_false + + if self.is_sparse: + return NotImplemented + + ndim = self.dim() + torch._check_index( + ndim != 0, + lambda: "select() cannot be applied to a 0-dim tensor.", + ) + + dim = dim if dim >= 0 else dim + ndim + size = self.size(dim) + + new_size = list(self.size()) + new_stride = list(self.stride()) + + new_storage_offset = None + if guard_or_false(index >= 0): + new_storage_offset = self.storage_offset() + index * new_stride[dim] + elif guard_or_false(index < 0): + new_storage_offset = self.storage_offset() + (index + size) * new_stride[dim] + + if new_storage_offset is None: + if fake_mode.shape_env is None or ( + not fake_mode.shape_env.allow_scalar_outputs + and not fake_mode.allow_scalar_outputs + ): + raise DataDependentOutputException(func) + + # index is data-dependent, we do not know which index we are accessing it could be index or index+size! + # we assign a new data-dependent symbol for the storage offset. + new_storage_offset = fake_mode.shape_env.create_unbacked_symint() + + del new_size[dim] + del new_stride[dim] + assert new_storage_offset is not None + return self.as_strided(new_size, new_stride, new_storage_offset) + + +@register_op_impl(aten.unique_dim.default) +def unique_dim( + fake_mode, func, arg, dim, sorted=True, return_inverse=False, return_counts=False +): + return _unique( + fake_mode, + func, + arg, + # normalize dim to be non-negative + dim if dim >= 0 else dim % max(arg.ndim, 1), + sorted, + return_inverse, + return_counts, + ) + + +@register_op_impl(aten.unique_consecutive.default) +def _(fake_mode, func, arg, return_inverse=False, return_counts=False, dim=None): + return _unique( + fake_mode, + func, + arg, + dim, + False, + return_inverse, + return_counts, + unique_consecutive=True, + ) + + +# This function is python match of computeStride_impl in TensorUtils.cpp +def _compute_stride(old_shape, old_stride, new_shape, size_oblivious=False): + from torch.fx.experimental.symbolic_shapes import ( + guard_or_false, + guard_or_true, + sym_eq, + ) + + def maybe_guard_or_false(x): + if size_oblivious: + return guard_or_false(x) + + return x + + def maybe_guard_or_true(x): + if size_oblivious: + return guard_or_true(x) + + return x + + if len(old_shape) == 0: + return [1] * len(new_shape) + + numel = reduce(operator.mul, old_shape, 1) + zero_numel = maybe_guard_or_false(numel == 0) + if zero_numel and maybe_guard_or_false(sym_eq(old_shape, new_shape)): + return old_stride + + new_stride = [0] * len(new_shape) + + if zero_numel: + for view_d in range(len(new_shape) - 1, -1, -1): + if view_d == len(new_shape) - 1: + new_stride[view_d] = 1 + else: + new_stride[view_d] = ( + max(new_shape[view_d + 1], 1) * new_stride[view_d + 1] + ) + return new_stride + + view_d = len(new_shape) - 1 + chunk_base_stride = old_stride[-1] + tensor_numel = 1 + view_numel = 1 + + for tensor_d in range(len(old_shape) - 1, -1, -1): + tensor_numel *= old_shape[tensor_d] + + if tensor_d == 0 or ( + maybe_guard_or_true(old_shape[tensor_d - 1] != 1) + and maybe_guard_or_true( + old_stride[tensor_d - 1] != tensor_numel * chunk_base_stride + ) + ): + while view_d >= 0 and ( + maybe_guard_or_true(view_numel < tensor_numel) + or maybe_guard_or_false(new_shape[view_d] == 1) + ): + new_stride[view_d] = view_numel * chunk_base_stride + view_numel *= new_shape[view_d] + view_d -= 1 + + if maybe_guard_or_true(view_numel != tensor_numel): + return None + + if tensor_d > 0: + chunk_base_stride = old_stride[tensor_d - 1] + tensor_numel = 1 + view_numel = 1 + if view_d != -1: + return None + return new_stride + + +def _view_has_unbacked_input(a, shape): + from torch.fx.experimental.symbolic_shapes import has_hint + + shape = utils.extract_shape_from_varargs(shape, validate=False) + + return ( + any(not has_hint(s) for s in a.size()) + or any(not has_hint(s) for s in a.stride()) + or any(not has_hint(s) for s in shape) + ) + + +def _view_unbacked_meta(a, shape, size_oblivious_enabled=True): + from torch._prims import view_of + from torch.fx.experimental.symbolic_shapes import guard_or_false, sym_eq + + # Creates a valid shape + shape = utils.extract_shape_from_varargs(shape, validate=False) + + # Reshape may be given a shape with a -1 length + # This indicates that the dimension's length should be inferred + shape = utils.infer_size(shape, a.numel()) + + # Special-cases reshaping zero dim tensors + if a.ndim == 0: + _a = a + for length in shape: + torch._check(length == 1) + _a = torch._refs.unsqueeze(_a, -1) + if _a is a: + return view_of(a) + else: + return _a + + # Special-cases reshaping to zero dim tensors + if len(shape) == 0: + _a = a + for length in a.shape: + torch._check(length == 1) + _a = torch._refs.squeeze(_a, -1) + if _a is a: + return view_of(a) + else: + return _a + + shape_numel = reduce(operator.mul, shape, 1) + + torch._check( + a.numel() == shape_numel, + lambda: f"Could not reshape a tensor with shape {a.shape} as a tensor with shape {shape}!", + ) + + if len(shape) == len(a.shape) and guard_or_false(sym_eq(shape, a.shape)): + return view_of(a) + + if is_contiguous_or_false(a) if size_oblivious_enabled else is_contiguous(a): + strides = make_contiguous_strides_for(shape) + return a.as_strided(shape, strides) + + new_strides = _compute_stride( + a.size(), a.stride(), shape, size_oblivious=size_oblivious_enabled + ) + + if new_strides is not None: + return a.as_strided(shape, new_strides) + + # If we fail to do size oblivious view, and backed_size_oblivious was on, + # then we redo everything by looking at hints and guarding instead of failing. + # Also if the expression has unbacked symbols, then we run again with size_oblivious_enabled=False + # to throw a data dependent error. + + if size_oblivious_enabled and ( + torch.fx.experimental._config.backed_size_oblivious + or _view_has_unbacked_input(a, shape) + ): + return _view_unbacked_meta(a, shape, size_oblivious_enabled=False) + + msg = f"Cannot view a tensor with shape {a.shape} and strides {a.stride()} as a tensor with shape {shape}!" + raise ValueError(msg) + + +@register_op_impl(aten._reshape_copy.default) +def _reshape_copy(fake_mode, func, a, *shape): + if a.is_sparse or a.is_mkldnn: + return NotImplemented + + shape = utils.infer_size(*shape, a.numel()) + if is_contiguous_or_false(a): + view = _view_meta(fake_mode, func, a, *shape) + return view.clone(memory_format=torch.contiguous_format) + else: + return _view_meta( + fake_mode, func, a.clone(memory_format=torch.contiguous_format), *shape + ) + + +@register_op_impl(aten.view.default) +@register_op_impl(aten._unsafe_view.default) +def _view_meta(fake_mode, func, a, *shape): + if torch.fx.experimental._config.backed_size_oblivious or _view_has_unbacked_input( + a, shape + ): + return _view_unbacked_meta(a, shape) + else: + return torch._refs._reshape_view_helper(a, *shape, allow_copy=False) + + +@register_op_impl(aten.view_copy.default) +def _view_meta_copy(fake_mode, func, a, *shape, out=None): + result = _view_meta(fake_mode, func, a, *shape) + if out is not None: + return result + + return pytree.tree_map( + lambda x: x.clone(memory_format=torch.contiguous_format), + result, + ) + + +@register_op_impl(aten.repeat_interleave.Tensor) +def repeat_interleave_tensor(fake_mode, func, repeats, output_size=None): + if output_size is None: + if ( + fake_mode.shape_env is None + or not fake_mode.shape_env.allow_dynamic_output_shape_ops + ): + raise DynamicOutputShapeException(func) + + output_size = fake_mode.shape_env.create_unbacked_symint() + + # Avoid importing sympy at a module level + from torch.fx.experimental.symbolic_shapes import _constrain_range_for_size + + _constrain_range_for_size(output_size) + # TODO: consider a memo + return repeats.new_empty(output_size) + + +@register_op_impl(torch.ops.aten.item.default) +@register_op_impl(torch.ops.aten._local_scalar_dense.default) +def local_scalar_dense(fake_mode, func, arg): + if (r := arg.item_memo) is not None: + return r + if fake_mode.shape_env is None or ( + not fake_mode.shape_env.allow_scalar_outputs + and not fake_mode.allow_scalar_outputs + ): + # Without symints/symfloats, cannot handle this + raise DataDependentOutputException(func) + if is_float_dtype(arg.dtype): + r = fake_mode.shape_env.create_unbacked_symfloat() + elif is_integer_dtype(arg.dtype): + r = fake_mode.shape_env.create_unbacked_symint() + elif is_boolean_dtype(arg.dtype): + r = fake_mode.shape_env.create_unbacked_symbool() + else: + raise NotImplementedError(f"local_scalar_dense/item NYI for {arg.dtype}") + arg.item_memo = r + return r + + +@register_op_impl(torch.ops.aten.nonzero_numpy.default) +def nonzero_numpy(fake_mode, func, arg): + return torch.ops.aten.nonzero.default(arg).unbind(1) + + +@register_op_impl(torch.ops.aten.nonzero.default) +def nonzero(fake_mode, func, arg): + if ( + fake_mode.shape_env is None + or not fake_mode.shape_env.allow_dynamic_output_shape_ops + ): + # Without symints/symfloats, cannot handle this + raise DynamicOutputShapeException(func) + + if (nnz := arg.nonzero_memo) is None: + # Avoid importing sympy at a module level + from torch.fx.experimental.symbolic_shapes import ( + _constrain_range_for_size, + has_free_symbols, + ) + from torch.utils._sympy.numbers import IntInfinity + from torch.utils._sympy.value_ranges import bound_sympy + + if not has_free_symbols(arg.numel()) and arg.numel() == 0: + # If numel is zero, then the output size must be zero. + # In this case, we must not allocate an unbacked SymInt, + # because if we do, it will immediately get refined to + # zero, but this will be inconsistent with size oblivious + # tests (which will continue to claim that the unbacked + # symint cannot equal zero). We could also unconditionally + # allocate an unbacked SymInt and not refine its range, + # but this seems more precise. + nnz = 0 + else: + nnz = fake_mode.shape_env.create_unbacked_symint() + + maxval = sys.maxsize - 1 + + if not has_free_symbols(arg.numel()): + maxval = int(arg.numel()) + else: + prod_node = math.prod(arg.shape).node + prod_range = bound_sympy( + prod_node.expr, prod_node.shape_env.var_to_range + ) + if isinstance(prod_range.upper, IntInfinity): + maxval = sys.maxsize - 1 + else: + maxval = prod_range.upper + + _constrain_range_for_size(nnz, max=maxval) + + arg.nonzero_memo = nnz + + return arg.new_empty_strided((nnz, arg.dim()), (1, nnz), dtype=torch.int64) + + +@register_op_impl(torch.ops.aten._padded_dense_to_jagged_forward.default) +def _padded_dense_to_jagged_forward(fake_mode, func, padded, offsets, total_L=None): + # only one jagged dim is supported for now + assert len(offsets) == 1 + + if not total_L: + if ( + fake_mode.shape_env is None + or not fake_mode.shape_env.allow_dynamic_output_shape_ops + ): + # Without symints/symfloats, cannot handle this + raise DynamicOutputShapeException(func) + + total_L = fake_mode.shape_env.create_unbacked_symint() + + maxval = sys.maxsize - 1 + + # Avoid importing sympy at a module level + from torch.fx.experimental.symbolic_shapes import ( + _constrain_range_for_size, + has_free_symbols, + ) + + if not has_free_symbols(padded.numel()): + maxval = int(padded.numel()) + + _constrain_range_for_size(total_L, min=0, max=maxval) + + output_shape = (total_L, *padded.shape[2:]) + return padded.new_empty(output_shape) + + +def _compute_slice_index(size, index): + from torch.fx.experimental.symbolic_shapes import guard_or_false, sym_and + + if guard_or_false(sym_and(index >= 0, index <= size)): + return index + elif guard_or_false(sym_and(index < 0, index >= -size)): + return index + size + elif guard_or_false(index < -size): + return 0 + elif guard_or_false(index > size): + return size + return None + + +@register_op_impl(torch.ops.aten.slice.Tensor) +def slice_forward( + fake_mode, + func, + self, + dim: int = 0, + start: Optional[int] = None, + end: Optional[int] = None, + step: int = 1, +): + from torch.fx.experimental.symbolic_shapes import ( + guard_or_false, + statically_known_true, + ) + + shape_env = fake_mode.shape_env + + ndim = self.dim() + if ndim == 0: + raise RuntimeError("slice() cannot be applied to a 0-dim tensor.") + dim = canonicalize_dim(self.dim(), dim) + sizes = list(self.size()) + strides = list(self.stride()) + + if step <= 0: + raise RuntimeError("slice step must be positive") + + # start, end + start_index = 0 if start is None else _compute_slice_index(sizes[dim], start) + end_index = ( + sizes[dim] + if statically_known_true(end == sys.maxsize) or end is None + else _compute_slice_index(sizes[dim], end) + ) + + # size + new_size = None + if start_index is not None and end_index is not None: + if guard_or_false(end_index >= start_index): + new_size = (end_index - start_index + step - 1) // step + elif guard_or_false(start_index >= end_index): + new_size = 0 + + # create unbacked if case unknown + if new_size is None: + new_size = shape_env.create_unbacked_symint() + torch._check(new_size >= 0) + torch._check(new_size <= sizes[dim]) + + # stride + new_stride = strides[dim] * step + + # storage offset + if start_index is not None: + storage_offset = self.storage_offset() + start_index * strides[dim] + else: + storage_offset = shape_env.create_unbacked_symint() + torch._check(storage_offset >= 0) + + sizes[dim] = new_size + strides[dim] = new_stride + if self.is_quantized: + raise NotImplementedError( + "Slice decomposition for quantized tensors aren't implemented" + ) + else: + return self.as_strided(sizes, strides, storage_offset) + + +@register_op_impl(torch.ops.aten.masked_select.default) +def masked_select(fake_mode, func, self, mask): + if ( + fake_mode.shape_env is None + or not fake_mode.shape_env.allow_dynamic_output_shape_ops + ): + # Without symints/symfloats, cannot handle this + raise DynamicOutputShapeException(func) + + nnz = fake_mode.shape_env.create_unbacked_symint() + + # see nonzero for commentary + maxval = sys.maxsize - 1 + + # Avoid importing sympy at a module level + from torch.fx.experimental.symbolic_shapes import ( + _constrain_range_for_size, + has_free_symbols, + ) + from torch.utils._sympy.numbers import IntInfinity + from torch.utils._sympy.value_ranges import bound_sympy + + # If num elements is expressed symbolically, calculate + # the concrete value based on upper bounds. Otherwise, + # we can set max val directly. + if not has_free_symbols(self.numel()): + num_elements = int(self.numel()) + else: + prod_node = math.prod(self.shape).node + prod_range = bound_sympy(prod_node.expr, prod_node.shape_env.var_to_range) + if isinstance(prod_range.upper, IntInfinity): + num_elements = sys.maxsize - 1 + else: + num_elements = prod_range.upper + if num_elements > 2: + maxval = num_elements + + _constrain_range_for_size(nnz, max=maxval) + + return self.new_empty((nnz,)) + + +@register_op_impl(torch.ops.aten._assert_tensor_metadata.default) +def assert_tensor_metadata( + fake_mode, + func, + t, + sizes=None, + strides=None, + dtype=None, + *, + device=None, + layout=None, +) -> None: + if sizes is not None: + assert t.size() == sizes, ( + f"Tensor sizes mismatch! Expected: {sizes}, Got: {t.size()}" + ) + if strides is not None: + assert t.stride() == strides, ( + f"Tensor strides mismatch! Expected: {strides}, Got: {t.stride()}" + ) + if dtype is not None: + assert t.dtype == dtype, ( + f"Tensor dtype mismatch! Expected: {dtype}, Got: {t.dtype}" + ) + if layout is not None: + assert t.layout == layout, ( + f"Tensor layout mismatch! Expected: {layout}, Got: {t.layout()}" + ) + if device is not None: + assert t.device == device, ( + f"Tensor device mismatch! Expected: {device}, Got: {t.device}" + ) + + +# NB: this must be ordered after local_scalar_dense +@register_op_impl(lambda func: torch.Tag.data_dependent_output in func.tags) +def data_dep(fake_mode, func, *args, **kwargs): + raise DataDependentOutputException(func) + + +# Bool Indices get Expanded as Masks +# See: IndexingUtils.h:expandTensors +def check_no_bool_index_tensors(func, self, indices): + for index in indices: + if index is not None and index.dtype in (torch.bool, torch.uint8): + raise DynamicOutputShapeException(func) + + +def run_and_return_new_tensor_of_input_device(fake_mode, func, args, kwargs): + _, new_kwargs = normalize_function( + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + out_device = new_kwargs["input"].device + with in_kernel_invocation_manager(fake_mode): + out = func(*args, **kwargs) + if not is_noncontiguous_supported(out_device): + out = out.new_empty(out.shape) + + if out is new_kwargs["input"]: + return out # copy_ + return FakeTensor(fake_mode, out, out_device) + + +_is_builtin_namespaces = ordered_set("aten", "prims", "prim") + + +def is_builtin(op): + return op.namespace in _is_builtin_namespaces + + +def has_meta(func): + return torch._C._dispatch_has_computed_kernel_for_dispatch_key(func.name(), "Meta") + + +# These are for the `torch._foreach_...` ops like `torch._foreach_add`. +@register_op_impl( + lambda func: is_builtin(func) + and func.name().startswith("aten::_foreach_") + and has_meta(func) +) +def foreach_run_and_map_input_device(fake_mode, func, *args, **kwargs): + tensor_lists = [ + arg + for arg in itertools.chain(args, kwargs.values()) + if isinstance(arg, (list, tuple)) + and len(arg) + and isinstance(arg[0], torch.Tensor) + ] + + try: + with in_kernel_invocation_manager(fake_mode): + out_meta = func(*args, **kwargs) + except NotImplementedError: + return NotImplemented + + if not out_meta: + return out_meta + + assert tensor_lists + out_fake = [] + + for i, meta_t in enumerate(out_meta): + device, _ = FakeTensor._find_common_device(func, [tl[i] for tl in tensor_lists]) + out_fake.append( + fake_mode.fake_tensor_converter.from_meta_and_device( + fake_mode, meta_t, device + ) + ) + + return out_fake + + +# Dont default to default device handling, +# Since op can take in non-zero sized cpu +# index tensors with cuda self +@register_op_impl(aten.index.Tensor) +def index_tensor(fake_mode, func, *args, **kwargs): + from torch._meta_registrations import meta_index_Tensor + + _, new_kwargs = normalize_function( + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + out_device = new_kwargs["input"].device + # ensure nonzero call goes to fake tensor + with fake_mode: + out = meta_index_Tensor(*args, **kwargs) + return out.to(out_device) + + +# Can take mixed meta/non-meta arguments; the meta registration +# will roughly do the right thing even when given real devices +@register_op_impl(aten._embedding_bag.default) +def embedding_bag(fake_mode, func, *args, **kwargs): + from torch._meta_registrations import meta_embedding_bag + + with fake_mode: + return meta_embedding_bag(*args, **kwargs) + + +# takes in multiple-devices, dont default to default device handling +@register_op_impl(aten._unsafe_index_put.default) +@register_op_impl(aten.copy.default) +@register_op_impl(aten.copy_.default) +@register_op_impl(aten.slice_scatter.default) +def multi_device_op_default(fake_mode, func, *args, **kwargs): + return run_and_return_new_tensor_of_input_device(fake_mode, func, args, kwargs) + + +# same with multi_device_op_default, but return the input +@register_op_impl(aten.copy.out) +@register_op_impl(aten.slice_scatter.out) +def multi_device_op_out(fake_mode, func, *args, **kwargs): + with in_kernel_invocation_manager(fake_mode): + func(*args, **kwargs) + + _, new_kwargs = normalize_function( + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + return new_kwargs["input"] + + +@register_op_impl(aten.index_put.default) +@register_op_impl(aten.index_put_.default) +def index_put_impl(fake_mode, func, *args, **kwargs): + _, new_kwargs = normalize_function( + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + values = new_kwargs["values"] + self_device = new_kwargs["input"].fake_device + torch._check( + self_device == values.fake_device or (values.ndim == 0 and values.numel() == 1), + lambda: f"Mismatching {func} device between self ({self_device}) and values ({values.device})", + ) + + out = run_and_return_new_tensor_of_input_device(fake_mode, func, args, kwargs) + if func is aten.index_put_.default: + return new_kwargs["input"] + else: + return out + + +@register_op_impl(aten._nested_tensor_from_tensor_list.default) +@register_op_impl(aten._nested_tensor_from_tensor_list.out) +@register_op_impl(aten._nested_view_from_buffer.default) +@register_op_impl(aten._nested_view_from_buffer_copy.default) +def nested_tensors_unsupported(fake_mode, func, *args, **kwargs): + raise UnsupportedOperatorException( + "torch.compile does not support strided NestedTensor" + ) + + +@register_op_impl( + [ + x + for x in _device_not_kwarg_ops + if x + not in ( + # these are already registered elsewhere + aten.is_pinned.default, + aten.to.device, + aten.to.prim_Device, + aten._nested_tensor_from_tensor_list.default, + aten._nested_tensor_from_tensor_list.out, + ) + ] +) +def nyi(fake_mode, func, *args, **kwargs): + assert func not in _device_not_kwarg_ops, f"NYI: {func}" + + +@register_op_impl([aten.convolution.default, aten.convolution_backward.default]) +def conv(fake_mode, func, *args, **kwargs): + _, kwargs = normalize_function( + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + device = kwargs["input"].fake_device + # need to re-enable mode so the tensors report fake device + with fake_mode: + # if the input is unsqueezed is done in Convolution.cpp we get segfault + k = kwargs["weight"].ndim + batch = kwargs["input"].shape[0] + + # Avoid importing sympy at a module level + from torch.fx.experimental.symbolic_shapes import has_hint + + if not has_hint(batch): + # TODO: We can make this a little more faithful with best effort + # channels last detection (but only if it's statically obvious!) + mem_fmt = None + else: + if func is aten.convolution.default: + conv_backend = torch._C._select_conv_backend(**kwargs) + else: + conv_backend = torch._C._select_conv_backend( + kwargs["input"], + kwargs["weight"], + bias=None, + stride=kwargs["stride"], + padding=kwargs["padding"], + dilation=kwargs["dilation"], + transposed=kwargs["transposed"], + output_padding=kwargs["output_padding"], + groups=kwargs["groups"], + bias_sizes=kwargs["bias_sizes"], + ) + # Expand 1d -> 2d. + # Note: Avoid expanding before calling _select_conv_backend, + # as the function handles 2D expansion internally. + if k == 3 and not kwargs["input"].is_mkldnn and not kwargs["input"].is_xpu: + # Note: Using input.to(memory_format=contiguous) does not work. + kwargs["input"] = kwargs["input"].contiguous().unsqueeze(2) + kwargs["weight"] = kwargs["weight"].unsqueeze(2) + if len(kwargs["stride"]) == 1: + kwargs["stride"].insert(0, 1) + kwargs["padding"].insert(0, 0) + kwargs["dilation"].insert(0, 1) + kwargs["output_padding"].insert(0, 0) + mem_fmt = torch._C._conv_determine_backend_memory_format( + kwargs["input"], kwargs["weight"], conv_backend + ) + # revert 2d -> 1d + if k == 3 and not kwargs["input"].is_mkldnn and not kwargs["input"].is_xpu: + kwargs["input"] = kwargs["input"].squeeze(2) + kwargs["weight"] = kwargs["weight"].squeeze(2) + if len(kwargs["stride"]) == 2: + kwargs["stride"].pop(0) + kwargs["padding"].pop(0) + kwargs["dilation"].pop(0) + kwargs["output_padding"].pop(0) + + def convert(t, mem_fmt): + if t is None: + return t + if mem_fmt is not None: + # channels last only support 4d, try to expand dim then convert it back later. + if t.dim() == 3 and mem_fmt == torch.channels_last: + t = t.unsqueeze(2).to(memory_format=mem_fmt).squeeze(2) + else: + t = t.to(memory_format=mem_fmt) + return FakeTensor(fake_mode, t, device) + + with in_kernel_invocation_manager(fake_mode): + out = func(**kwargs) + + if func is aten.convolution.default: + return convert(out, mem_fmt) + else: + return ( + convert(out[0], mem_fmt), + convert(out[1], mem_fmt), + convert(out[2], None), + ) + + +@register_op_impl(torch.ops.aten.bincount.default) +def bincount(fake_mode, func, inputs, weights=None, minlength=0): + if ( + fake_mode.shape_env is None + or not fake_mode.shape_env.allow_dynamic_output_shape_ops + ): + # Without symints/symfloats, cannot handle this + raise DynamicOutputShapeException(func) + + new_size = fake_mode.shape_env.create_unbacked_symint() + + from torch.fx.experimental.symbolic_shapes import _constrain_range_for_size + + _constrain_range_for_size(new_size) + torch._check(new_size >= minlength) + return inputs.new_empty(new_size) + + +@register_op_impl(torch.ops.aten._pack_padded_sequence.default) +def _pack_padded_sequence(fake_mode, func, inputs, lengths, batch_first): + if ( + fake_mode.shape_env is None + or not fake_mode.shape_env.allow_dynamic_output_shape_ops + ): + # Without symints/symfloats, cannot handle this + raise DynamicOutputShapeException(func) + + new_batch_size = fake_mode.shape_env.create_unbacked_symint() + + from torch.fx.experimental.symbolic_shapes import _constrain_range_for_size + + _constrain_range_for_size(new_batch_size) + + if not batch_first: + # Inputs should have shape (batch_size, seq_len, *) + inputs = inputs.transpose(0, 1) + + res_size = inputs.shape[1:] + packed_data = inputs.new_empty(res_size) + batch_size = inputs.new_empty((new_batch_size,)) + return (packed_data, batch_size) + + +FAST_OP_IMPLEMENTATIONS = {} + + +# Unlike register_op_impl, these don't do the slow iteration for +# run_impl_check, and these run BEFORE decompositions +def register_fast_op_impl(func: OpOverload): + def impl_decorator(op_impl): + FAST_OP_IMPLEMENTATIONS[func] = op_impl + return op_impl + + return impl_decorator + + +# infer_size_impl in ExpandUtils +def infer_size(a, b): + from torch.fx.experimental.symbolic_shapes import guard_or_false + + dimsA = len(a) + dimsB = len(b) + ndim = max(dimsA, dimsB) + expandedSizes = [0] * ndim + for i in range(ndim - 1, -1, -1): + offset = ndim - 1 - i + dimA = dimsA - 1 - offset + dimB = dimsB - 1 - offset + sizeA = a[dimA] if dimA >= 0 else 1 + sizeB = b[dimB] if dimB >= 0 else 1 + + # NB: It is very important to test for broadcasting, before testing + # sizeA == sizeB. This is because the broadcasting tests are likely + # to be statically known (in particular, if sizeA/sizeB is unbacked + # but size-like, we will unsoundly assume they never equal 1), but + # the sizeA == sizeB test may not be statically known. However, once + # we have established that no broadcasting is happening, the + # sizeA == sizeB is now expect_true and we can defer it as a runtime + # assert (this works because Python will return the terminal + # expression of an or statement as-is, without bool()'ing it; if this + # were not the case, we'd need to write this using torch.sym_or() or + # something like that). + torch._check( + guard_or_false(sizeA == 1) or guard_or_false(sizeB == 1) or sizeA == sizeB, + lambda: f"The size of tensor a ({sizeA}) " + f"must match the size of tensor b ({sizeB}) " + f"at non-singleton dimension {i})", + ) + expandedSizes[i] = sizeB if guard_or_false(sizeA == 1) else sizeA + return tuple(expandedSizes) + + +def make_fast_binary_impl( + slow_ref, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT +): + def fast_binary_impl(mode, *args, **kwargs): + def slow(msg): + count_label(f"slow {msg}") + with mode: + return slow_ref(*args, **kwargs) + + count_label("attempt fast") + + # Fast path (based off of TensorIterator fast path). + # Unfortunately, there is no way to easily deduplicate + # this with either the TensorIterator C++ implementation + # (which we don't want to SymIntify, and also the algorithm + # here is slightly different from TensorIterator to allow + # for broadcasting), nor the PrimTorch implementation + # (which does not actually implement a fast path.) + + operands = args + + # compute_shape + final_shape = None + for op in operands: + shape = op.shape if isinstance(op, torch.Tensor) else () + if final_shape is None: + final_shape = shape + # TODO: Minor optimization: track if the shapes + # were equal so you can skip the equality check + # below if unnecessary + final_shape = infer_size(final_shape, shape) + assert final_shape is not None + + from torch.fx.experimental.symbolic_shapes import guard_or_false, sym_eq + + # Do some extra safety checks to see if the output + # stride is obvious + for op in operands: + if ( + isinstance(op, torch.Tensor) + and len(op.shape) == len(final_shape) + # take the slow path if result is not determined. + and guard_or_false(sym_eq(op.shape, final_shape)) + ): + break + else: + # if we never break in the for loop above we take the slow path. + return slow("both tensors nontrivially broadcast") + + # compute_types + cpu = torch.device("cpu") + common_device = cpu + common_dtype = None + has_different_input_dtypes = False + for op in operands: + if not isinstance(op, torch.Tensor): + # Use elementwise_dtypes for the tricky case + has_different_input_dtypes = True + continue + if common_device == cpu and op.device.type != "cpu": + common_device = op.device + if common_dtype is None: + if type_promotion_kind != ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT: + has_different_input_dtypes = True + else: + common_dtype = op.dtype + elif common_dtype != op.dtype: + has_different_input_dtypes = True + + if has_different_input_dtypes: + # compute promotion + # TODO: we don't need the compute type + _, common_dtype = elementwise_dtypes( + *operands, type_promotion_kind=type_promotion_kind + ) + + # check all tensors on same device + # cpu scalars are assumed allow + current_cpu_scalars_on_non_cpu = 0 + max_cpu_scalars_on_non_cpu = 1 # hard coded atm + for op in operands: + if not isinstance(op, torch.Tensor): + continue + if common_device != cpu and op.dim() == 0 and op.device == cpu: + if current_cpu_scalars_on_non_cpu >= max_cpu_scalars_on_non_cpu: + return slow("error") + current_cpu_scalars_on_non_cpu += 1 + elif op.device != common_device: + return slow("error") + + # compute_fast_setup_type + definitely_contiguous = True + definitely_channels_last = True + + # TODO: is_non-overlapping_and_dense not bound from Python + # no inplace, no out, everything defined + + if is_noncontiguous_supported(common_device): + for op in operands: + if not isinstance(op, torch.Tensor): + continue + definitely_contiguous = ( + definitely_contiguous + and is_contiguous_for_memory_format_or_false( + op, memory_format=torch.contiguous_format + ) + ) + definitely_channels_last = ( + definitely_channels_last + and is_contiguous_for_memory_format_or_false( + op, memory_format=torch.channels_last + ) + ) + if definitely_contiguous: + # do contiguous + count_label("fast is_contiguous") + return FakeTensor( + mode, + torch.empty( + final_shape, + dtype=common_dtype, + device="meta", + memory_format=torch.contiguous_format, + ), + device=common_device, + ) + if definitely_channels_last: + count_label("fast channels_last") + # do channels last + return FakeTensor( + mode, + torch.empty( + final_shape, + dtype=common_dtype, + device="meta", + memory_format=torch.channels_last, + ), + device=common_device, + ) + + return slow("no contiguity match") + + return fast_binary_impl + + +# disable the python dispatcher to avoid decomposing detach() further +# (proxy_mode should still decompose detach() though) +def fast_detach(fake_mode, x, include_real=False): + with no_python_dispatcher(), in_kernel_invocation_manager(fake_mode): + out = torch.ops.aten.detach.default(x) + if include_real: + return FakeTensor(fake_mode, out, x.device, real_tensor=x.real_tensor) + return FakeTensor(fake_mode, out, x.device) + + +@functools.cache +def get_fast_op_impls(): + import torch._refs + + register_fast_op_impl(torch.ops.aten.add.Tensor)( + make_fast_binary_impl(torch._refs.add) + ) + register_fast_op_impl(torch.ops.aten.sub.Tensor)( + make_fast_binary_impl(torch._refs.sub) + ) + register_fast_op_impl(torch.ops.aten.mul.Tensor)( + make_fast_binary_impl(torch._refs.mul) + ) # type: ignore[has-type] + register_fast_op_impl(torch.ops.aten.div.Tensor)( + make_fast_binary_impl( + torch._refs.div, + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + ) + ) + register_fast_op_impl(torch.ops.aten.detach.default)(fast_detach) + return FAST_OP_IMPLEMENTATIONS diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py new file mode 100644 index 0000000000000000000000000000000000000000..23d222c5165e4e66775d117163c4b896144b32b9 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py @@ -0,0 +1,3422 @@ +# mypy: allow-untyped-decorators +from __future__ import annotations + +import atexit +import contextlib +import dataclasses +import functools +import logging +import math +import os +import threading +import traceback +import types +import typing +import weakref +from collections import defaultdict +from dataclasses import dataclass +from typing import ( + Any, + cast, + Literal, + Optional, + TYPE_CHECKING, + TypeGuard, + TypeVar, + Union, +) +from typing_extensions import Self +from weakref import ReferenceType + +import torch +import torch._library.utils as library_utils +from torch import SymBool, SymFloat, SymInt, Tensor +from torch._C._functorch import is_functorch_wrapped_tensor, is_legacy_batchedtensor +from torch._library.fake_class_registry import FakeScriptObject +from torch._library.fake_profile import MissingOpProfile +from torch._logging import dtrace_structured +from torch._prims_common import suggest_memory_format +from torch._subclasses.meta_utils import ( + assert_eq, + assert_metadata_eq, + is_sparse_any, + is_sparse_compressed, + MetaConverter, +) +from torch._utils import render_call +from torch.fx.immutable_collections import immutable_dict +from torch.fx.operator_schemas import normalize_function +from torch.multiprocessing.reductions import StorageWeakRef +from torch.overrides import TorchFunctionMode +from torch.types import IntLikeType, py_sym_types +from torch.utils._mode_utils import no_dispatch +from torch.utils._python_dispatch import ( + is_traceable_wrapper_subclass, + TorchDispatchMode, +) +from torch.utils._pytree import KeyPath, keystr, PyTree, tree_map, tree_map_, TreeSpec +from torch.utils._stats import count +from torch.utils._traceback import CapturedTraceback + +from ._fake_tensor_utils import _CacheKeyState, _PySymInputStub, _SymIntOutputStub + + +if TYPE_CHECKING: + from collections.abc import Callable, Generator, Iterable, Mapping, Sequence + from types import TracebackType + + from torch._guards import Source + from torch._ops import OpOverload + from torch.fx.experimental.symbolic_shapes import ShapeEnv, SymbolicContext + +log = logging.getLogger(__name__) +hc_log = torch._logging.getArtifactLogger(__name__, "hierarchical_compile") + +# TODO: Hack to unblock https://github.com/pytorch/pytorch/pull/108186 +# Proper fix tracked by https://github.com/pytorch/pytorch/issues/120105 +try: + not_implemented_log = torch._logging.getArtifactLogger(__name__, "not_implemented") +except ValueError as e: + if "'not_implemented' not registered" in str(e): + not_implemented_log = logging.getLogger(__name__ + ".not_implemented") + else: + raise e + + +DimList = list + +pytree = torch.utils._pytree +T = TypeVar("T") + +aten = torch._ops.ops.aten + +CONSTANT_NUMEL_LIMIT = 1 + +RECURSION_COUNT = 0 + + +# Small helper that increments recursion count, and +# resets it when the object goes out of scope. Useful +# if you don't want to increase indentation which is +# what a context manager would do. +class IncrementRecursionCount: + def __init__(self) -> None: + global RECURSION_COUNT + RECURSION_COUNT += 1 + + def __del__(self) -> None: + global RECURSION_COUNT + RECURSION_COUNT -= 1 + + +@dataclass +class UnsupportedFakeTensorException(RuntimeError): + reason: str + + +@dataclass +class DynamicOutputShapeException(RuntimeError): + func: OpOverload + + +@dataclass +class DataDependentOutputException(RuntimeError): + func: OpOverload + + +@dataclass +class UnsupportedOperatorException(RuntimeError): + func: OpOverload + + +@dataclass +class UnsupportedMutationAliasingException(RuntimeError): + reason: str + + +@dataclass +class MetadataMismatchError(RuntimeError): + reason: str + + +class FakeTensorTLS(threading.local): + # Default to None, otherwise it'll be used to override _all_ + # `FakeTensorMode.allow_non_fake_inputs` in this thread. + allow_non_fake_inputs_override: Optional[bool] + non_strict_export_fake_tensor_tracker: weakref.WeakSet + + def __init__(self) -> None: + self.allow_non_fake_inputs_override = None + self.non_strict_export_fake_tensor_tracker = weakref.WeakSet() + + +fake_tensor_tls = FakeTensorTLS() + + +def ordered_set(*items: T) -> dict[T, Literal[True]]: + return dict.fromkeys(items, True) + + +@contextlib.contextmanager +def unset_fake_temporarily() -> Generator[Optional[TorchDispatchMode], None, None]: + old = torch._C._unset_dispatch_mode(torch._C._TorchDispatchModeKey.FAKE) + try: + yield old + finally: + if old is not None: + torch._C._set_dispatch_mode(old) + + +@contextlib.contextmanager +def disable_fake_tensor_cache(fake_mode: FakeTensorMode) -> Generator[None, None, None]: + old_value: bool = fake_mode.cache_enabled + try: + fake_mode.cache_enabled = False + yield + finally: + fake_mode.cache_enabled = old_value + + +def get_plain_tensors( + subclass: Tensor, *, out: list[Union[Tensor, int, SymInt]] +) -> list[Union[Tensor, int, SymInt]]: + # This function is used in Runtime, do not add redundant asserts + todo = [subclass] + while todo: + curr = todo.pop() + if not is_traceable_wrapper_subclass(curr): + out.append(curr) + continue + + inner_keys, _ = curr.__tensor_flatten__() + todo.extend(getattr(curr, key) for key in reversed(inner_keys)) + + return out + + +def is_fake(x: object) -> TypeGuard[Tensor]: + from torch._subclasses.functional_tensor import FunctionalTensor + + if isinstance(x, FakeTensor): + return True + if is_traceable_wrapper_subclass(x): + attrs, _ = type(x).__tensor_flatten__(x) + flattened_tensors = [getattr(x, attr) for attr in attrs] + all_fake = all(is_fake(x) for x in flattened_tensors) + any_fake = any(is_fake(x) for x in flattened_tensors) + assert all_fake == any_fake, "got mixed fake and real tensors!" + return all_fake + elif isinstance(x, FunctionalTensor): + return is_fake(x.elem) + elif isinstance(x, Tensor) and torch._is_functional_tensor(x): + reapply_views = torch._C._functionalization_reapply_views_tls() + unwrapped = torch._C._functorch._unwrap_functional_tensor(x, reapply_views) + return is_fake(unwrapped) + elif isinstance(x, Tensor) and is_functorch_wrapped_tensor(x): + unwrapped = torch._C._functorch.get_unwrapped(x) + return is_fake(unwrapped) + return False + + +def maybe_get_fake_mode(t: object) -> Optional[FakeTensorMode]: + from torch._subclasses.functional_tensor import FunctionalTensor + + if isinstance(t, FakeTensor): + return t.fake_mode + if is_traceable_wrapper_subclass(t): + inner_tensor_names, _ = t.__tensor_flatten__() + modes = [ + maybe_get_fake_mode(getattr(t, t_name)) for t_name in inner_tensor_names + ] + m = modes[0] + assert all(m is x for x in modes) + return m + elif isinstance(t, FunctionalTensor): + return maybe_get_fake_mode(t.elem) + elif isinstance(t, Tensor) and torch._is_functional_tensor(t): + reapply_views = torch._C._functionalization_reapply_views_tls() + unwrapped = torch._C._functorch._unwrap_functional_tensor(t, reapply_views) + return maybe_get_fake_mode(unwrapped) + elif isinstance(t, Tensor) and is_functorch_wrapped_tensor(t): + unwrapped = torch._C._functorch.get_unwrapped(t) + return maybe_get_fake_mode(unwrapped) + return None + + +@functools.cache +def get_schema_info(func: OpOverload) -> torch._C._SchemaInfo: + return torch._C._SchemaInfo(func._schema) + + +# many of the decompositions registered to torch/_prims do not at the moment model +# aliasing or strides, so as an incremental step, just enable the decompositions in +# torch/_decomp/decompositions.py. +# decomps are used for aot autograd tracing so we would like to unify on their +# implementation and add additional testing to them +@functools.cache +def torch_decomp_decompositions(func: OpOverload) -> bool: + from torch._decomp import decomposition_table + + decompositions = torch._decomp.decompositions + # Note that the function in the decomposition table might be + # different from the one in the module because of the difference + # in out handling in aten API and torch public API + return decomposition_table[func].__module__.startswith( + "torch._decomp" + ) and decomposition_table[func].__name__ in dir(decompositions) + + +def tree_flatten_only(ty: type[T], tree: PyTree) -> list[T]: + flat_vals = pytree.tree_leaves(tree) + return [elem for elem in flat_vals if isinstance(elem, ty)] + + +def _is_plain_tensor(t: object) -> bool: + return ( + type(t) is Tensor + and t.layout == torch.strided + and not ( + t.is_sparse + or t.is_nested + or is_functorch_wrapped_tensor(t) + or is_legacy_batchedtensor(t) + or torch._is_functional_tensor(t) + ) + ) + + +# Similar to `MetaConverter`, this is a class for converting +# multiple tensors into fake tensors which share the same view/storage +# structure. Like `MetaConverter`, it uses `WeakIdRef` to +# hold a weak reference for all memoized tensors. +class FakeTensorConverter: + @property + def tensor_memo( + self, + ) -> weakref.WeakValueDictionary: + # not valid until py3.10 + # weakref.WeakValueDictionary["torch._subclasses.meta_utils.MetaTensorId", Optional["FakeTensor"]] + return self.meta_converter.tensor_memo + + meta_converter: MetaConverter + constant_storage_mapping: dict[StorageWeakRef, list[ReferenceType]] + export: bool + + def __init__(self, *, copy_data: bool = False, export: bool = False) -> None: + self.meta_converter = MetaConverter(copy_data=copy_data) + self.export = export + + # map from to storage to corresponding constant tensors + self.constant_storage_mapping = {} + + def add_constant_storage_mapping(self, fake_tensor: FakeTensor) -> None: + # when you have a constant, aliased tensor: + # const_tensor.add_(torch.rand([1])) + # all aliases of it must become no longer const + assert isinstance(fake_tensor, FakeTensor) and fake_tensor.constant is not None + weak_st = StorageWeakRef(fake_tensor.constant._typed_storage()) + + # we need a map from a weak storage to all of its corresponding + # constant tensors. python doesn't have the weak value equivalent + # of defaultdict(list), so we are using a WeakValueDictionary as one + if weak_st not in self.constant_storage_mapping: + self.constant_storage_mapping[weak_st] = [] + self.constant_storage_mapping[weak_st].append(weakref.ref(fake_tensor)) + + def invalidate_constant_aliases(self, tensor: Tensor) -> None: + assert not isinstance(tensor, FakeTensor) + + weak_st = StorageWeakRef(tensor._typed_storage()) + if weak_st not in self.constant_storage_mapping: + return + + for weak_tensor_ref in self.constant_storage_mapping[weak_st]: + ten = weak_tensor_ref() + if ten is not None: + ten._fix_weakref() + ten.constant = None + + del self.constant_storage_mapping[weak_st] + + def _get_memo(self, t: Tensor) -> Optional[FakeTensor]: + tid = self.meta_converter.describer.lookup_tensor.get(t) + if tid is None: + return None + return self.tensor_memo.get(tid) + + def set_tensor_memo(self, t: Tensor, v: FakeTensor) -> None: + tid = self.meta_converter.describer.get_tensor_id(t) + self.meta_converter.tensor_memo[tid] = v + + # You can have a real tensor that you need to convert into a fake tensor. + # If you have a meta tensor already, call from_meta_and_device. + # + # You're allowed to pass a meta tensor to be turned into a fake + # tensor; although an odd thing to do, this can occur if you're doing + # cross ref testing and the inner test is already operating on meta tensors. + def from_real_tensor( + self, + fake_mode: FakeTensorMode, + t: Tensor, + make_constant: bool = False, + shape_env: Optional[ShapeEnv] = None, + *, + source: Optional[Source] = None, + symbolic_context: Optional[SymbolicContext] = None, + trace: bool = True, + ) -> FakeTensor: + # see note [Tensor Fakification and Symbol Caching] + if not symbolic_context and not source and shape_env: + if tracing_context := torch._guards.TracingContext.try_get(): + if t in tracing_context.tensor_to_context: + symbolic_context = tracing_context.tensor_to_context[t] + from torch.fx.experimental.symbolic_shapes import ( + StatefulSymbolicContext, + ) + + assert isinstance(symbolic_context, StatefulSymbolicContext) + source = symbolic_context.tensor_source + + maybe_memo = self._get_memo(t) + if maybe_memo is not None: + return maybe_memo + # not yet supported in metatensors + if t.is_quantized: + raise UnsupportedFakeTensorException("quantized nyi in meta tensors") + if type(t) is torch.nn.Parameter: + assert not make_constant + + constant = t if make_constant else None + + # This callback is used by both subclass and inner tensors. Require the + # caller to explicitly specify the device in case outer and inner tensors + # have different devices. + def mk_fake_tensor( + make_meta_t: Callable[[], object], device: Union[torch.device, str] + ) -> FakeTensor: + # NB: don't use in_kernel_invocation_manager. to + # ensure FakeTensor can internally do constant computation + # as necessary. Invocation manager is "more correct" as + # it works for more operators in make_meta_t, but + # invariant is that make_meta_t only calls factories + # for which it is not strictly necessary to use the + # invocation manager (I think!) + with no_dispatch(): + return FakeTensor( + fake_mode, + # pyrefly: ignore [bad-argument-type] + make_meta_t(), + # pyrefly: ignore [bad-argument-type] + device, + # TODO: callback might be used in recursive contexts, in + # which case using t is wrong! BUG! + constant=constant, + ) + + out = self.meta_converter( + t, + shape_env=shape_env, + callback=mk_fake_tensor, + source=source, + symbolic_context=symbolic_context, + trace=trace, + ) + if out is NotImplemented: + raise UnsupportedFakeTensorException("meta converter nyi") + + from torch._dynamo.source import RandomValueSource + + value = None + if ( + not self.export + and _is_plain_tensor(t) # mostly, we want to know if item() works + and t.dim() == 0 + and t.device.type == "cpu" + # All integer types are fair game, because signed overflow is UB + # (and even int64 can overflow, since integers in Python are + # arbitrary precision). But only float64 is OK for float, because + # switching between float32 and float64 changes semantics in an + # observable way without hitting UB. + and t.dtype + in [torch.int64, torch.int32, torch.int16, torch.int8, torch.float64] + and source is not None + # Impede setting up item() on things coming from random. These + # are not "real" item() calls, instead UnspecializedPythonVariable + # is unsafely pretending an int is a tensor, which can sometimes + # implicitly cause an item call. The problem is this is pretty + # unsound: there's no reason substituting an int with a Tensor is + # going to give the same results. Today, you mostly get around + # this by typically not having capture_scalar_outputs on and graph + # breaking when someone tries to use the unspec variable in an + # int-y context. But allowing it through here would break that. + # So don't. + # + # Once random values are setup to be represented as + # SymNodeVariable, this condition can be removed. To check if + # you've done it right, this is a good test: + # + # PYTORCH_TEST_WITH_DYNAMO=1 python test/test_reductions.py -k + # TestReductionsCPU.test_dim_reduction_fns_fn_name_amax_cpu_bfloat16 + and not isinstance(source, RandomValueSource) + # In Dynamo, shape_env is never none (even with static shapes). + # However, FakeTensorMode can be used by hand and in some cases + # ShapeEnv is not allocated. + and shape_env is not None + ): + from torch._dynamo.source import CallMethodItemSource, FloatTensorSource + from torch.fx.experimental.symbolic_shapes import DimDynamic + + with no_dispatch(): + value = t.item() + if not math.isnan(value) and not math.isinf(value): + # Peephole strip out unnecessary torch.as_tensor(x).item() + if isinstance(source, FloatTensorSource): + item_source = source.base + else: + item_source = CallMethodItemSource(source) + symbol = shape_env.create_unspecified_symbol( + value, + source=item_source, + dynamic_dim=DimDynamic.DYNAMIC, + symbolic_context=symbolic_context, + ) + # NB: reusing item_memo here ensures that we invalidate on + # mutation + if t.dtype == torch.int64: + out.item_memo = shape_env.create_symintnode( + symbol, + hint=value, + source=item_source, + ) + elif t.dtype == torch.float64: + out.item_memo = shape_env.create_symfloatnode( + symbol, + hint=value, + source=item_source, + ) + if make_constant: + self.add_constant_storage_mapping(out) + # NB: meta_converter set the memo + return out + + # If you specify the device, it MUST be a meta tensor. + def from_meta_and_device( + self, + fake_mode: FakeTensorMode, + t: Tensor, + device: torch.device, + pytype: Optional[type[torch.Tensor]] = None, + dispatch_keys: Optional[torch.DispatchKeySet] = None, + ) -> FakeTensor: + assert t.device.type == "meta", ( + f"tensor's device must be `meta`, got {t.device.type} instead" + ) + # This is a bit abusive (this is not the "real" tensor) but whatever, + # the meta tensor should be fresh so there's no way to get it wrong + maybe_memo = self._get_memo(t) + if maybe_memo is not None: + return maybe_memo + out = FakeTensor( + fake_mode, t, device, pytype=pytype, dispatch_keys=dispatch_keys + ) + self.set_tensor_memo(t, out) + return out + + +@functools.cache +def init_gpu_context(device: torch.device) -> None: + # Backward will error with cuda Fake Tensors if no cuda tensors have been initialized first + if torch.cuda.is_available() or torch.xpu.is_available(): + ( + torch.empty(1, device=device) + if torch.version.hip is None + else torch.zeros(1, device=device) + ) + + +@contextlib.contextmanager +def in_kernel_invocation_manager( + fake_mode: FakeTensorMode, +) -> Generator[None, None, None]: + # See: note [Fake Tensor Dispatch Keys] + prev_in_kernel = fake_mode.in_kernel_invocation + meta_in_tls = torch._C._meta_in_tls_dispatch_include() + assert meta_in_tls == prev_in_kernel, f"{meta_in_tls}, {prev_in_kernel}" + + with torch._C._DisableTorchDispatch(): + fake_mode.in_kernel_invocation = True + # Unfortunately _set_meta_in_tls_dispatch_include(False) can leave + # `Dense` turned on (because it's implied by `Meta`) + with torch._C._PreserveDispatchKeyGuard(): + torch._C._set_meta_in_tls_dispatch_include(True) + try: + yield + finally: + fake_mode.in_kernel_invocation = prev_in_kernel + # torch._C._set_meta_in_tls_dispatch_include(prev_in_kernel) + + +# Return if the function allows Python numbers to bind to Tensors +def should_allow_numbers_as_tensors(func: OpOverload) -> bool: + return torch._C._should_allow_numbers_as_tensors( + func.name().split("::")[-1].split(".")[0] + ) + + +class FakeTensorConfig: + debug = os.environ.get("TORCH_FAKE_TENSOR_DEBUG", "0") == "1" + + +# This memorizes unbacked SymInt or SymFloats representing quantities like the +# number of nonzero elements in this tensor or learning rate. There is one +# instance of the descriptor per particular quantity to memoize. +# +# Memoization is helpful if you do something like x[mask] and y[mask]; +# mask.nonzero() gets repeatedly called and should give a consistent unbacked +# SymInt. It needs to be invalidated in the same way constant is. +# +# Making this a descriptor may seem overly fancy, but actually it's the most +# convenient way to ensure access to FakeTensor during access, which is +# required for testing version counter and epoch validity. +class SymNumberMemoDescriptor: + _name: str + + # By default, SymInts in this memo are invalidated across versions/epochs. + # nested_ints however are preserved across epochs and across versions. + # Preserving across versions is okay for nested int since the association + # of a nested int is agnostic to the underlying data and nested ints are not + # shared across multiple distinct tensors. + _is_nested_int: bool + + def __init__(self, *, is_nested_int: bool = False) -> None: + self._is_nested_int = is_nested_int + + def __set_name__(self, owner: str, name: str) -> None: + self._name = name + + def _memo(self, obj: FakeTensor) -> str: + return f"_{self._name}" + + def _memo_vc(self, obj: FakeTensor) -> str: + return f"_{self._name}_vc" + + # When we retrace, we need to invalidate all the memos so that we can + # accurately identify the first time unbacked SymInts are allocated. + # This is only relevant for inputs; for intermediates, they will get fresh + # fake tensors so you won't have a memo anyway + def _memo_epoch(self, obj: FakeTensor) -> str: + return f"_{self._name}_epoch" + + def __get__( + self, obj: FakeTensor, objtype: Optional[type[FakeTensor]] = None + ) -> Optional[Union[torch.SymInt, torch.SymFloat]]: + if (r := getattr(obj, self._memo(obj))) is None: + return None + + # If backed, it's ok to preserve memo since we know it won't renumber. + if isinstance(r, torch.SymFloat) and r.node.hint is not None: + return r + + # Version counter based tracking isn't 100% sound but it's close + # enough + if ( + not self._is_nested_int and getattr(obj, self._memo_vc(obj)) != obj._version + ) or ( + not self._is_nested_int + and getattr(obj, self._memo_epoch(obj)) != obj.fake_mode.epoch + ): + setattr(obj, self._memo(obj), None) + return None + return r + + def __set__( + self, obj: FakeTensor, value: Optional[Union[torch.SymInt, torch.SymFloat]] + ) -> None: + if value is None: + setattr(obj, self._memo(obj), None) + setattr(obj, self._memo_vc(obj), None) + setattr(obj, self._memo_epoch(obj), None) + elif not obj.is_inference() or self._is_nested_int: + setattr(obj, self._memo(obj), value) + if not self._is_nested_int: + setattr(obj, self._memo_vc(obj), obj._version) + setattr(obj, self._memo_epoch(obj), obj.fake_mode.epoch) + + +class FakeTensor(Tensor): + """ + Meta tensors give you the ability to run PyTorch code without having to + actually do computation through tensors allocated on a `meta` device. + Because the device is `meta`, meta tensors do not model device propagation. + FakeTensor extends MetaTensors to also carry an additional `fake_device` + which tracks devices that would have been used. + """ + + fake_device: torch.device + fake_mode: FakeTensorMode + constant: Optional[Tensor] + real_tensor: Optional[Tensor] + + # TODO: Generalize this as needed, e.g., into a trie of memos, if + # you do something like x[0].item() (x[0] is fresh each time, so + # memo mechanism here won't work) + nonzero_memo = SymNumberMemoDescriptor() + item_memo = SymNumberMemoDescriptor() + unique_memo = SymNumberMemoDescriptor() + unique_consecutive_memo = SymNumberMemoDescriptor() + + # We expect nested_int_memo to be None when an offsets is a graph + # intermediate, or an input that has never been associated with a + # nested int. + nested_int_memo = SymNumberMemoDescriptor(is_nested_int=True) + + # FakeTensor doesn't fully emulate the original tensor's Python type + # and dispatch key set, therefore sometimes we want to track them + # separately. + pytype: Optional[type[Tensor]] + dispatch_keys: Optional[torch.DispatchKeySet] + + # Indicates to our torch_dispatch dispatching infra that + # this is an "infra" mode with lower dispatching precedence. + _mode_key = torch._C._TorchDispatchModeKey.FAKE + + @property + # pyrefly: ignore [bad-override] + def device(self) -> torch.device: + if self.fake_mode.in_kernel_invocation: + return torch.device("meta") + else: + return self.fake_device + + @device.setter + def device(self, _: torch.device) -> None: + raise NotImplementedError + + # Note: [Fake Tensor Dispatch Keys] + # In order to model the behavior of device-specific autocast + # and autograd logic, we update the dispatch keys of FakeTensors + # to reflect their fake device. This includes the BackendComponent + # (DispatchKey::Meta -> DispatchKey::CUDA), and also the BackendComponent + # related Autocast and Autograd keys. __torch_dispatch__ sits below + # Autocast and Autograd, and is only invoked when we are at the + # kernel for the BackendComponent. Then, we add Meta to the + # thread-local dispatch include set to hit the meta kernel + # instead of the kernel of the BackendComponent for the fake device. + # The `device_for_backend_keys` does that below + # NOTE: this probably will not do the right thing for backends + # that have dispatch keys which are higher than the "meta" key: + # https://github.com/pytorch/pytorch/blob/main/c10/core/DispatchKey.h#L189 + + # We don't support named tensors; graph break + @property + # pyrefly: ignore [bad-override] + def names(self) -> list[str]: + raise UnsupportedFakeTensorException( + "torch.compile doesn't support named tensors" + ) + + @names.setter + def names(self, _: list[str]) -> None: + raise NotImplementedError + + @staticmethod + def __new__( + cls, + fake_mode: FakeTensorMode, + elem: Tensor, + device: torch.device, + constant: Optional[Tensor] = None, + real_tensor: Optional[Tensor] = None, + pytype: Optional[type[Tensor]] = None, + dispatch_keys: Optional[torch.DispatchKeySet] = None, + ) -> Self: + self = Tensor._make_subclass( + cls, + elem, + elem.requires_grad, + dispatch_device=True, + device_for_backend_keys=device, + ) + if not fake_mode._allow_unsafe_data_ptr_access: + torch._C._set_throw_on_mutable_data_ptr(self) + else: + torch._C._set_warn_deprecated_on_mutable_data_ptr(self) + + assert elem.device.type == "meta", elem.device.type + device = device if isinstance(device, torch.device) else torch.device(device) + # NB: it is fine, if a little confusing, for device to be meta + # (we are faking a meta tensor in that case). However, it often + # indicates some sort of confusion (e.g., you accidentally passed + # in a meta tensor when you should have passed in the real tensor). + # So by default we disallow meta, and if you are working in a situation + # where it is helpful (e.g., crossref testing) you can turn it back + # on + if not fake_mode.allow_meta: + assert device.type != "meta" + # normalize device. + if device.type in ["cuda", "xpu"]: + init_gpu_context(device) + + if ( + device.type + in [ + "cuda", + "hpu", + "xpu", + "mps", + "mtia", + torch._C._get_privateuse1_backend_name(), + ] + and device.index is None + ): + if device.type != "mps" and getattr(torch, device.type).is_initialized(): + device = torch.device( + f"{device.type}:{getattr(torch, device.type).current_device()}" + ) + else: + device = torch.device(f"{device.type}:0") + # pyrefly: ignore [read-only] + self.fake_device = device + self.fake_mode = fake_mode + self.constant = constant + self.pytype = pytype + self.dispatch_keys = dispatch_keys + assert not isinstance(real_tensor, FakeTensor) + self.real_tensor = real_tensor + self.nonzero_memo = None + self.item_memo = None + self.unique_memo = None + self.unique_consecutive_memo = None + self.nested_int_memo = None + + if FakeTensorConfig.debug: + self._debug_trace = CapturedTraceback.extract() # type: ignore[attr-defined] + return self + + # In some circumstances, a conventional Tensor constructor + # will get rewritten to call into FakeTensor. We must provide an + # __init__ method that can accept the Python interpreters initialization + # in such a situation; we must also be able to handle direct fake + # tensor construction via FakeTensor(). + # + # In particular, the __init__ call will look funny in the following case: + # + # with FakeTensorMode(): + # x = Tensor([1, 2, 3]) + # + # this desugars into: + # + # with FakeTensorMode(): + # x = Tensor.__new__([1, 2, 3]) + # # NB: x is a fake tensor, because of the mode! + # x.__init__([1, 2, 3]) # not the normal fake tensor args! + # + def __init__(self, *args: object, **kwargs: object) -> None: + super().__init__() + if ( + torch.compiler.is_exporting() + and torch._export.config.detect_non_strict_fake_tensor_leaks + ): + fake_tensor_tls.non_strict_export_fake_tensor_tracker.add(self) + + @staticmethod + def from_tensor(t: Tensor, fake_mode: FakeTensorMode) -> FakeTensor: + return fake_mode.from_tensor(t) + + @classmethod + @count + def __torch_dispatch__( # type: ignore[override] # TODO + cls, + func: OpOverload, + types: Sequence[type], + args: Sequence[object] = (), + kwargs: Mapping[str, object] = immutable_dict(), + ) -> object: + # need to handle here to avoid infinite recursion + # see [in_kernel_invocation] + if func is torch.ops.prim.device.default: + assert len(args) == 1 and isinstance(args[0], FakeTensor) + if args[0].fake_mode.in_kernel_invocation: + return torch.device("meta") + else: + return args[0].fake_device + + # this handler must be done inside FakeTensor subclass, not mode, because + # we can end up dispatching here when we have a fake tensor with + # symbolic sizes running under in_kernel_invocation_manager. + # The subclass is asked to handle this query because size (not + # sym_size) was called, but we are unable to serve it directly because + # there are symbolic sizes in the class. The use of + # in_kernel_invocation_manager means it's incorrect to activate a + # mode to actually handle this (this caused + # https://github.com/pytorch/pytorch/issues/122772). + if handler := _DISPATCH_META_HANDLERS.get(func): + return handler(args) + + # Because fake mode can return NotImplemented (if it sees a subclass + # it doesn't know how to deal with), this test here is important + # because the next dispatch after a fake mode will attempt to use + # subclasses of tensors to dispatch, and any FakeTensor arguments + # will be considered eligible. + unrecognized_types = [ + t for t in types if not issubclass(t, FakeTensor) and t is not Tensor + ] + if unrecognized_types: + not_implemented_log.debug( + "FakeTensor unrecognized subclass(es): %s", unrecognized_types + ) + return NotImplemented + + fake_mode = None + for arg in pytree.arg_tree_leaves(*args, **kwargs): + if isinstance(arg, FakeTensor): + fake_mode = arg.fake_mode + break + + assert fake_mode is not None + + # If the fake mode is already active, don't try to reapply it! + # NotImplemented is the right thing to return here, because the + # typical situation this can occur is if ProxyTensorMode returned a + # NotImplemented because of a not implemented subclass; we may have + # unluckily attempted to hit FakeTensor's dispatch first, + # NotImplemented lets us keep chaining until we find the actual + # subclass + maybe_cur_fake_mode = torch._C._get_dispatch_mode( + torch._C._TorchDispatchModeKey.FAKE + ) + if maybe_cur_fake_mode: + not_implemented_log.debug( + "FakeTensor mode already active: %s in %s", + fake_mode, + maybe_cur_fake_mode, + ) + return NotImplemented + + assert not fake_mode.in_kernel_invocation + + with fake_mode: + return func(*args, **kwargs) + + @staticmethod + def _find_common_device( + func: OpOverload, flat_args: Sequence[object] + ) -> tuple[torch.device, bool]: + # Returns: (common_device, has_scalar_only_inputs) + + # cpu - zero-dim tensors can be called in cuda kernels, + # so overwrite the common_device if it the only existing + # device comes from a cpu zero-dim tensor + common_device = None + has_scalar_only_inputs = False + is_cpu_zero_dim = None + + # list of ops which can have args(tensor/tensorList) in mixed device + mixed_device_fns = ordered_set( + aten._foreach_copy.default, + ) + + # list of ops not using zero dim cpu tensor logic to align with the eager mode. + bypass_zero_dim_cpu_tensor_check_ops = ordered_set( + aten.nextafter.default, + ) + + def check_cpu_device(device: torch.device) -> bool: + return device.type == "cpu" + + def cpu_zero_dim(t: Tensor) -> bool: + return check_cpu_device(t.device) and t.dim() == 0 + + def merge_devices(t: object) -> None: + nonlocal common_device + nonlocal is_cpu_zero_dim + if not isinstance(t, FakeTensor): + return + + if common_device is None: + common_device = t.device + is_cpu_zero_dim = cpu_zero_dim(t) + return + + t_is_cpu_zero_dim = cpu_zero_dim(t) + if t.device == common_device: + if is_cpu_zero_dim: + is_cpu_zero_dim = t_is_cpu_zero_dim + return + + is_bypass_zero_dim_cpu_tensor_check_op = ( + func in bypass_zero_dim_cpu_tensor_check_ops + ) + + # mismatching devices ! + # if current tensor is cpu 0 dim, defer to existing device + if t_is_cpu_zero_dim and not is_bypass_zero_dim_cpu_tensor_check_op: + return + + # current device is from cpu 0 dim tensor, overwrite + if is_cpu_zero_dim and not is_bypass_zero_dim_cpu_tensor_check_op: + common_device = t.device + is_cpu_zero_dim = t_is_cpu_zero_dim + return + + # if still device mismatches we will check ops which can work + # on different devices for ex. _foreach_copy, and one of the + # device must be cpu in this case we will return from here without + # throwing an error + if func in mixed_device_fns: + if any(map(check_cpu_device, (common_device, t.device))): + return + + # if prefer_device_type is set, prefer that device type over others + prefer_device_type = torch._functorch.config.fake_tensor_prefer_device_type + if prefer_device_type is not None: + common_has_preferred = prefer_device_type in common_device.type + t_has_preferred = prefer_device_type in t.device.type + + if not common_has_preferred and t_has_preferred: + # Switch to the preferred device type + common_device = t.device + is_cpu_zero_dim = t_is_cpu_zero_dim + return + elif common_has_preferred and not t_has_preferred: + # Keep the existing preferred device type + return + + # mismatching devices of non-zero dim tensors, throw + # This might be valid behavior and need to be explicitly modeled, e.g. reshape_as + raise RuntimeError( + f"Unhandled FakeTensor Device Propagation for {func}, found two different devices {common_device}, {t.device}" + ) + + for arg in flat_args: + merge_devices(arg) + + # some functions that allow Python numbers to bind to Tensors + # if we have failed to find a device, and we're running one of these operators, + # we must have scalar only inputs + if should_allow_numbers_as_tensors(func) and common_device is None: + # ops with scalar only inputs always have result on cpu + has_scalar_only_inputs = True + common_device = torch.device("cpu") + + assert common_device is not None, f"Could not find common device for {func}" + + return common_device, has_scalar_only_inputs + + def get_nested_int( + self, + *, + coeff: Union[int, torch.SymInt] = 1, + ) -> torch.SymInt: + if self.nested_int_memo is None: + self.nested_int_memo = self.fake_mode.create_symbolic_nested_int( + nt_tensor_id=None + ) + assert isinstance(self.nested_int_memo, torch.SymInt) + return self.nested_int_memo * coeff + + # Similar to FunctionalTensor.tolist + def tolist(self) -> Any: + if self.dim() == 0: + return self.item() + elif self.dim() == 1: + return [elem.item() for elem in self] + else: + return [elem.tolist() for elem in self] + + +_MetadataIntLike = Union[IntLikeType, "_PySymInputStub", "_SymIntOutputStub"] + + +@dataclass(slots=True) +class TensorMetadata: + """ + The Tensor metadata relevant to hashing FakeTensors when caching. + """ + + dtype: torch.dtype + shape: tuple[_MetadataIntLike, ...] + stride: tuple[_MetadataIntLike, ...] + device: torch.device + layout: torch.layout + memory_format: Optional[torch.memory_format] + storage_offset: _MetadataIntLike + storage_bytes: Optional[_MetadataIntLike] + requires_grad: bool + is_quantized: bool + is_conj: bool + is_neg: bool + is_inference: bool + is_sparse: bool # read: is sparse COO + is_coalesced: Optional[bool] + dense_dim: Optional[int] + sparse_dim: Optional[int] + + def _flatten_into( + self, + result: list[object], + mode: FakeTensorMode, + state: _CacheKeyState, + ) -> None: + # Flatten the TensorMetadata out into `result`. Make sure to call + # state.convert_sym_int() on any SymInts. + for field in dataclasses.fields(self): + value = getattr(self, field.name) + if isinstance(value, (tuple, list, torch.Size)): + # This will recursively flatten the iterable, calling + # convert_sym_int() as necessary. + id_hashed_objects: list[object] = [] + mode._prep_args_for_hash(result, value, state, id_hashed_objects) + id_hashed_objects.clear() + elif isinstance(value, SymInt): + state.convert_sym_int(result, value) + else: + result.append(value) + + +def extract_tensor_metadata(t: Tensor) -> TensorMetadata: + """ + Extract the TensorMetadata of a tensor. + """ + memory_format = suggest_memory_format(t) + # Don't call is_contiguous() on a Tensor which has symbolic sizes or things + # will go badly (guards will be messed up?) + if ( + t._has_symbolic_sizes_strides + or is_sparse_any(t) + or not t.is_contiguous(memory_format=memory_format) + ): + memory_format = None # type: ignore[assignment] + + storage_offset = t.storage_offset() + + return TensorMetadata( + t.dtype, + t.shape, + t.stride() if t.layout == torch.strided else (), + t.device, + t.layout, + memory_format, + storage_offset, + # Only set storage_bytes for tensors that have storage (not sparse) + t.untyped_storage().nbytes() if not is_sparse_any(t) else None, + t.requires_grad, + t.is_quantized, + t.is_conj(), + t.is_neg(), + t.is_inference(), + t.is_sparse, + t.is_coalesced() if t.is_sparse else None, + t.dense_dim() if is_sparse_any(t) else None, + t.sparse_dim() if is_sparse_any(t) else None, + ) + + +@dataclass(slots=True) +class _DispatchCacheKey: + """ + Key for the FakeTensor dispatch cache. + """ + + key: tuple[object, ...] + hashvalue: int + + def __init__(self, tup: tuple[object, ...]) -> None: + self.key = tup + self.hashvalue = hash(tup) + + def __eq__(self, other: object) -> bool: + return isinstance(other, _DispatchCacheKey) and self.key == other.key + + def __hash__(self) -> int: + return self.hashvalue + + def strip_shape_env(self) -> None: + # We need to strip the ShapeEnv from any values before we store in the + # cache so the cache doesn't keep our ShapeEnvs alive. + for v in self.key: + if isinstance(v, _PySymInputStub): + v.strip_shape_env() + + +# Default value for constant_value in _DispatchCacheEntryOutputInfo. This is +# only for checking and differentiates from None. +class SingletonConstant: + pass + + +@dataclass(frozen=True, slots=True) +class _DispatchCacheEntryOutputInfo: + """ + Entry type for the FakeTensor dispatch cache for an output. Accounts for three + possibilities: + 1) The op is inplace, and a hit means we need to alias the argument at a + given index. + 2) We need to synthesize a new FakeTensor given tensor metadata. For view + ops, we further capture the index of the arg to alias. + 3) if the tensor related fields are None, then it is a constant value (e.g. + None or integer) + """ + + inplace_idx: Optional[int] + metadata: Optional[TensorMetadata] + view_idx: Optional[int] + constant_value: Optional[Any] = SingletonConstant + + +@dataclass(frozen=True, slots=True) +class _DispatchCacheValidEntry: + """ + Entry type for the FakeTensor dispatch cache. It supports two types of outputs + 1) tensor + 2) tuple of tensors + + is_output_tuple flag helps in differentiating the return type + """ + + output_infos: tuple[_DispatchCacheEntryOutputInfo] + is_output_tuple: bool = False + + +@dataclass(frozen=True, slots=True) +class _DispatchCacheBypassEntry: + """ + Entry type for a negative cache entry. + """ + + reason: str + + +if TYPE_CHECKING: + _DispatchCacheEntry = Union[_DispatchCacheValidEntry, _DispatchCacheBypassEntry] + + +@dataclass(frozen=True, slots=True) +class _BypassDispatchCache(Exception): + """ + Signals cases that should skip FakeTensor caching. + """ + + reason: str + + +@dataclass(frozen=True, slots=True) +class DispatchCacheInfo: + """ + Information about the state of the FakeTensor dispatch cache. + """ + + hits: int + misses: int + bypasses: dict[str, int] + size: int + + +# We keep one instantiation of `fake_tensor_converter` active +# for the duration of `with FakeTensorMode()`. +# This allows accurate storage aliasing across invocation of +# different operators. While this will keep all freshly allocated +# tensors alive during `FakeTensorMode`, there will be no +# new allocations of Tensors which have non-meta storage so +# memory should not significantly increase. + + +class FakeTensorMode(TorchDispatchMode): + cache: dict[_DispatchCacheKey, _DispatchCacheEntry] = {} + cache_hits: int = 0 + cache_misses: int = 0 + cache_bypasses: dict[str, int] = defaultdict(int) + # Every time you retrace using the same fake tensor mode, you should + # advance the epoch so we don't reuse unbacked memos + epoch: int = 0 + in_kernel_invocation: bool = False + static_shapes: bool + shape_env: Optional[ShapeEnv] + _stack: Optional[str] + allow_meta: bool + + # NestedTensor uses a tensor_id_counter to uniquely identify offsets. + # This counter is incremented when an offsets is used to create an NJT + # for the first time. To avoid mutating eager state if we construct NJT + # during tracing, we maintain a separate counter on the FakeTensorMode. + # The initial count is set to the current eager tensor_id_counter value + # upon initialization, and every time you retrace using the same fake tensor + # mode, you should reset the counter to the initial count. + nt_tensor_id_counter: int = -1 + nt_tensor_id_initial_count: int = -1 + + def __init__( + self, + *, + allow_fallback_kernels: bool = True, + allow_non_fake_inputs: bool = False, + shape_env: Optional[ShapeEnv] = None, + static_shapes: Optional[bool] = None, + # TODO: This is a temporary measure, see + # https://github.com/pytorch/pytorch/pull/126245#discussion_r1604185748 + # We're currently solely using this to impede population of + # item_memo for 0d scalar tensor inputs when export, because this + # causes things that used to be deferred runtime asserts to turn into + # guards, and then the guards are just lost. We can potentially fix + # this by ensuring guards also get put in the graph, but this is + # pending a rework of how deferred runtime asserts in export. Once + # that's done, we can remove this. + export: bool = False, + ) -> None: + log.debug("create_mode 0x%x", id(self)) + super().__init__() + self.allow_fallback_kernels = allow_fallback_kernels + + import torch._dynamo.config + import torch._functorch.config + + self.propagate_real_tensors = ( + torch._functorch.config.fake_tensor_propagate_real_tensors + ) + self.fake_tensor_converter = FakeTensorConverter( + copy_data=self.propagate_real_tensors, + export=export, + ) + + if static_shapes is not None: + self.static_shapes = static_shapes + else: + self.static_shapes = shape_env is None + + # This is temporarily patched to True in Dynamo to grandfather in some + # places where we unconditionally allow scalar outputs, TO BE REMOVED + self.allow_scalar_outputs = False + + self._allow_unsafe_data_ptr_access = ( + torch._functorch.config.fake_tensor_allow_unsafe_data_ptr_access + ) + self.allow_meta = torch._functorch.config.fake_tensor_allow_meta + self.cache_enabled: bool = ( + torch._dynamo.config.fake_tensor_cache_enabled + and not self.propagate_real_tensors + ) + self.cache_crosscheck_enabled = ( + torch._dynamo.config.fake_tensor_cache_crosscheck_enabled + ) + + # A flag that controls, whether we want to invoke ops on mix of + # real weights/global variables and fake inputs + self.allow_non_fake_inputs = allow_non_fake_inputs + + # [in_kernel_invocation] + # when FakeTensor is invoked in user code, .device should return + # the fake_device of the tensor so that code such as as `if x.is_cuda` + # or torch.zeros([10, 10], device=x.device) continues to execute as if + # the FakeTensor were real. However, within kernel execution, we return + # the `Meta` device because all computation within the kernels should + # behave as if the Tensors are on meta devices. Kernels should allocate + # new tensors on meta devices, and checks like `is_meta` should return true. + # within python refs, we always return the real device by defining + # the device property + self.in_kernel_invocation = False + + # True if we enter'ed and actually enabled fake tensor mode, + # false if it was a no-op. Not thread safe but neither is + # in_kernel_invocation + # If another fake mode was already active when we enter, we also stash it here. + # That way when we exit, we know to re-enable the previous fake mode. + self.enter_stack: list[ + tuple[bool, Optional[TorchDispatchMode], Optional[bool]] + ] = [] + + self.shape_env = shape_env + + self._stack_trace = traceback.extract_stack() + self._stack = None + + # Indicates to our torch_dispatch dispatching infra that + # this is an "infra" mode with lower dispatching precedence. + self._mode_key = torch._C._TorchDispatchModeKey.FAKE + + import torch.nested._internal.nested_tensor + + self.nt_tensor_id_initial_count = ( + torch.nested._internal.nested_tensor._tensor_id_counter + ) + self.nt_tensor_id_counter = self.nt_tensor_id_initial_count + + def reset_nt_tensor_id_counter(self) -> None: + self.nt_tensor_id_counter = self.nt_tensor_id_initial_count + + # Typically, there is only one fake tensor mode and you test for it by + # doing an isinstance test. However, in some situations, there might be + # TWO fake tensor modes. The canonical example of this is exporting + # a fake model: there is an outer fake mode created by the user, and + # an inner fake mode created by Dynamo. The two phase process is required + # because the outer fake mode typically won't have a ShapeEnv, even if + # the user is interested in exporting with dynamic shapes (so the inner + # fake mode will actually have a ShapeEnv and swap in symbolic sizes.) + # + # In this case, it's insufficient to test only one FakeTensor: you need + # to distinguish between our fake tensor and other fake tensors. That's + # what this function does. + def is_our_fake(self, t: object) -> TypeGuard[FakeTensor]: + return isinstance(t, FakeTensor) and t.fake_mode is self + + # If we should avoid device init. This changes the behavior of various APIs: + # - We avoid constant-prop on Tensors with ops that move them to another device + # - We change the torch.tensor ctor contract to never materialize + # tensors on device + # (see NOTE: [torch.tensor, lift_fresh, and device movement]) + @property + def avoid_device_init(self) -> bool: + if torch.xpu._is_compiled(): + assert not torch.cuda._is_compiled() + return not torch.xpu.is_available() + + return not ( + torch.cuda.is_available() + or (hasattr(torch, "hpu") and torch.hpu.is_available()) + ) + + @property + def stack(self) -> str: + if self._stack is None: + self._stack = "".join(traceback.format_list(self._stack_trace)) + return self._stack + + @count + # pyrefly: ignore [bad-override] + def __torch_dispatch__( + self, + func: OpOverload, + types: Sequence[type], + args: Sequence[object] = (), + kwargs: Mapping[str, object] = immutable_dict(), + ) -> object: + # FakeTensorMode should not be set when we're inside of it. + assert ( + torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.FAKE) is None + ), func + try: + return self.dispatch(func, types, args, kwargs) + except TypeError: + log.exception("fake tensor raised TypeError") + raise + + # No-op if FakeTensorMode is already in use + def __enter__(self) -> Self: + import torch.nested._internal.nested_tensor + + prev_only_lift_cpu_tensors = None + if self.avoid_device_init: + # See NOTE: [torch.tensor, lift_fresh, and device movement] + prev_only_lift_cpu_tensors = torch._C._only_lift_cpu_tensors() + torch._C._set_only_lift_cpu_tensors(True) + + # In the case of CPU-only build or cuda device unavailable, + # we patch the cuda device guard to use NoOpDeviceGuardImpl. + # This enables us to trace over cuda kernels under FakeTensorMode. + torch._C._ensureCUDADeviceGuardSet() + + maybe_prev_fake_mode = torch._C._unset_dispatch_mode(self._mode_key) + if self is not maybe_prev_fake_mode: + self.enter_stack.append( + (True, maybe_prev_fake_mode, prev_only_lift_cpu_tensors) + ) + return super().__enter__() + else: + # no-op (still need to re-set the fake mode though since we unset it) + torch._C._set_dispatch_mode(self) + self.enter_stack.append((False, None, prev_only_lift_cpu_tensors)) + + return self + + def __exit__( + self, + a: Optional[type[BaseException]], + b: Optional[BaseException], + c: Optional[TracebackType], + ) -> None: + ( + live, + maybe_prev_fake_mode, + maybe_prev_only_lift_cpu_tensors, + ) = self.enter_stack.pop() + if live: + super().__exit__(a, b, c) + + # Re-enable the previous fake mode, if there was one. + if maybe_prev_fake_mode is not None: + torch._C._set_dispatch_mode(maybe_prev_fake_mode) + if maybe_prev_only_lift_cpu_tensors is not None: + torch._C._set_only_lift_cpu_tensors(maybe_prev_only_lift_cpu_tensors) + + @classmethod + def is_infra_mode(cls) -> bool: + return True + + @classmethod + def cache_info(cls) -> DispatchCacheInfo: + """ + Query the state of the dispatch cache. + """ + return DispatchCacheInfo( + FakeTensorMode.cache_hits, + FakeTensorMode.cache_misses, + dict(FakeTensorMode.cache_bypasses), + len(FakeTensorMode.cache), + ) + + @classmethod + def cache_clear(cls) -> None: + """ + Clear the dispatch cache. + """ + cls.cache_hits = 0 + cls.cache_misses = 0 + cls.cache_bypasses.clear() + cls.cache.clear() + + def _cached_dispatch_impl( + self, + func: OpOverload, + types: Sequence[type], + args: Sequence[object], + kwargs: Mapping[str, object], + ) -> object: + """ + Lookup a cache entry for the given arguments. If none exists, dispatch + and cache the result (if the result is eligible for caching). + """ + state = None + key = None + try: + state = _CacheKeyState(self.shape_env) + key = self._cache_key(state, func, args, kwargs) + except _BypassDispatchCache as e: + # We couldn't create the cache key at all + if ( + isinstance(func, torch._ops.HigherOrderOperator) + and func.name() == "invoke_subgraph" + ): + hc_log.debug( + "Fake tensor cache failed: identifier = %s, reason = %s", + args[1], + e.reason, + ) + FakeTensorMode.cache_bypasses[e.reason] += 1 + + if key is None: + # Do this dispatch outside the above except handler so if it + # generates its own exception there won't be a __context__ caused by + # the caching mechanism. + # pyrefly: ignore [bad-argument-type] + return self._dispatch_impl(func, types, args, kwargs) + + assert state is not None + if state.cache_on_shape_env(): + assert state.shape_env is not None + cache = state.shape_env.fake_tensor_cache + set_cache_key = _set_cache_key_for_shape_env + else: + cache = FakeTensorMode.cache + set_cache_key = _set_cache_key + entry = cache.get(key, None) + + if entry is not None: + if isinstance(entry, _DispatchCacheBypassEntry): + # This represents a negative cache entry - we already saw that the + # output is uncachable. Compute it from first principals. + FakeTensorMode.cache_bypasses[entry.reason] += 1 + # pyrefly: ignore [bad-argument-type] + return self._dispatch_impl(func, types, args, kwargs) + + # We have a cache entry. + # pyrefly: ignore [bad-argument-type] + output = self._output_from_cache_entry(state, entry, key, func, args) + FakeTensorMode.cache_hits += 1 + if self.cache_crosscheck_enabled: + # For debugging / testing: Validate that the output synthesized + # from the cache matches the output created by normal dispatch. + with disable_fake_tensor_cache(self): + # pyrefly: ignore [bad-argument-type] + self._crosscheck_cache_output(output, func, types, args, kwargs) + return output + + # We don't have a cache entry. + # pyrefly: ignore [bad-argument-type] + output = self._dispatch_impl(func, types, args, kwargs) + + try: + # pyrefly: ignore [bad-argument-type] + entry = self._make_cache_entry(state, key, func, args, kwargs, output) + except _BypassDispatchCache as e: + # We ran "extra" checks on the cache key and determined that it's no + # good. Record the reason and mark it so we don't bother validating + # again. + if ( + isinstance(func, torch._ops.HigherOrderOperator) + and func.name() == "invoke_subgraph" + ): + hc_log.debug( + "Fake tensor cache failed: identifier = %s, reason = %s", + args[1], + e.reason, + ) + FakeTensorMode.cache_bypasses[e.reason] += 1 + set_cache_key(cache, key, _DispatchCacheBypassEntry(e.reason)) + return output + + set_cache_key(cache, key, entry) + FakeTensorMode.cache_misses += 1 + return output + + def _cache_key( + self, + state: _CacheKeyState, + func: OpOverload, + args: Sequence[object], + kwargs: Mapping[str, object], + ) -> _DispatchCacheKey: + """ + Create a cache key given the dispatch args. Raises _BypassDispatchCache + for any situation that precludes caching. + """ + is_tracing = torch.fx.experimental.proxy_tensor.get_proxy_mode() is not None + key_values = [ + func, + # Capture the default_dtype mode since that can affect the output tensor, + # e.g., when operating on constant float values. + torch.get_default_dtype(), + # Capture the current device to support, e.g., cache tensor creation, + # where there isn't necessarily a tensor to take the device from. + torch._C._get_default_device(), + # We want to create tensors from cached metadata only when the inference + # mode is the same. + torch.is_inference_mode_enabled(), + # Shape env settings could affect behavior. One example seen in the wild: + # Disallowing dynamic shapes can introduce a DynamicOutputShapeException + # where it wasn't seen on a previous instance of the same op. + self.shape_env.settings if self.shape_env else None, + # ProxyTorchDispatchMode needs to track how SymNodes are constructed + # so we need to handle things a little different depending on + # whether we're tracing or not. + is_tracing, + ] + if state.known_symbols: + # If there are symbols then include the epoch - this is really more + # of a Shape env var which lives on the FakeTensorMode. + # pyrefly: ignore [bad-argument-type] + key_values.append(self.epoch) + # Collect the id_hashed objects to attach a weakref finalize later + id_hashed_objects: list[object] = [] + # Translate any FakeTensor args to metadata. + if args: + # pyrefly: ignore [bad-argument-type] + self._prep_args_for_hash(key_values, args, state, id_hashed_objects) + if kwargs: + # pyrefly: ignore [bad-argument-type] + self._prep_args_for_hash(key_values, kwargs, state, id_hashed_objects) + key = _DispatchCacheKey(tuple(key_values)) + + for id_hashed_obj in id_hashed_objects: + weakref.finalize( + id_hashed_obj, functools.partial(evict_fake_tensor_cache_key, key=key) + ) + id_hashed_objects.clear() + return key + + def _validate_cache_key( + self, + func: OpOverload, + args: Sequence[object], + kwargs: Mapping[str, object], + ) -> None: + """ + Validate that the cache key generated by _cache_key will be + reasonable. + """ + from torch._higher_order_ops.utils import registered_hop_fake_fns + + # For hops, we perform the validity check in _make_cache_entry because we + # need to have the output tensor. + if ( + isinstance(func, torch._ops.HigherOrderOperator) + and func in registered_hop_fake_fns + ): + return + + # Avoid caching for any ops that would require a more sophisticated + # caching implementation, e.g., data dependent ops or ops that modify + # the inputs. + if torch.Tag.data_dependent_output in func.tags: + raise _BypassDispatchCache("data dependent output") + + if torch.Tag.dynamic_output_shape in func.tags: + if func is aten.index.Tensor: + _, new_kwargs = normalize_function( # type: ignore[misc] + func, + args=args, # type: ignore[arg-type] + kwargs=kwargs, # type: ignore[arg-type] + normalize_to_only_use_kwargs=True, + ) + for index in new_kwargs["indices"]: + # index calls nonzero for bool or int8 tensors, and + # therefore has a dynamic shape output. For other dtypes, + # the output shape depends on the input shape (and not data) + if isinstance(index, torch.Tensor) and index.dtype in ( + torch.bool, + torch.int8, + ): + raise _BypassDispatchCache("dynamic output shape") + return + + raise _BypassDispatchCache("dynamic output shape") + + if torch.Tag.inplace_view in func.tags: + raise _BypassDispatchCache("inplace view") + + if func is aten._unsafe_view.default: + raise _BypassDispatchCache("unsafe view") + + if func in self.lift_fns: + raise _BypassDispatchCache("lift") + + if func.name() == "inductor::resize_storage_bytes_": + raise _BypassDispatchCache("inductor::resize_storage_bytes_") + + if not torch._library.utils.is_builtin(func): + raise _BypassDispatchCache("non-builtin") + + # In order to handle storage aliasing, we need to establish the alias + # for any view op on a cache hit. But CompositeImplicitAutograd ops may + # or may not alias the input, so just punt on caching these. + if func.is_view and torch._C._dispatch_has_kernel_for_dispatch_key( + func.name(), torch._C.DispatchKey.CompositeImplicitAutograd + ): + raise _BypassDispatchCache("CompositeImplicitAutograd") + + def _prep_args_for_hash( + self, + result: list[object], + args: Union[Mapping[str, object], Sequence[object], Iterable[object]], + state: _CacheKeyState, + id_hashed_objects: list[object], + ) -> None: + """ + Translate the provided args into a form suitable for caching at FakeTensor + dispatch, i.e., convert unhashable types like lists & dicts into tuples and + convert FakeTensors into metadata. Raises _BypassDispatchCache to signal + unsupported cases that should bypass caching. + """ + from torch._higher_order_ops.auto_functionalize import ( + FunctionalCallableWithEpilogue, + ) + from torch._higher_order_ops.utils import FunctionalizeCtxWrapper + + if isinstance(args, (list, tuple, dict)): + result.append(type(args)) + result.append(f"length_{len(args)}") + + if isinstance(args, dict): + self._prep_args_for_hash(result, args.keys(), state, id_hashed_objects) + self._prep_args_for_hash(result, args.values(), state, id_hashed_objects) + return + + for arg in args: + if isinstance(arg, FakeTensor): + if not self.is_our_fake(arg): + raise _BypassDispatchCache("not our fake") + if arg.constant is not None: + raise _BypassDispatchCache("constant attribute") + if is_sparse_any(arg): + raise _BypassDispatchCache(f"{arg.layout} tensor") + metadata = extract_tensor_metadata(arg) + metadata._flatten_into(result, self, state) + elif isinstance(arg, Tensor): + raise _BypassDispatchCache("non-fake tensor") + elif isinstance(arg, SymInt): + state.convert_sym_int(result, arg) + elif isinstance(arg, (SymBool, SymFloat)): + raise _BypassDispatchCache("symbolic shape") + elif isinstance(arg, (list, tuple, dict)): + self._prep_args_for_hash(result, arg, state, id_hashed_objects) + elif isinstance(arg, types.FunctionType): + raise _BypassDispatchCache("function argument") + elif isinstance(arg, torch.fx.GraphModule): + # This is used for invoke_subgraph where id(graph_module) allows + # us to cache fake outputs + result.append(type(arg)) + result.append(id(arg)) + id_hashed_objects.append(arg) + elif isinstance(arg, FunctionalizeCtxWrapper): + # Special case for AOT Dispatcher first pass, where the fake + # tensor is called on the functional wrapper of the subgraph. + result.append(hash(arg)) + # functional wrapper is destroyed after fake tensor prop. We + # need to put the finalizer on the subgraph. + id_hashed_objects.append(arg.subgraph) + elif isinstance(arg, FunctionalCallableWithEpilogue): + result.append(type(arg)) + result.append(hash(arg)) + id_hashed_objects.append(arg.orig_callable) + else: + # It's important to capture the type of the arg since, e.g., 1 and 1.0 + # hash to the same value, but can produce different dtypes for the + # output tensor. + result.append(type(arg)) + result.append(arg) + + def _validate_output_for_cache_entry( + self, + state: _CacheKeyState, + key: _DispatchCacheKey, + func: OpOverload, + args: Sequence[object], + kwargs: Mapping[str, object], + output: Optional[FakeTensor], + ) -> None: + # Is this even possible? According to the signature this can be None but + # not `int`. So either the signature is a lie or (part of) this line is + # unnecessary... + if isinstance(output, (int, type(None))): + return + + # Check for symbolic content that should bypass caching - raises + # _BypassDispatchCache if necessary. + _validate_symbolic_output_for_caching(state, output) + + # Some ops return tuples of Tensors, but it's rare, so avoid + # the complexity of caching other types. + if not isinstance(output, FakeTensor): + raise _BypassDispatchCache("non-FakeTensor output") + + # Avoid caching FakeTensors with constants attached since those + # can be invalidated. + if output.constant is not None: + raise _BypassDispatchCache("constant attribute") + + # TODO: support caching sparse outputs? + if output.is_sparse: + raise _BypassDispatchCache("sparse output") + + if is_sparse_compressed(output): + raise _BypassDispatchCache("sparse compressed output") + + # Can an in-place op really reference a kwarg? If so, then we need + # to extend the implementation to handle it. + for kval in kwargs.values(): + if id(kval) == id(output): + raise _BypassDispatchCache("kwarg aliases output") + + def _get_output_info_for_cache_entry( + self, + state: _CacheKeyState, + key: _DispatchCacheKey, + func: OpOverload, + args: Sequence[object], + kwargs: Mapping[str, object], + output: FakeTensor, + ) -> _DispatchCacheEntryOutputInfo: + if isinstance(output, (int, torch.SymInt, type(None))): + return _DispatchCacheEntryOutputInfo( + inplace_idx=None, metadata=None, view_idx=None, constant_value=output + ) + + # If this is an in-place op, the entry records which input arg is aliased. + for idx in range(len(args)): + if id(args[idx]) == id(output): + return _DispatchCacheEntryOutputInfo( + inplace_idx=idx, metadata=None, view_idx=None + ) + + # Otherwise, create an entry that records the output tensor's metadata. + view_idx = None + if isinstance(func, torch._ops.OpOverload) and func.is_view: + idxs = [i for i, t in enumerate(args) if isinstance(t, Tensor)] + assert len(idxs) == 1 + view_idx = idxs[0] + + metadata = extract_tensor_metadata(output) + metadata.shape = tuple(state.convert_output(v) for v in metadata.shape) + metadata.stride = tuple(state.convert_output(v) for v in metadata.stride) + metadata.storage_offset = state.convert_output(metadata.storage_offset) + metadata.storage_bytes = ( + None + if metadata.storage_bytes is None + else state.convert_output(metadata.storage_bytes) + ) + + entry = _DispatchCacheEntryOutputInfo( + inplace_idx=None, + metadata=metadata, + view_idx=view_idx, + ) + + # N.B.: Some checks for bypassing the cache would be performed on the + # output tensor synthesized from the cached metadata. As an optimization, + # we can synthesize a tensor here and do the checks on that instance. + # This approach keeps the (more frequent) cache-hit path as lightweight + # as possible. + entry_for_synth_output = _DispatchCacheValidEntry( + output_infos=(entry,), is_output_tuple=False + ) + from torch.fx.experimental.symbolic_shapes import GuardOnDataDependentSymNode + + try: + synth_output = self._output_from_cache_entry( + state, entry_for_synth_output, key, func, args + ) + except GuardOnDataDependentSymNode: + # This should probably never really happen. If it does it means that + # although the original call didn't get a data-dependent error when + # we tried to reconstruct the output we did - that's almost + # certainly a bug. + raise _BypassDispatchCache("data dependent symnode") from None + + # Make sure the dispatch_key_set from the synthesized output tensor will + # be the same. + synth_key_set = torch._C._dispatch_key_set(synth_output) + key_set = torch._C._dispatch_key_set(output) + if synth_key_set != key_set: + raise _BypassDispatchCache("dispatch_key_set mismatch") + + return entry + + def _make_cache_entry( + self, + state: _CacheKeyState, + key: _DispatchCacheKey, + func: OpOverload, + args: Sequence[object], + kwargs: Mapping[str, object], + output: Optional[FakeTensor], + ) -> _DispatchCacheValidEntry: + """ + Make a cache entry object for the given 'output' Tensor. Raises + _BypassDispatchCache if the output tensor has characteristics that + prevent caching it. + """ + from torch._higher_order_ops.utils import registered_hop_fake_fns + from torch.fx.experimental.symbolic_shapes import has_free_unbacked_symbols + + self._validate_cache_key(func, args, kwargs) + + # For hops, lets look at the output tensor to find any unbacked symints. + # If there are none, then we rely on the existing checks to validate + # caching. + # NB: Note that the HOPs that sta alive till FakeTensor are functional, + # once they support mutations, we will have to revisit this logic. + if ( + isinstance(func, torch._ops.HigherOrderOperator) + and func in registered_hop_fake_fns + ): + assert isinstance(output, tuple) + non_cacheable = any( + isinstance(o, (torch.Tensor, torch.SymInt)) + and has_free_unbacked_symbols(o) + for o in output + ) + if non_cacheable: + raise _BypassDispatchCache(f"unbacked symbol in HOP {func} output") + + if isinstance(output, (int, torch.SymInt, type(None))): + output_info = _DispatchCacheEntryOutputInfo( + inplace_idx=None, metadata=None, view_idx=None, constant_value=output + ) + return _DispatchCacheValidEntry( + output_infos=(output_info,), is_output_tuple=False + ) + + if isinstance(output, tuple): + for out_element in output: + self._validate_output_for_cache_entry( + state, + key, + # pyrefly: ignore [bad-argument-type] + func, + args, + kwargs, + out_element, + ) + else: + self._validate_output_for_cache_entry( + state, + key, + # pyrefly: ignore [bad-argument-type] + func, + args, + kwargs, + output, + ) + + if isinstance(output, tuple): + output_infos = [ + self._get_output_info_for_cache_entry( + state, + key, + # pyrefly: ignore [bad-argument-type] + func, + args, + kwargs, + out_elem, + ) + for out_elem in output + ] + return _DispatchCacheValidEntry( + # pyrefly: ignore [bad-argument-type] + output_infos=tuple(output_infos), + is_output_tuple=True, + ) + + else: + output_info = self._get_output_info_for_cache_entry( + state, + key, + # pyrefly: ignore [bad-argument-type] + func, + args, + kwargs, + output, + ) + return _DispatchCacheValidEntry( + output_infos=(output_info,), is_output_tuple=False + ) + + def _get_output_tensor_from_cache_entry( + self, + state: _CacheKeyState, + entry: _DispatchCacheEntryOutputInfo, + key: _DispatchCacheKey, + func: OpOverload, + args: Sequence[object], + ) -> Optional[FakeTensor]: + if ( + entry.inplace_idx is None + and entry.metadata is None + and entry.view_idx is None + ): + assert entry.constant_value is not SingletonConstant + return entry.constant_value + if entry.inplace_idx is not None: + # This is an in-place op; return the aliased arg. + inplace_arg = args[entry.inplace_idx] + assert isinstance(inplace_arg, FakeTensor) + return inplace_arg + + # Synthesize a new FakeTensor with the cached metadata. + metadata = entry.metadata + if metadata is None: + return None + + assert not is_sparse_any(metadata) + + def check_value( + value: _MetadataIntLike, state: _CacheKeyState + ) -> Union[IntLikeType]: + if isinstance(value, _SymIntOutputStub): + assert state.shape_env is not None + return value.extract(key, state.shape_env) + else: + assert not isinstance(value, _PySymInputStub) + return value + + shape = tuple(check_value(v, state) for v in metadata.shape) + stride = tuple(check_value(v, state) for v in metadata.stride) + storage_offset = check_value(metadata.storage_offset, state) + if metadata.storage_bytes is not None: + check_value(metadata.storage_bytes, state) + + maybe_suppress: Callable[[], typing.ContextManager] = contextlib.nullcontext + if self.shape_env is not None: + maybe_suppress = self.shape_env.suppress_guards + + with in_kernel_invocation_manager(self), maybe_suppress(): + empty = torch.empty_strided( + shape, + stride, + dtype=metadata.dtype, + layout=metadata.layout, + device="meta", + requires_grad=metadata.requires_grad, + ) + + if metadata.is_conj: + torch._C._set_conj(empty, True) + if metadata.is_neg: + torch._C._set_neg(empty, True) + + if isinstance(func, torch._ops.OpOverload) and func.is_view: + # For view ops, the storage should be the same as the tensor input. + view_arg = args[cast(int, entry.view_idx)] + assert isinstance(view_arg, FakeTensor) + storage = view_arg.untyped_storage() + with in_kernel_invocation_manager(self), maybe_suppress(): + empty.set_(storage, storage_offset, shape, stride) + + return FakeTensor(self, empty, metadata.device) + + def _output_from_cache_entry( + self, + state: _CacheKeyState, + entry: _DispatchCacheValidEntry, + key: _DispatchCacheKey, + func: OpOverload, + args: Sequence[object], + ) -> Union[Optional[FakeTensor], tuple[Optional[FakeTensor], ...]]: + """ + Create a new FakeTensor from the cache entry. + """ + + if entry.is_output_tuple: + outputs = [ + self._get_output_tensor_from_cache_entry( + state, output_info, key, func, args + ) + for output_info in entry.output_infos + ] + return tuple(outputs) + else: + return self._get_output_tensor_from_cache_entry( + state, entry.output_infos[0], key, func, args + ) + + def _crosscheck_cache_output( + self, + output: Union[Optional[FakeTensor], tuple[Optional[FakeTensor], ...]], + func: OpOverload, + types: Sequence[type], + args: Sequence[object], + kwargs: Mapping[str, object], + ) -> None: + """ + Helper to validate that the output synthesized from the cache matches + the output created by normal dispatch. + """ + + def assert_helper(a: Any, b: Any) -> None: + if isinstance(a, tuple): + assert isinstance(b, tuple) + assert len(a) == len(b) + for l, r in zip(a, b): + assert_helper(l, r) + elif isinstance(a, int): + assert isinstance(b, int) and a == b + elif a is None: + assert b is None + elif isinstance(a, py_sym_types): + assert type(a) is type(b) and a.node is b.node + elif isinstance(a, torch.Tensor): + assert isinstance(b, torch.Tensor) + assert_metadata_eq(assert_eq, a, b) + else: + raise RuntimeError(f"Unsupported type {type(a)}") + + try: + true_output = self._dispatch_impl(func, types, args, kwargs) + except Exception as e: + raise RuntimeError( + f"FakeTensor cache crosscheck failure: func={func}, " + f"args={args}, kwargs={kwargs}: Dispatch raised={e}" + ) from e + try: + assert_helper(true_output, output) + except Exception as e: + raise RuntimeError( + f"FakeTensor cache crosscheck failure: func={func}, " + f"args={args}, kwargs={kwargs}" + ) from e + + def dispatch( + self, + func: OpOverload, + types: Sequence[type], + args: Sequence[object] = (), + kwargs: Mapping[str, object] = immutable_dict(), + ) -> object: + kwargs = kwargs or {} + with no_dispatch(): + log.debug("%s %s %s", func, args, kwargs) + + if func in _DISPATCH_META_HANDLERS: + return _DISPATCH_META_HANDLERS[func](args) + + if log.getEffectiveLevel() <= logging.DEBUG: + log.debug( + "%sFakeTensorMode.__torch_dispatch__: %s", " " * RECURSION_COUNT, func + ) + # NOTE: incr is intentionally unused for a RAII pattern + incr = IncrementRecursionCount() # noqa: F841 + + # Some attribute queries that can be serviced directly + # See Note [is_coalesced is dispatched] + if func in _DISPATCH_HANDLE_DIRECTLY: + # NB: no_dispatch is ok here too, this func is very simple + with in_kernel_invocation_manager(self): + return func(*args, **kwargs) + + if self.cache_enabled: + return self._cached_dispatch_impl(func, types, args, kwargs) + else: + return self._dispatch_impl(func, types, args, kwargs) + + def _maybe_infer_fake( + self, func: OpOverload, path: KeyPath, fake: object, real: object + ) -> tuple[Optional[object], bool]: + """ + Helper to cross-check fake/real output properties & values, + and create new fake vals if mismatched. + Returns tuple of object & boolean, for whether or not it was overwrriten + """ + import sympy + + from torch._subclasses.fake_utils import _check_fake_real_tensors + + def _check_fake_real_vals(fake: Any, real: Any) -> None: + # use real values + ShapeEnv to check mismatches between potentially symbolic values + if isinstance(fake, (SymInt, SymFloat)): + # symbolic expression, ask ShapeEnv to substitute known backed/unbacked values + assert self.shape_env is not None + if ( + not fake.node.expr.free_symbols + - self.shape_env.var_to_val.keys() + - self.shape_env.unbacked_var_to_val.keys() + ): + if ( + self.shape_env._maybe_evaluate_static( + sympy.Eq(fake.node.expr, real), compute_hint=True + ) + is not sympy.S.true + ): + raise MetadataMismatchError( + f"mismatch between fake value {fake} and real value {real} " + ) + elif isinstance( + fake, (int, float, bool) + ): # concrete value, check direct equality + if fake != real: + raise MetadataMismatchError( + f"mismatch between fake value {fake} and real value {real} " + ) + + if isinstance(fake, torch.Tensor): + try: + _check_fake_real_tensors( + real, # type: ignore[arg-type] + fake, # type: ignore[arg-type] + context="Real tensor propagation found", + sizes=False, # manual check below + strides=False, # skip strides + storage_offset=True, + requires_grad=False, # issues with FakeTensorConverter preserving requires_grad + ) + except MetadataMismatchError as exc: + if torch._functorch.config.generate_fake_kernels_from_real_mismatches: + dtrace_structured( + "mismatched_fake_kernel", + metadata_fn=lambda: { + "op": str(func), + "reason": exc.reason, # noqa: F821 + }, + ) + return _infer_fake_from_real_tensor(self, func, real), True # type: ignore[arg-type] + raise MetadataMismatchError( + f"Real tensor propagation found a metadata mismatch between " + f"fake tensor {fake} and real tensor {real}, " + f" at output{keystr(path)}, for func: {func}" + ) from exc + + for j, (s_fake, s_real) in enumerate(zip(fake.size(), real.size())): # type: ignore[attr-defined] + try: + _check_fake_real_vals(s_fake, s_real) + except MetadataMismatchError as exc: + if torch._functorch.config.generate_fake_kernels_from_real_mismatches: + dtrace_structured( + "mismatched_fake_kernel", + metadata_fn=lambda: { + "op": str(func), + "reason": exc.reason, # noqa: F821 + }, + ) + return _infer_fake_from_real_tensor(self, func, real), True # type: ignore[arg-type] + raise MetadataMismatchError( + f"Real tensor propagation found an output size mismatch between " + f"fake shape {s_fake} and real shape {s_real}, " + f"at output{keystr(path)}.size({j}), for func: {func}" + ) from exc + elif fake is None and real is not None: + if torch._functorch.config.generate_fake_kernels_from_real_mismatches: + dtrace_structured( + "mismatched_fake_kernel", + metadata_fn=lambda: { + "op": str(func), + "reason": f"mismatch between fake value {fake} and real value {real}", # noqa: F821 + }, + ) + return _infer_fake_from_real_tensor(self, func, real), True # type: ignore[arg-type] + raise MetadataMismatchError( + f"Real tensor propagation found a metadata mismatch between " + f"fake tensor {fake} and real tensor {real}, " + f" at output{keystr(path)}, for func: {func}" + ) + else: + try: + _check_fake_real_vals(fake, real) + except MetadataMismatchError as exc: + raise MetadataMismatchError( + f"Real tensor propagation found an output value mismatch between " + f"fake output value {fake} and real output value {real}, " + f"at output{keystr(path)}, for func: {func}" + ) from exc + return fake, False + + def _maybe_infer_fake_kernel_from_pytree_out( + self, + func: OpOverload, + fake_in: object, + real_in: object, + fake_out: object, + real_out: object, + ) -> Optional[object]: + """ + Helper to cross-check fake/real output properties & values, + and create new fake vals if mismatched, but at the kernel level. + Means this handles pytree outputs & checks aliasing. + """ + from torch._subclasses.fake_utils import _check_alias_info + + # we might have to clear pending unbacked symbols, if we override the kernel + pending_unbacked = None + if self.shape_env: + pending_unbacked = list(self.shape_env.pending_fresh_unbacked_symbols) + + def _clear_pending_unbacked() -> None: + self.shape_env.pending_fresh_unbacked_symbols = list( # type: ignore[union-attr] + set(self.shape_env.pending_fresh_unbacked_symbols).difference( # type: ignore[union-attr] + pending_unbacked # type: ignore[arg-type] + ) + ) + + fake_paths_leaves, fake_spec = pytree.tree_flatten_with_path(fake_out) + real_leaves, _ = pytree.tree_flatten(real_out) + try: + # catch aliasing mismatches between fake/real tensors + _check_alias_info( + "Real tensor propagation found", real_out, real_in, fake_out, fake_in + ) + except MetadataMismatchError as exc: + # if mismatch found, optionally infer fake kernel + if torch._functorch.config.generate_fake_kernels_from_real_mismatches: + dtrace_structured( + "mismatched_fake_kernel", + metadata_fn=lambda: { + "op": str(func), + "reason": ( + f"Mismatched aliasing spec between fake kernel and real kernel: {exc.reason}" # noqa: F821 + ), + }, + ) + # if aliasing mismatches are found, it's likely that the fake tensor impl + # is incorrectly aliasing, since we don't support aliasing custom ops. + # in this case we can default to inferring non-aliasing fake kernels from the real outputs. + _clear_pending_unbacked() + return tree_map( + lambda x: _infer_fake_from_real_tensor(self, func, x), real_out + ) + else: + raise MetadataMismatchError( + f"Real tensor propagation found an aliasing mismatch between " + f"fake output {fake_out} and real output {real_out}, " + f" for func: {func}" + ) from exc + + # if no errors raised, run cross checks on fake/real tensors, + # optionally overriding individual fake tensors, if individual meta kernel output is incorrect. + fake_leaves, overrides = zip( + *[ + self._maybe_infer_fake(func, _fake_path, _fake_out, _real_out) + for (_fake_path, _fake_out), _real_out in zip( + fake_paths_leaves, real_leaves + ) + ] + ) + if ( + any(overrides) and pending_unbacked + ): # only keep new pending unbacked symbols + _clear_pending_unbacked() + return pytree.tree_unflatten(fake_leaves, fake_spec) + + def _dispatch_impl( + self, + func: OpOverload, + types: Sequence[type], + args: Sequence[object], + kwargs: Mapping[str, object], + ) -> Optional[FakeTensor]: + from torch._higher_order_ops.utils import registered_hop_fake_fns + + flat_args, args_spec = pytree.tree_flatten((args, kwargs)) + + # DO NOT PUT LOGIC BEFORE UNRECOGNIZED TYPE CHECKING + # We must throw NotImplemented in case of unrecognized types to handle subclasses. + # Throwing the exception will pass the control to the next __torch_dispatch__. + # See [subclass inputs] below + # NB: If you're seeing a mysterious infinite loop involving fake + # tensor, it might be related to this line. Though I'm not sure + # how you'll know to read this comment, as this line won't show up + # in the stack trace. + has_unrecognized_types = _check_for_subclass(flat_args) + if has_unrecognized_types: + unrecognized_types = [ + type(x) for x in flat_args if _check_for_subclass_arg(x) + ] + not_implemented_log.debug( + "FakeTensorMode unrecognized subclass(es): %s", unrecognized_types + ) + return NotImplemented + + flat_arg_fake_tensors = [t for t in flat_args if self.is_our_fake(t)] + has_symbolic_sizes = any( + i._has_symbolic_sizes_strides for i in flat_arg_fake_tensors + ) or any(isinstance(a, SymInt) for a in flat_args) + + converter = self.fake_tensor_converter + + is_lift_func = func in self.lift_fns + + # If we are trying to avoid device init, then we need to avoid constant + # prop on constant tensors for ops that change devices. + avoiding_device_init = False + if self.avoid_device_init: + if ( + func is torch.ops.aten._to_copy.default + and "device" in kwargs + and kwargs["device"].type != "cpu" # type: ignore[attr-defined] + ): + avoiding_device_init = True + if func is torch.ops.prims.device_put.default: + avoiding_device_init = True + + # skip const prop for aten._to_copy if + # 1. input tensor is on "meta" device + # 2. destination device is unavailable, captured by `avoiding_device_init` + device_conversion_skip_const_prop = ( + func is torch.ops.aten._to_copy.default + and isinstance(args[0], torch.Tensor) + and args[0].device.type == "meta" + ) or avoiding_device_init + + # To constant propagate through these functions: + # 1, If this is a lift due to a torch.tensor call, + # the input tensor is guaranteed to be a + # constant, so we keep a copy of the original argument along so + # we can query it if we're asked to item() it at some later point. + # (Note that you can always call a lift fn manually, so we do + # have to check if there are any fake tensors!) + # 2, Some functions that allow Python numbers to bind to Tensors, e.g, torch.div + if (is_lift_func and not flat_arg_fake_tensors) or ( + should_allow_numbers_as_tensors(func) + and not has_symbolic_sizes + and not flat_arg_fake_tensors + and not device_conversion_skip_const_prop + ): + assert all(t.constant is not None for t in flat_arg_fake_tensors), ( + f"{func} should not have fake inputs without constants" + ) + const_flat_args = [ + a.constant if self.is_our_fake(a) else a for a in flat_args + ] + const_args, const_kwargs = pytree.tree_unflatten(const_flat_args, args_spec) + out = func(*const_args, **const_kwargs) + if type(out) is Tensor and self.may_turn_const(out): + # NB: not in_kernel_invocation_manager because we're doing real + # compute here + # NB: no_dispatch() here is VERY DANGEROUS (like, segfault + # dangerous) if this is actually a wrapper subclass tensor, + # therefore the exact type test above + with no_dispatch(): + out = out.clone() + return converter.from_real_tensor(self, out, make_constant=True) + + # if we are in the dispatch mode, we will enter this function even if the inputs + # are not FakeTensors. For now, throw if any non-Fake Tensor inputs + # and just support constructors. + + # this is generated from torch.tensor(), which does not use the + # dispatcher, to allow wrapper subclasses to wrap the new tensor + if is_lift_func: + assert len(kwargs) == 0 and len(args) == 1, f"{args} {kwargs}" + + if type(args[0]) is Tensor: + return converter.from_real_tensor(self, args[0]) + + # Recompute flat_arg_fake_tensors here again in case some of the inputs + # were real tensors and fakified in validate_and_convert_non_fake_tensors + (flat_args, flat_arg_fake_tensors) = self.validate_and_convert_non_fake_tensors( + func, converter, flat_args, args_spec + ) + del args, kwargs # Invalidated + + # The current constant handling only support tracing systems + # (aot autograd, torchdynamo) where each operation is run consecutively. + # Because each operation is run in order, we can trace out and support + # sequences like: x = torch.tensor(0.); y = x.add_(1) + # Whenever a constant is written to but with inputs that cannot be evaluated + # statically, such as random_(), we invalidate all constants that alias the input + # We will rely on functionalization for use of fake tensors constants as persistent + # objects on an FX Graph. + + # We dispatch size/stride/numel on the FakeTensor not its constant, so bail on inplace_view + all_constant = all(e.constant is not None for e in flat_arg_fake_tensors) + if ( + isinstance(func, torch._ops.OpOverload) + and torch.Tag.nondeterministic_seeded not in func.tags + and torch.Tag.inplace_view not in func.tags + and all_constant + and len(flat_arg_fake_tensors) != 0 + and not has_symbolic_sizes + and not avoiding_device_init + and func is not aten._nested_tensor_from_tensor_list.default + ): + const_flat_args = [ + a.constant if self.is_our_fake(a) else a for a in flat_args + ] + const_args, const_kwargs = pytree.tree_unflatten(const_flat_args, args_spec) + + # NB: not in_kernel_invocation_manager(self) as we want to do REAL + # compute + with no_dispatch(): + out = func(*const_args, **const_kwargs) + + flat_out = pytree.tree_leaves(out) + flat_out_tensors = [t for t in flat_out if isinstance(t, Tensor)] + all_constant = all(self.may_turn_const(t) for t in flat_out_tensors) + + if all_constant: + return pytree.tree_map_only( + Tensor, + lambda t: converter.from_real_tensor(self, t, make_constant=True), + out, + ) + + # we weren't able to turn outputs to constants, + # so invalidate all constants that might be aliases of the outputs + for ten in flat_out_tensors: + converter.invalidate_constant_aliases(ten) + + # we are falling through to running non constant tensors, any input constant that + # is written to must be invalidated + args, kwargs = pytree.tree_unflatten(flat_args, args_spec) + + if ( + isinstance(func, torch._ops.HigherOrderOperator) + and func in registered_hop_fake_fns + ): + # Reenable the fake tensor mode for the registered fake function + maybe_ignore_fresh_unbacked_symbols = ( + contextlib.nullcontext + if self.shape_env is None + else self.shape_env.ignore_fresh_unbacked_symbols + ) + + with self, maybe_ignore_fresh_unbacked_symbols(): + # pyrefly: ignore [index-error] + return registered_hop_fake_fns[func](*args, **kwargs) + + self.invalidate_written_to_constants(func, flat_arg_fake_tensors, args, kwargs) + + def maybe_to_real_tensor( + t: T, + ) -> Optional[Union[T, Tensor, torch._C.ScriptObject]]: + if isinstance(t, FakeTensor): + return t.real_tensor + elif isinstance(t, py_sym_types): + assert self.shape_env is not None + return t.node.pytype( + t.node.expr.xreplace(self.shape_env.var_to_val).xreplace( + self.shape_env.unbacked_var_to_val + ) + ) + elif isinstance(t, FakeScriptObject): + return t.real_obj + else: + return t + + from torch.fx.experimental.symbolic_shapes import ( + compute_unbacked_bindings, + free_unbacked_symbols, + ) + + nil = object() + + real_out = nil + if ( + self.propagate_real_tensors + and all(e.real_tensor is not None for e in flat_arg_fake_tensors) + and not any( + ( + isinstance(a, py_sym_types) + and (syms := free_unbacked_symbols(a)) + and self.shape_env is not None + and any(s not in self.shape_env.unbacked_var_to_val for s in syms) + ) + for a in flat_args + ) + ): + log.debug("propagate_real_tensors %s", func) + real_flat_args = [maybe_to_real_tensor(a) for a in flat_args] + real_args, real_kwargs = pytree.tree_unflatten(real_flat_args, args_spec) + + is_builtin = library_utils.is_builtin(func) + if not is_builtin: + mutation_checker = library_utils.MutationChecker( + func, real_flat_args, args_spec + ) + + try: + real_out = func(*real_args, **real_kwargs) + except ZeroDivisionError as exc: + # we shouldn't broadly catch all errors here; + # some come from real-kernel mutation/aliasing checks we want to run. + # add more exception types as needed. + log.debug( # noqa: G200 + "real-tensor fallback failed for %s: %s; silently ignoring", + func, + exc, + ) + + if not is_builtin: + mutation_checker.check() # type: ignore[possibly-undefined] + library_utils.check_aliasing_constraint(func._name, flat_args, real_out) + + elif self.propagate_real_tensors: + # This can happen occasionally legitimately, specifically when you + # are inside the meta of a data dependent operation and you create + # a tensor on an unbacked SymInt; at this point in time we don't + # know what the unbacked SymInt is, but we will know later. + # However, if there's a bug in the condition above, this condition + # will also trigger. + log.debug( + "SKIPPED propagate_real_tensors %s(%s, %s) %s", + func, + flat_arg_fake_tensors, + flat_args, + self.shape_env.unbacked_var_to_val if self.shape_env else None, + ) + + def maybe_propagate_real_tensors(fake_out: T) -> T: + import sympy + + log.debug("maybe_propagate_real_tensors %s", func) + + def go(t: object, real_t: Tensor) -> None: + if isinstance(t, FakeTensor): + # NB: unconditionally overwrite + log.debug( + "maybe_propagate_real_tensors %s -> %s", id(t), id(real_t) + ) + t.real_tensor = real_t + for s, real_s in zip(t.size(), real_t.size()): + go(s, real_s) # type: ignore[arg-type] + for s, real_s in zip(t.stride(), real_t.stride()): + go(s, real_s) # type: ignore[arg-type] + go(t.storage_offset(), real_t.storage_offset()) # type: ignore[arg-type] + elif isinstance(t, py_sym_types) and free_unbacked_symbols(t): + if isinstance(t.node.expr, sympy.Symbol): + assert self.shape_env is not None + self.shape_env.set_unbacked_var_to_val(t.node.expr, real_t) + elif ( + isinstance(s := t.node.expr, sympy.Eq) + and isinstance(s.lhs, sympy.Symbol) + and s.rhs == 1 + ): + assert self.shape_env is not None + + self.shape_env.set_unbacked_var_to_val(s, int(real_t)) + + if real_out is not nil: + # cross check fake/real outputs, and optionally override fake kernel mismatches + if not torch._functorch.config.generate_fake_kernels_from_real_mismatches: + self._maybe_infer_fake_kernel_from_pytree_out( + func, + (args, kwargs), + (real_args, real_kwargs), + fake_out, + real_out, + ) + else: + # this can override the output only when the flag is True + fake_out = self._maybe_infer_fake_kernel_from_pytree_out( # type: ignore[assignment] + func, + (args, kwargs), + (real_args, real_kwargs), + fake_out, + real_out, + ) + + # populate unbacked_var_to_val + if ( + not isinstance(fake_out, Tensor) + and not isinstance(real_out, Tensor) + and type(fake_out) is not type(real_out) + ): + # This can happen when decompositions have different return types, + # e.g. namedtuple vs. tuple vs. list. + tree_map_( + go, + tuple(pytree.tree_flatten(fake_out)), + tuple(pytree.tree_flatten(real_out)), + ) + else: + tree_map_(go, fake_out, real_out) + + # If a data-dependent op is used in a decomposition, we + # may need to get the unbacked settings "early" + # TODO: Is this really needed? + compute_unbacked_bindings(self.shape_env, fake_out, peek=True) + + # pyrefly: ignore [bad-return] + return fake_out + + # Try for fastpath + if has_symbolic_sizes: + fast_impl = get_fast_op_impls().get(func) + if fast_impl is not None: + return maybe_propagate_real_tensors(fast_impl(self, *args, **kwargs)) + + # If there's a Python meta, prefer that over the decomposition + from torch._decomp import meta_table + + if ( + func not in meta_table + and not self.cpp_meta_supports_symint(func) + and not ( + has_symbolic_sizes and func in self._unbacked_special_fake_handling_ops + ) + ): + from torch._decomp import decomposition_table + + # Prefer Python decompositions over C++ ones + if func in decomposition_table and ( + has_symbolic_sizes + or ( + # TODO: Remove these exclusions, so that we can remove + # this leg entirely + torch_decomp_decompositions(func) + and all(not is_sparse_any(e) for e in flat_arg_fake_tensors) + ) + ): + with self: + return maybe_propagate_real_tensors( + decomposition_table[func](*args, **kwargs) + ) + + with self: + # Decomposes CompositeImplicitAutograd ops + r = func.decompose(*args, **kwargs) + if r is not NotImplemented: + return maybe_propagate_real_tensors(r) + + # prims already wrap FakeTensor inputs to FakeTensor outputs + # and do device logic, we dont need do anything but run them + # and ensure that Meta kernels are dispatched to (see) + # Fake Tensor Dispatch Keys + # TODO - we should be use the prim aten impl + # TODO - fix prims complex ops + if ( + "prims::" in func._schema.name + and hasattr(func, "prim_meta_impl") + and not stride_incorrect_op(func) + ): + with self: + return maybe_propagate_real_tensors( + func.prim_meta_impl(*args, **kwargs) + ) + + profiles = torch._dynamo.config._custom_ops_profile + if profiles is not None: + if func in profiles.data: + return profiles.generic_fake_kernel(func, self, *args, **kwargs) + + if ( + self.propagate_real_tensors + and real_out is not nil + and not library_utils.is_builtin(func) + and self.shape_env is not None + ): + # Automatically infer a Fake kernel if there isn't one. + if not library_utils.has_fake_kernel(func): + result = inferred_fake_kernel_from_real_out(self, func, real_out) + + dtrace_structured( + "missing_fake_kernel", + metadata_fn=lambda: { + "op": str(func), + }, + ) + return maybe_propagate_real_tensors(result) + + # Users can register FakeTensor rules for custom operators + # Call them if they exist. + maybe_fake_impl = torch._library.simple_registry.singleton.find( + func.name() + ).fake_impl.kernel + if maybe_fake_impl: + try: + ctx = torch._library.fake_impl.FakeImplCtx(self, func) + with torch._library.fake_impl.set_ctx_getter(lambda: ctx), self: + result = maybe_fake_impl(*args, **kwargs) + return maybe_propagate_real_tensors(result) + + except MissingOpProfile as e: + # If we have a fake kernel registered generated from OpProfiles + # but there doesn't exist a profile for the existing inputs, and we are in + if ( + self.propagate_real_tensors + and real_out is not nil + and not library_utils.is_builtin(func) + and self.shape_env is not None + ): + result = inferred_fake_kernel_from_real_out(self, func, real_out) + + dtrace_structured( + "missing_fake_kernel", + metadata_fn=lambda: { + "op": str(func), + }, + ) + return maybe_propagate_real_tensors(result) + else: + raise e + + # special handling for funcs registered through `register_op_impl`, + # e.g., manipulating args on constructor calls to construct meta tensors + # and then afterwards wrapping them to a FakeTensor + for run_impl_check, op_impl in op_implementations_checks: + if run_impl_check(func): + op_impl_out = op_impl(self, func, *args, **kwargs) + if op_impl_out is not NotImplemented: + return maybe_propagate_real_tensors(op_impl_out) + + def maybe_run_unsafe_fallback( + error: Optional[RuntimeError] = None, + ) -> Optional[FakeTensor]: + # We infer the meta of a custom ops that return None to just + # return None. custom ops are not allowed to mutate metadata + # of their inputs, so this is safe. + if torch._library.utils.can_generate_trivial_fake_impl(func): + return None + # no meta kernel registered, fallback to kernel for the device + if has_symbolic_sizes or not self.can_run_unsafe_fallback(func): + raise UnsupportedOperatorException(func) + if error is None: + error = UnsupportedOperatorException(func) + return run_fallback_kernel(self, func, flat_args, args_spec, error) + + # Optimization: If there is no Meta kernel, it takes a surprisingly long + # amount of time to catch the NotImplementedError, so we check it here. + if not has_meta(func): + fallback = maybe_run_unsafe_fallback() + return maybe_propagate_real_tensors(fallback) + + # run kernel registered to meta for func, which include + # python meta registrations, prims, decomps, and c++ meta fns (structured kernels) + # It's possible that the kernel will return NotImplementedError + try: + with in_kernel_invocation_manager(self): + r = func(*args, **kwargs) + except NotImplementedError as not_implemented_error: + return maybe_run_unsafe_fallback(not_implemented_error) + except Exception: + log.exception("failed while attempting to run meta for %s", func) + raise + + return maybe_propagate_real_tensors( + self.wrap_meta_outputs_with_default_device_logic( + r, func, flat_args, device=kwargs.get("device") + ) + ) + + # WARNING: DO NOT add any additional namespaces/operators here if they refer to operators + # outside of the pytorch/pytorch library! Any pre-existing things here + # are either in the pytorch/pytorch library or have been grandfathered in. + # The fallback does not always work and MAY CRASH and emit unreadable error messages + # so it should not be allowed by default. + _can_run_unsafe_fallback_allowed_namespaces = ordered_set( + "debugprims", + "prims", + "aten", + "xla", + "vision", + "torchtext", + "torchaudio", + "quantized", + ) + + def can_run_unsafe_fallback(self, func: OpOverload) -> bool: + if not self.allow_fallback_kernels: + return False + # It's OK to try the fallback for built-in ops (e.g. aten, prims) + # because we control and test these but the fallback leads to unexpected behavior + # in user-defined custom ops + return ( + func.namespace in self._can_run_unsafe_fallback_allowed_namespaces + or func.name() == "fbgemm::gmm" + ) + + def validate_and_convert_non_fake_tensors( + self, + func: OpOverload, + converter: FakeTensorConverter, + flat_args: Sequence[object], + args_spec: TreeSpec, + ) -> tuple[list[object], list[FakeTensor]]: + """ + Checks if the list of tensors are fake tensors. + If not, try to convert them to fake tensors. + Returns the original args, kwargs, and a flattened list of (args, kwargs) that are fake tensors. + """ + flat_arg_fake_tensors: list[FakeTensor] = [] + + def validate(x: T) -> Union[T, FakeTensor]: + if not isinstance(x, Tensor): + return x + + nonlocal flat_arg_fake_tensors + if not self.is_our_fake(x): + if hasattr(func, "tags") and torch.Tag.inplace_view in func.tags: + args, kwargs = pytree.tree_unflatten(flat_args, args_spec) + raise AssertionError( + f"Can't call metadata mutating ops on non-Fake Tensor inputs. Found in {render_call(func, args, kwargs)}" + ) + allow_non_fake_inputs = ( + self.allow_non_fake_inputs + if fake_tensor_tls.allow_non_fake_inputs_override is None + else fake_tensor_tls.allow_non_fake_inputs_override + ) + if not allow_non_fake_inputs: + if isinstance(x, FakeTensor) and x.fake_mode is not self: + raise AssertionError("Mixing fake modes NYI") + args, kwargs = pytree.tree_unflatten(flat_args, args_spec) + raise AssertionError( + f"Please convert all Tensors to FakeTensors first or instantiate FakeTensorMode " + f"with 'allow_non_fake_inputs'. Found in {render_call(func, args, kwargs)}" + ) + + out = converter.from_real_tensor(self, x) + else: + out = x + + flat_arg_fake_tensors.append(out) + return out + + validated_args = [validate(a) for a in flat_args] + return validated_args, flat_arg_fake_tensors + + def wrap_meta_outputs_with_default_device_logic( + self, + r: object, + func: OpOverload, + flat_args: Sequence[object], + device: torch.device, + ) -> PyTree: + converter = self.fake_tensor_converter + + # Lazily initialized, in case there are no tensor returns + common_device = None + has_scalar_only_inputs = False + + def wrap(e: T) -> Union[T, FakeTensor]: + nonlocal common_device + nonlocal has_scalar_only_inputs + + if not isinstance(e, Tensor): + return e + + if common_device is None: + ( + common_device, + has_scalar_only_inputs, + ) = FakeTensor._find_common_device(func, flat_args) + + is_our_fake = self.is_our_fake(e) + if is_our_fake: + torch._check( + e.device == common_device, + lambda: f"FakeTensor is wrapped to wrong device, found {e.device}, expected {common_device}", + ) + return cast(T, e) + elif converter is not None: + if has_scalar_only_inputs: + # Under FakeTensorMode, op accepts scalar only inputs, such as aten.add/sub/mul/div, + # returns a real scalar tensor on CPU. See TensorMeta() in _prims/__init__.py for details. + # We thus directly convert real tensor to fake tensor. + return converter.from_real_tensor(self, e) + else: + return converter.from_meta_and_device( + self, e, device or common_device + ) + else: + # pyrefly: ignore [bad-return] + return e + + return tree_map(wrap, r) + + def create_symbolic_nested_int( + self, *, nt_tensor_id: Optional[int] = None + ) -> torch.SymInt: + # See Note: [Creating symbolic nested int] + # Returned nested int always has coeff=1; multiply the result by coeff if needed + import torch.nested._internal.nested_tensor + from torch.nested._internal.nested_int import NestedIntNode + + if nt_tensor_id is None: + nt_tensor_id = self.nt_tensor_id_counter + assert self.enter_stack, "should only called while FakeTensorMode is active" + self.nt_tensor_id_counter += 1 + hint = torch.SymInt(NestedIntNode(nt_tensor_id, 1)) + + src = torch._dynamo.source.EphemeralSource("intermediate_offsets_or_lengths") + assert self.shape_env is not None + ret = self.shape_env.create_symintnode( + sym=self.shape_env.create_symbol( + val=hint, + source=src, + ), + hint=hint, + source=src, + ) + return ret + + _cpp_meta_supports_symint = ordered_set( + aten.empty.memory_format, + aten.empty_strided.default, + aten.as_strided_scatter.default, + aten.as_strided.default, + aten.as_strided_.default, + aten.zeros.default, + aten.detach.default, + aten.view_as_real.default, + aten.view_as_complex.default, + aten.set_.source_Storage_storage_offset, + aten._sparse_coo_tensor_with_dims_and_tensors.default, + ) + + _unbacked_special_fake_handling_ops = ordered_set( + aten.view.default, + aten._unsafe_view.default, + aten.slice.Tensor, + ) + + def cpp_meta_supports_symint(self, func: OpOverload) -> bool: + if torch.Tag.view_copy in func.tags: + return True + return func in self._cpp_meta_supports_symint + + lift_fns = ordered_set(aten.lift_fresh.default, aten.lift_fresh_copy.default) + + def may_turn_const(self, t: Tensor) -> bool: + return ( + t.numel() <= CONSTANT_NUMEL_LIMIT + and not is_sparse_any(t) + and not self.is_our_fake(t) + and t.device.type != "meta" + ) + + def invalidate_written_to_constants( + self, + func: OpOverload, + flat_arg_fake_tensors: Sequence[FakeTensor], + args: Sequence[object], + kwargs: Mapping[str, object], + ) -> None: + any_constant = any(e.constant is not None for e in flat_arg_fake_tensors) + schema_info = get_schema_info(func) + if any_constant and schema_info.is_mutable(): + _, new_kwargs = normalize_function( # type: ignore[misc] + func, + args=args, # type: ignore[arg-type] + kwargs=kwargs, # type: ignore[arg-type] + normalize_to_only_use_kwargs=True, + ) + for k, v in new_kwargs.items(): + k = k if (k != "input" or schema_info.has_argument(k)) else "self" + if ( + self.is_our_fake(v) + and schema_info.is_mutable(k) + and v.constant is not None + ): + self.fake_tensor_converter.invalidate_constant_aliases(v.constant) + + def from_tensor( + self, + tensor: Tensor, + *, + static_shapes: Optional[bool] = None, + source: Optional[Source] = None, + symbolic_context: Optional[SymbolicContext] = None, + trace: bool = True, + ) -> FakeTensor: + shape_env: Optional[ShapeEnv] = self.shape_env + if static_shapes is None: + static_shapes = self.static_shapes + if static_shapes: + assert symbolic_context is None, ( + "cannot set both static_shapes and symbolic_context" + ) + shape_env = None + return self.fake_tensor_converter.from_real_tensor( + self, + tensor, + shape_env=shape_env, + source=source, + symbolic_context=symbolic_context, + trace=trace, + ) + + +_StoragePointer = object + + +def _validate_symbolic_output_for_caching( + state: _CacheKeyState, output: FakeTensor +) -> None: + """ + Validate symbolic content in output and raise _BypassDispatchCache if + caching should be bypassed. + + Args: + state: Cache key state containing known symbols + output: Output to validate + proxy_mode_active: Whether PROXY dispatch mode is currently active + + Raises: _BypassDispatchCache: If output contains symbolic content that + prevents caching + + Details: + + If our output contains any symbols that didn't appear in the input then we + need to bypass. Usually this will be unbacked symbols which can't be + properly reconstructed but there could be "weird" cases where backed symbols + spontaneously appear (from non-input state)? + + If we're proxy (symbol) tracing and the output contains ANY symbols then we + need to bypass. The problem is that ProxyTorchDispatchMode relies on SymNode + object identity and being able to see the construction of SymNodes. + + We could improve the proxy tracing case in a few ways: + + 1. If the output SymNodes are directly copied from inputs then this is + actually fine - they're already tracked. This would probably be the + biggest bang/buck. + + 2. If the output (tensors) are all direct copies of the inputs then this is + also fine - since they're inputs they must be tracked. We already compute + this we just don't plumb it around enough. + + 3. If the output SymNodes are already tracked by the proxy then this is also + actually fine - they're properly tracked. This probably wouldn't be + common since for most outputs we use torch.empty_strided() and recompute + strides. + + 4. We could use the proxy to track "how" the SymNodes were computed and when + using the cache we could "replay" them properly to teach the proxy how to + build them. + """ + from torch.fx.experimental.symbolic_shapes import _iterate_exprs, _iterate_nodes + + is_tracing = torch.fx.experimental.proxy_tensor.get_proxy_mode() is not None + if is_tracing: + # Check for SymNode types in PROXY mode - this should bypass caching + # regardless of whether symbols are known or not + for _ in _iterate_nodes(output): + raise _BypassDispatchCache("Proxy mode with SymNode output") + else: + # Check for unrepresented symbols in tensor expressions + for s in _iterate_exprs(output): + for symbol in s.free_symbols: + if symbol not in state.known_symbols: + raise _BypassDispatchCache("unrepresented symbol in output") + + +# NB: returns fake tensors +def run_fallback_kernel( + fake_mode: FakeTensorMode, + func: OpOverload, + flat_args: Sequence[object], + args_spec: PyTree, + orig_not_implemented_exception: RuntimeError, +) -> FakeTensor: + # these should all be supported, just to be safe + # avoid fallback for operators which inplace modify metadata + # because the input fake tensors would be umodified + if torch.Tag.inplace_view in func.tags: + raise orig_not_implemented_exception + + inp_impls = {} + + # Don't use in_kernel_invocation_manager(fake_mode) as we want to do + # REAL compute (not with meta device) + with no_dispatch(): + + def to_real_tensor(e: T) -> Union[T, Tensor]: + if fake_mode.is_our_fake(e): + out = torch.zeros_like(e, device=e.fake_device) + if e.is_sparse: + out._coalesced_(e.is_coalesced()) + inp_impls[id(out)] = e + return out + return e + + flat_args = [to_real_tensor(a) for a in flat_args] + args, kwargs = pytree.tree_unflatten(flat_args, args_spec) + + r = func(*args, **kwargs) + + storages: set[_StoragePointer] = set() + + for e in flat_args: + if isinstance(e, Tensor): + if not is_sparse_any(e): + storages.add(e._typed_storage()._cdata) + + # TODO: also check metadata change on inputs + # proper aliasing/metadata relationship between outputs and inputs will + # not be set up, bc of conversion to device, unless we can reuse an + # input impl + + def map_out(e: T) -> Union[T, FakeTensor]: + if id(e) not in inp_impls and ( + isinstance(e, Tensor) + and not is_sparse_any(e) + and e._typed_storage()._cdata in storages + ): + raise orig_not_implemented_exception + + if isinstance(e, Tensor): + if id(e) in inp_impls: + return inp_impls[id(e)] + else: + return fake_mode.fake_tensor_converter.from_real_tensor(fake_mode, e) + else: + return e + + return pytree.tree_map(map_out, r) + + +def _set_cache_key_for_shape_env( + cache: dict[_DispatchCacheKey, _DispatchCacheEntry], + key: _DispatchCacheKey, + entry: _DispatchCacheEntry, +) -> None: + key.strip_shape_env() + cache[key] = entry + + +def _set_cache_key( + cache: dict[_DispatchCacheKey, _DispatchCacheEntry], + key: _DispatchCacheKey, + entry: _DispatchCacheEntry, +) -> None: + cache[key] = entry + + +# Just for use to allow copying a module to fake tensors, +# does not apply elsewhere +class FakeCopyMode(TorchFunctionMode): + def __init__(self, fake_mode: FakeTensorMode) -> None: + self.fake_mode = fake_mode + + def __torch_function__( + self, + func: OpOverload, + types: Sequence[type], + args: Sequence[object] = (), + kwargs: Optional[Mapping[str, object]] = None, + ) -> FakeTensor: + kwargs = kwargs if kwargs else {} + + # clone will get called in Parameter deepcopy + if func is torch._C.TensorBase.clone: + assert isinstance(args[0], Tensor) + return func( + self.fake_mode.from_tensor(args[0], static_shapes=True), **kwargs + ) + elif func is Tensor.__deepcopy__: + assert len(args) == 2 and len(kwargs) == 0 + tensor = cast(Tensor, args[0]) + memo = cast(dict[int, FakeTensor], args[1]) + + if id(tensor) in memo: + return memo[id(tensor)] + + out = self.fake_mode.from_tensor(tensor, static_shapes=True) + memo[id(tensor)] = out + return out + else: + with torch._C.DisableTorchFunctionSubclass(): + return func(*args, **kwargs) + + +def _device_handler(args: Sequence[object]) -> torch.device: + # NB: Don't use is_our_fake, just serve the fake information + # as is. Notice we don't use 'self'; we use args[0].fake_mode + # because they may not be the same. It would also be possible + # to return NotImplemented here, in which case the FakeTensor + # handler on args[0] would handle it, but we're being nice and + # short-circuiting quickly. + assert len(args) == 1 and isinstance(args[0], FakeTensor) + if args[0].fake_mode.in_kernel_invocation: + return torch.device("meta") + else: + return args[0].fake_device + + +# [subclass inputs] +# Suppose we enable fake tensor mode. This means that fake tensor +# mode will run first. But what if we do an operation that +# involves a tensor subclass that will desugar into normal tensor +# operations? Without returning NotImplemented, fake tensor mode will run first, +# decide that a conversion was made (since there was a non fake +# tensor argument), and report an error that converting non +# fake tensor is not supported. What we actually wanted to happen +# was to give the subclass a chance to figure out what it wants to +# before erroring out. Returning NotImplemented here allows this. +def _check_for_subclass(flat_args: Sequence[object]) -> bool: + return any(_check_for_subclass_arg(x) for x in flat_args) + + +def _check_for_subclass_arg(x: object) -> bool: + return ( + not isinstance(x, FakeTensor) + and isinstance(x, Tensor) + and type(x) is not Tensor + and type(x) is not torch.nn.Parameter + ) + + +_DISPATCH_META_HANDLERS = { + torch.ops.prim.device.default: _device_handler, + torch.ops.aten.size.default: lambda args: tuple( + int(s) for s in cast(Tensor, args[0]).size() + ), + torch.ops.aten.stride.default: lambda args: tuple( + int(s) for s in cast(Tensor, args[0]).stride() + ), + torch.ops.aten.storage_offset.default: lambda args: int( + cast(Tensor, args[0]).storage_offset() + ), +} + +_DISPATCH_HANDLE_DIRECTLY = ordered_set( + torch.ops.aten.is_coalesced.default, + torch.ops.aten.dense_dim.default, + torch.ops.aten.sparse_dim.default, + # _RecordFunction doesn't support __eq__ so make sure not to attempt to + # cache it. + torch.ops.profiler._record_function_exit._RecordFunction, +) + +from torch._subclasses.fake_impls import ( # noqa: F401 + _device_not_kwarg_ops, + _is_tensor_constructor, + _like_tensor_constructors, + contains_tensor_types, + get_fast_op_impls, + has_meta, + op_implementations_checks, + stride_incorrect_op, +) + + +def evict_fake_tensor_cache_key(key: _DispatchCacheKey) -> None: + if key in FakeTensorMode.cache: + FakeTensorMode.cache.pop(key) + + +@atexit.register +def dump_cache_stats() -> None: + log.info("FakeTensor cache stats:") + log.info(" cache_hits: %s", FakeTensorMode.cache_hits) + log.info(" cache_misses: %s", FakeTensorMode.cache_misses) + bypasses = FakeTensorMode.cache_bypasses + if bypasses: + log.info(" cache_bypasses:") + width = max(len(k) for k in bypasses) + for k, v in sorted(bypasses.items(), key=lambda i: -i[1]): + log.info(" %-*s %s", width + 1, f"{k}:", v) + + +def _infer_fake_from_real_tensor( + mode: FakeTensorMode, op: torch._ops.OpOverload, real_out: torch.Tensor +) -> torch.Tensor: + def unsupported(reason: str) -> None: + raise RuntimeError( + f"propagate_real_tensors: we cannot infer a Fake kernel " + f"(meta kernel) for operator {op._name} because {reason}. " + f"Please use torch.library.register_fake to add a Fake kernel." + ) + + if real_out.storage_offset() != 0: + unsupported( + f"a return has a non-zero storage offset {real_out.storage_offset()}" + ) + + # Since PT2 is rank specialized, there's no such thing as a symbolic + # output rank. So we can assume the fake tensor has the same number of + # dimensions as the real tensor output. + # + # We shouldn't assume the Fake sizes/strides are exactly what we see on + # the real tensor output (perhaps we should give users a lever to toggle + # this). This is because there's a good amount of operators that return + # outputs with data-dependent output shape. + # So we infer the output sizes to all be unbacked symints + fake_shape = [ + torch._library.fake_impl.allocate_size(mode.shape_env) + for _ in range(real_out.dim()) + ] + + # We infer what the strides are. We had a couple of options for this: + # - assume the strides are computable from the sizes + # - use new fresh unbacked symints in the strides + # This doesn't work that well (PT2 doesn't support unbacked symint strides well) + # - use the real strides + # This can only be used if we assume the strides are static. + # We went with the first option. + fake_strides = [-1] * real_out.dim() + strides = [(s, idx) for idx, s in enumerate(real_out.stride())] + strides.sort(key=lambda x: (x[0], -x[1])) + expected = 1 + fake_stride = expected + for s, idx in strides: + if s != expected: + unsupported( + f"a return was not dense in memory (sizes {real_out.shape} strides {real_out.stride()})" + ) + fake_strides[idx] = fake_stride + expected = expected * real_out.shape[idx] + fake_stride = fake_stride * fake_shape[idx] + + with mode: + return torch.empty_strided( + fake_shape, + fake_strides, + device=real_out.device, + dtype=real_out.dtype, + layout=real_out.layout, + ) + + +def inferred_fake_kernel_from_real_out( + mode: FakeTensorMode, op: torch._ops.OpOverload, real_out: Any +) -> Any: + assert mode.shape_env is not None + + # Only support operators that have all Tensor outputs + # This is a general limitation on custom ops that we impose for PT2 + # to avoid baking non-symbolic float/int outputs into the graph. + real_flat_out, spec = pytree.tree_flatten(real_out) + if not all(isinstance(t, torch.Tensor) for t in real_flat_out): + raise RuntimeError( + f"propagate_real_tensors: we don't support operators that return " + f"non-Tensors. Got {op._schema}" + ) + + fake_flat_out = [_infer_fake_from_real_tensor(mode, op, t) for t in real_flat_out] + return pytree.tree_unflatten(fake_flat_out, spec) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_subclasses/fake_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_subclasses/fake_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1212168b090498284e10f11d7c017a3ef8ba94c3 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_subclasses/fake_utils.py @@ -0,0 +1,305 @@ +# mypy: ignore-errors + +import functools +import warnings +from collections.abc import Callable +from typing import Any, Union + +import torch +import torch.utils._pytree as pytree +from torch._ops import OpOverload +from torch._subclasses.fake_tensor import ( + FakeTensor, + FakeTensorMode, + MetadataMismatchError, + tree_flatten_only, + UnsupportedFakeTensorException, +) +from torch.utils._python_dispatch import TorchDispatchMode + + +aten = torch._ops.ops.aten + + +def outputs_alias_inputs(outputs, inputs): + input_storages = { + inp._typed_storage()._cdata + for inp in tree_flatten_only(torch.Tensor, inputs) + if torch._C._has_storage(inp) + } + return any( + torch._C._has_storage(out) and out._typed_storage()._cdata in input_storages + for out in tree_flatten_only(torch.Tensor, outputs) + ) + + +def outputs_are_inputs(outputs, inputs): + input_ids = {id(inp) for inp in tree_flatten_only(torch.Tensor, inputs)} + return any(id(out) in input_ids for out in tree_flatten_only(torch.Tensor, outputs)) + + +def output_alias_each_other(outputs): + storages = set() + for out in tree_flatten_only(torch.Tensor, outputs): + if not torch._C._has_storage(out): + continue + stor = out._typed_storage()._cdata + if stor in storages: + return True + storages.add(stor) + return False + + +def _check_alias_info(context, real_out, real_in, fake_out, fake_in): + r_aliasing = outputs_alias_inputs(real_out, real_in) + f_aliasing = outputs_alias_inputs(fake_out, fake_in) + if r_aliasing != f_aliasing: + raise MetadataMismatchError( + f"{context} mismatch in outputs_alias_inputs check {f_aliasing} != {r_aliasing}" + ) + + r_identity_eq = outputs_are_inputs(real_out, real_in) + f_identity_eq = outputs_are_inputs(fake_out, fake_in) + if r_identity_eq != f_identity_eq: + raise MetadataMismatchError( + f"{context} mismatch in outputs_are_inputs check {f_identity_eq} != {r_identity_eq}" + ) + + r_output_alias_each_other = output_alias_each_other(real_out) + f_output_alias_each_other = output_alias_each_other(fake_out) + if r_output_alias_each_other != f_output_alias_each_other: + raise MetadataMismatchError( + f"{context} mismatch in outputs_alias_each_other check " + f"{f_output_alias_each_other} != {r_output_alias_each_other}" + ) + + +def is_sdpa_error(func, idx, e): + if ( + ( + func is aten._scaled_dot_product_flash_attention.default + or func is aten._flash_attention_forward.default + ) + and idx in (6, 7) + and "Devices" in repr(e) + ): + return True + if ( + ( + func is aten._scaled_dot_product_efficient_attention.default + or func is aten._efficient_attention_forward.default + ) + and idx in (2, 3) + and "Devices" in repr(e) + ): + return True + if ( + func is aten._scaled_dot_product_cudnn_attention.default + and idx in (6, 7) + and "Devices" in repr(e) + ): + return True + return False + + +def try_convert_fake_to_real( + ten_list: list[Union[FakeTensor, Any]], +) -> list[Union[FakeTensor, torch.Tensor, Any]]: + """ + Attempt to convert fake tensors to a corresponding real tensor with the correct underlying storage by looking up + the FakeTensorMode meta to real storage mapping. On failure to find the storage mapping, the FakeTensor will + remain in the list. + + Note: this is not currently optimized (makes copies of the meta converter internal dictionaries) + """ + + fake_tensor = next( + (item for item in ten_list if isinstance(item, FakeTensor)), None + ) + if fake_tensor is None: + return ten_list + + fake_mode = fake_tensor.fake_mode + meta_converter = fake_mode.fake_tensor_converter.meta_converter + desc = meta_converter.describer + + storage_to_key = {v: k for k, v in meta_converter.storage_memo.items()} + key_to_real_storage = {v: k for k, v in desc.lookup_storage.items()} + out = [] + for t in ten_list: + if not isinstance(t, FakeTensor) or t.layout != torch.strided: + out.append(t) + continue + + key = storage_to_key.get(t.untyped_storage()) + real_storage = None if key is None else key_to_real_storage.get(key) + if real_storage is None: + out.append(t) + continue + + unhinted = False + + def map_symint(s): + nonlocal unhinted + if not isinstance(s, torch.SymInt): + return s + unhinted = unhinted if not unhinted else s.node.has_hint() + return s.node.hint + + stor_offset = map_symint(t.storage_offset()) + size = [map_symint(s) for s in t.shape] + stride = [map_symint(s) for s in t.stride()] + + if unhinted: + out.append(t) + continue + + new_tensor = torch.empty( + [], + dtype=t.dtype, + device=t.device, + ) + new_tensor.set_( + real_storage, + storage_offset=stor_offset, + size=size, + stride=stride, + ) + out.append(new_tensor.clone()) + + return out + + +def _check_fake_real_tensors( + real_out: torch.Tensor, + fake_out: FakeTensor, + context="", + sizes=True, + strides=False, + storage_offset=True, + requires_grad=True, +): + if requires_grad: + if real_out.requires_grad != fake_out.requires_grad: + raise MetadataMismatchError( + f"{context} mismatched requires_grad-ness of outputs. " + f"This usually means that you have added autograd support " + f"for your operator at a dispatch key other than Autograd, " + f"which will lead to problems" + ) + + if torch._C._has_storage(real_out): + r_offset = real_out.storage_offset() + f_offset = fake_out.storage_offset() + if r_offset != f_offset: + raise MetadataMismatchError(f"{context} mismatched storage offset") + + torch._prims.utils.compare_tensor_meta( + real_out, + fake_out, + check_sizes=sizes, + check_strides=strides, + allow_rhs_unbacked=True, + ) + + +class CrossRefFakeMode(TorchDispatchMode): + def __init__( + self, + ignore_op_fn: Union[Callable[[OpOverload], bool], None] = None, + *, + check_strides=True, + check_aliasing=True, + only_check_ops_with_meta=True, + ): + super().__init__() + self.ignore_op_fn = ( + ignore_op_fn if ignore_op_fn is not None else lambda fn: False + ) + self.check_strides = check_strides + self.check_aliasing = check_aliasing + self.only_check_ops_with_meta = only_check_ops_with_meta + + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + kwargs = kwargs or {} + + fake_r = None + + # empty_like excluded for now due to sparse complex + # aten._to_dense.default this one is getting called with csc + if ( + func + not in ( + aten.lift_fresh.default, + aten.lift_fresh_copy.default, + aten.set_.source_Storage_storage_offset, + ) + and not self.ignore_op_fn(func) + and ( + not self.only_check_ops_with_meta + or torch._subclasses.fake_impls.has_meta(func) + ) + and torch.Tag.dynamic_output_shape not in func.tags + and torch.Tag.inplace_view not in func.tags + and torch.Tag.data_dependent_output not in func.tags + ): + # Do not import symbolic_shapes at the top of the module as it imports sympy and that's slow + from torch.fx.experimental.symbolic_shapes import ShapeEnv + + try: + # TODO: enable_python_dispatcher() here + with FakeTensorMode(shape_env=ShapeEnv()) as fake_mode: + fake_args, fake_kwargs = pytree.tree_map_only( + torch.Tensor, + functools.partial(fake_mode.from_tensor, static_shapes=True), + (args, kwargs), + ) + with warnings.catch_warnings(): + fake_r = func(*fake_args, **fake_kwargs) + except UnsupportedFakeTensorException: + pass + + context = ( + f"When comparing the output of {func} on FakeTensor and concrete Tensors, " + f"found" + ) + r = func(*args, **kwargs) + if fake_r is not None: + r_flat = pytree.tree_leaves(r) + f_flat = pytree.tree_leaves(fake_r) + assert len(f_flat) == len(r_flat), ( + f"{context} mismatch in number of returns {len(f_flat)} != {len(r_flat)}" + ) + + if self.check_aliasing: + _check_alias_info( + context, r, (args, kwargs), fake_r, (fake_args, fake_kwargs) + ) + + for idx, (r_out, f_out) in enumerate( + zip(pytree.tree_leaves(r), pytree.tree_leaves(fake_r)) + ): + r_is_ten = isinstance(r_out, torch.Tensor) + assert r_is_ten == isinstance(f_out, torch.Tensor), ( + f"{context} mismatched number of tensor outputs" + ) + if r_is_ten: + try: + _check_fake_real_tensors( + r_out, + f_out, + sizes=True, + strides=self.check_strides, + storage_offset=True, + requires_grad=True, + ) + except Exception as e: + if is_sdpa_error(func, idx, e): + continue + error_message = ( + f"{context} mismatched tensor metadata: {e}" + if len(r_flat) == 1 + else f"{context} mismatched tensor metadata for output[{idx}]: {e}" + ) + raise MetadataMismatchError(error_message) from e + return r diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_subclasses/functional_tensor.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_subclasses/functional_tensor.py new file mode 100644 index 0000000000000000000000000000000000000000..b0aa1977b1093948223b9d56cbd43ba7a29416fa --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_subclasses/functional_tensor.py @@ -0,0 +1,837 @@ +# mypy: allow-untyped-defs +import contextlib +import warnings +import weakref +from abc import ABC, abstractmethod +from collections.abc import Callable +from contextlib import AbstractContextManager +from typing import Any, Optional, Union + +import torch +import torch.fx.traceback as fx_traceback +import torch.utils._pytree as pytree +from torch._C import _functionalization_reapply_views_tls as _reapply_views +from torch._ops import _get_dispatch_mode_pre_dispatch, TorchBindOpOverload +from torch._subclasses.meta_utils import is_sparse_any +from torch.utils._python_dispatch import ( + _detect_infra_mode, + _disable_infra_mode, + autograd_would_have_decomposed, + return_and_correct_aliasing, + TorchDispatchMode, +) + + +not_implemented_log = torch._logging.getArtifactLogger(__name__, "not_implemented") + + +# NOTE Some special handling for tensor conversion during export is needed. +# Normally, when tracing through the model with tensor.to(), the maybe-aliasing +# relationship between input and output tensors will be baked into the graph. +# For example, if we got a tensor with device cpu and call tensor.to("cpu"), +# it will become a no-op in the graph. For a whole graph capture, this is not +# sound so we need to do something different. Instead, in export we will try to +# preserve the tensor conversion by forcing a non-semantic-breaking aten::_to_copy +# operator to be traced in the graph, and subsequently banning mutations on all +# such converted tensors. +# In addition to patching .to() method call in functionalization, we will have to +# patch other similar methods like float() and cpu(), because they intentionally +# don't fall back to .to() methods, but have the same behavior as .to() according to +# pytorch document. https://pytorch.org/docs/stable/generated/torch.Tensor.float.html +# thus we simply force them to go through .to() call. +def _conversion_method_template(**extra_kwargs): + def _(self, *args, **kwargs): + return self.to(*args, **{**kwargs, **extra_kwargs}) + + return _ + + +class FunctionalTensor(torch.Tensor): + """ + Functional tensors represent tensors that will remove mutations + from a program. If you perform a mutable operation on a functional tensor, + it will re-dispatch to the functional variant of that operation. + + Historically, functionalization is implemented in C++ in the dispatcher. + This class is a lightweight python shim around the C++ functionalization logic. + + FunctionalTensor is required to be used with a corresponding + FunctionalTensormode active, because it relies + on using the mode for dispatch (which can properly handle factory functions). + """ + + elem: torch.Tensor + # Indicates to our torch_dispatch dispatching infra that + # this is an "infra" mode with lower dispatching precedence. + _mode_key = torch._C._TorchDispatchModeKey.FUNCTIONAL + + # Note: The reason we add these extra keys to our FunctionalTensor subclass + # is to mirror the behavior of C++ functionalization (we can choose to change this + # later, as long as it doesn't break anything). + # FunctionalTensorWrapper copies **all** dispatch keys from the inner tensor + # to the wrapper, excluding functorch and python dispatch keys. + # Here I'm trying to reuse the keyset the functorch wrapper subclasses copy, + # except that they don't include ZeroTensor so I'm manually adding it in. + _extra_dispatch_keys = torch._C._additional_keys_to_prop_for_wrapper_tensors.add( + torch._C.DispatchKey.ZeroTensor + ) + + # These are all aten ops that correspond to metadata queries. + # We want FunctionalTensor to be able to handle them directly. + metadata_fns = [ + torch.ops.aten.is_contiguous.default, # type: ignore[has-type] + torch.ops.aten.is_contiguous.memory_format, # type: ignore[has-type] + torch.ops.aten.is_strides_like_format.default, # type: ignore[has-type] + torch.ops.aten.is_non_overlapping_and_dense.default, # type: ignore[has-type] + torch.ops.aten.size.default, # type: ignore[has-type] + torch.ops.aten.sym_size.default, # type: ignore[has-type] + torch.ops.aten.stride.default, # type: ignore[has-type] + torch.ops.aten.sym_stride.default, # type: ignore[has-type] + torch.ops.aten.storage_offset.default, # type: ignore[has-type] + torch.ops.aten.sym_storage_offset.default, # type: ignore[has-type] + torch.ops.aten.numel.default, # type: ignore[has-type] + torch.ops.aten.sym_numel.default, # type: ignore[has-type] + torch.ops.aten.dim.default, # type: ignore[has-type] + torch.ops.prim.device.default, # type: ignore[has-type] + ] + + # Used by auto_functionalize to determine base of tensors during inference mode. + _inference_mode_base: Optional["FunctionalTensor"] = None + + def __new__(cls, elem, mode): + assert torch._is_functional_tensor(elem) + + # In general, we'd like our functional tensor subclass to only be in charge of functionalization, + # and defer to the inner subclass for all other functionality. + # Example: If our inner tensor is a ZeroTensor, we would want to defer running the ZeroTensor fallback + # until after we redispatch to our inner ZeroTensor. + # However, there are a few keys that we need to mirror between the inner and outer tensors. + # Conjugate + # Negative + # Why? These keys are used to test metadata queries, like `.is_conj()` and `.is_neg()`. + # We **need** calls to is_conj() to return the same thing on the outer and inner tensors, + # Because user code / framework code that branches like so needs to do the same thing + # when it sees the outer FunctionalTensor: + # if (x.is_conj()) { + # return at::view_as_real(x.resolve_conj()); + # } else { + # return at::view_as_real(x); + # } + extra_dispatch_keys = ( + FunctionalTensor._extra_dispatch_keys & torch._C._dispatch_keys(elem) + ) + + out = torch.Tensor._make_wrapper_subclass( + # TODO: right now, _make_wrapper_subclass's dynamic shape interaction is not great. + # Calling the overload that has kwargs causes us to go down the first overload path, + # which will **always** specialize sizes. + # We should probably eventually fix this so that the first overload can just handle dynamic shapes. + cls, + elem.shape, # sizes + elem.stride() if not is_sparse_any(elem) else None, # strides + ( + elem.storage_offset() if not is_sparse_any(elem) else None + ), # storage_offset + None, # memory_format + elem.dtype, # dtype + elem.layout, # layout + elem.device, # device + False, # pin_memory + elem.requires_grad, # requires_grad + None, # dispatch_sizes_strides_policy + False, # dispatch_device + False, # dispatch_layout + extra_dispatch_keys, # _extra_dispatch_keys + ) + torch._C._set_throw_on_mutable_data_ptr(out) + out.elem = elem + + if ( + torch._export.config.enable_auto_functionalized_v2_for_export + and torch.is_inference_mode_enabled() + and torch._inductor.config.enable_auto_functionalized_v2 + ): + if out.is_base_tensor(): + out._inference_mode_base = None + # This assumes that the FunctionalTensor.elem does not change its storage after this point. + # Otherwise this would be invalid. + mode._storage_to_base[out.elem.untyped_storage()] = out + else: + out._inference_mode_base = mode._storage_to_base[ + out.elem.untyped_storage() + ] + assert out._inference_mode_base is not None + return out + + def __torch_dispatch__(self, func, types, args=(), kwargs=None): # type: ignore[override] + unrecognized_types = [ + t + for t in types + if t not in [torch.Tensor, torch._subclasses.FakeTensor, FunctionalTensor] + ] + if unrecognized_types: + not_implemented_log.debug( + "FunctionalTensor unrecognized subclass(es): %s", unrecognized_types + ) + return NotImplemented + + if kwargs is None: + kwargs = {} + + # FunctionalTensor needs to plumb all metadata requests to the inner tensor. + # In theory we don't have to do this - but if we want to service metadata requests here, + # we need to carefully make sure all metadata is accurate (including metadata mutations) + if func in FunctionalTensor.metadata_fns: + # All metadata accesses should be plumbed to the inner tensor, that way we don't have to worry + # about the problem of keeping metadata in sync between the wrapper and inner tensor. + # This also alleviates us from having to manually handle metadata mutations on the wrapper. + assert len(kwargs) == 0 + if func in [ + torch.ops.aten.is_strides_like_format.default, + torch.ops.aten.is_contiguous.memory_format, + ]: + assert len(args) == 2 and isinstance(args[0], FunctionalTensor) + return func(torch._from_functional_tensor(args[0].elem), args[1]) + assert len(args) == 1 and isinstance(args[0], FunctionalTensor) + + return func(torch._from_functional_tensor(args[0].elem)) + # Originally I tried to implement my subclass without giving it a torch_dispatch, but I gave up: + # - _make_wrapper_subclass requires a __torch_dispatch__ + # - If we want to use _make_subclass(), we have a problem: the subclass will share a TensorImpl with the inner tensor, + # which is of type FunctionalTensorWrapper! We explicitly do not want our wrapper to be a FunctionalTensorWrapper. + # - If we use the default tensor.__new__(), we have another problem: it returns inner_tensor.alias(), + # which causes every subclass created above autograd to have autograd view metadata + # (in addition to also being a FunctionalTensorWrapper). + raise RuntimeError( + "Attempting to use FunctionalTensor on its own. Instead, please use it with a corresponding FunctionalTensorMode()" + ) + + def __repr__(self) -> str: # type: ignore[override] + return f"FunctionalTensor({repr(self.elem)})" + + @staticmethod + def to_functional(x): + # We will do the wrapping for the user. + + assert not torch._is_functional_tensor(x) + # The only autograd metadata we care about on the FunctionalTensor is: + # - requires_grad (so autograd runs) + # - is_leaf (so that mutations on graph inputs that are not leaves are allowed by the autograd engine) + # this is handled by FunctionalTensor.to_functional + x_functional = torch._to_functional_tensor(x) + # Technically the FunctionalTensormode here is unnecessary, + # but it avoids spurious NotImplemented logs during `ProxyTorchDispatchMode` tracing. + # _mirror_autograd_meta_to queries tensor sizes, + # and otherwise the sym_size() call will go to the proxy mode before hitting + # FunctionalTensor.__torch_dispatch__ + + functional_mode = _detect_infra_mode(torch._C._TorchDispatchModeKey.FUNCTIONAL) + assert functional_mode is not None + + with functional_mode: + torch._mirror_autograd_meta_to(x, x_functional) # type: ignore[attr-defined] + out = FunctionalTensor(x_functional, functional_mode) + torch._mirror_autograd_meta_to(x_functional, out) # type: ignore[attr-defined] + return out + + def from_functional(self): + torch._sync(self) + return torch._from_functional_tensor(self.elem) + + def is_base_tensor(self) -> bool: + return torch._is_functional_tensor_base(self.elem) + + def replace_(self, output) -> None: + torch._functionalize_replace(self.elem, output) + + def commit_update(self) -> None: + torch._functionalize_commit_update(self.elem) + + def sync(self) -> None: + torch._functionalize_sync(self.elem) + + def mark_mutation_hidden_from_autograd(self) -> None: + torch._functionalize_mark_mutation_hidden_from_autograd(self.elem) + + def tolist(self) -> Any: + if self.elem.dim() == 0: + return self.elem.item() + elif self.elem.dim() == 1: + return [elem.item() for elem in self.elem] + else: + return [elem.tolist() for elem in self.elem] + + def to(self, *args, **kwargs): + if _detect_infra_mode(torch._C._TorchDispatchModeKey.FUNCTIONAL).export: + torch.ops.aten._assert_tensor_metadata( + self, + dtype=self.dtype, + device=self.device, + layout=self.layout, + ) + # pyrefly: ignore [not-iterable] + return super().to(*args, **kwargs) + + def cuda(self, device=None, *args, **kwargs): + device = device or torch.cuda.current_device() + if len(args) > 0: + return self.to(device, *args, **kwargs) + else: + return self.to(device=device, **kwargs) + + char = _conversion_method_template(dtype=torch.int8) + cpu = _conversion_method_template(device=torch.device("cpu")) + bfloat16 = _conversion_method_template(dtype=torch.bfloat16) + byte = _conversion_method_template(dtype=torch.uint8) + double = _conversion_method_template(dtype=torch.float64) + float = _conversion_method_template(dtype=torch.float32) + bool = _conversion_method_template(dtype=torch.bool) + half = _conversion_method_template(dtype=torch.float16) + int = _conversion_method_template(dtype=torch.int32) + long = _conversion_method_template(dtype=torch.int64) + + # TODO(sparse-team): fixes #133174 but can we do without the relay? + def to_dense(self): # type: ignore[override] + return self.elem.to_dense() + + @property + def layout(self): # type: ignore[override] + return self.elem.layout + + def __bool__(self): + return bool(self.item()) + + +class FunctionalTensorMode(TorchDispatchMode): + def __init__(self, pre_dispatch=False, export=False, _allow_token_discovery=False): + super().__init__() + self.export = export + self.is_on_stack = False + self.enter_stack = [] + # Indicates to our torch_dispatch dispatching infra that + # this is an "infra" mode with lower dispatching precedence. + self._mode_key = torch._C._TorchDispatchModeKey.FUNCTIONAL + self.pre_dispatch = pre_dispatch + # This will be turned off later for pre-dispatch functionalization + self._dispatch_key = torch._C.DispatchKey.PreDispatch if pre_dispatch else None # type: ignore[attr-defined] + # Map of effect type (ex. _EffectType.ORDERED) to a token. The tokens help keep + # track of the ordering between side effectful operations. + self._tokens: dict[Any, torch.Tensor] = {} + + # Filled after forward tracing. + self._tokens_forward_output: dict[Any, torch.Tensor] = {} + + # Functionalization runs twice in AOTAutograd, once in + # `run_functionalized_fw_and_collect_metadata` to collect metadata to + # see which tensors need to be functionalized and discover how many + # tokens we need, and another time in `make_fx` which does the actual + # tracing to replace ops with their functional variants and handling + # side-effectful ops. In the second stage there should be no token + # discovery. This flag distinguishes between the two stages. + self._allow_token_discovery = _allow_token_discovery + + self._storage_to_base: weakref.WeakKeyDictionary[ + torch.storage.UntypedStorage, Optional[FunctionalTensor] + ] = weakref.WeakKeyDictionary() + + # No-op if FunctionalTensorMode is already in use + def __enter__(self): + def _get_prev_mode(): + if self._dispatch_key == torch._C.DispatchKey.PreDispatch: + return _get_dispatch_mode_pre_dispatch( + torch._C._TorchDispatchModeKey.FUNCTIONAL + ) + return torch._C._get_dispatch_mode( + torch._C._TorchDispatchModeKey.FUNCTIONAL + ) + + if _get_prev_mode() is None: + self.enter_stack.append(True) + return super().__enter__() + else: + self.enter_stack.append(False) + return self + + def __exit__(self, a, b, c): + is_on_stack = self.enter_stack.pop() + if is_on_stack: + super().__exit__(a, b, c) + + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + + unrecognized_types = [ + t + for t in types + if not issubclass(t, torch._subclasses.FakeTensor) + and t not in [torch.Tensor, FunctionalTensor] + ] + + if unrecognized_types: + not_implemented_log.debug( + "FunctionalTensor unrecognized subclass(es): %s", unrecognized_types + ) + return NotImplemented + + def _can_decompose(func): + # See https://github.com/pytorch/pytorch/pull/115258#issuecomment-1900755832 + # Never decompose dropout in export + if self.export and func is torch.ops.aten.dropout.default: + return False + + # We unconditionally decompose ops that are maybe aliasing or mutating ops + from torch._decomp import _should_decompose_because_unsafe_op + + if _should_decompose_because_unsafe_op(func): + return True + + # (1) we unconditionally decompose maybe-aliasing or maybe-mutating ops, + # because we must know statically of an op mutates or aliasing in order to functionalize it properly + # (2) for mutating ops that have CompositeImplicit decomps, we choose to decompose them today. + # In theory, we could walk this back and avoid decomposing them later if we need to. + alias_info_present = any(arg.alias_info for arg in func._schema.arguments) + if alias_info_present or func._schema.is_mutable: + return True + + # If we are here, it means we are seeing functional composite op. + # For pre-dispatch IR, we don't want to decompose this op + # For post-dispatch IR, we do want to decompose this op. it is fine + # to decompose here even if you want to preserve a CIA in post-dispatch export + # because we already override decompose behaviour so it will do the + # right thing. + if self.export: + if self.pre_dispatch: + # If it is CIA custom op, we warn that we are assuming this op is indeed functional. + if func.namespace not in ["aten", "prim"] and func._can_decompose(): + warnings.warn( + f"At pre-dispatch tracing, we assume that any custom op marked with " + f"CompositeImplicitAutograd and have functional schema are safe to not decompose. " + f"Found {func} to be one such op.", + stacklevel=2, + ) + return False + return True + + # in normal torch.compile IR, we only decompose an op if autograd + # would have decomposed it (NB: autograd may have been skipped if + # we are in inference mode) + # TODO: the flatten here can potentially be deduped with the + # unwrapping pytree_map later + flat_args_kwargs, _ = pytree.tree_flatten((args, kwargs)) + return autograd_would_have_decomposed(func, flat_args_kwargs) + + if ( + func not in FunctionalTensor.metadata_fns + and _can_decompose(func) + # Not all funcs from __torch_dispatch__ are actual dispatcher ops, + # e.g. prim.device + and torch._C._dispatch_has_kernel(func.name()) + ): + with self: + r = func.decompose(*args, **kwargs) + if r is not NotImplemented: + return r + + def wrap(x): + # Only wrap our outputs in subclasses if the inner functionalization call + # also wrapped outputs into FunctionalTensorWrappers. + # When can this happen? e.g. `torch.div(2, 2)` + assert not isinstance(x, FunctionalTensor) + if isinstance(x, torch.Tensor) and torch._is_functional_tensor(x): + return FunctionalTensor(x, self) + return x + + def unwrap(x): + return x.elem + + from torch._higher_order_ops.auto_functionalize import ( + can_auto_functionalize, + do_auto_functionalize, + do_auto_functionalize_v2, + ) + + if can_auto_functionalize( + func + ) and not torch._C._dispatch_has_kernel_for_dispatch_key( + func.name(), torch._C.DispatchKey.Functionalize + ): + import torch._export.config as export_config + import torch._inductor.config as inductor_config + + if torch.compiler.is_exporting(): + if export_config.enable_auto_functionalized_v2_for_export: + return do_auto_functionalize_v2(self, func, args, kwargs) + + return do_auto_functionalize(self, func, args, kwargs) + + if inductor_config.enable_auto_functionalized_v2: + return do_auto_functionalize_v2(self, func, args, kwargs) + return do_auto_functionalize(self, func, args, kwargs) + + from torch._higher_order_ops.effects import handle_effects, has_effects + + if has_effects(func): + assert not torch._C._dispatch_has_kernel_for_dispatch_key( + func.name(), torch._C.DispatchKey.Functionalize + ) + return handle_effects( + self._allow_token_discovery, self._tokens, func, args, kwargs + ) + + args_unwrapped, kwargs_unwrapped = pytree.tree_map_only( + FunctionalTensor, unwrap, (args, kwargs) + ) + + # Expectation: functionalization should not **already** be enabled above our mode. + # Why would that be bad? when we return a FunctionalTensor here, we don't want functionalization + # to run above this mode and further wrap that output in **another** C++ FunctionalTensorWrapper. + is_included = torch._C._dispatch_tls_is_dispatch_key_included( + torch._C.DispatchKey.Functionalize + ) + is_excluded = torch._C._dispatch_tls_is_dispatch_key_excluded( + torch._C.DispatchKey.Functionalize + ) + assert is_excluded or not is_included + include_to_set = ( + torch._C._dispatch_tls_local_include_set() + | torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize) + ) + exclude_to_set = ( + torch._C._dispatch_tls_local_exclude_set().remove( + torch._C.DispatchKey.Functionalize + ) + - FunctionalTensor._extra_dispatch_keys + ) + + if isinstance(func, TorchBindOpOverload): + # When the function is a TorchBindOpOverload, meaning some of the + # inputs are FakeScriptObjects, we need to skip c++ dispatcher and + # dispatch in python because C++ dispatcher will check the schema + # and cannot recognize FakeScriptObject. + ctx = PythonFunctionalizeAPI() + fully_unwrapped_args = ctx.unwrap_tensors(args) + fully_unwrapped_kwargs = ctx.unwrap_tensors( + kwargs # pyrefly: ignore[bad-argument-type] + ) + outs_unwrapped = func( + *fully_unwrapped_args, + **fully_unwrapped_kwargs, + ) + outs_wrapped = ctx.wrap_tensors(outs_unwrapped) + else: + # All we want to do here is reuse the existing C++ functionalization logic. + # This requires swizzling our TLS dispatch keys so that the Functionalize key is active. + with torch._C._ForceDispatchKeyGuard(include_to_set, exclude_to_set): + try: + # By default for python functionalization (for AOTAutograd), we reapply views. + old_apply_views = torch._functionalize_enable_reapply_views(True) # type: ignore[attr-defined] + + # Sometimes these functions cannot be directly dispatched to functionalize key + # because args are sometimes not functional tensors for some reason? + if func in FunctionalTensor.metadata_fns: + outs_unwrapped = func(*args_unwrapped, **kwargs_unwrapped) + outs_wrapped = pytree.tree_map_only( + torch.Tensor, wrap, outs_unwrapped + ) + else: + # Note: [Functionalization View Replay Annotation] + # When functionalization encounters a mutation, it handles aliases by lazily regenerating the aliases + # at the first time they are next used. + # This is a problem when plumbing user annotations during tracing. We want the view ops from view replay + # to have the same annotation that the user specified on the original views. But view replay in + # functionalization happens the next time the alias is used (e.g. second_op(alias_with_pending_mutation)), + # so when we regenerate views before calling into second_op, those views will end up getting the metadata + # for second_op! + # + # Instead, we need to remember the node metadata from the original views, and ensure that this node metadata + # is globally set when we lazily perform view replay. + # The globally set metadata will be used to populate the fx node created for the replayed operation. + if m := torch._C._get_dispatch_mode( + torch._C._TorchDispatchModeKey.PROXY + ): + for a in pytree.tree_leaves([args, kwargs]): + if not isinstance(a, FunctionalTensor): + continue + curr_node = m.tracer.tensor_tracker[ + torch._from_functional_tensor(a.elem) + ].proxy.node + with fx_traceback.set_current_replay_node(curr_node): + torch._sync(a) + + # When we dispatch to the C++ functionalization kernel, we might need to jump back to the + # PreDispatch mode stack afterwards, to handle any other PreDispatch modes underneath + # FunctionalTensorMode. If we call func() directly, we would need to exclude PreDispatch + # from the TLS in order to avoid infinite looping, but this would prevent us from coming + # back to PreDispatch later + outs_unwrapped = func._op_dk( + torch._C.DispatchKey.Functionalize, + *args_unwrapped, + **kwargs_unwrapped, + ) + + if self.export: + if func is torch.ops.aten.dropout.default: + torch._freeze_functional_tensor(outs_unwrapped) # type: ignore[attr-defined] + outs_wrapped = pytree.tree_map_only( + torch.Tensor, wrap, outs_unwrapped + ) + finally: + torch._disable_functionalization() + torch._functionalize_enable_reapply_views(old_apply_views) # type: ignore[attr-defined] + + is_included = torch._C._dispatch_tls_is_dispatch_key_included( + torch._C.DispatchKey.Functionalize + ) + is_excluded = torch._C._dispatch_tls_is_dispatch_key_excluded( + torch._C.DispatchKey.Functionalize + ) + assert is_excluded or not is_included + + if ( + # If no outputs are our functional subclass, then don't try to fix up aliasing + not any( + isinstance(x, FunctionalTensor) + for x in pytree.tree_leaves(outs_wrapped) + ) + # Since lift_fresh lifts its argument into a functional tensor, we can skip the + # aliasing correction step. Otherwise, we would be setting the storage of a + # lifted tensor to that of an unlifted tensor. + # Ref: https://github.com/pytorch/pytorch/issues/111506 + or func is torch.ops.aten.lift_fresh.default + ): + return outs_wrapped + # for metadata mutations, need to manually mutate the metadata of the FunctionalTensor wrapper + if ( + torch.Tag.inplace_view in func.tags + and func is not torch.ops.aten.set_.source_Tensor + ): + with torch.utils._mode_utils.no_dispatch(): + func(*args, **kwargs) + # Wrapper tensor subclasses do not have correct aliasing info! Use this util to manually correct the output aliasing. + # inplace ops like `aten.add_()` are expected to return inputs **directly**, instead of creating fresh tensor objects. + # Use this util to figure out the right thing to return. + # If none of our inputs were wrapped, then we have no FunctionalTensor outputs that we need to fix up storages for. + return return_and_correct_aliasing(func, args, kwargs, outs_wrapped) + + @classmethod + def is_infra_mode(cls) -> bool: + return True + + +@contextlib.contextmanager +def disable_functional_mode(): + return _disable_infra_mode(torch._C._TorchDispatchModeKey.FUNCTIONAL) + + +# This is similar to torch.func.functionalize, but: +# - It uses FunctionalTensorMode, and FunctionalTensor (a python subclass). +# One important advantage to using this mode is that it will let us +# run functionalization underneath __torch_dispatch__, +# which we need in AOTAutograd. +# - Doing so means that it does not automatically compose with other +# functorch transforms, since these transforms always run above __torch_dispatch__. +# That's why this util lives here, and not in functorch. +def dispatch_functionalize(func, mode: FunctionalTensorMode = FunctionalTensorMode()): + # TODO: pull these from aot autograd + def to_fun(t): + if isinstance(t, torch.Tensor): + return FunctionalTensor.to_functional(t) + return t + + def from_fun(t): + if not isinstance(t, FunctionalTensor): + # quick sanity assert + if isinstance(t, torch.Tensor): + assert not torch._is_functional_tensor(t) + return t + torch._sync(t) + return torch._from_functional_tensor(t.elem) + + def inner(*args, **kwargs): + disable_above = torch._C._ExcludeDispatchKeyGuard( + torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize) + ) + with disable_above, mode: + func_args = pytree.tree_map_only(torch.Tensor, to_fun, args) + func_kwargs = pytree.tree_map_only(torch.Tensor, to_fun, kwargs) + func_outputs = func(*func_args, **func_kwargs) + outputs = pytree.tree_map_only(FunctionalTensor, from_fun, func_outputs) + + return outputs + + return inner + + +class BaseFunctionalizeAPI(ABC): + @abstractmethod + def wrap_tensors(self, args: tuple[Any]) -> tuple[Any]: + pass + + @abstractmethod + def unwrap_tensors( + self, args: Union[torch.Tensor, tuple[torch.Tensor, ...]] + ) -> Any: + pass + + @abstractmethod + def functionalize(self, inner_f: Callable) -> Callable: + pass + + @abstractmethod + def redispatch_to_next(self) -> AbstractContextManager: + pass + + @abstractmethod + def replace(self, input_tensor, output_tensor) -> None: + pass + + @abstractmethod + def commit_update(self, tensor) -> None: + pass + + @abstractmethod + def sync(self, tensor) -> None: + pass + + @abstractmethod + def mark_mutation_hidden_from_autograd(self, tensor) -> None: + pass + + +class PythonFunctionalizeAPI(BaseFunctionalizeAPI): + def __init__( + self, mode: Optional[FunctionalTensorMode] = None, pre_dispatch: bool = False + ) -> None: + super().__init__() + self.mode = mode if mode else FunctionalTensorMode() + self.pre_dispatch = pre_dispatch + + def wrap_tensors(self, args: tuple[Any]) -> tuple[Any]: + with self.mode: + return torch.utils._pytree.tree_map_only( + torch.Tensor, FunctionalTensor.to_functional, args + ) + + def unwrap_tensors( + self, args: Union[torch.Tensor, tuple[torch.Tensor, ...], list[torch.Tensor]] + ) -> Any: + return torch.utils._pytree.tree_map_only( + FunctionalTensor, FunctionalTensor.from_functional, args + ) + + def functionalize(self, inner_f: Callable) -> Callable: + return dispatch_functionalize(inner_f, self.mode) + + def redispatch_to_next(self) -> AbstractContextManager: + # [NOTE] We don't do anything here because at the time + # we exercise this path, we would have already popped the + # FunctionalTensorMode from mode stack. Since FunctionalTensorMode + # is now stateful, it is better to explicitly pass in correct mode + # directly instead of globally setting it. + return contextlib.nullcontext() + + def replace(self, input_tensor, output_tensor) -> None: + assert isinstance(input_tensor, FunctionalTensor) + assert not isinstance(output_tensor, FunctionalTensor) + input_tensor.replace_(output_tensor) + + def commit_update(self, tensor) -> None: + assert isinstance(tensor, FunctionalTensor) + tensor.commit_update() + + def sync(self, tensor) -> None: + assert isinstance(tensor, FunctionalTensor) + tensor.sync() + + def mark_mutation_hidden_from_autograd(self, tensor) -> None: + assert isinstance(tensor, FunctionalTensor) + tensor.mark_mutation_hidden_from_autograd() + + +class CppFunctionalizeAPI(BaseFunctionalizeAPI): + def wrap_tensors(self, args: tuple[Any]) -> tuple[Any]: + from torch._functorch.eager_transforms import _wrap_all_tensors_to_functional + + return _wrap_all_tensors_to_functional(args, level=0) + + def unwrap_tensors( + self, args: Union[torch.Tensor, tuple[torch.Tensor, ...]] + ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]: + from torch._functorch.eager_transforms import ( + _unwrap_all_tensors_from_functional, + ) + + return _unwrap_all_tensors_from_functional(args, reapply_views=_reapply_views()) + + def functionalize(self, inner_f: Callable) -> Callable: + return torch.func.functionalize(inner_f) + + def redispatch_to_next(self) -> AbstractContextManager: + return torch._C._ExcludeDispatchKeyGuard( + torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize) + ) + + def replace(self, input_tensor, output_tensor) -> None: + torch._functionalize_replace(input_tensor, output_tensor) + + def commit_update(self, tensor) -> None: + torch._functionalize_commit_update(tensor) + + def sync(self, tensor) -> None: + torch._functionalize_sync(tensor) + + def mark_mutation_hidden_from_autograd(self, tensor) -> None: + torch._functionalize_mark_mutation_hidden_from_autograd(tensor) + + +class FunctorchFunctionalizeAPI(BaseFunctionalizeAPI): + def __init__(self, interpreter): + self.interpreter = interpreter + + def wrap_tensors(self, args: tuple[Any]) -> tuple[Any]: + from torch._functorch.eager_transforms import _wrap_all_tensors_to_functional + + return _wrap_all_tensors_to_functional(args, level=self.interpreter.level()) + + def unwrap_tensors( + self, args: Union[torch.Tensor, tuple[torch.Tensor, ...]] + ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]: + from torch._functorch.eager_transforms import ( + _unwrap_all_tensors_from_functional, + ) + + return _unwrap_all_tensors_from_functional( + args, reapply_views=self.interpreter.functionalize_add_back_views() + ) + + def functionalize(self, inner_f: Callable) -> Callable: + return torch.func.functionalize( + inner_f, + remove=( + "mutations_and_views" + if self.interpreter.functionalize_add_back_views() + else "mutations" + ), + ) + + def redispatch_to_next(self) -> AbstractContextManager: + return self.interpreter.lower() + + def replace(self, input_tensor, output_tensor) -> None: + torch._functionalize_replace(input_tensor, output_tensor) + + def commit_update(self, tensor) -> None: + torch._functionalize_commit_update(tensor) + + def sync(self, tensor) -> None: + torch._functionalize_sync(tensor) + + def mark_mutation_hidden_from_autograd(self, tensor) -> None: + torch._functionalize_mark_mutation_hidden_from_autograd(tensor) + + +def mb_unwrap_functional_tensor(tensor: torch.Tensor): + if isinstance(tensor, FunctionalTensor): + return torch._from_functional_tensor(tensor.elem) + return tensor diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_subclasses/meta_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_subclasses/meta_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1db028fdbe2eef67f79cf3a547c4930a23647b56 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_subclasses/meta_utils.py @@ -0,0 +1,1972 @@ +from __future__ import annotations + +import contextlib +import dataclasses +import functools +import threading +import typing +import weakref +from abc import abstractmethod +from contextlib import AbstractContextManager, contextmanager +from dataclasses import dataclass +from typing import ( + Any, + ClassVar, + Generic, + NewType, + Optional, + Protocol, + TYPE_CHECKING, + TypeGuard, + TypeVar, + Union, +) +from typing_extensions import override, TypedDict, TypeIs, Unpack + +import torch +from torch._C._autograd import CreationMeta +from torch._C._functorch import ( + _add_batch_dim, + _unwrap_functional_tensor, + _wrap_functional_tensor, + get_unwrapped, + is_batchedtensor, + is_functorch_wrapped_tensor, + is_gradtrackingtensor, + is_legacy_batchedtensor, + maybe_get_bdim, + maybe_get_level, + peek_interpreter_stack, +) +from torch._dispatch.python import enable_python_dispatcher +from torch._logging import trace_structured +from torch.utils._mode_utils import no_dispatch +from torch.utils._python_dispatch import is_traceable_wrapper_subclass +from torch.utils.weak import WeakIdKeyDictionary + + +if TYPE_CHECKING: + from collections.abc import Callable, Generator + + from torch._C._functorch import CInterpreter + from torch._guards import Source + from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode + + # Import here to avoid cycle + # Import the following modules during type checking to enable code intelligence features, + # Do not import unconditionally, as they import sympy and importing sympy is very slow + from torch.fx.experimental.symbolic_shapes import ShapeEnv, SymbolicContext + + +def _is_fake_tensor(t: object) -> TypeIs[FakeTensor]: + from torch._subclasses.fake_tensor import FakeTensor + + return isinstance(t, FakeTensor) + + +DimList = list +_TensorLikeT = TypeVar("_TensorLikeT", "MetaTensorDesc", torch.Tensor) +_T = TypeVar("_T") +_TensorT = TypeVar("_TensorT", bound=torch.Tensor) +_TensorT_cov = TypeVar("_TensorT_cov", bound=torch.Tensor, covariant=True) + + +def safe_is_leaf(t: Union[MetaTensorDesc, torch.Tensor]) -> bool: + try: + return t.is_leaf + except RuntimeError: + # inference mode can trigger this + return False + + +def safe_grad(t: _TensorLikeT) -> Optional[_TensorLikeT]: + with torch._logging.hide_warnings(torch._logging._internal.safe_grad_filter): + # pyrefly: ignore [bad-return] + return t.grad + + +def _expect_safe_grad(t: _TensorLikeT) -> _TensorLikeT: + grad = safe_grad(t) + assert grad is not None + return grad + + +def assert_eq(a: _T, b: _T) -> None: + assert a == b, f"{a} != {b}" + + +tls = threading.local() +# Turns off inference mode for fake tensor propagation. This is turned to True +# only for `torch.compile`. Also look at +# _dynamo.config.fake_tensor_disable_inference_mode +tls.disable_inference_mode = False + + +@contextmanager +def disable_inference_mode_for_fake_prop() -> Generator[None, None, None]: + prior = getattr(tls, "disable_inference_mode", False) + tls.disable_inference_mode = True + try: + yield + finally: + tls.disable_inference_mode = prior + + +def assert_metadata_eq( + assert_eq: Callable[[object, object], None], + m1: Union[MetaTensorDesc, torch.Tensor], + m2: torch.Tensor, + *, + skip_symbolic: bool = False, + skip_leaf: bool = False, +) -> None: + m1 = ( + MetaTensorDescriber().describe_tensor(m1) + if isinstance(m1, torch.Tensor) + else m1 + ) + + def go(m1: MetaTensorDesc, m2: torch.Tensor) -> None: + assert_eq(m1.dtype, m2.dtype) + if not skip_symbolic: + assert_eq(m1.shape, m2.shape) + assert_eq(m1.requires_grad, m2.requires_grad) + if not skip_leaf: + assert_eq(m1.is_leaf, m2.is_leaf) + # MetaTensorDesc doesn't store grad_fn; inferred from leaf + # assert_eq(m1.grad_fn is None, m2.grad_fn is None) + assert_eq(m1.is_sparse, m2.is_sparse) + if not getattr(tls, "disable_inference_mode", False): + assert_eq(m1.is_inference, m2.is_inference()) + else: + assert_eq(m1.is_inference, False) + assert_eq(m1.is_conj, m2.is_conj()) + assert_eq(m1.is_neg, m2.is_neg()) + assert_eq(m1.grad is not None, safe_grad(m2) is not None) + if m1.grad is not None: + go(m1.grad, _expect_safe_grad(m2)) + # TODO: move "assert_eq(m1.layout, m2.layout)" out of sparse + # branches (but not ready for prime time yet)... + if m1.is_sparse: + assert_eq(m1.layout, m2.layout) + assert_eq(m1.dense_dim, m2.dense_dim()) + assert_eq(m1.sparse_dim, m2.sparse_dim()) + assert_eq(m1.is_coalesced, m2.is_coalesced()) + elif is_sparse_compressed(m1): + assert_eq(m1.layout, m2.layout) + assert_eq(m1.dense_dim, m2.dense_dim()) + assert_eq(m1.sparse_dim, m2.sparse_dim()) + else: + if not skip_symbolic: + assert_eq(m1.stride, m2.stride()) + assert_eq(m1.storage_offset, m2.storage_offset()) + assert_eq(m1.is_view, m2._is_view()) + if m1.is_view: + assert m1.base is not None + assert m2._base is not None + go(m1.base, m2._base) + # TODO: test if is resizable (no direct query for this atm) + # TODO: audit AutogradMeta to see if it matches + # TODO: test forward AD + + return go(m1, m2) + + +# TypeGuard (not TypeIs): False does not imply !torch.Tensor +def is_sparse_coo(t: object) -> TypeGuard[torch.Tensor]: + return isinstance(t, torch.Tensor) and t.layout is torch.sparse_coo + + +def is_sparse_compressed_layout(layout: torch.layout) -> bool: + return layout in { + torch.sparse_csr, + torch.sparse_csc, + torch.sparse_bsr, + torch.sparse_bsc, + } + + +# TypeGuard (not TypeIs): False does not imply !torch.Tensor +def is_sparse_compressed(t: object) -> TypeGuard[torch.Tensor]: + return isinstance(t, torch.Tensor) and is_sparse_compressed_layout(t.layout) + + +# TypeGuard (not TypeIs): False does not imply !torch.Tensor +def is_sparse_any(t: object) -> TypeGuard[torch.Tensor]: + return is_sparse_coo(t) or is_sparse_compressed(t) + + +def _checked_cast(ty: type[_T], obj: object) -> _T: + assert isinstance(obj, ty), f"expected {ty} but got {type(obj)}" + return obj + + +def _get_real_storage(base: torch.UntypedStorage) -> torch.UntypedStorage: + return base.real_storage # type: ignore[attr-defined] + + +def _set_real_storage( + base: torch.UntypedStorage, real_storage: torch.UntypedStorage +) -> None: + base.real_storage = real_storage # type: ignore[attr-defined] + + +# Don't use id() directly, because those can get reallocated over time. +MetaStorageId = NewType("MetaStorageId", int) +MetaTensorId = NewType("MetaTensorId", int) + + +_DescriberId = NewType("_DescriberId", int) +DESCRIBER_NEXT_ID = _DescriberId(0) + + +class MetaTensorDescriber: + """ + Given a Tensor/Storage, generate a MetaTensorDesc/MetaStorageDesc + for it, which is enough information to reconstruct a meta tensor/fake tensor + corresponding to a Tensor as faithfully as possible. + + This is a stateful conversion object because we keep track of the IDs + of the tensors/storages passed to us, so we can consistently give + the same ID when we see the same tensor/storage. + """ + + def __init__(self, *, copy_data: bool = False) -> None: + global DESCRIBER_NEXT_ID + self.id = DESCRIBER_NEXT_ID + DESCRIBER_NEXT_ID = _DescriberId(DESCRIBER_NEXT_ID + 1) + self.next_tensor_id: MetaTensorId = MetaTensorId(0) + self.next_storage_id: MetaStorageId = MetaStorageId(0) + # Tensor -> int + self.lookup_tensor = WeakIdKeyDictionary() + # Storage -> int + self.lookup_storage = WeakIdKeyDictionary() + self.copy_data = copy_data + self.traced_tensors: set[int] = set() + self.traced_storages: set[int] = set() + + def get_tensor_id(self, t: torch.Tensor) -> MetaTensorId: + if t not in self.lookup_tensor: + self.lookup_tensor[t] = self.next_tensor_id + self.next_tensor_id = MetaTensorId(self.next_tensor_id + 1) + return self.lookup_tensor[t] + + def get_storage_id(self, s: torch.UntypedStorage) -> MetaStorageId: + if s not in self.lookup_storage: + self.lookup_storage[s] = self.next_storage_id + self.next_storage_id = MetaStorageId(self.next_storage_id + 1) + return self.lookup_storage[s] + + def describe_storage( + self, s: torch.UntypedStorage, *, trace: bool = False + ) -> MetaStorageDesc: + r = MetaStorageDesc( + id=self.get_storage_id(s), + size=s.size(), + # NB: We don't do the copy yet; copy happens when we start + # creating the new storages + data=s if self.copy_data else None, + ) + if trace and r.id not in self.traced_storages: + trace_structured( + "describe_storage", + metadata_fn=lambda: r.as_json(self.id), + ) + self.traced_storages.add(r.id) + return r + + def describe_tensor( + self, t: torch.Tensor, *, recurse: bool = True, trace: bool = False + ) -> MetaTensorDesc: + is_leaf = safe_is_leaf(t) + is_view = t._is_view() + is_sparse = t.is_sparse + layout = t.layout + is_nested = t.is_nested + is_traceable_wrapper_subclass_v = is_traceable_wrapper_subclass(t) + is_functorch_wrapped = is_functorch_wrapped_tensor(t) + is_mkldnn = t.is_mkldnn + is_batchedtensor_v = is_batchedtensor(t) + is_legacy_batchedtensor_v = is_legacy_batchedtensor(t) + is_gradtrackingtensor_v = is_gradtrackingtensor(t) + is_functional = torch._is_functional_tensor(t) + + storage = None + # NB: For compatibility, I default this to zero, as sometimes people + # still have stuffed zero into storage offset even though the tensor + # doesn't meaningfully have an offset + storage_offset = 0 + if not ( + is_sparse + or is_sparse_compressed_layout(layout) + or (is_nested and not is_traceable_wrapper_subclass_v) + or is_mkldnn + # TODO: TBH, functorch wrapped tensors probably should have + # storage associated with them + or is_functorch_wrapped + or is_legacy_batchedtensor_v + ): + # NB: We actually don't use storage to do views, but might as well + # put it in for accuracy + storage = self.describe_storage(t.untyped_storage(), trace=trace) + storage_offset = t.storage_offset() # type: ignore[assignment] + + stride = None + if not ( + is_sparse + or is_sparse_compressed_layout(layout) + or (is_nested and not is_traceable_wrapper_subclass_v) + ): + # stride/storage_offset are called from is_functorch_wrapped, + # view_from_base, empty_create_subclass, + # sym_sizes_strides_storage_offset (empty_create) + stride = t.stride() + + # NB: this technically should refer to functorch unwrapped tensor, but + # I am (perhaps abusively) using it to store both the functorch and + # non-functorch functional tensor + unwrapped = None + autograd_meta_from = None + current_level = None + if is_batchedtensor_v or is_gradtrackingtensor_v: + unwrapped = self.describe_tensor(get_unwrapped(t), trace=trace) + # xla and lazy tensors present as functional tensors, but we want them + # to be handled specially + elif is_functional and t.device.type not in ("xla", "lazy"): + if t._is_view(): + raise RuntimeError( + "Cannot safely fakify a view because this process drops the view information right now." + ) + if not is_functorch_wrapped: + torch._sync(t) + unwrapped = self.describe_tensor( + torch._from_functional_tensor(t), trace=trace + ) + autograd_meta_from = t + else: + reapply_views = torch._C._functionalization_reapply_views_tls() + # NB: has side effects! + unwrapped = self.describe_tensor( + _unwrap_functional_tensor(t, reapply_views), trace=trace + ) + # TODO: It's pretty suspicious that functional tensors don't have + # valid level and thus we just grab whatever the current level + # is + current_level = torch._C._functorch.current_level() + + maybe_functorch_stack = None + if is_functorch_wrapped: + with ( + torch._functorch.pyfunctorch.temporarily_clear_interpreter_stack() + ) as maybe_functorch_stack: + pass + + attrs = None + ctx = None + type_v = None + if is_traceable_wrapper_subclass_v: + assert hasattr(t, "__tensor_flatten__") + raw_attrs, ctx = t.__tensor_flatten__() + attrs = { + attr: self.describe_tensor(getattr(t, attr), trace=trace) + for attr in raw_attrs + } + type_v = type(t) + + from torch.nested._internal.nested_tensor import _tensor_symint_registry + + view_func = ViewFunc.from_tensor(t) + + # TODO: Is it important to enable torch.inference_mode before querying + # these values? + is_inference_mode_disabled = getattr(tls, "disable_inference_mode", False) + r: MetaTensorDesc = MetaTensorDesc( + id=self.get_tensor_id(t), + storage=storage, + is_inference=False if is_inference_mode_disabled else t.is_inference(), + is_leaf=is_leaf, + requires_grad=t.requires_grad, + # NB: ndim should be OK too but there is a disaster at + # python test/dynamo/test_subclasses.py -k test_user_overridden_property_unsupported + # Actually, this means that we have a little bit of a problem + # here, which is that there is some sensitivity to how exactly an + # access is done if you have a __torch_function__ subclass. Maybe + # should disable torch function before doing accesses? + ndim=t.dim(), + dtype=t.dtype, + is_sparse=is_sparse, + is_mkldnn=is_mkldnn, + is_functorch_wrapped=is_functorch_wrapped, + is_batchedtensor=is_batchedtensor_v, + is_legacy_batchedtensor=is_legacy_batchedtensor_v, + is_gradtrackingtensor=is_gradtrackingtensor_v, + is_view=is_view, + is_conj=t.is_conj(), + is_neg=t.is_neg(), + is_parameter=isinstance(t, torch.nn.Parameter), + is_traceable_wrapper_subclass=is_traceable_wrapper_subclass_v, + is_nested=is_nested, + nested_int=( + _tensor_symint_registry[t].node.nested_int() + if t in _tensor_symint_registry + else None + ), + is_functional=is_functional, + layout=layout, + device=t.device, + size=t.size(), + stride=stride, + # pyrefly: ignore [bad-argument-type] + storage_offset=storage_offset, + dynamo_dynamic_indices=list(getattr(t, "_dynamo_dynamic_indices", set())), + dynamo_hint_overrides=getattr(t, "_dynamo_hint_overrides", {}), + sparse_dim=( + t.sparse_dim() if t.is_sparse or is_sparse_compressed(t) else None + ), + dense_dim=t.dense_dim() if t.is_sparse or is_sparse_compressed(t) else None, + is_coalesced=t.is_coalesced() if t.is_sparse else None, + # TODO: I actually think recursing here is correct, but we have at + # least an infinite cycle from base -> values -> base + # https://github.com/pytorch/pytorch/issues/122089 + crow_indices=( + self.describe_tensor(t.crow_indices(), recurse=False, trace=trace) + if recurse and t.layout in {torch.sparse_csr, torch.sparse_bsr} + else None + ), + col_indices=( + self.describe_tensor(t.col_indices(), recurse=False, trace=trace) + if recurse and t.layout in {torch.sparse_csr, torch.sparse_bsr} + else None + ), + ccol_indices=( + self.describe_tensor(t.ccol_indices(), recurse=False, trace=trace) + if recurse and t.layout in {torch.sparse_csc, torch.sparse_bsc} + else None + ), + row_indices=( + self.describe_tensor(t.row_indices(), recurse=False, trace=trace) + if recurse and t.layout in {torch.sparse_csc, torch.sparse_bsc} + else None + ), + values=( + self.describe_tensor(t.values(), recurse=False, trace=trace) + if recurse and is_sparse_compressed(t) + else None + ), + grad=( + self.describe_tensor(grad, trace=trace) + if (grad := safe_grad(t)) is not None + else None + ), + creation_meta=( + torch._C._autograd._get_creation_meta(t) if t._is_view() else None + ), + unwrapped=unwrapped, + level=( + maybe_get_level(t) + if is_batchedtensor_v or is_gradtrackingtensor_v + else None + ), + bdim=maybe_get_bdim(t) if is_batchedtensor_v else None, + base=( + self.describe_tensor(t._base, trace=trace) + if recurse and t._is_view() and t._base is not None + else None + ), + fake_mode=torch._subclasses.fake_tensor.maybe_get_fake_mode(t), + view_func=view_func, + attrs=attrs, + ctx=ctx, + type=type_v, + # NB: even if functorch is enabled, don't actually save the + # interpreter stack here unless we are actually functorch wrapped; + # it's irrelevant for non-functorch stuff + functorch_stack=maybe_functorch_stack, + autograd_meta_from=autograd_meta_from, + current_level=current_level, + data=t if self.copy_data else None, + ) + if trace and r.id not in self.traced_tensors: + trace_structured( + "describe_tensor", + metadata_fn=lambda: r.as_json(self.id), + ) + self.traced_tensors.add(r.id) + return r + + +@dataclass(frozen=True) +class MetaStorageDesc: + id: MetaStorageId + size: int + # NB: this is only populated with copy_data True, it is not directly + # serializable in JSON, you want to do something special here anyway + data: Optional[torch.UntypedStorage] + + def as_json(self, describer_id: _DescriberId) -> dict[str, object]: + return { + "id": self.id, + "describer_id": describer_id, + "size": self.size if isinstance(self.size, int) else repr(self.size), + } + + +@dataclass(frozen=True) +class ViewFunc(Generic[_TensorT]): + @abstractmethod + def apply( + self, + t: _TensorT, + new_base: _TensorT, + symint_visitor_fn: Optional[Callable[[int], int]] = None, + tensor_visitor_fn: Optional[Callable[[torch.Tensor], _TensorT]] = None, + ) -> _TensorT: ... + + @staticmethod + def from_tensor(t: torch.Tensor) -> ViewFunc: + if _is_fake_tensor(t): + return _FakeTensorViewFunc() + else: + return _CustomViewFunc(t._view_func_unsafe) + + +@dataclass(frozen=True) +class _FakeTensorViewFunc(ViewFunc["FakeTensor"]): + @override + def apply( + self, + t: torch.Tensor, + new_base: torch.Tensor, + symint_visitor_fn: Optional[Callable[[int], int]] = None, + tensor_visitor_fn: Optional[Callable[[torch.Tensor], FakeTensor]] = None, + ) -> FakeTensor: + return torch._subclasses.fake_tensor.FakeTensor._view_func_unsafe( + # pyrefly: ignore [bad-argument-type] + t, + new_base, + symint_visitor_fn, + tensor_visitor_fn, + ) + + +@dataclass(frozen=True) +class _CustomViewFunc(ViewFunc[_TensorT], Generic[_TensorT]): + func: Callable[ + [ + torch.Tensor, + Optional[Callable[[int], int]], + Optional[Callable[[torch.Tensor], _TensorT]], + ], + _TensorT, + ] + + @override + def apply( + self, + t: torch.Tensor, + new_base: torch.Tensor, + symint_visitor_fn: Optional[Callable[[int], int]] = None, + tensor_visitor_fn: Optional[Callable[[torch.Tensor], _TensorT]] = None, + ) -> _TensorT: + # ignore `t` + return self.func(new_base, symint_visitor_fn, tensor_visitor_fn) + + +# A callback where the device is either optional or required. +# All of these satisfy this protocol: +# def mk(arg: Callable[[], torch.Tensor], device: Union[torch.device, str]) +# def mk(arg: Callable[[], torch.Tensor], device: Union[torch.device, str] = "meta") +# def mk(arg: Callable[[], torch.Tensor], device: Optional[Union[torch.device, str]] = None) +class _MetaTensorCallback(Protocol, Generic[_TensorT_cov]): + def __call__( + self, arg: Callable[[], torch.Tensor], /, *, device: Union[torch.device, str] + ) -> _TensorT_cov: ... + + +class _MetaTensorCallbackKwargs(TypedDict, total=False): + device: Union[torch.device, str] + + +# A callback where the device may not be provided (is optional). +# All of these satisfy this protocol: +# def mk(arg: Callable[[], torch.Tensor], device: Union[torch.device, str] = "meta") +# def mk(arg: Callable[[], torch.Tensor], device: Optional[Union[torch.device, str]] = None) +class _MetaTensorCallbackOptDevice(Protocol, Generic[_TensorT_cov]): + def __call__( + self, + arg: Callable[[], torch.Tensor], + /, + **kwargs: Unpack[_MetaTensorCallbackKwargs], + ) -> _TensorT_cov: ... + + +@dataclass(frozen=True) +class MetaTensorDesc(Generic[_TensorT]): + id: MetaTensorId + ndim: int + dtype: torch.dtype + device: torch.device + + # NB: Sometimes, size, stride and storage_offset contain SymInt, in which + # case this is NOT serializable. That only happens when you're + # re-fakeifying a fake tensor with an existing ShapeEnv... maybe we + # can get rid of this use case entirely. Notably, even if we are + # fakeifying a real tensor into a fake tensor with symbolic shapes, the + # size here is NOT dynamic + # NB: These also contain SymInt because wrap_meta_outputs_with_default_device_logic + # goes through this codepath. But it really should not LOL. + # NB: size could potentially be None as you can override it and make it + # throw an error, but we don't currently have any subclasses that do this + # except C++ nested tensor but we're going to have nested int to make this + # defined on NJT + size: tuple[int, ...] + dynamo_dynamic_indices: list[int] + dynamo_hint_overrides: dict[int, int] + + layout: torch.layout = torch.strided + is_inference: bool = False + is_leaf: bool = False + requires_grad: bool = False + is_sparse: bool = False + is_mkldnn: bool = False + is_functorch_wrapped: bool = False + is_batchedtensor: bool = False + is_legacy_batchedtensor: bool = False + is_gradtrackingtensor: bool = False + is_view: bool = False + is_nested: bool = False + # We eagerly symbolicize the associated nested int for e.g. offsets / lengths + # metadata if that offsets is already associated with a nested int. + # See test_construct_from_jagged_with_input_offsets_mixed_case. + nested_int: Optional[int] = None + is_traceable_wrapper_subclass: bool = False + is_functional: bool = False + is_conj: bool = False + is_neg: bool = False + is_parameter: bool = False + stride: Optional[tuple[int, ...]] = None + storage_offset: int = 0 + # NB: We have a choice whether or not to store the id or a direct pointer + # to the data structure. For ease of use, we store the data structure, + # but this means that when we serialize, we have to swizzle these pointers + # back into ids (so we have accurate aliasing relationships) + storage: Optional[MetaStorageDesc] = None + sparse_dim: Optional[int] = None # is_sparse, is_sparse_compressed + dense_dim: Optional[int] = None # is_sparse, is_sparse_compressed + is_coalesced: Optional[bool] = None # is_sparse + crow_indices: Optional[MetaTensorDesc] = None # is_sparse_compressed + col_indices: Optional[MetaTensorDesc] = None # is_sparse_compressed + ccol_indices: Optional[MetaTensorDesc] = None # is_sparse_compressed + row_indices: Optional[MetaTensorDesc] = None # is_sparse_compressed + values: Optional[MetaTensorDesc] = None # is_sparse_compressed + unwrapped: Optional[MetaTensorDesc] = None # is_functorch_wrapped + bdim: Optional[int] = None # is_functorch_wrapped + base: Optional[MetaTensorDesc] = None # is_view + attrs: Optional[dict[str, MetaTensorDesc]] = None # is_traceable_wrapper_subclass + creation_meta: Optional[CreationMeta] = None + grad: Optional[MetaTensorDesc] = None + + # Everything below is NOT serializable, need some more work + + _UNSERIALIZABLE: ClassVar[set[str]] = { + "ctx", + "type", + "fake_mode", + # view_func isn't serializable when it's a _CustomViewFunc + "view_func", + "level", + "current_level", + "functorch_stack", + "autograd_meta_from", + "data", + "nested_int", + } + + ctx: Optional[object] = None # is_traceable_wrapper_subclass + type: Optional[type] = None # is_traceable_wrapper_subclass + fake_mode: Optional[FakeTensorMode] = None + view_func: Optional[ViewFunc] = None + # level looks serializable, but actually it is meaningless without + # the functorch_stack below + level: Optional[int] = None # is_functorch_wrapped + current_level: Optional[int] = None + functorch_stack: Optional[list[CInterpreter]] = None + autograd_meta_from: Optional[torch.Tensor] = None + + # This is only populated on copy_data, and typically is not used at all, + # except for some of our meta-ification paths that don't properly use + # storage (pro-tip: you should use storage) + data: Optional[torch.Tensor] = None + + # Faithfully serializing functorch tensors will not be too difficult. + # We only need to consider grad/vmap interpreters, and their internal + # state is only bools (mostly what the grad enabled/disabled state + # should be in the lower layer). Beyond that, tensors just need to + # precisely indicate which particular interpreter they correspond + # to (we then replace level with a pointer to the interpreter stack.) + # However, this use of functorch is very "non-lexical" so it's not + # entirely clear how to make it all lexical again, so we haven't done + # it for now. + + # NB: This will reference numeric IDs, and it is assumed that you've + # already serialized everything this recursively references + def as_json(self, describer_id: _DescriberId) -> dict[str, object]: + def json(k: str, v: object) -> object: + # Some best-effort debugging serialization for unserializable + # fields (feel free to add other special cases as appropriate) + if k in ["data", "autograd_meta_from"]: + return None # never repr these + if k in MetaTensorDesc._UNSERIALIZABLE: + return repr(v) + if isinstance(v, (torch.device, torch.dtype, torch.layout)): + return repr(v) + if isinstance(v, torch.SymInt): + return repr(v) + if isinstance(v, (tuple, list)): + return [json(k, v1) for v1 in v] + if isinstance(v, (MetaStorageDesc, MetaTensorDesc)): + return v.id + if isinstance(v, CreationMeta): + return str(v) + if k == "attrs" and isinstance(v, dict): + return {k1: v1.id for k1, v1 in v.items()} + return v + + r = { + field.name: json(field.name, getattr(self, field.name)) + for field in dataclasses.fields(self) + if not ( + getattr(self, field.name) is field.default + or ( + field.name == "dynamo_dynamic_indices" + and not getattr(self, field.name) + ) + ) + } + r.update({"describer_id": describer_id}) + return r + + @property + def shape(self) -> tuple[int, ...]: + return self.size + + +# A more faithful reproduction would do a copy on the entire +# storage, but this needs to be done carefully because the +# underlying storage could have larger extent than is implied +# by size/stride. The real fix is to properly call +# meta_storage recursively here. +# +# These "safe" functions are intended to be used under no_dispatch() mode. +# The no_dispatch() here is intended to prevent ambient fake tensor mode from +# fakeifying the operation. But if we are given an honest to goodness +# FakeTensor as src, we MUST NOT run the copy/clone operation. A better way +# to do this would be to not use no_dispatch and instead just disable fake +# tensor mode only (allowing for subclass dispatch to occur) +def _safe_copy(dst: torch.Tensor, src: Optional[torch.Tensor]) -> None: + if type(src) is not torch.Tensor: + return + dst.copy_(src) + + +def _safe_clone(src: torch.Tensor) -> Optional[torch.Tensor]: + if type(src) is not torch.Tensor: + return None + return src.clone() + + +# This is a class for converting multiple tensors into meta tensors which +# share the same view/storage structure. The operation model is you allocate +# one of these, and then call it repeatedly on all the tensors you want to +# convert. It's important to use the same object for tensors you want to +# share storage because this is how we correlate shared storages to the same +# meta storages. This class will hold weak references to cached tenosrs +# and tensor storages. +class MetaConverter(Generic[_TensorT]): + def __init__(self, *, copy_data: bool = False) -> None: + # Maps MetaStorageId to UntypedStorage + self.storage_memo: weakref.WeakValueDictionary[ + MetaStorageId, torch.UntypedStorage + ] = weakref.WeakValueDictionary() + # Maps MetaTensorId to torch.Tensor (typically a meta tensor or + # FakeTensor) + self.tensor_memo: weakref.WeakValueDictionary[MetaTensorId, _TensorT] = ( + weakref.WeakValueDictionary() + ) + self.hit = 0 + self.miss = 0 + self.del_hook = None + self.arg_cnt = 0 + # Ensures real_storage/real_tensor are populated on the resulting + # metaified storage/tensor. The naming of this attribute is load + # bearing: FakeTensor relies on real tensor being set to exactly this + # value + self.copy_data = copy_data + self.describer = MetaTensorDescriber(copy_data=copy_data) + + def successful(self) -> bool: + return self.hit > 0 and self.miss == 0 + + def get_tensor_memo(self, t: MetaTensorDesc) -> Optional[torch.Tensor]: + return self.tensor_memo.get(t.id, None) + + def _checked_get_tensor_memo(self, t: MetaTensorDesc) -> _TensorT: + r = self.tensor_memo.get(t.id, None) + assert r is not None + return r + + def set_tensor_memo(self, t: MetaTensorDesc, v: _TensorT) -> None: + self.tensor_memo[t.id] = v + + def get_storage_memo(self, s: MetaStorageDesc) -> Optional[torch.UntypedStorage]: + return self.storage_memo.get(s.id, None) + + def set_storage_memo(self, s: MetaStorageDesc, v: torch.UntypedStorage) -> None: + self.storage_memo[s.id] = v + + def meta_storage( + self, + s: MetaStorageDesc, + callback: Callable[[Callable[[], torch.Tensor]], _TensorT], + ) -> torch.UntypedStorage: + # If we are fakeifying a tensor that has a secretly-zero-sized storage, + # Need to make sure to resize the meta storage too. + if (memo := self.get_storage_memo(s)) is None: + r_s = callback( + lambda: torch.empty(s.size, dtype=torch.uint8, device="meta"), + ).untyped_storage() + if self.copy_data: + # NB: no_dispatch is needed because internally storage copy is + # implemented as Tensor operations + with torch.no_grad(), no_dispatch(): + assert s.data is not None + _set_real_storage(r_s, s.data.clone()) + self.set_storage_memo(s, r_s) + return r_s + else: + return memo + + @classmethod + def _checked_cast_tensor_t(cls, t: torch.Tensor) -> _TensorT: + # TODO: how to check _TensorT? + return typing.cast(_TensorT, t) + + @classmethod + def _identity_callable( + cls, + t: Callable[[], torch.Tensor], + device: Optional[Union[torch.device, str]] = None, + ) -> _TensorT: + return cls._checked_cast_tensor_t(t()) + + @classmethod + def _backward_error(cls, t: _TensorT) -> _TensorT: + errfn = torch._C._functions.DelayedError( + "Internal error: Tried to backward() through example input", + 1, + ) + err = errfn(t) + return typing.cast(_TensorT, err) + + # This function assumes that it's possible to do the conversion + # NB: name here is used in a conventional way by Dynamo; it corresponds + # precisely to the Source.name of the tensor we're fakeifying and + # corresponds to a valid Python expression. When we construct sub-names + # as part of this process, we will maintain this invariant! (Even though + # other users of this may not need it this property to be upheld.) + def meta_tensor( + self, + t: MetaTensorDesc, + shape_env: Optional[ShapeEnv], + callback_: _MetaTensorCallback[_TensorT], + source: Optional[Source], + symbolic_context: Optional[SymbolicContext], + ) -> _TensorT: + callback: _MetaTensorCallbackOptDevice = functools.partial( + callback_, device=t.device + ) + if source is None: + from torch._dynamo.source import ConstantSource + + # TODO: make a dedicated UnknownSource for this? + source = ConstantSource( + f"__meta_utils_unknown_tensor{len(self.tensor_memo)}" + ) + + msg = ( + " This indicates you set no_dispatch() before calling into this" + " function. This is an error: we may be creating fake tensors and" + " will perform operations on them which need fake tensor mode to" + " be active. You will segfault if you are in a no_dispatch() block." + ) + assert not torch._C._dispatch_tls_local_exclude_set().has( + torch._C.DispatchKey.Python + ), msg + self.arg_cnt += 1 + + # When we make as_strided calls, we end up generating a guard + # that the new as_strided tensor is in bounds for the old storage + # for the base (since as_strided calls can "bust" out of their + # bounding box.) This guard is unnecessary: if a user is able + # to provide us a tensor with the view base setup this way, we + # don't need to produce a guard, because the fact that they + # were able to produce the view base means its in bounds. + # + # Now, ordinarily, this guard would be harmless. However, the + # generated guard refers to variables bound on the base variable. + # At the moment, Dynamo doesn't actually guard on x._base, because + # according to Voz this results in a lot of spurious invalidations, + # and also if the user doesn't directly make use of _base, its + # pointless anyway (because programs should be parametric over + # whether or not the input tensor is a view or not--unless you're + # mutating the input, but that's a whole 'nother ballgame). So + # for expediency, we suppress these guards so we don't have to + # deal with this (yet, anyway.) + # + # NB: An old version of this code suppressed guards for ALL operations + # happening during meta conversion, not just as_strided calls. + # This is too aggressive: we do duck sizing and 0/1 simplification + # as we allocate variables, and we do need to register guards for + # these cases. + maybe_suppress: Callable[[], Any] = contextlib.nullcontext + if shape_env is not None: + maybe_suppress = shape_env.suppress_guards + + def sym_sizes_strides_storage_offset( + t: MetaTensorDesc, + src: torch._guards.Source, + symbolic_context: Optional[ + torch.fx.experimental.symbolic_shapes.SymbolicContext + ] = symbolic_context, + ) -> tuple[tuple[int, ...], tuple[int, ...], int]: + assert t.stride is not None + if shape_env is not None: + fake_mode = t.fake_mode + if fake_mode is not None and fake_mode.shape_env is shape_env: + # Don't reallocate the sizes; the shape envs are the same, + # so reuse the old sizes/strides/etc + return (t.size, t.stride, t.storage_offset) + else: + # TODO: deduplicate this + t_size = tuple( + shape_env._maybe_specialize_sym_int_with_hint(sz) + for sz in t.size + ) + t_stride = tuple( + shape_env._maybe_specialize_sym_int_with_hint(sd) + for sd in t.stride + ) + t_storage_offset = shape_env._maybe_specialize_sym_int_with_hint( + t.storage_offset + ) + return shape_env._create_symbolic_sizes_strides_storage_offset( + t_size, + t_stride, + t_storage_offset, + [d in t.dynamo_dynamic_indices for d in range(t.ndim)], + src, + symbolic_context=symbolic_context, + hint_overrides=t.dynamo_hint_overrides, + ) + else: + return (t.size, t.stride, t.storage_offset) + + def empty_create( + inner_t: MetaTensorDesc, + inner_src: torch._guards.Source, + symbolic_context: Optional[ + torch.fx.experimental.symbolic_shapes.SymbolicContext + ] = symbolic_context, + ) -> torch.Tensor: + ( + inner_sizes, + inner_strides, + _inner_storage_offset, + ) = sym_sizes_strides_storage_offset(inner_t, inner_src, symbolic_context) + return torch.empty_strided( + inner_sizes, + inner_strides, + dtype=inner_t.dtype, + device="meta", + ) + + # Creates a subclass instance with empty inner tensors according to the specified + # symbolic context. + def empty_create_subclass( + t: MetaTensorDesc, + outer_size: tuple[int, ...], + outer_stride: tuple[int, ...], + symbolic_context: Optional[ + torch.fx.experimental.symbolic_shapes.SymbolicContext + ] = symbolic_context, + source: Optional[torch._guards.Source] = source, + ) -> _TensorT: + from torch._dynamo.source import AttrSource + from torch.fx.experimental.symbolic_shapes import SubclassSymbolicContext + + assert t.attrs is not None + assert t.type is not None + # NB: t.ctx could be None if the subclass in question has no + # meaningful context + + # Note: transform_subclass will use __tensor_unflatten__ to generate + # a fresh subclass wrapper with outer sizes / strides according to the + # outer symbolic context (passed in to this function). Inner size / stride + # / storage offset symbols are allocated according to the appropriate inner + # symbolic contexts, after which the checks in transform_subclass() will + # relate them to the outer metadata as possible. + # + # Morally, the code here is same as transform_subclass, but we've + # written it from scratch to read EmptyCreateSubclass + outer_size = outer_size if outer_size is not None else t.size + # pyrefly: ignore [bad-assignment] + outer_stride = outer_stride if outer_stride is not None else t.stride + + assert symbolic_context is None or isinstance( + symbolic_context, SubclassSymbolicContext + ) + + def _empty_create_subclass( + t: MetaTensorDesc, + outer_size: Optional[tuple[int, ...]], + outer_stride: Optional[tuple[int, ...]], + symbolic_context: Optional[ + torch.fx.experimental.symbolic_shapes.SymbolicContext + ], + callback: _MetaTensorCallbackOptDevice[_TensorT], + source: torch._guards.Source, + ) -> _TensorT: + # We are hitting plain meta_desc tensor so actually + # create a tensor here. + if t.attrs is None: + return self.meta_tensor( + t, + shape_env, + callback, + source, + symbolic_context, + ) + + inner_tensors = {} + for attr, meta_tensor_desc in t.attrs.items(): + current_context = None + if symbolic_context is not None: + assert isinstance(symbolic_context, SubclassSymbolicContext) + if ( + current_context_ := symbolic_context.inner_contexts[attr] + ) is not None: + current_context = _checked_cast( + torch.fx.experimental.symbolic_shapes.SymbolicContext, + current_context_, + ) + + current_source = AttrSource(source, attr) + inner_callback = functools.partial( + callback, device=meta_tensor_desc.device + ) + new_empty_tensor = _empty_create_subclass( + meta_tensor_desc, + meta_tensor_desc.size, + meta_tensor_desc.stride, + current_context, + inner_callback, + current_source, + ) + inner_tensors[attr] = new_empty_tensor + + assert t.type is not None + return t.type.__tensor_unflatten__( # type: ignore[attr-defined] + inner_tensors, t.ctx, outer_size, outer_stride + ) + + assert source is not None + sub = _empty_create_subclass( + t, outer_size, outer_stride, symbolic_context, callback, source + ) + + # NB: Purposefully guard here to simplify the inner / outer symbols. + # Using sym_eq() for symbolic comparison can result in an expression that's too + # difficult to guard on, so we use == here. + assert sub.shape == outer_size, ( + f"Expected return value from {t.type}__tensor_unflatten__() to have " + f"shape equal to {outer_size}, but got: {sub.shape}" + ) + assert sub.stride() == outer_stride, ( + f"Expected return value from {t.type}__tensor_unflatten__() to have " + f"stride equal to {outer_stride}, but got: {sub.stride()}" + ) + + return sub + + # Returns an all-dynamic symbolic context used for metafying the given tensor with + # fully dynamic dims. This is useful when fake-ifying intermediate tensors in + # closed-over ViewFunc state, as we don't have symbolic contexts for them, but we + # don't want to over-specialize during view replay. + def all_dynamic_symbolic_context( + t: MetaTensorDesc, + source: torch._guards.Source, + shape_env: Optional[torch.fx.experimental.symbolic_shapes.ShapeEnv], + callback: _MetaTensorCallback[_TensorT], + ) -> torch.fx.experimental.symbolic_shapes.SymbolicContext: + from torch._dynamo.source import AttrSource + from torch.fx.experimental.symbolic_shapes import ( + DimDynamic, + StatelessSymbolicContext, + SubclassSymbolicContext, + ) + + view_base_context: Optional[ + torch.fx.experimental.symbolic_shapes.SymbolicContext + ] = None + if t.is_view: + assert t.base is not None + view_base_context = all_dynamic_symbolic_context( + t.base, AttrSource(source, "_base"), shape_env, callback + ) + + t_symbolic_context: torch.fx.experimental.symbolic_shapes.SymbolicContext + t_dynamic_sizes = [DimDynamic.DYNAMIC] * t.ndim + if t.is_traceable_wrapper_subclass: + assert t.attrs is not None + inner_contexts: dict[ + str, torch.fx.experimental.symbolic_shapes.SymbolicContext + ] = {} + for attr, inner in t.attrs.items(): + assert isinstance(attr, str) + inner_contexts[attr] = all_dynamic_symbolic_context( + inner, AttrSource(source, attr), shape_env, callback + ) + t_symbolic_context = SubclassSymbolicContext( + dynamic_sizes=t_dynamic_sizes, + constraint_sizes=[None] * t.ndim, + inner_contexts=inner_contexts, # type: ignore[arg-type] + tensor_source=source, + view_base_context=view_base_context, + ) + else: + t_symbolic_context = StatelessSymbolicContext( + dynamic_sizes=t_dynamic_sizes, + constraint_sizes=[None] * t.ndim, + view_base_context=view_base_context, + ) + + return t_symbolic_context + + # Returns a fake-ified version of an input view tensor t, given an already fake-ified + # base. At a high level, we want two things: + # 1. fake_t should have the same view relationship to the given fake base as the + # input t has to its _base. + # 2. fake_t should have symbolic sizes / strides / storage offset according to the + # appropriate symbolic context (i.e. from the automatic dynamic algorithm). + # + # We currently take different strategies across view types: + # * For dense -> dense views, accomplish both (1) and (2) simultaneously via an + # as_strided() call on the fake-ified base, passing symbolic metadata. + # * For views involving subclasses, perform view replay using view funcs to + # achieve (1). It's necessary for (2) to swap out any closed-over state in + # the view funcs with symbolicized SymInts and fake-ified tensors. Doing this + # avoids specialization (and thus over-eager simplification of symbols) that + # could occur during view replay on the fake-ified base. + # + # Examples: + # * t.unsqueeze(-1) with dense t is a dense -> dense view. It can be modeled + # with an as_strided() call on the fake base passing symbolic metadata. + # * sub.select(dim=0, index=3) is a subclass -> subclass view. The index arg + # is made symbolic to avoid invalid specialization and view replay is then + # done to reconstruct the view. + # * _nested_from_jagged(values, offsets) is a dense -> subclass view + # that returns a subclass instance from a dense values tensor. The offsets + # tensor is closed over in the view func, as it can be considered view metadata. + # First, the offsets tensor is fake-ified according to the inner symbolic + # context and with the correct relationship to the outer size / stride metadata. + # Then view replay is done, swapping in the fake offsets so the view replay output + # is fully fake with no invalid specialization. + def view_from_base( + base: _TensorT, + t: MetaTensorDesc, + shape_env: Optional[ + torch.fx.experimental.symbolic_shapes.ShapeEnv + ] = shape_env, + ) -> _TensorT: + with enable_python_dispatcher(): + # fake-ify t's metadata according to the outer symbolic context + (sizes, strides, storage_offset) = sym_sizes_strides_storage_offset( + t, source + ) + if ( + not t.is_traceable_wrapper_subclass + and not is_traceable_wrapper_subclass(base) + ): + # Dense -> Dense view case uses as_strided() to construct view relationship. + # TODO: Change this logic to use view replay for consistency? + # It's likely there is no view func available. + with maybe_suppress(): + return self._checked_cast_tensor_t( + base.as_strided(sizes, strides, storage_offset) + ) + + from torch._dynamo.source import EphemeralSource + from torch.fx.experimental.symbolic_shapes import ( + StatelessSymbolicContext, + sym_eq, + ) + + def symint_visitor_fn(s: int) -> int: + nonlocal symbolic_context + from torch.fx.experimental.symbolic_shapes import DimDynamic + + all_static_sizes = ( + symbolic_context is not None + and isinstance(symbolic_context, StatelessSymbolicContext) + and all( + x is DimDynamic.STATIC + for x in symbolic_context.dynamic_sizes + ) + ) + # Can't just rely on shape env being None - dynamo always initializes it + if all_static_sizes or shape_env is None: + return s + + # NB: The symbol here is expected to be simplified out because we a priori + # allocate inner and outer symbols according to the appropriate symbolic + # contexts and prefer those over this symbol during symbol simplification + # (via usage of EphemeralSource below). This -shouldn't- happen, but if + # this symbol somehow leaks out beyond the view tensor's shape metadata, our + # assumption of it being simplified out will fail and it may be guarded on, + # which will hard error. + sym_source = EphemeralSource("symint_visitor_fn") + + symbol = shape_env.create_symbol(s, sym_source, positive=None) + return shape_env.create_symintnode( + symbol, hint=s, source=sym_source + ) + + real_to_fake_mapping = {} + if t.is_traceable_wrapper_subclass: + assert t.attrs is not None + # NB: t.ctx could be None if the subclass in question has no + # meaningful context + assert t.type is not None + + # Fake-ify t naively here; this is only done so we can get fake-ified inner + # tensors with the correct relationships to the outer sizes / strides for use + # in view replay. It's done beforehand here because it's not easy to do when + # visiting tensors one-by-one during view replay. + # + # Example: + # Consider a Dense -> NJT view. NJT has (values, offsets) components and we + # want a view of values with the offsets closed over. As the offsets component + # is needed to describe the output view, it's important that it's fakeified + # correctly. + fake_t: _TensorT = empty_create_subclass( + t, outer_size=sizes, outer_stride=strides + ) + attrs, _ = fake_t.__tensor_flatten__() # type: ignore[attr-defined] + for attr in attrs: + real_to_fake_mapping[t.attrs[attr].id] = getattr(fake_t, attr) + + def tensor_visitor_fn( + visited_t: torch.Tensor, + # These arguments are never passed, we just use them to close + # over these relevant values + shape_env: Optional[ + torch.fx.experimental.symbolic_shapes.ShapeEnv + ] = shape_env, + callback: _MetaTensorCallbackOptDevice[_TensorT] = callback, + ) -> torch.Tensor: + # It's possible to close over an undefined tensor (e.g. NJT's lengths). + if visited_t is None: + # pyrefly: ignore [bad-return] + return None + + # NB: visited_t being a Tensor here is very naughty! Should + # have already been described + + # Fake inner tensors of view subclasses will come from the mapping built above. + visited_id = self.describer.get_tensor_id(visited_t) + fake_visited_t = real_to_fake_mapping.get(visited_id) + if fake_visited_t is not None: + return fake_visited_t + + visited_desc = self.describer.describe_tensor(visited_t) + + # For other closed-over tensor state, fake-ify it as all dynamic with an + # ephemeral source. This avoids invalid specialization during view replay. + # If we find that in practice the usage of ephemeral sources isn't enough + # to guarantee that we don't have guards on these symbols, we may need to + # explicitly suppress guards (as is done for _base in the dense -> dense + # view case). + temp_source = EphemeralSource("tensor_visitor_fn") + return self.meta_tensor( + visited_desc, + shape_env, + callback, + temp_source, + all_dynamic_symbolic_context( + visited_desc, temp_source, shape_env, callback + ), + ) + + # Replay the view, swapping out any non-symbolic SymInts or real tensors + # for symbolic SymInts or fake tensors. + assert t.view_func is not None + # NB: we do NOT suppress guards here, we need to remove ephemeral + # sources + fake_t = t.view_func.apply( + t, base, symint_visitor_fn, tensor_visitor_fn + ) + + # Ensure the output has symbolic shapes according to the outer symbolic context. + # These checks should simplify out any symbols created for closed-over view func + # SymInts. + torch._check(sym_eq(fake_t.size(), sizes)) + torch._check(sym_eq(fake_t.stride(), strides)) + torch._check(sym_eq(fake_t.storage_offset(), storage_offset)) + return fake_t + + if self.get_tensor_memo(t) is None: + GRAD_TENSOR_SENTINEL_VALUE = -2 + + with torch.inference_mode(t.is_inference): + if t.is_sparse: + is_leaf = t.is_leaf + + # The lambda function below is similar to + # `t.to(device='meta')` except the latter + # preserves nnz value + r = callback( + lambda: torch.ops.aten._sparse_coo_tensor_with_dims( + t.sparse_dim, + t.dense_dim, + t.size, + dtype=t.dtype, + layout=torch.sparse_coo, + device="meta", + ) + ) + if self.copy_data: + # Pray that sparse clone doesn't lose information + assert t.data is not None + with torch.no_grad(), no_dispatch(): + assert _is_fake_tensor(r) + r.real_tensor = _safe_clone(t.data) + assert safe_is_leaf(r), "the callback you passed in doesn't detach" + # Note [is_coalesced is dispatched] + # Strangely enough, is_coalesced() is a dispatched operator, + # which means that it will get caught by fake tensor mode. + # Ordinarily this would error, but there's some logic in + # fake tensor ensure this doesn't happen. + r._coalesced_(bool(t.is_coalesced)) + if t.requires_grad: + r.requires_grad = True + if t.requires_grad and not is_leaf: + # This should probably use DelayedError, + # but clone is fine for now for sparse tensors. + # (DelayedError does not work for sparse because it causes + # the Fake sparse tensor to "lose" its fakeness) + r = self._checked_cast_tensor_t(r.clone()) + with torch.enable_grad(): + r._coalesced_(bool(t.is_coalesced)) + elif is_sparse_compressed_layout(t.layout): + is_leaf = t.is_leaf + + if t.layout in {torch.sparse_bsr, torch.sparse_bsc}: + assert t.sparse_dim is not None + assert t.dense_dim is not None + assert t.values is not None + batch_dim = t.ndim - t.sparse_dim - t.dense_dim + blocksize = t.values.shape[batch_dim + 1 : batch_dim + 3] + else: + blocksize = () + if t.layout in {torch.sparse_csr, torch.sparse_bsr}: + assert t.crow_indices is not None + index_dtype = t.crow_indices.dtype + else: + assert t.ccol_indices is not None + index_dtype = t.ccol_indices.dtype + + r = callback( + lambda: torch.ops.aten._sparse_compressed_tensor_with_dims( + 0, + t.dense_dim, + t.shape, + blocksize, + index_dtype, + layout=t.layout, + dtype=t.dtype, + device="meta", + ) + ) + if self.copy_data: + # Pray sparse clone doesn't lose information + assert t.data is not None + with torch.no_grad(), no_dispatch(): + assert _is_fake_tensor(r) + r.real_tensor = _safe_clone(t.data) + assert safe_is_leaf(r), "the callback you passed in doesn't detach" + if t.requires_grad: + r.requires_grad = True + if t.requires_grad and not is_leaf: + # pyrefly: ignore [bad-argument-type] + r = self._backward_error(r) + elif t.is_nested and not t.is_traceable_wrapper_subclass: + # TODO: Handle this better in Dynamo? + # There are checks there now, but this can still be triggered by a dense + # tensor graph input that is a view of a strided NT. + from torch._dynamo.exc import unimplemented + + # NOTE this graph break will NOT be present in Dynamo's graph break registry + unimplemented( + gb_type="attempted to apply meta conversion to strided nested tensor", + context=str(t), + explanation="This is not supported.", + hints=[], + ) + elif t.is_mkldnn: + is_leaf = t.is_leaf + ( + sizes, + strides, + _storage_offset, + ) = sym_sizes_strides_storage_offset(t, source) + # TODO: This doesn't seem right, where's the MKLDNN'ness + # lol + r = callback( + lambda: torch.empty_strided( + sizes, strides, dtype=t.dtype, device="meta" + ) + ) + if self.copy_data: + with torch.no_grad(), no_dispatch(): + assert t.size is not None + assert t.stride is not None + assert _is_fake_tensor(r) + r.real_tensor = torch.empty_strided( + t.size, t.stride, dtype=t.dtype, device=t.device + ) + assert t.data is not None + _safe_copy(r.real_tensor, t.data) + assert safe_is_leaf(r), "the callback you passed in doesn't detach" + if t.requires_grad: + r.requires_grad = True + if t.requires_grad and not is_leaf: + # pyrefly: ignore [bad-argument-type] + r = self._backward_error(r) + elif t.is_functorch_wrapped: + if t.is_view: + from torch._dynamo.exc import unimplemented + + unimplemented( + gb_type="attempted to apply meta conversion to view functorch tensor", + context=str(t), + explanation="This is not supported.", + hints=[], + ) + + # Wraps a functorch tensor class (BatchedTensor, GradTrackingTensor) + # in a FakeTensor + def _to_fake_tensor(t: MetaTensorDesc) -> _TensorT: + # TODO: why aren't the recursive calls going to + # meta_tensor + r: _TensorT + if t.is_batchedtensor: + assert t.unwrapped is not None + assert t.level is not None + assert t.bdim is not None + ft = _to_fake_tensor(t.unwrapped) + lvl = t.level + bdim = t.bdim + # You cannot create functorch tensors without + # having the ambient funtorch interpreter stack + # available, as the level refers to things in the + # stack + with torch._functorch.pyfunctorch.temporarily_restore_interpreter_stack( + t.functorch_stack + ): + r = self._checked_cast_tensor_t( + _add_batch_dim(ft, bdim, lvl) + ) + elif t.is_gradtrackingtensor: + assert t.unwrapped is not None + assert t.level is not None + disable_functorch = torch._C._DisableFuncTorch + with disable_functorch(): + ft = _to_fake_tensor(t.unwrapped) + lvl = t.level + if lvl == GRAD_TENSOR_SENTINEL_VALUE: + r = ft + else: + with torch._functorch.pyfunctorch.temporarily_restore_interpreter_stack( + t.functorch_stack + ): + r = self._checked_cast_tensor_t( + torch._C._functorch._wrap_for_grad(ft, lvl), + ) + + is_leaf = t.is_leaf + if t.requires_grad and safe_is_leaf(r): + r.requires_grad = True + elif t.requires_grad and not is_leaf: + r = self._backward_error(r) + elif t.is_functional: + assert t.unwrapped is not None + assert t.current_level is not None + ft = self.meta_tensor( + t.unwrapped, + shape_env, + callback, + # NB: reuse these exactly, we treat the + # functional tensor as "invisible". + # TODO: Actually this all probably doesn't + # work, take a closer look. + source, + symbolic_context, + ) + r = self._checked_cast_tensor_t( + _wrap_functional_tensor(ft, t.current_level), + ) + # TODO: is_leaf/requires_grad? + else: + assert t.stride is not None + + sizes = t.size + strides = t.stride + r = callback( + lambda: torch.empty_strided( + sizes, + strides, + dtype=t.dtype, + device="meta", + ), + # device="meta", + ) + if self.copy_data: + with torch.no_grad(), no_dispatch(): + r.real_tensor = torch.empty_strided( # type: ignore[attr-defined] + t.size, + t.stride, + dtype=t.dtype, + device=t.device, + ) + assert t.data is not None + _safe_copy(r.real_tensor, t.data) # type: ignore[attr-defined] + # pyrefly: ignore [bad-return] + return r + + r = _to_fake_tensor(t) + + elif t.is_functional and t.device.type not in ["xla", "lazy"]: + assert t.unwrapped is not None + assert not t.is_functorch_wrapped # handled above + unwrapped = self.meta_tensor( + t.unwrapped, + shape_env, + callback, + source, + symbolic_context, + ) + r = self._checked_cast_tensor_t( + torch._to_functional_tensor(unwrapped) + ) + torch._mirror_autograd_meta_to(t.autograd_meta_from, r) # type: ignore[attr-defined] + + elif t.is_view: + # Construct views in two steps: recursively meta-fy their + # base, and then create view(s) off that. NB: doing it + # directly from storage is WRONG because this won't cause + # version counters to get shared. + + assert t.base is not None + + base_symbolic_context = None + if shape_env and symbolic_context is not None: + from torch.fx.experimental.symbolic_shapes import ( + StatelessSymbolicContext, + ) + + assert isinstance(symbolic_context, StatelessSymbolicContext) + # NB: This should generally be set when the input is a view, + # but the exception right now is for fake-ifying grads, which is + # a work in progress. + if symbolic_context.view_base_context is not None: + base_symbolic_context = symbolic_context.view_base_context + + base = self.meta_tensor( + t.base, + shape_env, + callback, + torch._dynamo.source.AttrSource(source, "_base"), + base_symbolic_context, + ) + + def is_c_of_r( + complex_dtype: torch.dtype, real_dtype: torch.dtype + ) -> bool: + return ( + utils.is_complex_dtype(complex_dtype) + and utils.corresponding_real_dtype(complex_dtype) + == real_dtype + ) + + # In some situations, MetaConverter may be called in a + # context where autograd is disabled. For the _is_view + # assert to pass, we have to setup the autograd view + # metadata anyway. Do this by reenabling the + # ADInplaceOrView key. This is kind of a hack. + old_exclude = torch._C._dispatch_tls_is_dispatch_key_excluded( + torch._C.DispatchKey.ADInplaceOrView + ) + torch._C._dispatch_tls_set_dispatch_key_excluded( + torch._C.DispatchKey.ADInplaceOrView, False + ) + try: + if base.dtype == t.dtype: + pass + elif is_c_of_r(base.dtype, t.dtype): + base = self._checked_cast_tensor_t(torch.view_as_real(base)) + elif is_c_of_r(t.dtype, base.dtype): + base = self._checked_cast_tensor_t( + torch.view_as_complex(base) + ) + else: + # This is not guaranteed to succeed. If it fails, it + # means there is another dtype-converting view function + # that hasn't been handled here + base = self._checked_cast_tensor_t(base.view(t.dtype)) + + # This is very tricky. Naively, you might expect this + # to hold: + # + # if t.requires_grad and not safe_is_leaf(t) + # assert t._base.requires_grad + # + # But it's not true! As you can see in the following + # program: + # + # x = torch.zeros(4) + # y = x.view(1, 4) + # y.requires_grad = True + # z = y.view(1, 1, 4) + # assert z._base is x + # + # So we may have to do *two* views out of the base to + # recreate this situation. + if t.is_leaf: + # Leaf views that track view metadata are created by + # creating a view inside a no_grad block + with torch.no_grad(): + r = view_from_base(base, t) + # As it's a leaf, we can directly assign requires_grad + r.requires_grad = t.requires_grad + else: + if t.base.requires_grad == t.requires_grad: + # Easy case, just run the view op + with torch.enable_grad(): + r = view_from_base(base, t) + + # NB: We don't actually faithfully replicate + # autograd connectivity, but that doesn't matter + # today. See following for more info: + # https://gist.github.com/soulitzer/e03f015b314c3f5fcf80888c69390913 + else: + # Obscure case. Create a leaf view and give it the + # correct requires_grad, then do the final view. + # NB: Can't have a non-leaf without requiring grad! + assert t.requires_grad + with torch.no_grad(), enable_python_dispatcher(): + mid = self._checked_cast_tensor_t( + base.view(base.shape) + ) + mid.requires_grad = t.requires_grad + with torch.enable_grad(): + r = view_from_base(mid, t) + # The CreationMeta influences whether or not inplace + # mutation is an error or not. So we need to make + # sure we properly propagate this as well. + assert t.creation_meta is not None + torch._C._autograd._set_creation_meta(r, t.creation_meta) + finally: + torch._C._dispatch_tls_set_dispatch_key_excluded( + torch._C.DispatchKey.ADInplaceOrView, old_exclude + ) + + r.fake_device = t.device # type: ignore[attr-defined] + + else: + is_leaf = t.is_leaf + + # Graph-Break for wrapped tensors + if ( + not (t.is_batchedtensor or t.is_gradtrackingtensor) + and t.is_functorch_wrapped + ) or t.is_legacy_batchedtensor: + # pyrefly: ignore [bad-return] + return NotImplemented + + ( + sizes, + strides, + storage_offset, + ) = sym_sizes_strides_storage_offset(t, source, symbolic_context) + + # If we have a subclass that desugars into dense tensors, + # perform our callback on each inner tensor. + if t.is_traceable_wrapper_subclass: + r = empty_create_subclass( + t, outer_size=sizes, outer_stride=strides + ) + else: + r = callback( + lambda: torch.empty_strided( + sizes, + strides, + dtype=t.dtype, + device="meta", + ) + ) + if self.copy_data: + with torch.no_grad(), no_dispatch(): + assert t.size is not None + assert t.stride is not None + assert _is_fake_tensor(r) + r.real_tensor = torch.empty_strided( + t.size, t.stride, dtype=t.dtype, device=t.device + ) + _safe_copy(r.real_tensor, t.data) + + assert safe_is_leaf(r), "the callback you passed in doesn't detach" + if t.requires_grad: + r.requires_grad = t.requires_grad + if not is_leaf: + # Fake up some autograd history. + # Note: we *used* to call .clone() here to mock up some autograd history. + # This is bad for subclasses. + # Consider the case where you have a wrapper subclass that is contiguous, + # but its inner tensor is noncontiguous(). + # .clone() (or other ops) will have the side effect of changing + # the metadata of the inner tensor. + # So instead, we now have a dedicated fn to set autograd history, + # without inadvertently changing other metadata. + # pyrefly: ignore [bad-argument-type] + r = self._backward_error(r) + + s = t.storage + assert s is not None + if s.id not in self.storage_memo and ( + r.is_nested + or ( + r.stride() == strides + and r.storage_offset() == storage_offset + ) + ): + # You're normal and happy, install the fresh storage into the memo + self.set_storage_memo(s, r.untyped_storage()) + if self.copy_data: + assert _is_fake_tensor(r) + assert r.real_tensor is not None + _set_real_storage( + r.untyped_storage(), r.real_tensor.untyped_storage() + ) + else: + # You're in crazy town; somehow you gave us a tensor + # that wasn't a view, but had nonzero storage offset, + # nontrivial strides (such that clone() couldn't + # preserve them), or already aliases with another + # tensor's storage. The most typical way to end + # up here is with set_. So use set_ to bludgeon this + # in. + r_s = self.meta_storage(s, callback=callback) + # NB: In principle, this should always work, but there + # is some subtle difference in the autograd metadata + # that means we will backprop the set_ call, even if + # r is declared as an input to grad. + # See https://github.com/pytorch/pytorch/issues/87956 + # for the reproducer. + # NB: The in_kernel_invocation_manager here is necessary + # for fake tensor. If we run the set_ call with fake + # tensor on, r will improperly report that it is NOT a + # meta tensor but a cpu tensor, and then the set_ call + # will fail due to device mismatch. no_dispatch() is + # not enough, because the fake tensor will still claim + # to be a CPU tensor and you'll end up in the CPU + # kernel. Arguably this is a hack; a cleaner way to + # solve this is to have a FakeStorage concept which + # would report it's CPU device--no problem now! But + # this is difficult to do because we don't have storage + # subclasses. Relevant test is + # DynamicShapesFunctionTests::test_add_dynamic_shapes in + # test/dynamo/test_dynamic_shapes.py + maybe_fake_mgr: AbstractContextManager[None] = ( + contextlib.nullcontext() + ) + from torch._subclasses.fake_tensor import ( + in_kernel_invocation_manager, + maybe_get_fake_mode, + ) + + mb_fake_mode = maybe_get_fake_mode(r) + if mb_fake_mode is not None: + maybe_fake_mgr = in_kernel_invocation_manager(mb_fake_mode) + with torch.no_grad(), maybe_suppress(): + with maybe_fake_mgr: + r.set_(r_s, storage_offset, sizes, strides) + if self.copy_data: + with torch.no_grad(), no_dispatch(): + assert _is_fake_tensor(r) + assert r.real_tensor is not None + assert t.stride is not None + r.real_tensor.set_( + _get_real_storage(r_s), + t.storage_offset, + t.size, + t.stride, + ) + + if t.grad is not None: + from torch._dynamo.source import AttrSource + + # TODO: Use a valid grad-specific symbolic context instead of recycling + # the one from t. This isn't correct if e.g. t._is_view() != t.grad._is_view(). + # pyrefly: ignore [unbound-name] + r.grad = self.meta_tensor( + t.grad, + shape_env, + callback, + AttrSource(source, "grad"), + symbolic_context, + ) + # pyrefly: ignore [unbound-name] + torch._C._set_conj(r, t.is_conj) + # pyrefly: ignore [unbound-name] + torch._C._set_neg(r, t.is_neg) + # This can be skipped if necessary for performance reasons + skip_leaf = ( + t.is_gradtrackingtensor and t.level == GRAD_TENSOR_SENTINEL_VALUE + ) + # pyrefly: ignore [unbound-name] + assert_metadata_eq(assert_eq, t, r, skip_symbolic=True, skip_leaf=skip_leaf) + # Thanks to storage resizing, it's possible to end up with a tensor + # that advertises a real size, but has a storage that actually has zero bytes. + # Need to reflect this in the generated FakeTensor. + from torch.fx.experimental.symbolic_shapes import guard_or_false + + if t.storage is not None and guard_or_false(t.storage.size == 0): + # pyrefly: ignore [unbound-name] + r.untyped_storage().resize_(0) + + if t.is_parameter: + # pyrefly: ignore [unbound-name] + r._is_param = True + + # See Note: [Creating symbolic nested int] + if t.nested_int is not None: + # pyrefly: ignore [unbound-name] + assert _is_fake_tensor(r) + # pyrefly: ignore [unbound-name] + r.nested_int_memo = r.fake_mode.create_symbolic_nested_int( + nt_tensor_id=t.nested_int + ) + + # pyrefly: ignore [bad-argument-type, unbound-name] + self.set_tensor_memo(t, r) + + return self._checked_get_tensor_memo(t) + + def __call__( + self, + t: torch.Tensor, + shape_env: Optional[ShapeEnv] = None, + *, + callback: Optional[_MetaTensorCallback[_TensorT]] = None, + source: Optional[Source] = None, + symbolic_context: Optional[SymbolicContext] = None, + # Controls whether or not we should dump the tensor metadata to structured logs + # when source is not None. Because we refakify after Dynamo is done, + # we don't want to dump info again from AOTAutograd, it is redundant. + trace: bool = True, + ) -> _TensorT: + callback_: _MetaTensorCallback[_TensorT] + if callback is None: + callback_ = self._identity_callable + else: + callback_ = callback + # TODO: zero tensors? We appear to have eliminated them by + # excluding complex for now + + # Filter out cases we don't support + # TODO: This can probably be simplified quite a bit + if isinstance(t, torch.Tensor): + if ( + # Lazy tensors are not supported. Note that XLA is + # implemented on top of lazy tensor, not excluded here; we + # have some special handling for it; this is for XLA Dynamo + # integration + t.device.type == "lazy" + or + # Quantization is not supported + t.is_quantized + or + # Views out of sparse tensors not currently supported (plain + # sparse is supported htough) + (t._is_view() and t._base is not None and t._base.is_sparse) + ): + self.miss += 1 + # pyrefly: ignore [bad-return] + return NotImplemented + else: + self.hit += 1 + elif torch.overrides.is_tensor_like(t): + self.miss += 1 + # pyrefly: ignore [bad-return] + return NotImplemented + else: + # non-Tensor types don't count as hit or miss + return t + + if source is None: + trace = False + + # Describe the tensor. NB: do NOT disable ambient modes, we may need + # to query them when figuring out what to put in here + t_desc = self.describer.describe_tensor(t, trace=trace) + + if trace: + assert source is not None + trace_structured( + "describe_source", + metadata_fn=lambda: { + "describer_id": self.describer.id, + "id": t_desc.id, + "source": source.name, + }, + ) + + # Do the meta-fication. Here, we disable all the ambient modes, to + # better simulate what would be like to re-fakeify from a fresh + # process + with contextlib.ExitStack() as exit_stack: + exit_stack.enter_context(torch._dispatch.python.suspend_functionalization()) + st = peek_interpreter_stack() + if st is not None: + exit_stack.enter_context( + torch._functorch.pyfunctorch.temporarily_clear_interpreter_stack() + ) + + r = self.meta_tensor( + t_desc, + shape_env, + callback_, + source, + symbolic_context, + ) + + if type(t) is torch.nn.Parameter: + # NB: Cannot directly use Parameter constructor + # because that would force a detach, not desirable + r._is_param = True + + # TODO: return the description for later + return r + + +import torch._prims_common as utils diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_subclasses/schema_check_mode.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_subclasses/schema_check_mode.py new file mode 100644 index 0000000000000000000000000000000000000000..28bbb8f335ec0c98d2fb6688425309da45e718c9 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/_subclasses/schema_check_mode.py @@ -0,0 +1,230 @@ +# mypy: ignore-errors + +from collections import namedtuple +from copy import deepcopy +from itertools import combinations + +import torch +from torch.fx.operator_schemas import normalize_function +from torch.utils import _pytree as pytree +from torch.utils._python_dispatch import TorchDispatchMode +from torch.utils._pytree import tree_map + + +# Named Tuples used within SchemaCheckMode +Mutation = namedtuple("Mutation", ["op_name", "arg_name"]) +Aliasing = namedtuple("Aliasing", ["op_name", "arg_name", "output_number"]) + +# Simplified naming for C++ classes +SchemaArgument = torch._C._SchemaArgument +SchemaArgType = torch._C._SchemaArgType +SchemaInfo = torch._C._SchemaInfo + +# This TorchDispatchMode Subclass is used to verify op schemas +# This TorchDispatchMode Scubclass currently: +# - Records the called ops +# - Checks for mutations on all inputs +# - Checks for aliasing on all inputs + + +# move these 2 functions here to avoid numpy dependency in testing/_internal/common_utils.py + + +def is_iterable_of_tensors(iterable): + # Tensor itself is iterable so we check this first + if isinstance(iterable, torch.Tensor): + return False + try: + if len(iterable) == 0: + return False + for t in iter(iterable): + if not isinstance(t, torch.Tensor): + return False + except TypeError: + return False + return True + + +def clone_inputs(args): + inputs = [] + + for arg in args: + if isinstance(arg, torch.Tensor): + inputs.append(arg.detach().clone()) + elif is_iterable_of_tensors(arg): + inputs.append([t.detach().clone() for t in arg]) + else: + inputs.append(arg) + + return inputs + + +class SchemaCheckMode(TorchDispatchMode): + def __init__(self) -> None: + # Information recorded for testing purposes. For example: + # - incorrect schemas + # - overly conservative schemas + self.ops = [] + self.mutated = [] + self.aliasing = [] + + def reset_cache(self): + self.ops.clear() + self.mutated.clear() + self.aliasing.clear() + + def display_ops(self): + print(*self.ops, sep=",") + + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + def bitwise_equal(lhs, rhs): + if lhs.is_quantized: + # TODO: This is only OK if can't have NaN quantized; idk if + # this is actually true + return torch.equal(lhs, rhs) + else: + return torch.allclose(lhs, rhs, equal_nan=True) + + def has_mutated(before, after, md): + are_tensors = type(before) is torch.Tensor and type(after) is torch.Tensor + if ( + are_tensors + and before.layout != torch.sparse_csr + and after.layout != torch.sparse_csr + ): + return not ( + before.size() == after.size() + and bitwise_equal(before, after) + and md[0] == after.stride() + and md[1] == after._typed_storage()._cdata + ) + return False + + def has_aliased(lhs, rhs): + try: + return torch._C._overlaps(lhs, rhs) + except Exception as exception: + if str(exception).startswith("Cannot inspect value of type "): + return False + else: + raise exception + + def standardize_name(name): + return name if name != "self" else "input" + + def unwrap(e): + if isinstance(e, torch.Tensor) and type(e) is not torch.Tensor: + try: + return e.elem + except AttributeError: + return e + return e + + def parse_metadata(e): + if isinstance(e, torch.Tensor): + if type(e) is not torch.Tensor: + try: + current = e.elem + return ( + deepcopy(current.stride()), + current._typed_storage()._cdata, + ) + except AttributeError: + return None + # Sparse CSR tensors do not have strides or storage + elif e.layout != torch.sparse_csr: + return (deepcopy(e.stride()), e._typed_storage()._cdata) + return None + + self.ops.append(func._schema.name) + + # Clone and process arguments and outputs + pre_arguments = normalize_function( + func, args, kwargs, normalize_to_only_use_kwargs=True + ).kwargs + + c_p_args = dict(zip(pre_arguments.keys(), clone_inputs(pre_arguments.values()))) + cloned_arguments = { + name: tree_map(unwrap, c_p_args.get(name)) for name in c_p_args + } + cloned_metadata = { + name: [ + parse_metadata(a) for a in pytree.tree_leaves(pre_arguments.get(name)) + ] + for name in pre_arguments + } + + out = func(*args, **kwargs) + arguments = { + name: tree_map(unwrap, pre_arguments.get(name)) for name in pre_arguments + } + tuple_out = out if isinstance(out, tuple) else (out,) + tuple_out = tree_map(unwrap, tuple_out) + + schema_info = SchemaInfo(func._schema) + schema_info.add_argument_values(pre_arguments) + + # Process arguments with outputs + for i in range(len(func._schema.arguments)): + arg = func._schema.arguments[i] + name = standardize_name(arg.name) + if arguments.get(name) is not None: + before = cloned_arguments.get(name) + md = cloned_metadata.get(name) + after = arguments.get(name) + for j in range(len(tuple_out)): + # aten::_unsafe_view is intended to have incorrect aliasing notation (hence unsafe) + unsafe_ops = ("aten::_unsafe_view", "aten::unsafe_split") + if ( + has_aliased(tuple_out[j], after) + and func._schema.name not in unsafe_ops + ): + if not schema_info.may_contain_alias( + SchemaArgument(SchemaArgType.output, j), + SchemaArgument(SchemaArgType.input, i), + ): + raise RuntimeError( + f"Argument {name} is not defined to alias output but was aliasing" + ) + else: + self.aliasing.append( + Aliasing(func._schema.name, name, f"output_{j}") + ) + if after is tuple_out[j] and isinstance(after, torch.Tensor): + # Only mutable ops e.g. (add_, add.out) are allowed to directly return inputs. + if not schema_info.is_mutable( + SchemaArgument(SchemaArgType.input, i) + ) and func not in [ + torch.ops.aten.lift.default, + torch.ops.aten.lift_fresh.default, + ]: + raise RuntimeError( + f"""\ +Dispatcher operators below autograd are not allowed to directly return inputs. +However, we found that `outputs[{str(j)}] is {name}""" + ) + if any( + has_mutated(a, b, c) + for a, b, c in zip( + pytree.tree_leaves(before), pytree.tree_leaves(after), md + ) + ): + if not schema_info.is_mutable( + SchemaArgument(SchemaArgType.input, i) + ): + raise RuntimeError( + f"Argument {name} is not defined as mutable but was mutated" + ) + else: + self.mutated.append(Mutation(func._schema.name, name)) + + # Aliasing between outputs + for i, j in combinations(range(len(func._schema.returns)), 2): + if has_aliased(tuple_out[i], tuple_out[j]): + if not schema_info.may_contain_alias( + SchemaArgument(SchemaArgType.output, i), + SchemaArgument(SchemaArgType.output, j), + ): + raise RuntimeError(f"Outputs {i} and {j} alias unexpectedly") + + return out diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/accelerator/__pycache__/__init__.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/accelerator/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..657e7c35a9dd85a8682eed7f2544328043be0cc1 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/accelerator/__pycache__/__init__.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/accelerator/__pycache__/_utils.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/accelerator/__pycache__/_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e975f89ccdc69648d3efa8c155bda1dae87bab7a Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/accelerator/__pycache__/_utils.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/accelerator/__pycache__/memory.cpython-312.pyc b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/accelerator/__pycache__/memory.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4d2d85e97176dcc602659e7bffe39b5e0548db47 Binary files /dev/null and b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/accelerator/__pycache__/memory.cpython-312.pyc differ diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/compiler/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/compiler/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..442cd7d765b89136483e733f010c2a82d0fff18f --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/compiler/__init__.py @@ -0,0 +1,691 @@ +# mypy: allow-untyped-defs +import io +from collections.abc import Callable +from typing import Any, Optional, TYPE_CHECKING, TypeVar, Union +from typing_extensions import ParamSpec + +import torch + +from . import config + + +if TYPE_CHECKING: + from ._cache import CacheInfo + + +__all__ = [ + "compile", + "config", + "assume_constant_result", + "reset", + "allow_in_graph", + "substitute_in_graph", + "list_backends", + "disable", + "set_stance", + "set_enable_guard_collectives", + "cudagraph_mark_step_begin", + "load_compiled_function", + "wrap_numpy", + "is_compiling", + "is_dynamo_compiling", + "is_exporting", + "save_cache_artifacts", + "load_cache_artifacts", + "skip_guard_on_inbuilt_nn_modules_unsafe", + "skip_guard_on_all_nn_modules_unsafe", + "keep_tensor_guards_unsafe", + "skip_guard_on_globals_unsafe", + "skip_all_guards_unsafe", + "nested_compile_region", +] + + +_P = ParamSpec("_P") +_R = TypeVar("_R") +FuncType = Callable[..., Any] +F = TypeVar("F", bound=FuncType) + + +def compile(*args, **kwargs): + """ + See :func:`torch.compile` for details on the arguments for this function. + """ + # pyrefly: ignore [not-iterable] + return torch.compile(*args, **kwargs) + + +def reset() -> None: + """ + This function clears all compilation caches and restores the system to its initial state. + It is recommended to call this function, especially after using operations like `torch.compile(...)` + to ensure a clean state before another unrelated compilation + """ + import torch._dynamo + + torch._dynamo.reset() + + +def allow_in_graph(fn): + """ + Tells the compiler frontend (Dynamo) to skip symbolic introspection of the function + and instead directly write it to the graph when encountered. + + If you are using :func:`torch.compile` (with backend="inductor" (the default)), or + :func:`torch.export.export`, and trying to black-box a Python function throughout + all tracing, do not use this API. + Instead, please create a custom operator (see `PyTorch Custom Operators Landing Page + `_) + + .. warning:: + + If you're a typical torch.compile user (e.g. you're applying torch.compile to + a model to make it run faster), you probably don't want to use this function. + :func:`allow_in_graph` is a footgun because it skips the compiler frontend + (Dynamo) that is responsible for doing safety checks (graph breaks, handling + closures, etc). Incorrect usage will lead to difficult-to-debug silent + incorrectness issues. + + Given a Python function with no allow_in_graph decorator, regular execution + of torch.compile traces through the function. :func:`allow_in_graph` changes + it so that the frontend does not trace inside the function, but the compiler + backend still traces through it. Compare this to custom operators, which + treats a function as a black box throughout the torch.compile stack. The following + table compares these mechanisms. + + +------------------------+-----------------------+--------------------------------+ + | Mechanism | Frontend (Dynamo) | Backend (AOTAutograd+Inductor) | + +========================+=======================+================================+ + | no decorator | trace inside | trace inside | + +------------------------+-----------------------+--------------------------------+ + | allow_in_graph | opaque callable | trace inside | + +------------------------+-----------------------+--------------------------------+ + | custom op | opaque callable | opaque callable | + +------------------------+-----------------------+--------------------------------+ + + One common use case for :func:`allow_in_graph()` is as an escape hatch for the compiler + frontend: if you know the function works w.r.t. to the downstream components of the + compilation stack (AOTAutograd and Inductor) but there is a Dynamo bug that prevents it from + symbolically introspecting the function properly (or if your code is in C/C++ and + therefore cannot be introspected with Dynamo), then one can decorate said function + with :func:`allow_in_graph` to bypass Dynamo. + + We require that ``fn`` adhere to the following restrictions. Failure to adhere + results in undefined behavior: + + - The inputs to ``fn`` must be Proxy-able types in the FX graph. Valid types include: + Tensor/int/bool/float/None/List[Tensor?]/List[int?]/List[float?] + Tuple[Tensor?, ...]/Tuple[int?, ...]/Tuple[float?, ...]/torch.dtype/torch.device + - The outputs to ``fn`` must be Proxy-able types in the FX graph (see previous bullet) + - all Tensors used inside of ``fn`` must be passed directly as inputs to ``fn`` + (as opposed to being captured variables). + + Args: + fn: A callable representing the function to be included in the graph. + If ``fn`` is a list or tuple of callables it recursively applies + :func:`allow_in_graph()` to each function and returns a new list or + tuple containing the modified functions. + + Example:: + + torch.compiler.allow_in_graph(my_custom_function) + + + @torch.compile(...) + def fn(x): + x = torch.add(x, 1) + x = my_custom_function(x) + x = torch.add(x, 1) + return x + + + fn(...) + + Will capture a single graph containing ``my_custom_function()``. + + """ + import torch._dynamo + + return torch._dynamo.allow_in_graph(fn) + + +def substitute_in_graph( + original_fn: Callable[_P, _R], + *, + can_constant_fold_through: bool = False, + skip_signature_check: bool = False, +) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]: + """ + Register a polyfill handler for a function, usually a C function from the C extension, to be + used in place of the original function when inlining the original function in the graph. + + .. note:: + + The polyfill handler is only used when inlining the original function. It is not used when + the original function is called directly. In the eager mode, the decorated function calls + the performant C function rather than the polyfill handler. + + The polyfill handler is a function that will be called in place of the original function when + inlining the original function. The polyfill handler should have the same signature and the same + behavior as the original function. + + Args: + original_fn (callable): The original function, usually a C function, to register a polyfill + handler for. + can_constant_fold_through (bool, optional): Whether the polyfill handler can be constant + folded through. That is, if the polyfill handler is a pure function and its arguments + are constant, the result of the polyfill handler can be constant folded during the + compilation. Defaults to ``False``. + skip_signature_check (bool, optional): Whether to skip the signature check between the + original function and the polyfill handler. Defaults to ``False``. + + Returns: + A decorator that registers the polyfill handler for the original function. + + Example:: + + >>> import operator + >>> operator.indexOf([1, 2, 3, 4, 5], 3) + 2 + >>> torch.compile(operator.indexOf, fullgraph=True)([1, 2, 3, 4, 5], 3) + ... # xdoctest: +SKIP("Long tracebacks") + Traceback (most recent call last): + ... + torch._dynamo.exc.Unsupported: ... + + >>> @torch.compiler.substitute_in_graph(operator.indexOf) + ... def indexOf(a, b, /): + ... for i, item in enumerate(a): + ... if item is b or item == b: + ... return i + ... raise ValueError("sequence.index(x): x not in sequence") + >>> + >>> torch.compile(operator.indexOf, fullgraph=True)([1, 2, 3, 4, 5], 3) + 2 + """ + import torch._dynamo + + return torch._dynamo.substitute_in_graph( + original_fn, + can_constant_fold_through=can_constant_fold_through, + skip_signature_check=skip_signature_check, + ) + + +def list_backends(exclude_tags=("debug", "experimental")) -> list[str]: + """ + Return valid strings that can be passed to `torch.compile(..., backend="name")`. + + Args: + exclude_tags(optional): A tuple of strings representing tags to exclude. + """ + import torch._dynamo + + return torch._dynamo.list_backends(exclude_tags) + + +def assume_constant_result(fn): + """ + This function is used to mark a function `fn` as having a constant result. + This allows the compiler to optimize away your function. + Returns The same function `fn` + + Args: + fn: The function to be marked as having a constant result. + + .. warning:: + `assume_constant_result` can if invalid cause safety and soundness issues, :func:`torch.compile` + will not attempt to validate whether the constant assumption is true or not + + """ + import torch._dynamo + + return torch._dynamo.assume_constant_result(fn) + + +def disable(fn=None, recursive=True, *, reason=None): + """ + This function provides a decorator to disable compilation on a function. + It also provides the option of recursively disabling called functions. + + Args: + fn (optional): The function to disable + recursive (optional): A boolean value indicating whether the disabling should be recursive. + reason (optional): A string value indicating the reason for disabling the function. + """ + import torch._dynamo + + return torch._dynamo.disable(fn, recursive, reason=reason) + + +def set_stance( + stance: str = "default", + *, + skip_guard_eval_unsafe: bool = False, + force_backend: Union[str, Callable[..., Any], None] = None, +): + """ + Set the current stance of the compiler. + Can be used as a function, context manager, or decorator. + Do not use this function inside a `torch.compile` region - an error will be raised otherwise. + + .. code-block:: python + + @torch.compile + def foo(x): ... + + + @torch.compiler.set_stance("force_eager") + def bar(): + # will not be compiled + foo(...) + + + bar() + + with torch.compiler.set_stance("force_eager"): + # will also not be compiled + foo(...) + + torch.compiler.set_stance("force_eager") + # will also not be compiled + foo(...) + torch.compiler.set_stance("default") + + # will be compiled + foo(...) + + Args: + stance: The stance to set the compiler to. Valid values are: + + - "default": The default stance, used for normal compilation. + - "force_eager": Ignore all `torch.compile` directives. + - "eager_on_recompile": Run code eagerly when a recompile is necessary. + If there is cached compiled code valid for the input, it will still be used. + - "fail_on_recompile": Raise an error when recompiling a function. + - "eager_then_compile": Run the first invocation in eager mode, then compile on + subsequent calls. This is beneficial for dynamic shapes as it allows inferring + dynamism from the first two invocations instead of wasting a static compile on + the first invocation. + - "aot_eager_then_compile": Run the first invocation with AOT eager to get memory + benefits from activation checkpointing, then compile on subsequent calls. Like + eager_then_compile, this improves handling of dynamic shapes by avoiding an + initial static compile. + + + skip_guard_eval_unsafe: A flag to run only differentiating guards. + CAUTION - This flag is unsafe and should only be used if your setup + meets the following conditions. + + torch.compile uses a guard system to support recompilations and + choose which compiled artifact to run at runtime. These guards, + though efficient, add some overhead, which may impact performance in + scenarios where you need to optimize for minimal guard processing + time. This API enables you to disable guard evaluation, assuming + that you have warmed up the compiled model with a sufficient variety + of inputs. This assumption means that, after the warmup phase, no + further recompilations will be necessary. If this assumption fails, + there is a risk of silently producing incorrect results (hence the + term "unsafe" in the API name). + + force_backend: If `stance` is "default", this argument can be used to force `torch.compile` + to use a specific backend. Otherwise, an error is raised. + """ + import torch._dynamo + + return torch._dynamo.set_stance( + stance, + skip_guard_eval_unsafe=skip_guard_eval_unsafe, + force_backend=force_backend, + ) + + +# forbid in graph +set_stance._dynamo_forbidden = True # type: ignore[attr-defined] + + +def set_enable_guard_collectives(enabled: bool): + """ + Enables use of collectives *during* guard evaluation to synchronize behavior + across ranks. This is expensive: we have to issue a collective every time + we enter a compiled code region, even if no rank actually would need to + compile. This can help prevent NCCL hangs by ensuring that we never have a + situation where one rank starts recompiling while other ranks don't compile; + it is especially useful in conjunction with enable_compiler_collectives + where such a situation would immediately cause a hang (as it is necessary + for all ranks to compile at the same time to run compiler collectives). Like + compiler collectives, you can only run this on SPMD programs; you will hang + otherwise. Note that a guard collective is only issued if there is any + compiled code to guard on; if this the first time we encounter a frame or + the frame is skipped, we don't issue collectives. + + Returns the previous setting of enabled. + """ + from torch._C._dynamo.eval_frame import set_guard_complete_hook # noqa: F401 + from torch._dynamo.eval_frame import guard_collectives_hook + + if enabled: + return set_guard_complete_hook(guard_collectives_hook) is not None # type: ignore[arg-type] + else: + return set_guard_complete_hook(None) is not None + + +set_enable_guard_collectives._dynamo_forbidden = True # type: ignore[attr-defined] + + +def cudagraph_mark_step_begin(): + """ + Indicates that a new iteration of inference or training is about to begin. + + CUDA Graphs will free tensors of a prior iteration. A new iteration is started on each invocation of + torch.compile, so long as there is not a pending backward that has not been called. + + If that heuristic is wrong, such as in the following example, manually mark it with this api. + + .. code-block:: python + + @torch.compile(mode="reduce-overhead") + def rand_foo(): + return torch.rand([4], device="cuda") + + + for _ in range(5): + torch.compiler.cudagraph_mark_step_begin() + rand_foo() + rand_foo() + + For more details, see `torch.compiler_cudagraph_trees `__ + """ + from torch._inductor import cudagraph_trees + + cudagraph_trees.mark_step_begin() + + +def wrap_numpy(fn): + r"""Decorator that turns a function from ``np.ndarray``s to ``np.ndarray``s into a function + from ``torch.Tensor``s to ``torch.Tensor``s. + + It is designed to be used with :func:`torch.compile` with ``fullgraph=True``. It allows to + compile a NumPy function as if it were a PyTorch function. This allows you to run NumPy code + on CUDA or compute its gradients. + + .. note:: + + This decorator does not work without :func:`torch.compile`. + + Example:: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) + >>> # Compile a NumPy function as a Tensor -> Tensor function + >>> @torch.compile(fullgraph=True) + >>> @torch.compiler.wrap_numpy + >>> def fn(a: np.ndarray): + >>> return np.sum(a * a) + >>> # Execute the NumPy function using Tensors on CUDA and compute the gradients + >>> x = torch.arange(6, dtype=torch.float32, device="cuda", requires_grad=True) + >>> out = fn(x) + >>> out.backward() + >>> print(x.grad) + tensor([ 0., 2., 4., 6., 8., 10.], device='cuda:0') + """ + from torch._dynamo.external_utils import wrap_numpy as wrap + + return wrap(fn) + + +_is_compiling_flag: bool = False +_is_exporting_flag: bool = False + + +def is_compiling() -> bool: + """ + Indicates whether a graph is executed/traced as part of torch.compile() or torch.export(). + + Note that there are 2 other related flags that should deprecated eventually: + * torch._dynamo.external_utils.is_compiling() + * torch._utils.is_compiling() + + Example:: + + >>> def forward(self, x): + >>> if not torch.compiler.is_compiling(): + >>> pass # ...logic that is not needed in a compiled/traced graph... + >>> + >>> # ...rest of the function... + """ + if torch.jit.is_scripting(): + return False + else: + return _is_compiling_flag + + +def is_dynamo_compiling() -> bool: + """ + Indicates whether a graph is traced via TorchDynamo. + + It's stricter than is_compiling() flag, as it would only be set to True when + TorchDynamo is used. + + Example:: + + >>> def forward(self, x): + >>> if not torch.compiler.is_dynamo_compiling(): + >>> pass # ...logic that is not needed in a TorchDynamo-traced graph... + >>> + >>> # ...rest of the function... + """ + return False + + +def is_exporting() -> bool: + """ + Indicated whether we're under exporting. + + It's stricter than is_compiling() flag, as it would only be set to True when + torch.export is used. + + Example:: + + >>> def forward(self, x): + >>> if not torch.compiler.is_exporting(): + >>> pass # ...logic that is not needed in export... + >>> + >>> # ...rest of the function... + """ + return _is_exporting_flag + + +def save_cache_artifacts() -> Optional[tuple[bytes, "CacheInfo"]]: + """ + Serializes all the cache artifacts that were created during the compilation + + Example: + + - Execute torch.compile + - Call torch.compiler.save_cache_artifacts() + """ + from ._cache import CacheArtifactManager + + if torch._dynamo.config.caching_precompile: + from torch._dynamo.precompile_context import PrecompileContext + + PrecompileContext.save_to_dynamo_cache() + + return CacheArtifactManager.serialize() + + +def load_cache_artifacts(serialized_artifacts: bytes) -> Optional["CacheInfo"]: + """ + Hot loads cache artifacts that were previously serialized via + save_cache_artifacts + + Example: + + # From a previous invocation + artifacts = torch.compiler.save_cache_artifacts() + + torch.compiler.load_cache_artifacts(artifacts[0]) + """ + from ._cache import CacheArtifactManager, CacheInfo + + artifacts = CacheArtifactManager.deserialize(serialized_artifacts) + if artifacts is not None: + return CacheArtifactManager.populate_caches(artifacts) + return None + + +def skip_guard_on_inbuilt_nn_modules_unsafe(guard_entries): + """ + A common function to skip guards on the inbuilt nn modules like + torch.nn.Linear. This is unsafe to use by default. But for majority of + torch.compile users, the model code does not modify the inbuilt nn module + attributes. They can benefit from reduction in guard latency overhead using + this API. + + To use this API, use guard_filter_fn argument while calling torch.compile + + >> opt_mod = torch.compile( + >> mod, + >> options={"guard_filter_fn": torch.compiler.skip_guard_on_all_nn_modules_unsafe}, + >> ) + """ + return [ + not entry.orig_guard.source.is_unspecialized_builtin_nn_module() + for entry in guard_entries + ] + + +def skip_guard_on_all_nn_modules_unsafe(guard_entries): + """ + A common function to skip guards on all nn modules, both user defined as + well inbuilt nn modules (like torch.nn.Linear). This is unsafe to use by + default. But for majority of torch.compile users, the model code does not + modify the nn module attributes. They can benefit from reduction in guard + latency overhead using this API. + + To use this API, use guard_filter_fn argument while calling torch.compile + + >> opt_mod = torch.compile( + >> mod, + >> options={"guard_filter_fn": torch.compiler.skip_guard_on_all_nn_modules_unsafe}, + >> ) + """ + + return [ + not entry.orig_guard.source.is_unspecialized_nn_module() + for entry in guard_entries + ] + + +def keep_tensor_guards_unsafe(guard_entries, keep_parameters=False): + """ + A common function to keep tensor guards on all tensors. This is unsafe to + use by default. But if you don't expect any changes in the model code, you + can just keep the tensor guards. + + + >> opt_mod = torch.compile( + >> mod, + >> options={"guard_filter_fn": torch.compiler.keep_tensor_guards}, + >> ) + """ + + keep_flags = [] + for entry in guard_entries: + if entry.guard_type == "TENSOR_MATCH": + if not isinstance(entry.value, torch.nn.Parameter): + keep_flags.append(True) + elif keep_parameters: + keep_flags.append(True) + else: + keep_flags.append(False) + else: + keep_flags.append(False) + return keep_flags + + +def skip_guard_on_globals_unsafe(guard_entries): + """ + A common function to skip guards on all globals. This is unsafe to use by + default. But if you don't expect any changes in the globals, you can just + keep the tensor guards. + + >> opt_mod = torch.compile( + >> mod, + >> options={"guard_filter_fn": torch.compiler.skip_guard_on_globals}, + >> ) + """ + + return [not entry.is_global for entry in guard_entries] + + +def skip_all_guards_unsafe(guard_entries): + """ + A function for skipping all guards on a compiled function. + + WARNING: This function will drop all the safety guarantees from Dynamo + compiled function. Use this with caution. + + To use this API, use guard_filter_fn argument while calling torch.compile + + >> opt_mod = torch.compile( + >> mod, + >> options={"guard_filter_fn": torch.compiler.skip_all_guards_unsafe}, + >> ) + """ + return [False for entry in guard_entries] + + +def nested_compile_region(fn=None): + """ + Tells **``torch.compile``** that the marked set of operations forms a nested + compile region (which is often repeated in the full model) whose code can be + compiled once and safely reused. ``nested_compile_region`` can also be used + as a decorator. + + During **``torch.compile``** tracing, the compiler applies *hierarchical + compilation* with ``nested_compile_region``: it emits optimized code for the + marked region the first time it is encountered and re-emits (or “stamps + out”) the previously compiled code on every subsequent invocation. This can + substantially reduce overall compile time for deeply-stacked, + structurally-identical components such as the transformer layers of a + large-language-model (LLM). + + Outside a ``torch.compile`` context—i.e., in standard eager execution—the + call is a no-op, so existing workflows remain unaffected. + + Note that ``nested_compile_region`` **does not** promise that a region will + be compiled exactly once. If the compiler detects that new input conditions + (shape, dtype, device, stride, globals etc.) make the cached version invalid + to reuse, it will transparently re-compile the region. Using it is + therefore *safe*: correctness is always preserved, and you pay the extra + compilation cost only when required. + """ + + from torch._higher_order_ops.invoke_subgraph import ( + mark_compile_region as _mark_compile_region, + ) + + return _mark_compile_region(fn) + + +def load_compiled_function( + file: io.IOBase, *, f_globals: Optional[dict[str, object]] = None +) -> Callable[..., Any]: + """ + Load an aot-compiled function from a file. + + .. warning:: + + This API is currently experimental and subject to change. + + Args: + file: A file-like object containing the serialized compiled function. + f_globals: Optional globals to be loaded into the compiled function. + + Returns: + A torch-compiled function with compilation preloaded from disk. + """ + from torch._dynamo.aot_compile import AOTCompiledFunction + + data = file.read() + return AOTCompiledFunction.deserialize(data, f_globals) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/compiler/_cache.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/compiler/_cache.py new file mode 100644 index 0000000000000000000000000000000000000000..b525438d1bb5b74b7c6fe3c2e2df06366ff4bd7e --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/compiler/_cache.py @@ -0,0 +1,322 @@ +import copy +import dataclasses +import logging +from abc import ABC, abstractmethod +from collections import defaultdict +from collections.abc import Generator +from contextlib import contextmanager +from itertools import chain +from typing import Any, Optional + +from torch.utils._appending_byte_serializer import ( + AppendingByteSerializer, + BytesReader, + BytesWriter, +) +from torch.utils._ordered_set import OrderedSet + + +log = logging.getLogger(__name__) + + +@dataclasses.dataclass(frozen=True) +class CacheArtifact(ABC): + """ + Data for each cache artifact that will be serialized and deserialized + """ + + key: str + content: bytes = dataclasses.field(repr=False) # Do not display potential binary + + @staticmethod + def serialize(writer: BytesWriter, cls: "CacheArtifact") -> None: + writer.write_str(cls.key) + writer.write_bytes(cls.content) + + @staticmethod + def deserialize(artifact_type: str, reader: BytesReader) -> "CacheArtifact": + key = reader.read_str() + content = reader.read_bytes() + return CacheArtifactFactory.create(artifact_type, key, content) + + @staticmethod + def encode(content: Any) -> bytes: + assert isinstance(content, bytes), f"Expected bytes, got {type(content)}" + return content + + @abstractmethod + def populate_cache(self) -> None: + pass + + @staticmethod + def type() -> str: + """ + Returns the type of the artifact. Must be unique across all CacheArtifact classes. + + CacheArtifactFactory.register will add property method to CacheInfo based on this (def {type}_artifacts) + that returns all artifacts for specific cache. + """ + raise RuntimeError("CacheArtifact is an abstract class, please use a subclass") + + +class CacheArtifactFactory: + """ + Factory for creating CacheArtifact objects based on their type + """ + + _artifact_types: dict[str, type[CacheArtifact]] = {} + + @classmethod + def register(cls, artifact_cls: type[CacheArtifact]) -> type[CacheArtifact]: + artifact_type_key = artifact_cls.type() + assert artifact_cls.type() not in cls._artifact_types, ( + f"Artifact of type={artifact_type_key} already registered in mega-cache artifact factory" + ) + cls._artifact_types[artifact_type_key] = artifact_cls + setattr( + CacheInfo, + f"{artifact_type_key}_artifacts", + property(lambda self: self.artifacts[artifact_type_key]), + ) + return artifact_cls + + @classmethod + def _get_artifact_type(cls, artifact_type_key: str) -> type[CacheArtifact]: + assert artifact_type_key in cls._artifact_types, ( + f"Artifact of type={artifact_type_key} not registered in mega-cache artifact factory" + ) + return cls._artifact_types[artifact_type_key] + + @classmethod + def create(cls, artifact_type_key: str, key: str, content: bytes) -> CacheArtifact: + artifact_cls = cls._get_artifact_type(artifact_type_key) + # pyrefly: ignore [bad-instantiation] + return artifact_cls(key, content) + + @classmethod + def encode_create( + cls, artifact_type_key: str, key: str, content: Any + ) -> CacheArtifact: + artifact_cls = cls._get_artifact_type(artifact_type_key) + # pyrefly: ignore [bad-instantiation] + return artifact_cls(key, artifact_cls.encode(content)) + + +@dataclasses.dataclass +class CacheInfo: + """ + Return value of serialization and deserialization for the purpose of + instrumentation + """ + + artifacts: defaultdict[str, list[str]] = dataclasses.field( + default_factory=lambda: defaultdict(list) + ) + + # Methods set by CacheArtifactFactory.register based on CacheArtifact.type() + @property + def inductor_artifacts(self) -> list[str]: # type: ignore[empty-body] + ... + + @property + def autotune_artifacts(self) -> list[str]: # type: ignore[empty-body] + ... + + @property + def aot_autograd_artifacts(self) -> list[str]: # type: ignore[empty-body] + ... + + @property + def pgo_artifacts(self) -> list[str]: # type: ignore[empty-body] + ... + + @property + def precompile_artifacts(self) -> list[str]: # type: ignore[empty-body] + ... + + def add(self, artifact: CacheArtifact) -> None: + self.artifacts[artifact.type()].append(artifact.key) + + def clear(self) -> None: + self.artifacts.clear() + + def empty(self) -> bool: + return not self.artifacts + + +def _serialize_single_cache( + writer: BytesWriter, cls: "tuple[str, list[CacheArtifact]]" +) -> None: + writer.write_str(cls[0]) + writer.write_uint64(len(cls[1])) + for artifact in cls[1]: + CacheArtifact.serialize(writer, artifact) + + +def _deserialize_single_cache( + reader: BytesReader, +) -> "tuple[str, list[CacheArtifact]]": + artifacts = [] + artifact_type_key = reader.read_str() + num_artifacts = reader.read_uint64() + for _ in range(num_artifacts): + artifacts.append(CacheArtifact.deserialize(artifact_type_key, reader)) + + return artifact_type_key, artifacts + + +CacheArtifactsResult = dict[str, list[CacheArtifact]] + + +class CacheArtifactManager: + """ + Lightweight manager class for collecting and processing cache artifacts for + hot loading + + Intended Lifecycle: + - Execute code via torch.compile, this will call + CacheArtifactManager.record_artifact on each cache artifact + - Call CacheArtifactManager.serialize to convert all the cache artifacts + to portable format + - Call CacheArtifactManager.deserialize to hot load the cache artifacts on + a potentially different process + + NOTE: There's no FB/FC guarantees, results of cache artifacts will not be + used unless code version matches. + """ + + # Protected by the compile_lock + _new_cache_artifacts: CacheArtifactsResult = defaultdict(list) + # Keep a separate seen artifacts list to make avoid unnecessary duplicates + # This list will not be cleared between serialize() calls + _seen_artifacts: OrderedSet[CacheArtifact] = OrderedSet() + # When serialize() is called, artifacts are transferred from _cache_artifacts to + # internal data structure of the _serializer + # This allows us to only pay the cost of serialization if serialize() is called + _serializer: AppendingByteSerializer[tuple[str, list[CacheArtifact]]] = ( + AppendingByteSerializer(serialize_fn=_serialize_single_cache) + ) + _cache_info: CacheInfo = CacheInfo() + + @classmethod + def clear(cls) -> None: + cls._new_cache_artifacts.clear() + cls._seen_artifacts.clear() + cls._serializer.clear() + cls._cache_info.clear() + + @classmethod + @contextmanager + def with_fresh_cache(cls) -> Generator[None, None, None]: + original_new_cache_artifacts = cls._new_cache_artifacts + original_seen_artifacts = cls._seen_artifacts + original_serializer = cls._serializer + original_cache_info = cls._cache_info + + cls._new_cache_artifacts = defaultdict(list) + cls._seen_artifacts = OrderedSet() + cls._serializer = AppendingByteSerializer(serialize_fn=_serialize_single_cache) + cls._cache_info = cls._cache_info.__class__() + try: + yield + finally: + cls._new_cache_artifacts = original_new_cache_artifacts + cls._seen_artifacts = original_seen_artifacts + cls._serializer = original_serializer + cls._cache_info = original_cache_info + + @classmethod + def record_artifact( + cls, + artifact_type: str, + key: str, + content: Any, + ) -> None: + """ + Called from each caching operation to record the artifact in this + "mega" list + """ + artifact = CacheArtifactFactory.encode_create(artifact_type, key, content) + if artifact in cls._seen_artifacts: + return + log.debug("Recording %s", str(artifact)) + cls._new_cache_artifacts[artifact_type].append(artifact) + cls._seen_artifacts.add(artifact) + + @classmethod + def need_serialize(cls) -> bool: + """ + Have we seen new artifacts since last serialize call? + """ + return len(cls._new_cache_artifacts) != 0 + + @classmethod + def serialize(cls) -> Optional[tuple[bytes, CacheInfo]]: + """ + Converts the "mega" list into portable format + """ + for artifact in chain(*cls._new_cache_artifacts.values()): + log.debug("saving: %s", artifact) + cls._cache_info.add(artifact) + + if cls._cache_info.empty(): + # If there are not artifacts, dont just return bytes with + # version. + return None + + try: + # We deep copy cls._cache_info since later compilations + # can keep adding to cache_info + info = copy.deepcopy(cls._cache_info) + cls._serializer.extend(cls._new_cache_artifacts.items()) + artifact_bytes = cls._serializer.to_bytes() + cls._new_cache_artifacts.clear() + return artifact_bytes, info + except Exception: + log.warning("Failed to pickle cache artifacts", exc_info=True) + return None + + @staticmethod + def deserialize(serialized_artifacts: bytes) -> Optional[CacheArtifactsResult]: + """ + Converts the portable format back into CacheArtifacts + """ + try: + CacheArtifactManager._ensure_cache_artifacts_registered() + artifacts = dict( + AppendingByteSerializer.to_list( + serialized_artifacts, + deserialize_fn=_deserialize_single_cache, + ) + ) + except Exception: + log.warning("Failed to un-pickle cache artifacts", exc_info=True) + return None + + return artifacts + + @staticmethod + def populate_caches(artifacts: CacheArtifactsResult) -> CacheInfo: + info = CacheInfo() + for artifact in chain(*artifacts.values()): + log.debug("writing: %s", artifact) + info.add(artifact) + artifact.populate_cache() + + return info + + @classmethod + def _ensure_cache_artifacts_registered(cls) -> None: + """When deserializing caches in fresh process, we need to ensure that all + cache artifacts are registered in the cache registry. This is done by + simply importing all the cache artifacts already wrapped with register call. + """ + from torch._dynamo.package import PrecompileCacheArtifact # noqa: F401 + from torch._dynamo.pgo import PGOCacheArtifact # noqa: F401 + from torch._functorch._aot_autograd.autograd_cache import ( # noqa: F401 + AOTAutogradCacheArtifact, + ) + from torch._inductor.codecache import InductorCacheArtifact # noqa: F401 + from torch._inductor.runtime.autotune_cache import ( # noqa: F401 + AutotuneCacheArtifact, + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/compiler/config.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/compiler/config.py new file mode 100644 index 0000000000000000000000000000000000000000..e507ddc18052e2e9b53bd6de3a6001a1d17a1be0 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/compiler/config.py @@ -0,0 +1,281 @@ +""" +This is the top-level configuration module for the compiler, containing +cross-cutting configuration options that affect all parts of the compiler +stack. + +You may also be interested in the per-component configuration modules, which +contain configuration options that affect only a specific part of the compiler: + +* :mod:`torch._dynamo.config` +* :mod:`torch._inductor.config` +* :mod:`torch._functorch.config` +* :mod:`torch.fx.experimental.config` +""" + +import sys +from typing import Optional + +from torch.utils._config_module import Config, install_config_module + + +__all__ = [ + "job_id", + "dynamic_shapes", + "assume_static_by_default", + "automatic_dynamic_shapes", + "recompile_limit", + "accumulated_recompile_limit", + "verbose", + "capture_scalar_outputs", + "capture_dynamic_output_shape_ops", + "log_file_name", + "fail_on_recompile_limit_hit", + "allow_unspec_int_on_nn_module", + "skip_tensor_guards_with_matching_dict_tags", + "enable_cpp_symbolic_shape_guards", + "wrap_top_frame", + "reorderable_logging_functions", + "force_disable_caches", +] + + +# NB: Docblocks go UNDER variable definitions! Use spacing to make the +# grouping clear. + +# FB-internal note: you do NOT have to specify this explicitly specify this if +# you run on MAST, we will automatically default this to +# mast:MAST_JOB_NAME:MAST_JOB_VERSION. +job_id: Optional[str] = Config( + env_name_default=["TORCH_COMPILE_JOB_ID", "TORCH_COMPILE_STICKY_PGO_KEY"], + default=None, +) +""" +Semantically, this should be an identifier that uniquely identifies, e.g., a +training job. You might have multiple attempts of the same job, e.g., if it was +preempted or needed to be restarted, but each attempt should be running +substantially the same workload with the same distributed topology. You can +set this by environment variable with :envvar:`TORCH_COMPILE_JOB_ID`. + +Operationally, this controls the effect of profile-guided optimization related +persistent state. PGO state can affect how we perform compilation across +multiple invocations of PyTorch, e.g., the first time you run your program we +may compile twice as we discover what inputs are dynamic, and then PGO will +save this state so subsequent invocations only need to compile once, because +they remember it is dynamic. This profile information, however, is sensitive +to what workload you are running, so we require you to tell us that two jobs +are *related* (i.e., are the same workload) before we are willing to reuse +this information. Notably, PGO does nothing (even if explicitly enabled) +unless a valid ``job_id`` is available. In some situations, PyTorch can +configured to automatically compute a ``job_id`` based on the environment it +is running in. + +Profiles are always collected on a per rank basis, so different ranks may have +different profiles. If you know your workload is truly SPMD, you can run with +:data:`torch._dynamo.config.enable_compiler_collectives` to ensure nodes get +consistent profiles across all ranks. +""" + +pgo_extra_read_key: Optional[str] = Config( + env_name_default="TORCH_COMPILE_STICKY_PGO_READ", default=None +) +pgo_extra_write_key: Optional[str] = Config( + env_name_default="TORCH_COMPILE_STICKY_PGO_WRITE", default=None +) +""" +Additional read/write keys for PGO. +Write key: Besides writing to the default local/remote PGO state, this also writes to the specified key. +Read key: Besides reading from the default state, this also reads from the specified key (if written to before) +and merges it with the default state. +""" + + +cache_key_tag: str = Config(env_name_default="TORCH_COMPILE_CACHE_KEY_TAG", default="") +""" +Tag to be included in the cache key generation for all torch compile caching. +A common use case for such a tag is to break caches. +""" + +force_disable_caches: bool = Config( + justknob="pytorch/remote_cache:force_disable_caches", + env_name_force=[ + "TORCHINDUCTOR_FORCE_DISABLE_CACHES", + "TORCH_COMPILE_FORCE_DISABLE_CACHES", + ], + default=False, +) +""" +Force disables all caching -- This will take precedence over and override any other caching flag +""" + +dynamic_sources: str = Config( + env_name_default="TORCH_COMPILE_DYNAMIC_SOURCES", default="" +) +""" +Comma delimited list of sources that should be marked as dynamic. Primarily useful for large +models with graph breaks where you need intermediate tensors and ints to be marked dynamic. + +This whitelist is dominant over all other flags dynamic=False, force_nn_module_property_static_shapes +and force_parameter_static_shapes. +""" + +unbacked_sources: str = Config( + env_name_default="TORCH_COMPILE_UNBACKED_SOURCES", default="" +) +""" +Comma delimited list of sources that should be marked as unbacked. Primarily useful for large +models with graph breaks where you need intermediate tensors marked unbacked. + +This whitelist is dominant over all other flags dynamic=False, force_nn_module_property_static_shapes +and force_parameter_static_shapes. +""" + +# force a python GC before recording cudagraphs +force_cudagraph_gc: bool = Config(env_name_default="TORCH_CUDAGRAPH_GC", default=False) +""" +If True (the backward-compatible behavior) then gc.collect() before recording +any cudagraph. +""" + + +# Cross-cutting configuration options that affect the entire compilation pipeline + +dynamic_shapes: bool = Config(alias="torch._dynamo.config.dynamic_shapes") +""" +Controls whether the compilation pipeline supports dynamic tensor shapes. +When enabled, the compiler can handle tensors with varying dimensions across +different invocations. This is a cross-cutting setting that affects shape +inference, guard generation, and code generation across the entire compilation +stack. +""" + +assume_static_by_default: bool = Config( + alias="torch._dynamo.config.assume_static_by_default" +) +""" +When enabled, all tensor dimensions are assumed to be static unless explicitly +marked as dynamic or detected as changing. This compilation-wide behavior affects +how the entire stack handles shape specialization and can improve performance +for static workloads. +""" + +automatic_dynamic_shapes: bool = Config( + alias="torch._dynamo.config.automatic_dynamic_shapes" +) +""" +Enables automatic detection and handling of dynamic shapes. When a tensor's +shape changes between compilations, the system automatically marks those +dimensions as dynamic rather than requiring manual specification. This +cross-cutting optimization improves the user experience by reducing recompilations. +""" + +recompile_limit: int = Config(alias="torch._dynamo.config.recompile_limit") +""" +Maximum number of recompilations allowed for a single function before falling +back to eager execution. This compilation performance control prevents excessive +recompilation overhead that can degrade overall performance. +""" + +accumulated_recompile_limit: int = Config( + alias="torch._dynamo.config.accumulated_recompile_limit" +) +""" +Global limit on total recompilations across all compiled functions to prevent +runaway recompilation scenarios. This safeguard protects against compilation +performance issues that could affect the entire program. +""" + +verbose: bool = Config(alias="torch._dynamo.config.verbose") +""" +Enables verbose debugging output for Dynamo. When enabled, provides detailed +information about Dynamo's compilation decisions, optimizations, and potential +issues. +""" + + +# TorchDynamo-specific configuration options + +capture_scalar_outputs: bool = Config( + alias="torch._dynamo.config.capture_scalar_outputs" +) +""" +Controls whether TorchDynamo captures operations that return scalar values (like .item()) +into the FX graph. When disabled, these operations cause graph breaks. This is a +TorchDynamo-specific tracing behavior that affects how the tracer handles +scalar-returning operations. +""" + +capture_dynamic_output_shape_ops: bool = Config( + alias="torch._dynamo.config.capture_dynamic_output_shape_ops" +) +""" +Controls whether TorchDynamo captures operations with dynamic output shapes (like +nonzero, unique) into the FX graph. When disabled, these operations cause graph breaks. +This is a TorchDynamo-specific setting for handling operations with unpredictable +output shapes during tracing. +""" + +log_file_name: Optional[str] = Config(alias="torch._dynamo.config.log_file_name") +""" +Specifies a file path for TorchDynamo-specific logging output. When set, internal +TorchDynamo debug information is written to this file rather than stdout. This is +useful for debugging TorchDynamo's internal tracing behavior. +""" + +fail_on_recompile_limit_hit: bool = Config( + alias="torch._dynamo.config.fail_on_recompile_limit_hit" +) +""" +Raises a hard error when recompile limits are exceeded instead of falling back +to eager execution. This is useful for detecting excessive recompilation in +performance-critical deployments where you want to ensure compilation overhead +is kept under control. +""" + +allow_unspec_int_on_nn_module: bool = Config( + alias="torch._dynamo.config.allow_unspec_int_on_nn_module" +) +""" +Allows integer attributes of nn.Module instances to be unspecialized through +the dynamic shape mechanism. By default, TorchDynamo specializes on all integer +module attributes, but this can cause excessive recompilation when integers +like step counters change frequently. +""" + +skip_tensor_guards_with_matching_dict_tags: bool = Config( + alias="torch._dynamo.config.skip_tensor_guards_with_matching_dict_tags" +) +""" +Optimizes guard generation by treating tensors as immutable when they are +dictionary values with consistent dictionary tags across invocations. This +reduces guard overhead for tensors stored in persistent data structures. +""" + +enable_cpp_symbolic_shape_guards: bool = Config( + alias="torch._dynamo.config.enable_cpp_symbolic_shape_guards" +) +""" +Uses C++ implementation for symbolic shape guard evaluation to improve performance. +The C++ guard manager can significantly speed up guard checking for symbolic shapes +in shape-polymorphic compilations. +""" + +wrap_top_frame: bool = Config(alias="torch._dynamo.config.wrap_top_frame") +""" +Wraps the top-level decorated function/module in a frame wrapper to ensure +nn.Module hooks are compiled within the same frame as the main function. This +improves compilation coverage for models that rely on hooks. +""" + +reorderable_logging_functions: set = Config( + alias="torch._dynamo.config.reorderable_logging_functions" +) +""" +A set of logging functions that can be reordered to execute after the compiled +portion of the graph, allowing larger graphs to be captured. Functions in this +set will have their execution deferred to avoid graph breaks, though this may +affect the timing of log output. In particular, mutated values will not be logged +at the right time, leading to incorrect logging. +""" + + +install_config_module(sys.modules[__name__]) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/cpu/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/cpu/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b42b7f0ff54bd7dafda3fb72cffe93a4e4e23645 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/cpu/__init__.py @@ -0,0 +1,202 @@ +# mypy: allow-untyped-defs +r""" +This package implements abstractions found in ``torch.cuda`` +to facilitate writing device-agnostic code. +""" + +from contextlib import AbstractContextManager +from typing import Any, Optional, Union + +import torch + +from .. import device as _device +from . import amp + + +__all__ = [ + "is_available", + "is_initialized", + "synchronize", + "current_device", + "current_stream", + "stream", + "set_device", + "device_count", + "Stream", + "StreamContext", + "Event", +] + + +def _is_avx2_supported() -> bool: + r"""Returns a bool indicating if CPU supports AVX2.""" + return torch._C._cpu._is_avx2_supported() + + +def _is_avx512_supported() -> bool: + r"""Returns a bool indicating if CPU supports AVX512.""" + return torch._C._cpu._is_avx512_supported() + + +def _is_avx512_bf16_supported() -> bool: + r"""Returns a bool indicating if CPU supports AVX512_BF16.""" + return torch._C._cpu._is_avx512_bf16_supported() + + +def _is_vnni_supported() -> bool: + r"""Returns a bool indicating if CPU supports VNNI.""" + # Note: Currently, it only checks avx512_vnni, will add the support of avx2_vnni later. + return torch._C._cpu._is_avx512_vnni_supported() + + +def _is_amx_tile_supported() -> bool: + r"""Returns a bool indicating if CPU supports AMX_TILE.""" + return torch._C._cpu._is_amx_tile_supported() + + +def _is_amx_fp16_supported() -> bool: + r"""Returns a bool indicating if CPU supports AMX FP16.""" + return torch._C._cpu._is_amx_fp16_supported() + + +def _init_amx() -> bool: + r"""Initializes AMX instructions.""" + return torch._C._cpu._init_amx() + + +def is_available() -> bool: + r"""Returns a bool indicating if CPU is currently available. + + N.B. This function only exists to facilitate device-agnostic code + + """ + return True + + +def synchronize(device: torch.types.Device = None) -> None: + r"""Waits for all kernels in all streams on the CPU device to complete. + + Args: + device (torch.device or int, optional): ignored, there's only one CPU device. + + N.B. This function only exists to facilitate device-agnostic code. + """ + + +class Stream: + """ + N.B. This class only exists to facilitate device-agnostic code + """ + + def __init__(self, priority: int = -1) -> None: + pass + + def wait_stream(self, stream) -> None: + pass + + def record_event(self) -> None: + pass + + def wait_event(self, event) -> None: + pass + + +class Event: + def query(self) -> bool: + return True + + def record(self, stream=None) -> None: + pass + + def synchronize(self) -> None: + pass + + def wait(self, stream=None) -> None: + pass + + +_default_cpu_stream = Stream() +_current_stream = _default_cpu_stream + + +def current_stream(device: torch.types.Device = None) -> Stream: + r"""Returns the currently selected :class:`Stream` for a given device. + + Args: + device (torch.device or int, optional): Ignored. + + N.B. This function only exists to facilitate device-agnostic code + + """ + return _current_stream + + +class StreamContext(AbstractContextManager): + r"""Context-manager that selects a given stream. + + N.B. This class only exists to facilitate device-agnostic code + + """ + + cur_stream: Optional[Stream] + + def __init__(self, stream): + self.stream = stream + self.prev_stream = _default_cpu_stream + + def __enter__(self): + cur_stream = self.stream + if cur_stream is None: + return + + global _current_stream + self.prev_stream = _current_stream + _current_stream = cur_stream + + def __exit__(self, type: Any, value: Any, traceback: Any) -> None: + cur_stream = self.stream + if cur_stream is None: + return + + global _current_stream + _current_stream = self.prev_stream + + +def stream(stream: Stream) -> AbstractContextManager: + r"""Wrapper around the Context-manager StreamContext that + selects a given stream. + + N.B. This function only exists to facilitate device-agnostic code + """ + return StreamContext(stream) + + +def device_count() -> int: + r"""Returns number of CPU devices (not cores). Always 1. + + N.B. This function only exists to facilitate device-agnostic code + """ + return 1 + + +def set_device(device: torch.types.Device) -> None: + r"""Sets the current device, in CPU we do nothing. + + N.B. This function only exists to facilitate device-agnostic code + """ + + +def current_device() -> str: + r"""Returns current device for cpu. Always 'cpu'. + + N.B. This function only exists to facilitate device-agnostic code + """ + return "cpu" + + +def is_initialized() -> bool: + r"""Returns True if the CPU is initialized. Always True. + + N.B. This function only exists to facilitate device-agnostic code + """ + return True diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..095e8e9bf2654e3e609554dec3fde496abe66a44 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/__init__.py @@ -0,0 +1,168 @@ +# mypy: allow-untyped-defs +import logging +import pdb +import sys +import traceback +import typing +from datetime import timedelta + +import torch + + +log = logging.getLogger(__name__) + + +def is_available() -> bool: + """ + Return ``True`` if the distributed package is available. + + Otherwise, + ``torch.distributed`` does not expose any other APIs. Currently, + ``torch.distributed`` is available on Linux, MacOS and Windows. Set + ``USE_DISTRIBUTED=1`` to enable it when building PyTorch from source. + Currently, the default value is ``USE_DISTRIBUTED=1`` for Linux and Windows, + ``USE_DISTRIBUTED=0`` for MacOS. + """ + return hasattr(torch._C, "_c10d_init") + + +if is_available() and not torch._C._c10d_init(): + raise RuntimeError("Failed to initialize torch.distributed") + +# Custom Runtime Errors thrown from the distributed package +DistError = torch._C._DistError +DistBackendError = torch._C._DistBackendError +DistNetworkError = torch._C._DistNetworkError +DistStoreError = torch._C._DistStoreError +QueueEmptyError = torch._C._DistQueueEmptyError + +if is_available(): + from torch._C._distributed_c10d import ( + _broadcast_coalesced, + _compute_bucket_assignment_by_size, + _ControlCollectives, + _DEFAULT_FIRST_BUCKET_BYTES, + _make_nccl_premul_sum, + _register_builtin_comm_hook, + _register_comm_hook, + _StoreCollectives, + _test_python_store, + _verify_params_across_processes, + Backend as _Backend, + BuiltinCommHookType, + DebugLevel, + FileStore, + get_debug_level, + GradBucket, + Logger, + PrefixStore, + ProcessGroup as ProcessGroup, + Reducer, + set_debug_level, + set_debug_level_from_env, + Store, + TCPStore, + Work as _Work, + ) + + class _DistributedPdb(pdb.Pdb): + """ + Supports using PDB from inside a multiprocessing child process. + + Usage: + _DistributedPdb().set_trace() + """ + + def interaction(self, *args, **kwargs): + _stdin = sys.stdin + try: + with open("/dev/stdin") as sys.stdin: + pdb.Pdb.interaction(self, *args, **kwargs) + finally: + sys.stdin = _stdin + + _breakpoint_cache: dict[int, typing.Any] = {} + + def breakpoint(rank: int = 0, skip: int = 0, timeout_s=3600): + """ + Set a breakpoint, but only on a single rank. All other ranks will wait for you to be + done with the breakpoint before continuing. + + Args: + rank (int): Which rank to break on. Default: ``0`` + skip (int): Skip the first ``skip`` calls to this breakpoint. Default: ``0``. + """ + if skip > 0: + key = hash(str(traceback.format_exc())) + counter = _breakpoint_cache.get(key, 0) + 1 + _breakpoint_cache[key] = counter + if counter <= skip: + log.warning("Skip the breakpoint, counter=%d", counter) + return + + # avoid having the default timeout (if short) interrupt your debug session + if timeout_s is not None: + for group in torch.distributed.distributed_c10d._pg_map: + torch.distributed.distributed_c10d._set_pg_timeout( + timedelta(seconds=timeout_s), group + ) + + if get_rank() == rank: + pdb = _DistributedPdb() + pdb.message( + "\n!!! ATTENTION !!!\n\n" + f"Type 'up' to get to the frame that called dist.breakpoint(rank={rank})\n" + ) + pdb.set_trace() + # If Meta/Python keys are in the TLS, we want to make sure that we ignore them + # and hit the (default) CPU/CUDA implementation of barrier. + meta_in_tls = torch._C._meta_in_tls_dispatch_include() + guard = torch._C._DisableTorchDispatch() # type: ignore[attr-defined] + torch._C._set_meta_in_tls_dispatch_include(False) + try: + barrier() + finally: + torch._C._set_meta_in_tls_dispatch_include(meta_in_tls) + del guard + + if sys.platform != "win32": + from torch._C._distributed_c10d import HashStore + + from .device_mesh import DeviceMesh, init_device_mesh + + # Variables prefixed with underscore are not auto imported + # See the comment in `distributed_c10d.py` above `_backend` on why we expose + # this. + # pyrefly: ignore [deprecated] + from .distributed_c10d import * # noqa: F403 + from .distributed_c10d import ( # pyrefly: ignore # deprecated; pyrefly: ignore [deprecated] + _all_gather_base, + _coalescing_manager, + _CoalescingManager, + _create_process_group_wrapper, + _get_process_group_name, + _rank_not_in_group, + _reduce_scatter_base, + _time_estimator, + get_node_local_rank, + ) + from .remote_device import _remote_device + from .rendezvous import ( + _create_store_from_options, + register_rendezvous_handler, + rendezvous, + ) + + set_debug_level_from_env() + +else: + # This stub is sufficient to get + # python test/test_public_bindings.py -k test_correct_module_names + # working even when USE_DISTRIBUTED=0. Feel free to add more + # stubs as necessary. + # We cannot define stubs directly because they confuse pyre + + class _ProcessGroupStub: + pass + + sys.modules["torch.distributed"].ProcessGroup = _ProcessGroupStub # type: ignore[attr-defined] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_checkpointable.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_checkpointable.py new file mode 100644 index 0000000000000000000000000000000000000000..0594c20337b3bf1c73fb40e2218e0c71580b75c5 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_checkpointable.py @@ -0,0 +1,37 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +from typing_extensions import Protocol, runtime_checkable + +import torch + + +@runtime_checkable +class _Checkpointable(Protocol): # noqa: PYI046 + """ + Interface for checkpointable objects. + Implemented as a protocol, implicit subtyping is supported so subclasses do not need to inherit this explicitly. + This is to allow arbitrary objects/tensor subclasses to hook into DCP seamlessly through implementing the interface. + """ + + def __create_write_items__(self, fqn: str, object: object) -> list[object]: + """ + Return a list of WriteItems based on object's contents. + """ + raise NotImplementedError( + "_Checkpointable._create_write_items is not implemented" + ) + + def __create_chunk_list__(self) -> list[object]: + """ + Return a list of `ChunkStorageMetadata` based on object's contents. + """ + raise NotImplementedError( + "_Checkpointable._create_chunk_list is not implemented" + ) + + def __get_tensor_shard__(self, index: int) -> torch.Tensor: + """ + Return a 'torch.Tensor' shard based on 'MetadataIndex'. + """ + raise NotImplementedError( + "_Checkpointable._get_tensor_shard is not implemented" + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_composable_state.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_composable_state.py new file mode 100644 index 0000000000000000000000000000000000000000..b91797536ec7fb969e9ad2c57cbbe9b7e0bd181c --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_composable_state.py @@ -0,0 +1,46 @@ +import weakref +from typing import cast + +import torch.nn as nn + + +class _State: + pass + + +_module_state_mapping: weakref.WeakKeyDictionary[ + nn.Module, weakref.ReferenceType[_State] +] = weakref.WeakKeyDictionary() + + +def _insert_module_state(module: nn.Module, state: _State) -> None: + global _module_state_mapping + if module in _module_state_mapping: + raise AssertionError(f"Inserting {module} more than once.") + _module_state_mapping[module] = weakref.ref(state) + + +def _get_module_state(module: nn.Module) -> _State | None: + """ + Return the ``_State`` in ``model``. + + Given a ``module``, this API finds out if the module is also a ``_State`` + instance or if the module is managed by a composable API. If the module + is also a ``_State``, ``module`` will be casted to ``_State` and returned. + If it is managed by a composable API, the corresponding ``_State`` will + be returned. + """ + global _module_state_mapping + if isinstance(module, _State): + # pyrefly: ignore [redundant-cast] + return cast(_State, module) + else: + # https://github.com/pytorch/pytorch/issues/107054 + if module in _module_state_mapping: + state_ref = _module_state_mapping[module] + state = state_ref() + if state is None: + raise AssertionError("State has already been garbage collected") + return state + else: + return None diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_dist2.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_dist2.py new file mode 100644 index 0000000000000000000000000000000000000000..9ec53372c4d62bd24de7956ef95e1c033c3b3bd7 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_dist2.py @@ -0,0 +1,183 @@ +""" +This is an experimental new API for PyTorch Distributed. This is actively in development and subject to change or deletion entirely. + +This is intended as a proving ground for more flexible and object oriented distributed APIs. +""" + +from collections.abc import Generator +from contextlib import contextmanager +from datetime import timedelta +from typing import Protocol + +import torch +from torch._C._distributed_c10d import ( + _current_process_group, + _set_process_group, + ProcessGroup, + ReduceOp, + Store, +) +from torch.distributed.rendezvous import rendezvous + + +_BACKENDS: dict[str, "ProcessGroupFactory"] = {} + +__all__ = [ + "ProcessGroup", + "ReduceOp", + "ProcessGroupFactory", + "register_backend", + "new_group", + "current_process_group", + "process_group", +] + + +class ProcessGroupFactory(Protocol): + """Protocol for process group factories.""" + + def __call__( + self, + store: Store, + rank: int, + world_size: int, + timeout: timedelta, + device: torch.device, + **kwargs: object, + ) -> ProcessGroup: ... + + +def register_backend(name: str, func: ProcessGroupFactory) -> None: + """ + Register a new process group backend. + + Args: + name: The name of the backend. + func: The function to create the process group. + """ + if name in _BACKENDS: + raise ValueError(f"Backend {name} already registered") + + _BACKENDS[name] = func + + +def _gloo_factory( + store: Store, + rank: int, + world_size: int, + timeout: timedelta, + device: torch.device, + **kwargs: object, +) -> ProcessGroup: + from torch.distributed import ProcessGroupGloo + + if len(kwargs) != 0: + raise AssertionError("Gloo backend received unexpected kwargs") + + backend_class = ProcessGroupGloo(store, rank, world_size, timeout) + backend_class._set_sequence_number_for_group() + + pg = ProcessGroup(store, rank, world_size) + pg._set_default_backend(ProcessGroup.BackendType.GLOO) + + # register devices + pg._register_backend(device, ProcessGroup.BackendType.GLOO, backend_class) + pg._register_backend( + torch.device("cpu"), ProcessGroup.BackendType.GLOO, backend_class + ) + if torch.cuda.is_available(): + pg._register_backend( + torch.device("cuda"), ProcessGroup.BackendType.GLOO, backend_class + ) + return pg + + +def _nccl_factory( + store: Store, + rank: int, + world_size: int, + timeout: timedelta, + device: torch.device, + **kwargs: object, +) -> ProcessGroup: + from torch.distributed import ProcessGroupNCCL + + opts = ProcessGroupNCCL.Options() + opts._timeout = timeout + for k, v in kwargs.items(): + if not hasattr(opts, k): + raise KeyError(f"Unknown option {k}") + setattr(opts, k, v) + + backend_class = ProcessGroupNCCL(store, rank, world_size, opts) + backend_class._set_sequence_number_for_group() + backend_class.eager_connect_single_device(device) + + pg = ProcessGroup(store, rank, world_size) + pg._set_default_backend(ProcessGroup.BackendType.NCCL) + pg._register_backend(device, ProcessGroup.BackendType.NCCL, backend_class) + + return pg + + +register_backend("gloo", _gloo_factory) +register_backend("nccl", _nccl_factory) + + +def new_group( + backend: str, + timeout: timedelta, + device: str | torch.device, + **kwargs: object, +) -> ProcessGroup: + """ + Create a new process group with the given backend and options. This group is + independent and will not be globally registered and thus not usable via the + standard torch.distributed.* APIs. + + Args: + backend: The backend to use for the process group. + timeout: The timeout for collective operations. + device: The device to use for the process group. + **kwargs: All remaining arguments are passed to the backend constructor. + See the backend specific documentation for details. + + Returns: + A new process group. + """ + if backend not in _BACKENDS: + raise ValueError(f"Backend {backend} not registered") + + device = torch.device(device) + + store, rank, world_size = next(iter(rendezvous("env://"))) + store.set_timeout(timeout) + + return _BACKENDS[backend](store, rank, world_size, timeout, device, **kwargs) + + +def current_process_group() -> ProcessGroup: + """ + Get the current process group. Thread local method. + + Returns: + The current process group. + """ + return _current_process_group() + + +@contextmanager +def process_group(pg: ProcessGroup) -> Generator[None, None, None]: + """ + Context manager for process groups. Thread local method. + + Args: + pg: The process group to use. + """ + prev_pg = current_process_group() + + _set_process_group(pg) + try: + yield + finally: + _set_process_group(prev_pg) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_functional_collectives.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_functional_collectives.py new file mode 100644 index 0000000000000000000000000000000000000000..24d7d5cf2748bd70b9df1a54895da03870185192 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_functional_collectives.py @@ -0,0 +1,1251 @@ +# mypy: allow-untyped-defs +import contextlib +import sys +import warnings +from typing import Any, cast, TYPE_CHECKING, Union + +import torch +import torch.distributed as dist +import torch.distributed.distributed_c10d as c10d +from torch.distributed.device_mesh import DeviceMesh +from torch.fx.experimental.proxy_tensor import get_proxy_mode + +from . import _functional_collectives_impl as fun_col_impl + + +try: + from torch.utils._cxx_pytree import tree_map_only +except ImportError: + from torch.utils._pytree import tree_map_only # type: ignore[no-redef] + + +try: + from torch.compiler import is_dynamo_compiling as is_torchdynamo_compiling +except Exception: + warnings.warn( + "Unable to import torchdynamo util `is_torchdynamo_compiling`, so won't support torchdynamo correctly", + stacklevel=2, + ) + + def is_torchdynamo_compiling(): # type: ignore[misc] + return False + return False + + +""" +New traceable, functional collectives. +RFC: https://github.com/pytorch/pytorch/issues/93173 + + compiler: trace these ops with plain-old-data schemas, then choose how to lower them. + eager: execute these 'functional' ops which in eager return AsyncCollectiveTensor subclasses, + automatically calling .wait() on underlying/hidden async 'work' obj only when fed to + a downstream op. + +Issues: +* Where should these ops live? Couldn't `import torch` if putting these ops in existing torch.distributed files +* Proper support for eager requires inplace ops. We should explore having it as an option for the API. +""" + +""" +Functional collectives are asynchronous only and we perform implicit stream synchronization +on behalf of the user. + +We use AsyncCollectiveTensor to wrap the result tensor of a collective and it lets us witness +first usage of the tensor and insert cross stream sync at the right place. + +The above are the easy bits, the hard one is how we match the Work object returned by +c10d and the tensor AsyncCollectiveTensor wraps. We alloc the tensor inside the collective +op implementation (see ``clone()`` call in ``_all_reduce``) and then it's handled by the +dispatcher which might call other implementations that are allowed to change the returned +tensor - even return a tensor with a different shape (see ``torch.vmap``). + +This means the caller of our ops receives a Tensor that is not guaranteed to be the same +allocated by our implementations and that makes pairing The AsyncTensor to the original +tensor a lot harder. This pairing is needed so we can lookup the Work object to use. + +Originally, we tried WeakKeyDictionary to map from Tensor to Work, but because Tensor's +identity is not stable across dispatch, the op caller would end up with a different Tensor +instance that would not match any in the dictionary. + +With Tensor identity out of the question, we decided use the tensor data pointer, which +should be stable across all the Tensor changes done during dispatch. + +We have a dictionary of tensor::data_ptr -> Work that we insert right after we call into c10d. + +We use this dictionary when AsyncCollectiveTensor is used to invoke Work::wait() + +Finally, we setup a finalizer against the tensor wrapper to observe it getting collected so we +can clean up stale entries in the dictionary. + +To eliminate the possibility of races we have a global version counter that is used by the finalizer. + +As a wise man said once: Don't cross the streams (https://www.youtube.com/watch?v=wyKQe_i9yyo) + +""" + +""" +Functional collectives can accept any of these types to describe the ranks participating in collectives. + +The different types will be desugared to a canonical format +""" +RANK_TYPES = Union[ + list[int], + list[list[int]], + dist.ProcessGroup, + DeviceMesh, + tuple["dist.tensor.DeviceMesh", int], + c10d.GroupName, +] + + +""" +User facing APIs for functional collectives +------------------------------------------- + +These apis are called by user code and expected to work both in eager execution and compilation, +but there are significant differences to how the two modes are implemented underneath. + +Eager execution is 'optimized' using a tensor subclass that schedules the synchronization (via wait_tensor() op) +just before the tensor is first used. Compiled tracing currently relies on the compiler to perform this optimization, +and cannot yet correctly trace the AsyncTensor wrapper class. In the future, these paths may be unified +if sufficient subclass support is added in dynamo. + +Example: all_reduce is an entrypoint API, and other collectives follow a similar pattern. + +Here's how it works under torch.compile/dynamo: +all_reduce(...) + |--> _expand_group(...) - desugars processgroup into canonical/traceable format + |--> c10d_functional.all_reduce(...) - dynamo captures this op call, doesn't trace deeper + |--> _maybe_wrap_tensor(...) - wait_tensor() op is immediately called, no AsyncTensor subclass needed + +And under eager execution: +all_reduce(...) + |--> _expand_group(...) - same as above, but less critical for eager + |--> c10d_functional.all_reduce(...) - dispatches to real kernel OR records op in trace + |--> _maybe_wrap_tensor(...) - AsyncTensor wrapper applied to returned tensor, + which issues wait_tensor() at the time of first use +""" + + +def wait_tensor(tensor): + """ + Wait on a tensor returned by the collectives ops. + + Waiting follows device semantics, which means blocking on CPU and synchronizing streams on CUDA. + """ + return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined] + + +def broadcast(self: torch.Tensor, src: int, group: RANK_TYPES, tag: str = ""): + """ + Broadcasts the tensor to all processes in the given process group. + + Args: + src (int): Source rank + group (ProcessGroup or List[int]): The process group to work on. + tag (str, optional): A unique identifier for the collective. Default: empty string + """ + group_name = _resolve_group_name(group, tag) + tensor = torch.ops._c10d_functional.broadcast(self, src, group_name) + return _maybe_wrap_tensor(tensor) + + +def all_reduce(self: torch.Tensor, reduceOp: str, group: RANK_TYPES, tag: str = ""): + """ + Reduces the tensor data across all machines in such a way that all get + the final result. + + The input tensor is left unmodified. + + Group can be one of: + List[int]: ranks participating in the collective. + List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD. + ProcessGroup: Will perform a collective using the ranks and tag of the PG. + DeviceMesh: Do a SPMD collective over all ranks of the mesh + (DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh + + :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover + that information and perform collective algebraic optimization. Use other forms of input for that. + """ + group_name = _resolve_group_name(group, tag) + tensor = torch.ops._c10d_functional.all_reduce(self, reduceOp.lower(), group_name) + return _maybe_wrap_tensor(tensor) + + +def all_gather_tensor( + self: torch.Tensor, + gather_dim: int, + group: RANK_TYPES, + tag: str = "", +) -> torch.Tensor: + """ + Gather tensor data across from all machines and concatenate over ``gather_dim``. + + Note that it currently only supports gather_dim = 0. + + The input tensor is left unmodified. + Group can be one of: + List[int]: ranks participating in the collective. + List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD. + ProcessGroup: Will perform a collective using the ranks and tag of the PG. + DeviceMesh: Do a SPMD collective over all ranks of the mesh + (DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh + + :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover + that information and perform collective algebraic optimization. Use other forms of input for that. + """ + if not self.is_contiguous(): + raise AssertionError("Tensor must be contiguous for all_gather_tensor") + group_name = _resolve_group_name(group, tag) + group_size = c10d._get_group_size_by_name(group_name) + tensor = torch.ops._c10d_functional.all_gather_into_tensor( + self, group_size, group_name + ) + res = _maybe_wrap_tensor(tensor) + # TODO this should be done inside AsyncCollectiveTensor to delay the wait() call + if gather_dim != 0: + # torch.cat access the data so we already need to wait here, first do wait + # and then chunk + cat avoid us going through ACT dispatching logic again + if isinstance(res, AsyncCollectiveTensor): + res = res.wait() # type: ignore[attr-defined] + res = torch.cat(torch.chunk(res, group_size, dim=0), dim=gather_dim) + return res + + +def all_gather_tensor_autograd( + self: torch.Tensor, + gather_dim: int, + group: RANK_TYPES, + tag: str = "", +): + """ + Gather tensor data across from all machines and concatenate over ``gather_dim``. + + Note that it currently only supports gather_dim = 0. + + This function is the same as all_gather_tensor but will propagate the + backwards gradient across workers. + + See all_gather_tensor for more details on usage. + """ + group_name = _resolve_group_name(group, tag) + group_size = c10d._get_group_size_by_name(group_name) + + tensor = torch.ops._c10d_functional_autograd.all_gather_into_tensor( + self, group_size, group_name + ) + res = _FromTorchTensor.apply(tensor) + # TODO this should be done inside AsyncCollectiveTensor to delay the wait() call + if gather_dim != 0: + # torch.cat access the data so we already need to wait here, first do wait + # and then chunk + cat avoid us going through ACT dispatching logic again + if isinstance(res, AsyncCollectiveTensor): + res = res.wait() # type: ignore[attr-defined] + res = torch.cat(torch.chunk(res, group_size, dim=0), dim=gather_dim) + return res + + +def reduce_scatter_tensor( + self: torch.Tensor, + reduceOp: str, + scatter_dim: int, + group: RANK_TYPES, + tag: str = "", +): + """ + Reduces the tensor data across all machines in such a way that all get + the final result, then scatter the results to corresponding ranks. + + + The input tensor is left unmodified. + Group can be one of: + List[int]: ranks participating in the collective. + List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD. + ProcessGroup: Will perform a collective using the ranks and tag of the PG. + DeviceMesh: Do a SPMD collective over all ranks of the mesh + (DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh + :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover + that information and perform collective algebraic optimization. Use other forms of input for that. + """ + group_name = _resolve_group_name(group, tag) + group_size = c10d._get_group_size_by_name(group_name) + + if self.size(scatter_dim) % group_size != 0: + raise AssertionError( + f"input dimension 0 ({self.size(0)} must be a multiple of group_size {group_size})" + ) + if scatter_dim != 0: + tensor_list = torch.chunk(self, group_size, dim=scatter_dim) + self = torch.cat(tensor_list) + + tensor = torch.ops._c10d_functional.reduce_scatter_tensor( + self, + reduceOp.lower(), + group_size, + group_name, # type: ignore[possibly-undefined] + ) + res = _maybe_wrap_tensor(tensor) + return res + + +def reduce_scatter_tensor_autograd( + self: torch.Tensor, + reduceOp: str, + scatter_dim: int, + group: RANK_TYPES, + tag: str = "", +): + """ + Reduces the tensor data across all machines in such a way that all get + the final result, then scatter the results to corresponding ranks. + + This function is the same as reduce_scatter_tensor but will propagate the + backwards gradient across workers. + + Currently only the "sum" reduceOp is supported. + + See reduce_scatter_tensor for more details on usage. + """ + + group_name = _resolve_group_name(group, tag) + group_size = c10d._get_group_size_by_name(group_name) + + if self.size(scatter_dim) % group_size != 0: + raise AssertionError( + f"input dimension 0 ({self.size(0)} must be a multiple of group_size {group_size}" + ) + if scatter_dim != 0: + tensor_list = torch.chunk(self, group_size, dim=scatter_dim) + self = torch.cat(tensor_list) + + tensor = torch.ops._c10d_functional_autograd.reduce_scatter_tensor( + self, + reduceOp.lower(), + group_size, + group_name, # type: ignore[possibly-undefined] + ) + res = _FromTorchTensor.apply(tensor) + return res + + +def all_reduce_coalesced( + self: list[torch.Tensor], reduceOp: str, group: RANK_TYPES, tag: str = "" +) -> list[torch.Tensor]: + """ + Reduces a list of tensors across all machines in such a way that all get + the final result. + + The all tensors in the input list are left unmodified. + + Group can be one of: + List[int]: ranks participating in the collective. + List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD. + ProcessGroup: Will perform a collective using the ranks and tag of the PG. + DeviceMesh: Do a SPMD collective over all ranks of the mesh + (DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh + + :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover + that information and perform collective algebraic optimization. Use other forms of input for that. + """ + group_name = _resolve_group_name(group, tag) + tensor_list = torch.ops._c10d_functional.all_reduce_coalesced( # type: ignore[attr-defined] + self, + reduceOp.lower(), + group_name, + ) + return list(map(_maybe_wrap_tensor, tensor_list)) + + +def all_gather_into_tensor_coalesced( + self: list[torch.Tensor], group: RANK_TYPES, tag: str = "" +) -> list[torch.Tensor]: + """ + Gather a list of tensors across from all machines. + + Note that it currently only supports gather_dim = 0. + + The input tensor is left unmodified. + Group can be one of: + List[int]: ranks participating in the collective. + List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD. + ProcessGroup: Will perform a collective using the ranks and tag of the PG. + DeviceMesh: Do a SPMD collective over all ranks of the mesh + (DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh + + :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover + that information and perform collective algebraic optimization. Use other forms of input for that. + """ + group_name = _resolve_group_name(group, tag) + group_size = c10d._get_group_size_by_name(group_name) + tensor_list = torch.ops._c10d_functional.all_gather_into_tensor_coalesced( # type: ignore[attr-defined] + self, + group_size, + group_name, + ) + return list(map(_maybe_wrap_tensor, tensor_list)) + + +def reduce_scatter_tensor_coalesced( + inputs: list[torch.Tensor], + reduceOp: str, + scatter_dim: list[int], + group: RANK_TYPES, + tag: str = "", +) -> list[torch.Tensor]: + """ + Reduces a list of tensors across all machines in such a way that all get + the final result, then scatter the results to corresponding ranks. + + The input tensors are left unmodified. + Group can be one of: + List[int]: ranks participating in the collective. + List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD. + ProcessGroup: Will perform a collective using the ranks and tag of the PG. + DeviceMesh: Do a SPMD collective over all ranks of the mesh + (DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh + + :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover + that information and perform collective algebraic optimization. Use other forms of input for that. + """ + group_name = _resolve_group_name(group, tag) + group_size = c10d._get_group_size_by_name(group_name) + + if len(scatter_dim) != len(inputs): + raise AssertionError( + f"Length of scatter_dim ({len(scatter_dim)}) must equal length of inputs ({len(inputs)})" + ) + for idx, (dim, tensor) in enumerate(zip(scatter_dim, inputs)): + if tensor.size(dim) % group_size != 0: + raise AssertionError( + f"input dimension {dim} ({tensor.size(dim)} must be a multiple of group_size {group_size} for tensor at index {idx}" + ) + if dim != 0: + tensor_list = torch.chunk(tensor, group_size, dim=dim) + inputs[idx] = torch.cat(tensor_list) + + tensor_list = torch.ops._c10d_functional.reduce_scatter_tensor_coalesced( # type: ignore[attr-defined] + inputs, + reduceOp.lower(), + group_size, + group_name, # type: ignore[possibly-undefined] + ) + + return list(map(_maybe_wrap_tensor, tensor_list)) + + +# This is a bit unsafe: it checks if the first argument in the schema reports as a non-mutable alias. +# Today, this maps 1:1 with "aten ops that are views". +def _is_view_op(tgt): + if not isinstance(tgt, torch._ops.OpOverload): + raise AssertionError(f"Expected torch._ops.OpOverload, got {type(tgt)}") + # Don't apply the view optimization to any `CompositeImplicitAutograd` ops. + # See issue: https://github.com/pytorch/pytorch/issues/133421 + if torch._C._dispatch_has_kernel_for_dispatch_key( + tgt.name(), torch.DispatchKey.CompositeImplicitAutograd + ): + return False + schema = tgt._schema + if len(schema.arguments) > 0: + first_arg = schema.arguments[0] + # check if op is a view + return first_arg.alias_info is not None and not first_arg.alias_info.is_write + + +def all_to_all_single( + self: torch.Tensor, + output_split_sizes: list[int] | None, + input_split_sizes: list[int] | None, + group: RANK_TYPES, + tag: str = "", +) -> torch.Tensor: + """ + Each process splits input tensor and then scatters the split list + to all processes in a group. Then concatenate the received tensors from all + the processes in the group and return single output tensor. + + Group can be one of: + List[int]: ranks participating in the collective. + List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD. + ProcessGroup: Will perform a collective using the ranks and tag of the PG. + DeviceMesh: Do a SPMD collective over all ranks of the mesh + (DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh + + :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover + that information and perform collective algebraic optimization. Use other forms of input for that. + """ + if output_split_sizes is not None: + if not all( + isinstance(size, (int, torch.SymInt)) for size in output_split_sizes + ): + raise AssertionError( + f"All output_split_sizes must be int or SymInt, got {output_split_sizes}" + ) + if input_split_sizes is not None: + if not all(isinstance(size, (int, torch.SymInt)) for size in input_split_sizes): + raise AssertionError( + f"All input_split_sizes must be int or SymInt, got {input_split_sizes}" + ) + group_name = _resolve_group_name(group, tag) + group_size = c10d._get_group_size_by_name(group_name) + if output_split_sizes is None or input_split_sizes is None: + if not (output_split_sizes is None and input_split_sizes is None): + raise AssertionError( + "output_split_sizes and input_split_sizes must either be " + "specified together or both set to None" + ) + output_split_sizes = [self.shape[0] // group_size] * group_size + input_split_sizes = output_split_sizes + tensor = torch.ops._c10d_functional.all_to_all_single( # type: ignore[attr-defined] + self, + output_split_sizes, + input_split_sizes, + group_name, + ) + return _maybe_wrap_tensor(tensor) + + +def all_to_all_single_autograd( + self: torch.Tensor, + output_split_sizes: list[int] | None, + input_split_sizes: list[int] | None, + group: RANK_TYPES, + tag: str = "", +) -> torch.Tensor: + """ + Same as all_to_all_single but supports autograd. + """ + if output_split_sizes is not None: + if not all( + isinstance(size, (int, torch.SymInt)) for size in output_split_sizes + ): + raise AssertionError( + f"All output_split_sizes must be int or SymInt, got {output_split_sizes}" + ) + if input_split_sizes is not None: + if not all(isinstance(size, (int, torch.SymInt)) for size in input_split_sizes): + raise AssertionError( + f"All input_split_sizes must be int or SymInt, got {input_split_sizes}" + ) + + group_name = _resolve_group_name(group, tag) + group_size = c10d._get_group_size_by_name(group_name) + if output_split_sizes is None or input_split_sizes is None: + if not (output_split_sizes is None and input_split_sizes is None): + raise AssertionError( + "output_split_sizes and input_split_sizes must either be " + "specified together or both set to None" + ) + output_split_sizes = [self.shape[0] // group_size] * group_size + input_split_sizes = output_split_sizes + tensor = torch.ops._c10d_functional_autograd.all_to_all_single( # type: ignore[attr-defined] + self, + output_split_sizes, + input_split_sizes, + group_name, + ) + return _FromTorchTensor.apply(tensor) + + +def permute_tensor( + self: torch.Tensor, + src_dst: list[int], + group: RANK_TYPES, + tag: str = "", +) -> torch.Tensor: + """ + Permutes the elements of the tensor according to the given source/destination pairs. `src_dst` should + be defined such that src_dst[m] == n means m sends to n. + + Group can be one of: + List[int]: ranks participating in the collective. + List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD. + ProcessGroup: Will perform a collective using the ranks and tag of the PG. + DeviceMesh: Do a SPMD collective over all ranks of the mesh + (DeviceMesh, int): Do a MPMD collective over one + """ + t, rankset, group_size = _expand_group(group, tag) + local_pg = c10d._find_or_create_pg_by_ranks_and_tag(t, rankset, group_size) + + output_split_sizes = [0] * group_size + input_split_sizes = [0] * group_size + for src, dst in enumerate(src_dst): + if src == dist.get_rank(local_pg): + input_split_sizes[dst] = self.numel() + if dst == dist.get_rank(local_pg): + output_split_sizes[src] = self.numel() + + return all_to_all_single(self, output_split_sizes, input_split_sizes, group, tag) + + +class AsyncCollectiveTensor(torch.Tensor): + r""" + A Tensor wrapper subclass that is used to trigger a call to wait + prior to first use of the underlying tensor. + Use it inside functional collective pytorch wrappers like the following: + def functional_collective(self, group, tag): + tag, rankset, group_size = _expand_group(group, tag) + tensor = torch.ops.c10d_functional.{collective}(self, tag, rankset, group_size) + return _maybe_wrap_tensor(tensor) + """ + + elem: torch.Tensor + completed: bool + + __slots__ = ["elem", "completed"] + + @staticmethod + def __new__(cls, elem: torch.Tensor): + r = torch.Tensor._make_wrapper_subclass( + cls, + elem.size(), + strides=elem.stride(), + storage_offset=elem.storage_offset(), + dtype=elem.dtype, + layout=elem.layout, + device=elem.device, + requires_grad=elem.requires_grad, + ) + r.elem = elem + r.completed = False + return r + + def __tensor_flatten__(self): + return ["elem"], None + + def tolist(self): + return self.trigger_wait().tolist() + + @staticmethod + def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride): + if meta is not None: + raise AssertionError( + "meta must be None for AsyncCollectiveTensor unflatten" + ) + elem = inner_tensors["elem"] + return AsyncCollectiveTensor(elem) + + def __coerce_same_metadata_as_tangent__( + self, expected_metadata: Any, expected_type: type | None = None + ): + if expected_type is not torch.Tensor: + return None + + return self.trigger_wait() + + def __repr__(self) -> str: # type: ignore[override] + return f"AsyncCollectiveTensor({self.trigger_wait()})" + + def trigger_wait(self): + if not self.completed: + out = wait_tensor(self.elem) + self.completed = True + return out + else: + return self.elem + + def wait(self) -> torch.Tensor: + return wait_tensor(self.elem) + + def _get_acs_underlying_tensor(self): + """This method enables _functional_collectives_impl to test if a tensor is an ACS""" + return self.elem + + @classmethod + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): # type: ignore[override] + if func is torch.ops.aten.view.default: + # Fast handle aten.view as a lot of view related op goes to aten.view + # eventually, this avoids pytree slowdown + # pyrefly: ignore [index-error] + res = func(args[0].elem, args[1]) + wrapper_res = AsyncCollectiveTensor(res) + return wrapper_res + + is_view_op = _is_view_op(func) + + def unwrap(e: AsyncCollectiveTensor): + # wait_tensor is idepotent and will do stream sync only once + if not is_view_op: + return e.trigger_wait() + return e.elem + + def wrap(e: torch.Tensor): + # wait_tensor is idepotent and will do stream sync only once + if isinstance(e, AsyncCollectiveTensor): + raise AssertionError( + "Cannot wrap an AsyncCollectiveTensor inside another AsyncCollectiveTensor" + ) + res = AsyncCollectiveTensor(e) + return res + + unwrapped_args = tree_map_only(AsyncCollectiveTensor, unwrap, args) + unwrapped_kwargs = tree_map_only(AsyncCollectiveTensor, unwrap, kwargs) + + # we don't wrap the result as it doesn't need to be waited on. + out = func(*unwrapped_args, **unwrapped_kwargs) + + # View ops dont require a sync, so we should re-wrap the outputs. + if is_view_op: + out = tree_map_only(torch.Tensor, wrap, out) + + return out + + def numpy(self): # type: ignore[override] + return self.wait().numpy() + + +""" +Utils and infrastructure for tracing support +""" + + +def _expand_group(group: RANK_TYPES, tag: str = "") -> tuple[str, list[int], int]: + """ + _expand_group desugars the different RANK_TYPES types into a canonical format that is traceable. + + By having this be part of the explicit eager codepath, we avoid having to specialize behavior inside + torchdynamo and can still interoperate with processgroup objects or other untraceable forms. + """ + # had to define this hack _inside_ expand_group to avoid + # graph_break [('torch.* op returned non-Tensor int + # caused by 'cast_*` functions being treated as 'torch.*' ops (iiuc) + if TYPE_CHECKING: + + def cast_listlistint(x): + return cast(list[list[int]], x) + + def cast_listint(x): + return cast(list[int], x) + + else: + # fake cast op for use at runtime since dynamo doesn't support real cast + # also, dynamo didn't like encountering 'typing' objects () + # NotImplementedError: argument of type: + def cast_listlistint(x): + return x + + def cast_listint(x): + return x + + rankset: list[int] + if isinstance(group, list): + if isinstance(group[0], list): + nested_list = cast_listlistint(group) + rankset = [] + group_size = -1 + for rs in nested_list: + rankset.extend(rs) + if group_size != -1 and group_size != len(rs): + raise ValueError( + f"group sizes must be identical found {group_size} and {len(rs)}" + ) + group_size = len(rs) + else: + rankset = cast_listint(group) + group_size = len(rankset) + elif isinstance(group, dist.ProcessGroup): + rankset = dist.get_process_group_ranks(group) + group_size = len(rankset) + tag = tag or c10d._get_group_tag(group) + elif isinstance(group, DeviceMesh): + if group.ndim != 1: + raise AssertionError( + "Only 1D mesh is supported, pass in (DeviceMesh, int) together if mesh > 1D" + ) + # TODO: it should run collective in the whole mesh instead of dim 0 + pg = group.get_group() + rankset = dist.get_process_group_ranks(pg) + group_size = len(rankset) + tag = tag or c10d._get_group_tag(pg) + elif isinstance(group, tuple): + if ( + len(group) == 2 + and isinstance(group[0], DeviceMesh) + and isinstance(group[1], int) + ): + dmesh = group[0] + dim = group[1] + pg = dmesh.get_group(dim) + rankset = dist.get_process_group_ranks(pg) + group_size = len(rankset) + tag = tag or c10d._get_group_tag(pg) + else: + raise ValueError("Invalid tuple for group must be (DeviceMesh, int)") + else: + raise ValueError( + "Invalid type for group, must be one of List, Processgroup, DeviceMesh or (DeviceMesh, int)." + ) + + return (tag, rankset, group_size) + + +def _resolve_group_name(group: RANK_TYPES, tag: str = "") -> c10d.GroupName: + """ + Given group in RANK_TYPES, return the group name. + """ + # `tag` will be deprecated. See details in: + # https://github.com/pytorch/pytorch/issues/93173#issuecomment-1907095208 + if isinstance(group, dist.ProcessGroup): + return group.group_name + elif isinstance(group, str): + # In some cases Dynamo doesn't like tracing through NewType constructors + # - so use a cast instead (the actual newtype representation is + # literally the underlying type so this is fine). I haven't been able to + # reproduce it in isolation (see T247631668). + return cast(c10d.GroupName, group) # c10d.GroupName(group) + elif isinstance(group, DeviceMesh): + if group.ndim != 1: + raise AssertionError( + "Only 1D mesh is supported, pass in (DeviceMesh, int) together if mesh > 1D" + ) + return group._dim_group_names[0] + elif isinstance(group, tuple): + if ( + len(group) == 2 + and isinstance(group[0], DeviceMesh) + and isinstance(group[1], int) + ): + dmesh = group[0] + dim = group[1] + return dmesh._dim_group_names[dim] + else: + raise ValueError("Invalid tuple for group must be (DeviceMesh, int)") + elif isinstance(group, list): + if not is_torchdynamo_compiling(): + warnings.warn( + "The combination of ranks + tag as process group " + "identifier has been deprecated. Please switch to " + "using ProcessGroup, DeviceMesh, or group name instead.", + FutureWarning, + stacklevel=3, + ) + # pyrefly: ignore [redundant-cast] + return c10d._resolve_group_name_by_ranks_and_tag(cast(list[int], group), tag) + else: + raise ValueError(f"Unsupported group type: {type(group)}, {group}") + + +class _FromTorchTensor(torch.autograd.Function): + """ + _FromTorchTensor allows autograd to propagate from a normal Tensor to an + AsyncCollectiveTensor. + """ + + @staticmethod + def forward( # type: ignore[override] + ctx, # pyre-ignore[2]: Parameter must be annotated. + input: torch.Tensor, + ) -> torch.Tensor: + return _maybe_wrap_tensor(input) + + @staticmethod + def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor: # type: ignore[override] + return grad_output + + +def _are_we_tracing() -> bool: + if is_torchdynamo_compiling(): + return True + # If fake mode is turned on, we are almost definitely compiling/tracing. + if torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.FAKE) is not None: + return True + # See Note [enable_python_dispatcher in dynamo] + if torch._C._dispatch_tls_is_dispatch_key_included( + torch._C.DispatchKey.PythonDispatcher + ): + return True + return get_proxy_mode() is not None + + +def _maybe_wrap_tensor(self) -> torch.Tensor: + if _are_we_tracing(): + return wait_tensor(self) + res = AsyncCollectiveTensor(self) + return cast(torch.Tensor, res) + + +@contextlib.contextmanager +def allow_inflight_collective_as_graph_input_ctx(value: bool = True): + """ + Context manager to temporarily set whether inflight collectives are allowed as torch.compile graph inputs. + Common use case is when the collective is issued in eager (with `async_op=True`) but waited in compiled region: + ``` + def all_reduce_eager(x): + y = x * x + req = dist.all_reduce(y, op=dist.ReduceOp.SUM, async_op=True) + return y + + + @torch.compile(fullgraph=True) + def all_reduce_wait_compiled(y): + torch.ops.c10d_functional.wait_tensor(y) + return y * y + + + x = torch.ones(1280, 1280, device="cuda") + self.rank + # the context manager ensures that `wait_tensor(y)` will wait on the correct work object + with allow_inflight_collective_as_graph_input_ctx(): + y = all_reduce_eager(x) + z = all_reduce_wait_compiled(y) + ``` + With this context manager, when a collective is called, under the hood the work object of the collective + will be registered in the work registry, and the wait_tensor() in compiled region called on + the output tensor of the collective will wait on the correct work object. + """ + previous = torch._C._distributed_c10d._allow_inflight_collective_as_graph_input() + + try: + torch._C._distributed_c10d._set_allow_inflight_collective_as_graph_input(value) + yield + finally: + torch._C._distributed_c10d._set_allow_inflight_collective_as_graph_input( + previous + ) + + +def _make_all_gather_out_tensor(input, group_size): + out_size = list(input.size()) + if len(out_size) == 0: + out_size.append(group_size) + else: + out_size[0] *= group_size + out_tensor = input.new_empty(out_size) + return out_tensor + + +def _all_gather_into_tensor_coalesced_meta(self, tag, rankset, group_size): + return [_make_all_gather_out_tensor(t, group_size) for t in self] + + +# We now register meta kernels to deal with tracing +def _broadcast_meta(self, *args): + return torch.empty_like(self) + + +def _all_reduce_meta(self, *args): + return torch.empty_like(self) + + +def _wait_tensor_meta(self, *args): + return torch.empty_like(self) + + +def _all_gather_into_tensor_meta(shard, tag, rankset, group_size): + return _make_all_gather_out_tensor(shard, group_size) + + +def _reduce_scatter_tensor_meta(input, reduce_op, tag, rankset, group_size): + out_size = list(input.size()) + out_size[0] //= group_size + return input.new_empty(out_size) + + +def _all_reduce_coalesced_meta(self, *args): + return [torch.empty_like(t) for t in self] + + +def _all_reduce__meta(inp, *args): + return inp + + +def _broadcast__meta(inp, *args): + return inp + + +def _all_reduce_coalesced__meta(inputs, *args): + return inputs + + +def _reduce_scatter_tensor_coalesced_meta(inputs, reduceOp, tag, rankset, group_size): + def mk_out_tensor(input): + out_size = list(input.size()) + out_size[0] //= group_size + out_tensor = input.new_empty(out_size) + return out_tensor + + return [mk_out_tensor(t) for t in inputs] + + +# NB: We often say all_to_all has dynamic output size, but this is not +# technically true: instead, what typically happens is you manually +# communicate the output_split_sizes ahead of time (which is dynamic), +# but then you pass those sizes explicitly, and the all to all itself +# isn't dynamic, it just follows the specified output splits +def _all_to_all_single_meta( + input, output_split_sizes, input_split_sizes, *args, **kwargs +): + if output_split_sizes is None: + return input.new_empty(input.size()) + else: + for s in output_split_sizes: + torch._check(s >= 0) + out_size = list(input.size()) + out_size[0] = sum(output_split_sizes) + return input.new_empty(out_size) + + +def _all_gather_into_tensor_out_native_meta(input, group_size, group_name, *, out): + return _make_all_gather_out_tensor(input, group_size) + + +def _all_gather_into_tensor_native_meta(input, group_size, group_name): + return _make_all_gather_out_tensor(input, group_size) + + +def _all_gather_into_tensor_coalesced_native_meta(inputs, group_size, group_name): + return [ + _all_gather_into_tensor_native_meta(input, group_size, group_name) + for input in inputs + ] + + +def _reduce_scatter_tensor_native_meta(inp, reduce_op, group_size, group_name): + shape = list(inp.size()) + shape[0] //= group_size + return inp.new_empty(shape) + + +def _reduce_scatter_tensor_out_native_meta( + inp, reduce_op, group_size, group_name, *, out +): + shape = list(inp.size()) + shape[0] //= group_size + return inp.new_empty(shape) + + +def _reduce_scatter_tensor_coalesced_native_meta( + inputs, reduce_op, group_size, group_name +): + return [ + _reduce_scatter_tensor_native_meta(inp, reduce_op, group_size, group_name) + for inp in inputs + ] + + +# Library MUST be defined at module scope or it doesn't work +lib_impl = torch.library.Library("_c10d_functional", "IMPL") +lib_impl.impl("all_reduce", _all_reduce_meta, "Meta") +lib_impl.impl("all_reduce_", _all_reduce__meta, "Meta") +lib_impl.impl("all_reduce_coalesced", _all_reduce_coalesced_meta, "Meta") +lib_impl.impl("all_reduce_coalesced_", _all_reduce_coalesced__meta, "Meta") +lib_impl.impl("wait_tensor", _wait_tensor_meta, "Meta") +lib_impl.impl( + "all_gather_into_tensor_out", _all_gather_into_tensor_out_native_meta, "Meta" +) +lib_impl.impl("all_gather_into_tensor", _all_gather_into_tensor_native_meta, "Meta") +lib_impl.impl( + "all_gather_into_tensor_coalesced", + _all_gather_into_tensor_coalesced_native_meta, + "Meta", +) +lib_impl.impl("reduce_scatter_tensor", _reduce_scatter_tensor_native_meta, "Meta") +lib_impl.impl( + "reduce_scatter_tensor_out", _reduce_scatter_tensor_out_native_meta, "Meta" +) +lib_impl.impl( + "reduce_scatter_tensor_coalesced", + _reduce_scatter_tensor_coalesced_native_meta, + "Meta", +) +lib_impl.impl("all_to_all_single", _all_to_all_single_meta, "Meta") +lib_impl.impl("broadcast", _broadcast_meta, "Meta") +lib_impl.impl("broadcast_", _broadcast__meta, "Meta") + +# mark these ops has side effect so that they won't be removed by DCE +torch.fx.node.has_side_effect(torch.ops._c10d_functional.wait_tensor.default) # type: ignore[has-type] +torch.fx.node.has_side_effect(torch.ops._c10d_functional.wait_tensor) # type: ignore[has-type] + +# Register legacy ops for backward compatibility +# TODO(yifu): remove these in functional collective beta release +legacy_lib = torch.library.Library("c10d_functional", "DEF") +legacy_lib_impl = torch.library.Library("c10d_functional", "IMPL") +ops_defs = [ + "broadcast(Tensor self, int src, str tag, int[] ranks, int group_size) -> Tensor", + "all_reduce(Tensor self, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor", + "all_reduce_coalesced(Tensor[] self, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor[]", + "wait_tensor(Tensor self) -> Tensor", + "all_gather_into_tensor(Tensor shard, str tag, int[] ranks, int group_size) -> Tensor", + "all_gather_into_tensor_coalesced(Tensor[] input, str tag, int[] ranks, int group_size) -> Tensor[]", + "reduce_scatter_tensor(Tensor input, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor", + "reduce_scatter_tensor_coalesced(Tensor[] inputs, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor[]", + "all_to_all_single(Tensor input, SymInt[]? output_split_sizes, SymInt[]? input_split_sizes, str tag, int[] ranks, int group_size) -> Tensor", # noqa: B950 +] + +my_module = sys.modules[__name__] +for op_def in ops_defs: + op_name = op_def[0 : op_def.index("(")] + backend_impl = getattr(fun_col_impl, f"_{op_name}") + legacy_lib.define(op_def, tags=torch.Tag.pt2_compliant_tag) + legacy_lib_impl.impl(op_name, backend_impl, "CompositeImplicitAutograd") + + +""" +Dynamo Remappings allow seamless translation from non-functional collectives of supportable form into +functional collective calls followed by inplace copy ops, allowing them to be traced into a functional graph. + +We implement this by writing a decomposition and teaching dynamo how to associate it to a corresponding op via +the mapping dict below. + +These schemas intentionally match torch.distributed.distributed_c10d.* ops that we are trying to remap from +""" + + +def all_gather_tensor_inplace( + output_tensor: torch.Tensor, + input_tensor: torch.Tensor, + group=None, # TODO add a type, + async_op: bool = False, + tag: str = "", + gather_dim: int = 0, +): + if async_op: + raise AssertionError( + "Can't remap async version of inplace op to functional collective" + ) + + group = group or dist.group.WORLD + if group is None: + raise AssertionError("group cannot be None") + + return output_tensor.copy_(all_gather_tensor(input_tensor, gather_dim, group, tag)) + + +def reduce_scatter_tensor_inplace( + output: torch.Tensor, + input: torch.Tensor, + op: str = "sum", # TODO type is actually c10d ReduceOp. is this ok? + group=None, # TODO add a type + async_op: bool = False, + scatter_dim: int = 0, + tag: str = "", +): + if async_op: + raise AssertionError( + "Can't remap async version of inplace op to functional collective" + ) + + group = group or dist.group.WORLD + if group is None: + raise AssertionError("group cannot be None") + + return output.copy_(reduce_scatter_tensor(input, op, scatter_dim, group, tag)) + + +REDUCE_OP_TO_STR = { + dist.ReduceOp.SUM: "sum", + dist.ReduceOp.AVG: "avg", + dist.ReduceOp.PRODUCT: "product", + dist.ReduceOp.MIN: "min", + dist.ReduceOp.MAX: "max", + dist.ReduceOp.BAND: "band", + dist.ReduceOp.BOR: "bor", + dist.ReduceOp.BXOR: "bxor", +} + + +def all_reduce_inplace( + tensor: torch.Tensor, + op: str = "sum", + group=None, + async_op: bool = False, + tag: str = "", +): + if async_op: + raise AssertionError( + "Can't remap async version of inplace op to functional collective" + ) + + group = group or dist.group.WORLD + if group is None: + raise AssertionError("group cannot be None") + + return tensor.copy_(all_reduce(tensor, op, group, tag)) + + +def all_to_all_inplace( + output: torch.Tensor, + input: torch.Tensor, + output_split_sizes=None, + input_split_sizes=None, + group=None, + async_op=False, + tag: str = "", +): + if async_op: + raise AssertionError( + "Can't remap async version of inplace op to functional collective" + ) + + group = group or dist.group.WORLD + if group is None: + raise AssertionError("group cannot be None") + + return output.copy_( + all_to_all_single( + input, + output_split_sizes, + input_split_sizes, + group, + tag, + ) + ) + + +def all_gather_inplace( + tensor_list: list[torch.Tensor], + tensor: torch.Tensor, + group=None, + async_op=False, + tag: str = "", +): + if async_op: + raise AssertionError( + "Can't remap async version of inplace op to functional collective" + ) + if tensor.dim() != 0 and not all(t.size(0) == tensor.size(0) for t in tensor_list): + raise AssertionError("Remapping variable size all_gather is not yet supported") + + group = group or dist.group.WORLD + if group is None: + raise AssertionError("group cannot be None") + + output = all_gather_tensor(tensor, 0, group, tag) + + # Use aten.slice instead of aten.split because the latter causes + # tensor.shape(0) to be unnecessarily baked in when it's a SymInt. + output_splits = [] + offset = 0 + for t in tensor_list: + is_scalar = t.dim() == 0 + t_offset = 1 if is_scalar else t.size(0) + # pyrefly: ignore [unsupported-operation] + out = output[offset] if is_scalar else output[offset : offset + t_offset] + output_splits.append(out) + # pyrefly: ignore [unsupported-operation] + offset += t_offset + for dst, src in zip(tensor_list, output_splits): + dst.copy_(src) + return tensor_list + + +from torch.distributed.distributed_c10d import ( # pyrefly: ignore # deprecated; pyrefly: ignore [deprecated] + _all_gather_base as legacy_all_gather_base, + _reduce_scatter_base as legacy_reduce_scatter_base, + all_gather as legacy_all_gather, + all_gather_into_tensor as legacy_allgather, + all_reduce as legacy_allreduce, + all_to_all_single as legacy_all_to_all_single, + reduce_scatter_tensor as legacy_reducescatter, +) + + +# This dict should contain sets of functions that dynamo is allowed to remap. +# Functions in this set should accept the same args/kwargs 1:1 as their mapping. +traceable_collective_remaps = { + legacy_allgather: all_gather_tensor_inplace, # type: ignore[has-type] + legacy_reducescatter: reduce_scatter_tensor_inplace, # type: ignore[has-type] + legacy_allreduce: all_reduce_inplace, # type: ignore[has-type] + legacy_all_to_all_single: all_to_all_inplace, # type: ignore[has-type] + legacy_all_gather: all_gather_inplace, # type: ignore[has-type] + legacy_reduce_scatter_base: reduce_scatter_tensor_inplace, # type: ignore[has-type] + legacy_all_gather_base: all_gather_tensor_inplace, # type: ignore[has-type] +} diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_functional_collectives_impl.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_functional_collectives_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..fcb659b74bc0537b36e447f1a69628e70933d3e9 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_functional_collectives_impl.py @@ -0,0 +1,117 @@ +# mypy: allow-untyped-defs + +import torch +import torch.distributed.distributed_c10d as c10d + + +""" +This file contains the op impls for the legacy (c10d_functional) functional collectives. +These impls simply call into the native (_c10d_functional) functional collectives. +""" + + +def _broadcast(input, src, tag, ranks, group_size): + group_name = c10d._resolve_group_name_by_ranks_and_tag(ranks, tag) + return torch.ops._c10d_functional.broadcast( + input, + src, + group_name, + ) + + +def _all_reduce(input, reduce_op, tag, ranks, group_size): + group_name = c10d._resolve_group_name_by_ranks_and_tag(ranks, tag) + return torch.ops._c10d_functional.all_reduce( + input, + reduce_op, + group_name, + ) + + +def _all_reduce_coalesced(inputs, reduce_op, tag, ranks, group_size): + group_name = c10d._resolve_group_name_by_ranks_and_tag(ranks, tag) + return torch.ops._c10d_functional.all_reduce_coalesced( + inputs, + reduce_op, + group_name, + ) + + +def _all_gather_into_tensor(input, tag, ranks, group_size): + group_name = c10d._resolve_group_name_by_ranks_and_tag(ranks, tag) + return torch.ops._c10d_functional.all_gather_into_tensor( + input, + group_size, + group_name, + ) + + +def _all_gather_into_tensor_coalesced(input, tag, ranks, group_size): + group_name = c10d._resolve_group_name_by_ranks_and_tag(ranks, tag) + return torch.ops._c10d_functional.all_gather_into_tensor_coalesced( + input, + group_size, + group_name, + ) + + +def _reduce_scatter_tensor( + input: torch.Tensor, + reduce_op: str, + tag: str, + ranks: list[int], + group_size: int, +): + group_name = c10d._resolve_group_name_by_ranks_and_tag(ranks, tag) + return torch.ops._c10d_functional.reduce_scatter_tensor( + input, + reduce_op, + group_size, + group_name, + ) + + +def _reduce_scatter_tensor_coalesced( + inputs: list[torch.Tensor], + reduce_op: str, + tag: str, + ranks: list[int], + group_size: int, +): + group_name = c10d._resolve_group_name_by_ranks_and_tag(ranks, tag) + return torch.ops._c10d_functional.reduce_scatter_tensor_coalesced( + inputs, + reduce_op, + group_size, + group_name, + ) + + +def _all_to_all_single( + input: torch.Tensor, + output_split_sizes: list[int] | None, + input_split_sizes: list[int] | None, + tag: str, + ranks: list[int], + group_size: int, +): + if output_split_sizes is None or input_split_sizes is None: + if not (output_split_sizes is None and input_split_sizes is None): + raise AssertionError( + "output_split_sizes and input_split_sizes must either be " + "specified together or both set to None" + ) + output_split_sizes = [input.shape[0] // group_size] * group_size + input_split_sizes = output_split_sizes + + group_name = c10d._resolve_group_name_by_ranks_and_tag(ranks, tag) + return torch.ops._c10d_functional.all_to_all_single( + input, + output_split_sizes, + input_split_sizes, + group_name, + ) + + +def _wait_tensor(tensor: torch.Tensor) -> torch.Tensor: + return torch.ops._c10d_functional.wait_tensor(tensor) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_mesh_layout.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_mesh_layout.py new file mode 100644 index 0000000000000000000000000000000000000000..38026b7d3d5e1dc94325e6c26b4bf4c4841b4994 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_mesh_layout.py @@ -0,0 +1,309 @@ +""" +Definition of CuTe inspired Layouts for DeviceMesh internal bookkeeping and functions to manipulate them +""" + +import math +from collections.abc import Iterator +from dataclasses import dataclass +from itertools import product + +import torch +from torch.distributed._pycute import ( + as_tuple, + coalesce, + complement, + composition, + flatten, + IntTuple, + is_int, + is_tuple, + Layout, + match_structure, +) + + +@dataclass(frozen=True, init=True) +class _MeshLayout(Layout): + """ + Utility class for representing an integer layout by borrowing ideas from CuTe Layout Algebra. + See https://docs.nvidia.com/cutlass/media/docs/cpp/cute/02_layout_algebra.html for more details. + + Each layout is represented as a list of sizes and strides. We use it as a way for mechanical bookkeeping + of the integers such as ranks in a SPMD mesh, and the transformation on top of it. + + Lots of methods of layout like coalesce, composition, complement, etc. are borrowed from pycute. + https://github.com/NVIDIA/cutlass/blob/6dd13d42784ee5bfa232d2441e6b9a021c5c6290/python/pycute/layout.py#L137,L257 + + Note this is a CuTe-inspired layout, because CuTe uses co-lexicographic way in linearization while PyTorch + is using lexicographic. So even though the CuTe documentation can still be referenced, the implementation will be + different from that of PyCute's. + """ + + # pyrefly: ignore [bad-override] + shape: IntTuple + # pyrefly: ignore [bad-override] + stride: IntTuple + + def __post_init__(self) -> None: + if not is_tuple(self.shape) and not is_int(self.shape): + raise TypeError(f"shape must be a tuple or int, got {type(self.shape)}") + if not is_tuple(self.stride) and not is_int(self.stride): + raise TypeError(f"stride must be a tuple or int, got {type(self.stride)}") + if not match_structure(self.shape, self.stride): + raise ValueError( + f"sizes {self.shape} and strides {self.stride} don't match" + ) + + @property + def sizes(self) -> IntTuple: + return self.shape + + @property + def strides(self) -> IntTuple: + return self.stride + + @property + def sizes_and_strides(self) -> Iterator[tuple[int, int]]: + return zip(flatten(self.shape), flatten(self.stride)) + + @property + def top_level_sizes(self) -> tuple[int, ...]: + return tuple(self[i].numel() for i in range(len(self))) + + def numel(self) -> int: + return math.prod(flatten(self.shape)) + + # # operator [] (get-i like tuples) + def __getitem__(self, i: int) -> "_MeshLayout": + if i < -len(self) or i >= len(self): + raise IndexError( + f"Dim {i} is out of range for layout with {len(self)} dimensions. " + f"Expected dim to be in range [{-len(self)}, {len(self) - 1}]." + ) + layout = super().__getitem__(i) + return _MeshLayout(layout.shape, layout.stride) + + def nest(self) -> "_MeshLayout": + return _MeshLayout((self.shape,), (self.stride,)) + + def coalesce(self) -> "_MeshLayout": + """ + A layout is represented by (sizes):(strides), e.g. (3,2):(4,2). + Two consecutive dimensions can be "merged" into one if their + strides are contiguous/multiplicative (i.e., the inner stride * inner size + equals the next stride), we perform this kind of merge inside coalesce. + + Example 1 (simple): (3,2):(2,1) + - inner dimension: has stride=1, size=2 + - outer dimension: stride = inner_stride * inner_size = 2 + → coalesced = (6:1) # acts like a flat 1D array of length 6 + + Example 2 (non-coalescible): (3,2):(4,1) + - inner dimension: stride=1, size=2 → 2*1 = 2 + - outer dimension: stride=4, mismatch (≠ 2) + → cannot merge; result stays (3,2):(4,1) + """ + layout = coalesce(self) + return _MeshLayout(layout.shape, layout.stride) + + def composition(self, layout: "_MeshLayout") -> "_MeshLayout": + """ + By-dimension composition allows one layout to "select from" or "filter through" another layout. + Think of it as function composition: (self ∘ layout)(input) = self(layout(input)) + between two layouts. This function is a wrapper of pycute's composition. + + Mental model about how to understand the composition logic: + - The LEFT layout (self) defines the "output space" - what indices are possible + - The RIGHT layout (layout parameter) acts as a "selector" - which specific indices to pick + - The composition only generates indices that the left layout could originally produce, + but the right layout determines which indices to be picked. + - The stride of the composition layout will not be smaller than the stride of the right layout, + because when picking the indices the composition will at least follow the the right layout's stride + to move forward. + + Example: + self = (6,2):(2,1) # sizes=(6,2), strides=(2,1) + layout = (3:2) # sizes=(3,), stride=(2,) + self o layout = (3:2) + + Returns: + Layout being composed. + """ + result = composition(self, layout) + return _MeshLayout(result.shape, result.stride) + + def complement(self, world_size: int) -> "_MeshLayout": + """ + Compute the "complement layout" relative to a given world_size. + A complement layout fills in the "missing" factor so that: self repeat a layout of complement(self, world_size) + will get a complete world_size. We use ⊗ to denote the repeat operation. + + Example: + self = (4:1) # size=4, stride=1 + world_size = 8 + Then: + complete needed factor = 8 / 4 = 2 + complement(self, 8) = (2:1) + + Together they form: + (4:1) ⊗ (2:1) = (4,2):(2,1) + which has world_size = 4 * 2 = 8, as required. + + In distributed terms, complement() is often used to derive the "other" + rank grouping when splitting processes into 2D meshes. + + For a visualized explanation, see https://x.com/ezyang/status/1962364978393981433/ + """ + layout = complement(self, world_size) + return _MeshLayout(layout.shape, layout.stride) + + def splice(self, start: int, end: int, layout: "_MeshLayout") -> "_MeshLayout": + sizes = list(as_tuple(self.sizes)) + strides = list(as_tuple(self.strides)) + sizes[start:end] = list(as_tuple(layout.sizes)) + strides[start:end] = list(as_tuple(layout.strides)) + return _MeshLayout(tuple(sizes), tuple(strides)) + + def all_ranks_from_zero(self) -> list[int]: + """ + This function computes the all ranks specified by the layout staring from zero. + + How it works: + 1. we enumerates every possible coordinate (like a nested for-loop). + If sizes = (2, 3), we get the following coordinates: + (0,0), (0,1), (0,2), (1,0), (1,1), (1,2) + + 2. For each coordinate, we compute a linear rank index as: + all_ranks_from_zero = sum(coord[i] * strides[i] for i in range(ndim)) + + Example A: + sizes = (2, 3) # 2 rows, 3 cols + strides = (3, 1) # row-major layout + coords = (0,0) -> 0*3 + 0*1 = 0 + (0,1) -> 0*3 + 1*1 = 1 + (0,2) -> 0*3 + 2*1 = 2 + (1,0) -> 1*3 + 0*1 = 3 + (1,1) -> 1*3 + 1*1 = 4 + (1,2) -> 1*3 + 2*1 = 5 + result = [0, 1, 2, 3, 4, 5] + + Example B: + sizes = (2, 3) + strides = (1, 2) # non-standard / strided layout + coords = (0,0) -> 0*1 + 0*2 = 0 + (0,1) -> 0*1 + 1*2 = 2 + (0,2) -> 0*1 + 2*2 = 4 + (1,0) -> 1*1 + 0*2 = 1 + (1,1) -> 1*1 + 1*2 = 3 + (1,2) -> 1*1 + 2*2 = 5 + result = [0, 2, 4, 1, 3, 5] + """ + return [ + sum(c * s for c, s in zip(coord, flatten(self.strides))) + for coord in product(*(range(s) for s in flatten(self.sizes))) + ] + + def global_ranks(self, world_size: int) -> list[list[int]]: + """ + Build global ranks specified by the layout via two-level ranks composition. + + The nested list forms the Cartesian product of all ranks for one layout and offset + regarding filling up the world_size with the layout. + The final global ranks are the addition of these two. The result is a + list of lists: one sublist per layout. This rank list will be used to build + the communicator underlying the layout and the given `world_size`. + + Example: + world_size = 16 + self.size = 4 + self.stride = 1 + ranks = [0, 1, 2, 3] + offsets = [0, 4, 8, 12] + result = [ + [0+0, 0+1, 0+2, 0+3], # → [0, 1, 2, 3] + [4+0, 4+1, 4+2, 4+3], # → [4, 5, 6, 7] + [8+0, 8+1, 8+2, 8+3], # → [8, 9, 10,11] + [12+0, 12+1, 12+2, 12+3], # → [12,13,14,15] + ] + """ + return [ + [offset + rank for rank in self.all_ranks_from_zero()] + for offset in self.complement(world_size).all_ranks_from_zero() + ] + + def check_non_overlap(self) -> bool: + """ + Check if the layout has any overlap between the ranks it generates. If there is overlap, + we return False, otherwise True. + + The layout is supposed to be injective i.e, aside from indice 0, indices from each + dim of the layout must be non-overlapping. + + Example 1 - Valid (no overlap): + Layout: sizes=(2,3), strides=(6,1) + - Dim 1: stride=1, span=3*1=3, covers indices [0,1,2] + - Dim 0: stride=6, span=2*6=12, covers indices [0,6] + → No overlap since 6 > 3 + + Example 2 - Invalid (overlap): + Layout: sizes=(2,3), strides=(2,1) + - Dim 1: stride=1, span=3*1=3, covers indices [0,1,2] + - Dim 0: stride=2, span=2*2=4, covers indices [0,2] + → Overlap! stride=2 < span=3, so indices [0,2] are duplicated + + Example 3 - Invalid (overlap): + Layout: sizes=(4,2), strides=(1,1) + - Dim 1: stride=1, span=4, covers indices [0,1,2,3] + - Dim 0: stride=1, span=2, covers indices [0,1] + → Overlap! stride is same for two dims, so indices [0,2] are duplicated + + Returns: + bool: True if no overlap, False if overlap detected + """ + ranks = self.all_ranks_from_zero() + return len(ranks) == len(set(ranks)) + + def remap_to_tensor(self, rank_map: torch.Tensor) -> torch.Tensor: + """ + Leverage layout as an index for mesh tensor that re-maps the indexes after layout + transformation to actual device ranks. + + With this method, the cute layout serves as the backend of indices bookkeeping for the + mesh tensor when it comes to flatten, unflatten and slicing operations. The actual mesh + tensor still represents the actual device assignment and ranks. We need this function + to specify device allocation and create backend for a mesh. Although any transform of mesh tensors + can be treated as a view or subset of mesh tensor, we do need to use the actual view or + sub-tensor for DeviceMesh and its backend creation. + + The shape of the `rank_map` must be 1D and contiguous. + + Examples: + + Case 1 - Consecutive ranks, full world: + original_mesh_tensor = [[0,1],[2,3]] # 2x2 mesh, ranks 0-3 + world_size = 4 + layout = Layout(2:2) + Return: [[0,2],[1,3]] + + Case 2 - Non-consecutive ranks: + original_mesh_tensor = [[10,20],[30,40]] # custom rank assignment + world_size = 4 + layout = Layout(2:2) + Return: [[[10,30],[20,40]]] + + Args: + rank_map: The concrete mesh tensor with actual device ranks + + Returns: + torch.Tensor: A tensor representing the actual device allocation from rank_map + """ + assert rank_map.ndim == 1 + assert rank_map.is_contiguous() + assert rank_map.numel() >= self.cosize() + + complement_layout = self.complement(rank_map.numel()) + + return rank_map.as_strided( + flatten(complement_layout.sizes) + flatten(self.sizes), + flatten(complement_layout.strides) + flatten(self.strides), + ).reshape(-1, *self.top_level_sizes) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_serialization.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_serialization.py new file mode 100644 index 0000000000000000000000000000000000000000..8f7043453be769cccfb70cd391dac87348508016 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_serialization.py @@ -0,0 +1,158 @@ +import pickle +from dataclasses import dataclass +from io import BufferedIOBase +from typing import Any + +import torch +import torch._weights_only_unpickler as _weights_only_unpickler +from torch.serialization import _load, _save, DEFAULT_PROTOCOL, MAP_LOCATION + + +__all__: list[str] = [] + + +@dataclass +class _Entry: + key: str + is_storage: bool + length: int + + +_weights_only_unpickler._add_safe_globals([_Entry]) + + +class _PseudoZipFile: + def __init__(self) -> None: + self.records: dict[str, tuple[object, int]] = {} + + def write_record(self, key: str, data: object, length: int) -> None: + self.records[key] = (data, length) + + def write_to(self, f: BufferedIOBase) -> None: + entries = [] + for key, (data, length) in self.records.items(): + entries.append( + _Entry( + key=key, + is_storage=isinstance(data, torch.UntypedStorage), + length=length, + ) + ) + + pickle.dump(entries, f, protocol=DEFAULT_PROTOCOL) + + for data, _ in self.records.values(): + if isinstance(data, bytes): + f.write(data) + elif isinstance(data, str): + f.write(data.encode("utf-8")) + elif isinstance(data, torch.UntypedStorage): + data._write_file(f, False, False, 1) + else: + raise TypeError(f"unknown type: {type(data)}") + + def read_from(self, f: BufferedIOBase) -> None: + entries = _weights_only_unpickler.load(f) + + for entry in entries: + data = f.read(entry.length) + if entry.is_storage: + if entry.length == 0: + storage = torch.UntypedStorage(0) + else: + storage = torch.frombuffer( + data, + dtype=torch.uint8, + ).untyped_storage() + + self.records[entry.key] = ( + storage, + entry.length, + ) + else: + self.records[entry.key] = (data, entry.length) + + def has_record(self, key: str) -> bool: + return key in self.records + + def get_record(self, key: str) -> object: + return self.records[key][0] + + def get_storage_from_record( + self, key: str, _length: int, _type: int + ) -> torch.Tensor: + return torch.tensor(self.records[key][0], dtype=torch.uint8) + + def serialization_id(self) -> str: + return "torchft" + + +def _streaming_save( + obj: object, + f: BufferedIOBase, + pickle_module: Any = pickle, + pickle_protocol: int = DEFAULT_PROTOCOL, +) -> None: + """ + Save the object to a file-like object in a streaming fashion compatible with + network sockets. + + This behaves similarly to :func:`torch.save` with a few notable differences: + + * A non-seekable file like object can be used when loading. + * No forwards/backwards compatibility is provided for the serialization + format. This is only intended to be used with a single version of PyTorch + with transient storage (i.e. sockets or temp files). + * mmap is not supported + + See :func:`torch.save` for more details on specific arguments. + """ + + zip_file = _PseudoZipFile() + _save( + obj, + zip_file=zip_file, + pickle_module=pickle_module, + pickle_protocol=pickle_protocol, + _disable_byteorder_record=False, + ) + zip_file.write_to(f) + + +def _streaming_load( + f: BufferedIOBase, + map_location: MAP_LOCATION = None, + pickle_module: Any = None, + *, + weights_only: bool = True, + **pickle_load_args: Any, +) -> object: + """ + Load the object from a file-like object in a streaming fashion compatible with + network sockets. + + See :func:`_streaming_save` for more details about the streaming behavior. + + See :func:`torch.load` for more details on specific arguments. + """ + if weights_only: + if pickle_module is not None: + raise RuntimeError( + "Can not safely load weights when explicit pickle_module is specified" + ) + pickle_module = _weights_only_unpickler + else: + if pickle_module is None: + pickle_module = pickle + + if "encoding" not in pickle_load_args: + pickle_load_args["encoding"] = "utf-8" + + zip_file = _PseudoZipFile() + zip_file.read_from(f) + return _load( + zip_file=zip_file, + map_location=map_location, + pickle_module=pickle_module, + **pickle_load_args, + ) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_state_dict_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_state_dict_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5cb614e89cf9c29c8ed9a07d36b60e7f867002a8 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/_state_dict_utils.py @@ -0,0 +1,830 @@ +# mypy: allow-untyped-defs +import copy +import io +import math +import weakref +from collections.abc import Callable, Mapping, MutableMapping +from typing import Any, cast, NamedTuple, TYPE_CHECKING, Union + +import torch +import torch.cuda._pin_memory_utils as pin_memory_utils +import torch.distributed as dist +import torch.nn.functional as F +from torch.distributed._functional_collectives import AsyncCollectiveTensor + + +if dist.is_available() or TYPE_CHECKING: + from torch.distributed import distributed_c10d + from torch.distributed._shard.sharded_tensor import ShardedTensor + from torch.distributed.tensor import distribute_tensor, DTensor, Replicate + from torch.distributed.tensor._utils import compute_local_shape_and_global_offset + + +def _identity_func( + obj: torch.Tensor, + pg: dist.ProcessGroup | None, + device: torch.device | None, + companion_obj: Any, +) -> torch.Tensor: + return obj + + +def _all_gather_sharded_tensor( + sharded_tensor: "ShardedTensor", + pg: dist.ProcessGroup | None = None, + device: torch.device | None = None, +) -> torch.Tensor: + if pg is None: + pg = distributed_c10d._get_default_group() + world_size = dist.get_world_size(pg) + shards = sharded_tensor.local_shards() + dim_0_size = sharded_tensor.size()[0] # type: ignore[index] + tensor_numel = sharded_tensor.size().numel() # type: ignore[union-attr] + chunk_size = math.ceil(dim_0_size / world_size) * tensor_numel // dim_0_size + pg_device = ( + distributed_c10d._get_pg_default_device(pg) if device is None else device + ) + if shards: + local_tensor = shards[0].tensor.flatten() + if local_tensor.device.type != pg_device.type: + local_tensor = local_tensor.to(pg_device) + num_padding = chunk_size - local_tensor.numel() + if num_padding > 0: + local_tensor = F.pad(local_tensor, [0, num_padding]) + else: + local_tensor = torch.zeros( + chunk_size, dtype=sharded_tensor.dtype, device=pg_device + ) + + tensor = torch.empty( + chunk_size * world_size, + dtype=local_tensor.dtype, + device=pg_device, + ) + dist.all_gather_into_tensor(tensor, local_tensor, group=pg) + + tensor = tensor.narrow(0, 0, tensor_numel).reshape(sharded_tensor.size()) + return tensor + + +class CompanionMismatch(Exception): + pass + + +def _iterate_state_dict( + iter_object: Any, + sharded_tensor_func: Callable, + dtensor_func: Callable, + tensor_func: Callable, + *, + pg: dist.ProcessGroup | None = None, + device: torch.device | None = None, + cpu_offload: bool = False, + companion_obj: Any = None, + ranks_only: tuple[int, ...] = (), + type_check: bool = True, + non_blocking: bool = True, +) -> dict[str, Any]: + """Iterate through the state dict, applying the given functions to each tensor type. + + Args: + iter_object (Any): the target state_dict. + sharded_tensor_func (Callable): the function to apply to ShardedTensor + dtensor_func (Callable): the function to apply to DTensor + tensor_func (Callable): the function to apply to Tensor + pg (Optional[dist.ProcessGroup]): process group passed to tensor functions + device (Optional[torch.device]): device passed to tensor functions + cpu_offload (bool): whether to offload the tensors to CPU memory. This option is ignored + if a companion_obj is supplied. + companion_obj (Any): A companion object to the state dict. If this object + is supplied, we attempt to copy the tensor to the companion object. + ranks_only (Tuple[int, ...]): if this tuple is empty, all ranks will + have the same state_dicts. Otherwise only ranks that in ``ranks_only`` + have the same state_dicts. Other ranks will get empty state_dicts. + type_check (bool): check if the instance data type is a supported type + that can be saved by DCP. The current supported data types are + torch.Tensor, DTensor, int, float, str, list, dict, None. + non_blocking (bool): whether to use non-blocking copy when copying to the companion object. + """ + # TODO: should we use pytree? + cpu_device = torch.device("cpu") + if isinstance(iter_object, ShardedTensor): + ret = sharded_tensor_func(iter_object, pg, device, companion_obj) + elif isinstance(iter_object, DTensor): + ret = dtensor_func(iter_object, pg, device, companion_obj) + elif isinstance(iter_object, torch.Tensor): + ret = tensor_func(iter_object, pg, device, companion_obj) + elif ( + isinstance(iter_object, (int, float, str, bytes, io.BytesIO)) + or iter_object is None + ): + ret = iter_object + elif isinstance(iter_object, dict): + if companion_obj is not None and ( + not isinstance(companion_obj, dict) + or set(companion_obj.keys()) != set(iter_object.keys()) + ): + msg = ( + "" + if isinstance(companion_obj, dict) + else f"{set(companion_obj.keys())=} {set(iter_object.keys())=}" + ) + raise CompanionMismatch(msg) + + ret = { + key: _iterate_state_dict( + value, + sharded_tensor_func, + dtensor_func, + tensor_func, + pg=pg, + device=device, + cpu_offload=cpu_offload, + companion_obj=companion_obj[key] if companion_obj is not None else None, + ranks_only=ranks_only, + type_check=type_check, + non_blocking=non_blocking, + ) + for key, value in iter_object.items() + } + elif isinstance(iter_object, (list, tuple)): + if companion_obj is not None and ( + not isinstance(companion_obj, (list, tuple)) + or len(companion_obj) != len(iter_object) + ): + raise CompanionMismatch + + ret = [ + _iterate_state_dict( + v, + sharded_tensor_func, + dtensor_func, + tensor_func, + pg=pg, + device=device, + cpu_offload=cpu_offload, + companion_obj=companion_obj[idx] if companion_obj is not None else None, + ranks_only=ranks_only, + type_check=type_check, + non_blocking=non_blocking, + ) + for idx, v in enumerate(iter_object) + ] + if isinstance(iter_object, tuple): + ret = tuple(ret) + elif not type_check: + ret = copy.deepcopy(iter_object) + else: + raise ValueError(f"Unexpected value type {type(iter_object)}") + + if not ranks_only or dist.get_rank(pg) in ranks_only: + if isinstance(ret, torch.Tensor): + if cpu_offload and companion_obj is None: + ret = ret.to(cpu_device) + + if companion_obj is not None: + if isinstance(companion_obj, DTensor): + if not isinstance(ret, DTensor): + raise AssertionError( + "ret must be a DTensor when companion_obj is a DTensor" + ) + companion_obj._local_tensor.copy_( + ret._local_tensor, non_blocking=non_blocking + ) + elif isinstance(companion_obj, ShardedTensor): + if not isinstance(ret, ShardedTensor): + raise AssertionError( + "ret must be a ShardedTensor when companion_obj is a ShardedTensor" + ) + for idx, shard in enumerate(companion_obj.local_shards()): + shard.tensor.copy_( + ret.local_shards()[idx].tensor, non_blocking=non_blocking + ) + else: + # pyrefly: ignore [missing-attribute] + companion_obj.copy_(ret, non_blocking=non_blocking) + ret = companion_obj + else: + ret = {} if isinstance(ret, dict) else None + + # pyrefly: ignore [bad-return] + return ret + + +def _gather_state_dict( + state_dict: dict[str, Any], + *, + pg: dist.ProcessGroup | None = None, + device: torch.device | None = None, + cpu_offload: bool = False, + ranks_only: tuple[int, ...] = (), + type_check: bool = True, +) -> dict[str, Any]: + """ + Given a state_dict, this API gathers all the ShardedTensors or DTensors in + the state_dict. + + + Args: + state_dict (Dict[str, Any]): the target sharded state_dict. + pg (Optional[dist.ProcessGroup]): the process group that is used to + gather ShardedTensor. Note that gathering a DTensor will use + the DeviceMesh. So this argument will be ignored when gathering a + DTensor. + device: (Optional[torch.device]): the device that is used to + perform allgather for ShardedTensor. Note that gathering a DTensor + will use the DeviceMesh. So this argument will be ignored when + gathering a DTensor. + cpu_offload (bool): whether to offload the tensors to CPU memory. The + default value is False. + ranks_only: (Tuple[int, ...]): if this tuple is empty, all ranks will + have the same state_dicts. Otherwise only ranks that in ``ranks_only`` + have the same state_dicts. Other ranks will get empty state_dicts. + type_check: (bool): check if the instance data type is a supported type + that can be saved by DCP. The current supported data types are + torch.Tensor, DTensor, int, float, str, list, dict, None. + + Returns: + The gathered state dictionary. + """ + + def sharded_tensor_func(value, pg, device, companion_obj): + # ShardedTensor does not seem to record the original device type. + # So if the tensor is moved to CPU, we won't know the original type. + # As a result, we have to rely on the user to tell us the correct one. + cpu_device = torch.device("cpu") + output_tensor = _all_gather_sharded_tensor(value, pg, device) + local_shard_device = ( + value.local_shards()[0].tensor.device + if value.local_shards() + else cpu_device + ) + if output_tensor.device != local_shard_device: + value = output_tensor.to(local_shard_device) + else: + value = output_tensor + return value + + def dtensor_func(value, pg, device, companion_obj): + if value.device != value.device_mesh.device_type: + value = value.to(value.device_mesh.device_type) + # FSDP all_gather: [Shard(0)] -> [Replicate()] + # HSDP all_gather: [Replicate(), Shard(0)] -> [Replicate(), Replicate()] + # 2D FSDP + TP all_gather: + # - [Shard(0), Shard(n)] -> [Replicate(), Replicate()] + # - [Shard(0), Replicate()] -> [Replicate(), Replicate()] + placements = [Replicate() for _ in value.placements] + value = value.redistribute( + device_mesh=value.device_mesh, + placements=placements, + ) + # Call `wait()` to force the tensor to be synchronous with respect + # to the main stream. + # See the discussion in https://github.com/pytorch/pytorch/pull/117799. + value = value.to_local() + if isinstance(value, AsyncCollectiveTensor): + value = value.wait() + return value + + return _iterate_state_dict( + state_dict, + sharded_tensor_func, + dtensor_func, + _identity_func, + pg=pg, + device=device, + cpu_offload=cpu_offload, + ranks_only=ranks_only, + type_check=type_check, + ) + + +def _offload_state_dict_to_cpu( + state_dict: dict[str, Any], + *, + ranks_only: tuple[int, ...] = (), + type_check: bool = True, +) -> dict[str, Any]: + """ + Given a state_dict, this API offload all the tensors to CPU memory. + + Args: + state_dict (Dict[str, Any]): the target state_dict. + pg (Optional[dist.ProcessGroup]): the process group that is used to + gather ShardedTensor. Note that gathering a DTensor will use + the DeviceMesh. So this argument will be ignored when gathering a + DTensor. + ranks_only: (Tuple[int, ...]): if this tuple is empty, all ranks will + have the same state_dicts. Otherwise only ranks that in ``ranks_only`` + have the same state_dicts. Other ranks will get empty state_dicts. + type_check: (bool): check if the instance data type is a supported type + that can be saved by DCP. The current supported data types are + torch.Tensor, DTensor, int, float, str, list, dict, None. + + Returns: + The gathered state dictionary. + """ + + ret = _iterate_state_dict( + state_dict, + _identity_func, + _identity_func, + _identity_func, + pg=None, + device=None, + cpu_offload=True, + ranks_only=ranks_only, + type_check=type_check, + ) + return ret + + +@torch.no_grad() +def _copy_state_dict( + state_dict: dict[str, Any], + copy_state_dict: dict[str, Any], + non_blocking: bool = False, + type_check: bool = True, +) -> dict[str, Any]: + """ + Copies all tensors in a given state dict into a different state_dict with the + same structure. Additionally, a copied state dict with the same value references + is returned. Editing the keys on this state dict will not affect the + passed in copy_state_dict (but the value references are the same). + + .. warning:: + It is expected by this function that state_dict and copy_state_dict share + the same structure and data types. + + .. warning:: + The current supported data types are + torch.Tensor, DTensor, int, float, str, list, dict, None. + + Args: + state_dict (Dict[str, Any]): the target state_dict. + copy_state_dict (Dict[str, Any]): + The state dict we are copying into. This state_dict must have exactly + the same structure as the source `state_dict`. + non_blocking: (bool): Whether copy ops should be performed asynchronously + type_check (bool): check if the instance data type is a supported type + that can be saved by DCP. The current supported data types are + torch.Tensor, DTensor, int, float, str, list, dict, None. + + Returns: + State Dict copy + """ + + return _iterate_state_dict( + state_dict, + _identity_func, + _identity_func, + _identity_func, + pg=None, + device=None, + cpu_offload=False, + ranks_only=(), + companion_obj=copy_state_dict, + type_check=type_check, + non_blocking=non_blocking, + ) + + +@torch.no_grad() +def _create_cpu_state_dict( + state_dict: dict[str, Any], pin_memory: bool = False, share_memory: bool = False +) -> dict[str, Any]: + """ + Given a state_dict, create another state_dict with the same structure and elements. + However, all tensors in the returned state_dict are new tensors on CPU. These + tensors can be placed on pin_memory or share_memory based on the provided arguments. + + .. warning:: + Setting both `pin_memory` and `share_memory` to True significantly increases the + latency of this method because of the nuances which require us to register memory + as pinned directly as opposed to relying on the pin_memory cache allocator. This + option should only be used for long lived tensors which are required to be shared. + This is not the case as long as at least one of `pin_memory` or `share_memory` is + set to False. + + """ + + def tensor_func( + obj: torch.Tensor, + pg: dist.ProcessGroup | None, + device: torch.device | None, + _: Any, + ) -> torch.Tensor: + if len(obj.size()) == 0: + return torch.tensor(0, dtype=obj.dtype) + + # sometimes, a tensor might have non-zero size and 0 numel. In this case, pinning memory will fail + # so we take a best guess at how to replicate the tensor below to maintain symmetry in the returned + # state dict. + if obj.numel() == 0 or obj.data_ptr() == 0: + t = torch.zeros_like(obj, device="cpu") + if share_memory: + t = t.share_memory_() + return t + + if share_memory: + t = torch.empty(*tuple(obj.size()), dtype=obj.dtype) + t = t.share_memory_() + if pin_memory: + pin_memory_utils.pin_memory(t.data_ptr(), t.numel() * t.element_size()) + weakref.finalize(t, pin_memory_utils.unpin_memory, t.data_ptr()) + + return t + elif pin_memory: + return torch.empty(*tuple(obj.size()), dtype=obj.dtype).pin_memory() + else: + return torch.empty(*tuple(obj.size()), dtype=obj.dtype) + + def dtensor_func( + obj: DTensor, + pg: dist.ProcessGroup | None, + device: torch.device | None, + _: Any, + ) -> DTensor: + if len(obj.size()) == 0: + return obj + + if obj.device != torch.device("cpu"): + ret = cast(DTensor, obj.to(device="cpu")) + else: + ret = copy.deepcopy(obj) + ret._local_tensor = tensor_func(ret._local_tensor, pg, device, None) + return ret + + def sharded_tensor_func( + obj: ShardedTensor, + pg: dist.ProcessGroup | None, + device: torch.device | None, + _: Any, + ) -> ShardedTensor: + if not obj.local_shards(): + return obj + + if obj.device != torch.device("cpu"): + ret = obj.to(device="cpu") + else: + ret = copy.deepcopy(obj) + + for shards in ret.local_shards(): + shards.tensor = tensor_func(shards.tensor, pg, device, None) + + return ret + + ret = _iterate_state_dict( + state_dict, + sharded_tensor_func, + dtensor_func, + tensor_func, + pg=None, + device=None, + cpu_offload=False, + ranks_only=(), + type_check=False, + ) + return ret + + +def _check_state_dict_similarity( + state_dict: dict[str, Any], + compared_state_dict: dict[str, Any], +) -> bool: + """ + Given two state_dicts, check if the structures are the same. And + if a [key, tensor] pair exist in one state_dict there must be + the a corresponding pait, [key, other_tensor], in the other state_dict, + where tensor and other_tensor have the same size and dtype. + + Return the check result. + """ + + def tensor_func( + obj: torch.Tensor, + pg: dist.ProcessGroup | None, + device: torch.device | None, + companion_obj: Any, + ) -> torch.Tensor: + if companion_obj.dtype != obj.dtype or companion_obj.size() != obj.size(): + raise CompanionMismatch + return obj + + try: + _iterate_state_dict( + state_dict, + _identity_func, + _identity_func, + tensor_func, + pg=None, + device=None, + cpu_offload=False, + ranks_only=(), + companion_obj=compared_state_dict, + type_check=False, + ) + except CompanionMismatch: + return False + + return True + + +class _TensorInfo(NamedTuple): + size: torch.Size + dtype: torch.dtype + + +def _broadcast_tensors( + full_state_dict: dict[str, Any], + local_state_dict: dict[str, Any], + keys: list[str], + device: torch.device, + pg: dist.ProcessGroup | None = None, +) -> None: + if pg is None: + pg = dist.distributed_c10d._get_default_group() + pg_device = ( + device + if device.type in {pg_device.type for pg_device in pg._device_types} + else pg._device_types[0] + ) + + tensors: list[torch.Tensor] = [] + for key in keys: + if dist.get_rank() == 0: + full_state = full_state_dict[key] + if not isinstance(full_state, torch.Tensor): + raise AssertionError("full_state must be a torch.Tensor") + full_tensor = full_state.detach().to(pg_device) + else: + tensor_info = full_state_dict[key] + full_tensor = torch.empty( + size=tensor_info.size, + device=pg_device, + dtype=tensor_info.dtype, + ) + + tensors.append(full_tensor) + + if (local_state := local_state_dict.get(key)) is None: + continue + + local_state_dict[key] = ( + (local_state, full_tensor) + if isinstance(local_state, DTensor) + else full_tensor + ) + + if len(tensors) > 1: + dist._broadcast_coalesced(pg, tensors, 500, 0) + else: + dist.broadcast(tensors[0], src=0, group=pg) + + if pg_device != device: + for key, full_tensor in zip(keys, tensors): + if (local_state := local_state_dict.get(key)) is not None: + local_state_dict[key] = ( + (local_state[0], full_tensor.to(device)) + if ( + isinstance(local_state, tuple) + and isinstance(local_state[0], DTensor) + ) + else full_tensor.to(device) + ) + + _distribute_tensors(local_state_dict, keys, device, pg) + + +def _distribute_tensors( + local_state_dict: dict[str, Any], + keys: list[str], + device: torch.device, + pg: dist.ProcessGroup | None = None, +) -> None: + if pg is None: + pg = dist.distributed_c10d._get_default_group() + for key in keys: + _local_state = local_state_dict.get(key) + if _local_state is None or torch.is_tensor(_local_state): + continue + + local_state = _local_state[0] + full_tensor = _local_state[1] + + shape, offset = compute_local_shape_and_global_offset( + full_tensor.shape, local_state.device_mesh, local_state.placements + ) + slices = [ + slice(cur_offset, cur_offset + cur_shape) + for cur_shape, cur_offset in zip(shape, offset) + ] + if local_state.is_meta: + # Use .clone() here rather than view to clone and return only the sliced portion, minimizing memory access and cost. + local_tensor = full_tensor[tuple(slices)].detach().clone() + # TODO: currently, we cannot handle strided sharding if the dp dimension is not even. For example, + # one of the case that is not yet supported is when placements = (Shard(0), _StridedShard(0, sf=2)). + ret = DTensor.from_local( + local_tensor, + local_state.device_mesh, + local_state.placements, + shape=local_state.shape, + stride=local_state.stride(), + ) + else: + ret = local_state + # Copy full_tensor[slices] into local_state.to_local() to reduce memory footprint. + ret.to_local().copy_(full_tensor[tuple(slices)]) + local_state_dict[key] = ret + + +def _broadcast_state_dict( + full_state_dict: dict[str, Any], + local_state_dict: dict[str, Any], + device: torch.device, + pg: dist.ProcessGroup | None = None, + strict: bool = False, + cpu_offload: bool = False, +) -> None: + # Broadcast from rank0's `full_state_dict` to all ranks' `local_state_dict`. + # If strict is True, any keys in `local_state_dict` but not in `full_state_dict` + # will be removed from `local_state_dict`. + ret = {} + if dist.get_rank() == 0: + for key, value in full_state_dict.items(): + if not torch.is_tensor(value): + ret[key] = value + elif value.dim() == 0: + ret[key] = value.cpu() + else: + ret[key] = _TensorInfo(value.size(), value.dtype) + + broadcast_list = [ret] + dist.broadcast_object_list(broadcast_list, src=0, group=pg) + ret = broadcast_list[0] + # Gather values + keys = [] + local_state_dict_keys = set(local_state_dict.keys()) + global_keys = set() + for key, value in ret.items(): + global_keys.add(key) + if not isinstance(value, _TensorInfo): + if key in local_state_dict: + local_state_dict[key] = value + continue + + if dist.get_rank() == 0: + ret[key] = full_state_dict[key] + + keys.append(key) + # Broadcast every tensor to avoid OOM for now. + if len(keys) >= 1: + _broadcast_tensors(ret, local_state_dict, keys, device, pg) + if cpu_offload: + for key in keys: + local_state_dict[key] = local_state_dict[key].cpu() + keys.clear() + + if strict: + if missing_keys := (local_state_dict_keys - global_keys): + for key in missing_keys: + local_state_dict.pop(key) + + if keys: + _broadcast_tensors(ret, local_state_dict, keys, device, pg) + if cpu_offload: + for key in keys: + local_state_dict[key] = local_state_dict[key].cpu() + + +def _distribute_state_dict( + full_state_dict: dict[str, Any], + local_state_dict: dict[str, Any], + device: torch.device, + pg: dist.ProcessGroup | None = None, +) -> None: + # Full_state_dict = True, broadcast_from_rank0 = False here. Each rank has + # full_state_dict. Skip the broadcast in ``_broadcast_state_dict`` and + # distribute tensors in each rank + for key, value in full_state_dict.items(): + if key not in full_state_dict: + continue + if not torch.is_tensor(value): + local_state_dict[key] = value + elif value.dim() == 0: + local_state_dict[key] = value.cpu() + else: + if not isinstance(value, torch.Tensor): + raise AssertionError("value must be a torch.Tensor") + local_state = local_state_dict.get(key) + if local_state is None: + continue + elif isinstance(local_state, DTensor): + local_state_dict[key] = distribute_tensor( + value.detach().to(device), + local_state.device_mesh, + local_state.placements, + ) + else: + local_state_dict[key] = value.detach().to(device) + + +# These APIs are from torch.distributed.checkpoint. +# TODO: We should consolidate the code here as some not all modules can depend on +# DCP. +PATH_ITEM = Union[str, int] +OBJ_PATH = tuple[PATH_ITEM, ...] +FLATTEN_MAPPING = dict[str, OBJ_PATH] +STATE_DICT_TYPE = dict[str, Any] +CONTAINER_TYPE = MutableMapping[PATH_ITEM, Any] + + +def _traverse_state_dict( + state_dict: STATE_DICT_TYPE, + visitor: Callable[[OBJ_PATH, Any], None], +) -> None: + """ + Invoke ``visitor`` for each value recursively in ``state_dict``. + Mapping, list, and tuple will be flattened and other value types are treated + as the terminal values and will invoke ``visitor``. + """ + + def _traverse_obj(path: OBJ_PATH, value: Any) -> None: + if isinstance(value, Mapping): + for k, v in value.items(): + _traverse_obj(path + (str(k),), v) + elif isinstance(value, (list, tuple)): + for i, v in enumerate(value): + _traverse_obj(path + (i,), v) + else: + visitor(path, value) + + for key, value in state_dict.items(): + _traverse_obj((str(key),), value) + + +def _flatten_state_dict( + state_dict: STATE_DICT_TYPE, +) -> tuple[STATE_DICT_TYPE, FLATTEN_MAPPING]: + """ + Flatten ``state_dict`` made of nested dicts and lists into a top level dictionary. + + Use ``unflatten_state_dict`` to revert this process. + Returns: + A tuple with the flatten state_dict and a mapping from original to new state_dict. + N.B. The new keys are derived from the object paths, joined by dot. + For example: ``{ 'a': {'b':...}}`` results in the key `a.b`. + """ + flattened: STATE_DICT_TYPE = {} + mappings: FLATTEN_MAPPING = {} + + def flat_copy(path: OBJ_PATH, value: Any) -> None: + new_fqn = ".".join(map(str, path)) + if new_fqn in flattened: + raise ValueError(f"duplicated flatten key {new_fqn}") + flattened[new_fqn] = value + mappings[new_fqn] = path + + _traverse_state_dict(state_dict, flat_copy) + return flattened, mappings + + +def _set_element(root_dict: STATE_DICT_TYPE, path: OBJ_PATH, value: Any) -> None: + """Set ``value`` in ``root_dict`` along the ``path`` object path.""" + cur_container = cast(CONTAINER_TYPE, root_dict) + + def extend_list(lst: list[Any], idx: int) -> None: + while len(lst) <= idx: + lst.append(None) + + for i in range(1, len(path)): + prev_key = path[i - 1] + key = path[i] + def_val: CONTAINER_TYPE | list[Any] = {} if type(key) is str else [] + + if isinstance(cur_container, Mapping): + cur_container = cast( + CONTAINER_TYPE, cur_container.setdefault(prev_key, def_val) + ) + else: + # pyrefly: ignore [bad-argument-type] + extend_list(cur_container, prev_key) + if cur_container[prev_key] is None: + cur_container[prev_key] = def_val + cur_container = cur_container[prev_key] + + key = path[-1] + if type(key) is int: + extend_list(cast(list[Any], cur_container), key) + + cur_container[key] = value + + +def _unflatten_state_dict( + state_dict: STATE_DICT_TYPE, mapping: FLATTEN_MAPPING +) -> STATE_DICT_TYPE: + """Restore the original nested state_dict according to ``mapping`` and the flattened ``state_dict``.""" + nested: STATE_DICT_TYPE = {} + for key, value in state_dict.items(): + _set_element(nested, mapping[key], value) + return nested diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/argparse_util.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/argparse_util.py new file mode 100644 index 0000000000000000000000000000000000000000..c475eebf21273abb53ab99e3edcbdef18e9f0c8f --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/argparse_util.py @@ -0,0 +1,104 @@ +#!/usr/bin/env python3 +# mypy: allow-untyped-defs + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import os +from argparse import Action + + +class env(Action): + """ + Get argument values from ``PET_{dest}`` before defaulting to the given ``default`` value. + + For flags (e.g. ``--standalone``) + use ``check_env`` instead. + + .. note:: when multiple option strings are specified, ``dest`` is + the longest option string (e.g. for ``"-f", "--foo"`` + the env var to set is ``PET_FOO`` not ``PET_F``) + + Example: + :: + + parser.add_argument("-f", "--foo", action=env, default="bar") + + ./program -> args.foo="bar" + ./program -f baz -> args.foo="baz" + ./program --foo baz -> args.foo="baz" + PET_FOO="env_bar" ./program -f baz -> args.foo="baz" + PET_FOO="env_bar" ./program --foo baz -> args.foo="baz" + PET_FOO="env_bar" ./program -> args.foo="env_bar" + + parser.add_argument("-f", "--foo", action=env, required=True) + + ./program -> fails + ./program -f baz -> args.foo="baz" + PET_FOO="env_bar" ./program -> args.foo="env_bar" + PET_FOO="env_bar" ./program -f baz -> args.foo="baz" + """ + + def __init__(self, dest, default=None, required=False, **kwargs) -> None: + env_name = f"PET_{dest.upper()}" + default = os.environ.get(env_name, default) + + # ``required`` means that it NEEDS to be present in the command-line args + # rather than "this option requires a value (either set explicitly or default" + # so if we found default then we don't "require" it to be in the command-line + # so set it to False + if default: + required = False + + super().__init__(dest=dest, default=default, required=required, **kwargs) + + def __call__(self, parser, namespace, values, option_string=None): + setattr(namespace, self.dest, values) + + +class check_env(Action): + """ + Check whether the env var ``PET_{dest}`` exists before defaulting to the given ``default`` value. + + Equivalent to + ``store_true`` argparse built-in action except that the argument can + be omitted from the commandline if the env var is present and has a + non-zero value. + + .. note:: it is redundant to pass ``default=True`` for arguments + that use this action because a flag should be ``True`` + when present and ``False`` otherwise. + + Example: + :: + + parser.add_argument("--verbose", action=check_env) + + ./program -> args.verbose=False + ./program --verbose -> args.verbose=True + PET_VERBOSE=1 ./program -> args.verbose=True + PET_VERBOSE=0 ./program -> args.verbose=False + PET_VERBOSE=0 ./program --verbose -> args.verbose=True + + Anti-pattern (don't do this): + + :: + + parser.add_argument("--verbose", action=check_env, default=True) + + ./program -> args.verbose=True + ./program --verbose -> args.verbose=True + PET_VERBOSE=1 ./program -> args.verbose=True + PET_VERBOSE=0 ./program -> args.verbose=False + + """ + + def __init__(self, dest, default=False, **kwargs) -> None: + env_name = f"PET_{dest.upper()}" + default = bool(int(os.environ.get(env_name, "1" if default else "0"))) + super().__init__(dest=dest, const=True, default=default, nargs=0, **kwargs) + + def __call__(self, parser, namespace, values, option_string=None): + setattr(namespace, self.dest, self.const) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/c10d_logger.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/c10d_logger.py new file mode 100644 index 0000000000000000000000000000000000000000..1dfae5b92962f44f4dea3a3393cbcb6ae752999b --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/c10d_logger.py @@ -0,0 +1,100 @@ +#!/usr/bin/env python3 +# mypy: allow-untyped-defs + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import functools +import logging +from collections.abc import Callable +from typing import Any, TypeVar +from typing_extensions import ParamSpec + +import torch +import torch.distributed as dist +from torch.distributed.logging_handlers import _log_handlers +from torch.monitor import _WaitCounter + + +__all__: list[str] = [] + +_DEFAULT_DESTINATION = "default" + + +def _get_or_create_logger(destination: str = _DEFAULT_DESTINATION) -> logging.Logger: + logging_handler, log_handler_name = _get_logging_handler(destination) + logger = logging.getLogger(f"c10d-{log_handler_name}") + logger.setLevel(logging.DEBUG) + formatter = logging.Formatter( + "%(asctime)s %(filename)s:%(lineno)s %(levelname)s p:%(processName)s t:%(threadName)s: %(message)s" + ) + logging_handler.setFormatter(formatter) + logger.propagate = False + logger.addHandler(logging_handler) + return logger + + +def _get_logging_handler( + destination: str = _DEFAULT_DESTINATION, +) -> tuple[logging.Handler, str]: + log_handler = _log_handlers[destination] + log_handler_name = f"{type(log_handler).__name__}-{destination}" + return (log_handler, log_handler_name) + + +# pyrefly: ignore [unknown-name] +global _c10d_logger +_c10d_logger = _get_or_create_logger() + + +def _get_msg_dict(func_name, *args, **kwargs) -> dict[str, Any]: + if dist.is_initialized(): + group = kwargs.get("group") or kwargs.get("process_group") + msg_dict = { + "func_name": f"{func_name}", + "pg_name": f"{dist._get_process_group_name(kwargs.get('pg'))}", # type: ignore[arg-type] + "backend": f"{dist.get_backend(group)}", + "world_size": f"{dist.get_world_size()}", + "group_size": f"{dist.get_world_size(group)}", + "global_rank": f"{dist.get_rank()}", + "local_rank": f"{dist.get_rank(group)}", + } + if msg_dict["backend"] == "nccl": + nccl_version = torch.cuda.nccl.version() + msg_dict["nccl_version"] = ".".join(str(v) for v in nccl_version) + else: + msg_dict = { + "func_name": f"{func_name}", + } + return msg_dict + + +_T = TypeVar("_T") +_P = ParamSpec("_P") + + +def _exception_logger(func: Callable[_P, _T]) -> Callable[_P, _T]: + @functools.wraps(func) + def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _T: + try: + return func(*args, **kwargs) + except Exception as error: + msg_dict = _get_msg_dict(func.__name__, *args, **kwargs) + msg_dict["error"] = f"{error}" + _c10d_logger.debug(msg_dict) + raise + + return wrapper + + +def _time_logger(func: Callable[_P, _T]) -> Callable[_P, _T]: + @functools.wraps(func) + def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _T: + with _WaitCounter(f"pytorch.wait_counter.c10d.{func.__name__}").guard(): + func_return = func(*args, **kwargs) + return func_return + + return wrapper diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/collective_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/collective_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..cb20c58f13309152c9e1cebaf38995bcc8b390fb --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/collective_utils.py @@ -0,0 +1,354 @@ +#!/usr/bin/env python3 + + +""" +A set of primitive functions for performing collective ops. + +Each should also handle single rank scenario. +""" + +from __future__ import annotations + +import importlib +import logging +from collections import defaultdict +from dataclasses import dataclass +from typing import Any, cast, Generic, TYPE_CHECKING, TypeVar + + +if TYPE_CHECKING: + from collections.abc import Callable, Iterable + +import torch +import torch.distributed as dist + + +__all__: list[str] = [ + "SyncPayload", + "broadcast", + "all_gather", + "all_gather_object_enforce_type", +] + +logger = logging.getLogger(__name__) + +T = TypeVar("T") + + +@dataclass +class SyncPayload(Generic[T]): + stage_name: str | None + success: bool + payload: T + exception: Exception | None = None + + +def broadcast( + data_or_fn: T | Callable[[], T], + *, + success: bool = True, + stage_name: str | None = None, + rank: int = 0, + pg: dist.ProcessGroup | None = None, +) -> T: + """ + Broadcasts the data payload from rank 0 to all other ranks. + Or if a function is passed, execute it in rank 0 and broadcast result to all other ranks. + + Can be used to broadcast a failure signal to stop all ranks. + + If the function raises an exception, all ranks will raise. + + Args: + data_or_fn: the data to broadcast or function to execute and broadcast result. + success: False to stop all ranks. + stage_name: the name of the logical stage for synchronization and debugging + rank: rank to broadcast data or execute function and broadcast results. + pg: the process group for sync + Throws: + RuntimeError from original exception trace + Returns: + the value after synchronization + + Example usage: + >> id = broadcast(data_or_fn=allocate_id, rank=0, pg=ext_pg.my_pg) + """ + + if not success and data_or_fn is not None: + raise AssertionError( + "Data or Function is expected to be None if not successful" + ) + + payload: T | None = None + exception: Exception | None = None + # if no pg is passed then execute if rank is 0 + if (pg is None and rank == 0) or (pg is not None and pg.rank() == rank): + # determine if it is an executable function or data payload only + if callable(data_or_fn): + try: + payload = data_or_fn() + except Exception as e: + success = False + exception = e + else: + payload = data_or_fn + + # broadcast the exception type if any to all ranks for failure categorization + sync_obj = SyncPayload( + stage_name=stage_name, + success=success, + payload=payload, + exception=exception, + ) + + if pg is not None: + broadcast_list = [sync_obj] + dist.broadcast_object_list(broadcast_list, src=rank, group=pg) + if len(broadcast_list) != 1: + raise AssertionError( + f"Expected broadcast_list to have exactly 1 element, got {len(broadcast_list)}" + ) + sync_obj = broadcast_list[0] + + # failure in any rank will trigger a throw in every rank. + if not sync_obj.success: + error_msg = f"Rank {rank} failed" + if stage_name is not None: + error_msg += f": stage {sync_obj.stage_name}" + if sync_obj.exception is not None: + error_msg += f": exception {sync_obj.exception}" + # pyrefly: ignore [invalid-inheritance] + raise RuntimeError(error_msg) from sync_obj.exception + + return cast(T, sync_obj.payload) + + +def all_gather( + data_or_fn: T | Callable[[], T], + stage_name: str | None = None, + pg: dist.ProcessGroup | None = None, +) -> list[T]: + """ + A simple all_gather primitive with basic synchronization guard logic, + by checking payload from all ranks has the same stage name. + + Args: + data_or_fn: the data to be all gathered across ranks or function to be executed + stage_name: the sync stage name for out-of-sync protection + pg: the process group for sync + Throws: + RuntimeError from original exception trace + Returns: + a list of synced data from all ranks + + Example usage: + >> all_ids = all_gather(data_or_fn=allocate_id, pg=ext_pg.my_pg) + """ + payload: T | None = None + exception: Exception | None = None + success = True + # determine if it is an executable function or data payload only + if callable(data_or_fn): + try: + payload = data_or_fn() + except Exception as e: + success = False + exception = e + else: + payload = data_or_fn + + sync_obj = SyncPayload( + stage_name=stage_name, + success=success, + payload=payload, + exception=exception, + ) + + if pg is not None: + # List of success/failure across all ranks. + total_list = [None] * dist.get_world_size(pg) + all_gather_object_enforce_type(pg, total_list, sync_obj) + # Each rank will throw RuntimeError in case of failure on any rank. + stage_name = cast(SyncPayload[T], total_list[0]).stage_name + exception_list: list[tuple[int, Exception]] = [] + ret_list: list[T] = [] + error_msg: str = "" + + for i, sp in enumerate(cast(list[SyncPayload[T]], total_list)): + if sp.stage_name != stage_name: + error_msg += ( + f"Unexpected stage name received from rank {i}: {sp.stage_name} " + ) + continue + if not sp.success and sp.exception is not None: + exception_list.append((i, sp.exception)) + continue + ret_list.append(sp.payload) + + if len(exception_list) > 0: + raise RuntimeError( # type: ignore[misc] + error_msg, + exception_list, + # pyrefly: ignore [invalid-inheritance] + ) from exception_list[0] + return ret_list + else: + if not sync_obj.success: + raise RuntimeError( + f"all_gather failed with exception {sync_obj.exception}", + # pyrefly: ignore [invalid-inheritance] + ) from sync_obj.exception + return [sync_obj.payload] # type: ignore[list-item] + + +# Note: use Any for typing for now so users can pass in +# either a list of None or target type placeholders +# otherwise pyre would complain +def all_gather_object_enforce_type( + pg: dist.ProcessGroup, + # pyre-fixme[2]: Parameter must have a type that does not contain `Any` + object_list: list[Any], + # pyre-fixme[2]: Parameter must have a type other than `Any` + obj: Any, + # pyre-fixme[2]: Parameter must have a type that does not contain `Any` + type_checker: Callable[[Any, Any], bool] = lambda x, y: type(x) is type(y), +) -> None: + """ + Similar to plain all_gather_object but with additional type checking + AFTER gather is done to ensure basic consistency. + If check does not pass, all ranks will fail with exception. + + This is generally to prevent conditional logic leading to + unexpected messages being received. This is considered fatal code error, + but due to logic stacks this might happen implicitly in practice. + + The default check does not check sub type (considered different) + or covariance (considered same) but users can pass in custom checker + if more complicated check is needed. + """ + dist.all_gather_object(object_list, obj, group=pg) + + # conservative check + list_len = len(object_list) + if list_len == 0: + return + first_obj = object_list[0] + for i in range(1, list_len): + if not type_checker(first_obj, object_list[i]): + raise TypeError( + f"Object type at index {i} is {type(object_list[i])}, " + f"while first object type is {type(first_obj)}" + ) + + +def _summarize_ranks(ranks: Iterable[int]) -> str: + ranks = sorted(ranks) + if min(ranks) < 0: + raise AssertionError("ranks should all be positive") + if len(set(ranks)) != len(ranks): + raise AssertionError("ranks should not contain duplicates") + curr: int | range | None = None + ranges = [] + while ranks: + x = ranks.pop(0) + if curr is None: + curr = x + elif isinstance(curr, int): + if x == curr + 1: + curr = range(curr, x + 1, 1) + else: + step = x - curr + curr = range(curr, x + step, step) + else: + if not isinstance(curr, range): + raise AssertionError("curr must be an instance of range") + if x == curr.stop: + curr = range(curr.start, curr.stop + curr.step, curr.step) + else: + ranges.append(curr) + curr = x + + if isinstance(curr, int): + ranges.append(range(curr, curr + 1, 1)) + elif isinstance(curr, range): + ranges.append(curr) + + result = [] + for r in ranges: + if len(r) == 1: + # pyrefly: ignore [bad-argument-type] + result.append(f"{r.start}") + elif r.step == 1: + # pyrefly: ignore [bad-argument-type] + result.append(f"{r.start}:{r.stop}") + else: + # pyrefly: ignore [bad-argument-type] + result.append(f"{r.start}:{r.stop}:{r.step}") + return ",".join(result) + + +def _check_philox_rng_sync( + generator: torch.Generator, group: dist.ProcessGroup +) -> tuple[dict[Any, set], str]: + local_state = generator.get_state() + all_states = [torch.empty_like(local_state) for _ in range(group.size())] + torch.distributed.all_gather(all_states, local_state) + seeds_offsets = [ + (state[:8].view(torch.uint64).item(), state[8:].view(torch.uint64).item()) + for state in all_states + ] + seed_offset_ranks = defaultdict(set) + for rank, (seed, offset) in enumerate(seeds_offsets): + seed_offset_ranks[(seed, offset)].add(rank) + return seed_offset_ranks, "(Seed, Offset)" + + +def _check_cpu_rng_sync( + generator: torch.Generator, group: dist.ProcessGroup +) -> tuple[dict[Any, set], str]: + # seed is returned as uint64_t from C impl, so may not fit in torch int64 tensor directly. + state_tensor = generator.get_state() + all_state_tensors = [torch.empty_like(state_tensor) for _ in range(group.size())] + torch.distributed.all_gather(all_state_tensors, state_tensor) + state_ranks = defaultdict(set) + for rank, state_tensor in enumerate(all_state_tensors): + # Summarize the state vector of the CPU rng. + # The properties that matter most are (1) its different if there is a state difference, (2) its printable + # (see desync table- not viable to print whole state vector of size 5k) + state_ranks[torch.hash_tensor(state_tensor).item()].add(rank) + return state_ranks, "Generator state hash" + + +def _check_rng_sync_internal( + generator: torch.Generator, group: dist.ProcessGroup +) -> tuple[dict[Any, set], str]: + if generator.device.type == "cuda": + return _check_philox_rng_sync(generator, group) + elif generator.device.type == "cpu": + return _check_cpu_rng_sync(generator, group) + else: + raise NotImplementedError( + f"Unsupported generator device: {generator.device.type}" + ) + + +def _desync_table_str(tag: str, value_ranks: dict[Any, set[int]]) -> str: + headers = ["Ranks", f"{tag} values"] + rank_values = [ + [_summarize_ranks(ranks), str(value)] for value, ranks in value_ranks.items() + ] + if importlib.util.find_spec("tabulate"): + from tabulate import tabulate + + return tabulate(rank_values, headers=headers) + row_str = "\n".join([str(row) for row in rank_values]) + return str(f"{headers}\n{row_str}") + + +def _check_rng_sync(generator: torch.Generator, group: dist.ProcessGroup) -> str | None: + value_ranks, value_header = _check_rng_sync_internal(generator, group) + log_str = None + if len(value_ranks) > 1: + log_str = f"Generator desync detected:\n{_desync_table_str(value_header, value_ranks)}" + logger.error(log_str) + return log_str diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/constants.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..0a077bd6d4e5e5b614c3651e16286c41a814d983 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/constants.py @@ -0,0 +1,25 @@ +from datetime import timedelta + +from torch._C._distributed_c10d import _DEFAULT_PG_TIMEOUT + + +__all__ = ["default_pg_timeout", "default_pg_nccl_timeout"] + +# Default process group wide timeout, if applicable. +# This only applies to the non-nccl backends +# To make an attempt at backwards compatibility with THD, we use an +# extraordinarily high default timeout, given that THD did not have timeouts. +default_pg_timeout: timedelta = _DEFAULT_PG_TIMEOUT +# Separate timeout for PGNCCL mainly because it's always been that way in the C++ layer, but until recently +# there was one default that applied across all backends in the python layer. +# Later, we could consider merging them back together at the c++ layer if we can align on a same value. +# (only if TORCH_NCCL_BLOCKING_WAIT or TORCH_NCCL_ASYNC_ERROR_HANDLING is set to 1). + +try: + from torch._C._distributed_c10d import _DEFAULT_PG_NCCL_TIMEOUT + + default_pg_nccl_timeout: timedelta | None = _DEFAULT_PG_NCCL_TIMEOUT +except ImportError: + # if C++ NCCL support is not compiled, we don't have access to the default nccl value. + # if anyone is actually trying to use nccl in this state, it should error. + default_pg_nccl_timeout = None diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/device_mesh.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/device_mesh.py new file mode 100644 index 0000000000000000000000000000000000000000..95d8fe8b8d2d0a16787baffc6ed43fb0771cc2c0 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/device_mesh.py @@ -0,0 +1,1370 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates +import logging +import os +import threading +import warnings +from collections.abc import Iterator +from itertools import zip_longest +from typing import Optional, TYPE_CHECKING, Union + +import torch +from torch.distributed import is_available +from torch.distributed._mesh_layout import _MeshLayout +from torch.distributed._pycute import IntTuple, is_int, suffix_product +from torch.utils._typing_utils import not_none + + +__all__ = ["init_device_mesh", "DeviceMesh"] + + +if not is_available(): + import sys + + # We need to create the stubs when distributed is not available. + # Otherwise, we would fail the doc tests (```./.ci/pytorch/docs-test.sh```), + # since it would try to import ``torch.distributed.device_mesh`` or + # ``torch.distributed.init_device_mesh`` but cannot find them. + + class _DeviceMeshStub: + pass + + def _init_device_mesh_stub(): + pass + + sys.modules["torch.distributed.device_mesh"].DeviceMesh = _DeviceMeshStub # type: ignore[attr-defined] + sys.modules[ + "torch.distributed.device_mesh" + ].init_device_mesh = _init_device_mesh_stub # type: ignore[attr-defined] + + +else: + from torch._C._distributed_c10d import Backend as C10dBackend + from torch.distributed.distributed_c10d import ( + _get_default_group, + _resolve_process_group, + get_backend, + get_process_group_ranks, + get_rank, + get_world_size, + GroupName, + init_process_group, + is_initialized, + new_group, + ProcessGroup, + split_group, + ) + + logger = logging.getLogger(__name__) + + # only import numpy typing when type checking + if TYPE_CHECKING: + try: + from numpy.typing import ArrayLike + except ImportError: + logger.warning( + "DeviceMesh requires numpy >= 1.21 to be installed for type checking" + ) + + BackendConfig = tuple[str | None, C10dBackend.Options | None] + torch.serialization.add_safe_globals([_MeshLayout]) + + class _MeshEnv(threading.local): + def __init__(self) -> None: + self.mesh_stack: list[DeviceMesh] = [] + + def get_current_mesh(self) -> "DeviceMesh": + if len(self.mesh_stack) == 0: + raise RuntimeError("No device mesh is currently active!") + return self.mesh_stack[-1] + + # TODO: to remove it once we move all use cases into new API. + def get_root_mesh(self, device_mesh: "DeviceMesh") -> "DeviceMesh": + # If a mesh could not be found in the child_to_root_mapping, it is a root mesh itself. + # A root mesh is not created through slicing. + # We considers the root mesh of a root mesh is itself. + # We keep this function for backward compatibility. + warnings.warn( + "This get_root_mesh API will be deprecated soon." + "Please use `get_root_mesh` inside DeviceMesh instead.", + stacklevel=2, + ) + if not device_mesh: + return device_mesh + return device_mesh._get_root_mesh() + + @staticmethod + def num_devices_per_host(device_type: str) -> int: + return _get_device_handle(device_type).device_count() + + @staticmethod + def num_hosts(device_type: str) -> int: + # ProcessGroup can't tell us this info so we have to infer it, assume + # homogeneous hardware for now + return get_world_size() // _MeshEnv.num_devices_per_host(device_type) + + # TODO: to remove it once we move all use cases into new API. + # We keep this API for backward compatibility. + def _get_all_submeshes( + self, device_mesh: "DeviceMesh", mesh_dim_name: str + ) -> list["DeviceMesh"]: + warnings.warn( + "This _get_all_submeshes API will be deprecated soon." + "Please use `_get_all_submeshes` inside DeviceMesh instead.", + stacklevel=2, + ) + return device_mesh._get_all_submeshes(mesh_dim_name) + + _mesh_resources: _MeshEnv = _MeshEnv() + + def _get_device_handle(device_type: str = "cuda"): + """ + Get the module corresponding to the device_type which is cuda or cuda-like device. + For example, when the device_type is cuda, the module `torch.cuda` is returned. + Return None when there is no corresponding module for device_type, otherwise + return the corresponding module. + """ + return getattr(torch, device_type, None) + + class DeviceMesh: + """ + DeviceMesh represents a mesh of devices, where layout of devices could be + represented as a n-d dimension array, and each value of the n-d dimensional + array is the global id of the default process group ranks. + + DeviceMesh could be used to setup the N dimensional device connections across the cluster, + and manage the ProcessGroups for N dimensional parallelisms. Communications could happen on + each dimension of the DeviceMesh separately. DeviceMesh respects the device that user selects + already (i.e. if user call `torch.cuda.set_device` before the DeviceMesh initialization), + and will select/set the device for the current process if user does not set the device + beforehand. Note that manual device selection should happen BEFORE the DeviceMesh initialization. + + DeviceMesh can also be used as a context manager when using together with DTensor APIs. + + .. note:: + DeviceMesh follows SPMD programming model, which means the same PyTorch Python program + is running on all processes/ranks in the cluster. Therefore, users need to make sure the + `mesh` array (which describes the layout of devices) should be identical across all ranks. + Inconsistent `mesh` will lead to silent hang. + + Args: + device_type (str): The device type of the mesh. Currently supports: "cpu", "cuda/cuda-like". + mesh (ndarray): A multi-dimensional array or an integer tensor describing the layout + of devices, where the IDs are global IDs of the default process group. + _rank (int): (experimental/internal) + The global rank of the current process. If not provided, it will + be inferred from the default process group. + + Returns: + DeviceMesh: A :class:`DeviceMesh` object representing the device layout. + + The following program runs on each process/rank in an SPMD manner. In this example, we have 2 + hosts with 4 GPUs each. + A reduction over the first dimension of mesh will reduce across + columns (0, 4), .. and (3, 7), a reduction over the second dimension + of mesh reduces across rows (0, 1, 2, 3) and (4, 5, 6, 7). + + Example:: + + >>> # xdoctest: +SKIP("no rank") + >>> from torch.distributed.device_mesh import DeviceMesh + >>> + >>> # Initialize device mesh as (2, 4) to represent the topology + >>> # of cross-host(dim 0), and within-host (dim 1). + >>> mesh = DeviceMesh(device_type="cuda", mesh=[[0, 1, 2, 3],[4, 5, 6, 7]]) + """ + + _device_type: str + _rank_map: torch.Tensor + _mesh_dim_names: tuple[str, ...] | None + _layout: _MeshLayout + _root_mesh: Optional["DeviceMesh"] = None + # Record flatten mesh name to its flattened mesh in root mesh. + _flatten_mapping: dict[str, "DeviceMesh"] + + def __init__( + self, + device_type: str, + mesh: Union[torch.Tensor, "ArrayLike"] | None = None, + *, + mesh_dim_names: tuple[str, ...] | None = None, + backend_override: tuple[BackendConfig, ...] | None = None, + _init_backend: bool = True, + _rank: int | None = None, + _layout: _MeshLayout | None = None, + _rank_map: torch.Tensor | None = None, + _root_mesh: Optional["DeviceMesh"] = None, + ) -> None: + # no-op in OSS, logs API usage metrics in meta-internal runs + torch._C._log_api_usage_once( + "torch.distributed.device_mesh.DeviceMesh.__init__" + ) + if mesh is not None: + if _layout is not None or _rank_map is not None: + raise TypeError( + "Cannot provide _layout and/or _rank_map if passing explicit mesh" + ) + if isinstance(mesh, torch.Tensor) and mesh.device.type != "cpu": + raise ValueError(f"`mesh` must be a CPU tensor, got {mesh}") + mesh_tensor = ( + mesh.detach().to(dtype=torch.int).contiguous() + if isinstance(mesh, torch.Tensor) + else torch.tensor(mesh, device="cpu", dtype=torch.int) + ) + _layout = _MeshLayout(mesh_tensor.size(), mesh_tensor.stride()) + _rank_map = mesh_tensor.flatten() + else: + if _layout is None or _rank_map is None: + raise TypeError( + "The mesh argument is required except for PRIVATE USAGE ONLY!" + ) + + assert _layout.check_non_overlap(), ( + "Please use a non-overlapping layout when creating a DeviceMesh." + ) + assert _rank_map.ndim == 1, "The rank map must be 1-dimensional" + assert _rank_map.is_contiguous(), "The rank map must be contiguous" + assert _rank_map.numel() >= _layout.cosize(), ( + f"The rank map contains {_rank_map.numel()} element, " + f"which isn't large enough for layout {_layout}" + ) + + self._device_type = device_type + self._layout = _layout + self._rank_map = _rank_map + self._mesh_dim_names = tuple(mesh_dim_names) if mesh_dim_names else None + self._root_mesh = _root_mesh + + if backend_override is None: + backend_override = ((None, None),) * len(self._layout) + elif len(backend_override) != len(self._layout): + raise ValueError( + f"backend_override should have the same length as the number of mesh dimensions, " + f"but got {len(backend_override)} and {len(self._layout)}." + ) + # Internal bookkeeping for the device mesh. + self._layout = ( + _layout + if _layout + else _MeshLayout(self.mesh.size(), self.mesh.stride()) + ) + if not self._layout.check_non_overlap(): + raise AssertionError( + "Please use a non-overlapping layout when creating a DeviceMesh." + ) + # Because we still need to support slicing of flattened dim from root mesh, so we don't check stride here. + if self._layout.numel() != self.mesh.numel(): + raise AssertionError( + "Please use a valid layout when creating a DeviceMesh." + f"The layout {self._layout} is not consistent with the mesh size {self.mesh.size()}." + ) + + # private field to pre-generate DeviceMesh's hash + self._flatten_rank_map = tuple(self._rank_map.tolist()) + self._thread_id = None + # Initialize instance-specific flatten mapping + self._flatten_mapping = {} + + # Skip process group initialization if xla device or init backend is False + # TODO(yeounoh) implement DeviceMesh backend and register XLA backend. + if device_type != "xla": + # always try to create default (world) pg, even if it is not initialized + # already. The world pg is used for device mesh identity (rank) on each + # process (we need to know if the current global rank is in the mesh or not). + if _init_backend: + self._setup_world_group_and_device() + self._dim_group_names = self._init_process_groups( + self._layout, + self._rank_map, + self._mesh_dim_names, + backend_override, + ) + + if is_initialized() and get_backend() == "threaded": + # pyrefly: ignore [bad-assignment] + self._thread_id = threading.get_ident() + + if _rank is None: + _rank = get_rank() + + # calculate the coordinates of the current global rank on the mesh + rank_coords = (self.mesh == _rank).nonzero() + if rank_coords.size(0) not in (0, 1): + raise AssertionError( + f"rank_coords.size(0) must be 0 or 1, got {rank_coords.size(0)}" + ) + self._coordinate_on_dim: list[int] | None = ( + rank_coords[0].tolist() if rank_coords.size(0) > 0 else None + ) + + @property + def device_type(self) -> str: + """Returns the device type of the mesh.""" + return self._device_type + + @property + def mesh(self) -> torch.Tensor: + """Returns the tensor representing the layout of devices.""" + full_mesh = self._layout.remap_to_tensor(self._rank_map) + if full_mesh.size(0) == 1: + return full_mesh[0] + my_coords = (full_mesh == get_rank()).nonzero() + if my_coords.size(0) > 0: + return full_mesh[my_coords[0, 0]] + raise RuntimeError( + "In order to get the mesh Tensor of a DeviceMesh it needs to " + "either have all its original dimensions (e.g., no slicing) " + "or it needs to contain the local rank" + ) + + @property + def mesh_dim_names(self) -> tuple[str, ...] | None: + """Returns the names of mesh dimensions.""" + return self._mesh_dim_names + + def _setup_world_group_and_device(self): + default_initialized = is_initialized() + # TODO: think about how to allow pg options to be passed to world group + # or mesh dimension groups + if not default_initialized: + init_process_group() + + world_size = get_world_size() + if self._layout.numel() > world_size: + raise RuntimeError( + f"Mesh should not be bigger than default world size {world_size}, but found {self._layout.numel()} ranks!" + ) + + # ONLY set the device if the current device is not initialized, if user already + # set the device before DeviceMesh init, we respect the user's choice. + device_handle = _get_device_handle(self._device_type) + if device_handle and not device_handle.is_initialized(): + # auto set the cuda/cuda-like device only if user has not set it, if there's LOCAL_RANK + # env variable from launchers, we use it to set the device. + if "LOCAL_RANK" in os.environ: + local_rank = int(os.environ["LOCAL_RANK"]) + logger.info( + "Setting default device for the current process based on LOCAL_RANK=%s", + local_rank, + ) + device_handle.set_device(local_rank) + else: + warnings.warn( + "It seems like you did not set/select the default device for the current process before the DeviceMesh " + "initialization or use a launcher (i.e. torchrun) which populates `LOCAL_RANK` environment variable. " + "It is recommended to set the current device for the process BEFORE the DeviceMesh initialization so that " + "the underlying communicator (i.e. NCCL) can be initialized properly. " + "Given that the current process has no default device selected, DeviceMesh will use a heuristic to set the " + "device_id via `global_rank % num_devices_per_host`, assuming homogeneous hardware cluster. ", + stacklevel=2, + ) + # heuristic to set the current cuda/cuda-like device base on num of gpu devices available in each host + # NOTE: This device selection would only work for homogeneous hardware. + num_devices_per_host = device_handle.device_count() + if ( + world_size > num_devices_per_host + and world_size % num_devices_per_host != 0 + ): + raise RuntimeError( + f"DeviceMesh only support homogeneous hardware, but found " + f"{world_size} ranks and {num_devices_per_host} {self._device_type} devices!" + ) + device_handle.set_device(get_rank() % num_devices_per_host) + + return _get_default_group() + + @staticmethod + def _init_one_process_group( + sub_layout: _MeshLayout, + rank_map: torch.Tensor, + dim_name: str, + backend_override: BackendConfig, + ) -> GroupName | None: + # Generate a 2D global mesh tensor for the current dim for PG creation. + pg_ranks_by_dim = sub_layout.nest().remap_to_tensor(rank_map) + backend, pg_options = backend_override + # We need to explicitly pass in timeout when specified in option, otherwise + # the default timeout will be used to override the timeout set in option. + # TODO: remove this once we have fixed inside c10d level. + timeout = pg_options._timeout if pg_options else None + + # If we have a 2D mesh with mesh_dim_names ("dp", "tp"), the group description + # of the subgroups would be `mesh_dim_dp` and `mesh_name_tp`. + # If the mesh doesn't have a mesh_dim_names, then the group description of the + # subgroup would be `mesh_dim_0` and `mesh_dim_1`. + group_desc = f"mesh_{dim_name}" + + dim_group = None + default_group = _get_default_group() + + # Early return if there is only one sub_layout in the mesh layout. + if sub_layout.numel() == get_world_size() and backend_override == ( + None, + None, + ): + # Append the default pg to the first dim groups only if the default pg is compatible with `self._device_type`. + # Otherwise, create new pg. + ranks = list(range(get_world_size())) + dim_group = ( + new_group( + backend="cpu:gloo,cuda:nccl", + ranks=ranks, + group_desc="mesh_default", + ) + if torch.cuda.is_available() + and get_backend(default_group) == "gloo" + else default_group + ) + return dim_group.group_name # type: ignore[union-attr] + + # If bound_device_id exists, it means the nccl communicator has been eagerly initialized + # so that we can use `split_group` to create subgroups through `ncclCommSplit`. + # In this case, we only need to make one API call (`split_group``) for the subgroup creation + # for each mesh dimension. In a 2 * 4 mesh, we only need to make two API calls per ranks to create + # all the subgroups. + # Otherwise, we need to make more than one API call (`new_group`) for subgroup creations. The + # numbers of API calls are equal to the number of subgroups for each mesh dimension. In a 2 * 4 + # mesh, we need to make two API calls per ranks to create all the subgroups. + if ( + getattr(default_group, "bound_device_id", None) is not None + and torch.cuda.is_available() + and ( + backend is None + or default_group._get_backend(torch.device("cuda")).name() + == backend + ) + ): + dim_group = split_group( + parent_pg=default_group, + timeout=timeout, + pg_options=pg_options, + split_ranks=pg_ranks_by_dim.tolist(), + group_desc=group_desc, + ) + return dim_group.group_name # type: ignore[union-attr] + + # If the subgroup has been already created through `split_group`, we simply loop over `pg_ranks_by_dim` + # and append the `group_name` to the `dim_group_names` list when the current rank is in the subgroup. + # Otherwise, we use `new_group` instead of `split_group` to create subgroups by looping over `pg_ranks_by_dim` + # along with appending information to the `dim_group_names` list whenever necessary. + pg_name = None + for dim_mesh in pg_ranks_by_dim: + subgroup_ranks = dim_mesh.tolist() + dim_group = new_group( + ranks=subgroup_ranks, + timeout=timeout, + backend=backend, + pg_options=pg_options, + group_desc=group_desc, + ) + + # only add to dim_groups if the current rank in the subgroup + if get_rank() in subgroup_ranks: + if pg_name is not None: + raise RuntimeError( + f"Each device mesh dimension should get only one process group, but got {get_rank()} " + f"in {subgroup_ranks}!" + ) + pg_name = dim_group.group_name + return pg_name + + @staticmethod + def _init_process_groups( + layout: _MeshLayout, + rank_map: torch.Tensor, + mesh_dim_names: tuple[str, ...] | None, + backend_override: tuple[BackendConfig, ...], + ) -> list[GroupName]: + # group_name associated with each mesh dimension, each + # mesh dimension should have one sub-group per rank + dim_group_names: list[GroupName | None] = [] + # create sub pgs base on the mesh argument specified + for dim in range(len(layout)): + dim_name = mesh_dim_names[dim] if mesh_dim_names else f"dim_{dim}" + dim_group_names.append( + DeviceMesh._init_one_process_group( + layout[dim], rank_map, dim_name, backend_override[dim] + ) + ) + # Filter out None values. If any are None then they should all be None. + dim_non_none_group_names = [n for n in dim_group_names if n is not None] + assert not dim_non_none_group_names or len(dim_non_none_group_names) == len( + dim_group_names + ) + return dim_non_none_group_names + + def _get_root_mesh(self) -> "DeviceMesh": + return self._root_mesh if self._root_mesh else self + + def __enter__(self) -> "DeviceMesh": + # set this mesh as the current mesh in mesh env + _mesh_resources.mesh_stack.append(self) + return self + + # pyre-fixme[2]: Parameter must be annotated. + def __exit__(self, exc_type, exc_value, exc_traceback) -> None: + # pop this mesh from mesh env + _mesh_resources.mesh_stack.pop() + + def __repr__(self) -> str: + device_mesh_repr = ( + f"({', '.join(f'{k}={v}' for k, v in zip(self._mesh_dim_names, self._layout.top_level_sizes))})" + if self._mesh_dim_names + else f"{self._layout.top_level_sizes}" + ) + device_mesh_repr = f"DeviceMesh({device_mesh_repr}, '{self.device_type}', stride={self._layout.strides}" + # We only print the mesh tensor if the debug mode is turned on. + if os.environ.get("TORCH_DISTRIBUTED_DEBUG", "") == "DETAIL": + device_mesh_repr += f", Mesh: {self.mesh.tolist()}" + return f"{device_mesh_repr})" + + def __hash__(self): + # lazily compute hash + self._hash = getattr(self, "_hash", None) + if not self._hash: + self._hash = hash( + ( + self._flatten_rank_map, + self._layout, + self._device_type, + self._mesh_dim_names, + self._thread_id, + ) + ) + return self._hash + + def __eq__(self, other: object) -> bool: + if self is other: + return True + if not isinstance(other, DeviceMesh): + return False + return ( + self._flatten_rank_map == other._flatten_rank_map + and self._layout == other._layout + and self._device_type == other._device_type + and self._mesh_dim_names == other._mesh_dim_names + and self._thread_id == other._thread_id + ) + + def __getitem__(self, mesh_dim_names: str | tuple[str, ...]) -> "DeviceMesh": + """ + Slice the current DeviceMesh based on the mesh_dim_names given to create a submesh. + The submesh created consists of the dimensions and the communicators indicated by + ``mesh_dim_names`` + + Args: + mesh_dim_names (Union[str, tuple[str, ...]]): the name or the tuple of names of the + mesh dimension of the DeviceMesh to create the submesh for. + Returns: + A :class:`DeviceMesh` object + + The following program runs on each process/rank in an SPMD manner in a world size of 8. + In the first example: + Calling mesh_2d["tp"] on rank 0, 1, 2, 3 returns a 1D submesh of DeviceMesh:([0, 1, 2, 3]). + Calling mesh_2d["tp"] on rank 4, 5, 6, 7 returns a 1D submesh of DeviceMesh:([4, 5, 6, 7]). + Calling mesh_2d["dp"] on rank 0, 4 returns a 1D submesh of DeviceMesh:([0, 4]). + Calling mesh_2d["dp"] on rank 1, 5 returns a 1D submesh of DeviceMesh:([1, 5]). + Calling mesh_2d["dp"] on rank 2, 6 returns a 1D submesh of DeviceMesh:([2, 6]). + Calling mesh_2d["dp"] on rank 3, 7 returns a 1D submesh of DeviceMesh:([3, 7]). + + In the second example: + Calling mesh_3d["dp", "cp"] on rank 0, 1, 4, 5 returns a 2D submesh of DeviceMesh:([[0, 1], [4, 5]]). + Calling mesh_3d["dp", "cp"] on rank 2, 3, 6, 7 returns a 2D submesh of DeviceMesh:([[2, 3], [6, 7]]). + Calling mesh_3d["cp", "dp"] on rank 0, 1, 4, 5 returns a 2D submesh of DeviceMesh:([[0, 4], [1, 5]]). + Calling mesh_3d["cp", "dp"] on rank 2, 3, 6, 7 returns a 2D submesh of DeviceMesh:([[2, 6], [3, 7]]). + + Example:: + + >>> # xdoctest: +SKIP("no rank") + >>> from torch.distributed.device_mesh import DeviceMesh + >>> + >>> # Initialize a 2D device mesh as (2, 4) to represent the topology + >>> # of cross-host(dim 0), and within-host (dim 1). + >>> mesh_2d = init_device_mesh(device_type="cuda", (2,4), mesh_dim_names=("dp", "tp")) + >>> tp_mesh = mesh_2d["tp"] + >>> dp_mesh = mesh_2d["dp"] + >>> + >>> # Initialize a 3D mesh. + >>> mesh_3d = init_device_mesh(device_type="cuda", (2,2,2), mesh_dim_names=("dp", "pp", "cp")) + >>> # The order of the mesh_dim_names provided deteremines the order of dimensions in the submesh. + >>> dp_cp_mesh = mesh_3d["dp", "cp"] + >>> cp_dp_mesh = mesh_3d["cp", "dp"] + """ + if not self._mesh_dim_names: + raise RuntimeError("Cannot slice a DeviceMesh without mesh_dim_names!") + + mesh_dim_names = ( + (mesh_dim_names,) if isinstance(mesh_dim_names, str) else mesh_dim_names + ) + + if mesh_dim_names == self._mesh_dim_names: + return self + else: + sliced_mesh_layout = self._get_slice_mesh_layout(mesh_dim_names) + # When using FakeTensorMode to trace the model, `_create_sub_mesh()` will + # fail as it will require a real tensor to manipulate. + # `unset_fake_temporarily()` will allow us to materialize the tensors + # within `_create_sub_mesh`, which should not affect modling. + # + # Note that this should be orthogonal to torch.compile(). But whether + # we can compile device_mesh `slicing` (no graph break) is not verified + # yet and need a follow-up, + # TODO: compiler + device_mesh slicing. + with torch._subclasses.fake_tensor.unset_fake_temporarily(): + submesh = self._create_sub_mesh(sliced_mesh_layout, mesh_dim_names) + return submesh + + def get_group(self, mesh_dim: int | str | None = None) -> ProcessGroup: + """ + Returns the single ProcessGroup specified by mesh_dim, or, if mesh_dim is not specified and the + DeviceMesh is 1-dimensional, returns the only ProcessGroup in the mesh. + + Args: + mesh_dim (str/int, optional): it can be the name of the mesh dimension or the index + of the mesh dimension. Default is None. + + Returns: + A :class:`ProcessGroup` object. + """ + if not hasattr(self, "_dim_group_names"): + raise RuntimeError("DeviceMesh process groups not initialized!") + + if len(self._layout) > 1 and mesh_dim is None: + raise RuntimeError( + f"Found the DeviceMesh have {len(self._layout)} dimensions", + "Optional kwarg `mesh_dim` needs to be specified when device_mesh.ndim > 1.", + "If you want to get the list of all the ProcessGroups in the DeviceMesh," + "please use `get_all_groups()` instead.", + ) + + # Quick return if the current device_mesh is a 1D mesh. + if len(self._layout) == 1 and mesh_dim is None: + return not_none(_resolve_process_group(self._dim_group_names[0])) + + root_mesh = self._get_root_mesh() + root_to_flatten_mapping = root_mesh._flatten_mapping + if root_to_flatten_mapping and mesh_dim in root_to_flatten_mapping: + dim_group_name = root_to_flatten_mapping[ + mesh_dim # type: ignore[index] + ]._dim_group_names[0] + return not_none(_resolve_process_group(dim_group_name)) + else: + mesh_dim = ( + self._get_mesh_dim_by_name(mesh_dim) + if isinstance(mesh_dim, str) + else mesh_dim + ) + if not isinstance(mesh_dim, int): + raise AssertionError( + f"mesh_dim must be an int, got {type(mesh_dim)}" + ) + return not_none(_resolve_process_group(self._dim_group_names[mesh_dim])) + + def get_all_groups(self) -> list[ProcessGroup]: + """ + Returns a list of ProcessGroups for all mesh dimensions. + + Returns: + A list of :class:`ProcessGroup` object. + """ + return [self.get_group(i) for i in range(len(self._layout))] + + def _create_sub_mesh( + self, + layout: _MeshLayout, + submesh_dim_names: tuple[str, ...], + ) -> "DeviceMesh": + root_mesh = self._get_root_mesh() + slice_dim_group_name = [] + for name in submesh_dim_names: + if name in not_none(self._mesh_dim_names): + slice_dim_group_name.append( + self._dim_group_names[ # type: ignore[has-type] + not_none(self._mesh_dim_names).index(name) + ] + ) + else: + # If device_mesh is not root_mesh, we already throw error in _get_slice_mesh_layout + # Since we will deprecate the slicing of flattened dim_name from root mesh soon, + # we don't want to optimize the code furthermore. + flatten_mesh = self._flatten_mapping[name] + slice_dim_group_name.append( + flatten_mesh._dim_group_names[ # type: ignore[has-type] + not_none(flatten_mesh._mesh_dim_names).index(name) + ] + ) + res_submesh = DeviceMesh( + self._device_type, + _layout=layout, + _rank_map=root_mesh._rank_map, + mesh_dim_names=submesh_dim_names, + _root_mesh=root_mesh, + _init_backend=False, + ) + res_submesh._dim_group_names = slice_dim_group_name + return res_submesh + + def _create_flatten_mesh( + self, + mesh_dim_name: str | None = None, + backend_override: BackendConfig = (None, None), + ) -> "DeviceMesh": + root_mesh = self._get_root_mesh() + + if not mesh_dim_name: + mesh_dim_name = "_".join(not_none(self._mesh_dim_names)) + + # Flatten a 1D device mesh into its original mesh_dim_name will return itself. + if self.ndim == 1 and mesh_dim_name in not_none(self._mesh_dim_names): + return self + + # Check whether the mesh_dim_name for flattened mesh is valid. + invalid_dim_names = not_none(root_mesh._mesh_dim_names) + if mesh_dim_name in invalid_dim_names: + raise ValueError( + f"{mesh_dim_name} already exists for submesh of the {root_mesh}. ", + f"The mesh_dim_names of submesh and flattened mesh are {invalid_dim_names}. " + f"Please specify another valid mesh_dim_name.", + ) + + flattened_mesh_layout = self._layout.coalesce() + if len(flattened_mesh_layout) > 1: + flattened_mesh_layout = flattened_mesh_layout.nest() + # Quick return if the flatten mesh has been created before. + if mesh_dim_name in root_mesh._flatten_mapping: + if ( + flattened_mesh_layout + == root_mesh._flatten_mapping[mesh_dim_name]._layout + ): + return root_mesh._flatten_mapping[mesh_dim_name] + else: + raise ValueError( + f"Flatten mesh with mesh_dim_name {mesh_dim_name} has been created before, " + f"Please specify another valid mesh_dim_name." + ) + + res_flattened_mesh = DeviceMesh( + root_mesh._device_type, + _layout=flattened_mesh_layout, + _rank_map=root_mesh._rank_map, + mesh_dim_names=(mesh_dim_name,), + _root_mesh=root_mesh, + backend_override=(backend_override,), + ) + root_mesh._flatten_mapping[mesh_dim_name] = res_flattened_mesh + + return res_flattened_mesh + + def _get_root_mesh_dim(self) -> int | None: + """ + Returns the index of the mesh dim in the root mesh. + The device_mesh passed in needs to be sliced out from the root mesh + or submesh of the root mesh. + """ + root_mesh = self._get_root_mesh() + child_mesh_dim_names = self._mesh_dim_names + if root_mesh and child_mesh_dim_names: + if len(child_mesh_dim_names) != 1: + raise AssertionError("The submesh can only be a 1D mesh.") + child_mesh_dim_name = child_mesh_dim_names[0] + return root_mesh._get_mesh_dim_by_name(child_mesh_dim_name) + return None + + def _get_mesh_dim_by_name(self, mesh_dim_name: str) -> int: + if self._mesh_dim_names is None or len(self._mesh_dim_names) == 0: + raise KeyError( + "No `mesh_dim_names` found.", + ) + if mesh_dim_name not in self._mesh_dim_names: + raise KeyError( + f"Mesh dimension '{mesh_dim_name}' does not exist.", + f"Available mesh dimensions are: mesh_dim_names={self._mesh_dim_names}", + ) + return not_none(self._mesh_dim_names.index(mesh_dim_name)) + + def _get_slice_mesh_layout( + self, mesh_dim_names: tuple[str, ...] + ) -> _MeshLayout: + """ + Validate whether the mesh_dim_names is valid for slicing the given device_mesh. + If valid, return dim indexes of the slice mesh in the device mesh. + """ + slice_from_root = True + if self != self._get_root_mesh(): + slice_from_root = False + + # The slice mesh_dim_names should consist either the current device_mesh's mesh_dim_names + # or its flattened mesh's mesh_dim_names if it's root_mesh. + flatten_name_to_root_layout = ( + { + key: mesh._layout + for key, mesh in self._get_root_mesh()._flatten_mapping.items() + } + if slice_from_root + else {} + ) + valid_mesh_dim_names = [ + *not_none(self._mesh_dim_names), + *flatten_name_to_root_layout, + ] + + if not all( + mesh_dim_name in valid_mesh_dim_names + for mesh_dim_name in mesh_dim_names + ): + raise KeyError( + f"Invalid mesh_dim_names {mesh_dim_names} specified. " + f"Valid mesh_dim_names are {valid_mesh_dim_names}." + ) + + layout_sliced = [] + for name in mesh_dim_names: + if name in not_none(self._mesh_dim_names): + layout_sliced.append( + self._layout[not_none(self._mesh_dim_names).index(name)] + ) + elif name in flatten_name_to_root_layout: + warnings.warn( + "Slicing a flattened dim from root mesh will be deprecated in PT 2.11. " + "Users need to bookkeep the flattened mesh directly. ", + stacklevel=2, + ) + layout_sliced.append(flatten_name_to_root_layout[name]) + + sliced_sizes = tuple(l.sizes for l in layout_sliced) + sliced_strides = tuple(l.strides for l in layout_sliced) + + # The check below is from DeviceMesh's implementation before adopting CuTe layout for internal + # bookkeeping and it can be removed but we need to define what is the expected behavior. + # TODO: Remove the below check and define the expected behavior. + # Validate the order of the slice mesh dim indices. + # This needs to be in ascending order. + pre_stride = -1 + for stride in reversed(sliced_strides): + # Note that with CuTe layout, we can support slicing flattened non-contiguous mesh dims with no problem. + # But this will make this behavior complicated so we decided to not support it for now. + if not is_int(stride): + raise NotImplementedError( + "Currently, this only allows slicing out a contiguous flattened dim." + ) + if stride < pre_stride: + raise KeyError( + f"Invalid mesh_dim_names {mesh_dim_names} specified. " + "Mesh dim indices should be in ascending order." + ) + pre_stride = stride + + # When users sliced dim_names outside from current mesh, we will check whether + # there is layout overlap. + # TODO: Eventually we will just directly throw error here because + # we will deprecate the slicing of flattened dim_name from root mesh. + layout_sliced = _MeshLayout(sliced_sizes, sliced_strides) + if not layout_sliced.check_non_overlap(): + raise RuntimeError( + f"Slicing overlapping dim_names {mesh_dim_names} is not allowed." + ) + + return layout_sliced + + # TODO: to make this use case by other components public API in the future. + def _get_all_submeshes(self, mesh_dim_name: str) -> list["DeviceMesh"]: + """ + Return all the submeshes of a given mesh dimension of the device mesh. + """ + mesh_dim = self._get_mesh_dim_by_name(mesh_dim_name) + layout = self._layout[mesh_dim] + pg_ranks_by_dim = layout.remap_to_tensor(self._rank_map) + cur_rank = self.get_rank() + res_submeshes = [] + for mesh_1d in pg_ranks_by_dim: + submesh = DeviceMesh( + self._device_type, + mesh_1d, + mesh_dim_names=(mesh_dim_name,), + _init_backend=False, + ) + submesh._dim_group_names = ( # type: ignore[has-type] + [self._dim_group_names[mesh_dim]] # type: ignore[has-type] + if cur_rank in mesh_1d + else [] + ) + res_submeshes.append(submesh) + + return res_submeshes + + @staticmethod + def from_group( + group: ProcessGroup | list[ProcessGroup], + device_type: str, + mesh: Union[torch.Tensor, "ArrayLike"] | None = None, + *, + mesh_dim_names: tuple[str, ...] | None = None, + ) -> "DeviceMesh": + """ + Constructs a :class:`DeviceMesh` with ``device_type`` from an + existing :class:`ProcessGroup` or a list of existing :class:`ProcessGroup`. + + The constructed device mesh has number of dimensions equal to the + number of groups passed. For example, if a single process group is passed in, + the resulted DeviceMesh is a 1D mesh. If a list of 2 process groups is passed in, + the resulted DeviceMesh is a 2D mesh. + + If more than one group is passed, then the ``mesh`` and ``mesh_dim_names`` arguments + are required. The order of the process groups passed in determines the topology of + the mesh. For example, the first process group will be the 0th dimension of the DeviceMesh. + The `mesh` tensor passed in must have the same number of dimensions as the number of process + groups passed in, and the order of the dimensions in the `mesh` tensor must match the order + in the process groups passed in. + + Args: + group (ProcessGroup or list[ProcessGroup]): the existing ProcessGroup + or a list of existing ProcessGroups. + device_type (str): The device type of the mesh. Currently supports: "cpu", + "cuda/cuda-like". Passing in a device type with a GPU index, such as "cuda:0", + is not allowed. + mesh (torch.Tensor or ArrayLike, optional): A multi-dimensional array or an + integer tensor describing the layout of devices, where the IDs are global IDs + of the default process group. Default is None. + mesh_dim_names (tuple[str, ...], optional): A tuple of mesh dimension names to assign + to each dimension of the multi-dimensional array describing the layout of devices. + Its length must match the length of `mesh_shape`. Each string in `mesh_dim_names` + must be unique. Default is None. + + Returns: + DeviceMesh: A :class:`DeviceMesh` object representing the device layout. + """ + + # 1D scenario + if isinstance(group, ProcessGroup): + group_ranks = get_process_group_ranks(group) + if ( + isinstance(mesh, torch.Tensor) and mesh.tolist() != group_ranks + ) or ( + mesh is not None + and not isinstance(mesh, torch.Tensor) + and mesh != group_ranks + ): + raise ValueError( + f"Invalid mesh {str(mesh)} for ProcessGroup with ranks {group_ranks}" + ) + mesh = torch.tensor(group_ranks, device="cpu", dtype=torch.int) + device_mesh = DeviceMesh( + device_type, + mesh, + mesh_dim_names=mesh_dim_names, + _init_backend=False, + ) + device_mesh._dim_group_names = [group.group_name] + return device_mesh + + # nD scenario + groups = list(group) + if len(groups) == 0: + raise ValueError("Expects at least one ProcessGroup to be passed") + if mesh is None: + raise ValueError("Must pass mesh if passing multiple ProcessGroups") + if mesh_dim_names is None: + raise ValueError( + "Must pass mesh_dim_names if passing multiple ProcessGroups" + ) + # When init a DeviceMesh with multiple ProcessGroups directly, we need to make sure + # the mesh tensor is contiguous. Otherwise, the layout we inferred from the mesh tensor + # will have larger span than the actual tensor. This is just internal implementation detail + # and does not affect user facing behavior. + mesh = ( + mesh.detach().to(dtype=torch.int, device="cpu") + if isinstance(mesh, torch.Tensor) + else torch.tensor(mesh, device="cpu", dtype=torch.int) + ) + if mesh.ndim != len(groups): + raise ValueError( + "Expects mesh with ndim equal to number of ProcessGroups but got " + f"mesh {mesh.tolist()} and {len(groups)} ProcessGroups" + ) + device_mesh = DeviceMesh( + device_type, mesh, mesh_dim_names=mesh_dim_names, _init_backend=False + ) + device_mesh._dim_group_names = [group.group_name for group in groups] + return device_mesh + + def size(self, mesh_dim: int | None = None) -> int: + if mesh_dim is not None: + return self._layout[mesh_dim].numel() + return self._layout.numel() + + @property + def ndim(self) -> int: + return len(self._layout) + + @property + def shape(self) -> tuple[int, ...]: + return self._layout.top_level_sizes + + def get_rank(self) -> int: + """ + Returns the current global rank. + """ + return get_rank() + + def get_local_rank(self, mesh_dim: int | str | None = None) -> int: + """ + Returns the local rank of the given mesh_dim of the DeviceMesh. + + Args: + mesh_dim (str/int, optional): it can be the name of the mesh dimension or the index + of the mesh dimension. Default is None. + + Returns: + An integer denotes the local rank. + + The following program runs on each process/rank in an SPMD manner. In this example, we have 2 + hosts with 4 GPUs each. + Calling mesh_2d.get_local_rank(mesh_dim=0) on rank 0, 1, 2, 3 would return 0. + Calling mesh_2d.get_local_rank(mesh_dim=0) on rank 4, 5, 6, 7 would return 1. + Calling mesh_2d.get_local_rank(mesh_dim=1) on rank 0, 4 would return 0. + Calling mesh_2d.get_local_rank(mesh_dim=1) on rank 1, 5 would return 1. + Calling mesh_2d.get_local_rank(mesh_dim=1) on rank 2, 6 would return 2. + Calling mesh_2d.get_local_rank(mesh_dim=1) on rank 3, 7 would return 3. + + Example:: + + >>> # xdoctest: +SKIP("no rank") + >>> from torch.distributed.device_mesh import DeviceMesh + >>> + >>> # Initialize device mesh as (2, 4) to represent the topology + >>> # of cross-host(dim 0), and within-host (dim 1). + >>> mesh = DeviceMesh(device_type="cuda", mesh=[[0, 1, 2, 3],[4, 5, 6, 7]]) + """ + if self.ndim > 1 and mesh_dim is None: + raise RuntimeError( + f"Found the DeviceMesh have {len(self._layout)} dimensions", + "Optional kwarg `mesh_dim` needs to be specified when device_mesh.ndim > 1.", + ) + elif mesh_dim is None: + mesh_dim = 0 + + mesh_dim_group = not_none(self.get_group(mesh_dim)) + if not isinstance(mesh_dim_group, ProcessGroup): + raise AssertionError( + "We expect ProcessGroup before calling `get_rank`!" + ) + return not_none(get_rank(mesh_dim_group)) + + def get_coordinate(self) -> list[int] | None: + """ + Return the relative indices of this rank relative to all + dimensions of the mesh. If this rank is not part of the mesh, return None. + """ + return self._coordinate_on_dim if self._coordinate_on_dim else None + + def _flatten( + self, + mesh_dim_name: str | None = None, + backend_override: None + | str + | C10dBackend.Options + | tuple[str, C10dBackend.Options] = None, + ) -> "DeviceMesh": + """ + Returns a 1D DeviceMesh by flattening the current DeviceMesh. + + If no mesh_dim_name is provided, the default is a string concatenating the mesh_dim_names of the + given submesh with each mesh_dim_name separated by "_". For example, if we have a 3D mesh + DeviceMesh([[[0, 1], [2, 3]], [[4, 5], [6, 7]]], mesh_dim_names=("dp", "cp", "tp")), calling + mesh_3d["dp", "cp"]._flatten() will create a 1D submesh DeviceMesh([0, 2, 4, 6], mesh_dim_names=("dp_cp",)) + on rank 0, 2, 4, 6 and a 1D submesh DeviceMesh([1, 3, 5, 7], mesh_dim_names=("dp_cp",)) on rank 1, 3, 5, 7. + + After the flattened dimension is created, to access the flattened dimension in mesh_3d, one can use the + existing slicing method to obtain the flattened mesh through calling mesh_3d["dp_cp"]. + """ + if not self._mesh_dim_names: + raise RuntimeError( + "Cannot flatten a DeviceMesh without mesh_dim_names!" + ) + + if backend_override is not None: + (backend_override_tuple,) = _normalize_backend_override( + {0: backend_override}, 1 + ) + else: + backend_override_tuple = (None, None) + + return self._create_flatten_mesh(mesh_dim_name, backend_override_tuple) + + def _create_unflatten_mesh( + self, + dim: int, + mesh_sizes: tuple[int, ...], + mesh_dim_names: tuple[str, ...], + backend_override: tuple[ + tuple[str | None, C10dBackend.Options | None], ... + ] = ((None, None),), + ) -> "DeviceMesh": + inner_layout = _MeshLayout(tuple(mesh_sizes), suffix_product(mesh_sizes)) + + if inner_layout.numel() != self._layout[dim].numel(): + raise ValueError( + f"The product of {mesh_sizes=} is {inner_layout.numel()}, " + f"but the original dimension at dim={dim} has size {self._layout[dim].numel()}. " + f"These must be equal for unflatten to work correctly." + ) + + partial_layout = self._layout[dim].composition(inner_layout) + unflattened_layout = self._layout.splice(dim, dim + 1, partial_layout) + unflattened_mesh_dim_names = list(not_none(self.mesh_dim_names)) + unflattened_mesh_dim_names[dim : dim + 1] = list(mesh_dim_names) + + root_mesh = self._get_root_mesh() + res_mesh = DeviceMesh( + self.device_type, + _layout=unflattened_layout, + _rank_map=root_mesh._rank_map, + mesh_dim_names=tuple(unflattened_mesh_dim_names), + _root_mesh=root_mesh, + _init_backend=False, + ) + + # If original mesh has initiated its backend, we need to initialize the backend + # of unflatten dims as well. + # TODO: To make backend init more efficient with cute layout representation and support + # per dim backend init. + if hasattr(self, "_dim_group_names"): + dim_group_names = self._dim_group_names.copy() + dim_group_names[dim : dim + 1] = self._init_process_groups( + partial_layout, + root_mesh._rank_map, + mesh_dim_names, + backend_override, + ) + res_mesh._dim_group_names = dim_group_names + + return res_mesh + + def _unflatten( + self, + dim: int | str, + mesh_sizes: tuple[int, ...], + mesh_dim_names: tuple[str, ...], + backend_override: dict[ + str, str | C10dBackend.Options | tuple[str, C10dBackend.Options] + ] + | None = None, + ) -> "DeviceMesh": + """ + Returns a DeviceMesh by unflatten the current DeviceMesh. + + This api can be used to unflatten a N-D DeviceMesh into N-1+len(mesh_sizes)-D meshes or submeshes. + The dim is the dimension to be unflattened which can be either a string or an integer. + + The mesh_sizes is a tuple which specifies the shape of the mesh unflatten into for the given dim. + The mesh_dim_names is a list of strings which specifies the names of the dimensions of the mesh unflatten into. + Its length must match the length of mesh_sizes. + + For example, if we have a 1D mesh DeviceMesh([0, 1, 2, 3, 4, 5, 6, 7], mesh_dim_names=("world")), + calling mesh_1d._unflatten(0, (2, 2, 4), ["dp", "pp", "tp"]) will create a 3D mesh + DeviceMesh([[[0, 1], [2, 3]], [[4, 5], [6, 7]]], mesh_dim_names=("dp", "cp", "tp")). + + Note that after calling the unflatten, there is no access to the unflattened dimension in mesh_1d, one can only + use the newly unflattened mesh to slice out the unflattened mesh dims. + """ + if isinstance(dim, int) and dim >= self.ndim: + raise ValueError( + f"dim {dim} specified in `_unflatten` is out of range {self.ndim}" + ) + elif isinstance(dim, str) and dim in not_none(self.mesh_dim_names): + raise ValueError( + f"dim {dim} specified in `_unflatten` is not in {self.mesh_dim_names}" + ) + + if len(mesh_sizes) != len(mesh_dim_names): + raise RuntimeError( + "mesh_dim_names must have same length as mesh_sizes in _unflatten!" + ) + + if isinstance(dim, str): + dim = not_none(self.mesh_dim_names).index(dim) + + if backend_override is not None: + backend_override_tuple = tuple( + _normalize_backend_override( + backend_override, # type: ignore[arg-type] + len(mesh_sizes), + mesh_dim_names, + ) + ) + else: + backend_override_tuple = ((None, None),) * len(mesh_dim_names) + + return self._create_unflatten_mesh( + dim, + mesh_sizes, + mesh_dim_names, + backend_override_tuple, + ) + + @staticmethod + def _concatenate(device_mesh_list: list["DeviceMesh"]) -> "DeviceMesh": + concat_dim_names: list[str] = [] + concat_sizes: list[IntTuple] = [] + concat_strides: list[IntTuple] = [] + concat_dim_group_name: list[GroupName] = [] + flatten_rank_map = device_mesh_list[0]._flatten_rank_map + for dm in device_mesh_list: + for i in range(len(dm._layout)): + concat_sizes.append(dm._layout[i].sizes) + concat_strides.append(dm._layout[i].strides) + concat_dim_names.extend(not_none(dm.mesh_dim_names)) + concat_dim_group_name.extend(not_none(dm._dim_group_names)) + # Concatenate device mesh having different root mesh tensors are meaningless + # because the concatenated indices should be indexed by the same root mesh tensor. + if dm._flatten_rank_map != flatten_rank_map: + raise RuntimeError( + "Cannot concatenate DeviceMeshes derived from different device meshs" + ) + concat_mesh_layout = _MeshLayout(tuple(concat_sizes), tuple(concat_strides)) + if not concat_mesh_layout.check_non_overlap(): + raise RuntimeError( + f"Cannot concatenate overlapping meshes: {device_mesh_list}" + ) + res_mesh = DeviceMesh( + device_mesh_list[0].device_type, + _layout=concat_mesh_layout, + _rank_map=device_mesh_list[0]._rank_map, + mesh_dim_names=tuple(concat_dim_names), + _root_mesh=device_mesh_list[0]._get_root_mesh(), + _init_backend=False, + ) + res_mesh._dim_group_names = concat_dim_group_name + return res_mesh + + def _normalize_backend_override( + backend_override: dict[ + int | str, + str | C10dBackend.Options | tuple[str, C10dBackend.Options], + ], + ndim: int, + mesh_dim_names: tuple[str, ...] | None = None, + ) -> Iterator[BackendConfig]: + if mesh_dim_names is None: + mesh_dim_names = () + for dim_idx, dim_name in zip_longest(range(ndim), mesh_dim_names): + if dim_name is not None and dim_name in backend_override: + if dim_idx in backend_override: + raise RuntimeError( + f"Found redundant dim index {dim_idx} and " + f"name {dim_name} in backend_override" + ) + val = backend_override.pop(dim_name) + elif dim_idx in backend_override: + val = backend_override.pop(dim_idx) + else: + yield (None, None) + continue + + if isinstance(val, str): + yield (val, None) + elif isinstance(val, C10dBackend.Options): + yield (None, val) + else: + yield val + + if backend_override: + raise RuntimeError( + f"Found invalid keys in backend_override: got {list(backend_override.keys())}, " + f"expected integers in range [0, {ndim}) or one of {mesh_dim_names}" + ) + + def init_device_mesh( + device_type: str, + mesh_shape: tuple[int, ...], + *, + mesh_dim_names: tuple[str, ...] | None = None, + backend_override: dict[ + int | str, str | C10dBackend.Options | tuple[str, C10dBackend.Options] + ] + | None = None, + ) -> DeviceMesh: + """ + Initializes a `DeviceMesh` based on `device_type`, `mesh_shape`, and `mesh_dim_names` parameters. + + This creates a DeviceMesh with an n-dimensional array layout, where `n` is the length of `mesh_shape`. + If `mesh_dim_names` is provided, each dimension is labeled as `mesh_dim_names[i]`. + + .. note:: + `init_device_mesh` follows SPMD programming model, meaning the same PyTorch Python program + runs on all processes/ranks in the cluster. Ensure `mesh_shape` (the dimensions of the nD array + describing device layout) is identical across all ranks. Inconsistent `mesh_shape` may lead to hanging. + + .. note:: + If no process group is found, init_device_mesh will initialize distributed process group/groups + required for distributed communications behind the scene. + + Args: + device_type (str): The device type of the mesh. Currently supports: "cpu", "cuda/cuda-like", "xpu". + Passing in a device type with a GPU index, such as "cuda:0", is not allowed. + mesh_shape (Tuple[int]): A tuple defining the dimensions of the multi-dimensional array + describing the layout of devices. + mesh_dim_names (tuple[str, ...], optional): A tuple of mesh dimension names to assign to each dimension + of the multi-dimensional array describing the layout of devices. Its length must match the length + of `mesh_shape`. Each string in `mesh_dim_names` must be unique. + backend_override (Dict[int | str, tuple[str, Options] | str | Options], optional): Overrides for some or all of + the ProcessGroups that will be created for each mesh dimension. Each key can be either the index of a + dimension or its name (if mesh_dim_names is provided). Each value can be a tuple containing the name + of the backend and its options, or just one of these two components (in which case the other will be + set to its default value). + + Returns: + DeviceMesh: A :class:`DeviceMesh` object representing the device layout. + + Example:: + + >>> # xdoctest: +SKIP("no rank") + >>> from torch.distributed.device_mesh import init_device_mesh + >>> + >>> mesh_1d = init_device_mesh("cuda", mesh_shape=(8,)) + >>> mesh_2d = init_device_mesh("cuda", mesh_shape=(2, 8), mesh_dim_names=("dp", "tp")) + + """ + if mesh_dim_names is not None: + if len(set(mesh_dim_names)) != len(mesh_dim_names): + raise RuntimeError( + "Each mesh_dim_name must be unique.", + f"Found repeated mesh_dim_name in mesh_dim_names {mesh_dim_names}", + ) + + if len(mesh_shape) != len(mesh_dim_names): + raise RuntimeError( + "mesh_shape and mesh_dim_names should have same length!", + f"Found len(mesh_dim_names): {len(mesh_dim_names)} and len(mesh_shape):{len(mesh_shape)}.", + ) + + if backend_override is not None: + backend_override_tuple = tuple( + _normalize_backend_override( + backend_override, len(mesh_shape), mesh_dim_names + ) + ) + else: + backend_override_tuple = None + + # assume valid device types are all letters + if device_type and not device_type.isalpha(): + raise RuntimeError( + f"Device type with index is not supported but got {device_type}. ", + "If you maintained a 'torch.device' object, it's recommended to pass in 'device.type'.", + ) + + layout = _MeshLayout(tuple(mesh_shape), suffix_product(tuple(mesh_shape))) + # Always initialize the (identity) rank map on CPU, regardless of what the + # external device type has been set to be (e.g. meta) + with torch.device("cpu"): + rank_map = torch.arange(layout.numel(), dtype=torch.int) + device_mesh = DeviceMesh( + device_type=device_type, + _layout=layout, + _rank_map=rank_map, + mesh_dim_names=mesh_dim_names, + backend_override=backend_override_tuple, + ) + + return device_mesh diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/distributed_c10d.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/distributed_c10d.py new file mode 100644 index 0000000000000000000000000000000000000000..6d29e77da50a8c6eaff9af2eda317ff1ce5156d1 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/distributed_c10d.py @@ -0,0 +1,6286 @@ +# mypy: allow-untyped-defs +"""Distributed Collective Communication (c10d).""" + +import collections.abc +import contextlib +import copy +import ctypes +import hashlib +import io +import itertools +import logging +import os +import pickle +import sys +import time +import warnings +from collections import namedtuple +from collections.abc import Callable +from datetime import timedelta +from typing import Any, NewType, TYPE_CHECKING +from typing_extensions import deprecated + +import torch +from torch._C import _DistStoreError as DistStoreError +from torch._C._distributed_c10d import ( + _DistributedBackendOptions, + _register_process_group, + _resolve_process_group, + _unregister_all_process_groups, + _unregister_process_group, + AllgatherOptions, + AllreduceCoalescedOptions, + AllreduceOptions, + AllToAllOptions, + BarrierOptions, + BroadcastOptions, + DebugLevel, + GatherOptions, + get_debug_level, + PrefixStore, + ProcessGroup, + ReduceOp, + ReduceOptions, + ReduceScatterOptions, + ScatterOptions, + Store, + Work, +) +from torch._utils_internal import set_pytorch_distributed_envs_from_justknobs +from torch.monitor import _WaitCounter +from torch.overrides import handle_torch_function, has_torch_function +from torch.utils._typing_utils import not_none + +from .c10d_logger import _exception_logger, _time_logger +from .constants import default_pg_nccl_timeout, default_pg_timeout +from .rendezvous import register_rendezvous_handler, rendezvous # noqa: F401 + + +__all__ = [ + "Backend", + "BackendConfig", + "GroupMember", + "P2POp", + "all_gather", + "all_gather_coalesced", + "all_gather_object", + "all_reduce", + "all_reduce_coalesced", + "all_to_all", + "all_to_all_single", + "barrier", + "batch_isend_irecv", + "broadcast", + "send_object_list", + "recv_object_list", + "broadcast_object_list", + "destroy_process_group", + "gather", + "gather_object", + "get_backend_config", + "get_backend", + "get_default_backend_for_device", + "get_rank", + "get_world_size", + "get_pg_count", + "group", + "init_process_group", + "irecv", + "is_gloo_available", + "is_initialized", + "is_mpi_available", + "is_backend_available", + "is_nccl_available", + "is_torchelastic_launched", + "is_ucc_available", + "is_xccl_available", + "isend", + "monitored_barrier", + "new_group", + "new_subgroups", + "new_subgroups_by_enumeration", + "recv", + "reduce", + "reduce_scatter", + "scatter", + "scatter_object_list", + "send", + "supports_complex", + "AllreduceCoalescedOptions", + "AllreduceOptions", + "AllToAllOptions", + "BarrierOptions", + "BroadcastOptions", + "GatherOptions", + "GroupName", + "PrefixStore", + "ProcessGroup", + "ReduceOp", + "ReduceOptions", + "ReduceScatterOptions", + "ScatterOptions", + "Store", + "DebugLevel", + "get_debug_level", + "Work", + "default_pg_timeout", + "get_group_rank", + "get_global_rank", + "get_process_group_ranks", + "reduce_op", + "all_gather_into_tensor", + "reduce_scatter_tensor", + "get_node_local_rank", + "split_group", + "shrink_group", +] + +_MPI_AVAILABLE = True +_NCCL_AVAILABLE = True +_GLOO_AVAILABLE = True +_UCC_AVAILABLE = True +_XCCL_AVAILABLE = True + +_pickler = pickle.Pickler +_unpickler = pickle.Unpickler + +GroupName = NewType("GroupName", str) + + +# Change __module__ of all imported types from torch._C._distributed_c10d that are public +def _export_c_types() -> None: + _public_types_to_change_module = [ + AllreduceCoalescedOptions, + AllreduceOptions, + AllToAllOptions, + BarrierOptions, + BroadcastOptions, + GatherOptions, + PrefixStore, + ProcessGroup, + ReduceOp, + ReduceOptions, + ReduceScatterOptions, + ScatterOptions, + Store, + DebugLevel, + get_debug_level, + Work, + ] + for type in _public_types_to_change_module: + type.__module__ = "torch.distributed.distributed_c10d" + + +_export_c_types() + +try: + from torch._C._distributed_c10d import ProcessGroupMPI + + ProcessGroupMPI.__module__ = "torch.distributed.distributed_c10d" + __all__ += ["ProcessGroupMPI"] +except ImportError: + _MPI_AVAILABLE = False + +try: + from torch._C._distributed_c10d import ProcessGroupNCCL + + ProcessGroupNCCL.__module__ = "torch.distributed.distributed_c10d" + __all__ += ["ProcessGroupNCCL"] +except ImportError: + _NCCL_AVAILABLE = False + +try: + from torch._C._distributed_c10d import _ProcessGroupWrapper, ProcessGroupGloo + + ProcessGroupGloo.__module__ = "torch.distributed.distributed_c10d" + __all__ += ["ProcessGroupGloo"] +except ImportError: + _GLOO_AVAILABLE = False + +try: + from torch._C._distributed_c10d import ProcessGroupUCC + + ProcessGroupUCC.__module__ = "torch.distributed.distributed_c10d" + __all__ += ["ProcessGroupUCC"] +except ImportError: + _UCC_AVAILABLE = False + +try: + from torch._C._distributed_c10d import ProcessGroupXCCL + + ProcessGroupXCCL.__module__ = "torch.distributed.distributed_c10d" + __all__ += ["ProcessGroupXCCL"] +except ImportError: + _XCCL_AVAILABLE = False + +logger = logging.getLogger(__name__) + +PG_WRAPPER_STORE_PREFIX = "pg_wrapper" + + +# Some reduce ops are not supported by complex numbers and will result in an error. +# We currently provide complex support to the distributed API by viewing +# complex tensors as real (torch.view_as_real), meaning that calling +# these unsupported ops will return garbage values rather than error out. +# (e.g. max(2+3i, 3+2i) = 3+3i) +# We'd like calls to unsupported ops to error out accordingly, +# rather than returning garbage values. +def supports_complex(reduceOp: ReduceOp) -> bool: + """Return true if reduce ops is supported. False otherwise.""" + denyList = [ + ReduceOp.MAX, + ReduceOp.MIN, + ReduceOp.PRODUCT, + ReduceOp.BAND, + ReduceOp.BOR, + ReduceOp.BXOR, + ] + return reduceOp not in denyList + + +# TODO refactor into enum/strenum +class Backend(str): # noqa: SLOT000 + """ + An enum-like class for backends. + + Available backends: GLOO, NCCL, UCC, MPI, XCCL, and other registered backends. + + The values of this class are lowercase strings, e.g., ``"gloo"``. They can + be accessed as attributes, e.g., ``Backend.NCCL``. + + This class can be directly called to parse the string, e.g., + ``Backend(backend_str)`` will check if ``backend_str`` is valid, and + return the parsed lowercase string if so. It also accepts uppercase strings, + e.g., ``Backend("GLOO")`` returns ``"gloo"``. + + .. note:: The entry ``Backend.UNDEFINED`` is present but only used as + initial value of some fields. Users should neither use it directly + nor assume its existence. + """ + + UNDEFINED = "undefined" + GLOO = "gloo" + NCCL = "nccl" + UCC = "ucc" + MPI = "mpi" + XCCL = "xccl" + + _BackendPlugin = namedtuple("_BackendPlugin", ["creator_fn", "extended_api"]) + + _plugins: dict[str, _BackendPlugin] = {} + + backend_list = [UNDEFINED, GLOO, NCCL, XCCL, UCC, MPI] + + # 3rd-party devices can register the default backend support here + default_device_backend_map: dict[str, str] = { + "cpu": GLOO, + "cuda": NCCL, + "xpu": XCCL, + "mps": GLOO, + } + + backend_capability: dict[str, list[str]] = { + GLOO: ["cpu", "cuda"], + NCCL: ["cuda"], + XCCL: ["xpu"], + UCC: ["cpu", "cuda"], + MPI: ["cpu", "cuda"], + } + + backend_type_map: dict[str, ProcessGroup.BackendType] = { + UNDEFINED: ProcessGroup.BackendType.UNDEFINED, + GLOO: ProcessGroup.BackendType.GLOO, + NCCL: ProcessGroup.BackendType.NCCL, + XCCL: ProcessGroup.BackendType.XCCL, + UCC: ProcessGroup.BackendType.UCC, + MPI: ProcessGroup.BackendType.MPI, + } + + def __new__(cls, name: str): + """Create and return a new instance of the class.""" + if not isinstance(name, str): + raise ValueError("Backend constructor parameter must be string-ish") + value = getattr(Backend, name.upper(), Backend.UNDEFINED) + + if value == Backend.UNDEFINED: + value = name.lower() + return value + + @classmethod + def register_backend( + cls, + name, + func, + extended_api: bool = False, + devices: str | list[str] | None = None, + ) -> None: + """ + Register a new backend with the given name and instantiating function. + + This class method is used by 3rd party ``ProcessGroup`` extension to + register new backends. + + Args: + name (str): Backend name of the ``ProcessGroup`` extension. It + should match the one in ``init_process_group()``. + func (function): Function handler that instantiates the backend. + The function should be implemented in the backend + extension and takes four arguments, including + ``store``, ``rank``, ``world_size``, and ``timeout``. + extended_api (bool, optional): Whether the backend supports extended argument structure. + Default: ``False``. If set to ``True``, the backend + will get an instance of ``c10d::DistributedBackendOptions``, and + a process group options object as defined by the backend implementation. + device (str or list of str, optional): device type this backend + supports, e.g. "cpu", "cuda", etc. If `None`, + assuming both "cpu" and "cuda" + + .. note:: This support of 3rd party backend is experimental and subject to change. + + """ + # This takes care of CUSTOM Out-of-tree backend types, update in backend_list indicates availability + if not hasattr(Backend, name.upper()): + setattr(Backend, name.upper(), name.lower()) + if name.lower() not in Backend.backend_list: + Backend.backend_list.append(name.lower()) + + if devices is not None: + for device in devices: + if device not in Backend.default_device_backend_map: + Backend.default_device_backend_map[device] = name.lower() + Backend.backend_type_map[name.lower()] = ProcessGroup.BackendType.CUSTOM + + # Update device capability matrix in Backend class + if devices is None: + # This is more of a backward support for groups like `threaded`: + # assume default devices "cpu" and "cuda", but warn + warnings.warn( + f"Device capability of {name} unspecified, assuming `cpu` and " + "`cuda` or `xpu`. Please specify it via the `devices` argument of " + "`register_backend`.", + stacklevel=2, + ) + Backend.backend_capability[name.lower()] = ( + ["cpu", "cuda", "xpu"] if torch.xpu.is_available() else ["cpu", "cuda"] + ) + elif isinstance(devices, str): + # Single device string specified. Simply convert to list. + Backend.backend_capability[name.lower()] = [devices] + else: + Backend.backend_capability[name.lower()] = devices + + Backend._plugins[name.upper()] = Backend._BackendPlugin(func, extended_api) + + +class BackendConfig: + """Backend configuration class.""" + + def __init__(self, backend: Backend): + """Init.""" + self.device_backend_map: dict[str, Backend] = {} + # pyrefly: ignore [bad-assignment] + backend = str(backend) + + if backend == Backend.UNDEFINED: + # Detect the accelerator on the machine. If no accelerator is + # available, it returns CPU. + device_type = torch._C._get_accelerator().type + try: + backend_str = Backend.default_device_backend_map[device_type] + self.device_backend_map[device_type] = Backend(backend_str) + except KeyError: + raise ValueError( + f"We detected accelerator {device_type} on your machine. " + f"But we don't know which communication backend to use for this accelerator. " + f"Please specify the `backend` argument in the `init_process_group` call." + ) from None + elif backend.lower() in Backend.backend_list: + # Cases for when backend is a single string (without device types) + # e.g. "nccl", "gloo", "ucc", "mpi" + supported_devices = Backend.backend_capability[backend.lower()] + backend_val = Backend(backend) + + self.device_backend_map = dict.fromkeys(supported_devices, backend_val) + elif ":" in backend.lower(): + # Backend specified in "device:backend" format + # make sure the backend string is in the correct format + # "{device_type1}:{backend1},{device_type2}:{backend2}" + # e.g. "cpu:gloo,cuda:nccl" + backend_str_error_message = f"""The custom backend string argument is invalid: {backend}. + Custom backend string is an experimental feature where the backend string must be in the format: + ":,:...". e.g. 'cpu:gloo,cuda:nccl'""" + + # parse the backend string and populate the device_backend_map + for device_backend_pair_str in backend.lower().split(","): + device_backend_pair = device_backend_pair_str.split(":") + if len(device_backend_pair) != 2: + raise ValueError( + f"Invalid device:backend pairing: \ + {device_backend_pair_str}. {backend_str_error_message}" + ) + # pyrefly: ignore [bad-assignment] + device, backend = device_backend_pair + if device in self.device_backend_map: + raise ValueError( + f"Duplicate device type {device} \ + in backend string: {backend}. {backend_str_error_message}" + ) + self.device_backend_map[device] = Backend(backend) + else: + # User specified a single backend name whose device capability is + # unknown, assuming it can support the default devices of PyTorch + # (cpu and cuda) + warnings.warn( + f"Device capability of {backend} unknown, assuming `cpu` and " + "`cuda`. You can specify it in `device:backend` format in " + "`init_process_group` call.", + stacklevel=2, + ) + backend_val = Backend(backend) + self.device_backend_map = { + "cpu": backend_val, + "cuda": backend_val, + "xpu": backend_val, + } + + logger.info("Using backend config: %s", self.device_backend_map) + + def __repr__(self): + """Return all the device:backend pairs separated by commas.""" + return ",".join( + f"{device}:{backend}" for device, backend in self.device_backend_map.items() + ) + + def get_device_backend_map(self) -> dict[str, Backend]: + """Return backend map of the device.""" + return self.device_backend_map + + +class _reduce_op: + r""" + Deprecated enum-like class. + + For reduction operations: ``SUM``, ``PRODUCT``, ``MIN``, and ``MAX``. + + :class:`~torch.distributed.ReduceOp` is recommended to use instead. + """ + + def __init__(self) -> None: + # __members__ is a dict storing key-value pairs for enum classes + for k, v in ReduceOp.RedOpType.__members__.items(): + setattr(self, k, v) + self.__members__ = ReduceOp.RedOpType.__members__ + + @deprecated( + "`torch.distributed.reduce_op` is deprecated, " + "please use `torch.distributed.ReduceOp` instead", + category=FutureWarning, + ) + def __getattribute__(self, key): + return object.__getattribute__(self, key) + + +reduce_op = _reduce_op() + + +class P2POp: + """ + A class to build point-to-point operations for ``batch_isend_irecv``. + + This class builds the type of P2P operation, communication buffer, peer rank, + Process Group, and tag. Instances of this class will be passed to + ``batch_isend_irecv`` for point-to-point communications. + + Args: + op (Callable): A function to send data to or receive data from a peer process. + The type of ``op`` is either ``torch.distributed.isend`` or + ``torch.distributed.irecv``. + tensor (Tensor): Tensor to send or receive. + peer (int, optional): Destination or source rank. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + tag (int, optional): Tag to match send with recv. + group_peer (int, optional): Destination or source rank. + """ + + def __init__( + self, + op: Callable, + tensor: torch.Tensor, + peer: int | None = None, + group: ProcessGroup | None = None, + tag: int = 0, + group_peer: int | None = None, + ): + """Init.""" + self.op = op + self.tensor = tensor + self.group = _group_or_default_group(group) + self.peer = _canonicalize_group_rank( + self.group, peer, group_peer, return_global=True + ) + self.tag = tag + self.group_peer = _canonicalize_group_rank(self.group, peer, group_peer) + + def __new__( + cls, + op: Callable, + tensor: torch.Tensor, + peer: int | None = None, + group: ProcessGroup | None = None, + tag: int = 0, + group_peer: int | None = None, + ): + """Create and return a new instance of the class.""" + _check_op(op) + _check_single_tensor(tensor, "tensor") + + return object.__new__(cls) + + def __repr__(self): + my_group_rank = get_rank(self.group) + op_name = self.op.__name__ + group_name = self.group.group_name if self.group else "default_pg" + if "send" in op_name: + s = my_group_rank + d = self.group_peer + elif "recv" in op_name: + s = self.group_peer + d = my_group_rank + else: + return super().__repr__() + + return f"P2POp({op_name} pg={group_name}, group_src={s}, group_dst={d}, {self.tensor.shape}, {self.tensor.dtype})" + + +class _CollOp: + """ + A class to capture collective operations. + + Args: + op (Callable): A collective function, e.g. ``torch.distributed.all_reduce``. + tensor (Tensor): Tensor to operate on. + dst_tensor (Tensor, optional): Provided when source and destination tensors are not the same. + redop (ReduceOp, optional): reduce operation. + root (int, optional): root of broadcast or reduce. + """ + + def __init__( + self, + op: Callable, + tensor: torch.Tensor, + dst_tensor: torch.Tensor | None = None, + redop: ReduceOp | None = None, + root: int | None = None, + ): + self.op = op + self.tensor = tensor + self.dst_tensor = dst_tensor + self.redop = redop + self.root = root + + +# DO NOT USE THESE FIELDS DIRECTLY. +# Use them through the _world object to make sure the _world override mechanism +_pg_map: dict[ProcessGroup, tuple[str, Store]] = {} +_pg_names: dict[ProcessGroup, GroupName] = {} +_pg_group_ranks: dict[ProcessGroup, dict[int, int]] = {} +# For a pg, it is a map from ProcessGroup to BackendConfig +_pg_backend_config: dict[ProcessGroup, str] = {} +_group_count = 0 +_tags_to_pg: dict[str, list[ProcessGroup]] = {} +_pg_to_tag: dict[ProcessGroup, str] = {} +_backend: str | None = None + + +class _World: + """ + Container class for c10d process group state. + + This is used during registration and lookup of PG state. + + .. warning:: This is an experimental API intended to expose the inner workings + of c10d and is subject to change.. + """ + + def __init__(self) -> None: + self._default_pg = None + self._pg_coalesce_state: dict[ProcessGroup, list[_CollOp]] = {} + + @property + def default_pg(self) -> ProcessGroup | None: + """ + Process group that includes all ranks of the cluster. + + This default ProcessGroup is used by c10d APIs when a ProcessGroup is needed + but None is provided. + """ + return self._default_pg + + @default_pg.setter + def default_pg(self, value) -> None: + self._default_pg = value + + @property + def pg_map(self) -> dict[ProcessGroup, tuple[str, Store]]: + """ + Provide Mapping from ProcessGroup to backend name and store. + + For NCCL and GLOO pg, it is a map from ProcessGroup to (Backend, Store) + For MPI pg, it is a map from ProcessGroup to (Backend, None) + + TODO don't expose the map, expose fine grained ops + """ + global _pg_map + return _pg_map + + @property + def pg_names(self) -> dict[ProcessGroup, GroupName]: + """ + Process group's names, map from ProcessGroup to str. + + TODO don't expose the map, expose fine grained ops + """ + global _pg_names + return _pg_names + + @property + def pg_group_ranks(self) -> dict[ProcessGroup, dict[int, int]]: + """ + Process group's global rank to local rank mapping. + + TODO don't expose the map, expose fine grained ops + """ + global _pg_group_ranks + return _pg_group_ranks + + @property + def pg_backend_config(self) -> dict[ProcessGroup, str]: + """ + Process group's backend config. + + TODO don't expose the map, expose fine grained ops + """ + global _pg_backend_config + return _pg_backend_config + + @property + def group_count(self) -> int: + """ + Process group count for default naming. + + TODO don't expose group_count, use something else instead + """ + global _group_count + return _group_count + + @group_count.setter + def group_count(self, value: int) -> None: + """Use to compute the name of ProcessGroups when using global synchronization.""" + global _group_count + _group_count = value + + @property + def tags_to_pg(self) -> dict[str, list[ProcessGroup]]: + global _tags_to_pg + return _tags_to_pg + + @property + def pg_to_tag(self) -> dict[ProcessGroup, str]: + global _pg_to_tag + return _pg_to_tag + + @property + def pg_coalesce_state(self) -> dict[ProcessGroup, list[_CollOp]]: + return self._pg_coalesce_state + + @property + def pg_config_info(self) -> list[dict[str, Any]]: + """ + Return a list of dict with process groups and backends. + + Along with their unique IDs and configurations (types and ranks). + """ + config_info: list[dict[str, Any]] = [] + default_pg_size = _get_group_size(None) + for pg in self.pg_map: + ranks = self.pg_group_ranks[pg] + config_info.append( + { + "pg_name": self.pg_names[pg], + "pg_desc": pg.group_desc, + "backend_config": self.pg_backend_config[pg], + "ranks": ( + list(ranks.keys()) if len(ranks) != default_pg_size else [] + ), # 'ranks' is an empty list when all ranks are involved in a pg + "group_size": len(ranks), + "group_count": self.group_count, + } + ) + return config_info + + +_world = _World() +"""Holds the singleton instance of ``_World`` used by c10. Experimental extension point to override it""" + + +class _WorldMeta(type): + """ + Meta class of ``group`` and ``GroupMember``. + + Allows them to have the class property ``WORLD``. + """ + + # Points to the default PG once initialized. + @property + def WORLD(cls) -> ProcessGroup | None: + return _world.default_pg + + @WORLD.setter + def WORLD(cls, pg: ProcessGroup | None): + _world.default_pg = pg + + +class group(metaclass=_WorldMeta): + """Group class. Placeholder.""" + + +class GroupMember(metaclass=_WorldMeta): + """Group member class.""" + + NON_GROUP_MEMBER = -100 + + +def _get_default_timeout(backend: Backend) -> timedelta: + # see note on nccl vs other backend timeout (constants.py) + if backend == Backend.NCCL: + if not isinstance(default_pg_nccl_timeout, timedelta): + # TODO moco benchmark on CPU initializes pgnccl backend today, triggered this assert in CI before it was + # changed to be a warning. We should fix the moco model. + warnings.warn( + "Attempted to get default timeout for nccl backend, but NCCL support is not compiled", + stacklevel=2, + ) + return default_pg_timeout + return default_pg_nccl_timeout + else: + return default_pg_timeout + + +def _check_valid_timeout(timeout: Any) -> None: + if not isinstance(timeout, timedelta): + raise TypeError( + f"Expected timeout argument to be of type datetime.timedelta, got {timeout}" + ) + + +# Default process group state +_default_pg_init_method: str | None = None + +STORE_BASED_BARRIER_PREFIX = "store_based_barrier_key" + + +def _get_object_coll_device(group: ProcessGroup | None = None) -> str: + """ + .. note:: This is an internal helper and does not have backward + compatibility, please use with caution. + + Return the device type to use with ``group`` for object collectives or + barrier. + + There are selection rules: + 1. If user specifies exactly one backend in ``init_process_group`` call: + use that backend + 2. Else if user specifies multiple "device:backend" pairs in init_process_group: + If "cpu" is among those pairs, use "cpu" (because the object is in cpu memory); + Otherwise, use the first backend (sort of a random pick). + + Args: + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + + Returns: + str: The device type to use for object collective with ``group``. + + """ + group = group or _get_default_group() + + if not isinstance(group, ProcessGroup): + warnings.warn( + f"You are using a Backend {type(group)} as a ProcessGroup. " + "This usage is deprecated since PyTorch 2.0. Please use a public API " + "of PyTorch Distributed instead.", + stacklevel=2, + ) + # Provide backward compatibility to cases where `group` passed in is + # actually a Backend (like `ProcessGroupGloo`) rather than a + # `ProcessGroup` in PT 2.0 sense + if isinstance(group, ProcessGroupGloo): + # RPC uses Gloo for object collectives + return "cpu" + else: + raise ValueError(f"Expecting a ProcessGroup, but got a {type(group)}.") + + """ + ``group._device_types`` is a property pybind that returns the devices + ("cpu", "cuda", etc) supported by ``group``. Can be multiple if the + ``group`` supports multiple devices. + """ + devices = group._device_types + + if len(devices) == 1: + # User fixed exactly one backend in `init_process_group` + return devices[0].type + elif len(devices) == 0: + # No backend has been registered with this PG (maybe because no + # collective has been run?) We pick cpu as the default and hopefully + # this would lazily init Gloo or other available cpu backend. + return "cpu" + elif torch.device("cpu") in devices: + # There are multiple backends in this PG and cpu is among them. + # cpu is preferred as the object is in cpu memory. No need for device + # copy. + return "cpu" + else: + # No cpu in the backend list. Randomly pick the first backend + return devices[0].type + + +def _get_pg_default_device(group: ProcessGroup | None = None) -> torch.device: + """ + .. note:: This method will be deprecated, it only stays for + backward-compatiblity reason. Alternatives: + + - If you need to find a device for object collectives, please use + `_get_object_coll_device(group)`. + + - If you need to query the device types supported by group, please use + `_device_capability(group)`. + + Return the device type registered with ``group``. + + For example, if `init_process_group("nccl", ...)` was called, the returned + value would be `torch.device("cuda")`. + + Errors out if no device has been registered. + + Args: + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + + Returns: + torch.device: The device type registered with ``group``. + """ + + warnings.warn( + "`_get_pg_default_device` will be deprecated, it only stays for " + "backward-compatiblity reason. If you need to find a device for object " + "collectives, please use `_get_object_coll_device`. If you need to query " + "the device types supported by group, please use " + "`_device_capability(group)`. ", + stacklevel=2, + ) + group = group or _get_default_group() + + if not isinstance(group, ProcessGroup): + # Provide backward compatibility to cases where `group` passed in is + # actually a Backend (like `ProcessGroupGloo`) rather than a + # `ProcessGroup` in PT 2.0 sense + warnings.warn( + f"You are using a Backend {type(group)} as a ProcessGroup. " + "This usage is deprecated since PyTorch 2.0. Please use a public API " + "of PyTorch Distributed instead.", + FutureWarning, + stacklevel=3, + ) + # Most users create Gloo with private API for object collectives + return torch.device("cpu") + + """ + ``group._device_types`` is a property pybind that returns the devices + ("cpu", "cuda", etc) supported by ``group``. Can be multiple if the + ``group`` supports multiple devices. + """ + devices = group._device_types + + if len(devices) == 1: + # User fixed exactly one backend in `init_process_group` + return devices[0] + elif len(devices) == 0: + raise RuntimeError( + "Default device not found, because no backend has been registered " + "with this ProcessGroup." + ) + else: + # There are multiple backends in this PG. + if torch.device("cpu") in devices: + rv = torch.device("cpu") + else: + rv = devices[0] + warnings.warn( + "Multiple backends are registered with this ProcessGroup. We cannot " + f"determine which one is the default. Returning {rv}. " + "Please consider using other APIs.", + stacklevel=2, + ) + return rv + + +def _device_capability(group: ProcessGroup | None = None) -> list[str]: + """ + Return the device type(s) supported by ``group``. + + Args: + group (ProcessGroup, optional): The process group to query. If None, + the default process group will be used. + + Returns: + List[str]: A list of device types supported by ``group``. + """ + group = group or _get_default_group() + return [device.type for device in group._device_types] + + +@_time_logger +def _store_based_barrier( + rank, + store, + group_name: GroupName, + rendezvous_count, + timeout, + logging_interval=timedelta(seconds=10), +) -> None: + """ + Store based barrier for synchronizing processes. + + Barrier based on store which is used for synchronizing processes after + ``init_process_group`` or ``new_group``. Intended to be used only with + those two methods and is not a generic alternative to ``barrier()``. + """ + store_key = f"{STORE_BASED_BARRIER_PREFIX}:{group_name}" + store.add(store_key, 1) + logger.debug("Added key: %s to store for rank: %s", store_key, rank) + + # Now wait for all workers to check in with the store. + world_size = rendezvous_count + worker_count = store.add(store_key, 0) + + last_worker_key = f"{store_key}:last_worker" + if worker_count == world_size: + store.set(last_worker_key, "1") + + # adjust the timeout to be at least 10secs + 1sec per thousand ranks to reduce the odds of timeout + # this value was empirically found while scale testing. + logging_interval = max(logging_interval, timedelta(seconds=10 + world_size / 1000)) + + start = time.time() + while True: + try: + # This will throw an exception after the logging_interval in which we print out + # the status of the group or time out officially, throwing runtime error + store.wait([last_worker_key], logging_interval) + break + except RuntimeError as e: + worker_count = store.add(store_key, 0) + # Print status periodically to keep track. + logger.debug( # noqa: G200 + "Waiting in store based barrier to initialize process group for %s seconds" + "rank: %s, key: %s (world_size=%s, num_workers_joined=%s, timeout=%s error=%s)", + time.time() - start, + rank, + store_key, + world_size, + worker_count, + timeout, + e, + ) + + if timedelta(seconds=(time.time() - start)) > timeout: + raise DistStoreError( # noqa: B904 + "Timed out initializing process group in store based barrier on " + f"rank {rank}, for key: {store_key} (world_size={world_size}, " + f"num_workers_joined={worker_count}, timeout={timeout} error={e})" + ) + + logger.info( + "Rank %s: Completed store-based barrier for key:%s with %s nodes.", + rank, + store_key, + world_size, + ) + + +def _rank_not_in_group(group: ProcessGroup | None) -> bool: + """Check if the current process's rank is not in a given group.""" + if group is None: + return False + return group == GroupMember.NON_GROUP_MEMBER + + +def _warn_not_in_group(op_name) -> None: + global_rank = -1 if GroupMember.WORLD is None else GroupMember.WORLD.rank() + warnings.warn( + f"Running {op_name} on global rank {global_rank} which does not " + "belong to the given group.", + stacklevel=2, + ) + + +def get_group_rank(group: ProcessGroup, global_rank: int) -> int: + """ + Translate a global rank into a group rank. + + ``global_rank`` must be part of ``group`` otherwise this raises RuntimeError. + + Args: + group (ProcessGroup): ProcessGroup to find the relative rank. + global_rank (int): Global rank to query. + + Returns: + Group rank of ``global_rank`` relative to ``group`` + + N.B. calling this function on the default process group returns identity + """ + if group is GroupMember.WORLD: + return global_rank + if group not in _world.pg_group_ranks: + raise ValueError( + f"Group {group} is not registered, please create group with torch.distributed.new_group API" + ) + group_ranks = _world.pg_group_ranks[group] + if global_rank not in group_ranks: + raise ValueError(f"Global rank {global_rank} is not part of group {group}") + + return group_ranks[global_rank] + + +def get_global_rank(group: ProcessGroup, group_rank: int) -> int: + """ + Translate a group rank into a global rank. + + ``group_rank`` must be part of `group` otherwise this raises RuntimeError. + + Args: + group (ProcessGroup): ProcessGroup to find the global rank from. + group_rank (int): Group rank to query. + + Returns: + Global rank of ``group_rank`` relative to ``group`` + + N.B. calling this function on the default process group returns identity + """ + if group is GroupMember.WORLD: + return group_rank + if group not in _world.pg_group_ranks: + raise ValueError( + f"Group {group} is not registered, please create group with torch.distributed.new_group API" + ) + for rank, grp_rank in _world.pg_group_ranks[group].items(): + if grp_rank == group_rank: + return rank + raise ValueError(f"Group rank {group_rank} is not part of group {group}") + + +# TODO: remove this once the ecosystem moves away from it. +@deprecated( + "`torch.distributed.distributed_c10d._get_global_rank` is deprecated, " + "please use `torch.distributed.distributed_c10d.get_global_rank` instead", + category=FutureWarning, +) +def _get_global_rank(group, rank) -> int: + """Use get_global_rank as this method is deprecated.""" + return get_global_rank(group, rank) + + +def get_process_group_ranks(group: ProcessGroup | None) -> list[int]: + """ + Get all ranks associated with ``group``. + + Args: + group (Optional[ProcessGroup]): ProcessGroup to get all ranks from. + If None, the default process group will be used. + + Returns: + List of global ranks ordered by group rank. + """ + return list(_world.pg_group_ranks[group or _get_default_group()].keys()) + + +def _get_group_size(group: ProcessGroup | None) -> int: + """Get a given group's world size.""" + if group is GroupMember.WORLD or group is None: + default_pg = _get_default_group() + return default_pg.size() + return group.size() + + +def _get_group_size_by_name(group_name: GroupName) -> int: + group = _resolve_process_group(group_name) + return group.size() + + +def _resolve_group_name_by_ranks_and_tag(ranks: list[int], tag: str) -> GroupName: + # TODO(yifu): remove this function once ranks + tag is not a supported + # identifier for process group for functional collectives. + group = _find_pg_by_ranks_and_tag(tag, ranks) + if group is None: + raise ValueError("") + return group.group_name + + +def _check_single_tensor(param, param_name: str) -> None: + """Check that the parameter ``param_name`` is a single tensor.""" + if not isinstance(param, torch.Tensor): + raise TypeError( + f"""Invalid function argument. Expected parameter `{param_name}` of type torch.Tensor + but got {type(param)} instead.""" + ) + + +def _check_tensor_list(param, param_name: str) -> None: + """Check that the parameter ``param_name`` is a list of tensors.""" + if not isinstance(param, list): + raise TypeError( + f"""Invalid function argument. Expected parameter `{param_name}` of type List[torch.Tensor] + but got {type(param)} instead.""" + ) + elif not all(isinstance(p, torch.Tensor) for p in param): + raise TypeError( + f"""Invalid function argument. Expected parameter `{param_name}` of type List[torch.Tensor] + but got {type(param)} with elements of type {[type(p) for p in param]}.""" + ) + + +def _group_or_default_group(group: ProcessGroup | None = None) -> ProcessGroup: + if group is None or group is GroupMember.WORLD: + group = _get_default_group() + return group + + +def _canonicalize_group_rank( + group: ProcessGroup, + global_rank: int | None = None, + group_rank: int | None = None, + return_global: bool = False, +) -> int: + """ + Helper method to take _either_ a global rank or a group rank and produce a group rank. + + If 'return_global' is true, produce a global rank instead of a group rank. + """ + + if group_rank is not None: + if global_rank is not None: + raise ValueError("Can't specify both group_rank and global_rank") + if return_global: + return get_global_rank(group, group_rank) + else: + if global_rank is None: + raise ValueError("Must specify global_rank or group_rank") + if return_global: + return global_rank + group_rank = get_group_rank(group, global_rank) + return group_rank + + +def _check_not_self_rank(group: ProcessGroup, rank: int, rank_type: str): + if group.rank() == rank: + raise ValueError( + f"Invalid {rank_type} rank: {rank_type} rank should not be the same as " + "the rank of the current process." + ) + + +def _as_iterable(obj) -> collections.abc.Iterable: + return obj if isinstance(obj, list) else (obj,) + + +def _ensure_all_tensors_same_dtype(*tensors) -> None: + last_dtype = None + # pyrefly: ignore [bad-assignment] + for tensor in itertools.chain.from_iterable(map(_as_iterable, tensors)): + tensor_dtype = tensor.dtype + # Mixing complex and its element type is allowed + if tensor_dtype.is_complex: + tensor_dtype = ( + torch.float32 if tensor_dtype == torch.complex64 else torch.complex128 + ) + + if last_dtype is None: + last_dtype = tensor_dtype + else: + if last_dtype != tensor_dtype: + raise ValueError( + "Invalid usage of tensors with different dtypes" + f"Found {last_dtype} and {tensor.dtype}" + ) + + +def _check_op(op) -> None: + """Check that the ``op`` is either isend or irecv.""" + if op not in [isend, irecv]: + raise ValueError( + "Invalid ``op``. Expected ``op`` " + "to be of type ``torch.distributed.isend`` or " + "``torch.distributed.irecv``." + ) + + +def _check_p2p_op_list(p2p_op_list) -> None: + """ + Check that the ``p2p_op_list`` is a list of P2POp instances. + + Also, check that all ops use the same group. + """ + if not isinstance(p2p_op_list, list) or not all( + isinstance(p2p_op, P2POp) for p2p_op in p2p_op_list + ): + raise ValueError( + "Invalid ``p2p_op_list``. Each op is expected to " + "to be of type ``torch.distributed.P2POp``." + ) + + group = p2p_op_list[0].group + if not all(group == p2p_op.group for p2p_op in p2p_op_list): + raise ValueError("All ops need to use the same group.") + + +def is_mpi_available() -> bool: + """Check if the MPI backend is available.""" + return _MPI_AVAILABLE + + +def is_nccl_available() -> bool: + """Check if the NCCL backend is available.""" + return _NCCL_AVAILABLE + + +def is_gloo_available() -> bool: + """Check if the Gloo backend is available.""" + return _GLOO_AVAILABLE + + +def is_ucc_available() -> bool: + """Check if the UCC backend is available.""" + return _UCC_AVAILABLE + + +def is_xccl_available() -> bool: + """Check if the XCCL backend is available.""" + return _XCCL_AVAILABLE + + +def _check_single_backend_availability(backend_name: str) -> bool: + """ + Helper function to check if a single backend is available. + """ + available_func = getattr( + torch.distributed, f"is_{str(backend_name).lower()}_available", None + ) + if available_func: + return available_func() + return str(backend_name).lower() in Backend.backend_list + + +def is_backend_available(backend: str) -> bool: + """ + Check backend availability. + + Checks if the given backend is available and supports the built-in backends or + third-party backends through function ``Backend.register_backend``. + + Args: + backend (str): Backend name. + Returns: + bool: Returns true if the backend is available otherwise false. + """ + # If the backend has an ``is_backend_available`` function, return the result of that function directly + if ":" in backend.lower(): # composite backend like "cpu:gloo" + backend_config = BackendConfig(Backend(backend)) + device_backend_map = backend_config.get_device_backend_map() + return all( + _check_single_backend_availability(str(backend_name)) + for backend_name in device_backend_map.values() + ) + else: + # Handle simple backend strings like "nccl", "gloo" + return _check_single_backend_availability(backend) + + +def is_initialized() -> bool: + """Check if the default process group has been initialized.""" + return GroupMember.WORLD is not None + + +def is_torchelastic_launched() -> bool: + """ + Check whether this process was launched with ``torch.distributed.elastic`` (aka torchelastic). + + The existence of ``TORCHELASTIC_RUN_ID`` environment + variable is used as a proxy to determine whether the current process + was launched with torchelastic. This is a reasonable proxy since + ``TORCHELASTIC_RUN_ID`` maps to the rendezvous id which is always a + non-null value indicating the job id for peer discovery purposes.. + """ + return os.getenv("TORCHELASTIC_RUN_ID") is not None + + +def _is_barrier_after_init() -> int: + # Environment variable to control whether process group should perform a + # barrier after its init. Default value is 0, i.e. no barrier. If you + # experience issue with this setting, you may set + # `TORCH_DIST_INIT_BARRIER=1` to add the barrier. + return int(os.getenv("TORCH_DIST_INIT_BARRIER", "0")) + + +def _get_default_group() -> ProcessGroup: + """Get the default process group created by init_process_group.""" + if not is_initialized(): + raise ValueError( + "Default process group has not been initialized, " + "please make sure to call init_process_group." + ) + if TYPE_CHECKING: + return not_none(GroupMember.WORLD) + else: + return GroupMember.WORLD + + +def _get_default_store() -> Store: + """Get the default store created by init_process_group.""" + if not is_initialized(): + raise ValueError( + "Default process group has not been initialized, " + "please make sure to call init_process_group." + ) + default_pg = _get_default_group() + _, default_store = _world.pg_map[default_pg] + return default_store + + +def _update_default_pg(pg: ProcessGroup | None) -> None: + _world.default_pg = pg + rank = pg.rank() if pg is not None and pg != GroupMember.NON_GROUP_MEMBER else -1 + torch._C._distributed_c10d._set_global_rank(rank) + + +def get_backend_config(group: ProcessGroup | None = None) -> str: + """ + Return the backend configuration of the given process group. + + Args: + group (ProcessGroup, optional): The process group to work on. The + default is the general main process group. If another specific group + is specified, the calling process must be part of :attr:`group`. + + Returns: + The backend configuration of the given process group as a lower case string. + + """ + pg = group or _get_default_group() + if _rank_not_in_group(pg): + raise ValueError("Invalid process group specified") + backend_config = _world.pg_backend_config.get(pg) + return str(not_none(backend_config)) + + +def get_backend(group: ProcessGroup | None = None) -> Backend: + """ + Return the backend of the given process group. + + Args: + group (ProcessGroup, optional): The process group to work on. The + default is the general main process group. If another specific group + is specified, the calling process must be part of :attr:`group`. + + Returns: + The backend of the given process group as a lower case string. + + """ + pg = group or _get_default_group() + if _rank_not_in_group(pg): + raise ValueError("Invalid process group specified") + + pg_store = _world.pg_map.get(pg, None) + if pg_store is None: + raise ValueError( + f"Process group {pg} is not initialized in the world group map. Please initialize the group first." + ) + + return Backend(not_none(pg_store)[0]) + + +def get_default_backend_for_device(device: str | torch.device) -> str: + """ + Return the default backend for the given device. + + Args: + device (Union[str, torch.device]): The device to get the default backend for. + + Returns: + The default backend for the given device as a lower case string. + + """ + if isinstance(device, torch.device): + device_str = device.type + else: + device_str = torch.device(device).type + + backend = Backend.default_device_backend_map.get(device_str) + if backend is None: + raise ValueError(f"Default backend not registered for device : {device}") + + return backend + + +def _get_process_group_uid(pg: ProcessGroup) -> int: + backend = None + try: + backend = pg._get_backend(torch.device("cuda")) + except RuntimeError: + pass + if is_nccl_available() and isinstance(backend, ProcessGroupNCCL): + return backend.uid + return -1 + + +def _get_pg_config(group: ProcessGroup | None = None) -> dict[str, Any]: + """ + Return the pg configuration of the given process group. + + """ + pg = group or _get_default_group() + return { + "pg_name": _get_process_group_name(pg), + "pg_desc": pg.group_desc, + "backend_config": get_backend_config(pg), + "pg_size": _get_group_size(pg), + "ranks": get_process_group_ranks(pg), + } + + +def _get_all_pg_configs() -> list[dict[str, Any]]: + """ + Return the pg configuration of all the process groups. + + """ + config_info: list[dict[str, Any]] = [_get_pg_config(pg) for pg in _world.pg_map] + return config_info + + +def get_pg_count() -> int: + """ + Return the number of process groups. + + """ + return _world.group_count + + +def get_node_local_rank(fallback_rank: int | None = None) -> int: + """ + Return the local rank of the current process relative to the node. + + Semantically, this is a useful concept for mapping processes to devices. + For example, on a node with 8 accelerator you could use the node local rank to decide + which accelerator device to bind the process to. + + In practice, the actual assignment of node local ranks is handled by the process launcher outside of pytorch, + and communicated via the `LOCAL_RANK` environment variable. + + Torchrun will automatically populate `LOCAL_RANK`, but other launchers may not. If `LOCAL_RANK` is unspecified, + this API will fall back to the provided kwarg 'fallback_rank' if specified, otherwise it will raise an error. The + intent is to allow writing an application that runs either in single or multi device contexts without error. + + """ + if "LOCAL_RANK" in os.environ: + return int(os.environ["LOCAL_RANK"]) + elif fallback_rank is not None: + return int(fallback_rank) + raise RuntimeError( + "LOCAL_RANK is not in the environment. Consider passing fallback_rank to allow `get_node_local_rank` to work, " + "assuming you are not running in a multi-device context and want the code to run locally instead." + ) + + +def _add_ephemeral_timeout_for_all_pgs(timeout: timedelta) -> None: + """ + This API adds an ephemeral timeout extension for all PGs locally + on one rank. The timeout gets reset when the first collective issued + after API called finished. + NOTE: We only support to set timeout for cuda backends for now. + NOTE: While this feature + provides flexibility in specific scenarios, it introduces statefulness + to timeout setting. Therefore, it is advisable to use this API sparingly + and consider alternative approaches, such as directly setting the timeout + or utilizing a barrier collective (one can set any timeout to the barrier), + whenever feasible. + + Args: + timeout (timedelta): The delta of timeout to extend. + + Returns: + None. + """ + for pg in _world.pg_map: + devices = pg._device_types + if torch.device("cuda") in devices: + backend = pg._get_backend(torch.device("cuda")) + if is_nccl_available() and isinstance(backend, ProcessGroupNCCL): + backend._add_ephemeral_timeout(timeout) + + +def _set_pg_timeout(timeout: timedelta, group: ProcessGroup | None = None) -> None: + """ + Set the timeout for the given process group when users want to use a different timeout instead of + default values. + + Args: + timeout (timedelta): Timeout for operations executed against the process group which + users want to set. Default value is 10 minutes for NCCL and 30 minutes for other backends. + This is the duration after which collectives will be aborted asynchronously and the process will crash. + This is done since CUDA execution is async and it is no longer safe to continue executing user code since + failed async NCCL operations might result in subsequent CUDA operations running on corrupted data. + When TORCH_NCCL_BLOCKING_WAIT is set, the process will block and wait for this timeout. + + group (ProcessGroup, optional): The process group to work on. The + default is the general main process group. If another specific group + is specified, the calling process must be part of :attr:`group`. + + Returns: + None + """ + if group is None: + group = _get_default_group() + if _rank_not_in_group(group): + raise ValueError("Invalid process group specified") + if not isinstance(group, ProcessGroup): + raise AssertionError(f"Expected ProcessGroup, got {type(group)}") + devices = group._device_types + backends = set() + if torch.device("cpu") in devices and is_gloo_available(): + backend = group._get_backend(torch.device("cpu")) + if isinstance(backend, ProcessGroupGloo): + backends.add(backend) + if torch.device("cuda") in devices: + backend = group._get_backend(torch.device("cuda")) + if is_nccl_available() and isinstance(backend, ProcessGroupNCCL): + backends.add(backend) # type: ignore[arg-type] + elif is_gloo_available() and isinstance(backend, ProcessGroupGloo): + backends.add(backend) # type: ignore[arg-type] + if len(backends) == 0: + warnings.warn( + "Set timeout is now only supported for either nccl or gloo.", stacklevel=2 + ) + for backend in backends: + backend._set_default_timeout(timeout) + + +@_exception_logger +@_time_logger +def init_process_group( + backend: str | None = None, + init_method: str | None = None, + timeout: timedelta | None = None, + world_size: int = -1, + rank: int = -1, + store: Store | None = None, + group_name: str = "", + pg_options: Any | None = None, + device_id: torch.device | int | None = None, + _ranks: list[int] | None = None, +) -> None: + """ + Initialize the default distributed process group. + + This will also initialize the distributed package. + + There are 2 main ways to initialize a process group: + 1. Specify ``store``, ``rank``, and ``world_size`` explicitly. + 2. Specify ``init_method`` (a URL string) which indicates where/how + to discover peers. Optionally specify ``rank`` and ``world_size``, + or encode all required parameters in the URL and omit them. + + If neither is specified, ``init_method`` is assumed to be "env://". + + + Args: + backend (str or Backend, optional): The backend to use. Depending on + build-time configurations, valid values include ``mpi``, ``gloo``, + ``nccl``, ``ucc``, ``xccl`` or one that is registered by a third-party + plugin. + Since 2.6, if ``backend`` is not provided, c10d will use a backend + registered for the device type indicated by the `device_id` kwarg + (if provided). The known default registrations today are: ``nccl`` + for ``cuda``, ``gloo`` for ``cpu``, ``xccl`` for ``xpu``. + If neither ``backend`` nor ``device_id`` is provided, c10d will + detect the accelerator on the run-time machine and use a backend + registered for that detected accelerator (or ``cpu``). + This field can be given as a lowercase string (e.g., ``"gloo"``), + which can also be accessed via :class:`Backend` attributes (e.g., + ``Backend.GLOO``). + If using multiple processes per machine with ``nccl`` backend, each + process must have exclusive access to every GPU it uses, as sharing + GPUs between processes can result in deadlock or NCCL invalid usage. + ``ucc`` backend is experimental. + Default backend for the device can be queried with + :func:`get_default_backend_for_device`. + init_method (str, optional): URL specifying how to initialize the + process group. Default is "env://" if no + ``init_method`` or ``store`` is specified. + Mutually exclusive with ``store``. + world_size (int, optional): Number of processes participating in + the job. Required if ``store`` is specified. + rank (int, optional): Rank of the current process (it should be a + number between 0 and ``world_size``-1). + Required if ``store`` is specified. + store(Store, optional): Key/value store accessible to all workers, used + to exchange connection/address information. + Mutually exclusive with ``init_method``. + timeout (timedelta, optional): Timeout for operations executed against + the process group. Default value is 10 minutes for NCCL and 30 minutes for other backends. + This is the duration after which collectives will be aborted asynchronously and the process will crash. + This is done since CUDA execution is async and it is no longer safe to continue executing user code since + failed async NCCL operations might result in subsequent CUDA operations running on corrupted data. + When TORCH_NCCL_BLOCKING_WAIT is set, the process will block and wait for this timeout. + + group_name (str, optional, deprecated): Group name. This argument is ignored + pg_options (ProcessGroupOptions, optional): process group options + specifying what additional options need to be passed in during + the construction of specific process groups. As of now, the only + options we support is ``ProcessGroupNCCL.Options`` for the ``nccl`` + backend, ``is_high_priority_stream`` can be specified so that + the nccl backend can pick up high priority cuda streams when + there're compute kernels waiting. For other available options to config nccl, + See https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/types.html#ncclconfig-t + device_id (torch.device | int, optional): a single, specific device + this process will work on, allowing for backend-specific + optimizations. Currently this has two effects, only under + NCCL: the communicator is immediately formed (calling + ``ncclCommInit*`` immediately rather than the normal lazy + call) and sub-groups will use ``ncclCommSplit`` when + possible to avoid unnecessary overhead of group creation. If you + want to know NCCL initialization error early, you can also use this + field. If an `int` is provided, the API assumes that the accelerator + type at compile time will be used. + _ranks: The ranks in the process group. If provided, the process + group name will be the hash of all the ranks in the group. + + .. note:: To enable ``backend == Backend.MPI``, PyTorch needs to be built from source + on a system that supports MPI. + + .. note:: Support for multiple backends is experimental. Currently when no backend is + specified, both ``gloo`` and ``nccl`` backends will be created. The ``gloo`` backend + will be used for collectives with CPU tensors and the ``nccl`` backend will be used + for collectives with CUDA tensors. A custom backend can be specified by passing in + a string with format ":,:", e.g. + "cpu:gloo,cuda:custom_backend". + + """ + + global _world + + global _backend + global _default_pg_init_method + + if GroupMember.WORLD is not None: + raise ValueError("trying to initialize the default process group twice!") + + set_pytorch_distributed_envs_from_justknobs() + + # Depending on the import order, some trace_rules functions may be evaluated + # during the import phase. In such a case, these functions may not correctly + # add the distributed related rules due to import circular dependency. + # We need to clear the lru_cache during the runtime to ensure the correctness + # of these trace_rules. + # + # Since this API must be called before all distributed code being compiled, + # clearing the cache here should be safe. + if "torch._dynamo" in sys.modules: + torch._dynamo.trace_rules.clear_lru_cache() + + if not ((store is None) or (init_method is None)): + raise AssertionError("Cannot specify both init_method and store.") + + if store is not None: + if not world_size > 0: + raise AssertionError("world_size must be positive if using store") + if not rank >= 0: + raise AssertionError("rank must be non-negative if using store") + elif init_method is None: + init_method = "env://" + + # Get the compile-time accelerator type. + # None indicates no accelerator support. + acc = torch.accelerator.current_accelerator() + + # Auto complete device id + if isinstance(device_id, int): + if acc is None: + raise ValueError( + "device_id is an int, but no accelerator support is found from the current compilation. " + "Please use a different compiled version that supports your accelerator." + ) + device_id = torch.device(acc.type, device_id) + + # Sanity check device_id + if device_id is not None and device_id.type != "cpu": + # Type + if acc is None or device_id.type != acc.type: + raise ValueError( + f"device_id {device_id} does not match the current compilation's accelerator support: {acc}. " + "Please use a different compiled version that supports your accelerator." + ) + # Index + if device_id.index is None: + raise ValueError("Please use a device_id with index.") + # Range + if device_id.index >= torch.accelerator.device_count(): + raise ValueError( + f"device_id {device_id} is out of range. Please use a device index less than " + f"the number of accelerators available: {torch.accelerator.device_count()}." + ) + + logger.info("Using device: %s", device_id) + + # If user did not provide a backend string but provided a device id, e.g. + # >>> init_process_group(device_id=device) + # we try to figure out the backend name based on the device type. + if backend is None and device_id is not None: + # Note: 3rd-party devices can register default backend through the + # default map below. + backend = Backend.default_device_backend_map.get(device_id.type) + + # If we still cannot figure it out, e.g. + # >>> init_process_group() + # we set it to `undefined` and rely on lazy init. + if backend is None: + backend = "undefined" + + # Convert string into `Backend` type + backend = Backend(backend) + + if timeout is None: + timeout = _get_default_timeout(backend) + + _check_valid_timeout(timeout) + + """ + Group name is not visible to users unless they access + internals of c10d. This means we can ignore the value + they provide as it not exposed in a public way. + """ + if _ranks is None or len(_ranks) == 0: + group_name = _process_group_name([], use_hashed_name=False) + else: + group_name = _process_group_name(_ranks, use_hashed_name=True) + if backend == Backend.MPI: + if world_size != -1 or rank != -1: + warnings.warn( + f"For MPI backend, world_size ({world_size}) and rank ({rank}) " + "are ignored since they are assigned by the " + "MPI runtime.", + stacklevel=2, + ) + + default_pg, _ = _new_process_group_helper( + -1, + -1, + [], + backend, + Store(), # Placeholder value since store cannot be None + group_name, + timeout=timeout, + group_desc="default_pg", + ) + else: + # backward compatible API + if store is None: + if backend == "fake": + from torch.testing._internal.distributed.fake_pg import FakeStore + + store = FakeStore() + else: + rendezvous_iterator = rendezvous( + not_none(init_method), rank, world_size, timeout=timeout + ) + store, rank, world_size = next(rendezvous_iterator) + store.set_timeout(timeout) + + # Use a PrefixStore to avoid accidental overrides of keys used by + # different systems (e.g. RPC) in case the store is multi-tenant. + store = PrefixStore("default_pg", store) + + default_pg, _ = _new_process_group_helper( + world_size, + rank, + [], + backend, + store, + group_name, + backend_options=pg_options, + timeout=timeout, + device_id=device_id, + group_desc="default_pg", + ) + + _update_default_pg(default_pg) + + _world.pg_group_ranks[GroupMember.WORLD] = { # type: ignore[index] + i: i + for i in range(GroupMember.WORLD.size()) # type: ignore[attr-defined] + } + _backend = _world.pg_map[not_none(GroupMember.WORLD)][0] + _default_pg_init_method = init_method + + old_hook = sys.excepthook + excepthook_prefix = f"[rank{get_rank()}]" + + def _distributed_excepthook(*args): + old_stderr = sys.stderr + sys.stderr = buf = io.StringIO() + try: + old_hook(*args) + finally: + sys.stderr = old_stderr + msg = buf.getvalue() + msg = "\n".join( + f"{excepthook_prefix}: {s}" if s != "" else "" for s in msg.split("\n") + ) + sys.stderr.write(msg) + sys.stderr.flush() + + sys.excepthook = _distributed_excepthook + + if _is_barrier_after_init() == 1: + # barrier at the end to ensure that once we return from this method, all + # process groups including global variables (if any) are updated + # correctly on all ranks. + # Update 04/2023: for large-scale runs, this barrier (esp. store-based + # barrier) may be costly and/or unscalable. Also, in a lot of cases, + # these barriers may be unnecessary, as proven by a green CI after + # removal. An environment variable `TORCH_DIST_INIT_BARRIER` has been + # added which enables this barrier only when set to 1. + logger.debug( + "Performing barrier after ProcessGroup initialization since " + "TORCH_DIST_INIT_BARRIER = 1" + ) + if backend == Backend.MPI: + # MPI backend doesn't use store. + barrier() + else: + # Use store based barrier here since barrier() used a bunch of + # default devices and messes up NCCL internal state. + _store_based_barrier(rank, store, group_name, world_size, timeout) + + +def _get_split_source(pg: ProcessGroup): + split_from = None + if pg.bound_device_id: + split_from = pg._get_backend(pg.bound_device_id) + elif pg is _world.default_pg: + try: + # pyrefly: ignore [missing-attribute] + split_from = pg._get_backend(torch.device("cuda")) + except RuntimeError: + # no cuda device associated with this backend + pass + + if not split_from or not split_from.supports_splitting: + return None + + # If necessary, find a backend to split from by peeling process + # group wrappers from our potentially wrapped process group. + while _GLOO_AVAILABLE and isinstance(split_from, _ProcessGroupWrapper): + split_from = split_from.wrapped_pg + + return split_from + + +def _new_process_group_helper( + group_size, + group_rank, + global_ranks_in_group, + backend, + store, + group_name: GroupName, + backend_options=None, + timeout=None, + pg_tag=None, + device_id=None, + group_desc=None, +): + """ + Create a new distributed process group. + + This function must be called by ALL processes in the global group, even if + the calling process is not part of the newly created group. In that case, + this function returns GroupMember.NON_GROUP_MEMBER. + + This function is called with ``global_ranks_in_group == []`` for the default group. + """ + global _world + + if group_name in _world.pg_names.values(): + raise ValueError( + "The specified group name has already been " + "created, please use a different group name" + ) + + if device_id is not None and (device_id.index is None or device_id.type == "cpu"): + raise ValueError( + "init_process_group device_id parameter must be an accelerator with an index" + ) + + # Note: _new_process_group_helper is only called from init_process_group, which always provides a timeout value + _check_valid_timeout(timeout) + + if pg_tag not in [None, ""]: + # creating with the same tag and rank set results in the same underlying PG + existing_group = _find_pg_by_ranks_and_tag(pg_tag, global_ranks_in_group) + if existing_group: + _, prefix_store = _world.pg_map[existing_group] + return existing_group, prefix_store + + group_desc = "undefined" if group_desc is None else group_desc + + # The list of group ranks is empty if we're creating the default group. + is_default_group = len(global_ranks_in_group) == 0 + + # nccl and potentially other backends allow creation of + # communicators based on pre-existing ones, which can save + # initialization time. Due to lazy initialization of + # communicators in some backends, we have to be careful and only + # split when we *know* the default PG has already started communicator initialization. + # We know this if we have bound a device id to the default pg (eager initialized). + if is_initialized() and _get_default_group().bound_device_id: + split_from = _get_split_source(_get_default_group()) + else: + split_from = None + + # If this is a subgroup (which means group_ranks is specified), + # we check if the current process is a member of the new group. + if not is_default_group: + global_rank = _get_default_group().rank() + if global_rank not in global_ranks_in_group: + # If we are using `ncclCommSplit` (or similar split from + # other APIs) to create the communicator, we will need to + # call `ncclCommSplit` on *all* ranks in this new group's + # parent group, even those not in the new group. This is + # a requirement of the NCCL API as otherwise we would get + # out of sync. + if split_from: + split_from.perform_nocolor_split(_get_default_group().bound_device_id) + return GroupMember.NON_GROUP_MEMBER, None + + prefix_store = PrefixStore(f"{group_name}/", store) + # The backend for PG will be set later based on what's inside BackendConfig + # and timeout are set in each backend's option. + pg: ProcessGroup = ProcessGroup( + prefix_store, + group_rank, + group_size, + ) + backend_config = BackendConfig(backend) + # Set the default backend when single backend is passed in. + if "," not in str(backend) and ":" not in str(backend): + if backend not in Backend.backend_type_map: + raise AssertionError(f"Unknown backend type {backend}") + if backend == Backend.UNDEFINED: + # Currently when backend is UNDEFINED, only one backend will be initialized + # we use nccl (if cuda is available) or gloo as default backend + # so we can correctly call getDefaultBackend which in ProcessGroup. + if Backend.NCCL in backend_config.get_device_backend_map().values(): + pg._set_default_backend(ProcessGroup.BackendType.NCCL) + else: + pg._set_default_backend(ProcessGroup.BackendType.GLOO) + else: + pg._set_default_backend(Backend.backend_type_map[backend]) + # In order to correctly call pg._has_hooks(), we should set the default backend + # when multi backend is passed in + else: + if Backend.NCCL in backend_config.device_backend_map.values(): + pg._set_default_backend(ProcessGroup.BackendType.NCCL) + elif Backend._plugins.keys(): + custom_backend = next(iter(Backend._plugins.keys())) + if custom_backend in backend_config.device_backend_map.values(): + pg._set_default_backend(ProcessGroup.BackendType.CUSTOM) + else: + pg._set_default_backend(ProcessGroup.BackendType.GLOO) + + if device_id: + pg.bound_device_id = device_id + backend_class: torch._C._distributed_c10d.Backend + for device, backend_str in backend_config.get_device_backend_map().items(): + # Use the group name as prefix in the default store, such that + # a single store can be reused by multiple groups. + backend_prefix_store = PrefixStore(f"{device}/", prefix_store) + + if backend_str == Backend.MPI: + if not is_mpi_available(): + raise RuntimeError( + "Distributed package doesn't have MPI built in." + " MPI is only included if you build PyTorch from" + " source on a host that has MPI installed." + ) + backend_class = ProcessGroupMPI.create(global_ranks_in_group) + backend_type = ProcessGroup.BackendType.MPI + if not backend_class: + return GroupMember.NON_GROUP_MEMBER, None + # create new process group with accurate rank and size + if pg.rank() == -1 and pg.size() == -1: + pg = ProcessGroup( + backend_prefix_store, + backend_class.rank(), + backend_class.size(), + ) + pg._set_default_backend(backend_type) + elif backend_str == Backend.GLOO: + # TODO: remove this check after lazy initialization is supported + # if pg_options is not None: + # raise RuntimeError("GLOO options not supported") + if not is_gloo_available(): + raise RuntimeError("Distributed package doesn't have Gloo built in") + backend_class = ProcessGroupGloo( + backend_prefix_store, + group_rank, + group_size, + # pyrefly: ignore [bad-argument-type] + timeout=timeout, + ) + backend_class.options.global_ranks_in_group = global_ranks_in_group + backend_class.options.group_name = group_name + backend_type = ProcessGroup.BackendType.GLOO + elif backend_str == Backend.NCCL: + if not is_nccl_available(): + raise RuntimeError("Distributed package doesn't have NCCL built in") + if backend_options is not None: + if not isinstance(backend_options, ProcessGroupNCCL.Options): + raise AssertionError( + "Expected backend_options argument to be of type ProcessGroupNCCL.Options" + ) + if backend_options._timeout != timeout: + warnings.warn( + "backend_options._timeout was specified, " + "but timeout kwarg has a default value that will always override it. ", + stacklevel=2, + ) + else: + # default backend_options for NCCL + backend_options = ProcessGroupNCCL.Options() + backend_options.is_high_priority_stream = False + # pyrefly: ignore [bad-argument-type] + backend_options._timeout = timeout + + if split_from: + backend_options.split_from = split_from + backend_options.split_color = _process_group_color( + global_ranks_in_group + ) + backend_options.global_ranks_in_group = global_ranks_in_group + backend_options.group_name = group_name + backend_class = ProcessGroupNCCL( + backend_prefix_store, group_rank, group_size, backend_options + ) + backend_type = ProcessGroup.BackendType.NCCL + elif backend_str == Backend.UCC and is_ucc_available(): + # TODO: once UCC plugin is fully deprecated, remove + # is_ucc_available() from above elif-condition and raise + # RuntimeError if is_ucc_available() returns false. + + backend_class = ProcessGroupUCC( + backend_prefix_store, + group_rank, + group_size, + # pyrefly: ignore [bad-argument-type] + timeout=timeout, + ) + backend_type = ProcessGroup.BackendType.UCC + elif backend_str == Backend.XCCL: + if not is_xccl_available(): + raise RuntimeError("Distributed package doesn't have XCCL built in") + backend_options = ProcessGroupXCCL.Options() + backend_options.global_ranks_in_group = global_ranks_in_group + backend_options.group_name = group_name + # pyrefly: ignore [bad-argument-type] + backend_options._timeout = timeout + backend_class = ProcessGroupXCCL( + backend_prefix_store, group_rank, group_size, backend_options + ) + backend_type = ProcessGroup.BackendType.XCCL + else: + if backend_str.upper() not in Backend._plugins: + raise AssertionError(f"Unknown c10d backend type {backend_str.upper()}") + + backend_plugin = Backend._plugins[backend_str.upper()] + creator_fn = backend_plugin.creator_fn + extended_api = backend_plugin.extended_api + backend_type = ProcessGroup.BackendType.CUSTOM + + if not extended_api: + backend_class = creator_fn( + backend_prefix_store, group_rank, group_size, timeout + ) + else: + dist_backend_opts = _DistributedBackendOptions() + dist_backend_opts.store = backend_prefix_store + dist_backend_opts.group_rank = group_rank + dist_backend_opts.group_size = group_size + # pyrefly: ignore [bad-argument-type] + dist_backend_opts.timeout = timeout + dist_backend_opts.group_id = group_name + dist_backend_opts.global_ranks_in_group = global_ranks_in_group + + backend_class = creator_fn(dist_backend_opts, backend_options) + + # Set sequence numbers for gloo and nccl backends. + if backend_str == Backend.GLOO: + if not isinstance(backend_class, ProcessGroupGloo): + raise AssertionError( + f"Expected ProcessGroupGloo, got {type(backend_class)}" + ) + backend_class._set_sequence_number_for_group() + elif backend_str == Backend.NCCL: + if not isinstance(backend_class, ProcessGroupNCCL): + raise AssertionError( + f"Expected ProcessGroupNCCL, got {type(backend_class)}" + ) + backend_class._set_sequence_number_for_group() + + # If the type is a subclass of ProcessGroup then return this process group immediately + # TODO: This defaults to the old behavior for PythonProcessGroups which overwrites the + # ProcessGroup instance + if issubclass(type(backend_class), ProcessGroup): + pg = backend_class # type: ignore[assignment] + break + + # Process group wrapper initialization for supported PGs when TORCH_DISTRIBUTED_DEBUG is set + if ( + backend_str in [Backend.GLOO, Backend.NCCL, Backend.UCC] + or backend_str.upper() in Backend._plugins + ): + # In debug mode and if GLOO is available, wrap in a wrapper PG that + # enables enhanced collective checking for debuggability. + if get_debug_level() == DebugLevel.DETAIL: + if not _GLOO_AVAILABLE: + logger.info( + """TORCH_DISTRIBUTED_DEBUG was set to DETAIL, but + GLOO is not available. Build with Gloo to + create a wrapper process group in debug mode + to aid collective desynchronization debugging.""" + ) + else: + backend_class = _create_process_group_wrapper( + wrapped_pg=backend_class, + store_prefix=group_name, + store=backend_prefix_store, + rank=group_rank, + world_size=group_size, + # pyrefly: ignore [bad-argument-type] + timeout=timeout, + ) + + # register only a single backend when all get_device_backend_map values are the same + if len(set(backend_config.get_device_backend_map().values())) == 1: + for device in backend_config.get_device_backend_map(): + pg._register_backend(torch.device(device), backend_type, backend_class) + + # break out of outer loop to not create any more backends + break + + pg._register_backend(torch.device(device), backend_type, backend_class) + + # set group_name and group_dsec to backend + if group_name is None: + raise AssertionError("group_name must not be None") + if group_desc is None: + raise AssertionError("group_desc must not be None") + pg._set_group_name(group_name) + pg._set_group_desc(group_desc) + + if device_id and pg._get_backend(device_id).supports_splitting: + eager_backend = pg._get_backend(device_id) + eager_backend.eager_connect_single_device(device_id) + + # update global state + _world.pg_map[pg] = (backend, prefix_store) + _world.pg_names[pg] = group_name + _register_process_group(group_name, pg) + + _world.pg_backend_config[pg] = str(backend_config) + # "" is the default tag for user PGs + if pg_tag in [None, ""]: + pg_tag = f"ptd:{group_name}" + _world.tags_to_pg.setdefault("", []).append(pg) + else: + pg_tag = f"user:{pg_tag}" + + _world.tags_to_pg.setdefault(pg_tag, []).append(pg) + _world.pg_to_tag[pg] = pg_tag + return pg, prefix_store + + +def destroy_process_group(group: ProcessGroup | None = None): + """ + Destroy a given process group, and deinitialize the distributed package. + + Args: + group (ProcessGroup, optional): The process group to be destroyed, if + group.WORLD is given, all process + groups including the default one will + be destroyed. + """ + global _world + + if group == GroupMember.NON_GROUP_MEMBER: + return + + if group is None: + pg = GroupMember.WORLD + else: + pg = group + + if pg is None: + raise AssertionError("Process group cannot be None") + if _world.pg_map.get(pg, None) is None: + raise ValueError("Invalid process group specified") + + # When users register Python onCompletion hooks, those hooks will run on a + # different thread than the main thread. Today, the ProcessGroup dtor does + # wait for that thread. However, the dtor might finish after the Python + # Interpreter exits. After that grabbing the GIL for the Python hook will crash. + # We can either revive the interpreter when running hooks or keep the main one + # alive until all works and hooks are done. The current implementation does the + # latter. Therefore, we explicitly call _wait_for_pending_works() here to wait + # for the pending hooks to finish. + if type(pg) is ProcessGroup and pg._has_hooks(): + pg._wait_for_pending_works() + + if group is None or group == GroupMember.WORLD: + # shutdown all backends in the order of pg names. shutting down in order because + # ncclCommAbort() was a 'collective' call in some versions of NCCL. + for pg_to_shutdown in sorted( + _world.pg_names, key=lambda x: _world.pg_names[x], reverse=True + ): + pg_to_shutdown.shutdown() + + _update_default_pg(None) + _world.pg_map.clear() + _world.pg_names.clear() + _world.pg_group_ranks.clear() + _world.pg_backend_config.clear() + _world.pg_to_tag.clear() + _world.tags_to_pg.clear() + _world.pg_coalesce_state.clear() + _unregister_all_process_groups() + + # when process group doesn't have an explicit name (only WORLD (default) + # process group can have an explicit name), we use global _world.group_count + # to generate the name. We need to reset the counter on destruction to + # allow consistent value to be generated when we re-create process + # groups after some trainers recover from failure + # + # We only reset this when WORLD is being destroyed because if this + # process group is in good state, we aren't dealing with failures. + _world.group_count = 0 + else: + pg.shutdown() + del _world.pg_map[pg] + del _world.pg_names[pg] + del _world.pg_group_ranks[pg] + del _world.pg_backend_config[pg] + if pg in _world.pg_coalesce_state: + warnings.warn( + "Some coalesced collectives haven't been launched when " + "ProcessGroup is destroyed. They will be cleaned.", + stacklevel=2, + ) + del _world.pg_coalesce_state[pg] + + tag = _world.pg_to_tag.get(pg) + del _world.pg_to_tag[pg] + if tag is not None: + try: + _world.tags_to_pg[tag].remove(pg) + if tag.startswith("ptd:"): + _world.tags_to_pg[""].remove(pg) + except Exception: + pass + _unregister_process_group(pg.group_name) + + +def _abort_process_group(group: ProcessGroup | None = None): + """ + Abort a given process group. If group.WORLD (i.e. `None`) is given, all + process groups including the default one will be aborted. + + Args: + group (ProcessGroup, optional): The process group to be aborted. + + .. note:: this API is experimental and currently only works with the NCCL + backend. + + .. note:: this API should be used with `TORCH_NCCL_ASYNC_ERROR_HANDLING` + turned off (i.e. set to 0). Otherwise, ProcessGroupNCCL's watchdog may + automatically handle errors or timeouts for you including aborting the + ProcessGroup. + """ + global _world + + if group == GroupMember.NON_GROUP_MEMBER: + return + + pg = group or GroupMember.WORLD + + if pg is None: + raise AssertionError("Process group cannot be None") + if _world.pg_map.get(pg, None) is None: + raise ValueError("Invalid process group specified or has been destroyed.") + + try: + backend = pg._get_backend(torch.device("cuda")) + except RuntimeError: + backend = None + + if group is None or group == GroupMember.WORLD: + # Abort all backends within a ncclGroupStart|End semantic. + # This ensures that different NCCL communicators' abort calls won't + # deadlock each other. + # For details, please see: https://github.com/pytorch/pytorch/issues/119797 + if is_nccl_available() and isinstance(backend, ProcessGroupNCCL): + backend._group_start() + for pg_to_abort in sorted( + _world.pg_names, key=lambda x: _world.pg_names[x], reverse=True + ): + pg_to_abort.abort() + if is_nccl_available() and isinstance(backend, ProcessGroupNCCL): + backend._group_end() + + _update_default_pg(None) + _world.pg_map.clear() + _world.pg_names.clear() + _world.pg_group_ranks.clear() + _world.pg_backend_config.clear() + _world.pg_to_tag.clear() + _world.tags_to_pg.clear() + _world.pg_coalesce_state.clear() + _unregister_all_process_groups() + + # when process group doesn't have an explicit name (only WORLD (default) + # process group can have an explicit name), we use global _world.group_count + # to generate the name. We need to reset the counter on destruction to + # allow consistent value to be generated when we re-create process + # groups after some trainers recover from failure + # + # We only reset this when WORLD is being destroyed because if this + # process group is in good state, we aren't dealing with failures. + _world.group_count = 0 + else: + pg.abort() + del _world.pg_map[pg] + del _world.pg_names[pg] + del _world.pg_group_ranks[pg] + del _world.pg_backend_config[pg] + if pg in _world.pg_coalesce_state: + warnings.warn( + "Some coalesced collectives haven't been launched when " + "ProcessGroup is aborted. They will be cleaned.", + stacklevel=2, + ) + del _world.pg_coalesce_state[pg] + + tag = _world.pg_to_tag.get(pg) + del _world.pg_to_tag[pg] + if tag is not None: + try: + _world.tags_to_pg[tag].remove(pg) + if tag.startswith("ptd:"): + _world.tags_to_pg[""].remove(pg) + except Exception: + pass + _unregister_process_group(pg.group_name) + + +def get_rank(group: ProcessGroup | None = None) -> int: + """ + Return the rank of the current process in the provided ``group``, default otherwise. + + Rank is a unique identifier assigned to each process within a distributed + process group. They are always consecutive integers ranging from 0 to + ``world_size``. + + Args: + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + + Returns: + The rank of the process group + -1, if not part of the group + + """ + if _rank_not_in_group(group): + return -1 + + default_pg = _get_default_group() + if group is None or group is GroupMember.WORLD: + return default_pg.rank() + + return get_group_rank(group, default_pg.rank()) + + +def get_world_size(group: ProcessGroup | None = None) -> int: + """ + Return the number of processes in the current process group. + + Args: + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + + Returns: + The world size of the process group + -1, if not part of the group + + """ + if _rank_not_in_group(group): + return -1 + + return _get_group_size(group) + + +def isend( + tensor: torch.Tensor, + dst: int | None = None, + group: ProcessGroup | None = None, + tag: int = 0, + group_dst: int | None = None, +) -> Work | None: + """ + Send a tensor asynchronously. + + .. warning:: + Modifying ``tensor`` before the request completes causes undefined + behavior. + + .. warning:: + ``tag`` is not supported with the NCCL backend. + + Unlike send, which is blocking, isend allows src == dst rank, i.e. send to self. + + Args: + tensor (Tensor): Tensor to send. + dst (int): Destination rank on global process group (regardless of ``group`` argument) + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + tag (int, optional): Tag to match send with remote recv + group_dst (int, optional): Destination rank on ``group``. Invalid to specify both ``dst`` and ``group_dst`` + + Returns: + A distributed request object. + None, if not part of the group + + """ + group = _group_or_default_group(group) + group_dst = _canonicalize_group_rank(group, dst, group_dst) + _check_single_tensor(tensor, "tensor") + if _rank_not_in_group(group): + _warn_not_in_group("isend") + return None + + if tensor.is_complex(): + tensor = torch.view_as_real(tensor) + + return group.send([tensor], group_dst, tag) + + +def irecv( + tensor: torch.Tensor, + src: int | None = None, + group: ProcessGroup | None = None, + tag: int = 0, + group_src: int | None = None, +) -> Work | None: + """ + Receives a tensor asynchronously. + + .. warning:: + ``tag`` is not supported with the NCCL backend. + + Unlike recv, which is blocking, irecv allows src == dst rank, i.e. recv from self. + + Args: + tensor (Tensor): Tensor to fill with received data. + src (int, optional): Source rank on global process group (regardless of ``group`` argument). + Will receive from any process if unspecified. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + tag (int, optional): Tag to match recv with remote send + group_src (int, optional): Destination rank on ``group``. Invalid to specify both ``src`` and ``group_src``. + + Returns: + A distributed request object. + None, if not part of the group + + """ + _check_single_tensor(tensor, "tensor") + if _rank_not_in_group(group): + _warn_not_in_group("irecv") + return None + + if tensor.is_complex(): + tensor = torch.view_as_real(tensor) + + group = _group_or_default_group(group) + if src is None and group_src is None: + return group.recv_anysource([tensor], tag) + else: + group_src = _canonicalize_group_rank(group, src, group_src) + return group.recv([tensor], group_src, tag) + + +@_exception_logger +def send( + tensor: torch.Tensor, + dst: int | None = None, + group: ProcessGroup | None = None, + tag: int = 0, + group_dst: int | None = None, +) -> None: + """ + Send a tensor synchronously. + + .. warning:: + ``tag`` is not supported with the NCCL backend. + + Args: + tensor (Tensor): Tensor to send. + dst (int): Destination rank on global process group (regardless of ``group`` argument). + Destination rank should not be the same as the rank of the current process. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + tag (int, optional): Tag to match send with remote recv + group_dst (int, optional): Destination rank on ``group``. Invalid to specify both ``dst`` and ``group_dst``. + + """ + group = _group_or_default_group(group) + group_dst = _canonicalize_group_rank(group, dst, group_dst) + _check_not_self_rank(group, group_dst, "destination") + work = isend(tensor, group=group, tag=tag, group_dst=group_dst) + if work is not None: + work.wait() + + +@_exception_logger +def recv( + tensor: torch.Tensor, + src: int | None = None, + group: ProcessGroup | None = None, + tag: int = 0, + group_src: int | None = None, +) -> int: + """ + Receives a tensor synchronously. + + .. warning:: + ``tag`` is not supported with the NCCL backend. + + Args: + tensor (Tensor): Tensor to fill with received data. + src (int, optional): Source rank on global process group (regardless of ``group`` argument). + Will receive from any process if unspecified. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + tag (int, optional): Tag to match recv with remote send + group_src (int, optional): Destination rank on ``group``. Invalid to specify both ``src`` and ``group_src``. + Returns: + Sender rank + -1, if not part of the group + + """ + work = irecv(tensor, src=src, group=group, tag=tag, group_src=group_src) + if work is None: + return -1 + work.wait() + if src is None: + if group_src is None: + group_src = work._source_rank() + group = _group_or_default_group(group) + _check_not_self_rank(group, group_src, "source") + src = get_global_rank(group, group_src) + return src + + +class _IllegalWork(Work): + def __getattribute__(self, name): + if name in [ + "is_success", + "exception", + "wait", + "source_rank", + "_source_rank", + "result", + "synchronize", + ]: + raise ValueError(f"Illegal to call {name} on IllegalWork object") + + +class _CoalescingManager: + def __init__(self) -> None: + self.works: list[Work] = [] + + def append(self, work: Work | None = None): + if work: + self.works.append(work) + + def wait(self): + for work in self.works: + work.wait() + + +@contextlib.contextmanager +def _coalescing_manager( + group: ProcessGroup | None = None, + device: torch.device | None = None, + async_ops: bool = False, +): + """ + Context manager used to coalesce collectives or P2P operations when possible. + + Args: + group (`ProcessGroup`, optional): The process group to work on. If None, + the default process group will be used. + device (`torch.device`, optional): Default is None, set to a device if + there isn't a `**_coalesced` implementation by the backend. + async_ops (`bool`, optional): whether the coalesced ops are async ops. + + Examples: + >>> # xdoctest: +SKIP("no rank") + >>> # Synchronous ops + >>> with _coalescing_manager(): + >>> for i in range(num_colls): + >>> dist.all_reduce(tensors[i]) + >>> # Asynchronous ops + >>> with _coalescing_manager(async_ops=True) as cm: + >>> for i in range(num_colls): + >>> dist.all_reduce(tensors[i]) + >>> cm.wait() + + .. warning:: + :func:`_coalescing_manager` currently do not support coalescing + all-reduces with different reduce operators, e.g. `ReduceOp.SUM` mixed + with `ReduceOp.PRODUCT`. + """ + group = group or _get_default_group() + op_list = _world.pg_coalesce_state.setdefault(group, []) + if op_list: + raise ValueError( + "ProcessGroup has non-empty op list at the start of coalescing" + ) + if device: + group._start_coalescing(device) + cm = _CoalescingManager() + yield cm + work = None + op_list = _world.pg_coalesce_state.pop(group) + if op_list: + # Collectives supporting "Fast Path" coalescing are captured. + # See implementation in corresponding collective APIs. + # Currently supported: + # - coalesced `all_reduce` + # - coalesced `all_gather_into_tensor` + # - coalesced `reduce_scatter_tensor` + op0 = op_list[0].op + if op0 is all_reduce: + tensors = [op.tensor for op in op_list] + all_reduce_opts = AllreduceCoalescedOptions() + all_reduce_opts.reduceOp = not_none(op_list[0].redop) + all_reduce_opts.asyncOp = async_ops + work = group.allreduce_coalesced(tensors, all_reduce_opts) + elif op0 is all_gather_into_tensor: + inputs = [] + outputs = [] + for op in op_list: + inputs.append(op.tensor) + outputs.append(not_none(op.dst_tensor)) + all_gather_opts = AllgatherOptions() + all_gather_opts.asyncOp = async_ops + work = group.allgather_into_tensor_coalesced(outputs, inputs) + elif op0 is reduce_scatter_tensor: + inputs = [] + outputs = [] + for op in op_list: + inputs.append(op.tensor) + outputs.append(not_none(op.dst_tensor)) + reduce_opts = ReduceScatterOptions() + reduce_opts.reduceOp = not_none(op_list[0].redop) + reduce_opts.asyncOp = async_ops + work = group.reduce_scatter_tensor_coalesced(outputs, inputs, reduce_opts) + else: + raise AssertionError( + f"Coalescing manager does not support fast-path coalescing of {op0}, " + f"yet {op0} is still recorded in op list. This is an internal error of c10d." + ) + + if device: + # Old style of letting each coll inside the context manager to call into C++ counterpart via python binding + work = group._end_coalescing(device) + + if async_ops: + cm.append(work) + elif ( + work is not None + ): # Backward compatible with backends that don't sync at CPP level + work.wait() + # Otherwise, the backend has sync'ed at CPP level + + +class _TimeEstimator: + def __init__(self) -> None: + self.estimated_time: float | None = None + + +@contextlib.contextmanager +def _time_estimator( + group: ProcessGroup | None = None, + device: torch.device | None = None, +): + """ + Context manager used to estimate time of collectives. + Within the context manager, nothing is actually run and the backend just simulates + the collective time only. + + Args: + group (`ProcessGroup`, optional): The process group to work on. If None, + the default process group will be used. + device (`torch.device`, optional): Default is None, set to a device if + there isn't a `**_coalesced` implementation by the backend. + + Examples: + >>> # xdoctest: +SKIP("no rank") + >>> # Synchronous ops + >>> with _time_estimator() as cm: + >>> for i in range(num_colls): + >>> dist.all_reduce(tensors[i]) + >>> # estimate time is stored in cm.estimated_time + + .. warning:: + :func:`_time_estimator` currently only support NCCL backend but it can + easily be extended to other backends. + + Also a NCCL communicator needs to be created because only with a real communicator can we do accurate estimation. + The communicator internally has knowledge about the links it runs on + (e.g. intra-node or inter-node, whether the links are NVLink or PCI-e or IB). + """ + # TODO: We need to also support torch inductor for the time estimator. + group = group or _get_default_group() + device = device or _get_pg_default_device(group) + backend = group._get_backend(device) + if not backend.supports_time_estimate: + raise NotImplementedError( + f"collective time estimator is not supported in the current version of backend {backend}" + ) + backend._start_time_estimate() # type: ignore[attr-defined] + cm = _TimeEstimator() + yield cm + cm.estimated_time = backend._end_time_estimate() # type: ignore[attr-defined] + + +def batch_isend_irecv(p2p_op_list: list[P2POp]) -> list[Work]: + """ + Send or Receive a batch of tensors asynchronously and return a list of requests. + + Process each of the operations in ``p2p_op_list`` and return the corresponding + requests. NCCL, Gloo, and UCC backend are currently supported. + + Args: + p2p_op_list: A list of point-to-point operations(type of each operator is + ``torch.distributed.P2POp``). The order of the isend/irecv in the list + matters and it needs to match with corresponding isend/irecv on the + remote end. + + Returns: + A list of distributed request objects returned by calling the corresponding + op in the op_list. + + Examples: + >>> # xdoctest: +SKIP("no rank") + >>> send_tensor = torch.arange(2, dtype=torch.float32) + 2 * rank + >>> recv_tensor = torch.randn(2, dtype=torch.float32) + >>> send_op = dist.P2POp(dist.isend, send_tensor, (rank + 1) % world_size) + >>> recv_op = dist.P2POp( + ... dist.irecv, recv_tensor, (rank - 1 + world_size) % world_size + ... ) + >>> reqs = batch_isend_irecv([send_op, recv_op]) + >>> for req in reqs: + >>> req.wait() + >>> recv_tensor + tensor([2, 3]) # Rank 0 + tensor([0, 1]) # Rank 1 + + .. note:: Note that when this API is used with the NCCL PG backend, users must set + the current GPU device with `torch.cuda.set_device`, otherwise it will + lead to unexpected hang issues. + + In addition, if this API is the first collective call in the ``group`` + passed to ``dist.P2POp``, all ranks of the ``group`` must participate in + this API call; otherwise, the behavior is undefined. If this API call is + not the first collective call in the ``group``, batched P2P operations + involving only a subset of ranks of the ``group`` are allowed. + """ + _check_p2p_op_list(p2p_op_list) + group = p2p_op_list[0].group + if group is None: + group = _get_default_group() + device = p2p_op_list[0].tensor.device + + def peer_kwarg(op: P2POp) -> dict[str, int]: + key = "group_dst" if op.op is isend else "group_src" + return {key: op.group_peer} + + if type(group) is ProcessGroup and group._get_backend(device).supports_coalescing: + # NCCL style coalescing + with _coalescing_manager(group, device, async_ops=True) as cm: + for p2p_op in p2p_op_list: + p2p_op.op( + p2p_op.tensor, + group=p2p_op.group, + tag=p2p_op.tag, + **peer_kwarg(p2p_op), + ) + + return cm.works + else: + # backend not support coalescing + reqs = [] + for p2p_op in p2p_op_list: + work = p2p_op.op( + p2p_op.tensor, + group=p2p_op.group, + tag=p2p_op.tag, + **peer_kwarg(p2p_op), + ) + if work: + reqs.append(work) + return reqs + + +@_exception_logger +def broadcast( + tensor: torch.Tensor, + src: int | None = None, + group: ProcessGroup | None = None, + async_op: bool = False, + group_src: int | None = None, +): + """ + Broadcasts the tensor to the whole group. + + ``tensor`` must have the same number of elements in all processes + participating in the collective. + + Args: + tensor (Tensor): Data to be sent if ``src`` is the rank of current + process, and tensor to be used to save received data otherwise. + src (int): Source rank on global process group (regardless of ``group`` argument). + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + async_op (bool, optional): Whether this op should be an async op + group_src (int): Source rank on ``group``. Must specify one of ``group_src`` + and ``src`` but not both. + + Returns: + Async work handle, if async_op is set to True. + None, if not async_op or if not part of the group + + """ + group = _group_or_default_group(group) + group_src = _canonicalize_group_rank(group, src, group_src, return_global=False) + _check_single_tensor(tensor, "tensor") + if _rank_not_in_group(group): + _warn_not_in_group("broadcast") + return + + opts = BroadcastOptions() + opts.rootRank = group_src + opts.rootTensor = 0 + opts.asyncOp = async_op + if tensor.is_complex(): + tensor = torch.view_as_real(tensor) + work = group.broadcast([tensor], opts) + if async_op: + return work + elif ( + work is not None + ): # Backward compatible with backends that don't sync at CPP level + work.wait() + # Otherwise, the backend has sync'ed at CPP level + + +@_exception_logger +def all_reduce(tensor, op=ReduceOp.SUM, group=None, async_op: bool = False): + """ + Reduces the tensor data across all machines in a way that all get the final result. + + After the call ``tensor`` is going to be bitwise identical in all processes. + + Complex tensors are supported. + + Args: + tensor (Tensor): Input and output of the collective. The function + operates in-place. + op (optional): One of the values from + ``torch.distributed.ReduceOp`` + enum. Specifies an operation used for element-wise reductions. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + async_op (bool, optional): Whether this op should be an async op + + Returns: + Async work handle, if async_op is set to True. + None, if not async_op or if not part of the group + + Examples: + >>> # xdoctest: +SKIP("no rank") + >>> # All tensors below are of torch.int64 type. + >>> # We have 2 process groups, 2 ranks. + >>> device = torch.device(f"cuda:{rank}") + >>> tensor = torch.arange(2, dtype=torch.int64, device=device) + 1 + 2 * rank + >>> tensor + tensor([1, 2], device='cuda:0') # Rank 0 + tensor([3, 4], device='cuda:1') # Rank 1 + >>> dist.all_reduce(tensor, op=ReduceOp.SUM) + >>> tensor + tensor([4, 6], device='cuda:0') # Rank 0 + tensor([4, 6], device='cuda:1') # Rank 1 + + >>> # All tensors below are of torch.cfloat type. + >>> # We have 2 process groups, 2 ranks. + >>> tensor = torch.tensor( + ... [1 + 1j, 2 + 2j], dtype=torch.cfloat, device=device + ... ) + 2 * rank * (1 + 1j) + >>> tensor + tensor([1.+1.j, 2.+2.j], device='cuda:0') # Rank 0 + tensor([3.+3.j, 4.+4.j], device='cuda:1') # Rank 1 + >>> dist.all_reduce(tensor, op=ReduceOp.SUM) + >>> tensor + tensor([4.+4.j, 6.+6.j], device='cuda:0') # Rank 0 + tensor([4.+4.j, 6.+6.j], device='cuda:1') # Rank 1 + + """ + # Dynamo has built-in logic to map legacy distributed ops to functional collectives. + # Let's redirect to a torch function mode that can mimic this logic outside Dynamo + # (e.g., non-strict export implements such a torch function mode). + relevant_args = (tensor,) + if has_torch_function(relevant_args): + return handle_torch_function( + all_reduce, + relevant_args, + tensor, + op=op, + group=group, + async_op=async_op, + ) + + _check_single_tensor(tensor, "tensor") + if _rank_not_in_group(group): + _warn_not_in_group("all_reduce") + return + + if tensor.is_complex(): + if not supports_complex(op): + raise ValueError(f"all_reduce does not support {op} on complex tensors") + tensor = torch.view_as_real(tensor) + + opts = AllreduceOptions() + opts.reduceOp = op + opts.asyncOp = async_op + if group is None: + group = _get_default_group() + + if group in _world.pg_coalesce_state: + # We are in coalescing context, do not issue single operation, just append a collective representation + coll = _CollOp(all_reduce, tensor, None, op, None) + _world.pg_coalesce_state[group].append(coll) + if async_op: + return _IllegalWork() + else: + return None + + work = group.allreduce([tensor], opts) + + if async_op: + return work + elif ( + work is not None + ): # Backward compatible with backends that don't sync at CPP level + work.wait() + # Otherwise, the backend has sync'ed at CPP level + + +@_exception_logger +@deprecated( + "`torch.distributed.all_reduce_coalesced` will be deprecated. If you must " + "use it, please revisit our documentation later at " + "https://pytorch.org/docs/main/distributed.html#collective-functions", + category=FutureWarning, +) +def all_reduce_coalesced(tensors, op=ReduceOp.SUM, group=None, async_op: bool = False): + """ + WARNING: at this time individual shape checking is not implemented across nodes. + + For example, if the rank 0 node passes [torch.rand(4), torch.rand(2)] and the + rank 1 node passes [torch.rand(2), torch.rand(2), torch.rand(2)], the allreduce + operation will proceed without complaint and return erroneous outputs. This lack + of shape checking results in significant performance improvements but users of this + function should take extra care to ensure that each node passes in tensors whose + shapes match across nodes. + + Reduces each tensor in tensors (residing on the same device) across all machines + in such a way that all get the final result. + + After the call each tensor in tensors is going to bitwise identical + in all processes. + + Complex tensors are supported. + + Args: + tensors (Union[List[Tensor], Tensor]): Input and output of the collective. + The function operates in-place. + op (Optional[ReduceOp]): One of the values from + ``torch.distributed.ReduceOp`` enum. Specifies an operation used for + element-wise reductions. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + async_op (Optional[bool]): Whether this op should be an async op. + + Returns: + Async work handle, if async_op is set to True. + None, if not async_op or if not part of the group. + + """ + if isinstance(tensors, torch.Tensor): + tensors = [tensors] + _check_tensor_list(tensors, "tensor") + _ensure_all_tensors_same_dtype(tensors) + if _rank_not_in_group(group): + _warn_not_in_group("all_reduce_coalesced") + return + + if any(t.is_complex() for t in tensors) and not supports_complex(op): + raise ValueError(f"all_reduce does not support {op} on complex tensors") + + tensors = [t if not t.is_complex() else torch.view_as_real(t) for t in tensors] + + opts = AllreduceCoalescedOptions() + opts.reduceOp = op + opts.asyncOp = async_op + group = group or _get_default_group() + work = group.allreduce_coalesced(tensors, opts) + + if async_op: + return work.get_future() + elif ( + work is not None + ): # Backward compatible with backends that don't sync at CPP level + work.wait() + # Otherwise, the backend has sync'ed at CPP level + + +@_exception_logger +def reduce( + tensor: torch.Tensor, + dst: int | None = None, + op=ReduceOp.SUM, + group: ProcessGroup | None = None, + async_op: bool = False, + group_dst: int | None = None, +): + """ + Reduces the tensor data across all machines. + + Only the process with rank ``dst`` is going to receive the final result. + + Args: + tensor (Tensor): Input and output of the collective. The function + operates in-place. + dst (int): Destination rank on global process group (regardless of ``group`` argument) + op (optional): One of the values from + ``torch.distributed.ReduceOp`` + enum. Specifies an operation used for element-wise reductions. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + async_op (bool, optional): Whether this op should be an async op + group_dst (int): Destination rank on ``group``. Must specify one of ``group_dst`` + and ``dst`` but not both. + + Returns: + Async work handle, if async_op is set to True. + None, if not async_op or if not part of the group + + """ + group = _group_or_default_group(group) + group_dst = _canonicalize_group_rank(group, dst, group_dst, return_global=False) + _check_single_tensor(tensor, "tensor") + if _rank_not_in_group(group): + _warn_not_in_group("reduce") + return + + opts = ReduceOptions() + opts.reduceOp = op + opts.rootRank = group_dst + opts.asyncOp = async_op + work = group.reduce([tensor], opts) + if async_op: + return work + elif ( + work is not None + ): # Backward compatible with backends that don't sync at CPP level + work.wait() + # Otherwise, the backend has sync'ed at CPP level + + +def _object_to_tensor(obj, device, group): + with _WaitCounter("pytorch.wait_counter.c10d._object_to_tensor").guard(): + f = io.BytesIO() + _pickler(f).dump(obj) + byte_storage = torch.ByteStorage._from_buffer(f.getvalue()) # type: ignore[attr-defined] + # Do not replace `torch.ByteTensor` or `torch.LongTensor` with torch.tensor and specifying dtype. + # Otherwise, it will cause 100X slowdown. + # See: https://github.com/pytorch/pytorch/issues/65696 + byte_tensor = torch.ByteTensor(byte_storage).to(device) + if get_debug_level() == DebugLevel.DETAIL and is_nccl_available(): + backend = get_backend(group) + if backend == Backend.NCCL: + hash = torch._C._distributed_c10d._hash_tensors([byte_tensor]) + logger.warning( + "_object_to_tensor size: %s hash value: %s", + byte_tensor.numel(), + hash, + ) + local_size = torch.LongTensor([byte_tensor.numel()]).to(device) + return byte_tensor, local_size + + +def _tensor_to_object(tensor, tensor_size, group): + with _WaitCounter("pytorch.wait_counter.c10d._tensor_to_object").guard(): + if get_debug_level() == DebugLevel.DETAIL and is_nccl_available(): + backend = get_backend(group) + if backend == Backend.NCCL: + hash = torch._C._distributed_c10d._hash_tensors([tensor]) + logger.warning( + "_tensor_to_object size: %s hash value: %s", tensor.numel(), hash + ) + tensor = tensor.cpu() + buf = tensor.numpy().tobytes()[:tensor_size] + return _unpickler(io.BytesIO(buf)).load() + + +@_exception_logger +def all_gather_object(object_list, obj, group=None): + """ + Gathers picklable objects from the whole group into a list. + + Similar to :func:`all_gather`, but Python objects can be passed in. + Note that the object must be picklable in order to be gathered. + + Args: + object_list (list[Any]): Output list. It should be correctly sized as the + size of the group for this collective and will contain the output. + obj (Any): Pickable Python object to be broadcast from current process. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. Default is ``None``. + + Returns: + None. If the calling rank is part of this group, the output of the + collective will be populated into the input ``object_list``. If the + calling rank is not part of the group, the passed in ``object_list`` will + be unmodified. + + .. note:: Note that this API differs slightly from the :func:`all_gather` + collective since it does not provide an ``async_op`` handle and thus + will be a blocking call. + + .. note:: For NCCL-based processed groups, internal tensor representations + of objects must be moved to the GPU device before communication takes + place. In this case, the device used is given by + ``torch.cuda.current_device()`` and it is the user's responsibility to + ensure that this is set so that each rank has an individual GPU, via + ``torch.cuda.set_device()``. + + .. warning:: + Object collectives have a number of serious performance and scalability + limitations. See :ref:`object_collectives` for details. + + .. warning:: + :func:`all_gather_object` uses ``pickle`` module implicitly, which is + known to be insecure. It is possible to construct malicious pickle data + which will execute arbitrary code during unpickling. Only call this + function with data you trust. + + .. warning:: + Calling :func:`all_gather_object` with GPU tensors is not well supported + and inefficient as it incurs GPU -> CPU transfer since tensors would be + pickled. Please consider using :func:`all_gather` instead. + + Example:: + >>> # xdoctest: +SKIP("need process group init") + >>> # Note: Process group initialization omitted on each rank. + >>> import torch.distributed as dist + >>> # Assumes world_size of 3. + >>> gather_objects = ["foo", 12, {1: 2}] # any picklable object + >>> output = [None for _ in gather_objects] + >>> dist.all_gather_object(output, gather_objects[dist.get_rank()]) + >>> output + ['foo', 12, {1: 2}] + """ + if _rank_not_in_group(group): + _warn_not_in_group("all_gather_object") + return + + current_device = _get_object_coll_device(group) + input_tensor, local_size = _object_to_tensor(obj, current_device, group) + + # Gather all local sizes. This is so that we can find the max size, and index + # until the correct size when deserializing the tensors. + group_size = get_world_size(group=group) + object_sizes_tensor = torch.zeros( + group_size, dtype=torch.long, device=current_device + ) + object_size_list = [ + object_sizes_tensor[i].unsqueeze(dim=0) for i in range(group_size) + ] + # Allgather tensor sizes + all_gather(object_size_list, local_size, group=group) + max_object_size = int(max(object_size_list).item()) # type: ignore[type-var] + # Resize tensor to max size across all ranks. + input_tensor.resize_(max_object_size) + coalesced_output_tensor = torch.empty( + max_object_size * group_size, dtype=torch.uint8, device=current_device + ) + # Output tensors are nonoverlapping views of coalesced_output_tensor + output_tensors = [ + coalesced_output_tensor[max_object_size * i : max_object_size * (i + 1)] + for i in range(group_size) + ] + all_gather(output_tensors, input_tensor, group=group) + # Deserialize outputs back to object. + for i, tensor in enumerate(output_tensors): + tensor = tensor.type(torch.uint8) + tensor_size = object_size_list[i] + object_list[i] = _tensor_to_object(tensor, tensor_size, group) + + +@_exception_logger +def gather_object( + obj: Any, + object_gather_list: list[Any] | None = None, + dst: int | None = None, + group: ProcessGroup | None = None, + group_dst: int | None = None, +): + """ + Gathers picklable objects from the whole group in a single process. + + Similar to :func:`gather`, but Python objects can be passed in. Note that the + object must be picklable in order to be gathered. + + Args: + obj (Any): Input object. Must be picklable. + object_gather_list (list[Any]): Output list. On the ``dst`` rank, it + should be correctly sized as the size of the group for this + collective and will contain the output. Must be ``None`` on non-dst + ranks. (default is ``None``) + dst (int, optional): Destination rank on global process group (regardless of ``group`` argument). + (If both ``dst`` and ``group_dst`` are None, default is global rank 0) + group: (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. Default is ``None``. + group_dst (int, optional): Destination rank on ``group``. Invalid to specify both ``dst`` and ``group_dst`` + + Returns: + None. On the ``dst`` rank, ``object_gather_list`` will contain the + output of the collective. + + .. note:: Note that this API differs slightly from the gather collective + since it does not provide an async_op handle and thus will be a blocking + call. + + .. note:: For NCCL-based processed groups, internal tensor representations + of objects must be moved to the GPU device before communication takes + place. In this case, the device used is given by + ``torch.cuda.current_device()`` and it is the user's responsibility to + ensure that this is set so that each rank has an individual GPU, via + ``torch.cuda.set_device()``. + + .. warning:: + Object collectives have a number of serious performance and scalability + limitations. See :ref:`object_collectives` for details. + + .. warning:: + :func:`gather_object` uses ``pickle`` module implicitly, which is + known to be insecure. It is possible to construct malicious pickle data + which will execute arbitrary code during unpickling. Only call this + function with data you trust. + + .. warning:: + Calling :func:`gather_object` with GPU tensors is not well supported + and inefficient as it incurs GPU -> CPU transfer since tensors would be + pickled. Please consider using :func:`gather` instead. + + Example:: + >>> # xdoctest: +SKIP("need process group init") + >>> # Note: Process group initialization omitted on each rank. + >>> import torch.distributed as dist + >>> # Assumes world_size of 3. + >>> gather_objects = ["foo", 12, {1: 2}] # any picklable object + >>> output = [None for _ in gather_objects] + >>> dist.gather_object( + ... gather_objects[dist.get_rank()], + ... output if dist.get_rank() == 0 else None, + ... dst=0 + ... ) + >>> # On rank 0 + >>> output + ['foo', 12, {1: 2}] + """ + group = _group_or_default_group(group) + if dst is None and group_dst is None: + dst = 0 + group_dst = _canonicalize_group_rank(group, dst, group_dst, return_global=False) + if _rank_not_in_group(group): + _warn_not_in_group("gather_object") + return + + # Ensure object_gather_list is specified appropriately. + my_group_rank = group.rank() + _validate_output_list_for_rank(my_group_rank, group_dst, object_gather_list) + current_device = _get_object_coll_device(group) + input_tensor, local_size = _object_to_tensor(obj, current_device, group) + + # Gather all local sizes. This is so that we can find the max size, and index + # until the correct size when deserializing the tensors. + group_size = get_world_size(group=group) + object_sizes_tensor = torch.zeros( + group_size, dtype=torch.long, device=current_device + ) + object_size_list = [ + object_sizes_tensor[i].unsqueeze(dim=0) for i in range(group_size) + ] + # Allgather tensor sizes. An all-gather is needed here despite this being a + # gather, since each rank needs to broadcast a tensor of the same (maximal) + # size. + all_gather(object_size_list, local_size, group=group) + max_object_size = int(max(object_size_list).item()) # type: ignore[type-var] + # Resize tensor to max size across all ranks. + input_tensor.resize_(max_object_size) + # Avoid populating output tensors if the result won't be gathered on this rank. + if my_group_rank == group_dst: + coalesced_output_tensor = torch.empty( + max_object_size * group_size, dtype=torch.uint8, device=current_device + ) + # Output tensors are nonoverlapping views of coalesced_output_tensor + output_tensors = [ + coalesced_output_tensor[max_object_size * i : max_object_size * (i + 1)] + for i in range(group_size) + ] + # All ranks call gather with equal-sized tensors. + gather( + input_tensor, + gather_list=output_tensors if my_group_rank == group_dst else None, # type: ignore[possibly-undefined] + group_dst=group_dst, + group=group, + ) + if my_group_rank != group_dst: + return + + if object_gather_list is None: + raise AssertionError("Must provide object_gather_list on dst rank") + # pyrefly: ignore # unbound-name + for i, tensor in enumerate(output_tensors): + tensor = tensor.type(torch.uint8) + tensor_size = object_size_list[i] + object_gather_list[i] = _tensor_to_object(tensor, tensor_size, group) + + +@_exception_logger +def send_object_list( + object_list: list[Any], + dst: int | None = None, + group: ProcessGroup | None = None, + device: torch.device | None = None, + group_dst: int | None = None, + use_batch: bool = False, +): + """ + Sends picklable objects in ``object_list`` synchronously. + + Similar to :func:`send`, but Python objects can be passed in. + Note that all objects in ``object_list`` must be picklable in order to be + sent. + + Args: + object_list (List[Any]): List of input objects to sent. + Each object must be picklable. Receiver must provide lists of equal sizes. + dst (int): Destination rank to send ``object_list`` to. + Destination rank is based on global process group (regardless of ``group`` argument) + group: (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. Default is ``None``. + device (``torch.device``, optional): If not None, the objects are + serialized and converted to tensors which are moved to the + ``device`` before sending. Default is ``None``. + group_dst (int, optional): Destination rank on ``group``. + Must specify one of ``dst`` and ``group_dst`` but not both + use_batch (bool, optional): If True, use batch p2p operations instead of + regular send operations. This avoids initializing 2-rank communicators and + uses existing entire group communicators. See batch_isend_irecv for usage and + assumptions. Default is ``False``. + Returns: + ``None``. + + .. note:: For NCCL-based process groups, internal tensor representations + of objects must be moved to the GPU device before communication takes + place. In this case, the device used is given by + ``torch.cuda.current_device()`` and it is the user's responsibility to + ensure that this is set so that each rank has an individual GPU, via + ``torch.cuda.set_device()``. + + .. warning:: + Object collectives have a number of serious performance and scalability + limitations. See :ref:`object_collectives` for details. + + .. warning:: + :func:`send_object_list` uses ``pickle`` module implicitly, which + is known to be insecure. It is possible to construct malicious pickle + data which will execute arbitrary code during unpickling. Only call this + function with data you trust. + + .. warning:: + Calling :func:`send_object_list` with GPU tensors is not well supported + and inefficient as it incurs GPU -> CPU transfer since tensors would be + pickled. Please consider using :func:`send` instead. + + Example:: + >>> # xdoctest: +SKIP("need process group init") + >>> # Note: Process group initialization omitted on each rank. + >>> import torch.distributed as dist + >>> # Assumes backend is not NCCL + >>> device = torch.device("cpu") + >>> if dist.get_rank() == 0: + >>> # Assumes world_size of 2. + >>> objects = ["foo", 12, {1: 2}] # any picklable object + >>> dist.send_object_list(objects, dst=1, device=device) + >>> else: + >>> objects = [None, None, None] + >>> dist.recv_object_list(objects, src=0, device=device) + >>> objects + ['foo', 12, {1: 2}] + """ + group = _group_or_default_group(group) + group_dst = _canonicalize_group_rank(group, dst, group_dst) + _check_not_self_rank(group, group_dst, "destination") + + if _rank_not_in_group(group): + _warn_not_in_group("send_object_list") + return + + # Current device selection. + # To preserve backwards compatibility, ``device`` is default to ``None`` + # in which case we run current logic of device selection, i.e. + # ``current_device`` is CUDA if backend is NCCL otherwise CPU device. In the + # case it is not ``None`` we move the size and object tensors to be + # sent to this device. + current_device = device or _get_object_coll_device(group) + # Serialize object_list elements to tensors on src rank. + tensor_list, size_list = zip( + *[_object_to_tensor(obj, current_device, group) for obj in object_list] + ) + object_sizes_tensor = torch.cat(size_list) + + # Send object sizes + if use_batch: + batch_isend_irecv( + [P2POp(isend, object_sizes_tensor, group_peer=group_dst, group=group)] + ).pop().wait() + else: + send(object_sizes_tensor, group_dst=group_dst, group=group) + + # Concatenate and send serialized object tensors + # Note: torch.cat will do an extra memory copy to the current device, if the tensor_list + # has only one element, we can skip the copy. + if len(tensor_list) == 1: # type: ignore[possibly-undefined] + object_tensor = tensor_list[0] + else: + object_tensor = torch.cat(tensor_list) + + if use_batch: + batch_isend_irecv( + [P2POp(isend, object_tensor, group_peer=group_dst, group=group)] + ).pop().wait() + else: + send(object_tensor, group_dst=group_dst, group=group) + + +@_exception_logger +def recv_object_list( + object_list: list[Any], + src: int | None = None, + group: ProcessGroup | None = None, + device: torch.device | None = None, + group_src: int | None = None, + use_batch: bool = False, +): + """ + Receives picklable objects in ``object_list`` synchronously. + + Similar to :func:`recv`, but can receive Python objects. + + Args: + object_list (List[Any]): List of objects to receive into. + Must provide a list of sizes equal to the size of the list being sent. + src (int, optional): Source rank from which to recv ``object_list``. + Source rank is based on global process group (regardless of ``group`` argument) + Will receive from any rank if set to None. Default is ``None``. + group: (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. Default is ``None``. + device (``torch.device``, optional): If not None, receives on this device. + Default is ``None``. + group_src (int, optional): Destination rank on ``group``. Invalid to specify both ``src`` and ``group_src``. + use_batch (bool, optional): If True, use batch p2p operations instead of + regular send operations. This avoids initializing 2-rank communicators and + uses existing entire group communicators. See batch_isend_irecv for usage and + assumptions. Default is ``False``. + + Returns: + Sender rank. -1 if rank is not part of the group. If rank is part of the group, + ``object_list`` will contain the sent objects from ``src`` rank. + + .. note:: For NCCL-based process groups, internal tensor representations + of objects must be moved to the GPU device before communication takes + place. In this case, the device used is given by + ``torch.cuda.current_device()`` and it is the user's responsibility to + ensure that this is set so that each rank has an individual GPU, via + ``torch.cuda.set_device()``. + + .. warning:: + Object collectives have a number of serious performance and scalability + limitations. See :ref:`object_collectives` for details. + + .. warning:: + :func:`recv_object_list` uses ``pickle`` module implicitly, which + is known to be insecure. It is possible to construct malicious pickle + data which will execute arbitrary code during unpickling. Only call this + function with data you trust. + + .. warning:: + Calling :func:`recv_object_list` with GPU tensors is not well supported + and inefficient as it incurs GPU -> CPU transfer since tensors would be + pickled. Please consider using :func:`recv` instead. + + Example:: + >>> # xdoctest: +SKIP("need process group init") + >>> # Note: Process group initialization omitted on each rank. + >>> import torch.distributed as dist + >>> # Assumes backend is not NCCL + >>> device = torch.device("cpu") + >>> if dist.get_rank() == 0: + >>> # Assumes world_size of 2. + >>> objects = ["foo", 12, {1: 2}] # any picklable object + >>> dist.send_object_list(objects, dst=1, device=device) + >>> else: + >>> objects = [None, None, None] + >>> dist.recv_object_list(objects, src=0, device=device) + >>> objects + ['foo', 12, {1: 2}] + """ + group = _group_or_default_group(group) + group_src = _canonicalize_group_rank(group, src, group_src) + _check_not_self_rank(group, group_src, "source") + + if _rank_not_in_group(group): + _warn_not_in_group("recv_object_list") + return -1 + + # Current device selection. + # To preserve backwards compatibility, ``device`` is default to ``None`` + # in which case we run current logic of device selection, i.e. + # ``current_device`` is CUDA if backend is NCCL otherwise CPU device. In the + # case it is not ``None`` we move the size and object tensors to be + # received to this device. + current_device = device or _get_object_coll_device(group) + object_sizes_tensor = torch.empty( + len(object_list), dtype=torch.long, device=current_device + ) + + # Receive object sizes + if use_batch: + work = batch_isend_irecv( + [ + P2POp( + irecv, + object_sizes_tensor, + group_peer=group_src, + group=group, + ) + ] + ).pop() + work.wait() + rank_sizes = get_global_rank(group, group_src) + else: + rank_sizes = recv(object_sizes_tensor, group=group, group_src=group_src) + + # Tensor to receive serialized objects into. + object_tensor = torch.empty( # type: ignore[call-overload] + torch.sum(object_sizes_tensor).item(), # type: ignore[arg-type] + dtype=torch.uint8, + device=current_device, + ) + + if use_batch: + work = batch_isend_irecv( + [ + P2POp( + irecv, + object_tensor, + group_peer=group_src, + group=group, + ) + ] + ).pop() + work.wait() + rank_objects = get_global_rank(group, group_src) + else: + rank_objects = recv(object_tensor, group=group, group_src=group_src) + if rank_sizes != rank_objects: + raise AssertionError("Mismatch in return ranks for object sizes and objects.") + # Deserialize objects using their stored sizes. + offset = 0 + for i, obj_size in enumerate(object_sizes_tensor): + obj_view = object_tensor[offset : offset + obj_size] + obj_view = obj_view.type(torch.uint8) + offset += obj_size + object_list[i] = _tensor_to_object(obj_view, obj_size, group) + return rank_objects + + +@_exception_logger +def broadcast_object_list( + object_list: list[Any], + src: int | None = None, + group: ProcessGroup | None = None, + device: torch.device | None = None, + group_src: int | None = None, +): + """ + Broadcasts picklable objects in ``object_list`` to the whole group. + + Similar to :func:`broadcast`, but Python objects can be passed in. + Note that all objects in ``object_list`` must be picklable in order to be + broadcasted. + + Args: + object_list (List[Any]): List of input objects to broadcast. + Each object must be picklable. Only objects on the ``src`` rank will + be broadcast, but each rank must provide lists of equal sizes. + src (int): Source rank from which to broadcast ``object_list``. + Source rank is based on global process group (regardless of ``group`` argument) + group: (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. Default is ``None``. + device (``torch.device``, optional): If not None, the objects are + serialized and converted to tensors which are moved to the + ``device`` before broadcasting. Default is ``None``. + group_src (int): Source rank on ``group``. Must not specify one of ``group_src`` + and ``src`` but not both. + + Returns: + ``None``. If rank is part of the group, ``object_list`` will contain the + broadcasted objects from ``src`` rank. + + .. note:: For NCCL-based process groups, internal tensor representations + of objects must be moved to the GPU device before communication takes + place. In this case, the device used is given by + ``torch.cuda.current_device()`` and it is the user's responsibility to + ensure that this is set so that each rank has an individual GPU, via + ``torch.cuda.set_device()``. + + .. note:: Note that this API differs slightly from the :func:`broadcast` + collective since it does not provide an ``async_op`` handle and thus + will be a blocking call. + + .. warning:: + Object collectives have a number of serious performance and scalability + limitations. See :ref:`object_collectives` for details. + + .. warning:: + :func:`broadcast_object_list` uses ``pickle`` module implicitly, which + is known to be insecure. It is possible to construct malicious pickle + data which will execute arbitrary code during unpickling. Only call this + function with data you trust. + + .. warning:: + Calling :func:`broadcast_object_list` with GPU tensors is not well supported + and inefficient as it incurs GPU -> CPU transfer since tensors would be + pickled. Please consider using :func:`broadcast` instead. + + Example:: + >>> # xdoctest: +SKIP("need process group init") + >>> # Note: Process group initialization omitted on each rank. + >>> import torch.distributed as dist + >>> if dist.get_rank() == 0: + >>> # Assumes world_size of 3. + >>> objects = ["foo", 12, {1: 2}] # any picklable object + >>> else: + >>> objects = [None, None, None] + >>> # Assumes backend is not NCCL + >>> device = torch.device("cpu") + >>> dist.broadcast_object_list(objects, src=0, device=device) + >>> objects + ['foo', 12, {1: 2}] + """ + group = _group_or_default_group(group) + if src is None and group_src is None: + src = 0 + group_src = _canonicalize_group_rank(group, src, group_src, return_global=False) + if _rank_not_in_group(group): + _warn_not_in_group("broadcast_object_list") + return + + # Current device selection. + # To preserve backwards compatibility, ``device`` is default to ``None`` + # in which case we run current logic of device selection, i.e. + # ``current_device`` is CUDA if backend is NCCL otherwise CPU device. In the + # case it is not ``None`` we move the size and object tensors to be + # broadcasted to this device. + current_device = device or _get_object_coll_device(group) + my_group_rank = group.rank() + # Serialize object_list elements to tensors on src rank. + if my_group_rank == group_src: + tensor_list, size_list = zip( + *[_object_to_tensor(obj, current_device, group) for obj in object_list] + ) + object_sizes_tensor = torch.cat(size_list) + else: + object_sizes_tensor = torch.empty( + len(object_list), dtype=torch.long, device=current_device + ) + + # Broadcast object sizes + broadcast(object_sizes_tensor, group_src=group_src, group=group) + + # Concatenate and broadcast serialized object tensors + # Note: torch.cat will do an extra memory copy to the current device, if the tensor_list + # has only one element, we can skip the copy. + if my_group_rank == group_src: + if len(tensor_list) == 1: # type: ignore[possibly-undefined] + # pyrefly: ignore [unbound-name] + object_tensor = tensor_list[0] + else: + # pyrefly: ignore [unbound-name] + object_tensor = torch.cat(tensor_list) + else: + object_tensor = torch.empty( # type: ignore[call-overload] + torch.sum(object_sizes_tensor).item(), # type: ignore[arg-type] + dtype=torch.uint8, + device=current_device, + ) + + broadcast(object_tensor, group_src=group_src, group=group) + # Deserialize objects using their stored sizes. + offset = 0 + if my_group_rank != group_src: + for i, obj_size in enumerate(object_sizes_tensor): + obj_view = object_tensor[offset : offset + obj_size] + obj_view = obj_view.type(torch.uint8) + offset += obj_size + object_list[i] = _tensor_to_object(obj_view, obj_size, group) + + +@_exception_logger +def scatter_object_list( + scatter_object_output_list: list[Any], + scatter_object_input_list: list[Any] | None = None, + src: int | None = None, + group: ProcessGroup | None = None, + group_src: int | None = None, +): + """ + Scatters picklable objects in ``scatter_object_input_list`` to the whole group. + + Similar to :func:`scatter`, but Python objects can be passed in. On + each rank, the scattered object will be stored as the first element of + ``scatter_object_output_list``. Note that all objects in + ``scatter_object_input_list`` must be picklable in order to be scattered. + + Args: + scatter_object_output_list (List[Any]): Non-empty list whose first + element will store the object scattered to this rank. + scatter_object_input_list (List[Any], optional): List of input objects to scatter. + Each object must be picklable. Only objects on the ``src`` rank will + be scattered, and the argument can be ``None`` for non-src ranks. + src (int): Source rank from which to scatter ``scatter_object_input_list``. + Source rank is based on global process group (regardless of ``group`` argument). + (If both ``src`` and ``group_src`` are None, default is global rank 0) + group: (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. Default is ``None``. + group_src (int, optional): Source rank on ``group``. Invalid to specify both ``src`` and ``group_src`` + + Returns: + ``None``. If rank is part of the group, ``scatter_object_output_list`` + will have its first element set to the scattered object for this rank. + + .. note:: Note that this API differs slightly from the scatter collective + since it does not provide an ``async_op`` handle and thus will be a + blocking call. + + .. warning:: + Object collectives have a number of serious performance and scalability + limitations. See :ref:`object_collectives` for details. + + .. warning:: + :func:`scatter_object_list` uses ``pickle`` module implicitly, which + is known to be insecure. It is possible to construct malicious pickle + data which will execute arbitrary code during unpickling. Only call this + function with data you trust. + + .. warning:: + Calling :func:`scatter_object_list` with GPU tensors is not well supported + and inefficient as it incurs GPU -> CPU transfer since tensors would be + pickled. Please consider using :func:`scatter` instead. + + Example:: + >>> # xdoctest: +SKIP("need process group init") + >>> # Note: Process group initialization omitted on each rank. + >>> import torch.distributed as dist + >>> if dist.get_rank() == 0: + >>> # Assumes world_size of 3. + >>> objects = ["foo", 12, {1: 2}] # any picklable object + >>> else: + >>> # Can be any list on non-src ranks, elements are not used. + >>> objects = [None, None, None] + >>> output_list = [None] + >>> dist.scatter_object_list(output_list, objects, src=0) + >>> # Rank i gets objects[i]. For example, on rank 2: + >>> output_list + [{1: 2}] + """ + group = _group_or_default_group(group) + if src is None and group_src is None: + src = 0 + group_src = _canonicalize_group_rank(group, src, group_src, return_global=False) + if _rank_not_in_group(group): + _warn_not_in_group("scatter_object_list") + return + + if ( + not isinstance(scatter_object_output_list, list) + or len(scatter_object_output_list) < 1 + ): + raise ValueError( + "Expected argument scatter_object_output_list to be a list of size at least 1." + ) + + my_group_rank = group.rank() + pg_device = _get_object_coll_device(group) + if my_group_rank == group_src: + if scatter_object_input_list is None: + raise ValueError( + "source rank must provide non-None scatter_object_input_list" + ) + tensor_list, tensor_sizes = zip( + *[ + _object_to_tensor(obj, pg_device, group) + for obj in scatter_object_input_list + ] + ) + tensor_list, tensor_sizes = list(tensor_list), list(tensor_sizes) + + # Src rank broadcasts the maximum tensor size. This is because all ranks are + # expected to call into scatter() with equal-sized tensors. + max_tensor_size = max(tensor_sizes) # type: ignore[possibly-undefined] + for tensor in tensor_list: # type: ignore[possibly-undefined] + tensor.resize_(max_tensor_size) + else: + max_tensor_size = torch.tensor([0], dtype=torch.long, device=pg_device) + broadcast(max_tensor_size, group_src=group_src, group=group) + + # Scatter actual serialized objects + # pyrefly: ignore [no-matching-overload] + output_tensor = torch.empty( + max_tensor_size.item(), dtype=torch.uint8, device=pg_device + ) + scatter( + output_tensor, + scatter_list=None if my_group_rank != group_src else tensor_list, # type: ignore[possibly-undefined] + group_src=group_src, + group=group, + ) + + # Scatter per-object sizes to trim tensors when deserializing back to object + obj_tensor_size = torch.tensor([0], dtype=torch.long, device=pg_device) + scatter( + obj_tensor_size, + scatter_list=None if my_group_rank != group_src else tensor_sizes, # type: ignore[possibly-undefined] + group_src=group_src, + group=group, + ) + + # Deserialize back to object + scatter_object_output_list[0] = _tensor_to_object( + output_tensor, obj_tensor_size, group + ) + + +@_exception_logger +def all_gather(tensor_list, tensor, group=None, async_op=False): + """ + Gathers tensors from the whole group in a list. + + Complex and uneven sized tensors are supported. + + Args: + tensor_list (list[Tensor]): Output list. It should contain + correctly-sized tensors to be used for output of the collective. + Uneven sized tensors are supported. + tensor (Tensor): Tensor to be broadcast from current process. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + async_op (bool, optional): Whether this op should be an async op + + Returns: + Async work handle, if async_op is set to True. + None, if not async_op or if not part of the group + + Examples: + >>> # xdoctest: +SKIP("need process group init") + >>> # All tensors below are of torch.int64 dtype. + >>> # We have 2 process groups, 2 ranks. + >>> device = torch.device(f"cuda:{rank}") + >>> tensor_list = [ + ... torch.zeros(2, dtype=torch.int64, device=device) for _ in range(2) + ... ] + >>> tensor_list + [tensor([0, 0], device='cuda:0'), tensor([0, 0], device='cuda:0')] # Rank 0 + [tensor([0, 0], device='cuda:1'), tensor([0, 0], device='cuda:1')] # Rank 1 + >>> tensor = torch.arange(2, dtype=torch.int64, device=device) + 1 + 2 * rank + >>> tensor + tensor([1, 2], device='cuda:0') # Rank 0 + tensor([3, 4], device='cuda:1') # Rank 1 + >>> dist.all_gather(tensor_list, tensor) + >>> tensor_list + [tensor([1, 2], device='cuda:0'), tensor([3, 4], device='cuda:0')] # Rank 0 + [tensor([1, 2], device='cuda:1'), tensor([3, 4], device='cuda:1')] # Rank 1 + + >>> # All tensors below are of torch.cfloat dtype. + >>> # We have 2 process groups, 2 ranks. + >>> tensor_list = [ + ... torch.zeros(2, dtype=torch.cfloat, device=device) for _ in range(2) + ... ] + >>> tensor_list + [tensor([0.+0.j, 0.+0.j], device='cuda:0'), tensor([0.+0.j, 0.+0.j], device='cuda:0')] # Rank 0 + [tensor([0.+0.j, 0.+0.j], device='cuda:1'), tensor([0.+0.j, 0.+0.j], device='cuda:1')] # Rank 1 + >>> tensor = torch.tensor( + ... [1 + 1j, 2 + 2j], dtype=torch.cfloat, device=device + ... ) + 2 * rank * (1 + 1j) + >>> tensor + tensor([1.+1.j, 2.+2.j], device='cuda:0') # Rank 0 + tensor([3.+3.j, 4.+4.j], device='cuda:1') # Rank 1 + >>> dist.all_gather(tensor_list, tensor) + >>> tensor_list + [tensor([1.+1.j, 2.+2.j], device='cuda:0'), tensor([3.+3.j, 4.+4.j], device='cuda:0')] # Rank 0 + [tensor([1.+1.j, 2.+2.j], device='cuda:1'), tensor([3.+3.j, 4.+4.j], device='cuda:1')] # Rank 1 + + """ + # Dynamo has built-in logic to map legacy distributed ops to functional collectives. + # Let's redirect to a torch function mode that can mimic this logic outside Dynamo + # (e.g., non-strict export implements such a torch function mode). + relevant_args = (tensor,) + if has_torch_function(relevant_args): + return handle_torch_function( + all_gather, + relevant_args, + tensor_list, + tensor, + group=group, + async_op=async_op, + ) + + _check_tensor_list(tensor_list, "tensor_list") + _check_single_tensor(tensor, "tensor") + _ensure_all_tensors_same_dtype(tensor_list, tensor) + if _rank_not_in_group(group): + _warn_not_in_group("all_gather") + return + + tensor_list = [ + t if not t.is_complex() else torch.view_as_real(t) for t in tensor_list + ] + tensor = tensor if not tensor.is_complex() else torch.view_as_real(tensor) + + group = group or _get_default_group() + opts = AllgatherOptions() + opts.asyncOp = async_op + work = group.allgather([tensor_list], [tensor], opts) + + if async_op: + return work + elif ( + work is not None + ): # Backward compatible with backends that don't sync at CPP level + work.wait() + # Otherwise, the backend has sync'ed at CPP level + + +@_exception_logger +def all_gather_into_tensor(output_tensor, input_tensor, group=None, async_op=False): + """ + Gather tensors from all ranks and put them in a single output tensor. + + This function requires all tensors to be the same size on each process. + + Args: + output_tensor (Tensor): Output tensor to accommodate tensor elements + from all ranks. It must be correctly sized to have one of the + following forms: + (i) a concatenation of all the input tensors along the primary + dimension; for definition of "concatenation", see ``torch.cat()``; + (ii) a stack of all the input tensors along the primary dimension; + for definition of "stack", see ``torch.stack()``. + Examples below may better explain the supported output forms. + input_tensor (Tensor): Tensor to be gathered from current rank. + Different from the ``all_gather`` API, the input tensors in this + API must have the same size across all ranks. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + async_op (bool, optional): Whether this op should be an async op + + Returns: + Async work handle, if async_op is set to True. + None, if not async_op or if not part of the group + + Examples: + >>> # xdoctest: +SKIP("need process group init") + >>> # All tensors below are of torch.int64 dtype and on CUDA devices. + >>> # We have two ranks. + >>> device = torch.device(f"cuda:{rank}") + >>> tensor_in = torch.arange(2, dtype=torch.int64, device=device) + 1 + 2 * rank + >>> tensor_in + tensor([1, 2], device='cuda:0') # Rank 0 + tensor([3, 4], device='cuda:1') # Rank 1 + >>> # Output in concatenation form + >>> tensor_out = torch.zeros(world_size * 2, dtype=torch.int64, device=device) + >>> dist.all_gather_into_tensor(tensor_out, tensor_in) + >>> tensor_out + tensor([1, 2, 3, 4], device='cuda:0') # Rank 0 + tensor([1, 2, 3, 4], device='cuda:1') # Rank 1 + >>> # Output in stack form + >>> tensor_out2 = torch.zeros(world_size, 2, dtype=torch.int64, device=device) + >>> dist.all_gather_into_tensor(tensor_out2, tensor_in) + >>> tensor_out2 + tensor([[1, 2], + [3, 4]], device='cuda:0') # Rank 0 + tensor([[1, 2], + [3, 4]], device='cuda:1') # Rank 1 + """ + # Dynamo has built-in logic to map legacy distributed ops to functional collectives. + # Let's redirect to a torch function mode that can mimic this logic outside Dynamo + # (e.g., non-strict export implements such a torch function mode). + relevant_args = (input_tensor,) + if has_torch_function(relevant_args): + return handle_torch_function( + all_gather_into_tensor, + relevant_args, + output_tensor, + input_tensor, + group=group, + async_op=async_op, + ) + + _check_single_tensor(input_tensor, "input_tensor") + _check_single_tensor(output_tensor, "output_tensor") + if _rank_not_in_group(group): + _warn_not_in_group("all_gather_into_tensor") + return + + output_tensor = ( + output_tensor + if not output_tensor.is_complex() + else torch.view_as_real(output_tensor) + ) + input_tensor = ( + input_tensor + if not input_tensor.is_complex() + else torch.view_as_real(input_tensor) + ) + + opts = AllgatherOptions() + opts.asyncOp = async_op + + group = group or _get_default_group() + + if group in _world.pg_coalesce_state: + # We are in coalescing context, do not issue single operation, just append a collective representation + coll = _CollOp(all_gather_into_tensor, input_tensor, output_tensor) + _world.pg_coalesce_state[group].append(coll) + if async_op: + return _IllegalWork() + else: + return None + + work = group._allgather_base(output_tensor, input_tensor, opts) + + if async_op: + return work + elif ( + work is not None + ): # Backward compatible with backends that don't sync at CPP level + work.wait() + # Otherwise, the backend has sync'ed at CPP level + + +@_exception_logger +@deprecated( + "`torch.distributed._all_gather_base` is a private function and will be deprecated. " + "Please use `torch.distributed.all_gather_into_tensor` instead.", + category=FutureWarning, +) +def _all_gather_base(output_tensor, input_tensor, group=None, async_op: bool = False): + """ + Single tensor all gather. Gathers a single tensor from all ranks, and puts them in a single output tensor. + + Args: + output_tensor (Tensor): Output tensor. It should contain + correctly-sized tensors to be used for output of the collective. + input_tensor (Tensor): Tensor to be broadcast from current process. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + async_op (bool, optional): Whether this op should be an async op + + Returns: + Async work handle, if async_op is set to True. + None, if not async_op or if not part of the group + + .. warning:: + `_all_gather_base` is a private function. Users should use + `all_gather_into_tensor` instead. + + """ + return all_gather_into_tensor(output_tensor, input_tensor, group, async_op) + + +@_exception_logger +@deprecated( + "`torch.distributed.all_gather_coalesced` will be deprecated. If you must use it, " + "please revisit our documentation later at " + "https://pytorch.org/docs/main/distributed.html#collective-functions", + category=FutureWarning, +) +def all_gather_coalesced( + output_tensor_lists, input_tensor_list, group=None, async_op: bool = False +): + """ + Gathers input tensors from the whole group in a list in a coalesced manner. + + Complex tensors are supported. + + Args: + output_tensor_lists (list[list[Tensor]]): Output list. It should contain + correctly-sized tensors to be used for output of the collective. + input_tensor_list (list[Tensor]): Tensors to be broadcast from + current process. At least one tensor has to be non empty. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + async_op (bool, optional): Whether this op should be an async op. + + Returns: + Async work handle, if async_op is set to True. + None, if not async_op or if not part of the group + + Example: + we have 2 process groups, 2 ranks. + rank 0 passes: + input_tensor_list = [[[1, 1], [1, 1]], [2], [3, 3]] + output_tensor_lists = + [[[[-1, -1], [-1, -1]], [-1], [-1, -1]], + [[[-1, -1], [-1, -1]], [-1], [-1, -1]]] + rank 1 passes: + input_tensor_list = [[[3, 3], [3, 3]], [5], [1, 1]] + output_tensor_lists = + [[[[-1, -1], [-1, -1]], [-1], [-1, -1]], + [[[-1, -1], [-1, -1]], [-1], [-1, -1]]] + both rank 0 and 1 get: + output_tensor_lists = + [[[1, 1], [1, 1]], [2], [3, 3]], + [[3, 3], [3, 3]], [5], [1, 1]]]. + + WARNING: at this time individual shape checking is not implemented across nodes. + For example, if the rank 0 node passes [torch.rand(4), torch.rand(2)] and the + rank 1 node passes [torch.rand(2), torch.rand(2), torch.rand(2)], the + all_gather_coalesced operation will proceed without complaint and return + erroneous outputs. This lack of shape checking results in significant + performance improvements but users of this function should take extra care + to ensure that each node passes in tensors whose shapes match across nodes. + """ + # We only check basic compatibility with C++ params here, C++ code will + # do shape and type checking. + if _rank_not_in_group(group): + _warn_not_in_group("all_gather_coalesced") + return + _check_tensor_list(input_tensor_list, "input_tensor_list") + _ensure_all_tensors_same_dtype(input_tensor_list) + if not isinstance(output_tensor_lists, list): + raise TypeError( + "Invalid function argument: output_tensor_lists should be a list" + ) + for output_tensor_list in output_tensor_lists: + _check_tensor_list(output_tensor_list, "output_tensor_lists") + _ensure_all_tensors_same_dtype(output_tensor_list) + + output_tensor_lists = [ + [t if not t.is_complex() else torch.view_as_real(t) for t in l] + for l in output_tensor_lists + ] + input_tensor_list = [ + t if not t.is_complex() else torch.view_as_real(t) for t in input_tensor_list + ] + + group = group or _get_default_group() + opts = AllgatherOptions() + opts.asyncOp = async_op + work = group.allgather_coalesced(output_tensor_lists, input_tensor_list, opts) + + if async_op: + return work.get_future() + elif ( + work is not None + ): # Backward compatible with backends that don't sync at CPP level + work.wait() + # Otherwise, the backend has sync'ed at CPP level + + +def _validate_output_list_for_rank(my_rank: int, dst: int, gather_list): + if dst == my_rank: + if not gather_list: + raise ValueError( + "Argument ``gather_list`` must be specified on destination rank." + ) + elif gather_list: + raise ValueError( + "Argument ``gather_list`` must NOT be specified on non-destination ranks." + ) + + +@_exception_logger +def gather( + tensor: torch.Tensor, + gather_list: list[torch.Tensor] | None = None, + dst: int | None = None, + group: ProcessGroup | None = None, + async_op: bool = False, + group_dst: int | None = None, +): + """ + Gathers a list of tensors in a single process. + + This function requires all tensors to be the same size on each process. + + Args: + tensor (Tensor): Input tensor. + gather_list (list[Tensor], optional): List of appropriately, + same-sized tensors to use for gathered data + (default is None, must be specified on the destination rank) + dst (int, optional): Destination rank on global process group (regardless of ``group`` argument). + (If both ``dst`` and ``group_dst`` are None, default is global rank 0) + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + async_op (bool, optional): Whether this op should be an async op + group_dst (int, optional): Destination rank on ``group``. Invalid to specify both ``dst`` and ``group_dst`` + + Returns: + Async work handle, if async_op is set to True. + None, if not async_op or if not part of the group + + .. note:: Note that all Tensors in gather_list must have the same size. + + Example:: + >>> # xdoctest: +SKIP("no rank") + >>> # We have 2 process groups, 2 ranks. + >>> tensor_size = 2 + >>> device = torch.device(f'cuda:{rank}') + >>> tensor = torch.ones(tensor_size, device=device) + rank + >>> if dist.get_rank() == 0: + >>> gather_list = [torch.zeros_like(tensor, device=device) for i in range(2)] + >>> else: + >>> gather_list = None + >>> dist.gather(tensor, gather_list, dst=0) + >>> # Rank 0 gets gathered data. + >>> gather_list + [tensor([1., 1.], device='cuda:0'), tensor([2., 2.], device='cuda:0')] # Rank 0 + None # Rank 1 + + """ + _check_single_tensor(tensor, "tensor") + + # Parameter ``gather_list`` may be left unspecified on non-dst ranks. + if gather_list: + _check_tensor_list(gather_list, "gather_list") + else: + gather_list = [] + _ensure_all_tensors_same_dtype(tensor, gather_list) + group = _group_or_default_group(group) + if _rank_not_in_group(group): + _warn_not_in_group("gather") + return + if dst is None and group_dst is None: + dst = 0 + group_dst = _canonicalize_group_rank(group, dst, group_dst, return_global=False) + my_group_rank = group.rank() + _validate_output_list_for_rank(my_group_rank, group_dst, gather_list) + output_tensors = [gather_list] if group_dst == my_group_rank else [] + input_tensors = [tensor] + + opts = GatherOptions() + opts.rootRank = group_dst + opts.asyncOp = async_op + work = group.gather(output_tensors, input_tensors, opts) + + if async_op: + return work + elif ( + work is not None + ): # Backward compatible with backends that don't sync at CPP level + work.wait() + # Otherwise, the backend has sync'ed at CPP level + + +@_exception_logger +def scatter( + tensor: torch.Tensor, + scatter_list: list[torch.Tensor] | None = None, + src: int | None = None, + group: ProcessGroup | None = None, + async_op: bool = False, + group_src: int | None = None, +): + """ + Scatters a list of tensors to all processes in a group. + + Each process will receive exactly one tensor and store its data in the + ``tensor`` argument. + + Complex tensors are supported. + + Args: + tensor (Tensor): Output tensor. + scatter_list (list[Tensor]): List of tensors to scatter (default is + None, must be specified on the source rank) + src (int): Source rank on global process group (regardless of ``group`` argument). + (If both ``src`` and ``group_src`` are None, default is global rank 0) + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + async_op (bool, optional): Whether this op should be an async op + group_src (int, optional): Source rank on ``group``. Invalid to specify both ``src`` and ``group_src`` + + Returns: + Async work handle, if async_op is set to True. + None, if not async_op or if not part of the group + + .. note:: Note that all Tensors in scatter_list must have the same size. + + Example:: + >>> # xdoctest: +SKIP("need process group init") + >>> # Note: Process group initialization omitted on each rank. + >>> import torch.distributed as dist + >>> tensor_size = 2 + >>> device = torch.device(f'cuda:{rank}') + >>> output_tensor = torch.zeros(tensor_size, device=device) + >>> if dist.get_rank() == 0: + >>> # Assumes world_size of 2. + >>> # Only tensors, all of which must be the same size. + >>> t_ones = torch.ones(tensor_size, device=device) + >>> t_fives = torch.ones(tensor_size, device=device) * 5 + >>> scatter_list = [t_ones, t_fives] + >>> else: + >>> scatter_list = None + >>> dist.scatter(output_tensor, scatter_list, src=0) + >>> # Rank i gets scatter_list[i]. + >>> output_tensor + tensor([1., 1.], device='cuda:0') # Rank 0 + tensor([5., 5.], device='cuda:1') # Rank 1 + + """ + _check_single_tensor(tensor, "tensor") + # Parameter ``scatter_list`` may be left unspecified on non-src ranks. + if scatter_list: + _check_tensor_list(scatter_list, "scatter_list") + else: + scatter_list = [] + _ensure_all_tensors_same_dtype(tensor, scatter_list) + group = _group_or_default_group(group) + if src is None and group_src is None: + src = 0 + group_src = _canonicalize_group_rank(group, src, group_src, return_global=False) + if _rank_not_in_group(group): + _warn_not_in_group("scatter") + return + scatter_list = [ + t if not t.is_complex() else torch.view_as_real(t) for t in scatter_list + ] + tensor = tensor if not tensor.is_complex() else torch.view_as_real(tensor) + + my_group_rank = group.rank() + if group_src == my_group_rank: + if not scatter_list: + raise ValueError( + "Argument ``scatter_list`` must be specified on source rank." + ) + input_tensors = [scatter_list] + output_tensors = [tensor] + else: + if scatter_list: + raise ValueError( + "Argument ``scatter_list`` must NOT be specified on non-source ranks." + ) + input_tensors = [] + output_tensors = [tensor] + + opts = ScatterOptions() + opts.rootRank = group_src + opts.asyncOp = async_op + work = group.scatter(output_tensors, input_tensors, opts) + + if async_op: + return work + elif ( + work is not None + ): # Backward compatible with backends that don't sync at CPP level + work.wait() + # Otherwise, the backend has sync'ed at CPP level + + +@_exception_logger +def reduce_scatter( + output, input_list, op=ReduceOp.SUM, group=None, async_op: bool = False +): + """ + Reduces, then scatters a list of tensors to all processes in a group. + + Args: + output (Tensor): Output tensor. + input_list (list[Tensor]): List of tensors to reduce and scatter. + op (optional): One of the values from + ``torch.distributed.ReduceOp`` + enum. Specifies an operation used for element-wise reductions. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + async_op (bool, optional): Whether this op should be an async op. + + Returns: + Async work handle, if async_op is set to True. + None, if not async_op or if not part of the group. + + """ + _check_single_tensor(output, "output") + _check_tensor_list(input_list, "input_list") + _ensure_all_tensors_same_dtype(output, input_list) + if _rank_not_in_group(group): + _warn_not_in_group("reduce_scatter") + return + + opts = ReduceScatterOptions() + opts.reduceOp = op + opts.asyncOp = async_op + + group = group or _get_default_group() + work = group.reduce_scatter([output], [input_list], opts) + + if async_op: + return work + elif ( + work is not None + ): # Backward compatible with backends that don't sync at CPP level + work.wait() + # Otherwise, the backend has sync'ed at CPP level + + +@_exception_logger +def reduce_scatter_tensor(output, input, op=ReduceOp.SUM, group=None, async_op=False): + """ + Reduces, then scatters a tensor to all ranks in a group. + + Args: + output (Tensor): Output tensor. It should have the same size across all + ranks. + input (Tensor): Input tensor to be reduced and scattered. Its size + should be output tensor size times the world size. The input tensor + can have one of the following shapes: + (i) a concatenation of the output tensors along the primary + dimension, or + (ii) a stack of the output tensors along the primary dimension. + For definition of "concatenation", see ``torch.cat()``. + For definition of "stack", see ``torch.stack()``. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + async_op (bool, optional): Whether this op should be an async op. + + Returns: + Async work handle, if async_op is set to True. + None, if not async_op or if not part of the group. + + Examples: + >>> # xdoctest: +SKIP("need process group init") + >>> # All tensors below are of torch.int64 dtype and on CUDA devices. + >>> # We have two ranks. + >>> device = torch.device(f"cuda:{rank}") + >>> tensor_out = torch.zeros(2, dtype=torch.int64, device=device) + >>> # Input in concatenation form + >>> tensor_in = torch.arange(world_size * 2, dtype=torch.int64, device=device) + >>> tensor_in + tensor([0, 1, 2, 3], device='cuda:0') # Rank 0 + tensor([0, 1, 2, 3], device='cuda:1') # Rank 1 + >>> dist.reduce_scatter_tensor(tensor_out, tensor_in) + >>> tensor_out + tensor([0, 2], device='cuda:0') # Rank 0 + tensor([4, 6], device='cuda:1') # Rank 1 + >>> # Input in stack form + >>> tensor_in = torch.reshape(tensor_in, (world_size, 2)) + >>> tensor_in + tensor([[0, 1], + [2, 3]], device='cuda:0') # Rank 0 + tensor([[0, 1], + [2, 3]], device='cuda:1') # Rank 1 + >>> dist.reduce_scatter_tensor(tensor_out, tensor_in) + >>> tensor_out + tensor([0, 2], device='cuda:0') # Rank 0 + tensor([4, 6], device='cuda:1') # Rank 1 + + """ + # Dynamo has built-in logic to map legacy distributed ops to functional collectives. + # Let's redirect to a torch function mode that can mimic this logic outside Dynamo + # (e.g., non-strict export implements such a torch function mode). + relevant_args = (input,) + if has_torch_function(relevant_args): + return handle_torch_function( + reduce_scatter_tensor, + relevant_args, + output, + input, + op=op, + group=group, + async_op=async_op, + ) + + _check_single_tensor(output, "output") + _check_single_tensor(input, "input") + + if _rank_not_in_group(group): + _warn_not_in_group("reduce_scatter_tensor") + return + + opts = ReduceScatterOptions() + opts.reduceOp = op + opts.asyncOp = async_op + + group = group or _get_default_group() + + # Check if we are in coalescing context + # If we are, do not issue single operation, just append a collective representation + if group in _world.pg_coalesce_state: + coll = _CollOp(reduce_scatter_tensor, input, output, op, None) + _world.pg_coalesce_state[group].append(coll) + if async_op: + return _IllegalWork() + else: + return None + + work = group._reduce_scatter_base(output, input, opts) + + if async_op: + return work + elif ( + work is not None + ): # Backward compatible with backends that don't sync at CPP level + work.wait() + # Otherwise, the backend has sync'ed at CPP level + + +@deprecated( + "`torch.distributed._reduce_scatter_base` is a private function and will be deprecated. " + "Please use `torch.distributed.reduce_scatter_tensor` instead.", + category=FutureWarning, +) +def _reduce_scatter_base( + output, input, op=ReduceOp.SUM, group=None, async_op: bool = False +): + """ + Reduces, then scatters a flattened tensor to all processes in a group. + + Args: + output (Tensor): Output tensor. + input (Tensor): Input tensor that is of size output tensor size times world size + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + async_op (bool, optional): Whether this op should be an async op. + + Returns: + Async work handle, if async_op is set to True. + None, if not async_op or if not part of the group. + + .. warning:: + `_reduce_scatter_base` is a private function. Users should use + `reduce_scatter_tensor` instead. + + """ + return reduce_scatter_tensor(output, input, op, group, async_op) + + +@_exception_logger +def all_to_all_single( + output, + input, + output_split_sizes=None, + input_split_sizes=None, + group=None, + async_op: bool = False, +): + """ + Split input tensor and then scatter the split list to all processes in a group. + + Later the received tensors are concatenated from all the processes in the group + and returned as a single output tensor. + + Complex tensors are supported. + + Args: + output (Tensor): Gathered concatenated output tensor. + input (Tensor): Input tensor to scatter. + output_split_sizes: (list[Int], optional): Output split sizes for dim 0 + if specified None or empty, dim 0 of ``output`` tensor must divide + equally by ``world_size``. + input_split_sizes: (list[Int], optional): Input split sizes for dim 0 + if specified None or empty, dim 0 of ``input`` tensor must divide + equally by ``world_size``. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + async_op (bool, optional): Whether this op should be an async op. + + Returns: + Async work handle, if async_op is set to True. + None, if not async_op or if not part of the group. + + .. warning:: + `all_to_all_single` is experimental and subject to change. + + Examples: + >>> # xdoctest: +SKIP("Undefined rank") + >>> input = torch.arange(4) + rank * 4 + >>> input + tensor([0, 1, 2, 3]) # Rank 0 + tensor([4, 5, 6, 7]) # Rank 1 + tensor([8, 9, 10, 11]) # Rank 2 + tensor([12, 13, 14, 15]) # Rank 3 + >>> output = torch.empty([4], dtype=torch.int64) + >>> dist.all_to_all_single(output, input) + >>> output + tensor([0, 4, 8, 12]) # Rank 0 + tensor([1, 5, 9, 13]) # Rank 1 + tensor([2, 6, 10, 14]) # Rank 2 + tensor([3, 7, 11, 15]) # Rank 3 + + >>> # Essentially, it is similar to following operation: + >>> scatter_list = list(input.chunk(world_size)) + >>> gather_list = list(output.chunk(world_size)) + >>> for i in range(world_size): + >>> dist.scatter(gather_list[i], scatter_list if i == rank else [], src = i) + + >>> # Another example with uneven split + >>> input + tensor([0, 1, 2, 3, 4, 5]) # Rank 0 + tensor([10, 11, 12, 13, 14, 15, 16, 17, 18]) # Rank 1 + tensor([20, 21, 22, 23, 24]) # Rank 2 + tensor([30, 31, 32, 33, 34, 35, 36]) # Rank 3 + >>> input_splits + [2, 2, 1, 1] # Rank 0 + [3, 2, 2, 2] # Rank 1 + [2, 1, 1, 1] # Rank 2 + [2, 2, 2, 1] # Rank 3 + >>> output_splits + [2, 3, 2, 2] # Rank 0 + [2, 2, 1, 2] # Rank 1 + [1, 2, 1, 2] # Rank 2 + [1, 2, 1, 1] # Rank 3 + >>> output = ... + >>> dist.all_to_all_single(output, input, output_splits, input_splits) + >>> output + tensor([ 0, 1, 10, 11, 12, 20, 21, 30, 31]) # Rank 0 + tensor([ 2, 3, 13, 14, 22, 32, 33]) # Rank 1 + tensor([ 4, 15, 16, 23, 34, 35]) # Rank 2 + tensor([ 5, 17, 18, 24, 36]) # Rank 3 + + + >>> # Another example with tensors of torch.cfloat type. + >>> input = torch.tensor( + ... [1 + 1j, 2 + 2j, 3 + 3j, 4 + 4j], dtype=torch.cfloat + ... ) + 4 * rank * (1 + 1j) + >>> input + tensor([1+1j, 2+2j, 3+3j, 4+4j]) # Rank 0 + tensor([5+5j, 6+6j, 7+7j, 8+8j]) # Rank 1 + tensor([9+9j, 10+10j, 11+11j, 12+12j]) # Rank 2 + tensor([13+13j, 14+14j, 15+15j, 16+16j]) # Rank 3 + >>> output = torch.empty([4], dtype=torch.int64) + >>> dist.all_to_all_single(output, input) + >>> output + tensor([1+1j, 5+5j, 9+9j, 13+13j]) # Rank 0 + tensor([2+2j, 6+6j, 10+10j, 14+14j]) # Rank 1 + tensor([3+3j, 7+7j, 11+11j, 15+15j]) # Rank 2 + tensor([4+4j, 8+8j, 12+12j, 16+16j]) # Rank 3 + """ + # Dynamo has built-in logic to map legacy distributed ops to functional collectives. + # Let's redirect to a torch function mode that can mimic this logic outside Dynamo + # (e.g., non-strict export implements such a torch function mode). + relevant_args = (input,) + if has_torch_function(relevant_args): + return handle_torch_function( + all_to_all_single, + relevant_args, + output, + input, + output_split_sizes=output_split_sizes, + input_split_sizes=input_split_sizes, + group=group, + async_op=async_op, + ) + + if _rank_not_in_group(group): + _warn_not_in_group("all_to_all_single") + return + + opts = AllToAllOptions() + opts.asyncOp = async_op + _check_single_tensor(output, "output") + _check_single_tensor(input, "input") + _ensure_all_tensors_same_dtype(output, input) + + if input.is_complex(): + input = torch.view_as_real(input) + if output.is_complex(): + output = torch.view_as_real(output) + + output_split_sizes = [] if output_split_sizes is None else output_split_sizes + input_split_sizes = [] if input_split_sizes is None else input_split_sizes + + group = group or _get_default_group() + work = group.alltoall_base( + output, input, output_split_sizes, input_split_sizes, opts + ) + + if async_op: + return work + elif ( + work is not None + ): # Backward compatible with backends that don't sync at CPP level + work.wait() + # Otherwise, the backend has sync'ed at CPP level + + +@_exception_logger +def all_to_all( + output_tensor_list, input_tensor_list, group=None, async_op: bool = False +): + """ + Scatters list of input tensors to all processes in a group and return gathered list of tensors in output list. + + Complex tensors are supported. + + Args: + output_tensor_list (list[Tensor]): List of tensors to be gathered one + per rank. + input_tensor_list (list[Tensor]): List of tensors to scatter one per rank. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + async_op (bool, optional): Whether this op should be an async op. + + Returns: + Async work handle, if async_op is set to True. + None, if not async_op or if not part of the group. + + .. warning:: + `all_to_all` is experimental and subject to change. + + Examples: + >>> # xdoctest: +SKIP("Undefined rank") + >>> input = torch.arange(4) + rank * 4 + >>> input = list(input.chunk(4)) + >>> input + [tensor([0]), tensor([1]), tensor([2]), tensor([3])] # Rank 0 + [tensor([4]), tensor([5]), tensor([6]), tensor([7])] # Rank 1 + [tensor([8]), tensor([9]), tensor([10]), tensor([11])] # Rank 2 + [tensor([12]), tensor([13]), tensor([14]), tensor([15])] # Rank 3 + >>> output = list(torch.empty([4], dtype=torch.int64).chunk(4)) + >>> dist.all_to_all(output, input) + >>> output + [tensor([0]), tensor([4]), tensor([8]), tensor([12])] # Rank 0 + [tensor([1]), tensor([5]), tensor([9]), tensor([13])] # Rank 1 + [tensor([2]), tensor([6]), tensor([10]), tensor([14])] # Rank 2 + [tensor([3]), tensor([7]), tensor([11]), tensor([15])] # Rank 3 + + >>> # Essentially, it is similar to following operation: + >>> scatter_list = input + >>> gather_list = output + >>> for i in range(world_size): + >>> dist.scatter(gather_list[i], scatter_list if i == rank else [], src=i) + + >>> input + tensor([0, 1, 2, 3, 4, 5]) # Rank 0 + tensor([10, 11, 12, 13, 14, 15, 16, 17, 18]) # Rank 1 + tensor([20, 21, 22, 23, 24]) # Rank 2 + tensor([30, 31, 32, 33, 34, 35, 36]) # Rank 3 + >>> input_splits + [2, 2, 1, 1] # Rank 0 + [3, 2, 2, 2] # Rank 1 + [2, 1, 1, 1] # Rank 2 + [2, 2, 2, 1] # Rank 3 + >>> output_splits + [2, 3, 2, 2] # Rank 0 + [2, 2, 1, 2] # Rank 1 + [1, 2, 1, 2] # Rank 2 + [1, 2, 1, 1] # Rank 3 + >>> input = list(input.split(input_splits)) + >>> input + [tensor([0, 1]), tensor([2, 3]), tensor([4]), tensor([5])] # Rank 0 + [tensor([10, 11, 12]), tensor([13, 14]), tensor([15, 16]), tensor([17, 18])] # Rank 1 + [tensor([20, 21]), tensor([22]), tensor([23]), tensor([24])] # Rank 2 + [tensor([30, 31]), tensor([32, 33]), tensor([34, 35]), tensor([36])] # Rank 3 + >>> output = ... + >>> dist.all_to_all(output, input) + >>> output + [tensor([0, 1]), tensor([10, 11, 12]), tensor([20, 21]), tensor([30, 31])] # Rank 0 + [tensor([2, 3]), tensor([13, 14]), tensor([22]), tensor([32, 33])] # Rank 1 + [tensor([4]), tensor([15, 16]), tensor([23]), tensor([34, 35])] # Rank 2 + [tensor([5]), tensor([17, 18]), tensor([24]), tensor([36])] # Rank 3 + + >>> # Another example with tensors of torch.cfloat type. + >>> input = torch.tensor( + ... [1 + 1j, 2 + 2j, 3 + 3j, 4 + 4j], dtype=torch.cfloat + ... ) + 4 * rank * (1 + 1j) + >>> input = list(input.chunk(4)) + >>> input + [tensor([1+1j]), tensor([2+2j]), tensor([3+3j]), tensor([4+4j])] # Rank 0 + [tensor([5+5j]), tensor([6+6j]), tensor([7+7j]), tensor([8+8j])] # Rank 1 + [tensor([9+9j]), tensor([10+10j]), tensor([11+11j]), tensor([12+12j])] # Rank 2 + [tensor([13+13j]), tensor([14+14j]), tensor([15+15j]), tensor([16+16j])] # Rank 3 + >>> output = list(torch.empty([4], dtype=torch.int64).chunk(4)) + >>> dist.all_to_all(output, input) + >>> output + [tensor([1+1j]), tensor([5+5j]), tensor([9+9j]), tensor([13+13j])] # Rank 0 + [tensor([2+2j]), tensor([6+6j]), tensor([10+10j]), tensor([14+14j])] # Rank 1 + [tensor([3+3j]), tensor([7+7j]), tensor([11+11j]), tensor([15+15j])] # Rank 2 + [tensor([4+4j]), tensor([8+8j]), tensor([12+12j]), tensor([16+16j])] # Rank 3 + + """ + if _rank_not_in_group(group): + _warn_not_in_group("all_to_all") + return + + opts = AllToAllOptions() + opts.asyncOp = async_op + _check_tensor_list(output_tensor_list, "output_tensor_list") + _check_tensor_list(input_tensor_list, "input_tensor_list") + _ensure_all_tensors_same_dtype(output_tensor_list, input_tensor_list) + + input_tensor_list = [ + t if not t.is_complex() else torch.view_as_real(t) for t in input_tensor_list + ] + output_tensor_list = [ + t if not t.is_complex() else torch.view_as_real(t) for t in output_tensor_list + ] + + group = group or _get_default_group() + work = group.alltoall(output_tensor_list, input_tensor_list, opts) + + if async_op: + return work + elif ( + work is not None + ): # Backward compatible with backends that don't sync at CPP level + work.wait() + # Otherwise, the backend has sync'ed at CPP level + + +@_exception_logger +def barrier( + group: ProcessGroup | None = GroupMember.WORLD, + async_op: bool = False, + device_ids=None, +): + """ + Synchronize all processes. + + This collective blocks processes until the whole group enters this function, + if async_op is False, or if async work handle is called on wait(). + + Args: + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + async_op (bool, optional): Whether this op should be an async op + device_ids ([int], optional): List of device/GPU ids. Only one id is expected. + + Returns: + Async work handle, if async_op is set to True. + None, if not async_op or if not part of the group + + .. note:: `ProcessGroupNCCL` now blocks the cpu thread till the completion of the barrier collective. + .. note:: `ProcessGroupNCCL` implements barrier as an all_reduce of a 1-element tensor. A device must be chosen + for allocating this tensor. The device choice is made by checking in this order (1) the first device passed to + `device_ids` arg of barrier if not None, (2) the device passed to init_process_group if not None, (3) the device + that was first used with this process group, if another collective with tensor inputs has been performed, (4) + the device index indicated by the global rank mod local device count. + """ + group = group or _get_default_group() + + if _rank_not_in_group(group): + _warn_not_in_group("barrier") + return + + opts = BarrierOptions() + opts.asyncOp = async_op + # Detect the accelerator on the machine. If no accelerator is available, it + # returns CPU. + device = torch._C._get_accelerator() + if isinstance(device_ids, list): + opts.device_ids = device_ids + # use only the first device id + # pyrefly: ignore [read-only] + opts.device = torch.device(device.type, device_ids[0]) + elif getattr(group, "bound_device_id", None) is not None: + # Use device id from `init_process_group(device_id=...)` + opts.device = group.bound_device_id # type: ignore[assignment] + elif device.type == "cpu" or _get_object_coll_device(group) == "cpu": + # pyrefly: ignore [read-only] + opts.device = torch.device("cpu") + else: + # Use the current device set by the user. If user did not set any, this + # may use default device 0, causing issues like hang or all processes + # creating context on device 0. + # pyrefly: ignore [read-only] + opts.device = device + if group.rank() == 0: + warnings.warn( # warn only once + "barrier(): using the device under current context. " + "You can specify `device_id` in `init_process_group` to mute this warning.", + stacklevel=2, + ) + + work = group.barrier(opts=opts) + + if async_op: + return work + elif ( + work is not None + ): # Backward compatible with backends that don't sync at CPP level + work.wait() + # Otherwise, the backend has sync'ed at CPP level + + +def monitored_barrier( + group: ProcessGroup | None = GroupMember.WORLD, + timeout=None, + wait_all_ranks: bool = False, +): + """ + Synchronize processes similar to ``torch.distributed.barrier``, but consider a configurable timeout. + + It is able to report ranks that did not pass this barrier within the provided timeout. + Specifically, for non-zero ranks, will block until a send/recv is processed from rank 0. + Rank 0 will block until all send /recv from other ranks are processed, and will report + failures for ranks that failed to respond in time. Note that if one rank does not reach the + monitored_barrier (for example due to a hang), all other ranks would fail in monitored_barrier. + + This collective will block all processes/ranks in the group, until the + whole group exits the function successfully, making it useful for debugging + and synchronizing. However, it can have a performance impact and should only + be used for debugging or scenarios that require full synchronization points + on the host-side. For debugging purposes, this barrier can be inserted + before the application's collective calls to check if any ranks are + desynchronized. + + .. note:: Note that this collective is only supported with the GLOO backend. + + Args: + group (ProcessGroup, optional): The process group to work on. If + ``None``, the default process group will be used. + timeout (datetime.timedelta, optional): Timeout for monitored_barrier. + If ``None``, the default process group timeout will be used. + wait_all_ranks (bool, optional): Whether to collect all failed ranks or + not. By default, this is ``False`` and ``monitored_barrier`` on rank 0 + will throw on the first failed rank it encounters in order to fail + fast. By setting ``wait_all_ranks=True`` ``monitored_barrier`` will + collect all failed ranks and throw an error containing information + about all failed ranks. + + Returns: + ``None``. + + Example:: + >>> # xdoctest: +SKIP("need process group init") + >>> # Note: Process group initialization omitted on each rank. + >>> import torch.distributed as dist + >>> if dist.get_rank() != 1: + >>> dist.monitored_barrier() # Raises exception indicating that + >>> # rank 1 did not call into monitored_barrier. + >>> # Example with wait_all_ranks=True + >>> if dist.get_rank() == 0: + >>> dist.monitored_barrier(wait_all_ranks=True) # Raises exception + >>> # indicating that ranks 1, 2, ... world_size - 1 did not call into + >>> # monitored_barrier. + """ + # Need to call rank not in group before using the group, otherwise + # "Invalid process group" error is raised. + if _rank_not_in_group(group): + _warn_not_in_group("monitored_barrier") + return + + if get_backend(group) != Backend.GLOO: + raise ValueError("monitored_barrier is only implemented for GLOO backend.") + + if timeout is None: + timeout = _get_default_timeout(get_backend(group)) + elif isinstance(timeout, float): + # TODO(whc) apparently some existing test case for monitored_barrier passes in a timeout in float format? + warnings.warn( + "Please specify timeout arg as a timedelta. " + f"Converting current value of {timeout} assuming it represents seconds", + stacklevel=2, + ) + timeout = timedelta(seconds=timeout) + + _check_valid_timeout(timeout) + + group_to_use = _get_default_group() if group is None else group + return group_to_use.monitored_barrier( # type:ignore[attr-defined] + timeout, wait_all_ranks=wait_all_ranks + ) + + +def _create_process_group_wrapper( + wrapped_pg: torch._C._distributed_c10d.Backend, + store_prefix: str, + store: Store, + rank: int, + world_size: int, + timeout: timedelta = default_pg_timeout, +): + if not _GLOO_AVAILABLE: + raise AssertionError("ProcessGroupWrapper unsupported without GLOO backend.") + + # (whc) this appears to be just for the gloo backend? if so, `default_pg_timeout` is appropriate... + + # Create a separate prefix store for the helper process group. + prefix = f"{PG_WRAPPER_STORE_PREFIX}:{store_prefix}" + store = PrefixStore(prefix, store) + helper_pg = ProcessGroupGloo(store, rank, world_size, timeout=timeout) + # Wrap the underlying pg with ProcessGroupWrapper. + wrapped_pg = _ProcessGroupWrapper(wrapped_pg, helper_pg) + return wrapped_pg + + +# helper function for deterministically hashing a list of ranks to a unique +# string +def _hash_ranks_to_str(ranks: list[int]) -> str: + rank_join: str = "_".join(map(str, ranks)) + # In case there is already a PG with the same rank composition + unique_str = "_".join([rank_join, str(len(_world.pg_names))]) + return hashlib.sha1(bytes(unique_str, "utf-8"), usedforsecurity=False).hexdigest() + + +# Takes a list of ranks and computes an integer color +def _process_group_color(ranks: list[int]) -> int: + # Convert list to tuple to make it hashable + # pyrefly: ignore [bad-assignment] + ranks = tuple(ranks) + hash_value = hash(ranks) + # Split color must be: + # - a non-negative integer; + # - a type compatible with C's int because we are pybinding to the latter. + # Thus, we limit the hash value within c_int's max value. + max_c_int = 2 ** (ctypes.sizeof(ctypes.c_int) * 8 - 1) + color = abs(hash_value) % max_c_int + return color + + +def _process_group_name(ranks, use_hashed_name) -> GroupName: + # Create name for a process group. + global _world + if use_hashed_name: + pg_name = GroupName(_hash_ranks_to_str(ranks)) + else: + pg_name = GroupName(str(_world.group_count)) + _world.group_count += 1 + # TODO: why is group count incremented only in the else path? + return pg_name + + +def _get_backend_from_str(backend: str | None = None) -> Backend: + # Default to the same backend as the global process group + # if backend is not specified. + if not backend: + backend = get_backend(_get_default_group()) + return Backend(backend) + + +def _is_safe_to_split() -> bool: + """ + Checks if it is safe to split the any process group in the world. + This is only safe if the default pg has a bound device id, otherwise + users must be aware that a pg is only splittable after the first collective is + issued. + """ + return _get_default_group().bound_device_id is not None + + +@_time_logger +def split_group( + parent_pg: ProcessGroup | None = None, + split_ranks: list | None = None, + timeout: timedelta | None = None, + pg_options: Any | None = None, + group_desc: str | None = None, +) -> ProcessGroup | None: + """ + Create a new process group split from the given parent process group. + + warning:: This is an experimental API. Only the ``NCCL`` and custom plugin backends + are supported. Other backends will raise an error. + Users of this API must guarantee that all ranks in the parent group enter this API call, + and the split of the sub groups is the same across all ranks in the parent group. + + Args: + parent_pg (ProcessGroup, optional): The parent process group. If None, + the default process group will be used. Users need to guarantee that + the parent group is fully initialized (e.g, communicators are initialized) + split_ranks (list[list[int]]): the split ranks, which is a list of list of ranks. + Users need to make sure the validity of the split ranks such that one + split (represented by one inner list of ints) does not overlap with any other split. + Note that the ranks in each split is the group rank (instead of global rank) + in the parent pg. For example, if the parent group has 4 ranks, and split_ranks can be + [[0, 1], [2, 3]]. Note [[0,1]] is also a valid split, in which case ranks 2, 3 would + return a non-group member. + timeout (timedelta, optional): see `init_process_group` for details and default value. + pg_options (ProcessGroupOptions, optional): Additional options need to be passed in during + the construction of specific process groups. i.e.``is_high_priority_stream`` + can be specified so that process group can pick up high priority cuda streams. + group_desc (str, optional): a string to describe the process group. + + Returns: + ProcessGroup if the current rank is within one split/subgroup given by split_ranks, + or None if the current rank is not part of any split_ranks`. + + """ + # check inputs + if split_ranks is None or len(split_ranks) == 0: + raise ValueError("split_ranks cannot be None or empty") + + global _world + default_pg = _get_default_group() + device_id = default_pg.bound_device_id + if not device_id: + raise RuntimeError( + "No device associated with the default pg, not safe to split any process groups" + ) + global_rank = default_pg.rank() + global_world_size = default_pg.size() + + if not parent_pg: + parent_pg = default_pg + if parent_pg not in _world.pg_group_ranks: + raise ValueError(f"Group {parent_pg} is not registered") + + parent_global_to_group_ranks = _world.pg_group_ranks[parent_pg] + parent_group_to_global_ranks = { + group_rank: global_rank + for global_rank, group_rank in parent_global_to_group_ranks.items() + } + + if global_rank not in parent_global_to_group_ranks: + raise ValueError( + f"Global rank {global_rank} is not part of the parent group {parent_pg}" + ) + + parent_group_rank = parent_global_to_group_ranks[global_rank] + parent_backend = parent_pg._get_backend(torch.device("cuda")) + + # if the parent backend does not support splitting, raise error + # currently this API only support NCCL backend + if not parent_backend or not parent_backend.supports_splitting: + raise RuntimeError( + "No backend for the parent process group or its backend does not support splitting" + ) + + # set the group_desc before the color or no_cloor split + if hasattr(parent_backend, "comm_split_count") and group_desc is None: + group_desc = f"{parent_pg.group_desc}:split:{parent_backend.comm_split_count()}" # type: ignore[attr-defined] + + parent_backend_str, _ = _world.pg_map[parent_pg] + # same type of backend as the parent process group + backend = Backend(parent_backend_str) + backend_config = BackendConfig(backend) + + if pg_options is None: + # default pg_options same as the parent process group + # A deep copy is needed because if the option will be modified inside split + # and if we split parent pg multiple times, we will run into device out of bound error. + pg_options = copy.deepcopy(parent_backend.options) + + # this timeout defaulting/validation is used for all the new_groups/new_subgroups variants, + # which may just pass their timeout value (or None) + if timeout is None: + timeout = _get_default_timeout(backend) + _check_valid_timeout(timeout) + + # find my group of ranks and my group local rank in split_ranks + # for ranks which are not in any split PGs, we just pass in this the first split group + # and None will be returned. + my_group = split_ranks[0] + + for split_group in split_ranks: + if len(split_group) == 0: + raise ValueError("the split group cannot be empty") + if len(split_group) > global_world_size: + raise ValueError( + "the split group's size should be less or equal to the world_size set by init_process_group" + ) + if len(split_group) != len(set(split_group)): + raise ValueError("the split group cannot have duplicate ranks") + split_group = sorted(split_group) + if parent_group_rank in split_group: + my_group = split_group + break + + # use_hashed_name is True to ensure that subgroups have unique names. + # This is needed as some backends (e.g. Gloo) use the group name as a + # PrefixStore prefix for initialization of splits. Thus, names have to be + # unique to avoid key collisions. + group_name = _process_group_name(my_group, use_hashed_name=True) + split_pg = parent_pg.split_group( + my_group, + timeout=timeout, + opts=pg_options, + group_name=group_name, + group_desc=group_desc, + ) + if split_pg is None: + return None + + global_ranks_in_my_group = [parent_group_to_global_ranks[rank] for rank in my_group] + split_pg.bound_device_id = device_id # type: ignore[union-attr] + split_backend_class = split_pg._get_backend(torch.device("cuda")) + split_backend_class._set_sequence_number_for_group() + if split_pg.group_name != group_name: + raise AssertionError( + f"group name should be set to {group_name} but got {split_pg.group_name}" + ) + + # update global state + _world.pg_map[split_pg] = (backend, split_pg.get_group_store()) + _world.pg_names[split_pg] = group_name + _register_process_group(group_name, split_pg) + _world.pg_backend_config[split_pg] = str(backend_config) + pg_tag = f"ptd:{group_name}" + _world.tags_to_pg.setdefault(pg_tag, []).append(split_pg) + _world.pg_to_tag[split_pg] = pg_tag + + # Create the global rank to group rank mapping + _world.pg_group_ranks[split_pg] = { + global_rank: group_rank + for group_rank, global_rank in enumerate(global_ranks_in_my_group) + } + + return split_pg + + +@_time_logger +def new_group( + ranks=None, + timeout=None, + backend=None, + pg_options=None, + use_local_synchronization: bool = False, + group_desc=None, + device_id: torch.device | None = None, +): + """ + Create a new distributed group. + + This function requires that all processes in the main group (i.e. all + processes that are part of the distributed job) enter this function, even + if they are not going to be members of the group. Additionally, groups + should be created in the same order in all processes. + + .. warning:: + Safe concurrent usage: + When using multiple process groups with the ``NCCL`` backend, the user + must ensure a globally consistent execution order of collectives across + ranks. + + If multiple threads within a process issue collectives, explicit + synchronization is necessary to ensure consistent ordering. + + When using async variants of torch.distributed communication APIs, + a work object is returned and the communication kernel is + enqueued on a separate CUDA stream, allowing overlap of communication + and computation. Once one or more async ops have been issued on one process + group, they must be synchronized with other cuda streams by calling `work.wait()` + before using another process group. + + See `Using multiple NCCL communicators concurrently + ` + for more details. + + Args: + ranks (list[int]): List of ranks of group members. If ``None``, will be + set to all ranks. Default is ``None``. + timeout (timedelta, optional): see `init_process_group` for details and default value. + backend (str or Backend, optional): The backend to use. Depending on + build-time configurations, valid values are ``gloo`` and ``nccl``. + By default uses the same backend as the global group. This field + should be given as a lowercase string (e.g., ``"gloo"``), which can + also be accessed via :class:`Backend` attributes (e.g., + ``Backend.GLOO``). If ``None`` is passed in, the backend + corresponding to the default process group will be used. Default is + ``None``. + pg_options (ProcessGroupOptions, optional): process group options + specifying what additional options need to be passed in during + the construction of specific process groups. i.e. for the ``nccl`` + backend, ``is_high_priority_stream`` can be specified so that + process group can pick up high priority cuda streams. For other available options to config nccl, + See https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/types.html#ncclconfig-tuse_local_synchronization + (bool, optional): perform a group-local barrier at the end of the process group creation. + This is different in that non-member ranks don't need to call into API and don't + join the barrier. + group_desc (str, optional): a string to describe the process group. + device_id (torch.device, optional): a single, specific device + to "bind" this process to, The `new_group` call will try to initialize + a communication backend immediately for the device if this field is given. + + Returns: + A handle of distributed group that can be given to collective calls or + GroupMember.NON_GROUP_MEMBER if the rank is not part of ``ranks``. + + N.B. use_local_synchronization doesn't work with MPI. + + N.B. While use_local_synchronization=True can be significantly faster with larger + clusters and small process groups, care must be taken since it changes cluster behavior + as non-member ranks don't join the group barrier(). + + N.B. use_local_synchronization=True can lead to deadlocks when each rank creates + multiple overlapping process groups. To avoid that, make sure all ranks follow the + same global creation order. + """ + return _new_group_with_tag( + ranks, + timeout, + backend, + pg_options, + None, + use_local_synchronization=use_local_synchronization, + group_desc=group_desc, + device_id=device_id, + ) + + +def _new_group_with_tag( + ranks=None, + timeout=None, + backend=None, + backend_options=None, + pg_tag=None, + use_local_synchronization=False, + group_desc=None, + device_id: torch.device | None = None, +): + """ + Variant of ``new_group`` that exposes tag creation. + + :: N.B. The mechanism is experimental and tied to the functional collectives effort, see + ``torch.distributed._functional_collectives`` for reference on how to use it. + """ + global _world + + default_pg = _get_default_group() + if device_id is None: + device_id = default_pg.bound_device_id + elif default_pg.bound_device_id is not None: + if device_id != default_pg.bound_device_id: + raise AssertionError( + "Mismatched bound device between new pg and the default pg." + ) + default_backend, default_store = _world.pg_map[default_pg] + global_rank = default_pg.rank() + global_world_size = default_pg.size() + + # Default to the same backend as the global process group + # if the backend is not specified. + if not backend: + backend = default_backend + backend = Backend(backend) + + # this timeout defaulting/validation is used for all the new_groups/new_subgroups variants, + # which may just pass their timeout value (or None) + if timeout is None: + timeout = _get_default_timeout(backend) + _check_valid_timeout(timeout) + + if use_local_synchronization: + # MPI backend doesn't have have a way for us to perform a partial sync + if backend == Backend.MPI: + raise ValueError( + "MPI backend doesn't support use_local_synchronization=True" + ) + if ranks is not None and get_rank() not in ranks: + return None + + # checks the input ranks + if ranks is not None: + ranks = sorted(ranks) + group_world_size = len(ranks) + if group_world_size > global_world_size: + raise ValueError( + "the new group's world size should be less or " + "equal to the world size set by " + "init_process_group" + ) + # check ranks' sanity + for rank in ranks: + if rank < 0 or rank >= global_world_size: + raise ValueError( + "The new group's rank should be within " + "the world_size set by init_process_group" + ) + if global_rank in ranks: + group_rank = ranks.index(global_rank) + else: + group_rank = None + else: + ranks = list(range(global_world_size)) + group_world_size = global_world_size + group_rank = global_rank + + group_name = _process_group_name(ranks, use_hashed_name=use_local_synchronization) + + pg, pg_store = _new_process_group_helper( + group_world_size, + group_rank, + ranks, + backend, + default_store, + group_name, + backend_options=backend_options, + timeout=timeout, + pg_tag=pg_tag, + device_id=device_id, + group_desc=group_desc, + ) + + # Create the global rank to group rank mapping + _world.pg_group_ranks[pg] = { + global_rank: group_rank for group_rank, global_rank in enumerate(ranks) + } + + if _is_barrier_after_init() == 1: + # barrier at the end to ensure that once we return from this method, all + # process groups including global variables (if any) are updated + # correctly on all ranks. + # Update 04/2023: for large-scale runs, this barrier (esp. store-based + # barrier) may be costly and/or unscalable. Also, in a lot of cases, + # these barriers may be unnecessary, as proven by a green CI after + # removal. An environment variable `TORCH_DIST_INIT_BARRIER` has been + # added which enables this barrier only when set to 1. + logger.info( + "Performing barrier after ProcessGroup initialization since " + "TORCH_DIST_INIT_BARRIER = 1" + ) + if backend == Backend.MPI: + # MPI doesn't have store. + barrier() + else: + barrier_store = pg_store if use_local_synchronization else default_store + world_size = len(ranks) if use_local_synchronization else get_world_size() + # Use store based barrier here since barrier() used a bunch of + # default devices and messes up NCCL internal state. + _store_based_barrier( + global_rank, barrier_store, group_name, world_size, timeout + ) + + return pg + + +def new_subgroups( + group_size=None, + group=None, + timeout=None, + backend=None, + pg_options=None, + group_desc=None, +): + """ + Create subgroups of equal size. + + By default, it creates intra-machine subgroups, + where each of which contains all the ranks of a machine, based on the assumption + that each machine has the same number of devices. + + This is a convenience API that calls ``new_group`` to generate multiple subgroups. + It requires that all processes in the main group (i.e. all + processes that are part of the distributed job) enter this function, even + if they are not going to be members of the group. + + .. warning:: + If ``group_size`` is passed in, the world size must be divisible by ``group_size``. + If no ``group_size`` is passed in, it believe that you are creating a group based + on CUDA and determining the group size by number of CUDA devices, and if not all + the machines have the same number of devices, the subgroup division will be + different across nodes and can cause unexpected behaviors. Therefore, if you are + creating a subgroup that does not depend on CUDA (such as Gloo on CPU), please + pass in ``group_size`` correctly. + + .. warning:: + See warning `Safe concurrent usage` for `new_group` API for important details about + using multiple process groups concurrently in a safe manner. + + Args: + group_size (int, optional): The size of each subgroup. If ``None``, + the default subgroup size is equal to the number of devices on each machine, + based on the assumption that each machine has exactly the same + number of devices. Default is ``None``. + group (ProcessGroup, optional): The process group to work on. If + ``None``, the default process group will be used. Default is ``None``. + timeout (timedelta, optional): see `init_process_group` for details and default value. + backend (str or Backend, optional): The backend to use. Depending on + build-time configurations, valid values are ``gloo`` and ``nccl``. + By default uses the same backend as the global group. This field + should be given as a lowercase string (e.g., ``"gloo"``), which can + also be accessed via :class:`Backend` attributes (e.g., + ``Backend.GLOO``). If ``None`` is passed in, the backend + corresponding to the default process group will be used. Default is + ``None``. + pg_options (ProcessGroupOptions, optional): process group options + specifying what additional options need to be passed in during + the construction of specific process groups. i.e. for the ``nccl`` + backend, ``is_high_priority_stream`` can be specified so that + process group can pick up high priority cuda streams. + group_desc (str, optional): A string describing the group. Each subgroup will + inherit its group_desc + + Returns: + The subgroup containing the current rank, and all the subgroups used for cleanup. + + Examples: + >>> # Create intra-machine subgroups. + >>> # xdoctest: +SKIP("need process group init") + >>> cur_subgroup, subgroups = dist.new_subgroups() + >>> # Allreduce within the machine. + >>> rank = dist.get_rank() + >>> tensor = torch.ones(1, device=rank) * rank + >>> dist.all_reduce(tensor, group=cur_subgroup) + >>> tensor + tensor([28]) # Assume 8 CUDA devices per machine. 28 is sum(range(8)). + >>> # Cleanup. + >>> for subgroup in subgroups: + >>> dist.destroy_process_group(subgroup) + """ + if group_size is None: + if not torch.cuda.is_available(): + raise ValueError( + "Default group size only takes effect when CUDA is available." + "If your subgroup using a backend that does not depend on CUDA," + "please pass in 'group_size' correctly." + ) + group_size = torch.cuda.device_count() + if group_size <= 0: + raise ValueError(f"The arg 'group_size' ({group_size}) must be positive") + + world_size = get_world_size(group=group) + if world_size < group_size: + raise ValueError( + f"The arg 'group_size' ({group_size}) must not exceed the world size ({world_size})" + ) + if world_size % group_size != 0: + raise ValueError( + f"The world size ({world_size}) must be divisible by '{group_size=}'" + ) + + # TODO: Use itertools.batched(get_process_group_ranks(group=group), group_size) instead when Python 3.12 is supported. + ranks = get_process_group_ranks(group=group) + ranks_per_subgroup_list = [ + ranks[i : i + group_size] for i in range(0, len(ranks), group_size) + ] + return new_subgroups_by_enumeration( + ranks_per_subgroup_list, + timeout=timeout, + backend=backend, + pg_options=pg_options, + group_desc=group_desc, + ) + + +def new_subgroups_by_enumeration( + ranks_per_subgroup_list, + timeout=None, + backend=None, + pg_options=None, + group_desc=None, +): + """ + Create subgroups by dividing the global world. + + The division is specified by a nested list of ranks. The subgroups cannot have + overlap, and some ranks may not have to be in any subgroup. + + This is a convenience API that calls ``new_group`` to generate multiple subgroups. + It requires that all processes in the main group (i.e. all + processes that are part of the distributed job) enter this function, even + if they are not going to be members of the group. + + .. warning:: + See warning `Safe concurrent usage` for `new_group` API for important details about + using multiple process groups concurrently in a safe manner. + + Args: + ranks_per_subgroup_list (list[list[int]]): A nested list of ranks of + group members. + timeout (timedelta, optional): see `init_process_group` for details and default value. + backend (str or Backend, optional): The backend to use. Depending on + build-time configurations, valid values are ``gloo`` and ``nccl``. + By default uses the same backend as the global group. This field + should be given as a lowercase string (e.g., ``"gloo"``), which can + also be accessed via :class:`Backend` attributes (e.g., + ``Backend.GLOO``). If ``None`` is passed in, the backend + corresponding to the default process group will be used. Default is + ``None``. + pg_options (ProcessGroupOptions, optional): process group options + specifying what additional options need to be passed in during + the construction of specific process groups. i.e. for the ``nccl`` + backend, ``is_high_priority_stream`` can be specified so that + process group can pick up high priority cuda streams. + group_desc (str, optional): A string describing the group. Each subgroup will + inherit its group_desc. + + Returns: + The subgroup containing the current rank, and all the subgroups used for cleanup. + + Examples: + >>> # Create two subgroups, where each has 2 processes. + >>> # xdoctest: +SKIP("need process group init") + >>> cur_subgroup, subgroups = dist.new_subgroups(ranks=[[0, 2], [1, 3]]) + >>> rank = dist.get_rank() + >>> tensor = torch.ones(1, device=rank) * rank + >>> dist.all_reduce(tensor, group=cur_subgroup) + >>> tensor + tensor([2]) # Subgroup 0: ranks 0 and 2 + tensor([4]) # Subgroup 1: ranks 1 and 3 + """ + if ranks_per_subgroup_list is None or len(ranks_per_subgroup_list) == 0: + raise ValueError("The arg 'ranks_per_subgroup_list' cannot be empty") + + subgroups = [] + cur_subgroup = None + # Create a mapping from rank to subgroup to check if there is any subgroup overlap. + rank_to_ranks_dict = {} # type: ignore[var-annotated] + for ranks in ranks_per_subgroup_list: + subgroup = new_group( + ranks=ranks, + timeout=timeout, + backend=backend, + pg_options=pg_options, + group_desc=group_desc, + ) + subgroups.append(subgroup) + my_rank = get_rank() + for rank in ranks: + if rank in rank_to_ranks_dict: + raise ValueError( + f"Rank {rank} has appeared in both subgroup {rank_to_ranks_dict[rank]} and {ranks}" + ) + rank_to_ranks_dict[rank] = ranks + if my_rank == rank: + cur_subgroup = subgroup + logger.info("Rank %s is assigned to subgroup %s", rank, ranks) + + return cur_subgroup, subgroups + + +def _find_pg_by_ranks_and_tag(tag: str, ranks: list[int]) -> ProcessGroup | None: + if len(tag) > 0 and not tag.startswith("ptd:") and not tag.startswith("user:"): + tag = f"user:{tag}" + + for group in _world.tags_to_pg.get(tag, []): + if group.size() != len(ranks): + continue + + group_ranks = get_process_group_ranks(group) + good = all(r in group_ranks for r in ranks) + if good: + return group + return None + + +def _find_or_create_pg_by_ranks_and_tag( + tag: str, ranks: list[int], stride: int +) -> ProcessGroup: + if len(ranks) % stride != 0: + raise ValueError( + f"Ranks length ({len(ranks)}) must be divisible by stride ({stride})" + ) + + my_rank = get_rank() + my_ranks = None + + if stride == len(ranks): + my_ranks = ranks.copy() + if my_rank not in my_ranks: + raise AssertionError("rankset doesn't include the current node") + else: + for i in range(0, len(ranks), stride): + rank_set = ranks[i : i + stride] + if my_rank in rank_set: + my_ranks = rank_set + if my_ranks is None: + raise AssertionError("rankset doesn't include the current node") + + my_ranks = sorted(my_ranks) + + pg = _find_pg_by_ranks_and_tag(tag, my_ranks) + if pg is not None: + return pg + if tag == "": + raise ValueError("Cannot automatically create PG with empty tag") + # TODO copy settings and timeout from default PG + return _new_group_with_tag(my_ranks, pg_tag=tag) + + +def _get_group_tag(pg: ProcessGroup) -> str: + """Return the tag associated with ``pg``.""" + tag = _world.pg_to_tag[pg] + tag = tag.removeprefix("user:") + return tag + + +def _get_process_group_name(pg: ProcessGroup) -> str: + return _world.pg_names.get(pg, "None") + + +def _get_process_group_store(pg: ProcessGroup) -> Store: + return _world.pg_map[pg][1] + + +# Shrink flags for process group backends +SHRINK_DEFAULT = 0x00 +SHRINK_ABORT = 0x01 + + +@_time_logger +def shrink_group( + ranks_to_exclude: list[int], + group: ProcessGroup | None = None, + shrink_flags: int = SHRINK_DEFAULT, + pg_options: Any | None = None, +) -> ProcessGroup: + """ + Shrinks a process group by excluding specified ranks. + + Creates and returns a new, smaller process group comprising only the ranks + from the original group that were not in the ``ranks_to_exclude`` list. + + Args: + ranks_to_exclude (List[int]): A list of ranks from the original + ``group`` to exclude from the new group. + group (ProcessGroup, optional): The process group to shrink. If ``None``, + the default process group is used. Defaults to ``None``. + shrink_flags (int, optional): Flags to control the shrinking behavior. + Can be ``SHRINK_DEFAULT`` (default) or ``SHRINK_ABORT``. + ``SHRINK_ABORT`` will attempt to terminate ongoing operations + in the parent communicator before shrinking. + Defaults to ``SHRINK_DEFAULT``. + pg_options (ProcessGroupOptions, optional): Backend-specific options to apply + to the shrunken process group. If provided, the backend will use + these options when creating the new group. If omitted, the new group + inherits defaults from the parent. + + Returns: + ProcessGroup: a new group comprised of the remaining ranks. If the + default group was shrunk, the returned group becomes the new default group. + + Raises: + TypeError: if the group’s backend does not support shrinking. + ValueError: if ``ranks_to_exclude`` is invalid (empty, out of bounds, + duplicates, or excludes all ranks). + RuntimeError: if an excluded rank calls this function or the backend + fails the operation. + + Notes: + - Only non-excluded ranks should call this function; excluded ranks + must not participate in the shrink operation. + - Shrinking the default group destroys all other process groups since + rank reassignment makes them inconsistent. + """ + # Step 1: Validate input parameters with comprehensive error checking + _validate_shrink_inputs(ranks_to_exclude, shrink_flags) + + # Step 2: Get target group and essential properties + target_group_info = _prepare_shrink_target_group(group) + + # Step 3: Validate backend requirements and availability + backend_impl = _validate_shrink_backend_requirements(target_group_info) + + # Step 4: Validate ranks against group and check for duplicates + excluded_ranks_set = _validate_and_process_excluded_ranks( + ranks_to_exclude, target_group_info + ) + + # Step 5: Execute the actual shrink operation (backend-specific) + new_backend = backend_impl.shrink( + sorted(excluded_ranks_set), + shrink_flags, + pg_options if pg_options is not None else None, + ) + + # Step 6: Handle cleanup and creation of new process group + target_group_info["pg_options_override"] = pg_options + return _finalize_shrunk_group(target_group_info, excluded_ranks_set, new_backend) + + +def _validate_shrink_inputs(ranks_to_exclude: list[int], shrink_flags: int) -> None: + """Validate input parameters for shrink_group.""" + if not isinstance(ranks_to_exclude, list): + raise TypeError( + f"ranks_to_exclude must be a list, but got {type(ranks_to_exclude).__name__}. " + f"Example: [1, 3, 5] to exclude ranks 1, 3, and 5." + ) + + if not ranks_to_exclude: + raise ValueError( + "ranks_to_exclude cannot be empty. To shrink a group, you must specify at least " + "one rank to exclude. Example: [failed_rank_id]" + ) + + # Validate shrink_flags with clear explanation of valid values + valid_flags = [SHRINK_DEFAULT, SHRINK_ABORT] + if not isinstance(shrink_flags, int) or shrink_flags not in valid_flags: + raise ValueError( + f"Invalid shrink_flags value: {shrink_flags}. Must be one of: " + f"SHRINK_DEFAULT ({SHRINK_DEFAULT}) or SHRINK_ABORT ({SHRINK_ABORT}). " + f"Use SHRINK_ABORT to abort ongoing operations before shrinking." + ) + + +def _prepare_shrink_target_group(group: ProcessGroup | None) -> dict: + """Prepare and validate the target group for shrinking.""" + target_pg = group if group is not None else _get_default_group() + + # Cache frequently accessed properties to avoid repeated calls + group_size = int(target_pg.size()) + group_info = { + "process_group": target_pg, + "is_default_group": (target_pg == _get_default_group()), + "group_size": group_size, + "current_rank": target_pg.rank(), + "group_name": _get_process_group_name(target_pg), + } + + # Validate that we have a valid process group + if group_size <= 1: + raise ValueError( + f"Cannot shrink a process group with size {group_size}. " + f"Group must have at least 2 ranks to support shrinking." + ) + + return group_info + + +def _validate_shrink_backend_requirements(group_info: dict) -> Any: + """Return the backend implementation for the target group or raise if unsupported.""" + target_pg = group_info["process_group"] + group_name = group_info["group_name"] + + # Get the group's backend directly via ProcessGroup API. Prefer a bound device if present, + # otherwise try CUDA then fall back to CPU. + try: + preferred_device = getattr(target_pg, "bound_device_id", None) + if preferred_device is not None: + backend_impl = target_pg._get_backend(preferred_device) + else: + # Try CUDA first if available, else CPU + try: + backend_impl = target_pg._get_backend(torch.device("cuda")) + except Exception: + backend_impl = target_pg._get_backend(torch.device("cpu")) + except RuntimeError as e: + raise RuntimeError( + f"Cannot access device backend for process group '{group_name}'. " + f"Ensure the process group was initialized with a compatible device backend and devices are available." + ) from e + + try: + supports = bool(backend_impl.supports_shrinking) + except Exception: + supports = False + if not supports: + raise TypeError( + f"Process group backend for '{group_name}' does not support shrinking operations." + ) + + return backend_impl + + +def _validate_and_process_excluded_ranks( + ranks_to_exclude: list[int], group_info: dict +) -> set: + """Validate excluded ranks and convert to set for efficient operations.""" + group_size = group_info["group_size"] + current_rank = group_info["current_rank"] + + # Use set for O(1) duplicate detection and membership testing + excluded_ranks_set = set() + + # Validate each rank with detailed error messages + for i, rank in enumerate(ranks_to_exclude): + if not isinstance(rank, int): + raise TypeError( + f"All elements in ranks_to_exclude must be integers. " + f"Element at index {i} is {type(rank).__name__}: {rank}" + ) + + if not (0 <= rank < group_size): + raise ValueError( + f"Rank {rank} at index {i} is out of bounds for group size {group_size}. " + f"Valid ranks are in range [0, {group_size - 1}]." + ) + + if rank in excluded_ranks_set: + raise ValueError( + f"Duplicate rank {rank} found in ranks_to_exclude at index {i}. " + f"Each rank can only be excluded once." + ) + + excluded_ranks_set.add(rank) + + # Ensure we don't exclude all ranks + if len(excluded_ranks_set) >= group_size: + raise ValueError( + f"Cannot exclude all {group_size} ranks from process group. " + f"At least one rank must remain. Excluding {len(excluded_ranks_set)} ranks." + ) + + # Critical check: current rank should not be in excluded list + if current_rank in excluded_ranks_set: + raise RuntimeError( + f"Current rank {current_rank} is in the exclusion list and should not call shrink_group(). " + f"Only non-excluded ranks should participate in the shrinking operation. " + f"Excluded ranks should terminate their processes instead." + ) + + return excluded_ranks_set + + +def _finalize_shrunk_group( + group_info: dict, excluded_ranks_set: set, new_backend +) -> ProcessGroup: + """Clean up old group and create new shrunk process group.""" + target_pg = group_info["process_group"] + is_default_group = group_info["is_default_group"] + + # Handle default group dependencies - destroy other groups first + if is_default_group: + _destroy_all_other_groups(exclude_group=target_pg) + + # Gather original group metadata before cleanup + original_group_metadata = _extract_group_metadata(target_pg) + + # Calculate remaining ranks efficiently + original_ranks = get_process_group_ranks(target_pg) + remaining_ranks = [ + rank for rank in original_ranks if rank not in excluded_ranks_set + ] + + # Clean up the original group + _cleanup_original_group(target_pg, is_default_group) + + # Create and configure the new process group + new_pg = _create_shrunk_process_group( + new_backend, remaining_ranks, original_group_metadata, is_default_group + ) + + # Register the new group in global state + if is_default_group: + _update_default_pg(new_pg) + + # Update global state with new group information + rank_mapping = { + global_rank: group_rank + for group_rank, global_rank in enumerate(remaining_ranks) + } + _update_process_group_global_state( + pg=new_pg, + backend_name=original_group_metadata["backend_name"], + store=original_group_metadata["store"], + group_name=original_group_metadata["new_group_name"], + backend_config=original_group_metadata["backend_config"], + rank_mapping=rank_mapping, + ) + + return new_pg + + +def _extract_group_metadata(target_pg: ProcessGroup) -> dict: + """Extract metadata from the original group before cleanup.""" + original_backend_name, original_store = _world.pg_map[target_pg] + original_backend_config = _world.pg_backend_config.get(target_pg, "") + original_group_name = _get_process_group_name(target_pg) + + # Extract device binding information before cleanup to avoid accessing destroyed group + bound_device_id = None + if hasattr(target_pg, "bound_device_id"): + bound_device_id = target_pg.bound_device_id + + # Generate new group name for the shrunk group; hash for uniqueness across backends + remaining_ranks = list(get_process_group_ranks(target_pg)) + new_group_name = _process_group_name(remaining_ranks, use_hashed_name=True) + + return { + "backend_name": original_backend_name, + "store": original_store, + "backend_config": original_backend_config, + "original_group_name": original_group_name, + "new_group_name": new_group_name, + "bound_device_id": bound_device_id, # Safe to access after cleanup + } + + +def _cleanup_original_group(target_pg: ProcessGroup, is_default_group: bool) -> None: + """Clean up the original process group safely.""" + try: + destroy_process_group(target_pg) + except Exception: + group_type = "default" if is_default_group else "non-default" + logger.warning( + "Failed to destroy %s group during shrinking", group_type, exc_info=True + ) + + # Ensure global state cleanup even if destroy_process_group fails + _cleanup_process_group_global_state(target_pg) + + +def _create_shrunk_process_group( + new_backend, remaining_ranks: list[int], metadata: dict, is_default_group: bool +) -> ProcessGroup: + """Create and configure the new shrunk process group.""" + # Create new group properties + new_group_rank = new_backend.rank() + new_group_size = new_backend.size() + group_name = metadata["new_group_name"] + + # Generate descriptive group description + if is_default_group: + group_desc = "default:shrunken" + else: + group_desc = f"{metadata['original_group_name']}:shrunk" + + # Create process group with new communicator (clone the parent store like split does) + prefix_store = PrefixStore(f"{group_name}/", metadata["store"].clone()) + new_pg = ProcessGroup(prefix_store, new_group_rank, new_group_size) + + # Configure backend using the device type of the new backend's bound device if available, + # otherwise derive from the original group's bound device or fall back to CPU. + backend_device = metadata.get("bound_device_id") + if backend_device is None: + # Default to CPU if no bound device is present + backend_device = torch.device("cpu") + + # Choose backend enum based on device type + if backend_device.type == "cuda": + backend_type = ProcessGroup.BackendType.NCCL + else: + backend_type = ProcessGroup.BackendType.GLOO + + new_pg._register_backend(backend_device, backend_type, new_backend) + new_pg._set_default_backend(backend_type) + + # Inherit device binding from original group if it was bound + bound_device_id = metadata.get("bound_device_id") + if bound_device_id is not None: + new_pg.bound_device_id = bound_device_id + + # Set group metadata + new_pg._set_group_name(group_name) + new_pg._set_group_desc(group_desc) + + # Persist backend configuration overrides (if provided via shrink_group) + backend_config_override = metadata.get("backend_config") + if backend_config_override is not None: + # Store for introspection/debugging and potential backend hooks + _world.pg_backend_config[new_pg] = backend_config_override + + return new_pg + + +def _destroy_all_other_groups(exclude_group: ProcessGroup | None = None) -> None: + """ + Destroy all process groups except the excluded group and clean up all global state. + + This is necessary when shrinking the default group because global ranks + are reassigned by NCCL, making all existing process groups inconsistent. + + Note: Uses abort for non-collective cleanup since excluded ranks may not + participate in collective operations. Backend cleanup is handled independently per group. + + Args: + exclude_group (ProcessGroup, optional): Process group to exclude from destruction. + If None, destroys all process groups. + """ + # Get list of groups to destroy (avoid modifying dict while iterating) + groups_to_destroy = [] + for pg in list(_world.pg_group_ranks.keys()): + if exclude_group is not None and pg == exclude_group: + continue + groups_to_destroy.append(pg) + + # Warn user about automatic destruction + if groups_to_destroy: + group_names = [_get_process_group_name(pg) for pg in groups_to_destroy] + logger.warning( + "Shrinking default group will destroy %d other process groups: %s. " + "This is necessary because shrinking the default group reassigns global ranks, " + "making existing groups inconsistent.", + len(groups_to_destroy), + ", ".join(group_names), + ) + + # Destroy each group and clean up global state + for pg in groups_to_destroy: + try: + # First call abort_process_group which handles the C++ cleanup non-collectively + _abort_process_group(pg) + except Exception: + # Log but don't fail - some groups might already be destroyed + logger.warning( + "Failed to abort process group %s", + _get_process_group_name(pg), + exc_info=True, + ) + + # Ensure all global state is cleaned up even if _abort_process_group fails + # or doesn't clean up everything + _cleanup_process_group_global_state(pg) + + +def _cleanup_process_group_global_state(pg: ProcessGroup) -> None: + """ + Clean up all global state associated with a process group. + + This function ensures complete cleanup of process group state from all + global dictionaries and registries, even if destroy_process_group fails + or doesn't clean up everything. This is critical when destroying multiple + groups to prevent inconsistent state. + + The cleanup removes the process group from: + - _world.pg_map (backend and store mapping) + - _world.pg_names (group name mapping) + - _world.pg_group_ranks (rank mappings) + - _world.pg_backend_config (backend configuration) + - _world.tags_to_pg and _world.pg_to_tag (tag mappings) + - _world.pg_coalesce_state (coalescing state) + - C++ internal registries via _unregister_process_group + + Args: + pg (ProcessGroup): The process group to clean up. + """ + try: + # Clean up main process group mappings + _world.pg_map.pop(pg, None) + _world.pg_group_ranks.pop(pg, None) + _world.pg_backend_config.pop(pg, None) + + # Clean up process group name mapping + group_name = _world.pg_names.pop(pg, None) + + # Clean up tag mappings + pg_tag = _world.pg_to_tag.pop(pg, None) + if pg_tag is not None and pg_tag in _world.tags_to_pg: + try: + _world.tags_to_pg[pg_tag].remove(pg) + # Remove the tag entry if list is empty + if not _world.tags_to_pg[pg_tag]: + _world.tags_to_pg.pop(pg_tag, None) + except (ValueError, KeyError): + # Process group was already removed from the list + pass + + # Clean up any registered process group names using C++ unregister function + if group_name is not None: + try: + _unregister_process_group(group_name) + except Exception: + # Process group name might not be registered or already unregistered + pass + + # Clean up coalesce state if present + _world.pg_coalesce_state.pop(pg, None) + + except Exception: + # Log cleanup failures but don't propagate - we want to continue with other cleanups + logger.warning( + "Failed to fully clean up global state for process group", exc_info=True + ) + + +def _update_process_group_global_state( + pg: ProcessGroup, + backend_name: str, + store: Store, + group_name: GroupName, + backend_config: str, + rank_mapping: dict[int, int] | None = None, + pg_tag: str | None = None, + user_tag: str | None = None, +) -> None: + """ + Update all global state dictionaries for a process group. + + This helper function consolidates the common pattern of updating multiple + global state dictionaries when creating or modifying process groups. + + Args: + pg (ProcessGroup): The process group to update state for. + backend_name (str): Backend name for pg_map. + store (Store): Store instance for pg_map. + group_name (str): Group name for pg_names and registration. + backend_config (str): Backend configuration string. + rank_mapping (Dict[int, int], optional): Global rank to group rank mapping. + If None, skips updating pg_group_ranks. + pg_tag (str, optional): Process group tag. If None, defaults to f"ptd:{group_name}". + user_tag (str, optional): User-provided tag for special tag handling. + If provided, creates "user:{user_tag}" tag and also adds to default "". + """ + # Update main process group mappings + _world.pg_map[pg] = (backend_name, store) + _world.pg_names[pg] = group_name + _world.pg_backend_config[pg] = backend_config + + # Register the process group name + _register_process_group(group_name, pg) + + # Update rank mapping if provided + if rank_mapping is not None: + _world.pg_group_ranks[pg] = rank_mapping + + # Handle tag management + if pg_tag is None: + pg_tag = f"ptd:{group_name}" + + if user_tag is not None: + # Special handling for user-provided tags + # Add to default "" tag first + _world.tags_to_pg.setdefault("", []).append(pg) + # Then create user-specific tag + user_pg_tag = f"user:{user_tag}" + _world.tags_to_pg.setdefault(user_pg_tag, []).append(pg) + _world.pg_to_tag[pg] = user_pg_tag + else: + # Standard process group tag + _world.tags_to_pg.setdefault(pg_tag, []).append(pg) + _world.pg_to_tag[pg] = pg_tag diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/launch.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/launch.py new file mode 100644 index 0000000000000000000000000000000000000000..ad3307c13303d0319af710923669d119b4cff30c --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/launch.py @@ -0,0 +1,207 @@ +# mypy: allow-untyped-defs +r""" +Module ``torch.distributed.launch``. + +``torch.distributed.launch`` is a module that spawns up multiple distributed +training processes on each of the training nodes. + +.. warning:: + + This module is going to be deprecated in favor of :ref:`torchrun `. + +The utility can be used for single-node distributed training, in which one or +more processes per node will be spawned. The utility can be used for either +CPU training or GPU training. If the utility is used for GPU training, +each distributed process will be operating on a single GPU. This can achieve +well-improved single-node training performance. It can also be used in +multi-node distributed training, by spawning up multiple processes on each node +for well-improved multi-node distributed training performance as well. +This will especially be beneficial for systems with multiple Infiniband +interfaces that have direct-GPU support, since all of them can be utilized for +aggregated communication bandwidth. + +In both cases of single-node distributed training or multi-node distributed +training, this utility will launch the given number of processes per node +(``--nproc-per-node``). If used for GPU training, this number needs to be less +or equal to the number of GPUs on the current system (``nproc_per_node``), +and each process will be operating on a single GPU from *GPU 0 to +GPU (nproc_per_node - 1)*. + +**How to use this module:** + +1. Single-Node multi-process distributed training + +:: + + python -m torch.distributed.launch --nproc-per-node=NUM_GPUS_YOU_HAVE + YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3 and all other + arguments of your training script) + +2. Multi-Node multi-process distributed training: (e.g. two nodes) + + +Node 1: *(IP: 192.168.1.1, and has a free port: 1234)* + +:: + + python -m torch.distributed.launch --nproc-per-node=NUM_GPUS_YOU_HAVE + --nnodes=2 --node-rank=0 --master-addr="192.168.1.1" + --master-port=1234 YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3 + and all other arguments of your training script) + +Node 2: + +:: + + python -m torch.distributed.launch --nproc-per-node=NUM_GPUS_YOU_HAVE + --nnodes=2 --node-rank=1 --master-addr="192.168.1.1" + --master-port=1234 YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3 + and all other arguments of your training script) + +3. To look up what optional arguments this module offers: + +:: + + python -m torch.distributed.launch --help + + +**Important Notices:** + +1. This utility and multi-process distributed (single-node or +multi-node) GPU training currently only achieves the best performance using +the NCCL distributed backend. Thus NCCL backend is the recommended backend to +use for GPU training. + +2. In your training program, you must parse the command-line argument: +``--local-rank=LOCAL_PROCESS_RANK``, which will be provided by this module. +If your training program uses GPUs, you should ensure that your code only +runs on the GPU device of LOCAL_PROCESS_RANK. This can be done by: + +Parsing the local_rank argument + +:: + + >>> # xdoctest: +SKIP + >>> import argparse + >>> parser = argparse.ArgumentParser() + >>> parser.add_argument("--local-rank", "--local_rank", type=int) + >>> args = parser.parse_args() + +Set your device to local rank using either + +:: + + >>> torch.cuda.set_device(args.local_rank) # before your code runs + +or + +:: + + >>> with torch.cuda.device(args.local_rank): + >>> # your code to run + >>> ... + +.. versionchanged:: 2.0.0 + + The launcher will passes the ``--local-rank=`` argument to your script. + From PyTorch 2.0.0 onwards, the dashed ``--local-rank`` is preferred over the + previously used underscored ``--local_rank``. + + For backward compatibility, it may be necessary for users to handle both + cases in their argument parsing code. This means including both ``"--local-rank"`` + and ``"--local_rank"`` in the argument parser. If only ``"--local_rank"`` is + provided, the launcher will trigger an error: "error: unrecognized arguments: + --local-rank=". For training code that only supports PyTorch 2.0.0+, + including ``"--local-rank"`` should be sufficient. + +3. In your training program, you are supposed to call the following function +at the beginning to start the distributed backend. It is strongly recommended +that ``init_method=env://``. Other init methods (e.g. ``tcp://``) may work, +but ``env://`` is the one that is officially supported by this module. + +:: + + >>> torch.distributed.init_process_group(backend='YOUR BACKEND', + >>> init_method='env://') + +4. In your training program, you can either use regular distributed functions +or use :func:`torch.nn.parallel.DistributedDataParallel` module. If your +training program uses GPUs for training and you would like to use +:func:`torch.nn.parallel.DistributedDataParallel` module, +here is how to configure it. + +:: + + >>> model = torch.nn.parallel.DistributedDataParallel(model, + >>> device_ids=[args.local_rank], + >>> output_device=args.local_rank) + +Please ensure that ``device_ids`` argument is set to be the only GPU device id +that your code will be operating on. This is generally the local rank of the +process. In other words, the ``device_ids`` needs to be ``[args.local_rank]``, +and ``output_device`` needs to be ``args.local_rank`` in order to use this +utility + +5. Another way to pass ``local_rank`` to the subprocesses via environment variable +``LOCAL_RANK``. This behavior is enabled when you launch the script with +``--use-env=True``. You must adjust the subprocess example above to replace +``args.local_rank`` with ``os.environ['LOCAL_RANK']``; the launcher +will not pass ``--local-rank`` when you specify this flag. + +.. warning:: + + ``local_rank`` is NOT globally unique: it is only unique per process + on a machine. Thus, don't use it to decide if you should, e.g., + write to a networked filesystem. See + https://github.com/pytorch/pytorch/issues/12042 for an example of + how things can go wrong if you don't do this correctly. + + + +""" + +from typing_extensions import deprecated as _deprecated + +from torch.distributed.run import get_args_parser, run + + +def parse_args(args): + parser = get_args_parser() + parser.add_argument( + "--use-env", + "--use_env", + default=False, + action="store_true", + help="Use environment variable to pass " + "'local rank'. For legacy reasons, the default value is False. " + "If set to True, the script will not pass " + "--local-rank as argument, and will instead set LOCAL_RANK.", + ) + return parser.parse_args(args) + + +def launch(args): + if args.no_python and not args.use_env: + raise ValueError( + "When using the '--no-python' flag, you must also set the '--use-env' flag." + ) + run(args) + + +@_deprecated( + "The module torch.distributed.launch is deprecated\n" + "and will be removed in future. Use torchrun.\n" + "Note that --use-env is set by default in torchrun.\n" + "If your script expects `--local-rank` argument to be set, please\n" + "change it to read from `os.environ['LOCAL_RANK']` instead. See \n" + "https://pytorch.org/docs/stable/distributed.html#launch-utility for \n" + "further instructions\n", + category=FutureWarning, +) +def main(args=None): + args = parse_args(args) + launch(args) + + +if __name__ == "__main__": + main() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/logging_handlers.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/logging_handlers.py new file mode 100644 index 0000000000000000000000000000000000000000..ed6832fd1ae834b6365a6b005b07bbbfffe90726 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/logging_handlers.py @@ -0,0 +1,16 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging + + +__all__: list[str] = [] + +_log_handlers: dict[str, logging.Handler] = { + "default": logging.NullHandler(), +} diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/remote_device.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/remote_device.py new file mode 100644 index 0000000000000000000000000000000000000000..3ad0076f5e8901644e56d14530dc36624ecf87a0 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/remote_device.py @@ -0,0 +1,122 @@ +# mypy: allow-untyped-defs + +import torch + + +class _remote_device: + """ + Represents a device on a remote worker. + + Args: + remote_device (str or torch.device): Represents a device on a remote worker. + The string format should be one of the following: + + 1. "/", where the device field can be parsed as torch.device type. + E.g., "trainer0/cpu", "trainer0", "ps0/cuda:0". + In addition, the device field can be optional and the default value is "cpu". + 2. "rank:/", where is the rank of the + process and device can be parsed as torch.device type. + E.g., "rank:0/cpu", "rank:0", "rank:0/cuda:0" + 3. and are optional and formats like "cpu" + and "cuda:1", just represent local devices. + """ + + def __init__(self, remote_device: str | torch.device): + PARSE_ERROR = ( + f"Could not parse remote_device: {remote_device}. The valid format is " + "'/' or 'rank:/' or ''" + ) + self._worker_name = None + self._rank = None + self._device: str | int | torch.device | None = None + + if isinstance(remote_device, torch.device): + self._device = remote_device + elif isinstance(remote_device, str): + fields = remote_device.split("/") + if len(fields) == 2: + # pyrefly: ignore [bad-assignment] + self._worker_name, self._device = fields + elif len(fields) == 1: + # Check if this is a valid device. + if _remote_device._is_valid_local_device(fields[0]): + self._device = fields[0] + else: + # pyrefly: ignore [bad-assignment] + self._worker_name = fields[0] + self._device = "cpu" + else: + raise ValueError(PARSE_ERROR) + else: + raise TypeError(f"Invalid type for remote_device: {type(remote_device)}") + + # Do some basic sanity check (no empty string) + if self._worker_name is not None and not self._worker_name: + raise ValueError(PARSE_ERROR) + + # Validate the device. + self._device = torch.device(self._device) + + # Check for rank based format. + if self._worker_name is not None: + fields = self._worker_name.split(":") + if len(fields) == 2: + # rank:/device format, extract rank + if fields[0] == "rank" and fields[1].isdigit(): + self._rank = int(fields[1]) # type: ignore[assignment] + # pyrefly: ignore [bad-assignment] + self._worker_name = None + else: + raise ValueError(PARSE_ERROR) + elif len(fields) > 2: + raise ValueError(PARSE_ERROR) + + @staticmethod + def _is_valid_local_device(device): + # Check for torch.device + try: + torch.device(device) + return True + except Exception: + return False + + def worker_name(self) -> str | None: + """Return the name of remote worker representing the remote device and ``None`` if no worker name is available.""" + return self._worker_name + + def rank(self) -> int | None: + """ + Returns the rank of remote worker representing the remote device. + Returns ``None`` if no rank is available. + """ + return self._rank + + def device(self) -> torch.device: + """Return the local device on the remote worker.""" + return self._device # type: ignore[return-value] + + def __repr__(self): + if self._device is not None: + if self._worker_name is not None: + return f"{self._worker_name}/{self._device}" + elif self._rank is not None: + return f"rank:{self._rank}/{self._device}" + else: + return str(self._device) + else: + if self._worker_name is not None: + return f"{self._worker_name}" + elif self._rank is not None: + return f"{self._rank}" + else: + raise RuntimeError("Invalid state!") + + def __eq__(self, other): + return isinstance(other, _remote_device) and ( + self._worker_name == other._worker_name + and self._device == other._device + and self._rank == other._rank + ) + + def __hash__(self): + return hash(self._worker_name) ^ hash(self._device) ^ hash(self._rank) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/rendezvous.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/rendezvous.py new file mode 100644 index 0000000000000000000000000000000000000000..f7913341175fbecd69fcd6e621aeb02f2cfc82b6 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/rendezvous.py @@ -0,0 +1,293 @@ +# mypy: allow-untyped-defs +try: + from urllib.parse import urlparse, urlunparse +except ImportError as e: + raise ImportError( + "urllib cannot be found, urlparse from python2 is no longer supported." + ) from e + +import numbers +import os +import sys +from collections.abc import Callable, Iterator +from datetime import timedelta + +from torch.distributed import FileStore, Store, TCPStore + +from .constants import default_pg_timeout + + +_rendezvous_handlers: dict[str, Callable[..., Iterator[tuple[Store, int, int]]]] = {} + +__all__ = ["register_rendezvous_handler", "rendezvous"] + + +def register_rendezvous_handler(scheme, handler): + """ + Register a new rendezvous handler. + + Before we can run collective algorithms, participating processes + need to find each other and exchange information to be able to + communicate. We call this process rendezvous. + + The outcome of the rendezvous process is a triplet containing a + shared key/value store, the rank of the process, and the total + number of participating processes. + + If none of the bundled rendezvous methods apply to your execution + environment you can opt to register your own rendezvous handler. + Pick a unique name and use the URL scheme to identify it when + calling the `rendezvous()` function. + + Args: + scheme (str): URL scheme to identify your rendezvous handler. + handler (function): Handler that is invoked when the + `rendezvous()` function is called with a URL that uses + the corresponding scheme. It must be a generator function + that yields the triplet. + """ + global _rendezvous_handlers + if scheme in _rendezvous_handlers: + raise RuntimeError(f"Rendezvous handler for {scheme}:// already registered") + _rendezvous_handlers[scheme] = handler + + +# Query will have format "rank=0&world_size=1" and is +# converted into {"rank": 0, "world_size": 1} +def _query_to_dict(query: str) -> dict[str, str]: + return { + pair[0]: pair[1] + for pair in (pair.split("=") for pair in filter(None, query.split("&"))) + } + + +def _get_use_libuv_from_query_dict(query_dict: dict[str, str]) -> bool: + # libuv is the default backend for TCPStore. To enable the non-libuv backend, + # user can explicitly specify ``use_libuv=0`` in the URL parameter. + if sys.platform == "win32": + # PyTorch is built without libuv support on windows, so default to 0 + return query_dict.get("use_libuv", os.environ.get("USE_LIBUV", "0")) == "1" + return query_dict.get("use_libuv", os.environ.get("USE_LIBUV", "1")) == "1" + + +def _rendezvous_helper(url: str, rank: int, world_size_opt: int | None, **kwargs): + result = urlparse(url) + if world_size_opt is None: + world_size = -1 + if result.scheme == "env": + rank = int(os.environ.get("RANK", rank)) + # If the world_size env variable is not present then it is a dynamic group + world_size = int(os.environ.get("WORLD_SIZE", world_size)) + else: + world_size = world_size_opt + if rank != -1 or world_size != -1 or world_size_opt is None: + query_dict = _query_to_dict(result.query) + if "rank" in query_dict or "world_size" in query_dict: + raise AssertionError( + f"The url: {url} has node-specific arguments(rank, world_size) already." + ) + if rank != -1: + query_dict["rank"] = str(rank) + if world_size != -1 or world_size_opt is None: + query_dict["world_size"] = str(world_size) + result = result._replace( + query=f"{'&'.join([f'{k}={v}' for k, v in query_dict.items()])}" + ) + # pyrefly: ignore [bad-assignment] + url = urlunparse(result) + + if result.scheme not in _rendezvous_handlers: + raise RuntimeError(f"No rendezvous handler for {result.scheme}://") + return _rendezvous_handlers[result.scheme](url, **kwargs) + + +def rendezvous(url: str, rank: int = -1, world_size: int = -1, **kwargs): + if not isinstance(url, (str, bytes)): + raise RuntimeError(f"`url` must be a string. {type(url)}: {url}") + + if not isinstance(rank, numbers.Integral): + raise RuntimeError(f"`rank` must be an integer. {rank}") + + if not isinstance(world_size, numbers.Integral): + raise RuntimeError(f"`world_size` must be an integer. {world_size}") + + # pyrefly: ignore [bad-argument-type] + return _rendezvous_helper(url, rank, world_size, **kwargs) + + +def _create_store_from_options(backend_options, rank): + store, _, _ = next(_rendezvous_helper(backend_options.init_method, rank, None)) + return store + + +def _rendezvous_error(msg): + return ValueError("Error initializing torch.distributed using " + msg) + + +def _file_rendezvous_handler(url: str, **kwargs): + def _error(msg): + return _rendezvous_error("file:// rendezvous: " + msg) + + result = urlparse(url) + path = result.path + if sys.platform == "win32": + import urllib.request + + full_path = result.netloc + result.path + path = urllib.request.url2pathname(full_path) + if path: + # Normalizing an empty string produces ".", which is not expected. + path = os.path.normpath(path) + + if not path: + raise _error("path missing") + query_dict = _query_to_dict(result.query) + if "rank" not in query_dict: + raise _error("rank parameter missing") + if "world_size" not in query_dict: + raise _error("world size parameter missing") + + rank = int(query_dict["rank"]) + world_size = int(query_dict["world_size"]) + store = FileStore(path, world_size) + yield (store, rank, world_size) + + # If this configuration is invalidated, there is nothing we can do about it + raise RuntimeError("Unable to perform rerendezvous using file:// method") + + +def _torchelastic_use_agent_store() -> bool: + return os.environ.get("TORCHELASTIC_USE_AGENT_STORE", None) == str(True) + + +def _create_c10d_store( + hostname, port, rank, world_size, timeout, use_libuv=True +) -> Store: + """ + Smartly creates a c10d Store object on ``rank`` based on whether we need to reuse agent store. + + The TCPStore server is assumed to be hosted + on ``hostname:port``. + + By default, the TCPStore server uses the asynchronous implementation + ``LibUVStoreDaemon`` which utilizes libuv. + + If ``torchelastic_use_agent_store()`` is ``True``, then it is assumed that + the agent leader (node rank 0) hosts the TCPStore server (for which the + endpoint is specified by the given ``hostname:port``). Hence + ALL ranks will create and return a TCPStore client (e.g. ``start_daemon=False``). + + If ``torchelastic_use_agent_store()`` is ``False``, then rank 0 will host + the TCPStore (with multi-tenancy) and it is assumed that rank 0's hostname + and port are correctly passed via ``hostname`` and ``port``. All + non-zero ranks will create and return a TCPStore client. + """ + # check if port is uint16_t + if not 0 <= port < 2**16: + raise ValueError(f"port must have value from 0 to 65535 but was {port}.") + + if _torchelastic_use_agent_store(): + # We create a new TCPStore for every retry so no need to add prefix for each attempt. + return TCPStore( + host_name=hostname, + port=port, + world_size=world_size, + is_master=False, + timeout=timeout, + ) + else: + start_daemon = rank == 0 + return TCPStore( + host_name=hostname, + port=port, + world_size=world_size, + is_master=start_daemon, + timeout=timeout, + multi_tenant=True, + use_libuv=use_libuv, + ) + + +def _tcp_rendezvous_handler( + url: str, timeout: timedelta = default_pg_timeout, **kwargs +): + def _error(msg): + return _rendezvous_error("tcp:// rendezvous: " + msg) + + result = urlparse(url) + if result.port is None: + raise _error("port number missing") + query_dict = _query_to_dict(result.query) + if "rank" not in query_dict: + raise _error("rank parameter missing") + if "world_size" not in query_dict: + raise _error("world size parameter missing") + + rank = int(query_dict["rank"]) + world_size = int(query_dict["world_size"]) + use_libuv = _get_use_libuv_from_query_dict(query_dict) + + if result.hostname is None: + raise AssertionError("hostname cannot be None") + + store = _create_c10d_store( + result.hostname, result.port, rank, world_size, timeout, use_libuv + ) + + yield (store, rank, world_size) + + # If this configuration is invalidated, there is nothing we can do about it + raise RuntimeError("Unable to perform re-rendezvous using tcp:// method") + + +def _env_rendezvous_handler( + url: str, timeout: timedelta = default_pg_timeout, **kwargs +): + def _error(msg): + return _rendezvous_error("env:// rendezvous: " + msg) + + def _env_error(var): + return _error(f"environment variable {var} expected, but not set") + + def _get_env_or_raise(env_var: str) -> str: + env_val = os.environ.get(env_var, None) + if not env_val: + raise _env_error(env_var) + else: + return env_val + + result = urlparse(url) + query_dict = _query_to_dict(result.query) + + rank: int + world_size: int + master_port: int + master_addr: str + + if "rank" in query_dict: + rank = int(query_dict["rank"]) + else: + rank = int(_get_env_or_raise("RANK")) + + if "world_size" in query_dict: + world_size = int(query_dict["world_size"]) + else: + world_size = int(_get_env_or_raise("WORLD_SIZE")) + + master_addr = _get_env_or_raise("MASTER_ADDR") + master_port = int(_get_env_or_raise("MASTER_PORT")) + use_libuv = _get_use_libuv_from_query_dict(query_dict) + + store = _create_c10d_store( + master_addr, master_port, rank, world_size, timeout, use_libuv + ) + + yield (store, rank, world_size) + + # If this configuration is invalidated, there is nothing we can do about it + raise RuntimeError("Unable to perform re-rendezvous using env:// method") + + +register_rendezvous_handler("tcp", _tcp_rendezvous_handler) +register_rendezvous_handler("env", _env_rendezvous_handler) +register_rendezvous_handler("file", _file_rendezvous_handler) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/run.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/run.py new file mode 100644 index 0000000000000000000000000000000000000000..3d8d0fb64276eb4ed8f53dea9a62a55d7c69f14f --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/run.py @@ -0,0 +1,995 @@ +#!/usr/bin/env python3 +# mypy: allow-untyped-defs + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Module ``torch.distributed.run``. + +``torch.distributed.run`` is a module that spawns up multiple distributed +training processes on each of the training nodes. + +``torchrun`` is a python +`console script `_ +to the main module +`torch.distributed.run `_ +declared in the ``entry_points`` configuration in +`setup.py `_. +It is equivalent to invoking ``python -m torch.distributed.run``. + +``torchrun`` can be used for single-node distributed training, in which one or +more processes per node will be spawned. It can be used for either +CPU training or GPU training. If it is used for GPU training, +each distributed process will be operating on a single GPU. This can achieve +well-improved single-node training performance. ``torchrun`` can also be used in +multi-node distributed training, by spawning up multiple processes on each node +for well-improved multi-node distributed training performance as well. +This will especially be beneficial for systems with multiple Infiniband +interfaces that have direct-GPU support, since all of them can be utilized for +aggregated communication bandwidth. + +In both cases of single-node distributed training or multi-node distributed +training, ``torchrun`` will launch the given number of processes per node +(``--nproc-per-node``). If used for GPU training, this number needs to be less +or equal to the number of GPUs on the current system (``nproc_per_node``), +and each process will be operating on a single GPU from *GPU 0 to +GPU (nproc_per_node - 1)*. + +.. versionchanged:: 2.0.0 + + ``torchrun`` will pass the ``--local-rank=`` argument to your script. + From PyTorch 2.0.0 onwards, the dashed ``--local-rank`` is preferred over the + previously used underscored ``--local_rank``. + + For backward compatibility, it may be necessary for users to handle both + cases in their argument parsing code. This means including both ``"--local-rank"`` + and ``"--local_rank"`` in the argument parser. If only ``"--local_rank"`` is + provided, ``torchrun`` will trigger an error: "error: unrecognized arguments: + --local-rank=". For training code that only supports PyTorch 2.0.0+, + including ``"--local-rank"`` should be sufficient. + + :: + + >>> # xdoctest: +SKIP + >>> import argparse + >>> parser = argparse.ArgumentParser() + >>> parser.add_argument("--local-rank", "--local_rank", type=int) + >>> args = parser.parse_args() + +Usage +----- + +Single-node multi-worker +++++++++++++++++++++++++ + +:: + + torchrun + --standalone + --nnodes=1 + --nproc-per-node=$NUM_TRAINERS + YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...) + +.. note:: ``--nproc-per-node`` may be + ``"gpu"`` (spawn one process per GPU), + ``"cpu"`` (spawn one process per CPU), + ``"xpu"`` (spawn one process per XPU), + ``"auto"`` (equivalent to ``"gpu"`` if CUDA is available, + else equivalent to ``"xpu"`` if XPU is available, + else equivalent to ``"cpu"``), + or an integer specifying the number of processes. + See `torch.distributed.run.determine_local_world_size + `_ + for more details. + +Stacked single-node multi-worker +++++++++++++++++++++++++++++++++ + +To run multiple instances (separate jobs) of single-node, multi-worker on the +same host, we need to make sure that each instance (job) is +setup on different ports to avoid port conflicts (or worse, two jobs being merged +as a single job). To do this you have to run with ``--rdzv-backend=c10d`` +and specify a different port by setting ``--rdzv-endpoint=localhost:$PORT_k``. +For ``--nodes=1``, its often convenient to let ``torchrun`` pick a free random +port automatically instead of manually assigning different ports for each run. + +:: + + torchrun + --rdzv-backend=c10d + --rdzv-endpoint=localhost:0 + --nnodes=1 + --nproc-per-node=$NUM_TRAINERS + YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...) + + +Fault tolerant (fixed sized number of workers, no elasticity, tolerates 3 failures) ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +:: + + torchrun + --nnodes=$NUM_NODES + --nproc-per-node=$NUM_TRAINERS + --max-restarts=3 + --rdzv-id=$JOB_ID + --rdzv-backend=c10d + --rdzv-endpoint=$HOST_NODE_ADDR + YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...) + +``HOST_NODE_ADDR``, in form [:] (e.g. node1.example.com:29400), specifies the node and +the port on which the C10d rendezvous backend should be instantiated and hosted. It can be any +node in your training cluster, but ideally you should pick a node that has a high bandwidth. + +.. note:: + If no port number is specified ``HOST_NODE_ADDR`` defaults to 29400. + +Elastic (``min=1``, ``max=4``, tolerates up to 3 membership changes or failures) +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +:: + + torchrun + --nnodes=1:4 + --nproc-per-node=$NUM_TRAINERS + --max-restarts=3 + --rdzv-id=$JOB_ID + --rdzv-backend=c10d + --rdzv-endpoint=$HOST_NODE_ADDR + YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...) + +``HOST_NODE_ADDR``, in form [:] (e.g. node1.example.com:29400), specifies the node and +the port on which the C10d rendezvous backend should be instantiated and hosted. It can be any +node in your training cluster, but ideally you should pick a node that has a high bandwidth. + +.. note:: + If no port number is specified ``HOST_NODE_ADDR`` defaults to 29400. + +Note on rendezvous backend +-------------------------- + +For multi-node training you need to specify: + +1. ``--rdzv-id``: A unique job id (shared by all nodes participating in the job) +2. ``--rdzv-backend``: An implementation of + :py:class:`torch.distributed.elastic.rendezvous.RendezvousHandler` +3. ``--rdzv-endpoint``: The endpoint where the rendezvous backend is running; usually in form + ``host:port``. + +Currently ``c10d`` (recommended), ``etcd-v2``, and ``etcd`` (legacy) rendezvous backends are +supported out of the box. To use ``etcd-v2`` or ``etcd``, setup an etcd server with the ``v2`` api +enabled (e.g. ``--enable-v2``). + +.. warning:: + ``etcd-v2`` and ``etcd`` rendezvous use etcd API v2. You MUST enable the v2 API on the etcd + server. Our tests use etcd v3.4.3. + +.. warning:: + For etcd-based rendezvous we recommend using ``etcd-v2`` over ``etcd`` which is functionally + equivalent, but uses a revised implementation. ``etcd`` is in maintenance mode and will be + removed in a future version. + +Definitions +----------- + +1. ``Node`` - A physical instance or a container; maps to the unit that the job manager works with. + +2. ``Worker`` - A worker in the context of distributed training. + +3. ``WorkerGroup`` - The set of workers that execute the same function (e.g. trainers). + +4. ``LocalWorkerGroup`` - A subset of the workers in the worker group running on the same node. + +5. ``RANK`` - The rank of the worker within a worker group. + +6. ``WORLD_SIZE`` - The total number of workers in a worker group. + +7. ``LOCAL_RANK`` - The rank of the worker within a local worker group. + +8. ``LOCAL_WORLD_SIZE`` - The size of the local worker group. + +9. ``rdzv_id`` - A user-defined id that uniquely identifies the worker group for a job. This id is + used by each node to join as a member of a particular worker group. + +9. ``rdzv_backend`` - The backend of the rendezvous (e.g. ``c10d``). This is typically a strongly + consistent key-value store. + +10. ``rdzv_endpoint`` - The rendezvous backend endpoint; usually in form ``:``. + +A ``Node`` runs ``LOCAL_WORLD_SIZE`` workers which comprise a ``LocalWorkerGroup``. The union of +all ``LocalWorkerGroups`` in the nodes in the job comprise the ``WorkerGroup``. + +Environment Variables +--------------------- + +The following environment variables are made available to you in your script: + +1. ``LOCAL_RANK`` - The local rank. + +2. ``RANK`` - The global rank. + +3. ``GROUP_RANK`` - The rank of the worker group. A number between 0 and ``max_nnodes``. When + running a single worker group per node, this is the rank of the node. + +4. ``ROLE_RANK`` - The rank of the worker across all the workers that have the same role. The role + of the worker is specified in the ``WorkerSpec``. + +5. ``LOCAL_WORLD_SIZE`` - The local world size (e.g. number of workers running locally); equals to + ``--nproc-per-node`` specified on ``torchrun``. + +6. ``WORLD_SIZE`` - The world size (total number of workers in the job). + +7. ``ROLE_WORLD_SIZE`` - The total number of workers that was launched with the same role specified + in ``WorkerSpec``. + +8. ``MASTER_ADDR`` - The FQDN of the host that is running worker with rank 0; used to initialize + the Torch Distributed backend. + +9. ``MASTER_PORT`` - The port on the ``MASTER_ADDR`` that can be used to host the C10d TCP store. + +10. ``TORCHELASTIC_RESTART_COUNT`` - The number of worker group restarts so far. + +11. ``TORCHELASTIC_MAX_RESTARTS`` - The configured maximum number of restarts. + +12. ``TORCHELASTIC_RUN_ID`` - Equal to the rendezvous ``run_id`` (e.g. unique job id). + +13. ``PYTHON_EXEC`` - System executable override. If provided, the python user script will + use the value of ``PYTHON_EXEC`` as executable. The `sys.executable` is used by default. + +Deployment +---------- + +1. (Not needed for the C10d backend) Start the rendezvous backend server and get the endpoint (to be + passed as ``--rdzv-endpoint`` to ``torchrun``) + +2. Single-node multi-worker: Start ``torchrun`` on the host to start the agent process which + creates and monitors a local worker group. + +3. Multi-node multi-worker: Start ``torchrun`` with the same arguments on all the nodes + participating in training. + +When using a job/cluster manager, the entry point command to the multi-node job should be ``torchrun``. + +Failure Modes +------------- + +1. Worker failure: For a training job with ``n`` workers, if ``k<=n`` workers fail all workers + are stopped and restarted up to ``max_restarts``. + +2. Agent failure: An agent failure results in a local worker group failure. It is up to the job + manager to fail the entire job (gang semantics) or attempt to replace the node. Both behaviors + are supported by the agent. + +3. Node failure: Same as agent failure. + +Membership Changes +------------------ + +1. Node departure (scale-down): The agent is notified of the departure, all existing workers are + stopped, a new ``WorkerGroup`` is formed, and all workers are started with a new ``RANK`` and + ``WORLD_SIZE``. + +2. Node arrival (scale-up): The new node is admitted to the job, all existing workers are stopped, + a new ``WorkerGroup`` is formed, and all workers are started with a new ``RANK`` and + ``WORLD_SIZE``. + +Important Notices +----------------- + +1. This utility and multi-process distributed (single-node or + multi-node) GPU training currently only achieves the best performance using + the NCCL distributed backend. Thus NCCL backend is the recommended backend to + use for GPU training. + +2. The environment variables necessary to initialize a Torch process group are provided to you by + this module, no need for you to pass ``RANK`` manually. To initialize a process group in your + training script, simply run: + +:: + + >>> # xdoctest: +SKIP("stub") + >>> import torch.distributed as dist + >>> dist.init_process_group(backend="gloo|nccl") + +3. In your training program, you can either use regular distributed functions + or use :func:`torch.nn.parallel.DistributedDataParallel` module. If your + training program uses GPUs for training and you would like to use + :func:`torch.nn.parallel.DistributedDataParallel` module, + here is how to configure it. + +:: + + local_rank = int(os.environ["LOCAL_RANK"]) + model = torch.nn.parallel.DistributedDataParallel( + model, device_ids=[local_rank], output_device=local_rank + ) + +Please ensure that ``device_ids`` argument is set to be the only GPU device id +that your code will be operating on. This is generally the local rank of the +process. In other words, the ``device_ids`` needs to be ``[int(os.environ("LOCAL_RANK"))]``, +and ``output_device`` needs to be ``int(os.environ("LOCAL_RANK"))`` in order to use this +utility + + +4. On failures or membership changes ALL surviving workers are killed immediately. Make sure to + checkpoint your progress. The frequency of checkpoints should depend on your job's tolerance + for lost work. + +5. This module only supports homogeneous ``LOCAL_WORLD_SIZE``. That is, it is assumed that all + nodes run the same number of local workers (per role). + +6. ``RANK`` is NOT stable. Between restarts, the local workers on a node can be assigned a + different range of ranks than before. NEVER hard code any assumptions about the stable-ness of + ranks or some correlation between ``RANK`` and ``LOCAL_RANK``. + +7. When using elasticity (``min_size!=max_size``) DO NOT hard code assumptions about + ``WORLD_SIZE`` as the world size can change as nodes are allowed to leave and join. + +8. It is recommended for your script to have the following structure: + +:: + + def main(): + load_checkpoint(checkpoint_path) + initialize() + train() + + + def train(): + for batch in iter(dataset): + train_step(batch) + + if should_checkpoint: + save_checkpoint(checkpoint_path) + +9. (Recommended) On worker errors, this tool will summarize the details of the error + (e.g. time, rank, host, pid, traceback, etc). On each node, the first error (by timestamp) + is heuristically reported as the "Root Cause" error. To get tracebacks as part of this + error summary print out, you must decorate your main entrypoint function in your + training script as shown in the example below. If not decorated, then the summary + will not include the traceback of the exception and will only contain the exitcode. + For details on torchelastic error handling see: https://pytorch.org/docs/stable/elastic/errors.html + +:: + + from torch.distributed.elastic.multiprocessing.errors import record + + + @record + def main(): + # do train + pass + + + if __name__ == "__main__": + main() +""" # noqa: E501 + +import os +import sys +import uuid +from argparse import ArgumentParser, REMAINDER +from collections.abc import Callable +from importlib import metadata + +import torch +from torch.distributed.argparse_util import check_env, env +from torch.distributed.elastic.multiprocessing import DefaultLogsSpecs, LogsSpecs, Std +from torch.distributed.elastic.multiprocessing.errors import record +from torch.distributed.elastic.rendezvous.utils import _parse_rendezvous_config +from torch.distributed.elastic.utils import macros +from torch.distributed.elastic.utils.logging import get_logger +from torch.distributed.launcher.api import elastic_launch, LaunchConfig +from torch.numa.binding import ( + AffinityMode as _AffinityMode, # Signify as private with _ + NumaOptions as _NumaOptions, +) +from torch.utils.backend_registration import _get_custom_mod_func + + +logger = get_logger(__name__) + + +def get_args_parser() -> ArgumentParser: + """Parse the command line options.""" + parser = ArgumentParser(description="Torch Distributed Elastic Training Launcher") + + def comma_separated_list(value): + placeholder = "" + value = value.replace(",,", placeholder) + items = value.split(",") + items = [item.replace(placeholder, ",") for item in items] + return items + + # + # Worker/node size related arguments. + # + + parser.add_argument( + "--nnodes", + action=env, + type=str, + default="1:1", + help="Number of nodes, or the range of nodes in form :.", + ) + parser.add_argument( + "--nproc-per-node", + "--nproc_per_node", + action=env, + type=str, + default="1", + help="Number of workers per node; supported values: [auto, cpu, gpu, xpu, int].", + ) + + # + # Rendezvous related arguments + # + + parser.add_argument( + "--rdzv-backend", + "--rdzv_backend", + action=env, + type=str, + default="static", + help="Rendezvous backend.", + ) + parser.add_argument( + "--rdzv-endpoint", + "--rdzv_endpoint", + action=env, + type=str, + default="", + help="Rendezvous backend endpoint; usually in form :.", + ) + parser.add_argument( + "--rdzv-id", + "--rdzv_id", + action=env, + type=str, + default="none", + help="User-defined group id.", + ) + parser.add_argument( + "--rdzv-conf", + "--rdzv_conf", + action=env, + type=str, + default="", + help="Additional rendezvous configuration (=,=,...).", + ) + parser.add_argument( + "--standalone", + action=check_env, + help="Start a local standalone rendezvous backend that is represented by a C10d TCP store " + "on a free port. Useful when launching single-node, multi-worker job. If specified " + "--rdzv-backend, --rdzv-endpoint, --rdzv-id are auto-assigned and any explicitly set values " + "are ignored.", + ) + + # + # User-code launch related arguments. + # + + parser.add_argument( + "--max-restarts", + "--max_restarts", + action=env, + type=int, + default=0, + help="Maximum number of worker group restarts before failing.", + ) + parser.add_argument( + "--monitor-interval", + "--monitor_interval", + action=env, + type=float, + default=0.1, + help="Interval, in seconds, to monitor the state of workers.", + ) + parser.add_argument( + "--start-method", + "--start_method", + action=env, + type=str, + default="spawn", + choices=["spawn", "fork", "forkserver"], + help="Multiprocessing start method to use when creating workers.", + ) + parser.add_argument( + "--event-log-handler", + "--event_log_handler", + action=env, + type=str, + default="null", + help="name of a registered event logging handler (see: https://docs.pytorch.org/docs/stable/elastic/events.html)", + ) + parser.add_argument( + "--role", + action=env, + type=str, + default="default", + help="User-defined role for the workers.", + ) + parser.add_argument( + "-m", + "--module", + action=check_env, + help="Change each process to interpret the launch script as a Python module, executing " + "with the same behavior as 'python -m'.", + ) + parser.add_argument( + "--no-python", + "--no_python", + action=check_env, + help="Skip prepending the training script with 'python' - just execute it directly. Useful " + "when the script is not a Python script.", + ) + + parser.add_argument( + "--run-path", + "--run_path", + action=check_env, + help="Run the training script with runpy.run_path in the same interpreter." + " Script must be provided as an abs path (e.g. /abs/path/script.py)." + " Takes precedence over --no-python.", + ) + parser.add_argument( + "--log-dir", + "--log_dir", + action=env, + type=str, + default=None, + help="Base directory to use for log files (e.g. /var/log/torch/elastic). The same " + "directory is reused for multiple runs (a unique job-level sub-directory is created with " + "rdzv_id as the prefix).", + ) + parser.add_argument( + "-r", + "--redirects", + action=env, + type=str, + default="0", + help="Redirect std streams into a log file in the log directory (e.g. [-r 3] redirects " + "both stdout+stderr for all workers, [-r 0:1,1:2] redirects stdout for local rank 0 and " + "stderr for local rank 1).", + ) + parser.add_argument( + "-t", + "--tee", + action=env, + type=str, + default="0", + help="Tee std streams into a log file and also to console (see --redirects for format).", + ) + + parser.add_argument( + "--local-ranks-filter", + "--local_ranks_filter", + action=env, + type=str, + default="", + help="Only show logs from specified ranks in console (e.g. [--local_ranks_filter=0,1,2] will " + "only show logs from rank 0, 1 and 2). This will only apply to stdout and stderr, not to" + "log files saved via --redirect or --tee", + ) + + parser.add_argument( + "--duplicate-stdout-filters", + "--duplicate_stdout_filters", + action=env, + type=comma_separated_list, + default=[], + help="Duplicates logs streamed to stdout to another specified file with a list of filters (e.g. " + "[--duplicate_stdout_filters 'apple,orange'] will duplicate log lines matching 'apple' " + "OR 'orange'. An empty filters list won't duplicate any lines. Use double comma to escape a comma) ", + ) + + parser.add_argument( + "--duplicate-stderr-filters", + "--duplicate_stderr_filters", + action=env, + type=comma_separated_list, + default=[], + help="Duplicates logs streamed to stderr to another specified file with a list of filters (e.g. " + "[--duplicate_stdout_filters 'apple,orange'] will duplicate log lines matching 'apple' " + "OR 'orange'. An empty filters list won't duplicate any lines. Use double comma to escape a comma) ", + ) + + # + # Backwards compatible parameters with caffe2.distributed.launch. + # + + parser.add_argument( + "--node-rank", + "--node_rank", + type=int, + action=env, + default=0, + help="Rank of the node for multi-node distributed training.", + ) + parser.add_argument( + "--master-addr", + "--master_addr", + default="127.0.0.1", + type=str, + action=env, + help="Address of the master node (rank 0) that only used for static rendezvous. It should " + "be either the IP address or the hostname of rank 0. For single node multi-proc training " + "the --master-addr can simply be 127.0.0.1; IPv6 should have the pattern " + "`[0:0:0:0:0:0:0:1]`.", + ) + parser.add_argument( + "--master-port", + "--master_port", + default=29500, + type=int, + action=env, + help="Port on the master node (rank 0) to be used for communication during distributed " + "training. It is only used for static rendezvous.", + ) + parser.add_argument( + "--local-addr", + "--local_addr", + default=None, + type=str, + action=env, + help="Address of the local node. If specified, will use the given address for connection. " + "Else, will look up the local node address instead. Else, it will be default to local " + "machine's FQDN.", + ) + + parser.add_argument( + "--logs-specs", + "--logs_specs", + default=None, + type=str, + help="torchrun.logs_specs group entrypoint name, value must be type of LogsSpecs. " + "Can be used to override custom logging behavior.", + ) + + parser.add_argument( + "--numa-binding", + "--numa_binding", + type=str, + choices=[mode.value for mode in _AffinityMode], + default=None, + help=""" + If provided, we will affinitize the worker processes based on NUMA nodes + for better performance. (E.g., preferring to allocate memory locally and run on CPUs on the + same NUMA node.) + + NOTE: This is currently only supported for GPUs, and we assume + that the LOCAL_RANK process corresponds to the GPU with index LOCAL_RANK. If this is not + accurate for your workload, this feature may be a pessimization. + + Available options are: + - node: Processes are bound to cpu cores within a NUMA node. This is a good starting point, + but other options may perform even slightly better in some cases. + - socket: Processes are bound to cpu cores within a socket. + - exclusive: Processes are bound to exclusive sets of cpu cores within a NUMA node. + - core-complex: Processes are bound to cpu cores in a core-complex. + NOTE: The core-complex option might not achieve optimal performance on architectures + featuring a single L3 cache per socket.""", + ) + + parser.add_argument( + "--signals-to-handle", + "--signals_to_handle", + action=env, + type=str, + default="SIGTERM,SIGINT,SIGHUP,SIGQUIT", + help="Comma-separated list of signals to handle and forward to subprocesses. " + "Default: SIGTERM,SIGINT,SIGHUP,SIGQUIT. " + "Common additional signals: SIGUSR1,SIGUSR2 (used in SLURM environments).", + ) + + parser.add_argument( + "--virtual-local-rank", + "--virtual_local_rank", + action=check_env, + help="Enable virtual local rank mode for workers. When enabled, LOCAL_RANK is set to 0 " + "for all workers and CUDA_VISIBLE_DEVICES is adjusted so each worker accesses its " + "assigned GPU at device index 0.", + ) + + # + # Positional arguments. + # + + parser.add_argument( + "training_script", + type=str, + help="Full path to the (single GPU) training program/script to be launched in parallel, " + "followed by all the arguments for the training script.", + ) + + # Rest from the training program. + parser.add_argument("training_script_args", nargs=REMAINDER) + + return parser + + +def parse_args(args): + parser = get_args_parser() + return parser.parse_args(args) + + +def parse_min_max_nnodes(nnodes: str): + arr = nnodes.split(":") + + if len(arr) == 1: + min_nodes = max_nodes = int(arr[0]) + elif len(arr) == 2: + min_nodes = int(arr[0]) + max_nodes = int(arr[1]) + else: + raise RuntimeError(f'nnodes={nnodes} is not in "MIN:MAX" format') # noqa: E231 + + return min_nodes, max_nodes + + +def determine_local_world_size(nproc_per_node: str): + try: + logger.info("Using nproc_per_node=%s.", nproc_per_node) + return int(nproc_per_node) + except ValueError as e: + if nproc_per_node == "cpu": + num_proc = os.cpu_count() + device_type = "cpu" + elif nproc_per_node == "gpu": + if not torch.cuda.is_available(): + raise ValueError("Cuda is not available.") from e + device_type = "gpu" + num_proc = torch.cuda.device_count() + elif nproc_per_node == "xpu": + if not torch.xpu.is_available(): + raise ValueError("Xpu is not available.") from e + device_type = "xpu" + num_proc = torch.xpu.device_count() + elif nproc_per_node == torch._C._get_privateuse1_backend_name(): + if not _get_custom_mod_func("is_available")(): + raise ValueError(f"{nproc_per_node} is not available.") from e + device_type = nproc_per_node + num_proc = _get_custom_mod_func("device_count")() + elif nproc_per_node == "auto": + if torch.accelerator.is_available(): + num_proc = torch.accelerator.device_count() + device_type = torch.accelerator.current_accelerator().type # type: ignore[union-attr] + else: + num_proc = os.cpu_count() + device_type = "cpu" + else: + raise ValueError( + f"Unsupported nproc_per_node value: {nproc_per_node}" + ) from e + + logger.info( + "Using nproc_per_node=%s, setting nproc_per_node to %s since the instance has %s %s", + nproc_per_node, + num_proc, + num_proc, + device_type, + ) + return num_proc + + +def get_rdzv_endpoint(args): + if args.rdzv_backend == "static" and not args.rdzv_endpoint: + return f"{args.master_addr}:{args.master_port}" # noqa: E231 + return args.rdzv_endpoint + + +def get_use_env(args) -> bool: + """ + Retrieve ``use_env`` from the args. + + ``use_env`` is a legacy argument, if ``use_env`` is False, the + ``--node-rank`` argument will be transferred to all worker processes. + ``use_env`` is only used by the ``torch.distributed.launch`` and will + be deprecated in future releases. + """ + if not hasattr(args, "use_env"): + return True + return args.use_env + + +def _get_logs_specs_class(logs_specs_name: str | None) -> type[LogsSpecs]: + """ + Attempts to load `torchrun.logs_spec` entrypoint with key of `logs_specs_name` param. + Provides plugin mechanism to provide custom implementation of LogsSpecs. + + Returns `DefaultLogsSpecs` when logs_spec_name is None. + Raises ValueError when entrypoint for `logs_spec_name` can't be found in entrypoints. + """ + logs_specs_cls = None + if logs_specs_name is not None: + eps = metadata.entry_points() + group = eps.select(group="torchrun.logs_specs") + if group.select(name=logs_specs_name): + logs_specs_cls = group[logs_specs_name].load() + + if logs_specs_cls is None: + raise ValueError( + f"Could not find entrypoint under 'torchrun.logs_specs[{logs_specs_name}]' key" + ) + + logger.info( + "Using logs_spec '%s' mapped to %s", logs_specs_name, str(logs_specs_cls) + ) + else: + logs_specs_cls = DefaultLogsSpecs + + return logs_specs_cls + + +def config_from_args(args) -> tuple[LaunchConfig, Callable | str, list[str]]: + # If ``args`` not passed, defaults to ``sys.argv[:1]`` + min_nodes, max_nodes = parse_min_max_nnodes(args.nnodes) + if not (0 < min_nodes <= max_nodes): + raise AssertionError( + f"min_nodes must be > 0 and <= max_nodes, got min_nodes={min_nodes}, max_nodes={max_nodes}" + ) + if args.max_restarts < 0: + raise AssertionError("max_restarts must be >= 0") + + if ( + hasattr(args, "master_addr") + and args.rdzv_backend != "static" + and not args.rdzv_endpoint + ): + logger.warning( + "master_addr is only used for static rdzv_backend and when rdzv_endpoint " + "is not specified." + ) + + nproc_per_node = determine_local_world_size(args.nproc_per_node) + if "OMP_NUM_THREADS" not in os.environ and nproc_per_node > 1: + omp_num_threads = 1 + logger.warning( + "\n*****************************************\n" + "Setting OMP_NUM_THREADS environment variable for each process to be " + "%s in default, to avoid your system being overloaded, " + "please further tune the variable for optimal performance in " + "your application as needed. \n" + "*****************************************", + omp_num_threads, + ) + # This env variable will be passed down to the subprocesses + os.environ["OMP_NUM_THREADS"] = str(omp_num_threads) + + log_line_prefix_template = os.getenv("TORCHELASTIC_LOG_LINE_PREFIX_TEMPLATE") + + rdzv_configs = _parse_rendezvous_config(args.rdzv_conf) + + if args.rdzv_backend == "static": + rdzv_configs["rank"] = args.node_rank + + rdzv_endpoint = get_rdzv_endpoint(args) + + ranks: set[int] | None = None + if args.local_ranks_filter: + try: + ranks = set(map(int, args.local_ranks_filter.split(","))) + if not ranks: + raise AssertionError("ranks set cannot be empty") + except Exception as e: + raise ValueError( + "--local_ranks_filter must be a comma-separated list of integers e.g. --local_ranks_filter=0,1,2" + ) from e + + logs_specs_cls: type[LogsSpecs] = _get_logs_specs_class(args.logs_specs) + # pyrefly: ignore [bad-instantiation] + logs_specs = logs_specs_cls( + log_dir=args.log_dir, + redirects=Std.from_str(args.redirects), + tee=Std.from_str(args.tee), + local_ranks_filter=ranks, + ) + numa_options = ( + None + if args.numa_binding is None + else _NumaOptions(affinity_mode=_AffinityMode(args.numa_binding)) + ) + + config = LaunchConfig( + min_nodes=min_nodes, + max_nodes=max_nodes, + nproc_per_node=nproc_per_node, + run_id=args.rdzv_id, + role=args.role, + rdzv_endpoint=rdzv_endpoint, + rdzv_backend=args.rdzv_backend, + rdzv_configs=rdzv_configs, + max_restarts=args.max_restarts, + monitor_interval=args.monitor_interval, + start_method=args.start_method, + log_line_prefix_template=log_line_prefix_template, + local_addr=args.local_addr, + logs_specs=logs_specs, + event_log_handler=args.event_log_handler, + numa_options=numa_options, + signals_to_handle=args.signals_to_handle, + duplicate_stdout_filters=args.duplicate_stdout_filters, + duplicate_stderr_filters=args.duplicate_stderr_filters, + virtual_local_rank=args.virtual_local_rank, + ) + + with_python = not args.no_python + cmd: Callable | str + cmd_args = [] + use_env = get_use_env(args) + if args.run_path: + cmd = run_script_path + cmd_args.append(args.training_script) + else: + if with_python: + cmd = os.getenv("PYTHON_EXEC", sys.executable) + cmd_args.append("-u") + if args.module: + cmd_args.append("-m") + cmd_args.append(args.training_script) + else: + if args.module: + raise ValueError( + "Don't use both the '--no-python' flag" + " and the '--module' flag at the same time." + ) + cmd = args.training_script + if not use_env: + cmd_args.append(f"--local-rank={macros.local_rank}") + cmd_args.extend(args.training_script_args) + + return config, cmd, cmd_args + + +def run_script_path(training_script: str, *training_script_args: str): + """ + Run the provided `training_script` from within this interpreter. + + Usage: `script_as_function("/abs/path/to/script.py", "--arg1", "val1")` + """ + import runpy + import sys + + sys.argv = [training_script] + [*training_script_args] + runpy.run_path(sys.argv[0], run_name="__main__") + + +def run(args): + torch.multiprocessing._set_thread_name("pt_elastic") + + if args.standalone: + args.rdzv_backend = "c10d" + args.rdzv_endpoint = "localhost:0" + args.rdzv_id = str(uuid.uuid4()) + logger.info( + "\n**************************************\n" + "Rendezvous info:\n" + "--rdzv-backend=%s " + "--rdzv-endpoint=%s " + "--rdzv-id=%s\n" + "**************************************\n", + args.rdzv_backend, + args.rdzv_endpoint, + args.rdzv_id, + ) + + config, cmd, cmd_args = config_from_args(args) + elastic_launch( + config=config, + entrypoint=cmd, + )(*cmd_args) + + +@record +def main(args=None): + args = parse_args(args) + run(args) + + +if __name__ == "__main__": + main() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9422d05bf7e7d5c10b1e9d7bdaf56c21af17019b --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/distributed/utils.py @@ -0,0 +1,381 @@ +# mypy: allow-untyped-defs +import dataclasses +import traceback +from collections import OrderedDict +from collections.abc import Callable, Container +from typing import Any, Optional, overload, TypeVar + +import torch +import torch.distributed as dist +from torch import nn +from torch.nn.utils.rnn import PackedSequence + + +__all__ = [] # type: ignore[var-annotated] + + +def _pack_kwargs(*args: Any, **kwargs: Any) -> tuple[tuple[Any, ...], tuple[str, ...]]: + """ + Turn argument list into separate key list and value list (unpack_kwargs does the opposite). + + Inspiration: https://github.com/facebookresearch/fairscale/blob/eeb6684/fairscale/internal/containers.py#L70 + Usage:: + + kwarg_keys, flat_args = pack_kwargs(1, 2, a=3, b=4) + assert kwarg_keys == ("a", "b") + assert flat_args == (1, 2, 3, 4) + args, kwargs = unpack_kwargs(kwarg_keys, flat_args) + assert args == (1, 2) + assert kwargs == {"a": 3, "b": 4} + Returns: + Tuple[Tuple[Any, ...], Tuple[str, ...]]: The first tuple element gives + gives both positional args and kwarg values, where the positional args + proceed kwarg values and kwarg values are ordered consistently with the + kwarg keys. The second tuple element gives the kwarg keys. + The second tuple element's length is at most the first tuple element's length. + """ + kwarg_keys: list[str] = [] + flat_args: list[Any] = list(args) + for k, v in kwargs.items(): + kwarg_keys.append(k) + flat_args.append(v) + + return tuple(flat_args), tuple(kwarg_keys) + + +def _cast_forward_inputs( + dtype: torch.dtype | None, + *args: Any, + **kwargs: Any, +) -> tuple[Any, Any]: + """ + Cast floating point tensors in ``args`` and ``kwargs`` to ``input_dtype``. + + This respects the existing ``requires_grad`` on the tensors. + """ + if dtype is None: + return args, kwargs + + def cast_fn(x: torch.Tensor) -> torch.Tensor: + if not torch.is_floating_point(x) or x.dtype == dtype: + return x + + return x.to(dtype) + + return (_apply_to_tensors(cast_fn, args), _apply_to_tensors(cast_fn, kwargs)) + + +def _unpack_kwargs( + flat_args: tuple[Any, ...], kwarg_keys: tuple[str, ...] +) -> tuple[tuple[Any, ...], dict[str, Any]]: + """See _pack_kwargs.""" + if len(kwarg_keys) > len(flat_args): + raise AssertionError(f"too many keys {len(kwarg_keys)} vs. {len(flat_args)}") + if len(kwarg_keys) == 0: + return flat_args, {} + args = flat_args[: -len(kwarg_keys)] + kwargs = dict(zip(kwarg_keys, flat_args[-len(kwarg_keys) :])) + return args, kwargs + + +S = TypeVar("S", dict, list, tuple) +T = TypeVar("T", torch.Tensor, PackedSequence) + + +@overload +def _recursive_to( + inputs: S, target_device: torch.device, use_side_stream_for_tensor_copies: bool +) -> list[S]: ... + + +@overload +def _recursive_to( + inputs: T, target_device: torch.device, use_side_stream_for_tensor_copies: bool +) -> tuple[T]: ... + + +def _recursive_to(inputs, target_device, use_side_stream_for_tensor_copies): + r"""Recursively moves input to the target_device.""" + + def to_map(obj): + if isinstance(obj, (torch.Tensor, PackedSequence)): + device = obj.data.device if isinstance(obj, PackedSequence) else obj.device + if device == target_device: + return (obj,) + if not use_side_stream_for_tensor_copies: + return (obj.to(target_device),) + else: + # If the custom module is not registered to torch, stream is not used for acceleration + if device.type == "cpu": + return (obj.to(target_device),) + + from torch.nn.parallel._functions import _get_stream + + # Perform CPU -> target_device copies in a background stream. This code is + # motivated from similar logic in torch/nn/parallel/_functions.py + stream = _get_stream(target_device) + with stream: + output = obj.to(target_device) + # synchronize with the copy stream + with torch.accelerator.device_index(target_device.index): + current_stream = torch.accelerator.current_stream() + # Sync the current stream with the copy stream + current_stream.wait_stream(stream) + # Ensure tensor memory is not reused until work on + # main stream is complete + if isinstance(obj, PackedSequence): + output.data.record_stream(current_stream) # type: ignore[arg-type] + else: + if not isinstance(output, torch.Tensor): + raise AssertionError("output must be a torch.Tensor") + output.record_stream(current_stream) # type: ignore[arg-type] + return (output,) + + from torch.nn.parallel.scatter_gather import _is_namedtuple + + if _is_namedtuple(obj): + # pyrefly: ignore [no-matching-overload] + return [type(obj)(*args) for args in zip(*map(to_map, obj))] + if isinstance(obj, tuple) and len(obj) > 0: + # pyrefly: ignore [no-matching-overload] + return list(zip(*map(to_map, obj))) + if isinstance(obj, list) and len(obj) > 0: + # pyrefly: ignore [no-matching-overload] + return [list(i) for i in zip(*map(to_map, obj))] + if isinstance(obj, dict) and len(obj) > 0: + # pyrefly: ignore [no-matching-overload] + return [type(obj)(i) for i in zip(*map(to_map, obj.items()))] + return [obj] + + # Avoid reference cycle + try: + res = to_map(inputs) + finally: + to_map = None # type: ignore[assignment] + return res + + +def _p_assert(cond: Any, s: str, raise_assertion_error: bool = True) -> None: + """Alternate to ``assert`` when in the backward context to print the error message ``s`` since otherwise, it is swallowed.""" + if not cond: + print(s) + traceback.print_stack() + if raise_assertion_error: + raise AssertionError(s) + + +def _alloc_storage(tensor: torch.Tensor, size: torch.Size) -> None: + """ + Allocate storage for ``tensor`` with the given size. + + Returns: + bool: ``True`` if this method allocated storage and ``False`` if the + storage was already allocated. + """ + with torch.no_grad(): + if not torch.distributed._functional_collectives.is_torchdynamo_compiling(): + already_allocated = tensor._typed_storage()._size() == size.numel() + if not already_allocated: + tensor_storage_size = tensor._typed_storage()._size() + _p_assert( + tensor_storage_size == 0, + "Tensor storage should have been resized to be 0 but got PLACEHOLDEr", + ) + tensor._typed_storage()._resize_(size.numel()) + + +def _free_storage(tensor: torch.Tensor): + """ + Frees the underlying storage of ``tensor``. + + Returns: + bool: ``True`` if the method freed the storage and ``False`` if the + storage was already freed. + """ + with torch.no_grad(): + if not torch.distributed._functional_collectives.is_torchdynamo_compiling(): + already_freed = tensor._typed_storage()._size() == 0 + if not already_freed: + _p_assert( + tensor.storage_offset() == 0, + "Freeing a tensor's storage is unsafe when it is not the sole occupant\n" + f"storage offset: {tensor.storage_offset()}\n" + f"storage size: {tensor._typed_storage()._size()}\n" + f"tensor shape: {tensor.shape}", + ) + tensor._typed_storage()._resize_(0) + + +Q = TypeVar("Q") +R = TypeVar("R", dict, list, tuple, set, OrderedDict, PackedSequence, Any) + + +@overload +def _apply_to_tensors( + fn: Callable[[torch.Tensor], Q], container: torch.Tensor +) -> Q: ... + + +@overload +def _apply_to_tensors(fn: Callable[[torch.Tensor], Any], container: R) -> R: ... + + +def _apply_to_tensors(fn, container): + """Recursively apply to all tensor in different kinds of container types.""" + + def apply(x): + from torch.nn.parallel.scatter_gather import _is_namedtuple + + if isinstance(x, torch.Tensor): + return fn(x) + elif hasattr(x, "__dataclass_fields__"): + dc = dataclasses.replace(x) + changes = { + f.name: apply(getattr(dc, f.name)) for f in dataclasses.fields(dc) + } + return dataclasses.replace(dc, **changes) + elif isinstance(x, OrderedDict): + od = x.__class__() + for key, value in x.items(): + od[key] = apply(value) + return od + elif isinstance(x, PackedSequence): + apply(x.data) + return x + elif isinstance(x, dict): + return {key: apply(value) for key, value in x.items()} + elif _is_namedtuple(x): + res = (apply(el) for el in x) + return type(x)(*res) + elif isinstance(x, (list, tuple, set)): + return type(x)(apply(el) for el in x) + else: + return x + + return apply(container) + + +def _to_kwargs( + inputs: tuple[Any, ...], + kwargs: dict[str, Any] | None, + target_device: torch.device, + use_side_stream_for_tensor_copies: bool, +) -> tuple[tuple[Any, ...], tuple[dict[str, Any], ...]]: + moved_inputs = ( + _recursive_to(inputs, target_device, use_side_stream_for_tensor_copies) + if inputs + else [] + ) + moved_kwargs = ( + _recursive_to(kwargs, target_device, use_side_stream_for_tensor_copies) + if kwargs + else [] + ) + if len(moved_inputs) < len(moved_kwargs): + moved_inputs.extend([() for _ in range(len(moved_kwargs) - len(inputs))]) + elif len(moved_kwargs) < len(moved_inputs): + moved_kwargs.extend([{} for _ in range(len(moved_inputs) - len(moved_kwargs))]) + return tuple(moved_inputs), tuple(moved_kwargs) + + +def _verify_param_shape_across_processes( + process_group: dist.ProcessGroup, + tensors: list[torch.Tensor], + logger: Optional["dist.Logger"] = None, +): + return dist._verify_params_across_processes(process_group, tensors, logger) + + +def _sync_module_states( + module: nn.Module, + process_group: dist.ProcessGroup, + broadcast_bucket_size: int, + src: int, + params_and_buffers_to_ignore: Container[str], + broadcast_buffers: bool = True, +) -> None: + """ + Sync ``module``'s parameters and buffers state. + + Syncs ``module``'s parameters and buffers state so that all ranks contain + the same module state across all ranks. Note that this API assumes that all + parameter shapes are consistent before running the synchronization. This can + be checked with ``_verify_param_shape_across_processes``. + """ + module_states: list[torch.Tensor] = [] + for name, param in module.named_parameters(): + if name not in params_and_buffers_to_ignore: + module_states.append(param.detach()) + + if broadcast_buffers: + for name, buffer in module.named_buffers(): + if name not in params_and_buffers_to_ignore: + module_states.append(buffer.detach()) + + _sync_params_and_buffers(process_group, module_states, broadcast_bucket_size, src) + + +def _sync_params_and_buffers( + process_group: dist.ProcessGroup, + module_states: list[torch.Tensor], + broadcast_bucket_size: int, + src: int, +) -> None: + """Synchronize ``module_states`` (list of tensors) across all processes by broadcasting them from rank 0.""" + if len(module_states) > 0: + dist._broadcast_coalesced( + process_group, module_states, broadcast_bucket_size, src + ) + + +def _replace_by_prefix( + state_dict: dict[str, Any], + old_prefix: str, + new_prefix: str, +) -> None: + """ + Replace all keys that match a given old_prefix with a new_prefix (in-place). + + Usage:: + + state_dict = {"layer.xyz": torch.tensor(1)} + replace_by_prefix_(state_dict, "layer.", "module.layer.") + assert state_dict == {"module.layer.xyz": torch.tensor(1)} + """ + if old_prefix == new_prefix: + raise ValueError("old_prefix and new_prefix must be distinct") + for key in list(state_dict.keys()): + if not key.startswith(old_prefix): + continue + new_key = new_prefix + key[len(old_prefix) :] + state_dict[new_key] = state_dict[key] + del state_dict[key] + + +def _data_ptr_allocated(tensor: torch.Tensor) -> bool: + return tensor.untyped_storage().data_ptr() > 0 + + +def _get_root_modules(modules: list[nn.Module]) -> list[nn.Module]: + """ + Returns the modules in ``modules`` that are root modules (i.e. + parent-less) with respect to the set ``modules``. In other words, these + are the modules in ``modules`` that are the not child of any other + module in ``modules``. + """ + root_modules: list[nn.Module] = [] + module_to_modules: dict[nn.Module, set[nn.Module]] = { + module: set(module.modules()) for module in modules + } + for candidate_module in modules: + is_root_module = True + for module, _modules in module_to_modules.items(): + is_child_module = ( + candidate_module is not module and candidate_module in _modules + ) + if is_child_module: + is_root_module = False + break + if is_root_module: + root_modules.append(candidate_module) + return root_modules diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9865ecf9c95bc36f4784fd3f0f63f0d9bdf27dbd --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/__init__.py @@ -0,0 +1,611 @@ +import logging +import os +import warnings +import zipfile +from collections.abc import Callable, Mapping +from typing import Any +from typing_extensions import deprecated + +import torch +import torch.utils._pytree as pytree +from torch.fx.passes.infra.pass_base import PassResult +from torch.types import FileLike + + +__all__ = [ + "AdditionalInputs", + "Constraint", + "CustomDecompTable", + "default_decompositions", + "Dim", + "dims", + "draft_export", + "export_for_training", + "export", + "ExportBackwardSignature", + "ExportedProgram", + "ExportGraphSignature", + "FlatArgsAdapter", + "load", + "ModuleCallEntry", + "ModuleCallSignature", + "register_dataclass", + "save", + "ShapesCollection", + "unflatten", + "UnflattenedModule", +] + +# To make sure export specific custom ops are loaded +import torch.export.custom_ops + +from .decomp_utils import CustomDecompTable +from .dynamic_shapes import AdditionalInputs, Constraint, Dim, dims, ShapesCollection +from .exported_program import ( + default_decompositions, + ExportedProgram, + ModuleCallEntry, + ModuleCallSignature, +) +from .graph_signature import ExportBackwardSignature, ExportGraphSignature +from .unflatten import FlatArgsAdapter, unflatten, UnflattenedModule + + +PassType = Callable[[torch.fx.GraphModule], PassResult | None] + +log: logging.Logger = logging.getLogger(__name__) + + +@deprecated( + "`torch.export.export_for_training` is deprecated and will be removed in PyTorch 2.10. " + "Please use `torch.export.export` instead, which is functionally equivalent.", + category=FutureWarning, +) +def export_for_training( + mod: torch.nn.Module, + args: tuple[Any, ...], + kwargs: Mapping[str, Any] | None = None, + *, + dynamic_shapes: dict[str, Any] | tuple[Any, ...] | list[Any] | None = None, + strict: bool = False, + preserve_module_call_signature: tuple[str, ...] = (), + prefer_deferred_runtime_asserts_over_guards: bool = False, +) -> ExportedProgram: + """ + :func:`export_for_training` takes any nn.Module along with example inputs, and produces a traced graph representing + only the Tensor computation of the function in an Ahead-of-Time (AOT) fashion, + which can subsequently be executed with different inputs or serialized. The + traced graph (1) produces normalized operators in the all ATen operator set + (as well as any user-specified custom operators), (2) has eliminated all Python control + flow and data structures (with certain exceptions), and (3) records the set of + shape constraints needed to show that this normalization and control-flow elimination + is sound for future inputs. This API is intended for PT2 quantization training use cases + and will soon be the default IR of torch.export.export in the near future. To read further about + the motivation behind this change, please refer to + https://dev-discuss.pytorch.org/t/why-pytorch-does-not-need-a-new-standardized-operator-set/2206 + With this API, and :func:`run_decompositions()`, you should be able to get inference IR with + your custom decomposition behaviour. + + **Soundness Guarantee** + + See :func:`export()` docstring for more details. + + Args: + mod: We will trace the forward method of this module. + + args: Example positional inputs. + + kwargs: Optional example keyword inputs. + + dynamic_shapes: + An optional argument where the type should either be: + 1) a dict from argument names of ``f`` to their dynamic shape specifications, + 2) a tuple that specifies dynamic shape specifications for each input in original order. + If you are specifying dynamism on keyword args, you will need to pass them in the order that + is defined in the original function signature. + + The dynamic shape of a tensor argument can be specified as either + (1) a dict from dynamic dimension indices to :func:`Dim` types, where it is + not required to include static dimension indices in this dict, but when they are, + they should be mapped to None; or (2) a tuple / list of :func:`Dim` types or None, + where the :func:`Dim` types correspond to dynamic dimensions, and static dimensions + are denoted by None. Arguments that are dicts or tuples / lists of tensors are + recursively specified by using mappings or sequences of contained specifications. + + strict: When enabled (default), the export function will trace the program through + TorchDynamo which will ensure the soundness of the resulting graph. Otherwise, the + exported program will not validate the implicit assumptions baked into the graph and + may cause behavior divergence between the original model and the exported one. This is + useful when users need to workaround bugs in the tracer, or simply want incrementally + enable safety in their models. Note that this does not affect the resulting IR spec + to be different and the model will be serialized in the same way regardless of what value + is passed here. + WARNING: This option is experimental and use this at your own risk. + + preserve_module_call_signature: A list of submodule paths for which the original + calling conventions are preserved as metadata. The metadata will be used when calling + torch.export.unflatten to preserve the original calling conventions of modules. + + Returns: + An :class:`ExportedProgram` containing the traced callable. + + **Acceptable input/output types** + + Acceptable types of inputs (for ``args`` and ``kwargs``) and outputs include: + + - Primitive types, i.e. ``torch.Tensor``, ``int``, ``float``, ``bool`` and ``str``. + - Dataclasses, but they must be registered by calling :func:`register_dataclass` first. + - (Nested) Data structures comprising of ``dict``, ``list``, ``tuple``, ``namedtuple`` and + ``OrderedDict`` containing all above types. + + """ + from ._trace import _export_for_training + + if not isinstance(mod, torch.nn.Module): + raise ValueError( + f"Expected `mod` to be an instance of `torch.nn.Module`, got {type(mod)}." + ) + if isinstance(mod, torch.jit.ScriptModule): + raise ValueError( + "Exporting a ScriptModule is not supported. " + "Maybe try converting your ScriptModule to an ExportedProgram " + "using `TS2EPConverter(mod, args, kwargs).convert()` instead." + ) + return _export_for_training( + mod, + args, + kwargs, + dynamic_shapes, + strict=strict, + preserve_module_call_signature=preserve_module_call_signature, + prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards, + ) + + +def export( + mod: torch.nn.Module, + args: tuple[Any, ...], + kwargs: Mapping[str, Any] | None = None, + *, + dynamic_shapes: dict[str, Any] | tuple[Any, ...] | list[Any] | None = None, + strict: bool = False, + preserve_module_call_signature: tuple[str, ...] = (), + prefer_deferred_runtime_asserts_over_guards: bool = False, +) -> ExportedProgram: + """ + :func:`export` takes any nn.Module along with example inputs, and produces a traced graph representing + only the Tensor computation of the function in an Ahead-of-Time (AOT) fashion, + which can subsequently be executed with different inputs or serialized. The + traced graph (1) produces normalized operators in the functional ATen operator set + (as well as any user-specified custom operators), (2) has eliminated all Python control + flow and data structures (with certain exceptions), and (3) records the set of + shape constraints needed to show that this normalization and control-flow elimination + is sound for future inputs. + + **Soundness Guarantee** + + While tracing, :func:`export()` takes note of shape-related assumptions + made by the user program and the underlying PyTorch operator kernels. + The output :class:`ExportedProgram` is considered valid only when these + assumptions hold true. + + Tracing makes assumptions on the shapes (not values) of input tensors. + Such assumptions must be validated at graph capture time for :func:`export` + to succeed. Specifically: + + - Assumptions on static shapes of input tensors are automatically validated without additional effort. + - Assumptions on dynamic shape of input tensors require explicit specification + by using the :func:`Dim` API to construct dynamic dimensions and by associating + them with example inputs through the ``dynamic_shapes`` argument. + + If any assumption can not be validated, a fatal error will be raised. When that happens, + the error message will include suggested fixes to the specification that are needed + to validate the assumptions. For example :func:`export` might suggest the + following fix to the definition of a dynamic dimension ``dim0_x``, say appearing in the + shape associated with input ``x``, that was previously defined as ``Dim("dim0_x")``:: + + dim = Dim("dim0_x", max=5) + + This example means the generated code requires dimension 0 of input ``x`` to be less + than or equal to 5 to be valid. You can inspect the suggested fixes to dynamic dimension + definitions and then copy them verbatim into your code without needing to change the + ``dynamic_shapes`` argument to your :func:`export` call. + + Args: + mod: We will trace the forward method of this module. + + args: Example positional inputs. + + kwargs: Optional example keyword inputs. + + dynamic_shapes: + An optional argument where the type should either be: + 1) a dict from argument names of ``f`` to their dynamic shape specifications, + 2) a tuple that specifies dynamic shape specifications for each input in original order. + If you are specifying dynamism on keyword args, you will need to pass them in the order that + is defined in the original function signature. + + The dynamic shape of a tensor argument can be specified as either + (1) a dict from dynamic dimension indices to :func:`Dim` types, where it is + not required to include static dimension indices in this dict, but when they are, + they should be mapped to None; or (2) a tuple / list of :func:`Dim` types or None, + where the :func:`Dim` types correspond to dynamic dimensions, and static dimensions + are denoted by None. Arguments that are dicts or tuples / lists of tensors are + recursively specified by using mappings or sequences of contained specifications. + + strict: When disabled (default), the export function will trace the program through + Python runtime, which by itself will not validate some of the implicit assumptions + baked into the graph. It will still validate most critical assumptions like shape + safety. When enabled (by setting ``strict=True``), the export function will trace + the program through TorchDynamo which will ensure the soundness of the resulting + graph. TorchDynamo has limited Python feature coverage, thus you may experience more + errors. Note that toggling this argument does not affect the resulting IR spec to be + different and the model will be serialized in the same way regardless of what value + is passed here. + + preserve_module_call_signature: A list of submodule paths for which the original + calling conventions are preserved as metadata. The metadata will be used when calling + torch.export.unflatten to preserve the original calling conventions of modules. + + Returns: + An :class:`ExportedProgram` containing the traced callable. + + **Acceptable input/output types** + + Acceptable types of inputs (for ``args`` and ``kwargs``) and outputs include: + + - Primitive types, i.e. ``torch.Tensor``, ``int``, ``float``, ``bool`` and ``str``. + - Dataclasses, but they must be registered by calling :func:`register_dataclass` first. + - (Nested) Data structures comprising of ``dict``, ``list``, ``tuple``, ``namedtuple`` and + ``OrderedDict`` containing all above types. + + """ + from ._trace import _export + + if not isinstance(mod, torch.nn.Module): + raise ValueError( + f"Expected `mod` to be an instance of `torch.nn.Module`, got {type(mod)}." + ) + if isinstance(mod, torch.jit.ScriptModule): + raise ValueError( + "Exporting a ScriptModule is not supported. " + "Maybe try converting your ScriptModule to an ExportedProgram " + "using `TS2EPConverter(mod, args, kwargs).convert()` instead." + ) + + try: + return _export( + mod, + args, + kwargs, + dynamic_shapes, + strict=strict, + preserve_module_call_signature=preserve_module_call_signature, + pre_dispatch=True, + prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards, + ) + except Exception as e: + draft_export_msg = ( + "The error above occurred when calling torch.export.export. If you would " + "like to view some more information about this error, and get a list " + "of all other errors that may occur in your export call, you can " + "replace your `export()` call with `draft_export()`." + ) + + # For errors that we know can be caught by draft-export, add the message + # to ask users to try out draft-export + if isinstance( + e, + ( + torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode, + torch._subclasses.fake_tensor.UnsupportedOperatorException, + torch._dynamo.exc.UserError, + torch.fx.experimental.symbolic_shapes.ConstraintViolationError, + ), + ): + new_msg = str(e) + "\n\n" + draft_export_msg + e.args = (new_msg,) + elif isinstance(e, RuntimeError) and "no fake impl registered" in str(e): + new_msg = str(e) + "\n\n" + draft_export_msg + e.args = (new_msg,) + raise e + + +DEFAULT_PICKLE_PROTOCOL = 2 + + +def save( + ep: ExportedProgram, + f: FileLike, + *, + extra_files: dict[str, Any] | None = None, + opset_version: dict[str, int] | None = None, + pickle_protocol: int = DEFAULT_PICKLE_PROTOCOL, +) -> None: + """ + + .. warning:: + Under active development, saved files may not be usable in newer versions + of PyTorch. + + Saves an :class:`ExportedProgram` to a file-like object. It can then be + loaded using the Python API :func:`torch.export.load `. + + Args: + ep (ExportedProgram): The exported program to save. + + f (str | os.PathLike[str] | IO[bytes]) A file-like object (has to + implement write and flush) or a string containing a file name. + + extra_files (Optional[Dict[str, Any]]): Map from filename to contents + which will be stored as part of f. + + opset_version (Optional[Dict[str, int]]): A map of opset names + to the version of this opset + + pickle_protocol: can be specified to override the default protocol + + Example:: + + import torch + import io + + + class MyModule(torch.nn.Module): + def forward(self, x): + return x + 10 + + + ep = torch.export.export(MyModule(), (torch.randn(5),)) + + # Save to file + torch.export.save(ep, "exported_program.pt2") + + # Save to io.BytesIO buffer + buffer = io.BytesIO() + torch.export.save(ep, buffer) + + # Save with extra files + extra_files = {"foo.txt": b"bar".decode("utf-8")} + torch.export.save(ep, "exported_program.pt2", extra_files=extra_files) + + """ + if not isinstance(ep, ExportedProgram): + raise TypeError( + f"The 'ep' parameter must be an instance of 'ExportedProgram', got '{type(ep).__name__}' instead." + ) + + from torch.export.pt2_archive._package import package_pt2 + + package_pt2( + f, + exported_programs={"model": ep}, + extra_files=extra_files, + pickle_protocol=pickle_protocol, + opset_version=opset_version, + ) + + +def load( + f: FileLike, + *, + extra_files: dict[str, Any] | None = None, + expected_opset_version: dict[str, int] | None = None, +) -> ExportedProgram: + """ + + .. warning:: + Under active development, saved files may not be usable in newer versions + of PyTorch. + + .. warning:: + :func:`torch.export.load()` uses pickle under the hood to load models. **Never load data from an untrusted source.** + + Loads an :class:`ExportedProgram` previously saved with + :func:`torch.export.save `. + + Args: + f (str | os.PathLike[str] | IO[bytes]): A file-like object (has to + implement write and flush) or a string containing a file name. + + extra_files (Optional[Dict[str, Any]]): The extra filenames given in + this map would be loaded and their content would be stored in the + provided map. + + expected_opset_version (Optional[Dict[str, int]]): A map of opset names + to expected opset versions + + Returns: + An :class:`ExportedProgram` object + + Example:: + + import torch + import io + + # Load ExportedProgram from file + ep = torch.export.load("exported_program.pt2") + + # Load ExportedProgram from io.BytesIO object + with open("exported_program.pt2", "rb") as f: + buffer = io.BytesIO(f.read()) + buffer.seek(0) + ep = torch.export.load(buffer) + + # Load with extra files. + extra_files = {"foo.txt": ""} # values will be replaced with data + ep = torch.export.load("exported_program.pt2", extra_files=extra_files) + print(extra_files["foo.txt"]) + print(ep(torch.randn(5))) + """ + if isinstance(f, (str, os.PathLike)): + f = os.fspath(f) + + extra_files = extra_files or {} + + from torch.export.pt2_archive._package import load_pt2, PT2ArchiveContents + + try: + pt2_contents = load_pt2( + f, + expected_opset_version=expected_opset_version, + ) + except RuntimeError: + log.warning("Ran into the following error when deserializing", exc_info=True) + pt2_contents = PT2ArchiveContents({}, {}, {}) + + if len(pt2_contents.exported_programs) > 0 or len(pt2_contents.extra_files) > 0: + for k, v in pt2_contents.extra_files.items(): + extra_files[k] = v + + return pt2_contents.exported_programs["model"] + + # TODO: For backward compatibility, we support loading a zip file from 2.7. Delete this path in 2.9(?) + with zipfile.ZipFile(f, "r") as zipf: + if "version" not in zipf.namelist(): + raise RuntimeError( + "We ran into an error when deserializing the saved file. " + "Please check the warnings above for possible errors. " + ) + + log.warning( + "Trying to deserialize for the older format. This version of file is " + "deprecated. Please generate a new pt2 saved file." + ) + + # Check the version + version = zipf.read("version").decode().split(".") + from torch._export.serde.schema import ( + SCHEMA_VERSION, # todo change archive version to schema version + ) + + assert len(version) == len(SCHEMA_VERSION), ( + "Version in the saved file has incorrect length, double check if the file is generated by torch.export.save()" + ) + if version[0] != str(SCHEMA_VERSION[0]): + raise RuntimeError( + f"Serialized version {version} does not match our current " + f"schema version {SCHEMA_VERSION}." + ) + + from torch._export.serde.serialize import deserialize, SerializedArtifact + + # Load serialized_ep and serialized_state_dict from the zip file + + serialized_exported_program: bytes | None = None + serialized_state_dict: bytes | None = None + serialized_constants: bytes | None = None + serialized_example_inputs: bytes | None = None + + for file_info in zipf.infolist(): + file_content = zipf.read(file_info.filename) + + if file_info.filename == "serialized_exported_program.json": + serialized_exported_program = file_content + elif file_info.filename == "serialized_state_dict.json": + warnings.warn("This version of file is deprecated", stacklevel=2) + serialized_state_dict = file_content + elif file_info.filename == "serialized_constants.json": + warnings.warn("This version of file is deprecated", stacklevel=2) + serialized_constants = file_content + elif file_info.filename == "serialized_state_dict.pt": + serialized_state_dict = file_content + elif file_info.filename == "serialized_constants.pt": + serialized_constants = file_content + elif file_info.filename == "serialized_example_inputs.pt": + serialized_example_inputs = file_content + elif file_info.filename.startswith("extra_files"): + filename = file_info.filename.split("/", 1)[1] + extra_files[filename] = file_content.decode("utf-8") + + assert serialized_exported_program is not None + assert serialized_state_dict is not None + assert serialized_constants is not None + assert serialized_example_inputs is not None + artifact: SerializedArtifact = SerializedArtifact( + serialized_exported_program, + serialized_state_dict, + serialized_constants, + serialized_example_inputs, + ) + + # Deserialize ExportedProgram + ep = deserialize(artifact, expected_opset_version) + + return ep + + +def draft_export( + mod: torch.nn.Module, + args: tuple[Any, ...], + kwargs: Mapping[str, Any] | None = None, + *, + dynamic_shapes: dict[str, Any] | tuple[Any, ...] | list[Any] | None = None, + preserve_module_call_signature: tuple[str, ...] = (), + strict: bool = False, + prefer_deferred_runtime_asserts_over_guards: bool = False, +) -> ExportedProgram: + """ + A version of torch.export.export which is designed to consistently produce + an ExportedProgram, even if there are potential soundness issues, and to + generate a report listing the issues found. + """ + from ._draft_export import draft_export + + return draft_export( + mod=mod, + args=args, + kwargs=kwargs, + dynamic_shapes=dynamic_shapes, + preserve_module_call_signature=preserve_module_call_signature, + strict=strict, + prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards, + ) + + +def register_dataclass( + cls: type[Any], + *, + serialized_type_name: str | None = None, +) -> None: + """ + Registers a dataclass as a valid input/output type for :func:`torch.export.export`. + + Args: + cls: the dataclass type to register + serialized_type_name: The serialized name for the dataclass. This is + required if you want to serialize the pytree TreeSpec containing this + dataclass. + + Example:: + + import torch + from dataclasses import dataclass + + + @dataclass + class InputDataClass: + feature: torch.Tensor + bias: int + + + @dataclass + class OutputDataClass: + res: torch.Tensor + + + torch.export.register_dataclass(InputDataClass) + torch.export.register_dataclass(OutputDataClass) + + + class Mod(torch.nn.Module): + def forward(self, x: InputDataClass) -> OutputDataClass: + res = x.feature + x.bias + return OutputDataClass(res=res) + + + ep = torch.export.export(Mod(), (InputDataClass(torch.ones(2, 2), 1),)) + print(ep) + + """ + pytree.register_dataclass(cls, serialized_type_name=serialized_type_name) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/_draft_export.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/_draft_export.py new file mode 100644 index 0000000000000000000000000000000000000000..3eb2621b707bb4bc047ce2cad284a0c9e33b6fb4 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/_draft_export.py @@ -0,0 +1,544 @@ +import getpass +import json +import logging +import os +import re +import tempfile +import time +from collections.abc import Callable, Mapping +from dataclasses import dataclass +from enum import IntEnum +from typing import Any + +import torch +import torch._logging._internal +import torch.utils._pytree as pytree +from torch._dynamo.exc import UserError, UserErrorType +from torch._export.passes.insert_custom_op_guards import ( + get_op_profiles, + insert_custom_op_guards, + OpProfile, +) +from torch._utils_internal import log_draft_export_usage + +from ._trace import _export, get_ep_stats +from .dynamic_shapes import _DimHint, _DimHintType, Dim +from .exported_program import ExportedProgram + + +log = logging.getLogger(__name__) + + +class FailureType(IntEnum): + MISSING_FAKE_KERNEL = 1 + DATA_DEPENDENT_ERROR = 2 + GUARD_ADDED = 3 + MISMATCHED_FAKE_KERNEL = 4 + + def __str__(self) -> str: + return self.name + + +def prettify_stack(stack: list[dict[str, str]], str_to_filename: dict[int, str]) -> str: + res = "" + for frame in stack: + if frame["filename"] not in str_to_filename: + continue + + res += f""" + File {str_to_filename[frame["filename"]]}, lineno {frame["line"]}, in {frame["name"]}""" # type: ignore[index] + + res += f"\n {stack[-1]['loc']}" + return res + + +def prettify_frame_locals( + loc: str, locals: dict[str, Any], symbols: dict[str, Any] +) -> str: + local_str = "\n".join(f" {k}: {v}" for k, v in locals.items()) + res = f""" + Locals: +{local_str} +""" + if any(v is not None for v in symbols.values()): + symbol_str = "\n".join( + f" {k}: {v}" for k, v in symbols.items() if v is not None + ) + res += f""" + Symbols: +{symbol_str} +""" + return res + + +def get_loc(filename: str, lineno: int) -> str | None: + try: + with open(filename) as f: + for i, line in enumerate(f): + if i == lineno - 1: + return line.strip() + except FileNotFoundError: + pass + return None + + +class FailureReport: + def __init__( + self, failure_type: FailureType, data: dict[str, Any], xfail: bool = False + ) -> None: + self.failure_type: FailureType = failure_type + self.data: dict[str, Any] = data + self.xfail: bool = xfail + + def __repr__(self) -> str: + return f"FailureReport(failure_type={self.failure_type}, xfail={self.xfail}, data={self.data})" + + def print(self, str_to_filename: dict[int, str]) -> str: + if self.failure_type == FailureType.MISSING_FAKE_KERNEL: + op = self.data["op"] + + return f"""Missing fake kernel. + torch.ops.{op} is missing a fake kernel implementation. + + Please refer to https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU/edit#heading=h.ahugy69p2jmz for more detailed instructions on how to write a meta implementation. +""" # noqa: B950 + + elif self.failure_type == FailureType.GUARD_ADDED: + locals_info = ( + prettify_frame_locals(**self.data["frame_locals"]) + if self.data["frame_locals"] + else "" + ) + return f"""Guard Added. + A guard was added during tracing, which might've resulted in some incorrect + tracing or constraint violation error. + Specifically, this guard was added: {self.data["expr"]}, where {self.data["symbol_to_sources"]}. + This occurred at the following stacktrace: {prettify_stack(self.data["user_stack"], str_to_filename)}: + {locals_info} + And the following framework stacktrace: {prettify_stack(self.data["stack"], str_to_filename)}\n + Because of this, we have modified the dynamic shapes structure to be the + following. You can also use torch.export.Dim.AUTO instead to specify your + dynamic shapes, and we will automatically infer the dynamism for you. + ``` + dynamic_shapes = {self.data["new_dynamic_shapes"]} + ``` +""" + + elif self.failure_type == FailureType.DATA_DEPENDENT_ERROR: + locals_info = ( + prettify_frame_locals(**self.data["frame_locals"]) + if self.data["frame_locals"] + else "" + ) + return f"""Data dependent error. + When exporting, we were unable to evaluate the value of `{self.data["expr"]}`. + This was encountered {self.data["occurrences"]} times. + This occurred at the following user stacktrace: {prettify_stack(self.data["user_stack"], str_to_filename)} + {locals_info} + And the following framework stacktrace: {prettify_stack(self.data["stack"], str_to_filename)}\n + As a result, it was specialized to a constant (e.g. `{self.data["result"]}` in the 1st occurrence), and asserts were inserted into the graph. + + Please add `torch._check(...)` to the original code to assert this data-dependent assumption. + Please refer to https://docs.google.com/document/d/1kZ_BbB3JnoLbUZleDT6635dHs88ZVYId8jT-yTFgf3A/edit#heading=h.boi2xurpqa0o for more details. +""" # noqa: B950 + + elif self.failure_type == FailureType.MISMATCHED_FAKE_KERNEL: + op = self.data["op"] + reason = self.data["reason"] + return f"""Mismatched fake kernel. + torch.ops.{op} has a fake kernel implementation, but it has incorrect behavior, based on the real kernel. + The reason for the mismatch is: {reason}. + + Please refer to https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU/edit#heading=h.ahugy69p2jmz for more detailed instructions on how to write a fake implementation. +""" # noqa: B950 + + else: + raise ValueError(f"Unknown failure type: {self.failure_type}") + + +class DraftExportReport: + def __init__( + self, + failures: list[FailureReport], + str_to_filename: dict[int, str], + expressions_created: dict[int, dict[str, Any]], + op_profiles: dict[str, set[OpProfile]], + ): + self.failures: list[FailureReport] = failures + self.str_to_filename = str_to_filename + self.expressions_created: dict[int, dict[str, Any]] = expressions_created + self.op_profiles = op_profiles + + def successful(self) -> bool: + return len(self.failures) == 0 or all( + failure.xfail for failure in self.failures + ) + + def __repr__(self) -> str: + return f"DraftExportReport({self.failures})" + + def __str__(self) -> str: + WARNING_COLOR = "\033[93m" + GREEN_COLOR = "\033[92m" + END_COLOR = "\033[0m" + + if self.successful(): + return f"""{GREEN_COLOR} +############################################################################################## +Congratuations: No issues are found during export, and it was able to soundly produce a graph. +You can now change back to torch.export.export() +############################################################################################## +{END_COLOR}""" + + error = f"""{WARNING_COLOR} +################################################################################################### +WARNING: {len(self.failures)} issue(s) found during export, and it was not able to soundly produce a graph. +Please follow the instructions to fix the errors. +################################################################################################### + +""" + + for i, failure in enumerate(self.failures): + error += f"{i + 1}. {failure.print(self.str_to_filename)}\n" + error += END_COLOR + return error + + def apply_suggested_fixes(self) -> None: + raise NotImplementedError("Not implemented yet") + + +@dataclass +class ExpressionCreatedNode: + result_id: int + argument_ids: list[int] + record: dict[str, object] + visited: bool = False + + +class LogRecord: + def __init__(self) -> None: + self.log_count: dict[int, int] = {} + self.logs: list[tuple[str, dict[str, Any]]] = [] + + def _hash(self, element: tuple[str, dict[str, Any]]) -> int: + key, data = element + + if key == "missing_fake_kernel": + return hash((key, data["op"])) + elif key == "mismatched_fake_kernel": + return hash((key, data["op"], data["reason"])) + elif key == "propagate_real_tensors_provenance": + return hash((key, json.dumps(data["user_stack"]))) + elif key == "guard_added": + return hash((key, json.dumps(data["user_stack"]))) + elif key == "create_unbacked_symbol": + return hash((key, json.dumps(data["user_stack"]))) + + return hash((key, json.dumps(data))) + + def try_add(self, element: tuple[str, dict[str, str]]) -> bool: + hash_value = self._hash(element) + if hash_value in self.log_count: + self.log_count[hash_value] += 1 + return False + + self.log_count[hash_value] = 1 + self.logs.append(element) + return True + + def get_log_count(self, element: tuple[str, dict[str, Any]]) -> int: + return self.log_count[self._hash(element)] + + +class CaptureStructuredTrace(torch._logging._internal.LazyTraceHandler): + def __init__(self) -> None: + self.specific_log_keys = [ + "str", + "exported_program", + "propagate_real_tensors_provenance", + "guard_added", + "missing_fake_kernel", + "mismatched_fake_kernel", + "expression_created", + "create_unbacked_symbol", + ] + self.log_record: LogRecord = LogRecord() + self.expression_created_logs: dict[int, ExpressionCreatedNode] = {} + self.symbol_to_expressions: dict[str, list[dict[str, Any]]] = {} + self.logger = logging.getLogger("torch.__trace") + self.prev_get_dtrace = False + + if root_dir := os.environ.get(torch._logging._internal.DTRACE_ENV_VAR): + super().__init__(root_dir) + else: + sanitized_username = re.sub(r'[\\/:*?"<>|]', "_", getpass.getuser()) + root_dir = os.path.join( + tempfile.gettempdir(), + "export_" + sanitized_username, + ) + super().__init__(root_dir) + + self.setFormatter(torch._logging._internal.TorchLogsFormatter(trace=True)) + + def __enter__(self) -> "CaptureStructuredTrace": + self.log_record = LogRecord() + self.expression_created_logs = {} + + # Remove the lazy trace handler if it exists + possible_lazy_trace_handlers = [ + handler + for handler in self.logger.handlers + if isinstance(handler, torch._logging._internal.LazyTraceHandler) + ] + for handler in possible_lazy_trace_handlers: + self.logger.removeHandler(handler) + + self.logger.addHandler(self) + self.prev_get_dtrace = torch._logging._internal.GET_DTRACE_STRUCTURED + # pyrefly: ignore [bad-assignment] + torch._logging._internal.GET_DTRACE_STRUCTURED = True + return self + + def __exit__(self, exc_type, exc_value, traceback) -> None: # type: ignore[no-untyped-def] + self.log_record = LogRecord() + self.expression_created_logs = {} + self.logger.removeHandler(self) + # pyrefly: ignore [bad-assignment] + torch._logging._internal.GET_DTRACE_STRUCTURED = self.prev_get_dtrace + self.prev_get_dtrace = False + + def emit(self, record: Any) -> None: + def _log_expression_created( + emit_func: Callable[[Any], None], sym_node_id: int + ) -> None: + # Log all the relevant expression_created logs + if sym_node_id is None: + return + if res := self.expression_created_logs.get(sym_node_id, None): + # Don't log the expression if we have already + # printed it beforehand + if not res.visited: + res.visited = True + for arg in res.argument_ids: + _log_expression_created(emit_func, arg) + + emit_func(res.record) + + metadata = record.metadata + for key in self.specific_log_keys: + if key in metadata: + if self.log_record.try_add((key, metadata[key])): + if key == "expression_created": + # We don't want to log all expression_created logs, only + # the ones that are relevant to the + # guards/propagate_real_tensor + self.expression_created_logs[metadata[key]["result_id"]] = ( + ExpressionCreatedNode( + metadata[key]["result_id"], + metadata[key].get("argument_ids", []), + record, + ) + ) + return + + elif key == "propagate_real_tensors_provenance": + _log_expression_created( + super().emit, metadata[key].get("expr_node_id") + ) + + elif key == "guard_added": + if len(metadata[key]["symbol_to_sources"]) == 0: + # We only want to include guards added that are relevant to + # the symbolic shapes corresponding to the inputs which were + # specified in the dynamic_shapes arg. These have a source. + return + elif metadata[key]["prefix"] == "runtime_assert": + # This should've been captured by a + # propagate_real_tensors log + return + + _log_expression_created( + super().emit, metadata[key].get("expr_node_id") + ) + + super().emit(record) + + +def draft_export( + mod: torch.nn.Module, + args: tuple[Any, ...], + kwargs: Mapping[str, Any] | None = None, + *, + dynamic_shapes: dict[str, Any] | tuple[Any] | list[Any] | None = None, + preserve_module_call_signature: tuple[str, ...] = (), + strict: bool = False, + pre_dispatch: bool = True, + prefer_deferred_runtime_asserts_over_guards: bool = False, +) -> ExportedProgram: + start_time = time.time() + kwargs = kwargs or {} + dynamic_shapes = dynamic_shapes or {} + + constraint_violation_msg = None + capture_structured_log = CaptureStructuredTrace() + + with ( + torch._functorch.config.patch( + fake_tensor_propagate_real_tensors=True, + generate_fake_kernels_from_real_mismatches=True, + ), + capture_structured_log, + ): + try: + new_shapes = None + ep = _export( + mod, + args, + kwargs, + dynamic_shapes=dynamic_shapes, + strict=strict, + pre_dispatch=pre_dispatch, + preserve_module_call_signature=preserve_module_call_signature, + prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards, + ) + except Exception as exc: + if ( + isinstance(exc, UserError) + and exc.error_type == UserErrorType.CONSTRAINT_VIOLATION + ): + constraint_violation_msg = exc.msg + + def convert_dim_to_auto(dim: Any) -> Any: + if isinstance(dim, Dim): + return Dim.AUTO(min=dim.min, max=dim.max) + elif isinstance(dim, _DimHint) and dim.type == _DimHintType.DYNAMIC: + return Dim.AUTO(min=dim.min, max=dim.max) + return dim + + new_shapes = pytree.tree_map(convert_dim_to_auto, dynamic_shapes) + ep = _export( + mod, + args, + kwargs, + dynamic_shapes=new_shapes, + strict=strict, + pre_dispatch=pre_dispatch, + preserve_module_call_signature=preserve_module_call_signature, + prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards, + ) + else: + log_draft_export_usage( + error=True, + export_time=time.time() - start_time, + strict=strict, + message=str(exc), + type=f"{type(exc).__name__}.{type(exc).__qualname__}", + ) + raise exc + + torch._logging.dtrace_structured("exported_program", payload_fn=lambda: str(ep)) + + str_to_filename: dict[int, str] = {} + failures: list[FailureReport] = [] + incorrect_custom_ops: set[str] = set() + expressions_created: dict[int, dict[str, Any]] = {} + + for log_name, log_contents in capture_structured_log.log_record.logs: + failure_type = None + + if log_name == "str": + str_to_filename[log_contents[1]] = log_contents[0] # type: ignore[index] + continue + + elif log_name == "propagate_real_tensors_provenance": + log_contents["occurrences"] = ( + capture_structured_log.log_record.get_log_count( + (log_name, log_contents) + ) + ) + + failure_type = FailureType.DATA_DEPENDENT_ERROR + + elif log_name == "guard_added": + if new_shapes is None: + continue + + failure_type = FailureType.GUARD_ADDED + log_contents["new_dynamic_shapes"] = new_shapes + elif log_name == "missing_fake_kernel": + failure_type = FailureType.MISSING_FAKE_KERNEL + incorrect_custom_ops.add(log_contents["op"]) + + elif log_name == "mismatched_fake_kernel": + failure_type = FailureType.MISMATCHED_FAKE_KERNEL + incorrect_custom_ops.add(log_contents["op"]) + + else: + continue + + assert failure_type is not None + failures.append( + FailureReport( + failure_type, + log_contents, + ) + ) + + for k, v in capture_structured_log.expression_created_logs.items(): + if v.visited: + expressions_created[k] = v.record + + op_profiles = get_op_profiles(ep.graph_module, incorrect_custom_ops) + report = DraftExportReport( + failures, str_to_filename, expressions_created, op_profiles + ) + + # Add asserts around custom ops + insert_custom_op_guards(ep.graph_module, incorrect_custom_ops) + + ep._report = report + if not report.successful(): + log_filename = capture_structured_log.stream.name + + warning_msg = f""" +################################################################################################### +WARNING: {len(report.failures)} issue(s) found during export, and it was not able to soundly produce a graph. +To view the report of failures in an html page, please run the command: + `tlparse {log_filename} --export` +Or, you can view the errors in python by inspecting `print(ep._report)`. +""" + + if len(report.op_profiles) > 0: + warning_msg += f""" +While tracing we found {len(report.op_profiles)} operator(s) which do not have a fake kernel registered. +If you intend to retrace the exported graph or run it with fake tensors, please run it under the +following context manager, which will register a fake kernel for those operators. +``` +with torch._library.fake_profile.unsafe_generate_fake_kernels(ep._report.op_profiles): + # run with fake tensors +``` +""" + + warning_msg += """#################################################################################################""" + + log.warning(warning_msg) + + else: + log.info( + """ +############################################################################################## +Congratuations: No issues are found during export, and it was able to soundly produce a graph. +You can now change back to torch.export.export() +############################################################################################## + """ + ) + + log_draft_export_usage( + error=False, + export_time=time.time() - start_time, + strict=strict, + constraint_violations=constraint_violation_msg, + report=ep._report, + **get_ep_stats(ep), + ) + return ep diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/_leakage_detection_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/_leakage_detection_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..fe211e1dc079c844748edda14ba269917e7e7847 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/_leakage_detection_utils.py @@ -0,0 +1,112 @@ +import gc +import types +import typing +import weakref + +import torch + + +""" +These functions are used to detect potential fake tensor leakage when using PT2 export. +See NOTE [export non-strict fake tensor leak detection] + +There are some complications that made this logic overly complicated: +1) Python 3.10 and Python 3.12 have different ways of implementing referrer so + we need to account for whether it is ref.__dict__ or the real ref object + +2) There are some internal PT2 references to fake tensors like `TrackedFake` +3) closures, generators, and bound methods can hold fake tensors. +4) global object can hold onto a fake tensor + +In general, these utils are our last resort to detect fake tensors. if the leak happens +within the model attributes, we have a separate mechanism to detect. This tool relies a bit +on garbage collector internal details, so I think it is unsafe to turn on by default, hence +this tool should be used as debugging tool. +""" + + +# Things we never want to flag as leaks +_SKIP_TYPES = ( + types.FrameType, + types.ModuleType, +) + + +def _is_globals_or_locals(obj: typing.Any) -> bool: + # These comparisons only make sense within this frame; still cheap to check. + return obj is globals() or obj is locals() + + +def _is_tracked_fake(obj: typing.Any) -> bool: + return isinstance(obj, torch.fx.experimental.symbolic_shapes.TrackedFake) + + +def _is_gm_meta_like_dict(d: dict, o: typing.Any) -> bool: + # Hope gm.meta was a custom dict we can assert on + return d.get("val") is o + + +def _dict_is_attr_of_tracked_fake(d: dict) -> bool: + """ + Python 3.10 quirk: sometimes the referrer is obj.__dict__ instead of obj. + Check if this dict is exactly the __dict__ of a TrackedFake. + """ + for parent in gc.get_referrers(d): + if ( + hasattr(parent, "__dict__") + and parent.__dict__ is d + and _is_tracked_fake(parent) + ): + return True + return False + + +def find_legit_leaks_from_referrers(active_fakes: weakref.WeakSet) -> weakref.WeakSet: + legit_leak: weakref.WeakSet = weakref.WeakSet() + + # This is so that we don't falsely flag generator to be holding fake tensor + fake_list = list(active_fakes) + fake_list_id = id(fake_list) + + for act in fake_list: + # Track by id to avoid processing duplicate referrers + seen = set() + # Assume it's a leak unless we find only ignorable referrers + flagged = False + + for r in gc.get_referrers(act): + rid = id(r) + if rid in seen: + continue + seen.add(rid) + + # Skip our own fake_list + if rid == fake_list_id: + continue + + # Fast-path: skip obvious non-owners + if _is_globals_or_locals(r): + continue + if isinstance(r, _SKIP_TYPES): + continue + if _is_tracked_fake(r): + # TrackedFake should be ignored + continue + + # Handle dicts carefully (Python 3.10 sometimes shows __dict__) + if isinstance(r, dict): + if _is_gm_meta_like_dict(r, act): + continue + if _dict_is_attr_of_tracked_fake(r): + continue + flagged = True + break + + # Any other referrer we don't explicitly whitelist counts as a leak + flagged = True + break + + if flagged: + legit_leak.add(act) + + return legit_leak diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/_remove_auto_functionalized_pass.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/_remove_auto_functionalized_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..1b4833927656767940f3795d37f0e75b93b647da --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/_remove_auto_functionalized_pass.py @@ -0,0 +1,52 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import torch +from torch._higher_order_ops.auto_functionalize import ( + auto_functionalized, + auto_functionalized_v2, +) +from torch._inductor.fx_passes.post_grad import decompose_auto_functionalized +from torch.export import ExportedProgram +from torch.fx import Graph + + +def remove_self_clone(graph: Graph) -> None: + for node in graph.nodes: + if node.target is torch.ops.aten.copy_.default and node.args[0] == node.args[1]: + node.replace_all_uses_with(node.args[0]) + graph.erase_node(node) + + +def unsafe_remove_auto_functionalized_pass( + ep: ExportedProgram, +) -> ExportedProgram: + """ + This pass removes an instances of the higher order op 'auto_functionalized', + and modifies the calling EP inplace to have the original mutator op. + This pass doesn't perform safety checks to make sure that this inplace mutation is safe. + """ + + with ep.graph_module._set_replace_hook(ep.graph_signature.get_replace_hook()): + for module in ep.graph_module.modules(): + if not isinstance(module, torch.fx.GraphModule): + continue + for node in ep.graph.nodes: + if ( + node.op == "call_function" and node.target is auto_functionalized + ) or ( + node.op == "call_function" and node.target is auto_functionalized_v2 + ): + func = node.args[0] + assert isinstance(func, torch._ops.OpOverload) + # re-inplace everything + node.meta["only_clone_these_tensors"] = [] + decompose_auto_functionalized(ep.graph) + remove_self_clone(ep.graph) + ep.graph.eliminate_dead_code() + + return ep diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/_remove_effect_tokens_pass.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/_remove_effect_tokens_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..8504d1cbdb71fd6f2199776fb7712768893afed1 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/_remove_effect_tokens_pass.py @@ -0,0 +1,212 @@ +# mypy: allow-untyped-defs +import operator + +import torch +from torch._higher_order_ops.effects import _get_schema, with_effects + +from .exported_program import ExportedProgram +from .graph_signature import ( + CustomObjArgument, + InputKind, + InputSpec, + OutputKind, + OutputSpec, + TokenArgument, +) + + +def _get_custom_obj_for_node(node, inputs_to_lifted_custom_objs, constants): + """Extract the custom object from a node's arguments.""" + custom_obj_node = node + custom_obj_meta = custom_obj_node.meta["val"] # type: ignore[union-attr] + assert isinstance(custom_obj_meta, CustomObjArgument) + + if custom_obj_meta.fake_val: + return custom_obj_meta.fake_val + elif custom_obj_node.name in inputs_to_lifted_custom_objs: # type: ignore[union-attr] + return constants[inputs_to_lifted_custom_objs[custom_obj_node.name]] # type: ignore[union-attr] + else: + raise RuntimeError(f"Unable to find custom obj for node {node}") + + +def _replace_with_effects_node( + node, ep, inputs_to_lifted_custom_objs, output_tokens, input_tokens, module +): + """Replace a with_effects node with the underlying function call.""" + # Get the input nodes + token_node, func, *node_args = node.args + if token_node.op == "placeholder": + input_tokens.append(token_node) + + assert isinstance(func, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)) + + # Get the schema for the function + if func is torch.ops.higher_order.call_torchbind: + custom_obj = _get_custom_obj_for_node( + node_args[0], inputs_to_lifted_custom_objs, ep.constants + ) + schema = _get_schema(func, [custom_obj] + node_args[1:]) + else: + schema = _get_schema(func, node_args) + + # Create the replacement node + with module.graph.inserting_before(node): + new_node = module.graph.call_function(func, tuple(node_args), node.kwargs) + + # Update getitem nodes that extract outputs from with_effects + for user in list(node.users.keys()): + assert user.target is operator.getitem + # getitem(with_effects, 0) is the token node + if user.args[1] == 0: + for user_user in list(user.users.keys()): + if user_user.op == "output": + output_tokens.append(user) + + # Copy metadata from old node to new node + for k, v in node.meta.items(): + new_node.meta[k] = v + if k == "unbacked_bindings": + # Remove the extra layer for effect token + old_bindings = new_node.meta[k] + new_bindings = { + k: path[1:] if path else path for k, path in old_bindings.items() + } + new_node.meta[k] = new_bindings + + # Fix up the getitem nodes based on return count + if len(schema.returns) == 1: + # Single return: replace getitem(with_effects, 1) with the node itself + for user in list(node.users.keys()): + if user.args[1] == 1: + user.replace_all_uses_with(new_node) + new_node.meta["val"] = node.meta["val"][1] + elif len(schema.returns) > 1: + # Multiple returns: shift getitem indices down by 1 + for user in list(node.users.keys()): + if user.args[1] >= 1: + user.args = (new_node, user.args[1] - 1) + new_node.meta["val"] = node.meta["val"][1:] + else: + # No returns + assert len(schema.returns) == 0 + assert len(new_node.users) == 0 + new_node.meta["val"] = None + + +def _replace_invoke_subgraph_node(node, module, output_tokens, input_tokens): + """Replace an invoke_subgraph node to remove the token argument.""" + assert node.args[0].op == "get_attr" + submod = getattr(module, node.args[0].target) + if not submod.meta.get("has_with_effects", False): + return + + # Remove token from inputs + subgraph, identifier, token, *operands = node.args + node.args = (subgraph, identifier, *operands) + if token.op == "placeholder": + input_tokens.append(token) + + # Update getitem nodes to account for removed token output + for user in list(node.users.keys()): + if user.args[1] >= 1: + user.args = (node, user.args[1] - 1) + elif user.args[1] == 0: + for user_user in list(user.users.keys()): + if user_user.op == "output": + output_tokens.append(user) + + +def _remove_effect_tokens(ep: ExportedProgram) -> ExportedProgram: + """ + Removes the existence of tokens from the exported program, including: + - Removes the input and output tokens + - Replaces with_effects(token, func, args) with just func(args) + + This function does an inplace modification on the given ExportedProgram. + """ + inputs_to_lifted_custom_objs = ep.graph_signature.inputs_to_lifted_custom_objs + + # mark submodules with effects as having effects. This will be used in the following pass to remove effects from subgraphs + for _, module in ep.graph_module.named_modules(): + if not isinstance(module, torch.fx.GraphModule): + continue + + with_effect_nodes = [ + node for node in module.graph.nodes if node.target is with_effects + ] + if len(with_effect_nodes) > 0: + module.meta["has_with_effects"] = True + + # Process each module with the replace hook to ensure graph signature is updated + with ep.graph_module._set_replace_hook(ep.graph_signature.get_replace_hook()): + for _, module in ep.graph_module.named_modules(): + if not isinstance(module, torch.fx.GraphModule): + continue + + input_tokens = [] + output_tokens = [] + + # Process with_effects and invoke_subgraph nodes + for node in module.graph.nodes: + if node.target is with_effects: + _replace_with_effects_node( + node, + ep, + inputs_to_lifted_custom_objs, + output_tokens, + input_tokens, + module, + ) + elif node.target is torch.ops.higher_order.invoke_subgraph: + _replace_invoke_subgraph_node( + node, module, output_tokens, input_tokens + ) + + # Remove tokens from the output node + if len(output_tokens) > 0: + output_node = next(reversed(module.graph.find_nodes(op="output"))) + output_args = output_node.args[0] + assert len(output_args) >= len(output_tokens), ( + f"{output_args} output arguments found\n" + f"{output_tokens} output tokens found\n" + f"{module.graph}" + ) + output_node.args = (tuple(output_args[len(output_tokens) :]),) + + module.graph.eliminate_dead_code() + + # Remove tokens from the input placeholders + for node in module.graph.nodes: + if node.op == "placeholder" and node in input_tokens: + module.graph.erase_node(node) + + module.recompile() + + num_tokens: int = 0 + input_token_names: list[str] = [] + new_input_specs: list[InputSpec] = [] + for inp in ep.graph_signature.input_specs: + if inp.kind == InputKind.TOKEN: + num_tokens += 1 + assert isinstance(inp.arg, TokenArgument) + input_token_names.append(inp.arg.name) + else: + new_input_specs.append(inp) + + num_out_tokens: int = 0 + new_output_specs: list[OutputSpec] = [] + output_token_names: list[OutputSpec] = [] + for out in ep.graph_signature.output_specs: + if out.kind == OutputKind.TOKEN: + num_out_tokens += 1 + output_token_names.append(out.arg.name) + else: + new_output_specs.append(out) + + # Update graph signature + ep.graph_signature.input_specs = new_input_specs + ep.graph_signature.output_specs = new_output_specs + + assert num_tokens == num_out_tokens + + return ep diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/_safeguard.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/_safeguard.py new file mode 100644 index 0000000000000000000000000000000000000000..76f22f369c566a97062fc60696ad7972dc2b260c --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/_safeguard.py @@ -0,0 +1,44 @@ +# mypy: allow-untyped-defs +import torch +from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode +from torch.overrides import TorchFunctionMode + + +class AutogradStateOpsFailSafeguard(TorchFunctionMode): + """ + Detect grad state ops during exporting the graph and fail the process by + raising an error, to avoid unexpected behavior. Those grad mode ops could be: + `torch.no_grad` + `torch.enable_grad` + `torch.set_grad_enabled` + + Export with predispatch mode is exempted. + """ + + def __torch_function__(self, func, types, args=(), kwargs=None): + kwargs = kwargs or {} + unsupported_grad_mode_ops = [ + torch._C._set_grad_enabled, + ] + # It's only enabled while tracing, by confirming the torch dispatch mode is + # any active PROXY. This is to allow the autograd ops out of tracing. + current_state = torch._C.is_grad_enabled() + if func in unsupported_grad_mode_ops: + assert len(args) == 1 + changed_state = args[0] + mode = torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.PROXY) + # Intend to check if it's not the pre_dispatch mode. It's allowed to use + # autograd ops in pre_dispatch mode, e.g. `torch.no_grad` + if ( + mode + and isinstance(mode, ProxyTorchDispatchMode) + and not mode.pre_dispatch + and changed_state != current_state + ): + raise RuntimeError( + f"Encountered autograd state manager op {func} trying to change global autograd state " + "while exporting. This is unsafe because we don't capture this op in torch.export " + "today, hence we can't reflect the user intention soundly. You can fix this by " + "adding a torch.no_grad() context around the export call." + ) + return func(*args, **kwargs) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/_swap.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/_swap.py new file mode 100644 index 0000000000000000000000000000000000000000..f5aca6305c7de783b82d3848d58bec12a77a6fb8 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/_swap.py @@ -0,0 +1,439 @@ +import logging +import operator +import types +from collections import defaultdict + +import torch +import torch.fx._pytree as fx_pytree +import torch.utils._pytree as pytree +from torch.export.exported_program import ( + ConstantArgument, + ExportedProgram, + ModuleCallSignature, +) +from torch.fx.passes.tools_common import legalize_graph, NodeList +from torch.fx.passes.utils.fuser_utils import erase_nodes, fuse_as_graphmodule + + +log = logging.getLogger(__name__) + + +def _get_getitem_users(node: torch.fx.Node) -> set[torch.fx.Node]: + node_users = list(node.users.keys()) + getitem_users = set() + for user in node_users: + if user.op == "output": + continue + + assert user.op == "call_function" and user.target is operator.getitem, ( + f"Expected getitem node as user for {node}, instead got {user}" + ) + getitem_users.update(list(user.users.keys())) + return getitem_users + + +def _try_remove_connecting_pytrees(curr_module_node: torch.fx.Node) -> None: + """ + We want to try to remove extraneous pytree flatten/unflatten calls between modules + calls. Instead of having the following: + graph(): + ... + %foo : [num_users=1] = call_module[target=foo](args = (%getitem_1, %getitem_2), kwargs = {}) + %tree_flatten_spec : [num_users=1] = call_function[target=torch.fx._pytree.tree_flatten_spec](args = (%foo, %_spec_1), kwargs = {}) + %getitem_4 : [num_users=1] = call_function[target=operator.getitem](args = (%tree_flatten_spec, 0), kwargs = {}) + %tree_unflatten_1 : [num_users=2] = call_function[target=torch.utils._pytree.tree_unflatten](args = ([%getitem_4], %_spec_2), kwargs = {}) + %getitem_5 : [num_users=1] = call_function[target=operator.getitem](args = (%tree_unflatten_1, 0), kwargs = {}) + %getitem_7 : [num_users=0] = call_function[target=operator.getitem](args = (%tree_unflatten_1, 1), kwargs = {}) + %getitem_6 : [num_users=1] = call_function[target=operator.getitem](args = (%getitem_5, 0), kwargs = {}) + %bar : [num_users=1] = call_module[target=bar](args = (%getitem_6,), kwargs = {}) + ... + + We could do the following, if we know that all the outputs of `foo` feed into `bar`: + graph(): + ... + %foo : [num_users=1] = call_module[target=foo](args = (%getitem_1, %getitem_2), kwargs = {}) + %bar : [num_users=1] = call_module[target=bar](args = (%getitem_6,), kwargs = {}) + ... + + Currently this optimization only works for the case where all of the outputs + of `foo` go directly into `bar`, and `bar` has no other inputs. + """ # noqa: B950 + + log.debug("Trying to remove pytrees for module call %s", curr_module_node) + + curr_module_users = list(curr_module_node.users.keys()) + assert len(curr_module_users) == 1, ( + f"Expected only one user for module node, instead got {list(curr_module_users)}" + ) + flatten_node = curr_module_users[0] + assert ( + flatten_node.op == "call_function" + and flatten_node.target is fx_pytree.tree_flatten_spec + ) + + flatten_getitem_users = _get_getitem_users(flatten_node) + if len(flatten_getitem_users) != 1: + log.debug( + "More than one user found for flatten node, %s: %s. " + "Unable to fuse it with another unflatten call.", + flatten_node, + flatten_getitem_users, + ) + return + + unflatten_node = next(iter(flatten_getitem_users)) + if not ( + unflatten_node.op == "call_function" + and unflatten_node.target is pytree.tree_unflatten + ): + log.debug( + "Flatten node %s's user is not a pytree.tree_unflatten. " + "Instead it is: %s. Passing...", + flatten_node, + unflatten_node, + ) + return + + for i, arg in enumerate(unflatten_node.args[0]): # type: ignore[union-attr,arg-type] + if arg not in flatten_node.users: + log.debug( + "Module %s's outputs are not all directly used as inputs to " + "the subsequent module. Unable to fuse the connecting " + "flatten/unflatten. The inputs to the subsequent module are: %s. ", + curr_module_node, + unflatten_node.args[0], + ) + return + + if not ( + # pyrefly: ignore [missing-attribute] + arg.op == "call_function" + # pyrefly: ignore [missing-attribute] + and arg.target is operator.getitem + # pyrefly: ignore [missing-attribute] + and arg.args[1] == i + ): + log.debug( + "Module %s's outputs are not all directly used in the same " + "order as outputted. Unable to fuse the connecting " + "flatten/unflatten. The inputs to the " + "subsequent module are: %s. ", + curr_module_node, + unflatten_node.args[0], + ) + return + + # Unflatten has two levels of getitem, because it gets the args and kwargs + unflatten_getitem_getitem_users = set() + unflatten_getitem_users = _get_getitem_users(unflatten_node) + for unflatten_getitem_user in unflatten_getitem_users: + unflatten_getitem_getitem_users.update( + list(unflatten_getitem_user.users.keys()) + ) + + if len(unflatten_getitem_getitem_users) != 1: + log.debug( + "More than one user found for unflatten node, %s: %s. " + "Unable to fuse it with another flatten call.", + unflatten_node, + unflatten_getitem_getitem_users, + ) + return + + next_module_node = next(iter(unflatten_getitem_getitem_users)) + if next_module_node.op != "call_module": + log.debug( + "Unflatten node %s's user is not a call_module. " + "Instead it is: %s. Passing...", + unflatten_node, + next_module_node, + ) + return + + # Directly put the outputs of the current module into the next module + next_module_node.args = (curr_module_node,) + + +def _remove_extraneous_pytrees(gm: torch.fx.GraphModule) -> None: + """ + Remove extraneous pytree flatten/unflatten calls. + + We try a couple of optimizations here: + 1. Remove pytree flatten/unflatten calls between modules + 2. TODO: Remove module's in_spec + initial unflatten call + 3. TODO: Remove module's out_spec + final flatten call + """ + + for node in gm.graph.nodes: + if node.op == "call_module" and node.target != "_guards_fn": + _try_remove_connecting_pytrees(node) + + gm.graph.eliminate_dead_code() + + +def _construct_inputs( + gm: torch.fx.GraphModule, + signature: ModuleCallSignature, + node_name_map: dict[str, torch.fx.Node], +) -> tuple[list[torch.fx.Node], dict[str, torch.fx.Node]]: + tree_unflatten_args: list[torch.fx.Node | None] = [] + for input_ in signature.inputs: + if isinstance(input_, ConstantArgument) and input_.value is None: + # Constants should be directly embedded into the graph and not used + # as inputs + tree_unflatten_args.append(None) + elif input_.name not in node_name_map: + # For unused inputs + tree_unflatten_args.append(None) + else: + tree_unflatten_args.append(node_name_map[input_.name]) + + # Insert unflatten call + from .unflatten import _generate_unflatten + + unflatten_node = _generate_unflatten(gm, tree_unflatten_args, signature.in_spec) + + assert signature.in_spec.num_children == 2 + assert signature.in_spec.type is tuple + args_spec, kwargs_spec = signature.in_spec.children() + assert args_spec.type is tuple + assert kwargs_spec.type is dict + + args_node = gm.graph.call_function(operator.getitem, (unflatten_node, 0)) + args_nodes = [ + gm.graph.call_function(operator.getitem, (args_node, i)) + for i in range(args_spec.num_children) + ] + kwargs_node = gm.graph.call_function(operator.getitem, (unflatten_node, 1)) + kwargs_nodes = { + k: gm.graph.call_function(operator.getitem, (kwargs_node, k)) + for k in kwargs_spec.context + } + return args_nodes, kwargs_nodes + + +def _insert_call_module( + gm: torch.fx.GraphModule, + args_nodes: list[torch.fx.Node], + kwargs_nodes: dict[str, torch.fx.Node], + module_to_swap: torch.nn.Module, + name: str, +) -> torch.fx.Node: + from .unflatten import _assign_attr, _AttrKind + + _assign_attr(module_to_swap, gm, name, _AttrKind.MODULE) + module_node = gm.graph.call_module(name, tuple(args_nodes), kwargs_nodes) # type: ignore[arg-type] + return module_node + + +def _deconstruct_outputs( + gm: torch.fx.GraphModule, + signature: ModuleCallSignature, + module_node: torch.fx.Node, + node_name_map: dict[str, torch.fx.Node], + orig_outputs: tuple[torch.fx.Node, ...], +) -> None: + from .unflatten import _generate_flatten_spec + + flatten_node = _generate_flatten_spec(gm, module_node, signature.out_spec) + + for i, orig_output in enumerate(orig_outputs): + # Use Proxy to record getitem access. + proxy_out = torch.fx.Proxy(flatten_node)[i].node # type: ignore[index] + orig_output.replace_all_uses_with(proxy_out, propagate_meta=True) + + node_name_map[orig_output.name] = proxy_out + + +def _swap_module_helper( + gm: torch.fx.GraphModule, + modules_to_swap: dict[str, torch.nn.Module], + module_call_graph: dict[str, ModuleCallSignature], +) -> torch.fx.GraphModule: + log.debug("Starting graph:") + log.debug(gm.graph) + + legalize_graph(gm) + + partitions: dict[str, NodeList] = defaultdict(list) + + node_name_map: dict[str, torch.fx.Node] = { + node.name: node for node in gm.graph.nodes + } + + # TODO: Handle the duplicate module case + for node in gm.graph.nodes: + if nn_module_stack := node.meta.get("nn_module_stack"): + for path, _ in nn_module_stack.values(): + if path in modules_to_swap: + partitions[path].append(node) + break + + for name, nodes in partitions.items(): + """ + Given a graph like the following, and we want to swap out the submodule "foo": + graph(): + %x : [num_users=1] = placeholder[target=x] + %y : [num_users=2] = placeholder[target=y] + %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%y, %x), kwargs = {}), nn_module_stack = {"foo": ("foo", torch.nn.Module)} + %sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%y, %add), kwargs = {}), nn_module_stack = {"bar": ("bar", torch.nn.Module)} + return (sub,) + + We will first partition out foo's subgraph: + graph(): + %x : [num_users=1] = placeholder[target=x] + %y : [num_users=2] = placeholder[target=y] + %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%y, %x), kwargs = {}) + return add + + And then insert an unflatten + call_module + flatten to replace the subgraph: + graph(): + %x : [num_users=1] = placeholder[target=x] + %y : [num_users=1] = placeholder[target=y] + + %_spec_0 : [num_users=1] = get_attr[target=_spec_0] + %tree_unflatten : [num_users=2] = call_function[target=torch.utils._pytree.tree_unflatten](args = ([%x, %y], %_spec_0), kwargs = {}) + %getitem : [num_users=2] = call_function[target=operator.getitem](args = (%tree_unflatten, 0), kwargs = {}) + %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%getitem, 0), kwargs = {}) + %getitem_2 : [num_users=1] = call_function[target=operator.getitem](args = (%getitem, 1), kwargs = {}) + %getitem_3 : [num_users=0] = call_function[target=operator.getitem](args = (%tree_unflatten, 1), kwargs = {}) + %foo : [num_users=0] = call_module[target=foo](args = (%getitem_1, %getitem_2), kwargs = {}) + %_spec_1 : [num_users=1] = get_attr[target=_spec_1] + %tree_flatten_spec : [num_users=1] = call_function[target=torch.fx._pytree.tree_flatten_spec](args = (None, %_spec_1), kwargs = {}) + %getitem_4 : [num_users=1] = call_function[target=operator.getitem](args = (%tree_flatten_spec, 0), kwargs = {}) + + %sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%y, %getitem_4), kwargs = {}) + return (%sub,) + + The `tree_unflatten` call will construct tensor inputs into the input + format needed by the swapped eager module. + The `call_module` node should now reference the swapped torch.nn.Module. + The `tree_flatten_spec` call will deconstruct the eager outputs of the + swapped module into tensors. + """ # noqa: B950 + + submod_name = name.replace(".", "_") + sub_gm, orig_inputs, orig_outputs = fuse_as_graphmodule( + gm, nodes, f"fused_{submod_name}" + ) + + log.debug("Fused subgraph nodes:") + log.debug(sub_gm.graph) + + signature: ModuleCallSignature = module_call_graph[name] + + args_nodes, kwargs_nodes = _construct_inputs(gm, signature, node_name_map) + module_node = _insert_call_module( + gm, args_nodes, kwargs_nodes, modules_to_swap[name], name + ) + _deconstruct_outputs(gm, signature, module_node, node_name_map, orig_outputs) + + erase_nodes(gm, nodes) + + log.debug("Swapped graph:") + log.debug(gm.graph) + + legalize_graph(gm) + + log.debug("Before removing extraneous pytrees:") + log.debug(gm.graph) + + _remove_extraneous_pytrees(gm) + log.debug("After removing extraneous pytrees:") + log.debug(gm.graph) + + gm.recompile() + + return gm + + +def _fix_input_output_signature( + gm: torch.fx.GraphModule, signature: ModuleCallSignature +) -> None: + """ + Given the unlifted module from calling ep.module(), we want to remove the + pytree processing from the graph module's PyTreeCodeGen and instead make it + nodes inside of the graph. This allows us to do some optimizations, like + remove these pytree calls if it is unnecessary, and makes the PyTree part + more obvious to graph passes. + """ + from torch.export.unflatten import _generate_flatten, _generate_unflatten + + # Remove the registered pytree codegen because we will take care of it + # through inserting pytree nodes into the graph + gm.graph._codegen = torch.fx.graph.CodeGen() + + old_placeholders = [node for node in gm.graph.nodes if node.op == "placeholder"] + + new_placeholders = [] + forward_arg_names = signature.forward_arg_names + if forward_arg_names is None: + forward_arg_names = [] + assert signature.in_spec.num_children == 2 + arg_spec = signature.in_spec.child(0) + kwarg_spec = signature.in_spec.child(1) + assert arg_spec.type is tuple + assert kwarg_spec.type is dict + for i in range(arg_spec.num_children): + forward_arg_names.append(f"arg_{i}") + forward_arg_names.extend(kwarg_spec.context) + + for arg in forward_arg_names: + with gm.graph.inserting_before(old_placeholders[0]): + new_placeholders.append(gm.graph.placeholder(arg)) + + # Insert flatten call for the inputs + with gm.graph.inserting_before(old_placeholders[0]): + flat_node = _generate_flatten(gm, tuple(new_placeholders)) + for i, old_placeholder in enumerate(old_placeholders): + old_placeholder.op = "call_function" + old_placeholder.target = operator.getitem + old_placeholder.args = (flat_node, i) + + # Insert unflatten call for the outputs + output_node = next(node for node in gm.graph.nodes if node.op == "output") + with gm.graph.inserting_before(output_node): + unflat = _generate_unflatten(gm, output_node.args[0], signature.out_spec) + output_node.args = (unflat,) + + gm.recompile() + + +def _swap_modules( + ep: ExportedProgram, modules_to_swap: dict[str, torch.nn.Module] +) -> torch.fx.GraphModule: + """ + Unlifts the given ExportedProgram into a fx.GraphModule, and then swaps + previously traced modules with new eager modules specified. Returns a + fx.GraphModule with a custom forward function. + + Args: + ep (ExportedProgram): Exported program to modify + modules_to_swap (Dict[str, torch.nn.Module]): Mapping from module fqn to + eager module to swap with. The specified module fqn should have also + been specified in the `preserve_module_call_signature` argument to + torch.export so that we know how to restore the calling convention + to this argument. + run_with_interpreter: Whether or not to run the graph using + fx.Interpreter. Setting to true will help result in better error + messages and easier debugging, but it has found to result in a QPS + drop. + """ + module_call_graph = { + entry.fqn: entry.signature for entry in ep.module_call_graph if entry.signature + } + + gm = ep.module() + gm.validate_inputs = False # type: ignore[assignment] + gm.graph.eliminate_dead_code() # type: ignore[operator, union-attr] + assert isinstance(gm, torch.fx.GraphModule) + _fix_input_output_signature(gm, ep.module_call_graph[0].signature) + + gm.module_call_graph = ep.module_call_graph + gm.train = types.MethodType(type(gm).train, gm) # type: ignore[assignment] + gm.eval = types.MethodType(type(gm).eval, gm) # type: ignore[assignment] + + assert isinstance(gm, torch.fx.GraphModule) + gm = _swap_module_helper(gm, modules_to_swap, module_call_graph) + + return gm diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/_trace.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/_trace.py new file mode 100644 index 0000000000000000000000000000000000000000..ac401dc433bdaa04333917551e03900b7250ccc1 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/_trace.py @@ -0,0 +1,2463 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +import dataclasses +import functools +import inspect +import logging +import re +import sys +import time +import warnings +from collections.abc import Callable +from contextlib import contextmanager, ExitStack, nullcontext +from itertools import chain +from typing import Any, TYPE_CHECKING, TypeAlias +from unittest import mock + + +if TYPE_CHECKING: + import weakref + +import torch +import torch._dynamo +import torch.fx +import torch.utils._pytree as pytree +from torch._dispatch.python import enable_python_dispatcher +from torch._dynamo.exc import UserError, UserErrorType +from torch._export.db.logging import ( + exportdb_error_message, + get_class_if_classified_error, +) +from torch._export.non_strict_utils import ( + _fakify_module_inputs, + _fakify_script_objects, + _gather_constant_attrs, + _NonStrictTorchFunctionHandler, + _override_builtin_ops, + make_constraints, + make_fake_inputs, + produce_guards_and_solve_constraints, +) +from torch._export.passes.collect_tracepoints_pass import CollectTracepointsPass +from torch._export.passes.lift_constants_pass import ( + _materialize_and_lift_constants, + ConstantAttrMap, +) +from torch._export.utils import ( + _collect_param_buffer_metadata, + _compiling_state_context, + _fakify_params_buffers, + _populate_param_buffer_metadata_to_new_gm, + _update_gm_meta_if_possible, + apply_runtime_assertion_pass, + placeholder_naming_pass, + placeholder_prefixes, +) +from torch._export.verifier import SpecViolationError +from torch._export.wrappers import _wrap_submodules +from torch._functorch._aot_autograd.graph_capture_wrappers import create_functional_call +from torch._functorch._aot_autograd.input_output_analysis import ( + _graph_input_names, + _graph_output_names, +) +from torch._functorch._aot_autograd.schemas import GraphSignature +from torch._functorch._aot_autograd.subclass_utils import get_subclass_typing_container +from torch._functorch._aot_autograd.utils import ( + create_tree_flattened_fn, + register_buffer_assignment_hook, +) +from torch._functorch.aot_autograd import ( + _detect_attribute_assignment, + aot_export_joint_with_descriptors, +) +from torch._guards import detect_fake_mode, tracing, TracingContext +from torch._library.fake_class_registry import FakeScriptObject +from torch._logging import dtrace_structured +from torch._subclasses.fake_tensor import FakeTensorMode +from torch._utils_internal import log_export_usage +from torch.export._leakage_detection_utils import find_legit_leaks_from_referrers +from torch.export._unlift import _check_input_constraints_pre_hook +from torch.export.dynamic_shapes import ( + _check_dynamic_shapes, + _combine_args, + _DimHintType, + _IntWrapper, + _process_dynamic_shapes, +) +from torch.export.exported_program import OutputKind +from torch.fx._symbolic_trace import _ConstantAttributeType +from torch.fx.experimental.proxy_tensor import ( + get_proxy_slot, + make_fx, + PreDispatchTorchFunctionMode, + track_tensor_tree, +) +from torch.fx.experimental.symbolic_shapes import ( + ConstraintViolationError, + free_unbacked_symbols, + GuardOnDataDependentSymNode, + ShapeEnv, +) +from torch.fx.graph import _PyTreeInfo +from torch.utils._pytree import TreeSpec +from torch.utils._sympy.value_ranges import ValueRangeError + +from .exported_program import ( + _disable_prexisiting_fake_mode, + ExportedProgram, + InputKind, + ModuleCallEntry, + ModuleCallSignature, +) +from .graph_signature import _convert_to_export_graph_signature, ExportGraphSignature + + +log = logging.getLogger(__name__) + +# Type alias for dynamic shapes specification +_DynamicShapesSpec: TypeAlias = dict[str, Any] | tuple[Any, ...] | list[Any] + + +@dataclasses.dataclass +class ExportDynamoConfig: + """ + Manage Export-specific configurations of Dynamo. + """ + + allow_rnn: bool = True + reorderable_logging_functions: set[Callable] = dataclasses.field( + default_factory=set + ) + # Emit runtime asserts after AOTAutograd instead. + # This isn't really necessary, and isn't much more efficient since the runtime asserts pass does CSE, + # but if we want to reason more about what guards/runtime asserts to emit, + # this makes it a bit cleaner to do from the export side. Also no real point in running this twice. + do_not_emit_runtime_asserts: bool = True + specialize_int: bool = True + specialize_float: bool = True + assume_static_by_default: bool = False + automatic_dynamic_shapes: bool = False + capture_dynamic_output_shape_ops: bool = True + capture_scalar_outputs: bool = True + prefer_deferred_runtime_asserts_over_guards: bool = False + replay_side_effects: bool = False + side_effect_replay_policy: str = "warn" + + +@dataclasses.dataclass +class ATenExportArtifact: + gm: torch.fx.GraphModule + sig: ExportGraphSignature + constants: dict[str, _ConstantAttributeType] + + +@dataclasses.dataclass(frozen=True) +class ExportArtifact: + aten: ATenExportArtifact + in_spec: TreeSpec + out_spec: TreeSpec + fake_mode: FakeTensorMode + module_call_specs: dict[str, dict[str, pytree.TreeSpec]] + + +DEFAULT_EXPORT_DYNAMO_CONFIG = ExportDynamoConfig() +DEFAULT_EXPORT_DYNAMO_CONFIG.reorderable_logging_functions = { + logging.critical, + logging.debug, + logging.error, + logging.exception, + logging.info, + logging.log, + logging.warning, + print, + warnings.warn, +} + + +@contextmanager +def _ignore_backend_decomps(): + orig_mkldnn_flag = torch.backends.mkldnn.set_flags(False) + orig_nnpack_flag = torch.backends.nnpack.set_flags(False) + try: + yield + finally: + torch.backends.mkldnn.set_flags(*orig_mkldnn_flag) + torch.backends.nnpack.set_flags(*orig_nnpack_flag) + + +@contextmanager +def _disable_custom_triton_op_functional_decomposition(): + old = torch._functorch.config.decompose_custom_triton_ops + try: + # pyrefly: ignore [bad-assignment] + torch._functorch.config.decompose_custom_triton_ops = False + yield torch._functorch.config.decompose_custom_triton_ops + finally: + torch._functorch.config.decompose_custom_triton_ops = old + + +def custom_triton_ops_decomposition_disabled(): + return not torch._functorch.config.decompose_custom_triton_ops + + +def _fixup_key(x): + return "L__self__" + _strip_root(x) + + +def _strip_root(x): + if isinstance(x, str) and x.startswith("_export_root"): + stripped = x[len("_export_root") :] + return stripped.removeprefix(".") + return x + + +def _is_bogus_const_name(name: str): + splitted_names = name.split(".") + if len(splitted_names) < 1: + return True + + return splitted_names[-1].startswith("lifted_tensor") + + +def _rewrite_tracepoint_node(gm: torch.fx.GraphModule): + """ + In-place modify input graph module by replacing the export tracepoint with a new node + that has the same target and args, but with the _export_root stripped from path. + """ + for node in gm.graph.nodes: + if node.target is torch.ops.higher_order._export_tracepoint: + if "path" in node.kwargs: + path = _strip_root(node.kwargs["path"]) + with gm.graph.inserting_before(node): + new_node = gm.graph.create_node( + "call_function", + torch.ops.higher_order._export_tracepoint, + args=node.args, + kwargs={ + "path": path, + "kind": node.kwargs["kind"], + }, + ) + new_node.meta = node.meta + node.replace_all_uses_with(new_node) + gm.graph.erase_node(node) + + +def detect_shape_env(inputs: Any = None): + shape_envs = [] + + for i, flat_input in enumerate(inputs): + if isinstance(flat_input, torch.SymInt): + shape_envs.append((flat_input.node.shape_env, "symint input", i)) + + if shape_envs: + shape_env, desc1, i1 = shape_envs[0] + for m, desc2, i2 in shape_envs[1:]: + assert shape_env is m, ( + f"shape env ({shape_env}) from {desc1} {i1} doesn't match mode ({m}) from {desc2} {i2}\n\n" + f"shape env from {desc1} {i1} allocated at:\n{shape_env.stack}\n" + f"shape env from {desc2} {i2} allocated at:\n{m.stack}" + ) + return shape_env + else: + return None + + +def _extract_fake_inputs(gm, args, kwargs): + """ + Given a graph module, extract fakified input tensors from the metadata of + its placeholders, and map them to the structure of given args and kwargs. + Also return the fake mode used to fakify those inputs. + """ + fake_inps: list[Any] = [] + fake_vals: list[Any] = [] + for node in gm.graph.nodes: + if node.op == "placeholder": + fake_inps.append(node.meta.get("val")) + else: + fake_vals.append(node.meta.get("example_value")) + + if in_shuffle_graph := getattr(gm, "_in_shuffle_graph", None): + flat_args = pytree.tree_leaves((args, kwargs)) + node_map = { + node: i + for i, node in enumerate( + next(iter(reversed(in_shuffle_graph.graph.nodes))).args[0] + ) + if node.op == "placeholder" + } + new_fake_inps: list[Any] = [] + for i, node in enumerate( + in_shuffle_graph.graph.find_nodes(op="placeholder")[1:] + ): + if node in node_map: + new_fake_inps.append(fake_inps[node_map[node]]) + else: + new_fake_inps.append(flat_args[i]) + fake_inps = new_fake_inps + # We get both because now we might have a combination of symint and tensor + # inputs, and we want to check that the shape env is consistent between + # both. Unfortunately we can't see what fake mode is attached to the shape + # env, then we can just compare fake modes. + detected_fake_mode = detect_fake_mode(fake_inps + fake_vals) + detected_shape_env = detect_shape_env(fake_inps + fake_vals) + + if detected_fake_mode: + if detected_shape_env: + assert detected_shape_env is detected_fake_mode.shape_env, ( + "Detected shape env does not match fake mode's shape env" + ) + fake_mode = detected_fake_mode + elif detected_shape_env: + fake_mode = FakeTensorMode(shape_env=detected_shape_env, export=True) + else: + fake_mode = FakeTensorMode(shape_env=ShapeEnv(), export=True) + + count = 0 + + def lookup_fake(x): + nonlocal count + val = fake_inps[count] if isinstance(x, (int, torch.Tensor)) else x + count += 1 + return val + + fake_args = pytree.tree_map(lookup_fake, args) + fake_kwargs = pytree.tree_map(lookup_fake, kwargs) + + return fake_args, fake_kwargs, fake_mode + + +def _replace_param_buffer_names(param_buffer_table, sig): + for spec in sig.input_specs: + if spec.kind in ( + InputKind.PARAMETER, + InputKind.BUFFER, + ): + spec.target = param_buffer_table[spec.target] + for spec in sig.output_specs: + if spec.kind in ( + OutputKind.BUFFER_MUTATION, + OutputKind.GRADIENT_TO_PARAMETER, + ): + spec.target = param_buffer_table[spec.target] + + +def _convert_to_positional_args(orig_arg_names, args, kwargs): + assert len(orig_arg_names) == len(args) + len(kwargs), ( + f"Total number of arg names is expected to be {len(orig_arg_names)} " + f"but got {len(args)} positional args, {len(kwargs)} kwargs." + ) + reordered_kwargs = [kwargs[kw_name] for kw_name in orig_arg_names[len(args) :]] + return ( + *args, + *reordered_kwargs, + ) + + +def _normalize_nn_module_stack(gm_torch_level, root_cls): + # Append a root module to every nn_module_stack. + root = "L['self']" + root_key = re.sub(r"[^a-zA-Z0-9]", "_", root) + for gm in gm_torch_level.modules(): + if not isinstance(gm, torch.fx.GraphModule): + continue + for node in gm.graph.nodes: + if node.op in ["placeholder", "output"]: + continue + add_root = True + if nn_module_stack := node.meta.get("nn_module_stack", {}): + path, ty = next(iter(nn_module_stack.values())) + # After deserializing the class `ty` might not exist anymore so + # it could be a string + if inspect.isclass(ty) and issubclass(ty, torch.nn.Module): + # TODO Figure out why sometimes we have root sometimes we don't. + if path == root and ty is root_cls: + add_root = False + else: + assert isinstance(ty, str) + if add_root: + + def normalize_path(path): + if path == "L['self']": + return "" + if path.startswith("L['self']."): + return path[len("L['self'].") :] + return path + + nn_module_stack = { + root_key: (root, root_cls.__module__ + "." + root_cls.__qualname__), + # pyrefly: ignore [unbound-name] + **nn_module_stack, + } + node.meta["nn_module_stack"] = { + key: (normalize_path(path), ty) + for key, (path, ty) in nn_module_stack.items() + } + + +def _get_param_buffer_mapping( + original_module: torch.nn.Module, + traced_module: torch.nn.Module, +) -> dict[str, str]: + """ + Returns a mapping of parameter/buffer names from the new module to the + original model. This is to help with restoring the FQN for parameter/buffers + of a traced module to what the original module contains. + """ + + param_lookup: dict[int, str] = {} + buffer_lookup: dict[int, str] = {} + for name, param in original_module.named_parameters(remove_duplicate=False): + if param_lookup.get(id(param)) is None: + # we only want to keep the first occurrence of a parameter to guarantee parity of original and traced module. + param_lookup[id(param)] = name + for name, buffer in original_module.named_buffers(remove_duplicate=False): + buffer_lookup[id(buffer)] = name + + param_buffer_table: dict[str, str] = {} + for dynamo_name, dynamo_param in traced_module.named_parameters( + remove_duplicate=False + ): + assert dynamo_name not in param_buffer_table + if id(dynamo_param) in param_lookup: + param_buffer_table[dynamo_name] = param_lookup[id(dynamo_param)] + + for dynamo_name, dynamo_buffer in traced_module.named_buffers( + remove_duplicate=False + ): + assert dynamo_name not in param_buffer_table + if id(dynamo_buffer) in buffer_lookup: + param_buffer_table[dynamo_name] = buffer_lookup[id(dynamo_buffer)] + + return param_buffer_table + + +def _preserve_requires_grad_pass( + gm: torch.fx.GraphModule, + sig: ExportGraphSignature, + fake_params_buffers: dict[str, torch.Tensor], + constants: dict[str, _ConstantAttributeType], + flat_fake_args: list[Any], +): + placeholders = [node for node in gm.graph.nodes if node.op == "placeholder"] + assert len(sig.input_specs) == len(placeholders) + i = 0 + for node, spec in zip(placeholders, sig.input_specs): + if spec.kind in ( + InputKind.PARAMETER, + InputKind.BUFFER, + ): + assert spec.target is not None + node.meta["val"].requires_grad = fake_params_buffers[ + spec.target + ].requires_grad + elif spec.kind == InputKind.USER_INPUT: + fake_arg = flat_fake_args[i] + if isinstance(fake_arg, torch.Tensor): + node.meta["val"].requires_grad = fake_arg.requires_grad + i += 1 + elif spec.kind == InputKind.CONSTANT_TENSOR: + assert spec.target is not None + constant = constants[spec.target] + if isinstance(constant, torch.Tensor): + # If the tensor is not leaf, it should already have a correct requires grad field + if node.meta["val"].is_leaf: + node.meta["val"].requires_grad = constant.requires_grad + else: + assert node.meta["val"].requires_grad == constant.requires_grad + elif spec.kind in (InputKind.CUSTOM_OBJ, InputKind.TOKEN): + continue + else: + raise AssertionError(spec.kind) + + +def _remap_constants( + orig_constant_attrs: ConstantAttrMap, + graph_signature: ExportGraphSignature, + constants: dict[str, _ConstantAttributeType], +) -> None: + """Rewrite the graph signature and constants table to use the FQN from the original module.""" + remap_table: dict[str, list[str]] = {} + for name, value in constants.items(): + if value in orig_constant_attrs: + remap_table[name] = orig_constant_attrs[value] + + for spec in graph_signature.input_specs: + if spec.kind in ( + InputKind.CONSTANT_TENSOR, + InputKind.CUSTOM_OBJ, + ): + orig_target = spec.target + assert orig_target is not None + targets = remap_table.get(orig_target, [orig_target]) + spec.target = targets[0] + + constant = constants[orig_target] + del constants[orig_target] + for target in targets: + constants[target] = constant + + +def _replace_unbacked_bindings(gm: torch.fx.GraphModule) -> None: + """ + When we run an interpreter-based pass over a GraphModule, execution of data-dependent operators + will produce example values with new unbacked symbols. To track that the new/old symbols are equivalent, + we used to rely on the unbacked_renamings mapping. This led to problematic metadata where the unbacked_bindings + keys mapped new symbols (u2) to paths containing old symbols (u0) in the example values, or worse, backed symbols + or constants (e.g. if the original unbacked was replaced/specialized). Additionally this created problems with + de/serialized programs, since we didn't comprehensively serialize ShapeEnv/unbacked renamings/node bindings. + + This pass attempts a simpler way of handling these for export, by throwing away the previously computed bindings, and re-running + the pattern match used in compute_unbacked_bindings. This ensures we keep the original symbols contained in the example values, + or delete bindings if they've been replaced/specialized. + """ + from torch._export.utils import _get_shape_env_from_gm + from torch.fx.experimental.symbolic_shapes import _free_unbacked_symbols_with_path + from torch.utils._sympy.symbol import symbol_is_type, SymT + + if (shape_env := _get_shape_env_from_gm(gm)) is None: + return + + base_unbacked_symbols = { + symbol + for symbol in shape_env.var_to_range + if symbol_is_type(symbol, (SymT.UNBACKED_INT, SymT.UNBACKED_FLOAT)) + and symbol not in shape_env.unbacked_renamings + } + for node in gm.graph.nodes: + node.meta.pop("unbacked_bindings", None) + if (val := node.meta.get("val")) is not None and ( + unbacked_bindings := _free_unbacked_symbols_with_path( + val, + (), + shape_env=shape_env, + pending=base_unbacked_symbols, + simplify=True, + ) + ): + node.meta["unbacked_bindings"] = unbacked_bindings + + +def _produce_aten_artifact( + *, + gm: torch.fx.GraphModule, + mod, + constant_attrs, + graph_signature, + pre_dispatch, + fake_args, + fake_kwargs, + fake_params_buffers, + _prettify_placeholder_names=True, +) -> ATenExportArtifact: + """ + This is a helper function that is shared between export_to_aten_ir and export_to_aten_ir_make_fx + to produce the aten artifact. (export compatible graph module + signature) + + It does: + 1. Applies runtime assertion pass + 2. Recompute unbacked_bindings pass + 3. Populate meta val when missing + 4. Lift constants as placeholders + 5. Replace raw autograd and autocast ops with HOPs + 6. Prettify names for placeholders + 7. Preserve requires_grad value on node meta val + """ + # Run runtime asserts pass before creating input/output specs, since size-related CSE/DCE might affect output signature. + # Overwrite output specs afterwards. + flat_fake_args = pytree.tree_leaves((fake_args, fake_kwargs)) + gm, graph_signature = apply_runtime_assertion_pass(gm, graph_signature) + + # Simplify unbacked_bindings by recomputing them. + # Useful for any pass that's interpreter-based and might call rebind_unbacked(), + # e.g. AOTAutograd in this case. + _replace_unbacked_bindings(gm) + + total_non_user_inputs = ( + len(graph_signature.parameters) + + len(graph_signature.buffers) + + len(graph_signature.input_tokens) + ) + set_missing_meta_vals(gm, flat_fake_args, total_non_user_inputs) + + export_graph_signature: ExportGraphSignature | None + export_graph_signature = _convert_to_export_graph_signature( + graph_signature, gm, _get_non_persistent_buffers(mod) + ) + + # script objects are always stored in constants no matter whether they're initial inputs or + # they're lifted in aot" before rewrite_script_object_meta + constants = _materialize_and_lift_constants( + gm, export_graph_signature, constant_attrs + ) + + if pre_dispatch: + from torch._export.passes.replace_autocast_with_hop_pass import ( + replace_autocast_with_hop_pass, + ) + from torch._export.passes.replace_set_grad_with_hop_pass import ( + replace_set_grad_with_hop_pass, + ) + + # Note: replace_set_grad_with_hop_pass need to be after lift_constant_pass because + # a getattr of a constant tensor doesn't have meta["val"] until after lift_constant_pass. + # If replace_set_grad_with_hop_pass is before lift_constant_pass, + # and the constant_tensor is passed as input of the set grad hop, the placeholder's + # meta["val"] will be None and fails our verifier for placeholder. + gm, export_graph_signature = replace_set_grad_with_hop_pass( + gm, export_graph_signature + ) + + gm, export_graph_signature = replace_autocast_with_hop_pass( + gm, export_graph_signature + ) + + # Remove nn_module_stack, stack_trace metadata from all placeholders/inputs nodes. + for _mod in gm.modules(): + if not isinstance(_mod, torch.fx.GraphModule): + continue + for node in _mod.graph.nodes: + if node.op in ["placeholder", "output"]: + node.meta.pop("nn_module_stack", None) + node.meta.pop("stack_trace", None) + + # Prettify names for placeholder nodes. + assert export_graph_signature is not None + if _prettify_placeholder_names: + placeholder_naming_pass( + gm, + export_graph_signature, + mod, + fake_args, + fake_kwargs, + fake_params_buffers, + constants, + ) + + _preserve_requires_grad_pass( + gm, export_graph_signature, fake_params_buffers, constants, flat_fake_args + ) + + return ATenExportArtifact( + gm, + export_graph_signature, + constants, + ) + + +def _rename_constants_nodes( + gm: torch.fx.GraphModule, + graph_signature: ExportGraphSignature, +) -> None: + """ + For strict mode, rename constants nodes that were previously annotated as buffers. + """ + # handle name collisions with existing constants + node_names = {node.name for node in gm.graph.nodes} + + def rename_constant(name): + if name in node_names: + n = 1 + while (dup_name := f"{name}_{n}") in node_names: + n += 1 + name = dup_name + node_names.add(name) + return name + + # use input specs to map names from buffers to constants + buffer_prefix = placeholder_prefixes[InputKind.BUFFER] + const_prefix = placeholder_prefixes[InputKind.CONSTANT_TENSOR] + buffer_to_constant = {} + for spec in graph_signature.input_specs: + if spec.kind == InputKind.CONSTANT_TENSOR and not spec.arg.name.startswith( + const_prefix + ): + if spec.arg.name.startswith(buffer_prefix): # map from buffer to constants + c_name = rename_constant( + const_prefix + spec.arg.name[len(buffer_prefix) :] + ) + else: # lifted constant + c_name = rename_constant(const_prefix + spec.arg.name) + buffer_to_constant[spec.arg.name] = c_name + spec.arg.name = c_name + for spec in graph_signature.output_specs: + if spec.arg.name in buffer_to_constant: + spec.arg.name = buffer_to_constant[spec.arg.name] + + # Rename constants nodes for all modules + for mod in gm.modules(): + if not isinstance(mod, torch.fx.GraphModule): + continue + for node in mod.graph.nodes: + if node.name in buffer_to_constant: + node.name = node.target = buffer_to_constant[node.name] + mod.recompile() + + +def _restore_state_dict( + original_module: torch.nn.Module, traced_module: torch.fx.GraphModule +) -> None: + """ + Restores the state dict of the traced module to that of the original module. + """ + param_buffer_table = _get_param_buffer_mapping(original_module, traced_module) + # Don't want to change the convention of previous call. + param_buffer_table_reverse = {v: k for k, v in param_buffer_table.items()} + + # Replace state dict attr names with the fqn + for name, _ in list( + chain( + original_module.named_parameters(remove_duplicate=False), + # pyrefly: ignore [bad-argument-type] + original_module.named_buffers(remove_duplicate=False), + ) + ): + if name in param_buffer_table_reverse: + dynamo_name = param_buffer_table_reverse[name] + param = torch.fx.graph_module._get_attr(traced_module, dynamo_name) + torch.fx.graph_module._assign_attr(param, traced_module, name) + torch.fx.graph_module._del_attr(traced_module, dynamo_name) + + # Replace graph getattr nodes with the correct name + for node in traced_module.graph.nodes: + if node.op == "get_attr": + attr_name = node.target + if attr_name in param_buffer_table: + node.target = param_buffer_table[attr_name] + + traced_module.recompile() + + +def _get_module_hierarchy(mod: torch.nn.Module) -> dict[str, str]: + return { + name: type(m).__name__ for name, m in mod.named_modules(remove_duplicate=False) + } + + +def _make_module_call_graph( + in_spec: TreeSpec, + out_spec: TreeSpec, + module_call_signatures: dict[str, ModuleCallSignature], + forward_arg_names: list[str] | None = None, +) -> list[ModuleCallEntry]: + original = [ + ModuleCallEntry(fqn=fqn, signature=module_call_signatures.get(fqn)) + for fqn in _EXPORT_MODULE_HIERARCHY # type: ignore[union-attr] + ] + assert original[0].fqn == "" + original[0].signature = ModuleCallSignature( + inputs=[], + outputs=[], + in_spec=in_spec, + out_spec=out_spec, + forward_arg_names=forward_arg_names, + ) + additional = [ + ModuleCallEntry(fqn=fqn, signature=signature) + for fqn, signature in module_call_signatures.items() + if fqn not in _EXPORT_MODULE_HIERARCHY # type: ignore[operator] + ] + return [*original, *additional] + + +class _ExportModuleSpecTrackerDict(dict): + pass + + +def _export_to_torch_ir( + f: Callable, + args: tuple[Any, ...], + kwargs: dict[str, Any] | None = None, + dynamic_shapes: dict[str, Any] | tuple[Any] | list[Any] | None = None, + *, + preserve_module_call_signature: tuple[str, ...] = (), + disable_constraint_solver: bool = False, + prefer_deferred_runtime_asserts_over_guards: bool = False, + restore_fqn: bool = True, + _log_export_usage: bool = True, + same_signature: bool = True, +) -> torch.fx.GraphModule: + """ + Traces either an nn.Module's forward function or just a callable with PyTorch + operations inside and produce a torch.fx.GraphModule in torch IR. + """ + + if _log_export_usage: + log_export_usage(event="export.private_api", flags={"_export_to_torch_ir"}) + + if not isinstance(args, tuple): + raise UserError( + UserErrorType.INVALID_INPUT, + f"Expecting `args` to be a tuple of example positional inputs, got {type(args)}", + ) + + kwargs = kwargs or {} + + # Map ints to a wrapper structure to help us mark it as dynamic, if it is + # dynamic. We will unwrap ints in fakify later. + args, kwargs = pytree.tree_map_only(int, _IntWrapper, (args, kwargs)) + + combined_args = _combine_args(f, args, kwargs) + _check_dynamic_shapes(combined_args, dynamic_shapes) + constraints = _process_dynamic_shapes(combined_args, dynamic_shapes) + + # Unwrap static ints -- in the case where we have an empty graph + # containing just integer computation, dynamo will run its generated + # bytecode with these args/kwargs, which will error because we cannot + # directly apply int operations on IntWrapper. So we will just unwrap + # them here. + args, kwargs = pytree.tree_map_only( + _IntWrapper, + lambda a: a.val + if a.dynamism is None or a.dynamism.type == _DimHintType.STATIC + else a, + (args, kwargs), + ) + + dynamo_cfg = dataclasses.replace( + DEFAULT_EXPORT_DYNAMO_CONFIG, + prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards, + ) + + def use_legacy_dynamo_graph_capture() -> bool: + return bool( + constraints # dynamic shape + or dynamic_shapes # dynamic shape + or isinstance(f, torch.fx.GraphModule) # retracing + or preserve_module_call_signature # unflatten + or torch._functorch.config.fake_tensor_propagate_real_tensors # draft + or torch._export.config.use_legacy_dynamo_graph_capture + ) + + with torch._dynamo.config.patch(dataclasses.asdict(dynamo_cfg)): + try: + module_call_specs: dict[str, dict[str, pytree.TreeSpec]] = ( + _ExportModuleSpecTrackerDict() + ) + ctx = nullcontext() + if not isinstance(f, torch.fx.GraphModule): + ctx = _wrap_submodules( # type: ignore[assignment] + f, preserve_module_call_signature, module_call_specs + ) + with ctx, _ignore_backend_decomps(): + if torch._export.config.use_new_tracer_experimental: + from torch._dynamo.functional_export import ( + _dynamo_graph_capture_for_export, + dynamo_graph_capture_for_export, + ) + + if use_legacy_dynamo_graph_capture(): + dynamo_graph_capture = _dynamo_graph_capture_for_export( + f, constraints=constraints, dynamic_shapes=dynamic_shapes + ) + else: + dynamo_graph_capture = torch._dynamo.config.patch( + replay_side_effects=False + )(dynamo_graph_capture_for_export(f)) + # We can't serialize entire fake mode yet, so this is to make sure + # things like copy.deepcopy(ep.graph_module) not crash. + # see test_export.py::test_custom_tag_metadata_re_export + # Once we delete the old strict export, we can use + gm_torch_level = dynamo_graph_capture(*args, **kwargs) + # We can't serialize entire fake mode yet, so this is to make sure + # things like copy.deepcopy(ep.graph_module) not crash. + # see test_export.py::test_custom_tag_metadata_re_export + # Once we delete the old strict export, we can use this fake mode in the + # subsequent logic when lowering to aten IR. + del gm_torch_level.meta["fake_mode"] + + else: + gm_torch_level, _ = torch._dynamo.export( + f, + dynamic_shapes=dynamic_shapes, # type: ignore[arg-type] + constraints=constraints, # type: ignore[arg-type] + assume_static_by_default=True, + tracing_mode="symbolic", + disable_constraint_solver=disable_constraint_solver, + prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards, + _log_export_usage=_log_export_usage, + same_signature=same_signature, + )( + *args, + **kwargs, + ) + gm_torch_level.meta["module_call_specs"] = module_call_specs + except (ConstraintViolationError, ValueRangeError) as e: + raise UserError(UserErrorType.CONSTRAINT_VIOLATION, str(e)) # noqa: B904 + except GuardOnDataDependentSymNode as e: + raise UserError( # noqa: B904 + UserErrorType.ANTI_PATTERN, + f"Consider annotating your code using torch._check*(). {str(e)}", + case_name="constrain_as_size_example", + ) + + if isinstance(f, torch.nn.Module) and restore_fqn: + _restore_state_dict(f, gm_torch_level) + + return gm_torch_level + + +def _aot_export_joint_with_descriptors( + stack, + mod, + args, + *, + kwargs, + decompositions, + fake_params_buffers, + _record_nn_module_stack=True, +): + from torch._functorch._aot_autograd.graph_compile import aot_stage2_export + from torch._functorch._aot_autograd.input_output_analysis import ( + create_graph_signature, + ) + + joint_with_descriptors = aot_export_joint_with_descriptors( + stack, + mod, + args, + kwargs=kwargs, + decompositions=decompositions, + _record_nn_module_stack=_record_nn_module_stack, + ) + # Convert JointWithDescriptors to graph module and ViewAndMutationMeta + gm, fw_metadata = aot_stage2_export( + joint_with_descriptors._aot_state, + joint_with_descriptors._aot_graph_capture, + ) + + assert isinstance(gm, torch.fx.GraphModule) + + # Create GraphSignature from the metadata + graph_signature = create_graph_signature( + gm, + fw_metadata, + joint_with_descriptors.in_spec, + joint_with_descriptors.out_spec, + user_args_flat=pytree.tree_leaves((args, kwargs)), + params_and_buffers_flat=list(fake_params_buffers.values()), + param_names=joint_with_descriptors.params_spec, + buffer_names=joint_with_descriptors.buffers_spec, + trace_joint=False, + num_user_fw_outs=None, + loss_index=None, + ) + return gm, graph_signature + + +def _export_to_aten_ir( + mod: torch.nn.Module, + fake_args, + fake_kwargs, + fake_params_buffers, + constant_attrs: ConstantAttrMap, + produce_guards_callback=None, + *, + transform=lambda x: x, # TODO(zhxchen17) Revisit if this is needed later. + pre_dispatch=False, + decomp_table=None, + _prettify_placeholder_names: bool = True, + decompose_custom_triton_ops: bool = False, +) -> ATenExportArtifact: + custom_triton_ops_decomposition_ctx = ( + nullcontext + if decompose_custom_triton_ops + else _disable_custom_triton_op_functional_decomposition + ) + # This _reparameterize_module makes sure inputs and module.params/buffers have the same fake_mode, + # otherwise aot_export_module will error out because it sees a mix of fake_modes. + # And we want aot_export_module to use the fake_tensor mode in dynamo to keep the pipeline easy to reason about. + with ExitStack() as stack: + stack.enter_context( + torch.nn.utils.stateless._reparametrize_module( + mod, + fake_params_buffers, + tie_weights=True, + strict=True, + stack_weights=True, + ) + ) + stack.enter_context(_ignore_backend_decomps()) + stack.enter_context(_compiling_state_context()) + stack.enter_context(custom_triton_ops_decomposition_ctx()) + stack.enter_context(torch.no_grad()) + + gm, graph_signature = transform(_aot_export_joint_with_descriptors)( + stack, + mod, + fake_args, + kwargs=fake_kwargs, + decompositions=decomp_table, + fake_params_buffers=fake_params_buffers, + _record_nn_module_stack=True, + ) + + def _maybe_fixup_gm_and_output_node_meta(old_gm, new_gm): + if isinstance(old_gm, torch.fx.GraphModule): + if hasattr(old_gm, "meta"): + new_gm.meta.update(old_gm.meta) + old_output_node = list(old_gm.graph.nodes)[-1] + new_output_node = list(new_gm.graph.nodes)[-1] + assert old_output_node.op == "output" and new_output_node.op == "output" + # make sure we don't override any meta + if "desc" in new_output_node.meta: + del new_output_node.meta["desc"] + new_output_node.meta.update(old_output_node.meta) + + # TODO unfortunately preserving graph-level metadata and output node's meta + # is not working well with aot_export. So we manually copy it. + # (The node-level meta is addressed above.) + _maybe_fixup_gm_and_output_node_meta(mod, gm) + + # Run produce guards before we handle runtime asserts. + # This means we run the export solver before the runtime asserts pass. + # Right now this doesn't mean much - the export solver is only there for suggested fixes, + # and we won't even get to constraint solving if that's needed. + # But if in future we want to control what runtime asserts are emitted for export, + # or rely on produce_guards + solver for some simplification on runtime asserts, this probably makes sense. + if produce_guards_callback: + try: + produce_guards_callback(gm) + except (ConstraintViolationError, ValueRangeError) as e: + raise UserError(UserErrorType.CONSTRAINT_VIOLATION, str(e)) # noqa: B904 + + return _produce_aten_artifact( + gm=gm, + mod=mod, + constant_attrs=constant_attrs, + graph_signature=graph_signature, + pre_dispatch=pre_dispatch, + fake_args=fake_args, + fake_kwargs=fake_kwargs, + fake_params_buffers=fake_params_buffers, + _prettify_placeholder_names=_prettify_placeholder_names, + ) + + +def _get_forward_arg_names( + mod: torch.nn.Module, + args: tuple[Any, ...], + kwargs: dict[str, Any] | None = None, +) -> list[str]: + """ + Gets the argument names to forward that are used, for restoring the + original signature when unlifting the exported program module. + - Positional args: retain the original argument names, and enumerate + *args as args_0, args_1, ... + - Keyword args: retain the original kwarg names in the order specified + by the user. This order seems to matter for the current state of + export lifted modules. + """ + sig = inspect.signature(mod.forward) + _args = sig.bind_partial(*args).arguments + + names: list[str] = [] + for name, value in _args.items(): + # handle variable number of positional args + if sig.parameters[name].kind == inspect._ParameterKind.VAR_POSITIONAL: + names.extend([f"{name}_{i}" for i, _ in enumerate(value)]) + else: + names.append(name) + # order of kwargs matters for input spec + if kwargs: + names.extend([kwarg for kwarg, _ in kwargs.items()]) + + return names + + +def _get_non_persistent_buffers(mod: torch.nn.Module) -> set[str]: + """ + Returns set of non-persistent buffers in a module and its submodules. + """ + result: set[str] = set() + for name, m in mod.named_modules(remove_duplicate=False): + if name: + result.update(f"{name}.{b}" for b in m._non_persistent_buffers_set) + else: + result.update(m._non_persistent_buffers_set) + return result + + +def _rewrite_dynamo_tensor_constants( + orig_mod_buffers: set[torch.Tensor], + traced_mod_buffers: dict[str, torch.Tensor], + graph_signature: ExportGraphSignature, + constants: dict[str, _ConstantAttributeType], +) -> None: + """ + Dynamo erroneously marks tensor attributes on modules as buffers. + Rewrite them to be tensor constants. + """ + for spec in graph_signature.input_specs: + if spec.kind == InputKind.BUFFER: + assert spec.target is not None + value = traced_mod_buffers[spec.target] + if value not in orig_mod_buffers: + # This was a tensor constant erroneously marked as a buffer. + # Convert it into a constant in the graph signature, and add its + # value to the constants table. + spec.kind = InputKind.CONSTANT_TENSOR + constants[spec.target] = value # type: ignore[arg-type] + + +def _move_non_persistent_buffers_to_tensor_constants( + orig_mod: torch.nn.Module, + graph_signature: ExportGraphSignature, + constants: dict[str, _ConstantAttributeType], +) -> None: + """ + Moves non-persistent buffers to tensor constants. + """ + for spec in graph_signature.input_specs: + if spec.kind == InputKind.BUFFER and not spec.persistent: + assert spec.target is not None + assert spec.target not in constants + constants[spec.target] = orig_mod.get_buffer(spec.target) # type: ignore[arg-type] + + +def _verify_nn_module_stack(graph_module: torch.fx.GraphModule) -> None: + """ + Perform nn_module_stack checks on the graph. + Current constraints: + For the top level graph: + - populated for 'call_function', 'get_attr' + - None for 'placeholder', 'output' + For submodule graphs: + - None for 'placeholder', output' + + TODO(pianpwk): make this a consistent node-level check once nn_module_stack is populated for cond submodules. + """ + # Check top-level graph for all nodes, all graphs for placeholder & output nodes + for i, mod in enumerate([graph_module] + list(graph_module.modules())): + if not isinstance(mod, torch.fx.GraphModule): + continue + for node in mod.graph.nodes: + if node.op in ["call_function", "get_attr"]: + if i == 0: + if ( + nn_module_stack := node.meta.get("nn_module_stack", None) + ) is None: + raise SpecViolationError( + f"Node {node} of type {node.op} is missing nn_module_stack metadata" + ) + if not all( + isinstance(k, str) + and isinstance(v, tuple) + and len(v) == 2 + and all(isinstance(x, str) for x in v) + for k, v in nn_module_stack.items() + ): + raise SpecViolationError( + f"Node {node} of type {node.op} has incorrect nn_module_stack metadata format" + f"expected Dict[str, Tuple[str, str]], but got {nn_module_stack}" + ) + elif node.op in ["placeholder", "output"]: + if node.meta.get("nn_module_stack", None): + raise SpecViolationError( + f"Node {node} of type {node.op} contains nn_module_stack metadata, this should be None" + ) + + +def _verify_stack_trace(graph_module: torch.fx.GraphModule) -> None: + """ + Perform stack trace checks on the graph. + Constraints: + - None or non-empty str for 'call_function', 'get_attr' + - None for 'placeholder', 'output' + """ + for mod in [graph_module, *graph_module.modules()]: + if not isinstance(mod, torch.fx.GraphModule): + continue + for node in graph_module.graph.nodes: + stack_trace = node.meta.get("stack_trace", None) + if node.op in ["call_function", "get_attr"]: + if not (stack_trace is None or isinstance(stack_trace, str)): + raise SpecViolationError( + f"Node {node} of type {node.op} has invalid stack_trace metadata, " + f"expected a string or None but instead found: {stack_trace}" + ) + elif node.op in ["placeholder", "output"]: + if stack_trace: + raise SpecViolationError( + f"Node {node} of type {node.op} contains stack_trace metadata, " + f"expected None but instead found: {stack_trace}" + ) + + +def _verify_placeholder_names( + gm: torch.fx.GraphModule, sig: ExportGraphSignature +) -> None: + """ + Performs a sanity check on the placeholder node names. + - User input nodes: no restrictions, should match the original forward() signature + - Params/buffers/constants/custom_obj/token nodes: should start with prefixes defined in + """ + name_to_kind = {spec.arg.name: spec.kind for spec in sig.input_specs} + for mod in gm.modules(): + if not isinstance(mod, torch.fx.GraphModule): + continue + for node in mod.graph.nodes: + if node.op == "placeholder": + if node.name not in name_to_kind: + continue + node_kind = name_to_kind[node.name] + prefix = placeholder_prefixes[node_kind] + if not node.name.startswith(prefix): + raise SpecViolationError( + f"Placeholder node name {node.name} does not follow spec for {node_kind}, name should have prefix: {prefix}" + ) + + +def get_ep_stats(ep: ExportedProgram) -> dict[str, Any]: + op_count = 0 + op_set = set() + for m in ep.graph_module.modules(): + if not isinstance(m, torch.fx.GraphModule): + continue + for node in m.graph.nodes: + if node.op != "call_function": + continue + op_count += 1 + assert hasattr(node.target, "__module__") + assert hasattr(node.target, "__name__") + op_set.add(f"{node.target.__module__}.{node.target.__name__}") + return {"op_count": op_count, "op_set": op_set} + + +_EXPORT_FLAGS: set[str] | None = None +_EXPORT_MODULE_HIERARCHY: dict[str, str] | None = None + + +def _log_export_wrapper(fn): + @functools.wraps(fn) + def wrapper(*args, **kwargs): + global _EXPORT_FLAGS, _EXPORT_MODULE_HIERARCHY + try: + start = time.time() + ep = fn(*args, **kwargs) + end = time.time() + log_export_usage( + event="export.time", + metrics=end - start, + flags=_EXPORT_FLAGS, + **get_ep_stats(ep), + ) + except Exception as e: + t = type(e) + error_type = t.__module__ + "." + t.__qualname__ + case_name = get_class_if_classified_error(e) + if case_name is not None: + log.error(exportdb_error_message(case_name)) + log_export_usage( + event="export.error.classified", + type=error_type, + message=str(e), + flags=_EXPORT_FLAGS, + ) + else: + log_export_usage( + event="export.error.unclassified", + type=error_type, + message=str(e), + flags=_EXPORT_FLAGS, + ) + + if hasattr(e, "partial_fx_graph"): + print( + e.partial_fx_graph, + file=sys.stderr, + ) + + raise e + finally: + _EXPORT_FLAGS = None + _EXPORT_MODULE_HIERARCHY = None + + return ep + + return wrapper + + +def _process_jit_trace_inputs_for_export(example_inputs, example_kwarg_inputs): + if not isinstance(example_inputs, (tuple, list, dict)): + example_inputs = (example_inputs,) + + elif isinstance(example_inputs, list): + example_inputs = tuple(example_inputs) + + elif ( + isinstance(example_inputs, (torch.Tensor, dict)) + and example_kwarg_inputs is None + ): + example_inputs = (example_inputs,) + + if example_kwarg_inputs is None: + example_kwarg_inputs = {} + return example_inputs, example_kwarg_inputs + + +def _get_original_state_dict(mod: torch.nn.Module) -> dict[str, Any]: + # Explicitly not calling mode.state_dict() as we do not want the module state for serialization + # but the running module state so we can always match by id() the entries here with the graph inputs + named_parameters = dict(mod.named_parameters(remove_duplicate=False)) + named_buffers = dict(mod.named_buffers(remove_duplicate=False)) + original_state_dict = named_parameters | named_buffers + + non_persistent_buffers = _get_non_persistent_buffers(mod) + for k in non_persistent_buffers: + original_state_dict.pop(k, None) + + return original_state_dict + + +def _process_export_inputs( + mod: torch.nn.Module, + args: tuple[object, ...], + kwargs: dict[str, object] | None, + dynamic_shapes: _DynamicShapesSpec + | torch.export.AdditionalInputs + | torch.export.ShapesCollection + | None, +) -> tuple[ + tuple[object, ...], + dict[str, object], + TreeSpec, + _DynamicShapesSpec | None, + Callable[[ExportedProgram], None], +]: + """ + Process and validate export inputs for the torch.export API. + + This function validates the input arguments, normalizes kwargs, computes input tree specs, + and handles special dynamic shapes cases like AdditionalInputs and ShapesCollection. + + Args: + mod: The PyTorch module to be exported. + args: Tuple of example positional inputs for the module. + kwargs: Optional dictionary of example keyword inputs. + dynamic_shapes: Optional specification for dynamic shapes. Can be: + - dict mapping argument names to dynamic shape specifications + - tuple/list specifying dynamic shapes for each input in order + - torch.export.AdditionalInputs object with verification callback + - torch.export.ShapesCollection object + + Returns: + A tuple containing: + - args: Validated tuple of positional inputs + - kwargs: Normalized dictionary of keyword inputs (empty dict if None was passed) + - original_in_spec: TreeSpec representing the flattened input structure + - dynamic_shapes: Processed dynamic shapes specification + - verify_additional_inputs: Callback function for additional input verification + + Raises: + UserError: If args is not a tuple. + """ + if not isinstance(args, tuple): + raise UserError( + UserErrorType.INVALID_INPUT, + f"Expecting `args` to be a tuple of example positional inputs, got {type(args)}", + ) + kwargs = kwargs if kwargs is not None else {} + if pytree.is_namedtuple_instance(args): + args = tuple(args) + + _, original_in_spec = pytree.tree_flatten((args, kwargs)) + + verify_additional_inputs: Callable[[ExportedProgram], None] + out_dynamic_shapes: _DynamicShapesSpec | None + if isinstance(dynamic_shapes, torch.export.AdditionalInputs): + verify_additional_inputs = dynamic_shapes.verify # type: ignore[assignment] + out_dynamic_shapes = dynamic_shapes.dynamic_shapes(mod, args, kwargs) # type: ignore[assignment] + else: + verify_additional_inputs = lambda ep: None # noqa: E731 + if isinstance(dynamic_shapes, torch.export.ShapesCollection): + out_dynamic_shapes = dynamic_shapes.dynamic_shapes(mod, args, kwargs) # type: ignore[assignment] + else: + out_dynamic_shapes = dynamic_shapes + + return args, kwargs, original_in_spec, out_dynamic_shapes, verify_additional_inputs + + +def _get_module_call_graph( + export_artifact: ExportArtifact, + preserve_module_call_signature: tuple[str, ...], + strict_mode_export: bool, + forward_arg_names: list[str] | None = None, +) -> tuple[torch.fx.GraphModule, list[ModuleCallEntry]]: + """ + In-place modify the graph module in export_artifact, remove _export_tracepoint nodes and + return module_call_graph. + """ + gm: torch.fx.GraphModule = export_artifact.aten.gm + export_graph_signature: ExportGraphSignature = export_artifact.aten.sig + module_call_specs: dict[str, dict[str, TreeSpec]] = ( + export_artifact.module_call_specs + ) + in_spec: TreeSpec = export_artifact.in_spec + out_spec: TreeSpec = export_artifact.out_spec + + # Make module signatures. + module_call_signatures: dict[str, ModuleCallSignature] = {} + for fqn, specs in module_call_specs.items(): + mod_fqn = _strip_root(fqn) if not strict_mode_export else fqn + module_call_signatures[mod_fqn] = ModuleCallSignature( + inputs=[], + outputs=[], + in_spec=specs["in_spec"], + out_spec=specs["out_spec"], + forward_arg_names=None, # we only propagate forward_arg_names for the top level module + ) + + if len(preserve_module_call_signature) > 0: + if not strict_mode_export: + _rewrite_tracepoint_node(gm) + res = CollectTracepointsPass(module_call_signatures, export_graph_signature)(gm) + assert res is not None + gm = res.graph_module + + assert _EXPORT_MODULE_HIERARCHY is not None + module_call_graph = _make_module_call_graph( + in_spec, + out_spec, + module_call_signatures, + forward_arg_names, + ) + return gm, module_call_graph + + +def _get_range_constraints( + mod: torch.nn.Module, + export_artifact: ExportArtifact, + args, + kwargs, + dynamic_shapes, +): + gm: torch.fx.GraphModule = export_artifact.aten.gm + export_graph_signature: ExportGraphSignature = export_artifact.aten.sig + fake_mode: FakeTensorMode = export_artifact.fake_mode + num_lifted = next( + ( + i + for i, s in enumerate(export_graph_signature.input_specs) + if s.kind == InputKind.USER_INPUT + ), + len(export_graph_signature.input_specs), + ) + combined_args = _combine_args(mod, args, kwargs) + + # This is because we trace based on the kwargs passed in from user + # not based on the signature. I feel it would be better to just enforce + # one ordering at the start of tracing to avoid confusions, but that is + # bigger refactor, so do this to unblock for now. + combined_args_traced_order = {} + for arg in combined_args: + if arg not in kwargs: + combined_args_traced_order[arg] = combined_args[arg] + + for key in kwargs: + combined_args_traced_order[key] = kwargs[key] + + combined_args = combined_args_traced_order + + range_constraints = make_constraints( + fake_mode, + gm, + combined_args, + dynamic_shapes, + num_lifted, + ) + return range_constraints + + +def _get_inline_constraints(fake_mode: FakeTensorMode): + assert fake_mode.shape_env is not None + return { + k: v + for k, v in fake_mode.shape_env.var_to_range.items() + if free_unbacked_symbols(k) + } + + +@contextmanager +def patch_forward(obj: torch.nn.Module, new_method): + """Helper method to make it easier to cleanly torch.export() a method on a + module that is not `forward`. + """ + # Save the original method + original_method = obj.forward + + # Patch the method + obj.forward = new_method.__get__(obj, obj.__class__) + + try: + yield + finally: + # Restore the original method + obj.forward = original_method + + +@contextmanager +def _temp_disable_texpr_fuser(): + original_state = torch._C._jit_texpr_fuser_enabled() + torch._C._jit_set_texpr_fuser_enabled(False) + try: + yield + finally: + torch._C._jit_set_texpr_fuser_enabled(original_state) + + +def _strict_export( + mod: torch.nn.Module, + args: tuple[Any, ...], + kwargs: dict[str, Any], + dynamic_shapes: dict[str, Any] | tuple[Any] | list[Any] | None, + preserve_module_call_signature: tuple[str, ...], + orig_in_spec: TreeSpec, + prefer_deferred_runtime_asserts_over_guards: bool, + _to_aten_func: Callable, +) -> ExportArtifact: + """ + _to_aten_func can either be `_export_to_aten_ir_make_fx` or `_export_to_aten_ir` + """ + + gm_torch_level = _export_to_torch_ir( + mod, + args, + kwargs, + dynamic_shapes, + preserve_module_call_signature=preserve_module_call_signature, + restore_fqn=False, # don't need to restore because we will do it later + prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards, + _log_export_usage=False, + ) + + # We detect the fake_mode by looking at gm_torch_level's placeholders, this is the fake_mode created in dynamo. + ( + fake_args, + fake_kwargs, + dynamo_fake_mode, + ) = _extract_fake_inputs(gm_torch_level, args, kwargs) + + fake_params_buffers = _fakify_params_buffers(dynamo_fake_mode, gm_torch_level) + + # First, we want to pass through the graph to try populating + # val field for getattr if there is anything missing. + # This can happen when quantization adds extra params and forgets + # to update "val" + for node in gm_torch_level.graph.nodes: + if node.op == "get_attr" and "val" not in node.meta: + attr = getattr(gm_torch_level, node.target) + # Checks if it is not a HigherOrderOp branch or a module + if not isinstance(attr, torch.nn.Module): + assert dynamo_fake_mode is not None, ( + "Cannot find dynamo_fake_mode. This could be due to the exported graph module have no placeholders." + ) + node.meta["val"] = dynamo_fake_mode.from_tensor( + attr, static_shapes=True + ) + + # Fix the graph output signature to be tuple if scalar + + # gm_torch_level.graph._codegen is made a _PyTreeCodeGen in rewrite_signature in eval_frame.py + assert isinstance(gm_torch_level.graph._codegen, torch.fx.graph._PyTreeCodeGen) + + # Calling gm_torch_level._out_spec is not safe because gm_torch_level might be + # a _LazyGraphModule, which does not populate _out_spec when calling recompile(). + # TODO: Fix recompile() in _LazyGraphModule. T207713214 + out_spec = orig_out_spec = gm_torch_level.graph._codegen.pytree_info.out_spec + + # Used to get rid of lint type error. + assert out_spec is not None + assert orig_out_spec is not None + + # aot_export expect the return type to always be a tuple. + if out_spec.type not in (list, tuple): + out_spec = pytree.treespec_tuple([out_spec]) + + orig_arg_names = gm_torch_level.graph._codegen.pytree_info.orig_args # type: ignore[attr-defined] + + gm_torch_level.graph._codegen.pytree_info = _PyTreeInfo( + orig_arg_names, + gm_torch_level._in_spec, + out_spec, + ) + gm_torch_level.recompile() + + _normalize_nn_module_stack(gm_torch_level, type(mod)) + + params_buffers_to_node_meta = _collect_param_buffer_metadata(gm_torch_level) + + # When aot_export lifts the params, we lose metadata (e.g. source_fn_stack, stack_trace) + # from the param nodes as they are treated as fresh inputs + # Therefore, we manually extract them before calling into aot_export + # params_buffers_to_node_meta = _collect_param_buffer_metadata(gm_torch_level) + + constant_attrs = _gather_constant_attrs(mod) + param_buffer_table: dict[str, str] = _get_param_buffer_mapping(mod, gm_torch_level) + + # Dynamo does not track which buffers were registered as non-persistent. This info + # is available in the original module, so we transfer it to the traced module. Also, + # since we didn't restore original param/buffer names yet, we must use traced names. + non_persistent_buffers = _get_non_persistent_buffers(mod) + reverse_name_lookup = {orig: traced for traced, orig in param_buffer_table.items()} + gm_torch_level._non_persistent_buffers_set = { + reverse_name_lookup[name] + for name in non_persistent_buffers + if name in reverse_name_lookup + } + + tx = TracingContext(dynamo_fake_mode) + with ( + dynamo_fake_mode, + tracing(tx), + mock.patch.object(dynamo_fake_mode, "allow_non_fake_inputs", True), + ): + aten_export_artifact = _to_aten_func( + gm_torch_level, + # NOTE: graph module expects only positional args + _convert_to_positional_args(orig_arg_names, fake_args, fake_kwargs), + {}, + fake_params_buffers, + constant_attrs, + ) + + # Decompose for readability. + gm = aten_export_artifact.gm + export_graph_signature = aten_export_artifact.sig + constants = aten_export_artifact.constants + + _populate_param_buffer_metadata_to_new_gm( + params_buffers_to_node_meta, gm, export_graph_signature + ) + + # Do some cleanups on the graph module to restore the state dict to the + # expected form. Each of these steps should probably get fixed upstream. + # 1. Remove tensor constants that were added as buffers. + _rewrite_dynamo_tensor_constants( + orig_mod_buffers=set(mod.buffers()), + traced_mod_buffers=dict(gm_torch_level.named_buffers()), + graph_signature=export_graph_signature, + constants=constants, + ) + # 2. Restore FQN of param/buffers + _replace_param_buffer_names(param_buffer_table, export_graph_signature) + + # 3. Move non-persistent buffers to tensor constants + _move_non_persistent_buffers_to_tensor_constants( + mod, export_graph_signature, constants + ) + + # 4. Rewrite constants to have the same FQN as the original module. + _remap_constants(constant_attrs, export_graph_signature, constants) + + # 5. Rename constants nodes in graph module from buffers to constants + _rename_constants_nodes(gm, export_graph_signature) + + return ExportArtifact( + aten=aten_export_artifact, + in_spec=orig_in_spec, + out_spec=orig_out_spec, + fake_mode=dynamo_fake_mode, + module_call_specs=gm_torch_level.meta["module_call_specs"], + ) + + +def _export_to_aten_ir_make_fx( + mod: torch.nn.Module, + fake_args, + fake_kwargs, + fake_params_buffers, + constant_attrs: ConstantAttrMap, + produce_guards_callback=None, + transform=lambda x: x, +) -> ATenExportArtifact: + def _make_fx_helper(stack, mod, args, kwargs, **flags): + kwargs = kwargs or {} + + named_parameters = dict(mod.named_parameters(remove_duplicate=False)) + named_buffers = dict(mod.named_buffers(remove_duplicate=False)) + + params_and_buffers = {**named_parameters, **named_buffers} + params_and_buffers_flat, params_spec = pytree.tree_flatten(params_and_buffers) + params_and_buffers_flat = tuple(params_and_buffers_flat) + + param_len = len(named_parameters) + buffer_len = len(named_buffers) + params_len = len(params_and_buffers) + + functional_call = create_functional_call( + mod, params_spec, params_len, store_orig_mod=True + ) + + params_buffers_args: list[Any] = [] + params_buffers_args.extend(params_and_buffers_flat) + params_buffers_args.extend(args) + + flat_fn, out_spec = create_tree_flattened_fn( + functional_call, params_buffers_args, kwargs + ) + flat_args, in_spec = pytree.tree_flatten((params_buffers_args, kwargs)) + + @functools.wraps(flat_fn) + def wrapped_fn(*args): + return tuple(flat_fn(*args)) + + with enable_python_dispatcher(): + ctx = nullcontext() + non_strict_root = getattr(mod, "_export_root", None) + if non_strict_root is not None: + ctx = _detect_attribute_assignment(non_strict_root) # type: ignore[assignment] + + # For any buffer that is assigned, we want to associate it to the final proxy node + # that it is assigned to. This node can then be copied into the buffer. + assigned_buffers: dict[str, str] = {} + hook = register_buffer_assignment_hook( + non_strict_root, assigned_buffers + ) + + def custom_getattribute(self, attr, *, original_getattr, attrs_to_proxy): + """ + The idea here is that we override subclass getattr methods to proxy + inner tensors and metadata. Because of infinite loop shenanigans, we have + to manually construct the getattr proxy nodes without relying on torch function + system. + """ + out = original_getattr(self, attr) + if attr in attrs_to_proxy: + if torch._C._is_torch_function_mode_enabled(): + if isinstance(out, torch.Tensor): + # When we get here there is no guarantee that we will hit the + # PreDispatchTorchFunctionMode, so we manually peak into the torch + # function mode list and tweak the PreDispatchTorchFunctionMode. + # This has side effect of proxying stuff like + # proxy.node.meta["val"] = extract_val(val) because at that time, torch function + # mode is still active. It seems bad to turn it off inside proxy_tensor.py, so + # I guess we will just rely on DCE for now to remove extra stuff like detach + torch_function_mode_stack = ( + torch.overrides._get_current_function_mode_stack() + ) + for mode in torch_function_mode_stack: + if isinstance(mode, PreDispatchTorchFunctionMode): + tracer = mode.tracer + proxy = get_proxy_slot(self, tracer).proxy + inner_proxy = tracer.create_proxy( + "call_function", + torch.ops.export.access_subclass_inner_tensor.default, + (proxy, attr), + {}, + ) + track_tensor_tree( + out, inner_proxy, constant=None, tracer=tracer + ) + return out + + @contextmanager + def override_getattribute_for_subclasses(args): + """ + Context manager that temporarily monkey patches + tensor.__getattribute__ so that we can intercept it at + torch_function layer. + """ + + # Dictionary that tracks subclass type to original getattr function + # and the attributes we can proxy. + tensor_type_to_old_getattribute: dict[ + type[torch.Tensor], tuple[Callable, set[str]] + ] = {} + for arg in args: + subclass_types_to_instances: dict[ + type[torch.Tensor], list[type[torch.Tensor]] + ] = get_subclass_typing_container(arg) + for subclass_type in subclass_types_to_instances: + if subclass_type not in tensor_type_to_old_getattribute: + assert len(subclass_types_to_instances[subclass_type]) > 0 + instance = subclass_types_to_instances[subclass_type][0] + # Query subclass specific attrs + attrs_to_proxy = set(dir(instance)) - set(dir(torch.Tensor)) + tensor_type_to_old_getattribute[subclass_type] = ( + subclass_type.__getattribute__, # type: ignore[attr-defined] + attrs_to_proxy, + ) + + try: + for k, ( + old_getattr, + attrs_to_proxy, + ) in tensor_type_to_old_getattribute.items(): + custom = functools.partialmethod( + custom_getattribute, + original_getattr=old_getattr, + attrs_to_proxy=attrs_to_proxy, + ) + k.__getattribute__ = custom # type: ignore[assignment, attr-defined] + yield + finally: + for k, (old_getattr, _) in tensor_type_to_old_getattribute.items(): + k.__getattribute__ = old_getattr # type: ignore[method-assign, attr-defined] + + @contextmanager + def _maybe_restore_grad_state(): + """ + When pre-dispatch export accidentally change grad state, we restore it back. + This can happen when we are calling torch._C._set_grad_enabled directly in the + forward. + """ + old_state = torch.is_grad_enabled() + try: + yield + finally: + torch._C._set_grad_enabled(old_state) + + with ( + ctx, + override_getattribute_for_subclasses(flat_args), + _maybe_restore_grad_state(), + ): + gm = make_fx( + wrapped_fn, + record_module_stack=True, + pre_dispatch=True, + )(*flat_args) + + if non_strict_root is not None: + input_names = _graph_input_names(gm) + buffer_input_names = { + name: input_names[param_len + i] + for i, (name, buf) in enumerate(non_strict_root._buffers.items()) + if buf is not None + } + output_node = list(gm.graph.nodes)[-1] + # We copy nodes corresponding to buffer assignments to buffers in the graph. + for buf, name in assigned_buffers.items(): # type: ignore[possibly-undefined] + buf_node = _find_node(gm, buffer_input_names[buf]) + name_node = _find_node(gm, name) + with gm.graph.inserting_before(output_node): + new_node = gm.graph.create_node( + "call_function", + torch.ops.aten.copy_.default, + args=(buf_node, name_node), + ) + new_node.meta = name_node.meta + + hook.remove() # type: ignore[possibly-undefined] + + def _is_impure(node): + if node.op == "call_function" and node.target in ( + # In export, we ignore any op that is related to + # eager mode profiling call. The expectation is + # that either runtimes provide their own profiling + # OR user wrap the compiled region on a profiling in + # later stage. + torch.ops.profiler._record_function_enter.default, + torch.ops.profiler._record_function_enter_new.default, + torch.ops.profiler._record_function_exit._RecordFunction, + # In theory, we could fix this dead detach and getattr nodes + # from subclass tensors if we carefully rewrite track_tensor_tree + # in a way that it doesn't do any tensor methods. + torch.ops.aten.detach.default, + torch.ops.export.access_subclass_inner_tensor.default, + ): + return False + return True + + gm.graph.eliminate_dead_code(_is_impure) + + # create graph signature + assert out_spec.spec is not None, "out_spec.spec is None!" + input_names = _graph_input_names(gm) + output_names = _graph_output_names(gm) + sig = GraphSignature( + parameters=list(named_parameters), + buffers=list(named_buffers), + user_inputs=input_names[params_len:], + user_outputs=output_names, + inputs_to_parameters=dict(zip(input_names[0:param_len], named_parameters)), + inputs_to_buffers=dict( + zip(input_names[param_len : param_len + buffer_len], named_buffers) + ), + buffers_to_mutate={}, + parameters_to_mutate={}, + user_inputs_to_mutate={}, + in_spec=in_spec, + out_spec=out_spec.spec, + backward_signature=None, + input_tokens=[], + output_tokens=[], + ) + return gm, sig + + # This _reparameterize_module makes sure inputs and module.params/buffers have the same fake_mode, + # otherwise aot_export_module will error out because it sees a mix of fake_modes. + # And we want aot_export_module to use the fake_tensor mode in dynamo to keep the pipeline easy to reason about. + with ExitStack() as stack: + stack.enter_context( + torch.nn.utils.stateless._reparametrize_module( + mod, + fake_params_buffers, + tie_weights=True, + strict=True, + stack_weights=True, + ) + ) + stack.enter_context(_ignore_backend_decomps()) + stack.enter_context(_compiling_state_context()) + gm, graph_signature = transform(_make_fx_helper)( + stack, + mod, + fake_args, + trace_joint=False, + kwargs=fake_kwargs, + ) + + # [NOTE] In training IR, we don't run + # any DCE as a result we preserve constant + # nodes in the graph. make_fx invariant is that + # they don't guarantee every node gets a meta['val'] + # field. Since the actual value is already hardcoded in + # graph, the node.meta here actually doesn't matter. But + # we do this to make spec verifier happy. + for node in gm.graph.nodes: + if ( + node.op == "call_function" + and len(node.users) == 0 + and "val" not in node.meta + ): + node.meta["val"] = None + + if isinstance(mod, torch.fx.GraphModule) and hasattr(mod, "meta"): + gm.meta.update(mod.meta) + + # See comment in _export_to_aten_ir() + if produce_guards_callback: + try: + produce_guards_callback(gm) + except (ConstraintViolationError, ValueRangeError) as e: + raise UserError(UserErrorType.CONSTRAINT_VIOLATION, str(e)) # noqa: B904 + + return _produce_aten_artifact( + gm=gm, + mod=mod, + constant_attrs=constant_attrs, + graph_signature=graph_signature, + pre_dispatch=True, + fake_args=fake_args, + fake_kwargs=fake_kwargs, + fake_params_buffers=fake_params_buffers, + ) + + +def set_missing_meta_vals(gm, flat_args, num_params_buffers): + # Sets missing metadata to address two problems: + # 1. aot_export adds symint metadata for placeholders with int values; since + # these become specialized, we replace such metadata with the original values. + # 2. any tensor attributes that are not params / buffers, i.e., are constants + # need to have their metadata set before lifting them because it is needed + # for computing the exported program's signature. + index = 0 + for node in gm.graph.nodes: + if node.op == "placeholder": + if index >= num_params_buffers: + user_arg = flat_args[index - num_params_buffers] + if not isinstance(user_arg, torch.Tensor): + node.meta["val"] = user_arg + index += 1 + + +def _find_node(gm: torch.fx.GraphModule, name: str) -> torch.fx.Node: + return next(iter(node for node in gm.graph.nodes if node.name == name)) + + +def _non_strict_export( + mod: torch.nn.Module, + args: tuple[Any, ...], + kwargs: dict[str, Any], + dynamic_shapes: dict[str, Any] | tuple[Any] | list[Any] | None, + preserve_module_call_signature: tuple[str, ...], + orig_in_spec: TreeSpec, + prefer_deferred_runtime_asserts_over_guards: bool, + _to_aten_func: Callable, +) -> ExportArtifact: + """ + _to_aten_func can either be `_export_to_aten_ir_make_fx` or `_export_to_aten_ir` + """ + + out_spec: TreeSpec | None = None + in_spec: TreeSpec | None = None + + module_call_specs: dict[str, dict[str, pytree.TreeSpec]] = {} + + def _tuplify_outputs(aot_export): + def _aot_export_non_strict(stack, mod, args, *, kwargs=None, **flags): + kwargs = kwargs or {} + + class Wrapper(torch.nn.Module): + def __init__(self, mod): + super().__init__() + self._export_root = mod + + def forward(self, *args, **kwargs): + nonlocal out_spec + nonlocal in_spec + mod = self._export_root + _, in_spec = pytree.tree_flatten((args, kwargs)) + if isinstance(mod, torch.fx.GraphModule): + # NOTE: We're going to run this graph module with an fx interpreter, + # which will not run any forward hooks. Thus, ideally, we should run + # all forward hooks here. But the general logic for running them is + # complicated (see nn/module.py), and probably not worth duplicating. + # Instead we only look for, and run, an export-specific forward hook. + if ( + _check_input_constraints_pre_hook + in mod._forward_pre_hooks.values() + ): + _check_input_constraints_pre_hook(mod, args, kwargs) + with torch.fx.traceback.preserve_node_meta(): + args = (*args, *kwargs.values()) + tree_out = torch.fx.Interpreter(mod).run(*args) + else: + tree_out = mod(*args, **kwargs) + flat_outs, out_spec = pytree.tree_flatten(tree_out) + return tuple(flat_outs) + + wrapped_mod = Wrapper(mod) + # Patch export_root to the signatures so that wrapper module correctly populates the + # in/out spec + new_preserved_call_signatures = [ + "_export_root." + i for i in preserve_module_call_signature + ] + ctx = nullcontext() + if not isinstance(mod, torch.fx.GraphModule): + ctx = _wrap_submodules( # type: ignore[assignment] + wrapped_mod, new_preserved_call_signatures, module_call_specs + ) + with ctx: + gm, sig = aot_export(stack, wrapped_mod, args, kwargs=kwargs, **flags) + log.debug("Exported program from AOTAutograd:\n%s", gm) + + sig.parameters = pytree.tree_map(_strip_root, sig.parameters) + sig.buffers = pytree.tree_map(_strip_root, sig.buffers) + sig.inputs_to_buffers = pytree.tree_map(_strip_root, sig.inputs_to_buffers) + sig.inputs_to_parameters = pytree.tree_map( + _strip_root, sig.inputs_to_parameters + ) + sig.buffers_to_mutate = pytree.tree_map(_strip_root, sig.buffers_to_mutate) + sig.parameters_to_mutate = pytree.tree_map( + _strip_root, sig.parameters_to_mutate + ) + + for node in gm.graph.nodes: + if "nn_module_stack" in node.meta: + nn_module_stack = node.meta["nn_module_stack"] + node.meta["nn_module_stack"] = { + _fixup_key(key): val + for key, val in pytree.tree_map( + _strip_root, nn_module_stack + ).items() + } + + return gm, sig + + return _aot_export_non_strict + + ( + fake_mode, + fake_args, + fake_kwargs, + equalities_inputs, + original_signature, + dynamic_shapes, + ) = make_fake_inputs( + mod, + args, + kwargs, + dynamic_shapes, + prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards, # for shape env initialization + ) + + fake_params_buffers = _fakify_params_buffers(fake_mode, mod) + + def _produce_guards_callback(gm): + return produce_guards_and_solve_constraints( + fake_mode=fake_mode, + gm=gm, + dynamic_shapes=dynamic_shapes, + equalities_inputs=equalities_inputs, + original_signature=original_signature, + ) + + tx = TracingContext(fake_mode) + + # We also need to attach dynamo configs as these will be used in HOOs that + # use torch.compile, like cond + dynamo_config = dataclasses.asdict(DEFAULT_EXPORT_DYNAMO_CONFIG) + dynamo_config["do_not_emit_runtime_asserts"] = ( + False # We want to emit runtime asserts + ) + + with ( + fake_mode, + _NonStrictTorchFunctionHandler(), + tracing(tx), + torch._dynamo.config.patch(dynamo_config), + ): + with ( + _fakify_script_objects(mod, fake_args, fake_kwargs, fake_mode) as ( + patched_mod, + new_fake_args, + new_fake_kwargs, + new_fake_constant_attrs, + map_fake_to_real, + ), + _fakify_module_inputs(fake_args, fake_kwargs, fake_mode), + _override_builtin_ops(), + ): + # _to_aten_func is _export_to_aten_ir when using the default non-strict export + # We need to pass positional args correctly + aten_export_artifact = _to_aten_func( + patched_mod, + new_fake_args, + new_fake_kwargs, + fake_params_buffers, + new_fake_constant_attrs, + produce_guards_callback=_produce_guards_callback, + transform=_tuplify_outputs, + ) + # aten_export_artifact.constants contains only fake script objects, we need to map them back + aten_export_artifact.constants = { + fqn: map_fake_to_real[obj] if isinstance(obj, FakeScriptObject) else obj + for fqn, obj in aten_export_artifact.constants.items() + } + + _move_non_persistent_buffers_to_tensor_constants( + mod, aten_export_artifact.sig, aten_export_artifact.constants + ) + + assert out_spec is not None + assert in_spec is not None + + return ExportArtifact( + aten=aten_export_artifact, + in_spec=in_spec, + out_spec=out_spec, + fake_mode=fake_mode, + module_call_specs=module_call_specs, + ) + + +@_log_export_wrapper +@_disable_prexisiting_fake_mode +def _export_for_training( + mod: torch.nn.Module, + args: tuple[Any, ...], + kwargs: dict[str, Any] | None = None, + dynamic_shapes: dict[str, Any] | tuple[Any] | list[Any] | None = None, + *, + strict: bool = True, + preserve_module_call_signature: tuple[str, ...] = (), + prefer_deferred_runtime_asserts_over_guards: bool = False, +) -> ExportedProgram: + global _EXPORT_MODULE_HIERARCHY + _EXPORT_MODULE_HIERARCHY = _get_module_hierarchy(mod) + + ( + args, + kwargs, + orig_in_spec, + dynamic_shapes, + verify_additional_inputs, + ) = _process_export_inputs(mod, args, kwargs, dynamic_shapes) + + original_state_dict = _get_original_state_dict(mod) + + has_ambient_mode = False + if not strict: + flat_args, _ = pytree.tree_flatten((args, kwargs)) + has_ambient_mode = torch._guards.detect_fake_mode(flat_args) is not None + + # Call the appropriate export function based on the strictness of tracing. + export_func = _strict_export if strict else _non_strict_export + + if not strict and torch._export.config.detect_non_strict_fake_tensor_leaks: + from torch._subclasses.fake_tensor import fake_tensor_tls + + fake_tensor_tls.non_strict_export_fake_tensor_tracker.clear() + + export_artifact = export_func( + mod=mod, + args=args, + kwargs=kwargs, + dynamic_shapes=dynamic_shapes, + preserve_module_call_signature=preserve_module_call_signature, + orig_in_spec=orig_in_spec, + prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards, + _to_aten_func=_export_to_aten_ir_make_fx, + ) + + # If we are tracing with fake inputs, it is expected to + # see fake tensor constants. + if not strict and not has_ambient_mode: + for const, val in export_artifact.aten.constants.items(): + if isinstance( + val, torch._subclasses.fake_tensor.FakeTensor + ) and _is_bogus_const_name(const): + error_msg = ( + f"We found a fake tensor in the exported program constant's list. " + f"This typically means our tracing system encountered an op that " + f"we can't trace through. For the potential source, you can refer to " + f"following model attribute: {const}. " + f"Please file an issue on github. " + ) + if torch._export.config.error_on_lifted_constant_tensors: + raise RuntimeError(error_msg) + else: + warnings.warn(error_msg, stacklevel=2) + + export_graph_signature = export_artifact.aten.sig + + forward_arg_names = _get_forward_arg_names(mod, args, kwargs) + inline_constraints = _get_inline_constraints(export_artifact.fake_mode) + # The unbacked symint symbols are updated in aot_export + # so we serialize them here instead of inside dynamo. + # Note: _get_range_constraints depends on "inline_constraints" to be set. + export_artifact.aten.gm.meta["inline_constraints"] = inline_constraints + range_constraints = _get_range_constraints( + mod, + export_artifact, + args, + kwargs, + dynamic_shapes, + ) + # The returned the gm is in-place modified + gm, module_call_graph = _get_module_call_graph( + export_artifact, + preserve_module_call_signature, + strict, + forward_arg_names, + ) + + _verify_nn_module_stack(gm) + _verify_stack_trace(gm) + _verify_placeholder_names(gm, export_graph_signature) + + _update_gm_meta_if_possible(gm, mod) + + from torch._export.verifier import TrainingIRVerifier + + exported_program = ExportedProgram( + root=gm, + graph=gm.graph, + graph_signature=export_graph_signature, + state_dict=original_state_dict, + range_constraints=range_constraints, + module_call_graph=module_call_graph, + example_inputs=(args, kwargs), + constants=export_artifact.aten.constants, + verifiers=[TrainingIRVerifier], + ) + + verify_additional_inputs(exported_program) + + if not strict and torch._export.config.detect_non_strict_fake_tensor_leaks: + # See NOTE [export non-strict fake tensor leak detection] + from torch._subclasses.fake_tensor import fake_tensor_tls + from torch.fx.experimental.proxy_tensor import ( + _FAKE_TENSOR_ID_TO_PROXY_MAP_FOR_EXPORT, + ) + + active_fakes = fake_tensor_tls.non_strict_export_fake_tensor_tracker + legit_leak: weakref.WeakSet = find_legit_leaks_from_referrers(active_fakes) + leak_sources: list[str] = [] + if len(legit_leak) > 0: + for fake_val in legit_leak: + if id(fake_val) in _FAKE_TENSOR_ID_TO_PROXY_MAP_FOR_EXPORT: + stack_trace = _FAKE_TENSOR_ID_TO_PROXY_MAP_FOR_EXPORT[ + id(fake_val) + ].meta.get("stack_trace", "") + + # Get shape and dtype info + shape_info = f"shape={fake_val.shape}, dtype={fake_val.dtype}" + leak_info = f"FakeTensor({shape_info}): {stack_trace}" + leak_sources.append(leak_info) + + # Format the warning message more nicely + leak_details = "\n ".join(leak_sources) + warnings.warn( + f"Detected {len(legit_leak)} fake tensors that are still alive after export.\n" + f"This is likely result of torch.export.export not being able to track side effects " + f"that is happening outside of model scope.\n\n" + f"Leaked tensors:\n {leak_details}\n\n" + f"Alternatively, please file a bug report to PyTorch team for further debugging help.", + stacklevel=2, + ) + + del legit_leak + + return exported_program + + +@_log_export_wrapper +@_disable_prexisiting_fake_mode +def _export( + mod: torch.nn.Module, + args: tuple[Any, ...], + kwargs: dict[str, Any] | None = None, + dynamic_shapes: dict[str, Any] | tuple[Any] | list[Any] | None = None, + *, + strict: bool = True, + preserve_module_call_signature: tuple[str, ...] = (), + pre_dispatch: bool = False, + prefer_deferred_runtime_asserts_over_guards: bool = False, +) -> ExportedProgram: + """ + Traces either an nn.Module's forward function or just a callable with PyTorch + operations inside and produce a ExportedProgram. + + Args: + mod: the `nn.Module` to trace. + + args: example positional inputs. + + kwargs: optional example keyword inputs. + + dynamic_shapes: + An optional argument where the type should either be: + 1) a dict from argument names of ``f`` to their dynamic shape specifications, + 2) a tuple that specifies dynamic shape specifications for each input in original order. + If you are specifying dynamism on keyword args, you will need to pass them in the order that + is defined in the original function signature. + + The dynamic shape of a tensor argument can be specified as either + (1) a dict from dynamic dimension indices to :func:`Dim` types, where it is + not required to include static dimension indices in this dict, but when they are, + they should be mapped to None; or (2) a tuple / list of :func:`Dim` types or None, + where the :func:`Dim` types correspond to dynamic dimensions, and static dimensions + are denoted by None. Arguments that are dicts or tuples / lists of tensors are + recursively specified by using mappings or sequences of contained specifications. + + preserve_module_call_signature: A list of submodule paths for which the original + calling conventions are preserved as metadata. + + prefer_deferred_runtime_asserts_over_guards: + With the current dynamic shapes language for dims and derived dims, we can run into constraints + that are not expressible with the language. For example, flattening a matrix and adding to a vector, + both fully dynamic (i.e. x.reshape([-1]) + y) emits a guard s0 * s1 = s2, which is not expressible. + By default, we either raise a constraint violation error or specialize to static values. + If this flag is set to True, we avoid erroring out and instead allow complex constraints to exist as runtime + assertions in the graph. The sympy interpreter (torch/utils/_sympy/interp.py) will produce the math ops + required to compute and assert the value of the guard (e.g. sym_size_int, eq, _assert_scalar). + Additionally, if TORCH_DYNAMO_DO_NOT_EMIT_RUNTIME_ASSERTS=1 is specified, we will allow complex constraints + while not emitting runtime asserts, returning a cleaner graph with lesser guarantees around dynamic shapes. + + Returns: + An ExportedProgram containing the traced module. + """ + + from torch._utils_internal import export_training_ir_rollout_check + + global _EXPORT_FLAGS, _EXPORT_MODULE_HIERARCHY + _EXPORT_MODULE_HIERARCHY = _get_module_hierarchy(mod) + + flags = set() + flags.add("strict" if strict else "non_strict") + flags.add("pre_dispatch" if pre_dispatch else "aot_dispatch") + _EXPORT_FLAGS = flags + + log_export_usage(event="export.enter", flags=_EXPORT_FLAGS) + + dtrace_structured("export", payload_fn=lambda: "start!") + + # NOTE Export training IR rollout + # Old export calls export._trace(pre_dispatch=True) + # and there are still lot of internal/OSS callsites that + # use export._trace(pre_dispatch=True) directly. Therefore, + # it makes more sense to do the switch here. + # export_training_ir_rollout_check returns True in OSS + # while internally it returns False UNLESS otherwise specified. + if pre_dispatch and export_training_ir_rollout_check(): + ep = _export_for_training( + mod, + args, + kwargs, + dynamic_shapes, + strict=strict, + preserve_module_call_signature=preserve_module_call_signature, + prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards, + ) + dtrace_structured("exported_program", payload_fn=lambda: str(ep)) + return ep + + ( + args, + kwargs, + original_in_spec, + dynamic_shapes, + verify_additional_inputs, + ) = _process_export_inputs(mod, args, kwargs, dynamic_shapes) + + original_state_dict = _get_original_state_dict(mod) + + # Call the appropriate export function based on the strictness of tracing. + export_func = _strict_export if strict else _non_strict_export + + export_artifact = export_func( # type: ignore[operator] + mod=mod, + args=args, + kwargs=kwargs, + dynamic_shapes=dynamic_shapes, + preserve_module_call_signature=preserve_module_call_signature, + orig_in_spec=original_in_spec, + prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards, + _to_aten_func=functools.partial( + _export_to_aten_ir, + pre_dispatch=pre_dispatch, + ), + ) + export_graph_signature: ExportGraphSignature = export_artifact.aten.sig + + forward_arg_names = _get_forward_arg_names(mod, args, kwargs) + inline_constraints = _get_inline_constraints(export_artifact.fake_mode) + # The unbacked symint symbols are updated in aot_export + # so we serialize them here instead of inside dynamo. + # Note: this step must be before _get_range_constraints. + export_artifact.aten.gm.meta["inline_constraints"] = inline_constraints + range_constraints = _get_range_constraints( + mod, + export_artifact, + args, + kwargs, + dynamic_shapes, + ) + gm, module_call_graph = _get_module_call_graph( + export_artifact, + preserve_module_call_signature, + strict, + forward_arg_names, + ) + + _verify_nn_module_stack(gm) + _verify_stack_trace(gm) + _verify_placeholder_names(gm, export_graph_signature) + + # Remove Proxy because they cannot be deepcopied or pickled. + torch._export.utils.remove_proxy_from_state_dict(original_state_dict, in_place=True) + + from torch._export.verifier import Verifier + + _update_gm_meta_if_possible(gm, mod) + + exported_program = ExportedProgram( + root=gm, + graph=gm.graph, + graph_signature=export_graph_signature, + state_dict=original_state_dict, + range_constraints=range_constraints, + module_call_graph=module_call_graph, + example_inputs=(args, kwargs), + constants=export_artifact.aten.constants, + verifiers=[Verifier], + ) + + dtrace_structured("exported_program", payload_fn=lambda: str(exported_program)) + + verify_additional_inputs(exported_program) + return exported_program diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/_tree_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/_tree_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..fd3f55c30afeb2e05dd6dbaebb42b3149dd32e93 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/_tree_utils.py @@ -0,0 +1,65 @@ +from collections.abc import Callable +from typing import Any + +from torch.utils._pytree import Context, TreeSpec + + +def reorder_kwargs(user_kwargs: dict[str, Any], spec: TreeSpec) -> dict[str, Any]: + """Reorder user-provided kwargs to match the order in `spec`. `spec` is + expected to be the in_spec of an exported program, i.e. the spec that + results from flattening `(args, kwargs)`. + + We need this to provide consistent input ordering, such so that users can + pass in foo(a=a, b=b) OR foo(b=b, a=a) and receive the same result. + """ + # Make sure that the spec is actually shaped like (args, kwargs) + assert spec.type is tuple + assert spec.num_children == 2 + kwargs_spec = spec.child(1) + assert kwargs_spec.type is dict + + if set(user_kwargs) != set(kwargs_spec.context): + raise ValueError( + f"Ran into a kwarg keyword mismatch: " + f"Got the following keywords {list(user_kwargs)} but expected {kwargs_spec.context}" + ) + + reordered_kwargs = {} + for kw in kwargs_spec.context: + reordered_kwargs[kw] = user_kwargs[kw] + + return reordered_kwargs + + +def is_equivalent( + spec1: TreeSpec, + spec2: TreeSpec, + equivalence_fn: Callable[[type | None, Context, type | None, Context], bool], +) -> bool: + """Customizable equivalence check for two TreeSpecs. + + Arguments: + spec1: The first TreeSpec to compare + spec2: The second TreeSpec to compare + equivalence_fn: A function to determine the equivalence of two + TreeSpecs by examining their types and contexts. It will be called like: + + equivalence_fn(spec1.type, spec1.context, spec2.type, spec2.context) + + This function will be applied recursively to all children. + + Returns: + True if the two TreeSpecs are equivalent, False otherwise. + """ + if not equivalence_fn(spec1.type, spec1.context, spec2.type, spec2.context): + return False + + # Recurse on children + if spec1.num_children != spec2.num_children: + return False + + for child_spec1, child_spec2 in zip(spec1.children(), spec2.children()): + if not is_equivalent(child_spec1, child_spec2, equivalence_fn): + return False + + return True diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/_unlift.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/_unlift.py new file mode 100644 index 0000000000000000000000000000000000000000..84e4313d395b7a100784601ebce020760299f9ca --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/_unlift.py @@ -0,0 +1,878 @@ +# mypy: allow-untyped-defs +import copy +import inspect +import math +import warnings +from collections.abc import Sequence +from itertools import chain +from typing import Any + +import sympy + +import torch +import torch.utils._pytree as pytree +from torch._export.non_strict_utils import ( + _enter_enable_graph_inputs_of_type_nn_module, + _exit_enable_graph_inputs_of_type_nn_module, + _get_graph_inputs_of_type_nn_module, +) +from torch._export.passes.add_runtime_assertions_for_constraints_pass import ( + _convert_range_to_int, +) +from torch._export.utils import _check_input_constraints_for_graph +from torch.export.unflatten import _assign_attr, _AttrKind +from torch.fx.experimental.proxy_tensor import _pytree_subclasses_that_lose_info +from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo +from torch.fx.traceback import NodeSource, NodeSourceAction +from torch.utils._sympy.solve import try_solve +from torch.utils._sympy.value_ranges import ValueRanges + +from ._remove_effect_tokens_pass import _remove_effect_tokens +from ._tree_utils import reorder_kwargs +from .exported_program import ( + ExportedProgram, + ExportGraphSignature, + InputKind, + OutputKind, +) + + +def eq_spec(self: pytree.TreeSpec, other: pytree.TreeSpec) -> bool: + """ + Refinement of TreeSpec.__eq__ where, e.g., torch.Size(...) matches tuple(...). + See _pytree_subclasses_that_lose_info in proxy_tensor.py for more details. + """ + + def _normalize_type(t): + return str(_pytree_subclasses_that_lose_info.get(t, t)) + + def _match_normalized_structure(a, b): + if a is b: + return True + if _normalize_type(a.type) != _normalize_type(b.type): + return False + if a.type is dict and b.type is dict: + # in the case of dict, the context is list of keys and we allow the keys to be in any order + if set(a.context) != set(b.context): + return False + elif a.context != b.context: + return False + if a.num_children != b.num_children: + return False + return all( + _match_normalized_structure(a, b) + for a, b in zip(a.children(), b.children()) + ) + + return _match_normalized_structure(self, other) + + +def _check_inputs_match(args, kwargs, in_spec: pytree.TreeSpec) -> list: + reordered_kwargs = reorder_kwargs(kwargs, in_spec) + flat_args_with_path, received_spec = pytree.tree_flatten_with_path( + (args, reordered_kwargs) + ) + + if not eq_spec(received_spec, in_spec): + raise ValueError( # noqa: B904 + "Trying to flatten user inputs with exported input tree spec: \n" + f"{in_spec}\n" + "but actually got inputs with tree spec of: \n" + f"{received_spec}.\n" + "Please check that the inputs have the same number and type of " + "args and kwargs as the ones you used when tracing." + ) + + return flat_args_with_path + + +def _force_ep_signature_match(ep_guards_code: list[str], input_paths): + # TODO (tmanlaibaatar) + # This is band-aid solution to export new tracer replacing + # shape env sources to flat_args. The real fix should be replacing + # shape env sources to original user sources but this is quite + # involved because you need to carefully construct new sources using + # dynamo and replace all instances of it inside shape env. But it is + # lot easier to manipulate after we turn them into strings and only + # time we use these guards is during retracing or running exported program, + # so it is probably ok to have "not useful" guards on ep for now. + name_mapping = {} + for idx, path in enumerate(input_paths): + name_mapping[f"L['flat_args'][{idx}]"] = f"L{pytree.keystr(path)}" + + new_guards_code = [] + for guard in ep_guards_code: + for old_name, new_name in name_mapping.items(): + guard = guard.replace(old_name, new_name) + new_guards_code.append(guard) + + return new_guards_code + + +def _force_gm_signature_match(ep_guards_code: list[str], signature): + """ + The signature of the originally exported module may not match + the signature of the unlifted graph module extracted from the + exported program. The guards code extracted from the exported + program is based on the former, but the generated guards fn is + based on the latter; thus we need to reconcile any such diff. + """ + + import re + + # Handle case where signatures may differ in var args. + orig_arg_names = set() + for g in ep_guards_code: + # match substrings of the form L[''][] + orig_arg_names.update(re.findall(r"L\[\'([^\']+)\'\]\[([0-9]+)\]", g)) + + sig_arg_names = set() + for n in signature.parameters: + # match substrings of the form _ + sig_arg_names.update(re.findall(r"(.+)_([0-9]+)", n)) + + # replace L[''][] with L['_'] + new_guards_code = ep_guards_code + for match in orig_arg_names: + if match in sig_arg_names: + base, idx = match + new_guards_code = [ + g.replace(f"L['{base}'][{idx}]", f"L['{base}_{idx}']") + for g in new_guards_code + ] + + return new_guards_code + + +def _convert_guards_code_to_fn( + guards_code: list[str], + paths_of_placeholders: list[pytree.KeyPath], +): + """ + Generates Python code given guards code and paths of placeholders. + We assume that, based on source information, + - the tracer generates the guards code + - the input spec generates the paths of placeholders. + + Example: + + Suppose we are given the guards code "L['z']['k'].size()[1] == 3" + and we are given that ['z']['k'] is the path of placeholder #2. + Then we will generate: + ``` + torch._assert( + args[2].size()[0] == 3, + "Guard failed: z['k'].size()[0] == 3", + ) + ``` + + FAQ: Why do we generate code based on (flattened) args instead of + the original (unflattened) inputs? Because this would require + inserting an additional pytree.unflatten call in our graph. + + FAQ: Why do we not emit RuntimeError on guard failure as we used to? + Because it is inconvenient :/, get used to AssertionError instead. + """ + + import ast + + from torch.fx.experimental.symbolic_shapes import SYMPY_INTERP + + actual_guards_code = [] + shadow_guards_code = [] + for c in guards_code: + a, s = c, c + for idx, path in enumerate(paths_of_placeholders): + # e.g., replace L['z']['k'] with args[2] for Python code (actual) + a = a.replace("L" + pytree.keystr(path), f"args[{idx}]") + # e.g., replace L['z']['k'] with z['k'] for error message (shadow) + s = s.replace( + "L" + pytree.keystr(path), + path[0].key + pytree.keystr(path[1:]), # type: ignore[attr-defined] + ) + actual_guards_code.append(a) + shadow_guards_code.append(s.replace("\n", "")) + + # generate function code as str + code_str = "\ndef _(*args):\n" + for actual, shadow in zip(actual_guards_code, shadow_guards_code): + # printing guards code may potentially introduce redundant parens; + # we can normalize them out for readability by parsing/unparsing + # NOTE: this is not necessary for correctness, just deemed desirable + _shadow = ast.unparse(ast.parse(shadow, mode="eval")) + # actual code and shadow error message + code_str += f' torch._assert({actual}, "Guard failed: {_shadow}")\n' + code_str += " return\n" + + # populate namespace with sympy globals, materialize function (named `_`) + namespace = {**SYMPY_INTERP} + exec(code_str, namespace) + + # create and return a module whose forward is the materialized function + # NOTE: we want Dynamo to trace through this module, to repopulate guards: + # otherwise we would lose them when retracing + # NOTE: calling this module will be a side effect (no users): so it must + # be marked impure to avoid being not cleaned up by DCE + guards_fn = GuardsFn() + guards_fn.forward = torch._dynamo.dont_skip_tracing(namespace["_"]) # type: ignore[call-overload, method-assign] + guards_fn._is_impure = True # type: ignore[assignment] + return guards_fn + + +@torch._dynamo.disable +def _check_input_constraints_for_module(self, args, kwargs): + flat_args_with_path = _check_inputs_match(args, kwargs, self._in_spec) + _check_input_constraints_for_graph( + self.graph.find_nodes(op="placeholder"), + flat_args_with_path, + self.range_constraints, + ) + + +def _check_input_constraints_pre_hook(self, args, kwargs): + # preserve current behavior for clients that do not want any validation + if not self.validate_inputs: + return + + # when a guards function exists, assume that the graph does calls it! + # so we do not need to check input constraints...but we still want + # to check inputs match, otherwise we'd get obscure pytree errors + if hasattr(self, "_guards_fn"): + _check_inputs_match(args, kwargs, self._in_spec) + return + + # NOTE: for some reason, Dynamo is tracing into this, we should see why and + # put compile at the right place. Until then, we can skip the input + # constraint checks. + if not torch.compiler.is_dynamo_compiling(): + _check_input_constraints_for_module(self, args, kwargs) + + +def _unlift_inputs_as_getattr( + gm: torch.fx.GraphModule, + lifted_inputs: Sequence[str | None], +) -> tuple[dict[str, torch.fx.Node], dict[str, torch.fx.Node]]: + """ + Unlift inputs referring to params/buffers/constants as getattr nodes in the + graph + """ + unlifted_name_to_node = {} + input_name_to_node = {} + + placeholder_nodes = [node for node in gm.graph.nodes if node.op == "placeholder"] + assert len(lifted_inputs) == len(placeholder_nodes) + for input_node, lifted_node in zip(placeholder_nodes, lifted_inputs): + if lifted_node is None: + input_name_to_node[input_node.name] = input_node + + else: + with gm.graph.inserting_after(input_node): + # It is fine to ignore this warning because + # it is guaranteed that we will populate this + # attr later. + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + getattr_node = gm.graph.get_attr(lifted_node) + input_node.replace_all_uses_with(getattr_node) + metadata = input_node.meta + gm.graph.erase_node(input_node) + getattr_node.meta = metadata + getattr_node.meta["from_node"] = [ + NodeSource( + input_node, + "ExportedProgram.module().unlift()", + [NodeSourceAction.CREATE, NodeSourceAction.REPLACE], + ) + ] + unlifted_name_to_node[lifted_node] = getattr_node + + return unlifted_name_to_node, input_name_to_node + + +def _insert_copy_for_mutations( + gm: torch.fx.GraphModule, + mutated_outputs: Sequence[str | None], + unlifted_name_to_node: dict[str, torch.fx.Node], + input_name_to_node: dict[str, torch.fx.Node], +) -> None: + """ + Find the all the buffers and inputs that were mutated and insert copy_ + operators to reflect mutations. + """ + output_node = gm.graph.output_node() + outputs = pytree.tree_flatten(output_node.args)[0] + assert len(outputs) == len(mutated_outputs) + + user_output_nodes = [] + return_nodes_to_copy = {} + for return_node, mutated_node_name in zip(outputs, mutated_outputs): + if mutated_node_name is None: + user_output_nodes.append(return_node) + continue + + if mutated_node_name in unlifted_name_to_node: + mutated_node = unlifted_name_to_node[mutated_node_name] + elif mutated_node_name in input_name_to_node: + mutated_node = input_name_to_node[mutated_node_name] + else: + raise RuntimeError( + f"Could not find {mutated_node_name} in either buffer or input nodes" + ) + + with gm.graph.inserting_before(output_node): + copy_node = gm.graph.call_function( + torch.ops.aten.copy_.default, (mutated_node, return_node) + ) + return_nodes_to_copy[return_node] = copy_node + + output_args = tuple( + return_nodes_to_copy.get(node, node) for node in user_output_nodes + ) + with gm.graph.inserting_before(output_node): + # Only return user outputs + new_output = gm.graph.output(output_args) + output_node.replace_all_uses_with(new_output) + gm.graph.erase_node(output_node) + new_output.name = output_node.name + new_output.meta.update(output_node.meta) + new_output.meta["from_node"] = [ + NodeSource( + output_node, + "ExportedProgram.module().unlift()", + [NodeSourceAction.CREATE, NodeSourceAction.REPLACE], + ) + ] + + +def _get_codegen( + in_spec: pytree.TreeSpec, + out_spec: pytree.TreeSpec | None, + forward_arg_names: list[str] | None = None, +) -> _PyTreeCodeGen: + """ + Create the codegen for the graph module based on the in/out specs + """ + if forward_arg_names: + names = forward_arg_names + elif ( + in_spec.type is tuple + and in_spec.num_children == 2 + and in_spec.child(0).type is tuple + and in_spec.child(1).type is dict + ): + # if in_spec contains the args (tuple) and kwargs (dict) + names = [f"arg_{i}" for i in range(in_spec.child(0).num_children)] + # add kwarg names + names.extend(in_spec.child(1).context) + else: + names = [f"arg_{i}" for i in range(in_spec.num_children)] + + return _PyTreeCodeGen( + _PyTreeInfo( + names, + in_spec, + out_spec, + ) + ) + + +def _unlift( + gm: torch.fx.GraphModule, + lifted_inputs: Sequence[str | None], + mutated_outputs: Sequence[str | None], + in_spec: pytree.TreeSpec, + out_spec: pytree.TreeSpec | None, + forward_arg_names: list[str] | None = None, +): + """ + Args: + lifted_inputs: A list matching the graph module's input nodes. For + an input node that is referring to a lifted parameter/buffer, this + list will contain the fqn the corresponding attribute. Otherwise, this + list will contain None. This is used to unlift the lifted parameters as + get_attr nodes. + + mutated_outputs: A list matching the graph module's output nodes. For + an output node that is referring to a mutated buffer or user input, this + list will contain the name of the corresponding buffer or user input + that needs to be mutated. Otherwise, this list will contain None. This + is used to re-insert an inplace copy_ operator to copy the mutated + values back to the original node. + """ + unlifted_name_to_node, input_name_to_node = _unlift_inputs_as_getattr( + gm, lifted_inputs + ) + _insert_copy_for_mutations( + gm, mutated_outputs, unlifted_name_to_node, input_name_to_node + ) + gm.graph._codegen = _get_codegen(in_spec, out_spec, forward_arg_names) + gm.graph.lint() + gm.recompile() + return gm + + +def _register_attrs_to_new_gm( + new_gm: torch.fx.GraphModule, + graph_signature: ExportGraphSignature, + state_dict: dict[str, Any], + constants: dict[str, Any], +) -> None: + non_persistent_buffers = set(graph_signature.non_persistent_buffers) + for name in graph_signature.buffers: + if name in non_persistent_buffers: + persistent = False + value = constants[name] + else: + persistent = True + value = state_dict[name] + _assign_attr( + value, new_gm, name, attr_kind=_AttrKind.BUFFER, persistent=persistent + ) + for name in graph_signature.parameters: + value = state_dict[name] + _assign_attr( + value, + new_gm, + name, + attr_kind=_AttrKind.PARAMETER, + ) + + # Technically this doesn't account for the aliased multiple constants but + # it is ok because we have a separate pass later in the stack that populates + # the final gm. + for name in chain( + graph_signature.lifted_custom_objs, graph_signature.lifted_tensor_constants + ): + value = constants[name] + _assign_attr( + value, + new_gm, + name, + attr_kind=_AttrKind.CONSTANT, + ) + + +class _StatefulGraphModuleFactory(type): + """ + Metaclass that ensures a private constructor for _StatefulGraphModule + """ + + def __call__(cls, *args, **kwargs): + raise TypeError( + f"{cls.__module__}.{cls.__qualname__} has no public constructor. " + ) + + def _create(cls, root, graph, range_constraints=None): + return super().__call__( + root, + graph, + range_constraints=range_constraints, + ) + + +class _StatefulGraphModule(torch.fx.GraphModule, metaclass=_StatefulGraphModuleFactory): + def __init__(self, root, graph, range_constraints=None): + super().__init__(root, graph) + # Need to fix up non-persistent buffers. + self.range_constraints = range_constraints or [] + self.validate_inputs = True + + +def _create_stateful_graph_module( + plain_graph_module: torch.fx.GraphModule, + range_constraints, + ep: ExportedProgram, +) -> _StatefulGraphModule: + stateful_gm = _StatefulGraphModule._create( + plain_graph_module, + plain_graph_module.graph, + range_constraints=range_constraints, + ) + + module_types = _get_graph_inputs_of_type_nn_module(ep.example_inputs) + stateful_gm.register_forward_pre_hook( + lambda *args, **kwargs: _enter_enable_graph_inputs_of_type_nn_module( + module_types + ) + ) + stateful_gm.register_forward_pre_hook( + _check_input_constraints_pre_hook, with_kwargs=True + ) + + stateful_gm.register_forward_hook( + lambda *args, **kwargs: _exit_enable_graph_inputs_of_type_nn_module( + module_types + ), + always_call=True, + ) + + # When we have a constant that has requires_grad=True, we need to detach it + # when we unlift as the tensors that require gradients should be registered + # via parameters. But this is problematic when we have aliasing two constants + # because when we call detach, they will become different tensors. This dict + # keeps track of this logic. + original_tensor_to_detached_tensor = {} + + # Fix up lifted tensor constants. + # fx.GraphModule() constructor silently turns a constant attribute of plain_graph_module + # into a buffer in stateful_gm and creates an inconsistency with graph_signature. + # We fix this by de-registering these buffers in lifted_tensor_constants + # and call _assign_attr(attr_kind=CONSTANT) to register them as constants. + for constant_fqn in ep.graph_signature.lifted_tensor_constants: + # Sometimes, the constant can require gradient, this is probably a bug in user code, + # e.g. `self.const = torch.randn(2, 2, requires_grad=True)`. + # We call detach on the constant_val since they're tensor constants and we don't need to + # compute their gradients anyway. + # Users should properly register it as parameter if they want it to require gradient. + buffer = stateful_gm.get_buffer(constant_fqn) + if buffer.requires_grad: + warnings.warn( + f"A model attribute `{constant_fqn}` requires gradient. " + f"but it's not properly registered as a parameter. " + f"torch.export will detach it and treat it as a constant tensor " + f"but please register it as parameter instead.", + stacklevel=2, + ) + detached_buffer = buffer.detach() + original_tensor_to_detached_tensor[buffer] = detached_buffer + buffer = detached_buffer + *prefix, field = constant_fqn.rsplit(".") + submod = torch.fx.graph_module._get_attr_via_attr_list(stateful_gm, prefix) + delattr(submod, field) + _assign_attr(buffer, stateful_gm, constant_fqn, attr_kind=_AttrKind.CONSTANT) + + # Constants are not preserved well when we create a new GraphModule unlike param/buffers + for const_name, value in ep.constants.items(): + if not torch.fx.graph_module._has_attr(stateful_gm, const_name): + if isinstance(value, torch.Tensor): + if value.requires_grad: + warnings.warn( + f"A model attribute `{const_name}` requires gradient " + f"but it's not properly registered as a parameter. " + f"torch.export will detach it and treat it as a constant tensor " + f"but please register it as parameter instead.", + stacklevel=2, + ) + if value in original_tensor_to_detached_tensor: + value = original_tensor_to_detached_tensor[value] + else: + detached_value = value.detach() + original_tensor_to_detached_tensor[value] = detached_value + value = detached_value + _assign_attr( + value, + stateful_gm, + const_name, + attr_kind=_AttrKind.CONSTANT, + ) + + # Fix up non-persistent buffers. torch.fx does not distinguish between + # persistent and non-persistent buffers, so we must restore that distinction + # here. + for buffer in ep.graph_signature.non_persistent_buffers: + _assign_attr( + plain_graph_module.get_buffer(buffer), + stateful_gm, + buffer, + attr_kind=_AttrKind.BUFFER, + persistent=False, + ) + + return stateful_gm + + +def _get_input_paths(example_inputs, signature): + """ + Generate paths of placeholders, needed for generating the guards function. + + NOTE: Here we make use of the example inputs used for export as well as + the signature of the unlifted graph module (not preserved by export). + """ + + args, kwargs = example_inputs + binded = signature.bind(*args, **kwargs) + binded.apply_defaults() + ctx = binded.arguments + flat_example_inputs_with_paths = pytree.tree_leaves_with_path(ctx) + return [path for path, _ in flat_example_inputs_with_paths] + + +def _replace_sources(result_str: str, flat_input_paths: list[Any]): + """ + Given user specified input paths, maybe fix up the guard string + to reflect user path instead of tracer path. + """ + name_mapping = {} + for idx, path in enumerate(flat_input_paths): + name_mapping[f"L['flat_args'][{idx}]"] = f"L{pytree.keystr(path)}" + + replace = result_str + for key, val in name_mapping.items(): + replace = replace.replace(key, val) + return replace + + +def _get_input_guards_for_graph( + placeholders: list[torch.fx.Node], + range_constraints: dict[sympy.Symbol, ValueRanges], + paths_for_placeholders: list[pytree.KeyPath], +): + """ + Guards generated by the tracer include conditions observed in code, but + but do not include some additional checks we typically do in export. + For example, when dynamic shapes get specialized, are specified to be + within a range, or are specified to be in some equational relation, + corresponding input invalidation is done within a pre_hook, specifically, + `_check_input_constraints_for_graph`. + + Here we generate guards corresponding to the checks that happen in + `_check_input_constraints_for_graph`, and add them to the guards already + generated by the tracer. In the future, it may be worthwhile to separate + them so that we can allow clients to turn off one but not the other. + (Looking at you, AOTI.) + + NOTE: We should eventually reconcile this logic with `build_guards` that + is used by AOT Precompile. + """ + + deferred_expressions = [] + new_guards_code = [] + sources: dict[sympy.Expr, str] = {} + + def handle_symint(expr, src): + if len(expr.free_symbols) == 1: + # complex equations (e.g., involving derived dims) need to + # handled later, since we may not have enough information + # just as we are passing through the placeholders in order + deferred_expressions.append((src, expr)) + if expr in sources: + # expressions that appear in multiple sources should force + # inputs corresponding to those sources to be equal + # e.g., x.shape[0] == y.shape[1] + orig_src = sources[expr] + new_guards_code.append(f"{src} == {orig_src}") + else: + sources[expr] = src + # process value ranges as elsewhere in export + min_val, max_val = _convert_range_to_int(range_constraints[expr]) + if min_val > 2: + new_guards_code.append(f"{src} >= {min_val}") + if max_val < math.inf: + new_guards_code.append(f"{src} <= {max_val}") + + for placeholder, path in zip(placeholders, paths_for_placeholders): + src = "L" + pytree.keystr(path) + meta = placeholder.meta["val"] + # specializations + if isinstance(meta, int): + new_guards_code.append(f"{src} == {meta}") + if isinstance(meta, float): + if meta == math.inf: + new_guards_code.append(f"{src} == math.inf") + elif meta == -math.inf: + new_guards_code.append(f"{src} == -math.inf") + else: + new_guards_code.append(f"{src} == {meta}") + elif isinstance(meta, str): + new_guards_code.append(f"{src} == '{meta}'") + # range constraints and equalities + elif isinstance(meta, torch.SymInt) and meta.node.expr in range_constraints: + handle_symint(meta.node.expr, src) + elif isinstance(meta, torch.Tensor): + for i, dim in enumerate(meta.shape): + src = "L" + pytree.keystr(path) + f".size()[{i}]" + if isinstance(dim, int): + # specializations + new_guards_code.append(f"{src} == {dim}") + elif ( + isinstance(dim, torch.SymInt) and dim.node.expr in range_constraints + ): + # range constraints and equalities + handle_symint(dim.node.expr, src) + + unification_map: dict[sympy.Symbol, sympy.Expr] = {} + py_printer = torch.utils._sympy.printers.PythonPrinter() + + # process complex equations (e.g., involving derived dims) + for src, expr in deferred_expressions: + # we know this is the only symbol in expr (see check above) + symbol = next(iter(expr.free_symbols)) + if symbol in sources: + # if s0 is already known to be directly sourced from inputs, + # e.g., z.shape[2], we do not need to do anything further + # (assume we have already processed constraints on s0 above) + continue + + # otherwise s0 has some "hidden" source like 'dim' + # example: src = y.shape[1], expr = s0 + 1 + if symbol in unification_map: + # suppose that we already know that s0 = x.shape[0] * 2 + # so we can emit the guard: x.shape[0] * 2 + 1 = y.shape[1] + substitution = expr.subs(unification_map) + new_guards_code.append( + py_printer.doprint(sympy.Eq(substitution, sympy.Symbol(src))) + ) + else: + # we do not yet know what s0 is, but given s0 + 1 = y.shape[1], + # we can solve for s0...now knowing that s0 = y.shape[1] - 1 + solution = try_solve(sympy.Eq(expr, sympy.Symbol(src)), symbol) + if solution is not None: + definition = solution[1] + unification_map[symbol] = definition + + return new_guards_code + + +def _ok_to_generate_guards_fn(): + patterns = [ + "executorch", + "modai", + "on_device_ai", + "torchao", + ] + # force check_guards=False for files matching `patterns` + # because they have too many calls to .module() and + # do not like any call modules in the graph + # TODO: fix these files to handle guard fns + frame = inspect.currentframe() + while frame is not None: + if any(path in frame.f_code.co_filename for path in patterns): + return False + frame = frame.f_back + + return True + + +def _unlift_exported_program_lifted_states( + ep: ExportedProgram, check_guards=True +) -> torch.fx.GraphModule: + check_guards = check_guards and _ok_to_generate_guards_fn() + + source_node_dict = { + node.name: node for node in ep.graph.nodes if node.op != "placeholder" + } + # placeholder node name might change after deepcopy + placeholder_source_node_dict = { + node.target: node for node in ep.graph.nodes if node.op == "placeholder" + } + + new_gm = torch.fx.GraphModule(ep.graph_module, copy.deepcopy(ep.graph)) + new_gm.meta.update(ep.graph_module.meta) + ep = copy.copy(ep) + ep._graph_module = new_gm + + # TODO T206340015 + if ep.verifiers[0].dialect != "TRAINING": + ep = _remove_effect_tokens(ep) + + _register_attrs_to_new_gm(new_gm, ep.graph_signature, ep.state_dict, ep.constants) + forward_arg_names = ( + sig.forward_arg_names if (sig := ep.module_call_graph[0].signature) else None + ) + lifted_inputs: list[str | None] = [ + ( + in_spec.target + if in_spec.kind + in ( + InputKind.BUFFER, + InputKind.CONSTANT_TENSOR, + InputKind.PARAMETER, + InputKind.CUSTOM_OBJ, + ) + else None + ) + for in_spec in ep.graph_signature.input_specs + ] + + mutated_outputs: list[str | None] = [ + ( + out_spec.target + if out_spec.kind + in ( + OutputKind.BUFFER_MUTATION, + OutputKind.USER_INPUT_MUTATION, + OutputKind.PARAMETER_MUTATION, + ) + else None + ) + for out_spec in ep.graph_signature.output_specs + ] + + for node in new_gm.graph.nodes: + source_node = None + if node.op == "placeholder": + source_node = placeholder_source_node_dict.get(node.target) + else: + if node.name in source_node_dict: + source_node = source_node_dict.get(node.name) + node.meta["from_node"] = [ + NodeSource( + source_node, + "ExportedProgram.module()", + NodeSourceAction.CREATE, + ) + ] + + assert ep.call_spec.in_spec is not None + new_gm = _unlift( + new_gm, + lifted_inputs, + mutated_outputs, + ep.call_spec.in_spec, + ep.call_spec.out_spec, + forward_arg_names=forward_arg_names, + ) + unlift_gm = _create_stateful_graph_module(new_gm, ep.range_constraints, ep) + unlift_gm.meta.update(ep.graph_module.meta) + + # create a _guards_fn submodule and insert a call to it after placeholders + graph = unlift_gm.graph + placeholders = graph.find_nodes(op="placeholder") + if check_guards and placeholders and ep.example_inputs: + sig = inspect.signature(unlift_gm.forward) + input_paths = _get_input_paths( + ep.example_inputs, + sig, + ) + + # TODO (tmanlaibaatar) + # This is band-aid solution to export new tracer replacing + # shape env sources to flat_args. The real fix should be replacing + # shape env sources to original user sources but this is quite + # involved because you need to carefully construct new sources using + # dynamo and replace all instances of it inside shape env. But it is + # lot easier to manipulate after we turn them into strings and only + # time we use these guards is during retracing or running exported program, + # so it is probably ok to have "not useful" guards on ep for now. + ep_guards = [] + for guard in ep._guards_code: + ep_guards.append(_replace_sources(guard, input_paths)) + + guards_code = _get_input_guards_for_graph( + placeholders, ep.range_constraints, input_paths + ) + + ep_guards_code = _force_ep_signature_match(ep._guards_code, input_paths) + ep_guards_code = _force_gm_signature_match(ep_guards_code, sig) + guards_code.extend(ep_guards_code) + unlift_gm._guards_fn = _convert_guards_code_to_fn(guards_code, input_paths) + + root_nn_module_stack = torch.fx._utils.first_call_function_nn_module_stack( + graph + ) + with graph.inserting_after(placeholders[-1]): + node = graph.call_module("_guards_fn", tuple(placeholders)) + node.meta["nn_module_stack"] = root_nn_module_stack + + unlift_gm.recompile() + + return unlift_gm + + +class GuardsFn(torch.nn.Module): + """ + Module class for guard functions. + """ + + def forward(self, *args): + pass diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/_wrapper_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/_wrapper_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..bc27a8575a0a0d4d90fe9bcbc1a65180f0afdd18 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/_wrapper_utils.py @@ -0,0 +1,10 @@ +import torch + + +class _WrapperModule(torch.nn.Module): + def __init__(self, f): # type: ignore[no-untyped-def] + super().__init__() + self.f = f + + def forward(self, *args, **kwargs): # type: ignore[no-untyped-def] + return self.f(*args, **kwargs) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/custom_obj.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/custom_obj.py new file mode 100644 index 0000000000000000000000000000000000000000..8e7f2080a4ee705a2621386c9b69a089d507544a --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/custom_obj.py @@ -0,0 +1,16 @@ +from dataclasses import dataclass + + +__all__ = ["ScriptObjectMeta"] + + +@dataclass +class ScriptObjectMeta: + """ + Metadata which is stored on nodes representing ScriptObjects. + """ + + # Key into constants table to retrieve the real ScriptObject. + constant_name: str + + class_fqn: str diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/custom_ops.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/custom_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..9df7988da9314c4b18863c88e503ad5b04ae07d4 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/custom_ops.py @@ -0,0 +1,49 @@ +# mypy: allow-untyped-defs +import importlib + +import torch + + +lib = torch.library.Library("export", "FRAGMENT") # noqa: TOR901 + +lib.define( + "access_subclass_inner_tensor(Tensor src_subclass_tensor, str attr) -> Tensor" +) + + +@torch.library.impl(lib, "access_subclass_inner_tensor", "Autograd") +# When running under torch.inference_mode(), we seem to skip AUtograd key +# so we should desugar this op as soon as we start tracing to post-dispatch. +@torch.library.impl(lib, "access_subclass_inner_tensor", "Python") +def _access_subclass_inner_tensor( + src_subclass_tensor: torch.Tensor, attr: str +) -> torch.Tensor: + from torch.utils._python_dispatch import is_traceable_wrapper_subclass + + assert is_traceable_wrapper_subclass(src_subclass_tensor) + val = getattr(src_subclass_tensor, attr, None) + if val is None or not isinstance(val, torch.Tensor): + raise RuntimeError( + f"Attribute {attr} is not a tensor or doesn't exist in {src_subclass_tensor}" + ) + return val + + +def _call_custom_autograd_function_in_pre_dispatch(function_cls_name, *args, **kwargs): + """ + Import a custom autograd function by string name and call it. This is pretty bad + because: + 1) There is no schema + + Ideally we should automatically wrap custom autograd functions with a custom op, but + that is too much work because we need to schematize custom autograd functions. For now, + we just hackily put it in the IR. + """ + # Parse module and class name + module_name, class_name = function_cls_name.rsplit(".", 1) + + # Import the module and get the class + module = importlib.import_module(module_name) + function_cls = getattr(module, class_name) + assert hasattr(function_cls, "apply") + return function_cls.apply(*args, **kwargs) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/decomp_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/decomp_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d3097734c8a35adecf0423633452989c07f68e90 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/decomp_utils.py @@ -0,0 +1,160 @@ +# mypy: allow-untyped-defs +from collections.abc import Callable + +import torch +from torch._export.utils import ( + _collect_all_valid_cia_ops, + _collect_all_valid_cia_ops_for_aten_namespace, + _get_decomp_for_cia, + _is_aten_op, +) + + +__all__ = ["CustomDecompTable"] + + +""" +Core ATen ops with Composite Implicit Autograd dispatch that should be excluded from decomposition +by default. The decomposition logic should eventually exclude all core-tagged CIA ops, but until all +backends are ready, this list allows opt-in one at a time. +""" +PRESERVED_ATEN_CIA_OPS = { + torch.ops.aten.upsample_bilinear2d.vec, + torch.ops.aten.upsample_nearest2d.vec, + # NB: don't use the C++ decomp, because it is not functional! + torch.ops.aten.silu_backward.default, + torch.ops.aten.mish_backward.default, + torch.ops.aten._fused_rms_norm.default, +} + + +class CustomDecompTable(dict[torch._ops.OperatorBase, Callable]): + """ + This is a custom dictionary that is specifically used for handling decomp_table in export. + The reason we need this is because in the new world, you can only *delete* an op from decomp + table to preserve it. This is problematic for custom ops because we don't know when the custom + op will actually be loaded to the dispatcher. As a result, we need to record the custom ops operations + until we really need to materialize it (which is when we run decomposition pass.) + + Invariants we hold are: + 1. All aten decomp is loaded at the init time + 2. We materialize ALL ops when user ever reads from the table to make it more likely + that dispatcher picks up the custom op. + 3. If it is write operation, we don't necessarily materialize + 4. We load the final time during export, right before calling run_decompositions() + + """ + + def __init__(self): + super().__init__() + from torch._decomp import _core_aten_decompositions_post_autograd + + # For aten ops, we load them up in the beginning + self.decomp_table = _core_aten_decompositions_post_autograd() + + for op in _collect_all_valid_cia_ops_for_aten_namespace(): + if op not in PRESERVED_ATEN_CIA_OPS and op not in self.decomp_table: + self.decomp_table[op] = _get_decomp_for_cia(op) + + # This is to track the *pending* deleted custom ops that haven't been materialized yet + self.deleted_custom_ops = set() + # When this is true, there shouldn't be any pending operations in the table. + self.has_materialized = False + + def __getitem__(self, key): + self._materialize_if_needed() + return self.decomp_table.__getitem__(key) + + def __setitem__(self, key, value) -> None: + self.decomp_table.__setitem__(key, value) + + if key in self.deleted_custom_ops: + self.deleted_custom_ops.remove(key) + + def keys(self): + self._materialize_if_needed() + return self.decomp_table.keys() + + def __delitem__(self, key) -> None: + self.pop(key) + + def update(self, other_dict): # type: ignore[override] + for k, v in other_dict.items(): + self.decomp_table.__setitem__(k, v) + + def __missing__(self, key) -> bool: + return not self.__contains__(key) + + def __contains__(self, key) -> bool: + self._materialize_if_needed() + return self.decomp_table.__contains__(key) + + def __len__(self) -> int: + self._materialize_if_needed() + return self.decomp_table.__len__() + + def __iter__(self): + self._materialize_if_needed() + return self.decomp_table.__iter__() + + def __reversed__(self): + self._materialize_if_needed() + return self.decomp_table.__reversed__() + + def copy(self) -> "CustomDecompTable": + new_dict = CustomDecompTable() + new_dict.decomp_table = self.decomp_table.copy() + new_dict.deleted_custom_ops = self.deleted_custom_ops.copy() + new_dict.has_materialized = self.has_materialized + return new_dict + + def pop(self, *args): + def _pop_if_can(key): + if _is_aten_op(key): + return self.decomp_table.pop(key) + + if key in self.decomp_table: + # Even if we materialized it, we should add it to the deleted + # custom ops list so that when we materialize next time, + # we should respect user's intention. + self.deleted_custom_ops.add(key) + return self.decomp_table.pop(key) + + if key in self.deleted_custom_ops: + raise KeyError(f"{key} doesn't exist in the table") + + self.deleted_custom_ops.add(key) + # We would come here when user pops off something that is + # not in the table. In this case, we just pretend that it + # was in the table. + return _get_decomp_for_cia(key) + + if len(args) == 1: + return _pop_if_can(args[0]) + + if len(args) == 2: + try: + return _pop_if_can(args[0]) + except KeyError: + return args[1] + + def items(self): + self._materialize_if_needed() + return self.decomp_table.items() + + def materialize(self) -> dict[torch._ops.OperatorBase, Callable]: + for op in _collect_all_valid_cia_ops(): + if _is_aten_op(op): + continue + elif op in self.decomp_table: + continue + elif op not in self.deleted_custom_ops: + self.decomp_table[op] = _get_decomp_for_cia(op) + + self.has_materialized = True + self.deleted_custom_ops = set() + return {**self.decomp_table} + + def _materialize_if_needed(self) -> None: + if not self.has_materialized: + self.materialize() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/dynamic_shapes.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/dynamic_shapes.py new file mode 100644 index 0000000000000000000000000000000000000000..fe2e12ba0810b703dd2ccae1dc6da447132d3b5a --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/dynamic_shapes.py @@ -0,0 +1,1374 @@ +# mypy: allow-untyped-defs +import dataclasses +import inspect +import logging +import sys +from collections import defaultdict +from collections.abc import Callable +from enum import auto, Enum +from typing import Any, TYPE_CHECKING, Union + +import torch +from torch.utils._pytree import ( + _get_node_type, + BUILTIN_TYPES, + KeyPath, + keystr, + MappingKey, + SequenceKey, + SUPPORTED_NODES, + tree_iter, + tree_map, + tree_map_with_path, + tree_structure, + TreeSpec, +) + +from .exported_program import ExportedProgram + + +if TYPE_CHECKING: + from sympy import Symbol + + from torch._guards import Source + from torch.fx.experimental.symbolic_shapes import ShapeEnv, StrictMinMaxConstraint + +__all__ = [ + "Constraint", + "Dim", + "dims", + "refine_dynamic_shapes_from_suggested_fixes", + "AdditionalInputs", +] + + +log = logging.getLogger(__name__) + + +class _DimHintType(Enum): + """ + Enum for dynamic shape hints. + - AUTO means automatic inference of shape (static or dynamic). + - STATIC means static shape (always specialized). + - DYNAMIC means dynamic, will error out if specialized. + """ + + AUTO = auto() + STATIC = auto() + DYNAMIC = auto() + + +@dataclasses.dataclass +class _DimHint: + """ + Internal class for dynamic shape hints. + - min and max are optional. + - _factory is for UX only, below example: + auto_hint = _DimHint.AUTO() # _factory=True + bounded_hint = auto_hint(min=10, max=100) # Returns new instance with _factory=False + bounded_hint(min=5, max=50) # Will fail, non-factory instance cannot be called + """ + + type: _DimHintType + min: int | None = None + max: int | None = None + _factory: bool | None = True + + @staticmethod + def AUTO(): + return _DimHint(_DimHintType.AUTO) + + @staticmethod + def DYNAMIC(): + return _DimHint(_DimHintType.DYNAMIC) + + @staticmethod + def STATIC(): + return _DimHint(_DimHintType.STATIC) + + def __call__(self, min=None, max=None) -> "_DimHint": + if not self._factory: + raise TypeError(f"'{type(self)}' object is not callable") + assert min is None or min >= 0, "min must be non-negative" + assert max is None or max >= 0, "max must be non-negative" + assert min is None or max is None or min <= max, "min must be <= max" + return _DimHint(self.type, min=min, max=max, _factory=False) + + def __repr__(self): + parts = [self.type.name] + if self.min is not None: + parts.append(f"min={self.min}") + if self.max is not None: + parts.append(f"max={self.max}") + return f"DimHint({', '.join(parts)})" + + +class Dim: + """ + The ``Dim`` class allows users to specify dynamism in their exported + programs. By marking a dimension with a ``Dim``, the compiler associates the + dimension with a symbolic integer containing a dynamic range. + + The API can be used in 2 ways: Dim hints (i.e. automatic dynamic shapes: + ``Dim.AUTO``, ``Dim.DYNAMIC``, ``Dim.STATIC``), or named Dims (i.e. + ``Dim("name", min=1, max=2)``). + + Dim hints provide the lowest barrier to exportability, with the user only + needing to specify if a dimension if dynamic, static, or left for the + compiler to decide (``Dim.AUTO``). The export process will automatically + infer the remaining constraints on min/max ranges and relationships between + dimensions. + + Example:: + + class Foo(nn.Module): + def forward(self, x, y): + assert x.shape[0] == 4 + assert y.shape[0] >= 16 + return x @ y + + + x = torch.randn(4, 8) + y = torch.randn(8, 16) + dynamic_shapes = { + "x": {0: Dim.AUTO, 1: Dim.AUTO}, + "y": {0: Dim.AUTO, 1: Dim.AUTO}, + } + ep = torch.export(Foo(), (x, y), dynamic_shapes=dynamic_shapes) + + Here, export would raise an exception if we replaced all uses of ``Dim.AUTO`` with ``Dim.DYNAMIC``, + as ``x.shape[0]`` is constrained to be static by the model. + + More complex relations between dimensions may also be codegened as runtime assertion nodes by the compiler, + e.g. ``(x.shape[0] + y.shape[1]) % 4 == 0``, to be raised if runtime inputs do not satisfy such constraints. + + You may also specify min-max bounds for Dim hints, e.g. ``Dim.AUTO(min=16, max=32)``, ``Dim.DYNAMIC(max=64)``, + with the compiler inferring the remaining constraints within the ranges. An exception will be raised if + the valid range is entirely outside the user-specified range. + + Named Dims provide a stricter way of specifying dynamism, where exceptions are raised if the compiler + infers constraints that do not match the user specification. For example, exporting the previous + model, the user would need the following ``dynamic_shapes`` argument:: + + s0 = Dim("s0") + s1 = Dim("s1", min=16) + dynamic_shapes = { + "x": {0: 4, 1: s0}, + "y": {0: s0, 1: s1}, + } + ep = torch.export(Foo(), (x, y), dynamic_shapes=dynamic_shapes) + + Named Dims also allow specification of relationships between dimensions, up + to univariate linear relations. For example, the following indicates one + dimension is a multiple of another plus 4:: + + s0 = Dim("s0") + s1 = 3 * s0 + 4 + + """ + + AUTO = _DimHint.AUTO() + DYNAMIC = _DimHint.DYNAMIC() + STATIC = _DimHint.STATIC() + + def __init__(self, name: str, *, min: int | None = None, max: int | None = None): + from torch.utils._sympy.numbers import int_oo + + _min = 0 if min is None else min + _max = int_oo if max is None else max + assert _max > _min, f"Cannot create Dim with inconsistent min={min}, max={max}" + assert name.isidentifier(), f"Dim name must be a valid identifier, got {name}" + self.__name__ = name + self.min = _min + self.max = _max + + def __add__(self, other) -> "Dim": + # e.g., dim + 1 + if type(other) is not int: + raise NotImplementedError( + f"Attempted to add {other} to {self.__name__}, where an integer was expected. " + "(Only increasing linear operations with integer coefficients are supported.)" + ) + return self._derive(lambda x: x + other) + + def __radd__(self, other) -> "Dim": + return self + other + + def __sub__(self, other) -> "Dim": + # e.g., dim - 1 + if type(other) is not int: + raise NotImplementedError( + f"Attempted to subtract {other} from {self.__name__}, where an integer was expected. " + "(Only increasing linear operations with integer coefficients are supported.)" + ) + return self._derive(lambda x: x - other) + + def __rsub__(self, other) -> "Dim": + raise NotImplementedError( + f"Attempted to negate {self.__name__}. " + "(Only increasing linear operations with integer coefficients are supported.)" + ) + + def __mul__(self, other) -> "Dim": + # e.g., dim * 2 + if type(other) is not int or other <= 0: + raise NotImplementedError( + f"Attempted to multiply {other} with {self.__name__}, where a positive integer was expected. " + "(Only increasing linear operations with integer coefficients are supported.)" + ) + return self._derive(lambda x: x * other) + + def __rmul__(self, other) -> "Dim": + return self * other + + def _derived_name(self, fn) -> str: + from sympy import sympify + + return str(fn(sympify(self.__name__))) + + def _derive(self, fn) -> "Dim": + return _DerivedDim(self._derived_name(fn), self, fn) + + @staticmethod + def _readable(name: str, min_: int, max_: int) -> str: + from torch.utils._sympy.numbers import int_oo + + if min_ == 2: + min_ = None # type: ignore[assignment] + if max_ == int_oo: + max_ = None # type: ignore[assignment] + if min_ is None and max_ is None: + return f"Dim('{name}')" + if min_ is None: + return f"Dim('{name}', max={max_})" + if max_ is None: + return f"Dim('{name}', min={min_})" + return f"Dim('{name}', min={min_}, max={max_})" + + def __repr__(self): + return Dim._readable(self.__name__, self.min, self.max) + + +_Dim = Dim # TODO(pianpwk): remove after it's no longer internally breaking + + +class _StaticDim(Dim): + """ + Class for static :func:`Dim` types. + + This class is only for setting and checking static dim constraints, + and the user should never interact with it. + """ + + def __init__(self, value: int): + self.__name__ = str(value) + self.value = value + + @property + def min(self): # type: ignore[override] + return self.value # type: ignore[attr-defined] + + @property + def max(self): # type: ignore[override] + return self.value # type: ignore[attr-defined] + + +class _DerivedDim(Dim): + """ + Class for derived :func:`Dim` types. + + Currently we only support increasing linear expressions with integer coefficients. + In other words, a derived Dim can always be written in the form Ax + B, where + x is a regular Dim (i.e., non-derived Dim), A and B are integers, and A is positive. + (In particular, the latter ensures that x < y => Ax + B < Ay + B.) + These restrictions on the form of derived Dims makes the metatheory simpler: e.g., + it simplifies computing ranges for derived Dims, solving for underlying regular Dims, + deciding equalities between derived Dims, and so on. + + The function lambda x: Ax + B is expressed by `fn`, where x is a normal Dim, `root`. + The range of a derived Dim is computed by mapping `fn` over the range of its `root`. + """ + + def __init__(self, name: str, root: Dim, fn: Callable): + self.__name__ = name + self.root = root + self.fn = fn + + @property + def min(self): # type: ignore[override] + # assume that self.fn is an increasing function + # TODO(avik): use sympy value range analysis instead? + from sympy import Integer + + from torch.utils._sympy.numbers import int_oo + + if self.root.min is -int_oo: # type: ignore[attr-defined] + return -int_oo # fn not needed cuz increasing + + _min_symint = self.fn(Integer(self.root.min)) # type: ignore[attr-defined] + root = self.root # type: ignore[attr-defined] + assert _min_symint >= 0, ( + f"Expected derived min value of {self.__name__} to be >= 0. " + f"Please specify an appropriate min value for {root.__name__} " + f"(currently {root.min})." + ) + return int(_min_symint) + + @property + def max(self): # type: ignore[override] + # assume that self.fn is an increasing function + # TODO(avik): use sympy value range analysis instead? + from sympy import Integer + + from torch.utils._sympy.numbers import int_oo + + if self.root.max is int_oo: # type: ignore[attr-defined] + return int_oo # fn not needed cuz increasing + + _max_symint = self.fn(Integer(self.root.max)) # type: ignore[attr-defined] + root = self.root # type: ignore[attr-defined] + assert _max_symint <= sys.maxsize - 1, ( + f"Expected derived max value of {self.__name__} to be <= {sys.maxsize - 1}. " + f"Please specify an appropriate max value for {root.__name__} " + f"(currently {root.max})." + ) + return int(_max_symint) + + def _derive(self, fn): + # We support nesting, e.g., 2*dim + 1. + # This is implemented by composing operations on the same root. + # As a consequence, roots are always regular Dims (i.e., not derived Dims). + return _DerivedDim( + self._derived_name(fn), + self.root, + lambda x: fn(self.fn(x)), + ) + + def __repr__(self): + return self.__name__ + + +def dims( + *names: str, min: int | None = None, max: int | None = None +) -> tuple[Dim, ...]: + """ + Util to create multiple :func:`Dim` types. + + Returns: + A tuple of :func:`Dim` types. + """ + return tuple(Dim(name, min=min, max=max) for name in names) # type: ignore[misc] + + +@dataclasses.dataclass +class _ConstraintTarget: + """ + This represents input tensor dimensions. + """ + + t_id: int + dim: int + + +@dataclasses.dataclass +class _Constraint(_ConstraintTarget): + """ + This represents a Dim describing a constraint target. + + `name` is the name of the Dim. + `constraint_range` contains the min/max bounds of the Dim. + """ + + name: str + constraint_range: "StrictMinMaxConstraint" + + def _clone_with_range(self, lower=0, upper=None): + # Import sympy locally + from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint + from torch.utils._sympy.numbers import int_oo + from torch.utils._sympy.value_ranges import ValueRanges + + if upper is None: + upper = int_oo + + constraint_range = StrictMinMaxConstraint( + vr=self.constraint_range.vr & ValueRanges(lower=lower, upper=upper), + warn_only=False, + ) + return _Constraint( + self.t_id, + self.dim, + self.name, + constraint_range, + ) + + def __ge__(self, lower): + return self._clone_with_range(lower=lower) + + def __gt__(self, lower): + return self._clone_with_range(lower=lower + 1) + + def __le__(self, upper): + return self._clone_with_range(upper=upper) + + def __lt__(self, upper): + return self._clone_with_range(upper=upper - 1) + + def __bool__(self): + # NOTE(avik): We do not support compound expressions like a <= x <= b. + # This is because Python implicitly desugars them into bool(a <= x) and bool(x <= b), + # and moreover, enforces that any overload of __bool__ must return True or False. + # FWIW, sympy also raises TypeError in this case. + raise TypeError( + "Cannot determine truth value of _Constraint. " + "If you are trying to combine _Constraint's with logical connectives, " + "you can specify them separately instead." + ) + + @property + def serializable_spec(self): + # We need a serialization compatible format of the constraint so that it + # can be savedin the graph module w/o breaking the module serialization. + # The saved constraints will be used directly for the post-exporting pass + # that converts constraints to runtime assertion. The saved constraints + # will not be saved in the serialized module. + # TODO: A better way is needed. Currently we use 't_id' to map the constraint, + # which is not reliable + return { + "t_id": self.t_id, + "dim": self.dim, + "min": self.constraint_range.vr.lower, + "max": self.constraint_range.vr.upper, + } + + +@dataclasses.dataclass +class _PhantomRoot: + """ + This represents the root of a derived Dim where the root does not directly + specify the shape of any input dimension, but the derived Dim does. + + e.g., the input shapes 2*dim and dim + 1 are related via a "phantom" dim. + + The fields `name`, `constraint_range`, and `val` carried by a phantom root + help create a symbol for it. Any derived dims with this phantom root are + backed by expressions over this symbol. + """ + + name: str + constraint_range: "StrictMinMaxConstraint" + val: int + + +@dataclasses.dataclass +class _DerivedConstraint(_ConstraintTarget): + """ + This represents a derived Dim, whose root is either a regular constraint target + (which directly specifies the shape of some input dimension) or a phantom root + (which does so indirectly). + + It can be thought of as a subclass of `_Constraint`, except that it does not + support <, <=, >, >= operations. + """ + + name: str + constraint_range: "StrictMinMaxConstraint" + root: _ConstraintTarget | _PhantomRoot + fn: Callable + + @property + def serializable_spec(self): + # same as _Constraint.serializable_spec + return { + "t_id": self.t_id, + "dim": self.dim, + "min": self.constraint_range.vr.lower, + "max": self.constraint_range.vr.upper, + } + + +@dataclasses.dataclass +class _RelaxedConstraint(_ConstraintTarget): + """ + This represents a dim marked with Dim.AUTO/DYNAMIC (i.e. mark_dynamic() or maybe_mark_dynamic()), + which leaves relations & min/max ranges for inference, instead of requiring explicit specification. + The intention is for constraint violations to not be raised if produce_guards() finds equalities or + relations between a _RelaxedConstraint and another type of _Constraint. + """ + + @property + def serializable_spec(self): + return { + "t_id": self.t_id, + "dim": self.dim, + } + + +Constraint = _Constraint | _DerivedConstraint | _RelaxedConstraint + + +@dataclasses.dataclass +class _IntWrapper: + """ + Dummy wrapper class to wrap around integer inputs so that when we parse the + dynamic_shapes structure, we can mark if any of the integers were marked as + dynamic. + """ + + val: int + # Disallow specifying dynamism + dynamism: _DimHint | int | None = dataclasses.field(init=False, default=None) + + +def _process_equalities( + constraint: Constraint, + get_sources: Callable[[int, int], list["Source"]], + shape_env: "ShapeEnv", + names: dict[str, tuple[int, int]], + source_pairs: list[tuple["Source", "Source"]], + derived_equalities: list[tuple["Source", Union["Source", "Symbol"], Callable]], + phantom_symbols: dict[str, "Symbol"], + relaxed_sources: set["Source"], +): + """ + Updates `source_pairs`, `derived_equalities`, and `phantom_symbols` (which become + fields of `EqualityConstraint`) based on a given input `constraint`. + """ + + sources = get_sources(constraint.t_id, constraint.dim) + if not sources: # empty sources due to unused shapes + return + + source, *other_sources = sources + # When t.size()[dim] maps to src0, src1, ..., srcN, we add + # constraints that make src0 "equal" to src1, ..., srcN. + source_pairs.extend((source, other_source) for other_source in other_sources) + if isinstance(constraint, _Constraint): + if constraint.name in names: + shared_t_id, shared_dim = names[constraint.name] + other_sources = get_sources(shared_t_id, shared_dim) + source_pairs.extend( + (source, other_source) for other_source in other_sources + ) + else: + names[constraint.name] = (constraint.t_id, constraint.dim) + elif isinstance(constraint, _DerivedConstraint): + # branch based on the root of the _DerivedConstraint + if not isinstance(constraint.root, _PhantomRoot): + # either root points to an input source + root = get_sources(constraint.root.t_id, constraint.root.dim)[0] + else: + # or root points to a phantom symbol + if constraint.root.name in phantom_symbols: + root = phantom_symbols[constraint.root.name] + else: + # create a phantom symbol in the shape env based on the _PhantomRoot + root = shape_env.create_symbol( + val=constraint.root.val, + source=torch._dynamo.source.ConstantSource(constraint.root.name), + dynamic_dim=torch.fx.experimental.symbolic_shapes.DimDynamic.DYNAMIC, + constraint_dim=constraint.root.constraint_range, + ) + phantom_symbols[constraint.root.name] = root + + fn = constraint.fn + # A derived equality (source, root, fn) informally corresponds to source = fn(root). + # Here source describes an input and root might describe another input or a phantom symbol. + derived_equalities.append((source, root, fn)) + elif isinstance(constraint, _RelaxedConstraint): + relaxed_sources.add(source) + + +def _tree_map_with_path( + func: Callable[..., Any], + tree: Any, + *dynamic_shapes: Any, + tree_name: str | None = None, +) -> Any: + """ + Customized tree_map for mapping pytrees to dynamic_shapes. + + For built-in types (e.g., standard collections) this behaves exactly like tree_map. + + OTOH for a user-defined class C registered with pytree, we cannot assume that a C + containing tensors can be mapped to a C containing dynamic shapes (i.e., C may not + be a polymorphic container). In that case we use the flattened form of C instead. + Thus a C(**tensors) that flattens to (**tensors) will map to (**dynamic_shapes). + + Args: + func: function to apply to each (int, float, str, bool, None, torch.Tensor) + tree: input pytree + dynamic_shapes: zero or more (typically one) dynamic_shapes to match + + Returns: + output pytree mapping func to each (int, float, str, bool, None, torch.Tensor) + """ + + def is_leaf(t): + # BUILTIN_TYPES is a subset of SUPPORTED_NODES, the latter being all types + # registered with pytree. Types *not* in BUILTIN_TYPES include primitive types + # (int, float, str, bool, None, torch.Tensor), which are not in SUPPORTED_NODES, + # as well as user-defined classes registered with pytree, which are. + return _get_node_type(t) not in BUILTIN_TYPES + + def f(path, t, *dynamic_shapes): + typ = _get_node_type(t) + # typ is not in BUILTIN_TYPES + if typ in SUPPORTED_NODES: + # thus typ is a user-defined class registered with pytree, + # in which case flatten and recurse + return tree_map_with_path( + f, + SUPPORTED_NODES[typ].flatten_fn(t)[0], + *dynamic_shapes, + is_leaf=is_leaf, + ) + else: + return func(path, t, *dynamic_shapes) + + try: + return tree_map_with_path(f, tree, *dynamic_shapes, is_leaf=is_leaf) + except ValueError as e: + if "mismatch" in e.args[0]: + # When PyTree finds a structural mismatch between tree and dynamic_shapes, + # the error message is unfortunately quite horrible. Let's fix that. + assert dynamic_shapes, "Cannot be a mismatch if there is no dynamic_shapes" + assert tree_name, "Must provide a tree_name when there might be a mismatch" + + def _key(type_, context, i): + # derive a PyTree key given the type, context, and child # of a TreeSpec + if type_ is dict: + return MappingKey(context[i]) + if type_ in (list, tuple): + assert context is None + return SequenceKey(i) + raise AssertionError(f"Did not expect type {type_}") + + def raise_mismatch_error(msg): + from torch._dynamo.exc import UserError, UserErrorType + + raise UserError( + UserErrorType.INVALID_INPUT, + f"Detected mismatch between the structure of `{tree_name}` and `dynamic_shapes`: {msg}", + case_name="dynamic_shapes_validation", + ) + + def _compare( + treespec: TreeSpec, other_treespec: TreeSpec, path: KeyPath + ) -> None: + # raise an error at the point where tree and dynamic_shapes differ, + # including the path to that point and the reason for the difference + rendered_path = keystr(path) + if treespec.is_leaf(): + return + if other_treespec.is_leaf(): + raise_mismatch_error( + f"`{tree_name}{rendered_path}` is a {treespec.type}, " + f"but `dynamic_shapes{rendered_path}` is not" + ) + if treespec.type != other_treespec.type: + raise_mismatch_error( + f"`{tree_name}{rendered_path}` is a {treespec.type}, " + f"but `dynamic_shapes{rendered_path}` is a {other_treespec.type}" + ) + if treespec.num_children != other_treespec.num_children: + raise_mismatch_error( + f"`{tree_name}{rendered_path}` has {treespec.num_children} elements, " + f"but `dynamic_shapes{rendered_path}` has {other_treespec.num_children} elements" + ) + if treespec.type is dict: + # context, children could be out of order + if set(treespec.context) != set(other_treespec.context): + raise_mismatch_error( + f"`{tree_name}{rendered_path}` has keys {treespec.context}, " + f"but `dynamic_shapes{rendered_path}` has keys {other_treespec.context}" + ) + _remap = dict( + zip(other_treespec.context, other_treespec.children()) + ) + other_children = [_remap[k] for k in treespec.context] + else: + other_children = other_treespec.children() + for i, (child, other_child) in enumerate( + zip(treespec.children(), other_children) + ): + _compare( + child, + other_child, + path + (_key(treespec.type, treespec.context, i),), + ) + + treespec = tree_structure(tree, is_leaf=is_leaf) + for other_tree in dynamic_shapes: + other_treespec = tree_structure(other_tree, is_leaf) + _compare(treespec, other_treespec, ()) + raise + + +def _combine_args(f, args, kwargs) -> dict[str, Any]: + # combine args and kwargs following the signature of f, as it happens + # in the body of f when called with *args, **kwargs + if isinstance(f, ExportedProgram): + f = f.module() + + signature = ( + inspect.signature(f.forward) + if isinstance(f, torch.nn.Module) + else inspect.signature(f) + ) + kwargs = kwargs if kwargs is not None else {} + return signature.bind(*args, **kwargs).arguments + + +class ShapesCollection: + """ + Builder for dynamic_shapes. + Used to assign dynamic shape specifications to tensors that appear in inputs. + + This is useful particularly when :func:`args` is a nested input structure, and it's + easier to index the input tensors, than to replicate the structure of :func:`args` in + the :func:`dynamic_shapes` specification. + + Example:: + + args = {"x": tensor_x, "others": [tensor_y, tensor_z]} + + dim = torch.export.Dim(...) + dynamic_shapes = torch.export.ShapesCollection() + dynamic_shapes[tensor_x] = (dim, dim + 1, 8) + dynamic_shapes[tensor_y] = {0: dim * 2} + # This is equivalent to the following (now auto-generated): + # dynamic_shapes = {"x": (dim, dim + 1, 8), "others": [{0: dim * 2}, None]} + + torch.export(..., args, dynamic_shapes=dynamic_shapes) + + To specify dynamism for integers, we need to first wrap the integers using + _IntWrapper so that we have a "unique identification tag" for each integer. + + Example:: + + args = {"x": tensor_x, "others": [int_x, int_y]} + # Wrap all ints with _IntWrapper + mapped_args = pytree.tree_map_only(int, lambda a: _IntWrapper(a), args) + + dynamic_shapes = torch.export.ShapesCollection() + dynamic_shapes[tensor_x] = (dim, dim + 1, 8) + dynamic_shapes[mapped_args["others"][0]] = Dim.DYNAMIC + + # This is equivalent to the following (now auto-generated): + # dynamic_shapes = {"x": (dim, dim + 1, 8), "others": [Dim.DYNAMIC, None]} + + torch.export(..., args, dynamic_shapes=dynamic_shapes) + """ + + def __init__(self): + self._shapes = {} + + def __setitem__(self, t, shape): + assert isinstance(t, (torch.Tensor, _IntWrapper)), ( + f"Cannot assign shape to non-tensor or non-_IntWrapper type {type(t)}" + ) + + # TODO(avik): check that shape is indeed a Shape + + t_id = id(t) + if t_id in self._shapes: + _shape = self._shapes[t_id] + assert shape == _shape, ( + f"Shapes assigned to input do not match: expected {_shape}, got {shape}" + ) + else: + self._shapes[id(t)] = shape + + def __getitem__(self, t): + t_id = id(t) + if t_id not in self._shapes: + self._shapes[t_id] = {} + return self._shapes[t_id] + + def __len__(self): + return len(self._shapes) + + def dynamic_shapes(self, m, args, kwargs=None): + """ + Generates the :func:`dynamic_shapes` pytree structure according to :func:`args` and :func:`kwargs`. + """ + + t_ids = set() + + def find_shape(path, t): + t_id = id(t) + if t_id in self._shapes: + t_ids.add(t_id) + return self._shapes[t_id] + else: + return None + + combined_args = _combine_args(m, args, kwargs) + dynamic_shapes = _tree_map_with_path(find_shape, combined_args) + if any(t_id not in t_ids for t_id in self._shapes): + raise ValueError( + "Some tensors that were assigned shapes were not found in args. " + "Maybe such tensors were copied when passing them as args? " + "Maybe such tensors are contained in classes that were not registered with pytree?" + ) + return dynamic_shapes + + +class AdditionalInputs: + """ + Infers dynamic_shapes based on additional inputs. + + This is useful particularly for deployment engineers who, on the one hand, may + have access to ample testing or profiling data that can provide a fair sense of + representative inputs for a model, but on the other hand, may not know enough + about the model to guess which input shapes should be dynamic. + + Input shapes that are different than the original are considered dynamic; conversely, + those that are the same as the original are considered static. Moreover, we verify + that the additional inputs are valid for the exported program. This guarantees that + tracing with them instead of the original would have generated the same graph. + + Example:: + + args0, kwargs0 = ... # example inputs for export + + # other representative inputs that the exported program will run on + dynamic_shapes = torch.export.AdditionalInputs() + dynamic_shapes.add(args1, kwargs1) + ... + dynamic_shapes.add(argsN, kwargsN) + + torch.export(..., args0, kwargs0, dynamic_shapes=dynamic_shapes) + """ + + def __init__(self): + self._examples = [] + + def add(self, args, kwargs=None): + """ + Additional input :func:`args` and :func:`kwargs`. + """ + + assert type(args) is tuple, f"Representative args {args} must be a tuple" + assert kwargs is None or type(kwargs) is dict, ( + f"Representative kwargs {kwargs} must be None or a dict" + ) + self._examples.append((args, kwargs)) + + def dynamic_shapes(self, m, args, kwargs=None): + """ + Infers a :func:`dynamic_shapes` pytree structure by merging shapes of the + original input :func:`args` and :func:`kwargs` and of each additional input + args and kwargs. + """ + + dynamic_shapes, *other_dynamic_shapes = [ + _tree_map_with_path( + lambda path, t: tuple(t.shape) if isinstance(t, torch.Tensor) else t, + _combine_args(m, args, kwargs), + ) + for args, kwargs in [(args, kwargs), *self._examples] + ] + + def _mark_dynamism(v, *other_vs): + if not all(type(v) is type(other) for other in other_vs): + raise ValueError( + "The following inputs were found to have differing types, " + f"so they cannot be marked as dynamic: {(v,) + other_vs}." + ) + + if isinstance(v, int) and not isinstance(v, bool): + if all(other_v == v for other_v in other_vs): + return None + else: + return Dim.DYNAMIC + else: + if not all(other_v == v for other_v in other_vs): + raise ValueError( + "The following inputs were found to have differing values, " + f"but they cannot be marked as dynamic: {(v,) + other_vs}." + ) + return None + + return tree_map( + _mark_dynamism, + dynamic_shapes, + *other_dynamic_shapes, + is_leaf=lambda i: type(i) is int, + ) + + def verify(self, ep): + """ + Verifies that an exported program is valid for each additional input. + """ + + epm = ep.module() + for args, kwargs in self._examples: + torch.export._unlift._check_input_constraints_for_module( + epm, args, kwargs or {} + ) + + +def _warn_on_None_dynamic_shape_dimension(): + msg = ( + "Using None as a dynamic shape dimension is deprecated. " + "Please use Dim.STATIC instead" + ) + # TODO(avik): raise an error in the future + log.warning(msg) + + +def _check_dynamic_shapes( + combined_args: dict[str, Any], + dynamic_shapes: dict[str, Any] | tuple[Any] | list[Any] | None, +): + """ + Checks the dynamic_shapes specification for correctness, + using combined args + kwargs as reference for inputs structure. + """ + from torch._dynamo.exc import UserError, UserErrorType + + if dynamic_shapes is None or len(dynamic_shapes) == 0: + return + if isinstance(dynamic_shapes, (tuple, list)): + combined_args = type(dynamic_shapes)(combined_args.values()) # type: ignore[assignment, misc] + + bounds: dict[str, tuple[int, int]] = {} + + def check_same_bounds(dim): + if dim.__name__ in bounds: + min_, max_ = bounds[dim.__name__] + if dim.min != min_ or dim.max != max_: + this_ = Dim._readable(dim.__name__, min_, max_) + that_ = Dim._readable(dim.__name__, dim.min, dim.max) + raise UserError( + UserErrorType.INVALID_INPUT, + f"Found different definitions {this_} and {that_} " + f"for the same symbolic dimension {dim}!", + ) + else: + bounds[dim.__name__] = (dim.min, dim.max) + + def check_symbols(path, tensor, shape): + if isinstance(shape, dict): + for i, dim in shape.items(): + if isinstance(dim, Dim): + check_same_bounds(dim) + elif dim is None: + _warn_on_None_dynamic_shape_dimension() + elif not (isinstance(dim, (int, _DimHint))): + raise UserError( + UserErrorType.INVALID_INPUT, + f"Unexpected dimension mapped to index {i} in input tensor shape {shape} " + f"specified at `dynamic_shapes{keystr(path)}` " + f"(expected None, an int, a Dim, Dim.AUTO, Dim.STATIC, or Dim.DYNAMIC, " + f" but got {dim!r} instead)", + case_name="dynamic_shapes_validation", + ) + elif isinstance(shape, (tuple, list)): + if len(shape) != len(tensor.shape): + raise UserError( + UserErrorType.INVALID_INPUT, + f"Expected dynamic shape spec {shape} specified at `dynamic_shapes{keystr(path)}` " + f"to have the same length as the actual tensor shape {tensor.shape} " + f"(expected {len(tensor.shape)}, but got {len(shape)} instead)", + case_name="dynamic_shapes_validation", + ) + for i, dim in enumerate(shape): + if isinstance(dim, Dim): + check_same_bounds(dim) + elif dim is None: + _warn_on_None_dynamic_shape_dimension() + elif not (isinstance(dim, (int, _DimHint))): + raise UserError( + UserErrorType.INVALID_INPUT, + f"Unexpected dimension #{i} in input tensor shape {shape} " + f"specified at `dynamic_shapes{keystr(path)}` " + f"(expected None, an int, a Dim, Dim.AUTO, Dim.STATIC, or Dim.DYNAMIC, " + f"but got {dim!r} instead)", + case_name="dynamic_shapes_validation", + ) + elif shape is not None: + raise UserError( + UserErrorType.INVALID_INPUT, + f"Unexpected input tensor shape {shape} specified at `dynamic_shapes{keystr(path)}` " + f"(expected either a list/tuple of dimensions, or a dict mapping indices to dimensions," + f" where each dimension is an int, a Dim, Dim.AUTO, Dim.STATIC, or Dim.DYNAMIC)", + case_name="dynamic_shapes_validation", + ) + + assert isinstance(dynamic_shapes, (dict, tuple, list)) + if isinstance(dynamic_shapes, dict): + got_keys = list(dynamic_shapes.keys()) + expected_arg_names = list(combined_args.keys()) + if sorted(got_keys) != sorted(expected_arg_names): + msg = ( + f"When `dynamic_shapes` is specified as a dict, its top-level keys " + f"must be the arg names {expected_arg_names} of `inputs`, but " + f"here they are {got_keys}. " + ) + if ( + len(combined_args) == 1 + and expected_arg_names[0] not in got_keys + and isinstance(combined_args[expected_arg_names[0]], dict) + ): + msg += ( + "Since here `inputs` is a list/tuple enclosing a single dict, " + "maybe you just forgot to enclose `dynamic_shapes` in a list/tuple?" + ) + else: + msg += ( + "Alternatively, you could also ignore arg names entirely " + "and specify `dynamic_shapes` as a list/tuple matching `inputs`." + ) + raise UserError( + UserErrorType.INVALID_INPUT, msg, case_name="dynamic_shapes_validation" + ) + + def check_shape(path, t, dynamic_shape): + if isinstance(t, torch.Tensor): + check_symbols(path, t, dynamic_shape) + elif isinstance(t, _IntWrapper): + if isinstance(dynamic_shape, _Dim): + raise ValueError( + "Unable to specify input integers as dynamic through named " + "Dims. Please use Dim.AUTO/DYNAMIC instead." + ) + assert dynamic_shape is None or isinstance(dynamic_shape, (int, _DimHint)) + else: + if dynamic_shape is not None: + rendered_path = keystr(path) + raise UserError( + UserErrorType.INVALID_INPUT, + f"Cannot associate shape {dynamic_shape} specified at `dynamic_shapes{rendered_path}` " + f"to non-tensor type {type(t)} at `inputs{rendered_path}` (expected None)", + case_name="dynamic_shapes_validation", + ) + + _tree_map_with_path(check_shape, combined_args, dynamic_shapes, tree_name="inputs") + + +def _process_dynamic_shapes( + combined_args: dict[str, Any], + dynamic_shapes: dict[str, Any] | tuple[Any] | list[Any] | None, +) -> list[Constraint]: + """ + Reads the dynamic_shapes specification and produces a list of constraints. + """ + from torch._dynamo.exc import UserError, UserErrorType + + if dynamic_shapes is None or len(dynamic_shapes) == 0: + # we run with dynamic by default, so no need to produce constraints + return [] + if isinstance(dynamic_shapes, (tuple, list)): + combined_args = type(dynamic_shapes)(combined_args.values()) # type: ignore[assignment, misc] + + # map of Dim names representing input shape dimensions to constraints on them + symbols: dict[str, list[Constraint]] = defaultdict(list) + # track roots that do not directly represent input shape dimensions + phantom_roots: dict[str, _PhantomRoot] = {} + derived_constraints_with_phantom_root: list[_DerivedConstraint] = [] + # list of constraints to return + constraints: list[Constraint] = [] + + def to_constraint(dim, tensor, i): + import sympy + + from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint + from torch.utils._sympy.solve import try_solve + from torch.utils._sympy.value_ranges import ValueRanges + + def root_value(): + # given tensor.shape[i] is the value of dim = fn(root), + # find the value of root + symbol = sympy.Symbol(dim.root.__name__, integer=True) + expr = dim.fn(symbol) + solution = try_solve(sympy.Eq(expr, tensor.shape[i]), symbol) + if solution is not None: + return int(solution[1]) + else: + raise UserError( # noqa: B904 + UserErrorType.CONSTRAINT_VIOLATION, + f"Expected shape[{i}] = {tensor.shape[i]} of input Tensor to be " + f"of the form {expr}, where {symbol} is an integer", + ) + + if isinstance(dim, _DerivedDim): + # generate a _DerivedConstraint where the root is: + # - either a _ConstraintTarget (if dim.root directly describes an input shape) + # - or a _PhantomRoot (otherwise) + dim_root = dim.root # type: ignore[attr-defined] + if dim_root.__name__ in symbols: + # root represents an input shape dimension + root_constraint = symbols[dim_root.__name__][0] + root = _ConstraintTarget( + root_constraint.t_id, + root_constraint.dim, + ) + elif dim_root.__name__ not in phantom_roots: + # create a phantom root + root = _PhantomRoot( # type: ignore[assignment] + name=dim_root.__name__, + constraint_range=StrictMinMaxConstraint( + vr=ValueRanges(lower=dim_root.min, upper=dim_root.max), + warn_only=False, + ), + val=root_value(), + ) + phantom_roots[dim_root.__name__] = root # type: ignore[assignment] + else: + root = phantom_roots[dim_root.__name__] # type: ignore[assignment] + constraint = _DerivedConstraint( + id(tensor), + i, + dim.__name__, + StrictMinMaxConstraint( + vr=ValueRanges(lower=dim.min, upper=dim.max), + warn_only=False, + ), + root, + dim.fn, # type: ignore[attr-defined] + ) + if isinstance(root, _PhantomRoot): + # NOTE(avik): since we have not processed all inputs yet, we may replace this + # with a root that does represent an input shape dimension later (see below) + derived_constraints_with_phantom_root.append(constraint) + elif isinstance(dim, _StaticDim): + constraint = _Constraint( # type: ignore[assignment] + id(tensor), + i, + dim.__name__, + StrictMinMaxConstraint( + vr=ValueRanges(lower=dim.value, upper=dim.value), # type: ignore[attr-defined] + warn_only=False, + ), + ) + else: + assert isinstance(dim, Dim) + constraint = _Constraint( # type: ignore[assignment] + id(tensor), + i, + dim.__name__, + StrictMinMaxConstraint( + vr=ValueRanges(lower=dim.min, upper=dim.max), # type: ignore[attr-defined] + warn_only=False, + ), + ) + return constraint + + def _parse_tensor_dim(tensor, idx, dim) -> None: + def _create_static_dim(tensor, i, value): + return _StaticDim(value) + + if isinstance(dim, (int, Dim)): + if isinstance(dim, int): + dim = _create_static_dim(tensor, idx, dim) + constraint = to_constraint(dim, tensor, idx) + symbols[dim.__name__].append(constraint) + elif isinstance(dim, _DimHint): + if dim.type == _DimHintType.AUTO: + torch._dynamo.maybe_mark_dynamic(tensor, idx) + elif dim.type == _DimHintType.STATIC: + torch._dynamo.mark_static(tensor, idx) + elif dim.type == _DimHintType.DYNAMIC: + torch._dynamo.mark_dynamic(tensor, idx) + constraints.append(_RelaxedConstraint(id(tensor), idx)) + elif dim is None: + torch._dynamo.mark_static(tensor, idx) + + def update_symbols(path, tensor, shape): + # clean out decorators from user side, or previous export call + # we also delete these attributes in non_strict_utils.py/make_constraints() + tensor._dynamo_weak_dynamic_indices = set() + tensor._dynamo_dynamic_indices = set() + tensor._dynamo_dynamic_range = set() + tensor._dynamo_static_indices = set() + tensor._dynamo_unbacked_indices = set() + + if isinstance(shape, dict): + for i, dim in shape.items(): + _parse_tensor_dim(tensor, i, dim) + elif isinstance(shape, (tuple, list)): + for i, dim in enumerate(shape): + _parse_tensor_dim(tensor, i, dim) + elif shape is None: + for i in range(tensor.dim()): + _parse_tensor_dim(tensor, i, None) + + def assoc_shape(path, t, dynamic_shape): + if isinstance(t, torch.Tensor): + update_symbols(path, t, dynamic_shape) + elif isinstance(t, _IntWrapper): + # If tensor dimensions are marked as dynamic, the tensors themselves + # get marked using mark_dynamic. However since we can't mark + # integers as dynamic, we first wrap integers in this class, and + # then set the `dim` field of the class with the dynamic shapes dim + # to mark the integer as dynamic. + t.dynamism = dynamic_shape + + _tree_map_with_path(assoc_shape, combined_args, dynamic_shapes, tree_name="inputs") + + for derived_constraint_with_phantom_root in derived_constraints_with_phantom_root: + phantom_root_name = derived_constraint_with_phantom_root.root.name # type: ignore[union-attr] + if phantom_root_name in symbols: + # We found an input shape dimension corresponding to this name, so we + # do not need a phantom symbol for it after all. + # NOTE(avik): Overall we want to maintain the invariant that roots that + # are phantom symbols are really "phantom," i.e., they cannot be represented + # by any input source. This is important when we are deciding derived equalities, + # since we can focus our attention exclusively on input sources: deciding + # derived equalities involving phantom symbols are, in comparison, trivial. + derived_constraint_with_phantom_root.root = symbols[phantom_root_name][0] + + for dynamic_dims in symbols.values(): + constraints.extend(dynamic_dims) + + return constraints + + +def _get_dim_name_mapping( + dynamic_shapes: dict[str, Any] | tuple[Any] | list[Any] | None, +): + name_to_dim = {} + for dim in tree_iter(dynamic_shapes, is_leaf=lambda x: isinstance(x, Dim)): + if dim is None: + # NOTE: this must denote a non-Tensor or automatic at this point. + continue + if isinstance(dim, int): + continue + elif isinstance(dim, Dim): + name_to_dim[dim.__name__] = dim + if isinstance(dim, _DerivedDim): + name_to_dim[dim.root.__name__] = dim.root # type: ignore[attr-defined] + else: + assert isinstance(dim, _DimHint) + return name_to_dim + + +def refine_dynamic_shapes_from_suggested_fixes( + msg: str, + dynamic_shapes: dict[str, Any] | tuple[Any] | list[Any], +) -> dict[str, Any] | tuple[Any] | list[Any]: + """ + When exporting with :func:`dynamic_shapes`, export may fail with a ConstraintViolation error if the specification + doesn't match the constraints inferred from tracing the model. The error message may provide suggested fixes - + changes that can be made to :func:`dynamic_shapes` to export successfully. + + Example ConstraintViolation error message:: + + Suggested fixes: + + dim = Dim('dim', min=3, max=6) # this just refines the dim's range + dim = 4 # this specializes to a constant + dy = dx + 1 # dy was specified as an independent dim, but is actually tied to dx with this relation + + This is a helper function that takes the ConstraintViolation error message and the original :func:`dynamic_shapes` spec, + and returns a new :func:`dynamic_shapes` spec that incorporates the suggested fixes. + + Example usage:: + + try: + ep = export(mod, args, dynamic_shapes=dynamic_shapes) + except torch._dynamo.exc.UserError as exc: + new_shapes = refine_dynamic_shapes_from_suggested_fixes( + exc.msg, dynamic_shapes + ) + ep = export(mod, args, dynamic_shapes=new_shapes) + + """ + + import re + + import sympy + + from torch._dynamo.exc import UserError, UserErrorType + from torch.fx.experimental.symbolic_shapes import _is_supported_equivalence + + try: + shape_fixes_msg = msg.split("Suggested fixes:")[1].strip() + except Exception as exc: + raise UserError( + UserErrorType.INVALID_INPUT, + "Suggested fixes not found in error message given to refine_dynamic_shapes_from_suggested_fixes()", + ) from exc + + # build shape_fixes dictionary + shape_fixes = {} + for fix in shape_fixes_msg.split("\n"): + fix = fix.strip() + if match := re.match(r"(.*) = Dim\('(.*)'.*\)", fix): + name = match.group(1) + _min, _max = None, None + if match_min := re.match(r".* = Dim\('.*', min\=([0-9]+).*\)", fix): + _min = int(match_min.group(1)) + if match_max := re.match(r".* = Dim\('.*'.*max\=([0-9]+)\)", fix): + _max = int(match_max.group(1)) + shape_fixes[name] = Dim(name, min=_min, max=_max) + else: + name, expr = fix.split(" = ") + expr = sympy.sympify(expr) + if isinstance(expr, sympy.Number): + # static, integer + shape_fixes[name] = int(expr) # type: ignore[assignment] + else: + # relation or derived dim + shape_fixes[name] = expr + + name_to_dim = _get_dim_name_mapping(dynamic_shapes) + + # track derived dim roots + roots: set[str] = set() + for k, c in shape_fixes.items(): + assert isinstance(c, (int, Dim, _DerivedDim, sympy.Expr)) + if isinstance(c, sympy.Expr): # check dim/derived dim expression + assert _is_supported_equivalence(c) + shape_fixes[k] = c + roots.add(str(next(iter(c.free_symbols)))) + if isinstance(c, _DerivedDim): + roots.add(c.root.__name__) # type: ignore[attr-defined] + + # check keys are existing dims or new roots + for k in shape_fixes: + assert k in name_to_dim or k in roots + + # cache so we don't produce multiple derived dim objects + derived_dim_cache: dict[str, _DerivedDim] = {} + + def apply_fixes(path, dim, dummy): + if dim is None or isinstance(dim, int): # not dynamic + return dim + elif dim.__name__ in shape_fixes: # directly fix + fix = shape_fixes[dim.__name__] + if isinstance(fix, sympy.Expr): # now derived or related + if str(fix) in derived_dim_cache: + return derived_dim_cache[str(fix)] + else: + symbol = next(iter(fix.free_symbols)) + # try to locate symbol + if symbol.name in shape_fixes: + root = shape_fixes[symbol.name] + else: + assert symbol.name in name_to_dim + root = name_to_dim[symbol.name] + # figure out value of fix + modulus, remainder = sympy.polys.polytools.div(fix, symbol) + dim = root + if modulus != 1: + dim = int(modulus) * dim + if remainder != 0: + dim = dim + int(remainder) + derived_dim_cache[str(fix)] = dim + return dim + else: + return fix + elif isinstance(dim, _DerivedDim) and dim.root.__name__ in shape_fixes: # type: ignore[attr-defined] + if dim.__name__ in derived_dim_cache: + return derived_dim_cache[dim.__name__] + else: # evaluate new derived value based on root + _dim = dim.fn(shape_fixes[dim.root.__name__]) # type: ignore[attr-defined] + derived_dim_cache[dim.__name__] = _dim + return _dim + return dim # unchanged dim + + return _tree_map_with_path(apply_fixes, dynamic_shapes, dynamic_shapes) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/exported_program.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/exported_program.py new file mode 100644 index 0000000000000000000000000000000000000000..47385fa2f088209d0e6d1c0aa8094885c39f6c5e --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/exported_program.py @@ -0,0 +1,1719 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +import contextlib +import copy +import dataclasses +import functools +import operator +import types +import warnings +from collections import defaultdict +from collections.abc import Callable, Iterator +from contextlib import contextmanager +from typing import Any, final, NamedTuple, TYPE_CHECKING + +from torch._guards import tracing, TracingContext +from torch._higher_order_ops.utils import autograd_not_implemented +from torch._library.fake_class_registry import FakeScriptObject +from torch._subclasses.fake_impls import ( + _deregister_op_impl, + _is_op_registered_to_fake_rule, + register_op_impl, +) +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.fx._symbolic_trace import _ConstantAttributeType +from torch.fx._utils import first_call_function_nn_module_stack +from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo +from torch.fx.immutable_collections import immutable_dict, immutable_list +from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts + + +if TYPE_CHECKING: + # Import the following modules during type checking to enable code intelligence features, + # such as auto-completion in tools like pylance, even when these modules are not explicitly + # imported in user code. + + import sympy + + from torch.utils._sympy.value_ranges import ValueRanges + +import torch +import torch.utils._pytree as pytree +from torch._export.utils import ( + _build_cache, + _collect_all_valid_cia_ops, + _collect_and_set_constant_attrs, + _collect_param_buffer_metadata, + _detect_fake_mode_from_gm, + _fakify_params_buffers, + _get_decomp_for_cia, + _is_preservable_cia_op, + _name_hoo_subgraph_placeholders, + _override_graph_signature_for_temp_registered_constants, + _overwrite_signature_for_non_persistent_buffers, + _populate_param_buffer_metadata_to_new_gm, + _register_constants_as_buffers, + _rename_without_collisions, + _special_op_to_preserve_cia, + placeholder_naming_pass, +) +from torch._export.verifier import Verifier +from torch._guards import detect_fake_mode +from torch._subclasses.fake_tensor import unset_fake_temporarily +from torch.export._tree_utils import is_equivalent, reorder_kwargs +from torch.export.decomp_utils import CustomDecompTable +from torch.fx._compatibility import compatibility +from torch.fx.passes.infra.pass_base import PassResult +from torch.fx.passes.infra.pass_manager import PassManager + +from .graph_signature import ( # noqa: F401 + ArgumentSpec, + ConstantArgument, + CustomObjArgument, + ExportGraphSignature, + InputKind, + InputSpec, + OutputKind, + OutputSpec, + SymBoolArgument, + SymFloatArgument, + SymIntArgument, + TensorArgument, + TokenArgument, +) + + +__all__ = [ + "ExportedProgram", + "ModuleCallEntry", + "ModuleCallSignature", + "default_decompositions", +] + + +PassType = Callable[[torch.fx.GraphModule], PassResult | None] + + +@dataclasses.dataclass +class ModuleCallSignature: + inputs: list[ArgumentSpec] + outputs: list[ArgumentSpec] + in_spec: pytree.TreeSpec + out_spec: pytree.TreeSpec + forward_arg_names: list[str] | None = None + + def replace_all_uses_with(self, original_node, new_node): + for i in self.inputs: + if i.name == original_node.name: + i.name = new_node.name + for o in self.outputs: + if o.name == original_node.name: + o.name = new_node.name + + +@dataclasses.dataclass +class ModuleCallEntry: + fqn: str + signature: ModuleCallSignature | None = None + + +def _disable_prexisiting_fake_mode(fn): + @functools.wraps(fn) + def wrapper(*args, **kwargs): + with unset_fake_temporarily(): + return fn(*args, **kwargs) + + return wrapper + + +def _fx_collection_equivalence_fn( + spec1_type: type | None, + spec1_context: pytree.Context, + spec2_type: type | None, + spec2_context: pytree.Context, +) -> bool: + """Treat containers and their immutable variants as the same type. Otherwise + compare as normal. + """ + if spec1_type is None or spec2_type is None: + return spec1_type is spec2_type and spec1_context == spec2_context + + if issubclass(spec1_type, (dict, immutable_dict)) and issubclass( + spec2_type, (dict, immutable_dict) + ): + return spec1_context == spec2_context + + if issubclass(spec1_type, (list, immutable_list)) and issubclass( + spec2_type, (list, immutable_list) + ): + return spec1_context == spec2_context + + return spec1_type is spec2_type and spec1_context == spec2_context + + +# This list is compiled from DispatchKey.cpp. +# The idea is that we use these keys to override +# CIA decomp in export +_AUTOGRAD_ALIAS_BACKEND_KEYS_TO_OVERRIDE = [ + torch._C.DispatchKey.AutogradCPU, + torch._C.DispatchKey.AutogradCUDA, + torch._C.DispatchKey.AutogradMeta, + torch._C.DispatchKey.AutogradXLA, + torch._C.DispatchKey.AutogradLazy, + torch._C.DispatchKey.AutogradIPU, + torch._C.DispatchKey.AutogradXPU, + torch._C.DispatchKey.AutogradMPS, + torch._C.DispatchKey.AutogradHPU, + torch._C.DispatchKey.AutogradPrivateUse1, + torch._C.DispatchKey.AutogradPrivateUse2, + torch._C.DispatchKey.AutogradPrivateUse3, +] + + +# This list is compiled from DispatchKey.cpp. +# The idea is that we use these keys to add +# python kernels that directly uses default +# CIA decomp +# See NOTE Registering old CIA to Backend kernel +_BACKEND_KEYS_TO_OVERRIDE = [ + torch._C.DispatchKey.CPU, + torch._C.DispatchKey.CUDA, + torch._C.DispatchKey.Meta, + torch._C.DispatchKey.XLA, + torch._C.DispatchKey.Lazy, + torch._C.DispatchKey.IPU, + torch._C.DispatchKey.XPU, + torch._C.DispatchKey.MPS, + torch._C.DispatchKey.HPU, +] + + +@contextmanager +def _override_composite_implicit_decomp(cia_ops_to_callable): + # This function overrides CompositeImplicitAutograd decomp for + # functional composite ops that user specified. Ideally we want to not-decompose + # ALL composite ops but today's C++ functinalization relies on + # the fact that it is working with the opset after decomp is run. + # Hence we can only do it for functional ops. One caveat is that + # there are some composite ops that lie about their schema (claimed to be + # functional but not really aka dropout), for these cases, we just decompose. + saved_tables = {} + patched_ops = set() + for op_overload, decomp_callable in cia_ops_to_callable.items(): + saved_tables[op_overload] = op_overload.py_kernels.copy() + patched_ops.add(op_overload) + for override_dispatch_key in _AUTOGRAD_ALIAS_BACKEND_KEYS_TO_OVERRIDE: + if override_dispatch_key not in op_overload.py_kernels: + # TODO (tmanlaibaatar)https://github.com/pytorch/pytorch/issues/129430 + op_overload.py_impl(override_dispatch_key)( + autograd_not_implemented(op_overload, deferred_error=True) + ) + # See NOTE: Registering old CIA to Backend kernel + # It is important that we cache this before we override py_kernels. + orig_cia_callable = _get_decomp_for_cia(op_overload) + if torch._C.DispatchKey.CompositeImplicitAutograd in op_overload.py_kernels: + del op_overload.py_kernels[torch._C.DispatchKey.CompositeImplicitAutograd] + + op_overload.py_impl(torch._C.DispatchKey.CompositeImplicitAutograd)( + decomp_callable + ) + + # [NOTE] Directly registering fake tensor rule to CIA ops + # The problem we are facing here is if your CIA custom rule + # says we want to preserve the op, we will return NotImplemented. + # Unfortunately, this will invoke meta device tracing in fake tensor + # resulting in divergent behaviour for CIA kernels that has device based + # branching (one case is torch.ops.aten.scaled_dot_product.attention) + # To get around this issue, we register direct fake impl so that we + # run the kernel before we actually try to decompose the op in FakeTensorMode. + # Note that is a no-op in most cases, because: + # 1) In post dispatch tracing, CIA would have already decomposed + # 2) Most CIA impl are device agnostic. + def _force_dispatch_to_orig_cia_callable(fake_tensor_mode, op, *args, **kwargs): + orig_cia_callable = kwargs["original_callable"] + del kwargs["original_callable"] + with fake_tensor_mode: + return orig_cia_callable(*args, **kwargs) + + if not _is_op_registered_to_fake_rule(op_overload): + register_op_impl(op_overload)( + functools.partial( + _force_dispatch_to_orig_cia_callable, + original_callable=orig_cia_callable, + ) + ) + + for key in _BACKEND_KEYS_TO_OVERRIDE: + if key not in op_overload.py_kernels: + # [NOTE] Registering old CIA to Backend kernel + # We always register original CIA behavior to the backend keys kernel + # The reason is when we are fake tensor prop-ing or executing real kernel, + # we end up calling an operator on respective backend, which in python dispatcher, + # will resolve into CIA key. (see resolve_key in torch/_ops.py) + # As a result, this CIA now will call into the custom user defined + # CIA which can cause a problem. + # To make it more concrete, the case we are handling is: + # (1) there is a tensor constant we are performing constant propagation + # on during tracing + # (2) we invoke an op underneath autograd (either because we are below autograd, + # or we are tracing in inference mode), so one of the backend keys gets hit + # (3) the op we are invoking has a CIA impl that normally runs in eager mode + # (and the user wants to tweak this CIA impl during tracing, but during + # const-prop we want the original CIA to run + op_overload.py_impl(key)(orig_cia_callable) + + try: + yield + finally: + for op in patched_ops: + op.py_kernels.clear() + op.py_kernels.update(saved_tables[op]) + op._dispatch_cache.clear() + _deregister_op_impl(op) + + +def _split_decomp_table_to_cia_and_python_decomp( + decomp_table: dict[torch._ops.OperatorBase, Callable], +) -> tuple[dict[torch._ops.OperatorBase, Callable], ...]: + all_preservable_cia_ops = set(_collect_all_valid_cia_ops()) + cia_ops_to_callable = {} + + for op in list(decomp_table.keys()): + # TODO we are silently allowing non-safe(non-functional) ops through a crack + # due to core aten decomp table having non-functional entries. Once we have + # a tighter check around core aten decomp, we should warn users about them. + # Tracking issue: (https://github.com/pytorch/pytorch/issues/135759) + + # if it is a valid CIA op we can mess with in export, we check if it is: + # 1. Has been marked as to be decomposed. Example: + # decomp_table = decomp_table_to_core_aten() + # del decomp_table[aten.linear] + # In this case, user says decompose everything except for aten.linear + # 2. Has been marked with custom decomp behaviour. Example: + # decomp_table = {aten.linear: some_op} + # For (1), we want to remove all the CIA ops that weren't handled by user as + # it suggests they are safe to decompose, so we should remove from preservable_list. + # for (2), we just plumb the custom decomp to AOTDIspatcher. + # In both cases, we want to remove this CIA op from the decomp_table as it is special + # handled. + if op in all_preservable_cia_ops: + cia_ops_to_callable[op] = decomp_table[op] + all_preservable_cia_ops.remove(op) + del decomp_table[op] + # If it is a custom op, we want to still preserve or do whatever + # with it if it is a functional CIA. The reason we don't remove + # from CIA list is because we don't query custom ops. + elif _is_preservable_cia_op(op): + op_name = op.name() + assert not op_name.startswith("aten"), "This should be a custom op" + cia_ops_to_callable[op] = decomp_table[op] + + # If we reached here, it means user intentionally deleted these CIA ops from + # decomp table. + for k in all_preservable_cia_ops: + cia_ops_to_callable[k] = _special_op_to_preserve_cia + + return cia_ops_to_callable, decomp_table + + +def default_decompositions() -> "CustomDecompTable": + """ + This is the default decomposition table which contains decomposition of + all ATEN operators to core aten opset. Use this API together with + :func:`run_decompositions()` + """ + return CustomDecompTable() + + +def _decompose_and_get_gm_with_new_signature_constants( + ep: "ExportedProgram", + *, + cia_to_decomp: dict[torch._ops.OperatorBase, Callable], + python_decomp_table: dict[torch._ops.OperatorBase, Callable], + joint_loss_index: int | None, + decompose_custom_triton_ops, +): + from torch._export.passes.lift_constants_pass import _materialize_and_lift_constants + from torch._functorch.aot_autograd import aot_export_module + from torch.export._trace import ( + _disable_custom_triton_op_functional_decomposition, + _export_to_aten_ir, + _ignore_backend_decomps, + _verify_nn_module_stack, + _verify_placeholder_names, + _verify_stack_trace, + ) + from torch.fx.experimental.symbolic_shapes import ShapeEnv + + def _is_joint_ir_decomp(ep, joint_loss_index): + return ( + joint_loss_index is not None + or ep.graph_signature.backward_signature is not None + ) + + if not _is_joint_ir_decomp(ep, joint_loss_index): + mod = ep.module() + + wrapped_params_buffers = { + **dict(mod.named_parameters(remove_duplicate=False)), + **dict(mod.named_buffers(remove_duplicate=False)), + } + + from torch._functorch._aot_autograd.subclass_parametrization import ( + unwrap_tensor_subclass_parameters, + ) + + # [NOTE] Unwrapping subclasses AOT + # In torch.compile, the subclass unwrapping/wrapping happen at runtime + # but at export, this is impossible as it is intended to be run on + # C++ environment. As a result, we unwrap subclass parameters AOT. After this, + # ExportedProgram state_dict won't be same as eager model because eager model + # could have subclass weights while ExportedProgram will have desugared versions. + # This is fine because run_decompositions is supposed to specialize to post-autograd + # graph where the subclass desugaring is supposed to happen. + unwrap_tensor_subclass_parameters(mod) + unwrapped_params_buffers = { + **dict(mod.named_parameters(remove_duplicate=False)), + **dict(mod.named_buffers(remove_duplicate=False)), + } + + # TODO T204030333 + fake_mode = _detect_fake_mode_from_gm(ep.graph_module) + if fake_mode is None: + fake_mode = FakeTensorMode(shape_env=ShapeEnv(), export=True) + + # Fix the graph output signature to be tuple if scalar + out_spec = mod._out_spec + + assert isinstance(mod.graph._codegen, _PyTreeCodeGen) + orig_arg_names = mod.graph._codegen.pytree_info.orig_args + + # aot_export expect the return type to always be a tuple. + assert out_spec is not None + if out_spec.type not in (list, tuple): + out_spec = pytree.treespec_tuple([out_spec]) + + mod.graph._codegen = _PyTreeCodeGen( + _PyTreeInfo( + orig_arg_names, + mod._in_spec, + out_spec, + ) + ) + + mod.recompile() + + # the exported module will store constants & non-persistent buffers such that + # retracing treats them as persistent buffers, so we inform the constants lifting pass + # and overwrite the new graph signature using the previous program. + _collect_and_set_constant_attrs(ep.graph_signature, ep.constants, mod) + + # When we have a module with constant attributes, AotDispatcher doesn't actually + # wrap them as functional tensors, because dynamo would have already made it buffer. + # In non-strict case, however, AotDispatcher can intercept constants, causing it to not + # functionalize the operators that are operating on constant tensors. Since dynamo already + # wraps constants as buffers, we temporarily register the constants as buffers and undo this + # operation after AOTDispatcher is done. + temp_registered_constants = _register_constants_as_buffers( + mod, ep.state_dict, ep.graph_signature.non_persistent_buffers + ) + + # get params & buffers after excluding constants + fake_params_buffers = _fakify_params_buffers(fake_mode, mod) + + params_buffers_to_node_meta = _collect_param_buffer_metadata(mod) + + # TODO (tmanlaibaatar) Ideally run_decomp should just call _non_strict_export + # but due to special handling of constants as non-persistent buffers make it little + # difficult. But we should unify this code path together. T206837815 + from torch._export.non_strict_utils import ( + _enable_graph_inputs_of_type_nn_module, + _fakify_script_objects, + ) + + retracing_args = [] + for node in mod.graph.nodes: + if node.op == "placeholder": + if isinstance(node.meta["val"], CustomObjArgument): + real_script_obj = None + if node.meta["val"].fake_val is None: + real_script_obj = ep.constants[node.meta["val"].name] + else: + real_script_obj = node.meta["val"].fake_val.real_obj + retracing_args.append(real_script_obj) + else: + retracing_args.append(node.meta["val"]) + + tx = TracingContext(fake_mode) + + with ( + fake_mode, + _override_composite_implicit_decomp( + cia_to_decomp, + ), + _enable_graph_inputs_of_type_nn_module(ep.example_inputs), + tracing(tx), + ): + retracing_args_unwrapped = pytree.tree_unflatten( + retracing_args, mod._in_spec + ) + # this requires empty kwargs, but not in pytree.flattened format + with _fakify_script_objects( + mod, + ( + *retracing_args_unwrapped[0], + *retracing_args_unwrapped[1].values(), + ), + {}, + fake_mode, + ) as ( + patched_mod, + new_fake_args, + new_fake_kwargs, + new_fake_constant_attrs, + map_fake_to_real, + ): + aten_export_artifact = _export_to_aten_ir( + patched_mod, + new_fake_args, + new_fake_kwargs, + fake_params_buffers, + new_fake_constant_attrs, + decomp_table=python_decomp_table, + _prettify_placeholder_names=False, + decompose_custom_triton_ops=decompose_custom_triton_ops, + ) + + # aten_export_artifact.constants contains only fake script objects, we need to map them back + aten_export_artifact.constants = { + fqn: ( + map_fake_to_real[obj] + if isinstance(obj, FakeScriptObject) + else obj + ) + for fqn, obj in aten_export_artifact.constants.items() + } + + gm = aten_export_artifact.gm + new_graph_signature = aten_export_artifact.sig + + # In the previous step, we assume constants as buffers for AOTDispatcher to + # functianalize properly, so undo that here + new_graph_signature = ( + _override_graph_signature_for_temp_registered_constants( + new_graph_signature, temp_registered_constants + ) + ) + + _populate_param_buffer_metadata_to_new_gm( + params_buffers_to_node_meta, gm, new_graph_signature + ) + + # overwrite signature for non-persistent buffers + new_graph_signature = _overwrite_signature_for_non_persistent_buffers( + ep.graph_signature, new_graph_signature + ) + + constants = _materialize_and_lift_constants( + gm, new_graph_signature, new_fake_constant_attrs + ) + + placeholder_naming_pass( + gm, + new_graph_signature, + patched_mod, + new_fake_args, + new_fake_kwargs, + fake_params_buffers, + constants, + ) + + _verify_nn_module_stack(gm) + _verify_stack_trace(gm) + _verify_placeholder_names(gm, new_graph_signature) + + gm, new_graph_signature = _remove_unnecessary_copy_op_pass( + gm, new_graph_signature + ) + + # When we apply parameterization rule to unwrap + # subclasses, the state dict will now have different + # desugared parameters. We need to manually filter those + # and update the ep.state_dict. Ideally, we should just return + # the state dict of ep.module but ep.module only stores params + # buffers that participate in forward. If we undo this behavior, + # it would break some downstream users. + new_state_dict = { + **ep.state_dict, + **{ + name: p + for name, p in unwrapped_params_buffers.items() + if name not in wrapped_params_buffers + }, + } + + for name, p in wrapped_params_buffers.items(): + # Buffers can be persistent/non-persistent + if name not in new_state_dict: + assert not isinstance(p, torch.nn.Parameter) + + if name in new_state_dict: + if name not in unwrapped_params_buffers: + new_state_dict.pop(name) + + return gm, new_graph_signature, new_state_dict + + old_placeholders = [ + node for node in ep.graph_module.graph.nodes if node.op == "placeholder" + ] + fake_args = [node.meta["val"] for node in old_placeholders] + + buffers_to_remove = [name for name, _ in ep.graph_module.named_buffers()] + for name in buffers_to_remove: + delattr(ep.graph_module, name) + + # TODO(zhxhchen17) Return the new graph_signature directly. + fake_mode_det = detect_fake_mode(fake_args) + fake_mode_ctx = contextlib.nullcontext() if fake_mode_det is None else fake_mode_det # type: ignore[assignment] + custom_triton_ops_decomposition_ctx = ( + contextlib.nullcontext + if decompose_custom_triton_ops + else _disable_custom_triton_op_functional_decomposition + ) + with ( + _ignore_backend_decomps(), + fake_mode_ctx, + _override_composite_implicit_decomp(cia_to_decomp), + custom_triton_ops_decomposition_ctx(), + ): + gm, graph_signature = aot_export_module( + ep.graph_module, + fake_args, + decompositions=python_decomp_table, + trace_joint=joint_loss_index is not None, + output_loss_index=( + joint_loss_index if joint_loss_index is not None else None + ), + ) + gm.graph.eliminate_dead_code() + + # Update the signatures with the new placeholder names in case they + # changed when calling aot_export + def update_arg(old_arg, new_ph): + if isinstance(old_arg, ConstantArgument): + return old_arg + elif isinstance(old_arg, TensorArgument): + return TensorArgument(name=new_ph.name) + elif isinstance(old_arg, SymIntArgument): + return SymIntArgument(name=new_ph.name) + elif isinstance(old_arg, SymFloatArgument): + return SymFloatArgument(name=new_ph.name) + elif isinstance(old_arg, SymBoolArgument): + return SymBoolArgument(name=new_ph.name) + raise RuntimeError(f"Type of old_arg not supported: {type(old_arg)}") + + new_placeholders = [node for node in gm.graph.nodes if node.op == "placeholder"] + new_outputs: tuple[torch.fx.Node, ...] = tuple(gm.graph.output_node().args[0]) # type: ignore[arg-type] + + # rename the placeholders + assert len(new_placeholders) == len(old_placeholders) + for old_ph, new_ph in zip(old_placeholders, new_placeholders): + new_ph.name = new_ph.target = old_ph.name + + # handle name collisions with newly decomposed graph nodes + name_map = {} + find_available: dict[str, int] = defaultdict(int) + used_names: set[str] = set() + for ph in new_placeholders: + name_map[ph.name] = ph.name + _build_cache(ph.name, find_available, used_names) + for node in gm.graph.nodes: + if node.op == "placeholder": + continue + node.name = _rename_without_collisions( + name_map, find_available, used_names, node.name, node.name + ) + + # propagate names to higher order op subgraphs + _name_hoo_subgraph_placeholders(gm) + + # Run this pass before creating input/output specs, since size-related CSE/DCE might affect output signature. + # Overwrite output specs afterwards. + from torch._export.passes._node_metadata_hook import ( + _node_metadata_hook, + _set_node_metadata_hook, + ) + from torch._functorch._aot_autograd.input_output_analysis import _graph_output_names + + if not torch._dynamo.config.do_not_emit_runtime_asserts: + stack_trace = ( + 'File "torch/fx/passes/runtime_assert.py", line 24, ' + "in insert_deferred_runtime_asserts" + ) + shape_env = _get_shape_env(gm) + if shape_env is not None: + with _set_node_metadata_hook( + gm, + functools.partial( + _node_metadata_hook, metadata={"stack_trace": stack_trace} + ), + ): + insert_deferred_runtime_asserts( + gm, + shape_env, + f"exported program: {first_call_function_nn_module_stack(gm.graph)}", + export=True, + ) + + # update output specs + gm.recompile() + for output, name in zip(new_outputs, _graph_output_names(gm)): + if name is not None: + output.name = name + + # To match the output target with correct input for input mutations + # need to find the old to new placeholder map + old_new_placeholder_map = { + spec.arg.name: new_placeholders[i].name + for i, spec in enumerate(ep.graph_signature.input_specs) + if not isinstance(spec.arg, ConstantArgument) + } + + input_specs = [ + InputSpec( + spec.kind, + update_arg(spec.arg, new_placeholders[i]), + spec.target, + spec.persistent, + ) + for i, spec in enumerate(ep.graph_signature.input_specs) + ] + + output_specs = [] + + # handle buffer & input mutations; these appear before loss output & gradients + # (1) ep.graph_signature.input_specs tells us types of inputs + # (2) graph_signature.user_inputs tells us node input names in order + # (3) graph_signature.user_inputs_to_mutate tells us buffer & input mutations + # map (3) -> (2) for input order, -> (1) for input type + user_inputs_index = {name: i for i, name in enumerate(graph_signature.user_inputs)} + mutation_names = list(graph_signature.user_inputs_to_mutate.keys()) + assert mutation_names == [node.name for node in new_outputs[: len(mutation_names)]] + for output_name, input_name in graph_signature.user_inputs_to_mutate.items(): + i = user_inputs_index[input_name] + input_spec = ep.graph_signature.input_specs[i] + assert input_spec.kind in (InputKind.USER_INPUT, InputKind.BUFFER) + output_kind = ( + OutputKind.BUFFER_MUTATION + if input_spec.kind == InputKind.BUFFER + else OutputKind.USER_INPUT_MUTATION + ) + target = ( + input_spec.target + if input_spec.kind == InputKind.BUFFER + else input_spec.arg.name + ) + output_specs.append( + OutputSpec( + kind=output_kind, + arg=TensorArgument(name=output_name), + target=target, + ) + ) + + # handle actual user outputs + for i, spec in enumerate(ep.graph_signature.output_specs): + output_specs.append( + OutputSpec( + OutputKind.LOSS_OUTPUT if i == joint_loss_index else spec.kind, + update_arg(spec.arg, new_outputs[len(mutation_names) + i]), + old_new_placeholder_map.get(spec.target, spec.target), + ) + ) + + if joint_loss_index is not None: + assert graph_signature.backward_signature is not None + gradients = graph_signature.backward_signature.gradients_to_user_inputs + assert len(graph_signature.user_inputs) == len(ep.graph_signature.input_specs) + specs = { + graph_signature.user_inputs[i]: spec + for i, spec in enumerate(ep.graph_signature.input_specs) + if isinstance(spec.arg, TensorArgument) + } + for node in new_outputs[len(output_specs) :]: + source = gradients[node.name] + spec = specs[source] # type: ignore[index] + if spec.kind == InputKind.PARAMETER: + kind = OutputKind.GRADIENT_TO_PARAMETER + target = spec.target + elif spec.kind == InputKind.USER_INPUT: + kind = OutputKind.GRADIENT_TO_USER_INPUT + target = source + else: + raise AssertionError(f"Unknown input kind: {spec.kind}") + output_specs.append( + OutputSpec( + kind, + TensorArgument(name=node.name), + target, + ) + ) + + assert len(new_placeholders) == len(old_placeholders) + + new_graph_signature = ExportGraphSignature( + input_specs=input_specs, output_specs=output_specs + ) + # NOTE: aot_export adds symint metadata for placeholders with int + # values; since these become specialized, we replace such metadata with + # the original values. + # Also, set the param/buffer metadata back to the placeholders. + for old_node, new_node in zip(old_placeholders, new_placeholders): + if not isinstance(old_node.meta["val"], torch.Tensor): + new_node.meta["val"] = old_node.meta["val"] + + if ( + new_node.target in new_graph_signature.inputs_to_parameters + or new_node.target in new_graph_signature.inputs_to_buffers + ): + for k, v in old_node.meta.items(): + new_node.meta[k] = v + return gm, new_graph_signature, ep.state_dict + + +def _remove_unnecessary_copy_op_pass( + gm: torch.fx.GraphModule, new_graph_signature: ExportGraphSignature +) -> tuple[torch.fx.GraphModule, ExportGraphSignature]: + """ + Removes redundant copy_ node that was introduced due to mutated buffer. + """ + with gm._set_replace_hook(new_graph_signature.get_replace_hook()): + for node in gm.graph.nodes: + if node.op == "output": + args, _ = pytree.tree_flatten(node.args) + for out in args: + if isinstance(out, torch.fx.Node) and ( + out.name in new_graph_signature.buffers_to_mutate + or out.name in new_graph_signature.parameters_to_mutate + ): + if ( + out.op == "call_function" + and out.target is torch.ops.aten.copy.default + ): + out.replace_all_uses_with(out.args[1]) # type: ignore[arg-type] + gm.graph.erase_node(out) + gm.recompile() + return gm, new_graph_signature + + +def _common_getitem_elimination_pass( + gm: torch.fx.GraphModule, graph_signature, module_call_graph +): + with gm._set_replace_hook(graph_signature.get_replace_hook()): + for module in gm.modules(): + if not isinstance(module, torch.fx.GraphModule): + continue + + node_id: dict[torch.fx.Node, str] = {} + getitems: dict[str, torch.fx.Node] = {} + for node in list(module.graph.nodes): + if node.op == "call_function" and node.target is operator.getitem: + source, idx = node.args + new_id = f"{node_id[source]}.{idx}" + if new_id in getitems: + node.replace_all_uses_with(getitems[new_id]) + for entry in module_call_graph: + if entry.signature is not None: + entry.signature.replace_all_uses_with( + node, getitems[new_id] + ) + module.graph.erase_node(node) + else: + getitems[new_id] = node + node_id[node] = new_id + else: + node_id[node] = node.name + + +def _get_updated_module_call_graph( + old_gm: torch.fx.GraphModule, + old_graph_signature: ExportGraphSignature, + gm: torch.fx.GraphModule, + graph_signature: ExportGraphSignature, + old_module_call_graph: list[ModuleCallEntry], +): + new_module_call_graph = copy.deepcopy(old_module_call_graph) + + old_nodes = {node.name: node for node in old_gm.graph.nodes} + + old_graph_params_buffers = { + **old_graph_signature.inputs_to_parameters, + **old_graph_signature.inputs_to_buffers, + } + new_graph_params_buffers = { + **graph_signature.inputs_to_parameters, + **graph_signature.inputs_to_buffers, + } + + # use node-level provenance metadata to create a map + # from old node names to new node names + provenance: dict[str, str] = {} + + user_input_counter = 0 + old_user_input_names = [ + node.target for node in old_gm.graph.nodes if node.op == "placeholder" + ] + old_user_input_names = list( + filter( + lambda x: x not in old_graph_params_buffers + and x not in old_graph_signature.input_tokens, + old_user_input_names, + ) + ) + new_user_input_names = [ + node.target for node in gm.graph.nodes if node.op == "placeholder" + ] + + for node in gm.graph.nodes: + if history := node.meta.get("from_node", []): + provenance[history[-1].name] = node.name + + # For params and buffers, we might have applied parameterizaiton rule + # so that the names might have changed. But for user inputs, we know we + # must preserve the old name. + elif node.op == "placeholder": + if not ( + node.target in new_graph_params_buffers + or node.target in graph_signature.input_tokens + ): + if node.target in new_user_input_names: + assert isinstance(node.name, str) + old_name = old_user_input_names[user_input_counter] + assert isinstance(old_name, str) + provenance[old_name] = node.name + user_input_counter += 1 + + # For all the parameters and buffers, we first see + # if they are result of parametrizations and if they + # are, we log them and error later + old_param_to_desugared = defaultdict(list) + for name, target in new_graph_params_buffers.items(): + # if the parameters are not parametrized, the naming won't change. + if not target.startswith("parametrizations."): + # If we are in strict mode, we can't just reuse the param names + if name in old_graph_params_buffers: + provenance[name] = name + else: + old_target = ".".join(target.split(".")[1:-1]) + old_param_to_desugared[old_target].append(name) + + # map old names to new names in module call signatures + for entry in new_module_call_graph: + signature = entry.signature + if signature is None: + continue + for x in [*signature.inputs, *signature.outputs]: + # We noticed that submodule is taking subclass as input. we can't + # preserve signature here. + if x.name in old_param_to_desugared: + raise ValueError( + f"It looks like {x.name} is a tensor subclass. " + f"Preserving submodule that takes subclass parameter is not supported" + f" in inference IR because we desugar them, resulting in more tensors" + ) + + if x.name in provenance: + x.name = provenance[x.name] + + # This can happen when aten.to is called at graph boundaries. + # Basically aten.to at post-dispatch level can either be copy + # or alias. In the alias case, we will no-op it so it will + # disappear from the graph. If we detect such case, we should + # reuse the input to aten.to as the new input to the submodule. + # Technically this can happen for other maybe aliasing ops, + # but aten.to is probably the most common one. + elif x.name in old_nodes: + old_node = old_nodes[x.name] + if old_node.op == "call_function" and old_node.target in [ + torch.ops.aten.to.dtype_layout, + torch.ops.aten.to.device, + torch.ops.aten.to.dtype, + ]: + old_target = old_node.args[0].name + if old_target not in provenance: + raise ValueError( + f"It looks like {old_target} is a tensor subclass. " + f"Preserving submodule that takes subclass parameter is not supported" + f" in inference IR because we desugar them, resulting in more tensors" + ) + + x.name = provenance[old_target] + + return new_module_call_graph + + +def _decompose_exported_program( + ep, + *, + cia_to_decomp: dict[torch._ops.OperatorBase, Callable], + python_decomp_table: dict[torch._ops.OperatorBase, Callable], + joint_loss_index: int | None, + decompose_custom_triton_ops: bool, +): + ( + gm, + new_graph_signature, + state_dict, + ) = _decompose_and_get_gm_with_new_signature_constants( + ep, + cia_to_decomp=cia_to_decomp, + python_decomp_table=python_decomp_table, + joint_loss_index=joint_loss_index, + decompose_custom_triton_ops=decompose_custom_triton_ops, + ) + + # The signatures of ep.module_call_graph refer to input / output nodes of + # the original graph module. However, the new graph module may have + # new nodes due to decompositions. So we need to update these signatures + # in the decomposed exported program's module_call_graph. + new_module_call_graph = _get_updated_module_call_graph( + ep.graph_module, + ep.graph_signature, + gm, + new_graph_signature, + ep.module_call_graph, + ) + + # TODO unfortunately preserving graph-level metadata is not + # working well with aot_export. So we manually copy it. + # (The node-level meta is addressed above.) + gm.meta.update(ep.graph_module.meta) + + new_range_constraints = _get_updated_range_constraints( + gm, + ep.range_constraints, + ) + + exported_program = ExportedProgram( + root=gm, + graph=gm.graph, + graph_signature=new_graph_signature, + state_dict=state_dict, + range_constraints=new_range_constraints, + module_call_graph=new_module_call_graph, + example_inputs=ep.example_inputs, + constants=ep.constants, + ) + return exported_program + + +class ExportedProgram: + """ + Package of a program from :func:`export`. It contains + an :class:`torch.fx.Graph` that represents Tensor computation, a state_dict containing + tensor values of all lifted parameters and buffers, and various metadata. + + You can call an ExportedProgram like the original callable traced by + :func:`export` with the same calling convention. + + To perform transformations on the graph, use ``.module`` property to access + an :class:`torch.fx.GraphModule`. You can then use + `FX transformation `_ + to rewrite the graph. Afterwards, you can simply use :func:`export` + again to construct a correct ExportedProgram. + """ + + _graph_module: torch.fx.GraphModule + """The underlying GraphModule containing the exported computation graph.""" + + _graph_signature: ExportGraphSignature + """The signature containing input/output specifications for the graph.""" + + _state_dict: dict[str, Any] + """Dictionary containing parameter and buffer values from the original module.""" + + _range_constraints: "dict[sympy.Symbol, ValueRanges]" + """Symbolic shape constraints for dynamic shapes in the graph.""" + + _module_call_graph: list[ModuleCallEntry] + """Call graph information tracking module hierarchy and signatures.""" + + _example_inputs: tuple[tuple[Any, ...], dict[str, Any]] | None + """Example inputs used during export, stored as (args, kwargs) tuple.""" + + _constants: dict[str, _ConstantAttributeType] + """Dictionary of constant values used in the graph.""" + + _verifiers: list[type[Verifier]] + """List of verifier classes used to validate the exported program.""" + + _guards_code: list[str] + + def __init__( + self, + root: torch.nn.Module | dict[str, Any], + graph: torch.fx.Graph, + graph_signature: ExportGraphSignature, + state_dict: dict[str, torch.Tensor | torch.nn.Parameter], + range_constraints: "dict[sympy.Symbol, Any]", + module_call_graph: list[ModuleCallEntry], + example_inputs: tuple[tuple[Any, ...], dict[str, Any]] | None = None, + constants: dict[str, _ConstantAttributeType] | None = None, + *, + verifiers: list[type[Verifier]] | None = None, + ): + # Remove codegen related things from the graph. It should just be a flat graph. + graph._codegen = torch.fx.graph.CodeGen() + self._graph_module = _create_graph_module_for_export(root, graph) + if isinstance(root, torch.fx.GraphModule): + self._graph_module.meta.update(root.meta) + + _common_getitem_elimination_pass( + self._graph_module, graph_signature, module_call_graph + ) + self._graph_signature: ExportGraphSignature = graph_signature + self._state_dict: dict[str, Any] = state_dict + self._range_constraints: dict[sympy.Symbol, ValueRanges] = range_constraints + assert module_call_graph is not None + self._module_call_graph: list[ModuleCallEntry] = module_call_graph + self._example_inputs = example_inputs + + self._constants = constants or {} + + verifiers = verifiers or [Verifier] + assert all(issubclass(v, Verifier) for v in verifiers) + self._verifiers = verifiers + # Validate should be always the last step of the constructor. + self.validate() + + self._guards_code = _convert_guards_to_code(self._graph_module) + + @property + @compatibility(is_backward_compatible=False) + def graph_module(self): + return self._graph_module + + @graph_module.setter + @compatibility(is_backward_compatible=False) + def graph_module(self, value): + raise RuntimeError("Unable to set ExportedProgram's graph_module attribute.") + + @property + @compatibility(is_backward_compatible=False) + def graph(self): + return self.graph_module.graph + + @graph.setter + @compatibility(is_backward_compatible=False) + def graph(self, value): + raise RuntimeError("Unable to set ExportedProgram's graph attribute.") + + @property + @compatibility(is_backward_compatible=False) + def graph_signature(self): + return self._graph_signature + + @graph_signature.setter + @compatibility(is_backward_compatible=False) + def graph_signature(self, value): + raise RuntimeError("Unable to set ExportedProgram's graph_signature attribute.") + + @property + @compatibility(is_backward_compatible=False) + def state_dict(self): + return self._state_dict + + @state_dict.setter + @compatibility(is_backward_compatible=False) + def state_dict(self, value): + raise RuntimeError("Unable to set ExportedProgram's state_dict attribute.") + + @compatibility(is_backward_compatible=False) + def parameters(self) -> Iterator[torch.nn.Parameter]: + """ + Returns an iterator over original module's parameters. + """ + for _, param in self.named_parameters(): + yield param + + @compatibility(is_backward_compatible=False) + def named_parameters(self) -> Iterator[tuple[str, torch.nn.Parameter]]: + """ + Returns an iterator over original module parameters, yielding + both the name of the parameter as well as the parameter itself. + """ + for param_name in self.graph_signature.parameters: + yield param_name, self.state_dict[param_name] + + @compatibility(is_backward_compatible=False) + def buffers(self) -> Iterator[torch.Tensor]: + """ + Returns an iterator over original module buffers. + """ + for _, buf in self.named_buffers(): + yield buf + + @compatibility(is_backward_compatible=False) + def named_buffers(self) -> Iterator[tuple[str, torch.Tensor]]: + """ + Returns an iterator over original module buffers, yielding + both the name of the buffer as well as the buffer itself. + """ + non_persistent_buffers = set(self.graph_signature.non_persistent_buffers) + for buffer_name in self.graph_signature.buffers: + if buffer_name in non_persistent_buffers: + yield buffer_name, self.constants[buffer_name] + else: + yield buffer_name, self.state_dict[buffer_name] + + @property + @compatibility(is_backward_compatible=False) + def range_constraints(self): + return self._range_constraints + + @range_constraints.setter + @compatibility(is_backward_compatible=False) + def range_constraints(self, value): + raise RuntimeError( + "Unable to set ExportedProgram's range_constraints attribute." + ) + + @property + @compatibility(is_backward_compatible=False) + def module_call_graph(self): + return self._module_call_graph + + @module_call_graph.setter + @compatibility(is_backward_compatible=False) + def module_call_graph(self, value): + raise RuntimeError( + "Unable to set ExportedProgram's module_call_graph attribute." + ) + + @property + @compatibility(is_backward_compatible=False) + def example_inputs(self): + return self._example_inputs + + @example_inputs.setter + @compatibility(is_backward_compatible=False) + def example_inputs(self, value): + # This is allowed + + if value is None: + self._example_inputs = value + return + + if not ( + isinstance(value, tuple) + and len(value) == 2 + and isinstance(value[0], tuple) + and isinstance(value[1], dict) + ): + raise ValueError( + "Example inputs should be a tuple containing example arguments (as " + "a tuple), and example kwargs (as a dictionary)." + ) + + args, kwargs = value + from ._unlift import _check_inputs_match + + _check_inputs_match(args, kwargs, self.call_spec.in_spec) + + self._example_inputs = value + + @property + @compatibility(is_backward_compatible=False) + def call_spec(self): + class CallSpec(NamedTuple): + in_spec: pytree.TreeSpec | None + out_spec: pytree.TreeSpec | None + + if len(self.module_call_graph) == 0: + return CallSpec(in_spec=None, out_spec=None) + assert self.module_call_graph[0].fqn == "" + return CallSpec( + in_spec=self.module_call_graph[0].signature.in_spec, + out_spec=self.module_call_graph[0].signature.out_spec, + ) + + @call_spec.setter + @compatibility(is_backward_compatible=False) + def call_spec(self, value): + raise RuntimeError("Unable to set ExportedProgram's call_spec attribute.") + + @property + @compatibility(is_backward_compatible=False) + def verifier(self) -> Any: + return self._verifiers[0] + + @verifier.setter + @compatibility(is_backward_compatible=False) + def verifier(self, value): + raise RuntimeError("Unable to set ExportedProgram's verifier attribute.") + + @property + @compatibility(is_backward_compatible=False) + def dialect(self) -> str: + assert self._verifiers is not None + return self._verifiers[0].dialect + + @dialect.setter + @compatibility(is_backward_compatible=False) + def dialect(self, value): + raise RuntimeError("Unable to set ExportedProgram's dialect attribute.") + + @property + @compatibility(is_backward_compatible=False) + def verifiers(self): + return self._verifiers + + @verifiers.setter + @compatibility(is_backward_compatible=False) + def verifiers(self, value): + raise RuntimeError("Unable to set ExportedProgram's verifiers attribute.") + + @property + @compatibility(is_backward_compatible=False) + def tensor_constants(self): + return self._constants + + @tensor_constants.setter + @compatibility(is_backward_compatible=False) + def tensor_constants(self, value): + raise RuntimeError( + "Unable to set ExportedProgram's tensor_constants attribute." + ) + + @property + @compatibility(is_backward_compatible=False) + def constants(self): + return self._constants + + @constants.setter + @compatibility(is_backward_compatible=False) + def constants(self, value): + raise RuntimeError("Unable to set ExportedProgram's constants attribute.") + + def _get_flat_args_with_check(self, args, kwargs): + """Flatten args, kwargs using pytree, then, check specs. + + Args: + args: List[Any] original args passed to __call__ + kwargs: Dict[str, Any] original kwargs passed to __call + + Returns: + A tuple of (flat_args, received_spec) + flat_args is flattened args / kwargs + received_spec is the pytree spec produced while flattening the + tuple (args, kwargs) + """ + in_spec = self.call_spec.in_spec + if in_spec is not None: + kwargs = reorder_kwargs(kwargs, in_spec) + flat_args_with_path, received_spec = pytree.tree_flatten_with_path( + (args, kwargs) + ) + self._check_input_constraints(flat_args_with_path) + flat_args = tuple(x[1] for x in flat_args_with_path) + return flat_args, received_spec + + def _graph_module_flat_inputs(self, args: Any, kwargs: Any) -> Any: + """Transform args, kwargs of __call__ to args for graph_module. + + self.graph_module takes stuff from state dict as inputs. + The invariant is for ep: ExportedProgram is + ep(args, kwargs) == + ep.postprocess(ep.graph_module(ep.graph_module_flat_inputs(args, kwargs))) + """ + + in_spec = self.call_spec.in_spec + flat_args, received_spec = self._get_flat_args_with_check(args, kwargs) + if in_spec is not None and not is_equivalent( + received_spec, in_spec, _fx_collection_equivalence_fn + ): + raise ValueError( + "Trying to flatten user inputs with exported input tree spec: \n" + f"{in_spec}\n" + "but actually got inputs with tree spec of: \n" + f"{received_spec}" + ) + + additional_inputs = [] + for input_ in self.graph_signature.input_specs: + if input_.kind == InputKind.USER_INPUT: + continue + elif input_.kind in ( + InputKind.PARAMETER, + InputKind.BUFFER, + ): + if input_.persistent is False: + # This is a non-persistent buffer, grab it from our + # constants instead of the state dict. + additional_inputs.append(self.constants[input_.target]) + else: + additional_inputs.append(self.state_dict[input_.target]) + elif input_.kind in ( + InputKind.CONSTANT_TENSOR, + InputKind.CUSTOM_OBJ, + ): + additional_inputs.append(self.constants[input_.target]) + additional_inputs = tuple(additional_inputs) + + # NOTE: calling convention is first params, then buffers, then args as user supplied them. + # See: torch/_functorch/aot_autograd.py#L1034 + return additional_inputs + flat_args + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + raise RuntimeError( + "Unable to call ExportedProgram directly. " + "You should use `exported_program.module()` instead." + ) + + def __str__(self) -> str: + graph_module = self.graph_module.print_readable( + print_output=False, colored=False + ).replace("\n", "\n ") + graph_signature = str(self.graph_signature).replace("\n", "\n ") + string = ( + "ExportedProgram:\n" + f" {graph_module}\n" + f"Graph signature: {graph_signature}\n" + f"Range constraints: {self.range_constraints}\n" + ) + return string + + def module(self, check_guards=True) -> torch.fx.GraphModule: + """ + Returns a self contained GraphModule with all the parameters/buffers inlined. + + - When `check_guards=True` (default), a `_guards_fn` submodule is generated + and a call to a `_guards_fn` submodule is inserted right after placeholders + in the graph. This module checks guards on inputs. + - When `check_guards=False`, a subset of these checks are performed by a + forward pre-hook on the graph module. No `_guards_fn` submodule is generated. + + """ + from ._unlift import _unlift_exported_program_lifted_states + + module = _unlift_exported_program_lifted_states(self, check_guards=check_guards) + + def _train(self, mode: bool = True): + raise NotImplementedError("Calling train() is not supported yet.") + + def _eval(self, mode: bool = True): + raise NotImplementedError("Calling eval() is not supported yet.") + + module.train = types.MethodType(_train, module) # type: ignore[method-assign] + module.eval = types.MethodType(_eval, module) # type: ignore[method-assign] + return module + + def _num_lifted_params_buffers(self): + return next( + ( + i + for i, s in enumerate(self._graph_signature.input_specs) + if s.kind == InputKind.USER_INPUT + ), + len(self._graph_signature.input_specs), + ) + + @_disable_prexisiting_fake_mode + def run_decompositions( + self, + decomp_table: dict[torch._ops.OperatorBase, Callable] | None = None, + decompose_custom_triton_ops: bool = False, + ) -> "ExportedProgram": + """ + Run a set of decompositions on the exported program and returns a new + exported program. By default we will run the Core ATen decompositions to + get operators in the + `Core ATen Operator Set `_. + + For now, we do not decompose joint graphs. + + Args: + decomp_table: + An optional argument that specifies decomp behaviour for Aten ops + (1) If None, we decompose to core aten decompositions + (2) If empty, we don't decompose any operator + + + Some examples: + + If you don't want to decompose anything + + .. code-block:: python + + ep = torch.export.export(model, ...) + ep = ep.run_decompositions(decomp_table={}) + + If you want to get a core aten operator set except for certain operator, you can do following: + + .. code-block:: python + + ep = torch.export.export(model, ...) + decomp_table = torch.export.default_decompositions() + decomp_table[your_op] = your_custom_decomp + ep = ep.run_decompositions(decomp_table=decomp_table) + """ + _decomp_table = ( + default_decompositions() if decomp_table is None else dict(decomp_table) + ) + + if isinstance(_decomp_table, CustomDecompTable): + _decomp_table = _decomp_table.materialize() + + # Note [Separating decomp_table into CIA decomps and non-CIA decomps] + # At this point, we have a decomp_table that contains decomp behaviour for + # both CIA and post-autograd ops. + # We need to separate the op into two categories: + # 1. CIA op: These are the ops that we want to override + # CompositeImplicitAutograd decomp for. For them, we need to use _override_composite_implicit_decomp + # context manager to plumb it through AOTDispatcher + # 2. Non-CIA op: These ops are only relevant after AOTDIspatcher runs, so just + # checking if they are statically functional is enough. + # For joint IR case tho, we need to use the old path because we can't register + # custom decomps this way because we can't use context manager as it installs + # autograd_error node. + ( + cia_to_decomp, + python_decomp_table, + ) = _split_decomp_table_to_cia_and_python_decomp(_decomp_table) + + return _decompose_exported_program( + self, + cia_to_decomp=cia_to_decomp, + python_decomp_table=python_decomp_table, + joint_loss_index=None, + decompose_custom_triton_ops=decompose_custom_triton_ops, + ) + + def _transform_do_not_use(self, *passes: PassType) -> "ExportedProgram": + pm = PassManager(list(passes)) + # Since we abstractly run the passes, we need to disable backend decomp here + # again. + from torch.export._trace import _ignore_backend_decomps + + with _ignore_backend_decomps(): + res = pm(self.graph_module) + transformed_gm = res.graph_module if res is not None else self.graph_module + assert transformed_gm is not None + + # pyrefly: ignore [missing-attribute] + if transformed_gm is self.graph_module and not res.modified: + return self + + # TODO(zhxchen17) Remove this. + def _get_updated_graph_signature( + old_signature: ExportGraphSignature, + new_gm: torch.fx.GraphModule, + ) -> ExportGraphSignature: + """ + Update the graph signature's user_input/user_outputs. + """ + new_input_specs = [] + for i, node in enumerate(new_gm.graph.nodes): + if node.op != "placeholder": + break + + assert i < len(old_signature.input_specs), ( + "Number of inputs changed after transformation" + ) + old_input_spec = old_signature.input_specs[i] + arg = ( + old_input_spec.arg + if isinstance( + old_input_spec.arg, (ConstantArgument, CustomObjArgument) + ) + else type(old_input_spec.arg)(node.name) + ) + new_input_specs.append( + InputSpec( + old_input_spec.kind, + arg, + old_input_spec.target, + old_input_spec.persistent, + ) + ) + + output_node = list(new_gm.graph.nodes)[-1] + assert output_node.op == "output" + + new_output_specs = [] + for i, node in enumerate(output_node.args[0]): + assert i < len(old_signature.output_specs), ( + "Number of outputs changed after transformation" + ) + old_output_spec = old_signature.output_specs[i] + arg = ( + old_output_spec.arg + if isinstance( + old_output_spec.arg, (ConstantArgument, CustomObjArgument) + ) + else type(old_output_spec.arg)(node.name) + ) + new_output_specs.append( + OutputSpec(old_output_spec.kind, arg, old_output_spec.target) + ) + + new_signature = ExportGraphSignature( + input_specs=new_input_specs, output_specs=new_output_specs + ) + return new_signature + + transformed_ep = ExportedProgram( + root=transformed_gm, + graph=transformed_gm.graph, + graph_signature=_get_updated_graph_signature( + self.graph_signature, transformed_gm + ), + state_dict=self.state_dict, + range_constraints=_get_updated_range_constraints( + transformed_gm, + self.range_constraints, + ), + module_call_graph=copy.deepcopy(self._module_call_graph), + example_inputs=self.example_inputs, + constants=self.constants, + verifiers=self.verifiers, + ) + transformed_ep.graph_module.meta.update(self.graph_module.meta) + # pyrefly: ignore [missing-attribute] + transformed_ep.graph_module.meta.update(res.graph_module.meta) + return transformed_ep + + def _check_input_constraints(self, flat_args_with_path): + from torch._export.utils import _check_input_constraints_for_graph + + placeholders = [p for p in self.graph.nodes if p.op == "placeholder"] + input_placeholders = [ + p + for p, s in zip(placeholders, self.graph_signature.input_specs) + if s.kind == InputKind.USER_INPUT + ] + _check_input_constraints_for_graph( + input_placeholders, flat_args_with_path, self.range_constraints + ) + + @compatibility(is_backward_compatible=False) + def validate(self): + self._validate() + + # TODO: remove this + @final + def _validate(self): + assert len(self.verifiers) > 0, ( + "ExportedProgram must have at least one verifier." + ) + for v in self.verifiers: + v().check(self) + + # TODO(zhxchen17) Formalize this. + def _update( + self, + graph_module, + graph_signature, + *, + state_dict=None, + constants=None, + verifiers=None, + ) -> "ExportedProgram": + return ExportedProgram( + root=graph_module, + graph=graph_module.graph, + graph_signature=graph_signature, + state_dict=state_dict if state_dict is not None else self.state_dict, + range_constraints=copy.deepcopy(self.range_constraints), + module_call_graph=copy.deepcopy(self._module_call_graph), + example_inputs=self.example_inputs, + constants=constants if constants is not None else self.constants, + verifiers=verifiers if verifiers is not None else self.verifiers, + ) + + +def _get_shape_env(gm): + vals = [ + node.meta["val"] + for node in gm.graph.nodes + if node.meta.get("val", None) is not None + ] + from torch._guards import detect_fake_mode + + fake_mode = detect_fake_mode(vals) + if fake_mode is not None: + return fake_mode.shape_env + for v in vals: + if isinstance(v, torch.SymInt): + return v.node.shape_env + + +def _get_updated_range_constraints( + gm: torch.fx.GraphModule, + old_range_constraints: "dict[sympy.Symbol, Any] | None" = None, +) -> "dict[sympy.Symbol, Any]": + assert old_range_constraints is not None + + shape_env = _get_shape_env(gm) + if shape_env is None: + return {} + + range_constraints = copy.copy(old_range_constraints) + range_constraints = { + k: v for k, v in range_constraints.items() if k not in shape_env.replacements + } + # Only when we have an unbacked symint, and it's used as constructor inputs, + # runtime_var_to_range will make a difference compated to var_to_range. + # e.g. [2, oo) -> [0, oo) + for k, v in shape_env.var_to_range.items(): + if k not in shape_env.replacements and k not in range_constraints: + range_constraints[k] = v + return range_constraints + + +def _create_graph_module_for_export(root, graph): + try: + gm = torch.fx.GraphModule(root, graph) + except SyntaxError: + # If custom objects stored in memory are being used in the graph, + # the generated python code will result in a syntax error on the custom + # object, since it is unable to parse the in-memory object. However + # we can still run the graph eagerly through torch.fx.Interpreter, + # so we will bypass this error. + warnings.warn( + "Unable to execute the generated python source code from " + "the graph. The graph module will no longer be directly callable, " + "but you can still run the ExportedProgram, and if needed, you can " + "run the graph module eagerly using torch.fx.Interpreter.", + stacklevel=2, + ) + gm = torch.fx.GraphModule(root, torch.fx.Graph()) + gm._graph = graph + + return gm + + +def _convert_guards_to_code(graph_module): + shape_env = _get_shape_env(graph_module) + if shape_env is None: + return [] + + local_vars = { + var + for var, sources in shape_env.var_to_sources.items() + if all( + not isinstance(source, torch._dynamo.source.ConstantSource) + for source in sources + ) + } + py_printer = torch.fx.experimental.symbolic_shapes.ShapeGuardPythonPrinter( + shape_env.var_to_sources, lambda s: s.name, shape_env.var_to_sources + ) + ret = [ + py_printer.doprint(guard.expr) + for guard in shape_env.guards + if guard.expr.free_symbols.issubset(local_vars) + ] + # TODO Figure out how to resolve guards containing weight sizes. + # This is not a big deal as _guards_code is mostly empty today. + return [guard for guard in ret if "L['self']" not in guard] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/graph_signature.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/graph_signature.py new file mode 100644 index 0000000000000000000000000000000000000000..5311b7beb47ef318e5f5f646fe70b43327672270 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/graph_signature.py @@ -0,0 +1,729 @@ +# mypy: allow-untyped-defs +import dataclasses +from collections.abc import Collection, Mapping +from enum import auto, Enum +from typing import TYPE_CHECKING, Union + +from torch._library.fake_class_registry import FakeScriptObject +from torch._library.opaque_object import get_opaque_type_name, is_opaque_type +from torch._subclasses.fake_tensor import is_fake + + +if TYPE_CHECKING: + import torch + from torch._functorch._aot_autograd.schemas import GraphSignature + +__all__ = [ + "ConstantArgument", + "CustomObjArgument", + "ExportBackwardSignature", + "ExportGraphSignature", + "InputKind", + "InputSpec", + "OutputKind", + "OutputSpec", + "SymIntArgument", + "SymFloatArgument", + "SymBoolArgument", + "TensorArgument", +] + + +@dataclasses.dataclass +class TensorArgument: + name: str + + +@dataclasses.dataclass +class TokenArgument: + name: str + + +@dataclasses.dataclass +class SymIntArgument: + name: str + + +@dataclasses.dataclass +class SymFloatArgument: + name: str + + +@dataclasses.dataclass +class SymBoolArgument: + name: str + + +@dataclasses.dataclass +class CustomObjArgument: + name: str + class_fqn: str + fake_val: FakeScriptObject | None = None + + +@dataclasses.dataclass +class ConstantArgument: + name: str + value: int | float | bool | str | None + + +ArgumentSpec = Union[ + TensorArgument, + SymIntArgument, + SymFloatArgument, + SymBoolArgument, + ConstantArgument, + CustomObjArgument, + TokenArgument, +] + + +class InputKind(Enum): + USER_INPUT = auto() + PARAMETER = auto() + BUFFER = auto() + CONSTANT_TENSOR = auto() + CUSTOM_OBJ = auto() + TOKEN = auto() + + +@dataclasses.dataclass +class InputSpec: + kind: InputKind + arg: ArgumentSpec + target: str | None + persistent: bool | None = None + + def __post_init__(self): + if self.kind == InputKind.BUFFER: + assert self.persistent is not None, ( + "Failed to specify persistent flag on BUFFER." + ) + assert isinstance( + self.arg, + ( + TensorArgument, + SymIntArgument, + SymFloatArgument, + SymBoolArgument, + ConstantArgument, + CustomObjArgument, + TokenArgument, + ), + ), f"got {type(self.arg)}" + + def __str__(self): + target = "" if self.target is None else f" target='{self.target}'" + persistent = "" if self.persistent is None else f" persistent={self.persistent}" + return f"{str(self.arg.name)}: {str(self.kind.name)}{target}{persistent}" + + +class OutputKind(Enum): + USER_OUTPUT = auto() + LOSS_OUTPUT = auto() + BUFFER_MUTATION = auto() + PARAMETER_MUTATION = auto() + GRADIENT_TO_PARAMETER = auto() + GRADIENT_TO_USER_INPUT = auto() + USER_INPUT_MUTATION = auto() + TOKEN = auto() + + +@dataclasses.dataclass +class OutputSpec: + kind: OutputKind + arg: ArgumentSpec + target: str | None + + def __post_init__(self): + assert isinstance( + self.arg, + ( + TensorArgument, + SymIntArgument, + SymFloatArgument, + SymBoolArgument, + ConstantArgument, + TokenArgument, + CustomObjArgument, + ), + ), self.arg + + def __str__(self): + target = "" if self.target is None else f" target='{self.target}'" + return f"{str(self.arg.name)}: {str(self.kind.name)}{target}" + + +@dataclasses.dataclass +class ExportBackwardSignature: + gradients_to_parameters: dict[str, str] + gradients_to_user_inputs: dict[str, str] + loss_output: str + + +@dataclasses.dataclass +class ExportGraphSignature: + """ + :class:`ExportGraphSignature` models the input/output signature of Export Graph, + which is a fx.Graph with stronger invariants guarantees. + + Export Graph is functional and does not access "states" like parameters + or buffers within the graph via ``getattr`` nodes. Instead, :func:`export` + guarantees that parameters, buffers, and constant tensors are lifted out of + the graph as inputs. Similarly, any mutations to buffers are not included + in the graph either, instead the updated values of mutated buffers are + modeled as additional outputs of Export Graph. + + The ordering of all inputs and outputs are:: + + Inputs = [*parameters_buffers_constant_tensors, *flattened_user_inputs] + Outputs = [*mutated_inputs, *flattened_user_outputs] + + e.g. If following module is exported:: + + class CustomModule(nn.Module): + def __init__(self) -> None: + super(CustomModule, self).__init__() + + # Define a parameter + self.my_parameter = nn.Parameter(torch.tensor(2.0)) + + # Define two buffers + self.register_buffer("my_buffer1", torch.tensor(3.0)) + self.register_buffer("my_buffer2", torch.tensor(4.0)) + + def forward(self, x1, x2): + # Use the parameter, buffers, and both inputs in the forward method + output = ( + x1 + self.my_parameter + ) * self.my_buffer1 + x2 * self.my_buffer2 + + # Mutate one of the buffers (e.g., increment it by 1) + self.my_buffer2.add_(1.0) # In-place addition + + return output + + + mod = CustomModule() + ep = torch.export.export(mod, (torch.tensor(1.0), torch.tensor(2.0))) + + Resulting Graph is non-functional:: + + graph(): + %p_my_parameter : [num_users=1] = placeholder[target=p_my_parameter] + %b_my_buffer1 : [num_users=1] = placeholder[target=b_my_buffer1] + %b_my_buffer2 : [num_users=2] = placeholder[target=b_my_buffer2] + %x1 : [num_users=1] = placeholder[target=x1] + %x2 : [num_users=1] = placeholder[target=x2] + %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x1, %p_my_parameter), kwargs = {}) + %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, %b_my_buffer1), kwargs = {}) + %mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%x2, %b_my_buffer2), kwargs = {}) + %add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul, %mul_1), kwargs = {}) + %add_ : [num_users=0] = call_function[target=torch.ops.aten.add_.Tensor](args = (%b_my_buffer2, 1.0), kwargs = {}) + return (add_1,) + + Resulting ExportGraphSignature of the non-functional Graph would be:: + + # inputs + p_my_parameter: PARAMETER target='my_parameter' + b_my_buffer1: BUFFER target='my_buffer1' persistent=True + b_my_buffer2: BUFFER target='my_buffer2' persistent=True + x1: USER_INPUT + x2: USER_INPUT + + # outputs + add_1: USER_OUTPUT + + To get a functional Graph, you can use :func:`run_decompositions`:: + + mod = CustomModule() + ep = torch.export.export(mod, (torch.tensor(1.0), torch.tensor(2.0))) + ep = ep.run_decompositions() + + Resulting Graph is functional:: + + graph(): + %p_my_parameter : [num_users=1] = placeholder[target=p_my_parameter] + %b_my_buffer1 : [num_users=1] = placeholder[target=b_my_buffer1] + %b_my_buffer2 : [num_users=2] = placeholder[target=b_my_buffer2] + %x1 : [num_users=1] = placeholder[target=x1] + %x2 : [num_users=1] = placeholder[target=x2] + %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x1, %p_my_parameter), kwargs = {}) + %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, %b_my_buffer1), kwargs = {}) + %mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%x2, %b_my_buffer2), kwargs = {}) + %add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul, %mul_1), kwargs = {}) + %add_2 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%b_my_buffer2, 1.0), kwargs = {}) + return (add_2, add_1) + + Resulting ExportGraphSignature of the functional Graph would be:: + + # inputs + p_my_parameter: PARAMETER target='my_parameter' + b_my_buffer1: BUFFER target='my_buffer1' persistent=True + b_my_buffer2: BUFFER target='my_buffer2' persistent=True + x1: USER_INPUT + x2: USER_INPUT + + # outputs + add_2: BUFFER_MUTATION target='my_buffer2' + add_1: USER_OUTPUT + + """ + + input_specs: list[InputSpec] + output_specs: list[OutputSpec] + + # A list of parameters uniquely identified by mangled fully qualified name + @property + def parameters(self) -> Collection[str]: + return tuple( + s.target + for s in self.input_specs + if s.kind == InputKind.PARAMETER + if isinstance(s.target, str) + ) + + # A list of buffers uniquely identified by mangled fully qualified name + @property + def buffers(self) -> Collection[str]: + return tuple( + s.target + for s in self.input_specs + if s.kind == InputKind.BUFFER + if isinstance(s.target, str) + ) + + @property + def non_persistent_buffers(self) -> Collection[str]: + return tuple( + s.target + for s in self.input_specs + if s.kind == InputKind.BUFFER + if s.persistent is False + if isinstance(s.target, str) + ) + + # A list of lifted constant tensors + @property + def lifted_tensor_constants(self) -> Collection[str]: + return tuple( + s.target + for s in self.input_specs + if s.kind == InputKind.CONSTANT_TENSOR + if isinstance(s.target, str) + ) + + @property + def lifted_custom_objs(self) -> Collection[str]: + return tuple( + s.target + for s in self.input_specs + if s.kind == InputKind.CUSTOM_OBJ + if isinstance(s.target, str) + ) + + # Graph node names of pytree-flattened inputs of original program + @property + def user_inputs(self) -> Collection[int | float | bool | None | str]: + user_inputs: list[int | float | bool | None | str] = [] + for s in self.input_specs: + if s.kind != InputKind.USER_INPUT: + continue + + if isinstance( + s.arg, + ( + TensorArgument, + SymIntArgument, + SymFloatArgument, + SymBoolArgument, + CustomObjArgument, + ), + ): + user_inputs.append(s.arg.name) + elif isinstance(s.arg, ConstantArgument): + user_inputs.append(s.arg.value) + else: + raise RuntimeError(f"{s.arg} is not a valid user inputs") + return tuple(user_inputs) + + # Graph node names of pytree-flattened outputs of original program + # For joint-graph purposes, will include the loss output. + @property + def user_outputs(self) -> Collection[int | float | bool | None | str]: + user_outputs: list[int | float | bool | None | str] = [] + for s in self.output_specs: + if s.kind not in [ + OutputKind.USER_OUTPUT, + OutputKind.LOSS_OUTPUT, + ]: + continue + + if isinstance( + s.arg, + (TensorArgument, SymIntArgument, SymFloatArgument, SymBoolArgument), + ): + user_outputs.append(s.arg.name) + elif isinstance(s.arg, ConstantArgument): + user_outputs.append(s.arg.value) + elif isinstance(s.arg, CustomObjArgument): + user_outputs.append(s.arg.name) + else: + raise RuntimeError(f"{s.arg} is not a valid user output") + return tuple(user_outputs) + + # A dictionary mapping graph input node names to parameters. If a graph input + # name is found in this dictionary, it is guaranteed to be a lifted parameter. + @property + def inputs_to_parameters(self) -> Mapping[str, str]: + return _immutable_dict( + (s.arg.name, s.target) + for s in self.input_specs + if s.kind == InputKind.PARAMETER + and isinstance(s.arg, TensorArgument) + and isinstance(s.target, str) + ) + + # A dictionary mapping graph input node names to buffers. If a graph input + # name is found in this dictionary, it is guaranteed to be a lifted buffer. + @property + def inputs_to_buffers(self) -> Mapping[str, str]: + return _immutable_dict( + (s.arg.name, s.target) # type: ignore[union-attr, misc] + for s in self.input_specs + if s.kind == InputKind.BUFFER + and isinstance(s.arg, TensorArgument) + and isinstance(s.target, str) + ) + + # A dictionary mapping graph output node names to buffers that are mutated in the + # original program. Buffers that are not mutated will not be found in this dictionary. + @property + def buffers_to_mutate(self) -> Mapping[str, str]: + return _immutable_dict( + (s.arg.name, s.target) + for s in self.output_specs + if s.kind == OutputKind.BUFFER_MUTATION + and isinstance(s.arg, TensorArgument) + and isinstance(s.target, str) + ) + + @property + def parameters_to_mutate(self) -> Mapping[str, str]: + return _immutable_dict( + (s.arg.name, s.target) + for s in self.output_specs + if s.kind == OutputKind.PARAMETER_MUTATION + and isinstance(s.arg, TensorArgument) + and isinstance(s.target, str) + ) + + @property + def user_inputs_to_mutate(self) -> Mapping[str, str]: + return _immutable_dict( + (s.arg.name, s.target) + for s in self.output_specs + if s.kind == OutputKind.USER_INPUT_MUTATION + and isinstance(s.arg, TensorArgument) + and isinstance(s.target, str) + ) + + # A dictionary mapping graph input node names to lifted tensor constants. + @property + def inputs_to_lifted_tensor_constants(self) -> Mapping[str, str]: + return _immutable_dict( + (s.arg.name, s.target) + for s in self.input_specs + if s.kind == InputKind.CONSTANT_TENSOR + and isinstance(s.arg, TensorArgument) + and isinstance(s.target, str) + ) + + @property + def inputs_to_lifted_custom_objs(self) -> Mapping[str, str]: + return _immutable_dict( + (s.arg.name, s.target) + for s in self.input_specs + if s.kind == InputKind.CUSTOM_OBJ + and isinstance(s.arg, CustomObjArgument) + and isinstance(s.target, str) + ) + + @property + def backward_signature(self) -> ExportBackwardSignature | None: + loss_output = None + gradients_to_parameters: dict[str, str] = {} + gradients_to_user_inputs: dict[str, str] = {} + for spec in self.output_specs: + if spec.kind == OutputKind.LOSS_OUTPUT: + assert loss_output is None + assert isinstance(spec.arg, TensorArgument) + loss_output = spec.arg.name + elif spec.kind == OutputKind.GRADIENT_TO_PARAMETER: + assert isinstance(spec.target, str) + assert isinstance(spec.arg, TensorArgument) + gradients_to_parameters[spec.arg.name] = spec.target + elif spec.kind == OutputKind.GRADIENT_TO_USER_INPUT: + assert isinstance(spec.target, str) + assert isinstance(spec.arg, TensorArgument) + gradients_to_user_inputs[spec.arg.name] = spec.target + + if loss_output is None: + return None + + return ExportBackwardSignature( + loss_output=loss_output, + gradients_to_parameters=gradients_to_parameters, + gradients_to_user_inputs=gradients_to_user_inputs, + ) + + # Map from assertion dependency token index to assertion dep token output + # name in output. The shape of output after aot_autograd will be like: + # (updated_inputs, user_outputs, dep_token). + @property + def assertion_dep_token(self) -> Mapping[int, str] | None: + return None + + @property + def input_tokens(self) -> Collection[str]: + input_tokens = [] + for s in self.input_specs: + if s.kind == InputKind.TOKEN: + assert isinstance(s.arg, TokenArgument) + input_tokens.append(s.arg.name) + return tuple(input_tokens) + + @property + def output_tokens(self) -> Collection[str]: + output_tokens = [] + for s in self.output_specs: + if s.kind == OutputKind.TOKEN: + assert isinstance(s.arg, TokenArgument) + output_tokens.append(s.arg.name) + return tuple(output_tokens) + + def __post_init__(self) -> None: + assertion_dep_token = self.assertion_dep_token + if assertion_dep_token is None: + return + assert len(assertion_dep_token) == 1 + assertion_dep_token_index = next(iter(assertion_dep_token.keys())) + assert ( + len(self.user_outputs) + len(self.buffers_to_mutate) + == assertion_dep_token_index + ) + + def replace_all_uses(self, old: str, new: str): + """ + Replace all uses of the old name with new name in the signature. + """ + assert isinstance(old, str) + assert isinstance(new, str) + arg_types = ( + TensorArgument, + SymIntArgument, + SymFloatArgument, + SymBoolArgument, + CustomObjArgument, + TokenArgument, + ) + for o in self.output_specs: + if isinstance(o.arg, arg_types): + if o.arg.name == old: + o.arg.name = new + for i in self.input_specs: + if isinstance(i.arg, arg_types): + if i.arg.name == old: + i.arg.name = new + + def get_replace_hook(self, replace_inputs=False): + def _(old, new, user): + if user.op == "output": + self.replace_all_uses(old.name, new) + if replace_inputs and old.op == "placeholder": + self.replace_all_uses(old.name, new) + + return _ + + def __str__(self): + input_specs = "\n".join(str(s) for s in self.input_specs) + output_specs = "\n".join(str(s) for s in self.output_specs) + return f"\n# inputs\n{input_specs}\n\n# outputs\n{output_specs}\n" + + +def _immutable_dict(items): + """ + Creates a mapping where items cannot be added, deleted, or updated. + NOTE: The immutability is shallow (like tuple is an immutable collection). + """ + from types import MappingProxyType + + return MappingProxyType(dict(items)) + + +def _make_argument_spec(node, token_names) -> ArgumentSpec: + from torch import ScriptObject, SymBool, SymFloat, SymInt + from torch._library.fake_class_registry import FakeScriptObject + + if isinstance(node, (int, bool, float, type(None), str)): + # For const outputs we just directly return this + return ConstantArgument(name="", value=node) + + assert "val" in node.meta, ( + f"{node} is not a constant or a node with a 'val' metadata field" + ) + val = node.meta["val"] + if node.name in token_names: + return TokenArgument(name=node.name) + elif is_fake(val): + return TensorArgument(name=node.name) + elif isinstance(val, SymInt): + return SymIntArgument(name=node.name) + elif isinstance(val, SymFloat): + return SymFloatArgument(name=node.name) + elif isinstance(val, SymBool): + return SymBoolArgument(name=node.name) + elif isinstance(val, ScriptObject): + return CustomObjArgument(name=node.name, class_fqn=val._type().qualified_name()) # type: ignore[attr-defined] + elif isinstance(val, FakeScriptObject): + return CustomObjArgument( + name=node.name, class_fqn=val.script_class_name, fake_val=val + ) + elif is_opaque_type(type(val)): + return CustomObjArgument( + name=node.name, class_fqn=get_opaque_type_name(type(val)), fake_val=val + ) + elif isinstance(val, (int, bool, str, float, type(None))): + return ConstantArgument(name=node.name, value=val) + else: + raise AssertionError( + f"Encountered an unsupported object of type {type(val)} " + f"while writing the metadata for exported program" + ) + + +def _convert_to_export_graph_signature( + graph_signature: "GraphSignature", + gm: "torch.fx.GraphModule", + non_persistent_buffers: set[str], +) -> "ExportGraphSignature": + from torch.utils import _pytree as pytree + + is_joint = graph_signature.backward_signature is not None + + # unpack objects + user_inputs = set(graph_signature.user_inputs) + inputs_to_parameters = graph_signature.inputs_to_parameters + inputs_to_buffers = graph_signature.inputs_to_buffers + user_outputs = set(graph_signature.user_outputs) + buffer_mutations = graph_signature.buffers_to_mutate + parameter_mutations = graph_signature.parameters_to_mutate + user_input_mutations = graph_signature.user_inputs_to_mutate + grad_params = ( + graph_signature.backward_signature.gradients_to_parameter # type: ignore[union-attr] + if is_joint + else {} + ) + grad_user_inputs = ( + graph_signature.backward_signature.gradients_to_user_inputs # type: ignore[union-attr] + if is_joint + else {} + ) + loss_output = ( + graph_signature.backward_signature.loss_output # type: ignore[union-attr] + if is_joint + else None + ) + input_tokens = graph_signature.input_tokens + output_tokens = graph_signature.output_tokens + + inputs = [ + _make_argument_spec(node, input_tokens) + for node in gm.graph.nodes + if node.op == "placeholder" + ] + outputs = [ + _make_argument_spec(node, output_tokens) + for node in pytree.tree_leaves(next(iter(reversed(gm.graph.nodes))).args) + ] + + def to_input_spec(inp: ArgumentSpec) -> InputSpec: + if isinstance(inp, TokenArgument): + return InputSpec(kind=InputKind.TOKEN, arg=inp, target=None) + + if not isinstance(inp, TensorArgument): + return InputSpec(kind=InputKind.USER_INPUT, arg=inp, target=None) + name = inp.name + if name in user_inputs: + return InputSpec(kind=InputKind.USER_INPUT, arg=inp, target=None) + elif name in inputs_to_parameters: + return InputSpec( + kind=InputKind.PARAMETER, + arg=inp, + target=inputs_to_parameters[name], # type: ignore[index] + ) + elif name in inputs_to_buffers: + return InputSpec( + kind=InputKind.BUFFER, + arg=inp, + target=inputs_to_buffers[name], # type: ignore[index] + persistent=(inputs_to_buffers[name] not in non_persistent_buffers), # type: ignore[index] + ) + else: + raise AssertionError(f"Unknown tensor input kind: {name}") + + def to_output_spec(idx: int, o: ArgumentSpec) -> OutputSpec: + if isinstance(o, TokenArgument): + return OutputSpec(kind=OutputKind.TOKEN, arg=o, target=None) + + if not isinstance(o, TensorArgument): + return OutputSpec(kind=OutputKind.USER_OUTPUT, arg=o, target=None) + name = o.name + if idx < len(buffer_mutations) + len(parameter_mutations) + len( + user_input_mutations + ) + len(output_tokens): + if name in buffer_mutations: + return OutputSpec( + kind=OutputKind.BUFFER_MUTATION, + arg=o, + target=buffer_mutations[name], # type: ignore[index] + ) + elif name in parameter_mutations: + return OutputSpec( + kind=OutputKind.PARAMETER_MUTATION, + arg=o, + target=parameter_mutations[name], # type: ignore[index] + ) + elif name in user_input_mutations: + return OutputSpec( + kind=OutputKind.USER_INPUT_MUTATION, + arg=o, + target=user_input_mutations[name], # type: ignore[index] + ) + else: + raise AssertionError(f"Unknown tensor mutation kind: {name}") + else: + if name in user_outputs: + return OutputSpec(kind=OutputKind.USER_OUTPUT, arg=o, target=None) + + elif name in grad_params: + return OutputSpec( + kind=OutputKind.GRADIENT_TO_PARAMETER, + arg=o, + target=grad_params[name], + ) + elif name in grad_user_inputs: + return OutputSpec( + kind=OutputKind.GRADIENT_TO_USER_INPUT, + arg=o, + target=grad_user_inputs[name], + ) + elif name == loss_output: + return OutputSpec(kind=OutputKind.LOSS_OUTPUT, arg=o, target=None) + + else: + raise AssertionError(f"Unknown tensor output kind: {name}") + + input_specs = [to_input_spec(inp) for inp in inputs] + output_specs = [to_output_spec(idx, o) for idx, o in enumerate(outputs)] + return ExportGraphSignature(input_specs=input_specs, output_specs=output_specs) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/unflatten.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/unflatten.py new file mode 100644 index 0000000000000000000000000000000000000000..680b0a907512971da77e505818989035929d20aa --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/export/unflatten.py @@ -0,0 +1,1803 @@ +# mypy: allow-untyped-defs +import abc +import copy +import logging +import operator +import re +from collections import defaultdict +from collections.abc import Callable +from contextlib import contextmanager +from copy import deepcopy +from dataclasses import dataclass +from enum import Enum +from typing import Any, cast + +import torch +import torch.fx._pytree as fx_pytree +import torch.utils._pytree as pytree +from torch._library.fake_class_registry import FakeScriptObject +from torch.export import ExportedProgram +from torch.export._tree_utils import reorder_kwargs +from torch.export.exported_program import ( + ConstantArgument, + ExportGraphSignature, + InputKind, + ModuleCallSignature, + SymBoolArgument, + SymFloatArgument, + SymIntArgument, + TensorArgument, +) +from torch.fx._symbolic_trace import is_fx_symbolic_tracing +from torch.fx.graph_module import _get_attr, _get_attr_via_attr_list, _print_readable +from torch.utils._pytree import GetAttrKey, SequenceKey + +from ._remove_effect_tokens_pass import _remove_effect_tokens + + +log = logging.getLogger(__name__) + + +__all__ = [ + "FlatArgsAdapter", + "InterpreterModule", + "InterpreterModuleDispatcher", + "UnflattenedModule", + "unflatten", +] + + +class _AttrKind(Enum): + PARAMETER = "parameter" + BUFFER = "buffer" + CONSTANT = "constant" + MODULE = "module" + + +@dataclass(frozen=True) +class _TensorID: + """Custom tensor identifier containing storage, stride, and size information.""" + + untyped_storage: torch.UntypedStorage + stride: tuple + size: tuple + storage_offset: int + + +RUN_WITH_INTERPRETER = True + + +@contextmanager +def _disable_interpreter(): + global RUN_WITH_INTERPRETER + old_flag = RUN_WITH_INTERPRETER + RUN_WITH_INTERPRETER = False + try: + yield + finally: + RUN_WITH_INTERPRETER = old_flag + + +# Assign attribute 'from_obj' to the qualified name 'target' on 'to_module +# This installs empty Modules where none exist yet if they are subpaths of target +def _assign_attr( + from_obj: torch.Tensor | torch.ScriptObject | torch.nn.Module, + to_module: torch.nn.Module, + target: str, + attr_kind: _AttrKind, + persistent: bool = True, +): + *prefix, field = target.split(".") + # We need to generate all submodules of `to_module` that are at `prefix` and + # variants of `prefix` that differ only by call name. All of these submodules + # will then be assigned `from_obj` at `field` so that they can share this attribute. + # For example, if target is foo.bar.f, foo has another call name foo@1, + # and bar has other call names bar@1, bar@2, then we will assign f to + # foo.bar, foo.bar@1, foo.bar@2, foo@1.bar, foo@1.bar@1, foo@1.bar@2. + to_modules = {to_module} + for item in prefix: + ts: set[torch.nn.Module] = set() + for to_module in to_modules: + if not hasattr(to_module, item): + setattr(to_module, item, torch.nn.Module()) + ts.update( + t_call # type: ignore[misc] + for k, t_call in to_module._modules.items() + if _is_call_name(k, item) + ) + to_modules = ts + + for to_module in to_modules: + if attr_kind == _AttrKind.PARAMETER: + assert isinstance(from_obj, torch.nn.Parameter) + to_module.register_parameter(field, from_obj) + elif attr_kind == _AttrKind.BUFFER: + assert isinstance(from_obj, torch.Tensor) + to_module.register_buffer(field, from_obj, persistent=persistent) + elif attr_kind == _AttrKind.CONSTANT: + assert not isinstance(from_obj, FakeScriptObject), ( + "FakeScriptObject should only exist during tracing." + ) + assert isinstance( + from_obj, + ( + torch.Tensor, + torch.ScriptObject, + ), + ) + setattr(to_module, field, from_obj) + elif attr_kind == _AttrKind.MODULE: + assert isinstance(from_obj, torch.nn.Module) + setattr(to_module, field, from_obj) + + +class _SubmoduleBase: + _ty: str | None + + def type_name(self) -> str | None: + """ + Subclass of this class - InterpreterModule, InterpreterModuleDispatcher, represents + corresponding model in eager model. To get this type information for those modules + in eager model we need to use this method. + """ + return self._ty + + +class InterpreterModule(_SubmoduleBase, torch.nn.Module): + """A module that uses torch.fx.Interpreter to execute instead of the usual + codegen that GraphModule uses. This provides better stack trace information + and makes it easier to debug execution. + """ + + graph_module: torch.fx.GraphModule | None + + def __init__( + self, + graph: torch.fx.Graph, + ty: str | None = None, + ): + super().__init__() + self.graph = graph + self._ty = ty + self.graph.owning_module = self # type: ignore[assignment] + self._run_with_interpreter = RUN_WITH_INTERPRETER + + def forward(self, *args, **kwargs): + assert self.graph_module is not None, "Didn't finalize this InterpreterModule" + if not is_fx_symbolic_tracing() and ( + torch.compiler.is_dynamo_compiling() or not self._run_with_interpreter + ): + # Dynamo cannot trace through torch.fx.Interpreter, so fall back to + # GraphModule codegen in this instance. + # Patch the codegened forward to run with this InterpreterModule, + # so attribute accesses, etc. are on this module instead. + return type(self.graph_module).forward(self, *args, **kwargs) + else: + if kwargs: + # Handle **kwargs. FX only natively supports positional + # arguments (through placeholders). So in order to pass in + # kwargs, we must correspond the names of the placeholders with + # the keys in the kwarg dict. + arg_list = list(args) + kwarg_names = self.arg_names[len(arg_list) :] + arg_list.extend( + kwargs[kwarg_name] + for kwarg_name in kwarg_names + if kwarg_name in kwargs + ) + + # Assert that the kwargs passed in exactly match the positional + # arguments specified by the GraphModule. This should be + # guaranteed by the unflattening process. + assert len(kwarg_names) == len(kwargs) + assert len(arg_list) == len(self.arg_names) + args = tuple(arg_list) + + return torch.fx.Interpreter(self, graph=self.graph).run( + *args, enable_io_processing=False + ) + + def finalize(self): + # We need to "finalize" because GraphModule populates its own state_dict + # based on the get_attrs observed in the graph. So we need to fully + # construct the graph and call _sink_params before generating this + # GraphModule. + + # need to set `graph_module` directly on the dict to avoid it getting + # registered as a submodule. + self.__dict__["graph_module"] = torch.fx.GraphModule(self, self.graph) + self.graph.lint() + + # Cache arg names for kwarg handling (see forward()) + self.arg_names = [] + for node in self.graph.nodes: + if node.op == "placeholder": + self.arg_names.append(node.target) + + def print_readable( + self, + print_output=True, + include_stride=False, + include_device=False, + colored=False, + ): + return _print_readable( + self, + "InterpreterModule", + print_output, + include_stride, + include_device, + colored, + ) + + +class InterpreterModuleDispatcher(_SubmoduleBase, torch.nn.Module): + """ + A module that carries a sequence of InterpreterModules corresponding to + a sequence of calls of that module. Each call to the module dispatches + to the next InterpreterModule, and wraps back around after the last. + """ + + def __init__(self, attrs: set[str], call_modules: list[InterpreterModule]): + super().__init__() + assert call_modules + self._modules = call_modules[0]._modules + for accessor in attrs: + setattr(self, accessor, getattr(call_modules[0], accessor)) + self._ty = call_modules[0]._ty + self._call_modules = call_modules + self._num_calls = 0 + + def forward(self, *args, **kwargs): + call_module = self._call_modules[self._num_calls] + self._num_calls = (self._num_calls + 1) % len(self._call_modules) + try: + return call_module(*args, **kwargs) + except Exception: + self._num_calls = 0 + raise + + def call_modules(self): + return self._call_modules + + def print_readable( + self, + print_output=True, + include_stride=False, + include_device=False, + colored=False, + ): + outputs = [ + mod.print_readable( + print_output, + include_stride, + include_device, + colored, + ) + for mod in self._call_modules + ] + return "\n".join(outputs) + + +class FlatArgsAdapter(abc.ABC): + """ + Adapts input arguments with ``input_spec`` to align ``target_spec``. + """ + + @abc.abstractmethod + def adapt( + self, + target_spec: pytree.TreeSpec, + input_spec: pytree.TreeSpec, + input_args: list[Any], + metadata: dict[str, Any] | None = None, + obj: Any | None = None, + ) -> list[Any]: + """NOTE: This adapter may mutate given ``input_args_with_path``.""" + ... + + def get_flat_arg_paths(self) -> list[str]: + """Returns a list of paths that are used to access the flat args.""" + return [] + + +class UnflattenedModule(_SubmoduleBase, torch.nn.Module): + def __init__( + self, + export_module: ExportedProgram, + flat_args_adapter: FlatArgsAdapter | None = None, + ): + super().__init__() + if export_module.graph_signature.backward_signature is not None: + raise ValueError("Unflattening on JointExportModule NYI") + + def _id(obj): + """Returns _TensorID dataclass for tensors, otherwise id().""" + if isinstance(obj, torch.Tensor): + return _TensorID( + untyped_storage=obj.untyped_storage(), + stride=obj.stride(), + size=obj.size(), + storage_offset=obj.storage_offset(), # type: ignore[arg-type] + ) + return id(obj) + + fqn_list = [entry.fqn for entry in export_module.module_call_graph] + assert fqn_list[0] == "" + export_graph = deepcopy(export_module.graph) + self.graph_signature = deepcopy(export_module.graph_signature) + self.graph = torch.fx.Graph() + self.graph.owning_module = self # type: ignore[assignment] + self.module_call_graph = deepcopy(export_module.module_call_graph) + self.flat_args_adapter = flat_args_adapter + + self.meta = export_module.graph_module.meta + self.meta["unflattened_module"] = self + + # Flag to indicate whether args have been adapted. + self.adapted = False + self._run_with_interpreter = RUN_WITH_INTERPRETER + + _inplace_buffer_and_input_mutations(export_graph, self.graph_signature) + _fix_nn_module_stacks(export_graph) + self._ty = _root_module_type(export_graph) + + self.ivals = _IVals() + # for any intermediate value of a mutation that is read, track the mutation + seen_modules, seen_attrs = _outline_submodules(export_graph, self) + # for each read intermediate value of a mutation, find where it was created, + # and perform the mutation + self.ivals.update(seen_modules.values()) + # move attributes that correspond to graph arguments for HOPs + # from exported program to unflattened submodules + _copy_graph_attrs(export_module._graph_module, self, seen_attrs) + + self.range_constraints = export_module.range_constraints + self.equality_constraints: list = [] + + # aliasing/unused param or buffer issues: + # in strict-mode export, dynamo export will deduplicate aliased tensors, + # and ignore unused tensors. For aliasing, this causes issues when some aliases + # are unused, and we're unable to match the placeholder node to the correct FQN. + # This leads to the graph signature potentially having the wrong target FQN, + # and downstream issues where parameters are assigned to the wrong target attribute, + # mismatching the relevant placeholder node in the unflattened module. + # To resolve this we restore (_assign_attr) all aliased/unused tensors in + # the state_dict as module attributes, but only keep the used tensors in the + # graph's forward pass (_sink_params). + state_dict = export_module.state_dict + assigned_params: set[str] = set() # tracking unused params + id_to_param: dict[ + int | _TensorID, torch.nn.Parameter + ] = {} # handling weight-sharing + for name in self.graph_signature.parameters: # this loop adds used params + param = state_dict[name] + if _id(param) not in id_to_param: + id_to_param[_id(param)] = torch.nn.Parameter( + param.clone(), requires_grad=param.requires_grad + ) + + _assign_attr( + id_to_param[_id(param)], + self, + name, + attr_kind=_AttrKind.PARAMETER, + ) + assigned_params.add(name) + + non_persistent_buffers = set(self.graph_signature.non_persistent_buffers) + assigned_buffers: set[str] = set() # tracking unused buffers + id_to_buffer: dict[int | _TensorID, tuple[torch.nn.Parameter, bool]] = {} + for name in self.graph_signature.buffers: # this loop adds used buffers + if name in non_persistent_buffers: + persistent = False + buffer = export_module.constants[name] + else: + persistent = True + buffer = state_dict[name] + + if _id(buffer) not in id_to_buffer: + id_to_buffer[_id(buffer)] = (buffer.clone(), persistent) + + _assign_attr( + id_to_buffer[_id(buffer)][0], + self, + name, + attr_kind=_AttrKind.BUFFER, + persistent=persistent, + ) + assigned_buffers.add(name) + + # restore aliased/unused params and buffers + # these appear in state dict but not graph signature + for name, tensor in state_dict.items(): + if name in assigned_params or name in assigned_buffers: # already assigned + continue + + is_buffer = False + if _id(tensor) in id_to_buffer or not isinstance( + tensor, torch.nn.Parameter + ): # aliased buffer + is_buffer = True + + if is_buffer: + if ( + _id(tensor) not in id_to_buffer + ): # this is completely unused (not weight-sharing) + id_to_buffer[_id(tensor)] = ( + tensor, + True, + ) # assign to respect original model + _assign_attr( + id_to_buffer[_id(tensor)][0], + self, + name, + attr_kind=_AttrKind.BUFFER, + persistent=True, + ) + else: + if _id(tensor) not in id_to_param: # this is unused + id_to_param[_id(tensor)] = tensor + _assign_attr( + id_to_param[_id(tensor)], + self, + name, + attr_kind=_AttrKind.PARAMETER, + ) + + # use id map so we don't double-clone aliased constants + id_to_const: dict[int | _TensorID, torch.Tensor | torch._C.ScriptObject] = {} + for fqn, constant in export_module.constants.items(): + if _id(constant) not in id_to_const: + if isinstance(constant, torch.Tensor): + constant = constant.clone() + id_to_const[_id(constant)] = constant + _constant = id_to_const[_id(constant)] + _assign_attr( + _constant, + self, + fqn, + attr_kind=_AttrKind.CONSTANT, + ) + + # This is to handle parameters/buffers that point to the same tensor + # object id -> list of (node_name, target_name) + consts_map: dict[int | _TensorID, list[tuple[str, str]]] = defaultdict(list) + consts_targets: set[str] = set() + + def add_to_consts_map(obj_id, node_name, target_name): + name_list = consts_map[obj_id] + name_list.append((node_name, target_name)) + + # track aliased/unused params, buffers + # prefer using untyped_storage() over id() when it's available + added_params_buffers: set[str] = set() + for s in self.graph_signature.input_specs: + if s.kind == InputKind.PARAMETER or ( + s.kind == InputKind.BUFFER and s.persistent + ): + assert hasattr(s.arg, "name") + assert isinstance(s.target, str) + add_to_consts_map( + _id(export_module.state_dict[s.target]), + s.arg.name, + s.target, + ) + consts_targets.add(s.target) + added_params_buffers.add(s.target) + elif ( + s.kind == InputKind.BUFFER + and not s.persistent + or s.kind == InputKind.CONSTANT_TENSOR + or s.kind == InputKind.CUSTOM_OBJ + ): + assert hasattr(s.arg, "name") + assert isinstance(s.target, str) + add_to_consts_map( + _id(export_module.constants[s.target]), + s.arg.name, + s.target, + ) + consts_targets.add(s.target) + + # add constants that are aliased and don't appear in graph signature + for const_name, const in export_module.constants.items(): + if const_name not in consts_targets: + const_id = _id(const) + assert const_id in consts_map + ph_name, _ = consts_map[const_id][0] + add_to_consts_map(const_id, ph_name, const_name) + added_params_buffers.add(s.target) + + # add aliased/unused params and buffers that don't appear in graph signature + for fqn, tensor in export_module.state_dict.items(): + if fqn not in added_params_buffers: + tensor_id = _id(tensor) + if tensor_id not in consts_map: + # completely unused (no weight-sharing), ignore. + # this weight doesn't appear in graph module, + # so won't cause FQN assignment issues + continue + ph_name, _ = consts_map[tensor_id][0] + add_to_consts_map(tensor_id, ph_name, fqn) + + # node name -> list of possible targets + inputs_to_state: dict[str, list[str]] = {} + for node_target in consts_map.values(): + targets = [t[1] for t in node_target] + for n, _ in node_target: + inputs_to_state[n] = targets + + _sink_params(self, inputs_to_state, []) + + redirected_call_indices = _deduplicate_modules(seen_modules.values()) + fqn_list = [fqn for fqn in fqn_list if fqn not in redirected_call_indices] + + self._dispatch_modules(redirected_call_indices, consts_targets) + fqn_list = [fqn for fqn in fqn_list if "@" not in fqn] + + # Cache so we don't have to compute this every time. + # NOTE: this needs to be kept in sync with the placeholders in + # self.graph, but currently we have no way to guarantee that. + self.input_placeholders = [ + node for node in self.graph.nodes if node.op == "placeholder" + ] + self.check_input_constraints = True + # TODO(zhxchen17) We can register modules ahead of time instead of reorder later. + fqn_order = {fqn: i for i, fqn in enumerate(fqn_list)} + # In the case of legacy IR, we might be missing some modules from metadata. + for name, _ in self.named_modules(remove_duplicate=False): + if name not in fqn_order: + fqn_order[name] = len(fqn_order) + _reorder_submodules(self, fqn_order) + self.graph.lint() + self.finalize() + + def _print_graph(self): + for fqn, mod in self.named_modules(): + print(fqn + ":") + if hasattr(mod, "graph") and isinstance(mod.graph, torch.fx.Graph): + print(mod.graph) + + def _adapt_flat_args(self, flat_args, in_spec, input): + signature = self.module_call_graph[0].signature + if in_spec == signature.in_spec: + return flat_args + + if self.flat_args_adapter is None: + raise TypeError( + "There is no flat args adapter specified. " + "Are you sure you are calling this with the right arguments? " + ) + else: + flat_args = self.flat_args_adapter.adapt( + target_spec=signature.in_spec, + input_spec=in_spec, + input_args=flat_args, + metadata=self.meta, + obj=input, + ) + + if len(flat_args) != signature.in_spec.num_leaves: + raise TypeError( + f"Flat args adaption failed, number of args mismatch " + f"Adatped: {len(flat_args)} \n" + f"Exported module: {signature.in_spec.num_leaves}" + ) + return flat_args + + def process_forward_inputs(self, *args, **kwargs): + signature = self.module_call_graph[0].signature + + reordered_kwargs = kwargs + if kwargs: + reordered_kwargs = reorder_kwargs(kwargs, signature.in_spec) + + flat_args_with_path, in_spec = pytree.tree_flatten_with_path( + (args, reordered_kwargs) + ) + flat_args = [x[1] for x in flat_args_with_path] + + if is_fx_symbolic_tracing(): + return flat_args + + if in_spec != signature.in_spec: + if not self.adapted: + print( + "Input treespec does not match with exported module's: \n" + f"Input treespec: {in_spec}. ", + f"Exported module treespec: {signature.in_spec}", + ) + print("Adapting flat arg to match exported module's treespec") + flat_args = self._adapt_flat_args(flat_args, in_spec, args) + self.adapted = True + + if self.check_input_constraints: + # Import here to avoid an unfortunate circular dependency. + # TODO(suo): untangle this. + from torch._export.utils import _check_input_constraints_for_graph + + if self.adapted is True: + flat_arg_paths = ( + self.flat_args_adapter.get_flat_arg_paths() + if self.flat_args_adapter + else [] + ) + assert not flat_arg_paths or len(flat_arg_paths) == len(flat_args) + new_flat_args_with_path = [ # type: ignore[var-annotated] + ( + ( + SequenceKey(idx=idx), + GetAttrKey( + name=flat_arg_paths[idx] + if flat_arg_paths + else "" + ), + ), + arg, + ) + for idx, arg in enumerate(flat_args) + ] + else: + new_flat_args_with_path = flat_args_with_path # type: ignore[assignment] + + _check_input_constraints_for_graph( + self.input_placeholders, new_flat_args_with_path, self.range_constraints + ) + + return flat_args + + def forward(self, *args, **kwargs): + flat_args = self.process_forward_inputs(*args, **kwargs) + signature = self.module_call_graph[0].signature + + if is_fx_symbolic_tracing(): + return_val = torch.fx.Interpreter(self, graph=self.graph).run( + *flat_args, enable_io_processing=False + ) + # For scalar return value, fx.Graph wraps in a tuple + if isinstance(return_val, tuple) and len(return_val) == 1: + return return_val[0] + return return_val + + if torch.compiler.is_dynamo_compiling() or not self._run_with_interpreter: + tree_out = type(self.graph_module).forward(self, *flat_args) # type: ignore[union-attr] + else: + tree_out = torch.fx.Interpreter(self, graph=self.graph).run( + *flat_args, enable_io_processing=False + ) + return pytree.tree_unflatten(tree_out, signature.out_spec) + + def finalize(self): + self.__dict__["graph_module"] = torch.fx.GraphModule(self, self.graph) + self.graph.lint() + + def _dispatch_modules(self, redirected_call_indices, consts_targets): + """For a module whose call signatures are preserved, replace + multiple modules corresponding to multiple calls to that module + with a single dispatcher module that tracks which module to call. + """ + + # for each fqn whose module call signature is preserved, + # map that fqn to a list of called modules + called_modules = defaultdict(list) + for entry in self.module_call_graph: + if entry.fqn and entry.signature: + # some modules were removed and their fqns redirected to other + # fqns during deduplication + fqn = entry.fqn + mod = _get_attr(self, redirected_call_indices.get(fqn, fqn)) + base, idx = fqn.split("@") if "@" in fqn else [fqn, "0"] + called_modules[base].append((int(idx), mod)) + + attrs_map = defaultdict(set) + for target in consts_targets: + if "." in target: + orig_fqn, name = target.rsplit(".", 1) + attrs_map[orig_fqn].add(name) + else: + attrs_map[""].add(target) + + # replace multiple call modules with a single dispatcher module + for orig_fqn, indexed_call_modules in called_modules.items(): + call_modules = [mod for _, mod in sorted(indexed_call_modules)] + if len(call_modules) > 1: + for i in range(len(call_modules)): + fqn = _call_name(orig_fqn, i + 1) + if fqn not in redirected_call_indices: + *prefix, name = fqn.split(".") + _get_attr_via_attr_list(self, prefix)._modules.pop(name) + self.set_submodule( + orig_fqn, + InterpreterModuleDispatcher(attrs_map[orig_fqn], call_modules), + ) + + # elide call indices in call modules because they are + # tracked automatically inside the dispatcher module + def elide_call_indices(prefix, graph): + for node in graph.nodes: + if node.op == "call_module": + fqn = node.target.split("@")[0] + path = f"{prefix}.{fqn}" if prefix else fqn + if path in called_modules: + node.target = fqn + + for fqn, mod in self.named_modules(remove_duplicate=False): + if hasattr(mod, "graph"): + elide_call_indices(fqn, mod.graph) + elif hasattr(mod, "_call_modules"): + for mod_ in mod._call_modules: + assert hasattr(mod_, "graph") + elide_call_indices(fqn, mod_.graph) + + def print_readable( + self, + print_output=True, + include_stride=False, + include_device=False, + colored=False, + ): + return _print_readable( + self, + "UnflattenedModule", + print_output, + include_stride, + include_device, + colored, + ) + + +def unflatten( + module: ExportedProgram, flat_args_adapter: FlatArgsAdapter | None = None +) -> UnflattenedModule: + """Unflatten an ExportedProgram, producing a module with the same module + hierarchy as the original eager module. This can be useful if you are trying + to use :mod:`torch.export` with another system that expects a module + hierarchy instead of the flat graph that :mod:`torch.export` usually produces. + + .. note:: The args/kwargs of unflattened modules will not necessarily match + the eager module, so doing a module swap (e.g. :code:`self.submod = + new_mod`) will not necessarily work. If you need to swap a module out, you + need to set the :code:`preserve_module_call_signature` parameter of + :func:`torch.export.export`. + + Args: + module (ExportedProgram): The ExportedProgram to unflatten. + flat_args_adapter (Optional[FlatArgsAdapter]): Adapt flat args if input TreeSpec does not match with exported module's. + + Returns: + An instance of :class:`UnflattenedModule`, which has the same module + hierarchy as the original eager module pre-export. + """ + module = _remove_effect_tokens(module) + m = UnflattenedModule(module, flat_args_adapter) + + # Disable process_forward_inputs as the adapter has many + # non-dynamo-traceable behavior. + m.process_forward_inputs = torch._dynamo.disable( # type: ignore[method-assign] + m.process_forward_inputs, + reason="do not trace into preprocessing the inputs", + recursive=True, + ) + + return m + + +def _inplace_buffer_and_input_mutations( + graph: torch.fx.Graph, + graph_signature: ExportGraphSignature, +) -> None: + """Transform buffer and input mutations from their functionalized form + into copy_ nodes in the graph. + + Functionalization represents a buffer mutation by passing the buffer as + an input and output. For example, consider the eager code: + def forward(self, x): + self.buffer += x + return x * x + + This corresponds to a graph that looks like: + def forward(self, buffer, x): + mutated_buffer = aten.add(buffer, x) + mul = aten.mul(x, x) + return (mutated_buffer, mul) + + We want to inplace this into something that looks like the original + eager code: + def forward(self, buffer, x): + mutated_buffer = aten.add(buffer, x) + buffer.copy_(mutated_buffer) + mul = aten.mul(x, x) + return (mul,) + + Input mutations are handled similarly. + """ + output_node = next(iter(reversed(graph.nodes))) + assert output_node.op == "output" and len(output_node.args) == 1 + return_args = output_node.args[0] + + input_name_to_node = { + node.name: node for node in graph.nodes if node.op == "placeholder" + } + mutation_name_to_input_name = {} + + # Collect mutated buffers. + buffer_fqn_to_input_name = { + buffer_fqn: k for k, buffer_fqn in graph_signature.inputs_to_buffers.items() + } + mutation_name_to_input_name = { + k: buffer_fqn_to_input_name[buffer_fqn] + for k, buffer_fqn in graph_signature.buffers_to_mutate.items() + } + # Collect mutated user inputs. + mutation_name_to_input_name.update(graph_signature.user_inputs_to_mutate) + + num_mutations = len(mutation_name_to_input_name) + + for mutation in return_args[:num_mutations]: + input_name = mutation_name_to_input_name[mutation.name] + input_node = input_name_to_node[input_name] + + with graph.inserting_after(mutation): + # Create a copy_ node that inplaces the mutation. + new_node = graph.create_node( + "call_function", torch.ops.aten.copy_.default, (input_node, mutation) + ) + for k, v in mutation.meta.items(): + new_node.meta[k] = v + # Replace all uses of the previously functional mutation with + # our copy_ node. + mutation.replace_all_uses_with(new_node, lambda x: x is not new_node) + + # Remove the mutated buffer / input from the graph outputs, since we don't + # need to thread it through anymore. + user_outputs = tuple(return_args[num_mutations:]) + output_node.args = ((user_outputs),) + + +def _root_module_type(graph: torch.fx.Graph) -> str | None: + for node in graph.nodes: + if "nn_module_stack" not in node.meta: + continue + + for path, ty in node.meta["nn_module_stack"].values(): + if not path: + return ty + return None + + +def _fix_nn_module_stacks(graph): + # For each nn module stack in the graph, check if the fqns in it represent a stack: + # 1. Each fqn must be a prefix of the next fqn. + # 2. If not, remove the entries starting from the next fqn, emitting a warning. + for node in graph.nodes: + if "nn_module_stack" not in node.meta: + continue + + nn_module_stack = node.meta["nn_module_stack"] + fqns = [ + fqn.split("@")[0] if "@" in fqn else fqn + for fqn, _t in nn_module_stack.values() + ] + + # Check if each FQN is a prefix of the next one + prev_fqn, *next_fqns = fqns + num_valid_indices = 1 # root FQN + for curr_fqn in next_fqns: + # Check if the previous FQN is a prefix of the current one + if _is_prefix(prev_fqn, curr_fqn): + num_valid_indices += 1 + prev_fqn = curr_fqn + else: + # Found a non-prefix FQN, stop here + break + + # If we need to remove entries, create a new stack with only valid entries + if num_valid_indices < len(nn_module_stack): + log.warning( + "nn_module_stack fqns %s at node %s do not form a stack! dropping last %d entries", + fqns, + node, + len(nn_module_stack) - num_valid_indices, + ) + node.meta["nn_module_stack"] = dict( + list(nn_module_stack.items())[:num_valid_indices] + ) + + +def _is_prefix(candidate, target): + """Check whether `candidate` is a prefix of `target`.""" + return len(candidate) < len(target) and target[: len(candidate)] == candidate + + +def _compute_accessor(parent_fqn: str, child_fqn: str) -> str: + if parent_fqn == "": + # Handle the root module correctly. + return child_fqn + + parent_split = parent_fqn.split(".") + child_split = child_fqn.split(".") + + # TODO: support skip connection by inlining the child module. + if child_split[: len(parent_split)] != parent_split: + raise RuntimeError( + f"Child module '{child_fqn}' is not a descendant of parent module '{parent_fqn}'." + "This is currently unsupported." + "Please try to make child module attach to parent module directly." + ) + return ".".join(child_split[len(parent_split) :]) + + +def _check_graph_equivalence(x: torch.nn.Module, y: torch.nn.Module): + def graph_dump(graph: torch.fx.Graph) -> str: + ret = [] + nodes_idx: dict[int, int] = {} + + def arg_dump(arg) -> str: + if isinstance(arg, torch.fx.Node): + return "%" + str(nodes_idx[id(arg)]) + return str(arg) + + for i, node in enumerate(graph.nodes): + args_dump = [str(arg) for arg in pytree.tree_map(arg_dump, node.args)] + args_dump += [ + f"{key}={value}" + for key, value in pytree.tree_map(arg_dump, node.kwargs).items() + ] + target = node.target if node.op in ("call_function", "get_attr") else "" + # pyrefly: ignore [bad-argument-type] + ret.append(f"{i}: {node.op}[{target}]({', '.join(args_dump)})") + nodes_idx[id(node)] = i + return "\n".join(ret) + + assert isinstance(x.graph, torch.fx.Graph) + assert isinstance(y.graph, torch.fx.Graph) + return graph_dump(x.graph) == graph_dump(y.graph) + + +def _add_spec(gm: torch.nn.Module, spec) -> str: + i = 0 + while hasattr(gm, f"_spec_{i}"): + i += 1 + name = f"_spec_{i}" + setattr(gm, name, spec) + return name + + +def _generate_flatten(gm: torch.fx.GraphModule, node) -> torch.fx.Node: + flatten = gm.graph.call_function(pytree.tree_flatten, (node,)) + getitem_0 = gm.graph.call_function(operator.getitem, (flatten, 0)) + return getitem_0 + + +def _generate_flatten_spec( + gm: torch.fx.GraphModule | InterpreterModule | UnflattenedModule, node, spec +) -> torch.fx.Node: + name = _add_spec(gm, spec) + spec_node = gm.graph.get_attr(name) + return gm.graph.call_function(fx_pytree.tree_flatten_spec, (node, spec_node)) + + +def _generate_unflatten( + gm: torch.fx.GraphModule | InterpreterModule | UnflattenedModule, nodes, spec +) -> torch.fx.Node: + name = _add_spec(gm, spec) + spec_node = gm.graph.get_attr(name) + return gm.graph.call_function(pytree.tree_unflatten, (nodes, spec_node)) + + +def _get_submodule(mod: torch.nn.Module, target: str): + *prefix, field = target.split(".") + + for item in prefix: + submod = getattr(mod, item, None) + + if submod is None: + return None + + if not isinstance(submod, torch.nn.Module): + return None + + mod = submod + + return getattr(mod, field, None) + + +def _add_submodule( + mod: torch.nn.Module, + target: str, + module_to_add: torch.nn.Module, + create_module: Callable[[str], torch.nn.Module] | None = None, +): + *prefix, field = target.split(".") + + for i, item in enumerate(prefix): + submod = getattr(mod, item, None) + + if submod is None: + if create_module is not None: + submod = create_module(".".join(prefix[: i + 1])) + else: + submod = torch.nn.Module() + setattr(mod, item, submod) + + if not isinstance(submod, torch.nn.Module): + return False + + mod = submod + + mod.add_module(field, module_to_add) + + +def _call_name(base: str, n: int) -> str: + # Given n >= 0, generate call names to a submodule `base` of the form + # `base`, `base@1`, `base@2`, etc. + return base if n == 1 else f"{base}@{n - 1}" + + +def _is_call_name(call_name: str, base: str) -> bool: + # Recognize when call_name = _call_name(base, n) for some n >= 0. + return re.match(re.escape(base) + r"(@\d+)?$", call_name) is not None + + +class _ModuleFrame: + def __init__( + self, + flat_graph: torch.fx.Graph, + nodes: tuple[torch.fx.Node, ...], + seen_nodes, + seen_modules, + seen_attrs, + created_modules, + parent, + module_stack: list[tuple[str, str | None, int]], + module_id, + module_call_graph: dict[str, ModuleCallSignature], + module: torch.fx.GraphModule | UnflattenedModule | None = None, + ): + self.flat_graph = flat_graph + self.nodes = nodes + self.seen_nodes = seen_nodes + self.seen_modules = seen_modules + self.seen_attrs = seen_attrs + self.created_modules = created_modules + self.parent = parent + self.module_stack = module_stack + self.module_id = module_id + + self.module_call_graph = module_call_graph + self.verbose = False + + self.fqn, ty, num_calls = self.module_stack[-1] + # generate call name for self.fqn + self.child_fqn = _call_name(self.fqn, num_calls + 1) + + self.module: torch.fx.GraphModule | UnflattenedModule | InterpreterModule + if module is not None: + self.module = module + self.ivals = module.ivals if hasattr(module, "ivals") else {} # type: ignore[var-annotated] + else: + self.module = self.created_modules.get( + self.fqn, + InterpreterModule(torch.fx.Graph(), ty=ty), + ) + self.ivals = parent.ivals + + self.graph = self.module.graph + + # Mapping of nodes in the flat graph to nodes in this graph. + self.node_map: dict[torch.fx.Node, torch.fx.Node] = {} + self.node_to_placeholder = {} + + self.parent_call_module: torch.fx.Node | None = None + if parent is not None: + accessor = _compute_accessor(parent.fqn, self.child_fqn) + + def create_module(fqn): + path = f"{parent.fqn}.{fqn}" if parent.fqn else fqn + if path in self.created_modules: + return self.created_modules[path] + submod = InterpreterModule(torch.fx.Graph(), ty=ty) + self.created_modules[path] = submod + return submod + + _add_submodule(parent.module, accessor, self.module, create_module) + self.parent_call_module = parent.graph.call_module(accessor) + if self.seen_modules[self.module_id]: + base_module_frame = self.seen_modules[self.module_id][0] + self.module._modules = base_module_frame.module._modules + self.seen_modules[self.module_id].append( + _SubmoduleEntry( + parent_fqn=self.parent.fqn, + parent_module=self.parent.module, + parent_call_module=self.parent_call_module, + fqn=self.fqn, + call_idx=num_calls + 1, + module=self.module, + ) + ) + + signature = module_call_graph.get(self.child_fqn) + if signature is not None and self.parent is not None: + assert signature.in_spec.num_children == 2 + assert signature.in_spec.type is tuple + args_spec, kwargs_spec = signature.in_spec.children() + assert args_spec.type is tuple + assert kwargs_spec.type is dict + + with self.graph.inserting_after(None): + arg_nodes = [ + self.graph.placeholder(f"_positional_arg_{idx}") + for idx in range(args_spec.num_children) + ] + kwarg_nodes = {} + for name in kwargs_spec.context: + kwarg_nodes[name] = self.graph.placeholder(name) + flat_args = _generate_flatten_spec( + self.module, + (tuple(arg_nodes), kwarg_nodes), + signature.in_spec, + ) + for idx, arg in enumerate(signature.inputs): + flat_arg_node = self.graph.create_node( + op="call_function", + target=operator.getitem, + args=(flat_args, idx), + name=( + arg.name + if not isinstance(arg, ConstantArgument) + else f"_constant_{idx}" + ), + ) + if isinstance(arg, ConstantArgument): + continue + + if arg.name in self.seen_nodes: + flat_arg_node.meta = copy.copy(self.seen_nodes[arg.name].meta) + self.node_to_placeholder[self.seen_nodes[arg.name]] = ( + flat_arg_node + ) + + with self.parent.graph.inserting_before(self.parent_call_module): + input_nodes: list[torch.fx.Node | None] = [] + for input in signature.inputs: + if isinstance(input, ConstantArgument): + input_nodes.append(input.value) # type: ignore[arg-type] + elif input.name not in self.seen_nodes: + input_nodes.append(None) + else: + assert isinstance( + input, + ( + TensorArgument, + SymIntArgument, + SymBoolArgument, + SymFloatArgument, + ), + ) + input_nodes.append( + self.parent.remap_input(self.seen_nodes[input.name]) + ) + + inputs_node = _generate_unflatten( + self.parent.module, + input_nodes, + signature.in_spec, + ) + + args_node = self.parent.graph.call_function( + operator.getitem, (inputs_node, 0) + ) + kwargs_node = self.parent.graph.call_function( + operator.getitem, (inputs_node, 1) + ) + arg_nodes = [ + self.parent.graph.call_function(operator.getitem, (args_node, i)) + for i in range(args_spec.num_children) + ] + kwarg_nodes = { + k: self.parent.graph.call_function( + operator.getitem, (kwargs_node, k) + ) + for k in kwargs_spec.context + } + assert self.parent_call_module is not None + # pyrefly: ignore [bad-assignment] + self.parent_call_module.args = tuple(arg_nodes) + self.parent_call_module.kwargs = kwarg_nodes # type: ignore[assignment] + + def add_placeholder(self, x): + assert self.fqn != "", f"Cannot add placeholder {x} to root module" + assert x.graph is self.flat_graph + # x is not in subgraph, create a new placeholder for subgraph + with self.graph.inserting_before(None): + placeholder_node = self.graph.placeholder(x.name, type_expr=x.type) + # copy all meta fields, even if some fields might be irrelevant for + # the placeholder node + placeholder_node.meta = copy.copy(x.meta) + self.node_to_placeholder[x] = placeholder_node + + def copy_sym_call_function(self, x): + # This only exists because we deduplicate sym_size nodes in the flat export graph, + # and if preserve_module_call_signature is set, we may not be able to pass sym_size + # nodes, or their downstream users, as inputs to submodule calls. + # To avoid this we copy these call_function nodes with sym_type results. + # This should however only be done for sym_type nodes - call_function nodes on tensors + # should not be deduplicated in the first place. + args = pytree.tree_map_only(torch.fx.Node, self.remap_input, x.args) + kwargs = pytree.tree_map_only(torch.fx.Node, self.remap_input, x.kwargs) + node = self.graph.call_function(x.target, args, kwargs) + node.meta = copy.copy(x.meta) + self.node_map[x] = node + return node + + def remap_input(self, x): + assert x.graph is self.flat_graph + if x in self.node_map: + return self.node_map[x] + self.print(f"remap_input({x})") + if x in self.node_to_placeholder: + return self.node_to_placeholder[x] + elif ( + x.op == "placeholder" or self.module_call_graph.get(self.fqn) is None + # allow placeholder creation if we are not preserving module call signature + ): + self.add_placeholder(x) + if self.parent_call_module is not None: + # Important to *prepend* the output to match how we are + # inserting placeholder nodes. + with self.parent.graph.inserting_before(self.parent_call_module): + self.parent_call_module.insert_arg(0, self.parent.remap_input(x)) + return self.node_to_placeholder[x] + elif x.op == "call_function" and ( + x.target + in ( + torch.ops.aten.sym_size.int, + torch.ops.aten.item.default, + torch.ops.aten.unbind.int, + torch.ops.aten.sum.dim_IntList, + torch.ops.aten.view.default, + torch.ops.aten.diff.default, + ) + or (hasattr(x.target, "__module__") and x.target.__module__ == "_operator") + ): + # export deduplicates sym_size nodes, and may need to re-copy them + # if module call signature needs to be preserved + self.copy_sym_call_function(x) + return self.node_map[x] + elif self.module_call_graph.get(self.fqn) is not None: + # x is reading the intermediate value of a mutation, so record it; + # later we will find where it was created and perform the update + return self.ivals.read(self, x) # type: ignore[operator, union-attr] + else: + raise RuntimeError( + f"Could not run remap_input() on op type: {x.op} for node {x}" + ) + + def uplift_common_custom_metadata(self) -> None: + # Copy custom metadata if all nodes have same custom metadata + custom_meta = None + for node in self.node_map.values(): + curr_meta = node.meta.get("custom", {}) + if custom_meta is None: + # first node + custom_meta = curr_meta + continue + + if curr_meta != custom_meta: + custom_meta = {} + break + + if custom_meta: + # Lift common custom metadata to parent node and clear children node's custom metadata + assert self.parent_call_module is not None + self.parent_call_module.meta["custom"] = custom_meta + for node in self.node_map.values(): + del node.meta["custom"] + + def finalize_outputs(self): + self.created_modules.pop(self.fqn, None) + + orig_outputs = [] + + signature = self.module_call_graph.get(self.child_fqn) + if signature is not None and self.parent is not None: + for output in signature.outputs: + if isinstance( + output, + ( + TensorArgument, + SymIntArgument, + SymBoolArgument, + SymFloatArgument, + ConstantArgument, + ), + ): + if output.name in self.seen_nodes: + orig_outputs.append(self.seen_nodes[output.name]) + else: + orig_outputs.append(None) + else: + raise RuntimeError( + f"Unsupported data type for output node: {output}" + ) + + def get_actual_output_node(output): + if output is None: + return None + + seen_node = self.seen_nodes[output.name] + if seen_node in self.node_map: + return self.node_map[seen_node] + elif seen_node in self.node_to_placeholder: + return self.node_to_placeholder[seen_node] + else: + raise RuntimeError( + f"Could not find output node {output}. Graph: {self.graph}" + ) + + tree_out_node = _generate_unflatten( + self.module, + tuple(get_actual_output_node(output) for output in orig_outputs), + signature.out_spec, + ) + parent_out: torch.fx.Node | None = _generate_flatten_spec( + self.parent.module, self.parent_call_module, signature.out_spec + ) + graph_outputs: torch.fx.Node | list[torch.fx.Node] = tree_out_node + else: + graph_outputs = [] + # Iterate through nodes we have copied into self.graph. + for orig_node in self.node_map: + for user_node in orig_node.users: + if user_node.name not in self.seen_nodes: + # external user node, need to expose as an output + orig_outputs.append(orig_node) + graph_outputs.append(self.node_map[orig_node]) + break + + parent_out = self.parent_call_module + if len(graph_outputs) == 1: + graph_outputs = graph_outputs[0] + + assert isinstance(graph_outputs, (list, torch.fx.Node)) + + self.graph.output(graph_outputs) + + # Rewrite outputs in parent module + if parent_out is None: + return + + parent_out.meta["val"] = ( + graph_outputs.meta.get("val") + if isinstance(graph_outputs, torch.fx.Node) + else [o.meta.get("val") for o in graph_outputs] + ) + self.uplift_common_custom_metadata() + + if len(orig_outputs) == 1 and signature is None: + self.parent.node_map[orig_outputs[0]] = parent_out + else: + for i, orig_output in enumerate(orig_outputs): + if orig_output is None: + continue + # Use Proxy to record getitem access. + proxy_out = torch.fx.Proxy(parent_out)[i].node # type: ignore[index] + proxy_out.meta["val"] = orig_output.meta.get("val") + self.parent.node_map[orig_output] = proxy_out + + def copy_node(self, node): + self.print("copying", node.format_node()) + self.node_map[node] = self.graph.node_copy(node, self.remap_input) + self.seen_nodes[node.name] = node + + def run_outer(self): + for i, node in enumerate(self.flat_graph.nodes): + self.print(i, node.meta.get("nn_module_stack"), node.format_node()) + + # Copy all graph inputs + node_idx: int = 0 + node = self.nodes[node_idx] + while node.op == "placeholder": + self.copy_node(node) + node_idx += 1 + node = self.nodes[node_idx] + + self.run_from(node_idx) + + # Copy graph outputs + for node in self.flat_graph.nodes: + if node.op == "output": + self.copy_node(node) + + def print(self, *args, **kwargs): + if self.verbose: + # pyrefly: ignore [not-iterable] + print(*args, **kwargs) + + def run_from(self, node_idx): + module_idx = 0 + # Walk through the graph, building up a new graph with the right submodules + while node_idx < len(self.nodes): + node = self.nodes[node_idx] + assert node.op != "placeholder" + + self.print() + self.print("STEP", node_idx, node.format_node()) + self.print(self.module_stack) + depth = len(self.module_stack) + if node.op == "output": + if depth == 1: + # We want the output node of the original graph to be handled + # specially by the outermost stack frame (in run_outer). So + # skip finalization here. + return node_idx + + # We've reached the end of the graph. Wrap up all the existing stack frames. + self.finalize_outputs() + return node_idx + + if len(node.meta.get("nn_module_stack", {})) == 0: + raise RuntimeError(f"Unable to find nn_module_stack for node {node}") + + nn_module_stack = node.meta["nn_module_stack"] + from torch._export.passes._node_metadata_hook import ( + _EMPTY_NN_MODULE_STACK_KEY, + ) + + if ( + len(nn_module_stack) == 1 + and _EMPTY_NN_MODULE_STACK_KEY in nn_module_stack + ): + # Empty case from the node_metadata_hook + node_module_stack = self.module_stack + else: + node_module_stack = [ + ( + path, + ty if path else None, + int(k.split("@")[-1]) if "@" in k else 0, + ) + for k, (path, ty) in node.meta["nn_module_stack"].items() + ] + + if node_module_stack[:depth] != self.module_stack: + # This means that the current module is done executing and the + # current node is the beginning of a new module. + # + # In this case, we should finalize this module and return without + # incrementing the node counter. + self.finalize_outputs() + self.print("outlining", self.fqn) + self.print(self.graph) + return node_idx + + assert node_module_stack is not None + + if _is_prefix(self.module_stack, node_module_stack): + # This means that the current node represents the execution of a new + # module. + next_module = node_module_stack[depth] + self.print("Creating new stack frame for", next_module) + # Run a nested version of module outliner from the current node + # counter. Once it is complete, continue from that point. + next_module_key = list(node.meta["nn_module_stack"].keys())[depth] + node_idx = _ModuleFrame( + self.flat_graph, + self.nodes, + self.seen_nodes, + self.seen_modules, + self.seen_attrs, + self.created_modules, + self, + self.module_stack + [next_module], + next_module_key.split("@")[0], + self.module_call_graph, + ).run_from(node_idx) + module_idx += 1 + continue + + # The only remaining possibility is that we are in the right stack + # frame. Copy the node into this frame's graph and increment the node counter. + assert node_module_stack == self.module_stack + + if node.op == "get_attr": + # this must be a graph argument for a HOP + self.seen_attrs[self.child_fqn].add(node.target) + + self.copy_node(node) + # pyrefly: ignore [unsupported-operation] + node_idx += 1 + + +@dataclass +class _SubmoduleEntry: + parent_fqn: str + parent_module: torch.nn.Module + parent_call_module: torch.fx.Node + fqn: str + call_idx: int + module: torch.nn.Module + + +def _outline_submodules(orig_graph: torch.fx.Graph, root_module: UnflattenedModule): + seen_nodes: dict[str, torch.fx.Node] = {} + seen_modules: dict[int, list[_SubmoduleEntry]] = defaultdict(list) + seen_attrs: dict[str, set[str]] = defaultdict(set) + created_modules: dict[str, torch.nn.Module] = {} + _ModuleFrame( + orig_graph, + tuple(orig_graph.nodes), + seen_nodes, + seen_modules, + seen_attrs, + created_modules, + None, + [("", None, 0)], + "", + { + entry.fqn: entry.signature + for entry in root_module.module_call_graph + if entry.signature + }, + module=root_module, + ).run_outer() + return seen_modules, seen_attrs + + +def _reorder_submodules( + parent: torch.nn.Module, fqn_order: dict[str, int], prefix: str = "" +): + # TODO Can be optimized by adding submodules ahead of time. + if prefix == "": + for fqn in list(fqn_order.keys())[1:]: + if _get_submodule(parent, fqn) is None: + _add_submodule(parent, fqn, torch.nn.Module()) + + children = [] + for name, child in list(parent._modules.items()): + if child is None: + continue + fqn = prefix + name + _reorder_submodules(child, fqn_order, prefix=fqn.split("@")[0] + ".") + delattr(parent, name) + children.append((fqn_order[fqn], name, child)) + children.sort(key=operator.itemgetter(0)) + for _, name, child in children: + parent.register_module(name, child) + + +class _IVals: + """ + Collect the intermediate values of mutations in a graph. + + Example: in the following graph, suppose that buf_in and buf_out + are the input and output values of a buffer. + + buf_in = placeholder() + ... + ival1 = f0(buf_in, ...) # inside self.n0(...) + ... + ival2 = f1(ival1, ...) # inside self.n1(...) + ... + buf_out = f2(ival2, ...) # inside self.n2(...) + return buf_out, ... + + Here ival1 and ival2 are intermediate values created inside + calls to n0 and n1 respectively, and used inside calls to + n1 and n2 respectively. + """ + + def __init__(self): + # for each fqn, set of node names corresponding to intermediate values + self.node_names_by_fqn = defaultdict(set) + + def _is_mutable(self, target): + if isinstance(target, torch._ops.OpOverload): + return target._schema.is_mutable + return False + + def read(self, mf, node): + """ + Read state corresponding to a given intermediate value. + """ + # we can assume that the node must be from a mutation + assert node.op == "call_function" + b = self._is_mutable(node.target) + print("Checking mutability", node.target, b) + if not b: + # so the mutation was functionalized; + # we will apply the original mutation later (see below) + fqn, _ = next(reversed(node.meta["nn_module_stack"].values())) + self.node_names_by_fqn[fqn].add(node.name) + return mf.remap_input(node.args[0]) + + def update(self, partitions): + """ + Update states corresponding to intermediate values that were read. + """ + for shared_submodules in partitions: + for entry in shared_submodules: + graph = entry.module.graph + node_names = self.node_names_by_fqn[entry.fqn] + nodes = [n for n in graph.nodes if n.name in node_names] + for node in nodes: + # so node must be from a functionalized mutation; + # we perform the original mutation now + with graph.inserting_after(node): + new_node = graph.create_node( + "call_function", + torch.ops.aten.copy_.default, + (node.args[0], node), + ) + new_node.meta = copy.copy(node.meta) + + +def _copy_graph_attrs( + gm: torch.fx.GraphModule, + root_module: UnflattenedModule, + seen_attrs: dict[str, set[str]], +): + for child_fqn, names in seen_attrs.items(): + module = _get_attr(root_module, child_fqn) if child_fqn else root_module + for name in names: + val = getattr(gm, name) + setattr(module, name, val) + + +def _deduplicate_modules(partitions): + redirected_call_indices = {} + for shared_submodules in partitions: + for i, entry in enumerate(shared_submodules): + child_fqn = _call_name(entry.fqn, entry.call_idx) + target = _compute_accessor(entry.parent_fqn, child_fqn) + deduplicated = False + # Iterate over all previously seen modules, and deduplicate if possible + for seen in shared_submodules[:i]: + if _check_graph_equivalence(seen.module, entry.module): + parent = entry.parent_module + # Since graphs are equivalent, we can deduplicate. + # There are two cases. + if seen.fqn == entry.fqn: + # Case 1: The current module has the same fqn as the seen module. + # In this case we have generated a call name that can be optimized away. + # So we remove the current module from the hierarchy and replace + # the current call name with the seen call name in the parent graph. + *prefix, name = target.split(".") + _get_attr_via_attr_list(parent, prefix)._modules.pop(name) + seen_child_fqn = _call_name(seen.fqn, seen.call_idx) + seen_target = _compute_accessor( + entry.parent_fqn, seen_child_fqn + ) + entry.parent_call_module.target = seen_target + redirected_call_indices[child_fqn] = seen_child_fqn + break + elif not deduplicated: + # Case 2: The current module has a different fqn than the seen module. + # In this case we replace the current module with the seen module. + # There should be nothing pointing to the current module any more, + # so it can be garbage collected. + # NOTE: We *do not* replace the current call name with the seen call name + # in the parent graph, because this will lose information on which fqn + # was actually called. However, it is possible that the current call name + # will be optimized away when we find another seen module with the same fqn, + # so we do not break out of the loop yet. + parent.set_submodule(target, seen.module) + deduplicated = True + + return redirected_call_indices + + +def _sink_params( + module: torch.nn.Module, + inputs_to_state: dict[str, list[str]], + scope: list[str], + module_id_to_inputs_removed: dict[int, set[str]] | None = None, +): + """Sink params, buffers, and constants from graph inputs into get_attr nodes. + + Exported modules are purely functional, so they pass their parameters and + buffers in as inputs to the graph. + + To replicate eager's semantics, we need to get them from the module state + via get_attr instead. + + module: GraphModule, potentially containing nested submodules. + inputs_to_state: mapping graph input names to the corresponding key in the state_dict. + scope: tracks where we are in the module hierarchy, so that we can emit the + right `getattr(self, "foo.bar")` calls, etc. + module_id_to_inputs_removed: records inputs removed by child modules, mapping + the module object id to the list of placeholder node names in the child module + that were removed. + """ + if module_id_to_inputs_removed is None: + module_id_to_inputs_removed = defaultdict(set) + + if id(module) in module_id_to_inputs_removed: + return {id(module): module_id_to_inputs_removed[id(module)]} + + # We need to use _modules here instead of named_children(), because we + # explicitly want duplicate modules to show up in the traversal. + for name, submodule in module._modules.items(): + submod_id_to_inputs_removed = _sink_params( + cast("torch.nn.Module", submodule), + inputs_to_state, + scope + [name], + module_id_to_inputs_removed, + ) + for k, v in submod_id_to_inputs_removed.items(): + module_id_to_inputs_removed[k].update(v) + + graph = getattr(module, "graph", None) + if graph is None or len(graph.nodes) == 0: + # Not all modules have graphs defined, if they are empty modules with no operations (like ParameterList) + return module_id_to_inputs_removed + + assert isinstance(graph, torch.fx.Graph) + + inputs = list(filter(lambda n: n.op == "placeholder", graph.nodes)) + the_last_input = None if len(inputs) == 0 else inputs[-1] + + # Also remove from call_module nodes + call_module_nodes = filter(lambda n: n.op == "call_module", graph.nodes) + for node in call_module_nodes: + submodule = _get_attr(module, node.target) + # remove placeholder from call_module node arguments, only if we've + # erased the placeholder node in the corresponding _sink_params() call + if submodule is not None and id(submodule) in module_id_to_inputs_removed: + node.args = tuple( + filter( + lambda n: n.name not in module_id_to_inputs_removed[id(submodule)], + node.args, + ) + ) + + # Filter out inputs_to_state corresponding to current scope. + inputs_to_state_of_scope: dict[torch.fx.Node, list[str]] = {} + for node in inputs: + if node.name not in inputs_to_state: + continue + + state_name = None + for sn in inputs_to_state[node.name]: + sn_split = sn.split(".") + if sn_split[: len(scope)] == [x.split("@")[0] for x in scope]: + state_name = sn_split + break + + # If there's a mismatch between scope name and state name, then + # there must be multiple scopes pointing to the same state name, + # meaning some modules are shared. In such case, we can simply skip + # updating the current node because another later iteration will + # take care of this input node when the unique match between scope + # and state name occurs. To make sure this always happen, we should + # enforce the invariant that no placeholder node in the unflattened + # graph appears in inputs_to_state dict, which means all the extra + # input nodes have been handled. + if state_name is None: + continue + + inputs_to_state_of_scope[node] = state_name + + # Record name of remove inputs for return purpose. + inputs_removed: set[str] = set() + + for node, state_name in inputs_to_state_of_scope.items(): + if len(node.users) > 0: + attr_path = state_name[len(scope) :] + state_attr = _get_attr_via_attr_list(module, attr_path) + assert isinstance(state_attr, (torch.Tensor, torch.ScriptObject)) + + # Make sure the newly created get_attr node is placed after the last placeholder node + with graph.inserting_after(the_last_input): + new_node = graph.create_node("get_attr", ".".join(attr_path)) + + node.replace_all_uses_with(new_node, propagate_meta=True) + + graph.erase_node(node) + inputs_removed.add(node.name) + + if isinstance(module, InterpreterModule): + module.finalize() + + return {id(module): inputs_removed} diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/__init__.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c048b4fdd8f8940d46ff75a265af3bc7587255c7 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/__init__.py @@ -0,0 +1,116 @@ +r''' +FX is a toolkit for developers to use to transform ``nn.Module`` +instances. FX consists of three main components: a **symbolic tracer,** +an **intermediate representation**, and **Python code generation**. A +demonstration of these components in action: + +:: + + import torch + + + # Simple module for demonstration + class MyModule(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.param = torch.nn.Parameter(torch.rand(3, 4)) + self.linear = torch.nn.Linear(4, 5) + + def forward(self, x): + return self.linear(x + self.param).clamp(min=0.0, max=1.0) + + + module = MyModule() + + from torch.fx import symbolic_trace + + # Symbolic tracing frontend - captures the semantics of the module + symbolic_traced: torch.fx.GraphModule = symbolic_trace(module) + + # High-level intermediate representation (IR) - Graph representation + print(symbolic_traced.graph) + """ + graph(): + %x : [num_users=1] = placeholder[target=x] + %param : [num_users=1] = get_attr[target=param] + %add : [num_users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {}) + %linear : [num_users=1] = call_module[target=linear](args = (%add,), kwargs = {}) + %clamp : [num_users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0}) + return clamp + """ + + # Code generation - valid Python code + print(symbolic_traced.code) + """ + def forward(self, x): + param = self.param + add = x + param; x = param = None + linear = self.linear(add); add = None + clamp = linear.clamp(min = 0.0, max = 1.0); linear = None + return clamp + """ + +The **symbolic tracer** performs "symbolic execution" of the Python +code. It feeds fake values, called Proxies, through the code. Operations +on these Proxies are recorded. More information about symbolic tracing +can be found in the :func:`symbolic_trace` and :class:`Tracer` +documentation. + +The **intermediate representation** is the container for the operations +that were recorded during symbolic tracing. It consists of a list of +Nodes that represent function inputs, callsites (to functions, methods, +or :class:`torch.nn.Module` instances), and return values. More information +about the IR can be found in the documentation for :class:`Graph`. The +IR is the format on which transformations are applied. + +**Python code generation** is what makes FX a Python-to-Python (or +Module-to-Module) transformation toolkit. For each Graph IR, we can +create valid Python code matching the Graph's semantics. This +functionality is wrapped up in :class:`GraphModule`, which is a +:class:`torch.nn.Module` instance that holds a :class:`Graph` as well as a +``forward`` method generated from the Graph. + +Taken together, this pipeline of components (symbolic tracing -> +intermediate representation -> transforms -> Python code generation) +constitutes the Python-to-Python transformation pipeline of FX. In +addition, these components can be used separately. For example, +symbolic tracing can be used in isolation to capture a form of +the code for analysis (and not transformation) purposes. Code +generation can be used for programmatically generating models, for +example from a config file. There are many uses for FX! + +Several example transformations can be found at the +`examples `__ +repository. +''' + +from torch.fx import immutable_collections +from torch.fx._symbolic_trace import ( # noqa: F401 + PH, + ProxyableClassMeta, + symbolic_trace, + Tracer, + wrap, +) +from torch.fx.graph import CodeGen, Graph # noqa: F401 +from torch.fx.graph_module import GraphModule +from torch.fx.interpreter import Interpreter, Transformer +from torch.fx.node import has_side_effect, map_arg, Node +from torch.fx.proxy import Proxy +from torch.fx.subgraph_rewriter import replace_pattern + + +__all__ = [ + "symbolic_trace", + "Tracer", + "wrap", + "Graph", + "GraphModule", + "Interpreter", + "Transformer", + "Node", + "Proxy", + "replace_pattern", + "has_side_effect", + "map_arg", +] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/_compatibility.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/_compatibility.py new file mode 100644 index 0000000000000000000000000000000000000000..c07dd1b51bc05a7b1288efdd331a9bec4926845f --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/_compatibility.py @@ -0,0 +1,40 @@ +import textwrap +from collections.abc import Callable +from typing import Any, TypeVar + + +_BACK_COMPAT_OBJECTS: dict[Any, None] = {} +_MARKED_WITH_COMPATIBILITY: dict[Any, None] = {} + + +_T = TypeVar("_T") + + +def compatibility(is_backward_compatible: bool) -> Callable[[_T], _T]: + if is_backward_compatible: + + def mark_back_compat(fn: _T) -> _T: + docstring = textwrap.dedent(getattr(fn, "__doc__", None) or "") + docstring += """ +.. note:: + Backwards-compatibility for this API is guaranteed. +""" + fn.__doc__ = docstring + _BACK_COMPAT_OBJECTS.setdefault(fn) + _MARKED_WITH_COMPATIBILITY.setdefault(fn) + return fn + + return mark_back_compat + else: + + def mark_not_back_compat(fn: _T) -> _T: + docstring = textwrap.dedent(getattr(fn, "__doc__", None) or "") + docstring += """ +.. warning:: + This API is experimental and is *NOT* backward-compatible. +""" + fn.__doc__ = docstring + _MARKED_WITH_COMPATIBILITY.setdefault(fn) + return fn + + return mark_not_back_compat diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/_graph_pickler.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/_graph_pickler.py new file mode 100644 index 0000000000000000000000000000000000000000..eb6465680570b3300407cb5297f3c0cba6420276 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/_graph_pickler.py @@ -0,0 +1,647 @@ +import dataclasses +import importlib +import io +import pickle +from abc import abstractmethod +from collections.abc import Callable +from typing import Any, NewType, Optional, TypeVar, Union +from typing_extensions import override, Self + +import torch +import torch.utils._pytree as pytree +from torch._guards import TracingContext +from torch._inductor.standalone_compile import AOTCompiledArtifact +from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode, Tensor +from torch._subclasses.meta_utils import ( + MetaConverter, + MetaTensorDesc, + MetaTensorDescriber, +) +from torch.fx.experimental.sym_node import SymNode +from torch.fx.experimental.symbolic_shapes import ShapeEnv +from torch.utils._mode_utils import no_dispatch + + +_SymNodeT = TypeVar("_SymNodeT", torch.SymInt, torch.SymFloat) + + +def _ops_filter_safe(name: str) -> bool: + """ + An ops filter which allows pickle-safe ops. Pickle-safe ops are built-in + ones where it will be possible to unpickle on any machine which has PyTorch. + """ + # TODO: This list is pretty pessimistic right now. What's the full list? + return name.startswith( + ( + "torch.ops.aten", + "torch.ops.fbgemm", + ) + ) + + +@dataclasses.dataclass +class Options: + # A filter for which ops will cause the pickler to raise a + # BypassFxGraphCache exception. If None then all ops are allowed. + ops_filter: Optional[Callable[[str], bool]] = _ops_filter_safe + + +class GraphPickler(pickle.Pickler): + """ + GraphPickler is a Pickler which helps pickling fx graph - in particular + GraphModule. + """ + + def __init__(self, file: io.BytesIO, options: Optional[Options] = None) -> None: + super().__init__(file) + self.options = options or Options() + + # This abomination is so we can pass external decoding state to the + # unpickler functions. We serialize _unpickle_state as a persistent + # external item and when we deserialize it we return the common state + # object. + self._unpickle_state = _UnpickleStateToken(object()) + + # This is used to describe tensors. It needs to be common across the + # pickle so that duplicates and views are properly handled. + self._meta_tensor_describer = MetaTensorDescriber(copy_data=False) + + @override + # pyrefly: ignore [bad-override] + def reducer_override( + self, obj: object + ) -> tuple[Callable[..., Any], tuple[Any, ...]]: + # This function is supposed to return either NotImplemented (meaning to + # do the default pickle behavior) or a pair of (unpickle callable, data + # to pass to unpickle). + + # We could instead teach individual classes how to pickle themselves but + # that has a few problems: + # + # 1. If we have some special needs (maybe for this use-case we don't + # want to fully serialize every field) then we're adding private + # details to a public interface. + # + # 2. If we need to have some common shared data (such as a + # FakeTensorMode) which is passed to each value it's harder to + # support. + + # These are the types that need special handling. See the individual + # *PickleData classes for details on pickling that particular type. + if isinstance(obj, FakeTensor): + return _TensorPickleData.reduce_helper(self, obj) + elif isinstance(obj, torch.fx.GraphModule): + return _GraphModulePickleData.reduce_helper(self, obj) + elif isinstance(obj, (torch._ops.OperatorBase, torch._ops.OpOverloadPacket)): + return _OpPickleData.reduce_helper(self, obj) + elif isinstance(obj, ShapeEnv): + return _ShapeEnvPickleData.reduce_helper(self, obj) + elif isinstance(obj, torch.SymInt): + return _SymNodePickleData.reduce_helper(self, obj) + elif isinstance(obj, torch._guards.TracingContext): + return _TracingContextPickleData.reduce_helper(self, obj) + else: + # We should never get a raw Node! + assert not isinstance(obj, torch.fx.Node) + if reduce := _TorchNumpyPickleData.reduce_helper(self, obj): + return reduce + + # returning `NotImplemented` causes pickle to revert to the default + # behavior for this object. + return NotImplemented + + @override + def persistent_id(self, obj: object) -> Optional[str]: + if obj is self._unpickle_state: + return "unpickle_state" + else: + return None + + @classmethod + def dumps(cls, obj: object, options: Optional[Options] = None) -> bytes: + """ + Pickle an object. + """ + with io.BytesIO() as stream: + pickler = cls(stream, options) + pickler.dump(obj) + return stream.getvalue() + + @staticmethod + def loads(data: bytes, fake_mode: FakeTensorMode) -> object: + """ + Unpickle an object. + """ + state = _UnpickleState(fake_mode) + with io.BytesIO(data) as stream: + unpickler = _GraphUnpickler(stream, state) + return unpickler.load() + + +class _UnpickleState: + def __init__(self, fake_mode: FakeTensorMode) -> None: + self.fake_mode = fake_mode + self.meta_converter: MetaConverter[FakeTensor] = MetaConverter() + + +# This token is passed when pickling to indicate that we want to use the +# unpickler's _UnpickleState as a parameter in that position. +_UnpickleStateToken = NewType("_UnpickleStateToken", object) + + +class _GraphUnpickler(pickle.Unpickler): + def __init__(self, stream: io.BytesIO, unpickle_state: _UnpickleState) -> None: + super().__init__(stream) + self._unpickle_state = unpickle_state + + @override + def persistent_load(self, pid: object) -> object: + if pid == "unpickle_state": + return self._unpickle_state + else: + raise pickle.UnpicklingError("Invalid persistent ID") + + +class _ShapeEnvPickleData: + data: dict[str, object] + + @classmethod + def reduce_helper( + cls, pickler: GraphPickler, obj: ShapeEnv + ) -> tuple[ + Callable[[Self, _UnpickleState], ShapeEnv], tuple[Self, _UnpickleStateToken] + ]: + return cls.unpickle, (cls(obj), pickler._unpickle_state) + + def __init__(self, env: ShapeEnv) -> None: + # In theory pickle should recognize that a given ShapeEnv was already + # pickled and reuse the resulting _ShapeEnvPickleData (so two objects + # pointing at the same ShapeEnv get the same ShapeEnv out). + assert not env._translation_validation_enabled + self.data = env.__dict__.copy() + del self.data["tracked_fakes"] + del self.data["fake_tensor_cache"] + + def unpickle(self, unpickle_state: _UnpickleState) -> ShapeEnv: + # Fill in the existing ShapeEnv rather than creating a new one + assert unpickle_state.fake_mode + assert unpickle_state.fake_mode.shape_env + + for k, v in self.data.items(): + setattr(unpickle_state.fake_mode.shape_env, k, v) + + return unpickle_state.fake_mode.shape_env + + +class _SymNodePickleData: + @classmethod + def reduce_helper( + cls, + pickler: GraphPickler, + obj: _SymNodeT, + ) -> tuple[ + Callable[[Self, _UnpickleState], _SymNodeT], tuple[Self, _UnpickleStateToken] + ]: + args = (cls(obj.node), pickler._unpickle_state) + if isinstance(obj, torch.SymInt): + # pyrefly: ignore [bad-return] + return _SymNodePickleData.unpickle_sym_int, args + else: + raise NotImplementedError(f"Unhandled SymNode type {type(obj)}") + + def __init__(self, node: SymNode) -> None: + self.expr = node._expr + self.shape_env = node.shape_env + self.pytype = node.pytype + self.hint = node._hint + + def _to_sym_node(self) -> SymNode: + assert self.shape_env is not None + return SymNode(self.expr, self.shape_env, self.pytype, self.hint) + + def unpickle_sym_int(self, unpickle_state: _UnpickleState) -> torch.SymInt: + return torch.SymInt(self._to_sym_node()) + + +class _TensorPickleData: + metadata: MetaTensorDesc[FakeTensor] + + @classmethod + def reduce_helper( + cls, pickler: GraphPickler, obj: FakeTensor + ) -> tuple[ + Callable[[Self, _UnpickleState], FakeTensor], tuple[Self, _UnpickleStateToken] + ]: + return cls.unpickle, ( + cls(pickler._meta_tensor_describer, obj), + pickler._unpickle_state, + ) + + def __init__(self, describer: MetaTensorDescriber, t: Tensor) -> None: + # THINGS TO WORRY ABOUT: + # 1. Need to make sure that two tensors with the same id end up with the + # same id on the other side of the wire. + + metadata = describer.describe_tensor(t) + + # view_func is fine if it's either None or a _FakeTensorViewFunc. A + # custom one (which is basically a lambda) can't be serialized. + assert not metadata.view_func or isinstance( + metadata.view_func, torch._subclasses.meta_utils._FakeTensorViewFunc + ) + self.metadata = dataclasses.replace(metadata, fake_mode=None) + + # Some debugging/verification + for k in MetaTensorDesc._UNSERIALIZABLE: + if k in ("fake_mode", "view_func"): + continue + assert getattr(self.metadata, k) is None, ( + f"not None: {k}: {getattr(self.metadata, k)}" + ) + + def unpickle(self, unpickle_state: _UnpickleState) -> FakeTensor: + # TODO: make common w/ _output_from_cache_entry() in fake_tensor.py? + metadata = dataclasses.replace( + self.metadata, + fake_mode=unpickle_state.fake_mode, + ) + + # also need to set the fake_mode on the base of a tensor if it's a view + if metadata.is_view and metadata.base is not None: + new_base = dataclasses.replace( + metadata.base, + fake_mode=unpickle_state.fake_mode, + ) + metadata = dataclasses.replace(metadata, base=new_base) + + def with_fake( + make_meta_t: Callable[[], torch.Tensor], device: Union[torch.device, str] + ) -> FakeTensor: + with no_dispatch(): + return FakeTensor( + unpickle_state.fake_mode, + make_meta_t(), + # pyrefly: ignore [bad-argument-type] + device, + ) + + return unpickle_state.meta_converter.meta_tensor( + metadata, + unpickle_state.fake_mode.shape_env, + with_fake, + None, + None, + ) + + +class _TorchNumpyPickleData: + @classmethod + def reduce_helper( + cls, pickler: GraphPickler, obj: object + ) -> Optional[ + tuple[ + Callable[[Self, _UnpickleState], object], tuple[Self, _UnpickleStateToken] + ] + ]: + if data := cls.from_object(obj): + return (cls.unpickle, (data, pickler._unpickle_state)) + else: + return None + + def __init__(self, mod: str, name: str) -> None: + self.mod = mod + self.name = name + + def unpickle(self, unpickle_state: _UnpickleState) -> Callable[..., object]: + np = getattr(importlib.import_module(self.mod), self.name) + return torch._dynamo.variables.misc.get_np_to_tnp_map()[np] + + @classmethod + def from_object(cls, tnp: object) -> Optional[Self]: + if not callable(tnp): + return None + + tnp_to_np = torch._dynamo.variables.misc.get_tnp_to_np_map() + try: + if not (np := tnp_to_np.get(tnp)): + return None + except TypeError: + return None + + if not (mod := getattr(np, "__module__", None)): + mod = "numpy" + + if not (name := getattr(np, "__name__", None)): + return None + + # pyrefly: ignore [unbound-name] + assert np == getattr(importlib.import_module(mod), name) + # pyrefly: ignore [unbound-name] + return cls(mod, name) + + +class _GraphModulePickleData: + @classmethod + def reduce_helper( + cls, pickler: GraphPickler, obj: torch.fx.GraphModule + ) -> tuple[ + Callable[[Self, _UnpickleState], torch.fx.GraphModule], + tuple[Self, _UnpickleStateToken], + ]: + return cls.unpickle, ( + cls(obj, pickler.options), + pickler._unpickle_state, + ) + + def __init__(self, gm: torch.fx.GraphModule, options: Options) -> None: + # Need to do this to ensure the code is created for later pickling. + if isinstance(gm, torch.fx._lazy_graph_module._LazyGraphModule): + _python_code = gm._real_recompile() + else: + _python_code = gm.recompile() + self.gm_dict = gm.__dict__.copy() + del self.gm_dict["_graph"] + self.graph = _GraphPickleData(gm._graph, options) + + def unpickle(self, unpickle_state: _UnpickleState) -> torch.fx.GraphModule: + gm = torch.fx.GraphModule.__new__(torch.fx.GraphModule) + gm.__dict__ = self.gm_dict + gm._graph = self.graph.unpickle(gm, unpickle_state) + return gm + + +class _NodePickleData: + def __init__( + self, + node: torch.fx.Node, + mapping: dict[torch.fx.Node, "_NodePickleData"], + options: Options, + ) -> None: + self.args = pytree.tree_map_only(torch.fx.Node, lambda n: mapping[n], node.args) + self.kwargs = pytree.tree_map_only( + torch.fx.Node, lambda n: mapping[n], node.kwargs + ) + # -- self.graph = node.graph + self.name = node.name + self.op = node.op + self.target = _OpPickleData.pickle(node.target, options) + # self.input_nodes = node._input_nodes + # self.users = node.users + self.type = node.type + # self.sort_key = node._sort_key + # self.repr_fn = node._repr_fn + # self.meta = node.meta + self.meta = node.meta + + def unpickle( + self, + graph: torch.fx.Graph, + mapping: dict["_NodePickleData", torch.fx.Node], + unpickle_state: _UnpickleState, + ) -> torch.fx.Node: + args = pytree.tree_map_only(_NodePickleData, lambda n: mapping[n], self.args) + kwargs = pytree.tree_map_only( + _NodePickleData, lambda n: mapping[n], self.kwargs + ) + target = self.target.unpickle(unpickle_state) + assert callable(target) or isinstance(target, str) + node = graph.create_node(self.op, target, args, kwargs, self.name, self.type) + node.meta = self.meta + return node + + +class _OpPickleData: + @classmethod + def reduce_helper( + cls, pickler: GraphPickler, op: object + ) -> tuple[Callable[[_UnpickleState], object], tuple[_UnpickleStateToken]]: + result = cls.pickle(op, pickler.options) + return (result.unpickle, (pickler._unpickle_state,)) + + @classmethod + def pickle(cls, op: object, options: Options) -> "_OpPickleData": + if isinstance(op, str): + return _OpStrPickleData(op) + + if isinstance(getattr(op, "__wrapped__", None), AOTCompiledArtifact): + assert hasattr(op, "__wrapped__") + artifact = op.__wrapped__ + assert isinstance(artifact, AOTCompiledArtifact) + return _OpPrecompiledPickleData(artifact) + + name = torch.fx.Node._pretty_print_target(op) + + if isinstance(op, torch._ops.OpOverload): + return cls._pickle_op(name, _OpOverloadPickleData, options) + elif isinstance(op, torch._ops.OpOverloadPacket): + return cls._pickle_op(name, _OpOverloadPacketPickleData, options) + elif name.startswith(_OpFunctionPickleData.SUPPORTED_ROOTS): + root, detail = name.split(".", 1) + return _OpFunctionPickleData(root, detail) + else: + # TODO: raise a BypassFxGraphCache so we will just bypass this one... + raise NotImplementedError(f"TARGET: {type(op)} {op} {name}") + + @staticmethod + def _pickle_op( + name: str, + datacls: Union[ + type["_OpOverloadPickleData"], type["_OpOverloadPacketPickleData"] + ], + options: Options, + ) -> "_OpPickleData": + if (ops_filter := options.ops_filter) and not ops_filter(name): + from torch._inductor.codecache import BypassFxGraphCache + + raise BypassFxGraphCache(f"Unable to pickle non-standard op: {name}") + return datacls(name) + + @abstractmethod + def unpickle(self, unpickle_state: _UnpickleState) -> object: + pass + + @classmethod + def _lookup_global_by_name(cls, name: str) -> object: + """ + Like `globals()[name]` but supports dotted names. + """ + if "." in name: + mod, rest = name.split(".", 1) + root = globals()[mod] + return cls._getattr_by_name(root, rest) + else: + return globals()[name] + + @staticmethod + def _getattr_by_name(root: object, name: str) -> object: + """ + Like `getattr(root, name)` but supports dotted names. + """ + while "." in name: + mod, name = name.split(".", 1) + root = getattr(root, mod) + return getattr(root, name) + + +class _OpStrPickleData(_OpPickleData): + def __init__(self, name: str) -> None: + self.name = name + + def unpickle(self, unpickle_state: _UnpickleState) -> str: + return self.name + + +class _OpOverloadPickleData(_OpPickleData): + def __init__(self, name: str) -> None: + self.name = name + + def unpickle(self, unpickle_state: _UnpickleState) -> torch._ops.OpOverload: + obj = self._lookup_global_by_name(self.name) + assert isinstance(obj, torch._ops.OpOverload) + return obj + + +class _OpOverloadPacketPickleData(_OpPickleData): + def __init__(self, name: str) -> None: + self.name = name + + def unpickle(self, unpickle_state: _UnpickleState) -> torch._ops.OpOverloadPacket: + obj = self._lookup_global_by_name(self.name) + assert isinstance(obj, torch._ops.OpOverloadPacket) + return obj + + +class _OpPrecompiledPickleData(_OpPickleData): + def __init__(self, artifact: AOTCompiledArtifact) -> None: + self.contents = artifact.serialize() + + def unpickle(self, unpickle_state: _UnpickleState) -> object: + precompiled_artifact = AOTCompiledArtifact.deserialize(self.contents) + import functools + + @functools.wraps(precompiled_artifact) + def wrapped(*args: Any) -> Any: + return precompiled_artifact(*args) + + return wrapped + + +class _OpFunctionPickleData(_OpPickleData): + """ + Supports pickling a set of standard/common functions + These must be prefixed with the full namespace in order to properly + be pickled (i.e `einops.rearrange` and not `from einops import rearrange`) + """ + + # Static variable listing supported root names + SUPPORTED_ROOTS = ("builtins.", "math.", "torch.", "operator.", "einops.") + + def __init__(self, root: str, name: str) -> None: + self.root = root + self.name = name + + def unpickle(self, unpickle_state: _UnpickleState) -> object: + if self.root == "builtins": + return __builtins__.get(self.name) # type: ignore[attr-defined] + elif self.root == "math": + import math + + return self._getattr_by_name(math, self.name) + elif self.root == "torch": + return self._getattr_by_name(torch, self.name) + elif self.root == "operator": + import operator + + return self._getattr_by_name(operator, self.name) + elif self.root == "einops": + import einops + + return self._getattr_by_name(einops, self.name) + else: + raise NotImplementedError + + +class _GraphPickleData: + def __init__(self, graph: torch.fx.Graph, options: Options) -> None: + self.tracer_cls = graph._tracer_cls + self.tracer_extras = graph._tracer_extras + + nodes: dict[torch.fx.Node, _NodePickleData] = {} + for node in graph.nodes: + nodes[node] = _NodePickleData(node, nodes, options) + self.nodes = tuple(nodes.values()) + + # Unpickled variables: + # self._used_names = graph._used_names + # -- self._insert = self._root.prepend + # self._len = graph._len + # self._graph_namespace = graph._graph_namespace + # self._owning_module = graph._owning_module + # self._codegen = graph._codegen + # self._co_fields: Dict[str, Any] = graph._co_fields + # -- self._find_nodes_lookup_table = _FindNodesLookupTable() + + def unpickle( + self, gm: torch.fx.GraphModule, unpickle_state: _UnpickleState + ) -> torch.fx.Graph: + graph = torch.fx.Graph(gm, self.tracer_cls, self.tracer_extras) + + nodes: dict[_NodePickleData, torch.fx.Node] = {} + for nd in self.nodes: + nodes[nd] = nd.unpickle(graph, nodes, unpickle_state) + + return graph + + +class _TracingContextPickleData: + @classmethod + def reduce_helper( + cls, pickler: GraphPickler, obj: torch._guards.TracingContext + ) -> tuple[ + Callable[[Self, _UnpickleState], torch._guards.TracingContext], + tuple[Self, _UnpickleStateToken], + ]: + return ( + cls.unpickle, + ( + cls(obj), + pickler._unpickle_state, + ), + ) + + def __init__(self, context: TracingContext) -> None: + # TODO: Do we really need all of this? + self.module_context = context.module_context + self.frame_summary_stack = context.frame_summary_stack + self.loc_in_frame = context.loc_in_frame + self.aot_graph_name = context.aot_graph_name + self.params_flat = context.params_flat + self.params_flat_unwrap_subclasses = context.params_flat_unwrap_subclasses + self.params_unwrapped_to_flat_index = context.params_unwrapped_to_flat_index + self.output_strides = context.output_strides + self.force_unspec_int_unbacked_size_like = ( + context.force_unspec_int_unbacked_size_like + ) + # Not saved (because it's difficult and maybe not needed?): + # self.fw_metadata = context.fw_metadata + # self.guards_context = None + # self.global_context = None + # self.fake_mode = None + # self.fakify_first_call = None + # self.hop_dispatch_set_cache = None + # self.tensor_to_context = context.tensor_to_context + + def unpickle(self, unpickle_state: _UnpickleState) -> TracingContext: + context = TracingContext(unpickle_state.fake_mode) + context.module_context = self.module_context + context.frame_summary_stack = self.frame_summary_stack + context.loc_in_frame = self.loc_in_frame + context.aot_graph_name = self.aot_graph_name + context.params_flat = self.params_flat + context.params_flat_unwrap_subclasses = self.params_flat_unwrap_subclasses + context.params_unwrapped_to_flat_index = self.params_unwrapped_to_flat_index + context.output_strides = self.output_strides + context.force_unspec_int_unbacked_size_like = ( + self.force_unspec_int_unbacked_size_like + ) + return context diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/_lazy_graph_module.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/_lazy_graph_module.py new file mode 100644 index 0000000000000000000000000000000000000000..83ce51fddd0405213ce86e95bafb2c61503e262b --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/_lazy_graph_module.py @@ -0,0 +1,180 @@ +# mypy: allow-untyped-defs +from contextlib import contextmanager + +from torch.fx.graph_module import ( + _format_import_block, + GraphModule, + reduce_graph_module, + reduce_package_graph_module, +) +from torch.package import PackageExporter, sys_importer + +from ._compatibility import compatibility + + +_use_lazy_graph_module_flag = False +_force_skip_lazy_graph_module_flag = False + + +@compatibility(is_backward_compatible=False) +@contextmanager +def _force_skip_lazy_graph_module(): + """ + Skip using lazy graph module disregarding the setting of _use_lazy_graph_module. + Use to skip _LazyGraphModule when testing inductor torchscript related backend. + + torch.jit.script a _LazyGraphModule results in following error: + https://gist.github.com/shunting314/5143654c8084aed84ecd19b818258a69 + """ + try: + global _force_skip_lazy_graph_module_flag + prior = _force_skip_lazy_graph_module_flag + _force_skip_lazy_graph_module_flag = True + yield + finally: + _force_skip_lazy_graph_module_flag = prior + + +@compatibility(is_backward_compatible=False) +@contextmanager +def _use_lazy_graph_module(should_use: bool): + try: + global _use_lazy_graph_module_flag + prior = _use_lazy_graph_module_flag + _use_lazy_graph_module_flag = ( + should_use and not _force_skip_lazy_graph_module_flag + ) + yield + finally: + _use_lazy_graph_module_flag = prior + + +@compatibility(is_backward_compatible=False) +def _get_graph_module_cls(): + return _LazyGraphModule if _use_lazy_graph_module_flag else GraphModule + + +def _make_graph_module(*args, graph_module_cls=None, **kwargs): + if graph_module_cls is None: + graph_module_cls = _get_graph_module_cls() + + return graph_module_cls(*args, **kwargs) + + +@compatibility(is_backward_compatible=False) +class _LazyGraphModule(GraphModule): + """ + The main difference between _LazyGraphModule and GraphModule is how recompile happens. + GraphModule will do a 'recompile' call to generate python code and the forward method when it's + constructed. Later on if the graph get updated, recompile method can be called again to refresh + the saved python code and forward method. + + However in some cases especially in inductor, the recompilation can be a waste since we never + check the python code for the graph module or call its forward method. A few more concreate + examples regarding pattern matching fx passes in inductor: + 1. some passes will update the graph to be compiled and then call recompile on the GraphModule. + 2. some passes will trace small pattern function to search it in the graph being compiled and + replace the match with the traced graph of a replacement function. The pattern graph and + replacement graph are quite small but there are large amount of them. Doing GraphModule.recompile + for them in GraphModule.__init__ is also a waste of time. + + However simply skip calling GraphModule.recompile in these scenarios is also dangeruous. + People may want to check the python code or call the GraphModule's forward method for debugging purposes. + + The way _LazyGraphModule solves it is, we override the recompile method to just mark the + need for recompilation but does not do the actual recompilation. Later on if people really + access the compiled python code or call the GraphModule's forward method, we do the real + recompilation. + """ + + @classmethod + def from_graphmodule(cls, gm: GraphModule): + if isinstance(gm, _LazyGraphModule): + return gm + else: + return _LazyGraphModule(gm, gm.graph) + + @staticmethod + def force_recompile(gm): + """ + Sometimes we need force a recompile as a workaround + - we want to do the real recompilation before symbolic_trace to avoid error: + https://gist.github.com/shunting314/75549c2e82ae07ac1139c94a3583d259 + """ + if isinstance(gm, _LazyGraphModule): + gm.real_recompile() + + def real_recompile(self): + if self._needs_recompile(): + self._real_recompile() + + @classmethod + def _needs_recompile(cls): + return cls.forward is cls._lazy_forward + + def _lazy_forward(self, *args, **kwargs): + # Call self.real_recompile() rather than self._real_recompile() here. + # The _lazy_forward method may be saved and call repeatedly. + # Calling self.real_recompile can make sure we skip recompilation if + # we have already done so. + self.real_recompile() + assert not self._needs_recompile() + + # call `__call__` rather than 'forward' since recompilation may + # install a wrapper for `__call__` to provide a customized error + # message. + return self(*args, **kwargs) + + forward = _lazy_forward + + def __reduce_package__(self, exporter: PackageExporter): + """ + Follow GraphModule.__reduce__ but call 'self._real_recompile' rather + than 'self.recompile' since for a _LazyGraphModule, self.recompile just + mark the need of recompilation and does not return the PythonCode object. + """ + python_code = self._real_recompile() + dict_without_graph = self.__dict__.copy() + dict_without_graph["_graphmodule_cls_name"] = self.__class__.__name__ + del dict_without_graph["_graph"] + + generated_module_name = f"fx-generated._{exporter.get_unique_id()}" + import_block = _format_import_block(python_code.globals, exporter.importer) + module_code = import_block + self.code + exporter.save_source_string(generated_module_name, module_code) + return ( + reduce_package_graph_module, + (dict_without_graph, generated_module_name), + ) + + def __reduce__(self): + """ + Follow GraphModule.__reduce__ but call 'self._real_recompile' rather + than 'self.recompile' since for a _LazyGraphModule, self.recompile just + mark the need of recompilation and does not return the PythonCode object. + """ + python_code = self._real_recompile() + dict_without_graph = self.__dict__.copy() + import_block = _format_import_block(python_code.globals, sys_importer) + del dict_without_graph["_graph"] + return (reduce_graph_module, (dict_without_graph, import_block)) + + def _real_recompile(self): + return super().recompile() + + @classmethod + def recompile(cls): + cls.forward = cls._lazy_forward + + @property + def code(self) -> str: + self.real_recompile() + return super().code + + def __str__(self) -> str: + """ + str(GraphModule) will access the _code attribute. Make sure recompile + happens so _code attribute is available. + """ + self.real_recompile() + return super().__str__() diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/_pytree.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/_pytree.py new file mode 100644 index 0000000000000000000000000000000000000000..bfb62f871eb6f5ba7788fd4920e292dda95738c6 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/_pytree.py @@ -0,0 +1,114 @@ +from collections import namedtuple +from collections.abc import Callable +from typing import Any, Optional, TypeVar +from typing_extensions import NamedTuple + +import torch.return_types +from torch.utils._pytree import PyTree, tree_flatten, TreeSpec + + +FlattenFuncSpec = Callable[[PyTree, TreeSpec], list] +FlattenFuncExactMatchSpec = Callable[[PyTree, TreeSpec], bool] + +SUPPORTED_NODES: dict[type[Any], FlattenFuncSpec] = {} +SUPPORTED_NODES_EXACT_MATCH: dict[type[Any], Optional[FlattenFuncExactMatchSpec]] = {} + +_T = TypeVar("_T") +_K = TypeVar("_K") +_V = TypeVar("_V") + + +def register_pytree_flatten_spec( + cls: type[Any], + flatten_fn_spec: FlattenFuncSpec, + flatten_fn_exact_match_spec: Optional[FlattenFuncExactMatchSpec] = None, +) -> None: + SUPPORTED_NODES[cls] = flatten_fn_spec + SUPPORTED_NODES_EXACT_MATCH[cls] = flatten_fn_exact_match_spec + + +def _deregister_pytree_flatten_spec( + cls: type[Any], +) -> None: + del SUPPORTED_NODES[cls] + del SUPPORTED_NODES_EXACT_MATCH[cls] + + +def tree_flatten_spec( + pytree: PyTree, + spec: TreeSpec, +) -> list[Any]: + if spec.is_leaf(): + return [pytree] + # I guess these exist for BC, FC reasons. + # In general, we should be able to directly + # use pytree tree flattener to flatten them, + # as export serializes the pytree separately. + # Will remove it in follow up PR. + if spec.type in SUPPORTED_NODES: + flatten_fn_spec = SUPPORTED_NODES[spec.type] + child_pytrees = flatten_fn_spec(pytree, spec) + result = [] + for child, child_spec in zip(child_pytrees, spec.children()): + flat = tree_flatten_spec(child, child_spec) + result += flat + return result + flat_result, real_spec = tree_flatten(pytree) + if spec != real_spec: + raise RuntimeError( + f"Real spec {real_spec} of object {pytree} is different from expected spec {spec}. " + f"Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml" + ) + return flat_result + + +def _dict_flatten_spec(d: dict[_K, _V], spec: TreeSpec) -> list[_V]: + return [d[k] for k in spec.context] + + +def _list_flatten_spec(d: list[_T], spec: TreeSpec) -> list[_T]: + return [d[i] for i in range(spec.num_children)] + + +def _tuple_flatten_spec(d: tuple[_T, ...], spec: TreeSpec) -> list[_T]: + return [d[i] for i in range(spec.num_children)] + + +def _namedtuple_flatten_spec(d: NamedTuple, spec: TreeSpec) -> list[Any]: + return [d[i] for i in range(spec.num_children)] + + +def _dict_flatten_spec_exact_match(d: dict[_K, _V], spec: TreeSpec) -> bool: + return len(d) == spec.num_children + + +def _list_flatten_spec_exact_match(d: list[_T], spec: TreeSpec) -> bool: + return len(d) == spec.num_children + + +def _tuple_flatten_spec_exact_match(d: tuple[_T, ...], spec: TreeSpec) -> bool: + return len(d) == spec.num_children + + +def _namedtuple_flatten_spec_exact_match(d: NamedTuple, spec: TreeSpec) -> bool: + return len(d) == spec.num_children + + +register_pytree_flatten_spec(dict, _dict_flatten_spec, _dict_flatten_spec_exact_match) +register_pytree_flatten_spec(list, _list_flatten_spec, _list_flatten_spec_exact_match) +register_pytree_flatten_spec( + tuple, + _tuple_flatten_spec, + _tuple_flatten_spec_exact_match, +) +for return_type in torch.return_types.all_return_types: + register_pytree_flatten_spec( + return_type, + _tuple_flatten_spec, + _tuple_flatten_spec_exact_match, + ) +register_pytree_flatten_spec( + namedtuple, # type: ignore[arg-type] + _namedtuple_flatten_spec, + _namedtuple_flatten_spec_exact_match, +) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py new file mode 100644 index 0000000000000000000000000000000000000000..dd3482f3e04602ef800222e6a6f88727c681705d --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py @@ -0,0 +1,1368 @@ +# mypy: allow-untyped-defs +import builtins +import collections +import contextlib +import copy +import functools +import inspect +import logging +import math +import os +import warnings +from collections.abc import Callable +from itertools import chain +from types import CodeType, FunctionType, ModuleType +from typing import Any, get_args, NamedTuple, Optional, TypeAlias, Union + +import torch +import torch.utils._pytree as pytree +from torch._C import ScriptObject # type: ignore[attr-defined] +from torch._library.fake_class_registry import FakeScriptObject +from torch._library.opaque_object import is_opaque_reference_type, is_opaque_type + +from ._compatibility import compatibility +from ._lazy_graph_module import _make_graph_module +from .graph import _PyTreeCodeGen, _PyTreeInfo, Graph +from .graph_module import GraphModule +from .node import Argument, base_types, map_aggregate +from .proxy import ParameterProxy, Proxy, Scope, ScopeContextManager, TracerBase + + +log = logging.getLogger(__name__) + +HAS_VARSTUFF = inspect.CO_VARARGS | inspect.CO_VARKEYWORDS + +# These need to run in global scope to handle nested calls correctly +_orig_module_call: Callable = torch.nn.Module.__call__ +_orig_module_getattr: Callable = torch.nn.Module.__getattr__ + +_proxyable_classes: dict[type, None] = {} + +_is_fx_tracing_flag = False + +_ConstantAttributeType: TypeAlias = Union[ + torch.Tensor, torch.ScriptObject, FakeScriptObject, pytree.TreeSpec +] + +_constant_attribute_types = get_args(_ConstantAttributeType) + + +# We only want to print this once to avoid flooding logs +@functools.lru_cache +def is_fx_tracing_warning(): + log.warning( + "is_fx_tracing will return true for both fx.symbolic_trace and " + "torch.export. Please use " + "is_fx_tracing_symbolic_tracing() for specifically fx.symbolic_trace " + "or torch.compiler.is_compiling() for specifically torch.export/compile." + ) + + +def is_fx_tracing(): + is_fx_tracing_warning() + return _is_fx_tracing_flag + + +def is_fx_symbolic_tracing(): + return _is_fx_tracing_flag and not torch.compiler.is_compiling() + + +@compatibility(is_backward_compatible=True) +class ProxyableClassMeta(type): + """ + ProxyableClassMeta allows you to make construction of a given Python class + symbolically traceable. For example:: + + import torch + import torch.fx + + + class TensorPair(metaclass=torch.fx.ProxyableClassMeta): + def __init__(self, left, right): + self.left, self.right = left, right + + def add(self, other): + l = self.left + other.left + r = self.right + other.right + return TensorPair(l, r) + + def mul(self, other): + l = self.left * other.left + r = self.right * other.right + return TensorPair(l, r) + + + def use_tensor_pair_ctor(x: TensorPair, y: torch.Tensor): + s = x.add(TensorPair(y, y)) + return s.mul(x) + + + x = TensorPair(torch.randn(5, 3), torch.randn(5, 3)) + y = torch.randn(5, 3) + ref_out = use_tensor_pair_ctor(x, y) + + traced = torch.fx.symbolic_trace(use_tensor_pair_ctor) + print(traced.code) + ''' + def forward(self, x : __main___TensorPair, y : torch.Tensor): + tensor_pair = __main___TensorPair(y, y); y = None + add = x.add(tensor_pair); tensor_pair = None + mul = add.mul(x); add = x = None + return mul + ''' + + From this example, we can see that construction of a class (``TensorPair``) + defined with ``ProxyableClassMeta`` as metaclass can be recorded in symbolic + tracing. + """ + + def __init__(cls, name, bases, attrs): + _proxyable_classes.setdefault(cls) + super().__init__(name, bases, attrs) + + def __call__(cls, *args, **kwargs): + instance = cls.__new__(cls) # type: ignore[call-overload] + + if not is_fx_tracing(): + cls.__init__(instance, *args, **kwargs) # type: ignore[misc] + return instance + + found_proxies = [] + + def check_proxy(a): + if isinstance(a, Proxy): + found_proxies.append(a) + + map_aggregate(args, check_proxy) + map_aggregate(kwargs, check_proxy) + + if len(found_proxies) != 0: + tracer = found_proxies[0].tracer + return tracer.create_proxy("call_function", cls, args, kwargs) + else: + cls.__init__(instance, *args, **kwargs) # type: ignore[misc] + return instance + + +def _patch_function(fn: FunctionType, nargs: int) -> FunctionType: + co = fn.__code__ + co_flags = co.co_flags & ~HAS_VARSTUFF + co_args: tuple + if hasattr(co, "co_qualname"): + # Python-3.11+ code signature + co_args = ( + nargs, + 0, + 0, + co.co_nlocals, + co.co_stacksize, + co_flags, + co.co_code, + co.co_consts, + co.co_names, + co.co_varnames, + co.co_filename, + co.co_name, + co.co_qualname, # type: ignore[attr-defined] + co.co_firstlineno, + co.co_linetable, + co.co_exceptiontable, # type: ignore[attr-defined] + co.co_freevars, + co.co_cellvars, + ) + elif hasattr(co, "co_posonlyargcount"): + co_args = ( + nargs, + 0, + 0, + co.co_nlocals, + co.co_stacksize, + co_flags, + co.co_code, + co.co_consts, + co.co_names, + co.co_varnames, + co.co_filename, + co.co_name, + co.co_firstlineno, + co.co_lnotab, + co.co_freevars, + co.co_cellvars, + ) + else: + co_args = ( + nargs, + 0, + co.co_nlocals, + co.co_stacksize, + co_flags, + co.co_code, + co.co_consts, + co.co_names, + co.co_varnames, + co.co_filename, + co.co_name, + co.co_firstlineno, + co.co_lnotab, + co.co_freevars, + co.co_cellvars, + ) + new_code = CodeType(*co_args) # type: ignore[arg-type] + return FunctionType( + new_code, fn.__globals__, fn.__name__, fn.__defaults__, fn.__closure__ + ) + + # we need to insert placeholder nodes for *args and **kwargs + # we can't call this function normally, otherwise it would try to unpack them + # instead, let's make python think that args and kwargs are normal variables + + +@compatibility(is_backward_compatible=False) +class PHBase: + """ + Object representing an input placeholder to `concrete_args` + """ + + def __repr__(self): + return "PH" + + +PH = PHBase() + + +@compatibility(is_backward_compatible=False) +class PHWithMeta(PHBase): + """ + Object representing an input placeholder to `concrete_args` + """ + + def __init__(self, ph_key: Optional[str] = None): + super().__init__() + + # Provide a hey for user to identify placeholder node during analysis + self.ph_key = ph_key + + +def _transfer_attrs(fr, to): + for attr_name in dir(fr): + attr_val = getattr(fr, attr_name) + if ( + not callable(attr_val) + and not attr_name.startswith("__") + and not hasattr(to, attr_name) + ): + setattr(to, attr_name, attr_val) + + +@compatibility(is_backward_compatible=True) +class Tracer(TracerBase): + # Reference: https://github.com/pytorch/pytorch/issues/54354 + # The first line of this docstring overrides the one Sphinx generates for the + # documentation. We need it so that Sphinx doesn't leak `math`s path from the + # build environment (e.g. ` None: + # This method's signature is overridden by the first line of this class' + # docstring. If this method's signature is modified, the signature that + # overrides it also should be modified accordingly. + + """ + Construct a Tracer object. + + Args: + + autowrap_modules (Tuple[ModuleType]): defaults to `(math, )`, + Python modules whose functions should be wrapped automatically + without needing to use fx.wrap(). Backward-compatibility for + this parameter is guaranteed. + + autowrap_functions (Tuple[Callable, ...]): defaults to `()`, + Python functions that should be wrapped automatically without + needing to use fx.wrap(). Backward compatibility for this + parameter is guaranteed. + + param_shapes_constant (bool): When this flag is set, calls to shape, + size and a few other shape like attributes of a module's parameter + will be evaluated directly, rather than returning a new Proxy value + for an attribute access. Backward compatibility for this parameter + is guaranteed. + """ + + super().__init__() + + # Functions we will eagerly wrap when we see them while tracing + # this captures both `math.sqrt()` and `from math import sqrt` automatically + self._autowrap_function_ids: set[int] = { + id(value) + for name, value in chain.from_iterable( + m.__dict__.items() for m in autowrap_modules + ) + if not name.startswith("_") and callable(value) + } + self._autowrap_function_ids.update({id(f) for f in autowrap_functions}) + + # Python modules to apply autowrap to at the start, in addition to + # modules we see while tracing + self._autowrap_search: list[ModuleType] = list(autowrap_modules) + self.param_shapes_constant = param_shapes_constant + + self.submodule_paths: Optional[dict[torch.nn.Module, str]] = None + self.root_module_name: str = "" + # Maps the containing module's name to the operator name + self.scope = Scope("", None) + # Records the module call stack + self.module_stack = collections.OrderedDict() + self.num_calls: dict[str, int] = {} + # Mapping of node name to module scope + self.node_name_to_scope: dict[str, tuple[str, type]] = {} + + _qualname_counter: dict[str, int] = collections.defaultdict(int) + + @compatibility(is_backward_compatible=True) + def get_fresh_qualname(self, prefix: str) -> str: + """ + Gets a fresh name for a prefix and returns it. This function ensures + that it will not clash with an existing attribute on the graph. + """ + # The idea here is that if the module doesn't have this prefix at all we + # should reset the counter to start from the beginning + # It's a ... little bit hacky (doesn't cover all cases) but the precise + # naming of the prefixes isn't a correctness issue, just a niceness + # issue + qualname = f"{prefix}0" + if not hasattr(self.root, qualname): + self._qualname_counter[prefix] = 0 + return qualname + + i = self._qualname_counter[prefix] + while True: + qualname = f"{prefix}{i}" + i += 1 + if not hasattr(self.root, qualname): + break + self._qualname_counter[prefix] = i + + return qualname + + @compatibility(is_backward_compatible=True) + def create_arg(self, a: Any) -> "Argument": + """ + A method to specify the behavior of tracing when preparing values to + be used as arguments to nodes in the ``Graph``. + + By default, the behavior includes: + + #. Iterate through collection types (e.g. tuple, list, dict) and recursively + call ``create_args`` on the elements. + #. Given a Proxy object, return a reference to the underlying IR ``Node`` + #. Given a non-Proxy Tensor object, emit IR for various cases: + + * For a Parameter, emit a ``get_attr`` node referring to that Parameter + * For a non-Parameter Tensor, store the Tensor away in a special + attribute referring to that attribute. + + This method can be overridden to support more types. + + Args: + + a (Any): The value to be emitted as an ``Argument`` in the ``Graph``. + + + Returns: + + The value ``a`` converted into the appropriate ``Argument`` + """ + # The base tracer is used to construct Graphs when there is no associated + # module hierarchy, so it can never create parameter references. + # The default tracer adds the ability to refer to parameters when + # tracing modules. + if isinstance(a, torch.nn.Parameter): + for n, p in self.root.named_parameters(): + if a is p: + return self.create_node("get_attr", n, (), {}) + raise NameError("parameter is not a member of this module") + elif isinstance(a, torch.Tensor): + for n_, p_ in self.root.named_buffers(): + if a is p_: + return self.create_node("get_attr", n_, (), {}) + elif isinstance(a, torch.nn.Module): + for n_, p_ in self.root.named_modules(): + if a is p_: + return self.create_node("get_attr", n_, (), {}) + # For NamedTuple instances that appear literally as args, we emit + # a node to construct the NamedTuple and use that Node as the argument. + if isinstance(a, tuple) and hasattr(a, "_fields"): + args = tuple(self.create_arg(elem) for elem in a) + return self.create_node("call_function", a.__class__, args, {}) + + # Tensors do not have a reliable string repr() from which they can be + # constructed (and we probably don't want to rely on that, either), so + # for any constant Tensor values we encounter, first search for if they + # are an attribute of some module in the module hierarchy. If so, emit + # a get_attr to retrieve that tensor. Otherwise, we'll store away the + # tensor value into a special attribute on the Module s.t. we can + # retrieve it with a get_attr. + if isinstance(a, _constant_attribute_types) or ( + is_opaque_reference_type(type(a)) + ): + qualname: Optional[str] = self.tensor_attrs.get( + a + ) # pyrefly: ignore[no-matching-overload] + + # Tensor was not found in the Module hierarchy, stow it away in a + # special attribute and set the qualname to refer to that + if not qualname: + if isinstance(a, torch.Tensor): + base_name = "_tensor_constant" + elif isinstance(a, (FakeScriptObject, ScriptObject)): + base_name = "_torchbind_obj" + elif isinstance(a, pytree.TreeSpec): + base_name = "_tree_spec_constant" + elif is_opaque_type(type(a)): + base_name = "_opaque_obj" + else: + raise RuntimeError( + f"cannot create constant arg for {a} of type {type(a)}." + ) + qualname = self.get_fresh_qualname(base_name) + assert isinstance(qualname, str) + self.tensor_attrs[a] = ( # pyrefly: ignore[unsupported-operation] + qualname + ) + setattr(self.root, qualname, a) + + return self.create_node("get_attr", qualname, (), {}) + + if type(a) in _proxyable_classes: + # This is an instance of a proxyable class for which we did not + # witness its construction. Intern this as a constant attribute + + # TODO: binary search + qualname = self.get_fresh_qualname(f"_{a.__class__.__name__}_constant_") + assert isinstance(qualname, str) + setattr(self.root, qualname, a) + + return self.create_node("get_attr", qualname, (), {}) + return super().create_arg(a) + + @compatibility(is_backward_compatible=True) + def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool: + """ + A method to specify whether a given ``nn.Module`` is a "leaf" module. + + Leaf modules are the atomic units that appear in + the IR, referenced by ``call_module`` calls. By default, + Modules in the PyTorch standard library namespace (torch.nn) + are leaf modules. All other modules are traced through and + their constituent ops are recorded, unless specified otherwise + via this parameter. + + Args: + + m (Module): The module being queried about + module_qualified_name (str): The path to root of this module. For example, + if you have a module hierarchy where submodule ``foo`` contains + submodule ``bar``, which contains submodule ``baz``, that module will + appear with the qualified name ``foo.bar.baz`` here. + """ + return ( + m.__module__.startswith("torch.nn") + or m.__module__.startswith("torch.ao.nn") + ) and not isinstance(m, torch.nn.Sequential) + + @compatibility(is_backward_compatible=True) + def path_of_module(self, mod: torch.nn.Module) -> str: + """ + Helper method to find the qualified name of ``mod`` in the Module hierarchy + of ``root``. For example, if ``root`` has a submodule named ``foo``, which has + a submodule named ``bar``, passing ``bar`` into this function will return + the string "foo.bar". + + Args: + + mod (str): The ``Module`` to retrieve the qualified name for. + """ + # Prefer the O(1) algorithm + if self.submodule_paths: + path = self.submodule_paths.get(mod) + if path is None: + raise NameError("module is not installed as a submodule") + assert isinstance(path, str) + return path + # O(N^2) fallback in the case that we didn't store the submodule + # paths. + else: + for n, p in self.root.named_modules(): + if mod is p: + return n + raise NameError("module is not installed as a submodule") + + @compatibility(is_backward_compatible=True) + def call_module( + self, + m: torch.nn.Module, + forward: Callable[..., Any], + args: tuple[Any, ...], + kwargs: dict[str, Any], + ) -> Any: + """ + Method that specifies the behavior of this ``Tracer`` when it encounters + a call to an ``nn.Module`` instance. + + By default, the behavior is to check if the called module is a leaf module + via ``is_leaf_module``. If it is, emit a ``call_module`` node referring to + ``m`` in the ``Graph``. Otherwise, call the ``Module`` normally, tracing through + the operations in its ``forward`` function. + + This method can be overridden to--for example--create nested traced + GraphModules, or any other behavior you would want while tracing across + ``Module`` boundaries. + + Args: + + m (Module): The module for which a call is being emitted + forward (Callable): The forward() method of the ``Module`` to be invoked + args (Tuple): args of the module callsite + kwargs (Dict): kwargs of the module callsite + + Return: + + The return value from the Module call. In the case that a ``call_module`` + node was emitted, this is a ``Proxy`` value. Otherwise, it is whatever + value was returned from the ``Module`` invocation. + """ + module_qualified_name = self.path_of_module(m) + with ScopeContextManager( + self.scope, Scope(module_qualified_name, type(m)) + ) as _scope: + # module_stack is an ordered dict so writing then deleting the + # entry is equivalent to push/pop on a list + num_calls = self.num_calls.get(module_qualified_name, 0) + module_key = ( + f"{_scope.module_path}@{num_calls}" + if num_calls > 0 + else _scope.module_path + ) + self.module_stack[module_key] = (module_qualified_name, _scope.module_type) + self.num_calls[module_qualified_name] = num_calls + 1 + if not self.is_leaf_module(m, module_qualified_name): + ret_val = forward(*args, **kwargs) + else: + ret_val = self.create_proxy( + "call_module", module_qualified_name, args, kwargs + ) + key, _ = self.module_stack.popitem(last=True) + assert key == module_key, f" Unexpected key {key}" + + return ret_val + + @compatibility(is_backward_compatible=False) + def getattr(self, attr: str, attr_val: Any, parameter_proxy_cache: dict[str, Any]): + """ + Method that specifies the behavior of this ``Tracer`` when we call getattr + on a call to an ``nn.Module`` instance. + + By default, the behavior is to return a proxy value for the attribute. It + also stores the proxy value in the ``parameter_proxy_cache``, so that future + calls will reuse the proxy rather than creating a new one. + + This method can be overridden to --for example-- not return proxies when + querying parameters. + + Args: + + attr (str): The name of the attribute being queried + attr_val (Any): The value of the attribute + parameter_proxy_cache (Dict[str, Any]): A cache of attr names to proxies + + Return: + + The return value from the getattr call. + """ + + def maybe_get_proxy_for_attr( + attr_val, collection_to_search, parameter_proxy_cache + ): + for n, p in collection_to_search: + if attr_val is p: + if n not in parameter_proxy_cache: + kwargs = {} + if ( + "proxy_factory_fn" + in inspect.signature(self.create_proxy).parameters + ): + kwargs["proxy_factory_fn"] = ( + # pyrefly: ignore [unsupported-operation] + None + if not self.param_shapes_constant + else lambda node: ParameterProxy( + self, node, n, attr_val + ) + ) + val_proxy = self.create_proxy("get_attr", n, (), {}, **kwargs) # type: ignore[arg-type] + parameter_proxy_cache[n] = val_proxy + return parameter_proxy_cache[n] + return None + + if isinstance(attr_val, torch.nn.Parameter): + maybe_parameter_proxy = maybe_get_proxy_for_attr( + attr_val, self.root.named_parameters(), parameter_proxy_cache + ) + if maybe_parameter_proxy is not None: + return maybe_parameter_proxy + + if self.proxy_buffer_attributes and isinstance(attr_val, torch.Tensor): + maybe_buffer_proxy = maybe_get_proxy_for_attr( + attr_val, self.root.named_buffers(), parameter_proxy_cache + ) + if maybe_buffer_proxy is not None: + return maybe_buffer_proxy + + return attr_val + + # This method will be refactored + @compatibility(is_backward_compatible=False) + def create_args_for_root(self, root_fn, is_module, concrete_args=None): + """ + Create ``placeholder`` nodes corresponding to the signature of the ``root`` + Module. This method introspects root's signature and emits those + nodes accordingly, also supporting ``*args`` and ``**kwargs``. + """ + # In some cases, a function or method has been decorated with a wrapper + # defined via ``functools.wraps``. In this case, the outer code object + # will likely not contain the actual parameters we care about, so unwrap + # the function to get to the innermost callable. + fn_for_analysis = inspect.unwrap(root_fn) + co = fn_for_analysis.__code__ + total_args = co.co_argcount + co.co_kwonlyargcount + orig_args = list(co.co_varnames) + names_iter = iter(co.co_varnames) + args: list[Any] = [] + skip_arg_idx = 0 + if is_module: + if total_args == 0: + raise RuntimeError( + "``self`` argument cannot be part of *args expansion!" + ) + skip_arg_idx = 1 + next(names_iter) # skip self + args.append(self.root) + + sig = inspect.signature(fn_for_analysis) + + # This covers the very specific case where we are passing in flat + # concrete_args as a tuple, but our traced fn takes (*args, **kwargs). + # In this case, just take the concrete_args and pass them through. + name_idx = 0 + if ( + isinstance(concrete_args, tuple) + and len(concrete_args) > 0 + and (co.co_flags & HAS_VARSTUFF) + and total_args == 1 + ): + for concrete_arg in concrete_args: + out = self.create_proxy("placeholder", f"input_{name_idx}", (), {}) + if isinstance(concrete_arg, PHBase): + if concrete_arg != PH: + # Transfer attrs in the case where you're using a placeholder other + # than the singleton PH (PH has no attributes to transfer). + # Proxies were created out of the placeholders. + # Transfer any metadata (put on the placeholders in the form of + # attributes set by the user) from the placeholder to the + # underlying nodes (the proxy is unwrapped by the user, but + # the metadata should hold). + _transfer_attrs(fr=concrete_arg, to=out.node) + args.append(out) + name_idx += 1 + return root_fn, args + + arg_names = [next(names_iter) for idx in range(skip_arg_idx, total_args)] + if isinstance(concrete_args, tuple): + if len(arg_names) != len(concrete_args): + raise RuntimeError( + f"Tracing expected {len(arg_names)} arguments but got {len(concrete_args)} concrete arguments" + ) + concrete_args = dict(zip(arg_names, concrete_args)) + + def proxy_placeholder(name): + return self._proxy_placeholder(name, concrete_args, sig, fn_for_analysis) + + args.extend(proxy_placeholder(names) for names in arg_names) + + if co.co_kwonlyargcount > 0 or co.co_flags & HAS_VARSTUFF: + # TODO: type annotations for *args and **kwargs + if co.co_flags & inspect.CO_VARARGS: + args.append(proxy_placeholder("*" + next(names_iter))) + if co.co_flags & inspect.CO_VARKEYWORDS: + args.append(proxy_placeholder("**" + next(names_iter))) + root_fn = _patch_function(root_fn, len(args)) + + flat_args, in_spec = pytree.tree_flatten(tuple(args)) + if not all(child.is_leaf() for child in in_spec.children()): + # In the case that we have pytree-flattened inputs in + # `concrete_args`, generate a flattening wrapper around the + # original root function and return that. + self.graph._codegen = _PyTreeCodeGen( # type: ignore[has-type] + _PyTreeInfo(orig_args[:total_args], in_spec, None) + ) + + def flatten_fn(*args): + tree_args = pytree.tree_unflatten(list(args), in_spec) + tree_out = root_fn(*tree_args) + out_args, out_spec = pytree.tree_flatten(tree_out) + assert isinstance(self.graph._codegen, _PyTreeCodeGen) # type: ignore[has-type] + self.graph._codegen.pytree_info = ( + self.graph._codegen.pytree_info._replace(out_spec=out_spec) + ) + return out_args + + return flatten_fn, flat_args + return root_fn, args + + @compatibility(is_backward_compatible=True) + def trace( + self, + root: Union[torch.nn.Module, Callable[..., Any]], + concrete_args: Optional[dict[str, Any]] = None, + ) -> Graph: + """ + Trace ``root`` and return the corresponding FX ``Graph`` representation. ``root`` + can either be an ``nn.Module`` instance or a Python callable. + + Note that after this call, ``self.root`` may be different from the ``root`` passed + in here. For example, when a free function is passed to ``trace()``, we will + create an ``nn.Module`` instance to use as the root and add embedded constants + to. + + + Args: + + root (Union[Module, Callable]): Either a ``Module`` or a function to be + traced through. Backwards-compatibility for this parameter is + guaranteed. + concrete_args (Optional[Dict[str, any]]): Concrete arguments that should + not be treated as Proxies. This parameter is experimental and + its backwards-compatibility is *NOT* guaranteed. + + Returns: + + A ``Graph`` representing the semantics of the passed-in ``root``. + """ + global _is_fx_tracing_flag + old_is_fx_tracing_flag = _is_fx_tracing_flag + _is_fx_tracing_flag = True + try: + if isinstance(root, torch.nn.Module): + # do real recompilation for _LazyGraphModule before retracing since the trace + # method can not trace the _lazy_forward method. Got error: + # https://gist.github.com/shunting314/75549c2e82ae07ac1139c94a3583d259 + # without this. + from torch.fx._lazy_graph_module import _LazyGraphModule + + _LazyGraphModule.force_recompile(root) + + self.root = root + + assert hasattr(type(root), self.traced_func_name), ( + f"traced_func_name={self.traced_func_name} doesn't exist in {type(root).__name__}" + ) + + fn = getattr(type(root), self.traced_func_name) + self.root_module_name = root._get_name() + self.submodule_paths = {mod: name for name, mod in root.named_modules()} + else: + self.root = torch.nn.Module() + fn = root + + tracer_cls: Optional[type[Tracer]] = getattr(self, "__class__", None) + self.graph = Graph(tracer_cls=tracer_cls) + if hasattr(fn, "__code__"): + code = fn.__code__ + self.graph._co_fields = { + "co_name": code.co_name, + "co_filename": code.co_filename, + "co_firstlineno": code.co_firstlineno, + } + + # When we encounter a Tensor value that's not a parameter, we look if it + # is some other attribute on the model. Construct a dict mapping Tensor + # values to the qualified name here for efficiency. This is used downstream + # in create_arg + self.tensor_attrs: dict[ + _ConstantAttributeType, + str, + ] = {} + + def collect_tensor_attrs(m: torch.nn.Module, prefix_atoms: list[str]): + for k, v in m.__dict__.items(): + if isinstance(v, _constant_attribute_types): + self.tensor_attrs[v] = ".".join(prefix_atoms + [k]) + for k, v in m.named_children(): + collect_tensor_attrs(v, prefix_atoms + [k]) + + collect_tensor_attrs(self.root, []) + + assert isinstance(fn, FunctionType) + + fn_globals = fn.__globals__ # run before it gets patched + fn, args = self.create_args_for_root( + fn, isinstance(root, torch.nn.Module), concrete_args + ) + + parameter_proxy_cache: dict[ + str, Proxy + ] = {} # Reduce number of get_attr calls + + # Method dispatch on parameters is not recorded unless it's directly used. + # Thus, we need to insert a proxy when __getattr__ requests a parameter. + @functools.wraps(_orig_module_getattr) + def module_getattr_wrapper(mod, attr): + attr_val = _orig_module_getattr(mod, attr) + return self.getattr(attr, attr_val, parameter_proxy_cache) + + @functools.wraps(_orig_module_call) + def module_call_wrapper(mod, *args, **kwargs): + def forward(*args, **kwargs): + return _orig_module_call(mod, *args, **kwargs) + + _autowrap_check( + patcher, # type: ignore[has-type] + getattr(getattr(mod, "forward", mod), "__globals__", {}), + self._autowrap_function_ids, + ) + return self.call_module(mod, forward, args, kwargs) + + with _new_patcher() as patcher: + # allow duplicate patches to support the case of nested calls + patcher.patch_method( + torch.nn.Module, + "__getattr__", + module_getattr_wrapper, + deduplicate=False, + ) + patcher.patch_method( + torch.nn.Module, + "__call__", + module_call_wrapper, + deduplicate=False, + ) + _patch_wrapped_functions(patcher) + _autowrap_check(patcher, fn_globals, self._autowrap_function_ids) + for module in self._autowrap_search: + _autowrap_check( + patcher, module.__dict__, self._autowrap_function_ids + ) + ann = inspect.get_annotations(inspect.unwrap(fn)) + self.create_node( + "output", + "output", + (self.create_arg(fn(*args)),), + {}, + type_expr=ann.get("return", None), + ) + + self.submodule_paths = None + except RuntimeError as e: + if e.args and isinstance(e.args[0], str) and "data-dependent" in e.args[0]: + partial_fx_graph = self.graph.python_code( + root_module="self", + verbose=True, + ).src + e.partial_fx_graph = partial_fx_graph # type: ignore[attr-defined] + raise + + raise + finally: + _is_fx_tracing_flag = old_is_fx_tracing_flag + return self.graph + + def __deepcopy__(self, memo): + # _autowrap_search contains modules, which cannot be deepcopied. + new_tracer = Tracer.__new__(Tracer) + + for k, v in self.__dict__.items(): + if k == "_autowrap_search": + new_obj = copy.copy(v) + else: + new_obj = copy.deepcopy(v, memo) + + new_tracer.__dict__[k] = new_obj + + return new_tracer + + def _proxy_placeholder(self, name, concrete_args, sig, fn_for_analysis): + if concrete_args is not None and name in concrete_args: + cnt = 0 + + def replace_ph(x): + nonlocal cnt + cnt += 1 + param = sig.parameters[name] + default: tuple[Any, ...] = ( + () if param.default is inspect.Parameter.empty else (param.default,) + ) + out = self.create_proxy( + "placeholder", f"{name}_{str(cnt)}", default, {} + ) + if isinstance(x, PHBase): + if x != PH: + # Transfer attrs in the case where you're using a placeholder other + # than the singleton PH (PH has no attributes to transfer). + # Proxies were created out of the placeholders. + # Transfer any metadata (put on the placeholders in the form of + # attributes set by the user) from the placeholder to the + # underlying nodes (the proxy is unwrapped by the user, but + # the metadata should hold). + _transfer_attrs(fr=x, to=out.node) + + return out + # Union[int, bool] == bool in Python <= 3.6 + if ( + type(x) is bool + or type(x) in base_types + and type(x) is not torch.Tensor + ): + torch._assert( + out == x, + f"{name} has been specialized to have value {x} but got another value", + ) + elif x is None: + args = ( + out, + f"{name} has been specialized to have value None but got another value", + ) + self.create_proxy("call_function", _assert_is_none, args, {}) + else: + warnings.warn( + f"Was not able to add assertion to guarantee correct input {name} to " + f"specialized function. It is up to the user to make sure that your inputs match the " + f"inputs you specialized the function with." + ) + + return x + + return pytree.tree_map(replace_ph, concrete_args[name]) + if name[0] == "*": + default: tuple[Any, ...] = () + else: + param = sig.parameters[name] + default = ( # type: ignore[assignment] + () if param.default is inspect.Parameter.empty else (param.default,) + ) + return self.create_proxy( + "placeholder", + name, + default, + {}, + type_expr=fn_for_analysis.__annotations__.get(name, None), + ) + + +# Dictionary of (id(globals dict), function name) => globals_dict to patch for +# the purposes of the wrap() API. +# We key by the globals dict id and function name to ensure we're wrapping a given +# function only once. +_wrapped_fns_to_patch: dict[tuple[int, str], dict] = {} + +# List of methods on classes to wrap (class type, function name) +# this currently only works for Tensor.* methods that aren't traced properly +_wrapped_methods_to_patch: list[tuple[type, str]] = [] + +if os.environ.get("FX_PATCH_GETITEM") == "1": + # This change is needed to trace models like PositionalEmbedding from BERT: + # https://github.com/pytorch/benchmark/blob/master/torchbenchmark/models/BERT_pytorch/bert_pytorch/model/embedding/position.py + # but causes issues in quantization documented here: + # https://github.com/pytorch/pytorch/issues/50710 + # once that is fixed we can make this the default behavior. + _wrapped_methods_to_patch.append((torch.Tensor, "__getitem__")) + + +def _find_proxy(*objects_to_search): + """ + Recursively search a data structure for a Proxy() and return it, + return None if not found. + """ + proxy = None + + def find_proxy(x): + nonlocal proxy + if isinstance(x, Proxy): + proxy = x + + map_aggregate(objects_to_search, find_proxy) + return proxy + + +def _create_wrapped_func(orig_fn): + @functools.wraps(orig_fn) + def wrapped(*args, **kwargs): + """ + Given an closed-over ``orig_function`` to invoke, search the args and kwargs for + a Proxy object. If there is one, emit a ``call_function`` node to preserve the + call to this leaf function directly. Otherwise, just return the results of + this function call, as this function is not being traced. + """ + proxy = _find_proxy(args, kwargs) + if proxy is not None: + return_proxy = proxy.tracer.create_proxy( + "call_function", orig_fn, args, kwargs + ) + return_proxy.node.meta["is_wrapped"] = True + return return_proxy + return orig_fn(*args, **kwargs) + + return wrapped + + +def _create_wrapped_method(cls, name): + orig_fn = getattr(cls, name) + + @functools.wraps(orig_fn) + def wrapped(*args, **kwargs): + """ + Search the args and kwargs for a Proxy object. If there is one, + emit a ``call_method`` node to preserve the call to this method + directly. Otherwise, just return the results of this function + call, as this function is not being traced. + """ + proxy = _find_proxy(args, kwargs) + if proxy is not None: + return proxy.tracer.create_proxy("call_method", name, args, kwargs) + return orig_fn(*args, **kwargs) + + return wrapped + + +class _PatchedFn(NamedTuple): + frame_dict: Any + fn_name: str + orig_fn: Any + new_fn: Any + + def revert(self): + raise NotImplementedError + + def patch(self): + raise NotImplementedError + + +class _PatchedFnSetItem(_PatchedFn): + def revert(self): + self.frame_dict[self.fn_name] = self.orig_fn + + def patch(self): + self.frame_dict[self.fn_name] = self.new_fn + + +class _PatchedFnDel(_PatchedFn): + def revert(self): + del self.frame_dict[self.fn_name] + + def patch(self): + self.frame_dict[self.fn_name] = self.new_fn + + +class _PatchedFnSetAttr(_PatchedFn): + def revert(self): + setattr(self.frame_dict, self.fn_name, self.orig_fn) + + def patch(self): + setattr(self.frame_dict, self.fn_name, self.new_fn) + + +class _Patcher: + def __init__(self) -> None: + super().__init__() + self.patches_made: list[_PatchedFn] = [] + self.visited: set[int] = set() + + def patch( + self, + frame_dict: dict[str, Any], + name: str, + new_fn: Callable, + deduplicate: bool = True, + ): + """ + Replace frame_dict[name] with new_fn until we exit the context manager. + """ + new_fn.__fx_already_patched = deduplicate # type: ignore[attr-defined] + if name not in frame_dict and hasattr(builtins, name): + self.patches_made.append(_PatchedFnDel(frame_dict, name, None, new_fn)) + self.patches_made[-1].patch() + elif getattr(frame_dict[name], "__fx_already_patched", False): + return # already patched, no need to do it again + else: + self.patches_made.append( + _PatchedFnSetItem(frame_dict, name, frame_dict[name], new_fn) + ) + self.patches_made[-1].patch() + + def patch_method( + self, cls: type, name: str, new_fn: Callable, deduplicate: bool = True + ): + """ + Replace object_or_dict.name with new_fn until we exit the context manager. + """ + new_fn.__fx_already_patched = deduplicate # type: ignore[attr-defined] + orig_fn = getattr(cls, name) + if getattr(orig_fn, "__fx_already_patched", False): + return # already patched, no need to do it again + self.patches_made.append(_PatchedFnSetAttr(cls, name, orig_fn, new_fn)) + self.patches_made[-1].patch() + + def visit_once(self, thing: Any): + """Return True on the first call to with thing, otherwise false""" + idx = id(thing) + if idx in self.visited: + return False + self.visited.add(idx) + return True + + def revert_all_patches(self): + """ + Remove all the stored patcheds. It doesn't modify patches_made. + """ + for patch in self.patches_made: + patch.revert() + return self.patches_made + + def reapply_all_patches(self): + """ + Patch all the stored patcheds. It doesn't modify patches_made. + """ + for patch in self.patches_made: + patch.patch() + return self.patches_made + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """ + Undo all the changes made via self.patch() and self.patch_method() + """ + while self.patches_made: + # unpatch in reverse order to handle duplicates correctly + self.patches_made.pop().revert() + self.visited.clear() + + +CURRENT_PATCHER: Optional[_Patcher] = None + + +@contextlib.contextmanager +def _new_patcher(): + global CURRENT_PATCHER + prior_patcher = CURRENT_PATCHER + try: + CURRENT_PATCHER = _Patcher() + yield CURRENT_PATCHER + finally: + # Clear all the patches made by when using current patcher. + assert CURRENT_PATCHER is not None + CURRENT_PATCHER.revert_all_patches() + CURRENT_PATCHER = prior_patcher + + +@contextlib.contextmanager +def _maybe_revert_all_patches(): + current_patcher = CURRENT_PATCHER + patches_made = None + patches_removed = None + try: + if current_patcher is not None: + patches_removed = current_patcher.revert_all_patches() + yield + finally: + if current_patcher is not None: + patches_made = current_patcher.reapply_all_patches() + assert patches_made == patches_removed, ( + "CURRENT_PATCHER was changed during a revert_all_patches" + ) + + +def _patch_wrapped_functions(patcher: _Patcher): + """ + Go through ``_wrapped_fn_patch_table`` and, for each frame object, wrap + the listed global functions in the `_create_wrapped_func` wrapper. + """ + for (_, name), frame_dict in _wrapped_fns_to_patch.copy().items(): + if name not in frame_dict and hasattr(builtins, name): + orig_fn = getattr(builtins, name) + else: + orig_fn = frame_dict[name] + patcher.patch(frame_dict, name, _create_wrapped_func(orig_fn)) + + for cls, name in _wrapped_methods_to_patch: + patcher.patch_method(cls, name, _create_wrapped_method(cls, name)) + + +def _autowrap_check( + patcher: _Patcher, frame_dict: dict[str, Any], function_ids: set[int] +): + """ + Some methods, like `math.sqrt` are common enough we want to automatically wrap them as we see them. + This method searches a scope for them and patches them if found. + """ + if patcher.visit_once(frame_dict): + for name, value in frame_dict.items(): + if ( + not name.startswith("_") + and callable(value) + and id(value) in function_ids + ): + patcher.patch(frame_dict, name, _create_wrapped_func(value)) + + +@compatibility(is_backward_compatible=True) +def wrap(fn_or_name: Union[str, Callable]): + """ + This function can be called at module-level scope to register fn_or_name as a "leaf function". + A "leaf function" will be preserved as a CallFunction node in the FX trace instead of being + traced through:: + + # foo/bar/baz.py + def my_custom_function(x, y): + return x * x + y * y + + + torch.fx.wrap("my_custom_function") + + + def fn_to_be_traced(x, y): + # When symbolic tracing, the below call to my_custom_function will be inserted into + # the graph rather than tracing it. + return my_custom_function(x, y) + + This function can also equivalently be used as a decorator:: + + # foo/bar/baz.py + @torch.fx.wrap + def my_custom_function(x, y): + return x * x + y * y + + A wrapped function can be thought of a "leaf function", analogous to the concept of + "leaf modules", that is, they are functions that are left as calls in the FX trace + rather than traced through. + + Args: + + fn_or_name (Union[str, Callable]): The function or name of the global function to insert into the + graph when it's called + """ + if not callable(fn_or_name) and not isinstance(fn_or_name, str): + raise RuntimeError( + "Unsupported type for global function! Must be either a callable or " + "string name" + ) + + if callable(fn_or_name): + assert not isinstance(fn_or_name, str) # to make mypy happy + fn_name = fn_or_name.__name__ + else: + assert isinstance(fn_or_name, str), ( + "fn_or_name must be a global function or string name" + ) + fn_name = fn_or_name + + currentframe = inspect.currentframe() + assert currentframe is not None + f = currentframe.f_back + assert f is not None + if f.f_code.co_name != "": + raise NotImplementedError("wrap must be called at the top level of a module") + + # consider implementing Callable version of this via _autowrap_function_ids / _autowrap_search + # semantics would be slightly different, but would add support `from x import wrapped_function` + _wrapped_fns_to_patch[(id(f.f_globals), fn_name)] = f.f_globals + return fn_or_name + + +@compatibility(is_backward_compatible=True) +def symbolic_trace( + root: Union[torch.nn.Module, Callable[..., Any]], + concrete_args: Optional[dict[str, Any]] = None, +) -> GraphModule: + """ + Symbolic tracing API + + Given an ``nn.Module`` or function instance ``root``, this function will return a ``GraphModule`` + constructed by recording operations seen while tracing through ``root``. + + ``concrete_args`` allows you to partially specialize your function, whether it's to remove control flow or data structures. + + For example:: + + def f(a, b): + if b == True: + return a + else: + return a * 2 + + FX can typically not trace through this due to the presence of control + flow. However, we can use `concrete_args` to specialize on the value of + `b` to trace through this:: + + f = fx.symbolic_trace(f, concrete_args={"b": False}) + assert f(3, False) == 6 + + Note that although you can still pass in different values of `b`, they will be ignored. + + We can also use `concrete_args` to eliminate data-structure handling from + our function. This will use pytrees to flatten your input. To avoid + overspecializing, pass in `fx.PH` for values that shouldn't be + specialized. For example:: + + def f(x): + out = 0 + for v in x.values(): + out += v + return out + + + f = fx.symbolic_trace( + f, concrete_args={"x": {"a": fx.PH, "b": fx.PH, "c": fx.PH}} + ) + assert f({"a": 1, "b": 2, "c": 4}) == 7 + + + Args: + root (Union[torch.nn.Module, Callable]): Module or function to be traced and converted + into a Graph representation. + concrete_args (Optional[Dict[str, any]]): Inputs to be partially specialized + + Returns: + GraphModule: a Module created from the recorded operations from ``root``. + """ + tracer = Tracer() + graph = tracer.trace(root, concrete_args) + name = ( + root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__ + ) + return _make_graph_module(tracer.root, graph, name) + + +@wrap +def _assert_is_none(value, msg): + assert value is None, msg diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/_utils.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..25f1c51171734704c6c2ccea3a739de629ff5262 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/_utils.py @@ -0,0 +1,67 @@ +# mypy: allow-untyped-defs +import sys +from typing import Optional + +import torch +from torch._logging import LazyString + + +def lazy_format_graph_code(name, gm, maybe_id=None, **kwargs): + """ + Returns a LazyString that formats the graph code. + """ + + def format_name(): + if maybe_id is not None: + return f"{name} {maybe_id}" + else: + return name + + if "print_output" not in kwargs: + kwargs["print_output"] = False + + if "colored" in kwargs: + try: + if not sys.stdout.isatty(): + kwargs["colored"] = False + except AttributeError: + kwargs["colored"] = False + + return LazyString( + lambda: _format_graph_code( + f"===== {format_name()} =====\n", + gm.forward.__code__.co_filename, + gm.print_readable(**kwargs), + ) + ) + + +def _format_graph_code(name, filename, graph_str): + """ + Returns a string that formats the graph code. + """ + return f"TRACED GRAPH\n {name} {filename} {graph_str}\n" + + +def first_call_function_nn_module_stack(graph: torch.fx.Graph) -> Optional[dict]: + """ + Returns the nn_module_stack of the first call_function node. + """ + for node in graph.nodes: + if node.op == "call_function" and "nn_module_stack" in node.meta: + return node.meta["nn_module_stack"] + return None + + +def get_node_context(node, num_nodes=2) -> str: + """ + Returns a string of the last num_nodes nodes in the graph. + """ + node_contexts = [] + cur = node + for _ in range(num_nodes): + node_contexts.append(cur.format_node()) + if cur.op == "root": + break + cur = cur.prev + return "\n".join(node_contexts[::-1]) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/annotate.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/annotate.py new file mode 100644 index 0000000000000000000000000000000000000000..b3c5056066251df51542cd187652c5111749516f --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/annotate.py @@ -0,0 +1,36 @@ +# mypy: allow-untyped-defs +from torch.fx.proxy import Proxy + +from ._compatibility import compatibility + + +@compatibility(is_backward_compatible=False) +def annotate(val, type): + """ + Annotates a Proxy object with a given type. + + This function annotates a val with a given type if a type of the val is a torch.fx.Proxy object + Args: + val (object): An object to be annotated if its type is torch.fx.Proxy. + type (object): A type to be assigned to a given proxy object as val. + Returns: + The given val. + Raises: + RuntimeError: If a val already has a type in its node. + """ + if isinstance(val, Proxy): + if val.node.type: + raise RuntimeError( + f"Tried to annotate a value that already had a type on it!" + f" Existing type is {val.node.type} " + f"and new type is {type}. " + f"This could happen if you tried to annotate a function parameter " + f"value (in which case you should use the type slot " + f"on the function signature) or you called " + f"annotate on the same value twice" + ) + else: + val.node.type = type + return val + else: + return val diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/config.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/config.py new file mode 100644 index 0000000000000000000000000000000000000000..db06176c43e13c1fecd2e2e89a4b8371ca8d3bc5 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/config.py @@ -0,0 +1,6 @@ +# Whether to disable showing progress on compilation passes +# Need to add a new config otherwise will get a circular import if dynamo config is imported here +disable_progress = True + +# If True this also shows the node names in each pass, for small models this is great but larger models it's quite noisy +verbose_progress = False diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/graph.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/graph.py new file mode 100644 index 0000000000000000000000000000000000000000..03cba403350becac369ca01ce17a43bbd184a455 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/graph.py @@ -0,0 +1,2316 @@ +# mypy: allow-untyped-defs +import builtins +import contextlib +import copy +import enum +import functools +import inspect +import keyword +import logging +import math +import os +import pprint +import re +import types +import typing +import warnings +from collections import defaultdict +from collections.abc import Callable, Iterable, Iterator +from contextlib import contextmanager +from dataclasses import dataclass +from typing import Any, Literal, NamedTuple, Optional, TYPE_CHECKING + +import torch +import torch.utils._pytree as pytree +from torch._C import _fx_map_arg as map_arg, _NodeIter +from torch._library.opaque_object import is_opaque_value_type +from torch.utils._dtype_abbrs import dtype_abbrs + +from . import _pytree as fx_pytree +from ._compatibility import compatibility +from .immutable_collections import immutable_dict +from .node import _get_qualified_name, _type_repr, Argument, Node, Target + + +log = logging.getLogger(__name__) + +__all__ = ["PythonCode", "CodeGen", "Graph"] + +if TYPE_CHECKING: + from ._symbolic_trace import Tracer # noqa: F401 + from .graph_module import GraphModule # noqa: F401 + + +# Mapping of builtins to their `typing` equivalent. +# (PEP585: See D68459095 test plan) +_origin_type_map = { + list: typing.List, # noqa: UP006 + dict: typing.Dict, # noqa: UP006 + set: typing.Set, # noqa: UP006 + frozenset: typing.FrozenSet, # noqa: UP006 + tuple: typing.Tuple, # noqa: UP006 +} + +_legal_ops = dict.fromkeys( + ["call_function", "call_method", "get_attr", "call_module", "placeholder", "output"] +) + + +# Signature for functions thattransforms the body (`list[str]`) of the +# generated code +TransformCodeFunc = Callable[[list[str]], list[str]] + + +class _CustomBuiltin(NamedTuple): + """Additional objs that we add to every graph's globals. + + The repr() for some standard library objects is not valid Python code without + an import. For common objects of this sort, we bundle them in the globals of + every FX graph. + """ + + # How to import this object from the standard library. + import_str: str + # The actual object, produced from that import string. + obj: Any + + +# Combined dict of disallowed variable names so we can check with one lookup +_illegal_names = {k: object() for k in keyword.kwlist} +_illegal_names.update(builtins.__dict__) # can't shadow a builtin name + +_custom_builtins: dict[str, _CustomBuiltin] = {} + + +def _register_custom_builtin(name: str, import_str: str, obj: Any): + _custom_builtins[name] = _CustomBuiltin(import_str, obj) + _illegal_names[name] = obj + + +_register_custom_builtin("inf", "from math import inf", math.inf) +_register_custom_builtin("nan", "from math import nan", math.nan) +_register_custom_builtin("NoneType", "NoneType = type(None)", type(None)) +_register_custom_builtin("torch", "import torch", torch) +_register_custom_builtin("device", "from torch import device", torch.device) +_register_custom_builtin("fx_pytree", "import torch.fx._pytree as fx_pytree", fx_pytree) +_register_custom_builtin("pytree", "import torch.utils._pytree as pytree", pytree) + + +def _is_magic(x: str) -> bool: + return x.startswith("__") and x.endswith("__") + + +def _snake_case(s: str) -> str: + """ + Transforms the given string ``s`` to a Python-style variable name + + Examples: + ``mod.snake_case`` -> ``mod.snake_case`` + ``mod.pascalCase``-> ``mod.pascal_case`` + ``mod.ALL_CAPS`` -> ``mod.all_caps`` + """ + return _snake_case_sub(s).lower() + + +# Replace occurrences where a lowercase letter is followed by an uppercase letter +_snake_case_sub = functools.partial(re.compile(r"(?<=[a-z])([A-Z])").sub, r"_\1") + +# Find chars that can't be in a Python identifier +_illegal_char_regex = re.compile("[^0-9a-zA-Z_]+") + +# Combined check for variable names: +# 1) Checks name is not empty +# 2) Checks first character is not a digit +# 3) Checks name has no illegal characters (_illegal_char_regex) +# 3) Splits off the number suffix (if present) +_name_regex = re.compile(r"^([a-zA-Z_][0-9a-zA-Z_]*?)(?:_(\d+))?$") + +# starts with torch but does not start with torch._dynamo. or torch._inductor. +_torch_but_not_dynamo = re.compile( + r"^torch(?:\.(?!_dynamo\.|_inductor\.)[^.]+)*$" +).fullmatch + + +def _is_from_torch(obj: Any) -> bool: + module_name = getattr(obj, "__module__", None) + if module_name is not None: + return _torch_but_not_dynamo(module_name) is not None + + name = getattr(obj, "__name__", None) + # exclude torch because torch.torch.torch.torch works. idk mang + if name is not None and name != "torch": + for guess in [torch, torch.nn.functional]: + if getattr(guess, name, None) is obj: + return True + + return False + + +class _Namespace: + """A context for associating names uniquely with objects. + + The following invariants are enforced: + - Each object gets a single name. + - Each name is unique within a given namespace. + - Names generated do not shadow builtins, unless the object is indeed that builtin. + """ + + def __init__(self): + self._obj_to_name: dict[Any, str] = {} + self._used_names: set[str] = set() + self._base_count: dict[str, int] = {} + + def create_name(self, candidate: str, obj: Optional[Any]) -> str: + """Create a unique name. + + Arguments: + candidate: used as the basis for the unique name, relevant to the user. + obj: If not None, an object that will be associated with the unique name. + """ + if obj is not None and obj in self._obj_to_name: + return self._obj_to_name[obj] + + # optimistically check if candidate is already a valid name + match = _name_regex.match(candidate) + if match is None: + # delete all characters that are illegal in a Python identifier + candidate = _illegal_char_regex.sub("_", candidate) + + if not candidate: + candidate = "_unnamed" + + if candidate[0].isdigit(): + candidate = f"_{candidate}" + + match = _name_regex.match(candidate) + assert match is not None + + base, num = match.group(1, 2) + if num is None or candidate in self._used_names: + num = self._base_count.get(candidate, 0) + if _illegal_names.get(candidate, obj) is not obj: + num += 1 + candidate = f"{base}_{num}" + # assume illegal names don't end in _\d so no need to check again + else: + num = int(num) + + while candidate in self._used_names: + num += 1 + candidate = f"{base}_{num}" + + self._used_names.add(candidate) + self._base_count[base] = num + if obj is not None: + self._obj_to_name[obj] = candidate + return candidate + + def associate_name_with_obj(self, name: str, obj: Any): + """Associate a unique name with an object. + + Neither `name` nor `obj` should be associated already. + """ + maybe_existing = self._obj_to_name.setdefault(obj, name) + assert maybe_existing is name, "obj is already associated" + + def _rename_object(self, obj: Any, name: str): + assert obj in self._obj_to_name + self._obj_to_name[obj] = name + self._used_names.add(name) + + +@compatibility(is_backward_compatible=True) +@dataclass +class PythonCode: + """ + Represents all the information necessary to exec or save a graph as Python code. + """ + + # Python source code for the forward function definition. + src: str + # Values in global scope during execution of `src_def`. + globals: dict[str, Any] + # Optional mapping from the forward function's line number to + # node index. Line number starts at the prologue (i.e. forward()). + _lineno_map: Optional[dict[int, Optional[int]]] + # The line number of prologue in fn_code + _prologue_start: int = 0 + + +def _format_target(base: str, target: str) -> str: + elems = target.split(".") + r = base + for e in elems: + if not e.isidentifier(): + r = f'getattr({r}, "{e}")' + else: + r = f"{r}.{e}" + return r + + +class _InsertPoint: + def __init__(self, graph, new_insert): + self.graph = graph + self.orig_insert, graph._insert = graph._insert, new_insert + + def __enter__(self): + pass + + def __exit__(self, type, value, tb): + self.graph._insert = self.orig_insert + + +class _node_list: + def __init__(self, graph: "Graph", direction: Literal["_prev", "_next"] = "_next"): + assert direction in ("_next", "_prev") + self.graph = graph + self.direction = direction + + def __len__(self): + return self.graph._len + + def __iter__(self): + return _NodeIter(self.graph._root, self.direction == "_prev") + + def __reversed__(self): + return _node_list(self.graph, "_next" if self.direction == "_prev" else "_prev") + + +class _PyTreeInfo(NamedTuple): + """ + Contains extra info stored when we're using Pytrees + """ + + orig_args: list[str] + in_spec: pytree.TreeSpec + out_spec: Optional[pytree.TreeSpec] + + +@dataclass(frozen=True) +class _ParsedStackTrace: + """ + Represents the top-most frame of a parsed stack trace + """ + + file: str + lineno: str + name: str + code: str + + def get_summary_str(self): + return f"File: {self.file}:{self.lineno} in {self.name}, code: {self.code}" + + +# get File:lineno code from stack_trace +def _parse_stack_trace( + stack_trace: str, filter_fn: Optional[Callable[[str, str, str], bool]] = None +): + if stack_trace is None: + return None + pattern = re.compile(r"^File \"(.+)\", line (\d+), in (.+)$") + lines = stack_trace.strip().split("\n") + # stacktrace should have innermost frame last, so we + # iterate backwards to find the first line that starts + # with 'File ' + for idx in range(len(lines) - 2, -1, -1): + line = lines[idx].strip() + matches = pattern.match(line) + if matches: + file = matches.group(1) + lineno = matches.group(2) + name = matches.group(3) + # next line should be the code + code = lines[idx + 1].strip() + if filter_fn and not filter_fn(file, name, code): + continue + return _ParsedStackTrace(file, lineno, name, code) + return None + + +@compatibility(is_backward_compatible=False) +class CodeGen: + # This is an override hook so we can customize the SymNode printer. + _sym_repr: Callable[["torch.types.PySymType"], str] = lambda x: repr(x) + + def __init__(self): + self._body_transformer: Optional[TransformCodeFunc] = None + self._func_name: str = "forward" + + def _format_multiline_args(self, args: list[str]) -> str: + """Helper to format function arguments in expanded multiline format.""" + return "".join(self._format_single_arg(arg) for arg in args) + + def _format_single_arg(self, arg: str) -> str: + """Helper to format a single argument with optional comment.""" + if "#" in arg: + arg_part, comment_part = arg.split("#", 1) + return f" {arg_part.rstrip()}, # {comment_part.lstrip()}\n" + else: + return f" {arg},\n" + + def _get_delimiters(self, container) -> tuple[str, str]: + """Helper to get opening and closing delimiters for containers.""" + return ("(", ")") if isinstance(container, tuple) else ("[", "]") + + def _format_multiline_container(self, items, descs=None, prefix="") -> str: + """Helper to format containers (lists/tuples) in multiline format.""" + ldelim, rdelim = self._get_delimiters(items) + desc_trailers = self._get_desc_trailers(items, descs) + + return ( + f"{prefix}{ldelim}\n" + + "".join( + f" {item},{trailer}\n" for item, trailer in zip(items, desc_trailers) + ) + + f"{rdelim}" + ) + + def _get_desc_trailers(self, items, descs): + """Helper to generate description trailers for items.""" + if descs is None: + return [""] * len(items) + return [f" # {desc}" for desc in descs] + + def _call_method_with_signature_check(self, method, *args, **kwargs): + """Helper to call a method with optional parameters based on signature.""" + sig = inspect.signature(method) + # Filter kwargs to only include parameters that exist in the method signature + filtered_kwargs = {k: v for k, v in kwargs.items() if k in sig.parameters} + return method(*args, **filtered_kwargs) + + def gen_fn_def( + self, + free_vars: list[str], + maybe_return_annotation: str, + *, + expanded_def: bool = False, + ) -> str: + """ + Given the free variables and a return annotation, generates the beginning of the FX function. + By default, `gen_fn_def(['a', 'b'], '') == 'def {self._func_name}(a, b):'` + """ + # If the original function didn't have self as its first argument, we + # would have added it. + if len(free_vars) == 0 or free_vars[0] != "self": + free_vars.insert(0, "self") + + if expanded_def: + args_formatted = self._format_multiline_args(free_vars) + return ( + f"def {self._func_name}(\n{args_formatted}){maybe_return_annotation}:" + ) + else: + return f"def {self._func_name}({', '.join(free_vars)}){maybe_return_annotation}:" + + def generate_output( + self, output_args: Argument, *, descs: Optional[Any] = None + ) -> str: + """ + Given the output arguments, generates the return statement of the FX function. + Note: The returned statement should not be indented. + """ + if descs is not None and isinstance(output_args, (list, tuple)): + return self._format_multiline_container(output_args, descs, "return ") + else: + return f"return {repr(output_args)}" + + def process_inputs(self, *args: Any) -> Any: + """ + Transforms the inputs so that the graph can take them as arguments, as + non-default codegen may result in the inputs to the function being + different from the inputs to the graph. + + If the graph was directly runnable, this invariant should hold true + `f.graph.process_outputs(f.graph(*f.graph.process_inputs(*inputs))) == f(*inputs)` + """ + return args + + def process_outputs(self, outputs: Any) -> Any: + """ + Transforms the outputs of the graph to be identical to the codegen. + + See ``process_inputs`` for more details. + """ + return outputs + + def additional_globals(self) -> list[tuple[str, Any]]: + """ + If your codegen uses extra global values, add tuples of (identifier,reference to the value) here. + For example, return ['List', typing.List] if you need ``List`` in the global context. + """ + return [] + + def _gen_python_code( + self, + nodes, + root_module: str, + namespace: _Namespace, + *, + verbose: bool = False, + include_stride: bool = False, + include_device: bool = False, + colored: bool = False, + # Render each argument on its own line + expanded_def: bool = False, + record_func: bool = False, + ) -> PythonCode: + free_vars: list[str] = [] + body: list[str] = [] + globals_: dict[str, Any] = {} + wrapped_fns: dict[str, None] = {} + + # Wrap string in list to pass by reference + maybe_return_annotation: list[str] = [""] + include_stride = include_stride or ( + os.environ.get("FX_GRAPH_SHOW_STRIDE", "0") == "1" + ) + include_device = include_device or ( + os.environ.get("FX_GRAPH_SHOW_DEVICE", "0") == "1" + ) + include_meta = os.environ.get("FX_GRAPH_SHOW_META", "0") == "1" + + def add_global(name_hint: str, obj: Any): + """Add an obj to be tracked as a global. + + We call this for names that reference objects external to the + Graph, like functions or types. + + Returns: the global name that should be used to reference 'obj' in generated source. + """ + if ( + _is_from_torch(obj) and obj != torch.device + ): # to support registering torch.device + # HACK: workaround for how torch custom ops are registered. We + # can't import them like normal modules so they must retain their + # fully qualified name. + return _get_qualified_name(obj) + + # normalize the name hint to get a proper identifier + global_name = namespace.create_name(name_hint, obj) + + if global_name in globals_: + assert globals_[global_name] == obj + return global_name + globals_[global_name] = obj + return global_name + + # Pre-fill the globals table with registered builtins. + for name, (_, obj) in _custom_builtins.items(): + add_global(name, obj) + + def type_repr(o: Any): + if o == (): + # Empty tuple is used for empty tuple type annotation Tuple[()] + return "()" + + typename = _type_repr(o) + if isinstance(o, types.UnionType) and "|" in typename: + # str | int + args = [type_repr(arg) for arg in o.__args__] + return "|".join(args) + + if origin_type := getattr(o, "__origin__", None): + # list[...], typing.List[...], TensorType[...] + + if isinstance(o, typing._GenericAlias): # type: ignore[attr-defined] + # This is a generic pre-PEP585 type, e.g. typing.List[torch.Tensor] + origin_type = _origin_type_map.get(origin_type, origin_type) + + origin_typename = add_global(_type_repr(origin_type), origin_type) + + if hasattr(o, "__args__") and o.__args__: + args = [type_repr(arg) for arg in o.__args__] + return f"{origin_typename}[{','.join(args)}]" + else: + return origin_typename + + # Common case: this is a regular module name like 'foo.bar.baz' + return add_global(typename, o) + + if colored: + red = _color_fns["red"] + dim_green = _color_fns["dim_green"] + dim = _color_fns["dim"] + dim_blue = _color_fns["dim_blue"] + blue = _color_fns["blue"] + else: + red = _identity + dim_green = _identity + dim = _identity + dim_blue = _identity + blue = _identity + + def _get_repr(arg: Any) -> str: + if isinstance(arg, Node): # first because common + return repr(arg) + elif isinstance(arg, tuple) and hasattr(arg, "_fields"): + # Handle NamedTuples (if it has `_fields`) via add_global. + qualified_name = _get_qualified_name(type(arg)) + global_name = add_global(qualified_name, type(arg)) + return f"{global_name}{repr(tuple(arg))}" + elif isinstance( + arg, (torch._ops.OpOverload, torch._ops.HigherOrderOperator) + ): + qualified_name = _get_qualified_name(arg) + global_name = add_global(qualified_name, arg) + return f"{global_name}" + elif isinstance(arg, enum.Enum): + cls = arg.__class__ + clsname = add_global(cls.__name__, cls) + return f"{clsname}.{arg.name}" + elif isinstance(arg, torch.Tensor): + size = list(arg.size()) + dtype = str(arg.dtype).split(".")[-1] + return f"torch.Tensor(size={size}, dtype={dtype})" + elif isinstance(arg, tuple): + if len(arg) == 1: + return f"({_get_repr(arg[0])},)" + else: + return "(" + ", ".join(_get_repr(a) for a in arg) + ")" + elif isinstance(arg, list): + return "[" + ", ".join(_get_repr(a) for a in arg) + "]" + elif isinstance(arg, slice): + return f"slice({_get_repr(arg.start)}, {_get_repr(arg.stop)}, {_get_repr(arg.step)})" + elif is_opaque_value_type(type(arg)): + arg_type = type(arg) + add_global(arg_type.__name__, arg_type) + return repr(arg) + else: + return blue(repr(arg)) + + def _format_args( + args: tuple[Argument, ...], kwargs: dict[str, Argument] + ) -> str: + res = [_get_repr(a) for a in args] + res.extend([f"{k} = {_get_repr(v)}" for k, v in kwargs.items()]) + return ", ".join(res) + + # Run through reverse nodes and record the first instance of a use + # of a given node. This represents the *last* use of the node in the + # execution order of the program, which we will use to free unused + # values + node_to_last_use: dict[Node, Node] = {} + user_to_last_uses: dict[Node, list[Node]] = {} + + def register_last_uses(n: Node, user: Node): + if n not in node_to_last_use: + node_to_last_use[n] = user + user_to_last_uses.setdefault(user, []).append(n) + + for node in reversed(nodes): + for input_node in node._input_nodes: + register_last_uses(input_node, node) + + def delete_unused_values(user: Node): + """ + Delete values after their last use. This ensures that values that are + not used in the remainder of the code are freed and the memory usage + of the code is optimal. + """ + if user.op == "placeholder": + return + if user.op == "output": + body.append("\n") + return + nodes_to_delete = user_to_last_uses.get(user, []) + + if len(user.users.keys()) == 0: + # This node is not used by any others. however it's also not + # removed by DCE since side-effect. We want to free it's outputs + # right after its execution done to save memory. + nodes_to_delete.append(user) + + if len(nodes_to_delete): + to_delete_str = " = ".join( + [repr(n) for n in nodes_to_delete] + ["None"] + ) + body.append(f"; {dim(to_delete_str)}\n") + else: + body.append("\n") + + prev_summary_str = None + + def append_stacktrace_summary(node: Node): + """ + Append a summary of the stacktrace to the generated code. This is + useful for debugging. + """ + nonlocal prev_summary_str + + if node.op not in {"placeholder", "output"}: + annotation_str = "" + annotation = node.meta.get("custom", {}) + if annotation: + annotation_str = f" Annotation: {annotation}" + + stack_trace_str = "No stacktrace found for following nodes" + if stack_trace := node.stack_trace: + if parsed_stack_trace := _parse_stack_trace(stack_trace): + stack_trace_str = parsed_stack_trace.get_summary_str() + + maybe_recompute_info = "" + if hasattr(node, "meta") and node.meta: + # recompute tags are generated by torch.compile and put in the joint graph. + # These tags are load bearing enough that we want them to show up by default + # in tlparse, when you run torch.compile. + recompute = node.meta.get("recompute", None) + ac_graph_id = node.meta.get("ac_graph_id", None) + + if recompute is not None and ac_graph_id is not None: + maybe_recompute_info = f" # ac_graph_id: {str(ac_graph_id)} - {str(recompute.name)}" + elif recompute is not None: + maybe_recompute_info = f" # recompute: {str(recompute.name)}" + elif ac_graph_id is not None: + maybe_recompute_info = f" # ac_graph_id: {str(ac_graph_id)}" + + summary_str = f"\n{dim(f'#{annotation_str}{maybe_recompute_info} {stack_trace_str}')}\n" + + if summary_str != prev_summary_str: + prev_summary_str = summary_str + body.append(summary_str) + + def stringify_shape(shape: Iterable) -> str: + return f"[{', '.join([str(x) for x in shape])}]" + + def emit_node(node: Node): + maybe_type_annotation = ( + "" if node.type is None else f" : {type_repr(node.type)}" + ) + maybe_comment = "" + + if verbose: + # override annotation with more detailed information + try: + from torch.distributed.tensor._api import DTensor, DTensorSpec + + dtensorspec_format_shard_order_str = ( + DTensorSpec.format_shard_order_str + ) + except ModuleNotFoundError: + DTensor = None # type: ignore[assignment,misc] + dtensorspec_format_shard_order_str = None + from torch.fx.experimental.proxy_tensor import py_sym_types + from torch.fx.passes.shape_prop import TensorMetadata + + meta_val = node.meta.get( + "val", + node.meta.get("tensor_meta", node.meta.get("example_value", None)), + ) + + def _tensor_annotation(t: torch.Tensor) -> str: + stride = stringify_shape(t.stride()) if include_stride else "" + device = f"{t.device}" if include_device else "" + return ( + f"{red(dtype_abbrs[t.dtype])}" + f"{blue(stringify_shape(t.shape))}" + f"{dim_blue(stride)}" + f"{dim_green(device)}" + ) + + # use string as annotation, to make it valid python code + if isinstance(meta_val, torch.Tensor) and meta_val.layout not in ( + torch.sparse_csc, + torch.sparse_csr, + ): + # Fake tensors cause tests to wobble, so do not custom print them. + is_plain = type(meta_val) is torch.Tensor or isinstance( + meta_val, torch._subclasses.FakeTensor + ) + core = _tensor_annotation(meta_val) + if is_plain: + maybe_type_annotation = f': "{core}"' + elif type(meta_val) is DTensor: + assert dtensorspec_format_shard_order_str is not None + dtensor_meta = dtensorspec_format_shard_order_str( + meta_val._spec.placements, # type: ignore[attr-defined] + meta_val._spec.shard_order, # type: ignore[attr-defined] + ) + cls = meta_val.__class__.__name__ + maybe_type_annotation = ( + f': "{cls}({core}, {dim_green(dtensor_meta)})"' + ) + else: + cls = meta_val.__class__.__name__ + maybe_type_annotation = f': "{cls}({core})"' + + elif isinstance(meta_val, py_sym_types): + val_str = CodeGen._sym_repr(meta_val) + maybe_type_annotation = f': "Sym({val_str})"' + + elif isinstance(meta_val, TensorMetadata): + maybe_type_annotation = f': "{dtype_abbrs[meta_val.dtype]}{stringify_shape(meta_val.shape)}"' + + desc = None + if expanded_def: + desc = node.meta.get("desc", None) + if desc is not None and node.op == "placeholder": + maybe_comment += f" # {desc}" + # output is handled specially + + if include_meta and hasattr(node, "meta") and node.meta: + body.append('"""\n') + for k, v in node.meta.items(): + # use str over repr since repr is susceptible to sympy + # errors such as "cannot determine truth value of Relational" + # Pretty print the high-level dict with str() for values + body.append( + f"{k}: {pprint.pformat(str(v), width=80, compact=True)}\n" + ) + body.append('"""\n') + + if node.op == "placeholder": + assert isinstance(node.target, str) + maybe_default_arg = ( + "" if not node.args else f" = {_get_repr(node.args[0])}" + ) + free_vars.append( + f"{node.target}{maybe_type_annotation}{maybe_default_arg}{maybe_comment}" + ) + raw_name = node.target.replace("*", "") + if raw_name != repr(node): + body.append(f"{repr(node)} = {raw_name}\n") + return + elif node.op == "call_method": + assert isinstance(node.target, str) + body.append( + f"{repr(node)}{maybe_type_annotation} = {_format_target(_get_repr(node.args[0]), node.target)}" + f"({_format_args(node.args[1:], node.kwargs)})" + ) + return + elif node.op == "call_function": + assert callable(node.target) + # pretty print operators + if ( + getattr(node.target, "__module__", "") == "_operator" + and node.target.__name__ in magic_methods + ): + assert isinstance(node.args, tuple) + body.append( + f"{repr(node)}{maybe_type_annotation} = " + f"{magic_methods[node.target.__name__].format(*(_get_repr(a) for a in node.args))}" + ) + return + + # pretty print inplace operators; required for jit.script to work properly + # not currently supported in normal FX graphs, but generated by torchdynamo + if ( + getattr(node.target, "__module__", "") == "_operator" + and node.target.__name__ in inplace_methods + ): + body.append( + f"{inplace_methods[node.target.__name__].format(*(_get_repr(a) for a in node.args))}; " + f"{repr(node)}{maybe_type_annotation} = {_get_repr(node.args[0])}" + ) + return + + qualified_name = _get_qualified_name(node.target) + global_name = add_global(qualified_name, node.target) + # special case for getattr: node.args could be 2-argument or 3-argument + # 2-argument: attribute access; 3-argument: fall through to attrib function call with default value + if ( + global_name == "getattr" + and isinstance(node.args, tuple) + and isinstance(node.args[1], str) + and node.args[1].isidentifier() + and len(node.args) == 2 + ): + body.append( + f"{repr(node)}{maybe_type_annotation} = {_format_target(_get_repr(node.args[0]), node.args[1])}" + ) + return + body.append( + f"{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})" + ) + if node.meta.get("is_wrapped", False): + wrapped_fns.setdefault(global_name) + return + elif node.op == "call_module": + assert isinstance(node.target, str) + body.append( + f"{repr(node)}{maybe_type_annotation} = " + f"{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})" + ) + return + elif node.op == "get_attr": + assert isinstance(node.target, str) + body.append( + f"{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}" + ) + return + elif node.op == "output": + if node.type is not None: + maybe_return_annotation[0] = f" -> {type_repr(node.type)}" + body.append( + self._call_method_with_signature_check( + self.generate_output, + node.args[0], + descs=desc if expanded_def else None, + ) + ) + return + raise NotImplementedError(f"node: {node.op} {node.target}") + + if record_func: + body.append( + "_rf = torch._C._profiler._RecordFunctionFast('## ENTER_GRAPH_PLACEHOLDER_KEY ##'); _rf.__enter__()\n" + ) + for i, node in enumerate(nodes): + # NOTE: emit_node does not emit a string with newline. It depends + # on delete_unused_values to append one + if verbose: + append_stacktrace_summary(node) + # emit a counter comment to keep track of + # node index, which will be deleted later + # after going through _body_transformer + body.append(f"# COUNTER: {i}\n") + do_record = record_func and node.op in ( + "call_function", + "call_method", + "call_module", + ) + if do_record: + # The double hash ## convention is used by post-processing to find the fx markers + body.append( + f"_rf_{node.name} = torch._C._profiler._RecordFunctionFast('## {i} ##'); _rf_{node.name}.__enter__()\n" + ) + emit_node(node) + delete_unused_values(node) + if do_record: + body.append(f"_rf_{node.name}.__exit__(None, None, None)\n") + if record_func: + body.append("_rf.__exit__(None, None, None)\n") + + if len(body) == 0: + # If the Graph has no non-placeholder nodes, no lines for the body + # have been emitted. To continue to have valid Python code, emit a + # single pass statement + body.append("pass\n") + + if len(wrapped_fns) > 0: + wrap_name = add_global("wrap", torch.fx.wrap) + wrap_stmts = "\n".join([f'{wrap_name}("{name}")' for name in wrapped_fns]) + else: + wrap_stmts = "" + + if self._body_transformer: + body = self._body_transformer(body) + + for name, value in self.additional_globals(): + add_global(name, value) + + prologue = self._call_method_with_signature_check( + self.gen_fn_def, + free_vars, + maybe_return_annotation[0], + expanded_def=expanded_def, + ) + + # remove counter and generate lineno to node index mapping + lineno_map: dict[int, Optional[int]] = {} + prologue_len = prologue.count("\n") + 1 + new_lines: list[str] = [] + cur_idx = None + for line in "".join(body).split("\n"): + counter = _counter_regexp.search(line) + if counter is not None: + cur_idx = int(counter.group(1)) + else: + lineno_map[len(new_lines) + prologue_len] = cur_idx + new_lines.append(line) + + code = "\n".join(new_lines).lstrip("\n") + code = "\n".join(" " + line for line in code.split("\n")) + + fn_code = f""" +{wrap_stmts} + +{prologue} +{code}""" + # The +4 accounts for the empty lines before prologue in fn_code + prologue_start = wrap_stmts.count("\n") + 4 + return PythonCode( + fn_code, + globals_, + _lineno_map=lineno_map, + _prologue_start=prologue_start, + ) + + +# Ideally, we'd like to refactor all of the pytree logic into this codegen +# class. Unfortunately, there are 3 areas we currently need extra logic in FX. +# 1. In the initial symbolic trace, the pytree logic is tied up with `concrete_args`. +# 2. In the FX graph, we need to access 2 attributes - in_spec and out_spec. +# Since we can't access .graph within the FX forward, we need to copy the attribute to the module. +# 3. We currently can't register the pytree imports with `add_global` - not sure why. +class _BoxedCodeGen(CodeGen): + """ + CodeGen subclass that generates code using the "boxed" calling convention. + + The boxed calling convention takes a single list argument and clears it + after extracting the arguments, which allows for early deallocation of + input tensors. + """ + + def gen_fn_def( + self, free_vars, maybe_return_annotation, *, expanded_def: bool = False + ): + """ + Generate function definition for boxed calling convention. + + Instead of taking individual arguments, the generated function takes + a single 'args_list' parameter, extracts placeholder values from it, + and clears the list. + """ + # Generate the function signature with args_list parameter + fn_def = f"def {self._func_name}(self, args_list){maybe_return_annotation}:" + + if free_vars: + # This is horribly manual but we don't get the "raw" free vars + # without a bigger refactor. + placeholder_vars = [ + v.split(":")[0].split("=")[0].strip() for v in free_vars if v != "self" + ] + + if placeholder_vars: + fn_def += "\n args_iter = iter(args_list)" + for var in placeholder_vars: + fn_def += f"\n {var} = next(args_iter)" + fn_def += "\n args_list.clear()" + + return fn_def + + +class _PyTreeCodeGen(CodeGen): + def __init__(self, pytree_info: _PyTreeInfo): + super().__init__() + self.pytree_info: _PyTreeInfo = pytree_info + + def process_inputs(self, *inputs: Any) -> Any: + flat_args = pytree.arg_tree_leaves(*inputs) + return flat_args + + def process_outputs(self, out: Any) -> Any: + if self.pytree_info is None or self.pytree_info.out_spec is None: + return out + if not isinstance(out, (list, tuple)): + out = [out] + assert self.pytree_info.out_spec is not None + return pytree.tree_unflatten(out, self.pytree_info.out_spec) + + def _format_annotations(self, free_vars: list[str], expanded_def: bool) -> str: + """Helper to format annotations for variables in pytree codegen.""" + if not free_vars: + return "" + + has_annotation = [x for x in free_vars if ":" in x] + if not has_annotation: + return "" + + if expanded_def: + return "\n " + "\n ".join(has_annotation) + else: + return "\n " + "".join(x + "; " for x in has_annotation) + "\n" + + def gen_var_bindings(self, fn_args, free_vars, expanded_def) -> str: + in_spec = self.pytree_info.in_spec + # when kwargs is present, in_spec is tuple(args, kwargs) + has_args_kwargs_tuple = ( + in_spec.type is tuple + and in_spec.num_children == 2 + and in_spec.child(0).type is tuple + and in_spec.child(1).type is dict + ) + fn_kwargs = "{}" + fn_signature = f"[{', '.join(fn_args)}], self._in_spec" + if has_args_kwargs_tuple: + count_args = in_spec.child(0).num_children + fn_args = self.pytree_info.orig_args[:count_args] + fn_kwargs = ( + "{" + + ", ".join( + f"'{k}':{v}" + for k, v in zip( + in_spec.child(1).context, + self.pytree_info.orig_args[count_args:], + ) + ) + + "}" + ) + fn_signature = f"([{', '.join(fn_args)}], {fn_kwargs}), self._in_spec" + + # in Python, `var1: annotation1, var2: annotation2 = function_call()` is invalid. + # we need to split it to two lines: + # one for annotation: `var1: annotation1; var2: annotation2;` (note the semicolon) + # one for code: `var1, var2, = function_call()` + without_annotation = [x.split(":")[0].split("#")[0] for x in free_vars] + bindings = self._format_annotations(free_vars, expanded_def) + bindings += f""" + {", ".join(without_annotation)}, = fx_pytree.tree_flatten_spec({fn_signature})""" + return bindings + + def gen_fn_def( + self, free_vars, maybe_return_annotation, *, expanded_def: bool = False + ): + # Given a user function/model: + # myargs = [myargs0, myargs1] + # mykwargs = {'mykwargs0': ..., 'mykwargs1': ...} + # def forward(self, mypos, *myargs, mykey=None, **mykwargs): + # + # The generated code flattens all keywords into positional arguments for `forward()` + # e.g forward(self, mypos, myargs0, myargs1, mykey, mykwargs0, mykwargs1): + # + # Within `forward`, `tree_flatten_spec``still parses args and kwargs separately + # e.g. tree_flatten_spec(([mypos, myargs0, myargs1], + # {'mykey':mykey, 'mykwargs0':mykwargs0, 'mykwargs1':mykwargs1}), + # self._in_spec) + # + # If the user function/model does not have keywords, the dict is suppressed from tree_flatten_spec + # e.g. tree_flatten_spec([mypos, myargs0, myargs1]), self._in_spec) + if self.pytree_info is None: + return super().gen_fn_def( + free_vars, maybe_return_annotation, expanded_def=expanded_def + ) + + fn_args = self.pytree_info.orig_args + has_orig_self = (fn_args[0] == "self") if len(fn_args) > 0 else False + if has_orig_self: + free_vars.insert(0, "self") + fn_definition = super().gen_fn_def( + fn_args[:], maybe_return_annotation, expanded_def=expanded_def + ) + + if len(free_vars) > 0: # pytree has placeholders in it + fn_definition += self.gen_var_bindings(fn_args, free_vars, expanded_def) + return fn_definition + + def generate_output(self, output_args, *, descs: Optional[Any] = None): + if self.pytree_info and self.pytree_info.out_spec: + if descs is not None and isinstance(output_args, (list, tuple)): + return ( + self._format_multiline_container( + output_args, descs, "return pytree.tree_unflatten(" + ) + + ", self._out_spec)" + ) + else: + return ( + f"return pytree.tree_unflatten({repr(output_args)}, self._out_spec)" + ) + else: + return super().generate_output(output_args, descs=descs) + + +class _ExportCodeGen(_PyTreeCodeGen): + def __init__( + self, + pytree_info: _PyTreeInfo, + in_shuffle_graph: "GraphModule", + out_shuffle_graph: "GraphModule", + tree_leaf_names: list[str], + root: Optional[torch.nn.Module], + ): + super().__init__(pytree_info) + self.in_shuffle_graph = in_shuffle_graph + self.out_shuffle_graph = out_shuffle_graph + self.tree_leaf_names = tree_leaf_names + self.root = root + + def process_inputs(self, *inputs: Any) -> Any: + flat_args = super().process_inputs(*inputs) + if self.root is not None: + flat_args = (self.root, *flat_args) + self.flat_args = flat_args + return self.in_shuffle_graph(*flat_args) + + def process_outputs(self, out: Any) -> Any: + flat_outs = self.out_shuffle_graph(*self.flat_args, *out) + del self.flat_args + ret = super().process_outputs(flat_outs) + return ret + + def gen_fn_def(self, *args, **kwargs) -> str: + fn_def = super().gen_fn_def(*args, **kwargs) + return fn_def + + def gen_var_bindings(self, fn_args, free_vars, expanded_def) -> str: + without_annotation = [x.split(":")[0].split("#")[0] for x in free_vars] + fn_signature: str = f"{', '.join(fn_args)}" + if self.root is not None: + fn_signature = f"self, {fn_signature}" + return f""" + {", ".join(self.tree_leaf_names)}, = pytree.tree_leaves(({fn_signature},)) + {", ".join(without_annotation)}, = self._in_shuffle_graph({", ".join(self.tree_leaf_names)})""" + + def generate_output(self, output_args, *args, **kwargs) -> str: + output = f"self._out_shuffle_graph({', '.join(self.tree_leaf_names)}, {', '.join([str(a) for a in output_args])})" + return f"return pytree.tree_unflatten({output}, self._out_spec)" + + +class _FindNodesLookupTable: + """ + Side table for the graph for the purpose of doing fast queries + """ + + def __init__(self): + self.table: dict[tuple[str, Optional[Target]], dict[Node, None]] = defaultdict( + dict + ) + + def _key(self, node) -> tuple[str, Optional[Target]]: + return (node.op, node.target if node.op == "call_function" else None) + + def __contains__(self, node) -> bool: + return node in self.table[self._key(node)] + + def insert(self, node: Node) -> None: + self.table[self._key(node)][node] = None + + def remove(self, node: Node) -> None: + self.table[self._key(node)].pop(node) + + def find_nodes(self, *, op: str, target: Optional["Target"] = None): + if op == "call_function": + assert target is not None + return [*self.table[(op, target)].keys()] + + if target is None: + return [*self.table[(op, None)].keys()] + + # op is call_method, get_attr, call_module + return [node for node in self.table[(op, None)] if node.target == target] + + +@compatibility(is_backward_compatible=True) +class Graph: + """ + ``Graph`` is the main data structure used in the FX Intermediate Representation. + It consists of a series of ``Node`` s, each representing callsites (or other + syntactic constructs). The list of ``Node`` s, taken together, constitute a + valid Python function. + + For example, the following code + + .. code-block:: python + + import torch + import torch.fx + + + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.param = torch.nn.Parameter(torch.rand(3, 4)) + self.linear = torch.nn.Linear(4, 5) + + def forward(self, x): + return torch.topk( + torch.sum(self.linear(x + self.linear.weight).relu(), dim=-1), 3 + ) + + + m = MyModule() + gm = torch.fx.symbolic_trace(m) + + Will produce the following Graph:: + + print(gm.graph) + + .. code-block:: text + + graph(x): + %linear_weight : [num_users=1] = self.linear.weight + %add_1 : [num_users=1] = call_function[target=operator.add](args = (%x, %linear_weight), kwargs = {}) + %linear_1 : [num_users=1] = call_module[target=linear](args = (%add_1,), kwargs = {}) + %relu_1 : [num_users=1] = call_method[target=relu](args = (%linear_1,), kwargs = {}) + %sum_1 : [num_users=1] = call_function[target=torch.sum](args = (%relu_1,), kwargs = {dim: -1}) + %topk_1 : [num_users=1] = call_function[target=torch.topk](args = (%sum_1, 3), kwargs = {}) + return topk_1 + + For the semantics of operations represented in the ``Graph``, please see :class:`Node`. + """ + + @compatibility(is_backward_compatible=True) + def __init__( + self, + owning_module: Optional["GraphModule"] = None, + tracer_cls: Optional[type["Tracer"]] = None, + tracer_extras: Optional[dict[str, Any]] = None, + ): + """ + Construct an empty Graph. + """ + self._root: Node = Node(self, "", "root", "", (), {}) + self._used_names: dict[str, int] = {} # base name -> number + self._insert = self._root.prepend + self._len = 0 + self._graph_namespace = _Namespace() + self._owning_module = owning_module + self._tracer_cls = tracer_cls + self._tracer_extras = tracer_extras + self._codegen = CodeGen() + self._co_fields: dict[str, Any] = {} + self._find_nodes_lookup_table = _FindNodesLookupTable() + + @property + def owning_module(self): + return self._owning_module + + @owning_module.setter + def owning_module(self, mod: Optional["GraphModule"]): + self._owning_module = mod + + @property + def nodes(self) -> _node_list: + """ + Get the list of Nodes that constitute this Graph. + + Note that this ``Node`` list representation is a doubly-linked list. Mutations + during iteration (e.g. delete a Node, add a Node) are safe. + + Returns: + + A doubly-linked list of Nodes. Note that ``reversed`` can be called on + this list to switch iteration order. + """ + return _node_list(self) + + @compatibility(is_backward_compatible=False) + def output_node(self) -> Node: + output_node = next(iter(reversed(self.nodes))) + assert output_node.op == "output" + return output_node + + @compatibility(is_backward_compatible=False) + def find_nodes( + self, *, op: str, target: Optional["Target"] = None, sort: bool = True + ): + """ + Allows for fast query of nodes + + Args: + + op (str): the name of the operation + + target (Optional[Target]): the target of the node. For call_function, + the target is required. For other ops, the target is optional. + + sort (bool): whether to return nodes in the order they appear on + on the graph. + + Returns: + + Iterable of nodes with the requested op and target. + """ + node_list = self._find_nodes_lookup_table.find_nodes(op=op, target=target) + if sort: + return sorted(node_list) + return node_list + + @compatibility(is_backward_compatible=True) + def graph_copy( + self, g: "Graph", val_map: dict[Node, Node], return_output_node=False + ) -> "Optional[Argument]": + """ + Copy all nodes from a given graph into ``self``. + + Args: + + g (Graph): The source graph from which to copy Nodes. + + val_map (Dict[Node, Node]): a dictionary that will be populated with a mapping + from nodes in ``g`` to nodes in ``self``. Note that ``val_map`` can be passed + in with values in it already to override copying of certain values. + + Returns: + + The value in ``self`` that is now equivalent to the output value in ``g``, + if ``g`` had an ``output`` node. ``None`` otherwise. + """ + for node in g.nodes: + if node in val_map: + continue + if node.op == "output": + rv = map_arg(node.args[0], lambda n: val_map[n]) + return rv if not return_output_node else (rv, node) + val_map[node] = self.node_copy(node, lambda n: val_map[n]) + return None + + def __deepcopy__(self, memo=None) -> "Graph": + """ + Explicitly implement __deepcopy__ to prevent excessive recursion depth + from the default implementation. This uses graph_copy to copy the nodes + in an iterative way, rather than recursive. It also populates the + memoization table to prevent unnecessary copies (e.g. references to + nodes or other parts of the Graph from a custom GraphModule implementation. + """ + memo = memo if memo else {} + g = Graph(tracer_cls=self._tracer_cls) + output_vals = g.graph_copy(self, val_map=memo, return_output_node=True) + g._codegen = copy.deepcopy(self._codegen) + if output_vals is not None: + assert isinstance(output_vals, tuple) + output_val, old_output_node = output_vals + new_output_node = g.output( + output_val, type_expr=getattr(old_output_node, "type", None) + ) + new_output_node.meta = copy.copy(old_output_node.meta) + return g + + @compatibility(is_backward_compatible=True) + def create_node( + self, + op: str, + target: "Target", + args: Optional[tuple["Argument", ...]] = None, + kwargs: Optional[dict[str, "Argument"]] = None, + name: Optional[str] = None, + type_expr: Optional[Any] = None, + ) -> Node: + """ + Create a ``Node`` and add it to the ``Graph`` at the current insert-point. + Note that the current insert-point can be set via :meth:`Graph.inserting_before` + and :meth:`Graph.inserting_after`. + + Args: + op (str): the opcode for this Node. One of 'call_function', 'call_method', 'get_attr', + 'call_module', 'placeholder', or 'output'. The semantics of these opcodes are + described in the ``Graph`` docstring. + + args (Optional[Tuple[Argument, ...]]): is a tuple of arguments to this node. + + kwargs (Optional[Dict[str, Argument]]): the kwargs of this Node + + name (Optional[str]): an optional string name for the ``Node``. + This will influence the name of the value assigned to in the + Python generated code. + + type_expr (Optional[Any]): an optional type annotation representing the + Python type the output of this node will have. + + Returns: + + The newly-created and inserted node. + """ + # `target in _legal_ops` is checked in Node.__init__ + if not args: + args = () + else: + assert isinstance(args, tuple), "args must be a tuple" + if not kwargs: + kwargs = immutable_dict() + else: + assert isinstance(kwargs, dict), "kwargs must be a dict" + + candidate = name if name is not None else self._target_to_str(target) + name = self._graph_namespace.create_name(candidate, None) + n = Node(self, name, op, target, args, kwargs, type_expr) + + if ( + self.owning_module is not None + and getattr(self.owning_module, "_create_node_hooks", None) is not None + ): + for f in self.owning_module._create_node_hooks: + f(n) + + self._graph_namespace.associate_name_with_obj(name, n) + + self._insert(n) + self._find_nodes_lookup_table.insert(n) + self._len += 1 + return n + + @compatibility(is_backward_compatible=False) + def process_inputs(self, *args): + """ + Processes args so that they can be passed to the FX graph. + """ + return self._codegen.process_inputs(*args) + + @compatibility(is_backward_compatible=False) + def process_outputs(self, out): + return self._codegen.process_outputs(out) + + @compatibility(is_backward_compatible=True) + def erase_node(self, to_erase: Node) -> None: + """ + Erases a ``Node`` from the ``Graph``. Throws an exception if + there are still users of that node in the ``Graph``. + + Args: + + to_erase (Node): The ``Node`` to erase from the ``Graph``. + """ + if len(to_erase.users) > 0: + raise RuntimeError( + f"Tried to erase Node {to_erase} but it still had {len(to_erase.users)} " + f"users in the graph: {to_erase.users}!" + ) + if to_erase.graph != self: + raise RuntimeError(f"Attempting to remove {to_erase} from wrong graph!") + if to_erase._erased: + warnings.warn(f"erase_node({to_erase}) on an already erased node") + return + + if ( + self.owning_module is not None + and getattr(self.owning_module, "_erase_node_hooks", None) is not None + ): + for f in self.owning_module._erase_node_hooks: + f(to_erase) + + self._find_nodes_lookup_table.remove(to_erase) + # pyrefly: ignore [missing-attribute] + to_erase._remove_from_list() + to_erase._erased = True # iterators may retain handles to erased nodes + self._len -= 1 + + # Null out this Node's argument nodes so that the Nodes referred to + # can update their ``users`` accordingly + to_erase._update_args_kwargs( + map_arg(to_erase._args, lambda n: None), + map_arg(to_erase._kwargs, lambda n: None), + ) + + @compatibility(is_backward_compatible=True) + def inserting_before(self, n: Optional[Node] = None): + """Set the point at which create_node and companion methods will insert into the graph. + When used within a 'with' statement, this will temporary set the insert point and + then restore it when the with statement exits:: + + with g.inserting_before(n): + ... # inserting before node n + ... # insert point restored to what it was previously + g.inserting_before(n) # set the insert point permanently + + Args: + + n (Optional[Node]): The node before which to insert. If None this will insert before + the beginning of the entire graph. + + Returns: + A resource manager that will restore the insert point on ``__exit__``. + """ + if n is None: + return self.inserting_after(self._root) + assert n.graph == self, "Node to insert before is not in graph." + return _InsertPoint(self, n.prepend) + + @compatibility(is_backward_compatible=True) + def inserting_after(self, n: Optional[Node] = None): + """Set the point at which create_node and companion methods will insert into the graph. + When used within a 'with' statement, this will temporary set the insert point and + then restore it when the with statement exits:: + + with g.inserting_after(n): + ... # inserting after node n + ... # insert point restored to what it was previously + g.inserting_after(n) # set the insert point permanently + + Args: + + n (Optional[Node]): The node before which to insert. If None this will insert after + the beginning of the entire graph. + + Returns: + A resource manager that will restore the insert point on ``__exit__``. + """ + if n is None: + return self.inserting_before(self._root) + assert n.graph == self, "Node to insert after is not in graph." + return _InsertPoint(self, n.append) + + @compatibility(is_backward_compatible=True) + def placeholder( + self, + name: str, + type_expr: Optional[Any] = None, + default_value: Any = inspect.Signature.empty, + ) -> Node: + """ + Insert a ``placeholder`` node into the Graph. A ``placeholder`` represents + a function input. + + Args: + + name (str): A name for the input value. This corresponds to the name + of the positional argument to the function this ``Graph`` represents. + + type_expr (Optional[Any]): an optional type annotation representing the + Python type the output of this node will have. This is needed in some + cases for proper code generation (e.g. when the function is used + subsequently in TorchScript compilation). + + default_value (Any): The default value this function argument should take + on. NOTE: to allow for `None` as a default value, `inspect.Signature.empty` + should be passed as this argument to specify that the parameter does _not_ + have a default value. + + .. note:: + The same insertion point and type expression rules apply for this method + as ``Graph.create_node``. + """ + args = () if default_value is inspect.Signature.empty else (default_value,) + return self.create_node("placeholder", name, args=args, type_expr=type_expr) + + @compatibility(is_backward_compatible=True) + def get_attr(self, qualified_name: str, type_expr: Optional[Any] = None) -> Node: + """ + Insert a ``get_attr`` node into the Graph. A ``get_attr`` ``Node`` represents the + fetch of an attribute from the ``Module`` hierarchy. + + Args: + + qualified_name (str): the fully-qualified name of the attribute to be retrieved. + For example, if the traced Module has a submodule named ``foo``, which has a + submodule named ``bar``, which has an attribute named ``baz``, the qualified + name ``foo.bar.baz`` should be passed as ``qualified_name``. + + type_expr (Optional[Any]): an optional type annotation representing the + Python type the output of this node will have. + + + Returns: + + The newly-created and inserted ``get_attr`` node. + + .. note:: + The same insertion point and type expression rules apply for this method + as ``Graph.create_node``. + """ + + def _get_attr_reference_exists( + mod: torch.nn.Module, qualified_name: str + ) -> bool: + module_path, _, name = qualified_name.rpartition(".") + + try: + submod: torch.nn.Module = mod.get_submodule(module_path) + except AttributeError: + warnings.warn(f"Failed to fetch module {module_path}!") + return False + + if not hasattr(submod, name): + return False + + res = getattr(submod, name) + + if ( + not isinstance(res, torch.nn.Module) + and not isinstance(res, torch.nn.Parameter) + and name not in submod._buffers + ): + return False + + return True + + if self.owning_module and not _get_attr_reference_exists( + self.owning_module, qualified_name + ): + warnings.warn( + "Attempted to insert a get_attr Node with no " + "underlying reference in the owning " + "GraphModule! Call " + "GraphModule.add_submodule to add the " + "necessary submodule, " + "GraphModule.add_parameter to add the " + "necessary Parameter, or " + "nn.Module.register_buffer to add the " + "necessary buffer", + stacklevel=2, + ) + return self.create_node("get_attr", qualified_name, type_expr=type_expr) + + @compatibility(is_backward_compatible=True) + def call_module( + self, + module_name: str, + args: Optional[tuple["Argument", ...]] = None, + kwargs: Optional[dict[str, "Argument"]] = None, + type_expr: Optional[Any] = None, + ) -> Node: + """ + Insert a ``call_module`` ``Node`` into the ``Graph``. A ``call_module`` node + represents a call to the forward() function of a ``Module`` in the ``Module`` + hierarchy. + + Args: + + module_name (str): The qualified name of the ``Module`` in the ``Module`` + hierarchy to be called. For example, if the traced ``Module`` has a + submodule named ``foo``, which has a submodule named ``bar``, the + qualified name ``foo.bar`` should be passed as ``module_name`` to + call that module. + + args (Optional[Tuple[Argument, ...]]): The positional arguments to be passed + to the called method. Note that this should *not* include a ``self`` argument. + + kwargs (Optional[Dict[str, Argument]]): The keyword arguments to be passed + to the called method + + type_expr (Optional[Any]): an optional type annotation representing the + Python type the output of this node will have. + + Returns: + + The newly-created and inserted ``call_module`` node. + + .. note:: + The same insertion point and type expression rules apply for this method + as :meth:`Graph.create_node`. + """ + if self.owning_module and self.owning_module.get_submodule(module_name) is None: + warnings.warn( + "Attempted to insert a call_module Node with " + "no underlying reference in the owning " + "GraphModule! Call " + "GraphModule.add_submodule to add the " + "necessary submodule" + ) + return self.create_node( + "call_module", module_name, args, kwargs, type_expr=type_expr + ) + + @compatibility(is_backward_compatible=True) + def call_method( + self, + method_name: str, + args: Optional[tuple["Argument", ...]] = None, + kwargs: Optional[dict[str, "Argument"]] = None, + type_expr: Optional[Any] = None, + ) -> Node: + """ + Insert a ``call_method`` ``Node`` into the ``Graph``. A ``call_method`` node + represents a call to a given method on the 0th element of ``args``. + + Args: + + method_name (str): The name of the method to apply to the self argument. + For example, if args[0] is a ``Node`` representing a ``Tensor``, + then to call ``relu()`` on that ``Tensor``, pass ``relu`` to ``method_name``. + + args (Optional[Tuple[Argument, ...]]): The positional arguments to be passed + to the called method. Note that this *should* include a ``self`` argument. + + kwargs (Optional[Dict[str, Argument]]): The keyword arguments to be passed + to the called method + + type_expr (Optional[Any]): an optional type annotation representing the + Python type the output of this node will have. + + Returns: + + The newly created and inserted ``call_method`` node. + + .. note:: + The same insertion point and type expression rules apply for this method + as :meth:`Graph.create_node`. + """ + return self.create_node( + "call_method", method_name, args, kwargs, type_expr=type_expr + ) + + @compatibility(is_backward_compatible=True) + def call_function( + self, + the_function: Callable[..., Any], + args: Optional[tuple["Argument", ...]] = None, + kwargs: Optional[dict[str, "Argument"]] = None, + type_expr: Optional[Any] = None, + name: Optional[str] = None, + ) -> Node: + """ + Insert a ``call_function`` ``Node`` into the ``Graph``. A ``call_function`` node + represents a call to a Python callable, specified by ``the_function``. + + Args: + + the_function (Callable[..., Any]): The function to be called. Can be any PyTorch + operator, Python function, or member of the ``builtins`` or ``operator`` + namespaces. + + args (Optional[Tuple[Argument, ...]]): The positional arguments to be passed + to the called function. + + kwargs (Optional[Dict[str, Argument]]): The keyword arguments to be passed + to the called function + + type_expr (Optional[Any]): an optional type annotation representing the + Python type the output of this node will have. + + name (Optional[str]): The name of the node. If not specified, set to None + + Returns: + + The newly created and inserted ``call_function`` node. + + .. note:: + The same insertion point and type expression rules apply for this method + as :meth:`Graph.create_node`. + """ + return self.create_node( + "call_function", the_function, args, kwargs, name=name, type_expr=type_expr + ) + + @compatibility(is_backward_compatible=True) + def node_copy( + self, node: Node, arg_transform: Callable[[Node], "Argument"] = lambda x: x + ) -> Node: + """ + Copy a node from one graph into another. ``arg_transform`` needs to transform arguments from + the graph of node to the graph of self. Example:: + + # Copying all the nodes in `g` into `new_graph` + g: torch.fx.Graph = ... + new_graph = torch.fx.graph() + value_remap = {} + for node in g.nodes: + value_remap[node] = new_graph.node_copy(node, lambda n: value_remap[n]) + + Args: + + node (Node): The node to copy into ``self``. + + arg_transform (Callable[[Node], Argument]): A function that transforms + ``Node`` arguments in node's ``args`` and ``kwargs`` into the + equivalent argument in ``self``. In the simplest case, this should + retrieve a value out of a table mapping Nodes in the original + graph to ``self``. + """ + args = map_arg(node.args, arg_transform) + kwargs = map_arg(node.kwargs, arg_transform) + assert isinstance(args, tuple) + assert isinstance(kwargs, dict) + result_node = self.create_node( + node.op, node.target, args, kwargs, node.name, node.type + ) + result_node.meta = copy.copy(node.meta) + return result_node + + @compatibility(is_backward_compatible=True) + def output(self, result: "Argument", type_expr: Optional[Any] = None): + """ + Insert an ``output`` ``Node`` into the ``Graph``. An ``output`` node represents + a ``return`` statement in Python code. ``result`` is the value that should + be returned. + + Args: + + result (Argument): The value to be returned. + + type_expr (Optional[Any]): an optional type annotation representing the + Python type the output of this node will have. + + .. note:: + + The same insertion point and type expression rules apply for this method + as ``Graph.create_node``. + """ + return self.create_node( + op="output", target="output", args=(result,), type_expr=type_expr + ) + + def _target_to_str(self, target: Optional[Target]) -> str: + if callable(target): + op = target.__name__ + else: + assert isinstance(target, str) + op = target + if _is_magic(op): + op = op[2:-2] + op = _snake_case(op) + return op + + @compatibility(is_backward_compatible=True) + def python_code( + self, + root_module: str, + *, + verbose: bool = False, + include_stride: bool = False, + include_device: bool = False, + colored: bool = False, + expanded_def: bool = False, + record_func: bool = False, + ) -> PythonCode: + """ + Turn this ``Graph`` into valid Python code. + + Args: + + root_module (str): The name of the root module on which to look-up + qualified name targets. This is usually 'self'. + + Returns: + + A PythonCode object, consisting of two fields: + src: the Python source code representing the object + globals: a dictionary of global names in `src` -> the objects that they reference. + """ + # NOTE: [Graph Namespaces] + # + # There are two types of symbols in generated Python source code: + # locals and globals. + # Locals are locally defined by the output of a node in the Graph. + # Globals are references to external objects, like functions or types. + # + # When generating Python code, we need to make sure to name things + # appropriately. In particular: + # - All names should be unique, to avoid weird shadowing bugs. + # - These names need to be consistent, e.g. a object should always be + # referenced by the same name. + # + # To do this, we create a new namespace just for this source. All names + # that get printed must come from this namespace. + # + # Why can't we reuse node.name? Because it was generated within the + # namespace `self._graph_namespace`. In order to provide uniqueness + # over both locals (node.name) *and* globals, we create a completely + # new namespace to put all identifiers in. + namespace = _Namespace() + + # Override Node's repr to generate a valid name within our namespace. + # Since repr() is designed to produce a valid Python expression, it + # makes sense to reuse it. This way, it's easy to print something like + # Tuple[Node, Node] by simply calling repr() on it. Node's __repr__ is + # implemented cooperatively to allow this. + def node_repr(n: Node): + return namespace.create_name(n.name, n) + + @contextmanager + def override_node_repr(graph: Graph): + orig_repr_fns = {} + for node in graph.nodes: + orig_repr_fns[node] = node._repr_fn + node._repr_fn = node_repr + try: + yield None + finally: + # restore the original repr functions + for node in graph.nodes: + node._repr_fn = orig_repr_fns[node] + + with override_node_repr(self): + return self._python_code( + root_module, + namespace, + verbose=verbose, + include_stride=include_stride, + include_device=include_device, + colored=colored, + expanded_def=expanded_def, + record_func=record_func, + ) + + def _python_code( + self, + root_module: str, + namespace: _Namespace, + *, + verbose: bool = False, + include_stride: bool = False, + include_device: bool = False, + colored: bool = False, + expanded_def: bool = False, + record_func: bool = False, + ) -> PythonCode: + return self._codegen._gen_python_code( + self.nodes, + root_module, + namespace, + verbose=verbose, + include_stride=include_stride, + include_device=include_device, + colored=colored, + expanded_def=expanded_def, + record_func=record_func, + ) + + def __str__(self) -> str: + """ + Return a human-readable (not machine-readable) string representation + of this Graph + """ + placeholder_names: list[str] = [] + # This is a one-element array just so ``format_node`` can modify the closed + # over value + maybe_return_typename: list[str] = [""] + + node_strs = [node.format_node(placeholder_names) for node in self.nodes] + param_str = ", ".join(placeholder_names) + s = f"graph({param_str}){maybe_return_typename[0]}:" + for node_str in node_strs: + if node_str: + s += "\n " + node_str + return s + + @compatibility(is_backward_compatible=True) + def print_tabular(self): + """ + Prints the intermediate representation of the graph in tabular + format. Note that this API requires the ``tabulate`` module to be + installed. + """ + try: + from tabulate import tabulate + except ImportError: + print( + "`print_tabular` relies on the library `tabulate`, " + "which could not be found on this machine. Run `pip " + "install tabulate` to install the library." + ) + raise + + node_specs = [[n.op, n.name, n.target, n.args, n.kwargs] for n in self.nodes] + print( + tabulate(node_specs, headers=["opcode", "name", "target", "args", "kwargs"]) + ) + + @compatibility(is_backward_compatible=True) + def lint(self): + """ + Runs various checks on this Graph to make sure it is well-formed. In + particular: + - Checks Nodes have correct ownership (owned by this graph) + - Checks Nodes appear in topological order + - If this Graph has an owning GraphModule, checks that targets + exist in that GraphModule + """ + + # Check topo order + def check_arg(arg: Node, n: Optional[Node] = None) -> None: + context_str = f" of Node '{n}' " if n else " " + if arg.graph is not self: + raise RuntimeError( + f"Argument '{arg}'{context_str}does not belong to this Graph, " + f"but was used as an argument! If you are copying nodes from another graph, make " + f"sure to use ``arg_transform`` on node_copy() to remap values\n{self}" + ) + if arg not in seen_values: + raise RuntimeError( + f"Argument '{arg}'{context_str}was used before it has been " + f"defined! Please check that Nodes in the graph are topologically ordered\n{self}" + ) + + seen_names: set[str] = set() + seen_values: set[Node] = set() + for node in self.nodes: + if node.op not in _legal_ops: + raise RuntimeError(f"Node {node} had unknown opcode {node.op}!") + if node.graph is not self: + raise RuntimeError(f"Node '{node}' does not belong to this Graph!") + if node not in self._find_nodes_lookup_table: + raise RuntimeError(f"Node '{node}' is not added to the side table") + for arg in node._input_nodes: + check_arg(arg, node) + seen_values.add(node) + + if node.name in seen_names: + raise RuntimeError(f"Node redefined name {node.name}!") + seen_names.add(node.name) + + # Check targets are legit + if self.owning_module: + for node in self.nodes: + if node.op == "call_function": + if not callable(node.target): + raise ValueError( + f"Node {node} target {node.target} has type {torch.typename(node.target)} but " + "a Callable is expected" + ) + else: + if not isinstance(node.target, str): + raise ValueError( + f"Node {node} target {node.target} has type {torch.typename(node.target)} but " + "a str is expected" + ) + if node.op in ["get_attr", "call_module"]: + # pyrefly: ignore [missing-attribute] + target_atoms = node.target.split(".") + m_itr = self.owning_module + for i, atom in enumerate(target_atoms): + new_m_itr = getattr(m_itr, atom, None) + seen_qualname = ".".join(target_atoms[:i]) + if new_m_itr is None: + raise RuntimeError( + f"Node {node} target {node.target} references nonexistent attribute " + f"{atom} of {seen_qualname}" + ) + if node.op == "call_module" and not isinstance( + new_m_itr, torch.nn.Module + ): + raise RuntimeError( + f"Node {node} target {node.target} {atom} of {seen_qualname} does " + "not reference an nn.Module" + ) + + m_itr = new_m_itr + + @compatibility(is_backward_compatible=True) + def eliminate_dead_code( + self, is_impure_node: Optional[Callable[[Node], bool]] = None + ) -> bool: + """ + Remove all dead code from the graph, based on each node's number of + users, and whether the nodes have any side effects. The graph must be + topologically sorted before calling. + + Args: + is_impure_node (Optional[Callable[[Node], bool]]): A function that returns + whether a node is impure. If this is None, then the default behavior is to + use Node.is_impure. + + Returns: + bool: Whether the graph was changed as a result of the pass. + + Example: + + Before dead code is eliminated, `a` from `a = x + 1` below has no users + and thus can be eliminated from the graph without having an effect. + + .. code-block:: python + + def forward(self, x): + a = x + 1 + return x + self.attr_1 + + After dead code is eliminated, `a = x + 1` has been removed, and the rest + of `forward` remains. + + .. code-block:: python + + def forward(self, x): + return x + self.attr_1 + + .. warning:: + + Dead code elimination has some heuristics to avoid removing + side-effectful nodes (see Node.is_impure) but in general coverage + is very bad, so you should assume that this method is not sound + to call unless you know that your FX graph consists entirely + of functional operations or you supply your own custom + function for detecting side-effectful nodes. + """ + from torch.utils._ordered_set import OrderedSet + + # Lint the graph first to make sure its topologically sorted, otherwise + # DCE below will not behave as expected. + self.lint() + + impure_random = True + if torch._guards.TracingContext.try_get(): + impure_random = torch._inductor.config.fallback_random + + def has_side_effect(node): + if is_impure_node is not None: + return is_impure_node(node) + return node.is_impure(impure_random) + + # Reverse iterate so that when we remove a node, any nodes used as an + # input to that node have an updated user count that no longer reflects + # the removed node. + removed_nodes = set() + for node in reversed(self.nodes): + if not has_side_effect(node) and len(node.users) == 0: + self.erase_node(node) + removed_nodes.add(node.name) + + changed = len(removed_nodes) > 0 + if changed: + log.info("The following nodes were dead code eliminated: %s", removed_nodes) + + # Call DCE on the subgraphs + if self.owning_module is not None: + subgraph_names = OrderedSet( + x.target for x in self.find_nodes(op="get_attr") + ) + for child_name, child_module in self.owning_module.named_children(): + # Sometimes an owning_module can have unused children. Skip them + # by checking them from get_attr node targets. + if child_name in subgraph_names and isinstance( + child_module, torch.fx.GraphModule + ): + changed |= child_module.graph.eliminate_dead_code() + child_module.recompile() + + return changed + + @compatibility(is_backward_compatible=False) + def set_codegen(self, codegen: CodeGen): + self._codegen = codegen + + @compatibility(is_backward_compatible=False) + def on_generate_code( + self, + make_transformer: Callable[[Optional[TransformCodeFunc]], TransformCodeFunc], + ): + """Register a transformer function when python code is generated + + Args: + make_transformer (Callable[[Optional[TransformCodeFunc]], TransformCodeFunc]): + a function that returns a code transformer to be registered. + This function is called by `on_generate_code` to obtain the + code transformer. + + This function is also given as its input the currently + registered code transformer (or None if nothing is registered), + in case it is not desirable to overwrite it. This is useful to + chain code transformers together. + + Returns: + a context manager that when used in a `with` statement, to automatically + restore the previously registered code transformer. + + Example: + + .. code-block:: python + + + gm: fx.GraphModule = ... + + + # This is a code transformer we want to register. This code + # transformer prepends a pdb import and trace statement at the very + # beginning of the generated torch.fx code to allow for manual + # debugging with the PDB library. + def insert_pdb(body): + return ["import pdb; pdb.set_trace()\\n", *body] + + + # Registers `insert_pdb`, and overwrites the current registered + # code transformer (given by `_` to the lambda): + gm.graph.on_generate_code(lambda _: insert_pdb) + + # Or alternatively, registers a code transformer which first + # runs `body` through existing registered transformer, then + # through `insert_pdb`: + gm.graph.on_generate_code( + lambda current_trans: ( + lambda body: insert_pdb( + current_trans(body) if current_trans else body + ) + ) + ) + + gm.recompile() + gm(*inputs) # drops into pdb + + + This function can also be used as a context manager, with the benefit to + automatically restores the previously registered code transformer: + + .. code-block:: python + + # ... continue from previous example + + with gm.graph.on_generate_code(lambda _: insert_pdb): + # do more stuff with `gm`... + gm.recompile() + gm(*inputs) # drops into pdb + + # now previous code transformer is restored (but `gm`'s code with pdb + # remains - that means you can run `gm` with pdb here too, until you + # run next `recompile()`). + """ + on_gen_code_old = self._codegen._body_transformer + self._codegen._body_transformer = make_transformer(on_gen_code_old) + + @contextlib.contextmanager + def on_generate_code_context_manager(): + try: + yield + finally: + self._codegen._body_transformer = on_gen_code_old + + return on_generate_code_context_manager() + + def _clear_nodes(self) -> None: + for node in reversed(self.nodes): + node.meta.clear() + self.erase_node(node) + + +@contextmanager +def _override_sym_repr( + override: Callable[["torch.types.PySymType"], str], +) -> Iterator[None]: + tmp = CodeGen._sym_repr + try: + CodeGen._sym_repr = override + yield + finally: + CodeGen._sym_repr = tmp + + +def _identity(x): + return x + + +def _make_color_fn(code): + def f(s): + reset = "\033[0m" + return f"{code}{s}{reset}" + + return f + + +_color_codes = { + "yellow": "\033[33m", + "cyan": "\033[36m", + "green": "\033[32m", + "blue": "\033[34m", + "red": "\033[31m", + "dim": "\033[2m", + "dim_blue": "\033[2m\033[34m", + "dim_green": "\033[2m\033[32m", +} +_color_fns = {k: _make_color_fn(v) for k, v in _color_codes.items()} +_counter_regexp = re.compile(r"# COUNTER: (\d+)") + + +reflectable_magic_methods = { + "add": "{} + {}", + "sub": "{} - {}", + "mul": "{} * {}", + "floordiv": "{} // {}", + "truediv": "{} / {}", + "div": "{} / {}", + "mod": "{} % {}", + "pow": "{} ** {}", + "lshift": "{} << {}", + "rshift": "{} >> {}", + "and_": "{} & {}", + "or_": "{} | {}", + "xor": "{} ^ {}", + "getitem": "{}[{}]", + "matmul": "{} @ {}", +} + +magic_methods = { + "eq": "{} == {}", + "ne": "{} != {}", + "lt": "{} < {}", + "gt": "{} > {}", + "le": "{} <= {}", + "ge": "{} >= {}", + "pos": "+{}", + "neg": "-{}", + "invert": "~{}", + **reflectable_magic_methods, +} + +inplace_methods = { + "iadd": "{} += {}", + "iand": "{} &= {}", + "ifloordiv": "{} //= {}", + "ilshift": "{} <<= {}", + "imod": "{} %= {}", + "imul": "{} *= {}", + "imatmul": "{} @= {}", + "ior": "{} |= {}", + "ipow": "{} **= {}", + "irshift": "{} >>= {}", + "isub": "{} -= {}", + "itruediv": "{} /= {}", + "ixor": "{} ^= {}", + "setitem": "{}[{}] = {}", +} diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/graph_module.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/graph_module.py new file mode 100644 index 0000000000000000000000000000000000000000..ab33d7bf321c9ba4e41eed45732c91e38c545593 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/graph_module.py @@ -0,0 +1,1183 @@ +# mypy: allow-untyped-defs +import base64 +import contextlib +import copy +import hashlib +import itertools +import linecache +import os +import sys +import traceback +import warnings +from collections.abc import Callable +from pathlib import Path +from typing import Any, Optional, Union + +import torch +import torch.nn as nn +import torch.overrides +from torch.nn.modules.module import _addindent +from torch.package import Importer, PackageExporter, PackageImporter, sys_importer + +from ._compatibility import compatibility +from .experimental import _config as fx_experimental_config +from .graph import ( + _BoxedCodeGen, + _custom_builtins, + _is_from_torch, + _override_sym_repr, + _PyTreeCodeGen, + Graph, + PythonCode, +) + + +__all__ = [ + "reduce_graph_module", + "reduce_package_graph_module", + "GraphModule", +] + +_USER_PRESERVED_ATTRIBUTES_KEY = "_user_preserved_attributes" +FX_GRAPH_MODULE_FILE_PREFIX = "fx_generated_" + + +# Normal exec loses the source code, however we can work with +# the linecache module to recover it. +# Using _exec_with_source will add it to our local cache +# and then tools like TorchScript will be able to get source info. +class _EvalCacheLoader: + def __init__(self): + self.eval_cache = {} + self.next_id = 0 + + def cache(self, src: str, globals: dict[str, Any], co_fields=None): + """Store the source in a private cache, and add a lazy entry in linecache + that allows the source to be retrieved by 'filename'. + + Args: + src (str): The module source to cache + globals (dict): The module globals + + Returns: + str: The cache key (and dummy filename) generated for src. + """ + + key = self._get_key() + if co_fields: + if "co_filename" in co_fields: + # If only co_filename is provided, use it directly as the key + if "co_firstlineno" not in co_fields or "co_name" not in co_fields: + key = co_fields["co_filename"] + else: + # Full co_fields with all three components + key += f" from {co_fields['co_filename']}:{co_fields['co_firstlineno']} in {co_fields['co_name']}" + self.eval_cache[key] = src + + # Don't mutate globals so that this loader is only used + # to populate linecache, and doesn't interact with other modules + # that might check `__loader__` + globals_copy = globals.copy() + globals_copy["__file__"] = key + globals_copy["__name__"] = key + globals_copy["__loader__"] = self + linecache.lazycache(key, globals_copy) + + return key + + # Part of the loader protocol (PEP 302) + # linecache will use this method when trying to find source code + def get_source(self, module_name) -> Optional[str]: + if module_name in self.eval_cache: + return self.eval_cache[module_name] + return None + + def _get_key(self): + key = f".{self.next_id}" + self.next_id += 1 + return key + + +_loader = _EvalCacheLoader() + + +def _exec_with_source(src: str, globals: dict[str, Any], co_fields=None): + key = _loader.cache(src, globals, co_fields) + exec(compile(src, key, "exec"), globals) + + +def _forward_from_src(src: str, globals: dict[str, Any], co_fields=None): + return _method_from_src( + method_name="forward", src=src, globals=globals, co_fields=co_fields + ) + + +def _method_from_src( + method_name: str, src: str, globals: dict[str, Any], co_fields=None +) -> Callable: + # avoid mutating the passed in dict + globals_copy = globals.copy() + _exec_with_source(src, globals_copy, co_fields) + fn = globals_copy[method_name] + del globals_copy[method_name] + return fn + + +def _format_import_statement(name: str, obj: Any, importer: Importer) -> str: + if name in _custom_builtins: + return _custom_builtins[name].import_str + if _is_from_torch(name): + return "import torch" + module_name, attr_name = importer.get_name(obj) + return f"from {module_name} import {attr_name} as {name}" + + +def _format_import_block(globals: dict[str, Any], importer: Importer): + import_strs: set[str] = { + _format_import_statement(name, obj, importer) for name, obj in globals.items() + } + # Sort the imports so we have a stable import block that allows us to + # hash the graph module and get a consistent key for use in a cache. + return "\n".join(sorted(import_strs)) + + +@compatibility(is_backward_compatible=True) +def reduce_graph_module(body: dict[Any, Any], import_block: str) -> torch.nn.Module: + # BC: attribute name was changed from `code` to `_code` to facilitate + # making `code` into a property and adding a docstring to it + fn_src = body.get("_code") or body["code"] + forward = _forward_from_src(import_block + fn_src, {}) + return _deserialize_graph_module(forward, body) + + +@compatibility(is_backward_compatible=True) +def reduce_package_graph_module( + importer: PackageImporter, body: dict[Any, Any], generated_module_name: str +) -> torch.nn.Module: + forward = importer.import_module(generated_module_name).forward + return _deserialize_graph_module(forward, body) + + +# We create a dummy class here because symbolic_trace pulls the forward() +# function off of the class, rather than the instance. This class is used +# in _deserialize_graph_module() below. +class _CodeOnlyModule(torch.nn.Module): + def __init__(self, body): + super().__init__() + self.__dict__ = body + + +def _deserialize_graph_module( + forward, body: dict[Any, Any], graph_module_cls=None +) -> torch.nn.Module: + """ + Deserialize a GraphModule given the dictionary of the original module, + using the code to reconstruct the graph. We delete the actual graph before + saving the dictionary so that changes to the in-memory graph format do not + get serialized. + """ + + # Try to retrieve the forward source in a backward-compatible way + _CodeOnlyModule.forward = forward + + tracer_cls = body.get("_tracer_cls") + if tracer_cls is None: + from ._symbolic_trace import Tracer + + tracer_cls = Tracer + + graphmodule_cls_name = body.get("_graphmodule_cls_name", "GraphModule") + + # This is a workaround for a mypy linter issue related to + # passing base class as an argument - https://github.com/python/mypy/issues/5865. + cls_tracer: Any = tracer_cls + + class KeepModules(cls_tracer): + # we shouldn't trace into any of the submodules, + # because they were not traced in the original GraphModule + def is_leaf_module(self, _: torch.nn.Module, __: str) -> bool: + return True + + com = _CodeOnlyModule(body) + + tracer_extras = body.get("_tracer_extras", {}) + graph = KeepModules().trace(com, **tracer_extras) + + # Recover node.meta["stack_trace"] after re-tracing + node_meta_stack_trace = body.get("_graphmodule_graph_node_meta_stack_trace") + if node_meta_stack_trace is not None: + del body["_graphmodule_graph_node_meta_stack_trace"] + for node in graph.nodes: + if node_meta_stack_trace.get(node.name, None) is not None: + node.meta["stack_trace"] = node_meta_stack_trace[node.name] + + # Manually set Tracer class on the reconstructed Graph, to avoid + # referencing the private local subclass KeepModules. + graph._tracer_cls = tracer_cls + from ._lazy_graph_module import _make_graph_module + + gm = _make_graph_module( + com, graph, class_name=graphmodule_cls_name, graph_module_cls=graph_module_cls + ) + + # The GraphModule constructor only retains attributes referenced by the graph. + # In this case, our goal is return a GraphModule as close to identical as the one + # put into the package. If any additional attributes were present in body, + # we should keep them. + for k, v in body.items(): + if not hasattr(gm, k): + setattr(gm, k, v) + return gm + + +# copy an attribute value with qualified name 'target' from 'from_module' to 'to_module' +# This installs empty Modules where none exist yet if they are subpaths of target +def _copy_attr(from_module: torch.nn.Module, to_module: torch.nn.Module, target: str): + *prefix, field = target.split(".") + for item in prefix: + f = getattr(from_module, item) + t = getattr(to_module, item, None) + if f is t: + # we have already installed one of its parents + # (e.g. target = root.linear.weight, but we have already installed root.linear) + # once we install a parent, we no longer need to copy the children + # since all the needed properties will already be present + return + + if t is None: + t = torch.nn.Module() + setattr(to_module, item, t) + from_module, to_module = f, t + + orig = getattr(from_module, field) + # If it is a tensor and not a parameter attribute of a module, it should be a named buffer. + # So, we register it as a named buffer in the target module. + if isinstance(orig, torch.Tensor) and not isinstance(orig, torch.nn.Parameter): + to_module.register_buffer(field, orig) + else: + setattr(to_module, field, orig) + + +# Assign attribute 'from_obj' to the qualified name 'target' on 'to_module +# This installs empty Modules where none exist yet if they are subpaths of target +def _assign_attr(from_obj: Any, to_module: torch.nn.Module, target: str): + *prefix, field = target.split(".") + for item in prefix: + t = getattr(to_module, item, None) + + if t is None: + t = torch.nn.Module() + setattr(to_module, item, t) + to_module = t + + # If it is a tensor and not a parameter attribute of a module, it should be a named buffer. + # So, we register it as a named buffer in the target module. + if isinstance(from_obj, torch.Tensor) and not isinstance( + from_obj, torch.nn.Parameter + ): + to_module.register_buffer(field, from_obj) + else: + setattr(to_module, field, from_obj) + + +# Recursively look up target from a graph module. +def _get_attr(model: torch.nn.Module, attr_name: str): + return _get_attr_via_attr_list(model, attr_name.split(".")) + + +def _del_attr(model: torch.nn.Module, attr_name: str): + attr_names = attr_name.split(".") + t = _get_attr_via_attr_list(model, attr_names[:-1]) + return delattr(t, attr_names[-1]) + + +def _get_attr_via_attr_list(model: torch.nn.Module, attr_list: list[str]): + if len(attr_list) == 0: + return model + *prefix, field = attr_list + t = model + for item in prefix: + t = getattr(t, item, None) # type: ignore[assignment] + assert t is not None + + return getattr(t, field) + + +def _has_attr(model: torch.nn.Module, attr_name: str): + *prefix, field = attr_name.split(".") + t = model + for item in prefix: + t = hasattr(t, item) # type: ignore[assignment] + if t is False: + return False + + return hasattr(t, field) + + +def _print_readable( + module, + module_name, + print_output=True, + include_stride=False, + include_device=False, + colored=False, + expanded_def=False, +): + graph = module.graph + assert graph is not None and isinstance(graph, torch.fx.Graph), ( + "print_readable must be used on a module with a graph" + ) + + verbose_python_code = graph.python_code( + root_module="self", + verbose=True, + include_stride=include_stride, + include_device=include_device, + colored=colored, + expanded_def=expanded_def, + ) + module_code = verbose_python_code.src + module_code = module_code.lstrip("\n") + module_code = f"class {module_name}(torch.nn.Module):\n" + module_code + module_code = _addindent(module_code, 4) + + submodule_code_list = [""] + for submodule_name, submodule in module.named_children(): + if hasattr(submodule, "graph"): + submodule_code_list.append( + _print_readable( + submodule, + submodule_name, + print_output=False, + include_stride=include_stride, + include_device=include_device, + colored=colored, + ) + ) + submodule_code = "\n".join(submodule_code_list) + submodule_code = _addindent(submodule_code, 4) + + output = module_code + submodule_code + if print_output: + print(module_code + submodule_code) + return output + + +def _metadata_hash(code: str, node_metadata: dict) -> str: + """ + Create a content-addressed hash from code and metadata. + + Args: + code: The source code string + lineno_map: Mapping from line numbers to node indices + node_metadata: Metadata for each node + + Returns: + A 51-character base32-encoded hash + """ + import json + + # Create a deterministic string representation of all components + # We use JSON to ensure consistent serialization + hash_data = { + "code": code, + "node_metadata": node_metadata, + } + hashing_str = json.dumps(hash_data).encode("utf-8") + + # [:51] to strip off the "Q====" suffix common to every hash value. + return ( + base64.b32encode(hashlib.sha256(hashing_str).digest())[:51] + .decode("utf-8") + .lower() + ) + + +class _WrappedCall: + def __init__(self, cls, cls_call): + self.cls = cls + self.cls_call = cls_call + + # Previously, if an error occurred when valid + # symbolically-traced code was run with an invalid input, the + # user would see the source of the error as coming from + # `File "`, where N is some number. We use + # this function to generate a more informative error message. We + # return the traceback itself, a message explaining that the + # error occurred in a traced Module's generated forward + # function, and five lines of context surrounding the faulty + # line + @staticmethod + def _generate_error_message(frame_summary: traceback.FrameSummary) -> str: + # auxiliary variables (for readability) + err_lineno = frame_summary.lineno + assert err_lineno is not None + line = frame_summary.line + assert line is not None + err_line_len = len(line) + all_src_lines = linecache.getlines(frame_summary.filename) + + # constituent substrings of the error message + tb_repr = torch._dynamo.disable( + traceback.format_exc, + reason="do not trace into traceback.format_exc when generating error message", + )() + custom_msg = ( + "Call using an FX-traced Module, " + f"line {err_lineno} of the traced Module's " + "generated forward function:" + ) + before_err = "".join(all_src_lines[err_lineno - 2 : err_lineno]) + marker = "~" * err_line_len + "~~~ <--- HERE" + err_and_after_err = "\n".join(all_src_lines[err_lineno : err_lineno + 2]) + + # joined message + return "\n".join([tb_repr, custom_msg, before_err, marker, err_and_after_err]) + + def __call__(self, obj, *args, **kwargs): + try: + if self.cls_call is not None: + return self.cls_call(obj, *args, **kwargs) + else: + return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc] + except Exception as e: + assert e.__traceback__ + topmost_framesummary: traceback.FrameSummary = ( + traceback.StackSummary.extract(traceback.walk_tb(e.__traceback__))[-1] + ) + if "eval_with_key" in topmost_framesummary.filename: + print( + _WrappedCall._generate_error_message(topmost_framesummary), + file=sys.stderr, + ) + raise e.with_traceback(None) # noqa: B904 + else: + raise e + + +@compatibility(is_backward_compatible=True) +class GraphModule(torch.nn.Module): + """ + GraphModule is an nn.Module generated from an fx.Graph. Graphmodule has a + ``graph`` attribute, as well as ``code`` and ``forward`` attributes generated + from that ``graph``. + + .. warning:: + + When ``graph`` is reassigned, ``code`` and ``forward`` will be automatically + regenerated. However, if you edit the contents of the ``graph`` without reassigning + the ``graph`` attribute itself, you must call ``recompile()`` to update the generated + code. + """ + + def __new__(cls: "type[GraphModule]", *args, **kwargs): + # each instance of a graph module needs its own forward method + # so create a new singleton class for each instance. + # it is a subclass of the user-defined class, the only difference + # is an extra layer to install the forward method + + # address issue described at https://github.com/pytorch/pytorch/issues/63883 + # in other words, traverse class hierarchy to fix the redundant class definition problem + for t in cls.__mro__: + c = t.__qualname__.split(".")[-1] + if c != "GraphModuleImpl": + cls = t + break + + class GraphModuleImpl(cls): # type: ignore[misc, valid-type] + pass + + return super().__new__(GraphModuleImpl) + + @compatibility(is_backward_compatible=True) + def __init__( + self, + root: Union[torch.nn.Module, dict[str, Any]], + graph: Graph, + class_name: str = "GraphModule", + ): + """ + Construct a GraphModule. + + Args: + + root (Union[torch.nn.Module, Dict[str, Any]): + ``root`` can either be an nn.Module instance or a Dict mapping strings to any attribute type. + In the case that ``root`` is a Module, any references to Module-based objects (via qualified + name) in the Graph's Nodes' ``target`` field will be copied over from the respective place + within ``root``'s Module hierarchy into the GraphModule's module hierarchy. + In the case that ``root`` is a dict, the qualified name found in a Node's ``target`` will be + looked up directly in the dict's keys. The object mapped to by the Dict will be copied + over into the appropriate place within the GraphModule's module hierarchy. + + graph (Graph): ``graph`` contains the nodes this GraphModule should use for code generation + + class_name (str): ``name`` denotes the name of this GraphModule for debugging purposes. If it's unset, all + error messages will report as originating from ``GraphModule``. It may be helpful to set this + to ``root``'s original name or a name that makes sense within the context of your transform. + """ + super().__init__() + self.__class__.__name__ = class_name + if isinstance(root, torch.nn.Module): + if hasattr(root, "training"): + self.training = root.training + + # When we pickle/unpickle graph module, we don't want to drop any module or attributes. + if isinstance(root, _CodeOnlyModule): + for k, _ in root.named_children(): + _copy_attr(root, self, k) + + for k, _ in root.named_buffers(): + _copy_attr(root, self, k) + + for k, _ in root.named_parameters(): + _copy_attr(root, self, k) + + for node in graph.nodes: + if node.op in ["get_attr", "call_module"]: + assert isinstance(node.target, str) + _copy_attr(root, self, node.target) + elif isinstance(root, dict): + targets_to_copy = [] + for node in graph.nodes: + if node.op in ["get_attr", "call_module"]: + assert isinstance(node.target, str) + if node.target not in root: + raise RuntimeError( + "Node " + + str(node) + + " referenced target " + + node.target + + " but that target was not provided in ``root``!" + ) + targets_to_copy.append(node.target) + # Sort targets in ascending order of the # of atoms. + # This will ensure that less deeply nested attributes are assigned + # before more deeply nested attributes. For example, foo.bar + # will be assigned before foo.bar.baz. Otherwise, we might assign + # the user-provided ``foo.bar`` and wipe out the previously-assigned + # ``foo.bar.baz`` + targets_to_copy.sort(key=lambda t: t.count(".")) + for target_to_copy in targets_to_copy: + _assign_attr(root[target_to_copy], self, target_to_copy) + else: + raise RuntimeError("Unsupported type " + str(root) + " passed for root!") + + self.graph = graph + + # Store the Tracer class responsible for creating a Graph separately as part of the + # GraphModule state, except when the Tracer is defined in a local namespace. + # Locally defined Tracers are not pickleable. This is needed because torch.package will + # serialize a GraphModule without retaining the Graph, and needs to use the correct Tracer + # to re-create the Graph during deserialization. + self._tracer_cls = None + if ( + self.graph._tracer_cls + and "" not in self.graph._tracer_cls.__qualname__ + ): + # pyrefly: ignore [bad-assignment] + self._tracer_cls = self.graph._tracer_cls + + self._tracer_extras = {} + if self.graph._tracer_extras: + self._tracer_extras = self.graph._tracer_extras + + # Dictionary to store metadata + self.meta: dict[str, Any] = {} + self._replace_hooks: list[Callable] = [] + self._create_node_hooks: list[Callable] = [] + self._erase_node_hooks: list[Callable] = [] + # Used to remove hooks from deepcopied graph modules within a context manager. + self._deepcopy_hooks: list[Callable] = [] + self.shape_env = None # optional not always set even when dynamic shapes exist. + + # TorchScript breaks trying to compile the graph setter because of the + # continued string literal. Issue here: https://github.com/pytorch/pytorch/issues/44842 + # + # Shouldn't be an issue since these methods shouldn't be used in TorchScript anyway + __jit_unused_properties__ = ["graph", "_boxed_call"] + + @property + def _boxed_call(self) -> bool: + return isinstance(self._graph._codegen, _BoxedCodeGen) + + @property + def graph(self) -> Graph: + """ + Return the ``Graph`` underlying this ``GraphModule`` + """ + return self._graph + + @graph.setter + def graph(self, g: Graph) -> None: + """ + Set the underlying ``Graph`` for this ``GraphModule``. This will internally + recompile the ``GraphModule`` so that the generated ``forward()`` function + corresponds to ``g`` + """ + assert isinstance(g, Graph), f"Expected a Graph instance, but got {type(g)}" + self._graph = g + g.owning_module = self + self.recompile() + + @compatibility(is_backward_compatible=False) + def to_folder(self, folder: Union[str, os.PathLike], module_name: str = "FxModule"): + """Dumps out module to ``folder`` with ``module_name`` so that it can be + imported with ``from import `` + + Args: + + folder (Union[str, os.PathLike]): The folder to write the code out to + + module_name (str): Top-level name to use for the ``Module`` while + writing out the code + """ + folder = Path(folder) + Path(folder).mkdir(exist_ok=True) + torch.save(self.state_dict(), folder / "state_dict.pt") + tab = " " * 4 + custom_builtins = "\n".join([v.import_str for v in _custom_builtins.values()]) + model_str = f""" +import torch +{custom_builtins} + +from torch.nn import * +class {module_name}(torch.nn.Module): + def __init__(self): + super().__init__() +""" + + def _gen_model_repr(module_name: str, module: torch.nn.Module) -> Optional[str]: + safe_reprs = [ + nn.Linear, + nn.Conv1d, + nn.Conv2d, + nn.Conv3d, + nn.BatchNorm1d, + nn.BatchNorm2d, + nn.BatchNorm3d, + ] + if type(module) in safe_reprs: + return f"{module.__repr__()}" + else: + return None + + blobified_modules = [] + for module_name, module in self.named_children(): + module_str = _gen_model_repr(module_name, module) + if module_str is None: + module_file = folder / f"{module_name}.pt" + torch.save(module, module_file) + blobified_modules.append(module_name) + module_repr = module.__repr__().replace("\r", " ").replace("\n", " ") + # weights_only=False as this is legacy code that saves the model + module_str = ( + f"torch.load(r'{module_file}', weights_only=False) # {module_repr}" + ) + model_str += f"{tab * 2}self.{module_name} = {module_str}\n" + + for buffer_name, buffer in self._buffers.items(): + if buffer is None: + continue + model_str += f"{tab * 2}self.register_buffer('{buffer_name}', torch.empty({list(buffer.shape)}, dtype={buffer.dtype}))\n" # noqa: B950 + + for param_name, param in self._parameters.items(): + if param is None: + continue + model_str += f"{tab * 2}self.{param_name} = torch.nn.Parameter(torch.empty({list(param.shape)}, dtype={param.dtype}))\n" # noqa: B950 + + model_str += ( + f"{tab * 2}self.load_state_dict(torch.load(r'{folder}/state_dict.pt'))\n" + ) + model_str += f"{_addindent(self.code, 4)}\n" + + module_file = folder / "module.py" + module_file.write_text(model_str) + + init_file = folder / "__init__.py" + init_file.write_text("from .module import *") + + if len(blobified_modules) > 0: + warnings.warn( + "Was not able to save the following children modules as reprs -" + f"saved as pickled files instead: {blobified_modules}" + ) + + @compatibility(is_backward_compatible=True) + def add_submodule(self, target: str, m: torch.nn.Module) -> bool: + """ + Adds the given submodule to ``self``. + + This installs empty Modules where none exist yet if they are + subpaths of ``target``. + + Args: + target: The fully-qualified string name of the new submodule + (See example in ``nn.Module.get_submodule`` for how to + specify a fully-qualified string.) + m: The submodule itself; the actual object we want to + install in the current Module + + Return: + bool: Whether or not the submodule could be inserted. For + this method to return True, each object in the chain + denoted by ``target`` must either a) not exist yet, + or b) reference an ``nn.Module`` (not a parameter or + other attribute) + """ + *prefix, field = target.split(".") + mod: torch.nn.Module = self + + for item in prefix: + submod = getattr(mod, item, None) + + if submod is None: + submod = torch.nn.Module() + setattr(mod, item, submod) + + if not isinstance(submod, torch.nn.Module): + return False + + mod = submod + + mod.add_module(field, m) + return True + + @compatibility(is_backward_compatible=True) + def delete_submodule(self, target: str) -> bool: + """ + Deletes the given submodule from ``self``. + + The module will not be deleted if ``target`` is not a valid + target. + + Args: + target: The fully-qualified string name of the new submodule + (See example in ``nn.Module.get_submodule`` for how to + specify a fully-qualified string.) + + Returns: + bool: Whether or not the target string referenced a + submodule we want to delete. A return value of ``False`` + means that the ``target`` was not a valid reference to + a submodule. + """ + atoms = target.split(".") + path, target_submod = atoms[:-1], atoms[-1] + mod: torch.nn.Module = self + + # Get the parent module + for item in path: + if not hasattr(mod, item): + return False + + mod = getattr(mod, item) + + if not isinstance(mod, torch.nn.Module): + return False + + if not hasattr(mod, target_submod): + return False + + if not isinstance(getattr(mod, target_submod), torch.nn.Module): + return False + + delattr(mod, target_submod) + return True + + @compatibility(is_backward_compatible=True) + def delete_all_unused_submodules(self) -> None: + """ + Deletes all unused submodules from ``self``. + + A Module is considered "used" if any one of the following is + true: + 1. It has children that are used + 2. Its forward is called directly via a ``call_module`` node + 3. It has a non-Module attribute that is used from a + ``get_attr`` node + + This method can be called to clean up an ``nn.Module`` without + manually calling ``delete_submodule`` on each unused submodule. + """ + used: list[str] = [] + + for node in self.graph.nodes: + if node.op == "call_module" or node.op == "get_attr": + # A list of strings representing the different parts + # of the path. For example, `foo.bar.baz` gives us + # ["foo", "bar", "baz"] + fullpath = node.target.split(".") + + # If we're looking at multiple parts of a path, join + # join them with a dot. Otherwise, return that single + # element without doing anything to it. + def join_fn(x: str, y: str) -> str: + return ".".join([x, y] if y else [x]) + + # Progressively collect all the names of intermediate + # modules. For example, if we have the target + # `foo.bar.baz`, we'll add `foo`, `foo.bar`, and + # `foo.bar.baz` to the list. + used.extend(itertools.accumulate(fullpath, join_fn)) + + # For a `call_module` node, also register all recursive submodules + # as used + if node.op == "call_module": + try: + submod = self.get_submodule(node.target) + + for submod_name, _ in submod.named_modules(): + if submod_name != "": + used.append(".".join([node.target, submod_name])) + except AttributeError: + # Node referenced nonexistent submodule, don't need to + # worry about GCing anything + pass + + to_delete = [name for name, _ in self.named_modules() if name not in used] + + for name in to_delete: + self.delete_submodule(name) + + @property + def code(self) -> str: + """ + Return the Python code generated from the ``Graph`` underlying this + ``GraphModule``. + """ + if not hasattr(self, "_code"): + raise RuntimeError( + "Code has not been generated! Please report a bug to PyTorch" + ) + return self._code + + @compatibility(is_backward_compatible=True) + def recompile(self) -> PythonCode: + """ + Recompile this GraphModule from its ``graph`` attribute. This should be + called after editing the contained ``graph``, otherwise the generated + code of this ``GraphModule`` will be out of date. + """ + # Do not import anything inside recompile, it might slow down the + # function and cause perf regression. Import outside of the method instead. + if isinstance(self._graph._codegen, _PyTreeCodeGen): + self._in_spec = self._graph._codegen.pytree_info.in_spec + self._out_spec = self._graph._codegen.pytree_info.out_spec + + python_code = self._graph.python_code( + root_module="self", + record_func=fx_experimental_config.enrich_profiler_metadata, + ) + self._code = python_code.src + self._lineno_map = python_code._lineno_map + self._prologue_start = python_code._prologue_start + + cls = type(self) + co_fields = self._graph._co_fields if hasattr(self._graph, "_co_fields") else {} + + if fx_experimental_config.enrich_profiler_metadata: + # Generate metadata and register for profiler augmentation + node_metadata: dict[int, dict[str, Any]] = {} + for i, node in enumerate(self._graph.nodes): + node_metadata[i] = { + "name": node.name, + "op": node.op, + "target": str(node.target), + "stack_trace": node.meta.get("stack_trace", None), + } + + # Generate a content-addressed filename based on hash of code and metadata + # This ensures the same code+metadata always generates the same filename + hash_value = _metadata_hash(self._code, node_metadata) + file_stem = f"{FX_GRAPH_MODULE_FILE_PREFIX}_{hash_value}" + filename = f"{file_stem}.py" + + # Only include co_filename to use it directly as the cache key + co_fields = { + "co_filename": filename, + } + + # Store metadata in global in-memory registry + metadata = { + "lineno_map": python_code._lineno_map, + "prologue_start": python_code._prologue_start, + "node_metadata": node_metadata, + } + + # Register metadata in the global registry + from torch.fx.traceback import _register_fx_metadata + + _register_fx_metadata(filename, metadata) + + # Replace the placeholder in generated code with actual filename + # The double hash ## convention is used by post-processing to find the fx markers + self._code = self._code.replace( + "torch._C._profiler._RecordFunctionFast('## ENTER_GRAPH_PLACEHOLDER_KEY ##')", + f"torch._C._profiler._RecordFunctionFast('## {filename} ##')", + ) + + cls.forward = _forward_from_src(self._code, python_code.globals, co_fields) + + # Determine whether this class explicitly defines a __call__ implementation + # to wrap. If it does, save it in order to have wrapped_call invoke it. + # If it does not, wrapped_call can use a dynamic call to super() instead. + # In most cases, super().__call__ should be torch.nn.Module.__call__. + # We do not want to hold a reference to Module.__call__ here; doing so will + # bypass patching of torch.nn.Module.__call__ done while symbolic tracing. + cls_call = cls.__call__ if "__call__" in vars(cls) else None + + if "_wrapped_call" not in vars(cls): + cls._wrapped_call = _WrappedCall(cls, cls_call) # type: ignore[attr-defined] + + self._recompile_submodules() + + def call_wrapped(self, *args, **kwargs): + return self._wrapped_call(self, *args, **kwargs) + + cls.__call__ = call_wrapped # type: ignore[method-assign] + + return python_code + + def _recompile_submodules(self) -> list[tuple[str, PythonCode]]: + """ + Recompile all submodules of this graph module, returning their respective PythonCodes + in a similar format to named_children() + """ + results: list[tuple[str, PythonCode]] = [] + for name, mod in self.named_children(): + if isinstance(mod, GraphModule): + results.append((name, mod.recompile())) + return results + + # Passing Tracer as argument allows subclasses extending fx.GraphModule + # define their own Tracer (extending fx.Tracer). + + def __reduce_package__(self, exporter: PackageExporter): + dict_without_graph = self.__dict__.copy() + dict_without_graph["_graphmodule_cls_name"] = self.__class__.__name__ + del dict_without_graph["_graph"] + + # Store node.meta["stack_trace"] so we can recover them after re-tracing during deserialization + node_meta_stack_trace = { + node.name: node.meta["stack_trace"] + for node in self.graph.nodes + if "stack_trace" in node.meta + } + dict_without_graph["_graphmodule_graph_node_meta_stack_trace"] = ( + node_meta_stack_trace + ) + + generated_module_name = f"fx-generated._{exporter.get_unique_id()}" + python_code = self.recompile() + import_block = _format_import_block(python_code.globals, exporter.importer) + module_code = import_block + self.code + exporter.save_source_string(generated_module_name, module_code) + return ( + reduce_package_graph_module, + (dict_without_graph, generated_module_name), + ) + + def __reduce__(self): + """ + Serialization of GraphModule. We serialize only the generated code, not + the underlying ``Graph``. This is because ``Graph`` does not have on-disk + backward-compatibility guarantees, whereas Python source code does. + On the deserialization side, we symbolically trace through the generated + code to regenerate the underlying ``Graph`` + """ + dict_without_graph = self.__dict__.copy() + + python_code = self.recompile() + import_block = _format_import_block(python_code.globals, sys_importer) + del dict_without_graph["_graph"] + return (reduce_graph_module, (dict_without_graph, import_block)) + + def _deepcopy_init(self): + return GraphModule.__init__ + + # because __reduce__ is defined for serialization, + # we need to define deepcopy otherwise it will call __reduce__ + # and cause symbolic tracing to occur every time we try to copy the object + def __deepcopy__(self, memo): + res = type(self).__new__(type(self)) + memo[id(self)] = res + fake_mod = _CodeOnlyModule(copy.deepcopy(self.__dict__, memo)) + self._deepcopy_init()(res, fake_mod, fake_mod.__dict__["_graph"]) + # hooks are lost during `GraphModule.__init__`, so we need to copy over + # them explicitly, note right now we are only copying state_dict related + # hooks, to reduce bc-related issues, we can copy forward/backward related + # hooks in the future as well if needed + extra_preserved_attrs = [ + "_state_dict_hooks", + "_load_state_dict_pre_hooks", + "_load_state_dict_post_hooks", + "_replace_hooks", + "_create_node_hooks", + "_erase_node_hooks", + "_deepcopy_hooks", + ] + for attr in extra_preserved_attrs: + if attr in self.__dict__: + setattr(res, attr, copy.deepcopy(self.__dict__[attr], memo)) + res.meta = copy.deepcopy(getattr(self, "meta", {}), memo) + if _USER_PRESERVED_ATTRIBUTES_KEY in res.meta: + for attr_name, attr in res.meta[_USER_PRESERVED_ATTRIBUTES_KEY].items(): + setattr(res, attr_name, attr) + if hasattr(self, "_deepcopy_hooks"): + for hook in self._deepcopy_hooks: + hook(res) + return res + + def __copy__(self): + from ._lazy_graph_module import _make_graph_module + + res = _make_graph_module(self, self.graph) + res.meta = getattr(self, "meta", {}) + return res + + @compatibility(is_backward_compatible=False) + def print_readable( + self, + print_output=True, + include_stride=False, + include_device=False, + colored=False, + *, + # If `fast_sympy_print` is True then we use a sympy printer which is faster + # but may result in less-readable output. + fast_sympy_print: bool = False, + expanded_def: bool = False, + ): + """ + Return the Python code generated for current GraphModule and its children GraphModules + """ + ctx_mgr = contextlib.ExitStack() + with ctx_mgr: + if fast_sympy_print: + from torch._inductor.utils import sympy_str + + def fast_repr(expr: torch.types.PySymType) -> str: + return sympy_str(expr.node.expr) + + ctx_mgr.enter_context(_override_sym_repr(fast_repr)) + + r = _print_readable( + self, + self._get_name(), + print_output, + include_stride, + include_device, + colored, + expanded_def, + ) + return r + + def __str__(self) -> str: + orig_str = super().__str__() + print_readable_reminder = ( + "# To see more debug info, please use `graph_module.print_readable()`" + ) + return "\n".join([orig_str, self._code, print_readable_reminder]) + + def _replicate_for_data_parallel(self): + new_gm = self.__copy__() + new_gm._is_replica = True + return new_gm + + @contextlib.contextmanager + def _set_replace_hook(self, f): + """ + Takes a callable which will be called every time when we replace a node + to a new node, or change the node's name. Callable takes three arguments: + the old node we're changing, and NAME of the new node, followed by the + user node which consumes the old node to be replaced. + """ + assert callable(f), "Replace hook must be a callable." + self._register_replace_node_hook(f) + try: + yield + finally: + self._unregister_replace_node_hook(f) + + def _register_replace_node_hook(self, f): + """ + Takes a callable which will be called every time when we replace a node + to a new node, or change the node's name. Callable takes three arguments: + the old node we're changing, and NAME of the new node, followed by the + user node which consumes the old node to be replaced. + """ + assert callable(f), "create_node hook must be a callable." + self._replace_hooks.append(f) + + def _unregister_replace_node_hook(self, f): + """ + Takes a callable which was previously registered to be called every time when we replace a node. + This function will unregister that callable so it is no longer invoked on node replacement. + """ + assert callable(f), "create_node hook must be a callable." + self._replace_hooks.remove(f) + + def _register_create_node_hook(self, f): + """ + Takes a callable which will be called after we create a new node. The + callable takes the newly created node as input and returns None. + """ + assert callable(f), "create_node hook must be a callable." + self._create_node_hooks.append(f) + + def _unregister_create_node_hook(self, f): + """ + Takes a callable which was previously registered to be called after we create a node. + This function will unregister that callable so it is no longer invoked on node creation. + """ + assert callable(f), "create_node hook must be a callable." + self._create_node_hooks.remove(f) + + def _register_erase_node_hook(self, f): + """ + Takes a callable which will be called after we erase a node. The + callable takes the node that is being erased as input and returns None. + """ + assert callable(f), "erase_node hook must be a callable." + self._erase_node_hooks.append(f) + + def _unregister_erase_node_hook(self, f): + """ + Takes a callable which was previously registered to be called after we erase a node. + This function will unregister that callable so it is no longer invoked on node erasure. + """ + assert callable(f), "erase_node hook must be a callable." + self._erase_node_hooks.remove(f) + + def _register_deepcopy_hook(self, f): + """ + Takes a callable which will be called when we deepcopy this graph module. The + callable takes the resulting deepcopied graph module. + """ + assert callable(f), "deepcopy hook must be a callable." + self._deepcopy_hooks.append(f) + + def _unregister_deepcopy_hook(self, f): + """ + Takes a callable which was previously registered to be called after deepcopy. + This function will unregister that callable so it is no longer invoked on deepcopy. + """ + assert callable(f), "deepcopy hook must be a callable." + self._deepcopy_hooks.remove(f) + + +# workarounds for issues in __torch_function__ + +# WAR for __torch_function__ not handling tensor lists, +# fix is in https://github.com/pytorch/pytorch/pull/34725 +# orig_cat = torch.cat +# def patched_cat(*args, **kwargs): +# tensors = args[0] +# for t in tensors: +# if isinstance(t, Proxy): +# return t.__torch_function__(patched_cat, (), args, kwargs) +# return orig_cat(*args, **kwargs) +# patched_cat.__module__ = 'torch' +# patched_cat.__name__ = 'cat' +# torch.cat = patched_cat diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/immutable_collections.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/immutable_collections.py new file mode 100644 index 0000000000000000000000000000000000000000..6c6204d520bc66af3e6c161b0254b9d81012c287 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/immutable_collections.py @@ -0,0 +1,122 @@ +from collections.abc import Iterable +from typing import Any, NoReturn, TypeVar +from typing_extensions import Self + +from torch.utils._pytree import ( + _dict_flatten, + _dict_flatten_with_keys, + _dict_unflatten, + _list_flatten, + _list_flatten_with_keys, + _list_unflatten, + Context, + register_pytree_node, +) + +from ._compatibility import compatibility + + +__all__ = ["immutable_list", "immutable_dict"] + + +_help_mutation = """ +If you are attempting to modify the kwargs or args of a torch.fx.Node object, +instead create a new copy of it and assign the copy to the node: + + new_args = ... # copy and mutate args + node.args = new_args +""".strip() + + +_T = TypeVar("_T") +_KT = TypeVar("_KT") +_VT = TypeVar("_VT") + + +def _no_mutation(self: Any, *args: Any, **kwargs: Any) -> NoReturn: + raise TypeError( + f"{type(self).__name__!r} object does not support mutation. {_help_mutation}", + ) + + +@compatibility(is_backward_compatible=True) +class immutable_list(list[_T]): + """An immutable version of :class:`list`.""" + + __delitem__ = _no_mutation + __iadd__ = _no_mutation + __imul__ = _no_mutation + __setitem__ = _no_mutation + append = _no_mutation + clear = _no_mutation + extend = _no_mutation + insert = _no_mutation + pop = _no_mutation + remove = _no_mutation + reverse = _no_mutation + sort = _no_mutation + + def __hash__(self) -> int: # type: ignore[override] + return hash(tuple(self)) + + def __reduce__(self) -> tuple[type[Self], tuple[tuple[_T, ...]]]: + return (type(self), (tuple(self),)) + + +@compatibility(is_backward_compatible=True) +class immutable_dict(dict[_KT, _VT]): + """An immutable version of :class:`dict`.""" + + __delitem__ = _no_mutation + __ior__ = _no_mutation + __setitem__ = _no_mutation + clear = _no_mutation + pop = _no_mutation + popitem = _no_mutation + setdefault = _no_mutation + update = _no_mutation # type: ignore[assignment] + + def __hash__(self) -> int: # type: ignore[override] + return hash(frozenset(self.items())) + + def __reduce__(self) -> tuple[type[Self], tuple[tuple[tuple[_KT, _VT], ...]]]: + return (type(self), (tuple(self.items()),)) + + +# Register immutable collections for PyTree operations +def _immutable_list_flatten(d: immutable_list[_T]) -> tuple[list[_T], Context]: + return _list_flatten(d) + + +def _immutable_list_unflatten( + values: Iterable[_T], + context: Context, +) -> immutable_list[_T]: + return immutable_list(_list_unflatten(values, context)) + + +def _immutable_dict_flatten(d: immutable_dict[Any, _VT]) -> tuple[list[_VT], Context]: + return _dict_flatten(d) + + +def _immutable_dict_unflatten( + values: Iterable[_VT], + context: Context, +) -> immutable_dict[Any, _VT]: + return immutable_dict(_dict_unflatten(values, context)) + + +register_pytree_node( + immutable_list, + _immutable_list_flatten, + _immutable_list_unflatten, + serialized_type_name="torch.fx.immutable_collections.immutable_list", + flatten_with_keys_fn=_list_flatten_with_keys, +) +register_pytree_node( + immutable_dict, + _immutable_dict_flatten, + _immutable_dict_unflatten, + serialized_type_name="torch.fx.immutable_collections.immutable_dict", + flatten_with_keys_fn=_dict_flatten_with_keys, +) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/interpreter.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/interpreter.py new file mode 100644 index 0000000000000000000000000000000000000000..5b40e8a66147f410e03e349560571a3da0859f19 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/interpreter.py @@ -0,0 +1,656 @@ +# mypy: allow-untyped-defs +import inspect +import logging +from contextlib import contextmanager +from typing import Any, Optional, TYPE_CHECKING, Union + +import torch +import torch.fx.traceback as fx_traceback +from torch._logging import LazyString, trace_structured +from torch.hub import tqdm + +from . import config +from ._compatibility import compatibility +from ._lazy_graph_module import _make_graph_module +from ._symbolic_trace import Tracer +from .graph import Graph +from .graph_module import GraphModule +from .node import Argument, map_aggregate, map_arg, Node, Target +from .proxy import Proxy + + +if TYPE_CHECKING: + from collections.abc import Iterator + +log = logging.getLogger(__name__) + +__all__ = ["Interpreter", "Transformer"] + + +def _format_fx_node(n): + """ + Format a torch.fx.Node into a human-readable string for debug logging. + + Args: + n (torch.fx.Node): The FX node being executed. + + Returns: + str: A formatted string describing the node operation, including its + name, target, positional arguments, and keyword arguments. + """ + module_prefix = getattr(n.target, "__module__", "") + module_prefix = f"{module_prefix}." if module_prefix else "" + + # Handle positional and keyword arguments + args = ", ".join(map(str, n.args)) + kwargs = ", ".join(f"{k}={v}" for k, v in n.kwargs.items()) + joined = ", ".join(filter(None, [args, kwargs])) + + return ( + f"{n.name} = {module_prefix}{getattr(n.target, '__name__', n.target)}({joined})" + ) + + +@compatibility(is_backward_compatible=True) +class Interpreter: + """ + An Interpreter executes an FX graph Node-by-Node. This pattern + can be useful for many things, including writing code + transformations as well as analysis passes. + + Methods in the Interpreter class can be overridden to customize + the behavior of execution. The map of overridable methods + in terms of call hierarchy:: + + run() + +-- run_node + +-- placeholder() + +-- get_attr() + +-- call_function() + +-- call_method() + +-- call_module() + +-- output() + + Example: + + Suppose we want to swap all instances of ``torch.neg`` with + ``torch.sigmoid`` and vice versa (including their ``Tensor`` + method equivalents). We could subclass Interpreter like so:: + + class NegSigmSwapInterpreter(Interpreter): + def call_function( + self, target: Target, args: Tuple, kwargs: Dict + ) -> Any: + if target is torch.sigmoid: + return torch.neg(*args, **kwargs) + return super().call_function(target, args, kwargs) + + def call_method(self, target: Target, args: Tuple, kwargs: Dict) -> Any: + if target == "neg": + call_self, *args_tail = args + return call_self.sigmoid(*args_tail, **kwargs) + return super().call_method(target, args, kwargs) + + + def fn(x): + return torch.sigmoid(x).neg() + + + gm = torch.fx.symbolic_trace(fn) + input = torch.randn(3, 4) + result = NegSigmSwapInterpreter(gm).run(input) + torch.testing.assert_close(result, torch.neg(input).sigmoid()) + + Args: + module (torch.nn.Module): The module to be executed + garbage_collect_values (bool): Whether to delete values after their last + use within the Module's execution. This ensures optimal memory usage during + execution. This can be disabled to, for example, examine all of the intermediate + values in the execution by looking at the ``Interpreter.env`` attribute. + graph (Optional[Graph]): If passed, the interpreter will execute this + graph instead of `module.graph`, using the provided `module` + argument to satisfy any requests for state. + """ + + @compatibility(is_backward_compatible=True) + def __init__( + self, + module: torch.nn.Module, + garbage_collect_values: bool = True, + graph: Optional[Graph] = None, + ): + self.module = module + self.submodules = dict(self.module.named_modules()) + if graph is not None: + self.graph = graph + else: + self.graph = self.module.graph # type: ignore[assignment] + self.env: dict[Node, Any] = {} + self.name = "Interpreter" + self.garbage_collect_values = garbage_collect_values + self.extra_traceback = True + + if self.garbage_collect_values: + # Run through reverse nodes and record the first instance of a use + # of a given node. This represents the *last* use of the node in the + # execution order of the program, which we will use to free unused + # values + node_to_last_use: dict[Node, Node] = {} + self.user_to_last_uses: dict[Node, list[Node]] = {} + + def register_last_uses(n: Node, user: Node): + if n not in node_to_last_use: + node_to_last_use[n] = user + self.user_to_last_uses.setdefault(user, []).append(n) + + for node in reversed(self.graph.nodes): + for n in node._input_nodes: + register_last_uses(n, node) + + @compatibility(is_backward_compatible=True) + def run( + self, + *args, + initial_env: Optional[dict[Node, Any]] = None, + enable_io_processing: bool = True, + ) -> Any: + """ + Run `module` via interpretation and return the result. + + Args: + *args: The arguments to the Module to run, in positional order + initial_env (Optional[Dict[Node, Any]]): An optional starting environment for execution. + This is a dict mapping `Node` to any value. This can be used, for example, to + pre-populate results for certain `Nodes` so as to do only partial evaluation within + the interpreter. + enable_io_processing (bool): If true, we process the inputs and outputs with graph's process_inputs and + process_outputs function first before using them. + + Returns: + Any: The value returned from executing the Module + """ + self.env = initial_env if initial_env is not None else {} + + # Positional function args are consumed left-to-right by + # `placeholder` nodes. Use an iterator to keep track of + # position and extract those values. + if enable_io_processing: + args = self.graph.process_inputs(*args) + self.args_iter: Iterator[Any] = iter(args) + pbar = tqdm( + total=len(self.graph.nodes), + desc=f"{self.name}: {str(list(self.graph.nodes)) if config.verbose_progress else ''}", + initial=0, + position=0, + leave=True, + disable=config.disable_progress, + delay=0, + ) + + for node in self.graph.nodes: + pbar.update(1) + if node in self.env: + # Short circuit if we have this value. This could + # be used, for example, for partial evaluation + # where the caller has pre-populated `env` with + # values for a subset of the program. + continue + + try: + self.env[node] = self.run_node(node) + except Exception as e: + if self.extra_traceback: + msg = f"While executing {node.format_node()}" + msg = f"{e.args[0]}\n\n{msg}" if e.args else str(msg) + msg += f"\nOriginal traceback:\n{node.stack_trace}" + if ( + isinstance(self.module, GraphModule) + and self.module.graph is not None + and isinstance(self.module.graph, torch.fx.Graph) + ): + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "fx_interpreter_error", + "encoding": "string", + }, + payload_fn=lambda: ( + f"{msg}\nGraphModule: " + f"{self.module.print_readable(print_output=False, include_stride=True)}" # type: ignore[operator] + ), + ) + + msg += "\nUse tlparse to see full graph. " + msg += "(https://github.com/pytorch/tlparse?tab=readme-ov-file#tlparse-parse-structured-pt2-logs)" + e.args = (msg,) + e.args[1:] + if isinstance(e, KeyError): + raise RuntimeError(*e.args) from e + raise + + if self.garbage_collect_values: + for to_delete in self.user_to_last_uses.get(node, []): + del self.env[to_delete] + + if node.op == "output": + output_val = self.env[node] + return ( + self.graph.process_outputs(output_val) + if enable_io_processing + else output_val + ) + + @compatibility(is_backward_compatible=True) + def boxed_run(self, args_list): + """ + Run `module` via interpretation and return the result. This uses the "boxed" + calling convention, where you pass a list of arguments, which will be cleared + by the interpreter. This ensures that input tensors are promptly deallocated. + """ + # Collect placeholder nodes first + placeholder_nodes = [n for n in self.graph.nodes if n.op == "placeholder"] + + # Check argument count + if len(args_list) != len(placeholder_nodes): + detail = ( + "extra arguments" + if len(args_list) > len(placeholder_nodes) + else "missing arguments" + ) + raise RuntimeError( + f"Interpreter.boxed_run expected {len(placeholder_nodes)} arguments for placeholders " + f"but received {len(args_list)} ({detail})" + ) + + # Assign arguments to placeholders + env = dict(zip(placeholder_nodes, args_list)) + args_list.clear() + return self.run(initial_env=env) + + @contextmanager + def _set_current_node(self, node): + with fx_traceback.set_current_meta( + node, f"Interpreter_{self.__class__.__name__}" + ): + yield + + @compatibility(is_backward_compatible=True) + def run_node(self, n: Node) -> Any: + """ + Run a specific node ``n`` and return the result. + Calls into placeholder, get_attr, call_function, + call_method, call_module, or output depending + on ``node.op`` + + Args: + n (Node): The Node to execute + + Returns: + Any: The result of executing ``n`` + """ + log.debug("run_node %s", LazyString(lambda: _format_fx_node(n))) + with self._set_current_node(n): + args, kwargs = self.fetch_args_kwargs_from_env(n) + assert isinstance(args, tuple) + assert isinstance(kwargs, dict) + return getattr(self, n.op)(n.target, args, kwargs) + + # Main Node running APIs + @compatibility(is_backward_compatible=True) + def placeholder( + self, target: "Target", args: tuple[Argument, ...], kwargs: dict[str, Any] + ) -> Any: + """ + Execute a ``placeholder`` node. Note that this is stateful: + ``Interpreter`` maintains an internal iterator over + arguments passed to ``run`` and this method returns + next() on that iterator. + + Args: + target (Target): The call target for this node. See + `Node `__ for + details on semantics + args (Tuple): Tuple of positional args for this invocation + kwargs (Dict): Dict of keyword arguments for this invocation + + Returns: + Any: The argument value that was retrieved. + """ + assert isinstance(target, str) + if target.startswith("*"): + # For a starred parameter e.g. `*args`, retrieve all + # remaining values from the args list. + return list(self.args_iter) + else: + try: + return next(self.args_iter) + except StopIteration as si: + if len(args) > 0: + return args[0] + else: + raise RuntimeError( + f"Expected positional argument for parameter {target}, but one was not passed in!" + ) from si + + @compatibility(is_backward_compatible=True) + def get_attr( + self, target: "Target", args: tuple[Argument, ...], kwargs: dict[str, Any] + ) -> Any: + """ + Execute a ``get_attr`` node. Will retrieve an attribute + value from the ``Module`` hierarchy of ``self.module``. + + Args: + target (Target): The call target for this node. See + `Node `__ for + details on semantics + args (Tuple): Tuple of positional args for this invocation + kwargs (Dict): Dict of keyword arguments for this invocation + + Return: + Any: The value of the attribute that was retrieved + """ + assert isinstance(target, str) + return self.fetch_attr(target) + + @compatibility(is_backward_compatible=True) + def call_function( + self, target: "Target", args: tuple[Argument, ...], kwargs: dict[str, Any] + ) -> Any: + """ + Execute a ``call_function`` node and return the result. + + Args: + target (Target): The call target for this node. See + `Node `__ for + details on semantics + args (Tuple): Tuple of positional args for this invocation + kwargs (Dict): Dict of keyword arguments for this invocation + + Return + Any: The value returned by the function invocation + """ + assert not isinstance(target, str) + + # Execute the function and return the result + return target(*args, **kwargs) + + @compatibility(is_backward_compatible=True) + def call_method( + self, target: "Target", args: tuple[Argument, ...], kwargs: dict[str, Any] + ) -> Any: + """ + Execute a ``call_method`` node and return the result. + + Args: + target (Target): The call target for this node. See + `Node `__ for + details on semantics + args (Tuple): Tuple of positional args for this invocation + kwargs (Dict): Dict of keyword arguments for this invocation + + Return + Any: The value returned by the method invocation + """ + # args[0] is the `self` object for this method call + self_obj, *args_tail = args + + # Execute the method and return the result + assert isinstance(target, str) + return getattr(self_obj, target)(*args_tail, **kwargs) + + @compatibility(is_backward_compatible=True) + def call_module( + self, target: "Target", args: tuple[Argument, ...], kwargs: dict[str, Any] + ) -> Any: + """ + Execute a ``call_module`` node and return the result. + + Args: + target (Target): The call target for this node. See + `Node `__ for + details on semantics + args (Tuple): Tuple of positional args for this invocation + kwargs (Dict): Dict of keyword arguments for this invocation + + Return + Any: The value returned by the module invocation + """ + # Retrieve executed args and kwargs values from the environment + + # Execute the method and return the result + assert isinstance(target, str) + submod = self.fetch_attr(target) + + return submod(*args, **kwargs) + + @compatibility(is_backward_compatible=True) + def output( + self, target: "Target", args: tuple[Argument, ...], kwargs: dict[str, Any] + ) -> Any: + """ + Execute an ``output`` node. This really just retrieves + the value referenced by the ``output`` node and returns it. + + Args: + target (Target): The call target for this node. See + `Node `__ for + details on semantics + args (Tuple): Tuple of positional args for this invocation + kwargs (Dict): Dict of keyword arguments for this invocation + + Return: + Any: The return value referenced by the output node + """ + return args[0] + + # Helper methods + @compatibility(is_backward_compatible=True) + def fetch_attr(self, target: str): + """ + Fetch an attribute from the ``Module`` hierarchy of ``self.module``. + + Args: + target (str): The fully-qualified name of the attribute to fetch + + Return: + Any: The value of the attribute. + """ + target_atoms = target.split(".") + attr_itr = self.module + for i, atom in enumerate(target_atoms): + if not hasattr(attr_itr, atom): + raise RuntimeError( + f"Node referenced nonexistent target {'.'.join(target_atoms[: i + 1])}" + ) + attr_itr = getattr(attr_itr, atom) + return attr_itr + + @compatibility(is_backward_compatible=True) + def fetch_args_kwargs_from_env(self, n: Node) -> tuple[tuple, dict]: + """ + Fetch the concrete values of ``args`` and ``kwargs`` of node ``n`` + from the current execution environment. + + Args: + n (Node): The node for which ``args`` and ``kwargs`` should be fetched. + + Return: + Tuple[Tuple, Dict]: ``args`` and ``kwargs`` with concrete values for ``n``. + """ + args = self.map_nodes_to_values(n.args, n) + assert isinstance(args, tuple) + kwargs = self.map_nodes_to_values(n.kwargs, n) + assert isinstance(kwargs, dict) + return args, kwargs + + @compatibility(is_backward_compatible=True) + def map_nodes_to_values(self, args: Argument, n: Node) -> Argument: + """ + Recursively descend through ``args`` and look up the concrete value + for each ``Node`` in the current execution environment. + + Args: + args (Argument): Data structure within which to look up concrete values + + n (Node): Node to which ``args`` belongs. This is only used for error reporting. + """ + + def load_arg(n_arg: Node) -> Any: + if n_arg not in self.env: + raise RuntimeError( + f"Node {n} referenced nonexistent value {n_arg}! Run Graph.lint() " + f"to diagnose such issues" + ) + return self.env[n_arg] + + return map_arg(args, load_arg) + + +@compatibility(is_backward_compatible=True) +class Transformer(Interpreter): + """ + ``Transformer`` is a special type of interpreter that produces a + new ``Module``. It exposes a ``transform()`` method that returns + the transformed ``Module``. ``Transformer`` does not require + arguments to run, as ``Interpreter`` does. ``Transformer`` works + entirely symbolically. + + Example: + + Suppose we want to swap all instances of ``torch.neg`` with + ``torch.sigmoid`` and vice versa (including their ``Tensor`` + method equivalents). We could subclass ``Transformer`` like so:: + + class NegSigmSwapXformer(Transformer): + def call_function( + self, + target: "Target", + args: Tuple[Argument, ...], + kwargs: Dict[str, Any], + ) -> Any: + if target is torch.sigmoid: + return torch.neg(*args, **kwargs) + return super().call_function(target, args, kwargs) + + def call_method( + self, + target: "Target", + args: Tuple[Argument, ...], + kwargs: Dict[str, Any], + ) -> Any: + if target == "neg": + call_self, *args_tail = args + return call_self.sigmoid(*args_tail, **kwargs) + return super().call_method(target, args, kwargs) + + + def fn(x): + return torch.sigmoid(x).neg() + + + gm = torch.fx.symbolic_trace(fn) + + transformed: torch.nn.Module = NegSigmSwapXformer(gm).transform() + input = torch.randn(3, 4) + torch.testing.assert_close(transformed(input), torch.neg(input).sigmoid()) + + Args: + module (GraphModule): The ``Module`` to be transformed. + """ + + @compatibility(is_backward_compatible=True) + def __init__(self, module): + super().__init__(module) + self.new_graph = Graph() + self.new_graph.set_codegen(module.graph._codegen) + + class TransformerTracer(Tracer): + def __init__(self, graph: Graph): + super().__init__() + self.graph = graph + self.tensor_attrs: dict[torch.Tensor, str] = {} # type: ignore[assignment] + + def is_leaf_module(self, _, __) -> bool: + return True + + self.tracer = TransformerTracer(self.new_graph) + self.tracer.root = module + + @compatibility(is_backward_compatible=True) + def placeholder( + self, target: "Target", args: tuple[Argument, ...], kwargs: dict[str, Any] + ) -> Proxy: + """ + Execute a ``placeholder`` node. In ``Transformer``, this is + overridden to insert a new ``placeholder`` into the output + graph. + + Args: + target (Target): The call target for this node. See + `Node `__ for + details on semantics + args (Tuple): Tuple of positional args for this invocation + kwargs (Dict): Dict of keyword arguments for this invocation + """ + assert isinstance(target, str) + default_value = next(iter(args)) if args else inspect.Signature.empty + return Proxy( + self.new_graph.placeholder(target, default_value=default_value), self.tracer + ) + + @compatibility(is_backward_compatible=True) + def get_attr( + self, target: "Target", args: tuple[Argument, ...], kwargs: dict[str, Any] + ) -> Proxy: + """ + Execute a ``get_attr`` node. In ``Transformer``, this is + overridden to insert a new ``get_attr`` node into the output + graph. + + Args: + target (Target): The call target for this node. See + `Node `__ for + details on semantics + args (Tuple): Tuple of positional args for this invocation + kwargs (Dict): Dict of keyword arguments for this invocation + """ + assert isinstance(target, str) + return self.tracer.create_proxy("get_attr", target, args, kwargs) + + @compatibility(is_backward_compatible=True) + def call_module( + self, target: "Target", args: tuple[Argument, ...], kwargs: dict[str, Any] + ) -> Any: + # Override so that the leaf module policy from `self.tracer` is respected. + assert isinstance(target, str) + submod = self.fetch_attr(target) + return self.tracer.call_module(submod, submod.forward, args, kwargs) + + @compatibility(is_backward_compatible=True) + def call_function( + self, target: "Target", args: tuple[Argument, ...], kwargs: dict[str, Any] + ) -> Any: + # Override so that functions that were wrapped are still wrapped. + return self.tracer.create_proxy("call_function", target, args, kwargs) + + @compatibility(is_backward_compatible=True) + def transform(self) -> GraphModule: + """ + Transform ``self.module`` and return the transformed + ``GraphModule``. + """ + with fx_traceback.preserve_node_meta(): + result = super().run(enable_io_processing=False) + if result is not None: + + def strip_proxy(a: Union[Argument, Proxy]) -> Any: + return a.node if isinstance(a, Proxy) else a + + new_output_node = self.new_graph.output(map_aggregate(result, strip_proxy)) + # also preserve the metadata from the old output node, if it exists + old_output_node = list(self.graph.nodes)[-1] + assert old_output_node.op == "output" + for k, v in old_output_node.meta.items(): + new_output_node.meta[k] = v + + return _make_graph_module(self.module, self.new_graph) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/node.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/node.py new file mode 100644 index 0000000000000000000000000000000000000000..4af5ed9d82fe202265e4c3dc17665c91439197de --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/node.py @@ -0,0 +1,893 @@ +# Nodes represent a definition of a value in our graph of operators. +import builtins +import inspect +import logging +import operator +import types +import typing +from collections.abc import Callable, Iterable, Mapping, Sequence +from typing import Any, Optional, TYPE_CHECKING, TypeAlias, Union +from typing_extensions import ParamSpec, TypeVar + +import torch +from torch._C import _fx_map_aggregate, _fx_map_arg, _NodeBase +from torch.fx.operator_schemas import ( + ArgsKwargsPair, + normalize_function, + normalize_module, +) +from torch.utils._dtype_abbrs import dtype_abbrs + +from .._ops import ops as _ops +from ._compatibility import compatibility + + +if TYPE_CHECKING: + from .graph import Graph + +__all__ = ["Node", "map_arg", "map_aggregate", "has_side_effect"] + +log = logging.getLogger(__name__) + +BaseArgumentTypes = Union[ + str, + int, + float, + bool, + complex, + torch.dtype, + torch.Tensor, + torch.device, + torch.memory_format, + torch.layout, + torch._ops.OpOverload, + torch.SymInt, + torch.SymBool, + torch.SymFloat, +] +base_types = typing.get_args(BaseArgumentTypes) + +Target: TypeAlias = Union[Callable[..., Any], str] + +Argument = Optional[ + Union[ + tuple["Argument", ...], + Sequence["Argument"], + Mapping[str, "Argument"], + slice, # Slice[Argument, Argument, Argument], but slice is not a templated type in typing + range, + "Node", + BaseArgumentTypes, + ] +] +# pyrefly: ignore [invalid-annotation] +ArgumentT = TypeVar("ArgumentT", bound=Argument) +_P = ParamSpec("_P") +_R = TypeVar("_R") + +_legal_ops = dict.fromkeys( + [ + "placeholder", + "call_method", + "call_module", + "call_function", + "get_attr", + "output", + "root", + ] +) + +# Dynamo is unable to trace global set[Callable].__contains__. +# See https://github.com/pytorch/pytorch/issues/145761. Since we only have +# a handful of ops so switch to list of callables. +_side_effectful_need_to_be_preserved_pre_dispatch: list[Callable[..., Any]] = [ + torch._C._set_grad_enabled, + torch.amp._enter_autocast, + torch.amp._exit_autocast, +] + +# TODO: Either refactor this into 2 functions 1 dce for functional graphs and 1 dce for all graphs, +# or add logic to correctly mark all inplace ops as side effectful. +# +# NOTE: For new operators, please do not add to this set! +# Instead, consider using the effects system via +# torch.library._register_effectful_op() for operators. +# +# This _side_effectful_functions set is only for: +# - Legacy functions that aren't operators (e.g., profiler ops, asserts) +# - Things that cannot be marked via the normal effects system +_side_effectful_functions: set[Callable[..., Any]] = { + torch._assert, + torch._assert_async, + _ops.aten._assert_async.msg, + _ops.aten._assert_scalar.default, + _ops.aten._assert_tensor_metadata.default, + _ops.aten.sym_constrain_range.default, + _ops.aten.sym_constrain_range_for_size.default, + _ops.profiler._record_function_enter, + _ops.profiler._record_function_enter_new, + _ops.profiler._record_function_exit, + _ops.inductor.accumulate_grad_.default, + operator.setitem, + *_side_effectful_need_to_be_preserved_pre_dispatch, +} + +if hasattr(_ops.inductor, "resize_storage_bytes_"): + _side_effectful_functions.add(_ops.inductor.resize_storage_bytes_.default) + + +@compatibility(is_backward_compatible=False) +def has_side_effect(fn: Callable[_P, _R]) -> Callable[_P, _R]: + """ + Registers a function to not be dead code eliminated by + fx.graph.eliminate_dead_code + + NOTE: For new operators, please do not add to this set! + Instead, consider using the effects system via + torch.library._register_effectful_op() for operators. + + This _side_effectful_functions set is only for: + - Legacy functions that aren't operators (e.g., profiler ops, asserts) + - Things that cannot be marked via the normal effects system + """ + _side_effectful_functions.add(fn) + return fn + + +# this is fixed on master, WAR for 1.5 +def _find_module_of_method(orig_method: Callable[..., Any]) -> str: + name = orig_method.__name__ + module = orig_method.__module__ + if module is not None: + return module + for guess in [torch, torch.nn.functional]: + if getattr(guess, name, None) is orig_method: + return guess.__name__ + raise RuntimeError(f"cannot find module for {orig_method}") + + +# Borrowed from CPython typing module +# https://github.com/python/cpython/blob/f90dc36c15d7fee0efaf6d39e97be0bdf2683e93/Lib/typing.py#L156 +def _type_repr(obj: object) -> str: + """Return the repr() of an object, special-casing types (internal helper). + If obj is a type, we return a shorter version than the default + type.__repr__, based on the module and qualified name, which is + typically enough to uniquely identify a type. For everything + else, we fall back on repr(obj). + """ + # Extension: If we don't ignore GenericAlias then `list[int]` will print + # simply "list". + if isinstance(obj, type) and not isinstance(obj, types.GenericAlias): + if obj.__module__ == "builtins": + return obj.__qualname__ + return f"{obj.__module__}.{obj.__qualname__}" + if obj is ...: + return "..." + if isinstance(obj, types.FunctionType): + return obj.__name__ + return repr(obj) + + +def _get_qualified_name(func: Callable[..., Any]) -> str: + # things like getattr just appear in builtins + if getattr(builtins, func.__name__, None) is func: + return func.__name__ + # torch.Tensor.{fn} + if ( + isinstance(func, (types.MethodDescriptorType, types.WrapperDescriptorType)) + and func is getattr(torch.Tensor, func.__name__, None) + ) or ( + func.__module__ == torch._tensor.__name__ + and func.__qualname__ == f"Tensor.{func.__name__}" + ): + return f"torch.Tensor.{func.__name__}" + name = func.__name__ + if name == "": + # For lambdas, try to get their defining name in the module + try: + name = inspect.getsource(func).split("=")[0].strip() + except Exception as e: + raise RuntimeError("Unable to represent lambda") from e + module = _find_module_of_method(func) + module = module.replace( + "torch._ops", "torch.ops" + ) # WAR for bug in how torch.ops assigns module + # Fixup segment_reduce mismatch + if module == "torch" and name == "segment_reduce": + name = "_" + name + if module == "torch.nn.functional" and name in ("_ScalingType", "_SwizzleType"): + name = name.removeprefix("_") + return f"{module}.{name}" + + +def _format_arg(arg: object, max_list_len: float = float("inf")) -> str: + if hasattr(arg, "_custom_fx_repr_fn"): + return arg._custom_fx_repr_fn() + elif isinstance(arg, list): + items = ", ".join( + _format_arg(a) for idx, a in enumerate(arg) if idx < max_list_len + ) + maybe_len = ( + "" if len(arg) < max_list_len + 1 else f", ...[total_len={len(arg)}]" + ) + return f"[{items}{maybe_len}]" + elif isinstance(arg, tuple): + items = ", ".join( + _format_arg(a) for idx, a in enumerate(arg) if idx < max_list_len + ) + maybe_len = ( + "" if len(arg) < max_list_len + 1 else f", ...[total_len={len(arg)}]" + ) + maybe_comma = "," if len(arg) == 1 else "" + return f"({items}{maybe_comma}{maybe_len})" + elif isinstance(arg, dict): + items_str = ", ".join(f"{k}: {_format_arg(v)}" for k, v in arg.items()) + return f"{{{items_str}}}" + + if isinstance(arg, Node): + return "%" + str(arg) + else: + return str(arg) + + +@compatibility(is_backward_compatible=True) +class Node(_NodeBase): + """ + ``Node`` is the data structure that represents individual operations within + a ``Graph``. For the most part, Nodes represent callsites to various entities, + such as operators, methods, and Modules (some exceptions include nodes that + specify function inputs and outputs). Each ``Node`` has a function specified + by its ``op`` property. The ``Node`` semantics for each value of ``op`` are as follows: + + - ``placeholder`` represents a function input. The ``name`` attribute specifies the name this value will take on. + ``target`` is similarly the name of the argument. ``args`` holds either: 1) nothing, or 2) a single argument + denoting the default parameter of the function input. ``kwargs`` is don't-care. Placeholders correspond to + the function parameters (e.g. ``x``) in the graph printout. + - ``get_attr`` retrieves a parameter from the module hierarchy. ``name`` is similarly the name the result of the + fetch is assigned to. ``target`` is the fully-qualified name of the parameter's position in the module hierarchy. + ``args`` and ``kwargs`` are don't-care + - ``call_function`` applies a free function to some values. ``name`` is similarly the name of the value to assign + to. ``target`` is the function to be applied. ``args`` and ``kwargs`` represent the arguments to the function, + following the Python calling convention + - ``call_module`` applies a module in the module hierarchy's ``forward()`` method to given arguments. ``name`` is + as previous. ``target`` is the fully-qualified name of the module in the module hierarchy to call. + ``args`` and ``kwargs`` represent the arguments to invoke the module on, *excluding the self argument*. + - ``call_method`` calls a method on a value. ``name`` is as similar. ``target`` is the string name of the method + to apply to the ``self`` argument. ``args`` and ``kwargs`` represent the arguments to invoke the module on, + *including the self argument* + - ``output`` contains the output of the traced function in its ``args[0]`` attribute. This corresponds to the "return" statement + in the Graph printout. + """ + + _args: tuple["Argument", ...] + _kwargs: dict[str, "Argument"] + graph: "Graph" + # unique name of value being created + name: str + # the kind of operation = placeholder|call_method|call_module|call_function|get_attr + op: str + # for method/module/function, the name of the method/module/function/attr + # being invoked, e.g add, layer1, or torch.add + target: "Target" + # All `Node`-valued inputs. Key is the Node, value is don't-care. + # The public API for this is `all_input_nodes`, this private attribute + # should not be accessed directly. + _input_nodes: dict["Node", None] + # All of the nodes that use the value produced by this Node + # Note one user may correspond to several uses, e.g. the node for ``x + x`` + # would appear once here, but represents two uses. + # Is a dict to act as an "ordered set". Keys are significant, value dont-care + users: dict["Node", None] + # Type expression representing the output value of this node. + # This should contain the same class of Type objects that would appear + # as type annotations for function inputs/outputs. + # + # For placeholder nodes, this value will be used to type-annotate the + # generated function parameters. + # For the return node, this value will be used to type-annotate the + # generated function return type. (Note this is a special case. ``return`` + # does not produce a value, it's more of a notation. Thus, this value + # describes the type of args[0] in the ``return`` node. + type: Optional[Any] + _sort_key: Any + # If set, use this fn to print this node + _repr_fn: Optional[Callable[["Node"], str]] + # Dictionary to store metadata passes need to do their + # transformations. This metadata is preserved across node copies + meta: dict[str, Any] + + @compatibility(is_backward_compatible=True) + def __init__( + self, + graph: "Graph", + name: str, + op: str, + target: "Target", + args: tuple["Argument", ...], + kwargs: dict[str, "Argument"], + return_type: Optional[Any] = None, + ) -> None: + """ + Instantiate an instance of ``Node``. Note: most often, you want to use the + Graph APIs, i.e. ``Graph.call_module``, ``Graph.call_method``, etc. rather + than instantiating a ``Node`` directly. + + Args: + graph (Graph): The ``Graph`` to which this ``Node`` should belong. + + name (str): The name to which the output of this ``Node`` should be assigned + + op (str): The opcode for this ``Node``. Can be one of 'placeholder', + 'call_method', 'call_module', 'call_function', 'get_attr', + 'output' + + target ('Target'): The target this op should call. See the broader + ``Node`` docstring for more details. + + args (Tuple['Argument']): The args to be passed to ``target`` + + kwargs (Dict[str, 'Argument']): The kwargs to be passed to ``target`` + + return_type (Optional[Any]): The python type expression representing the + type of the output of this node. This field can be used for + annotation of values in the generated code or for other types + of analyses. + """ + if op == "call_function": + if not callable(target): + raise ValueError( + f"Node [graph = {graph}, name = '{name}'] target {target} has type {torch.typename(target)} " + "but a Callable is expected" + ) + else: + assert op in _legal_ops + if not isinstance(target, str): + raise ValueError( + f"Node [graph = {graph}, name = '{name}'] target {target} has type {torch.typename(target)} " + "but a str is expected" + ) + super().__init__(graph, name, op, target, return_type) + self._update_args_kwargs(args, kwargs) + + def __getstate__(self) -> dict[str, Any]: + return { + **self.__dict__, + "graph": self.graph, + "name": self.name, + "op": self.op, + "target": self.target, + "type": self.target, + "_sort_key": self._sort_key, + "_args": self._args, + "_kwargs": self._kwargs, + "_erased": self._erased, + "_prev": self._prev, + "_next": self._next, + "_input_nodes": self._input_nodes, + "users": self.users, + "_repr_fn": self._repr_fn, + "meta": self.meta, + } + + def __setstate__(self, state: dict[str, Any]) -> None: + for k, v in state.items(): + setattr(self, k, v) + + @property + def next(self) -> "Node": + """ + Returns the next ``Node`` in the linked list of Nodes. + + Returns: + + The next ``Node`` in the linked list of Nodes. + """ + return self._next + + @property + def prev(self) -> "Node": + """ + Returns the previous ``Node`` in the linked list of Nodes. + + Returns: + + The previous ``Node`` in the linked list of Nodes. + """ + return self._prev + + @compatibility(is_backward_compatible=True) + def prepend(self, x: "Node") -> None: + """ + Insert x before this node in the list of nodes in the graph. Example:: + + Before: p -> self + bx -> x -> ax + After: p -> x -> self + bx -> ax + + Args: + x (Node): The node to put before this node. Must be a member of the same graph. + """ + # pyrefly: ignore [missing-attribute] + self._prepend(x) + + @compatibility(is_backward_compatible=True) + def append(self, x: "Node") -> None: + """ + Insert ``x`` after this node in the list of nodes in the graph. + Equivalent to ``self.next.prepend(x)`` + + Args: + x (Node): The node to put after this node. Must be a member of the same graph. + """ + # pyrefly: ignore [missing-attribute] + self._next._prepend(x) + + @property + def args(self) -> tuple[Argument, ...]: + """ + The tuple of arguments to this ``Node``. The interpretation of arguments + depends on the node's opcode. See the :class:`Node` docstring for more + information. + + Assignment to this property is allowed. All accounting of uses and users + is updated automatically on assignment. + """ + return self._args + + @args.setter + def args(self, a: tuple[Argument, ...]) -> None: + """ + Set the tuple of arguments to this Node. The interpretation of arguments + depends on the node's opcode. See the ``fx.Graph`` docstring for more + information. + """ + # DO NOT CALL `_update_args_kwargs` directly. The correct way to + # set `args` is via direct assignment, i.e. `node.args = new_args` + self._update_args_kwargs(a, self._kwargs) + + @property + def kwargs(self) -> dict[str, Argument]: + """ + The dict of keyword arguments to this ``Node``. The interpretation of arguments + depends on the node's opcode. See the :class:`Node` docstring for more + information. + + Assignment to this property is allowed. All accounting of uses and users + is updated automatically on assignment. + """ + return self._kwargs + + @kwargs.setter + def kwargs(self, k: dict[str, Argument]) -> None: + """ + Set the dict of kwargs to this Node. The interpretation of arguments + depends on the node's opcode. See the ``fx.Graph`` docstring for more + information. + """ + # DO NOT CALL `_update_args_kwargs` directly. The correct way to + # set `args` is via direct assignment, i.e. `node.kwargs = new_kwargs` + self._update_args_kwargs(self._args, k) + + @property + def all_input_nodes(self) -> list["Node"]: + """ + Return all Nodes that are inputs to this Node. This is equivalent to + iterating over ``args`` and ``kwargs`` and only collecting the values that + are Nodes. + + Returns: + + List of ``Nodes`` that appear in the ``args`` and ``kwargs`` of this + ``Node``, in that order. + """ + return list(self._input_nodes.keys()) + + @compatibility(is_backward_compatible=True) + def update_arg(self, idx: int, arg: Argument) -> None: + """ + Update an existing positional argument to contain the new value + ``arg``. After calling, ``self.args[idx] == arg``. + + Args: + + idx (int): The index into ``self.args`` of the element to update + arg (Argument): The new argument value to write into ``args`` + """ + args = list(self.args) + args[idx] = arg + self.args = tuple(args) + + @compatibility(is_backward_compatible=True) + def insert_arg(self, idx: int, arg: Argument) -> None: + """ + Insert an positional argument to the argument list with given index. + + Args: + + idx (int): The index of the element in ``self.args`` to be inserted before. + arg (Argument): The new argument value to insert into ``args`` + """ + assert 0 <= idx <= len(self.args), ( + "insert_args index must be between 0 and len(self.args)" + ) + args_left = self.args[:idx] + args_right = self.args[idx:] + + self._args = args_left + (arg,) + args_right + + _new_input_nodes: dict[Node, None] = {} + _fx_map_arg(arg, _new_input_nodes.setdefault) + + for new_use in _new_input_nodes: + if new_use not in self._input_nodes: + self._input_nodes.setdefault(new_use) + new_use.users.setdefault(self) + + @compatibility(is_backward_compatible=True) + def update_kwarg(self, key: str, arg: Argument) -> None: + """ + Update an existing keyword argument to contain the new value + ``arg``. After calling, ``self.kwargs[key] == arg``. + + Args: + + key (str): The key in ``self.kwargs`` of the element to update + arg (Argument): The new argument value to write into ``kwargs`` + """ + self.kwargs = {**self.kwargs, key: arg} + + @property + def stack_trace(self) -> Optional[str]: + """ + Return the Python stack trace that was recorded during tracing, if any. + When traced with fx.Tracer, this property is usually populated by + `Tracer.create_proxy`. To record stack traces during tracing for debug purposes, + set `record_stack_traces = True` on the `Tracer` instance. + When traced with dynamo, this property will be populated by default by + `OutputGraph.create_proxy`. + + stack_trace would have the innermost frame at the end of the string. + """ + return self.meta.get("stack_trace", None) + + @stack_trace.setter + def stack_trace(self, trace: Optional[str]) -> None: + self.meta["stack_trace"] = trace + + def __repr__(self) -> str: + if self._repr_fn: + return self._repr_fn(self) + return self.name + + @staticmethod + def _pretty_print_target(target: object) -> str: + """ + Make target printouts more user-friendly. + 1) builtins will be printed as `builtins.xyz` + 2) operators will be printed as `operator.xyz` + 3) other callables will be printed with qualified name, e.g. torch.add + """ + if isinstance(target, str): + return target + if hasattr(target, "__module__"): + name = getattr(target, "__name__", None) + if name is None: + # Just to be defensive, if we don't have `__name__`, get the + # qualname. Not sure if this happens for any members of `operator` + # or `builtins`. This fallback path is not as good, since e.g. + # things in `operator` have `_operator` as their __module__. + # TODO: THIS IS BROKEN: _get_qualified_name calls `__name__` + return _get_qualified_name(target) # type: ignore[arg-type] + if target.__module__ == "builtins": + return f"builtins.{name}" + elif target.__module__ == "_operator": + return f"operator.{name}" + return _get_qualified_name(target) # type: ignore[arg-type] + + @compatibility(is_backward_compatible=True) + def format_node( + self, + placeholder_names: Optional[list[str]] = None, + maybe_return_typename: Optional[list[str]] = None, + *, + include_tensor_metadata: bool = False, + ) -> Optional[str]: + """ + Return a descriptive string representation of ``self``. + + This method can be used with no arguments as a debugging + utility. + + This function is also used internally in the ``__str__`` method + of ``Graph``. Together, the strings in ``placeholder_names`` + and ``maybe_return_typename`` make up the signature of the + autogenerated ``forward`` function in this Graph's surrounding + GraphModule. ``placeholder_names`` and ``maybe_return_typename`` + should not be used otherwise. + + Args: + placeholder_names: A list that will store formatted strings + representing the placeholders in the generated + ``forward`` function. Internal use only. + maybe_return_typename: A single-element list that will store + a formatted string representing the output of the + generated ``forward`` function. Internal use only. + include_tensor_metadata: Whether to include tensor metadata + + Returns: + str: If 1) we're using ``format_node`` as an internal helper + in the ``__str__`` method of ``Graph``, and 2) ``self`` + is a placeholder Node, return ``None``. Otherwise, + return a descriptive string representation of the + current Node. + """ + if self.op == "placeholder": + assert isinstance(self.target, str) + arg_str = self.target + arg_str += arg_str + f": {_type_repr(self.type)}" if self.type else "" + if placeholder_names: + placeholder_names.append(arg_str) + return None + maybe_typename = f"{_type_repr(self.type)} " if self.type else "" + default_val = "(default=" + str(self.args[0]) + ")" if self.args else "" + return f"%{self.name} : {maybe_typename}[num_users={len(self.users)}] = {self.op}[target={self.target}]{default_val}" + elif self.op == "get_attr": + maybe_typename = ( + f"{_type_repr(self.type)} " if self.type is not None else "" + ) + return ( + f"%{self.name} : {maybe_typename}[num_users={len(self.users)}] = " + f"{self.op}[target={self._pretty_print_target(self.target)}]" + ) + elif self.op == "output": + if self.type and maybe_return_typename: + maybe_return_typename[0] = f" -> {_type_repr(self.type)}" + return f"return {self.args[0]}" + else: + + def stringify_shape(shape: Iterable) -> str: + return f"[{', '.join([str(x) for x in shape])}]" + + meta_val = self.meta.get( + "val", + self.meta.get("tensor_meta", self.meta.get("example_value", None)), + ) + type_annotation = "" + if ( + include_tensor_metadata + and isinstance(meta_val, torch.Tensor) + and meta_val.layout + not in ( + torch.sparse_csc, + torch.sparse_csr, + ) + ): + stride_annotation = f"{stringify_shape(meta_val.stride())}" + device_annotation = f"{meta_val.device}" + type_annotation = ( + f'Tensor "{dtype_abbrs[meta_val.dtype]}{stringify_shape(meta_val.shape)}' + f'{stride_annotation}{device_annotation}"' + ) + else: + type_annotation = ( + f"{_type_repr(self.type)} " if self.type is not None else "" + ) + return ( + f"%{self.name} : {type_annotation}[num_users={len(self.users)}] = " + f"{self.op}[target={self._pretty_print_target(self.target)}](" + f"args = {_format_arg(self.args)}, kwargs = {_format_arg(self.kwargs)})" + ) + + @compatibility(is_backward_compatible=True) + def replace_all_uses_with( + self, + replace_with: "Node", + delete_user_cb: Optional[Callable[["Node"], bool]] = None, + *, + propagate_meta: bool = False, + ) -> list["Node"]: + """ + Replace all uses of ``self`` in the Graph with the Node ``replace_with``. + + Args: + + replace_with (Node): The node to replace all uses of ``self`` with. + delete_user_cb (Callable): Callback that is called to determine + whether a given user of the self node should be removed. + propagate_meta (bool): Whether or not to copy all properties + on the .meta field of the original node onto the replacement node. + For safety, this is only valid to do if the replacement node + doesn't already have an existing .meta field. + + Returns: + + The list of Nodes on which this change was made. + """ + if propagate_meta: + assert len(replace_with.meta) == 0, ( + "Called node.replace_all_uses_with(replace_with, propagate_meta=True), " + "but replace_with already has .meta keys" + ) + for k, v in self.meta.items(): + replace_with.meta[k] = v + to_process = [*self.users] + replace_hooks = getattr(self.graph.owning_module, "_replace_hooks", None) + result = [] + for use_node in to_process: + if delete_user_cb is not None and not delete_user_cb(use_node): + continue + result.append(use_node) + if replace_hooks: + for replace_hook in replace_hooks: + replace_hook(old=self, new=replace_with.name, user=use_node) + # pyrefly: ignore [missing-attribute] + use_node._replace_input_with(self, replace_with) # type: ignore[attr-defined] + return result + + @compatibility(is_backward_compatible=False) + def is_impure(self, impure_random: bool = True) -> bool: + """ + Returns whether this op is impure, i.e. if its op is a placeholder or + output, or if a call_function or call_module which is impure. + + Args: + impure_random (bool): Whether to treat rand op as impure. + + Returns: + + bool: If the op is impure or not. + """ + # Placeholders and outputs are always impure for DCE purposes + if self.op in {"placeholder", "output"}: + return True + + # Check if an impure module. + if self.op == "call_module": + assert self.graph.owning_module is not None, ( + "self.graph.owning_module not set for purity check" + ) + target_mod = self.graph.owning_module.get_submodule(self.target) + assert target_mod is not None, ( + f"Did not find expected submodule target {self.target}" + ) + # NOTE: here we can end up considering GraphModule submodules pure, + # even if they contain impure ops. It may not be safe to change + # because this function is used by graph.eliminate_dead_code, + # and some users depend on current elimination behavior. + return getattr(target_mod, "_is_impure", False) + + # For call_function, delegate to the unified has_side_effects function + if self.op == "call_function": + from torch._library.utils import is_impure + + return is_impure( + self.target, # pyrefly: ignore[bad-argument-type] + args=self.args, + kwargs=self.kwargs, + impure_random=impure_random, + ) + + return False + + @compatibility(is_backward_compatible=False) + def normalized_arguments( + self, + root: torch.nn.Module, + arg_types: Optional[tuple[Any]] = None, + kwarg_types: Optional[dict[str, Any]] = None, + normalize_to_only_use_kwargs: bool = False, + ) -> Optional[ArgsKwargsPair]: + """ + Returns normalized arguments to Python targets. This means that + `args/kwargs` will be matched up to the module/functional's + signature and return exclusively kwargs in positional order + if `normalize_to_only_use_kwargs` is true. + Also populates default values. Does not support positional-only + parameters or varargs parameters. + + Supports module calls. + + May require `arg_types` and `kwarg_types` in order to disambiguate overloads. + + Args: + root (torch.nn.Module): Module upon which to resolve module targets. + arg_types (Optional[Tuple[Any]]): Tuple of arg types for the args + kwarg_types (Optional[Dict[str, Any]]): Dict of arg types for the kwargs + normalize_to_only_use_kwargs (bool): Whether to normalize to only use kwargs. + + Returns: + + Returns NamedTuple ArgsKwargsPair, or `None` if not successful. + """ + if self.op == "call_function": + assert callable(self.target) + return normalize_function( + self.target, + self.args, # type: ignore[arg-type] + self.kwargs, + arg_types, + kwarg_types, + normalize_to_only_use_kwargs=normalize_to_only_use_kwargs, + ) + elif self.op == "call_module": + assert isinstance(self.target, str) + return normalize_module( + root, + self.target, + self.args, # type: ignore[arg-type] + self.kwargs, + normalize_to_only_use_kwargs=normalize_to_only_use_kwargs, + ) + + return None + + @compatibility(is_backward_compatible=True) + def replace_input_with(self, old_input: "Node", new_input: "Node") -> None: + """ + Loop through input nodes of ``self``, and replace all instances of + ``old_input`` with ``new_input``. + + Args: + + old_input (Node): The old input node to be replaced. + new_input (Node): The new input node to replace ``old_input``. + """ + + m = self.graph.owning_module + if getattr(m, "_replace_hooks", None): + for replace_hook in m._replace_hooks: + replace_hook(old=old_input, new=new_input.name, user=self) + + # pyrefly: ignore [missing-attribute] + self._replace_input_with(old_input, new_input) # type: ignore[attr-defined] + + def _rename(self, candidate: str) -> None: + if candidate == self.name: + return + name = self.graph._graph_namespace.create_name(candidate, None) + self.name = name + self.graph._graph_namespace._rename_object(self, name) + + def __setattr__(self, name: str, value: Any) -> None: + if name == "name" and hasattr(self, "name"): + m = self.graph.owning_module + if getattr(m, "_replace_hooks", None): + assert isinstance(value, str) + for user in self.users: + for replace_hook in m._replace_hooks: + replace_hook(old=self, new=value, user=user) + update = False + if ( + hasattr(self, name) + and hasattr(self.graph, "_find_nodes_lookup_table") + and self in self.graph._find_nodes_lookup_table + ): + update = True + self.graph._find_nodes_lookup_table.remove(self) + object.__setattr__(self, name, value) + if update: + self.graph._find_nodes_lookup_table.insert(self) + + +@compatibility(is_backward_compatible=True) +def map_arg(a: ArgumentT, fn: Callable[[Node], Argument]) -> ArgumentT: + """ + Apply fn recursively to each Node appearing in arg. + + arg may be a list, tuple, slice, or dict with string keys: the return value will + have the same type and structure. + """ + assert callable(fn), "torch.fx.map_arg(a, fn): fn must be a callable" + return _fx_map_arg(a, fn) + + +@compatibility(is_backward_compatible=True) +def map_aggregate(a: ArgumentT, fn: Callable[[Argument], Argument]) -> ArgumentT: + """ + Apply fn recursively to each object appearing in arg. + + arg may be a list, tuple, slice, or dict with string keys: the return value will + have the same type and structure. + """ + return _fx_map_aggregate(a, fn) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/operator_schemas.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/operator_schemas.py new file mode 100644 index 0000000000000000000000000000000000000000..397d4c5996ee9024ecebf2e306d45a4d27b36c7f --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/operator_schemas.py @@ -0,0 +1,570 @@ +# mypy: allow-untyped-defs +import enum +import inspect +import numbers +import types +import typing +import warnings +from collections.abc import Callable +from typing import Any, cast, NamedTuple, Optional, TYPE_CHECKING + +import torch +from torch._jit_internal import boolean_dispatched +from torch._ops import OpOverload, OpOverloadPacket + +from ._compatibility import compatibility + + +if TYPE_CHECKING: + from .node import Argument + +__all__ = [ + "ArgsKwargsPair", + "check_for_mutable_operation", + "get_signature_for_torch_op", + "create_type_hint", + "type_matches", + "normalize_function", + "normalize_module", +] + + +@compatibility(is_backward_compatible=False) +class ArgsKwargsPair(NamedTuple): + """ + Simple named tuple for wrapping args/kwargs pairs. + """ + + args: tuple[Any, ...] + kwargs: dict[str, Any] + + +_manual_overrides: dict[Callable, list[inspect.Signature]] = {} + + +def _nonzero_schemas(): + signatures = [] + + def nonzero(self): + pass + + signatures.append(inspect.signature(nonzero)) + + def nonzero(self, *, as_tuple: bool): # type: ignore[no-redef] + pass + + signatures.append(inspect.signature(nonzero)) + + return signatures + + +_manual_overrides[torch.nonzero] = _nonzero_schemas() + + +class _FakeGlobalNamespace: + def __getattr__(self, name): + if name == "torch": + return torch + raise RuntimeError("Expected a torch namespace lookup") + + +_type_eval_globals = { + "Tensor": torch.Tensor, + "Device": torch.device, + "Layout": torch.layout, + "number": numbers.Number, + "Future": torch.jit.Future, + "AnyEnumType": enum.Enum, + "QScheme": torch.qscheme, + "__torch__": _FakeGlobalNamespace(), + "NoneType": type(None), + "Storage": torch.UntypedStorage, + "t": typing.TypeVar("t"), + "PyObject": Any, +} +for k in dir(typing): + _type_eval_globals[k] = getattr(typing, k) + + +def _torchscript_type_to_python_type(ts_type: "torch._C.JitType") -> Any: + """ + Convert a TorchScript type to a Python type (including subtypes) via + eval'ing the annotation_str. _type_eval_globals sets up expressions + like "List" and "Future" to map to actual types (typing.List and jit.Future) + """ + return eval(ts_type.annotation_str, _type_eval_globals) + + +def _torchscript_schema_to_signature_impl( + ts_schema: torch._C.FunctionSchema, +) -> inspect.Signature: + from inspect import Parameter + + parameters: list[Parameter] = [] + for arg in ts_schema.arguments: + arg_type = _torchscript_type_to_python_type(arg.type) + default = arg.default_value if arg.has_default_value() else Parameter.empty + # TODO: Figure out if this is safe. It seems like when generating the type signatures for + # PythonArgParser, we emit signatures with `input` instead of `self` as the first tensor + # argument name. Downstream, if someone converts that positional argument to a keyword + # argument, the name mismatch will break things, so here we're going to normalize the + # name to "input" + name = arg.name if arg.name != "self" else "input" + kind = ( + Parameter.KEYWORD_ONLY + if arg.kwarg_only + else Parameter.POSITIONAL_OR_KEYWORD + ) + # "from" is a keyword therefore it must be a POSITIONAL_ONLY argument + if name == "from": + assert kind == Parameter.POSITIONAL_OR_KEYWORD + # ParameterKind type is internal implementation detail to inspec package + # which makes it hard to do type annotation + kind = Parameter.POSITIONAL_ONLY # type: ignore[assignment] + # This renders all previous arguments to positional only + + for idx, p in enumerate(parameters): + assert p.kind == Parameter.POSITIONAL_OR_KEYWORD + parameters[idx] = Parameter( + name=p.name, + kind=Parameter.POSITIONAL_ONLY, + default=p.default, + annotation=p.annotation, + ) + + parameters.append( + Parameter(name=name, kind=kind, default=default, annotation=arg_type) + ) + return_types = [ + _torchscript_type_to_python_type(ret.type) for ret in ts_schema.returns + ] + if len(return_types) == 0: + return_type = None + elif len(return_types) == 1: + return_type = return_types[0] + else: + return_type = tuple(return_types) + + return inspect.Signature(parameters, return_annotation=return_type) + + +_SCHEMA_TO_SIGNATURE_CACHE: dict[tuple[str, str], inspect.Signature] = {} + + +def _torchscript_schema_to_signature( + ts_schema: torch._C.FunctionSchema, +) -> inspect.Signature: + # Cached as it's called in the hot path of FakeTensor dispatch + cache_key = ts_schema.name, ts_schema.overload_name + cache_val = _SCHEMA_TO_SIGNATURE_CACHE.get(cache_key) + if cache_val is not None: + return cache_val + + res = _torchscript_schema_to_signature_impl(ts_schema) + _SCHEMA_TO_SIGNATURE_CACHE[cache_key] = res + return res + + +@compatibility(is_backward_compatible=False) +def check_for_mutable_operation( + target: Callable, args: tuple["Argument", ...], kwargs: dict[str, "Argument"] +): + signatures, schemas = get_signature_for_torch_op(target, return_schemas=True) + + if signatures and schemas: + matched_schemas = [] + + # Iterate through all of the schema until we find one that matches + # If one matches, populate `new_args_and_kwargs` with the new args/kwargs + # values. If none matches, `new_args_and_kwargs` will be None + for candidate_signature, schema in zip(signatures, schemas): + try: + candidate_signature.bind(*args, **kwargs) + matched_schemas.append((candidate_signature, schema)) + except TypeError: + continue + + def throw_if_mutable(schema): + if schema.is_mutable: + raise RuntimeError( + f"Tried to trace mutable operation {schema}. FX only supports functional " + f"code, so operations that mutate operands in-place (e.g. via `out` arguments) " + f"are not supported" + ) + + if len(matched_schemas) == 0: + # Did not match any schema. Cannot check for mutation + pass + elif len(matched_schemas) == 1: + # Matched exactly one schema, unambiguous + _, schema_to_check = matched_schemas[0] + throw_if_mutable(schema_to_check) + else: + # Ambiguous schema match. Since mutability checking is best effort, + # do nothing. + pass + + +@compatibility(is_backward_compatible=False) +def get_signature_for_torch_op(op: Callable, return_schemas: bool = False): + """ + Given an operator on the `torch` namespace, return a list of `inspect.Signature` + objects corresponding to the overloads of that op.. May return `None` if a signature + could not be retrieved. + + Args: + op (Callable): An operator on the `torch` namespace to look up a signature for + + Returns: + Optional[List[inspect.Signature]]: A list of signatures for the overloads of this + operator, or None if the operator signatures could not be retrieved. If + return_schemas=True, returns a tuple containing the optional Python signatures + and the optional TorchScript Function signature + """ + if isinstance(op, OpOverload): + schemas = [op._schema] + elif isinstance(op, OpOverloadPacket): + schemas = [getattr(op, overload)._schema for overload in op.overloads()] + else: + override = _manual_overrides.get(op) + if override: + return (override, None) if return_schemas else None + + aten_fn = torch.jit._builtins._find_builtin(op) + + if aten_fn is None: + return (None, None) if return_schemas else None + schemas = torch._C._jit_get_schemas_for_operator(aten_fn) + + signatures = [_torchscript_schema_to_signature(schema) for schema in schemas] + return (signatures, schemas) if return_schemas else signatures + + +@compatibility(is_backward_compatible=False) +def create_type_hint(x): + """ + Produces a type hint for the given argument. + + The :func:`create_type_hint` looks for a type hint compatible with the input argument `x`. + + If `x` is a `list` or `tuple`, it looks for an object in the list whose type is a superclass + of the rest, and uses that as `base_type` for the `List` or `Tuple` to be returned. + If no such object is found, it defaults to `List[Any]`. + + If `x` is neither a `list` nor a `tuple`, it returns `x`. + """ + try: + if isinstance(x, (list, tuple)): + # todo(chilli): Figure out the right way for mypy to handle this + if isinstance(x, list): + + def ret_type(x): + return list[x] # type: ignore[valid-type] + + else: + + def ret_type(x): + return tuple[x, ...] # type: ignore[valid-type] + + if len(x) == 0: + return ret_type(Any) + base_type = x[0] + for t in x: + if issubclass(t, base_type): + continue + elif issubclass(base_type, t): + base_type = t + else: + return ret_type(Any) + return ret_type(base_type) + except Exception: + # We tried to create a type hint for list but failed. + warnings.warn( + f"We were not able to successfully create type hint from the type {x}" + ) + return x + + +@compatibility(is_backward_compatible=False) +def type_matches(signature_type: Any, argument_type: Any): + sig_origin_type = getattr(signature_type, "__origin__", signature_type) + + if signature_type is argument_type: + return True + + # Union types in signature. Given type needs to match one of the + # contained types in the Union + if sig_origin_type is typing.Union and signature_type != argument_type: + sig_contained = signature_type.__args__ + return any(type_matches(c, argument_type) for c in sig_contained) + + if getattr(signature_type, "__origin__", None) is list: + sig_el_type = signature_type.__args__[0] + + # int can be promoted to list[int] + if argument_type is int and sig_el_type is int: + return True + + if not inspect.isclass(sig_el_type): + warnings.warn( + f"Does not support nested parametric types, got {signature_type}. Please file a bug." + ) + return False + if getattr(argument_type, "__origin__", None) is list: + return issubclass(argument_type.__args__[0], sig_el_type) + + def is_homogeneous_tuple(t): + if getattr(t, "__origin__", None) is not tuple: + return False + contained = t.__args__ + if t.__args__ == ((),): # Tuple[()].__args__ == ((),) for some reason + return True + return all((c is Ellipsis) or issubclass(c, sig_el_type) for c in contained) + + # Tuple[T] is accepted for List[T] parameters + return is_homogeneous_tuple(argument_type) + + # Dtype is an int in schemas + if signature_type is int and argument_type is torch.dtype: + return True + + if signature_type is numbers.Number and argument_type in {int, float}: + return True + if inspect.isclass(argument_type) and inspect.isclass(signature_type): + return issubclass(argument_type, signature_type) + + return False + + +@compatibility(is_backward_compatible=False) +def normalize_function( + target: Callable, + args: tuple[Any, ...], + kwargs: Optional[dict[str, Any]] = None, + arg_types: Optional[tuple[Any]] = None, + kwarg_types: Optional[dict[str, Any]] = None, + normalize_to_only_use_kwargs: bool = False, +) -> Optional[ArgsKwargsPair]: + """ + Returns normalized arguments to PyTorch functions. This means that + `args/kwargs` will be matched up to the functional's + signature and return exclusively kwargs in positional order if + `normalize_to_only_use_kwargs` is True. + Also populates default values. Does not support positional-only + parameters or varargs parameters (*args, **kwargs). Does not support modules. + + May require `arg_types` and `kwarg_types` in order to disambiguate overloads. + + Args: + target (Callable): Function that we are normalizing + args (Tuple[Any]): Tuple of args to the function + kwargs (Optional[Dict[str, Any]]): Dict of kwargs to the function + arg_types (Optional[Tuple[Any]]): Tuple of arg types for the args + kwarg_types (Optional[Dict[str, Any]]): Dict of arg types for the kwargs + normalize_to_only_use_kwargs (bool): Whether to normalize to only use kwargs. + + Returns: + + Returns normalized_args_and_kwargs, or `None` if not successful. + """ + if kwargs is None: + kwargs = {} + new_args_and_kwargs = None + if ( + not isinstance(target, types.BuiltinFunctionType) + and not (isinstance(target, (OpOverloadPacket, OpOverload))) + and hasattr(target, "_op") + ): + # ExecuTorch's EdgeOpOverload are a wrapper around PyTorch's OpOverload, + # so we can unwrap it here to get its schema + # Can't import EdgeOpOverload directly because of a circular dependency, + # so checking for "_op" existing is the next best thing. + target = target._op + + # Repeat the condition after checking for the inner _op field. + if not isinstance(target, types.BuiltinFunctionType) and not ( + isinstance(target, (OpOverloadPacket, OpOverload)) + ): + target_for_analysis = target + if target in boolean_dispatched: + # HACK: `boolean_dispatch` as used in `torch.nn.functional` makes it so that we have + # a 2-way dispatch based on a boolean value. Here we check that the `true` and `false` + # branches of the dispatch have exactly the same signature. If they do, use the `true` + # branch signature for analysis. Otherwise, leave this un-normalized + assert not isinstance(target, str) + dispatched = boolean_dispatched[target] + if_true, if_false = dispatched["if_true"], dispatched["if_false"] + if ( + inspect.signature(if_true).parameters + != inspect.signature(if_false).parameters + ): + return None + target_for_analysis = if_true + + assert callable(target_for_analysis) + sig = inspect.signature(inspect.unwrap(target_for_analysis)) + new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs( + sig, args, kwargs, normalize_to_only_use_kwargs + ) + else: + assert callable(target) + torch_op_schemas = get_signature_for_torch_op(target) + matched_schemas = [] + if torch_op_schemas: + # Iterate through all of the schema until we find one that matches + # If one matches, populate `new_args_and_kwargs` with the new args/kwargs + # values. If none matches, `new_args_and_kwargs` will be None + for candidate_signature in torch_op_schemas: + try: + candidate_signature.bind(*args, **kwargs) + matched_schemas.append(candidate_signature) + except TypeError: + continue + + if len(matched_schemas) == 0: + # Did not match any schema. Cannot normalize + pass + elif len(matched_schemas) == 1: + # Matched exactly one schema, unambiguous + new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs( + matched_schemas[0], args, kwargs, normalize_to_only_use_kwargs + ) + else: + if arg_types is not None or kwarg_types is not None: + arg_types = arg_types if arg_types else cast(tuple[Any], ()) + kwarg_types = kwarg_types if kwarg_types else {} + for candidate_signature in torch_op_schemas: + sig_matches = True + try: + bound_types = candidate_signature.bind( + *arg_types, **kwarg_types + ) + for arg_name, arg_type in bound_types.arguments.items(): + param = candidate_signature.parameters[arg_name] + sig_matches = sig_matches and type_matches( + param.annotation, arg_type + ) + except TypeError: + sig_matches = False + if sig_matches: + new_args_and_kwargs = ( + _args_kwargs_to_normalized_args_kwargs( + candidate_signature, + args, + kwargs, + normalize_to_only_use_kwargs, + ) + ) + break + else: + # Matched more than one schema. In this situation, the caller must provide the types of + # the arguments of the overload they expect. + schema_printouts = "\n".join( + str(schema) for schema in matched_schemas + ) + raise RuntimeError( + f"Tried to normalize arguments to {torch.typename(target)} but " + f"the schema match was ambiguous! Please provide argument types to " + f"the normalize_arguments() call. Available schemas:\n{schema_printouts}" + ) + + return new_args_and_kwargs + + +@compatibility(is_backward_compatible=False) +def normalize_module( + root: torch.nn.Module, + target: str, + args: tuple[Any], + kwargs: Optional[dict[str, Any]] = None, + normalize_to_only_use_kwargs: bool = False, +) -> Optional[ArgsKwargsPair]: + """ + Returns normalized arguments to PyTorch modules. This means that + `args/kwargs` will be matched up to the functional's + signature and return exclusively kwargs in positional order if + `normalize_to_only_use_kwargs` is True. + Also populates default values. Does not support positional-only + parameters or varargs parameters (*args, **kwargs). + + Args: + root (nn.Module): root module upon which we query modules + target (Callable): Function that we are normalizing + args (Tuple[Any]): Tuple of args to the function + kwargs (Optional[Dict[str, Any]]): Dict of kwargs to the function + normalize_to_only_use_kwargs (bool): Whether to normalize to only use kwargs. + + Returns: + + Returns normalized_args_and_kwargs, or `None` if not successful. + """ + try: + submod = root.get_submodule(target) + except AttributeError as e: + raise RuntimeError( + f"Tried to normalize node with target {target} but root did not " + f"have that target!" + ) from e + if hasattr(submod.__class__, "__name__"): + classname = submod.__class__.__name__ + if getattr(torch.nn, classname, None) == submod.__class__: + sig = inspect.signature(inspect.unwrap(submod.forward)) + if kwargs is None: + kwargs = {} + new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs( + sig, args, kwargs, normalize_to_only_use_kwargs + ) + return new_args_and_kwargs + return None + + +def _args_kwargs_to_normalized_args_kwargs( + sig: inspect.Signature, + args: tuple[Any, ...], + kwargs: dict[str, Any], + normalize_to_only_use_kwargs: bool, +) -> Optional[ArgsKwargsPair]: + """ + Given a call target, args, and kwargs, return the arguments normalized into + an ArgsKwargsPair, or None if the type signature is not supported by + this normalization. + + Args: + + sig (inspect.Signature): Signature object for the target + args (Tuple): Arguments that appear at the callsite for `target` + kwargs (Dict): Keyword arguments that appear at the callsite for `target` + normalize_to_only_use_kwargs (bool): Whether to normalize to only use kwargs. + + Returns: + + Optional[ArgsKwargsPair]: Normalized args and kwargs for `target`, or `None` if + this target is not supported. + """ + + # Don't currently support positional-only + # or varargs (*args, **kwargs) signatures + supported_parameter_types = { + inspect.Parameter.POSITIONAL_OR_KEYWORD, + inspect.Parameter.KEYWORD_ONLY, + } + if any(p.kind not in supported_parameter_types for p in sig.parameters.values()): + # Add an exception for one signature, which is common for random/uniform, i.e.: + # Tensor(a!) self, float from=0, float to=1, *, Generator? generator=None + # `from` is Python keyword and as such functions with that signature should have + # positional-only args, but at the same time they could be dispatched as kwargs + if list(sig.parameters.keys()) != ["input", "from", "to", "generator"]: + return None + + bound_args = sig.bind(*args, **kwargs) + bound_args.apply_defaults() + + new_kwargs: dict[str, Any] = {} + new_args: list[Any] = [] + for i, param in enumerate(sig.parameters): + if not normalize_to_only_use_kwargs and i < len(args): + new_args.append(bound_args.arguments[param]) + else: + new_kwargs[param] = bound_args.arguments[param] + + return ArgsKwargsPair(tuple(new_args), new_kwargs) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/proxy.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/proxy.py new file mode 100644 index 0000000000000000000000000000000000000000..c38c31c4d211278484edd53fc176effb65945aef --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/proxy.py @@ -0,0 +1,850 @@ +# mypy: ignore-errors + +import collections +import copy +import dis +import enum +import inspect +import logging +import operator +import sys +import traceback +from collections import OrderedDict +from collections.abc import Callable, Iterator +from dataclasses import fields, is_dataclass +from typing import Any, Optional + +import torch +import torch.fx.traceback as fx_traceback +from torch._C import _fx_map_aggregate as map_aggregate, _fx_map_arg as map_arg +from torch._library.opaque_object import is_opaque_value_type +from torch._logging import getArtifactLogger +from torch.utils._traceback import CapturedTraceback + +from ._compatibility import compatibility +from .graph import Graph, magic_methods, reflectable_magic_methods +from .immutable_collections import immutable_dict, immutable_list +from .node import Argument, base_types, Node, Target +from .operator_schemas import check_for_mutable_operation + + +__all__ = [ + "TracerBase", + "GraphAppendingTracer", + "TraceError", + "Proxy", + "MetaProxy", + "Attribute", + "ParameterProxy", + "Scope", + "ScopeContextManager", +] + + +log = logging.getLogger(__name__) +annotation_log = getArtifactLogger(__name__, "annotation") + + +@compatibility(is_backward_compatible=False) +class Scope: + """Scope object that records the module path and the module type + of a module. Scope is used to track the information of the module + that contains a Node in a Graph of GraphModule. For example:: + + class Sub(torch.nn.Module): + def forward(self, x): + # This will be a call_method Node in GraphModule, + # scope for this would be (module_path="sub", module_type=Sub) + return x.transpose(1, 2) + + + class M(torch.nn.Module): + def __init__(self) -> None: + self.sub = Sub() + + def forward(self, x): + # This will be a call_method Node as well, + # scope for this would be (module_path="", None) + x = x.transpose(1, 2) + x = self.sub(x) + return x + + """ + + def __init__(self, module_path: str, module_type: Any): + super().__init__() + self.module_path = module_path + self.module_type = module_type + + +@compatibility(is_backward_compatible=False) +class ScopeContextManager: + """A context manager to track the Scope of Node during symbolic tracing. + When entering a forward function of a Module, we'll update the scope information of + the current module, and when we exit, we'll restore the previous scope information. + """ + + def __init__( + self, + scope: Scope, + current_scope: Scope, + ): + super().__init__() + # Keep a copy of prev scope to restore on exit + self._prev_scope = copy.copy(scope) + # Update scope to current scope + scope.module_path = current_scope.module_path + scope.module_type = current_scope.module_type + # Save a reference so we can restore it + self._scope = scope + + def __enter__(self): + return self._scope + + def __exit__(self, *args): + self._scope.module_path = self._prev_scope.module_path + self._scope.module_type = self._prev_scope.module_type + return + + +_COPY_META_FIELDS = [ + "nn_module_stack", + "torch_fn", + "source_fn_stack", + "original_aten", + "recompute", + "ac_graph_id", + "has_backward_hook", + "from_node", + "quantization_tag", # TODO deprecated + "_numeric_debug_handle", # TODO deprecated + "custom", + "partitioner_tag", +] + + +@compatibility(is_backward_compatible=True) +class TracerBase: + graph: Graph + record_stack_traces: bool = False + # When record_stack_traces is True, only reocrd stack traces + # with forward function names. + # This helps when we want stack trace back to model code + _record_forward_stack_traces_only: bool = False + # Feature flag for mutable schema checking + # Enableby default in 1.12 + check_mutable_operations: bool = False + # Feature flag for assert tracing + trace_asserts: bool = False + # Feature flag for proxying accesses to buffer values + proxy_buffer_attributes: bool = False + + # Name of the function to be traced. It will only be used when + # ``root`` is an instance of ``nn.Module`` + traced_func_name: str = "forward" + + # Maps the containing module's name to the operator name + scope: Scope + + # Records the module call stack + module_stack: OrderedDict[str, tuple[str, Any]] + + # Mapping of node name to module scope + node_name_to_scope: dict[str, tuple[str, type]] + + @compatibility(is_backward_compatible=True) + def create_node( + self, + kind: str, + target: Target, + args: tuple[Argument, ...], + kwargs: dict[str, Argument], + name: Optional[str] = None, + type_expr: Optional[Any] = None, + ) -> Node: + """ + Inserts a graph node given target, args, kwargs, and name. + + This method can be overridden to do extra checking, validation, or + modification of values used in node creation. For example, one might + want to disallow in-place operations from being recorded. + """ + + if kind == "call_function" and self.check_mutable_operations: + check_for_mutable_operation(target, args, kwargs) + + node = self.graph.create_node(kind, target, args, kwargs, name, type_expr) + # TODO node_name_to_scope will be depreciated in favor of + # node.meta['nn_module_stack'] + self.node_name_to_scope[node.name] = ( + self.scope.module_path, + self.scope.module_type, + ) + + # Optionally set stack trace on the created Node for debugging purposes + if fx_traceback.has_preserved_node_meta(): + current_meta: dict[str, Any] = fx_traceback.get_current_meta() + + stack_trace = current_meta.get("stack_trace") + if stack_trace: + node.stack_trace = stack_trace + + if fx_traceback.GRADIENT_ACC_SPECIAL_STACK in stack_trace: + node.meta["is_gradient_acc"] = True + + # Explicitly set the stack_trace, nn_module_stack and source_fn on the node.meta + # If other meta fields are needed, they can be added here + for field in _COPY_META_FIELDS: + if field in current_meta: + node.meta[field] = copy.copy(current_meta[field]) + + # Here we decrement to account for the sequence_nr having + # just been incremented while tracing this lowered aten op. + new_seq_nr = torch.autograd._get_sequence_nr() - 1 + # The sequence_nr increments every time a new autograd Node + # is created. During the FWD pass we store the sequence_nr + # corresponding to the last autograd Node created on this fx + # node's meta. A single aten op can create multiple autograd + # nodes as is the case with in-place foreach ops. During the + # BWD pass we retrieve the sequence_nr stored on the current + # executing autograd Node. See NOTE [ Sequence Number ]. + if current_meta.get("in_grad_fn", 0) > 0: + annotation_log.debug("seq_nr from current_meta") + new_seq_nr = current_meta["grad_fn_seq_nr"][-1] + + # See Note [Functionalization View Replay Annotation] + # Overriding some node meta with the original node meta of the + # regenerated node. + replay_node: Node = fx_traceback.get_current_replay_node() + if replay_node is not None: + node.meta["is_functional_regenerated"] = True + if "seq_nr" in replay_node.meta: + annotation_log.debug("seq_nr from replay_node") + new_seq_nr = replay_node.meta["seq_nr"] + if "custom" in replay_node.meta: + node.meta["custom"] = replay_node.meta.get("custom") + if "stack_trace" in replay_node.meta: + node.stack_trace = replay_node.meta.get("stack_trace") + + annotation_log.debug("Assigning new_seq_nr %s to %s", new_seq_nr, node.name) + node.meta["seq_nr"] = new_seq_nr + + elif self.module_stack: + node.meta["nn_module_stack"] = copy.copy(self.module_stack) + + if self.record_stack_traces and not node.stack_trace: + user_stack_summary = CapturedTraceback.extract().summary() + if user_stack_summary: + user_stack_summary = self._filter_traceback_frames(user_stack_summary) + if user_stack_summary: + node.stack_trace = "".join(user_stack_summary.format()).strip() + + log.debug("create_node %s", node) + return node + + def _filter_traceback_frames( + self, user_stack_summary: traceback.StackSummary + ) -> traceback.StackSummary: + # This method can be overridden to customize the frame filtering logic + # for the recorded stack trace + user_frames = [] + if self._record_forward_stack_traces_only: + user_frames = [ + frame + for frame in user_stack_summary + if ( + frame.name == "forward" + or frame.filename.endswith("torch/__init__.py") + ) + ] + else: + first_forward = -1 + for i, frame in enumerate(user_stack_summary): + if frame.name == "forward": + user_frames = user_stack_summary[i:] + first_forward = i + break + + # Not having a "forward" call in the stacktrace implies the + # stacktrace will probably be irrelevant + if first_forward == -1: + user_frames = [] + + from torch.fx.experimental.symbolic_shapes import uninteresting_files + + user_frames = [ + frame + for frame in user_frames + if frame.filename not in uninteresting_files() + ] + + return traceback.StackSummary.from_list(user_frames) + + @compatibility(is_backward_compatible=True) + def proxy(self, node: Node) -> "Proxy": + return Proxy(node, self) + + @compatibility(is_backward_compatible=True) + def create_proxy( + self, + kind: str, + target: Target, + args: tuple[Any, ...], + kwargs: dict[str, Any], + name: Optional[str] = None, + type_expr: Optional[Any] = None, + # fix noqa when updating bc tests + proxy_factory_fn: Callable[[Node], "Proxy"] = None, # noqa: RUF013 + ): + """ + Create a Node from the given arguments, then return the Node + wrapped in a Proxy object. + + If kind = 'placeholder', then we're creating a Node that + represents the parameter of a function. If we need to encode + a default parameter, we use the ``args`` tuple. ``args`` is + otherwise empty for ``placeholder`` Nodes. + """ + + args_ = self.create_arg(args) + kwargs_ = self.create_arg(kwargs) + assert isinstance(args_, tuple) + assert isinstance(kwargs_, dict) + + node = self.create_node(kind, target, args_, kwargs_, name, type_expr) + + if not proxy_factory_fn: + proxy = self.proxy(node) + else: + proxy = proxy_factory_fn(node) + + return proxy + + def _find_user_frame(self): + """ + Find the Python stack frame executing the user code during + symbolic tracing. + """ + # We have to do a little dance here. Basically, walk up the callstack and + # record the first frame not in the pytorch source. This is the frame executing + # the user code during tracing. + frame = inspect.currentframe() + + pt_files = [ + "torch/fx/proxy.py", + "torch/fx/_symbolic_trace.py", + "torch/fx/experimental/proxy_tensor.py", + "torch/_ops.py", + "torch/_tensor.py", + "torch/utils/_python_dispatch.py", + "torch/_prims_common/wrappers.py", + "torch/_refs/__init__.py", + "torch/_refs/nn/functional/__init__.py", + "torch/utils/_stats.py", + ] + while frame: + frame = frame.f_back + if frame and all( + not frame.f_code.co_filename.endswith(file) for file in pt_files + ): + break + + if not frame: + return None + + return frame + + @compatibility(is_backward_compatible=True) + def create_arg(self, a: Any) -> Argument: + """ + A method that lowers the objects seen as arguments during symbolic evaluation + into Argument types that can be stored in IR. + + Can be override to support more trace-specific types. + """ + # IMPORTANT: Are you here because you are trying to proxy a new type into + # the graph? Please Please Please contact someone on the PyTorch Compiler team; + # the considerations are subtle. + # + # 1) When you add a new type, all of the downstream consumers and pass writers + # need to handle the new type. torch.fx is intended to be easy to write + # passes for, so we will push back against new types. + # 2) In torch.compile's IR, there are only specific operations that go + # into the graph. In particular, Tensor operations should go into the graph, + # but non-Tensor operations shouldn't. What that means is that constructors + # for new types *SHOULD NOT* become nodes in the FX graph. + handler = _create_arg_bypass.get(type(a)) + if handler is not None: + # this is just a performance optimization and can be removed if needed + # for common types, we have a fast path to avoid isinstance() overhead + # this doesn't remove the checks below since we need to handle subclasses + return handler(self, a) + + if isinstance(a, Proxy): + return a.node # most common arg type goes first + elif hasattr(a, "__fx_create_arg__"): + return a.__fx_create_arg__(self) + # aggregates + elif isinstance(a, tuple): + if hasattr(a, "_fields"): + # NamedTuple constructors don't seem to like getting a generator + # expression as an argument to their constructor, so build this + # intermediate tuple and unpack it into the NamedTuple constructor + args = [self.create_arg(elem) for elem in a] + return type(a)(*args) # type: ignore[arg-type] + return type(a)([self.create_arg(elem) for elem in a]) + elif isinstance(a, list): + return [self.create_arg(elem) for elem in a] + elif isinstance(a, dict): + return _create_arg_dict(self, a) + elif isinstance(a, slice): + return slice( + self.create_arg(a.start), + self.create_arg(a.stop), + self.create_arg(a.step), + ) + + elif isinstance(a, range): + return range( + self.create_arg(a.start), + self.create_arg(a.stop), + self.create_arg(a.step), + ) + + elif isinstance(a, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)): + return a + + elif is_opaque_value_type(type(a)): + return a + + elif is_dataclass(a): + kwargs = { + field.name: self.create_arg(getattr(a, field.name)) + for field in fields(a) + } + return self.create_node("call_function", a.__class__, (), kwargs) + + elif isinstance(a, (*base_types, enum.Enum)) or a is None or a is ...: + return a + + raise NotImplementedError(f"argument of type: {type(a)}") + + @compatibility(is_backward_compatible=True) + def to_bool(self, obj: "Proxy") -> bool: + """Called when a proxy object is being converted to a boolean, such as + when used in control flow. Normally we don't know what to do because + we don't know the value of the proxy, but a custom tracer can attach more + information to the graph node using create_node and can choose to return a value. + """ + raise TraceError( + "symbolically traced variables cannot be used as inputs to control flow" + ) + + @compatibility(is_backward_compatible=True) + def iter(self, obj: "Proxy") -> Iterator: + """Called when a proxy object is being iterated over, such as + when used in control flow. Normally we don't know what to do because + we don't know the value of the proxy, but a custom tracer can attach more + information to the graph node using create_node and can choose to return an iterator. + """ + raise TraceError( + "Proxy object cannot be iterated. This can be " + "attempted when the Proxy is used in a loop or" + " as a *args or **kwargs function argument. " + "See the torch.fx docs on pytorch.org for a " + "more detailed explanation of what types of " + "control flow can be traced, and check out the" + " Proxy docstring for help troubleshooting " + "Proxy iteration errors" + ) + + @compatibility(is_backward_compatible=True) + def keys(self, obj: "Proxy") -> Any: + """Called when a proxy object is has the keys() method called. + This is what happens when ** is called on a proxy. This should return an + iterator it ** is suppose to work in your custom tracer. + """ + return Attribute(obj, "keys")() + + +# used in Proxy object when just appending to the graph while not tracing. +@compatibility(is_backward_compatible=True) +class GraphAppendingTracer(TracerBase): + def __init__(self, graph: Graph): + super().__init__() + self.graph = graph + self.scope = Scope("", None) + self.module_stack = collections.OrderedDict() + self.node_name_to_scope = {} + + +@compatibility(is_backward_compatible=False) +def assert_fn(x): + assert x + + +@compatibility(is_backward_compatible=True) +class TraceError(ValueError): + pass + + +@compatibility(is_backward_compatible=True) +class Proxy: + """ + ``Proxy`` objects are ``Node`` wrappers that flow through the + program during symbolic tracing and record all the operations + (``torch`` function calls, method calls, operators) that they touch + into the growing FX Graph. + + If you're doing graph transforms, you can wrap your own ``Proxy`` + method around a raw ``Node`` so that you can use the overloaded + operators to add additional things to a ``Graph``. + + ``Proxy`` objects cannot be iterated. In other words, the symbolic + tracer will throw an error if a ``Proxy`` is used in a loop or as + an ``*args``/``**kwargs`` function argument. + + There are two main ways around this: + 1. Factor out the untraceable logic into a top-level function and + use ``fx.wrap`` on it. + 2. If the control flow is static (i.e. the loop trip count is + based on some hyperparameter), the code can be kept in its original + position and refactored into something like:: + + for i in range(self.some_hyperparameter): + indexed_item = proxied_value[i] + + For a more detailed description into the Proxy internals, check out + the "Proxy" section in `torch/fx/README.md` + """ + + @compatibility(is_backward_compatible=True) + def __init__(self, node: Node, tracer: "Optional[TracerBase]" = None): + if tracer is None: + # This allows you to create a Proxy object around a raw Node + tracer = GraphAppendingTracer(node.graph) + self.tracer = tracer + self.node = node + + def __repr__(self) -> str: + return f"Proxy({self.node.name})" + + def __getattr__(self, k) -> "Attribute": + # note: not added to the graph yet, if this is a method call + # we peephole optimize to the method invocation + return Attribute(self, k) + + def __getstate__(self) -> dict: + return self.__dict__ + + def __deepcopy__(self, memo) -> dict: + # We have to explicitly override this method, because otherwise deepcopy + # will go to __getattr__(self, "__deepcopy__") and return a + # Attribute(__deepcopy__), and may go into an infinite loop in some cases. + import copy + + new_dict = {} + for k, v in self.__dict__.items(): + try: + new_obj = copy.deepcopy(v, memo) + except Exception: + log.warning( + "Shallow copy %s of Proxy because it cannot be deepcopied. " + "Proxy is created for node %s", + k, + self.node.name, + ) + new_obj = copy.copy(v) + new_dict[k] = new_obj + assert "node" in new_dict + assert "tracer" in new_dict + new_proxy = Proxy(new_dict["node"], new_dict["tracer"]) + for k, v in new_dict.items(): + new_proxy.__dict__[k] = v + return new_proxy + + def __setstate__(self, d): + # This is called when being unpickled/loaded. + self.__dict__ = d + + def __call__(self, *args, **kwargs) -> "Proxy": + return self.tracer.create_proxy( + "call_method", "__call__", (self,) + args, kwargs + ) + + def __iter__(self) -> Iterator["Proxy"]: + frame = inspect.currentframe() + assert frame is not None + calling_frame = frame.f_back + assert calling_frame is not None + inst_list = list(dis.get_instructions(calling_frame.f_code)) + if sys.version_info >= (3, 11): + from bisect import bisect_left + + inst_idx = bisect_left( + inst_list, calling_frame.f_lasti, key=lambda x: x.offset + ) + else: + inst_idx = calling_frame.f_lasti // 2 + inst = inst_list[inst_idx] + if inst.opname == "UNPACK_SEQUENCE": + return (self[i] for i in range(inst.argval)) # type: ignore[index] + + return self.tracer.iter(self) + + def __abs__(self): + return self.tracer.create_proxy("call_function", operator.abs, (self,), {}) + + def __bool__(self) -> bool: + if self.tracer.trace_asserts: + # check if this boolean is used in an assertion, bytecode pattern for assertions + # is pretty stable for Python 3.7--3.9 + frame = inspect.currentframe() + assert frame is not None + calling_frame = frame.f_back + assert calling_frame is not None + insts = list(dis.get_instructions(calling_frame.f_code)) + if sys.version_info >= (3, 11): + from bisect import bisect_left + + cur = bisect_left(insts, calling_frame.f_lasti, key=lambda x: x.offset) + else: + cur = calling_frame.f_lasti // 2 + inst = insts[cur] + + if inst.opname == "POP_JUMP_IF_TRUE": + first = insts[cur + 1] + assert inst.arg is not None + last = insts[inst.arg // 2 - 1] + starts_with_assert = ( + first.opname == "LOAD_GLOBAL" + and first.argval == "AssertionError" + or first.opname == "LOAD_ASSERTION_ERROR" + ) + if starts_with_assert and last.opname == "RAISE_VARARGS": + self.tracer.create_proxy("call_function", assert_fn, (self,), {}) + return True + + return self.tracer.to_bool(self) + + @compatibility(is_backward_compatible=True) + def keys(self): + return self.tracer.keys(self) + + def __len__(self): + raise RuntimeError( + "'len' is not supported in symbolic tracing by default. If you want " + "this call to be recorded, please call torch.fx.wrap('len') at " + "module scope" + ) + + @classmethod + def __torch_function__(cls, orig_method, types, args=None, kwargs=None): + args = args if args else () + kwargs = kwargs if kwargs else {} + + tracers: dict[Any, None] = {} + + def find_tracer(a): + if isinstance(a, cls): + tracers[a.tracer] = None + + map_aggregate(args, find_tracer) + map_aggregate(kwargs, find_tracer) + + if len(tracers) > 1: + raise RuntimeError( + f"Found multiple different tracers {list(tracers.keys())} while " + f"trying to trace operations {orig_method}" + ) + tracer = next(iter(tracers.keys())) + + if isinstance(orig_method, torch._C.ScriptMethod): + args = (orig_method.owner,) + args + return tracer.create_proxy("call_method", orig_method.name, args, kwargs) + if torch.overrides.is_tensor_method_or_property(orig_method): + return tracer.create_proxy( + "call_method", orig_method.__name__, args, kwargs + ) + else: + if isinstance(orig_method, torch._ops.HigherOrderOperator): + # TODO: Define how to symbolically trace HigherOrderOperators + raise RuntimeError("Unable to symbolically trace HigherOrderOperators") + return tracer.create_proxy( + "call_function", + orig_method, + args, + kwargs, + name=tracer.graph._target_to_str(orig_method.__name__), + ) + + +@compatibility(is_backward_compatible=False) +class MetaProxy(Proxy): + """ + A Proxy subclass that propagates metadata (meta['val']) during graph tracing. + """ + + def __init__( + self, node: Node, tracer: "Optional[TracerBase]" = None, fake_mode=None + ): + super().__init__(node, tracer) + self.fake_mode = fake_mode + + def __repr__(self) -> str: + return f"MetaProxy({self.node.name})" + + @classmethod + def __torch_function__(cls, orig_method, types, args=None, kwargs=None): + args = args if args else () + kwargs = kwargs if kwargs else {} + + meta_proxy = None + for arg in args: + if isinstance(arg, MetaProxy): + meta_proxy = arg + break + + assert meta_proxy is not None, ( + "No MetaProxy found in arguments, but one is expected." + ) + + proxy = super().__torch_function__(orig_method, types, args, kwargs) + with meta_proxy.fake_mode: + proxy.node.meta["val"] = orig_method( + *[a.node.meta["val"] if isinstance(a, Proxy) else a for a in args], + **kwargs, + ) + return MetaProxy(proxy.node, proxy.tracer, meta_proxy.fake_mode) + + +@compatibility(is_backward_compatible=True) +class Attribute(Proxy): + @compatibility(is_backward_compatible=True) + def __init__(self, root: Proxy, attr: str): + self.root = root + self.attr = attr + self.tracer = root.tracer + self._node: Optional[Node] = None + + @property + def node(self): + # the node for attributes is added lazily, since most will just be method calls + # which do not rely on the getitem call + if self._node is None: + self._node = self.tracer.create_proxy( + "call_function", getattr, (self.root, self.attr), {} + ).node + return self._node + + def __call__(self, *args, **kwargs): + return self.tracer.create_proxy( + "call_method", self.attr, (self.root,) + args, kwargs + ) + + +@compatibility(is_backward_compatible=False) +class ParameterProxy(Proxy): + """ + A special proxy which lets "shape", "size", "dim", and a few other + attribute accesses pass through to the underlying module parameter object, + so that conditional tests on these attributes will not throw exception during tracing + """ + + def __init__(self, tracer: TracerBase, node: Node, name, param): + super().__init__(node, tracer) + assert isinstance(param, torch.nn.Parameter) + self.param = param + self.name = name + + def __repr__(self) -> str: + return f"ParameterProxy({self.name})" + + @property + def shape(self): + return self.param.shape + + def size(self): + return self.param.size() + + def dim(self): + return self.param.dim() + + @property + def ndim(self): + return self.param.ndim + + def numel(self): + return self.param.numel() + + def nelement(self): + return self.param.nelement() + + +for method in magic_methods: + + def _scope(method): + def impl(*args, **kwargs): + tracer = args[0].tracer + target = getattr(operator, method) + return tracer.create_proxy("call_function", target, args, kwargs) + + impl.__name__ = method + as_magic = f"__{method.strip('_')}__" + setattr(Proxy, as_magic, impl) + + _scope(method) + + +def _define_reflectable(orig_method_name): + method_name = f"__r{orig_method_name.strip('_')}__" + + def impl(self, rhs): + target = getattr(operator, orig_method_name) + return self.tracer.create_proxy("call_function", target, (rhs, self), {}) + + impl.__name__ = method_name + impl.__qualname__ = method_name + setattr(Proxy, method_name, impl) + + +for orig_method_name in reflectable_magic_methods: + _define_reflectable(orig_method_name) + + +def _no_nodes_error(arg): + raise RuntimeError( + "Keys for dictionaries used as an argument cannot contain a " + f"Node. Got key: {arg}" + ) + + +def _create_arg_dict(self, a): + r = {} + for k, v in a.items(): + if not isinstance(k, str): + # Check for invalid dict keys. We do not want a Proxy to appear + # anywhere within the key. Since keys can be collection types, + # we iterate through the key with map_arg + k = self.create_arg(k) + map_arg(k, _no_nodes_error) + r[k] = self.create_arg(v) + return r + + +_create_arg_bypass = { + t: lambda self, a: a + for t in [ + *base_types, + type(None), + type(...), + torch._ops.OpOverload, + torch._ops.HigherOrderOperator, + ] +} +_create_arg_bypass[Proxy] = lambda self, a: a.node +_create_arg_bypass[tuple] = lambda self, a: tuple(self.create_arg(elem) for elem in a) +_create_arg_bypass[list] = lambda self, a: [self.create_arg(elem) for elem in a] +_create_arg_bypass[dict] = _create_arg_dict +_create_arg_bypass[immutable_list] = _create_arg_bypass[list] +_create_arg_bypass[immutable_dict] = _create_arg_bypass[dict] diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/subgraph_rewriter.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/subgraph_rewriter.py new file mode 100644 index 0000000000000000000000000000000000000000..2253da19d36427d4059926b71471744a841c62eb --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/subgraph_rewriter.py @@ -0,0 +1,440 @@ +import copy +from collections.abc import Callable +from dataclasses import dataclass +from typing import Any, NamedTuple, Optional, TYPE_CHECKING, Union + +import torch + +from ._compatibility import compatibility +from ._symbolic_trace import symbolic_trace +from .graph import Graph +from .graph_module import GraphModule +from .node import Node + + +if TYPE_CHECKING: + from .passes.utils.matcher_with_name_node_map_utils import InternalMatch + +__all__ = [ + "Match", + "replace_pattern", + "replace_pattern_with_filters", + "ReplacedPatterns", +] + + +@compatibility(is_backward_compatible=True) +class Match(NamedTuple): + # Node from which the match was found + anchor: Node + # Maps nodes in the pattern subgraph to nodes in the larger graph + nodes_map: dict[Node, Node] + + +@compatibility(is_backward_compatible=False) +@dataclass +class ReplacedPatterns: + # Node from which the match was found + anchor: Node + # Maps nodes in the pattern subgraph to nodes in the larger graph + nodes_map: dict[Node, Node] + # List of nodes that were added into the graph + replacements: list[Node] + + +def _replace_attributes(gm: GraphModule, replacement: torch.nn.Module) -> None: + gm.delete_all_unused_submodules() + + if isinstance(replacement, GraphModule): + replacement.graph.lint() + + def try_get_attr(gm: torch.nn.Module, target: str) -> Optional[Any]: + module_path, _, attr_name = target.rpartition(".") + try: + mod: torch.nn.Module = gm.get_submodule(module_path) + except AttributeError: + return None + attr = getattr(mod, attr_name, None) + return attr + + for node in gm.graph.nodes: + if node.op == "call_module" or node.op == "get_attr": + gm_attr = try_get_attr(gm, node.target) + replacement_attr = try_get_attr(replacement, node.target) + + # CASE 1: This target already exists as an attribute in our + # result GraphModule. Whether or not it exists in + # `replacement`, the existing submodule takes precedence. + if gm_attr is not None: + continue + + # CASE 2: The target exists as an attribute in `replacement` + # only, so we need to copy it over. + elif replacement_attr is not None: + new_attr = copy.deepcopy(replacement_attr) + if isinstance(replacement_attr, torch.nn.Module): + gm.add_submodule(node.target, new_attr) + else: + setattr(gm, node.target, new_attr) + + # CASE 3: The target doesn't exist as an attribute in `gm` + # or `replacement` + else: + raise RuntimeError( + 'Attempted to create a "', + node.op, + '" node during subgraph rewriting ' + f"with target {node.target}, but " + "the referenced attribute does not " + "exist in the replacement GraphModule", + ) + + gm.graph.lint() + + +@compatibility(is_backward_compatible=True) +def replace_pattern( + gm: GraphModule, + pattern: Union[Callable, GraphModule], + replacement: Union[Callable, GraphModule], +) -> list[Match]: + """ + Matches all possible non-overlapping sets of operators and their + data dependencies (``pattern``) in the Graph of a GraphModule + (``gm``), then replaces each of these matched subgraphs with another + subgraph (``replacement``). + + Args: + ``gm``: The GraphModule that wraps the Graph to operate on + ``pattern``: The subgraph to match in ``gm`` for replacement + ``replacement``: The subgraph to replace ``pattern`` with + + Returns: + List[Match]: A list of ``Match`` objects representing the places + in the original graph that ``pattern`` was matched to. The list + is empty if there are no matches. ``Match`` is defined as: + + .. code-block:: python + + class Match(NamedTuple): + # Node from which the match was found + anchor: Node + # Maps nodes in the pattern subgraph to nodes in the larger graph + nodes_map: Dict[Node, Node] + + Examples: + + .. code-block:: python + + import torch + from torch.fx import symbolic_trace, subgraph_rewriter + + + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, x, w1, w2): + m1 = torch.cat([w1, w2]).sum() + m2 = torch.cat([w1, w2]).sum() + return x + torch.max(m1) + torch.max(m2) + + + def pattern(w1, w2): + return torch.cat([w1, w2]) + + + def replacement(w1, w2): + return torch.stack([w1, w2]) + + + traced_module = symbolic_trace(M()) + + subgraph_rewriter.replace_pattern(traced_module, pattern, replacement) + + The above code will first match ``pattern`` in the ``forward`` + method of ``traced_module``. Pattern-matching is done based on + use-def relationships, not node names. For example, if you had + ``p = torch.cat([a, b])`` in ``pattern``, you could match + ``m = torch.cat([a, b])`` in the original ``forward`` function, + despite the variable names being different (``p`` vs ``m``). + + The ``return`` statement in ``pattern`` is matched based on its + value only; it may or may not match to the ``return`` statement in + the larger graph. In other words, the pattern doesn't have to extend + to the end of the larger graph. + + When the pattern is matched, it will be removed from the larger + function and replaced by ``replacement``. If there are multiple + matches for ``pattern`` in the larger function, each non-overlapping + match will be replaced. In the case of a match overlap, the first + found match in the set of overlapping matches will be replaced. + ("First" here being defined as the first in a topological ordering + of the Nodes' use-def relationships. In most cases, the first Node + is the parameter that appears directly after ``self``, while the + last Node is whatever the function returns.) + + One important thing to note is that the parameters of the + ``pattern`` Callable must be used in the Callable itself, + and the parameters of the ``replacement`` Callable must match + the pattern. The first rule is why, in the above code block, the + ``forward`` function has parameters ``x, w1, w2``, but the + ``pattern`` function only has parameters ``w1, w2``. ``pattern`` + doesn't use ``x``, so it shouldn't specify ``x`` as a parameter. + As an example of the second rule, consider replacing + + .. code-block:: python + + def pattern(x, y): + return torch.neg(x) + torch.relu(y) + + with + + .. code-block:: python + + def replacement(x, y): + return torch.relu(x) + + In this case, ``replacement`` needs the same number of parameters + as ``pattern`` (both ``x`` and ``y``), even though the parameter + ``y`` isn't used in ``replacement``. + + After calling ``subgraph_rewriter.replace_pattern``, the generated + Python code looks like this: + + .. code-block:: python + + def forward(self, x, w1, w2): + stack_1 = torch.stack([w1, w2]) + sum_1 = stack_1.sum() + stack_2 = torch.stack([w1, w2]) + sum_2 = stack_2.sum() + max_1 = torch.max(sum_1) + add_1 = x + max_1 + max_2 = torch.max(sum_2) + add_2 = add_1 + max_2 + return add_2 + """ + match_and_replacements = _replace_pattern(gm, pattern, replacement) + return [ + Match(anchor=m.anchor, nodes_map=m.nodes_map) for m in match_and_replacements + ] + + +# Experimental API, not backward compatible +@compatibility(is_backward_compatible=False) +def replace_pattern_with_filters( + gm: GraphModule, + pattern: Union[Callable, Graph, GraphModule], + replacement: Union[Callable, Graph, GraphModule, None] = None, + match_filters: Optional[ + list[Callable[["InternalMatch", Graph, Graph], bool]] + ] = None, + ignore_literals: bool = False, + # Placed at the end to avoid breaking backward compatibility + replacement_callback: Optional[ + Callable[["InternalMatch", Graph, Graph], Graph] + ] = None, + node_name_match: str = "", +) -> list[ReplacedPatterns]: + """ + See replace_pattern for documentation. This function is an overload with an additional match_filter argument. + + Args: + ``match_filters``: A list of functions that take in + (match: InternalMatch, original_graph: Graph, pattern_graph: Graph) and return a boolean indicating + whether the match satisfies the condition. + See matcher_utils.py for definition of InternalMatch. + ``replacement_callback``: A function that takes in a match and returns a + Graph to be used as the replacement. This allows you to construct a + replacement graph based on the match. + ``replacement_callback``: Node name to match. If not empty, it will try to match the node name. + """ + + return _replace_pattern( + gm, + pattern, + replacement, + match_filters, + ignore_literals, + replacement_callback, + node_name_match, + ) + + +def _replace_pattern( + gm: GraphModule, + pattern: Union[Callable, Graph, GraphModule], + replacement: Union[Callable, Graph, GraphModule, None] = None, + match_filters: Optional[ + list[Callable[["InternalMatch", Graph, Graph], bool]] + ] = None, + ignore_literals: bool = False, + # Placed at the end to avoid breaking backward compatibility + replacement_callback: Optional[ + Callable[["InternalMatch", Graph, Graph], Graph] + ] = None, + node_name_match: str = "", +) -> list[ReplacedPatterns]: + from torch.fx.passes.utils.matcher_utils import InternalMatch, SubgraphMatcher + + if match_filters is None: + match_filters = [] + + # Get the graphs for `gm`, `pattern`, `replacement` + original_graph: Graph = gm.graph + + if isinstance(pattern, GraphModule): + pattern_graph = pattern.graph + elif isinstance(pattern, Graph): + pattern_graph = pattern + else: + pattern_graph = symbolic_trace(pattern).graph # type: ignore[arg-type] + + matcher = SubgraphMatcher( + pattern_graph, + match_output=False, + match_placeholder=False, + remove_overlapping_matches=True, + ignore_literals=ignore_literals, + ) + _matches: list[InternalMatch] = matcher.match( + original_graph, node_name_match=node_name_match + ) + + # Filter out matches that don't match the filter + _matches = [ + m + for m in _matches + if all( + match_filter(m, original_graph, pattern_graph) + for match_filter in match_filters + ) + ] + + if isinstance(replacement, GraphModule): + common_replacement_graph = replacement.graph + elif isinstance(replacement, Graph): + common_replacement_graph = replacement + elif callable(replacement): + common_replacement_graph = symbolic_trace(replacement).graph + else: + assert replacement_callback is not None, ( + "Must provide either a replacement GraphModule or a replacement callback" + ) + common_replacement_graph = None # type: ignore[assignment] + + # As we progressively replace nodes, we'll need to keep track of how the match results should change + match_changed_node: dict[Node, Node] = {} + + match_and_replacements = [] + for match in _matches: + if replacement_callback is not None: + replacement_graph = replacement_callback( + match, original_graph, pattern_graph + ) + else: + assert common_replacement_graph is not None, ( + "Must provide either a replacement GraphModule or a replacement callback" + ) + replacement_graph = common_replacement_graph + replacement_placeholders = [ + n for n in replacement_graph.nodes if n.op == "placeholder" + ] + + # Build connecting between replacement graph's input and original graph input producer node + + # Initialize `val_map` with mappings from placeholder nodes in + # `replacement` to their corresponding node in `original_graph` + assert len(match.placeholder_nodes) == len(replacement_placeholders) + val_map: dict[Node, Node] = {} + for rn, gn in zip(replacement_placeholders, match.placeholder_nodes): + if isinstance(gn, Node): + val_map[rn] = match_changed_node.get(gn, gn) + if gn != val_map[rn]: + # Update match.placeholder_nodes and match.nodes_map with the node that replaced gn + gn_ind = match.placeholder_nodes.index(gn) + match.placeholder_nodes[gn_ind] = match_changed_node[gn] + map_key = list(match.nodes_map.keys())[ + list(match.nodes_map.values()).index(gn) + ] + match.nodes_map[map_key] = match_changed_node[gn] + else: + val_map[rn] = gn + + # Copy the replacement graph over + user_nodes: set[Node] = set() + for n in match.returning_nodes: + user_nodes.update(n.users) + + first_user_node = None + if len(user_nodes) == 0: + first_user_node = None + elif len(user_nodes) == 1: + first_user_node = next(iter(user_nodes)) + else: + # If there are multiple user nodes, we need to find the first user node + # in the current execution order of the `original_graph` + for n in original_graph.nodes: + if n in user_nodes: + first_user_node = n + break + + first_next_node = None + if first_user_node is None: + # no users, so we insert the replacement graph before the first next + # node of returning nodes + next_node = None + for n in reversed(original_graph.nodes): + if n in match.returning_nodes: + first_next_node = next_node + break + else: + next_node = n + insert_point = ( + first_user_node if first_user_node is not None else first_next_node + ) + assert insert_point is not None, "The insert point can't be None" + with original_graph.inserting_before(insert_point): + copied_returning_nodes = original_graph.graph_copy( + replacement_graph, val_map + ) + + if isinstance(copied_returning_nodes, Node): + copied_returning_nodes = (copied_returning_nodes,) + + # Get a list of nodes that have been replaced into the graph + replacement_nodes: list[Node] = [ + v for v in val_map.values() if v not in match.placeholder_nodes + ] + + # Hook the output Node of the replacement subgraph in to the + # original Graph at the correct location + assert len(match.returning_nodes) == len(copied_returning_nodes) # type: ignore[arg-type] + for gn, copied_node in zip(match.returning_nodes, copied_returning_nodes): # type: ignore[arg-type] + gn.replace_all_uses_with(copied_node) + match_changed_node[gn] = copied_node + # Remove the original nodes + for node in reversed(pattern_graph.nodes): + if node.op != "placeholder" and node.op != "output": + gn = match.nodes_map[node] + gm.graph.erase_node(gn) + + match_and_replacements.append( + ReplacedPatterns( + anchor=match.anchors[0], + nodes_map=match.nodes_map, + replacements=replacement_nodes, + ) + ) + + # Update the passed-in GraphModule to reflect the new state of + # `original_graph` + gm.recompile() + + # If `replacement` was an nn.Module, we'll need to make sure that + # all the submodules have been copied over correctly + if isinstance(replacement, torch.nn.Module): + _replace_attributes(gm, replacement) + + return match_and_replacements diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/tensor_type.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/tensor_type.py new file mode 100644 index 0000000000000000000000000000000000000000..4f375e461ef288341c5cd22e1e1ec2a851680b4c --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/tensor_type.py @@ -0,0 +1,111 @@ +# mypy: allow-untyped-defs +from torch.fx.experimental.unification import Var # type: ignore[attr-defined] + +from ._compatibility import compatibility + + +@compatibility(is_backward_compatible=False) +class TensorType: + """ + TensorType defines a type for tensors, which consists of a list of dimensions. + Example: + class M(torch.nn.Module): + def forward(self, x:TensorType((1,2,3, Dyn)), y:TensorType((1,2,3, Dyn))): + return torch.add(x, y) + """ + + def __init__(self, dim): + self.__origin__ = TensorType + self.__args__ = dim + + def __repr__(self): + return f"TensorType[{self.__args__}]" + + def __eq__(self, other): + if isinstance(other, self.__class__): + return list(self.__args__) == list(other.__args__) + else: + return False + + @staticmethod + def __class_getitem__(*args): + if len(args) == 1 and isinstance(args[0], tuple): + args = args[0] + return TensorType(tuple(args)) + + +class _DynType: + """ + _DynType defines a type which stands for the absence of type information. + """ + + def __init__(self) -> None: + self.__name__ = "_DynType" + + def __eq__(self, other): + return isinstance(other, self.__class__) + + def __str__(self): + return "Dyn" + + def __repr__(self): + return "Dyn" + + +Dyn = _DynType() + + +@compatibility(is_backward_compatible=False) +def is_consistent(t1, t2): + """ + A binary relation denoted by ~ that determines if t1 is consistent with t2. + The relation is reflexive, symmetric but not transitive. + returns True if t1 and t2 are consistent and False otherwise. + Example: + Dyn ~ TensorType((1,2,3)) + int ~ Dyn + int ~ int + TensorType((1,Dyn,3)) ~ TensorType((1,2,3)) + """ + + if t1 == t2: + return True + + if t1 == Dyn or t2 == Dyn or isinstance(t1, Var) or isinstance(t2, Var): + return True + + if isinstance(t1, TensorType) and isinstance(t2, TensorType): + return len(t1.__args__) == len(t2.__args__) and all( + is_consistent(elem1, elem2) + for elem1, elem2 in zip(t1.__args__, t2.__args__) + ) + else: + return False + + +@compatibility(is_backward_compatible=False) +def is_more_precise(t1, t2): + """ + A binary relation denoted by <= that determines if t1 is more precise than t2. + The relation is reflexive and transitive. + returns True if t1 is more precise than t2 and False otherwise. + Example: + Dyn >= TensorType((1,2,3)) + int >= Dyn + int >= int + TensorType((1,Dyn,3)) <= TensorType((1,2,3)) + """ + if t1 == t2: + return True + + if isinstance(t2, _DynType): + return True + + if isinstance(t1, TensorType) and isinstance(t2, TensorType): + return len(t1.__args__) == len(t2.__args__) and all( + is_more_precise(elem1, elem2) + for elem1, elem2 in zip(t1.__args__, t2.__args__) + ) + + else: + return False diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/traceback.py b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/traceback.py new file mode 100644 index 0000000000000000000000000000000000000000..b78ef313f24f5ff67391103468109685c4103ce7 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/fx/traceback.py @@ -0,0 +1,502 @@ +# mypy: allow-untyped-defs +import copy +import logging +import traceback +from contextlib import contextmanager +from enum import Enum +from typing import Any, Optional, Union + +from torch._utils_internal import signpost_event + +from ._compatibility import compatibility +from .graph import Graph +from .graph_module import GraphModule +from .node import Node + + +log = logging.getLogger(__name__) + +__all__ = [ + "annotate", + "annotate_fn", + "preserve_node_meta", + "has_preserved_node_meta", + "set_stack_trace", + "set_grad_fn_seq_nr", + "reset_grad_fn_seq_nr", + "format_stack", + "set_current_meta", + "get_current_meta", + "NodeSource", + "NodeSourceAction", + "get_graph_provenance_json", + "set_current_replay_node", + "get_current_replay_node", +] + +current_meta: dict[str, Any] = {} +current_replay_node: Optional[Node] = None +should_preserve_node_meta = False + +GRADIENT_ACC_SPECIAL_STACK = ( + "Gradient addition node due to multiple use of tensor around:" +) +# ============================================================================= +# FX Metadata Registry for Memory Profiler +# ============================================================================= +# Global in-memory registry for FX metadata +# Maps module_name -> metadata dict containing lineno_map and node_metadata +_FX_METADATA_REGISTRY: dict[str, dict[str, Any]] = {} + + +def _register_fx_metadata(module_name: str, metadata: dict[str, Any]) -> None: + """ + Register FX metadata in the global in-memory registry. + + This is called automatically during graph module compilation to store metadata + for later use by memory profiler augmentation. + + Args: + module_name: The module identifier (content-addressed filename) + metadata: Metadata dict containing lineno_map, node_metadata, and source_code + """ + # TODO: add logging to tlparse + _FX_METADATA_REGISTRY[module_name] = metadata + + +@compatibility(is_backward_compatible=False) +class NodeSourceAction(Enum): + CREATE = "create" + REPLACE = "replace" + + +@compatibility(is_backward_compatible=False) +class NodeSource: + """ + NodeSource is a data structure that contains the provenance information of a node. + If node `a` is created from node `b`, then `a.meta["from_node"]` may contain NodeSource(b). + """ + + class NodeInfo: + def __init__(self, name: str, target: str, graph_id: int): + self.name = name + self.target = target + self.graph_id = graph_id + + pass_name: str + action: list["NodeSourceAction"] + from_node: list["NodeSource"] + node_info: Optional["NodeInfo"] + _dict: Optional[dict[str, Any]] + _action_string: Optional[str] + + def __init__( + self, + node: Optional[Node], + pass_name: str = "", + action: Optional[Union["NodeSourceAction", list["NodeSourceAction"]]] = None, + ): + self.pass_name = pass_name + + if action is None: + action = [] + elif not isinstance(action, list): + action = [action] + for a in action: + assert isinstance(a, NodeSourceAction) + self.action = action + if node: + self.node_info = self.NodeInfo( + name=node.name, target=str(node.target), graph_id=id(node.graph) + ) + self.from_node = ( + copy.deepcopy(node.meta["from_node"]) + if "from_node" in node.meta + else [] + ) + else: + self.node_info = None + self.from_node = [] + + # cache the action string and dict representation for performance. + self._action_string: Optional[str] = None + self._dict: Optional[dict[str, Any]] = None + + @property + def name(self) -> str: + return self.node_info.name if self.node_info else "" + + @property + def target(self) -> str: + return self.node_info.target if self.node_info else "" + + @property + def graph_id(self) -> int: + return self.node_info.graph_id if self.node_info else -1 + + def __repr__(self): + return self.print_readable() + + def _get_action_string(self): + if self._action_string is None: + self._action_string = "+".join([a.name.lower() for a in self.action]) + return self._action_string + + def print_readable(self, indent=0): + if indent > 9: + return "" + result = "" + action_string = self._get_action_string() + result += ( + " " * indent * 4 + + f"(name={self.name}, pass_name={self.pass_name}, action={action_string}, graph_id={self.graph_id})\n" + ) + for item in self.from_node: + result += item.print_readable(indent + 1) + return result + + def to_dict(self) -> dict: + if self._dict is None: + # Convert the object to a dictionary + action_string = self._get_action_string() + self._dict = { + "name": self.name, + "target": self.target, + "graph_id": self.graph_id, + "pass_name": self.pass_name, + "action": action_string, + "from_node": [node.to_dict() for node in self.from_node], + } + + assert self._dict is not None + return self._dict + + def __eq__(self, other: object): + if not isinstance(other, NodeSource): + return False + return self.to_dict() == other.to_dict() + + def __hash__(self): + # Create a hash based on the dictionary representation + # We need to convert the dict to a hashable form + def _make_hashable(obj): + if isinstance(obj, dict): + return tuple(sorted((k, _make_hashable(v)) for k, v in obj.items())) + elif isinstance(obj, list): + return tuple(_make_hashable(item) for item in obj) + else: + return obj + + return hash(_make_hashable(self.to_dict())) + + @classmethod + def _from_dict(cls, d: Optional[dict]) -> Optional["NodeSource"]: + """ + Recursively deserialize from_node metadata from dictionary data. + It is used to deserialize the from_node field from serialized metadata. + Please use constructor NodeSource(node, ...) to create a NodeSource object. + """ + if d is None: + return None + + assert isinstance(d, dict), f"Expected a dict, got {type(d)}" + + # Create a NodeSource object directly without going through the constructor + # to avoid issues with graph ID and node creation + node_source = NodeSource.__new__(NodeSource) + + # Reset the cached properties + node_source._action_string = None + node_source._dict = None + + # Set the basic attributes + node_source.pass_name = d.get("pass_name", "") + + # Parse action string back to NodeSourceAction enum list + action_str = d.get("action", "") + actions = [] + if action_str: + for action_name in action_str.split("+"): + if action_name.upper() == "CREATE": + actions.append(NodeSourceAction.CREATE) + elif action_name.upper() == "REPLACE": + actions.append(NodeSourceAction.REPLACE) + node_source.action = actions + + # Create the NodeInfo object directly + if "name" in d and "target" in d and "graph_id" in d: + node_info = NodeSource.NodeInfo( + d.get("name", ""), d.get("target", ""), d.get("graph_id", -1) + ) + node_source.node_info = node_info + else: + node_source.node_info = None + + # Recursively deserialize nested from_node + if d.get("from_node", None) is not None: + node_source.from_node = [ + result + for fn in d.get("from_node", []) + if (result := cls._from_dict(fn)) is not None + ] + else: + node_source.from_node = [] + return node_source + + +@compatibility(is_backward_compatible=False) +@contextmanager +def preserve_node_meta(enable=True): + global should_preserve_node_meta + global current_meta + saved_should_preserve_node_meta = should_preserve_node_meta + # Shallow copy is OK since fields of current_meta are not mutated + saved_current_meta = current_meta.copy() + try: + should_preserve_node_meta = enable + yield + finally: + should_preserve_node_meta = saved_should_preserve_node_meta + current_meta = saved_current_meta + + +@compatibility(is_backward_compatible=False) +def set_stack_trace(stack: list[str]): + global current_meta + + if should_preserve_node_meta and stack: + current_meta["stack_trace"] = "".join(stack) + + +@compatibility(is_backward_compatible=False) +@contextmanager +def annotate(annotation_dict: dict): + """ + Temporarily adds custom annotations to the current tracing context. + The fx_node produced from this tracing context will have the + custom annotations in node.metadata["custom"] field. + + This context manager allows you to insert arbitrary metadata into the PT2 + tracing system by updating the global `current_meta["custom"]` dictionary. + The annotations are automatically reverted after the context exits. + + Gradient accumulation nodes will not be annotated. + + This is intended for advanced users who need to attach additional metadata to the fx nodes + (e.g., for debugging, analysis, or external tooling) during export tracing. + + Note: + This API is **not backward compatible** and may evolve in future releases. + + Note: + This API is not compatible with fx.symbolic_trace or jit.trace. It's intended + to be used with PT2 family of tracers, e.g. torch.export and dynamo. + + Args: + annotation_dict (dict): A dictionary of custom key-value pairs to inject + into the FX trace metadata. + + Example: + After exiting the context, custom annotations are removed. + + >>> with annotate({"source": "custom_pass", "tag": 42}): + ... pass # Your computation here + """ + + global current_meta + + has_custom = "custom" in current_meta + old_custom = copy.copy(current_meta.get("custom", {})) + + try: + if not has_custom: + current_meta["custom"] = {} + + # Update with all key-value pairs from the input dict + current_meta["custom"].update(annotation_dict) + yield + finally: + if has_custom: + # Restore the original custom dict + current_meta["custom"] = old_custom + else: + del current_meta["custom"] + + +@compatibility(is_backward_compatible=False) +def annotate_fn(annotation_dict: dict): + """ + A decorator that wraps a function with the annotate context manager. + Use this when you want to annotate an entire function instead of a specific code block. + + Note: + This API is **not backward compatible** and may evolve in future releases. + + Note: + This API is not compatible with fx.symbolic_trace or jit.trace. It's intended + to be used with PT2 family of tracers, e.g. torch.export and dynamo. + + Args: + annotation_dict (dict): A dictionary of custom key-value pairs to inject + into the FX trace metadata for all operations in the function. + + Example: + All operations in my_function will have {"pp_stage": 1} in their metadata. + + >>> @annotate_fn({"pp_stage": 1}) + ... def my_function(x): + ... return x + 1 + """ + from functools import wraps + + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + with annotate(annotation_dict): + return func(*args, **kwargs) + + return wrapper + + return decorator + + +@compatibility(is_backward_compatible=False) +def set_grad_fn_seq_nr(seq_nr): + global current_meta + + if should_preserve_node_meta: + # The seq_nr is captured by eager mode in the grad_fn during forward + current_meta["grad_fn_seq_nr"] = current_meta.get("grad_fn_seq_nr", []) + [ + seq_nr + ] + current_meta["in_grad_fn"] = current_meta.get("in_grad_fn", 0) + 1 + + +@compatibility(is_backward_compatible=False) +def reset_grad_fn_seq_nr(): + # NB: reset state properly, this would be helpful towards supporting + # reentrant autograd if we actually wanted to do that. + global current_meta + if should_preserve_node_meta: + current_level = current_meta.get("in_grad_fn", 0) + assert current_level > 0 + if current_level == 1: + del current_meta["in_grad_fn"] + del current_meta["grad_fn_seq_nr"] + else: + current_meta["in_grad_fn"] = current_level - 1 + current_meta["grad_fn_seq_nr"] = current_meta["grad_fn_seq_nr"][:-1] + + +@compatibility(is_backward_compatible=False) +def format_stack() -> list[str]: + if should_preserve_node_meta: + return [current_meta.get("stack_trace", "")] + else: + # fallback to traceback.format_stack() + return traceback.format_list(traceback.extract_stack()[:-1]) + + +@compatibility(is_backward_compatible=False) +def has_preserved_node_meta() -> bool: + return should_preserve_node_meta + + +@compatibility(is_backward_compatible=False) +@contextmanager +def set_current_meta(node, pass_name=""): + global current_meta + if should_preserve_node_meta and node.meta: + saved_meta = current_meta + try: + current_meta = node.meta.copy() + + # Update the "from_node" field in current_meta for provenance tracking. + # Instead of appending, overwrite the "from_node" field because current_meta + # will be assigned to the new node. The new NodeSource(node, ...) will + # include the information from the previous current_meta["from_node"]. + current_meta["from_node"] = [ + NodeSource(node, pass_name, NodeSourceAction.CREATE) + ] + yield + finally: + current_meta = saved_meta + else: + yield + + +@compatibility(is_backward_compatible=False) +def get_current_meta() -> dict[str, Any]: + return current_meta + + +@compatibility(is_backward_compatible=False) +@contextmanager +def set_current_replay_node(node): + """ + Set the currently replay node. If `current_replay_node` is not None, + then we're re-generating the `current_replay_node` in FunctionalTensorMode. + """ + # See [Note] annotation for more details. + global current_replay_node + saved_current_replay_node = current_replay_node + try: + current_replay_node = node + yield + finally: + current_replay_node = saved_current_replay_node + + +@compatibility(is_backward_compatible=False) +def get_current_replay_node(): + """ + Get the currently replay node + """ + return current_replay_node + + +@compatibility(is_backward_compatible=False) +def get_graph_provenance_json(graph: Graph) -> dict[str, Any]: + """ + Given an fx.Graph, return a json that contains the provenance information of each node. + """ + try: + provenance_tracking_json = {} + for node in graph.nodes: + if node.op == "call_function": + provenance_tracking_json[node.name] = ( + [source.to_dict() for source in node.meta["from_node"]] + if "from_node" in node.meta + else [] + ) + return provenance_tracking_json + except Exception as e: + # Since this is just debugging, it should never interfere with regular + # program execution, so we use this try-except to guard against any error + signpost_event( + "inductor", + "provenance_tracking_error", + { + "function": "get_graph_provenance_json", + "error_msg": str(e), + "stack_trace": traceback.format_exc(), + }, + ) + return {} + + +def _get_custom_metadata(gm: GraphModule) -> str: + assert isinstance(gm, GraphModule) + + def helper(gm: GraphModule): + custom_metadata = [] + for node in gm.graph.nodes: + if hasattr(node, "meta") and node.meta.get("custom", None): + custom_metadata.append((node.op, node.name, node.meta["custom"])) + if node.op == "get_attr" and isinstance( + getattr(gm, node.target), GraphModule + ): + custom_metadata.append(helper(getattr(gm, node.target))) + return custom_metadata + + return "\n".join(str(x) for x in helper(gm)) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/ATen.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/ATen.h new file mode 100644 index 0000000000000000000000000000000000000000..13b8d6a9aaa7109955133ac50a5b5a9af9deb113 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/ATen.h @@ -0,0 +1,42 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#if !defined(_MSC_VER) && __cplusplus < 201703L +#error C++17 or later compatible compiler is required to use ATen. +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// TODO: try to remove this +// There is some back story, see https://github.com/pytorch/pytorch/issues/48684 +#include + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/AccumulateType.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/AccumulateType.h new file mode 100644 index 0000000000000000000000000000000000000000..f4196e6e845ff1060c999823fcf179a3137e8480 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/AccumulateType.h @@ -0,0 +1,178 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// Defines the accumulation type for a scalar type. +// Example: +// using accscalar_t = acc_type; +// +// Accumulation types are an important concept in numeric computing +// because you frequently want to perform intermediate computations +// at a higher precision than the input and output precision, to avoid +// compounding internal rounding errors. Accumulation is the most +// well-known intermediate computation (it is of great importance for +// sum reduction and matrix multiply, for example), but in PyTorch +// acc_type ends up getting used for all sorts of other intermediate +// computations, so it perhaps would be more accurately (ahem) called an +// "accurate" type. acc_type is especially important for reduced +// precision operations like float16 and bfloat16, where relatively +// benign looking inputs can easily end up overflowing/underflowing. +// +// acc_type is parametrized by whether or not you are running on CUDA +// or not, because on CUDA double precision operations are expensive +// and so by default, we don't actually want to use double as an +// acc_type on CUDA. A lot of things are typed out below, but +// basically, the table is generated by a few rules: +// +// If bool: +// Use 'bool' as acc_type. +// If floating point: +// If CUDA, use 'float' as acc_type (unless scalar_t is double), +// otherwise (CPU) use 'double' +// If integral: +// Use 'int64_t' as acc_type +// +// You're not forced to use this template; if you happen to know +// something specific about your use case, you can specify your own +// desired behavior. This template, however, will give you a reasonable +// default that will work for all dtypes supported in PyTorch. + +#if defined(__CUDACC__) +#include +#include +#elif defined(__HIPCC__) +#include +#include +#endif + +namespace at { + +template +struct AccumulateTypeDevice {}; + +template +struct AccumulateType {}; + +template +struct AccumulateType { + using type = typename AccumulateTypeDevice::type; +}; + +template +struct AccumulateType { + using type = typename AccumulateTypeDevice::type; +}; + +template +using acc_type_device = typename AccumulateTypeDevice::type; + +template +using acc_type = typename AccumulateType::type; + +#define ACC_TYPE(t, acc_t, device_type) \ + template <> \ + struct AccumulateTypeDevice { \ + using type = acc_t; \ + }; +#define MPS_ACC_TYPE(t, acc_t) ACC_TYPE(t, acc_t, c10::DeviceType::MPS) +#define XPU_ACC_TYPE(t, acc_t) ACC_TYPE(t, acc_t, c10::DeviceType::XPU) +#define CUDA_ACC_TYPE(t, acc_t) ACC_TYPE(t, acc_t, c10::DeviceType::CUDA) +#define CPU_ACC_TYPE(t, acc_t) ACC_TYPE(t, acc_t, c10::DeviceType::CPU) + +MPS_ACC_TYPE(BFloat16, float) +MPS_ACC_TYPE(Half, float) +MPS_ACC_TYPE(Float8_e5m2, float) +MPS_ACC_TYPE(Float8_e4m3fn, float) +MPS_ACC_TYPE(Float8_e5m2fnuz, float) +MPS_ACC_TYPE(Float8_e4m3fnuz, float) +MPS_ACC_TYPE(float, float) +MPS_ACC_TYPE(double, float) +MPS_ACC_TYPE(int8_t, int64_t) +MPS_ACC_TYPE(uint8_t, int64_t) +MPS_ACC_TYPE(char, int64_t) +MPS_ACC_TYPE(int16_t, int64_t) +MPS_ACC_TYPE(int32_t, int64_t) +MPS_ACC_TYPE(int64_t, int64_t) +MPS_ACC_TYPE(bool, bool) +MPS_ACC_TYPE(c10::complex, c10::complex) +MPS_ACC_TYPE(c10::complex, c10::complex) +MPS_ACC_TYPE(c10::complex, c10::complex) + +XPU_ACC_TYPE(BFloat16, float) +XPU_ACC_TYPE(Half, float) +XPU_ACC_TYPE(Float8_e5m2, float) +XPU_ACC_TYPE(Float8_e4m3fn, float) +XPU_ACC_TYPE(Float8_e5m2fnuz, float) +XPU_ACC_TYPE(Float8_e4m3fnuz, float) +XPU_ACC_TYPE(float, float) +XPU_ACC_TYPE(double, double) +XPU_ACC_TYPE(int8_t, int64_t) +XPU_ACC_TYPE(uint8_t, int64_t) +XPU_ACC_TYPE(char, int64_t) +XPU_ACC_TYPE(int16_t, int64_t) +XPU_ACC_TYPE(int32_t, int64_t) +XPU_ACC_TYPE(int64_t, int64_t) +XPU_ACC_TYPE(bool, bool) +XPU_ACC_TYPE(c10::complex, c10::complex) +XPU_ACC_TYPE(c10::complex, c10::complex) +XPU_ACC_TYPE(c10::complex, c10::complex) + +#if defined(__CUDACC__) || defined(__HIPCC__) +CUDA_ACC_TYPE(half, float) +#endif +CUDA_ACC_TYPE(BFloat16, float) +CUDA_ACC_TYPE(Half, float) +CUDA_ACC_TYPE(Float8_e5m2, float) +CUDA_ACC_TYPE(Float8_e4m3fn, float) +CUDA_ACC_TYPE(Float8_e5m2fnuz, float) +CUDA_ACC_TYPE(Float8_e4m3fnuz, float) +CUDA_ACC_TYPE(float, float) +CUDA_ACC_TYPE(double, double) +CUDA_ACC_TYPE(int8_t, int64_t) +CUDA_ACC_TYPE(uint8_t, int64_t) +CUDA_ACC_TYPE(char, int64_t) +CUDA_ACC_TYPE(int16_t, int64_t) +CUDA_ACC_TYPE(int32_t, int64_t) +CUDA_ACC_TYPE(int64_t, int64_t) +CUDA_ACC_TYPE(bool, bool) +CUDA_ACC_TYPE(c10::complex, c10::complex) +CUDA_ACC_TYPE(c10::complex, c10::complex) +CUDA_ACC_TYPE(c10::complex, c10::complex) + +CPU_ACC_TYPE(BFloat16, float) +CPU_ACC_TYPE(Half, float) +CPU_ACC_TYPE(Float8_e5m2, float) +CPU_ACC_TYPE(Float8_e4m3fn, float) +CPU_ACC_TYPE(Float8_e5m2fnuz, float) +CPU_ACC_TYPE(Float8_e4m3fnuz, float) +CPU_ACC_TYPE(float, double) +CPU_ACC_TYPE(double, double) +CPU_ACC_TYPE(int8_t, int64_t) +CPU_ACC_TYPE(uint8_t, int64_t) +CPU_ACC_TYPE(char, int64_t) +CPU_ACC_TYPE(int16_t, int64_t) +CPU_ACC_TYPE(int32_t, int64_t) +CPU_ACC_TYPE(int64_t, int64_t) +CPU_ACC_TYPE(bool, bool) +CPU_ACC_TYPE(c10::complex, c10::complex) +CPU_ACC_TYPE(c10::complex, c10::complex) +CPU_ACC_TYPE(c10::complex, c10::complex) + +TORCH_API c10::ScalarType toAccumulateType( + c10::ScalarType type, + c10::DeviceType device); +TORCH_API c10::ScalarType toAccumulateType(c10::ScalarType type, bool is_cuda); + +} // namespace at + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/ArrayRef.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/ArrayRef.h new file mode 100644 index 0000000000000000000000000000000000000000..33ab9ae6e70b7faca0063bce3dfc25c8bcec4494 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/ArrayRef.h @@ -0,0 +1,7 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +#include + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/Backend.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/Backend.h new file mode 100644 index 0000000000000000000000000000000000000000..9b517c3fbf80f0792cf2cdad3df5122800006c2e --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/Backend.h @@ -0,0 +1,7 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +#include + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/Backtrace.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/Backtrace.h new file mode 100644 index 0000000000000000000000000000000000000000..bc2bcc208684f9b2e3e3ab12918a212bb8bb09a0 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/Backtrace.h @@ -0,0 +1,7 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +#include + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/BlasBackend.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/BlasBackend.h new file mode 100644 index 0000000000000000000000000000000000000000..892c6d9042b334de6a6aa50e363ce2101e6aaea0 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/BlasBackend.h @@ -0,0 +1,51 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include + +#include +#include + +namespace at { + +enum class BlasBackend : int8_t { Default, Cublas, Cublaslt, Ck }; + +inline std::string BlasBackendToString(at::BlasBackend backend) { + switch (backend) { + case BlasBackend::Default: + return "at::BlasBackend::Default"; + case BlasBackend::Cublas: + return "at::BlasBackend::Cublas"; + case BlasBackend::Cublaslt: + return "at::BlasBackend::Cublaslt"; + case BlasBackend::Ck: + return "at::BlasBackend::Ck"; + default: + TORCH_CHECK(false, "Unknown blas backend"); + } +} + +inline std::ostream& operator<<(std::ostream& stream, at::BlasBackend backend) { + return stream << BlasBackendToString(backend); +} + +namespace blas { + +enum class ScalingType : std::uint8_t { + TensorWise, // fp32 scales + RowWise, // fp32 scales + BlockWise1x16, // fp8_e4m3fn scales + BlockWise1x32, // fp8_e8m0fnu scales + BlockWise1x128, // fp32 scales + BlockWise128x128, // fp32 scales +}; + +enum class SwizzleType : std::uint8_t { NO_SWIZZLE = 0, SWIZZLE_32_4_4 = 1 }; + +} // namespace blas + +} // namespace at + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/CPUApplyUtils.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/CPUApplyUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..08991c86a1a66ceafc918881f9faf6b71e04c982 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/CPUApplyUtils.h @@ -0,0 +1,356 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace at { + +/* + * The basic strategy for apply is as follows: + * + * 1. Starting with the outermost index, loop until we reach a dimension where + * the data is no longer contiguous, i.e. the stride at that dimension is not + * equal to the size of the tensor defined by the outer dimensions. Let's call + * this outer (contiguous) tensor A. Note that if the Tensor is contiguous, then + * A is equal to the entire Tensor. Let's call the inner tensor B. + * + * 2. We loop through the indices in B, starting at its outermost dimension. For + * example, if B is a 2x2 matrix, then we do: + * + * B[0][0] + * B[0][1] + * B[1][0] + * B[1][1] + * + * We set the offset into the underlying storage as (storageOffset + stride_B * + * index_B), i.e. basically we compute the offset into the storage as we would + * normally for a Tensor. But because we are guaranteed the subsequent data is + * contiguous in memory, we can simply loop for sizeof(A) iterations and perform + * the operation, without having to follow the order described by the strides of + * A. + * + * 3. As an optimization, we merge dimensions of A that are contiguous in + * memory. For example, if A is a 3x3x3x3 tensor narrowed from a 3x3x4x3 tensor, + * then the first two dimensions can be merged for the purposes of APPLY, + * reducing the number of nested loops. + */ + +inline Tensor sort_strides(Tensor& tensor_) { + IntArrayRef strides = tensor_.strides(); + std::vector indices; + indices.reserve(tensor_.ndimension()); + for (const auto i : c10::irange(tensor_.ndimension())) { + indices.push_back(i); + } + std::sort(indices.begin(), indices.end(), [&strides](int64_t i1, int64_t i2) { + return strides[i1] > strides[i2]; + }); + Tensor tensor = tensor_.permute(indices); + return tensor; +} + +template +struct strided_tensor_iter_fixed { + public: + T* data_ = NULL; + int64_t dim_ = 0; + + // NOLINTNEXTLINE(*array*) + int64_t counter_[N] = {0}; + // NOLINTNEXTLINE(*array*) + int64_t sizes_[N] = {0}; + // NOLINTNEXTLINE(*array*) + int64_t strides_[N] = {0}; + + strided_tensor_iter_fixed(strided_tensor_iter_fixed const&) = delete; + strided_tensor_iter_fixed& operator=(strided_tensor_iter_fixed const& x) = + delete; + strided_tensor_iter_fixed(strided_tensor_iter_fixed&&) noexcept = default; + strided_tensor_iter_fixed& operator=(strided_tensor_iter_fixed&& x) noexcept = + default; + ~strided_tensor_iter_fixed() noexcept = default; + strided_tensor_iter_fixed( + Tensor& tensor, + [[maybe_unused]] bool sort_strides = false) + : data_(tensor.data_ptr()) { + std::memset(counter_, 0, sizeof(int64_t) * N); + if (tensor.dim() > 0) { + std::memcpy( + sizes_, tensor.sizes().data(), tensor.dim() * sizeof(int64_t)); + std::memcpy( + strides_, tensor.strides().data(), tensor.dim() * sizeof(int64_t)); + } + dim_ = std::get<1>(collapse_dims(sizes_, strides_, tensor.ndimension())); + } +}; + +template +struct strided_tensor_iter { + private: + public: + T* data_ = NULL; + int64_t dim_; + + std::vector counter_; + std::vector sizes_; + std::vector strides_; + + strided_tensor_iter(strided_tensor_iter const&) = delete; + strided_tensor_iter& operator=(strided_tensor_iter const& x) = delete; + strided_tensor_iter(strided_tensor_iter&&) noexcept = default; + strided_tensor_iter& operator=(strided_tensor_iter&&) noexcept = default; + ~strided_tensor_iter() noexcept = default; + strided_tensor_iter(Tensor& tensor) + : data_(tensor.data_ptr()), + dim_(tensor.ndimension()), + counter_(dim_, 0), + sizes_(tensor.sizes().vec()), + strides_(tensor.strides().vec()) { + dim_ = std::get<1>(collapse_dims(sizes_.data(), strides_.data(), dim_)); + } +}; + +inline bool _all_equal_numel(at::ArrayRef tensors) { + if (tensors.empty()) + return true; + int64_t all_numel = tensors[0].numel(); + for (const auto i : c10::irange(1, tensors.size())) { + if (tensors[i].numel() != all_numel) + return false; + } + return true; +} + +inline std::string _all_equal_numel_error(at::ArrayRef tensors) { + std::ostringstream oss; + oss << "inconsistent tensor size, expected "; + for (size_t i = 0; i < tensors.size() - 1; i++) { + oss << tensors[i].sizes() << ", "; + } + oss << "and " << tensors[tensors.size() - 1].sizes() + << " to have the same number of elements, but got "; + for (size_t i = 0; i < tensors.size() - 1; i++) { + oss << tensors[i].numel() << ", "; + } + oss << "and " << tensors[tensors.size() - 1].numel() + << " elements respectively"; + return oss.str(); +} + +inline bool _apply_preamble(ArrayRef tensors) { + checkDeviceType("CPU_tensor_apply", tensors, kCPU); + checkLayout("CPU_tensor_apply", tensors, kStrided); + TORCH_CHECK(_all_equal_numel(tensors), _all_equal_numel_error(tensors)); + // An empty tensor has no elements + for (auto& t : tensors) + if (t.numel() == 0) + return false; + return true; +} + +inline int64_t _max_dim_tensors(ArrayRef tensors) { + int64_t dim = 0; + for (auto& t : tensors) + dim = std::max(dim, t.ndimension()); + return dim; +} + +inline void iterate(int64_t /*size*/) {} + +template +inline void iterate(int64_t size, Arg& iter, Args&... iter_tail) { + iter.counter_[iter.dim_ - 1] += size; + iter.data_ = iter.data_ + size * iter.strides_[iter.dim_ - 1]; + iterate(size, iter_tail...); +} + +inline bool iterate_continue() { + return true; +} + +template +inline bool iterate_continue(Arg& iter, Args&... iter_tail) { + return iter.counter_[iter.dim_ - 1] < iter.sizes_[iter.dim_ - 1] && + iterate_continue(iter_tail...); +} + +inline int64_t max_iterate_size() { + return std::numeric_limits::max(); +} + +template +inline int64_t max_iterate_size(Arg& iter, Args&... iter_tail) { + return std::min( + (iter.sizes_[iter.dim_ - 1] - iter.counter_[iter.dim_ - 1]), + max_iterate_size(iter_tail...)); +} + +inline void iterate_overflow() {} + +template +inline void iterate_overflow(Arg& iter, Args&... iter_tail) { + if (iter.counter_[iter.dim_ - 1] == iter.sizes_[iter.dim_ - 1]) { + for (int64_t i = iter.dim_ - 1; i > 0; i--) { + if (iter.counter_[i] == iter.sizes_[i]) { + iter.counter_[i] = 0; + iter.counter_[i - 1]++; + iter.data_ = iter.data_ - (iter.sizes_[i] * iter.strides_[i]) + + iter.strides_[i - 1]; + } + } + } + iterate_overflow(iter_tail...); +} + +inline void forward(int64_t /*offset*/) {} + +template +inline void forward(int64_t offset, Arg& iter, Args&... iter_tail) { + int64_t multi = offset; + for (int64_t i = iter.dim_ - 1; i >= 0; i--) { + int64_t inc = multi % iter.sizes_[i]; + multi = multi / iter.sizes_[i]; + iter.data_ = iter.data_ + inc * iter.strides_[i]; + iter.counter_[i] += inc; + } + forward(offset, iter_tail...); +} + +inline int64_t max_dim() { + return 0; +} + +template +inline int64_t max_dim(Arg& iter, Args&... iter_tail) { + return std::max(iter.dim_, max_dim(iter_tail...)); +} + +inline void apply_op() {} + +template +inline void apply_op( + int64_t numel, + int64_t offset, + const Op& op, + Args... iters) { + // For 0-dim tensors + if (numel == 1 && max_dim(iters...) == 0) { + op(*iters.data_...); + return; + } + if (offset > 0) + forward(offset, iters...); + // Splitting this into chunks helps the compiler create faster assembly + for (int64_t i = 0; i < numel;) { + for (; iterate_continue(iters...) && i < numel;) { + op(*iters.data_...); + iterate(1, iters...); + i++; + } + iterate_overflow(iters...); + } +} + +/* + Apply a pointwise operator to sequence of tensors + + The calling convention for op is a function/functor that takes the same + number of pointers of type scalar as the number of given tensors. For example, + to compute a = b * c, op would be of the form: + [](scalar* a_val, const scalar* b_val, const scalar* c_val) { a_val[0] = + b_val[0] * c_val[0]; }; +*/ + +template +inline void CPU_tensor_apply2(Tensor tensor1, Tensor tensor2, const Op op) { + if (!_apply_preamble({tensor1, tensor2})) + return; + if (_max_dim_tensors({tensor1, tensor2}) <= 8) { + apply_op( + tensor1.numel(), + 0, + op, + strided_tensor_iter_fixed(tensor1), + strided_tensor_iter_fixed(tensor2)); + } else { + apply_op( + tensor1.numel(), + 0, + op, + strided_tensor_iter(tensor1), + strided_tensor_iter(tensor2)); + } +} + +template +inline void CPU_tensor_apply3( + Tensor tensor1, + Tensor tensor2, + Tensor tensor3, + const Op op) { + if (!_apply_preamble({tensor1, tensor2, tensor3})) + return; + if (_max_dim_tensors({tensor1, tensor2, tensor3}) <= 8) { + apply_op( + tensor1.numel(), + 0, + op, + strided_tensor_iter_fixed(tensor1), + strided_tensor_iter_fixed(tensor2), + strided_tensor_iter_fixed(tensor3)); + } else { + apply_op( + tensor1.numel(), + 0, + op, + strided_tensor_iter(tensor1), + strided_tensor_iter(tensor2), + strided_tensor_iter(tensor3)); + } +} + +template < + typename scalar1, + typename scalar2, + typename scalar3, + typename scalar4, + typename Op> +inline void CPU_tensor_apply4( + Tensor tensor1, + Tensor tensor2, + Tensor tensor3, + Tensor tensor4, + const Op op) { + if (!_apply_preamble({tensor1, tensor2, tensor3, tensor4})) + return; + if (_max_dim_tensors({tensor1, tensor2, tensor3, tensor4}) <= 8) { + apply_op( + tensor1.numel(), + 0, + op, + strided_tensor_iter_fixed(tensor1), + strided_tensor_iter_fixed(tensor2), + strided_tensor_iter_fixed(tensor3), + strided_tensor_iter_fixed(tensor4)); + } else { + apply_op( + tensor1.numel(), + 0, + op, + strided_tensor_iter(tensor1), + strided_tensor_iter(tensor2), + strided_tensor_iter(tensor3), + strided_tensor_iter(tensor4)); + } +} + +} // namespace at + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/CPUFixedAllocator.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/CPUFixedAllocator.h new file mode 100644 index 0000000000000000000000000000000000000000..c92a67876cc131cde31c8720cef9301ff4ac8433 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/CPUFixedAllocator.h @@ -0,0 +1,38 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include + +// This file creates a fake allocator that just throws exceptions if +// it is actually used. + +// state passed to the allocator is the std::function called +// when the blob is release by ATen + +namespace at { + +static void* cpu_fixed_malloc(void*, ptrdiff_t) { + TORCH_CHECK(false, "attempting to resize a tensor view of an external blob"); +} + +static void* cpu_fixed_realloc(void*, void*, ptrdiff_t) { + TORCH_CHECK(false, "attempting to resize a tensor view of an external blob"); +} + +static void cpu_fixed_free(void* state, void* allocation) { + auto on_release = static_cast*>(state); + (*on_release)(allocation); + delete on_release; +} + +static Allocator CPU_fixed_allocator = { + cpu_fixed_malloc, + cpu_fixed_realloc, + cpu_fixed_free}; + +} // namespace at + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/CPUFunctions.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/CPUFunctions.h new file mode 100644 index 0000000000000000000000000000000000000000..e9480061a6dbaf472fceb48b698bab464562c704 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/CPUFunctions.h @@ -0,0 +1,34 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#include + +// TODO Undo all logic introduced for Note [Avoiding Include Cycles In Static Dispatch] +// Code introduced to avoid cyclic dependency in static dispatch is no longer +// needed as static dispatch logic is moved from TensorBody.h, which caused cycles in the first place, +// to Operators.cpp for supporting multiple backends with multiple kernels. +// +// Note [Avoiding Include Cycles In Static Dispatch] +// In order to avoid #include cycles in the static dispatch build, we've carefully split out +// the static function definition files into {DispatchKey}Functions.h and {DispatchKey}Functions_inl.h. +// +// Without this split, the include cycle looks like TensorBody.h -> CPUFunctions.h -> TensorBody.h. +// - TensorBody.h #includes CPUFunctions.h in the static dispatch build, because the tensor methods +// all need to call into the fastpath C++ API defined in CPUFunctions.h. The methods are also all +// directly inlined into TensorBody.h. +// - CPUFunctions.h #includes TensorBody.h because it contains function declarations for the entire C++ API, +// which include functions that have defaultable std::optional arguments. +// That requires knowing the full Tensor class definition. +// +// We break the cycle by doing the following: +// - Split out CPUFunction.h into two files: CPUFunctions.h and CPUFunctions_inl.h +// - CPUFunction.h is a dummy file that just includes the Tensor class and includes CPUFunctions_inl., +// - CPUFunctions_inl.h includes everything else +// - (only in the static dispatch build) TensorBody.h makes sure to finish defining the Tensor class, +// and then it includes CPUFunctions_inl.h. +// - All other files that want the cpu fastpath functions can include CPUFunctions.h directly. +// - This also means that static dispatch build, CPUFunctions.h only needs to +// #include TensorBody.h, and it will automatically bring in CPUFunctions_inl.h. +#include + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/CPUFunctions_inl.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/CPUFunctions_inl.h new file mode 100644 index 0000000000000000000000000000000000000000..db94b7405387333f38ef0183cdc0b8b075a6b921 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/CPUFunctions_inl.h @@ -0,0 +1,549 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunctions_inl.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +#if defined(AT_PER_OPERATOR_HEADERS) && defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS) +#error This change adds a dependency on all pytorch operators, meaning the \ + file will need to be re-compiled every time an operator is changed or added. \ + Consider including a specific operator from \ + . \ + See NOTE [TORCH_ASSERT_ONLY_METHOD_OPERATORS]. +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/CPUGeneratorImpl.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/CPUGeneratorImpl.h new file mode 100644 index 0000000000000000000000000000000000000000..78fb733e15a3c4cd6774124a7beeeee62136b85c --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/CPUGeneratorImpl.h @@ -0,0 +1,54 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include + +namespace at { + +struct TORCH_API CPUGeneratorImpl : public c10::GeneratorImpl { + // Constructors + CPUGeneratorImpl(uint64_t seed_in = default_rng_seed_val); + ~CPUGeneratorImpl() override = default; + + // CPUGeneratorImpl methods + std::shared_ptr clone() const; + void set_current_seed(uint64_t seed) override; + void set_offset(uint64_t offset) override; + uint64_t get_offset() const override; + uint64_t current_seed() const override; + uint64_t seed() override; + void set_state(const c10::TensorImpl& new_state) override; + c10::intrusive_ptr get_state() const override; + static c10::DeviceType device_type(); + uint32_t random(); + uint64_t random64(); + std::optional next_float_normal_sample(); + std::optional next_double_normal_sample(); + void set_next_float_normal_sample(std::optional randn); + void set_next_double_normal_sample(std::optional randn); + at::mt19937 engine(); + void set_engine(at::mt19937 engine); + + private: + CPUGeneratorImpl* clone_impl() const override; + at::mt19937 engine_; + std::optional next_float_normal_sample_; + std::optional next_double_normal_sample_; +}; + +namespace detail { + +TORCH_API const Generator& getDefaultCPUGenerator(); +TORCH_API Generator +createCPUGenerator(uint64_t seed_val = default_rng_seed_val); + +} // namespace detail + +} // namespace at + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/CUDAFunctions.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/CUDAFunctions.h new file mode 100644 index 0000000000000000000000000000000000000000..50fd3e96cb64401b46838b5a03a18ee51a9cfee4 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/CUDAFunctions.h @@ -0,0 +1,34 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#include + +// TODO Undo all logic introduced for Note [Avoiding Include Cycles In Static Dispatch] +// Code introduced to avoid cyclic dependency in static dispatch is no longer +// needed as static dispatch logic is moved from TensorBody.h, which caused cycles in the first place, +// to Operators.cpp for supporting multiple backends with multiple kernels. +// +// Note [Avoiding Include Cycles In Static Dispatch] +// In order to avoid #include cycles in the static dispatch build, we've carefully split out +// the static function definition files into {DispatchKey}Functions.h and {DispatchKey}Functions_inl.h. +// +// Without this split, the include cycle looks like TensorBody.h -> CPUFunctions.h -> TensorBody.h. +// - TensorBody.h #includes CPUFunctions.h in the static dispatch build, because the tensor methods +// all need to call into the fastpath C++ API defined in CPUFunctions.h. The methods are also all +// directly inlined into TensorBody.h. +// - CPUFunctions.h #includes TensorBody.h because it contains function declarations for the entire C++ API, +// which include functions that have defaultable std::optional arguments. +// That requires knowing the full Tensor class definition. +// +// We break the cycle by doing the following: +// - Split out CPUFunction.h into two files: CPUFunctions.h and CPUFunctions_inl.h +// - CPUFunction.h is a dummy file that just includes the Tensor class and includes CPUFunctions_inl., +// - CPUFunctions_inl.h includes everything else +// - (only in the static dispatch build) TensorBody.h makes sure to finish defining the Tensor class, +// and then it includes CPUFunctions_inl.h. +// - All other files that want the cpu fastpath functions can include CPUFunctions.h directly. +// - This also means that static dispatch build, CPUFunctions.h only needs to +// #include TensorBody.h, and it will automatically bring in CPUFunctions_inl.h. +#include + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/CUDAFunctions_inl.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/CUDAFunctions_inl.h new file mode 100644 index 0000000000000000000000000000000000000000..2da2f806a9edeb396a93a709a6e652a47247be9d --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/CUDAFunctions_inl.h @@ -0,0 +1,641 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunctions_inl.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +#if defined(AT_PER_OPERATOR_HEADERS) && defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS) +#error This change adds a dependency on all pytorch operators, meaning the \ + file will need to be re-compiled every time an operator is changed or added. \ + Consider including a specific operator from \ + . \ + See NOTE [TORCH_ASSERT_ONLY_METHOD_OPERATORS]. +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/CachedTensorUtils.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/CachedTensorUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..81635eb9af5b1e02c7c43c26fd007da579654f6f --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/CachedTensorUtils.h @@ -0,0 +1,29 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include + +namespace at::caching { + +// Some systems (just cudagraphs currently) will persist a static tensor output +// whose TensorImpl does not change across iterations. For these tensors caching +// dtype conversions is invalid. Additionally, there will be an extra reference +// count to these cached tensors that would prevent buffer inplacing and other +// checks on tensor uniqueness. If we are not using these systems the enabled +// flag will be false and we will avoid the hash lookup. + +TORCH_API bool is_cached_tensor(const at::Tensor& t); +TORCH_API void add_cached_tensor(const at::Tensor& t); +TORCH_API void remove_cached_tensor(const at::Tensor& t); +TORCH_API void set_cached_tensors_enabled(bool enable); + +// For gradient buffer stealing we will adjust the use count of tensors +// which are persisted by cudagraphs, just as we need to adjust reference +// count of tensors with hooks. +TORCH_API size_t adjusted_use_count(const at::Tensor& t); + +} // namespace at::caching + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/CollapseDims.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/CollapseDims.h new file mode 100644 index 0000000000000000000000000000000000000000..38d5aa0fedf843921d8ce4e7a6dae65aee836408 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/CollapseDims.h @@ -0,0 +1,99 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#include +#include + +namespace at { + +/* +[collapse dims] Updates sizes, and strides to reflect a "collapse" of +the info, possibly excluding the optional excludeDim. A "collapsed" version +of the info is the fewest dims that order the tensor's elements in the same +way as the original info. If excludeDim is specified, the collapse is the +fewest dims that order the tensor's elements as the original and preserve the +excluded dimension, unless the tensor collapses to a point. + +This function returns a pair of values. + +1) The (new) index of the preserved dimension if excludeDim is +specified. 0 if the tensor is collapsed to a point. -1 +otherwise. + +2) The new number of dimensions. +*/ +template +inline std::pair collapse_dims( + T* sizes, + T* strides, + int64_t dims, + const int excludeDim = -1) { + TORCH_CHECK( + excludeDim >= -1 && excludeDim < dims, + "expected excluded dim between -1 and dims - 1"); + + int64_t stopDim = (excludeDim == -1) ? dims : excludeDim; + int64_t newIndex = -1; + int64_t oldIndex = 0; + int64_t remappedExcludedDim = -1; + + while (oldIndex < dims) { + // Finds a dimension to collapse into + for (; oldIndex < stopDim; ++oldIndex) { + if (sizes[oldIndex] == 1) { + continue; + } + + ++newIndex; + sizes[newIndex] = sizes[oldIndex]; + strides[newIndex] = strides[oldIndex]; + ++oldIndex; + break; + } + + // Collapses dims + for (; oldIndex < stopDim; ++oldIndex) { + if (sizes[oldIndex] == 1) { + continue; + } + + if (strides[newIndex] == sizes[oldIndex] * strides[oldIndex]) { + sizes[newIndex] *= sizes[oldIndex]; + strides[newIndex] = strides[oldIndex]; + } else { + ++newIndex; + sizes[newIndex] = sizes[oldIndex]; + strides[newIndex] = strides[oldIndex]; + } + } + + // Handles excludeDim being set (oldIndex == excludeDim) + if (oldIndex != dims) { + // Preserves excluded dimension + ++newIndex; + sizes[newIndex] = sizes[oldIndex]; + strides[newIndex] = strides[oldIndex]; + remappedExcludedDim = newIndex; + + // Restarts iteration after excludeDim + ++oldIndex; + stopDim = dims; + } + } + + // Handles special case of all dims size 1 + if (newIndex == -1 || (newIndex == 0 && sizes[0] == 1)) { + dims = 1; + sizes[0] = 1; + strides[0] = 1; + + return std::pair(0, 1); + } + + dims = newIndex + 1; + return std::pair(remappedExcludedDim, dims); +} + +} // namespace at + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/CompositeExplicitAutogradFunctions.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/CompositeExplicitAutogradFunctions.h new file mode 100644 index 0000000000000000000000000000000000000000..f05513a7e16dd4cca6ab92be7a01e8ec16f2106c --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/CompositeExplicitAutogradFunctions.h @@ -0,0 +1,34 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#include + +// TODO Undo all logic introduced for Note [Avoiding Include Cycles In Static Dispatch] +// Code introduced to avoid cyclic dependency in static dispatch is no longer +// needed as static dispatch logic is moved from TensorBody.h, which caused cycles in the first place, +// to Operators.cpp for supporting multiple backends with multiple kernels. +// +// Note [Avoiding Include Cycles In Static Dispatch] +// In order to avoid #include cycles in the static dispatch build, we've carefully split out +// the static function definition files into {DispatchKey}Functions.h and {DispatchKey}Functions_inl.h. +// +// Without this split, the include cycle looks like TensorBody.h -> CPUFunctions.h -> TensorBody.h. +// - TensorBody.h #includes CPUFunctions.h in the static dispatch build, because the tensor methods +// all need to call into the fastpath C++ API defined in CPUFunctions.h. The methods are also all +// directly inlined into TensorBody.h. +// - CPUFunctions.h #includes TensorBody.h because it contains function declarations for the entire C++ API, +// which include functions that have defaultable std::optional arguments. +// That requires knowing the full Tensor class definition. +// +// We break the cycle by doing the following: +// - Split out CPUFunction.h into two files: CPUFunctions.h and CPUFunctions_inl.h +// - CPUFunction.h is a dummy file that just includes the Tensor class and includes CPUFunctions_inl., +// - CPUFunctions_inl.h includes everything else +// - (only in the static dispatch build) TensorBody.h makes sure to finish defining the Tensor class, +// and then it includes CPUFunctions_inl.h. +// - All other files that want the cpu fastpath functions can include CPUFunctions.h directly. +// - This also means that static dispatch build, CPUFunctions.h only needs to +// #include TensorBody.h, and it will automatically bring in CPUFunctions_inl.h. +#include + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/CompositeExplicitAutogradFunctions_inl.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/CompositeExplicitAutogradFunctions_inl.h new file mode 100644 index 0000000000000000000000000000000000000000..2904daa74e7293c9ab495e0bb55d21dc2017637c --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/CompositeExplicitAutogradFunctions_inl.h @@ -0,0 +1,565 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunctions_inl.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +#if defined(AT_PER_OPERATOR_HEADERS) && defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS) +#error This change adds a dependency on all pytorch operators, meaning the \ + file will need to be re-compiled every time an operator is changed or added. \ + Consider including a specific operator from \ + . \ + See NOTE [TORCH_ASSERT_ONLY_METHOD_OPERATORS]. +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/CompositeExplicitAutogradNonFunctionalFunctions.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/CompositeExplicitAutogradNonFunctionalFunctions.h new file mode 100644 index 0000000000000000000000000000000000000000..c268c6443a8be7346008abefdc7d8ce37f9e25e8 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/CompositeExplicitAutogradNonFunctionalFunctions.h @@ -0,0 +1,34 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#include + +// TODO Undo all logic introduced for Note [Avoiding Include Cycles In Static Dispatch] +// Code introduced to avoid cyclic dependency in static dispatch is no longer +// needed as static dispatch logic is moved from TensorBody.h, which caused cycles in the first place, +// to Operators.cpp for supporting multiple backends with multiple kernels. +// +// Note [Avoiding Include Cycles In Static Dispatch] +// In order to avoid #include cycles in the static dispatch build, we've carefully split out +// the static function definition files into {DispatchKey}Functions.h and {DispatchKey}Functions_inl.h. +// +// Without this split, the include cycle looks like TensorBody.h -> CPUFunctions.h -> TensorBody.h. +// - TensorBody.h #includes CPUFunctions.h in the static dispatch build, because the tensor methods +// all need to call into the fastpath C++ API defined in CPUFunctions.h. The methods are also all +// directly inlined into TensorBody.h. +// - CPUFunctions.h #includes TensorBody.h because it contains function declarations for the entire C++ API, +// which include functions that have defaultable std::optional arguments. +// That requires knowing the full Tensor class definition. +// +// We break the cycle by doing the following: +// - Split out CPUFunction.h into two files: CPUFunctions.h and CPUFunctions_inl.h +// - CPUFunction.h is a dummy file that just includes the Tensor class and includes CPUFunctions_inl., +// - CPUFunctions_inl.h includes everything else +// - (only in the static dispatch build) TensorBody.h makes sure to finish defining the Tensor class, +// and then it includes CPUFunctions_inl.h. +// - All other files that want the cpu fastpath functions can include CPUFunctions.h directly. +// - This also means that static dispatch build, CPUFunctions.h only needs to +// #include TensorBody.h, and it will automatically bring in CPUFunctions_inl.h. +#include + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/CompositeExplicitAutogradNonFunctionalFunctions_inl.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/CompositeExplicitAutogradNonFunctionalFunctions_inl.h new file mode 100644 index 0000000000000000000000000000000000000000..2871a076e34bda961f26fffb94846f370d8b2990 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/CompositeExplicitAutogradNonFunctionalFunctions_inl.h @@ -0,0 +1,329 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunctions_inl.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +#if defined(AT_PER_OPERATOR_HEADERS) && defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS) +#error This change adds a dependency on all pytorch operators, meaning the \ + file will need to be re-compiled every time an operator is changed or added. \ + Consider including a specific operator from \ + . \ + See NOTE [TORCH_ASSERT_ONLY_METHOD_OPERATORS]. +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/CompositeImplicitAutogradFunctions.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/CompositeImplicitAutogradFunctions.h new file mode 100644 index 0000000000000000000000000000000000000000..2e156c650a6e891e11608a0fd86f506526001f07 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/CompositeImplicitAutogradFunctions.h @@ -0,0 +1,34 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#include + +// TODO Undo all logic introduced for Note [Avoiding Include Cycles In Static Dispatch] +// Code introduced to avoid cyclic dependency in static dispatch is no longer +// needed as static dispatch logic is moved from TensorBody.h, which caused cycles in the first place, +// to Operators.cpp for supporting multiple backends with multiple kernels. +// +// Note [Avoiding Include Cycles In Static Dispatch] +// In order to avoid #include cycles in the static dispatch build, we've carefully split out +// the static function definition files into {DispatchKey}Functions.h and {DispatchKey}Functions_inl.h. +// +// Without this split, the include cycle looks like TensorBody.h -> CPUFunctions.h -> TensorBody.h. +// - TensorBody.h #includes CPUFunctions.h in the static dispatch build, because the tensor methods +// all need to call into the fastpath C++ API defined in CPUFunctions.h. The methods are also all +// directly inlined into TensorBody.h. +// - CPUFunctions.h #includes TensorBody.h because it contains function declarations for the entire C++ API, +// which include functions that have defaultable std::optional arguments. +// That requires knowing the full Tensor class definition. +// +// We break the cycle by doing the following: +// - Split out CPUFunction.h into two files: CPUFunctions.h and CPUFunctions_inl.h +// - CPUFunction.h is a dummy file that just includes the Tensor class and includes CPUFunctions_inl., +// - CPUFunctions_inl.h includes everything else +// - (only in the static dispatch build) TensorBody.h makes sure to finish defining the Tensor class, +// and then it includes CPUFunctions_inl.h. +// - All other files that want the cpu fastpath functions can include CPUFunctions.h directly. +// - This also means that static dispatch build, CPUFunctions.h only needs to +// #include TensorBody.h, and it will automatically bring in CPUFunctions_inl.h. +#include + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/CompositeImplicitAutogradFunctions_inl.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/CompositeImplicitAutogradFunctions_inl.h new file mode 100644 index 0000000000000000000000000000000000000000..6dfa83e585932ef4aba6b52f690e7fe481c35995 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/CompositeImplicitAutogradFunctions_inl.h @@ -0,0 +1,508 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunctions_inl.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +#if defined(AT_PER_OPERATOR_HEADERS) && defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS) +#error This change adds a dependency on all pytorch operators, meaning the \ + file will need to be re-compiled every time an operator is changed or added. \ + Consider including a specific operator from \ + . \ + See NOTE [TORCH_ASSERT_ONLY_METHOD_OPERATORS]. +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/CompositeImplicitAutogradNestedTensorFunctions.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/CompositeImplicitAutogradNestedTensorFunctions.h new file mode 100644 index 0000000000000000000000000000000000000000..19574a46f4ce9e947d561af9261b9044b4fd2581 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/CompositeImplicitAutogradNestedTensorFunctions.h @@ -0,0 +1,34 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#include + +// TODO Undo all logic introduced for Note [Avoiding Include Cycles In Static Dispatch] +// Code introduced to avoid cyclic dependency in static dispatch is no longer +// needed as static dispatch logic is moved from TensorBody.h, which caused cycles in the first place, +// to Operators.cpp for supporting multiple backends with multiple kernels. +// +// Note [Avoiding Include Cycles In Static Dispatch] +// In order to avoid #include cycles in the static dispatch build, we've carefully split out +// the static function definition files into {DispatchKey}Functions.h and {DispatchKey}Functions_inl.h. +// +// Without this split, the include cycle looks like TensorBody.h -> CPUFunctions.h -> TensorBody.h. +// - TensorBody.h #includes CPUFunctions.h in the static dispatch build, because the tensor methods +// all need to call into the fastpath C++ API defined in CPUFunctions.h. The methods are also all +// directly inlined into TensorBody.h. +// - CPUFunctions.h #includes TensorBody.h because it contains function declarations for the entire C++ API, +// which include functions that have defaultable std::optional arguments. +// That requires knowing the full Tensor class definition. +// +// We break the cycle by doing the following: +// - Split out CPUFunction.h into two files: CPUFunctions.h and CPUFunctions_inl.h +// - CPUFunction.h is a dummy file that just includes the Tensor class and includes CPUFunctions_inl., +// - CPUFunctions_inl.h includes everything else +// - (only in the static dispatch build) TensorBody.h makes sure to finish defining the Tensor class, +// and then it includes CPUFunctions_inl.h. +// - All other files that want the cpu fastpath functions can include CPUFunctions.h directly. +// - This also means that static dispatch build, CPUFunctions.h only needs to +// #include TensorBody.h, and it will automatically bring in CPUFunctions_inl.h. +#include + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/CompositeImplicitAutogradNestedTensorFunctions_inl.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/CompositeImplicitAutogradNestedTensorFunctions_inl.h new file mode 100644 index 0000000000000000000000000000000000000000..4ae28be49138cab876fdf908ff49311e808216c9 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/CompositeImplicitAutogradNestedTensorFunctions_inl.h @@ -0,0 +1,30 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunctions_inl.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +#if defined(AT_PER_OPERATOR_HEADERS) && defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS) +#error This change adds a dependency on all pytorch operators, meaning the \ + file will need to be re-compiled every time an operator is changed or added. \ + Consider including a specific operator from \ + . \ + See NOTE [TORCH_ASSERT_ONLY_METHOD_OPERATORS]. +#endif + +#include +#include +#include +#include + + + + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/Config.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/Config.h new file mode 100644 index 0000000000000000000000000000000000000000..1a5c8e5ade1e78d5ee0ac04243e935639ae3eb59 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/Config.h @@ -0,0 +1,28 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +// Test these using #if AT_MKL_ENABLED(), not #ifdef, so that it's +// obvious if you forgot to include Config.h +// c.f. https://stackoverflow.com/questions/33759787/generating-an-error-if-checked-boolean-macro-is-not-defined +// +// DO NOT put the macros for CUDA libraries in this file; they belong in cuda/CUDAConfig.h + +#define AT_MKLDNN_ENABLED() 1 +#define AT_MKLDNN_ACL_ENABLED() 0 +#define AT_MKL_ENABLED() 1 +#define AT_MKL_SEQUENTIAL() 0 +#define AT_POCKETFFT_ENABLED() 0 +#define AT_NNPACK_ENABLED() 1 +#define CAFFE2_STATIC_LINK_CUDA() 0 +#define AT_BUILD_WITH_BLAS() 1 +#define AT_BUILD_WITH_LAPACK() 1 +#define AT_PARALLEL_OPENMP 1 +#define AT_PARALLEL_NATIVE 0 +#define AT_BLAS_F2C() 0 +#define AT_BLAS_USE_CBLAS_DOT() 0 +#define AT_KLEIDIAI_ENABLED() 0 +#define AT_USE_EIGEN_SPARSE() 0 + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/Context.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/Context.h new file mode 100644 index 0000000000000000000000000000000000000000..dda851524c700ed74336183ad5cfad660a6cde17 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/Context.h @@ -0,0 +1,712 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace at { + +class Tensor; + +enum class TORCH_API Float32MatmulPrecision { HIGHEST, HIGH, MEDIUM }; + +enum class CuBLASReductionOption : uint8_t { + AllowReducedPrecisionWithSplitK = 0, + DisallowReducedPrecisionAllowSplitK = 1, + DisallowReducedPrecisionDisallowSplitK = 2, +}; +enum class TORCH_API Float32Backend { GENERIC, CUDA, MKLDNN }; +enum class TORCH_API Float32Op { ALL, CONV, RNN, MATMUL }; +enum class TORCH_API Float32Precision { NONE, IEEE, TF32, BF16 }; + +TORCH_API Float32Backend str2backend(const std::string& name); +TORCH_API Float32Op str2op(const std::string& name); +TORCH_API Float32Precision str2precision(const std::string& name); +TORCH_API std::string precision2str(Float32Precision prec); + +class TORCH_API Context { + public: + Context(); + + const Generator& defaultGenerator(Device device) { + c10::DeviceType device_type = device.type(); + lazyInitDevice(device_type); + + if (device_type == at::kCPU) { + return at::detail::getDefaultCPUGenerator(); + } else { + return getAcceleratorHooksInterface(device_type) + .getDefaultGenerator(device.index()); + } + } + + const AcceleratorHooksInterface& getAcceleratorHooksInterface( + std::optional opt_device_type = std::nullopt) { + if (!opt_device_type.has_value()) { + opt_device_type = at::getAccelerator(true); + } + if (opt_device_type == at::kCUDA) { + return at::detail::getCUDAHooks(); + } else if (opt_device_type == at::kXPU) { + return at::detail::getXPUHooks(); + } else if (opt_device_type == at::kMPS) { + return at::detail::getMPSHooks(); + } else if (opt_device_type == at::kPrivateUse1) { + return at::detail::getPrivateUse1Hooks(); + } else if (opt_device_type == at::kMTIA) { + return at::detail::getMTIAHooks(); + } else if (opt_device_type == at::kHIP) { + return at::detail::getHIPHooks(); + } else if (opt_device_type == at::kHPU) { + return at::detail::getHPUHooks(); + } else if (opt_device_type == at::kXLA) { + return at::detail::getXLAHooks(); + } else { + TORCH_CHECK( + false, + opt_device_type.has_value() + ? c10::DeviceTypeName(opt_device_type.value()) + : "None", + " device type not an accelerator."); + } + } + + Device getDeviceFromPtr(void* data, c10::DeviceType device_type) { + lazyInitDevice(device_type); + + if (device_type == at::kCPU) { + return c10::DeviceType::CPU; + } else { + return getAcceleratorHooksInterface(device_type).getDeviceFromPtr(data); + } + } + + bool isPinnedPtr( + const void* data, + std::optional device_type = std::nullopt) { + auto opt_device_type = + device_type.has_value() ? device_type : at::getAccelerator(); + if (!opt_device_type.has_value() || // there is no accelerator + !at::isAccelerator( + opt_device_type.value())) { // passed device not an accelerator + return false; + } + if (!init_[static_cast(opt_device_type.value())].test_once()) { + // If the device is not initialized, no pointer can be pinned for it + return false; + } + return getAcceleratorHooksInterface(opt_device_type).isPinnedPtr(data); + } + + Allocator* getPinnedMemoryAllocator( + std::optional device_type = std::nullopt) { + auto opt_device_type = + device_type.has_value() ? device_type : at::getAccelerator(); + if (opt_device_type) { + lazyInitDevice(opt_device_type.value()); + } + return getAcceleratorHooksInterface(device_type).getPinnedMemoryAllocator(); + } + + void lazyInitDevice(c10::DeviceType device_type) { + if (device_type != at::kCPU) { + c10::call_once(init_[static_cast(device_type)], [&] { + getAcceleratorHooksInterface(device_type).init(); + }); + } + } + + static bool hasOpenMP(); + static bool hasMKL(); + static bool hasKleidiAI(); + static bool hasLAPACK(); + static bool hasMKLDNN(); + static bool ckSupported(); + static bool hasEigenSparse(); + static bool hasMAGMA() { + return detail::getCUDAHooks().hasMAGMA(); + } + static bool hasCUDA() { + return detail::getCUDAHooks().hasCUDA(); + } + static bool hasMTIA() { + return detail::getMTIAHooks().hasMTIA(); + } + static bool hasCUDART() { + return detail::getCUDAHooks().hasCUDART(); + } + static long versionCUDART() { + return detail::getCUDAHooks().versionCUDART(); + } + static bool hasCuDNN() { + return detail::getCUDAHooks().hasCuDNN(); + } + static long versionCuDNN() { + return detail::getCUDAHooks().versionCuDNN(); + } + static long versionRuntimeCuDNN() { + return detail::getCUDAHooks().versionRuntimeCuDNN(); + } + static long versionCuDNNFrontend() { + return detail::getCUDAHooks().versionCuDNNFrontend(); + } + static bool hasCuSOLVER() { + return detail::getCUDAHooks().hasCuSOLVER(); + } + static bool hasCuBLASLt() { + return detail::getCUDAHooks().hasCuBLASLt(); + } + static bool hasROCM() { + return detail::getCUDAHooks().hasROCM(); + } + static bool hasCKSDPA() { + return detail::getCUDAHooks().hasCKSDPA(); + } + static bool hasCKGEMM() { + return detail::getCUDAHooks().hasCKGEMM(); + } + static bool hasHIP() { + return detail::getHIPHooks().hasHIP(); + } + static bool hasMPS() { + return detail::getMPSHooks().hasMPS(); + } + static bool hasIPU() { + return c10::impl::hasDeviceGuardImpl(c10::DeviceType::IPU); + } + static bool hasXLA() { + return detail::getXLAHooks().hasXLA(); + } + static bool hasXPU() { + return detail::getXPUHooks().hasXPU(); + } + static bool hasLazy() { + return c10::impl::hasDeviceGuardImpl(c10::DeviceType::Lazy); + } + static bool hasMAIA() { + return c10::impl::hasDeviceGuardImpl(c10::DeviceType::MAIA); + } + static bool hasHPU() { + return detail::getHPUHooks().hasHPU(); + } + + static const at::cuda::NVRTC& getNVRTC() { + return detail::getCUDAHooks().nvrtc(); + } + + static bool setFlushDenormal(bool on); + + // NB: This method is *purely* whether or not a user requested + // that CuDNN was enabled, it doesn't actually say anything about + // whether or not CuDNN is actually usable. Use cudnn_is_acceptable + // to test this instead + bool userEnabledCuDNN() const; + void setUserEnabledCuDNN(bool e); + bool userEnabledMkldnn() const; + void setUserEnabledMkldnn(bool e); + bool benchmarkCuDNN() const; + void setBenchmarkCuDNN(bool /*b*/); + int benchmarkLimitCuDNN() const; + void setBenchmarkLimitCuDNN(int /*b*/); + bool immediateMiopen() const; + void setImmediateMiopen(bool /*b*/); + bool deterministicCuDNN() const; + void setDeterministicCuDNN(bool /*b*/); + bool deterministicMkldnn() const; + void setDeterministicMkldnn(bool /*b*/); + bool userEnabledNNPACK() const; + void setUserEnabledNNPACK(bool e); + + // Note [Disabling Fused SDP Kernels] + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + // Flash and Memory Efficient SDP kernels are enabled by default. + // However, they can be disabled by setting + // at::globalContext().setUserEnabledFlashSDP(false) flag. + // This is useful for debugging purposes. For example, if you want to + // compare the performance of the flash SDP kernels with the unfused + // kernel, you can disable the flash SDP kernels. By disabling + // the math SDP kernel, you can force your code to use flash kernels. + // The math SDP kernel can be disabled by setting + // at::globalContext().setUserEnabledMathSDP(false) flag. + void setSDPPriorityOrder(const std::vector& order); + std::array sDPPriorityOrder(); + + void setSDPUseFlash(bool /*e*/); + bool userEnabledFlashSDP() const; + + void setSDPUseMemEfficient(bool /*e*/); + bool userEnabledMemEfficientSDP() const; + + void setSDPUseMath(bool /*e*/); + bool userEnabledMathSDP() const; + + void setSDPUseCuDNN(bool /*e*/); + bool userEnabledCuDNNSDP() const; + + void setAllowFP16BF16ReductionMathSDP(bool /*e*/); + bool allowFP16BF16ReductionMathSDP() const; + + void setSDPUseOverrideable(bool /*e*/); + bool userEnabledOverrideableSDP() const; + + at::LinalgBackend linalgPreferredBackend() const; + void setLinalgPreferredBackend(at::LinalgBackend /*b*/); + + at::BlasBackend blasPreferredBackend(); + void setBlasPreferredBackend(at::BlasBackend /*b*/); + + at::ROCmFABackend getROCmFAPreferredBackend(); + void setROCmFAPreferredBackend(at::ROCmFABackend /*b*/); + + // Note [Enabling Deterministic Operations] + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + // Operations in PyTorch that normally act nondeterministically, but have an + // alternate deterministic implementation, should satisfy the following + // requirements: + // + // * Include this comment: "See Note [Enabling Deterministic Operations]" + // + // * Check the value of `at::globalContext().deterministicAlgorithms()` to + // toggle + // between nondeterministic and deterministic implementations. + // + // * Have an entry in the list of PyTorch operations that toggle between + // nondeterministic + // and deterministic implementations, in the docstring of + // `use_deterministic_algorithms()` in torch/__init__.py + // + // `example_func()` below shows an example of toggling between + // nondeterministic and deterministic implementations: + // + // void example_func() { + // // See Note [Enabling Deterministic Operations] + // if (at::globalContext().deterministicAlgorithms()) { + // example_func_deterministic(); + // } else { + // example_func_nondeterministic(); + // } + // } + + bool deterministicAlgorithms() const; + bool deterministicAlgorithmsWarnOnly() const; + void setDeterministicAlgorithms(bool /*b*/, bool /*warn_only*/); + bool deterministicFillUninitializedMemory() const; + void setDeterministicFillUninitializedMemory(bool /*b*/); + + // Note [Writing Nondeterministic Operations] + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + // Operations in PyTorch that act nondeterministically and do not have an + // alternate deterministic implementation should satisfy the following + // requirements: + // + // * Include this comment: "See Note [Writing Nondeterministic Operations]" + // + // * Include a comment explaining why the operation is nondeterministic. + // + // * Throw an error when `Context::deterministicAlgorithms()` is true. Most + // of the time, this should be accomplished by calling + // `at::globalContext().alertNotDeterminstic(). + // + // * Have an entry in the list of nondeterministic PyTorch operations in the + // docstring of `use_deterministic_algorithms()` in torch/__init__.py + // + // * Have a test function in `test/test_torch.py` whose name begins with + // `test_nondeterministic_alert_`. Alternatively, if CuBLAS workspace + // configuration is the reason for nondeterminism, the operation should be + // included in the `test_cublas_config_nondeterministic_alert` test. Any new + // tests should ideally follow a pattern similar to the existing ones. + // + // `example_func()` below shows an example of the comments and error-throwing + // code for a nondeterministic operation: + // + // void example_func() { + // // See Note [Writing Nondeterministic Operations] + // // Nondeterministic because + // at::globalContext().alertNondeterministic("example_func"); + // ... + // } + + // Throws an error if `Context::deterministicAlgorithms()` is true + static void alertNotDeterministic(std::string_view const& caller); + + void setFloat32MatmulPrecision(const std::string& s); + void setFloat32Precision( + Float32Backend backend, + Float32Op op, + Float32Precision p); + bool allowTF32CuDNN(std::optional op = std::nullopt) const; + void setAllowTF32CuDNN(bool /*b*/); + bool allowTF32OneDNN() const; + void setAllowTF32OneDNN(bool /*b*/); + bool allowTF32CuBLAS() const; + void setAllowTF32CuBLAS(bool /*b*/); + Float32MatmulPrecision float32MatmulPrecision() const; + Float32Precision float32Precision(Float32Backend backend, Float32Op op) const; + CuBLASReductionOption allowFP16ReductionCuBLAS() const; + void setAllowFP16ReductionCuBLAS( + bool allow_reduced_precision, + bool allow_splitk = true); + CuBLASReductionOption allowBF16ReductionCuBLAS() const; + void setAllowBF16ReductionCuBLAS( + bool allow_reduced_precision, + bool allow_splitk = true); + bool allowFP16AccumulationCuBLAS() const; + void setAllowFP16AccumulationCuBLAS(bool /*b*/); + bool rocmAllowGroupGemmCk() const; + + // Matmuls can use a so-called "persistent" kernel which launches one CUDA + // block for each SM on the GPU, and each block then iterates over multiple + // output tiles. This allows to use software pipelining to hide the begin/end + // latencies (e.g., epilogue), especially when only one tile fits per SM. + // However, if some SMs are busy (e.g., with a background NCCL kernel), the + // matmul's blocks will be scheduled in two waves and, in the absence of some + // smart load balancing, the kernel will take twice as long. This flag allows + // to make matmuls target only a subset of the SMs, so they can fully schedule + // even next to a comms kernel, and only be a few percent slower. + std::optional _SMCarveout_EXPERIMENTAL() const; + void _setSMCarveout_EXPERIMENTAL(std::optional /*c*/); + + at::QEngine qEngine() const; + void setQEngine(at::QEngine e); + static const std::vector& supportedQEngines(); + static bool isXNNPACKAvailable(); + void setCheckSparseTensorInvariants(bool e); + bool checkSparseTensorInvariants() const; + // This method is used to release the original weight after pre-packing. + // It should be called once before loading/running the model. + // NB: By default it is set to true for mobile builds. + void setReleaseWeightsWhenPrepacking(bool e); + bool releaseWeightsWhenPrepacking() const; + + void setDisplayVmapFallbackWarnings(bool enabled); + bool areVmapFallbackWarningsEnabled() const; + + void setWarnOnAccumulateGradStreamMismatch(bool enabled); + bool warnOnAccumulateGradStreamMismatch() const; + + bool isDefaultMobileCPUAllocatorSet(); + void setDefaultMobileCPUAllocator(); + void unsetDefaultMobileCPUAllocator(); + bool allowFP16ReductionCPU() const; + void setAllowFP16ReductionCPU(bool /*b*/); + + // Preserved for BC + void lazyInitCUDA() { + TORCH_WARN_DEPRECATION( + "lazyInitCUDA is deprecated. Please use lazyInitDevice(at::kCUDA) instead.") + lazyInitDevice(at::kCUDA); + } + void lazyInitHIP() { + TORCH_WARN_DEPRECATION( + "lazyInitHIP is deprecated. Please use lazyInitDevice(at::kHIP) instead.") + lazyInitDevice(at::kHIP); + } + void lazyInitXPU() { + TORCH_WARN_DEPRECATION( + "lazyInitXPU is deprecated. Please use lazyInitDevice(at::kXPU) instead.") + lazyInitDevice(at::kXPU); + } + void lazyInitMTIA() { + TORCH_WARN_DEPRECATION( + "lazyInitMTIA is deprecated. Please use lazyInitDevice(at::kMTIA) instead.") + lazyInitDevice(at::kMTIA); + } + void lazyInitPrivateUse1() { + TORCH_WARN_DEPRECATION( + "lazyInitPrivateUse1 is deprecated. Please use lazyInitDevice(at::kPrivateUse1) instead.") + lazyInitDevice(at::kPrivateUse1); + } + + private: + std::array init_; + bool enabled_cudnn = true; + bool deterministic_cudnn = false; + bool deterministic_mkldnn = false; + bool _deterministic_algorithms = false; + bool _deterministic_algorithms_warn_only = false; + bool _deterministic_fill_uninitialized_memory = true; + std::array sdp_priority_order = { + at::SDPBackend::flash_attention, + at::SDPBackend::efficient_attention, + at::SDPBackend::math, + at::SDPBackend::cudnn_attention, + at::SDPBackend::overrideable}; + bool enabled_flashSDP = true; + bool enabled_mem_efficientSDP = true; + bool enabled_mathSDP = true; + bool enabled_cudnnSDP = true; + bool enabled_overrideable = true; + bool allow_fp16_bf16_reduction_mathSDP = false; + bool benchmark_cudnn = false; + bool immediate_miopen = false; + Float32MatmulPrecision float32_matmul_precision = + c10::utils::check_env("TORCH_ALLOW_TF32_CUBLAS_OVERRIDE") == true + ? at::Float32MatmulPrecision::HIGH + : at::Float32MatmulPrecision::HIGHEST; + int benchmark_limit_cudnn = 10; + bool allow_tf32_cudnn = true; + CuBLASReductionOption allow_fp16_reduction_cublas = + CuBLASReductionOption::AllowReducedPrecisionWithSplitK; + CuBLASReductionOption allow_bf16_reduction_cublas = + CuBLASReductionOption::AllowReducedPrecisionWithSplitK; + bool allow_fp16_accumulation_cublas = false; + std::optional sm_carveout = std::nullopt; + bool enabled_mkldnn = true; + bool allow_tf32_onednn = false; + bool enabled_nnpack = true; + at::LinalgBackend linalg_preferred_backend = + (c10::utils::check_env("TORCH_LINALG_PREFER_CUSOLVER") == true || + c10::utils::check_env("TORCH_LINALG_PREFER_HIPSOLVER") == true) // alias + ? at::LinalgBackend::Cusolver + : at::LinalgBackend::Default; + at::BlasBackend blas_preferred_backend = + (c10::utils::check_env("TORCH_BLAS_PREFER_CUBLASLT") == true || + c10::utils::check_env("TORCH_BLAS_PREFER_HIPBLASLT") == true) // alias + ? at::BlasBackend::Cublaslt + : at::BlasBackend::Default; + at::ROCmFABackend rocm_fa_preferred_backend = + c10::utils::check_env("TORCH_ROCM_FA_PREFER_CK") == true + ? at::ROCmFABackend::Ck + : at::ROCmFABackend::Default; +#ifdef C10_MOBILE + bool release_original_weights = true; +#else + bool release_original_weights = false; +#endif + bool display_vmap_fallback_warnings_ = false; + bool warn_on_accumulate_grad_stream_mismatch_ = true; + std::atomic quantized_engine = at::QEngine::NoQEngine; + bool enable_sparse_tensor_invariant_checks = false; + bool allow_fp16_reduction_cpu = false; + + using Key = std::pair; + std::unordered_map> fp32_precision = { + {{Float32Backend::GENERIC, Float32Op::ALL}, Float32Precision::NONE}, + {{Float32Backend::MKLDNN, Float32Op::ALL}, Float32Precision::NONE}, + {{Float32Backend::MKLDNN, Float32Op::CONV}, Float32Precision::NONE}, + {{Float32Backend::MKLDNN, Float32Op::RNN}, Float32Precision::NONE}, + {{Float32Backend::MKLDNN, Float32Op::MATMUL}, Float32Precision::NONE}, + {{Float32Backend::CUDA, Float32Op::ALL}, Float32Precision::NONE}, + {{Float32Backend::CUDA, Float32Op::CONV}, Float32Precision::TF32}, + {{Float32Backend::CUDA, Float32Op::RNN}, Float32Precision::TF32}, + {{Float32Backend::CUDA, Float32Op::MATMUL}, + float32_matmul_precision == at::Float32MatmulPrecision::HIGHEST + ? Float32Precision::NONE + : Float32Precision::TF32}, + }; + + Allocator* prev_allocator_ptr_{nullptr}; +}; + +TORCH_API Context& globalContext(); + +inline void init() { + globalContext(); +} + +TORCH_API Allocator* getCPUAllocator(); + +inline DeprecatedTypeProperties& getDeprecatedTypeProperties( + Backend p, + ScalarType s) { + return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties( + p, s); +} + +inline DeprecatedTypeProperties& CPU(ScalarType s) { + return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties( + Backend::CPU, s); +} + +inline DeprecatedTypeProperties& CUDA(ScalarType s) { + return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties( + Backend::CUDA, s); +} + +inline DeprecatedTypeProperties& HIP(ScalarType s) { + return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties( + Backend::HIP, s); +} + +inline DeprecatedTypeProperties& MPS(ScalarType s) { + return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties( + Backend::MPS, s); +} + +inline bool hasCUDA() { + return globalContext().hasCUDA(); +} + +inline bool hasMTIA() { + return globalContext().hasMTIA(); +} + +inline bool hasHIP() { + return globalContext().hasHIP(); +} + +inline bool hasIPU() { + return globalContext().hasIPU(); +} + +inline bool hasXLA() { + return globalContext().hasXLA(); +} + +inline bool hasMPS() { + return globalContext().hasMPS(); +} + +inline bool hasMAIA() { + return globalContext().hasMAIA(); +} + +inline bool hasXPU() { + return globalContext().hasXPU(); +} + +inline bool hasHPU() { + return globalContext().hasHPU(); +} + +// Despite its name, this function returns the number of *CUDA* GPUs. +inline size_t getNumGPUs() { + // WARNING: DO NOT ADD LOGIC TO HANDLE OTHER DEVICE TYPES TO THIS + // FUNCTION. If you are interested in interrogating the number of + // devices for a specific device type, add that function to the + // relevant library (e.g., similar to at::cuda::device_count()) + if (hasCUDA() && hasHIP()) { + TORCH_CHECK( + false, + "Enabling both CUDA and HIP in ATen is not supported, as HIP masquerades " + "to be CUDA (e.g., when you say CUDA, on a HIP build of ATen, this actually " + "means HIP. Rebuild PyTorch with one or the other disabled."); + } else if (hasCUDA()) { + return detail::getCUDAHooks().deviceCount(); + } else if (hasHIP()) { + return detail::getHIPHooks().getNumGPUs(); + } else { + return 0; + } +} + +inline bool hasOpenMP() { + return globalContext().hasOpenMP(); +} + +inline bool hasMKL() { + return globalContext().hasMKL(); +} + +inline bool hasKleidiAI() { + return globalContext().hasKleidiAI(); +} + +inline bool hasLAPACK() { + return globalContext().hasLAPACK(); +} + +inline bool hasEigenSparse() { + return globalContext().hasEigenSparse(); +} + +inline bool hasMAGMA() { + return globalContext().hasMAGMA(); +} + +inline bool hasMKLDNN() { + return globalContext().hasMKLDNN(); +} + +inline void manual_seed(uint64_t seed) { + { + auto gen = globalContext().defaultGenerator(c10::DeviceType::CPU); + // See Note [Acquire lock when using random generators] + std::lock_guard lock(gen.mutex()); + gen.set_current_seed(seed); + } + + const auto opt_device_type = at::getAccelerator(); + if (!opt_device_type.has_value()) { + return; + } + const auto num_gpus = globalContext() + .getAcceleratorHooksInterface(opt_device_type) + .deviceCount(); + for (const auto i : c10::irange(num_gpus)) { + auto gen = globalContext().defaultGenerator( + Device(opt_device_type.value(), static_cast(i))); + { + // See Note [Acquire lock when using random generators] + std::lock_guard lock(gen.mutex()); + gen.set_current_seed(seed); + } + } +} + +// When the global flag `allow_tf32` is set to true, cuBLAS handles are +// automatically configured to use math mode CUBLAS_TF32_TENSOR_OP_MATH. +// For some operators, such as addmv, TF32 offers no performance improvement +// but causes precision loss. To help this case, this class implements +// a RAII guard that can be used to quickly disable TF32 within its scope. +// +// Usage: +// NoTF32Guard disable_tf32; +struct TORCH_API NoTF32Guard { + NoTF32Guard(); + NoTF32Guard(NoTF32Guard&& other) = delete; + NoTF32Guard(const NoTF32Guard&) = delete; + NoTF32Guard& operator=(const NoTF32Guard&) = delete; + NoTF32Guard& operator=(NoTF32Guard&&) = delete; + ~NoTF32Guard(); + static bool should_disable_tf32(); + + private: + bool changed = false; +}; + +struct TORCH_API ROCmBackwardPassGuard { + ROCmBackwardPassGuard(); + ROCmBackwardPassGuard(ROCmBackwardPassGuard&& other) = delete; + ROCmBackwardPassGuard(const ROCmBackwardPassGuard&) = delete; + ROCmBackwardPassGuard& operator=(const ROCmBackwardPassGuard&) = delete; + ROCmBackwardPassGuard& operator=(ROCmBackwardPassGuard&&) = delete; + ~ROCmBackwardPassGuard(); + static bool is_backward_pass(); +}; +} // namespace at + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/DLConvertor.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/DLConvertor.h new file mode 100644 index 0000000000000000000000000000000000000000..95f9ca90b9927897344e46863f77136d19aa5e3d --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/DLConvertor.h @@ -0,0 +1,81 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include + +// this converter will: +// 1) take a Tensor object and wrap it in the DLPack tensor +// 2) take a dlpack tensor and convert it to the ATen Tensor + +namespace at { + +TORCH_API ScalarType toScalarType(const DLDataType& dtype); +TORCH_API DLManagedTensor* toDLPack(const Tensor& src); +TORCH_API struct DLManagedTensorVersioned* toDLPackVersioned(const Tensor& src); +TORCH_API void toDLPackNonOwning(const Tensor& src, DLTensor* out); +TORCH_API Tensor +fromDLPack(DLManagedTensor* src, std::function deleter = {}); +TORCH_API Tensor fromDLPackVersioned( + DLManagedTensorVersioned* src, + std::function deleter = {}); +TORCH_API DLDataType getDLDataType(const Tensor& t); +TORCH_API DLDevice getDLContext(const Tensor& tensor, const int64_t& device_id); + +// Copies the Tensor if there's a device mismatch or copy is forced. +// This should be used before actually creating the DLPack capsule. +TORCH_API Tensor maybeCopyTensor( + const Tensor& data, + std::optional optional_dl_device, + std::optional copy); + +// Converts the given at::Device into a DLDevice. +TORCH_API DLDevice torchDeviceToDLDevice(at::Device device); + +// Converts the DLDevice to an ATen device. +TORCH_API Device dlDeviceToTorchDevice( + DLDeviceType type, + c10::DeviceIndex index, + void* data = nullptr); + +// This trait class is used for retrieving different attributes, such as the +// PyCapsule names and conversion functions for both DLPack tensor classes: +// `DLManagedTensor` and `DLManagedTensorVersioned`. +// +// Each specialization should contain the following 2 traits: +// - `capsule`: actual name of the capsule +// - `used`: name of the capsule after using it +// - `toDLPack`: function for converting a tensor into a DLPack capsule +// - `fromDLPack`: function for creating a tensor from a DLPack capsule +// +// While `toDLPack` is the directly exposed to Python, `fromDLPack` is not. +// Although it contains the core implementation, it lacks the required book +// keeping logic contained in its caller `tensor_fromDLPack`. +// +// That said, `fromDLPack` is used directly in a few DLPack tests that live +// inside ATen (no Python available). +template +struct DLPackTraits {}; + +template <> +struct DLPackTraits { + inline static constexpr const char* capsule = "dltensor"; + inline static constexpr const char* used = "used_dltensor"; + inline static auto toDLPack = at::toDLPack; + inline static auto fromDLPack = at::fromDLPack; +}; + +template <> +struct DLPackTraits { + inline static constexpr const char* capsule = "dltensor_versioned"; + inline static constexpr const char* used = "used_dltensor_versioned"; + inline static auto toDLPack = at::toDLPackVersioned; + inline static auto fromDLPack = at::fromDLPackVersioned; +}; + +} // namespace at + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/DTensorState.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/DTensorState.h new file mode 100644 index 0000000000000000000000000000000000000000..f2449f6c8129d4c3e1b0341af1f324e8fcfdb7ff --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/DTensorState.h @@ -0,0 +1,39 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include + +namespace at { + +TORCH_API bool get_dtensor_allow_implicit_replication(); +TORCH_API void set_dtensor_allow_implicit_replication(bool enabled); + +struct DTensorAllowImplicitReplication { + DTensorAllowImplicitReplication() + : prev_dtensor_allow_implicit_replication_( + get_dtensor_allow_implicit_replication()) { + set_dtensor_allow_implicit_replication(true); + } + + DTensorAllowImplicitReplication(const DTensorAllowImplicitReplication&) = + delete; + DTensorAllowImplicitReplication& operator=( + const DTensorAllowImplicitReplication&) = delete; + DTensorAllowImplicitReplication(DTensorAllowImplicitReplication&&) = delete; + DTensorAllowImplicitReplication& operator=( + DTensorAllowImplicitReplication&&) = delete; + + ~DTensorAllowImplicitReplication() { + set_dtensor_allow_implicit_replication( + prev_dtensor_allow_implicit_replication_); + } + + private: + bool prev_dtensor_allow_implicit_replication_; +}; + +} // namespace at + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/Device.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/Device.h new file mode 100644 index 0000000000000000000000000000000000000000..484129812ba1d31d406e35106839fe5d35edcfed --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/Device.h @@ -0,0 +1,7 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +#include + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/DeviceAccelerator.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/DeviceAccelerator.h new file mode 100644 index 0000000000000000000000000000000000000000..5cea143f8707d958aea007f7d33e782480be89c3 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/DeviceAccelerator.h @@ -0,0 +1,118 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include + +#include +#include + +namespace at::accelerator { + +// Note [Accelerator Concept] +// This file defines the top level Accelerator concept for PyTorch. +// A device is an accelerator per the definition here if: +// - It is mutually exclusive with all other accelerators +// - It performs asynchronous compute via a Stream/Event system +// - It provides a set of common APIs as defined by AcceleratorHooksInterface +// +// As of today, accelerator devices are (in no particular order): +// CUDA, MTIA, XPU, HIP, MPS, PrivateUse1 + +// Ensures that only one accelerator is available (at +// compile time if possible) and return it. +// When checked is true, the returned optional always has a value. +TORCH_API std::optional getAccelerator(bool checked = false); + +// Check if the given device type is an accelerator. +TORCH_API bool isAccelerator(c10::DeviceType device_type); + +// Check if the given device type is an accelerator, not the excluded ones. +template < + typename... T, + typename = std::enable_if_t<(std::is_same_v && ...)>> +inline bool isAcceleratorExcluded( + c10::DeviceType device_type, + c10::DeviceType first_excluded, + T... rest_excluded) { + if constexpr (sizeof...(rest_excluded) > 0) { + return device_type != first_excluded && + isAcceleratorExcluded(device_type, rest_excluded...); + } else { + return device_type != first_excluded && isAccelerator(device_type); + } +} + +// Return the number of the device available. Note that this is *REQUIRED* to +// not raise any exception. +TORCH_API c10::DeviceIndex deviceCount(); + +// Set the current device index to the given device index. +TORCH_API void setDeviceIndex(c10::DeviceIndex device_index); + +// Get the current device index. +TORCH_API c10::DeviceIndex getDeviceIndex(); + +// Set the current stream to a given stream. Note that this API doesn't change +// the current device index. +TORCH_API void setCurrentStream(c10::Stream stream); + +// Get the current stream of the given device index. +TORCH_API c10::Stream getCurrentStream(c10::DeviceIndex device_index); + +// Wait (by blocking the calling thread) until all the work previously enqueued +// on the given device index has been completed. +TORCH_API void synchronizeDevice(c10::DeviceIndex device_index); + +// Set the current device index to the given device_index and return the +// original device index that was active before the change. +TORCH_API c10::DeviceIndex exchangeDevice(c10::DeviceIndex device_index); + +// Set the current device index to the given device_index. Avoid creating a new +// context if the context for device_index is not initialized. Return the +// original device index that was active before the change. +TORCH_API c10::DeviceIndex maybeExchangeDevice(c10::DeviceIndex device_index); + +// Get the device capability of the given device index. +TORCH_API c10::DeviceCapability getDeviceCapability( + c10::DeviceIndex device_index); + +TORCH_API inline void emptyCache() { + const auto device_type = getAccelerator(true).value(); + at::getDeviceAllocator(device_type)->emptyCache(); +} + +TORCH_API inline at::CachingDeviceAllocator::DeviceStats getDeviceStats( + c10::DeviceIndex device_index) { + const auto device_type = getAccelerator(true).value(); + return at::getDeviceAllocator(device_type)->getDeviceStats(device_index); +} + +TORCH_API inline void resetAccumulatedStats(c10::DeviceIndex device_index) { + const auto device_type = getAccelerator(true).value(); + at::getDeviceAllocator(device_type)->resetAccumulatedStats(device_index); +} + +TORCH_API inline void resetPeakStats(c10::DeviceIndex device_index) { + const auto device_type = getAccelerator(true).value(); + at::getDeviceAllocator(device_type)->resetPeakStats(device_index); +} + +TORCH_API inline std::pair getMemoryInfo( + c10::DeviceIndex device_index) { + const auto device_type = getAccelerator(true).value(); + return at::getDeviceAllocator(device_type)->getMemoryInfo(device_index); +} +} // namespace at::accelerator + +namespace at { +// Keep BC only +using at::accelerator::getAccelerator; +using at::accelerator::isAccelerator; +} // namespace at + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/DeviceGuard.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/DeviceGuard.h new file mode 100644 index 0000000000000000000000000000000000000000..2e54ef3bb0bf111db1a731f077dc142b0b7feaab --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/DeviceGuard.h @@ -0,0 +1,46 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include // TensorList whyyyyy + +namespace at { + +// Are you here because you're wondering why DeviceGuard(tensor) no +// longer works? For code organization reasons, we have temporarily(?) +// removed this constructor from DeviceGuard. The new way to +// spell it is: +// +// OptionalDeviceGuard guard(device_of(tensor)); + +/// Return the Device of a Tensor, if the Tensor is defined. +inline std::optional device_of(const Tensor& t) { + if (t.defined()) { + return t.device(); + } else { + return std::nullopt; + } +} + +inline std::optional device_of(const std::optional& t) { + return t.has_value() ? device_of(t.value()) : std::nullopt; +} + +/// Return the Device of a TensorList, if the list is non-empty and +/// the first Tensor is defined. (This function implicitly assumes +/// that all tensors in the list have the same device.) +inline std::optional device_of(ITensorListRef t) { + if (!t.empty()) { + return device_of(t.front()); + } else { + return std::nullopt; + } +} + +} // namespace at + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/DimVector.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/DimVector.h new file mode 100644 index 0000000000000000000000000000000000000000..fe267f8a808c7749c768cbdb9545eaf10c909982 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/DimVector.h @@ -0,0 +1,7 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +#include + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/Dimname.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/Dimname.h new file mode 100644 index 0000000000000000000000000000000000000000..91749c111a67ed9c1a2f728debd9080f4f60c071 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/Dimname.h @@ -0,0 +1,6 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#include + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/Dispatch.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/Dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..25fec1aba233db5b5b666ec88db2fa44168d3446 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/Dispatch.h @@ -0,0 +1,790 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#ifdef __CUDACC__ +#include // For CUDA_VERSION +#endif + +#ifdef TEMPLATE_SELECTIVE_BUILD +#include +#else +namespace at { +/** + * The method should_include_kernel_dtype() returns true/false + * based on whether the switching code for a specific dtype should be + * included based on build time constants generated from tracing model + * execution. This method will be implemented via code-generation and + * included in this file when code-gen is ready. + */ +inline constexpr bool should_include_kernel_dtype( + const char* /*kernel_tag_str*/, + at::ScalarType /*scalar_type*/ +) { + return true; +} +} // namespace at +#endif + +/** + * In the Facebook internal build (using BUCK), this macro is enabled by + * passing in -c pt.enable_record_kernel_dtype=1 when building the tracer + * binary. + */ +#if defined ENABLE_RECORD_KERNEL_FUNCTION_DTYPE +namespace at::detail { +TORCH_API void record_kernel_function_dtype(std::string name); +} // namespace at::detail + +#define RECORD_KERNEL_FUNCTION_DTYPE(NAME, enum_type) \ + at::detail::record_kernel_function_dtype( \ + std::string(NAME) + "$" + toString(enum_type)); +#else +#define RECORD_KERNEL_FUNCTION_DTYPE(NAME, enum_type) +#endif + +#define AT_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type) \ + do { \ + if constexpr (!at::should_include_kernel_dtype( \ + at_dispatch_name, enum_type)) { \ + TORCH_CHECK( \ + false, \ + "dtype '", \ + toString(enum_type), \ + "' not selected for kernel tag ", \ + at_dispatch_name); \ + } \ + } while (0) + +#define AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, HINT, ...) \ + THO_PRIVATE_CASE_TYPE_USING_HINT_TMPL( \ + AT_PRIVATE_CHECK_SELECTIVE_BUILD, enum_type, HINT, __VA_ARGS__) + +#define AT_DISPATCH_CASE(enum_type, ...) \ + AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, scalar_t, __VA_ARGS__) + +#define AT_DISPATCH_CASE_QINT(enum_type, scalar_type, ...) \ + case enum_type: { \ + AT_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type); \ + using scalar_t = scalar_type; \ + using underlying_t [[maybe_unused]] = typename scalar_t::underlying; \ + [[maybe_unused]] const auto& SCALAR_TYPE = enum_type; \ + [[maybe_unused]] const auto& UNDERLYING_TYPE = toUnderlying(enum_type); \ + return __VA_ARGS__(); \ + } + +#define AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \ + enum_type, scalar_type, bitwidth, qmin, qmax, ...) \ + case enum_type: { \ + AT_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type); \ + using scalar_t = scalar_type; \ + using underlying_t [[maybe_unused]] = typename scalar_t::underlying; \ + [[maybe_unused]] const auto& SCALAR_TYPE = enum_type; \ + [[maybe_unused]] const auto& UNDERLYING_TYPE = toUnderlying(enum_type); \ + [[maybe_unused]] int bit_width = bitwidth; \ + [[maybe_unused]] int64_t quant_min = qmin; \ + [[maybe_unused]] int64_t quant_max = qmax; \ + return __VA_ARGS__(); \ + } + +// The AT_DISPATCH_* family of macros provides the ability to +// conveniently generate specializations of a kernel over all of the +// dtypes we care about in PyTorch. We call it "dispatch" because +// we are "dispatching" to the correct, dtype-specific kernel. +// +// A standard usage looks like: +// +// AT_DISPATCH_ALL_TYPES(self.scalar_type(), "op_name", [&] { +// // Your code here, with 'scalar_t' now defined to +// // be the dtype in question +// }); +// +// There are many variations of this macro, so it's important to +// understand exactly /which/ dtypes you want to get instantiated, as +// well as what the "default" set is. +// +// The default set of dtypes that are instantiated (e.g., by +// AT_DISPATCH_ALL_TYPES) are floating point types (float, double), +// and integral types (int32_t, int64_t, int16_t, int8_t, uint8_t), +// but NOT booleans (bool), half-precision floats (Half) or +// complex number (c10::complex, c10::complex). +// This "cut" is somewhat historical (the default types are the +// ones that TH historically supported), but it also reflects the +// fact that the non-default types are "poorly" behaved (booleans +// are NOT integers mod 2, half precision operations ~essentially +// don't exist on CPU, complex numbers are an experimental application). +// +// Here are the questions you should generally ask to decide which +// dispatch you want: +// +// 1. Is this an integral or floating point specific operation? +// (If so, you'll want one of the FLOATING or INTEGRAL macros.) +// +// 2. Should half be supported? (If you're on CPU, the answer is almost +// definitely no. If you do want support, use one of the AND_HALF +// macros) +// +// Much rarer situations: +// +// 3. Should bool be supported? (You often have to write your kernel +// differently if arithmetic operations are involved.) If so, +// Use AT_DISPATCH_ALL_TYPES_AND along with ScalarType::Bool +// +// 4. Should complex be supported? The answer is almost always no, +// unless you are working on "generic" code that should work on +// all dtypes. +// +// Parameters: +// ----------- +// +// 1. The NAME argument is a "tag" that is used to trace and then +// conditionally compile fragments of the case statements such +// that the kernel functions are specialized only for the dtypes +// that are needed. The NAME parameter *must* be a build time +// const char* (can't be std::string, etc...) +// +// Please ensure that the NAME is unique for every implementation +// or you run the risk of over-including code for the kernel +// functions. There is no risk of missing out on any code, so +// it's mostly a risk of a Type-2 error, and not a Type-1 error. +// +// Switch-like syntax: +// ------------------- +// There is also a switch-case like syntax which is useful if a kernel +// needs to be specialized for particular scalar types +// +// AT_DISPATCH_SWITCH(self.scalar_type(), "op_name", +// AT_DISPATCH_CASE_INTEGRAL_TYPES([&] { +// op_integral(iter); +// }) +// AT_DISPATCH_CASE_FLOATING_TYPES([&] { +// op_floating(iter); +// }) +// AT_DISPATCH_CASE(kBool, [&] { +// op_bool(iter); +// }) +// ); +// +// For each AT_DISPATCH_FOO macro, there is a corresponding +// AT_DISPATCH_CASE_FOO macro which can be used inside of an +// AT_DISPATCH_SWITCH block. + +// NB: the the_type variable is not used, but we have kept it for +// backwards compatibility. It's probably not used by anyone though; +// but we're just being safe (and it doesn't hurt.) Note we must +// use it to shut up warnings about unused store. + +#define AT_DISPATCH_SWITCH(TYPE, NAME, ...) \ + THO_DISPATCH_SWITCH_TMPL( \ + RECORD_KERNEL_FUNCTION_DTYPE, \ + TORCH_CHECK_NOT_IMPLEMENTED, \ + TYPE, \ + NAME, \ + __VA_ARGS__) + +#define AT_DISPATCH_CASE_FLOATING_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Double, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) + +#define AT_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) + +#define AT_DISPATCH_CASE_FLOATING_TYPES_AND_HALF(...) \ + AT_DISPATCH_CASE(at::ScalarType::Double, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) + +#define AT_DISPATCH_FLOATING_TYPES_AND_HALF(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, NAME, AT_DISPATCH_CASE_FLOATING_TYPES_AND_HALF(__VA_ARGS__)) + +#define AT_DISPATCH_CASE_REDUCED_FLOATING_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) + +#define AT_DISPATCH_REDUCED_FLOATING_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, NAME, AT_DISPATCH_CASE_REDUCED_FLOATING_TYPES(__VA_ARGS__)) + +#define AT_DISPATCH_CASE_FLOATING_TYPES_AND(SCALARTYPE, ...) \ + AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__) + +#define AT_DISPATCH_FLOATING_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, \ + NAME, \ + AT_DISPATCH_CASE_FLOATING_TYPES_AND(SCALARTYPE, __VA_ARGS__)) + +#define AT_DISPATCH_CASE_FLOATING_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, ...) \ + AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) + +#define AT_DISPATCH_FLOATING_TYPES_AND2( \ + SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, \ + NAME, \ + AT_DISPATCH_CASE_FLOATING_TYPES_AND2( \ + SCALARTYPE1, SCALARTYPE2, __VA_ARGS__)) + +#define AT_DISPATCH_CASE_FLOATING_TYPES_AND3( \ + SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, ...) \ + AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) + +#define AT_DISPATCH_FLOATING_TYPES_AND3( \ + SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, \ + NAME, \ + AT_DISPATCH_CASE_FLOATING_TYPES_AND3( \ + SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, __VA_ARGS__)) + +#define AT_DISPATCH_CASE_FLOATING_TYPES_AND4( \ + SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, ...) \ + AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__) + +#define AT_DISPATCH_CASE_FLOATING_TYPES_AND5( \ + SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, SCALARTYPE5, ...) \ + AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE5, __VA_ARGS__) + +#define AT_DISPATCH_FLOATING_TYPES_AND4( \ + SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, \ + NAME, \ + AT_DISPATCH_CASE_FLOATING_TYPES_AND4( \ + SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, __VA_ARGS__)) + +#define AT_DISPATCH_FLOATING_TYPES_AND5( \ + SCALARTYPE1, \ + SCALARTYPE2, \ + SCALARTYPE3, \ + SCALARTYPE4, \ + SCALARTYPE5, \ + TYPE, \ + NAME, \ + ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, \ + NAME, \ + AT_DISPATCH_CASE_FLOATING_TYPES_AND5( \ + SCALARTYPE1, \ + SCALARTYPE2, \ + SCALARTYPE3, \ + SCALARTYPE4, \ + SCALARTYPE5, \ + __VA_ARGS__)) + +#define AT_DISPATCH_CASE_COMPLEX_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::ComplexDouble, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::ComplexFloat, __VA_ARGS__) + +#define AT_DISPATCH_COMPLEX_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_COMPLEX_TYPES(__VA_ARGS__)) + +#define AT_DISPATCH_CASE_COMPLEX_TYPES_AND(SCALARTYPE, ...) \ + AT_DISPATCH_CASE_COMPLEX_TYPES(__VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__) + +#define AT_DISPATCH_COMPLEX_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, NAME, AT_DISPATCH_CASE_COMPLEX_TYPES_AND(SCALARTYPE, __VA_ARGS__)) + +#define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(...) \ + AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__) \ + AT_DISPATCH_CASE_COMPLEX_TYPES(__VA_ARGS__) + +#define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, NAME, AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__)) + +#define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND1(SCALARTYPE, ...) \ + AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__) + +#define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1( \ + SCALARTYPE, TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, \ + NAME, \ + AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND1( \ + SCALARTYPE, __VA_ARGS__)) + +#define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND2( \ + SCALARTYPE1, SCALARTYPE2, ...) \ + AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) + +#define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2( \ + SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, \ + NAME, \ + AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND2( \ + SCALARTYPE1, SCALARTYPE2, __VA_ARGS__)) + +#define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND3( \ + SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, ...) \ + AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) + +#define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND3( \ + SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, \ + NAME, \ + AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND3( \ + SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, __VA_ARGS__)) + +#define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND4( \ + SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, ...) \ + AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__) + +#define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND4( \ + SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, \ + NAME, \ + AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND4( \ + SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, __VA_ARGS__)) + +#define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND5( \ + SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, SCALARTYPE5, ...) \ + AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE5, __VA_ARGS__) + +#define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND5( \ + SCALARTYPE1, \ + SCALARTYPE2, \ + SCALARTYPE3, \ + SCALARTYPE4, \ + SCALARTYPE5, \ + TYPE, \ + NAME, \ + ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, \ + NAME, \ + AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND5( \ + SCALARTYPE1, \ + SCALARTYPE2, \ + SCALARTYPE3, \ + SCALARTYPE4, \ + SCALARTYPE5, \ + __VA_ARGS__)) + +#define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND6( \ + SCALARTYPE1, \ + SCALARTYPE2, \ + SCALARTYPE3, \ + SCALARTYPE4, \ + SCALARTYPE5, \ + SCALARTYPE6, \ + ...) \ + AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE5, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE6, __VA_ARGS__) + +#define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND6( \ + SCALARTYPE1, \ + SCALARTYPE2, \ + SCALARTYPE3, \ + SCALARTYPE4, \ + SCALARTYPE5, \ + SCALARTYPE6, \ + TYPE, \ + NAME, \ + ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, \ + NAME, \ + AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND6( \ + SCALARTYPE1, \ + SCALARTYPE2, \ + SCALARTYPE3, \ + SCALARTYPE4, \ + SCALARTYPE5, \ + SCALARTYPE6, \ + __VA_ARGS__)) + +#define AT_DISPATCH_CASE_INTEGRAL_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) + +#define AT_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__)) + +#define AT_DISPATCH_CASE_INTEGRAL_TYPES_AND(SCALARTYPE, ...) \ + AT_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__) + +#define AT_DISPATCH_INTEGRAL_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, \ + NAME, \ + AT_DISPATCH_CASE_INTEGRAL_TYPES_AND(SCALARTYPE, __VA_ARGS__)) + +#define AT_DISPATCH_CASE_ALL_TYPES(...) \ + AT_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__) \ + AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__) + +#define AT_DISPATCH_ALL_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__)) + +#define AT_DISPATCH_CASE_QINT_TYPES(...) \ + AT_DISPATCH_CASE_QINT(at::kQInt8, at::qint8, __VA_ARGS__) \ + AT_DISPATCH_CASE_QINT(at::kQUInt8, at::quint8, __VA_ARGS__) \ + AT_DISPATCH_CASE_QINT(at::kQInt32, at::qint32, __VA_ARGS__) + +#define AT_DISPATCH_QINT_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_QINT_TYPES(__VA_ARGS__)) + +#define AT_DISPATCH_CASE_QINT_TYPES_AND(SCALARTYPE, ...) \ + AT_DISPATCH_CASE_QINT_TYPES(__VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__) + +#define AT_DISPATCH_QINT_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, NAME, AT_DISPATCH_CASE_QINT_TYPES_AND(SCALARTYPE, __VA_ARGS__)) + +#define AT_DISPATCH_CASE_QINT_BYTE_TYPES(...) \ + AT_DISPATCH_CASE_QINT(at::kQInt8, at::qint8, __VA_ARGS__) \ + AT_DISPATCH_CASE_QINT(at::kQUInt8, at::quint8, __VA_ARGS__) + +#define AT_DISPATCH_QINT_BYTE_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_QINT_BYTE_TYPES(__VA_ARGS__)) + +#define AT_DISPATCH_CASE_QINT_AND_SUB_BYTE_TYPES(...) \ + AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \ + at::kQInt8, at::qint8, CHAR_BIT, SCHAR_MIN, SCHAR_MAX, __VA_ARGS__) \ + AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \ + at::kQUInt8, at::quint8, CHAR_BIT, 0, UCHAR_MAX, __VA_ARGS__) \ + AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \ + at::kQInt32, \ + at::qint32, \ + CHAR_BIT * sizeof(int), \ + INT_MIN, \ + INT_MAX, \ + __VA_ARGS__) \ + AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \ + at::kQUInt4x2, at::quint4x2, 4, 0, 15, __VA_ARGS__) \ + AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \ + at::kQUInt2x4, at::quint2x4, 2, 0, 3, __VA_ARGS__) + +#define AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, NAME, AT_DISPATCH_CASE_QINT_AND_SUB_BYTE_TYPES(__VA_ARGS__)) + +#define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(...) \ + AT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__) \ + AT_DISPATCH_CASE_COMPLEX_TYPES(__VA_ARGS__) + +#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, NAME, AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__)) + +#define AT_DISPATCH_CASE_ALL_TYPES_AND(SCALARTYPE, ...) \ + AT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__) + +#define AT_DISPATCH_ALL_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, NAME, AT_DISPATCH_CASE_ALL_TYPES_AND(SCALARTYPE, __VA_ARGS__)) + +#define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND(SCALARTYPE, ...) \ + AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__) + +#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(SCALARTYPE, TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, \ + NAME, \ + AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND(SCALARTYPE, __VA_ARGS__)) + +#define AT_DISPATCH_CASE_ALL_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, ...) \ + AT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) + +#define AT_DISPATCH_ALL_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, \ + NAME, \ + AT_DISPATCH_CASE_ALL_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, __VA_ARGS__)) + +#define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND2( \ + SCALARTYPE1, SCALARTYPE2, ...) \ + AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) + +#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2( \ + SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, \ + NAME, \ + AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND2( \ + SCALARTYPE1, SCALARTYPE2, __VA_ARGS__)) + +#define AT_DISPATCH_CASE_ALL_TYPES_AND3( \ + SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, ...) \ + AT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) + +#define AT_DISPATCH_ALL_TYPES_AND3( \ + SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, \ + NAME, \ + AT_DISPATCH_CASE_ALL_TYPES_AND3( \ + SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, __VA_ARGS__)) + +#define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND3( \ + SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, ...) \ + AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) + +#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3( \ + SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, \ + NAME, \ + AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND3( \ + SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, __VA_ARGS__)) + +#define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND4( \ + SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, ...) \ + AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__) + +#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( \ + SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, \ + NAME, \ + AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND4( \ + SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, __VA_ARGS__)) + +#define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND5( \ + SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, SCALARTYPE5, ...) \ + AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE5, __VA_ARGS__) + +#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND5( \ + SCALARTYPE1, \ + SCALARTYPE2, \ + SCALARTYPE3, \ + SCALARTYPE4, \ + SCALARTYPE5, \ + TYPE, \ + NAME, \ + ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, \ + NAME, \ + AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND5( \ + SCALARTYPE1, \ + SCALARTYPE2, \ + SCALARTYPE3, \ + SCALARTYPE4, \ + SCALARTYPE5, \ + __VA_ARGS__)) + +#define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND6( \ + SCALARTYPE1, \ + SCALARTYPE2, \ + SCALARTYPE3, \ + SCALARTYPE4, \ + SCALARTYPE5, \ + SCALARTYPE6, \ + ...) \ + AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE5, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE6, __VA_ARGS__) + +#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND6( \ + SCALARTYPE1, \ + SCALARTYPE2, \ + SCALARTYPE3, \ + SCALARTYPE4, \ + SCALARTYPE5, \ + SCALARTYPE6, \ + TYPE, \ + NAME, \ + ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, \ + NAME, \ + AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND6( \ + SCALARTYPE1, \ + SCALARTYPE2, \ + SCALARTYPE3, \ + SCALARTYPE4, \ + SCALARTYPE5, \ + SCALARTYPE6, \ + __VA_ARGS__)) + +#define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND7( \ + SCALARTYPE1, \ + SCALARTYPE2, \ + SCALARTYPE3, \ + SCALARTYPE4, \ + SCALARTYPE5, \ + SCALARTYPE6, \ + SCALARTYPE7, \ + ...) \ + AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE5, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE6, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE7, __VA_ARGS__) + +#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND7( \ + SCALARTYPE1, \ + SCALARTYPE2, \ + SCALARTYPE3, \ + SCALARTYPE4, \ + SCALARTYPE5, \ + SCALARTYPE6, \ + SCALARTYPE7, \ + TYPE, \ + NAME, \ + ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, \ + NAME, \ + AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND7( \ + SCALARTYPE1, \ + SCALARTYPE2, \ + SCALARTYPE3, \ + SCALARTYPE4, \ + SCALARTYPE5, \ + SCALARTYPE6, \ + SCALARTYPE7, \ + __VA_ARGS__)) + +#define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND8( \ + SCALARTYPE1, \ + SCALARTYPE2, \ + SCALARTYPE3, \ + SCALARTYPE4, \ + SCALARTYPE5, \ + SCALARTYPE6, \ + SCALARTYPE7, \ + SCALARTYPE8, \ + ...) \ + AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE5, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE6, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE7, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE8, __VA_ARGS__) + +#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND8( \ + SCALARTYPE1, \ + SCALARTYPE2, \ + SCALARTYPE3, \ + SCALARTYPE4, \ + SCALARTYPE5, \ + SCALARTYPE6, \ + SCALARTYPE7, \ + SCALARTYPE8, \ + TYPE, \ + NAME, \ + ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, \ + NAME, \ + AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND8( \ + SCALARTYPE1, \ + SCALARTYPE2, \ + SCALARTYPE3, \ + SCALARTYPE4, \ + SCALARTYPE5, \ + SCALARTYPE6, \ + SCALARTYPE7, \ + SCALARTYPE8, \ + __VA_ARGS__)) + +#define AT_DISPATCH_CASE_BIT_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Bits1x8, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Bits2x4, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Bits4x2, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Bits8, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Bits16, __VA_ARGS__) + +#define AT_DISPATCH_BIT_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_BIT_TYPES(__VA_ARGS__)) + +#define AT_DISPATCH_INDEX_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, \ + NAME, \ + AT_PRIVATE_CASE_TYPE_USING_HINT( \ + at::ScalarType::Int, index_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE_USING_HINT( \ + at::ScalarType::Long, index_t, __VA_ARGS__)) + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/Dispatch_v2.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/Dispatch_v2.h new file mode 100644 index 0000000000000000000000000000000000000000..35cfb4f653c9ea162cdc1289649d319db453d0a4 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/Dispatch_v2.h @@ -0,0 +1,182 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include + +// Get AT_DISPATCH_SWITCH and AT_DISPATCH_CASE: +#include + +// This is a new implementation of the AT_DISPATCH macro family from +// ATen/Dispatch.h +// +// The intended usage is: +// +// ScalarType scalar_type; +// +// AT_DISPATCH_V2( +// scalar_type, +// "debug string", +// AT_WRAP([&] { +// ... code to specialize with scalar_t ... +// }), +// kHalf, +// AT_EXPAND(AT_ALL_TYPES), +// ... as many types arguments as needed ... +// ) +// +// For example, given an old style: +// +// AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2( +// kComplexHalf, +// kHalf, +// self.scalar_type(), +// "_local_scalar_dense_cpu", +// [&] { +// scalar_t value = *self.data_ptr(); +// r = Scalar(value); +// } +// ) +// +// You now write: +// +// AT_DISPATCH_V2( +// self.scalar_type(), +// "_local_scalar_dense_cpu", +// AT_WRAP([&] { +// scalar_t value = *self.data_ptr(); +// r = Scalar(value); +// }), +// AT_EXPAND(AT_ALL_TYPES), +// AT_EXPAND(AT_COMPLEX_TYPES), +// kComplexHalf, +// kHalf, +// ) +// +// Notably, it sports the following improvements: +// +// - It is not necessary to specify the arity (e.g., +// AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND{2,3,4,...}) +// when using the macro +// +// - It is not necessary to specify each dtype individually; if +// there is a set of related dtypes and you want to dispatch +// over all of them, you can simply say, e.g., AT_EXPAND(AT_INTEGRAL_TYPES) +// in your argument list. +// +// However, you must remember to wrap the payload body in AT_WRAP, or commas +// inside your lambda will be improperly handled. Furthermore, if you more +// entries to ScalarType than can be supported by this macro, it will fail +// with an obscure error (due to attempting to concatenate AT_AP with +// something that is not a number). +// +// The implementation strategy is to use the count arguments trick +// (e.g., as described in https://stackoverflow.com/a/2124385/23845) +// to discover how many dtypes have been passed, and then dispatch to a +// hand-written macro for each arity that applies as many DISPATCH_CASE as +// necessary. The hand-written macros can be regenerated for other arities +// with the script below. +// +// There is some delicacy in the implementation in controlling when +// macro expansion occurs, mediated with AT_EXPAND and AT_GUARD. I mostly +// relied on GPT4 to help me get it right. + +// See documentation above +#define AT_DISPATCH_V2(TYPE, NAME, BODY, ...) \ + THO_DISPATCH_V2_TMPL( \ + AT_DISPATCH_SWITCH, \ + AT_DISPATCH_CASE, \ + TYPE, \ + NAME, \ + AT_WRAP(BODY), \ + __VA_ARGS__) + +// Unused helper macros, kept for BC: +#define AT_AP_VAR(N, T, ...) \ + AT_EXPAND(AT_CONCAT(AT_AP, AT_NUM_ARGS(__VA_ARGS__))(AT_WRAP(N), __VA_ARGS__)) + +// Ensure we never have too many scalar types for the expansion here to +// support. To bump this, you must regenerate the macros below. +static_assert(static_cast(c10::ScalarType::NumOptions) < 60); + +// Python code to regenerate generate code below: +#if 0 + +num_args = 60 + +for i in range(1, num_args+1): + args = ', '.join(f'_{i}' for i in range(1, i+1)) + cases = ' '.join([f'AT_DISPATCH_CASE(_{j}, N)' for j in range(1, i+1)]) + print(f'#define AT_AP{i}(N, {args}) {cases}') + +#endif + +// Begin generated code +// clang-format off + +#define AT_AP1(N, _1) AT_DISPATCH_CASE(_1, N) +#define AT_AP2(N, _1, _2) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) +#define AT_AP3(N, _1, _2, _3) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) +#define AT_AP4(N, _1, _2, _3, _4) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) +#define AT_AP5(N, _1, _2, _3, _4, _5) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) +#define AT_AP6(N, _1, _2, _3, _4, _5, _6) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) +#define AT_AP7(N, _1, _2, _3, _4, _5, _6, _7) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) +#define AT_AP8(N, _1, _2, _3, _4, _5, _6, _7, _8) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) +#define AT_AP9(N, _1, _2, _3, _4, _5, _6, _7, _8, _9) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) +#define AT_AP10(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) +#define AT_AP11(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) +#define AT_AP12(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) +#define AT_AP13(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) +#define AT_AP14(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) +#define AT_AP15(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) +#define AT_AP16(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) +#define AT_AP17(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) +#define AT_AP18(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) +#define AT_AP19(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) +#define AT_AP20(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) +#define AT_AP21(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) +#define AT_AP22(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) +#define AT_AP23(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) +#define AT_AP24(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) +#define AT_AP25(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) +#define AT_AP26(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) +#define AT_AP27(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) +#define AT_AP28(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) +#define AT_AP29(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) +#define AT_AP30(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) +#define AT_AP31(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) +#define AT_AP32(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) +#define AT_AP33(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) +#define AT_AP34(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) +#define AT_AP35(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) +#define AT_AP36(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) +#define AT_AP37(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) +#define AT_AP38(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) +#define AT_AP39(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) +#define AT_AP40(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) +#define AT_AP41(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) +#define AT_AP42(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) +#define AT_AP43(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N) +#define AT_AP44(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N) AT_DISPATCH_CASE(_44, N) +#define AT_AP45(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N) AT_DISPATCH_CASE(_44, N) AT_DISPATCH_CASE(_45, N) +#define AT_AP46(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N) AT_DISPATCH_CASE(_44, N) AT_DISPATCH_CASE(_45, N) AT_DISPATCH_CASE(_46, N) +#define AT_AP47(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N) AT_DISPATCH_CASE(_44, N) AT_DISPATCH_CASE(_45, N) AT_DISPATCH_CASE(_46, N) AT_DISPATCH_CASE(_47, N) +#define AT_AP48(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N) AT_DISPATCH_CASE(_44, N) AT_DISPATCH_CASE(_45, N) AT_DISPATCH_CASE(_46, N) AT_DISPATCH_CASE(_47, N) AT_DISPATCH_CASE(_48, N) +#define AT_AP49(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N) AT_DISPATCH_CASE(_44, N) AT_DISPATCH_CASE(_45, N) AT_DISPATCH_CASE(_46, N) AT_DISPATCH_CASE(_47, N) AT_DISPATCH_CASE(_48, N) AT_DISPATCH_CASE(_49, N) +#define AT_AP50(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N) AT_DISPATCH_CASE(_44, N) AT_DISPATCH_CASE(_45, N) AT_DISPATCH_CASE(_46, N) AT_DISPATCH_CASE(_47, N) AT_DISPATCH_CASE(_48, N) AT_DISPATCH_CASE(_49, N) AT_DISPATCH_CASE(_50, N) +#define AT_AP51(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N) AT_DISPATCH_CASE(_44, N) AT_DISPATCH_CASE(_45, N) AT_DISPATCH_CASE(_46, N) AT_DISPATCH_CASE(_47, N) AT_DISPATCH_CASE(_48, N) AT_DISPATCH_CASE(_49, N) AT_DISPATCH_CASE(_50, N) AT_DISPATCH_CASE(_51, N) +#define AT_AP52(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N) AT_DISPATCH_CASE(_44, N) AT_DISPATCH_CASE(_45, N) AT_DISPATCH_CASE(_46, N) AT_DISPATCH_CASE(_47, N) AT_DISPATCH_CASE(_48, N) AT_DISPATCH_CASE(_49, N) AT_DISPATCH_CASE(_50, N) AT_DISPATCH_CASE(_51, N) AT_DISPATCH_CASE(_52, N) +#define AT_AP53(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52, _53) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N) AT_DISPATCH_CASE(_44, N) AT_DISPATCH_CASE(_45, N) AT_DISPATCH_CASE(_46, N) AT_DISPATCH_CASE(_47, N) AT_DISPATCH_CASE(_48, N) AT_DISPATCH_CASE(_49, N) AT_DISPATCH_CASE(_50, N) AT_DISPATCH_CASE(_51, N) AT_DISPATCH_CASE(_52, N) AT_DISPATCH_CASE(_53, N) +#define AT_AP54(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52, _53, _54) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N) AT_DISPATCH_CASE(_44, N) AT_DISPATCH_CASE(_45, N) AT_DISPATCH_CASE(_46, N) AT_DISPATCH_CASE(_47, N) AT_DISPATCH_CASE(_48, N) AT_DISPATCH_CASE(_49, N) AT_DISPATCH_CASE(_50, N) AT_DISPATCH_CASE(_51, N) AT_DISPATCH_CASE(_52, N) AT_DISPATCH_CASE(_53, N) AT_DISPATCH_CASE(_54, N) +#define AT_AP55(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52, _53, _54, _55) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N) AT_DISPATCH_CASE(_44, N) AT_DISPATCH_CASE(_45, N) AT_DISPATCH_CASE(_46, N) AT_DISPATCH_CASE(_47, N) AT_DISPATCH_CASE(_48, N) AT_DISPATCH_CASE(_49, N) AT_DISPATCH_CASE(_50, N) AT_DISPATCH_CASE(_51, N) AT_DISPATCH_CASE(_52, N) AT_DISPATCH_CASE(_53, N) AT_DISPATCH_CASE(_54, N) AT_DISPATCH_CASE(_55, N) +#define AT_AP56(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52, _53, _54, _55, _56) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N) AT_DISPATCH_CASE(_44, N) AT_DISPATCH_CASE(_45, N) AT_DISPATCH_CASE(_46, N) AT_DISPATCH_CASE(_47, N) AT_DISPATCH_CASE(_48, N) AT_DISPATCH_CASE(_49, N) AT_DISPATCH_CASE(_50, N) AT_DISPATCH_CASE(_51, N) AT_DISPATCH_CASE(_52, N) AT_DISPATCH_CASE(_53, N) AT_DISPATCH_CASE(_54, N) AT_DISPATCH_CASE(_55, N) AT_DISPATCH_CASE(_56, N) +#define AT_AP57(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52, _53, _54, _55, _56, _57) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N) AT_DISPATCH_CASE(_44, N) AT_DISPATCH_CASE(_45, N) AT_DISPATCH_CASE(_46, N) AT_DISPATCH_CASE(_47, N) AT_DISPATCH_CASE(_48, N) AT_DISPATCH_CASE(_49, N) AT_DISPATCH_CASE(_50, N) AT_DISPATCH_CASE(_51, N) AT_DISPATCH_CASE(_52, N) AT_DISPATCH_CASE(_53, N) AT_DISPATCH_CASE(_54, N) AT_DISPATCH_CASE(_55, N) AT_DISPATCH_CASE(_56, N) AT_DISPATCH_CASE(_57, N) +#define AT_AP58(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52, _53, _54, _55, _56, _57, _58) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N) AT_DISPATCH_CASE(_44, N) AT_DISPATCH_CASE(_45, N) AT_DISPATCH_CASE(_46, N) AT_DISPATCH_CASE(_47, N) AT_DISPATCH_CASE(_48, N) AT_DISPATCH_CASE(_49, N) AT_DISPATCH_CASE(_50, N) AT_DISPATCH_CASE(_51, N) AT_DISPATCH_CASE(_52, N) AT_DISPATCH_CASE(_53, N) AT_DISPATCH_CASE(_54, N) AT_DISPATCH_CASE(_55, N) AT_DISPATCH_CASE(_56, N) AT_DISPATCH_CASE(_57, N) AT_DISPATCH_CASE(_58, N) +#define AT_AP59(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52, _53, _54, _55, _56, _57, _58, _59) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N) AT_DISPATCH_CASE(_44, N) AT_DISPATCH_CASE(_45, N) AT_DISPATCH_CASE(_46, N) AT_DISPATCH_CASE(_47, N) AT_DISPATCH_CASE(_48, N) AT_DISPATCH_CASE(_49, N) AT_DISPATCH_CASE(_50, N) AT_DISPATCH_CASE(_51, N) AT_DISPATCH_CASE(_52, N) AT_DISPATCH_CASE(_53, N) AT_DISPATCH_CASE(_54, N) AT_DISPATCH_CASE(_55, N) AT_DISPATCH_CASE(_56, N) AT_DISPATCH_CASE(_57, N) AT_DISPATCH_CASE(_58, N) AT_DISPATCH_CASE(_59, N) +#define AT_AP60(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52, _53, _54, _55, _56, _57, _58, _59, _60) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N) AT_DISPATCH_CASE(_44, N) AT_DISPATCH_CASE(_45, N) AT_DISPATCH_CASE(_46, N) AT_DISPATCH_CASE(_47, N) AT_DISPATCH_CASE(_48, N) AT_DISPATCH_CASE(_49, N) AT_DISPATCH_CASE(_50, N) AT_DISPATCH_CASE(_51, N) AT_DISPATCH_CASE(_52, N) AT_DISPATCH_CASE(_53, N) AT_DISPATCH_CASE(_54, N) AT_DISPATCH_CASE(_55, N) AT_DISPATCH_CASE(_56, N) AT_DISPATCH_CASE(_57, N) AT_DISPATCH_CASE(_58, N) AT_DISPATCH_CASE(_59, N) AT_DISPATCH_CASE(_60, N) + +// End generated code +// clang-format on + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/DynamicLibrary.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/DynamicLibrary.h new file mode 100644 index 0000000000000000000000000000000000000000..a86720c3249192f97947e085f77fd27710cb2a2a --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/DynamicLibrary.h @@ -0,0 +1,41 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include + +namespace c10 { + +class DynamicLibraryError : public Error { + using Error::Error; +}; + +} // namespace c10 + +namespace at { + +struct DynamicLibrary { + AT_DISALLOW_COPY_AND_ASSIGN(DynamicLibrary); + DynamicLibrary(DynamicLibrary&& other) = delete; + DynamicLibrary& operator=(DynamicLibrary&&) = delete; + + TORCH_API DynamicLibrary( + const char* name, + const char* alt_name = nullptr, + bool leak_handle = false); + + TORCH_API void* sym(const char* name); + + TORCH_API ~DynamicLibrary(); + + private: + bool leak_handle; + void* handle = nullptr; +}; + +} // namespace at + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/EmptyTensor.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/EmptyTensor.h new file mode 100644 index 0000000000000000000000000000000000000000..155b54409c9ec10b6fcf9f2af7c185d31839bda5 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/EmptyTensor.h @@ -0,0 +1,171 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +#include + +namespace at::detail { + +inline void check_size_nonnegative(ArrayRef size) { + for (const auto& x : size) { + TORCH_CHECK( + x >= 0, + "Trying to create tensor with negative dimension ", + x, + ": ", + size); + } +} + +inline void check_size_nonnegative(ArrayRef size) { + for (const auto& x : size) { + TORCH_SYM_CHECK( + x.sym_ge(0), + "Trying to create tensor with negative dimension ", + x, + ": ", + size); + } +} + +TORCH_API size_t computeStorageNbytesContiguous( + IntArrayRef sizes, + size_t itemsize, + size_t storage_offset = 0); +TORCH_API SymInt computeStorageNbytesContiguous( + SymIntArrayRef sizes, + const SymInt& itemsize, + const SymInt& storage_offset = 0); +TORCH_API size_t computeStorageNbytes( + IntArrayRef sizes, + IntArrayRef strides, + size_t itemsize, + size_t storage_offset = 0); +TORCH_API SymInt computeStorageNbytes( + SymIntArrayRef sizes, + SymIntArrayRef strides, + const SymInt& itemsize, + const SymInt& storage_offset = 0); + +TORCH_API TensorBase empty_generic( + IntArrayRef size, + c10::Allocator* allocator, + c10::DispatchKeySet ks, + ScalarType scalar_type, + std::optional memory_format_opt); + +TORCH_API TensorBase empty_generic_symint( + SymIntArrayRef size, + c10::Allocator* allocator, + c10::DispatchKeySet ks, + ScalarType scalar_type, + std::optional memory_format_opt); + +TORCH_API TensorBase empty_strided_generic( + IntArrayRef size, + IntArrayRef stride, + c10::Allocator* allocator, + c10::DispatchKeySet ks, + ScalarType scalar_type); + +TORCH_API TensorBase empty_strided_symint_generic( + SymIntArrayRef size, + SymIntArrayRef stride, + c10::Allocator* allocator, + c10::DispatchKeySet ks, + ScalarType scalar_type); + +TORCH_API TensorBase empty_cpu( + IntArrayRef size, + ScalarType dtype, + bool pin_memory = false, + std::optional memory_format_opt = std::nullopt); + +TORCH_API TensorBase empty_cpu( + IntArrayRef size, + std::optional dtype_opt, + std::optional layout_opt, + std::optional device_opt, + std::optional pin_memory_opt, + std::optional memory_format_opt); + +TORCH_API TensorBase empty_cpu(IntArrayRef size, const TensorOptions& options); + +TORCH_API TensorBase empty_strided_cpu( + IntArrayRef size, + IntArrayRef stride, + ScalarType dtype, + bool pin_memory = false); + +TORCH_API TensorBase empty_strided_cpu( + IntArrayRef size, + IntArrayRef stride, + std::optional dtype_opt, + std::optional layout_opt, + std::optional device_opt, + std::optional pin_memory_opt); + +TORCH_API TensorBase empty_strided_cpu( + IntArrayRef size, + IntArrayRef stride, + const TensorOptions& options); + +TORCH_API TensorBase empty_meta( + IntArrayRef size, + ScalarType dtype, + std::optional memory_format_opt = std::nullopt); + +TORCH_API TensorBase empty_meta( + IntArrayRef size, + std::optional dtype_opt, + std::optional layout_opt, + std::optional device_opt, + std::optional pin_memory_opt, + std::optional memory_format_opt); + +TORCH_API TensorBase empty_symint_meta( + SymIntArrayRef size, + std::optional dtype_opt, + std::optional layout_opt, + std::optional device_opt, + std::optional pin_memory_opt, + std::optional memory_format_opt); + +TORCH_API TensorBase empty_meta(IntArrayRef size, const TensorOptions& options); + +TORCH_API TensorBase +empty_strided_meta(IntArrayRef size, IntArrayRef stride, ScalarType dtype); + +TORCH_API TensorBase empty_strided_meta( + IntArrayRef size, + IntArrayRef stride, + std::optional dtype_opt, + std::optional layout_opt, + std::optional device_opt, + std::optional pin_memory_opt); + +TORCH_API TensorBase empty_strided_meta( + IntArrayRef size, + IntArrayRef stride, + const TensorOptions& options); + +TORCH_API TensorBase empty_strided_symint_meta( + SymIntArrayRef size, + SymIntArrayRef stride, + ScalarType dtype); + +TORCH_API TensorBase empty_strided_symint_meta( + SymIntArrayRef size, + SymIntArrayRef stride, + std::optional dtype_opt, + std::optional layout_opt, + std::optional device_opt); + +TORCH_API TensorBase empty_strided_symint_meta( + SymIntArrayRef size, + SymIntArrayRef stride, + const TensorOptions& options); + +} // namespace at::detail + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/ExpandBase.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/ExpandBase.h new file mode 100644 index 0000000000000000000000000000000000000000..2d223ca906ef7cf868e57350fa47e775693bee57 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/ExpandBase.h @@ -0,0 +1,35 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#include + +// Broadcasting utilities for working with TensorBase +namespace at { +namespace internal { +TORCH_API TensorBase expand_slow_path(const TensorBase& self, IntArrayRef size); +} // namespace internal + +inline c10::MaybeOwned expand_size( + const TensorBase& self, + IntArrayRef size) { + if (size.equals(self.sizes())) { + return c10::MaybeOwned::borrowed(self); + } + return c10::MaybeOwned::owned( + at::internal::expand_slow_path(self, size)); +} +c10::MaybeOwned expand_size(TensorBase&& self, IntArrayRef size) = + delete; + +inline c10::MaybeOwned expand_inplace( + const TensorBase& tensor, + const TensorBase& to_expand) { + return expand_size(to_expand, tensor.sizes()); +} +c10::MaybeOwned expand_inplace( + const TensorBase& tensor, + TensorBase&& to_expand) = delete; + +} // namespace at + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/ExpandUtils.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/ExpandUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..ccadded2b30c08843f38983754aa5ad3b9d2d347 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/ExpandUtils.h @@ -0,0 +1,540 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#include +#endif + +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace at { + +TORCH_API std::vector infer_size(IntArrayRef a, IntArrayRef b); +TORCH_API std::vector infer_size_symint( + SymIntArrayRef a, + SymIntArrayRef b); +TORCH_API DimVector infer_size_dimvector(IntArrayRef a, IntArrayRef b); +TORCH_API SymDimVector +infer_size_symdimvector(SymIntArrayRef a, SymIntArrayRef b); + +// Named type instead of a pair/tuple so that we can be sure to +// construct the vectors in place and get NRVO. +template +struct InferExpandGeometryResult { + Container sizes; + Container strides; + explicit InferExpandGeometryResult(size_t ndim) + : sizes(ndim), strides(ndim) {} + explicit InferExpandGeometryResult(IntArrayRef sizes_, size_t ndim) + : sizes(sizes_.begin(), sizes_.end()), strides(ndim) {} +}; + +TORCH_API std::tuple, std::vector> +inferExpandGeometry( + IntArrayRef tensor_sizes, + IntArrayRef tensor_strides, + IntArrayRef sizes); + +TORCH_API InferExpandGeometryResult inferExpandGeometry_dimvector( + IntArrayRef tensor_sizes, + IntArrayRef tensor_strides, + IntArrayRef sizes); + +TORCH_API std::vector infer_dense_strides( + IntArrayRef tensor_sizes, + IntArrayRef tensor_strides); + +// True if input shapes are expandable +// NOTE: infer_size did a similar check, please keep them sync if change is +// needed +inline bool are_expandable(IntArrayRef shape1, IntArrayRef shape2) { + size_t ndim1 = shape1.size(); + size_t ndim2 = shape2.size(); + size_t ndim = ndim1 < ndim2 ? ndim1 : ndim2; + + for (int64_t i = static_cast(ndim) - 1; i >= 0; --i) { + if (shape1[--ndim1] == shape2[--ndim2] || shape1[ndim1] == 1 || + shape2[ndim2] == 1) { + continue; + } + return false; + } + return true; +} + +// avoid copy-construction of Tensor by using a reference_wrapper. +inline void check_defined( + std::initializer_list> tensors, + const char* api_name) { + for (auto& t : tensors) { + if (!t.get().defined()) { + TORCH_CHECK(false, api_name, "(...) called with an undefined Tensor"); + } + } +} + +// NOTE [ ExpandUtils Borrowing ] +// +// Functions in ExpandUtils return `c10::MaybeOwned` because +// expansion may not actually be needed, in which case we can improve +// efficiency by returning +// `c10::MaybeOwned::borrowed(to_expand)`. However, this means +// that you need to be careful: the returned `c10::MaybeOwned` +// must not outlive the original `Tensor` object that `to_expand` +// referred to! The deleted rvalue reference overloads of these +// functions help with this by preventing trivial use of a temporary +// resulting from a function call, but it is still possible to make a +// mistake. + +inline c10::MaybeOwned expand_inplace( + const Tensor& tensor, + const Tensor& to_expand) { + if (tensor.sym_sizes().equals(to_expand.sym_sizes())) { + return c10::MaybeOwned::borrowed(to_expand); + } + return c10::MaybeOwned::owned( + to_expand.expand_symint(tensor.sym_sizes())); +} + +inline c10::MaybeOwned expand_inplace( + const Tensor& tensor, + Tensor&& to_expand) = delete; + +inline c10::MaybeOwned expand_inplace( + const Tensor& tensor, + const Tensor& to_expand, + const char* api_name) { + check_defined({tensor, to_expand}, api_name); + return expand_inplace(tensor, to_expand); +} + +inline c10::MaybeOwned expand_inplace( + const Tensor& tensor, + Tensor&& to_expand, + const char* api_name) = delete; + +inline std::tuple, c10::MaybeOwned> +expand_inplace( + const Tensor& tensor, + const Tensor& to_expand1, + const Tensor& to_expand2) { + if (tensor.sizes().equals(to_expand1.sizes()) && + tensor.sizes().equals((to_expand2.sizes()))) { + return std::make_tuple( + c10::MaybeOwned::borrowed(to_expand1), + c10::MaybeOwned::borrowed(to_expand2)); + } + + return std::make_tuple( + c10::MaybeOwned::owned(to_expand1.expand(tensor.sizes())), + c10::MaybeOwned::owned(to_expand2.expand(tensor.sizes()))); +} + +inline std::tuple, c10::MaybeOwned> +expand_inplace( + const Tensor& tensor, + Tensor&& to_expand1, + const Tensor& to_expand2) = delete; +inline std::tuple, c10::MaybeOwned> +expand_inplace( + const Tensor& tensor, + const Tensor& to_expand1, + Tensor&& to_expand2) = delete; +inline std::tuple, c10::MaybeOwned> +expand_inplace(const Tensor& tensor, Tensor&& to_expand1, Tensor&& to_expand2) = + delete; + +inline std::tuple, c10::MaybeOwned> +expand_inplace( + const Tensor& tensor, + const Tensor& to_expand1, + const Tensor& to_expand2, + const char* api_name) { + check_defined({tensor, to_expand1, to_expand2}, api_name); + return expand_inplace(tensor, to_expand1, to_expand2); +} + +inline std::tuple, c10::MaybeOwned> +expand_inplace( + const Tensor& tensor, + Tensor&& to_expand1, + const Tensor& to_expand2, + const char* api_name) = delete; +inline std::tuple, c10::MaybeOwned> +expand_inplace( + const Tensor& tensor, + const Tensor& to_expand1, + Tensor&& to_expand2, + const char* api_name) = delete; +inline std::tuple, c10::MaybeOwned> +expand_inplace( + const Tensor& tensor, + Tensor&& to_expand1, + Tensor&& to_expand2, + const char* api_name) = delete; + +// See NOTE [ ExpandUtils Borrowing ] above for `MaybeOwned` explanation. +inline std::tuple, c10::MaybeOwned> +expand_outplace(const Tensor& to_expand1, const Tensor& to_expand2) { + auto s1 = to_expand1.sym_sizes(); + auto s2 = to_expand2.sym_sizes(); + if (s1.equals(s2)) { + return std::make_tuple( + c10::MaybeOwned::borrowed(to_expand1), + c10::MaybeOwned::borrowed(to_expand2)); + } + + auto expanded_size = infer_size_symdimvector(s1, s2); + return std::make_tuple( + c10::MaybeOwned::owned(to_expand1.expand_symint(expanded_size)), + c10::MaybeOwned::owned(to_expand2.expand_symint(expanded_size))); +} + +inline std::tuple, c10::MaybeOwned> +expand_outplace(Tensor&& to_expand1, const Tensor& to_expand2) = delete; +inline std::tuple, c10::MaybeOwned> +expand_outplace(const Tensor& to_expand1, Tensor&& to_expand2) = delete; +inline std::tuple, c10::MaybeOwned> +expand_outplace(Tensor&& to_expand1, Tensor&& to_expand2) = delete; + +inline std::tuple, c10::MaybeOwned> +expand_outplace( + const Tensor& to_expand1, + const Tensor& to_expand2, + const char* api_name) { + check_defined({to_expand1, to_expand2}, api_name); + return expand_outplace(to_expand1, to_expand2); +} + +inline std::tuple, c10::MaybeOwned> +expand_outplace( + Tensor&& to_expand1, + const Tensor& to_expand2, + const char* api_name) = delete; +inline std::tuple, c10::MaybeOwned> +expand_outplace( + const Tensor& to_expand1, + Tensor&& to_expand2, + const char* api_name) = delete; +inline std::tuple, c10::MaybeOwned> +expand_outplace( + Tensor&& to_expand1, + Tensor&& to_expand2, + const char* api_name) = delete; + +inline std::tuple< + c10::MaybeOwned, + c10::MaybeOwned, + c10::MaybeOwned> +expand_outplace( + const Tensor& to_expand1, + const Tensor& to_expand2, + const Tensor& to_expand3) { + if (to_expand1.sizes().equals(to_expand2.sizes()) && + to_expand1.sizes().equals(to_expand3.sizes())) { + return std::make_tuple( + c10::MaybeOwned::borrowed(to_expand1), + c10::MaybeOwned::borrowed(to_expand2), + c10::MaybeOwned::borrowed(to_expand3)); + } + + auto expanded_size12 = + infer_size_dimvector(to_expand1.sizes(), to_expand2.sizes()); + auto expanded_size = + infer_size_dimvector(expanded_size12, to_expand3.sizes()); + return std::make_tuple( + c10::MaybeOwned::owned(to_expand1.expand(expanded_size)), + c10::MaybeOwned::owned(to_expand2.expand(expanded_size)), + c10::MaybeOwned::owned(to_expand3.expand(expanded_size))); +} + +inline std::tuple< + c10::MaybeOwned, + c10::MaybeOwned, + c10::MaybeOwned> +expand_outplace( + Tensor&& to_expand1, + const Tensor& to_expand2, + const Tensor& to_expand3) = delete; +inline std::tuple< + c10::MaybeOwned, + c10::MaybeOwned, + c10::MaybeOwned> +expand_outplace( + const Tensor& to_expand1, + Tensor&& to_expand2, + const Tensor& to_expand3) = delete; +inline std::tuple< + c10::MaybeOwned, + c10::MaybeOwned, + c10::MaybeOwned> +expand_outplace( + Tensor&& to_expand1, + Tensor&& to_expand2, + const Tensor& to_expand3) = delete; +inline std::tuple< + c10::MaybeOwned, + c10::MaybeOwned, + c10::MaybeOwned> +expand_outplace( + const Tensor& to_expand1, + const Tensor& to_expand2, + Tensor&& to_expand3) = delete; +inline std::tuple< + c10::MaybeOwned, + c10::MaybeOwned, + c10::MaybeOwned> +expand_outplace( + Tensor&& to_expand1, + const Tensor& to_expand2, + Tensor&& to_expand3) = delete; +inline std::tuple< + c10::MaybeOwned, + c10::MaybeOwned, + c10::MaybeOwned> +expand_outplace( + const Tensor& to_expand1, + Tensor&& to_expand2, + Tensor&& to_expand3) = delete; +inline std::tuple< + c10::MaybeOwned, + c10::MaybeOwned, + c10::MaybeOwned> +expand_outplace(Tensor&& to_expand1, Tensor&& to_expand2, Tensor&& to_expand3) = + delete; + +inline std::tuple< + c10::MaybeOwned, + c10::MaybeOwned, + c10::MaybeOwned> +expand_outplace( + const Tensor& to_expand1, + const Tensor& to_expand2, + const Tensor& to_expand3, + const char* api_name) { + check_defined({to_expand1, to_expand2, to_expand3}, api_name); + return expand_outplace(to_expand1, to_expand2, to_expand3); +} + +inline std::tuple< + c10::MaybeOwned, + c10::MaybeOwned, + c10::MaybeOwned> +expand_outplace( + Tensor&& to_expand1, + const Tensor& to_expand2, + const Tensor& to_expand3, + const char* api_name) = delete; +inline std::tuple< + c10::MaybeOwned, + c10::MaybeOwned, + c10::MaybeOwned> +expand_outplace( + const Tensor& to_expand1, + Tensor&& to_expand2, + const Tensor& to_expand3, + const char* api_name) = delete; +inline std::tuple< + c10::MaybeOwned, + c10::MaybeOwned, + c10::MaybeOwned> +expand_outplace( + Tensor&& to_expand1, + Tensor&& to_expand2, + const Tensor& to_expand3, + const char* api_name) = delete; +inline std::tuple< + c10::MaybeOwned, + c10::MaybeOwned, + c10::MaybeOwned> +expand_outplace( + const Tensor& to_expand1, + const Tensor& to_expand2, + Tensor&& to_expand3, + const char* api_name) = delete; +inline std::tuple< + c10::MaybeOwned, + c10::MaybeOwned, + c10::MaybeOwned> +expand_outplace( + Tensor&& to_expand1, + const Tensor& to_expand2, + Tensor&& to_expand3, + const char* api_name) = delete; +inline std::tuple< + c10::MaybeOwned, + c10::MaybeOwned, + c10::MaybeOwned> +expand_outplace( + const Tensor& to_expand1, + Tensor&& to_expand2, + Tensor&& to_expand3, + const char* api_name) = delete; +inline std::tuple< + c10::MaybeOwned, + c10::MaybeOwned, + c10::MaybeOwned> +expand_outplace( + Tensor&& to_expand1, + Tensor&& to_expand2, + Tensor&& to_expand3, + const char* api_name) = delete; + +inline c10::MaybeOwned expand_size( + const Tensor& to_expand, + IntArrayRef sizes) { + if (to_expand.sizes().equals(sizes)) { + return c10::MaybeOwned::borrowed(to_expand); + } + + return c10::MaybeOwned::owned(to_expand.expand(sizes)); +} + +inline c10::MaybeOwned expand_size( + Tensor&& to_expand, + IntArrayRef sizes) = delete; + +inline c10::MaybeOwned expand_size( + const Tensor& to_expand, + IntArrayRef sizes, + const char* api_name) { + check_defined({to_expand}, api_name); + return expand_size(to_expand, sizes); +} + +inline c10::MaybeOwned expand_size( + Tensor&& to_expand, + IntArrayRef sizes, + const char* api_name) = delete; + +inline std::vector expand_outplace(TensorList to_expand) { + // expands a list of Tensors; ignores undefined (null) tensors + bool first = true; + SymDimVector sizes; + for (const auto i : c10::irange(to_expand.size())) { + if (!to_expand[i].defined()) { + continue; + } else if (first) { + sizes = to_expand[i].sym_sizes(); + first = false; + } else { + sizes = infer_size_symdimvector(sizes, to_expand[i].sym_sizes()); + } + } + + std::vector result(to_expand.size()); + for (const auto i : c10::irange(to_expand.size())) { + if (!to_expand[i].defined()) { + continue; + } else if (to_expand[i].sym_sizes().equals(sizes)) { + result[i] = to_expand[i]; + } else { + result[i] = to_expand[i].expand_symint(sizes); + } + } + return result; +} + +template +inline Tensor _sum_to( + Tensor tensor, + const c10::ArrayRef shape, + bool always_return_non_view = false) { + if (shape.size() == 0) { + return tensor.sum(); + } + + auto sizes = at::symint::sizes(tensor); + c10::SmallVector reduce_dims; + const int64_t leading_dims = sizes.size() - shape.size(); + for (const auto i : c10::irange(leading_dims)) { + reduce_dims.push_back(i); + } + for (int64_t i = leading_dims; i < static_cast(sizes.size()); ++i) { + if (TORCH_GUARD_OR_FALSE(sym_eq(shape[i - leading_dims], 1)) && + TORCH_GUARD_OR_TRUE(sym_ne(sizes[i], 1))) { + reduce_dims.push_back(i); + } else { + // if we assume no reduction due to unbacked we ensure that at runtime. + TORCH_MAYBE_SYM_CHECK( + sym_eq(shape[i - leading_dims], sizes[i]), + "non-reduction path was assumed due to unbacked symbols expected those two sizes to be the same:", + shape[i - leading_dims], + ", ", + sizes[i]) + } + } + + if (!reduce_dims.empty()) { + tensor = tensor.sum(reduce_dims, /*keepdim=*/true); + } + + if (always_return_non_view) { + // This is only actually used by the functionalization pass. + // We want to be able to guarantee that this function doesn't return a view + // of the input. + return leading_dims > 0 ? at::symint::view_copy(tensor, shape) + : tensor.clone(); + } else { + return leading_dims > 0 ? at::symint::view(tensor, shape) : tensor; + } +} + +inline Tensor sum_to( + Tensor tensor, + const c10::SymIntArrayRef shape, + bool always_return_non_view = false) { + return _sum_to(std::move(tensor), shape, always_return_non_view); +} + +// Sums `tensor` repeatedly to produce a tensor of shape `shape`. +// Precondition: is_expandable_to(shape, tensor.sizes()) must be true +inline Tensor sum_to( + Tensor tensor, + const IntArrayRef shape, + bool always_return_non_view = false) { + return _sum_to(std::move(tensor), shape, always_return_non_view); +} + +inline bool is_expandable_to( + SymIntArrayRef shape, + c10::SymIntArrayRef desired) { + size_t ndim = shape.size(); + size_t target_dim = desired.size(); + if (ndim > target_dim) { + return false; + } + for (const auto i : c10::irange(ndim)) { + const auto& size = shape[ndim - i - 1]; + const auto& target = desired[target_dim - i - 1]; + if (size != target && size != 1) { + return false; + } + } + return true; +} + +inline bool is_expandable_to(IntArrayRef shape, IntArrayRef desired) { + auto sym_shape = c10::SymIntArrayRef( + reinterpret_cast(shape.data()), shape.size()); + auto sym_desired = c10::SymIntArrayRef( + reinterpret_cast(desired.data()), desired.size()); + return is_expandable_to(sym_shape, sym_desired); +} + +} // namespace at + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/Formatting.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/Formatting.h new file mode 100644 index 0000000000000000000000000000000000000000..446f03d859315a91f8e3eb16c74057e3b633cdea --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/Formatting.h @@ -0,0 +1,6 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#include + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/FuncTorchTLS.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/FuncTorchTLS.h new file mode 100644 index 0000000000000000000000000000000000000000..7b6e133b0730748cd5923ded1b440e5616c9dbe3 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/FuncTorchTLS.h @@ -0,0 +1,51 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include + +namespace at::functorch { + +// NOTE [functorch TLS in pytorch/pytorch] +// +// functorch lives out-of-tree. However, it has some TLS that needs to be +// propagated. The solution for that is we store a pointer to the TLS +// inside pytorch/pytorch and extend FuncTorchTLSBase inside functorch to +// include whatever functorch needs. +// +// We need to store a pointer due to the indirection: +// inside functorch, we will create a subclass of FunctorchTLSBase called +// FuncTorchTLSImpl that actually contains metadata, like the DynamicLayerStack. +// FuncTorchTLSBase doesn't have any metadata because it hasn't been defined +// yet. +// +// Here in pytorch/pytorch, we will pass around FuncTorchTLSBase*, but inside +// functorch, we will assign a FuncTorchTLSImpl* to the FunctorchTLSBase*. +// We can't directly pass around FunctorchTLSBase (without a pointer) because +// FuncTorchTLSImpl does not fit inside a FuncTorchTLSBase by virtue of having +// more elements. +struct TORCH_API FuncTorchTLSBase { + virtual ~FuncTorchTLSBase() = default; + virtual std::unique_ptr deepcopy() const = 0; + + virtual int64_t checkSupportsSingleLevelAutogradFunction() const = 0; + virtual void checkSupportsCppAutogradFunction() const = 0; + virtual void checkSupportsInplaceRequiresGrad() const = 0; + virtual void checkSupportsRetainGrad() const = 0; +}; + +// returns deepcopy of the functorch tls +TORCH_API std::unique_ptr getCopyOfFuncTorchTLS(); + +// sets the functorch tls. always does a deep copy. +TORCH_API void setFuncTorchTLS( + const std::shared_ptr& state); + +// get a mutable reference to the functorch tls +TORCH_API std::unique_ptr& functorchTLSAccessor(); + +} // namespace at::functorch + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/FunctionalStorageImpl.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/FunctionalStorageImpl.h new file mode 100644 index 0000000000000000000000000000000000000000..1ff595ec015cc1575985b406c1ee93018ae9d21c --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/FunctionalStorageImpl.h @@ -0,0 +1,274 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include + +#include + +namespace at::functionalization { + +// See Note [Functionalization Pass In Core] + +enum class InverseReturnMode { + /// Specifies that functional inverses should always return a view. + AlwaysView, + /// Specifies that functional inverses should always return a non-view / copy. + NeverView, + /// Specifies that functional inverses should return a view unless a (copying) + /// scatter + /// inverse exists, in which case that will be used instead. + /// This avoids as_strided() calls that can be difficult for subclasses to + /// handle. + ViewOrScatterInverse, +}; + +#define FUNCTIONALIZATION_VIEWMETA_NAME(TYPE) \ + static const char* name() { \ + return #TYPE; \ + } + +#define FUNCTIONALIZATION_VIEWMETA_SERIALIZABLE_TUPLE(...) \ + using SerializableTuple = std::tuple<__VA_ARGS__> + +// ViewMeta is a class used by the functionalization pass to navigate between +// a base tensor and a view tensor. +// For example, if I call `b = a.view1(...)` +// the functionalization pass will generate and store a ViewMeta specialization +// for `view1` operation on b that looks like: +// +// struct TORCH_API view1_ViewMeta : public ViewMeta { +// FUNCTIONALIZATION_VIEWMETA_NAME(view1_ViewMeta); +// FUNCTIONALIZATION_VIEWMETA_SERIALIZABLE_TUPLE( +// bool /* reapply_views */, +// const std::vector&); +// +// view1_ViewMeta(const SerializableTuple& tpl) +// : view1_ViewMeta(std::get<0>(tpl), std::get<1>(tpl)) {} +// +// view1_ViewMeta(bool reapply_views, const std::vector& size) +// : ViewMeta(/*has_symbolic_inputs=*/false), +// reapply_views(reapply_views), +// size(size) {} +// +// Tensor forward(const Tensor& base) override { +// return base.view1(...); +// } +// +// Tensor reverse(const Tensor& base, const Tensor& mutated_view) override { +// return at::functionalization::impl::view1_inverse(base, mutated_view, +// ...); +// } +// +// SerializableTuple to_serializable_tuple() { +// return std::make_tuple(reapply_views, size); +// } +// +// bool reapply_views; +// std::vector size; +// }; +// +// The forward function describes how to replay view1 on a tensor. +// +// The reverse function describes how, given a tensor that is already a view, +// how to get the corresponding base tensor. See Note [Functionalization Pass: +// View Inverses] for details. +// +// `SerializedTuple` is a typedef that defines an `std::tuple<...>` type +// representing the `ViewMeta` instance state. Methods that take in/return such +// a type are used for supporting pickle serialization. +struct ViewMeta { + ViewMeta( + bool has_symbolic_inputs, + bool is_multi_output = false, + bool is_as_strided = false, + int64_t out_idx = 0) + : out_index(out_idx), + is_multi_output(is_multi_output), + is_as_strided(is_as_strided), + has_symbolic_inputs(has_symbolic_inputs) {} + + virtual ~ViewMeta() = default; + + virtual Tensor forward(const Tensor& base) = 0; + virtual Tensor reverse(const Tensor& base, const Tensor& mutated_view) = 0; + + // See Note [out_idx in ViewMeta] + int64_t out_index; + + // Tells us if this is a multi-output view + bool is_multi_output; + + bool is_as_strided; + + // Tells us if this view operation has any symbolic inputs + bool has_symbolic_inputs; + + // Returns a new ViewMeta with the same forward/reverse + // functions, but a new out index. + // + // This method should be implemented by those `ViewMeta` that have more than + // one output. + virtual std::shared_ptr to_out_index(int64_t out_index) { + TORCH_CHECK_NOT_IMPLEMENTED( + false, + "ViewMeta::to_out_index not implemented. ", + "Likely because there's only one output."); + } +}; + +// FunctionalStorageImpl is a subclass of StorageImpl used by the +// functionalization pass. It has no underlying data (similar to meta storage). +// It also knows how to reflect mutations to tensors in the absence of a valid +// data pointer. +// +// A storage represents the state shared by (potentially multiple) views of the +// same tensor. For example, in the following code: +// +// b = a.view1(...) +// c = b.view2(...) +// b.add_(1) +// --> storage.add_update(b, {view1_meta}) +// +// The call to add_(1) will result in a call to alias.add_update(b, +// {view1_meta}), queueing up the mutation from b onto the alias. Later, suppose +// c is used in an expression (e.g. you try to print c, or pass it to an +// operator). Doing so will involve "syncing" c. First we apply any pending +// updates to the alias, and then we regenerate c by replaying its views off of +// the updated alias. E.g: +// +// print(str(c)) +// --> c.sync_() +// --> alias.apply_updates() // after this, the alias will be updated to +// reflect the mutation to b +struct TORCH_API FunctionalStorageImpl : public c10::StorageImpl { + public: + struct Update { + // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) + const at::Tensor new_val; + // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) + const std::vector> view_metas; + }; + + explicit FunctionalStorageImpl(const Tensor& value); + + void add_update( + const Tensor& updated_val, + const std::vector>& view_metas); + bool apply_updates(); + const Tensor& base() { + return base_; + } + size_t generation() const { + return generation_; + } + void freeze() { + frozen_ = true; + } + + c10::SymInt get_storage_size(bool before) { + if (before) { + return original_storage_size_; + } else { + return curr_storage_size_; + } + } + + ~FunctionalStorageImpl() override = default; + + uint64_t mutation_counter() { + return mutation_counter_; + } + void mark_mutation() { + mutation_counter_++; + } + void mark_mutation_during_no_grad_or_inference_mode() { + mutation_counter_during_no_grad_or_inference_mode_++; + } + void mark_mutation_hidden_from_autograd() { + mutation_counter_hidden_from_autograd_++; + } + + bool are_all_mutations_under_no_grad_or_inference_mode() const { + auto non_autograd_mutations = + mutation_counter_during_no_grad_or_inference_mode_ + + mutation_counter_hidden_from_autograd_; + // The <= is because both counters will technically be incremented, if we + // perform e.g. a triton kernel mutation under no_grad + return mutation_counter_ <= non_autograd_mutations; + } + + bool are_all_mutations_hidden_from_autograd() const { + // mutations under no_grad / inference_mode are technically not hidden from + // autograd - they change the version counter + return mutation_counter_ <= mutation_counter_hidden_from_autograd_; + } + + void mark_inductor_storage_resize(c10::SymInt new_size) { + inductor_storage_resized_ = true; + curr_storage_size_ = std::move(new_size); + inductor_storage_resized_counter_++; + } + + bool was_inductor_storage_resized() { + return inductor_storage_resized_; + } + + uint64_t inductor_storage_resized_counter() { + return inductor_storage_resized_counter_; + } + + private: + // NB: base_ should always point to a tensor BELOW the current + // functionalization layer. This is mainly to avoid reference cycles. e.g. + // given `b = a.view(...)` Both a.storage_ and b.storage_ are a + // FunctionStorageImpl containing an Walualias, with contains a Tensor + // `base_`. In this case (where a and b are FunctionalTensorWrapper's), base_ + // should point not to a, but to a's unwrapped value, a.value_` See Note + // [Functionalization: Walualias Removal] for a diagram that shows this + // visually. + at::Tensor base_; + std::vector updates_; + // generation_ gets incremented every time a mutation is queued onto the + // alias. It is used to determine if a given tensor is "up to date", or if it + // needs to be regenerated from the alias. + size_t generation_ = 0; + // If frozen, no more mutations are allowed on this storage. Once frozen, a + // storage cannot be unfrozen. + bool frozen_ = false; + + // These mutation counters are bumped on the storage + // whenever a FunctionalTensorWrapper experiences a mutation. + // When the mutation is under no_grad, or comes from a triton kernel, we also + // bump the corresponding during_no_grad or hidden_from_autograd counters. Why + // do we need to detect these two situations separately from "normal" input + // mutations? (1) "normal" input mutations can mutate autograd metadata like + // .grad_fn, + // in which case they need to be replayed outside of the compiled graph + // (2) "no_grad" input mutations are generally safe to keep in the graph (and + // compile), + // but they bump the tensor's VC, so we need to mark_dirty() on the inputs + // in torch.compile + // (3) mutations that are fully hidden from autograd (e.g. from a triton + // kernel) + // do not mutate any autograd state, and be fully kept in the graph + // When we detect that an input was mutated, we need to be able to tell if: + // (1) all of the mutations were from triton kernels + // (2) all of the mutations were under no_grad + uint64_t mutation_counter_during_no_grad_or_inference_mode_ = 0; + uint64_t mutation_counter_ = 0; + uint64_t mutation_counter_hidden_from_autograd_ = 0; + + // Used to tell if: + // (1) There were any storage resizes on a graph input + // (2) The original/curr storage size tell us if these resizes result in a nop + bool inductor_storage_resized_ = false; + uint64_t inductor_storage_resized_counter_ = 0; + c10::SymInt original_storage_size_; + c10::SymInt curr_storage_size_; +}; + +} // namespace at::functionalization + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/FunctionalTensorWrapper.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/FunctionalTensorWrapper.h new file mode 100644 index 0000000000000000000000000000000000000000..5c65cace41f0720913d23b009d3dbd5aa8c838f6 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/FunctionalTensorWrapper.h @@ -0,0 +1,476 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace at { + +// Note [Functionalization Pass In Core] +// The Functionalization pass is used to remove aliasing from a pytorch program. +// +// This is useful for backends that don't support aliasing, like XLA and Vulkan. +// It's also necessary in order to remove mutation from a program, which is +// needed in Functorch. +// +// Consider this program: +// a = torch.ones(...) +// b = a.view(...) +// b.add_(1) +// +// In this program, b is meant to alias with a due to the use of view(). At the +// end of the program, both a and b are full of 2's. However, backends that +// don't support aliasing aren't able to correctly implement the view() +// operator. Instead, they can opt into the Functionalization pass, which will +// sit between the user and the backend, and provide the necessary aliasing +// logic. +// +// The functionalization pass will turn the above program into a slightly +// different program that has the same semantics, transparently to the user, +// that backends like XLA/Vulkan are able to implement a = torch.ones(...) b = +// a.view_copy(...) # view() replaced with view_copy(). Backends like +// XLA/Vulkan can implement this! b.add_(1) a.add_(1) # Our functionalization +// pass machinery knows that a and b are aliased - it applies b's mutation to a +// too. +// +// So, how does the functionalization pass keep track of which tensors are +// aliased? The pass works by wrapping EVERY tensor in the program inside of a +// FunctionalTensorWrapper, which knows about its alias'd tensors. +// +// See Note [Functionalization: Alias Removal] for details on the aliasing +// machinery. See Note [Functionalization: Mutation Removal] for details on +// mutation removal. +struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl { + explicit FunctionalTensorWrapper(const Tensor& value); + // Additional constructor to create a FunctionalTensorWrapper directly from an + // underlying tensor that was created from a view. For example, the code b = + // a.view1() will generate a constructor call to FunctionalTensorWrapper(b, a, + // view1_meta) + explicit FunctionalTensorWrapper( + const Tensor& view_value, + const FunctionalTensorWrapper* base, + const std::shared_ptr& meta); + + // Get the underlying, actual tensor, that doesn't know anything about + // functionalization. + const Tensor& value() const { + return value_; + } + // The concept of "level" is only ever important to functorch; it's exposed + // here as more of a hook for functorch to use. + int64_t level() const { + return level_; + } + void set_level(int64_t level) { + level_ = level; + } + bool has_metadata_mutation() const { + return has_metadata_mutation_; + } + uint64_t mutation_counter() const { + return functional_storage_impl()->mutation_counter(); + } + void mark_mutation() { + functional_storage_impl()->mark_mutation(); + } + // Denotes a mutation that's hidden from autograd, + // e.g. for the purposes of passing a tensor to a triton kernel + void mark_mutation_hidden_from_autograd() { + functional_storage_impl()->mark_mutation_hidden_from_autograd(); + } + void mark_mutation_during_no_grad_or_inference_mode() { + functional_storage_impl()->mark_mutation_during_no_grad_or_inference_mode(); + } + // Are all the mutations happening to the tensor hidden from autograd + bool are_all_mutations_hidden_from_autograd() const { + return functional_storage_impl()->are_all_mutations_hidden_from_autograd(); + } + // Did all mutations happen under no_grad or inference_mode + // (We also need to ignore mutations fully hidden from autograd here) + bool are_all_mutations_under_no_grad_or_inference_mode() const { + return functional_storage_impl() + ->are_all_mutations_under_no_grad_or_inference_mode(); + } + + void maybe_mark_symbolic(functionalization::ViewMeta* meta) { + is_symbolic_ = is_symbolic_ | meta->has_symbolic_inputs; + } + + bool is_symbolic() const { + return is_symbolic_; + } + + // Retrieves the ViewMeta sequence of this tensor. + const std::vector>& view_metas() + const; + + // Sync's the underlying tensor with its alias, if it's out of date. This + // involves two steps: 1) Apply any pending updates/mutations to the alias 2) + // Replay the views (if any) to regenerate the current tensor off of the + // updated alias. + void sync_(); + // Performs step (1) of the sync. This is its own public API because it's + // needed by view_inplace ops like transpose_. See Note [Functionalization + // Pass - Inplace View Ops] + void regenerate_from_base(); + // Performs step (2) of the sync. This is its own public API because it's + // needed by functorch. functorch wants to make sure that all input tensors to + // a functionalized program have been properly synced so it can properly + // propagate mutations to inputs. It can't just call sync_(), because the + // FunctionalTensorWrapper will look like it has no aliases and sync_ will be + // a noop. We use the reference count on storage_ to determine if the wrapper + // is aliased, and by the time functorch is ready to propagate updates to + // inputs, any intermediate views of the input created by the program will + // have been deallocated. This function also returns whether or not the base + // actually had any updates to apply. + bool apply_updates(); + // Takes the current state of value_ and snapshots it, sending it as a pending + // update to the alias. + void commit_update(); + // When any tensor is mutated, the tensor increments its alias's "generation". + // Separately, each tensor maintains its own "generation" counter, which is + // used to determine if it's up-to-date with its alias. The act of syncing a + // tensor will set a tensor's generation equal to its alias's generation. + bool is_up_to_date() const; + // Freezes the storage of this tensor, preventing subsequent mutations + void freeze_storage() const; + // Every FunctionalTensorWrapper contains a vector objects + // describing the series of view ops that ran to generate the current tensor + // from the base tensor. This method is used by inplace-view ops like + // transpose_. It appends a ViewMeta to the existing stack, and refreshes the + // tensor by replaying the views off of the alias. + void mutate_view_meta( + const std::shared_ptr& meta); + + // Custom implementation of self.set_(src) + void set__impl(const FunctionalTensorWrapper* other); + + // Custom implementation of resize_storage_bytes_(self, new_size) + void storage_resize_(const c10::SymInt& new_size); + + // Returns whether the current tensor's data was ever mutated + bool has_data_mutation(); + // + // Returns whether the current FunctionalTensorWrapper + // experienced a set_() call. + bool was_storage_changed() { + return was_storage_changed_; + } + + void mark_storage_changed() { + was_storage_changed_ = true; + storage_changed_counter_++; + } + + uint64_t storage_changed_counter() { + return storage_changed_counter_; + } + + // A FunctionalTensor is considered a base if its not a view of another + // tensor. + bool isBaseTensor() const { + return view_metas_.empty(); + } + + c10::SymInt get_storage_size(bool before) { + return functional_storage_impl()->get_storage_size(before); + } + + // Returns whether the FunctionalTensor experienced an + // untyped_storage().resize_() call + bool was_inductor_storage_resized() { + return functional_storage_impl()->was_inductor_storage_resized(); + } + + bool inductor_storage_resized_counter() { + return functional_storage_impl()->inductor_storage_resized_counter(); + } + // The functionalization pass can be used to remove mutations. + // It does so by replacing any mutation op with it's corresponding + // out-of-place op, followed by a call to replace_(). e.g: + // + // a.add_(1) + // + // will turn into: + // + // tmp = a.add(1) + // a.replace_(tmp) + // + // replace_() swaps out the wrapped tensor, value_, with tmp. + void replace_(const Tensor& other, bool from_lazy_regenerate = false); + + bool is_multi_output_view() { + return is_multi_output_view_; + } + + // See Note[resize_() in functionalization pass] + void maybe_replace_storage(const Tensor& other); + + // Replaces the storage with a new functional storage, + // and clears the view_metas_ stack. + // WARNING: Calling this function will sever the aliasing relationship between + // the current FunctionalTensorWrapper and any of its outstanding aliases. + // Please only call if you know what you're doing. + void _unsafe_reset_storage(); + + c10::intrusive_ptr shallow_copy_and_detach( + const c10::VariableVersion& version_counter, + bool allow_tensor_metadata_change) const override; + + c10::intrusive_ptr shallow_copy_and_detach( + c10::VariableVersion&& version_counter, + bool allow_tensor_metadata_change) const override; + + ~FunctionalTensorWrapper() override = default; + + // FunctionalTensorWrapper overrides all custom size/stride function, + // so that if the inner tensor has a custom implementation + // we make sure to call that implementation. + at::IntArrayRef sizes_custom() const override; + at::IntArrayRef strides_custom() const override; + int64_t dim_custom() const override; + int64_t numel_custom() const override; + c10::SymBool sym_is_contiguous_custom( + at::MemoryFormat memory_format) const override; + c10::SymIntArrayRef sym_sizes_custom() const override; + c10::SymInt sym_size_custom(int64_t d) const override; + c10::SymIntArrayRef sym_strides_custom() const override; + c10::SymInt sym_storage_offset_custom() const override; + c10::Device device_custom() const override; + c10::Layout layout_impl() const override; + + private: + const char* tensorimpl_type_name() const override; + void set_constructor_metadata(); + functionalization::FunctionalStorageImpl* functional_storage_impl() const; + + // This is used to re-implement shallow_copy_and_detach for + // FunctionalTensorWrapper. The implementation is identical, but we just need + // to return a subclass instead of a plain TensorImpl. + // TODO: maybe it's possible to arrange for that to happen automatically + // without an override here? + template + c10::intrusive_ptr shallow_copy_and_detach_core( + VariableVersion&& version_counter, + bool allow_tensor_metadata_change) const; + + void shallow_copy_from(const c10::intrusive_ptr& impl) override; + void copy_tensor_metadata_and_refresh( + const FunctionalTensorWrapper* src_impl, + FunctionalTensorWrapper* dest_impl, + const c10::VariableVersion& version_counter, + bool allow_tensor_metadata_change) const; + + // Note that value is not taken by reference: internally, the wrapper will + // change the value tensor that it points to over time. + Tensor value_; + int64_t level_{}; + // These two counters are used for identifying + // whether all the mutations on a given tensor are hidden from autograd or + // not. If we have an input mutation that is hidden from autograd, then once + // we convert the input mutation to a copy_() we know it will be safe to hide + // the copy_() from autograd as well. + bool has_metadata_mutation_ = false; + bool is_multi_output_view_ = false; + // Did the tensor experience a set_() call. + bool was_storage_changed_ = false; + uint64_t storage_changed_counter_ = 0; + // Did the tensor experience any view operation with symbolic int. + bool is_symbolic_ = false; + + size_t generation_ = 0; + std::vector> view_metas_; + + protected: + static void copy_tensor_metadata( + const FunctionalTensorWrapper* src_impl, + FunctionalTensorWrapper* dest_impl, + const c10::VariableVersion& version_counter, + bool allow_tensor_metadata_change); +}; + +// Utility functions for the functionalization pass. + +namespace functionalization { +namespace impl { + +inline FunctionalTensorWrapper* unsafeGetFunctionalWrapper( + const Tensor& tensor) { + auto functional_impl = + static_cast(tensor.unsafeGetTensorImpl()); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(functional_impl != nullptr); + return functional_impl; +} + +TORCH_API bool isBaseTensor(const at::Tensor& tensor); + +TORCH_API bool isFunctionalTensor(const at::Tensor& tensor); +TORCH_API bool isFunctionalTensor(const std::optional& t); +TORCH_API bool isFunctionalTensor( + const c10::List>& t_list); +TORCH_API bool isFunctionalTensor(ITensorListRef list); + +TORCH_API Tensor to_functional_tensor(const Tensor& tensor); +TORCH_API std::optional to_functional_tensor( + const std::optional& tensor); +TORCH_API c10::List> to_functional_tensor( + const c10::List>& t_list); +TORCH_API std::vector to_functional_tensor(ITensorListRef t_list); + +TORCH_API void freeze_functional_tensor(const Tensor& tensor); + +TORCH_API Tensor +from_functional_tensor(const Tensor& tensor, bool assert_functional = true); +TORCH_API std::optional from_functional_tensor( + const std::optional& t, + bool assert_functional = true); +TORCH_API c10::List> from_functional_tensor( + const c10::List>& t_list); +TORCH_API std::vector from_functional_tensor(ITensorListRef t_list); + +TORCH_API void sync(const at::Tensor& t); +TORCH_API void sync(const std::optional& t); +TORCH_API void sync(const c10::List>& t_list); +TORCH_API void sync(ITensorListRef t_list); + +TORCH_API void replace_(const Tensor& functional_tensor, const Tensor& other); +TORCH_API void replace_( + const ITensorListRef functional_tensor, + ITensorListRef other); + +TORCH_API void commit_update(const Tensor& functional_tensor); +TORCH_API void commit_update(ITensorListRef functional_tensor); + +TORCH_API void unsafe_reset_storage(const Tensor& functional_tensor); + +TORCH_API void mark_mutation_hidden_from_autograd( + const Tensor& functional_tensor); + +TORCH_API bool are_all_mutations_hidden_from_autograd( + const Tensor& functional_tensor); + +TORCH_API bool are_all_mutations_under_no_grad_or_inference_mode( + const Tensor& functional_tensor); + +// These two methods are XLA-specific logic and are no-ops +// for the normal functionalization flow. +TORCH_API void propagate_xla_data( + const Tensor& functional_tensor, + const Tensor& other); +TORCH_API void propagate_xla_data( + const ITensorListRef functional_tensor, + ITensorListRef other); + +TORCH_API void propagate_xla_data_direct( + const Tensor& tensor, + const Tensor& other); +TORCH_API void propagate_xla_data_direct( + const ITensorListRef tensor, + ITensorListRef other); + +Tensor create_functional_tensor_with_view_meta( + const Tensor& view_to_wrap, + const Tensor& base, + const std::shared_ptr& meta, + int64_t out_idx = 0); +std::vector create_functional_tensor_with_view_meta( + ITensorListRef view_to_wrap, + const Tensor& base, + const std::shared_ptr& meta); + +void mutate_view_meta( + const Tensor& self, + const std::shared_ptr& meta); + +TORCH_API Tensor apply_view_meta_sequence( + const Tensor& base, + const std::vector>& sequence); + +void set_sizes_strides_offset(const Tensor& out, const Tensor& meta_out); +void set_sizes_strides_offset( + const std::vector& outs, + const std::vector& meta_outs); + +// ~~~~~ TLS used in functionalization ~~~~~ + +TORCH_API bool getFunctionalizationReapplyViewsTLS(); +TORCH_API void setFunctionalizationReapplyViewsTLS(bool reapply_views); + +class TORCH_API FunctionalizationReapplyViewsGuard { + public: + FunctionalizationReapplyViewsGuard(bool reapply_views) + : prev_(getFunctionalizationReapplyViewsTLS()) { + setFunctionalizationReapplyViewsTLS(reapply_views); + } + + ~FunctionalizationReapplyViewsGuard() { + setFunctionalizationReapplyViewsTLS(prev_); + } + + FunctionalizationReapplyViewsGuard( + const FunctionalizationReapplyViewsGuard&) = delete; + FunctionalizationReapplyViewsGuard operator=( + const FunctionalizationReapplyViewsGuard&) = delete; + FunctionalizationReapplyViewsGuard(FunctionalizationReapplyViewsGuard&&) = + delete; + FunctionalizationReapplyViewsGuard operator=( + FunctionalizationReapplyViewsGuard&&) = delete; + + private: + bool prev_; +}; + +} // namespace impl + +// Helper function to call an out-of-place composite aten kernel that may use +// mutations / views internally, and functionalize them. +TORCH_API void functionalize_op_helper( + const c10::OperatorHandle& op, + torch::jit::Stack* stack); + +template +struct _functionalize_aten_op final {}; + +template +struct _functionalize_aten_op final { + static ReturnType call( + typename c10::maybe_keep_symint::type... args) { + using FuncType = ReturnType( + typename c10::maybe_keep_symint::type...); + auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow( + (const char*)Op::name, (const char*)Op::overload_name) + .typed(); + + return c10::impl::BoxedKernelWrapper::call( + c10::BoxedKernel::makeFromFunction(), + op, + // BoxedKernelWrapper knows to ignore this keyset argument, + // because functionalize_op_helper doesn't take in a DispatchKeySet + c10::DispatchKeySet(), + args...); + } +}; + +template +using functionalize_aten_op = + _functionalize_aten_op; + +template +using functionalize_aten_op_symint = + _functionalize_aten_op; + +} // namespace functionalization +} // namespace at + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/FunctionalizeFallbackKernel.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/FunctionalizeFallbackKernel.h new file mode 100644 index 0000000000000000000000000000000000000000..d5e533beed3b10fb8b92ecd58feab5c659df2d5e --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/FunctionalizeFallbackKernel.h @@ -0,0 +1,63 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include + +namespace at::functionalization { + +// `ViewMeta` implementation for `resize_` operation. +struct TORCH_API resize__ViewMeta : public ViewMeta { + FUNCTIONALIZATION_VIEWMETA_NAME(resize__ViewMeta) + FUNCTIONALIZATION_VIEWMETA_SERIALIZABLE_TUPLE( + bool /* reapply_views */, + const std::vector&); + + resize__ViewMeta(const SerializableTuple& tpl) + : resize__ViewMeta(std::get<0>(tpl), std::get<1>(tpl)) {} + + resize__ViewMeta(bool reapply_views, const std::vector& size) + : ViewMeta(/*has_symbolic_inputs=*/false), + reapply_views(reapply_views), + size(size) {} + + Tensor forward(const Tensor& base) override; + Tensor reverse(const Tensor& base, const Tensor& mutated_view) override; + + SerializableTuple to_serializable_tuple() { + return std::make_tuple(reapply_views, size); + } + + bool reapply_views; + std::vector size; +}; + +// `ViewMeta` implementation for `_unsafe_view` operation. +struct TORCH_API _unsafe_view_ViewMeta : public ViewMeta { + FUNCTIONALIZATION_VIEWMETA_NAME(_unsafe_view_ViewMeta) + FUNCTIONALIZATION_VIEWMETA_SERIALIZABLE_TUPLE( + bool /* has_symbolic_inputs */, + const std::vector&); + + _unsafe_view_ViewMeta(const SerializableTuple& tpl) + : _unsafe_view_ViewMeta(std::get<0>(tpl), std::get<1>(tpl)) {} + + _unsafe_view_ViewMeta( + bool has_symbolic_inputs, + const std::vector& size) + : ViewMeta(has_symbolic_inputs), size(size) {} + + Tensor forward(const Tensor& base) override; + Tensor reverse(const Tensor& base, const Tensor& mutated_view) override; + + SerializableTuple to_serializable_tuple() { + return std::make_tuple(has_symbolic_inputs, size); + } + + std::vector size; +}; + +} // namespace at::functionalization + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/Functions.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/Functions.h new file mode 100644 index 0000000000000000000000000000000000000000..cffda238e71032ff373a8ccec0524ca9a13604a8 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/Functions.h @@ -0,0 +1,1476 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +// @generated by torchgen/gen.py from Functions.h + +#ifdef TORCH_ASSERT_NO_OPERATORS +#error This change adds a dependency on native_functions.yaml, \ + meaning the file will need to be re-compiled every time an operator \ + is changed or added. Consider if your change would be better placed in \ + another file, or if a more specific header might achieve the same goal. \ + See NOTE: [Tensor vs. TensorBase] +#endif + +#if defined(AT_PER_OPERATOR_HEADERS) && defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS) +#error This change adds a dependency on all pytorch operators, meaning the \ + file will need to be re-compiled every time an operator is changed or added. \ + Consider including a specific operator from and \ + see NOTE [TORCH_ASSERT_ONLY_METHOD_OPERATORS]. +#endif + +// NOTE: [TORCH_ASSERT_ONLY_METHOD_OPERATORS] +// +// In ATen, certain generated headers files include the definitions of +// every single operator in PyTorch. Unfortunately this means every +// time an operator signature is updated or changed in +// native_functions.yaml, you (and every other PyTorch developer) need +// to recompile every source file that includes any of these headers. +// +// To break up these header dependencies, and improve incremental +// build times for all PyTorch developers. These headers are split +// into per-operator headers in the `ATen/ops` folder. This limits +// incremental builds to only changes to methods of `Tensor`, or files +// that use the specific operator being changed. With `at::sum` as an +// example, you should include +// +// // instead of ATen/Functions.h +// // instead of ATen/NativeFunctions.h +// // instead of ATen/Operators.h +// // instead of ATen/CPUFunctions.h +// +// However, even if you're careful to use this in your own code. +// `Functions.h` might be included indirectly through another header +// without you realising. To avoid this, you can add +// +// #define TORCH_ASSERT_ONLY_METHOD_OPERATORS +// +// to the top of your source file. This way any time the non-specific +// headers are included, the compiler will error out. +// +// Also, be aware that `ops` are not available in all build +// configurations (namely fb-internal) so you must guard these +// includes with `#ifdef AT_PER_OPERATOR_HEADERS`. e.g. +// +// #ifndef AT_PER_OPERATOR_HEADERS +// #include +// #else +// #include +// #endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { + + + +// Special C++ only overloads for std()-like functions (See gh-40287) +// These are needed because int -> bool conversion takes precedence over int -> IntArrayRef +// So, for example std(0) would select the std(unbiased=False) overload +inline Tensor var(const Tensor& self, int dim) { + return at::var(self, IntArrayRef{dim}); +} +inline std::tuple var_mean(const Tensor& self, int dim) { + return at::var_mean(self, IntArrayRef{dim}); +} +inline Tensor std(const Tensor& self, int dim) { + return at::std(self, IntArrayRef{dim}); +} +inline std::tuple std_mean(const Tensor& self, int dim) { + return at::std_mean(self, IntArrayRef{dim}); +} + +inline int64_t numel(const Tensor& tensor) { + return tensor.numel(); +} + +inline int64_t size(const Tensor& tensor, int64_t dim) { + return tensor.size(dim); +} + +inline int64_t stride(const Tensor& tensor, int64_t dim) { + return tensor.stride(dim); +} + +inline bool is_complex(const Tensor& tensor) { + return tensor.is_complex(); +} + +inline bool is_floating_point(const Tensor& tensor) { + return tensor.is_floating_point(); +} + +inline bool is_signed(const Tensor& tensor) { + return tensor.is_signed(); +} + +inline bool is_inference(const Tensor& tensor) { + return tensor.is_inference(); +} + +inline bool _is_zerotensor(const Tensor& tensor) { + return tensor._is_zerotensor(); +} + +inline bool is_conj(const Tensor& tensor) { + return tensor.is_conj(); +} + +inline Tensor conj(const Tensor& tensor) { + return tensor.conj(); +} + +inline bool is_neg(const Tensor& tensor) { + return tensor.is_neg(); +} + +} + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/Generator.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/Generator.h new file mode 100644 index 0000000000000000000000000000000000000000..5fad06d7668f88deb252e211d0c2a89d76769a8c --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/Generator.h @@ -0,0 +1,7 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +#include + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/InferSize.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/InferSize.h new file mode 100644 index 0000000000000000000000000000000000000000..33ca7e6b14d6428f40056afe289b062b42207f96 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/InferSize.h @@ -0,0 +1,133 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { + +// Infers the size of a dim with size -1, if it exists. Also checks that new +// shape is compatible with the number of elements. +// +// templated to handle std::vector and DimVector use cases, see +// below +// +template +inline void infer_size_impl( + InputArrayRef shape, + NumelType numel, + ResultVec& res) { + NumelType newsize = 1; + // N.B. this is an index, not a sym dim! + std::optional infer_dim; + for (int64_t dim = 0, ndim = shape.size(); dim != ndim; dim++) { + if (TORCH_GUARD_OR_FALSE(sym_eq(shape[dim], -1))) { + TORCH_CHECK(!infer_dim, "only one dimension can be inferred"); + infer_dim = dim; + } else { + // in case of unbacked shape[dim] we assume it's not -1 and add a runtime + // assertion. + TORCH_MAYBE_SYM_CHECK( + sym_gt(shape[dim], -1), + "invalid shape dimension ", + shape[dim], + " at index ", + dim, + " of shape ", + shape); + newsize *= shape[dim]; + } + } + + if (infer_dim) { + // numel is the product of known sizes, it has to be divisible by newsize. + // and newsize should be positive unless newsize == numel (we throw + // different) error message in that case. + if constexpr (std::is_same_v) { + auto v = newsize.maybe_as_int(); + if (v and *v == 0) { + // Avoid div by 0 when sym_eq(numel % newsize, 0) is constructed! + // which may happen when newsize is not a symbol! if its a symbol + // division won't happen anyway during compile. + TORCH_MAYBE_SYM_CHECK( + numel == newsize, + "shape '", + shape, + "' is invalid for input of size ", + numel); + } else { + auto cond = sym_gt(newsize, 0) + .sym_and(sym_eq(numel % newsize, 0)) + .sym_or(sym_eq(numel, newsize)); + TORCH_MAYBE_SYM_CHECK( + cond, "shape '", shape, "' is invalid for input of size ", numel); + } + + } else { + TORCH_CHECK( + (newsize > 0 && (numel % newsize == 0)) || numel == newsize, + "shape '", + shape, + "' is invalid for input of size ", + numel); + } + + // We have a degree of freedom here to select the dimension size; follow + // NumPy semantics and just bail. However, a nice error message is needed + // because users often use `view` as a way to flatten & unflatten + // dimensions and will otherwise be confused why + // empty_tensor.view( 0, 0) + // works yet + // empty_tensor.view(-1, 0) + // doesn't. + TORCH_MAYBE_SYM_CHECK( + newsize != 0, + "cannot reshape tensor of 0 elements into shape ", + shape, + " because the unspecified dimension size -1 can be any " + "value and is ambiguous"); + + res[*infer_dim] = numel / newsize; + return; + } + + TORCH_MAYBE_SYM_CHECK( + sym_eq(numel, newsize), + "shape '", + shape, + "' is invalid for input of size ", + numel); +} + +inline std::vector infer_size(IntArrayRef shape, int64_t numel) { + auto res = shape.vec(); + infer_size_impl(shape, numel, res); + return res; +} + +inline at::DimVector infer_size_dv(IntArrayRef shape, int64_t numel) { + auto res = at::DimVector(shape); + infer_size_impl(shape, numel, res); + return res; +} + +inline at::SymDimVector infer_size_dv( + c10::SymIntArrayRef shape, + c10::SymInt numel) { + auto res = at::SymDimVector(shape); + infer_size_impl( + shape, std::move(numel), res); + return res; +} + +} // namespace at + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/InitialTensorOptions.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/InitialTensorOptions.h new file mode 100644 index 0000000000000000000000000000000000000000..333dcf3ad83dd27fdeedbfc9a328cbc7eeea7780 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/InitialTensorOptions.h @@ -0,0 +1,20 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include + +namespace at { + +// Represents the initial TensorOptions, before the "defaults" are ever changed. +// This is designed to be used in library code, where the explicit devices, +// dtypes, etc. are known. NOTE: this is not a stable API. +inline TensorOptions initialTensorOptions() { + return TensorOptions(kCPU).dtype(kFloat).layout(kStrided).requires_grad( + false); +} + +} // namespace at + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/Layout.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/Layout.h new file mode 100644 index 0000000000000000000000000000000000000000..e781763d973b63be695b5349f8c07ad15e9f939d --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/Layout.h @@ -0,0 +1,7 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +#include + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/LegacyBatchedFallback.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/LegacyBatchedFallback.h new file mode 100644 index 0000000000000000000000000000000000000000..759e191316a9749aac3d28378e38f2a7d069cead --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/LegacyBatchedFallback.h @@ -0,0 +1,30 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +#include +#include +#include + +namespace at { + +// If an operator doesn't have a batching rule implemented then we fallback +// to this implementation. The fallback only works on out-of-place operators +// that return only tensors with new memory. (e.g., no in-place operators, no +// view operations). +// +// The fallback effectively takes all of the BatchedTensors in `stack`, slices +// them, and runs `op` on all of the corresponding slices to produce slices +// of the outputs. The output slices then get `torch.stack`ed to create the +// final returns. +// +// The performance of the fallback is not very good because it introduces an +// extra copy from stacking the sliced outputs. Because of this, we prefer to +// write batching rules for operators whenever possible. +void batchedTensorForLoopFallback( + const c10::OperatorHandle& op, + torch::jit::Stack* stack); + +} // namespace at + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/LegacyBatchedTensorImpl.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/LegacyBatchedTensorImpl.h new file mode 100644 index 0000000000000000000000000000000000000000..ddeb243ed4744ce2d87fd1194c5a5a7f33d9577c --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/LegacyBatchedTensorImpl.h @@ -0,0 +1,166 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include + +#include +#include +#include + +namespace at { + +// We assume this in a few other places in the codebase, +// but there isn't a centralized definition. +constexpr int64_t kVmapMaxTensorDims = 64; + +// The valid vmap levels range from [0, 64). This effectively means that we +// support a maximum of 64 nested vmaps. +constexpr int64_t kVmapNumLevels = 64; + +// Store this number of elements of BatchDims on the stack. Most people will +// probably use <= 5 nested vmaps, but adjust this number as necessary. +constexpr int64_t kBatchDimsStackSize = 5; + +// a BatchDim represents a "private" dimension on a Tensor created inside of +// vmap. It is a (level, dim) tuple, with the `dim` indicating which dimension +// is being vmap'ed over and the `level` being an identifier for which vmap +// said dimension was created inside. The `dim` corresponds to a "physical +// dim" - it is a dimension index on the underlying physical tensor that is +// being vmapped over. +struct BatchDim { + BatchDim(int64_t level, int64_t dim) : dim_(dim), level_(level) {} + int64_t dim() const { + return dim_; + } + int64_t level() const { + return level_; + } + + private: + int64_t dim_; + int64_t level_; +}; + +using BatchDims = SmallVector; +using BatchDimsRef = ArrayRef; + +// A BatchedTensorImpl holds an underlying Tensor and a list of BatchDim +// NB: We use the term "BatchedTensor" to mean a Tensor that is backed with a +// BatchedTensorImpl. +// +// The batch dimensions are treated as being "private"; they are not +// user-visible. For example, in the following Tensor, +// bt = BatchedTensorImpl(ones(2, 3, 5, 7), [(lvl=1, dim=0), (lvl=2, dim=1)]) +// dimensions 0 and 1 are batch dimensions. +// +// bt.sizes() returns (5, 7); bt.sum(0) performs a reduction over the (public) +// dim 0, which is equivalent to dim 3 in the underlying ones(2, 3, 5, 7) +// tensor. +struct TORCH_API BatchedTensorImpl : public c10::TensorImpl { + explicit BatchedTensorImpl(Tensor value, BatchDims bdims); + + // Returns a reference to BatchDims that represent which dimensions of this + // tensor are private. + BatchDimsRef bdims() const { + return bdims_; + } + + // BatchedTensorImpl wraps a Tensor + const Tensor& value() const { + return value_; + } + + // Given a public dimension index, return the dimension index in the + // underlying value() tensor. For example, if we have + // bt = BatchedTensorImpl(ones(2, 3, 5, 7), [(lvl=1, dim=0), (lvl=2, + // dim=2)]) + // bt.actualDim(0) -> 1 + // bt.actualDim(1) -> 3 + // bt.actualDim(2) -> Error + int64_t actualDim(int64_t dim, bool wrap_dim = true) const; + + // We have to override this because we opted into CustomStrides + IntArrayRef strides_custom() const override; + // Override a bunch of methods inherited from TensorImpl to return error + // messages. + c10::SymBool sym_is_contiguous_custom( + at::MemoryFormat memory_format) const override; + void set_size(int64_t dim, int64_t new_size) override; + void set_stride(int64_t dim, int64_t new_stride) override; + void set_storage_offset(int64_t storage_offset) override; +#ifdef DEBUG + bool has_storage() const override; +#endif + + private: + // see NOTE: [BatchedTensorImpl levels invariant] + void checkInvariants() const; + const char* tensorimpl_type_name() const override; + + Tensor value_; + + // Note: [BatchedTensorImpl levels invariant] + // There is an invariant that the BatchDims must be stored in increasing + // `level` order. That is, for i < j, bdims_[i].level must be less than + // bdims_[j].level. + BatchDims bdims_; +}; + +// NB: We use the term "BatchedTensor" to mean a Tensor that is backed with a +// BatchedTensorImpl. +inline bool isBatchedTensor(const Tensor& tensor) { + return tensor.unsafeGetTensorImpl()->key_set().has(DispatchKey::Batched); +} + +// It is unsafe to call this on a Tensor that is not backed by a +// BatchedTensorImpl. Please use `maybeGetBatchedImpl` whenever possible. +inline BatchedTensorImpl* unsafeGetBatchedImpl(const Tensor& tensor) { + return static_cast(tensor.unsafeGetTensorImpl()); +} + +inline BatchedTensorImpl* maybeGetBatchedImpl(const Tensor& tensor) { + if (!isBatchedTensor(tensor)) { + return nullptr; + } + return unsafeGetBatchedImpl(tensor); +} + +// Returns a bitset. If bit i is set, then that means dim i is a batchdim. +inline std::bitset createBatchDimBitset( + BatchDimsRef bdims) { + std::bitset is_bdim; + for (const auto& bdim : bdims) { + is_bdim.set(bdim.dim()); + } + return is_bdim; +} + +// Creates a bitset for all of the levels present in `bdims` +inline std::bitset createVmapLevelsBitset(BatchDimsRef bdims) { + std::bitset result; + for (const auto& bdim : bdims) { + result.set(bdim.level()); + } + return result; +} + +inline std::ostream& operator<<(std::ostream& out, const BatchDim& bdim) { + out << "(lvl=" << bdim.level() << ", dim=" << bdim.dim() << ')'; + return out; +} + +// Use this to construct a BatchedTensor from a regular Tensor +TORCH_API Tensor makeBatched(Tensor tensor, BatchDims bdims); + +// Adds a batch dim to `tensor`, returning a BatchedTensor +TORCH_API Tensor addBatchDim(Tensor tensor, int64_t level, int64_t dim); + +// Checks if an inplace operation on self and other is "vmap compatible". +// See NOTE: [vmap-incompatible in-place operations] for the definition of this. +TORCH_API bool inplaceIsVmapCompatible(const Tensor& self, const Tensor& other); + +} // namespace at + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/LegacyVmapMode.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/LegacyVmapMode.h new file mode 100644 index 0000000000000000000000000000000000000000..3f7588904df5992b9aefd14da3348d1149d06b82 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/LegacyVmapMode.h @@ -0,0 +1,31 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include + +namespace at::impl { + +// VmapMode contains a thread local count of how many nested vmaps +// we are currently inside. That number is known as the `vmap level`. +// VmapMode is used in the implementation of the Python `torch.vmap` API. +// +// NOTE: this is NOT the c++ api for torch.vmap. That doesn't exist yet. + +struct TORCH_API VmapMode { + // Returns the vmap level, aka the count of how many nested vmaps we're in. + static int64_t current_vmap_level(); + + // Increment the count of nested vmaps. If this causes the vmap level to be + // greater than 0, then it enables DispatchKey::VmapMode on all tensors. + static int64_t increment_nesting(); + + // Decrements the count of nested vmaps. If this causes the vmap level to be + // equal to 0, then it disables DispatchKey::VmapMode on all tensors. + static int64_t decrement_nesting(); +}; + +} // namespace at::impl + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/LegacyVmapTransforms.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/LegacyVmapTransforms.h new file mode 100644 index 0000000000000000000000000000000000000000..0b301d8c8b2f0fe1b7e067b984dec052b0b7b097 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/LegacyVmapTransforms.h @@ -0,0 +1,188 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include + +namespace at { + +// This file contains abstractions used for transforming *logical* vmap +// arguments into *physical* arguments. (Keep reading for definitions of these +// terms). + +// NOTE: [Logical vs physical args] +// Consider the following vmap. +// vmap(vmap(func, in_dims=(2,)), in_dims=(0,))(torch.ones(2, 3, 4)) +// This would produce a BatchedTensor wrapping a Tensor of size [2, 3, 4], +// with batch dims 0 and 2: +// BatchedTensor(ones(2, 3, 4), bdims=[(lvl=1,dim=0),(lvl=2,dim=2)]) +// +// We say the *logical* view of the tensor has size [3] -- tensors inside +// `func` appear to have size [3]. +// However, the *physical* underlying tensor (the one passed to vmap) has size +// [2, 3, 4]. +// +// This notion of logical vs physical also extends to non-tensor arguments. +// Consider the previous tensor; let's assume the user called +// `torch.sum(tensor, dim=0)` inside of `func`. Then the logical +// dimension they are reducing over is dim 0 but the physical dim is dim 1 +// (the first non-batch dimension) + +// Forward declared; see NOTE: [What is a VmapPhysicalView?] +struct VmapPhysicalView; + +// Most PyTorch operators take 4 or fewer inputs. +constexpr int64_t kVmapTransformStaticInputSize = 4; +using VmapPhysicalViewVec = + SmallVector; + +// Pytorch generally advertises good performance for <= 5 dims. +// (see ATen/core/DimVector.h). We add a few extra dims (~3) for vmap +// dimensions to get 8. Adjust this number as necessary +constexpr int64_t kVmapStaticDimVecSize = 8; +using VmapDimVector = SmallVector; +using VmapSymDimVector = SmallVector; + +// NOTE: [What is an VmapTransform?] +// An *VmapTransform* converts logical views of tensors to physical views. +// +// Batching rules use VmapTransforms to convert logical arguments to +// physical arguments, then call one or more at:: operator that handles the +// physical arguments, and then converts the physical result back to a logical +// argument. + +// VmapTransform for operators that take tensors with multiple batch dims. +// Given one or more logical views on Tensors, `logicalToPhysical` +// permutes all of the batch dims to the front of the tensor, aligns +// and expands the batch dims to match each other (according to their `level`), +// and returns a VmapPhysicalView on the tensor(s). +struct TORCH_API MultiBatchVmapTransform { + static VmapPhysicalView logicalToPhysical(const Tensor& logical_tensor); + static VmapPhysicalViewVec logicalToPhysical(ITensorListRef logical_tensors); +}; + +// VmapTransform for operators that broadcast all inputs. +// Given some logical views on Tensors, `logicalToPhysical`: +// - permutes all of the batch dims to the front of the tensors +// - aligns all the batch dims to the collective levels of all of the tensors. +// If a tensor does not have a batch dim for a vmap level, then it receives +// a size-one dimension for said level. +// - aligns the non-batch dims to have the same dimensionality, adding extra +// size-1 dimensions in between the batch dimensions and the non-batch +// dimensions so that the batch dimensions are lined up from the right. +// +// For example: given inputs of size (B, 2) and (B, 3, 2) where B is the batch +// dimension, BroadcastingVmapTransform returns VmapPhysicalViews that wrap +// tensors of size (B, 1, 2) and (B, 3, 2). +// +// Given inputs of size (B, 2) and (2,), BroadcastingVmapTransform returns +// VmapPhysicalViews wrapping tensors of size (B, 2) and (1, 2). We don't +// actually *need* to return a tensor of size (1, 2) for the second tensor +// because the broadcasting operation takes care of that for us, but we do +// it anyways to keep things simple. +struct TORCH_API BroadcastingVmapTransform { + static VmapPhysicalViewVec logicalToPhysical(TensorList logical_tensors); +}; + +// Forward declared, if you're reading this file head to toe, don't worry about +// it yet. +struct VmapPhysicalToLogicalMap; + +// NOTE: [What is a VmapPhysicalView?] +// VmapPhysicalView represents a physical view on a Tensor. +// +// One can use it to further convert logical dimension indices, logical shapes, +// and more to their physical variants, or convert a new (physical) tensor into +// a logical BatchedTensor. (TODO(rzou): some of these are not yet implemented). +// +// VmapPhysicalView stores a physical tensor with all of its batch dimensions at +// the front and some levels that correspond to said batch dimensions. +// +// The levels bitset specifies which vmap levels correspond to the batch +// dimensions at the front of the tensor. In particular, the number of set bits +// corresponds to the number of batch dimensions on `tensor` and the rightmost +// bit of `levels` specifies the maximum number of nested vmaps we are in at +// this point in time. +// For example, given: +// physical_view = VmapPhysicalView(tensor=ones(2, 3, 4, 5, 6), levels={1, 3}) +// +// Rightmost bit of `levels` is 3 indicating the number of nested vmaps less +// than or equal to 3. +// bitset: 010100 +// ^ +// | +// levels: 012345 +struct TORCH_API VmapPhysicalView { + VmapPhysicalView(Tensor&& tensor, std::bitset levels) + : levels_(levels), tensor_(std::move(tensor)) { + TORCH_INTERNAL_ASSERT(!isBatchedTensor(tensor_)); + } + + Tensor& tensor() { + return tensor_; + } + const Tensor& tensor() const { + return tensor_; + } + + // Maps logical dim indices to physical dim indices. Also does dim wrapping. + // + // For example, given: + // physical_view = VmapPhysicalView(tensor=ones(2, 3, 4, 5), levels={1, 3}) + // + // Then physical_view.getPhysicalDims({0, 1}) returns {2, 3}. + // This is because the size of levels tell us that the first two dimensions + // of `tensor_` are batch dimensions, so a logical dim of `n` is actually + // a physical dim of `n + 2`. + VmapDimVector getPhysicalDims(OptionalIntArrayRef logical_dims) const; + int64_t getPhysicalDim(int64_t logical_dim) const; + + // Returns a VmapPhysicalToLogicalMap object. This can be used for + // mapping a physical tensor to a new logical tensor (BatchedTensor) + VmapPhysicalToLogicalMap getPhysicalToLogicalMap() const; + + // Maps a logical shape to a physical shape by prepending the batch + // sizes to the logical shape. + VmapDimVector getPhysicalShape(IntArrayRef logical_shape) const; + + int64_t numBatchDims() const; + + private: + int64_t numLogicalDims() const; + + std::bitset levels_; + Tensor tensor_; +}; + +// Convenience struct used for mapping a physical tensor (a non-BatchedTensor) +// to a logical one (BatchedTensor). It holds some levels that are used to do +// the mapping and assumes that the batch dimensions in the physical tensor all +// occur at the front of the tensor. +struct TORCH_API VmapPhysicalToLogicalMap { + VmapPhysicalToLogicalMap(std::bitset levels) + : levels_(levels) {} + + // Maps a physical tensor to a new logical tensor (BatchedTensor). + // Assumes that all of the "batch dimensions" are at the front + // of the physical tensor. For example, given: + // - x = rank-4 Tensor with size 2, 3, 5, 7 + // - levels = (2, 4) + // Returns: + // - BatchedTensor(x, bdims=[(dim=0,lvl=2), (dim=1, lvl=4)]) + Tensor apply(const Tensor& physical_tensor) const; + + // Given a vector of physical tensors, + // 1. maps each tensor to a new logical tensor. Assumes that all of the + // "batch dimensions" are at the front of the physical tensors. + // 2. stores the new logical tensors back into the passed-in vector. This is + // to avoid additional dynamic allocations. + void applyInplace(std::vector& physical_tensors) const; + + std::bitset levels_; +}; + +} // namespace at + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/LinalgBackend.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/LinalgBackend.h new file mode 100644 index 0000000000000000000000000000000000000000..87acc1e26194926f5da71211ab12c1c9767cef82 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/LinalgBackend.h @@ -0,0 +1,36 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include + +#include +#include + +namespace at { + +enum class LinalgBackend : int8_t { Default, Cusolver, Magma }; + +inline std::string LinalgBackendToString(at::LinalgBackend backend) { + switch (backend) { + case LinalgBackend::Default: + return "at::LinalgBackend::Default"; + case LinalgBackend::Cusolver: + return "at::LinalgBackend::Cusolver"; + case LinalgBackend::Magma: + return "at::LinalgBackend::Magma"; + default: + TORCH_CHECK(false, "Unknown linalg backend"); + } +} + +inline std::ostream& operator<<( + std::ostream& stream, + at::LinalgBackend backend) { + return stream << LinalgBackendToString(backend); +} + +} // namespace at + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/MapAllocator.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/MapAllocator.h new file mode 100644 index 0000000000000000000000000000000000000000..c603a6a33fcc68864ce8c3ab2bbc97ec741071be --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/MapAllocator.h @@ -0,0 +1,152 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include + +namespace at { + +enum MappedAllocatorModes { + ALLOCATOR_MAPPED_SHARED = 1, + ALLOCATOR_MAPPED_SHAREDMEM = 2, + ALLOCATOR_MAPPED_EXCLUSIVE = 4, + ALLOCATOR_MAPPED_NOCREATE = 8, + ALLOCATOR_MAPPED_KEEPFD = 16, + ALLOCATOR_MAPPED_FROMFD = 32, + ALLOCATOR_MAPPED_UNLINK = 64 +}; + +// Sentinel value/type to help distinguish the file descriptor constructor from +// the non-file descriptor constructor +enum WithFd { WITH_FD }; + +TORCH_API std::string NewProcessWideShmHandle(); + +class TORCH_API MapAllocator { + public: + MapAllocator(std::string_view filename, int flags, size_t size); + MapAllocator( + WithFd /*unused*/, + std::string_view filename, + int fd, + int flags, + size_t size); + MapAllocator(const MapAllocator&) = delete; + MapAllocator& operator=(const MapAllocator&) = delete; + MapAllocator(MapAllocator&&) = delete; + MapAllocator& operator=(MapAllocator&&) = delete; + + const char* filename() const { + return filename_.c_str(); + } + int fd() const { +#ifdef _WIN32 + TORCH_CHECK(false, "MapAllocator::fd() is unsupported on Windows"); +#else + return fd_; +#endif + } + ptrdiff_t size() const { + return size_; + } + // Return a pointer to the actual data for this allocator + // (in the case of the refcounted allocator, this is offset + // from the base pointer.) + virtual void* data() const { + return base_ptr_; + } + + int flags() const { + return flags_; + } + + static MapAllocator* fromDataPtr(const at::DataPtr& /*dptr*/); + static at::DataPtr makeDataPtr( + std::string_view filename, + int flags, + size_t size, + size_t* actual_size_out); + static at::DataPtr makeDataPtr( + WithFd /*unused*/, + const char* filename, + int fd, + int flags, + size_t size, + size_t* actual_size_out); + + // Closes the data. Helps us avoid destructor shenanigans + virtual void close(); + + // This is very dangerous. You have to redefine this destructor for each + // subclass + virtual ~MapAllocator(); + + protected: + bool closed_ = false; + std::string filename_; + int flags_ = 0; + ptrdiff_t size_; /* mapped size */ +#ifdef _WIN32 + void* handle_; + void* event_; + std::string eventname_; +#else + int fd_ = -1; +#endif + void* base_ptr_ = nullptr; +}; + +// Base-from-member idiom +struct TORCH_API RefcountedMapAllocatorArgCheck { + RefcountedMapAllocatorArgCheck(int flags); +}; + +class TORCH_API RefcountedMapAllocator : private RefcountedMapAllocatorArgCheck, + public MapAllocator { + public: + RefcountedMapAllocator(const char* filename, int flags, size_t size); + RefcountedMapAllocator( + WithFd /*unused*/, + const char* filename, + int fd, + int flags, + size_t size); + + static RefcountedMapAllocator* fromDataPtr(const at::DataPtr& /*dptr*/); + RefcountedMapAllocator(const RefcountedMapAllocator&) = delete; + RefcountedMapAllocator(RefcountedMapAllocator&&) = delete; + RefcountedMapAllocator& operator=(const RefcountedMapAllocator&) = delete; + RefcountedMapAllocator& operator=(RefcountedMapAllocator&&) = delete; + static at::DataPtr makeDataPtr( + const char* filename, + int flags, + size_t size, + size_t* actual_size_out); + static at::DataPtr makeDataPtr( + WithFd /*unused*/, + const char* filename, + int fd, + int flags, + size_t size, + size_t* actual_size_out); + + void* data() const override; + + void incref(); + int decref(); + void close() override; + + ~RefcountedMapAllocator() override { + RefcountedMapAllocator::close(); + } + + protected: + void checkFlags(); + void initializeAlloc(); +}; + +} // namespace at + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/MatrixRef.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/MatrixRef.h new file mode 100644 index 0000000000000000000000000000000000000000..c0f63cc2d4ee11abdc4ebcb50f2765d6525d5953 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/MatrixRef.h @@ -0,0 +1,114 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +#include +#include + +namespace at { +/// MatrixRef - Like an ArrayRef, but with an extra recorded strides so that +/// we can easily view it as a multidimensional array. +/// +/// Like ArrayRef, this class does not own the underlying data, it is expected +/// to be used in situations where the data resides in some other buffer. +/// +/// This is intended to be trivially copyable, so it should be passed by +/// value. +/// +/// For now, 2D only (so the copies are actually cheap, without having +/// to write a SmallVector class) and contiguous only (so we can +/// return non-strided ArrayRef on index). +/// +/// P.S. dimension 0 indexes rows, dimension 1 indexes columns +template +class MatrixRef { + public: + typedef size_t size_type; + + private: + /// Underlying ArrayRef + ArrayRef arr; + + /// Stride of dim 0 (outer dimension) + size_type stride0; + + // Stride of dim 1 is assumed to be 1 + + public: + /// Construct an empty Matrixref. + /*implicit*/ MatrixRef() : arr(nullptr), stride0(0) {} + + /// Construct an MatrixRef from an ArrayRef and outer stride. + /*implicit*/ MatrixRef(ArrayRef arr, size_type stride0) + : arr(arr), stride0(stride0) { + TORCH_CHECK( + arr.size() % stride0 == 0, + "MatrixRef: ArrayRef size ", + arr.size(), + " not divisible by stride ", + stride0) + } + + /// @} + /// @name Simple Operations + /// @{ + + /// empty - Check if the matrix is empty. + bool empty() const { + return arr.empty(); + } + + const T* data() const { + return arr.data(); + } + + /// size - Get size a dimension + size_t size(size_t dim) const { + if (dim == 0) { + return arr.size() / stride0; + } else if (dim == 1) { + return stride0; + } else { + TORCH_CHECK( + 0, "MatrixRef: out of bounds dimension ", dim, "; expected 0 or 1"); + } + } + + size_t numel() const { + return arr.size(); + } + + /// equals - Check for element-wise equality. + bool equals(MatrixRef RHS) const { + return stride0 == RHS.stride0 && arr.equals(RHS.arr); + } + + /// @} + /// @name Operator Overloads + /// @{ + ArrayRef operator[](size_t Index) const { + return arr.slice(Index * stride0, stride0); + } + + /// Disallow accidental assignment from a temporary. + /// + /// The declaration here is extra complicated so that "arrayRef = {}" + /// continues to select the move assignment operator. + template + // NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward) + std::enable_if_t, MatrixRef>& operator=( + // NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward) + U&& Temporary) = delete; + + /// Disallow accidental assignment from a temporary. + /// + /// The declaration here is extra complicated so that "arrayRef = {}" + /// continues to select the move assignment operator. + template + std::enable_if_t, MatrixRef>& operator=( + std::initializer_list) = delete; +}; + +} // end namespace at + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/MemoryOverlap.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/MemoryOverlap.h new file mode 100644 index 0000000000000000000000000000000000000000..e090c8091d03a5e45c9991fa4a532be11784ea29 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/MemoryOverlap.h @@ -0,0 +1,47 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include + +namespace c10 { +struct TensorImpl; +} + +namespace at { +class TensorBase; + +// MemOverlap: Whether or not there is memory overlap +// +// No: Absolutely no memory overlap +// Yes: Absolutely yes memory overlap +// TooHard: There might be memory overlap, but it was too expensive to compute. +// +// NB: Please update the python test for these if you renumber them. +enum class MemOverlap { No, Yes, TooHard }; + +enum class MemOverlapStatus { Full, Partial, No, TooHard }; + +TORCH_API MemOverlap has_internal_overlap(const TensorBase& t); +TORCH_API MemOverlap has_internal_overlap(c10::TensorImpl* t); + +TORCH_API void assert_no_internal_overlap(const TensorBase& t); +TORCH_API void assert_no_internal_overlap(c10::TensorImpl* t); + +TORCH_API MemOverlapStatus +get_overlap_status(const TensorBase& a, const TensorBase& b); +TORCH_API MemOverlapStatus +get_overlap_status(const c10::TensorImpl* a, const c10::TensorImpl* b); + +TORCH_API void assert_no_partial_overlap( + const TensorBase& a, + const TensorBase& b); +void assert_no_partial_overlap(c10::TensorImpl* a, c10::TensorImpl* b); + +TORCH_API void assert_no_overlap(const TensorBase& a, const TensorBase& b); +TORCH_API void assert_no_overlap(c10::TensorImpl* a, c10::TensorImpl* b); + +} // namespace at + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/MetaFunctions.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/MetaFunctions.h new file mode 100644 index 0000000000000000000000000000000000000000..ecdf355908f3c5f4f0d74c43f0b9a401561d2689 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/MetaFunctions.h @@ -0,0 +1,34 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#include + +// TODO Undo all logic introduced for Note [Avoiding Include Cycles In Static Dispatch] +// Code introduced to avoid cyclic dependency in static dispatch is no longer +// needed as static dispatch logic is moved from TensorBody.h, which caused cycles in the first place, +// to Operators.cpp for supporting multiple backends with multiple kernels. +// +// Note [Avoiding Include Cycles In Static Dispatch] +// In order to avoid #include cycles in the static dispatch build, we've carefully split out +// the static function definition files into {DispatchKey}Functions.h and {DispatchKey}Functions_inl.h. +// +// Without this split, the include cycle looks like TensorBody.h -> CPUFunctions.h -> TensorBody.h. +// - TensorBody.h #includes CPUFunctions.h in the static dispatch build, because the tensor methods +// all need to call into the fastpath C++ API defined in CPUFunctions.h. The methods are also all +// directly inlined into TensorBody.h. +// - CPUFunctions.h #includes TensorBody.h because it contains function declarations for the entire C++ API, +// which include functions that have defaultable std::optional arguments. +// That requires knowing the full Tensor class definition. +// +// We break the cycle by doing the following: +// - Split out CPUFunction.h into two files: CPUFunctions.h and CPUFunctions_inl.h +// - CPUFunction.h is a dummy file that just includes the Tensor class and includes CPUFunctions_inl., +// - CPUFunctions_inl.h includes everything else +// - (only in the static dispatch build) TensorBody.h makes sure to finish defining the Tensor class, +// and then it includes CPUFunctions_inl.h. +// - All other files that want the cpu fastpath functions can include CPUFunctions.h directly. +// - This also means that static dispatch build, CPUFunctions.h only needs to +// #include TensorBody.h, and it will automatically bring in CPUFunctions_inl.h. +#include + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/MetaFunctions_inl.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/MetaFunctions_inl.h new file mode 100644 index 0000000000000000000000000000000000000000..979a82b2c1c3179184a30faa3e1e2844cb59d658 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/MetaFunctions_inl.h @@ -0,0 +1,332 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunctions_inl.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +#if defined(AT_PER_OPERATOR_HEADERS) && defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS) +#error This change adds a dependency on all pytorch operators, meaning the \ + file will need to be re-compiled every time an operator is changed or added. \ + Consider including a specific operator from \ + . \ + See NOTE [TORCH_ASSERT_ONLY_METHOD_OPERATORS]. +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/MethodOperators.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/MethodOperators.h new file mode 100644 index 0000000000000000000000000000000000000000..fd1f397d49d356616dabaf0eac470581be450aae --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/MethodOperators.h @@ -0,0 +1,449 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +// @generated by torchgen/gen.py from MethodOperators.h + +#ifdef TORCH_ASSERT_NO_OPERATORS +#error This change adds a dependency on native_functions.yaml, \ + meaning the file will need to be re-compiled every time an operator \ + is changed or added. Consider if your change would be better placed in \ + another file, or if a more specific header might achieve the same goal. \ + See NOTE: [Tensor vs. TensorBase] +#endif + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace _ops { + +} // namespace _ops +} // namespace at + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/NamedTensor.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/NamedTensor.h new file mode 100644 index 0000000000000000000000000000000000000000..c558768f703970bfb5c4930e3a0522b323589b53 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/NamedTensor.h @@ -0,0 +1,6 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#include + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/NamedTensorUtils.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/NamedTensorUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..743c1827ac9c35fe4a15760d6fd4858b6ebda169 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/NamedTensorUtils.h @@ -0,0 +1,217 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +#include +#include +#include + +#include +#include + +namespace at { + +using NameVector = SmallVector; + +inline bool has_names(const ITensorListRef& tensors) { + return std::any_of(tensors.begin(), tensors.end(), [](const Tensor& t) { + return t.has_names(); + }); +} + +// Converts dim to an positional index. Errors if `dim` cannot be used to +// refer to any dimension of tensor. +TORCH_API int64_t dimname_to_position(const Tensor& tensor, Dimname dim); +TORCH_API std::vector dimnames_to_positions( + const Tensor& tensor, + DimnameList dims); + +// Unifies two DimnameList to produce a third. This is useful for implementing +// the named inference rule for binary broadcasting operations like add. +// +// There are three main constraints: +// 1) Check matching: Names must match positionally from the right. +// 2) Check misaligned: If a name `n` is in `names`, then it must appear at +// the same index from the right in other. +// 3) The output names are obtained by unifying the names individually from the +// right. +TORCH_API std::vector unify_from_right( + DimnameList names, + DimnameList other, + const char* action = "broadcast"); + +[[noreturn]] inline void reportNYIDimnameOverload(const char* op_name) { + TORCH_CHECK( + false, + op_name, + ": You passed a dimname (string) to this op in place of a dimension " + "index but it does not yet support this behavior. Please pass a dimension " + "index to work around this."); +} + +// [NOTE] Writing name inference rules +// +// Operators that support named tensors are either composed of operations that +// support named tensors or implement some name inference rule. An op that +// implements its own name inference rule generally looks like the following: +// +// Tensor op(...) { +// perform_shape_checks(...); +// # (1) +// auto maybe_outnames = compute_outnames(...); +// auto result = [&]() { +// NoNamesGuard guard; +// return op_impl(...); +// }(); +// # (2) +// propagate_names_if_nonempty(result, maybe_outnames); +// +// Each op has (1) a compute outnames step and (2) a propagate names step. +// +// compute_outnames is responsible for checking that input names match and +// determining what the output names should be. It returns either: +// - {} (if the inputs tensors are all unnamed) +// - non-empty outnames. +// +// propagate_names_if_nonempty propagates the outnames if they exist to the +// result tensors. +// +// The {} case is an optimization; if the user does not use named tensors they +// pay no perf cost for it. + +namespace namedinference { + +const Tensor& propagate_names_if_present_and_nonempty( + const Tensor& result, + std::optional maybe_names, + bool validate_names = false); +// Propagates `names` to `result` if `names` is not empty. +// `names` can be empty; see [NOTE] Writing name inference rules +// If `names` is not empty, `names.size()` should equal `result.dim()`. +// When in doubt, use this overload instead of the others. +TORCH_API const Tensor& propagate_names_if_nonempty( + const Tensor& result, + DimnameList maybe_names, + bool validate_names = false); + +// Propagates `names` to `result`. Only use this if we are certain that there +// are names to propagate (that names is not empty). +TORCH_API const Tensor& propagate_names( + const Tensor& result, + DimnameList names, + bool validate_names = false); + +// Propagates all names from src to result. +TORCH_API void propagate_names(const Tensor& result, const Tensor& src); + +// Propagates all names except for those at the excluded_idxs. +TORCH_API void propagate_names_except( + const Tensor& result, + const Tensor& src, + IntArrayRef excluded_idxs); + +// Used for reduction ops that have a `keepdim` arg. +TORCH_API void propagate_names_for_reduction( + const Tensor& result, + const Tensor& src, + IntArrayRef excluded_idxs, + bool keepdim); + +TORCH_API void propagate_names_for_expand( + const Tensor& result, + const Tensor& self); + +TORCH_API std::vector compute_cat_outnames( + const MaterializedITensorListRef& tensors); + +TORCH_API std::vector compute_broadcast_outnames( + const Tensor& self, + const Tensor& other); + +TORCH_API std::vector broadcast_to_outnames( + const Tensor& tensor, + const Tensor& reference_tensor, + const char* op_name); + +TORCH_API std::vector compute_matmul_outnames( + const Tensor& self, + const Tensor& other); + +TORCH_API std::vector compute_cdist_outnames( + const Tensor& self, + const Tensor& other); + +TORCH_API std::vector compute_bmm_outnames( + const Tensor& result, + const Tensor& self, + const Tensor& other); + +TORCH_API std::vector compute_squeeze_outnames(const Tensor& tensor); +TORCH_API std::vector compute_squeeze_outnames( + const Tensor& tensor, + std::bitset dims); + +std::vector compute_diagonal_outnames( + const Tensor& tensor, + int64_t dim1, + int64_t dim2); + +// TensorImpl* overloads for Legacy TH/THC code. Use these sparingly. + +TORCH_API TensorImpl* propagate_names_if_nonempty( + TensorImpl* result, + DimnameList maybe_names, + bool validate_names = false); + +TORCH_API TensorImpl* propagate_names( + TensorImpl* result, + DimnameList names, + bool validate_names = false); + +TORCH_API void propagate_names(TensorImpl* result, /*const */ TensorImpl* src); + +inline void propagate_names( + const TensorBase& result, + DimnameList names, + bool validate_names = false) { + propagate_names(result.unsafeGetTensorImpl(), names, validate_names); +} + +inline void propagate_names_if_nonempty( + const TensorBase& result, + DimnameList names, + bool validate_names = false) { + propagate_names_if_nonempty( + result.unsafeGetTensorImpl(), names, validate_names); +} + +inline void propagate_names(const TensorBase& result, const TensorBase& src) { + propagate_names(result.unsafeGetTensorImpl(), src.unsafeGetTensorImpl()); +} + +// result = m1 @ m2 + bias +TORCH_API std::vector propagate_names_for_addmm( + const Tensor& m1, + const Tensor& m2, + const Tensor& bias); + +TORCH_API std::vector propagate_names_for_addmv( + const Tensor& mat, + const Tensor& vec, + const Tensor& bias); + +TORCH_API void check_names_for_dot(TensorImpl* vec1, TensorImpl* vec2); + +TORCH_API std::vector compute_baddbmm_outnames( + const Tensor& result, + const Tensor& self, + const Tensor& other, + const Tensor& bias); + +TORCH_API bool are_names_equal(TensorImpl* self, TensorImpl* other); + +} // namespace namedinference + +} // namespace at + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/NativeFunctions.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/NativeFunctions.h new file mode 100644 index 0000000000000000000000000000000000000000..25657711150f36d74c9d8d695ed69b54f943831a --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/NativeFunctions.h @@ -0,0 +1,1366 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +// @generated by torchgen/gen.py from NativeFunctions.h + +#ifdef TORCH_ASSERT_NO_OPERATORS +#error This change adds a dependency on native_functions.yaml, \ + meaning the file will need to be re-compiled every time an operator \ + is changed or added. Consider if your change would be better placed in \ + another file, or if a more specific header might achieve the same goal. \ + See NOTE: [Tensor vs. TensorBase] +#endif + +#if defined(AT_PER_OPERATOR_HEADERS) && defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS) +#error This change adds a dependency on all pytorch operators, meaning the \ + file will need to be re-compiled every time an operator is changed or added. \ + Consider including a specific operator from \ + and see NOTE [TORCH_ASSERT_ONLY_METHOD_OPERATORS]. +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/NativeMetaFunctions.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/NativeMetaFunctions.h new file mode 100644 index 0000000000000000000000000000000000000000..493b85c86d88b12a34d4b1a89998d8cc19fe69b3 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/NativeMetaFunctions.h @@ -0,0 +1,1352 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +// @generated by torchgen/gen.py from NativeMetaFunctions.h + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { + +namespace meta { + + + +} // namespace meta +} // namespace at + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/NestedTensorImpl.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/NestedTensorImpl.h new file mode 100644 index 0000000000000000000000000000000000000000..3555257bccbdb7b8b730f2b78e300ffecf7dc9c2 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/NestedTensorImpl.h @@ -0,0 +1,292 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at::native { +struct NestedTensorImpl; +inline bool nested_tensor_impl_is_contiguous(const NestedTensorImpl* nt); +int64_t get_numel_from_nested_size_tensor(const at::Tensor& tensor); +at::Tensor construct_nested_strides(const at::Tensor& nested_size); +at::Tensor construct_offsets(const at::Tensor& nested_size); + +struct TORCH_API NestedTensorImpl : public c10::TensorImpl { + explicit NestedTensorImpl( + Storage storage, + c10::DispatchKeySet key_set, + const caffe2::TypeMeta data_type, + at::Tensor nested_sizes, + at::Tensor nested_strides, + at::Tensor storage_offsets); + + explicit NestedTensorImpl( + const at::Tensor& buffer, + at::Tensor nested_sizes, + at::Tensor nested_strides, + at::Tensor storage_offsets); + // assume contiguous, `nested_strides` and `offsets` + // can be inferred from `nested_sizes` + explicit NestedTensorImpl( + const at::Tensor& buffer, + const at::Tensor& nested_sizes); + + // This constructor is used creating view tensors from nested tensors + explicit NestedTensorImpl( + c10::TensorImpl::ImplType impl_type, + const at::Tensor& base_tensor, + at::Tensor nested_sizes, + at::Tensor nested_strides, + at::Tensor storage_offsets); + + // TODO: don't expose private implementation details like this; in + // particular, resizing this tensor will mess up our dim() and + // callers cannot fix it. + const Tensor& get_nested_sizes() const { + return nested_sizes_; + } + // TODO: don't expose private implementation details like this + const Tensor& get_nested_strides() const { + return nested_strides_; + } + const Tensor& get_storage_offsets() const { + return storage_offsets_; + } + // Returns nullopt if the ith dimension is irregular. The ith dimension + // of a NestedTensor is regular if the unbound tensors match in + // size at the (i-1)th dimension. + std::optional opt_size(int64_t d) const; + + int64_t size(int64_t d) const { + std::optional optional_size = this->opt_size(d); + TORCH_CHECK( + optional_size.has_value(), + "Given dimension ", + d, + " is irregular and does not have a size."); + return *optional_size; + } + /** + * Return a view of the nested tensor as a 1 dimensional contiguous tensor. + * + * The buffer tensor created by this function shares the same storage_impl as + * the original nested tensor, and therefore can be seen as a view. + * + * @return A newly constructed view tensor + */ + at::Tensor get_buffer() const { + TORCH_CHECK( + nested_tensor_impl_is_contiguous(this), + "NestedTensor must be contiguous to get buffer."); + return get_unsafe_storage_as_tensor(); + } + /** + * If possible use get_buffer() instead. This function returns the storage + * as a tensor directly, which is not safe to use in general. If using this + * function, The caller must ensure to account for nested_sizes, + * nested_strides and storage_offsets. + * + * @return A newly constructed view tensor + */ + at::Tensor get_unsafe_storage_as_tensor() const { + auto buffer_key_set_ = generate_buffer_key_set(); + const auto buffer_size = get_buffer_size(); + auto buffer_tensor_impl = c10::make_intrusive( + c10::TensorImpl::VIEW, Storage(storage_), buffer_key_set_, data_type_); + buffer_tensor_impl->set_sizes_contiguous( + c10::makeArrayRef(static_cast(buffer_size))); + return Tensor(buffer_tensor_impl); + } + + size_t get_buffer_size() const { + return storage_.nbytes() / data_type_.itemsize(); + } + + protected: + const char* tensorimpl_type_name() const override; + + // TODO: numel_custom and is_contiguous_custom can be profitably overridden + // with real implementations + int64_t numel_custom() const override; + c10::SymInt sym_numel_custom() const override; + c10::SymBool sym_is_contiguous_custom( + MemoryFormat /*memory_format*/) const override; + int64_t size_custom(int64_t d) const override { + return this->size(d); + } + c10::SymInt sym_size_custom(int64_t d) const override { + return c10::SymInt{this->size(d)}; + } + IntArrayRef sizes_custom() const override; + c10::SymIntArrayRef sym_sizes_custom() const override; + IntArrayRef strides_custom() const override; + c10::SymIntArrayRef sym_strides_custom() const override; + + // this one is real + int64_t dim_custom() const override; + + c10::intrusive_ptr shallow_copy_and_detach( + const c10::VariableVersion& version_counter, + bool allow_tensor_metadata_change) const override; + + c10::intrusive_ptr shallow_copy_and_detach( + c10::VariableVersion&& version_counter, + bool allow_tensor_metadata_change) const override; + + void shallow_copy_from(const c10::intrusive_ptr& impl) override { + copy_tensor_metadata( + /*src_impl=*/impl.get(), + /*dest_impl=*/this, + /*version_counter=*/version_counter(), + /*allow_tensor_metadata_change=*/allow_tensor_metadata_change()); + } + + private: + // Must be called after any changes to our dim() to sync the state + // to TensorImpl. + void refresh_dim(); + + // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) + const at::Tensor nested_sizes_, nested_strides_; + // The starting positions of the underlying tensors in contiguous buffer + // i.e. the buffer memory offsets to get the underlying tensors + // The reason to keep this metadata is that, without strong enough constraint + // it cannot be derived from `nested_sizes_` + // and `nested_strides_`: + // 1. when buffer has blanks, e.g. [tensor1, blank, tensor2] + // this can happen e.g. after slicing a nested tensor + // 2. when multiple tensors share a same memory + // 3. when the nesting ordering is changed, e.g. [tensor1, tensor3, tensor2] + // Some strong enough constraints are: + // 1. every underlying tensor is contiguous in memory + // && nesting in ascending order + // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) + const at::Tensor storage_offsets_; + // NOTE: -1 here means the size is missing + // Optional to allow it to be computed lazily from nested. + // TODO: maybe we can remove this metadata since + // we can compute it from `nested_sizes_` + mutable std::optional> opt_sizes_; + + template + c10::intrusive_ptr shallow_copy_and_detach_core( + VariableVersion&& version_counter, + bool allow_tensor_metadata_change) const; + + /** + * Generates a non-nested key_set from a nested tensor. + * + * For many nested tensor kernel implementations a buffer tensor + * is generated and redispatched to a non-nested kernel this function + * generates the key set used by that buffer tensor + * + * @return Appropriate key set for non-nested tensor + */ + inline c10::DispatchKeySet generate_buffer_key_set() const { + auto buffer_key_set = this->key_set(); + const bool Autograd = buffer_key_set.has_any(c10::autograd_dispatch_keyset); + // Remove nested tensor specific keys + buffer_key_set = buffer_key_set - + c10::DispatchKeySet{ + c10::DispatchKey::NestedTensor, + c10::DispatchKey::AutogradNestedTensor}; + + // Add dense tensor specific keys + buffer_key_set = + buffer_key_set | c10::DispatchKeySet{c10::DispatchKey::Dense}; + buffer_key_set = Autograd + ? c10::DispatchKeySet{c10::DispatchKey::Autograd} | buffer_key_set + : buffer_key_set; + + return buffer_key_set; + } +}; + +inline NestedTensorImpl* get_nested_tensor_impl_or_null( + const at::Tensor& tensor) { + if (tensor.is_nested()) { + return static_cast(tensor.unsafeGetTensorImpl()); + } + return nullptr; +} + +inline NestedTensorImpl* get_nested_tensor_impl(const at::Tensor& tensor) { + TORCH_CHECK( + tensor.is_nested(), "get_nested_tensor_impl requires a NestedTensor."); + return static_cast(tensor.unsafeGetTensorImpl()); +} + +inline bool nested_tensor_impl_is_contiguous(const NestedTensorImpl* nt) { + int64_t ntensors = nt->size(0); + if (ntensors == 0) { + return true; + } + const Tensor &sizemat = nt->get_nested_sizes(), + &stridemat = nt->get_nested_strides(); + const int64_t* offsets_ptr = + nt->get_storage_offsets().const_data_ptr(); + int64_t orig_dim = sizemat.size(1); + // nesting scalars + if (orig_dim == 0) { + // each scalar must be contiguous + // if there is blank memory between underlying scalars + for (int64_t i = 0; i < ntensors; i++) { + if (offsets_ptr[i] != i) { + return false; + } + } + } + // nesting tensors + else { + // if any underlying tensor is non-contiguous + const int64_t *sizemat_ptr = sizemat.const_data_ptr(), + *stridemat_ptr = stridemat.const_data_ptr(); + for (int64_t i = 0; i < ntensors; i++) { + if (stridemat_ptr[orig_dim - 1] != 1) { + return false; + } + int64_t product = sizemat_ptr[orig_dim - 1]; + for (int64_t j = orig_dim - 2; j >= 0; j--) { + if (stridemat_ptr[j] != product) { + return false; + } + product *= sizemat_ptr[j]; + } + sizemat_ptr += orig_dim; + stridemat_ptr += orig_dim; + } + // if there is blank memory between underlying tensors + if (offsets_ptr[0] != 0) { + return false; + } + sizemat_ptr = sizemat.const_data_ptr(); + stridemat_ptr = stridemat.const_data_ptr(); + for (int64_t i = 1; i < ntensors; i++) { + if (offsets_ptr[i] != + offsets_ptr[i - 1] + *sizemat_ptr * *stridemat_ptr) { + return false; + } + sizemat_ptr += orig_dim; + stridemat_ptr += orig_dim; + } + } + // everything is fine + return true; +} + +inline const at::Tensor& get_nested_sizes(const at::Tensor& tensor) { + return get_nested_tensor_impl(tensor)->get_nested_sizes(); +} + +} // namespace at::native + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/NumericUtils.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/NumericUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..452b4ef17e1a40ad2dc121f957478a24a878222b --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/NumericUtils.h @@ -0,0 +1,208 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#ifdef __HIPCC__ +#include +#endif + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace at { + +// std::isnan isn't performant to use on integral types; it will +// (uselessly) convert to floating point and then do the test. +// This function is. + +template , int> = 0> +inline C10_HOST_DEVICE bool _isnan(T /*val*/) { + return false; +} + +template , int> = 0> +inline C10_HOST_DEVICE bool _isnan(T val) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return ::isnan(val); +#else + return std::isnan(val); +#endif +} + +template ::value, int> = 0> +inline C10_HOST_DEVICE bool _isnan(T val) { + return std::isnan(val.real()) || std::isnan(val.imag()); +} + +template , int> = 0> +inline C10_HOST_DEVICE bool _isnan(T val) { + return at::_isnan(static_cast(val)); +} + +template < + typename T, + std::enable_if_t, int> = 0> +inline C10_HOST_DEVICE bool _isnan(at::BFloat16 val) { + return at::_isnan(static_cast(val)); +} + +inline C10_HOST_DEVICE bool _isnan(at::BFloat16 val) { + return at::_isnan(static_cast(val)); +} + +template < + typename T, + std::enable_if_t, int> = 0> +inline C10_HOST_DEVICE bool _isnan(T val) { + return val.isnan(); +} + +template < + typename T, + std::enable_if_t, int> = 0> +inline C10_HOST_DEVICE bool _isnan(T val) { + return val.isnan(); +} + +template < + typename T, + std::enable_if_t, int> = 0> +inline C10_HOST_DEVICE bool _isnan(T val) { + return val.isnan(); +} + +template < + typename T, + std::enable_if_t, int> = 0> +inline C10_HOST_DEVICE bool _isnan(T val) { + return val.isnan(); +} + +// std::isinf isn't performant to use on integral types; it will +// (uselessly) convert to floating point and then do the test. +// This function is. + +template , int> = 0> +inline C10_HOST_DEVICE bool _isinf(T /*val*/) { + return false; +} + +template , int> = 0> +inline C10_HOST_DEVICE bool _isinf(T val) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return ::isinf(val); +#else + return std::isinf(val); +#endif +} + +inline C10_HOST_DEVICE bool _isinf(at::Half val) { + return at::_isinf(static_cast(val)); +} + +inline C10_HOST_DEVICE bool _isinf(at::BFloat16 val) { + return at::_isinf(static_cast(val)); +} + +inline C10_HOST_DEVICE bool _isinf(at::Float8_e5m2 val) { + return val.isinf(); +} + +inline C10_HOST_DEVICE bool _isinf(at::Float8_e4m3fn val [[maybe_unused]]) { + return false; +} + +inline C10_HOST_DEVICE bool _isinf(at::Float8_e5m2fnuz val [[maybe_unused]]) { + return false; +} + +inline C10_HOST_DEVICE bool _isinf(at::Float8_e4m3fnuz val [[maybe_unused]]) { + return false; +} + +template +C10_HOST_DEVICE inline T exp(T x) { + static_assert( + !std::is_same_v, + "this template must be used with float or less precise type"); +#if defined(__CUDA_ARCH__) || defined(__HIP_ARCH__) + // use __expf fast approximation for peak bandwidth + return __expf(x); +#else + return ::exp(x); +#endif +} + +template <> +C10_HOST_DEVICE inline double exp(double x) { + return ::exp(x); +} + +template +C10_HOST_DEVICE inline T log(T x) { + static_assert( + !std::is_same_v, + "this template must be used with float or less precise type"); +#if defined(__CUDA_ARCH__) || defined(__HIP_ARCH__) + // use __logf fast approximation for peak bandwidth + return __logf(x); +#else + return ::log(x); +#endif +} + +template <> +C10_HOST_DEVICE inline double log(double x) { + return ::log(x); +} + +template +C10_HOST_DEVICE inline T log1p(T x) { + static_assert( + !std::is_same_v, + "this template must be used with float or less precise type"); +#if defined(__CUDA_ARCH__) || defined(__HIP_ARCH__) + // use __logf fast approximation for peak bandwidth + // NOTE: There is no __log1pf so unfortunately we lose precision. + return __logf(1.0f + x); +#else + return ::log1p(x); +#endif +} + +template <> +C10_HOST_DEVICE inline double log1p(double x) { + return ::log1p(x); +} + +template +C10_HOST_DEVICE inline T tan(T x) { + static_assert( + !std::is_same_v, + "this template must be used with float or less precise type"); +#if defined(__CUDA_ARCH__) || defined(__HIP_ARCH__) + // use __tanf fast approximation for peak bandwidth + return __tanf(x); +#else + return ::tan(x); +#endif +} + +template <> +C10_HOST_DEVICE inline double tan(double x) { + return ::tan(x); +} + +} // namespace at + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/OpMathType.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/OpMathType.h new file mode 100644 index 0000000000000000000000000000000000000000..1817c09ab454abb87b4f21cb7bd6657a130133ad --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/OpMathType.h @@ -0,0 +1,78 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { + +// For FP16 or BFloat16 inputs, ops should perform internal math in FP32. +template +struct OpMathType { + using type = scalar_t; +}; +template <> +struct OpMathType { + using type = float; +}; +template <> +struct OpMathType { + using type = float; +}; +template <> +struct OpMathType { + using type = float; +}; +template <> +struct OpMathType { + using type = float; +}; +template <> +struct OpMathType { + using type = float; +}; +template <> +struct OpMathType { + using type = float; +}; +template <> +struct OpMathType { + using type = float; +}; +template <> +struct OpMathType> { + using type = c10::complex; +}; + +template +using opmath_type = typename OpMathType::type; + +namespace { + +inline c10::ScalarType toOpMathType(const c10::ScalarType type) { + switch (type) { +#define DEFINE_CASE(scalar_t, TypeNum) \ + case ScalarType::TypeNum: \ + return CppTypeToScalarType>::value; + + AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_CASE) +#undef DEFINE_CASE + + default: + TORCH_INTERNAL_ASSERT(false, "Unrecognized ScalarType: ", type); + } +} + +} // namespace + +} // namespace at + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/OpaqueTensorImpl.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/OpaqueTensorImpl.h new file mode 100644 index 0000000000000000000000000000000000000000..329caeda99e3ca9fdca29de23552e13e25d9fcf8 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/OpaqueTensorImpl.h @@ -0,0 +1,211 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include + +namespace at { + +// An "Opaque" TensorImpl -- there are no strides and (for now) +// even data() is not supported (thus no pointer arithmetic). + +// NOTE: We could allow data() in the future, but would have to ensure pointer +// arithmetic code is properly guarded. +// +// NOTE: This does not support resize_ (and other metadata-changing ops) because +// of `shallow_copy_and_detach`. We would need to define an interface to +// "shallow copy" in order to add support. + +template +struct TORCH_API OpaqueTensorImpl : public TensorImpl { + // public constructor for now... + OpaqueTensorImpl( + at::DispatchKeySet key_set, + const caffe2::TypeMeta data_type, + c10::Device device, + OpaqueHandle opaque_handle, + c10::IntArrayRef sizes, + bool is_non_overlapping_and_dense = true) + : TensorImpl(key_set, data_type, device), + opaque_handle_(std::move(opaque_handle)) { + constructor_impl(sizes, is_non_overlapping_and_dense); + } + + OpaqueTensorImpl( + TensorImpl::ImplType impl_type, + c10::Storage&& storage, + at::DispatchKeySet key_set, + const caffe2::TypeMeta data_type, + OpaqueHandle opaque_handle, + c10::IntArrayRef sizes, + bool is_non_overlapping_and_dense = true) + : TensorImpl(impl_type, std::move(storage), key_set, data_type), + opaque_handle_(std::move(opaque_handle)) { + constructor_impl(sizes, is_non_overlapping_and_dense); + } + + // Destructor doesn't call release_resources because it's + // unnecessary; don't forget to change that if needed! + void release_resources() override { + TensorImpl::release_resources(); + opaque_handle_ = {}; + } + + void set_size(int64_t dim, int64_t new_size) override { + TORCH_CHECK(false, "opaque tensors do not have set_size"); + } + + void set_stride(int64_t dim, int64_t new_stride) override { + TORCH_CHECK(false, "opaque tensors do not have set_stride"); + } + + void set_storage_offset(int64_t storage_offset) override { + TORCH_CHECK(false, "opaque tensors do not have set_storage_offset"); + } + +#ifdef DEBUG + bool has_storage() const override { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + !storage_, "OpaqueTensorImpl assumes that storage_ is never set"); + return false; + } +#endif + + /** + * Return a TensorImpl that is a shallow-copy of this TensorImpl. + * + * For usage of `version_counter` and `allow_tensor_metadata_change`, + * see NOTE [ TensorImpl Shallow-Copying ]. + */ + c10::intrusive_ptr shallow_copy_and_detach( + const c10::VariableVersion& version_counter, + bool allow_tensor_metadata_change) const override { + auto impl = c10::make_intrusive>( + key_set(), + dtype(), + device(), + opaque_handle_, + sizes_and_strides_.sizes_arrayref()); + copy_tensor_metadata( + /*src_opaque_impl=*/this, + /*dest_opaque_impl=*/impl.get(), + /*version_counter=*/version_counter, + /*allow_tensor_metadata_change=*/allow_tensor_metadata_change); + impl->refresh_numel(); + return impl; + } + + /** + * Return a TensorImpl that is a shallow-copy of this TensorImpl. + * + * For usage of `version_counter` and `allow_tensor_metadata_change`, + * see NOTE [ TensorImpl Shallow-Copying ]. + */ + c10::intrusive_ptr shallow_copy_and_detach( + c10::VariableVersion&& version_counter, + bool allow_tensor_metadata_change) const override { + auto impl = c10::make_intrusive>( + key_set(), + dtype(), + device(), + opaque_handle_, + sizes_and_strides_.sizes_arrayref()); + copy_tensor_metadata( + /*src_opaque_impl=*/this, + /*dest_opaque_impl=*/impl.get(), + /*version_counter=*/std::move(version_counter), + /*allow_tensor_metadata_change=*/allow_tensor_metadata_change); + impl->refresh_numel(); + return impl; + } + + /** + * Shallow-copies data from another TensorImpl into this TensorImpl. + * + * For why this function doesn't check this TensorImpl's + * `allow_tensor_metadata_change_`, see NOTE [ TensorImpl Shallow-Copying ]. + */ + void shallow_copy_from(const c10::intrusive_ptr& impl) override { + AT_ASSERT(has_compatible_shallow_copy_type(impl->key_set())); + auto opaque_impl = + static_cast*>(impl.get()); + copy_tensor_metadata( + /*src_impl=*/opaque_impl, + /*dest_impl=*/this, + /*version_counter=*/version_counter(), + /*allow_tensor_metadata_change=*/allow_tensor_metadata_change()); + refresh_numel(); + } + + const OpaqueHandle& opaque_handle() const { + return opaque_handle_; + } + + OpaqueHandle& unsafe_opaque_handle() { + return opaque_handle_; + } + + protected: + /** + * Copy the tensor metadata fields (e.g. sizes / strides / storage pointer / + * storage_offset) from one TensorImpl to another TensorImpl. + * + * For usage of `version_counter` and `allow_tensor_metadata_change`, see NOTE + * [ TensorImpl Shallow-Copying ]. + */ + static void copy_tensor_metadata( + const OpaqueTensorImpl* src_opaque_impl, + OpaqueTensorImpl* dest_opaque_impl, + const c10::VariableVersion& version_counter, + bool allow_tensor_metadata_change) { + TensorImpl::copy_tensor_metadata( + src_opaque_impl, + dest_opaque_impl, + version_counter, + allow_tensor_metadata_change); + + // OpaqueTensorImpl-specific fields. + dest_opaque_impl->opaque_handle_ = src_opaque_impl->opaque_handle_; + } + + static void copy_tensor_metadata( + const OpaqueTensorImpl* src_opaque_impl, + OpaqueTensorImpl* dest_opaque_impl, + c10::VariableVersion&& version_counter, + bool allow_tensor_metadata_change) { + TensorImpl::copy_tensor_metadata( + src_opaque_impl, + dest_opaque_impl, + std::move(version_counter), + allow_tensor_metadata_change); + + // OpaqueTensorImpl-specific fields. + dest_opaque_impl->opaque_handle_ = src_opaque_impl->opaque_handle_; + } + + private: + const char* tensorimpl_type_name() const override { + return "OpaqueTensorImpl"; + } + + void constructor_impl( + c10::IntArrayRef sizes, + bool is_non_overlapping_and_dense) { + set_storage_access_should_throw(); + set_custom_sizes_strides(SizesStridesPolicy::CustomStrides); + sizes_and_strides_.set_sizes(sizes); + refresh_numel(); + // NOLINTNEXTLINE(cppcoreguidelines-prefer-member-initializer) + is_non_overlapping_and_dense_ = is_non_overlapping_and_dense; + } + + OpaqueHandle opaque_handle_; +}; + +} // namespace at + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/Operators.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/Operators.h new file mode 100644 index 0000000000000000000000000000000000000000..2f631376f3f53bfcf9e4bc72fb4ee64974ee8d06 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/Operators.h @@ -0,0 +1,1407 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +// @generated by torchgen/gen.py from Operators.h + +#ifdef TORCH_ASSERT_NO_OPERATORS +#error This change adds a dependency on native_functions.yaml, \ + meaning the file will need to be re-compiled every time an operator \ + is changed or added. Consider if your change would be better placed in \ + another file, or if a more specific header might achieve the same goal. \ + See NOTE: [Tensor vs. TensorBase] +#endif + +#if defined(AT_PER_OPERATOR_HEADERS) && defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS) +#error This change adds a dependency on all pytorch operators, meaning the \ + file will need to be re-compiled every time an operator is changed or added. \ + Consider including a specific operator from \ + and see NOTE [TORCH_ASSERT_ONLY_METHOD_OPERATORS]. +#endif + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// Extension writers: do you write wrapper functions? Are you frustrated with +// resolving overloads of operators? Are you frustrated with dealing with +// pointer-to-methods and resolving overloads of pointer-to-methods?? Look no +// further, this is the utility for you. +// +// Given an operator schema: aten::op.overload(... +// +// Use ATEN_FN2(op, overload) to get a *function* version of the operator +// that is guaranteed to not be overloaded. This means that you can safely +// decltype(&ATEN_FN2(op, overload)) it. NB: the 2 means this macro takes 2 args. +// +// Given an operator schema without an overload name: aten::op(... +// +// Use ATEN_FN(op) to get an unambiguous *function* version of the operator. +// +// There is some interesting behavior for out= operations. +// ATEN_FN2(sin, out) gives a function that is *faithful* to the schema; +// that is, the order of arguments is exactly what it looks like in the schema. + +#define ATEN_FN2(op_name, overload) at::_ops::op_name##_##overload::call +#define ATEN_FN(op_name) at::_ops::op_name::call + +// Separately, ATEN_OP(op) and ATEN_OP2(op, overload) define a class containing compile-time +// metadata about a given aten operator. +// Notable data on the class includes: +// - ATEN_OP2(add, Tensor)::name // returns the string name: "add" +// - ATEN_OP2(add, Tensor)::overload_name // returns the string overload name: "Tensor" +// - ATEN_OP2(add, Tensor)::schema // returns the C++ schema type: at::Tensor (const at::Tensor &, const at::Tensor &, const at::Scalar &) +// - ATEN_OP2(add, Tensor)::schema_str // returns the string jit type: "add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor" + +#define ATEN_OP2(op_name, overload) at::_ops::op_name##_##overload +#define ATEN_OP(op_name) at::_ops::op_name + +// WARNING: Please do not call any of the ops in the _ops namespace directly. +// Use the ATEN_FN macros. We do not guarantee stability of the naming +// scheme for the functions in at::_ops + +// See Note [The ATen Operators API] for details of the at::_ops namespace + +namespace at { +namespace _ops { + +} // namespace _ops +} // namespace at + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/PTThreadPool.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/PTThreadPool.h new file mode 100644 index 0000000000000000000000000000000000000000..416d72dd8e6fcaab45d6b0715b86c8375937da4a --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/PTThreadPool.h @@ -0,0 +1,22 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include + +namespace at { + +class TORCH_API PTThreadPool : public c10::ThreadPool { + public: + explicit PTThreadPool(int pool_size, int numa_node_id = -1) + : c10::ThreadPool(pool_size, numa_node_id, []() { + c10::setThreadName("PTThreadPool"); + at::init_num_threads(); + }) {} +}; + +} // namespace at + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/PadNd.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/PadNd.h new file mode 100644 index 0000000000000000000000000000000000000000..e11341d5cec1fe84d92712613d0241fdb7243815 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/PadNd.h @@ -0,0 +1,17 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +namespace at { + +enum class padding_mode { + reflect, + replicate, + circular, + constant, +}; + +} // namespace at + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/Parallel-inl.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/Parallel-inl.h new file mode 100644 index 0000000000000000000000000000000000000000..d944db83e5ff34b0c0f2bbabb76ec3b7d54b4da7 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/Parallel-inl.h @@ -0,0 +1,98 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include + +namespace at { + +template +inline void parallel_for( + const int64_t begin, + const int64_t end, + const int64_t grain_size, + const F& f) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(grain_size >= 0); + if (begin >= end) { + return; + } + +#ifdef INTRA_OP_PARALLEL + at::internal::lazy_init_num_threads(); + const auto numiter = end - begin; + const bool use_parallel = + (numiter > grain_size && numiter > 1 && !at::in_parallel_region() && + at::get_num_threads() > 1); + if (!use_parallel) { + internal::ThreadIdGuard tid_guard(0); + c10::ParallelGuard guard(true); + f(begin, end); + return; + } + + internal::invoke_parallel( + begin, end, grain_size, [&](int64_t begin, int64_t end) { + c10::ParallelGuard guard(true); + f(begin, end); + }); +#else + internal::ThreadIdGuard tid_guard(0); + c10::ParallelGuard guard(true); + f(begin, end); +#endif +} + +template +inline scalar_t parallel_reduce( + const int64_t begin, + const int64_t end, + const int64_t grain_size, + const scalar_t ident, + const F& f, + const SF& sf) { + TORCH_CHECK(grain_size >= 0); + if (begin >= end) { + return ident; + } + +#ifdef INTRA_OP_PARALLEL + at::internal::lazy_init_num_threads(); + const auto max_threads = at::get_num_threads(); + const bool use_parallel = + ((end - begin) > grain_size && !at::in_parallel_region() && + max_threads > 1); + if (!use_parallel) { + internal::ThreadIdGuard tid_guard(0); + c10::ParallelGuard guard(true); + return f(begin, end, ident); + } + + c10::SmallVector results(max_threads, ident); + internal::invoke_parallel( + begin, + end, + grain_size, + [&](const int64_t my_begin, const int64_t my_end) { + const auto tid = at::get_thread_num(); + c10::ParallelGuard guard(true); + results[tid] = f(my_begin, my_end, ident); + }); + + scalar_t result = ident; + for (auto partial_result : results) { + result = sf(result, partial_result); + } + return result; +#else + internal::ThreadIdGuard tid_guard(0); + c10::ParallelGuard guard(true); + return f(begin, end, ident); +#endif +} + +} // namespace at + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/Parallel.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/Parallel.h new file mode 100644 index 0000000000000000000000000000000000000000..83e227411a2d77544aabae385e7d4729e75a52a1 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/Parallel.h @@ -0,0 +1,163 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +#include +#include +#include +#include + +namespace at { + +inline int64_t divup(int64_t x, int64_t y) { + return (x + y - 1) / y; +} + +// Called during new thread initialization +TORCH_API void init_num_threads(); + +// Sets the number of threads to be used in parallel region +TORCH_API void set_num_threads(int /*nthreads*/); + +// Returns the maximum number of threads that may be used in a parallel region +TORCH_API int get_num_threads(); + +// Returns the current thread number (starting from 0) +// in the current parallel region, or 0 in the sequential region +TORCH_API int get_thread_num(); + +// Checks whether the code runs in parallel region +TORCH_API bool in_parallel_region(); + +namespace internal { + +// Initialise num_threads lazily at first parallel call +inline void lazy_init_num_threads() { + thread_local bool init = false; + if (C10_UNLIKELY(!init)) { + at::init_num_threads(); + init = true; + } +} + +TORCH_API void set_thread_num(int /*id*/); + +class TORCH_API ThreadIdGuard { + public: + ThreadIdGuard(int new_id) : old_id_(at::get_thread_num()) { + set_thread_num(new_id); + } + + ~ThreadIdGuard() { + set_thread_num(old_id_); + } + + private: + int old_id_; +}; + +} // namespace internal + +/* +parallel_for + +begin: index at which to start applying user function + +end: index at which to stop applying user function + +grain_size: number of elements per chunk. impacts the degree of parallelization + +f: user function applied in parallel to the chunks, signature: + void f(int64_t begin, int64_t end) + +Warning: parallel_for does NOT copy thread local +states from the current thread to the worker threads. +This means for example that Tensor operations CANNOT be used in the +body of your function, only data pointers. +*/ +template +inline void parallel_for( + const int64_t begin, + const int64_t end, + const int64_t grain_size, + const F& f); + +/* +parallel_reduce + +begin: index at which to start applying reduction + +end: index at which to stop applying reduction + +grain_size: number of elements per chunk. impacts number of elements in +intermediate results tensor and degree of parallelization. + +ident: identity for binary combination function sf. sf(ident, x) needs to return +x. + +f: function for reduction over a chunk. f needs to be of signature scalar_t +f(int64_t partial_begin, int64_t partial_end, scalar_t identify) + +sf: function to combine two partial results. sf needs to be of signature +scalar_t sf(scalar_t x, scalar_t y) + +For example, you might have a tensor of 10000 entries and want to sum together +all the elements. Parallel_reduce with a grain_size of 2500 will then allocate +an intermediate result tensor with 4 elements. Then it will execute the function +"f" you provide and pass the beginning and end index of these chunks, so +0-2499, 2500-4999, etc. and the combination identity. It will then write out +the result from each of these chunks into the intermediate result tensor. After +that it'll reduce the partial results from each chunk into a single number using +the combination function sf and the identity ident. For a total summation this +would be "+" and 0 respectively. This is similar to tbb's approach [1], where +you need to provide a function to accumulate a subrange, a function to combine +two partial results and an identity. + +Warning: parallel_reduce does NOT copy thread local +states from the current thread to the worker threads. +This means for example that Tensor operations CANNOT be used in the +body of your function, only data pointers. + +[1] https://software.intel.com/en-us/node/506154 +*/ +template +inline scalar_t parallel_reduce( + const int64_t begin, + const int64_t end, + const int64_t grain_size, + const scalar_t ident, + const F& f, + const SF& sf); + +// Returns a detailed string describing parallelization settings +TORCH_API std::string get_parallel_info(); + +// Sets number of threads used for inter-op parallelism +TORCH_API void set_num_interop_threads(int /*nthreads*/); + +// Returns the number of threads used for inter-op parallelism +TORCH_API size_t get_num_interop_threads(); + +// Launches inter-op parallel task +TORCH_API void launch(std::function func); +namespace internal { +void launch_no_thread_state(std::function fn); +} // namespace internal + +// Launches intra-op parallel task +TORCH_API void intraop_launch(const std::function& func); + +// Returns number of intra-op threads used by default +TORCH_API int intraop_default_num_threads(); + +} // namespace at + +#if AT_PARALLEL_OPENMP +#include // IWYU pragma: keep +#elif AT_PARALLEL_NATIVE +#include // IWYU pragma: keep +#endif + +#include // IWYU pragma: keep + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/ParallelFuture.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/ParallelFuture.h new file mode 100644 index 0000000000000000000000000000000000000000..c0f3f434d127c2ae2169e1a047ec9c933b2e56a6 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/ParallelFuture.h @@ -0,0 +1,18 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include + +namespace at { + +// Launches intra-op parallel task, returns a future +TORCH_API c10::intrusive_ptr intraop_launch_future( + const std::function& func); + +} // namespace at + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/ParallelNative.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/ParallelNative.h new file mode 100644 index 0000000000000000000000000000000000000000..f1dbd84bafb8275d8b2d579a06e5f565c4aa5ca9 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/ParallelNative.h @@ -0,0 +1,20 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include + +#define INTRA_OP_PARALLEL + +namespace at::internal { + +TORCH_API void invoke_parallel( + const int64_t begin, + const int64_t end, + const int64_t grain_size, + const std::function& f); + +} // namespace at::internal + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/ParallelOpenMP.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/ParallelOpenMP.h new file mode 100644 index 0000000000000000000000000000000000000000..d5cb3134f09841f70f7054c4f9671cb1f03f8488 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/ParallelOpenMP.h @@ -0,0 +1,59 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include + +#ifdef _OPENMP +#define INTRA_OP_PARALLEL + +#include +#endif + +#ifdef _OPENMP +namespace at::internal { +template +inline void invoke_parallel( + int64_t begin, + int64_t end, + int64_t grain_size, + const F& f) { + std::atomic_flag err_flag = ATOMIC_FLAG_INIT; + std::exception_ptr eptr; + +#pragma omp parallel + { + // choose number of tasks based on grain size and number of threads + // can't use num_threads clause due to bugs in GOMP's thread pool (See + // #32008) + int64_t num_threads = omp_get_num_threads(); + if (grain_size > 0) { + num_threads = std::min(num_threads, divup((end - begin), grain_size)); + } + + int64_t tid = omp_get_thread_num(); + int64_t chunk_size = divup((end - begin), num_threads); + int64_t begin_tid = begin + tid * chunk_size; + if (begin_tid < end) { + try { + internal::ThreadIdGuard tid_guard(tid); + f(begin_tid, std::min(end, chunk_size + begin_tid)); + } catch (...) { + if (!err_flag.test_and_set()) { + eptr = std::current_exception(); + } + } + } + } + if (eptr) { + std::rethrow_exception(eptr); + } +} +} // namespace at::internal +#endif // _OPENMP + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/PythonTorchFunctionTLS.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/PythonTorchFunctionTLS.h new file mode 100644 index 0000000000000000000000000000000000000000..e0cb73ec391fc2d10d82dfacf224ba64d034ecec --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/PythonTorchFunctionTLS.h @@ -0,0 +1,42 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include + +namespace at::impl { + +enum TorchFunctionDisabledState { ENABLED, SUBCLASSES_DISABLED, ALL_DISABLED }; + +struct TORCH_API PythonTorchFunctionTLS { + static void set_disabled_state(TorchFunctionDisabledState disabled_state_); + static TorchFunctionDisabledState get_disabled_state(); + + static void push_onto_stack(std::shared_ptr mode); + static const std::shared_ptr pop_stack(); + static const std::shared_ptr& get_stack_at(int64_t idx); + static int64_t stack_len(); + + static const PythonTorchFunctionTLS& get_state(); + static void set_state(const PythonTorchFunctionTLS& state); + + private: + // The mode TLS is split into + // - disabled_state, which says which part of torch function are disabled + // - stack_, which is a vector of modes representing the stack of user + // defined modes + TorchFunctionDisabledState disabled_state_ = + TorchFunctionDisabledState::ENABLED; + std::vector> stack_; + friend TORCH_API bool torch_function_mode_enabled(); +}; + +TORCH_API bool torch_function_mode_enabled(); + +TORCH_API bool torch_function_all_disabled(); + +} // namespace at::impl + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/ROCmFABackend.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/ROCmFABackend.h new file mode 100644 index 0000000000000000000000000000000000000000..e88dbe5614dd02c0cbdbc6f9fc3e635611787fca --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/ROCmFABackend.h @@ -0,0 +1,36 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include + +#include +#include + +namespace at { + +enum class ROCmFABackend : int8_t { Default, AOTriton, Ck }; + +inline std::string ROCmFABackendToString(at::ROCmFABackend backend) { + switch (backend) { + case ROCmFABackend::Default: + return "at::ROCmFABackend::Default"; + case ROCmFABackend::AOTriton: + return "at::ROCmFABackend::AOTriton"; + case ROCmFABackend::Ck: + return "at::ROCmFABackend::Ck"; + default: + TORCH_CHECK(false, "Unknown ROCm flash attention backend") + } +} + +inline std::ostream& operator<<( + std::ostream& stream, + at::ROCmFABackend backend) { + return stream << ROCmFABackendToString(backend); +} + +} // namespace at + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/RedispatchFunctions.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/RedispatchFunctions.h new file mode 100644 index 0000000000000000000000000000000000000000..ca3fe1b24d14760aba858042c8161c2cf1e66bbe --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/RedispatchFunctions.h @@ -0,0 +1,25616 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +// @generated by torchgen/gen.py from RedispatchFunctions.h + +#ifdef TORCH_ASSERT_ONLY_METHOD_OPERATORS +#error This change adds a dependency on all pytorch operators, meaning the \ + file will need to be re-compiled every time an operator is changed or added. \ + Consider using the at::_ops::{name}::redispatch() interface by including \ + the specific operator from +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { + +namespace redispatch { + + // aten::_cast_Byte(Tensor self, bool non_blocking=False) -> Tensor + inline at::Tensor _cast_Byte(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool non_blocking=false) { + return at::_ops::_cast_Byte::redispatch(dispatchKeySet, self, non_blocking); + } + + // aten::_cast_Char(Tensor self, bool non_blocking=False) -> Tensor + inline at::Tensor _cast_Char(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool non_blocking=false) { + return at::_ops::_cast_Char::redispatch(dispatchKeySet, self, non_blocking); + } + + // aten::_cast_Double(Tensor self, bool non_blocking=False) -> Tensor + inline at::Tensor _cast_Double(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool non_blocking=false) { + return at::_ops::_cast_Double::redispatch(dispatchKeySet, self, non_blocking); + } + + // aten::_cast_Float(Tensor self, bool non_blocking=False) -> Tensor + inline at::Tensor _cast_Float(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool non_blocking=false) { + return at::_ops::_cast_Float::redispatch(dispatchKeySet, self, non_blocking); + } + + // aten::_cast_Int(Tensor self, bool non_blocking=False) -> Tensor + inline at::Tensor _cast_Int(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool non_blocking=false) { + return at::_ops::_cast_Int::redispatch(dispatchKeySet, self, non_blocking); + } + + // aten::_cast_Long(Tensor self, bool non_blocking=False) -> Tensor + inline at::Tensor _cast_Long(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool non_blocking=false) { + return at::_ops::_cast_Long::redispatch(dispatchKeySet, self, non_blocking); + } + + // aten::_cast_Short(Tensor self, bool non_blocking=False) -> Tensor + inline at::Tensor _cast_Short(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool non_blocking=false) { + return at::_ops::_cast_Short::redispatch(dispatchKeySet, self, non_blocking); + } + + // aten::_cast_Half(Tensor self, bool non_blocking=False) -> Tensor + inline at::Tensor _cast_Half(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool non_blocking=false) { + return at::_ops::_cast_Half::redispatch(dispatchKeySet, self, non_blocking); + } + + // aten::_backward(Tensor self, Tensor[] inputs, Tensor? gradient=None, bool? retain_graph=None, bool create_graph=False) -> () + inline void __dispatch__backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::TensorList inputs, const ::std::optional & gradient={}, ::std::optional retain_graph=::std::nullopt, bool create_graph=false) { + return at::_ops::_backward::redispatch(dispatchKeySet, self, inputs, gradient, retain_graph, create_graph); + } + + // aten::set_data(Tensor(a!) self, Tensor new_data) -> () + inline void __dispatch_set_data(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & new_data) { + return at::_ops::set_data::redispatch(dispatchKeySet, self, new_data); + } + + // aten::data(Tensor self) -> Tensor + inline at::Tensor __dispatch_data(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::data::redispatch(dispatchKeySet, self); + } + + // aten::is_leaf(Tensor self) -> bool + inline bool __dispatch_is_leaf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::is_leaf::redispatch(dispatchKeySet, self); + } + + // aten::output_nr(Tensor self) -> int + inline int64_t __dispatch_output_nr(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::output_nr::redispatch(dispatchKeySet, self); + } + + // aten::_version(Tensor self) -> int + inline int64_t __dispatch__version(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::_version::redispatch(dispatchKeySet, self); + } + + // aten::requires_grad_(Tensor(a!) self, bool requires_grad=True) -> Tensor(a!) + inline at::Tensor & __dispatch_requires_grad_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, bool requires_grad=true) { + return at::_ops::requires_grad_::redispatch(dispatchKeySet, self, requires_grad); + } + + // aten::retain_grad(Tensor(a!) self) -> () + inline void __dispatch_retain_grad(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::retain_grad::redispatch(dispatchKeySet, self); + } + + // aten::retains_grad(Tensor self) -> bool + inline bool __dispatch_retains_grad(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::retains_grad::redispatch(dispatchKeySet, self); + } + + // aten::_fw_primal(Tensor(a) self, int level) -> Tensor(a) + inline at::Tensor _fw_primal(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t level) { + return at::_ops::_fw_primal::redispatch(dispatchKeySet, self, level); + } + + // aten::_make_dual(Tensor(a) primal, Tensor tangent, int level) -> Tensor(a) + inline at::Tensor _make_dual(c10::DispatchKeySet dispatchKeySet, const at::Tensor & primal, const at::Tensor & tangent, int64_t level) { + return at::_ops::_make_dual::redispatch(dispatchKeySet, primal, tangent, level); + } + + // aten::_unpack_dual(Tensor(a) dual, int level) -> (Tensor(a) primal, Tensor tangent) + inline ::std::tuple _unpack_dual(c10::DispatchKeySet dispatchKeySet, const at::Tensor & dual, int64_t level) { + return at::_ops::_unpack_dual::redispatch(dispatchKeySet, dual, level); + } + + // aten::_new_zeros_with_same_feature_meta(Tensor self, Tensor other, *, int self_num_batch_dims=0) -> Tensor + inline at::Tensor _new_zeros_with_same_feature_meta(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, int64_t self_num_batch_dims=0) { + return at::_ops::_new_zeros_with_same_feature_meta::redispatch(dispatchKeySet, self, other, self_num_batch_dims); + } + + // aten::_has_same_storage_numel(Tensor self, Tensor other) -> bool + inline bool _has_same_storage_numel(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::_has_same_storage_numel::redispatch(dispatchKeySet, self, other); + } + + // aten::rename_(Tensor(a!) self, Dimname[]? names) -> Tensor(a!) + inline at::Tensor & rename_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, ::std::optional names) { + return at::_ops::rename_::redispatch(dispatchKeySet, self, names); + } + + // aten::rename(Tensor(a) self, Dimname[]? names) -> Tensor(a) + inline at::Tensor rename(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional names) { + return at::_ops::rename::redispatch(dispatchKeySet, self, names); + } + + // aten::align_to(Tensor(a) self, Dimname[] names) -> Tensor(a) + inline at::Tensor align_to(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::DimnameList names) { + return at::_ops::align_to::redispatch(dispatchKeySet, self, names); + } + + // aten::align_to.ellipsis_idx(Tensor(a) self, Dimname[] order, int ellipsis_idx) -> Tensor(a) + inline at::Tensor align_to(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::DimnameList order, int64_t ellipsis_idx) { + return at::_ops::align_to_ellipsis_idx::redispatch(dispatchKeySet, self, order, ellipsis_idx); + } + + // aten::align_as(Tensor self, Tensor other) -> Tensor + inline at::Tensor align_as(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::align_as::redispatch(dispatchKeySet, self, other); + } + + // aten::align_tensors(Tensor[] tensors) -> Tensor[] + inline ::std::vector align_tensors(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors) { + return at::_ops::align_tensors::redispatch(dispatchKeySet, tensors); + } + + // aten::_assert_async(Tensor self) -> () + inline void _assert_async(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::_assert_async::redispatch(dispatchKeySet, self); + } + + // aten::_assert_async.msg(Tensor self, str assert_msg) -> () + inline void _assert_async(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::string_view assert_msg) { + return at::_ops::_assert_async_msg::redispatch(dispatchKeySet, self, assert_msg); + } + + // aten::_assert_scalar(Scalar self, str assert_msg) -> () + inline void _assert_scalar(c10::DispatchKeySet dispatchKeySet, const at::Scalar & self, c10::string_view assert_msg) { + return at::_ops::_assert_scalar::redispatch(dispatchKeySet, self, assert_msg); + } + + // aten::_functional_assert_scalar(Scalar self, str assert_msg, Tensor dep_token) -> Tensor + inline at::Tensor _functional_assert_scalar(c10::DispatchKeySet dispatchKeySet, const at::Scalar & self, c10::string_view assert_msg, const at::Tensor & dep_token) { + return at::_ops::_functional_assert_scalar::redispatch(dispatchKeySet, self, assert_msg, dep_token); + } + + // aten::_functional_assert_async.msg(Tensor self, str assert_msg, Tensor dep_token) -> Tensor + inline at::Tensor _functional_assert_async(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::string_view assert_msg, const at::Tensor & dep_token) { + return at::_ops::_functional_assert_async_msg::redispatch(dispatchKeySet, self, assert_msg, dep_token); + } + + // aten::_assert_tensor_metadata(Tensor a, SymInt[]? size=None, SymInt[]? stride=None, ScalarType? dtype=None, *, Device? device=None, Layout? layout=None) -> () + inline void _assert_tensor_metadata(c10::DispatchKeySet dispatchKeySet, const at::Tensor & a, at::OptionalIntArrayRef size=::std::nullopt, at::OptionalIntArrayRef stride=::std::nullopt, ::std::optional dtype=::std::nullopt, ::std::optional device=::std::nullopt, ::std::optional layout=::std::nullopt) { + return at::_ops::_assert_tensor_metadata::redispatch(dispatchKeySet, a, size.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*size)) : ::std::nullopt, stride.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*stride)) : ::std::nullopt, dtype, device, layout); + } + + // aten::_assert_tensor_metadata(Tensor a, SymInt[]? size=None, SymInt[]? stride=None, ScalarType? dtype=None, *, Device? device=None, Layout? layout=None) -> () + inline void _assert_tensor_metadata_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & a, at::OptionalSymIntArrayRef size=::std::nullopt, at::OptionalSymIntArrayRef stride=::std::nullopt, ::std::optional dtype=::std::nullopt, ::std::optional device=::std::nullopt, ::std::optional layout=::std::nullopt) { + return at::_ops::_assert_tensor_metadata::redispatch(dispatchKeySet, a, size, stride, dtype, device, layout); + } + + // aten::_print(str s) -> () + inline void _print(c10::DispatchKeySet dispatchKeySet, c10::string_view s) { + return at::_ops::_print::redispatch(dispatchKeySet, s); + } + + // aten::sym_constrain_range(Scalar size, *, int? min=None, int? max=None) -> () + inline void sym_constrain_range(c10::DispatchKeySet dispatchKeySet, const at::Scalar & size, ::std::optional min=::std::nullopt, ::std::optional max=::std::nullopt) { + return at::_ops::sym_constrain_range::redispatch(dispatchKeySet, size, min, max); + } + + // aten::sym_constrain_range_for_size(Scalar size, *, int? min=None, int? max=None) -> () + inline void sym_constrain_range_for_size(c10::DispatchKeySet dispatchKeySet, const at::Scalar & size, ::std::optional min=::std::nullopt, ::std::optional max=::std::nullopt) { + return at::_ops::sym_constrain_range_for_size::redispatch(dispatchKeySet, size, min, max); + } + + // aten::_functional_sym_constrain_range(Scalar size, int? min, int? max, Tensor dep_token) -> Tensor + inline at::Tensor _functional_sym_constrain_range(c10::DispatchKeySet dispatchKeySet, const at::Scalar & size, ::std::optional min, ::std::optional max, const at::Tensor & dep_token) { + return at::_ops::_functional_sym_constrain_range::redispatch(dispatchKeySet, size, min, max, dep_token); + } + + // aten::_functional_sym_constrain_range_for_size(Scalar size, int? min, int? max, Tensor dep_token) -> Tensor + inline at::Tensor _functional_sym_constrain_range_for_size(c10::DispatchKeySet dispatchKeySet, const at::Scalar & size, ::std::optional min, ::std::optional max, const at::Tensor & dep_token) { + return at::_ops::_functional_sym_constrain_range_for_size::redispatch(dispatchKeySet, size, min, max, dep_token); + } + + // aten::_make_dep_token(*, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + inline at::Tensor _make_dep_token(c10::DispatchKeySet dispatchKeySet, at::TensorOptions options={}, ::std::optional memory_format=::std::nullopt) { + return at::_ops::_make_dep_token::redispatch(dispatchKeySet, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), c10::impl::check_tensor_options_and_extract_memory_format(options, memory_format)); + } + + // aten::_make_dep_token(*, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + inline at::Tensor _make_dep_token(c10::DispatchKeySet dispatchKeySet, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format) { + return at::_ops::_make_dep_token::redispatch(dispatchKeySet, dtype, layout, device, pin_memory, memory_format); + } + + // aten::refine_names(Tensor(a) self, Dimname[] names) -> Tensor(a) + inline at::Tensor refine_names(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::DimnameList names) { + return at::_ops::refine_names::redispatch(dispatchKeySet, self, names); + } + + // aten::_use_cudnn_ctc_loss(Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, int blank) -> bool + inline bool _use_cudnn_ctc_loss(c10::DispatchKeySet dispatchKeySet, const at::Tensor & log_probs, const at::Tensor & targets, at::IntArrayRef input_lengths, at::IntArrayRef target_lengths, int64_t blank) { + return at::_ops::_use_cudnn_ctc_loss::redispatch(dispatchKeySet, log_probs, targets, input_lengths, target_lengths, blank); + } + + // aten::_use_cudnn_ctc_loss.Tensor(Tensor log_probs, Tensor targets, Tensor input_lengths, Tensor target_lengths, int blank) -> bool + inline bool _use_cudnn_ctc_loss(c10::DispatchKeySet dispatchKeySet, const at::Tensor & log_probs, const at::Tensor & targets, const at::Tensor & input_lengths, const at::Tensor & target_lengths, int64_t blank) { + return at::_ops::_use_cudnn_ctc_loss_Tensor::redispatch(dispatchKeySet, log_probs, targets, input_lengths, target_lengths, blank); + } + + // aten::_cudnn_ctc_loss(Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, int blank, bool deterministic, bool zero_infinity) -> (Tensor, Tensor) + inline ::std::tuple _cudnn_ctc_loss(c10::DispatchKeySet dispatchKeySet, const at::Tensor & log_probs, const at::Tensor & targets, at::IntArrayRef input_lengths, at::IntArrayRef target_lengths, int64_t blank, bool deterministic, bool zero_infinity) { + return at::_ops::_cudnn_ctc_loss::redispatch(dispatchKeySet, log_probs, targets, input_lengths, target_lengths, blank, deterministic, zero_infinity); + } + + // aten::_cudnn_ctc_loss.Tensor(Tensor log_probs, Tensor targets, Tensor input_lengths, Tensor target_lengths, int blank, bool deterministic, bool zero_infinity) -> (Tensor, Tensor) + inline ::std::tuple _cudnn_ctc_loss(c10::DispatchKeySet dispatchKeySet, const at::Tensor & log_probs, const at::Tensor & targets, const at::Tensor & input_lengths, const at::Tensor & target_lengths, int64_t blank, bool deterministic, bool zero_infinity) { + return at::_ops::_cudnn_ctc_loss_Tensor::redispatch(dispatchKeySet, log_probs, targets, input_lengths, target_lengths, blank, deterministic, zero_infinity); + } + + // aten::_use_cudnn_rnn_flatten_weight() -> bool + inline bool _use_cudnn_rnn_flatten_weight(c10::DispatchKeySet dispatchKeySet) { + return at::_ops::_use_cudnn_rnn_flatten_weight::redispatch(dispatchKeySet); + } + + // aten::_cudnn_rnn_flatten_weight(Tensor[] weight_arr, int weight_stride0, SymInt input_size, int mode, SymInt hidden_size, SymInt proj_size, int num_layers, bool batch_first, bool bidirectional) -> Tensor + inline at::Tensor _cudnn_rnn_flatten_weight(c10::DispatchKeySet dispatchKeySet, at::TensorList weight_arr, int64_t weight_stride0, int64_t input_size, int64_t mode, int64_t hidden_size, int64_t proj_size, int64_t num_layers, bool batch_first, bool bidirectional) { + return at::_ops::_cudnn_rnn_flatten_weight::redispatch(dispatchKeySet, weight_arr, weight_stride0, input_size, mode, hidden_size, proj_size, num_layers, batch_first, bidirectional); + } + + // aten::_cudnn_rnn_flatten_weight(Tensor[] weight_arr, int weight_stride0, SymInt input_size, int mode, SymInt hidden_size, SymInt proj_size, int num_layers, bool batch_first, bool bidirectional) -> Tensor + inline at::Tensor _cudnn_rnn_flatten_weight_symint(c10::DispatchKeySet dispatchKeySet, at::TensorList weight_arr, int64_t weight_stride0, c10::SymInt input_size, int64_t mode, c10::SymInt hidden_size, c10::SymInt proj_size, int64_t num_layers, bool batch_first, bool bidirectional) { + return at::_ops::_cudnn_rnn_flatten_weight::redispatch(dispatchKeySet, weight_arr, weight_stride0, input_size, mode, hidden_size, proj_size, num_layers, batch_first, bidirectional); + } + + // aten::_cudnn_rnn(Tensor input, Tensor[] weight, int weight_stride0, Tensor? weight_buf, Tensor hx, Tensor? cx, int mode, SymInt hidden_size, SymInt proj_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, SymInt[] batch_sizes, Tensor? dropout_state) -> (Tensor, Tensor, Tensor, Tensor, Tensor) + inline ::std::tuple _cudnn_rnn(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::TensorList weight, int64_t weight_stride0, const ::std::optional & weight_buf, const at::Tensor & hx, const ::std::optional & cx, int64_t mode, int64_t hidden_size, int64_t proj_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, at::IntArrayRef batch_sizes, const ::std::optional & dropout_state) { + return at::_ops::_cudnn_rnn::redispatch(dispatchKeySet, input, weight, weight_stride0, weight_buf, hx, cx, mode, hidden_size, proj_size, num_layers, batch_first, dropout, train, bidirectional, c10::fromIntArrayRefSlow(batch_sizes), dropout_state); + } + + // aten::_cudnn_rnn(Tensor input, Tensor[] weight, int weight_stride0, Tensor? weight_buf, Tensor hx, Tensor? cx, int mode, SymInt hidden_size, SymInt proj_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, SymInt[] batch_sizes, Tensor? dropout_state) -> (Tensor, Tensor, Tensor, Tensor, Tensor) + inline ::std::tuple _cudnn_rnn_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::TensorList weight, int64_t weight_stride0, const ::std::optional & weight_buf, const at::Tensor & hx, const ::std::optional & cx, int64_t mode, c10::SymInt hidden_size, c10::SymInt proj_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, c10::SymIntArrayRef batch_sizes, const ::std::optional & dropout_state) { + return at::_ops::_cudnn_rnn::redispatch(dispatchKeySet, input, weight, weight_stride0, weight_buf, hx, cx, mode, hidden_size, proj_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state); + } + + // aten::_cudnn_rnn_backward(Tensor input, Tensor[] weight, int weight_stride0, Tensor weight_buf, Tensor hx, Tensor? cx, Tensor output, Tensor? grad_output, Tensor? grad_hy, Tensor? grad_cy, int mode, SymInt hidden_size, SymInt proj_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, SymInt[] batch_sizes, Tensor? dropout_state, Tensor reserve, bool[4] output_mask) -> (Tensor, Tensor, Tensor, Tensor[]) + inline ::std::tuple> _cudnn_rnn_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::TensorList weight, int64_t weight_stride0, const at::Tensor & weight_buf, const at::Tensor & hx, const ::std::optional & cx, const at::Tensor & output, const ::std::optional & grad_output, const ::std::optional & grad_hy, const ::std::optional & grad_cy, int64_t mode, int64_t hidden_size, int64_t proj_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, at::IntArrayRef batch_sizes, const ::std::optional & dropout_state, const at::Tensor & reserve, ::std::array output_mask) { + return at::_ops::_cudnn_rnn_backward::redispatch(dispatchKeySet, input, weight, weight_stride0, weight_buf, hx, cx, output, grad_output, grad_hy, grad_cy, mode, hidden_size, proj_size, num_layers, batch_first, dropout, train, bidirectional, c10::fromIntArrayRefSlow(batch_sizes), dropout_state, reserve, output_mask); + } + + // aten::_cudnn_rnn_backward(Tensor input, Tensor[] weight, int weight_stride0, Tensor weight_buf, Tensor hx, Tensor? cx, Tensor output, Tensor? grad_output, Tensor? grad_hy, Tensor? grad_cy, int mode, SymInt hidden_size, SymInt proj_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, SymInt[] batch_sizes, Tensor? dropout_state, Tensor reserve, bool[4] output_mask) -> (Tensor, Tensor, Tensor, Tensor[]) + inline ::std::tuple> _cudnn_rnn_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::TensorList weight, int64_t weight_stride0, const at::Tensor & weight_buf, const at::Tensor & hx, const ::std::optional & cx, const at::Tensor & output, const ::std::optional & grad_output, const ::std::optional & grad_hy, const ::std::optional & grad_cy, int64_t mode, c10::SymInt hidden_size, c10::SymInt proj_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, c10::SymIntArrayRef batch_sizes, const ::std::optional & dropout_state, const at::Tensor & reserve, ::std::array output_mask) { + return at::_ops::_cudnn_rnn_backward::redispatch(dispatchKeySet, input, weight, weight_stride0, weight_buf, hx, cx, output, grad_output, grad_hy, grad_cy, mode, hidden_size, proj_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state, reserve, output_mask); + } + + // aten::_cudnn_init_dropout_state(float dropout, bool train, int dropout_seed, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor + inline at::Tensor _cudnn_init_dropout_state(c10::DispatchKeySet dispatchKeySet, double dropout, bool train, int64_t dropout_seed, at::TensorOptions options) { + return at::_ops::_cudnn_init_dropout_state::redispatch(dispatchKeySet, dropout, train, dropout_seed, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::_cudnn_init_dropout_state(float dropout, bool train, int dropout_seed, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor + inline at::Tensor _cudnn_init_dropout_state(c10::DispatchKeySet dispatchKeySet, double dropout, bool train, int64_t dropout_seed, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::_cudnn_init_dropout_state::redispatch(dispatchKeySet, dropout, train, dropout_seed, dtype, layout, device, pin_memory); + } + + // aten::_debug_has_internal_overlap(Tensor self) -> int + inline int64_t _debug_has_internal_overlap(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::_debug_has_internal_overlap::redispatch(dispatchKeySet, self); + } + + // aten::_fused_dropout(Tensor self, float p, Generator? generator=None) -> (Tensor, Tensor) + inline ::std::tuple _fused_dropout(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double p, ::std::optional generator=::std::nullopt) { + return at::_ops::_fused_dropout::redispatch(dispatchKeySet, self, p, generator); + } + + // aten::_masked_scale(Tensor self, Tensor mask, float scale) -> Tensor + inline at::Tensor _masked_scale(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mask, double scale) { + return at::_ops::_masked_scale::redispatch(dispatchKeySet, self, mask, scale); + } + + // aten::native_dropout(Tensor input, float p, bool? train) -> (Tensor, Tensor) + inline ::std::tuple native_dropout(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, double p, ::std::optional train) { + return at::_ops::native_dropout::redispatch(dispatchKeySet, input, p, train); + } + + // aten::native_dropout_backward(Tensor grad_output, Tensor mask, float scale) -> Tensor + inline at::Tensor native_dropout_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & mask, double scale) { + return at::_ops::native_dropout_backward::redispatch(dispatchKeySet, grad_output, mask, scale); + } + + // aten::_sobol_engine_draw(Tensor quasi, int n, Tensor sobolstate, int dimension, int num_generated, ScalarType? dtype) -> (Tensor, Tensor) + inline ::std::tuple _sobol_engine_draw(c10::DispatchKeySet dispatchKeySet, const at::Tensor & quasi, int64_t n, const at::Tensor & sobolstate, int64_t dimension, int64_t num_generated, ::std::optional dtype) { + return at::_ops::_sobol_engine_draw::redispatch(dispatchKeySet, quasi, n, sobolstate, dimension, num_generated, dtype); + } + + // aten::_sobol_engine_ff_(Tensor(a!) self, int n, Tensor sobolstate, int dimension, int num_generated) -> Tensor(a!) + inline at::Tensor & _sobol_engine_ff_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, int64_t n, const at::Tensor & sobolstate, int64_t dimension, int64_t num_generated) { + return at::_ops::_sobol_engine_ff_::redispatch(dispatchKeySet, self, n, sobolstate, dimension, num_generated); + } + + // aten::_sobol_engine_scramble_(Tensor(a!) self, Tensor ltm, int dimension) -> Tensor(a!) + inline at::Tensor & _sobol_engine_scramble_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & ltm, int64_t dimension) { + return at::_ops::_sobol_engine_scramble_::redispatch(dispatchKeySet, self, ltm, dimension); + } + + // aten::_sobol_engine_initialize_state_(Tensor(a!) self, int dimension) -> Tensor(a!) + inline at::Tensor & _sobol_engine_initialize_state_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, int64_t dimension) { + return at::_ops::_sobol_engine_initialize_state_::redispatch(dispatchKeySet, self, dimension); + } + + // aten::_reshape_from_tensor(Tensor self, Tensor shape) -> Tensor + inline at::Tensor _reshape_from_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & shape) { + return at::_ops::_reshape_from_tensor::redispatch(dispatchKeySet, self, shape); + } + + // aten::_shape_as_tensor(Tensor self) -> Tensor + inline at::Tensor _shape_as_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::_shape_as_tensor::redispatch(dispatchKeySet, self); + } + + // aten::dropout(Tensor input, float p, bool train) -> Tensor + inline at::Tensor dropout(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, double p, bool train) { + return at::_ops::dropout::redispatch(dispatchKeySet, input, p, train); + } + + // aten::dropout_(Tensor(a!) self, float p, bool train) -> Tensor(a!) + inline at::Tensor & dropout_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, double p, bool train) { + return at::_ops::dropout_::redispatch(dispatchKeySet, self, p, train); + } + + // aten::feature_dropout(Tensor input, float p, bool train) -> Tensor + inline at::Tensor feature_dropout(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, double p, bool train) { + return at::_ops::feature_dropout::redispatch(dispatchKeySet, input, p, train); + } + + // aten::feature_dropout_(Tensor(a!) self, float p, bool train) -> Tensor(a!) + inline at::Tensor & feature_dropout_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, double p, bool train) { + return at::_ops::feature_dropout_::redispatch(dispatchKeySet, self, p, train); + } + + // aten::alpha_dropout(Tensor input, float p, bool train) -> Tensor + inline at::Tensor alpha_dropout(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, double p, bool train) { + return at::_ops::alpha_dropout::redispatch(dispatchKeySet, input, p, train); + } + + // aten::alpha_dropout_(Tensor(a!) self, float p, bool train) -> Tensor(a!) + inline at::Tensor & alpha_dropout_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, double p, bool train) { + return at::_ops::alpha_dropout_::redispatch(dispatchKeySet, self, p, train); + } + + // aten::feature_alpha_dropout(Tensor input, float p, bool train) -> Tensor + inline at::Tensor feature_alpha_dropout(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, double p, bool train) { + return at::_ops::feature_alpha_dropout::redispatch(dispatchKeySet, input, p, train); + } + + // aten::feature_alpha_dropout_(Tensor(a!) self, float p, bool train) -> Tensor(a!) + inline at::Tensor & feature_alpha_dropout_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, double p, bool train) { + return at::_ops::feature_alpha_dropout_::redispatch(dispatchKeySet, self, p, train); + } + + // aten::abs(Tensor self) -> Tensor + inline at::Tensor abs(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::abs::redispatch(dispatchKeySet, self); + } + + // aten::abs_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & abs_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::abs_::redispatch(dispatchKeySet, self); + } + + // aten::abs.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & abs_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::abs_out::redispatch(dispatchKeySet, self, out); + } + + // aten::abs.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & abs_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::abs_out::redispatch(dispatchKeySet, self, out); + } + + // aten::absolute(Tensor self) -> Tensor + inline at::Tensor absolute(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::absolute::redispatch(dispatchKeySet, self); + } + + // aten::absolute_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & absolute_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::absolute_::redispatch(dispatchKeySet, self); + } + + // aten::absolute.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & absolute_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::absolute_out::redispatch(dispatchKeySet, self, out); + } + + // aten::absolute.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & absolute_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::absolute_out::redispatch(dispatchKeySet, self, out); + } + + // aten::angle(Tensor self) -> Tensor + inline at::Tensor angle(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::angle::redispatch(dispatchKeySet, self); + } + + // aten::angle.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & angle_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::angle_out::redispatch(dispatchKeySet, self, out); + } + + // aten::angle.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & angle_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::angle_out::redispatch(dispatchKeySet, self, out); + } + + // aten::view_as_real(Tensor(a) self) -> Tensor(a) + inline at::Tensor view_as_real(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::view_as_real::redispatch(dispatchKeySet, self); + } + + // aten::view_as_complex(Tensor(a) self) -> Tensor(a) + inline at::Tensor view_as_complex(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::view_as_complex::redispatch(dispatchKeySet, self); + } + + // aten::sgn(Tensor self) -> Tensor + inline at::Tensor sgn(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::sgn::redispatch(dispatchKeySet, self); + } + + // aten::sgn_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & sgn_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::sgn_::redispatch(dispatchKeySet, self); + } + + // aten::sgn.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & sgn_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::sgn_out::redispatch(dispatchKeySet, self, out); + } + + // aten::sgn.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & sgn_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::sgn_out::redispatch(dispatchKeySet, self, out); + } + + // aten::chalf(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor + inline at::Tensor chalf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional memory_format=::std::nullopt) { + return at::_ops::chalf::redispatch(dispatchKeySet, self, memory_format); + } + + // aten::real(Tensor(a) self) -> Tensor(a) + inline at::Tensor real(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::real::redispatch(dispatchKeySet, self); + } + + // aten::imag(Tensor(a) self) -> Tensor(a) + inline at::Tensor imag(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::imag::redispatch(dispatchKeySet, self); + } + + // aten::_conj(Tensor(a) self) -> Tensor(a) + inline at::Tensor _conj(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::_conj::redispatch(dispatchKeySet, self); + } + + // aten::conj(Tensor(a) self) -> Tensor(a) + inline at::Tensor __dispatch_conj(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::conj::redispatch(dispatchKeySet, self); + } + + // aten::_conj_physical(Tensor self) -> Tensor + inline at::Tensor _conj_physical(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::_conj_physical::redispatch(dispatchKeySet, self); + } + + // aten::conj_physical(Tensor self) -> Tensor + inline at::Tensor conj_physical(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::conj_physical::redispatch(dispatchKeySet, self); + } + + // aten::conj_physical.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & conj_physical_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::conj_physical_out::redispatch(dispatchKeySet, self, out); + } + + // aten::conj_physical.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & conj_physical_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::conj_physical_out::redispatch(dispatchKeySet, self, out); + } + + // aten::conj_physical_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & conj_physical_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::conj_physical_::redispatch(dispatchKeySet, self); + } + + // aten::resolve_conj(Tensor(a) self) -> Tensor(a) + inline at::Tensor resolve_conj(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::resolve_conj::redispatch(dispatchKeySet, self); + } + + // aten::resolve_neg(Tensor(a) self) -> Tensor(a) + inline at::Tensor resolve_neg(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::resolve_neg::redispatch(dispatchKeySet, self); + } + + // aten::_neg_view(Tensor(a) self) -> Tensor(a) + inline at::Tensor _neg_view(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::_neg_view::redispatch(dispatchKeySet, self); + } + + // aten::acos(Tensor self) -> Tensor + inline at::Tensor acos(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::acos::redispatch(dispatchKeySet, self); + } + + // aten::acos_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & acos_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::acos_::redispatch(dispatchKeySet, self); + } + + // aten::acos.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & acos_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::acos_out::redispatch(dispatchKeySet, self, out); + } + + // aten::acos.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & acos_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::acos_out::redispatch(dispatchKeySet, self, out); + } + + // aten::arccos(Tensor self) -> Tensor + inline at::Tensor arccos(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::arccos::redispatch(dispatchKeySet, self); + } + + // aten::arccos_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & arccos_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::arccos_::redispatch(dispatchKeySet, self); + } + + // aten::arccos.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & arccos_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::arccos_out::redispatch(dispatchKeySet, self, out); + } + + // aten::arccos.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & arccos_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::arccos_out::redispatch(dispatchKeySet, self, out); + } + + // aten::avg_pool1d(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, bool ceil_mode=False, bool count_include_pad=True) -> Tensor + inline at::Tensor avg_pool1d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, bool ceil_mode=false, bool count_include_pad=true) { + return at::_ops::avg_pool1d::redispatch(dispatchKeySet, self, kernel_size, stride, padding, ceil_mode, count_include_pad); + } + + // aten::adaptive_avg_pool1d(Tensor self, int[1] output_size) -> Tensor + inline at::Tensor adaptive_avg_pool1d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size) { + return at::_ops::adaptive_avg_pool1d::redispatch(dispatchKeySet, self, output_size); + } + + // aten::adaptive_max_pool1d(Tensor self, int[1] output_size) -> (Tensor, Tensor) + inline ::std::tuple adaptive_max_pool1d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size) { + return at::_ops::adaptive_max_pool1d::redispatch(dispatchKeySet, self, output_size); + } + + // aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor + inline at::Tensor add(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha=1) { + return at::_ops::add_Tensor::redispatch(dispatchKeySet, self, other, alpha); + } + + // aten::add_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!) + inline at::Tensor & add_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha=1) { + return at::_ops::add__Tensor::redispatch(dispatchKeySet, self, other, alpha); + } + + // aten::add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & add_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha=1) { + return at::_ops::add_out::redispatch(dispatchKeySet, self, other, alpha, out); + } + + // aten::add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & add_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha, at::Tensor & out) { + return at::_ops::add_out::redispatch(dispatchKeySet, self, other, alpha, out); + } + + // aten::_add_relu.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor + inline at::Tensor _add_relu(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha=1) { + return at::_ops::_add_relu_Tensor::redispatch(dispatchKeySet, self, other, alpha); + } + + // aten::_add_relu_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!) + inline at::Tensor & _add_relu_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha=1) { + return at::_ops::_add_relu__Tensor::redispatch(dispatchKeySet, self, other, alpha); + } + + // aten::_add_relu.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _add_relu_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha=1) { + return at::_ops::_add_relu_out::redispatch(dispatchKeySet, self, other, alpha, out); + } + + // aten::_add_relu.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _add_relu_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha, at::Tensor & out) { + return at::_ops::_add_relu_out::redispatch(dispatchKeySet, self, other, alpha, out); + } + + // aten::_add_relu.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor + inline at::Tensor _add_relu(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha=1) { + return at::_ops::_add_relu_Scalar::redispatch(dispatchKeySet, self, other, alpha); + } + + // aten::_add_relu_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!) + inline at::Tensor & _add_relu_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha=1) { + return at::_ops::_add_relu__Scalar::redispatch(dispatchKeySet, self, other, alpha); + } + + // aten::add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor + inline at::Tensor add(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha=1) { + return at::_ops::add_Scalar::redispatch(dispatchKeySet, self, other, alpha); + } + + // aten::add_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!) + inline at::Tensor & add_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha=1) { + return at::_ops::add__Scalar::redispatch(dispatchKeySet, self, other, alpha); + } + + // aten::addmv(Tensor self, Tensor mat, Tensor vec, *, Scalar beta=1, Scalar alpha=1) -> Tensor + inline at::Tensor addmv(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mat, const at::Tensor & vec, const at::Scalar & beta=1, const at::Scalar & alpha=1) { + return at::_ops::addmv::redispatch(dispatchKeySet, self, mat, vec, beta, alpha); + } + + // aten::addmv_(Tensor(a!) self, Tensor mat, Tensor vec, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!) + inline at::Tensor & addmv_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & mat, const at::Tensor & vec, const at::Scalar & beta=1, const at::Scalar & alpha=1) { + return at::_ops::addmv_::redispatch(dispatchKeySet, self, mat, vec, beta, alpha); + } + + // aten::addmv.out(Tensor self, Tensor mat, Tensor vec, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & addmv_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & mat, const at::Tensor & vec, const at::Scalar & beta=1, const at::Scalar & alpha=1) { + return at::_ops::addmv_out::redispatch(dispatchKeySet, self, mat, vec, beta, alpha, out); + } + + // aten::addmv.out(Tensor self, Tensor mat, Tensor vec, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & addmv_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mat, const at::Tensor & vec, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out) { + return at::_ops::addmv_out::redispatch(dispatchKeySet, self, mat, vec, beta, alpha, out); + } + + // aten::addr(Tensor self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1) -> Tensor + inline at::Tensor addr(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & vec1, const at::Tensor & vec2, const at::Scalar & beta=1, const at::Scalar & alpha=1) { + return at::_ops::addr::redispatch(dispatchKeySet, self, vec1, vec2, beta, alpha); + } + + // aten::addr_(Tensor(a!) self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!) + inline at::Tensor & addr_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & vec1, const at::Tensor & vec2, const at::Scalar & beta=1, const at::Scalar & alpha=1) { + return at::_ops::addr_::redispatch(dispatchKeySet, self, vec1, vec2, beta, alpha); + } + + // aten::addr.out(Tensor self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & addr_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & vec1, const at::Tensor & vec2, const at::Scalar & beta=1, const at::Scalar & alpha=1) { + return at::_ops::addr_out::redispatch(dispatchKeySet, self, vec1, vec2, beta, alpha, out); + } + + // aten::addr.out(Tensor self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & addr_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & vec1, const at::Tensor & vec2, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out) { + return at::_ops::addr_out::redispatch(dispatchKeySet, self, vec1, vec2, beta, alpha, out); + } + + // aten::affine_grid_generator(Tensor theta, SymInt[] size, bool align_corners) -> Tensor + inline at::Tensor affine_grid_generator(c10::DispatchKeySet dispatchKeySet, const at::Tensor & theta, at::IntArrayRef size, bool align_corners) { + return at::_ops::affine_grid_generator::redispatch(dispatchKeySet, theta, c10::fromIntArrayRefSlow(size), align_corners); + } + + // aten::affine_grid_generator(Tensor theta, SymInt[] size, bool align_corners) -> Tensor + inline at::Tensor affine_grid_generator_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & theta, c10::SymIntArrayRef size, bool align_corners) { + return at::_ops::affine_grid_generator::redispatch(dispatchKeySet, theta, size, align_corners); + } + + // aten::affine_grid_generator_backward(Tensor grad, SymInt[] size, bool align_corners) -> Tensor + inline at::Tensor affine_grid_generator_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, at::IntArrayRef size, bool align_corners) { + return at::_ops::affine_grid_generator_backward::redispatch(dispatchKeySet, grad, c10::fromIntArrayRefSlow(size), align_corners); + } + + // aten::affine_grid_generator_backward(Tensor grad, SymInt[] size, bool align_corners) -> Tensor + inline at::Tensor affine_grid_generator_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, c10::SymIntArrayRef size, bool align_corners) { + return at::_ops::affine_grid_generator_backward::redispatch(dispatchKeySet, grad, size, align_corners); + } + + // aten::_is_all_true(Tensor self) -> Tensor + inline at::Tensor _is_all_true(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::_is_all_true::redispatch(dispatchKeySet, self); + } + + // aten::_is_any_true(Tensor self) -> Tensor + inline at::Tensor _is_any_true(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::_is_any_true::redispatch(dispatchKeySet, self); + } + + // aten::_test_check_tensor(Tensor self) -> Tensor + inline at::Tensor _test_check_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::_test_check_tensor::redispatch(dispatchKeySet, self); + } + + // aten::_test_functorch_fallback(Tensor self, Tensor other) -> Tensor + inline at::Tensor _test_functorch_fallback(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::_test_functorch_fallback::redispatch(dispatchKeySet, self, other); + } + + // aten::all.dim(Tensor self, int dim, bool keepdim=False) -> Tensor + inline at::Tensor all(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, bool keepdim=false) { + return at::_ops::all_dim::redispatch(dispatchKeySet, self, dim, keepdim); + } + + // aten::all.dims(Tensor self, int[]? dim=None, bool keepdim=False) -> Tensor + inline at::Tensor all(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim=false) { + return at::_ops::all_dims::redispatch(dispatchKeySet, self, dim, keepdim); + } + + // aten::all.out(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & all_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim, bool keepdim=false) { + return at::_ops::all_out::redispatch(dispatchKeySet, self, dim, keepdim, out); + } + + // aten::all.out(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & all_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, bool keepdim, at::Tensor & out) { + return at::_ops::all_out::redispatch(dispatchKeySet, self, dim, keepdim, out); + } + + // aten::all.dims_out(Tensor self, int[]? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & all_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim=false) { + return at::_ops::all_dims_out::redispatch(dispatchKeySet, self, dim, keepdim, out); + } + + // aten::all.dims_out(Tensor self, int[]? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & all_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim, at::Tensor & out) { + return at::_ops::all_dims_out::redispatch(dispatchKeySet, self, dim, keepdim, out); + } + + // aten::all.dimname(Tensor self, Dimname dim, bool keepdim=False) -> Tensor + inline at::Tensor all(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, bool keepdim=false) { + return at::_ops::all_dimname::redispatch(dispatchKeySet, self, dim, keepdim); + } + + // aten::all.dimname_out(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & all_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::Dimname dim, bool keepdim=false) { + return at::_ops::all_dimname_out::redispatch(dispatchKeySet, self, dim, keepdim, out); + } + + // aten::all.dimname_out(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & all_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, bool keepdim, at::Tensor & out) { + return at::_ops::all_dimname_out::redispatch(dispatchKeySet, self, dim, keepdim, out); + } + + // aten::allclose(Tensor self, Tensor other, float rtol=1e-05, float atol=1e-08, bool equal_nan=False) -> bool + inline bool allclose(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, double rtol=1e-05, double atol=1e-08, bool equal_nan=false) { + return at::_ops::allclose::redispatch(dispatchKeySet, self, other, rtol, atol, equal_nan); + } + + // aten::any.dim(Tensor self, int dim, bool keepdim=False) -> Tensor + inline at::Tensor any(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, bool keepdim=false) { + return at::_ops::any_dim::redispatch(dispatchKeySet, self, dim, keepdim); + } + + // aten::any.dims(Tensor self, int[]? dim=None, bool keepdim=False) -> Tensor + inline at::Tensor any(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim=false) { + return at::_ops::any_dims::redispatch(dispatchKeySet, self, dim, keepdim); + } + + // aten::any.out(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & any_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim, bool keepdim=false) { + return at::_ops::any_out::redispatch(dispatchKeySet, self, dim, keepdim, out); + } + + // aten::any.out(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & any_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, bool keepdim, at::Tensor & out) { + return at::_ops::any_out::redispatch(dispatchKeySet, self, dim, keepdim, out); + } + + // aten::any.dims_out(Tensor self, int[]? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & any_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim=false) { + return at::_ops::any_dims_out::redispatch(dispatchKeySet, self, dim, keepdim, out); + } + + // aten::any.dims_out(Tensor self, int[]? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & any_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim, at::Tensor & out) { + return at::_ops::any_dims_out::redispatch(dispatchKeySet, self, dim, keepdim, out); + } + + // aten::any.dimname(Tensor self, Dimname dim, bool keepdim=False) -> Tensor + inline at::Tensor any(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, bool keepdim=false) { + return at::_ops::any_dimname::redispatch(dispatchKeySet, self, dim, keepdim); + } + + // aten::any.dimname_out(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & any_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::Dimname dim, bool keepdim=false) { + return at::_ops::any_dimname_out::redispatch(dispatchKeySet, self, dim, keepdim, out); + } + + // aten::any.dimname_out(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & any_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, bool keepdim, at::Tensor & out) { + return at::_ops::any_dimname_out::redispatch(dispatchKeySet, self, dim, keepdim, out); + } + + // aten::arange(Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor arange(c10::DispatchKeySet dispatchKeySet, const at::Scalar & end, at::TensorOptions options={}) { + return at::_ops::arange::redispatch(dispatchKeySet, end, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::arange(Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor arange(c10::DispatchKeySet dispatchKeySet, const at::Scalar & end, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::arange::redispatch(dispatchKeySet, end, dtype, layout, device, pin_memory); + } + + // aten::arange.start(Scalar start, Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor arange(c10::DispatchKeySet dispatchKeySet, const at::Scalar & start, const at::Scalar & end, at::TensorOptions options={}) { + return at::_ops::arange_start::redispatch(dispatchKeySet, start, end, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::arange.start(Scalar start, Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor arange(c10::DispatchKeySet dispatchKeySet, const at::Scalar & start, const at::Scalar & end, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::arange_start::redispatch(dispatchKeySet, start, end, dtype, layout, device, pin_memory); + } + + // aten::arange.start_step(Scalar start, Scalar end, Scalar step=1, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor arange(c10::DispatchKeySet dispatchKeySet, const at::Scalar & start, const at::Scalar & end, const at::Scalar & step, at::TensorOptions options={}) { + return at::_ops::arange_start_step::redispatch(dispatchKeySet, start, end, step, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::arange.start_step(Scalar start, Scalar end, Scalar step=1, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor arange(c10::DispatchKeySet dispatchKeySet, const at::Scalar & start, const at::Scalar & end, const at::Scalar & step, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::arange_start_step::redispatch(dispatchKeySet, start, end, step, dtype, layout, device, pin_memory); + } + + // aten::arange.out(Scalar end, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & arange_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & end) { + return at::_ops::arange_out::redispatch(dispatchKeySet, end, out); + } + + // aten::arange.out(Scalar end, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & arange_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & end, at::Tensor & out) { + return at::_ops::arange_out::redispatch(dispatchKeySet, end, out); + } + + // aten::arange.start_out(Scalar start, Scalar end, Scalar step=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & arange_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & start, const at::Scalar & end, const at::Scalar & step) { + return at::_ops::arange_start_out::redispatch(dispatchKeySet, start, end, step, out); + } + + // aten::arange.start_out(Scalar start, Scalar end, Scalar step=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & arange_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & start, const at::Scalar & end, const at::Scalar & step, at::Tensor & out) { + return at::_ops::arange_start_out::redispatch(dispatchKeySet, start, end, step, out); + } + + // aten::_dim_arange(Tensor like, int dim) -> Tensor + inline at::Tensor _dim_arange(c10::DispatchKeySet dispatchKeySet, const at::Tensor & like, int64_t dim) { + return at::_ops::_dim_arange::redispatch(dispatchKeySet, like, dim); + } + + // aten::argmax(Tensor self, int? dim=None, bool keepdim=False) -> Tensor + inline at::Tensor argmax(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional dim=::std::nullopt, bool keepdim=false) { + return at::_ops::argmax::redispatch(dispatchKeySet, self, dim, keepdim); + } + + // aten::argmax.out(Tensor self, int? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & argmax_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, ::std::optional dim=::std::nullopt, bool keepdim=false) { + return at::_ops::argmax_out::redispatch(dispatchKeySet, self, dim, keepdim, out); + } + + // aten::argmax.out(Tensor self, int? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & argmax_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional dim, bool keepdim, at::Tensor & out) { + return at::_ops::argmax_out::redispatch(dispatchKeySet, self, dim, keepdim, out); + } + + // aten::argmin(Tensor self, int? dim=None, bool keepdim=False) -> Tensor + inline at::Tensor argmin(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional dim=::std::nullopt, bool keepdim=false) { + return at::_ops::argmin::redispatch(dispatchKeySet, self, dim, keepdim); + } + + // aten::argmin.out(Tensor self, int? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & argmin_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, ::std::optional dim=::std::nullopt, bool keepdim=false) { + return at::_ops::argmin_out::redispatch(dispatchKeySet, self, dim, keepdim, out); + } + + // aten::argmin.out(Tensor self, int? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & argmin_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional dim, bool keepdim, at::Tensor & out) { + return at::_ops::argmin_out::redispatch(dispatchKeySet, self, dim, keepdim, out); + } + + // aten::acosh(Tensor self) -> Tensor + inline at::Tensor acosh(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::acosh::redispatch(dispatchKeySet, self); + } + + // aten::acosh_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & acosh_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::acosh_::redispatch(dispatchKeySet, self); + } + + // aten::acosh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & acosh_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::acosh_out::redispatch(dispatchKeySet, self, out); + } + + // aten::acosh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & acosh_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::acosh_out::redispatch(dispatchKeySet, self, out); + } + + // aten::arccosh(Tensor self) -> Tensor + inline at::Tensor arccosh(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::arccosh::redispatch(dispatchKeySet, self); + } + + // aten::arccosh_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & arccosh_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::arccosh_::redispatch(dispatchKeySet, self); + } + + // aten::arccosh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & arccosh_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::arccosh_out::redispatch(dispatchKeySet, self, out); + } + + // aten::arccosh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & arccosh_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::arccosh_out::redispatch(dispatchKeySet, self, out); + } + + // aten::asinh(Tensor self) -> Tensor + inline at::Tensor asinh(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::asinh::redispatch(dispatchKeySet, self); + } + + // aten::asinh_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & asinh_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::asinh_::redispatch(dispatchKeySet, self); + } + + // aten::asinh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & asinh_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::asinh_out::redispatch(dispatchKeySet, self, out); + } + + // aten::asinh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & asinh_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::asinh_out::redispatch(dispatchKeySet, self, out); + } + + // aten::arcsinh(Tensor self) -> Tensor + inline at::Tensor arcsinh(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::arcsinh::redispatch(dispatchKeySet, self); + } + + // aten::arcsinh_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & arcsinh_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::arcsinh_::redispatch(dispatchKeySet, self); + } + + // aten::arcsinh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & arcsinh_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::arcsinh_out::redispatch(dispatchKeySet, self, out); + } + + // aten::arcsinh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & arcsinh_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::arcsinh_out::redispatch(dispatchKeySet, self, out); + } + + // aten::atanh(Tensor self) -> Tensor + inline at::Tensor atanh(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::atanh::redispatch(dispatchKeySet, self); + } + + // aten::atanh_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & atanh_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::atanh_::redispatch(dispatchKeySet, self); + } + + // aten::atanh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & atanh_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::atanh_out::redispatch(dispatchKeySet, self, out); + } + + // aten::atanh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & atanh_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::atanh_out::redispatch(dispatchKeySet, self, out); + } + + // aten::arctanh(Tensor self) -> Tensor + inline at::Tensor arctanh(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::arctanh::redispatch(dispatchKeySet, self); + } + + // aten::arctanh_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & arctanh_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::arctanh_::redispatch(dispatchKeySet, self); + } + + // aten::arctanh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & arctanh_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::arctanh_out::redispatch(dispatchKeySet, self, out); + } + + // aten::arctanh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & arctanh_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::arctanh_out::redispatch(dispatchKeySet, self, out); + } + + // aten::as_strided(Tensor(a) self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor(a) + inline at::Tensor as_strided(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, at::IntArrayRef stride, ::std::optional storage_offset=::std::nullopt) { + return at::_ops::as_strided::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride), storage_offset.has_value() ? ::std::make_optional(c10::SymInt(*storage_offset)) : ::std::nullopt); + } + + // aten::as_strided(Tensor(a) self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor(a) + inline at::Tensor as_strided_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional storage_offset=::std::nullopt) { + return at::_ops::as_strided::redispatch(dispatchKeySet, self, size, stride, storage_offset); + } + + // aten::as_strided_(Tensor(a!) self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor(a!) + inline const at::Tensor & as_strided_(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, at::IntArrayRef stride, ::std::optional storage_offset=::std::nullopt) { + return at::_ops::as_strided_::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride), storage_offset.has_value() ? ::std::make_optional(c10::SymInt(*storage_offset)) : ::std::nullopt); + } + + // aten::as_strided_(Tensor(a!) self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor(a!) + inline const at::Tensor & as_strided__symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional storage_offset=::std::nullopt) { + return at::_ops::as_strided_::redispatch(dispatchKeySet, self, size, stride, storage_offset); + } + + // aten::asin(Tensor self) -> Tensor + inline at::Tensor asin(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::asin::redispatch(dispatchKeySet, self); + } + + // aten::asin_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & asin_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::asin_::redispatch(dispatchKeySet, self); + } + + // aten::asin.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & asin_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::asin_out::redispatch(dispatchKeySet, self, out); + } + + // aten::asin.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & asin_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::asin_out::redispatch(dispatchKeySet, self, out); + } + + // aten::arcsin(Tensor self) -> Tensor + inline at::Tensor arcsin(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::arcsin::redispatch(dispatchKeySet, self); + } + + // aten::arcsin_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & arcsin_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::arcsin_::redispatch(dispatchKeySet, self); + } + + // aten::arcsin.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & arcsin_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::arcsin_out::redispatch(dispatchKeySet, self, out); + } + + // aten::arcsin.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & arcsin_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::arcsin_out::redispatch(dispatchKeySet, self, out); + } + + // aten::atan(Tensor self) -> Tensor + inline at::Tensor atan(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::atan::redispatch(dispatchKeySet, self); + } + + // aten::atan_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & atan_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::atan_::redispatch(dispatchKeySet, self); + } + + // aten::atan.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & atan_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::atan_out::redispatch(dispatchKeySet, self, out); + } + + // aten::atan.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & atan_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::atan_out::redispatch(dispatchKeySet, self, out); + } + + // aten::arctan(Tensor self) -> Tensor + inline at::Tensor arctan(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::arctan::redispatch(dispatchKeySet, self); + } + + // aten::arctan_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & arctan_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::arctan_::redispatch(dispatchKeySet, self); + } + + // aten::arctan.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & arctan_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::arctan_out::redispatch(dispatchKeySet, self, out); + } + + // aten::arctan.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & arctan_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::arctan_out::redispatch(dispatchKeySet, self, out); + } + + // aten::atleast_1d(Tensor self) -> Tensor + inline at::Tensor atleast_1d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::atleast_1d::redispatch(dispatchKeySet, self); + } + + // aten::atleast_1d.Sequence(Tensor[] tensors) -> Tensor[] + inline ::std::vector atleast_1d(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors) { + return at::_ops::atleast_1d_Sequence::redispatch(dispatchKeySet, tensors); + } + + // aten::atleast_2d(Tensor self) -> Tensor + inline at::Tensor atleast_2d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::atleast_2d::redispatch(dispatchKeySet, self); + } + + // aten::atleast_2d.Sequence(Tensor[] tensors) -> Tensor[] + inline ::std::vector atleast_2d(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors) { + return at::_ops::atleast_2d_Sequence::redispatch(dispatchKeySet, tensors); + } + + // aten::atleast_3d(Tensor self) -> Tensor + inline at::Tensor atleast_3d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::atleast_3d::redispatch(dispatchKeySet, self); + } + + // aten::atleast_3d.Sequence(Tensor[] tensors) -> Tensor[] + inline ::std::vector atleast_3d(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors) { + return at::_ops::atleast_3d_Sequence::redispatch(dispatchKeySet, tensors); + } + + // aten::baddbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor + inline at::Tensor baddbmm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta=1, const at::Scalar & alpha=1) { + return at::_ops::baddbmm::redispatch(dispatchKeySet, self, batch1, batch2, beta, alpha); + } + + // aten::baddbmm_(Tensor(a!) self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!) + inline at::Tensor & baddbmm_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta=1, const at::Scalar & alpha=1) { + return at::_ops::baddbmm_::redispatch(dispatchKeySet, self, batch1, batch2, beta, alpha); + } + + // aten::baddbmm.out(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & baddbmm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta=1, const at::Scalar & alpha=1) { + return at::_ops::baddbmm_out::redispatch(dispatchKeySet, self, batch1, batch2, beta, alpha, out); + } + + // aten::baddbmm.out(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & baddbmm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out) { + return at::_ops::baddbmm_out::redispatch(dispatchKeySet, self, batch1, batch2, beta, alpha, out); + } + + // aten::baddbmm.dtype(Tensor self, Tensor batch1, Tensor batch2, ScalarType out_dtype, *, Scalar beta=1, Scalar alpha=1) -> Tensor + inline at::Tensor baddbmm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, at::ScalarType out_dtype, const at::Scalar & beta=1, const at::Scalar & alpha=1) { + return at::_ops::baddbmm_dtype::redispatch(dispatchKeySet, self, batch1, batch2, out_dtype, beta, alpha); + } + + // aten::baddbmm.dtype_out(Tensor self, Tensor batch1, Tensor batch2, ScalarType out_dtype, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & baddbmm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, at::ScalarType out_dtype, const at::Scalar & beta=1, const at::Scalar & alpha=1) { + return at::_ops::baddbmm_dtype_out::redispatch(dispatchKeySet, self, batch1, batch2, out_dtype, beta, alpha, out); + } + + // aten::baddbmm.dtype_out(Tensor self, Tensor batch1, Tensor batch2, ScalarType out_dtype, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & baddbmm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, at::ScalarType out_dtype, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out) { + return at::_ops::baddbmm_dtype_out::redispatch(dispatchKeySet, self, batch1, batch2, out_dtype, beta, alpha, out); + } + + // aten::bartlett_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor bartlett_window(c10::DispatchKeySet dispatchKeySet, int64_t window_length, at::TensorOptions options={}) { + return at::_ops::bartlett_window::redispatch(dispatchKeySet, window_length, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::bartlett_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor bartlett_window(c10::DispatchKeySet dispatchKeySet, int64_t window_length, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::bartlett_window::redispatch(dispatchKeySet, window_length, dtype, layout, device, pin_memory); + } + + // aten::bartlett_window.periodic(int window_length, bool periodic, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor bartlett_window(c10::DispatchKeySet dispatchKeySet, int64_t window_length, bool periodic, at::TensorOptions options={}) { + return at::_ops::bartlett_window_periodic::redispatch(dispatchKeySet, window_length, periodic, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::bartlett_window.periodic(int window_length, bool periodic, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor bartlett_window(c10::DispatchKeySet dispatchKeySet, int64_t window_length, bool periodic, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::bartlett_window_periodic::redispatch(dispatchKeySet, window_length, periodic, dtype, layout, device, pin_memory); + } + + // aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor + inline at::Tensor batch_norm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const ::std::optional & running_mean, const ::std::optional & running_var, bool training, double momentum, double eps, bool cudnn_enabled) { + return at::_ops::batch_norm::redispatch(dispatchKeySet, input, weight, bias, running_mean, running_var, training, momentum, eps, cudnn_enabled); + } + + // aten::quantized_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor mean, Tensor var, float eps, float output_scale, int output_zero_point) -> Tensor + inline at::Tensor quantized_batch_norm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const at::Tensor & mean, const at::Tensor & var, double eps, double output_scale, int64_t output_zero_point) { + return at::_ops::quantized_batch_norm::redispatch(dispatchKeySet, input, weight, bias, mean, var, eps, output_scale, output_zero_point); + } + + // aten::_batch_norm_impl_index(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> (Tensor, Tensor, Tensor, Tensor, int) + inline ::std::tuple _batch_norm_impl_index(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const ::std::optional & running_mean, const ::std::optional & running_var, bool training, double momentum, double eps, bool cudnn_enabled) { + return at::_ops::_batch_norm_impl_index::redispatch(dispatchKeySet, input, weight, bias, running_mean, running_var, training, momentum, eps, cudnn_enabled); + } + + // aten::_batch_norm_impl_index_backward(int impl_index, Tensor input, Tensor grad_output, Tensor? weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var_transform, bool train, float eps, bool[3] output_mask, Tensor reservedSpace) -> (Tensor, Tensor, Tensor) + inline ::std::tuple _batch_norm_impl_index_backward(c10::DispatchKeySet dispatchKeySet, int64_t impl_index, const at::Tensor & input, const at::Tensor & grad_output, const ::std::optional & weight, const ::std::optional & running_mean, const ::std::optional & running_var, const ::std::optional & save_mean, const ::std::optional & save_var_transform, bool train, double eps, ::std::array output_mask, const at::Tensor & reservedSpace) { + return at::_ops::_batch_norm_impl_index_backward::redispatch(dispatchKeySet, impl_index, input, grad_output, weight, running_mean, running_var, save_mean, save_var_transform, train, eps, output_mask, reservedSpace); + } + + // aten::bernoulli(Tensor self, *, Generator? generator=None) -> Tensor + inline at::Tensor bernoulli(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional generator=::std::nullopt) { + return at::_ops::bernoulli::redispatch(dispatchKeySet, self, generator); + } + + // aten::bernoulli.out(Tensor self, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bernoulli_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, ::std::optional generator=::std::nullopt) { + return at::_ops::bernoulli_out::redispatch(dispatchKeySet, self, generator, out); + } + + // aten::bernoulli.out(Tensor self, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bernoulli_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional generator, at::Tensor & out) { + return at::_ops::bernoulli_out::redispatch(dispatchKeySet, self, generator, out); + } + + // aten::bernoulli_.Tensor(Tensor(a!) self, Tensor p, *, Generator? generator=None) -> Tensor(a!) + inline at::Tensor & bernoulli_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & p, ::std::optional generator=::std::nullopt) { + return at::_ops::bernoulli__Tensor::redispatch(dispatchKeySet, self, p, generator); + } + + // aten::bernoulli_.float(Tensor(a!) self, float p=0.5, *, Generator? generator=None) -> Tensor(a!) + inline at::Tensor & bernoulli_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, double p=0.5, ::std::optional generator=::std::nullopt) { + return at::_ops::bernoulli__float::redispatch(dispatchKeySet, self, p, generator); + } + + // aten::bernoulli.p(Tensor self, float p, *, Generator? generator=None) -> Tensor + inline at::Tensor bernoulli(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double p, ::std::optional generator=::std::nullopt) { + return at::_ops::bernoulli_p::redispatch(dispatchKeySet, self, p, generator); + } + + // aten::bilinear(Tensor input1, Tensor input2, Tensor weight, Tensor? bias=None) -> Tensor + inline at::Tensor bilinear(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input1, const at::Tensor & input2, const at::Tensor & weight, const ::std::optional & bias={}) { + return at::_ops::bilinear::redispatch(dispatchKeySet, input1, input2, weight, bias); + } + + // aten::binary_cross_entropy(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean) -> Tensor + inline at::Tensor binary_cross_entropy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight={}, int64_t reduction=at::Reduction::Mean) { + return at::_ops::binary_cross_entropy::redispatch(dispatchKeySet, self, target, weight, reduction); + } + + // aten::binary_cross_entropy.out(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & binary_cross_entropy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight={}, int64_t reduction=at::Reduction::Mean) { + return at::_ops::binary_cross_entropy_out::redispatch(dispatchKeySet, self, target, weight, reduction, out); + } + + // aten::binary_cross_entropy.out(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & binary_cross_entropy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, at::Tensor & out) { + return at::_ops::binary_cross_entropy_out::redispatch(dispatchKeySet, self, target, weight, reduction, out); + } + + // aten::binary_cross_entropy_backward(Tensor grad_output, Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean) -> Tensor + inline at::Tensor binary_cross_entropy_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight={}, int64_t reduction=at::Reduction::Mean) { + return at::_ops::binary_cross_entropy_backward::redispatch(dispatchKeySet, grad_output, self, target, weight, reduction); + } + + // aten::binary_cross_entropy_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & binary_cross_entropy_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight={}, int64_t reduction=at::Reduction::Mean) { + return at::_ops::binary_cross_entropy_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, target, weight, reduction, grad_input); + } + + // aten::binary_cross_entropy_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & binary_cross_entropy_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, at::Tensor & grad_input) { + return at::_ops::binary_cross_entropy_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, target, weight, reduction, grad_input); + } + + // aten::binary_cross_entropy_with_logits(Tensor self, Tensor target, Tensor? weight=None, Tensor? pos_weight=None, int reduction=Mean) -> Tensor + inline at::Tensor binary_cross_entropy_with_logits(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight={}, const ::std::optional & pos_weight={}, int64_t reduction=at::Reduction::Mean) { + return at::_ops::binary_cross_entropy_with_logits::redispatch(dispatchKeySet, self, target, weight, pos_weight, reduction); + } + + // aten::bincount(Tensor self, Tensor? weights=None, SymInt minlength=0) -> Tensor + inline at::Tensor bincount(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const ::std::optional & weights={}, int64_t minlength=0) { + return at::_ops::bincount::redispatch(dispatchKeySet, self, weights, minlength); + } + + // aten::bincount(Tensor self, Tensor? weights=None, SymInt minlength=0) -> Tensor + inline at::Tensor bincount_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const ::std::optional & weights={}, c10::SymInt minlength=0) { + return at::_ops::bincount::redispatch(dispatchKeySet, self, weights, minlength); + } + + // aten::bitwise_not(Tensor self) -> Tensor + inline at::Tensor bitwise_not(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::bitwise_not::redispatch(dispatchKeySet, self); + } + + // aten::bitwise_not_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & bitwise_not_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::bitwise_not_::redispatch(dispatchKeySet, self); + } + + // aten::bitwise_not.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bitwise_not_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::bitwise_not_out::redispatch(dispatchKeySet, self, out); + } + + // aten::bitwise_not.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bitwise_not_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::bitwise_not_out::redispatch(dispatchKeySet, self, out); + } + + // aten::copysign.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & copysign_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::copysign_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::copysign.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & copysign_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::copysign_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::copysign.Tensor(Tensor self, Tensor other) -> Tensor + inline at::Tensor copysign(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::copysign_Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::copysign_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + inline at::Tensor & copysign_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) { + return at::_ops::copysign__Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::copysign.Scalar(Tensor self, Scalar other) -> Tensor + inline at::Tensor copysign(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::copysign_Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::copysign_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + inline at::Tensor & copysign_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other) { + return at::_ops::copysign__Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::copysign.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & copysign_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::copysign_Scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::copysign.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & copysign_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, at::Tensor & out) { + return at::_ops::copysign_Scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::_lazy_clone(Tensor self) -> Tensor + inline at::Tensor _lazy_clone(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::_lazy_clone::redispatch(dispatchKeySet, self); + } + + // aten::logical_not(Tensor self) -> Tensor + inline at::Tensor logical_not(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::logical_not::redispatch(dispatchKeySet, self); + } + + // aten::logical_not_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & logical_not_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::logical_not_::redispatch(dispatchKeySet, self); + } + + // aten::logical_not.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & logical_not_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::logical_not_out::redispatch(dispatchKeySet, self, out); + } + + // aten::logical_not.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & logical_not_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::logical_not_out::redispatch(dispatchKeySet, self, out); + } + + // aten::logical_xor(Tensor self, Tensor other) -> Tensor + inline at::Tensor logical_xor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::logical_xor::redispatch(dispatchKeySet, self, other); + } + + // aten::logical_xor_(Tensor(a!) self, Tensor other) -> Tensor(a!) + inline at::Tensor & logical_xor_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) { + return at::_ops::logical_xor_::redispatch(dispatchKeySet, self, other); + } + + // aten::logical_xor.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & logical_xor_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::logical_xor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::logical_xor.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & logical_xor_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::logical_xor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::logical_and(Tensor self, Tensor other) -> Tensor + inline at::Tensor logical_and(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::logical_and::redispatch(dispatchKeySet, self, other); + } + + // aten::logical_and_(Tensor(a!) self, Tensor other) -> Tensor(a!) + inline at::Tensor & logical_and_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) { + return at::_ops::logical_and_::redispatch(dispatchKeySet, self, other); + } + + // aten::logical_and.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & logical_and_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::logical_and_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::logical_and.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & logical_and_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::logical_and_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::logical_or(Tensor self, Tensor other) -> Tensor + inline at::Tensor logical_or(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::logical_or::redispatch(dispatchKeySet, self, other); + } + + // aten::logical_or_(Tensor(a!) self, Tensor other) -> Tensor(a!) + inline at::Tensor & logical_or_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) { + return at::_ops::logical_or_::redispatch(dispatchKeySet, self, other); + } + + // aten::logical_or.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & logical_or_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::logical_or_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::logical_or.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & logical_or_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::logical_or_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::blackman_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor blackman_window(c10::DispatchKeySet dispatchKeySet, int64_t window_length, at::TensorOptions options={}) { + return at::_ops::blackman_window::redispatch(dispatchKeySet, window_length, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::blackman_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor blackman_window(c10::DispatchKeySet dispatchKeySet, int64_t window_length, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::blackman_window::redispatch(dispatchKeySet, window_length, dtype, layout, device, pin_memory); + } + + // aten::blackman_window.periodic(int window_length, bool periodic, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor blackman_window(c10::DispatchKeySet dispatchKeySet, int64_t window_length, bool periodic, at::TensorOptions options={}) { + return at::_ops::blackman_window_periodic::redispatch(dispatchKeySet, window_length, periodic, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::blackman_window.periodic(int window_length, bool periodic, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor blackman_window(c10::DispatchKeySet dispatchKeySet, int64_t window_length, bool periodic, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::blackman_window_periodic::redispatch(dispatchKeySet, window_length, periodic, dtype, layout, device, pin_memory); + } + + // aten::bmm(Tensor self, Tensor mat2) -> Tensor + inline at::Tensor bmm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mat2) { + return at::_ops::bmm::redispatch(dispatchKeySet, self, mat2); + } + + // aten::bmm.out(Tensor self, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bmm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & mat2) { + return at::_ops::bmm_out::redispatch(dispatchKeySet, self, mat2, out); + } + + // aten::bmm.out(Tensor self, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bmm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mat2, at::Tensor & out) { + return at::_ops::bmm_out::redispatch(dispatchKeySet, self, mat2, out); + } + + // aten::bmm.dtype(Tensor self, Tensor mat2, ScalarType out_dtype) -> Tensor + inline at::Tensor bmm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mat2, at::ScalarType out_dtype) { + return at::_ops::bmm_dtype::redispatch(dispatchKeySet, self, mat2, out_dtype); + } + + // aten::bmm.dtype_out(Tensor self, Tensor mat2, ScalarType out_dtype, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bmm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & mat2, at::ScalarType out_dtype) { + return at::_ops::bmm_dtype_out::redispatch(dispatchKeySet, self, mat2, out_dtype, out); + } + + // aten::bmm.dtype_out(Tensor self, Tensor mat2, ScalarType out_dtype, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bmm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mat2, at::ScalarType out_dtype, at::Tensor & out) { + return at::_ops::bmm_dtype_out::redispatch(dispatchKeySet, self, mat2, out_dtype, out); + } + + // aten::broadcast_tensors(Tensor[] tensors) -> Tensor[] + inline ::std::vector broadcast_tensors(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors) { + return at::_ops::broadcast_tensors::redispatch(dispatchKeySet, tensors); + } + + // aten::broadcast_to(Tensor(a) self, SymInt[] size) -> Tensor(a) + inline at::Tensor broadcast_to(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size) { + return at::_ops::broadcast_to::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size)); + } + + // aten::broadcast_to(Tensor(a) self, SymInt[] size) -> Tensor(a) + inline at::Tensor broadcast_to_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size) { + return at::_ops::broadcast_to::redispatch(dispatchKeySet, self, size); + } + + // aten::_sparse_broadcast_to(Tensor(a) self, int[] size) -> Tensor(a) + inline at::Tensor _sparse_broadcast_to(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size) { + return at::_ops::_sparse_broadcast_to::redispatch(dispatchKeySet, self, size); + } + + // aten::cat(Tensor[] tensors, int dim=0) -> Tensor + inline at::Tensor cat(c10::DispatchKeySet dispatchKeySet, const at::ITensorListRef & tensors, int64_t dim=0) { + return at::_ops::cat::redispatch(dispatchKeySet, tensors, dim); + } + + // aten::cat.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & cat_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::ITensorListRef & tensors, int64_t dim=0) { + return at::_ops::cat_out::redispatch(dispatchKeySet, tensors, dim, out); + } + + // aten::cat.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & cat_outf(c10::DispatchKeySet dispatchKeySet, const at::ITensorListRef & tensors, int64_t dim, at::Tensor & out) { + return at::_ops::cat_out::redispatch(dispatchKeySet, tensors, dim, out); + } + + // aten::cat.names(Tensor[] tensors, Dimname dim) -> Tensor + inline at::Tensor cat(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors, at::Dimname dim) { + return at::_ops::cat_names::redispatch(dispatchKeySet, tensors, dim); + } + + // aten::cat.names_out(Tensor[] tensors, Dimname dim, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & cat_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::TensorList tensors, at::Dimname dim) { + return at::_ops::cat_names_out::redispatch(dispatchKeySet, tensors, dim, out); + } + + // aten::cat.names_out(Tensor[] tensors, Dimname dim, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & cat_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors, at::Dimname dim, at::Tensor & out) { + return at::_ops::cat_names_out::redispatch(dispatchKeySet, tensors, dim, out); + } + + // aten::concat(Tensor[] tensors, int dim=0) -> Tensor + inline at::Tensor concat(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors, int64_t dim=0) { + return at::_ops::concat::redispatch(dispatchKeySet, tensors, dim); + } + + // aten::concat.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & concat_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::TensorList tensors, int64_t dim=0) { + return at::_ops::concat_out::redispatch(dispatchKeySet, tensors, dim, out); + } + + // aten::concat.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & concat_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors, int64_t dim, at::Tensor & out) { + return at::_ops::concat_out::redispatch(dispatchKeySet, tensors, dim, out); + } + + // aten::concat.names(Tensor[] tensors, Dimname dim) -> Tensor + inline at::Tensor concat(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors, at::Dimname dim) { + return at::_ops::concat_names::redispatch(dispatchKeySet, tensors, dim); + } + + // aten::concat.names_out(Tensor[] tensors, Dimname dim, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & concat_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::TensorList tensors, at::Dimname dim) { + return at::_ops::concat_names_out::redispatch(dispatchKeySet, tensors, dim, out); + } + + // aten::concat.names_out(Tensor[] tensors, Dimname dim, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & concat_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors, at::Dimname dim, at::Tensor & out) { + return at::_ops::concat_names_out::redispatch(dispatchKeySet, tensors, dim, out); + } + + // aten::concatenate(Tensor[] tensors, int dim=0) -> Tensor + inline at::Tensor concatenate(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors, int64_t dim=0) { + return at::_ops::concatenate::redispatch(dispatchKeySet, tensors, dim); + } + + // aten::concatenate.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & concatenate_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::TensorList tensors, int64_t dim=0) { + return at::_ops::concatenate_out::redispatch(dispatchKeySet, tensors, dim, out); + } + + // aten::concatenate.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & concatenate_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors, int64_t dim, at::Tensor & out) { + return at::_ops::concatenate_out::redispatch(dispatchKeySet, tensors, dim, out); + } + + // aten::concatenate.names(Tensor[] tensors, Dimname dim) -> Tensor + inline at::Tensor concatenate(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors, at::Dimname dim) { + return at::_ops::concatenate_names::redispatch(dispatchKeySet, tensors, dim); + } + + // aten::concatenate.names_out(Tensor[] tensors, Dimname dim, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & concatenate_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::TensorList tensors, at::Dimname dim) { + return at::_ops::concatenate_names_out::redispatch(dispatchKeySet, tensors, dim, out); + } + + // aten::concatenate.names_out(Tensor[] tensors, Dimname dim, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & concatenate_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors, at::Dimname dim, at::Tensor & out) { + return at::_ops::concatenate_names_out::redispatch(dispatchKeySet, tensors, dim, out); + } + + // aten::block_diag(Tensor[] tensors) -> Tensor + inline at::Tensor block_diag(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors) { + return at::_ops::block_diag::redispatch(dispatchKeySet, tensors); + } + + // aten::ceil(Tensor self) -> Tensor + inline at::Tensor ceil(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::ceil::redispatch(dispatchKeySet, self); + } + + // aten::ceil_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & ceil_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::ceil_::redispatch(dispatchKeySet, self); + } + + // aten::ceil.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & ceil_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::ceil_out::redispatch(dispatchKeySet, self, out); + } + + // aten::ceil.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & ceil_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::ceil_out::redispatch(dispatchKeySet, self, out); + } + + // aten::chain_matmul(Tensor[] matrices) -> Tensor + inline at::Tensor chain_matmul(c10::DispatchKeySet dispatchKeySet, at::TensorList matrices) { + return at::_ops::chain_matmul::redispatch(dispatchKeySet, matrices); + } + + // aten::chain_matmul.out(Tensor[] matrices, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & chain_matmul_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::TensorList matrices) { + return at::_ops::chain_matmul_out::redispatch(dispatchKeySet, matrices, out); + } + + // aten::chain_matmul.out(Tensor[] matrices, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & chain_matmul_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList matrices, at::Tensor & out) { + return at::_ops::chain_matmul_out::redispatch(dispatchKeySet, matrices, out); + } + + // aten::unsafe_chunk(Tensor self, int chunks, int dim=0) -> Tensor[] + inline ::std::vector unsafe_chunk(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t chunks, int64_t dim=0) { + return at::_ops::unsafe_chunk::redispatch(dispatchKeySet, self, chunks, dim); + } + + // aten::chunk(Tensor(a -> *) self, int chunks, int dim=0) -> Tensor(a)[] + inline ::std::vector chunk(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t chunks, int64_t dim=0) { + return at::_ops::chunk::redispatch(dispatchKeySet, self, chunks, dim); + } + + // aten::tensor_split.sections(Tensor(a -> *) self, SymInt sections, int dim=0) -> Tensor(a)[] + inline ::std::vector tensor_split(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t sections, int64_t dim=0) { + return at::_ops::tensor_split_sections::redispatch(dispatchKeySet, self, sections, dim); + } + + // aten::tensor_split.sections(Tensor(a -> *) self, SymInt sections, int dim=0) -> Tensor(a)[] + inline ::std::vector tensor_split_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymInt sections, int64_t dim=0) { + return at::_ops::tensor_split_sections::redispatch(dispatchKeySet, self, sections, dim); + } + + // aten::tensor_split.indices(Tensor(a -> *) self, SymInt[] indices, int dim=0) -> Tensor(a)[] + inline ::std::vector tensor_split(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef indices, int64_t dim=0) { + return at::_ops::tensor_split_indices::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(indices), dim); + } + + // aten::tensor_split.indices(Tensor(a -> *) self, SymInt[] indices, int dim=0) -> Tensor(a)[] + inline ::std::vector tensor_split_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef indices, int64_t dim=0) { + return at::_ops::tensor_split_indices::redispatch(dispatchKeySet, self, indices, dim); + } + + // aten::tensor_split.tensor_indices_or_sections(Tensor(a -> *) self, Tensor tensor_indices_or_sections, int dim=0) -> Tensor(a)[] + inline ::std::vector tensor_split(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & tensor_indices_or_sections, int64_t dim=0) { + return at::_ops::tensor_split_tensor_indices_or_sections::redispatch(dispatchKeySet, self, tensor_indices_or_sections, dim); + } + + // aten::clamp(Tensor self, Scalar? min=None, Scalar? max=None) -> Tensor + inline at::Tensor clamp(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const ::std::optional & min, const ::std::optional & max=::std::nullopt) { + return at::_ops::clamp::redispatch(dispatchKeySet, self, min, max); + } + + // aten::clamp.Tensor(Tensor self, Tensor? min=None, Tensor? max=None) -> Tensor + inline at::Tensor clamp(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const ::std::optional & min={}, const ::std::optional & max={}) { + return at::_ops::clamp_Tensor::redispatch(dispatchKeySet, self, min, max); + } + + // aten::clamp_(Tensor(a!) self, Scalar? min=None, Scalar? max=None) -> Tensor(a!) + inline at::Tensor & clamp_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const ::std::optional & min, const ::std::optional & max=::std::nullopt) { + return at::_ops::clamp_::redispatch(dispatchKeySet, self, min, max); + } + + // aten::clamp_.Tensor(Tensor(a!) self, Tensor? min=None, Tensor? max=None) -> Tensor(a!) + inline at::Tensor & clamp_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const ::std::optional & min={}, const ::std::optional & max={}) { + return at::_ops::clamp__Tensor::redispatch(dispatchKeySet, self, min, max); + } + + // aten::clamp.out(Tensor self, Scalar? min=None, Scalar? max=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & clamp_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const ::std::optional & min, const ::std::optional & max=::std::nullopt) { + return at::_ops::clamp_out::redispatch(dispatchKeySet, self, min, max, out); + } + + // aten::clamp.out(Tensor self, Scalar? min=None, Scalar? max=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & clamp_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const ::std::optional & min, const ::std::optional & max, at::Tensor & out) { + return at::_ops::clamp_out::redispatch(dispatchKeySet, self, min, max, out); + } + + // aten::clamp.Tensor_out(Tensor self, Tensor? min=None, Tensor? max=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & clamp_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const ::std::optional & min={}, const ::std::optional & max={}) { + return at::_ops::clamp_Tensor_out::redispatch(dispatchKeySet, self, min, max, out); + } + + // aten::clamp.Tensor_out(Tensor self, Tensor? min=None, Tensor? max=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & clamp_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const ::std::optional & min, const ::std::optional & max, at::Tensor & out) { + return at::_ops::clamp_Tensor_out::redispatch(dispatchKeySet, self, min, max, out); + } + + // aten::clamp_max(Tensor self, Scalar max) -> Tensor + inline at::Tensor clamp_max(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & max) { + return at::_ops::clamp_max::redispatch(dispatchKeySet, self, max); + } + + // aten::clamp_max.Tensor(Tensor self, Tensor max) -> Tensor + inline at::Tensor clamp_max(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & max) { + return at::_ops::clamp_max_Tensor::redispatch(dispatchKeySet, self, max); + } + + // aten::clamp_max_(Tensor(a!) self, Scalar max) -> Tensor(a!) + inline at::Tensor & clamp_max_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & max) { + return at::_ops::clamp_max_::redispatch(dispatchKeySet, self, max); + } + + // aten::clamp_max_.Tensor(Tensor(a!) self, Tensor max) -> Tensor(a!) + inline at::Tensor & clamp_max_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & max) { + return at::_ops::clamp_max__Tensor::redispatch(dispatchKeySet, self, max); + } + + // aten::clamp_max.out(Tensor self, Scalar max, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & clamp_max_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & max) { + return at::_ops::clamp_max_out::redispatch(dispatchKeySet, self, max, out); + } + + // aten::clamp_max.out(Tensor self, Scalar max, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & clamp_max_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & max, at::Tensor & out) { + return at::_ops::clamp_max_out::redispatch(dispatchKeySet, self, max, out); + } + + // aten::clamp_max.Tensor_out(Tensor self, Tensor max, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & clamp_max_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & max) { + return at::_ops::clamp_max_Tensor_out::redispatch(dispatchKeySet, self, max, out); + } + + // aten::clamp_max.Tensor_out(Tensor self, Tensor max, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & clamp_max_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & max, at::Tensor & out) { + return at::_ops::clamp_max_Tensor_out::redispatch(dispatchKeySet, self, max, out); + } + + // aten::clamp_min(Tensor self, Scalar min) -> Tensor + inline at::Tensor clamp_min(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & min) { + return at::_ops::clamp_min::redispatch(dispatchKeySet, self, min); + } + + // aten::clamp_min.Tensor(Tensor self, Tensor min) -> Tensor + inline at::Tensor clamp_min(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & min) { + return at::_ops::clamp_min_Tensor::redispatch(dispatchKeySet, self, min); + } + + // aten::clamp_min_(Tensor(a!) self, Scalar min) -> Tensor(a!) + inline at::Tensor & clamp_min_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & min) { + return at::_ops::clamp_min_::redispatch(dispatchKeySet, self, min); + } + + // aten::clamp_min_.Tensor(Tensor(a!) self, Tensor min) -> Tensor(a!) + inline at::Tensor & clamp_min_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & min) { + return at::_ops::clamp_min__Tensor::redispatch(dispatchKeySet, self, min); + } + + // aten::clamp_min.out(Tensor self, Scalar min, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & clamp_min_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & min) { + return at::_ops::clamp_min_out::redispatch(dispatchKeySet, self, min, out); + } + + // aten::clamp_min.out(Tensor self, Scalar min, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & clamp_min_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & min, at::Tensor & out) { + return at::_ops::clamp_min_out::redispatch(dispatchKeySet, self, min, out); + } + + // aten::clamp_min.Tensor_out(Tensor self, Tensor min, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & clamp_min_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & min) { + return at::_ops::clamp_min_Tensor_out::redispatch(dispatchKeySet, self, min, out); + } + + // aten::clamp_min.Tensor_out(Tensor self, Tensor min, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & clamp_min_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & min, at::Tensor & out) { + return at::_ops::clamp_min_Tensor_out::redispatch(dispatchKeySet, self, min, out); + } + + // aten::clip(Tensor self, Scalar? min=None, Scalar? max=None) -> Tensor + inline at::Tensor clip(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const ::std::optional & min, const ::std::optional & max=::std::nullopt) { + return at::_ops::clip::redispatch(dispatchKeySet, self, min, max); + } + + // aten::clip.Tensor(Tensor self, Tensor? min=None, Tensor? max=None) -> Tensor + inline at::Tensor clip(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const ::std::optional & min={}, const ::std::optional & max={}) { + return at::_ops::clip_Tensor::redispatch(dispatchKeySet, self, min, max); + } + + // aten::clip_(Tensor(a!) self, Scalar? min=None, Scalar? max=None) -> Tensor(a!) + inline at::Tensor & clip_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const ::std::optional & min, const ::std::optional & max=::std::nullopt) { + return at::_ops::clip_::redispatch(dispatchKeySet, self, min, max); + } + + // aten::clip_.Tensor(Tensor(a!) self, Tensor? min=None, Tensor? max=None) -> Tensor(a!) + inline at::Tensor & clip_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const ::std::optional & min={}, const ::std::optional & max={}) { + return at::_ops::clip__Tensor::redispatch(dispatchKeySet, self, min, max); + } + + // aten::clip.out(Tensor self, Scalar? min=None, Scalar? max=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & clip_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const ::std::optional & min, const ::std::optional & max=::std::nullopt) { + return at::_ops::clip_out::redispatch(dispatchKeySet, self, min, max, out); + } + + // aten::clip.out(Tensor self, Scalar? min=None, Scalar? max=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & clip_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const ::std::optional & min, const ::std::optional & max, at::Tensor & out) { + return at::_ops::clip_out::redispatch(dispatchKeySet, self, min, max, out); + } + + // aten::clip.Tensor_out(Tensor self, Tensor? min=None, Tensor? max=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & clip_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const ::std::optional & min={}, const ::std::optional & max={}) { + return at::_ops::clip_Tensor_out::redispatch(dispatchKeySet, self, min, max, out); + } + + // aten::clip.Tensor_out(Tensor self, Tensor? min=None, Tensor? max=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & clip_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const ::std::optional & min, const ::std::optional & max, at::Tensor & out) { + return at::_ops::clip_Tensor_out::redispatch(dispatchKeySet, self, min, max, out); + } + + // aten::cudnn_is_acceptable(Tensor self) -> bool + inline bool cudnn_is_acceptable(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::cudnn_is_acceptable::redispatch(dispatchKeySet, self); + } + + // aten::complex(Tensor real, Tensor imag) -> Tensor + inline at::Tensor complex(c10::DispatchKeySet dispatchKeySet, const at::Tensor & real, const at::Tensor & imag) { + return at::_ops::complex::redispatch(dispatchKeySet, real, imag); + } + + // aten::complex.out(Tensor real, Tensor imag, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & complex_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & real, const at::Tensor & imag) { + return at::_ops::complex_out::redispatch(dispatchKeySet, real, imag, out); + } + + // aten::complex.out(Tensor real, Tensor imag, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & complex_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & real, const at::Tensor & imag, at::Tensor & out) { + return at::_ops::complex_out::redispatch(dispatchKeySet, real, imag, out); + } + + // aten::polar(Tensor abs, Tensor angle) -> Tensor + inline at::Tensor polar(c10::DispatchKeySet dispatchKeySet, const at::Tensor & abs, const at::Tensor & angle) { + return at::_ops::polar::redispatch(dispatchKeySet, abs, angle); + } + + // aten::polar.out(Tensor abs, Tensor angle, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & polar_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & abs, const at::Tensor & angle) { + return at::_ops::polar_out::redispatch(dispatchKeySet, abs, angle, out); + } + + // aten::polar.out(Tensor abs, Tensor angle, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & polar_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & abs, const at::Tensor & angle, at::Tensor & out) { + return at::_ops::polar_out::redispatch(dispatchKeySet, abs, angle, out); + } + + // aten::constant_pad_nd(Tensor self, SymInt[] pad, Scalar value=0) -> Tensor + inline at::Tensor constant_pad_nd(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef pad, const at::Scalar & value=0) { + return at::_ops::constant_pad_nd::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(pad), value); + } + + // aten::constant_pad_nd(Tensor self, SymInt[] pad, Scalar value=0) -> Tensor + inline at::Tensor constant_pad_nd_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef pad, const at::Scalar & value=0) { + return at::_ops::constant_pad_nd::redispatch(dispatchKeySet, self, pad, value); + } + + // aten::contiguous(Tensor(a) self, *, MemoryFormat memory_format=contiguous_format) -> Tensor(a) + inline at::Tensor __dispatch_contiguous(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::MemoryFormat memory_format=c10::MemoryFormat::Contiguous) { + return at::_ops::contiguous::redispatch(dispatchKeySet, self, memory_format); + } + + // aten::convolution(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups) -> Tensor + inline at::Tensor convolution(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups) { + return at::_ops::convolution::redispatch(dispatchKeySet, input, weight, bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), transposed, c10::fromIntArrayRefSlow(output_padding), groups); + } + + // aten::convolution(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups) -> Tensor + inline at::Tensor convolution_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups) { + return at::_ops::convolution::redispatch(dispatchKeySet, input, weight, bias, stride, padding, dilation, transposed, output_padding, groups); + } + + // aten::convolution_backward(Tensor grad_output, Tensor input, Tensor weight, SymInt[]? bias_sizes, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor) + inline ::std::tuple convolution_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & weight, at::OptionalIntArrayRef bias_sizes, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups, ::std::array output_mask) { + return at::_ops::convolution_backward::redispatch(dispatchKeySet, grad_output, input, weight, bias_sizes.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*bias_sizes)) : ::std::nullopt, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), transposed, c10::fromIntArrayRefSlow(output_padding), groups, output_mask); + } + + // aten::convolution_backward(Tensor grad_output, Tensor input, Tensor weight, SymInt[]? bias_sizes, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor) + inline ::std::tuple convolution_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & weight, at::OptionalSymIntArrayRef bias_sizes, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups, ::std::array output_mask) { + return at::_ops::convolution_backward::redispatch(dispatchKeySet, grad_output, input, weight, bias_sizes, stride, padding, dilation, transposed, output_padding, groups, output_mask); + } + + // aten::convolution_overrideable(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups) -> Tensor + inline at::Tensor convolution_overrideable(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups) { + return at::_ops::convolution_overrideable::redispatch(dispatchKeySet, input, weight, bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), transposed, c10::fromIntArrayRefSlow(output_padding), groups); + } + + // aten::convolution_overrideable(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups) -> Tensor + inline at::Tensor convolution_overrideable_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups) { + return at::_ops::convolution_overrideable::redispatch(dispatchKeySet, input, weight, bias, stride, padding, dilation, transposed, output_padding, groups); + } + + // aten::convolution_backward_overrideable(Tensor grad_output, Tensor input, Tensor weight, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool[3] output_mask) -> (Tensor grad_input, Tensor grad_weight, Tensor grad_bias) + inline ::std::tuple convolution_backward_overrideable(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & weight, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups, ::std::array output_mask) { + return at::_ops::convolution_backward_overrideable::redispatch(dispatchKeySet, grad_output, input, weight, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), transposed, c10::fromIntArrayRefSlow(output_padding), groups, output_mask); + } + + // aten::convolution_backward_overrideable(Tensor grad_output, Tensor input, Tensor weight, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool[3] output_mask) -> (Tensor grad_input, Tensor grad_weight, Tensor grad_bias) + inline ::std::tuple convolution_backward_overrideable_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & weight, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups, ::std::array output_mask) { + return at::_ops::convolution_backward_overrideable::redispatch(dispatchKeySet, grad_output, input, weight, stride, padding, dilation, transposed, output_padding, groups, output_mask); + } + + // aten::_convolution(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) -> Tensor + inline at::Tensor _convolution(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) { + return at::_ops::_convolution::redispatch(dispatchKeySet, input, weight, bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), transposed, c10::fromIntArrayRefSlow(output_padding), groups, benchmark, deterministic, cudnn_enabled, allow_tf32); + } + + // aten::_convolution(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) -> Tensor + inline at::Tensor _convolution_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) { + return at::_ops::_convolution::redispatch(dispatchKeySet, input, weight, bias, stride, padding, dilation, transposed, output_padding, groups, benchmark, deterministic, cudnn_enabled, allow_tf32); + } + + // aten::_convolution.deprecated(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, int[] output_padding, SymInt groups, bool benchmark, bool deterministic, bool cudnn_enabled) -> Tensor + inline at::Tensor _convolution(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups, bool benchmark, bool deterministic, bool cudnn_enabled) { + return at::_ops::_convolution_deprecated::redispatch(dispatchKeySet, input, weight, bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), transposed, output_padding, groups, benchmark, deterministic, cudnn_enabled); + } + + // aten::_convolution.deprecated(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, int[] output_padding, SymInt groups, bool benchmark, bool deterministic, bool cudnn_enabled) -> Tensor + inline at::Tensor _convolution_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, c10::SymInt groups, bool benchmark, bool deterministic, bool cudnn_enabled) { + return at::_ops::_convolution_deprecated::redispatch(dispatchKeySet, input, weight, bias, stride, padding, dilation, transposed, output_padding, groups, benchmark, deterministic, cudnn_enabled); + } + + // aten::_convolution_mode(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, str padding, SymInt[] dilation, SymInt groups) -> Tensor + inline at::Tensor _convolution_mode(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, at::IntArrayRef stride, c10::string_view padding, at::IntArrayRef dilation, int64_t groups) { + return at::_ops::_convolution_mode::redispatch(dispatchKeySet, input, weight, bias, c10::fromIntArrayRefSlow(stride), padding, c10::fromIntArrayRefSlow(dilation), groups); + } + + // aten::_convolution_mode(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, str padding, SymInt[] dilation, SymInt groups) -> Tensor + inline at::Tensor _convolution_mode_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::string_view padding, c10::SymIntArrayRef dilation, c10::SymInt groups) { + return at::_ops::_convolution_mode::redispatch(dispatchKeySet, input, weight, bias, stride, padding, dilation, groups); + } + + // aten::_convolution_double_backward(Tensor? ggI, Tensor? ggW, Tensor? ggb, Tensor gO, Tensor weight, Tensor self, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor) + inline ::std::tuple _convolution_double_backward(c10::DispatchKeySet dispatchKeySet, const ::std::optional & ggI, const ::std::optional & ggW, const ::std::optional & ggb, const at::Tensor & gO, const at::Tensor & weight, const at::Tensor & self, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups, ::std::array output_mask) { + return at::_ops::_convolution_double_backward::redispatch(dispatchKeySet, ggI, ggW, ggb, gO, weight, self, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), transposed, c10::fromIntArrayRefSlow(output_padding), groups, output_mask); + } + + // aten::_convolution_double_backward(Tensor? ggI, Tensor? ggW, Tensor? ggb, Tensor gO, Tensor weight, Tensor self, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor) + inline ::std::tuple _convolution_double_backward_symint(c10::DispatchKeySet dispatchKeySet, const ::std::optional & ggI, const ::std::optional & ggW, const ::std::optional & ggb, const at::Tensor & gO, const at::Tensor & weight, const at::Tensor & self, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups, ::std::array output_mask) { + return at::_ops::_convolution_double_backward::redispatch(dispatchKeySet, ggI, ggW, ggb, gO, weight, self, stride, padding, dilation, transposed, output_padding, groups, output_mask); + } + + // aten::conv1d(Tensor input, Tensor weight, Tensor? bias=None, SymInt[1] stride=1, SymInt[1] padding=0, SymInt[1] dilation=1, SymInt groups=1) -> Tensor + inline at::Tensor conv1d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias={}, at::IntArrayRef stride=1, at::IntArrayRef padding=0, at::IntArrayRef dilation=1, int64_t groups=1) { + return at::_ops::conv1d::redispatch(dispatchKeySet, input, weight, bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), groups); + } + + // aten::conv1d(Tensor input, Tensor weight, Tensor? bias=None, SymInt[1] stride=1, SymInt[1] padding=0, SymInt[1] dilation=1, SymInt groups=1) -> Tensor + inline at::Tensor conv1d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias={}, c10::SymIntArrayRef stride=c10::SymInt(1), c10::SymIntArrayRef padding=c10::SymInt(0), c10::SymIntArrayRef dilation=c10::SymInt(1), c10::SymInt groups=1) { + return at::_ops::conv1d::redispatch(dispatchKeySet, input, weight, bias, stride, padding, dilation, groups); + } + + // aten::conv2d(Tensor input, Tensor weight, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] dilation=1, SymInt groups=1) -> Tensor + inline at::Tensor conv2d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias={}, at::IntArrayRef stride=1, at::IntArrayRef padding=0, at::IntArrayRef dilation=1, int64_t groups=1) { + return at::_ops::conv2d::redispatch(dispatchKeySet, input, weight, bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), groups); + } + + // aten::conv2d(Tensor input, Tensor weight, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] dilation=1, SymInt groups=1) -> Tensor + inline at::Tensor conv2d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias={}, c10::SymIntArrayRef stride=c10::SymInt(1), c10::SymIntArrayRef padding=c10::SymInt(0), c10::SymIntArrayRef dilation=c10::SymInt(1), c10::SymInt groups=1) { + return at::_ops::conv2d::redispatch(dispatchKeySet, input, weight, bias, stride, padding, dilation, groups); + } + + // aten::conv3d(Tensor input, Tensor weight, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, SymInt[3] dilation=1, SymInt groups=1) -> Tensor + inline at::Tensor conv3d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias={}, at::IntArrayRef stride=1, at::IntArrayRef padding=0, at::IntArrayRef dilation=1, int64_t groups=1) { + return at::_ops::conv3d::redispatch(dispatchKeySet, input, weight, bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), groups); + } + + // aten::conv3d(Tensor input, Tensor weight, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, SymInt[3] dilation=1, SymInt groups=1) -> Tensor + inline at::Tensor conv3d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias={}, c10::SymIntArrayRef stride=c10::SymInt(1), c10::SymIntArrayRef padding=c10::SymInt(0), c10::SymIntArrayRef dilation=c10::SymInt(1), c10::SymInt groups=1) { + return at::_ops::conv3d::redispatch(dispatchKeySet, input, weight, bias, stride, padding, dilation, groups); + } + + // aten::conv1d.padding(Tensor input, Tensor weight, Tensor? bias=None, SymInt[1] stride=1, str padding="valid", SymInt[1] dilation=1, SymInt groups=1) -> Tensor + inline at::Tensor conv1d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, at::IntArrayRef stride, c10::string_view padding, at::IntArrayRef dilation=1, int64_t groups=1) { + return at::_ops::conv1d_padding::redispatch(dispatchKeySet, input, weight, bias, c10::fromIntArrayRefSlow(stride), padding, c10::fromIntArrayRefSlow(dilation), groups); + } + + // aten::conv1d.padding(Tensor input, Tensor weight, Tensor? bias=None, SymInt[1] stride=1, str padding="valid", SymInt[1] dilation=1, SymInt groups=1) -> Tensor + inline at::Tensor conv1d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::string_view padding, c10::SymIntArrayRef dilation=c10::SymInt(1), c10::SymInt groups=1) { + return at::_ops::conv1d_padding::redispatch(dispatchKeySet, input, weight, bias, stride, padding, dilation, groups); + } + + // aten::conv2d.padding(Tensor input, Tensor weight, Tensor? bias=None, SymInt[2] stride=1, str padding="valid", SymInt[2] dilation=1, SymInt groups=1) -> Tensor + inline at::Tensor conv2d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, at::IntArrayRef stride, c10::string_view padding, at::IntArrayRef dilation=1, int64_t groups=1) { + return at::_ops::conv2d_padding::redispatch(dispatchKeySet, input, weight, bias, c10::fromIntArrayRefSlow(stride), padding, c10::fromIntArrayRefSlow(dilation), groups); + } + + // aten::conv2d.padding(Tensor input, Tensor weight, Tensor? bias=None, SymInt[2] stride=1, str padding="valid", SymInt[2] dilation=1, SymInt groups=1) -> Tensor + inline at::Tensor conv2d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::string_view padding, c10::SymIntArrayRef dilation=c10::SymInt(1), c10::SymInt groups=1) { + return at::_ops::conv2d_padding::redispatch(dispatchKeySet, input, weight, bias, stride, padding, dilation, groups); + } + + // aten::conv3d.padding(Tensor input, Tensor weight, Tensor? bias=None, SymInt[3] stride=1, str padding="valid", SymInt[3] dilation=1, SymInt groups=1) -> Tensor + inline at::Tensor conv3d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, at::IntArrayRef stride, c10::string_view padding, at::IntArrayRef dilation=1, int64_t groups=1) { + return at::_ops::conv3d_padding::redispatch(dispatchKeySet, input, weight, bias, c10::fromIntArrayRefSlow(stride), padding, c10::fromIntArrayRefSlow(dilation), groups); + } + + // aten::conv3d.padding(Tensor input, Tensor weight, Tensor? bias=None, SymInt[3] stride=1, str padding="valid", SymInt[3] dilation=1, SymInt groups=1) -> Tensor + inline at::Tensor conv3d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::string_view padding, c10::SymIntArrayRef dilation=c10::SymInt(1), c10::SymInt groups=1) { + return at::_ops::conv3d_padding::redispatch(dispatchKeySet, input, weight, bias, stride, padding, dilation, groups); + } + + // aten::conv_tbc(Tensor self, Tensor weight, Tensor bias, int pad=0) -> Tensor + inline at::Tensor conv_tbc(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const at::Tensor & bias, int64_t pad=0) { + return at::_ops::conv_tbc::redispatch(dispatchKeySet, self, weight, bias, pad); + } + + // aten::conv_tbc_backward(Tensor self, Tensor input, Tensor weight, Tensor bias, int pad) -> (Tensor, Tensor, Tensor) + inline ::std::tuple conv_tbc_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & input, const at::Tensor & weight, const at::Tensor & bias, int64_t pad) { + return at::_ops::conv_tbc_backward::redispatch(dispatchKeySet, self, input, weight, bias, pad); + } + + // aten::conv_transpose1d(Tensor input, Tensor weight, Tensor? bias=None, SymInt[1] stride=1, SymInt[1] padding=0, SymInt[1] output_padding=0, SymInt groups=1, SymInt[1] dilation=1) -> Tensor + inline at::Tensor conv_transpose1d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias={}, at::IntArrayRef stride=1, at::IntArrayRef padding=0, at::IntArrayRef output_padding=0, int64_t groups=1, at::IntArrayRef dilation=1) { + return at::_ops::conv_transpose1d::redispatch(dispatchKeySet, input, weight, bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(output_padding), groups, c10::fromIntArrayRefSlow(dilation)); + } + + // aten::conv_transpose1d(Tensor input, Tensor weight, Tensor? bias=None, SymInt[1] stride=1, SymInt[1] padding=0, SymInt[1] output_padding=0, SymInt groups=1, SymInt[1] dilation=1) -> Tensor + inline at::Tensor conv_transpose1d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias={}, c10::SymIntArrayRef stride=c10::SymInt(1), c10::SymIntArrayRef padding=c10::SymInt(0), c10::SymIntArrayRef output_padding=c10::SymInt(0), c10::SymInt groups=1, c10::SymIntArrayRef dilation=c10::SymInt(1)) { + return at::_ops::conv_transpose1d::redispatch(dispatchKeySet, input, weight, bias, stride, padding, output_padding, groups, dilation); + } + + // aten::conv_transpose2d.input(Tensor input, Tensor weight, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] output_padding=0, SymInt groups=1, SymInt[2] dilation=1) -> Tensor + inline at::Tensor conv_transpose2d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias={}, at::IntArrayRef stride=1, at::IntArrayRef padding=0, at::IntArrayRef output_padding=0, int64_t groups=1, at::IntArrayRef dilation=1) { + return at::_ops::conv_transpose2d_input::redispatch(dispatchKeySet, input, weight, bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(output_padding), groups, c10::fromIntArrayRefSlow(dilation)); + } + + // aten::conv_transpose2d.input(Tensor input, Tensor weight, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] output_padding=0, SymInt groups=1, SymInt[2] dilation=1) -> Tensor + inline at::Tensor conv_transpose2d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias={}, c10::SymIntArrayRef stride=c10::SymInt(1), c10::SymIntArrayRef padding=c10::SymInt(0), c10::SymIntArrayRef output_padding=c10::SymInt(0), c10::SymInt groups=1, c10::SymIntArrayRef dilation=c10::SymInt(1)) { + return at::_ops::conv_transpose2d_input::redispatch(dispatchKeySet, input, weight, bias, stride, padding, output_padding, groups, dilation); + } + + // aten::conv_transpose3d.input(Tensor input, Tensor weight, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, SymInt[3] output_padding=0, SymInt groups=1, SymInt[3] dilation=1) -> Tensor + inline at::Tensor conv_transpose3d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias={}, at::IntArrayRef stride=1, at::IntArrayRef padding=0, at::IntArrayRef output_padding=0, int64_t groups=1, at::IntArrayRef dilation=1) { + return at::_ops::conv_transpose3d_input::redispatch(dispatchKeySet, input, weight, bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(output_padding), groups, c10::fromIntArrayRefSlow(dilation)); + } + + // aten::conv_transpose3d.input(Tensor input, Tensor weight, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, SymInt[3] output_padding=0, SymInt groups=1, SymInt[3] dilation=1) -> Tensor + inline at::Tensor conv_transpose3d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias={}, c10::SymIntArrayRef stride=c10::SymInt(1), c10::SymIntArrayRef padding=c10::SymInt(0), c10::SymIntArrayRef output_padding=c10::SymInt(0), c10::SymInt groups=1, c10::SymIntArrayRef dilation=c10::SymInt(1)) { + return at::_ops::conv_transpose3d_input::redispatch(dispatchKeySet, input, weight, bias, stride, padding, output_padding, groups, dilation); + } + + // aten::copy(Tensor self, Tensor src, bool non_blocking=False) -> Tensor + inline at::Tensor copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & src, bool non_blocking=false) { + return at::_ops::copy::redispatch(dispatchKeySet, self, src, non_blocking); + } + + // aten::copy_(Tensor(a!) self, Tensor src, bool non_blocking=False) -> Tensor(a!) + inline at::Tensor & copy_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & src, bool non_blocking=false) { + return at::_ops::copy_::redispatch(dispatchKeySet, self, src, non_blocking); + } + + // aten::_copy_from(Tensor self, Tensor dst, bool non_blocking=False) -> Tensor + inline at::Tensor _copy_from(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & dst, bool non_blocking=false) { + return at::_ops::_copy_from::redispatch(dispatchKeySet, self, dst, non_blocking); + } + + // aten::_copy_from_and_resize(Tensor self, Tensor dst) -> Tensor + inline at::Tensor _copy_from_and_resize(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & dst) { + return at::_ops::_copy_from_and_resize::redispatch(dispatchKeySet, self, dst); + } + + // aten::cos(Tensor self) -> Tensor + inline at::Tensor cos(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::cos::redispatch(dispatchKeySet, self); + } + + // aten::cos_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & cos_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::cos_::redispatch(dispatchKeySet, self); + } + + // aten::cos.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & cos_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::cos_out::redispatch(dispatchKeySet, self, out); + } + + // aten::cos.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & cos_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::cos_out::redispatch(dispatchKeySet, self, out); + } + + // aten::cosh(Tensor self) -> Tensor + inline at::Tensor cosh(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::cosh::redispatch(dispatchKeySet, self); + } + + // aten::cosh_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & cosh_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::cosh_::redispatch(dispatchKeySet, self); + } + + // aten::cosh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & cosh_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::cosh_out::redispatch(dispatchKeySet, self, out); + } + + // aten::cosh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & cosh_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::cosh_out::redispatch(dispatchKeySet, self, out); + } + + // aten::cosine_embedding_loss(Tensor input1, Tensor input2, Tensor target, float margin=0.0, int reduction=Mean) -> Tensor + inline at::Tensor cosine_embedding_loss(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input1, const at::Tensor & input2, const at::Tensor & target, double margin=0.0, int64_t reduction=at::Reduction::Mean) { + return at::_ops::cosine_embedding_loss::redispatch(dispatchKeySet, input1, input2, target, margin, reduction); + } + + // aten::count_nonzero.dim_IntList(Tensor self, int[] dim) -> Tensor + inline at::Tensor count_nonzero(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim) { + return at::_ops::count_nonzero_dim_IntList::redispatch(dispatchKeySet, self, dim); + } + + // aten::count_nonzero(Tensor self, int? dim=None) -> Tensor + inline at::Tensor count_nonzero(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional dim=::std::nullopt) { + return at::_ops::count_nonzero::redispatch(dispatchKeySet, self, dim); + } + + // aten::cov(Tensor self, *, int correction=1, Tensor? fweights=None, Tensor? aweights=None) -> Tensor + inline at::Tensor cov(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t correction=1, const ::std::optional & fweights={}, const ::std::optional & aweights={}) { + return at::_ops::cov::redispatch(dispatchKeySet, self, correction, fweights, aweights); + } + + // aten::corrcoef(Tensor self) -> Tensor + inline at::Tensor corrcoef(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::corrcoef::redispatch(dispatchKeySet, self); + } + + // aten::cudnn_affine_grid_generator(Tensor theta, int N, int C, int H, int W) -> Tensor grid + inline at::Tensor cudnn_affine_grid_generator(c10::DispatchKeySet dispatchKeySet, const at::Tensor & theta, int64_t N, int64_t C, int64_t H, int64_t W) { + return at::_ops::cudnn_affine_grid_generator::redispatch(dispatchKeySet, theta, N, C, H, W); + } + + // aten::cudnn_affine_grid_generator_backward(Tensor grad, int N, int C, int H, int W) -> Tensor grad_theta + inline at::Tensor cudnn_affine_grid_generator_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, int64_t N, int64_t C, int64_t H, int64_t W) { + return at::_ops::cudnn_affine_grid_generator_backward::redispatch(dispatchKeySet, grad, N, C, H, W); + } + + // aten::cudnn_batch_norm(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon) -> (Tensor, Tensor, Tensor, Tensor) + inline ::std::tuple cudnn_batch_norm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, const ::std::optional & running_mean, const ::std::optional & running_var, bool training, double exponential_average_factor, double epsilon) { + return at::_ops::cudnn_batch_norm::redispatch(dispatchKeySet, input, weight, bias, running_mean, running_var, training, exponential_average_factor, epsilon); + } + + // aten::cudnn_batch_norm.out(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!)) + inline ::std::tuple cudnn_batch_norm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, const ::std::optional & running_mean, const ::std::optional & running_var, bool training, double exponential_average_factor, double epsilon) { + return at::_ops::cudnn_batch_norm_out::redispatch(dispatchKeySet, input, weight, bias, running_mean, running_var, training, exponential_average_factor, epsilon, out0, out1, out2, out3); + } + + // aten::cudnn_batch_norm.out(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!)) + inline ::std::tuple cudnn_batch_norm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, const ::std::optional & running_mean, const ::std::optional & running_var, bool training, double exponential_average_factor, double epsilon, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3) { + return at::_ops::cudnn_batch_norm_out::redispatch(dispatchKeySet, input, weight, bias, running_mean, running_var, training, exponential_average_factor, epsilon, out0, out1, out2, out3); + } + + // aten::cudnn_batch_norm_backward(Tensor input, Tensor grad_output, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, float epsilon, Tensor reserveSpace) -> (Tensor, Tensor, Tensor) + inline ::std::tuple cudnn_batch_norm_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & grad_output, const at::Tensor & weight, const ::std::optional & running_mean, const ::std::optional & running_var, const ::std::optional & save_mean, const ::std::optional & save_var, double epsilon, const at::Tensor & reserveSpace) { + return at::_ops::cudnn_batch_norm_backward::redispatch(dispatchKeySet, input, grad_output, weight, running_mean, running_var, save_mean, save_var, epsilon, reserveSpace); + } + + // aten::cudnn_convolution(Tensor self, Tensor weight, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, bool allow_tf32) -> Tensor + inline at::Tensor cudnn_convolution(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups, bool benchmark, bool deterministic, bool allow_tf32) { + return at::_ops::cudnn_convolution::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups, benchmark, deterministic, allow_tf32); + } + + // aten::cudnn_convolution(Tensor self, Tensor weight, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, bool allow_tf32) -> Tensor + inline at::Tensor cudnn_convolution_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, bool benchmark, bool deterministic, bool allow_tf32) { + return at::_ops::cudnn_convolution::redispatch(dispatchKeySet, self, weight, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32); + } + + // aten::cudnn_convolution.out(Tensor self, Tensor weight, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, bool allow_tf32, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & cudnn_convolution_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups, bool benchmark, bool deterministic, bool allow_tf32) { + return at::_ops::cudnn_convolution_out::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups, benchmark, deterministic, allow_tf32, out); + } + + // aten::cudnn_convolution.out(Tensor self, Tensor weight, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, bool allow_tf32, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & cudnn_convolution_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups, bool benchmark, bool deterministic, bool allow_tf32, at::Tensor & out) { + return at::_ops::cudnn_convolution_out::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups, benchmark, deterministic, allow_tf32, out); + } + + // aten::cudnn_convolution.out(Tensor self, Tensor weight, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, bool allow_tf32, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & cudnn_convolution_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, bool benchmark, bool deterministic, bool allow_tf32) { + return at::_ops::cudnn_convolution_out::redispatch(dispatchKeySet, self, weight, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32, out); + } + + // aten::cudnn_convolution.out(Tensor self, Tensor weight, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, bool allow_tf32, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & cudnn_convolution_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, bool benchmark, bool deterministic, bool allow_tf32, at::Tensor & out) { + return at::_ops::cudnn_convolution_out::redispatch(dispatchKeySet, self, weight, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32, out); + } + + // aten::cudnn_convolution_transpose(Tensor self, Tensor weight, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, bool allow_tf32) -> Tensor + inline at::Tensor cudnn_convolution_transpose(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef padding, at::IntArrayRef output_padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups, bool benchmark, bool deterministic, bool allow_tf32) { + return at::_ops::cudnn_convolution_transpose::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(output_padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups, benchmark, deterministic, allow_tf32); + } + + // aten::cudnn_convolution_transpose(Tensor self, Tensor weight, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, bool allow_tf32) -> Tensor + inline at::Tensor cudnn_convolution_transpose_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, bool benchmark, bool deterministic, bool allow_tf32) { + return at::_ops::cudnn_convolution_transpose::redispatch(dispatchKeySet, self, weight, padding, output_padding, stride, dilation, groups, benchmark, deterministic, allow_tf32); + } + + // aten::_mps_convolution_transpose(Tensor self, Tensor weight, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups) -> Tensor + inline at::Tensor _mps_convolution_transpose(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef padding, at::IntArrayRef output_padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups) { + return at::_ops::_mps_convolution_transpose::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(output_padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups); + } + + // aten::_mps_convolution_transpose(Tensor self, Tensor weight, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups) -> Tensor + inline at::Tensor _mps_convolution_transpose_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups) { + return at::_ops::_mps_convolution_transpose::redispatch(dispatchKeySet, self, weight, padding, output_padding, stride, dilation, groups); + } + + // aten::mps_convolution_transpose_backward(Tensor self, Tensor grad_output, Tensor weight, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool[2] output_mask) -> (Tensor, Tensor) + inline ::std::tuple mps_convolution_transpose_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & grad_output, const at::Tensor & weight, at::IntArrayRef padding, at::IntArrayRef output_padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups, ::std::array output_mask) { + return at::_ops::mps_convolution_transpose_backward::redispatch(dispatchKeySet, self, grad_output, weight, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(output_padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups, output_mask); + } + + // aten::mps_convolution_transpose_backward(Tensor self, Tensor grad_output, Tensor weight, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool[2] output_mask) -> (Tensor, Tensor) + inline ::std::tuple mps_convolution_transpose_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & grad_output, const at::Tensor & weight, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, ::std::array output_mask) { + return at::_ops::mps_convolution_transpose_backward::redispatch(dispatchKeySet, self, grad_output, weight, padding, output_padding, stride, dilation, groups, output_mask); + } + + // aten::cudnn_convolution_relu(Tensor self, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, SymInt groups) -> Tensor + inline at::Tensor cudnn_convolution_relu(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, int64_t groups) { + return at::_ops::cudnn_convolution_relu::redispatch(dispatchKeySet, self, weight, bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), groups); + } + + // aten::cudnn_convolution_relu(Tensor self, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, SymInt groups) -> Tensor + inline at::Tensor cudnn_convolution_relu_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, c10::SymInt groups) { + return at::_ops::cudnn_convolution_relu::redispatch(dispatchKeySet, self, weight, bias, stride, padding, dilation, groups); + } + + // aten::cudnn_convolution_add_relu(Tensor self, Tensor weight, Tensor z, Scalar? alpha, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, SymInt groups) -> Tensor + inline at::Tensor cudnn_convolution_add_relu(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const at::Tensor & z, const ::std::optional & alpha, const ::std::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, int64_t groups) { + return at::_ops::cudnn_convolution_add_relu::redispatch(dispatchKeySet, self, weight, z, alpha, bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), groups); + } + + // aten::cudnn_convolution_add_relu(Tensor self, Tensor weight, Tensor z, Scalar? alpha, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, SymInt groups) -> Tensor + inline at::Tensor cudnn_convolution_add_relu_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const at::Tensor & z, const ::std::optional & alpha, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, c10::SymInt groups) { + return at::_ops::cudnn_convolution_add_relu::redispatch(dispatchKeySet, self, weight, z, alpha, bias, stride, padding, dilation, groups); + } + + // aten::cudnn_grid_sampler(Tensor self, Tensor grid) -> Tensor output + inline at::Tensor cudnn_grid_sampler(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & grid) { + return at::_ops::cudnn_grid_sampler::redispatch(dispatchKeySet, self, grid); + } + + // aten::cudnn_grid_sampler_backward(Tensor self, Tensor grid, Tensor grad_output) -> (Tensor grad_self, Tensor grad_grid) + inline ::std::tuple cudnn_grid_sampler_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & grid, const at::Tensor & grad_output) { + return at::_ops::cudnn_grid_sampler_backward::redispatch(dispatchKeySet, self, grid, grad_output); + } + + // aten::cummax(Tensor self, int dim) -> (Tensor values, Tensor indices) + inline ::std::tuple cummax(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim) { + return at::_ops::cummax::redispatch(dispatchKeySet, self, dim); + } + + // aten::cummax.out(Tensor self, int dim, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + inline ::std::tuple cummax_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & values, at::Tensor & indices, const at::Tensor & self, int64_t dim) { + return at::_ops::cummax_out::redispatch(dispatchKeySet, self, dim, values, indices); + } + + // aten::cummax.out(Tensor self, int dim, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + inline ::std::tuple cummax_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, at::Tensor & values, at::Tensor & indices) { + return at::_ops::cummax_out::redispatch(dispatchKeySet, self, dim, values, indices); + } + + // aten::cummax.dimname(Tensor self, Dimname dim) -> (Tensor values, Tensor indices) + inline ::std::tuple cummax(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim) { + return at::_ops::cummax_dimname::redispatch(dispatchKeySet, self, dim); + } + + // aten::cummax.dimname_out(Tensor self, Dimname dim, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + inline ::std::tuple cummax_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & values, at::Tensor & indices, const at::Tensor & self, at::Dimname dim) { + return at::_ops::cummax_dimname_out::redispatch(dispatchKeySet, self, dim, values, indices); + } + + // aten::cummax.dimname_out(Tensor self, Dimname dim, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + inline ::std::tuple cummax_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, at::Tensor & values, at::Tensor & indices) { + return at::_ops::cummax_dimname_out::redispatch(dispatchKeySet, self, dim, values, indices); + } + + // aten::_cummax_helper(Tensor self, Tensor(a!) values, Tensor(b!) indices, int dim) -> () + inline void _cummax_helper(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & values, at::Tensor & indices, int64_t dim) { + return at::_ops::_cummax_helper::redispatch(dispatchKeySet, self, values, indices, dim); + } + + // aten::cummin(Tensor self, int dim) -> (Tensor values, Tensor indices) + inline ::std::tuple cummin(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim) { + return at::_ops::cummin::redispatch(dispatchKeySet, self, dim); + } + + // aten::cummin.out(Tensor self, int dim, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + inline ::std::tuple cummin_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & values, at::Tensor & indices, const at::Tensor & self, int64_t dim) { + return at::_ops::cummin_out::redispatch(dispatchKeySet, self, dim, values, indices); + } + + // aten::cummin.out(Tensor self, int dim, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + inline ::std::tuple cummin_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, at::Tensor & values, at::Tensor & indices) { + return at::_ops::cummin_out::redispatch(dispatchKeySet, self, dim, values, indices); + } + + // aten::cummin.dimname(Tensor self, Dimname dim) -> (Tensor values, Tensor indices) + inline ::std::tuple cummin(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim) { + return at::_ops::cummin_dimname::redispatch(dispatchKeySet, self, dim); + } + + // aten::cummin.dimname_out(Tensor self, Dimname dim, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + inline ::std::tuple cummin_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & values, at::Tensor & indices, const at::Tensor & self, at::Dimname dim) { + return at::_ops::cummin_dimname_out::redispatch(dispatchKeySet, self, dim, values, indices); + } + + // aten::cummin.dimname_out(Tensor self, Dimname dim, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + inline ::std::tuple cummin_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, at::Tensor & values, at::Tensor & indices) { + return at::_ops::cummin_dimname_out::redispatch(dispatchKeySet, self, dim, values, indices); + } + + // aten::_cummin_helper(Tensor self, Tensor(a!) values, Tensor(b!) indices, int dim) -> () + inline void _cummin_helper(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & values, at::Tensor & indices, int64_t dim) { + return at::_ops::_cummin_helper::redispatch(dispatchKeySet, self, values, indices, dim); + } + + // aten::cummaxmin_backward(Tensor grad, Tensor input, Tensor indices, int dim) -> Tensor + inline at::Tensor cummaxmin_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & input, const at::Tensor & indices, int64_t dim) { + return at::_ops::cummaxmin_backward::redispatch(dispatchKeySet, grad, input, indices, dim); + } + + // aten::cumprod(Tensor self, int dim, *, ScalarType? dtype=None) -> Tensor + inline at::Tensor cumprod(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, ::std::optional dtype=::std::nullopt) { + return at::_ops::cumprod::redispatch(dispatchKeySet, self, dim, dtype); + } + + // aten::cumprod_(Tensor(a!) self, int dim, *, ScalarType? dtype=None) -> Tensor(a!) + inline at::Tensor & cumprod_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, int64_t dim, ::std::optional dtype=::std::nullopt) { + return at::_ops::cumprod_::redispatch(dispatchKeySet, self, dim, dtype); + } + + // aten::cumprod.out(Tensor self, int dim, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & cumprod_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim, ::std::optional dtype=::std::nullopt) { + return at::_ops::cumprod_out::redispatch(dispatchKeySet, self, dim, dtype, out); + } + + // aten::cumprod.out(Tensor self, int dim, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & cumprod_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, ::std::optional dtype, at::Tensor & out) { + return at::_ops::cumprod_out::redispatch(dispatchKeySet, self, dim, dtype, out); + } + + // aten::cumprod.dimname(Tensor self, Dimname dim, *, ScalarType? dtype=None) -> Tensor + inline at::Tensor cumprod(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, ::std::optional dtype=::std::nullopt) { + return at::_ops::cumprod_dimname::redispatch(dispatchKeySet, self, dim, dtype); + } + + // aten::cumprod_.dimname(Tensor(a!) self, Dimname dim, *, ScalarType? dtype=None) -> Tensor(a!) + inline at::Tensor & cumprod_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, at::Dimname dim, ::std::optional dtype=::std::nullopt) { + return at::_ops::cumprod__dimname::redispatch(dispatchKeySet, self, dim, dtype); + } + + // aten::cumprod.dimname_out(Tensor self, Dimname dim, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & cumprod_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::Dimname dim, ::std::optional dtype=::std::nullopt) { + return at::_ops::cumprod_dimname_out::redispatch(dispatchKeySet, self, dim, dtype, out); + } + + // aten::cumprod.dimname_out(Tensor self, Dimname dim, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & cumprod_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, ::std::optional dtype, at::Tensor & out) { + return at::_ops::cumprod_dimname_out::redispatch(dispatchKeySet, self, dim, dtype, out); + } + + // aten::cumprod_backward(Tensor grad, Tensor input, int dim, Tensor output) -> Tensor + inline at::Tensor cumprod_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & input, int64_t dim, const at::Tensor & output) { + return at::_ops::cumprod_backward::redispatch(dispatchKeySet, grad, input, dim, output); + } + + // aten::cumsum(Tensor self, int dim, *, ScalarType? dtype=None) -> Tensor + inline at::Tensor cumsum(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, ::std::optional dtype=::std::nullopt) { + return at::_ops::cumsum::redispatch(dispatchKeySet, self, dim, dtype); + } + + // aten::cumsum_(Tensor(a!) self, int dim, *, ScalarType? dtype=None) -> Tensor(a!) + inline at::Tensor & cumsum_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, int64_t dim, ::std::optional dtype=::std::nullopt) { + return at::_ops::cumsum_::redispatch(dispatchKeySet, self, dim, dtype); + } + + // aten::cumsum.out(Tensor self, int dim, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & cumsum_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim, ::std::optional dtype=::std::nullopt) { + return at::_ops::cumsum_out::redispatch(dispatchKeySet, self, dim, dtype, out); + } + + // aten::cumsum.out(Tensor self, int dim, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & cumsum_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, ::std::optional dtype, at::Tensor & out) { + return at::_ops::cumsum_out::redispatch(dispatchKeySet, self, dim, dtype, out); + } + + // aten::cumsum.dimname(Tensor self, Dimname dim, *, ScalarType? dtype=None) -> Tensor + inline at::Tensor cumsum(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, ::std::optional dtype=::std::nullopt) { + return at::_ops::cumsum_dimname::redispatch(dispatchKeySet, self, dim, dtype); + } + + // aten::cumsum_.dimname(Tensor(a!) self, Dimname dim, *, ScalarType? dtype=None) -> Tensor(a!) + inline at::Tensor & cumsum_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, at::Dimname dim, ::std::optional dtype=::std::nullopt) { + return at::_ops::cumsum__dimname::redispatch(dispatchKeySet, self, dim, dtype); + } + + // aten::cumsum.dimname_out(Tensor self, Dimname dim, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & cumsum_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::Dimname dim, ::std::optional dtype=::std::nullopt) { + return at::_ops::cumsum_dimname_out::redispatch(dispatchKeySet, self, dim, dtype, out); + } + + // aten::cumsum.dimname_out(Tensor self, Dimname dim, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & cumsum_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, ::std::optional dtype, at::Tensor & out) { + return at::_ops::cumsum_dimname_out::redispatch(dispatchKeySet, self, dim, dtype, out); + } + + // aten::cumulative_trapezoid.x(Tensor y, Tensor x, *, int dim=-1) -> Tensor + inline at::Tensor cumulative_trapezoid(c10::DispatchKeySet dispatchKeySet, const at::Tensor & y, const at::Tensor & x, int64_t dim=-1) { + return at::_ops::cumulative_trapezoid_x::redispatch(dispatchKeySet, y, x, dim); + } + + // aten::cumulative_trapezoid.dx(Tensor y, *, Scalar dx=1, int dim=-1) -> Tensor + inline at::Tensor cumulative_trapezoid(c10::DispatchKeySet dispatchKeySet, const at::Tensor & y, const at::Scalar & dx=1, int64_t dim=-1) { + return at::_ops::cumulative_trapezoid_dx::redispatch(dispatchKeySet, y, dx, dim); + } + + // aten::ctc_loss.IntList(Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, int blank=0, int reduction=Mean, bool zero_infinity=False) -> Tensor + inline at::Tensor ctc_loss(c10::DispatchKeySet dispatchKeySet, const at::Tensor & log_probs, const at::Tensor & targets, at::IntArrayRef input_lengths, at::IntArrayRef target_lengths, int64_t blank=0, int64_t reduction=at::Reduction::Mean, bool zero_infinity=false) { + return at::_ops::ctc_loss_IntList::redispatch(dispatchKeySet, log_probs, targets, input_lengths, target_lengths, blank, reduction, zero_infinity); + } + + // aten::ctc_loss.Tensor(Tensor log_probs, Tensor targets, Tensor input_lengths, Tensor target_lengths, int blank=0, int reduction=Mean, bool zero_infinity=False) -> Tensor + inline at::Tensor ctc_loss(c10::DispatchKeySet dispatchKeySet, const at::Tensor & log_probs, const at::Tensor & targets, const at::Tensor & input_lengths, const at::Tensor & target_lengths, int64_t blank=0, int64_t reduction=at::Reduction::Mean, bool zero_infinity=false) { + return at::_ops::ctc_loss_Tensor::redispatch(dispatchKeySet, log_probs, targets, input_lengths, target_lengths, blank, reduction, zero_infinity); + } + + // aten::_ctc_loss(Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, int blank=0, bool zero_infinity=False) -> (Tensor, Tensor) + inline ::std::tuple _ctc_loss(c10::DispatchKeySet dispatchKeySet, const at::Tensor & log_probs, const at::Tensor & targets, at::IntArrayRef input_lengths, at::IntArrayRef target_lengths, int64_t blank=0, bool zero_infinity=false) { + return at::_ops::_ctc_loss::redispatch(dispatchKeySet, log_probs, targets, input_lengths, target_lengths, blank, zero_infinity); + } + + // aten::_ctc_loss.Tensor(Tensor log_probs, Tensor targets, Tensor input_lengths, Tensor target_lengths, int blank=0, bool zero_infinity=False) -> (Tensor, Tensor) + inline ::std::tuple _ctc_loss(c10::DispatchKeySet dispatchKeySet, const at::Tensor & log_probs, const at::Tensor & targets, const at::Tensor & input_lengths, const at::Tensor & target_lengths, int64_t blank=0, bool zero_infinity=false) { + return at::_ops::_ctc_loss_Tensor::redispatch(dispatchKeySet, log_probs, targets, input_lengths, target_lengths, blank, zero_infinity); + } + + // aten::_ctc_loss_backward(Tensor grad, Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, Tensor neg_log_likelihood, Tensor log_alpha, int blank, bool zero_infinity=False) -> Tensor + inline at::Tensor _ctc_loss_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & log_probs, const at::Tensor & targets, at::IntArrayRef input_lengths, at::IntArrayRef target_lengths, const at::Tensor & neg_log_likelihood, const at::Tensor & log_alpha, int64_t blank, bool zero_infinity=false) { + return at::_ops::_ctc_loss_backward::redispatch(dispatchKeySet, grad, log_probs, targets, input_lengths, target_lengths, neg_log_likelihood, log_alpha, blank, zero_infinity); + } + + // aten::_ctc_loss_backward.Tensor(Tensor grad, Tensor log_probs, Tensor targets, Tensor input_lengths, Tensor target_lengths, Tensor neg_log_likelihood, Tensor log_alpha, int blank, bool zero_infinity=False) -> Tensor + inline at::Tensor _ctc_loss_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & log_probs, const at::Tensor & targets, const at::Tensor & input_lengths, const at::Tensor & target_lengths, const at::Tensor & neg_log_likelihood, const at::Tensor & log_alpha, int64_t blank, bool zero_infinity=false) { + return at::_ops::_ctc_loss_backward_Tensor::redispatch(dispatchKeySet, grad, log_probs, targets, input_lengths, target_lengths, neg_log_likelihood, log_alpha, blank, zero_infinity); + } + + // aten::diag_embed(Tensor self, int offset=0, int dim1=-2, int dim2=-1) -> Tensor + inline at::Tensor diag_embed(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t offset=0, int64_t dim1=-2, int64_t dim2=-1) { + return at::_ops::diag_embed::redispatch(dispatchKeySet, self, offset, dim1, dim2); + } + + // aten::diagflat(Tensor self, int offset=0) -> Tensor + inline at::Tensor diagflat(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t offset=0) { + return at::_ops::diagflat::redispatch(dispatchKeySet, self, offset); + } + + // aten::diagonal(Tensor(a) self, int offset=0, int dim1=0, int dim2=1) -> Tensor(a) + inline at::Tensor diagonal(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t offset=0, int64_t dim1=0, int64_t dim2=1) { + return at::_ops::diagonal::redispatch(dispatchKeySet, self, offset, dim1, dim2); + } + + // aten::linalg_diagonal(Tensor(a) A, *, int offset=0, int dim1=-2, int dim2=-1) -> Tensor(a) + inline at::Tensor linalg_diagonal(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A, int64_t offset=0, int64_t dim1=-2, int64_t dim2=-1) { + return at::_ops::linalg_diagonal::redispatch(dispatchKeySet, A, offset, dim1, dim2); + } + + // aten::diagonal.Dimname(Tensor(a) self, *, Dimname outdim, Dimname dim1, Dimname dim2, int offset=0) -> Tensor(a) + inline at::Tensor diagonal(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname outdim, at::Dimname dim1, at::Dimname dim2, int64_t offset=0) { + return at::_ops::diagonal_Dimname::redispatch(dispatchKeySet, self, outdim, dim1, dim2, offset); + } + + // aten::diagonal_backward(Tensor grad_output, SymInt[] input_sizes, int offset, int dim1, int dim2) -> Tensor + inline at::Tensor diagonal_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, at::IntArrayRef input_sizes, int64_t offset, int64_t dim1, int64_t dim2) { + return at::_ops::diagonal_backward::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(input_sizes), offset, dim1, dim2); + } + + // aten::diagonal_backward(Tensor grad_output, SymInt[] input_sizes, int offset, int dim1, int dim2) -> Tensor + inline at::Tensor diagonal_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, c10::SymIntArrayRef input_sizes, int64_t offset, int64_t dim1, int64_t dim2) { + return at::_ops::diagonal_backward::redispatch(dispatchKeySet, grad_output, input_sizes, offset, dim1, dim2); + } + + // aten::fill_diagonal_(Tensor(a!) self, Scalar fill_value, bool wrap=False) -> Tensor(a!) + inline at::Tensor & fill_diagonal_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & fill_value, bool wrap=false) { + return at::_ops::fill_diagonal_::redispatch(dispatchKeySet, self, fill_value, wrap); + } + + // aten::diff(Tensor self, int n=1, int dim=-1, Tensor? prepend=None, Tensor? append=None) -> Tensor + inline at::Tensor diff(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t n=1, int64_t dim=-1, const ::std::optional & prepend={}, const ::std::optional & append={}) { + return at::_ops::diff::redispatch(dispatchKeySet, self, n, dim, prepend, append); + } + + // aten::diff.out(Tensor self, int n=1, int dim=-1, Tensor? prepend=None, Tensor? append=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & diff_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t n=1, int64_t dim=-1, const ::std::optional & prepend={}, const ::std::optional & append={}) { + return at::_ops::diff_out::redispatch(dispatchKeySet, self, n, dim, prepend, append, out); + } + + // aten::diff.out(Tensor self, int n=1, int dim=-1, Tensor? prepend=None, Tensor? append=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & diff_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t n, int64_t dim, const ::std::optional & prepend, const ::std::optional & append, at::Tensor & out) { + return at::_ops::diff_out::redispatch(dispatchKeySet, self, n, dim, prepend, append, out); + } + + // aten::gradient.scalarint(Tensor self, *, Scalar? spacing=None, int? dim=None, int edge_order=1) -> Tensor[] + inline ::std::vector gradient(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const ::std::optional & spacing=::std::nullopt, ::std::optional dim=::std::nullopt, int64_t edge_order=1) { + return at::_ops::gradient_scalarint::redispatch(dispatchKeySet, self, spacing, dim, edge_order); + } + + // aten::gradient.scalararray(Tensor self, *, Scalar spacing, int[] dim, int edge_order=1) -> Tensor[] + inline ::std::vector gradient(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & spacing, at::IntArrayRef dim, int64_t edge_order=1) { + return at::_ops::gradient_scalararray::redispatch(dispatchKeySet, self, spacing, dim, edge_order); + } + + // aten::gradient.array(Tensor self, *, int[] dim, int edge_order=1) -> Tensor[] + inline ::std::vector gradient(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim, int64_t edge_order=1) { + return at::_ops::gradient_array::redispatch(dispatchKeySet, self, dim, edge_order); + } + + // aten::gradient.scalarrayint(Tensor self, *, Scalar[] spacing, int? dim=None, int edge_order=1) -> Tensor[] + inline ::std::vector gradient(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::ArrayRef spacing, ::std::optional dim=::std::nullopt, int64_t edge_order=1) { + return at::_ops::gradient_scalarrayint::redispatch(dispatchKeySet, self, spacing, dim, edge_order); + } + + // aten::gradient.scalarrayarray(Tensor self, *, Scalar[] spacing, int[] dim, int edge_order=1) -> Tensor[] + inline ::std::vector gradient(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::ArrayRef spacing, at::IntArrayRef dim, int64_t edge_order=1) { + return at::_ops::gradient_scalarrayarray::redispatch(dispatchKeySet, self, spacing, dim, edge_order); + } + + // aten::gradient.tensorarrayint(Tensor self, *, Tensor[] spacing, int? dim=None, int edge_order=1) -> Tensor[] + inline ::std::vector gradient(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::TensorList spacing, ::std::optional dim=::std::nullopt, int64_t edge_order=1) { + return at::_ops::gradient_tensorarrayint::redispatch(dispatchKeySet, self, spacing, dim, edge_order); + } + + // aten::gradient.tensorarray(Tensor self, *, Tensor[] spacing, int[] dim, int edge_order=1) -> Tensor[] + inline ::std::vector gradient(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::TensorList spacing, at::IntArrayRef dim, int64_t edge_order=1) { + return at::_ops::gradient_tensorarray::redispatch(dispatchKeySet, self, spacing, dim, edge_order); + } + + // aten::div.Tensor(Tensor self, Tensor other) -> Tensor + inline at::Tensor div(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::div_Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::div_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + inline at::Tensor & div_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) { + return at::_ops::div__Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::div.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & div_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::div_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::div.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & div_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::div_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::div.Tensor_mode(Tensor self, Tensor other, *, str? rounding_mode) -> Tensor + inline at::Tensor div(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, ::std::optional rounding_mode) { + return at::_ops::div_Tensor_mode::redispatch(dispatchKeySet, self, other, rounding_mode); + } + + // aten::div_.Tensor_mode(Tensor(a!) self, Tensor other, *, str? rounding_mode) -> Tensor(a!) + inline at::Tensor & div_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other, ::std::optional rounding_mode) { + return at::_ops::div__Tensor_mode::redispatch(dispatchKeySet, self, other, rounding_mode); + } + + // aten::div.out_mode(Tensor self, Tensor other, *, str? rounding_mode, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & div_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other, ::std::optional rounding_mode) { + return at::_ops::div_out_mode::redispatch(dispatchKeySet, self, other, rounding_mode, out); + } + + // aten::div.out_mode(Tensor self, Tensor other, *, str? rounding_mode, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & div_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, ::std::optional rounding_mode, at::Tensor & out) { + return at::_ops::div_out_mode::redispatch(dispatchKeySet, self, other, rounding_mode, out); + } + + // aten::div.Scalar(Tensor self, Scalar other) -> Tensor + inline at::Tensor div(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::div_Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::div_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + inline at::Tensor & div_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other) { + return at::_ops::div__Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::div.Scalar_mode(Tensor self, Scalar other, *, str? rounding_mode) -> Tensor + inline at::Tensor div(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, ::std::optional rounding_mode) { + return at::_ops::div_Scalar_mode::redispatch(dispatchKeySet, self, other, rounding_mode); + } + + // aten::div_.Scalar_mode(Tensor(a!) self, Scalar other, *, str? rounding_mode) -> Tensor(a!) + inline at::Tensor & div_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other, ::std::optional rounding_mode) { + return at::_ops::div__Scalar_mode::redispatch(dispatchKeySet, self, other, rounding_mode); + } + + // aten::divide.Tensor(Tensor self, Tensor other) -> Tensor + inline at::Tensor divide(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::divide_Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::divide_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + inline at::Tensor & divide_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) { + return at::_ops::divide__Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::divide.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & divide_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::divide_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::divide.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & divide_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::divide_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::divide.Scalar(Tensor self, Scalar other) -> Tensor + inline at::Tensor divide(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::divide_Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::divide_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + inline at::Tensor & divide_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other) { + return at::_ops::divide__Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::divide.Tensor_mode(Tensor self, Tensor other, *, str? rounding_mode) -> Tensor + inline at::Tensor divide(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, ::std::optional rounding_mode) { + return at::_ops::divide_Tensor_mode::redispatch(dispatchKeySet, self, other, rounding_mode); + } + + // aten::divide_.Tensor_mode(Tensor(a!) self, Tensor other, *, str? rounding_mode) -> Tensor(a!) + inline at::Tensor & divide_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other, ::std::optional rounding_mode) { + return at::_ops::divide__Tensor_mode::redispatch(dispatchKeySet, self, other, rounding_mode); + } + + // aten::divide.out_mode(Tensor self, Tensor other, *, str? rounding_mode, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & divide_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other, ::std::optional rounding_mode) { + return at::_ops::divide_out_mode::redispatch(dispatchKeySet, self, other, rounding_mode, out); + } + + // aten::divide.out_mode(Tensor self, Tensor other, *, str? rounding_mode, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & divide_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, ::std::optional rounding_mode, at::Tensor & out) { + return at::_ops::divide_out_mode::redispatch(dispatchKeySet, self, other, rounding_mode, out); + } + + // aten::divide.Scalar_mode(Tensor self, Scalar other, *, str? rounding_mode) -> Tensor + inline at::Tensor divide(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, ::std::optional rounding_mode) { + return at::_ops::divide_Scalar_mode::redispatch(dispatchKeySet, self, other, rounding_mode); + } + + // aten::divide_.Scalar_mode(Tensor(a!) self, Scalar other, *, str? rounding_mode) -> Tensor(a!) + inline at::Tensor & divide_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other, ::std::optional rounding_mode) { + return at::_ops::divide__Scalar_mode::redispatch(dispatchKeySet, self, other, rounding_mode); + } + + // aten::true_divide.Tensor(Tensor self, Tensor other) -> Tensor + inline at::Tensor true_divide(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::true_divide_Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::true_divide_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + inline at::Tensor & true_divide_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) { + return at::_ops::true_divide__Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::true_divide.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & true_divide_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::true_divide_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::true_divide.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & true_divide_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::true_divide_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::true_divide.Scalar(Tensor self, Scalar other) -> Tensor + inline at::Tensor true_divide(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::true_divide_Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::true_divide_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + inline at::Tensor & true_divide_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other) { + return at::_ops::true_divide__Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::dot(Tensor self, Tensor tensor) -> Tensor + inline at::Tensor dot(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & tensor) { + return at::_ops::dot::redispatch(dispatchKeySet, self, tensor); + } + + // aten::dot.out(Tensor self, Tensor tensor, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & dot_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & tensor) { + return at::_ops::dot_out::redispatch(dispatchKeySet, self, tensor, out); + } + + // aten::dot.out(Tensor self, Tensor tensor, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & dot_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & tensor, at::Tensor & out) { + return at::_ops::dot_out::redispatch(dispatchKeySet, self, tensor, out); + } + + // aten::vdot(Tensor self, Tensor other) -> Tensor + inline at::Tensor vdot(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::vdot::redispatch(dispatchKeySet, self, other); + } + + // aten::vdot.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & vdot_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::vdot_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::vdot.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & vdot_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::vdot_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::einsum(str equation, Tensor[] tensors, *, int[]? path=None) -> Tensor + inline at::Tensor einsum(c10::DispatchKeySet dispatchKeySet, c10::string_view equation, at::TensorList tensors, at::OptionalIntArrayRef path=::std::nullopt) { + return at::_ops::einsum::redispatch(dispatchKeySet, equation, tensors, path); + } + + // aten::embedding(Tensor weight, Tensor indices, SymInt padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False) -> Tensor + inline at::Tensor embedding(c10::DispatchKeySet dispatchKeySet, const at::Tensor & weight, const at::Tensor & indices, int64_t padding_idx=-1, bool scale_grad_by_freq=false, bool sparse=false) { + return at::_ops::embedding::redispatch(dispatchKeySet, weight, indices, padding_idx, scale_grad_by_freq, sparse); + } + + // aten::embedding(Tensor weight, Tensor indices, SymInt padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False) -> Tensor + inline at::Tensor embedding_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & weight, const at::Tensor & indices, c10::SymInt padding_idx=-1, bool scale_grad_by_freq=false, bool sparse=false) { + return at::_ops::embedding::redispatch(dispatchKeySet, weight, indices, padding_idx, scale_grad_by_freq, sparse); + } + + // aten::embedding_backward(Tensor grad, Tensor indices, SymInt num_weights, SymInt padding_idx, bool scale_grad_by_freq, bool sparse) -> Tensor + inline at::Tensor embedding_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & indices, int64_t num_weights, int64_t padding_idx, bool scale_grad_by_freq, bool sparse) { + return at::_ops::embedding_backward::redispatch(dispatchKeySet, grad, indices, num_weights, padding_idx, scale_grad_by_freq, sparse); + } + + // aten::embedding_backward(Tensor grad, Tensor indices, SymInt num_weights, SymInt padding_idx, bool scale_grad_by_freq, bool sparse) -> Tensor + inline at::Tensor embedding_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & indices, c10::SymInt num_weights, c10::SymInt padding_idx, bool scale_grad_by_freq, bool sparse) { + return at::_ops::embedding_backward::redispatch(dispatchKeySet, grad, indices, num_weights, padding_idx, scale_grad_by_freq, sparse); + } + + // aten::embedding_dense_backward(Tensor grad_output, Tensor indices, SymInt num_weights, SymInt padding_idx, bool scale_grad_by_freq) -> Tensor + inline at::Tensor embedding_dense_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & indices, int64_t num_weights, int64_t padding_idx, bool scale_grad_by_freq) { + return at::_ops::embedding_dense_backward::redispatch(dispatchKeySet, grad_output, indices, num_weights, padding_idx, scale_grad_by_freq); + } + + // aten::embedding_dense_backward(Tensor grad_output, Tensor indices, SymInt num_weights, SymInt padding_idx, bool scale_grad_by_freq) -> Tensor + inline at::Tensor embedding_dense_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & indices, c10::SymInt num_weights, c10::SymInt padding_idx, bool scale_grad_by_freq) { + return at::_ops::embedding_dense_backward::redispatch(dispatchKeySet, grad_output, indices, num_weights, padding_idx, scale_grad_by_freq); + } + + // aten::embedding_renorm_(Tensor(a!) self, Tensor indices, float max_norm, float norm_type) -> Tensor(a!) + inline at::Tensor & embedding_renorm_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & indices, double max_norm, double norm_type) { + return at::_ops::embedding_renorm_::redispatch(dispatchKeySet, self, indices, max_norm, norm_type); + } + + // aten::embedding_sparse_backward(Tensor grad, Tensor indices, int num_weights, int padding_idx, bool scale_grad_by_freq) -> Tensor + inline at::Tensor embedding_sparse_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & indices, int64_t num_weights, int64_t padding_idx, bool scale_grad_by_freq) { + return at::_ops::embedding_sparse_backward::redispatch(dispatchKeySet, grad, indices, num_weights, padding_idx, scale_grad_by_freq); + } + + // aten::_embedding_bag_forward_only(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False, int padding_idx=-1) -> (Tensor, Tensor, Tensor, Tensor) + inline ::std::tuple _embedding_bag_forward_only(c10::DispatchKeySet dispatchKeySet, const at::Tensor & weight, const at::Tensor & indices, const at::Tensor & offsets, bool scale_grad_by_freq=false, int64_t mode=0, bool sparse=false, const ::std::optional & per_sample_weights={}, bool include_last_offset=false, int64_t padding_idx=-1) { + return at::_ops::_embedding_bag_forward_only::redispatch(dispatchKeySet, weight, indices, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights, include_last_offset, padding_idx); + } + + // aten::_rowwise_prune(Tensor weight, Tensor mask, ScalarType compressed_indices_dtype) -> (Tensor, Tensor) + inline ::std::tuple _rowwise_prune(c10::DispatchKeySet dispatchKeySet, const at::Tensor & weight, const at::Tensor & mask, at::ScalarType compressed_indices_dtype) { + return at::_ops::_rowwise_prune::redispatch(dispatchKeySet, weight, mask, compressed_indices_dtype); + } + + // aten::row_stack(Tensor[] tensors) -> Tensor + inline at::Tensor row_stack(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors) { + return at::_ops::row_stack::redispatch(dispatchKeySet, tensors); + } + + // aten::row_stack.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & row_stack_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::TensorList tensors) { + return at::_ops::row_stack_out::redispatch(dispatchKeySet, tensors, out); + } + + // aten::row_stack.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & row_stack_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors, at::Tensor & out) { + return at::_ops::row_stack_out::redispatch(dispatchKeySet, tensors, out); + } + + // aten::embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False) -> (Tensor, Tensor, Tensor, Tensor) + inline ::std::tuple embedding_bag(c10::DispatchKeySet dispatchKeySet, const at::Tensor & weight, const at::Tensor & indices, const at::Tensor & offsets, bool scale_grad_by_freq=false, int64_t mode=0, bool sparse=false, const ::std::optional & per_sample_weights={}, bool include_last_offset=false) { + return at::_ops::embedding_bag::redispatch(dispatchKeySet, weight, indices, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights, include_last_offset); + } + + // aten::embedding_bag.padding_idx(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq, int mode, bool sparse, Tensor? per_sample_weights, bool include_last_offset, int? padding_idx) -> (Tensor, Tensor, Tensor, Tensor) + inline ::std::tuple embedding_bag(c10::DispatchKeySet dispatchKeySet, const at::Tensor & weight, const at::Tensor & indices, const at::Tensor & offsets, bool scale_grad_by_freq, int64_t mode, bool sparse, const ::std::optional & per_sample_weights, bool include_last_offset, ::std::optional padding_idx) { + return at::_ops::embedding_bag_padding_idx::redispatch(dispatchKeySet, weight, indices, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights, include_last_offset, padding_idx); + } + + // aten::_embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False, int padding_idx=-1) -> (Tensor, Tensor, Tensor, Tensor) + inline ::std::tuple _embedding_bag(c10::DispatchKeySet dispatchKeySet, const at::Tensor & weight, const at::Tensor & indices, const at::Tensor & offsets, bool scale_grad_by_freq=false, int64_t mode=0, bool sparse=false, const ::std::optional & per_sample_weights={}, bool include_last_offset=false, int64_t padding_idx=-1) { + return at::_ops::_embedding_bag::redispatch(dispatchKeySet, weight, indices, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights, include_last_offset, padding_idx); + } + + // aten::_embedding_bag_backward(Tensor grad, Tensor indices, Tensor offsets, Tensor offset2bag, Tensor bag_size, Tensor maximum_indices, SymInt num_weights, bool scale_grad_by_freq, int mode, bool sparse, Tensor? per_sample_weights, int padding_idx=-1) -> Tensor + inline at::Tensor _embedding_bag_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & indices, const at::Tensor & offsets, const at::Tensor & offset2bag, const at::Tensor & bag_size, const at::Tensor & maximum_indices, int64_t num_weights, bool scale_grad_by_freq, int64_t mode, bool sparse, const ::std::optional & per_sample_weights, int64_t padding_idx=-1) { + return at::_ops::_embedding_bag_backward::redispatch(dispatchKeySet, grad, indices, offsets, offset2bag, bag_size, maximum_indices, num_weights, scale_grad_by_freq, mode, sparse, per_sample_weights, padding_idx); + } + + // aten::_embedding_bag_backward(Tensor grad, Tensor indices, Tensor offsets, Tensor offset2bag, Tensor bag_size, Tensor maximum_indices, SymInt num_weights, bool scale_grad_by_freq, int mode, bool sparse, Tensor? per_sample_weights, int padding_idx=-1) -> Tensor + inline at::Tensor _embedding_bag_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & indices, const at::Tensor & offsets, const at::Tensor & offset2bag, const at::Tensor & bag_size, const at::Tensor & maximum_indices, c10::SymInt num_weights, bool scale_grad_by_freq, int64_t mode, bool sparse, const ::std::optional & per_sample_weights, int64_t padding_idx=-1) { + return at::_ops::_embedding_bag_backward::redispatch(dispatchKeySet, grad, indices, offsets, offset2bag, bag_size, maximum_indices, num_weights, scale_grad_by_freq, mode, sparse, per_sample_weights, padding_idx); + } + + // aten::_embedding_bag_sparse_backward(Tensor grad, Tensor indices, Tensor offsets, Tensor offset2bag, Tensor bag_size, SymInt num_weights, bool scale_grad_by_freq, int mode, Tensor? per_sample_weights, int padding_idx=-1) -> Tensor + inline at::Tensor _embedding_bag_sparse_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & indices, const at::Tensor & offsets, const at::Tensor & offset2bag, const at::Tensor & bag_size, int64_t num_weights, bool scale_grad_by_freq, int64_t mode, const ::std::optional & per_sample_weights, int64_t padding_idx=-1) { + return at::_ops::_embedding_bag_sparse_backward::redispatch(dispatchKeySet, grad, indices, offsets, offset2bag, bag_size, num_weights, scale_grad_by_freq, mode, per_sample_weights, padding_idx); + } + + // aten::_embedding_bag_sparse_backward(Tensor grad, Tensor indices, Tensor offsets, Tensor offset2bag, Tensor bag_size, SymInt num_weights, bool scale_grad_by_freq, int mode, Tensor? per_sample_weights, int padding_idx=-1) -> Tensor + inline at::Tensor _embedding_bag_sparse_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & indices, const at::Tensor & offsets, const at::Tensor & offset2bag, const at::Tensor & bag_size, c10::SymInt num_weights, bool scale_grad_by_freq, int64_t mode, const ::std::optional & per_sample_weights, int64_t padding_idx=-1) { + return at::_ops::_embedding_bag_sparse_backward::redispatch(dispatchKeySet, grad, indices, offsets, offset2bag, bag_size, num_weights, scale_grad_by_freq, mode, per_sample_weights, padding_idx); + } + + // aten::_embedding_bag_dense_backward(Tensor grad, Tensor indices, Tensor offset2bag, Tensor bag_size, Tensor maximum_indices, SymInt num_weights, bool scale_grad_by_freq, int mode, Tensor? per_sample_weights, int padding_idx=-1) -> Tensor + inline at::Tensor _embedding_bag_dense_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & indices, const at::Tensor & offset2bag, const at::Tensor & bag_size, const at::Tensor & maximum_indices, int64_t num_weights, bool scale_grad_by_freq, int64_t mode, const ::std::optional & per_sample_weights, int64_t padding_idx=-1) { + return at::_ops::_embedding_bag_dense_backward::redispatch(dispatchKeySet, grad, indices, offset2bag, bag_size, maximum_indices, num_weights, scale_grad_by_freq, mode, per_sample_weights, padding_idx); + } + + // aten::_embedding_bag_dense_backward(Tensor grad, Tensor indices, Tensor offset2bag, Tensor bag_size, Tensor maximum_indices, SymInt num_weights, bool scale_grad_by_freq, int mode, Tensor? per_sample_weights, int padding_idx=-1) -> Tensor + inline at::Tensor _embedding_bag_dense_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & indices, const at::Tensor & offset2bag, const at::Tensor & bag_size, const at::Tensor & maximum_indices, c10::SymInt num_weights, bool scale_grad_by_freq, int64_t mode, const ::std::optional & per_sample_weights, int64_t padding_idx=-1) { + return at::_ops::_embedding_bag_dense_backward::redispatch(dispatchKeySet, grad, indices, offset2bag, bag_size, maximum_indices, num_weights, scale_grad_by_freq, mode, per_sample_weights, padding_idx); + } + + // aten::_embedding_bag_per_sample_weights_backward(Tensor grad, Tensor weight, Tensor indices, Tensor offsets, Tensor offset2bag, int mode, int padding_idx=-1) -> Tensor + inline at::Tensor _embedding_bag_per_sample_weights_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & weight, const at::Tensor & indices, const at::Tensor & offsets, const at::Tensor & offset2bag, int64_t mode, int64_t padding_idx=-1) { + return at::_ops::_embedding_bag_per_sample_weights_backward::redispatch(dispatchKeySet, grad, weight, indices, offsets, offset2bag, mode, padding_idx); + } + + // aten::empty.names(int[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + inline at::Tensor empty(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, ::std::optional names, at::TensorOptions options={}, ::std::optional memory_format=::std::nullopt) { + return at::_ops::empty_names::redispatch(dispatchKeySet, size, names, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), c10::impl::check_tensor_options_and_extract_memory_format(options, memory_format)); + } + + // aten::empty.names(int[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + inline at::Tensor empty(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, ::std::optional names, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format) { + return at::_ops::empty_names::redispatch(dispatchKeySet, size, names, dtype, layout, device, pin_memory, memory_format); + } + + // aten::empty.memory_format(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + inline at::Tensor empty(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, at::TensorOptions options={}, ::std::optional memory_format=::std::nullopt) { + return at::_ops::empty_memory_format::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), c10::impl::check_tensor_options_and_extract_memory_format(options, memory_format)); + } + + // aten::empty.memory_format(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + inline at::Tensor empty(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format) { + return at::_ops::empty_memory_format::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), dtype, layout, device, pin_memory, memory_format); + } + + // aten::empty.memory_format(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + inline at::Tensor empty_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, at::TensorOptions options={}, ::std::optional memory_format=::std::nullopt) { + return at::_ops::empty_memory_format::redispatch(dispatchKeySet, size, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), c10::impl::check_tensor_options_and_extract_memory_format(options, memory_format)); + } + + // aten::empty.memory_format(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + inline at::Tensor empty_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format) { + return at::_ops::empty_memory_format::redispatch(dispatchKeySet, size, dtype, layout, device, pin_memory, memory_format); + } + + // aten::empty_permuted(SymInt[] size, int[] physical_layout, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor empty_permuted(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, at::IntArrayRef physical_layout, at::TensorOptions options={}) { + return at::_ops::empty_permuted::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), physical_layout, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::empty_permuted(SymInt[] size, int[] physical_layout, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor empty_permuted(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, at::IntArrayRef physical_layout, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::empty_permuted::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), physical_layout, dtype, layout, device, pin_memory); + } + + // aten::empty_permuted(SymInt[] size, int[] physical_layout, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor empty_permuted_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, at::IntArrayRef physical_layout, at::TensorOptions options={}) { + return at::_ops::empty_permuted::redispatch(dispatchKeySet, size, physical_layout, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::empty_permuted(SymInt[] size, int[] physical_layout, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor empty_permuted_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, at::IntArrayRef physical_layout, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::empty_permuted::redispatch(dispatchKeySet, size, physical_layout, dtype, layout, device, pin_memory); + } + + // aten::new_empty(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor new_empty(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, at::TensorOptions options={}) { + return at::_ops::new_empty::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::new_empty(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor new_empty(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::new_empty::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), dtype, layout, device, pin_memory); + } + + // aten::new_empty(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor new_empty_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, at::TensorOptions options={}) { + return at::_ops::new_empty::redispatch(dispatchKeySet, self, size, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::new_empty(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor new_empty_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::new_empty::redispatch(dispatchKeySet, self, size, dtype, layout, device, pin_memory); + } + + // aten::new_empty_strided(Tensor self, SymInt[] size, SymInt[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor new_empty_strided(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, at::IntArrayRef stride, at::TensorOptions options={}) { + return at::_ops::new_empty_strided::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride), c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::new_empty_strided(Tensor self, SymInt[] size, SymInt[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor new_empty_strided(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, at::IntArrayRef stride, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::new_empty_strided::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride), dtype, layout, device, pin_memory); + } + + // aten::new_empty_strided(Tensor self, SymInt[] size, SymInt[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor new_empty_strided_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, at::TensorOptions options={}) { + return at::_ops::new_empty_strided::redispatch(dispatchKeySet, self, size, stride, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::new_empty_strided(Tensor self, SymInt[] size, SymInt[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor new_empty_strided_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::new_empty_strided::redispatch(dispatchKeySet, self, size, stride, dtype, layout, device, pin_memory); + } + + // aten::new_full(Tensor self, SymInt[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor new_full(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, const at::Scalar & fill_value, at::TensorOptions options={}) { + return at::_ops::new_full::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), fill_value, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::new_full(Tensor self, SymInt[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor new_full(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, const at::Scalar & fill_value, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::new_full::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), fill_value, dtype, layout, device, pin_memory); + } + + // aten::new_full(Tensor self, SymInt[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor new_full_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, const at::Scalar & fill_value, at::TensorOptions options={}) { + return at::_ops::new_full::redispatch(dispatchKeySet, self, size, fill_value, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::new_full(Tensor self, SymInt[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor new_full_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, const at::Scalar & fill_value, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::new_full::redispatch(dispatchKeySet, self, size, fill_value, dtype, layout, device, pin_memory); + } + + // aten::new_zeros(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor new_zeros(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, at::TensorOptions options={}) { + return at::_ops::new_zeros::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::new_zeros(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor new_zeros(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::new_zeros::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), dtype, layout, device, pin_memory); + } + + // aten::new_zeros(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor new_zeros_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, at::TensorOptions options={}) { + return at::_ops::new_zeros::redispatch(dispatchKeySet, self, size, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::new_zeros(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor new_zeros_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::new_zeros::redispatch(dispatchKeySet, self, size, dtype, layout, device, pin_memory); + } + + // aten::new_ones(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor new_ones(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, at::TensorOptions options={}) { + return at::_ops::new_ones::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::new_ones(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor new_ones(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::new_ones::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), dtype, layout, device, pin_memory); + } + + // aten::new_ones(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor new_ones_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, at::TensorOptions options={}) { + return at::_ops::new_ones::redispatch(dispatchKeySet, self, size, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::new_ones(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor new_ones_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::new_ones::redispatch(dispatchKeySet, self, size, dtype, layout, device, pin_memory); + } + + // aten::_empty_affine_quantized(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, float scale=1, int zero_point=0, MemoryFormat? memory_format=contiguous_format) -> Tensor + inline at::Tensor _empty_affine_quantized(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, at::TensorOptions options={}, double scale=1, int64_t zero_point=0, ::std::optional memory_format=c10::MemoryFormat::Contiguous) { + return at::_ops::_empty_affine_quantized::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), scale, zero_point, c10::impl::check_tensor_options_and_extract_memory_format(options, memory_format)); + } + + // aten::_empty_affine_quantized(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, float scale=1, int zero_point=0, MemoryFormat? memory_format=contiguous_format) -> Tensor + inline at::Tensor _empty_affine_quantized(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, double scale, int64_t zero_point, ::std::optional memory_format) { + return at::_ops::_empty_affine_quantized::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), dtype, layout, device, pin_memory, scale, zero_point, memory_format); + } + + // aten::_empty_affine_quantized(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, float scale=1, int zero_point=0, MemoryFormat? memory_format=contiguous_format) -> Tensor + inline at::Tensor _empty_affine_quantized_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, at::TensorOptions options={}, double scale=1, int64_t zero_point=0, ::std::optional memory_format=c10::MemoryFormat::Contiguous) { + return at::_ops::_empty_affine_quantized::redispatch(dispatchKeySet, size, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), scale, zero_point, c10::impl::check_tensor_options_and_extract_memory_format(options, memory_format)); + } + + // aten::_empty_affine_quantized(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, float scale=1, int zero_point=0, MemoryFormat? memory_format=contiguous_format) -> Tensor + inline at::Tensor _empty_affine_quantized_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, double scale, int64_t zero_point, ::std::optional memory_format) { + return at::_ops::_empty_affine_quantized::redispatch(dispatchKeySet, size, dtype, layout, device, pin_memory, scale, zero_point, memory_format); + } + + // aten::_empty_per_channel_affine_quantized(SymInt[] size, *, Tensor scales, Tensor zero_points, int axis, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=contiguous_format) -> Tensor + inline at::Tensor _empty_per_channel_affine_quantized(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, at::TensorOptions options={}, ::std::optional memory_format=c10::MemoryFormat::Contiguous) { + return at::_ops::_empty_per_channel_affine_quantized::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), scales, zero_points, axis, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), c10::impl::check_tensor_options_and_extract_memory_format(options, memory_format)); + } + + // aten::_empty_per_channel_affine_quantized(SymInt[] size, *, Tensor scales, Tensor zero_points, int axis, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=contiguous_format) -> Tensor + inline at::Tensor _empty_per_channel_affine_quantized(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format) { + return at::_ops::_empty_per_channel_affine_quantized::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), scales, zero_points, axis, dtype, layout, device, pin_memory, memory_format); + } + + // aten::_empty_per_channel_affine_quantized(SymInt[] size, *, Tensor scales, Tensor zero_points, int axis, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=contiguous_format) -> Tensor + inline at::Tensor _empty_per_channel_affine_quantized_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, at::TensorOptions options={}, ::std::optional memory_format=c10::MemoryFormat::Contiguous) { + return at::_ops::_empty_per_channel_affine_quantized::redispatch(dispatchKeySet, size, scales, zero_points, axis, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), c10::impl::check_tensor_options_and_extract_memory_format(options, memory_format)); + } + + // aten::_empty_per_channel_affine_quantized(SymInt[] size, *, Tensor scales, Tensor zero_points, int axis, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=contiguous_format) -> Tensor + inline at::Tensor _empty_per_channel_affine_quantized_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format) { + return at::_ops::_empty_per_channel_affine_quantized::redispatch(dispatchKeySet, size, scales, zero_points, axis, dtype, layout, device, pin_memory, memory_format); + } + + // aten::resize_(Tensor(a!) self, SymInt[] size, *, MemoryFormat? memory_format=None) -> Tensor(a!) + inline const at::Tensor & resize_(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, ::std::optional memory_format=::std::nullopt) { + return at::_ops::resize_::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), memory_format); + } + + // aten::resize_(Tensor(a!) self, SymInt[] size, *, MemoryFormat? memory_format=None) -> Tensor(a!) + inline const at::Tensor & resize__symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, ::std::optional memory_format=::std::nullopt) { + return at::_ops::resize_::redispatch(dispatchKeySet, self, size, memory_format); + } + + // aten::_resize_output_(Tensor(a!) self, SymInt[] size, Device device) -> Tensor(a!) + inline const at::Tensor & _resize_output_(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, at::Device device) { + return at::_ops::_resize_output_::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), device); + } + + // aten::_resize_output_(Tensor(a!) self, SymInt[] size, Device device) -> Tensor(a!) + inline const at::Tensor & _resize_output__symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, at::Device device) { + return at::_ops::_resize_output_::redispatch(dispatchKeySet, self, size, device); + } + + // aten::empty_quantized(int[] size, Tensor qtensor, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + inline at::Tensor empty_quantized(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, const at::Tensor & qtensor, at::TensorOptions options={}, ::std::optional memory_format=::std::nullopt) { + return at::_ops::empty_quantized::redispatch(dispatchKeySet, size, qtensor, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), c10::impl::check_tensor_options_and_extract_memory_format(options, memory_format)); + } + + // aten::empty_quantized(int[] size, Tensor qtensor, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + inline at::Tensor empty_quantized(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, const at::Tensor & qtensor, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format) { + return at::_ops::empty_quantized::redispatch(dispatchKeySet, size, qtensor, dtype, layout, device, pin_memory, memory_format); + } + + // aten::empty.out(SymInt[] size, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & empty_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::IntArrayRef size, ::std::optional memory_format=::std::nullopt) { + return at::_ops::empty_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), memory_format, out); + } + + // aten::empty.out(SymInt[] size, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & empty_outf(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, ::std::optional memory_format, at::Tensor & out) { + return at::_ops::empty_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), memory_format, out); + } + + // aten::empty.out(SymInt[] size, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & empty_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, c10::SymIntArrayRef size, ::std::optional memory_format=::std::nullopt) { + return at::_ops::empty_out::redispatch(dispatchKeySet, size, memory_format, out); + } + + // aten::empty.out(SymInt[] size, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & empty_symint_outf(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, ::std::optional memory_format, at::Tensor & out) { + return at::_ops::empty_out::redispatch(dispatchKeySet, size, memory_format, out); + } + + // aten::empty_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + inline at::Tensor empty_like(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::TensorOptions options={}, ::std::optional memory_format=::std::nullopt) { + return at::_ops::empty_like::redispatch(dispatchKeySet, self, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), c10::impl::check_tensor_options_and_extract_memory_format(options, memory_format)); + } + + // aten::empty_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + inline at::Tensor empty_like(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format) { + return at::_ops::empty_like::redispatch(dispatchKeySet, self, dtype, layout, device, pin_memory, memory_format); + } + + // aten::empty_strided(SymInt[] size, SymInt[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor empty_strided(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, at::IntArrayRef stride, at::TensorOptions options={}) { + return at::_ops::empty_strided::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride), c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::empty_strided(SymInt[] size, SymInt[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor empty_strided(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, at::IntArrayRef stride, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::empty_strided::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride), dtype, layout, device, pin_memory); + } + + // aten::empty_strided(SymInt[] size, SymInt[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor empty_strided_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, at::TensorOptions options={}) { + return at::_ops::empty_strided::redispatch(dispatchKeySet, size, stride, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::empty_strided(SymInt[] size, SymInt[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor empty_strided_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::empty_strided::redispatch(dispatchKeySet, size, stride, dtype, layout, device, pin_memory); + } + + // aten::erf(Tensor self) -> Tensor + inline at::Tensor erf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::erf::redispatch(dispatchKeySet, self); + } + + // aten::erf_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & erf_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::erf_::redispatch(dispatchKeySet, self); + } + + // aten::erf.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & erf_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::erf_out::redispatch(dispatchKeySet, self, out); + } + + // aten::erf.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & erf_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::erf_out::redispatch(dispatchKeySet, self, out); + } + + // aten::erfc(Tensor self) -> Tensor + inline at::Tensor erfc(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::erfc::redispatch(dispatchKeySet, self); + } + + // aten::erfc_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & erfc_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::erfc_::redispatch(dispatchKeySet, self); + } + + // aten::erfc.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & erfc_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::erfc_out::redispatch(dispatchKeySet, self, out); + } + + // aten::erfc.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & erfc_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::erfc_out::redispatch(dispatchKeySet, self, out); + } + + // aten::exp(Tensor self) -> Tensor + inline at::Tensor exp(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::exp::redispatch(dispatchKeySet, self); + } + + // aten::exp_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & exp_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::exp_::redispatch(dispatchKeySet, self); + } + + // aten::exp.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & exp_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::exp_out::redispatch(dispatchKeySet, self, out); + } + + // aten::exp.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & exp_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::exp_out::redispatch(dispatchKeySet, self, out); + } + + // aten::exp2(Tensor self) -> Tensor + inline at::Tensor exp2(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::exp2::redispatch(dispatchKeySet, self); + } + + // aten::exp2_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & exp2_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::exp2_::redispatch(dispatchKeySet, self); + } + + // aten::exp2.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & exp2_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::exp2_out::redispatch(dispatchKeySet, self, out); + } + + // aten::exp2.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & exp2_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::exp2_out::redispatch(dispatchKeySet, self, out); + } + + // aten::expm1(Tensor self) -> Tensor + inline at::Tensor expm1(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::expm1::redispatch(dispatchKeySet, self); + } + + // aten::expm1_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & expm1_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::expm1_::redispatch(dispatchKeySet, self); + } + + // aten::expm1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & expm1_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::expm1_out::redispatch(dispatchKeySet, self, out); + } + + // aten::expm1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & expm1_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::expm1_out::redispatch(dispatchKeySet, self, out); + } + + // aten::expand(Tensor(a) self, SymInt[] size, *, bool implicit=False) -> Tensor(a) + inline at::Tensor expand(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, bool implicit=false) { + return at::_ops::expand::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), implicit); + } + + // aten::expand(Tensor(a) self, SymInt[] size, *, bool implicit=False) -> Tensor(a) + inline at::Tensor expand_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, bool implicit=false) { + return at::_ops::expand::redispatch(dispatchKeySet, self, size, implicit); + } + + // aten::expand_as(Tensor(a) self, Tensor other) -> Tensor(a) + inline at::Tensor expand_as(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::expand_as::redispatch(dispatchKeySet, self, other); + } + + // aten::eye(SymInt n, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor eye(c10::DispatchKeySet dispatchKeySet, int64_t n, at::TensorOptions options={}) { + return at::_ops::eye::redispatch(dispatchKeySet, n, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::eye(SymInt n, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor eye(c10::DispatchKeySet dispatchKeySet, int64_t n, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::eye::redispatch(dispatchKeySet, n, dtype, layout, device, pin_memory); + } + + // aten::eye(SymInt n, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor eye_symint(c10::DispatchKeySet dispatchKeySet, c10::SymInt n, at::TensorOptions options={}) { + return at::_ops::eye::redispatch(dispatchKeySet, n, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::eye(SymInt n, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor eye_symint(c10::DispatchKeySet dispatchKeySet, c10::SymInt n, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::eye::redispatch(dispatchKeySet, n, dtype, layout, device, pin_memory); + } + + // aten::eye.m(SymInt n, SymInt m, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor eye(c10::DispatchKeySet dispatchKeySet, int64_t n, int64_t m, at::TensorOptions options={}) { + return at::_ops::eye_m::redispatch(dispatchKeySet, n, m, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::eye.m(SymInt n, SymInt m, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor eye(c10::DispatchKeySet dispatchKeySet, int64_t n, int64_t m, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::eye_m::redispatch(dispatchKeySet, n, m, dtype, layout, device, pin_memory); + } + + // aten::eye.m(SymInt n, SymInt m, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor eye_symint(c10::DispatchKeySet dispatchKeySet, c10::SymInt n, c10::SymInt m, at::TensorOptions options={}) { + return at::_ops::eye_m::redispatch(dispatchKeySet, n, m, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::eye.m(SymInt n, SymInt m, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor eye_symint(c10::DispatchKeySet dispatchKeySet, c10::SymInt n, c10::SymInt m, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::eye_m::redispatch(dispatchKeySet, n, m, dtype, layout, device, pin_memory); + } + + // aten::eye.out(SymInt n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & eye_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, int64_t n) { + return at::_ops::eye_out::redispatch(dispatchKeySet, n, out); + } + + // aten::eye.out(SymInt n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & eye_outf(c10::DispatchKeySet dispatchKeySet, int64_t n, at::Tensor & out) { + return at::_ops::eye_out::redispatch(dispatchKeySet, n, out); + } + + // aten::eye.out(SymInt n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & eye_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, c10::SymInt n) { + return at::_ops::eye_out::redispatch(dispatchKeySet, n, out); + } + + // aten::eye.out(SymInt n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & eye_symint_outf(c10::DispatchKeySet dispatchKeySet, c10::SymInt n, at::Tensor & out) { + return at::_ops::eye_out::redispatch(dispatchKeySet, n, out); + } + + // aten::eye.m_out(SymInt n, SymInt m, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & eye_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, int64_t n, int64_t m) { + return at::_ops::eye_m_out::redispatch(dispatchKeySet, n, m, out); + } + + // aten::eye.m_out(SymInt n, SymInt m, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & eye_outf(c10::DispatchKeySet dispatchKeySet, int64_t n, int64_t m, at::Tensor & out) { + return at::_ops::eye_m_out::redispatch(dispatchKeySet, n, m, out); + } + + // aten::eye.m_out(SymInt n, SymInt m, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & eye_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, c10::SymInt n, c10::SymInt m) { + return at::_ops::eye_m_out::redispatch(dispatchKeySet, n, m, out); + } + + // aten::eye.m_out(SymInt n, SymInt m, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & eye_symint_outf(c10::DispatchKeySet dispatchKeySet, c10::SymInt n, c10::SymInt m, at::Tensor & out) { + return at::_ops::eye_m_out::redispatch(dispatchKeySet, n, m, out); + } + + // aten::flatten.using_ints(Tensor(a) self, int start_dim=0, int end_dim=-1) -> Tensor(a) + inline at::Tensor flatten(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t start_dim=0, int64_t end_dim=-1) { + return at::_ops::flatten_using_ints::redispatch(dispatchKeySet, self, start_dim, end_dim); + } + + // aten::flatten.named_out_dim(Tensor(a) self, int start_dim, int end_dim, Dimname out_dim) -> Tensor(a) + inline at::Tensor flatten(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t start_dim, int64_t end_dim, at::Dimname out_dim) { + return at::_ops::flatten_named_out_dim::redispatch(dispatchKeySet, self, start_dim, end_dim, out_dim); + } + + // aten::flatten.using_names(Tensor(a) self, Dimname start_dim, Dimname end_dim, Dimname out_dim) -> Tensor(a) + inline at::Tensor flatten(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname start_dim, at::Dimname end_dim, at::Dimname out_dim) { + return at::_ops::flatten_using_names::redispatch(dispatchKeySet, self, start_dim, end_dim, out_dim); + } + + // aten::flatten.DimnameList(Tensor(a) self, Dimname[] dims, Dimname out_dim) -> Tensor(a) + inline at::Tensor flatten(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::DimnameList dims, at::Dimname out_dim) { + return at::_ops::flatten_DimnameList::redispatch(dispatchKeySet, self, dims, out_dim); + } + + // aten::unflatten.int(Tensor(a) self, int dim, SymInt[] sizes) -> Tensor(a) + inline at::Tensor unflatten(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, at::IntArrayRef sizes) { + return at::_ops::unflatten_int::redispatch(dispatchKeySet, self, dim, c10::fromIntArrayRefSlow(sizes)); + } + + // aten::unflatten.int(Tensor(a) self, int dim, SymInt[] sizes) -> Tensor(a) + inline at::Tensor unflatten_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, c10::SymIntArrayRef sizes) { + return at::_ops::unflatten_int::redispatch(dispatchKeySet, self, dim, sizes); + } + + // aten::unflatten.Dimname(Tensor(a) self, Dimname dim, SymInt[] sizes, Dimname[] names) -> Tensor(a) + inline at::Tensor unflatten(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, at::IntArrayRef sizes, at::DimnameList names) { + return at::_ops::unflatten_Dimname::redispatch(dispatchKeySet, self, dim, c10::fromIntArrayRefSlow(sizes), names); + } + + // aten::unflatten.Dimname(Tensor(a) self, Dimname dim, SymInt[] sizes, Dimname[] names) -> Tensor(a) + inline at::Tensor unflatten_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, c10::SymIntArrayRef sizes, at::DimnameList names) { + return at::_ops::unflatten_Dimname::redispatch(dispatchKeySet, self, dim, sizes, names); + } + + // aten::fill.Scalar(Tensor self, Scalar value) -> Tensor + inline at::Tensor fill(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & value) { + return at::_ops::fill_Scalar::redispatch(dispatchKeySet, self, value); + } + + // aten::fill.Tensor(Tensor self, Tensor value) -> Tensor + inline at::Tensor fill(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & value) { + return at::_ops::fill_Tensor::redispatch(dispatchKeySet, self, value); + } + + // aten::fill_.Scalar(Tensor(a!) self, Scalar value) -> Tensor(a!) + inline at::Tensor & fill_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & value) { + return at::_ops::fill__Scalar::redispatch(dispatchKeySet, self, value); + } + + // aten::fill_.Tensor(Tensor(a!) self, Tensor value) -> Tensor(a!) + inline at::Tensor & fill_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & value) { + return at::_ops::fill__Tensor::redispatch(dispatchKeySet, self, value); + } + + // aten::floor(Tensor self) -> Tensor + inline at::Tensor floor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::floor::redispatch(dispatchKeySet, self); + } + + // aten::floor_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & floor_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::floor_::redispatch(dispatchKeySet, self); + } + + // aten::floor.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & floor_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::floor_out::redispatch(dispatchKeySet, self, out); + } + + // aten::floor.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & floor_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::floor_out::redispatch(dispatchKeySet, self, out); + } + + // aten::floor_divide(Tensor self, Tensor other) -> Tensor + inline at::Tensor floor_divide(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::floor_divide::redispatch(dispatchKeySet, self, other); + } + + // aten::floor_divide_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + inline at::Tensor & floor_divide_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) { + return at::_ops::floor_divide__Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::floor_divide.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & floor_divide_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::floor_divide_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::floor_divide.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & floor_divide_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::floor_divide_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::floor_divide.Scalar(Tensor self, Scalar other) -> Tensor + inline at::Tensor floor_divide(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::floor_divide_Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::floor_divide_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + inline at::Tensor & floor_divide_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other) { + return at::_ops::floor_divide__Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::frac(Tensor self) -> Tensor + inline at::Tensor frac(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::frac::redispatch(dispatchKeySet, self); + } + + // aten::frac_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & frac_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::frac_::redispatch(dispatchKeySet, self); + } + + // aten::frac.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & frac_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::frac_out::redispatch(dispatchKeySet, self, out); + } + + // aten::frac.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & frac_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::frac_out::redispatch(dispatchKeySet, self, out); + } + + // aten::full.names(int[] size, Scalar fill_value, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor full(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, const at::Scalar & fill_value, ::std::optional names, at::TensorOptions options={}) { + return at::_ops::full_names::redispatch(dispatchKeySet, size, fill_value, names, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::full.names(int[] size, Scalar fill_value, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor full(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, const at::Scalar & fill_value, ::std::optional names, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::full_names::redispatch(dispatchKeySet, size, fill_value, names, dtype, layout, device, pin_memory); + } + + // aten::full(SymInt[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor full(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, const at::Scalar & fill_value, at::TensorOptions options={}) { + return at::_ops::full::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), fill_value, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::full(SymInt[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor full(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, const at::Scalar & fill_value, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::full::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), fill_value, dtype, layout, device, pin_memory); + } + + // aten::full(SymInt[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor full_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, const at::Scalar & fill_value, at::TensorOptions options={}) { + return at::_ops::full::redispatch(dispatchKeySet, size, fill_value, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::full(SymInt[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor full_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, const at::Scalar & fill_value, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::full::redispatch(dispatchKeySet, size, fill_value, dtype, layout, device, pin_memory); + } + + // aten::full.out(SymInt[] size, Scalar fill_value, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & full_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::IntArrayRef size, const at::Scalar & fill_value) { + return at::_ops::full_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), fill_value, out); + } + + // aten::full.out(SymInt[] size, Scalar fill_value, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & full_outf(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, const at::Scalar & fill_value, at::Tensor & out) { + return at::_ops::full_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), fill_value, out); + } + + // aten::full.out(SymInt[] size, Scalar fill_value, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & full_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, c10::SymIntArrayRef size, const at::Scalar & fill_value) { + return at::_ops::full_out::redispatch(dispatchKeySet, size, fill_value, out); + } + + // aten::full.out(SymInt[] size, Scalar fill_value, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & full_symint_outf(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, const at::Scalar & fill_value, at::Tensor & out) { + return at::_ops::full_out::redispatch(dispatchKeySet, size, fill_value, out); + } + + // aten::full_like(Tensor self, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + inline at::Tensor full_like(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & fill_value, at::TensorOptions options={}, ::std::optional memory_format=::std::nullopt) { + return at::_ops::full_like::redispatch(dispatchKeySet, self, fill_value, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), c10::impl::check_tensor_options_and_extract_memory_format(options, memory_format)); + } + + // aten::full_like(Tensor self, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + inline at::Tensor full_like(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & fill_value, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format) { + return at::_ops::full_like::redispatch(dispatchKeySet, self, fill_value, dtype, layout, device, pin_memory, memory_format); + } + + // aten::from_file(str filename, bool? shared=None, int? size=0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor from_file(c10::DispatchKeySet dispatchKeySet, c10::string_view filename, ::std::optional shared=::std::nullopt, ::std::optional size=0, at::TensorOptions options={}) { + return at::_ops::from_file::redispatch(dispatchKeySet, filename, shared, size, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::from_file(str filename, bool? shared=None, int? size=0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor from_file(c10::DispatchKeySet dispatchKeySet, c10::string_view filename, ::std::optional shared, ::std::optional size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::from_file::redispatch(dispatchKeySet, filename, shared, size, dtype, layout, device, pin_memory); + } + + // aten::gcd.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & gcd_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::gcd_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::gcd.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & gcd_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::gcd_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::gcd(Tensor self, Tensor other) -> Tensor + inline at::Tensor gcd(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::gcd::redispatch(dispatchKeySet, self, other); + } + + // aten::gcd_(Tensor(a!) self, Tensor other) -> Tensor(a!) + inline at::Tensor & gcd_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) { + return at::_ops::gcd_::redispatch(dispatchKeySet, self, other); + } + + // aten::lcm.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & lcm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::lcm_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::lcm.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & lcm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::lcm_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::lcm(Tensor self, Tensor other) -> Tensor + inline at::Tensor lcm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::lcm::redispatch(dispatchKeySet, self, other); + } + + // aten::lcm_(Tensor(a!) self, Tensor other) -> Tensor(a!) + inline at::Tensor & lcm_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) { + return at::_ops::lcm_::redispatch(dispatchKeySet, self, other); + } + + // aten::grid_sampler(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners) -> Tensor + inline at::Tensor grid_sampler(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners) { + return at::_ops::grid_sampler::redispatch(dispatchKeySet, input, grid, interpolation_mode, padding_mode, align_corners); + } + + // aten::grid_sampler_2d(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners) -> Tensor + inline at::Tensor grid_sampler_2d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners) { + return at::_ops::grid_sampler_2d::redispatch(dispatchKeySet, input, grid, interpolation_mode, padding_mode, align_corners); + } + + // aten::grid_sampler_2d_backward(Tensor grad_output, Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners, bool[2] output_mask) -> (Tensor, Tensor) + inline ::std::tuple grid_sampler_2d_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners, ::std::array output_mask) { + return at::_ops::grid_sampler_2d_backward::redispatch(dispatchKeySet, grad_output, input, grid, interpolation_mode, padding_mode, align_corners, output_mask); + } + + // aten::_grid_sampler_2d_cpu_fallback(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners) -> Tensor + inline at::Tensor _grid_sampler_2d_cpu_fallback(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners) { + return at::_ops::_grid_sampler_2d_cpu_fallback::redispatch(dispatchKeySet, input, grid, interpolation_mode, padding_mode, align_corners); + } + + // aten::_grid_sampler_2d_cpu_fallback_backward(Tensor grad_output, Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners) -> (Tensor, Tensor) + inline ::std::tuple _grid_sampler_2d_cpu_fallback_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners) { + return at::_ops::_grid_sampler_2d_cpu_fallback_backward::redispatch(dispatchKeySet, grad_output, input, grid, interpolation_mode, padding_mode, align_corners); + } + + // aten::grid_sampler_3d(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners) -> Tensor + inline at::Tensor grid_sampler_3d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners) { + return at::_ops::grid_sampler_3d::redispatch(dispatchKeySet, input, grid, interpolation_mode, padding_mode, align_corners); + } + + // aten::grid_sampler_3d_backward(Tensor grad_output, Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners, bool[2] output_mask) -> (Tensor, Tensor) + inline ::std::tuple grid_sampler_3d_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners, ::std::array output_mask) { + return at::_ops::grid_sampler_3d_backward::redispatch(dispatchKeySet, grad_output, input, grid, interpolation_mode, padding_mode, align_corners, output_mask); + } + + // aten::hann_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor hann_window(c10::DispatchKeySet dispatchKeySet, int64_t window_length, at::TensorOptions options={}) { + return at::_ops::hann_window::redispatch(dispatchKeySet, window_length, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::hann_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor hann_window(c10::DispatchKeySet dispatchKeySet, int64_t window_length, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::hann_window::redispatch(dispatchKeySet, window_length, dtype, layout, device, pin_memory); + } + + // aten::hann_window.periodic(int window_length, bool periodic, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor hann_window(c10::DispatchKeySet dispatchKeySet, int64_t window_length, bool periodic, at::TensorOptions options={}) { + return at::_ops::hann_window_periodic::redispatch(dispatchKeySet, window_length, periodic, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::hann_window.periodic(int window_length, bool periodic, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor hann_window(c10::DispatchKeySet dispatchKeySet, int64_t window_length, bool periodic, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::hann_window_periodic::redispatch(dispatchKeySet, window_length, periodic, dtype, layout, device, pin_memory); + } + + // aten::hamming_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor hamming_window(c10::DispatchKeySet dispatchKeySet, int64_t window_length, at::TensorOptions options={}) { + return at::_ops::hamming_window::redispatch(dispatchKeySet, window_length, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::hamming_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor hamming_window(c10::DispatchKeySet dispatchKeySet, int64_t window_length, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::hamming_window::redispatch(dispatchKeySet, window_length, dtype, layout, device, pin_memory); + } + + // aten::hamming_window.periodic(int window_length, bool periodic, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor hamming_window(c10::DispatchKeySet dispatchKeySet, int64_t window_length, bool periodic, at::TensorOptions options={}) { + return at::_ops::hamming_window_periodic::redispatch(dispatchKeySet, window_length, periodic, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::hamming_window.periodic(int window_length, bool periodic, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor hamming_window(c10::DispatchKeySet dispatchKeySet, int64_t window_length, bool periodic, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::hamming_window_periodic::redispatch(dispatchKeySet, window_length, periodic, dtype, layout, device, pin_memory); + } + + // aten::hamming_window.periodic_alpha(int window_length, bool periodic, float alpha, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor hamming_window(c10::DispatchKeySet dispatchKeySet, int64_t window_length, bool periodic, double alpha, at::TensorOptions options={}) { + return at::_ops::hamming_window_periodic_alpha::redispatch(dispatchKeySet, window_length, periodic, alpha, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::hamming_window.periodic_alpha(int window_length, bool periodic, float alpha, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor hamming_window(c10::DispatchKeySet dispatchKeySet, int64_t window_length, bool periodic, double alpha, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::hamming_window_periodic_alpha::redispatch(dispatchKeySet, window_length, periodic, alpha, dtype, layout, device, pin_memory); + } + + // aten::hamming_window.periodic_alpha_beta(int window_length, bool periodic, float alpha, float beta, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor hamming_window(c10::DispatchKeySet dispatchKeySet, int64_t window_length, bool periodic, double alpha, double beta, at::TensorOptions options={}) { + return at::_ops::hamming_window_periodic_alpha_beta::redispatch(dispatchKeySet, window_length, periodic, alpha, beta, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::hamming_window.periodic_alpha_beta(int window_length, bool periodic, float alpha, float beta, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor hamming_window(c10::DispatchKeySet dispatchKeySet, int64_t window_length, bool periodic, double alpha, double beta, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::hamming_window_periodic_alpha_beta::redispatch(dispatchKeySet, window_length, periodic, alpha, beta, dtype, layout, device, pin_memory); + } + + // aten::kaiser_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor kaiser_window(c10::DispatchKeySet dispatchKeySet, int64_t window_length, at::TensorOptions options={}) { + return at::_ops::kaiser_window::redispatch(dispatchKeySet, window_length, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::kaiser_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor kaiser_window(c10::DispatchKeySet dispatchKeySet, int64_t window_length, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::kaiser_window::redispatch(dispatchKeySet, window_length, dtype, layout, device, pin_memory); + } + + // aten::kaiser_window.periodic(int window_length, bool periodic, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor kaiser_window(c10::DispatchKeySet dispatchKeySet, int64_t window_length, bool periodic, at::TensorOptions options={}) { + return at::_ops::kaiser_window_periodic::redispatch(dispatchKeySet, window_length, periodic, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::kaiser_window.periodic(int window_length, bool periodic, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor kaiser_window(c10::DispatchKeySet dispatchKeySet, int64_t window_length, bool periodic, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::kaiser_window_periodic::redispatch(dispatchKeySet, window_length, periodic, dtype, layout, device, pin_memory); + } + + // aten::kaiser_window.beta(int window_length, bool periodic, float beta, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor kaiser_window(c10::DispatchKeySet dispatchKeySet, int64_t window_length, bool periodic, double beta, at::TensorOptions options={}) { + return at::_ops::kaiser_window_beta::redispatch(dispatchKeySet, window_length, periodic, beta, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::kaiser_window.beta(int window_length, bool periodic, float beta, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor kaiser_window(c10::DispatchKeySet dispatchKeySet, int64_t window_length, bool periodic, double beta, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::kaiser_window_beta::redispatch(dispatchKeySet, window_length, periodic, beta, dtype, layout, device, pin_memory); + } + + // aten::hinge_embedding_loss(Tensor self, Tensor target, float margin=1.0, int reduction=Mean) -> Tensor + inline at::Tensor hinge_embedding_loss(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, double margin=1.0, int64_t reduction=at::Reduction::Mean) { + return at::_ops::hinge_embedding_loss::redispatch(dispatchKeySet, self, target, margin, reduction); + } + + // aten::group_norm(Tensor input, int num_groups, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enabled=True) -> Tensor + inline at::Tensor group_norm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, int64_t num_groups, const ::std::optional & weight={}, const ::std::optional & bias={}, double eps=1e-05, bool cudnn_enabled=true) { + return at::_ops::group_norm::redispatch(dispatchKeySet, input, num_groups, weight, bias, eps, cudnn_enabled); + } + + // aten::native_group_norm(Tensor input, Tensor? weight, Tensor? bias, SymInt N, SymInt C, SymInt HxW, int group, float eps) -> (Tensor, Tensor, Tensor) + inline ::std::tuple native_group_norm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, int64_t N, int64_t C, int64_t HxW, int64_t group, double eps) { + return at::_ops::native_group_norm::redispatch(dispatchKeySet, input, weight, bias, N, C, HxW, group, eps); + } + + // aten::native_group_norm(Tensor input, Tensor? weight, Tensor? bias, SymInt N, SymInt C, SymInt HxW, int group, float eps) -> (Tensor, Tensor, Tensor) + inline ::std::tuple native_group_norm_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, c10::SymInt N, c10::SymInt C, c10::SymInt HxW, int64_t group, double eps) { + return at::_ops::native_group_norm::redispatch(dispatchKeySet, input, weight, bias, N, C, HxW, group, eps); + } + + // aten::native_group_norm_backward(Tensor grad_out, Tensor input, Tensor mean, Tensor rstd, Tensor? weight, SymInt N, SymInt C, SymInt HxW, int group, bool[3] output_mask) -> (Tensor, Tensor, Tensor) + inline ::std::tuple native_group_norm_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & rstd, const ::std::optional & weight, int64_t N, int64_t C, int64_t HxW, int64_t group, ::std::array output_mask) { + return at::_ops::native_group_norm_backward::redispatch(dispatchKeySet, grad_out, input, mean, rstd, weight, N, C, HxW, group, output_mask); + } + + // aten::native_group_norm_backward(Tensor grad_out, Tensor input, Tensor mean, Tensor rstd, Tensor? weight, SymInt N, SymInt C, SymInt HxW, int group, bool[3] output_mask) -> (Tensor, Tensor, Tensor) + inline ::std::tuple native_group_norm_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & rstd, const ::std::optional & weight, c10::SymInt N, c10::SymInt C, c10::SymInt HxW, int64_t group, ::std::array output_mask) { + return at::_ops::native_group_norm_backward::redispatch(dispatchKeySet, grad_out, input, mean, rstd, weight, N, C, HxW, group, output_mask); + } + + // aten::_fft_r2c(Tensor self, int[] dim, int normalization, bool onesided) -> Tensor + inline at::Tensor _fft_r2c(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim, int64_t normalization, bool onesided) { + return at::_ops::_fft_r2c::redispatch(dispatchKeySet, self, dim, normalization, onesided); + } + + // aten::_fft_r2c.out(Tensor self, int[] dim, int normalization, bool onesided, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _fft_r2c_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef dim, int64_t normalization, bool onesided) { + return at::_ops::_fft_r2c_out::redispatch(dispatchKeySet, self, dim, normalization, onesided, out); + } + + // aten::_fft_r2c.out(Tensor self, int[] dim, int normalization, bool onesided, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _fft_r2c_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim, int64_t normalization, bool onesided, at::Tensor & out) { + return at::_ops::_fft_r2c_out::redispatch(dispatchKeySet, self, dim, normalization, onesided, out); + } + + // aten::_fft_c2r(Tensor self, int[] dim, int normalization, SymInt last_dim_size) -> Tensor + inline at::Tensor _fft_c2r(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim, int64_t normalization, int64_t last_dim_size) { + return at::_ops::_fft_c2r::redispatch(dispatchKeySet, self, dim, normalization, last_dim_size); + } + + // aten::_fft_c2r(Tensor self, int[] dim, int normalization, SymInt last_dim_size) -> Tensor + inline at::Tensor _fft_c2r_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim, int64_t normalization, c10::SymInt last_dim_size) { + return at::_ops::_fft_c2r::redispatch(dispatchKeySet, self, dim, normalization, last_dim_size); + } + + // aten::_fft_c2r.out(Tensor self, int[] dim, int normalization, SymInt last_dim_size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _fft_c2r_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef dim, int64_t normalization, int64_t last_dim_size) { + return at::_ops::_fft_c2r_out::redispatch(dispatchKeySet, self, dim, normalization, last_dim_size, out); + } + + // aten::_fft_c2r.out(Tensor self, int[] dim, int normalization, SymInt last_dim_size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _fft_c2r_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim, int64_t normalization, int64_t last_dim_size, at::Tensor & out) { + return at::_ops::_fft_c2r_out::redispatch(dispatchKeySet, self, dim, normalization, last_dim_size, out); + } + + // aten::_fft_c2r.out(Tensor self, int[] dim, int normalization, SymInt last_dim_size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _fft_c2r_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef dim, int64_t normalization, c10::SymInt last_dim_size) { + return at::_ops::_fft_c2r_out::redispatch(dispatchKeySet, self, dim, normalization, last_dim_size, out); + } + + // aten::_fft_c2r.out(Tensor self, int[] dim, int normalization, SymInt last_dim_size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _fft_c2r_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim, int64_t normalization, c10::SymInt last_dim_size, at::Tensor & out) { + return at::_ops::_fft_c2r_out::redispatch(dispatchKeySet, self, dim, normalization, last_dim_size, out); + } + + // aten::_fft_c2c(Tensor self, SymInt[] dim, int normalization, bool forward) -> Tensor + inline at::Tensor _fft_c2c(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim, int64_t normalization, bool forward) { + return at::_ops::_fft_c2c::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(dim), normalization, forward); + } + + // aten::_fft_c2c(Tensor self, SymInt[] dim, int normalization, bool forward) -> Tensor + inline at::Tensor _fft_c2c_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef dim, int64_t normalization, bool forward) { + return at::_ops::_fft_c2c::redispatch(dispatchKeySet, self, dim, normalization, forward); + } + + // aten::_fft_c2c.out(Tensor self, SymInt[] dim, int normalization, bool forward, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _fft_c2c_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef dim, int64_t normalization, bool forward) { + return at::_ops::_fft_c2c_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(dim), normalization, forward, out); + } + + // aten::_fft_c2c.out(Tensor self, SymInt[] dim, int normalization, bool forward, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _fft_c2c_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim, int64_t normalization, bool forward, at::Tensor & out) { + return at::_ops::_fft_c2c_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(dim), normalization, forward, out); + } + + // aten::_fft_c2c.out(Tensor self, SymInt[] dim, int normalization, bool forward, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _fft_c2c_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef dim, int64_t normalization, bool forward) { + return at::_ops::_fft_c2c_out::redispatch(dispatchKeySet, self, dim, normalization, forward, out); + } + + // aten::_fft_c2c.out(Tensor self, SymInt[] dim, int normalization, bool forward, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _fft_c2c_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef dim, int64_t normalization, bool forward, at::Tensor & out) { + return at::_ops::_fft_c2c_out::redispatch(dispatchKeySet, self, dim, normalization, forward, out); + } + + // aten::_validate_compressed_sparse_indices(bool is_crow, Tensor compressed_idx, Tensor plain_idx, int cdim, int dim, int nnz) -> () + inline void _validate_compressed_sparse_indices(c10::DispatchKeySet dispatchKeySet, bool is_crow, const at::Tensor & compressed_idx, const at::Tensor & plain_idx, int64_t cdim, int64_t dim, int64_t nnz) { + return at::_ops::_validate_compressed_sparse_indices::redispatch(dispatchKeySet, is_crow, compressed_idx, plain_idx, cdim, dim, nnz); + } + + // aten::_cufft_get_plan_cache_size(DeviceIndex device_index) -> int + inline int64_t _cufft_get_plan_cache_size(c10::DispatchKeySet dispatchKeySet, at::DeviceIndex device_index) { + return at::_ops::_cufft_get_plan_cache_size::redispatch(dispatchKeySet, device_index); + } + + // aten::_cufft_get_plan_cache_max_size(DeviceIndex device_index) -> int + inline int64_t _cufft_get_plan_cache_max_size(c10::DispatchKeySet dispatchKeySet, at::DeviceIndex device_index) { + return at::_ops::_cufft_get_plan_cache_max_size::redispatch(dispatchKeySet, device_index); + } + + // aten::_cufft_set_plan_cache_max_size(DeviceIndex device_index, int max_size) -> () + inline void _cufft_set_plan_cache_max_size(c10::DispatchKeySet dispatchKeySet, at::DeviceIndex device_index, int64_t max_size) { + return at::_ops::_cufft_set_plan_cache_max_size::redispatch(dispatchKeySet, device_index, max_size); + } + + // aten::_cufft_clear_plan_cache(DeviceIndex device_index) -> () + inline void _cufft_clear_plan_cache(c10::DispatchKeySet dispatchKeySet, at::DeviceIndex device_index) { + return at::_ops::_cufft_clear_plan_cache::redispatch(dispatchKeySet, device_index); + } + + // aten::index.Tensor(Tensor self, Tensor?[] indices) -> Tensor + inline at::Tensor index(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const c10::List<::std::optional> & indices) { + return at::_ops::index_Tensor::redispatch(dispatchKeySet, self, indices); + } + + // aten::index.Tensor_out(Tensor self, Tensor?[] indices, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & index_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const c10::List<::std::optional> & indices) { + return at::_ops::index_Tensor_out::redispatch(dispatchKeySet, self, indices, out); + } + + // aten::index.Tensor_out(Tensor self, Tensor?[] indices, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & index_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const c10::List<::std::optional> & indices, at::Tensor & out) { + return at::_ops::index_Tensor_out::redispatch(dispatchKeySet, self, indices, out); + } + + // aten::_unsafe_index.Tensor(Tensor self, Tensor?[] indices) -> Tensor + inline at::Tensor _unsafe_index(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const c10::List<::std::optional> & indices) { + return at::_ops::_unsafe_index_Tensor::redispatch(dispatchKeySet, self, indices); + } + + // aten::_unsafe_masked_index(Tensor self, Tensor mask, Tensor?[] indices, Scalar fill) -> Tensor + inline at::Tensor _unsafe_masked_index(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mask, const c10::List<::std::optional> & indices, const at::Scalar & fill) { + return at::_ops::_unsafe_masked_index::redispatch(dispatchKeySet, self, mask, indices, fill); + } + + // aten::_unsafe_masked_index_put_accumulate(Tensor self, Tensor mask, Tensor?[] indices, Tensor values) -> Tensor + inline at::Tensor _unsafe_masked_index_put_accumulate(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mask, const c10::List<::std::optional> & indices, const at::Tensor & values) { + return at::_ops::_unsafe_masked_index_put_accumulate::redispatch(dispatchKeySet, self, mask, indices, values); + } + + // aten::index_copy.out(Tensor self, int dim, Tensor index, Tensor source, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & index_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & source) { + return at::_ops::index_copy_out::redispatch(dispatchKeySet, self, dim, index, source, out); + } + + // aten::index_copy.out(Tensor self, int dim, Tensor index, Tensor source, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & index_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & source, at::Tensor & out) { + return at::_ops::index_copy_out::redispatch(dispatchKeySet, self, dim, index, source, out); + } + + // aten::index_copy_(Tensor(a!) self, int dim, Tensor index, Tensor source) -> Tensor(a!) + inline at::Tensor & index_copy_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & source) { + return at::_ops::index_copy_::redispatch(dispatchKeySet, self, dim, index, source); + } + + // aten::index_copy(Tensor self, int dim, Tensor index, Tensor source) -> Tensor + inline at::Tensor index_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & source) { + return at::_ops::index_copy::redispatch(dispatchKeySet, self, dim, index, source); + } + + // aten::index_copy_.dimname(Tensor(a!) self, Dimname dim, Tensor index, Tensor source) -> Tensor(a!) + inline at::Tensor & index_copy_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, at::Dimname dim, const at::Tensor & index, const at::Tensor & source) { + return at::_ops::index_copy__dimname::redispatch(dispatchKeySet, self, dim, index, source); + } + + // aten::index_copy.dimname(Tensor self, Dimname dim, Tensor index, Tensor source) -> Tensor + inline at::Tensor index_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, const at::Tensor & index, const at::Tensor & source) { + return at::_ops::index_copy_dimname::redispatch(dispatchKeySet, self, dim, index, source); + } + + // aten::index_put_(Tensor(a!) self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor(a!) + inline at::Tensor & index_put_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const c10::List<::std::optional> & indices, const at::Tensor & values, bool accumulate=false) { + return at::_ops::index_put_::redispatch(dispatchKeySet, self, indices, values, accumulate); + } + + // aten::index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor + inline at::Tensor index_put(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const c10::List<::std::optional> & indices, const at::Tensor & values, bool accumulate=false) { + return at::_ops::index_put::redispatch(dispatchKeySet, self, indices, values, accumulate); + } + + // aten::_unsafe_index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor + inline at::Tensor _unsafe_index_put(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const c10::List<::std::optional> & indices, const at::Tensor & values, bool accumulate=false) { + return at::_ops::_unsafe_index_put::redispatch(dispatchKeySet, self, indices, values, accumulate); + } + + // aten::_index_put_impl_(Tensor(a!) self, Tensor?[] indices, Tensor values, bool accumulate=False, bool unsafe=False) -> Tensor(a!) + inline at::Tensor & _index_put_impl_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const c10::List<::std::optional> & indices, const at::Tensor & values, bool accumulate=false, bool unsafe=false) { + return at::_ops::_index_put_impl_::redispatch(dispatchKeySet, self, indices, values, accumulate, unsafe); + } + + // aten::instance_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool use_input_stats, float momentum, float eps, bool cudnn_enabled) -> Tensor + inline at::Tensor instance_norm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const ::std::optional & running_mean, const ::std::optional & running_var, bool use_input_stats, double momentum, double eps, bool cudnn_enabled) { + return at::_ops::instance_norm::redispatch(dispatchKeySet, input, weight, bias, running_mean, running_var, use_input_stats, momentum, eps, cudnn_enabled); + } + + // aten::isclose(Tensor self, Tensor other, float rtol=1e-05, float atol=1e-08, bool equal_nan=False) -> Tensor + inline at::Tensor isclose(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, double rtol=1e-05, double atol=1e-08, bool equal_nan=false) { + return at::_ops::isclose::redispatch(dispatchKeySet, self, other, rtol, atol, equal_nan); + } + + // aten::isin.Tensor_Tensor_out(Tensor elements, Tensor test_elements, *, bool assume_unique=False, bool invert=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & isin_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & elements, const at::Tensor & test_elements, bool assume_unique=false, bool invert=false) { + return at::_ops::isin_Tensor_Tensor_out::redispatch(dispatchKeySet, elements, test_elements, assume_unique, invert, out); + } + + // aten::isin.Tensor_Tensor_out(Tensor elements, Tensor test_elements, *, bool assume_unique=False, bool invert=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & isin_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & elements, const at::Tensor & test_elements, bool assume_unique, bool invert, at::Tensor & out) { + return at::_ops::isin_Tensor_Tensor_out::redispatch(dispatchKeySet, elements, test_elements, assume_unique, invert, out); + } + + // aten::isin.Tensor_Tensor(Tensor elements, Tensor test_elements, *, bool assume_unique=False, bool invert=False) -> Tensor + inline at::Tensor isin(c10::DispatchKeySet dispatchKeySet, const at::Tensor & elements, const at::Tensor & test_elements, bool assume_unique=false, bool invert=false) { + return at::_ops::isin_Tensor_Tensor::redispatch(dispatchKeySet, elements, test_elements, assume_unique, invert); + } + + // aten::isin.Tensor_Scalar_out(Tensor elements, Scalar test_element, *, bool assume_unique=False, bool invert=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & isin_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & elements, const at::Scalar & test_element, bool assume_unique=false, bool invert=false) { + return at::_ops::isin_Tensor_Scalar_out::redispatch(dispatchKeySet, elements, test_element, assume_unique, invert, out); + } + + // aten::isin.Tensor_Scalar_out(Tensor elements, Scalar test_element, *, bool assume_unique=False, bool invert=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & isin_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & elements, const at::Scalar & test_element, bool assume_unique, bool invert, at::Tensor & out) { + return at::_ops::isin_Tensor_Scalar_out::redispatch(dispatchKeySet, elements, test_element, assume_unique, invert, out); + } + + // aten::isin.Tensor_Scalar(Tensor elements, Scalar test_element, *, bool assume_unique=False, bool invert=False) -> Tensor + inline at::Tensor isin(c10::DispatchKeySet dispatchKeySet, const at::Tensor & elements, const at::Scalar & test_element, bool assume_unique=false, bool invert=false) { + return at::_ops::isin_Tensor_Scalar::redispatch(dispatchKeySet, elements, test_element, assume_unique, invert); + } + + // aten::isin.Scalar_Tensor_out(Scalar element, Tensor test_elements, *, bool assume_unique=False, bool invert=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & isin_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & element, const at::Tensor & test_elements, bool assume_unique=false, bool invert=false) { + return at::_ops::isin_Scalar_Tensor_out::redispatch(dispatchKeySet, element, test_elements, assume_unique, invert, out); + } + + // aten::isin.Scalar_Tensor_out(Scalar element, Tensor test_elements, *, bool assume_unique=False, bool invert=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & isin_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & element, const at::Tensor & test_elements, bool assume_unique, bool invert, at::Tensor & out) { + return at::_ops::isin_Scalar_Tensor_out::redispatch(dispatchKeySet, element, test_elements, assume_unique, invert, out); + } + + // aten::isin.Scalar_Tensor(Scalar element, Tensor test_elements, *, bool assume_unique=False, bool invert=False) -> Tensor + inline at::Tensor isin(c10::DispatchKeySet dispatchKeySet, const at::Scalar & element, const at::Tensor & test_elements, bool assume_unique=false, bool invert=false) { + return at::_ops::isin_Scalar_Tensor::redispatch(dispatchKeySet, element, test_elements, assume_unique, invert); + } + + // aten::isnan(Tensor self) -> Tensor + inline at::Tensor isnan(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::isnan::redispatch(dispatchKeySet, self); + } + + // aten::is_distributed(Tensor self) -> bool + inline bool is_distributed(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::is_distributed::redispatch(dispatchKeySet, self); + } + + // aten::is_floating_point(Tensor self) -> bool + inline bool __dispatch_is_floating_point(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::is_floating_point::redispatch(dispatchKeySet, self); + } + + // aten::is_complex(Tensor self) -> bool + inline bool __dispatch_is_complex(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::is_complex::redispatch(dispatchKeySet, self); + } + + // aten::is_conj(Tensor self) -> bool + inline bool __dispatch_is_conj(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::is_conj::redispatch(dispatchKeySet, self); + } + + // aten::_is_zerotensor(Tensor self) -> bool + inline bool __dispatch__is_zerotensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::_is_zerotensor::redispatch(dispatchKeySet, self); + } + + // aten::is_neg(Tensor self) -> bool + inline bool __dispatch_is_neg(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::is_neg::redispatch(dispatchKeySet, self); + } + + // aten::isreal(Tensor self) -> Tensor + inline at::Tensor isreal(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::isreal::redispatch(dispatchKeySet, self); + } + + // aten::is_nonzero(Tensor self) -> bool + inline bool is_nonzero(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::is_nonzero::redispatch(dispatchKeySet, self); + } + + // aten::is_same_size(Tensor self, Tensor other) -> bool + inline bool is_same_size(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::is_same_size::redispatch(dispatchKeySet, self, other); + } + + // aten::is_signed(Tensor self) -> bool + inline bool __dispatch_is_signed(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::is_signed::redispatch(dispatchKeySet, self); + } + + // aten::is_inference(Tensor self) -> bool + inline bool __dispatch_is_inference(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::is_inference::redispatch(dispatchKeySet, self); + } + + // aten::kl_div(Tensor self, Tensor target, int reduction=Mean, *, bool log_target=False) -> Tensor + inline at::Tensor kl_div(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, int64_t reduction=at::Reduction::Mean, bool log_target=false) { + return at::_ops::kl_div::redispatch(dispatchKeySet, self, target, reduction, log_target); + } + + // aten::kron(Tensor self, Tensor other) -> Tensor + inline at::Tensor kron(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::kron::redispatch(dispatchKeySet, self, other); + } + + // aten::kron.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & kron_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::kron_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::kron.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & kron_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::kron_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::kthvalue(Tensor self, SymInt k, int dim=-1, bool keepdim=False) -> (Tensor values, Tensor indices) + inline ::std::tuple kthvalue(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t k, int64_t dim=-1, bool keepdim=false) { + return at::_ops::kthvalue::redispatch(dispatchKeySet, self, k, dim, keepdim); + } + + // aten::kthvalue(Tensor self, SymInt k, int dim=-1, bool keepdim=False) -> (Tensor values, Tensor indices) + inline ::std::tuple kthvalue_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymInt k, int64_t dim=-1, bool keepdim=false) { + return at::_ops::kthvalue::redispatch(dispatchKeySet, self, k, dim, keepdim); + } + + // aten::kthvalue.values(Tensor self, SymInt k, int dim=-1, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + inline ::std::tuple kthvalue_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & values, at::Tensor & indices, const at::Tensor & self, int64_t k, int64_t dim=-1, bool keepdim=false) { + return at::_ops::kthvalue_values::redispatch(dispatchKeySet, self, k, dim, keepdim, values, indices); + } + + // aten::kthvalue.values(Tensor self, SymInt k, int dim=-1, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + inline ::std::tuple kthvalue_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t k, int64_t dim, bool keepdim, at::Tensor & values, at::Tensor & indices) { + return at::_ops::kthvalue_values::redispatch(dispatchKeySet, self, k, dim, keepdim, values, indices); + } + + // aten::kthvalue.values(Tensor self, SymInt k, int dim=-1, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + inline ::std::tuple kthvalue_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & values, at::Tensor & indices, const at::Tensor & self, c10::SymInt k, int64_t dim=-1, bool keepdim=false) { + return at::_ops::kthvalue_values::redispatch(dispatchKeySet, self, k, dim, keepdim, values, indices); + } + + // aten::kthvalue.values(Tensor self, SymInt k, int dim=-1, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + inline ::std::tuple kthvalue_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymInt k, int64_t dim, bool keepdim, at::Tensor & values, at::Tensor & indices) { + return at::_ops::kthvalue_values::redispatch(dispatchKeySet, self, k, dim, keepdim, values, indices); + } + + // aten::kthvalue.dimname(Tensor self, SymInt k, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices) + inline ::std::tuple kthvalue(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t k, at::Dimname dim, bool keepdim=false) { + return at::_ops::kthvalue_dimname::redispatch(dispatchKeySet, self, k, dim, keepdim); + } + + // aten::kthvalue.dimname(Tensor self, SymInt k, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices) + inline ::std::tuple kthvalue_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymInt k, at::Dimname dim, bool keepdim=false) { + return at::_ops::kthvalue_dimname::redispatch(dispatchKeySet, self, k, dim, keepdim); + } + + // aten::kthvalue.dimname_out(Tensor self, SymInt k, Dimname dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + inline ::std::tuple kthvalue_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & values, at::Tensor & indices, const at::Tensor & self, int64_t k, at::Dimname dim, bool keepdim=false) { + return at::_ops::kthvalue_dimname_out::redispatch(dispatchKeySet, self, k, dim, keepdim, values, indices); + } + + // aten::kthvalue.dimname_out(Tensor self, SymInt k, Dimname dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + inline ::std::tuple kthvalue_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t k, at::Dimname dim, bool keepdim, at::Tensor & values, at::Tensor & indices) { + return at::_ops::kthvalue_dimname_out::redispatch(dispatchKeySet, self, k, dim, keepdim, values, indices); + } + + // aten::kthvalue.dimname_out(Tensor self, SymInt k, Dimname dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + inline ::std::tuple kthvalue_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & values, at::Tensor & indices, const at::Tensor & self, c10::SymInt k, at::Dimname dim, bool keepdim=false) { + return at::_ops::kthvalue_dimname_out::redispatch(dispatchKeySet, self, k, dim, keepdim, values, indices); + } + + // aten::kthvalue.dimname_out(Tensor self, SymInt k, Dimname dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + inline ::std::tuple kthvalue_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymInt k, at::Dimname dim, bool keepdim, at::Tensor & values, at::Tensor & indices) { + return at::_ops::kthvalue_dimname_out::redispatch(dispatchKeySet, self, k, dim, keepdim, values, indices); + } + + // aten::layer_norm(Tensor input, SymInt[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> Tensor + inline at::Tensor layer_norm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::IntArrayRef normalized_shape, const ::std::optional & weight={}, const ::std::optional & bias={}, double eps=1e-05, bool cudnn_enable=true) { + return at::_ops::layer_norm::redispatch(dispatchKeySet, input, c10::fromIntArrayRefSlow(normalized_shape), weight, bias, eps, cudnn_enable); + } + + // aten::layer_norm(Tensor input, SymInt[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> Tensor + inline at::Tensor layer_norm_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, c10::SymIntArrayRef normalized_shape, const ::std::optional & weight={}, const ::std::optional & bias={}, double eps=1e-05, bool cudnn_enable=true) { + return at::_ops::layer_norm::redispatch(dispatchKeySet, input, normalized_shape, weight, bias, eps, cudnn_enable); + } + + // aten::native_layer_norm(Tensor input, SymInt[] normalized_shape, Tensor? weight, Tensor? bias, float eps) -> (Tensor, Tensor, Tensor) + inline ::std::tuple native_layer_norm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::IntArrayRef normalized_shape, const ::std::optional & weight, const ::std::optional & bias, double eps) { + return at::_ops::native_layer_norm::redispatch(dispatchKeySet, input, c10::fromIntArrayRefSlow(normalized_shape), weight, bias, eps); + } + + // aten::native_layer_norm(Tensor input, SymInt[] normalized_shape, Tensor? weight, Tensor? bias, float eps) -> (Tensor, Tensor, Tensor) + inline ::std::tuple native_layer_norm_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, c10::SymIntArrayRef normalized_shape, const ::std::optional & weight, const ::std::optional & bias, double eps) { + return at::_ops::native_layer_norm::redispatch(dispatchKeySet, input, normalized_shape, weight, bias, eps); + } + + // aten::native_layer_norm_backward(Tensor grad_out, Tensor input, SymInt[] normalized_shape, Tensor mean, Tensor rstd, Tensor? weight, Tensor? bias, bool[3] output_mask) -> (Tensor, Tensor, Tensor) + inline ::std::tuple native_layer_norm_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_out, const at::Tensor & input, at::IntArrayRef normalized_shape, const at::Tensor & mean, const at::Tensor & rstd, const ::std::optional & weight, const ::std::optional & bias, ::std::array output_mask) { + return at::_ops::native_layer_norm_backward::redispatch(dispatchKeySet, grad_out, input, c10::fromIntArrayRefSlow(normalized_shape), mean, rstd, weight, bias, output_mask); + } + + // aten::native_layer_norm_backward(Tensor grad_out, Tensor input, SymInt[] normalized_shape, Tensor mean, Tensor rstd, Tensor? weight, Tensor? bias, bool[3] output_mask) -> (Tensor, Tensor, Tensor) + inline ::std::tuple native_layer_norm_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_out, const at::Tensor & input, c10::SymIntArrayRef normalized_shape, const at::Tensor & mean, const at::Tensor & rstd, const ::std::optional & weight, const ::std::optional & bias, ::std::array output_mask) { + return at::_ops::native_layer_norm_backward::redispatch(dispatchKeySet, grad_out, input, normalized_shape, mean, rstd, weight, bias, output_mask); + } + + // aten::rms_norm(Tensor input, SymInt[] normalized_shape, Tensor? weight=None, float? eps=None) -> Tensor + inline at::Tensor rms_norm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::IntArrayRef normalized_shape, const ::std::optional & weight={}, ::std::optional eps=::std::nullopt) { + return at::_ops::rms_norm::redispatch(dispatchKeySet, input, c10::fromIntArrayRefSlow(normalized_shape), weight, eps); + } + + // aten::rms_norm(Tensor input, SymInt[] normalized_shape, Tensor? weight=None, float? eps=None) -> Tensor + inline at::Tensor rms_norm_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, c10::SymIntArrayRef normalized_shape, const ::std::optional & weight={}, ::std::optional eps=::std::nullopt) { + return at::_ops::rms_norm::redispatch(dispatchKeySet, input, normalized_shape, weight, eps); + } + + // aten::_fused_rms_norm(Tensor input, int[] normalized_shape, Tensor? weight, float? eps) -> (Tensor, Tensor) + inline ::std::tuple _fused_rms_norm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::IntArrayRef normalized_shape, const ::std::optional & weight, ::std::optional eps) { + return at::_ops::_fused_rms_norm::redispatch(dispatchKeySet, input, normalized_shape, weight, eps); + } + + // aten::_fused_rms_norm_backward(Tensor grad_out, Tensor input, int[] normalized_shape, Tensor rstd, Tensor? weight, bool[2] output_mask) -> (Tensor, Tensor) + inline ::std::tuple _fused_rms_norm_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_out, const at::Tensor & input, at::IntArrayRef normalized_shape, const at::Tensor & rstd, const ::std::optional & weight, ::std::array output_mask) { + return at::_ops::_fused_rms_norm_backward::redispatch(dispatchKeySet, grad_out, input, normalized_shape, rstd, weight, output_mask); + } + + // aten::nan_to_num(Tensor self, float? nan=None, float? posinf=None, float? neginf=None) -> Tensor + inline at::Tensor nan_to_num(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional nan=::std::nullopt, ::std::optional posinf=::std::nullopt, ::std::optional neginf=::std::nullopt) { + return at::_ops::nan_to_num::redispatch(dispatchKeySet, self, nan, posinf, neginf); + } + + // aten::nan_to_num_(Tensor(a!) self, float? nan=None, float? posinf=None, float? neginf=None) -> Tensor(a!) + inline at::Tensor & nan_to_num_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, ::std::optional nan=::std::nullopt, ::std::optional posinf=::std::nullopt, ::std::optional neginf=::std::nullopt) { + return at::_ops::nan_to_num_::redispatch(dispatchKeySet, self, nan, posinf, neginf); + } + + // aten::nan_to_num.out(Tensor self, float? nan=None, float? posinf=None, float? neginf=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & nan_to_num_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, ::std::optional nan=::std::nullopt, ::std::optional posinf=::std::nullopt, ::std::optional neginf=::std::nullopt) { + return at::_ops::nan_to_num_out::redispatch(dispatchKeySet, self, nan, posinf, neginf, out); + } + + // aten::nan_to_num.out(Tensor self, float? nan=None, float? posinf=None, float? neginf=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & nan_to_num_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional nan, ::std::optional posinf, ::std::optional neginf, at::Tensor & out) { + return at::_ops::nan_to_num_out::redispatch(dispatchKeySet, self, nan, posinf, neginf, out); + } + + // aten::linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor + inline at::Tensor linear(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias={}) { + return at::_ops::linear::redispatch(dispatchKeySet, input, weight, bias); + } + + // aten::linear_backward(Tensor self, Tensor grad_output, Tensor weight, bool[3] output_mask) -> (Tensor, Tensor, Tensor) + inline ::std::tuple linear_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & grad_output, const at::Tensor & weight, ::std::array output_mask) { + return at::_ops::linear_backward::redispatch(dispatchKeySet, self, grad_output, weight, output_mask); + } + + // aten::linear.out(Tensor input, Tensor weight, Tensor? bias=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linear_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias={}) { + return at::_ops::linear_out::redispatch(dispatchKeySet, input, weight, bias, out); + } + + // aten::linear.out(Tensor input, Tensor weight, Tensor? bias=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linear_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, at::Tensor & out) { + return at::_ops::linear_out::redispatch(dispatchKeySet, input, weight, bias, out); + } + + // aten::mkldnn_linear(Tensor self, Tensor weight, Tensor? bias=None) -> Tensor + inline at::Tensor mkldnn_linear(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias={}) { + return at::_ops::mkldnn_linear::redispatch(dispatchKeySet, self, weight, bias); + } + + // aten::mkldnn_linear_backward_input(int[] input_size, Tensor grad_output, Tensor weight) -> Tensor + inline at::Tensor mkldnn_linear_backward_input(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef input_size, const at::Tensor & grad_output, const at::Tensor & weight) { + return at::_ops::mkldnn_linear_backward_input::redispatch(dispatchKeySet, input_size, grad_output, weight); + } + + // aten::mkldnn_linear_backward_weights(Tensor grad_output, Tensor input, Tensor weight, bool bias_defined) -> (Tensor, Tensor) + inline ::std::tuple mkldnn_linear_backward_weights(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & weight, bool bias_defined) { + return at::_ops::mkldnn_linear_backward_weights::redispatch(dispatchKeySet, grad_output, input, weight, bias_defined); + } + + // aten::mkldnn_linear_backward(Tensor self, Tensor grad_output, Tensor weight, bool[3] output_mask) -> (Tensor, Tensor, Tensor) + inline ::std::tuple mkldnn_linear_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & grad_output, const at::Tensor & weight, ::std::array output_mask) { + return at::_ops::mkldnn_linear_backward::redispatch(dispatchKeySet, self, grad_output, weight, output_mask); + } + + // aten::_cslt_compress(Tensor input) -> Tensor + inline at::Tensor _cslt_compress(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input) { + return at::_ops::_cslt_compress::redispatch(dispatchKeySet, input); + } + + // aten::_cslt_sparse_mm(Tensor compressed_A, Tensor dense_B, Tensor? bias=None, Tensor? alpha=None, ScalarType? out_dtype=None, bool transpose_result=False, int alg_id=0, int split_k=1, int split_k_mode=-1) -> Tensor + inline at::Tensor _cslt_sparse_mm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & compressed_A, const at::Tensor & dense_B, const ::std::optional & bias={}, const ::std::optional & alpha={}, ::std::optional out_dtype=::std::nullopt, bool transpose_result=false, int64_t alg_id=0, int64_t split_k=1, int64_t split_k_mode=-1) { + return at::_ops::_cslt_sparse_mm::redispatch(dispatchKeySet, compressed_A, dense_B, bias, alpha, out_dtype, transpose_result, alg_id, split_k, split_k_mode); + } + + // aten::_cslt_sparse_mm_search(Tensor compressed_A, Tensor dense_B, Tensor? bias=None, Tensor? alpha=None, ScalarType? out_dtype=None, bool transpose_result=False) -> int + inline int64_t _cslt_sparse_mm_search(c10::DispatchKeySet dispatchKeySet, const at::Tensor & compressed_A, const at::Tensor & dense_B, const ::std::optional & bias={}, const ::std::optional & alpha={}, ::std::optional out_dtype=::std::nullopt, bool transpose_result=false) { + return at::_ops::_cslt_sparse_mm_search::redispatch(dispatchKeySet, compressed_A, dense_B, bias, alpha, out_dtype, transpose_result); + } + + // aten::_sparse_semi_structured_tile(Tensor input, str algorithm="", bool use_cutlass=True) -> (Tensor, Tensor, Tensor, Tensor, Tensor) + inline ::std::tuple _sparse_semi_structured_tile(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, c10::string_view algorithm="", bool use_cutlass=true) { + return at::_ops::_sparse_semi_structured_tile::redispatch(dispatchKeySet, input, algorithm, use_cutlass); + } + + // aten::_sparse_semi_structured_apply(Tensor input, Tensor thread_masks) -> (Tensor, Tensor) + inline ::std::tuple _sparse_semi_structured_apply(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & thread_masks) { + return at::_ops::_sparse_semi_structured_apply::redispatch(dispatchKeySet, input, thread_masks); + } + + // aten::_sparse_semi_structured_apply_dense(Tensor input, Tensor thread_masks) -> Tensor + inline at::Tensor _sparse_semi_structured_apply_dense(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & thread_masks) { + return at::_ops::_sparse_semi_structured_apply_dense::redispatch(dispatchKeySet, input, thread_masks); + } + + // aten::_sparse_semi_structured_linear(Tensor input, Tensor weight, Tensor meta, *, Tensor? bias=None, str? activation=None, ScalarType? out_dtype=None) -> Tensor + inline at::Tensor _sparse_semi_structured_linear(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const at::Tensor & meta, const ::std::optional & bias={}, ::std::optional activation=::std::nullopt, ::std::optional out_dtype=::std::nullopt) { + return at::_ops::_sparse_semi_structured_linear::redispatch(dispatchKeySet, input, weight, meta, bias, activation, out_dtype); + } + + // aten::_sparse_semi_structured_mm(Tensor mat1, Tensor mat1_meta, Tensor mat2, *, ScalarType? out_dtype=None) -> Tensor + inline at::Tensor _sparse_semi_structured_mm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & mat1, const at::Tensor & mat1_meta, const at::Tensor & mat2, ::std::optional out_dtype=::std::nullopt) { + return at::_ops::_sparse_semi_structured_mm::redispatch(dispatchKeySet, mat1, mat1_meta, mat2, out_dtype); + } + + // aten::_sparse_semi_structured_addmm(Tensor input, Tensor mat1, Tensor mat1_meta, Tensor mat2, *, Scalar alpha=1, Scalar beta=1, ScalarType? out_dtype=None) -> Tensor + inline at::Tensor _sparse_semi_structured_addmm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & mat1, const at::Tensor & mat1_meta, const at::Tensor & mat2, const at::Scalar & alpha=1, const at::Scalar & beta=1, ::std::optional out_dtype=::std::nullopt) { + return at::_ops::_sparse_semi_structured_addmm::redispatch(dispatchKeySet, input, mat1, mat1_meta, mat2, alpha, beta, out_dtype); + } + + // aten::_mixed_dtypes_linear(Tensor input, Tensor weight, Tensor scale, *, Tensor? bias=None, str? activation=None) -> Tensor + inline at::Tensor _mixed_dtypes_linear(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const at::Tensor & scale, const ::std::optional & bias={}, ::std::optional activation=::std::nullopt) { + return at::_ops::_mixed_dtypes_linear::redispatch(dispatchKeySet, input, weight, scale, bias, activation); + } + + // aten::fbgemm_linear_int8_weight_fp32_activation(Tensor input, Tensor weight, Tensor packed, Tensor col_offsets, Scalar weight_scale, Scalar weight_zero_point, Tensor bias) -> Tensor + inline at::Tensor fbgemm_linear_int8_weight_fp32_activation(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const at::Tensor & packed, const at::Tensor & col_offsets, const at::Scalar & weight_scale, const at::Scalar & weight_zero_point, const at::Tensor & bias) { + return at::_ops::fbgemm_linear_int8_weight_fp32_activation::redispatch(dispatchKeySet, input, weight, packed, col_offsets, weight_scale, weight_zero_point, bias); + } + + // aten::fbgemm_linear_int8_weight(Tensor input, Tensor weight, Tensor packed, Tensor col_offsets, Scalar weight_scale, Scalar weight_zero_point, Tensor bias) -> Tensor + inline at::Tensor fbgemm_linear_int8_weight(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const at::Tensor & packed, const at::Tensor & col_offsets, const at::Scalar & weight_scale, const at::Scalar & weight_zero_point, const at::Tensor & bias) { + return at::_ops::fbgemm_linear_int8_weight::redispatch(dispatchKeySet, input, weight, packed, col_offsets, weight_scale, weight_zero_point, bias); + } + + // aten::fbgemm_linear_quantize_weight(Tensor input) -> (Tensor, Tensor, float, int) + inline ::std::tuple fbgemm_linear_quantize_weight(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input) { + return at::_ops::fbgemm_linear_quantize_weight::redispatch(dispatchKeySet, input); + } + + // aten::fbgemm_pack_gemm_matrix_fp16(Tensor input) -> Tensor + inline at::Tensor fbgemm_pack_gemm_matrix_fp16(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input) { + return at::_ops::fbgemm_pack_gemm_matrix_fp16::redispatch(dispatchKeySet, input); + } + + // aten::_wrapped_linear_prepack(Tensor weight, Tensor weight_scale, Tensor weight_zero_point, Tensor bias) -> Tensor + inline at::Tensor _wrapped_linear_prepack(c10::DispatchKeySet dispatchKeySet, const at::Tensor & weight, const at::Tensor & weight_scale, const at::Tensor & weight_zero_point, const at::Tensor & bias) { + return at::_ops::_wrapped_linear_prepack::redispatch(dispatchKeySet, weight, weight_scale, weight_zero_point, bias); + } + + // aten::_wrapped_quantized_linear_prepacked(Tensor input, Tensor input_scale, Tensor input_zero_point, Tensor packed_weight, Tensor output_scale, Tensor output_zero_point, int out_channel) -> Tensor + inline at::Tensor _wrapped_quantized_linear_prepacked(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & input_scale, const at::Tensor & input_zero_point, const at::Tensor & packed_weight, const at::Tensor & output_scale, const at::Tensor & output_zero_point, int64_t out_channel) { + return at::_ops::_wrapped_quantized_linear_prepacked::redispatch(dispatchKeySet, input, input_scale, input_zero_point, packed_weight, output_scale, output_zero_point, out_channel); + } + + // aten::fbgemm_linear_fp16_weight_fp32_activation(Tensor input, Tensor packed_weight, Tensor? bias) -> Tensor + inline at::Tensor fbgemm_linear_fp16_weight_fp32_activation(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & packed_weight, const ::std::optional & bias) { + return at::_ops::fbgemm_linear_fp16_weight_fp32_activation::redispatch(dispatchKeySet, input, packed_weight, bias); + } + + // aten::fbgemm_linear_fp16_weight_fp32_activation.out(Tensor input, Tensor packed_weight, Tensor? bias, Tensor(a!) output) -> Tensor + inline at::Tensor fbgemm_linear_fp16_weight_fp32_activation(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & packed_weight, const ::std::optional & bias, at::Tensor & output) { + return at::_ops::fbgemm_linear_fp16_weight_fp32_activation_out::redispatch(dispatchKeySet, input, packed_weight, bias, output); + } + + // aten::fbgemm_linear_fp16_weight(Tensor input, Tensor packed_weight, Tensor bias) -> Tensor + inline at::Tensor fbgemm_linear_fp16_weight(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & packed_weight, const at::Tensor & bias) { + return at::_ops::fbgemm_linear_fp16_weight::redispatch(dispatchKeySet, input, packed_weight, bias); + } + + // aten::fbgemm_linear_fp16_weight.out(Tensor input, Tensor packed_weight, Tensor bias, Tensor(a!) output) -> Tensor + inline at::Tensor fbgemm_linear_fp16_weight(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & packed_weight, const at::Tensor & bias, at::Tensor & output) { + return at::_ops::fbgemm_linear_fp16_weight_out::redispatch(dispatchKeySet, input, packed_weight, bias, output); + } + + // aten::fbgemm_pack_quantized_matrix(Tensor input) -> Tensor + inline at::Tensor fbgemm_pack_quantized_matrix(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input) { + return at::_ops::fbgemm_pack_quantized_matrix::redispatch(dispatchKeySet, input); + } + + // aten::fbgemm_pack_quantized_matrix.KN(Tensor input, int K, int N) -> Tensor + inline at::Tensor fbgemm_pack_quantized_matrix(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, int64_t K, int64_t N) { + return at::_ops::fbgemm_pack_quantized_matrix_KN::redispatch(dispatchKeySet, input, K, N); + } + + // aten::ldexp.Tensor(Tensor self, Tensor other) -> Tensor + inline at::Tensor ldexp(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::ldexp_Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::ldexp_(Tensor(a!) self, Tensor other) -> Tensor(a!) + inline at::Tensor & ldexp_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) { + return at::_ops::ldexp_::redispatch(dispatchKeySet, self, other); + } + + // aten::ldexp.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & ldexp_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::ldexp_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::ldexp.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & ldexp_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::ldexp_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::linspace(Scalar start, Scalar end, int steps, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor linspace(c10::DispatchKeySet dispatchKeySet, const at::Scalar & start, const at::Scalar & end, int64_t steps, at::TensorOptions options={}) { + return at::_ops::linspace::redispatch(dispatchKeySet, start, end, steps, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::linspace(Scalar start, Scalar end, int steps, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor linspace(c10::DispatchKeySet dispatchKeySet, const at::Scalar & start, const at::Scalar & end, int64_t steps, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::linspace::redispatch(dispatchKeySet, start, end, steps, dtype, layout, device, pin_memory); + } + + // aten::linspace.Tensor_Tensor(Tensor start, Tensor end, int steps, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor linspace(c10::DispatchKeySet dispatchKeySet, const at::Tensor & start, const at::Tensor & end, int64_t steps, at::TensorOptions options={}) { + return at::_ops::linspace_Tensor_Tensor::redispatch(dispatchKeySet, start, end, steps, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::linspace.Tensor_Tensor(Tensor start, Tensor end, int steps, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor linspace(c10::DispatchKeySet dispatchKeySet, const at::Tensor & start, const at::Tensor & end, int64_t steps, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::linspace_Tensor_Tensor::redispatch(dispatchKeySet, start, end, steps, dtype, layout, device, pin_memory); + } + + // aten::linspace.Tensor_Scalar(Tensor start, Scalar end, int steps, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor linspace(c10::DispatchKeySet dispatchKeySet, const at::Tensor & start, const at::Scalar & end, int64_t steps, at::TensorOptions options={}) { + return at::_ops::linspace_Tensor_Scalar::redispatch(dispatchKeySet, start, end, steps, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::linspace.Tensor_Scalar(Tensor start, Scalar end, int steps, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor linspace(c10::DispatchKeySet dispatchKeySet, const at::Tensor & start, const at::Scalar & end, int64_t steps, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::linspace_Tensor_Scalar::redispatch(dispatchKeySet, start, end, steps, dtype, layout, device, pin_memory); + } + + // aten::linspace.Scalar_Tensor(Scalar start, Tensor end, int steps, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor linspace(c10::DispatchKeySet dispatchKeySet, const at::Scalar & start, const at::Tensor & end, int64_t steps, at::TensorOptions options={}) { + return at::_ops::linspace_Scalar_Tensor::redispatch(dispatchKeySet, start, end, steps, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::linspace.Scalar_Tensor(Scalar start, Tensor end, int steps, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor linspace(c10::DispatchKeySet dispatchKeySet, const at::Scalar & start, const at::Tensor & end, int64_t steps, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::linspace_Scalar_Tensor::redispatch(dispatchKeySet, start, end, steps, dtype, layout, device, pin_memory); + } + + // aten::linspace.out(Scalar start, Scalar end, int steps, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linspace_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & start, const at::Scalar & end, int64_t steps) { + return at::_ops::linspace_out::redispatch(dispatchKeySet, start, end, steps, out); + } + + // aten::linspace.out(Scalar start, Scalar end, int steps, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linspace_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & start, const at::Scalar & end, int64_t steps, at::Tensor & out) { + return at::_ops::linspace_out::redispatch(dispatchKeySet, start, end, steps, out); + } + + // aten::linspace.Tensor_Tensor_out(Tensor start, Tensor end, int steps, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linspace_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & start, const at::Tensor & end, int64_t steps) { + return at::_ops::linspace_Tensor_Tensor_out::redispatch(dispatchKeySet, start, end, steps, out); + } + + // aten::linspace.Tensor_Tensor_out(Tensor start, Tensor end, int steps, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linspace_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & start, const at::Tensor & end, int64_t steps, at::Tensor & out) { + return at::_ops::linspace_Tensor_Tensor_out::redispatch(dispatchKeySet, start, end, steps, out); + } + + // aten::linspace.Tensor_Scalar_out(Tensor start, Scalar end, int steps, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linspace_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & start, const at::Scalar & end, int64_t steps) { + return at::_ops::linspace_Tensor_Scalar_out::redispatch(dispatchKeySet, start, end, steps, out); + } + + // aten::linspace.Tensor_Scalar_out(Tensor start, Scalar end, int steps, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linspace_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & start, const at::Scalar & end, int64_t steps, at::Tensor & out) { + return at::_ops::linspace_Tensor_Scalar_out::redispatch(dispatchKeySet, start, end, steps, out); + } + + // aten::linspace.Scalar_Tensor_out(Scalar start, Tensor end, int steps, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linspace_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & start, const at::Tensor & end, int64_t steps) { + return at::_ops::linspace_Scalar_Tensor_out::redispatch(dispatchKeySet, start, end, steps, out); + } + + // aten::linspace.Scalar_Tensor_out(Scalar start, Tensor end, int steps, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linspace_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & start, const at::Tensor & end, int64_t steps, at::Tensor & out) { + return at::_ops::linspace_Scalar_Tensor_out::redispatch(dispatchKeySet, start, end, steps, out); + } + + // aten::log(Tensor self) -> Tensor + inline at::Tensor log(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::log::redispatch(dispatchKeySet, self); + } + + // aten::log_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & log_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::log_::redispatch(dispatchKeySet, self); + } + + // aten::log.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & log_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::log_out::redispatch(dispatchKeySet, self, out); + } + + // aten::log.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & log_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::log_out::redispatch(dispatchKeySet, self, out); + } + + // aten::log10(Tensor self) -> Tensor + inline at::Tensor log10(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::log10::redispatch(dispatchKeySet, self); + } + + // aten::log10_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & log10_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::log10_::redispatch(dispatchKeySet, self); + } + + // aten::log10.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & log10_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::log10_out::redispatch(dispatchKeySet, self, out); + } + + // aten::log10.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & log10_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::log10_out::redispatch(dispatchKeySet, self, out); + } + + // aten::log1p(Tensor self) -> Tensor + inline at::Tensor log1p(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::log1p::redispatch(dispatchKeySet, self); + } + + // aten::log1p_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & log1p_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::log1p_::redispatch(dispatchKeySet, self); + } + + // aten::log1p.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & log1p_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::log1p_out::redispatch(dispatchKeySet, self, out); + } + + // aten::log1p.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & log1p_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::log1p_out::redispatch(dispatchKeySet, self, out); + } + + // aten::log2(Tensor self) -> Tensor + inline at::Tensor log2(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::log2::redispatch(dispatchKeySet, self); + } + + // aten::log2_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & log2_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::log2_::redispatch(dispatchKeySet, self); + } + + // aten::log2.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & log2_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::log2_out::redispatch(dispatchKeySet, self, out); + } + + // aten::log2.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & log2_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::log2_out::redispatch(dispatchKeySet, self, out); + } + + // aten::logaddexp.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & logaddexp_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::logaddexp_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::logaddexp.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & logaddexp_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::logaddexp_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::logaddexp(Tensor self, Tensor other) -> Tensor + inline at::Tensor logaddexp(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::logaddexp::redispatch(dispatchKeySet, self, other); + } + + // aten::logaddexp2.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & logaddexp2_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::logaddexp2_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::logaddexp2.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & logaddexp2_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::logaddexp2_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::logaddexp2(Tensor self, Tensor other) -> Tensor + inline at::Tensor logaddexp2(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::logaddexp2::redispatch(dispatchKeySet, self, other); + } + + // aten::xlogy.Tensor(Tensor self, Tensor other) -> Tensor + inline at::Tensor xlogy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::xlogy_Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::xlogy.Scalar_Self(Scalar self, Tensor other) -> Tensor + inline at::Tensor xlogy(c10::DispatchKeySet dispatchKeySet, const at::Scalar & self, const at::Tensor & other) { + return at::_ops::xlogy_Scalar_Self::redispatch(dispatchKeySet, self, other); + } + + // aten::xlogy.Scalar_Other(Tensor self, Scalar other) -> Tensor + inline at::Tensor xlogy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::xlogy_Scalar_Other::redispatch(dispatchKeySet, self, other); + } + + // aten::xlogy_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + inline at::Tensor & xlogy_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) { + return at::_ops::xlogy__Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::xlogy_.Scalar_Other(Tensor(a!) self, Scalar other) -> Tensor(a!) + inline at::Tensor & xlogy_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other) { + return at::_ops::xlogy__Scalar_Other::redispatch(dispatchKeySet, self, other); + } + + // aten::xlogy.OutTensor(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & xlogy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::xlogy_OutTensor::redispatch(dispatchKeySet, self, other, out); + } + + // aten::xlogy.OutTensor(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & xlogy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::xlogy_OutTensor::redispatch(dispatchKeySet, self, other, out); + } + + // aten::xlogy.OutScalar_Self(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & xlogy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & self, const at::Tensor & other) { + return at::_ops::xlogy_OutScalar_Self::redispatch(dispatchKeySet, self, other, out); + } + + // aten::xlogy.OutScalar_Self(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & xlogy_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::xlogy_OutScalar_Self::redispatch(dispatchKeySet, self, other, out); + } + + // aten::xlogy.OutScalar_Other(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & xlogy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::xlogy_OutScalar_Other::redispatch(dispatchKeySet, self, other, out); + } + + // aten::xlogy.OutScalar_Other(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & xlogy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, at::Tensor & out) { + return at::_ops::xlogy_OutScalar_Other::redispatch(dispatchKeySet, self, other, out); + } + + // aten::logspace(Scalar start, Scalar end, int steps, float base=10.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor logspace(c10::DispatchKeySet dispatchKeySet, const at::Scalar & start, const at::Scalar & end, int64_t steps, double base=10.0, at::TensorOptions options={}) { + return at::_ops::logspace::redispatch(dispatchKeySet, start, end, steps, base, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::logspace(Scalar start, Scalar end, int steps, float base=10.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor logspace(c10::DispatchKeySet dispatchKeySet, const at::Scalar & start, const at::Scalar & end, int64_t steps, double base, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::logspace::redispatch(dispatchKeySet, start, end, steps, base, dtype, layout, device, pin_memory); + } + + // aten::logspace.Tensor_Tensor(Tensor start, Tensor end, int steps, float base=10.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor logspace(c10::DispatchKeySet dispatchKeySet, const at::Tensor & start, const at::Tensor & end, int64_t steps, double base=10.0, at::TensorOptions options={}) { + return at::_ops::logspace_Tensor_Tensor::redispatch(dispatchKeySet, start, end, steps, base, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::logspace.Tensor_Tensor(Tensor start, Tensor end, int steps, float base=10.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor logspace(c10::DispatchKeySet dispatchKeySet, const at::Tensor & start, const at::Tensor & end, int64_t steps, double base, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::logspace_Tensor_Tensor::redispatch(dispatchKeySet, start, end, steps, base, dtype, layout, device, pin_memory); + } + + // aten::logspace.Tensor_Scalar(Tensor start, Scalar end, int steps, float base=10.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor logspace(c10::DispatchKeySet dispatchKeySet, const at::Tensor & start, const at::Scalar & end, int64_t steps, double base=10.0, at::TensorOptions options={}) { + return at::_ops::logspace_Tensor_Scalar::redispatch(dispatchKeySet, start, end, steps, base, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::logspace.Tensor_Scalar(Tensor start, Scalar end, int steps, float base=10.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor logspace(c10::DispatchKeySet dispatchKeySet, const at::Tensor & start, const at::Scalar & end, int64_t steps, double base, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::logspace_Tensor_Scalar::redispatch(dispatchKeySet, start, end, steps, base, dtype, layout, device, pin_memory); + } + + // aten::logspace.Scalar_Tensor(Scalar start, Tensor end, int steps, float base=10.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor logspace(c10::DispatchKeySet dispatchKeySet, const at::Scalar & start, const at::Tensor & end, int64_t steps, double base=10.0, at::TensorOptions options={}) { + return at::_ops::logspace_Scalar_Tensor::redispatch(dispatchKeySet, start, end, steps, base, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::logspace.Scalar_Tensor(Scalar start, Tensor end, int steps, float base=10.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor logspace(c10::DispatchKeySet dispatchKeySet, const at::Scalar & start, const at::Tensor & end, int64_t steps, double base, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::logspace_Scalar_Tensor::redispatch(dispatchKeySet, start, end, steps, base, dtype, layout, device, pin_memory); + } + + // aten::logspace.out(Scalar start, Scalar end, int steps, float base=10.0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & logspace_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & start, const at::Scalar & end, int64_t steps, double base=10.0) { + return at::_ops::logspace_out::redispatch(dispatchKeySet, start, end, steps, base, out); + } + + // aten::logspace.out(Scalar start, Scalar end, int steps, float base=10.0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & logspace_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & start, const at::Scalar & end, int64_t steps, double base, at::Tensor & out) { + return at::_ops::logspace_out::redispatch(dispatchKeySet, start, end, steps, base, out); + } + + // aten::logspace.Tensor_Tensor_out(Tensor start, Tensor end, int steps, float base=10.0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & logspace_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & start, const at::Tensor & end, int64_t steps, double base=10.0) { + return at::_ops::logspace_Tensor_Tensor_out::redispatch(dispatchKeySet, start, end, steps, base, out); + } + + // aten::logspace.Tensor_Tensor_out(Tensor start, Tensor end, int steps, float base=10.0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & logspace_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & start, const at::Tensor & end, int64_t steps, double base, at::Tensor & out) { + return at::_ops::logspace_Tensor_Tensor_out::redispatch(dispatchKeySet, start, end, steps, base, out); + } + + // aten::logspace.Tensor_Scalar_out(Tensor start, Scalar end, int steps, float base=10.0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & logspace_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & start, const at::Scalar & end, int64_t steps, double base=10.0) { + return at::_ops::logspace_Tensor_Scalar_out::redispatch(dispatchKeySet, start, end, steps, base, out); + } + + // aten::logspace.Tensor_Scalar_out(Tensor start, Scalar end, int steps, float base=10.0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & logspace_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & start, const at::Scalar & end, int64_t steps, double base, at::Tensor & out) { + return at::_ops::logspace_Tensor_Scalar_out::redispatch(dispatchKeySet, start, end, steps, base, out); + } + + // aten::logspace.Scalar_Tensor_out(Scalar start, Tensor end, int steps, float base=10.0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & logspace_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & start, const at::Tensor & end, int64_t steps, double base=10.0) { + return at::_ops::logspace_Scalar_Tensor_out::redispatch(dispatchKeySet, start, end, steps, base, out); + } + + // aten::logspace.Scalar_Tensor_out(Scalar start, Tensor end, int steps, float base=10.0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & logspace_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & start, const at::Tensor & end, int64_t steps, double base, at::Tensor & out) { + return at::_ops::logspace_Scalar_Tensor_out::redispatch(dispatchKeySet, start, end, steps, base, out); + } + + // aten::log_softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor + inline at::Tensor log_softmax(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, ::std::optional dtype=::std::nullopt) { + return at::_ops::log_softmax_int::redispatch(dispatchKeySet, self, dim, dtype); + } + + // aten::log_softmax.int_out(Tensor self, int dim, ScalarType? dtype=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & log_softmax_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim, ::std::optional dtype=::std::nullopt) { + return at::_ops::log_softmax_int_out::redispatch(dispatchKeySet, self, dim, dtype, out); + } + + // aten::log_softmax.int_out(Tensor self, int dim, ScalarType? dtype=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & log_softmax_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, ::std::optional dtype, at::Tensor & out) { + return at::_ops::log_softmax_int_out::redispatch(dispatchKeySet, self, dim, dtype, out); + } + + // aten::log_softmax.Dimname(Tensor self, Dimname dim, *, ScalarType? dtype=None) -> Tensor + inline at::Tensor log_softmax(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, ::std::optional dtype=::std::nullopt) { + return at::_ops::log_softmax_Dimname::redispatch(dispatchKeySet, self, dim, dtype); + } + + // aten::_log_softmax(Tensor self, int dim, bool half_to_float) -> Tensor + inline at::Tensor _log_softmax(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, bool half_to_float) { + return at::_ops::_log_softmax::redispatch(dispatchKeySet, self, dim, half_to_float); + } + + // aten::_log_softmax.out(Tensor self, int dim, bool half_to_float, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _log_softmax_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim, bool half_to_float) { + return at::_ops::_log_softmax_out::redispatch(dispatchKeySet, self, dim, half_to_float, out); + } + + // aten::_log_softmax.out(Tensor self, int dim, bool half_to_float, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _log_softmax_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, bool half_to_float, at::Tensor & out) { + return at::_ops::_log_softmax_out::redispatch(dispatchKeySet, self, dim, half_to_float, out); + } + + // aten::_log_softmax_backward_data(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype) -> Tensor + inline at::Tensor _log_softmax_backward_data(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & output, int64_t dim, at::ScalarType input_dtype) { + return at::_ops::_log_softmax_backward_data::redispatch(dispatchKeySet, grad_output, output, dim, input_dtype); + } + + // aten::_log_softmax_backward_data.out(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _log_softmax_backward_data_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad_output, const at::Tensor & output, int64_t dim, at::ScalarType input_dtype) { + return at::_ops::_log_softmax_backward_data_out::redispatch(dispatchKeySet, grad_output, output, dim, input_dtype, out); + } + + // aten::_log_softmax_backward_data.out(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _log_softmax_backward_data_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & output, int64_t dim, at::ScalarType input_dtype, at::Tensor & out) { + return at::_ops::_log_softmax_backward_data_out::redispatch(dispatchKeySet, grad_output, output, dim, input_dtype, out); + } + + // aten::_logcumsumexp(Tensor self, int dim) -> Tensor + inline at::Tensor _logcumsumexp(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim) { + return at::_ops::_logcumsumexp::redispatch(dispatchKeySet, self, dim); + } + + // aten::_logcumsumexp.out(Tensor self, int dim, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _logcumsumexp_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim) { + return at::_ops::_logcumsumexp_out::redispatch(dispatchKeySet, self, dim, out); + } + + // aten::_logcumsumexp.out(Tensor self, int dim, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _logcumsumexp_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, at::Tensor & out) { + return at::_ops::_logcumsumexp_out::redispatch(dispatchKeySet, self, dim, out); + } + + // aten::logcumsumexp(Tensor self, int dim) -> Tensor + inline at::Tensor logcumsumexp(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim) { + return at::_ops::logcumsumexp::redispatch(dispatchKeySet, self, dim); + } + + // aten::logcumsumexp.out(Tensor self, int dim, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & logcumsumexp_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim) { + return at::_ops::logcumsumexp_out::redispatch(dispatchKeySet, self, dim, out); + } + + // aten::logcumsumexp.out(Tensor self, int dim, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & logcumsumexp_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, at::Tensor & out) { + return at::_ops::logcumsumexp_out::redispatch(dispatchKeySet, self, dim, out); + } + + // aten::logcumsumexp.dimname(Tensor self, Dimname dim) -> Tensor + inline at::Tensor logcumsumexp(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim) { + return at::_ops::logcumsumexp_dimname::redispatch(dispatchKeySet, self, dim); + } + + // aten::logcumsumexp.dimname_out(Tensor self, Dimname dim, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & logcumsumexp_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::Dimname dim) { + return at::_ops::logcumsumexp_dimname_out::redispatch(dispatchKeySet, self, dim, out); + } + + // aten::logcumsumexp.dimname_out(Tensor self, Dimname dim, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & logcumsumexp_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, at::Tensor & out) { + return at::_ops::logcumsumexp_dimname_out::redispatch(dispatchKeySet, self, dim, out); + } + + // aten::logsumexp(Tensor self, int[1] dim, bool keepdim=False) -> Tensor + inline at::Tensor logsumexp(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim, bool keepdim=false) { + return at::_ops::logsumexp::redispatch(dispatchKeySet, self, dim, keepdim); + } + + // aten::logsumexp.out(Tensor self, int[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & logsumexp_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef dim, bool keepdim=false) { + return at::_ops::logsumexp_out::redispatch(dispatchKeySet, self, dim, keepdim, out); + } + + // aten::logsumexp.out(Tensor self, int[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & logsumexp_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim, bool keepdim, at::Tensor & out) { + return at::_ops::logsumexp_out::redispatch(dispatchKeySet, self, dim, keepdim, out); + } + + // aten::logsumexp.names(Tensor self, Dimname[1] dim, bool keepdim=False) -> Tensor + inline at::Tensor logsumexp(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::DimnameList dim, bool keepdim=false) { + return at::_ops::logsumexp_names::redispatch(dispatchKeySet, self, dim, keepdim); + } + + // aten::logsumexp.names_out(Tensor self, Dimname[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & logsumexp_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::DimnameList dim, bool keepdim=false) { + return at::_ops::logsumexp_names_out::redispatch(dispatchKeySet, self, dim, keepdim, out); + } + + // aten::logsumexp.names_out(Tensor self, Dimname[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & logsumexp_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::DimnameList dim, bool keepdim, at::Tensor & out) { + return at::_ops::logsumexp_names_out::redispatch(dispatchKeySet, self, dim, keepdim, out); + } + + // aten::margin_ranking_loss(Tensor input1, Tensor input2, Tensor target, float margin=0.0, int reduction=Mean) -> Tensor + inline at::Tensor margin_ranking_loss(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input1, const at::Tensor & input2, const at::Tensor & target, double margin=0.0, int64_t reduction=at::Reduction::Mean) { + return at::_ops::margin_ranking_loss::redispatch(dispatchKeySet, input1, input2, target, margin, reduction); + } + + // aten::matmul(Tensor self, Tensor other) -> Tensor + inline at::Tensor matmul(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::matmul::redispatch(dispatchKeySet, self, other); + } + + // aten::matmul_backward(Tensor grad, Tensor self, Tensor other, bool[2] mask) -> (Tensor, Tensor) + inline ::std::tuple matmul_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & self, const at::Tensor & other, ::std::array mask) { + return at::_ops::matmul_backward::redispatch(dispatchKeySet, grad, self, other, mask); + } + + // aten::matmul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & matmul_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::matmul_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::matmul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & matmul_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::matmul_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::matrix_power(Tensor self, int n) -> Tensor + inline at::Tensor matrix_power(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t n) { + return at::_ops::matrix_power::redispatch(dispatchKeySet, self, n); + } + + // aten::matrix_power.out(Tensor self, int n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & matrix_power_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t n) { + return at::_ops::matrix_power_out::redispatch(dispatchKeySet, self, n, out); + } + + // aten::matrix_power.out(Tensor self, int n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & matrix_power_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t n, at::Tensor & out) { + return at::_ops::matrix_power_out::redispatch(dispatchKeySet, self, n, out); + } + + // aten::matrix_exp(Tensor self) -> Tensor + inline at::Tensor matrix_exp(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::matrix_exp::redispatch(dispatchKeySet, self); + } + + // aten::matrix_exp_backward(Tensor self, Tensor grad) -> Tensor + inline at::Tensor matrix_exp_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & grad) { + return at::_ops::matrix_exp_backward::redispatch(dispatchKeySet, self, grad); + } + + // aten::_aminmax(Tensor self) -> (Tensor, Tensor) + inline ::std::tuple _aminmax(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::_aminmax::redispatch(dispatchKeySet, self); + } + + // aten::_aminmax.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor, Tensor) + inline ::std::tuple _aminmax(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, bool keepdim=false) { + return at::_ops::_aminmax_dim::redispatch(dispatchKeySet, self, dim, keepdim); + } + + // aten::aminmax(Tensor self, *, int? dim=None, bool keepdim=False) -> (Tensor min, Tensor max) + inline ::std::tuple aminmax(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional dim=::std::nullopt, bool keepdim=false) { + return at::_ops::aminmax::redispatch(dispatchKeySet, self, dim, keepdim); + } + + // aten::aminmax.out(Tensor self, *, int? dim=None, bool keepdim=False, Tensor(a!) min, Tensor(b!) max) -> (Tensor(a!) min, Tensor(b!) max) + inline ::std::tuple aminmax_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & min, at::Tensor & max, const at::Tensor & self, ::std::optional dim=::std::nullopt, bool keepdim=false) { + return at::_ops::aminmax_out::redispatch(dispatchKeySet, self, dim, keepdim, min, max); + } + + // aten::aminmax.out(Tensor self, *, int? dim=None, bool keepdim=False, Tensor(a!) min, Tensor(b!) max) -> (Tensor(a!) min, Tensor(b!) max) + inline ::std::tuple aminmax_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional dim, bool keepdim, at::Tensor & min, at::Tensor & max) { + return at::_ops::aminmax_out::redispatch(dispatchKeySet, self, dim, keepdim, min, max); + } + + // aten::_compute_linear_combination(Tensor input, Tensor coefficients) -> Tensor + inline at::Tensor _compute_linear_combination(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & coefficients) { + return at::_ops::_compute_linear_combination::redispatch(dispatchKeySet, input, coefficients); + } + + // aten::_compute_linear_combination.out(Tensor input, Tensor coefficients, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _compute_linear_combination_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & input, const at::Tensor & coefficients) { + return at::_ops::_compute_linear_combination_out::redispatch(dispatchKeySet, input, coefficients, out); + } + + // aten::_compute_linear_combination.out(Tensor input, Tensor coefficients, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _compute_linear_combination_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & coefficients, at::Tensor & out) { + return at::_ops::_compute_linear_combination_out::redispatch(dispatchKeySet, input, coefficients, out); + } + + // aten::max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices) + inline ::std::tuple max(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, bool keepdim=false) { + return at::_ops::max_dim::redispatch(dispatchKeySet, self, dim, keepdim); + } + + // aten::max.dim_max(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) max, Tensor(b!) max_values) -> (Tensor(a!) values, Tensor(b!) indices) + inline ::std::tuple max_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & max, at::Tensor & max_values, const at::Tensor & self, int64_t dim, bool keepdim=false) { + return at::_ops::max_dim_max::redispatch(dispatchKeySet, self, dim, keepdim, max, max_values); + } + + // aten::max.dim_max(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) max, Tensor(b!) max_values) -> (Tensor(a!) values, Tensor(b!) indices) + inline ::std::tuple max_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, bool keepdim, at::Tensor & max, at::Tensor & max_values) { + return at::_ops::max_dim_max::redispatch(dispatchKeySet, self, dim, keepdim, max, max_values); + } + + // aten::max.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices) + inline ::std::tuple max(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, bool keepdim=false) { + return at::_ops::max_names_dim::redispatch(dispatchKeySet, self, dim, keepdim); + } + + // aten::max.names_dim_max(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) max, Tensor(b!) max_values) -> (Tensor(a!) values, Tensor(b!) indices) + inline ::std::tuple max_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & max, at::Tensor & max_values, const at::Tensor & self, at::Dimname dim, bool keepdim=false) { + return at::_ops::max_names_dim_max::redispatch(dispatchKeySet, self, dim, keepdim, max, max_values); + } + + // aten::max.names_dim_max(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) max, Tensor(b!) max_values) -> (Tensor(a!) values, Tensor(b!) indices) + inline ::std::tuple max_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, bool keepdim, at::Tensor & max, at::Tensor & max_values) { + return at::_ops::max_names_dim_max::redispatch(dispatchKeySet, self, dim, keepdim, max, max_values); + } + + // aten::value_selecting_reduction_backward(Tensor grad, int dim, Tensor indices, SymInt[] sizes, bool keepdim) -> Tensor + inline at::Tensor value_selecting_reduction_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, int64_t dim, const at::Tensor & indices, at::IntArrayRef sizes, bool keepdim) { + return at::_ops::value_selecting_reduction_backward::redispatch(dispatchKeySet, grad, dim, indices, c10::fromIntArrayRefSlow(sizes), keepdim); + } + + // aten::value_selecting_reduction_backward(Tensor grad, int dim, Tensor indices, SymInt[] sizes, bool keepdim) -> Tensor + inline at::Tensor value_selecting_reduction_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, int64_t dim, const at::Tensor & indices, c10::SymIntArrayRef sizes, bool keepdim) { + return at::_ops::value_selecting_reduction_backward::redispatch(dispatchKeySet, grad, dim, indices, sizes, keepdim); + } + + // aten::amax(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor + inline at::Tensor amax(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim={}, bool keepdim=false) { + return at::_ops::amax::redispatch(dispatchKeySet, self, dim, keepdim); + } + + // aten::amax.out(Tensor self, int[1] dim=[], bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & amax_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef dim={}, bool keepdim=false) { + return at::_ops::amax_out::redispatch(dispatchKeySet, self, dim, keepdim, out); + } + + // aten::amax.out(Tensor self, int[1] dim=[], bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & amax_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim, bool keepdim, at::Tensor & out) { + return at::_ops::amax_out::redispatch(dispatchKeySet, self, dim, keepdim, out); + } + + // aten::max_pool1d_with_indices(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, int[1] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor) + inline ::std::tuple max_pool1d_with_indices(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, at::IntArrayRef dilation=1, bool ceil_mode=false) { + return at::_ops::max_pool1d_with_indices::redispatch(dispatchKeySet, self, kernel_size, stride, padding, dilation, ceil_mode); + } + + // aten::max_pool1d(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, int[1] dilation=1, bool ceil_mode=False) -> Tensor + inline at::Tensor max_pool1d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, at::IntArrayRef dilation=1, bool ceil_mode=false) { + return at::_ops::max_pool1d::redispatch(dispatchKeySet, self, kernel_size, stride, padding, dilation, ceil_mode); + } + + // aten::max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor + inline at::Tensor max_pool2d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, at::IntArrayRef dilation=1, bool ceil_mode=false) { + return at::_ops::max_pool2d::redispatch(dispatchKeySet, self, kernel_size, stride, padding, dilation, ceil_mode); + } + + // aten::max_pool2d_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor + inline at::Tensor max_pool2d_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, at::IntArrayRef dilation=1, bool ceil_mode=false) { + return at::_ops::max_pool2d_backward::redispatch(dispatchKeySet, grad_output, self, kernel_size, stride, padding, dilation, ceil_mode); + } + + // aten::mkldnn_max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor + inline at::Tensor mkldnn_max_pool2d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, at::IntArrayRef dilation=1, bool ceil_mode=false) { + return at::_ops::mkldnn_max_pool2d::redispatch(dispatchKeySet, self, kernel_size, stride, padding, dilation, ceil_mode); + } + + // aten::mkldnn_max_pool2d_backward(Tensor grad_output, Tensor output, Tensor input, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor + inline at::Tensor mkldnn_max_pool2d_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & output, const at::Tensor & input, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, at::IntArrayRef dilation=1, bool ceil_mode=false) { + return at::_ops::mkldnn_max_pool2d_backward::redispatch(dispatchKeySet, grad_output, output, input, kernel_size, stride, padding, dilation, ceil_mode); + } + + // aten::mkldnn_max_pool3d(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False) -> Tensor + inline at::Tensor mkldnn_max_pool3d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, at::IntArrayRef dilation=1, bool ceil_mode=false) { + return at::_ops::mkldnn_max_pool3d::redispatch(dispatchKeySet, self, kernel_size, stride, padding, dilation, ceil_mode); + } + + // aten::mkldnn_max_pool3d_backward(Tensor grad_output, Tensor output, Tensor input, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False) -> Tensor + inline at::Tensor mkldnn_max_pool3d_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & output, const at::Tensor & input, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, at::IntArrayRef dilation=1, bool ceil_mode=false) { + return at::_ops::mkldnn_max_pool3d_backward::redispatch(dispatchKeySet, grad_output, output, input, kernel_size, stride, padding, dilation, ceil_mode); + } + + // aten::quantized_max_pool1d(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, int[1] dilation=1, bool ceil_mode=False) -> Tensor + inline at::Tensor quantized_max_pool1d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, at::IntArrayRef dilation=1, bool ceil_mode=false) { + return at::_ops::quantized_max_pool1d::redispatch(dispatchKeySet, self, kernel_size, stride, padding, dilation, ceil_mode); + } + + // aten::quantized_max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor + inline at::Tensor quantized_max_pool2d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, at::IntArrayRef dilation=1, bool ceil_mode=false) { + return at::_ops::quantized_max_pool2d::redispatch(dispatchKeySet, self, kernel_size, stride, padding, dilation, ceil_mode); + } + + // aten::quantized_max_pool3d(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False) -> Tensor + inline at::Tensor quantized_max_pool3d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, at::IntArrayRef dilation=1, bool ceil_mode=false) { + return at::_ops::quantized_max_pool3d::redispatch(dispatchKeySet, self, kernel_size, stride, padding, dilation, ceil_mode); + } + + // aten::max_pool3d(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False) -> Tensor + inline at::Tensor max_pool3d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, at::IntArrayRef dilation=1, bool ceil_mode=false) { + return at::_ops::max_pool3d::redispatch(dispatchKeySet, self, kernel_size, stride, padding, dilation, ceil_mode); + } + + // aten::mean(Tensor self, *, ScalarType? dtype=None) -> Tensor + inline at::Tensor mean(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional dtype=::std::nullopt) { + return at::_ops::mean::redispatch(dispatchKeySet, self, dtype); + } + + // aten::mean.dtype_out(Tensor self, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mean_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, ::std::optional dtype=::std::nullopt) { + return at::_ops::mean_dtype_out::redispatch(dispatchKeySet, self, dtype, out); + } + + // aten::mean.dtype_out(Tensor self, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mean_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional dtype, at::Tensor & out) { + return at::_ops::mean_dtype_out::redispatch(dispatchKeySet, self, dtype, out); + } + + // aten::mean.dim(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor + inline at::Tensor mean(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim=false, ::std::optional dtype=::std::nullopt) { + return at::_ops::mean_dim::redispatch(dispatchKeySet, self, dim, keepdim, dtype); + } + + // aten::mean.out(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mean_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim=false, ::std::optional dtype=::std::nullopt) { + return at::_ops::mean_out::redispatch(dispatchKeySet, self, dim, keepdim, dtype, out); + } + + // aten::mean.out(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mean_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim, ::std::optional dtype, at::Tensor & out) { + return at::_ops::mean_out::redispatch(dispatchKeySet, self, dim, keepdim, dtype, out); + } + + // aten::mean.names_dim(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor + inline at::Tensor mean(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::DimnameList dim, bool keepdim=false, ::std::optional dtype=::std::nullopt) { + return at::_ops::mean_names_dim::redispatch(dispatchKeySet, self, dim, keepdim, dtype); + } + + // aten::mean.names_out(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mean_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::DimnameList dim, bool keepdim=false, ::std::optional dtype=::std::nullopt) { + return at::_ops::mean_names_out::redispatch(dispatchKeySet, self, dim, keepdim, dtype, out); + } + + // aten::mean.names_out(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mean_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::DimnameList dim, bool keepdim, ::std::optional dtype, at::Tensor & out) { + return at::_ops::mean_names_out::redispatch(dispatchKeySet, self, dim, keepdim, dtype, out); + } + + // aten::nanmean(Tensor self, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor + inline at::Tensor nanmean(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef dim=::std::nullopt, bool keepdim=false, ::std::optional dtype=::std::nullopt) { + return at::_ops::nanmean::redispatch(dispatchKeySet, self, dim, keepdim, dtype); + } + + // aten::nanmean.out(Tensor self, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & nanmean_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::OptionalIntArrayRef dim=::std::nullopt, bool keepdim=false, ::std::optional dtype=::std::nullopt) { + return at::_ops::nanmean_out::redispatch(dispatchKeySet, self, dim, keepdim, dtype, out); + } + + // aten::nanmean.out(Tensor self, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & nanmean_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim, ::std::optional dtype, at::Tensor & out) { + return at::_ops::nanmean_out::redispatch(dispatchKeySet, self, dim, keepdim, dtype, out); + } + + // aten::median(Tensor self) -> Tensor + inline at::Tensor median(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::median::redispatch(dispatchKeySet, self); + } + + // aten::median.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices) + inline ::std::tuple median(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, bool keepdim=false) { + return at::_ops::median_dim::redispatch(dispatchKeySet, self, dim, keepdim); + } + + // aten::median.dim_values(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + inline ::std::tuple median_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & values, at::Tensor & indices, const at::Tensor & self, int64_t dim, bool keepdim=false) { + return at::_ops::median_dim_values::redispatch(dispatchKeySet, self, dim, keepdim, values, indices); + } + + // aten::median.dim_values(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + inline ::std::tuple median_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, bool keepdim, at::Tensor & values, at::Tensor & indices) { + return at::_ops::median_dim_values::redispatch(dispatchKeySet, self, dim, keepdim, values, indices); + } + + // aten::median.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices) + inline ::std::tuple median(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, bool keepdim=false) { + return at::_ops::median_names_dim::redispatch(dispatchKeySet, self, dim, keepdim); + } + + // aten::median.names_dim_values(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + inline ::std::tuple median_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & values, at::Tensor & indices, const at::Tensor & self, at::Dimname dim, bool keepdim=false) { + return at::_ops::median_names_dim_values::redispatch(dispatchKeySet, self, dim, keepdim, values, indices); + } + + // aten::median.names_dim_values(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + inline ::std::tuple median_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, bool keepdim, at::Tensor & values, at::Tensor & indices) { + return at::_ops::median_names_dim_values::redispatch(dispatchKeySet, self, dim, keepdim, values, indices); + } + + // aten::nanmedian(Tensor self) -> Tensor + inline at::Tensor nanmedian(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::nanmedian::redispatch(dispatchKeySet, self); + } + + // aten::nanmedian.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices) + inline ::std::tuple nanmedian(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, bool keepdim=false) { + return at::_ops::nanmedian_dim::redispatch(dispatchKeySet, self, dim, keepdim); + } + + // aten::nanmedian.dim_values(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + inline ::std::tuple nanmedian_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & values, at::Tensor & indices, const at::Tensor & self, int64_t dim, bool keepdim=false) { + return at::_ops::nanmedian_dim_values::redispatch(dispatchKeySet, self, dim, keepdim, values, indices); + } + + // aten::nanmedian.dim_values(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + inline ::std::tuple nanmedian_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, bool keepdim, at::Tensor & values, at::Tensor & indices) { + return at::_ops::nanmedian_dim_values::redispatch(dispatchKeySet, self, dim, keepdim, values, indices); + } + + // aten::nanmedian.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices) + inline ::std::tuple nanmedian(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, bool keepdim=false) { + return at::_ops::nanmedian_names_dim::redispatch(dispatchKeySet, self, dim, keepdim); + } + + // aten::nanmedian.names_dim_values(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + inline ::std::tuple nanmedian_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & values, at::Tensor & indices, const at::Tensor & self, at::Dimname dim, bool keepdim=false) { + return at::_ops::nanmedian_names_dim_values::redispatch(dispatchKeySet, self, dim, keepdim, values, indices); + } + + // aten::nanmedian.names_dim_values(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + inline ::std::tuple nanmedian_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, bool keepdim, at::Tensor & values, at::Tensor & indices) { + return at::_ops::nanmedian_names_dim_values::redispatch(dispatchKeySet, self, dim, keepdim, values, indices); + } + + // aten::min.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices) + inline ::std::tuple min(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, bool keepdim=false) { + return at::_ops::min_dim::redispatch(dispatchKeySet, self, dim, keepdim); + } + + // aten::min.dim_min(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) min, Tensor(b!) min_indices) -> (Tensor(a!) values, Tensor(b!) indices) + inline ::std::tuple min_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & min, at::Tensor & min_indices, const at::Tensor & self, int64_t dim, bool keepdim=false) { + return at::_ops::min_dim_min::redispatch(dispatchKeySet, self, dim, keepdim, min, min_indices); + } + + // aten::min.dim_min(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) min, Tensor(b!) min_indices) -> (Tensor(a!) values, Tensor(b!) indices) + inline ::std::tuple min_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, bool keepdim, at::Tensor & min, at::Tensor & min_indices) { + return at::_ops::min_dim_min::redispatch(dispatchKeySet, self, dim, keepdim, min, min_indices); + } + + // aten::min.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices) + inline ::std::tuple min(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, bool keepdim=false) { + return at::_ops::min_names_dim::redispatch(dispatchKeySet, self, dim, keepdim); + } + + // aten::min.names_dim_min(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) min, Tensor(b!) min_indices) -> (Tensor(a!) values, Tensor(b!) indices) + inline ::std::tuple min_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & min, at::Tensor & min_indices, const at::Tensor & self, at::Dimname dim, bool keepdim=false) { + return at::_ops::min_names_dim_min::redispatch(dispatchKeySet, self, dim, keepdim, min, min_indices); + } + + // aten::min.names_dim_min(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) min, Tensor(b!) min_indices) -> (Tensor(a!) values, Tensor(b!) indices) + inline ::std::tuple min_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, bool keepdim, at::Tensor & min, at::Tensor & min_indices) { + return at::_ops::min_names_dim_min::redispatch(dispatchKeySet, self, dim, keepdim, min, min_indices); + } + + // aten::amin(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor + inline at::Tensor amin(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim={}, bool keepdim=false) { + return at::_ops::amin::redispatch(dispatchKeySet, self, dim, keepdim); + } + + // aten::amin.out(Tensor self, int[1] dim=[], bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & amin_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef dim={}, bool keepdim=false) { + return at::_ops::amin_out::redispatch(dispatchKeySet, self, dim, keepdim, out); + } + + // aten::amin.out(Tensor self, int[1] dim=[], bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & amin_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim, bool keepdim, at::Tensor & out) { + return at::_ops::amin_out::redispatch(dispatchKeySet, self, dim, keepdim, out); + } + + // aten::_mps_convolution(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups) -> Tensor + inline at::Tensor _mps_convolution(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, at::IntArrayRef padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups) { + return at::_ops::_mps_convolution::redispatch(dispatchKeySet, self, weight, bias, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups); + } + + // aten::_mps_convolution(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups) -> Tensor + inline at::Tensor _mps_convolution_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups) { + return at::_ops::_mps_convolution::redispatch(dispatchKeySet, self, weight, bias, padding, stride, dilation, groups); + } + + // aten::mps_convolution_backward(Tensor self, Tensor grad_output, Tensor weight, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor) + inline ::std::tuple mps_convolution_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & grad_output, const at::Tensor & weight, at::IntArrayRef padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups, ::std::array output_mask) { + return at::_ops::mps_convolution_backward::redispatch(dispatchKeySet, self, grad_output, weight, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups, output_mask); + } + + // aten::mps_convolution_backward(Tensor self, Tensor grad_output, Tensor weight, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor) + inline ::std::tuple mps_convolution_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & grad_output, const at::Tensor & weight, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, ::std::array output_mask) { + return at::_ops::mps_convolution_backward::redispatch(dispatchKeySet, self, grad_output, weight, padding, stride, dilation, groups, output_mask); + } + + // aten::mkldnn_convolution(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups) -> Tensor + inline at::Tensor mkldnn_convolution(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, at::IntArrayRef padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups) { + return at::_ops::mkldnn_convolution::redispatch(dispatchKeySet, self, weight, bias, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups); + } + + // aten::mkldnn_convolution(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups) -> Tensor + inline at::Tensor mkldnn_convolution_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups) { + return at::_ops::mkldnn_convolution::redispatch(dispatchKeySet, self, weight, bias, padding, stride, dilation, groups); + } + + // aten::mkldnn_rnn_layer(Tensor input, Tensor weight0, Tensor weight1, Tensor weight2, Tensor weight3, Tensor hx_, Tensor cx_, bool reverse, int[] batch_sizes, int mode, int hidden_size, int num_layers, bool has_biases, bool bidirectional, bool batch_first, bool train) -> (Tensor, Tensor, Tensor, Tensor) + inline ::std::tuple mkldnn_rnn_layer(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight0, const at::Tensor & weight1, const at::Tensor & weight2, const at::Tensor & weight3, const at::Tensor & hx_, const at::Tensor & cx_, bool reverse, at::IntArrayRef batch_sizes, int64_t mode, int64_t hidden_size, int64_t num_layers, bool has_biases, bool bidirectional, bool batch_first, bool train) { + return at::_ops::mkldnn_rnn_layer::redispatch(dispatchKeySet, input, weight0, weight1, weight2, weight3, hx_, cx_, reverse, batch_sizes, mode, hidden_size, num_layers, has_biases, bidirectional, batch_first, train); + } + + // aten::mkldnn_rnn_layer_backward(Tensor input, Tensor weight1, Tensor weight2, Tensor weight3, Tensor weight4, Tensor hx_, Tensor cx_tmp, Tensor output, Tensor hy_, Tensor cy_, Tensor? grad_output, Tensor? grad_hy, Tensor? grad_cy, bool reverse, int mode, int hidden_size, int num_layers, bool has_biases, bool train, bool bidirectional, int[] batch_sizes, bool batch_first, Tensor workspace) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor) + inline ::std::tuple mkldnn_rnn_layer_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight1, const at::Tensor & weight2, const at::Tensor & weight3, const at::Tensor & weight4, const at::Tensor & hx_, const at::Tensor & cx_tmp, const at::Tensor & output, const at::Tensor & hy_, const at::Tensor & cy_, const ::std::optional & grad_output, const ::std::optional & grad_hy, const ::std::optional & grad_cy, bool reverse, int64_t mode, int64_t hidden_size, int64_t num_layers, bool has_biases, bool train, bool bidirectional, at::IntArrayRef batch_sizes, bool batch_first, const at::Tensor & workspace) { + return at::_ops::mkldnn_rnn_layer_backward::redispatch(dispatchKeySet, input, weight1, weight2, weight3, weight4, hx_, cx_tmp, output, hy_, cy_, grad_output, grad_hy, grad_cy, reverse, mode, hidden_size, num_layers, has_biases, train, bidirectional, batch_sizes, batch_first, workspace); + } + + // aten::miopen_batch_norm(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon) -> (Tensor, Tensor, Tensor) + inline ::std::tuple miopen_batch_norm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, const ::std::optional & running_mean, const ::std::optional & running_var, bool training, double exponential_average_factor, double epsilon) { + return at::_ops::miopen_batch_norm::redispatch(dispatchKeySet, input, weight, bias, running_mean, running_var, training, exponential_average_factor, epsilon); + } + + // aten::miopen_batch_norm_backward(Tensor input, Tensor grad_output, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, float epsilon) -> (Tensor, Tensor, Tensor) + inline ::std::tuple miopen_batch_norm_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & grad_output, const at::Tensor & weight, const ::std::optional & running_mean, const ::std::optional & running_var, const ::std::optional & save_mean, const ::std::optional & save_var, double epsilon) { + return at::_ops::miopen_batch_norm_backward::redispatch(dispatchKeySet, input, grad_output, weight, running_mean, running_var, save_mean, save_var, epsilon); + } + + // aten::miopen_convolution(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic) -> Tensor + inline at::Tensor miopen_convolution(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, at::IntArrayRef padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups, bool benchmark, bool deterministic) { + return at::_ops::miopen_convolution::redispatch(dispatchKeySet, self, weight, bias, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups, benchmark, deterministic); + } + + // aten::miopen_convolution(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic) -> Tensor + inline at::Tensor miopen_convolution_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, bool benchmark, bool deterministic) { + return at::_ops::miopen_convolution::redispatch(dispatchKeySet, self, weight, bias, padding, stride, dilation, groups, benchmark, deterministic); + } + + // aten::miopen_convolution_transpose(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic) -> Tensor + inline at::Tensor miopen_convolution_transpose(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, at::IntArrayRef padding, at::IntArrayRef output_padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups, bool benchmark, bool deterministic) { + return at::_ops::miopen_convolution_transpose::redispatch(dispatchKeySet, self, weight, bias, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(output_padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups, benchmark, deterministic); + } + + // aten::miopen_convolution_transpose(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic) -> Tensor + inline at::Tensor miopen_convolution_transpose_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, bool benchmark, bool deterministic) { + return at::_ops::miopen_convolution_transpose::redispatch(dispatchKeySet, self, weight, bias, padding, output_padding, stride, dilation, groups, benchmark, deterministic); + } + + // aten::miopen_depthwise_convolution(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic) -> Tensor + inline at::Tensor miopen_depthwise_convolution(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, at::IntArrayRef padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups, bool benchmark, bool deterministic) { + return at::_ops::miopen_depthwise_convolution::redispatch(dispatchKeySet, self, weight, bias, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups, benchmark, deterministic); + } + + // aten::miopen_depthwise_convolution(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic) -> Tensor + inline at::Tensor miopen_depthwise_convolution_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, bool benchmark, bool deterministic) { + return at::_ops::miopen_depthwise_convolution::redispatch(dispatchKeySet, self, weight, bias, padding, stride, dilation, groups, benchmark, deterministic); + } + + // aten::miopen_convolution_relu(Tensor self, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, SymInt groups) -> Tensor + inline at::Tensor miopen_convolution_relu(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, int64_t groups) { + return at::_ops::miopen_convolution_relu::redispatch(dispatchKeySet, self, weight, bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), groups); + } + + // aten::miopen_convolution_relu(Tensor self, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, SymInt groups) -> Tensor + inline at::Tensor miopen_convolution_relu_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, c10::SymInt groups) { + return at::_ops::miopen_convolution_relu::redispatch(dispatchKeySet, self, weight, bias, stride, padding, dilation, groups); + } + + // aten::miopen_convolution_add_relu(Tensor self, Tensor weight, Tensor z, Scalar? alpha, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, SymInt groups) -> Tensor + inline at::Tensor miopen_convolution_add_relu(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const at::Tensor & z, const ::std::optional & alpha, const ::std::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, int64_t groups) { + return at::_ops::miopen_convolution_add_relu::redispatch(dispatchKeySet, self, weight, z, alpha, bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), groups); + } + + // aten::miopen_convolution_add_relu(Tensor self, Tensor weight, Tensor z, Scalar? alpha, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, SymInt groups) -> Tensor + inline at::Tensor miopen_convolution_add_relu_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const at::Tensor & z, const ::std::optional & alpha, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, c10::SymInt groups) { + return at::_ops::miopen_convolution_add_relu::redispatch(dispatchKeySet, self, weight, z, alpha, bias, stride, padding, dilation, groups); + } + + // aten::miopen_rnn(Tensor input, Tensor[] weight, int weight_stride0, Tensor hx, Tensor? cx, int mode, int hidden_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, int[] batch_sizes, Tensor? dropout_state) -> (Tensor, Tensor, Tensor, Tensor, Tensor) + inline ::std::tuple miopen_rnn(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::TensorList weight, int64_t weight_stride0, const at::Tensor & hx, const ::std::optional & cx, int64_t mode, int64_t hidden_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, at::IntArrayRef batch_sizes, const ::std::optional & dropout_state) { + return at::_ops::miopen_rnn::redispatch(dispatchKeySet, input, weight, weight_stride0, hx, cx, mode, hidden_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state); + } + + // aten::miopen_rnn_backward(Tensor input, Tensor[] weight, int weight_stride0, Tensor weight_buf, Tensor hx, Tensor? cx, Tensor output, Tensor? grad_output, Tensor? grad_hy, Tensor? grad_cy, int mode, int hidden_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, int[] batch_sizes, Tensor? dropout_state, Tensor reserve, bool[4] output_mask) -> (Tensor, Tensor, Tensor, Tensor[]) + inline ::std::tuple> miopen_rnn_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::TensorList weight, int64_t weight_stride0, const at::Tensor & weight_buf, const at::Tensor & hx, const ::std::optional & cx, const at::Tensor & output, const ::std::optional & grad_output, const ::std::optional & grad_hy, const ::std::optional & grad_cy, int64_t mode, int64_t hidden_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, at::IntArrayRef batch_sizes, const ::std::optional & dropout_state, const at::Tensor & reserve, ::std::array output_mask) { + return at::_ops::miopen_rnn_backward::redispatch(dispatchKeySet, input, weight, weight_stride0, weight_buf, hx, cx, output, grad_output, grad_hy, grad_cy, mode, hidden_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state, reserve, output_mask); + } + + // aten::mm(Tensor self, Tensor mat2) -> Tensor + inline at::Tensor mm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mat2) { + return at::_ops::mm::redispatch(dispatchKeySet, self, mat2); + } + + // aten::mm.out(Tensor self, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & mat2) { + return at::_ops::mm_out::redispatch(dispatchKeySet, self, mat2, out); + } + + // aten::mm.out(Tensor self, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mat2, at::Tensor & out) { + return at::_ops::mm_out::redispatch(dispatchKeySet, self, mat2, out); + } + + // aten::mm.dtype(Tensor self, Tensor mat2, ScalarType out_dtype) -> Tensor + inline at::Tensor mm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mat2, at::ScalarType out_dtype) { + return at::_ops::mm_dtype::redispatch(dispatchKeySet, self, mat2, out_dtype); + } + + // aten::mm.dtype_out(Tensor self, Tensor mat2, ScalarType out_dtype, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & mat2, at::ScalarType out_dtype) { + return at::_ops::mm_dtype_out::redispatch(dispatchKeySet, self, mat2, out_dtype, out); + } + + // aten::mm.dtype_out(Tensor self, Tensor mat2, ScalarType out_dtype, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mat2, at::ScalarType out_dtype, at::Tensor & out) { + return at::_ops::mm_dtype_out::redispatch(dispatchKeySet, self, mat2, out_dtype, out); + } + + // aten::_int_mm(Tensor self, Tensor mat2) -> Tensor + inline at::Tensor _int_mm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mat2) { + return at::_ops::_int_mm::redispatch(dispatchKeySet, self, mat2); + } + + // aten::_int_mm.out(Tensor self, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _int_mm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & mat2) { + return at::_ops::_int_mm_out::redispatch(dispatchKeySet, self, mat2, out); + } + + // aten::_int_mm.out(Tensor self, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _int_mm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mat2, at::Tensor & out) { + return at::_ops::_int_mm_out::redispatch(dispatchKeySet, self, mat2, out); + } + + // aten::_convert_weight_to_int4pack(Tensor self, int innerKTiles) -> Tensor + inline at::Tensor _convert_weight_to_int4pack(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t innerKTiles) { + return at::_ops::_convert_weight_to_int4pack::redispatch(dispatchKeySet, self, innerKTiles); + } + + // aten::_weight_int4pack_mm(Tensor self, Tensor mat2, int qGroupSize, Tensor qScaleAndZeros) -> Tensor + inline at::Tensor _weight_int4pack_mm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mat2, int64_t qGroupSize, const at::Tensor & qScaleAndZeros) { + return at::_ops::_weight_int4pack_mm::redispatch(dispatchKeySet, self, mat2, qGroupSize, qScaleAndZeros); + } + + // aten::_weight_int4pack_mm_with_scales_and_zeros(Tensor self, Tensor mat2, int qGroupSize, Tensor qScale, Tensor qZeros) -> Tensor + inline at::Tensor _weight_int4pack_mm_with_scales_and_zeros(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mat2, int64_t qGroupSize, const at::Tensor & qScale, const at::Tensor & qZeros) { + return at::_ops::_weight_int4pack_mm_with_scales_and_zeros::redispatch(dispatchKeySet, self, mat2, qGroupSize, qScale, qZeros); + } + + // aten::_convert_weight_to_int4pack_for_cpu(Tensor self, int innerKTiles) -> Tensor + inline at::Tensor _convert_weight_to_int4pack_for_cpu(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t innerKTiles) { + return at::_ops::_convert_weight_to_int4pack_for_cpu::redispatch(dispatchKeySet, self, innerKTiles); + } + + // aten::_weight_int4pack_mm_for_cpu(Tensor self, Tensor mat2, int qGroupSize, Tensor qScaleAndZeros) -> Tensor + inline at::Tensor _weight_int4pack_mm_for_cpu(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mat2, int64_t qGroupSize, const at::Tensor & qScaleAndZeros) { + return at::_ops::_weight_int4pack_mm_for_cpu::redispatch(dispatchKeySet, self, mat2, qGroupSize, qScaleAndZeros); + } + + // aten::_dyn_quant_pack_4bit_weight(Tensor weights, Tensor scales_zeros, Tensor? bias, int block_size, int in_features, int out_features) -> Tensor + inline at::Tensor _dyn_quant_pack_4bit_weight(c10::DispatchKeySet dispatchKeySet, const at::Tensor & weights, const at::Tensor & scales_zeros, const ::std::optional & bias, int64_t block_size, int64_t in_features, int64_t out_features) { + return at::_ops::_dyn_quant_pack_4bit_weight::redispatch(dispatchKeySet, weights, scales_zeros, bias, block_size, in_features, out_features); + } + + // aten::_dyn_quant_matmul_4bit(Tensor inp, Tensor packed_weights, int block_size, int in_features, int out_features) -> Tensor + inline at::Tensor _dyn_quant_matmul_4bit(c10::DispatchKeySet dispatchKeySet, const at::Tensor & inp, const at::Tensor & packed_weights, int64_t block_size, int64_t in_features, int64_t out_features) { + return at::_ops::_dyn_quant_matmul_4bit::redispatch(dispatchKeySet, inp, packed_weights, block_size, in_features, out_features); + } + + // aten::_weight_int8pack_mm(Tensor self, Tensor mat2, Tensor scales) -> Tensor + inline at::Tensor _weight_int8pack_mm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mat2, const at::Tensor & scales) { + return at::_ops::_weight_int8pack_mm::redispatch(dispatchKeySet, self, mat2, scales); + } + + // aten::_sparse_mm(Tensor sparse, Tensor dense) -> Tensor + inline at::Tensor _sparse_mm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & sparse, const at::Tensor & dense) { + return at::_ops::_sparse_mm::redispatch(dispatchKeySet, sparse, dense); + } + + // aten::_sparse_mm.reduce(Tensor sparse, Tensor dense, str reduce) -> Tensor + inline at::Tensor _sparse_mm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & sparse, const at::Tensor & dense, c10::string_view reduce) { + return at::_ops::_sparse_mm_reduce::redispatch(dispatchKeySet, sparse, dense, reduce); + } + + // aten::_sparse_sparse_matmul(Tensor self, Tensor other) -> Tensor + inline at::Tensor _sparse_sparse_matmul(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::_sparse_sparse_matmul::redispatch(dispatchKeySet, self, other); + } + + // aten::mode(Tensor self, int dim=-1, bool keepdim=False) -> (Tensor values, Tensor indices) + inline ::std::tuple mode(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim=-1, bool keepdim=false) { + return at::_ops::mode::redispatch(dispatchKeySet, self, dim, keepdim); + } + + // aten::mode.values(Tensor self, int dim=-1, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + inline ::std::tuple mode_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & values, at::Tensor & indices, const at::Tensor & self, int64_t dim=-1, bool keepdim=false) { + return at::_ops::mode_values::redispatch(dispatchKeySet, self, dim, keepdim, values, indices); + } + + // aten::mode.values(Tensor self, int dim=-1, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + inline ::std::tuple mode_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, bool keepdim, at::Tensor & values, at::Tensor & indices) { + return at::_ops::mode_values::redispatch(dispatchKeySet, self, dim, keepdim, values, indices); + } + + // aten::mode.dimname(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices) + inline ::std::tuple mode(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, bool keepdim=false) { + return at::_ops::mode_dimname::redispatch(dispatchKeySet, self, dim, keepdim); + } + + // aten::mode.dimname_out(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + inline ::std::tuple mode_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & values, at::Tensor & indices, const at::Tensor & self, at::Dimname dim, bool keepdim=false) { + return at::_ops::mode_dimname_out::redispatch(dispatchKeySet, self, dim, keepdim, values, indices); + } + + // aten::mode.dimname_out(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + inline ::std::tuple mode_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, bool keepdim, at::Tensor & values, at::Tensor & indices) { + return at::_ops::mode_dimname_out::redispatch(dispatchKeySet, self, dim, keepdim, values, indices); + } + + // aten::mul.Tensor(Tensor self, Tensor other) -> Tensor + inline at::Tensor mul(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::mul_Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::mul_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + inline at::Tensor & mul_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) { + return at::_ops::mul__Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::mul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mul_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::mul_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::mul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mul_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::mul_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::mul.Scalar(Tensor self, Scalar other) -> Tensor + inline at::Tensor mul(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::mul_Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::mul_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + inline at::Tensor & mul_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other) { + return at::_ops::mul__Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::multiply.Tensor(Tensor self, Tensor other) -> Tensor + inline at::Tensor multiply(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::multiply_Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::multiply_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + inline at::Tensor & multiply_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) { + return at::_ops::multiply__Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::multiply.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & multiply_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::multiply_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::multiply.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & multiply_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::multiply_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::multiply.Scalar(Tensor self, Scalar other) -> Tensor + inline at::Tensor multiply(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::multiply_Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::multiply_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + inline at::Tensor & multiply_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other) { + return at::_ops::multiply__Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::mv(Tensor self, Tensor vec) -> Tensor + inline at::Tensor mv(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & vec) { + return at::_ops::mv::redispatch(dispatchKeySet, self, vec); + } + + // aten::mv.out(Tensor self, Tensor vec, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mv_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & vec) { + return at::_ops::mv_out::redispatch(dispatchKeySet, self, vec, out); + } + + // aten::mv.out(Tensor self, Tensor vec, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mv_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & vec, at::Tensor & out) { + return at::_ops::mv_out::redispatch(dispatchKeySet, self, vec, out); + } + + // aten::mvlgamma.out(Tensor self, int p, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mvlgamma_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t p) { + return at::_ops::mvlgamma_out::redispatch(dispatchKeySet, self, p, out); + } + + // aten::mvlgamma.out(Tensor self, int p, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mvlgamma_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t p, at::Tensor & out) { + return at::_ops::mvlgamma_out::redispatch(dispatchKeySet, self, p, out); + } + + // aten::mvlgamma(Tensor self, int p) -> Tensor + inline at::Tensor mvlgamma(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t p) { + return at::_ops::mvlgamma::redispatch(dispatchKeySet, self, p); + } + + // aten::mvlgamma_(Tensor(a!) self, int p) -> Tensor(a!) + inline at::Tensor & mvlgamma_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, int64_t p) { + return at::_ops::mvlgamma_::redispatch(dispatchKeySet, self, p); + } + + // aten::narrow_copy(Tensor self, int dim, SymInt start, SymInt length) -> Tensor + inline at::Tensor narrow_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, int64_t start, int64_t length) { + return at::_ops::narrow_copy::redispatch(dispatchKeySet, self, dim, start, length); + } + + // aten::narrow_copy(Tensor self, int dim, SymInt start, SymInt length) -> Tensor + inline at::Tensor narrow_copy_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, c10::SymInt start, c10::SymInt length) { + return at::_ops::narrow_copy::redispatch(dispatchKeySet, self, dim, start, length); + } + + // aten::narrow_copy.out(Tensor self, int dim, SymInt start, SymInt length, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & narrow_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim, int64_t start, int64_t length) { + return at::_ops::narrow_copy_out::redispatch(dispatchKeySet, self, dim, start, length, out); + } + + // aten::narrow_copy.out(Tensor self, int dim, SymInt start, SymInt length, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & narrow_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, int64_t start, int64_t length, at::Tensor & out) { + return at::_ops::narrow_copy_out::redispatch(dispatchKeySet, self, dim, start, length, out); + } + + // aten::narrow_copy.out(Tensor self, int dim, SymInt start, SymInt length, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & narrow_copy_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim, c10::SymInt start, c10::SymInt length) { + return at::_ops::narrow_copy_out::redispatch(dispatchKeySet, self, dim, start, length, out); + } + + // aten::narrow_copy.out(Tensor self, int dim, SymInt start, SymInt length, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & narrow_copy_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, c10::SymInt start, c10::SymInt length, at::Tensor & out) { + return at::_ops::narrow_copy_out::redispatch(dispatchKeySet, self, dim, start, length, out); + } + + // aten::narrow(Tensor(a) self, int dim, SymInt start, SymInt length) -> Tensor(a) + inline at::Tensor narrow(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, int64_t start, int64_t length) { + return at::_ops::narrow::redispatch(dispatchKeySet, self, dim, start, length); + } + + // aten::narrow(Tensor(a) self, int dim, SymInt start, SymInt length) -> Tensor(a) + inline at::Tensor narrow_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, c10::SymInt start, c10::SymInt length) { + return at::_ops::narrow::redispatch(dispatchKeySet, self, dim, start, length); + } + + // aten::narrow.Tensor(Tensor(a) self, int dim, Tensor start, SymInt length) -> Tensor(a) + inline at::Tensor narrow(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, const at::Tensor & start, int64_t length) { + return at::_ops::narrow_Tensor::redispatch(dispatchKeySet, self, dim, start, length); + } + + // aten::narrow.Tensor(Tensor(a) self, int dim, Tensor start, SymInt length) -> Tensor(a) + inline at::Tensor narrow_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, const at::Tensor & start, c10::SymInt length) { + return at::_ops::narrow_Tensor::redispatch(dispatchKeySet, self, dim, start, length); + } + + // aten::native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor) + inline ::std::tuple native_batch_norm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const ::std::optional & running_mean, const ::std::optional & running_var, bool training, double momentum, double eps) { + return at::_ops::native_batch_norm::redispatch(dispatchKeySet, input, weight, bias, running_mean, running_var, training, momentum, eps); + } + + // aten::native_batch_norm.out(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, *, Tensor(a!) out, Tensor(b!) save_mean, Tensor(c!) save_invstd) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple native_batch_norm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::Tensor & save_mean, at::Tensor & save_invstd, const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const ::std::optional & running_mean, const ::std::optional & running_var, bool training, double momentum, double eps) { + return at::_ops::native_batch_norm_out::redispatch(dispatchKeySet, input, weight, bias, running_mean, running_var, training, momentum, eps, out, save_mean, save_invstd); + } + + // aten::native_batch_norm.out(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, *, Tensor(a!) out, Tensor(b!) save_mean, Tensor(c!) save_invstd) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple native_batch_norm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const ::std::optional & running_mean, const ::std::optional & running_var, bool training, double momentum, double eps, at::Tensor & out, at::Tensor & save_mean, at::Tensor & save_invstd) { + return at::_ops::native_batch_norm_out::redispatch(dispatchKeySet, input, weight, bias, running_mean, running_var, training, momentum, eps, out, save_mean, save_invstd); + } + + // aten::_native_batch_norm_legit(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor) + inline ::std::tuple _native_batch_norm_legit(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, at::Tensor & running_mean, at::Tensor & running_var, bool training, double momentum, double eps) { + return at::_ops::_native_batch_norm_legit::redispatch(dispatchKeySet, input, weight, bias, running_mean, running_var, training, momentum, eps); + } + + // aten::_native_batch_norm_legit_no_training(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor) + inline ::std::tuple _native_batch_norm_legit_no_training(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const at::Tensor & running_mean, const at::Tensor & running_var, double momentum, double eps) { + return at::_ops::_native_batch_norm_legit_no_training::redispatch(dispatchKeySet, input, weight, bias, running_mean, running_var, momentum, eps); + } + + // aten::_native_batch_norm_legit.out(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, bool training, float momentum, float eps, *, Tensor(d!) out, Tensor(e!) save_mean, Tensor(f!) save_invstd) -> (Tensor(d!), Tensor(e!), Tensor(f!)) + inline ::std::tuple _native_batch_norm_legit_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::Tensor & save_mean, at::Tensor & save_invstd, const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, at::Tensor & running_mean, at::Tensor & running_var, bool training, double momentum, double eps) { + return at::_ops::_native_batch_norm_legit_out::redispatch(dispatchKeySet, input, weight, bias, running_mean, running_var, training, momentum, eps, out, save_mean, save_invstd); + } + + // aten::_native_batch_norm_legit.out(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, bool training, float momentum, float eps, *, Tensor(d!) out, Tensor(e!) save_mean, Tensor(f!) save_invstd) -> (Tensor(d!), Tensor(e!), Tensor(f!)) + inline ::std::tuple _native_batch_norm_legit_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, at::Tensor & running_mean, at::Tensor & running_var, bool training, double momentum, double eps, at::Tensor & out, at::Tensor & save_mean, at::Tensor & save_invstd) { + return at::_ops::_native_batch_norm_legit_out::redispatch(dispatchKeySet, input, weight, bias, running_mean, running_var, training, momentum, eps, out, save_mean, save_invstd); + } + + // aten::_native_batch_norm_legit.no_stats(Tensor input, Tensor? weight, Tensor? bias, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor) + inline ::std::tuple _native_batch_norm_legit(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, bool training, double momentum, double eps) { + return at::_ops::_native_batch_norm_legit_no_stats::redispatch(dispatchKeySet, input, weight, bias, training, momentum, eps); + } + + // aten::_native_batch_norm_legit.no_stats_out(Tensor input, Tensor? weight, Tensor? bias, bool training, float momentum, float eps, *, Tensor(a!) out, Tensor(b!) save_mean, Tensor(c!) save_invstd) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple _native_batch_norm_legit_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::Tensor & save_mean, at::Tensor & save_invstd, const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, bool training, double momentum, double eps) { + return at::_ops::_native_batch_norm_legit_no_stats_out::redispatch(dispatchKeySet, input, weight, bias, training, momentum, eps, out, save_mean, save_invstd); + } + + // aten::_native_batch_norm_legit.no_stats_out(Tensor input, Tensor? weight, Tensor? bias, bool training, float momentum, float eps, *, Tensor(a!) out, Tensor(b!) save_mean, Tensor(c!) save_invstd) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple _native_batch_norm_legit_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, bool training, double momentum, double eps, at::Tensor & out, at::Tensor & save_mean, at::Tensor & save_invstd) { + return at::_ops::_native_batch_norm_legit_no_stats_out::redispatch(dispatchKeySet, input, weight, bias, training, momentum, eps, out, save_mean, save_invstd); + } + + // aten::batch_norm_stats(Tensor input, float eps) -> (Tensor, Tensor) + inline ::std::tuple batch_norm_stats(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, double eps) { + return at::_ops::batch_norm_stats::redispatch(dispatchKeySet, input, eps); + } + + // aten::batch_norm_elemt(Tensor input, Tensor? weight, Tensor? bias, Tensor mean, Tensor invstd, float eps) -> Tensor + inline at::Tensor batch_norm_elemt(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const at::Tensor & mean, const at::Tensor & invstd, double eps) { + return at::_ops::batch_norm_elemt::redispatch(dispatchKeySet, input, weight, bias, mean, invstd, eps); + } + + // aten::batch_norm_elemt.out(Tensor input, Tensor? weight, Tensor? bias, Tensor mean, Tensor invstd, float eps, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & batch_norm_elemt_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const at::Tensor & mean, const at::Tensor & invstd, double eps) { + return at::_ops::batch_norm_elemt_out::redispatch(dispatchKeySet, input, weight, bias, mean, invstd, eps, out); + } + + // aten::batch_norm_elemt.out(Tensor input, Tensor? weight, Tensor? bias, Tensor mean, Tensor invstd, float eps, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & batch_norm_elemt_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const at::Tensor & mean, const at::Tensor & invstd, double eps, at::Tensor & out) { + return at::_ops::batch_norm_elemt_out::redispatch(dispatchKeySet, input, weight, bias, mean, invstd, eps, out); + } + + // aten::batch_norm_gather_stats(Tensor input, Tensor mean, Tensor invstd, Tensor? running_mean, Tensor? running_var, float momentum, float eps, int count) -> (Tensor, Tensor) + inline ::std::tuple batch_norm_gather_stats(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const ::std::optional & running_mean, const ::std::optional & running_var, double momentum, double eps, int64_t count) { + return at::_ops::batch_norm_gather_stats::redispatch(dispatchKeySet, input, mean, invstd, running_mean, running_var, momentum, eps, count); + } + + // aten::batch_norm_gather_stats_with_counts(Tensor input, Tensor mean, Tensor invstd, Tensor? running_mean, Tensor? running_var, float momentum, float eps, Tensor counts) -> (Tensor, Tensor) + inline ::std::tuple batch_norm_gather_stats_with_counts(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const ::std::optional & running_mean, const ::std::optional & running_var, double momentum, double eps, const at::Tensor & counts) { + return at::_ops::batch_norm_gather_stats_with_counts::redispatch(dispatchKeySet, input, mean, invstd, running_mean, running_var, momentum, eps, counts); + } + + // aten::native_batch_norm_backward(Tensor grad_out, Tensor input, Tensor? weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_invstd, bool train, float eps, bool[3] output_mask) -> (Tensor, Tensor, Tensor) + inline ::std::tuple native_batch_norm_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_out, const at::Tensor & input, const ::std::optional & weight, const ::std::optional & running_mean, const ::std::optional & running_var, const ::std::optional & save_mean, const ::std::optional & save_invstd, bool train, double eps, ::std::array output_mask) { + return at::_ops::native_batch_norm_backward::redispatch(dispatchKeySet, grad_out, input, weight, running_mean, running_var, save_mean, save_invstd, train, eps, output_mask); + } + + // aten::batch_norm_backward_reduce(Tensor grad_out, Tensor input, Tensor mean, Tensor invstd, Tensor? weight, bool input_g, bool weight_g, bool bias_g) -> (Tensor, Tensor, Tensor, Tensor) + inline ::std::tuple batch_norm_backward_reduce(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const ::std::optional & weight, bool input_g, bool weight_g, bool bias_g) { + return at::_ops::batch_norm_backward_reduce::redispatch(dispatchKeySet, grad_out, input, mean, invstd, weight, input_g, weight_g, bias_g); + } + + // aten::batch_norm_backward_elemt(Tensor grad_out, Tensor input, Tensor mean, Tensor invstd, Tensor? weight, Tensor sum_dy, Tensor sum_dy_xmu, Tensor count) -> Tensor + inline at::Tensor batch_norm_backward_elemt(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const ::std::optional & weight, const at::Tensor & sum_dy, const at::Tensor & sum_dy_xmu, const at::Tensor & count) { + return at::_ops::batch_norm_backward_elemt::redispatch(dispatchKeySet, grad_out, input, mean, invstd, weight, sum_dy, sum_dy_xmu, count); + } + + // aten::batch_norm_update_stats(Tensor input, Tensor? running_mean, Tensor? running_var, float momentum) -> (Tensor, Tensor) + inline ::std::tuple batch_norm_update_stats(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const ::std::optional & running_mean, const ::std::optional & running_var, double momentum) { + return at::_ops::batch_norm_update_stats::redispatch(dispatchKeySet, input, running_mean, running_var, momentum); + } + + // aten::is_vulkan_available() -> bool + inline bool is_vulkan_available(c10::DispatchKeySet dispatchKeySet) { + return at::_ops::is_vulkan_available::redispatch(dispatchKeySet); + } + + // aten::_nnpack_available() -> bool + inline bool _nnpack_available(c10::DispatchKeySet dispatchKeySet) { + return at::_ops::_nnpack_available::redispatch(dispatchKeySet); + } + + // aten::_nnpack_spatial_convolution(Tensor input, Tensor weight, Tensor? bias, SymInt[2] padding, SymInt[2] stride=1) -> Tensor + inline at::Tensor _nnpack_spatial_convolution(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, at::IntArrayRef padding, at::IntArrayRef stride=1) { + return at::_ops::_nnpack_spatial_convolution::redispatch(dispatchKeySet, input, weight, bias, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(stride)); + } + + // aten::_nnpack_spatial_convolution(Tensor input, Tensor weight, Tensor? bias, SymInt[2] padding, SymInt[2] stride=1) -> Tensor + inline at::Tensor _nnpack_spatial_convolution_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride=c10::SymInt(1)) { + return at::_ops::_nnpack_spatial_convolution::redispatch(dispatchKeySet, input, weight, bias, padding, stride); + } + + // aten::ones.names(int[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor ones(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, ::std::optional names, at::TensorOptions options={}) { + return at::_ops::ones_names::redispatch(dispatchKeySet, size, names, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::ones.names(int[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor ones(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, ::std::optional names, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::ones_names::redispatch(dispatchKeySet, size, names, dtype, layout, device, pin_memory); + } + + // aten::ones(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor ones(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, at::TensorOptions options={}) { + return at::_ops::ones::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::ones(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor ones(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::ones::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), dtype, layout, device, pin_memory); + } + + // aten::ones(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor ones_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, at::TensorOptions options={}) { + return at::_ops::ones::redispatch(dispatchKeySet, size, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::ones(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor ones_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::ones::redispatch(dispatchKeySet, size, dtype, layout, device, pin_memory); + } + + // aten::ones.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & ones_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::IntArrayRef size) { + return at::_ops::ones_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), out); + } + + // aten::ones.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & ones_outf(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, at::Tensor & out) { + return at::_ops::ones_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), out); + } + + // aten::ones.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & ones_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, c10::SymIntArrayRef size) { + return at::_ops::ones_out::redispatch(dispatchKeySet, size, out); + } + + // aten::ones.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & ones_symint_outf(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, at::Tensor & out) { + return at::_ops::ones_out::redispatch(dispatchKeySet, size, out); + } + + // aten::ones_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + inline at::Tensor ones_like(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::TensorOptions options={}, ::std::optional memory_format=::std::nullopt) { + return at::_ops::ones_like::redispatch(dispatchKeySet, self, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), c10::impl::check_tensor_options_and_extract_memory_format(options, memory_format)); + } + + // aten::ones_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + inline at::Tensor ones_like(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format) { + return at::_ops::ones_like::redispatch(dispatchKeySet, self, dtype, layout, device, pin_memory, memory_format); + } + + // aten::pairwise_distance(Tensor x1, Tensor x2, float p=2, float eps=1e-06, bool keepdim=False) -> Tensor + inline at::Tensor pairwise_distance(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x1, const at::Tensor & x2, double p=2, double eps=1e-06, bool keepdim=false) { + return at::_ops::pairwise_distance::redispatch(dispatchKeySet, x1, x2, p, eps, keepdim); + } + + // aten::cdist(Tensor x1, Tensor x2, float p=2, int? compute_mode=None) -> Tensor + inline at::Tensor cdist(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x1, const at::Tensor & x2, double p=2, ::std::optional compute_mode=::std::nullopt) { + return at::_ops::cdist::redispatch(dispatchKeySet, x1, x2, p, compute_mode); + } + + // aten::_euclidean_dist(Tensor x1, Tensor x2) -> Tensor + inline at::Tensor _euclidean_dist(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x1, const at::Tensor & x2) { + return at::_ops::_euclidean_dist::redispatch(dispatchKeySet, x1, x2); + } + + // aten::_cdist_forward(Tensor x1, Tensor x2, float p, int? compute_mode) -> Tensor + inline at::Tensor _cdist_forward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x1, const at::Tensor & x2, double p, ::std::optional compute_mode) { + return at::_ops::_cdist_forward::redispatch(dispatchKeySet, x1, x2, p, compute_mode); + } + + // aten::_cdist_backward(Tensor grad, Tensor x1, Tensor x2, float p, Tensor cdist) -> Tensor + inline at::Tensor _cdist_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & x1, const at::Tensor & x2, double p, const at::Tensor & cdist) { + return at::_ops::_cdist_backward::redispatch(dispatchKeySet, grad, x1, x2, p, cdist); + } + + // aten::pdist(Tensor self, float p=2) -> Tensor + inline at::Tensor pdist(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double p=2) { + return at::_ops::pdist::redispatch(dispatchKeySet, self, p); + } + + // aten::_pdist_forward(Tensor self, float p=2) -> Tensor + inline at::Tensor _pdist_forward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double p=2) { + return at::_ops::_pdist_forward::redispatch(dispatchKeySet, self, p); + } + + // aten::_pdist_backward(Tensor grad, Tensor self, float p, Tensor pdist) -> Tensor + inline at::Tensor _pdist_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & self, double p, const at::Tensor & pdist) { + return at::_ops::_pdist_backward::redispatch(dispatchKeySet, grad, self, p, pdist); + } + + // aten::cosine_similarity(Tensor x1, Tensor x2, int dim=1, float eps=1e-08) -> Tensor + inline at::Tensor cosine_similarity(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x1, const at::Tensor & x2, int64_t dim=1, double eps=1e-08) { + return at::_ops::cosine_similarity::redispatch(dispatchKeySet, x1, x2, dim, eps); + } + + // aten::permute(Tensor(a) self, int[] dims) -> Tensor(a) + inline at::Tensor permute(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dims) { + return at::_ops::permute::redispatch(dispatchKeySet, self, dims); + } + + // aten::movedim.intlist(Tensor(a) self, int[] source, int[] destination) -> Tensor(a) + inline at::Tensor movedim(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef source, at::IntArrayRef destination) { + return at::_ops::movedim_intlist::redispatch(dispatchKeySet, self, source, destination); + } + + // aten::movedim.int(Tensor(a) self, int source, int destination) -> Tensor(a) + inline at::Tensor movedim(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t source, int64_t destination) { + return at::_ops::movedim_int::redispatch(dispatchKeySet, self, source, destination); + } + + // aten::moveaxis.intlist(Tensor(a) self, int[] source, int[] destination) -> Tensor(a) + inline at::Tensor moveaxis(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef source, at::IntArrayRef destination) { + return at::_ops::moveaxis_intlist::redispatch(dispatchKeySet, self, source, destination); + } + + // aten::moveaxis.int(Tensor(a) self, int source, int destination) -> Tensor(a) + inline at::Tensor moveaxis(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t source, int64_t destination) { + return at::_ops::moveaxis_int::redispatch(dispatchKeySet, self, source, destination); + } + + // aten::numpy_T(Tensor(a) self) -> Tensor(a) + inline at::Tensor numpy_T(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::numpy_T::redispatch(dispatchKeySet, self); + } + + // aten::matrix_H(Tensor(a) self) -> Tensor(a) + inline at::Tensor matrix_H(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::matrix_H::redispatch(dispatchKeySet, self); + } + + // aten::mT(Tensor(a) self) -> Tensor(a) + inline at::Tensor mT(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::mT::redispatch(dispatchKeySet, self); + } + + // aten::mH(Tensor(a) self) -> Tensor(a) + inline at::Tensor mH(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::mH::redispatch(dispatchKeySet, self); + } + + // aten::adjoint(Tensor(a) self) -> Tensor(a) + inline at::Tensor adjoint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::adjoint::redispatch(dispatchKeySet, self); + } + + // aten::pixel_shuffle(Tensor self, int upscale_factor) -> Tensor + inline at::Tensor pixel_shuffle(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t upscale_factor) { + return at::_ops::pixel_shuffle::redispatch(dispatchKeySet, self, upscale_factor); + } + + // aten::pixel_unshuffle(Tensor self, int downscale_factor) -> Tensor + inline at::Tensor pixel_unshuffle(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t downscale_factor) { + return at::_ops::pixel_unshuffle::redispatch(dispatchKeySet, self, downscale_factor); + } + + // aten::channel_shuffle(Tensor self, SymInt groups) -> Tensor + inline at::Tensor channel_shuffle(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t groups) { + return at::_ops::channel_shuffle::redispatch(dispatchKeySet, self, groups); + } + + // aten::channel_shuffle(Tensor self, SymInt groups) -> Tensor + inline at::Tensor channel_shuffle_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymInt groups) { + return at::_ops::channel_shuffle::redispatch(dispatchKeySet, self, groups); + } + + // aten::native_channel_shuffle(Tensor self, SymInt groups) -> Tensor + inline at::Tensor native_channel_shuffle(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t groups) { + return at::_ops::native_channel_shuffle::redispatch(dispatchKeySet, self, groups); + } + + // aten::native_channel_shuffle(Tensor self, SymInt groups) -> Tensor + inline at::Tensor native_channel_shuffle_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymInt groups) { + return at::_ops::native_channel_shuffle::redispatch(dispatchKeySet, self, groups); + } + + // aten::is_pinned(Tensor self, Device? device=None) -> bool + inline bool is_pinned(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional device=::std::nullopt) { + return at::_ops::is_pinned::redispatch(dispatchKeySet, self, device); + } + + // aten::pin_memory(Tensor(a) self, Device? device=None) -> Tensor(a) + inline at::Tensor pin_memory(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional device=::std::nullopt) { + return at::_ops::pin_memory::redispatch(dispatchKeySet, self, device); + } + + // aten::_pin_memory(Tensor self, Device? device=None) -> Tensor + inline at::Tensor _pin_memory(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional device=::std::nullopt) { + return at::_ops::_pin_memory::redispatch(dispatchKeySet, self, device); + } + + // aten::pinverse(Tensor self, float rcond=1e-15) -> Tensor + inline at::Tensor pinverse(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double rcond=1e-15) { + return at::_ops::pinverse::redispatch(dispatchKeySet, self, rcond); + } + + // aten::poisson_nll_loss(Tensor input, Tensor target, bool log_input, bool full, float eps, int reduction) -> Tensor + inline at::Tensor poisson_nll_loss(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & target, bool log_input, bool full, double eps, int64_t reduction) { + return at::_ops::poisson_nll_loss::redispatch(dispatchKeySet, input, target, log_input, full, eps, reduction); + } + + // aten::rad2deg(Tensor self) -> Tensor + inline at::Tensor rad2deg(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::rad2deg::redispatch(dispatchKeySet, self); + } + + // aten::rad2deg_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & rad2deg_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::rad2deg_::redispatch(dispatchKeySet, self); + } + + // aten::rad2deg.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & rad2deg_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::rad2deg_out::redispatch(dispatchKeySet, self, out); + } + + // aten::rad2deg.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & rad2deg_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::rad2deg_out::redispatch(dispatchKeySet, self, out); + } + + // aten::deg2rad(Tensor self) -> Tensor + inline at::Tensor deg2rad(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::deg2rad::redispatch(dispatchKeySet, self); + } + + // aten::deg2rad_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & deg2rad_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::deg2rad_::redispatch(dispatchKeySet, self); + } + + // aten::deg2rad.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & deg2rad_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::deg2rad_out::redispatch(dispatchKeySet, self, out); + } + + // aten::deg2rad.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & deg2rad_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::deg2rad_out::redispatch(dispatchKeySet, self, out); + } + + // aten::scalar_tensor(Scalar s, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor scalar_tensor(c10::DispatchKeySet dispatchKeySet, const at::Scalar & s, at::TensorOptions options={}) { + return at::_ops::scalar_tensor::redispatch(dispatchKeySet, s, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::scalar_tensor(Scalar s, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor scalar_tensor(c10::DispatchKeySet dispatchKeySet, const at::Scalar & s, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::scalar_tensor::redispatch(dispatchKeySet, s, dtype, layout, device, pin_memory); + } + + // aten::rand.names(SymInt[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor rand(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, ::std::optional names, at::TensorOptions options={}) { + return at::_ops::rand_names::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), names, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::rand.names(SymInt[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor rand(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, ::std::optional names, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::rand_names::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), names, dtype, layout, device, pin_memory); + } + + // aten::rand.names(SymInt[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor rand_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, ::std::optional names, at::TensorOptions options={}) { + return at::_ops::rand_names::redispatch(dispatchKeySet, size, names, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::rand.names(SymInt[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor rand_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, ::std::optional names, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::rand_names::redispatch(dispatchKeySet, size, names, dtype, layout, device, pin_memory); + } + + // aten::rand.generator_with_names(SymInt[] size, *, Generator? generator, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor rand(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, ::std::optional generator, ::std::optional names, at::TensorOptions options={}) { + return at::_ops::rand_generator_with_names::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), generator, names, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::rand.generator_with_names(SymInt[] size, *, Generator? generator, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor rand(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, ::std::optional generator, ::std::optional names, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::rand_generator_with_names::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), generator, names, dtype, layout, device, pin_memory); + } + + // aten::rand.generator_with_names(SymInt[] size, *, Generator? generator, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor rand_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, ::std::optional generator, ::std::optional names, at::TensorOptions options={}) { + return at::_ops::rand_generator_with_names::redispatch(dispatchKeySet, size, generator, names, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::rand.generator_with_names(SymInt[] size, *, Generator? generator, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor rand_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, ::std::optional generator, ::std::optional names, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::rand_generator_with_names::redispatch(dispatchKeySet, size, generator, names, dtype, layout, device, pin_memory); + } + + // aten::rand(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor rand(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, at::TensorOptions options={}) { + return at::_ops::rand::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::rand(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor rand(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::rand::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), dtype, layout, device, pin_memory); + } + + // aten::rand(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor rand_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, at::TensorOptions options={}) { + return at::_ops::rand::redispatch(dispatchKeySet, size, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::rand(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor rand_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::rand::redispatch(dispatchKeySet, size, dtype, layout, device, pin_memory); + } + + // aten::rand.generator(SymInt[] size, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor rand(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, ::std::optional generator, at::TensorOptions options={}) { + return at::_ops::rand_generator::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), generator, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::rand.generator(SymInt[] size, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor rand(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, ::std::optional generator, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::rand_generator::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), generator, dtype, layout, device, pin_memory); + } + + // aten::rand.generator(SymInt[] size, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor rand_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, ::std::optional generator, at::TensorOptions options={}) { + return at::_ops::rand_generator::redispatch(dispatchKeySet, size, generator, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::rand.generator(SymInt[] size, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor rand_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, ::std::optional generator, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::rand_generator::redispatch(dispatchKeySet, size, generator, dtype, layout, device, pin_memory); + } + + // aten::rand.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & rand_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::IntArrayRef size) { + return at::_ops::rand_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), out); + } + + // aten::rand.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & rand_outf(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, at::Tensor & out) { + return at::_ops::rand_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), out); + } + + // aten::rand.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & rand_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, c10::SymIntArrayRef size) { + return at::_ops::rand_out::redispatch(dispatchKeySet, size, out); + } + + // aten::rand.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & rand_symint_outf(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, at::Tensor & out) { + return at::_ops::rand_out::redispatch(dispatchKeySet, size, out); + } + + // aten::rand.generator_out(SymInt[] size, *, Generator? generator, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & rand_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::IntArrayRef size, ::std::optional generator) { + return at::_ops::rand_generator_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), generator, out); + } + + // aten::rand.generator_out(SymInt[] size, *, Generator? generator, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & rand_outf(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, ::std::optional generator, at::Tensor & out) { + return at::_ops::rand_generator_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), generator, out); + } + + // aten::rand.generator_out(SymInt[] size, *, Generator? generator, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & rand_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, c10::SymIntArrayRef size, ::std::optional generator) { + return at::_ops::rand_generator_out::redispatch(dispatchKeySet, size, generator, out); + } + + // aten::rand.generator_out(SymInt[] size, *, Generator? generator, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & rand_symint_outf(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, ::std::optional generator, at::Tensor & out) { + return at::_ops::rand_generator_out::redispatch(dispatchKeySet, size, generator, out); + } + + // aten::rand_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + inline at::Tensor rand_like(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::TensorOptions options={}, ::std::optional memory_format=::std::nullopt) { + return at::_ops::rand_like::redispatch(dispatchKeySet, self, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), c10::impl::check_tensor_options_and_extract_memory_format(options, memory_format)); + } + + // aten::rand_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + inline at::Tensor rand_like(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format) { + return at::_ops::rand_like::redispatch(dispatchKeySet, self, dtype, layout, device, pin_memory, memory_format); + } + + // aten::rand_like.generator(Tensor self, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + inline at::Tensor rand_like(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional generator, at::TensorOptions options={}, ::std::optional memory_format=::std::nullopt) { + return at::_ops::rand_like_generator::redispatch(dispatchKeySet, self, generator, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), c10::impl::check_tensor_options_and_extract_memory_format(options, memory_format)); + } + + // aten::rand_like.generator(Tensor self, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + inline at::Tensor rand_like(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional generator, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format) { + return at::_ops::rand_like_generator::redispatch(dispatchKeySet, self, generator, dtype, layout, device, pin_memory, memory_format); + } + + // aten::randint(SymInt high, SymInt[] size, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor randint(c10::DispatchKeySet dispatchKeySet, int64_t high, at::IntArrayRef size, at::TensorOptions options=at::kLong) { + return at::_ops::randint::redispatch(dispatchKeySet, high, c10::fromIntArrayRefSlow(size), c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::randint(SymInt high, SymInt[] size, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor randint(c10::DispatchKeySet dispatchKeySet, int64_t high, at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::randint::redispatch(dispatchKeySet, high, c10::fromIntArrayRefSlow(size), dtype, layout, device, pin_memory); + } + + // aten::randint(SymInt high, SymInt[] size, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor randint_symint(c10::DispatchKeySet dispatchKeySet, c10::SymInt high, c10::SymIntArrayRef size, at::TensorOptions options=at::kLong) { + return at::_ops::randint::redispatch(dispatchKeySet, high, size, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::randint(SymInt high, SymInt[] size, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor randint_symint(c10::DispatchKeySet dispatchKeySet, c10::SymInt high, c10::SymIntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::randint::redispatch(dispatchKeySet, high, size, dtype, layout, device, pin_memory); + } + + // aten::randint.generator(SymInt high, SymInt[] size, *, Generator? generator, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor randint(c10::DispatchKeySet dispatchKeySet, int64_t high, at::IntArrayRef size, ::std::optional generator, at::TensorOptions options=at::kLong) { + return at::_ops::randint_generator::redispatch(dispatchKeySet, high, c10::fromIntArrayRefSlow(size), generator, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::randint.generator(SymInt high, SymInt[] size, *, Generator? generator, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor randint(c10::DispatchKeySet dispatchKeySet, int64_t high, at::IntArrayRef size, ::std::optional generator, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::randint_generator::redispatch(dispatchKeySet, high, c10::fromIntArrayRefSlow(size), generator, dtype, layout, device, pin_memory); + } + + // aten::randint.generator(SymInt high, SymInt[] size, *, Generator? generator, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor randint_symint(c10::DispatchKeySet dispatchKeySet, c10::SymInt high, c10::SymIntArrayRef size, ::std::optional generator, at::TensorOptions options=at::kLong) { + return at::_ops::randint_generator::redispatch(dispatchKeySet, high, size, generator, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::randint.generator(SymInt high, SymInt[] size, *, Generator? generator, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor randint_symint(c10::DispatchKeySet dispatchKeySet, c10::SymInt high, c10::SymIntArrayRef size, ::std::optional generator, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::randint_generator::redispatch(dispatchKeySet, high, size, generator, dtype, layout, device, pin_memory); + } + + // aten::randint.low(SymInt low, SymInt high, SymInt[] size, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor randint(c10::DispatchKeySet dispatchKeySet, int64_t low, int64_t high, at::IntArrayRef size, at::TensorOptions options=at::kLong) { + return at::_ops::randint_low::redispatch(dispatchKeySet, low, high, c10::fromIntArrayRefSlow(size), c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::randint.low(SymInt low, SymInt high, SymInt[] size, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor randint(c10::DispatchKeySet dispatchKeySet, int64_t low, int64_t high, at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::randint_low::redispatch(dispatchKeySet, low, high, c10::fromIntArrayRefSlow(size), dtype, layout, device, pin_memory); + } + + // aten::randint.low(SymInt low, SymInt high, SymInt[] size, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor randint_symint(c10::DispatchKeySet dispatchKeySet, c10::SymInt low, c10::SymInt high, c10::SymIntArrayRef size, at::TensorOptions options=at::kLong) { + return at::_ops::randint_low::redispatch(dispatchKeySet, low, high, size, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::randint.low(SymInt low, SymInt high, SymInt[] size, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor randint_symint(c10::DispatchKeySet dispatchKeySet, c10::SymInt low, c10::SymInt high, c10::SymIntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::randint_low::redispatch(dispatchKeySet, low, high, size, dtype, layout, device, pin_memory); + } + + // aten::randint.low_generator(SymInt low, SymInt high, SymInt[] size, *, Generator? generator, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor randint(c10::DispatchKeySet dispatchKeySet, int64_t low, int64_t high, at::IntArrayRef size, ::std::optional generator, at::TensorOptions options=at::kLong) { + return at::_ops::randint_low_generator::redispatch(dispatchKeySet, low, high, c10::fromIntArrayRefSlow(size), generator, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::randint.low_generator(SymInt low, SymInt high, SymInt[] size, *, Generator? generator, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor randint(c10::DispatchKeySet dispatchKeySet, int64_t low, int64_t high, at::IntArrayRef size, ::std::optional generator, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::randint_low_generator::redispatch(dispatchKeySet, low, high, c10::fromIntArrayRefSlow(size), generator, dtype, layout, device, pin_memory); + } + + // aten::randint.low_generator(SymInt low, SymInt high, SymInt[] size, *, Generator? generator, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor randint_symint(c10::DispatchKeySet dispatchKeySet, c10::SymInt low, c10::SymInt high, c10::SymIntArrayRef size, ::std::optional generator, at::TensorOptions options=at::kLong) { + return at::_ops::randint_low_generator::redispatch(dispatchKeySet, low, high, size, generator, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::randint.low_generator(SymInt low, SymInt high, SymInt[] size, *, Generator? generator, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor randint_symint(c10::DispatchKeySet dispatchKeySet, c10::SymInt low, c10::SymInt high, c10::SymIntArrayRef size, ::std::optional generator, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::randint_low_generator::redispatch(dispatchKeySet, low, high, size, generator, dtype, layout, device, pin_memory); + } + + // aten::randint.out(SymInt high, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, int64_t high, at::IntArrayRef size) { + return at::_ops::randint_out::redispatch(dispatchKeySet, high, c10::fromIntArrayRefSlow(size), out); + } + + // aten::randint.out(SymInt high, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randint_outf(c10::DispatchKeySet dispatchKeySet, int64_t high, at::IntArrayRef size, at::Tensor & out) { + return at::_ops::randint_out::redispatch(dispatchKeySet, high, c10::fromIntArrayRefSlow(size), out); + } + + // aten::randint.out(SymInt high, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randint_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, c10::SymInt high, c10::SymIntArrayRef size) { + return at::_ops::randint_out::redispatch(dispatchKeySet, high, size, out); + } + + // aten::randint.out(SymInt high, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randint_symint_outf(c10::DispatchKeySet dispatchKeySet, c10::SymInt high, c10::SymIntArrayRef size, at::Tensor & out) { + return at::_ops::randint_out::redispatch(dispatchKeySet, high, size, out); + } + + // aten::randint.generator_out(SymInt high, SymInt[] size, *, Generator? generator, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, int64_t high, at::IntArrayRef size, ::std::optional generator) { + return at::_ops::randint_generator_out::redispatch(dispatchKeySet, high, c10::fromIntArrayRefSlow(size), generator, out); + } + + // aten::randint.generator_out(SymInt high, SymInt[] size, *, Generator? generator, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randint_outf(c10::DispatchKeySet dispatchKeySet, int64_t high, at::IntArrayRef size, ::std::optional generator, at::Tensor & out) { + return at::_ops::randint_generator_out::redispatch(dispatchKeySet, high, c10::fromIntArrayRefSlow(size), generator, out); + } + + // aten::randint.generator_out(SymInt high, SymInt[] size, *, Generator? generator, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randint_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, c10::SymInt high, c10::SymIntArrayRef size, ::std::optional generator) { + return at::_ops::randint_generator_out::redispatch(dispatchKeySet, high, size, generator, out); + } + + // aten::randint.generator_out(SymInt high, SymInt[] size, *, Generator? generator, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randint_symint_outf(c10::DispatchKeySet dispatchKeySet, c10::SymInt high, c10::SymIntArrayRef size, ::std::optional generator, at::Tensor & out) { + return at::_ops::randint_generator_out::redispatch(dispatchKeySet, high, size, generator, out); + } + + // aten::randint.low_out(SymInt low, SymInt high, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, int64_t low, int64_t high, at::IntArrayRef size) { + return at::_ops::randint_low_out::redispatch(dispatchKeySet, low, high, c10::fromIntArrayRefSlow(size), out); + } + + // aten::randint.low_out(SymInt low, SymInt high, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randint_outf(c10::DispatchKeySet dispatchKeySet, int64_t low, int64_t high, at::IntArrayRef size, at::Tensor & out) { + return at::_ops::randint_low_out::redispatch(dispatchKeySet, low, high, c10::fromIntArrayRefSlow(size), out); + } + + // aten::randint.low_out(SymInt low, SymInt high, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randint_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, c10::SymInt low, c10::SymInt high, c10::SymIntArrayRef size) { + return at::_ops::randint_low_out::redispatch(dispatchKeySet, low, high, size, out); + } + + // aten::randint.low_out(SymInt low, SymInt high, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randint_symint_outf(c10::DispatchKeySet dispatchKeySet, c10::SymInt low, c10::SymInt high, c10::SymIntArrayRef size, at::Tensor & out) { + return at::_ops::randint_low_out::redispatch(dispatchKeySet, low, high, size, out); + } + + // aten::randint.low_generator_out(SymInt low, SymInt high, SymInt[] size, *, Generator? generator, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, int64_t low, int64_t high, at::IntArrayRef size, ::std::optional generator) { + return at::_ops::randint_low_generator_out::redispatch(dispatchKeySet, low, high, c10::fromIntArrayRefSlow(size), generator, out); + } + + // aten::randint.low_generator_out(SymInt low, SymInt high, SymInt[] size, *, Generator? generator, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randint_outf(c10::DispatchKeySet dispatchKeySet, int64_t low, int64_t high, at::IntArrayRef size, ::std::optional generator, at::Tensor & out) { + return at::_ops::randint_low_generator_out::redispatch(dispatchKeySet, low, high, c10::fromIntArrayRefSlow(size), generator, out); + } + + // aten::randint.low_generator_out(SymInt low, SymInt high, SymInt[] size, *, Generator? generator, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randint_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, c10::SymInt low, c10::SymInt high, c10::SymIntArrayRef size, ::std::optional generator) { + return at::_ops::randint_low_generator_out::redispatch(dispatchKeySet, low, high, size, generator, out); + } + + // aten::randint.low_generator_out(SymInt low, SymInt high, SymInt[] size, *, Generator? generator, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randint_symint_outf(c10::DispatchKeySet dispatchKeySet, c10::SymInt low, c10::SymInt high, c10::SymIntArrayRef size, ::std::optional generator, at::Tensor & out) { + return at::_ops::randint_low_generator_out::redispatch(dispatchKeySet, low, high, size, generator, out); + } + + // aten::randint_like(Tensor self, SymInt high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + inline at::Tensor randint_like(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t high, at::TensorOptions options={}, ::std::optional memory_format=::std::nullopt) { + return at::_ops::randint_like::redispatch(dispatchKeySet, self, high, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), c10::impl::check_tensor_options_and_extract_memory_format(options, memory_format)); + } + + // aten::randint_like(Tensor self, SymInt high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + inline at::Tensor randint_like(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t high, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format) { + return at::_ops::randint_like::redispatch(dispatchKeySet, self, high, dtype, layout, device, pin_memory, memory_format); + } + + // aten::randint_like(Tensor self, SymInt high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + inline at::Tensor randint_like_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymInt high, at::TensorOptions options={}, ::std::optional memory_format=::std::nullopt) { + return at::_ops::randint_like::redispatch(dispatchKeySet, self, high, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), c10::impl::check_tensor_options_and_extract_memory_format(options, memory_format)); + } + + // aten::randint_like(Tensor self, SymInt high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + inline at::Tensor randint_like_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymInt high, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format) { + return at::_ops::randint_like::redispatch(dispatchKeySet, self, high, dtype, layout, device, pin_memory, memory_format); + } + + // aten::randint_like.generator(Tensor self, SymInt high, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + inline at::Tensor randint_like(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t high, ::std::optional generator, at::TensorOptions options={}, ::std::optional memory_format=::std::nullopt) { + return at::_ops::randint_like_generator::redispatch(dispatchKeySet, self, high, generator, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), c10::impl::check_tensor_options_and_extract_memory_format(options, memory_format)); + } + + // aten::randint_like.generator(Tensor self, SymInt high, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + inline at::Tensor randint_like(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t high, ::std::optional generator, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format) { + return at::_ops::randint_like_generator::redispatch(dispatchKeySet, self, high, generator, dtype, layout, device, pin_memory, memory_format); + } + + // aten::randint_like.generator(Tensor self, SymInt high, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + inline at::Tensor randint_like_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymInt high, ::std::optional generator, at::TensorOptions options={}, ::std::optional memory_format=::std::nullopt) { + return at::_ops::randint_like_generator::redispatch(dispatchKeySet, self, high, generator, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), c10::impl::check_tensor_options_and_extract_memory_format(options, memory_format)); + } + + // aten::randint_like.generator(Tensor self, SymInt high, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + inline at::Tensor randint_like_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymInt high, ::std::optional generator, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format) { + return at::_ops::randint_like_generator::redispatch(dispatchKeySet, self, high, generator, dtype, layout, device, pin_memory, memory_format); + } + + // aten::randint_like.Tensor(Tensor self, Tensor high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + inline at::Tensor randint_like(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & high, at::TensorOptions options={}, ::std::optional memory_format=::std::nullopt) { + return at::_ops::randint_like_Tensor::redispatch(dispatchKeySet, self, high, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), c10::impl::check_tensor_options_and_extract_memory_format(options, memory_format)); + } + + // aten::randint_like.Tensor(Tensor self, Tensor high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + inline at::Tensor randint_like(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & high, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format) { + return at::_ops::randint_like_Tensor::redispatch(dispatchKeySet, self, high, dtype, layout, device, pin_memory, memory_format); + } + + // aten::randint_like.Tensor_generator(Tensor self, Tensor high, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + inline at::Tensor randint_like(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & high, ::std::optional generator, at::TensorOptions options={}, ::std::optional memory_format=::std::nullopt) { + return at::_ops::randint_like_Tensor_generator::redispatch(dispatchKeySet, self, high, generator, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), c10::impl::check_tensor_options_and_extract_memory_format(options, memory_format)); + } + + // aten::randint_like.Tensor_generator(Tensor self, Tensor high, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + inline at::Tensor randint_like(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & high, ::std::optional generator, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format) { + return at::_ops::randint_like_Tensor_generator::redispatch(dispatchKeySet, self, high, generator, dtype, layout, device, pin_memory, memory_format); + } + + // aten::randint_like.low_dtype(Tensor self, SymInt low, SymInt high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + inline at::Tensor randint_like(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t low, int64_t high, at::TensorOptions options={}, ::std::optional memory_format=::std::nullopt) { + return at::_ops::randint_like_low_dtype::redispatch(dispatchKeySet, self, low, high, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), c10::impl::check_tensor_options_and_extract_memory_format(options, memory_format)); + } + + // aten::randint_like.low_dtype(Tensor self, SymInt low, SymInt high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + inline at::Tensor randint_like(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t low, int64_t high, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format) { + return at::_ops::randint_like_low_dtype::redispatch(dispatchKeySet, self, low, high, dtype, layout, device, pin_memory, memory_format); + } + + // aten::randint_like.low_dtype(Tensor self, SymInt low, SymInt high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + inline at::Tensor randint_like_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymInt low, c10::SymInt high, at::TensorOptions options={}, ::std::optional memory_format=::std::nullopt) { + return at::_ops::randint_like_low_dtype::redispatch(dispatchKeySet, self, low, high, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), c10::impl::check_tensor_options_and_extract_memory_format(options, memory_format)); + } + + // aten::randint_like.low_dtype(Tensor self, SymInt low, SymInt high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + inline at::Tensor randint_like_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymInt low, c10::SymInt high, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format) { + return at::_ops::randint_like_low_dtype::redispatch(dispatchKeySet, self, low, high, dtype, layout, device, pin_memory, memory_format); + } + + // aten::randint_like.low_generator_dtype(Tensor self, SymInt low, SymInt high, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + inline at::Tensor randint_like(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t low, int64_t high, ::std::optional generator, at::TensorOptions options={}, ::std::optional memory_format=::std::nullopt) { + return at::_ops::randint_like_low_generator_dtype::redispatch(dispatchKeySet, self, low, high, generator, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), c10::impl::check_tensor_options_and_extract_memory_format(options, memory_format)); + } + + // aten::randint_like.low_generator_dtype(Tensor self, SymInt low, SymInt high, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + inline at::Tensor randint_like(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t low, int64_t high, ::std::optional generator, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format) { + return at::_ops::randint_like_low_generator_dtype::redispatch(dispatchKeySet, self, low, high, generator, dtype, layout, device, pin_memory, memory_format); + } + + // aten::randint_like.low_generator_dtype(Tensor self, SymInt low, SymInt high, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + inline at::Tensor randint_like_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymInt low, c10::SymInt high, ::std::optional generator, at::TensorOptions options={}, ::std::optional memory_format=::std::nullopt) { + return at::_ops::randint_like_low_generator_dtype::redispatch(dispatchKeySet, self, low, high, generator, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), c10::impl::check_tensor_options_and_extract_memory_format(options, memory_format)); + } + + // aten::randint_like.low_generator_dtype(Tensor self, SymInt low, SymInt high, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + inline at::Tensor randint_like_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymInt low, c10::SymInt high, ::std::optional generator, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format) { + return at::_ops::randint_like_low_generator_dtype::redispatch(dispatchKeySet, self, low, high, generator, dtype, layout, device, pin_memory, memory_format); + } + + // aten::randn(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor randn(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, at::TensorOptions options={}) { + return at::_ops::randn::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::randn(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor randn(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::randn::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), dtype, layout, device, pin_memory); + } + + // aten::randn(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor randn_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, at::TensorOptions options={}) { + return at::_ops::randn::redispatch(dispatchKeySet, size, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::randn(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor randn_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::randn::redispatch(dispatchKeySet, size, dtype, layout, device, pin_memory); + } + + // aten::randn.generator(SymInt[] size, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor randn(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, ::std::optional generator, at::TensorOptions options={}) { + return at::_ops::randn_generator::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), generator, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::randn.generator(SymInt[] size, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor randn(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, ::std::optional generator, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::randn_generator::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), generator, dtype, layout, device, pin_memory); + } + + // aten::randn.generator(SymInt[] size, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor randn_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, ::std::optional generator, at::TensorOptions options={}) { + return at::_ops::randn_generator::redispatch(dispatchKeySet, size, generator, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::randn.generator(SymInt[] size, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor randn_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, ::std::optional generator, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::randn_generator::redispatch(dispatchKeySet, size, generator, dtype, layout, device, pin_memory); + } + + // aten::randn.names(SymInt[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor randn(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, ::std::optional names, at::TensorOptions options={}) { + return at::_ops::randn_names::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), names, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::randn.names(SymInt[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor randn(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, ::std::optional names, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::randn_names::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), names, dtype, layout, device, pin_memory); + } + + // aten::randn.names(SymInt[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor randn_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, ::std::optional names, at::TensorOptions options={}) { + return at::_ops::randn_names::redispatch(dispatchKeySet, size, names, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::randn.names(SymInt[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor randn_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, ::std::optional names, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::randn_names::redispatch(dispatchKeySet, size, names, dtype, layout, device, pin_memory); + } + + // aten::randn.generator_with_names(SymInt[] size, *, Generator? generator, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor randn(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, ::std::optional generator, ::std::optional names, at::TensorOptions options={}) { + return at::_ops::randn_generator_with_names::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), generator, names, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::randn.generator_with_names(SymInt[] size, *, Generator? generator, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor randn(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, ::std::optional generator, ::std::optional names, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::randn_generator_with_names::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), generator, names, dtype, layout, device, pin_memory); + } + + // aten::randn.generator_with_names(SymInt[] size, *, Generator? generator, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor randn_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, ::std::optional generator, ::std::optional names, at::TensorOptions options={}) { + return at::_ops::randn_generator_with_names::redispatch(dispatchKeySet, size, generator, names, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::randn.generator_with_names(SymInt[] size, *, Generator? generator, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor randn_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, ::std::optional generator, ::std::optional names, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::randn_generator_with_names::redispatch(dispatchKeySet, size, generator, names, dtype, layout, device, pin_memory); + } + + // aten::randn.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randn_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::IntArrayRef size) { + return at::_ops::randn_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), out); + } + + // aten::randn.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randn_outf(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, at::Tensor & out) { + return at::_ops::randn_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), out); + } + + // aten::randn.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randn_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, c10::SymIntArrayRef size) { + return at::_ops::randn_out::redispatch(dispatchKeySet, size, out); + } + + // aten::randn.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randn_symint_outf(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, at::Tensor & out) { + return at::_ops::randn_out::redispatch(dispatchKeySet, size, out); + } + + // aten::randn.generator_out(SymInt[] size, *, Generator? generator, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randn_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::IntArrayRef size, ::std::optional generator) { + return at::_ops::randn_generator_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), generator, out); + } + + // aten::randn.generator_out(SymInt[] size, *, Generator? generator, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randn_outf(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, ::std::optional generator, at::Tensor & out) { + return at::_ops::randn_generator_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), generator, out); + } + + // aten::randn.generator_out(SymInt[] size, *, Generator? generator, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randn_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, c10::SymIntArrayRef size, ::std::optional generator) { + return at::_ops::randn_generator_out::redispatch(dispatchKeySet, size, generator, out); + } + + // aten::randn.generator_out(SymInt[] size, *, Generator? generator, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randn_symint_outf(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, ::std::optional generator, at::Tensor & out) { + return at::_ops::randn_generator_out::redispatch(dispatchKeySet, size, generator, out); + } + + // aten::randn_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + inline at::Tensor randn_like(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::TensorOptions options={}, ::std::optional memory_format=::std::nullopt) { + return at::_ops::randn_like::redispatch(dispatchKeySet, self, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), c10::impl::check_tensor_options_and_extract_memory_format(options, memory_format)); + } + + // aten::randn_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + inline at::Tensor randn_like(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format) { + return at::_ops::randn_like::redispatch(dispatchKeySet, self, dtype, layout, device, pin_memory, memory_format); + } + + // aten::randn_like.generator(Tensor self, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + inline at::Tensor randn_like(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional generator, at::TensorOptions options={}, ::std::optional memory_format=::std::nullopt) { + return at::_ops::randn_like_generator::redispatch(dispatchKeySet, self, generator, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), c10::impl::check_tensor_options_and_extract_memory_format(options, memory_format)); + } + + // aten::randn_like.generator(Tensor self, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + inline at::Tensor randn_like(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional generator, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format) { + return at::_ops::randn_like_generator::redispatch(dispatchKeySet, self, generator, dtype, layout, device, pin_memory, memory_format); + } + + // aten::randperm(SymInt n, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor randperm(c10::DispatchKeySet dispatchKeySet, int64_t n, at::TensorOptions options=at::kLong) { + return at::_ops::randperm::redispatch(dispatchKeySet, n, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::randperm(SymInt n, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor randperm(c10::DispatchKeySet dispatchKeySet, int64_t n, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::randperm::redispatch(dispatchKeySet, n, dtype, layout, device, pin_memory); + } + + // aten::randperm(SymInt n, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor randperm_symint(c10::DispatchKeySet dispatchKeySet, c10::SymInt n, at::TensorOptions options=at::kLong) { + return at::_ops::randperm::redispatch(dispatchKeySet, n, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::randperm(SymInt n, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor randperm_symint(c10::DispatchKeySet dispatchKeySet, c10::SymInt n, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::randperm::redispatch(dispatchKeySet, n, dtype, layout, device, pin_memory); + } + + // aten::randperm.generator(SymInt n, *, Generator? generator, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor randperm(c10::DispatchKeySet dispatchKeySet, int64_t n, ::std::optional generator, at::TensorOptions options=at::kLong) { + return at::_ops::randperm_generator::redispatch(dispatchKeySet, n, generator, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::randperm.generator(SymInt n, *, Generator? generator, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor randperm(c10::DispatchKeySet dispatchKeySet, int64_t n, ::std::optional generator, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::randperm_generator::redispatch(dispatchKeySet, n, generator, dtype, layout, device, pin_memory); + } + + // aten::randperm.generator(SymInt n, *, Generator? generator, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor randperm_symint(c10::DispatchKeySet dispatchKeySet, c10::SymInt n, ::std::optional generator, at::TensorOptions options=at::kLong) { + return at::_ops::randperm_generator::redispatch(dispatchKeySet, n, generator, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::randperm.generator(SymInt n, *, Generator? generator, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor randperm_symint(c10::DispatchKeySet dispatchKeySet, c10::SymInt n, ::std::optional generator, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::randperm_generator::redispatch(dispatchKeySet, n, generator, dtype, layout, device, pin_memory); + } + + // aten::randperm.out(SymInt n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randperm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, int64_t n) { + return at::_ops::randperm_out::redispatch(dispatchKeySet, n, out); + } + + // aten::randperm.out(SymInt n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randperm_outf(c10::DispatchKeySet dispatchKeySet, int64_t n, at::Tensor & out) { + return at::_ops::randperm_out::redispatch(dispatchKeySet, n, out); + } + + // aten::randperm.out(SymInt n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randperm_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, c10::SymInt n) { + return at::_ops::randperm_out::redispatch(dispatchKeySet, n, out); + } + + // aten::randperm.out(SymInt n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randperm_symint_outf(c10::DispatchKeySet dispatchKeySet, c10::SymInt n, at::Tensor & out) { + return at::_ops::randperm_out::redispatch(dispatchKeySet, n, out); + } + + // aten::randperm.generator_out(SymInt n, *, Generator? generator, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randperm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, int64_t n, ::std::optional generator) { + return at::_ops::randperm_generator_out::redispatch(dispatchKeySet, n, generator, out); + } + + // aten::randperm.generator_out(SymInt n, *, Generator? generator, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randperm_outf(c10::DispatchKeySet dispatchKeySet, int64_t n, ::std::optional generator, at::Tensor & out) { + return at::_ops::randperm_generator_out::redispatch(dispatchKeySet, n, generator, out); + } + + // aten::randperm.generator_out(SymInt n, *, Generator? generator, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randperm_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, c10::SymInt n, ::std::optional generator) { + return at::_ops::randperm_generator_out::redispatch(dispatchKeySet, n, generator, out); + } + + // aten::randperm.generator_out(SymInt n, *, Generator? generator, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randperm_symint_outf(c10::DispatchKeySet dispatchKeySet, c10::SymInt n, ::std::optional generator, at::Tensor & out) { + return at::_ops::randperm_generator_out::redispatch(dispatchKeySet, n, generator, out); + } + + // aten::range.step(Scalar start, Scalar end, Scalar step=1, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor range(c10::DispatchKeySet dispatchKeySet, const at::Scalar & start, const at::Scalar & end, const at::Scalar & step=1, at::TensorOptions options={}) { + return at::_ops::range_step::redispatch(dispatchKeySet, start, end, step, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::range.step(Scalar start, Scalar end, Scalar step=1, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor range(c10::DispatchKeySet dispatchKeySet, const at::Scalar & start, const at::Scalar & end, const at::Scalar & step, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::range_step::redispatch(dispatchKeySet, start, end, step, dtype, layout, device, pin_memory); + } + + // aten::range(Scalar start, Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor range(c10::DispatchKeySet dispatchKeySet, const at::Scalar & start, const at::Scalar & end, at::TensorOptions options={}) { + return at::_ops::range::redispatch(dispatchKeySet, start, end, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::range(Scalar start, Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor range(c10::DispatchKeySet dispatchKeySet, const at::Scalar & start, const at::Scalar & end, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::range::redispatch(dispatchKeySet, start, end, dtype, layout, device, pin_memory); + } + + // aten::range.out_(Scalar start, Scalar end, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & range_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & start, const at::Scalar & end) { + return at::_ops::range_out_::redispatch(dispatchKeySet, start, end, out); + } + + // aten::range.out_(Scalar start, Scalar end, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & range_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & start, const at::Scalar & end, at::Tensor & out) { + return at::_ops::range_out_::redispatch(dispatchKeySet, start, end, out); + } + + // aten::range.out(Scalar start, Scalar end, Scalar step=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & range_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & start, const at::Scalar & end, const at::Scalar & step) { + return at::_ops::range_out::redispatch(dispatchKeySet, start, end, step, out); + } + + // aten::range.out(Scalar start, Scalar end, Scalar step=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & range_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & start, const at::Scalar & end, const at::Scalar & step, at::Tensor & out) { + return at::_ops::range_out::redispatch(dispatchKeySet, start, end, step, out); + } + + // aten::ravel(Tensor(a) self) -> Tensor(a) + inline at::Tensor ravel(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::ravel::redispatch(dispatchKeySet, self); + } + + // aten::reciprocal(Tensor self) -> Tensor + inline at::Tensor reciprocal(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::reciprocal::redispatch(dispatchKeySet, self); + } + + // aten::reciprocal_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & reciprocal_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::reciprocal_::redispatch(dispatchKeySet, self); + } + + // aten::reciprocal.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & reciprocal_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::reciprocal_out::redispatch(dispatchKeySet, self, out); + } + + // aten::reciprocal.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & reciprocal_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::reciprocal_out::redispatch(dispatchKeySet, self, out); + } + + // aten::neg(Tensor self) -> Tensor + inline at::Tensor neg(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::neg::redispatch(dispatchKeySet, self); + } + + // aten::neg_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & neg_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::neg_::redispatch(dispatchKeySet, self); + } + + // aten::neg.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & neg_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::neg_out::redispatch(dispatchKeySet, self, out); + } + + // aten::neg.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & neg_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::neg_out::redispatch(dispatchKeySet, self, out); + } + + // aten::negative(Tensor self) -> Tensor + inline at::Tensor negative(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::negative::redispatch(dispatchKeySet, self); + } + + // aten::negative_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & negative_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::negative_::redispatch(dispatchKeySet, self); + } + + // aten::negative.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & negative_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::negative_out::redispatch(dispatchKeySet, self, out); + } + + // aten::negative.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & negative_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::negative_out::redispatch(dispatchKeySet, self, out); + } + + // aten::repeat(Tensor self, SymInt[] repeats) -> Tensor + inline at::Tensor repeat(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef repeats) { + return at::_ops::repeat::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(repeats)); + } + + // aten::repeat(Tensor self, SymInt[] repeats) -> Tensor + inline at::Tensor repeat_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef repeats) { + return at::_ops::repeat::redispatch(dispatchKeySet, self, repeats); + } + + // aten::repeat_interleave.Tensor(Tensor repeats, *, SymInt? output_size=None) -> Tensor + inline at::Tensor repeat_interleave(c10::DispatchKeySet dispatchKeySet, const at::Tensor & repeats, ::std::optional output_size=::std::nullopt) { + return at::_ops::repeat_interleave_Tensor::redispatch(dispatchKeySet, repeats, output_size.has_value() ? ::std::make_optional(c10::SymInt(*output_size)) : ::std::nullopt); + } + + // aten::repeat_interleave.Tensor(Tensor repeats, *, SymInt? output_size=None) -> Tensor + inline at::Tensor repeat_interleave_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & repeats, ::std::optional output_size=::std::nullopt) { + return at::_ops::repeat_interleave_Tensor::redispatch(dispatchKeySet, repeats, output_size); + } + + // aten::repeat_interleave.self_Tensor(Tensor self, Tensor repeats, int? dim=None, *, SymInt? output_size=None) -> Tensor + inline at::Tensor repeat_interleave(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & repeats, ::std::optional dim=::std::nullopt, ::std::optional output_size=::std::nullopt) { + return at::_ops::repeat_interleave_self_Tensor::redispatch(dispatchKeySet, self, repeats, dim, output_size.has_value() ? ::std::make_optional(c10::SymInt(*output_size)) : ::std::nullopt); + } + + // aten::repeat_interleave.self_Tensor(Tensor self, Tensor repeats, int? dim=None, *, SymInt? output_size=None) -> Tensor + inline at::Tensor repeat_interleave_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & repeats, ::std::optional dim=::std::nullopt, ::std::optional output_size=::std::nullopt) { + return at::_ops::repeat_interleave_self_Tensor::redispatch(dispatchKeySet, self, repeats, dim, output_size); + } + + // aten::repeat_interleave.self_int(Tensor self, SymInt repeats, int? dim=None, *, SymInt? output_size=None) -> Tensor + inline at::Tensor repeat_interleave(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t repeats, ::std::optional dim=::std::nullopt, ::std::optional output_size=::std::nullopt) { + return at::_ops::repeat_interleave_self_int::redispatch(dispatchKeySet, self, repeats, dim, output_size.has_value() ? ::std::make_optional(c10::SymInt(*output_size)) : ::std::nullopt); + } + + // aten::repeat_interleave.self_int(Tensor self, SymInt repeats, int? dim=None, *, SymInt? output_size=None) -> Tensor + inline at::Tensor repeat_interleave_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymInt repeats, ::std::optional dim=::std::nullopt, ::std::optional output_size=::std::nullopt) { + return at::_ops::repeat_interleave_self_int::redispatch(dispatchKeySet, self, repeats, dim, output_size); + } + + // aten::reshape(Tensor(a) self, SymInt[] shape) -> Tensor(a) + inline at::Tensor reshape(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef shape) { + return at::_ops::reshape::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(shape)); + } + + // aten::reshape(Tensor(a) self, SymInt[] shape) -> Tensor(a) + inline at::Tensor reshape_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef shape) { + return at::_ops::reshape::redispatch(dispatchKeySet, self, shape); + } + + // aten::_reshape_copy(Tensor self, SymInt[] size) -> Tensor + inline at::Tensor _reshape_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size) { + return at::_ops::_reshape_copy::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size)); + } + + // aten::_reshape_copy(Tensor self, SymInt[] size) -> Tensor + inline at::Tensor _reshape_copy_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size) { + return at::_ops::_reshape_copy::redispatch(dispatchKeySet, self, size); + } + + // aten::_reshape_alias(Tensor(a) self, SymInt[] size, SymInt[] stride) -> Tensor(a) + inline at::Tensor _reshape_alias(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, at::IntArrayRef stride) { + return at::_ops::_reshape_alias::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride)); + } + + // aten::_reshape_alias(Tensor(a) self, SymInt[] size, SymInt[] stride) -> Tensor(a) + inline at::Tensor _reshape_alias_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride) { + return at::_ops::_reshape_alias::redispatch(dispatchKeySet, self, size, stride); + } + + // aten::_mkldnn_reshape(Tensor self, int[] shape) -> Tensor + inline at::Tensor _mkldnn_reshape(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef shape) { + return at::_ops::_mkldnn_reshape::redispatch(dispatchKeySet, self, shape); + } + + // aten::reshape_as(Tensor(a) self, Tensor other) -> Tensor(a) + inline at::Tensor reshape_as(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::reshape_as::redispatch(dispatchKeySet, self, other); + } + + // aten::round(Tensor self) -> Tensor + inline at::Tensor round(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::round::redispatch(dispatchKeySet, self); + } + + // aten::round_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & round_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::round_::redispatch(dispatchKeySet, self); + } + + // aten::round.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & round_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::round_out::redispatch(dispatchKeySet, self, out); + } + + // aten::round.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & round_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::round_out::redispatch(dispatchKeySet, self, out); + } + + // aten::round.decimals(Tensor self, *, int decimals) -> Tensor + inline at::Tensor round(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t decimals) { + return at::_ops::round_decimals::redispatch(dispatchKeySet, self, decimals); + } + + // aten::round_.decimals(Tensor(a!) self, *, int decimals) -> Tensor(a!) + inline at::Tensor & round_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, int64_t decimals) { + return at::_ops::round__decimals::redispatch(dispatchKeySet, self, decimals); + } + + // aten::round.decimals_out(Tensor self, *, int decimals, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & round_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t decimals) { + return at::_ops::round_decimals_out::redispatch(dispatchKeySet, self, decimals, out); + } + + // aten::round.decimals_out(Tensor self, *, int decimals, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & round_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t decimals, at::Tensor & out) { + return at::_ops::round_decimals_out::redispatch(dispatchKeySet, self, decimals, out); + } + + // aten::rrelu(Tensor self, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor + inline at::Tensor rrelu(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & lower=0.125, const at::Scalar & upper=0.3333333333333333, bool training=false, ::std::optional generator=::std::nullopt) { + return at::_ops::rrelu::redispatch(dispatchKeySet, self, lower, upper, training, generator); + } + + // aten::rrelu_(Tensor(a!) self, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor(a!) + inline at::Tensor & rrelu_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & lower=0.125, const at::Scalar & upper=0.3333333333333333, bool training=false, ::std::optional generator=::std::nullopt) { + return at::_ops::rrelu_::redispatch(dispatchKeySet, self, lower, upper, training, generator); + } + + // aten::relu(Tensor self) -> Tensor + inline at::Tensor relu(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::relu::redispatch(dispatchKeySet, self); + } + + // aten::relu_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & relu_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::relu_::redispatch(dispatchKeySet, self); + } + + // aten::relu6(Tensor self) -> Tensor + inline at::Tensor relu6(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::relu6::redispatch(dispatchKeySet, self); + } + + // aten::relu6_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & relu6_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::relu6_::redispatch(dispatchKeySet, self); + } + + // aten::prelu(Tensor self, Tensor weight) -> Tensor + inline at::Tensor prelu(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight) { + return at::_ops::prelu::redispatch(dispatchKeySet, self, weight); + } + + // aten::_prelu_kernel(Tensor self, Tensor weight) -> Tensor + inline at::Tensor _prelu_kernel(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight) { + return at::_ops::_prelu_kernel::redispatch(dispatchKeySet, self, weight); + } + + // aten::_prelu_kernel_backward(Tensor grad_output, Tensor self, Tensor weight) -> (Tensor, Tensor) + inline ::std::tuple _prelu_kernel_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & weight) { + return at::_ops::_prelu_kernel_backward::redispatch(dispatchKeySet, grad_output, self, weight); + } + + // aten::gelu.out(Tensor self, *, str approximate='none', Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & gelu_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::string_view approximate="none") { + return at::_ops::gelu_out::redispatch(dispatchKeySet, self, approximate, out); + } + + // aten::gelu.out(Tensor self, *, str approximate='none', Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & gelu_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::string_view approximate, at::Tensor & out) { + return at::_ops::gelu_out::redispatch(dispatchKeySet, self, approximate, out); + } + + // aten::gelu_(Tensor(a!) self, *, str approximate='none') -> Tensor(a!) + inline at::Tensor & gelu_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, c10::string_view approximate="none") { + return at::_ops::gelu_::redispatch(dispatchKeySet, self, approximate); + } + + // aten::gelu(Tensor self, *, str approximate='none') -> Tensor + inline at::Tensor gelu(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::string_view approximate="none") { + return at::_ops::gelu::redispatch(dispatchKeySet, self, approximate); + } + + // aten::gelu_backward.grad_input(Tensor grad_output, Tensor self, *, str approximate='none', Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & gelu_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, c10::string_view approximate="none") { + return at::_ops::gelu_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, approximate, grad_input); + } + + // aten::gelu_backward.grad_input(Tensor grad_output, Tensor self, *, str approximate='none', Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & gelu_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, c10::string_view approximate, at::Tensor & grad_input) { + return at::_ops::gelu_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, approximate, grad_input); + } + + // aten::gelu_backward(Tensor grad_output, Tensor self, *, str approximate='none') -> Tensor + inline at::Tensor gelu_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, c10::string_view approximate="none") { + return at::_ops::gelu_backward::redispatch(dispatchKeySet, grad_output, self, approximate); + } + + // aten::infinitely_differentiable_gelu_backward(Tensor grad, Tensor self) -> Tensor + inline at::Tensor infinitely_differentiable_gelu_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & self) { + return at::_ops::infinitely_differentiable_gelu_backward::redispatch(dispatchKeySet, grad, self); + } + + // aten::hardshrink.out(Tensor self, Scalar lambd=0.5, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & hardshrink_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & lambd=0.5) { + return at::_ops::hardshrink_out::redispatch(dispatchKeySet, self, lambd, out); + } + + // aten::hardshrink.out(Tensor self, Scalar lambd=0.5, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & hardshrink_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & lambd, at::Tensor & out) { + return at::_ops::hardshrink_out::redispatch(dispatchKeySet, self, lambd, out); + } + + // aten::hardshrink(Tensor self, Scalar lambd=0.5) -> Tensor + inline at::Tensor hardshrink(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & lambd=0.5) { + return at::_ops::hardshrink::redispatch(dispatchKeySet, self, lambd); + } + + // aten::hardshrink_backward.grad_input(Tensor grad_out, Tensor self, Scalar lambd, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & hardshrink_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_out, const at::Tensor & self, const at::Scalar & lambd) { + return at::_ops::hardshrink_backward_grad_input::redispatch(dispatchKeySet, grad_out, self, lambd, grad_input); + } + + // aten::hardshrink_backward.grad_input(Tensor grad_out, Tensor self, Scalar lambd, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & hardshrink_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_out, const at::Tensor & self, const at::Scalar & lambd, at::Tensor & grad_input) { + return at::_ops::hardshrink_backward_grad_input::redispatch(dispatchKeySet, grad_out, self, lambd, grad_input); + } + + // aten::hardshrink_backward(Tensor grad_out, Tensor self, Scalar lambd) -> Tensor + inline at::Tensor hardshrink_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_out, const at::Tensor & self, const at::Scalar & lambd) { + return at::_ops::hardshrink_backward::redispatch(dispatchKeySet, grad_out, self, lambd); + } + + // aten::rsqrt(Tensor self) -> Tensor + inline at::Tensor rsqrt(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::rsqrt::redispatch(dispatchKeySet, self); + } + + // aten::rsqrt_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & rsqrt_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::rsqrt_::redispatch(dispatchKeySet, self); + } + + // aten::rsqrt.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & rsqrt_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::rsqrt_out::redispatch(dispatchKeySet, self, out); + } + + // aten::rsqrt.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & rsqrt_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::rsqrt_out::redispatch(dispatchKeySet, self, out); + } + + // aten::select.Dimname(Tensor(a) self, Dimname dim, int index) -> Tensor(a) + inline at::Tensor select(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, int64_t index) { + return at::_ops::select_Dimname::redispatch(dispatchKeySet, self, dim, index); + } + + // aten::select.int(Tensor(a) self, int dim, SymInt index) -> Tensor(a) + inline at::Tensor select(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, int64_t index) { + return at::_ops::select_int::redispatch(dispatchKeySet, self, dim, index); + } + + // aten::select.int(Tensor(a) self, int dim, SymInt index) -> Tensor(a) + inline at::Tensor select_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, c10::SymInt index) { + return at::_ops::select_int::redispatch(dispatchKeySet, self, dim, index); + } + + // aten::select_backward(Tensor grad_output, SymInt[] input_sizes, int dim, SymInt index) -> Tensor + inline at::Tensor select_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, at::IntArrayRef input_sizes, int64_t dim, int64_t index) { + return at::_ops::select_backward::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(input_sizes), dim, index); + } + + // aten::select_backward(Tensor grad_output, SymInt[] input_sizes, int dim, SymInt index) -> Tensor + inline at::Tensor select_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, c10::SymIntArrayRef input_sizes, int64_t dim, c10::SymInt index) { + return at::_ops::select_backward::redispatch(dispatchKeySet, grad_output, input_sizes, dim, index); + } + + // aten::_nested_select_backward(Tensor grad_output, Tensor self, int dim, SymInt index) -> Tensor + inline at::Tensor _nested_select_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, int64_t dim, int64_t index) { + return at::_ops::_nested_select_backward::redispatch(dispatchKeySet, grad_output, self, dim, index); + } + + // aten::_nested_select_backward(Tensor grad_output, Tensor self, int dim, SymInt index) -> Tensor + inline at::Tensor _nested_select_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, int64_t dim, c10::SymInt index) { + return at::_ops::_nested_select_backward::redispatch(dispatchKeySet, grad_output, self, dim, index); + } + + // aten::selu(Tensor self) -> Tensor + inline at::Tensor selu(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::selu::redispatch(dispatchKeySet, self); + } + + // aten::selu_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & selu_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::selu_::redispatch(dispatchKeySet, self); + } + + // aten::celu(Tensor self, Scalar alpha=1.0) -> Tensor + inline at::Tensor celu(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & alpha=1.0) { + return at::_ops::celu::redispatch(dispatchKeySet, self, alpha); + } + + // aten::celu_(Tensor(a!) self, Scalar alpha=1.0) -> Tensor(a!) + inline at::Tensor & celu_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & alpha=1.0) { + return at::_ops::celu_::redispatch(dispatchKeySet, self, alpha); + } + + // aten::silu(Tensor self) -> Tensor + inline at::Tensor silu(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::silu::redispatch(dispatchKeySet, self); + } + + // aten::silu_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & silu_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::silu_::redispatch(dispatchKeySet, self); + } + + // aten::silu.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & silu_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::silu_out::redispatch(dispatchKeySet, self, out); + } + + // aten::silu.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & silu_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::silu_out::redispatch(dispatchKeySet, self, out); + } + + // aten::silu_backward.grad_input(Tensor grad_output, Tensor self, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & silu_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self) { + return at::_ops::silu_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, grad_input); + } + + // aten::silu_backward.grad_input(Tensor grad_output, Tensor self, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & silu_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::Tensor & grad_input) { + return at::_ops::silu_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, grad_input); + } + + // aten::silu_backward(Tensor grad_output, Tensor self) -> Tensor + inline at::Tensor silu_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self) { + return at::_ops::silu_backward::redispatch(dispatchKeySet, grad_output, self); + } + + // aten::mish(Tensor self) -> Tensor + inline at::Tensor mish(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::mish::redispatch(dispatchKeySet, self); + } + + // aten::mish_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & mish_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::mish_::redispatch(dispatchKeySet, self); + } + + // aten::mish.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mish_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::mish_out::redispatch(dispatchKeySet, self, out); + } + + // aten::mish.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mish_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::mish_out::redispatch(dispatchKeySet, self, out); + } + + // aten::mish_backward(Tensor grad_output, Tensor self) -> Tensor + inline at::Tensor mish_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self) { + return at::_ops::mish_backward::redispatch(dispatchKeySet, grad_output, self); + } + + // aten::sigmoid(Tensor self) -> Tensor + inline at::Tensor sigmoid(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::sigmoid::redispatch(dispatchKeySet, self); + } + + // aten::sigmoid_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & sigmoid_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::sigmoid_::redispatch(dispatchKeySet, self); + } + + // aten::sigmoid.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & sigmoid_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::sigmoid_out::redispatch(dispatchKeySet, self, out); + } + + // aten::sigmoid.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & sigmoid_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::sigmoid_out::redispatch(dispatchKeySet, self, out); + } + + // aten::logit(Tensor self, float? eps=None) -> Tensor + inline at::Tensor logit(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional eps=::std::nullopt) { + return at::_ops::logit::redispatch(dispatchKeySet, self, eps); + } + + // aten::logit_(Tensor(a!) self, float? eps=None) -> Tensor(a!) + inline at::Tensor & logit_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, ::std::optional eps=::std::nullopt) { + return at::_ops::logit_::redispatch(dispatchKeySet, self, eps); + } + + // aten::logit.out(Tensor self, float? eps=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & logit_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, ::std::optional eps=::std::nullopt) { + return at::_ops::logit_out::redispatch(dispatchKeySet, self, eps, out); + } + + // aten::logit.out(Tensor self, float? eps=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & logit_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional eps, at::Tensor & out) { + return at::_ops::logit_out::redispatch(dispatchKeySet, self, eps, out); + } + + // aten::sin(Tensor self) -> Tensor + inline at::Tensor sin(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::sin::redispatch(dispatchKeySet, self); + } + + // aten::sin_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & sin_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::sin_::redispatch(dispatchKeySet, self); + } + + // aten::sin.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & sin_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::sin_out::redispatch(dispatchKeySet, self, out); + } + + // aten::sin.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & sin_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::sin_out::redispatch(dispatchKeySet, self, out); + } + + // aten::sinc(Tensor self) -> Tensor + inline at::Tensor sinc(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::sinc::redispatch(dispatchKeySet, self); + } + + // aten::sinc_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & sinc_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::sinc_::redispatch(dispatchKeySet, self); + } + + // aten::sinc.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & sinc_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::sinc_out::redispatch(dispatchKeySet, self, out); + } + + // aten::sinc.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & sinc_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::sinc_out::redispatch(dispatchKeySet, self, out); + } + + // aten::sinh(Tensor self) -> Tensor + inline at::Tensor sinh(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::sinh::redispatch(dispatchKeySet, self); + } + + // aten::sinh_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & sinh_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::sinh_::redispatch(dispatchKeySet, self); + } + + // aten::sinh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & sinh_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::sinh_out::redispatch(dispatchKeySet, self, out); + } + + // aten::sinh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & sinh_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::sinh_out::redispatch(dispatchKeySet, self, out); + } + + // aten::detach(Tensor(a) self) -> Tensor(a) + inline at::Tensor detach(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::detach::redispatch(dispatchKeySet, self); + } + + // aten::detach_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & detach_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::detach_::redispatch(dispatchKeySet, self); + } + + // aten::size.int(Tensor self, int dim) -> int + inline int64_t __dispatch_size(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim) { + return at::_ops::size_int::redispatch(dispatchKeySet, self, dim); + } + + // aten::size.Dimname(Tensor self, Dimname dim) -> int + inline int64_t size(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim) { + return at::_ops::size_Dimname::redispatch(dispatchKeySet, self, dim); + } + + // aten::sym_size.int(Tensor self, int dim) -> SymInt + inline c10::SymInt __dispatch_sym_size(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim) { + return at::_ops::sym_size_int::redispatch(dispatchKeySet, self, dim); + } + + // aten::sym_is_contiguous(Tensor self, MemoryFormat memory_format=contiguous_format) -> SymBool + inline c10::SymBool __dispatch_sym_is_contiguous(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::MemoryFormat memory_format=c10::MemoryFormat::Contiguous) { + return at::_ops::sym_is_contiguous::redispatch(dispatchKeySet, self, memory_format); + } + + // aten::sym_numel(Tensor self) -> SymInt + inline c10::SymInt __dispatch_sym_numel(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::sym_numel::redispatch(dispatchKeySet, self); + } + + // aten::sym_storage_offset(Tensor self) -> SymInt + inline c10::SymInt __dispatch_sym_storage_offset(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::sym_storage_offset::redispatch(dispatchKeySet, self); + } + + // aten::slice.Tensor(Tensor(a) self, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor(a) + inline at::Tensor slice(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim=0, ::std::optional start=::std::nullopt, ::std::optional end=::std::nullopt, int64_t step=1) { + return at::_ops::slice_Tensor::redispatch(dispatchKeySet, self, dim, start.has_value() ? ::std::make_optional(c10::SymInt(*start)) : ::std::nullopt, end.has_value() ? ::std::make_optional(c10::SymInt(*end)) : ::std::nullopt, step); + } + + // aten::slice.Tensor(Tensor(a) self, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor(a) + inline at::Tensor slice_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim=0, ::std::optional start=::std::nullopt, ::std::optional end=::std::nullopt, c10::SymInt step=1) { + return at::_ops::slice_Tensor::redispatch(dispatchKeySet, self, dim, start, end, step); + } + + // aten::slice_backward(Tensor grad_output, SymInt[] input_sizes, int dim, SymInt start, SymInt end, SymInt step) -> Tensor + inline at::Tensor slice_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, at::IntArrayRef input_sizes, int64_t dim, int64_t start, int64_t end, int64_t step) { + return at::_ops::slice_backward::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(input_sizes), dim, start, end, step); + } + + // aten::slice_backward(Tensor grad_output, SymInt[] input_sizes, int dim, SymInt start, SymInt end, SymInt step) -> Tensor + inline at::Tensor slice_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, c10::SymIntArrayRef input_sizes, int64_t dim, c10::SymInt start, c10::SymInt end, c10::SymInt step) { + return at::_ops::slice_backward::redispatch(dispatchKeySet, grad_output, input_sizes, dim, start, end, step); + } + + // aten::slice_inverse(Tensor(a) self, Tensor src, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor(a) + inline at::Tensor slice_inverse(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & src, int64_t dim=0, ::std::optional start=::std::nullopt, ::std::optional end=::std::nullopt, int64_t step=1) { + return at::_ops::slice_inverse::redispatch(dispatchKeySet, self, src, dim, start.has_value() ? ::std::make_optional(c10::SymInt(*start)) : ::std::nullopt, end.has_value() ? ::std::make_optional(c10::SymInt(*end)) : ::std::nullopt, step); + } + + // aten::slice_inverse(Tensor(a) self, Tensor src, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor(a) + inline at::Tensor slice_inverse_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & src, int64_t dim=0, ::std::optional start=::std::nullopt, ::std::optional end=::std::nullopt, c10::SymInt step=1) { + return at::_ops::slice_inverse::redispatch(dispatchKeySet, self, src, dim, start, end, step); + } + + // aten::slice_scatter(Tensor self, Tensor src, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor + inline at::Tensor slice_scatter(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & src, int64_t dim=0, ::std::optional start=::std::nullopt, ::std::optional end=::std::nullopt, int64_t step=1) { + return at::_ops::slice_scatter::redispatch(dispatchKeySet, self, src, dim, start.has_value() ? ::std::make_optional(c10::SymInt(*start)) : ::std::nullopt, end.has_value() ? ::std::make_optional(c10::SymInt(*end)) : ::std::nullopt, step); + } + + // aten::slice_scatter(Tensor self, Tensor src, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor + inline at::Tensor slice_scatter_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & src, int64_t dim=0, ::std::optional start=::std::nullopt, ::std::optional end=::std::nullopt, c10::SymInt step=1) { + return at::_ops::slice_scatter::redispatch(dispatchKeySet, self, src, dim, start, end, step); + } + + // aten::select_scatter(Tensor self, Tensor src, int dim, SymInt index) -> Tensor + inline at::Tensor select_scatter(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & src, int64_t dim, int64_t index) { + return at::_ops::select_scatter::redispatch(dispatchKeySet, self, src, dim, index); + } + + // aten::select_scatter(Tensor self, Tensor src, int dim, SymInt index) -> Tensor + inline at::Tensor select_scatter_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & src, int64_t dim, c10::SymInt index) { + return at::_ops::select_scatter::redispatch(dispatchKeySet, self, src, dim, index); + } + + // aten::diagonal_scatter(Tensor self, Tensor src, int offset=0, int dim1=0, int dim2=1) -> Tensor + inline at::Tensor diagonal_scatter(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & src, int64_t offset=0, int64_t dim1=0, int64_t dim2=1) { + return at::_ops::diagonal_scatter::redispatch(dispatchKeySet, self, src, offset, dim1, dim2); + } + + // aten::as_strided_scatter(Tensor self, Tensor src, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor + inline at::Tensor as_strided_scatter(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & src, at::IntArrayRef size, at::IntArrayRef stride, ::std::optional storage_offset=::std::nullopt) { + return at::_ops::as_strided_scatter::redispatch(dispatchKeySet, self, src, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride), storage_offset.has_value() ? ::std::make_optional(c10::SymInt(*storage_offset)) : ::std::nullopt); + } + + // aten::as_strided_scatter(Tensor self, Tensor src, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor + inline at::Tensor as_strided_scatter_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & src, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional storage_offset=::std::nullopt) { + return at::_ops::as_strided_scatter::redispatch(dispatchKeySet, self, src, size, stride, storage_offset); + } + + // aten::smm(Tensor self, Tensor mat2) -> Tensor + inline at::Tensor smm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mat2) { + return at::_ops::smm::redispatch(dispatchKeySet, self, mat2); + } + + // aten::softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor + inline at::Tensor softmax(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, ::std::optional dtype=::std::nullopt) { + return at::_ops::softmax_int::redispatch(dispatchKeySet, self, dim, dtype); + } + + // aten::softmax.int_out(Tensor self, int dim, ScalarType? dtype=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & softmax_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim, ::std::optional dtype=::std::nullopt) { + return at::_ops::softmax_int_out::redispatch(dispatchKeySet, self, dim, dtype, out); + } + + // aten::softmax.int_out(Tensor self, int dim, ScalarType? dtype=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & softmax_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, ::std::optional dtype, at::Tensor & out) { + return at::_ops::softmax_int_out::redispatch(dispatchKeySet, self, dim, dtype, out); + } + + // aten::softmax.Dimname(Tensor self, Dimname dim, *, ScalarType? dtype=None) -> Tensor + inline at::Tensor softmax(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, ::std::optional dtype=::std::nullopt) { + return at::_ops::softmax_Dimname::redispatch(dispatchKeySet, self, dim, dtype); + } + + // aten::_softmax(Tensor self, int dim, bool half_to_float) -> Tensor + inline at::Tensor _softmax(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, bool half_to_float) { + return at::_ops::_softmax::redispatch(dispatchKeySet, self, dim, half_to_float); + } + + // aten::_softmax.out(Tensor self, int dim, bool half_to_float, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _softmax_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim, bool half_to_float) { + return at::_ops::_softmax_out::redispatch(dispatchKeySet, self, dim, half_to_float, out); + } + + // aten::_softmax.out(Tensor self, int dim, bool half_to_float, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _softmax_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, bool half_to_float, at::Tensor & out) { + return at::_ops::_softmax_out::redispatch(dispatchKeySet, self, dim, half_to_float, out); + } + + // aten::_softmax_backward_data(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype) -> Tensor + inline at::Tensor _softmax_backward_data(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & output, int64_t dim, at::ScalarType input_dtype) { + return at::_ops::_softmax_backward_data::redispatch(dispatchKeySet, grad_output, output, dim, input_dtype); + } + + // aten::_softmax_backward_data.out(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & _softmax_backward_data_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & output, int64_t dim, at::ScalarType input_dtype) { + return at::_ops::_softmax_backward_data_out::redispatch(dispatchKeySet, grad_output, output, dim, input_dtype, grad_input); + } + + // aten::_softmax_backward_data.out(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & _softmax_backward_data_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & output, int64_t dim, at::ScalarType input_dtype, at::Tensor & grad_input) { + return at::_ops::_softmax_backward_data_out::redispatch(dispatchKeySet, grad_output, output, dim, input_dtype, grad_input); + } + + // aten::unsafe_split.Tensor(Tensor self, SymInt split_size, int dim=0) -> Tensor[] + inline ::std::vector unsafe_split(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t split_size, int64_t dim=0) { + return at::_ops::unsafe_split_Tensor::redispatch(dispatchKeySet, self, split_size, dim); + } + + // aten::unsafe_split.Tensor(Tensor self, SymInt split_size, int dim=0) -> Tensor[] + inline ::std::vector unsafe_split_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymInt split_size, int64_t dim=0) { + return at::_ops::unsafe_split_Tensor::redispatch(dispatchKeySet, self, split_size, dim); + } + + // aten::split.Tensor(Tensor(a -> *) self, SymInt split_size, int dim=0) -> Tensor(a)[] + inline ::std::vector split(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t split_size, int64_t dim=0) { + return at::_ops::split_Tensor::redispatch(dispatchKeySet, self, split_size, dim); + } + + // aten::split.Tensor(Tensor(a -> *) self, SymInt split_size, int dim=0) -> Tensor(a)[] + inline ::std::vector split_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymInt split_size, int64_t dim=0) { + return at::_ops::split_Tensor::redispatch(dispatchKeySet, self, split_size, dim); + } + + // aten::split.sizes(Tensor(a -> *) self, SymInt[] split_size, int dim=0) -> Tensor(a)[] + inline ::std::vector split(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef split_size, int64_t dim=0) { + return at::_ops::split_sizes::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(split_size), dim); + } + + // aten::split.sizes(Tensor(a -> *) self, SymInt[] split_size, int dim=0) -> Tensor(a)[] + inline ::std::vector split_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef split_size, int64_t dim=0) { + return at::_ops::split_sizes::redispatch(dispatchKeySet, self, split_size, dim); + } + + // aten::unsafe_split_with_sizes(Tensor self, SymInt[] split_sizes, int dim=0) -> Tensor[] + inline ::std::vector unsafe_split_with_sizes(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef split_sizes, int64_t dim=0) { + return at::_ops::unsafe_split_with_sizes::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(split_sizes), dim); + } + + // aten::unsafe_split_with_sizes(Tensor self, SymInt[] split_sizes, int dim=0) -> Tensor[] + inline ::std::vector unsafe_split_with_sizes_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef split_sizes, int64_t dim=0) { + return at::_ops::unsafe_split_with_sizes::redispatch(dispatchKeySet, self, split_sizes, dim); + } + + // aten::split_with_sizes(Tensor(a -> *) self, SymInt[] split_sizes, int dim=0) -> Tensor(a)[] + inline ::std::vector split_with_sizes(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef split_sizes, int64_t dim=0) { + return at::_ops::split_with_sizes::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(split_sizes), dim); + } + + // aten::split_with_sizes(Tensor(a -> *) self, SymInt[] split_sizes, int dim=0) -> Tensor(a)[] + inline ::std::vector split_with_sizes_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef split_sizes, int64_t dim=0) { + return at::_ops::split_with_sizes::redispatch(dispatchKeySet, self, split_sizes, dim); + } + + // aten::hsplit.int(Tensor(a -> *) self, int sections) -> Tensor(a)[] + inline ::std::vector hsplit(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t sections) { + return at::_ops::hsplit_int::redispatch(dispatchKeySet, self, sections); + } + + // aten::hsplit.array(Tensor(a -> *) self, int[] indices) -> Tensor(a)[] + inline ::std::vector hsplit(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef indices) { + return at::_ops::hsplit_array::redispatch(dispatchKeySet, self, indices); + } + + // aten::vsplit.int(Tensor(a -> *) self, int sections) -> Tensor(a)[] + inline ::std::vector vsplit(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t sections) { + return at::_ops::vsplit_int::redispatch(dispatchKeySet, self, sections); + } + + // aten::vsplit.array(Tensor(a -> *) self, int[] indices) -> Tensor(a)[] + inline ::std::vector vsplit(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef indices) { + return at::_ops::vsplit_array::redispatch(dispatchKeySet, self, indices); + } + + // aten::dsplit.int(Tensor(a -> *) self, int sections) -> Tensor(a)[] + inline ::std::vector dsplit(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t sections) { + return at::_ops::dsplit_int::redispatch(dispatchKeySet, self, sections); + } + + // aten::dsplit.array(Tensor(a -> *) self, int[] indices) -> Tensor(a)[] + inline ::std::vector dsplit(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef indices) { + return at::_ops::dsplit_array::redispatch(dispatchKeySet, self, indices); + } + + // aten::squeeze(Tensor(a) self) -> Tensor(a) + inline at::Tensor squeeze(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::squeeze::redispatch(dispatchKeySet, self); + } + + // aten::squeeze.dim(Tensor(a) self, int dim) -> Tensor(a) + inline at::Tensor squeeze(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim) { + return at::_ops::squeeze_dim::redispatch(dispatchKeySet, self, dim); + } + + // aten::squeeze.dimname(Tensor(a) self, Dimname dim) -> Tensor(a) + inline at::Tensor squeeze(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim) { + return at::_ops::squeeze_dimname::redispatch(dispatchKeySet, self, dim); + } + + // aten::squeeze.dims(Tensor(a) self, int[] dim) -> Tensor(a) + inline at::Tensor squeeze(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim) { + return at::_ops::squeeze_dims::redispatch(dispatchKeySet, self, dim); + } + + // aten::squeeze_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & squeeze_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::squeeze_::redispatch(dispatchKeySet, self); + } + + // aten::squeeze_.dim(Tensor(a!) self, int dim) -> Tensor(a!) + inline at::Tensor & squeeze_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, int64_t dim) { + return at::_ops::squeeze__dim::redispatch(dispatchKeySet, self, dim); + } + + // aten::squeeze_.dims(Tensor(a!) self, int[] dim) -> Tensor(a!) + inline at::Tensor & squeeze_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, at::IntArrayRef dim) { + return at::_ops::squeeze__dims::redispatch(dispatchKeySet, self, dim); + } + + // aten::squeeze_.dimname(Tensor(a!) self, Dimname dim) -> Tensor(a!) + inline at::Tensor & squeeze_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, at::Dimname dim) { + return at::_ops::squeeze__dimname::redispatch(dispatchKeySet, self, dim); + } + + // aten::sspaddmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor + inline at::Tensor sspaddmm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta=1, const at::Scalar & alpha=1) { + return at::_ops::sspaddmm::redispatch(dispatchKeySet, self, mat1, mat2, beta, alpha); + } + + // aten::sspaddmm.out(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & sspaddmm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta=1, const at::Scalar & alpha=1) { + return at::_ops::sspaddmm_out::redispatch(dispatchKeySet, self, mat1, mat2, beta, alpha, out); + } + + // aten::sspaddmm.out(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & sspaddmm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out) { + return at::_ops::sspaddmm_out::redispatch(dispatchKeySet, self, mat1, mat2, beta, alpha, out); + } + + // aten::_chunk_cat(Tensor[] tensors, int dim, int num_chunks) -> Tensor + inline at::Tensor _chunk_cat(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors, int64_t dim, int64_t num_chunks) { + return at::_ops::_chunk_cat::redispatch(dispatchKeySet, tensors, dim, num_chunks); + } + + // aten::_chunk_cat.out(Tensor[] tensors, int dim, int num_chunks, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _chunk_cat_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::TensorList tensors, int64_t dim, int64_t num_chunks) { + return at::_ops::_chunk_cat_out::redispatch(dispatchKeySet, tensors, dim, num_chunks, out); + } + + // aten::_chunk_cat.out(Tensor[] tensors, int dim, int num_chunks, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _chunk_cat_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors, int64_t dim, int64_t num_chunks, at::Tensor & out) { + return at::_ops::_chunk_cat_out::redispatch(dispatchKeySet, tensors, dim, num_chunks, out); + } + + // aten::stack(Tensor[] tensors, int dim=0) -> Tensor + inline at::Tensor stack(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors, int64_t dim=0) { + return at::_ops::stack::redispatch(dispatchKeySet, tensors, dim); + } + + // aten::stack.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & stack_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::TensorList tensors, int64_t dim=0) { + return at::_ops::stack_out::redispatch(dispatchKeySet, tensors, dim, out); + } + + // aten::stack.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & stack_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors, int64_t dim, at::Tensor & out) { + return at::_ops::stack_out::redispatch(dispatchKeySet, tensors, dim, out); + } + + // aten::_stack(Tensor[] tensors, int dim=0) -> Tensor + inline at::Tensor _stack(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors, int64_t dim=0) { + return at::_ops::_stack::redispatch(dispatchKeySet, tensors, dim); + } + + // aten::_stack.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _stack_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::TensorList tensors, int64_t dim=0) { + return at::_ops::_stack_out::redispatch(dispatchKeySet, tensors, dim, out); + } + + // aten::_stack.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _stack_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors, int64_t dim, at::Tensor & out) { + return at::_ops::_stack_out::redispatch(dispatchKeySet, tensors, dim, out); + } + + // aten::hstack(Tensor[] tensors) -> Tensor + inline at::Tensor hstack(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors) { + return at::_ops::hstack::redispatch(dispatchKeySet, tensors); + } + + // aten::hstack.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & hstack_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::TensorList tensors) { + return at::_ops::hstack_out::redispatch(dispatchKeySet, tensors, out); + } + + // aten::hstack.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & hstack_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors, at::Tensor & out) { + return at::_ops::hstack_out::redispatch(dispatchKeySet, tensors, out); + } + + // aten::vstack(Tensor[] tensors) -> Tensor + inline at::Tensor vstack(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors) { + return at::_ops::vstack::redispatch(dispatchKeySet, tensors); + } + + // aten::vstack.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & vstack_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::TensorList tensors) { + return at::_ops::vstack_out::redispatch(dispatchKeySet, tensors, out); + } + + // aten::vstack.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & vstack_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors, at::Tensor & out) { + return at::_ops::vstack_out::redispatch(dispatchKeySet, tensors, out); + } + + // aten::dstack(Tensor[] tensors) -> Tensor + inline at::Tensor dstack(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors) { + return at::_ops::dstack::redispatch(dispatchKeySet, tensors); + } + + // aten::dstack.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & dstack_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::TensorList tensors) { + return at::_ops::dstack_out::redispatch(dispatchKeySet, tensors, out); + } + + // aten::dstack.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & dstack_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors, at::Tensor & out) { + return at::_ops::dstack_out::redispatch(dispatchKeySet, tensors, out); + } + + // aten::stft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool normalized=False, bool? onesided=None, bool? return_complex=None, bool? align_to_window=None) -> Tensor + inline at::Tensor stft(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t n_fft, ::std::optional hop_length, ::std::optional win_length, const ::std::optional & window, bool normalized, ::std::optional onesided=::std::nullopt, ::std::optional return_complex=::std::nullopt, ::std::optional align_to_window=::std::nullopt) { + return at::_ops::stft::redispatch(dispatchKeySet, self, n_fft, hop_length, win_length, window, normalized, onesided, return_complex, align_to_window); + } + + // aten::stft.center(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool center=True, str pad_mode="reflect", bool normalized=False, bool? onesided=None, bool? return_complex=None, bool? align_to_window=None) -> Tensor + inline at::Tensor stft(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t n_fft, ::std::optional hop_length=::std::nullopt, ::std::optional win_length=::std::nullopt, const ::std::optional & window={}, bool center=true, c10::string_view pad_mode="reflect", bool normalized=false, ::std::optional onesided=::std::nullopt, ::std::optional return_complex=::std::nullopt, ::std::optional align_to_window=::std::nullopt) { + return at::_ops::stft_center::redispatch(dispatchKeySet, self, n_fft, hop_length, win_length, window, center, pad_mode, normalized, onesided, return_complex, align_to_window); + } + + // aten::istft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool center=True, bool normalized=False, bool? onesided=None, int? length=None, bool return_complex=False) -> Tensor + inline at::Tensor istft(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t n_fft, ::std::optional hop_length=::std::nullopt, ::std::optional win_length=::std::nullopt, const ::std::optional & window={}, bool center=true, bool normalized=false, ::std::optional onesided=::std::nullopt, ::std::optional length=::std::nullopt, bool return_complex=false) { + return at::_ops::istft::redispatch(dispatchKeySet, self, n_fft, hop_length, win_length, window, center, normalized, onesided, length, return_complex); + } + + // aten::stride.int(Tensor self, int dim) -> int + inline int64_t __dispatch_stride(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim) { + return at::_ops::stride_int::redispatch(dispatchKeySet, self, dim); + } + + // aten::stride.Dimname(Tensor self, Dimname dim) -> int + inline int64_t stride(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim) { + return at::_ops::stride_Dimname::redispatch(dispatchKeySet, self, dim); + } + + // aten::sym_stride.int(Tensor self, int dim) -> SymInt + inline c10::SymInt __dispatch_sym_stride(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim) { + return at::_ops::sym_stride_int::redispatch(dispatchKeySet, self, dim); + } + + // aten::sum(Tensor self, *, ScalarType? dtype=None) -> Tensor + inline at::Tensor sum(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional dtype=::std::nullopt) { + return at::_ops::sum::redispatch(dispatchKeySet, self, dtype); + } + + // aten::sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor + inline at::Tensor sum(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim=false, ::std::optional dtype=::std::nullopt) { + return at::_ops::sum_dim_IntList::redispatch(dispatchKeySet, self, dim, keepdim, dtype); + } + + // aten::sum.dim_DimnameList(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor + inline at::Tensor sum(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::DimnameList dim, bool keepdim=false, ::std::optional dtype=::std::nullopt) { + return at::_ops::sum_dim_DimnameList::redispatch(dispatchKeySet, self, dim, keepdim, dtype); + } + + // aten::sum.IntList_out(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & sum_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim=false, ::std::optional dtype=::std::nullopt) { + return at::_ops::sum_IntList_out::redispatch(dispatchKeySet, self, dim, keepdim, dtype, out); + } + + // aten::sum.IntList_out(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & sum_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim, ::std::optional dtype, at::Tensor & out) { + return at::_ops::sum_IntList_out::redispatch(dispatchKeySet, self, dim, keepdim, dtype, out); + } + + // aten::sum.DimnameList_out(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & sum_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::DimnameList dim, bool keepdim=false, ::std::optional dtype=::std::nullopt) { + return at::_ops::sum_DimnameList_out::redispatch(dispatchKeySet, self, dim, keepdim, dtype, out); + } + + // aten::sum.DimnameList_out(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & sum_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::DimnameList dim, bool keepdim, ::std::optional dtype, at::Tensor & out) { + return at::_ops::sum_DimnameList_out::redispatch(dispatchKeySet, self, dim, keepdim, dtype, out); + } + + // aten::_nested_sum_backward(Tensor grad, Tensor self, int[1]? dim, bool keepdim=False) -> Tensor + inline at::Tensor _nested_sum_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim=false) { + return at::_ops::_nested_sum_backward::redispatch(dispatchKeySet, grad, self, dim, keepdim); + } + + // aten::nansum(Tensor self, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor + inline at::Tensor nansum(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef dim=::std::nullopt, bool keepdim=false, ::std::optional dtype=::std::nullopt) { + return at::_ops::nansum::redispatch(dispatchKeySet, self, dim, keepdim, dtype); + } + + // aten::nansum.out(Tensor self, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & nansum_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::OptionalIntArrayRef dim=::std::nullopt, bool keepdim=false, ::std::optional dtype=::std::nullopt) { + return at::_ops::nansum_out::redispatch(dispatchKeySet, self, dim, keepdim, dtype, out); + } + + // aten::nansum.out(Tensor self, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & nansum_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim, ::std::optional dtype, at::Tensor & out) { + return at::_ops::nansum_out::redispatch(dispatchKeySet, self, dim, keepdim, dtype, out); + } + + // aten::hash_tensor(Tensor self, int[1] dim=[], *, bool keepdim=False, int mode=0) -> Tensor + inline at::Tensor hash_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim={}, bool keepdim=false, int64_t mode=0) { + return at::_ops::hash_tensor::redispatch(dispatchKeySet, self, dim, keepdim, mode); + } + + // aten::hash_tensor.out(Tensor self, int[1] dim=[], *, bool keepdim=False, int mode=0, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & hash_tensor_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef dim={}, bool keepdim=false, int64_t mode=0) { + return at::_ops::hash_tensor_out::redispatch(dispatchKeySet, self, dim, keepdim, mode, out); + } + + // aten::hash_tensor.out(Tensor self, int[1] dim=[], *, bool keepdim=False, int mode=0, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & hash_tensor_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim, bool keepdim, int64_t mode, at::Tensor & out) { + return at::_ops::hash_tensor_out::redispatch(dispatchKeySet, self, dim, keepdim, mode, out); + } + + // aten::sum_to_size(Tensor self, SymInt[] size) -> Tensor + inline at::Tensor sum_to_size(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size) { + return at::_ops::sum_to_size::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size)); + } + + // aten::sum_to_size(Tensor self, SymInt[] size) -> Tensor + inline at::Tensor sum_to_size_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size) { + return at::_ops::sum_to_size::redispatch(dispatchKeySet, self, size); + } + + // aten::sqrt(Tensor self) -> Tensor + inline at::Tensor sqrt(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::sqrt::redispatch(dispatchKeySet, self); + } + + // aten::sqrt_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & sqrt_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::sqrt_::redispatch(dispatchKeySet, self); + } + + // aten::sqrt.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & sqrt_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::sqrt_out::redispatch(dispatchKeySet, self, out); + } + + // aten::sqrt.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & sqrt_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::sqrt_out::redispatch(dispatchKeySet, self, out); + } + + // aten::square(Tensor self) -> Tensor + inline at::Tensor square(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::square::redispatch(dispatchKeySet, self); + } + + // aten::square_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & square_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::square_::redispatch(dispatchKeySet, self); + } + + // aten::square.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & square_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::square_out::redispatch(dispatchKeySet, self, out); + } + + // aten::square.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & square_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::square_out::redispatch(dispatchKeySet, self, out); + } + + // aten::std(Tensor self, bool unbiased=True) -> Tensor + inline at::Tensor std(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool unbiased) { + return at::_ops::std::redispatch(dispatchKeySet, self, unbiased); + } + + // aten::std.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> Tensor + inline at::Tensor std(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef dim, bool unbiased, bool keepdim=false) { + return at::_ops::std_dim::redispatch(dispatchKeySet, self, dim, unbiased, keepdim); + } + + // aten::std.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> Tensor + inline at::Tensor std(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef dim=::std::nullopt, const ::std::optional & correction=::std::nullopt, bool keepdim=false) { + return at::_ops::std_correction::redispatch(dispatchKeySet, self, dim, correction, keepdim); + } + + // aten::std_mean(Tensor self, bool unbiased=True) -> (Tensor, Tensor) + inline ::std::tuple std_mean(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool unbiased) { + return at::_ops::std_mean::redispatch(dispatchKeySet, self, unbiased); + } + + // aten::std_mean.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> (Tensor, Tensor) + inline ::std::tuple std_mean(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef dim, bool unbiased, bool keepdim=false) { + return at::_ops::std_mean_dim::redispatch(dispatchKeySet, self, dim, unbiased, keepdim); + } + + // aten::std_mean.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> (Tensor, Tensor) + inline ::std::tuple std_mean(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef dim=::std::nullopt, const ::std::optional & correction=::std::nullopt, bool keepdim=false) { + return at::_ops::std_mean_correction::redispatch(dispatchKeySet, self, dim, correction, keepdim); + } + + // aten::std_mean.names_dim(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False) -> (Tensor, Tensor) + inline ::std::tuple std_mean(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::DimnameList dim, bool unbiased, bool keepdim=false) { + return at::_ops::std_mean_names_dim::redispatch(dispatchKeySet, self, dim, unbiased, keepdim); + } + + // aten::std_mean.correction_names(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False) -> (Tensor, Tensor) + inline ::std::tuple std_mean(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::DimnameList dim, const ::std::optional & correction=::std::nullopt, bool keepdim=false) { + return at::_ops::std_mean_correction_names::redispatch(dispatchKeySet, self, dim, correction, keepdim); + } + + // aten::std.out(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & std_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::OptionalIntArrayRef dim, bool unbiased, bool keepdim=false) { + return at::_ops::std_out::redispatch(dispatchKeySet, self, dim, unbiased, keepdim, out); + } + + // aten::std.out(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & std_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef dim, bool unbiased, bool keepdim, at::Tensor & out) { + return at::_ops::std_out::redispatch(dispatchKeySet, self, dim, unbiased, keepdim, out); + } + + // aten::std.correction_out(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & std_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::OptionalIntArrayRef dim=::std::nullopt, const ::std::optional & correction=::std::nullopt, bool keepdim=false) { + return at::_ops::std_correction_out::redispatch(dispatchKeySet, self, dim, correction, keepdim, out); + } + + // aten::std.correction_out(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & std_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef dim, const ::std::optional & correction, bool keepdim, at::Tensor & out) { + return at::_ops::std_correction_out::redispatch(dispatchKeySet, self, dim, correction, keepdim, out); + } + + // aten::std.names_dim(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False) -> Tensor + inline at::Tensor std(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::DimnameList dim, bool unbiased, bool keepdim=false) { + return at::_ops::std_names_dim::redispatch(dispatchKeySet, self, dim, unbiased, keepdim); + } + + // aten::std.names_out(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & std_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::DimnameList dim, bool unbiased, bool keepdim=false) { + return at::_ops::std_names_out::redispatch(dispatchKeySet, self, dim, unbiased, keepdim, out); + } + + // aten::std.names_out(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & std_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::DimnameList dim, bool unbiased, bool keepdim, at::Tensor & out) { + return at::_ops::std_names_out::redispatch(dispatchKeySet, self, dim, unbiased, keepdim, out); + } + + // aten::std.correction_names(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False) -> Tensor + inline at::Tensor std(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::DimnameList dim, const ::std::optional & correction=::std::nullopt, bool keepdim=false) { + return at::_ops::std_correction_names::redispatch(dispatchKeySet, self, dim, correction, keepdim); + } + + // aten::std.correction_names_out(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & std_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::DimnameList dim, const ::std::optional & correction=::std::nullopt, bool keepdim=false) { + return at::_ops::std_correction_names_out::redispatch(dispatchKeySet, self, dim, correction, keepdim, out); + } + + // aten::std.correction_names_out(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & std_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::DimnameList dim, const ::std::optional & correction, bool keepdim, at::Tensor & out) { + return at::_ops::std_correction_names_out::redispatch(dispatchKeySet, self, dim, correction, keepdim, out); + } + + // aten::prod(Tensor self, *, ScalarType? dtype=None) -> Tensor + inline at::Tensor prod(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional dtype=::std::nullopt) { + return at::_ops::prod::redispatch(dispatchKeySet, self, dtype); + } + + // aten::prod.dim_int(Tensor self, int dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor + inline at::Tensor prod(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, bool keepdim=false, ::std::optional dtype=::std::nullopt) { + return at::_ops::prod_dim_int::redispatch(dispatchKeySet, self, dim, keepdim, dtype); + } + + // aten::prod.int_out(Tensor self, int dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & prod_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim, bool keepdim=false, ::std::optional dtype=::std::nullopt) { + return at::_ops::prod_int_out::redispatch(dispatchKeySet, self, dim, keepdim, dtype, out); + } + + // aten::prod.int_out(Tensor self, int dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & prod_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, bool keepdim, ::std::optional dtype, at::Tensor & out) { + return at::_ops::prod_int_out::redispatch(dispatchKeySet, self, dim, keepdim, dtype, out); + } + + // aten::prod.dim_Dimname(Tensor self, Dimname dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor + inline at::Tensor prod(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, bool keepdim=false, ::std::optional dtype=::std::nullopt) { + return at::_ops::prod_dim_Dimname::redispatch(dispatchKeySet, self, dim, keepdim, dtype); + } + + // aten::prod.Dimname_out(Tensor self, Dimname dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & prod_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::Dimname dim, bool keepdim=false, ::std::optional dtype=::std::nullopt) { + return at::_ops::prod_Dimname_out::redispatch(dispatchKeySet, self, dim, keepdim, dtype, out); + } + + // aten::prod.Dimname_out(Tensor self, Dimname dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & prod_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, bool keepdim, ::std::optional dtype, at::Tensor & out) { + return at::_ops::prod_Dimname_out::redispatch(dispatchKeySet, self, dim, keepdim, dtype, out); + } + + // aten::t(Tensor(a) self) -> Tensor(a) + inline at::Tensor t(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::t::redispatch(dispatchKeySet, self); + } + + // aten::t_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & t_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::t_::redispatch(dispatchKeySet, self); + } + + // aten::tan(Tensor self) -> Tensor + inline at::Tensor tan(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::tan::redispatch(dispatchKeySet, self); + } + + // aten::tan_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & tan_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::tan_::redispatch(dispatchKeySet, self); + } + + // aten::tan.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & tan_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::tan_out::redispatch(dispatchKeySet, self, out); + } + + // aten::tan.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & tan_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::tan_out::redispatch(dispatchKeySet, self, out); + } + + // aten::tanh(Tensor self) -> Tensor + inline at::Tensor tanh(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::tanh::redispatch(dispatchKeySet, self); + } + + // aten::tanh_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & tanh_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::tanh_::redispatch(dispatchKeySet, self); + } + + // aten::tanh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & tanh_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::tanh_out::redispatch(dispatchKeySet, self, out); + } + + // aten::tanh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & tanh_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::tanh_out::redispatch(dispatchKeySet, self, out); + } + + // aten::tensordot(Tensor self, Tensor other, int[] dims_self, int[] dims_other) -> Tensor + inline at::Tensor tensordot(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::IntArrayRef dims_self, at::IntArrayRef dims_other) { + return at::_ops::tensordot::redispatch(dispatchKeySet, self, other, dims_self, dims_other); + } + + // aten::tensordot.out(Tensor self, Tensor other, int[] dims_self, int[] dims_other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & tensordot_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other, at::IntArrayRef dims_self, at::IntArrayRef dims_other) { + return at::_ops::tensordot_out::redispatch(dispatchKeySet, self, other, dims_self, dims_other, out); + } + + // aten::tensordot.out(Tensor self, Tensor other, int[] dims_self, int[] dims_other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & tensordot_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::IntArrayRef dims_self, at::IntArrayRef dims_other, at::Tensor & out) { + return at::_ops::tensordot_out::redispatch(dispatchKeySet, self, other, dims_self, dims_other, out); + } + + // aten::threshold(Tensor self, Scalar threshold, Scalar value) -> Tensor + inline at::Tensor threshold(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & threshold, const at::Scalar & value) { + return at::_ops::threshold::redispatch(dispatchKeySet, self, threshold, value); + } + + // aten::threshold_(Tensor(a!) self, Scalar threshold, Scalar value) -> Tensor(a!) + inline at::Tensor & threshold_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & threshold, const at::Scalar & value) { + return at::_ops::threshold_::redispatch(dispatchKeySet, self, threshold, value); + } + + // aten::threshold.out(Tensor self, Scalar threshold, Scalar value, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & threshold_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & threshold, const at::Scalar & value) { + return at::_ops::threshold_out::redispatch(dispatchKeySet, self, threshold, value, out); + } + + // aten::threshold.out(Tensor self, Scalar threshold, Scalar value, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & threshold_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & threshold, const at::Scalar & value, at::Tensor & out) { + return at::_ops::threshold_out::redispatch(dispatchKeySet, self, threshold, value, out); + } + + // aten::threshold_backward.grad_input(Tensor grad_output, Tensor self, Scalar threshold, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & threshold_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, const at::Scalar & threshold) { + return at::_ops::threshold_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, threshold, grad_input); + } + + // aten::threshold_backward.grad_input(Tensor grad_output, Tensor self, Scalar threshold, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & threshold_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Scalar & threshold, at::Tensor & grad_input) { + return at::_ops::threshold_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, threshold, grad_input); + } + + // aten::threshold_backward(Tensor grad_output, Tensor self, Scalar threshold) -> Tensor + inline at::Tensor threshold_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Scalar & threshold) { + return at::_ops::threshold_backward::redispatch(dispatchKeySet, grad_output, self, threshold); + } + + // aten::tile(Tensor self, SymInt[] dims) -> Tensor + inline at::Tensor tile(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dims) { + return at::_ops::tile::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(dims)); + } + + // aten::tile(Tensor self, SymInt[] dims) -> Tensor + inline at::Tensor tile_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef dims) { + return at::_ops::tile::redispatch(dispatchKeySet, self, dims); + } + + // aten::transpose.int(Tensor(a) self, int dim0, int dim1) -> Tensor(a) + inline at::Tensor transpose(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim0, int64_t dim1) { + return at::_ops::transpose_int::redispatch(dispatchKeySet, self, dim0, dim1); + } + + // aten::transpose.Dimname(Tensor(a) self, Dimname dim0, Dimname dim1) -> Tensor(a) + inline at::Tensor transpose(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim0, at::Dimname dim1) { + return at::_ops::transpose_Dimname::redispatch(dispatchKeySet, self, dim0, dim1); + } + + // aten::_mkldnn_transpose(Tensor self, int dim0, int dim1) -> Tensor + inline at::Tensor _mkldnn_transpose(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim0, int64_t dim1) { + return at::_ops::_mkldnn_transpose::redispatch(dispatchKeySet, self, dim0, dim1); + } + + // aten::transpose_(Tensor(a!) self, int dim0, int dim1) -> Tensor(a!) + inline at::Tensor & transpose_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, int64_t dim0, int64_t dim1) { + return at::_ops::transpose_::redispatch(dispatchKeySet, self, dim0, dim1); + } + + // aten::_mkldnn_transpose_(Tensor(a!) self, int dim0, int dim1) -> Tensor(a!) + inline at::Tensor & _mkldnn_transpose_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, int64_t dim0, int64_t dim1) { + return at::_ops::_mkldnn_transpose_::redispatch(dispatchKeySet, self, dim0, dim1); + } + + // aten::one_hot(Tensor self, int num_classes=-1) -> Tensor + inline at::Tensor one_hot(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t num_classes=-1) { + return at::_ops::one_hot::redispatch(dispatchKeySet, self, num_classes); + } + + // aten::flip(Tensor self, int[] dims) -> Tensor + inline at::Tensor flip(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dims) { + return at::_ops::flip::redispatch(dispatchKeySet, self, dims); + } + + // aten::fliplr(Tensor self) -> Tensor + inline at::Tensor fliplr(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::fliplr::redispatch(dispatchKeySet, self); + } + + // aten::flipud(Tensor self) -> Tensor + inline at::Tensor flipud(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::flipud::redispatch(dispatchKeySet, self); + } + + // aten::roll(Tensor self, SymInt[1] shifts, int[1] dims=[]) -> Tensor + inline at::Tensor roll(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef shifts, at::IntArrayRef dims={}) { + return at::_ops::roll::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(shifts), dims); + } + + // aten::roll(Tensor self, SymInt[1] shifts, int[1] dims=[]) -> Tensor + inline at::Tensor roll_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef shifts, at::IntArrayRef dims={}) { + return at::_ops::roll::redispatch(dispatchKeySet, self, shifts, dims); + } + + // aten::rot90(Tensor self, int k=1, int[] dims=[0,1]) -> Tensor + inline at::Tensor rot90(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t k=1, at::IntArrayRef dims={0,1}) { + return at::_ops::rot90::redispatch(dispatchKeySet, self, k, dims); + } + + // aten::trapezoid.x(Tensor y, Tensor x, *, int dim=-1) -> Tensor + inline at::Tensor trapezoid(c10::DispatchKeySet dispatchKeySet, const at::Tensor & y, const at::Tensor & x, int64_t dim=-1) { + return at::_ops::trapezoid_x::redispatch(dispatchKeySet, y, x, dim); + } + + // aten::trapezoid.dx(Tensor y, *, Scalar dx=1, int dim=-1) -> Tensor + inline at::Tensor trapezoid(c10::DispatchKeySet dispatchKeySet, const at::Tensor & y, const at::Scalar & dx=1, int64_t dim=-1) { + return at::_ops::trapezoid_dx::redispatch(dispatchKeySet, y, dx, dim); + } + + // aten::trapz.x(Tensor y, Tensor x, *, int dim=-1) -> Tensor + inline at::Tensor trapz(c10::DispatchKeySet dispatchKeySet, const at::Tensor & y, const at::Tensor & x, int64_t dim=-1) { + return at::_ops::trapz_x::redispatch(dispatchKeySet, y, x, dim); + } + + // aten::trapz.dx(Tensor y, *, float dx=1, int dim=-1) -> Tensor + inline at::Tensor trapz(c10::DispatchKeySet dispatchKeySet, const at::Tensor & y, double dx=1, int64_t dim=-1) { + return at::_ops::trapz_dx::redispatch(dispatchKeySet, y, dx, dim); + } + + // aten::_transform_bias_rescale_qkv(Tensor qkv, Tensor qkv_bias, int num_heads) -> (Tensor, Tensor, Tensor) + inline ::std::tuple _transform_bias_rescale_qkv(c10::DispatchKeySet dispatchKeySet, const at::Tensor & qkv, const at::Tensor & qkv_bias, int64_t num_heads) { + return at::_ops::_transform_bias_rescale_qkv::redispatch(dispatchKeySet, qkv, qkv_bias, num_heads); + } + + // aten::_nested_tensor_from_mask(Tensor t, Tensor mask, bool mask_check=True) -> Tensor + inline at::Tensor _nested_tensor_from_mask(c10::DispatchKeySet dispatchKeySet, const at::Tensor & t, const at::Tensor & mask, bool mask_check=true) { + return at::_ops::_nested_tensor_from_mask::redispatch(dispatchKeySet, t, mask, mask_check); + } + + // aten::_nested_tensor_from_mask_left_aligned(Tensor t, Tensor mask) -> bool + inline bool _nested_tensor_from_mask_left_aligned(c10::DispatchKeySet dispatchKeySet, const at::Tensor & t, const at::Tensor & mask) { + return at::_ops::_nested_tensor_from_mask_left_aligned::redispatch(dispatchKeySet, t, mask); + } + + // aten::_nested_from_padded(Tensor padded, Tensor cpu_nested_shape_example, bool fuse_transform_0213=False) -> Tensor + inline at::Tensor _nested_from_padded(c10::DispatchKeySet dispatchKeySet, const at::Tensor & padded, const at::Tensor & cpu_nested_shape_example, bool fuse_transform_0213=false) { + return at::_ops::_nested_from_padded::redispatch(dispatchKeySet, padded, cpu_nested_shape_example, fuse_transform_0213); + } + + // aten::_nested_tensor_size(Tensor self) -> Tensor + inline at::Tensor _nested_tensor_size(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::_nested_tensor_size::redispatch(dispatchKeySet, self); + } + + // aten::_nested_tensor_strides(Tensor self) -> Tensor + inline at::Tensor _nested_tensor_strides(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::_nested_tensor_strides::redispatch(dispatchKeySet, self); + } + + // aten::_nested_tensor_storage_offsets(Tensor self) -> Tensor + inline at::Tensor _nested_tensor_storage_offsets(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::_nested_tensor_storage_offsets::redispatch(dispatchKeySet, self); + } + + // aten::_nested_from_padded_and_nested_example(Tensor padded, Tensor nt_example) -> Tensor + inline at::Tensor _nested_from_padded_and_nested_example(c10::DispatchKeySet dispatchKeySet, const at::Tensor & padded, const at::Tensor & nt_example) { + return at::_ops::_nested_from_padded_and_nested_example::redispatch(dispatchKeySet, padded, nt_example); + } + + // aten::_nested_view_from_buffer(Tensor(a) self, Tensor nested_size, Tensor nested_strides, Tensor offsets) -> Tensor(a) + inline at::Tensor _nested_view_from_buffer(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & nested_size, const at::Tensor & nested_strides, const at::Tensor & offsets) { + return at::_ops::_nested_view_from_buffer::redispatch(dispatchKeySet, self, nested_size, nested_strides, offsets); + } + + // aten::_nested_view_from_buffer_copy(Tensor self, Tensor nested_size, Tensor nested_strides, Tensor offsets) -> Tensor + inline at::Tensor _nested_view_from_buffer_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & nested_size, const at::Tensor & nested_strides, const at::Tensor & offsets) { + return at::_ops::_nested_view_from_buffer_copy::redispatch(dispatchKeySet, self, nested_size, nested_strides, offsets); + } + + // aten::_nested_view_from_jagged(Tensor(a) self, Tensor offsets, Tensor dummy, Tensor? lengths=None, int ragged_idx=1, Tensor? min_seqlen=None, Tensor? max_seqlen=None) -> Tensor(a) + inline at::Tensor _nested_view_from_jagged(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & offsets, const at::Tensor & dummy, const ::std::optional & lengths={}, int64_t ragged_idx=1, const ::std::optional & min_seqlen={}, const ::std::optional & max_seqlen={}) { + return at::_ops::_nested_view_from_jagged::redispatch(dispatchKeySet, self, offsets, dummy, lengths, ragged_idx, min_seqlen, max_seqlen); + } + + // aten::_nested_view_from_jagged_copy(Tensor self, Tensor offsets, Tensor dummy, Tensor? lengths=None, int ragged_idx=1, Tensor? min_seqlen=None, Tensor? max_seqlen=None) -> Tensor + inline at::Tensor _nested_view_from_jagged_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & offsets, const at::Tensor & dummy, const ::std::optional & lengths={}, int64_t ragged_idx=1, const ::std::optional & min_seqlen={}, const ::std::optional & max_seqlen={}) { + return at::_ops::_nested_view_from_jagged_copy::redispatch(dispatchKeySet, self, offsets, dummy, lengths, ragged_idx, min_seqlen, max_seqlen); + } + + // aten::_nested_get_values(Tensor(a) self) -> Tensor(a) + inline at::Tensor _nested_get_values(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::_nested_get_values::redispatch(dispatchKeySet, self); + } + + // aten::_nested_get_values_copy(Tensor self) -> Tensor + inline at::Tensor _nested_get_values_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::_nested_get_values_copy::redispatch(dispatchKeySet, self); + } + + // aten::_nested_get_offsets(Tensor self) -> Tensor + inline at::Tensor _nested_get_offsets(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::_nested_get_offsets::redispatch(dispatchKeySet, self); + } + + // aten::_nested_get_lengths(Tensor self) -> Tensor + inline at::Tensor _nested_get_lengths(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::_nested_get_lengths::redispatch(dispatchKeySet, self); + } + + // aten::_nested_get_ragged_idx(Tensor self) -> int + inline int64_t _nested_get_ragged_idx(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::_nested_get_ragged_idx::redispatch(dispatchKeySet, self); + } + + // aten::_nested_get_min_seqlen(Tensor self) -> Tensor + inline at::Tensor _nested_get_min_seqlen(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::_nested_get_min_seqlen::redispatch(dispatchKeySet, self); + } + + // aten::_nested_get_max_seqlen(Tensor self) -> Tensor + inline at::Tensor _nested_get_max_seqlen(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::_nested_get_max_seqlen::redispatch(dispatchKeySet, self); + } + + // aten::_nested_get_jagged_dummy(Tensor any) -> Tensor + inline at::Tensor _nested_get_jagged_dummy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & any) { + return at::_ops::_nested_get_jagged_dummy::redispatch(dispatchKeySet, any); + } + + // aten::_nested_compute_contiguous_strides_offsets(Tensor nested_size) -> (Tensor, Tensor) + inline ::std::tuple _nested_compute_contiguous_strides_offsets(c10::DispatchKeySet dispatchKeySet, const at::Tensor & nested_size) { + return at::_ops::_nested_compute_contiguous_strides_offsets::redispatch(dispatchKeySet, nested_size); + } + + // aten::_trilinear(Tensor i1, Tensor i2, Tensor i3, int[] expand1, int[] expand2, int[] expand3, int[] sumdim, int unroll_dim=1) -> Tensor + inline at::Tensor _trilinear(c10::DispatchKeySet dispatchKeySet, const at::Tensor & i1, const at::Tensor & i2, const at::Tensor & i3, at::IntArrayRef expand1, at::IntArrayRef expand2, at::IntArrayRef expand3, at::IntArrayRef sumdim, int64_t unroll_dim=1) { + return at::_ops::_trilinear::redispatch(dispatchKeySet, i1, i2, i3, expand1, expand2, expand3, sumdim, unroll_dim); + } + + // aten::triplet_margin_loss(Tensor anchor, Tensor positive, Tensor negative, float margin=1.0, float p=2, float eps=1e-06, bool swap=False, int reduction=Mean) -> Tensor + inline at::Tensor triplet_margin_loss(c10::DispatchKeySet dispatchKeySet, const at::Tensor & anchor, const at::Tensor & positive, const at::Tensor & negative, double margin=1.0, double p=2, double eps=1e-06, bool swap=false, int64_t reduction=at::Reduction::Mean) { + return at::_ops::triplet_margin_loss::redispatch(dispatchKeySet, anchor, positive, negative, margin, p, eps, swap, reduction); + } + + // aten::trunc(Tensor self) -> Tensor + inline at::Tensor trunc(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::trunc::redispatch(dispatchKeySet, self); + } + + // aten::trunc_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & trunc_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::trunc_::redispatch(dispatchKeySet, self); + } + + // aten::trunc.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & trunc_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::trunc_out::redispatch(dispatchKeySet, self, out); + } + + // aten::trunc.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & trunc_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::trunc_out::redispatch(dispatchKeySet, self, out); + } + + // aten::fix(Tensor self) -> Tensor + inline at::Tensor fix(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::fix::redispatch(dispatchKeySet, self); + } + + // aten::fix_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & fix_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::fix_::redispatch(dispatchKeySet, self); + } + + // aten::fix.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fix_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::fix_out::redispatch(dispatchKeySet, self, out); + } + + // aten::fix.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fix_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::fix_out::redispatch(dispatchKeySet, self, out); + } + + // aten::type_as(Tensor self, Tensor other) -> Tensor + inline at::Tensor type_as(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::type_as::redispatch(dispatchKeySet, self, other); + } + + // aten::_has_compatible_shallow_copy_type(Tensor self, Tensor from) -> bool + inline bool _has_compatible_shallow_copy_type(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & from) { + return at::_ops::_has_compatible_shallow_copy_type::redispatch(dispatchKeySet, self, from); + } + + // aten::_unique(Tensor self, bool sorted=True, bool return_inverse=False) -> (Tensor, Tensor) + inline ::std::tuple _unique(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool sorted=true, bool return_inverse=false) { + return at::_ops::_unique::redispatch(dispatchKeySet, self, sorted, return_inverse); + } + + // aten::unique_dim(Tensor self, int dim, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor) + inline ::std::tuple unique_dim(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, bool sorted=true, bool return_inverse=false, bool return_counts=false) { + return at::_ops::unique_dim::redispatch(dispatchKeySet, self, dim, sorted, return_inverse, return_counts); + } + + // aten::unique_consecutive(Tensor self, bool return_inverse=False, bool return_counts=False, int? dim=None) -> (Tensor, Tensor, Tensor) + inline ::std::tuple unique_consecutive(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool return_inverse=false, bool return_counts=false, ::std::optional dim=::std::nullopt) { + return at::_ops::unique_consecutive::redispatch(dispatchKeySet, self, return_inverse, return_counts, dim); + } + + // aten::unique_dim_consecutive(Tensor self, int dim, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor) + inline ::std::tuple unique_dim_consecutive(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, bool return_inverse=false, bool return_counts=false) { + return at::_ops::unique_dim_consecutive::redispatch(dispatchKeySet, self, dim, return_inverse, return_counts); + } + + // aten::_unique2(Tensor self, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor) + inline ::std::tuple _unique2(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool sorted=true, bool return_inverse=false, bool return_counts=false) { + return at::_ops::_unique2::redispatch(dispatchKeySet, self, sorted, return_inverse, return_counts); + } + + // aten::_unsafe_view(Tensor self, SymInt[] size) -> Tensor + inline at::Tensor _unsafe_view(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size) { + return at::_ops::_unsafe_view::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size)); + } + + // aten::_unsafe_view(Tensor self, SymInt[] size) -> Tensor + inline at::Tensor _unsafe_view_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size) { + return at::_ops::_unsafe_view::redispatch(dispatchKeySet, self, size); + } + + // aten::unsqueeze(Tensor(a) self, int dim) -> Tensor(a) + inline at::Tensor unsqueeze(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim) { + return at::_ops::unsqueeze::redispatch(dispatchKeySet, self, dim); + } + + // aten::unsqueeze_(Tensor(a!) self, int dim) -> Tensor(a!) + inline at::Tensor & unsqueeze_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, int64_t dim) { + return at::_ops::unsqueeze_::redispatch(dispatchKeySet, self, dim); + } + + // aten::vander(Tensor x, int? N=None, bool increasing=False) -> Tensor + inline at::Tensor vander(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, ::std::optional N=::std::nullopt, bool increasing=false) { + return at::_ops::vander::redispatch(dispatchKeySet, x, N, increasing); + } + + // aten::var(Tensor self, bool unbiased=True) -> Tensor + inline at::Tensor var(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool unbiased) { + return at::_ops::var::redispatch(dispatchKeySet, self, unbiased); + } + + // aten::var.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> Tensor + inline at::Tensor var(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef dim, bool unbiased, bool keepdim=false) { + return at::_ops::var_dim::redispatch(dispatchKeySet, self, dim, unbiased, keepdim); + } + + // aten::var.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> Tensor + inline at::Tensor var(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef dim=::std::nullopt, const ::std::optional & correction=::std::nullopt, bool keepdim=false) { + return at::_ops::var_correction::redispatch(dispatchKeySet, self, dim, correction, keepdim); + } + + // aten::var.out(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & var_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::OptionalIntArrayRef dim, bool unbiased, bool keepdim=false) { + return at::_ops::var_out::redispatch(dispatchKeySet, self, dim, unbiased, keepdim, out); + } + + // aten::var.out(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & var_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef dim, bool unbiased, bool keepdim, at::Tensor & out) { + return at::_ops::var_out::redispatch(dispatchKeySet, self, dim, unbiased, keepdim, out); + } + + // aten::var.correction_out(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & var_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::OptionalIntArrayRef dim=::std::nullopt, const ::std::optional & correction=::std::nullopt, bool keepdim=false) { + return at::_ops::var_correction_out::redispatch(dispatchKeySet, self, dim, correction, keepdim, out); + } + + // aten::var.correction_out(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & var_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef dim, const ::std::optional & correction, bool keepdim, at::Tensor & out) { + return at::_ops::var_correction_out::redispatch(dispatchKeySet, self, dim, correction, keepdim, out); + } + + // aten::var.names_dim(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False) -> Tensor + inline at::Tensor var(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::DimnameList dim, bool unbiased, bool keepdim=false) { + return at::_ops::var_names_dim::redispatch(dispatchKeySet, self, dim, unbiased, keepdim); + } + + // aten::var.names_out(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & var_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::DimnameList dim, bool unbiased, bool keepdim=false) { + return at::_ops::var_names_out::redispatch(dispatchKeySet, self, dim, unbiased, keepdim, out); + } + + // aten::var.names_out(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & var_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::DimnameList dim, bool unbiased, bool keepdim, at::Tensor & out) { + return at::_ops::var_names_out::redispatch(dispatchKeySet, self, dim, unbiased, keepdim, out); + } + + // aten::var.correction_names(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False) -> Tensor + inline at::Tensor var(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::DimnameList dim, const ::std::optional & correction=::std::nullopt, bool keepdim=false) { + return at::_ops::var_correction_names::redispatch(dispatchKeySet, self, dim, correction, keepdim); + } + + // aten::var.correction_names_out(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & var_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::DimnameList dim, const ::std::optional & correction=::std::nullopt, bool keepdim=false) { + return at::_ops::var_correction_names_out::redispatch(dispatchKeySet, self, dim, correction, keepdim, out); + } + + // aten::var.correction_names_out(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & var_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::DimnameList dim, const ::std::optional & correction, bool keepdim, at::Tensor & out) { + return at::_ops::var_correction_names_out::redispatch(dispatchKeySet, self, dim, correction, keepdim, out); + } + + // aten::var_mean(Tensor self, bool unbiased=True) -> (Tensor, Tensor) + inline ::std::tuple var_mean(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool unbiased) { + return at::_ops::var_mean::redispatch(dispatchKeySet, self, unbiased); + } + + // aten::var_mean.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> (Tensor, Tensor) + inline ::std::tuple var_mean(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef dim, bool unbiased, bool keepdim=false) { + return at::_ops::var_mean_dim::redispatch(dispatchKeySet, self, dim, unbiased, keepdim); + } + + // aten::var_mean.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> (Tensor, Tensor) + inline ::std::tuple var_mean(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef dim=::std::nullopt, const ::std::optional & correction=::std::nullopt, bool keepdim=false) { + return at::_ops::var_mean_correction::redispatch(dispatchKeySet, self, dim, correction, keepdim); + } + + // aten::var_mean.names_dim(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False) -> (Tensor, Tensor) + inline ::std::tuple var_mean(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::DimnameList dim, bool unbiased, bool keepdim=false) { + return at::_ops::var_mean_names_dim::redispatch(dispatchKeySet, self, dim, unbiased, keepdim); + } + + // aten::var_mean.correction_names(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False) -> (Tensor, Tensor) + inline ::std::tuple var_mean(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::DimnameList dim, const ::std::optional & correction=::std::nullopt, bool keepdim=false) { + return at::_ops::var_mean_correction_names::redispatch(dispatchKeySet, self, dim, correction, keepdim); + } + + // aten::view_as(Tensor(a) self, Tensor other) -> Tensor(a) + inline at::Tensor view_as(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::view_as::redispatch(dispatchKeySet, self, other); + } + + // aten::where.self(Tensor condition, Tensor self, Tensor other) -> Tensor + inline at::Tensor where(c10::DispatchKeySet dispatchKeySet, const at::Tensor & condition, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::where_self::redispatch(dispatchKeySet, condition, self, other); + } + + // aten::where.self_out(Tensor condition, Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & where_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & condition, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::where_self_out::redispatch(dispatchKeySet, condition, self, other, out); + } + + // aten::where.self_out(Tensor condition, Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & where_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & condition, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::where_self_out::redispatch(dispatchKeySet, condition, self, other, out); + } + + // aten::where.ScalarSelf(Tensor condition, Scalar self, Tensor other) -> Tensor + inline at::Tensor where(c10::DispatchKeySet dispatchKeySet, const at::Tensor & condition, const at::Scalar & self, const at::Tensor & other) { + return at::_ops::where_ScalarSelf::redispatch(dispatchKeySet, condition, self, other); + } + + // aten::where.ScalarOther(Tensor condition, Tensor self, Scalar other) -> Tensor + inline at::Tensor where(c10::DispatchKeySet dispatchKeySet, const at::Tensor & condition, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::where_ScalarOther::redispatch(dispatchKeySet, condition, self, other); + } + + // aten::where.Scalar(Tensor condition, Scalar self, Scalar other) -> Tensor + inline at::Tensor where(c10::DispatchKeySet dispatchKeySet, const at::Tensor & condition, const at::Scalar & self, const at::Scalar & other) { + return at::_ops::where_Scalar::redispatch(dispatchKeySet, condition, self, other); + } + + // aten::where(Tensor condition) -> Tensor[] + inline ::std::vector where(c10::DispatchKeySet dispatchKeySet, const at::Tensor & condition) { + return at::_ops::where::redispatch(dispatchKeySet, condition); + } + + // aten::norm_except_dim(Tensor v, int pow=2, int dim=0) -> Tensor + inline at::Tensor norm_except_dim(c10::DispatchKeySet dispatchKeySet, const at::Tensor & v, int64_t pow=2, int64_t dim=0) { + return at::_ops::norm_except_dim::redispatch(dispatchKeySet, v, pow, dim); + } + + // aten::_weight_norm(Tensor v, Tensor g, int dim=0) -> Tensor + inline at::Tensor _weight_norm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & v, const at::Tensor & g, int64_t dim=0) { + return at::_ops::_weight_norm::redispatch(dispatchKeySet, v, g, dim); + } + + // aten::_weight_norm_interface(Tensor v, Tensor g, int dim=0) -> (Tensor, Tensor) + inline ::std::tuple _weight_norm_interface(c10::DispatchKeySet dispatchKeySet, const at::Tensor & v, const at::Tensor & g, int64_t dim=0) { + return at::_ops::_weight_norm_interface::redispatch(dispatchKeySet, v, g, dim); + } + + // aten::_weight_norm_interface_backward(Tensor grad_w, Tensor saved_v, Tensor saved_g, Tensor saved_norms, int dim) -> (Tensor, Tensor) + inline ::std::tuple _weight_norm_interface_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_w, const at::Tensor & saved_v, const at::Tensor & saved_g, const at::Tensor & saved_norms, int64_t dim) { + return at::_ops::_weight_norm_interface_backward::redispatch(dispatchKeySet, grad_w, saved_v, saved_g, saved_norms, dim); + } + + // aten::_weight_norm_differentiable_backward(Tensor grad_w, Tensor saved_v, Tensor saved_g, Tensor saved_norms, int dim) -> (Tensor, Tensor) + inline ::std::tuple _weight_norm_differentiable_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_w, const at::Tensor & saved_v, const at::Tensor & saved_g, const at::Tensor & saved_norms, int64_t dim) { + return at::_ops::_weight_norm_differentiable_backward::redispatch(dispatchKeySet, grad_w, saved_v, saved_g, saved_norms, dim); + } + + // aten::zeros.names(int[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor zeros(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, ::std::optional names, at::TensorOptions options={}) { + return at::_ops::zeros_names::redispatch(dispatchKeySet, size, names, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::zeros.names(int[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor zeros(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, ::std::optional names, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::zeros_names::redispatch(dispatchKeySet, size, names, dtype, layout, device, pin_memory); + } + + // aten::_efficientzerotensor(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor _efficientzerotensor(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, at::TensorOptions options={}) { + return at::_ops::_efficientzerotensor::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::_efficientzerotensor(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor _efficientzerotensor(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::_efficientzerotensor::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), dtype, layout, device, pin_memory); + } + + // aten::_efficientzerotensor(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor _efficientzerotensor_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, at::TensorOptions options={}) { + return at::_ops::_efficientzerotensor::redispatch(dispatchKeySet, size, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::_efficientzerotensor(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor _efficientzerotensor_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::_efficientzerotensor::redispatch(dispatchKeySet, size, dtype, layout, device, pin_memory); + } + + // aten::zeros(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor zeros(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, at::TensorOptions options={}) { + return at::_ops::zeros::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::zeros(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor zeros(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::zeros::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), dtype, layout, device, pin_memory); + } + + // aten::zeros(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor zeros_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, at::TensorOptions options={}) { + return at::_ops::zeros::redispatch(dispatchKeySet, size, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::zeros(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor zeros_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::zeros::redispatch(dispatchKeySet, size, dtype, layout, device, pin_memory); + } + + // aten::zeros.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & zeros_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::IntArrayRef size) { + return at::_ops::zeros_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), out); + } + + // aten::zeros.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & zeros_outf(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, at::Tensor & out) { + return at::_ops::zeros_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), out); + } + + // aten::zeros.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & zeros_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, c10::SymIntArrayRef size) { + return at::_ops::zeros_out::redispatch(dispatchKeySet, size, out); + } + + // aten::zeros.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & zeros_symint_outf(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, at::Tensor & out) { + return at::_ops::zeros_out::redispatch(dispatchKeySet, size, out); + } + + // aten::zeros_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + inline at::Tensor zeros_like(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::TensorOptions options={}, ::std::optional memory_format=::std::nullopt) { + return at::_ops::zeros_like::redispatch(dispatchKeySet, self, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), c10::impl::check_tensor_options_and_extract_memory_format(options, memory_format)); + } + + // aten::zeros_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + inline at::Tensor zeros_like(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format) { + return at::_ops::zeros_like::redispatch(dispatchKeySet, self, dtype, layout, device, pin_memory, memory_format); + } + + // aten::_standard_gamma_grad(Tensor self, Tensor output) -> Tensor + inline at::Tensor _standard_gamma_grad(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & output) { + return at::_ops::_standard_gamma_grad::redispatch(dispatchKeySet, self, output); + } + + // aten::_standard_gamma(Tensor self, Generator? generator=None) -> Tensor + inline at::Tensor _standard_gamma(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional generator=::std::nullopt) { + return at::_ops::_standard_gamma::redispatch(dispatchKeySet, self, generator); + } + + // aten::_dirichlet_grad(Tensor x, Tensor alpha, Tensor total) -> Tensor + inline at::Tensor _dirichlet_grad(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Tensor & alpha, const at::Tensor & total) { + return at::_ops::_dirichlet_grad::redispatch(dispatchKeySet, x, alpha, total); + } + + // aten::_sample_dirichlet(Tensor self, Generator? generator=None) -> Tensor + inline at::Tensor _sample_dirichlet(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional generator=::std::nullopt) { + return at::_ops::_sample_dirichlet::redispatch(dispatchKeySet, self, generator); + } + + // aten::poisson(Tensor self, Generator? generator=None) -> Tensor + inline at::Tensor poisson(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional generator=::std::nullopt) { + return at::_ops::poisson::redispatch(dispatchKeySet, self, generator); + } + + // aten::binomial(Tensor count, Tensor prob, Generator? generator=None) -> Tensor + inline at::Tensor binomial(c10::DispatchKeySet dispatchKeySet, const at::Tensor & count, const at::Tensor & prob, ::std::optional generator=::std::nullopt) { + return at::_ops::binomial::redispatch(dispatchKeySet, count, prob, generator); + } + + // aten::native_norm(Tensor self, Scalar p=2) -> Tensor + inline at::Tensor native_norm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & p=2) { + return at::_ops::native_norm::redispatch(dispatchKeySet, self, p); + } + + // aten::native_norm.ScalarOpt_dim_dtype(Tensor self, Scalar? p, int[1] dim, bool keepdim, ScalarType? dtype) -> Tensor + inline at::Tensor native_norm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const ::std::optional & p, at::IntArrayRef dim, bool keepdim, ::std::optional dtype) { + return at::_ops::native_norm_ScalarOpt_dim_dtype::redispatch(dispatchKeySet, self, p, dim, keepdim, dtype); + } + + // aten::_batch_norm_with_update(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor, Tensor) + inline ::std::tuple _batch_norm_with_update(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, at::Tensor & running_mean, at::Tensor & running_var, double momentum, double eps) { + return at::_ops::_batch_norm_with_update::redispatch(dispatchKeySet, input, weight, bias, running_mean, running_var, momentum, eps); + } + + // aten::_batch_norm_with_update.out(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, float momentum, float eps, *, Tensor(d!) out, Tensor(e!) save_mean, Tensor(f!) save_invstd, Tensor(g!) reserve) -> (Tensor(d!), Tensor(e!), Tensor(f!), Tensor(g!)) + inline ::std::tuple _batch_norm_with_update_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::Tensor & save_mean, at::Tensor & save_invstd, at::Tensor & reserve, const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, at::Tensor & running_mean, at::Tensor & running_var, double momentum, double eps) { + return at::_ops::_batch_norm_with_update_out::redispatch(dispatchKeySet, input, weight, bias, running_mean, running_var, momentum, eps, out, save_mean, save_invstd, reserve); + } + + // aten::_batch_norm_with_update.out(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, float momentum, float eps, *, Tensor(d!) out, Tensor(e!) save_mean, Tensor(f!) save_invstd, Tensor(g!) reserve) -> (Tensor(d!), Tensor(e!), Tensor(f!), Tensor(g!)) + inline ::std::tuple _batch_norm_with_update_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, at::Tensor & running_mean, at::Tensor & running_var, double momentum, double eps, at::Tensor & out, at::Tensor & save_mean, at::Tensor & save_invstd, at::Tensor & reserve) { + return at::_ops::_batch_norm_with_update_out::redispatch(dispatchKeySet, input, weight, bias, running_mean, running_var, momentum, eps, out, save_mean, save_invstd, reserve); + } + + // aten::_batch_norm_no_update(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor, Tensor) + inline ::std::tuple _batch_norm_no_update(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const ::std::optional & running_mean, const ::std::optional & running_var, double momentum, double eps) { + return at::_ops::_batch_norm_no_update::redispatch(dispatchKeySet, input, weight, bias, running_mean, running_var, momentum, eps); + } + + // aten::batch_norm_backward(Tensor grad_out, Tensor input, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, bool update, float eps, bool[3] output_mask, Tensor reserve) -> (Tensor, Tensor, Tensor) + inline ::std::tuple batch_norm_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & running_mean, const ::std::optional & running_var, const ::std::optional & save_mean, const ::std::optional & save_var, bool update, double eps, ::std::array output_mask, const at::Tensor & reserve) { + return at::_ops::batch_norm_backward::redispatch(dispatchKeySet, grad_out, input, weight, running_mean, running_var, save_mean, save_var, update, eps, output_mask, reserve); + } + + // aten::_sparse_sum(Tensor self) -> Tensor + inline at::Tensor _sparse_sum(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::_sparse_sum::redispatch(dispatchKeySet, self); + } + + // aten::_sparse_sum.dtype(Tensor self, *, ScalarType dtype) -> Tensor + inline at::Tensor _sparse_sum(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::ScalarType dtype) { + return at::_ops::_sparse_sum_dtype::redispatch(dispatchKeySet, self, dtype); + } + + // aten::_sparse_sum.dim(Tensor self, int[1] dim) -> Tensor + inline at::Tensor _sparse_sum(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim) { + return at::_ops::_sparse_sum_dim::redispatch(dispatchKeySet, self, dim); + } + + // aten::_sparse_sum.dim_dtype(Tensor self, int[1] dim, *, ScalarType dtype) -> Tensor + inline at::Tensor _sparse_sum(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim, at::ScalarType dtype) { + return at::_ops::_sparse_sum_dim_dtype::redispatch(dispatchKeySet, self, dim, dtype); + } + + // aten::_sparse_sum_backward(Tensor grad, Tensor self, int[] dim) -> Tensor + inline at::Tensor _sparse_sum_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & self, at::IntArrayRef dim) { + return at::_ops::_sparse_sum_backward::redispatch(dispatchKeySet, grad, self, dim); + } + + // aten::_sparse_csr_sum.dim_dtype(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor + inline at::Tensor _sparse_csr_sum(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim, bool keepdim=false, ::std::optional dtype=::std::nullopt) { + return at::_ops::_sparse_csr_sum_dim_dtype::redispatch(dispatchKeySet, self, dim, keepdim, dtype); + } + + // aten::_sparse_csr_prod.dim_dtype(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor + inline at::Tensor _sparse_csr_prod(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim, bool keepdim=false, ::std::optional dtype=::std::nullopt) { + return at::_ops::_sparse_csr_prod_dim_dtype::redispatch(dispatchKeySet, self, dim, keepdim, dtype); + } + + // aten::_sparse_softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor + inline at::Tensor _sparse_softmax(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, ::std::optional dtype=::std::nullopt) { + return at::_ops::_sparse_softmax_int::redispatch(dispatchKeySet, self, dim, dtype); + } + + // aten::_sparse_softmax.Dimname(Tensor self, Dimname dim, *, ScalarType? dtype=None) -> Tensor + inline at::Tensor _sparse_softmax(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, ::std::optional dtype=::std::nullopt) { + return at::_ops::_sparse_softmax_Dimname::redispatch(dispatchKeySet, self, dim, dtype); + } + + // aten::_sparse_softmax(Tensor self, int dim, bool half_to_float) -> Tensor + inline at::Tensor _sparse_softmax(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, bool half_to_float) { + return at::_ops::_sparse_softmax::redispatch(dispatchKeySet, self, dim, half_to_float); + } + + // aten::_sparse_softmax_backward_data(Tensor grad_output, Tensor output, int dim, Tensor self) -> Tensor + inline at::Tensor _sparse_softmax_backward_data(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & output, int64_t dim, const at::Tensor & self) { + return at::_ops::_sparse_softmax_backward_data::redispatch(dispatchKeySet, grad_output, output, dim, self); + } + + // aten::_sparse_log_softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor + inline at::Tensor _sparse_log_softmax(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, ::std::optional dtype=::std::nullopt) { + return at::_ops::_sparse_log_softmax_int::redispatch(dispatchKeySet, self, dim, dtype); + } + + // aten::_sparse_log_softmax.Dimname(Tensor self, Dimname dim, *, ScalarType? dtype=None) -> Tensor + inline at::Tensor _sparse_log_softmax(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, ::std::optional dtype=::std::nullopt) { + return at::_ops::_sparse_log_softmax_Dimname::redispatch(dispatchKeySet, self, dim, dtype); + } + + // aten::_sparse_log_softmax(Tensor self, int dim, bool half_to_float) -> Tensor + inline at::Tensor _sparse_log_softmax(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, bool half_to_float) { + return at::_ops::_sparse_log_softmax::redispatch(dispatchKeySet, self, dim, half_to_float); + } + + // aten::_sparse_log_softmax_backward_data(Tensor grad_output, Tensor output, int dim, Tensor self) -> Tensor + inline at::Tensor _sparse_log_softmax_backward_data(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & output, int64_t dim, const at::Tensor & self) { + return at::_ops::_sparse_log_softmax_backward_data::redispatch(dispatchKeySet, grad_output, output, dim, self); + } + + // aten::_spdiags(Tensor diagonals, Tensor offsets, int[] shape, Layout? layout=None) -> Tensor + inline at::Tensor _spdiags(c10::DispatchKeySet dispatchKeySet, const at::Tensor & diagonals, const at::Tensor & offsets, at::IntArrayRef shape, ::std::optional layout=::std::nullopt) { + return at::_ops::_spdiags::redispatch(dispatchKeySet, diagonals, offsets, shape, layout); + } + + // aten::norm.ScalarOpt_dtype(Tensor self, Scalar? p, *, ScalarType dtype) -> Tensor + inline at::Tensor norm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const ::std::optional & p, at::ScalarType dtype) { + return at::_ops::norm_ScalarOpt_dtype::redispatch(dispatchKeySet, self, p, dtype); + } + + // aten::norm.Scalar(Tensor self, Scalar p=2) -> Tensor + inline at::Tensor norm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & p=2) { + return at::_ops::norm_Scalar::redispatch(dispatchKeySet, self, p); + } + + // aten::norm.ScalarOpt_dim_dtype(Tensor self, Scalar? p, int[1] dim, bool keepdim, *, ScalarType dtype) -> Tensor + inline at::Tensor norm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const ::std::optional & p, at::IntArrayRef dim, bool keepdim, at::ScalarType dtype) { + return at::_ops::norm_ScalarOpt_dim_dtype::redispatch(dispatchKeySet, self, p, dim, keepdim, dtype); + } + + // aten::norm.ScalarOpt_dim(Tensor self, Scalar? p, int[1] dim, bool keepdim=False) -> Tensor + inline at::Tensor norm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const ::std::optional & p, at::IntArrayRef dim, bool keepdim=false) { + return at::_ops::norm_ScalarOpt_dim::redispatch(dispatchKeySet, self, p, dim, keepdim); + } + + // aten::norm.dtype_out(Tensor self, Scalar? p, int[1] dim, bool keepdim, *, ScalarType dtype, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & norm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const ::std::optional & p, at::IntArrayRef dim, bool keepdim, at::ScalarType dtype) { + return at::_ops::norm_dtype_out::redispatch(dispatchKeySet, self, p, dim, keepdim, dtype, out); + } + + // aten::norm.dtype_out(Tensor self, Scalar? p, int[1] dim, bool keepdim, *, ScalarType dtype, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & norm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const ::std::optional & p, at::IntArrayRef dim, bool keepdim, at::ScalarType dtype, at::Tensor & out) { + return at::_ops::norm_dtype_out::redispatch(dispatchKeySet, self, p, dim, keepdim, dtype, out); + } + + // aten::norm.out(Tensor self, Scalar? p, int[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & norm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const ::std::optional & p, at::IntArrayRef dim, bool keepdim=false) { + return at::_ops::norm_out::redispatch(dispatchKeySet, self, p, dim, keepdim, out); + } + + // aten::norm.out(Tensor self, Scalar? p, int[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & norm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const ::std::optional & p, at::IntArrayRef dim, bool keepdim, at::Tensor & out) { + return at::_ops::norm_out::redispatch(dispatchKeySet, self, p, dim, keepdim, out); + } + + // aten::norm.names_ScalarOpt_dim_dtype(Tensor self, Scalar? p, Dimname[1] dim, bool keepdim, *, ScalarType dtype) -> Tensor + inline at::Tensor norm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const ::std::optional & p, at::DimnameList dim, bool keepdim, at::ScalarType dtype) { + return at::_ops::norm_names_ScalarOpt_dim_dtype::redispatch(dispatchKeySet, self, p, dim, keepdim, dtype); + } + + // aten::norm.names_ScalarOpt_dim(Tensor self, Scalar? p, Dimname[1] dim, bool keepdim=False) -> Tensor + inline at::Tensor norm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const ::std::optional & p, at::DimnameList dim, bool keepdim=false) { + return at::_ops::norm_names_ScalarOpt_dim::redispatch(dispatchKeySet, self, p, dim, keepdim); + } + + // aten::norm.names_dtype_out(Tensor self, Scalar? p, Dimname[1] dim, bool keepdim, *, ScalarType dtype, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & norm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const ::std::optional & p, at::DimnameList dim, bool keepdim, at::ScalarType dtype) { + return at::_ops::norm_names_dtype_out::redispatch(dispatchKeySet, self, p, dim, keepdim, dtype, out); + } + + // aten::norm.names_dtype_out(Tensor self, Scalar? p, Dimname[1] dim, bool keepdim, *, ScalarType dtype, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & norm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const ::std::optional & p, at::DimnameList dim, bool keepdim, at::ScalarType dtype, at::Tensor & out) { + return at::_ops::norm_names_dtype_out::redispatch(dispatchKeySet, self, p, dim, keepdim, dtype, out); + } + + // aten::norm.names_out(Tensor self, Scalar? p, Dimname[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & norm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const ::std::optional & p, at::DimnameList dim, bool keepdim=false) { + return at::_ops::norm_names_out::redispatch(dispatchKeySet, self, p, dim, keepdim, out); + } + + // aten::norm.names_out(Tensor self, Scalar? p, Dimname[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & norm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const ::std::optional & p, at::DimnameList dim, bool keepdim, at::Tensor & out) { + return at::_ops::norm_names_out::redispatch(dispatchKeySet, self, p, dim, keepdim, out); + } + + // aten::frexp.Tensor(Tensor self) -> (Tensor mantissa, Tensor exponent) + inline ::std::tuple frexp(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::frexp_Tensor::redispatch(dispatchKeySet, self); + } + + // aten::frexp.Tensor_out(Tensor self, *, Tensor(a!) mantissa, Tensor(b!) exponent) -> (Tensor(a!) mantissa, Tensor(b!) exponent) + inline ::std::tuple frexp_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & mantissa, at::Tensor & exponent, const at::Tensor & self) { + return at::_ops::frexp_Tensor_out::redispatch(dispatchKeySet, self, mantissa, exponent); + } + + // aten::frexp.Tensor_out(Tensor self, *, Tensor(a!) mantissa, Tensor(b!) exponent) -> (Tensor(a!) mantissa, Tensor(b!) exponent) + inline ::std::tuple frexp_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & mantissa, at::Tensor & exponent) { + return at::_ops::frexp_Tensor_out::redispatch(dispatchKeySet, self, mantissa, exponent); + } + + // aten::frobenius_norm.dim(Tensor self, int[1] dim, bool keepdim=False) -> Tensor + inline at::Tensor frobenius_norm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim, bool keepdim=false) { + return at::_ops::frobenius_norm_dim::redispatch(dispatchKeySet, self, dim, keepdim); + } + + // aten::frobenius_norm.out(Tensor self, int[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & frobenius_norm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef dim, bool keepdim=false) { + return at::_ops::frobenius_norm_out::redispatch(dispatchKeySet, self, dim, keepdim, out); + } + + // aten::frobenius_norm.out(Tensor self, int[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & frobenius_norm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim, bool keepdim, at::Tensor & out) { + return at::_ops::frobenius_norm_out::redispatch(dispatchKeySet, self, dim, keepdim, out); + } + + // aten::nuclear_norm(Tensor self, bool keepdim=False) -> Tensor + inline at::Tensor nuclear_norm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool keepdim=false) { + return at::_ops::nuclear_norm::redispatch(dispatchKeySet, self, keepdim); + } + + // aten::nuclear_norm.out(Tensor self, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & nuclear_norm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, bool keepdim=false) { + return at::_ops::nuclear_norm_out::redispatch(dispatchKeySet, self, keepdim, out); + } + + // aten::nuclear_norm.out(Tensor self, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & nuclear_norm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool keepdim, at::Tensor & out) { + return at::_ops::nuclear_norm_out::redispatch(dispatchKeySet, self, keepdim, out); + } + + // aten::nuclear_norm.dim(Tensor self, int[2] dim, bool keepdim=False) -> Tensor + inline at::Tensor nuclear_norm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim, bool keepdim=false) { + return at::_ops::nuclear_norm_dim::redispatch(dispatchKeySet, self, dim, keepdim); + } + + // aten::nuclear_norm.dim_out(Tensor self, int[2] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & nuclear_norm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef dim, bool keepdim=false) { + return at::_ops::nuclear_norm_dim_out::redispatch(dispatchKeySet, self, dim, keepdim, out); + } + + // aten::nuclear_norm.dim_out(Tensor self, int[2] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & nuclear_norm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim, bool keepdim, at::Tensor & out) { + return at::_ops::nuclear_norm_dim_out::redispatch(dispatchKeySet, self, dim, keepdim, out); + } + + // aten::clone(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor + inline at::Tensor clone(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional memory_format=::std::nullopt) { + return at::_ops::clone::redispatch(dispatchKeySet, self, memory_format); + } + + // aten::positive(Tensor(a) self) -> Tensor(a) + inline at::Tensor positive(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::positive::redispatch(dispatchKeySet, self); + } + + // aten::resize_as_(Tensor(a!) self, Tensor the_template, *, MemoryFormat? memory_format=None) -> Tensor(a!) + inline const at::Tensor & resize_as_(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & the_template, ::std::optional memory_format=::std::nullopt) { + return at::_ops::resize_as_::redispatch(dispatchKeySet, self, the_template, memory_format); + } + + // aten::resize_as_sparse_(Tensor(a!) self, Tensor the_template) -> Tensor(a!) + inline const at::Tensor & resize_as_sparse_(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & the_template) { + return at::_ops::resize_as_sparse_::redispatch(dispatchKeySet, self, the_template); + } + + // aten::zero_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & zero_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::zero_::redispatch(dispatchKeySet, self); + } + + // aten::sub.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & sub_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha=1) { + return at::_ops::sub_out::redispatch(dispatchKeySet, self, other, alpha, out); + } + + // aten::sub.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & sub_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha, at::Tensor & out) { + return at::_ops::sub_out::redispatch(dispatchKeySet, self, other, alpha, out); + } + + // aten::sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor + inline at::Tensor sub(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha=1) { + return at::_ops::sub_Tensor::redispatch(dispatchKeySet, self, other, alpha); + } + + // aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!) + inline at::Tensor & sub_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha=1) { + return at::_ops::sub__Tensor::redispatch(dispatchKeySet, self, other, alpha); + } + + // aten::sub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor + inline at::Tensor sub(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha=1) { + return at::_ops::sub_Scalar::redispatch(dispatchKeySet, self, other, alpha); + } + + // aten::sub_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!) + inline at::Tensor & sub_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha=1) { + return at::_ops::sub__Scalar::redispatch(dispatchKeySet, self, other, alpha); + } + + // aten::subtract.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & subtract_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha=1) { + return at::_ops::subtract_out::redispatch(dispatchKeySet, self, other, alpha, out); + } + + // aten::subtract.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & subtract_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha, at::Tensor & out) { + return at::_ops::subtract_out::redispatch(dispatchKeySet, self, other, alpha, out); + } + + // aten::subtract.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor + inline at::Tensor subtract(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha=1) { + return at::_ops::subtract_Tensor::redispatch(dispatchKeySet, self, other, alpha); + } + + // aten::subtract_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!) + inline at::Tensor & subtract_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha=1) { + return at::_ops::subtract__Tensor::redispatch(dispatchKeySet, self, other, alpha); + } + + // aten::subtract.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor + inline at::Tensor subtract(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha=1) { + return at::_ops::subtract_Scalar::redispatch(dispatchKeySet, self, other, alpha); + } + + // aten::subtract_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!) + inline at::Tensor & subtract_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha=1) { + return at::_ops::subtract__Scalar::redispatch(dispatchKeySet, self, other, alpha); + } + + // aten::rsub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor + inline at::Tensor rsub(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha=1) { + return at::_ops::rsub_Tensor::redispatch(dispatchKeySet, self, other, alpha); + } + + // aten::heaviside.out(Tensor self, Tensor values, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & heaviside_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & values) { + return at::_ops::heaviside_out::redispatch(dispatchKeySet, self, values, out); + } + + // aten::heaviside.out(Tensor self, Tensor values, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & heaviside_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & values, at::Tensor & out) { + return at::_ops::heaviside_out::redispatch(dispatchKeySet, self, values, out); + } + + // aten::heaviside(Tensor self, Tensor values) -> Tensor + inline at::Tensor heaviside(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & values) { + return at::_ops::heaviside::redispatch(dispatchKeySet, self, values); + } + + // aten::heaviside_(Tensor(a!) self, Tensor values) -> Tensor(a!) + inline at::Tensor & heaviside_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & values) { + return at::_ops::heaviside_::redispatch(dispatchKeySet, self, values); + } + + // aten::rsub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor + inline at::Tensor rsub(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha=1) { + return at::_ops::rsub_Scalar::redispatch(dispatchKeySet, self, other, alpha); + } + + // aten::_sparse_addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor + inline at::Tensor _sparse_addmm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta=1, const at::Scalar & alpha=1) { + return at::_ops::_sparse_addmm::redispatch(dispatchKeySet, self, mat1, mat2, beta, alpha); + } + + // aten::sparse_sampled_addmm.out(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & sparse_sampled_addmm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta=1, const at::Scalar & alpha=1) { + return at::_ops::sparse_sampled_addmm_out::redispatch(dispatchKeySet, self, mat1, mat2, beta, alpha, out); + } + + // aten::sparse_sampled_addmm.out(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & sparse_sampled_addmm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out) { + return at::_ops::sparse_sampled_addmm_out::redispatch(dispatchKeySet, self, mat1, mat2, beta, alpha, out); + } + + // aten::sparse_sampled_addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor + inline at::Tensor sparse_sampled_addmm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta=1, const at::Scalar & alpha=1) { + return at::_ops::sparse_sampled_addmm::redispatch(dispatchKeySet, self, mat1, mat2, beta, alpha); + } + + // aten::_sparse_mm_reduce_impl(Tensor self, Tensor other, str reduce) -> (Tensor, Tensor) + inline ::std::tuple _sparse_mm_reduce_impl(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, c10::string_view reduce) { + return at::_ops::_sparse_mm_reduce_impl::redispatch(dispatchKeySet, self, other, reduce); + } + + // aten::_sparse_mm_reduce_impl_backward(Tensor self, Tensor grad_out, Tensor weight, str reduce, Tensor arg_out, bool[2] output_mask) -> (Tensor, Tensor) + inline ::std::tuple _sparse_mm_reduce_impl_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & grad_out, const at::Tensor & weight, c10::string_view reduce, const at::Tensor & arg_out, ::std::array output_mask) { + return at::_ops::_sparse_mm_reduce_impl_backward::redispatch(dispatchKeySet, self, grad_out, weight, reduce, arg_out, output_mask); + } + + // aten::addmm.out(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & addmm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta=1, const at::Scalar & alpha=1) { + return at::_ops::addmm_out::redispatch(dispatchKeySet, self, mat1, mat2, beta, alpha, out); + } + + // aten::addmm.out(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & addmm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out) { + return at::_ops::addmm_out::redispatch(dispatchKeySet, self, mat1, mat2, beta, alpha, out); + } + + // aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor + inline at::Tensor addmm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta=1, const at::Scalar & alpha=1) { + return at::_ops::addmm::redispatch(dispatchKeySet, self, mat1, mat2, beta, alpha); + } + + // aten::addmm.dtype(Tensor self, Tensor mat1, Tensor mat2, ScalarType out_dtype, *, Scalar beta=1, Scalar alpha=1) -> Tensor + inline at::Tensor addmm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, at::ScalarType out_dtype, const at::Scalar & beta=1, const at::Scalar & alpha=1) { + return at::_ops::addmm_dtype::redispatch(dispatchKeySet, self, mat1, mat2, out_dtype, beta, alpha); + } + + // aten::addmm.dtype_out(Tensor self, Tensor mat1, Tensor mat2, ScalarType out_dtype, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & addmm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, at::ScalarType out_dtype, const at::Scalar & beta=1, const at::Scalar & alpha=1) { + return at::_ops::addmm_dtype_out::redispatch(dispatchKeySet, self, mat1, mat2, out_dtype, beta, alpha, out); + } + + // aten::addmm.dtype_out(Tensor self, Tensor mat1, Tensor mat2, ScalarType out_dtype, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & addmm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, at::ScalarType out_dtype, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out) { + return at::_ops::addmm_dtype_out::redispatch(dispatchKeySet, self, mat1, mat2, out_dtype, beta, alpha, out); + } + + // aten::addmm_(Tensor(a!) self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!) + inline at::Tensor & addmm_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta=1, const at::Scalar & alpha=1) { + return at::_ops::addmm_::redispatch(dispatchKeySet, self, mat1, mat2, beta, alpha); + } + + // aten::_addmm_activation.out(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, bool use_gelu=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _addmm_activation_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta=1, const at::Scalar & alpha=1, bool use_gelu=false) { + return at::_ops::_addmm_activation_out::redispatch(dispatchKeySet, self, mat1, mat2, beta, alpha, use_gelu, out); + } + + // aten::_addmm_activation.out(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, bool use_gelu=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _addmm_activation_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta, const at::Scalar & alpha, bool use_gelu, at::Tensor & out) { + return at::_ops::_addmm_activation_out::redispatch(dispatchKeySet, self, mat1, mat2, beta, alpha, use_gelu, out); + } + + // aten::_addmm_activation(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, bool use_gelu=False) -> Tensor + inline at::Tensor _addmm_activation(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta=1, const at::Scalar & alpha=1, bool use_gelu=false) { + return at::_ops::_addmm_activation::redispatch(dispatchKeySet, self, mat1, mat2, beta, alpha, use_gelu); + } + + // aten::_scaled_mm(Tensor self, Tensor mat2, Tensor scale_a, Tensor scale_b, Tensor? bias=None, Tensor? scale_result=None, ScalarType? out_dtype=None, bool use_fast_accum=False) -> Tensor + inline at::Tensor _scaled_mm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mat2, const at::Tensor & scale_a, const at::Tensor & scale_b, const ::std::optional & bias={}, const ::std::optional & scale_result={}, ::std::optional out_dtype=::std::nullopt, bool use_fast_accum=false) { + return at::_ops::_scaled_mm::redispatch(dispatchKeySet, self, mat2, scale_a, scale_b, bias, scale_result, out_dtype, use_fast_accum); + } + + // aten::_scaled_mm.out(Tensor self, Tensor mat2, Tensor scale_a, Tensor scale_b, Tensor? bias=None, Tensor? scale_result=None, ScalarType? out_dtype=None, bool use_fast_accum=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _scaled_mm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & mat2, const at::Tensor & scale_a, const at::Tensor & scale_b, const ::std::optional & bias={}, const ::std::optional & scale_result={}, ::std::optional out_dtype=::std::nullopt, bool use_fast_accum=false) { + return at::_ops::_scaled_mm_out::redispatch(dispatchKeySet, self, mat2, scale_a, scale_b, bias, scale_result, out_dtype, use_fast_accum, out); + } + + // aten::_scaled_mm.out(Tensor self, Tensor mat2, Tensor scale_a, Tensor scale_b, Tensor? bias=None, Tensor? scale_result=None, ScalarType? out_dtype=None, bool use_fast_accum=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _scaled_mm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mat2, const at::Tensor & scale_a, const at::Tensor & scale_b, const ::std::optional & bias, const ::std::optional & scale_result, ::std::optional out_dtype, bool use_fast_accum, at::Tensor & out) { + return at::_ops::_scaled_mm_out::redispatch(dispatchKeySet, self, mat2, scale_a, scale_b, bias, scale_result, out_dtype, use_fast_accum, out); + } + + // aten::_scaled_mm_v2(Tensor self, Tensor mat2, Tensor[] scale_a, int[] recipe_a, int[] swizzle_a, Tensor[] scale_b, int[] recipe_b, int[] swizzle_b, Tensor? bias, ScalarType? out_dtype, int[] contraction_dim=[], bool use_fast_accum=False) -> Tensor + inline at::Tensor _scaled_mm_v2(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mat2, at::TensorList scale_a, at::IntArrayRef recipe_a, at::IntArrayRef swizzle_a, at::TensorList scale_b, at::IntArrayRef recipe_b, at::IntArrayRef swizzle_b, const ::std::optional & bias, ::std::optional out_dtype, at::IntArrayRef contraction_dim={}, bool use_fast_accum=false) { + return at::_ops::_scaled_mm_v2::redispatch(dispatchKeySet, self, mat2, scale_a, recipe_a, swizzle_a, scale_b, recipe_b, swizzle_b, bias, out_dtype, contraction_dim, use_fast_accum); + } + + // aten::_scaled_mm_v2.out(Tensor self, Tensor mat2, Tensor[] scale_a, int[] recipe_a, int[] swizzle_a, Tensor[] scale_b, int[] recipe_b, int[] swizzle_b, Tensor? bias, ScalarType? out_dtype, int[] contraction_dim=[], bool use_fast_accum=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _scaled_mm_v2_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & mat2, at::TensorList scale_a, at::IntArrayRef recipe_a, at::IntArrayRef swizzle_a, at::TensorList scale_b, at::IntArrayRef recipe_b, at::IntArrayRef swizzle_b, const ::std::optional & bias, ::std::optional out_dtype, at::IntArrayRef contraction_dim={}, bool use_fast_accum=false) { + return at::_ops::_scaled_mm_v2_out::redispatch(dispatchKeySet, self, mat2, scale_a, recipe_a, swizzle_a, scale_b, recipe_b, swizzle_b, bias, out_dtype, contraction_dim, use_fast_accum, out); + } + + // aten::_scaled_mm_v2.out(Tensor self, Tensor mat2, Tensor[] scale_a, int[] recipe_a, int[] swizzle_a, Tensor[] scale_b, int[] recipe_b, int[] swizzle_b, Tensor? bias, ScalarType? out_dtype, int[] contraction_dim=[], bool use_fast_accum=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _scaled_mm_v2_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mat2, at::TensorList scale_a, at::IntArrayRef recipe_a, at::IntArrayRef swizzle_a, at::TensorList scale_b, at::IntArrayRef recipe_b, at::IntArrayRef swizzle_b, const ::std::optional & bias, ::std::optional out_dtype, at::IntArrayRef contraction_dim, bool use_fast_accum, at::Tensor & out) { + return at::_ops::_scaled_mm_v2_out::redispatch(dispatchKeySet, self, mat2, scale_a, recipe_a, swizzle_a, scale_b, recipe_b, swizzle_b, bias, out_dtype, contraction_dim, use_fast_accum, out); + } + + // aten::_scaled_grouped_mm(Tensor self, Tensor mat2, Tensor scale_a, Tensor scale_b, Tensor? offs=None, Tensor? bias=None, Tensor? scale_result=None, ScalarType? out_dtype=None, bool use_fast_accum=False) -> Tensor + inline at::Tensor _scaled_grouped_mm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mat2, const at::Tensor & scale_a, const at::Tensor & scale_b, const ::std::optional & offs={}, const ::std::optional & bias={}, const ::std::optional & scale_result={}, ::std::optional out_dtype=::std::nullopt, bool use_fast_accum=false) { + return at::_ops::_scaled_grouped_mm::redispatch(dispatchKeySet, self, mat2, scale_a, scale_b, offs, bias, scale_result, out_dtype, use_fast_accum); + } + + // aten::_scaled_grouped_mm_v2(Tensor self, Tensor mat2, Tensor[] scale_a, int[] recipe_a, int[] swizzle_a, Tensor[] scale_b, int[] recipe_b, int[] swizzle_b, Tensor? offs=None, Tensor? bias=None, ScalarType? out_dtype=None, int[] contraction_dim=[], bool use_fast_accum=False) -> Tensor + inline at::Tensor _scaled_grouped_mm_v2(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mat2, at::TensorList scale_a, at::IntArrayRef recipe_a, at::IntArrayRef swizzle_a, at::TensorList scale_b, at::IntArrayRef recipe_b, at::IntArrayRef swizzle_b, const ::std::optional & offs={}, const ::std::optional & bias={}, ::std::optional out_dtype=::std::nullopt, at::IntArrayRef contraction_dim={}, bool use_fast_accum=false) { + return at::_ops::_scaled_grouped_mm_v2::redispatch(dispatchKeySet, self, mat2, scale_a, recipe_a, swizzle_a, scale_b, recipe_b, swizzle_b, offs, bias, out_dtype, contraction_dim, use_fast_accum); + } + + // aten::_grouped_mm(Tensor self, Tensor mat2, Tensor? offs=None, Tensor? bias=None, ScalarType? out_dtype=None) -> Tensor + inline at::Tensor _grouped_mm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mat2, const ::std::optional & offs={}, const ::std::optional & bias={}, ::std::optional out_dtype=::std::nullopt) { + return at::_ops::_grouped_mm::redispatch(dispatchKeySet, self, mat2, offs, bias, out_dtype); + } + + // aten::_sparse_compressed_tensor_with_dims(int nnz, int dense_dim, int[] size, int[] blocksize, ScalarType index_dtype, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor + inline at::Tensor _sparse_compressed_tensor_with_dims(c10::DispatchKeySet dispatchKeySet, int64_t nnz, int64_t dense_dim, at::IntArrayRef size, at::IntArrayRef blocksize, at::ScalarType index_dtype, at::TensorOptions options) { + return at::_ops::_sparse_compressed_tensor_with_dims::redispatch(dispatchKeySet, nnz, dense_dim, size, blocksize, index_dtype, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::_sparse_compressed_tensor_with_dims(int nnz, int dense_dim, int[] size, int[] blocksize, ScalarType index_dtype, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor + inline at::Tensor _sparse_compressed_tensor_with_dims(c10::DispatchKeySet dispatchKeySet, int64_t nnz, int64_t dense_dim, at::IntArrayRef size, at::IntArrayRef blocksize, at::ScalarType index_dtype, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::_sparse_compressed_tensor_with_dims::redispatch(dispatchKeySet, nnz, dense_dim, size, blocksize, index_dtype, dtype, layout, device, pin_memory); + } + + // aten::sparse_compressed_tensor.comp_plain_value_size(Tensor compressed_indices, Tensor plain_indices, Tensor values, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor + inline at::Tensor sparse_compressed_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & compressed_indices, const at::Tensor & plain_indices, const at::Tensor & values, at::IntArrayRef size, at::TensorOptions options) { + return at::_ops::sparse_compressed_tensor_comp_plain_value_size::redispatch(dispatchKeySet, compressed_indices, plain_indices, values, c10::fromIntArrayRefSlow(size), c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::sparse_compressed_tensor.comp_plain_value_size(Tensor compressed_indices, Tensor plain_indices, Tensor values, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor + inline at::Tensor sparse_compressed_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & compressed_indices, const at::Tensor & plain_indices, const at::Tensor & values, at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::sparse_compressed_tensor_comp_plain_value_size::redispatch(dispatchKeySet, compressed_indices, plain_indices, values, c10::fromIntArrayRefSlow(size), dtype, layout, device, pin_memory); + } + + // aten::sparse_compressed_tensor.comp_plain_value_size(Tensor compressed_indices, Tensor plain_indices, Tensor values, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor + inline at::Tensor sparse_compressed_tensor_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & compressed_indices, const at::Tensor & plain_indices, const at::Tensor & values, c10::SymIntArrayRef size, at::TensorOptions options) { + return at::_ops::sparse_compressed_tensor_comp_plain_value_size::redispatch(dispatchKeySet, compressed_indices, plain_indices, values, size, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::sparse_compressed_tensor.comp_plain_value_size(Tensor compressed_indices, Tensor plain_indices, Tensor values, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor + inline at::Tensor sparse_compressed_tensor_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & compressed_indices, const at::Tensor & plain_indices, const at::Tensor & values, c10::SymIntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::sparse_compressed_tensor_comp_plain_value_size::redispatch(dispatchKeySet, compressed_indices, plain_indices, values, size, dtype, layout, device, pin_memory); + } + + // aten::sparse_csr_tensor.crow_col_value_size(Tensor crow_indices, Tensor col_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor + inline at::Tensor sparse_csr_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & crow_indices, const at::Tensor & col_indices, const at::Tensor & values, at::IntArrayRef size, at::TensorOptions options) { + return at::_ops::sparse_csr_tensor_crow_col_value_size::redispatch(dispatchKeySet, crow_indices, col_indices, values, size, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::sparse_csr_tensor.crow_col_value_size(Tensor crow_indices, Tensor col_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor + inline at::Tensor sparse_csr_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & crow_indices, const at::Tensor & col_indices, const at::Tensor & values, at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::sparse_csr_tensor_crow_col_value_size::redispatch(dispatchKeySet, crow_indices, col_indices, values, size, dtype, layout, device, pin_memory); + } + + // aten::sparse_csc_tensor.ccol_row_value_size(Tensor ccol_indices, Tensor row_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor + inline at::Tensor sparse_csc_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & ccol_indices, const at::Tensor & row_indices, const at::Tensor & values, at::IntArrayRef size, at::TensorOptions options) { + return at::_ops::sparse_csc_tensor_ccol_row_value_size::redispatch(dispatchKeySet, ccol_indices, row_indices, values, size, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::sparse_csc_tensor.ccol_row_value_size(Tensor ccol_indices, Tensor row_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor + inline at::Tensor sparse_csc_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & ccol_indices, const at::Tensor & row_indices, const at::Tensor & values, at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::sparse_csc_tensor_ccol_row_value_size::redispatch(dispatchKeySet, ccol_indices, row_indices, values, size, dtype, layout, device, pin_memory); + } + + // aten::sparse_bsr_tensor.crow_col_value_size(Tensor crow_indices, Tensor col_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor + inline at::Tensor sparse_bsr_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & crow_indices, const at::Tensor & col_indices, const at::Tensor & values, at::IntArrayRef size, at::TensorOptions options) { + return at::_ops::sparse_bsr_tensor_crow_col_value_size::redispatch(dispatchKeySet, crow_indices, col_indices, values, size, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::sparse_bsr_tensor.crow_col_value_size(Tensor crow_indices, Tensor col_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor + inline at::Tensor sparse_bsr_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & crow_indices, const at::Tensor & col_indices, const at::Tensor & values, at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::sparse_bsr_tensor_crow_col_value_size::redispatch(dispatchKeySet, crow_indices, col_indices, values, size, dtype, layout, device, pin_memory); + } + + // aten::sparse_bsc_tensor.ccol_row_value_size(Tensor ccol_indices, Tensor row_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor + inline at::Tensor sparse_bsc_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & ccol_indices, const at::Tensor & row_indices, const at::Tensor & values, at::IntArrayRef size, at::TensorOptions options) { + return at::_ops::sparse_bsc_tensor_ccol_row_value_size::redispatch(dispatchKeySet, ccol_indices, row_indices, values, size, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::sparse_bsc_tensor.ccol_row_value_size(Tensor ccol_indices, Tensor row_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor + inline at::Tensor sparse_bsc_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & ccol_indices, const at::Tensor & row_indices, const at::Tensor & values, at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::sparse_bsc_tensor_ccol_row_value_size::redispatch(dispatchKeySet, ccol_indices, row_indices, values, size, dtype, layout, device, pin_memory); + } + + // aten::sparse_compressed_tensor.comp_plain_value(Tensor compressed_indices, Tensor plain_indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor + inline at::Tensor sparse_compressed_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & compressed_indices, const at::Tensor & plain_indices, const at::Tensor & values, at::TensorOptions options) { + return at::_ops::sparse_compressed_tensor_comp_plain_value::redispatch(dispatchKeySet, compressed_indices, plain_indices, values, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::sparse_compressed_tensor.comp_plain_value(Tensor compressed_indices, Tensor plain_indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor + inline at::Tensor sparse_compressed_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & compressed_indices, const at::Tensor & plain_indices, const at::Tensor & values, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::sparse_compressed_tensor_comp_plain_value::redispatch(dispatchKeySet, compressed_indices, plain_indices, values, dtype, layout, device, pin_memory); + } + + // aten::sparse_csr_tensor.crow_col_value(Tensor crow_indices, Tensor col_indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor + inline at::Tensor sparse_csr_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & crow_indices, const at::Tensor & col_indices, const at::Tensor & values, at::TensorOptions options) { + return at::_ops::sparse_csr_tensor_crow_col_value::redispatch(dispatchKeySet, crow_indices, col_indices, values, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::sparse_csr_tensor.crow_col_value(Tensor crow_indices, Tensor col_indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor + inline at::Tensor sparse_csr_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & crow_indices, const at::Tensor & col_indices, const at::Tensor & values, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::sparse_csr_tensor_crow_col_value::redispatch(dispatchKeySet, crow_indices, col_indices, values, dtype, layout, device, pin_memory); + } + + // aten::sparse_csc_tensor.ccol_row_value(Tensor ccol_indices, Tensor row_indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor + inline at::Tensor sparse_csc_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & ccol_indices, const at::Tensor & row_indices, const at::Tensor & values, at::TensorOptions options) { + return at::_ops::sparse_csc_tensor_ccol_row_value::redispatch(dispatchKeySet, ccol_indices, row_indices, values, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::sparse_csc_tensor.ccol_row_value(Tensor ccol_indices, Tensor row_indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor + inline at::Tensor sparse_csc_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & ccol_indices, const at::Tensor & row_indices, const at::Tensor & values, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::sparse_csc_tensor_ccol_row_value::redispatch(dispatchKeySet, ccol_indices, row_indices, values, dtype, layout, device, pin_memory); + } + + // aten::sparse_bsr_tensor.crow_col_value(Tensor crow_indices, Tensor col_indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor + inline at::Tensor sparse_bsr_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & crow_indices, const at::Tensor & col_indices, const at::Tensor & values, at::TensorOptions options) { + return at::_ops::sparse_bsr_tensor_crow_col_value::redispatch(dispatchKeySet, crow_indices, col_indices, values, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::sparse_bsr_tensor.crow_col_value(Tensor crow_indices, Tensor col_indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor + inline at::Tensor sparse_bsr_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & crow_indices, const at::Tensor & col_indices, const at::Tensor & values, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::sparse_bsr_tensor_crow_col_value::redispatch(dispatchKeySet, crow_indices, col_indices, values, dtype, layout, device, pin_memory); + } + + // aten::sparse_bsc_tensor.ccol_row_value(Tensor ccol_indices, Tensor row_indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor + inline at::Tensor sparse_bsc_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & ccol_indices, const at::Tensor & row_indices, const at::Tensor & values, at::TensorOptions options) { + return at::_ops::sparse_bsc_tensor_ccol_row_value::redispatch(dispatchKeySet, ccol_indices, row_indices, values, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::sparse_bsc_tensor.ccol_row_value(Tensor ccol_indices, Tensor row_indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor + inline at::Tensor sparse_bsc_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & ccol_indices, const at::Tensor & row_indices, const at::Tensor & values, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::sparse_bsc_tensor_ccol_row_value::redispatch(dispatchKeySet, ccol_indices, row_indices, values, dtype, layout, device, pin_memory); + } + + // aten::_sparse_compressed_tensor_unsafe(Tensor compressed_indices, Tensor plain_indices, Tensor values, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor _sparse_compressed_tensor_unsafe(c10::DispatchKeySet dispatchKeySet, const at::Tensor & compressed_indices, const at::Tensor & plain_indices, const at::Tensor & values, at::IntArrayRef size, at::TensorOptions options={}) { + return at::_ops::_sparse_compressed_tensor_unsafe::redispatch(dispatchKeySet, compressed_indices, plain_indices, values, c10::fromIntArrayRefSlow(size), c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::_sparse_compressed_tensor_unsafe(Tensor compressed_indices, Tensor plain_indices, Tensor values, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor _sparse_compressed_tensor_unsafe(c10::DispatchKeySet dispatchKeySet, const at::Tensor & compressed_indices, const at::Tensor & plain_indices, const at::Tensor & values, at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::_sparse_compressed_tensor_unsafe::redispatch(dispatchKeySet, compressed_indices, plain_indices, values, c10::fromIntArrayRefSlow(size), dtype, layout, device, pin_memory); + } + + // aten::_sparse_compressed_tensor_unsafe(Tensor compressed_indices, Tensor plain_indices, Tensor values, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor _sparse_compressed_tensor_unsafe_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & compressed_indices, const at::Tensor & plain_indices, const at::Tensor & values, c10::SymIntArrayRef size, at::TensorOptions options={}) { + return at::_ops::_sparse_compressed_tensor_unsafe::redispatch(dispatchKeySet, compressed_indices, plain_indices, values, size, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::_sparse_compressed_tensor_unsafe(Tensor compressed_indices, Tensor plain_indices, Tensor values, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor _sparse_compressed_tensor_unsafe_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & compressed_indices, const at::Tensor & plain_indices, const at::Tensor & values, c10::SymIntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::_sparse_compressed_tensor_unsafe::redispatch(dispatchKeySet, compressed_indices, plain_indices, values, size, dtype, layout, device, pin_memory); + } + + // aten::_sparse_csr_tensor_unsafe(Tensor crow_indices, Tensor col_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor _sparse_csr_tensor_unsafe(c10::DispatchKeySet dispatchKeySet, const at::Tensor & crow_indices, const at::Tensor & col_indices, const at::Tensor & values, at::IntArrayRef size, at::TensorOptions options={}) { + return at::_ops::_sparse_csr_tensor_unsafe::redispatch(dispatchKeySet, crow_indices, col_indices, values, size, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::_sparse_csr_tensor_unsafe(Tensor crow_indices, Tensor col_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor _sparse_csr_tensor_unsafe(c10::DispatchKeySet dispatchKeySet, const at::Tensor & crow_indices, const at::Tensor & col_indices, const at::Tensor & values, at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::_sparse_csr_tensor_unsafe::redispatch(dispatchKeySet, crow_indices, col_indices, values, size, dtype, layout, device, pin_memory); + } + + // aten::_sparse_csc_tensor_unsafe(Tensor ccol_indices, Tensor row_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor _sparse_csc_tensor_unsafe(c10::DispatchKeySet dispatchKeySet, const at::Tensor & ccol_indices, const at::Tensor & row_indices, const at::Tensor & values, at::IntArrayRef size, at::TensorOptions options={}) { + return at::_ops::_sparse_csc_tensor_unsafe::redispatch(dispatchKeySet, ccol_indices, row_indices, values, size, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::_sparse_csc_tensor_unsafe(Tensor ccol_indices, Tensor row_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor _sparse_csc_tensor_unsafe(c10::DispatchKeySet dispatchKeySet, const at::Tensor & ccol_indices, const at::Tensor & row_indices, const at::Tensor & values, at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::_sparse_csc_tensor_unsafe::redispatch(dispatchKeySet, ccol_indices, row_indices, values, size, dtype, layout, device, pin_memory); + } + + // aten::_sparse_bsr_tensor_unsafe(Tensor crow_indices, Tensor col_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor _sparse_bsr_tensor_unsafe(c10::DispatchKeySet dispatchKeySet, const at::Tensor & crow_indices, const at::Tensor & col_indices, const at::Tensor & values, at::IntArrayRef size, at::TensorOptions options={}) { + return at::_ops::_sparse_bsr_tensor_unsafe::redispatch(dispatchKeySet, crow_indices, col_indices, values, size, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::_sparse_bsr_tensor_unsafe(Tensor crow_indices, Tensor col_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor _sparse_bsr_tensor_unsafe(c10::DispatchKeySet dispatchKeySet, const at::Tensor & crow_indices, const at::Tensor & col_indices, const at::Tensor & values, at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::_sparse_bsr_tensor_unsafe::redispatch(dispatchKeySet, crow_indices, col_indices, values, size, dtype, layout, device, pin_memory); + } + + // aten::_sparse_bsc_tensor_unsafe(Tensor ccol_indices, Tensor row_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor _sparse_bsc_tensor_unsafe(c10::DispatchKeySet dispatchKeySet, const at::Tensor & ccol_indices, const at::Tensor & row_indices, const at::Tensor & values, at::IntArrayRef size, at::TensorOptions options={}) { + return at::_ops::_sparse_bsc_tensor_unsafe::redispatch(dispatchKeySet, ccol_indices, row_indices, values, size, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::_sparse_bsc_tensor_unsafe(Tensor ccol_indices, Tensor row_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor _sparse_bsc_tensor_unsafe(c10::DispatchKeySet dispatchKeySet, const at::Tensor & ccol_indices, const at::Tensor & row_indices, const at::Tensor & values, at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::_sparse_bsc_tensor_unsafe::redispatch(dispatchKeySet, ccol_indices, row_indices, values, size, dtype, layout, device, pin_memory); + } + + // aten::sparse_coo_tensor.size(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor + inline at::Tensor sparse_coo_tensor(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, at::TensorOptions options) { + return at::_ops::sparse_coo_tensor_size::redispatch(dispatchKeySet, size, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::sparse_coo_tensor.size(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor + inline at::Tensor sparse_coo_tensor(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::sparse_coo_tensor_size::redispatch(dispatchKeySet, size, dtype, layout, device, pin_memory); + } + + // aten::sparse_coo_tensor.indices(Tensor indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool? is_coalesced=None) -> Tensor + inline at::Tensor sparse_coo_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & indices, const at::Tensor & values, at::TensorOptions options={}, ::std::optional is_coalesced=::std::nullopt) { + return at::_ops::sparse_coo_tensor_indices::redispatch(dispatchKeySet, indices, values, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), is_coalesced); + } + + // aten::sparse_coo_tensor.indices(Tensor indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool? is_coalesced=None) -> Tensor + inline at::Tensor sparse_coo_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & indices, const at::Tensor & values, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional is_coalesced) { + return at::_ops::sparse_coo_tensor_indices::redispatch(dispatchKeySet, indices, values, dtype, layout, device, pin_memory, is_coalesced); + } + + // aten::sparse_coo_tensor.indices_size(Tensor indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool? is_coalesced=None) -> Tensor + inline at::Tensor sparse_coo_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & indices, const at::Tensor & values, at::IntArrayRef size, at::TensorOptions options={}, ::std::optional is_coalesced=::std::nullopt) { + return at::_ops::sparse_coo_tensor_indices_size::redispatch(dispatchKeySet, indices, values, size, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), is_coalesced); + } + + // aten::sparse_coo_tensor.indices_size(Tensor indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool? is_coalesced=None) -> Tensor + inline at::Tensor sparse_coo_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & indices, const at::Tensor & values, at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional is_coalesced) { + return at::_ops::sparse_coo_tensor_indices_size::redispatch(dispatchKeySet, indices, values, size, dtype, layout, device, pin_memory, is_coalesced); + } + + // aten::_sparse_coo_tensor_unsafe(Tensor indices, Tensor values, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool? is_coalesced=None) -> Tensor + inline at::Tensor _sparse_coo_tensor_unsafe(c10::DispatchKeySet dispatchKeySet, const at::Tensor & indices, const at::Tensor & values, at::IntArrayRef size, at::TensorOptions options={}, ::std::optional is_coalesced=::std::nullopt) { + return at::_ops::_sparse_coo_tensor_unsafe::redispatch(dispatchKeySet, indices, values, c10::fromIntArrayRefSlow(size), c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), is_coalesced); + } + + // aten::_sparse_coo_tensor_unsafe(Tensor indices, Tensor values, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool? is_coalesced=None) -> Tensor + inline at::Tensor _sparse_coo_tensor_unsafe(c10::DispatchKeySet dispatchKeySet, const at::Tensor & indices, const at::Tensor & values, at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional is_coalesced) { + return at::_ops::_sparse_coo_tensor_unsafe::redispatch(dispatchKeySet, indices, values, c10::fromIntArrayRefSlow(size), dtype, layout, device, pin_memory, is_coalesced); + } + + // aten::_sparse_coo_tensor_unsafe(Tensor indices, Tensor values, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool? is_coalesced=None) -> Tensor + inline at::Tensor _sparse_coo_tensor_unsafe_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & indices, const at::Tensor & values, c10::SymIntArrayRef size, at::TensorOptions options={}, ::std::optional is_coalesced=::std::nullopt) { + return at::_ops::_sparse_coo_tensor_unsafe::redispatch(dispatchKeySet, indices, values, size, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), is_coalesced); + } + + // aten::_sparse_coo_tensor_unsafe(Tensor indices, Tensor values, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool? is_coalesced=None) -> Tensor + inline at::Tensor _sparse_coo_tensor_unsafe_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & indices, const at::Tensor & values, c10::SymIntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional is_coalesced) { + return at::_ops::_sparse_coo_tensor_unsafe::redispatch(dispatchKeySet, indices, values, size, dtype, layout, device, pin_memory, is_coalesced); + } + + // aten::_validate_sparse_coo_tensor_args(Tensor indices, Tensor values, int[] size, bool? is_coalesced=None, bool? check_pinning=None) -> () + inline void _validate_sparse_coo_tensor_args(c10::DispatchKeySet dispatchKeySet, const at::Tensor & indices, const at::Tensor & values, at::IntArrayRef size, ::std::optional is_coalesced=::std::nullopt, ::std::optional check_pinning=::std::nullopt) { + return at::_ops::_validate_sparse_coo_tensor_args::redispatch(dispatchKeySet, indices, values, size, is_coalesced, check_pinning); + } + + // aten::_validate_sparse_compressed_tensor_args(Tensor compressed_indices, Tensor plain_indices, Tensor values, int[] size, Layout layout, bool? check_pinning=None) -> () + inline void _validate_sparse_compressed_tensor_args(c10::DispatchKeySet dispatchKeySet, const at::Tensor & compressed_indices, const at::Tensor & plain_indices, const at::Tensor & values, at::IntArrayRef size, at::Layout layout, ::std::optional check_pinning=::std::nullopt) { + return at::_ops::_validate_sparse_compressed_tensor_args::redispatch(dispatchKeySet, compressed_indices, plain_indices, values, size, layout, check_pinning); + } + + // aten::_validate_sparse_csr_tensor_args(Tensor crow_indices, Tensor col_indices, Tensor values, int[] size, bool? check_pinning=None) -> () + inline void _validate_sparse_csr_tensor_args(c10::DispatchKeySet dispatchKeySet, const at::Tensor & crow_indices, const at::Tensor & col_indices, const at::Tensor & values, at::IntArrayRef size, ::std::optional check_pinning=::std::nullopt) { + return at::_ops::_validate_sparse_csr_tensor_args::redispatch(dispatchKeySet, crow_indices, col_indices, values, size, check_pinning); + } + + // aten::_validate_sparse_csc_tensor_args(Tensor ccol_indices, Tensor row_indices, Tensor values, int[] size, bool? check_pinning=None) -> () + inline void _validate_sparse_csc_tensor_args(c10::DispatchKeySet dispatchKeySet, const at::Tensor & ccol_indices, const at::Tensor & row_indices, const at::Tensor & values, at::IntArrayRef size, ::std::optional check_pinning=::std::nullopt) { + return at::_ops::_validate_sparse_csc_tensor_args::redispatch(dispatchKeySet, ccol_indices, row_indices, values, size, check_pinning); + } + + // aten::_validate_sparse_bsr_tensor_args(Tensor crow_indices, Tensor col_indices, Tensor values, int[] size, bool? check_pinning=None) -> () + inline void _validate_sparse_bsr_tensor_args(c10::DispatchKeySet dispatchKeySet, const at::Tensor & crow_indices, const at::Tensor & col_indices, const at::Tensor & values, at::IntArrayRef size, ::std::optional check_pinning=::std::nullopt) { + return at::_ops::_validate_sparse_bsr_tensor_args::redispatch(dispatchKeySet, crow_indices, col_indices, values, size, check_pinning); + } + + // aten::_validate_sparse_bsc_tensor_args(Tensor ccol_indices, Tensor row_indices, Tensor values, int[] size, bool? check_pinning=None) -> () + inline void _validate_sparse_bsc_tensor_args(c10::DispatchKeySet dispatchKeySet, const at::Tensor & ccol_indices, const at::Tensor & row_indices, const at::Tensor & values, at::IntArrayRef size, ::std::optional check_pinning=::std::nullopt) { + return at::_ops::_validate_sparse_bsc_tensor_args::redispatch(dispatchKeySet, ccol_indices, row_indices, values, size, check_pinning); + } + + // aten::_sparse_coo_tensor_with_dims(int sparse_dim, int dense_dim, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor + inline at::Tensor _sparse_coo_tensor_with_dims(c10::DispatchKeySet dispatchKeySet, int64_t sparse_dim, int64_t dense_dim, at::IntArrayRef size, at::TensorOptions options) { + return at::_ops::_sparse_coo_tensor_with_dims::redispatch(dispatchKeySet, sparse_dim, dense_dim, size, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::_sparse_coo_tensor_with_dims(int sparse_dim, int dense_dim, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor + inline at::Tensor _sparse_coo_tensor_with_dims(c10::DispatchKeySet dispatchKeySet, int64_t sparse_dim, int64_t dense_dim, at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::_sparse_coo_tensor_with_dims::redispatch(dispatchKeySet, sparse_dim, dense_dim, size, dtype, layout, device, pin_memory); + } + + // aten::_sparse_coo_tensor_with_dims_and_tensors(int sparse_dim, int dense_dim, SymInt[] size, Tensor indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False, bool? is_coalesced=None) -> Tensor + inline at::Tensor _sparse_coo_tensor_with_dims_and_tensors(c10::DispatchKeySet dispatchKeySet, int64_t sparse_dim, int64_t dense_dim, at::IntArrayRef size, const at::Tensor & indices, const at::Tensor & values, at::TensorOptions options, ::std::optional is_coalesced=::std::nullopt) { + return at::_ops::_sparse_coo_tensor_with_dims_and_tensors::redispatch(dispatchKeySet, sparse_dim, dense_dim, c10::fromIntArrayRefSlow(size), indices, values, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), is_coalesced); + } + + // aten::_sparse_coo_tensor_with_dims_and_tensors(int sparse_dim, int dense_dim, SymInt[] size, Tensor indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False, bool? is_coalesced=None) -> Tensor + inline at::Tensor _sparse_coo_tensor_with_dims_and_tensors(c10::DispatchKeySet dispatchKeySet, int64_t sparse_dim, int64_t dense_dim, at::IntArrayRef size, const at::Tensor & indices, const at::Tensor & values, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional is_coalesced) { + return at::_ops::_sparse_coo_tensor_with_dims_and_tensors::redispatch(dispatchKeySet, sparse_dim, dense_dim, c10::fromIntArrayRefSlow(size), indices, values, dtype, layout, device, pin_memory, is_coalesced); + } + + // aten::_sparse_coo_tensor_with_dims_and_tensors(int sparse_dim, int dense_dim, SymInt[] size, Tensor indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False, bool? is_coalesced=None) -> Tensor + inline at::Tensor _sparse_coo_tensor_with_dims_and_tensors_symint(c10::DispatchKeySet dispatchKeySet, int64_t sparse_dim, int64_t dense_dim, c10::SymIntArrayRef size, const at::Tensor & indices, const at::Tensor & values, at::TensorOptions options, ::std::optional is_coalesced=::std::nullopt) { + return at::_ops::_sparse_coo_tensor_with_dims_and_tensors::redispatch(dispatchKeySet, sparse_dim, dense_dim, size, indices, values, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), is_coalesced); + } + + // aten::_sparse_coo_tensor_with_dims_and_tensors(int sparse_dim, int dense_dim, SymInt[] size, Tensor indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False, bool? is_coalesced=None) -> Tensor + inline at::Tensor _sparse_coo_tensor_with_dims_and_tensors_symint(c10::DispatchKeySet dispatchKeySet, int64_t sparse_dim, int64_t dense_dim, c10::SymIntArrayRef size, const at::Tensor & indices, const at::Tensor & values, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional is_coalesced) { + return at::_ops::_sparse_coo_tensor_with_dims_and_tensors::redispatch(dispatchKeySet, sparse_dim, dense_dim, size, indices, values, dtype, layout, device, pin_memory, is_coalesced); + } + + // aten::sparse_resize_(Tensor(a!) self, int[] size, int sparse_dim, int dense_dim) -> Tensor(a!) + inline const at::Tensor & sparse_resize_(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, int64_t sparse_dim, int64_t dense_dim) { + return at::_ops::sparse_resize_::redispatch(dispatchKeySet, self, size, sparse_dim, dense_dim); + } + + // aten::sparse_resize_and_clear_(Tensor(a!) self, int[] size, int sparse_dim, int dense_dim) -> Tensor(a!) + inline const at::Tensor & sparse_resize_and_clear_(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, int64_t sparse_dim, int64_t dense_dim) { + return at::_ops::sparse_resize_and_clear_::redispatch(dispatchKeySet, self, size, sparse_dim, dense_dim); + } + + // aten::sparse_mask(Tensor self, Tensor mask) -> Tensor + inline at::Tensor sparse_mask(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mask) { + return at::_ops::sparse_mask::redispatch(dispatchKeySet, self, mask); + } + + // aten::_sparse_mask_projection(Tensor self, Tensor mask, bool accumulate_matches=False) -> Tensor + inline at::Tensor _sparse_mask_projection(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mask, bool accumulate_matches=false) { + return at::_ops::_sparse_mask_projection::redispatch(dispatchKeySet, self, mask, accumulate_matches); + } + + // aten::_to_cpu(Tensor[] tensors) -> Tensor[] + inline ::std::vector _to_cpu(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors) { + return at::_ops::_to_cpu::redispatch(dispatchKeySet, tensors); + } + + // aten::to_dense(Tensor self, ScalarType? dtype=None, *, bool? masked_grad=None) -> Tensor + inline at::Tensor to_dense(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional dtype=::std::nullopt, ::std::optional masked_grad=::std::nullopt) { + return at::_ops::to_dense::redispatch(dispatchKeySet, self, dtype, masked_grad); + } + + // aten::_to_dense(Tensor self, ScalarType? dtype=None, bool? masked_grad=None) -> Tensor + inline at::Tensor _to_dense(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional dtype=::std::nullopt, ::std::optional masked_grad=::std::nullopt) { + return at::_ops::_to_dense::redispatch(dispatchKeySet, self, dtype, masked_grad); + } + + // aten::to_dense_backward(Tensor grad, Tensor input, bool? masked_grad=None) -> Tensor + inline at::Tensor to_dense_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & input, ::std::optional masked_grad=::std::nullopt) { + return at::_ops::to_dense_backward::redispatch(dispatchKeySet, grad, input, masked_grad); + } + + // aten::sparse_dim(Tensor self) -> int + inline int64_t sparse_dim(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::sparse_dim::redispatch(dispatchKeySet, self); + } + + // aten::_dimI(Tensor self) -> int + inline int64_t _dimI(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::_dimI::redispatch(dispatchKeySet, self); + } + + // aten::dense_dim(Tensor self) -> int + inline int64_t dense_dim(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::dense_dim::redispatch(dispatchKeySet, self); + } + + // aten::_dimV(Tensor self) -> int + inline int64_t _dimV(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::_dimV::redispatch(dispatchKeySet, self); + } + + // aten::_nnz(Tensor self) -> int + inline int64_t _nnz(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::_nnz::redispatch(dispatchKeySet, self); + } + + // aten::coalesce(Tensor(a) self) -> Tensor(a) + inline at::Tensor coalesce(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::coalesce::redispatch(dispatchKeySet, self); + } + + // aten::_coalesce(Tensor self) -> Tensor + inline at::Tensor _coalesce(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::_coalesce::redispatch(dispatchKeySet, self); + } + + // aten::is_coalesced(Tensor self) -> bool + inline bool is_coalesced(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::is_coalesced::redispatch(dispatchKeySet, self); + } + + // aten::_indices(Tensor(a) self) -> Tensor(a) + inline at::Tensor _indices(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::_indices::redispatch(dispatchKeySet, self); + } + + // aten::_values(Tensor(a) self) -> Tensor(a) + inline at::Tensor _values(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::_values::redispatch(dispatchKeySet, self); + } + + // aten::_coalesced_(Tensor(a!) self, bool coalesced) -> Tensor(a!) + inline at::Tensor & _coalesced_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, bool coalesced) { + return at::_ops::_coalesced_::redispatch(dispatchKeySet, self, coalesced); + } + + // aten::indices(Tensor(a) self) -> Tensor(a) + inline at::Tensor indices(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::indices::redispatch(dispatchKeySet, self); + } + + // aten::values(Tensor(a) self) -> Tensor(a) + inline at::Tensor values(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::values::redispatch(dispatchKeySet, self); + } + + // aten::crow_indices(Tensor(a) self) -> Tensor(a) + inline at::Tensor crow_indices(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::crow_indices::redispatch(dispatchKeySet, self); + } + + // aten::col_indices(Tensor(a) self) -> Tensor(a) + inline at::Tensor col_indices(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::col_indices::redispatch(dispatchKeySet, self); + } + + // aten::ccol_indices(Tensor(a) self) -> Tensor(a) + inline at::Tensor ccol_indices(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::ccol_indices::redispatch(dispatchKeySet, self); + } + + // aten::row_indices(Tensor(a) self) -> Tensor(a) + inline at::Tensor row_indices(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::row_indices::redispatch(dispatchKeySet, self); + } + + // aten::hspmm.out(Tensor mat1, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & hspmm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & mat1, const at::Tensor & mat2) { + return at::_ops::hspmm_out::redispatch(dispatchKeySet, mat1, mat2, out); + } + + // aten::hspmm.out(Tensor mat1, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & hspmm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & mat1, const at::Tensor & mat2, at::Tensor & out) { + return at::_ops::hspmm_out::redispatch(dispatchKeySet, mat1, mat2, out); + } + + // aten::hspmm(Tensor mat1, Tensor mat2) -> Tensor + inline at::Tensor hspmm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & mat1, const at::Tensor & mat2) { + return at::_ops::hspmm::redispatch(dispatchKeySet, mat1, mat2); + } + + // aten::copy_sparse_to_sparse_(Tensor(a!) self, Tensor src, bool non_blocking=False) -> Tensor(a!) + inline at::Tensor & copy_sparse_to_sparse_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & src, bool non_blocking=false) { + return at::_ops::copy_sparse_to_sparse_::redispatch(dispatchKeySet, self, src, non_blocking); + } + + // aten::unbind.int(Tensor(a -> *) self, int dim=0) -> Tensor(a)[] + inline ::std::vector unbind(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim=0) { + return at::_ops::unbind_int::redispatch(dispatchKeySet, self, dim); + } + + // aten::unbind.Dimname(Tensor(a -> *) self, Dimname dim) -> Tensor(a)[] + inline ::std::vector unbind(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim) { + return at::_ops::unbind_Dimname::redispatch(dispatchKeySet, self, dim); + } + + // aten::to_sparse.sparse_dim(Tensor self, int sparse_dim) -> Tensor + inline at::Tensor to_sparse(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t sparse_dim) { + return at::_ops::to_sparse_sparse_dim::redispatch(dispatchKeySet, self, sparse_dim); + } + + // aten::_to_sparse.sparse_dim(Tensor self, int sparse_dim) -> Tensor + inline at::Tensor _to_sparse(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t sparse_dim) { + return at::_ops::_to_sparse_sparse_dim::redispatch(dispatchKeySet, self, sparse_dim); + } + + // aten::to_sparse(Tensor self, *, Layout? layout=None, int[2]? blocksize=None, int? dense_dim=None) -> Tensor + inline at::Tensor to_sparse(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional layout=::std::nullopt, at::OptionalIntArrayRef blocksize=::std::nullopt, ::std::optional dense_dim=::std::nullopt) { + return at::_ops::to_sparse::redispatch(dispatchKeySet, self, layout, blocksize, dense_dim); + } + + // aten::_to_sparse(Tensor self, *, Layout? layout=None, int[2]? blocksize=None, int? dense_dim=None) -> Tensor + inline at::Tensor _to_sparse(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional layout=::std::nullopt, at::OptionalIntArrayRef blocksize=::std::nullopt, ::std::optional dense_dim=::std::nullopt) { + return at::_ops::_to_sparse::redispatch(dispatchKeySet, self, layout, blocksize, dense_dim); + } + + // aten::to_sparse_csr(Tensor self, int? dense_dim=None) -> Tensor + inline at::Tensor to_sparse_csr(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional dense_dim=::std::nullopt) { + return at::_ops::to_sparse_csr::redispatch(dispatchKeySet, self, dense_dim); + } + + // aten::_to_sparse_csr(Tensor self, int? dense_dim=None) -> Tensor + inline at::Tensor _to_sparse_csr(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional dense_dim=::std::nullopt) { + return at::_ops::_to_sparse_csr::redispatch(dispatchKeySet, self, dense_dim); + } + + // aten::to_sparse_csc(Tensor self, int? dense_dim=None) -> Tensor + inline at::Tensor to_sparse_csc(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional dense_dim=::std::nullopt) { + return at::_ops::to_sparse_csc::redispatch(dispatchKeySet, self, dense_dim); + } + + // aten::_to_sparse_csc(Tensor self, int? dense_dim=None) -> Tensor + inline at::Tensor _to_sparse_csc(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional dense_dim=::std::nullopt) { + return at::_ops::_to_sparse_csc::redispatch(dispatchKeySet, self, dense_dim); + } + + // aten::to_sparse_bsr(Tensor self, int[2] blocksize, int? dense_dim=None) -> Tensor + inline at::Tensor to_sparse_bsr(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef blocksize, ::std::optional dense_dim=::std::nullopt) { + return at::_ops::to_sparse_bsr::redispatch(dispatchKeySet, self, blocksize, dense_dim); + } + + // aten::_to_sparse_bsr(Tensor self, int[2] blocksize, int? dense_dim=None) -> Tensor + inline at::Tensor _to_sparse_bsr(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef blocksize, ::std::optional dense_dim=::std::nullopt) { + return at::_ops::_to_sparse_bsr::redispatch(dispatchKeySet, self, blocksize, dense_dim); + } + + // aten::to_sparse_bsc(Tensor self, int[2] blocksize, int? dense_dim=None) -> Tensor + inline at::Tensor to_sparse_bsc(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef blocksize, ::std::optional dense_dim=::std::nullopt) { + return at::_ops::to_sparse_bsc::redispatch(dispatchKeySet, self, blocksize, dense_dim); + } + + // aten::_to_sparse_bsc(Tensor self, int[2] blocksize, int? dense_dim=None) -> Tensor + inline at::Tensor _to_sparse_bsc(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef blocksize, ::std::optional dense_dim=::std::nullopt) { + return at::_ops::_to_sparse_bsc::redispatch(dispatchKeySet, self, blocksize, dense_dim); + } + + // aten::_to_sparse_semi_structured(Tensor dense) -> (Tensor, Tensor) + inline ::std::tuple _to_sparse_semi_structured(c10::DispatchKeySet dispatchKeySet, const at::Tensor & dense) { + return at::_ops::_to_sparse_semi_structured::redispatch(dispatchKeySet, dense); + } + + // aten::to_mkldnn(Tensor self, ScalarType? dtype=None) -> Tensor + inline at::Tensor to_mkldnn(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional dtype=::std::nullopt) { + return at::_ops::to_mkldnn::redispatch(dispatchKeySet, self, dtype); + } + + // aten::mkldnn_reorder_conv2d_weight(Tensor self, SymInt[2] padding=0, SymInt[2] stride=1, SymInt[2] dilation=1, SymInt groups=1, SymInt[]? input_size=None) -> Tensor + inline at::Tensor mkldnn_reorder_conv2d_weight(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef padding=0, at::IntArrayRef stride=1, at::IntArrayRef dilation=1, int64_t groups=1, at::OptionalIntArrayRef input_size=::std::nullopt) { + return at::_ops::mkldnn_reorder_conv2d_weight::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups, input_size.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*input_size)) : ::std::nullopt); + } + + // aten::mkldnn_reorder_conv2d_weight(Tensor self, SymInt[2] padding=0, SymInt[2] stride=1, SymInt[2] dilation=1, SymInt groups=1, SymInt[]? input_size=None) -> Tensor + inline at::Tensor mkldnn_reorder_conv2d_weight_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef padding=c10::SymInt(0), c10::SymIntArrayRef stride=c10::SymInt(1), c10::SymIntArrayRef dilation=c10::SymInt(1), c10::SymInt groups=1, at::OptionalSymIntArrayRef input_size=::std::nullopt) { + return at::_ops::mkldnn_reorder_conv2d_weight::redispatch(dispatchKeySet, self, padding, stride, dilation, groups, input_size); + } + + // aten::mkldnn_reorder_conv3d_weight(Tensor self, SymInt[3] padding=0, SymInt[3] stride=1, SymInt[3] dilation=1, SymInt groups=1, SymInt[]? input_size=None) -> Tensor + inline at::Tensor mkldnn_reorder_conv3d_weight(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef padding=0, at::IntArrayRef stride=1, at::IntArrayRef dilation=1, int64_t groups=1, at::OptionalIntArrayRef input_size=::std::nullopt) { + return at::_ops::mkldnn_reorder_conv3d_weight::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups, input_size.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*input_size)) : ::std::nullopt); + } + + // aten::mkldnn_reorder_conv3d_weight(Tensor self, SymInt[3] padding=0, SymInt[3] stride=1, SymInt[3] dilation=1, SymInt groups=1, SymInt[]? input_size=None) -> Tensor + inline at::Tensor mkldnn_reorder_conv3d_weight_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef padding=c10::SymInt(0), c10::SymIntArrayRef stride=c10::SymInt(1), c10::SymIntArrayRef dilation=c10::SymInt(1), c10::SymInt groups=1, at::OptionalSymIntArrayRef input_size=::std::nullopt) { + return at::_ops::mkldnn_reorder_conv3d_weight::redispatch(dispatchKeySet, self, padding, stride, dilation, groups, input_size); + } + + // aten::to_mkldnn_backward(Tensor grad, Tensor input) -> Tensor + inline at::Tensor to_mkldnn_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & input) { + return at::_ops::to_mkldnn_backward::redispatch(dispatchKeySet, grad, input); + } + + // aten::quantize_per_tensor_dynamic(Tensor self, ScalarType dtype, bool reduce_range) -> Tensor + inline at::Tensor quantize_per_tensor_dynamic(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::ScalarType dtype, bool reduce_range) { + return at::_ops::quantize_per_tensor_dynamic::redispatch(dispatchKeySet, self, dtype, reduce_range); + } + + // aten::quantize_per_tensor(Tensor self, float scale, int zero_point, ScalarType dtype) -> Tensor + inline at::Tensor quantize_per_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double scale, int64_t zero_point, at::ScalarType dtype) { + return at::_ops::quantize_per_tensor::redispatch(dispatchKeySet, self, scale, zero_point, dtype); + } + + // aten::quantize_per_tensor.tensor_qparams(Tensor self, Tensor scale, Tensor zero_point, ScalarType dtype) -> Tensor + inline at::Tensor quantize_per_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, at::ScalarType dtype) { + return at::_ops::quantize_per_tensor_tensor_qparams::redispatch(dispatchKeySet, self, scale, zero_point, dtype); + } + + // aten::quantize_per_tensor.tensors(Tensor[] tensors, Tensor scales, Tensor zero_points, ScalarType dtype) -> Tensor[] + inline ::std::vector quantize_per_tensor(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors, const at::Tensor & scales, const at::Tensor & zero_points, at::ScalarType dtype) { + return at::_ops::quantize_per_tensor_tensors::redispatch(dispatchKeySet, tensors, scales, zero_points, dtype); + } + + // aten::quantize_per_channel(Tensor self, Tensor scales, Tensor zero_points, int axis, ScalarType dtype) -> Tensor + inline at::Tensor quantize_per_channel(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, at::ScalarType dtype) { + return at::_ops::quantize_per_channel::redispatch(dispatchKeySet, self, scales, zero_points, axis, dtype); + } + + // aten::dequantize.self(Tensor self) -> Tensor + inline at::Tensor dequantize(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::dequantize_self::redispatch(dispatchKeySet, self); + } + + // aten::dequantize.tensors(Tensor[] tensors) -> Tensor[] + inline ::std::vector dequantize(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors) { + return at::_ops::dequantize_tensors::redispatch(dispatchKeySet, tensors); + } + + // aten::q_scale(Tensor self) -> float + inline double q_scale(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::q_scale::redispatch(dispatchKeySet, self); + } + + // aten::q_zero_point(Tensor self) -> int + inline int64_t q_zero_point(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::q_zero_point::redispatch(dispatchKeySet, self); + } + + // aten::q_per_channel_scales(Tensor self) -> Tensor + inline at::Tensor q_per_channel_scales(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::q_per_channel_scales::redispatch(dispatchKeySet, self); + } + + // aten::q_per_channel_zero_points(Tensor self) -> Tensor + inline at::Tensor q_per_channel_zero_points(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::q_per_channel_zero_points::redispatch(dispatchKeySet, self); + } + + // aten::q_per_channel_axis(Tensor self) -> int + inline int64_t q_per_channel_axis(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::q_per_channel_axis::redispatch(dispatchKeySet, self); + } + + // aten::int_repr(Tensor self) -> Tensor + inline at::Tensor int_repr(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::int_repr::redispatch(dispatchKeySet, self); + } + + // aten::_make_per_tensor_quantized_tensor(Tensor self, float scale, int zero_point) -> Tensor + inline at::Tensor _make_per_tensor_quantized_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double scale, int64_t zero_point) { + return at::_ops::_make_per_tensor_quantized_tensor::redispatch(dispatchKeySet, self, scale, zero_point); + } + + // aten::_make_per_channel_quantized_tensor(Tensor self, Tensor scale, Tensor zero_point, int axis) -> Tensor + inline at::Tensor _make_per_channel_quantized_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, int64_t axis) { + return at::_ops::_make_per_channel_quantized_tensor::redispatch(dispatchKeySet, self, scale, zero_point, axis); + } + + // aten::qscheme(Tensor self) -> QScheme + inline at::QScheme qscheme(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::qscheme::redispatch(dispatchKeySet, self); + } + + // aten::fake_quantize_per_tensor_affine(Tensor self, float scale, int zero_point, int quant_min, int quant_max) -> Tensor + inline at::Tensor fake_quantize_per_tensor_affine(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double scale, int64_t zero_point, int64_t quant_min, int64_t quant_max) { + return at::_ops::fake_quantize_per_tensor_affine::redispatch(dispatchKeySet, self, scale, zero_point, quant_min, quant_max); + } + + // aten::fake_quantize_per_tensor_affine.tensor_qparams(Tensor self, Tensor scale, Tensor zero_point, int quant_min, int quant_max) -> Tensor + inline at::Tensor fake_quantize_per_tensor_affine(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, int64_t quant_min, int64_t quant_max) { + return at::_ops::fake_quantize_per_tensor_affine_tensor_qparams::redispatch(dispatchKeySet, self, scale, zero_point, quant_min, quant_max); + } + + // aten::fake_quantize_per_tensor_affine_cachemask(Tensor self, float scale, int zero_point, int quant_min, int quant_max) -> (Tensor output, Tensor mask) + inline ::std::tuple fake_quantize_per_tensor_affine_cachemask(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double scale, int64_t zero_point, int64_t quant_min, int64_t quant_max) { + return at::_ops::fake_quantize_per_tensor_affine_cachemask::redispatch(dispatchKeySet, self, scale, zero_point, quant_min, quant_max); + } + + // aten::_fake_quantize_per_tensor_affine_cachemask_tensor_qparams(Tensor self, Tensor scale, Tensor zero_point, Tensor fake_quant_enabled, int quant_min, int quant_max) -> (Tensor output, Tensor mask) + inline ::std::tuple _fake_quantize_per_tensor_affine_cachemask_tensor_qparams(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, const at::Tensor & fake_quant_enabled, int64_t quant_min, int64_t quant_max) { + return at::_ops::_fake_quantize_per_tensor_affine_cachemask_tensor_qparams::redispatch(dispatchKeySet, self, scale, zero_point, fake_quant_enabled, quant_min, quant_max); + } + + // aten::fake_quantize_per_tensor_affine_cachemask_backward(Tensor grad, Tensor mask) -> Tensor + inline at::Tensor fake_quantize_per_tensor_affine_cachemask_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & mask) { + return at::_ops::fake_quantize_per_tensor_affine_cachemask_backward::redispatch(dispatchKeySet, grad, mask); + } + + // aten::_fake_quantize_learnable_per_tensor_affine(Tensor self, Tensor scale, Tensor zero_point, int quant_min, int quant_max, float grad_factor=1.0) -> Tensor + inline at::Tensor _fake_quantize_learnable_per_tensor_affine(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, int64_t quant_min, int64_t quant_max, double grad_factor=1.0) { + return at::_ops::_fake_quantize_learnable_per_tensor_affine::redispatch(dispatchKeySet, self, scale, zero_point, quant_min, quant_max, grad_factor); + } + + // aten::_fake_quantize_learnable_per_tensor_affine_backward(Tensor grad, Tensor self, Tensor scale, Tensor zero_point, int quant_min, int quant_max, float grad_factor=1.0) -> (Tensor, Tensor, Tensor) + inline ::std::tuple _fake_quantize_learnable_per_tensor_affine_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, int64_t quant_min, int64_t quant_max, double grad_factor=1.0) { + return at::_ops::_fake_quantize_learnable_per_tensor_affine_backward::redispatch(dispatchKeySet, grad, self, scale, zero_point, quant_min, quant_max, grad_factor); + } + + // aten::fake_quantize_per_channel_affine(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max) -> Tensor + inline at::Tensor fake_quantize_per_channel_affine(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, int64_t axis, int64_t quant_min, int64_t quant_max) { + return at::_ops::fake_quantize_per_channel_affine::redispatch(dispatchKeySet, self, scale, zero_point, axis, quant_min, quant_max); + } + + // aten::fake_quantize_per_channel_affine_cachemask(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max) -> (Tensor output, Tensor mask) + inline ::std::tuple fake_quantize_per_channel_affine_cachemask(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, int64_t axis, int64_t quant_min, int64_t quant_max) { + return at::_ops::fake_quantize_per_channel_affine_cachemask::redispatch(dispatchKeySet, self, scale, zero_point, axis, quant_min, quant_max); + } + + // aten::fake_quantize_per_channel_affine_cachemask_backward(Tensor grad, Tensor mask) -> Tensor + inline at::Tensor fake_quantize_per_channel_affine_cachemask_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & mask) { + return at::_ops::fake_quantize_per_channel_affine_cachemask_backward::redispatch(dispatchKeySet, grad, mask); + } + + // aten::_fake_quantize_learnable_per_channel_affine(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max, float grad_factor=1.0) -> Tensor + inline at::Tensor _fake_quantize_learnable_per_channel_affine(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, int64_t axis, int64_t quant_min, int64_t quant_max, double grad_factor=1.0) { + return at::_ops::_fake_quantize_learnable_per_channel_affine::redispatch(dispatchKeySet, self, scale, zero_point, axis, quant_min, quant_max, grad_factor); + } + + // aten::_fake_quantize_learnable_per_channel_affine_backward(Tensor grad, Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max, float grad_factor=1.0) -> (Tensor, Tensor, Tensor) + inline ::std::tuple _fake_quantize_learnable_per_channel_affine_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, int64_t axis, int64_t quant_min, int64_t quant_max, double grad_factor=1.0) { + return at::_ops::_fake_quantize_learnable_per_channel_affine_backward::redispatch(dispatchKeySet, grad, self, scale, zero_point, axis, quant_min, quant_max, grad_factor); + } + + // aten::fused_moving_avg_obs_fake_quant(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor(a!) running_min, Tensor(b!) running_max, Tensor(c!) scale, Tensor(d!) zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> Tensor + inline at::Tensor fused_moving_avg_obs_fake_quant(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & observer_on, const at::Tensor & fake_quant_on, at::Tensor & running_min, at::Tensor & running_max, at::Tensor & scale, at::Tensor & zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, bool per_row_fake_quant=false, bool symmetric_quant=false) { + return at::_ops::fused_moving_avg_obs_fake_quant::redispatch(dispatchKeySet, self, observer_on, fake_quant_on, running_min, running_max, scale, zero_point, averaging_const, quant_min, quant_max, ch_axis, per_row_fake_quant, symmetric_quant); + } + + // aten::_fused_moving_avg_obs_fq_helper(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor(a!) running_min, Tensor(b!) running_max, Tensor(c!) scale, Tensor(d!) zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> (Tensor output, Tensor mask) + inline ::std::tuple _fused_moving_avg_obs_fq_helper(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & observer_on, const at::Tensor & fake_quant_on, at::Tensor & running_min, at::Tensor & running_max, at::Tensor & scale, at::Tensor & zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, bool per_row_fake_quant=false, bool symmetric_quant=false) { + return at::_ops::_fused_moving_avg_obs_fq_helper::redispatch(dispatchKeySet, self, observer_on, fake_quant_on, running_min, running_max, scale, zero_point, averaging_const, quant_min, quant_max, ch_axis, per_row_fake_quant, symmetric_quant); + } + + // aten::_choose_qparams_per_tensor(Tensor self, bool reduce_range=False) -> (float, int) + inline ::std::tuple _choose_qparams_per_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool reduce_range=false) { + return at::_ops::_choose_qparams_per_tensor::redispatch(dispatchKeySet, self, reduce_range); + } + + // aten::_saturate_weight_to_fp16(Tensor weight) -> Tensor + inline at::Tensor _saturate_weight_to_fp16(c10::DispatchKeySet dispatchKeySet, const at::Tensor & weight) { + return at::_ops::_saturate_weight_to_fp16::redispatch(dispatchKeySet, weight); + } + + // aten::choose_qparams_optimized(Tensor input, int numel, int n_bins, float ratio, int bit_width) -> (Tensor, Tensor) + inline ::std::tuple choose_qparams_optimized(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, int64_t numel, int64_t n_bins, double ratio, int64_t bit_width) { + return at::_ops::choose_qparams_optimized::redispatch(dispatchKeySet, input, numel, n_bins, ratio, bit_width); + } + + // aten::_autocast_to_reduced_precision(Tensor(a) self, bool cuda_enabled, bool cpu_enabled, ScalarType cuda_dtype, ScalarType cpu_dtype) -> Tensor(a) + inline at::Tensor _autocast_to_reduced_precision(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool cuda_enabled, bool cpu_enabled, at::ScalarType cuda_dtype, at::ScalarType cpu_dtype) { + return at::_ops::_autocast_to_reduced_precision::redispatch(dispatchKeySet, self, cuda_enabled, cpu_enabled, cuda_dtype, cpu_dtype); + } + + // aten::_autocast_to_full_precision(Tensor(a) self, bool cuda_enabled, bool cpu_enabled) -> Tensor(a) + inline at::Tensor _autocast_to_full_precision(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool cuda_enabled, bool cpu_enabled) { + return at::_ops::_autocast_to_full_precision::redispatch(dispatchKeySet, self, cuda_enabled, cpu_enabled); + } + + // aten::_to_copy(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool non_blocking=False, MemoryFormat? memory_format=None) -> Tensor + inline at::Tensor _to_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::TensorOptions options={}, bool non_blocking=false, ::std::optional memory_format=::std::nullopt) { + return at::_ops::_to_copy::redispatch(dispatchKeySet, self, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), non_blocking, c10::impl::check_tensor_options_and_extract_memory_format(options, memory_format)); + } + + // aten::_to_copy(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool non_blocking=False, MemoryFormat? memory_format=None) -> Tensor + inline at::Tensor _to_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, bool non_blocking, ::std::optional memory_format) { + return at::_ops::_to_copy::redispatch(dispatchKeySet, self, dtype, layout, device, pin_memory, non_blocking, memory_format); + } + + // aten::to.dtype_layout(Tensor(a) self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a) + inline at::Tensor to(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::TensorOptions options={}, bool non_blocking=false, bool copy=false, ::std::optional memory_format=::std::nullopt) { + return at::_ops::to_dtype_layout::redispatch(dispatchKeySet, self, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), non_blocking, copy, c10::impl::check_tensor_options_and_extract_memory_format(options, memory_format)); + } + + // aten::to.dtype_layout(Tensor(a) self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a) + inline at::Tensor to(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, bool non_blocking, bool copy, ::std::optional memory_format) { + return at::_ops::to_dtype_layout::redispatch(dispatchKeySet, self, dtype, layout, device, pin_memory, non_blocking, copy, memory_format); + } + + // aten::to.device(Tensor(a) self, Device device, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a) + inline at::Tensor to(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Device device, at::ScalarType dtype, bool non_blocking=false, bool copy=false, ::std::optional memory_format=::std::nullopt) { + return at::_ops::to_device::redispatch(dispatchKeySet, self, device, dtype, non_blocking, copy, memory_format); + } + + // aten::to.dtype(Tensor(a) self, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a) + inline at::Tensor to(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::ScalarType dtype, bool non_blocking=false, bool copy=false, ::std::optional memory_format=::std::nullopt) { + return at::_ops::to_dtype::redispatch(dispatchKeySet, self, dtype, non_blocking, copy, memory_format); + } + + // aten::to.other(Tensor(a) self, Tensor other, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a) + inline at::Tensor to(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, bool non_blocking=false, bool copy=false, ::std::optional memory_format=::std::nullopt) { + return at::_ops::to_other::redispatch(dispatchKeySet, self, other, non_blocking, copy, memory_format); + } + + // aten::meshgrid(Tensor[] tensors) -> Tensor[] + inline ::std::vector meshgrid(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors) { + return at::_ops::meshgrid::redispatch(dispatchKeySet, tensors); + } + + // aten::meshgrid.indexing(Tensor[] tensors, *, str indexing) -> Tensor[] + inline ::std::vector meshgrid(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors, c10::string_view indexing) { + return at::_ops::meshgrid_indexing::redispatch(dispatchKeySet, tensors, indexing); + } + + // aten::cartesian_prod(Tensor[] tensors) -> Tensor + inline at::Tensor cartesian_prod(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors) { + return at::_ops::cartesian_prod::redispatch(dispatchKeySet, tensors); + } + + // aten::combinations(Tensor self, int r=2, bool with_replacement=False) -> Tensor + inline at::Tensor combinations(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t r=2, bool with_replacement=false) { + return at::_ops::combinations::redispatch(dispatchKeySet, self, r, with_replacement); + } + + // aten::item(Tensor self) -> Scalar + inline at::Scalar item(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::item::redispatch(dispatchKeySet, self); + } + + // aten::result_type.Tensor(Tensor tensor, Tensor other) -> ScalarType + inline at::ScalarType result_type(c10::DispatchKeySet dispatchKeySet, const at::Tensor & tensor, const at::Tensor & other) { + return at::_ops::result_type_Tensor::redispatch(dispatchKeySet, tensor, other); + } + + // aten::result_type.Scalar(Tensor tensor, Scalar other) -> ScalarType + inline at::ScalarType result_type(c10::DispatchKeySet dispatchKeySet, const at::Tensor & tensor, const at::Scalar & other) { + return at::_ops::result_type_Scalar::redispatch(dispatchKeySet, tensor, other); + } + + // aten::result_type.Scalar_Tensor(Scalar scalar, Tensor tensor) -> ScalarType + inline at::ScalarType result_type(c10::DispatchKeySet dispatchKeySet, const at::Scalar & scalar, const at::Tensor & tensor) { + return at::_ops::result_type_Scalar_Tensor::redispatch(dispatchKeySet, scalar, tensor); + } + + // aten::result_type.Scalar_Scalar(Scalar scalar1, Scalar scalar2) -> ScalarType + inline at::ScalarType result_type(c10::DispatchKeySet dispatchKeySet, const at::Scalar & scalar1, const at::Scalar & scalar2) { + return at::_ops::result_type_Scalar_Scalar::redispatch(dispatchKeySet, scalar1, scalar2); + } + + // aten::can_cast(ScalarType from_, ScalarType to) -> bool + inline bool can_cast(c10::DispatchKeySet dispatchKeySet, at::ScalarType from_, at::ScalarType to) { + return at::_ops::can_cast::redispatch(dispatchKeySet, from_, to); + } + + // aten::promote_types(ScalarType type1, ScalarType type2) -> ScalarType + inline at::ScalarType promote_types(c10::DispatchKeySet dispatchKeySet, at::ScalarType type1, at::ScalarType type2) { + return at::_ops::promote_types::redispatch(dispatchKeySet, type1, type2); + } + + // aten::_local_scalar_dense(Tensor self) -> Scalar + inline at::Scalar _local_scalar_dense(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::_local_scalar_dense::redispatch(dispatchKeySet, self); + } + + // aten::_lstm_mps(Tensor input, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor) + inline ::std::tuple _lstm_mps(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::TensorList hx, at::TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional, bool batch_first) { + return at::_ops::_lstm_mps::redispatch(dispatchKeySet, input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first); + } + + // aten::lstm_mps_backward(Tensor? grad_y, Tensor? grad_hy, Tensor? grad_cy, Tensor z_state, Tensor cell_state_fwd, Tensor input, Tensor layersOutputs, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor[], Tensor[]) + inline ::std::tuple,::std::vector> lstm_mps_backward(c10::DispatchKeySet dispatchKeySet, const ::std::optional & grad_y, const ::std::optional & grad_hy, const ::std::optional & grad_cy, const at::Tensor & z_state, const at::Tensor & cell_state_fwd, const at::Tensor & input, const at::Tensor & layersOutputs, at::TensorList hx, at::TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional, bool batch_first) { + return at::_ops::lstm_mps_backward::redispatch(dispatchKeySet, grad_y, grad_hy, grad_cy, z_state, cell_state_fwd, input, layersOutputs, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first); + } + + // aten::_thnn_fused_lstm_cell(Tensor input_gates, Tensor hidden_gates, Tensor cx, Tensor? input_bias=None, Tensor? hidden_bias=None) -> (Tensor, Tensor, Tensor) + inline ::std::tuple _thnn_fused_lstm_cell(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input_gates, const at::Tensor & hidden_gates, const at::Tensor & cx, const ::std::optional & input_bias={}, const ::std::optional & hidden_bias={}) { + return at::_ops::_thnn_fused_lstm_cell::redispatch(dispatchKeySet, input_gates, hidden_gates, cx, input_bias, hidden_bias); + } + + // aten::_thnn_fused_lstm_cell_backward_impl(Tensor? grad_hy, Tensor? grad_cy, Tensor cx, Tensor cy, Tensor workspace, bool has_bias) -> (Tensor, Tensor, Tensor) + inline ::std::tuple _thnn_fused_lstm_cell_backward_impl(c10::DispatchKeySet dispatchKeySet, const ::std::optional & grad_hy, const ::std::optional & grad_cy, const at::Tensor & cx, const at::Tensor & cy, const at::Tensor & workspace, bool has_bias) { + return at::_ops::_thnn_fused_lstm_cell_backward_impl::redispatch(dispatchKeySet, grad_hy, grad_cy, cx, cy, workspace, has_bias); + } + + // aten::_thnn_fused_lstm_cell_backward(Tensor? grad_hy, Tensor? grad_cy, Tensor cx, Tensor cy, Tensor workspace, bool has_bias) -> (Tensor, Tensor, Tensor, Tensor, Tensor) + inline ::std::tuple _thnn_fused_lstm_cell_backward(c10::DispatchKeySet dispatchKeySet, const ::std::optional & grad_hy, const ::std::optional & grad_cy, const at::Tensor & cx, const at::Tensor & cy, const at::Tensor & workspace, bool has_bias) { + return at::_ops::_thnn_fused_lstm_cell_backward::redispatch(dispatchKeySet, grad_hy, grad_cy, cx, cy, workspace, has_bias); + } + + // aten::_thnn_differentiable_lstm_cell_backward(Tensor? grad_hy, Tensor? grad_cy, Tensor input_gates, Tensor hidden_gates, Tensor? input_bias, Tensor? hidden_bias, Tensor cx, Tensor cy) -> (Tensor, Tensor, Tensor, Tensor, Tensor) + inline ::std::tuple _thnn_differentiable_lstm_cell_backward(c10::DispatchKeySet dispatchKeySet, const ::std::optional & grad_hy, const ::std::optional & grad_cy, const at::Tensor & input_gates, const at::Tensor & hidden_gates, const ::std::optional & input_bias, const ::std::optional & hidden_bias, const at::Tensor & cx, const at::Tensor & cy) { + return at::_ops::_thnn_differentiable_lstm_cell_backward::redispatch(dispatchKeySet, grad_hy, grad_cy, input_gates, hidden_gates, input_bias, hidden_bias, cx, cy); + } + + // aten::_thnn_fused_gru_cell(Tensor input_gates, Tensor hidden_gates, Tensor hx, Tensor? input_bias=None, Tensor? hidden_bias=None) -> (Tensor, Tensor) + inline ::std::tuple _thnn_fused_gru_cell(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input_gates, const at::Tensor & hidden_gates, const at::Tensor & hx, const ::std::optional & input_bias={}, const ::std::optional & hidden_bias={}) { + return at::_ops::_thnn_fused_gru_cell::redispatch(dispatchKeySet, input_gates, hidden_gates, hx, input_bias, hidden_bias); + } + + // aten::_thnn_fused_gru_cell_backward(Tensor grad_hy, Tensor workspace, bool has_bias) -> (Tensor, Tensor, Tensor, Tensor, Tensor) + inline ::std::tuple _thnn_fused_gru_cell_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_hy, const at::Tensor & workspace, bool has_bias) { + return at::_ops::_thnn_fused_gru_cell_backward::redispatch(dispatchKeySet, grad_hy, workspace, has_bias); + } + + // aten::_thnn_differentiable_gru_cell_backward(Tensor grad_hy, Tensor input_gates, Tensor hidden_gates, Tensor hx, Tensor? input_bias, Tensor? hidden_bias) -> (Tensor, Tensor, Tensor, Tensor, Tensor) + inline ::std::tuple _thnn_differentiable_gru_cell_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_hy, const at::Tensor & input_gates, const at::Tensor & hidden_gates, const at::Tensor & hx, const ::std::optional & input_bias, const ::std::optional & hidden_bias) { + return at::_ops::_thnn_differentiable_gru_cell_backward::redispatch(dispatchKeySet, grad_hy, input_gates, hidden_gates, hx, input_bias, hidden_bias); + } + + // aten::lstm.input(Tensor input, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor, Tensor) + inline ::std::tuple lstm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::TensorList hx, at::TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional, bool batch_first) { + return at::_ops::lstm_input::redispatch(dispatchKeySet, input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first); + } + + // aten::lstm.data(Tensor data, Tensor batch_sizes, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional) -> (Tensor, Tensor, Tensor) + inline ::std::tuple lstm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & data, const at::Tensor & batch_sizes, at::TensorList hx, at::TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional) { + return at::_ops::lstm_data::redispatch(dispatchKeySet, data, batch_sizes, hx, params, has_biases, num_layers, dropout, train, bidirectional); + } + + // aten::gru.input(Tensor input, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor) + inline ::std::tuple gru(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & hx, at::TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional, bool batch_first) { + return at::_ops::gru_input::redispatch(dispatchKeySet, input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first); + } + + // aten::gru.data(Tensor data, Tensor batch_sizes, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional) -> (Tensor, Tensor) + inline ::std::tuple gru(c10::DispatchKeySet dispatchKeySet, const at::Tensor & data, const at::Tensor & batch_sizes, const at::Tensor & hx, at::TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional) { + return at::_ops::gru_data::redispatch(dispatchKeySet, data, batch_sizes, hx, params, has_biases, num_layers, dropout, train, bidirectional); + } + + // aten::rnn_tanh.input(Tensor input, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor) + inline ::std::tuple rnn_tanh(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & hx, at::TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional, bool batch_first) { + return at::_ops::rnn_tanh_input::redispatch(dispatchKeySet, input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first); + } + + // aten::rnn_tanh.data(Tensor data, Tensor batch_sizes, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional) -> (Tensor, Tensor) + inline ::std::tuple rnn_tanh(c10::DispatchKeySet dispatchKeySet, const at::Tensor & data, const at::Tensor & batch_sizes, const at::Tensor & hx, at::TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional) { + return at::_ops::rnn_tanh_data::redispatch(dispatchKeySet, data, batch_sizes, hx, params, has_biases, num_layers, dropout, train, bidirectional); + } + + // aten::rnn_relu.input(Tensor input, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor) + inline ::std::tuple rnn_relu(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & hx, at::TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional, bool batch_first) { + return at::_ops::rnn_relu_input::redispatch(dispatchKeySet, input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first); + } + + // aten::rnn_relu.data(Tensor data, Tensor batch_sizes, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional) -> (Tensor, Tensor) + inline ::std::tuple rnn_relu(c10::DispatchKeySet dispatchKeySet, const at::Tensor & data, const at::Tensor & batch_sizes, const at::Tensor & hx, at::TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional) { + return at::_ops::rnn_relu_data::redispatch(dispatchKeySet, data, batch_sizes, hx, params, has_biases, num_layers, dropout, train, bidirectional); + } + + // aten::lstm_cell(Tensor input, Tensor[] hx, Tensor w_ih, Tensor w_hh, Tensor? b_ih=None, Tensor? b_hh=None) -> (Tensor, Tensor) + inline ::std::tuple lstm_cell(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::TensorList hx, const at::Tensor & w_ih, const at::Tensor & w_hh, const ::std::optional & b_ih={}, const ::std::optional & b_hh={}) { + return at::_ops::lstm_cell::redispatch(dispatchKeySet, input, hx, w_ih, w_hh, b_ih, b_hh); + } + + // aten::gru_cell(Tensor input, Tensor hx, Tensor w_ih, Tensor w_hh, Tensor? b_ih=None, Tensor? b_hh=None) -> Tensor + inline at::Tensor gru_cell(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & hx, const at::Tensor & w_ih, const at::Tensor & w_hh, const ::std::optional & b_ih={}, const ::std::optional & b_hh={}) { + return at::_ops::gru_cell::redispatch(dispatchKeySet, input, hx, w_ih, w_hh, b_ih, b_hh); + } + + // aten::rnn_tanh_cell(Tensor input, Tensor hx, Tensor w_ih, Tensor w_hh, Tensor? b_ih=None, Tensor? b_hh=None) -> Tensor + inline at::Tensor rnn_tanh_cell(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & hx, const at::Tensor & w_ih, const at::Tensor & w_hh, const ::std::optional & b_ih={}, const ::std::optional & b_hh={}) { + return at::_ops::rnn_tanh_cell::redispatch(dispatchKeySet, input, hx, w_ih, w_hh, b_ih, b_hh); + } + + // aten::rnn_relu_cell(Tensor input, Tensor hx, Tensor w_ih, Tensor w_hh, Tensor? b_ih=None, Tensor? b_hh=None) -> Tensor + inline at::Tensor rnn_relu_cell(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & hx, const at::Tensor & w_ih, const at::Tensor & w_hh, const ::std::optional & b_ih={}, const ::std::optional & b_hh={}) { + return at::_ops::rnn_relu_cell::redispatch(dispatchKeySet, input, hx, w_ih, w_hh, b_ih, b_hh); + } + + // aten::quantized_lstm_cell(Tensor input, Tensor[] hx, Tensor w_ih, Tensor w_hh, Tensor b_ih, Tensor b_hh, Tensor packed_ih, Tensor packed_hh, Tensor col_offsets_ih, Tensor col_offsets_hh, Scalar scale_ih, Scalar scale_hh, Scalar zero_point_ih, Scalar zero_point_hh) -> (Tensor, Tensor) + inline ::std::tuple quantized_lstm_cell(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::TensorList hx, const at::Tensor & w_ih, const at::Tensor & w_hh, const at::Tensor & b_ih, const at::Tensor & b_hh, const at::Tensor & packed_ih, const at::Tensor & packed_hh, const at::Tensor & col_offsets_ih, const at::Tensor & col_offsets_hh, const at::Scalar & scale_ih, const at::Scalar & scale_hh, const at::Scalar & zero_point_ih, const at::Scalar & zero_point_hh) { + return at::_ops::quantized_lstm_cell::redispatch(dispatchKeySet, input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih, col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh); + } + + // aten::quantized_gru_cell(Tensor input, Tensor hx, Tensor w_ih, Tensor w_hh, Tensor b_ih, Tensor b_hh, Tensor packed_ih, Tensor packed_hh, Tensor col_offsets_ih, Tensor col_offsets_hh, Scalar scale_ih, Scalar scale_hh, Scalar zero_point_ih, Scalar zero_point_hh) -> Tensor + inline at::Tensor quantized_gru_cell(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & hx, const at::Tensor & w_ih, const at::Tensor & w_hh, const at::Tensor & b_ih, const at::Tensor & b_hh, const at::Tensor & packed_ih, const at::Tensor & packed_hh, const at::Tensor & col_offsets_ih, const at::Tensor & col_offsets_hh, const at::Scalar & scale_ih, const at::Scalar & scale_hh, const at::Scalar & zero_point_ih, const at::Scalar & zero_point_hh) { + return at::_ops::quantized_gru_cell::redispatch(dispatchKeySet, input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih, col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh); + } + + // aten::quantized_rnn_relu_cell(Tensor input, Tensor hx, Tensor w_ih, Tensor w_hh, Tensor b_ih, Tensor b_hh, Tensor packed_ih, Tensor packed_hh, Tensor col_offsets_ih, Tensor col_offsets_hh, Scalar scale_ih, Scalar scale_hh, Scalar zero_point_ih, Scalar zero_point_hh) -> Tensor + inline at::Tensor quantized_rnn_relu_cell(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & hx, const at::Tensor & w_ih, const at::Tensor & w_hh, const at::Tensor & b_ih, const at::Tensor & b_hh, const at::Tensor & packed_ih, const at::Tensor & packed_hh, const at::Tensor & col_offsets_ih, const at::Tensor & col_offsets_hh, const at::Scalar & scale_ih, const at::Scalar & scale_hh, const at::Scalar & zero_point_ih, const at::Scalar & zero_point_hh) { + return at::_ops::quantized_rnn_relu_cell::redispatch(dispatchKeySet, input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih, col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh); + } + + // aten::quantized_rnn_tanh_cell(Tensor input, Tensor hx, Tensor w_ih, Tensor w_hh, Tensor b_ih, Tensor b_hh, Tensor packed_ih, Tensor packed_hh, Tensor col_offsets_ih, Tensor col_offsets_hh, Scalar scale_ih, Scalar scale_hh, Scalar zero_point_ih, Scalar zero_point_hh) -> Tensor + inline at::Tensor quantized_rnn_tanh_cell(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & hx, const at::Tensor & w_ih, const at::Tensor & w_hh, const at::Tensor & b_ih, const at::Tensor & b_hh, const at::Tensor & packed_ih, const at::Tensor & packed_hh, const at::Tensor & col_offsets_ih, const at::Tensor & col_offsets_hh, const at::Scalar & scale_ih, const at::Scalar & scale_hh, const at::Scalar & zero_point_ih, const at::Scalar & zero_point_hh) { + return at::_ops::quantized_rnn_tanh_cell::redispatch(dispatchKeySet, input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih, col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh); + } + + // aten::_pack_padded_sequence(Tensor input, Tensor lengths, bool batch_first) -> (Tensor, Tensor) + inline ::std::tuple _pack_padded_sequence(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & lengths, bool batch_first) { + return at::_ops::_pack_padded_sequence::redispatch(dispatchKeySet, input, lengths, batch_first); + } + + // aten::_pack_padded_sequence_backward(Tensor grad, SymInt[] input_size, Tensor batch_sizes, bool batch_first) -> Tensor + inline at::Tensor _pack_padded_sequence_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, at::IntArrayRef input_size, const at::Tensor & batch_sizes, bool batch_first) { + return at::_ops::_pack_padded_sequence_backward::redispatch(dispatchKeySet, grad, c10::fromIntArrayRefSlow(input_size), batch_sizes, batch_first); + } + + // aten::_pack_padded_sequence_backward(Tensor grad, SymInt[] input_size, Tensor batch_sizes, bool batch_first) -> Tensor + inline at::Tensor _pack_padded_sequence_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, c10::SymIntArrayRef input_size, const at::Tensor & batch_sizes, bool batch_first) { + return at::_ops::_pack_padded_sequence_backward::redispatch(dispatchKeySet, grad, input_size, batch_sizes, batch_first); + } + + // aten::_pad_packed_sequence(Tensor data, Tensor batch_sizes, bool batch_first, Scalar padding_value, int total_length) -> (Tensor, Tensor) + inline ::std::tuple _pad_packed_sequence(c10::DispatchKeySet dispatchKeySet, const at::Tensor & data, const at::Tensor & batch_sizes, bool batch_first, const at::Scalar & padding_value, int64_t total_length) { + return at::_ops::_pad_packed_sequence::redispatch(dispatchKeySet, data, batch_sizes, batch_first, padding_value, total_length); + } + + // aten::set_.source_Storage(Tensor(a!) self, Storage source) -> Tensor(a!) + inline at::Tensor & set_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, at::Storage source) { + return at::_ops::set__source_Storage::redispatch(dispatchKeySet, self, source); + } + + // aten::set_.source_Storage_storage_offset(Tensor(a!) self, Storage source, SymInt storage_offset, SymInt[] size, SymInt[] stride=[]) -> Tensor(a!) + inline at::Tensor & set_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, at::Storage source, int64_t storage_offset, at::IntArrayRef size, at::IntArrayRef stride={}) { + return at::_ops::set__source_Storage_storage_offset::redispatch(dispatchKeySet, self, source, storage_offset, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride)); + } + + // aten::set_.source_Storage_storage_offset(Tensor(a!) self, Storage source, SymInt storage_offset, SymInt[] size, SymInt[] stride=[]) -> Tensor(a!) + inline at::Tensor & set__symint(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, at::Storage source, c10::SymInt storage_offset, c10::SymIntArrayRef size, c10::SymIntArrayRef stride={}) { + return at::_ops::set__source_Storage_storage_offset::redispatch(dispatchKeySet, self, source, storage_offset, size, stride); + } + + // aten::set_.source_Tensor_storage_offset(Tensor(a!) self, Tensor source, SymInt storage_offset, SymInt[] size, SymInt[] stride=[]) -> Tensor(a!) + inline at::Tensor & set_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & source, int64_t storage_offset, at::IntArrayRef size, at::IntArrayRef stride={}) { + return at::_ops::set__source_Tensor_storage_offset::redispatch(dispatchKeySet, self, source, storage_offset, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride)); + } + + // aten::set_.source_Tensor_storage_offset(Tensor(a!) self, Tensor source, SymInt storage_offset, SymInt[] size, SymInt[] stride=[]) -> Tensor(a!) + inline at::Tensor & set__symint(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & source, c10::SymInt storage_offset, c10::SymIntArrayRef size, c10::SymIntArrayRef stride={}) { + return at::_ops::set__source_Tensor_storage_offset::redispatch(dispatchKeySet, self, source, storage_offset, size, stride); + } + + // aten::set_.source_Tensor(Tensor(a!) self, Tensor source) -> Tensor(a!) + inline at::Tensor & set_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & source) { + return at::_ops::set__source_Tensor::redispatch(dispatchKeySet, self, source); + } + + // aten::set_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & set_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::set_::redispatch(dispatchKeySet, self); + } + + // aten::lift(Tensor self) -> Tensor + inline at::Tensor lift(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::lift::redispatch(dispatchKeySet, self); + } + + // aten::lift_fresh(Tensor(a) self) -> Tensor(a) + inline at::Tensor lift_fresh(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::lift_fresh::redispatch(dispatchKeySet, self); + } + + // aten::lift_fresh_copy(Tensor self) -> Tensor + inline at::Tensor lift_fresh_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::lift_fresh_copy::redispatch(dispatchKeySet, self); + } + + // aten::is_set_to(Tensor self, Tensor tensor) -> bool + inline bool is_set_to(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & tensor) { + return at::_ops::is_set_to::redispatch(dispatchKeySet, self, tensor); + } + + // aten::masked_fill_.Scalar(Tensor(a!) self, Tensor mask, Scalar value) -> Tensor(a!) + inline at::Tensor & masked_fill_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & mask, const at::Scalar & value) { + return at::_ops::masked_fill__Scalar::redispatch(dispatchKeySet, self, mask, value); + } + + // aten::masked_fill.Scalar(Tensor self, Tensor mask, Scalar value) -> Tensor + inline at::Tensor masked_fill(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mask, const at::Scalar & value) { + return at::_ops::masked_fill_Scalar::redispatch(dispatchKeySet, self, mask, value); + } + + // aten::masked_fill_.Tensor(Tensor(a!) self, Tensor mask, Tensor value) -> Tensor(a!) + inline at::Tensor & masked_fill_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & mask, const at::Tensor & value) { + return at::_ops::masked_fill__Tensor::redispatch(dispatchKeySet, self, mask, value); + } + + // aten::masked_fill.Tensor(Tensor self, Tensor mask, Tensor value) -> Tensor + inline at::Tensor masked_fill(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mask, const at::Tensor & value) { + return at::_ops::masked_fill_Tensor::redispatch(dispatchKeySet, self, mask, value); + } + + // aten::masked_scatter_(Tensor(a!) self, Tensor mask, Tensor source) -> Tensor(a!) + inline at::Tensor & masked_scatter_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & mask, const at::Tensor & source) { + return at::_ops::masked_scatter_::redispatch(dispatchKeySet, self, mask, source); + } + + // aten::masked_scatter(Tensor self, Tensor mask, Tensor source) -> Tensor + inline at::Tensor masked_scatter(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mask, const at::Tensor & source) { + return at::_ops::masked_scatter::redispatch(dispatchKeySet, self, mask, source); + } + + // aten::masked_scatter_backward(Tensor grad_output, Tensor mask, SymInt[] sizes) -> Tensor + inline at::Tensor masked_scatter_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & mask, at::IntArrayRef sizes) { + return at::_ops::masked_scatter_backward::redispatch(dispatchKeySet, grad_output, mask, c10::fromIntArrayRefSlow(sizes)); + } + + // aten::masked_scatter_backward(Tensor grad_output, Tensor mask, SymInt[] sizes) -> Tensor + inline at::Tensor masked_scatter_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & mask, c10::SymIntArrayRef sizes) { + return at::_ops::masked_scatter_backward::redispatch(dispatchKeySet, grad_output, mask, sizes); + } + + // aten::_masked_softmax(Tensor self, Tensor mask, int? dim=None, int? mask_type=None) -> Tensor + inline at::Tensor _masked_softmax(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mask, ::std::optional dim=::std::nullopt, ::std::optional mask_type=::std::nullopt) { + return at::_ops::_masked_softmax::redispatch(dispatchKeySet, self, mask, dim, mask_type); + } + + // aten::_masked_softmax_backward(Tensor grad_output, Tensor output, Tensor mask, int? dim=None) -> Tensor + inline at::Tensor _masked_softmax_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & output, const at::Tensor & mask, ::std::optional dim=::std::nullopt) { + return at::_ops::_masked_softmax_backward::redispatch(dispatchKeySet, grad_output, output, mask, dim); + } + + // aten::view(Tensor(a) self, SymInt[] size) -> Tensor(a) + inline at::Tensor view(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size) { + return at::_ops::view::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size)); + } + + // aten::view(Tensor(a) self, SymInt[] size) -> Tensor(a) + inline at::Tensor view_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size) { + return at::_ops::view::redispatch(dispatchKeySet, self, size); + } + + // aten::view.dtype(Tensor(a) self, ScalarType dtype) -> Tensor(a) + inline at::Tensor view(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::ScalarType dtype) { + return at::_ops::view_dtype::redispatch(dispatchKeySet, self, dtype); + } + + // aten::put_(Tensor(a!) self, Tensor index, Tensor source, bool accumulate=False) -> Tensor(a!) + inline at::Tensor & put_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & index, const at::Tensor & source, bool accumulate=false) { + return at::_ops::put_::redispatch(dispatchKeySet, self, index, source, accumulate); + } + + // aten::put(Tensor self, Tensor index, Tensor source, bool accumulate=False) -> Tensor + inline at::Tensor put(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & index, const at::Tensor & source, bool accumulate=false) { + return at::_ops::put::redispatch(dispatchKeySet, self, index, source, accumulate); + } + + // aten::index_add.out(Tensor self, int dim, Tensor index, Tensor source, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & index_add_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & source, const at::Scalar & alpha=1) { + return at::_ops::index_add_out::redispatch(dispatchKeySet, self, dim, index, source, alpha, out); + } + + // aten::index_add.out(Tensor self, int dim, Tensor index, Tensor source, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & index_add_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & source, const at::Scalar & alpha, at::Tensor & out) { + return at::_ops::index_add_out::redispatch(dispatchKeySet, self, dim, index, source, alpha, out); + } + + // aten::index_add_(Tensor(a!) self, int dim, Tensor index, Tensor source, *, Scalar alpha=1) -> Tensor(a!) + inline at::Tensor & index_add_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & source, const at::Scalar & alpha=1) { + return at::_ops::index_add_::redispatch(dispatchKeySet, self, dim, index, source, alpha); + } + + // aten::index_add(Tensor self, int dim, Tensor index, Tensor source, *, Scalar alpha=1) -> Tensor + inline at::Tensor index_add(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & source, const at::Scalar & alpha=1) { + return at::_ops::index_add::redispatch(dispatchKeySet, self, dim, index, source, alpha); + } + + // aten::index_add.dimname(Tensor self, Dimname dim, Tensor index, Tensor source, *, Scalar alpha=1) -> Tensor + inline at::Tensor index_add(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, const at::Tensor & index, const at::Tensor & source, const at::Scalar & alpha=1) { + return at::_ops::index_add_dimname::redispatch(dispatchKeySet, self, dim, index, source, alpha); + } + + // aten::index_reduce.out(Tensor self, int dim, Tensor index, Tensor source, str reduce, *, bool include_self=True, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & index_reduce_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & source, c10::string_view reduce, bool include_self=true) { + return at::_ops::index_reduce_out::redispatch(dispatchKeySet, self, dim, index, source, reduce, include_self, out); + } + + // aten::index_reduce.out(Tensor self, int dim, Tensor index, Tensor source, str reduce, *, bool include_self=True, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & index_reduce_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & source, c10::string_view reduce, bool include_self, at::Tensor & out) { + return at::_ops::index_reduce_out::redispatch(dispatchKeySet, self, dim, index, source, reduce, include_self, out); + } + + // aten::index_reduce_(Tensor(a!) self, int dim, Tensor index, Tensor source, str reduce, *, bool include_self=True) -> Tensor(a!) + inline at::Tensor & index_reduce_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & source, c10::string_view reduce, bool include_self=true) { + return at::_ops::index_reduce_::redispatch(dispatchKeySet, self, dim, index, source, reduce, include_self); + } + + // aten::index_reduce(Tensor self, int dim, Tensor index, Tensor source, str reduce, *, bool include_self=True) -> Tensor + inline at::Tensor index_reduce(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & source, c10::string_view reduce, bool include_self=true) { + return at::_ops::index_reduce::redispatch(dispatchKeySet, self, dim, index, source, reduce, include_self); + } + + // aten::index_fill_.int_Scalar(Tensor(a!) self, int dim, Tensor index, Scalar value) -> Tensor(a!) + inline at::Tensor & index_fill_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Scalar & value) { + return at::_ops::index_fill__int_Scalar::redispatch(dispatchKeySet, self, dim, index, value); + } + + // aten::index_fill.int_Scalar(Tensor self, int dim, Tensor index, Scalar value) -> Tensor + inline at::Tensor index_fill(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Scalar & value) { + return at::_ops::index_fill_int_Scalar::redispatch(dispatchKeySet, self, dim, index, value); + } + + // aten::index_fill_.int_Tensor(Tensor(a!) self, int dim, Tensor index, Tensor value) -> Tensor(a!) + inline at::Tensor & index_fill_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & value) { + return at::_ops::index_fill__int_Tensor::redispatch(dispatchKeySet, self, dim, index, value); + } + + // aten::index_fill.int_Tensor(Tensor self, int dim, Tensor index, Tensor value) -> Tensor + inline at::Tensor index_fill(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & value) { + return at::_ops::index_fill_int_Tensor::redispatch(dispatchKeySet, self, dim, index, value); + } + + // aten::index_fill_.Dimname_Scalar(Tensor(a!) self, Dimname dim, Tensor index, Scalar value) -> Tensor(a!) + inline at::Tensor & index_fill_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, at::Dimname dim, const at::Tensor & index, const at::Scalar & value) { + return at::_ops::index_fill__Dimname_Scalar::redispatch(dispatchKeySet, self, dim, index, value); + } + + // aten::index_fill_.Dimname_Tensor(Tensor(a!) self, Dimname dim, Tensor index, Tensor value) -> Tensor(a!) + inline at::Tensor & index_fill_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, at::Dimname dim, const at::Tensor & index, const at::Tensor & value) { + return at::_ops::index_fill__Dimname_Tensor::redispatch(dispatchKeySet, self, dim, index, value); + } + + // aten::index_fill.Dimname_Scalar(Tensor self, Dimname dim, Tensor index, Scalar value) -> Tensor + inline at::Tensor index_fill(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, const at::Tensor & index, const at::Scalar & value) { + return at::_ops::index_fill_Dimname_Scalar::redispatch(dispatchKeySet, self, dim, index, value); + } + + // aten::index_fill.Dimname_Tensor(Tensor self, Dimname dim, Tensor index, Tensor value) -> Tensor + inline at::Tensor index_fill(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, const at::Tensor & index, const at::Tensor & value) { + return at::_ops::index_fill_Dimname_Tensor::redispatch(dispatchKeySet, self, dim, index, value); + } + + // aten::scatter.src(Tensor self, int dim, Tensor index, Tensor src) -> Tensor + inline at::Tensor scatter(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src) { + return at::_ops::scatter_src::redispatch(dispatchKeySet, self, dim, index, src); + } + + // aten::scatter_.src(Tensor(a!) self, int dim, Tensor index, Tensor src) -> Tensor(a!) + inline at::Tensor & scatter_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src) { + return at::_ops::scatter__src::redispatch(dispatchKeySet, self, dim, index, src); + } + + // aten::scatter.src_out(Tensor self, int dim, Tensor index, Tensor src, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & scatter_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src) { + return at::_ops::scatter_src_out::redispatch(dispatchKeySet, self, dim, index, src, out); + } + + // aten::scatter.src_out(Tensor self, int dim, Tensor index, Tensor src, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & scatter_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src, at::Tensor & out) { + return at::_ops::scatter_src_out::redispatch(dispatchKeySet, self, dim, index, src, out); + } + + // aten::scatter.value(Tensor self, int dim, Tensor index, Scalar value) -> Tensor + inline at::Tensor scatter(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Scalar & value) { + return at::_ops::scatter_value::redispatch(dispatchKeySet, self, dim, index, value); + } + + // aten::scatter_.value(Tensor(a!) self, int dim, Tensor index, Scalar value) -> Tensor(a!) + inline at::Tensor & scatter_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Scalar & value) { + return at::_ops::scatter__value::redispatch(dispatchKeySet, self, dim, index, value); + } + + // aten::scatter.value_out(Tensor self, int dim, Tensor index, Scalar value, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & scatter_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Scalar & value) { + return at::_ops::scatter_value_out::redispatch(dispatchKeySet, self, dim, index, value, out); + } + + // aten::scatter.value_out(Tensor self, int dim, Tensor index, Scalar value, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & scatter_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Scalar & value, at::Tensor & out) { + return at::_ops::scatter_value_out::redispatch(dispatchKeySet, self, dim, index, value, out); + } + + // aten::scatter.reduce(Tensor self, int dim, Tensor index, Tensor src, *, str reduce) -> Tensor + inline at::Tensor scatter(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src, c10::string_view reduce) { + return at::_ops::scatter_reduce::redispatch(dispatchKeySet, self, dim, index, src, reduce); + } + + // aten::scatter_.reduce(Tensor(a!) self, int dim, Tensor index, Tensor src, *, str reduce) -> Tensor(a!) + inline at::Tensor & scatter_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src, c10::string_view reduce) { + return at::_ops::scatter__reduce::redispatch(dispatchKeySet, self, dim, index, src, reduce); + } + + // aten::scatter.reduce_out(Tensor self, int dim, Tensor index, Tensor src, *, str reduce, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & scatter_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src, c10::string_view reduce) { + return at::_ops::scatter_reduce_out::redispatch(dispatchKeySet, self, dim, index, src, reduce, out); + } + + // aten::scatter.reduce_out(Tensor self, int dim, Tensor index, Tensor src, *, str reduce, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & scatter_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src, c10::string_view reduce, at::Tensor & out) { + return at::_ops::scatter_reduce_out::redispatch(dispatchKeySet, self, dim, index, src, reduce, out); + } + + // aten::scatter.value_reduce(Tensor self, int dim, Tensor index, Scalar value, *, str reduce) -> Tensor + inline at::Tensor scatter(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Scalar & value, c10::string_view reduce) { + return at::_ops::scatter_value_reduce::redispatch(dispatchKeySet, self, dim, index, value, reduce); + } + + // aten::scatter_.value_reduce(Tensor(a!) self, int dim, Tensor index, Scalar value, *, str reduce) -> Tensor(a!) + inline at::Tensor & scatter_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Scalar & value, c10::string_view reduce) { + return at::_ops::scatter__value_reduce::redispatch(dispatchKeySet, self, dim, index, value, reduce); + } + + // aten::scatter.value_reduce_out(Tensor self, int dim, Tensor index, Scalar value, *, str reduce, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & scatter_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Scalar & value, c10::string_view reduce) { + return at::_ops::scatter_value_reduce_out::redispatch(dispatchKeySet, self, dim, index, value, reduce, out); + } + + // aten::scatter.value_reduce_out(Tensor self, int dim, Tensor index, Scalar value, *, str reduce, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & scatter_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Scalar & value, c10::string_view reduce, at::Tensor & out) { + return at::_ops::scatter_value_reduce_out::redispatch(dispatchKeySet, self, dim, index, value, reduce, out); + } + + // aten::scatter.dimname_src(Tensor self, Dimname dim, Tensor index, Tensor src) -> Tensor + inline at::Tensor scatter(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, const at::Tensor & index, const at::Tensor & src) { + return at::_ops::scatter_dimname_src::redispatch(dispatchKeySet, self, dim, index, src); + } + + // aten::scatter.dimname_value(Tensor self, Dimname dim, Tensor index, Scalar value) -> Tensor + inline at::Tensor scatter(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, const at::Tensor & index, const at::Scalar & value) { + return at::_ops::scatter_dimname_value::redispatch(dispatchKeySet, self, dim, index, value); + } + + // aten::scatter_add(Tensor self, int dim, Tensor index, Tensor src) -> Tensor + inline at::Tensor scatter_add(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src) { + return at::_ops::scatter_add::redispatch(dispatchKeySet, self, dim, index, src); + } + + // aten::scatter_add_(Tensor(a!) self, int dim, Tensor index, Tensor src) -> Tensor(a!) + inline at::Tensor & scatter_add_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src) { + return at::_ops::scatter_add_::redispatch(dispatchKeySet, self, dim, index, src); + } + + // aten::scatter_add.out(Tensor self, int dim, Tensor index, Tensor src, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & scatter_add_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src) { + return at::_ops::scatter_add_out::redispatch(dispatchKeySet, self, dim, index, src, out); + } + + // aten::scatter_add.out(Tensor self, int dim, Tensor index, Tensor src, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & scatter_add_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src, at::Tensor & out) { + return at::_ops::scatter_add_out::redispatch(dispatchKeySet, self, dim, index, src, out); + } + + // aten::scatter_add.dimname(Tensor self, Dimname dim, Tensor index, Tensor src) -> Tensor + inline at::Tensor scatter_add(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, const at::Tensor & index, const at::Tensor & src) { + return at::_ops::scatter_add_dimname::redispatch(dispatchKeySet, self, dim, index, src); + } + + // aten::scatter_reduce.two(Tensor self, int dim, Tensor index, Tensor src, str reduce, *, bool include_self=True) -> Tensor + inline at::Tensor scatter_reduce(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src, c10::string_view reduce, bool include_self=true) { + return at::_ops::scatter_reduce_two::redispatch(dispatchKeySet, self, dim, index, src, reduce, include_self); + } + + // aten::scatter_reduce_.two(Tensor(a!) self, int dim, Tensor index, Tensor src, str reduce, *, bool include_self=True) -> Tensor(a!) + inline at::Tensor & scatter_reduce_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src, c10::string_view reduce, bool include_self=true) { + return at::_ops::scatter_reduce__two::redispatch(dispatchKeySet, self, dim, index, src, reduce, include_self); + } + + // aten::scatter_reduce.two_out(Tensor self, int dim, Tensor index, Tensor src, str reduce, *, bool include_self=True, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & scatter_reduce_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src, c10::string_view reduce, bool include_self=true) { + return at::_ops::scatter_reduce_two_out::redispatch(dispatchKeySet, self, dim, index, src, reduce, include_self, out); + } + + // aten::scatter_reduce.two_out(Tensor self, int dim, Tensor index, Tensor src, str reduce, *, bool include_self=True, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & scatter_reduce_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src, c10::string_view reduce, bool include_self, at::Tensor & out) { + return at::_ops::scatter_reduce_two_out::redispatch(dispatchKeySet, self, dim, index, src, reduce, include_self, out); + } + + // aten::eq_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + inline at::Tensor & eq_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other) { + return at::_ops::eq__Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::eq_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + inline at::Tensor & eq_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) { + return at::_ops::eq__Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::bitwise_and.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bitwise_and_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::bitwise_and_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::bitwise_and.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bitwise_and_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::bitwise_and_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::bitwise_and.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bitwise_and_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::bitwise_and_Scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::bitwise_and.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bitwise_and_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, at::Tensor & out) { + return at::_ops::bitwise_and_Scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::bitwise_and.Scalar(Tensor self, Scalar other) -> Tensor + inline at::Tensor bitwise_and(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::bitwise_and_Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::bitwise_and.Scalar_Tensor(Scalar self, Tensor other) -> Tensor + inline at::Tensor bitwise_and(c10::DispatchKeySet dispatchKeySet, const at::Scalar & self, const at::Tensor & other) { + return at::_ops::bitwise_and_Scalar_Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::bitwise_and.Tensor(Tensor self, Tensor other) -> Tensor + inline at::Tensor bitwise_and(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::bitwise_and_Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::bitwise_and_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + inline at::Tensor & bitwise_and_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other) { + return at::_ops::bitwise_and__Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::bitwise_and_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + inline at::Tensor & bitwise_and_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) { + return at::_ops::bitwise_and__Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::__and__.Scalar(Tensor self, Scalar other) -> Tensor + inline at::Tensor __and__(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::__and___Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::__and__.Tensor(Tensor self, Tensor other) -> Tensor + inline at::Tensor __and__(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::__and___Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::__iand__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + inline at::Tensor & __iand__(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other) { + return at::_ops::__iand___Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::__iand__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + inline at::Tensor & __iand__(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) { + return at::_ops::__iand___Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::bitwise_or.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bitwise_or_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::bitwise_or_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::bitwise_or.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bitwise_or_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::bitwise_or_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::bitwise_or.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bitwise_or_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::bitwise_or_Scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::bitwise_or.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bitwise_or_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, at::Tensor & out) { + return at::_ops::bitwise_or_Scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::bitwise_or.Scalar(Tensor self, Scalar other) -> Tensor + inline at::Tensor bitwise_or(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::bitwise_or_Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::bitwise_or.Scalar_Tensor(Scalar self, Tensor other) -> Tensor + inline at::Tensor bitwise_or(c10::DispatchKeySet dispatchKeySet, const at::Scalar & self, const at::Tensor & other) { + return at::_ops::bitwise_or_Scalar_Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::bitwise_or.Tensor(Tensor self, Tensor other) -> Tensor + inline at::Tensor bitwise_or(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::bitwise_or_Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::bitwise_or_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + inline at::Tensor & bitwise_or_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other) { + return at::_ops::bitwise_or__Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::bitwise_or_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + inline at::Tensor & bitwise_or_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) { + return at::_ops::bitwise_or__Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::__or__.Scalar(Tensor self, Scalar other) -> Tensor + inline at::Tensor __or__(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::__or___Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::__or__.Tensor(Tensor self, Tensor other) -> Tensor + inline at::Tensor __or__(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::__or___Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::__ior__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + inline at::Tensor & __ior__(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other) { + return at::_ops::__ior___Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::__ior__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + inline at::Tensor & __ior__(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) { + return at::_ops::__ior___Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::bitwise_xor.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bitwise_xor_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::bitwise_xor_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::bitwise_xor.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bitwise_xor_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::bitwise_xor_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::bitwise_xor.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bitwise_xor_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::bitwise_xor_Scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::bitwise_xor.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bitwise_xor_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, at::Tensor & out) { + return at::_ops::bitwise_xor_Scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::bitwise_xor.Scalar(Tensor self, Scalar other) -> Tensor + inline at::Tensor bitwise_xor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::bitwise_xor_Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::bitwise_xor.Scalar_Tensor(Scalar self, Tensor other) -> Tensor + inline at::Tensor bitwise_xor(c10::DispatchKeySet dispatchKeySet, const at::Scalar & self, const at::Tensor & other) { + return at::_ops::bitwise_xor_Scalar_Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::bitwise_xor.Tensor(Tensor self, Tensor other) -> Tensor + inline at::Tensor bitwise_xor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::bitwise_xor_Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::bitwise_xor_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + inline at::Tensor & bitwise_xor_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other) { + return at::_ops::bitwise_xor__Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::bitwise_xor_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + inline at::Tensor & bitwise_xor_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) { + return at::_ops::bitwise_xor__Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::__xor__.Scalar(Tensor self, Scalar other) -> Tensor + inline at::Tensor __xor__(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::__xor___Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::__xor__.Tensor(Tensor self, Tensor other) -> Tensor + inline at::Tensor __xor__(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::__xor___Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::__ixor__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + inline at::Tensor & __ixor__(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other) { + return at::_ops::__ixor___Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::__ixor__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + inline at::Tensor & __ixor__(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) { + return at::_ops::__ixor___Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::__lshift__.Scalar(Tensor self, Scalar other) -> Tensor + inline at::Tensor __lshift__(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::__lshift___Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::__lshift__.Tensor(Tensor self, Tensor other) -> Tensor + inline at::Tensor __lshift__(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::__lshift___Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::__ilshift__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + inline at::Tensor & __ilshift__(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other) { + return at::_ops::__ilshift___Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::__ilshift__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + inline at::Tensor & __ilshift__(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) { + return at::_ops::__ilshift___Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::bitwise_left_shift.Tensor(Tensor self, Tensor other) -> Tensor + inline at::Tensor bitwise_left_shift(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::bitwise_left_shift_Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::bitwise_left_shift_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + inline at::Tensor & bitwise_left_shift_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) { + return at::_ops::bitwise_left_shift__Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::bitwise_left_shift.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bitwise_left_shift_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::bitwise_left_shift_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::bitwise_left_shift.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bitwise_left_shift_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::bitwise_left_shift_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::bitwise_left_shift.Tensor_Scalar(Tensor self, Scalar other) -> Tensor + inline at::Tensor bitwise_left_shift(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::bitwise_left_shift_Tensor_Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::bitwise_left_shift_.Tensor_Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + inline at::Tensor & bitwise_left_shift_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other) { + return at::_ops::bitwise_left_shift__Tensor_Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::bitwise_left_shift.Tensor_Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bitwise_left_shift_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::bitwise_left_shift_Tensor_Scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::bitwise_left_shift.Tensor_Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bitwise_left_shift_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, at::Tensor & out) { + return at::_ops::bitwise_left_shift_Tensor_Scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::bitwise_left_shift.Scalar_Tensor(Scalar self, Tensor other) -> Tensor + inline at::Tensor bitwise_left_shift(c10::DispatchKeySet dispatchKeySet, const at::Scalar & self, const at::Tensor & other) { + return at::_ops::bitwise_left_shift_Scalar_Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::__rshift__.Scalar(Tensor self, Scalar other) -> Tensor + inline at::Tensor __rshift__(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::__rshift___Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::__rshift__.Tensor(Tensor self, Tensor other) -> Tensor + inline at::Tensor __rshift__(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::__rshift___Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::__irshift__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + inline at::Tensor & __irshift__(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other) { + return at::_ops::__irshift___Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::__irshift__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + inline at::Tensor & __irshift__(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) { + return at::_ops::__irshift___Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::bitwise_right_shift.Tensor(Tensor self, Tensor other) -> Tensor + inline at::Tensor bitwise_right_shift(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::bitwise_right_shift_Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::bitwise_right_shift_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + inline at::Tensor & bitwise_right_shift_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) { + return at::_ops::bitwise_right_shift__Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::bitwise_right_shift.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bitwise_right_shift_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::bitwise_right_shift_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::bitwise_right_shift.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bitwise_right_shift_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::bitwise_right_shift_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::bitwise_right_shift.Tensor_Scalar(Tensor self, Scalar other) -> Tensor + inline at::Tensor bitwise_right_shift(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::bitwise_right_shift_Tensor_Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::bitwise_right_shift_.Tensor_Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + inline at::Tensor & bitwise_right_shift_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other) { + return at::_ops::bitwise_right_shift__Tensor_Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::bitwise_right_shift.Tensor_Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bitwise_right_shift_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::bitwise_right_shift_Tensor_Scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::bitwise_right_shift.Tensor_Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bitwise_right_shift_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, at::Tensor & out) { + return at::_ops::bitwise_right_shift_Tensor_Scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::bitwise_right_shift.Scalar_Tensor(Scalar self, Tensor other) -> Tensor + inline at::Tensor bitwise_right_shift(c10::DispatchKeySet dispatchKeySet, const at::Scalar & self, const at::Tensor & other) { + return at::_ops::bitwise_right_shift_Scalar_Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::tril_(Tensor(a!) self, SymInt diagonal=0) -> Tensor(a!) + inline at::Tensor & tril_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, int64_t diagonal=0) { + return at::_ops::tril_::redispatch(dispatchKeySet, self, diagonal); + } + + // aten::tril_(Tensor(a!) self, SymInt diagonal=0) -> Tensor(a!) + inline at::Tensor & tril__symint(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, c10::SymInt diagonal=0) { + return at::_ops::tril_::redispatch(dispatchKeySet, self, diagonal); + } + + // aten::triu_(Tensor(a!) self, SymInt diagonal=0) -> Tensor(a!) + inline at::Tensor & triu_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, int64_t diagonal=0) { + return at::_ops::triu_::redispatch(dispatchKeySet, self, diagonal); + } + + // aten::triu_(Tensor(a!) self, SymInt diagonal=0) -> Tensor(a!) + inline at::Tensor & triu__symint(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, c10::SymInt diagonal=0) { + return at::_ops::triu_::redispatch(dispatchKeySet, self, diagonal); + } + + // aten::digamma_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & digamma_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::digamma_::redispatch(dispatchKeySet, self); + } + + // aten::lerp_.Scalar(Tensor(a!) self, Tensor end, Scalar weight) -> Tensor(a!) + inline at::Tensor & lerp_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & end, const at::Scalar & weight) { + return at::_ops::lerp__Scalar::redispatch(dispatchKeySet, self, end, weight); + } + + // aten::lerp_.Tensor(Tensor(a!) self, Tensor end, Tensor weight) -> Tensor(a!) + inline at::Tensor & lerp_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & end, const at::Tensor & weight) { + return at::_ops::lerp__Tensor::redispatch(dispatchKeySet, self, end, weight); + } + + // aten::addbmm_(Tensor(a!) self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!) + inline at::Tensor & addbmm_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta=1, const at::Scalar & alpha=1) { + return at::_ops::addbmm_::redispatch(dispatchKeySet, self, batch1, batch2, beta, alpha); + } + + // aten::addbmm.out(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & addbmm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta=1, const at::Scalar & alpha=1) { + return at::_ops::addbmm_out::redispatch(dispatchKeySet, self, batch1, batch2, beta, alpha, out); + } + + // aten::addbmm.out(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & addbmm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out) { + return at::_ops::addbmm_out::redispatch(dispatchKeySet, self, batch1, batch2, beta, alpha, out); + } + + // aten::addbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor + inline at::Tensor addbmm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta=1, const at::Scalar & alpha=1) { + return at::_ops::addbmm::redispatch(dispatchKeySet, self, batch1, batch2, beta, alpha); + } + + // aten::random_.from(Tensor(a!) self, int from, int? to, *, Generator? generator=None) -> Tensor(a!) + inline at::Tensor & random_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, int64_t from, ::std::optional to, ::std::optional generator=::std::nullopt) { + return at::_ops::random__from::redispatch(dispatchKeySet, self, from, to, generator); + } + + // aten::random_.to(Tensor(a!) self, int to, *, Generator? generator=None) -> Tensor(a!) + inline at::Tensor & random_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, int64_t to, ::std::optional generator=::std::nullopt) { + return at::_ops::random__to::redispatch(dispatchKeySet, self, to, generator); + } + + // aten::random_(Tensor(a!) self, *, Generator? generator=None) -> Tensor(a!) + inline at::Tensor & random_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, ::std::optional generator=::std::nullopt) { + return at::_ops::random_::redispatch(dispatchKeySet, self, generator); + } + + // aten::uniform_(Tensor(a!) self, float from=0, float to=1, *, Generator? generator=None) -> Tensor(a!) + inline at::Tensor & uniform_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, double from=0, double to=1, ::std::optional generator=::std::nullopt) { + return at::_ops::uniform_::redispatch(dispatchKeySet, self, from, to, generator); + } + + // aten::cauchy_(Tensor(a!) self, float median=0, float sigma=1, *, Generator? generator=None) -> Tensor(a!) + inline at::Tensor & cauchy_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, double median=0, double sigma=1, ::std::optional generator=::std::nullopt) { + return at::_ops::cauchy_::redispatch(dispatchKeySet, self, median, sigma, generator); + } + + // aten::log_normal_(Tensor(a!) self, float mean=1, float std=2, *, Generator? generator=None) -> Tensor(a!) + inline at::Tensor & log_normal_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, double mean=1, double std=2, ::std::optional generator=::std::nullopt) { + return at::_ops::log_normal_::redispatch(dispatchKeySet, self, mean, std, generator); + } + + // aten::exponential_(Tensor(a!) self, float lambd=1, *, Generator? generator=None) -> Tensor(a!) + inline at::Tensor & exponential_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, double lambd=1, ::std::optional generator=::std::nullopt) { + return at::_ops::exponential_::redispatch(dispatchKeySet, self, lambd, generator); + } + + // aten::geometric_(Tensor(a!) self, float p, *, Generator? generator=None) -> Tensor(a!) + inline at::Tensor & geometric_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, double p, ::std::optional generator=::std::nullopt) { + return at::_ops::geometric_::redispatch(dispatchKeySet, self, p, generator); + } + + // aten::diag.out(Tensor self, int diagonal=0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & diag_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t diagonal=0) { + return at::_ops::diag_out::redispatch(dispatchKeySet, self, diagonal, out); + } + + // aten::diag.out(Tensor self, int diagonal=0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & diag_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t diagonal, at::Tensor & out) { + return at::_ops::diag_out::redispatch(dispatchKeySet, self, diagonal, out); + } + + // aten::diag(Tensor self, int diagonal=0) -> Tensor + inline at::Tensor diag(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t diagonal=0) { + return at::_ops::diag::redispatch(dispatchKeySet, self, diagonal); + } + + // aten::cross.out(Tensor self, Tensor other, int? dim=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & cross_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other, ::std::optional dim=::std::nullopt) { + return at::_ops::cross_out::redispatch(dispatchKeySet, self, other, dim, out); + } + + // aten::cross.out(Tensor self, Tensor other, int? dim=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & cross_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, ::std::optional dim, at::Tensor & out) { + return at::_ops::cross_out::redispatch(dispatchKeySet, self, other, dim, out); + } + + // aten::cross(Tensor self, Tensor other, int? dim=None) -> Tensor + inline at::Tensor cross(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, ::std::optional dim=::std::nullopt) { + return at::_ops::cross::redispatch(dispatchKeySet, self, other, dim); + } + + // aten::triu.out(Tensor self, SymInt diagonal=0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & triu_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t diagonal=0) { + return at::_ops::triu_out::redispatch(dispatchKeySet, self, diagonal, out); + } + + // aten::triu.out(Tensor self, SymInt diagonal=0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & triu_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t diagonal, at::Tensor & out) { + return at::_ops::triu_out::redispatch(dispatchKeySet, self, diagonal, out); + } + + // aten::triu.out(Tensor self, SymInt diagonal=0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & triu_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymInt diagonal=0) { + return at::_ops::triu_out::redispatch(dispatchKeySet, self, diagonal, out); + } + + // aten::triu.out(Tensor self, SymInt diagonal=0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & triu_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymInt diagonal, at::Tensor & out) { + return at::_ops::triu_out::redispatch(dispatchKeySet, self, diagonal, out); + } + + // aten::triu(Tensor self, SymInt diagonal=0) -> Tensor + inline at::Tensor triu(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t diagonal=0) { + return at::_ops::triu::redispatch(dispatchKeySet, self, diagonal); + } + + // aten::triu(Tensor self, SymInt diagonal=0) -> Tensor + inline at::Tensor triu_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymInt diagonal=0) { + return at::_ops::triu::redispatch(dispatchKeySet, self, diagonal); + } + + // aten::tril.out(Tensor self, SymInt diagonal=0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & tril_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t diagonal=0) { + return at::_ops::tril_out::redispatch(dispatchKeySet, self, diagonal, out); + } + + // aten::tril.out(Tensor self, SymInt diagonal=0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & tril_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t diagonal, at::Tensor & out) { + return at::_ops::tril_out::redispatch(dispatchKeySet, self, diagonal, out); + } + + // aten::tril.out(Tensor self, SymInt diagonal=0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & tril_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymInt diagonal=0) { + return at::_ops::tril_out::redispatch(dispatchKeySet, self, diagonal, out); + } + + // aten::tril.out(Tensor self, SymInt diagonal=0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & tril_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymInt diagonal, at::Tensor & out) { + return at::_ops::tril_out::redispatch(dispatchKeySet, self, diagonal, out); + } + + // aten::tril(Tensor self, SymInt diagonal=0) -> Tensor + inline at::Tensor tril(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t diagonal=0) { + return at::_ops::tril::redispatch(dispatchKeySet, self, diagonal); + } + + // aten::tril(Tensor self, SymInt diagonal=0) -> Tensor + inline at::Tensor tril_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymInt diagonal=0) { + return at::_ops::tril::redispatch(dispatchKeySet, self, diagonal); + } + + // aten::tril_indices(int row, int col, int offset=0, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor tril_indices(c10::DispatchKeySet dispatchKeySet, int64_t row, int64_t col, int64_t offset=0, at::TensorOptions options=at::kLong) { + return at::_ops::tril_indices::redispatch(dispatchKeySet, row, col, offset, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::tril_indices(int row, int col, int offset=0, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor tril_indices(c10::DispatchKeySet dispatchKeySet, int64_t row, int64_t col, int64_t offset, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::tril_indices::redispatch(dispatchKeySet, row, col, offset, dtype, layout, device, pin_memory); + } + + // aten::triu_indices(int row, int col, int offset=0, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor triu_indices(c10::DispatchKeySet dispatchKeySet, int64_t row, int64_t col, int64_t offset=0, at::TensorOptions options=at::kLong) { + return at::_ops::triu_indices::redispatch(dispatchKeySet, row, col, offset, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::triu_indices(int row, int col, int offset=0, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor triu_indices(c10::DispatchKeySet dispatchKeySet, int64_t row, int64_t col, int64_t offset, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::triu_indices::redispatch(dispatchKeySet, row, col, offset, dtype, layout, device, pin_memory); + } + + // aten::trace(Tensor self) -> Tensor + inline at::Tensor trace(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::trace::redispatch(dispatchKeySet, self); + } + + // aten::trace_backward(Tensor grad, SymInt[] sizes) -> Tensor + inline at::Tensor trace_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, at::IntArrayRef sizes) { + return at::_ops::trace_backward::redispatch(dispatchKeySet, grad, c10::fromIntArrayRefSlow(sizes)); + } + + // aten::trace_backward(Tensor grad, SymInt[] sizes) -> Tensor + inline at::Tensor trace_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, c10::SymIntArrayRef sizes) { + return at::_ops::trace_backward::redispatch(dispatchKeySet, grad, sizes); + } + + // aten::ne.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & ne_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::ne_Scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::ne.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & ne_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, at::Tensor & out) { + return at::_ops::ne_Scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::ne.Scalar(Tensor self, Scalar other) -> Tensor + inline at::Tensor ne(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::ne_Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::ne.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & ne_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::ne_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::ne.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & ne_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::ne_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::ne.Tensor(Tensor self, Tensor other) -> Tensor + inline at::Tensor ne(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::ne_Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::ne_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + inline at::Tensor & ne_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other) { + return at::_ops::ne__Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::ne_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + inline at::Tensor & ne_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) { + return at::_ops::ne__Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::not_equal.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & not_equal_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::not_equal_Scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::not_equal.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & not_equal_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, at::Tensor & out) { + return at::_ops::not_equal_Scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::not_equal.Scalar(Tensor self, Scalar other) -> Tensor + inline at::Tensor not_equal(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::not_equal_Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::not_equal.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & not_equal_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::not_equal_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::not_equal.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & not_equal_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::not_equal_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::not_equal.Tensor(Tensor self, Tensor other) -> Tensor + inline at::Tensor not_equal(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::not_equal_Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::not_equal_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + inline at::Tensor & not_equal_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other) { + return at::_ops::not_equal__Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::not_equal_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + inline at::Tensor & not_equal_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) { + return at::_ops::not_equal__Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::eq.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & eq_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::eq_Scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::eq.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & eq_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, at::Tensor & out) { + return at::_ops::eq_Scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::eq.Scalar(Tensor self, Scalar other) -> Tensor + inline at::Tensor eq(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::eq_Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::eq.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & eq_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::eq_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::eq.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & eq_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::eq_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::eq.Tensor(Tensor self, Tensor other) -> Tensor + inline at::Tensor eq(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::eq_Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::ge.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & ge_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::ge_Scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::ge.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & ge_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, at::Tensor & out) { + return at::_ops::ge_Scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::ge.Scalar(Tensor self, Scalar other) -> Tensor + inline at::Tensor ge(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::ge_Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::ge.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & ge_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::ge_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::ge.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & ge_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::ge_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::ge.Tensor(Tensor self, Tensor other) -> Tensor + inline at::Tensor ge(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::ge_Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::ge_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + inline at::Tensor & ge_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other) { + return at::_ops::ge__Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::ge_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + inline at::Tensor & ge_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) { + return at::_ops::ge__Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::greater_equal.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & greater_equal_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::greater_equal_Scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::greater_equal.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & greater_equal_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, at::Tensor & out) { + return at::_ops::greater_equal_Scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::greater_equal.Scalar(Tensor self, Scalar other) -> Tensor + inline at::Tensor greater_equal(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::greater_equal_Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::greater_equal.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & greater_equal_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::greater_equal_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::greater_equal.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & greater_equal_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::greater_equal_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::greater_equal.Tensor(Tensor self, Tensor other) -> Tensor + inline at::Tensor greater_equal(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::greater_equal_Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::greater_equal_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + inline at::Tensor & greater_equal_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other) { + return at::_ops::greater_equal__Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::greater_equal_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + inline at::Tensor & greater_equal_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) { + return at::_ops::greater_equal__Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::le.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & le_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::le_Scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::le.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & le_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, at::Tensor & out) { + return at::_ops::le_Scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::le.Scalar(Tensor self, Scalar other) -> Tensor + inline at::Tensor le(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::le_Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::le.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & le_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::le_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::le.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & le_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::le_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::le.Tensor(Tensor self, Tensor other) -> Tensor + inline at::Tensor le(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::le_Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::le_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + inline at::Tensor & le_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other) { + return at::_ops::le__Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::le_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + inline at::Tensor & le_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) { + return at::_ops::le__Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::less_equal.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & less_equal_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::less_equal_Scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::less_equal.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & less_equal_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, at::Tensor & out) { + return at::_ops::less_equal_Scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::less_equal.Scalar(Tensor self, Scalar other) -> Tensor + inline at::Tensor less_equal(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::less_equal_Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::less_equal.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & less_equal_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::less_equal_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::less_equal.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & less_equal_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::less_equal_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::less_equal.Tensor(Tensor self, Tensor other) -> Tensor + inline at::Tensor less_equal(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::less_equal_Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::less_equal_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + inline at::Tensor & less_equal_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other) { + return at::_ops::less_equal__Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::less_equal_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + inline at::Tensor & less_equal_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) { + return at::_ops::less_equal__Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::gt.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & gt_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::gt_Scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::gt.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & gt_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, at::Tensor & out) { + return at::_ops::gt_Scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::gt.Scalar(Tensor self, Scalar other) -> Tensor + inline at::Tensor gt(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::gt_Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::gt.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & gt_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::gt_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::gt.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & gt_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::gt_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::gt.Tensor(Tensor self, Tensor other) -> Tensor + inline at::Tensor gt(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::gt_Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::gt_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + inline at::Tensor & gt_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other) { + return at::_ops::gt__Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::gt_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + inline at::Tensor & gt_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) { + return at::_ops::gt__Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::greater.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & greater_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::greater_Scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::greater.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & greater_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, at::Tensor & out) { + return at::_ops::greater_Scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::greater.Scalar(Tensor self, Scalar other) -> Tensor + inline at::Tensor greater(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::greater_Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::greater.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & greater_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::greater_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::greater.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & greater_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::greater_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::greater.Tensor(Tensor self, Tensor other) -> Tensor + inline at::Tensor greater(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::greater_Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::greater_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + inline at::Tensor & greater_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other) { + return at::_ops::greater__Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::greater_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + inline at::Tensor & greater_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) { + return at::_ops::greater__Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::lt.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & lt_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::lt_Scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::lt.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & lt_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, at::Tensor & out) { + return at::_ops::lt_Scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::lt.Scalar(Tensor self, Scalar other) -> Tensor + inline at::Tensor lt(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::lt_Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::lt.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & lt_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::lt_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::lt.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & lt_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::lt_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::lt.Tensor(Tensor self, Tensor other) -> Tensor + inline at::Tensor lt(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::lt_Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::lt_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + inline at::Tensor & lt_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other) { + return at::_ops::lt__Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::lt_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + inline at::Tensor & lt_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) { + return at::_ops::lt__Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::less.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & less_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::less_Scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::less.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & less_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, at::Tensor & out) { + return at::_ops::less_Scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::less.Scalar(Tensor self, Scalar other) -> Tensor + inline at::Tensor less(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::less_Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::less.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & less_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::less_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::less.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & less_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::less_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::less.Tensor(Tensor self, Tensor other) -> Tensor + inline at::Tensor less(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::less_Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::less_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + inline at::Tensor & less_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other) { + return at::_ops::less__Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::less_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + inline at::Tensor & less_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) { + return at::_ops::less__Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::take.out(Tensor self, Tensor index, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & take_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & index) { + return at::_ops::take_out::redispatch(dispatchKeySet, self, index, out); + } + + // aten::take.out(Tensor self, Tensor index, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & take_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & index, at::Tensor & out) { + return at::_ops::take_out::redispatch(dispatchKeySet, self, index, out); + } + + // aten::take(Tensor self, Tensor index) -> Tensor + inline at::Tensor take(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & index) { + return at::_ops::take::redispatch(dispatchKeySet, self, index); + } + + // aten::take_along_dim.out(Tensor self, Tensor indices, int? dim=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & take_along_dim_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & indices, ::std::optional dim=::std::nullopt) { + return at::_ops::take_along_dim_out::redispatch(dispatchKeySet, self, indices, dim, out); + } + + // aten::take_along_dim.out(Tensor self, Tensor indices, int? dim=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & take_along_dim_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & indices, ::std::optional dim, at::Tensor & out) { + return at::_ops::take_along_dim_out::redispatch(dispatchKeySet, self, indices, dim, out); + } + + // aten::take_along_dim(Tensor self, Tensor indices, int? dim=None) -> Tensor + inline at::Tensor take_along_dim(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & indices, ::std::optional dim=::std::nullopt) { + return at::_ops::take_along_dim::redispatch(dispatchKeySet, self, indices, dim); + } + + // aten::index_select.out(Tensor self, int dim, Tensor index, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & index_select_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim, const at::Tensor & index) { + return at::_ops::index_select_out::redispatch(dispatchKeySet, self, dim, index, out); + } + + // aten::index_select.out(Tensor self, int dim, Tensor index, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & index_select_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, const at::Tensor & index, at::Tensor & out) { + return at::_ops::index_select_out::redispatch(dispatchKeySet, self, dim, index, out); + } + + // aten::index_select(Tensor self, int dim, Tensor index) -> Tensor + inline at::Tensor index_select(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, const at::Tensor & index) { + return at::_ops::index_select::redispatch(dispatchKeySet, self, dim, index); + } + + // aten::index_select.dimname_out(Tensor self, Dimname dim, Tensor index, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & index_select_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::Dimname dim, const at::Tensor & index) { + return at::_ops::index_select_dimname_out::redispatch(dispatchKeySet, self, dim, index, out); + } + + // aten::index_select.dimname_out(Tensor self, Dimname dim, Tensor index, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & index_select_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, const at::Tensor & index, at::Tensor & out) { + return at::_ops::index_select_dimname_out::redispatch(dispatchKeySet, self, dim, index, out); + } + + // aten::index_select.dimname(Tensor self, Dimname dim, Tensor index) -> Tensor + inline at::Tensor index_select(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, const at::Tensor & index) { + return at::_ops::index_select_dimname::redispatch(dispatchKeySet, self, dim, index); + } + + // aten::index_select_backward(Tensor grad, SymInt[] self_sizes, int dim, Tensor index) -> Tensor + inline at::Tensor index_select_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, at::IntArrayRef self_sizes, int64_t dim, const at::Tensor & index) { + return at::_ops::index_select_backward::redispatch(dispatchKeySet, grad, c10::fromIntArrayRefSlow(self_sizes), dim, index); + } + + // aten::index_select_backward(Tensor grad, SymInt[] self_sizes, int dim, Tensor index) -> Tensor + inline at::Tensor index_select_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, c10::SymIntArrayRef self_sizes, int64_t dim, const at::Tensor & index) { + return at::_ops::index_select_backward::redispatch(dispatchKeySet, grad, self_sizes, dim, index); + } + + // aten::masked_select.out(Tensor self, Tensor mask, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & masked_select_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & mask) { + return at::_ops::masked_select_out::redispatch(dispatchKeySet, self, mask, out); + } + + // aten::masked_select.out(Tensor self, Tensor mask, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & masked_select_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mask, at::Tensor & out) { + return at::_ops::masked_select_out::redispatch(dispatchKeySet, self, mask, out); + } + + // aten::masked_select(Tensor self, Tensor mask) -> Tensor + inline at::Tensor masked_select(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mask) { + return at::_ops::masked_select::redispatch(dispatchKeySet, self, mask); + } + + // aten::masked_select_backward(Tensor grad, Tensor input, Tensor mask) -> Tensor + inline at::Tensor masked_select_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & input, const at::Tensor & mask) { + return at::_ops::masked_select_backward::redispatch(dispatchKeySet, grad, input, mask); + } + + // aten::nonzero.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & nonzero_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::nonzero_out::redispatch(dispatchKeySet, self, out); + } + + // aten::nonzero.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & nonzero_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::nonzero_out::redispatch(dispatchKeySet, self, out); + } + + // aten::nonzero(Tensor self) -> Tensor + inline at::Tensor nonzero(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::nonzero::redispatch(dispatchKeySet, self); + } + + // aten::nonzero_static.out(Tensor self, *, SymInt size, int fill_value=-1, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & nonzero_static_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t size, int64_t fill_value=-1) { + return at::_ops::nonzero_static_out::redispatch(dispatchKeySet, self, size, fill_value, out); + } + + // aten::nonzero_static.out(Tensor self, *, SymInt size, int fill_value=-1, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & nonzero_static_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t size, int64_t fill_value, at::Tensor & out) { + return at::_ops::nonzero_static_out::redispatch(dispatchKeySet, self, size, fill_value, out); + } + + // aten::nonzero_static.out(Tensor self, *, SymInt size, int fill_value=-1, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & nonzero_static_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymInt size, int64_t fill_value=-1) { + return at::_ops::nonzero_static_out::redispatch(dispatchKeySet, self, size, fill_value, out); + } + + // aten::nonzero_static.out(Tensor self, *, SymInt size, int fill_value=-1, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & nonzero_static_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymInt size, int64_t fill_value, at::Tensor & out) { + return at::_ops::nonzero_static_out::redispatch(dispatchKeySet, self, size, fill_value, out); + } + + // aten::nonzero_static(Tensor self, *, SymInt size, int fill_value=-1) -> Tensor + inline at::Tensor nonzero_static(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t size, int64_t fill_value=-1) { + return at::_ops::nonzero_static::redispatch(dispatchKeySet, self, size, fill_value); + } + + // aten::nonzero_static(Tensor self, *, SymInt size, int fill_value=-1) -> Tensor + inline at::Tensor nonzero_static_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymInt size, int64_t fill_value=-1) { + return at::_ops::nonzero_static::redispatch(dispatchKeySet, self, size, fill_value); + } + + // aten::nonzero_numpy(Tensor self) -> Tensor[] + inline ::std::vector nonzero_numpy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::nonzero_numpy::redispatch(dispatchKeySet, self); + } + + // aten::argwhere(Tensor self) -> Tensor + inline at::Tensor argwhere(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::argwhere::redispatch(dispatchKeySet, self); + } + + // aten::gather.out(Tensor self, int dim, Tensor index, *, bool sparse_grad=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & gather_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim, const at::Tensor & index, bool sparse_grad=false) { + return at::_ops::gather_out::redispatch(dispatchKeySet, self, dim, index, sparse_grad, out); + } + + // aten::gather.out(Tensor self, int dim, Tensor index, *, bool sparse_grad=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & gather_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, const at::Tensor & index, bool sparse_grad, at::Tensor & out) { + return at::_ops::gather_out::redispatch(dispatchKeySet, self, dim, index, sparse_grad, out); + } + + // aten::gather(Tensor self, int dim, Tensor index, *, bool sparse_grad=False) -> Tensor + inline at::Tensor gather(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, const at::Tensor & index, bool sparse_grad=false) { + return at::_ops::gather::redispatch(dispatchKeySet, self, dim, index, sparse_grad); + } + + // aten::gather_backward(Tensor grad, Tensor self, int dim, Tensor index, bool sparse_grad) -> Tensor + inline at::Tensor gather_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & self, int64_t dim, const at::Tensor & index, bool sparse_grad) { + return at::_ops::gather_backward::redispatch(dispatchKeySet, grad, self, dim, index, sparse_grad); + } + + // aten::gather.dimname_out(Tensor self, Dimname dim, Tensor index, *, bool sparse_grad=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & gather_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::Dimname dim, const at::Tensor & index, bool sparse_grad=false) { + return at::_ops::gather_dimname_out::redispatch(dispatchKeySet, self, dim, index, sparse_grad, out); + } + + // aten::gather.dimname_out(Tensor self, Dimname dim, Tensor index, *, bool sparse_grad=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & gather_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, const at::Tensor & index, bool sparse_grad, at::Tensor & out) { + return at::_ops::gather_dimname_out::redispatch(dispatchKeySet, self, dim, index, sparse_grad, out); + } + + // aten::gather.dimname(Tensor self, Dimname dim, Tensor index, *, bool sparse_grad=False) -> Tensor + inline at::Tensor gather(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, const at::Tensor & index, bool sparse_grad=false) { + return at::_ops::gather_dimname::redispatch(dispatchKeySet, self, dim, index, sparse_grad); + } + + // aten::_gather_sparse_backward(Tensor self, int dim, Tensor index, Tensor grad) -> Tensor + inline at::Tensor _gather_sparse_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & grad) { + return at::_ops::_gather_sparse_backward::redispatch(dispatchKeySet, self, dim, index, grad); + } + + // aten::addcmul.out(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & addcmul_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value=1) { + return at::_ops::addcmul_out::redispatch(dispatchKeySet, self, tensor1, tensor2, value, out); + } + + // aten::addcmul.out(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & addcmul_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value, at::Tensor & out) { + return at::_ops::addcmul_out::redispatch(dispatchKeySet, self, tensor1, tensor2, value, out); + } + + // aten::addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor + inline at::Tensor addcmul(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value=1) { + return at::_ops::addcmul::redispatch(dispatchKeySet, self, tensor1, tensor2, value); + } + + // aten::addcmul_(Tensor(a!) self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor(a!) + inline at::Tensor & addcmul_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value=1) { + return at::_ops::addcmul_::redispatch(dispatchKeySet, self, tensor1, tensor2, value); + } + + // aten::addcdiv.out(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & addcdiv_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value=1) { + return at::_ops::addcdiv_out::redispatch(dispatchKeySet, self, tensor1, tensor2, value, out); + } + + // aten::addcdiv.out(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & addcdiv_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value, at::Tensor & out) { + return at::_ops::addcdiv_out::redispatch(dispatchKeySet, self, tensor1, tensor2, value, out); + } + + // aten::addcdiv(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor + inline at::Tensor addcdiv(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value=1) { + return at::_ops::addcdiv::redispatch(dispatchKeySet, self, tensor1, tensor2, value); + } + + // aten::addcdiv_(Tensor(a!) self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor(a!) + inline at::Tensor & addcdiv_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value=1) { + return at::_ops::addcdiv_::redispatch(dispatchKeySet, self, tensor1, tensor2, value); + } + + // aten::cross_entropy_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, float label_smoothing=0.0) -> Tensor + inline at::Tensor cross_entropy_loss(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight={}, int64_t reduction=at::Reduction::Mean, int64_t ignore_index=-100, double label_smoothing=0.0) { + return at::_ops::cross_entropy_loss::redispatch(dispatchKeySet, self, target, weight, reduction, ignore_index, label_smoothing); + } + + // aten::cross_entropy_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, float label_smoothing=0.0) -> Tensor + inline at::Tensor cross_entropy_loss_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight={}, int64_t reduction=at::Reduction::Mean, c10::SymInt ignore_index=-100, double label_smoothing=0.0) { + return at::_ops::cross_entropy_loss::redispatch(dispatchKeySet, self, target, weight, reduction, ignore_index, label_smoothing); + } + + // aten::triangular_solve.X(Tensor self, Tensor A, bool upper=True, bool transpose=False, bool unitriangular=False, *, Tensor(a!) X, Tensor(b!) M) -> (Tensor(a!) solution, Tensor(b!) cloned_coefficient) + inline ::std::tuple triangular_solve_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & X, at::Tensor & M, const at::Tensor & self, const at::Tensor & A, bool upper=true, bool transpose=false, bool unitriangular=false) { + return at::_ops::triangular_solve_X::redispatch(dispatchKeySet, self, A, upper, transpose, unitriangular, X, M); + } + + // aten::triangular_solve.X(Tensor self, Tensor A, bool upper=True, bool transpose=False, bool unitriangular=False, *, Tensor(a!) X, Tensor(b!) M) -> (Tensor(a!) solution, Tensor(b!) cloned_coefficient) + inline ::std::tuple triangular_solve_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & A, bool upper, bool transpose, bool unitriangular, at::Tensor & X, at::Tensor & M) { + return at::_ops::triangular_solve_X::redispatch(dispatchKeySet, self, A, upper, transpose, unitriangular, X, M); + } + + // aten::triangular_solve(Tensor self, Tensor A, bool upper=True, bool transpose=False, bool unitriangular=False) -> (Tensor solution, Tensor cloned_coefficient) + inline ::std::tuple triangular_solve(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & A, bool upper=true, bool transpose=false, bool unitriangular=false) { + return at::_ops::triangular_solve::redispatch(dispatchKeySet, self, A, upper, transpose, unitriangular); + } + + // aten::_linalg_check_errors(Tensor info, str api_name, *, bool is_matrix) -> () + inline void _linalg_check_errors(c10::DispatchKeySet dispatchKeySet, const at::Tensor & info, c10::string_view api_name, bool is_matrix) { + return at::_ops::_linalg_check_errors::redispatch(dispatchKeySet, info, api_name, is_matrix); + } + + // aten::linalg_solve_triangular.out(Tensor self, Tensor B, *, bool upper, bool left=True, bool unitriangular=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_solve_triangular_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & B, bool upper, bool left=true, bool unitriangular=false) { + return at::_ops::linalg_solve_triangular_out::redispatch(dispatchKeySet, self, B, upper, left, unitriangular, out); + } + + // aten::linalg_solve_triangular.out(Tensor self, Tensor B, *, bool upper, bool left=True, bool unitriangular=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_solve_triangular_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & B, bool upper, bool left, bool unitriangular, at::Tensor & out) { + return at::_ops::linalg_solve_triangular_out::redispatch(dispatchKeySet, self, B, upper, left, unitriangular, out); + } + + // aten::linalg_solve_triangular(Tensor self, Tensor B, *, bool upper, bool left=True, bool unitriangular=False) -> Tensor + inline at::Tensor linalg_solve_triangular(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & B, bool upper, bool left=true, bool unitriangular=false) { + return at::_ops::linalg_solve_triangular::redispatch(dispatchKeySet, self, B, upper, left, unitriangular); + } + + // aten::linalg_vander(Tensor x, *, SymInt? N=None) -> Tensor + inline at::Tensor linalg_vander(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, ::std::optional N=::std::nullopt) { + return at::_ops::linalg_vander::redispatch(dispatchKeySet, x, N.has_value() ? ::std::make_optional(c10::SymInt(*N)) : ::std::nullopt); + } + + // aten::linalg_vander(Tensor x, *, SymInt? N=None) -> Tensor + inline at::Tensor linalg_vander_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, ::std::optional N=::std::nullopt) { + return at::_ops::linalg_vander::redispatch(dispatchKeySet, x, N); + } + + // aten::svd.U(Tensor self, bool some=True, bool compute_uv=True, *, Tensor(a!) U, Tensor(b!) S, Tensor(c!) V) -> (Tensor(a!) U, Tensor(b!) S, Tensor(c!) V) + inline ::std::tuple svd_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & U, at::Tensor & S, at::Tensor & V, const at::Tensor & self, bool some=true, bool compute_uv=true) { + return at::_ops::svd_U::redispatch(dispatchKeySet, self, some, compute_uv, U, S, V); + } + + // aten::svd.U(Tensor self, bool some=True, bool compute_uv=True, *, Tensor(a!) U, Tensor(b!) S, Tensor(c!) V) -> (Tensor(a!) U, Tensor(b!) S, Tensor(c!) V) + inline ::std::tuple svd_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool some, bool compute_uv, at::Tensor & U, at::Tensor & S, at::Tensor & V) { + return at::_ops::svd_U::redispatch(dispatchKeySet, self, some, compute_uv, U, S, V); + } + + // aten::svd(Tensor self, bool some=True, bool compute_uv=True) -> (Tensor U, Tensor S, Tensor V) + inline ::std::tuple svd(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool some=true, bool compute_uv=true) { + return at::_ops::svd::redispatch(dispatchKeySet, self, some, compute_uv); + } + + // aten::swapaxes(Tensor(a) self, int axis0, int axis1) -> Tensor(a) + inline at::Tensor swapaxes(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t axis0, int64_t axis1) { + return at::_ops::swapaxes::redispatch(dispatchKeySet, self, axis0, axis1); + } + + // aten::swapaxes_(Tensor(a!) self, int axis0, int axis1) -> Tensor(a!) + inline at::Tensor & swapaxes_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, int64_t axis0, int64_t axis1) { + return at::_ops::swapaxes_::redispatch(dispatchKeySet, self, axis0, axis1); + } + + // aten::swapdims(Tensor(a) self, int dim0, int dim1) -> Tensor(a) + inline at::Tensor swapdims(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim0, int64_t dim1) { + return at::_ops::swapdims::redispatch(dispatchKeySet, self, dim0, dim1); + } + + // aten::swapdims_(Tensor(a!) self, int dim0, int dim1) -> Tensor(a!) + inline at::Tensor & swapdims_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, int64_t dim0, int64_t dim1) { + return at::_ops::swapdims_::redispatch(dispatchKeySet, self, dim0, dim1); + } + + // aten::cholesky.out(Tensor self, bool upper=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & cholesky_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, bool upper=false) { + return at::_ops::cholesky_out::redispatch(dispatchKeySet, self, upper, out); + } + + // aten::cholesky.out(Tensor self, bool upper=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & cholesky_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool upper, at::Tensor & out) { + return at::_ops::cholesky_out::redispatch(dispatchKeySet, self, upper, out); + } + + // aten::cholesky(Tensor self, bool upper=False) -> Tensor + inline at::Tensor cholesky(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool upper=false) { + return at::_ops::cholesky::redispatch(dispatchKeySet, self, upper); + } + + // aten::cholesky_solve.out(Tensor self, Tensor input2, bool upper=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & cholesky_solve_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & input2, bool upper=false) { + return at::_ops::cholesky_solve_out::redispatch(dispatchKeySet, self, input2, upper, out); + } + + // aten::cholesky_solve.out(Tensor self, Tensor input2, bool upper=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & cholesky_solve_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & input2, bool upper, at::Tensor & out) { + return at::_ops::cholesky_solve_out::redispatch(dispatchKeySet, self, input2, upper, out); + } + + // aten::cholesky_solve(Tensor self, Tensor input2, bool upper=False) -> Tensor + inline at::Tensor cholesky_solve(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & input2, bool upper=false) { + return at::_ops::cholesky_solve::redispatch(dispatchKeySet, self, input2, upper); + } + + // aten::_cholesky_solve_helper(Tensor self, Tensor A, bool upper) -> Tensor + inline at::Tensor _cholesky_solve_helper(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & A, bool upper) { + return at::_ops::_cholesky_solve_helper::redispatch(dispatchKeySet, self, A, upper); + } + + // aten::cholesky_inverse(Tensor self, bool upper=False) -> Tensor + inline at::Tensor cholesky_inverse(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool upper=false) { + return at::_ops::cholesky_inverse::redispatch(dispatchKeySet, self, upper); + } + + // aten::cholesky_inverse.out(Tensor self, bool upper=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & cholesky_inverse_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, bool upper=false) { + return at::_ops::cholesky_inverse_out::redispatch(dispatchKeySet, self, upper, out); + } + + // aten::cholesky_inverse.out(Tensor self, bool upper=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & cholesky_inverse_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool upper, at::Tensor & out) { + return at::_ops::cholesky_inverse_out::redispatch(dispatchKeySet, self, upper, out); + } + + // aten::qr.Q(Tensor self, bool some=True, *, Tensor(a!) Q, Tensor(b!) R) -> (Tensor(a!) Q, Tensor(b!) R) + inline ::std::tuple qr_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & Q, at::Tensor & R, const at::Tensor & self, bool some=true) { + return at::_ops::qr_Q::redispatch(dispatchKeySet, self, some, Q, R); + } + + // aten::qr.Q(Tensor self, bool some=True, *, Tensor(a!) Q, Tensor(b!) R) -> (Tensor(a!) Q, Tensor(b!) R) + inline ::std::tuple qr_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool some, at::Tensor & Q, at::Tensor & R) { + return at::_ops::qr_Q::redispatch(dispatchKeySet, self, some, Q, R); + } + + // aten::qr(Tensor self, bool some=True) -> (Tensor Q, Tensor R) + inline ::std::tuple qr(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool some=true) { + return at::_ops::qr::redispatch(dispatchKeySet, self, some); + } + + // aten::geqrf.a(Tensor self, *, Tensor(a!) a, Tensor(b!) tau) -> (Tensor(a!) a, Tensor(b!) tau) + inline ::std::tuple geqrf_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & a, at::Tensor & tau, const at::Tensor & self) { + return at::_ops::geqrf_a::redispatch(dispatchKeySet, self, a, tau); + } + + // aten::geqrf.a(Tensor self, *, Tensor(a!) a, Tensor(b!) tau) -> (Tensor(a!) a, Tensor(b!) tau) + inline ::std::tuple geqrf_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & a, at::Tensor & tau) { + return at::_ops::geqrf_a::redispatch(dispatchKeySet, self, a, tau); + } + + // aten::geqrf(Tensor self) -> (Tensor a, Tensor tau) + inline ::std::tuple geqrf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::geqrf::redispatch(dispatchKeySet, self); + } + + // aten::orgqr(Tensor self, Tensor input2) -> Tensor + inline at::Tensor orgqr(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & input2) { + return at::_ops::orgqr::redispatch(dispatchKeySet, self, input2); + } + + // aten::orgqr.out(Tensor self, Tensor input2, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & orgqr_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & input2) { + return at::_ops::orgqr_out::redispatch(dispatchKeySet, self, input2, out); + } + + // aten::orgqr.out(Tensor self, Tensor input2, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & orgqr_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & input2, at::Tensor & out) { + return at::_ops::orgqr_out::redispatch(dispatchKeySet, self, input2, out); + } + + // aten::ormqr.out(Tensor self, Tensor input2, Tensor input3, bool left=True, bool transpose=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & ormqr_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & input2, const at::Tensor & input3, bool left=true, bool transpose=false) { + return at::_ops::ormqr_out::redispatch(dispatchKeySet, self, input2, input3, left, transpose, out); + } + + // aten::ormqr.out(Tensor self, Tensor input2, Tensor input3, bool left=True, bool transpose=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & ormqr_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & input2, const at::Tensor & input3, bool left, bool transpose, at::Tensor & out) { + return at::_ops::ormqr_out::redispatch(dispatchKeySet, self, input2, input3, left, transpose, out); + } + + // aten::ormqr(Tensor self, Tensor input2, Tensor input3, bool left=True, bool transpose=False) -> Tensor + inline at::Tensor ormqr(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & input2, const at::Tensor & input3, bool left=true, bool transpose=false) { + return at::_ops::ormqr::redispatch(dispatchKeySet, self, input2, input3, left, transpose); + } + + // aten::_lu_with_info(Tensor self, bool pivot=True, bool check_errors=True) -> (Tensor LU, Tensor pivots, Tensor info) + inline ::std::tuple _lu_with_info(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool pivot=true, bool check_errors=true) { + return at::_ops::_lu_with_info::redispatch(dispatchKeySet, self, pivot, check_errors); + } + + // aten::lu_solve.out(Tensor self, Tensor LU_data, Tensor LU_pivots, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & lu_solve_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & LU_data, const at::Tensor & LU_pivots) { + return at::_ops::lu_solve_out::redispatch(dispatchKeySet, self, LU_data, LU_pivots, out); + } + + // aten::lu_solve.out(Tensor self, Tensor LU_data, Tensor LU_pivots, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & lu_solve_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & LU_data, const at::Tensor & LU_pivots, at::Tensor & out) { + return at::_ops::lu_solve_out::redispatch(dispatchKeySet, self, LU_data, LU_pivots, out); + } + + // aten::lu_solve(Tensor self, Tensor LU_data, Tensor LU_pivots) -> Tensor + inline at::Tensor lu_solve(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & LU_data, const at::Tensor & LU_pivots) { + return at::_ops::lu_solve::redispatch(dispatchKeySet, self, LU_data, LU_pivots); + } + + // aten::lu_unpack(Tensor LU_data, Tensor LU_pivots, bool unpack_data=True, bool unpack_pivots=True) -> (Tensor P, Tensor L, Tensor U) + inline ::std::tuple lu_unpack(c10::DispatchKeySet dispatchKeySet, const at::Tensor & LU_data, const at::Tensor & LU_pivots, bool unpack_data=true, bool unpack_pivots=true) { + return at::_ops::lu_unpack::redispatch(dispatchKeySet, LU_data, LU_pivots, unpack_data, unpack_pivots); + } + + // aten::lu_unpack.out(Tensor LU_data, Tensor LU_pivots, bool unpack_data=True, bool unpack_pivots=True, *, Tensor(a!) P, Tensor(b!) L, Tensor(c!) U) -> (Tensor(a!) P, Tensor(b!) L, Tensor(c!) U) + inline ::std::tuple lu_unpack_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & P, at::Tensor & L, at::Tensor & U, const at::Tensor & LU_data, const at::Tensor & LU_pivots, bool unpack_data=true, bool unpack_pivots=true) { + return at::_ops::lu_unpack_out::redispatch(dispatchKeySet, LU_data, LU_pivots, unpack_data, unpack_pivots, P, L, U); + } + + // aten::lu_unpack.out(Tensor LU_data, Tensor LU_pivots, bool unpack_data=True, bool unpack_pivots=True, *, Tensor(a!) P, Tensor(b!) L, Tensor(c!) U) -> (Tensor(a!) P, Tensor(b!) L, Tensor(c!) U) + inline ::std::tuple lu_unpack_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & LU_data, const at::Tensor & LU_pivots, bool unpack_data, bool unpack_pivots, at::Tensor & P, at::Tensor & L, at::Tensor & U) { + return at::_ops::lu_unpack_out::redispatch(dispatchKeySet, LU_data, LU_pivots, unpack_data, unpack_pivots, P, L, U); + } + + // aten::multinomial.out(Tensor self, SymInt num_samples, bool replacement=False, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & multinomial_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t num_samples, bool replacement=false, ::std::optional generator=::std::nullopt) { + return at::_ops::multinomial_out::redispatch(dispatchKeySet, self, num_samples, replacement, generator, out); + } + + // aten::multinomial.out(Tensor self, SymInt num_samples, bool replacement=False, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & multinomial_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t num_samples, bool replacement, ::std::optional generator, at::Tensor & out) { + return at::_ops::multinomial_out::redispatch(dispatchKeySet, self, num_samples, replacement, generator, out); + } + + // aten::multinomial.out(Tensor self, SymInt num_samples, bool replacement=False, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & multinomial_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymInt num_samples, bool replacement=false, ::std::optional generator=::std::nullopt) { + return at::_ops::multinomial_out::redispatch(dispatchKeySet, self, num_samples, replacement, generator, out); + } + + // aten::multinomial.out(Tensor self, SymInt num_samples, bool replacement=False, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & multinomial_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymInt num_samples, bool replacement, ::std::optional generator, at::Tensor & out) { + return at::_ops::multinomial_out::redispatch(dispatchKeySet, self, num_samples, replacement, generator, out); + } + + // aten::multinomial(Tensor self, SymInt num_samples, bool replacement=False, *, Generator? generator=None) -> Tensor + inline at::Tensor multinomial(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t num_samples, bool replacement=false, ::std::optional generator=::std::nullopt) { + return at::_ops::multinomial::redispatch(dispatchKeySet, self, num_samples, replacement, generator); + } + + // aten::multinomial(Tensor self, SymInt num_samples, bool replacement=False, *, Generator? generator=None) -> Tensor + inline at::Tensor multinomial_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymInt num_samples, bool replacement=false, ::std::optional generator=::std::nullopt) { + return at::_ops::multinomial::redispatch(dispatchKeySet, self, num_samples, replacement, generator); + } + + // aten::lgamma.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & lgamma_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::lgamma_out::redispatch(dispatchKeySet, self, out); + } + + // aten::lgamma.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & lgamma_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::lgamma_out::redispatch(dispatchKeySet, self, out); + } + + // aten::lgamma_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & lgamma_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::lgamma_::redispatch(dispatchKeySet, self); + } + + // aten::lgamma(Tensor self) -> Tensor + inline at::Tensor lgamma(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::lgamma::redispatch(dispatchKeySet, self); + } + + // aten::digamma.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & digamma_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::digamma_out::redispatch(dispatchKeySet, self, out); + } + + // aten::digamma.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & digamma_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::digamma_out::redispatch(dispatchKeySet, self, out); + } + + // aten::digamma(Tensor self) -> Tensor + inline at::Tensor digamma(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::digamma::redispatch(dispatchKeySet, self); + } + + // aten::polygamma.out(int n, Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & polygamma_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, int64_t n, const at::Tensor & self) { + return at::_ops::polygamma_out::redispatch(dispatchKeySet, n, self, out); + } + + // aten::polygamma.out(int n, Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & polygamma_outf(c10::DispatchKeySet dispatchKeySet, int64_t n, const at::Tensor & self, at::Tensor & out) { + return at::_ops::polygamma_out::redispatch(dispatchKeySet, n, self, out); + } + + // aten::polygamma(int n, Tensor self) -> Tensor + inline at::Tensor polygamma(c10::DispatchKeySet dispatchKeySet, int64_t n, const at::Tensor & self) { + return at::_ops::polygamma::redispatch(dispatchKeySet, n, self); + } + + // aten::polygamma_(Tensor(a!) self, int n) -> Tensor(a!) + inline at::Tensor & polygamma_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, int64_t n) { + return at::_ops::polygamma_::redispatch(dispatchKeySet, self, n); + } + + // aten::erfinv(Tensor self) -> Tensor + inline at::Tensor erfinv(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::erfinv::redispatch(dispatchKeySet, self); + } + + // aten::erfinv_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & erfinv_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::erfinv_::redispatch(dispatchKeySet, self); + } + + // aten::erfinv.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & erfinv_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::erfinv_out::redispatch(dispatchKeySet, self, out); + } + + // aten::erfinv.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & erfinv_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::erfinv_out::redispatch(dispatchKeySet, self, out); + } + + // aten::i0(Tensor self) -> Tensor + inline at::Tensor i0(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::i0::redispatch(dispatchKeySet, self); + } + + // aten::i0_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & i0_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::i0_::redispatch(dispatchKeySet, self); + } + + // aten::i0.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & i0_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::i0_out::redispatch(dispatchKeySet, self, out); + } + + // aten::i0.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & i0_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::i0_out::redispatch(dispatchKeySet, self, out); + } + + // aten::sign(Tensor self) -> Tensor + inline at::Tensor sign(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::sign::redispatch(dispatchKeySet, self); + } + + // aten::sign_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & sign_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::sign_::redispatch(dispatchKeySet, self); + } + + // aten::sign.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & sign_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::sign_out::redispatch(dispatchKeySet, self, out); + } + + // aten::sign.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & sign_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::sign_out::redispatch(dispatchKeySet, self, out); + } + + // aten::signbit(Tensor self) -> Tensor + inline at::Tensor signbit(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::signbit::redispatch(dispatchKeySet, self); + } + + // aten::signbit.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & signbit_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::signbit_out::redispatch(dispatchKeySet, self, out); + } + + // aten::signbit.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & signbit_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::signbit_out::redispatch(dispatchKeySet, self, out); + } + + // aten::dist(Tensor self, Tensor other, Scalar p=2) -> Tensor + inline at::Tensor dist(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, const at::Scalar & p=2) { + return at::_ops::dist::redispatch(dispatchKeySet, self, other, p); + } + + // aten::atan2.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & atan2_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::atan2_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::atan2.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & atan2_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::atan2_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::atan2_(Tensor(a!) self, Tensor other) -> Tensor(a!) + inline at::Tensor & atan2_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) { + return at::_ops::atan2_::redispatch(dispatchKeySet, self, other); + } + + // aten::atan2(Tensor self, Tensor other) -> Tensor + inline at::Tensor atan2(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::atan2::redispatch(dispatchKeySet, self, other); + } + + // aten::arctan2(Tensor self, Tensor other) -> Tensor + inline at::Tensor arctan2(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::arctan2::redispatch(dispatchKeySet, self, other); + } + + // aten::arctan2.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & arctan2_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::arctan2_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::arctan2.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & arctan2_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::arctan2_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::arctan2_(Tensor(a!) self, Tensor other) -> Tensor(a!) + inline at::Tensor & arctan2_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) { + return at::_ops::arctan2_::redispatch(dispatchKeySet, self, other); + } + + // aten::lerp.Scalar_out(Tensor self, Tensor end, Scalar weight, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & lerp_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & end, const at::Scalar & weight) { + return at::_ops::lerp_Scalar_out::redispatch(dispatchKeySet, self, end, weight, out); + } + + // aten::lerp.Scalar_out(Tensor self, Tensor end, Scalar weight, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & lerp_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & end, const at::Scalar & weight, at::Tensor & out) { + return at::_ops::lerp_Scalar_out::redispatch(dispatchKeySet, self, end, weight, out); + } + + // aten::lerp.Tensor_out(Tensor self, Tensor end, Tensor weight, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & lerp_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & end, const at::Tensor & weight) { + return at::_ops::lerp_Tensor_out::redispatch(dispatchKeySet, self, end, weight, out); + } + + // aten::lerp.Tensor_out(Tensor self, Tensor end, Tensor weight, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & lerp_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & end, const at::Tensor & weight, at::Tensor & out) { + return at::_ops::lerp_Tensor_out::redispatch(dispatchKeySet, self, end, weight, out); + } + + // aten::lerp.Scalar(Tensor self, Tensor end, Scalar weight) -> Tensor + inline at::Tensor lerp(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & end, const at::Scalar & weight) { + return at::_ops::lerp_Scalar::redispatch(dispatchKeySet, self, end, weight); + } + + // aten::lerp.Tensor(Tensor self, Tensor end, Tensor weight) -> Tensor + inline at::Tensor lerp(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & end, const at::Tensor & weight) { + return at::_ops::lerp_Tensor::redispatch(dispatchKeySet, self, end, weight); + } + + // aten::histc.out(Tensor self, int bins=100, Scalar min=0, Scalar max=0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & histc_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t bins=100, const at::Scalar & min=0, const at::Scalar & max=0) { + return at::_ops::histc_out::redispatch(dispatchKeySet, self, bins, min, max, out); + } + + // aten::histc.out(Tensor self, int bins=100, Scalar min=0, Scalar max=0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & histc_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t bins, const at::Scalar & min, const at::Scalar & max, at::Tensor & out) { + return at::_ops::histc_out::redispatch(dispatchKeySet, self, bins, min, max, out); + } + + // aten::histc(Tensor self, int bins=100, Scalar min=0, Scalar max=0) -> Tensor + inline at::Tensor histc(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t bins=100, const at::Scalar & min=0, const at::Scalar & max=0) { + return at::_ops::histc::redispatch(dispatchKeySet, self, bins, min, max); + } + + // aten::histogram.bins_tensor_out(Tensor self, Tensor bins, *, Tensor? weight=None, bool density=False, Tensor(a!) hist, Tensor(b!) bin_edges) -> (Tensor(a!) hist, Tensor(b!) bin_edges) + inline ::std::tuple histogram_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & hist, at::Tensor & bin_edges, const at::Tensor & self, const at::Tensor & bins, const ::std::optional & weight={}, bool density=false) { + return at::_ops::histogram_bins_tensor_out::redispatch(dispatchKeySet, self, bins, weight, density, hist, bin_edges); + } + + // aten::histogram.bins_tensor_out(Tensor self, Tensor bins, *, Tensor? weight=None, bool density=False, Tensor(a!) hist, Tensor(b!) bin_edges) -> (Tensor(a!) hist, Tensor(b!) bin_edges) + inline ::std::tuple histogram_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & bins, const ::std::optional & weight, bool density, at::Tensor & hist, at::Tensor & bin_edges) { + return at::_ops::histogram_bins_tensor_out::redispatch(dispatchKeySet, self, bins, weight, density, hist, bin_edges); + } + + // aten::histogram.bins_tensor(Tensor self, Tensor bins, *, Tensor? weight=None, bool density=False) -> (Tensor hist, Tensor bin_edges) + inline ::std::tuple histogram(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & bins, const ::std::optional & weight={}, bool density=false) { + return at::_ops::histogram_bins_tensor::redispatch(dispatchKeySet, self, bins, weight, density); + } + + // aten::histogram.bin_ct_out(Tensor self, int bins=100, *, float[]? range=None, Tensor? weight=None, bool density=False, Tensor(a!) hist, Tensor(b!) bin_edges) -> (Tensor(a!) hist, Tensor(b!) bin_edges) + inline ::std::tuple histogram_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & hist, at::Tensor & bin_edges, const at::Tensor & self, int64_t bins=100, ::std::optional> range=::std::nullopt, const ::std::optional & weight={}, bool density=false) { + return at::_ops::histogram_bin_ct_out::redispatch(dispatchKeySet, self, bins, range, weight, density, hist, bin_edges); + } + + // aten::histogram.bin_ct_out(Tensor self, int bins=100, *, float[]? range=None, Tensor? weight=None, bool density=False, Tensor(a!) hist, Tensor(b!) bin_edges) -> (Tensor(a!) hist, Tensor(b!) bin_edges) + inline ::std::tuple histogram_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t bins, ::std::optional> range, const ::std::optional & weight, bool density, at::Tensor & hist, at::Tensor & bin_edges) { + return at::_ops::histogram_bin_ct_out::redispatch(dispatchKeySet, self, bins, range, weight, density, hist, bin_edges); + } + + // aten::histogram.bin_ct(Tensor self, int bins=100, *, float[]? range=None, Tensor? weight=None, bool density=False) -> (Tensor hist, Tensor bin_edges) + inline ::std::tuple histogram(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t bins=100, ::std::optional> range=::std::nullopt, const ::std::optional & weight={}, bool density=false) { + return at::_ops::histogram_bin_ct::redispatch(dispatchKeySet, self, bins, range, weight, density); + } + + // aten::_histogramdd_bin_edges(Tensor self, int[] bins, *, float[]? range=None, Tensor? weight=None, bool density=False) -> Tensor[] + inline ::std::vector _histogramdd_bin_edges(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef bins, ::std::optional> range=::std::nullopt, const ::std::optional & weight={}, bool density=false) { + return at::_ops::_histogramdd_bin_edges::redispatch(dispatchKeySet, self, bins, range, weight, density); + } + + // aten::_histogramdd_from_bin_cts(Tensor self, int[] bins, *, float[]? range=None, Tensor? weight=None, bool density=False) -> Tensor + inline at::Tensor _histogramdd_from_bin_cts(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef bins, ::std::optional> range=::std::nullopt, const ::std::optional & weight={}, bool density=false) { + return at::_ops::_histogramdd_from_bin_cts::redispatch(dispatchKeySet, self, bins, range, weight, density); + } + + // aten::_histogramdd_from_bin_tensors(Tensor self, Tensor[] bins, *, Tensor? weight=None, bool density=False) -> Tensor + inline at::Tensor _histogramdd_from_bin_tensors(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::TensorList bins, const ::std::optional & weight={}, bool density=false) { + return at::_ops::_histogramdd_from_bin_tensors::redispatch(dispatchKeySet, self, bins, weight, density); + } + + // aten::histogramdd(Tensor self, int[] bins, float[]? range=None, Tensor? weight=None, bool density=False) -> (Tensor hist, Tensor[] bin_edges) + inline ::std::tuple> histogramdd(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef bins, ::std::optional> range=::std::nullopt, const ::std::optional & weight={}, bool density=false) { + return at::_ops::histogramdd::redispatch(dispatchKeySet, self, bins, range, weight, density); + } + + // aten::histogramdd.int_bins(Tensor self, int bins, float[]? range=None, Tensor? weight=None, bool density=False) -> (Tensor hist, Tensor[] bin_edges) + inline ::std::tuple> histogramdd(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t bins, ::std::optional> range=::std::nullopt, const ::std::optional & weight={}, bool density=false) { + return at::_ops::histogramdd_int_bins::redispatch(dispatchKeySet, self, bins, range, weight, density); + } + + // aten::histogramdd.TensorList_bins(Tensor self, Tensor[] bins, float[]? range=None, Tensor? weight=None, bool density=False) -> (Tensor hist, Tensor[] bin_edges) + inline ::std::tuple> histogramdd(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::TensorList bins, ::std::optional> range=::std::nullopt, const ::std::optional & weight={}, bool density=false) { + return at::_ops::histogramdd_TensorList_bins::redispatch(dispatchKeySet, self, bins, range, weight, density); + } + + // aten::fmod.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fmod_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::fmod_Scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::fmod.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fmod_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, at::Tensor & out) { + return at::_ops::fmod_Scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::fmod.Scalar(Tensor self, Scalar other) -> Tensor + inline at::Tensor fmod(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::fmod_Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::fmod_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + inline at::Tensor & fmod_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other) { + return at::_ops::fmod__Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::fmod.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fmod_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::fmod_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::fmod.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fmod_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::fmod_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::fmod.Tensor(Tensor self, Tensor other) -> Tensor + inline at::Tensor fmod(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::fmod_Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::fmod_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + inline at::Tensor & fmod_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) { + return at::_ops::fmod__Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::hypot.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & hypot_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::hypot_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::hypot.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & hypot_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::hypot_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::hypot(Tensor self, Tensor other) -> Tensor + inline at::Tensor hypot(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::hypot::redispatch(dispatchKeySet, self, other); + } + + // aten::hypot_(Tensor(a!) self, Tensor other) -> Tensor(a!) + inline at::Tensor & hypot_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) { + return at::_ops::hypot_::redispatch(dispatchKeySet, self, other); + } + + // aten::igamma.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & igamma_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::igamma_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::igamma.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & igamma_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::igamma_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::igamma(Tensor self, Tensor other) -> Tensor + inline at::Tensor igamma(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::igamma::redispatch(dispatchKeySet, self, other); + } + + // aten::igamma_(Tensor(a!) self, Tensor other) -> Tensor(a!) + inline at::Tensor & igamma_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) { + return at::_ops::igamma_::redispatch(dispatchKeySet, self, other); + } + + // aten::igammac.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & igammac_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::igammac_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::igammac.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & igammac_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::igammac_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::igammac(Tensor self, Tensor other) -> Tensor + inline at::Tensor igammac(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::igammac::redispatch(dispatchKeySet, self, other); + } + + // aten::igammac_(Tensor(a!) self, Tensor other) -> Tensor(a!) + inline at::Tensor & igammac_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) { + return at::_ops::igammac_::redispatch(dispatchKeySet, self, other); + } + + // aten::nextafter.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & nextafter_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::nextafter_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::nextafter.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & nextafter_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::nextafter_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::nextafter(Tensor self, Tensor other) -> Tensor + inline at::Tensor nextafter(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::nextafter::redispatch(dispatchKeySet, self, other); + } + + // aten::nextafter_(Tensor(a!) self, Tensor other) -> Tensor(a!) + inline at::Tensor & nextafter_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) { + return at::_ops::nextafter_::redispatch(dispatchKeySet, self, other); + } + + // aten::remainder.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & remainder_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::remainder_Scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::remainder.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & remainder_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, at::Tensor & out) { + return at::_ops::remainder_Scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::remainder.Scalar(Tensor self, Scalar other) -> Tensor + inline at::Tensor remainder(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::remainder_Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::remainder_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + inline at::Tensor & remainder_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other) { + return at::_ops::remainder__Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::remainder.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & remainder_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::remainder_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::remainder.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & remainder_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::remainder_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::remainder.Tensor(Tensor self, Tensor other) -> Tensor + inline at::Tensor remainder(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::remainder_Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::remainder_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + inline at::Tensor & remainder_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) { + return at::_ops::remainder__Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::remainder.Scalar_Tensor(Scalar self, Tensor other) -> Tensor + inline at::Tensor remainder(c10::DispatchKeySet dispatchKeySet, const at::Scalar & self, const at::Tensor & other) { + return at::_ops::remainder_Scalar_Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::min(Tensor self) -> Tensor + inline at::Tensor min(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::min::redispatch(dispatchKeySet, self); + } + + // aten::min.unary_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & min_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::min_unary_out::redispatch(dispatchKeySet, self, out); + } + + // aten::min.unary_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & min_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::min_unary_out::redispatch(dispatchKeySet, self, out); + } + + // aten::fmin(Tensor self, Tensor other) -> Tensor + inline at::Tensor fmin(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::fmin::redispatch(dispatchKeySet, self, other); + } + + // aten::fmin.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fmin_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::fmin_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::fmin.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fmin_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::fmin_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::max(Tensor self) -> Tensor + inline at::Tensor max(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::max::redispatch(dispatchKeySet, self); + } + + // aten::fmax(Tensor self, Tensor other) -> Tensor + inline at::Tensor fmax(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::fmax::redispatch(dispatchKeySet, self, other); + } + + // aten::fmax.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fmax_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::fmax_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::fmax.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fmax_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::fmax_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::maximum(Tensor self, Tensor other) -> Tensor + inline at::Tensor maximum(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::maximum::redispatch(dispatchKeySet, self, other); + } + + // aten::maximum.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & maximum_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::maximum_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::maximum.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & maximum_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::maximum_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::max.other(Tensor self, Tensor other) -> Tensor + inline at::Tensor max(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::max_other::redispatch(dispatchKeySet, self, other); + } + + // aten::max.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & max_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::max_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::max.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & max_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::max_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::max.unary_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & max_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::max_unary_out::redispatch(dispatchKeySet, self, out); + } + + // aten::max.unary_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & max_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::max_unary_out::redispatch(dispatchKeySet, self, out); + } + + // aten::minimum(Tensor self, Tensor other) -> Tensor + inline at::Tensor minimum(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::minimum::redispatch(dispatchKeySet, self, other); + } + + // aten::minimum.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & minimum_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::minimum_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::minimum.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & minimum_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::minimum_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::min.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & min_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::min_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::min.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & min_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::min_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::min.other(Tensor self, Tensor other) -> Tensor + inline at::Tensor min(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::min_other::redispatch(dispatchKeySet, self, other); + } + + // aten::quantile(Tensor self, Tensor q, int? dim=None, bool keepdim=False, *, str interpolation='linear') -> Tensor + inline at::Tensor quantile(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & q, ::std::optional dim=::std::nullopt, bool keepdim=false, c10::string_view interpolation="linear") { + return at::_ops::quantile::redispatch(dispatchKeySet, self, q, dim, keepdim, interpolation); + } + + // aten::quantile.out(Tensor self, Tensor q, int? dim=None, bool keepdim=False, *, str interpolation='linear', Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & quantile_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & q, ::std::optional dim=::std::nullopt, bool keepdim=false, c10::string_view interpolation="linear") { + return at::_ops::quantile_out::redispatch(dispatchKeySet, self, q, dim, keepdim, interpolation, out); + } + + // aten::quantile.out(Tensor self, Tensor q, int? dim=None, bool keepdim=False, *, str interpolation='linear', Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & quantile_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & q, ::std::optional dim, bool keepdim, c10::string_view interpolation, at::Tensor & out) { + return at::_ops::quantile_out::redispatch(dispatchKeySet, self, q, dim, keepdim, interpolation, out); + } + + // aten::quantile.scalar(Tensor self, float q, int? dim=None, bool keepdim=False, *, str interpolation='linear') -> Tensor + inline at::Tensor quantile(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double q, ::std::optional dim=::std::nullopt, bool keepdim=false, c10::string_view interpolation="linear") { + return at::_ops::quantile_scalar::redispatch(dispatchKeySet, self, q, dim, keepdim, interpolation); + } + + // aten::quantile.scalar_out(Tensor self, float q, int? dim=None, bool keepdim=False, *, str interpolation='linear', Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & quantile_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, double q, ::std::optional dim=::std::nullopt, bool keepdim=false, c10::string_view interpolation="linear") { + return at::_ops::quantile_scalar_out::redispatch(dispatchKeySet, self, q, dim, keepdim, interpolation, out); + } + + // aten::quantile.scalar_out(Tensor self, float q, int? dim=None, bool keepdim=False, *, str interpolation='linear', Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & quantile_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double q, ::std::optional dim, bool keepdim, c10::string_view interpolation, at::Tensor & out) { + return at::_ops::quantile_scalar_out::redispatch(dispatchKeySet, self, q, dim, keepdim, interpolation, out); + } + + // aten::nanquantile(Tensor self, Tensor q, int? dim=None, bool keepdim=False, *, str interpolation='linear') -> Tensor + inline at::Tensor nanquantile(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & q, ::std::optional dim=::std::nullopt, bool keepdim=false, c10::string_view interpolation="linear") { + return at::_ops::nanquantile::redispatch(dispatchKeySet, self, q, dim, keepdim, interpolation); + } + + // aten::nanquantile.out(Tensor self, Tensor q, int? dim=None, bool keepdim=False, *, str interpolation='linear', Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & nanquantile_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & q, ::std::optional dim=::std::nullopt, bool keepdim=false, c10::string_view interpolation="linear") { + return at::_ops::nanquantile_out::redispatch(dispatchKeySet, self, q, dim, keepdim, interpolation, out); + } + + // aten::nanquantile.out(Tensor self, Tensor q, int? dim=None, bool keepdim=False, *, str interpolation='linear', Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & nanquantile_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & q, ::std::optional dim, bool keepdim, c10::string_view interpolation, at::Tensor & out) { + return at::_ops::nanquantile_out::redispatch(dispatchKeySet, self, q, dim, keepdim, interpolation, out); + } + + // aten::nanquantile.scalar(Tensor self, float q, int? dim=None, bool keepdim=False, *, str interpolation='linear') -> Tensor + inline at::Tensor nanquantile(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double q, ::std::optional dim=::std::nullopt, bool keepdim=false, c10::string_view interpolation="linear") { + return at::_ops::nanquantile_scalar::redispatch(dispatchKeySet, self, q, dim, keepdim, interpolation); + } + + // aten::nanquantile.scalar_out(Tensor self, float q, int? dim=None, bool keepdim=False, *, str interpolation='linear', Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & nanquantile_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, double q, ::std::optional dim=::std::nullopt, bool keepdim=false, c10::string_view interpolation="linear") { + return at::_ops::nanquantile_scalar_out::redispatch(dispatchKeySet, self, q, dim, keepdim, interpolation, out); + } + + // aten::nanquantile.scalar_out(Tensor self, float q, int? dim=None, bool keepdim=False, *, str interpolation='linear', Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & nanquantile_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double q, ::std::optional dim, bool keepdim, c10::string_view interpolation, at::Tensor & out) { + return at::_ops::nanquantile_scalar_out::redispatch(dispatchKeySet, self, q, dim, keepdim, interpolation, out); + } + + // aten::sort.values(Tensor self, int dim=-1, bool descending=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + inline ::std::tuple sort_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & values, at::Tensor & indices, const at::Tensor & self, int64_t dim=-1, bool descending=false) { + return at::_ops::sort_values::redispatch(dispatchKeySet, self, dim, descending, values, indices); + } + + // aten::sort.values(Tensor self, int dim=-1, bool descending=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + inline ::std::tuple sort_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, bool descending, at::Tensor & values, at::Tensor & indices) { + return at::_ops::sort_values::redispatch(dispatchKeySet, self, dim, descending, values, indices); + } + + // aten::sort.values_stable(Tensor self, *, bool? stable, int dim=-1, bool descending=False, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + inline ::std::tuple sort_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & values, at::Tensor & indices, const at::Tensor & self, ::std::optional stable, int64_t dim=-1, bool descending=false) { + return at::_ops::sort_values_stable::redispatch(dispatchKeySet, self, stable, dim, descending, values, indices); + } + + // aten::sort.values_stable(Tensor self, *, bool? stable, int dim=-1, bool descending=False, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + inline ::std::tuple sort_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional stable, int64_t dim, bool descending, at::Tensor & values, at::Tensor & indices) { + return at::_ops::sort_values_stable::redispatch(dispatchKeySet, self, stable, dim, descending, values, indices); + } + + // aten::sort(Tensor self, int dim=-1, bool descending=False) -> (Tensor values, Tensor indices) + inline ::std::tuple sort(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim=-1, bool descending=false) { + return at::_ops::sort::redispatch(dispatchKeySet, self, dim, descending); + } + + // aten::sort.stable(Tensor self, *, bool? stable, int dim=-1, bool descending=False) -> (Tensor values, Tensor indices) + inline ::std::tuple sort(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional stable, int64_t dim=-1, bool descending=false) { + return at::_ops::sort_stable::redispatch(dispatchKeySet, self, stable, dim, descending); + } + + // aten::sort.dimname_values(Tensor self, Dimname dim, bool descending=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + inline ::std::tuple sort_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & values, at::Tensor & indices, const at::Tensor & self, at::Dimname dim, bool descending=false) { + return at::_ops::sort_dimname_values::redispatch(dispatchKeySet, self, dim, descending, values, indices); + } + + // aten::sort.dimname_values(Tensor self, Dimname dim, bool descending=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + inline ::std::tuple sort_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, bool descending, at::Tensor & values, at::Tensor & indices) { + return at::_ops::sort_dimname_values::redispatch(dispatchKeySet, self, dim, descending, values, indices); + } + + // aten::sort.dimname_values_stable(Tensor self, *, bool? stable, Dimname dim, bool descending=False, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + inline ::std::tuple sort_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & values, at::Tensor & indices, const at::Tensor & self, ::std::optional stable, at::Dimname dim, bool descending=false) { + return at::_ops::sort_dimname_values_stable::redispatch(dispatchKeySet, self, stable, dim, descending, values, indices); + } + + // aten::sort.dimname_values_stable(Tensor self, *, bool? stable, Dimname dim, bool descending=False, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + inline ::std::tuple sort_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional stable, at::Dimname dim, bool descending, at::Tensor & values, at::Tensor & indices) { + return at::_ops::sort_dimname_values_stable::redispatch(dispatchKeySet, self, stable, dim, descending, values, indices); + } + + // aten::sort.dimname(Tensor self, Dimname dim, bool descending=False) -> (Tensor values, Tensor indices) + inline ::std::tuple sort(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, bool descending=false) { + return at::_ops::sort_dimname::redispatch(dispatchKeySet, self, dim, descending); + } + + // aten::sort.dimname_stable(Tensor self, *, bool? stable, Dimname dim, bool descending=False) -> (Tensor values, Tensor indices) + inline ::std::tuple sort(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional stable, at::Dimname dim, bool descending=false) { + return at::_ops::sort_dimname_stable::redispatch(dispatchKeySet, self, stable, dim, descending); + } + + // aten::msort.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & msort_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::msort_out::redispatch(dispatchKeySet, self, out); + } + + // aten::msort.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & msort_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::msort_out::redispatch(dispatchKeySet, self, out); + } + + // aten::msort(Tensor self) -> Tensor + inline at::Tensor msort(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::msort::redispatch(dispatchKeySet, self); + } + + // aten::argsort(Tensor self, int dim=-1, bool descending=False) -> Tensor + inline at::Tensor argsort(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim=-1, bool descending=false) { + return at::_ops::argsort::redispatch(dispatchKeySet, self, dim, descending); + } + + // aten::argsort.stable(Tensor self, *, bool stable, int dim=-1, bool descending=False) -> Tensor + inline at::Tensor argsort(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool stable, int64_t dim=-1, bool descending=false) { + return at::_ops::argsort_stable::redispatch(dispatchKeySet, self, stable, dim, descending); + } + + // aten::argsort.stable_out(Tensor self, *, bool stable, int dim=-1, bool descending=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & argsort_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, bool stable, int64_t dim=-1, bool descending=false) { + return at::_ops::argsort_stable_out::redispatch(dispatchKeySet, self, stable, dim, descending, out); + } + + // aten::argsort.stable_out(Tensor self, *, bool stable, int dim=-1, bool descending=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & argsort_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool stable, int64_t dim, bool descending, at::Tensor & out) { + return at::_ops::argsort_stable_out::redispatch(dispatchKeySet, self, stable, dim, descending, out); + } + + // aten::argsort.dimname(Tensor self, Dimname dim, bool descending=False) -> Tensor + inline at::Tensor argsort(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, bool descending=false) { + return at::_ops::argsort_dimname::redispatch(dispatchKeySet, self, dim, descending); + } + + // aten::topk.values(Tensor self, SymInt k, int dim=-1, bool largest=True, bool sorted=True, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + inline ::std::tuple topk_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & values, at::Tensor & indices, const at::Tensor & self, int64_t k, int64_t dim=-1, bool largest=true, bool sorted=true) { + return at::_ops::topk_values::redispatch(dispatchKeySet, self, k, dim, largest, sorted, values, indices); + } + + // aten::topk.values(Tensor self, SymInt k, int dim=-1, bool largest=True, bool sorted=True, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + inline ::std::tuple topk_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t k, int64_t dim, bool largest, bool sorted, at::Tensor & values, at::Tensor & indices) { + return at::_ops::topk_values::redispatch(dispatchKeySet, self, k, dim, largest, sorted, values, indices); + } + + // aten::topk.values(Tensor self, SymInt k, int dim=-1, bool largest=True, bool sorted=True, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + inline ::std::tuple topk_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & values, at::Tensor & indices, const at::Tensor & self, c10::SymInt k, int64_t dim=-1, bool largest=true, bool sorted=true) { + return at::_ops::topk_values::redispatch(dispatchKeySet, self, k, dim, largest, sorted, values, indices); + } + + // aten::topk.values(Tensor self, SymInt k, int dim=-1, bool largest=True, bool sorted=True, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + inline ::std::tuple topk_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymInt k, int64_t dim, bool largest, bool sorted, at::Tensor & values, at::Tensor & indices) { + return at::_ops::topk_values::redispatch(dispatchKeySet, self, k, dim, largest, sorted, values, indices); + } + + // aten::topk(Tensor self, SymInt k, int dim=-1, bool largest=True, bool sorted=True) -> (Tensor values, Tensor indices) + inline ::std::tuple topk(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t k, int64_t dim=-1, bool largest=true, bool sorted=true) { + return at::_ops::topk::redispatch(dispatchKeySet, self, k, dim, largest, sorted); + } + + // aten::topk(Tensor self, SymInt k, int dim=-1, bool largest=True, bool sorted=True) -> (Tensor values, Tensor indices) + inline ::std::tuple topk_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymInt k, int64_t dim=-1, bool largest=true, bool sorted=true) { + return at::_ops::topk::redispatch(dispatchKeySet, self, k, dim, largest, sorted); + } + + // aten::all(Tensor self) -> Tensor + inline at::Tensor all(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::all::redispatch(dispatchKeySet, self); + } + + // aten::all.all_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & all_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::all_all_out::redispatch(dispatchKeySet, self, out); + } + + // aten::all.all_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & all_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::all_all_out::redispatch(dispatchKeySet, self, out); + } + + // aten::any(Tensor self) -> Tensor + inline at::Tensor any(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::any::redispatch(dispatchKeySet, self); + } + + // aten::any.all_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & any_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::any_all_out::redispatch(dispatchKeySet, self, out); + } + + // aten::any.all_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & any_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::any_all_out::redispatch(dispatchKeySet, self, out); + } + + // aten::renorm.out(Tensor self, Scalar p, int dim, Scalar maxnorm, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & renorm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & p, int64_t dim, const at::Scalar & maxnorm) { + return at::_ops::renorm_out::redispatch(dispatchKeySet, self, p, dim, maxnorm, out); + } + + // aten::renorm.out(Tensor self, Scalar p, int dim, Scalar maxnorm, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & renorm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & p, int64_t dim, const at::Scalar & maxnorm, at::Tensor & out) { + return at::_ops::renorm_out::redispatch(dispatchKeySet, self, p, dim, maxnorm, out); + } + + // aten::renorm(Tensor self, Scalar p, int dim, Scalar maxnorm) -> Tensor + inline at::Tensor renorm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & p, int64_t dim, const at::Scalar & maxnorm) { + return at::_ops::renorm::redispatch(dispatchKeySet, self, p, dim, maxnorm); + } + + // aten::renorm_(Tensor(a!) self, Scalar p, int dim, Scalar maxnorm) -> Tensor(a!) + inline at::Tensor & renorm_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & p, int64_t dim, const at::Scalar & maxnorm) { + return at::_ops::renorm_::redispatch(dispatchKeySet, self, p, dim, maxnorm); + } + + // aten::unfold(Tensor(a) self, int dimension, int size, int step) -> Tensor(a) + inline at::Tensor unfold(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dimension, int64_t size, int64_t step) { + return at::_ops::unfold::redispatch(dispatchKeySet, self, dimension, size, step); + } + + // aten::unfold_backward(Tensor grad_in, SymInt[] input_sizes, int dim, int size, int step) -> Tensor + inline at::Tensor unfold_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_in, at::IntArrayRef input_sizes, int64_t dim, int64_t size, int64_t step) { + return at::_ops::unfold_backward::redispatch(dispatchKeySet, grad_in, c10::fromIntArrayRefSlow(input_sizes), dim, size, step); + } + + // aten::unfold_backward(Tensor grad_in, SymInt[] input_sizes, int dim, int size, int step) -> Tensor + inline at::Tensor unfold_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_in, c10::SymIntArrayRef input_sizes, int64_t dim, int64_t size, int64_t step) { + return at::_ops::unfold_backward::redispatch(dispatchKeySet, grad_in, input_sizes, dim, size, step); + } + + // aten::equal(Tensor self, Tensor other) -> bool + inline bool equal(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::equal::redispatch(dispatchKeySet, self, other); + } + + // aten::pow.Tensor_Tensor_out(Tensor self, Tensor exponent, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & pow_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & exponent) { + return at::_ops::pow_Tensor_Tensor_out::redispatch(dispatchKeySet, self, exponent, out); + } + + // aten::pow.Tensor_Tensor_out(Tensor self, Tensor exponent, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & pow_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & exponent, at::Tensor & out) { + return at::_ops::pow_Tensor_Tensor_out::redispatch(dispatchKeySet, self, exponent, out); + } + + // aten::pow.Tensor_Tensor(Tensor self, Tensor exponent) -> Tensor + inline at::Tensor pow(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & exponent) { + return at::_ops::pow_Tensor_Tensor::redispatch(dispatchKeySet, self, exponent); + } + + // aten::pow.Scalar_out(Scalar self, Tensor exponent, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & pow_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & self, const at::Tensor & exponent) { + return at::_ops::pow_Scalar_out::redispatch(dispatchKeySet, self, exponent, out); + } + + // aten::pow.Scalar_out(Scalar self, Tensor exponent, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & pow_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & self, const at::Tensor & exponent, at::Tensor & out) { + return at::_ops::pow_Scalar_out::redispatch(dispatchKeySet, self, exponent, out); + } + + // aten::pow.Scalar(Scalar self, Tensor exponent) -> Tensor + inline at::Tensor pow(c10::DispatchKeySet dispatchKeySet, const at::Scalar & self, const at::Tensor & exponent) { + return at::_ops::pow_Scalar::redispatch(dispatchKeySet, self, exponent); + } + + // aten::pow.Tensor_Scalar_out(Tensor self, Scalar exponent, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & pow_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & exponent) { + return at::_ops::pow_Tensor_Scalar_out::redispatch(dispatchKeySet, self, exponent, out); + } + + // aten::pow.Tensor_Scalar_out(Tensor self, Scalar exponent, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & pow_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & exponent, at::Tensor & out) { + return at::_ops::pow_Tensor_Scalar_out::redispatch(dispatchKeySet, self, exponent, out); + } + + // aten::pow.Tensor_Scalar(Tensor self, Scalar exponent) -> Tensor + inline at::Tensor pow(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & exponent) { + return at::_ops::pow_Tensor_Scalar::redispatch(dispatchKeySet, self, exponent); + } + + // aten::pow_.Scalar(Tensor(a!) self, Scalar exponent) -> Tensor(a!) + inline at::Tensor & pow_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & exponent) { + return at::_ops::pow__Scalar::redispatch(dispatchKeySet, self, exponent); + } + + // aten::pow_.Tensor(Tensor(a!) self, Tensor exponent) -> Tensor(a!) + inline at::Tensor & pow_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & exponent) { + return at::_ops::pow__Tensor::redispatch(dispatchKeySet, self, exponent); + } + + // aten::float_power.Tensor_Tensor_out(Tensor self, Tensor exponent, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & float_power_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & exponent) { + return at::_ops::float_power_Tensor_Tensor_out::redispatch(dispatchKeySet, self, exponent, out); + } + + // aten::float_power.Tensor_Tensor_out(Tensor self, Tensor exponent, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & float_power_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & exponent, at::Tensor & out) { + return at::_ops::float_power_Tensor_Tensor_out::redispatch(dispatchKeySet, self, exponent, out); + } + + // aten::float_power.Tensor_Tensor(Tensor self, Tensor exponent) -> Tensor + inline at::Tensor float_power(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & exponent) { + return at::_ops::float_power_Tensor_Tensor::redispatch(dispatchKeySet, self, exponent); + } + + // aten::float_power.Scalar_out(Scalar self, Tensor exponent, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & float_power_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & self, const at::Tensor & exponent) { + return at::_ops::float_power_Scalar_out::redispatch(dispatchKeySet, self, exponent, out); + } + + // aten::float_power.Scalar_out(Scalar self, Tensor exponent, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & float_power_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & self, const at::Tensor & exponent, at::Tensor & out) { + return at::_ops::float_power_Scalar_out::redispatch(dispatchKeySet, self, exponent, out); + } + + // aten::float_power.Scalar(Scalar self, Tensor exponent) -> Tensor + inline at::Tensor float_power(c10::DispatchKeySet dispatchKeySet, const at::Scalar & self, const at::Tensor & exponent) { + return at::_ops::float_power_Scalar::redispatch(dispatchKeySet, self, exponent); + } + + // aten::float_power.Tensor_Scalar_out(Tensor self, Scalar exponent, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & float_power_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & exponent) { + return at::_ops::float_power_Tensor_Scalar_out::redispatch(dispatchKeySet, self, exponent, out); + } + + // aten::float_power.Tensor_Scalar_out(Tensor self, Scalar exponent, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & float_power_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & exponent, at::Tensor & out) { + return at::_ops::float_power_Tensor_Scalar_out::redispatch(dispatchKeySet, self, exponent, out); + } + + // aten::float_power.Tensor_Scalar(Tensor self, Scalar exponent) -> Tensor + inline at::Tensor float_power(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & exponent) { + return at::_ops::float_power_Tensor_Scalar::redispatch(dispatchKeySet, self, exponent); + } + + // aten::float_power_.Scalar(Tensor(a!) self, Scalar exponent) -> Tensor(a!) + inline at::Tensor & float_power_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & exponent) { + return at::_ops::float_power__Scalar::redispatch(dispatchKeySet, self, exponent); + } + + // aten::float_power_.Tensor(Tensor(a!) self, Tensor exponent) -> Tensor(a!) + inline at::Tensor & float_power_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & exponent) { + return at::_ops::float_power__Tensor::redispatch(dispatchKeySet, self, exponent); + } + + // aten::normal_(Tensor(a!) self, float mean=0, float std=1, *, Generator? generator=None) -> Tensor(a!) + inline at::Tensor & normal_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, double mean=0, double std=1, ::std::optional generator=::std::nullopt) { + return at::_ops::normal_::redispatch(dispatchKeySet, self, mean, std, generator); + } + + // aten::normal_functional(Tensor self, float mean=0, float std=1, *, Generator? generator=None) -> Tensor + inline at::Tensor normal_functional(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double mean=0, double std=1, ::std::optional generator=::std::nullopt) { + return at::_ops::normal_functional::redispatch(dispatchKeySet, self, mean, std, generator); + } + + // aten::normal.Tensor_float_out(Tensor mean, float std=1, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & normal_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & mean, double std=1, ::std::optional generator=::std::nullopt) { + return at::_ops::normal_Tensor_float_out::redispatch(dispatchKeySet, mean, std, generator, out); + } + + // aten::normal.Tensor_float_out(Tensor mean, float std=1, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & normal_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & mean, double std, ::std::optional generator, at::Tensor & out) { + return at::_ops::normal_Tensor_float_out::redispatch(dispatchKeySet, mean, std, generator, out); + } + + // aten::normal.Tensor_float(Tensor mean, float std=1, *, Generator? generator=None) -> Tensor + inline at::Tensor normal(c10::DispatchKeySet dispatchKeySet, const at::Tensor & mean, double std=1, ::std::optional generator=::std::nullopt) { + return at::_ops::normal_Tensor_float::redispatch(dispatchKeySet, mean, std, generator); + } + + // aten::normal.float_Tensor_out(float mean, Tensor std, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & normal_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, double mean, const at::Tensor & std, ::std::optional generator=::std::nullopt) { + return at::_ops::normal_float_Tensor_out::redispatch(dispatchKeySet, mean, std, generator, out); + } + + // aten::normal.float_Tensor_out(float mean, Tensor std, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & normal_outf(c10::DispatchKeySet dispatchKeySet, double mean, const at::Tensor & std, ::std::optional generator, at::Tensor & out) { + return at::_ops::normal_float_Tensor_out::redispatch(dispatchKeySet, mean, std, generator, out); + } + + // aten::normal.float_Tensor(float mean, Tensor std, *, Generator? generator=None) -> Tensor + inline at::Tensor normal(c10::DispatchKeySet dispatchKeySet, double mean, const at::Tensor & std, ::std::optional generator=::std::nullopt) { + return at::_ops::normal_float_Tensor::redispatch(dispatchKeySet, mean, std, generator); + } + + // aten::normal.Tensor_Tensor_out(Tensor mean, Tensor std, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & normal_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & mean, const at::Tensor & std, ::std::optional generator=::std::nullopt) { + return at::_ops::normal_Tensor_Tensor_out::redispatch(dispatchKeySet, mean, std, generator, out); + } + + // aten::normal.Tensor_Tensor_out(Tensor mean, Tensor std, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & normal_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & mean, const at::Tensor & std, ::std::optional generator, at::Tensor & out) { + return at::_ops::normal_Tensor_Tensor_out::redispatch(dispatchKeySet, mean, std, generator, out); + } + + // aten::normal.Tensor_Tensor(Tensor mean, Tensor std, *, Generator? generator=None) -> Tensor + inline at::Tensor normal(c10::DispatchKeySet dispatchKeySet, const at::Tensor & mean, const at::Tensor & std, ::std::optional generator=::std::nullopt) { + return at::_ops::normal_Tensor_Tensor::redispatch(dispatchKeySet, mean, std, generator); + } + + // aten::normal.float_float(float mean, float std, SymInt[] size, *, Generator? generator=None, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor normal(c10::DispatchKeySet dispatchKeySet, double mean, double std, at::IntArrayRef size, ::std::optional generator=::std::nullopt, at::TensorOptions options={}) { + return at::_ops::normal_float_float::redispatch(dispatchKeySet, mean, std, c10::fromIntArrayRefSlow(size), generator, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::normal.float_float(float mean, float std, SymInt[] size, *, Generator? generator=None, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor normal(c10::DispatchKeySet dispatchKeySet, double mean, double std, at::IntArrayRef size, ::std::optional generator, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::normal_float_float::redispatch(dispatchKeySet, mean, std, c10::fromIntArrayRefSlow(size), generator, dtype, layout, device, pin_memory); + } + + // aten::normal.float_float(float mean, float std, SymInt[] size, *, Generator? generator=None, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor normal_symint(c10::DispatchKeySet dispatchKeySet, double mean, double std, c10::SymIntArrayRef size, ::std::optional generator=::std::nullopt, at::TensorOptions options={}) { + return at::_ops::normal_float_float::redispatch(dispatchKeySet, mean, std, size, generator, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::normal.float_float(float mean, float std, SymInt[] size, *, Generator? generator=None, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor normal_symint(c10::DispatchKeySet dispatchKeySet, double mean, double std, c10::SymIntArrayRef size, ::std::optional generator, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::normal_float_float::redispatch(dispatchKeySet, mean, std, size, generator, dtype, layout, device, pin_memory); + } + + // aten::normal.float_float_out(float mean, float std, SymInt[] size, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & normal_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, double mean, double std, at::IntArrayRef size, ::std::optional generator=::std::nullopt) { + return at::_ops::normal_float_float_out::redispatch(dispatchKeySet, mean, std, c10::fromIntArrayRefSlow(size), generator, out); + } + + // aten::normal.float_float_out(float mean, float std, SymInt[] size, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & normal_outf(c10::DispatchKeySet dispatchKeySet, double mean, double std, at::IntArrayRef size, ::std::optional generator, at::Tensor & out) { + return at::_ops::normal_float_float_out::redispatch(dispatchKeySet, mean, std, c10::fromIntArrayRefSlow(size), generator, out); + } + + // aten::normal.float_float_out(float mean, float std, SymInt[] size, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & normal_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, double mean, double std, c10::SymIntArrayRef size, ::std::optional generator=::std::nullopt) { + return at::_ops::normal_float_float_out::redispatch(dispatchKeySet, mean, std, size, generator, out); + } + + // aten::normal.float_float_out(float mean, float std, SymInt[] size, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & normal_symint_outf(c10::DispatchKeySet dispatchKeySet, double mean, double std, c10::SymIntArrayRef size, ::std::optional generator, at::Tensor & out) { + return at::_ops::normal_float_float_out::redispatch(dispatchKeySet, mean, std, size, generator, out); + } + + // aten::alias(Tensor(a) self) -> Tensor(a) + inline at::Tensor alias(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::alias::redispatch(dispatchKeySet, self); + } + + // aten::_amp_foreach_non_finite_check_and_unscale_(Tensor(a!)[] self, Tensor(b!) found_inf, Tensor inv_scale) -> () + inline void _amp_foreach_non_finite_check_and_unscale_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::Tensor & found_inf, const at::Tensor & inv_scale) { + return at::_ops::_amp_foreach_non_finite_check_and_unscale_::redispatch(dispatchKeySet, self, found_inf, inv_scale); + } + + // aten::_amp_update_scale_(Tensor(a!) self, Tensor(b!) growth_tracker, Tensor found_inf, float scale_growth_factor, float scale_backoff_factor, int growth_interval) -> Tensor(a!) + inline at::Tensor & _amp_update_scale_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, at::Tensor & growth_tracker, const at::Tensor & found_inf, double scale_growth_factor, double scale_backoff_factor, int64_t growth_interval) { + return at::_ops::_amp_update_scale_::redispatch(dispatchKeySet, self, growth_tracker, found_inf, scale_growth_factor, scale_backoff_factor, growth_interval); + } + + // aten::_foreach_add.Scalar(Tensor[] self, Scalar scalar) -> Tensor[] + inline ::std::vector _foreach_add(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Scalar & scalar) { + return at::_ops::_foreach_add_Scalar::redispatch(dispatchKeySet, self, scalar); + } + + // aten::_foreach_add_.Scalar(Tensor(a!)[] self, Scalar scalar) -> () + inline void _foreach_add_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Scalar & scalar) { + return at::_ops::_foreach_add__Scalar::redispatch(dispatchKeySet, self, scalar); + } + + // aten::_foreach_add.List(Tensor[] self, Tensor[] other, *, Scalar alpha=1) -> Tensor[] + inline ::std::vector _foreach_add(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList other, const at::Scalar & alpha=1) { + return at::_ops::_foreach_add_List::redispatch(dispatchKeySet, self, other, alpha); + } + + // aten::_foreach_add_.List(Tensor(a!)[] self, Tensor[] other, *, Scalar alpha=1) -> () + inline void _foreach_add_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList other, const at::Scalar & alpha=1) { + return at::_ops::_foreach_add__List::redispatch(dispatchKeySet, self, other, alpha); + } + + // aten::_foreach_add.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[] + inline ::std::vector _foreach_add(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::ArrayRef scalars) { + return at::_ops::_foreach_add_ScalarList::redispatch(dispatchKeySet, self, scalars); + } + + // aten::_foreach_add_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> () + inline void _foreach_add_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::ArrayRef scalars) { + return at::_ops::_foreach_add__ScalarList::redispatch(dispatchKeySet, self, scalars); + } + + // aten::_foreach_add.Tensor(Tensor[] self, Tensor other, *, Scalar alpha=1) -> Tensor[] + inline ::std::vector _foreach_add(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Tensor & other, const at::Scalar & alpha=1) { + return at::_ops::_foreach_add_Tensor::redispatch(dispatchKeySet, self, other, alpha); + } + + // aten::_foreach_add_.Tensor(Tensor(a!)[] self, Tensor other, *, Scalar alpha=1) -> () + inline void _foreach_add_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Tensor & other, const at::Scalar & alpha=1) { + return at::_ops::_foreach_add__Tensor::redispatch(dispatchKeySet, self, other, alpha); + } + + // aten::_foreach_sub.Scalar(Tensor[] self, Scalar scalar) -> Tensor[] + inline ::std::vector _foreach_sub(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Scalar & scalar) { + return at::_ops::_foreach_sub_Scalar::redispatch(dispatchKeySet, self, scalar); + } + + // aten::_foreach_sub_.Scalar(Tensor(a!)[] self, Scalar scalar) -> () + inline void _foreach_sub_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Scalar & scalar) { + return at::_ops::_foreach_sub__Scalar::redispatch(dispatchKeySet, self, scalar); + } + + // aten::_foreach_sub.List(Tensor[] self, Tensor[] other, *, Scalar alpha=1) -> Tensor[] + inline ::std::vector _foreach_sub(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList other, const at::Scalar & alpha=1) { + return at::_ops::_foreach_sub_List::redispatch(dispatchKeySet, self, other, alpha); + } + + // aten::_foreach_sub_.List(Tensor(a!)[] self, Tensor[] other, *, Scalar alpha=1) -> () + inline void _foreach_sub_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList other, const at::Scalar & alpha=1) { + return at::_ops::_foreach_sub__List::redispatch(dispatchKeySet, self, other, alpha); + } + + // aten::_foreach_sub.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[] + inline ::std::vector _foreach_sub(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::ArrayRef scalars) { + return at::_ops::_foreach_sub_ScalarList::redispatch(dispatchKeySet, self, scalars); + } + + // aten::_foreach_sub_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> () + inline void _foreach_sub_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::ArrayRef scalars) { + return at::_ops::_foreach_sub__ScalarList::redispatch(dispatchKeySet, self, scalars); + } + + // aten::_foreach_mul.Scalar(Tensor[] self, Scalar scalar) -> Tensor[] + inline ::std::vector _foreach_mul(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Scalar & scalar) { + return at::_ops::_foreach_mul_Scalar::redispatch(dispatchKeySet, self, scalar); + } + + // aten::_foreach_mul_.Scalar(Tensor(a!)[] self, Scalar scalar) -> () + inline void _foreach_mul_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Scalar & scalar) { + return at::_ops::_foreach_mul__Scalar::redispatch(dispatchKeySet, self, scalar); + } + + // aten::_foreach_mul.List(Tensor[] self, Tensor[] other) -> Tensor[] + inline ::std::vector _foreach_mul(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList other) { + return at::_ops::_foreach_mul_List::redispatch(dispatchKeySet, self, other); + } + + // aten::_foreach_mul_.List(Tensor(a!)[] self, Tensor[] other) -> () + inline void _foreach_mul_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList other) { + return at::_ops::_foreach_mul__List::redispatch(dispatchKeySet, self, other); + } + + // aten::_foreach_mul.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[] + inline ::std::vector _foreach_mul(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::ArrayRef scalars) { + return at::_ops::_foreach_mul_ScalarList::redispatch(dispatchKeySet, self, scalars); + } + + // aten::_foreach_mul_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> () + inline void _foreach_mul_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::ArrayRef scalars) { + return at::_ops::_foreach_mul__ScalarList::redispatch(dispatchKeySet, self, scalars); + } + + // aten::_foreach_mul.Tensor(Tensor[] self, Tensor other) -> Tensor[] + inline ::std::vector _foreach_mul(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Tensor & other) { + return at::_ops::_foreach_mul_Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::_foreach_mul_.Tensor(Tensor(a!)[] self, Tensor other) -> () + inline void _foreach_mul_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Tensor & other) { + return at::_ops::_foreach_mul__Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::_foreach_div.Scalar(Tensor[] self, Scalar scalar) -> Tensor[] + inline ::std::vector _foreach_div(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Scalar & scalar) { + return at::_ops::_foreach_div_Scalar::redispatch(dispatchKeySet, self, scalar); + } + + // aten::_foreach_div_.Scalar(Tensor(a!)[] self, Scalar scalar) -> () + inline void _foreach_div_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Scalar & scalar) { + return at::_ops::_foreach_div__Scalar::redispatch(dispatchKeySet, self, scalar); + } + + // aten::_foreach_div.List(Tensor[] self, Tensor[] other) -> Tensor[] + inline ::std::vector _foreach_div(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList other) { + return at::_ops::_foreach_div_List::redispatch(dispatchKeySet, self, other); + } + + // aten::_foreach_div_.List(Tensor(a!)[] self, Tensor[] other) -> () + inline void _foreach_div_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList other) { + return at::_ops::_foreach_div__List::redispatch(dispatchKeySet, self, other); + } + + // aten::_foreach_div.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[] + inline ::std::vector _foreach_div(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::ArrayRef scalars) { + return at::_ops::_foreach_div_ScalarList::redispatch(dispatchKeySet, self, scalars); + } + + // aten::_foreach_div_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> () + inline void _foreach_div_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::ArrayRef scalars) { + return at::_ops::_foreach_div__ScalarList::redispatch(dispatchKeySet, self, scalars); + } + + // aten::_foreach_div.Tensor(Tensor[] self, Tensor other) -> Tensor[] + inline ::std::vector _foreach_div(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Tensor & other) { + return at::_ops::_foreach_div_Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::_foreach_div_.Tensor(Tensor(a!)[] self, Tensor other) -> () + inline void _foreach_div_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Tensor & other) { + return at::_ops::_foreach_div__Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::_foreach_clamp_max.Scalar(Tensor[] self, Scalar scalar) -> Tensor[] + inline ::std::vector _foreach_clamp_max(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Scalar & scalar) { + return at::_ops::_foreach_clamp_max_Scalar::redispatch(dispatchKeySet, self, scalar); + } + + // aten::_foreach_clamp_max_.Scalar(Tensor(a!)[] self, Scalar scalar) -> () + inline void _foreach_clamp_max_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Scalar & scalar) { + return at::_ops::_foreach_clamp_max__Scalar::redispatch(dispatchKeySet, self, scalar); + } + + // aten::_foreach_clamp_max.List(Tensor[] self, Tensor[] other) -> Tensor[] + inline ::std::vector _foreach_clamp_max(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList other) { + return at::_ops::_foreach_clamp_max_List::redispatch(dispatchKeySet, self, other); + } + + // aten::_foreach_clamp_max_.List(Tensor(a!)[] self, Tensor[] other) -> () + inline void _foreach_clamp_max_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList other) { + return at::_ops::_foreach_clamp_max__List::redispatch(dispatchKeySet, self, other); + } + + // aten::_foreach_clamp_max.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[] + inline ::std::vector _foreach_clamp_max(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::ArrayRef scalars) { + return at::_ops::_foreach_clamp_max_ScalarList::redispatch(dispatchKeySet, self, scalars); + } + + // aten::_foreach_clamp_max_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> () + inline void _foreach_clamp_max_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::ArrayRef scalars) { + return at::_ops::_foreach_clamp_max__ScalarList::redispatch(dispatchKeySet, self, scalars); + } + + // aten::_foreach_clamp_min.Scalar(Tensor[] self, Scalar scalar) -> Tensor[] + inline ::std::vector _foreach_clamp_min(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Scalar & scalar) { + return at::_ops::_foreach_clamp_min_Scalar::redispatch(dispatchKeySet, self, scalar); + } + + // aten::_foreach_clamp_min_.Scalar(Tensor(a!)[] self, Scalar scalar) -> () + inline void _foreach_clamp_min_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Scalar & scalar) { + return at::_ops::_foreach_clamp_min__Scalar::redispatch(dispatchKeySet, self, scalar); + } + + // aten::_foreach_clamp_min.List(Tensor[] self, Tensor[] other) -> Tensor[] + inline ::std::vector _foreach_clamp_min(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList other) { + return at::_ops::_foreach_clamp_min_List::redispatch(dispatchKeySet, self, other); + } + + // aten::_foreach_clamp_min_.List(Tensor(a!)[] self, Tensor[] other) -> () + inline void _foreach_clamp_min_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList other) { + return at::_ops::_foreach_clamp_min__List::redispatch(dispatchKeySet, self, other); + } + + // aten::_foreach_clamp_min.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[] + inline ::std::vector _foreach_clamp_min(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::ArrayRef scalars) { + return at::_ops::_foreach_clamp_min_ScalarList::redispatch(dispatchKeySet, self, scalars); + } + + // aten::_foreach_clamp_min_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> () + inline void _foreach_clamp_min_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::ArrayRef scalars) { + return at::_ops::_foreach_clamp_min__ScalarList::redispatch(dispatchKeySet, self, scalars); + } + + // aten::_foreach_maximum.Scalar(Tensor[] self, Scalar scalar) -> Tensor[] + inline ::std::vector _foreach_maximum(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Scalar & scalar) { + return at::_ops::_foreach_maximum_Scalar::redispatch(dispatchKeySet, self, scalar); + } + + // aten::_foreach_maximum_.Scalar(Tensor(a!)[] self, Scalar scalar) -> () + inline void _foreach_maximum_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Scalar & scalar) { + return at::_ops::_foreach_maximum__Scalar::redispatch(dispatchKeySet, self, scalar); + } + + // aten::_foreach_maximum.List(Tensor[] self, Tensor[] other) -> Tensor[] + inline ::std::vector _foreach_maximum(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList other) { + return at::_ops::_foreach_maximum_List::redispatch(dispatchKeySet, self, other); + } + + // aten::_foreach_maximum_.List(Tensor(a!)[] self, Tensor[] other) -> () + inline void _foreach_maximum_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList other) { + return at::_ops::_foreach_maximum__List::redispatch(dispatchKeySet, self, other); + } + + // aten::_foreach_maximum.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[] + inline ::std::vector _foreach_maximum(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::ArrayRef scalars) { + return at::_ops::_foreach_maximum_ScalarList::redispatch(dispatchKeySet, self, scalars); + } + + // aten::_foreach_maximum_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> () + inline void _foreach_maximum_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::ArrayRef scalars) { + return at::_ops::_foreach_maximum__ScalarList::redispatch(dispatchKeySet, self, scalars); + } + + // aten::_foreach_minimum.Scalar(Tensor[] self, Scalar scalar) -> Tensor[] + inline ::std::vector _foreach_minimum(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Scalar & scalar) { + return at::_ops::_foreach_minimum_Scalar::redispatch(dispatchKeySet, self, scalar); + } + + // aten::_foreach_minimum_.Scalar(Tensor(a!)[] self, Scalar scalar) -> () + inline void _foreach_minimum_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Scalar & scalar) { + return at::_ops::_foreach_minimum__Scalar::redispatch(dispatchKeySet, self, scalar); + } + + // aten::_foreach_minimum.List(Tensor[] self, Tensor[] other) -> Tensor[] + inline ::std::vector _foreach_minimum(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList other) { + return at::_ops::_foreach_minimum_List::redispatch(dispatchKeySet, self, other); + } + + // aten::_foreach_minimum_.List(Tensor(a!)[] self, Tensor[] other) -> () + inline void _foreach_minimum_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList other) { + return at::_ops::_foreach_minimum__List::redispatch(dispatchKeySet, self, other); + } + + // aten::_foreach_minimum.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[] + inline ::std::vector _foreach_minimum(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::ArrayRef scalars) { + return at::_ops::_foreach_minimum_ScalarList::redispatch(dispatchKeySet, self, scalars); + } + + // aten::_foreach_minimum_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> () + inline void _foreach_minimum_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::ArrayRef scalars) { + return at::_ops::_foreach_minimum__ScalarList::redispatch(dispatchKeySet, self, scalars); + } + + // aten::_foreach_addcdiv.Scalar(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1) -> Tensor[] + inline ::std::vector _foreach_addcdiv(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Scalar & value=1) { + return at::_ops::_foreach_addcdiv_Scalar::redispatch(dispatchKeySet, self, tensor1, tensor2, value); + } + + // aten::_foreach_addcdiv.ScalarList(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> Tensor[] + inline ::std::vector _foreach_addcdiv(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, at::ArrayRef scalars) { + return at::_ops::_foreach_addcdiv_ScalarList::redispatch(dispatchKeySet, self, tensor1, tensor2, scalars); + } + + // aten::_foreach_addcdiv.Tensor(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Tensor scalars) -> Tensor[] + inline ::std::vector _foreach_addcdiv(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Tensor & scalars) { + return at::_ops::_foreach_addcdiv_Tensor::redispatch(dispatchKeySet, self, tensor1, tensor2, scalars); + } + + // aten::_foreach_addcdiv_.Scalar(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1) -> () + inline void _foreach_addcdiv_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Scalar & value=1) { + return at::_ops::_foreach_addcdiv__Scalar::redispatch(dispatchKeySet, self, tensor1, tensor2, value); + } + + // aten::_foreach_addcdiv_.ScalarList(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> () + inline void _foreach_addcdiv_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, at::ArrayRef scalars) { + return at::_ops::_foreach_addcdiv__ScalarList::redispatch(dispatchKeySet, self, tensor1, tensor2, scalars); + } + + // aten::_foreach_addcdiv_.Tensor(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Tensor scalars) -> () + inline void _foreach_addcdiv_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Tensor & scalars) { + return at::_ops::_foreach_addcdiv__Tensor::redispatch(dispatchKeySet, self, tensor1, tensor2, scalars); + } + + // aten::_foreach_addcmul.Scalar(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1) -> Tensor[] + inline ::std::vector _foreach_addcmul(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Scalar & value=1) { + return at::_ops::_foreach_addcmul_Scalar::redispatch(dispatchKeySet, self, tensor1, tensor2, value); + } + + // aten::_foreach_addcmul.ScalarList(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> Tensor[] + inline ::std::vector _foreach_addcmul(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, at::ArrayRef scalars) { + return at::_ops::_foreach_addcmul_ScalarList::redispatch(dispatchKeySet, self, tensor1, tensor2, scalars); + } + + // aten::_foreach_addcmul.Tensor(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Tensor scalars) -> Tensor[] + inline ::std::vector _foreach_addcmul(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Tensor & scalars) { + return at::_ops::_foreach_addcmul_Tensor::redispatch(dispatchKeySet, self, tensor1, tensor2, scalars); + } + + // aten::_foreach_addcmul_.Scalar(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1) -> () + inline void _foreach_addcmul_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Scalar & value=1) { + return at::_ops::_foreach_addcmul__Scalar::redispatch(dispatchKeySet, self, tensor1, tensor2, value); + } + + // aten::_foreach_addcmul_.ScalarList(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> () + inline void _foreach_addcmul_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, at::ArrayRef scalars) { + return at::_ops::_foreach_addcmul__ScalarList::redispatch(dispatchKeySet, self, tensor1, tensor2, scalars); + } + + // aten::_foreach_addcmul_.Tensor(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Tensor scalars) -> () + inline void _foreach_addcmul_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Tensor & scalars) { + return at::_ops::_foreach_addcmul__Tensor::redispatch(dispatchKeySet, self, tensor1, tensor2, scalars); + } + + // aten::_foreach_abs(Tensor[] self) -> Tensor[] + inline ::std::vector _foreach_abs(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_abs::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_abs_(Tensor(a!)[] self) -> () + inline void _foreach_abs_(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_abs_::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_acos(Tensor[] self) -> Tensor[] + inline ::std::vector _foreach_acos(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_acos::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_acos_(Tensor(a!)[] self) -> () + inline void _foreach_acos_(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_acos_::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_asin(Tensor[] self) -> Tensor[] + inline ::std::vector _foreach_asin(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_asin::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_asin_(Tensor(a!)[] self) -> () + inline void _foreach_asin_(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_asin_::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_atan(Tensor[] self) -> Tensor[] + inline ::std::vector _foreach_atan(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_atan::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_atan_(Tensor(a!)[] self) -> () + inline void _foreach_atan_(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_atan_::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_ceil(Tensor[] self) -> Tensor[] + inline ::std::vector _foreach_ceil(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_ceil::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_ceil_(Tensor(a!)[] self) -> () + inline void _foreach_ceil_(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_ceil_::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_cos(Tensor[] self) -> Tensor[] + inline ::std::vector _foreach_cos(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_cos::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_cos_(Tensor(a!)[] self) -> () + inline void _foreach_cos_(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_cos_::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_cosh(Tensor[] self) -> Tensor[] + inline ::std::vector _foreach_cosh(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_cosh::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_cosh_(Tensor(a!)[] self) -> () + inline void _foreach_cosh_(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_cosh_::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_erf(Tensor[] self) -> Tensor[] + inline ::std::vector _foreach_erf(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_erf::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_erf_(Tensor(a!)[] self) -> () + inline void _foreach_erf_(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_erf_::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_erfc(Tensor[] self) -> Tensor[] + inline ::std::vector _foreach_erfc(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_erfc::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_erfc_(Tensor(a!)[] self) -> () + inline void _foreach_erfc_(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_erfc_::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_exp(Tensor[] self) -> Tensor[] + inline ::std::vector _foreach_exp(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_exp::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_exp_(Tensor(a!)[] self) -> () + inline void _foreach_exp_(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_exp_::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_expm1(Tensor[] self) -> Tensor[] + inline ::std::vector _foreach_expm1(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_expm1::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_expm1_(Tensor(a!)[] self) -> () + inline void _foreach_expm1_(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_expm1_::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_floor(Tensor[] self) -> Tensor[] + inline ::std::vector _foreach_floor(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_floor::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_floor_(Tensor(a!)[] self) -> () + inline void _foreach_floor_(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_floor_::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_frac(Tensor[] self) -> Tensor[] + inline ::std::vector _foreach_frac(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_frac::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_frac_(Tensor(a!)[] self) -> () + inline void _foreach_frac_(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_frac_::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_lerp.List(Tensor[] self, Tensor[] tensors1, Tensor[] weights) -> Tensor[] + inline ::std::vector _foreach_lerp(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList tensors1, at::TensorList weights) { + return at::_ops::_foreach_lerp_List::redispatch(dispatchKeySet, self, tensors1, weights); + } + + // aten::_foreach_lerp_.List(Tensor(a!)[] self, Tensor[] tensors1, Tensor[] weights) -> () + inline void _foreach_lerp_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList tensors1, at::TensorList weights) { + return at::_ops::_foreach_lerp__List::redispatch(dispatchKeySet, self, tensors1, weights); + } + + // aten::_foreach_lerp.Scalar(Tensor[] self, Tensor[] tensors1, Scalar weight) -> Tensor[] + inline ::std::vector _foreach_lerp(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList tensors1, const at::Scalar & weight) { + return at::_ops::_foreach_lerp_Scalar::redispatch(dispatchKeySet, self, tensors1, weight); + } + + // aten::_foreach_lerp_.Scalar(Tensor(a!)[] self, Tensor[] tensors1, Scalar weight) -> () + inline void _foreach_lerp_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList tensors1, const at::Scalar & weight) { + return at::_ops::_foreach_lerp__Scalar::redispatch(dispatchKeySet, self, tensors1, weight); + } + + // aten::_foreach_lerp.ScalarList(Tensor[] self, Tensor[] tensors1, Scalar[] weight) -> Tensor[] + inline ::std::vector _foreach_lerp(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList tensors1, at::ArrayRef weight) { + return at::_ops::_foreach_lerp_ScalarList::redispatch(dispatchKeySet, self, tensors1, weight); + } + + // aten::_foreach_lerp_.ScalarList(Tensor(a!)[] self, Tensor[] tensors1, Scalar[] weight) -> () + inline void _foreach_lerp_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList tensors1, at::ArrayRef weight) { + return at::_ops::_foreach_lerp__ScalarList::redispatch(dispatchKeySet, self, tensors1, weight); + } + + // aten::_foreach_lgamma(Tensor[] self) -> Tensor[] + inline ::std::vector _foreach_lgamma(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_lgamma::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_lgamma_(Tensor(a!)[] self) -> () + inline void _foreach_lgamma_(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_lgamma_::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_log(Tensor[] self) -> Tensor[] + inline ::std::vector _foreach_log(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_log::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_log_(Tensor(a!)[] self) -> () + inline void _foreach_log_(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_log_::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_log10(Tensor[] self) -> Tensor[] + inline ::std::vector _foreach_log10(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_log10::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_log10_(Tensor(a!)[] self) -> () + inline void _foreach_log10_(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_log10_::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_log1p(Tensor[] self) -> Tensor[] + inline ::std::vector _foreach_log1p(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_log1p::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_log1p_(Tensor(a!)[] self) -> () + inline void _foreach_log1p_(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_log1p_::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_log2(Tensor[] self) -> Tensor[] + inline ::std::vector _foreach_log2(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_log2::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_log2_(Tensor(a!)[] self) -> () + inline void _foreach_log2_(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_log2_::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_max(Tensor[] self) -> Tensor[] + inline ::std::vector _foreach_max(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_max::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_neg(Tensor[] self) -> Tensor[] + inline ::std::vector _foreach_neg(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_neg::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_neg_(Tensor(a!)[] self) -> () + inline void _foreach_neg_(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_neg_::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_norm.Scalar(Tensor[] self, Scalar ord=2, ScalarType? dtype=None) -> Tensor[] + inline ::std::vector _foreach_norm(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Scalar & ord=2, ::std::optional dtype=::std::nullopt) { + return at::_ops::_foreach_norm_Scalar::redispatch(dispatchKeySet, self, ord, dtype); + } + + // aten::_foreach_pow.List(Tensor[] self, Tensor[] exponent) -> Tensor[] + inline ::std::vector _foreach_pow(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList exponent) { + return at::_ops::_foreach_pow_List::redispatch(dispatchKeySet, self, exponent); + } + + // aten::_foreach_pow.Scalar(Tensor[] self, Scalar exponent) -> Tensor[] + inline ::std::vector _foreach_pow(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Scalar & exponent) { + return at::_ops::_foreach_pow_Scalar::redispatch(dispatchKeySet, self, exponent); + } + + // aten::_foreach_pow.ScalarList(Tensor[] self, Scalar[] exponent) -> Tensor[] + inline ::std::vector _foreach_pow(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::ArrayRef exponent) { + return at::_ops::_foreach_pow_ScalarList::redispatch(dispatchKeySet, self, exponent); + } + + // aten::_foreach_pow.ScalarAndTensor(Scalar self, Tensor[] exponent) -> Tensor[] + inline ::std::vector _foreach_pow(c10::DispatchKeySet dispatchKeySet, const at::Scalar & self, at::TensorList exponent) { + return at::_ops::_foreach_pow_ScalarAndTensor::redispatch(dispatchKeySet, self, exponent); + } + + // aten::_foreach_pow_.List(Tensor(a!)[] self, Tensor[] exponent) -> () + inline void _foreach_pow_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList exponent) { + return at::_ops::_foreach_pow__List::redispatch(dispatchKeySet, self, exponent); + } + + // aten::_foreach_pow_.Scalar(Tensor(a!)[] self, Scalar exponent) -> () + inline void _foreach_pow_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Scalar & exponent) { + return at::_ops::_foreach_pow__Scalar::redispatch(dispatchKeySet, self, exponent); + } + + // aten::_foreach_pow_.ScalarList(Tensor(a!)[] self, Scalar[] exponent) -> () + inline void _foreach_pow_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::ArrayRef exponent) { + return at::_ops::_foreach_pow__ScalarList::redispatch(dispatchKeySet, self, exponent); + } + + // aten::_foreach_reciprocal(Tensor[] self) -> Tensor[] + inline ::std::vector _foreach_reciprocal(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_reciprocal::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_reciprocal_(Tensor(a!)[] self) -> () + inline void _foreach_reciprocal_(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_reciprocal_::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_round(Tensor[] self) -> Tensor[] + inline ::std::vector _foreach_round(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_round::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_round_(Tensor(a!)[] self) -> () + inline void _foreach_round_(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_round_::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_rsqrt(Tensor[] self) -> Tensor[] + inline ::std::vector _foreach_rsqrt(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_rsqrt::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_rsqrt_(Tensor(a!)[] self) -> () + inline void _foreach_rsqrt_(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_rsqrt_::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_sigmoid(Tensor[] self) -> Tensor[] + inline ::std::vector _foreach_sigmoid(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_sigmoid::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_sigmoid_(Tensor(a!)[] self) -> () + inline void _foreach_sigmoid_(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_sigmoid_::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_sign(Tensor[] self) -> Tensor[] + inline ::std::vector _foreach_sign(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_sign::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_sign_(Tensor(a!)[] self) -> () + inline void _foreach_sign_(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_sign_::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_sin(Tensor[] self) -> Tensor[] + inline ::std::vector _foreach_sin(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_sin::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_sin_(Tensor(a!)[] self) -> () + inline void _foreach_sin_(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_sin_::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_sinh(Tensor[] self) -> Tensor[] + inline ::std::vector _foreach_sinh(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_sinh::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_sinh_(Tensor(a!)[] self) -> () + inline void _foreach_sinh_(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_sinh_::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_sqrt(Tensor[] self) -> Tensor[] + inline ::std::vector _foreach_sqrt(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_sqrt::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_sqrt_(Tensor(a!)[] self) -> () + inline void _foreach_sqrt_(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_sqrt_::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_tan(Tensor[] self) -> Tensor[] + inline ::std::vector _foreach_tan(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_tan::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_tan_(Tensor(a!)[] self) -> () + inline void _foreach_tan_(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_tan_::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_tanh(Tensor[] self) -> Tensor[] + inline ::std::vector _foreach_tanh(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_tanh::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_tanh_(Tensor(a!)[] self) -> () + inline void _foreach_tanh_(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_tanh_::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_trunc(Tensor[] self) -> Tensor[] + inline ::std::vector _foreach_trunc(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_trunc::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_trunc_(Tensor(a!)[] self) -> () + inline void _foreach_trunc_(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_trunc_::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_zero_(Tensor(a!)[] self) -> () + inline void _foreach_zero_(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_zero_::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_copy_(Tensor(a!)[] self, Tensor[] src, bool non_blocking=False) -> () + inline void _foreach_copy_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList src, bool non_blocking=false) { + return at::_ops::_foreach_copy_::redispatch(dispatchKeySet, self, src, non_blocking); + } + + // aten::_foreach_copy(Tensor[] self, Tensor[] src, bool non_blocking=False) -> Tensor[] self_out + inline ::std::vector _foreach_copy(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList src, bool non_blocking=false) { + return at::_ops::_foreach_copy::redispatch(dispatchKeySet, self, src, non_blocking); + } + + // aten::bucketize.Tensor(Tensor self, Tensor boundaries, *, bool out_int32=False, bool right=False) -> Tensor + inline at::Tensor bucketize(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & boundaries, bool out_int32=false, bool right=false) { + return at::_ops::bucketize_Tensor::redispatch(dispatchKeySet, self, boundaries, out_int32, right); + } + + // aten::bucketize.Tensor_out(Tensor self, Tensor boundaries, *, bool out_int32=False, bool right=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bucketize_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & boundaries, bool out_int32=false, bool right=false) { + return at::_ops::bucketize_Tensor_out::redispatch(dispatchKeySet, self, boundaries, out_int32, right, out); + } + + // aten::bucketize.Tensor_out(Tensor self, Tensor boundaries, *, bool out_int32=False, bool right=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bucketize_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & boundaries, bool out_int32, bool right, at::Tensor & out) { + return at::_ops::bucketize_Tensor_out::redispatch(dispatchKeySet, self, boundaries, out_int32, right, out); + } + + // aten::bucketize.Scalar(Scalar self, Tensor boundaries, *, bool out_int32=False, bool right=False) -> Tensor + inline at::Tensor bucketize(c10::DispatchKeySet dispatchKeySet, const at::Scalar & self, const at::Tensor & boundaries, bool out_int32=false, bool right=false) { + return at::_ops::bucketize_Scalar::redispatch(dispatchKeySet, self, boundaries, out_int32, right); + } + + // aten::searchsorted.Tensor(Tensor sorted_sequence, Tensor self, *, bool out_int32=False, bool right=False, str? side=None, Tensor? sorter=None) -> Tensor + inline at::Tensor searchsorted(c10::DispatchKeySet dispatchKeySet, const at::Tensor & sorted_sequence, const at::Tensor & self, bool out_int32=false, bool right=false, ::std::optional side=::std::nullopt, const ::std::optional & sorter={}) { + return at::_ops::searchsorted_Tensor::redispatch(dispatchKeySet, sorted_sequence, self, out_int32, right, side, sorter); + } + + // aten::searchsorted.Tensor_out(Tensor sorted_sequence, Tensor self, *, bool out_int32=False, bool right=False, str? side=None, Tensor? sorter=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & searchsorted_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & sorted_sequence, const at::Tensor & self, bool out_int32=false, bool right=false, ::std::optional side=::std::nullopt, const ::std::optional & sorter={}) { + return at::_ops::searchsorted_Tensor_out::redispatch(dispatchKeySet, sorted_sequence, self, out_int32, right, side, sorter, out); + } + + // aten::searchsorted.Tensor_out(Tensor sorted_sequence, Tensor self, *, bool out_int32=False, bool right=False, str? side=None, Tensor? sorter=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & searchsorted_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & sorted_sequence, const at::Tensor & self, bool out_int32, bool right, ::std::optional side, const ::std::optional & sorter, at::Tensor & out) { + return at::_ops::searchsorted_Tensor_out::redispatch(dispatchKeySet, sorted_sequence, self, out_int32, right, side, sorter, out); + } + + // aten::searchsorted.Scalar(Tensor sorted_sequence, Scalar self, *, bool out_int32=False, bool right=False, str? side=None, Tensor? sorter=None) -> Tensor + inline at::Tensor searchsorted(c10::DispatchKeySet dispatchKeySet, const at::Tensor & sorted_sequence, const at::Scalar & self, bool out_int32=false, bool right=false, ::std::optional side=::std::nullopt, const ::std::optional & sorter={}) { + return at::_ops::searchsorted_Scalar::redispatch(dispatchKeySet, sorted_sequence, self, out_int32, right, side, sorter); + } + + // aten::searchsorted.Scalar_out(Tensor sorted_sequence, Scalar self, *, bool out_int32=False, bool right=False, str? side=None, Tensor? sorter=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & searchsorted_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & sorted_sequence, const at::Scalar & self, bool out_int32=false, bool right=false, ::std::optional side=::std::nullopt, const ::std::optional & sorter={}) { + return at::_ops::searchsorted_Scalar_out::redispatch(dispatchKeySet, sorted_sequence, self, out_int32, right, side, sorter, out); + } + + // aten::searchsorted.Scalar_out(Tensor sorted_sequence, Scalar self, *, bool out_int32=False, bool right=False, str? side=None, Tensor? sorter=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & searchsorted_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & sorted_sequence, const at::Scalar & self, bool out_int32, bool right, ::std::optional side, const ::std::optional & sorter, at::Tensor & out) { + return at::_ops::searchsorted_Scalar_out::redispatch(dispatchKeySet, sorted_sequence, self, out_int32, right, side, sorter, out); + } + + // aten::_convert_indices_from_coo_to_csr(Tensor self, int size, *, bool out_int32=False) -> Tensor + inline at::Tensor _convert_indices_from_coo_to_csr(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t size, bool out_int32=false) { + return at::_ops::_convert_indices_from_coo_to_csr::redispatch(dispatchKeySet, self, size, out_int32); + } + + // aten::_convert_indices_from_coo_to_csr.out(Tensor self, int size, *, bool out_int32=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _convert_indices_from_coo_to_csr_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t size, bool out_int32=false) { + return at::_ops::_convert_indices_from_coo_to_csr_out::redispatch(dispatchKeySet, self, size, out_int32, out); + } + + // aten::_convert_indices_from_coo_to_csr.out(Tensor self, int size, *, bool out_int32=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _convert_indices_from_coo_to_csr_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t size, bool out_int32, at::Tensor & out) { + return at::_ops::_convert_indices_from_coo_to_csr_out::redispatch(dispatchKeySet, self, size, out_int32, out); + } + + // aten::_convert_indices_from_csr_to_coo(Tensor crow_indices, Tensor col_indices, *, bool out_int32=False, bool transpose=False) -> Tensor + inline at::Tensor _convert_indices_from_csr_to_coo(c10::DispatchKeySet dispatchKeySet, const at::Tensor & crow_indices, const at::Tensor & col_indices, bool out_int32=false, bool transpose=false) { + return at::_ops::_convert_indices_from_csr_to_coo::redispatch(dispatchKeySet, crow_indices, col_indices, out_int32, transpose); + } + + // aten::_convert_indices_from_csr_to_coo.out(Tensor crow_indices, Tensor col_indices, *, bool out_int32=False, bool transpose=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _convert_indices_from_csr_to_coo_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & crow_indices, const at::Tensor & col_indices, bool out_int32=false, bool transpose=false) { + return at::_ops::_convert_indices_from_csr_to_coo_out::redispatch(dispatchKeySet, crow_indices, col_indices, out_int32, transpose, out); + } + + // aten::_convert_indices_from_csr_to_coo.out(Tensor crow_indices, Tensor col_indices, *, bool out_int32=False, bool transpose=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _convert_indices_from_csr_to_coo_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & crow_indices, const at::Tensor & col_indices, bool out_int32, bool transpose, at::Tensor & out) { + return at::_ops::_convert_indices_from_csr_to_coo_out::redispatch(dispatchKeySet, crow_indices, col_indices, out_int32, transpose, out); + } + + // aten::mse_loss.out(Tensor self, Tensor target, int reduction=Mean, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mse_loss_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & target, int64_t reduction=at::Reduction::Mean) { + return at::_ops::mse_loss_out::redispatch(dispatchKeySet, self, target, reduction, out); + } + + // aten::mse_loss.out(Tensor self, Tensor target, int reduction=Mean, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mse_loss_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, int64_t reduction, at::Tensor & out) { + return at::_ops::mse_loss_out::redispatch(dispatchKeySet, self, target, reduction, out); + } + + // aten::mse_loss(Tensor self, Tensor target, int reduction=Mean) -> Tensor + inline at::Tensor mse_loss(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, int64_t reduction=at::Reduction::Mean) { + return at::_ops::mse_loss::redispatch(dispatchKeySet, self, target, reduction); + } + + // aten::mse_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, int reduction, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & mse_loss_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, int64_t reduction) { + return at::_ops::mse_loss_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, target, reduction, grad_input); + } + + // aten::mse_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, int reduction, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & mse_loss_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, int64_t reduction, at::Tensor & grad_input) { + return at::_ops::mse_loss_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, target, reduction, grad_input); + } + + // aten::mse_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction) -> Tensor + inline at::Tensor mse_loss_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, int64_t reduction) { + return at::_ops::mse_loss_backward::redispatch(dispatchKeySet, grad_output, self, target, reduction); + } + + // aten::l1_loss(Tensor self, Tensor target, int reduction=Mean) -> Tensor + inline at::Tensor l1_loss(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, int64_t reduction=at::Reduction::Mean) { + return at::_ops::l1_loss::redispatch(dispatchKeySet, self, target, reduction); + } + + // aten::multi_margin_loss.out(Tensor self, Tensor target, Scalar p=1, Scalar margin=1, Tensor? weight=None, int reduction=Mean, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & multi_margin_loss_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & target, const at::Scalar & p=1, const at::Scalar & margin=1, const ::std::optional & weight={}, int64_t reduction=at::Reduction::Mean) { + return at::_ops::multi_margin_loss_out::redispatch(dispatchKeySet, self, target, p, margin, weight, reduction, out); + } + + // aten::multi_margin_loss.out(Tensor self, Tensor target, Scalar p=1, Scalar margin=1, Tensor? weight=None, int reduction=Mean, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & multi_margin_loss_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, const at::Scalar & p, const at::Scalar & margin, const ::std::optional & weight, int64_t reduction, at::Tensor & out) { + return at::_ops::multi_margin_loss_out::redispatch(dispatchKeySet, self, target, p, margin, weight, reduction, out); + } + + // aten::multi_margin_loss(Tensor self, Tensor target, Scalar p=1, Scalar margin=1, Tensor? weight=None, int reduction=Mean) -> Tensor + inline at::Tensor multi_margin_loss(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, const at::Scalar & p=1, const at::Scalar & margin=1, const ::std::optional & weight={}, int64_t reduction=at::Reduction::Mean) { + return at::_ops::multi_margin_loss::redispatch(dispatchKeySet, self, target, p, margin, weight, reduction); + } + + // aten::multi_margin_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, Scalar p, Scalar margin, Tensor? weight=None, int reduction=Mean, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & multi_margin_loss_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const at::Scalar & p, const at::Scalar & margin, const ::std::optional & weight={}, int64_t reduction=at::Reduction::Mean) { + return at::_ops::multi_margin_loss_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, target, p, margin, weight, reduction, grad_input); + } + + // aten::multi_margin_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, Scalar p, Scalar margin, Tensor? weight=None, int reduction=Mean, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & multi_margin_loss_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const at::Scalar & p, const at::Scalar & margin, const ::std::optional & weight, int64_t reduction, at::Tensor & grad_input) { + return at::_ops::multi_margin_loss_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, target, p, margin, weight, reduction, grad_input); + } + + // aten::multi_margin_loss_backward(Tensor grad_output, Tensor self, Tensor target, Scalar p, Scalar margin, Tensor? weight=None, int reduction=Mean) -> Tensor + inline at::Tensor multi_margin_loss_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const at::Scalar & p, const at::Scalar & margin, const ::std::optional & weight={}, int64_t reduction=at::Reduction::Mean) { + return at::_ops::multi_margin_loss_backward::redispatch(dispatchKeySet, grad_output, self, target, p, margin, weight, reduction); + } + + // aten::multilabel_margin_loss.out(Tensor self, Tensor target, int reduction=Mean, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & multilabel_margin_loss_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & target, int64_t reduction=at::Reduction::Mean) { + return at::_ops::multilabel_margin_loss_out::redispatch(dispatchKeySet, self, target, reduction, out); + } + + // aten::multilabel_margin_loss.out(Tensor self, Tensor target, int reduction=Mean, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & multilabel_margin_loss_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, int64_t reduction, at::Tensor & out) { + return at::_ops::multilabel_margin_loss_out::redispatch(dispatchKeySet, self, target, reduction, out); + } + + // aten::multilabel_margin_loss(Tensor self, Tensor target, int reduction=Mean) -> Tensor + inline at::Tensor multilabel_margin_loss(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, int64_t reduction=at::Reduction::Mean) { + return at::_ops::multilabel_margin_loss::redispatch(dispatchKeySet, self, target, reduction); + } + + // aten::multilabel_margin_loss_forward.output(Tensor self, Tensor target, int reduction, *, Tensor(a!) output, Tensor(b!) is_target) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple multilabel_margin_loss_forward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & output, at::Tensor & is_target, const at::Tensor & self, const at::Tensor & target, int64_t reduction) { + return at::_ops::multilabel_margin_loss_forward_output::redispatch(dispatchKeySet, self, target, reduction, output, is_target); + } + + // aten::multilabel_margin_loss_forward.output(Tensor self, Tensor target, int reduction, *, Tensor(a!) output, Tensor(b!) is_target) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple multilabel_margin_loss_forward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, int64_t reduction, at::Tensor & output, at::Tensor & is_target) { + return at::_ops::multilabel_margin_loss_forward_output::redispatch(dispatchKeySet, self, target, reduction, output, is_target); + } + + // aten::multilabel_margin_loss_forward(Tensor self, Tensor target, int reduction) -> (Tensor output, Tensor is_target) + inline ::std::tuple multilabel_margin_loss_forward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, int64_t reduction) { + return at::_ops::multilabel_margin_loss_forward::redispatch(dispatchKeySet, self, target, reduction); + } + + // aten::multilabel_margin_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, int reduction, Tensor is_target, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & multilabel_margin_loss_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, int64_t reduction, const at::Tensor & is_target) { + return at::_ops::multilabel_margin_loss_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, target, reduction, is_target, grad_input); + } + + // aten::multilabel_margin_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, int reduction, Tensor is_target, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & multilabel_margin_loss_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, int64_t reduction, const at::Tensor & is_target, at::Tensor & grad_input) { + return at::_ops::multilabel_margin_loss_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, target, reduction, is_target, grad_input); + } + + // aten::multilabel_margin_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction, Tensor is_target) -> Tensor + inline at::Tensor multilabel_margin_loss_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, int64_t reduction, const at::Tensor & is_target) { + return at::_ops::multilabel_margin_loss_backward::redispatch(dispatchKeySet, grad_output, self, target, reduction, is_target); + } + + // aten::nll_loss.out(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & nll_loss_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight={}, int64_t reduction=at::Reduction::Mean, int64_t ignore_index=-100) { + return at::_ops::nll_loss_out::redispatch(dispatchKeySet, self, target, weight, reduction, ignore_index, out); + } + + // aten::nll_loss.out(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & nll_loss_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, int64_t ignore_index, at::Tensor & out) { + return at::_ops::nll_loss_out::redispatch(dispatchKeySet, self, target, weight, reduction, ignore_index, out); + } + + // aten::nll_loss.out(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & nll_loss_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight={}, int64_t reduction=at::Reduction::Mean, c10::SymInt ignore_index=-100) { + return at::_ops::nll_loss_out::redispatch(dispatchKeySet, self, target, weight, reduction, ignore_index, out); + } + + // aten::nll_loss.out(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & nll_loss_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, c10::SymInt ignore_index, at::Tensor & out) { + return at::_ops::nll_loss_out::redispatch(dispatchKeySet, self, target, weight, reduction, ignore_index, out); + } + + // aten::nll_loss_nd(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100) -> Tensor + inline at::Tensor nll_loss_nd(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight={}, int64_t reduction=at::Reduction::Mean, int64_t ignore_index=-100) { + return at::_ops::nll_loss_nd::redispatch(dispatchKeySet, self, target, weight, reduction, ignore_index); + } + + // aten::nll_loss_nd(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100) -> Tensor + inline at::Tensor nll_loss_nd_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight={}, int64_t reduction=at::Reduction::Mean, c10::SymInt ignore_index=-100) { + return at::_ops::nll_loss_nd::redispatch(dispatchKeySet, self, target, weight, reduction, ignore_index); + } + + // aten::nll_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100) -> Tensor + inline at::Tensor nll_loss(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight={}, int64_t reduction=at::Reduction::Mean, int64_t ignore_index=-100) { + return at::_ops::nll_loss::redispatch(dispatchKeySet, self, target, weight, reduction, ignore_index); + } + + // aten::nll_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100) -> Tensor + inline at::Tensor nll_loss_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight={}, int64_t reduction=at::Reduction::Mean, c10::SymInt ignore_index=-100) { + return at::_ops::nll_loss::redispatch(dispatchKeySet, self, target, weight, reduction, ignore_index); + } + + // aten::nll_loss_forward.output(Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, *, Tensor(a!) output, Tensor(b!) total_weight) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple nll_loss_forward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & output, at::Tensor & total_weight, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, int64_t ignore_index) { + return at::_ops::nll_loss_forward_output::redispatch(dispatchKeySet, self, target, weight, reduction, ignore_index, output, total_weight); + } + + // aten::nll_loss_forward.output(Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, *, Tensor(a!) output, Tensor(b!) total_weight) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple nll_loss_forward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, int64_t ignore_index, at::Tensor & output, at::Tensor & total_weight) { + return at::_ops::nll_loss_forward_output::redispatch(dispatchKeySet, self, target, weight, reduction, ignore_index, output, total_weight); + } + + // aten::nll_loss_forward.output(Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, *, Tensor(a!) output, Tensor(b!) total_weight) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple nll_loss_forward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & output, at::Tensor & total_weight, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, c10::SymInt ignore_index) { + return at::_ops::nll_loss_forward_output::redispatch(dispatchKeySet, self, target, weight, reduction, ignore_index, output, total_weight); + } + + // aten::nll_loss_forward.output(Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, *, Tensor(a!) output, Tensor(b!) total_weight) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple nll_loss_forward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, c10::SymInt ignore_index, at::Tensor & output, at::Tensor & total_weight) { + return at::_ops::nll_loss_forward_output::redispatch(dispatchKeySet, self, target, weight, reduction, ignore_index, output, total_weight); + } + + // aten::nll_loss_forward(Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index) -> (Tensor output, Tensor total_weight) + inline ::std::tuple nll_loss_forward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, int64_t ignore_index) { + return at::_ops::nll_loss_forward::redispatch(dispatchKeySet, self, target, weight, reduction, ignore_index); + } + + // aten::nll_loss_forward(Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index) -> (Tensor output, Tensor total_weight) + inline ::std::tuple nll_loss_forward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, c10::SymInt ignore_index) { + return at::_ops::nll_loss_forward::redispatch(dispatchKeySet, self, target, weight, reduction, ignore_index); + } + + // aten::nll_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, Tensor total_weight, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & nll_loss_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, int64_t ignore_index, const at::Tensor & total_weight) { + return at::_ops::nll_loss_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, target, weight, reduction, ignore_index, total_weight, grad_input); + } + + // aten::nll_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, Tensor total_weight, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & nll_loss_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, int64_t ignore_index, const at::Tensor & total_weight, at::Tensor & grad_input) { + return at::_ops::nll_loss_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, target, weight, reduction, ignore_index, total_weight, grad_input); + } + + // aten::nll_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, Tensor total_weight, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & nll_loss_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, c10::SymInt ignore_index, const at::Tensor & total_weight) { + return at::_ops::nll_loss_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, target, weight, reduction, ignore_index, total_weight, grad_input); + } + + // aten::nll_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, Tensor total_weight, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & nll_loss_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, c10::SymInt ignore_index, const at::Tensor & total_weight, at::Tensor & grad_input) { + return at::_ops::nll_loss_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, target, weight, reduction, ignore_index, total_weight, grad_input); + } + + // aten::nll_loss_backward(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, Tensor total_weight) -> Tensor + inline at::Tensor nll_loss_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, int64_t ignore_index, const at::Tensor & total_weight) { + return at::_ops::nll_loss_backward::redispatch(dispatchKeySet, grad_output, self, target, weight, reduction, ignore_index, total_weight); + } + + // aten::nll_loss_backward(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, Tensor total_weight) -> Tensor + inline at::Tensor nll_loss_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, c10::SymInt ignore_index, const at::Tensor & total_weight) { + return at::_ops::nll_loss_backward::redispatch(dispatchKeySet, grad_output, self, target, weight, reduction, ignore_index, total_weight); + } + + // aten::nll_loss2d.out(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & nll_loss2d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight={}, int64_t reduction=at::Reduction::Mean, int64_t ignore_index=-100) { + return at::_ops::nll_loss2d_out::redispatch(dispatchKeySet, self, target, weight, reduction, ignore_index, out); + } + + // aten::nll_loss2d.out(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & nll_loss2d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, int64_t ignore_index, at::Tensor & out) { + return at::_ops::nll_loss2d_out::redispatch(dispatchKeySet, self, target, weight, reduction, ignore_index, out); + } + + // aten::nll_loss2d.out(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & nll_loss2d_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight={}, int64_t reduction=at::Reduction::Mean, c10::SymInt ignore_index=-100) { + return at::_ops::nll_loss2d_out::redispatch(dispatchKeySet, self, target, weight, reduction, ignore_index, out); + } + + // aten::nll_loss2d.out(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & nll_loss2d_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, c10::SymInt ignore_index, at::Tensor & out) { + return at::_ops::nll_loss2d_out::redispatch(dispatchKeySet, self, target, weight, reduction, ignore_index, out); + } + + // aten::nll_loss2d(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100) -> Tensor + inline at::Tensor nll_loss2d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight={}, int64_t reduction=at::Reduction::Mean, int64_t ignore_index=-100) { + return at::_ops::nll_loss2d::redispatch(dispatchKeySet, self, target, weight, reduction, ignore_index); + } + + // aten::nll_loss2d(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100) -> Tensor + inline at::Tensor nll_loss2d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight={}, int64_t reduction=at::Reduction::Mean, c10::SymInt ignore_index=-100) { + return at::_ops::nll_loss2d::redispatch(dispatchKeySet, self, target, weight, reduction, ignore_index); + } + + // aten::nll_loss2d_forward.output(Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, *, Tensor(a!) output, Tensor(b!) total_weight) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple nll_loss2d_forward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & output, at::Tensor & total_weight, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, int64_t ignore_index) { + return at::_ops::nll_loss2d_forward_output::redispatch(dispatchKeySet, self, target, weight, reduction, ignore_index, output, total_weight); + } + + // aten::nll_loss2d_forward.output(Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, *, Tensor(a!) output, Tensor(b!) total_weight) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple nll_loss2d_forward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, int64_t ignore_index, at::Tensor & output, at::Tensor & total_weight) { + return at::_ops::nll_loss2d_forward_output::redispatch(dispatchKeySet, self, target, weight, reduction, ignore_index, output, total_weight); + } + + // aten::nll_loss2d_forward.output(Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, *, Tensor(a!) output, Tensor(b!) total_weight) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple nll_loss2d_forward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & output, at::Tensor & total_weight, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, c10::SymInt ignore_index) { + return at::_ops::nll_loss2d_forward_output::redispatch(dispatchKeySet, self, target, weight, reduction, ignore_index, output, total_weight); + } + + // aten::nll_loss2d_forward.output(Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, *, Tensor(a!) output, Tensor(b!) total_weight) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple nll_loss2d_forward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, c10::SymInt ignore_index, at::Tensor & output, at::Tensor & total_weight) { + return at::_ops::nll_loss2d_forward_output::redispatch(dispatchKeySet, self, target, weight, reduction, ignore_index, output, total_weight); + } + + // aten::nll_loss2d_forward(Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index) -> (Tensor output, Tensor total_weight) + inline ::std::tuple nll_loss2d_forward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, int64_t ignore_index) { + return at::_ops::nll_loss2d_forward::redispatch(dispatchKeySet, self, target, weight, reduction, ignore_index); + } + + // aten::nll_loss2d_forward(Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index) -> (Tensor output, Tensor total_weight) + inline ::std::tuple nll_loss2d_forward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, c10::SymInt ignore_index) { + return at::_ops::nll_loss2d_forward::redispatch(dispatchKeySet, self, target, weight, reduction, ignore_index); + } + + // aten::nll_loss2d_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, Tensor total_weight, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & nll_loss2d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, int64_t ignore_index, const at::Tensor & total_weight) { + return at::_ops::nll_loss2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, target, weight, reduction, ignore_index, total_weight, grad_input); + } + + // aten::nll_loss2d_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, Tensor total_weight, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & nll_loss2d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, int64_t ignore_index, const at::Tensor & total_weight, at::Tensor & grad_input) { + return at::_ops::nll_loss2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, target, weight, reduction, ignore_index, total_weight, grad_input); + } + + // aten::nll_loss2d_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, Tensor total_weight, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & nll_loss2d_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, c10::SymInt ignore_index, const at::Tensor & total_weight) { + return at::_ops::nll_loss2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, target, weight, reduction, ignore_index, total_weight, grad_input); + } + + // aten::nll_loss2d_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, Tensor total_weight, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & nll_loss2d_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, c10::SymInt ignore_index, const at::Tensor & total_weight, at::Tensor & grad_input) { + return at::_ops::nll_loss2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, target, weight, reduction, ignore_index, total_weight, grad_input); + } + + // aten::nll_loss2d_backward(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, Tensor total_weight) -> Tensor + inline at::Tensor nll_loss2d_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, int64_t ignore_index, const at::Tensor & total_weight) { + return at::_ops::nll_loss2d_backward::redispatch(dispatchKeySet, grad_output, self, target, weight, reduction, ignore_index, total_weight); + } + + // aten::nll_loss2d_backward(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, Tensor total_weight) -> Tensor + inline at::Tensor nll_loss2d_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, c10::SymInt ignore_index, const at::Tensor & total_weight) { + return at::_ops::nll_loss2d_backward::redispatch(dispatchKeySet, grad_output, self, target, weight, reduction, ignore_index, total_weight); + } + + // aten::smooth_l1_loss.out(Tensor self, Tensor target, int reduction=Mean, float beta=1.0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & smooth_l1_loss_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & target, int64_t reduction=at::Reduction::Mean, double beta=1.0) { + return at::_ops::smooth_l1_loss_out::redispatch(dispatchKeySet, self, target, reduction, beta, out); + } + + // aten::smooth_l1_loss.out(Tensor self, Tensor target, int reduction=Mean, float beta=1.0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & smooth_l1_loss_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, int64_t reduction, double beta, at::Tensor & out) { + return at::_ops::smooth_l1_loss_out::redispatch(dispatchKeySet, self, target, reduction, beta, out); + } + + // aten::smooth_l1_loss(Tensor self, Tensor target, int reduction=Mean, float beta=1.0) -> Tensor + inline at::Tensor smooth_l1_loss(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, int64_t reduction=at::Reduction::Mean, double beta=1.0) { + return at::_ops::smooth_l1_loss::redispatch(dispatchKeySet, self, target, reduction, beta); + } + + // aten::smooth_l1_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, int reduction, float beta, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & smooth_l1_loss_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, int64_t reduction, double beta) { + return at::_ops::smooth_l1_loss_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, target, reduction, beta, grad_input); + } + + // aten::smooth_l1_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, int reduction, float beta, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & smooth_l1_loss_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, int64_t reduction, double beta, at::Tensor & grad_input) { + return at::_ops::smooth_l1_loss_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, target, reduction, beta, grad_input); + } + + // aten::smooth_l1_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction, float beta) -> Tensor + inline at::Tensor smooth_l1_loss_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, int64_t reduction, double beta) { + return at::_ops::smooth_l1_loss_backward::redispatch(dispatchKeySet, grad_output, self, target, reduction, beta); + } + + // aten::huber_loss.out(Tensor self, Tensor target, int reduction=Mean, float delta=1.0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & huber_loss_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & target, int64_t reduction=at::Reduction::Mean, double delta=1.0) { + return at::_ops::huber_loss_out::redispatch(dispatchKeySet, self, target, reduction, delta, out); + } + + // aten::huber_loss.out(Tensor self, Tensor target, int reduction=Mean, float delta=1.0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & huber_loss_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, int64_t reduction, double delta, at::Tensor & out) { + return at::_ops::huber_loss_out::redispatch(dispatchKeySet, self, target, reduction, delta, out); + } + + // aten::huber_loss(Tensor self, Tensor target, int reduction=Mean, float delta=1.0) -> Tensor + inline at::Tensor huber_loss(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, int64_t reduction=at::Reduction::Mean, double delta=1.0) { + return at::_ops::huber_loss::redispatch(dispatchKeySet, self, target, reduction, delta); + } + + // aten::huber_loss_backward.out(Tensor grad_output, Tensor self, Tensor target, int reduction, float delta, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & huber_loss_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, int64_t reduction, double delta) { + return at::_ops::huber_loss_backward_out::redispatch(dispatchKeySet, grad_output, self, target, reduction, delta, grad_input); + } + + // aten::huber_loss_backward.out(Tensor grad_output, Tensor self, Tensor target, int reduction, float delta, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & huber_loss_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, int64_t reduction, double delta, at::Tensor & grad_input) { + return at::_ops::huber_loss_backward_out::redispatch(dispatchKeySet, grad_output, self, target, reduction, delta, grad_input); + } + + // aten::huber_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction, float delta) -> Tensor + inline at::Tensor huber_loss_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, int64_t reduction, double delta) { + return at::_ops::huber_loss_backward::redispatch(dispatchKeySet, grad_output, self, target, reduction, delta); + } + + // aten::soft_margin_loss.out(Tensor self, Tensor target, int reduction=Mean, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & soft_margin_loss_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & target, int64_t reduction=at::Reduction::Mean) { + return at::_ops::soft_margin_loss_out::redispatch(dispatchKeySet, self, target, reduction, out); + } + + // aten::soft_margin_loss.out(Tensor self, Tensor target, int reduction=Mean, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & soft_margin_loss_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, int64_t reduction, at::Tensor & out) { + return at::_ops::soft_margin_loss_out::redispatch(dispatchKeySet, self, target, reduction, out); + } + + // aten::soft_margin_loss(Tensor self, Tensor target, int reduction=Mean) -> Tensor + inline at::Tensor soft_margin_loss(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, int64_t reduction=at::Reduction::Mean) { + return at::_ops::soft_margin_loss::redispatch(dispatchKeySet, self, target, reduction); + } + + // aten::soft_margin_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, int reduction, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & soft_margin_loss_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, int64_t reduction) { + return at::_ops::soft_margin_loss_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, target, reduction, grad_input); + } + + // aten::soft_margin_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, int reduction, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & soft_margin_loss_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, int64_t reduction, at::Tensor & grad_input) { + return at::_ops::soft_margin_loss_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, target, reduction, grad_input); + } + + // aten::soft_margin_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction) -> Tensor + inline at::Tensor soft_margin_loss_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, int64_t reduction) { + return at::_ops::soft_margin_loss_backward::redispatch(dispatchKeySet, grad_output, self, target, reduction); + } + + // aten::elu.out(Tensor self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & elu_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & alpha=1, const at::Scalar & scale=1, const at::Scalar & input_scale=1) { + return at::_ops::elu_out::redispatch(dispatchKeySet, self, alpha, scale, input_scale, out); + } + + // aten::elu.out(Tensor self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & elu_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & alpha, const at::Scalar & scale, const at::Scalar & input_scale, at::Tensor & out) { + return at::_ops::elu_out::redispatch(dispatchKeySet, self, alpha, scale, input_scale, out); + } + + // aten::elu(Tensor self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) -> Tensor + inline at::Tensor elu(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & alpha=1, const at::Scalar & scale=1, const at::Scalar & input_scale=1) { + return at::_ops::elu::redispatch(dispatchKeySet, self, alpha, scale, input_scale); + } + + // aten::elu_backward.grad_input(Tensor grad_output, Scalar alpha, Scalar scale, Scalar input_scale, bool is_result, Tensor self_or_result, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & elu_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Scalar & alpha, const at::Scalar & scale, const at::Scalar & input_scale, bool is_result, const at::Tensor & self_or_result) { + return at::_ops::elu_backward_grad_input::redispatch(dispatchKeySet, grad_output, alpha, scale, input_scale, is_result, self_or_result, grad_input); + } + + // aten::elu_backward.grad_input(Tensor grad_output, Scalar alpha, Scalar scale, Scalar input_scale, bool is_result, Tensor self_or_result, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & elu_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Scalar & alpha, const at::Scalar & scale, const at::Scalar & input_scale, bool is_result, const at::Tensor & self_or_result, at::Tensor & grad_input) { + return at::_ops::elu_backward_grad_input::redispatch(dispatchKeySet, grad_output, alpha, scale, input_scale, is_result, self_or_result, grad_input); + } + + // aten::elu_backward(Tensor grad_output, Scalar alpha, Scalar scale, Scalar input_scale, bool is_result, Tensor self_or_result) -> Tensor + inline at::Tensor elu_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Scalar & alpha, const at::Scalar & scale, const at::Scalar & input_scale, bool is_result, const at::Tensor & self_or_result) { + return at::_ops::elu_backward::redispatch(dispatchKeySet, grad_output, alpha, scale, input_scale, is_result, self_or_result); + } + + // aten::elu_(Tensor(a!) self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) -> Tensor(a!) + inline at::Tensor & elu_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & alpha=1, const at::Scalar & scale=1, const at::Scalar & input_scale=1) { + return at::_ops::elu_::redispatch(dispatchKeySet, self, alpha, scale, input_scale); + } + + // aten::glu.out(Tensor self, int dim=-1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & glu_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim=-1) { + return at::_ops::glu_out::redispatch(dispatchKeySet, self, dim, out); + } + + // aten::glu.out(Tensor self, int dim=-1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & glu_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, at::Tensor & out) { + return at::_ops::glu_out::redispatch(dispatchKeySet, self, dim, out); + } + + // aten::glu(Tensor self, int dim=-1) -> Tensor + inline at::Tensor glu(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim=-1) { + return at::_ops::glu::redispatch(dispatchKeySet, self, dim); + } + + // aten::glu_backward.grad_input(Tensor grad_output, Tensor self, int dim, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & glu_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, int64_t dim) { + return at::_ops::glu_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, dim, grad_input); + } + + // aten::glu_backward.grad_input(Tensor grad_output, Tensor self, int dim, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & glu_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, int64_t dim, at::Tensor & grad_input) { + return at::_ops::glu_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, dim, grad_input); + } + + // aten::glu_backward(Tensor grad_output, Tensor self, int dim) -> Tensor + inline at::Tensor glu_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, int64_t dim) { + return at::_ops::glu_backward::redispatch(dispatchKeySet, grad_output, self, dim); + } + + // aten::glu_jvp(Tensor glu, Tensor x, Tensor dx, int dim) -> Tensor + inline at::Tensor glu_jvp(c10::DispatchKeySet dispatchKeySet, const at::Tensor & glu, const at::Tensor & x, const at::Tensor & dx, int64_t dim) { + return at::_ops::glu_jvp::redispatch(dispatchKeySet, glu, x, dx, dim); + } + + // aten::glu_backward_jvp(Tensor grad_x, Tensor grad_glu, Tensor x, Tensor dgrad_glu, Tensor dx, int dim) -> Tensor + inline at::Tensor glu_backward_jvp(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_x, const at::Tensor & grad_glu, const at::Tensor & x, const at::Tensor & dgrad_glu, const at::Tensor & dx, int64_t dim) { + return at::_ops::glu_backward_jvp::redispatch(dispatchKeySet, grad_x, grad_glu, x, dgrad_glu, dx, dim); + } + + // aten::hardsigmoid.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & hardsigmoid_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::hardsigmoid_out::redispatch(dispatchKeySet, self, out); + } + + // aten::hardsigmoid.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & hardsigmoid_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::hardsigmoid_out::redispatch(dispatchKeySet, self, out); + } + + // aten::hardsigmoid(Tensor self) -> Tensor + inline at::Tensor hardsigmoid(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::hardsigmoid::redispatch(dispatchKeySet, self); + } + + // aten::hardsigmoid_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & hardsigmoid_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::hardsigmoid_::redispatch(dispatchKeySet, self); + } + + // aten::hardsigmoid_backward.grad_input(Tensor grad_output, Tensor self, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & hardsigmoid_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self) { + return at::_ops::hardsigmoid_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, grad_input); + } + + // aten::hardsigmoid_backward.grad_input(Tensor grad_output, Tensor self, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & hardsigmoid_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::Tensor & grad_input) { + return at::_ops::hardsigmoid_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, grad_input); + } + + // aten::hardsigmoid_backward(Tensor grad_output, Tensor self) -> Tensor + inline at::Tensor hardsigmoid_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self) { + return at::_ops::hardsigmoid_backward::redispatch(dispatchKeySet, grad_output, self); + } + + // aten::hardtanh.out(Tensor self, Scalar min_val=-1, Scalar max_val=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & hardtanh_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & min_val=-1, const at::Scalar & max_val=1) { + return at::_ops::hardtanh_out::redispatch(dispatchKeySet, self, min_val, max_val, out); + } + + // aten::hardtanh.out(Tensor self, Scalar min_val=-1, Scalar max_val=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & hardtanh_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & min_val, const at::Scalar & max_val, at::Tensor & out) { + return at::_ops::hardtanh_out::redispatch(dispatchKeySet, self, min_val, max_val, out); + } + + // aten::hardtanh(Tensor self, Scalar min_val=-1, Scalar max_val=1) -> Tensor + inline at::Tensor hardtanh(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & min_val=-1, const at::Scalar & max_val=1) { + return at::_ops::hardtanh::redispatch(dispatchKeySet, self, min_val, max_val); + } + + // aten::hardtanh_backward.grad_input(Tensor grad_output, Tensor self, Scalar min_val, Scalar max_val, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & hardtanh_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, const at::Scalar & min_val, const at::Scalar & max_val) { + return at::_ops::hardtanh_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, min_val, max_val, grad_input); + } + + // aten::hardtanh_backward.grad_input(Tensor grad_output, Tensor self, Scalar min_val, Scalar max_val, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & hardtanh_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Scalar & min_val, const at::Scalar & max_val, at::Tensor & grad_input) { + return at::_ops::hardtanh_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, min_val, max_val, grad_input); + } + + // aten::hardtanh_backward(Tensor grad_output, Tensor self, Scalar min_val, Scalar max_val) -> Tensor + inline at::Tensor hardtanh_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Scalar & min_val, const at::Scalar & max_val) { + return at::_ops::hardtanh_backward::redispatch(dispatchKeySet, grad_output, self, min_val, max_val); + } + + // aten::hardtanh_(Tensor(a!) self, Scalar min_val=-1, Scalar max_val=1) -> Tensor(a!) + inline at::Tensor & hardtanh_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & min_val=-1, const at::Scalar & max_val=1) { + return at::_ops::hardtanh_::redispatch(dispatchKeySet, self, min_val, max_val); + } + + // aten::hardswish.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & hardswish_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::hardswish_out::redispatch(dispatchKeySet, self, out); + } + + // aten::hardswish.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & hardswish_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::hardswish_out::redispatch(dispatchKeySet, self, out); + } + + // aten::hardswish(Tensor self) -> Tensor + inline at::Tensor hardswish(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::hardswish::redispatch(dispatchKeySet, self); + } + + // aten::hardswish_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & hardswish_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::hardswish_::redispatch(dispatchKeySet, self); + } + + // aten::hardswish_backward(Tensor grad_output, Tensor self) -> Tensor + inline at::Tensor hardswish_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self) { + return at::_ops::hardswish_backward::redispatch(dispatchKeySet, grad_output, self); + } + + // aten::leaky_relu.out(Tensor self, Scalar negative_slope=0.01, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & leaky_relu_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & negative_slope=0.01) { + return at::_ops::leaky_relu_out::redispatch(dispatchKeySet, self, negative_slope, out); + } + + // aten::leaky_relu.out(Tensor self, Scalar negative_slope=0.01, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & leaky_relu_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & negative_slope, at::Tensor & out) { + return at::_ops::leaky_relu_out::redispatch(dispatchKeySet, self, negative_slope, out); + } + + // aten::leaky_relu(Tensor self, Scalar negative_slope=0.01) -> Tensor + inline at::Tensor leaky_relu(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & negative_slope=0.01) { + return at::_ops::leaky_relu::redispatch(dispatchKeySet, self, negative_slope); + } + + // aten::leaky_relu_backward.grad_input(Tensor grad_output, Tensor self, Scalar negative_slope, bool self_is_result, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & leaky_relu_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, const at::Scalar & negative_slope, bool self_is_result) { + return at::_ops::leaky_relu_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, negative_slope, self_is_result, grad_input); + } + + // aten::leaky_relu_backward.grad_input(Tensor grad_output, Tensor self, Scalar negative_slope, bool self_is_result, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & leaky_relu_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Scalar & negative_slope, bool self_is_result, at::Tensor & grad_input) { + return at::_ops::leaky_relu_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, negative_slope, self_is_result, grad_input); + } + + // aten::leaky_relu_backward(Tensor grad_output, Tensor self, Scalar negative_slope, bool self_is_result) -> Tensor + inline at::Tensor leaky_relu_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Scalar & negative_slope, bool self_is_result) { + return at::_ops::leaky_relu_backward::redispatch(dispatchKeySet, grad_output, self, negative_slope, self_is_result); + } + + // aten::leaky_relu_(Tensor(a!) self, Scalar negative_slope=0.01) -> Tensor(a!) + inline at::Tensor & leaky_relu_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & negative_slope=0.01) { + return at::_ops::leaky_relu_::redispatch(dispatchKeySet, self, negative_slope); + } + + // aten::log_sigmoid.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & log_sigmoid_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::log_sigmoid_out::redispatch(dispatchKeySet, self, out); + } + + // aten::log_sigmoid.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & log_sigmoid_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::log_sigmoid_out::redispatch(dispatchKeySet, self, out); + } + + // aten::log_sigmoid(Tensor self) -> Tensor + inline at::Tensor log_sigmoid(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::log_sigmoid::redispatch(dispatchKeySet, self); + } + + // aten::log_sigmoid_forward.output(Tensor self, *, Tensor(a!) output, Tensor(b!) buffer) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple log_sigmoid_forward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & output, at::Tensor & buffer, const at::Tensor & self) { + return at::_ops::log_sigmoid_forward_output::redispatch(dispatchKeySet, self, output, buffer); + } + + // aten::log_sigmoid_forward.output(Tensor self, *, Tensor(a!) output, Tensor(b!) buffer) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple log_sigmoid_forward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & output, at::Tensor & buffer) { + return at::_ops::log_sigmoid_forward_output::redispatch(dispatchKeySet, self, output, buffer); + } + + // aten::log_sigmoid_forward(Tensor self) -> (Tensor output, Tensor buffer) + inline ::std::tuple log_sigmoid_forward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::log_sigmoid_forward::redispatch(dispatchKeySet, self); + } + + // aten::log_sigmoid_backward.grad_input(Tensor grad_output, Tensor self, Tensor buffer, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & log_sigmoid_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & buffer) { + return at::_ops::log_sigmoid_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, buffer, grad_input); + } + + // aten::log_sigmoid_backward.grad_input(Tensor grad_output, Tensor self, Tensor buffer, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & log_sigmoid_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & buffer, at::Tensor & grad_input) { + return at::_ops::log_sigmoid_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, buffer, grad_input); + } + + // aten::log_sigmoid_backward(Tensor grad_output, Tensor self, Tensor buffer) -> Tensor + inline at::Tensor log_sigmoid_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & buffer) { + return at::_ops::log_sigmoid_backward::redispatch(dispatchKeySet, grad_output, self, buffer); + } + + // aten::rrelu_with_noise.out(Tensor self, Tensor(b!) noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & rrelu_with_noise_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::Tensor & noise, const at::Scalar & lower=0.125, const at::Scalar & upper=0.3333333333333333, bool training=false, ::std::optional generator=::std::nullopt) { + return at::_ops::rrelu_with_noise_out::redispatch(dispatchKeySet, self, noise, lower, upper, training, generator, out); + } + + // aten::rrelu_with_noise.out(Tensor self, Tensor(b!) noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & rrelu_with_noise_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & noise, const at::Scalar & lower, const at::Scalar & upper, bool training, ::std::optional generator, at::Tensor & out) { + return at::_ops::rrelu_with_noise_out::redispatch(dispatchKeySet, self, noise, lower, upper, training, generator, out); + } + + // aten::rrelu_with_noise(Tensor self, Tensor(b!) noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor + inline at::Tensor rrelu_with_noise(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & noise, const at::Scalar & lower=0.125, const at::Scalar & upper=0.3333333333333333, bool training=false, ::std::optional generator=::std::nullopt) { + return at::_ops::rrelu_with_noise::redispatch(dispatchKeySet, self, noise, lower, upper, training, generator); + } + + // aten::rrelu_with_noise_backward(Tensor grad_output, Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training, bool self_is_result) -> Tensor + inline at::Tensor rrelu_with_noise_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & noise, const at::Scalar & lower, const at::Scalar & upper, bool training, bool self_is_result) { + return at::_ops::rrelu_with_noise_backward::redispatch(dispatchKeySet, grad_output, self, noise, lower, upper, training, self_is_result); + } + + // aten::rrelu_with_noise_(Tensor(a!) self, Tensor(b!) noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor(a!) + inline at::Tensor & rrelu_with_noise_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, at::Tensor & noise, const at::Scalar & lower=0.125, const at::Scalar & upper=0.3333333333333333, bool training=false, ::std::optional generator=::std::nullopt) { + return at::_ops::rrelu_with_noise_::redispatch(dispatchKeySet, self, noise, lower, upper, training, generator); + } + + // aten::softplus.out(Tensor self, Scalar beta=1, Scalar threshold=20, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & softplus_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & beta=1, const at::Scalar & threshold=20) { + return at::_ops::softplus_out::redispatch(dispatchKeySet, self, beta, threshold, out); + } + + // aten::softplus.out(Tensor self, Scalar beta=1, Scalar threshold=20, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & softplus_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & beta, const at::Scalar & threshold, at::Tensor & out) { + return at::_ops::softplus_out::redispatch(dispatchKeySet, self, beta, threshold, out); + } + + // aten::softplus(Tensor self, Scalar beta=1, Scalar threshold=20) -> Tensor + inline at::Tensor softplus(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & beta=1, const at::Scalar & threshold=20) { + return at::_ops::softplus::redispatch(dispatchKeySet, self, beta, threshold); + } + + // aten::softplus_backward.grad_input(Tensor grad_output, Tensor self, Scalar beta, Scalar threshold, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & softplus_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, const at::Scalar & beta, const at::Scalar & threshold) { + return at::_ops::softplus_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, beta, threshold, grad_input); + } + + // aten::softplus_backward.grad_input(Tensor grad_output, Tensor self, Scalar beta, Scalar threshold, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & softplus_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Scalar & beta, const at::Scalar & threshold, at::Tensor & grad_input) { + return at::_ops::softplus_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, beta, threshold, grad_input); + } + + // aten::softplus_backward(Tensor grad_output, Tensor self, Scalar beta, Scalar threshold) -> Tensor + inline at::Tensor softplus_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Scalar & beta, const at::Scalar & threshold) { + return at::_ops::softplus_backward::redispatch(dispatchKeySet, grad_output, self, beta, threshold); + } + + // aten::softshrink.out(Tensor self, Scalar lambd=0.5, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & softshrink_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & lambd=0.5) { + return at::_ops::softshrink_out::redispatch(dispatchKeySet, self, lambd, out); + } + + // aten::softshrink.out(Tensor self, Scalar lambd=0.5, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & softshrink_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & lambd, at::Tensor & out) { + return at::_ops::softshrink_out::redispatch(dispatchKeySet, self, lambd, out); + } + + // aten::softshrink(Tensor self, Scalar lambd=0.5) -> Tensor + inline at::Tensor softshrink(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & lambd=0.5) { + return at::_ops::softshrink::redispatch(dispatchKeySet, self, lambd); + } + + // aten::softshrink_backward.grad_input(Tensor grad_output, Tensor self, Scalar lambd, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & softshrink_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, const at::Scalar & lambd) { + return at::_ops::softshrink_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, lambd, grad_input); + } + + // aten::softshrink_backward.grad_input(Tensor grad_output, Tensor self, Scalar lambd, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & softshrink_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Scalar & lambd, at::Tensor & grad_input) { + return at::_ops::softshrink_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, lambd, grad_input); + } + + // aten::softshrink_backward(Tensor grad_output, Tensor self, Scalar lambd) -> Tensor + inline at::Tensor softshrink_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Scalar & lambd) { + return at::_ops::softshrink_backward::redispatch(dispatchKeySet, grad_output, self, lambd); + } + + // aten::adaptive_avg_pool2d.out(Tensor self, SymInt[2] output_size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & adaptive_avg_pool2d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef output_size) { + return at::_ops::adaptive_avg_pool2d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), out); + } + + // aten::adaptive_avg_pool2d.out(Tensor self, SymInt[2] output_size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & adaptive_avg_pool2d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, at::Tensor & out) { + return at::_ops::adaptive_avg_pool2d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), out); + } + + // aten::adaptive_avg_pool2d.out(Tensor self, SymInt[2] output_size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & adaptive_avg_pool2d_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef output_size) { + return at::_ops::adaptive_avg_pool2d_out::redispatch(dispatchKeySet, self, output_size, out); + } + + // aten::adaptive_avg_pool2d.out(Tensor self, SymInt[2] output_size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & adaptive_avg_pool2d_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size, at::Tensor & out) { + return at::_ops::adaptive_avg_pool2d_out::redispatch(dispatchKeySet, self, output_size, out); + } + + // aten::adaptive_avg_pool2d(Tensor self, SymInt[2] output_size) -> Tensor + inline at::Tensor adaptive_avg_pool2d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size) { + return at::_ops::adaptive_avg_pool2d::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size)); + } + + // aten::adaptive_avg_pool2d(Tensor self, SymInt[2] output_size) -> Tensor + inline at::Tensor adaptive_avg_pool2d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size) { + return at::_ops::adaptive_avg_pool2d::redispatch(dispatchKeySet, self, output_size); + } + + // aten::mkldnn_adaptive_avg_pool2d(Tensor self, int[2] output_size) -> Tensor + inline at::Tensor mkldnn_adaptive_avg_pool2d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size) { + return at::_ops::mkldnn_adaptive_avg_pool2d::redispatch(dispatchKeySet, self, output_size); + } + + // aten::mkldnn_adaptive_avg_pool2d.out(Tensor self, int[2] output_size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mkldnn_adaptive_avg_pool2d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef output_size) { + return at::_ops::mkldnn_adaptive_avg_pool2d_out::redispatch(dispatchKeySet, self, output_size, out); + } + + // aten::mkldnn_adaptive_avg_pool2d.out(Tensor self, int[2] output_size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mkldnn_adaptive_avg_pool2d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, at::Tensor & out) { + return at::_ops::mkldnn_adaptive_avg_pool2d_out::redispatch(dispatchKeySet, self, output_size, out); + } + + // aten::mkldnn_adaptive_avg_pool2d_backward(Tensor grad_output, Tensor self) -> Tensor + inline at::Tensor mkldnn_adaptive_avg_pool2d_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self) { + return at::_ops::mkldnn_adaptive_avg_pool2d_backward::redispatch(dispatchKeySet, grad_output, self); + } + + // aten::_adaptive_avg_pool2d(Tensor self, SymInt[2] output_size) -> Tensor + inline at::Tensor _adaptive_avg_pool2d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size) { + return at::_ops::_adaptive_avg_pool2d::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size)); + } + + // aten::_adaptive_avg_pool2d(Tensor self, SymInt[2] output_size) -> Tensor + inline at::Tensor _adaptive_avg_pool2d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size) { + return at::_ops::_adaptive_avg_pool2d::redispatch(dispatchKeySet, self, output_size); + } + + // aten::_adaptive_avg_pool2d_backward(Tensor grad_output, Tensor self) -> Tensor + inline at::Tensor _adaptive_avg_pool2d_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self) { + return at::_ops::_adaptive_avg_pool2d_backward::redispatch(dispatchKeySet, grad_output, self); + } + + // aten::adaptive_avg_pool3d.out(Tensor self, SymInt[3] output_size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & adaptive_avg_pool3d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef output_size) { + return at::_ops::adaptive_avg_pool3d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), out); + } + + // aten::adaptive_avg_pool3d.out(Tensor self, SymInt[3] output_size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & adaptive_avg_pool3d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, at::Tensor & out) { + return at::_ops::adaptive_avg_pool3d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), out); + } + + // aten::adaptive_avg_pool3d.out(Tensor self, SymInt[3] output_size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & adaptive_avg_pool3d_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef output_size) { + return at::_ops::adaptive_avg_pool3d_out::redispatch(dispatchKeySet, self, output_size, out); + } + + // aten::adaptive_avg_pool3d.out(Tensor self, SymInt[3] output_size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & adaptive_avg_pool3d_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size, at::Tensor & out) { + return at::_ops::adaptive_avg_pool3d_out::redispatch(dispatchKeySet, self, output_size, out); + } + + // aten::adaptive_avg_pool3d(Tensor self, SymInt[3] output_size) -> Tensor + inline at::Tensor adaptive_avg_pool3d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size) { + return at::_ops::adaptive_avg_pool3d::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size)); + } + + // aten::adaptive_avg_pool3d(Tensor self, SymInt[3] output_size) -> Tensor + inline at::Tensor adaptive_avg_pool3d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size) { + return at::_ops::adaptive_avg_pool3d::redispatch(dispatchKeySet, self, output_size); + } + + // aten::_adaptive_avg_pool3d(Tensor self, SymInt[3] output_size) -> Tensor + inline at::Tensor _adaptive_avg_pool3d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size) { + return at::_ops::_adaptive_avg_pool3d::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size)); + } + + // aten::_adaptive_avg_pool3d(Tensor self, SymInt[3] output_size) -> Tensor + inline at::Tensor _adaptive_avg_pool3d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size) { + return at::_ops::_adaptive_avg_pool3d::redispatch(dispatchKeySet, self, output_size); + } + + // aten::adaptive_avg_pool3d_backward.grad_input(Tensor grad_output, Tensor self, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & adaptive_avg_pool3d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self) { + return at::_ops::adaptive_avg_pool3d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, grad_input); + } + + // aten::adaptive_avg_pool3d_backward.grad_input(Tensor grad_output, Tensor self, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & adaptive_avg_pool3d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::Tensor & grad_input) { + return at::_ops::adaptive_avg_pool3d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, grad_input); + } + + // aten::_adaptive_avg_pool3d_backward(Tensor grad_output, Tensor self) -> Tensor + inline at::Tensor _adaptive_avg_pool3d_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self) { + return at::_ops::_adaptive_avg_pool3d_backward::redispatch(dispatchKeySet, grad_output, self); + } + + // aten::adaptive_max_pool2d.out(Tensor self, int[2] output_size, *, Tensor(a!) out, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple adaptive_max_pool2d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::Tensor & indices, const at::Tensor & self, at::IntArrayRef output_size) { + return at::_ops::adaptive_max_pool2d_out::redispatch(dispatchKeySet, self, output_size, out, indices); + } + + // aten::adaptive_max_pool2d.out(Tensor self, int[2] output_size, *, Tensor(a!) out, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple adaptive_max_pool2d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, at::Tensor & out, at::Tensor & indices) { + return at::_ops::adaptive_max_pool2d_out::redispatch(dispatchKeySet, self, output_size, out, indices); + } + + // aten::adaptive_max_pool2d(Tensor self, int[2] output_size) -> (Tensor, Tensor) + inline ::std::tuple adaptive_max_pool2d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size) { + return at::_ops::adaptive_max_pool2d::redispatch(dispatchKeySet, self, output_size); + } + + // aten::adaptive_max_pool2d_backward.grad_input(Tensor grad_output, Tensor self, Tensor indices, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & adaptive_max_pool2d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices) { + return at::_ops::adaptive_max_pool2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, indices, grad_input); + } + + // aten::adaptive_max_pool2d_backward.grad_input(Tensor grad_output, Tensor self, Tensor indices, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & adaptive_max_pool2d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices, at::Tensor & grad_input) { + return at::_ops::adaptive_max_pool2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, indices, grad_input); + } + + // aten::adaptive_max_pool2d_backward(Tensor grad_output, Tensor self, Tensor indices) -> Tensor + inline at::Tensor adaptive_max_pool2d_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices) { + return at::_ops::adaptive_max_pool2d_backward::redispatch(dispatchKeySet, grad_output, self, indices); + } + + // aten::adaptive_max_pool3d.out(Tensor self, int[3] output_size, *, Tensor(a!) out, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple adaptive_max_pool3d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::Tensor & indices, const at::Tensor & self, at::IntArrayRef output_size) { + return at::_ops::adaptive_max_pool3d_out::redispatch(dispatchKeySet, self, output_size, out, indices); + } + + // aten::adaptive_max_pool3d.out(Tensor self, int[3] output_size, *, Tensor(a!) out, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple adaptive_max_pool3d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, at::Tensor & out, at::Tensor & indices) { + return at::_ops::adaptive_max_pool3d_out::redispatch(dispatchKeySet, self, output_size, out, indices); + } + + // aten::adaptive_max_pool3d(Tensor self, int[3] output_size) -> (Tensor, Tensor) + inline ::std::tuple adaptive_max_pool3d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size) { + return at::_ops::adaptive_max_pool3d::redispatch(dispatchKeySet, self, output_size); + } + + // aten::adaptive_max_pool3d_backward.grad_input(Tensor grad_output, Tensor self, Tensor indices, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & adaptive_max_pool3d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices) { + return at::_ops::adaptive_max_pool3d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, indices, grad_input); + } + + // aten::adaptive_max_pool3d_backward.grad_input(Tensor grad_output, Tensor self, Tensor indices, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & adaptive_max_pool3d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices, at::Tensor & grad_input) { + return at::_ops::adaptive_max_pool3d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, indices, grad_input); + } + + // aten::adaptive_max_pool3d_backward(Tensor grad_output, Tensor self, Tensor indices) -> Tensor + inline at::Tensor adaptive_max_pool3d_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices) { + return at::_ops::adaptive_max_pool3d_backward::redispatch(dispatchKeySet, grad_output, self, indices); + } + + // aten::avg_pool2d.out(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & avg_pool2d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, bool ceil_mode=false, bool count_include_pad=true, ::std::optional divisor_override=::std::nullopt) { + return at::_ops::avg_pool2d_out::redispatch(dispatchKeySet, self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override, out); + } + + // aten::avg_pool2d.out(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & avg_pool2d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override, at::Tensor & out) { + return at::_ops::avg_pool2d_out::redispatch(dispatchKeySet, self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override, out); + } + + // aten::avg_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None) -> Tensor + inline at::Tensor avg_pool2d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, bool ceil_mode=false, bool count_include_pad=true, ::std::optional divisor_override=::std::nullopt) { + return at::_ops::avg_pool2d::redispatch(dispatchKeySet, self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override); + } + + // aten::avg_pool2d_backward.grad_input(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride, int[2] padding, bool ceil_mode, bool count_include_pad, int? divisor_override, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & avg_pool2d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override) { + return at::_ops::avg_pool2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override, grad_input); + } + + // aten::avg_pool2d_backward.grad_input(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride, int[2] padding, bool ceil_mode, bool count_include_pad, int? divisor_override, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & avg_pool2d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override, at::Tensor & grad_input) { + return at::_ops::avg_pool2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override, grad_input); + } + + // aten::avg_pool2d_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride, int[2] padding, bool ceil_mode, bool count_include_pad, int? divisor_override) -> Tensor + inline at::Tensor avg_pool2d_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override) { + return at::_ops::avg_pool2d_backward::redispatch(dispatchKeySet, grad_output, self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override); + } + + // aten::avg_pool3d.out(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & avg_pool3d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, bool ceil_mode=false, bool count_include_pad=true, ::std::optional divisor_override=::std::nullopt) { + return at::_ops::avg_pool3d_out::redispatch(dispatchKeySet, self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override, out); + } + + // aten::avg_pool3d.out(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & avg_pool3d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override, at::Tensor & out) { + return at::_ops::avg_pool3d_out::redispatch(dispatchKeySet, self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override, out); + } + + // aten::avg_pool3d(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None) -> Tensor + inline at::Tensor avg_pool3d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, bool ceil_mode=false, bool count_include_pad=true, ::std::optional divisor_override=::std::nullopt) { + return at::_ops::avg_pool3d::redispatch(dispatchKeySet, self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override); + } + + // aten::avg_pool3d_backward.grad_input(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] stride, int[3] padding, bool ceil_mode, bool count_include_pad, int? divisor_override, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & avg_pool3d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override) { + return at::_ops::avg_pool3d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override, grad_input); + } + + // aten::avg_pool3d_backward.grad_input(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] stride, int[3] padding, bool ceil_mode, bool count_include_pad, int? divisor_override, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & avg_pool3d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override, at::Tensor & grad_input) { + return at::_ops::avg_pool3d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override, grad_input); + } + + // aten::avg_pool3d_backward(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] stride, int[3] padding, bool ceil_mode, bool count_include_pad, int? divisor_override) -> Tensor + inline at::Tensor avg_pool3d_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override) { + return at::_ops::avg_pool3d_backward::redispatch(dispatchKeySet, grad_output, self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override); + } + + // aten::fractional_max_pool2d.output(Tensor self, int[2] kernel_size, int[2] output_size, Tensor random_samples, *, Tensor(a!) output, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple fractional_max_pool2d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & output, at::Tensor & indices, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef output_size, const at::Tensor & random_samples) { + return at::_ops::fractional_max_pool2d_output::redispatch(dispatchKeySet, self, kernel_size, output_size, random_samples, output, indices); + } + + // aten::fractional_max_pool2d.output(Tensor self, int[2] kernel_size, int[2] output_size, Tensor random_samples, *, Tensor(a!) output, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple fractional_max_pool2d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef output_size, const at::Tensor & random_samples, at::Tensor & output, at::Tensor & indices) { + return at::_ops::fractional_max_pool2d_output::redispatch(dispatchKeySet, self, kernel_size, output_size, random_samples, output, indices); + } + + // aten::fractional_max_pool2d(Tensor self, int[2] kernel_size, int[2] output_size, Tensor random_samples) -> (Tensor, Tensor) + inline ::std::tuple fractional_max_pool2d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef output_size, const at::Tensor & random_samples) { + return at::_ops::fractional_max_pool2d::redispatch(dispatchKeySet, self, kernel_size, output_size, random_samples); + } + + // aten::fractional_max_pool2d_backward.grad_input(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] output_size, Tensor indices, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & fractional_max_pool2d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef output_size, const at::Tensor & indices) { + return at::_ops::fractional_max_pool2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, kernel_size, output_size, indices, grad_input); + } + + // aten::fractional_max_pool2d_backward.grad_input(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] output_size, Tensor indices, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & fractional_max_pool2d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef output_size, const at::Tensor & indices, at::Tensor & grad_input) { + return at::_ops::fractional_max_pool2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, kernel_size, output_size, indices, grad_input); + } + + // aten::fractional_max_pool2d_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] output_size, Tensor indices) -> Tensor + inline at::Tensor fractional_max_pool2d_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef output_size, const at::Tensor & indices) { + return at::_ops::fractional_max_pool2d_backward::redispatch(dispatchKeySet, grad_output, self, kernel_size, output_size, indices); + } + + // aten::fractional_max_pool3d.output(Tensor self, int[3] kernel_size, int[3] output_size, Tensor random_samples, *, Tensor(a!) output, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple fractional_max_pool3d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & output, at::Tensor & indices, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef output_size, const at::Tensor & random_samples) { + return at::_ops::fractional_max_pool3d_output::redispatch(dispatchKeySet, self, kernel_size, output_size, random_samples, output, indices); + } + + // aten::fractional_max_pool3d.output(Tensor self, int[3] kernel_size, int[3] output_size, Tensor random_samples, *, Tensor(a!) output, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple fractional_max_pool3d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef output_size, const at::Tensor & random_samples, at::Tensor & output, at::Tensor & indices) { + return at::_ops::fractional_max_pool3d_output::redispatch(dispatchKeySet, self, kernel_size, output_size, random_samples, output, indices); + } + + // aten::fractional_max_pool3d(Tensor self, int[3] kernel_size, int[3] output_size, Tensor random_samples) -> (Tensor, Tensor) + inline ::std::tuple fractional_max_pool3d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef output_size, const at::Tensor & random_samples) { + return at::_ops::fractional_max_pool3d::redispatch(dispatchKeySet, self, kernel_size, output_size, random_samples); + } + + // aten::fractional_max_pool3d_backward.grad_input(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] output_size, Tensor indices, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & fractional_max_pool3d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef output_size, const at::Tensor & indices) { + return at::_ops::fractional_max_pool3d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, kernel_size, output_size, indices, grad_input); + } + + // aten::fractional_max_pool3d_backward.grad_input(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] output_size, Tensor indices, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & fractional_max_pool3d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef output_size, const at::Tensor & indices, at::Tensor & grad_input) { + return at::_ops::fractional_max_pool3d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, kernel_size, output_size, indices, grad_input); + } + + // aten::fractional_max_pool3d_backward(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] output_size, Tensor indices) -> Tensor + inline at::Tensor fractional_max_pool3d_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef output_size, const at::Tensor & indices) { + return at::_ops::fractional_max_pool3d_backward::redispatch(dispatchKeySet, grad_output, self, kernel_size, output_size, indices); + } + + // aten::max_pool2d_with_indices.out(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False, *, Tensor(a!) out, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple max_pool2d_with_indices_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::Tensor & indices, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, at::IntArrayRef dilation=1, bool ceil_mode=false) { + return at::_ops::max_pool2d_with_indices_out::redispatch(dispatchKeySet, self, kernel_size, stride, padding, dilation, ceil_mode, out, indices); + } + + // aten::max_pool2d_with_indices.out(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False, *, Tensor(a!) out, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple max_pool2d_with_indices_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, at::Tensor & out, at::Tensor & indices) { + return at::_ops::max_pool2d_with_indices_out::redispatch(dispatchKeySet, self, kernel_size, stride, padding, dilation, ceil_mode, out, indices); + } + + // aten::max_pool2d_with_indices(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor) + inline ::std::tuple max_pool2d_with_indices(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, at::IntArrayRef dilation=1, bool ceil_mode=false) { + return at::_ops::max_pool2d_with_indices::redispatch(dispatchKeySet, self, kernel_size, stride, padding, dilation, ceil_mode); + } + + // aten::max_pool2d_with_indices_backward.grad_input(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride, int[2] padding, int[2] dilation, bool ceil_mode, Tensor indices, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & max_pool2d_with_indices_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, const at::Tensor & indices) { + return at::_ops::max_pool2d_with_indices_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, kernel_size, stride, padding, dilation, ceil_mode, indices, grad_input); + } + + // aten::max_pool2d_with_indices_backward.grad_input(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride, int[2] padding, int[2] dilation, bool ceil_mode, Tensor indices, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & max_pool2d_with_indices_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, const at::Tensor & indices, at::Tensor & grad_input) { + return at::_ops::max_pool2d_with_indices_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, kernel_size, stride, padding, dilation, ceil_mode, indices, grad_input); + } + + // aten::max_pool2d_with_indices_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride, int[2] padding, int[2] dilation, bool ceil_mode, Tensor indices) -> Tensor + inline at::Tensor max_pool2d_with_indices_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, const at::Tensor & indices) { + return at::_ops::max_pool2d_with_indices_backward::redispatch(dispatchKeySet, grad_output, self, kernel_size, stride, padding, dilation, ceil_mode, indices); + } + + // aten::max_pool3d_with_indices.out(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False, *, Tensor(a!) out, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple max_pool3d_with_indices_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::Tensor & indices, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, at::IntArrayRef dilation=1, bool ceil_mode=false) { + return at::_ops::max_pool3d_with_indices_out::redispatch(dispatchKeySet, self, kernel_size, stride, padding, dilation, ceil_mode, out, indices); + } + + // aten::max_pool3d_with_indices.out(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False, *, Tensor(a!) out, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple max_pool3d_with_indices_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, at::Tensor & out, at::Tensor & indices) { + return at::_ops::max_pool3d_with_indices_out::redispatch(dispatchKeySet, self, kernel_size, stride, padding, dilation, ceil_mode, out, indices); + } + + // aten::max_pool3d_with_indices(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor) + inline ::std::tuple max_pool3d_with_indices(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, at::IntArrayRef dilation=1, bool ceil_mode=false) { + return at::_ops::max_pool3d_with_indices::redispatch(dispatchKeySet, self, kernel_size, stride, padding, dilation, ceil_mode); + } + + // aten::max_pool3d_with_indices_backward.grad_input(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] stride, int[3] padding, int[3] dilation, bool ceil_mode, Tensor indices, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & max_pool3d_with_indices_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, const at::Tensor & indices) { + return at::_ops::max_pool3d_with_indices_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, kernel_size, stride, padding, dilation, ceil_mode, indices, grad_input); + } + + // aten::max_pool3d_with_indices_backward.grad_input(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] stride, int[3] padding, int[3] dilation, bool ceil_mode, Tensor indices, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & max_pool3d_with_indices_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, const at::Tensor & indices, at::Tensor & grad_input) { + return at::_ops::max_pool3d_with_indices_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, kernel_size, stride, padding, dilation, ceil_mode, indices, grad_input); + } + + // aten::max_pool3d_with_indices_backward(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] stride, int[3] padding, int[3] dilation, bool ceil_mode, Tensor indices) -> Tensor + inline at::Tensor max_pool3d_with_indices_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, const at::Tensor & indices) { + return at::_ops::max_pool3d_with_indices_backward::redispatch(dispatchKeySet, grad_output, self, kernel_size, stride, padding, dilation, ceil_mode, indices); + } + + // aten::max_unpool2d.out(Tensor self, Tensor indices, SymInt[2] output_size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & max_unpool2d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & indices, at::IntArrayRef output_size) { + return at::_ops::max_unpool2d_out::redispatch(dispatchKeySet, self, indices, c10::fromIntArrayRefSlow(output_size), out); + } + + // aten::max_unpool2d.out(Tensor self, Tensor indices, SymInt[2] output_size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & max_unpool2d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & indices, at::IntArrayRef output_size, at::Tensor & out) { + return at::_ops::max_unpool2d_out::redispatch(dispatchKeySet, self, indices, c10::fromIntArrayRefSlow(output_size), out); + } + + // aten::max_unpool2d.out(Tensor self, Tensor indices, SymInt[2] output_size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & max_unpool2d_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & indices, c10::SymIntArrayRef output_size) { + return at::_ops::max_unpool2d_out::redispatch(dispatchKeySet, self, indices, output_size, out); + } + + // aten::max_unpool2d.out(Tensor self, Tensor indices, SymInt[2] output_size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & max_unpool2d_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & indices, c10::SymIntArrayRef output_size, at::Tensor & out) { + return at::_ops::max_unpool2d_out::redispatch(dispatchKeySet, self, indices, output_size, out); + } + + // aten::max_unpool2d(Tensor self, Tensor indices, SymInt[2] output_size) -> Tensor + inline at::Tensor max_unpool2d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & indices, at::IntArrayRef output_size) { + return at::_ops::max_unpool2d::redispatch(dispatchKeySet, self, indices, c10::fromIntArrayRefSlow(output_size)); + } + + // aten::max_unpool2d(Tensor self, Tensor indices, SymInt[2] output_size) -> Tensor + inline at::Tensor max_unpool2d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & indices, c10::SymIntArrayRef output_size) { + return at::_ops::max_unpool2d::redispatch(dispatchKeySet, self, indices, output_size); + } + + // aten::max_unpool3d.out(Tensor self, Tensor indices, SymInt[3] output_size, int[3] stride, int[3] padding, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & max_unpool3d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & indices, at::IntArrayRef output_size, at::IntArrayRef stride, at::IntArrayRef padding) { + return at::_ops::max_unpool3d_out::redispatch(dispatchKeySet, self, indices, c10::fromIntArrayRefSlow(output_size), stride, padding, out); + } + + // aten::max_unpool3d.out(Tensor self, Tensor indices, SymInt[3] output_size, int[3] stride, int[3] padding, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & max_unpool3d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & indices, at::IntArrayRef output_size, at::IntArrayRef stride, at::IntArrayRef padding, at::Tensor & out) { + return at::_ops::max_unpool3d_out::redispatch(dispatchKeySet, self, indices, c10::fromIntArrayRefSlow(output_size), stride, padding, out); + } + + // aten::max_unpool3d.out(Tensor self, Tensor indices, SymInt[3] output_size, int[3] stride, int[3] padding, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & max_unpool3d_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & indices, c10::SymIntArrayRef output_size, at::IntArrayRef stride, at::IntArrayRef padding) { + return at::_ops::max_unpool3d_out::redispatch(dispatchKeySet, self, indices, output_size, stride, padding, out); + } + + // aten::max_unpool3d.out(Tensor self, Tensor indices, SymInt[3] output_size, int[3] stride, int[3] padding, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & max_unpool3d_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & indices, c10::SymIntArrayRef output_size, at::IntArrayRef stride, at::IntArrayRef padding, at::Tensor & out) { + return at::_ops::max_unpool3d_out::redispatch(dispatchKeySet, self, indices, output_size, stride, padding, out); + } + + // aten::max_unpool3d(Tensor self, Tensor indices, SymInt[3] output_size, int[3] stride, int[3] padding) -> Tensor + inline at::Tensor max_unpool3d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & indices, at::IntArrayRef output_size, at::IntArrayRef stride, at::IntArrayRef padding) { + return at::_ops::max_unpool3d::redispatch(dispatchKeySet, self, indices, c10::fromIntArrayRefSlow(output_size), stride, padding); + } + + // aten::max_unpool3d(Tensor self, Tensor indices, SymInt[3] output_size, int[3] stride, int[3] padding) -> Tensor + inline at::Tensor max_unpool3d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & indices, c10::SymIntArrayRef output_size, at::IntArrayRef stride, at::IntArrayRef padding) { + return at::_ops::max_unpool3d::redispatch(dispatchKeySet, self, indices, output_size, stride, padding); + } + + // aten::reflection_pad1d.out(Tensor self, SymInt[2] padding, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & reflection_pad1d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef padding) { + return at::_ops::reflection_pad1d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(padding), out); + } + + // aten::reflection_pad1d.out(Tensor self, SymInt[2] padding, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & reflection_pad1d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef padding, at::Tensor & out) { + return at::_ops::reflection_pad1d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(padding), out); + } + + // aten::reflection_pad1d.out(Tensor self, SymInt[2] padding, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & reflection_pad1d_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef padding) { + return at::_ops::reflection_pad1d_out::redispatch(dispatchKeySet, self, padding, out); + } + + // aten::reflection_pad1d.out(Tensor self, SymInt[2] padding, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & reflection_pad1d_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef padding, at::Tensor & out) { + return at::_ops::reflection_pad1d_out::redispatch(dispatchKeySet, self, padding, out); + } + + // aten::reflection_pad1d(Tensor self, SymInt[2] padding) -> Tensor + inline at::Tensor reflection_pad1d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef padding) { + return at::_ops::reflection_pad1d::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(padding)); + } + + // aten::reflection_pad1d(Tensor self, SymInt[2] padding) -> Tensor + inline at::Tensor reflection_pad1d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef padding) { + return at::_ops::reflection_pad1d::redispatch(dispatchKeySet, self, padding); + } + + // aten::reflection_pad1d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[2] padding, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & reflection_pad1d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef padding) { + return at::_ops::reflection_pad1d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, c10::fromIntArrayRefSlow(padding), grad_input); + } + + // aten::reflection_pad1d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[2] padding, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & reflection_pad1d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef padding, at::Tensor & grad_input) { + return at::_ops::reflection_pad1d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, c10::fromIntArrayRefSlow(padding), grad_input); + } + + // aten::reflection_pad1d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[2] padding, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & reflection_pad1d_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, c10::SymIntArrayRef padding) { + return at::_ops::reflection_pad1d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, padding, grad_input); + } + + // aten::reflection_pad1d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[2] padding, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & reflection_pad1d_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, c10::SymIntArrayRef padding, at::Tensor & grad_input) { + return at::_ops::reflection_pad1d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, padding, grad_input); + } + + // aten::reflection_pad1d_backward(Tensor grad_output, Tensor self, SymInt[2] padding) -> Tensor + inline at::Tensor reflection_pad1d_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef padding) { + return at::_ops::reflection_pad1d_backward::redispatch(dispatchKeySet, grad_output, self, c10::fromIntArrayRefSlow(padding)); + } + + // aten::reflection_pad1d_backward(Tensor grad_output, Tensor self, SymInt[2] padding) -> Tensor + inline at::Tensor reflection_pad1d_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, c10::SymIntArrayRef padding) { + return at::_ops::reflection_pad1d_backward::redispatch(dispatchKeySet, grad_output, self, padding); + } + + // aten::reflection_pad2d.out(Tensor self, SymInt[4] padding, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & reflection_pad2d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef padding) { + return at::_ops::reflection_pad2d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(padding), out); + } + + // aten::reflection_pad2d.out(Tensor self, SymInt[4] padding, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & reflection_pad2d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef padding, at::Tensor & out) { + return at::_ops::reflection_pad2d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(padding), out); + } + + // aten::reflection_pad2d.out(Tensor self, SymInt[4] padding, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & reflection_pad2d_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef padding) { + return at::_ops::reflection_pad2d_out::redispatch(dispatchKeySet, self, padding, out); + } + + // aten::reflection_pad2d.out(Tensor self, SymInt[4] padding, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & reflection_pad2d_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef padding, at::Tensor & out) { + return at::_ops::reflection_pad2d_out::redispatch(dispatchKeySet, self, padding, out); + } + + // aten::reflection_pad2d(Tensor self, SymInt[4] padding) -> Tensor + inline at::Tensor reflection_pad2d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef padding) { + return at::_ops::reflection_pad2d::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(padding)); + } + + // aten::reflection_pad2d(Tensor self, SymInt[4] padding) -> Tensor + inline at::Tensor reflection_pad2d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef padding) { + return at::_ops::reflection_pad2d::redispatch(dispatchKeySet, self, padding); + } + + // aten::reflection_pad2d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[4] padding, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & reflection_pad2d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef padding) { + return at::_ops::reflection_pad2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, c10::fromIntArrayRefSlow(padding), grad_input); + } + + // aten::reflection_pad2d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[4] padding, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & reflection_pad2d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef padding, at::Tensor & grad_input) { + return at::_ops::reflection_pad2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, c10::fromIntArrayRefSlow(padding), grad_input); + } + + // aten::reflection_pad2d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[4] padding, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & reflection_pad2d_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, c10::SymIntArrayRef padding) { + return at::_ops::reflection_pad2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, padding, grad_input); + } + + // aten::reflection_pad2d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[4] padding, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & reflection_pad2d_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, c10::SymIntArrayRef padding, at::Tensor & grad_input) { + return at::_ops::reflection_pad2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, padding, grad_input); + } + + // aten::reflection_pad2d_backward(Tensor grad_output, Tensor self, SymInt[4] padding) -> Tensor + inline at::Tensor reflection_pad2d_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef padding) { + return at::_ops::reflection_pad2d_backward::redispatch(dispatchKeySet, grad_output, self, c10::fromIntArrayRefSlow(padding)); + } + + // aten::reflection_pad2d_backward(Tensor grad_output, Tensor self, SymInt[4] padding) -> Tensor + inline at::Tensor reflection_pad2d_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, c10::SymIntArrayRef padding) { + return at::_ops::reflection_pad2d_backward::redispatch(dispatchKeySet, grad_output, self, padding); + } + + // aten::reflection_pad3d.out(Tensor self, SymInt[6] padding, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & reflection_pad3d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef padding) { + return at::_ops::reflection_pad3d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(padding), out); + } + + // aten::reflection_pad3d.out(Tensor self, SymInt[6] padding, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & reflection_pad3d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef padding, at::Tensor & out) { + return at::_ops::reflection_pad3d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(padding), out); + } + + // aten::reflection_pad3d.out(Tensor self, SymInt[6] padding, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & reflection_pad3d_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef padding) { + return at::_ops::reflection_pad3d_out::redispatch(dispatchKeySet, self, padding, out); + } + + // aten::reflection_pad3d.out(Tensor self, SymInt[6] padding, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & reflection_pad3d_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef padding, at::Tensor & out) { + return at::_ops::reflection_pad3d_out::redispatch(dispatchKeySet, self, padding, out); + } + + // aten::reflection_pad3d(Tensor self, SymInt[6] padding) -> Tensor + inline at::Tensor reflection_pad3d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef padding) { + return at::_ops::reflection_pad3d::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(padding)); + } + + // aten::reflection_pad3d(Tensor self, SymInt[6] padding) -> Tensor + inline at::Tensor reflection_pad3d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef padding) { + return at::_ops::reflection_pad3d::redispatch(dispatchKeySet, self, padding); + } + + // aten::reflection_pad3d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[6] padding, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & reflection_pad3d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef padding) { + return at::_ops::reflection_pad3d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, c10::fromIntArrayRefSlow(padding), grad_input); + } + + // aten::reflection_pad3d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[6] padding, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & reflection_pad3d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef padding, at::Tensor & grad_input) { + return at::_ops::reflection_pad3d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, c10::fromIntArrayRefSlow(padding), grad_input); + } + + // aten::reflection_pad3d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[6] padding, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & reflection_pad3d_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, c10::SymIntArrayRef padding) { + return at::_ops::reflection_pad3d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, padding, grad_input); + } + + // aten::reflection_pad3d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[6] padding, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & reflection_pad3d_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, c10::SymIntArrayRef padding, at::Tensor & grad_input) { + return at::_ops::reflection_pad3d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, padding, grad_input); + } + + // aten::reflection_pad3d_backward(Tensor grad_output, Tensor self, SymInt[6] padding) -> Tensor + inline at::Tensor reflection_pad3d_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef padding) { + return at::_ops::reflection_pad3d_backward::redispatch(dispatchKeySet, grad_output, self, c10::fromIntArrayRefSlow(padding)); + } + + // aten::reflection_pad3d_backward(Tensor grad_output, Tensor self, SymInt[6] padding) -> Tensor + inline at::Tensor reflection_pad3d_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, c10::SymIntArrayRef padding) { + return at::_ops::reflection_pad3d_backward::redispatch(dispatchKeySet, grad_output, self, padding); + } + + // aten::replication_pad1d.out(Tensor self, SymInt[2] padding, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & replication_pad1d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef padding) { + return at::_ops::replication_pad1d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(padding), out); + } + + // aten::replication_pad1d.out(Tensor self, SymInt[2] padding, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & replication_pad1d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef padding, at::Tensor & out) { + return at::_ops::replication_pad1d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(padding), out); + } + + // aten::replication_pad1d.out(Tensor self, SymInt[2] padding, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & replication_pad1d_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef padding) { + return at::_ops::replication_pad1d_out::redispatch(dispatchKeySet, self, padding, out); + } + + // aten::replication_pad1d.out(Tensor self, SymInt[2] padding, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & replication_pad1d_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef padding, at::Tensor & out) { + return at::_ops::replication_pad1d_out::redispatch(dispatchKeySet, self, padding, out); + } + + // aten::replication_pad1d(Tensor self, SymInt[2] padding) -> Tensor + inline at::Tensor replication_pad1d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef padding) { + return at::_ops::replication_pad1d::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(padding)); + } + + // aten::replication_pad1d(Tensor self, SymInt[2] padding) -> Tensor + inline at::Tensor replication_pad1d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef padding) { + return at::_ops::replication_pad1d::redispatch(dispatchKeySet, self, padding); + } + + // aten::replication_pad1d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[2] padding, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & replication_pad1d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef padding) { + return at::_ops::replication_pad1d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, c10::fromIntArrayRefSlow(padding), grad_input); + } + + // aten::replication_pad1d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[2] padding, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & replication_pad1d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef padding, at::Tensor & grad_input) { + return at::_ops::replication_pad1d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, c10::fromIntArrayRefSlow(padding), grad_input); + } + + // aten::replication_pad1d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[2] padding, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & replication_pad1d_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, c10::SymIntArrayRef padding) { + return at::_ops::replication_pad1d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, padding, grad_input); + } + + // aten::replication_pad1d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[2] padding, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & replication_pad1d_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, c10::SymIntArrayRef padding, at::Tensor & grad_input) { + return at::_ops::replication_pad1d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, padding, grad_input); + } + + // aten::replication_pad1d_backward(Tensor grad_output, Tensor self, SymInt[2] padding) -> Tensor + inline at::Tensor replication_pad1d_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef padding) { + return at::_ops::replication_pad1d_backward::redispatch(dispatchKeySet, grad_output, self, c10::fromIntArrayRefSlow(padding)); + } + + // aten::replication_pad1d_backward(Tensor grad_output, Tensor self, SymInt[2] padding) -> Tensor + inline at::Tensor replication_pad1d_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, c10::SymIntArrayRef padding) { + return at::_ops::replication_pad1d_backward::redispatch(dispatchKeySet, grad_output, self, padding); + } + + // aten::replication_pad2d.out(Tensor self, SymInt[4] padding, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & replication_pad2d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef padding) { + return at::_ops::replication_pad2d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(padding), out); + } + + // aten::replication_pad2d.out(Tensor self, SymInt[4] padding, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & replication_pad2d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef padding, at::Tensor & out) { + return at::_ops::replication_pad2d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(padding), out); + } + + // aten::replication_pad2d.out(Tensor self, SymInt[4] padding, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & replication_pad2d_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef padding) { + return at::_ops::replication_pad2d_out::redispatch(dispatchKeySet, self, padding, out); + } + + // aten::replication_pad2d.out(Tensor self, SymInt[4] padding, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & replication_pad2d_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef padding, at::Tensor & out) { + return at::_ops::replication_pad2d_out::redispatch(dispatchKeySet, self, padding, out); + } + + // aten::replication_pad2d(Tensor self, SymInt[4] padding) -> Tensor + inline at::Tensor replication_pad2d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef padding) { + return at::_ops::replication_pad2d::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(padding)); + } + + // aten::replication_pad2d(Tensor self, SymInt[4] padding) -> Tensor + inline at::Tensor replication_pad2d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef padding) { + return at::_ops::replication_pad2d::redispatch(dispatchKeySet, self, padding); + } + + // aten::replication_pad2d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[4] padding, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & replication_pad2d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef padding) { + return at::_ops::replication_pad2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, c10::fromIntArrayRefSlow(padding), grad_input); + } + + // aten::replication_pad2d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[4] padding, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & replication_pad2d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef padding, at::Tensor & grad_input) { + return at::_ops::replication_pad2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, c10::fromIntArrayRefSlow(padding), grad_input); + } + + // aten::replication_pad2d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[4] padding, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & replication_pad2d_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, c10::SymIntArrayRef padding) { + return at::_ops::replication_pad2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, padding, grad_input); + } + + // aten::replication_pad2d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[4] padding, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & replication_pad2d_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, c10::SymIntArrayRef padding, at::Tensor & grad_input) { + return at::_ops::replication_pad2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, padding, grad_input); + } + + // aten::replication_pad2d_backward(Tensor grad_output, Tensor self, SymInt[4] padding) -> Tensor + inline at::Tensor replication_pad2d_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef padding) { + return at::_ops::replication_pad2d_backward::redispatch(dispatchKeySet, grad_output, self, c10::fromIntArrayRefSlow(padding)); + } + + // aten::replication_pad2d_backward(Tensor grad_output, Tensor self, SymInt[4] padding) -> Tensor + inline at::Tensor replication_pad2d_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, c10::SymIntArrayRef padding) { + return at::_ops::replication_pad2d_backward::redispatch(dispatchKeySet, grad_output, self, padding); + } + + // aten::replication_pad3d.out(Tensor self, SymInt[6] padding, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & replication_pad3d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef padding) { + return at::_ops::replication_pad3d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(padding), out); + } + + // aten::replication_pad3d.out(Tensor self, SymInt[6] padding, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & replication_pad3d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef padding, at::Tensor & out) { + return at::_ops::replication_pad3d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(padding), out); + } + + // aten::replication_pad3d.out(Tensor self, SymInt[6] padding, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & replication_pad3d_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef padding) { + return at::_ops::replication_pad3d_out::redispatch(dispatchKeySet, self, padding, out); + } + + // aten::replication_pad3d.out(Tensor self, SymInt[6] padding, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & replication_pad3d_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef padding, at::Tensor & out) { + return at::_ops::replication_pad3d_out::redispatch(dispatchKeySet, self, padding, out); + } + + // aten::replication_pad3d(Tensor self, SymInt[6] padding) -> Tensor + inline at::Tensor replication_pad3d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef padding) { + return at::_ops::replication_pad3d::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(padding)); + } + + // aten::replication_pad3d(Tensor self, SymInt[6] padding) -> Tensor + inline at::Tensor replication_pad3d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef padding) { + return at::_ops::replication_pad3d::redispatch(dispatchKeySet, self, padding); + } + + // aten::replication_pad3d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[6] padding, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & replication_pad3d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef padding) { + return at::_ops::replication_pad3d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, c10::fromIntArrayRefSlow(padding), grad_input); + } + + // aten::replication_pad3d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[6] padding, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & replication_pad3d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef padding, at::Tensor & grad_input) { + return at::_ops::replication_pad3d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, c10::fromIntArrayRefSlow(padding), grad_input); + } + + // aten::replication_pad3d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[6] padding, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & replication_pad3d_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, c10::SymIntArrayRef padding) { + return at::_ops::replication_pad3d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, padding, grad_input); + } + + // aten::replication_pad3d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[6] padding, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & replication_pad3d_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, c10::SymIntArrayRef padding, at::Tensor & grad_input) { + return at::_ops::replication_pad3d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, padding, grad_input); + } + + // aten::replication_pad3d_backward(Tensor grad_output, Tensor self, SymInt[6] padding) -> Tensor + inline at::Tensor replication_pad3d_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef padding) { + return at::_ops::replication_pad3d_backward::redispatch(dispatchKeySet, grad_output, self, c10::fromIntArrayRefSlow(padding)); + } + + // aten::replication_pad3d_backward(Tensor grad_output, Tensor self, SymInt[6] padding) -> Tensor + inline at::Tensor replication_pad3d_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, c10::SymIntArrayRef padding) { + return at::_ops::replication_pad3d_backward::redispatch(dispatchKeySet, grad_output, self, padding); + } + + // aten::_pad_circular(Tensor self, SymInt[] pad) -> Tensor + inline at::Tensor _pad_circular(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef pad) { + return at::_ops::_pad_circular::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(pad)); + } + + // aten::_pad_circular(Tensor self, SymInt[] pad) -> Tensor + inline at::Tensor _pad_circular_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef pad) { + return at::_ops::_pad_circular::redispatch(dispatchKeySet, self, pad); + } + + // aten::_pad_enum(Tensor self, SymInt[] pad, int mode, float? value=None) -> Tensor + inline at::Tensor _pad_enum(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef pad, int64_t mode, ::std::optional value=::std::nullopt) { + return at::_ops::_pad_enum::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(pad), mode, value); + } + + // aten::_pad_enum(Tensor self, SymInt[] pad, int mode, float? value=None) -> Tensor + inline at::Tensor _pad_enum_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef pad, int64_t mode, ::std::optional value=::std::nullopt) { + return at::_ops::_pad_enum::redispatch(dispatchKeySet, self, pad, mode, value); + } + + // aten::pad(Tensor self, SymInt[] pad, str mode="constant", float? value=None) -> Tensor + inline at::Tensor pad(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef pad, c10::string_view mode="constant", ::std::optional value=::std::nullopt) { + return at::_ops::pad::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(pad), mode, value); + } + + // aten::pad(Tensor self, SymInt[] pad, str mode="constant", float? value=None) -> Tensor + inline at::Tensor pad_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef pad, c10::string_view mode="constant", ::std::optional value=::std::nullopt) { + return at::_ops::pad::redispatch(dispatchKeySet, self, pad, mode, value); + } + + // aten::upsample_linear1d.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor + inline at::Tensor upsample_linear1d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::OptionalIntArrayRef output_size, bool align_corners, ::std::optional> scale_factors) { + return at::_ops::upsample_linear1d_vec::redispatch(dispatchKeySet, input, output_size.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*output_size)) : ::std::nullopt, align_corners, scale_factors); + } + + // aten::upsample_linear1d.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor + inline at::Tensor upsample_linear1d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::OptionalSymIntArrayRef output_size, bool align_corners, ::std::optional> scale_factors) { + return at::_ops::upsample_linear1d_vec::redispatch(dispatchKeySet, input, output_size, align_corners, scale_factors); + } + + // aten::upsample_bilinear2d.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor + inline at::Tensor upsample_bilinear2d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::OptionalIntArrayRef output_size, bool align_corners, ::std::optional> scale_factors) { + return at::_ops::upsample_bilinear2d_vec::redispatch(dispatchKeySet, input, output_size.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*output_size)) : ::std::nullopt, align_corners, scale_factors); + } + + // aten::upsample_bilinear2d.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor + inline at::Tensor upsample_bilinear2d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::OptionalSymIntArrayRef output_size, bool align_corners, ::std::optional> scale_factors) { + return at::_ops::upsample_bilinear2d_vec::redispatch(dispatchKeySet, input, output_size, align_corners, scale_factors); + } + + // aten::_upsample_bilinear2d_aa.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor + inline at::Tensor _upsample_bilinear2d_aa(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::OptionalIntArrayRef output_size, bool align_corners, ::std::optional> scale_factors) { + return at::_ops::_upsample_bilinear2d_aa_vec::redispatch(dispatchKeySet, input, output_size.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*output_size)) : ::std::nullopt, align_corners, scale_factors); + } + + // aten::_upsample_bilinear2d_aa.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor + inline at::Tensor _upsample_bilinear2d_aa_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::OptionalSymIntArrayRef output_size, bool align_corners, ::std::optional> scale_factors) { + return at::_ops::_upsample_bilinear2d_aa_vec::redispatch(dispatchKeySet, input, output_size, align_corners, scale_factors); + } + + // aten::upsample_trilinear3d.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor + inline at::Tensor upsample_trilinear3d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::OptionalIntArrayRef output_size, bool align_corners, ::std::optional> scale_factors) { + return at::_ops::upsample_trilinear3d_vec::redispatch(dispatchKeySet, input, output_size.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*output_size)) : ::std::nullopt, align_corners, scale_factors); + } + + // aten::upsample_trilinear3d.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor + inline at::Tensor upsample_trilinear3d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::OptionalSymIntArrayRef output_size, bool align_corners, ::std::optional> scale_factors) { + return at::_ops::upsample_trilinear3d_vec::redispatch(dispatchKeySet, input, output_size, align_corners, scale_factors); + } + + // aten::upsample_bicubic2d.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor + inline at::Tensor upsample_bicubic2d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::OptionalIntArrayRef output_size, bool align_corners, ::std::optional> scale_factors) { + return at::_ops::upsample_bicubic2d_vec::redispatch(dispatchKeySet, input, output_size.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*output_size)) : ::std::nullopt, align_corners, scale_factors); + } + + // aten::upsample_bicubic2d.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor + inline at::Tensor upsample_bicubic2d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::OptionalSymIntArrayRef output_size, bool align_corners, ::std::optional> scale_factors) { + return at::_ops::upsample_bicubic2d_vec::redispatch(dispatchKeySet, input, output_size, align_corners, scale_factors); + } + + // aten::_upsample_bicubic2d_aa.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor + inline at::Tensor _upsample_bicubic2d_aa(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::OptionalIntArrayRef output_size, bool align_corners, ::std::optional> scale_factors) { + return at::_ops::_upsample_bicubic2d_aa_vec::redispatch(dispatchKeySet, input, output_size.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*output_size)) : ::std::nullopt, align_corners, scale_factors); + } + + // aten::_upsample_bicubic2d_aa.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor + inline at::Tensor _upsample_bicubic2d_aa_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::OptionalSymIntArrayRef output_size, bool align_corners, ::std::optional> scale_factors) { + return at::_ops::_upsample_bicubic2d_aa_vec::redispatch(dispatchKeySet, input, output_size, align_corners, scale_factors); + } + + // aten::upsample_nearest1d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor + inline at::Tensor upsample_nearest1d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::OptionalIntArrayRef output_size, ::std::optional> scale_factors) { + return at::_ops::upsample_nearest1d_vec::redispatch(dispatchKeySet, input, output_size.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*output_size)) : ::std::nullopt, scale_factors); + } + + // aten::upsample_nearest1d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor + inline at::Tensor upsample_nearest1d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::OptionalSymIntArrayRef output_size, ::std::optional> scale_factors) { + return at::_ops::upsample_nearest1d_vec::redispatch(dispatchKeySet, input, output_size, scale_factors); + } + + // aten::_upsample_nearest_exact1d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor + inline at::Tensor _upsample_nearest_exact1d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::OptionalIntArrayRef output_size, ::std::optional> scale_factors) { + return at::_ops::_upsample_nearest_exact1d_vec::redispatch(dispatchKeySet, input, output_size.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*output_size)) : ::std::nullopt, scale_factors); + } + + // aten::_upsample_nearest_exact1d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor + inline at::Tensor _upsample_nearest_exact1d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::OptionalSymIntArrayRef output_size, ::std::optional> scale_factors) { + return at::_ops::_upsample_nearest_exact1d_vec::redispatch(dispatchKeySet, input, output_size, scale_factors); + } + + // aten::upsample_nearest2d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor + inline at::Tensor upsample_nearest2d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::OptionalIntArrayRef output_size, ::std::optional> scale_factors) { + return at::_ops::upsample_nearest2d_vec::redispatch(dispatchKeySet, input, output_size.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*output_size)) : ::std::nullopt, scale_factors); + } + + // aten::upsample_nearest2d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor + inline at::Tensor upsample_nearest2d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::OptionalSymIntArrayRef output_size, ::std::optional> scale_factors) { + return at::_ops::upsample_nearest2d_vec::redispatch(dispatchKeySet, input, output_size, scale_factors); + } + + // aten::_upsample_nearest_exact2d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor + inline at::Tensor _upsample_nearest_exact2d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::OptionalIntArrayRef output_size, ::std::optional> scale_factors) { + return at::_ops::_upsample_nearest_exact2d_vec::redispatch(dispatchKeySet, input, output_size.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*output_size)) : ::std::nullopt, scale_factors); + } + + // aten::_upsample_nearest_exact2d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor + inline at::Tensor _upsample_nearest_exact2d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::OptionalSymIntArrayRef output_size, ::std::optional> scale_factors) { + return at::_ops::_upsample_nearest_exact2d_vec::redispatch(dispatchKeySet, input, output_size, scale_factors); + } + + // aten::upsample_nearest3d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor + inline at::Tensor upsample_nearest3d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::OptionalIntArrayRef output_size, ::std::optional> scale_factors) { + return at::_ops::upsample_nearest3d_vec::redispatch(dispatchKeySet, input, output_size.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*output_size)) : ::std::nullopt, scale_factors); + } + + // aten::upsample_nearest3d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor + inline at::Tensor upsample_nearest3d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::OptionalSymIntArrayRef output_size, ::std::optional> scale_factors) { + return at::_ops::upsample_nearest3d_vec::redispatch(dispatchKeySet, input, output_size, scale_factors); + } + + // aten::_upsample_nearest_exact3d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor + inline at::Tensor _upsample_nearest_exact3d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::OptionalIntArrayRef output_size, ::std::optional> scale_factors) { + return at::_ops::_upsample_nearest_exact3d_vec::redispatch(dispatchKeySet, input, output_size.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*output_size)) : ::std::nullopt, scale_factors); + } + + // aten::_upsample_nearest_exact3d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor + inline at::Tensor _upsample_nearest_exact3d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::OptionalSymIntArrayRef output_size, ::std::optional> scale_factors) { + return at::_ops::_upsample_nearest_exact3d_vec::redispatch(dispatchKeySet, input, output_size, scale_factors); + } + + // aten::upsample_linear1d.out(Tensor self, SymInt[1] output_size, bool align_corners, float? scales=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & upsample_linear1d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef output_size, bool align_corners, ::std::optional scales=::std::nullopt) { + return at::_ops::upsample_linear1d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), align_corners, scales, out); + } + + // aten::upsample_linear1d.out(Tensor self, SymInt[1] output_size, bool align_corners, float? scales=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & upsample_linear1d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, bool align_corners, ::std::optional scales, at::Tensor & out) { + return at::_ops::upsample_linear1d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), align_corners, scales, out); + } + + // aten::upsample_linear1d.out(Tensor self, SymInt[1] output_size, bool align_corners, float? scales=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & upsample_linear1d_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, ::std::optional scales=::std::nullopt) { + return at::_ops::upsample_linear1d_out::redispatch(dispatchKeySet, self, output_size, align_corners, scales, out); + } + + // aten::upsample_linear1d.out(Tensor self, SymInt[1] output_size, bool align_corners, float? scales=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & upsample_linear1d_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, ::std::optional scales, at::Tensor & out) { + return at::_ops::upsample_linear1d_out::redispatch(dispatchKeySet, self, output_size, align_corners, scales, out); + } + + // aten::upsample_linear1d(Tensor self, SymInt[1] output_size, bool align_corners, float? scales=None) -> Tensor + inline at::Tensor upsample_linear1d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, bool align_corners, ::std::optional scales=::std::nullopt) { + return at::_ops::upsample_linear1d::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), align_corners, scales); + } + + // aten::upsample_linear1d(Tensor self, SymInt[1] output_size, bool align_corners, float? scales=None) -> Tensor + inline at::Tensor upsample_linear1d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, ::std::optional scales=::std::nullopt) { + return at::_ops::upsample_linear1d::redispatch(dispatchKeySet, self, output_size, align_corners, scales); + } + + // aten::upsample_linear1d_backward.grad_input(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, bool align_corners, float? scales=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & upsample_linear1d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, bool align_corners, ::std::optional scales=::std::nullopt) { + return at::_ops::upsample_linear1d_backward_grad_input::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), align_corners, scales, grad_input); + } + + // aten::upsample_linear1d_backward.grad_input(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, bool align_corners, float? scales=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & upsample_linear1d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, bool align_corners, ::std::optional scales, at::Tensor & grad_input) { + return at::_ops::upsample_linear1d_backward_grad_input::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), align_corners, scales, grad_input); + } + + // aten::upsample_linear1d_backward.grad_input(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, bool align_corners, float? scales=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & upsample_linear1d_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, ::std::optional scales=::std::nullopt) { + return at::_ops::upsample_linear1d_backward_grad_input::redispatch(dispatchKeySet, grad_output, output_size, input_size, align_corners, scales, grad_input); + } + + // aten::upsample_linear1d_backward.grad_input(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, bool align_corners, float? scales=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & upsample_linear1d_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, ::std::optional scales, at::Tensor & grad_input) { + return at::_ops::upsample_linear1d_backward_grad_input::redispatch(dispatchKeySet, grad_output, output_size, input_size, align_corners, scales, grad_input); + } + + // aten::upsample_linear1d_backward(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, bool align_corners, float? scales=None) -> Tensor + inline at::Tensor upsample_linear1d_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, bool align_corners, ::std::optional scales=::std::nullopt) { + return at::_ops::upsample_linear1d_backward::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), align_corners, scales); + } + + // aten::upsample_linear1d_backward(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, bool align_corners, float? scales=None) -> Tensor + inline at::Tensor upsample_linear1d_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, ::std::optional scales=::std::nullopt) { + return at::_ops::upsample_linear1d_backward::redispatch(dispatchKeySet, grad_output, output_size, input_size, align_corners, scales); + } + + // aten::upsample_bilinear2d.out(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & upsample_bilinear2d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef output_size, bool align_corners, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::upsample_bilinear2d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), align_corners, scales_h, scales_w, out); + } + + // aten::upsample_bilinear2d.out(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & upsample_bilinear2d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, bool align_corners, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & out) { + return at::_ops::upsample_bilinear2d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), align_corners, scales_h, scales_w, out); + } + + // aten::upsample_bilinear2d.out(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & upsample_bilinear2d_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::upsample_bilinear2d_out::redispatch(dispatchKeySet, self, output_size, align_corners, scales_h, scales_w, out); + } + + // aten::upsample_bilinear2d.out(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & upsample_bilinear2d_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & out) { + return at::_ops::upsample_bilinear2d_out::redispatch(dispatchKeySet, self, output_size, align_corners, scales_h, scales_w, out); + } + + // aten::upsample_bilinear2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor + inline at::Tensor upsample_bilinear2d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, bool align_corners, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::upsample_bilinear2d::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), align_corners, scales_h, scales_w); + } + + // aten::upsample_bilinear2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor + inline at::Tensor upsample_bilinear2d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::upsample_bilinear2d::redispatch(dispatchKeySet, self, output_size, align_corners, scales_h, scales_w); + } + + // aten::upsample_bilinear2d_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & upsample_bilinear2d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, bool align_corners, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::upsample_bilinear2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), align_corners, scales_h, scales_w, grad_input); + } + + // aten::upsample_bilinear2d_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & upsample_bilinear2d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, bool align_corners, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & grad_input) { + return at::_ops::upsample_bilinear2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), align_corners, scales_h, scales_w, grad_input); + } + + // aten::upsample_bilinear2d_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & upsample_bilinear2d_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::upsample_bilinear2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, output_size, input_size, align_corners, scales_h, scales_w, grad_input); + } + + // aten::upsample_bilinear2d_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & upsample_bilinear2d_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & grad_input) { + return at::_ops::upsample_bilinear2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, output_size, input_size, align_corners, scales_h, scales_w, grad_input); + } + + // aten::upsample_bilinear2d_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor + inline at::Tensor upsample_bilinear2d_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, bool align_corners, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::upsample_bilinear2d_backward::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), align_corners, scales_h, scales_w); + } + + // aten::upsample_bilinear2d_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor + inline at::Tensor upsample_bilinear2d_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::upsample_bilinear2d_backward::redispatch(dispatchKeySet, grad_output, output_size, input_size, align_corners, scales_h, scales_w); + } + + // aten::_upsample_bilinear2d_aa.out(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _upsample_bilinear2d_aa_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef output_size, bool align_corners, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::_upsample_bilinear2d_aa_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), align_corners, scales_h, scales_w, out); + } + + // aten::_upsample_bilinear2d_aa.out(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _upsample_bilinear2d_aa_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, bool align_corners, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & out) { + return at::_ops::_upsample_bilinear2d_aa_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), align_corners, scales_h, scales_w, out); + } + + // aten::_upsample_bilinear2d_aa.out(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _upsample_bilinear2d_aa_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::_upsample_bilinear2d_aa_out::redispatch(dispatchKeySet, self, output_size, align_corners, scales_h, scales_w, out); + } + + // aten::_upsample_bilinear2d_aa.out(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _upsample_bilinear2d_aa_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & out) { + return at::_ops::_upsample_bilinear2d_aa_out::redispatch(dispatchKeySet, self, output_size, align_corners, scales_h, scales_w, out); + } + + // aten::_upsample_bilinear2d_aa(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor + inline at::Tensor _upsample_bilinear2d_aa(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, bool align_corners, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::_upsample_bilinear2d_aa::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), align_corners, scales_h, scales_w); + } + + // aten::_upsample_bilinear2d_aa(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor + inline at::Tensor _upsample_bilinear2d_aa_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::_upsample_bilinear2d_aa::redispatch(dispatchKeySet, self, output_size, align_corners, scales_h, scales_w); + } + + // aten::_upsample_bilinear2d_aa_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & _upsample_bilinear2d_aa_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, bool align_corners, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::_upsample_bilinear2d_aa_backward_grad_input::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), align_corners, scales_h, scales_w, grad_input); + } + + // aten::_upsample_bilinear2d_aa_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & _upsample_bilinear2d_aa_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, bool align_corners, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & grad_input) { + return at::_ops::_upsample_bilinear2d_aa_backward_grad_input::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), align_corners, scales_h, scales_w, grad_input); + } + + // aten::_upsample_bilinear2d_aa_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & _upsample_bilinear2d_aa_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::_upsample_bilinear2d_aa_backward_grad_input::redispatch(dispatchKeySet, grad_output, output_size, input_size, align_corners, scales_h, scales_w, grad_input); + } + + // aten::_upsample_bilinear2d_aa_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & _upsample_bilinear2d_aa_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & grad_input) { + return at::_ops::_upsample_bilinear2d_aa_backward_grad_input::redispatch(dispatchKeySet, grad_output, output_size, input_size, align_corners, scales_h, scales_w, grad_input); + } + + // aten::_upsample_bilinear2d_aa_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor + inline at::Tensor _upsample_bilinear2d_aa_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, bool align_corners, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::_upsample_bilinear2d_aa_backward::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), align_corners, scales_h, scales_w); + } + + // aten::_upsample_bilinear2d_aa_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor + inline at::Tensor _upsample_bilinear2d_aa_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::_upsample_bilinear2d_aa_backward::redispatch(dispatchKeySet, grad_output, output_size, input_size, align_corners, scales_h, scales_w); + } + + // aten::upsample_bicubic2d.out(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & upsample_bicubic2d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef output_size, bool align_corners, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::upsample_bicubic2d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), align_corners, scales_h, scales_w, out); + } + + // aten::upsample_bicubic2d.out(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & upsample_bicubic2d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, bool align_corners, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & out) { + return at::_ops::upsample_bicubic2d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), align_corners, scales_h, scales_w, out); + } + + // aten::upsample_bicubic2d.out(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & upsample_bicubic2d_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::upsample_bicubic2d_out::redispatch(dispatchKeySet, self, output_size, align_corners, scales_h, scales_w, out); + } + + // aten::upsample_bicubic2d.out(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & upsample_bicubic2d_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & out) { + return at::_ops::upsample_bicubic2d_out::redispatch(dispatchKeySet, self, output_size, align_corners, scales_h, scales_w, out); + } + + // aten::upsample_bicubic2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor + inline at::Tensor upsample_bicubic2d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, bool align_corners, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::upsample_bicubic2d::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), align_corners, scales_h, scales_w); + } + + // aten::upsample_bicubic2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor + inline at::Tensor upsample_bicubic2d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::upsample_bicubic2d::redispatch(dispatchKeySet, self, output_size, align_corners, scales_h, scales_w); + } + + // aten::upsample_bicubic2d_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & upsample_bicubic2d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, bool align_corners, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::upsample_bicubic2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), align_corners, scales_h, scales_w, grad_input); + } + + // aten::upsample_bicubic2d_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & upsample_bicubic2d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, bool align_corners, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & grad_input) { + return at::_ops::upsample_bicubic2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), align_corners, scales_h, scales_w, grad_input); + } + + // aten::upsample_bicubic2d_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & upsample_bicubic2d_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::upsample_bicubic2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, output_size, input_size, align_corners, scales_h, scales_w, grad_input); + } + + // aten::upsample_bicubic2d_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & upsample_bicubic2d_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & grad_input) { + return at::_ops::upsample_bicubic2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, output_size, input_size, align_corners, scales_h, scales_w, grad_input); + } + + // aten::upsample_bicubic2d_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor + inline at::Tensor upsample_bicubic2d_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, bool align_corners, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::upsample_bicubic2d_backward::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), align_corners, scales_h, scales_w); + } + + // aten::upsample_bicubic2d_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor + inline at::Tensor upsample_bicubic2d_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::upsample_bicubic2d_backward::redispatch(dispatchKeySet, grad_output, output_size, input_size, align_corners, scales_h, scales_w); + } + + // aten::_upsample_bicubic2d_aa.out(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _upsample_bicubic2d_aa_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef output_size, bool align_corners, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::_upsample_bicubic2d_aa_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), align_corners, scales_h, scales_w, out); + } + + // aten::_upsample_bicubic2d_aa.out(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _upsample_bicubic2d_aa_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, bool align_corners, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & out) { + return at::_ops::_upsample_bicubic2d_aa_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), align_corners, scales_h, scales_w, out); + } + + // aten::_upsample_bicubic2d_aa.out(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _upsample_bicubic2d_aa_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::_upsample_bicubic2d_aa_out::redispatch(dispatchKeySet, self, output_size, align_corners, scales_h, scales_w, out); + } + + // aten::_upsample_bicubic2d_aa.out(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _upsample_bicubic2d_aa_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & out) { + return at::_ops::_upsample_bicubic2d_aa_out::redispatch(dispatchKeySet, self, output_size, align_corners, scales_h, scales_w, out); + } + + // aten::_upsample_bicubic2d_aa(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor + inline at::Tensor _upsample_bicubic2d_aa(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, bool align_corners, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::_upsample_bicubic2d_aa::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), align_corners, scales_h, scales_w); + } + + // aten::_upsample_bicubic2d_aa(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor + inline at::Tensor _upsample_bicubic2d_aa_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::_upsample_bicubic2d_aa::redispatch(dispatchKeySet, self, output_size, align_corners, scales_h, scales_w); + } + + // aten::_upsample_bicubic2d_aa_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & _upsample_bicubic2d_aa_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, bool align_corners, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::_upsample_bicubic2d_aa_backward_grad_input::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), align_corners, scales_h, scales_w, grad_input); + } + + // aten::_upsample_bicubic2d_aa_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & _upsample_bicubic2d_aa_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, bool align_corners, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & grad_input) { + return at::_ops::_upsample_bicubic2d_aa_backward_grad_input::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), align_corners, scales_h, scales_w, grad_input); + } + + // aten::_upsample_bicubic2d_aa_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & _upsample_bicubic2d_aa_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::_upsample_bicubic2d_aa_backward_grad_input::redispatch(dispatchKeySet, grad_output, output_size, input_size, align_corners, scales_h, scales_w, grad_input); + } + + // aten::_upsample_bicubic2d_aa_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & _upsample_bicubic2d_aa_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & grad_input) { + return at::_ops::_upsample_bicubic2d_aa_backward_grad_input::redispatch(dispatchKeySet, grad_output, output_size, input_size, align_corners, scales_h, scales_w, grad_input); + } + + // aten::_upsample_bicubic2d_aa_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor + inline at::Tensor _upsample_bicubic2d_aa_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, bool align_corners, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::_upsample_bicubic2d_aa_backward::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), align_corners, scales_h, scales_w); + } + + // aten::_upsample_bicubic2d_aa_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor + inline at::Tensor _upsample_bicubic2d_aa_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::_upsample_bicubic2d_aa_backward::redispatch(dispatchKeySet, grad_output, output_size, input_size, align_corners, scales_h, scales_w); + } + + // aten::upsample_trilinear3d.out(Tensor self, SymInt[3] output_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & upsample_trilinear3d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef output_size, bool align_corners, ::std::optional scales_d=::std::nullopt, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::upsample_trilinear3d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), align_corners, scales_d, scales_h, scales_w, out); + } + + // aten::upsample_trilinear3d.out(Tensor self, SymInt[3] output_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & upsample_trilinear3d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, bool align_corners, ::std::optional scales_d, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & out) { + return at::_ops::upsample_trilinear3d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), align_corners, scales_d, scales_h, scales_w, out); + } + + // aten::upsample_trilinear3d.out(Tensor self, SymInt[3] output_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & upsample_trilinear3d_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, ::std::optional scales_d=::std::nullopt, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::upsample_trilinear3d_out::redispatch(dispatchKeySet, self, output_size, align_corners, scales_d, scales_h, scales_w, out); + } + + // aten::upsample_trilinear3d.out(Tensor self, SymInt[3] output_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & upsample_trilinear3d_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, ::std::optional scales_d, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & out) { + return at::_ops::upsample_trilinear3d_out::redispatch(dispatchKeySet, self, output_size, align_corners, scales_d, scales_h, scales_w, out); + } + + // aten::upsample_trilinear3d(Tensor self, SymInt[3] output_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor + inline at::Tensor upsample_trilinear3d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, bool align_corners, ::std::optional scales_d=::std::nullopt, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::upsample_trilinear3d::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), align_corners, scales_d, scales_h, scales_w); + } + + // aten::upsample_trilinear3d(Tensor self, SymInt[3] output_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor + inline at::Tensor upsample_trilinear3d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, ::std::optional scales_d=::std::nullopt, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::upsample_trilinear3d::redispatch(dispatchKeySet, self, output_size, align_corners, scales_d, scales_h, scales_w); + } + + // aten::upsample_trilinear3d_backward.grad_input(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & upsample_trilinear3d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, bool align_corners, ::std::optional scales_d=::std::nullopt, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::upsample_trilinear3d_backward_grad_input::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), align_corners, scales_d, scales_h, scales_w, grad_input); + } + + // aten::upsample_trilinear3d_backward.grad_input(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & upsample_trilinear3d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, bool align_corners, ::std::optional scales_d, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & grad_input) { + return at::_ops::upsample_trilinear3d_backward_grad_input::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), align_corners, scales_d, scales_h, scales_w, grad_input); + } + + // aten::upsample_trilinear3d_backward.grad_input(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & upsample_trilinear3d_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, ::std::optional scales_d=::std::nullopt, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::upsample_trilinear3d_backward_grad_input::redispatch(dispatchKeySet, grad_output, output_size, input_size, align_corners, scales_d, scales_h, scales_w, grad_input); + } + + // aten::upsample_trilinear3d_backward.grad_input(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & upsample_trilinear3d_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, ::std::optional scales_d, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & grad_input) { + return at::_ops::upsample_trilinear3d_backward_grad_input::redispatch(dispatchKeySet, grad_output, output_size, input_size, align_corners, scales_d, scales_h, scales_w, grad_input); + } + + // aten::upsample_trilinear3d_backward(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor + inline at::Tensor upsample_trilinear3d_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, bool align_corners, ::std::optional scales_d=::std::nullopt, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::upsample_trilinear3d_backward::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), align_corners, scales_d, scales_h, scales_w); + } + + // aten::upsample_trilinear3d_backward(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor + inline at::Tensor upsample_trilinear3d_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, ::std::optional scales_d=::std::nullopt, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::upsample_trilinear3d_backward::redispatch(dispatchKeySet, grad_output, output_size, input_size, align_corners, scales_d, scales_h, scales_w); + } + + // aten::upsample_nearest1d.out(Tensor self, SymInt[1] output_size, float? scales=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & upsample_nearest1d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef output_size, ::std::optional scales=::std::nullopt) { + return at::_ops::upsample_nearest1d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), scales, out); + } + + // aten::upsample_nearest1d.out(Tensor self, SymInt[1] output_size, float? scales=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & upsample_nearest1d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, ::std::optional scales, at::Tensor & out) { + return at::_ops::upsample_nearest1d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), scales, out); + } + + // aten::upsample_nearest1d.out(Tensor self, SymInt[1] output_size, float? scales=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & upsample_nearest1d_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef output_size, ::std::optional scales=::std::nullopt) { + return at::_ops::upsample_nearest1d_out::redispatch(dispatchKeySet, self, output_size, scales, out); + } + + // aten::upsample_nearest1d.out(Tensor self, SymInt[1] output_size, float? scales=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & upsample_nearest1d_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size, ::std::optional scales, at::Tensor & out) { + return at::_ops::upsample_nearest1d_out::redispatch(dispatchKeySet, self, output_size, scales, out); + } + + // aten::_upsample_nearest_exact1d.out(Tensor self, SymInt[1] output_size, float? scales=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _upsample_nearest_exact1d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef output_size, ::std::optional scales=::std::nullopt) { + return at::_ops::_upsample_nearest_exact1d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), scales, out); + } + + // aten::_upsample_nearest_exact1d.out(Tensor self, SymInt[1] output_size, float? scales=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _upsample_nearest_exact1d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, ::std::optional scales, at::Tensor & out) { + return at::_ops::_upsample_nearest_exact1d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), scales, out); + } + + // aten::_upsample_nearest_exact1d.out(Tensor self, SymInt[1] output_size, float? scales=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _upsample_nearest_exact1d_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef output_size, ::std::optional scales=::std::nullopt) { + return at::_ops::_upsample_nearest_exact1d_out::redispatch(dispatchKeySet, self, output_size, scales, out); + } + + // aten::_upsample_nearest_exact1d.out(Tensor self, SymInt[1] output_size, float? scales=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _upsample_nearest_exact1d_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size, ::std::optional scales, at::Tensor & out) { + return at::_ops::_upsample_nearest_exact1d_out::redispatch(dispatchKeySet, self, output_size, scales, out); + } + + // aten::upsample_nearest1d(Tensor self, SymInt[1] output_size, float? scales=None) -> Tensor + inline at::Tensor upsample_nearest1d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, ::std::optional scales=::std::nullopt) { + return at::_ops::upsample_nearest1d::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), scales); + } + + // aten::upsample_nearest1d(Tensor self, SymInt[1] output_size, float? scales=None) -> Tensor + inline at::Tensor upsample_nearest1d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size, ::std::optional scales=::std::nullopt) { + return at::_ops::upsample_nearest1d::redispatch(dispatchKeySet, self, output_size, scales); + } + + // aten::_upsample_nearest_exact1d(Tensor self, SymInt[1] output_size, float? scales=None) -> Tensor + inline at::Tensor _upsample_nearest_exact1d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, ::std::optional scales=::std::nullopt) { + return at::_ops::_upsample_nearest_exact1d::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), scales); + } + + // aten::_upsample_nearest_exact1d(Tensor self, SymInt[1] output_size, float? scales=None) -> Tensor + inline at::Tensor _upsample_nearest_exact1d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size, ::std::optional scales=::std::nullopt) { + return at::_ops::_upsample_nearest_exact1d::redispatch(dispatchKeySet, self, output_size, scales); + } + + // aten::upsample_nearest1d_backward.grad_input(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, float? scales=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & upsample_nearest1d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, ::std::optional scales=::std::nullopt) { + return at::_ops::upsample_nearest1d_backward_grad_input::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), scales, grad_input); + } + + // aten::upsample_nearest1d_backward.grad_input(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, float? scales=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & upsample_nearest1d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, ::std::optional scales, at::Tensor & grad_input) { + return at::_ops::upsample_nearest1d_backward_grad_input::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), scales, grad_input); + } + + // aten::upsample_nearest1d_backward.grad_input(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, float? scales=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & upsample_nearest1d_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, ::std::optional scales=::std::nullopt) { + return at::_ops::upsample_nearest1d_backward_grad_input::redispatch(dispatchKeySet, grad_output, output_size, input_size, scales, grad_input); + } + + // aten::upsample_nearest1d_backward.grad_input(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, float? scales=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & upsample_nearest1d_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, ::std::optional scales, at::Tensor & grad_input) { + return at::_ops::upsample_nearest1d_backward_grad_input::redispatch(dispatchKeySet, grad_output, output_size, input_size, scales, grad_input); + } + + // aten::_upsample_nearest_exact1d_backward.grad_input(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, float? scales=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & _upsample_nearest_exact1d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, ::std::optional scales=::std::nullopt) { + return at::_ops::_upsample_nearest_exact1d_backward_grad_input::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), scales, grad_input); + } + + // aten::_upsample_nearest_exact1d_backward.grad_input(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, float? scales=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & _upsample_nearest_exact1d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, ::std::optional scales, at::Tensor & grad_input) { + return at::_ops::_upsample_nearest_exact1d_backward_grad_input::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), scales, grad_input); + } + + // aten::_upsample_nearest_exact1d_backward.grad_input(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, float? scales=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & _upsample_nearest_exact1d_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, ::std::optional scales=::std::nullopt) { + return at::_ops::_upsample_nearest_exact1d_backward_grad_input::redispatch(dispatchKeySet, grad_output, output_size, input_size, scales, grad_input); + } + + // aten::_upsample_nearest_exact1d_backward.grad_input(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, float? scales=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & _upsample_nearest_exact1d_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, ::std::optional scales, at::Tensor & grad_input) { + return at::_ops::_upsample_nearest_exact1d_backward_grad_input::redispatch(dispatchKeySet, grad_output, output_size, input_size, scales, grad_input); + } + + // aten::upsample_nearest1d_backward(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, float? scales=None) -> Tensor + inline at::Tensor upsample_nearest1d_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, ::std::optional scales=::std::nullopt) { + return at::_ops::upsample_nearest1d_backward::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), scales); + } + + // aten::upsample_nearest1d_backward(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, float? scales=None) -> Tensor + inline at::Tensor upsample_nearest1d_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, ::std::optional scales=::std::nullopt) { + return at::_ops::upsample_nearest1d_backward::redispatch(dispatchKeySet, grad_output, output_size, input_size, scales); + } + + // aten::_upsample_nearest_exact1d_backward(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, float? scales=None) -> Tensor + inline at::Tensor _upsample_nearest_exact1d_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, ::std::optional scales=::std::nullopt) { + return at::_ops::_upsample_nearest_exact1d_backward::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), scales); + } + + // aten::_upsample_nearest_exact1d_backward(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, float? scales=None) -> Tensor + inline at::Tensor _upsample_nearest_exact1d_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, ::std::optional scales=::std::nullopt) { + return at::_ops::_upsample_nearest_exact1d_backward::redispatch(dispatchKeySet, grad_output, output_size, input_size, scales); + } + + // aten::upsample_nearest2d.out(Tensor self, SymInt[2] output_size, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & upsample_nearest2d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef output_size, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::upsample_nearest2d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), scales_h, scales_w, out); + } + + // aten::upsample_nearest2d.out(Tensor self, SymInt[2] output_size, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & upsample_nearest2d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & out) { + return at::_ops::upsample_nearest2d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), scales_h, scales_w, out); + } + + // aten::upsample_nearest2d.out(Tensor self, SymInt[2] output_size, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & upsample_nearest2d_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef output_size, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::upsample_nearest2d_out::redispatch(dispatchKeySet, self, output_size, scales_h, scales_w, out); + } + + // aten::upsample_nearest2d.out(Tensor self, SymInt[2] output_size, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & upsample_nearest2d_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & out) { + return at::_ops::upsample_nearest2d_out::redispatch(dispatchKeySet, self, output_size, scales_h, scales_w, out); + } + + // aten::_upsample_nearest_exact2d.out(Tensor self, SymInt[2] output_size, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _upsample_nearest_exact2d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef output_size, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::_upsample_nearest_exact2d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), scales_h, scales_w, out); + } + + // aten::_upsample_nearest_exact2d.out(Tensor self, SymInt[2] output_size, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _upsample_nearest_exact2d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & out) { + return at::_ops::_upsample_nearest_exact2d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), scales_h, scales_w, out); + } + + // aten::_upsample_nearest_exact2d.out(Tensor self, SymInt[2] output_size, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _upsample_nearest_exact2d_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef output_size, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::_upsample_nearest_exact2d_out::redispatch(dispatchKeySet, self, output_size, scales_h, scales_w, out); + } + + // aten::_upsample_nearest_exact2d.out(Tensor self, SymInt[2] output_size, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _upsample_nearest_exact2d_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & out) { + return at::_ops::_upsample_nearest_exact2d_out::redispatch(dispatchKeySet, self, output_size, scales_h, scales_w, out); + } + + // aten::upsample_nearest2d(Tensor self, SymInt[2] output_size, float? scales_h=None, float? scales_w=None) -> Tensor + inline at::Tensor upsample_nearest2d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::upsample_nearest2d::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), scales_h, scales_w); + } + + // aten::upsample_nearest2d(Tensor self, SymInt[2] output_size, float? scales_h=None, float? scales_w=None) -> Tensor + inline at::Tensor upsample_nearest2d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::upsample_nearest2d::redispatch(dispatchKeySet, self, output_size, scales_h, scales_w); + } + + // aten::_upsample_nearest_exact2d(Tensor self, SymInt[2] output_size, float? scales_h=None, float? scales_w=None) -> Tensor + inline at::Tensor _upsample_nearest_exact2d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::_upsample_nearest_exact2d::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), scales_h, scales_w); + } + + // aten::_upsample_nearest_exact2d(Tensor self, SymInt[2] output_size, float? scales_h=None, float? scales_w=None) -> Tensor + inline at::Tensor _upsample_nearest_exact2d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::_upsample_nearest_exact2d::redispatch(dispatchKeySet, self, output_size, scales_h, scales_w); + } + + // aten::upsample_nearest2d_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & upsample_nearest2d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::upsample_nearest2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), scales_h, scales_w, grad_input); + } + + // aten::upsample_nearest2d_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & upsample_nearest2d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & grad_input) { + return at::_ops::upsample_nearest2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), scales_h, scales_w, grad_input); + } + + // aten::upsample_nearest2d_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & upsample_nearest2d_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::upsample_nearest2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, output_size, input_size, scales_h, scales_w, grad_input); + } + + // aten::upsample_nearest2d_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & upsample_nearest2d_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & grad_input) { + return at::_ops::upsample_nearest2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, output_size, input_size, scales_h, scales_w, grad_input); + } + + // aten::_upsample_nearest_exact2d_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & _upsample_nearest_exact2d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::_upsample_nearest_exact2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), scales_h, scales_w, grad_input); + } + + // aten::_upsample_nearest_exact2d_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & _upsample_nearest_exact2d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & grad_input) { + return at::_ops::_upsample_nearest_exact2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), scales_h, scales_w, grad_input); + } + + // aten::_upsample_nearest_exact2d_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & _upsample_nearest_exact2d_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::_upsample_nearest_exact2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, output_size, input_size, scales_h, scales_w, grad_input); + } + + // aten::_upsample_nearest_exact2d_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & _upsample_nearest_exact2d_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & grad_input) { + return at::_ops::_upsample_nearest_exact2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, output_size, input_size, scales_h, scales_w, grad_input); + } + + // aten::upsample_nearest2d_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, float? scales_h=None, float? scales_w=None) -> Tensor + inline at::Tensor upsample_nearest2d_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::upsample_nearest2d_backward::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), scales_h, scales_w); + } + + // aten::upsample_nearest2d_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, float? scales_h=None, float? scales_w=None) -> Tensor + inline at::Tensor upsample_nearest2d_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::upsample_nearest2d_backward::redispatch(dispatchKeySet, grad_output, output_size, input_size, scales_h, scales_w); + } + + // aten::_upsample_nearest_exact2d_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, float? scales_h=None, float? scales_w=None) -> Tensor + inline at::Tensor _upsample_nearest_exact2d_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::_upsample_nearest_exact2d_backward::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), scales_h, scales_w); + } + + // aten::_upsample_nearest_exact2d_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, float? scales_h=None, float? scales_w=None) -> Tensor + inline at::Tensor _upsample_nearest_exact2d_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::_upsample_nearest_exact2d_backward::redispatch(dispatchKeySet, grad_output, output_size, input_size, scales_h, scales_w); + } + + // aten::upsample_nearest3d.out(Tensor self, SymInt[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & upsample_nearest3d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef output_size, ::std::optional scales_d=::std::nullopt, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::upsample_nearest3d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), scales_d, scales_h, scales_w, out); + } + + // aten::upsample_nearest3d.out(Tensor self, SymInt[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & upsample_nearest3d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, ::std::optional scales_d, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & out) { + return at::_ops::upsample_nearest3d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), scales_d, scales_h, scales_w, out); + } + + // aten::upsample_nearest3d.out(Tensor self, SymInt[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & upsample_nearest3d_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef output_size, ::std::optional scales_d=::std::nullopt, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::upsample_nearest3d_out::redispatch(dispatchKeySet, self, output_size, scales_d, scales_h, scales_w, out); + } + + // aten::upsample_nearest3d.out(Tensor self, SymInt[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & upsample_nearest3d_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size, ::std::optional scales_d, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & out) { + return at::_ops::upsample_nearest3d_out::redispatch(dispatchKeySet, self, output_size, scales_d, scales_h, scales_w, out); + } + + // aten::_upsample_nearest_exact3d.out(Tensor self, SymInt[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _upsample_nearest_exact3d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef output_size, ::std::optional scales_d=::std::nullopt, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::_upsample_nearest_exact3d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), scales_d, scales_h, scales_w, out); + } + + // aten::_upsample_nearest_exact3d.out(Tensor self, SymInt[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _upsample_nearest_exact3d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, ::std::optional scales_d, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & out) { + return at::_ops::_upsample_nearest_exact3d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), scales_d, scales_h, scales_w, out); + } + + // aten::_upsample_nearest_exact3d.out(Tensor self, SymInt[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _upsample_nearest_exact3d_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef output_size, ::std::optional scales_d=::std::nullopt, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::_upsample_nearest_exact3d_out::redispatch(dispatchKeySet, self, output_size, scales_d, scales_h, scales_w, out); + } + + // aten::_upsample_nearest_exact3d.out(Tensor self, SymInt[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _upsample_nearest_exact3d_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size, ::std::optional scales_d, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & out) { + return at::_ops::_upsample_nearest_exact3d_out::redispatch(dispatchKeySet, self, output_size, scales_d, scales_h, scales_w, out); + } + + // aten::upsample_nearest3d(Tensor self, SymInt[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor + inline at::Tensor upsample_nearest3d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, ::std::optional scales_d=::std::nullopt, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::upsample_nearest3d::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), scales_d, scales_h, scales_w); + } + + // aten::upsample_nearest3d(Tensor self, SymInt[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor + inline at::Tensor upsample_nearest3d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size, ::std::optional scales_d=::std::nullopt, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::upsample_nearest3d::redispatch(dispatchKeySet, self, output_size, scales_d, scales_h, scales_w); + } + + // aten::_upsample_nearest_exact3d(Tensor self, SymInt[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor + inline at::Tensor _upsample_nearest_exact3d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, ::std::optional scales_d=::std::nullopt, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::_upsample_nearest_exact3d::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), scales_d, scales_h, scales_w); + } + + // aten::_upsample_nearest_exact3d(Tensor self, SymInt[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor + inline at::Tensor _upsample_nearest_exact3d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size, ::std::optional scales_d=::std::nullopt, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::_upsample_nearest_exact3d::redispatch(dispatchKeySet, self, output_size, scales_d, scales_h, scales_w); + } + + // aten::upsample_nearest3d_backward.grad_input(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & upsample_nearest3d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, ::std::optional scales_d=::std::nullopt, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::upsample_nearest3d_backward_grad_input::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), scales_d, scales_h, scales_w, grad_input); + } + + // aten::upsample_nearest3d_backward.grad_input(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & upsample_nearest3d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, ::std::optional scales_d, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & grad_input) { + return at::_ops::upsample_nearest3d_backward_grad_input::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), scales_d, scales_h, scales_w, grad_input); + } + + // aten::upsample_nearest3d_backward.grad_input(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & upsample_nearest3d_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, ::std::optional scales_d=::std::nullopt, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::upsample_nearest3d_backward_grad_input::redispatch(dispatchKeySet, grad_output, output_size, input_size, scales_d, scales_h, scales_w, grad_input); + } + + // aten::upsample_nearest3d_backward.grad_input(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & upsample_nearest3d_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, ::std::optional scales_d, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & grad_input) { + return at::_ops::upsample_nearest3d_backward_grad_input::redispatch(dispatchKeySet, grad_output, output_size, input_size, scales_d, scales_h, scales_w, grad_input); + } + + // aten::_upsample_nearest_exact3d_backward.grad_input(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & _upsample_nearest_exact3d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, ::std::optional scales_d=::std::nullopt, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::_upsample_nearest_exact3d_backward_grad_input::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), scales_d, scales_h, scales_w, grad_input); + } + + // aten::_upsample_nearest_exact3d_backward.grad_input(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & _upsample_nearest_exact3d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, ::std::optional scales_d, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & grad_input) { + return at::_ops::_upsample_nearest_exact3d_backward_grad_input::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), scales_d, scales_h, scales_w, grad_input); + } + + // aten::_upsample_nearest_exact3d_backward.grad_input(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & _upsample_nearest_exact3d_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, ::std::optional scales_d=::std::nullopt, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::_upsample_nearest_exact3d_backward_grad_input::redispatch(dispatchKeySet, grad_output, output_size, input_size, scales_d, scales_h, scales_w, grad_input); + } + + // aten::_upsample_nearest_exact3d_backward.grad_input(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & _upsample_nearest_exact3d_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, ::std::optional scales_d, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & grad_input) { + return at::_ops::_upsample_nearest_exact3d_backward_grad_input::redispatch(dispatchKeySet, grad_output, output_size, input_size, scales_d, scales_h, scales_w, grad_input); + } + + // aten::upsample_nearest3d_backward(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor + inline at::Tensor upsample_nearest3d_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, ::std::optional scales_d=::std::nullopt, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::upsample_nearest3d_backward::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), scales_d, scales_h, scales_w); + } + + // aten::upsample_nearest3d_backward(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor + inline at::Tensor upsample_nearest3d_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, ::std::optional scales_d=::std::nullopt, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::upsample_nearest3d_backward::redispatch(dispatchKeySet, grad_output, output_size, input_size, scales_d, scales_h, scales_w); + } + + // aten::_upsample_nearest_exact3d_backward(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor + inline at::Tensor _upsample_nearest_exact3d_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, ::std::optional scales_d=::std::nullopt, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::_upsample_nearest_exact3d_backward::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), scales_d, scales_h, scales_w); + } + + // aten::_upsample_nearest_exact3d_backward(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor + inline at::Tensor _upsample_nearest_exact3d_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, ::std::optional scales_d=::std::nullopt, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::_upsample_nearest_exact3d_backward::redispatch(dispatchKeySet, grad_output, output_size, input_size, scales_d, scales_h, scales_w); + } + + // aten::sigmoid_backward.grad_input(Tensor grad_output, Tensor output, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & sigmoid_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & output) { + return at::_ops::sigmoid_backward_grad_input::redispatch(dispatchKeySet, grad_output, output, grad_input); + } + + // aten::sigmoid_backward.grad_input(Tensor grad_output, Tensor output, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & sigmoid_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & output, at::Tensor & grad_input) { + return at::_ops::sigmoid_backward_grad_input::redispatch(dispatchKeySet, grad_output, output, grad_input); + } + + // aten::sigmoid_backward(Tensor grad_output, Tensor output) -> Tensor + inline at::Tensor sigmoid_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & output) { + return at::_ops::sigmoid_backward::redispatch(dispatchKeySet, grad_output, output); + } + + // aten::logit_backward.grad_input(Tensor grad_output, Tensor self, float? eps=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & logit_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, ::std::optional eps=::std::nullopt) { + return at::_ops::logit_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, eps, grad_input); + } + + // aten::logit_backward.grad_input(Tensor grad_output, Tensor self, float? eps=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & logit_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, ::std::optional eps, at::Tensor & grad_input) { + return at::_ops::logit_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, eps, grad_input); + } + + // aten::logit_backward(Tensor grad_output, Tensor self, float? eps=None) -> Tensor + inline at::Tensor logit_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, ::std::optional eps=::std::nullopt) { + return at::_ops::logit_backward::redispatch(dispatchKeySet, grad_output, self, eps); + } + + // aten::tanh_backward.grad_input(Tensor grad_output, Tensor output, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & tanh_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & output) { + return at::_ops::tanh_backward_grad_input::redispatch(dispatchKeySet, grad_output, output, grad_input); + } + + // aten::tanh_backward.grad_input(Tensor grad_output, Tensor output, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & tanh_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & output, at::Tensor & grad_input) { + return at::_ops::tanh_backward_grad_input::redispatch(dispatchKeySet, grad_output, output, grad_input); + } + + // aten::tanh_backward(Tensor grad_output, Tensor output) -> Tensor + inline at::Tensor tanh_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & output) { + return at::_ops::tanh_backward::redispatch(dispatchKeySet, grad_output, output); + } + + // aten::slow_conv_transpose2d.out(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] output_padding=0, SymInt[2] dilation=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & slow_conv_transpose2d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, const ::std::optional & bias={}, at::IntArrayRef stride=1, at::IntArrayRef padding=0, at::IntArrayRef output_padding=0, at::IntArrayRef dilation=1) { + return at::_ops::slow_conv_transpose2d_out::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(kernel_size), bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(output_padding), c10::fromIntArrayRefSlow(dilation), out); + } + + // aten::slow_conv_transpose2d.out(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] output_padding=0, SymInt[2] dilation=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & slow_conv_transpose2d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, const ::std::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef output_padding, at::IntArrayRef dilation, at::Tensor & out) { + return at::_ops::slow_conv_transpose2d_out::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(kernel_size), bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(output_padding), c10::fromIntArrayRefSlow(dilation), out); + } + + // aten::slow_conv_transpose2d.out(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] output_padding=0, SymInt[2] dilation=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & slow_conv_transpose2d_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias={}, c10::SymIntArrayRef stride=c10::SymInt(1), c10::SymIntArrayRef padding=c10::SymInt(0), c10::SymIntArrayRef output_padding=c10::SymInt(0), c10::SymIntArrayRef dilation=c10::SymInt(1)) { + return at::_ops::slow_conv_transpose2d_out::redispatch(dispatchKeySet, self, weight, kernel_size, bias, stride, padding, output_padding, dilation, out); + } + + // aten::slow_conv_transpose2d.out(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] output_padding=0, SymInt[2] dilation=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & slow_conv_transpose2d_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymIntArrayRef dilation, at::Tensor & out) { + return at::_ops::slow_conv_transpose2d_out::redispatch(dispatchKeySet, self, weight, kernel_size, bias, stride, padding, output_padding, dilation, out); + } + + // aten::slow_conv_transpose2d(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] output_padding=0, SymInt[2] dilation=1) -> Tensor + inline at::Tensor slow_conv_transpose2d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, const ::std::optional & bias={}, at::IntArrayRef stride=1, at::IntArrayRef padding=0, at::IntArrayRef output_padding=0, at::IntArrayRef dilation=1) { + return at::_ops::slow_conv_transpose2d::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(kernel_size), bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(output_padding), c10::fromIntArrayRefSlow(dilation)); + } + + // aten::slow_conv_transpose2d(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] output_padding=0, SymInt[2] dilation=1) -> Tensor + inline at::Tensor slow_conv_transpose2d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias={}, c10::SymIntArrayRef stride=c10::SymInt(1), c10::SymIntArrayRef padding=c10::SymInt(0), c10::SymIntArrayRef output_padding=c10::SymInt(0), c10::SymIntArrayRef dilation=c10::SymInt(1)) { + return at::_ops::slow_conv_transpose2d::redispatch(dispatchKeySet, self, weight, kernel_size, bias, stride, padding, output_padding, dilation); + } + + // aten::slow_conv_transpose3d.out(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, SymInt[3] output_padding=0, SymInt[3] dilation=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & slow_conv_transpose3d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, const ::std::optional & bias={}, at::IntArrayRef stride=1, at::IntArrayRef padding=0, at::IntArrayRef output_padding=0, at::IntArrayRef dilation=1) { + return at::_ops::slow_conv_transpose3d_out::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(kernel_size), bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(output_padding), c10::fromIntArrayRefSlow(dilation), out); + } + + // aten::slow_conv_transpose3d.out(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, SymInt[3] output_padding=0, SymInt[3] dilation=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & slow_conv_transpose3d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, const ::std::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef output_padding, at::IntArrayRef dilation, at::Tensor & out) { + return at::_ops::slow_conv_transpose3d_out::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(kernel_size), bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(output_padding), c10::fromIntArrayRefSlow(dilation), out); + } + + // aten::slow_conv_transpose3d.out(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, SymInt[3] output_padding=0, SymInt[3] dilation=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & slow_conv_transpose3d_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias={}, c10::SymIntArrayRef stride=c10::SymInt(1), c10::SymIntArrayRef padding=c10::SymInt(0), c10::SymIntArrayRef output_padding=c10::SymInt(0), c10::SymIntArrayRef dilation=c10::SymInt(1)) { + return at::_ops::slow_conv_transpose3d_out::redispatch(dispatchKeySet, self, weight, kernel_size, bias, stride, padding, output_padding, dilation, out); + } + + // aten::slow_conv_transpose3d.out(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, SymInt[3] output_padding=0, SymInt[3] dilation=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & slow_conv_transpose3d_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymIntArrayRef dilation, at::Tensor & out) { + return at::_ops::slow_conv_transpose3d_out::redispatch(dispatchKeySet, self, weight, kernel_size, bias, stride, padding, output_padding, dilation, out); + } + + // aten::slow_conv_transpose3d(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, SymInt[3] output_padding=0, SymInt[3] dilation=1) -> Tensor + inline at::Tensor slow_conv_transpose3d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, const ::std::optional & bias={}, at::IntArrayRef stride=1, at::IntArrayRef padding=0, at::IntArrayRef output_padding=0, at::IntArrayRef dilation=1) { + return at::_ops::slow_conv_transpose3d::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(kernel_size), bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(output_padding), c10::fromIntArrayRefSlow(dilation)); + } + + // aten::slow_conv_transpose3d(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, SymInt[3] output_padding=0, SymInt[3] dilation=1) -> Tensor + inline at::Tensor slow_conv_transpose3d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias={}, c10::SymIntArrayRef stride=c10::SymInt(1), c10::SymIntArrayRef padding=c10::SymInt(0), c10::SymIntArrayRef output_padding=c10::SymInt(0), c10::SymIntArrayRef dilation=c10::SymInt(1)) { + return at::_ops::slow_conv_transpose3d::redispatch(dispatchKeySet, self, weight, kernel_size, bias, stride, padding, output_padding, dilation); + } + + // aten::thnn_conv2d.out(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & thnn_conv2d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, const ::std::optional & bias={}, at::IntArrayRef stride=1, at::IntArrayRef padding=0) { + return at::_ops::thnn_conv2d_out::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(kernel_size), bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), out); + } + + // aten::thnn_conv2d.out(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & thnn_conv2d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, const ::std::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::Tensor & out) { + return at::_ops::thnn_conv2d_out::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(kernel_size), bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), out); + } + + // aten::thnn_conv2d.out(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & thnn_conv2d_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias={}, c10::SymIntArrayRef stride=c10::SymInt(1), c10::SymIntArrayRef padding=c10::SymInt(0)) { + return at::_ops::thnn_conv2d_out::redispatch(dispatchKeySet, self, weight, kernel_size, bias, stride, padding, out); + } + + // aten::thnn_conv2d.out(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & thnn_conv2d_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, at::Tensor & out) { + return at::_ops::thnn_conv2d_out::redispatch(dispatchKeySet, self, weight, kernel_size, bias, stride, padding, out); + } + + // aten::thnn_conv2d(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0) -> Tensor + inline at::Tensor thnn_conv2d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, const ::std::optional & bias={}, at::IntArrayRef stride=1, at::IntArrayRef padding=0) { + return at::_ops::thnn_conv2d::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(kernel_size), bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding)); + } + + // aten::thnn_conv2d(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0) -> Tensor + inline at::Tensor thnn_conv2d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias={}, c10::SymIntArrayRef stride=c10::SymInt(1), c10::SymIntArrayRef padding=c10::SymInt(0)) { + return at::_ops::thnn_conv2d::redispatch(dispatchKeySet, self, weight, kernel_size, bias, stride, padding); + } + + // aten::_slow_conv2d_forward.output(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias, SymInt[2] stride, SymInt[2] padding, *, Tensor(a!) output) -> Tensor(a!) + inline at::Tensor & _slow_conv2d_forward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & output, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, const ::std::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding) { + return at::_ops::_slow_conv2d_forward_output::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(kernel_size), bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), output); + } + + // aten::_slow_conv2d_forward.output(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias, SymInt[2] stride, SymInt[2] padding, *, Tensor(a!) output) -> Tensor(a!) + inline at::Tensor & _slow_conv2d_forward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, const ::std::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::Tensor & output) { + return at::_ops::_slow_conv2d_forward_output::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(kernel_size), bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), output); + } + + // aten::_slow_conv2d_forward.output(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias, SymInt[2] stride, SymInt[2] padding, *, Tensor(a!) output) -> Tensor(a!) + inline at::Tensor & _slow_conv2d_forward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & output, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding) { + return at::_ops::_slow_conv2d_forward_output::redispatch(dispatchKeySet, self, weight, kernel_size, bias, stride, padding, output); + } + + // aten::_slow_conv2d_forward.output(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias, SymInt[2] stride, SymInt[2] padding, *, Tensor(a!) output) -> Tensor(a!) + inline at::Tensor & _slow_conv2d_forward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, at::Tensor & output) { + return at::_ops::_slow_conv2d_forward_output::redispatch(dispatchKeySet, self, weight, kernel_size, bias, stride, padding, output); + } + + // aten::_slow_conv2d_forward(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias, SymInt[2] stride, SymInt[2] padding) -> Tensor + inline at::Tensor _slow_conv2d_forward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, const ::std::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding) { + return at::_ops::_slow_conv2d_forward::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(kernel_size), bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding)); + } + + // aten::_slow_conv2d_forward(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias, SymInt[2] stride, SymInt[2] padding) -> Tensor + inline at::Tensor _slow_conv2d_forward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding) { + return at::_ops::_slow_conv2d_forward::redispatch(dispatchKeySet, self, weight, kernel_size, bias, stride, padding); + } + + // aten::_slow_conv2d_backward.grad_input(Tensor grad_output, Tensor self, Tensor weight, SymInt[2] kernel_size, SymInt[2] stride, SymInt[2] padding, *, Tensor(a!) grad_input, Tensor(b!) grad_weight, Tensor(c!) grad_bias) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple _slow_conv2d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, at::Tensor & grad_weight, at::Tensor & grad_bias, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding) { + return at::_ops::_slow_conv2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, weight, c10::fromIntArrayRefSlow(kernel_size), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), grad_input, grad_weight, grad_bias); + } + + // aten::_slow_conv2d_backward.grad_input(Tensor grad_output, Tensor self, Tensor weight, SymInt[2] kernel_size, SymInt[2] stride, SymInt[2] padding, *, Tensor(a!) grad_input, Tensor(b!) grad_weight, Tensor(c!) grad_bias) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple _slow_conv2d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::Tensor & grad_input, at::Tensor & grad_weight, at::Tensor & grad_bias) { + return at::_ops::_slow_conv2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, weight, c10::fromIntArrayRefSlow(kernel_size), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), grad_input, grad_weight, grad_bias); + } + + // aten::_slow_conv2d_backward.grad_input(Tensor grad_output, Tensor self, Tensor weight, SymInt[2] kernel_size, SymInt[2] stride, SymInt[2] padding, *, Tensor(a!) grad_input, Tensor(b!) grad_weight, Tensor(c!) grad_bias) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple _slow_conv2d_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, at::Tensor & grad_weight, at::Tensor & grad_bias, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding) { + return at::_ops::_slow_conv2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, weight, kernel_size, stride, padding, grad_input, grad_weight, grad_bias); + } + + // aten::_slow_conv2d_backward.grad_input(Tensor grad_output, Tensor self, Tensor weight, SymInt[2] kernel_size, SymInt[2] stride, SymInt[2] padding, *, Tensor(a!) grad_input, Tensor(b!) grad_weight, Tensor(c!) grad_bias) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple _slow_conv2d_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, at::Tensor & grad_input, at::Tensor & grad_weight, at::Tensor & grad_bias) { + return at::_ops::_slow_conv2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, weight, kernel_size, stride, padding, grad_input, grad_weight, grad_bias); + } + + // aten::_slow_conv2d_backward.output_mask(Tensor grad_output, Tensor self, Tensor weight, SymInt[2] kernel_size, SymInt[2] stride, SymInt[2] padding, bool[3] output_mask) -> (Tensor grad_input, Tensor grad_weight, Tensor grad_bias) + inline ::std::tuple _slow_conv2d_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, ::std::array output_mask) { + return at::_ops::_slow_conv2d_backward_output_mask::redispatch(dispatchKeySet, grad_output, self, weight, c10::fromIntArrayRefSlow(kernel_size), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), output_mask); + } + + // aten::_slow_conv2d_backward.output_mask(Tensor grad_output, Tensor self, Tensor weight, SymInt[2] kernel_size, SymInt[2] stride, SymInt[2] padding, bool[3] output_mask) -> (Tensor grad_input, Tensor grad_weight, Tensor grad_bias) + inline ::std::tuple _slow_conv2d_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, ::std::array output_mask) { + return at::_ops::_slow_conv2d_backward_output_mask::redispatch(dispatchKeySet, grad_output, self, weight, kernel_size, stride, padding, output_mask); + } + + // aten::_conv_depthwise2d.out(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias, SymInt[2] stride, SymInt[2] padding, SymInt[2] dilation, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _conv_depthwise2d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, const ::std::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation) { + return at::_ops::_conv_depthwise2d_out::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(kernel_size), bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), out); + } + + // aten::_conv_depthwise2d.out(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias, SymInt[2] stride, SymInt[2] padding, SymInt[2] dilation, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _conv_depthwise2d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, const ::std::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, at::Tensor & out) { + return at::_ops::_conv_depthwise2d_out::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(kernel_size), bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), out); + } + + // aten::_conv_depthwise2d.out(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias, SymInt[2] stride, SymInt[2] padding, SymInt[2] dilation, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _conv_depthwise2d_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation) { + return at::_ops::_conv_depthwise2d_out::redispatch(dispatchKeySet, self, weight, kernel_size, bias, stride, padding, dilation, out); + } + + // aten::_conv_depthwise2d.out(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias, SymInt[2] stride, SymInt[2] padding, SymInt[2] dilation, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _conv_depthwise2d_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, at::Tensor & out) { + return at::_ops::_conv_depthwise2d_out::redispatch(dispatchKeySet, self, weight, kernel_size, bias, stride, padding, dilation, out); + } + + // aten::_conv_depthwise2d(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias, SymInt[2] stride, SymInt[2] padding, SymInt[2] dilation) -> Tensor + inline at::Tensor _conv_depthwise2d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, const ::std::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation) { + return at::_ops::_conv_depthwise2d::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(kernel_size), bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation)); + } + + // aten::_conv_depthwise2d(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias, SymInt[2] stride, SymInt[2] padding, SymInt[2] dilation) -> Tensor + inline at::Tensor _conv_depthwise2d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation) { + return at::_ops::_conv_depthwise2d::redispatch(dispatchKeySet, self, weight, kernel_size, bias, stride, padding, dilation); + } + + // aten::conv_depthwise3d(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias, SymInt[3] stride, SymInt[3] padding, SymInt[3] dilation) -> Tensor + inline at::Tensor conv_depthwise3d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, const ::std::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation) { + return at::_ops::conv_depthwise3d::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(kernel_size), bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation)); + } + + // aten::conv_depthwise3d(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias, SymInt[3] stride, SymInt[3] padding, SymInt[3] dilation) -> Tensor + inline at::Tensor conv_depthwise3d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation) { + return at::_ops::conv_depthwise3d::redispatch(dispatchKeySet, self, weight, kernel_size, bias, stride, padding, dilation); + } + + // aten::slow_conv3d.out(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & slow_conv3d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, const ::std::optional & bias={}, at::IntArrayRef stride=1, at::IntArrayRef padding=0) { + return at::_ops::slow_conv3d_out::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(kernel_size), bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), out); + } + + // aten::slow_conv3d.out(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & slow_conv3d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, const ::std::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::Tensor & out) { + return at::_ops::slow_conv3d_out::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(kernel_size), bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), out); + } + + // aten::slow_conv3d.out(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & slow_conv3d_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias={}, c10::SymIntArrayRef stride=c10::SymInt(1), c10::SymIntArrayRef padding=c10::SymInt(0)) { + return at::_ops::slow_conv3d_out::redispatch(dispatchKeySet, self, weight, kernel_size, bias, stride, padding, out); + } + + // aten::slow_conv3d.out(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & slow_conv3d_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, at::Tensor & out) { + return at::_ops::slow_conv3d_out::redispatch(dispatchKeySet, self, weight, kernel_size, bias, stride, padding, out); + } + + // aten::slow_conv3d(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0) -> Tensor + inline at::Tensor slow_conv3d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, const ::std::optional & bias={}, at::IntArrayRef stride=1, at::IntArrayRef padding=0) { + return at::_ops::slow_conv3d::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(kernel_size), bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding)); + } + + // aten::slow_conv3d(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0) -> Tensor + inline at::Tensor slow_conv3d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias={}, c10::SymIntArrayRef stride=c10::SymInt(1), c10::SymIntArrayRef padding=c10::SymInt(0)) { + return at::_ops::slow_conv3d::redispatch(dispatchKeySet, self, weight, kernel_size, bias, stride, padding); + } + + // aten::slow_conv3d_forward.output(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias, SymInt[3] stride, SymInt[3] padding, *, Tensor(a!) output) -> Tensor(a!) + inline at::Tensor & slow_conv3d_forward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & output, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, const ::std::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding) { + return at::_ops::slow_conv3d_forward_output::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(kernel_size), bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), output); + } + + // aten::slow_conv3d_forward.output(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias, SymInt[3] stride, SymInt[3] padding, *, Tensor(a!) output) -> Tensor(a!) + inline at::Tensor & slow_conv3d_forward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, const ::std::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::Tensor & output) { + return at::_ops::slow_conv3d_forward_output::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(kernel_size), bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), output); + } + + // aten::slow_conv3d_forward.output(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias, SymInt[3] stride, SymInt[3] padding, *, Tensor(a!) output) -> Tensor(a!) + inline at::Tensor & slow_conv3d_forward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & output, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding) { + return at::_ops::slow_conv3d_forward_output::redispatch(dispatchKeySet, self, weight, kernel_size, bias, stride, padding, output); + } + + // aten::slow_conv3d_forward.output(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias, SymInt[3] stride, SymInt[3] padding, *, Tensor(a!) output) -> Tensor(a!) + inline at::Tensor & slow_conv3d_forward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, at::Tensor & output) { + return at::_ops::slow_conv3d_forward_output::redispatch(dispatchKeySet, self, weight, kernel_size, bias, stride, padding, output); + } + + // aten::slow_conv3d_forward(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias, SymInt[3] stride, SymInt[3] padding) -> Tensor + inline at::Tensor slow_conv3d_forward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, const ::std::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding) { + return at::_ops::slow_conv3d_forward::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(kernel_size), bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding)); + } + + // aten::slow_conv3d_forward(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias, SymInt[3] stride, SymInt[3] padding) -> Tensor + inline at::Tensor slow_conv3d_forward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding) { + return at::_ops::slow_conv3d_forward::redispatch(dispatchKeySet, self, weight, kernel_size, bias, stride, padding); + } + + // aten::slow_conv_dilated2d(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] dilation=1) -> Tensor + inline at::Tensor slow_conv_dilated2d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, const ::std::optional & bias={}, at::IntArrayRef stride=1, at::IntArrayRef padding=0, at::IntArrayRef dilation=1) { + return at::_ops::slow_conv_dilated2d::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(kernel_size), bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation)); + } + + // aten::slow_conv_dilated2d(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] dilation=1) -> Tensor + inline at::Tensor slow_conv_dilated2d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias={}, c10::SymIntArrayRef stride=c10::SymInt(1), c10::SymIntArrayRef padding=c10::SymInt(0), c10::SymIntArrayRef dilation=c10::SymInt(1)) { + return at::_ops::slow_conv_dilated2d::redispatch(dispatchKeySet, self, weight, kernel_size, bias, stride, padding, dilation); + } + + // aten::slow_conv_dilated3d(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, SymInt[3] dilation=1) -> Tensor + inline at::Tensor slow_conv_dilated3d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, const ::std::optional & bias={}, at::IntArrayRef stride=1, at::IntArrayRef padding=0, at::IntArrayRef dilation=1) { + return at::_ops::slow_conv_dilated3d::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(kernel_size), bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation)); + } + + // aten::slow_conv_dilated3d(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, SymInt[3] dilation=1) -> Tensor + inline at::Tensor slow_conv_dilated3d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias={}, c10::SymIntArrayRef stride=c10::SymInt(1), c10::SymIntArrayRef padding=c10::SymInt(0), c10::SymIntArrayRef dilation=c10::SymInt(1)) { + return at::_ops::slow_conv_dilated3d::redispatch(dispatchKeySet, self, weight, kernel_size, bias, stride, padding, dilation); + } + + // aten::col2im.out(Tensor self, SymInt[2] output_size, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & col2im_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef output_size, at::IntArrayRef kernel_size, at::IntArrayRef dilation, at::IntArrayRef padding, at::IntArrayRef stride) { + return at::_ops::col2im_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), kernel_size, dilation, padding, stride, out); + } + + // aten::col2im.out(Tensor self, SymInt[2] output_size, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & col2im_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, at::IntArrayRef kernel_size, at::IntArrayRef dilation, at::IntArrayRef padding, at::IntArrayRef stride, at::Tensor & out) { + return at::_ops::col2im_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), kernel_size, dilation, padding, stride, out); + } + + // aten::col2im.out(Tensor self, SymInt[2] output_size, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & col2im_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef output_size, at::IntArrayRef kernel_size, at::IntArrayRef dilation, at::IntArrayRef padding, at::IntArrayRef stride) { + return at::_ops::col2im_out::redispatch(dispatchKeySet, self, output_size, kernel_size, dilation, padding, stride, out); + } + + // aten::col2im.out(Tensor self, SymInt[2] output_size, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & col2im_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size, at::IntArrayRef kernel_size, at::IntArrayRef dilation, at::IntArrayRef padding, at::IntArrayRef stride, at::Tensor & out) { + return at::_ops::col2im_out::redispatch(dispatchKeySet, self, output_size, kernel_size, dilation, padding, stride, out); + } + + // aten::col2im(Tensor self, SymInt[2] output_size, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride) -> Tensor + inline at::Tensor col2im(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, at::IntArrayRef kernel_size, at::IntArrayRef dilation, at::IntArrayRef padding, at::IntArrayRef stride) { + return at::_ops::col2im::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), kernel_size, dilation, padding, stride); + } + + // aten::col2im(Tensor self, SymInt[2] output_size, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride) -> Tensor + inline at::Tensor col2im_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size, at::IntArrayRef kernel_size, at::IntArrayRef dilation, at::IntArrayRef padding, at::IntArrayRef stride) { + return at::_ops::col2im::redispatch(dispatchKeySet, self, output_size, kernel_size, dilation, padding, stride); + } + + // aten::column_stack(Tensor[] tensors) -> Tensor + inline at::Tensor column_stack(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors) { + return at::_ops::column_stack::redispatch(dispatchKeySet, tensors); + } + + // aten::column_stack.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & column_stack_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::TensorList tensors) { + return at::_ops::column_stack_out::redispatch(dispatchKeySet, tensors, out); + } + + // aten::column_stack.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & column_stack_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors, at::Tensor & out) { + return at::_ops::column_stack_out::redispatch(dispatchKeySet, tensors, out); + } + + // aten::im2col.out(Tensor self, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & im2col_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef dilation, at::IntArrayRef padding, at::IntArrayRef stride) { + return at::_ops::im2col_out::redispatch(dispatchKeySet, self, kernel_size, dilation, padding, stride, out); + } + + // aten::im2col.out(Tensor self, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & im2col_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef dilation, at::IntArrayRef padding, at::IntArrayRef stride, at::Tensor & out) { + return at::_ops::im2col_out::redispatch(dispatchKeySet, self, kernel_size, dilation, padding, stride, out); + } + + // aten::im2col(Tensor self, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride) -> Tensor + inline at::Tensor im2col(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef dilation, at::IntArrayRef padding, at::IntArrayRef stride) { + return at::_ops::im2col::redispatch(dispatchKeySet, self, kernel_size, dilation, padding, stride); + } + + // aten::isfinite(Tensor self) -> Tensor + inline at::Tensor isfinite(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::isfinite::redispatch(dispatchKeySet, self); + } + + // aten::isinf(Tensor self) -> Tensor + inline at::Tensor isinf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::isinf::redispatch(dispatchKeySet, self); + } + + // aten::record_stream(Tensor(a!) self, Stream s) -> () + inline void record_stream(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, at::Stream s) { + return at::_ops::record_stream::redispatch(dispatchKeySet, self, s); + } + + // aten::isposinf(Tensor self) -> Tensor + inline at::Tensor isposinf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::isposinf::redispatch(dispatchKeySet, self); + } + + // aten::isposinf.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & isposinf_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::isposinf_out::redispatch(dispatchKeySet, self, out); + } + + // aten::isposinf.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & isposinf_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::isposinf_out::redispatch(dispatchKeySet, self, out); + } + + // aten::isneginf(Tensor self) -> Tensor + inline at::Tensor isneginf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::isneginf::redispatch(dispatchKeySet, self); + } + + // aten::isneginf.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & isneginf_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::isneginf_out::redispatch(dispatchKeySet, self, out); + } + + // aten::isneginf.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & isneginf_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::isneginf_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_add_batch_dim(Tensor self, int batch_dim, int level) -> Tensor + inline at::Tensor _add_batch_dim(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t batch_dim, int64_t level) { + return at::_ops::_add_batch_dim::redispatch(dispatchKeySet, self, batch_dim, level); + } + + // aten::_remove_batch_dim(Tensor self, int level, SymInt batch_size, int out_dim) -> Tensor + inline at::Tensor _remove_batch_dim(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t level, int64_t batch_size, int64_t out_dim) { + return at::_ops::_remove_batch_dim::redispatch(dispatchKeySet, self, level, batch_size, out_dim); + } + + // aten::_remove_batch_dim(Tensor self, int level, SymInt batch_size, int out_dim) -> Tensor + inline at::Tensor _remove_batch_dim_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t level, c10::SymInt batch_size, int64_t out_dim) { + return at::_ops::_remove_batch_dim::redispatch(dispatchKeySet, self, level, batch_size, out_dim); + } + + // aten::special_entr(Tensor self) -> Tensor + inline at::Tensor special_entr(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::special_entr::redispatch(dispatchKeySet, self); + } + + // aten::special_entr.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_entr_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::special_entr_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_entr.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_entr_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::special_entr_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_ndtri(Tensor self) -> Tensor + inline at::Tensor special_ndtri(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::special_ndtri::redispatch(dispatchKeySet, self); + } + + // aten::special_ndtri.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_ndtri_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::special_ndtri_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_ndtri.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_ndtri_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::special_ndtri_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_log_ndtr(Tensor self) -> Tensor + inline at::Tensor special_log_ndtr(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::special_log_ndtr::redispatch(dispatchKeySet, self); + } + + // aten::special_log_ndtr.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_log_ndtr_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::special_log_ndtr_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_log_ndtr.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_log_ndtr_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::special_log_ndtr_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_expm1(Tensor self) -> Tensor + inline at::Tensor special_expm1(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::special_expm1::redispatch(dispatchKeySet, self); + } + + // aten::special_expm1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_expm1_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::special_expm1_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_expm1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_expm1_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::special_expm1_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_exp2(Tensor self) -> Tensor + inline at::Tensor special_exp2(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::special_exp2::redispatch(dispatchKeySet, self); + } + + // aten::special_exp2.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_exp2_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::special_exp2_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_exp2.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_exp2_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::special_exp2_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_psi(Tensor self) -> Tensor + inline at::Tensor special_psi(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::special_psi::redispatch(dispatchKeySet, self); + } + + // aten::special_psi.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_psi_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::special_psi_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_psi.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_psi_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::special_psi_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_digamma(Tensor self) -> Tensor + inline at::Tensor special_digamma(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::special_digamma::redispatch(dispatchKeySet, self); + } + + // aten::special_digamma.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_digamma_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::special_digamma_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_digamma.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_digamma_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::special_digamma_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_gammaln(Tensor self) -> Tensor + inline at::Tensor special_gammaln(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::special_gammaln::redispatch(dispatchKeySet, self); + } + + // aten::special_gammaln.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_gammaln_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::special_gammaln_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_gammaln.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_gammaln_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::special_gammaln_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_erf(Tensor self) -> Tensor + inline at::Tensor special_erf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::special_erf::redispatch(dispatchKeySet, self); + } + + // aten::special_erf.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_erf_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::special_erf_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_erf.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_erf_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::special_erf_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_erfc(Tensor self) -> Tensor + inline at::Tensor special_erfc(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::special_erfc::redispatch(dispatchKeySet, self); + } + + // aten::special_erfc.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_erfc_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::special_erfc_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_erfc.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_erfc_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::special_erfc_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_erfcx(Tensor self) -> Tensor + inline at::Tensor special_erfcx(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::special_erfcx::redispatch(dispatchKeySet, self); + } + + // aten::special_erfcx.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_erfcx_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::special_erfcx_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_erfcx.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_erfcx_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::special_erfcx_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_erfinv(Tensor self) -> Tensor + inline at::Tensor special_erfinv(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::special_erfinv::redispatch(dispatchKeySet, self); + } + + // aten::special_erfinv.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_erfinv_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::special_erfinv_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_erfinv.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_erfinv_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::special_erfinv_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_ndtr(Tensor self) -> Tensor + inline at::Tensor special_ndtr(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::special_ndtr::redispatch(dispatchKeySet, self); + } + + // aten::special_ndtr.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_ndtr_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::special_ndtr_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_ndtr.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_ndtr_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::special_ndtr_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_xlog1py(Tensor self, Tensor other) -> Tensor + inline at::Tensor special_xlog1py(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::special_xlog1py::redispatch(dispatchKeySet, self, other); + } + + // aten::special_xlog1py.self_scalar(Scalar self, Tensor other) -> Tensor + inline at::Tensor special_xlog1py(c10::DispatchKeySet dispatchKeySet, const at::Scalar & self, const at::Tensor & other) { + return at::_ops::special_xlog1py_self_scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::special_xlog1py.other_scalar(Tensor self, Scalar other) -> Tensor + inline at::Tensor special_xlog1py(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::special_xlog1py_other_scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::special_xlog1py.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_xlog1py_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::special_xlog1py_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::special_xlog1py.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_xlog1py_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::special_xlog1py_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::special_xlog1py.self_scalar_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_xlog1py_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & self, const at::Tensor & other) { + return at::_ops::special_xlog1py_self_scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::special_xlog1py.self_scalar_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_xlog1py_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::special_xlog1py_self_scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::special_xlog1py.other_scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_xlog1py_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::special_xlog1py_other_scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::special_xlog1py.other_scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_xlog1py_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, at::Tensor & out) { + return at::_ops::special_xlog1py_other_scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::special_xlogy(Tensor self, Tensor other) -> Tensor + inline at::Tensor special_xlogy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::special_xlogy::redispatch(dispatchKeySet, self, other); + } + + // aten::special_xlogy.self_scalar(Scalar self, Tensor other) -> Tensor + inline at::Tensor special_xlogy(c10::DispatchKeySet dispatchKeySet, const at::Scalar & self, const at::Tensor & other) { + return at::_ops::special_xlogy_self_scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::special_xlogy.other_scalar(Tensor self, Scalar other) -> Tensor + inline at::Tensor special_xlogy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::special_xlogy_other_scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::special_xlogy.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_xlogy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::special_xlogy_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::special_xlogy.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_xlogy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::special_xlogy_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::special_xlogy.self_scalar_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_xlogy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & self, const at::Tensor & other) { + return at::_ops::special_xlogy_self_scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::special_xlogy.self_scalar_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_xlogy_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::special_xlogy_self_scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::special_xlogy.other_scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_xlogy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::special_xlogy_other_scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::special_xlogy.other_scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_xlogy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, at::Tensor & out) { + return at::_ops::special_xlogy_other_scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::special_zeta(Tensor self, Tensor other) -> Tensor + inline at::Tensor special_zeta(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::special_zeta::redispatch(dispatchKeySet, self, other); + } + + // aten::special_zeta.self_scalar(Scalar self, Tensor other) -> Tensor + inline at::Tensor special_zeta(c10::DispatchKeySet dispatchKeySet, const at::Scalar & self, const at::Tensor & other) { + return at::_ops::special_zeta_self_scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::special_zeta.other_scalar(Tensor self, Scalar other) -> Tensor + inline at::Tensor special_zeta(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::special_zeta_other_scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::special_zeta.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_zeta_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::special_zeta_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::special_zeta.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_zeta_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::special_zeta_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::special_zeta.self_scalar_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_zeta_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & self, const at::Tensor & other) { + return at::_ops::special_zeta_self_scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::special_zeta.self_scalar_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_zeta_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::special_zeta_self_scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::special_zeta.other_scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_zeta_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::special_zeta_other_scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::special_zeta.other_scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_zeta_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, at::Tensor & out) { + return at::_ops::special_zeta_other_scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::special_i0(Tensor self) -> Tensor + inline at::Tensor special_i0(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::special_i0::redispatch(dispatchKeySet, self); + } + + // aten::special_i0.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_i0_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::special_i0_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_i0.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_i0_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::special_i0_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_i0e(Tensor self) -> Tensor + inline at::Tensor special_i0e(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::special_i0e::redispatch(dispatchKeySet, self); + } + + // aten::special_i0e.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_i0e_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::special_i0e_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_i0e.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_i0e_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::special_i0e_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_i1(Tensor self) -> Tensor + inline at::Tensor special_i1(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::special_i1::redispatch(dispatchKeySet, self); + } + + // aten::special_i1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_i1_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::special_i1_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_i1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_i1_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::special_i1_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_i1e(Tensor self) -> Tensor + inline at::Tensor special_i1e(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::special_i1e::redispatch(dispatchKeySet, self); + } + + // aten::special_i1e.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_i1e_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::special_i1e_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_i1e.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_i1e_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::special_i1e_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_logit(Tensor self, float? eps=None) -> Tensor + inline at::Tensor special_logit(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional eps=::std::nullopt) { + return at::_ops::special_logit::redispatch(dispatchKeySet, self, eps); + } + + // aten::special_logit.out(Tensor self, float? eps=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_logit_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, ::std::optional eps=::std::nullopt) { + return at::_ops::special_logit_out::redispatch(dispatchKeySet, self, eps, out); + } + + // aten::special_logit.out(Tensor self, float? eps=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_logit_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional eps, at::Tensor & out) { + return at::_ops::special_logit_out::redispatch(dispatchKeySet, self, eps, out); + } + + // aten::special_polygamma(int n, Tensor self) -> Tensor + inline at::Tensor special_polygamma(c10::DispatchKeySet dispatchKeySet, int64_t n, const at::Tensor & self) { + return at::_ops::special_polygamma::redispatch(dispatchKeySet, n, self); + } + + // aten::special_polygamma.out(int n, Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_polygamma_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, int64_t n, const at::Tensor & self) { + return at::_ops::special_polygamma_out::redispatch(dispatchKeySet, n, self, out); + } + + // aten::special_polygamma.out(int n, Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_polygamma_outf(c10::DispatchKeySet dispatchKeySet, int64_t n, const at::Tensor & self, at::Tensor & out) { + return at::_ops::special_polygamma_out::redispatch(dispatchKeySet, n, self, out); + } + + // aten::special_logsumexp(Tensor self, int[1] dim, bool keepdim=False) -> Tensor + inline at::Tensor special_logsumexp(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim, bool keepdim=false) { + return at::_ops::special_logsumexp::redispatch(dispatchKeySet, self, dim, keepdim); + } + + // aten::special_logsumexp.out(Tensor self, int[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_logsumexp_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef dim, bool keepdim=false) { + return at::_ops::special_logsumexp_out::redispatch(dispatchKeySet, self, dim, keepdim, out); + } + + // aten::special_logsumexp.out(Tensor self, int[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_logsumexp_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim, bool keepdim, at::Tensor & out) { + return at::_ops::special_logsumexp_out::redispatch(dispatchKeySet, self, dim, keepdim, out); + } + + // aten::special_expit(Tensor self) -> Tensor + inline at::Tensor special_expit(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::special_expit::redispatch(dispatchKeySet, self); + } + + // aten::special_expit.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_expit_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::special_expit_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_expit.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_expit_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::special_expit_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_sinc(Tensor self) -> Tensor + inline at::Tensor special_sinc(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::special_sinc::redispatch(dispatchKeySet, self); + } + + // aten::special_sinc.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_sinc_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::special_sinc_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_sinc.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_sinc_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::special_sinc_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_round(Tensor self, *, int decimals=0) -> Tensor + inline at::Tensor special_round(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t decimals=0) { + return at::_ops::special_round::redispatch(dispatchKeySet, self, decimals); + } + + // aten::special_round.out(Tensor self, *, int decimals=0, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_round_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t decimals=0) { + return at::_ops::special_round_out::redispatch(dispatchKeySet, self, decimals, out); + } + + // aten::special_round.out(Tensor self, *, int decimals=0, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_round_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t decimals, at::Tensor & out) { + return at::_ops::special_round_out::redispatch(dispatchKeySet, self, decimals, out); + } + + // aten::special_log1p(Tensor self) -> Tensor + inline at::Tensor special_log1p(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::special_log1p::redispatch(dispatchKeySet, self); + } + + // aten::special_log1p.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_log1p_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::special_log1p_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_log1p.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_log1p_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::special_log1p_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_log_softmax(Tensor self, int dim, *, ScalarType? dtype=None) -> Tensor + inline at::Tensor special_log_softmax(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, ::std::optional dtype=::std::nullopt) { + return at::_ops::special_log_softmax::redispatch(dispatchKeySet, self, dim, dtype); + } + + // aten::special_gammainc.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_gammainc_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::special_gammainc_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::special_gammainc.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_gammainc_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::special_gammainc_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::special_gammainc(Tensor self, Tensor other) -> Tensor + inline at::Tensor special_gammainc(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::special_gammainc::redispatch(dispatchKeySet, self, other); + } + + // aten::special_gammaincc.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_gammaincc_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::special_gammaincc_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::special_gammaincc.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_gammaincc_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::special_gammaincc_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::special_gammaincc(Tensor self, Tensor other) -> Tensor + inline at::Tensor special_gammaincc(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::special_gammaincc::redispatch(dispatchKeySet, self, other); + } + + // aten::special_multigammaln(Tensor self, int p) -> Tensor + inline at::Tensor special_multigammaln(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t p) { + return at::_ops::special_multigammaln::redispatch(dispatchKeySet, self, p); + } + + // aten::special_multigammaln.out(Tensor self, int p, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_multigammaln_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t p) { + return at::_ops::special_multigammaln_out::redispatch(dispatchKeySet, self, p, out); + } + + // aten::special_multigammaln.out(Tensor self, int p, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_multigammaln_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t p, at::Tensor & out) { + return at::_ops::special_multigammaln_out::redispatch(dispatchKeySet, self, p, out); + } + + // aten::special_softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor + inline at::Tensor special_softmax(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, ::std::optional dtype=::std::nullopt) { + return at::_ops::special_softmax::redispatch(dispatchKeySet, self, dim, dtype); + } + + // aten::fft_fft(Tensor self, SymInt? n=None, int dim=-1, str? norm=None) -> Tensor + inline at::Tensor fft_fft(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional n=::std::nullopt, int64_t dim=-1, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_fft::redispatch(dispatchKeySet, self, n.has_value() ? ::std::make_optional(c10::SymInt(*n)) : ::std::nullopt, dim, norm); + } + + // aten::fft_fft(Tensor self, SymInt? n=None, int dim=-1, str? norm=None) -> Tensor + inline at::Tensor fft_fft_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional n=::std::nullopt, int64_t dim=-1, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_fft::redispatch(dispatchKeySet, self, n, dim, norm); + } + + // aten::fft_fft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_fft_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, ::std::optional n=::std::nullopt, int64_t dim=-1, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_fft_out::redispatch(dispatchKeySet, self, n.has_value() ? ::std::make_optional(c10::SymInt(*n)) : ::std::nullopt, dim, norm, out); + } + + // aten::fft_fft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_fft_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional n, int64_t dim, ::std::optional norm, at::Tensor & out) { + return at::_ops::fft_fft_out::redispatch(dispatchKeySet, self, n.has_value() ? ::std::make_optional(c10::SymInt(*n)) : ::std::nullopt, dim, norm, out); + } + + // aten::fft_fft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_fft_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, ::std::optional n=::std::nullopt, int64_t dim=-1, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_fft_out::redispatch(dispatchKeySet, self, n, dim, norm, out); + } + + // aten::fft_fft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_fft_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional n, int64_t dim, ::std::optional norm, at::Tensor & out) { + return at::_ops::fft_fft_out::redispatch(dispatchKeySet, self, n, dim, norm, out); + } + + // aten::fft_ifft(Tensor self, SymInt? n=None, int dim=-1, str? norm=None) -> Tensor + inline at::Tensor fft_ifft(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional n=::std::nullopt, int64_t dim=-1, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_ifft::redispatch(dispatchKeySet, self, n.has_value() ? ::std::make_optional(c10::SymInt(*n)) : ::std::nullopt, dim, norm); + } + + // aten::fft_ifft(Tensor self, SymInt? n=None, int dim=-1, str? norm=None) -> Tensor + inline at::Tensor fft_ifft_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional n=::std::nullopt, int64_t dim=-1, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_ifft::redispatch(dispatchKeySet, self, n, dim, norm); + } + + // aten::fft_ifft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_ifft_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, ::std::optional n=::std::nullopt, int64_t dim=-1, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_ifft_out::redispatch(dispatchKeySet, self, n.has_value() ? ::std::make_optional(c10::SymInt(*n)) : ::std::nullopt, dim, norm, out); + } + + // aten::fft_ifft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_ifft_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional n, int64_t dim, ::std::optional norm, at::Tensor & out) { + return at::_ops::fft_ifft_out::redispatch(dispatchKeySet, self, n.has_value() ? ::std::make_optional(c10::SymInt(*n)) : ::std::nullopt, dim, norm, out); + } + + // aten::fft_ifft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_ifft_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, ::std::optional n=::std::nullopt, int64_t dim=-1, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_ifft_out::redispatch(dispatchKeySet, self, n, dim, norm, out); + } + + // aten::fft_ifft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_ifft_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional n, int64_t dim, ::std::optional norm, at::Tensor & out) { + return at::_ops::fft_ifft_out::redispatch(dispatchKeySet, self, n, dim, norm, out); + } + + // aten::fft_rfft(Tensor self, SymInt? n=None, int dim=-1, str? norm=None) -> Tensor + inline at::Tensor fft_rfft(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional n=::std::nullopt, int64_t dim=-1, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_rfft::redispatch(dispatchKeySet, self, n.has_value() ? ::std::make_optional(c10::SymInt(*n)) : ::std::nullopt, dim, norm); + } + + // aten::fft_rfft(Tensor self, SymInt? n=None, int dim=-1, str? norm=None) -> Tensor + inline at::Tensor fft_rfft_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional n=::std::nullopt, int64_t dim=-1, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_rfft::redispatch(dispatchKeySet, self, n, dim, norm); + } + + // aten::fft_rfft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_rfft_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, ::std::optional n=::std::nullopt, int64_t dim=-1, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_rfft_out::redispatch(dispatchKeySet, self, n.has_value() ? ::std::make_optional(c10::SymInt(*n)) : ::std::nullopt, dim, norm, out); + } + + // aten::fft_rfft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_rfft_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional n, int64_t dim, ::std::optional norm, at::Tensor & out) { + return at::_ops::fft_rfft_out::redispatch(dispatchKeySet, self, n.has_value() ? ::std::make_optional(c10::SymInt(*n)) : ::std::nullopt, dim, norm, out); + } + + // aten::fft_rfft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_rfft_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, ::std::optional n=::std::nullopt, int64_t dim=-1, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_rfft_out::redispatch(dispatchKeySet, self, n, dim, norm, out); + } + + // aten::fft_rfft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_rfft_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional n, int64_t dim, ::std::optional norm, at::Tensor & out) { + return at::_ops::fft_rfft_out::redispatch(dispatchKeySet, self, n, dim, norm, out); + } + + // aten::fft_irfft(Tensor self, SymInt? n=None, int dim=-1, str? norm=None) -> Tensor + inline at::Tensor fft_irfft(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional n=::std::nullopt, int64_t dim=-1, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_irfft::redispatch(dispatchKeySet, self, n.has_value() ? ::std::make_optional(c10::SymInt(*n)) : ::std::nullopt, dim, norm); + } + + // aten::fft_irfft(Tensor self, SymInt? n=None, int dim=-1, str? norm=None) -> Tensor + inline at::Tensor fft_irfft_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional n=::std::nullopt, int64_t dim=-1, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_irfft::redispatch(dispatchKeySet, self, n, dim, norm); + } + + // aten::fft_irfft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_irfft_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, ::std::optional n=::std::nullopt, int64_t dim=-1, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_irfft_out::redispatch(dispatchKeySet, self, n.has_value() ? ::std::make_optional(c10::SymInt(*n)) : ::std::nullopt, dim, norm, out); + } + + // aten::fft_irfft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_irfft_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional n, int64_t dim, ::std::optional norm, at::Tensor & out) { + return at::_ops::fft_irfft_out::redispatch(dispatchKeySet, self, n.has_value() ? ::std::make_optional(c10::SymInt(*n)) : ::std::nullopt, dim, norm, out); + } + + // aten::fft_irfft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_irfft_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, ::std::optional n=::std::nullopt, int64_t dim=-1, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_irfft_out::redispatch(dispatchKeySet, self, n, dim, norm, out); + } + + // aten::fft_irfft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_irfft_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional n, int64_t dim, ::std::optional norm, at::Tensor & out) { + return at::_ops::fft_irfft_out::redispatch(dispatchKeySet, self, n, dim, norm, out); + } + + // aten::fft_hfft(Tensor self, SymInt? n=None, int dim=-1, str? norm=None) -> Tensor + inline at::Tensor fft_hfft(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional n=::std::nullopt, int64_t dim=-1, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_hfft::redispatch(dispatchKeySet, self, n.has_value() ? ::std::make_optional(c10::SymInt(*n)) : ::std::nullopt, dim, norm); + } + + // aten::fft_hfft(Tensor self, SymInt? n=None, int dim=-1, str? norm=None) -> Tensor + inline at::Tensor fft_hfft_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional n=::std::nullopt, int64_t dim=-1, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_hfft::redispatch(dispatchKeySet, self, n, dim, norm); + } + + // aten::fft_hfft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_hfft_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, ::std::optional n=::std::nullopt, int64_t dim=-1, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_hfft_out::redispatch(dispatchKeySet, self, n.has_value() ? ::std::make_optional(c10::SymInt(*n)) : ::std::nullopt, dim, norm, out); + } + + // aten::fft_hfft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_hfft_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional n, int64_t dim, ::std::optional norm, at::Tensor & out) { + return at::_ops::fft_hfft_out::redispatch(dispatchKeySet, self, n.has_value() ? ::std::make_optional(c10::SymInt(*n)) : ::std::nullopt, dim, norm, out); + } + + // aten::fft_hfft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_hfft_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, ::std::optional n=::std::nullopt, int64_t dim=-1, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_hfft_out::redispatch(dispatchKeySet, self, n, dim, norm, out); + } + + // aten::fft_hfft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_hfft_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional n, int64_t dim, ::std::optional norm, at::Tensor & out) { + return at::_ops::fft_hfft_out::redispatch(dispatchKeySet, self, n, dim, norm, out); + } + + // aten::fft_ihfft(Tensor self, SymInt? n=None, int dim=-1, str? norm=None) -> Tensor + inline at::Tensor fft_ihfft(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional n=::std::nullopt, int64_t dim=-1, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_ihfft::redispatch(dispatchKeySet, self, n.has_value() ? ::std::make_optional(c10::SymInt(*n)) : ::std::nullopt, dim, norm); + } + + // aten::fft_ihfft(Tensor self, SymInt? n=None, int dim=-1, str? norm=None) -> Tensor + inline at::Tensor fft_ihfft_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional n=::std::nullopt, int64_t dim=-1, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_ihfft::redispatch(dispatchKeySet, self, n, dim, norm); + } + + // aten::fft_ihfft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_ihfft_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, ::std::optional n=::std::nullopt, int64_t dim=-1, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_ihfft_out::redispatch(dispatchKeySet, self, n.has_value() ? ::std::make_optional(c10::SymInt(*n)) : ::std::nullopt, dim, norm, out); + } + + // aten::fft_ihfft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_ihfft_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional n, int64_t dim, ::std::optional norm, at::Tensor & out) { + return at::_ops::fft_ihfft_out::redispatch(dispatchKeySet, self, n.has_value() ? ::std::make_optional(c10::SymInt(*n)) : ::std::nullopt, dim, norm, out); + } + + // aten::fft_ihfft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_ihfft_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, ::std::optional n=::std::nullopt, int64_t dim=-1, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_ihfft_out::redispatch(dispatchKeySet, self, n, dim, norm, out); + } + + // aten::fft_ihfft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_ihfft_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional n, int64_t dim, ::std::optional norm, at::Tensor & out) { + return at::_ops::fft_ihfft_out::redispatch(dispatchKeySet, self, n, dim, norm, out); + } + + // aten::fft_fft2(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None) -> Tensor + inline at::Tensor fft_fft2(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef s=::std::nullopt, at::IntArrayRef dim={-2,-1}, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_fft2::redispatch(dispatchKeySet, self, s.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*s)) : ::std::nullopt, dim, norm); + } + + // aten::fft_fft2(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None) -> Tensor + inline at::Tensor fft_fft2_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalSymIntArrayRef s=::std::nullopt, at::IntArrayRef dim={-2,-1}, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_fft2::redispatch(dispatchKeySet, self, s, dim, norm); + } + + // aten::fft_fft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_fft2_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::OptionalIntArrayRef s=::std::nullopt, at::IntArrayRef dim={-2,-1}, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_fft2_out::redispatch(dispatchKeySet, self, s.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*s)) : ::std::nullopt, dim, norm, out); + } + + // aten::fft_fft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_fft2_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef s, at::IntArrayRef dim, ::std::optional norm, at::Tensor & out) { + return at::_ops::fft_fft2_out::redispatch(dispatchKeySet, self, s.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*s)) : ::std::nullopt, dim, norm, out); + } + + // aten::fft_fft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_fft2_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::OptionalSymIntArrayRef s=::std::nullopt, at::IntArrayRef dim={-2,-1}, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_fft2_out::redispatch(dispatchKeySet, self, s, dim, norm, out); + } + + // aten::fft_fft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_fft2_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalSymIntArrayRef s, at::IntArrayRef dim, ::std::optional norm, at::Tensor & out) { + return at::_ops::fft_fft2_out::redispatch(dispatchKeySet, self, s, dim, norm, out); + } + + // aten::fft_ifft2(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None) -> Tensor + inline at::Tensor fft_ifft2(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef s=::std::nullopt, at::IntArrayRef dim={-2,-1}, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_ifft2::redispatch(dispatchKeySet, self, s.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*s)) : ::std::nullopt, dim, norm); + } + + // aten::fft_ifft2(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None) -> Tensor + inline at::Tensor fft_ifft2_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalSymIntArrayRef s=::std::nullopt, at::IntArrayRef dim={-2,-1}, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_ifft2::redispatch(dispatchKeySet, self, s, dim, norm); + } + + // aten::fft_ifft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_ifft2_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::OptionalIntArrayRef s=::std::nullopt, at::IntArrayRef dim={-2,-1}, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_ifft2_out::redispatch(dispatchKeySet, self, s.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*s)) : ::std::nullopt, dim, norm, out); + } + + // aten::fft_ifft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_ifft2_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef s, at::IntArrayRef dim, ::std::optional norm, at::Tensor & out) { + return at::_ops::fft_ifft2_out::redispatch(dispatchKeySet, self, s.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*s)) : ::std::nullopt, dim, norm, out); + } + + // aten::fft_ifft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_ifft2_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::OptionalSymIntArrayRef s=::std::nullopt, at::IntArrayRef dim={-2,-1}, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_ifft2_out::redispatch(dispatchKeySet, self, s, dim, norm, out); + } + + // aten::fft_ifft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_ifft2_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalSymIntArrayRef s, at::IntArrayRef dim, ::std::optional norm, at::Tensor & out) { + return at::_ops::fft_ifft2_out::redispatch(dispatchKeySet, self, s, dim, norm, out); + } + + // aten::fft_rfft2(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None) -> Tensor + inline at::Tensor fft_rfft2(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef s=::std::nullopt, at::IntArrayRef dim={-2,-1}, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_rfft2::redispatch(dispatchKeySet, self, s.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*s)) : ::std::nullopt, dim, norm); + } + + // aten::fft_rfft2(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None) -> Tensor + inline at::Tensor fft_rfft2_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalSymIntArrayRef s=::std::nullopt, at::IntArrayRef dim={-2,-1}, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_rfft2::redispatch(dispatchKeySet, self, s, dim, norm); + } + + // aten::fft_rfft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_rfft2_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::OptionalIntArrayRef s=::std::nullopt, at::IntArrayRef dim={-2,-1}, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_rfft2_out::redispatch(dispatchKeySet, self, s.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*s)) : ::std::nullopt, dim, norm, out); + } + + // aten::fft_rfft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_rfft2_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef s, at::IntArrayRef dim, ::std::optional norm, at::Tensor & out) { + return at::_ops::fft_rfft2_out::redispatch(dispatchKeySet, self, s.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*s)) : ::std::nullopt, dim, norm, out); + } + + // aten::fft_rfft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_rfft2_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::OptionalSymIntArrayRef s=::std::nullopt, at::IntArrayRef dim={-2,-1}, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_rfft2_out::redispatch(dispatchKeySet, self, s, dim, norm, out); + } + + // aten::fft_rfft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_rfft2_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalSymIntArrayRef s, at::IntArrayRef dim, ::std::optional norm, at::Tensor & out) { + return at::_ops::fft_rfft2_out::redispatch(dispatchKeySet, self, s, dim, norm, out); + } + + // aten::fft_irfft2(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None) -> Tensor + inline at::Tensor fft_irfft2(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef s=::std::nullopt, at::IntArrayRef dim={-2,-1}, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_irfft2::redispatch(dispatchKeySet, self, s.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*s)) : ::std::nullopt, dim, norm); + } + + // aten::fft_irfft2(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None) -> Tensor + inline at::Tensor fft_irfft2_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalSymIntArrayRef s=::std::nullopt, at::IntArrayRef dim={-2,-1}, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_irfft2::redispatch(dispatchKeySet, self, s, dim, norm); + } + + // aten::fft_irfft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_irfft2_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::OptionalIntArrayRef s=::std::nullopt, at::IntArrayRef dim={-2,-1}, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_irfft2_out::redispatch(dispatchKeySet, self, s.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*s)) : ::std::nullopt, dim, norm, out); + } + + // aten::fft_irfft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_irfft2_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef s, at::IntArrayRef dim, ::std::optional norm, at::Tensor & out) { + return at::_ops::fft_irfft2_out::redispatch(dispatchKeySet, self, s.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*s)) : ::std::nullopt, dim, norm, out); + } + + // aten::fft_irfft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_irfft2_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::OptionalSymIntArrayRef s=::std::nullopt, at::IntArrayRef dim={-2,-1}, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_irfft2_out::redispatch(dispatchKeySet, self, s, dim, norm, out); + } + + // aten::fft_irfft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_irfft2_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalSymIntArrayRef s, at::IntArrayRef dim, ::std::optional norm, at::Tensor & out) { + return at::_ops::fft_irfft2_out::redispatch(dispatchKeySet, self, s, dim, norm, out); + } + + // aten::fft_hfft2(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None) -> Tensor + inline at::Tensor fft_hfft2(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef s=::std::nullopt, at::IntArrayRef dim={-2,-1}, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_hfft2::redispatch(dispatchKeySet, self, s.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*s)) : ::std::nullopt, dim, norm); + } + + // aten::fft_hfft2(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None) -> Tensor + inline at::Tensor fft_hfft2_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalSymIntArrayRef s=::std::nullopt, at::IntArrayRef dim={-2,-1}, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_hfft2::redispatch(dispatchKeySet, self, s, dim, norm); + } + + // aten::fft_hfft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_hfft2_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::OptionalIntArrayRef s=::std::nullopt, at::IntArrayRef dim={-2,-1}, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_hfft2_out::redispatch(dispatchKeySet, self, s.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*s)) : ::std::nullopt, dim, norm, out); + } + + // aten::fft_hfft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_hfft2_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef s, at::IntArrayRef dim, ::std::optional norm, at::Tensor & out) { + return at::_ops::fft_hfft2_out::redispatch(dispatchKeySet, self, s.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*s)) : ::std::nullopt, dim, norm, out); + } + + // aten::fft_hfft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_hfft2_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::OptionalSymIntArrayRef s=::std::nullopt, at::IntArrayRef dim={-2,-1}, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_hfft2_out::redispatch(dispatchKeySet, self, s, dim, norm, out); + } + + // aten::fft_hfft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_hfft2_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalSymIntArrayRef s, at::IntArrayRef dim, ::std::optional norm, at::Tensor & out) { + return at::_ops::fft_hfft2_out::redispatch(dispatchKeySet, self, s, dim, norm, out); + } + + // aten::fft_ihfft2(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None) -> Tensor + inline at::Tensor fft_ihfft2(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef s=::std::nullopt, at::IntArrayRef dim={-2,-1}, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_ihfft2::redispatch(dispatchKeySet, self, s.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*s)) : ::std::nullopt, dim, norm); + } + + // aten::fft_ihfft2(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None) -> Tensor + inline at::Tensor fft_ihfft2_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalSymIntArrayRef s=::std::nullopt, at::IntArrayRef dim={-2,-1}, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_ihfft2::redispatch(dispatchKeySet, self, s, dim, norm); + } + + // aten::fft_ihfft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_ihfft2_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::OptionalIntArrayRef s=::std::nullopt, at::IntArrayRef dim={-2,-1}, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_ihfft2_out::redispatch(dispatchKeySet, self, s.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*s)) : ::std::nullopt, dim, norm, out); + } + + // aten::fft_ihfft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_ihfft2_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef s, at::IntArrayRef dim, ::std::optional norm, at::Tensor & out) { + return at::_ops::fft_ihfft2_out::redispatch(dispatchKeySet, self, s.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*s)) : ::std::nullopt, dim, norm, out); + } + + // aten::fft_ihfft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_ihfft2_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::OptionalSymIntArrayRef s=::std::nullopt, at::IntArrayRef dim={-2,-1}, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_ihfft2_out::redispatch(dispatchKeySet, self, s, dim, norm, out); + } + + // aten::fft_ihfft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_ihfft2_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalSymIntArrayRef s, at::IntArrayRef dim, ::std::optional norm, at::Tensor & out) { + return at::_ops::fft_ihfft2_out::redispatch(dispatchKeySet, self, s, dim, norm, out); + } + + // aten::fft_fftn(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor + inline at::Tensor fft_fftn(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef s=::std::nullopt, at::OptionalIntArrayRef dim=::std::nullopt, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_fftn::redispatch(dispatchKeySet, self, s.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*s)) : ::std::nullopt, dim, norm); + } + + // aten::fft_fftn(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor + inline at::Tensor fft_fftn_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalSymIntArrayRef s=::std::nullopt, at::OptionalIntArrayRef dim=::std::nullopt, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_fftn::redispatch(dispatchKeySet, self, s, dim, norm); + } + + // aten::fft_fftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_fftn_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::OptionalIntArrayRef s=::std::nullopt, at::OptionalIntArrayRef dim=::std::nullopt, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_fftn_out::redispatch(dispatchKeySet, self, s.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*s)) : ::std::nullopt, dim, norm, out); + } + + // aten::fft_fftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_fftn_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef s, at::OptionalIntArrayRef dim, ::std::optional norm, at::Tensor & out) { + return at::_ops::fft_fftn_out::redispatch(dispatchKeySet, self, s.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*s)) : ::std::nullopt, dim, norm, out); + } + + // aten::fft_fftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_fftn_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::OptionalSymIntArrayRef s=::std::nullopt, at::OptionalIntArrayRef dim=::std::nullopt, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_fftn_out::redispatch(dispatchKeySet, self, s, dim, norm, out); + } + + // aten::fft_fftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_fftn_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalSymIntArrayRef s, at::OptionalIntArrayRef dim, ::std::optional norm, at::Tensor & out) { + return at::_ops::fft_fftn_out::redispatch(dispatchKeySet, self, s, dim, norm, out); + } + + // aten::fft_ifftn(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor + inline at::Tensor fft_ifftn(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef s=::std::nullopt, at::OptionalIntArrayRef dim=::std::nullopt, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_ifftn::redispatch(dispatchKeySet, self, s.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*s)) : ::std::nullopt, dim, norm); + } + + // aten::fft_ifftn(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor + inline at::Tensor fft_ifftn_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalSymIntArrayRef s=::std::nullopt, at::OptionalIntArrayRef dim=::std::nullopt, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_ifftn::redispatch(dispatchKeySet, self, s, dim, norm); + } + + // aten::fft_ifftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_ifftn_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::OptionalIntArrayRef s=::std::nullopt, at::OptionalIntArrayRef dim=::std::nullopt, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_ifftn_out::redispatch(dispatchKeySet, self, s.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*s)) : ::std::nullopt, dim, norm, out); + } + + // aten::fft_ifftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_ifftn_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef s, at::OptionalIntArrayRef dim, ::std::optional norm, at::Tensor & out) { + return at::_ops::fft_ifftn_out::redispatch(dispatchKeySet, self, s.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*s)) : ::std::nullopt, dim, norm, out); + } + + // aten::fft_ifftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_ifftn_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::OptionalSymIntArrayRef s=::std::nullopt, at::OptionalIntArrayRef dim=::std::nullopt, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_ifftn_out::redispatch(dispatchKeySet, self, s, dim, norm, out); + } + + // aten::fft_ifftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_ifftn_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalSymIntArrayRef s, at::OptionalIntArrayRef dim, ::std::optional norm, at::Tensor & out) { + return at::_ops::fft_ifftn_out::redispatch(dispatchKeySet, self, s, dim, norm, out); + } + + // aten::fft_rfftn(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor + inline at::Tensor fft_rfftn(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef s=::std::nullopt, at::OptionalIntArrayRef dim=::std::nullopt, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_rfftn::redispatch(dispatchKeySet, self, s.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*s)) : ::std::nullopt, dim, norm); + } + + // aten::fft_rfftn(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor + inline at::Tensor fft_rfftn_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalSymIntArrayRef s=::std::nullopt, at::OptionalIntArrayRef dim=::std::nullopt, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_rfftn::redispatch(dispatchKeySet, self, s, dim, norm); + } + + // aten::fft_rfftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_rfftn_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::OptionalIntArrayRef s=::std::nullopt, at::OptionalIntArrayRef dim=::std::nullopt, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_rfftn_out::redispatch(dispatchKeySet, self, s.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*s)) : ::std::nullopt, dim, norm, out); + } + + // aten::fft_rfftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_rfftn_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef s, at::OptionalIntArrayRef dim, ::std::optional norm, at::Tensor & out) { + return at::_ops::fft_rfftn_out::redispatch(dispatchKeySet, self, s.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*s)) : ::std::nullopt, dim, norm, out); + } + + // aten::fft_rfftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_rfftn_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::OptionalSymIntArrayRef s=::std::nullopt, at::OptionalIntArrayRef dim=::std::nullopt, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_rfftn_out::redispatch(dispatchKeySet, self, s, dim, norm, out); + } + + // aten::fft_rfftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_rfftn_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalSymIntArrayRef s, at::OptionalIntArrayRef dim, ::std::optional norm, at::Tensor & out) { + return at::_ops::fft_rfftn_out::redispatch(dispatchKeySet, self, s, dim, norm, out); + } + + // aten::fft_irfftn(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor + inline at::Tensor fft_irfftn(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef s=::std::nullopt, at::OptionalIntArrayRef dim=::std::nullopt, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_irfftn::redispatch(dispatchKeySet, self, s.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*s)) : ::std::nullopt, dim, norm); + } + + // aten::fft_irfftn(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor + inline at::Tensor fft_irfftn_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalSymIntArrayRef s=::std::nullopt, at::OptionalIntArrayRef dim=::std::nullopt, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_irfftn::redispatch(dispatchKeySet, self, s, dim, norm); + } + + // aten::fft_irfftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_irfftn_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::OptionalIntArrayRef s=::std::nullopt, at::OptionalIntArrayRef dim=::std::nullopt, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_irfftn_out::redispatch(dispatchKeySet, self, s.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*s)) : ::std::nullopt, dim, norm, out); + } + + // aten::fft_irfftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_irfftn_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef s, at::OptionalIntArrayRef dim, ::std::optional norm, at::Tensor & out) { + return at::_ops::fft_irfftn_out::redispatch(dispatchKeySet, self, s.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*s)) : ::std::nullopt, dim, norm, out); + } + + // aten::fft_irfftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_irfftn_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::OptionalSymIntArrayRef s=::std::nullopt, at::OptionalIntArrayRef dim=::std::nullopt, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_irfftn_out::redispatch(dispatchKeySet, self, s, dim, norm, out); + } + + // aten::fft_irfftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_irfftn_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalSymIntArrayRef s, at::OptionalIntArrayRef dim, ::std::optional norm, at::Tensor & out) { + return at::_ops::fft_irfftn_out::redispatch(dispatchKeySet, self, s, dim, norm, out); + } + + // aten::fft_hfftn(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor + inline at::Tensor fft_hfftn(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef s=::std::nullopt, at::OptionalIntArrayRef dim=::std::nullopt, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_hfftn::redispatch(dispatchKeySet, self, s.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*s)) : ::std::nullopt, dim, norm); + } + + // aten::fft_hfftn(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor + inline at::Tensor fft_hfftn_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalSymIntArrayRef s=::std::nullopt, at::OptionalIntArrayRef dim=::std::nullopt, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_hfftn::redispatch(dispatchKeySet, self, s, dim, norm); + } + + // aten::fft_hfftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_hfftn_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::OptionalIntArrayRef s=::std::nullopt, at::OptionalIntArrayRef dim=::std::nullopt, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_hfftn_out::redispatch(dispatchKeySet, self, s.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*s)) : ::std::nullopt, dim, norm, out); + } + + // aten::fft_hfftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_hfftn_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef s, at::OptionalIntArrayRef dim, ::std::optional norm, at::Tensor & out) { + return at::_ops::fft_hfftn_out::redispatch(dispatchKeySet, self, s.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*s)) : ::std::nullopt, dim, norm, out); + } + + // aten::fft_hfftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_hfftn_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::OptionalSymIntArrayRef s=::std::nullopt, at::OptionalIntArrayRef dim=::std::nullopt, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_hfftn_out::redispatch(dispatchKeySet, self, s, dim, norm, out); + } + + // aten::fft_hfftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_hfftn_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalSymIntArrayRef s, at::OptionalIntArrayRef dim, ::std::optional norm, at::Tensor & out) { + return at::_ops::fft_hfftn_out::redispatch(dispatchKeySet, self, s, dim, norm, out); + } + + // aten::fft_ihfftn(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor + inline at::Tensor fft_ihfftn(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef s=::std::nullopt, at::OptionalIntArrayRef dim=::std::nullopt, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_ihfftn::redispatch(dispatchKeySet, self, s.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*s)) : ::std::nullopt, dim, norm); + } + + // aten::fft_ihfftn(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor + inline at::Tensor fft_ihfftn_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalSymIntArrayRef s=::std::nullopt, at::OptionalIntArrayRef dim=::std::nullopt, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_ihfftn::redispatch(dispatchKeySet, self, s, dim, norm); + } + + // aten::fft_ihfftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_ihfftn_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::OptionalIntArrayRef s=::std::nullopt, at::OptionalIntArrayRef dim=::std::nullopt, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_ihfftn_out::redispatch(dispatchKeySet, self, s.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*s)) : ::std::nullopt, dim, norm, out); + } + + // aten::fft_ihfftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_ihfftn_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef s, at::OptionalIntArrayRef dim, ::std::optional norm, at::Tensor & out) { + return at::_ops::fft_ihfftn_out::redispatch(dispatchKeySet, self, s.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*s)) : ::std::nullopt, dim, norm, out); + } + + // aten::fft_ihfftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_ihfftn_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::OptionalSymIntArrayRef s=::std::nullopt, at::OptionalIntArrayRef dim=::std::nullopt, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_ihfftn_out::redispatch(dispatchKeySet, self, s, dim, norm, out); + } + + // aten::fft_ihfftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_ihfftn_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalSymIntArrayRef s, at::OptionalIntArrayRef dim, ::std::optional norm, at::Tensor & out) { + return at::_ops::fft_ihfftn_out::redispatch(dispatchKeySet, self, s, dim, norm, out); + } + + // aten::fft_fftfreq(int n, float d=1.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor fft_fftfreq(c10::DispatchKeySet dispatchKeySet, int64_t n, double d=1.0, at::TensorOptions options={}) { + return at::_ops::fft_fftfreq::redispatch(dispatchKeySet, n, d, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::fft_fftfreq(int n, float d=1.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor fft_fftfreq(c10::DispatchKeySet dispatchKeySet, int64_t n, double d, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::fft_fftfreq::redispatch(dispatchKeySet, n, d, dtype, layout, device, pin_memory); + } + + // aten::fft_fftfreq.out(int n, float d=1.0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_fftfreq_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, int64_t n, double d=1.0) { + return at::_ops::fft_fftfreq_out::redispatch(dispatchKeySet, n, d, out); + } + + // aten::fft_fftfreq.out(int n, float d=1.0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_fftfreq_outf(c10::DispatchKeySet dispatchKeySet, int64_t n, double d, at::Tensor & out) { + return at::_ops::fft_fftfreq_out::redispatch(dispatchKeySet, n, d, out); + } + + // aten::fft_rfftfreq(int n, float d=1.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor fft_rfftfreq(c10::DispatchKeySet dispatchKeySet, int64_t n, double d=1.0, at::TensorOptions options={}) { + return at::_ops::fft_rfftfreq::redispatch(dispatchKeySet, n, d, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::fft_rfftfreq(int n, float d=1.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor fft_rfftfreq(c10::DispatchKeySet dispatchKeySet, int64_t n, double d, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::fft_rfftfreq::redispatch(dispatchKeySet, n, d, dtype, layout, device, pin_memory); + } + + // aten::fft_rfftfreq.out(int n, float d=1.0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_rfftfreq_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, int64_t n, double d=1.0) { + return at::_ops::fft_rfftfreq_out::redispatch(dispatchKeySet, n, d, out); + } + + // aten::fft_rfftfreq.out(int n, float d=1.0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_rfftfreq_outf(c10::DispatchKeySet dispatchKeySet, int64_t n, double d, at::Tensor & out) { + return at::_ops::fft_rfftfreq_out::redispatch(dispatchKeySet, n, d, out); + } + + // aten::fft_fftshift(Tensor self, int[1]? dim=None) -> Tensor + inline at::Tensor fft_fftshift(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef dim=::std::nullopt) { + return at::_ops::fft_fftshift::redispatch(dispatchKeySet, self, dim); + } + + // aten::fft_ifftshift(Tensor self, int[1]? dim=None) -> Tensor + inline at::Tensor fft_ifftshift(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef dim=::std::nullopt) { + return at::_ops::fft_ifftshift::redispatch(dispatchKeySet, self, dim); + } + + // aten::linalg_cholesky_ex(Tensor self, *, bool upper=False, bool check_errors=False) -> (Tensor L, Tensor info) + inline ::std::tuple linalg_cholesky_ex(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool upper=false, bool check_errors=false) { + return at::_ops::linalg_cholesky_ex::redispatch(dispatchKeySet, self, upper, check_errors); + } + + // aten::linalg_cholesky_ex.L(Tensor self, *, bool upper=False, bool check_errors=False, Tensor(a!) L, Tensor(b!) info) -> (Tensor(a!) L, Tensor(b!) info) + inline ::std::tuple linalg_cholesky_ex_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & L, at::Tensor & info, const at::Tensor & self, bool upper=false, bool check_errors=false) { + return at::_ops::linalg_cholesky_ex_L::redispatch(dispatchKeySet, self, upper, check_errors, L, info); + } + + // aten::linalg_cholesky_ex.L(Tensor self, *, bool upper=False, bool check_errors=False, Tensor(a!) L, Tensor(b!) info) -> (Tensor(a!) L, Tensor(b!) info) + inline ::std::tuple linalg_cholesky_ex_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool upper, bool check_errors, at::Tensor & L, at::Tensor & info) { + return at::_ops::linalg_cholesky_ex_L::redispatch(dispatchKeySet, self, upper, check_errors, L, info); + } + + // aten::linalg_cholesky(Tensor self, *, bool upper=False) -> Tensor + inline at::Tensor linalg_cholesky(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool upper=false) { + return at::_ops::linalg_cholesky::redispatch(dispatchKeySet, self, upper); + } + + // aten::linalg_cholesky.out(Tensor self, *, bool upper=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_cholesky_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, bool upper=false) { + return at::_ops::linalg_cholesky_out::redispatch(dispatchKeySet, self, upper, out); + } + + // aten::linalg_cholesky.out(Tensor self, *, bool upper=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_cholesky_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool upper, at::Tensor & out) { + return at::_ops::linalg_cholesky_out::redispatch(dispatchKeySet, self, upper, out); + } + + // aten::linalg_cross(Tensor self, Tensor other, *, int dim=-1) -> Tensor + inline at::Tensor linalg_cross(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, int64_t dim=-1) { + return at::_ops::linalg_cross::redispatch(dispatchKeySet, self, other, dim); + } + + // aten::linalg_cross.out(Tensor self, Tensor other, *, int dim=-1, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_cross_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other, int64_t dim=-1) { + return at::_ops::linalg_cross_out::redispatch(dispatchKeySet, self, other, dim, out); + } + + // aten::linalg_cross.out(Tensor self, Tensor other, *, int dim=-1, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_cross_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, int64_t dim, at::Tensor & out) { + return at::_ops::linalg_cross_out::redispatch(dispatchKeySet, self, other, dim, out); + } + + // aten::linalg_lu_factor(Tensor A, *, bool pivot=True) -> (Tensor LU, Tensor pivots) + inline ::std::tuple linalg_lu_factor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A, bool pivot=true) { + return at::_ops::linalg_lu_factor::redispatch(dispatchKeySet, A, pivot); + } + + // aten::linalg_lu_factor.out(Tensor A, *, bool pivot=True, Tensor(a!) LU, Tensor(b!) pivots) -> (Tensor(a!) LU, Tensor(b!) pivots) + inline ::std::tuple linalg_lu_factor_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & LU, at::Tensor & pivots, const at::Tensor & A, bool pivot=true) { + return at::_ops::linalg_lu_factor_out::redispatch(dispatchKeySet, A, pivot, LU, pivots); + } + + // aten::linalg_lu_factor.out(Tensor A, *, bool pivot=True, Tensor(a!) LU, Tensor(b!) pivots) -> (Tensor(a!) LU, Tensor(b!) pivots) + inline ::std::tuple linalg_lu_factor_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A, bool pivot, at::Tensor & LU, at::Tensor & pivots) { + return at::_ops::linalg_lu_factor_out::redispatch(dispatchKeySet, A, pivot, LU, pivots); + } + + // aten::linalg_lu_factor_ex(Tensor A, *, bool pivot=True, bool check_errors=False) -> (Tensor LU, Tensor pivots, Tensor info) + inline ::std::tuple linalg_lu_factor_ex(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A, bool pivot=true, bool check_errors=false) { + return at::_ops::linalg_lu_factor_ex::redispatch(dispatchKeySet, A, pivot, check_errors); + } + + // aten::linalg_lu_factor_ex.out(Tensor A, *, bool pivot=True, bool check_errors=False, Tensor(a!) LU, Tensor(b!) pivots, Tensor(c!) info) -> (Tensor(a!) LU, Tensor(b!) pivots, Tensor(c!) info) + inline ::std::tuple linalg_lu_factor_ex_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & LU, at::Tensor & pivots, at::Tensor & info, const at::Tensor & A, bool pivot=true, bool check_errors=false) { + return at::_ops::linalg_lu_factor_ex_out::redispatch(dispatchKeySet, A, pivot, check_errors, LU, pivots, info); + } + + // aten::linalg_lu_factor_ex.out(Tensor A, *, bool pivot=True, bool check_errors=False, Tensor(a!) LU, Tensor(b!) pivots, Tensor(c!) info) -> (Tensor(a!) LU, Tensor(b!) pivots, Tensor(c!) info) + inline ::std::tuple linalg_lu_factor_ex_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A, bool pivot, bool check_errors, at::Tensor & LU, at::Tensor & pivots, at::Tensor & info) { + return at::_ops::linalg_lu_factor_ex_out::redispatch(dispatchKeySet, A, pivot, check_errors, LU, pivots, info); + } + + // aten::linalg_lu(Tensor A, *, bool pivot=True) -> (Tensor P, Tensor L, Tensor U) + inline ::std::tuple linalg_lu(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A, bool pivot=true) { + return at::_ops::linalg_lu::redispatch(dispatchKeySet, A, pivot); + } + + // aten::linalg_lu.out(Tensor A, *, bool pivot=True, Tensor(a!) P, Tensor(b!) L, Tensor(c!) U) -> (Tensor(a!) P, Tensor(b!) L, Tensor(c!) U) + inline ::std::tuple linalg_lu_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & P, at::Tensor & L, at::Tensor & U, const at::Tensor & A, bool pivot=true) { + return at::_ops::linalg_lu_out::redispatch(dispatchKeySet, A, pivot, P, L, U); + } + + // aten::linalg_lu.out(Tensor A, *, bool pivot=True, Tensor(a!) P, Tensor(b!) L, Tensor(c!) U) -> (Tensor(a!) P, Tensor(b!) L, Tensor(c!) U) + inline ::std::tuple linalg_lu_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A, bool pivot, at::Tensor & P, at::Tensor & L, at::Tensor & U) { + return at::_ops::linalg_lu_out::redispatch(dispatchKeySet, A, pivot, P, L, U); + } + + // aten::linalg_lu_solve(Tensor LU, Tensor pivots, Tensor B, *, bool left=True, bool adjoint=False) -> Tensor + inline at::Tensor linalg_lu_solve(c10::DispatchKeySet dispatchKeySet, const at::Tensor & LU, const at::Tensor & pivots, const at::Tensor & B, bool left=true, bool adjoint=false) { + return at::_ops::linalg_lu_solve::redispatch(dispatchKeySet, LU, pivots, B, left, adjoint); + } + + // aten::linalg_lu_solve.out(Tensor LU, Tensor pivots, Tensor B, *, bool left=True, bool adjoint=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_lu_solve_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & LU, const at::Tensor & pivots, const at::Tensor & B, bool left=true, bool adjoint=false) { + return at::_ops::linalg_lu_solve_out::redispatch(dispatchKeySet, LU, pivots, B, left, adjoint, out); + } + + // aten::linalg_lu_solve.out(Tensor LU, Tensor pivots, Tensor B, *, bool left=True, bool adjoint=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_lu_solve_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & LU, const at::Tensor & pivots, const at::Tensor & B, bool left, bool adjoint, at::Tensor & out) { + return at::_ops::linalg_lu_solve_out::redispatch(dispatchKeySet, LU, pivots, B, left, adjoint, out); + } + + // aten::_linalg_det(Tensor A) -> (Tensor result, Tensor LU, Tensor pivots) + inline ::std::tuple _linalg_det(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A) { + return at::_ops::_linalg_det::redispatch(dispatchKeySet, A); + } + + // aten::_linalg_det.result(Tensor A, *, Tensor(a!) result, Tensor(b!) LU, Tensor(c!) pivots) -> (Tensor(a!) result, Tensor(b!) LU, Tensor(c!) pivots) + inline ::std::tuple _linalg_det_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & result, at::Tensor & LU, at::Tensor & pivots, const at::Tensor & A) { + return at::_ops::_linalg_det_result::redispatch(dispatchKeySet, A, result, LU, pivots); + } + + // aten::_linalg_det.result(Tensor A, *, Tensor(a!) result, Tensor(b!) LU, Tensor(c!) pivots) -> (Tensor(a!) result, Tensor(b!) LU, Tensor(c!) pivots) + inline ::std::tuple _linalg_det_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A, at::Tensor & result, at::Tensor & LU, at::Tensor & pivots) { + return at::_ops::_linalg_det_result::redispatch(dispatchKeySet, A, result, LU, pivots); + } + + // aten::linalg_det(Tensor A) -> Tensor + inline at::Tensor linalg_det(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A) { + return at::_ops::linalg_det::redispatch(dispatchKeySet, A); + } + + // aten::linalg_det.out(Tensor A, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_det_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & A) { + return at::_ops::linalg_det_out::redispatch(dispatchKeySet, A, out); + } + + // aten::linalg_det.out(Tensor A, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_det_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A, at::Tensor & out) { + return at::_ops::linalg_det_out::redispatch(dispatchKeySet, A, out); + } + + // aten::det(Tensor self) -> Tensor + inline at::Tensor det(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::det::redispatch(dispatchKeySet, self); + } + + // aten::linalg_ldl_factor_ex(Tensor self, *, bool hermitian=False, bool check_errors=False) -> (Tensor LD, Tensor pivots, Tensor info) + inline ::std::tuple linalg_ldl_factor_ex(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool hermitian=false, bool check_errors=false) { + return at::_ops::linalg_ldl_factor_ex::redispatch(dispatchKeySet, self, hermitian, check_errors); + } + + // aten::linalg_ldl_factor_ex.out(Tensor self, *, bool hermitian=False, bool check_errors=False, Tensor(a!) LD, Tensor(b!) pivots, Tensor(c!) info) -> (Tensor(a!) LD, Tensor(b!) pivots, Tensor(c!) info) + inline ::std::tuple linalg_ldl_factor_ex_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & LD, at::Tensor & pivots, at::Tensor & info, const at::Tensor & self, bool hermitian=false, bool check_errors=false) { + return at::_ops::linalg_ldl_factor_ex_out::redispatch(dispatchKeySet, self, hermitian, check_errors, LD, pivots, info); + } + + // aten::linalg_ldl_factor_ex.out(Tensor self, *, bool hermitian=False, bool check_errors=False, Tensor(a!) LD, Tensor(b!) pivots, Tensor(c!) info) -> (Tensor(a!) LD, Tensor(b!) pivots, Tensor(c!) info) + inline ::std::tuple linalg_ldl_factor_ex_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool hermitian, bool check_errors, at::Tensor & LD, at::Tensor & pivots, at::Tensor & info) { + return at::_ops::linalg_ldl_factor_ex_out::redispatch(dispatchKeySet, self, hermitian, check_errors, LD, pivots, info); + } + + // aten::linalg_ldl_factor(Tensor self, *, bool hermitian=False) -> (Tensor LD, Tensor pivots) + inline ::std::tuple linalg_ldl_factor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool hermitian=false) { + return at::_ops::linalg_ldl_factor::redispatch(dispatchKeySet, self, hermitian); + } + + // aten::linalg_ldl_factor.out(Tensor self, *, bool hermitian=False, Tensor(a!) LD, Tensor(b!) pivots) -> (Tensor(a!) LD, Tensor(b!) pivots) + inline ::std::tuple linalg_ldl_factor_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & LD, at::Tensor & pivots, const at::Tensor & self, bool hermitian=false) { + return at::_ops::linalg_ldl_factor_out::redispatch(dispatchKeySet, self, hermitian, LD, pivots); + } + + // aten::linalg_ldl_factor.out(Tensor self, *, bool hermitian=False, Tensor(a!) LD, Tensor(b!) pivots) -> (Tensor(a!) LD, Tensor(b!) pivots) + inline ::std::tuple linalg_ldl_factor_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool hermitian, at::Tensor & LD, at::Tensor & pivots) { + return at::_ops::linalg_ldl_factor_out::redispatch(dispatchKeySet, self, hermitian, LD, pivots); + } + + // aten::linalg_ldl_solve(Tensor LD, Tensor pivots, Tensor B, *, bool hermitian=False) -> Tensor + inline at::Tensor linalg_ldl_solve(c10::DispatchKeySet dispatchKeySet, const at::Tensor & LD, const at::Tensor & pivots, const at::Tensor & B, bool hermitian=false) { + return at::_ops::linalg_ldl_solve::redispatch(dispatchKeySet, LD, pivots, B, hermitian); + } + + // aten::linalg_ldl_solve.out(Tensor LD, Tensor pivots, Tensor B, *, bool hermitian=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_ldl_solve_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & LD, const at::Tensor & pivots, const at::Tensor & B, bool hermitian=false) { + return at::_ops::linalg_ldl_solve_out::redispatch(dispatchKeySet, LD, pivots, B, hermitian, out); + } + + // aten::linalg_ldl_solve.out(Tensor LD, Tensor pivots, Tensor B, *, bool hermitian=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_ldl_solve_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & LD, const at::Tensor & pivots, const at::Tensor & B, bool hermitian, at::Tensor & out) { + return at::_ops::linalg_ldl_solve_out::redispatch(dispatchKeySet, LD, pivots, B, hermitian, out); + } + + // aten::linalg_lstsq(Tensor self, Tensor b, float? rcond=None, *, str? driver=None) -> (Tensor solution, Tensor residuals, Tensor rank, Tensor singular_values) + inline ::std::tuple linalg_lstsq(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & b, ::std::optional rcond=::std::nullopt, ::std::optional driver=::std::nullopt) { + return at::_ops::linalg_lstsq::redispatch(dispatchKeySet, self, b, rcond, driver); + } + + // aten::linalg_lstsq.out(Tensor self, Tensor b, float? rcond=None, *, str? driver=None, Tensor(a!) solution, Tensor(b!) residuals, Tensor(c!) rank, Tensor(d!) singular_values) -> (Tensor(a!) solution, Tensor(b!) residuals, Tensor(c!) rank, Tensor(d!) singular_values) + inline ::std::tuple linalg_lstsq_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & solution, at::Tensor & residuals, at::Tensor & rank, at::Tensor & singular_values, const at::Tensor & self, const at::Tensor & b, ::std::optional rcond=::std::nullopt, ::std::optional driver=::std::nullopt) { + return at::_ops::linalg_lstsq_out::redispatch(dispatchKeySet, self, b, rcond, driver, solution, residuals, rank, singular_values); + } + + // aten::linalg_lstsq.out(Tensor self, Tensor b, float? rcond=None, *, str? driver=None, Tensor(a!) solution, Tensor(b!) residuals, Tensor(c!) rank, Tensor(d!) singular_values) -> (Tensor(a!) solution, Tensor(b!) residuals, Tensor(c!) rank, Tensor(d!) singular_values) + inline ::std::tuple linalg_lstsq_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & b, ::std::optional rcond, ::std::optional driver, at::Tensor & solution, at::Tensor & residuals, at::Tensor & rank, at::Tensor & singular_values) { + return at::_ops::linalg_lstsq_out::redispatch(dispatchKeySet, self, b, rcond, driver, solution, residuals, rank, singular_values); + } + + // aten::linalg_matmul(Tensor self, Tensor other) -> Tensor + inline at::Tensor linalg_matmul(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::linalg_matmul::redispatch(dispatchKeySet, self, other); + } + + // aten::linalg_matmul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_matmul_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::linalg_matmul_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::linalg_matmul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_matmul_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::linalg_matmul_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::linalg_vecdot(Tensor x, Tensor y, *, int dim=-1) -> Tensor + inline at::Tensor linalg_vecdot(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Tensor & y, int64_t dim=-1) { + return at::_ops::linalg_vecdot::redispatch(dispatchKeySet, x, y, dim); + } + + // aten::linalg_vecdot.out(Tensor x, Tensor y, *, int dim=-1, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_vecdot_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & x, const at::Tensor & y, int64_t dim=-1) { + return at::_ops::linalg_vecdot_out::redispatch(dispatchKeySet, x, y, dim, out); + } + + // aten::linalg_vecdot.out(Tensor x, Tensor y, *, int dim=-1, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_vecdot_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Tensor & y, int64_t dim, at::Tensor & out) { + return at::_ops::linalg_vecdot_out::redispatch(dispatchKeySet, x, y, dim, out); + } + + // aten::linalg_matrix_exp(Tensor self) -> Tensor + inline at::Tensor linalg_matrix_exp(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::linalg_matrix_exp::redispatch(dispatchKeySet, self); + } + + // aten::_linalg_slogdet(Tensor A) -> (Tensor sign, Tensor logabsdet, Tensor LU, Tensor pivots) + inline ::std::tuple _linalg_slogdet(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A) { + return at::_ops::_linalg_slogdet::redispatch(dispatchKeySet, A); + } + + // aten::_linalg_slogdet.sign(Tensor A, *, Tensor(a!) sign, Tensor(b!) logabsdet, Tensor(c!) LU, Tensor(d!) pivots) -> (Tensor(a!) sign, Tensor(b!) logabsdet, Tensor(c!) LU, Tensor(d!) pivots) + inline ::std::tuple _linalg_slogdet_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & sign, at::Tensor & logabsdet, at::Tensor & LU, at::Tensor & pivots, const at::Tensor & A) { + return at::_ops::_linalg_slogdet_sign::redispatch(dispatchKeySet, A, sign, logabsdet, LU, pivots); + } + + // aten::_linalg_slogdet.sign(Tensor A, *, Tensor(a!) sign, Tensor(b!) logabsdet, Tensor(c!) LU, Tensor(d!) pivots) -> (Tensor(a!) sign, Tensor(b!) logabsdet, Tensor(c!) LU, Tensor(d!) pivots) + inline ::std::tuple _linalg_slogdet_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A, at::Tensor & sign, at::Tensor & logabsdet, at::Tensor & LU, at::Tensor & pivots) { + return at::_ops::_linalg_slogdet_sign::redispatch(dispatchKeySet, A, sign, logabsdet, LU, pivots); + } + + // aten::linalg_slogdet(Tensor A) -> (Tensor sign, Tensor logabsdet) + inline ::std::tuple linalg_slogdet(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A) { + return at::_ops::linalg_slogdet::redispatch(dispatchKeySet, A); + } + + // aten::linalg_slogdet.out(Tensor A, *, Tensor(a!) sign, Tensor(b!) logabsdet) -> (Tensor(a!) sign, Tensor(b!) logabsdet) + inline ::std::tuple linalg_slogdet_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & sign, at::Tensor & logabsdet, const at::Tensor & A) { + return at::_ops::linalg_slogdet_out::redispatch(dispatchKeySet, A, sign, logabsdet); + } + + // aten::linalg_slogdet.out(Tensor A, *, Tensor(a!) sign, Tensor(b!) logabsdet) -> (Tensor(a!) sign, Tensor(b!) logabsdet) + inline ::std::tuple linalg_slogdet_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A, at::Tensor & sign, at::Tensor & logabsdet) { + return at::_ops::linalg_slogdet_out::redispatch(dispatchKeySet, A, sign, logabsdet); + } + + // aten::slogdet(Tensor self) -> (Tensor sign, Tensor logabsdet) + inline ::std::tuple slogdet(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::slogdet::redispatch(dispatchKeySet, self); + } + + // aten::slogdet.out(Tensor self, *, Tensor(a!) sign, Tensor(b!) logabsdet) -> (Tensor(a!) sign, Tensor(b!) logabsdet) + inline ::std::tuple slogdet_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & sign, at::Tensor & logabsdet, const at::Tensor & self) { + return at::_ops::slogdet_out::redispatch(dispatchKeySet, self, sign, logabsdet); + } + + // aten::slogdet.out(Tensor self, *, Tensor(a!) sign, Tensor(b!) logabsdet) -> (Tensor(a!) sign, Tensor(b!) logabsdet) + inline ::std::tuple slogdet_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & sign, at::Tensor & logabsdet) { + return at::_ops::slogdet_out::redispatch(dispatchKeySet, self, sign, logabsdet); + } + + // aten::logdet(Tensor self) -> Tensor + inline at::Tensor logdet(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::logdet::redispatch(dispatchKeySet, self); + } + + // aten::linalg_eig(Tensor self) -> (Tensor eigenvalues, Tensor eigenvectors) + inline ::std::tuple linalg_eig(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::linalg_eig::redispatch(dispatchKeySet, self); + } + + // aten::linalg_eig.out(Tensor self, *, Tensor(a!) eigenvalues, Tensor(b!) eigenvectors) -> (Tensor(a!) eigenvalues, Tensor(b!) eigenvectors) + inline ::std::tuple linalg_eig_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & eigenvalues, at::Tensor & eigenvectors, const at::Tensor & self) { + return at::_ops::linalg_eig_out::redispatch(dispatchKeySet, self, eigenvalues, eigenvectors); + } + + // aten::linalg_eig.out(Tensor self, *, Tensor(a!) eigenvalues, Tensor(b!) eigenvectors) -> (Tensor(a!) eigenvalues, Tensor(b!) eigenvectors) + inline ::std::tuple linalg_eig_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & eigenvalues, at::Tensor & eigenvectors) { + return at::_ops::linalg_eig_out::redispatch(dispatchKeySet, self, eigenvalues, eigenvectors); + } + + // aten::_linalg_eigvals(Tensor self) -> Tensor + inline at::Tensor _linalg_eigvals(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::_linalg_eigvals::redispatch(dispatchKeySet, self); + } + + // aten::linalg_eigvals(Tensor self) -> Tensor + inline at::Tensor linalg_eigvals(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::linalg_eigvals::redispatch(dispatchKeySet, self); + } + + // aten::linalg_eigvals.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_eigvals_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::linalg_eigvals_out::redispatch(dispatchKeySet, self, out); + } + + // aten::linalg_eigvals.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_eigvals_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::linalg_eigvals_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_linalg_eigh(Tensor A, str UPLO="L", bool compute_v=True) -> (Tensor eigenvalues, Tensor eigenvectors) + inline ::std::tuple _linalg_eigh(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A, c10::string_view UPLO="L", bool compute_v=true) { + return at::_ops::_linalg_eigh::redispatch(dispatchKeySet, A, UPLO, compute_v); + } + + // aten::_linalg_eigh.eigenvalues(Tensor A, str UPLO="L", bool compute_v=True, *, Tensor(a!) eigenvalues, Tensor(b!) eigenvectors) -> (Tensor(a!) eigenvalues, Tensor(b!) eigenvectors) + inline ::std::tuple _linalg_eigh_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & eigenvalues, at::Tensor & eigenvectors, const at::Tensor & A, c10::string_view UPLO="L", bool compute_v=true) { + return at::_ops::_linalg_eigh_eigenvalues::redispatch(dispatchKeySet, A, UPLO, compute_v, eigenvalues, eigenvectors); + } + + // aten::_linalg_eigh.eigenvalues(Tensor A, str UPLO="L", bool compute_v=True, *, Tensor(a!) eigenvalues, Tensor(b!) eigenvectors) -> (Tensor(a!) eigenvalues, Tensor(b!) eigenvectors) + inline ::std::tuple _linalg_eigh_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A, c10::string_view UPLO, bool compute_v, at::Tensor & eigenvalues, at::Tensor & eigenvectors) { + return at::_ops::_linalg_eigh_eigenvalues::redispatch(dispatchKeySet, A, UPLO, compute_v, eigenvalues, eigenvectors); + } + + // aten::linalg_eigh(Tensor self, str UPLO="L") -> (Tensor eigenvalues, Tensor eigenvectors) + inline ::std::tuple linalg_eigh(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::string_view UPLO="L") { + return at::_ops::linalg_eigh::redispatch(dispatchKeySet, self, UPLO); + } + + // aten::linalg_eigh.eigvals(Tensor self, str UPLO="L", *, Tensor(a!) eigvals, Tensor(b!) eigvecs) -> (Tensor(a!) eigenvalues, Tensor(b!) eigenvectors) + inline ::std::tuple linalg_eigh_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & eigvals, at::Tensor & eigvecs, const at::Tensor & self, c10::string_view UPLO="L") { + return at::_ops::linalg_eigh_eigvals::redispatch(dispatchKeySet, self, UPLO, eigvals, eigvecs); + } + + // aten::linalg_eigh.eigvals(Tensor self, str UPLO="L", *, Tensor(a!) eigvals, Tensor(b!) eigvecs) -> (Tensor(a!) eigenvalues, Tensor(b!) eigenvectors) + inline ::std::tuple linalg_eigh_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::string_view UPLO, at::Tensor & eigvals, at::Tensor & eigvecs) { + return at::_ops::linalg_eigh_eigvals::redispatch(dispatchKeySet, self, UPLO, eigvals, eigvecs); + } + + // aten::linalg_eigvalsh(Tensor self, str UPLO="L") -> Tensor + inline at::Tensor linalg_eigvalsh(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::string_view UPLO="L") { + return at::_ops::linalg_eigvalsh::redispatch(dispatchKeySet, self, UPLO); + } + + // aten::linalg_eigvalsh.out(Tensor self, str UPLO="L", *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_eigvalsh_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::string_view UPLO="L") { + return at::_ops::linalg_eigvalsh_out::redispatch(dispatchKeySet, self, UPLO, out); + } + + // aten::linalg_eigvalsh.out(Tensor self, str UPLO="L", *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_eigvalsh_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::string_view UPLO, at::Tensor & out) { + return at::_ops::linalg_eigvalsh_out::redispatch(dispatchKeySet, self, UPLO, out); + } + + // aten::linalg_householder_product(Tensor input, Tensor tau) -> Tensor + inline at::Tensor linalg_householder_product(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & tau) { + return at::_ops::linalg_householder_product::redispatch(dispatchKeySet, input, tau); + } + + // aten::linalg_householder_product.out(Tensor input, Tensor tau, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_householder_product_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & input, const at::Tensor & tau) { + return at::_ops::linalg_householder_product_out::redispatch(dispatchKeySet, input, tau, out); + } + + // aten::linalg_householder_product.out(Tensor input, Tensor tau, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_householder_product_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & tau, at::Tensor & out) { + return at::_ops::linalg_householder_product_out::redispatch(dispatchKeySet, input, tau, out); + } + + // aten::linalg_inv_ex(Tensor A, *, bool check_errors=False) -> (Tensor inverse, Tensor info) + inline ::std::tuple linalg_inv_ex(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A, bool check_errors=false) { + return at::_ops::linalg_inv_ex::redispatch(dispatchKeySet, A, check_errors); + } + + // aten::linalg_inv_ex.inverse(Tensor A, *, bool check_errors=False, Tensor(a!) inverse, Tensor(b!) info) -> (Tensor(a!) inverse, Tensor(b!) info) + inline ::std::tuple linalg_inv_ex_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & inverse, at::Tensor & info, const at::Tensor & A, bool check_errors=false) { + return at::_ops::linalg_inv_ex_inverse::redispatch(dispatchKeySet, A, check_errors, inverse, info); + } + + // aten::linalg_inv_ex.inverse(Tensor A, *, bool check_errors=False, Tensor(a!) inverse, Tensor(b!) info) -> (Tensor(a!) inverse, Tensor(b!) info) + inline ::std::tuple linalg_inv_ex_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A, bool check_errors, at::Tensor & inverse, at::Tensor & info) { + return at::_ops::linalg_inv_ex_inverse::redispatch(dispatchKeySet, A, check_errors, inverse, info); + } + + // aten::linalg_inv(Tensor A) -> Tensor + inline at::Tensor linalg_inv(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A) { + return at::_ops::linalg_inv::redispatch(dispatchKeySet, A); + } + + // aten::linalg_inv.out(Tensor A, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_inv_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & A) { + return at::_ops::linalg_inv_out::redispatch(dispatchKeySet, A, out); + } + + // aten::linalg_inv.out(Tensor A, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_inv_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A, at::Tensor & out) { + return at::_ops::linalg_inv_out::redispatch(dispatchKeySet, A, out); + } + + // aten::inverse(Tensor self) -> Tensor + inline at::Tensor inverse(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::inverse::redispatch(dispatchKeySet, self); + } + + // aten::inverse.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & inverse_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::inverse_out::redispatch(dispatchKeySet, self, out); + } + + // aten::inverse.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & inverse_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::inverse_out::redispatch(dispatchKeySet, self, out); + } + + // aten::inner(Tensor self, Tensor other) -> Tensor + inline at::Tensor inner(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::inner::redispatch(dispatchKeySet, self, other); + } + + // aten::inner.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & inner_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::inner_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::inner.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & inner_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::inner_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::outer(Tensor self, Tensor vec2) -> Tensor + inline at::Tensor outer(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & vec2) { + return at::_ops::outer::redispatch(dispatchKeySet, self, vec2); + } + + // aten::outer.out(Tensor self, Tensor vec2, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & outer_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & vec2) { + return at::_ops::outer_out::redispatch(dispatchKeySet, self, vec2, out); + } + + // aten::outer.out(Tensor self, Tensor vec2, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & outer_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & vec2, at::Tensor & out) { + return at::_ops::outer_out::redispatch(dispatchKeySet, self, vec2, out); + } + + // aten::ger(Tensor self, Tensor vec2) -> Tensor + inline at::Tensor ger(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & vec2) { + return at::_ops::ger::redispatch(dispatchKeySet, self, vec2); + } + + // aten::ger.out(Tensor self, Tensor vec2, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & ger_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & vec2) { + return at::_ops::ger_out::redispatch(dispatchKeySet, self, vec2, out); + } + + // aten::ger.out(Tensor self, Tensor vec2, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & ger_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & vec2, at::Tensor & out) { + return at::_ops::ger_out::redispatch(dispatchKeySet, self, vec2, out); + } + + // aten::linalg_norm(Tensor self, Scalar? ord=None, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor + inline at::Tensor linalg_norm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const ::std::optional & ord=::std::nullopt, at::OptionalIntArrayRef dim=::std::nullopt, bool keepdim=false, ::std::optional dtype=::std::nullopt) { + return at::_ops::linalg_norm::redispatch(dispatchKeySet, self, ord, dim, keepdim, dtype); + } + + // aten::linalg_norm.ord_str(Tensor self, str ord, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor + inline at::Tensor linalg_norm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::string_view ord, at::OptionalIntArrayRef dim=::std::nullopt, bool keepdim=false, ::std::optional dtype=::std::nullopt) { + return at::_ops::linalg_norm_ord_str::redispatch(dispatchKeySet, self, ord, dim, keepdim, dtype); + } + + // aten::linalg_norm.out(Tensor self, Scalar? ord=None, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_norm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const ::std::optional & ord=::std::nullopt, at::OptionalIntArrayRef dim=::std::nullopt, bool keepdim=false, ::std::optional dtype=::std::nullopt) { + return at::_ops::linalg_norm_out::redispatch(dispatchKeySet, self, ord, dim, keepdim, dtype, out); + } + + // aten::linalg_norm.out(Tensor self, Scalar? ord=None, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_norm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const ::std::optional & ord, at::OptionalIntArrayRef dim, bool keepdim, ::std::optional dtype, at::Tensor & out) { + return at::_ops::linalg_norm_out::redispatch(dispatchKeySet, self, ord, dim, keepdim, dtype, out); + } + + // aten::linalg_norm.ord_str_out(Tensor self, str ord, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_norm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::string_view ord, at::OptionalIntArrayRef dim=::std::nullopt, bool keepdim=false, ::std::optional dtype=::std::nullopt) { + return at::_ops::linalg_norm_ord_str_out::redispatch(dispatchKeySet, self, ord, dim, keepdim, dtype, out); + } + + // aten::linalg_norm.ord_str_out(Tensor self, str ord, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_norm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::string_view ord, at::OptionalIntArrayRef dim, bool keepdim, ::std::optional dtype, at::Tensor & out) { + return at::_ops::linalg_norm_ord_str_out::redispatch(dispatchKeySet, self, ord, dim, keepdim, dtype, out); + } + + // aten::linalg_vector_norm(Tensor self, Scalar ord=2, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor + inline at::Tensor linalg_vector_norm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & ord=2, at::OptionalIntArrayRef dim=::std::nullopt, bool keepdim=false, ::std::optional dtype=::std::nullopt) { + return at::_ops::linalg_vector_norm::redispatch(dispatchKeySet, self, ord, dim, keepdim, dtype); + } + + // aten::linalg_vector_norm.out(Tensor self, Scalar ord=2, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_vector_norm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & ord=2, at::OptionalIntArrayRef dim=::std::nullopt, bool keepdim=false, ::std::optional dtype=::std::nullopt) { + return at::_ops::linalg_vector_norm_out::redispatch(dispatchKeySet, self, ord, dim, keepdim, dtype, out); + } + + // aten::linalg_vector_norm.out(Tensor self, Scalar ord=2, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_vector_norm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & ord, at::OptionalIntArrayRef dim, bool keepdim, ::std::optional dtype, at::Tensor & out) { + return at::_ops::linalg_vector_norm_out::redispatch(dispatchKeySet, self, ord, dim, keepdim, dtype, out); + } + + // aten::linalg_matrix_norm(Tensor self, Scalar ord, int[] dim=[-2,-1], bool keepdim=False, *, ScalarType? dtype=None) -> Tensor + inline at::Tensor linalg_matrix_norm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & ord, at::IntArrayRef dim={-2,-1}, bool keepdim=false, ::std::optional dtype=::std::nullopt) { + return at::_ops::linalg_matrix_norm::redispatch(dispatchKeySet, self, ord, dim, keepdim, dtype); + } + + // aten::linalg_matrix_norm.out(Tensor self, Scalar ord, int[] dim=[-2,-1], bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_matrix_norm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & ord, at::IntArrayRef dim={-2,-1}, bool keepdim=false, ::std::optional dtype=::std::nullopt) { + return at::_ops::linalg_matrix_norm_out::redispatch(dispatchKeySet, self, ord, dim, keepdim, dtype, out); + } + + // aten::linalg_matrix_norm.out(Tensor self, Scalar ord, int[] dim=[-2,-1], bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_matrix_norm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & ord, at::IntArrayRef dim, bool keepdim, ::std::optional dtype, at::Tensor & out) { + return at::_ops::linalg_matrix_norm_out::redispatch(dispatchKeySet, self, ord, dim, keepdim, dtype, out); + } + + // aten::linalg_matrix_norm.str_ord(Tensor self, str ord='fro', int[] dim=[-2,-1], bool keepdim=False, *, ScalarType? dtype=None) -> Tensor + inline at::Tensor linalg_matrix_norm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::string_view ord="fro", at::IntArrayRef dim={-2,-1}, bool keepdim=false, ::std::optional dtype=::std::nullopt) { + return at::_ops::linalg_matrix_norm_str_ord::redispatch(dispatchKeySet, self, ord, dim, keepdim, dtype); + } + + // aten::linalg_matrix_norm.str_ord_out(Tensor self, str ord='fro', int[] dim=[-2,-1], bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_matrix_norm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::string_view ord="fro", at::IntArrayRef dim={-2,-1}, bool keepdim=false, ::std::optional dtype=::std::nullopt) { + return at::_ops::linalg_matrix_norm_str_ord_out::redispatch(dispatchKeySet, self, ord, dim, keepdim, dtype, out); + } + + // aten::linalg_matrix_norm.str_ord_out(Tensor self, str ord='fro', int[] dim=[-2,-1], bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_matrix_norm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::string_view ord, at::IntArrayRef dim, bool keepdim, ::std::optional dtype, at::Tensor & out) { + return at::_ops::linalg_matrix_norm_str_ord_out::redispatch(dispatchKeySet, self, ord, dim, keepdim, dtype, out); + } + + // aten::_linalg_svd(Tensor A, bool full_matrices=False, bool compute_uv=True, *, str? driver=None) -> (Tensor U, Tensor S, Tensor Vh) + inline ::std::tuple _linalg_svd(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A, bool full_matrices=false, bool compute_uv=true, ::std::optional driver=::std::nullopt) { + return at::_ops::_linalg_svd::redispatch(dispatchKeySet, A, full_matrices, compute_uv, driver); + } + + // aten::_linalg_svd.U(Tensor A, bool full_matrices=False, bool compute_uv=True, *, str? driver=None, Tensor(a!) U, Tensor(b!) S, Tensor(c!) Vh) -> (Tensor(a!) U, Tensor(b!) S, Tensor(c!) Vh) + inline ::std::tuple _linalg_svd_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & U, at::Tensor & S, at::Tensor & Vh, const at::Tensor & A, bool full_matrices=false, bool compute_uv=true, ::std::optional driver=::std::nullopt) { + return at::_ops::_linalg_svd_U::redispatch(dispatchKeySet, A, full_matrices, compute_uv, driver, U, S, Vh); + } + + // aten::_linalg_svd.U(Tensor A, bool full_matrices=False, bool compute_uv=True, *, str? driver=None, Tensor(a!) U, Tensor(b!) S, Tensor(c!) Vh) -> (Tensor(a!) U, Tensor(b!) S, Tensor(c!) Vh) + inline ::std::tuple _linalg_svd_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A, bool full_matrices, bool compute_uv, ::std::optional driver, at::Tensor & U, at::Tensor & S, at::Tensor & Vh) { + return at::_ops::_linalg_svd_U::redispatch(dispatchKeySet, A, full_matrices, compute_uv, driver, U, S, Vh); + } + + // aten::linalg_svd(Tensor A, bool full_matrices=True, *, str? driver=None) -> (Tensor U, Tensor S, Tensor Vh) + inline ::std::tuple linalg_svd(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A, bool full_matrices=true, ::std::optional driver=::std::nullopt) { + return at::_ops::linalg_svd::redispatch(dispatchKeySet, A, full_matrices, driver); + } + + // aten::linalg_svd.U(Tensor A, bool full_matrices=True, *, str? driver=None, Tensor(a!) U, Tensor(b!) S, Tensor(c!) Vh) -> (Tensor(a!) U, Tensor(b!) S, Tensor(c!) Vh) + inline ::std::tuple linalg_svd_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & U, at::Tensor & S, at::Tensor & Vh, const at::Tensor & A, bool full_matrices=true, ::std::optional driver=::std::nullopt) { + return at::_ops::linalg_svd_U::redispatch(dispatchKeySet, A, full_matrices, driver, U, S, Vh); + } + + // aten::linalg_svd.U(Tensor A, bool full_matrices=True, *, str? driver=None, Tensor(a!) U, Tensor(b!) S, Tensor(c!) Vh) -> (Tensor(a!) U, Tensor(b!) S, Tensor(c!) Vh) + inline ::std::tuple linalg_svd_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A, bool full_matrices, ::std::optional driver, at::Tensor & U, at::Tensor & S, at::Tensor & Vh) { + return at::_ops::linalg_svd_U::redispatch(dispatchKeySet, A, full_matrices, driver, U, S, Vh); + } + + // aten::linalg_svdvals(Tensor A, *, str? driver=None) -> Tensor + inline at::Tensor linalg_svdvals(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A, ::std::optional driver=::std::nullopt) { + return at::_ops::linalg_svdvals::redispatch(dispatchKeySet, A, driver); + } + + // aten::linalg_svdvals.out(Tensor A, *, str? driver=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_svdvals_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & A, ::std::optional driver=::std::nullopt) { + return at::_ops::linalg_svdvals_out::redispatch(dispatchKeySet, A, driver, out); + } + + // aten::linalg_svdvals.out(Tensor A, *, str? driver=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_svdvals_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A, ::std::optional driver, at::Tensor & out) { + return at::_ops::linalg_svdvals_out::redispatch(dispatchKeySet, A, driver, out); + } + + // aten::linalg_cond(Tensor self, Scalar? p=None) -> Tensor + inline at::Tensor linalg_cond(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const ::std::optional & p=::std::nullopt) { + return at::_ops::linalg_cond::redispatch(dispatchKeySet, self, p); + } + + // aten::linalg_cond.out(Tensor self, Scalar? p=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_cond_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const ::std::optional & p=::std::nullopt) { + return at::_ops::linalg_cond_out::redispatch(dispatchKeySet, self, p, out); + } + + // aten::linalg_cond.out(Tensor self, Scalar? p=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_cond_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const ::std::optional & p, at::Tensor & out) { + return at::_ops::linalg_cond_out::redispatch(dispatchKeySet, self, p, out); + } + + // aten::linalg_cond.p_str(Tensor self, str p) -> Tensor + inline at::Tensor linalg_cond(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::string_view p) { + return at::_ops::linalg_cond_p_str::redispatch(dispatchKeySet, self, p); + } + + // aten::linalg_cond.p_str_out(Tensor self, str p, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_cond_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::string_view p) { + return at::_ops::linalg_cond_p_str_out::redispatch(dispatchKeySet, self, p, out); + } + + // aten::linalg_cond.p_str_out(Tensor self, str p, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_cond_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::string_view p, at::Tensor & out) { + return at::_ops::linalg_cond_p_str_out::redispatch(dispatchKeySet, self, p, out); + } + + // aten::linalg_pinv.atol_rtol_tensor(Tensor self, *, Tensor? atol=None, Tensor? rtol=None, bool hermitian=False) -> Tensor + inline at::Tensor linalg_pinv(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const ::std::optional & atol={}, const ::std::optional & rtol={}, bool hermitian=false) { + return at::_ops::linalg_pinv_atol_rtol_tensor::redispatch(dispatchKeySet, self, atol, rtol, hermitian); + } + + // aten::linalg_pinv.atol_rtol_tensor_out(Tensor self, *, Tensor? atol=None, Tensor? rtol=None, bool hermitian=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_pinv_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const ::std::optional & atol={}, const ::std::optional & rtol={}, bool hermitian=false) { + return at::_ops::linalg_pinv_atol_rtol_tensor_out::redispatch(dispatchKeySet, self, atol, rtol, hermitian, out); + } + + // aten::linalg_pinv.atol_rtol_tensor_out(Tensor self, *, Tensor? atol=None, Tensor? rtol=None, bool hermitian=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_pinv_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const ::std::optional & atol, const ::std::optional & rtol, bool hermitian, at::Tensor & out) { + return at::_ops::linalg_pinv_atol_rtol_tensor_out::redispatch(dispatchKeySet, self, atol, rtol, hermitian, out); + } + + // aten::linalg_pinv.atol_rtol_float(Tensor self, *, float? atol=None, float? rtol=None, bool hermitian=False) -> Tensor + inline at::Tensor linalg_pinv(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional atol, ::std::optional rtol, bool hermitian=false) { + return at::_ops::linalg_pinv_atol_rtol_float::redispatch(dispatchKeySet, self, atol, rtol, hermitian); + } + + // aten::linalg_pinv.atol_rtol_float_out(Tensor self, *, float? atol=None, float? rtol=None, bool hermitian=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_pinv_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, ::std::optional atol, ::std::optional rtol, bool hermitian=false) { + return at::_ops::linalg_pinv_atol_rtol_float_out::redispatch(dispatchKeySet, self, atol, rtol, hermitian, out); + } + + // aten::linalg_pinv.atol_rtol_float_out(Tensor self, *, float? atol=None, float? rtol=None, bool hermitian=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_pinv_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional atol, ::std::optional rtol, bool hermitian, at::Tensor & out) { + return at::_ops::linalg_pinv_atol_rtol_float_out::redispatch(dispatchKeySet, self, atol, rtol, hermitian, out); + } + + // aten::linalg_pinv(Tensor self, float rcond, bool hermitian=False) -> Tensor + inline at::Tensor linalg_pinv(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double rcond, bool hermitian=false) { + return at::_ops::linalg_pinv::redispatch(dispatchKeySet, self, rcond, hermitian); + } + + // aten::linalg_pinv.rcond_tensor(Tensor self, Tensor rcond, bool hermitian=False) -> Tensor + inline at::Tensor linalg_pinv(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & rcond, bool hermitian=false) { + return at::_ops::linalg_pinv_rcond_tensor::redispatch(dispatchKeySet, self, rcond, hermitian); + } + + // aten::linalg_pinv.out(Tensor self, float rcond, bool hermitian=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_pinv_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, double rcond, bool hermitian=false) { + return at::_ops::linalg_pinv_out::redispatch(dispatchKeySet, self, rcond, hermitian, out); + } + + // aten::linalg_pinv.out(Tensor self, float rcond, bool hermitian=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_pinv_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double rcond, bool hermitian, at::Tensor & out) { + return at::_ops::linalg_pinv_out::redispatch(dispatchKeySet, self, rcond, hermitian, out); + } + + // aten::linalg_pinv.out_rcond_tensor(Tensor self, Tensor rcond, bool hermitian=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_pinv_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & rcond, bool hermitian=false) { + return at::_ops::linalg_pinv_out_rcond_tensor::redispatch(dispatchKeySet, self, rcond, hermitian, out); + } + + // aten::linalg_pinv.out_rcond_tensor(Tensor self, Tensor rcond, bool hermitian=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_pinv_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & rcond, bool hermitian, at::Tensor & out) { + return at::_ops::linalg_pinv_out_rcond_tensor::redispatch(dispatchKeySet, self, rcond, hermitian, out); + } + + // aten::_linalg_solve_ex(Tensor A, Tensor B, *, bool left=True, bool check_errors=False) -> (Tensor result, Tensor LU, Tensor pivots, Tensor info) + inline ::std::tuple _linalg_solve_ex(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A, const at::Tensor & B, bool left=true, bool check_errors=false) { + return at::_ops::_linalg_solve_ex::redispatch(dispatchKeySet, A, B, left, check_errors); + } + + // aten::_linalg_solve_ex.result(Tensor A, Tensor B, *, bool left=True, bool check_errors=False, Tensor(a!) result, Tensor(b!) LU, Tensor(c!) pivots, Tensor(d!) info) -> (Tensor(a!) result, Tensor(b!) LU, Tensor(c!) pivots, Tensor(d!) info) + inline ::std::tuple _linalg_solve_ex_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & result, at::Tensor & LU, at::Tensor & pivots, at::Tensor & info, const at::Tensor & A, const at::Tensor & B, bool left=true, bool check_errors=false) { + return at::_ops::_linalg_solve_ex_result::redispatch(dispatchKeySet, A, B, left, check_errors, result, LU, pivots, info); + } + + // aten::_linalg_solve_ex.result(Tensor A, Tensor B, *, bool left=True, bool check_errors=False, Tensor(a!) result, Tensor(b!) LU, Tensor(c!) pivots, Tensor(d!) info) -> (Tensor(a!) result, Tensor(b!) LU, Tensor(c!) pivots, Tensor(d!) info) + inline ::std::tuple _linalg_solve_ex_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A, const at::Tensor & B, bool left, bool check_errors, at::Tensor & result, at::Tensor & LU, at::Tensor & pivots, at::Tensor & info) { + return at::_ops::_linalg_solve_ex_result::redispatch(dispatchKeySet, A, B, left, check_errors, result, LU, pivots, info); + } + + // aten::linalg_solve_ex(Tensor A, Tensor B, *, bool left=True, bool check_errors=False) -> (Tensor result, Tensor info) + inline ::std::tuple linalg_solve_ex(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A, const at::Tensor & B, bool left=true, bool check_errors=false) { + return at::_ops::linalg_solve_ex::redispatch(dispatchKeySet, A, B, left, check_errors); + } + + // aten::linalg_solve_ex.out(Tensor A, Tensor B, *, bool left=True, bool check_errors=False, Tensor(a!) result, Tensor(b!) info) -> (Tensor(a!) result, Tensor(b!) info) + inline ::std::tuple linalg_solve_ex_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & result, at::Tensor & info, const at::Tensor & A, const at::Tensor & B, bool left=true, bool check_errors=false) { + return at::_ops::linalg_solve_ex_out::redispatch(dispatchKeySet, A, B, left, check_errors, result, info); + } + + // aten::linalg_solve_ex.out(Tensor A, Tensor B, *, bool left=True, bool check_errors=False, Tensor(a!) result, Tensor(b!) info) -> (Tensor(a!) result, Tensor(b!) info) + inline ::std::tuple linalg_solve_ex_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A, const at::Tensor & B, bool left, bool check_errors, at::Tensor & result, at::Tensor & info) { + return at::_ops::linalg_solve_ex_out::redispatch(dispatchKeySet, A, B, left, check_errors, result, info); + } + + // aten::linalg_solve(Tensor A, Tensor B, *, bool left=True) -> Tensor + inline at::Tensor linalg_solve(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A, const at::Tensor & B, bool left=true) { + return at::_ops::linalg_solve::redispatch(dispatchKeySet, A, B, left); + } + + // aten::_spsolve(Tensor A, Tensor B, *, bool left=True) -> Tensor + inline at::Tensor _spsolve(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A, const at::Tensor & B, bool left=true) { + return at::_ops::_spsolve::redispatch(dispatchKeySet, A, B, left); + } + + // aten::linalg_solve.out(Tensor A, Tensor B, *, bool left=True, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_solve_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & A, const at::Tensor & B, bool left=true) { + return at::_ops::linalg_solve_out::redispatch(dispatchKeySet, A, B, left, out); + } + + // aten::linalg_solve.out(Tensor A, Tensor B, *, bool left=True, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_solve_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A, const at::Tensor & B, bool left, at::Tensor & out) { + return at::_ops::linalg_solve_out::redispatch(dispatchKeySet, A, B, left, out); + } + + // aten::linalg_tensorinv(Tensor self, int ind=2) -> Tensor + inline at::Tensor linalg_tensorinv(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t ind=2) { + return at::_ops::linalg_tensorinv::redispatch(dispatchKeySet, self, ind); + } + + // aten::linalg_tensorinv.out(Tensor self, int ind=2, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_tensorinv_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t ind=2) { + return at::_ops::linalg_tensorinv_out::redispatch(dispatchKeySet, self, ind, out); + } + + // aten::linalg_tensorinv.out(Tensor self, int ind=2, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_tensorinv_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t ind, at::Tensor & out) { + return at::_ops::linalg_tensorinv_out::redispatch(dispatchKeySet, self, ind, out); + } + + // aten::linalg_tensorsolve(Tensor self, Tensor other, int[]? dims=None) -> Tensor + inline at::Tensor linalg_tensorsolve(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::OptionalIntArrayRef dims=::std::nullopt) { + return at::_ops::linalg_tensorsolve::redispatch(dispatchKeySet, self, other, dims); + } + + // aten::linalg_tensorsolve.out(Tensor self, Tensor other, int[]? dims=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_tensorsolve_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other, at::OptionalIntArrayRef dims=::std::nullopt) { + return at::_ops::linalg_tensorsolve_out::redispatch(dispatchKeySet, self, other, dims, out); + } + + // aten::linalg_tensorsolve.out(Tensor self, Tensor other, int[]? dims=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_tensorsolve_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::OptionalIntArrayRef dims, at::Tensor & out) { + return at::_ops::linalg_tensorsolve_out::redispatch(dispatchKeySet, self, other, dims, out); + } + + // aten::linalg_qr(Tensor A, str mode='reduced') -> (Tensor Q, Tensor R) + inline ::std::tuple linalg_qr(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A, c10::string_view mode="reduced") { + return at::_ops::linalg_qr::redispatch(dispatchKeySet, A, mode); + } + + // aten::linalg_qr.out(Tensor A, str mode='reduced', *, Tensor(a!) Q, Tensor(b!) R) -> (Tensor(a!) Q, Tensor(b!) R) + inline ::std::tuple linalg_qr_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & Q, at::Tensor & R, const at::Tensor & A, c10::string_view mode="reduced") { + return at::_ops::linalg_qr_out::redispatch(dispatchKeySet, A, mode, Q, R); + } + + // aten::linalg_qr.out(Tensor A, str mode='reduced', *, Tensor(a!) Q, Tensor(b!) R) -> (Tensor(a!) Q, Tensor(b!) R) + inline ::std::tuple linalg_qr_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A, c10::string_view mode, at::Tensor & Q, at::Tensor & R) { + return at::_ops::linalg_qr_out::redispatch(dispatchKeySet, A, mode, Q, R); + } + + // aten::linalg_matrix_power(Tensor self, int n) -> Tensor + inline at::Tensor linalg_matrix_power(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t n) { + return at::_ops::linalg_matrix_power::redispatch(dispatchKeySet, self, n); + } + + // aten::linalg_matrix_power.out(Tensor self, int n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_matrix_power_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t n) { + return at::_ops::linalg_matrix_power_out::redispatch(dispatchKeySet, self, n, out); + } + + // aten::linalg_matrix_power.out(Tensor self, int n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_matrix_power_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t n, at::Tensor & out) { + return at::_ops::linalg_matrix_power_out::redispatch(dispatchKeySet, self, n, out); + } + + // aten::linalg_matrix_rank.atol_rtol_tensor(Tensor input, *, Tensor? atol=None, Tensor? rtol=None, bool hermitian=False) -> Tensor + inline at::Tensor linalg_matrix_rank(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const ::std::optional & atol={}, const ::std::optional & rtol={}, bool hermitian=false) { + return at::_ops::linalg_matrix_rank_atol_rtol_tensor::redispatch(dispatchKeySet, input, atol, rtol, hermitian); + } + + // aten::linalg_matrix_rank.atol_rtol_tensor_out(Tensor input, *, Tensor? atol=None, Tensor? rtol=None, bool hermitian=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_matrix_rank_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & input, const ::std::optional & atol={}, const ::std::optional & rtol={}, bool hermitian=false) { + return at::_ops::linalg_matrix_rank_atol_rtol_tensor_out::redispatch(dispatchKeySet, input, atol, rtol, hermitian, out); + } + + // aten::linalg_matrix_rank.atol_rtol_tensor_out(Tensor input, *, Tensor? atol=None, Tensor? rtol=None, bool hermitian=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_matrix_rank_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const ::std::optional & atol, const ::std::optional & rtol, bool hermitian, at::Tensor & out) { + return at::_ops::linalg_matrix_rank_atol_rtol_tensor_out::redispatch(dispatchKeySet, input, atol, rtol, hermitian, out); + } + + // aten::linalg_matrix_rank.atol_rtol_float(Tensor self, *, float? atol=None, float? rtol=None, bool hermitian=False) -> Tensor + inline at::Tensor linalg_matrix_rank(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional atol, ::std::optional rtol, bool hermitian=false) { + return at::_ops::linalg_matrix_rank_atol_rtol_float::redispatch(dispatchKeySet, self, atol, rtol, hermitian); + } + + // aten::linalg_matrix_rank.atol_rtol_float_out(Tensor self, *, float? atol=None, float? rtol=None, bool hermitian=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_matrix_rank_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, ::std::optional atol, ::std::optional rtol, bool hermitian=false) { + return at::_ops::linalg_matrix_rank_atol_rtol_float_out::redispatch(dispatchKeySet, self, atol, rtol, hermitian, out); + } + + // aten::linalg_matrix_rank.atol_rtol_float_out(Tensor self, *, float? atol=None, float? rtol=None, bool hermitian=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_matrix_rank_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional atol, ::std::optional rtol, bool hermitian, at::Tensor & out) { + return at::_ops::linalg_matrix_rank_atol_rtol_float_out::redispatch(dispatchKeySet, self, atol, rtol, hermitian, out); + } + + // aten::linalg_matrix_rank(Tensor self, float tol, bool hermitian=False) -> Tensor + inline at::Tensor linalg_matrix_rank(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double tol, bool hermitian=false) { + return at::_ops::linalg_matrix_rank::redispatch(dispatchKeySet, self, tol, hermitian); + } + + // aten::linalg_matrix_rank.out(Tensor self, float tol, bool hermitian=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_matrix_rank_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, double tol, bool hermitian=false) { + return at::_ops::linalg_matrix_rank_out::redispatch(dispatchKeySet, self, tol, hermitian, out); + } + + // aten::linalg_matrix_rank.out(Tensor self, float tol, bool hermitian=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_matrix_rank_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double tol, bool hermitian, at::Tensor & out) { + return at::_ops::linalg_matrix_rank_out::redispatch(dispatchKeySet, self, tol, hermitian, out); + } + + // aten::linalg_matrix_rank.tol_tensor(Tensor input, Tensor tol, bool hermitian=False) -> Tensor + inline at::Tensor linalg_matrix_rank(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & tol, bool hermitian=false) { + return at::_ops::linalg_matrix_rank_tol_tensor::redispatch(dispatchKeySet, input, tol, hermitian); + } + + // aten::linalg_matrix_rank.out_tol_tensor(Tensor input, Tensor tol, bool hermitian=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_matrix_rank_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & input, const at::Tensor & tol, bool hermitian=false) { + return at::_ops::linalg_matrix_rank_out_tol_tensor::redispatch(dispatchKeySet, input, tol, hermitian, out); + } + + // aten::linalg_matrix_rank.out_tol_tensor(Tensor input, Tensor tol, bool hermitian=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_matrix_rank_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & tol, bool hermitian, at::Tensor & out) { + return at::_ops::linalg_matrix_rank_out_tol_tensor::redispatch(dispatchKeySet, input, tol, hermitian, out); + } + + // aten::linalg_multi_dot(Tensor[] tensors) -> Tensor + inline at::Tensor linalg_multi_dot(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors) { + return at::_ops::linalg_multi_dot::redispatch(dispatchKeySet, tensors); + } + + // aten::linalg_multi_dot.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_multi_dot_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::TensorList tensors) { + return at::_ops::linalg_multi_dot_out::redispatch(dispatchKeySet, tensors, out); + } + + // aten::linalg_multi_dot.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_multi_dot_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors, at::Tensor & out) { + return at::_ops::linalg_multi_dot_out::redispatch(dispatchKeySet, tensors, out); + } + + // aten::nested_to_padded_tensor(Tensor self, float padding, int[]? output_size=None) -> Tensor + inline at::Tensor nested_to_padded_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double padding, at::OptionalIntArrayRef output_size=::std::nullopt) { + return at::_ops::nested_to_padded_tensor::redispatch(dispatchKeySet, self, padding, output_size); + } + + // aten::_test_serialization_subcmul(Tensor self, Tensor other, Scalar alpha=1) -> Tensor + inline at::Tensor _test_serialization_subcmul(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha=1) { + return at::_ops::_test_serialization_subcmul::redispatch(dispatchKeySet, self, other, alpha); + } + + // aten::_test_parallel_materialize(Tensor self, int num_parallel, bool skip_first=False) -> Tensor + inline at::Tensor _test_parallel_materialize(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t num_parallel, bool skip_first=false) { + return at::_ops::_test_parallel_materialize::redispatch(dispatchKeySet, self, num_parallel, skip_first); + } + + // aten::_test_optional_intlist(Tensor values, int[]? addends) -> Tensor + inline at::Tensor _test_optional_intlist(c10::DispatchKeySet dispatchKeySet, const at::Tensor & values, at::OptionalIntArrayRef addends) { + return at::_ops::_test_optional_intlist::redispatch(dispatchKeySet, values, addends); + } + + // aten::_test_optional_filled_intlist(Tensor values, int[2]? addends) -> Tensor + inline at::Tensor _test_optional_filled_intlist(c10::DispatchKeySet dispatchKeySet, const at::Tensor & values, at::OptionalIntArrayRef addends) { + return at::_ops::_test_optional_filled_intlist::redispatch(dispatchKeySet, values, addends); + } + + // aten::_test_optional_floatlist(Tensor values, float[]? addends) -> Tensor + inline at::Tensor _test_optional_floatlist(c10::DispatchKeySet dispatchKeySet, const at::Tensor & values, ::std::optional> addends) { + return at::_ops::_test_optional_floatlist::redispatch(dispatchKeySet, values, addends); + } + + // aten::_test_string_default(Tensor dummy, str a="\"'\\", str b='"\'\\') -> Tensor + inline at::Tensor _test_string_default(c10::DispatchKeySet dispatchKeySet, const at::Tensor & dummy, c10::string_view a="\"'\\", c10::string_view b="\"'\\") { + return at::_ops::_test_string_default::redispatch(dispatchKeySet, dummy, a, b); + } + + // aten::_test_ambiguous_defaults.a(Tensor dummy, int a=1, int b=1) -> Tensor + inline at::Tensor _test_ambiguous_defaults(c10::DispatchKeySet dispatchKeySet, const at::Tensor & dummy, int64_t a=1, int64_t b=1) { + return at::_ops::_test_ambiguous_defaults_a::redispatch(dispatchKeySet, dummy, a, b); + } + + // aten::_test_ambiguous_defaults.b(Tensor dummy, int a=2, str b="2") -> Tensor + inline at::Tensor _test_ambiguous_defaults(c10::DispatchKeySet dispatchKeySet, const at::Tensor & dummy, int64_t a, c10::string_view b) { + return at::_ops::_test_ambiguous_defaults_b::redispatch(dispatchKeySet, dummy, a, b); + } + + // aten::_test_warn_in_autograd(Tensor self) -> Tensor + inline at::Tensor _test_warn_in_autograd(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::_test_warn_in_autograd::redispatch(dispatchKeySet, self); + } + + // aten::_test_autograd_multiple_dispatch.fullcoverage(Tensor self) -> Tensor + inline at::Tensor _test_autograd_multiple_dispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::_test_autograd_multiple_dispatch_fullcoverage::redispatch(dispatchKeySet, self); + } + + // aten::_test_autograd_multiple_dispatch.ntonly(Tensor self, bool b) -> Tensor + inline at::Tensor _test_autograd_multiple_dispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool b) { + return at::_ops::_test_autograd_multiple_dispatch_ntonly::redispatch(dispatchKeySet, self, b); + } + + // aten::_test_autograd_multiple_dispatch_view(Tensor(a) self) -> Tensor(a) + inline at::Tensor _test_autograd_multiple_dispatch_view(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::_test_autograd_multiple_dispatch_view::redispatch(dispatchKeySet, self); + } + + // aten::_test_autograd_multiple_dispatch_view_copy(Tensor self) -> Tensor + inline at::Tensor _test_autograd_multiple_dispatch_view_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::_test_autograd_multiple_dispatch_view_copy::redispatch(dispatchKeySet, self); + } + + // aten::segment_reduce(Tensor data, str reduce, *, Tensor? lengths=None, Tensor? indices=None, Tensor? offsets=None, int axis=0, bool unsafe=False, Scalar? initial=None) -> Tensor + inline at::Tensor segment_reduce(c10::DispatchKeySet dispatchKeySet, const at::Tensor & data, c10::string_view reduce, const ::std::optional & lengths={}, const ::std::optional & indices={}, const ::std::optional & offsets={}, int64_t axis=0, bool unsafe=false, const ::std::optional & initial=::std::nullopt) { + return at::_ops::segment_reduce::redispatch(dispatchKeySet, data, reduce, lengths, indices, offsets, axis, unsafe, initial); + } + + // aten::_segment_reduce_backward(Tensor grad, Tensor output, Tensor data, str reduce, *, Tensor? lengths=None, Tensor? offsets=None, int axis=0, Scalar? initial=None) -> Tensor + inline at::Tensor _segment_reduce_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & output, const at::Tensor & data, c10::string_view reduce, const ::std::optional & lengths={}, const ::std::optional & offsets={}, int64_t axis=0, const ::std::optional & initial=::std::nullopt) { + return at::_ops::_segment_reduce_backward::redispatch(dispatchKeySet, grad, output, data, reduce, lengths, offsets, axis, initial); + } + + // aten::pad_sequence(Tensor[] sequences, bool batch_first=False, float padding_value=0.0, str padding_side="right") -> Tensor + inline at::Tensor pad_sequence(c10::DispatchKeySet dispatchKeySet, at::TensorList sequences, bool batch_first=false, double padding_value=0.0, c10::string_view padding_side="right") { + return at::_ops::pad_sequence::redispatch(dispatchKeySet, sequences, batch_first, padding_value, padding_side); + } + + // aten::flatten_dense_tensors(Tensor[] tensors) -> Tensor + inline at::Tensor flatten_dense_tensors(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors) { + return at::_ops::flatten_dense_tensors::redispatch(dispatchKeySet, tensors); + } + + // aten::unflatten_dense_tensors(Tensor flat, Tensor[] tensors) -> Tensor[] + inline ::std::vector unflatten_dense_tensors(c10::DispatchKeySet dispatchKeySet, const at::Tensor & flat, at::TensorList tensors) { + return at::_ops::unflatten_dense_tensors::redispatch(dispatchKeySet, flat, tensors); + } + + // aten::_nested_tensor_from_tensor_list(Tensor[] list, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor _nested_tensor_from_tensor_list(c10::DispatchKeySet dispatchKeySet, at::TensorList list, ::std::optional dtype=::std::nullopt, ::std::optional layout=::std::nullopt, ::std::optional device=::std::nullopt, ::std::optional pin_memory=::std::nullopt) { + return at::_ops::_nested_tensor_from_tensor_list::redispatch(dispatchKeySet, list, dtype, layout, device, pin_memory); + } + + // aten::_fw_primal_copy(Tensor self, int level) -> Tensor + inline at::Tensor _fw_primal_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t level) { + return at::_ops::_fw_primal_copy::redispatch(dispatchKeySet, self, level); + } + + // aten::_make_dual_copy(Tensor primal, Tensor tangent, int level) -> Tensor + inline at::Tensor _make_dual_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & primal, const at::Tensor & tangent, int64_t level) { + return at::_ops::_make_dual_copy::redispatch(dispatchKeySet, primal, tangent, level); + } + + // aten::view_as_real_copy(Tensor self) -> Tensor + inline at::Tensor view_as_real_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::view_as_real_copy::redispatch(dispatchKeySet, self); + } + + // aten::view_as_complex_copy(Tensor self) -> Tensor + inline at::Tensor view_as_complex_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::view_as_complex_copy::redispatch(dispatchKeySet, self); + } + + // aten::_conj_copy(Tensor self) -> Tensor + inline at::Tensor _conj_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::_conj_copy::redispatch(dispatchKeySet, self); + } + + // aten::_neg_view_copy(Tensor self) -> Tensor + inline at::Tensor _neg_view_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::_neg_view_copy::redispatch(dispatchKeySet, self); + } + + // aten::as_strided_copy(Tensor self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor + inline at::Tensor as_strided_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, at::IntArrayRef stride, ::std::optional storage_offset=::std::nullopt) { + return at::_ops::as_strided_copy::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride), storage_offset.has_value() ? ::std::make_optional(c10::SymInt(*storage_offset)) : ::std::nullopt); + } + + // aten::as_strided_copy(Tensor self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor + inline at::Tensor as_strided_copy_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional storage_offset=::std::nullopt) { + return at::_ops::as_strided_copy::redispatch(dispatchKeySet, self, size, stride, storage_offset); + } + + // aten::_sparse_broadcast_to_copy(Tensor self, int[] size) -> Tensor + inline at::Tensor _sparse_broadcast_to_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size) { + return at::_ops::_sparse_broadcast_to_copy::redispatch(dispatchKeySet, self, size); + } + + // aten::diagonal_copy(Tensor self, int offset=0, int dim1=0, int dim2=1) -> Tensor + inline at::Tensor diagonal_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t offset=0, int64_t dim1=0, int64_t dim2=1) { + return at::_ops::diagonal_copy::redispatch(dispatchKeySet, self, offset, dim1, dim2); + } + + // aten::expand_copy(Tensor self, SymInt[] size, *, bool implicit=False) -> Tensor + inline at::Tensor expand_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, bool implicit=false) { + return at::_ops::expand_copy::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), implicit); + } + + // aten::expand_copy(Tensor self, SymInt[] size, *, bool implicit=False) -> Tensor + inline at::Tensor expand_copy_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, bool implicit=false) { + return at::_ops::expand_copy::redispatch(dispatchKeySet, self, size, implicit); + } + + // aten::permute_copy(Tensor self, int[] dims) -> Tensor + inline at::Tensor permute_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dims) { + return at::_ops::permute_copy::redispatch(dispatchKeySet, self, dims); + } + + // aten::_reshape_alias_copy(Tensor self, SymInt[] size, SymInt[] stride) -> Tensor + inline at::Tensor _reshape_alias_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, at::IntArrayRef stride) { + return at::_ops::_reshape_alias_copy::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride)); + } + + // aten::_reshape_alias_copy(Tensor self, SymInt[] size, SymInt[] stride) -> Tensor + inline at::Tensor _reshape_alias_copy_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride) { + return at::_ops::_reshape_alias_copy::redispatch(dispatchKeySet, self, size, stride); + } + + // aten::select_copy.int(Tensor self, int dim, SymInt index) -> Tensor + inline at::Tensor select_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, int64_t index) { + return at::_ops::select_copy_int::redispatch(dispatchKeySet, self, dim, index); + } + + // aten::select_copy.int(Tensor self, int dim, SymInt index) -> Tensor + inline at::Tensor select_copy_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, c10::SymInt index) { + return at::_ops::select_copy_int::redispatch(dispatchKeySet, self, dim, index); + } + + // aten::detach_copy(Tensor self) -> Tensor + inline at::Tensor detach_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::detach_copy::redispatch(dispatchKeySet, self); + } + + // aten::slice_copy.Tensor(Tensor self, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor + inline at::Tensor slice_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim=0, ::std::optional start=::std::nullopt, ::std::optional end=::std::nullopt, int64_t step=1) { + return at::_ops::slice_copy_Tensor::redispatch(dispatchKeySet, self, dim, start.has_value() ? ::std::make_optional(c10::SymInt(*start)) : ::std::nullopt, end.has_value() ? ::std::make_optional(c10::SymInt(*end)) : ::std::nullopt, step); + } + + // aten::slice_copy.Tensor(Tensor self, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor + inline at::Tensor slice_copy_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim=0, ::std::optional start=::std::nullopt, ::std::optional end=::std::nullopt, c10::SymInt step=1) { + return at::_ops::slice_copy_Tensor::redispatch(dispatchKeySet, self, dim, start, end, step); + } + + // aten::split_copy.Tensor(Tensor self, SymInt split_size, int dim=0) -> Tensor[] + inline ::std::vector split_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t split_size, int64_t dim=0) { + return at::_ops::split_copy_Tensor::redispatch(dispatchKeySet, self, split_size, dim); + } + + // aten::split_copy.Tensor(Tensor self, SymInt split_size, int dim=0) -> Tensor[] + inline ::std::vector split_copy_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymInt split_size, int64_t dim=0) { + return at::_ops::split_copy_Tensor::redispatch(dispatchKeySet, self, split_size, dim); + } + + // aten::split_with_sizes_copy(Tensor self, SymInt[] split_sizes, int dim=0) -> Tensor[] + inline ::std::vector split_with_sizes_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef split_sizes, int64_t dim=0) { + return at::_ops::split_with_sizes_copy::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(split_sizes), dim); + } + + // aten::split_with_sizes_copy(Tensor self, SymInt[] split_sizes, int dim=0) -> Tensor[] + inline ::std::vector split_with_sizes_copy_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef split_sizes, int64_t dim=0) { + return at::_ops::split_with_sizes_copy::redispatch(dispatchKeySet, self, split_sizes, dim); + } + + // aten::squeeze_copy(Tensor self) -> Tensor + inline at::Tensor squeeze_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::squeeze_copy::redispatch(dispatchKeySet, self); + } + + // aten::squeeze_copy.dim(Tensor self, int dim) -> Tensor + inline at::Tensor squeeze_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim) { + return at::_ops::squeeze_copy_dim::redispatch(dispatchKeySet, self, dim); + } + + // aten::squeeze_copy.dims(Tensor self, int[] dim) -> Tensor + inline at::Tensor squeeze_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim) { + return at::_ops::squeeze_copy_dims::redispatch(dispatchKeySet, self, dim); + } + + // aten::t_copy(Tensor self) -> Tensor + inline at::Tensor t_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::t_copy::redispatch(dispatchKeySet, self); + } + + // aten::transpose_copy.int(Tensor self, int dim0, int dim1) -> Tensor + inline at::Tensor transpose_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim0, int64_t dim1) { + return at::_ops::transpose_copy_int::redispatch(dispatchKeySet, self, dim0, dim1); + } + + // aten::unsqueeze_copy(Tensor self, int dim) -> Tensor + inline at::Tensor unsqueeze_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim) { + return at::_ops::unsqueeze_copy::redispatch(dispatchKeySet, self, dim); + } + + // aten::_indices_copy(Tensor self) -> Tensor + inline at::Tensor _indices_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::_indices_copy::redispatch(dispatchKeySet, self); + } + + // aten::_values_copy(Tensor self) -> Tensor + inline at::Tensor _values_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::_values_copy::redispatch(dispatchKeySet, self); + } + + // aten::indices_copy(Tensor self) -> Tensor + inline at::Tensor indices_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::indices_copy::redispatch(dispatchKeySet, self); + } + + // aten::values_copy(Tensor self) -> Tensor + inline at::Tensor values_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::values_copy::redispatch(dispatchKeySet, self); + } + + // aten::crow_indices_copy(Tensor self) -> Tensor + inline at::Tensor crow_indices_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::crow_indices_copy::redispatch(dispatchKeySet, self); + } + + // aten::col_indices_copy(Tensor self) -> Tensor + inline at::Tensor col_indices_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::col_indices_copy::redispatch(dispatchKeySet, self); + } + + // aten::ccol_indices_copy(Tensor self) -> Tensor + inline at::Tensor ccol_indices_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::ccol_indices_copy::redispatch(dispatchKeySet, self); + } + + // aten::row_indices_copy(Tensor self) -> Tensor + inline at::Tensor row_indices_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::row_indices_copy::redispatch(dispatchKeySet, self); + } + + // aten::unbind_copy.int(Tensor self, int dim=0) -> Tensor[] + inline ::std::vector unbind_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim=0) { + return at::_ops::unbind_copy_int::redispatch(dispatchKeySet, self, dim); + } + + // aten::unbind_copy.int_out(Tensor self, int dim=0, *, Tensor(a!)[] out) -> () + inline void unbind_copy_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, const at::Tensor & self, int64_t dim=0) { + return at::_ops::unbind_copy_int_out::redispatch(dispatchKeySet, self, dim, out); + } + + // aten::unbind_copy.int_out(Tensor self, int dim=0, *, Tensor(a!)[] out) -> () + inline void unbind_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, at::TensorList out) { + return at::_ops::unbind_copy_int_out::redispatch(dispatchKeySet, self, dim, out); + } + + // aten::split_copy.Tensor_out(Tensor self, SymInt split_size, int dim=0, *, Tensor(a!)[] out) -> () + inline void split_copy_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, const at::Tensor & self, int64_t split_size, int64_t dim=0) { + return at::_ops::split_copy_Tensor_out::redispatch(dispatchKeySet, self, split_size, dim, out); + } + + // aten::split_copy.Tensor_out(Tensor self, SymInt split_size, int dim=0, *, Tensor(a!)[] out) -> () + inline void split_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t split_size, int64_t dim, at::TensorList out) { + return at::_ops::split_copy_Tensor_out::redispatch(dispatchKeySet, self, split_size, dim, out); + } + + // aten::split_copy.Tensor_out(Tensor self, SymInt split_size, int dim=0, *, Tensor(a!)[] out) -> () + inline void split_copy_symint_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, const at::Tensor & self, c10::SymInt split_size, int64_t dim=0) { + return at::_ops::split_copy_Tensor_out::redispatch(dispatchKeySet, self, split_size, dim, out); + } + + // aten::split_copy.Tensor_out(Tensor self, SymInt split_size, int dim=0, *, Tensor(a!)[] out) -> () + inline void split_copy_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymInt split_size, int64_t dim, at::TensorList out) { + return at::_ops::split_copy_Tensor_out::redispatch(dispatchKeySet, self, split_size, dim, out); + } + + // aten::split_with_sizes_copy.out(Tensor self, SymInt[] split_sizes, int dim=0, *, Tensor(a!)[] out) -> () + inline void split_with_sizes_copy_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, const at::Tensor & self, at::IntArrayRef split_sizes, int64_t dim=0) { + return at::_ops::split_with_sizes_copy_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(split_sizes), dim, out); + } + + // aten::split_with_sizes_copy.out(Tensor self, SymInt[] split_sizes, int dim=0, *, Tensor(a!)[] out) -> () + inline void split_with_sizes_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef split_sizes, int64_t dim, at::TensorList out) { + return at::_ops::split_with_sizes_copy_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(split_sizes), dim, out); + } + + // aten::split_with_sizes_copy.out(Tensor self, SymInt[] split_sizes, int dim=0, *, Tensor(a!)[] out) -> () + inline void split_with_sizes_copy_symint_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, const at::Tensor & self, c10::SymIntArrayRef split_sizes, int64_t dim=0) { + return at::_ops::split_with_sizes_copy_out::redispatch(dispatchKeySet, self, split_sizes, dim, out); + } + + // aten::split_with_sizes_copy.out(Tensor self, SymInt[] split_sizes, int dim=0, *, Tensor(a!)[] out) -> () + inline void split_with_sizes_copy_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef split_sizes, int64_t dim, at::TensorList out) { + return at::_ops::split_with_sizes_copy_out::redispatch(dispatchKeySet, self, split_sizes, dim, out); + } + + // aten::view_copy(Tensor self, SymInt[] size) -> Tensor + inline at::Tensor view_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size) { + return at::_ops::view_copy::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size)); + } + + // aten::view_copy(Tensor self, SymInt[] size) -> Tensor + inline at::Tensor view_copy_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size) { + return at::_ops::view_copy::redispatch(dispatchKeySet, self, size); + } + + // aten::view_copy.dtype(Tensor self, ScalarType dtype) -> Tensor + inline at::Tensor view_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::ScalarType dtype) { + return at::_ops::view_copy_dtype::redispatch(dispatchKeySet, self, dtype); + } + + // aten::unfold_copy(Tensor self, int dimension, int size, int step) -> Tensor + inline at::Tensor unfold_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dimension, int64_t size, int64_t step) { + return at::_ops::unfold_copy::redispatch(dispatchKeySet, self, dimension, size, step); + } + + // aten::alias_copy(Tensor self) -> Tensor + inline at::Tensor alias_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::alias_copy::redispatch(dispatchKeySet, self); + } + + // aten::to_padded_tensor(Tensor self, float padding, SymInt[]? output_size=None) -> Tensor + inline at::Tensor to_padded_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double padding, at::OptionalIntArrayRef output_size=::std::nullopt) { + return at::_ops::to_padded_tensor::redispatch(dispatchKeySet, self, padding, output_size.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*output_size)) : ::std::nullopt); + } + + // aten::to_padded_tensor(Tensor self, float padding, SymInt[]? output_size=None) -> Tensor + inline at::Tensor to_padded_tensor_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double padding, at::OptionalSymIntArrayRef output_size=::std::nullopt) { + return at::_ops::to_padded_tensor::redispatch(dispatchKeySet, self, padding, output_size); + } + + // aten::_jagged_to_padded_dense_forward(Tensor values, Tensor[] offsets, SymInt[] max_lengths, float padding_value=0.0) -> Tensor + inline at::Tensor _jagged_to_padded_dense_forward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & values, at::TensorList offsets, at::IntArrayRef max_lengths, double padding_value=0.0) { + return at::_ops::_jagged_to_padded_dense_forward::redispatch(dispatchKeySet, values, offsets, c10::fromIntArrayRefSlow(max_lengths), padding_value); + } + + // aten::_jagged_to_padded_dense_forward(Tensor values, Tensor[] offsets, SymInt[] max_lengths, float padding_value=0.0) -> Tensor + inline at::Tensor _jagged_to_padded_dense_forward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & values, at::TensorList offsets, c10::SymIntArrayRef max_lengths, double padding_value=0.0) { + return at::_ops::_jagged_to_padded_dense_forward::redispatch(dispatchKeySet, values, offsets, max_lengths, padding_value); + } + + // aten::_padded_dense_to_jagged_forward(Tensor dense, Tensor[] offsets, SymInt? total_L=None) -> Tensor + inline at::Tensor _padded_dense_to_jagged_forward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & dense, at::TensorList offsets, ::std::optional total_L=::std::nullopt) { + return at::_ops::_padded_dense_to_jagged_forward::redispatch(dispatchKeySet, dense, offsets, total_L.has_value() ? ::std::make_optional(c10::SymInt(*total_L)) : ::std::nullopt); + } + + // aten::_padded_dense_to_jagged_forward(Tensor dense, Tensor[] offsets, SymInt? total_L=None) -> Tensor + inline at::Tensor _padded_dense_to_jagged_forward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & dense, at::TensorList offsets, ::std::optional total_L=::std::nullopt) { + return at::_ops::_padded_dense_to_jagged_forward::redispatch(dispatchKeySet, dense, offsets, total_L); + } + + // aten::_nested_from_padded_tensor(Tensor padded, Tensor offsets, Tensor dummy, int ragged_idx=1, Tensor? min_seqlen=None, Tensor? max_seqlen=None, SymInt? sum_S=None) -> Tensor + inline at::Tensor _nested_from_padded_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & padded, const at::Tensor & offsets, const at::Tensor & dummy, int64_t ragged_idx=1, const ::std::optional & min_seqlen={}, const ::std::optional & max_seqlen={}, ::std::optional sum_S=::std::nullopt) { + return at::_ops::_nested_from_padded_tensor::redispatch(dispatchKeySet, padded, offsets, dummy, ragged_idx, min_seqlen, max_seqlen, sum_S.has_value() ? ::std::make_optional(c10::SymInt(*sum_S)) : ::std::nullopt); + } + + // aten::_nested_from_padded_tensor(Tensor padded, Tensor offsets, Tensor dummy, int ragged_idx=1, Tensor? min_seqlen=None, Tensor? max_seqlen=None, SymInt? sum_S=None) -> Tensor + inline at::Tensor _nested_from_padded_tensor_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & padded, const at::Tensor & offsets, const at::Tensor & dummy, int64_t ragged_idx=1, const ::std::optional & min_seqlen={}, const ::std::optional & max_seqlen={}, ::std::optional sum_S=::std::nullopt) { + return at::_ops::_nested_from_padded_tensor::redispatch(dispatchKeySet, padded, offsets, dummy, ragged_idx, min_seqlen, max_seqlen, sum_S); + } + + // aten::_nested_tensor_softmax_with_shape(Tensor self, Tensor query) -> Tensor + inline at::Tensor _nested_tensor_softmax_with_shape(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & query) { + return at::_ops::_nested_tensor_softmax_with_shape::redispatch(dispatchKeySet, self, query); + } + + // aten::_safe_softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor + inline at::Tensor _safe_softmax(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, ::std::optional dtype=::std::nullopt) { + return at::_ops::_safe_softmax::redispatch(dispatchKeySet, self, dim, dtype); + } + + // aten::_transformer_encoder_layer_fwd(Tensor src, int embed_dim, int num_heads, Tensor qkv_weight, Tensor qkv_bias, Tensor proj_weight, Tensor proj_bias, bool use_gelu, bool norm_first, float eps, Tensor norm_weight_1, Tensor norm_bias_1, Tensor norm_weight_2, Tensor norm_bias_2, Tensor ffn_weight_1, Tensor ffn_bias_1, Tensor ffn_weight_2, Tensor ffn_bias_2, Tensor? mask=None, int? mask_type=None) -> Tensor + inline at::Tensor _transformer_encoder_layer_fwd(c10::DispatchKeySet dispatchKeySet, const at::Tensor & src, int64_t embed_dim, int64_t num_heads, const at::Tensor & qkv_weight, const at::Tensor & qkv_bias, const at::Tensor & proj_weight, const at::Tensor & proj_bias, bool use_gelu, bool norm_first, double eps, const at::Tensor & norm_weight_1, const at::Tensor & norm_bias_1, const at::Tensor & norm_weight_2, const at::Tensor & norm_bias_2, const at::Tensor & ffn_weight_1, const at::Tensor & ffn_bias_1, const at::Tensor & ffn_weight_2, const at::Tensor & ffn_bias_2, const ::std::optional & mask={}, ::std::optional mask_type=::std::nullopt) { + return at::_ops::_transformer_encoder_layer_fwd::redispatch(dispatchKeySet, src, embed_dim, num_heads, qkv_weight, qkv_bias, proj_weight, proj_bias, use_gelu, norm_first, eps, norm_weight_1, norm_bias_1, norm_weight_2, norm_bias_2, ffn_weight_1, ffn_bias_1, ffn_weight_2, ffn_bias_2, mask, mask_type); + } + + // aten::_native_multi_head_attention(Tensor query, Tensor key, Tensor value, int embed_dim, int num_head, Tensor qkv_weight, Tensor qkv_bias, Tensor proj_weight, Tensor proj_bias, Tensor? mask=None, bool need_weights=True, bool average_attn_weights=True, int? mask_type=None) -> (Tensor, Tensor) + inline ::std::tuple _native_multi_head_attention(c10::DispatchKeySet dispatchKeySet, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, int64_t embed_dim, int64_t num_head, const at::Tensor & qkv_weight, const at::Tensor & qkv_bias, const at::Tensor & proj_weight, const at::Tensor & proj_bias, const ::std::optional & mask={}, bool need_weights=true, bool average_attn_weights=true, ::std::optional mask_type=::std::nullopt) { + return at::_ops::_native_multi_head_attention::redispatch(dispatchKeySet, query, key, value, embed_dim, num_head, qkv_weight, qkv_bias, proj_weight, proj_bias, mask, need_weights, average_attn_weights, mask_type); + } + + // aten::scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, *, float? scale=None, bool enable_gqa=False) -> Tensor + inline at::Tensor scaled_dot_product_attention(c10::DispatchKeySet dispatchKeySet, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const ::std::optional & attn_mask={}, double dropout_p=0.0, bool is_causal=false, ::std::optional scale=::std::nullopt, bool enable_gqa=false) { + return at::_ops::scaled_dot_product_attention::redispatch(dispatchKeySet, query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa); + } + + // aten::_fused_sdp_choice(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, *, float? scale=None, bool enable_gqa=False) -> int + inline int64_t _fused_sdp_choice(c10::DispatchKeySet dispatchKeySet, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const ::std::optional & attn_mask={}, double dropout_p=0.0, bool is_causal=false, ::std::optional scale=::std::nullopt, bool enable_gqa=false) { + return at::_ops::_fused_sdp_choice::redispatch(dispatchKeySet, query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa); + } + + // aten::_scaled_dot_product_attention_math(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, Tensor? dropout_mask=None, *, float? scale=None, bool enable_gqa=False) -> (Tensor, Tensor) + inline ::std::tuple _scaled_dot_product_attention_math(c10::DispatchKeySet dispatchKeySet, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const ::std::optional & attn_mask={}, double dropout_p=0.0, bool is_causal=false, const ::std::optional & dropout_mask={}, ::std::optional scale=::std::nullopt, bool enable_gqa=false) { + return at::_ops::_scaled_dot_product_attention_math::redispatch(dispatchKeySet, query, key, value, attn_mask, dropout_p, is_causal, dropout_mask, scale, enable_gqa); + } + + // aten::_scaled_dot_product_attention_math_for_mps(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, Tensor? dropout_mask=None, *, float? scale=None) -> (Tensor, Tensor) + inline ::std::tuple _scaled_dot_product_attention_math_for_mps(c10::DispatchKeySet dispatchKeySet, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const ::std::optional & attn_mask={}, double dropout_p=0.0, bool is_causal=false, const ::std::optional & dropout_mask={}, ::std::optional scale=::std::nullopt) { + return at::_ops::_scaled_dot_product_attention_math_for_mps::redispatch(dispatchKeySet, query, key, value, attn_mask, dropout_p, is_causal, dropout_mask, scale); + } + + // aten::_scaled_dot_product_flash_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor rng_state, Tensor unused, Tensor debug_attn_mask) + inline ::std::tuple _scaled_dot_product_flash_attention(c10::DispatchKeySet dispatchKeySet, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, double dropout_p=0.0, bool is_causal=false, bool return_debug_mask=false, ::std::optional scale=::std::nullopt) { + return at::_ops::_scaled_dot_product_flash_attention::redispatch(dispatchKeySet, query, key, value, dropout_p, is_causal, return_debug_mask, scale); + } + + // aten::_scaled_dot_product_flash_attention_for_cpu(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, *, Tensor? attn_mask=None, float? scale=None) -> (Tensor output, Tensor logsumexp) + inline ::std::tuple _scaled_dot_product_flash_attention_for_cpu(c10::DispatchKeySet dispatchKeySet, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, double dropout_p=0.0, bool is_causal=false, const ::std::optional & attn_mask={}, ::std::optional scale=::std::nullopt) { + return at::_ops::_scaled_dot_product_flash_attention_for_cpu::redispatch(dispatchKeySet, query, key, value, dropout_p, is_causal, attn_mask, scale); + } + + // aten::_scaled_dot_product_fused_attention_overrideable(Tensor query, Tensor key, Tensor value, Tensor? attn_bias=None, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) + inline ::std::tuple _scaled_dot_product_fused_attention_overrideable(c10::DispatchKeySet dispatchKeySet, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const ::std::optional & attn_bias={}, double dropout_p=0.0, bool is_causal=false, bool return_debug_mask=false, ::std::optional scale=::std::nullopt) { + return at::_ops::_scaled_dot_product_fused_attention_overrideable::redispatch(dispatchKeySet, query, key, value, attn_bias, dropout_p, is_causal, return_debug_mask, scale); + } + + // aten::_scaled_dot_product_flash_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, Tensor philox_seed, Tensor philox_offset, *, float? scale=None) -> (Tensor grad_query, Tensor grad_key, Tensor grad_value) + inline ::std::tuple _scaled_dot_product_flash_attention_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_out, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const at::Tensor & out, const at::Tensor & logsumexp, const at::Tensor & cum_seq_q, const at::Tensor & cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, bool is_causal, const at::Tensor & philox_seed, const at::Tensor & philox_offset, ::std::optional scale=::std::nullopt) { + return at::_ops::_scaled_dot_product_flash_attention_backward::redispatch(dispatchKeySet, grad_out, query, key, value, out, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, philox_seed, philox_offset, scale); + } + + // aten::_scaled_dot_product_flash_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, Tensor philox_seed, Tensor philox_offset, *, float? scale=None) -> (Tensor grad_query, Tensor grad_key, Tensor grad_value) + inline ::std::tuple _scaled_dot_product_flash_attention_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_out, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const at::Tensor & out, const at::Tensor & logsumexp, const at::Tensor & cum_seq_q, const at::Tensor & cum_seq_k, c10::SymInt max_q, c10::SymInt max_k, double dropout_p, bool is_causal, const at::Tensor & philox_seed, const at::Tensor & philox_offset, ::std::optional scale=::std::nullopt) { + return at::_ops::_scaled_dot_product_flash_attention_backward::redispatch(dispatchKeySet, grad_out, query, key, value, out, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, philox_seed, philox_offset, scale); + } + + // aten::_scaled_dot_product_flash_attention_for_cpu_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, float dropout_p, bool is_causal, *, Tensor? attn_mask=None, float? scale=None) -> (Tensor grad_query, Tensor grad_key, Tensor grad_value) + inline ::std::tuple _scaled_dot_product_flash_attention_for_cpu_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_out, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const at::Tensor & out, const at::Tensor & logsumexp, double dropout_p, bool is_causal, const ::std::optional & attn_mask={}, ::std::optional scale=::std::nullopt) { + return at::_ops::_scaled_dot_product_flash_attention_for_cpu_backward::redispatch(dispatchKeySet, grad_out, query, key, value, out, logsumexp, dropout_p, is_causal, attn_mask, scale); + } + + // aten::_scaled_dot_product_fused_attention_overrideable_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor attn_bias, bool[4] grad_input_mask, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, Tensor philox_seed, Tensor philox_offset, *, float? scale=None) -> (Tensor grad_query, Tensor grad_key, Tensor grad_value, Tensor grad_attn_bias) + inline ::std::tuple _scaled_dot_product_fused_attention_overrideable_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_out, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const at::Tensor & attn_bias, ::std::array grad_input_mask, const at::Tensor & out, const at::Tensor & logsumexp, const at::Tensor & cum_seq_q, const at::Tensor & cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, bool is_causal, const at::Tensor & philox_seed, const at::Tensor & philox_offset, ::std::optional scale=::std::nullopt) { + return at::_ops::_scaled_dot_product_fused_attention_overrideable_backward::redispatch(dispatchKeySet, grad_out, query, key, value, attn_bias, grad_input_mask, out, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, philox_seed, philox_offset, scale); + } + + // aten::_scaled_dot_product_fused_attention_overrideable_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor attn_bias, bool[4] grad_input_mask, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, Tensor philox_seed, Tensor philox_offset, *, float? scale=None) -> (Tensor grad_query, Tensor grad_key, Tensor grad_value, Tensor grad_attn_bias) + inline ::std::tuple _scaled_dot_product_fused_attention_overrideable_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_out, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const at::Tensor & attn_bias, ::std::array grad_input_mask, const at::Tensor & out, const at::Tensor & logsumexp, const at::Tensor & cum_seq_q, const at::Tensor & cum_seq_k, c10::SymInt max_q, c10::SymInt max_k, double dropout_p, bool is_causal, const at::Tensor & philox_seed, const at::Tensor & philox_offset, ::std::optional scale=::std::nullopt) { + return at::_ops::_scaled_dot_product_fused_attention_overrideable_backward::redispatch(dispatchKeySet, grad_out, query, key, value, attn_bias, grad_input_mask, out, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, philox_seed, philox_offset, scale); + } + + // aten::_scaled_dot_product_efficient_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> (Tensor output, Tensor log_sumexp, Tensor philox_seed, Tensor philox_offset) + inline ::std::tuple _scaled_dot_product_efficient_attention(c10::DispatchKeySet dispatchKeySet, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const ::std::optional & attn_bias, bool compute_log_sumexp, double dropout_p=0.0, bool is_causal=false, ::std::optional scale=::std::nullopt) { + return at::_ops::_scaled_dot_product_efficient_attention::redispatch(dispatchKeySet, query, key, value, attn_bias, compute_log_sumexp, dropout_p, is_causal, scale); + } + + // aten::_scaled_dot_product_efficient_attention_backward(Tensor grad_out_, Tensor query, Tensor key, Tensor value, Tensor attn_bias, Tensor out, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, float dropout_p, bool[4] grad_input_mask, bool is_causal=False, *, float? scale=None) -> (Tensor, Tensor, Tensor, Tensor) + inline ::std::tuple _scaled_dot_product_efficient_attention_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_out_, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const at::Tensor & attn_bias, const at::Tensor & out, const at::Tensor & logsumexp, const at::Tensor & philox_seed, const at::Tensor & philox_offset, double dropout_p, ::std::array grad_input_mask, bool is_causal=false, ::std::optional scale=::std::nullopt) { + return at::_ops::_scaled_dot_product_efficient_attention_backward::redispatch(dispatchKeySet, grad_out_, query, key, value, attn_bias, out, logsumexp, philox_seed, philox_offset, dropout_p, grad_input_mask, is_causal, scale); + } + + // aten::_scaled_dot_product_cudnn_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) + inline ::std::tuple _scaled_dot_product_cudnn_attention(c10::DispatchKeySet dispatchKeySet, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const ::std::optional & attn_bias, bool compute_log_sumexp, double dropout_p=0.0, bool is_causal=false, bool return_debug_mask=false, ::std::optional scale=::std::nullopt) { + return at::_ops::_scaled_dot_product_cudnn_attention::redispatch(dispatchKeySet, query, key, value, attn_bias, compute_log_sumexp, dropout_p, is_causal, return_debug_mask, scale); + } + + // aten::_scaled_dot_product_cudnn_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor attn_bias, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, *, float? scale=None) -> (Tensor, Tensor, Tensor) + inline ::std::tuple _scaled_dot_product_cudnn_attention_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_out, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const at::Tensor & out, const at::Tensor & logsumexp, const at::Tensor & philox_seed, const at::Tensor & philox_offset, const at::Tensor & attn_bias, const at::Tensor & cum_seq_q, const at::Tensor & cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, bool is_causal, ::std::optional scale=::std::nullopt) { + return at::_ops::_scaled_dot_product_cudnn_attention_backward::redispatch(dispatchKeySet, grad_out, query, key, value, out, logsumexp, philox_seed, philox_offset, attn_bias, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, scale); + } + + // aten::_scaled_dot_product_cudnn_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor attn_bias, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, *, float? scale=None) -> (Tensor, Tensor, Tensor) + inline ::std::tuple _scaled_dot_product_cudnn_attention_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_out, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const at::Tensor & out, const at::Tensor & logsumexp, const at::Tensor & philox_seed, const at::Tensor & philox_offset, const at::Tensor & attn_bias, const at::Tensor & cum_seq_q, const at::Tensor & cum_seq_k, c10::SymInt max_q, c10::SymInt max_k, double dropout_p, bool is_causal, ::std::optional scale=::std::nullopt) { + return at::_ops::_scaled_dot_product_cudnn_attention_backward::redispatch(dispatchKeySet, grad_out, query, key, value, out, logsumexp, philox_seed, philox_offset, attn_bias, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, scale); + } + + // aten::_flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cum_seq_q, Tensor? cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, bool return_debug_mask, *, float? scale=None, SymInt? window_size_left=None, SymInt? window_size_right=None, Tensor? seqused_k=None, Tensor? alibi_slopes=None) -> (Tensor output, Tensor softmax_logsumexp, Tensor rng_state, Tensor unused, Tensor debug_attn_mask) + inline ::std::tuple _flash_attention_forward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const ::std::optional & cum_seq_q, const ::std::optional & cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, bool is_causal, bool return_debug_mask, ::std::optional scale=::std::nullopt, ::std::optional window_size_left=::std::nullopt, ::std::optional window_size_right=::std::nullopt, const ::std::optional & seqused_k={}, const ::std::optional & alibi_slopes={}) { + return at::_ops::_flash_attention_forward::redispatch(dispatchKeySet, query, key, value, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, return_debug_mask, scale, window_size_left.has_value() ? ::std::make_optional(c10::SymInt(*window_size_left)) : ::std::nullopt, window_size_right.has_value() ? ::std::make_optional(c10::SymInt(*window_size_right)) : ::std::nullopt, seqused_k, alibi_slopes); + } + + // aten::_flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cum_seq_q, Tensor? cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, bool return_debug_mask, *, float? scale=None, SymInt? window_size_left=None, SymInt? window_size_right=None, Tensor? seqused_k=None, Tensor? alibi_slopes=None) -> (Tensor output, Tensor softmax_logsumexp, Tensor rng_state, Tensor unused, Tensor debug_attn_mask) + inline ::std::tuple _flash_attention_forward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const ::std::optional & cum_seq_q, const ::std::optional & cum_seq_k, c10::SymInt max_q, c10::SymInt max_k, double dropout_p, bool is_causal, bool return_debug_mask, ::std::optional scale=::std::nullopt, ::std::optional window_size_left=::std::nullopt, ::std::optional window_size_right=::std::nullopt, const ::std::optional & seqused_k={}, const ::std::optional & alibi_slopes={}) { + return at::_ops::_flash_attention_forward::redispatch(dispatchKeySet, query, key, value, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, return_debug_mask, scale, window_size_left, window_size_right, seqused_k, alibi_slopes); + } + + // aten::_flash_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, Tensor rng_state, Tensor unused, *, float? scale=None, SymInt? window_size_left=None, SymInt? window_size_right=None) -> (Tensor, Tensor, Tensor) + inline ::std::tuple _flash_attention_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_out, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const at::Tensor & out, const at::Tensor & logsumexp, const at::Tensor & cum_seq_q, const at::Tensor & cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, bool is_causal, const at::Tensor & rng_state, const at::Tensor & unused, ::std::optional scale=::std::nullopt, ::std::optional window_size_left=::std::nullopt, ::std::optional window_size_right=::std::nullopt) { + return at::_ops::_flash_attention_backward::redispatch(dispatchKeySet, grad_out, query, key, value, out, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, rng_state, unused, scale, window_size_left.has_value() ? ::std::make_optional(c10::SymInt(*window_size_left)) : ::std::nullopt, window_size_right.has_value() ? ::std::make_optional(c10::SymInt(*window_size_right)) : ::std::nullopt); + } + + // aten::_flash_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, Tensor rng_state, Tensor unused, *, float? scale=None, SymInt? window_size_left=None, SymInt? window_size_right=None) -> (Tensor, Tensor, Tensor) + inline ::std::tuple _flash_attention_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_out, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const at::Tensor & out, const at::Tensor & logsumexp, const at::Tensor & cum_seq_q, const at::Tensor & cum_seq_k, c10::SymInt max_q, c10::SymInt max_k, double dropout_p, bool is_causal, const at::Tensor & rng_state, const at::Tensor & unused, ::std::optional scale=::std::nullopt, ::std::optional window_size_left=::std::nullopt, ::std::optional window_size_right=::std::nullopt) { + return at::_ops::_flash_attention_backward::redispatch(dispatchKeySet, grad_out, query, key, value, out, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, rng_state, unused, scale, window_size_left, window_size_right); + } + + // aten::_efficient_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? bias, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, SymInt? max_seqlen_q, SymInt? max_seqlen_k, float dropout_p, int custom_mask_type, bool compute_log_sumexp=False, *, float? scale=None, Tensor? seqlen_k=None, int? window_size=None) -> (Tensor output, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, SymInt max_seqlen_batch_q, SymInt max_seqlen_batch_k) + inline ::std::tuple _efficient_attention_forward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const ::std::optional & bias, const ::std::optional & cu_seqlens_q, const ::std::optional & cu_seqlens_k, ::std::optional max_seqlen_q, ::std::optional max_seqlen_k, double dropout_p, int64_t custom_mask_type, bool compute_log_sumexp=false, ::std::optional scale=::std::nullopt, const ::std::optional & seqlen_k={}, ::std::optional window_size=::std::nullopt) { + return at::_ops::_efficient_attention_forward::redispatch(dispatchKeySet, query, key, value, bias, cu_seqlens_q, cu_seqlens_k, max_seqlen_q.has_value() ? ::std::make_optional(c10::SymInt(*max_seqlen_q)) : ::std::nullopt, max_seqlen_k.has_value() ? ::std::make_optional(c10::SymInt(*max_seqlen_k)) : ::std::nullopt, dropout_p, custom_mask_type, compute_log_sumexp, scale, seqlen_k, window_size); + } + + // aten::_efficient_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? bias, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, SymInt? max_seqlen_q, SymInt? max_seqlen_k, float dropout_p, int custom_mask_type, bool compute_log_sumexp=False, *, float? scale=None, Tensor? seqlen_k=None, int? window_size=None) -> (Tensor output, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, SymInt max_seqlen_batch_q, SymInt max_seqlen_batch_k) + inline ::std::tuple _efficient_attention_forward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const ::std::optional & bias, const ::std::optional & cu_seqlens_q, const ::std::optional & cu_seqlens_k, ::std::optional max_seqlen_q, ::std::optional max_seqlen_k, double dropout_p, int64_t custom_mask_type, bool compute_log_sumexp=false, ::std::optional scale=::std::nullopt, const ::std::optional & seqlen_k={}, ::std::optional window_size=::std::nullopt) { + return at::_ops::_efficient_attention_forward::redispatch(dispatchKeySet, query, key, value, bias, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, custom_mask_type, compute_log_sumexp, scale, seqlen_k, window_size); + } + + // aten::_efficient_attention_backward(Tensor grad_out_, Tensor query, Tensor key, Tensor value, Tensor? bias, Tensor out, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, SymInt max_seqlen_q, SymInt max_seqlen_k, Tensor logsumexp, float dropout_p, Tensor philox_seed, Tensor philox_offset, int custom_mask_type, bool bias_requires_grad, *, float? scale=None, int? num_splits_key=None, int? window_size=None, bool shared_storage_dqdkdv=False) -> (Tensor, Tensor, Tensor, Tensor) + inline ::std::tuple _efficient_attention_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_out_, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const ::std::optional & bias, const at::Tensor & out, const ::std::optional & cu_seqlens_q, const ::std::optional & cu_seqlens_k, int64_t max_seqlen_q, int64_t max_seqlen_k, const at::Tensor & logsumexp, double dropout_p, const at::Tensor & philox_seed, const at::Tensor & philox_offset, int64_t custom_mask_type, bool bias_requires_grad, ::std::optional scale=::std::nullopt, ::std::optional num_splits_key=::std::nullopt, ::std::optional window_size=::std::nullopt, bool shared_storage_dqdkdv=false) { + return at::_ops::_efficient_attention_backward::redispatch(dispatchKeySet, grad_out_, query, key, value, bias, out, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, logsumexp, dropout_p, philox_seed, philox_offset, custom_mask_type, bias_requires_grad, scale, num_splits_key, window_size, shared_storage_dqdkdv); + } + + // aten::_efficient_attention_backward(Tensor grad_out_, Tensor query, Tensor key, Tensor value, Tensor? bias, Tensor out, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, SymInt max_seqlen_q, SymInt max_seqlen_k, Tensor logsumexp, float dropout_p, Tensor philox_seed, Tensor philox_offset, int custom_mask_type, bool bias_requires_grad, *, float? scale=None, int? num_splits_key=None, int? window_size=None, bool shared_storage_dqdkdv=False) -> (Tensor, Tensor, Tensor, Tensor) + inline ::std::tuple _efficient_attention_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_out_, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const ::std::optional & bias, const at::Tensor & out, const ::std::optional & cu_seqlens_q, const ::std::optional & cu_seqlens_k, c10::SymInt max_seqlen_q, c10::SymInt max_seqlen_k, const at::Tensor & logsumexp, double dropout_p, const at::Tensor & philox_seed, const at::Tensor & philox_offset, int64_t custom_mask_type, bool bias_requires_grad, ::std::optional scale=::std::nullopt, ::std::optional num_splits_key=::std::nullopt, ::std::optional window_size=::std::nullopt, bool shared_storage_dqdkdv=false) { + return at::_ops::_efficient_attention_backward::redispatch(dispatchKeySet, grad_out_, query, key, value, bias, out, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, logsumexp, dropout_p, philox_seed, philox_offset, custom_mask_type, bias_requires_grad, scale, num_splits_key, window_size, shared_storage_dqdkdv); + } + + // aten::_cudnn_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, Tensor? cum_seq_q, Tensor? cum_seq_k, SymInt max_q, SymInt max_k, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) + inline ::std::tuple _cudnn_attention_forward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const ::std::optional & attn_bias, const ::std::optional & cum_seq_q, const ::std::optional & cum_seq_k, int64_t max_q, int64_t max_k, bool compute_log_sumexp, double dropout_p=0.0, bool is_causal=false, bool return_debug_mask=false, ::std::optional scale=::std::nullopt) { + return at::_ops::_cudnn_attention_forward::redispatch(dispatchKeySet, query, key, value, attn_bias, cum_seq_q, cum_seq_k, max_q, max_k, compute_log_sumexp, dropout_p, is_causal, return_debug_mask, scale); + } + + // aten::_cudnn_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, Tensor? cum_seq_q, Tensor? cum_seq_k, SymInt max_q, SymInt max_k, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) + inline ::std::tuple _cudnn_attention_forward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const ::std::optional & attn_bias, const ::std::optional & cum_seq_q, const ::std::optional & cum_seq_k, c10::SymInt max_q, c10::SymInt max_k, bool compute_log_sumexp, double dropout_p=0.0, bool is_causal=false, bool return_debug_mask=false, ::std::optional scale=::std::nullopt) { + return at::_ops::_cudnn_attention_forward::redispatch(dispatchKeySet, query, key, value, attn_bias, cum_seq_q, cum_seq_k, max_q, max_k, compute_log_sumexp, dropout_p, is_causal, return_debug_mask, scale); + } + + // aten::_cudnn_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor attn_bias, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, *, float? scale=None) -> (Tensor, Tensor, Tensor) + inline ::std::tuple _cudnn_attention_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_out, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const at::Tensor & out, const at::Tensor & logsumexp, const at::Tensor & philox_seed, const at::Tensor & philox_offset, const at::Tensor & attn_bias, const at::Tensor & cum_seq_q, const at::Tensor & cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, bool is_causal, ::std::optional scale=::std::nullopt) { + return at::_ops::_cudnn_attention_backward::redispatch(dispatchKeySet, grad_out, query, key, value, out, logsumexp, philox_seed, philox_offset, attn_bias, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, scale); + } + + // aten::_cudnn_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor attn_bias, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, *, float? scale=None) -> (Tensor, Tensor, Tensor) + inline ::std::tuple _cudnn_attention_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_out, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const at::Tensor & out, const at::Tensor & logsumexp, const at::Tensor & philox_seed, const at::Tensor & philox_offset, const at::Tensor & attn_bias, const at::Tensor & cum_seq_q, const at::Tensor & cum_seq_k, c10::SymInt max_q, c10::SymInt max_k, double dropout_p, bool is_causal, ::std::optional scale=::std::nullopt) { + return at::_ops::_cudnn_attention_backward::redispatch(dispatchKeySet, grad_out, query, key, value, out, logsumexp, philox_seed, philox_offset, attn_bias, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, scale); + } + + // aten::_triton_scaled_dot_attention(Tensor q, Tensor k, Tensor v, float dropout_p=0.0) -> Tensor + inline at::Tensor _triton_scaled_dot_attention(c10::DispatchKeySet dispatchKeySet, const at::Tensor & q, const at::Tensor & k, const at::Tensor & v, double dropout_p=0.0) { + return at::_ops::_triton_scaled_dot_attention::redispatch(dispatchKeySet, q, k, v, dropout_p); + } + + // aten::_fill_mem_eff_dropout_mask_(Tensor(a!) self, float dropout_p, int seed, int offset) -> Tensor(a!) + inline at::Tensor & _fill_mem_eff_dropout_mask_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, double dropout_p, int64_t seed, int64_t offset) { + return at::_ops::_fill_mem_eff_dropout_mask_::redispatch(dispatchKeySet, self, dropout_p, seed, offset); + } + + // aten::_triton_multi_head_attention(Tensor query, Tensor key, Tensor value, int embed_dim, int num_head, Tensor qkv_weight, Tensor qkv_bias, Tensor proj_weight, Tensor proj_bias, Tensor? mask=None) -> Tensor + inline at::Tensor _triton_multi_head_attention(c10::DispatchKeySet dispatchKeySet, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, int64_t embed_dim, int64_t num_head, const at::Tensor & qkv_weight, const at::Tensor & qkv_bias, const at::Tensor & proj_weight, const at::Tensor & proj_bias, const ::std::optional & mask={}) { + return at::_ops::_triton_multi_head_attention::redispatch(dispatchKeySet, query, key, value, embed_dim, num_head, qkv_weight, qkv_bias, proj_weight, proj_bias, mask); + } + + // aten::special_airy_ai(Tensor x) -> Tensor + inline at::Tensor special_airy_ai(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x) { + return at::_ops::special_airy_ai::redispatch(dispatchKeySet, x); + } + + // aten::special_airy_ai.out(Tensor x, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_airy_ai_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & x) { + return at::_ops::special_airy_ai_out::redispatch(dispatchKeySet, x, out); + } + + // aten::special_airy_ai.out(Tensor x, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_airy_ai_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, at::Tensor & out) { + return at::_ops::special_airy_ai_out::redispatch(dispatchKeySet, x, out); + } + + // aten::special_bessel_j0(Tensor self) -> Tensor + inline at::Tensor special_bessel_j0(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::special_bessel_j0::redispatch(dispatchKeySet, self); + } + + // aten::special_bessel_j0.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_bessel_j0_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::special_bessel_j0_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_bessel_j0.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_bessel_j0_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::special_bessel_j0_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_bessel_j1(Tensor self) -> Tensor + inline at::Tensor special_bessel_j1(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::special_bessel_j1::redispatch(dispatchKeySet, self); + } + + // aten::special_bessel_j1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_bessel_j1_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::special_bessel_j1_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_bessel_j1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_bessel_j1_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::special_bessel_j1_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_bessel_y0(Tensor self) -> Tensor + inline at::Tensor special_bessel_y0(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::special_bessel_y0::redispatch(dispatchKeySet, self); + } + + // aten::special_bessel_y0.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_bessel_y0_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::special_bessel_y0_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_bessel_y0.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_bessel_y0_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::special_bessel_y0_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_bessel_y1(Tensor self) -> Tensor + inline at::Tensor special_bessel_y1(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::special_bessel_y1::redispatch(dispatchKeySet, self); + } + + // aten::special_bessel_y1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_bessel_y1_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::special_bessel_y1_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_bessel_y1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_bessel_y1_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::special_bessel_y1_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_chebyshev_polynomial_t(Tensor x, Tensor n) -> Tensor + inline at::Tensor special_chebyshev_polynomial_t(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Tensor & n) { + return at::_ops::special_chebyshev_polynomial_t::redispatch(dispatchKeySet, x, n); + } + + // aten::special_chebyshev_polynomial_t.x_scalar(Scalar x, Tensor n) -> Tensor + inline at::Tensor special_chebyshev_polynomial_t(c10::DispatchKeySet dispatchKeySet, const at::Scalar & x, const at::Tensor & n) { + return at::_ops::special_chebyshev_polynomial_t_x_scalar::redispatch(dispatchKeySet, x, n); + } + + // aten::special_chebyshev_polynomial_t.n_scalar(Tensor x, Scalar n) -> Tensor + inline at::Tensor special_chebyshev_polynomial_t(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Scalar & n) { + return at::_ops::special_chebyshev_polynomial_t_n_scalar::redispatch(dispatchKeySet, x, n); + } + + // aten::special_chebyshev_polynomial_t.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_chebyshev_polynomial_t_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & x, const at::Tensor & n) { + return at::_ops::special_chebyshev_polynomial_t_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_chebyshev_polynomial_t.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_chebyshev_polynomial_t_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Tensor & n, at::Tensor & out) { + return at::_ops::special_chebyshev_polynomial_t_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_chebyshev_polynomial_t.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_chebyshev_polynomial_t_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & x, const at::Tensor & n) { + return at::_ops::special_chebyshev_polynomial_t_x_scalar_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_chebyshev_polynomial_t.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_chebyshev_polynomial_t_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & x, const at::Tensor & n, at::Tensor & out) { + return at::_ops::special_chebyshev_polynomial_t_x_scalar_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_chebyshev_polynomial_t.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_chebyshev_polynomial_t_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & x, const at::Scalar & n) { + return at::_ops::special_chebyshev_polynomial_t_n_scalar_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_chebyshev_polynomial_t.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_chebyshev_polynomial_t_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Scalar & n, at::Tensor & out) { + return at::_ops::special_chebyshev_polynomial_t_n_scalar_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_chebyshev_polynomial_u(Tensor x, Tensor n) -> Tensor + inline at::Tensor special_chebyshev_polynomial_u(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Tensor & n) { + return at::_ops::special_chebyshev_polynomial_u::redispatch(dispatchKeySet, x, n); + } + + // aten::special_chebyshev_polynomial_u.x_scalar(Scalar x, Tensor n) -> Tensor + inline at::Tensor special_chebyshev_polynomial_u(c10::DispatchKeySet dispatchKeySet, const at::Scalar & x, const at::Tensor & n) { + return at::_ops::special_chebyshev_polynomial_u_x_scalar::redispatch(dispatchKeySet, x, n); + } + + // aten::special_chebyshev_polynomial_u.n_scalar(Tensor x, Scalar n) -> Tensor + inline at::Tensor special_chebyshev_polynomial_u(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Scalar & n) { + return at::_ops::special_chebyshev_polynomial_u_n_scalar::redispatch(dispatchKeySet, x, n); + } + + // aten::special_chebyshev_polynomial_u.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_chebyshev_polynomial_u_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & x, const at::Tensor & n) { + return at::_ops::special_chebyshev_polynomial_u_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_chebyshev_polynomial_u.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_chebyshev_polynomial_u_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Tensor & n, at::Tensor & out) { + return at::_ops::special_chebyshev_polynomial_u_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_chebyshev_polynomial_u.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_chebyshev_polynomial_u_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & x, const at::Tensor & n) { + return at::_ops::special_chebyshev_polynomial_u_x_scalar_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_chebyshev_polynomial_u.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_chebyshev_polynomial_u_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & x, const at::Tensor & n, at::Tensor & out) { + return at::_ops::special_chebyshev_polynomial_u_x_scalar_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_chebyshev_polynomial_u.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_chebyshev_polynomial_u_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & x, const at::Scalar & n) { + return at::_ops::special_chebyshev_polynomial_u_n_scalar_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_chebyshev_polynomial_u.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_chebyshev_polynomial_u_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Scalar & n, at::Tensor & out) { + return at::_ops::special_chebyshev_polynomial_u_n_scalar_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_chebyshev_polynomial_v(Tensor x, Tensor n) -> Tensor + inline at::Tensor special_chebyshev_polynomial_v(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Tensor & n) { + return at::_ops::special_chebyshev_polynomial_v::redispatch(dispatchKeySet, x, n); + } + + // aten::special_chebyshev_polynomial_v.x_scalar(Scalar x, Tensor n) -> Tensor + inline at::Tensor special_chebyshev_polynomial_v(c10::DispatchKeySet dispatchKeySet, const at::Scalar & x, const at::Tensor & n) { + return at::_ops::special_chebyshev_polynomial_v_x_scalar::redispatch(dispatchKeySet, x, n); + } + + // aten::special_chebyshev_polynomial_v.n_scalar(Tensor x, Scalar n) -> Tensor + inline at::Tensor special_chebyshev_polynomial_v(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Scalar & n) { + return at::_ops::special_chebyshev_polynomial_v_n_scalar::redispatch(dispatchKeySet, x, n); + } + + // aten::special_chebyshev_polynomial_v.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_chebyshev_polynomial_v_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & x, const at::Tensor & n) { + return at::_ops::special_chebyshev_polynomial_v_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_chebyshev_polynomial_v.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_chebyshev_polynomial_v_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Tensor & n, at::Tensor & out) { + return at::_ops::special_chebyshev_polynomial_v_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_chebyshev_polynomial_v.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_chebyshev_polynomial_v_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & x, const at::Tensor & n) { + return at::_ops::special_chebyshev_polynomial_v_x_scalar_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_chebyshev_polynomial_v.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_chebyshev_polynomial_v_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & x, const at::Tensor & n, at::Tensor & out) { + return at::_ops::special_chebyshev_polynomial_v_x_scalar_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_chebyshev_polynomial_v.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_chebyshev_polynomial_v_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & x, const at::Scalar & n) { + return at::_ops::special_chebyshev_polynomial_v_n_scalar_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_chebyshev_polynomial_v.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_chebyshev_polynomial_v_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Scalar & n, at::Tensor & out) { + return at::_ops::special_chebyshev_polynomial_v_n_scalar_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_chebyshev_polynomial_w(Tensor x, Tensor n) -> Tensor + inline at::Tensor special_chebyshev_polynomial_w(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Tensor & n) { + return at::_ops::special_chebyshev_polynomial_w::redispatch(dispatchKeySet, x, n); + } + + // aten::special_chebyshev_polynomial_w.x_scalar(Scalar x, Tensor n) -> Tensor + inline at::Tensor special_chebyshev_polynomial_w(c10::DispatchKeySet dispatchKeySet, const at::Scalar & x, const at::Tensor & n) { + return at::_ops::special_chebyshev_polynomial_w_x_scalar::redispatch(dispatchKeySet, x, n); + } + + // aten::special_chebyshev_polynomial_w.n_scalar(Tensor x, Scalar n) -> Tensor + inline at::Tensor special_chebyshev_polynomial_w(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Scalar & n) { + return at::_ops::special_chebyshev_polynomial_w_n_scalar::redispatch(dispatchKeySet, x, n); + } + + // aten::special_chebyshev_polynomial_w.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_chebyshev_polynomial_w_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & x, const at::Tensor & n) { + return at::_ops::special_chebyshev_polynomial_w_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_chebyshev_polynomial_w.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_chebyshev_polynomial_w_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Tensor & n, at::Tensor & out) { + return at::_ops::special_chebyshev_polynomial_w_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_chebyshev_polynomial_w.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_chebyshev_polynomial_w_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & x, const at::Tensor & n) { + return at::_ops::special_chebyshev_polynomial_w_x_scalar_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_chebyshev_polynomial_w.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_chebyshev_polynomial_w_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & x, const at::Tensor & n, at::Tensor & out) { + return at::_ops::special_chebyshev_polynomial_w_x_scalar_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_chebyshev_polynomial_w.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_chebyshev_polynomial_w_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & x, const at::Scalar & n) { + return at::_ops::special_chebyshev_polynomial_w_n_scalar_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_chebyshev_polynomial_w.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_chebyshev_polynomial_w_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Scalar & n, at::Tensor & out) { + return at::_ops::special_chebyshev_polynomial_w_n_scalar_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_hermite_polynomial_h(Tensor x, Tensor n) -> Tensor + inline at::Tensor special_hermite_polynomial_h(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Tensor & n) { + return at::_ops::special_hermite_polynomial_h::redispatch(dispatchKeySet, x, n); + } + + // aten::special_hermite_polynomial_h.x_scalar(Scalar x, Tensor n) -> Tensor + inline at::Tensor special_hermite_polynomial_h(c10::DispatchKeySet dispatchKeySet, const at::Scalar & x, const at::Tensor & n) { + return at::_ops::special_hermite_polynomial_h_x_scalar::redispatch(dispatchKeySet, x, n); + } + + // aten::special_hermite_polynomial_h.n_scalar(Tensor x, Scalar n) -> Tensor + inline at::Tensor special_hermite_polynomial_h(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Scalar & n) { + return at::_ops::special_hermite_polynomial_h_n_scalar::redispatch(dispatchKeySet, x, n); + } + + // aten::special_hermite_polynomial_h.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_hermite_polynomial_h_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & x, const at::Tensor & n) { + return at::_ops::special_hermite_polynomial_h_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_hermite_polynomial_h.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_hermite_polynomial_h_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Tensor & n, at::Tensor & out) { + return at::_ops::special_hermite_polynomial_h_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_hermite_polynomial_h.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_hermite_polynomial_h_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & x, const at::Tensor & n) { + return at::_ops::special_hermite_polynomial_h_x_scalar_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_hermite_polynomial_h.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_hermite_polynomial_h_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & x, const at::Tensor & n, at::Tensor & out) { + return at::_ops::special_hermite_polynomial_h_x_scalar_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_hermite_polynomial_h.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_hermite_polynomial_h_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & x, const at::Scalar & n) { + return at::_ops::special_hermite_polynomial_h_n_scalar_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_hermite_polynomial_h.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_hermite_polynomial_h_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Scalar & n, at::Tensor & out) { + return at::_ops::special_hermite_polynomial_h_n_scalar_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_hermite_polynomial_he(Tensor x, Tensor n) -> Tensor + inline at::Tensor special_hermite_polynomial_he(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Tensor & n) { + return at::_ops::special_hermite_polynomial_he::redispatch(dispatchKeySet, x, n); + } + + // aten::special_hermite_polynomial_he.x_scalar(Scalar x, Tensor n) -> Tensor + inline at::Tensor special_hermite_polynomial_he(c10::DispatchKeySet dispatchKeySet, const at::Scalar & x, const at::Tensor & n) { + return at::_ops::special_hermite_polynomial_he_x_scalar::redispatch(dispatchKeySet, x, n); + } + + // aten::special_hermite_polynomial_he.n_scalar(Tensor x, Scalar n) -> Tensor + inline at::Tensor special_hermite_polynomial_he(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Scalar & n) { + return at::_ops::special_hermite_polynomial_he_n_scalar::redispatch(dispatchKeySet, x, n); + } + + // aten::special_hermite_polynomial_he.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_hermite_polynomial_he_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & x, const at::Tensor & n) { + return at::_ops::special_hermite_polynomial_he_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_hermite_polynomial_he.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_hermite_polynomial_he_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Tensor & n, at::Tensor & out) { + return at::_ops::special_hermite_polynomial_he_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_hermite_polynomial_he.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_hermite_polynomial_he_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & x, const at::Tensor & n) { + return at::_ops::special_hermite_polynomial_he_x_scalar_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_hermite_polynomial_he.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_hermite_polynomial_he_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & x, const at::Tensor & n, at::Tensor & out) { + return at::_ops::special_hermite_polynomial_he_x_scalar_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_hermite_polynomial_he.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_hermite_polynomial_he_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & x, const at::Scalar & n) { + return at::_ops::special_hermite_polynomial_he_n_scalar_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_hermite_polynomial_he.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_hermite_polynomial_he_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Scalar & n, at::Tensor & out) { + return at::_ops::special_hermite_polynomial_he_n_scalar_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_laguerre_polynomial_l(Tensor x, Tensor n) -> Tensor + inline at::Tensor special_laguerre_polynomial_l(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Tensor & n) { + return at::_ops::special_laguerre_polynomial_l::redispatch(dispatchKeySet, x, n); + } + + // aten::special_laguerre_polynomial_l.x_scalar(Scalar x, Tensor n) -> Tensor + inline at::Tensor special_laguerre_polynomial_l(c10::DispatchKeySet dispatchKeySet, const at::Scalar & x, const at::Tensor & n) { + return at::_ops::special_laguerre_polynomial_l_x_scalar::redispatch(dispatchKeySet, x, n); + } + + // aten::special_laguerre_polynomial_l.n_scalar(Tensor x, Scalar n) -> Tensor + inline at::Tensor special_laguerre_polynomial_l(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Scalar & n) { + return at::_ops::special_laguerre_polynomial_l_n_scalar::redispatch(dispatchKeySet, x, n); + } + + // aten::special_laguerre_polynomial_l.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_laguerre_polynomial_l_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & x, const at::Tensor & n) { + return at::_ops::special_laguerre_polynomial_l_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_laguerre_polynomial_l.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_laguerre_polynomial_l_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Tensor & n, at::Tensor & out) { + return at::_ops::special_laguerre_polynomial_l_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_laguerre_polynomial_l.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_laguerre_polynomial_l_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & x, const at::Tensor & n) { + return at::_ops::special_laguerre_polynomial_l_x_scalar_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_laguerre_polynomial_l.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_laguerre_polynomial_l_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & x, const at::Tensor & n, at::Tensor & out) { + return at::_ops::special_laguerre_polynomial_l_x_scalar_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_laguerre_polynomial_l.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_laguerre_polynomial_l_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & x, const at::Scalar & n) { + return at::_ops::special_laguerre_polynomial_l_n_scalar_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_laguerre_polynomial_l.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_laguerre_polynomial_l_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Scalar & n, at::Tensor & out) { + return at::_ops::special_laguerre_polynomial_l_n_scalar_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_legendre_polynomial_p(Tensor x, Tensor n) -> Tensor + inline at::Tensor special_legendre_polynomial_p(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Tensor & n) { + return at::_ops::special_legendre_polynomial_p::redispatch(dispatchKeySet, x, n); + } + + // aten::special_legendre_polynomial_p.x_scalar(Scalar x, Tensor n) -> Tensor + inline at::Tensor special_legendre_polynomial_p(c10::DispatchKeySet dispatchKeySet, const at::Scalar & x, const at::Tensor & n) { + return at::_ops::special_legendre_polynomial_p_x_scalar::redispatch(dispatchKeySet, x, n); + } + + // aten::special_legendre_polynomial_p.n_scalar(Tensor x, Scalar n) -> Tensor + inline at::Tensor special_legendre_polynomial_p(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Scalar & n) { + return at::_ops::special_legendre_polynomial_p_n_scalar::redispatch(dispatchKeySet, x, n); + } + + // aten::special_legendre_polynomial_p.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_legendre_polynomial_p_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & x, const at::Tensor & n) { + return at::_ops::special_legendre_polynomial_p_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_legendre_polynomial_p.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_legendre_polynomial_p_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Tensor & n, at::Tensor & out) { + return at::_ops::special_legendre_polynomial_p_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_legendre_polynomial_p.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_legendre_polynomial_p_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & x, const at::Tensor & n) { + return at::_ops::special_legendre_polynomial_p_x_scalar_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_legendre_polynomial_p.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_legendre_polynomial_p_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & x, const at::Tensor & n, at::Tensor & out) { + return at::_ops::special_legendre_polynomial_p_x_scalar_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_legendre_polynomial_p.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_legendre_polynomial_p_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & x, const at::Scalar & n) { + return at::_ops::special_legendre_polynomial_p_n_scalar_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_legendre_polynomial_p.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_legendre_polynomial_p_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Scalar & n, at::Tensor & out) { + return at::_ops::special_legendre_polynomial_p_n_scalar_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_modified_bessel_i0(Tensor self) -> Tensor + inline at::Tensor special_modified_bessel_i0(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::special_modified_bessel_i0::redispatch(dispatchKeySet, self); + } + + // aten::special_modified_bessel_i0.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_modified_bessel_i0_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::special_modified_bessel_i0_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_modified_bessel_i0.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_modified_bessel_i0_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::special_modified_bessel_i0_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_modified_bessel_i1(Tensor self) -> Tensor + inline at::Tensor special_modified_bessel_i1(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::special_modified_bessel_i1::redispatch(dispatchKeySet, self); + } + + // aten::special_modified_bessel_i1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_modified_bessel_i1_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::special_modified_bessel_i1_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_modified_bessel_i1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_modified_bessel_i1_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::special_modified_bessel_i1_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_modified_bessel_k0(Tensor self) -> Tensor + inline at::Tensor special_modified_bessel_k0(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::special_modified_bessel_k0::redispatch(dispatchKeySet, self); + } + + // aten::special_modified_bessel_k0.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_modified_bessel_k0_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::special_modified_bessel_k0_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_modified_bessel_k0.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_modified_bessel_k0_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::special_modified_bessel_k0_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_modified_bessel_k1(Tensor self) -> Tensor + inline at::Tensor special_modified_bessel_k1(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::special_modified_bessel_k1::redispatch(dispatchKeySet, self); + } + + // aten::special_modified_bessel_k1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_modified_bessel_k1_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::special_modified_bessel_k1_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_modified_bessel_k1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_modified_bessel_k1_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::special_modified_bessel_k1_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_scaled_modified_bessel_k0(Tensor x) -> Tensor + inline at::Tensor special_scaled_modified_bessel_k0(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x) { + return at::_ops::special_scaled_modified_bessel_k0::redispatch(dispatchKeySet, x); + } + + // aten::special_scaled_modified_bessel_k0.out(Tensor x, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_scaled_modified_bessel_k0_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & x) { + return at::_ops::special_scaled_modified_bessel_k0_out::redispatch(dispatchKeySet, x, out); + } + + // aten::special_scaled_modified_bessel_k0.out(Tensor x, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_scaled_modified_bessel_k0_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, at::Tensor & out) { + return at::_ops::special_scaled_modified_bessel_k0_out::redispatch(dispatchKeySet, x, out); + } + + // aten::special_scaled_modified_bessel_k1(Tensor x) -> Tensor + inline at::Tensor special_scaled_modified_bessel_k1(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x) { + return at::_ops::special_scaled_modified_bessel_k1::redispatch(dispatchKeySet, x); + } + + // aten::special_scaled_modified_bessel_k1.out(Tensor x, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_scaled_modified_bessel_k1_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & x) { + return at::_ops::special_scaled_modified_bessel_k1_out::redispatch(dispatchKeySet, x, out); + } + + // aten::special_scaled_modified_bessel_k1.out(Tensor x, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_scaled_modified_bessel_k1_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, at::Tensor & out) { + return at::_ops::special_scaled_modified_bessel_k1_out::redispatch(dispatchKeySet, x, out); + } + + // aten::special_shifted_chebyshev_polynomial_t(Tensor x, Tensor n) -> Tensor + inline at::Tensor special_shifted_chebyshev_polynomial_t(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Tensor & n) { + return at::_ops::special_shifted_chebyshev_polynomial_t::redispatch(dispatchKeySet, x, n); + } + + // aten::special_shifted_chebyshev_polynomial_t.x_scalar(Scalar x, Tensor n) -> Tensor + inline at::Tensor special_shifted_chebyshev_polynomial_t(c10::DispatchKeySet dispatchKeySet, const at::Scalar & x, const at::Tensor & n) { + return at::_ops::special_shifted_chebyshev_polynomial_t_x_scalar::redispatch(dispatchKeySet, x, n); + } + + // aten::special_shifted_chebyshev_polynomial_t.n_scalar(Tensor x, Scalar n) -> Tensor + inline at::Tensor special_shifted_chebyshev_polynomial_t(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Scalar & n) { + return at::_ops::special_shifted_chebyshev_polynomial_t_n_scalar::redispatch(dispatchKeySet, x, n); + } + + // aten::special_shifted_chebyshev_polynomial_t.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_shifted_chebyshev_polynomial_t_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & x, const at::Tensor & n) { + return at::_ops::special_shifted_chebyshev_polynomial_t_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_shifted_chebyshev_polynomial_t.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_shifted_chebyshev_polynomial_t_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Tensor & n, at::Tensor & out) { + return at::_ops::special_shifted_chebyshev_polynomial_t_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_shifted_chebyshev_polynomial_t.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_shifted_chebyshev_polynomial_t_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & x, const at::Tensor & n) { + return at::_ops::special_shifted_chebyshev_polynomial_t_x_scalar_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_shifted_chebyshev_polynomial_t.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_shifted_chebyshev_polynomial_t_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & x, const at::Tensor & n, at::Tensor & out) { + return at::_ops::special_shifted_chebyshev_polynomial_t_x_scalar_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_shifted_chebyshev_polynomial_t.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_shifted_chebyshev_polynomial_t_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & x, const at::Scalar & n) { + return at::_ops::special_shifted_chebyshev_polynomial_t_n_scalar_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_shifted_chebyshev_polynomial_t.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_shifted_chebyshev_polynomial_t_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Scalar & n, at::Tensor & out) { + return at::_ops::special_shifted_chebyshev_polynomial_t_n_scalar_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_shifted_chebyshev_polynomial_u(Tensor x, Tensor n) -> Tensor + inline at::Tensor special_shifted_chebyshev_polynomial_u(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Tensor & n) { + return at::_ops::special_shifted_chebyshev_polynomial_u::redispatch(dispatchKeySet, x, n); + } + + // aten::special_shifted_chebyshev_polynomial_u.x_scalar(Scalar x, Tensor n) -> Tensor + inline at::Tensor special_shifted_chebyshev_polynomial_u(c10::DispatchKeySet dispatchKeySet, const at::Scalar & x, const at::Tensor & n) { + return at::_ops::special_shifted_chebyshev_polynomial_u_x_scalar::redispatch(dispatchKeySet, x, n); + } + + // aten::special_shifted_chebyshev_polynomial_u.n_scalar(Tensor x, Scalar n) -> Tensor + inline at::Tensor special_shifted_chebyshev_polynomial_u(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Scalar & n) { + return at::_ops::special_shifted_chebyshev_polynomial_u_n_scalar::redispatch(dispatchKeySet, x, n); + } + + // aten::special_shifted_chebyshev_polynomial_u.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_shifted_chebyshev_polynomial_u_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & x, const at::Tensor & n) { + return at::_ops::special_shifted_chebyshev_polynomial_u_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_shifted_chebyshev_polynomial_u.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_shifted_chebyshev_polynomial_u_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Tensor & n, at::Tensor & out) { + return at::_ops::special_shifted_chebyshev_polynomial_u_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_shifted_chebyshev_polynomial_u.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_shifted_chebyshev_polynomial_u_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & x, const at::Tensor & n) { + return at::_ops::special_shifted_chebyshev_polynomial_u_x_scalar_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_shifted_chebyshev_polynomial_u.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_shifted_chebyshev_polynomial_u_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & x, const at::Tensor & n, at::Tensor & out) { + return at::_ops::special_shifted_chebyshev_polynomial_u_x_scalar_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_shifted_chebyshev_polynomial_u.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_shifted_chebyshev_polynomial_u_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & x, const at::Scalar & n) { + return at::_ops::special_shifted_chebyshev_polynomial_u_n_scalar_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_shifted_chebyshev_polynomial_u.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_shifted_chebyshev_polynomial_u_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Scalar & n, at::Tensor & out) { + return at::_ops::special_shifted_chebyshev_polynomial_u_n_scalar_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_shifted_chebyshev_polynomial_v(Tensor x, Tensor n) -> Tensor + inline at::Tensor special_shifted_chebyshev_polynomial_v(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Tensor & n) { + return at::_ops::special_shifted_chebyshev_polynomial_v::redispatch(dispatchKeySet, x, n); + } + + // aten::special_shifted_chebyshev_polynomial_v.x_scalar(Scalar x, Tensor n) -> Tensor + inline at::Tensor special_shifted_chebyshev_polynomial_v(c10::DispatchKeySet dispatchKeySet, const at::Scalar & x, const at::Tensor & n) { + return at::_ops::special_shifted_chebyshev_polynomial_v_x_scalar::redispatch(dispatchKeySet, x, n); + } + + // aten::special_shifted_chebyshev_polynomial_v.n_scalar(Tensor x, Scalar n) -> Tensor + inline at::Tensor special_shifted_chebyshev_polynomial_v(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Scalar & n) { + return at::_ops::special_shifted_chebyshev_polynomial_v_n_scalar::redispatch(dispatchKeySet, x, n); + } + + // aten::special_shifted_chebyshev_polynomial_v.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_shifted_chebyshev_polynomial_v_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & x, const at::Tensor & n) { + return at::_ops::special_shifted_chebyshev_polynomial_v_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_shifted_chebyshev_polynomial_v.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_shifted_chebyshev_polynomial_v_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Tensor & n, at::Tensor & out) { + return at::_ops::special_shifted_chebyshev_polynomial_v_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_shifted_chebyshev_polynomial_v.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_shifted_chebyshev_polynomial_v_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & x, const at::Tensor & n) { + return at::_ops::special_shifted_chebyshev_polynomial_v_x_scalar_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_shifted_chebyshev_polynomial_v.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_shifted_chebyshev_polynomial_v_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & x, const at::Tensor & n, at::Tensor & out) { + return at::_ops::special_shifted_chebyshev_polynomial_v_x_scalar_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_shifted_chebyshev_polynomial_v.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_shifted_chebyshev_polynomial_v_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & x, const at::Scalar & n) { + return at::_ops::special_shifted_chebyshev_polynomial_v_n_scalar_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_shifted_chebyshev_polynomial_v.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_shifted_chebyshev_polynomial_v_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Scalar & n, at::Tensor & out) { + return at::_ops::special_shifted_chebyshev_polynomial_v_n_scalar_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_shifted_chebyshev_polynomial_w(Tensor x, Tensor n) -> Tensor + inline at::Tensor special_shifted_chebyshev_polynomial_w(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Tensor & n) { + return at::_ops::special_shifted_chebyshev_polynomial_w::redispatch(dispatchKeySet, x, n); + } + + // aten::special_shifted_chebyshev_polynomial_w.x_scalar(Scalar x, Tensor n) -> Tensor + inline at::Tensor special_shifted_chebyshev_polynomial_w(c10::DispatchKeySet dispatchKeySet, const at::Scalar & x, const at::Tensor & n) { + return at::_ops::special_shifted_chebyshev_polynomial_w_x_scalar::redispatch(dispatchKeySet, x, n); + } + + // aten::special_shifted_chebyshev_polynomial_w.n_scalar(Tensor x, Scalar n) -> Tensor + inline at::Tensor special_shifted_chebyshev_polynomial_w(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Scalar & n) { + return at::_ops::special_shifted_chebyshev_polynomial_w_n_scalar::redispatch(dispatchKeySet, x, n); + } + + // aten::special_shifted_chebyshev_polynomial_w.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_shifted_chebyshev_polynomial_w_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & x, const at::Tensor & n) { + return at::_ops::special_shifted_chebyshev_polynomial_w_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_shifted_chebyshev_polynomial_w.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_shifted_chebyshev_polynomial_w_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Tensor & n, at::Tensor & out) { + return at::_ops::special_shifted_chebyshev_polynomial_w_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_shifted_chebyshev_polynomial_w.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_shifted_chebyshev_polynomial_w_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & x, const at::Tensor & n) { + return at::_ops::special_shifted_chebyshev_polynomial_w_x_scalar_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_shifted_chebyshev_polynomial_w.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_shifted_chebyshev_polynomial_w_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & x, const at::Tensor & n, at::Tensor & out) { + return at::_ops::special_shifted_chebyshev_polynomial_w_x_scalar_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_shifted_chebyshev_polynomial_w.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_shifted_chebyshev_polynomial_w_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & x, const at::Scalar & n) { + return at::_ops::special_shifted_chebyshev_polynomial_w_n_scalar_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_shifted_chebyshev_polynomial_w.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_shifted_chebyshev_polynomial_w_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Scalar & n, at::Tensor & out) { + return at::_ops::special_shifted_chebyshev_polynomial_w_n_scalar_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_spherical_bessel_j0(Tensor x) -> Tensor + inline at::Tensor special_spherical_bessel_j0(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x) { + return at::_ops::special_spherical_bessel_j0::redispatch(dispatchKeySet, x); + } + + // aten::special_spherical_bessel_j0.out(Tensor x, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_spherical_bessel_j0_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & x) { + return at::_ops::special_spherical_bessel_j0_out::redispatch(dispatchKeySet, x, out); + } + + // aten::special_spherical_bessel_j0.out(Tensor x, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_spherical_bessel_j0_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, at::Tensor & out) { + return at::_ops::special_spherical_bessel_j0_out::redispatch(dispatchKeySet, x, out); + } + + // aten::_foobar(Tensor self, bool arg1=True, bool arg2=True, *, bool arg3=True) -> Tensor + inline at::Tensor _foobar(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool arg1=true, bool arg2=true, bool arg3=true) { + return at::_ops::_foobar::redispatch(dispatchKeySet, self, arg1, arg2, arg3); + } + + // aten::_fused_adam_(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] exp_avgs, Tensor(d!)[] exp_avg_sqs, Tensor(e!)[] max_exp_avg_sqs, Tensor[] state_steps, *, float lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> () + inline void _fused_adam_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList grads, at::TensorList exp_avgs, at::TensorList exp_avg_sqs, at::TensorList max_exp_avg_sqs, at::TensorList state_steps, double lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const ::std::optional & grad_scale={}, const ::std::optional & found_inf={}) { + return at::_ops::_fused_adam_::redispatch(dispatchKeySet, self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale, found_inf); + } + + // aten::_fused_adam_.tensor_lr(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] exp_avgs, Tensor(d!)[] exp_avg_sqs, Tensor(e!)[] max_exp_avg_sqs, Tensor[] state_steps, *, Tensor lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> () + inline void _fused_adam_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList grads, at::TensorList exp_avgs, at::TensorList exp_avg_sqs, at::TensorList max_exp_avg_sqs, at::TensorList state_steps, const at::Tensor & lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const ::std::optional & grad_scale={}, const ::std::optional & found_inf={}) { + return at::_ops::_fused_adam__tensor_lr::redispatch(dispatchKeySet, self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale, found_inf); + } + + // aten::_fused_adamw_(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] exp_avgs, Tensor(d!)[] exp_avg_sqs, Tensor(e!)[] max_exp_avg_sqs, Tensor[] state_steps, *, float lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> () + inline void _fused_adamw_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList grads, at::TensorList exp_avgs, at::TensorList exp_avg_sqs, at::TensorList max_exp_avg_sqs, at::TensorList state_steps, double lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const ::std::optional & grad_scale={}, const ::std::optional & found_inf={}) { + return at::_ops::_fused_adamw_::redispatch(dispatchKeySet, self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale, found_inf); + } + + // aten::_fused_adamw_.tensor_lr(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] exp_avgs, Tensor(d!)[] exp_avg_sqs, Tensor(e!)[] max_exp_avg_sqs, Tensor[] state_steps, *, Tensor lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> () + inline void _fused_adamw_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList grads, at::TensorList exp_avgs, at::TensorList exp_avg_sqs, at::TensorList max_exp_avg_sqs, at::TensorList state_steps, const at::Tensor & lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const ::std::optional & grad_scale={}, const ::std::optional & found_inf={}) { + return at::_ops::_fused_adamw__tensor_lr::redispatch(dispatchKeySet, self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale, found_inf); + } + + // aten::_fused_sgd_(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] momentum_buffer_list, *, float weight_decay, float momentum, float lr, float dampening, bool nesterov, bool maximize, bool is_first_step, Tensor? grad_scale=None, Tensor? found_inf=None) -> () + inline void _fused_sgd_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList grads, at::TensorList momentum_buffer_list, double weight_decay, double momentum, double lr, double dampening, bool nesterov, bool maximize, bool is_first_step, const ::std::optional & grad_scale={}, const ::std::optional & found_inf={}) { + return at::_ops::_fused_sgd_::redispatch(dispatchKeySet, self, grads, momentum_buffer_list, weight_decay, momentum, lr, dampening, nesterov, maximize, is_first_step, grad_scale, found_inf); + } + + // aten::_fused_sgd_.tensor_lr(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] momentum_buffer_list, *, float weight_decay, float momentum, Tensor lr, float dampening, bool nesterov, bool maximize, bool is_first_step, Tensor? grad_scale=None, Tensor? found_inf=None) -> () + inline void _fused_sgd_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList grads, at::TensorList momentum_buffer_list, double weight_decay, double momentum, const at::Tensor & lr, double dampening, bool nesterov, bool maximize, bool is_first_step, const ::std::optional & grad_scale={}, const ::std::optional & found_inf={}) { + return at::_ops::_fused_sgd__tensor_lr::redispatch(dispatchKeySet, self, grads, momentum_buffer_list, weight_decay, momentum, lr, dampening, nesterov, maximize, is_first_step, grad_scale, found_inf); + } + + // aten::_fused_adagrad_(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] state_sums, Tensor(d!)[] state_steps, *, float lr, float lr_decay, float weight_decay, float eps, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> () + inline void _fused_adagrad_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList grads, at::TensorList state_sums, at::TensorList state_steps, double lr, double lr_decay, double weight_decay, double eps, bool maximize, const ::std::optional & grad_scale={}, const ::std::optional & found_inf={}) { + return at::_ops::_fused_adagrad_::redispatch(dispatchKeySet, self, grads, state_sums, state_steps, lr, lr_decay, weight_decay, eps, maximize, grad_scale, found_inf); + } + + // aten::_fused_adagrad_.tensor_lr(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] state_sums, Tensor[] state_steps, *, Tensor lr, float lr_decay, float weight_decay, float eps, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> () + inline void _fused_adagrad_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList grads, at::TensorList state_sums, at::TensorList state_steps, const at::Tensor & lr, double lr_decay, double weight_decay, double eps, bool maximize, const ::std::optional & grad_scale={}, const ::std::optional & found_inf={}) { + return at::_ops::_fused_adagrad__tensor_lr::redispatch(dispatchKeySet, self, grads, state_sums, state_steps, lr, lr_decay, weight_decay, eps, maximize, grad_scale, found_inf); + } + + // aten::_propagate_xla_data(Tensor input, Tensor output) -> () + inline void _propagate_xla_data(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & output) { + return at::_ops::_propagate_xla_data::redispatch(dispatchKeySet, input, output); + } + + // aten::_new_zeros_with_same_feature_meta.out(Tensor self, Tensor other, *, int self_num_batch_dims=0, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _new_zeros_with_same_feature_meta_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other, int64_t self_num_batch_dims=0) { + return at::_ops::_new_zeros_with_same_feature_meta_out::redispatch(dispatchKeySet, self, other, self_num_batch_dims, out); + } + + // aten::_new_zeros_with_same_feature_meta.out(Tensor self, Tensor other, *, int self_num_batch_dims=0, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _new_zeros_with_same_feature_meta_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, int64_t self_num_batch_dims, at::Tensor & out) { + return at::_ops::_new_zeros_with_same_feature_meta_out::redispatch(dispatchKeySet, self, other, self_num_batch_dims, out); + } + + // aten::_cudnn_ctc_loss.out(Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, int blank, bool deterministic, bool zero_infinity, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple _cudnn_ctc_loss_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, const at::Tensor & log_probs, const at::Tensor & targets, at::IntArrayRef input_lengths, at::IntArrayRef target_lengths, int64_t blank, bool deterministic, bool zero_infinity) { + return at::_ops::_cudnn_ctc_loss_out::redispatch(dispatchKeySet, log_probs, targets, input_lengths, target_lengths, blank, deterministic, zero_infinity, out0, out1); + } + + // aten::_cudnn_ctc_loss.out(Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, int blank, bool deterministic, bool zero_infinity, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple _cudnn_ctc_loss_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & log_probs, const at::Tensor & targets, at::IntArrayRef input_lengths, at::IntArrayRef target_lengths, int64_t blank, bool deterministic, bool zero_infinity, at::Tensor & out0, at::Tensor & out1) { + return at::_ops::_cudnn_ctc_loss_out::redispatch(dispatchKeySet, log_probs, targets, input_lengths, target_lengths, blank, deterministic, zero_infinity, out0, out1); + } + + // aten::_cudnn_rnn_flatten_weight.out(Tensor[] weight_arr, int weight_stride0, SymInt input_size, int mode, SymInt hidden_size, SymInt proj_size, int num_layers, bool batch_first, bool bidirectional, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _cudnn_rnn_flatten_weight_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::TensorList weight_arr, int64_t weight_stride0, int64_t input_size, int64_t mode, int64_t hidden_size, int64_t proj_size, int64_t num_layers, bool batch_first, bool bidirectional) { + return at::_ops::_cudnn_rnn_flatten_weight_out::redispatch(dispatchKeySet, weight_arr, weight_stride0, input_size, mode, hidden_size, proj_size, num_layers, batch_first, bidirectional, out); + } + + // aten::_cudnn_rnn_flatten_weight.out(Tensor[] weight_arr, int weight_stride0, SymInt input_size, int mode, SymInt hidden_size, SymInt proj_size, int num_layers, bool batch_first, bool bidirectional, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _cudnn_rnn_flatten_weight_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList weight_arr, int64_t weight_stride0, int64_t input_size, int64_t mode, int64_t hidden_size, int64_t proj_size, int64_t num_layers, bool batch_first, bool bidirectional, at::Tensor & out) { + return at::_ops::_cudnn_rnn_flatten_weight_out::redispatch(dispatchKeySet, weight_arr, weight_stride0, input_size, mode, hidden_size, proj_size, num_layers, batch_first, bidirectional, out); + } + + // aten::_cudnn_rnn_flatten_weight.out(Tensor[] weight_arr, int weight_stride0, SymInt input_size, int mode, SymInt hidden_size, SymInt proj_size, int num_layers, bool batch_first, bool bidirectional, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _cudnn_rnn_flatten_weight_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::TensorList weight_arr, int64_t weight_stride0, c10::SymInt input_size, int64_t mode, c10::SymInt hidden_size, c10::SymInt proj_size, int64_t num_layers, bool batch_first, bool bidirectional) { + return at::_ops::_cudnn_rnn_flatten_weight_out::redispatch(dispatchKeySet, weight_arr, weight_stride0, input_size, mode, hidden_size, proj_size, num_layers, batch_first, bidirectional, out); + } + + // aten::_cudnn_rnn_flatten_weight.out(Tensor[] weight_arr, int weight_stride0, SymInt input_size, int mode, SymInt hidden_size, SymInt proj_size, int num_layers, bool batch_first, bool bidirectional, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _cudnn_rnn_flatten_weight_symint_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList weight_arr, int64_t weight_stride0, c10::SymInt input_size, int64_t mode, c10::SymInt hidden_size, c10::SymInt proj_size, int64_t num_layers, bool batch_first, bool bidirectional, at::Tensor & out) { + return at::_ops::_cudnn_rnn_flatten_weight_out::redispatch(dispatchKeySet, weight_arr, weight_stride0, input_size, mode, hidden_size, proj_size, num_layers, batch_first, bidirectional, out); + } + + // aten::_cudnn_rnn.out(Tensor input, Tensor[] weight, int weight_stride0, Tensor? weight_buf, Tensor hx, Tensor? cx, int mode, SymInt hidden_size, SymInt proj_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, SymInt[] batch_sizes, Tensor? dropout_state, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3, Tensor(e!) out4) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!), Tensor(e!)) + inline ::std::tuple _cudnn_rnn_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3, at::Tensor & out4, const at::Tensor & input, at::TensorList weight, int64_t weight_stride0, const ::std::optional & weight_buf, const at::Tensor & hx, const ::std::optional & cx, int64_t mode, int64_t hidden_size, int64_t proj_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, at::IntArrayRef batch_sizes, const ::std::optional & dropout_state) { + return at::_ops::_cudnn_rnn_out::redispatch(dispatchKeySet, input, weight, weight_stride0, weight_buf, hx, cx, mode, hidden_size, proj_size, num_layers, batch_first, dropout, train, bidirectional, c10::fromIntArrayRefSlow(batch_sizes), dropout_state, out0, out1, out2, out3, out4); + } + + // aten::_cudnn_rnn.out(Tensor input, Tensor[] weight, int weight_stride0, Tensor? weight_buf, Tensor hx, Tensor? cx, int mode, SymInt hidden_size, SymInt proj_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, SymInt[] batch_sizes, Tensor? dropout_state, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3, Tensor(e!) out4) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!), Tensor(e!)) + inline ::std::tuple _cudnn_rnn_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::TensorList weight, int64_t weight_stride0, const ::std::optional & weight_buf, const at::Tensor & hx, const ::std::optional & cx, int64_t mode, int64_t hidden_size, int64_t proj_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, at::IntArrayRef batch_sizes, const ::std::optional & dropout_state, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3, at::Tensor & out4) { + return at::_ops::_cudnn_rnn_out::redispatch(dispatchKeySet, input, weight, weight_stride0, weight_buf, hx, cx, mode, hidden_size, proj_size, num_layers, batch_first, dropout, train, bidirectional, c10::fromIntArrayRefSlow(batch_sizes), dropout_state, out0, out1, out2, out3, out4); + } + + // aten::_cudnn_rnn.out(Tensor input, Tensor[] weight, int weight_stride0, Tensor? weight_buf, Tensor hx, Tensor? cx, int mode, SymInt hidden_size, SymInt proj_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, SymInt[] batch_sizes, Tensor? dropout_state, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3, Tensor(e!) out4) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!), Tensor(e!)) + inline ::std::tuple _cudnn_rnn_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3, at::Tensor & out4, const at::Tensor & input, at::TensorList weight, int64_t weight_stride0, const ::std::optional & weight_buf, const at::Tensor & hx, const ::std::optional & cx, int64_t mode, c10::SymInt hidden_size, c10::SymInt proj_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, c10::SymIntArrayRef batch_sizes, const ::std::optional & dropout_state) { + return at::_ops::_cudnn_rnn_out::redispatch(dispatchKeySet, input, weight, weight_stride0, weight_buf, hx, cx, mode, hidden_size, proj_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state, out0, out1, out2, out3, out4); + } + + // aten::_cudnn_rnn.out(Tensor input, Tensor[] weight, int weight_stride0, Tensor? weight_buf, Tensor hx, Tensor? cx, int mode, SymInt hidden_size, SymInt proj_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, SymInt[] batch_sizes, Tensor? dropout_state, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3, Tensor(e!) out4) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!), Tensor(e!)) + inline ::std::tuple _cudnn_rnn_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::TensorList weight, int64_t weight_stride0, const ::std::optional & weight_buf, const at::Tensor & hx, const ::std::optional & cx, int64_t mode, c10::SymInt hidden_size, c10::SymInt proj_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, c10::SymIntArrayRef batch_sizes, const ::std::optional & dropout_state, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3, at::Tensor & out4) { + return at::_ops::_cudnn_rnn_out::redispatch(dispatchKeySet, input, weight, weight_stride0, weight_buf, hx, cx, mode, hidden_size, proj_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state, out0, out1, out2, out3, out4); + } + + // aten::_cudnn_rnn_backward.out(Tensor input, Tensor[] weight, int weight_stride0, Tensor weight_buf, Tensor hx, Tensor? cx, Tensor output, Tensor? grad_output, Tensor? grad_hy, Tensor? grad_cy, int mode, SymInt hidden_size, SymInt proj_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, SymInt[] batch_sizes, Tensor? dropout_state, Tensor reserve, bool[4] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!)[] out3) -> () + inline void _cudnn_rnn_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::TensorList out3, const at::Tensor & input, at::TensorList weight, int64_t weight_stride0, const at::Tensor & weight_buf, const at::Tensor & hx, const ::std::optional & cx, const at::Tensor & output, const ::std::optional & grad_output, const ::std::optional & grad_hy, const ::std::optional & grad_cy, int64_t mode, int64_t hidden_size, int64_t proj_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, at::IntArrayRef batch_sizes, const ::std::optional & dropout_state, const at::Tensor & reserve, ::std::array output_mask) { + return at::_ops::_cudnn_rnn_backward_out::redispatch(dispatchKeySet, input, weight, weight_stride0, weight_buf, hx, cx, output, grad_output, grad_hy, grad_cy, mode, hidden_size, proj_size, num_layers, batch_first, dropout, train, bidirectional, c10::fromIntArrayRefSlow(batch_sizes), dropout_state, reserve, output_mask, out0, out1, out2, out3); + } + + // aten::_cudnn_rnn_backward.out(Tensor input, Tensor[] weight, int weight_stride0, Tensor weight_buf, Tensor hx, Tensor? cx, Tensor output, Tensor? grad_output, Tensor? grad_hy, Tensor? grad_cy, int mode, SymInt hidden_size, SymInt proj_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, SymInt[] batch_sizes, Tensor? dropout_state, Tensor reserve, bool[4] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!)[] out3) -> () + inline void _cudnn_rnn_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::TensorList weight, int64_t weight_stride0, const at::Tensor & weight_buf, const at::Tensor & hx, const ::std::optional & cx, const at::Tensor & output, const ::std::optional & grad_output, const ::std::optional & grad_hy, const ::std::optional & grad_cy, int64_t mode, int64_t hidden_size, int64_t proj_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, at::IntArrayRef batch_sizes, const ::std::optional & dropout_state, const at::Tensor & reserve, ::std::array output_mask, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::TensorList out3) { + return at::_ops::_cudnn_rnn_backward_out::redispatch(dispatchKeySet, input, weight, weight_stride0, weight_buf, hx, cx, output, grad_output, grad_hy, grad_cy, mode, hidden_size, proj_size, num_layers, batch_first, dropout, train, bidirectional, c10::fromIntArrayRefSlow(batch_sizes), dropout_state, reserve, output_mask, out0, out1, out2, out3); + } + + // aten::_cudnn_rnn_backward.out(Tensor input, Tensor[] weight, int weight_stride0, Tensor weight_buf, Tensor hx, Tensor? cx, Tensor output, Tensor? grad_output, Tensor? grad_hy, Tensor? grad_cy, int mode, SymInt hidden_size, SymInt proj_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, SymInt[] batch_sizes, Tensor? dropout_state, Tensor reserve, bool[4] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!)[] out3) -> () + inline void _cudnn_rnn_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::TensorList out3, const at::Tensor & input, at::TensorList weight, int64_t weight_stride0, const at::Tensor & weight_buf, const at::Tensor & hx, const ::std::optional & cx, const at::Tensor & output, const ::std::optional & grad_output, const ::std::optional & grad_hy, const ::std::optional & grad_cy, int64_t mode, c10::SymInt hidden_size, c10::SymInt proj_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, c10::SymIntArrayRef batch_sizes, const ::std::optional & dropout_state, const at::Tensor & reserve, ::std::array output_mask) { + return at::_ops::_cudnn_rnn_backward_out::redispatch(dispatchKeySet, input, weight, weight_stride0, weight_buf, hx, cx, output, grad_output, grad_hy, grad_cy, mode, hidden_size, proj_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state, reserve, output_mask, out0, out1, out2, out3); + } + + // aten::_cudnn_rnn_backward.out(Tensor input, Tensor[] weight, int weight_stride0, Tensor weight_buf, Tensor hx, Tensor? cx, Tensor output, Tensor? grad_output, Tensor? grad_hy, Tensor? grad_cy, int mode, SymInt hidden_size, SymInt proj_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, SymInt[] batch_sizes, Tensor? dropout_state, Tensor reserve, bool[4] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!)[] out3) -> () + inline void _cudnn_rnn_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::TensorList weight, int64_t weight_stride0, const at::Tensor & weight_buf, const at::Tensor & hx, const ::std::optional & cx, const at::Tensor & output, const ::std::optional & grad_output, const ::std::optional & grad_hy, const ::std::optional & grad_cy, int64_t mode, c10::SymInt hidden_size, c10::SymInt proj_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, c10::SymIntArrayRef batch_sizes, const ::std::optional & dropout_state, const at::Tensor & reserve, ::std::array output_mask, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::TensorList out3) { + return at::_ops::_cudnn_rnn_backward_out::redispatch(dispatchKeySet, input, weight, weight_stride0, weight_buf, hx, cx, output, grad_output, grad_hy, grad_cy, mode, hidden_size, proj_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state, reserve, output_mask, out0, out1, out2, out3); + } + + // aten::_cudnn_init_dropout_state.out(float dropout, bool train, int dropout_seed, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _cudnn_init_dropout_state_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, double dropout, bool train, int64_t dropout_seed) { + return at::_ops::_cudnn_init_dropout_state_out::redispatch(dispatchKeySet, dropout, train, dropout_seed, out); + } + + // aten::_cudnn_init_dropout_state.out(float dropout, bool train, int dropout_seed, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _cudnn_init_dropout_state_outf(c10::DispatchKeySet dispatchKeySet, double dropout, bool train, int64_t dropout_seed, at::Tensor & out) { + return at::_ops::_cudnn_init_dropout_state_out::redispatch(dispatchKeySet, dropout, train, dropout_seed, out); + } + + // aten::_fused_dropout.out(Tensor self, float p, Generator? generator=None, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple _fused_dropout_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, const at::Tensor & self, double p, ::std::optional generator=::std::nullopt) { + return at::_ops::_fused_dropout_out::redispatch(dispatchKeySet, self, p, generator, out0, out1); + } + + // aten::_fused_dropout.out(Tensor self, float p, Generator? generator=None, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple _fused_dropout_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double p, ::std::optional generator, at::Tensor & out0, at::Tensor & out1) { + return at::_ops::_fused_dropout_out::redispatch(dispatchKeySet, self, p, generator, out0, out1); + } + + // aten::_masked_scale.out(Tensor self, Tensor mask, float scale, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _masked_scale_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & mask, double scale) { + return at::_ops::_masked_scale_out::redispatch(dispatchKeySet, self, mask, scale, out); + } + + // aten::_masked_scale.out(Tensor self, Tensor mask, float scale, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _masked_scale_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mask, double scale, at::Tensor & out) { + return at::_ops::_masked_scale_out::redispatch(dispatchKeySet, self, mask, scale, out); + } + + // aten::native_dropout.out(Tensor input, float p, bool? train, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple native_dropout_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, const at::Tensor & input, double p, ::std::optional train) { + return at::_ops::native_dropout_out::redispatch(dispatchKeySet, input, p, train, out0, out1); + } + + // aten::native_dropout.out(Tensor input, float p, bool? train, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple native_dropout_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, double p, ::std::optional train, at::Tensor & out0, at::Tensor & out1) { + return at::_ops::native_dropout_out::redispatch(dispatchKeySet, input, p, train, out0, out1); + } + + // aten::native_dropout_backward.out(Tensor grad_output, Tensor mask, float scale, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & native_dropout_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad_output, const at::Tensor & mask, double scale) { + return at::_ops::native_dropout_backward_out::redispatch(dispatchKeySet, grad_output, mask, scale, out); + } + + // aten::native_dropout_backward.out(Tensor grad_output, Tensor mask, float scale, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & native_dropout_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & mask, double scale, at::Tensor & out) { + return at::_ops::native_dropout_backward_out::redispatch(dispatchKeySet, grad_output, mask, scale, out); + } + + // aten::_conj_physical.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _conj_physical_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::_conj_physical_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_conj_physical.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _conj_physical_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::_conj_physical_out::redispatch(dispatchKeySet, self, out); + } + + // aten::avg_pool1d.out(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, bool ceil_mode=False, bool count_include_pad=True, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & avg_pool1d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, bool ceil_mode=false, bool count_include_pad=true) { + return at::_ops::avg_pool1d_out::redispatch(dispatchKeySet, self, kernel_size, stride, padding, ceil_mode, count_include_pad, out); + } + + // aten::avg_pool1d.out(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, bool ceil_mode=False, bool count_include_pad=True, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & avg_pool1d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, at::Tensor & out) { + return at::_ops::avg_pool1d_out::redispatch(dispatchKeySet, self, kernel_size, stride, padding, ceil_mode, count_include_pad, out); + } + + // aten::adaptive_avg_pool1d.out(Tensor self, int[1] output_size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & adaptive_avg_pool1d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef output_size) { + return at::_ops::adaptive_avg_pool1d_out::redispatch(dispatchKeySet, self, output_size, out); + } + + // aten::adaptive_avg_pool1d.out(Tensor self, int[1] output_size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & adaptive_avg_pool1d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, at::Tensor & out) { + return at::_ops::adaptive_avg_pool1d_out::redispatch(dispatchKeySet, self, output_size, out); + } + + // aten::_add_relu.Scalar_out(Tensor self, Scalar other, Scalar alpha=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _add_relu_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha=1) { + return at::_ops::_add_relu_Scalar_out::redispatch(dispatchKeySet, self, other, alpha, out); + } + + // aten::_add_relu.Scalar_out(Tensor self, Scalar other, Scalar alpha=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _add_relu_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha, at::Tensor & out) { + return at::_ops::_add_relu_Scalar_out::redispatch(dispatchKeySet, self, other, alpha, out); + } + + // aten::add.Scalar_out(Tensor self, Scalar other, Scalar alpha=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & add_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha=1) { + return at::_ops::add_Scalar_out::redispatch(dispatchKeySet, self, other, alpha, out); + } + + // aten::add.Scalar_out(Tensor self, Scalar other, Scalar alpha=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & add_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha, at::Tensor & out) { + return at::_ops::add_Scalar_out::redispatch(dispatchKeySet, self, other, alpha, out); + } + + // aten::affine_grid_generator.out(Tensor theta, SymInt[] size, bool align_corners, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & affine_grid_generator_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & theta, at::IntArrayRef size, bool align_corners) { + return at::_ops::affine_grid_generator_out::redispatch(dispatchKeySet, theta, c10::fromIntArrayRefSlow(size), align_corners, out); + } + + // aten::affine_grid_generator.out(Tensor theta, SymInt[] size, bool align_corners, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & affine_grid_generator_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & theta, at::IntArrayRef size, bool align_corners, at::Tensor & out) { + return at::_ops::affine_grid_generator_out::redispatch(dispatchKeySet, theta, c10::fromIntArrayRefSlow(size), align_corners, out); + } + + // aten::affine_grid_generator.out(Tensor theta, SymInt[] size, bool align_corners, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & affine_grid_generator_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & theta, c10::SymIntArrayRef size, bool align_corners) { + return at::_ops::affine_grid_generator_out::redispatch(dispatchKeySet, theta, size, align_corners, out); + } + + // aten::affine_grid_generator.out(Tensor theta, SymInt[] size, bool align_corners, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & affine_grid_generator_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & theta, c10::SymIntArrayRef size, bool align_corners, at::Tensor & out) { + return at::_ops::affine_grid_generator_out::redispatch(dispatchKeySet, theta, size, align_corners, out); + } + + // aten::_test_functorch_fallback.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _test_functorch_fallback_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::_test_functorch_fallback_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::_test_functorch_fallback.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _test_functorch_fallback_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::_test_functorch_fallback_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::bartlett_window.out(int window_length, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bartlett_window_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, int64_t window_length) { + return at::_ops::bartlett_window_out::redispatch(dispatchKeySet, window_length, out); + } + + // aten::bartlett_window.out(int window_length, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bartlett_window_outf(c10::DispatchKeySet dispatchKeySet, int64_t window_length, at::Tensor & out) { + return at::_ops::bartlett_window_out::redispatch(dispatchKeySet, window_length, out); + } + + // aten::bartlett_window.periodic_out(int window_length, bool periodic, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bartlett_window_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, int64_t window_length, bool periodic) { + return at::_ops::bartlett_window_periodic_out::redispatch(dispatchKeySet, window_length, periodic, out); + } + + // aten::bartlett_window.periodic_out(int window_length, bool periodic, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bartlett_window_outf(c10::DispatchKeySet dispatchKeySet, int64_t window_length, bool periodic, at::Tensor & out) { + return at::_ops::bartlett_window_periodic_out::redispatch(dispatchKeySet, window_length, periodic, out); + } + + // aten::quantized_batch_norm.out(Tensor input, Tensor? weight, Tensor? bias, Tensor mean, Tensor var, float eps, float output_scale, int output_zero_point, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & quantized_batch_norm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const at::Tensor & mean, const at::Tensor & var, double eps, double output_scale, int64_t output_zero_point) { + return at::_ops::quantized_batch_norm_out::redispatch(dispatchKeySet, input, weight, bias, mean, var, eps, output_scale, output_zero_point, out); + } + + // aten::quantized_batch_norm.out(Tensor input, Tensor? weight, Tensor? bias, Tensor mean, Tensor var, float eps, float output_scale, int output_zero_point, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & quantized_batch_norm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const at::Tensor & mean, const at::Tensor & var, double eps, double output_scale, int64_t output_zero_point, at::Tensor & out) { + return at::_ops::quantized_batch_norm_out::redispatch(dispatchKeySet, input, weight, bias, mean, var, eps, output_scale, output_zero_point, out); + } + + // aten::bernoulli.Tensor_out(Tensor self, Tensor p, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bernoulli_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & p, ::std::optional generator=::std::nullopt) { + return at::_ops::bernoulli_Tensor_out::redispatch(dispatchKeySet, self, p, generator, out); + } + + // aten::bernoulli.Tensor_out(Tensor self, Tensor p, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bernoulli_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & p, ::std::optional generator, at::Tensor & out) { + return at::_ops::bernoulli_Tensor_out::redispatch(dispatchKeySet, self, p, generator, out); + } + + // aten::bernoulli.Tensor(Tensor self, Tensor p, *, Generator? generator=None) -> Tensor + inline at::Tensor bernoulli(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & p, ::std::optional generator=::std::nullopt) { + return at::_ops::bernoulli_Tensor::redispatch(dispatchKeySet, self, p, generator); + } + + // aten::bernoulli.float_out(Tensor self, float p=0.5, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bernoulli_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, double p=0.5, ::std::optional generator=::std::nullopt) { + return at::_ops::bernoulli_float_out::redispatch(dispatchKeySet, self, p, generator, out); + } + + // aten::bernoulli.float_out(Tensor self, float p=0.5, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bernoulli_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double p, ::std::optional generator, at::Tensor & out) { + return at::_ops::bernoulli_float_out::redispatch(dispatchKeySet, self, p, generator, out); + } + + // aten::binary_cross_entropy_with_logits.out(Tensor self, Tensor target, Tensor? weight=None, Tensor? pos_weight=None, int reduction=Mean, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & binary_cross_entropy_with_logits_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight={}, const ::std::optional & pos_weight={}, int64_t reduction=at::Reduction::Mean) { + return at::_ops::binary_cross_entropy_with_logits_out::redispatch(dispatchKeySet, self, target, weight, pos_weight, reduction, out); + } + + // aten::binary_cross_entropy_with_logits.out(Tensor self, Tensor target, Tensor? weight=None, Tensor? pos_weight=None, int reduction=Mean, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & binary_cross_entropy_with_logits_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, const ::std::optional & pos_weight, int64_t reduction, at::Tensor & out) { + return at::_ops::binary_cross_entropy_with_logits_out::redispatch(dispatchKeySet, self, target, weight, pos_weight, reduction, out); + } + + // aten::bincount.out(Tensor self, Tensor? weights=None, SymInt minlength=0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bincount_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const ::std::optional & weights={}, int64_t minlength=0) { + return at::_ops::bincount_out::redispatch(dispatchKeySet, self, weights, minlength, out); + } + + // aten::bincount.out(Tensor self, Tensor? weights=None, SymInt minlength=0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bincount_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const ::std::optional & weights, int64_t minlength, at::Tensor & out) { + return at::_ops::bincount_out::redispatch(dispatchKeySet, self, weights, minlength, out); + } + + // aten::bincount.out(Tensor self, Tensor? weights=None, SymInt minlength=0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bincount_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const ::std::optional & weights={}, c10::SymInt minlength=0) { + return at::_ops::bincount_out::redispatch(dispatchKeySet, self, weights, minlength, out); + } + + // aten::bincount.out(Tensor self, Tensor? weights=None, SymInt minlength=0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bincount_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const ::std::optional & weights, c10::SymInt minlength, at::Tensor & out) { + return at::_ops::bincount_out::redispatch(dispatchKeySet, self, weights, minlength, out); + } + + // aten::blackman_window.out(int window_length, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & blackman_window_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, int64_t window_length) { + return at::_ops::blackman_window_out::redispatch(dispatchKeySet, window_length, out); + } + + // aten::blackman_window.out(int window_length, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & blackman_window_outf(c10::DispatchKeySet dispatchKeySet, int64_t window_length, at::Tensor & out) { + return at::_ops::blackman_window_out::redispatch(dispatchKeySet, window_length, out); + } + + // aten::blackman_window.periodic_out(int window_length, bool periodic, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & blackman_window_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, int64_t window_length, bool periodic) { + return at::_ops::blackman_window_periodic_out::redispatch(dispatchKeySet, window_length, periodic, out); + } + + // aten::blackman_window.periodic_out(int window_length, bool periodic, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & blackman_window_outf(c10::DispatchKeySet dispatchKeySet, int64_t window_length, bool periodic, at::Tensor & out) { + return at::_ops::blackman_window_periodic_out::redispatch(dispatchKeySet, window_length, periodic, out); + } + + // aten::block_diag.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & block_diag_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::TensorList tensors) { + return at::_ops::block_diag_out::redispatch(dispatchKeySet, tensors, out); + } + + // aten::block_diag.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & block_diag_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors, at::Tensor & out) { + return at::_ops::block_diag_out::redispatch(dispatchKeySet, tensors, out); + } + + // aten::constant_pad_nd.out(Tensor self, SymInt[] pad, Scalar value=0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & constant_pad_nd_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef pad, const at::Scalar & value=0) { + return at::_ops::constant_pad_nd_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(pad), value, out); + } + + // aten::constant_pad_nd.out(Tensor self, SymInt[] pad, Scalar value=0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & constant_pad_nd_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef pad, const at::Scalar & value, at::Tensor & out) { + return at::_ops::constant_pad_nd_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(pad), value, out); + } + + // aten::constant_pad_nd.out(Tensor self, SymInt[] pad, Scalar value=0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & constant_pad_nd_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef pad, const at::Scalar & value=0) { + return at::_ops::constant_pad_nd_out::redispatch(dispatchKeySet, self, pad, value, out); + } + + // aten::constant_pad_nd.out(Tensor self, SymInt[] pad, Scalar value=0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & constant_pad_nd_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef pad, const at::Scalar & value, at::Tensor & out) { + return at::_ops::constant_pad_nd_out::redispatch(dispatchKeySet, self, pad, value, out); + } + + // aten::convolution.out(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & convolution_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups) { + return at::_ops::convolution_out::redispatch(dispatchKeySet, input, weight, bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), transposed, c10::fromIntArrayRefSlow(output_padding), groups, out); + } + + // aten::convolution.out(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & convolution_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups, at::Tensor & out) { + return at::_ops::convolution_out::redispatch(dispatchKeySet, input, weight, bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), transposed, c10::fromIntArrayRefSlow(output_padding), groups, out); + } + + // aten::convolution.out(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & convolution_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups) { + return at::_ops::convolution_out::redispatch(dispatchKeySet, input, weight, bias, stride, padding, dilation, transposed, output_padding, groups, out); + } + + // aten::convolution.out(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & convolution_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups, at::Tensor & out) { + return at::_ops::convolution_out::redispatch(dispatchKeySet, input, weight, bias, stride, padding, dilation, transposed, output_padding, groups, out); + } + + // aten::convolution_backward.out(Tensor grad_output, Tensor input, Tensor weight, SymInt[]? bias_sizes, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple convolution_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & weight, at::OptionalIntArrayRef bias_sizes, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups, ::std::array output_mask) { + return at::_ops::convolution_backward_out::redispatch(dispatchKeySet, grad_output, input, weight, bias_sizes.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*bias_sizes)) : ::std::nullopt, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), transposed, c10::fromIntArrayRefSlow(output_padding), groups, output_mask, out0, out1, out2); + } + + // aten::convolution_backward.out(Tensor grad_output, Tensor input, Tensor weight, SymInt[]? bias_sizes, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple convolution_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & weight, at::OptionalIntArrayRef bias_sizes, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups, ::std::array output_mask, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2) { + return at::_ops::convolution_backward_out::redispatch(dispatchKeySet, grad_output, input, weight, bias_sizes.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*bias_sizes)) : ::std::nullopt, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), transposed, c10::fromIntArrayRefSlow(output_padding), groups, output_mask, out0, out1, out2); + } + + // aten::convolution_backward.out(Tensor grad_output, Tensor input, Tensor weight, SymInt[]? bias_sizes, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple convolution_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & weight, at::OptionalSymIntArrayRef bias_sizes, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups, ::std::array output_mask) { + return at::_ops::convolution_backward_out::redispatch(dispatchKeySet, grad_output, input, weight, bias_sizes, stride, padding, dilation, transposed, output_padding, groups, output_mask, out0, out1, out2); + } + + // aten::convolution_backward.out(Tensor grad_output, Tensor input, Tensor weight, SymInt[]? bias_sizes, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple convolution_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & weight, at::OptionalSymIntArrayRef bias_sizes, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups, ::std::array output_mask, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2) { + return at::_ops::convolution_backward_out::redispatch(dispatchKeySet, grad_output, input, weight, bias_sizes, stride, padding, dilation, transposed, output_padding, groups, output_mask, out0, out1, out2); + } + + // aten::convolution_overrideable.out(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & convolution_overrideable_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups) { + return at::_ops::convolution_overrideable_out::redispatch(dispatchKeySet, input, weight, bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), transposed, c10::fromIntArrayRefSlow(output_padding), groups, out); + } + + // aten::convolution_overrideable.out(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & convolution_overrideable_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups, at::Tensor & out) { + return at::_ops::convolution_overrideable_out::redispatch(dispatchKeySet, input, weight, bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), transposed, c10::fromIntArrayRefSlow(output_padding), groups, out); + } + + // aten::convolution_overrideable.out(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & convolution_overrideable_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups) { + return at::_ops::convolution_overrideable_out::redispatch(dispatchKeySet, input, weight, bias, stride, padding, dilation, transposed, output_padding, groups, out); + } + + // aten::convolution_overrideable.out(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & convolution_overrideable_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups, at::Tensor & out) { + return at::_ops::convolution_overrideable_out::redispatch(dispatchKeySet, input, weight, bias, stride, padding, dilation, transposed, output_padding, groups, out); + } + + // aten::convolution_backward_overrideable.out(Tensor grad_output, Tensor input, Tensor weight, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple convolution_backward_overrideable_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & weight, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups, ::std::array output_mask) { + return at::_ops::convolution_backward_overrideable_out::redispatch(dispatchKeySet, grad_output, input, weight, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), transposed, c10::fromIntArrayRefSlow(output_padding), groups, output_mask, out0, out1, out2); + } + + // aten::convolution_backward_overrideable.out(Tensor grad_output, Tensor input, Tensor weight, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple convolution_backward_overrideable_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & weight, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups, ::std::array output_mask, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2) { + return at::_ops::convolution_backward_overrideable_out::redispatch(dispatchKeySet, grad_output, input, weight, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), transposed, c10::fromIntArrayRefSlow(output_padding), groups, output_mask, out0, out1, out2); + } + + // aten::convolution_backward_overrideable.out(Tensor grad_output, Tensor input, Tensor weight, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple convolution_backward_overrideable_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & weight, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups, ::std::array output_mask) { + return at::_ops::convolution_backward_overrideable_out::redispatch(dispatchKeySet, grad_output, input, weight, stride, padding, dilation, transposed, output_padding, groups, output_mask, out0, out1, out2); + } + + // aten::convolution_backward_overrideable.out(Tensor grad_output, Tensor input, Tensor weight, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple convolution_backward_overrideable_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & weight, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups, ::std::array output_mask, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2) { + return at::_ops::convolution_backward_overrideable_out::redispatch(dispatchKeySet, grad_output, input, weight, stride, padding, dilation, transposed, output_padding, groups, output_mask, out0, out1, out2); + } + + // aten::_convolution.out(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _convolution_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) { + return at::_ops::_convolution_out::redispatch(dispatchKeySet, input, weight, bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), transposed, c10::fromIntArrayRefSlow(output_padding), groups, benchmark, deterministic, cudnn_enabled, allow_tf32, out); + } + + // aten::_convolution.out(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _convolution_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32, at::Tensor & out) { + return at::_ops::_convolution_out::redispatch(dispatchKeySet, input, weight, bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), transposed, c10::fromIntArrayRefSlow(output_padding), groups, benchmark, deterministic, cudnn_enabled, allow_tf32, out); + } + + // aten::_convolution.out(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _convolution_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) { + return at::_ops::_convolution_out::redispatch(dispatchKeySet, input, weight, bias, stride, padding, dilation, transposed, output_padding, groups, benchmark, deterministic, cudnn_enabled, allow_tf32, out); + } + + // aten::_convolution.out(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _convolution_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32, at::Tensor & out) { + return at::_ops::_convolution_out::redispatch(dispatchKeySet, input, weight, bias, stride, padding, dilation, transposed, output_padding, groups, benchmark, deterministic, cudnn_enabled, allow_tf32, out); + } + + // aten::conv_tbc.out(Tensor self, Tensor weight, Tensor bias, int pad=0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & conv_tbc_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, const at::Tensor & bias, int64_t pad=0) { + return at::_ops::conv_tbc_out::redispatch(dispatchKeySet, self, weight, bias, pad, out); + } + + // aten::conv_tbc.out(Tensor self, Tensor weight, Tensor bias, int pad=0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & conv_tbc_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const at::Tensor & bias, int64_t pad, at::Tensor & out) { + return at::_ops::conv_tbc_out::redispatch(dispatchKeySet, self, weight, bias, pad, out); + } + + // aten::copy.out(Tensor self, Tensor src, bool non_blocking=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & src, bool non_blocking=false) { + return at::_ops::copy_out::redispatch(dispatchKeySet, self, src, non_blocking, out); + } + + // aten::copy.out(Tensor self, Tensor src, bool non_blocking=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & src, bool non_blocking, at::Tensor & out) { + return at::_ops::copy_out::redispatch(dispatchKeySet, self, src, non_blocking, out); + } + + // aten::_copy_from.out(Tensor self, Tensor dst, bool non_blocking=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _copy_from_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & dst, bool non_blocking=false) { + return at::_ops::_copy_from_out::redispatch(dispatchKeySet, self, dst, non_blocking, out); + } + + // aten::_copy_from.out(Tensor self, Tensor dst, bool non_blocking=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _copy_from_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & dst, bool non_blocking, at::Tensor & out) { + return at::_ops::_copy_from_out::redispatch(dispatchKeySet, self, dst, non_blocking, out); + } + + // aten::_copy_from_and_resize.out(Tensor self, Tensor dst, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _copy_from_and_resize_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & dst) { + return at::_ops::_copy_from_and_resize_out::redispatch(dispatchKeySet, self, dst, out); + } + + // aten::_copy_from_and_resize.out(Tensor self, Tensor dst, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _copy_from_and_resize_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & dst, at::Tensor & out) { + return at::_ops::_copy_from_and_resize_out::redispatch(dispatchKeySet, self, dst, out); + } + + // aten::count_nonzero.dim_IntList_out(Tensor self, int[] dim, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & count_nonzero_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef dim) { + return at::_ops::count_nonzero_dim_IntList_out::redispatch(dispatchKeySet, self, dim, out); + } + + // aten::count_nonzero.dim_IntList_out(Tensor self, int[] dim, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & count_nonzero_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim, at::Tensor & out) { + return at::_ops::count_nonzero_dim_IntList_out::redispatch(dispatchKeySet, self, dim, out); + } + + // aten::count_nonzero.out(Tensor self, int? dim=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & count_nonzero_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, ::std::optional dim=::std::nullopt) { + return at::_ops::count_nonzero_out::redispatch(dispatchKeySet, self, dim, out); + } + + // aten::count_nonzero.out(Tensor self, int? dim=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & count_nonzero_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional dim, at::Tensor & out) { + return at::_ops::count_nonzero_out::redispatch(dispatchKeySet, self, dim, out); + } + + // aten::cudnn_affine_grid_generator.out(Tensor theta, int N, int C, int H, int W, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & cudnn_affine_grid_generator_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & theta, int64_t N, int64_t C, int64_t H, int64_t W) { + return at::_ops::cudnn_affine_grid_generator_out::redispatch(dispatchKeySet, theta, N, C, H, W, out); + } + + // aten::cudnn_affine_grid_generator.out(Tensor theta, int N, int C, int H, int W, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & cudnn_affine_grid_generator_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & theta, int64_t N, int64_t C, int64_t H, int64_t W, at::Tensor & out) { + return at::_ops::cudnn_affine_grid_generator_out::redispatch(dispatchKeySet, theta, N, C, H, W, out); + } + + // aten::cudnn_affine_grid_generator_backward.out(Tensor grad, int N, int C, int H, int W, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & cudnn_affine_grid_generator_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad, int64_t N, int64_t C, int64_t H, int64_t W) { + return at::_ops::cudnn_affine_grid_generator_backward_out::redispatch(dispatchKeySet, grad, N, C, H, W, out); + } + + // aten::cudnn_affine_grid_generator_backward.out(Tensor grad, int N, int C, int H, int W, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & cudnn_affine_grid_generator_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, int64_t N, int64_t C, int64_t H, int64_t W, at::Tensor & out) { + return at::_ops::cudnn_affine_grid_generator_backward_out::redispatch(dispatchKeySet, grad, N, C, H, W, out); + } + + // aten::cudnn_batch_norm_backward.out(Tensor input, Tensor grad_output, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, float epsilon, Tensor reserveSpace, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple cudnn_batch_norm_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, const at::Tensor & input, const at::Tensor & grad_output, const at::Tensor & weight, const ::std::optional & running_mean, const ::std::optional & running_var, const ::std::optional & save_mean, const ::std::optional & save_var, double epsilon, const at::Tensor & reserveSpace) { + return at::_ops::cudnn_batch_norm_backward_out::redispatch(dispatchKeySet, input, grad_output, weight, running_mean, running_var, save_mean, save_var, epsilon, reserveSpace, out0, out1, out2); + } + + // aten::cudnn_batch_norm_backward.out(Tensor input, Tensor grad_output, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, float epsilon, Tensor reserveSpace, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple cudnn_batch_norm_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & grad_output, const at::Tensor & weight, const ::std::optional & running_mean, const ::std::optional & running_var, const ::std::optional & save_mean, const ::std::optional & save_var, double epsilon, const at::Tensor & reserveSpace, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2) { + return at::_ops::cudnn_batch_norm_backward_out::redispatch(dispatchKeySet, input, grad_output, weight, running_mean, running_var, save_mean, save_var, epsilon, reserveSpace, out0, out1, out2); + } + + // aten::cudnn_convolution_transpose.out(Tensor self, Tensor weight, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, bool allow_tf32, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & cudnn_convolution_transpose_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef padding, at::IntArrayRef output_padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups, bool benchmark, bool deterministic, bool allow_tf32) { + return at::_ops::cudnn_convolution_transpose_out::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(output_padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups, benchmark, deterministic, allow_tf32, out); + } + + // aten::cudnn_convolution_transpose.out(Tensor self, Tensor weight, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, bool allow_tf32, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & cudnn_convolution_transpose_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef padding, at::IntArrayRef output_padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups, bool benchmark, bool deterministic, bool allow_tf32, at::Tensor & out) { + return at::_ops::cudnn_convolution_transpose_out::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(output_padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups, benchmark, deterministic, allow_tf32, out); + } + + // aten::cudnn_convolution_transpose.out(Tensor self, Tensor weight, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, bool allow_tf32, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & cudnn_convolution_transpose_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, bool benchmark, bool deterministic, bool allow_tf32) { + return at::_ops::cudnn_convolution_transpose_out::redispatch(dispatchKeySet, self, weight, padding, output_padding, stride, dilation, groups, benchmark, deterministic, allow_tf32, out); + } + + // aten::cudnn_convolution_transpose.out(Tensor self, Tensor weight, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, bool allow_tf32, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & cudnn_convolution_transpose_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, bool benchmark, bool deterministic, bool allow_tf32, at::Tensor & out) { + return at::_ops::cudnn_convolution_transpose_out::redispatch(dispatchKeySet, self, weight, padding, output_padding, stride, dilation, groups, benchmark, deterministic, allow_tf32, out); + } + + // aten::_mps_convolution_transpose.out(Tensor self, Tensor weight, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _mps_convolution_transpose_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef padding, at::IntArrayRef output_padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups) { + return at::_ops::_mps_convolution_transpose_out::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(output_padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups, out); + } + + // aten::_mps_convolution_transpose.out(Tensor self, Tensor weight, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _mps_convolution_transpose_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef padding, at::IntArrayRef output_padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups, at::Tensor & out) { + return at::_ops::_mps_convolution_transpose_out::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(output_padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups, out); + } + + // aten::_mps_convolution_transpose.out(Tensor self, Tensor weight, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _mps_convolution_transpose_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups) { + return at::_ops::_mps_convolution_transpose_out::redispatch(dispatchKeySet, self, weight, padding, output_padding, stride, dilation, groups, out); + } + + // aten::_mps_convolution_transpose.out(Tensor self, Tensor weight, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _mps_convolution_transpose_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, at::Tensor & out) { + return at::_ops::_mps_convolution_transpose_out::redispatch(dispatchKeySet, self, weight, padding, output_padding, stride, dilation, groups, out); + } + + // aten::mps_convolution_transpose_backward.out(Tensor self, Tensor grad_output, Tensor weight, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool[2] output_mask, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple mps_convolution_transpose_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, const at::Tensor & self, const at::Tensor & grad_output, const at::Tensor & weight, at::IntArrayRef padding, at::IntArrayRef output_padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups, ::std::array output_mask) { + return at::_ops::mps_convolution_transpose_backward_out::redispatch(dispatchKeySet, self, grad_output, weight, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(output_padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups, output_mask, out0, out1); + } + + // aten::mps_convolution_transpose_backward.out(Tensor self, Tensor grad_output, Tensor weight, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool[2] output_mask, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple mps_convolution_transpose_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & grad_output, const at::Tensor & weight, at::IntArrayRef padding, at::IntArrayRef output_padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups, ::std::array output_mask, at::Tensor & out0, at::Tensor & out1) { + return at::_ops::mps_convolution_transpose_backward_out::redispatch(dispatchKeySet, self, grad_output, weight, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(output_padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups, output_mask, out0, out1); + } + + // aten::mps_convolution_transpose_backward.out(Tensor self, Tensor grad_output, Tensor weight, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool[2] output_mask, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple mps_convolution_transpose_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, const at::Tensor & self, const at::Tensor & grad_output, const at::Tensor & weight, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, ::std::array output_mask) { + return at::_ops::mps_convolution_transpose_backward_out::redispatch(dispatchKeySet, self, grad_output, weight, padding, output_padding, stride, dilation, groups, output_mask, out0, out1); + } + + // aten::mps_convolution_transpose_backward.out(Tensor self, Tensor grad_output, Tensor weight, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool[2] output_mask, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple mps_convolution_transpose_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & grad_output, const at::Tensor & weight, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, ::std::array output_mask, at::Tensor & out0, at::Tensor & out1) { + return at::_ops::mps_convolution_transpose_backward_out::redispatch(dispatchKeySet, self, grad_output, weight, padding, output_padding, stride, dilation, groups, output_mask, out0, out1); + } + + // aten::cudnn_convolution_relu.out(Tensor self, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, SymInt groups, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & cudnn_convolution_relu_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, int64_t groups) { + return at::_ops::cudnn_convolution_relu_out::redispatch(dispatchKeySet, self, weight, bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), groups, out); + } + + // aten::cudnn_convolution_relu.out(Tensor self, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, SymInt groups, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & cudnn_convolution_relu_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, int64_t groups, at::Tensor & out) { + return at::_ops::cudnn_convolution_relu_out::redispatch(dispatchKeySet, self, weight, bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), groups, out); + } + + // aten::cudnn_convolution_relu.out(Tensor self, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, SymInt groups, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & cudnn_convolution_relu_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, c10::SymInt groups) { + return at::_ops::cudnn_convolution_relu_out::redispatch(dispatchKeySet, self, weight, bias, stride, padding, dilation, groups, out); + } + + // aten::cudnn_convolution_relu.out(Tensor self, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, SymInt groups, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & cudnn_convolution_relu_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, c10::SymInt groups, at::Tensor & out) { + return at::_ops::cudnn_convolution_relu_out::redispatch(dispatchKeySet, self, weight, bias, stride, padding, dilation, groups, out); + } + + // aten::cudnn_convolution_add_relu.out(Tensor self, Tensor weight, Tensor z, Scalar? alpha, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, SymInt groups, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & cudnn_convolution_add_relu_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, const at::Tensor & z, const ::std::optional & alpha, const ::std::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, int64_t groups) { + return at::_ops::cudnn_convolution_add_relu_out::redispatch(dispatchKeySet, self, weight, z, alpha, bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), groups, out); + } + + // aten::cudnn_convolution_add_relu.out(Tensor self, Tensor weight, Tensor z, Scalar? alpha, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, SymInt groups, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & cudnn_convolution_add_relu_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const at::Tensor & z, const ::std::optional & alpha, const ::std::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, int64_t groups, at::Tensor & out) { + return at::_ops::cudnn_convolution_add_relu_out::redispatch(dispatchKeySet, self, weight, z, alpha, bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), groups, out); + } + + // aten::cudnn_convolution_add_relu.out(Tensor self, Tensor weight, Tensor z, Scalar? alpha, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, SymInt groups, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & cudnn_convolution_add_relu_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, const at::Tensor & z, const ::std::optional & alpha, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, c10::SymInt groups) { + return at::_ops::cudnn_convolution_add_relu_out::redispatch(dispatchKeySet, self, weight, z, alpha, bias, stride, padding, dilation, groups, out); + } + + // aten::cudnn_convolution_add_relu.out(Tensor self, Tensor weight, Tensor z, Scalar? alpha, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, SymInt groups, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & cudnn_convolution_add_relu_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const at::Tensor & z, const ::std::optional & alpha, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, c10::SymInt groups, at::Tensor & out) { + return at::_ops::cudnn_convolution_add_relu_out::redispatch(dispatchKeySet, self, weight, z, alpha, bias, stride, padding, dilation, groups, out); + } + + // aten::cudnn_grid_sampler.out(Tensor self, Tensor grid, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & cudnn_grid_sampler_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & grid) { + return at::_ops::cudnn_grid_sampler_out::redispatch(dispatchKeySet, self, grid, out); + } + + // aten::cudnn_grid_sampler.out(Tensor self, Tensor grid, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & cudnn_grid_sampler_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & grid, at::Tensor & out) { + return at::_ops::cudnn_grid_sampler_out::redispatch(dispatchKeySet, self, grid, out); + } + + // aten::cudnn_grid_sampler_backward.out(Tensor self, Tensor grid, Tensor grad_output, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple cudnn_grid_sampler_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, const at::Tensor & self, const at::Tensor & grid, const at::Tensor & grad_output) { + return at::_ops::cudnn_grid_sampler_backward_out::redispatch(dispatchKeySet, self, grid, grad_output, out0, out1); + } + + // aten::cudnn_grid_sampler_backward.out(Tensor self, Tensor grid, Tensor grad_output, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple cudnn_grid_sampler_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & grid, const at::Tensor & grad_output, at::Tensor & out0, at::Tensor & out1) { + return at::_ops::cudnn_grid_sampler_backward_out::redispatch(dispatchKeySet, self, grid, grad_output, out0, out1); + } + + // aten::_ctc_loss.out(Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, int blank=0, bool zero_infinity=False, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple _ctc_loss_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, const at::Tensor & log_probs, const at::Tensor & targets, at::IntArrayRef input_lengths, at::IntArrayRef target_lengths, int64_t blank=0, bool zero_infinity=false) { + return at::_ops::_ctc_loss_out::redispatch(dispatchKeySet, log_probs, targets, input_lengths, target_lengths, blank, zero_infinity, out0, out1); + } + + // aten::_ctc_loss.out(Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, int blank=0, bool zero_infinity=False, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple _ctc_loss_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & log_probs, const at::Tensor & targets, at::IntArrayRef input_lengths, at::IntArrayRef target_lengths, int64_t blank, bool zero_infinity, at::Tensor & out0, at::Tensor & out1) { + return at::_ops::_ctc_loss_out::redispatch(dispatchKeySet, log_probs, targets, input_lengths, target_lengths, blank, zero_infinity, out0, out1); + } + + // aten::_ctc_loss.Tensor_out(Tensor log_probs, Tensor targets, Tensor input_lengths, Tensor target_lengths, int blank=0, bool zero_infinity=False, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple _ctc_loss_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, const at::Tensor & log_probs, const at::Tensor & targets, const at::Tensor & input_lengths, const at::Tensor & target_lengths, int64_t blank=0, bool zero_infinity=false) { + return at::_ops::_ctc_loss_Tensor_out::redispatch(dispatchKeySet, log_probs, targets, input_lengths, target_lengths, blank, zero_infinity, out0, out1); + } + + // aten::_ctc_loss.Tensor_out(Tensor log_probs, Tensor targets, Tensor input_lengths, Tensor target_lengths, int blank=0, bool zero_infinity=False, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple _ctc_loss_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & log_probs, const at::Tensor & targets, const at::Tensor & input_lengths, const at::Tensor & target_lengths, int64_t blank, bool zero_infinity, at::Tensor & out0, at::Tensor & out1) { + return at::_ops::_ctc_loss_Tensor_out::redispatch(dispatchKeySet, log_probs, targets, input_lengths, target_lengths, blank, zero_infinity, out0, out1); + } + + // aten::_ctc_loss_backward.out(Tensor grad, Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, Tensor neg_log_likelihood, Tensor log_alpha, int blank, bool zero_infinity=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _ctc_loss_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad, const at::Tensor & log_probs, const at::Tensor & targets, at::IntArrayRef input_lengths, at::IntArrayRef target_lengths, const at::Tensor & neg_log_likelihood, const at::Tensor & log_alpha, int64_t blank, bool zero_infinity=false) { + return at::_ops::_ctc_loss_backward_out::redispatch(dispatchKeySet, grad, log_probs, targets, input_lengths, target_lengths, neg_log_likelihood, log_alpha, blank, zero_infinity, out); + } + + // aten::_ctc_loss_backward.out(Tensor grad, Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, Tensor neg_log_likelihood, Tensor log_alpha, int blank, bool zero_infinity=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _ctc_loss_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & log_probs, const at::Tensor & targets, at::IntArrayRef input_lengths, at::IntArrayRef target_lengths, const at::Tensor & neg_log_likelihood, const at::Tensor & log_alpha, int64_t blank, bool zero_infinity, at::Tensor & out) { + return at::_ops::_ctc_loss_backward_out::redispatch(dispatchKeySet, grad, log_probs, targets, input_lengths, target_lengths, neg_log_likelihood, log_alpha, blank, zero_infinity, out); + } + + // aten::diag_embed.out(Tensor self, int offset=0, int dim1=-2, int dim2=-1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & diag_embed_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t offset=0, int64_t dim1=-2, int64_t dim2=-1) { + return at::_ops::diag_embed_out::redispatch(dispatchKeySet, self, offset, dim1, dim2, out); + } + + // aten::diag_embed.out(Tensor self, int offset=0, int dim1=-2, int dim2=-1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & diag_embed_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t offset, int64_t dim1, int64_t dim2, at::Tensor & out) { + return at::_ops::diag_embed_out::redispatch(dispatchKeySet, self, offset, dim1, dim2, out); + } + + // aten::diagonal_backward.out(Tensor grad_output, SymInt[] input_sizes, int offset, int dim1, int dim2, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & diagonal_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad_output, at::IntArrayRef input_sizes, int64_t offset, int64_t dim1, int64_t dim2) { + return at::_ops::diagonal_backward_out::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(input_sizes), offset, dim1, dim2, out); + } + + // aten::diagonal_backward.out(Tensor grad_output, SymInt[] input_sizes, int offset, int dim1, int dim2, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & diagonal_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, at::IntArrayRef input_sizes, int64_t offset, int64_t dim1, int64_t dim2, at::Tensor & out) { + return at::_ops::diagonal_backward_out::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(input_sizes), offset, dim1, dim2, out); + } + + // aten::diagonal_backward.out(Tensor grad_output, SymInt[] input_sizes, int offset, int dim1, int dim2, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & diagonal_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad_output, c10::SymIntArrayRef input_sizes, int64_t offset, int64_t dim1, int64_t dim2) { + return at::_ops::diagonal_backward_out::redispatch(dispatchKeySet, grad_output, input_sizes, offset, dim1, dim2, out); + } + + // aten::diagonal_backward.out(Tensor grad_output, SymInt[] input_sizes, int offset, int dim1, int dim2, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & diagonal_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, c10::SymIntArrayRef input_sizes, int64_t offset, int64_t dim1, int64_t dim2, at::Tensor & out) { + return at::_ops::diagonal_backward_out::redispatch(dispatchKeySet, grad_output, input_sizes, offset, dim1, dim2, out); + } + + // aten::div.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & div_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::div_Scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::div.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & div_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, at::Tensor & out) { + return at::_ops::div_Scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::div.Scalar_mode_out(Tensor self, Scalar other, *, str? rounding_mode, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & div_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other, ::std::optional rounding_mode) { + return at::_ops::div_Scalar_mode_out::redispatch(dispatchKeySet, self, other, rounding_mode, out); + } + + // aten::div.Scalar_mode_out(Tensor self, Scalar other, *, str? rounding_mode, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & div_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, ::std::optional rounding_mode, at::Tensor & out) { + return at::_ops::div_Scalar_mode_out::redispatch(dispatchKeySet, self, other, rounding_mode, out); + } + + // aten::embedding.out(Tensor weight, Tensor indices, SymInt padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & embedding_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & weight, const at::Tensor & indices, int64_t padding_idx=-1, bool scale_grad_by_freq=false, bool sparse=false) { + return at::_ops::embedding_out::redispatch(dispatchKeySet, weight, indices, padding_idx, scale_grad_by_freq, sparse, out); + } + + // aten::embedding.out(Tensor weight, Tensor indices, SymInt padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & embedding_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & weight, const at::Tensor & indices, int64_t padding_idx, bool scale_grad_by_freq, bool sparse, at::Tensor & out) { + return at::_ops::embedding_out::redispatch(dispatchKeySet, weight, indices, padding_idx, scale_grad_by_freq, sparse, out); + } + + // aten::embedding.out(Tensor weight, Tensor indices, SymInt padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & embedding_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & weight, const at::Tensor & indices, c10::SymInt padding_idx=-1, bool scale_grad_by_freq=false, bool sparse=false) { + return at::_ops::embedding_out::redispatch(dispatchKeySet, weight, indices, padding_idx, scale_grad_by_freq, sparse, out); + } + + // aten::embedding.out(Tensor weight, Tensor indices, SymInt padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & embedding_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & weight, const at::Tensor & indices, c10::SymInt padding_idx, bool scale_grad_by_freq, bool sparse, at::Tensor & out) { + return at::_ops::embedding_out::redispatch(dispatchKeySet, weight, indices, padding_idx, scale_grad_by_freq, sparse, out); + } + + // aten::embedding_dense_backward.out(Tensor grad_output, Tensor indices, SymInt num_weights, SymInt padding_idx, bool scale_grad_by_freq, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & embedding_dense_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad_output, const at::Tensor & indices, int64_t num_weights, int64_t padding_idx, bool scale_grad_by_freq) { + return at::_ops::embedding_dense_backward_out::redispatch(dispatchKeySet, grad_output, indices, num_weights, padding_idx, scale_grad_by_freq, out); + } + + // aten::embedding_dense_backward.out(Tensor grad_output, Tensor indices, SymInt num_weights, SymInt padding_idx, bool scale_grad_by_freq, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & embedding_dense_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & indices, int64_t num_weights, int64_t padding_idx, bool scale_grad_by_freq, at::Tensor & out) { + return at::_ops::embedding_dense_backward_out::redispatch(dispatchKeySet, grad_output, indices, num_weights, padding_idx, scale_grad_by_freq, out); + } + + // aten::embedding_dense_backward.out(Tensor grad_output, Tensor indices, SymInt num_weights, SymInt padding_idx, bool scale_grad_by_freq, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & embedding_dense_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad_output, const at::Tensor & indices, c10::SymInt num_weights, c10::SymInt padding_idx, bool scale_grad_by_freq) { + return at::_ops::embedding_dense_backward_out::redispatch(dispatchKeySet, grad_output, indices, num_weights, padding_idx, scale_grad_by_freq, out); + } + + // aten::embedding_dense_backward.out(Tensor grad_output, Tensor indices, SymInt num_weights, SymInt padding_idx, bool scale_grad_by_freq, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & embedding_dense_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & indices, c10::SymInt num_weights, c10::SymInt padding_idx, bool scale_grad_by_freq, at::Tensor & out) { + return at::_ops::embedding_dense_backward_out::redispatch(dispatchKeySet, grad_output, indices, num_weights, padding_idx, scale_grad_by_freq, out); + } + + // aten::embedding_renorm.out(Tensor self, Tensor indices, float max_norm, float norm_type, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & embedding_renorm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & indices, double max_norm, double norm_type) { + return at::_ops::embedding_renorm_out::redispatch(dispatchKeySet, self, indices, max_norm, norm_type, out); + } + + // aten::embedding_renorm.out(Tensor self, Tensor indices, float max_norm, float norm_type, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & embedding_renorm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & indices, double max_norm, double norm_type, at::Tensor & out) { + return at::_ops::embedding_renorm_out::redispatch(dispatchKeySet, self, indices, max_norm, norm_type, out); + } + + // aten::embedding_renorm(Tensor self, Tensor indices, float max_norm, float norm_type) -> Tensor + inline at::Tensor embedding_renorm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & indices, double max_norm, double norm_type) { + return at::_ops::embedding_renorm::redispatch(dispatchKeySet, self, indices, max_norm, norm_type); + } + + // aten::_embedding_bag_forward_only.out(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False, int padding_idx=-1, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!)) + inline ::std::tuple _embedding_bag_forward_only_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3, const at::Tensor & weight, const at::Tensor & indices, const at::Tensor & offsets, bool scale_grad_by_freq=false, int64_t mode=0, bool sparse=false, const ::std::optional & per_sample_weights={}, bool include_last_offset=false, int64_t padding_idx=-1) { + return at::_ops::_embedding_bag_forward_only_out::redispatch(dispatchKeySet, weight, indices, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights, include_last_offset, padding_idx, out0, out1, out2, out3); + } + + // aten::_embedding_bag_forward_only.out(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False, int padding_idx=-1, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!)) + inline ::std::tuple _embedding_bag_forward_only_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & weight, const at::Tensor & indices, const at::Tensor & offsets, bool scale_grad_by_freq, int64_t mode, bool sparse, const ::std::optional & per_sample_weights, bool include_last_offset, int64_t padding_idx, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3) { + return at::_ops::_embedding_bag_forward_only_out::redispatch(dispatchKeySet, weight, indices, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights, include_last_offset, padding_idx, out0, out1, out2, out3); + } + + // aten::_embedding_bag.out(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False, int padding_idx=-1, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!)) + inline ::std::tuple _embedding_bag_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3, const at::Tensor & weight, const at::Tensor & indices, const at::Tensor & offsets, bool scale_grad_by_freq=false, int64_t mode=0, bool sparse=false, const ::std::optional & per_sample_weights={}, bool include_last_offset=false, int64_t padding_idx=-1) { + return at::_ops::_embedding_bag_out::redispatch(dispatchKeySet, weight, indices, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights, include_last_offset, padding_idx, out0, out1, out2, out3); + } + + // aten::_embedding_bag.out(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False, int padding_idx=-1, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!)) + inline ::std::tuple _embedding_bag_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & weight, const at::Tensor & indices, const at::Tensor & offsets, bool scale_grad_by_freq, int64_t mode, bool sparse, const ::std::optional & per_sample_weights, bool include_last_offset, int64_t padding_idx, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3) { + return at::_ops::_embedding_bag_out::redispatch(dispatchKeySet, weight, indices, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights, include_last_offset, padding_idx, out0, out1, out2, out3); + } + + // aten::_embedding_bag_dense_backward.out(Tensor grad, Tensor indices, Tensor offset2bag, Tensor bag_size, Tensor maximum_indices, SymInt num_weights, bool scale_grad_by_freq, int mode, Tensor? per_sample_weights, int padding_idx=-1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _embedding_bag_dense_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad, const at::Tensor & indices, const at::Tensor & offset2bag, const at::Tensor & bag_size, const at::Tensor & maximum_indices, int64_t num_weights, bool scale_grad_by_freq, int64_t mode, const ::std::optional & per_sample_weights, int64_t padding_idx=-1) { + return at::_ops::_embedding_bag_dense_backward_out::redispatch(dispatchKeySet, grad, indices, offset2bag, bag_size, maximum_indices, num_weights, scale_grad_by_freq, mode, per_sample_weights, padding_idx, out); + } + + // aten::_embedding_bag_dense_backward.out(Tensor grad, Tensor indices, Tensor offset2bag, Tensor bag_size, Tensor maximum_indices, SymInt num_weights, bool scale_grad_by_freq, int mode, Tensor? per_sample_weights, int padding_idx=-1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _embedding_bag_dense_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & indices, const at::Tensor & offset2bag, const at::Tensor & bag_size, const at::Tensor & maximum_indices, int64_t num_weights, bool scale_grad_by_freq, int64_t mode, const ::std::optional & per_sample_weights, int64_t padding_idx, at::Tensor & out) { + return at::_ops::_embedding_bag_dense_backward_out::redispatch(dispatchKeySet, grad, indices, offset2bag, bag_size, maximum_indices, num_weights, scale_grad_by_freq, mode, per_sample_weights, padding_idx, out); + } + + // aten::_embedding_bag_dense_backward.out(Tensor grad, Tensor indices, Tensor offset2bag, Tensor bag_size, Tensor maximum_indices, SymInt num_weights, bool scale_grad_by_freq, int mode, Tensor? per_sample_weights, int padding_idx=-1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _embedding_bag_dense_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad, const at::Tensor & indices, const at::Tensor & offset2bag, const at::Tensor & bag_size, const at::Tensor & maximum_indices, c10::SymInt num_weights, bool scale_grad_by_freq, int64_t mode, const ::std::optional & per_sample_weights, int64_t padding_idx=-1) { + return at::_ops::_embedding_bag_dense_backward_out::redispatch(dispatchKeySet, grad, indices, offset2bag, bag_size, maximum_indices, num_weights, scale_grad_by_freq, mode, per_sample_weights, padding_idx, out); + } + + // aten::_embedding_bag_dense_backward.out(Tensor grad, Tensor indices, Tensor offset2bag, Tensor bag_size, Tensor maximum_indices, SymInt num_weights, bool scale_grad_by_freq, int mode, Tensor? per_sample_weights, int padding_idx=-1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _embedding_bag_dense_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & indices, const at::Tensor & offset2bag, const at::Tensor & bag_size, const at::Tensor & maximum_indices, c10::SymInt num_weights, bool scale_grad_by_freq, int64_t mode, const ::std::optional & per_sample_weights, int64_t padding_idx, at::Tensor & out) { + return at::_ops::_embedding_bag_dense_backward_out::redispatch(dispatchKeySet, grad, indices, offset2bag, bag_size, maximum_indices, num_weights, scale_grad_by_freq, mode, per_sample_weights, padding_idx, out); + } + + // aten::_embedding_bag_per_sample_weights_backward.out(Tensor grad, Tensor weight, Tensor indices, Tensor offsets, Tensor offset2bag, int mode, int padding_idx=-1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _embedding_bag_per_sample_weights_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad, const at::Tensor & weight, const at::Tensor & indices, const at::Tensor & offsets, const at::Tensor & offset2bag, int64_t mode, int64_t padding_idx=-1) { + return at::_ops::_embedding_bag_per_sample_weights_backward_out::redispatch(dispatchKeySet, grad, weight, indices, offsets, offset2bag, mode, padding_idx, out); + } + + // aten::_embedding_bag_per_sample_weights_backward.out(Tensor grad, Tensor weight, Tensor indices, Tensor offsets, Tensor offset2bag, int mode, int padding_idx=-1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _embedding_bag_per_sample_weights_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & weight, const at::Tensor & indices, const at::Tensor & offsets, const at::Tensor & offset2bag, int64_t mode, int64_t padding_idx, at::Tensor & out) { + return at::_ops::_embedding_bag_per_sample_weights_backward_out::redispatch(dispatchKeySet, grad, weight, indices, offsets, offset2bag, mode, padding_idx, out); + } + + // aten::empty.names_out(int[] size, *, Dimname[]? names, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & empty_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::IntArrayRef size, ::std::optional names, ::std::optional memory_format=::std::nullopt) { + return at::_ops::empty_names_out::redispatch(dispatchKeySet, size, names, memory_format, out); + } + + // aten::empty.names_out(int[] size, *, Dimname[]? names, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & empty_outf(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, ::std::optional names, ::std::optional memory_format, at::Tensor & out) { + return at::_ops::empty_names_out::redispatch(dispatchKeySet, size, names, memory_format, out); + } + + // aten::empty_permuted.out(SymInt[] size, int[] physical_layout, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & empty_permuted_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::IntArrayRef size, at::IntArrayRef physical_layout) { + return at::_ops::empty_permuted_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), physical_layout, out); + } + + // aten::empty_permuted.out(SymInt[] size, int[] physical_layout, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & empty_permuted_outf(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, at::IntArrayRef physical_layout, at::Tensor & out) { + return at::_ops::empty_permuted_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), physical_layout, out); + } + + // aten::empty_permuted.out(SymInt[] size, int[] physical_layout, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & empty_permuted_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, c10::SymIntArrayRef size, at::IntArrayRef physical_layout) { + return at::_ops::empty_permuted_out::redispatch(dispatchKeySet, size, physical_layout, out); + } + + // aten::empty_permuted.out(SymInt[] size, int[] physical_layout, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & empty_permuted_symint_outf(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, at::IntArrayRef physical_layout, at::Tensor & out) { + return at::_ops::empty_permuted_out::redispatch(dispatchKeySet, size, physical_layout, out); + } + + // aten::new_empty.out(Tensor self, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & new_empty_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef size) { + return at::_ops::new_empty_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), out); + } + + // aten::new_empty.out(Tensor self, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & new_empty_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, at::Tensor & out) { + return at::_ops::new_empty_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), out); + } + + // aten::new_empty.out(Tensor self, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & new_empty_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef size) { + return at::_ops::new_empty_out::redispatch(dispatchKeySet, self, size, out); + } + + // aten::new_empty.out(Tensor self, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & new_empty_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, at::Tensor & out) { + return at::_ops::new_empty_out::redispatch(dispatchKeySet, self, size, out); + } + + // aten::new_empty_strided.out(Tensor self, SymInt[] size, SymInt[] stride, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & new_empty_strided_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef size, at::IntArrayRef stride) { + return at::_ops::new_empty_strided_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride), out); + } + + // aten::new_empty_strided.out(Tensor self, SymInt[] size, SymInt[] stride, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & new_empty_strided_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, at::IntArrayRef stride, at::Tensor & out) { + return at::_ops::new_empty_strided_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride), out); + } + + // aten::new_empty_strided.out(Tensor self, SymInt[] size, SymInt[] stride, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & new_empty_strided_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride) { + return at::_ops::new_empty_strided_out::redispatch(dispatchKeySet, self, size, stride, out); + } + + // aten::new_empty_strided.out(Tensor self, SymInt[] size, SymInt[] stride, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & new_empty_strided_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, at::Tensor & out) { + return at::_ops::new_empty_strided_out::redispatch(dispatchKeySet, self, size, stride, out); + } + + // aten::new_full.out(Tensor self, SymInt[] size, Scalar fill_value, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & new_full_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef size, const at::Scalar & fill_value) { + return at::_ops::new_full_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), fill_value, out); + } + + // aten::new_full.out(Tensor self, SymInt[] size, Scalar fill_value, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & new_full_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, const at::Scalar & fill_value, at::Tensor & out) { + return at::_ops::new_full_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), fill_value, out); + } + + // aten::new_full.out(Tensor self, SymInt[] size, Scalar fill_value, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & new_full_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef size, const at::Scalar & fill_value) { + return at::_ops::new_full_out::redispatch(dispatchKeySet, self, size, fill_value, out); + } + + // aten::new_full.out(Tensor self, SymInt[] size, Scalar fill_value, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & new_full_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, const at::Scalar & fill_value, at::Tensor & out) { + return at::_ops::new_full_out::redispatch(dispatchKeySet, self, size, fill_value, out); + } + + // aten::new_zeros.out(Tensor self, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & new_zeros_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef size) { + return at::_ops::new_zeros_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), out); + } + + // aten::new_zeros.out(Tensor self, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & new_zeros_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, at::Tensor & out) { + return at::_ops::new_zeros_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), out); + } + + // aten::new_zeros.out(Tensor self, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & new_zeros_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef size) { + return at::_ops::new_zeros_out::redispatch(dispatchKeySet, self, size, out); + } + + // aten::new_zeros.out(Tensor self, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & new_zeros_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, at::Tensor & out) { + return at::_ops::new_zeros_out::redispatch(dispatchKeySet, self, size, out); + } + + // aten::new_ones.out(Tensor self, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & new_ones_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef size) { + return at::_ops::new_ones_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), out); + } + + // aten::new_ones.out(Tensor self, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & new_ones_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, at::Tensor & out) { + return at::_ops::new_ones_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), out); + } + + // aten::new_ones.out(Tensor self, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & new_ones_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef size) { + return at::_ops::new_ones_out::redispatch(dispatchKeySet, self, size, out); + } + + // aten::new_ones.out(Tensor self, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & new_ones_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, at::Tensor & out) { + return at::_ops::new_ones_out::redispatch(dispatchKeySet, self, size, out); + } + + // aten::_empty_affine_quantized.out(SymInt[] size, *, float scale=1, int zero_point=0, MemoryFormat? memory_format=contiguous_format, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _empty_affine_quantized_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::IntArrayRef size, double scale=1, int64_t zero_point=0, ::std::optional memory_format=c10::MemoryFormat::Contiguous) { + return at::_ops::_empty_affine_quantized_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), scale, zero_point, memory_format, out); + } + + // aten::_empty_affine_quantized.out(SymInt[] size, *, float scale=1, int zero_point=0, MemoryFormat? memory_format=contiguous_format, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _empty_affine_quantized_outf(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, double scale, int64_t zero_point, ::std::optional memory_format, at::Tensor & out) { + return at::_ops::_empty_affine_quantized_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), scale, zero_point, memory_format, out); + } + + // aten::_empty_affine_quantized.out(SymInt[] size, *, float scale=1, int zero_point=0, MemoryFormat? memory_format=contiguous_format, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _empty_affine_quantized_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, c10::SymIntArrayRef size, double scale=1, int64_t zero_point=0, ::std::optional memory_format=c10::MemoryFormat::Contiguous) { + return at::_ops::_empty_affine_quantized_out::redispatch(dispatchKeySet, size, scale, zero_point, memory_format, out); + } + + // aten::_empty_affine_quantized.out(SymInt[] size, *, float scale=1, int zero_point=0, MemoryFormat? memory_format=contiguous_format, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _empty_affine_quantized_symint_outf(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, double scale, int64_t zero_point, ::std::optional memory_format, at::Tensor & out) { + return at::_ops::_empty_affine_quantized_out::redispatch(dispatchKeySet, size, scale, zero_point, memory_format, out); + } + + // aten::_empty_per_channel_affine_quantized.out(SymInt[] size, *, Tensor scales, Tensor zero_points, int axis, MemoryFormat? memory_format=contiguous_format, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _empty_per_channel_affine_quantized_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::IntArrayRef size, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, ::std::optional memory_format=c10::MemoryFormat::Contiguous) { + return at::_ops::_empty_per_channel_affine_quantized_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), scales, zero_points, axis, memory_format, out); + } + + // aten::_empty_per_channel_affine_quantized.out(SymInt[] size, *, Tensor scales, Tensor zero_points, int axis, MemoryFormat? memory_format=contiguous_format, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _empty_per_channel_affine_quantized_outf(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, ::std::optional memory_format, at::Tensor & out) { + return at::_ops::_empty_per_channel_affine_quantized_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), scales, zero_points, axis, memory_format, out); + } + + // aten::_empty_per_channel_affine_quantized.out(SymInt[] size, *, Tensor scales, Tensor zero_points, int axis, MemoryFormat? memory_format=contiguous_format, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _empty_per_channel_affine_quantized_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, c10::SymIntArrayRef size, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, ::std::optional memory_format=c10::MemoryFormat::Contiguous) { + return at::_ops::_empty_per_channel_affine_quantized_out::redispatch(dispatchKeySet, size, scales, zero_points, axis, memory_format, out); + } + + // aten::_empty_per_channel_affine_quantized.out(SymInt[] size, *, Tensor scales, Tensor zero_points, int axis, MemoryFormat? memory_format=contiguous_format, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _empty_per_channel_affine_quantized_symint_outf(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, ::std::optional memory_format, at::Tensor & out) { + return at::_ops::_empty_per_channel_affine_quantized_out::redispatch(dispatchKeySet, size, scales, zero_points, axis, memory_format, out); + } + + // aten::resize.out(Tensor self, SymInt[] size, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) + inline const at::Tensor & resize_out(c10::DispatchKeySet dispatchKeySet, const at::Tensor & out, const at::Tensor & self, at::IntArrayRef size, ::std::optional memory_format=::std::nullopt) { + return at::_ops::resize_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), memory_format, out); + } + + // aten::resize.out(Tensor self, SymInt[] size, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) + inline const at::Tensor & resize_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, ::std::optional memory_format, const at::Tensor & out) { + return at::_ops::resize_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), memory_format, out); + } + + // aten::resize.out(Tensor self, SymInt[] size, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) + inline const at::Tensor & resize_symint_out(c10::DispatchKeySet dispatchKeySet, const at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef size, ::std::optional memory_format=::std::nullopt) { + return at::_ops::resize_out::redispatch(dispatchKeySet, self, size, memory_format, out); + } + + // aten::resize.out(Tensor self, SymInt[] size, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) + inline const at::Tensor & resize_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, ::std::optional memory_format, const at::Tensor & out) { + return at::_ops::resize_out::redispatch(dispatchKeySet, self, size, memory_format, out); + } + + // aten::resize(Tensor self, SymInt[] size, *, MemoryFormat? memory_format=None) -> Tensor + inline at::Tensor resize(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, ::std::optional memory_format=::std::nullopt) { + return at::_ops::resize::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), memory_format); + } + + // aten::resize(Tensor self, SymInt[] size, *, MemoryFormat? memory_format=None) -> Tensor + inline at::Tensor resize_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, ::std::optional memory_format=::std::nullopt) { + return at::_ops::resize::redispatch(dispatchKeySet, self, size, memory_format); + } + + // aten::_resize_output.out(Tensor self, SymInt[] size, Device device, *, Tensor(a!) out) -> Tensor(a!) + inline const at::Tensor & _resize_output_out(c10::DispatchKeySet dispatchKeySet, const at::Tensor & out, const at::Tensor & self, at::IntArrayRef size, at::Device device) { + return at::_ops::_resize_output_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), device, out); + } + + // aten::_resize_output.out(Tensor self, SymInt[] size, Device device, *, Tensor(a!) out) -> Tensor(a!) + inline const at::Tensor & _resize_output_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, at::Device device, const at::Tensor & out) { + return at::_ops::_resize_output_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), device, out); + } + + // aten::_resize_output.out(Tensor self, SymInt[] size, Device device, *, Tensor(a!) out) -> Tensor(a!) + inline const at::Tensor & _resize_output_symint_out(c10::DispatchKeySet dispatchKeySet, const at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef size, at::Device device) { + return at::_ops::_resize_output_out::redispatch(dispatchKeySet, self, size, device, out); + } + + // aten::_resize_output.out(Tensor self, SymInt[] size, Device device, *, Tensor(a!) out) -> Tensor(a!) + inline const at::Tensor & _resize_output_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, at::Device device, const at::Tensor & out) { + return at::_ops::_resize_output_out::redispatch(dispatchKeySet, self, size, device, out); + } + + // aten::_resize_output(Tensor self, SymInt[] size, Device device) -> Tensor + inline at::Tensor _resize_output(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, at::Device device) { + return at::_ops::_resize_output::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), device); + } + + // aten::_resize_output(Tensor self, SymInt[] size, Device device) -> Tensor + inline at::Tensor _resize_output_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, at::Device device) { + return at::_ops::_resize_output::redispatch(dispatchKeySet, self, size, device); + } + + // aten::empty_quantized.out(int[] size, Tensor qtensor, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & empty_quantized_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::IntArrayRef size, const at::Tensor & qtensor, ::std::optional memory_format=::std::nullopt) { + return at::_ops::empty_quantized_out::redispatch(dispatchKeySet, size, qtensor, memory_format, out); + } + + // aten::empty_quantized.out(int[] size, Tensor qtensor, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & empty_quantized_outf(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, const at::Tensor & qtensor, ::std::optional memory_format, at::Tensor & out) { + return at::_ops::empty_quantized_out::redispatch(dispatchKeySet, size, qtensor, memory_format, out); + } + + // aten::empty_like.out(Tensor self, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & empty_like_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, ::std::optional memory_format=::std::nullopt) { + return at::_ops::empty_like_out::redispatch(dispatchKeySet, self, memory_format, out); + } + + // aten::empty_like.out(Tensor self, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & empty_like_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional memory_format, at::Tensor & out) { + return at::_ops::empty_like_out::redispatch(dispatchKeySet, self, memory_format, out); + } + + // aten::empty_strided.out(SymInt[] size, SymInt[] stride, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & empty_strided_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::IntArrayRef size, at::IntArrayRef stride) { + return at::_ops::empty_strided_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride), out); + } + + // aten::empty_strided.out(SymInt[] size, SymInt[] stride, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & empty_strided_outf(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, at::IntArrayRef stride, at::Tensor & out) { + return at::_ops::empty_strided_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride), out); + } + + // aten::empty_strided.out(SymInt[] size, SymInt[] stride, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & empty_strided_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, c10::SymIntArrayRef size, c10::SymIntArrayRef stride) { + return at::_ops::empty_strided_out::redispatch(dispatchKeySet, size, stride, out); + } + + // aten::empty_strided.out(SymInt[] size, SymInt[] stride, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & empty_strided_symint_outf(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, at::Tensor & out) { + return at::_ops::empty_strided_out::redispatch(dispatchKeySet, size, stride, out); + } + + // aten::fill.Scalar_out(Tensor self, Scalar value, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fill_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & value) { + return at::_ops::fill_Scalar_out::redispatch(dispatchKeySet, self, value, out); + } + + // aten::fill.Scalar_out(Tensor self, Scalar value, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fill_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & value, at::Tensor & out) { + return at::_ops::fill_Scalar_out::redispatch(dispatchKeySet, self, value, out); + } + + // aten::fill.Tensor_out(Tensor self, Tensor value, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fill_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & value) { + return at::_ops::fill_Tensor_out::redispatch(dispatchKeySet, self, value, out); + } + + // aten::fill.Tensor_out(Tensor self, Tensor value, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fill_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & value, at::Tensor & out) { + return at::_ops::fill_Tensor_out::redispatch(dispatchKeySet, self, value, out); + } + + // aten::floor_divide.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & floor_divide_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::floor_divide_Scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::floor_divide.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & floor_divide_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, at::Tensor & out) { + return at::_ops::floor_divide_Scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::full.names_out(int[] size, Scalar fill_value, *, Dimname[]? names, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & full_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::IntArrayRef size, const at::Scalar & fill_value, ::std::optional names) { + return at::_ops::full_names_out::redispatch(dispatchKeySet, size, fill_value, names, out); + } + + // aten::full.names_out(int[] size, Scalar fill_value, *, Dimname[]? names, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & full_outf(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, const at::Scalar & fill_value, ::std::optional names, at::Tensor & out) { + return at::_ops::full_names_out::redispatch(dispatchKeySet, size, fill_value, names, out); + } + + // aten::full_like.out(Tensor self, Scalar fill_value, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & full_like_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & fill_value, ::std::optional memory_format=::std::nullopt) { + return at::_ops::full_like_out::redispatch(dispatchKeySet, self, fill_value, memory_format, out); + } + + // aten::full_like.out(Tensor self, Scalar fill_value, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & full_like_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & fill_value, ::std::optional memory_format, at::Tensor & out) { + return at::_ops::full_like_out::redispatch(dispatchKeySet, self, fill_value, memory_format, out); + } + + // aten::from_file.out(str filename, bool? shared=None, int? size=0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & from_file_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, c10::string_view filename, ::std::optional shared=::std::nullopt, ::std::optional size=0) { + return at::_ops::from_file_out::redispatch(dispatchKeySet, filename, shared, size, out); + } + + // aten::from_file.out(str filename, bool? shared=None, int? size=0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & from_file_outf(c10::DispatchKeySet dispatchKeySet, c10::string_view filename, ::std::optional shared, ::std::optional size, at::Tensor & out) { + return at::_ops::from_file_out::redispatch(dispatchKeySet, filename, shared, size, out); + } + + // aten::grid_sampler_2d.out(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & grid_sampler_2d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners) { + return at::_ops::grid_sampler_2d_out::redispatch(dispatchKeySet, input, grid, interpolation_mode, padding_mode, align_corners, out); + } + + // aten::grid_sampler_2d.out(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & grid_sampler_2d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners, at::Tensor & out) { + return at::_ops::grid_sampler_2d_out::redispatch(dispatchKeySet, input, grid, interpolation_mode, padding_mode, align_corners, out); + } + + // aten::grid_sampler_2d_backward.out(Tensor grad_output, Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners, bool[2] output_mask, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple grid_sampler_2d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners, ::std::array output_mask) { + return at::_ops::grid_sampler_2d_backward_out::redispatch(dispatchKeySet, grad_output, input, grid, interpolation_mode, padding_mode, align_corners, output_mask, out0, out1); + } + + // aten::grid_sampler_2d_backward.out(Tensor grad_output, Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners, bool[2] output_mask, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple grid_sampler_2d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners, ::std::array output_mask, at::Tensor & out0, at::Tensor & out1) { + return at::_ops::grid_sampler_2d_backward_out::redispatch(dispatchKeySet, grad_output, input, grid, interpolation_mode, padding_mode, align_corners, output_mask, out0, out1); + } + + // aten::_grid_sampler_2d_cpu_fallback.out(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _grid_sampler_2d_cpu_fallback_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners) { + return at::_ops::_grid_sampler_2d_cpu_fallback_out::redispatch(dispatchKeySet, input, grid, interpolation_mode, padding_mode, align_corners, out); + } + + // aten::_grid_sampler_2d_cpu_fallback.out(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _grid_sampler_2d_cpu_fallback_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners, at::Tensor & out) { + return at::_ops::_grid_sampler_2d_cpu_fallback_out::redispatch(dispatchKeySet, input, grid, interpolation_mode, padding_mode, align_corners, out); + } + + // aten::grid_sampler_3d.out(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & grid_sampler_3d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners) { + return at::_ops::grid_sampler_3d_out::redispatch(dispatchKeySet, input, grid, interpolation_mode, padding_mode, align_corners, out); + } + + // aten::grid_sampler_3d.out(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & grid_sampler_3d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners, at::Tensor & out) { + return at::_ops::grid_sampler_3d_out::redispatch(dispatchKeySet, input, grid, interpolation_mode, padding_mode, align_corners, out); + } + + // aten::grid_sampler_3d_backward.out(Tensor grad_output, Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners, bool[2] output_mask, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple grid_sampler_3d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners, ::std::array output_mask) { + return at::_ops::grid_sampler_3d_backward_out::redispatch(dispatchKeySet, grad_output, input, grid, interpolation_mode, padding_mode, align_corners, output_mask, out0, out1); + } + + // aten::grid_sampler_3d_backward.out(Tensor grad_output, Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners, bool[2] output_mask, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple grid_sampler_3d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners, ::std::array output_mask, at::Tensor & out0, at::Tensor & out1) { + return at::_ops::grid_sampler_3d_backward_out::redispatch(dispatchKeySet, grad_output, input, grid, interpolation_mode, padding_mode, align_corners, output_mask, out0, out1); + } + + // aten::hann_window.out(int window_length, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & hann_window_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, int64_t window_length) { + return at::_ops::hann_window_out::redispatch(dispatchKeySet, window_length, out); + } + + // aten::hann_window.out(int window_length, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & hann_window_outf(c10::DispatchKeySet dispatchKeySet, int64_t window_length, at::Tensor & out) { + return at::_ops::hann_window_out::redispatch(dispatchKeySet, window_length, out); + } + + // aten::hann_window.periodic_out(int window_length, bool periodic, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & hann_window_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, int64_t window_length, bool periodic) { + return at::_ops::hann_window_periodic_out::redispatch(dispatchKeySet, window_length, periodic, out); + } + + // aten::hann_window.periodic_out(int window_length, bool periodic, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & hann_window_outf(c10::DispatchKeySet dispatchKeySet, int64_t window_length, bool periodic, at::Tensor & out) { + return at::_ops::hann_window_periodic_out::redispatch(dispatchKeySet, window_length, periodic, out); + } + + // aten::hamming_window.out(int window_length, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & hamming_window_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, int64_t window_length) { + return at::_ops::hamming_window_out::redispatch(dispatchKeySet, window_length, out); + } + + // aten::hamming_window.out(int window_length, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & hamming_window_outf(c10::DispatchKeySet dispatchKeySet, int64_t window_length, at::Tensor & out) { + return at::_ops::hamming_window_out::redispatch(dispatchKeySet, window_length, out); + } + + // aten::hamming_window.periodic_out(int window_length, bool periodic, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & hamming_window_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, int64_t window_length, bool periodic) { + return at::_ops::hamming_window_periodic_out::redispatch(dispatchKeySet, window_length, periodic, out); + } + + // aten::hamming_window.periodic_out(int window_length, bool periodic, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & hamming_window_outf(c10::DispatchKeySet dispatchKeySet, int64_t window_length, bool periodic, at::Tensor & out) { + return at::_ops::hamming_window_periodic_out::redispatch(dispatchKeySet, window_length, periodic, out); + } + + // aten::hamming_window.periodic_alpha_out(int window_length, bool periodic, float alpha, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & hamming_window_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, int64_t window_length, bool periodic, double alpha) { + return at::_ops::hamming_window_periodic_alpha_out::redispatch(dispatchKeySet, window_length, periodic, alpha, out); + } + + // aten::hamming_window.periodic_alpha_out(int window_length, bool periodic, float alpha, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & hamming_window_outf(c10::DispatchKeySet dispatchKeySet, int64_t window_length, bool periodic, double alpha, at::Tensor & out) { + return at::_ops::hamming_window_periodic_alpha_out::redispatch(dispatchKeySet, window_length, periodic, alpha, out); + } + + // aten::hamming_window.periodic_alpha_beta_out(int window_length, bool periodic, float alpha, float beta, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & hamming_window_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, int64_t window_length, bool periodic, double alpha, double beta) { + return at::_ops::hamming_window_periodic_alpha_beta_out::redispatch(dispatchKeySet, window_length, periodic, alpha, beta, out); + } + + // aten::hamming_window.periodic_alpha_beta_out(int window_length, bool periodic, float alpha, float beta, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & hamming_window_outf(c10::DispatchKeySet dispatchKeySet, int64_t window_length, bool periodic, double alpha, double beta, at::Tensor & out) { + return at::_ops::hamming_window_periodic_alpha_beta_out::redispatch(dispatchKeySet, window_length, periodic, alpha, beta, out); + } + + // aten::kaiser_window.out(int window_length, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & kaiser_window_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, int64_t window_length) { + return at::_ops::kaiser_window_out::redispatch(dispatchKeySet, window_length, out); + } + + // aten::kaiser_window.out(int window_length, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & kaiser_window_outf(c10::DispatchKeySet dispatchKeySet, int64_t window_length, at::Tensor & out) { + return at::_ops::kaiser_window_out::redispatch(dispatchKeySet, window_length, out); + } + + // aten::kaiser_window.periodic_out(int window_length, bool periodic, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & kaiser_window_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, int64_t window_length, bool periodic) { + return at::_ops::kaiser_window_periodic_out::redispatch(dispatchKeySet, window_length, periodic, out); + } + + // aten::kaiser_window.periodic_out(int window_length, bool periodic, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & kaiser_window_outf(c10::DispatchKeySet dispatchKeySet, int64_t window_length, bool periodic, at::Tensor & out) { + return at::_ops::kaiser_window_periodic_out::redispatch(dispatchKeySet, window_length, periodic, out); + } + + // aten::kaiser_window.beta_out(int window_length, bool periodic, float beta, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & kaiser_window_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, int64_t window_length, bool periodic, double beta) { + return at::_ops::kaiser_window_beta_out::redispatch(dispatchKeySet, window_length, periodic, beta, out); + } + + // aten::kaiser_window.beta_out(int window_length, bool periodic, float beta, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & kaiser_window_outf(c10::DispatchKeySet dispatchKeySet, int64_t window_length, bool periodic, double beta, at::Tensor & out) { + return at::_ops::kaiser_window_beta_out::redispatch(dispatchKeySet, window_length, periodic, beta, out); + } + + // aten::native_group_norm.out(Tensor input, Tensor? weight, Tensor? bias, SymInt N, SymInt C, SymInt HxW, int group, float eps, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple native_group_norm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, int64_t N, int64_t C, int64_t HxW, int64_t group, double eps) { + return at::_ops::native_group_norm_out::redispatch(dispatchKeySet, input, weight, bias, N, C, HxW, group, eps, out0, out1, out2); + } + + // aten::native_group_norm.out(Tensor input, Tensor? weight, Tensor? bias, SymInt N, SymInt C, SymInt HxW, int group, float eps, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple native_group_norm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, int64_t N, int64_t C, int64_t HxW, int64_t group, double eps, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2) { + return at::_ops::native_group_norm_out::redispatch(dispatchKeySet, input, weight, bias, N, C, HxW, group, eps, out0, out1, out2); + } + + // aten::native_group_norm.out(Tensor input, Tensor? weight, Tensor? bias, SymInt N, SymInt C, SymInt HxW, int group, float eps, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple native_group_norm_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, c10::SymInt N, c10::SymInt C, c10::SymInt HxW, int64_t group, double eps) { + return at::_ops::native_group_norm_out::redispatch(dispatchKeySet, input, weight, bias, N, C, HxW, group, eps, out0, out1, out2); + } + + // aten::native_group_norm.out(Tensor input, Tensor? weight, Tensor? bias, SymInt N, SymInt C, SymInt HxW, int group, float eps, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple native_group_norm_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, c10::SymInt N, c10::SymInt C, c10::SymInt HxW, int64_t group, double eps, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2) { + return at::_ops::native_group_norm_out::redispatch(dispatchKeySet, input, weight, bias, N, C, HxW, group, eps, out0, out1, out2); + } + + // aten::native_group_norm_backward.out(Tensor grad_out, Tensor input, Tensor mean, Tensor rstd, Tensor? weight, SymInt N, SymInt C, SymInt HxW, int group, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple native_group_norm_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & rstd, const ::std::optional & weight, int64_t N, int64_t C, int64_t HxW, int64_t group, ::std::array output_mask) { + return at::_ops::native_group_norm_backward_out::redispatch(dispatchKeySet, grad_out, input, mean, rstd, weight, N, C, HxW, group, output_mask, out0, out1, out2); + } + + // aten::native_group_norm_backward.out(Tensor grad_out, Tensor input, Tensor mean, Tensor rstd, Tensor? weight, SymInt N, SymInt C, SymInt HxW, int group, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple native_group_norm_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & rstd, const ::std::optional & weight, int64_t N, int64_t C, int64_t HxW, int64_t group, ::std::array output_mask, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2) { + return at::_ops::native_group_norm_backward_out::redispatch(dispatchKeySet, grad_out, input, mean, rstd, weight, N, C, HxW, group, output_mask, out0, out1, out2); + } + + // aten::native_group_norm_backward.out(Tensor grad_out, Tensor input, Tensor mean, Tensor rstd, Tensor? weight, SymInt N, SymInt C, SymInt HxW, int group, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple native_group_norm_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & rstd, const ::std::optional & weight, c10::SymInt N, c10::SymInt C, c10::SymInt HxW, int64_t group, ::std::array output_mask) { + return at::_ops::native_group_norm_backward_out::redispatch(dispatchKeySet, grad_out, input, mean, rstd, weight, N, C, HxW, group, output_mask, out0, out1, out2); + } + + // aten::native_group_norm_backward.out(Tensor grad_out, Tensor input, Tensor mean, Tensor rstd, Tensor? weight, SymInt N, SymInt C, SymInt HxW, int group, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple native_group_norm_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & rstd, const ::std::optional & weight, c10::SymInt N, c10::SymInt C, c10::SymInt HxW, int64_t group, ::std::array output_mask, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2) { + return at::_ops::native_group_norm_backward_out::redispatch(dispatchKeySet, grad_out, input, mean, rstd, weight, N, C, HxW, group, output_mask, out0, out1, out2); + } + + // aten::index_put.out(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & index_put_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const c10::List<::std::optional> & indices, const at::Tensor & values, bool accumulate=false) { + return at::_ops::index_put_out::redispatch(dispatchKeySet, self, indices, values, accumulate, out); + } + + // aten::index_put.out(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & index_put_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const c10::List<::std::optional> & indices, const at::Tensor & values, bool accumulate, at::Tensor & out) { + return at::_ops::index_put_out::redispatch(dispatchKeySet, self, indices, values, accumulate, out); + } + + // aten::_index_put_impl.out(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False, bool unsafe=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _index_put_impl_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const c10::List<::std::optional> & indices, const at::Tensor & values, bool accumulate=false, bool unsafe=false) { + return at::_ops::_index_put_impl_out::redispatch(dispatchKeySet, self, indices, values, accumulate, unsafe, out); + } + + // aten::_index_put_impl.out(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False, bool unsafe=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _index_put_impl_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const c10::List<::std::optional> & indices, const at::Tensor & values, bool accumulate, bool unsafe, at::Tensor & out) { + return at::_ops::_index_put_impl_out::redispatch(dispatchKeySet, self, indices, values, accumulate, unsafe, out); + } + + // aten::_index_put_impl(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False, bool unsafe=False) -> Tensor + inline at::Tensor _index_put_impl(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const c10::List<::std::optional> & indices, const at::Tensor & values, bool accumulate=false, bool unsafe=false) { + return at::_ops::_index_put_impl::redispatch(dispatchKeySet, self, indices, values, accumulate, unsafe); + } + + // aten::isnan.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & isnan_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::isnan_out::redispatch(dispatchKeySet, self, out); + } + + // aten::isnan.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & isnan_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::isnan_out::redispatch(dispatchKeySet, self, out); + } + + // aten::native_layer_norm.out(Tensor input, SymInt[] normalized_shape, Tensor? weight, Tensor? bias, float eps, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple native_layer_norm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, const at::Tensor & input, at::IntArrayRef normalized_shape, const ::std::optional & weight, const ::std::optional & bias, double eps) { + return at::_ops::native_layer_norm_out::redispatch(dispatchKeySet, input, c10::fromIntArrayRefSlow(normalized_shape), weight, bias, eps, out0, out1, out2); + } + + // aten::native_layer_norm.out(Tensor input, SymInt[] normalized_shape, Tensor? weight, Tensor? bias, float eps, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple native_layer_norm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::IntArrayRef normalized_shape, const ::std::optional & weight, const ::std::optional & bias, double eps, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2) { + return at::_ops::native_layer_norm_out::redispatch(dispatchKeySet, input, c10::fromIntArrayRefSlow(normalized_shape), weight, bias, eps, out0, out1, out2); + } + + // aten::native_layer_norm.out(Tensor input, SymInt[] normalized_shape, Tensor? weight, Tensor? bias, float eps, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple native_layer_norm_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, const at::Tensor & input, c10::SymIntArrayRef normalized_shape, const ::std::optional & weight, const ::std::optional & bias, double eps) { + return at::_ops::native_layer_norm_out::redispatch(dispatchKeySet, input, normalized_shape, weight, bias, eps, out0, out1, out2); + } + + // aten::native_layer_norm.out(Tensor input, SymInt[] normalized_shape, Tensor? weight, Tensor? bias, float eps, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple native_layer_norm_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, c10::SymIntArrayRef normalized_shape, const ::std::optional & weight, const ::std::optional & bias, double eps, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2) { + return at::_ops::native_layer_norm_out::redispatch(dispatchKeySet, input, normalized_shape, weight, bias, eps, out0, out1, out2); + } + + // aten::native_layer_norm_backward.out(Tensor grad_out, Tensor input, SymInt[] normalized_shape, Tensor mean, Tensor rstd, Tensor? weight, Tensor? bias, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple native_layer_norm_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, const at::Tensor & grad_out, const at::Tensor & input, at::IntArrayRef normalized_shape, const at::Tensor & mean, const at::Tensor & rstd, const ::std::optional & weight, const ::std::optional & bias, ::std::array output_mask) { + return at::_ops::native_layer_norm_backward_out::redispatch(dispatchKeySet, grad_out, input, c10::fromIntArrayRefSlow(normalized_shape), mean, rstd, weight, bias, output_mask, out0, out1, out2); + } + + // aten::native_layer_norm_backward.out(Tensor grad_out, Tensor input, SymInt[] normalized_shape, Tensor mean, Tensor rstd, Tensor? weight, Tensor? bias, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple native_layer_norm_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_out, const at::Tensor & input, at::IntArrayRef normalized_shape, const at::Tensor & mean, const at::Tensor & rstd, const ::std::optional & weight, const ::std::optional & bias, ::std::array output_mask, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2) { + return at::_ops::native_layer_norm_backward_out::redispatch(dispatchKeySet, grad_out, input, c10::fromIntArrayRefSlow(normalized_shape), mean, rstd, weight, bias, output_mask, out0, out1, out2); + } + + // aten::native_layer_norm_backward.out(Tensor grad_out, Tensor input, SymInt[] normalized_shape, Tensor mean, Tensor rstd, Tensor? weight, Tensor? bias, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple native_layer_norm_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, const at::Tensor & grad_out, const at::Tensor & input, c10::SymIntArrayRef normalized_shape, const at::Tensor & mean, const at::Tensor & rstd, const ::std::optional & weight, const ::std::optional & bias, ::std::array output_mask) { + return at::_ops::native_layer_norm_backward_out::redispatch(dispatchKeySet, grad_out, input, normalized_shape, mean, rstd, weight, bias, output_mask, out0, out1, out2); + } + + // aten::native_layer_norm_backward.out(Tensor grad_out, Tensor input, SymInt[] normalized_shape, Tensor mean, Tensor rstd, Tensor? weight, Tensor? bias, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple native_layer_norm_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_out, const at::Tensor & input, c10::SymIntArrayRef normalized_shape, const at::Tensor & mean, const at::Tensor & rstd, const ::std::optional & weight, const ::std::optional & bias, ::std::array output_mask, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2) { + return at::_ops::native_layer_norm_backward_out::redispatch(dispatchKeySet, grad_out, input, normalized_shape, mean, rstd, weight, bias, output_mask, out0, out1, out2); + } + + // aten::linear_backward.out(Tensor self, Tensor grad_output, Tensor weight, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple linear_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, const at::Tensor & self, const at::Tensor & grad_output, const at::Tensor & weight, ::std::array output_mask) { + return at::_ops::linear_backward_out::redispatch(dispatchKeySet, self, grad_output, weight, output_mask, out0, out1, out2); + } + + // aten::linear_backward.out(Tensor self, Tensor grad_output, Tensor weight, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple linear_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & grad_output, const at::Tensor & weight, ::std::array output_mask, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2) { + return at::_ops::linear_backward_out::redispatch(dispatchKeySet, self, grad_output, weight, output_mask, out0, out1, out2); + } + + // aten::mkldnn_linear.out(Tensor self, Tensor weight, Tensor? bias=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mkldnn_linear_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias={}) { + return at::_ops::mkldnn_linear_out::redispatch(dispatchKeySet, self, weight, bias, out); + } + + // aten::mkldnn_linear.out(Tensor self, Tensor weight, Tensor? bias=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mkldnn_linear_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, at::Tensor & out) { + return at::_ops::mkldnn_linear_out::redispatch(dispatchKeySet, self, weight, bias, out); + } + + // aten::mkldnn_linear_backward_input.out(int[] input_size, Tensor grad_output, Tensor weight, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mkldnn_linear_backward_input_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::IntArrayRef input_size, const at::Tensor & grad_output, const at::Tensor & weight) { + return at::_ops::mkldnn_linear_backward_input_out::redispatch(dispatchKeySet, input_size, grad_output, weight, out); + } + + // aten::mkldnn_linear_backward_input.out(int[] input_size, Tensor grad_output, Tensor weight, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mkldnn_linear_backward_input_outf(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef input_size, const at::Tensor & grad_output, const at::Tensor & weight, at::Tensor & out) { + return at::_ops::mkldnn_linear_backward_input_out::redispatch(dispatchKeySet, input_size, grad_output, weight, out); + } + + // aten::mkldnn_linear_backward_weights.out(Tensor grad_output, Tensor input, Tensor weight, bool bias_defined, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple mkldnn_linear_backward_weights_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & weight, bool bias_defined) { + return at::_ops::mkldnn_linear_backward_weights_out::redispatch(dispatchKeySet, grad_output, input, weight, bias_defined, out0, out1); + } + + // aten::mkldnn_linear_backward_weights.out(Tensor grad_output, Tensor input, Tensor weight, bool bias_defined, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple mkldnn_linear_backward_weights_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & weight, bool bias_defined, at::Tensor & out0, at::Tensor & out1) { + return at::_ops::mkldnn_linear_backward_weights_out::redispatch(dispatchKeySet, grad_output, input, weight, bias_defined, out0, out1); + } + + // aten::mkldnn_linear_backward.out(Tensor self, Tensor grad_output, Tensor weight, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple mkldnn_linear_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, const at::Tensor & self, const at::Tensor & grad_output, const at::Tensor & weight, ::std::array output_mask) { + return at::_ops::mkldnn_linear_backward_out::redispatch(dispatchKeySet, self, grad_output, weight, output_mask, out0, out1, out2); + } + + // aten::mkldnn_linear_backward.out(Tensor self, Tensor grad_output, Tensor weight, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple mkldnn_linear_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & grad_output, const at::Tensor & weight, ::std::array output_mask, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2) { + return at::_ops::mkldnn_linear_backward_out::redispatch(dispatchKeySet, self, grad_output, weight, output_mask, out0, out1, out2); + } + + // aten::matmul_backward.out(Tensor grad, Tensor self, Tensor other, bool[2] mask, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple matmul_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, const at::Tensor & grad, const at::Tensor & self, const at::Tensor & other, ::std::array mask) { + return at::_ops::matmul_backward_out::redispatch(dispatchKeySet, grad, self, other, mask, out0, out1); + } + + // aten::matmul_backward.out(Tensor grad, Tensor self, Tensor other, bool[2] mask, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple matmul_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & self, const at::Tensor & other, ::std::array mask, at::Tensor & out0, at::Tensor & out1) { + return at::_ops::matmul_backward_out::redispatch(dispatchKeySet, grad, self, other, mask, out0, out1); + } + + // aten::_aminmax.out(Tensor self, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple _aminmax_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, const at::Tensor & self) { + return at::_ops::_aminmax_out::redispatch(dispatchKeySet, self, out0, out1); + } + + // aten::_aminmax.out(Tensor self, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple _aminmax_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out0, at::Tensor & out1) { + return at::_ops::_aminmax_out::redispatch(dispatchKeySet, self, out0, out1); + } + + // aten::_aminmax.dim_out(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple _aminmax_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, const at::Tensor & self, int64_t dim, bool keepdim=false) { + return at::_ops::_aminmax_dim_out::redispatch(dispatchKeySet, self, dim, keepdim, out0, out1); + } + + // aten::_aminmax.dim_out(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple _aminmax_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, bool keepdim, at::Tensor & out0, at::Tensor & out1) { + return at::_ops::_aminmax_dim_out::redispatch(dispatchKeySet, self, dim, keepdim, out0, out1); + } + + // aten::max_pool2d_backward.out(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & max_pool2d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, at::IntArrayRef dilation=1, bool ceil_mode=false) { + return at::_ops::max_pool2d_backward_out::redispatch(dispatchKeySet, grad_output, self, kernel_size, stride, padding, dilation, ceil_mode, out); + } + + // aten::max_pool2d_backward.out(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & max_pool2d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, at::Tensor & out) { + return at::_ops::max_pool2d_backward_out::redispatch(dispatchKeySet, grad_output, self, kernel_size, stride, padding, dilation, ceil_mode, out); + } + + // aten::mkldnn_max_pool2d.out(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mkldnn_max_pool2d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, at::IntArrayRef dilation=1, bool ceil_mode=false) { + return at::_ops::mkldnn_max_pool2d_out::redispatch(dispatchKeySet, self, kernel_size, stride, padding, dilation, ceil_mode, out); + } + + // aten::mkldnn_max_pool2d.out(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mkldnn_max_pool2d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, at::Tensor & out) { + return at::_ops::mkldnn_max_pool2d_out::redispatch(dispatchKeySet, self, kernel_size, stride, padding, dilation, ceil_mode, out); + } + + // aten::mkldnn_max_pool2d_backward.out(Tensor grad_output, Tensor output, Tensor input, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mkldnn_max_pool2d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad_output, const at::Tensor & output, const at::Tensor & input, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, at::IntArrayRef dilation=1, bool ceil_mode=false) { + return at::_ops::mkldnn_max_pool2d_backward_out::redispatch(dispatchKeySet, grad_output, output, input, kernel_size, stride, padding, dilation, ceil_mode, out); + } + + // aten::mkldnn_max_pool2d_backward.out(Tensor grad_output, Tensor output, Tensor input, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mkldnn_max_pool2d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & output, const at::Tensor & input, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, at::Tensor & out) { + return at::_ops::mkldnn_max_pool2d_backward_out::redispatch(dispatchKeySet, grad_output, output, input, kernel_size, stride, padding, dilation, ceil_mode, out); + } + + // aten::mkldnn_max_pool3d.out(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mkldnn_max_pool3d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, at::IntArrayRef dilation=1, bool ceil_mode=false) { + return at::_ops::mkldnn_max_pool3d_out::redispatch(dispatchKeySet, self, kernel_size, stride, padding, dilation, ceil_mode, out); + } + + // aten::mkldnn_max_pool3d.out(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mkldnn_max_pool3d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, at::Tensor & out) { + return at::_ops::mkldnn_max_pool3d_out::redispatch(dispatchKeySet, self, kernel_size, stride, padding, dilation, ceil_mode, out); + } + + // aten::mkldnn_max_pool3d_backward.out(Tensor grad_output, Tensor output, Tensor input, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mkldnn_max_pool3d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad_output, const at::Tensor & output, const at::Tensor & input, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, at::IntArrayRef dilation=1, bool ceil_mode=false) { + return at::_ops::mkldnn_max_pool3d_backward_out::redispatch(dispatchKeySet, grad_output, output, input, kernel_size, stride, padding, dilation, ceil_mode, out); + } + + // aten::mkldnn_max_pool3d_backward.out(Tensor grad_output, Tensor output, Tensor input, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mkldnn_max_pool3d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & output, const at::Tensor & input, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, at::Tensor & out) { + return at::_ops::mkldnn_max_pool3d_backward_out::redispatch(dispatchKeySet, grad_output, output, input, kernel_size, stride, padding, dilation, ceil_mode, out); + } + + // aten::quantized_max_pool1d.out(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, int[1] dilation=1, bool ceil_mode=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & quantized_max_pool1d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, at::IntArrayRef dilation=1, bool ceil_mode=false) { + return at::_ops::quantized_max_pool1d_out::redispatch(dispatchKeySet, self, kernel_size, stride, padding, dilation, ceil_mode, out); + } + + // aten::quantized_max_pool1d.out(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, int[1] dilation=1, bool ceil_mode=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & quantized_max_pool1d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, at::Tensor & out) { + return at::_ops::quantized_max_pool1d_out::redispatch(dispatchKeySet, self, kernel_size, stride, padding, dilation, ceil_mode, out); + } + + // aten::quantized_max_pool2d.out(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & quantized_max_pool2d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, at::IntArrayRef dilation=1, bool ceil_mode=false) { + return at::_ops::quantized_max_pool2d_out::redispatch(dispatchKeySet, self, kernel_size, stride, padding, dilation, ceil_mode, out); + } + + // aten::quantized_max_pool2d.out(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & quantized_max_pool2d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, at::Tensor & out) { + return at::_ops::quantized_max_pool2d_out::redispatch(dispatchKeySet, self, kernel_size, stride, padding, dilation, ceil_mode, out); + } + + // aten::quantized_max_pool3d.out(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & quantized_max_pool3d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, at::IntArrayRef dilation=1, bool ceil_mode=false) { + return at::_ops::quantized_max_pool3d_out::redispatch(dispatchKeySet, self, kernel_size, stride, padding, dilation, ceil_mode, out); + } + + // aten::quantized_max_pool3d.out(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & quantized_max_pool3d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, at::Tensor & out) { + return at::_ops::quantized_max_pool3d_out::redispatch(dispatchKeySet, self, kernel_size, stride, padding, dilation, ceil_mode, out); + } + + // aten::median.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & median_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::median_out::redispatch(dispatchKeySet, self, out); + } + + // aten::median.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & median_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::median_out::redispatch(dispatchKeySet, self, out); + } + + // aten::nanmedian.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & nanmedian_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::nanmedian_out::redispatch(dispatchKeySet, self, out); + } + + // aten::nanmedian.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & nanmedian_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::nanmedian_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_mps_convolution.out(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _mps_convolution_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, at::IntArrayRef padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups) { + return at::_ops::_mps_convolution_out::redispatch(dispatchKeySet, self, weight, bias, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups, out); + } + + // aten::_mps_convolution.out(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _mps_convolution_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, at::IntArrayRef padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups, at::Tensor & out) { + return at::_ops::_mps_convolution_out::redispatch(dispatchKeySet, self, weight, bias, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups, out); + } + + // aten::_mps_convolution.out(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _mps_convolution_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups) { + return at::_ops::_mps_convolution_out::redispatch(dispatchKeySet, self, weight, bias, padding, stride, dilation, groups, out); + } + + // aten::_mps_convolution.out(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _mps_convolution_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, at::Tensor & out) { + return at::_ops::_mps_convolution_out::redispatch(dispatchKeySet, self, weight, bias, padding, stride, dilation, groups, out); + } + + // aten::mps_convolution_backward.out(Tensor self, Tensor grad_output, Tensor weight, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple mps_convolution_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, const at::Tensor & self, const at::Tensor & grad_output, const at::Tensor & weight, at::IntArrayRef padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups, ::std::array output_mask) { + return at::_ops::mps_convolution_backward_out::redispatch(dispatchKeySet, self, grad_output, weight, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups, output_mask, out0, out1, out2); + } + + // aten::mps_convolution_backward.out(Tensor self, Tensor grad_output, Tensor weight, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple mps_convolution_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & grad_output, const at::Tensor & weight, at::IntArrayRef padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups, ::std::array output_mask, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2) { + return at::_ops::mps_convolution_backward_out::redispatch(dispatchKeySet, self, grad_output, weight, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups, output_mask, out0, out1, out2); + } + + // aten::mps_convolution_backward.out(Tensor self, Tensor grad_output, Tensor weight, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple mps_convolution_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, const at::Tensor & self, const at::Tensor & grad_output, const at::Tensor & weight, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, ::std::array output_mask) { + return at::_ops::mps_convolution_backward_out::redispatch(dispatchKeySet, self, grad_output, weight, padding, stride, dilation, groups, output_mask, out0, out1, out2); + } + + // aten::mps_convolution_backward.out(Tensor self, Tensor grad_output, Tensor weight, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple mps_convolution_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & grad_output, const at::Tensor & weight, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, ::std::array output_mask, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2) { + return at::_ops::mps_convolution_backward_out::redispatch(dispatchKeySet, self, grad_output, weight, padding, stride, dilation, groups, output_mask, out0, out1, out2); + } + + // aten::mkldnn_convolution.out(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mkldnn_convolution_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, at::IntArrayRef padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups) { + return at::_ops::mkldnn_convolution_out::redispatch(dispatchKeySet, self, weight, bias, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups, out); + } + + // aten::mkldnn_convolution.out(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mkldnn_convolution_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, at::IntArrayRef padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups, at::Tensor & out) { + return at::_ops::mkldnn_convolution_out::redispatch(dispatchKeySet, self, weight, bias, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups, out); + } + + // aten::mkldnn_convolution.out(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mkldnn_convolution_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups) { + return at::_ops::mkldnn_convolution_out::redispatch(dispatchKeySet, self, weight, bias, padding, stride, dilation, groups, out); + } + + // aten::mkldnn_convolution.out(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mkldnn_convolution_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, at::Tensor & out) { + return at::_ops::mkldnn_convolution_out::redispatch(dispatchKeySet, self, weight, bias, padding, stride, dilation, groups, out); + } + + // aten::mkldnn_rnn_layer.out(Tensor input, Tensor weight0, Tensor weight1, Tensor weight2, Tensor weight3, Tensor hx_, Tensor cx_, bool reverse, int[] batch_sizes, int mode, int hidden_size, int num_layers, bool has_biases, bool bidirectional, bool batch_first, bool train, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!)) + inline ::std::tuple mkldnn_rnn_layer_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3, const at::Tensor & input, const at::Tensor & weight0, const at::Tensor & weight1, const at::Tensor & weight2, const at::Tensor & weight3, const at::Tensor & hx_, const at::Tensor & cx_, bool reverse, at::IntArrayRef batch_sizes, int64_t mode, int64_t hidden_size, int64_t num_layers, bool has_biases, bool bidirectional, bool batch_first, bool train) { + return at::_ops::mkldnn_rnn_layer_out::redispatch(dispatchKeySet, input, weight0, weight1, weight2, weight3, hx_, cx_, reverse, batch_sizes, mode, hidden_size, num_layers, has_biases, bidirectional, batch_first, train, out0, out1, out2, out3); + } + + // aten::mkldnn_rnn_layer.out(Tensor input, Tensor weight0, Tensor weight1, Tensor weight2, Tensor weight3, Tensor hx_, Tensor cx_, bool reverse, int[] batch_sizes, int mode, int hidden_size, int num_layers, bool has_biases, bool bidirectional, bool batch_first, bool train, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!)) + inline ::std::tuple mkldnn_rnn_layer_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight0, const at::Tensor & weight1, const at::Tensor & weight2, const at::Tensor & weight3, const at::Tensor & hx_, const at::Tensor & cx_, bool reverse, at::IntArrayRef batch_sizes, int64_t mode, int64_t hidden_size, int64_t num_layers, bool has_biases, bool bidirectional, bool batch_first, bool train, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3) { + return at::_ops::mkldnn_rnn_layer_out::redispatch(dispatchKeySet, input, weight0, weight1, weight2, weight3, hx_, cx_, reverse, batch_sizes, mode, hidden_size, num_layers, has_biases, bidirectional, batch_first, train, out0, out1, out2, out3); + } + + // aten::mkldnn_rnn_layer_backward.out(Tensor input, Tensor weight1, Tensor weight2, Tensor weight3, Tensor weight4, Tensor hx_, Tensor cx_tmp, Tensor output, Tensor hy_, Tensor cy_, Tensor? grad_output, Tensor? grad_hy, Tensor? grad_cy, bool reverse, int mode, int hidden_size, int num_layers, bool has_biases, bool train, bool bidirectional, int[] batch_sizes, bool batch_first, Tensor workspace, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3, Tensor(e!) out4, Tensor(f!) out5, Tensor(g!) out6) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!), Tensor(e!), Tensor(f!), Tensor(g!)) + inline ::std::tuple mkldnn_rnn_layer_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3, at::Tensor & out4, at::Tensor & out5, at::Tensor & out6, const at::Tensor & input, const at::Tensor & weight1, const at::Tensor & weight2, const at::Tensor & weight3, const at::Tensor & weight4, const at::Tensor & hx_, const at::Tensor & cx_tmp, const at::Tensor & output, const at::Tensor & hy_, const at::Tensor & cy_, const ::std::optional & grad_output, const ::std::optional & grad_hy, const ::std::optional & grad_cy, bool reverse, int64_t mode, int64_t hidden_size, int64_t num_layers, bool has_biases, bool train, bool bidirectional, at::IntArrayRef batch_sizes, bool batch_first, const at::Tensor & workspace) { + return at::_ops::mkldnn_rnn_layer_backward_out::redispatch(dispatchKeySet, input, weight1, weight2, weight3, weight4, hx_, cx_tmp, output, hy_, cy_, grad_output, grad_hy, grad_cy, reverse, mode, hidden_size, num_layers, has_biases, train, bidirectional, batch_sizes, batch_first, workspace, out0, out1, out2, out3, out4, out5, out6); + } + + // aten::mkldnn_rnn_layer_backward.out(Tensor input, Tensor weight1, Tensor weight2, Tensor weight3, Tensor weight4, Tensor hx_, Tensor cx_tmp, Tensor output, Tensor hy_, Tensor cy_, Tensor? grad_output, Tensor? grad_hy, Tensor? grad_cy, bool reverse, int mode, int hidden_size, int num_layers, bool has_biases, bool train, bool bidirectional, int[] batch_sizes, bool batch_first, Tensor workspace, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3, Tensor(e!) out4, Tensor(f!) out5, Tensor(g!) out6) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!), Tensor(e!), Tensor(f!), Tensor(g!)) + inline ::std::tuple mkldnn_rnn_layer_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight1, const at::Tensor & weight2, const at::Tensor & weight3, const at::Tensor & weight4, const at::Tensor & hx_, const at::Tensor & cx_tmp, const at::Tensor & output, const at::Tensor & hy_, const at::Tensor & cy_, const ::std::optional & grad_output, const ::std::optional & grad_hy, const ::std::optional & grad_cy, bool reverse, int64_t mode, int64_t hidden_size, int64_t num_layers, bool has_biases, bool train, bool bidirectional, at::IntArrayRef batch_sizes, bool batch_first, const at::Tensor & workspace, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3, at::Tensor & out4, at::Tensor & out5, at::Tensor & out6) { + return at::_ops::mkldnn_rnn_layer_backward_out::redispatch(dispatchKeySet, input, weight1, weight2, weight3, weight4, hx_, cx_tmp, output, hy_, cy_, grad_output, grad_hy, grad_cy, reverse, mode, hidden_size, num_layers, has_biases, train, bidirectional, batch_sizes, batch_first, workspace, out0, out1, out2, out3, out4, out5, out6); + } + + // aten::miopen_batch_norm.out(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple miopen_batch_norm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, const ::std::optional & running_mean, const ::std::optional & running_var, bool training, double exponential_average_factor, double epsilon) { + return at::_ops::miopen_batch_norm_out::redispatch(dispatchKeySet, input, weight, bias, running_mean, running_var, training, exponential_average_factor, epsilon, out0, out1, out2); + } + + // aten::miopen_batch_norm.out(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple miopen_batch_norm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, const ::std::optional & running_mean, const ::std::optional & running_var, bool training, double exponential_average_factor, double epsilon, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2) { + return at::_ops::miopen_batch_norm_out::redispatch(dispatchKeySet, input, weight, bias, running_mean, running_var, training, exponential_average_factor, epsilon, out0, out1, out2); + } + + // aten::miopen_batch_norm_backward.out(Tensor input, Tensor grad_output, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, float epsilon, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple miopen_batch_norm_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, const at::Tensor & input, const at::Tensor & grad_output, const at::Tensor & weight, const ::std::optional & running_mean, const ::std::optional & running_var, const ::std::optional & save_mean, const ::std::optional & save_var, double epsilon) { + return at::_ops::miopen_batch_norm_backward_out::redispatch(dispatchKeySet, input, grad_output, weight, running_mean, running_var, save_mean, save_var, epsilon, out0, out1, out2); + } + + // aten::miopen_batch_norm_backward.out(Tensor input, Tensor grad_output, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, float epsilon, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple miopen_batch_norm_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & grad_output, const at::Tensor & weight, const ::std::optional & running_mean, const ::std::optional & running_var, const ::std::optional & save_mean, const ::std::optional & save_var, double epsilon, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2) { + return at::_ops::miopen_batch_norm_backward_out::redispatch(dispatchKeySet, input, grad_output, weight, running_mean, running_var, save_mean, save_var, epsilon, out0, out1, out2); + } + + // aten::miopen_convolution.out(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & miopen_convolution_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, at::IntArrayRef padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups, bool benchmark, bool deterministic) { + return at::_ops::miopen_convolution_out::redispatch(dispatchKeySet, self, weight, bias, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups, benchmark, deterministic, out); + } + + // aten::miopen_convolution.out(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & miopen_convolution_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, at::IntArrayRef padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups, bool benchmark, bool deterministic, at::Tensor & out) { + return at::_ops::miopen_convolution_out::redispatch(dispatchKeySet, self, weight, bias, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups, benchmark, deterministic, out); + } + + // aten::miopen_convolution.out(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & miopen_convolution_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, bool benchmark, bool deterministic) { + return at::_ops::miopen_convolution_out::redispatch(dispatchKeySet, self, weight, bias, padding, stride, dilation, groups, benchmark, deterministic, out); + } + + // aten::miopen_convolution.out(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & miopen_convolution_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, bool benchmark, bool deterministic, at::Tensor & out) { + return at::_ops::miopen_convolution_out::redispatch(dispatchKeySet, self, weight, bias, padding, stride, dilation, groups, benchmark, deterministic, out); + } + + // aten::miopen_convolution_transpose.out(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & miopen_convolution_transpose_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, at::IntArrayRef padding, at::IntArrayRef output_padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups, bool benchmark, bool deterministic) { + return at::_ops::miopen_convolution_transpose_out::redispatch(dispatchKeySet, self, weight, bias, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(output_padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups, benchmark, deterministic, out); + } + + // aten::miopen_convolution_transpose.out(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & miopen_convolution_transpose_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, at::IntArrayRef padding, at::IntArrayRef output_padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups, bool benchmark, bool deterministic, at::Tensor & out) { + return at::_ops::miopen_convolution_transpose_out::redispatch(dispatchKeySet, self, weight, bias, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(output_padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups, benchmark, deterministic, out); + } + + // aten::miopen_convolution_transpose.out(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & miopen_convolution_transpose_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, bool benchmark, bool deterministic) { + return at::_ops::miopen_convolution_transpose_out::redispatch(dispatchKeySet, self, weight, bias, padding, output_padding, stride, dilation, groups, benchmark, deterministic, out); + } + + // aten::miopen_convolution_transpose.out(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & miopen_convolution_transpose_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, bool benchmark, bool deterministic, at::Tensor & out) { + return at::_ops::miopen_convolution_transpose_out::redispatch(dispatchKeySet, self, weight, bias, padding, output_padding, stride, dilation, groups, benchmark, deterministic, out); + } + + // aten::miopen_depthwise_convolution.out(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & miopen_depthwise_convolution_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, at::IntArrayRef padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups, bool benchmark, bool deterministic) { + return at::_ops::miopen_depthwise_convolution_out::redispatch(dispatchKeySet, self, weight, bias, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups, benchmark, deterministic, out); + } + + // aten::miopen_depthwise_convolution.out(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & miopen_depthwise_convolution_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, at::IntArrayRef padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups, bool benchmark, bool deterministic, at::Tensor & out) { + return at::_ops::miopen_depthwise_convolution_out::redispatch(dispatchKeySet, self, weight, bias, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups, benchmark, deterministic, out); + } + + // aten::miopen_depthwise_convolution.out(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & miopen_depthwise_convolution_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, bool benchmark, bool deterministic) { + return at::_ops::miopen_depthwise_convolution_out::redispatch(dispatchKeySet, self, weight, bias, padding, stride, dilation, groups, benchmark, deterministic, out); + } + + // aten::miopen_depthwise_convolution.out(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & miopen_depthwise_convolution_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, bool benchmark, bool deterministic, at::Tensor & out) { + return at::_ops::miopen_depthwise_convolution_out::redispatch(dispatchKeySet, self, weight, bias, padding, stride, dilation, groups, benchmark, deterministic, out); + } + + // aten::miopen_rnn.out(Tensor input, Tensor[] weight, int weight_stride0, Tensor hx, Tensor? cx, int mode, int hidden_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, int[] batch_sizes, Tensor? dropout_state, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3, Tensor(e!) out4) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!), Tensor(e!)) + inline ::std::tuple miopen_rnn_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3, at::Tensor & out4, const at::Tensor & input, at::TensorList weight, int64_t weight_stride0, const at::Tensor & hx, const ::std::optional & cx, int64_t mode, int64_t hidden_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, at::IntArrayRef batch_sizes, const ::std::optional & dropout_state) { + return at::_ops::miopen_rnn_out::redispatch(dispatchKeySet, input, weight, weight_stride0, hx, cx, mode, hidden_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state, out0, out1, out2, out3, out4); + } + + // aten::miopen_rnn.out(Tensor input, Tensor[] weight, int weight_stride0, Tensor hx, Tensor? cx, int mode, int hidden_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, int[] batch_sizes, Tensor? dropout_state, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3, Tensor(e!) out4) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!), Tensor(e!)) + inline ::std::tuple miopen_rnn_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::TensorList weight, int64_t weight_stride0, const at::Tensor & hx, const ::std::optional & cx, int64_t mode, int64_t hidden_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, at::IntArrayRef batch_sizes, const ::std::optional & dropout_state, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3, at::Tensor & out4) { + return at::_ops::miopen_rnn_out::redispatch(dispatchKeySet, input, weight, weight_stride0, hx, cx, mode, hidden_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state, out0, out1, out2, out3, out4); + } + + // aten::miopen_rnn_backward.out(Tensor input, Tensor[] weight, int weight_stride0, Tensor weight_buf, Tensor hx, Tensor? cx, Tensor output, Tensor? grad_output, Tensor? grad_hy, Tensor? grad_cy, int mode, int hidden_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, int[] batch_sizes, Tensor? dropout_state, Tensor reserve, bool[4] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!)[] out3) -> () + inline void miopen_rnn_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::TensorList out3, const at::Tensor & input, at::TensorList weight, int64_t weight_stride0, const at::Tensor & weight_buf, const at::Tensor & hx, const ::std::optional & cx, const at::Tensor & output, const ::std::optional & grad_output, const ::std::optional & grad_hy, const ::std::optional & grad_cy, int64_t mode, int64_t hidden_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, at::IntArrayRef batch_sizes, const ::std::optional & dropout_state, const at::Tensor & reserve, ::std::array output_mask) { + return at::_ops::miopen_rnn_backward_out::redispatch(dispatchKeySet, input, weight, weight_stride0, weight_buf, hx, cx, output, grad_output, grad_hy, grad_cy, mode, hidden_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state, reserve, output_mask, out0, out1, out2, out3); + } + + // aten::miopen_rnn_backward.out(Tensor input, Tensor[] weight, int weight_stride0, Tensor weight_buf, Tensor hx, Tensor? cx, Tensor output, Tensor? grad_output, Tensor? grad_hy, Tensor? grad_cy, int mode, int hidden_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, int[] batch_sizes, Tensor? dropout_state, Tensor reserve, bool[4] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!)[] out3) -> () + inline void miopen_rnn_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::TensorList weight, int64_t weight_stride0, const at::Tensor & weight_buf, const at::Tensor & hx, const ::std::optional & cx, const at::Tensor & output, const ::std::optional & grad_output, const ::std::optional & grad_hy, const ::std::optional & grad_cy, int64_t mode, int64_t hidden_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, at::IntArrayRef batch_sizes, const ::std::optional & dropout_state, const at::Tensor & reserve, ::std::array output_mask, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::TensorList out3) { + return at::_ops::miopen_rnn_backward_out::redispatch(dispatchKeySet, input, weight, weight_stride0, weight_buf, hx, cx, output, grad_output, grad_hy, grad_cy, mode, hidden_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state, reserve, output_mask, out0, out1, out2, out3); + } + + // aten::_sparse_sparse_matmul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _sparse_sparse_matmul_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::_sparse_sparse_matmul_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::_sparse_sparse_matmul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _sparse_sparse_matmul_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::_sparse_sparse_matmul_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::mul.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mul_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::mul_Scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::mul.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mul_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, at::Tensor & out) { + return at::_ops::mul_Scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::_native_batch_norm_legit_functional(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor, Tensor running_mean_out, Tensor running_var_out) + inline ::std::tuple _native_batch_norm_legit_functional(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const at::Tensor & running_mean, const at::Tensor & running_var, bool training, double momentum, double eps) { + return at::_ops::_native_batch_norm_legit_functional::redispatch(dispatchKeySet, input, weight, bias, running_mean, running_var, training, momentum, eps); + } + + // aten::_native_batch_norm_legit_no_training.out(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, float momentum, float eps, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple _native_batch_norm_legit_no_training_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const at::Tensor & running_mean, const at::Tensor & running_var, double momentum, double eps) { + return at::_ops::_native_batch_norm_legit_no_training_out::redispatch(dispatchKeySet, input, weight, bias, running_mean, running_var, momentum, eps, out0, out1, out2); + } + + // aten::_native_batch_norm_legit_no_training.out(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, float momentum, float eps, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple _native_batch_norm_legit_no_training_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const at::Tensor & running_mean, const at::Tensor & running_var, double momentum, double eps, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2) { + return at::_ops::_native_batch_norm_legit_no_training_out::redispatch(dispatchKeySet, input, weight, bias, running_mean, running_var, momentum, eps, out0, out1, out2); + } + + // aten::batch_norm_stats.out(Tensor input, float eps, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple batch_norm_stats_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, const at::Tensor & input, double eps) { + return at::_ops::batch_norm_stats_out::redispatch(dispatchKeySet, input, eps, out0, out1); + } + + // aten::batch_norm_stats.out(Tensor input, float eps, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple batch_norm_stats_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, double eps, at::Tensor & out0, at::Tensor & out1) { + return at::_ops::batch_norm_stats_out::redispatch(dispatchKeySet, input, eps, out0, out1); + } + + // aten::batch_norm_gather_stats.out(Tensor input, Tensor mean, Tensor invstd, Tensor? running_mean, Tensor? running_var, float momentum, float eps, int count, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple batch_norm_gather_stats_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const ::std::optional & running_mean, const ::std::optional & running_var, double momentum, double eps, int64_t count) { + return at::_ops::batch_norm_gather_stats_out::redispatch(dispatchKeySet, input, mean, invstd, running_mean, running_var, momentum, eps, count, out0, out1); + } + + // aten::batch_norm_gather_stats.out(Tensor input, Tensor mean, Tensor invstd, Tensor? running_mean, Tensor? running_var, float momentum, float eps, int count, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple batch_norm_gather_stats_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const ::std::optional & running_mean, const ::std::optional & running_var, double momentum, double eps, int64_t count, at::Tensor & out0, at::Tensor & out1) { + return at::_ops::batch_norm_gather_stats_out::redispatch(dispatchKeySet, input, mean, invstd, running_mean, running_var, momentum, eps, count, out0, out1); + } + + // aten::batch_norm_gather_stats_with_counts.out(Tensor input, Tensor mean, Tensor invstd, Tensor? running_mean, Tensor? running_var, float momentum, float eps, Tensor counts, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple batch_norm_gather_stats_with_counts_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const ::std::optional & running_mean, const ::std::optional & running_var, double momentum, double eps, const at::Tensor & counts) { + return at::_ops::batch_norm_gather_stats_with_counts_out::redispatch(dispatchKeySet, input, mean, invstd, running_mean, running_var, momentum, eps, counts, out0, out1); + } + + // aten::batch_norm_gather_stats_with_counts.out(Tensor input, Tensor mean, Tensor invstd, Tensor? running_mean, Tensor? running_var, float momentum, float eps, Tensor counts, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple batch_norm_gather_stats_with_counts_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const ::std::optional & running_mean, const ::std::optional & running_var, double momentum, double eps, const at::Tensor & counts, at::Tensor & out0, at::Tensor & out1) { + return at::_ops::batch_norm_gather_stats_with_counts_out::redispatch(dispatchKeySet, input, mean, invstd, running_mean, running_var, momentum, eps, counts, out0, out1); + } + + // aten::native_batch_norm_backward.out(Tensor grad_out, Tensor input, Tensor? weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_invstd, bool train, float eps, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple native_batch_norm_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, const at::Tensor & grad_out, const at::Tensor & input, const ::std::optional & weight, const ::std::optional & running_mean, const ::std::optional & running_var, const ::std::optional & save_mean, const ::std::optional & save_invstd, bool train, double eps, ::std::array output_mask) { + return at::_ops::native_batch_norm_backward_out::redispatch(dispatchKeySet, grad_out, input, weight, running_mean, running_var, save_mean, save_invstd, train, eps, output_mask, out0, out1, out2); + } + + // aten::native_batch_norm_backward.out(Tensor grad_out, Tensor input, Tensor? weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_invstd, bool train, float eps, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple native_batch_norm_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_out, const at::Tensor & input, const ::std::optional & weight, const ::std::optional & running_mean, const ::std::optional & running_var, const ::std::optional & save_mean, const ::std::optional & save_invstd, bool train, double eps, ::std::array output_mask, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2) { + return at::_ops::native_batch_norm_backward_out::redispatch(dispatchKeySet, grad_out, input, weight, running_mean, running_var, save_mean, save_invstd, train, eps, output_mask, out0, out1, out2); + } + + // aten::batch_norm_backward_reduce.out(Tensor grad_out, Tensor input, Tensor mean, Tensor invstd, Tensor? weight, bool input_g, bool weight_g, bool bias_g, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!)) + inline ::std::tuple batch_norm_backward_reduce_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3, const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const ::std::optional & weight, bool input_g, bool weight_g, bool bias_g) { + return at::_ops::batch_norm_backward_reduce_out::redispatch(dispatchKeySet, grad_out, input, mean, invstd, weight, input_g, weight_g, bias_g, out0, out1, out2, out3); + } + + // aten::batch_norm_backward_reduce.out(Tensor grad_out, Tensor input, Tensor mean, Tensor invstd, Tensor? weight, bool input_g, bool weight_g, bool bias_g, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!)) + inline ::std::tuple batch_norm_backward_reduce_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const ::std::optional & weight, bool input_g, bool weight_g, bool bias_g, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3) { + return at::_ops::batch_norm_backward_reduce_out::redispatch(dispatchKeySet, grad_out, input, mean, invstd, weight, input_g, weight_g, bias_g, out0, out1, out2, out3); + } + + // aten::batch_norm_backward_elemt.out(Tensor grad_out, Tensor input, Tensor mean, Tensor invstd, Tensor? weight, Tensor sum_dy, Tensor sum_dy_xmu, Tensor count, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & batch_norm_backward_elemt_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const ::std::optional & weight, const at::Tensor & sum_dy, const at::Tensor & sum_dy_xmu, const at::Tensor & count) { + return at::_ops::batch_norm_backward_elemt_out::redispatch(dispatchKeySet, grad_out, input, mean, invstd, weight, sum_dy, sum_dy_xmu, count, out); + } + + // aten::batch_norm_backward_elemt.out(Tensor grad_out, Tensor input, Tensor mean, Tensor invstd, Tensor? weight, Tensor sum_dy, Tensor sum_dy_xmu, Tensor count, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & batch_norm_backward_elemt_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const ::std::optional & weight, const at::Tensor & sum_dy, const at::Tensor & sum_dy_xmu, const at::Tensor & count, at::Tensor & out) { + return at::_ops::batch_norm_backward_elemt_out::redispatch(dispatchKeySet, grad_out, input, mean, invstd, weight, sum_dy, sum_dy_xmu, count, out); + } + + // aten::batch_norm_update_stats.out(Tensor input, Tensor? running_mean, Tensor? running_var, float momentum, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple batch_norm_update_stats_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, const at::Tensor & input, const ::std::optional & running_mean, const ::std::optional & running_var, double momentum) { + return at::_ops::batch_norm_update_stats_out::redispatch(dispatchKeySet, input, running_mean, running_var, momentum, out0, out1); + } + + // aten::batch_norm_update_stats.out(Tensor input, Tensor? running_mean, Tensor? running_var, float momentum, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple batch_norm_update_stats_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const ::std::optional & running_mean, const ::std::optional & running_var, double momentum, at::Tensor & out0, at::Tensor & out1) { + return at::_ops::batch_norm_update_stats_out::redispatch(dispatchKeySet, input, running_mean, running_var, momentum, out0, out1); + } + + // aten::_nnpack_spatial_convolution.out(Tensor input, Tensor weight, Tensor? bias, SymInt[2] padding, SymInt[2] stride=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _nnpack_spatial_convolution_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, at::IntArrayRef padding, at::IntArrayRef stride=1) { + return at::_ops::_nnpack_spatial_convolution_out::redispatch(dispatchKeySet, input, weight, bias, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(stride), out); + } + + // aten::_nnpack_spatial_convolution.out(Tensor input, Tensor weight, Tensor? bias, SymInt[2] padding, SymInt[2] stride=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _nnpack_spatial_convolution_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, at::IntArrayRef padding, at::IntArrayRef stride, at::Tensor & out) { + return at::_ops::_nnpack_spatial_convolution_out::redispatch(dispatchKeySet, input, weight, bias, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(stride), out); + } + + // aten::_nnpack_spatial_convolution.out(Tensor input, Tensor weight, Tensor? bias, SymInt[2] padding, SymInt[2] stride=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _nnpack_spatial_convolution_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride=c10::SymInt(1)) { + return at::_ops::_nnpack_spatial_convolution_out::redispatch(dispatchKeySet, input, weight, bias, padding, stride, out); + } + + // aten::_nnpack_spatial_convolution.out(Tensor input, Tensor weight, Tensor? bias, SymInt[2] padding, SymInt[2] stride=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _nnpack_spatial_convolution_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, at::Tensor & out) { + return at::_ops::_nnpack_spatial_convolution_out::redispatch(dispatchKeySet, input, weight, bias, padding, stride, out); + } + + // aten::ones.names_out(int[] size, *, Dimname[]? names, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & ones_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::IntArrayRef size, ::std::optional names) { + return at::_ops::ones_names_out::redispatch(dispatchKeySet, size, names, out); + } + + // aten::ones.names_out(int[] size, *, Dimname[]? names, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & ones_outf(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, ::std::optional names, at::Tensor & out) { + return at::_ops::ones_names_out::redispatch(dispatchKeySet, size, names, out); + } + + // aten::ones_like.out(Tensor self, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & ones_like_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, ::std::optional memory_format=::std::nullopt) { + return at::_ops::ones_like_out::redispatch(dispatchKeySet, self, memory_format, out); + } + + // aten::ones_like.out(Tensor self, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & ones_like_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional memory_format, at::Tensor & out) { + return at::_ops::ones_like_out::redispatch(dispatchKeySet, self, memory_format, out); + } + + // aten::_euclidean_dist.out(Tensor x1, Tensor x2, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _euclidean_dist_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & x1, const at::Tensor & x2) { + return at::_ops::_euclidean_dist_out::redispatch(dispatchKeySet, x1, x2, out); + } + + // aten::_euclidean_dist.out(Tensor x1, Tensor x2, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _euclidean_dist_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x1, const at::Tensor & x2, at::Tensor & out) { + return at::_ops::_euclidean_dist_out::redispatch(dispatchKeySet, x1, x2, out); + } + + // aten::_cdist_forward.out(Tensor x1, Tensor x2, float p, int? compute_mode, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _cdist_forward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & x1, const at::Tensor & x2, double p, ::std::optional compute_mode) { + return at::_ops::_cdist_forward_out::redispatch(dispatchKeySet, x1, x2, p, compute_mode, out); + } + + // aten::_cdist_forward.out(Tensor x1, Tensor x2, float p, int? compute_mode, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _cdist_forward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x1, const at::Tensor & x2, double p, ::std::optional compute_mode, at::Tensor & out) { + return at::_ops::_cdist_forward_out::redispatch(dispatchKeySet, x1, x2, p, compute_mode, out); + } + + // aten::_cdist_backward.out(Tensor grad, Tensor x1, Tensor x2, float p, Tensor cdist, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _cdist_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad, const at::Tensor & x1, const at::Tensor & x2, double p, const at::Tensor & cdist) { + return at::_ops::_cdist_backward_out::redispatch(dispatchKeySet, grad, x1, x2, p, cdist, out); + } + + // aten::_cdist_backward.out(Tensor grad, Tensor x1, Tensor x2, float p, Tensor cdist, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _cdist_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & x1, const at::Tensor & x2, double p, const at::Tensor & cdist, at::Tensor & out) { + return at::_ops::_cdist_backward_out::redispatch(dispatchKeySet, grad, x1, x2, p, cdist, out); + } + + // aten::_pdist_forward.out(Tensor self, float p=2, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _pdist_forward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, double p=2) { + return at::_ops::_pdist_forward_out::redispatch(dispatchKeySet, self, p, out); + } + + // aten::_pdist_forward.out(Tensor self, float p=2, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _pdist_forward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double p, at::Tensor & out) { + return at::_ops::_pdist_forward_out::redispatch(dispatchKeySet, self, p, out); + } + + // aten::_pdist_backward.out(Tensor grad, Tensor self, float p, Tensor pdist, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _pdist_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad, const at::Tensor & self, double p, const at::Tensor & pdist) { + return at::_ops::_pdist_backward_out::redispatch(dispatchKeySet, grad, self, p, pdist, out); + } + + // aten::_pdist_backward.out(Tensor grad, Tensor self, float p, Tensor pdist, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _pdist_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & self, double p, const at::Tensor & pdist, at::Tensor & out) { + return at::_ops::_pdist_backward_out::redispatch(dispatchKeySet, grad, self, p, pdist, out); + } + + // aten::pixel_shuffle.out(Tensor self, int upscale_factor, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & pixel_shuffle_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t upscale_factor) { + return at::_ops::pixel_shuffle_out::redispatch(dispatchKeySet, self, upscale_factor, out); + } + + // aten::pixel_shuffle.out(Tensor self, int upscale_factor, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & pixel_shuffle_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t upscale_factor, at::Tensor & out) { + return at::_ops::pixel_shuffle_out::redispatch(dispatchKeySet, self, upscale_factor, out); + } + + // aten::pixel_unshuffle.out(Tensor self, int downscale_factor, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & pixel_unshuffle_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t downscale_factor) { + return at::_ops::pixel_unshuffle_out::redispatch(dispatchKeySet, self, downscale_factor, out); + } + + // aten::pixel_unshuffle.out(Tensor self, int downscale_factor, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & pixel_unshuffle_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t downscale_factor, at::Tensor & out) { + return at::_ops::pixel_unshuffle_out::redispatch(dispatchKeySet, self, downscale_factor, out); + } + + // aten::channel_shuffle.out(Tensor self, SymInt groups, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & channel_shuffle_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t groups) { + return at::_ops::channel_shuffle_out::redispatch(dispatchKeySet, self, groups, out); + } + + // aten::channel_shuffle.out(Tensor self, SymInt groups, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & channel_shuffle_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t groups, at::Tensor & out) { + return at::_ops::channel_shuffle_out::redispatch(dispatchKeySet, self, groups, out); + } + + // aten::channel_shuffle.out(Tensor self, SymInt groups, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & channel_shuffle_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymInt groups) { + return at::_ops::channel_shuffle_out::redispatch(dispatchKeySet, self, groups, out); + } + + // aten::channel_shuffle.out(Tensor self, SymInt groups, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & channel_shuffle_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymInt groups, at::Tensor & out) { + return at::_ops::channel_shuffle_out::redispatch(dispatchKeySet, self, groups, out); + } + + // aten::_pin_memory.out(Tensor self, Device? device=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _pin_memory_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, ::std::optional device=::std::nullopt) { + return at::_ops::_pin_memory_out::redispatch(dispatchKeySet, self, device, out); + } + + // aten::_pin_memory.out(Tensor self, Device? device=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _pin_memory_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional device, at::Tensor & out) { + return at::_ops::_pin_memory_out::redispatch(dispatchKeySet, self, device, out); + } + + // aten::scalar_tensor.out(Scalar s, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & scalar_tensor_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & s) { + return at::_ops::scalar_tensor_out::redispatch(dispatchKeySet, s, out); + } + + // aten::scalar_tensor.out(Scalar s, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & scalar_tensor_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & s, at::Tensor & out) { + return at::_ops::scalar_tensor_out::redispatch(dispatchKeySet, s, out); + } + + // aten::rand.names_out(SymInt[] size, *, Dimname[]? names, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & rand_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::IntArrayRef size, ::std::optional names) { + return at::_ops::rand_names_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), names, out); + } + + // aten::rand.names_out(SymInt[] size, *, Dimname[]? names, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & rand_outf(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, ::std::optional names, at::Tensor & out) { + return at::_ops::rand_names_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), names, out); + } + + // aten::rand.names_out(SymInt[] size, *, Dimname[]? names, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & rand_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, c10::SymIntArrayRef size, ::std::optional names) { + return at::_ops::rand_names_out::redispatch(dispatchKeySet, size, names, out); + } + + // aten::rand.names_out(SymInt[] size, *, Dimname[]? names, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & rand_symint_outf(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, ::std::optional names, at::Tensor & out) { + return at::_ops::rand_names_out::redispatch(dispatchKeySet, size, names, out); + } + + // aten::rand.generator_with_names_out(SymInt[] size, *, Generator? generator, Dimname[]? names, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & rand_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::IntArrayRef size, ::std::optional generator, ::std::optional names) { + return at::_ops::rand_generator_with_names_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), generator, names, out); + } + + // aten::rand.generator_with_names_out(SymInt[] size, *, Generator? generator, Dimname[]? names, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & rand_outf(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, ::std::optional generator, ::std::optional names, at::Tensor & out) { + return at::_ops::rand_generator_with_names_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), generator, names, out); + } + + // aten::rand.generator_with_names_out(SymInt[] size, *, Generator? generator, Dimname[]? names, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & rand_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, c10::SymIntArrayRef size, ::std::optional generator, ::std::optional names) { + return at::_ops::rand_generator_with_names_out::redispatch(dispatchKeySet, size, generator, names, out); + } + + // aten::rand.generator_with_names_out(SymInt[] size, *, Generator? generator, Dimname[]? names, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & rand_symint_outf(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, ::std::optional generator, ::std::optional names, at::Tensor & out) { + return at::_ops::rand_generator_with_names_out::redispatch(dispatchKeySet, size, generator, names, out); + } + + // aten::rand_like.out(Tensor self, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & rand_like_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, ::std::optional memory_format=::std::nullopt) { + return at::_ops::rand_like_out::redispatch(dispatchKeySet, self, memory_format, out); + } + + // aten::rand_like.out(Tensor self, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & rand_like_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional memory_format, at::Tensor & out) { + return at::_ops::rand_like_out::redispatch(dispatchKeySet, self, memory_format, out); + } + + // aten::rand_like.generator_out(Tensor self, *, Generator? generator, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & rand_like_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, ::std::optional generator, ::std::optional memory_format=::std::nullopt) { + return at::_ops::rand_like_generator_out::redispatch(dispatchKeySet, self, generator, memory_format, out); + } + + // aten::rand_like.generator_out(Tensor self, *, Generator? generator, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & rand_like_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional generator, ::std::optional memory_format, at::Tensor & out) { + return at::_ops::rand_like_generator_out::redispatch(dispatchKeySet, self, generator, memory_format, out); + } + + // aten::randint_like.out(Tensor self, SymInt high, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randint_like_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t high, ::std::optional memory_format=::std::nullopt) { + return at::_ops::randint_like_out::redispatch(dispatchKeySet, self, high, memory_format, out); + } + + // aten::randint_like.out(Tensor self, SymInt high, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randint_like_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t high, ::std::optional memory_format, at::Tensor & out) { + return at::_ops::randint_like_out::redispatch(dispatchKeySet, self, high, memory_format, out); + } + + // aten::randint_like.out(Tensor self, SymInt high, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randint_like_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymInt high, ::std::optional memory_format=::std::nullopt) { + return at::_ops::randint_like_out::redispatch(dispatchKeySet, self, high, memory_format, out); + } + + // aten::randint_like.out(Tensor self, SymInt high, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randint_like_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymInt high, ::std::optional memory_format, at::Tensor & out) { + return at::_ops::randint_like_out::redispatch(dispatchKeySet, self, high, memory_format, out); + } + + // aten::randint_like.generator_out(Tensor self, SymInt high, *, Generator? generator, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randint_like_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t high, ::std::optional generator, ::std::optional memory_format=::std::nullopt) { + return at::_ops::randint_like_generator_out::redispatch(dispatchKeySet, self, high, generator, memory_format, out); + } + + // aten::randint_like.generator_out(Tensor self, SymInt high, *, Generator? generator, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randint_like_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t high, ::std::optional generator, ::std::optional memory_format, at::Tensor & out) { + return at::_ops::randint_like_generator_out::redispatch(dispatchKeySet, self, high, generator, memory_format, out); + } + + // aten::randint_like.generator_out(Tensor self, SymInt high, *, Generator? generator, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randint_like_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymInt high, ::std::optional generator, ::std::optional memory_format=::std::nullopt) { + return at::_ops::randint_like_generator_out::redispatch(dispatchKeySet, self, high, generator, memory_format, out); + } + + // aten::randint_like.generator_out(Tensor self, SymInt high, *, Generator? generator, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randint_like_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymInt high, ::std::optional generator, ::std::optional memory_format, at::Tensor & out) { + return at::_ops::randint_like_generator_out::redispatch(dispatchKeySet, self, high, generator, memory_format, out); + } + + // aten::randint_like.Tensor_out(Tensor self, Tensor high, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randint_like_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & high, ::std::optional memory_format=::std::nullopt) { + return at::_ops::randint_like_Tensor_out::redispatch(dispatchKeySet, self, high, memory_format, out); + } + + // aten::randint_like.Tensor_out(Tensor self, Tensor high, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randint_like_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & high, ::std::optional memory_format, at::Tensor & out) { + return at::_ops::randint_like_Tensor_out::redispatch(dispatchKeySet, self, high, memory_format, out); + } + + // aten::randint_like.Tensor_generator_out(Tensor self, Tensor high, *, Generator? generator, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randint_like_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & high, ::std::optional generator, ::std::optional memory_format=::std::nullopt) { + return at::_ops::randint_like_Tensor_generator_out::redispatch(dispatchKeySet, self, high, generator, memory_format, out); + } + + // aten::randint_like.Tensor_generator_out(Tensor self, Tensor high, *, Generator? generator, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randint_like_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & high, ::std::optional generator, ::std::optional memory_format, at::Tensor & out) { + return at::_ops::randint_like_Tensor_generator_out::redispatch(dispatchKeySet, self, high, generator, memory_format, out); + } + + // aten::randint_like.low_dtype_out(Tensor self, SymInt low, SymInt high, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randint_like_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t low, int64_t high, ::std::optional memory_format=::std::nullopt) { + return at::_ops::randint_like_low_dtype_out::redispatch(dispatchKeySet, self, low, high, memory_format, out); + } + + // aten::randint_like.low_dtype_out(Tensor self, SymInt low, SymInt high, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randint_like_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t low, int64_t high, ::std::optional memory_format, at::Tensor & out) { + return at::_ops::randint_like_low_dtype_out::redispatch(dispatchKeySet, self, low, high, memory_format, out); + } + + // aten::randint_like.low_dtype_out(Tensor self, SymInt low, SymInt high, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randint_like_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymInt low, c10::SymInt high, ::std::optional memory_format=::std::nullopt) { + return at::_ops::randint_like_low_dtype_out::redispatch(dispatchKeySet, self, low, high, memory_format, out); + } + + // aten::randint_like.low_dtype_out(Tensor self, SymInt low, SymInt high, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randint_like_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymInt low, c10::SymInt high, ::std::optional memory_format, at::Tensor & out) { + return at::_ops::randint_like_low_dtype_out::redispatch(dispatchKeySet, self, low, high, memory_format, out); + } + + // aten::randint_like.low_generator_dtype_out(Tensor self, SymInt low, SymInt high, *, Generator? generator, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randint_like_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t low, int64_t high, ::std::optional generator, ::std::optional memory_format=::std::nullopt) { + return at::_ops::randint_like_low_generator_dtype_out::redispatch(dispatchKeySet, self, low, high, generator, memory_format, out); + } + + // aten::randint_like.low_generator_dtype_out(Tensor self, SymInt low, SymInt high, *, Generator? generator, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randint_like_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t low, int64_t high, ::std::optional generator, ::std::optional memory_format, at::Tensor & out) { + return at::_ops::randint_like_low_generator_dtype_out::redispatch(dispatchKeySet, self, low, high, generator, memory_format, out); + } + + // aten::randint_like.low_generator_dtype_out(Tensor self, SymInt low, SymInt high, *, Generator? generator, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randint_like_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymInt low, c10::SymInt high, ::std::optional generator, ::std::optional memory_format=::std::nullopt) { + return at::_ops::randint_like_low_generator_dtype_out::redispatch(dispatchKeySet, self, low, high, generator, memory_format, out); + } + + // aten::randint_like.low_generator_dtype_out(Tensor self, SymInt low, SymInt high, *, Generator? generator, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randint_like_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymInt low, c10::SymInt high, ::std::optional generator, ::std::optional memory_format, at::Tensor & out) { + return at::_ops::randint_like_low_generator_dtype_out::redispatch(dispatchKeySet, self, low, high, generator, memory_format, out); + } + + // aten::randn.names_out(SymInt[] size, *, Dimname[]? names, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randn_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::IntArrayRef size, ::std::optional names) { + return at::_ops::randn_names_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), names, out); + } + + // aten::randn.names_out(SymInt[] size, *, Dimname[]? names, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randn_outf(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, ::std::optional names, at::Tensor & out) { + return at::_ops::randn_names_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), names, out); + } + + // aten::randn.names_out(SymInt[] size, *, Dimname[]? names, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randn_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, c10::SymIntArrayRef size, ::std::optional names) { + return at::_ops::randn_names_out::redispatch(dispatchKeySet, size, names, out); + } + + // aten::randn.names_out(SymInt[] size, *, Dimname[]? names, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randn_symint_outf(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, ::std::optional names, at::Tensor & out) { + return at::_ops::randn_names_out::redispatch(dispatchKeySet, size, names, out); + } + + // aten::randn.generator_with_names_out(SymInt[] size, *, Generator? generator, Dimname[]? names, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randn_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::IntArrayRef size, ::std::optional generator, ::std::optional names) { + return at::_ops::randn_generator_with_names_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), generator, names, out); + } + + // aten::randn.generator_with_names_out(SymInt[] size, *, Generator? generator, Dimname[]? names, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randn_outf(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, ::std::optional generator, ::std::optional names, at::Tensor & out) { + return at::_ops::randn_generator_with_names_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), generator, names, out); + } + + // aten::randn.generator_with_names_out(SymInt[] size, *, Generator? generator, Dimname[]? names, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randn_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, c10::SymIntArrayRef size, ::std::optional generator, ::std::optional names) { + return at::_ops::randn_generator_with_names_out::redispatch(dispatchKeySet, size, generator, names, out); + } + + // aten::randn.generator_with_names_out(SymInt[] size, *, Generator? generator, Dimname[]? names, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randn_symint_outf(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, ::std::optional generator, ::std::optional names, at::Tensor & out) { + return at::_ops::randn_generator_with_names_out::redispatch(dispatchKeySet, size, generator, names, out); + } + + // aten::randn_like.out(Tensor self, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randn_like_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, ::std::optional memory_format=::std::nullopt) { + return at::_ops::randn_like_out::redispatch(dispatchKeySet, self, memory_format, out); + } + + // aten::randn_like.out(Tensor self, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randn_like_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional memory_format, at::Tensor & out) { + return at::_ops::randn_like_out::redispatch(dispatchKeySet, self, memory_format, out); + } + + // aten::randn_like.generator_out(Tensor self, *, Generator? generator, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randn_like_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, ::std::optional generator, ::std::optional memory_format=::std::nullopt) { + return at::_ops::randn_like_generator_out::redispatch(dispatchKeySet, self, generator, memory_format, out); + } + + // aten::randn_like.generator_out(Tensor self, *, Generator? generator, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randn_like_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional generator, ::std::optional memory_format, at::Tensor & out) { + return at::_ops::randn_like_generator_out::redispatch(dispatchKeySet, self, generator, memory_format, out); + } + + // aten::repeat.out(Tensor self, SymInt[] repeats, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & repeat_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef repeats) { + return at::_ops::repeat_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(repeats), out); + } + + // aten::repeat.out(Tensor self, SymInt[] repeats, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & repeat_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef repeats, at::Tensor & out) { + return at::_ops::repeat_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(repeats), out); + } + + // aten::repeat.out(Tensor self, SymInt[] repeats, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & repeat_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef repeats) { + return at::_ops::repeat_out::redispatch(dispatchKeySet, self, repeats, out); + } + + // aten::repeat.out(Tensor self, SymInt[] repeats, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & repeat_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef repeats, at::Tensor & out) { + return at::_ops::repeat_out::redispatch(dispatchKeySet, self, repeats, out); + } + + // aten::repeat_interleave.Tensor_out(Tensor repeats, *, SymInt? output_size=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & repeat_interleave_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & repeats, ::std::optional output_size=::std::nullopt) { + return at::_ops::repeat_interleave_Tensor_out::redispatch(dispatchKeySet, repeats, output_size.has_value() ? ::std::make_optional(c10::SymInt(*output_size)) : ::std::nullopt, out); + } + + // aten::repeat_interleave.Tensor_out(Tensor repeats, *, SymInt? output_size=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & repeat_interleave_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & repeats, ::std::optional output_size, at::Tensor & out) { + return at::_ops::repeat_interleave_Tensor_out::redispatch(dispatchKeySet, repeats, output_size.has_value() ? ::std::make_optional(c10::SymInt(*output_size)) : ::std::nullopt, out); + } + + // aten::repeat_interleave.Tensor_out(Tensor repeats, *, SymInt? output_size=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & repeat_interleave_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & repeats, ::std::optional output_size=::std::nullopt) { + return at::_ops::repeat_interleave_Tensor_out::redispatch(dispatchKeySet, repeats, output_size, out); + } + + // aten::repeat_interleave.Tensor_out(Tensor repeats, *, SymInt? output_size=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & repeat_interleave_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & repeats, ::std::optional output_size, at::Tensor & out) { + return at::_ops::repeat_interleave_Tensor_out::redispatch(dispatchKeySet, repeats, output_size, out); + } + + // aten::_mkldnn_reshape.out(Tensor self, int[] shape, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _mkldnn_reshape_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef shape) { + return at::_ops::_mkldnn_reshape_out::redispatch(dispatchKeySet, self, shape, out); + } + + // aten::_mkldnn_reshape.out(Tensor self, int[] shape, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _mkldnn_reshape_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef shape, at::Tensor & out) { + return at::_ops::_mkldnn_reshape_out::redispatch(dispatchKeySet, self, shape, out); + } + + // aten::relu.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & relu_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::relu_out::redispatch(dispatchKeySet, self, out); + } + + // aten::relu.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & relu_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::relu_out::redispatch(dispatchKeySet, self, out); + } + + // aten::select_backward.out(Tensor grad_output, SymInt[] input_sizes, int dim, SymInt index, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & select_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad_output, at::IntArrayRef input_sizes, int64_t dim, int64_t index) { + return at::_ops::select_backward_out::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(input_sizes), dim, index, out); + } + + // aten::select_backward.out(Tensor grad_output, SymInt[] input_sizes, int dim, SymInt index, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & select_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, at::IntArrayRef input_sizes, int64_t dim, int64_t index, at::Tensor & out) { + return at::_ops::select_backward_out::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(input_sizes), dim, index, out); + } + + // aten::select_backward.out(Tensor grad_output, SymInt[] input_sizes, int dim, SymInt index, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & select_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad_output, c10::SymIntArrayRef input_sizes, int64_t dim, c10::SymInt index) { + return at::_ops::select_backward_out::redispatch(dispatchKeySet, grad_output, input_sizes, dim, index, out); + } + + // aten::select_backward.out(Tensor grad_output, SymInt[] input_sizes, int dim, SymInt index, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & select_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, c10::SymIntArrayRef input_sizes, int64_t dim, c10::SymInt index, at::Tensor & out) { + return at::_ops::select_backward_out::redispatch(dispatchKeySet, grad_output, input_sizes, dim, index, out); + } + + // aten::celu.out(Tensor self, Scalar alpha=1.0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & celu_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & alpha=1.0) { + return at::_ops::celu_out::redispatch(dispatchKeySet, self, alpha, out); + } + + // aten::celu.out(Tensor self, Scalar alpha=1.0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & celu_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & alpha, at::Tensor & out) { + return at::_ops::celu_out::redispatch(dispatchKeySet, self, alpha, out); + } + + // aten::slice_backward.out(Tensor grad_output, SymInt[] input_sizes, int dim, SymInt start, SymInt end, SymInt step, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & slice_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad_output, at::IntArrayRef input_sizes, int64_t dim, int64_t start, int64_t end, int64_t step) { + return at::_ops::slice_backward_out::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(input_sizes), dim, start, end, step, out); + } + + // aten::slice_backward.out(Tensor grad_output, SymInt[] input_sizes, int dim, SymInt start, SymInt end, SymInt step, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & slice_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, at::IntArrayRef input_sizes, int64_t dim, int64_t start, int64_t end, int64_t step, at::Tensor & out) { + return at::_ops::slice_backward_out::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(input_sizes), dim, start, end, step, out); + } + + // aten::slice_backward.out(Tensor grad_output, SymInt[] input_sizes, int dim, SymInt start, SymInt end, SymInt step, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & slice_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad_output, c10::SymIntArrayRef input_sizes, int64_t dim, c10::SymInt start, c10::SymInt end, c10::SymInt step) { + return at::_ops::slice_backward_out::redispatch(dispatchKeySet, grad_output, input_sizes, dim, start, end, step, out); + } + + // aten::slice_backward.out(Tensor grad_output, SymInt[] input_sizes, int dim, SymInt start, SymInt end, SymInt step, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & slice_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, c10::SymIntArrayRef input_sizes, int64_t dim, c10::SymInt start, c10::SymInt end, c10::SymInt step, at::Tensor & out) { + return at::_ops::slice_backward_out::redispatch(dispatchKeySet, grad_output, input_sizes, dim, start, end, step, out); + } + + // aten::slice_scatter.out(Tensor self, Tensor src, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & slice_scatter_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & src, int64_t dim=0, ::std::optional start=::std::nullopt, ::std::optional end=::std::nullopt, int64_t step=1) { + return at::_ops::slice_scatter_out::redispatch(dispatchKeySet, self, src, dim, start.has_value() ? ::std::make_optional(c10::SymInt(*start)) : ::std::nullopt, end.has_value() ? ::std::make_optional(c10::SymInt(*end)) : ::std::nullopt, step, out); + } + + // aten::slice_scatter.out(Tensor self, Tensor src, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & slice_scatter_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & src, int64_t dim, ::std::optional start, ::std::optional end, int64_t step, at::Tensor & out) { + return at::_ops::slice_scatter_out::redispatch(dispatchKeySet, self, src, dim, start.has_value() ? ::std::make_optional(c10::SymInt(*start)) : ::std::nullopt, end.has_value() ? ::std::make_optional(c10::SymInt(*end)) : ::std::nullopt, step, out); + } + + // aten::slice_scatter.out(Tensor self, Tensor src, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & slice_scatter_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & src, int64_t dim=0, ::std::optional start=::std::nullopt, ::std::optional end=::std::nullopt, c10::SymInt step=1) { + return at::_ops::slice_scatter_out::redispatch(dispatchKeySet, self, src, dim, start, end, step, out); + } + + // aten::slice_scatter.out(Tensor self, Tensor src, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & slice_scatter_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & src, int64_t dim, ::std::optional start, ::std::optional end, c10::SymInt step, at::Tensor & out) { + return at::_ops::slice_scatter_out::redispatch(dispatchKeySet, self, src, dim, start, end, step, out); + } + + // aten::select_scatter.out(Tensor self, Tensor src, int dim, SymInt index, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & select_scatter_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & src, int64_t dim, int64_t index) { + return at::_ops::select_scatter_out::redispatch(dispatchKeySet, self, src, dim, index, out); + } + + // aten::select_scatter.out(Tensor self, Tensor src, int dim, SymInt index, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & select_scatter_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & src, int64_t dim, int64_t index, at::Tensor & out) { + return at::_ops::select_scatter_out::redispatch(dispatchKeySet, self, src, dim, index, out); + } + + // aten::select_scatter.out(Tensor self, Tensor src, int dim, SymInt index, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & select_scatter_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & src, int64_t dim, c10::SymInt index) { + return at::_ops::select_scatter_out::redispatch(dispatchKeySet, self, src, dim, index, out); + } + + // aten::select_scatter.out(Tensor self, Tensor src, int dim, SymInt index, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & select_scatter_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & src, int64_t dim, c10::SymInt index, at::Tensor & out) { + return at::_ops::select_scatter_out::redispatch(dispatchKeySet, self, src, dim, index, out); + } + + // aten::diagonal_scatter.out(Tensor self, Tensor src, int offset=0, int dim1=0, int dim2=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & diagonal_scatter_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & src, int64_t offset=0, int64_t dim1=0, int64_t dim2=1) { + return at::_ops::diagonal_scatter_out::redispatch(dispatchKeySet, self, src, offset, dim1, dim2, out); + } + + // aten::diagonal_scatter.out(Tensor self, Tensor src, int offset=0, int dim1=0, int dim2=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & diagonal_scatter_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & src, int64_t offset, int64_t dim1, int64_t dim2, at::Tensor & out) { + return at::_ops::diagonal_scatter_out::redispatch(dispatchKeySet, self, src, offset, dim1, dim2, out); + } + + // aten::as_strided_scatter.out(Tensor self, Tensor src, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & as_strided_scatter_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & src, at::IntArrayRef size, at::IntArrayRef stride, ::std::optional storage_offset=::std::nullopt) { + return at::_ops::as_strided_scatter_out::redispatch(dispatchKeySet, self, src, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride), storage_offset.has_value() ? ::std::make_optional(c10::SymInt(*storage_offset)) : ::std::nullopt, out); + } + + // aten::as_strided_scatter.out(Tensor self, Tensor src, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & as_strided_scatter_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & src, at::IntArrayRef size, at::IntArrayRef stride, ::std::optional storage_offset, at::Tensor & out) { + return at::_ops::as_strided_scatter_out::redispatch(dispatchKeySet, self, src, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride), storage_offset.has_value() ? ::std::make_optional(c10::SymInt(*storage_offset)) : ::std::nullopt, out); + } + + // aten::as_strided_scatter.out(Tensor self, Tensor src, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & as_strided_scatter_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & src, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional storage_offset=::std::nullopt) { + return at::_ops::as_strided_scatter_out::redispatch(dispatchKeySet, self, src, size, stride, storage_offset, out); + } + + // aten::as_strided_scatter.out(Tensor self, Tensor src, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & as_strided_scatter_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & src, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional storage_offset, at::Tensor & out) { + return at::_ops::as_strided_scatter_out::redispatch(dispatchKeySet, self, src, size, stride, storage_offset, out); + } + + // aten::unsafe_split.Tensor_out(Tensor self, SymInt split_size, int dim=0, *, Tensor(a!)[] out) -> () + inline void unsafe_split_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, const at::Tensor & self, int64_t split_size, int64_t dim=0) { + return at::_ops::unsafe_split_Tensor_out::redispatch(dispatchKeySet, self, split_size, dim, out); + } + + // aten::unsafe_split.Tensor_out(Tensor self, SymInt split_size, int dim=0, *, Tensor(a!)[] out) -> () + inline void unsafe_split_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t split_size, int64_t dim, at::TensorList out) { + return at::_ops::unsafe_split_Tensor_out::redispatch(dispatchKeySet, self, split_size, dim, out); + } + + // aten::unsafe_split.Tensor_out(Tensor self, SymInt split_size, int dim=0, *, Tensor(a!)[] out) -> () + inline void unsafe_split_symint_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, const at::Tensor & self, c10::SymInt split_size, int64_t dim=0) { + return at::_ops::unsafe_split_Tensor_out::redispatch(dispatchKeySet, self, split_size, dim, out); + } + + // aten::unsafe_split.Tensor_out(Tensor self, SymInt split_size, int dim=0, *, Tensor(a!)[] out) -> () + inline void unsafe_split_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymInt split_size, int64_t dim, at::TensorList out) { + return at::_ops::unsafe_split_Tensor_out::redispatch(dispatchKeySet, self, split_size, dim, out); + } + + // aten::unsafe_split_with_sizes.out(Tensor self, SymInt[] split_sizes, int dim=0, *, Tensor(a!)[] out) -> () + inline void unsafe_split_with_sizes_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, const at::Tensor & self, at::IntArrayRef split_sizes, int64_t dim=0) { + return at::_ops::unsafe_split_with_sizes_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(split_sizes), dim, out); + } + + // aten::unsafe_split_with_sizes.out(Tensor self, SymInt[] split_sizes, int dim=0, *, Tensor(a!)[] out) -> () + inline void unsafe_split_with_sizes_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef split_sizes, int64_t dim, at::TensorList out) { + return at::_ops::unsafe_split_with_sizes_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(split_sizes), dim, out); + } + + // aten::unsafe_split_with_sizes.out(Tensor self, SymInt[] split_sizes, int dim=0, *, Tensor(a!)[] out) -> () + inline void unsafe_split_with_sizes_symint_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, const at::Tensor & self, c10::SymIntArrayRef split_sizes, int64_t dim=0) { + return at::_ops::unsafe_split_with_sizes_out::redispatch(dispatchKeySet, self, split_sizes, dim, out); + } + + // aten::unsafe_split_with_sizes.out(Tensor self, SymInt[] split_sizes, int dim=0, *, Tensor(a!)[] out) -> () + inline void unsafe_split_with_sizes_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef split_sizes, int64_t dim, at::TensorList out) { + return at::_ops::unsafe_split_with_sizes_out::redispatch(dispatchKeySet, self, split_sizes, dim, out); + } + + // aten::sum.out(Tensor self, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & sum_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, ::std::optional dtype=::std::nullopt) { + return at::_ops::sum_out::redispatch(dispatchKeySet, self, dtype, out); + } + + // aten::sum.out(Tensor self, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & sum_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional dtype, at::Tensor & out) { + return at::_ops::sum_out::redispatch(dispatchKeySet, self, dtype, out); + } + + // aten::std_mean.correction_out(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple std_mean_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, const at::Tensor & self, at::OptionalIntArrayRef dim=::std::nullopt, const ::std::optional & correction=::std::nullopt, bool keepdim=false) { + return at::_ops::std_mean_correction_out::redispatch(dispatchKeySet, self, dim, correction, keepdim, out0, out1); + } + + // aten::std_mean.correction_out(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple std_mean_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef dim, const ::std::optional & correction, bool keepdim, at::Tensor & out0, at::Tensor & out1) { + return at::_ops::std_mean_correction_out::redispatch(dispatchKeySet, self, dim, correction, keepdim, out0, out1); + } + + // aten::prod.out(Tensor self, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & prod_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, ::std::optional dtype=::std::nullopt) { + return at::_ops::prod_out::redispatch(dispatchKeySet, self, dtype, out); + } + + // aten::prod.out(Tensor self, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & prod_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional dtype, at::Tensor & out) { + return at::_ops::prod_out::redispatch(dispatchKeySet, self, dtype, out); + } + + // aten::_mkldnn_transpose.out(Tensor self, int dim0, int dim1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _mkldnn_transpose_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim0, int64_t dim1) { + return at::_ops::_mkldnn_transpose_out::redispatch(dispatchKeySet, self, dim0, dim1, out); + } + + // aten::_mkldnn_transpose.out(Tensor self, int dim0, int dim1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _mkldnn_transpose_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim0, int64_t dim1, at::Tensor & out) { + return at::_ops::_mkldnn_transpose_out::redispatch(dispatchKeySet, self, dim0, dim1, out); + } + + // aten::flip.out(Tensor self, int[] dims, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & flip_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef dims) { + return at::_ops::flip_out::redispatch(dispatchKeySet, self, dims, out); + } + + // aten::flip.out(Tensor self, int[] dims, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & flip_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dims, at::Tensor & out) { + return at::_ops::flip_out::redispatch(dispatchKeySet, self, dims, out); + } + + // aten::roll.out(Tensor self, SymInt[1] shifts, int[1] dims=[], *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & roll_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef shifts, at::IntArrayRef dims={}) { + return at::_ops::roll_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(shifts), dims, out); + } + + // aten::roll.out(Tensor self, SymInt[1] shifts, int[1] dims=[], *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & roll_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef shifts, at::IntArrayRef dims, at::Tensor & out) { + return at::_ops::roll_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(shifts), dims, out); + } + + // aten::roll.out(Tensor self, SymInt[1] shifts, int[1] dims=[], *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & roll_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef shifts, at::IntArrayRef dims={}) { + return at::_ops::roll_out::redispatch(dispatchKeySet, self, shifts, dims, out); + } + + // aten::roll.out(Tensor self, SymInt[1] shifts, int[1] dims=[], *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & roll_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef shifts, at::IntArrayRef dims, at::Tensor & out) { + return at::_ops::roll_out::redispatch(dispatchKeySet, self, shifts, dims, out); + } + + // aten::rot90.out(Tensor self, int k=1, int[] dims=[0,1], *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & rot90_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t k=1, at::IntArrayRef dims={0,1}) { + return at::_ops::rot90_out::redispatch(dispatchKeySet, self, k, dims, out); + } + + // aten::rot90.out(Tensor self, int k=1, int[] dims=[0,1], *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & rot90_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t k, at::IntArrayRef dims, at::Tensor & out) { + return at::_ops::rot90_out::redispatch(dispatchKeySet, self, k, dims, out); + } + + // aten::_transform_bias_rescale_qkv.out(Tensor qkv, Tensor qkv_bias, int num_heads, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple _transform_bias_rescale_qkv_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, const at::Tensor & qkv, const at::Tensor & qkv_bias, int64_t num_heads) { + return at::_ops::_transform_bias_rescale_qkv_out::redispatch(dispatchKeySet, qkv, qkv_bias, num_heads, out0, out1, out2); + } + + // aten::_transform_bias_rescale_qkv.out(Tensor qkv, Tensor qkv_bias, int num_heads, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple _transform_bias_rescale_qkv_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & qkv, const at::Tensor & qkv_bias, int64_t num_heads, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2) { + return at::_ops::_transform_bias_rescale_qkv_out::redispatch(dispatchKeySet, qkv, qkv_bias, num_heads, out0, out1, out2); + } + + // aten::_nested_tensor_from_mask.out(Tensor t, Tensor mask, bool mask_check=True, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _nested_tensor_from_mask_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & t, const at::Tensor & mask, bool mask_check=true) { + return at::_ops::_nested_tensor_from_mask_out::redispatch(dispatchKeySet, t, mask, mask_check, out); + } + + // aten::_nested_tensor_from_mask.out(Tensor t, Tensor mask, bool mask_check=True, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _nested_tensor_from_mask_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & t, const at::Tensor & mask, bool mask_check, at::Tensor & out) { + return at::_ops::_nested_tensor_from_mask_out::redispatch(dispatchKeySet, t, mask, mask_check, out); + } + + // aten::_nested_from_padded.out(Tensor padded, Tensor cpu_nested_shape_example, bool fuse_transform_0213=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _nested_from_padded_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & padded, const at::Tensor & cpu_nested_shape_example, bool fuse_transform_0213=false) { + return at::_ops::_nested_from_padded_out::redispatch(dispatchKeySet, padded, cpu_nested_shape_example, fuse_transform_0213, out); + } + + // aten::_nested_from_padded.out(Tensor padded, Tensor cpu_nested_shape_example, bool fuse_transform_0213=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _nested_from_padded_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & padded, const at::Tensor & cpu_nested_shape_example, bool fuse_transform_0213, at::Tensor & out) { + return at::_ops::_nested_from_padded_out::redispatch(dispatchKeySet, padded, cpu_nested_shape_example, fuse_transform_0213, out); + } + + // aten::_nested_tensor_size.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _nested_tensor_size_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::_nested_tensor_size_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_nested_tensor_size.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _nested_tensor_size_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::_nested_tensor_size_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_nested_tensor_strides.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _nested_tensor_strides_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::_nested_tensor_strides_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_nested_tensor_strides.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _nested_tensor_strides_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::_nested_tensor_strides_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_nested_tensor_storage_offsets.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _nested_tensor_storage_offsets_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::_nested_tensor_storage_offsets_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_nested_tensor_storage_offsets.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _nested_tensor_storage_offsets_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::_nested_tensor_storage_offsets_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_nested_from_padded_and_nested_example.out(Tensor padded, Tensor nt_example, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _nested_from_padded_and_nested_example_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & padded, const at::Tensor & nt_example) { + return at::_ops::_nested_from_padded_and_nested_example_out::redispatch(dispatchKeySet, padded, nt_example, out); + } + + // aten::_nested_from_padded_and_nested_example.out(Tensor padded, Tensor nt_example, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _nested_from_padded_and_nested_example_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & padded, const at::Tensor & nt_example, at::Tensor & out) { + return at::_ops::_nested_from_padded_and_nested_example_out::redispatch(dispatchKeySet, padded, nt_example, out); + } + + // aten::_nested_view_from_buffer_copy.out(Tensor self, Tensor nested_size, Tensor nested_strides, Tensor offsets, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _nested_view_from_buffer_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & nested_size, const at::Tensor & nested_strides, const at::Tensor & offsets) { + return at::_ops::_nested_view_from_buffer_copy_out::redispatch(dispatchKeySet, self, nested_size, nested_strides, offsets, out); + } + + // aten::_nested_view_from_buffer_copy.out(Tensor self, Tensor nested_size, Tensor nested_strides, Tensor offsets, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _nested_view_from_buffer_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & nested_size, const at::Tensor & nested_strides, const at::Tensor & offsets, at::Tensor & out) { + return at::_ops::_nested_view_from_buffer_copy_out::redispatch(dispatchKeySet, self, nested_size, nested_strides, offsets, out); + } + + // aten::_nested_view_from_jagged_copy.out(Tensor self, Tensor offsets, Tensor dummy, Tensor? lengths=None, int ragged_idx=1, Tensor? min_seqlen=None, Tensor? max_seqlen=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _nested_view_from_jagged_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & offsets, const at::Tensor & dummy, const ::std::optional & lengths={}, int64_t ragged_idx=1, const ::std::optional & min_seqlen={}, const ::std::optional & max_seqlen={}) { + return at::_ops::_nested_view_from_jagged_copy_out::redispatch(dispatchKeySet, self, offsets, dummy, lengths, ragged_idx, min_seqlen, max_seqlen, out); + } + + // aten::_nested_view_from_jagged_copy.out(Tensor self, Tensor offsets, Tensor dummy, Tensor? lengths=None, int ragged_idx=1, Tensor? min_seqlen=None, Tensor? max_seqlen=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _nested_view_from_jagged_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & offsets, const at::Tensor & dummy, const ::std::optional & lengths, int64_t ragged_idx, const ::std::optional & min_seqlen, const ::std::optional & max_seqlen, at::Tensor & out) { + return at::_ops::_nested_view_from_jagged_copy_out::redispatch(dispatchKeySet, self, offsets, dummy, lengths, ragged_idx, min_seqlen, max_seqlen, out); + } + + // aten::_nested_get_values_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _nested_get_values_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::_nested_get_values_copy_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_nested_get_values_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _nested_get_values_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::_nested_get_values_copy_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_trilinear.out(Tensor i1, Tensor i2, Tensor i3, int[] expand1, int[] expand2, int[] expand3, int[] sumdim, int unroll_dim=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _trilinear_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & i1, const at::Tensor & i2, const at::Tensor & i3, at::IntArrayRef expand1, at::IntArrayRef expand2, at::IntArrayRef expand3, at::IntArrayRef sumdim, int64_t unroll_dim=1) { + return at::_ops::_trilinear_out::redispatch(dispatchKeySet, i1, i2, i3, expand1, expand2, expand3, sumdim, unroll_dim, out); + } + + // aten::_trilinear.out(Tensor i1, Tensor i2, Tensor i3, int[] expand1, int[] expand2, int[] expand3, int[] sumdim, int unroll_dim=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _trilinear_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & i1, const at::Tensor & i2, const at::Tensor & i3, at::IntArrayRef expand1, at::IntArrayRef expand2, at::IntArrayRef expand3, at::IntArrayRef sumdim, int64_t unroll_dim, at::Tensor & out) { + return at::_ops::_trilinear_out::redispatch(dispatchKeySet, i1, i2, i3, expand1, expand2, expand3, sumdim, unroll_dim, out); + } + + // aten::_unique.out(Tensor self, bool sorted=True, bool return_inverse=False, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple _unique_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, const at::Tensor & self, bool sorted=true, bool return_inverse=false) { + return at::_ops::_unique_out::redispatch(dispatchKeySet, self, sorted, return_inverse, out0, out1); + } + + // aten::_unique.out(Tensor self, bool sorted=True, bool return_inverse=False, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple _unique_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool sorted, bool return_inverse, at::Tensor & out0, at::Tensor & out1) { + return at::_ops::_unique_out::redispatch(dispatchKeySet, self, sorted, return_inverse, out0, out1); + } + + // aten::unique_dim.out(Tensor self, int dim, bool sorted=True, bool return_inverse=False, bool return_counts=False, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple unique_dim_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, const at::Tensor & self, int64_t dim, bool sorted=true, bool return_inverse=false, bool return_counts=false) { + return at::_ops::unique_dim_out::redispatch(dispatchKeySet, self, dim, sorted, return_inverse, return_counts, out0, out1, out2); + } + + // aten::unique_dim.out(Tensor self, int dim, bool sorted=True, bool return_inverse=False, bool return_counts=False, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple unique_dim_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, bool sorted, bool return_inverse, bool return_counts, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2) { + return at::_ops::unique_dim_out::redispatch(dispatchKeySet, self, dim, sorted, return_inverse, return_counts, out0, out1, out2); + } + + // aten::unique_consecutive.out(Tensor self, bool return_inverse=False, bool return_counts=False, int? dim=None, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple unique_consecutive_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, const at::Tensor & self, bool return_inverse=false, bool return_counts=false, ::std::optional dim=::std::nullopt) { + return at::_ops::unique_consecutive_out::redispatch(dispatchKeySet, self, return_inverse, return_counts, dim, out0, out1, out2); + } + + // aten::unique_consecutive.out(Tensor self, bool return_inverse=False, bool return_counts=False, int? dim=None, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple unique_consecutive_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool return_inverse, bool return_counts, ::std::optional dim, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2) { + return at::_ops::unique_consecutive_out::redispatch(dispatchKeySet, self, return_inverse, return_counts, dim, out0, out1, out2); + } + + // aten::unique_dim_consecutive.out(Tensor self, int dim, bool return_inverse=False, bool return_counts=False, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple unique_dim_consecutive_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, const at::Tensor & self, int64_t dim, bool return_inverse=false, bool return_counts=false) { + return at::_ops::unique_dim_consecutive_out::redispatch(dispatchKeySet, self, dim, return_inverse, return_counts, out0, out1, out2); + } + + // aten::unique_dim_consecutive.out(Tensor self, int dim, bool return_inverse=False, bool return_counts=False, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple unique_dim_consecutive_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, bool return_inverse, bool return_counts, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2) { + return at::_ops::unique_dim_consecutive_out::redispatch(dispatchKeySet, self, dim, return_inverse, return_counts, out0, out1, out2); + } + + // aten::_unique2.out(Tensor self, bool sorted=True, bool return_inverse=False, bool return_counts=False, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple _unique2_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, const at::Tensor & self, bool sorted=true, bool return_inverse=false, bool return_counts=false) { + return at::_ops::_unique2_out::redispatch(dispatchKeySet, self, sorted, return_inverse, return_counts, out0, out1, out2); + } + + // aten::_unique2.out(Tensor self, bool sorted=True, bool return_inverse=False, bool return_counts=False, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple _unique2_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool sorted, bool return_inverse, bool return_counts, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2) { + return at::_ops::_unique2_out::redispatch(dispatchKeySet, self, sorted, return_inverse, return_counts, out0, out1, out2); + } + + // aten::_unsafe_view.out(Tensor self, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _unsafe_view_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef size) { + return at::_ops::_unsafe_view_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), out); + } + + // aten::_unsafe_view.out(Tensor self, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _unsafe_view_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, at::Tensor & out) { + return at::_ops::_unsafe_view_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), out); + } + + // aten::_unsafe_view.out(Tensor self, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _unsafe_view_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef size) { + return at::_ops::_unsafe_view_out::redispatch(dispatchKeySet, self, size, out); + } + + // aten::_unsafe_view.out(Tensor self, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _unsafe_view_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, at::Tensor & out) { + return at::_ops::_unsafe_view_out::redispatch(dispatchKeySet, self, size, out); + } + + // aten::var_mean.correction_out(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple var_mean_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, const at::Tensor & self, at::OptionalIntArrayRef dim=::std::nullopt, const ::std::optional & correction=::std::nullopt, bool keepdim=false) { + return at::_ops::var_mean_correction_out::redispatch(dispatchKeySet, self, dim, correction, keepdim, out0, out1); + } + + // aten::var_mean.correction_out(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple var_mean_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef dim, const ::std::optional & correction, bool keepdim, at::Tensor & out0, at::Tensor & out1) { + return at::_ops::var_mean_correction_out::redispatch(dispatchKeySet, self, dim, correction, keepdim, out0, out1); + } + + // aten::_weight_norm_interface.out(Tensor v, Tensor g, int dim=0, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple _weight_norm_interface_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, const at::Tensor & v, const at::Tensor & g, int64_t dim=0) { + return at::_ops::_weight_norm_interface_out::redispatch(dispatchKeySet, v, g, dim, out0, out1); + } + + // aten::_weight_norm_interface.out(Tensor v, Tensor g, int dim=0, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple _weight_norm_interface_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & v, const at::Tensor & g, int64_t dim, at::Tensor & out0, at::Tensor & out1) { + return at::_ops::_weight_norm_interface_out::redispatch(dispatchKeySet, v, g, dim, out0, out1); + } + + // aten::_weight_norm_interface_backward.out(Tensor grad_w, Tensor saved_v, Tensor saved_g, Tensor saved_norms, int dim, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple _weight_norm_interface_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, const at::Tensor & grad_w, const at::Tensor & saved_v, const at::Tensor & saved_g, const at::Tensor & saved_norms, int64_t dim) { + return at::_ops::_weight_norm_interface_backward_out::redispatch(dispatchKeySet, grad_w, saved_v, saved_g, saved_norms, dim, out0, out1); + } + + // aten::_weight_norm_interface_backward.out(Tensor grad_w, Tensor saved_v, Tensor saved_g, Tensor saved_norms, int dim, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple _weight_norm_interface_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_w, const at::Tensor & saved_v, const at::Tensor & saved_g, const at::Tensor & saved_norms, int64_t dim, at::Tensor & out0, at::Tensor & out1) { + return at::_ops::_weight_norm_interface_backward_out::redispatch(dispatchKeySet, grad_w, saved_v, saved_g, saved_norms, dim, out0, out1); + } + + // aten::zeros.names_out(int[] size, *, Dimname[]? names, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & zeros_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::IntArrayRef size, ::std::optional names) { + return at::_ops::zeros_names_out::redispatch(dispatchKeySet, size, names, out); + } + + // aten::zeros.names_out(int[] size, *, Dimname[]? names, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & zeros_outf(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, ::std::optional names, at::Tensor & out) { + return at::_ops::zeros_names_out::redispatch(dispatchKeySet, size, names, out); + } + + // aten::_efficientzerotensor.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _efficientzerotensor_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::IntArrayRef size) { + return at::_ops::_efficientzerotensor_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), out); + } + + // aten::_efficientzerotensor.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _efficientzerotensor_outf(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, at::Tensor & out) { + return at::_ops::_efficientzerotensor_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), out); + } + + // aten::_efficientzerotensor.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _efficientzerotensor_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, c10::SymIntArrayRef size) { + return at::_ops::_efficientzerotensor_out::redispatch(dispatchKeySet, size, out); + } + + // aten::_efficientzerotensor.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _efficientzerotensor_symint_outf(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, at::Tensor & out) { + return at::_ops::_efficientzerotensor_out::redispatch(dispatchKeySet, size, out); + } + + // aten::zeros_like.out(Tensor self, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & zeros_like_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, ::std::optional memory_format=::std::nullopt) { + return at::_ops::zeros_like_out::redispatch(dispatchKeySet, self, memory_format, out); + } + + // aten::zeros_like.out(Tensor self, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & zeros_like_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional memory_format, at::Tensor & out) { + return at::_ops::zeros_like_out::redispatch(dispatchKeySet, self, memory_format, out); + } + + // aten::_standard_gamma_grad.out(Tensor self, Tensor output, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _standard_gamma_grad_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & output) { + return at::_ops::_standard_gamma_grad_out::redispatch(dispatchKeySet, self, output, out); + } + + // aten::_standard_gamma_grad.out(Tensor self, Tensor output, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _standard_gamma_grad_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & output, at::Tensor & out) { + return at::_ops::_standard_gamma_grad_out::redispatch(dispatchKeySet, self, output, out); + } + + // aten::_standard_gamma.out(Tensor self, Generator? generator=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _standard_gamma_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, ::std::optional generator=::std::nullopt) { + return at::_ops::_standard_gamma_out::redispatch(dispatchKeySet, self, generator, out); + } + + // aten::_standard_gamma.out(Tensor self, Generator? generator=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _standard_gamma_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional generator, at::Tensor & out) { + return at::_ops::_standard_gamma_out::redispatch(dispatchKeySet, self, generator, out); + } + + // aten::_dirichlet_grad.out(Tensor x, Tensor alpha, Tensor total, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _dirichlet_grad_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & x, const at::Tensor & alpha, const at::Tensor & total) { + return at::_ops::_dirichlet_grad_out::redispatch(dispatchKeySet, x, alpha, total, out); + } + + // aten::_dirichlet_grad.out(Tensor x, Tensor alpha, Tensor total, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _dirichlet_grad_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Tensor & alpha, const at::Tensor & total, at::Tensor & out) { + return at::_ops::_dirichlet_grad_out::redispatch(dispatchKeySet, x, alpha, total, out); + } + + // aten::_sample_dirichlet.out(Tensor self, Generator? generator=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _sample_dirichlet_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, ::std::optional generator=::std::nullopt) { + return at::_ops::_sample_dirichlet_out::redispatch(dispatchKeySet, self, generator, out); + } + + // aten::_sample_dirichlet.out(Tensor self, Generator? generator=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _sample_dirichlet_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional generator, at::Tensor & out) { + return at::_ops::_sample_dirichlet_out::redispatch(dispatchKeySet, self, generator, out); + } + + // aten::poisson.out(Tensor self, Generator? generator=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & poisson_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, ::std::optional generator=::std::nullopt) { + return at::_ops::poisson_out::redispatch(dispatchKeySet, self, generator, out); + } + + // aten::poisson.out(Tensor self, Generator? generator=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & poisson_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional generator, at::Tensor & out) { + return at::_ops::poisson_out::redispatch(dispatchKeySet, self, generator, out); + } + + // aten::binomial.out(Tensor count, Tensor prob, Generator? generator=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & binomial_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & count, const at::Tensor & prob, ::std::optional generator=::std::nullopt) { + return at::_ops::binomial_out::redispatch(dispatchKeySet, count, prob, generator, out); + } + + // aten::binomial.out(Tensor count, Tensor prob, Generator? generator=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & binomial_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & count, const at::Tensor & prob, ::std::optional generator, at::Tensor & out) { + return at::_ops::binomial_out::redispatch(dispatchKeySet, count, prob, generator, out); + } + + // aten::native_norm.out(Tensor self, Scalar p=2, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & native_norm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & p=2) { + return at::_ops::native_norm_out::redispatch(dispatchKeySet, self, p, out); + } + + // aten::native_norm.out(Tensor self, Scalar p=2, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & native_norm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & p, at::Tensor & out) { + return at::_ops::native_norm_out::redispatch(dispatchKeySet, self, p, out); + } + + // aten::native_norm.ScalarOpt_dim_dtype_out(Tensor self, Scalar? p, int[1] dim, bool keepdim, ScalarType? dtype, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & native_norm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const ::std::optional & p, at::IntArrayRef dim, bool keepdim, ::std::optional dtype) { + return at::_ops::native_norm_ScalarOpt_dim_dtype_out::redispatch(dispatchKeySet, self, p, dim, keepdim, dtype, out); + } + + // aten::native_norm.ScalarOpt_dim_dtype_out(Tensor self, Scalar? p, int[1] dim, bool keepdim, ScalarType? dtype, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & native_norm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const ::std::optional & p, at::IntArrayRef dim, bool keepdim, ::std::optional dtype, at::Tensor & out) { + return at::_ops::native_norm_ScalarOpt_dim_dtype_out::redispatch(dispatchKeySet, self, p, dim, keepdim, dtype, out); + } + + // aten::_batch_norm_with_update_functional(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor, Tensor, Tensor running_mean_out, Tensor running_var_out) + inline ::std::tuple _batch_norm_with_update_functional(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const at::Tensor & running_mean, const at::Tensor & running_var, double momentum, double eps) { + return at::_ops::_batch_norm_with_update_functional::redispatch(dispatchKeySet, input, weight, bias, running_mean, running_var, momentum, eps); + } + + // aten::_batch_norm_no_update.out(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, float momentum, float eps, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!)) + inline ::std::tuple _batch_norm_no_update_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3, const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const ::std::optional & running_mean, const ::std::optional & running_var, double momentum, double eps) { + return at::_ops::_batch_norm_no_update_out::redispatch(dispatchKeySet, input, weight, bias, running_mean, running_var, momentum, eps, out0, out1, out2, out3); + } + + // aten::_batch_norm_no_update.out(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, float momentum, float eps, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!)) + inline ::std::tuple _batch_norm_no_update_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const ::std::optional & running_mean, const ::std::optional & running_var, double momentum, double eps, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3) { + return at::_ops::_batch_norm_no_update_out::redispatch(dispatchKeySet, input, weight, bias, running_mean, running_var, momentum, eps, out0, out1, out2, out3); + } + + // aten::_sparse_sum.dim_out(Tensor self, int[1] dim, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _sparse_sum_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef dim) { + return at::_ops::_sparse_sum_dim_out::redispatch(dispatchKeySet, self, dim, out); + } + + // aten::_sparse_sum.dim_out(Tensor self, int[1] dim, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _sparse_sum_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim, at::Tensor & out) { + return at::_ops::_sparse_sum_dim_out::redispatch(dispatchKeySet, self, dim, out); + } + + // aten::_sparse_sum_backward.out(Tensor grad, Tensor self, int[] dim, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _sparse_sum_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad, const at::Tensor & self, at::IntArrayRef dim) { + return at::_ops::_sparse_sum_backward_out::redispatch(dispatchKeySet, grad, self, dim, out); + } + + // aten::_sparse_sum_backward.out(Tensor grad, Tensor self, int[] dim, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _sparse_sum_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & self, at::IntArrayRef dim, at::Tensor & out) { + return at::_ops::_sparse_sum_backward_out::redispatch(dispatchKeySet, grad, self, dim, out); + } + + // aten::_sparse_csr_sum.dim_dtype_out(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _sparse_csr_sum_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef dim, bool keepdim=false, ::std::optional dtype=::std::nullopt) { + return at::_ops::_sparse_csr_sum_dim_dtype_out::redispatch(dispatchKeySet, self, dim, keepdim, dtype, out); + } + + // aten::_sparse_csr_sum.dim_dtype_out(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _sparse_csr_sum_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim, bool keepdim, ::std::optional dtype, at::Tensor & out) { + return at::_ops::_sparse_csr_sum_dim_dtype_out::redispatch(dispatchKeySet, self, dim, keepdim, dtype, out); + } + + // aten::_sparse_csr_prod.dim_dtype_out(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _sparse_csr_prod_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef dim, bool keepdim=false, ::std::optional dtype=::std::nullopt) { + return at::_ops::_sparse_csr_prod_dim_dtype_out::redispatch(dispatchKeySet, self, dim, keepdim, dtype, out); + } + + // aten::_sparse_csr_prod.dim_dtype_out(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _sparse_csr_prod_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim, bool keepdim, ::std::optional dtype, at::Tensor & out) { + return at::_ops::_sparse_csr_prod_dim_dtype_out::redispatch(dispatchKeySet, self, dim, keepdim, dtype, out); + } + + // aten::_sparse_softmax.out(Tensor self, int dim, bool half_to_float, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _sparse_softmax_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim, bool half_to_float) { + return at::_ops::_sparse_softmax_out::redispatch(dispatchKeySet, self, dim, half_to_float, out); + } + + // aten::_sparse_softmax.out(Tensor self, int dim, bool half_to_float, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _sparse_softmax_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, bool half_to_float, at::Tensor & out) { + return at::_ops::_sparse_softmax_out::redispatch(dispatchKeySet, self, dim, half_to_float, out); + } + + // aten::_sparse_softmax_backward_data.out(Tensor grad_output, Tensor output, int dim, Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _sparse_softmax_backward_data_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad_output, const at::Tensor & output, int64_t dim, const at::Tensor & self) { + return at::_ops::_sparse_softmax_backward_data_out::redispatch(dispatchKeySet, grad_output, output, dim, self, out); + } + + // aten::_sparse_softmax_backward_data.out(Tensor grad_output, Tensor output, int dim, Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _sparse_softmax_backward_data_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & output, int64_t dim, const at::Tensor & self, at::Tensor & out) { + return at::_ops::_sparse_softmax_backward_data_out::redispatch(dispatchKeySet, grad_output, output, dim, self, out); + } + + // aten::_sparse_log_softmax.out(Tensor self, int dim, bool half_to_float, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _sparse_log_softmax_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim, bool half_to_float) { + return at::_ops::_sparse_log_softmax_out::redispatch(dispatchKeySet, self, dim, half_to_float, out); + } + + // aten::_sparse_log_softmax.out(Tensor self, int dim, bool half_to_float, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _sparse_log_softmax_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, bool half_to_float, at::Tensor & out) { + return at::_ops::_sparse_log_softmax_out::redispatch(dispatchKeySet, self, dim, half_to_float, out); + } + + // aten::_sparse_log_softmax_backward_data.out(Tensor grad_output, Tensor output, int dim, Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _sparse_log_softmax_backward_data_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad_output, const at::Tensor & output, int64_t dim, const at::Tensor & self) { + return at::_ops::_sparse_log_softmax_backward_data_out::redispatch(dispatchKeySet, grad_output, output, dim, self, out); + } + + // aten::_sparse_log_softmax_backward_data.out(Tensor grad_output, Tensor output, int dim, Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _sparse_log_softmax_backward_data_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & output, int64_t dim, const at::Tensor & self, at::Tensor & out) { + return at::_ops::_sparse_log_softmax_backward_data_out::redispatch(dispatchKeySet, grad_output, output, dim, self, out); + } + + // aten::_spdiags.out(Tensor diagonals, Tensor offsets, int[] shape, Layout? layout=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _spdiags_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & diagonals, const at::Tensor & offsets, at::IntArrayRef shape, ::std::optional layout=::std::nullopt) { + return at::_ops::_spdiags_out::redispatch(dispatchKeySet, diagonals, offsets, shape, layout, out); + } + + // aten::_spdiags.out(Tensor diagonals, Tensor offsets, int[] shape, Layout? layout=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _spdiags_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & diagonals, const at::Tensor & offsets, at::IntArrayRef shape, ::std::optional layout, at::Tensor & out) { + return at::_ops::_spdiags_out::redispatch(dispatchKeySet, diagonals, offsets, shape, layout, out); + } + + // aten::norm.ScalarOpt_dtype_out(Tensor self, Scalar? p, *, ScalarType dtype, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & norm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const ::std::optional & p, at::ScalarType dtype) { + return at::_ops::norm_ScalarOpt_dtype_out::redispatch(dispatchKeySet, self, p, dtype, out); + } + + // aten::norm.ScalarOpt_dtype_out(Tensor self, Scalar? p, *, ScalarType dtype, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & norm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const ::std::optional & p, at::ScalarType dtype, at::Tensor & out) { + return at::_ops::norm_ScalarOpt_dtype_out::redispatch(dispatchKeySet, self, p, dtype, out); + } + + // aten::norm.Scalar_out(Tensor self, Scalar p=2, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & norm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & p=2) { + return at::_ops::norm_Scalar_out::redispatch(dispatchKeySet, self, p, out); + } + + // aten::norm.Scalar_out(Tensor self, Scalar p=2, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & norm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & p, at::Tensor & out) { + return at::_ops::norm_Scalar_out::redispatch(dispatchKeySet, self, p, out); + } + + // aten::clone.out(Tensor self, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & clone_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, ::std::optional memory_format=::std::nullopt) { + return at::_ops::clone_out::redispatch(dispatchKeySet, self, memory_format, out); + } + + // aten::clone.out(Tensor self, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & clone_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional memory_format, at::Tensor & out) { + return at::_ops::clone_out::redispatch(dispatchKeySet, self, memory_format, out); + } + + // aten::resize_as.out(Tensor self, Tensor the_template, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) + inline const at::Tensor & resize_as_out(c10::DispatchKeySet dispatchKeySet, const at::Tensor & out, const at::Tensor & self, const at::Tensor & the_template, ::std::optional memory_format=::std::nullopt) { + return at::_ops::resize_as_out::redispatch(dispatchKeySet, self, the_template, memory_format, out); + } + + // aten::resize_as.out(Tensor self, Tensor the_template, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) + inline const at::Tensor & resize_as_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & the_template, ::std::optional memory_format, const at::Tensor & out) { + return at::_ops::resize_as_out::redispatch(dispatchKeySet, self, the_template, memory_format, out); + } + + // aten::resize_as(Tensor self, Tensor the_template, *, MemoryFormat? memory_format=None) -> Tensor + inline at::Tensor resize_as(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & the_template, ::std::optional memory_format=::std::nullopt) { + return at::_ops::resize_as::redispatch(dispatchKeySet, self, the_template, memory_format); + } + + // aten::resize_as_sparse.out(Tensor self, Tensor the_template, *, Tensor(a!) out) -> Tensor(a!) + inline const at::Tensor & resize_as_sparse_out(c10::DispatchKeySet dispatchKeySet, const at::Tensor & out, const at::Tensor & self, const at::Tensor & the_template) { + return at::_ops::resize_as_sparse_out::redispatch(dispatchKeySet, self, the_template, out); + } + + // aten::resize_as_sparse.out(Tensor self, Tensor the_template, *, Tensor(a!) out) -> Tensor(a!) + inline const at::Tensor & resize_as_sparse_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & the_template, const at::Tensor & out) { + return at::_ops::resize_as_sparse_out::redispatch(dispatchKeySet, self, the_template, out); + } + + // aten::resize_as_sparse(Tensor self, Tensor the_template) -> Tensor + inline at::Tensor resize_as_sparse(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & the_template) { + return at::_ops::resize_as_sparse::redispatch(dispatchKeySet, self, the_template); + } + + // aten::zero.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & zero_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::zero_out::redispatch(dispatchKeySet, self, out); + } + + // aten::zero.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & zero_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::zero_out::redispatch(dispatchKeySet, self, out); + } + + // aten::zero(Tensor self) -> Tensor + inline at::Tensor zero(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::zero::redispatch(dispatchKeySet, self); + } + + // aten::sub.Scalar_out(Tensor self, Scalar other, Scalar alpha=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & sub_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha=1) { + return at::_ops::sub_Scalar_out::redispatch(dispatchKeySet, self, other, alpha, out); + } + + // aten::sub.Scalar_out(Tensor self, Scalar other, Scalar alpha=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & sub_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha, at::Tensor & out) { + return at::_ops::sub_Scalar_out::redispatch(dispatchKeySet, self, other, alpha, out); + } + + // aten::rsub.Tensor_out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & rsub_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha=1) { + return at::_ops::rsub_Tensor_out::redispatch(dispatchKeySet, self, other, alpha, out); + } + + // aten::rsub.Tensor_out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & rsub_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha, at::Tensor & out) { + return at::_ops::rsub_Tensor_out::redispatch(dispatchKeySet, self, other, alpha, out); + } + + // aten::rsub.Scalar_out(Tensor self, Scalar other, Scalar alpha=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & rsub_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha=1) { + return at::_ops::rsub_Scalar_out::redispatch(dispatchKeySet, self, other, alpha, out); + } + + // aten::rsub.Scalar_out(Tensor self, Scalar other, Scalar alpha=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & rsub_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha, at::Tensor & out) { + return at::_ops::rsub_Scalar_out::redispatch(dispatchKeySet, self, other, alpha, out); + } + + // aten::_sparse_addmm.out(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _sparse_addmm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta=1, const at::Scalar & alpha=1) { + return at::_ops::_sparse_addmm_out::redispatch(dispatchKeySet, self, mat1, mat2, beta, alpha, out); + } + + // aten::_sparse_addmm.out(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _sparse_addmm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out) { + return at::_ops::_sparse_addmm_out::redispatch(dispatchKeySet, self, mat1, mat2, beta, alpha, out); + } + + // aten::sparse_coo_tensor.size_out(int[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & sparse_coo_tensor_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::IntArrayRef size) { + return at::_ops::sparse_coo_tensor_size_out::redispatch(dispatchKeySet, size, out); + } + + // aten::sparse_coo_tensor.size_out(int[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & sparse_coo_tensor_outf(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, at::Tensor & out) { + return at::_ops::sparse_coo_tensor_size_out::redispatch(dispatchKeySet, size, out); + } + + // aten::_sparse_coo_tensor_with_dims.out(int sparse_dim, int dense_dim, int[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _sparse_coo_tensor_with_dims_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, int64_t sparse_dim, int64_t dense_dim, at::IntArrayRef size) { + return at::_ops::_sparse_coo_tensor_with_dims_out::redispatch(dispatchKeySet, sparse_dim, dense_dim, size, out); + } + + // aten::_sparse_coo_tensor_with_dims.out(int sparse_dim, int dense_dim, int[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _sparse_coo_tensor_with_dims_outf(c10::DispatchKeySet dispatchKeySet, int64_t sparse_dim, int64_t dense_dim, at::IntArrayRef size, at::Tensor & out) { + return at::_ops::_sparse_coo_tensor_with_dims_out::redispatch(dispatchKeySet, sparse_dim, dense_dim, size, out); + } + + // aten::_sparse_coo_tensor_with_dims_and_tensors.out(int sparse_dim, int dense_dim, SymInt[] size, Tensor indices, Tensor values, *, bool? is_coalesced=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _sparse_coo_tensor_with_dims_and_tensors_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, int64_t sparse_dim, int64_t dense_dim, at::IntArrayRef size, const at::Tensor & indices, const at::Tensor & values, ::std::optional is_coalesced=::std::nullopt) { + return at::_ops::_sparse_coo_tensor_with_dims_and_tensors_out::redispatch(dispatchKeySet, sparse_dim, dense_dim, c10::fromIntArrayRefSlow(size), indices, values, is_coalesced, out); + } + + // aten::_sparse_coo_tensor_with_dims_and_tensors.out(int sparse_dim, int dense_dim, SymInt[] size, Tensor indices, Tensor values, *, bool? is_coalesced=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _sparse_coo_tensor_with_dims_and_tensors_outf(c10::DispatchKeySet dispatchKeySet, int64_t sparse_dim, int64_t dense_dim, at::IntArrayRef size, const at::Tensor & indices, const at::Tensor & values, ::std::optional is_coalesced, at::Tensor & out) { + return at::_ops::_sparse_coo_tensor_with_dims_and_tensors_out::redispatch(dispatchKeySet, sparse_dim, dense_dim, c10::fromIntArrayRefSlow(size), indices, values, is_coalesced, out); + } + + // aten::_sparse_coo_tensor_with_dims_and_tensors.out(int sparse_dim, int dense_dim, SymInt[] size, Tensor indices, Tensor values, *, bool? is_coalesced=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _sparse_coo_tensor_with_dims_and_tensors_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, int64_t sparse_dim, int64_t dense_dim, c10::SymIntArrayRef size, const at::Tensor & indices, const at::Tensor & values, ::std::optional is_coalesced=::std::nullopt) { + return at::_ops::_sparse_coo_tensor_with_dims_and_tensors_out::redispatch(dispatchKeySet, sparse_dim, dense_dim, size, indices, values, is_coalesced, out); + } + + // aten::_sparse_coo_tensor_with_dims_and_tensors.out(int sparse_dim, int dense_dim, SymInt[] size, Tensor indices, Tensor values, *, bool? is_coalesced=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _sparse_coo_tensor_with_dims_and_tensors_symint_outf(c10::DispatchKeySet dispatchKeySet, int64_t sparse_dim, int64_t dense_dim, c10::SymIntArrayRef size, const at::Tensor & indices, const at::Tensor & values, ::std::optional is_coalesced, at::Tensor & out) { + return at::_ops::_sparse_coo_tensor_with_dims_and_tensors_out::redispatch(dispatchKeySet, sparse_dim, dense_dim, size, indices, values, is_coalesced, out); + } + + // aten::sparse_resize.out(Tensor self, int[] size, int sparse_dim, int dense_dim, *, Tensor(a!) out) -> Tensor(a!) + inline const at::Tensor & sparse_resize_out(c10::DispatchKeySet dispatchKeySet, const at::Tensor & out, const at::Tensor & self, at::IntArrayRef size, int64_t sparse_dim, int64_t dense_dim) { + return at::_ops::sparse_resize_out::redispatch(dispatchKeySet, self, size, sparse_dim, dense_dim, out); + } + + // aten::sparse_resize.out(Tensor self, int[] size, int sparse_dim, int dense_dim, *, Tensor(a!) out) -> Tensor(a!) + inline const at::Tensor & sparse_resize_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, int64_t sparse_dim, int64_t dense_dim, const at::Tensor & out) { + return at::_ops::sparse_resize_out::redispatch(dispatchKeySet, self, size, sparse_dim, dense_dim, out); + } + + // aten::sparse_resize(Tensor self, int[] size, int sparse_dim, int dense_dim) -> Tensor + inline at::Tensor sparse_resize(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, int64_t sparse_dim, int64_t dense_dim) { + return at::_ops::sparse_resize::redispatch(dispatchKeySet, self, size, sparse_dim, dense_dim); + } + + // aten::sparse_resize_and_clear.out(Tensor self, int[] size, int sparse_dim, int dense_dim, *, Tensor(a!) out) -> Tensor(a!) + inline const at::Tensor & sparse_resize_and_clear_out(c10::DispatchKeySet dispatchKeySet, const at::Tensor & out, const at::Tensor & self, at::IntArrayRef size, int64_t sparse_dim, int64_t dense_dim) { + return at::_ops::sparse_resize_and_clear_out::redispatch(dispatchKeySet, self, size, sparse_dim, dense_dim, out); + } + + // aten::sparse_resize_and_clear.out(Tensor self, int[] size, int sparse_dim, int dense_dim, *, Tensor(a!) out) -> Tensor(a!) + inline const at::Tensor & sparse_resize_and_clear_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, int64_t sparse_dim, int64_t dense_dim, const at::Tensor & out) { + return at::_ops::sparse_resize_and_clear_out::redispatch(dispatchKeySet, self, size, sparse_dim, dense_dim, out); + } + + // aten::sparse_resize_and_clear(Tensor self, int[] size, int sparse_dim, int dense_dim) -> Tensor + inline at::Tensor sparse_resize_and_clear(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, int64_t sparse_dim, int64_t dense_dim) { + return at::_ops::sparse_resize_and_clear::redispatch(dispatchKeySet, self, size, sparse_dim, dense_dim); + } + + // aten::sparse_mask.out(Tensor self, Tensor mask, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & sparse_mask_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & mask) { + return at::_ops::sparse_mask_out::redispatch(dispatchKeySet, self, mask, out); + } + + // aten::sparse_mask.out(Tensor self, Tensor mask, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & sparse_mask_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mask, at::Tensor & out) { + return at::_ops::sparse_mask_out::redispatch(dispatchKeySet, self, mask, out); + } + + // aten::_sparse_mask_projection.out(Tensor self, Tensor mask, bool accumulate_matches=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _sparse_mask_projection_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & mask, bool accumulate_matches=false) { + return at::_ops::_sparse_mask_projection_out::redispatch(dispatchKeySet, self, mask, accumulate_matches, out); + } + + // aten::_sparse_mask_projection.out(Tensor self, Tensor mask, bool accumulate_matches=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _sparse_mask_projection_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mask, bool accumulate_matches, at::Tensor & out) { + return at::_ops::_sparse_mask_projection_out::redispatch(dispatchKeySet, self, mask, accumulate_matches, out); + } + + // aten::_to_dense.out(Tensor self, ScalarType? dtype=None, bool? masked_grad=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _to_dense_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, ::std::optional dtype=::std::nullopt, ::std::optional masked_grad=::std::nullopt) { + return at::_ops::_to_dense_out::redispatch(dispatchKeySet, self, dtype, masked_grad, out); + } + + // aten::_to_dense.out(Tensor self, ScalarType? dtype=None, bool? masked_grad=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _to_dense_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional dtype, ::std::optional masked_grad, at::Tensor & out) { + return at::_ops::_to_dense_out::redispatch(dispatchKeySet, self, dtype, masked_grad, out); + } + + // aten::_coalesce.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _coalesce_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::_coalesce_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_coalesce.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _coalesce_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::_coalesce_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_coalesced.out(Tensor self, bool coalesced, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _coalesced_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, bool coalesced) { + return at::_ops::_coalesced_out::redispatch(dispatchKeySet, self, coalesced, out); + } + + // aten::_coalesced.out(Tensor self, bool coalesced, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _coalesced_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool coalesced, at::Tensor & out) { + return at::_ops::_coalesced_out::redispatch(dispatchKeySet, self, coalesced, out); + } + + // aten::_coalesced(Tensor self, bool coalesced) -> Tensor + inline at::Tensor _coalesced(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool coalesced) { + return at::_ops::_coalesced::redispatch(dispatchKeySet, self, coalesced); + } + + // aten::copy_sparse_to_sparse.out(Tensor self, Tensor src, bool non_blocking=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & copy_sparse_to_sparse_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & src, bool non_blocking=false) { + return at::_ops::copy_sparse_to_sparse_out::redispatch(dispatchKeySet, self, src, non_blocking, out); + } + + // aten::copy_sparse_to_sparse.out(Tensor self, Tensor src, bool non_blocking=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & copy_sparse_to_sparse_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & src, bool non_blocking, at::Tensor & out) { + return at::_ops::copy_sparse_to_sparse_out::redispatch(dispatchKeySet, self, src, non_blocking, out); + } + + // aten::copy_sparse_to_sparse(Tensor self, Tensor src, bool non_blocking=False) -> Tensor + inline at::Tensor copy_sparse_to_sparse(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & src, bool non_blocking=false) { + return at::_ops::copy_sparse_to_sparse::redispatch(dispatchKeySet, self, src, non_blocking); + } + + // aten::_to_sparse.sparse_dim_out(Tensor self, int sparse_dim, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _to_sparse_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t sparse_dim) { + return at::_ops::_to_sparse_sparse_dim_out::redispatch(dispatchKeySet, self, sparse_dim, out); + } + + // aten::_to_sparse.sparse_dim_out(Tensor self, int sparse_dim, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _to_sparse_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t sparse_dim, at::Tensor & out) { + return at::_ops::_to_sparse_sparse_dim_out::redispatch(dispatchKeySet, self, sparse_dim, out); + } + + // aten::_to_sparse.out(Tensor self, *, Layout? layout=None, int[2]? blocksize=None, int? dense_dim=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _to_sparse_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, ::std::optional layout=::std::nullopt, at::OptionalIntArrayRef blocksize=::std::nullopt, ::std::optional dense_dim=::std::nullopt) { + return at::_ops::_to_sparse_out::redispatch(dispatchKeySet, self, layout, blocksize, dense_dim, out); + } + + // aten::_to_sparse.out(Tensor self, *, Layout? layout=None, int[2]? blocksize=None, int? dense_dim=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _to_sparse_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional layout, at::OptionalIntArrayRef blocksize, ::std::optional dense_dim, at::Tensor & out) { + return at::_ops::_to_sparse_out::redispatch(dispatchKeySet, self, layout, blocksize, dense_dim, out); + } + + // aten::_to_sparse_csr.out(Tensor self, int? dense_dim=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _to_sparse_csr_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, ::std::optional dense_dim=::std::nullopt) { + return at::_ops::_to_sparse_csr_out::redispatch(dispatchKeySet, self, dense_dim, out); + } + + // aten::_to_sparse_csr.out(Tensor self, int? dense_dim=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _to_sparse_csr_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional dense_dim, at::Tensor & out) { + return at::_ops::_to_sparse_csr_out::redispatch(dispatchKeySet, self, dense_dim, out); + } + + // aten::_to_sparse_csc.out(Tensor self, int? dense_dim=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _to_sparse_csc_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, ::std::optional dense_dim=::std::nullopt) { + return at::_ops::_to_sparse_csc_out::redispatch(dispatchKeySet, self, dense_dim, out); + } + + // aten::_to_sparse_csc.out(Tensor self, int? dense_dim=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _to_sparse_csc_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional dense_dim, at::Tensor & out) { + return at::_ops::_to_sparse_csc_out::redispatch(dispatchKeySet, self, dense_dim, out); + } + + // aten::_to_sparse_bsr.out(Tensor self, int[2] blocksize, int? dense_dim=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _to_sparse_bsr_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef blocksize, ::std::optional dense_dim=::std::nullopt) { + return at::_ops::_to_sparse_bsr_out::redispatch(dispatchKeySet, self, blocksize, dense_dim, out); + } + + // aten::_to_sparse_bsr.out(Tensor self, int[2] blocksize, int? dense_dim=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _to_sparse_bsr_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef blocksize, ::std::optional dense_dim, at::Tensor & out) { + return at::_ops::_to_sparse_bsr_out::redispatch(dispatchKeySet, self, blocksize, dense_dim, out); + } + + // aten::_to_sparse_bsc.out(Tensor self, int[2] blocksize, int? dense_dim=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _to_sparse_bsc_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef blocksize, ::std::optional dense_dim=::std::nullopt) { + return at::_ops::_to_sparse_bsc_out::redispatch(dispatchKeySet, self, blocksize, dense_dim, out); + } + + // aten::_to_sparse_bsc.out(Tensor self, int[2] blocksize, int? dense_dim=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _to_sparse_bsc_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef blocksize, ::std::optional dense_dim, at::Tensor & out) { + return at::_ops::_to_sparse_bsc_out::redispatch(dispatchKeySet, self, blocksize, dense_dim, out); + } + + // aten::to_mkldnn.out(Tensor self, ScalarType? dtype=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & to_mkldnn_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, ::std::optional dtype=::std::nullopt) { + return at::_ops::to_mkldnn_out::redispatch(dispatchKeySet, self, dtype, out); + } + + // aten::to_mkldnn.out(Tensor self, ScalarType? dtype=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & to_mkldnn_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional dtype, at::Tensor & out) { + return at::_ops::to_mkldnn_out::redispatch(dispatchKeySet, self, dtype, out); + } + + // aten::mkldnn_reorder_conv2d_weight.out(Tensor self, SymInt[2] padding=0, SymInt[2] stride=1, SymInt[2] dilation=1, SymInt groups=1, SymInt[]? input_size=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mkldnn_reorder_conv2d_weight_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef padding=0, at::IntArrayRef stride=1, at::IntArrayRef dilation=1, int64_t groups=1, at::OptionalIntArrayRef input_size=::std::nullopt) { + return at::_ops::mkldnn_reorder_conv2d_weight_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups, input_size.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*input_size)) : ::std::nullopt, out); + } + + // aten::mkldnn_reorder_conv2d_weight.out(Tensor self, SymInt[2] padding=0, SymInt[2] stride=1, SymInt[2] dilation=1, SymInt groups=1, SymInt[]? input_size=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mkldnn_reorder_conv2d_weight_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups, at::OptionalIntArrayRef input_size, at::Tensor & out) { + return at::_ops::mkldnn_reorder_conv2d_weight_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups, input_size.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*input_size)) : ::std::nullopt, out); + } + + // aten::mkldnn_reorder_conv2d_weight.out(Tensor self, SymInt[2] padding=0, SymInt[2] stride=1, SymInt[2] dilation=1, SymInt groups=1, SymInt[]? input_size=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mkldnn_reorder_conv2d_weight_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef padding=c10::SymInt(0), c10::SymIntArrayRef stride=c10::SymInt(1), c10::SymIntArrayRef dilation=c10::SymInt(1), c10::SymInt groups=1, at::OptionalSymIntArrayRef input_size=::std::nullopt) { + return at::_ops::mkldnn_reorder_conv2d_weight_out::redispatch(dispatchKeySet, self, padding, stride, dilation, groups, input_size, out); + } + + // aten::mkldnn_reorder_conv2d_weight.out(Tensor self, SymInt[2] padding=0, SymInt[2] stride=1, SymInt[2] dilation=1, SymInt groups=1, SymInt[]? input_size=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mkldnn_reorder_conv2d_weight_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, at::OptionalSymIntArrayRef input_size, at::Tensor & out) { + return at::_ops::mkldnn_reorder_conv2d_weight_out::redispatch(dispatchKeySet, self, padding, stride, dilation, groups, input_size, out); + } + + // aten::mkldnn_reorder_conv3d_weight.out(Tensor self, SymInt[3] padding=0, SymInt[3] stride=1, SymInt[3] dilation=1, SymInt groups=1, SymInt[]? input_size=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mkldnn_reorder_conv3d_weight_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef padding=0, at::IntArrayRef stride=1, at::IntArrayRef dilation=1, int64_t groups=1, at::OptionalIntArrayRef input_size=::std::nullopt) { + return at::_ops::mkldnn_reorder_conv3d_weight_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups, input_size.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*input_size)) : ::std::nullopt, out); + } + + // aten::mkldnn_reorder_conv3d_weight.out(Tensor self, SymInt[3] padding=0, SymInt[3] stride=1, SymInt[3] dilation=1, SymInt groups=1, SymInt[]? input_size=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mkldnn_reorder_conv3d_weight_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups, at::OptionalIntArrayRef input_size, at::Tensor & out) { + return at::_ops::mkldnn_reorder_conv3d_weight_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups, input_size.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*input_size)) : ::std::nullopt, out); + } + + // aten::mkldnn_reorder_conv3d_weight.out(Tensor self, SymInt[3] padding=0, SymInt[3] stride=1, SymInt[3] dilation=1, SymInt groups=1, SymInt[]? input_size=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mkldnn_reorder_conv3d_weight_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef padding=c10::SymInt(0), c10::SymIntArrayRef stride=c10::SymInt(1), c10::SymIntArrayRef dilation=c10::SymInt(1), c10::SymInt groups=1, at::OptionalSymIntArrayRef input_size=::std::nullopt) { + return at::_ops::mkldnn_reorder_conv3d_weight_out::redispatch(dispatchKeySet, self, padding, stride, dilation, groups, input_size, out); + } + + // aten::mkldnn_reorder_conv3d_weight.out(Tensor self, SymInt[3] padding=0, SymInt[3] stride=1, SymInt[3] dilation=1, SymInt groups=1, SymInt[]? input_size=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mkldnn_reorder_conv3d_weight_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, at::OptionalSymIntArrayRef input_size, at::Tensor & out) { + return at::_ops::mkldnn_reorder_conv3d_weight_out::redispatch(dispatchKeySet, self, padding, stride, dilation, groups, input_size, out); + } + + // aten::quantize_per_tensor_dynamic.out(Tensor self, ScalarType dtype, bool reduce_range, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & quantize_per_tensor_dynamic_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::ScalarType dtype, bool reduce_range) { + return at::_ops::quantize_per_tensor_dynamic_out::redispatch(dispatchKeySet, self, dtype, reduce_range, out); + } + + // aten::quantize_per_tensor_dynamic.out(Tensor self, ScalarType dtype, bool reduce_range, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & quantize_per_tensor_dynamic_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::ScalarType dtype, bool reduce_range, at::Tensor & out) { + return at::_ops::quantize_per_tensor_dynamic_out::redispatch(dispatchKeySet, self, dtype, reduce_range, out); + } + + // aten::quantize_per_tensor.out(Tensor self, float scale, int zero_point, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & quantize_per_tensor_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, double scale, int64_t zero_point, at::ScalarType dtype) { + return at::_ops::quantize_per_tensor_out::redispatch(dispatchKeySet, self, scale, zero_point, dtype, out); + } + + // aten::quantize_per_tensor.out(Tensor self, float scale, int zero_point, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & quantize_per_tensor_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double scale, int64_t zero_point, at::ScalarType dtype, at::Tensor & out) { + return at::_ops::quantize_per_tensor_out::redispatch(dispatchKeySet, self, scale, zero_point, dtype, out); + } + + // aten::quantize_per_tensor.tensor_qparams_out(Tensor self, Tensor scale, Tensor zero_point, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & quantize_per_tensor_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, at::ScalarType dtype) { + return at::_ops::quantize_per_tensor_tensor_qparams_out::redispatch(dispatchKeySet, self, scale, zero_point, dtype, out); + } + + // aten::quantize_per_tensor.tensor_qparams_out(Tensor self, Tensor scale, Tensor zero_point, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & quantize_per_tensor_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, at::ScalarType dtype, at::Tensor & out) { + return at::_ops::quantize_per_tensor_tensor_qparams_out::redispatch(dispatchKeySet, self, scale, zero_point, dtype, out); + } + + // aten::quantize_per_tensor.tensors_out(Tensor[] tensors, Tensor scales, Tensor zero_points, ScalarType dtype, *, Tensor(a!)[] out) -> () + inline void quantize_per_tensor_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList tensors, const at::Tensor & scales, const at::Tensor & zero_points, at::ScalarType dtype) { + return at::_ops::quantize_per_tensor_tensors_out::redispatch(dispatchKeySet, tensors, scales, zero_points, dtype, out); + } + + // aten::quantize_per_tensor.tensors_out(Tensor[] tensors, Tensor scales, Tensor zero_points, ScalarType dtype, *, Tensor(a!)[] out) -> () + inline void quantize_per_tensor_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors, const at::Tensor & scales, const at::Tensor & zero_points, at::ScalarType dtype, at::TensorList out) { + return at::_ops::quantize_per_tensor_tensors_out::redispatch(dispatchKeySet, tensors, scales, zero_points, dtype, out); + } + + // aten::quantize_per_channel.out(Tensor self, Tensor scales, Tensor zero_points, int axis, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & quantize_per_channel_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, at::ScalarType dtype) { + return at::_ops::quantize_per_channel_out::redispatch(dispatchKeySet, self, scales, zero_points, axis, dtype, out); + } + + // aten::quantize_per_channel.out(Tensor self, Tensor scales, Tensor zero_points, int axis, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & quantize_per_channel_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, at::ScalarType dtype, at::Tensor & out) { + return at::_ops::quantize_per_channel_out::redispatch(dispatchKeySet, self, scales, zero_points, axis, dtype, out); + } + + // aten::dequantize.self_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & dequantize_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::dequantize_self_out::redispatch(dispatchKeySet, self, out); + } + + // aten::dequantize.self_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & dequantize_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::dequantize_self_out::redispatch(dispatchKeySet, self, out); + } + + // aten::dequantize.tensors_out(Tensor[] tensors, *, Tensor(a!)[] out) -> () + inline void dequantize_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList tensors) { + return at::_ops::dequantize_tensors_out::redispatch(dispatchKeySet, tensors, out); + } + + // aten::dequantize.tensors_out(Tensor[] tensors, *, Tensor(a!)[] out) -> () + inline void dequantize_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors, at::TensorList out) { + return at::_ops::dequantize_tensors_out::redispatch(dispatchKeySet, tensors, out); + } + + // aten::q_per_channel_scales.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & q_per_channel_scales_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::q_per_channel_scales_out::redispatch(dispatchKeySet, self, out); + } + + // aten::q_per_channel_scales.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & q_per_channel_scales_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::q_per_channel_scales_out::redispatch(dispatchKeySet, self, out); + } + + // aten::q_per_channel_zero_points.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & q_per_channel_zero_points_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::q_per_channel_zero_points_out::redispatch(dispatchKeySet, self, out); + } + + // aten::q_per_channel_zero_points.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & q_per_channel_zero_points_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::q_per_channel_zero_points_out::redispatch(dispatchKeySet, self, out); + } + + // aten::int_repr.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & int_repr_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::int_repr_out::redispatch(dispatchKeySet, self, out); + } + + // aten::int_repr.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & int_repr_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::int_repr_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_make_per_tensor_quantized_tensor.out(Tensor self, float scale, int zero_point, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _make_per_tensor_quantized_tensor_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, double scale, int64_t zero_point) { + return at::_ops::_make_per_tensor_quantized_tensor_out::redispatch(dispatchKeySet, self, scale, zero_point, out); + } + + // aten::_make_per_tensor_quantized_tensor.out(Tensor self, float scale, int zero_point, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _make_per_tensor_quantized_tensor_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double scale, int64_t zero_point, at::Tensor & out) { + return at::_ops::_make_per_tensor_quantized_tensor_out::redispatch(dispatchKeySet, self, scale, zero_point, out); + } + + // aten::_make_per_channel_quantized_tensor.out(Tensor self, Tensor scale, Tensor zero_point, int axis, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _make_per_channel_quantized_tensor_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, int64_t axis) { + return at::_ops::_make_per_channel_quantized_tensor_out::redispatch(dispatchKeySet, self, scale, zero_point, axis, out); + } + + // aten::_make_per_channel_quantized_tensor.out(Tensor self, Tensor scale, Tensor zero_point, int axis, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _make_per_channel_quantized_tensor_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, int64_t axis, at::Tensor & out) { + return at::_ops::_make_per_channel_quantized_tensor_out::redispatch(dispatchKeySet, self, scale, zero_point, axis, out); + } + + // aten::fake_quantize_per_tensor_affine_cachemask.out(Tensor self, float scale, int zero_point, int quant_min, int quant_max, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple fake_quantize_per_tensor_affine_cachemask_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, const at::Tensor & self, double scale, int64_t zero_point, int64_t quant_min, int64_t quant_max) { + return at::_ops::fake_quantize_per_tensor_affine_cachemask_out::redispatch(dispatchKeySet, self, scale, zero_point, quant_min, quant_max, out0, out1); + } + + // aten::fake_quantize_per_tensor_affine_cachemask.out(Tensor self, float scale, int zero_point, int quant_min, int quant_max, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple fake_quantize_per_tensor_affine_cachemask_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double scale, int64_t zero_point, int64_t quant_min, int64_t quant_max, at::Tensor & out0, at::Tensor & out1) { + return at::_ops::fake_quantize_per_tensor_affine_cachemask_out::redispatch(dispatchKeySet, self, scale, zero_point, quant_min, quant_max, out0, out1); + } + + // aten::_fake_quantize_per_tensor_affine_cachemask_tensor_qparams.out(Tensor self, Tensor scale, Tensor zero_point, Tensor fake_quant_enabled, int quant_min, int quant_max, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple _fake_quantize_per_tensor_affine_cachemask_tensor_qparams_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, const at::Tensor & fake_quant_enabled, int64_t quant_min, int64_t quant_max) { + return at::_ops::_fake_quantize_per_tensor_affine_cachemask_tensor_qparams_out::redispatch(dispatchKeySet, self, scale, zero_point, fake_quant_enabled, quant_min, quant_max, out0, out1); + } + + // aten::_fake_quantize_per_tensor_affine_cachemask_tensor_qparams.out(Tensor self, Tensor scale, Tensor zero_point, Tensor fake_quant_enabled, int quant_min, int quant_max, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple _fake_quantize_per_tensor_affine_cachemask_tensor_qparams_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, const at::Tensor & fake_quant_enabled, int64_t quant_min, int64_t quant_max, at::Tensor & out0, at::Tensor & out1) { + return at::_ops::_fake_quantize_per_tensor_affine_cachemask_tensor_qparams_out::redispatch(dispatchKeySet, self, scale, zero_point, fake_quant_enabled, quant_min, quant_max, out0, out1); + } + + // aten::_fake_quantize_learnable_per_tensor_affine.out(Tensor self, Tensor scale, Tensor zero_point, int quant_min, int quant_max, float grad_factor=1.0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _fake_quantize_learnable_per_tensor_affine_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, int64_t quant_min, int64_t quant_max, double grad_factor=1.0) { + return at::_ops::_fake_quantize_learnable_per_tensor_affine_out::redispatch(dispatchKeySet, self, scale, zero_point, quant_min, quant_max, grad_factor, out); + } + + // aten::_fake_quantize_learnable_per_tensor_affine.out(Tensor self, Tensor scale, Tensor zero_point, int quant_min, int quant_max, float grad_factor=1.0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _fake_quantize_learnable_per_tensor_affine_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, int64_t quant_min, int64_t quant_max, double grad_factor, at::Tensor & out) { + return at::_ops::_fake_quantize_learnable_per_tensor_affine_out::redispatch(dispatchKeySet, self, scale, zero_point, quant_min, quant_max, grad_factor, out); + } + + // aten::fake_quantize_per_channel_affine_cachemask.out(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple fake_quantize_per_channel_affine_cachemask_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, int64_t axis, int64_t quant_min, int64_t quant_max) { + return at::_ops::fake_quantize_per_channel_affine_cachemask_out::redispatch(dispatchKeySet, self, scale, zero_point, axis, quant_min, quant_max, out0, out1); + } + + // aten::fake_quantize_per_channel_affine_cachemask.out(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple fake_quantize_per_channel_affine_cachemask_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, int64_t axis, int64_t quant_min, int64_t quant_max, at::Tensor & out0, at::Tensor & out1) { + return at::_ops::fake_quantize_per_channel_affine_cachemask_out::redispatch(dispatchKeySet, self, scale, zero_point, axis, quant_min, quant_max, out0, out1); + } + + // aten::_fake_quantize_learnable_per_channel_affine.out(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max, float grad_factor=1.0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _fake_quantize_learnable_per_channel_affine_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, int64_t axis, int64_t quant_min, int64_t quant_max, double grad_factor=1.0) { + return at::_ops::_fake_quantize_learnable_per_channel_affine_out::redispatch(dispatchKeySet, self, scale, zero_point, axis, quant_min, quant_max, grad_factor, out); + } + + // aten::_fake_quantize_learnable_per_channel_affine.out(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max, float grad_factor=1.0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _fake_quantize_learnable_per_channel_affine_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, int64_t axis, int64_t quant_min, int64_t quant_max, double grad_factor, at::Tensor & out) { + return at::_ops::_fake_quantize_learnable_per_channel_affine_out::redispatch(dispatchKeySet, self, scale, zero_point, axis, quant_min, quant_max, grad_factor, out); + } + + // aten::_fused_moving_avg_obs_fq_helper.out(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor(a!) running_min, Tensor(b!) running_max, Tensor(c!) scale, Tensor(d!) zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False, *, Tensor(e!) out0, Tensor(f!) out1) -> (Tensor(e!), Tensor(f!)) + inline ::std::tuple _fused_moving_avg_obs_fq_helper_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, const at::Tensor & self, const at::Tensor & observer_on, const at::Tensor & fake_quant_on, at::Tensor & running_min, at::Tensor & running_max, at::Tensor & scale, at::Tensor & zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, bool per_row_fake_quant=false, bool symmetric_quant=false) { + return at::_ops::_fused_moving_avg_obs_fq_helper_out::redispatch(dispatchKeySet, self, observer_on, fake_quant_on, running_min, running_max, scale, zero_point, averaging_const, quant_min, quant_max, ch_axis, per_row_fake_quant, symmetric_quant, out0, out1); + } + + // aten::_fused_moving_avg_obs_fq_helper.out(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor(a!) running_min, Tensor(b!) running_max, Tensor(c!) scale, Tensor(d!) zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False, *, Tensor(e!) out0, Tensor(f!) out1) -> (Tensor(e!), Tensor(f!)) + inline ::std::tuple _fused_moving_avg_obs_fq_helper_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & observer_on, const at::Tensor & fake_quant_on, at::Tensor & running_min, at::Tensor & running_max, at::Tensor & scale, at::Tensor & zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, bool per_row_fake_quant, bool symmetric_quant, at::Tensor & out0, at::Tensor & out1) { + return at::_ops::_fused_moving_avg_obs_fq_helper_out::redispatch(dispatchKeySet, self, observer_on, fake_quant_on, running_min, running_max, scale, zero_point, averaging_const, quant_min, quant_max, ch_axis, per_row_fake_quant, symmetric_quant, out0, out1); + } + + // aten::_fused_moving_avg_obs_fq_helper_functional(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor running_min, Tensor running_max, Tensor scale, Tensor zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> (Tensor output, Tensor mask, Tensor running_min_out, Tensor running_max_out, Tensor scale_out, Tensor zero_point_out) + inline ::std::tuple _fused_moving_avg_obs_fq_helper_functional(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & observer_on, const at::Tensor & fake_quant_on, const at::Tensor & running_min, const at::Tensor & running_max, const at::Tensor & scale, const at::Tensor & zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, bool per_row_fake_quant=false, bool symmetric_quant=false) { + return at::_ops::_fused_moving_avg_obs_fq_helper_functional::redispatch(dispatchKeySet, self, observer_on, fake_quant_on, running_min, running_max, scale, zero_point, averaging_const, quant_min, quant_max, ch_axis, per_row_fake_quant, symmetric_quant); + } + + // aten::_to_copy.out(Tensor self, *, bool non_blocking=False, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _to_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, bool non_blocking=false, ::std::optional memory_format=::std::nullopt) { + return at::_ops::_to_copy_out::redispatch(dispatchKeySet, self, non_blocking, memory_format, out); + } + + // aten::_to_copy.out(Tensor self, *, bool non_blocking=False, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _to_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool non_blocking, ::std::optional memory_format, at::Tensor & out) { + return at::_ops::_to_copy_out::redispatch(dispatchKeySet, self, non_blocking, memory_format, out); + } + + // aten::_lstm_mps.out(Tensor input, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3, Tensor(e!) out4, Tensor(f!) out5) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!), Tensor(e!), Tensor(f!)) + inline ::std::tuple _lstm_mps_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3, at::Tensor & out4, at::Tensor & out5, const at::Tensor & input, at::TensorList hx, at::TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional, bool batch_first) { + return at::_ops::_lstm_mps_out::redispatch(dispatchKeySet, input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first, out0, out1, out2, out3, out4, out5); + } + + // aten::_lstm_mps.out(Tensor input, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3, Tensor(e!) out4, Tensor(f!) out5) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!), Tensor(e!), Tensor(f!)) + inline ::std::tuple _lstm_mps_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::TensorList hx, at::TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional, bool batch_first, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3, at::Tensor & out4, at::Tensor & out5) { + return at::_ops::_lstm_mps_out::redispatch(dispatchKeySet, input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first, out0, out1, out2, out3, out4, out5); + } + + // aten::lstm_mps_backward.out(Tensor? grad_y, Tensor? grad_hy, Tensor? grad_cy, Tensor z_state, Tensor cell_state_fwd, Tensor input, Tensor layersOutputs, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first, *, Tensor(a!) out0, Tensor(b!)[] out1, Tensor(c!)[] out2) -> () + inline void lstm_mps_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::TensorList out1, at::TensorList out2, const ::std::optional & grad_y, const ::std::optional & grad_hy, const ::std::optional & grad_cy, const at::Tensor & z_state, const at::Tensor & cell_state_fwd, const at::Tensor & input, const at::Tensor & layersOutputs, at::TensorList hx, at::TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional, bool batch_first) { + return at::_ops::lstm_mps_backward_out::redispatch(dispatchKeySet, grad_y, grad_hy, grad_cy, z_state, cell_state_fwd, input, layersOutputs, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first, out0, out1, out2); + } + + // aten::lstm_mps_backward.out(Tensor? grad_y, Tensor? grad_hy, Tensor? grad_cy, Tensor z_state, Tensor cell_state_fwd, Tensor input, Tensor layersOutputs, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first, *, Tensor(a!) out0, Tensor(b!)[] out1, Tensor(c!)[] out2) -> () + inline void lstm_mps_backward_outf(c10::DispatchKeySet dispatchKeySet, const ::std::optional & grad_y, const ::std::optional & grad_hy, const ::std::optional & grad_cy, const at::Tensor & z_state, const at::Tensor & cell_state_fwd, const at::Tensor & input, const at::Tensor & layersOutputs, at::TensorList hx, at::TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional, bool batch_first, at::Tensor & out0, at::TensorList out1, at::TensorList out2) { + return at::_ops::lstm_mps_backward_out::redispatch(dispatchKeySet, grad_y, grad_hy, grad_cy, z_state, cell_state_fwd, input, layersOutputs, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first, out0, out1, out2); + } + + // aten::_thnn_fused_lstm_cell.out(Tensor input_gates, Tensor hidden_gates, Tensor cx, Tensor? input_bias=None, Tensor? hidden_bias=None, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple _thnn_fused_lstm_cell_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, const at::Tensor & input_gates, const at::Tensor & hidden_gates, const at::Tensor & cx, const ::std::optional & input_bias={}, const ::std::optional & hidden_bias={}) { + return at::_ops::_thnn_fused_lstm_cell_out::redispatch(dispatchKeySet, input_gates, hidden_gates, cx, input_bias, hidden_bias, out0, out1, out2); + } + + // aten::_thnn_fused_lstm_cell.out(Tensor input_gates, Tensor hidden_gates, Tensor cx, Tensor? input_bias=None, Tensor? hidden_bias=None, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple _thnn_fused_lstm_cell_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input_gates, const at::Tensor & hidden_gates, const at::Tensor & cx, const ::std::optional & input_bias, const ::std::optional & hidden_bias, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2) { + return at::_ops::_thnn_fused_lstm_cell_out::redispatch(dispatchKeySet, input_gates, hidden_gates, cx, input_bias, hidden_bias, out0, out1, out2); + } + + // aten::_thnn_fused_lstm_cell_backward_impl.out(Tensor? grad_hy, Tensor? grad_cy, Tensor cx, Tensor cy, Tensor workspace, bool has_bias, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple _thnn_fused_lstm_cell_backward_impl_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, const ::std::optional & grad_hy, const ::std::optional & grad_cy, const at::Tensor & cx, const at::Tensor & cy, const at::Tensor & workspace, bool has_bias) { + return at::_ops::_thnn_fused_lstm_cell_backward_impl_out::redispatch(dispatchKeySet, grad_hy, grad_cy, cx, cy, workspace, has_bias, out0, out1, out2); + } + + // aten::_thnn_fused_lstm_cell_backward_impl.out(Tensor? grad_hy, Tensor? grad_cy, Tensor cx, Tensor cy, Tensor workspace, bool has_bias, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple _thnn_fused_lstm_cell_backward_impl_outf(c10::DispatchKeySet dispatchKeySet, const ::std::optional & grad_hy, const ::std::optional & grad_cy, const at::Tensor & cx, const at::Tensor & cy, const at::Tensor & workspace, bool has_bias, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2) { + return at::_ops::_thnn_fused_lstm_cell_backward_impl_out::redispatch(dispatchKeySet, grad_hy, grad_cy, cx, cy, workspace, has_bias, out0, out1, out2); + } + + // aten::_thnn_fused_gru_cell.out(Tensor input_gates, Tensor hidden_gates, Tensor hx, Tensor? input_bias=None, Tensor? hidden_bias=None, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple _thnn_fused_gru_cell_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, const at::Tensor & input_gates, const at::Tensor & hidden_gates, const at::Tensor & hx, const ::std::optional & input_bias={}, const ::std::optional & hidden_bias={}) { + return at::_ops::_thnn_fused_gru_cell_out::redispatch(dispatchKeySet, input_gates, hidden_gates, hx, input_bias, hidden_bias, out0, out1); + } + + // aten::_thnn_fused_gru_cell.out(Tensor input_gates, Tensor hidden_gates, Tensor hx, Tensor? input_bias=None, Tensor? hidden_bias=None, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple _thnn_fused_gru_cell_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input_gates, const at::Tensor & hidden_gates, const at::Tensor & hx, const ::std::optional & input_bias, const ::std::optional & hidden_bias, at::Tensor & out0, at::Tensor & out1) { + return at::_ops::_thnn_fused_gru_cell_out::redispatch(dispatchKeySet, input_gates, hidden_gates, hx, input_bias, hidden_bias, out0, out1); + } + + // aten::_thnn_fused_gru_cell_backward.out(Tensor grad_hy, Tensor workspace, bool has_bias, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3, Tensor(e!) out4) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!), Tensor(e!)) + inline ::std::tuple _thnn_fused_gru_cell_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3, at::Tensor & out4, const at::Tensor & grad_hy, const at::Tensor & workspace, bool has_bias) { + return at::_ops::_thnn_fused_gru_cell_backward_out::redispatch(dispatchKeySet, grad_hy, workspace, has_bias, out0, out1, out2, out3, out4); + } + + // aten::_thnn_fused_gru_cell_backward.out(Tensor grad_hy, Tensor workspace, bool has_bias, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3, Tensor(e!) out4) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!), Tensor(e!)) + inline ::std::tuple _thnn_fused_gru_cell_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_hy, const at::Tensor & workspace, bool has_bias, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3, at::Tensor & out4) { + return at::_ops::_thnn_fused_gru_cell_backward_out::redispatch(dispatchKeySet, grad_hy, workspace, has_bias, out0, out1, out2, out3, out4); + } + + // aten::_pack_padded_sequence.out(Tensor input, Tensor lengths, bool batch_first, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple _pack_padded_sequence_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, const at::Tensor & input, const at::Tensor & lengths, bool batch_first) { + return at::_ops::_pack_padded_sequence_out::redispatch(dispatchKeySet, input, lengths, batch_first, out0, out1); + } + + // aten::_pack_padded_sequence.out(Tensor input, Tensor lengths, bool batch_first, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple _pack_padded_sequence_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & lengths, bool batch_first, at::Tensor & out0, at::Tensor & out1) { + return at::_ops::_pack_padded_sequence_out::redispatch(dispatchKeySet, input, lengths, batch_first, out0, out1); + } + + // aten::set.source_Storage_out(Tensor self, Storage source, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & set_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::Storage source) { + return at::_ops::set_source_Storage_out::redispatch(dispatchKeySet, self, source, out); + } + + // aten::set.source_Storage_out(Tensor self, Storage source, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & set_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Storage source, at::Tensor & out) { + return at::_ops::set_source_Storage_out::redispatch(dispatchKeySet, self, source, out); + } + + // aten::set.source_Storage(Tensor self, Storage source) -> Tensor + inline at::Tensor set(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Storage source) { + return at::_ops::set_source_Storage::redispatch(dispatchKeySet, self, source); + } + + // aten::set.source_Storage_storage_offset_out(Tensor self, Storage source, SymInt storage_offset, SymInt[] size, SymInt[] stride=[], *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & set_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::Storage source, int64_t storage_offset, at::IntArrayRef size, at::IntArrayRef stride={}) { + return at::_ops::set_source_Storage_storage_offset_out::redispatch(dispatchKeySet, self, source, storage_offset, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride), out); + } + + // aten::set.source_Storage_storage_offset_out(Tensor self, Storage source, SymInt storage_offset, SymInt[] size, SymInt[] stride=[], *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & set_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Storage source, int64_t storage_offset, at::IntArrayRef size, at::IntArrayRef stride, at::Tensor & out) { + return at::_ops::set_source_Storage_storage_offset_out::redispatch(dispatchKeySet, self, source, storage_offset, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride), out); + } + + // aten::set.source_Storage_storage_offset_out(Tensor self, Storage source, SymInt storage_offset, SymInt[] size, SymInt[] stride=[], *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & set_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::Storage source, c10::SymInt storage_offset, c10::SymIntArrayRef size, c10::SymIntArrayRef stride={}) { + return at::_ops::set_source_Storage_storage_offset_out::redispatch(dispatchKeySet, self, source, storage_offset, size, stride, out); + } + + // aten::set.source_Storage_storage_offset_out(Tensor self, Storage source, SymInt storage_offset, SymInt[] size, SymInt[] stride=[], *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & set_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Storage source, c10::SymInt storage_offset, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, at::Tensor & out) { + return at::_ops::set_source_Storage_storage_offset_out::redispatch(dispatchKeySet, self, source, storage_offset, size, stride, out); + } + + // aten::set.source_Storage_storage_offset(Tensor self, Storage source, SymInt storage_offset, SymInt[] size, SymInt[] stride=[]) -> Tensor + inline at::Tensor set(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Storage source, int64_t storage_offset, at::IntArrayRef size, at::IntArrayRef stride={}) { + return at::_ops::set_source_Storage_storage_offset::redispatch(dispatchKeySet, self, source, storage_offset, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride)); + } + + // aten::set.source_Storage_storage_offset(Tensor self, Storage source, SymInt storage_offset, SymInt[] size, SymInt[] stride=[]) -> Tensor + inline at::Tensor set_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Storage source, c10::SymInt storage_offset, c10::SymIntArrayRef size, c10::SymIntArrayRef stride={}) { + return at::_ops::set_source_Storage_storage_offset::redispatch(dispatchKeySet, self, source, storage_offset, size, stride); + } + + // aten::set.source_Tensor_out(Tensor self, Tensor source, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & set_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & source) { + return at::_ops::set_source_Tensor_out::redispatch(dispatchKeySet, self, source, out); + } + + // aten::set.source_Tensor_out(Tensor self, Tensor source, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & set_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & source, at::Tensor & out) { + return at::_ops::set_source_Tensor_out::redispatch(dispatchKeySet, self, source, out); + } + + // aten::set.source_Tensor(Tensor self, Tensor source) -> Tensor + inline at::Tensor set(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & source) { + return at::_ops::set_source_Tensor::redispatch(dispatchKeySet, self, source); + } + + // aten::set.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & set_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::set_out::redispatch(dispatchKeySet, self, out); + } + + // aten::set.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & set_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::set_out::redispatch(dispatchKeySet, self, out); + } + + // aten::set(Tensor self) -> Tensor + inline at::Tensor set(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::set::redispatch(dispatchKeySet, self); + } + + // aten::lift.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & lift_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::lift_out::redispatch(dispatchKeySet, self, out); + } + + // aten::lift.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & lift_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::lift_out::redispatch(dispatchKeySet, self, out); + } + + // aten::lift_fresh_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & lift_fresh_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::lift_fresh_copy_out::redispatch(dispatchKeySet, self, out); + } + + // aten::lift_fresh_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & lift_fresh_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::lift_fresh_copy_out::redispatch(dispatchKeySet, self, out); + } + + // aten::masked_fill.Scalar_out(Tensor self, Tensor mask, Scalar value, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & masked_fill_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & mask, const at::Scalar & value) { + return at::_ops::masked_fill_Scalar_out::redispatch(dispatchKeySet, self, mask, value, out); + } + + // aten::masked_fill.Scalar_out(Tensor self, Tensor mask, Scalar value, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & masked_fill_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mask, const at::Scalar & value, at::Tensor & out) { + return at::_ops::masked_fill_Scalar_out::redispatch(dispatchKeySet, self, mask, value, out); + } + + // aten::masked_fill.Tensor_out(Tensor self, Tensor mask, Tensor value, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & masked_fill_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & mask, const at::Tensor & value) { + return at::_ops::masked_fill_Tensor_out::redispatch(dispatchKeySet, self, mask, value, out); + } + + // aten::masked_fill.Tensor_out(Tensor self, Tensor mask, Tensor value, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & masked_fill_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mask, const at::Tensor & value, at::Tensor & out) { + return at::_ops::masked_fill_Tensor_out::redispatch(dispatchKeySet, self, mask, value, out); + } + + // aten::masked_scatter.out(Tensor self, Tensor mask, Tensor source, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & masked_scatter_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & mask, const at::Tensor & source) { + return at::_ops::masked_scatter_out::redispatch(dispatchKeySet, self, mask, source, out); + } + + // aten::masked_scatter.out(Tensor self, Tensor mask, Tensor source, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & masked_scatter_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mask, const at::Tensor & source, at::Tensor & out) { + return at::_ops::masked_scatter_out::redispatch(dispatchKeySet, self, mask, source, out); + } + + // aten::_masked_softmax.out(Tensor self, Tensor mask, int? dim=None, int? mask_type=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _masked_softmax_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & mask, ::std::optional dim=::std::nullopt, ::std::optional mask_type=::std::nullopt) { + return at::_ops::_masked_softmax_out::redispatch(dispatchKeySet, self, mask, dim, mask_type, out); + } + + // aten::_masked_softmax.out(Tensor self, Tensor mask, int? dim=None, int? mask_type=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _masked_softmax_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mask, ::std::optional dim, ::std::optional mask_type, at::Tensor & out) { + return at::_ops::_masked_softmax_out::redispatch(dispatchKeySet, self, mask, dim, mask_type, out); + } + + // aten::_masked_softmax_backward.out(Tensor grad_output, Tensor output, Tensor mask, int? dim=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _masked_softmax_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad_output, const at::Tensor & output, const at::Tensor & mask, ::std::optional dim=::std::nullopt) { + return at::_ops::_masked_softmax_backward_out::redispatch(dispatchKeySet, grad_output, output, mask, dim, out); + } + + // aten::_masked_softmax_backward.out(Tensor grad_output, Tensor output, Tensor mask, int? dim=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _masked_softmax_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & output, const at::Tensor & mask, ::std::optional dim, at::Tensor & out) { + return at::_ops::_masked_softmax_backward_out::redispatch(dispatchKeySet, grad_output, output, mask, dim, out); + } + + // aten::put.out(Tensor self, Tensor index, Tensor source, bool accumulate=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & put_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & index, const at::Tensor & source, bool accumulate=false) { + return at::_ops::put_out::redispatch(dispatchKeySet, self, index, source, accumulate, out); + } + + // aten::put.out(Tensor self, Tensor index, Tensor source, bool accumulate=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & put_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & index, const at::Tensor & source, bool accumulate, at::Tensor & out) { + return at::_ops::put_out::redispatch(dispatchKeySet, self, index, source, accumulate, out); + } + + // aten::index_fill.int_Scalar_out(Tensor self, int dim, Tensor index, Scalar value, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & index_fill_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Scalar & value) { + return at::_ops::index_fill_int_Scalar_out::redispatch(dispatchKeySet, self, dim, index, value, out); + } + + // aten::index_fill.int_Scalar_out(Tensor self, int dim, Tensor index, Scalar value, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & index_fill_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Scalar & value, at::Tensor & out) { + return at::_ops::index_fill_int_Scalar_out::redispatch(dispatchKeySet, self, dim, index, value, out); + } + + // aten::index_fill.int_Tensor_out(Tensor self, int dim, Tensor index, Tensor value, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & index_fill_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & value) { + return at::_ops::index_fill_int_Tensor_out::redispatch(dispatchKeySet, self, dim, index, value, out); + } + + // aten::index_fill.int_Tensor_out(Tensor self, int dim, Tensor index, Tensor value, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & index_fill_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & value, at::Tensor & out) { + return at::_ops::index_fill_int_Tensor_out::redispatch(dispatchKeySet, self, dim, index, value, out); + } + + // aten::bitwise_and.Scalar_Tensor_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bitwise_and_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & self, const at::Tensor & other) { + return at::_ops::bitwise_and_Scalar_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::bitwise_and.Scalar_Tensor_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bitwise_and_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::bitwise_and_Scalar_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::bitwise_or.Scalar_Tensor_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bitwise_or_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & self, const at::Tensor & other) { + return at::_ops::bitwise_or_Scalar_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::bitwise_or.Scalar_Tensor_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bitwise_or_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::bitwise_or_Scalar_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::bitwise_xor.Scalar_Tensor_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bitwise_xor_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & self, const at::Tensor & other) { + return at::_ops::bitwise_xor_Scalar_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::bitwise_xor.Scalar_Tensor_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bitwise_xor_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::bitwise_xor_Scalar_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::__lshift__.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & __lshift___out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::__lshift___Scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::__lshift__.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & __lshift___outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, at::Tensor & out) { + return at::_ops::__lshift___Scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::__lshift__.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & __lshift___out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::__lshift___Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::__lshift__.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & __lshift___outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::__lshift___Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::bitwise_left_shift.Scalar_Tensor_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bitwise_left_shift_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & self, const at::Tensor & other) { + return at::_ops::bitwise_left_shift_Scalar_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::bitwise_left_shift.Scalar_Tensor_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bitwise_left_shift_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::bitwise_left_shift_Scalar_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::__rshift__.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & __rshift___out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::__rshift___Scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::__rshift__.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & __rshift___outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, at::Tensor & out) { + return at::_ops::__rshift___Scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::__rshift__.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & __rshift___out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::__rshift___Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::__rshift__.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & __rshift___outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::__rshift___Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::bitwise_right_shift.Scalar_Tensor_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bitwise_right_shift_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & self, const at::Tensor & other) { + return at::_ops::bitwise_right_shift_Scalar_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::bitwise_right_shift.Scalar_Tensor_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bitwise_right_shift_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::bitwise_right_shift_Scalar_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::random.from_out(Tensor self, int from, int? to, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & random_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t from, ::std::optional to, ::std::optional generator=::std::nullopt) { + return at::_ops::random_from_out::redispatch(dispatchKeySet, self, from, to, generator, out); + } + + // aten::random.from_out(Tensor self, int from, int? to, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & random_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t from, ::std::optional to, ::std::optional generator, at::Tensor & out) { + return at::_ops::random_from_out::redispatch(dispatchKeySet, self, from, to, generator, out); + } + + // aten::random.from(Tensor self, int from, int? to, *, Generator? generator=None) -> Tensor + inline at::Tensor random(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t from, ::std::optional to, ::std::optional generator=::std::nullopt) { + return at::_ops::random_from::redispatch(dispatchKeySet, self, from, to, generator); + } + + // aten::random.to_out(Tensor self, int to, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & random_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t to, ::std::optional generator=::std::nullopt) { + return at::_ops::random_to_out::redispatch(dispatchKeySet, self, to, generator, out); + } + + // aten::random.to_out(Tensor self, int to, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & random_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t to, ::std::optional generator, at::Tensor & out) { + return at::_ops::random_to_out::redispatch(dispatchKeySet, self, to, generator, out); + } + + // aten::random.to(Tensor self, int to, *, Generator? generator=None) -> Tensor + inline at::Tensor random(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t to, ::std::optional generator=::std::nullopt) { + return at::_ops::random_to::redispatch(dispatchKeySet, self, to, generator); + } + + // aten::random.out(Tensor self, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & random_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, ::std::optional generator=::std::nullopt) { + return at::_ops::random_out::redispatch(dispatchKeySet, self, generator, out); + } + + // aten::random.out(Tensor self, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & random_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional generator, at::Tensor & out) { + return at::_ops::random_out::redispatch(dispatchKeySet, self, generator, out); + } + + // aten::random(Tensor self, *, Generator? generator=None) -> Tensor + inline at::Tensor random(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional generator=::std::nullopt) { + return at::_ops::random::redispatch(dispatchKeySet, self, generator); + } + + // aten::uniform.out(Tensor self, float from=0, float to=1, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & uniform_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, double from=0, double to=1, ::std::optional generator=::std::nullopt) { + return at::_ops::uniform_out::redispatch(dispatchKeySet, self, from, to, generator, out); + } + + // aten::uniform.out(Tensor self, float from=0, float to=1, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & uniform_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double from, double to, ::std::optional generator, at::Tensor & out) { + return at::_ops::uniform_out::redispatch(dispatchKeySet, self, from, to, generator, out); + } + + // aten::uniform(Tensor self, float from=0, float to=1, *, Generator? generator=None) -> Tensor + inline at::Tensor uniform(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double from=0, double to=1, ::std::optional generator=::std::nullopt) { + return at::_ops::uniform::redispatch(dispatchKeySet, self, from, to, generator); + } + + // aten::cauchy.out(Tensor self, float median=0, float sigma=1, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & cauchy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, double median=0, double sigma=1, ::std::optional generator=::std::nullopt) { + return at::_ops::cauchy_out::redispatch(dispatchKeySet, self, median, sigma, generator, out); + } + + // aten::cauchy.out(Tensor self, float median=0, float sigma=1, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & cauchy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double median, double sigma, ::std::optional generator, at::Tensor & out) { + return at::_ops::cauchy_out::redispatch(dispatchKeySet, self, median, sigma, generator, out); + } + + // aten::cauchy(Tensor self, float median=0, float sigma=1, *, Generator? generator=None) -> Tensor + inline at::Tensor cauchy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double median=0, double sigma=1, ::std::optional generator=::std::nullopt) { + return at::_ops::cauchy::redispatch(dispatchKeySet, self, median, sigma, generator); + } + + // aten::log_normal.out(Tensor self, float mean=1, float std=2, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & log_normal_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, double mean=1, double std=2, ::std::optional generator=::std::nullopt) { + return at::_ops::log_normal_out::redispatch(dispatchKeySet, self, mean, std, generator, out); + } + + // aten::log_normal.out(Tensor self, float mean=1, float std=2, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & log_normal_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double mean, double std, ::std::optional generator, at::Tensor & out) { + return at::_ops::log_normal_out::redispatch(dispatchKeySet, self, mean, std, generator, out); + } + + // aten::log_normal(Tensor self, float mean=1, float std=2, *, Generator? generator=None) -> Tensor + inline at::Tensor log_normal(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double mean=1, double std=2, ::std::optional generator=::std::nullopt) { + return at::_ops::log_normal::redispatch(dispatchKeySet, self, mean, std, generator); + } + + // aten::exponential.out(Tensor self, float lambd=1, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & exponential_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, double lambd=1, ::std::optional generator=::std::nullopt) { + return at::_ops::exponential_out::redispatch(dispatchKeySet, self, lambd, generator, out); + } + + // aten::exponential.out(Tensor self, float lambd=1, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & exponential_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double lambd, ::std::optional generator, at::Tensor & out) { + return at::_ops::exponential_out::redispatch(dispatchKeySet, self, lambd, generator, out); + } + + // aten::exponential(Tensor self, float lambd=1, *, Generator? generator=None) -> Tensor + inline at::Tensor exponential(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double lambd=1, ::std::optional generator=::std::nullopt) { + return at::_ops::exponential::redispatch(dispatchKeySet, self, lambd, generator); + } + + // aten::geometric.out(Tensor self, float p, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & geometric_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, double p, ::std::optional generator=::std::nullopt) { + return at::_ops::geometric_out::redispatch(dispatchKeySet, self, p, generator, out); + } + + // aten::geometric.out(Tensor self, float p, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & geometric_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double p, ::std::optional generator, at::Tensor & out) { + return at::_ops::geometric_out::redispatch(dispatchKeySet, self, p, generator, out); + } + + // aten::geometric(Tensor self, float p, *, Generator? generator=None) -> Tensor + inline at::Tensor geometric(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double p, ::std::optional generator=::std::nullopt) { + return at::_ops::geometric::redispatch(dispatchKeySet, self, p, generator); + } + + // aten::tril_indices.out(int row, int col, int offset=0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & tril_indices_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, int64_t row, int64_t col, int64_t offset=0) { + return at::_ops::tril_indices_out::redispatch(dispatchKeySet, row, col, offset, out); + } + + // aten::tril_indices.out(int row, int col, int offset=0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & tril_indices_outf(c10::DispatchKeySet dispatchKeySet, int64_t row, int64_t col, int64_t offset, at::Tensor & out) { + return at::_ops::tril_indices_out::redispatch(dispatchKeySet, row, col, offset, out); + } + + // aten::triu_indices.out(int row, int col, int offset=0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & triu_indices_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, int64_t row, int64_t col, int64_t offset=0) { + return at::_ops::triu_indices_out::redispatch(dispatchKeySet, row, col, offset, out); + } + + // aten::triu_indices.out(int row, int col, int offset=0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & triu_indices_outf(c10::DispatchKeySet dispatchKeySet, int64_t row, int64_t col, int64_t offset, at::Tensor & out) { + return at::_ops::triu_indices_out::redispatch(dispatchKeySet, row, col, offset, out); + } + + // aten::trace.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & trace_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::trace_out::redispatch(dispatchKeySet, self, out); + } + + // aten::trace.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & trace_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::trace_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_cholesky_solve_helper.out(Tensor self, Tensor A, bool upper, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _cholesky_solve_helper_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & A, bool upper) { + return at::_ops::_cholesky_solve_helper_out::redispatch(dispatchKeySet, self, A, upper, out); + } + + // aten::_cholesky_solve_helper.out(Tensor self, Tensor A, bool upper, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _cholesky_solve_helper_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & A, bool upper, at::Tensor & out) { + return at::_ops::_cholesky_solve_helper_out::redispatch(dispatchKeySet, self, A, upper, out); + } + + // aten::dist.out(Tensor self, Tensor other, Scalar p=2, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & dist_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other, const at::Scalar & p=2) { + return at::_ops::dist_out::redispatch(dispatchKeySet, self, other, p, out); + } + + // aten::dist.out(Tensor self, Tensor other, Scalar p=2, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & dist_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, const at::Scalar & p, at::Tensor & out) { + return at::_ops::dist_out::redispatch(dispatchKeySet, self, other, p, out); + } + + // aten::_histogramdd_bin_edges.out(Tensor self, int[] bins, *, float[]? range=None, Tensor? weight=None, bool density=False, Tensor(a!)[] out) -> () + inline void _histogramdd_bin_edges_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, const at::Tensor & self, at::IntArrayRef bins, ::std::optional> range=::std::nullopt, const ::std::optional & weight={}, bool density=false) { + return at::_ops::_histogramdd_bin_edges_out::redispatch(dispatchKeySet, self, bins, range, weight, density, out); + } + + // aten::_histogramdd_bin_edges.out(Tensor self, int[] bins, *, float[]? range=None, Tensor? weight=None, bool density=False, Tensor(a!)[] out) -> () + inline void _histogramdd_bin_edges_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef bins, ::std::optional> range, const ::std::optional & weight, bool density, at::TensorList out) { + return at::_ops::_histogramdd_bin_edges_out::redispatch(dispatchKeySet, self, bins, range, weight, density, out); + } + + // aten::_histogramdd_from_bin_cts.out(Tensor self, int[] bins, *, float[]? range=None, Tensor? weight=None, bool density=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _histogramdd_from_bin_cts_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef bins, ::std::optional> range=::std::nullopt, const ::std::optional & weight={}, bool density=false) { + return at::_ops::_histogramdd_from_bin_cts_out::redispatch(dispatchKeySet, self, bins, range, weight, density, out); + } + + // aten::_histogramdd_from_bin_cts.out(Tensor self, int[] bins, *, float[]? range=None, Tensor? weight=None, bool density=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _histogramdd_from_bin_cts_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef bins, ::std::optional> range, const ::std::optional & weight, bool density, at::Tensor & out) { + return at::_ops::_histogramdd_from_bin_cts_out::redispatch(dispatchKeySet, self, bins, range, weight, density, out); + } + + // aten::_histogramdd_from_bin_tensors.out(Tensor self, Tensor[] bins, *, Tensor? weight=None, bool density=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _histogramdd_from_bin_tensors_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::TensorList bins, const ::std::optional & weight={}, bool density=false) { + return at::_ops::_histogramdd_from_bin_tensors_out::redispatch(dispatchKeySet, self, bins, weight, density, out); + } + + // aten::_histogramdd_from_bin_tensors.out(Tensor self, Tensor[] bins, *, Tensor? weight=None, bool density=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _histogramdd_from_bin_tensors_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::TensorList bins, const ::std::optional & weight, bool density, at::Tensor & out) { + return at::_ops::_histogramdd_from_bin_tensors_out::redispatch(dispatchKeySet, self, bins, weight, density, out); + } + + // aten::remainder.Scalar_Tensor_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & remainder_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & self, const at::Tensor & other) { + return at::_ops::remainder_Scalar_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::remainder.Scalar_Tensor_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & remainder_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::remainder_Scalar_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::unfold_backward.out(Tensor grad_in, SymInt[] input_sizes, int dim, int size, int step, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & unfold_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad_in, at::IntArrayRef input_sizes, int64_t dim, int64_t size, int64_t step) { + return at::_ops::unfold_backward_out::redispatch(dispatchKeySet, grad_in, c10::fromIntArrayRefSlow(input_sizes), dim, size, step, out); + } + + // aten::unfold_backward.out(Tensor grad_in, SymInt[] input_sizes, int dim, int size, int step, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & unfold_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_in, at::IntArrayRef input_sizes, int64_t dim, int64_t size, int64_t step, at::Tensor & out) { + return at::_ops::unfold_backward_out::redispatch(dispatchKeySet, grad_in, c10::fromIntArrayRefSlow(input_sizes), dim, size, step, out); + } + + // aten::unfold_backward.out(Tensor grad_in, SymInt[] input_sizes, int dim, int size, int step, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & unfold_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad_in, c10::SymIntArrayRef input_sizes, int64_t dim, int64_t size, int64_t step) { + return at::_ops::unfold_backward_out::redispatch(dispatchKeySet, grad_in, input_sizes, dim, size, step, out); + } + + // aten::unfold_backward.out(Tensor grad_in, SymInt[] input_sizes, int dim, int size, int step, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & unfold_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_in, c10::SymIntArrayRef input_sizes, int64_t dim, int64_t size, int64_t step, at::Tensor & out) { + return at::_ops::unfold_backward_out::redispatch(dispatchKeySet, grad_in, input_sizes, dim, size, step, out); + } + + // aten::normal.out(Tensor self, float mean=0, float std=1, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & normal_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, double mean=0, double std=1, ::std::optional generator=::std::nullopt) { + return at::_ops::normal_out::redispatch(dispatchKeySet, self, mean, std, generator, out); + } + + // aten::normal.out(Tensor self, float mean=0, float std=1, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & normal_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double mean, double std, ::std::optional generator, at::Tensor & out) { + return at::_ops::normal_out::redispatch(dispatchKeySet, self, mean, std, generator, out); + } + + // aten::_amp_foreach_non_finite_check_and_unscale.out(Tensor[] self, Tensor(b!) found_inf, Tensor inv_scale, *, Tensor(a!)[] out) -> () + inline void _amp_foreach_non_finite_check_and_unscale_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::Tensor & found_inf, const at::Tensor & inv_scale) { + return at::_ops::_amp_foreach_non_finite_check_and_unscale_out::redispatch(dispatchKeySet, self, found_inf, inv_scale, out); + } + + // aten::_amp_foreach_non_finite_check_and_unscale.out(Tensor[] self, Tensor(b!) found_inf, Tensor inv_scale, *, Tensor(a!)[] out) -> () + inline void _amp_foreach_non_finite_check_and_unscale_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::Tensor & found_inf, const at::Tensor & inv_scale, at::TensorList out) { + return at::_ops::_amp_foreach_non_finite_check_and_unscale_out::redispatch(dispatchKeySet, self, found_inf, inv_scale, out); + } + + // aten::_amp_foreach_non_finite_check_and_unscale(Tensor[] self, Tensor found_inf, Tensor inv_scale) -> (Tensor[] self_out, Tensor found_inf_out) + inline ::std::tuple<::std::vector,at::Tensor> _amp_foreach_non_finite_check_and_unscale(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Tensor & found_inf, const at::Tensor & inv_scale) { + return at::_ops::_amp_foreach_non_finite_check_and_unscale::redispatch(dispatchKeySet, self, found_inf, inv_scale); + } + + // aten::_amp_update_scale.out(Tensor self, Tensor(b!) growth_tracker, Tensor found_inf, float scale_growth_factor, float scale_backoff_factor, int growth_interval, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _amp_update_scale_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::Tensor & growth_tracker, const at::Tensor & found_inf, double scale_growth_factor, double scale_backoff_factor, int64_t growth_interval) { + return at::_ops::_amp_update_scale_out::redispatch(dispatchKeySet, self, growth_tracker, found_inf, scale_growth_factor, scale_backoff_factor, growth_interval, out); + } + + // aten::_amp_update_scale.out(Tensor self, Tensor(b!) growth_tracker, Tensor found_inf, float scale_growth_factor, float scale_backoff_factor, int growth_interval, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _amp_update_scale_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & growth_tracker, const at::Tensor & found_inf, double scale_growth_factor, double scale_backoff_factor, int64_t growth_interval, at::Tensor & out) { + return at::_ops::_amp_update_scale_out::redispatch(dispatchKeySet, self, growth_tracker, found_inf, scale_growth_factor, scale_backoff_factor, growth_interval, out); + } + + // aten::_amp_update_scale(Tensor self, Tensor growth_tracker, Tensor found_inf, float scale_growth_factor, float scale_backoff_factor, int growth_interval) -> (Tensor, Tensor growth_tracker_out) + inline ::std::tuple _amp_update_scale(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & growth_tracker, const at::Tensor & found_inf, double scale_growth_factor, double scale_backoff_factor, int64_t growth_interval) { + return at::_ops::_amp_update_scale::redispatch(dispatchKeySet, self, growth_tracker, found_inf, scale_growth_factor, scale_backoff_factor, growth_interval); + } + + // aten::_foreach_add.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> () + inline void _foreach_add_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, const at::Scalar & scalar) { + return at::_ops::_foreach_add_Scalar_out::redispatch(dispatchKeySet, self, scalar, out); + } + + // aten::_foreach_add.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> () + inline void _foreach_add_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Scalar & scalar, at::TensorList out) { + return at::_ops::_foreach_add_Scalar_out::redispatch(dispatchKeySet, self, scalar, out); + } + + // aten::_foreach_add.List_out(Tensor[] self, Tensor[] other, *, Scalar alpha=1, Tensor(a!)[] out) -> () + inline void _foreach_add_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::TensorList other, const at::Scalar & alpha=1) { + return at::_ops::_foreach_add_List_out::redispatch(dispatchKeySet, self, other, alpha, out); + } + + // aten::_foreach_add.List_out(Tensor[] self, Tensor[] other, *, Scalar alpha=1, Tensor(a!)[] out) -> () + inline void _foreach_add_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList other, const at::Scalar & alpha, at::TensorList out) { + return at::_ops::_foreach_add_List_out::redispatch(dispatchKeySet, self, other, alpha, out); + } + + // aten::_foreach_add.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> () + inline void _foreach_add_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::ArrayRef scalars) { + return at::_ops::_foreach_add_ScalarList_out::redispatch(dispatchKeySet, self, scalars, out); + } + + // aten::_foreach_add.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> () + inline void _foreach_add_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::ArrayRef scalars, at::TensorList out) { + return at::_ops::_foreach_add_ScalarList_out::redispatch(dispatchKeySet, self, scalars, out); + } + + // aten::_foreach_add.Tensor_out(Tensor[] self, Tensor other, *, Scalar alpha=1, Tensor(a!)[] out) -> () + inline void _foreach_add_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, const at::Tensor & other, const at::Scalar & alpha=1) { + return at::_ops::_foreach_add_Tensor_out::redispatch(dispatchKeySet, self, other, alpha, out); + } + + // aten::_foreach_add.Tensor_out(Tensor[] self, Tensor other, *, Scalar alpha=1, Tensor(a!)[] out) -> () + inline void _foreach_add_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Tensor & other, const at::Scalar & alpha, at::TensorList out) { + return at::_ops::_foreach_add_Tensor_out::redispatch(dispatchKeySet, self, other, alpha, out); + } + + // aten::_foreach_sub.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> () + inline void _foreach_sub_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, const at::Scalar & scalar) { + return at::_ops::_foreach_sub_Scalar_out::redispatch(dispatchKeySet, self, scalar, out); + } + + // aten::_foreach_sub.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> () + inline void _foreach_sub_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Scalar & scalar, at::TensorList out) { + return at::_ops::_foreach_sub_Scalar_out::redispatch(dispatchKeySet, self, scalar, out); + } + + // aten::_foreach_sub.List_out(Tensor[] self, Tensor[] other, *, Scalar alpha=1, Tensor(a!)[] out) -> () + inline void _foreach_sub_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::TensorList other, const at::Scalar & alpha=1) { + return at::_ops::_foreach_sub_List_out::redispatch(dispatchKeySet, self, other, alpha, out); + } + + // aten::_foreach_sub.List_out(Tensor[] self, Tensor[] other, *, Scalar alpha=1, Tensor(a!)[] out) -> () + inline void _foreach_sub_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList other, const at::Scalar & alpha, at::TensorList out) { + return at::_ops::_foreach_sub_List_out::redispatch(dispatchKeySet, self, other, alpha, out); + } + + // aten::_foreach_sub.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> () + inline void _foreach_sub_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::ArrayRef scalars) { + return at::_ops::_foreach_sub_ScalarList_out::redispatch(dispatchKeySet, self, scalars, out); + } + + // aten::_foreach_sub.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> () + inline void _foreach_sub_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::ArrayRef scalars, at::TensorList out) { + return at::_ops::_foreach_sub_ScalarList_out::redispatch(dispatchKeySet, self, scalars, out); + } + + // aten::_foreach_mul.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> () + inline void _foreach_mul_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, const at::Scalar & scalar) { + return at::_ops::_foreach_mul_Scalar_out::redispatch(dispatchKeySet, self, scalar, out); + } + + // aten::_foreach_mul.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> () + inline void _foreach_mul_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Scalar & scalar, at::TensorList out) { + return at::_ops::_foreach_mul_Scalar_out::redispatch(dispatchKeySet, self, scalar, out); + } + + // aten::_foreach_mul.List_out(Tensor[] self, Tensor[] other, *, Tensor(a!)[] out) -> () + inline void _foreach_mul_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::TensorList other) { + return at::_ops::_foreach_mul_List_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::_foreach_mul.List_out(Tensor[] self, Tensor[] other, *, Tensor(a!)[] out) -> () + inline void _foreach_mul_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList other, at::TensorList out) { + return at::_ops::_foreach_mul_List_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::_foreach_mul.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> () + inline void _foreach_mul_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::ArrayRef scalars) { + return at::_ops::_foreach_mul_ScalarList_out::redispatch(dispatchKeySet, self, scalars, out); + } + + // aten::_foreach_mul.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> () + inline void _foreach_mul_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::ArrayRef scalars, at::TensorList out) { + return at::_ops::_foreach_mul_ScalarList_out::redispatch(dispatchKeySet, self, scalars, out); + } + + // aten::_foreach_mul.Tensor_out(Tensor[] self, Tensor other, *, Tensor(a!)[] out) -> () + inline void _foreach_mul_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, const at::Tensor & other) { + return at::_ops::_foreach_mul_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::_foreach_mul.Tensor_out(Tensor[] self, Tensor other, *, Tensor(a!)[] out) -> () + inline void _foreach_mul_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Tensor & other, at::TensorList out) { + return at::_ops::_foreach_mul_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::_foreach_div.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> () + inline void _foreach_div_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, const at::Scalar & scalar) { + return at::_ops::_foreach_div_Scalar_out::redispatch(dispatchKeySet, self, scalar, out); + } + + // aten::_foreach_div.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> () + inline void _foreach_div_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Scalar & scalar, at::TensorList out) { + return at::_ops::_foreach_div_Scalar_out::redispatch(dispatchKeySet, self, scalar, out); + } + + // aten::_foreach_div.List_out(Tensor[] self, Tensor[] other, *, Tensor(a!)[] out) -> () + inline void _foreach_div_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::TensorList other) { + return at::_ops::_foreach_div_List_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::_foreach_div.List_out(Tensor[] self, Tensor[] other, *, Tensor(a!)[] out) -> () + inline void _foreach_div_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList other, at::TensorList out) { + return at::_ops::_foreach_div_List_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::_foreach_div.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> () + inline void _foreach_div_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::ArrayRef scalars) { + return at::_ops::_foreach_div_ScalarList_out::redispatch(dispatchKeySet, self, scalars, out); + } + + // aten::_foreach_div.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> () + inline void _foreach_div_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::ArrayRef scalars, at::TensorList out) { + return at::_ops::_foreach_div_ScalarList_out::redispatch(dispatchKeySet, self, scalars, out); + } + + // aten::_foreach_div.Tensor_out(Tensor[] self, Tensor other, *, Tensor(a!)[] out) -> () + inline void _foreach_div_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, const at::Tensor & other) { + return at::_ops::_foreach_div_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::_foreach_div.Tensor_out(Tensor[] self, Tensor other, *, Tensor(a!)[] out) -> () + inline void _foreach_div_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Tensor & other, at::TensorList out) { + return at::_ops::_foreach_div_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::_foreach_clamp_max.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> () + inline void _foreach_clamp_max_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, const at::Scalar & scalar) { + return at::_ops::_foreach_clamp_max_Scalar_out::redispatch(dispatchKeySet, self, scalar, out); + } + + // aten::_foreach_clamp_max.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> () + inline void _foreach_clamp_max_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Scalar & scalar, at::TensorList out) { + return at::_ops::_foreach_clamp_max_Scalar_out::redispatch(dispatchKeySet, self, scalar, out); + } + + // aten::_foreach_clamp_max.List_out(Tensor[] self, Tensor[] other, *, Tensor(a!)[] out) -> () + inline void _foreach_clamp_max_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::TensorList other) { + return at::_ops::_foreach_clamp_max_List_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::_foreach_clamp_max.List_out(Tensor[] self, Tensor[] other, *, Tensor(a!)[] out) -> () + inline void _foreach_clamp_max_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList other, at::TensorList out) { + return at::_ops::_foreach_clamp_max_List_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::_foreach_clamp_max.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> () + inline void _foreach_clamp_max_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::ArrayRef scalars) { + return at::_ops::_foreach_clamp_max_ScalarList_out::redispatch(dispatchKeySet, self, scalars, out); + } + + // aten::_foreach_clamp_max.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> () + inline void _foreach_clamp_max_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::ArrayRef scalars, at::TensorList out) { + return at::_ops::_foreach_clamp_max_ScalarList_out::redispatch(dispatchKeySet, self, scalars, out); + } + + // aten::_foreach_clamp_min.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> () + inline void _foreach_clamp_min_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, const at::Scalar & scalar) { + return at::_ops::_foreach_clamp_min_Scalar_out::redispatch(dispatchKeySet, self, scalar, out); + } + + // aten::_foreach_clamp_min.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> () + inline void _foreach_clamp_min_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Scalar & scalar, at::TensorList out) { + return at::_ops::_foreach_clamp_min_Scalar_out::redispatch(dispatchKeySet, self, scalar, out); + } + + // aten::_foreach_clamp_min.List_out(Tensor[] self, Tensor[] other, *, Tensor(a!)[] out) -> () + inline void _foreach_clamp_min_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::TensorList other) { + return at::_ops::_foreach_clamp_min_List_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::_foreach_clamp_min.List_out(Tensor[] self, Tensor[] other, *, Tensor(a!)[] out) -> () + inline void _foreach_clamp_min_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList other, at::TensorList out) { + return at::_ops::_foreach_clamp_min_List_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::_foreach_clamp_min.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> () + inline void _foreach_clamp_min_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::ArrayRef scalars) { + return at::_ops::_foreach_clamp_min_ScalarList_out::redispatch(dispatchKeySet, self, scalars, out); + } + + // aten::_foreach_clamp_min.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> () + inline void _foreach_clamp_min_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::ArrayRef scalars, at::TensorList out) { + return at::_ops::_foreach_clamp_min_ScalarList_out::redispatch(dispatchKeySet, self, scalars, out); + } + + // aten::_foreach_maximum.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> () + inline void _foreach_maximum_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, const at::Scalar & scalar) { + return at::_ops::_foreach_maximum_Scalar_out::redispatch(dispatchKeySet, self, scalar, out); + } + + // aten::_foreach_maximum.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> () + inline void _foreach_maximum_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Scalar & scalar, at::TensorList out) { + return at::_ops::_foreach_maximum_Scalar_out::redispatch(dispatchKeySet, self, scalar, out); + } + + // aten::_foreach_maximum.List_out(Tensor[] self, Tensor[] other, *, Tensor(a!)[] out) -> () + inline void _foreach_maximum_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::TensorList other) { + return at::_ops::_foreach_maximum_List_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::_foreach_maximum.List_out(Tensor[] self, Tensor[] other, *, Tensor(a!)[] out) -> () + inline void _foreach_maximum_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList other, at::TensorList out) { + return at::_ops::_foreach_maximum_List_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::_foreach_maximum.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> () + inline void _foreach_maximum_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::ArrayRef scalars) { + return at::_ops::_foreach_maximum_ScalarList_out::redispatch(dispatchKeySet, self, scalars, out); + } + + // aten::_foreach_maximum.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> () + inline void _foreach_maximum_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::ArrayRef scalars, at::TensorList out) { + return at::_ops::_foreach_maximum_ScalarList_out::redispatch(dispatchKeySet, self, scalars, out); + } + + // aten::_foreach_minimum.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> () + inline void _foreach_minimum_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, const at::Scalar & scalar) { + return at::_ops::_foreach_minimum_Scalar_out::redispatch(dispatchKeySet, self, scalar, out); + } + + // aten::_foreach_minimum.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> () + inline void _foreach_minimum_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Scalar & scalar, at::TensorList out) { + return at::_ops::_foreach_minimum_Scalar_out::redispatch(dispatchKeySet, self, scalar, out); + } + + // aten::_foreach_minimum.List_out(Tensor[] self, Tensor[] other, *, Tensor(a!)[] out) -> () + inline void _foreach_minimum_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::TensorList other) { + return at::_ops::_foreach_minimum_List_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::_foreach_minimum.List_out(Tensor[] self, Tensor[] other, *, Tensor(a!)[] out) -> () + inline void _foreach_minimum_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList other, at::TensorList out) { + return at::_ops::_foreach_minimum_List_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::_foreach_minimum.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> () + inline void _foreach_minimum_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::ArrayRef scalars) { + return at::_ops::_foreach_minimum_ScalarList_out::redispatch(dispatchKeySet, self, scalars, out); + } + + // aten::_foreach_minimum.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> () + inline void _foreach_minimum_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::ArrayRef scalars, at::TensorList out) { + return at::_ops::_foreach_minimum_ScalarList_out::redispatch(dispatchKeySet, self, scalars, out); + } + + // aten::_foreach_addcdiv.Scalar_out(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1, *, Tensor(a!)[] out) -> () + inline void _foreach_addcdiv_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Scalar & value=1) { + return at::_ops::_foreach_addcdiv_Scalar_out::redispatch(dispatchKeySet, self, tensor1, tensor2, value, out); + } + + // aten::_foreach_addcdiv.Scalar_out(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1, *, Tensor(a!)[] out) -> () + inline void _foreach_addcdiv_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Scalar & value, at::TensorList out) { + return at::_ops::_foreach_addcdiv_Scalar_out::redispatch(dispatchKeySet, self, tensor1, tensor2, value, out); + } + + // aten::_foreach_addcdiv.ScalarList_out(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars, *, Tensor(a!)[] out) -> () + inline void _foreach_addcdiv_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, at::ArrayRef scalars) { + return at::_ops::_foreach_addcdiv_ScalarList_out::redispatch(dispatchKeySet, self, tensor1, tensor2, scalars, out); + } + + // aten::_foreach_addcdiv.ScalarList_out(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars, *, Tensor(a!)[] out) -> () + inline void _foreach_addcdiv_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, at::ArrayRef scalars, at::TensorList out) { + return at::_ops::_foreach_addcdiv_ScalarList_out::redispatch(dispatchKeySet, self, tensor1, tensor2, scalars, out); + } + + // aten::_foreach_addcdiv.Tensor_out(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Tensor scalars, *, Tensor(a!)[] out) -> () + inline void _foreach_addcdiv_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Tensor & scalars) { + return at::_ops::_foreach_addcdiv_Tensor_out::redispatch(dispatchKeySet, self, tensor1, tensor2, scalars, out); + } + + // aten::_foreach_addcdiv.Tensor_out(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Tensor scalars, *, Tensor(a!)[] out) -> () + inline void _foreach_addcdiv_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Tensor & scalars, at::TensorList out) { + return at::_ops::_foreach_addcdiv_Tensor_out::redispatch(dispatchKeySet, self, tensor1, tensor2, scalars, out); + } + + // aten::_foreach_addcmul.Scalar_out(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1, *, Tensor(a!)[] out) -> () + inline void _foreach_addcmul_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Scalar & value=1) { + return at::_ops::_foreach_addcmul_Scalar_out::redispatch(dispatchKeySet, self, tensor1, tensor2, value, out); + } + + // aten::_foreach_addcmul.Scalar_out(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1, *, Tensor(a!)[] out) -> () + inline void _foreach_addcmul_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Scalar & value, at::TensorList out) { + return at::_ops::_foreach_addcmul_Scalar_out::redispatch(dispatchKeySet, self, tensor1, tensor2, value, out); + } + + // aten::_foreach_addcmul.ScalarList_out(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars, *, Tensor(a!)[] out) -> () + inline void _foreach_addcmul_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, at::ArrayRef scalars) { + return at::_ops::_foreach_addcmul_ScalarList_out::redispatch(dispatchKeySet, self, tensor1, tensor2, scalars, out); + } + + // aten::_foreach_addcmul.ScalarList_out(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars, *, Tensor(a!)[] out) -> () + inline void _foreach_addcmul_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, at::ArrayRef scalars, at::TensorList out) { + return at::_ops::_foreach_addcmul_ScalarList_out::redispatch(dispatchKeySet, self, tensor1, tensor2, scalars, out); + } + + // aten::_foreach_addcmul.Tensor_out(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Tensor scalars, *, Tensor(a!)[] out) -> () + inline void _foreach_addcmul_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Tensor & scalars) { + return at::_ops::_foreach_addcmul_Tensor_out::redispatch(dispatchKeySet, self, tensor1, tensor2, scalars, out); + } + + // aten::_foreach_addcmul.Tensor_out(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Tensor scalars, *, Tensor(a!)[] out) -> () + inline void _foreach_addcmul_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Tensor & scalars, at::TensorList out) { + return at::_ops::_foreach_addcmul_Tensor_out::redispatch(dispatchKeySet, self, tensor1, tensor2, scalars, out); + } + + // aten::_foreach_abs.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_abs_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self) { + return at::_ops::_foreach_abs_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_abs.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_abs_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out) { + return at::_ops::_foreach_abs_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_acos.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_acos_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self) { + return at::_ops::_foreach_acos_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_acos.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_acos_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out) { + return at::_ops::_foreach_acos_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_asin.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_asin_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self) { + return at::_ops::_foreach_asin_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_asin.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_asin_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out) { + return at::_ops::_foreach_asin_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_atan.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_atan_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self) { + return at::_ops::_foreach_atan_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_atan.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_atan_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out) { + return at::_ops::_foreach_atan_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_ceil.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_ceil_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self) { + return at::_ops::_foreach_ceil_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_ceil.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_ceil_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out) { + return at::_ops::_foreach_ceil_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_cos.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_cos_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self) { + return at::_ops::_foreach_cos_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_cos.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_cos_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out) { + return at::_ops::_foreach_cos_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_cosh.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_cosh_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self) { + return at::_ops::_foreach_cosh_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_cosh.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_cosh_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out) { + return at::_ops::_foreach_cosh_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_erf.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_erf_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self) { + return at::_ops::_foreach_erf_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_erf.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_erf_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out) { + return at::_ops::_foreach_erf_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_erfc.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_erfc_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self) { + return at::_ops::_foreach_erfc_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_erfc.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_erfc_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out) { + return at::_ops::_foreach_erfc_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_exp.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_exp_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self) { + return at::_ops::_foreach_exp_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_exp.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_exp_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out) { + return at::_ops::_foreach_exp_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_expm1.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_expm1_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self) { + return at::_ops::_foreach_expm1_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_expm1.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_expm1_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out) { + return at::_ops::_foreach_expm1_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_floor.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_floor_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self) { + return at::_ops::_foreach_floor_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_floor.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_floor_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out) { + return at::_ops::_foreach_floor_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_frac.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_frac_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self) { + return at::_ops::_foreach_frac_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_frac.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_frac_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out) { + return at::_ops::_foreach_frac_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_lerp.List_out(Tensor[] self, Tensor[] tensors1, Tensor[] weights, *, Tensor(a!)[] out) -> () + inline void _foreach_lerp_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::TensorList tensors1, at::TensorList weights) { + return at::_ops::_foreach_lerp_List_out::redispatch(dispatchKeySet, self, tensors1, weights, out); + } + + // aten::_foreach_lerp.List_out(Tensor[] self, Tensor[] tensors1, Tensor[] weights, *, Tensor(a!)[] out) -> () + inline void _foreach_lerp_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList tensors1, at::TensorList weights, at::TensorList out) { + return at::_ops::_foreach_lerp_List_out::redispatch(dispatchKeySet, self, tensors1, weights, out); + } + + // aten::_foreach_lerp.Scalar_out(Tensor[] self, Tensor[] tensors1, Scalar weight, *, Tensor(a!)[] out) -> () + inline void _foreach_lerp_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::TensorList tensors1, const at::Scalar & weight) { + return at::_ops::_foreach_lerp_Scalar_out::redispatch(dispatchKeySet, self, tensors1, weight, out); + } + + // aten::_foreach_lerp.Scalar_out(Tensor[] self, Tensor[] tensors1, Scalar weight, *, Tensor(a!)[] out) -> () + inline void _foreach_lerp_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList tensors1, const at::Scalar & weight, at::TensorList out) { + return at::_ops::_foreach_lerp_Scalar_out::redispatch(dispatchKeySet, self, tensors1, weight, out); + } + + // aten::_foreach_lerp.ScalarList_out(Tensor[] self, Tensor[] tensors1, Scalar[] weight, *, Tensor(a!)[] out) -> () + inline void _foreach_lerp_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::TensorList tensors1, at::ArrayRef weight) { + return at::_ops::_foreach_lerp_ScalarList_out::redispatch(dispatchKeySet, self, tensors1, weight, out); + } + + // aten::_foreach_lerp.ScalarList_out(Tensor[] self, Tensor[] tensors1, Scalar[] weight, *, Tensor(a!)[] out) -> () + inline void _foreach_lerp_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList tensors1, at::ArrayRef weight, at::TensorList out) { + return at::_ops::_foreach_lerp_ScalarList_out::redispatch(dispatchKeySet, self, tensors1, weight, out); + } + + // aten::_foreach_lgamma.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_lgamma_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self) { + return at::_ops::_foreach_lgamma_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_lgamma.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_lgamma_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out) { + return at::_ops::_foreach_lgamma_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_log.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_log_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self) { + return at::_ops::_foreach_log_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_log.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_log_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out) { + return at::_ops::_foreach_log_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_log10.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_log10_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self) { + return at::_ops::_foreach_log10_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_log10.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_log10_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out) { + return at::_ops::_foreach_log10_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_log1p.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_log1p_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self) { + return at::_ops::_foreach_log1p_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_log1p.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_log1p_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out) { + return at::_ops::_foreach_log1p_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_log2.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_log2_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self) { + return at::_ops::_foreach_log2_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_log2.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_log2_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out) { + return at::_ops::_foreach_log2_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_max.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_max_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self) { + return at::_ops::_foreach_max_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_max.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_max_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out) { + return at::_ops::_foreach_max_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_neg.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_neg_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self) { + return at::_ops::_foreach_neg_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_neg.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_neg_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out) { + return at::_ops::_foreach_neg_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_norm.Scalar_out(Tensor[] self, Scalar ord=2, ScalarType? dtype=None, *, Tensor(a!)[] out) -> () + inline void _foreach_norm_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, const at::Scalar & ord=2, ::std::optional dtype=::std::nullopt) { + return at::_ops::_foreach_norm_Scalar_out::redispatch(dispatchKeySet, self, ord, dtype, out); + } + + // aten::_foreach_norm.Scalar_out(Tensor[] self, Scalar ord=2, ScalarType? dtype=None, *, Tensor(a!)[] out) -> () + inline void _foreach_norm_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Scalar & ord, ::std::optional dtype, at::TensorList out) { + return at::_ops::_foreach_norm_Scalar_out::redispatch(dispatchKeySet, self, ord, dtype, out); + } + + // aten::_foreach_pow.List_out(Tensor[] self, Tensor[] exponent, *, Tensor(a!)[] out) -> () + inline void _foreach_pow_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::TensorList exponent) { + return at::_ops::_foreach_pow_List_out::redispatch(dispatchKeySet, self, exponent, out); + } + + // aten::_foreach_pow.List_out(Tensor[] self, Tensor[] exponent, *, Tensor(a!)[] out) -> () + inline void _foreach_pow_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList exponent, at::TensorList out) { + return at::_ops::_foreach_pow_List_out::redispatch(dispatchKeySet, self, exponent, out); + } + + // aten::_foreach_pow.Scalar_out(Tensor[] self, Scalar exponent, *, Tensor(a!)[] out) -> () + inline void _foreach_pow_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, const at::Scalar & exponent) { + return at::_ops::_foreach_pow_Scalar_out::redispatch(dispatchKeySet, self, exponent, out); + } + + // aten::_foreach_pow.Scalar_out(Tensor[] self, Scalar exponent, *, Tensor(a!)[] out) -> () + inline void _foreach_pow_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Scalar & exponent, at::TensorList out) { + return at::_ops::_foreach_pow_Scalar_out::redispatch(dispatchKeySet, self, exponent, out); + } + + // aten::_foreach_pow.ScalarList_out(Tensor[] self, Scalar[] exponent, *, Tensor(a!)[] out) -> () + inline void _foreach_pow_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::ArrayRef exponent) { + return at::_ops::_foreach_pow_ScalarList_out::redispatch(dispatchKeySet, self, exponent, out); + } + + // aten::_foreach_pow.ScalarList_out(Tensor[] self, Scalar[] exponent, *, Tensor(a!)[] out) -> () + inline void _foreach_pow_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::ArrayRef exponent, at::TensorList out) { + return at::_ops::_foreach_pow_ScalarList_out::redispatch(dispatchKeySet, self, exponent, out); + } + + // aten::_foreach_reciprocal.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_reciprocal_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self) { + return at::_ops::_foreach_reciprocal_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_reciprocal.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_reciprocal_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out) { + return at::_ops::_foreach_reciprocal_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_round.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_round_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self) { + return at::_ops::_foreach_round_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_round.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_round_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out) { + return at::_ops::_foreach_round_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_rsqrt.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_rsqrt_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self) { + return at::_ops::_foreach_rsqrt_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_rsqrt.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_rsqrt_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out) { + return at::_ops::_foreach_rsqrt_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_sigmoid.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_sigmoid_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self) { + return at::_ops::_foreach_sigmoid_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_sigmoid.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_sigmoid_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out) { + return at::_ops::_foreach_sigmoid_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_sign.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_sign_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self) { + return at::_ops::_foreach_sign_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_sign.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_sign_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out) { + return at::_ops::_foreach_sign_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_sin.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_sin_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self) { + return at::_ops::_foreach_sin_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_sin.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_sin_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out) { + return at::_ops::_foreach_sin_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_sinh.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_sinh_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self) { + return at::_ops::_foreach_sinh_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_sinh.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_sinh_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out) { + return at::_ops::_foreach_sinh_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_sqrt.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_sqrt_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self) { + return at::_ops::_foreach_sqrt_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_sqrt.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_sqrt_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out) { + return at::_ops::_foreach_sqrt_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_tan.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_tan_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self) { + return at::_ops::_foreach_tan_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_tan.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_tan_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out) { + return at::_ops::_foreach_tan_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_tanh.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_tanh_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self) { + return at::_ops::_foreach_tanh_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_tanh.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_tanh_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out) { + return at::_ops::_foreach_tanh_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_trunc.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_trunc_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self) { + return at::_ops::_foreach_trunc_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_trunc.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_trunc_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out) { + return at::_ops::_foreach_trunc_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_zero.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_zero_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self) { + return at::_ops::_foreach_zero_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_zero.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_zero_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out) { + return at::_ops::_foreach_zero_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_zero(Tensor[] self) -> Tensor[] self_out + inline ::std::vector _foreach_zero(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_zero::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_copy.out(Tensor[] self, Tensor[] src, bool non_blocking=False, *, Tensor(a!)[] out) -> () + inline void _foreach_copy_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::TensorList src, bool non_blocking=false) { + return at::_ops::_foreach_copy_out::redispatch(dispatchKeySet, self, src, non_blocking, out); + } + + // aten::_foreach_copy.out(Tensor[] self, Tensor[] src, bool non_blocking=False, *, Tensor(a!)[] out) -> () + inline void _foreach_copy_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList src, bool non_blocking, at::TensorList out) { + return at::_ops::_foreach_copy_out::redispatch(dispatchKeySet, self, src, non_blocking, out); + } + + // aten::bucketize.Scalar_out(Scalar self, Tensor boundaries, *, bool out_int32=False, bool right=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bucketize_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & self, const at::Tensor & boundaries, bool out_int32=false, bool right=false) { + return at::_ops::bucketize_Scalar_out::redispatch(dispatchKeySet, self, boundaries, out_int32, right, out); + } + + // aten::bucketize.Scalar_out(Scalar self, Tensor boundaries, *, bool out_int32=False, bool right=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bucketize_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & self, const at::Tensor & boundaries, bool out_int32, bool right, at::Tensor & out) { + return at::_ops::bucketize_Scalar_out::redispatch(dispatchKeySet, self, boundaries, out_int32, right, out); + } + + // aten::glu_jvp.out(Tensor glu, Tensor x, Tensor dx, int dim, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & glu_jvp_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & glu, const at::Tensor & x, const at::Tensor & dx, int64_t dim) { + return at::_ops::glu_jvp_out::redispatch(dispatchKeySet, glu, x, dx, dim, out); + } + + // aten::glu_jvp.out(Tensor glu, Tensor x, Tensor dx, int dim, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & glu_jvp_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & glu, const at::Tensor & x, const at::Tensor & dx, int64_t dim, at::Tensor & out) { + return at::_ops::glu_jvp_out::redispatch(dispatchKeySet, glu, x, dx, dim, out); + } + + // aten::glu_backward_jvp.out(Tensor grad_x, Tensor grad_glu, Tensor x, Tensor dgrad_glu, Tensor dx, int dim, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & glu_backward_jvp_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad_x, const at::Tensor & grad_glu, const at::Tensor & x, const at::Tensor & dgrad_glu, const at::Tensor & dx, int64_t dim) { + return at::_ops::glu_backward_jvp_out::redispatch(dispatchKeySet, grad_x, grad_glu, x, dgrad_glu, dx, dim, out); + } + + // aten::glu_backward_jvp.out(Tensor grad_x, Tensor grad_glu, Tensor x, Tensor dgrad_glu, Tensor dx, int dim, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & glu_backward_jvp_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_x, const at::Tensor & grad_glu, const at::Tensor & x, const at::Tensor & dgrad_glu, const at::Tensor & dx, int64_t dim, at::Tensor & out) { + return at::_ops::glu_backward_jvp_out::redispatch(dispatchKeySet, grad_x, grad_glu, x, dgrad_glu, dx, dim, out); + } + + // aten::hardswish_backward.out(Tensor grad_output, Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & hardswish_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad_output, const at::Tensor & self) { + return at::_ops::hardswish_backward_out::redispatch(dispatchKeySet, grad_output, self, out); + } + + // aten::hardswish_backward.out(Tensor grad_output, Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & hardswish_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::Tensor & out) { + return at::_ops::hardswish_backward_out::redispatch(dispatchKeySet, grad_output, self, out); + } + + // aten::rrelu_with_noise_functional(Tensor self, Tensor noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> (Tensor, Tensor noise_out) + inline ::std::tuple rrelu_with_noise_functional(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & noise, const at::Scalar & lower=0.125, const at::Scalar & upper=0.3333333333333333, bool training=false, ::std::optional generator=::std::nullopt) { + return at::_ops::rrelu_with_noise_functional::redispatch(dispatchKeySet, self, noise, lower, upper, training, generator); + } + + // aten::rrelu_with_noise_backward.out(Tensor grad_output, Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training, bool self_is_result, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & rrelu_with_noise_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & noise, const at::Scalar & lower, const at::Scalar & upper, bool training, bool self_is_result) { + return at::_ops::rrelu_with_noise_backward_out::redispatch(dispatchKeySet, grad_output, self, noise, lower, upper, training, self_is_result, out); + } + + // aten::rrelu_with_noise_backward.out(Tensor grad_output, Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training, bool self_is_result, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & rrelu_with_noise_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & noise, const at::Scalar & lower, const at::Scalar & upper, bool training, bool self_is_result, at::Tensor & out) { + return at::_ops::rrelu_with_noise_backward_out::redispatch(dispatchKeySet, grad_output, self, noise, lower, upper, training, self_is_result, out); + } + + // aten::mkldnn_adaptive_avg_pool2d_backward.out(Tensor grad_output, Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mkldnn_adaptive_avg_pool2d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad_output, const at::Tensor & self) { + return at::_ops::mkldnn_adaptive_avg_pool2d_backward_out::redispatch(dispatchKeySet, grad_output, self, out); + } + + // aten::mkldnn_adaptive_avg_pool2d_backward.out(Tensor grad_output, Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mkldnn_adaptive_avg_pool2d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::Tensor & out) { + return at::_ops::mkldnn_adaptive_avg_pool2d_backward_out::redispatch(dispatchKeySet, grad_output, self, out); + } + + // aten::_adaptive_avg_pool2d.out(Tensor self, SymInt[2] output_size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _adaptive_avg_pool2d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef output_size) { + return at::_ops::_adaptive_avg_pool2d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), out); + } + + // aten::_adaptive_avg_pool2d.out(Tensor self, SymInt[2] output_size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _adaptive_avg_pool2d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, at::Tensor & out) { + return at::_ops::_adaptive_avg_pool2d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), out); + } + + // aten::_adaptive_avg_pool2d.out(Tensor self, SymInt[2] output_size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _adaptive_avg_pool2d_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef output_size) { + return at::_ops::_adaptive_avg_pool2d_out::redispatch(dispatchKeySet, self, output_size, out); + } + + // aten::_adaptive_avg_pool2d.out(Tensor self, SymInt[2] output_size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _adaptive_avg_pool2d_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size, at::Tensor & out) { + return at::_ops::_adaptive_avg_pool2d_out::redispatch(dispatchKeySet, self, output_size, out); + } + + // aten::_adaptive_avg_pool2d_backward.out(Tensor grad_output, Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _adaptive_avg_pool2d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad_output, const at::Tensor & self) { + return at::_ops::_adaptive_avg_pool2d_backward_out::redispatch(dispatchKeySet, grad_output, self, out); + } + + // aten::_adaptive_avg_pool2d_backward.out(Tensor grad_output, Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _adaptive_avg_pool2d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::Tensor & out) { + return at::_ops::_adaptive_avg_pool2d_backward_out::redispatch(dispatchKeySet, grad_output, self, out); + } + + // aten::_adaptive_avg_pool3d.out(Tensor self, SymInt[3] output_size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _adaptive_avg_pool3d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef output_size) { + return at::_ops::_adaptive_avg_pool3d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), out); + } + + // aten::_adaptive_avg_pool3d.out(Tensor self, SymInt[3] output_size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _adaptive_avg_pool3d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, at::Tensor & out) { + return at::_ops::_adaptive_avg_pool3d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), out); + } + + // aten::_adaptive_avg_pool3d.out(Tensor self, SymInt[3] output_size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _adaptive_avg_pool3d_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef output_size) { + return at::_ops::_adaptive_avg_pool3d_out::redispatch(dispatchKeySet, self, output_size, out); + } + + // aten::_adaptive_avg_pool3d.out(Tensor self, SymInt[3] output_size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _adaptive_avg_pool3d_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size, at::Tensor & out) { + return at::_ops::_adaptive_avg_pool3d_out::redispatch(dispatchKeySet, self, output_size, out); + } + + // aten::_adaptive_avg_pool3d_backward.out(Tensor grad_output, Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _adaptive_avg_pool3d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad_output, const at::Tensor & self) { + return at::_ops::_adaptive_avg_pool3d_backward_out::redispatch(dispatchKeySet, grad_output, self, out); + } + + // aten::_adaptive_avg_pool3d_backward.out(Tensor grad_output, Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _adaptive_avg_pool3d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::Tensor & out) { + return at::_ops::_adaptive_avg_pool3d_backward_out::redispatch(dispatchKeySet, grad_output, self, out); + } + + // aten::upsample_bilinear2d.vec_out(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & upsample_bilinear2d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & input, at::OptionalIntArrayRef output_size, bool align_corners, ::std::optional> scale_factors) { + return at::_ops::upsample_bilinear2d_vec_out::redispatch(dispatchKeySet, input, output_size.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*output_size)) : ::std::nullopt, align_corners, scale_factors, out); + } + + // aten::upsample_bilinear2d.vec_out(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & upsample_bilinear2d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::OptionalIntArrayRef output_size, bool align_corners, ::std::optional> scale_factors, at::Tensor & out) { + return at::_ops::upsample_bilinear2d_vec_out::redispatch(dispatchKeySet, input, output_size.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*output_size)) : ::std::nullopt, align_corners, scale_factors, out); + } + + // aten::upsample_bilinear2d.vec_out(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & upsample_bilinear2d_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & input, at::OptionalSymIntArrayRef output_size, bool align_corners, ::std::optional> scale_factors) { + return at::_ops::upsample_bilinear2d_vec_out::redispatch(dispatchKeySet, input, output_size, align_corners, scale_factors, out); + } + + // aten::upsample_bilinear2d.vec_out(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & upsample_bilinear2d_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::OptionalSymIntArrayRef output_size, bool align_corners, ::std::optional> scale_factors, at::Tensor & out) { + return at::_ops::upsample_bilinear2d_vec_out::redispatch(dispatchKeySet, input, output_size, align_corners, scale_factors, out); + } + + // aten::upsample_nearest2d.vec_out(Tensor input, SymInt[]? output_size, float[]? scale_factors, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & upsample_nearest2d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & input, at::OptionalIntArrayRef output_size, ::std::optional> scale_factors) { + return at::_ops::upsample_nearest2d_vec_out::redispatch(dispatchKeySet, input, output_size.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*output_size)) : ::std::nullopt, scale_factors, out); + } + + // aten::upsample_nearest2d.vec_out(Tensor input, SymInt[]? output_size, float[]? scale_factors, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & upsample_nearest2d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::OptionalIntArrayRef output_size, ::std::optional> scale_factors, at::Tensor & out) { + return at::_ops::upsample_nearest2d_vec_out::redispatch(dispatchKeySet, input, output_size.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*output_size)) : ::std::nullopt, scale_factors, out); + } + + // aten::upsample_nearest2d.vec_out(Tensor input, SymInt[]? output_size, float[]? scale_factors, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & upsample_nearest2d_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & input, at::OptionalSymIntArrayRef output_size, ::std::optional> scale_factors) { + return at::_ops::upsample_nearest2d_vec_out::redispatch(dispatchKeySet, input, output_size, scale_factors, out); + } + + // aten::upsample_nearest2d.vec_out(Tensor input, SymInt[]? output_size, float[]? scale_factors, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & upsample_nearest2d_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::OptionalSymIntArrayRef output_size, ::std::optional> scale_factors, at::Tensor & out) { + return at::_ops::upsample_nearest2d_vec_out::redispatch(dispatchKeySet, input, output_size, scale_factors, out); + } + + // aten::_slow_conv2d_backward.output_mask_out(Tensor grad_output, Tensor self, Tensor weight, SymInt[2] kernel_size, SymInt[2] stride, SymInt[2] padding, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple _slow_conv2d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, ::std::array output_mask) { + return at::_ops::_slow_conv2d_backward_output_mask_out::redispatch(dispatchKeySet, grad_output, self, weight, c10::fromIntArrayRefSlow(kernel_size), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), output_mask, out0, out1, out2); + } + + // aten::_slow_conv2d_backward.output_mask_out(Tensor grad_output, Tensor self, Tensor weight, SymInt[2] kernel_size, SymInt[2] stride, SymInt[2] padding, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple _slow_conv2d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, ::std::array output_mask, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2) { + return at::_ops::_slow_conv2d_backward_output_mask_out::redispatch(dispatchKeySet, grad_output, self, weight, c10::fromIntArrayRefSlow(kernel_size), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), output_mask, out0, out1, out2); + } + + // aten::_slow_conv2d_backward.output_mask_out(Tensor grad_output, Tensor self, Tensor weight, SymInt[2] kernel_size, SymInt[2] stride, SymInt[2] padding, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple _slow_conv2d_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, ::std::array output_mask) { + return at::_ops::_slow_conv2d_backward_output_mask_out::redispatch(dispatchKeySet, grad_output, self, weight, kernel_size, stride, padding, output_mask, out0, out1, out2); + } + + // aten::_slow_conv2d_backward.output_mask_out(Tensor grad_output, Tensor self, Tensor weight, SymInt[2] kernel_size, SymInt[2] stride, SymInt[2] padding, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple _slow_conv2d_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, ::std::array output_mask, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2) { + return at::_ops::_slow_conv2d_backward_output_mask_out::redispatch(dispatchKeySet, grad_output, self, weight, kernel_size, stride, padding, output_mask, out0, out1, out2); + } + + // aten::conv_depthwise3d.out(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias, SymInt[3] stride, SymInt[3] padding, SymInt[3] dilation, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & conv_depthwise3d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, const ::std::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation) { + return at::_ops::conv_depthwise3d_out::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(kernel_size), bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), out); + } + + // aten::conv_depthwise3d.out(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias, SymInt[3] stride, SymInt[3] padding, SymInt[3] dilation, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & conv_depthwise3d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, const ::std::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, at::Tensor & out) { + return at::_ops::conv_depthwise3d_out::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(kernel_size), bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), out); + } + + // aten::conv_depthwise3d.out(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias, SymInt[3] stride, SymInt[3] padding, SymInt[3] dilation, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & conv_depthwise3d_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation) { + return at::_ops::conv_depthwise3d_out::redispatch(dispatchKeySet, self, weight, kernel_size, bias, stride, padding, dilation, out); + } + + // aten::conv_depthwise3d.out(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias, SymInt[3] stride, SymInt[3] padding, SymInt[3] dilation, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & conv_depthwise3d_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, at::Tensor & out) { + return at::_ops::conv_depthwise3d_out::redispatch(dispatchKeySet, self, weight, kernel_size, bias, stride, padding, dilation, out); + } + + // aten::slow_conv_dilated2d.out(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] dilation=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & slow_conv_dilated2d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, const ::std::optional & bias={}, at::IntArrayRef stride=1, at::IntArrayRef padding=0, at::IntArrayRef dilation=1) { + return at::_ops::slow_conv_dilated2d_out::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(kernel_size), bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), out); + } + + // aten::slow_conv_dilated2d.out(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] dilation=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & slow_conv_dilated2d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, const ::std::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, at::Tensor & out) { + return at::_ops::slow_conv_dilated2d_out::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(kernel_size), bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), out); + } + + // aten::slow_conv_dilated2d.out(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] dilation=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & slow_conv_dilated2d_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias={}, c10::SymIntArrayRef stride=c10::SymInt(1), c10::SymIntArrayRef padding=c10::SymInt(0), c10::SymIntArrayRef dilation=c10::SymInt(1)) { + return at::_ops::slow_conv_dilated2d_out::redispatch(dispatchKeySet, self, weight, kernel_size, bias, stride, padding, dilation, out); + } + + // aten::slow_conv_dilated2d.out(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] dilation=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & slow_conv_dilated2d_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, at::Tensor & out) { + return at::_ops::slow_conv_dilated2d_out::redispatch(dispatchKeySet, self, weight, kernel_size, bias, stride, padding, dilation, out); + } + + // aten::slow_conv_dilated3d.out(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, SymInt[3] dilation=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & slow_conv_dilated3d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, const ::std::optional & bias={}, at::IntArrayRef stride=1, at::IntArrayRef padding=0, at::IntArrayRef dilation=1) { + return at::_ops::slow_conv_dilated3d_out::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(kernel_size), bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), out); + } + + // aten::slow_conv_dilated3d.out(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, SymInt[3] dilation=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & slow_conv_dilated3d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, const ::std::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, at::Tensor & out) { + return at::_ops::slow_conv_dilated3d_out::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(kernel_size), bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), out); + } + + // aten::slow_conv_dilated3d.out(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, SymInt[3] dilation=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & slow_conv_dilated3d_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias={}, c10::SymIntArrayRef stride=c10::SymInt(1), c10::SymIntArrayRef padding=c10::SymInt(0), c10::SymIntArrayRef dilation=c10::SymInt(1)) { + return at::_ops::slow_conv_dilated3d_out::redispatch(dispatchKeySet, self, weight, kernel_size, bias, stride, padding, dilation, out); + } + + // aten::slow_conv_dilated3d.out(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, SymInt[3] dilation=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & slow_conv_dilated3d_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, at::Tensor & out) { + return at::_ops::slow_conv_dilated3d_out::redispatch(dispatchKeySet, self, weight, kernel_size, bias, stride, padding, dilation, out); + } + + // aten::isinf.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & isinf_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::isinf_out::redispatch(dispatchKeySet, self, out); + } + + // aten::isinf.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & isinf_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::isinf_out::redispatch(dispatchKeySet, self, out); + } + + // aten::linalg_matrix_exp.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_matrix_exp_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::linalg_matrix_exp_out::redispatch(dispatchKeySet, self, out); + } + + // aten::linalg_matrix_exp.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_matrix_exp_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::linalg_matrix_exp_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_test_optional_intlist.out(Tensor values, int[]? addends, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _test_optional_intlist_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & values, at::OptionalIntArrayRef addends) { + return at::_ops::_test_optional_intlist_out::redispatch(dispatchKeySet, values, addends, out); + } + + // aten::_test_optional_intlist.out(Tensor values, int[]? addends, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _test_optional_intlist_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & values, at::OptionalIntArrayRef addends, at::Tensor & out) { + return at::_ops::_test_optional_intlist_out::redispatch(dispatchKeySet, values, addends, out); + } + + // aten::_test_optional_filled_intlist.out(Tensor values, int[2]? addends, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _test_optional_filled_intlist_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & values, at::OptionalIntArrayRef addends) { + return at::_ops::_test_optional_filled_intlist_out::redispatch(dispatchKeySet, values, addends, out); + } + + // aten::_test_optional_filled_intlist.out(Tensor values, int[2]? addends, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _test_optional_filled_intlist_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & values, at::OptionalIntArrayRef addends, at::Tensor & out) { + return at::_ops::_test_optional_filled_intlist_out::redispatch(dispatchKeySet, values, addends, out); + } + + // aten::_test_optional_floatlist.out(Tensor values, float[]? addends, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _test_optional_floatlist_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & values, ::std::optional> addends) { + return at::_ops::_test_optional_floatlist_out::redispatch(dispatchKeySet, values, addends, out); + } + + // aten::_test_optional_floatlist.out(Tensor values, float[]? addends, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _test_optional_floatlist_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & values, ::std::optional> addends, at::Tensor & out) { + return at::_ops::_test_optional_floatlist_out::redispatch(dispatchKeySet, values, addends, out); + } + + // aten::_test_warn_in_autograd.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _test_warn_in_autograd_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::_test_warn_in_autograd_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_test_warn_in_autograd.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _test_warn_in_autograd_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::_test_warn_in_autograd_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_test_autograd_multiple_dispatch.fullcoverage_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _test_autograd_multiple_dispatch_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::_test_autograd_multiple_dispatch_fullcoverage_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_test_autograd_multiple_dispatch.fullcoverage_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _test_autograd_multiple_dispatch_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::_test_autograd_multiple_dispatch_fullcoverage_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_test_autograd_multiple_dispatch_view_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _test_autograd_multiple_dispatch_view_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::_test_autograd_multiple_dispatch_view_copy_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_test_autograd_multiple_dispatch_view_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _test_autograd_multiple_dispatch_view_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::_test_autograd_multiple_dispatch_view_copy_out::redispatch(dispatchKeySet, self, out); + } + + // aten::segment_reduce.out(Tensor data, str reduce, *, Tensor? lengths=None, Tensor? indices=None, Tensor? offsets=None, int axis=0, bool unsafe=False, Scalar? initial=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & segment_reduce_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & data, c10::string_view reduce, const ::std::optional & lengths={}, const ::std::optional & indices={}, const ::std::optional & offsets={}, int64_t axis=0, bool unsafe=false, const ::std::optional & initial=::std::nullopt) { + return at::_ops::segment_reduce_out::redispatch(dispatchKeySet, data, reduce, lengths, indices, offsets, axis, unsafe, initial, out); + } + + // aten::segment_reduce.out(Tensor data, str reduce, *, Tensor? lengths=None, Tensor? indices=None, Tensor? offsets=None, int axis=0, bool unsafe=False, Scalar? initial=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & segment_reduce_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & data, c10::string_view reduce, const ::std::optional & lengths, const ::std::optional & indices, const ::std::optional & offsets, int64_t axis, bool unsafe, const ::std::optional & initial, at::Tensor & out) { + return at::_ops::segment_reduce_out::redispatch(dispatchKeySet, data, reduce, lengths, indices, offsets, axis, unsafe, initial, out); + } + + // aten::_segment_reduce_backward.out(Tensor grad, Tensor output, Tensor data, str reduce, *, Tensor? lengths=None, Tensor? offsets=None, int axis=0, Scalar? initial=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _segment_reduce_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad, const at::Tensor & output, const at::Tensor & data, c10::string_view reduce, const ::std::optional & lengths={}, const ::std::optional & offsets={}, int64_t axis=0, const ::std::optional & initial=::std::nullopt) { + return at::_ops::_segment_reduce_backward_out::redispatch(dispatchKeySet, grad, output, data, reduce, lengths, offsets, axis, initial, out); + } + + // aten::_segment_reduce_backward.out(Tensor grad, Tensor output, Tensor data, str reduce, *, Tensor? lengths=None, Tensor? offsets=None, int axis=0, Scalar? initial=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _segment_reduce_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & output, const at::Tensor & data, c10::string_view reduce, const ::std::optional & lengths, const ::std::optional & offsets, int64_t axis, const ::std::optional & initial, at::Tensor & out) { + return at::_ops::_segment_reduce_backward_out::redispatch(dispatchKeySet, grad, output, data, reduce, lengths, offsets, axis, initial, out); + } + + // aten::_nested_tensor_from_tensor_list.out(Tensor[] list, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _nested_tensor_from_tensor_list_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::TensorList list, ::std::optional dtype=::std::nullopt, ::std::optional layout=::std::nullopt, ::std::optional device=::std::nullopt, ::std::optional pin_memory=::std::nullopt) { + return at::_ops::_nested_tensor_from_tensor_list_out::redispatch(dispatchKeySet, list, dtype, layout, device, pin_memory, out); + } + + // aten::_nested_tensor_from_tensor_list.out(Tensor[] list, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _nested_tensor_from_tensor_list_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList list, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, at::Tensor & out) { + return at::_ops::_nested_tensor_from_tensor_list_out::redispatch(dispatchKeySet, list, dtype, layout, device, pin_memory, out); + } + + // aten::_fw_primal_copy.out(Tensor self, int level, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _fw_primal_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t level) { + return at::_ops::_fw_primal_copy_out::redispatch(dispatchKeySet, self, level, out); + } + + // aten::_fw_primal_copy.out(Tensor self, int level, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _fw_primal_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t level, at::Tensor & out) { + return at::_ops::_fw_primal_copy_out::redispatch(dispatchKeySet, self, level, out); + } + + // aten::_make_dual_copy.out(Tensor primal, Tensor tangent, int level, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _make_dual_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & primal, const at::Tensor & tangent, int64_t level) { + return at::_ops::_make_dual_copy_out::redispatch(dispatchKeySet, primal, tangent, level, out); + } + + // aten::_make_dual_copy.out(Tensor primal, Tensor tangent, int level, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _make_dual_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & primal, const at::Tensor & tangent, int64_t level, at::Tensor & out) { + return at::_ops::_make_dual_copy_out::redispatch(dispatchKeySet, primal, tangent, level, out); + } + + // aten::view_as_real_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & view_as_real_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::view_as_real_copy_out::redispatch(dispatchKeySet, self, out); + } + + // aten::view_as_real_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & view_as_real_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::view_as_real_copy_out::redispatch(dispatchKeySet, self, out); + } + + // aten::view_as_complex_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & view_as_complex_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::view_as_complex_copy_out::redispatch(dispatchKeySet, self, out); + } + + // aten::view_as_complex_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & view_as_complex_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::view_as_complex_copy_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_conj_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _conj_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::_conj_copy_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_conj_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _conj_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::_conj_copy_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_neg_view_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _neg_view_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::_neg_view_copy_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_neg_view_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _neg_view_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::_neg_view_copy_out::redispatch(dispatchKeySet, self, out); + } + + // aten::as_strided_copy.out(Tensor self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & as_strided_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef size, at::IntArrayRef stride, ::std::optional storage_offset=::std::nullopt) { + return at::_ops::as_strided_copy_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride), storage_offset.has_value() ? ::std::make_optional(c10::SymInt(*storage_offset)) : ::std::nullopt, out); + } + + // aten::as_strided_copy.out(Tensor self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & as_strided_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, at::IntArrayRef stride, ::std::optional storage_offset, at::Tensor & out) { + return at::_ops::as_strided_copy_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride), storage_offset.has_value() ? ::std::make_optional(c10::SymInt(*storage_offset)) : ::std::nullopt, out); + } + + // aten::as_strided_copy.out(Tensor self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & as_strided_copy_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional storage_offset=::std::nullopt) { + return at::_ops::as_strided_copy_out::redispatch(dispatchKeySet, self, size, stride, storage_offset, out); + } + + // aten::as_strided_copy.out(Tensor self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & as_strided_copy_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional storage_offset, at::Tensor & out) { + return at::_ops::as_strided_copy_out::redispatch(dispatchKeySet, self, size, stride, storage_offset, out); + } + + // aten::_sparse_broadcast_to_copy.out(Tensor self, int[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _sparse_broadcast_to_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef size) { + return at::_ops::_sparse_broadcast_to_copy_out::redispatch(dispatchKeySet, self, size, out); + } + + // aten::_sparse_broadcast_to_copy.out(Tensor self, int[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _sparse_broadcast_to_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, at::Tensor & out) { + return at::_ops::_sparse_broadcast_to_copy_out::redispatch(dispatchKeySet, self, size, out); + } + + // aten::diagonal_copy.out(Tensor self, int offset=0, int dim1=0, int dim2=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & diagonal_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t offset=0, int64_t dim1=0, int64_t dim2=1) { + return at::_ops::diagonal_copy_out::redispatch(dispatchKeySet, self, offset, dim1, dim2, out); + } + + // aten::diagonal_copy.out(Tensor self, int offset=0, int dim1=0, int dim2=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & diagonal_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t offset, int64_t dim1, int64_t dim2, at::Tensor & out) { + return at::_ops::diagonal_copy_out::redispatch(dispatchKeySet, self, offset, dim1, dim2, out); + } + + // aten::expand_copy.out(Tensor self, SymInt[] size, *, bool implicit=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & expand_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef size, bool implicit=false) { + return at::_ops::expand_copy_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), implicit, out); + } + + // aten::expand_copy.out(Tensor self, SymInt[] size, *, bool implicit=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & expand_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, bool implicit, at::Tensor & out) { + return at::_ops::expand_copy_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), implicit, out); + } + + // aten::expand_copy.out(Tensor self, SymInt[] size, *, bool implicit=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & expand_copy_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef size, bool implicit=false) { + return at::_ops::expand_copy_out::redispatch(dispatchKeySet, self, size, implicit, out); + } + + // aten::expand_copy.out(Tensor self, SymInt[] size, *, bool implicit=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & expand_copy_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, bool implicit, at::Tensor & out) { + return at::_ops::expand_copy_out::redispatch(dispatchKeySet, self, size, implicit, out); + } + + // aten::permute_copy.out(Tensor self, int[] dims, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & permute_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef dims) { + return at::_ops::permute_copy_out::redispatch(dispatchKeySet, self, dims, out); + } + + // aten::permute_copy.out(Tensor self, int[] dims, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & permute_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dims, at::Tensor & out) { + return at::_ops::permute_copy_out::redispatch(dispatchKeySet, self, dims, out); + } + + // aten::_reshape_alias_copy.out(Tensor self, SymInt[] size, SymInt[] stride, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _reshape_alias_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef size, at::IntArrayRef stride) { + return at::_ops::_reshape_alias_copy_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride), out); + } + + // aten::_reshape_alias_copy.out(Tensor self, SymInt[] size, SymInt[] stride, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _reshape_alias_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, at::IntArrayRef stride, at::Tensor & out) { + return at::_ops::_reshape_alias_copy_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride), out); + } + + // aten::_reshape_alias_copy.out(Tensor self, SymInt[] size, SymInt[] stride, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _reshape_alias_copy_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride) { + return at::_ops::_reshape_alias_copy_out::redispatch(dispatchKeySet, self, size, stride, out); + } + + // aten::_reshape_alias_copy.out(Tensor self, SymInt[] size, SymInt[] stride, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _reshape_alias_copy_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, at::Tensor & out) { + return at::_ops::_reshape_alias_copy_out::redispatch(dispatchKeySet, self, size, stride, out); + } + + // aten::select_copy.int_out(Tensor self, int dim, SymInt index, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & select_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim, int64_t index) { + return at::_ops::select_copy_int_out::redispatch(dispatchKeySet, self, dim, index, out); + } + + // aten::select_copy.int_out(Tensor self, int dim, SymInt index, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & select_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, int64_t index, at::Tensor & out) { + return at::_ops::select_copy_int_out::redispatch(dispatchKeySet, self, dim, index, out); + } + + // aten::select_copy.int_out(Tensor self, int dim, SymInt index, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & select_copy_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim, c10::SymInt index) { + return at::_ops::select_copy_int_out::redispatch(dispatchKeySet, self, dim, index, out); + } + + // aten::select_copy.int_out(Tensor self, int dim, SymInt index, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & select_copy_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, c10::SymInt index, at::Tensor & out) { + return at::_ops::select_copy_int_out::redispatch(dispatchKeySet, self, dim, index, out); + } + + // aten::detach_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & detach_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::detach_copy_out::redispatch(dispatchKeySet, self, out); + } + + // aten::detach_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & detach_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::detach_copy_out::redispatch(dispatchKeySet, self, out); + } + + // aten::slice_copy.Tensor_out(Tensor self, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & slice_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim=0, ::std::optional start=::std::nullopt, ::std::optional end=::std::nullopt, int64_t step=1) { + return at::_ops::slice_copy_Tensor_out::redispatch(dispatchKeySet, self, dim, start.has_value() ? ::std::make_optional(c10::SymInt(*start)) : ::std::nullopt, end.has_value() ? ::std::make_optional(c10::SymInt(*end)) : ::std::nullopt, step, out); + } + + // aten::slice_copy.Tensor_out(Tensor self, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & slice_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, ::std::optional start, ::std::optional end, int64_t step, at::Tensor & out) { + return at::_ops::slice_copy_Tensor_out::redispatch(dispatchKeySet, self, dim, start.has_value() ? ::std::make_optional(c10::SymInt(*start)) : ::std::nullopt, end.has_value() ? ::std::make_optional(c10::SymInt(*end)) : ::std::nullopt, step, out); + } + + // aten::slice_copy.Tensor_out(Tensor self, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & slice_copy_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim=0, ::std::optional start=::std::nullopt, ::std::optional end=::std::nullopt, c10::SymInt step=1) { + return at::_ops::slice_copy_Tensor_out::redispatch(dispatchKeySet, self, dim, start, end, step, out); + } + + // aten::slice_copy.Tensor_out(Tensor self, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & slice_copy_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, ::std::optional start, ::std::optional end, c10::SymInt step, at::Tensor & out) { + return at::_ops::slice_copy_Tensor_out::redispatch(dispatchKeySet, self, dim, start, end, step, out); + } + + // aten::squeeze_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & squeeze_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::squeeze_copy_out::redispatch(dispatchKeySet, self, out); + } + + // aten::squeeze_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & squeeze_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::squeeze_copy_out::redispatch(dispatchKeySet, self, out); + } + + // aten::squeeze_copy.dim_out(Tensor self, int dim, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & squeeze_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim) { + return at::_ops::squeeze_copy_dim_out::redispatch(dispatchKeySet, self, dim, out); + } + + // aten::squeeze_copy.dim_out(Tensor self, int dim, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & squeeze_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, at::Tensor & out) { + return at::_ops::squeeze_copy_dim_out::redispatch(dispatchKeySet, self, dim, out); + } + + // aten::squeeze_copy.dims_out(Tensor self, int[] dim, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & squeeze_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef dim) { + return at::_ops::squeeze_copy_dims_out::redispatch(dispatchKeySet, self, dim, out); + } + + // aten::squeeze_copy.dims_out(Tensor self, int[] dim, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & squeeze_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim, at::Tensor & out) { + return at::_ops::squeeze_copy_dims_out::redispatch(dispatchKeySet, self, dim, out); + } + + // aten::t_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & t_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::t_copy_out::redispatch(dispatchKeySet, self, out); + } + + // aten::t_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & t_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::t_copy_out::redispatch(dispatchKeySet, self, out); + } + + // aten::transpose_copy.int_out(Tensor self, int dim0, int dim1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & transpose_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim0, int64_t dim1) { + return at::_ops::transpose_copy_int_out::redispatch(dispatchKeySet, self, dim0, dim1, out); + } + + // aten::transpose_copy.int_out(Tensor self, int dim0, int dim1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & transpose_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim0, int64_t dim1, at::Tensor & out) { + return at::_ops::transpose_copy_int_out::redispatch(dispatchKeySet, self, dim0, dim1, out); + } + + // aten::unsqueeze_copy.out(Tensor self, int dim, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & unsqueeze_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim) { + return at::_ops::unsqueeze_copy_out::redispatch(dispatchKeySet, self, dim, out); + } + + // aten::unsqueeze_copy.out(Tensor self, int dim, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & unsqueeze_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, at::Tensor & out) { + return at::_ops::unsqueeze_copy_out::redispatch(dispatchKeySet, self, dim, out); + } + + // aten::_indices_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _indices_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::_indices_copy_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_indices_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _indices_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::_indices_copy_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_values_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _values_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::_values_copy_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_values_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _values_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::_values_copy_out::redispatch(dispatchKeySet, self, out); + } + + // aten::indices_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & indices_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::indices_copy_out::redispatch(dispatchKeySet, self, out); + } + + // aten::indices_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & indices_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::indices_copy_out::redispatch(dispatchKeySet, self, out); + } + + // aten::values_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & values_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::values_copy_out::redispatch(dispatchKeySet, self, out); + } + + // aten::values_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & values_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::values_copy_out::redispatch(dispatchKeySet, self, out); + } + + // aten::crow_indices_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & crow_indices_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::crow_indices_copy_out::redispatch(dispatchKeySet, self, out); + } + + // aten::crow_indices_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & crow_indices_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::crow_indices_copy_out::redispatch(dispatchKeySet, self, out); + } + + // aten::col_indices_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & col_indices_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::col_indices_copy_out::redispatch(dispatchKeySet, self, out); + } + + // aten::col_indices_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & col_indices_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::col_indices_copy_out::redispatch(dispatchKeySet, self, out); + } + + // aten::ccol_indices_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & ccol_indices_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::ccol_indices_copy_out::redispatch(dispatchKeySet, self, out); + } + + // aten::ccol_indices_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & ccol_indices_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::ccol_indices_copy_out::redispatch(dispatchKeySet, self, out); + } + + // aten::row_indices_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & row_indices_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::row_indices_copy_out::redispatch(dispatchKeySet, self, out); + } + + // aten::row_indices_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & row_indices_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::row_indices_copy_out::redispatch(dispatchKeySet, self, out); + } + + // aten::view_copy.out(Tensor self, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & view_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef size) { + return at::_ops::view_copy_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), out); + } + + // aten::view_copy.out(Tensor self, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & view_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, at::Tensor & out) { + return at::_ops::view_copy_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), out); + } + + // aten::view_copy.out(Tensor self, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & view_copy_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef size) { + return at::_ops::view_copy_out::redispatch(dispatchKeySet, self, size, out); + } + + // aten::view_copy.out(Tensor self, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & view_copy_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, at::Tensor & out) { + return at::_ops::view_copy_out::redispatch(dispatchKeySet, self, size, out); + } + + // aten::view_copy.dtype_out(Tensor self, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & view_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::ScalarType dtype) { + return at::_ops::view_copy_dtype_out::redispatch(dispatchKeySet, self, dtype, out); + } + + // aten::view_copy.dtype_out(Tensor self, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & view_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::ScalarType dtype, at::Tensor & out) { + return at::_ops::view_copy_dtype_out::redispatch(dispatchKeySet, self, dtype, out); + } + + // aten::unfold_copy.out(Tensor self, int dimension, int size, int step, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & unfold_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dimension, int64_t size, int64_t step) { + return at::_ops::unfold_copy_out::redispatch(dispatchKeySet, self, dimension, size, step, out); + } + + // aten::unfold_copy.out(Tensor self, int dimension, int size, int step, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & unfold_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dimension, int64_t size, int64_t step, at::Tensor & out) { + return at::_ops::unfold_copy_out::redispatch(dispatchKeySet, self, dimension, size, step, out); + } + + // aten::alias_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & alias_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::alias_copy_out::redispatch(dispatchKeySet, self, out); + } + + // aten::alias_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & alias_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::alias_copy_out::redispatch(dispatchKeySet, self, out); + } + + // aten::to_padded_tensor.out(Tensor self, float padding, SymInt[]? output_size=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & to_padded_tensor_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, double padding, at::OptionalIntArrayRef output_size=::std::nullopt) { + return at::_ops::to_padded_tensor_out::redispatch(dispatchKeySet, self, padding, output_size.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*output_size)) : ::std::nullopt, out); + } + + // aten::to_padded_tensor.out(Tensor self, float padding, SymInt[]? output_size=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & to_padded_tensor_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double padding, at::OptionalIntArrayRef output_size, at::Tensor & out) { + return at::_ops::to_padded_tensor_out::redispatch(dispatchKeySet, self, padding, output_size.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*output_size)) : ::std::nullopt, out); + } + + // aten::to_padded_tensor.out(Tensor self, float padding, SymInt[]? output_size=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & to_padded_tensor_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, double padding, at::OptionalSymIntArrayRef output_size=::std::nullopt) { + return at::_ops::to_padded_tensor_out::redispatch(dispatchKeySet, self, padding, output_size, out); + } + + // aten::to_padded_tensor.out(Tensor self, float padding, SymInt[]? output_size=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & to_padded_tensor_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double padding, at::OptionalSymIntArrayRef output_size, at::Tensor & out) { + return at::_ops::to_padded_tensor_out::redispatch(dispatchKeySet, self, padding, output_size, out); + } + + // aten::_transformer_encoder_layer_fwd.out(Tensor src, int embed_dim, int num_heads, Tensor qkv_weight, Tensor qkv_bias, Tensor proj_weight, Tensor proj_bias, bool use_gelu, bool norm_first, float eps, Tensor norm_weight_1, Tensor norm_bias_1, Tensor norm_weight_2, Tensor norm_bias_2, Tensor ffn_weight_1, Tensor ffn_bias_1, Tensor ffn_weight_2, Tensor ffn_bias_2, Tensor? mask=None, int? mask_type=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _transformer_encoder_layer_fwd_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & src, int64_t embed_dim, int64_t num_heads, const at::Tensor & qkv_weight, const at::Tensor & qkv_bias, const at::Tensor & proj_weight, const at::Tensor & proj_bias, bool use_gelu, bool norm_first, double eps, const at::Tensor & norm_weight_1, const at::Tensor & norm_bias_1, const at::Tensor & norm_weight_2, const at::Tensor & norm_bias_2, const at::Tensor & ffn_weight_1, const at::Tensor & ffn_bias_1, const at::Tensor & ffn_weight_2, const at::Tensor & ffn_bias_2, const ::std::optional & mask={}, ::std::optional mask_type=::std::nullopt) { + return at::_ops::_transformer_encoder_layer_fwd_out::redispatch(dispatchKeySet, src, embed_dim, num_heads, qkv_weight, qkv_bias, proj_weight, proj_bias, use_gelu, norm_first, eps, norm_weight_1, norm_bias_1, norm_weight_2, norm_bias_2, ffn_weight_1, ffn_bias_1, ffn_weight_2, ffn_bias_2, mask, mask_type, out); + } + + // aten::_transformer_encoder_layer_fwd.out(Tensor src, int embed_dim, int num_heads, Tensor qkv_weight, Tensor qkv_bias, Tensor proj_weight, Tensor proj_bias, bool use_gelu, bool norm_first, float eps, Tensor norm_weight_1, Tensor norm_bias_1, Tensor norm_weight_2, Tensor norm_bias_2, Tensor ffn_weight_1, Tensor ffn_bias_1, Tensor ffn_weight_2, Tensor ffn_bias_2, Tensor? mask=None, int? mask_type=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _transformer_encoder_layer_fwd_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & src, int64_t embed_dim, int64_t num_heads, const at::Tensor & qkv_weight, const at::Tensor & qkv_bias, const at::Tensor & proj_weight, const at::Tensor & proj_bias, bool use_gelu, bool norm_first, double eps, const at::Tensor & norm_weight_1, const at::Tensor & norm_bias_1, const at::Tensor & norm_weight_2, const at::Tensor & norm_bias_2, const at::Tensor & ffn_weight_1, const at::Tensor & ffn_bias_1, const at::Tensor & ffn_weight_2, const at::Tensor & ffn_bias_2, const ::std::optional & mask, ::std::optional mask_type, at::Tensor & out) { + return at::_ops::_transformer_encoder_layer_fwd_out::redispatch(dispatchKeySet, src, embed_dim, num_heads, qkv_weight, qkv_bias, proj_weight, proj_bias, use_gelu, norm_first, eps, norm_weight_1, norm_bias_1, norm_weight_2, norm_bias_2, ffn_weight_1, ffn_bias_1, ffn_weight_2, ffn_bias_2, mask, mask_type, out); + } + + // aten::_native_multi_head_attention.out(Tensor query, Tensor key, Tensor value, int embed_dim, int num_head, Tensor qkv_weight, Tensor qkv_bias, Tensor proj_weight, Tensor proj_bias, Tensor? mask=None, bool need_weights=True, bool average_attn_weights=True, int? mask_type=None, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple _native_multi_head_attention_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, int64_t embed_dim, int64_t num_head, const at::Tensor & qkv_weight, const at::Tensor & qkv_bias, const at::Tensor & proj_weight, const at::Tensor & proj_bias, const ::std::optional & mask={}, bool need_weights=true, bool average_attn_weights=true, ::std::optional mask_type=::std::nullopt) { + return at::_ops::_native_multi_head_attention_out::redispatch(dispatchKeySet, query, key, value, embed_dim, num_head, qkv_weight, qkv_bias, proj_weight, proj_bias, mask, need_weights, average_attn_weights, mask_type, out0, out1); + } + + // aten::_native_multi_head_attention.out(Tensor query, Tensor key, Tensor value, int embed_dim, int num_head, Tensor qkv_weight, Tensor qkv_bias, Tensor proj_weight, Tensor proj_bias, Tensor? mask=None, bool need_weights=True, bool average_attn_weights=True, int? mask_type=None, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple _native_multi_head_attention_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, int64_t embed_dim, int64_t num_head, const at::Tensor & qkv_weight, const at::Tensor & qkv_bias, const at::Tensor & proj_weight, const at::Tensor & proj_bias, const ::std::optional & mask, bool need_weights, bool average_attn_weights, ::std::optional mask_type, at::Tensor & out0, at::Tensor & out1) { + return at::_ops::_native_multi_head_attention_out::redispatch(dispatchKeySet, query, key, value, embed_dim, num_head, qkv_weight, qkv_bias, proj_weight, proj_bias, mask, need_weights, average_attn_weights, mask_type, out0, out1); + } + + // aten::_triton_scaled_dot_attention.out(Tensor q, Tensor k, Tensor v, float dropout_p=0.0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _triton_scaled_dot_attention_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & q, const at::Tensor & k, const at::Tensor & v, double dropout_p=0.0) { + return at::_ops::_triton_scaled_dot_attention_out::redispatch(dispatchKeySet, q, k, v, dropout_p, out); + } + + // aten::_triton_scaled_dot_attention.out(Tensor q, Tensor k, Tensor v, float dropout_p=0.0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _triton_scaled_dot_attention_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & q, const at::Tensor & k, const at::Tensor & v, double dropout_p, at::Tensor & out) { + return at::_ops::_triton_scaled_dot_attention_out::redispatch(dispatchKeySet, q, k, v, dropout_p, out); + } + + // aten::_triton_multi_head_attention.out(Tensor query, Tensor key, Tensor value, int embed_dim, int num_head, Tensor qkv_weight, Tensor qkv_bias, Tensor proj_weight, Tensor proj_bias, Tensor? mask=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _triton_multi_head_attention_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, int64_t embed_dim, int64_t num_head, const at::Tensor & qkv_weight, const at::Tensor & qkv_bias, const at::Tensor & proj_weight, const at::Tensor & proj_bias, const ::std::optional & mask={}) { + return at::_ops::_triton_multi_head_attention_out::redispatch(dispatchKeySet, query, key, value, embed_dim, num_head, qkv_weight, qkv_bias, proj_weight, proj_bias, mask, out); + } + + // aten::_triton_multi_head_attention.out(Tensor query, Tensor key, Tensor value, int embed_dim, int num_head, Tensor qkv_weight, Tensor qkv_bias, Tensor proj_weight, Tensor proj_bias, Tensor? mask=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _triton_multi_head_attention_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, int64_t embed_dim, int64_t num_head, const at::Tensor & qkv_weight, const at::Tensor & qkv_bias, const at::Tensor & proj_weight, const at::Tensor & proj_bias, const ::std::optional & mask, at::Tensor & out) { + return at::_ops::_triton_multi_head_attention_out::redispatch(dispatchKeySet, query, key, value, embed_dim, num_head, qkv_weight, qkv_bias, proj_weight, proj_bias, mask, out); + } + + // aten::_foobar.out(Tensor self, bool arg1=True, bool arg2=True, *, bool arg3=True, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _foobar_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, bool arg1=true, bool arg2=true, bool arg3=true) { + return at::_ops::_foobar_out::redispatch(dispatchKeySet, self, arg1, arg2, arg3, out); + } + + // aten::_foobar.out(Tensor self, bool arg1=True, bool arg2=True, *, bool arg3=True, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _foobar_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool arg1, bool arg2, bool arg3, at::Tensor & out) { + return at::_ops::_foobar_out::redispatch(dispatchKeySet, self, arg1, arg2, arg3, out); + } + + // aten::_fused_adam.out(Tensor[] self, Tensor(b!)[] grads, Tensor(c!)[] exp_avgs, Tensor(d!)[] exp_avg_sqs, Tensor(e!)[] max_exp_avg_sqs, Tensor[] state_steps, *, float lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None, Tensor(a!)[] out) -> () + inline void _fused_adam_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::TensorList grads, at::TensorList exp_avgs, at::TensorList exp_avg_sqs, at::TensorList max_exp_avg_sqs, at::TensorList state_steps, double lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const ::std::optional & grad_scale={}, const ::std::optional & found_inf={}) { + return at::_ops::_fused_adam_out::redispatch(dispatchKeySet, self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale, found_inf, out); + } + + // aten::_fused_adam.out(Tensor[] self, Tensor(b!)[] grads, Tensor(c!)[] exp_avgs, Tensor(d!)[] exp_avg_sqs, Tensor(e!)[] max_exp_avg_sqs, Tensor[] state_steps, *, float lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None, Tensor(a!)[] out) -> () + inline void _fused_adam_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList grads, at::TensorList exp_avgs, at::TensorList exp_avg_sqs, at::TensorList max_exp_avg_sqs, at::TensorList state_steps, double lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const ::std::optional & grad_scale, const ::std::optional & found_inf, at::TensorList out) { + return at::_ops::_fused_adam_out::redispatch(dispatchKeySet, self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale, found_inf, out); + } + + // aten::_fused_adam(Tensor[] self, Tensor[] grads, Tensor[] exp_avgs, Tensor[] exp_avg_sqs, Tensor[] max_exp_avg_sqs, Tensor[] state_steps, *, float lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> (Tensor[] self_out, Tensor[] grads_out, Tensor[] exp_avgs_out, Tensor[] exp_avg_sqs_out, Tensor[] max_exp_avg_sqs_out) + inline ::std::tuple<::std::vector,::std::vector,::std::vector,::std::vector,::std::vector> _fused_adam(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList grads, at::TensorList exp_avgs, at::TensorList exp_avg_sqs, at::TensorList max_exp_avg_sqs, at::TensorList state_steps, double lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const ::std::optional & grad_scale={}, const ::std::optional & found_inf={}) { + return at::_ops::_fused_adam::redispatch(dispatchKeySet, self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale, found_inf); + } + + // aten::_fused_adam.tensor_lr_out(Tensor[] self, Tensor(b!)[] grads, Tensor(c!)[] exp_avgs, Tensor(d!)[] exp_avg_sqs, Tensor(e!)[] max_exp_avg_sqs, Tensor[] state_steps, *, Tensor lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None, Tensor(a!)[] out) -> () + inline void _fused_adam_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::TensorList grads, at::TensorList exp_avgs, at::TensorList exp_avg_sqs, at::TensorList max_exp_avg_sqs, at::TensorList state_steps, const at::Tensor & lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const ::std::optional & grad_scale={}, const ::std::optional & found_inf={}) { + return at::_ops::_fused_adam_tensor_lr_out::redispatch(dispatchKeySet, self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale, found_inf, out); + } + + // aten::_fused_adam.tensor_lr_out(Tensor[] self, Tensor(b!)[] grads, Tensor(c!)[] exp_avgs, Tensor(d!)[] exp_avg_sqs, Tensor(e!)[] max_exp_avg_sqs, Tensor[] state_steps, *, Tensor lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None, Tensor(a!)[] out) -> () + inline void _fused_adam_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList grads, at::TensorList exp_avgs, at::TensorList exp_avg_sqs, at::TensorList max_exp_avg_sqs, at::TensorList state_steps, const at::Tensor & lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const ::std::optional & grad_scale, const ::std::optional & found_inf, at::TensorList out) { + return at::_ops::_fused_adam_tensor_lr_out::redispatch(dispatchKeySet, self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale, found_inf, out); + } + + // aten::_fused_adam.tensor_lr(Tensor[] self, Tensor[] grads, Tensor[] exp_avgs, Tensor[] exp_avg_sqs, Tensor[] max_exp_avg_sqs, Tensor[] state_steps, *, Tensor lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> (Tensor[] self_out, Tensor[] grads_out, Tensor[] exp_avgs_out, Tensor[] exp_avg_sqs_out, Tensor[] max_exp_avg_sqs_out) + inline ::std::tuple<::std::vector,::std::vector,::std::vector,::std::vector,::std::vector> _fused_adam(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList grads, at::TensorList exp_avgs, at::TensorList exp_avg_sqs, at::TensorList max_exp_avg_sqs, at::TensorList state_steps, const at::Tensor & lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const ::std::optional & grad_scale={}, const ::std::optional & found_inf={}) { + return at::_ops::_fused_adam_tensor_lr::redispatch(dispatchKeySet, self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale, found_inf); + } + + // aten::_fused_adamw.out(Tensor[] self, Tensor(b!)[] grads, Tensor(c!)[] exp_avgs, Tensor(d!)[] exp_avg_sqs, Tensor(e!)[] max_exp_avg_sqs, Tensor[] state_steps, *, float lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None, Tensor(a!)[] out) -> () + inline void _fused_adamw_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::TensorList grads, at::TensorList exp_avgs, at::TensorList exp_avg_sqs, at::TensorList max_exp_avg_sqs, at::TensorList state_steps, double lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const ::std::optional & grad_scale={}, const ::std::optional & found_inf={}) { + return at::_ops::_fused_adamw_out::redispatch(dispatchKeySet, self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale, found_inf, out); + } + + // aten::_fused_adamw.out(Tensor[] self, Tensor(b!)[] grads, Tensor(c!)[] exp_avgs, Tensor(d!)[] exp_avg_sqs, Tensor(e!)[] max_exp_avg_sqs, Tensor[] state_steps, *, float lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None, Tensor(a!)[] out) -> () + inline void _fused_adamw_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList grads, at::TensorList exp_avgs, at::TensorList exp_avg_sqs, at::TensorList max_exp_avg_sqs, at::TensorList state_steps, double lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const ::std::optional & grad_scale, const ::std::optional & found_inf, at::TensorList out) { + return at::_ops::_fused_adamw_out::redispatch(dispatchKeySet, self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale, found_inf, out); + } + + // aten::_fused_adamw(Tensor[] self, Tensor[] grads, Tensor[] exp_avgs, Tensor[] exp_avg_sqs, Tensor[] max_exp_avg_sqs, Tensor[] state_steps, *, float lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> (Tensor[] self_out, Tensor[] grads_out, Tensor[] exp_avgs_out, Tensor[] exp_avg_sqs_out, Tensor[] max_exp_avg_sqs_out) + inline ::std::tuple<::std::vector,::std::vector,::std::vector,::std::vector,::std::vector> _fused_adamw(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList grads, at::TensorList exp_avgs, at::TensorList exp_avg_sqs, at::TensorList max_exp_avg_sqs, at::TensorList state_steps, double lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const ::std::optional & grad_scale={}, const ::std::optional & found_inf={}) { + return at::_ops::_fused_adamw::redispatch(dispatchKeySet, self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale, found_inf); + } + + // aten::_fused_adamw.tensor_lr_out(Tensor[] self, Tensor(b!)[] grads, Tensor(c!)[] exp_avgs, Tensor(d!)[] exp_avg_sqs, Tensor(e!)[] max_exp_avg_sqs, Tensor[] state_steps, *, Tensor lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None, Tensor(a!)[] out) -> () + inline void _fused_adamw_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::TensorList grads, at::TensorList exp_avgs, at::TensorList exp_avg_sqs, at::TensorList max_exp_avg_sqs, at::TensorList state_steps, const at::Tensor & lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const ::std::optional & grad_scale={}, const ::std::optional & found_inf={}) { + return at::_ops::_fused_adamw_tensor_lr_out::redispatch(dispatchKeySet, self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale, found_inf, out); + } + + // aten::_fused_adamw.tensor_lr_out(Tensor[] self, Tensor(b!)[] grads, Tensor(c!)[] exp_avgs, Tensor(d!)[] exp_avg_sqs, Tensor(e!)[] max_exp_avg_sqs, Tensor[] state_steps, *, Tensor lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None, Tensor(a!)[] out) -> () + inline void _fused_adamw_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList grads, at::TensorList exp_avgs, at::TensorList exp_avg_sqs, at::TensorList max_exp_avg_sqs, at::TensorList state_steps, const at::Tensor & lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const ::std::optional & grad_scale, const ::std::optional & found_inf, at::TensorList out) { + return at::_ops::_fused_adamw_tensor_lr_out::redispatch(dispatchKeySet, self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale, found_inf, out); + } + + // aten::_fused_adamw.tensor_lr(Tensor[] self, Tensor[] grads, Tensor[] exp_avgs, Tensor[] exp_avg_sqs, Tensor[] max_exp_avg_sqs, Tensor[] state_steps, *, Tensor lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> (Tensor[] self_out, Tensor[] grads_out, Tensor[] exp_avgs_out, Tensor[] exp_avg_sqs_out, Tensor[] max_exp_avg_sqs_out) + inline ::std::tuple<::std::vector,::std::vector,::std::vector,::std::vector,::std::vector> _fused_adamw(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList grads, at::TensorList exp_avgs, at::TensorList exp_avg_sqs, at::TensorList max_exp_avg_sqs, at::TensorList state_steps, const at::Tensor & lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const ::std::optional & grad_scale={}, const ::std::optional & found_inf={}) { + return at::_ops::_fused_adamw_tensor_lr::redispatch(dispatchKeySet, self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale, found_inf); + } + + // aten::_fused_sgd.out(Tensor[] self, Tensor(b!)[] grads, Tensor(c!)[] momentum_buffer_list, *, float weight_decay, float momentum, float lr, float dampening, bool nesterov, bool maximize, bool is_first_step, Tensor? grad_scale=None, Tensor? found_inf=None, Tensor(a!)[] out) -> () + inline void _fused_sgd_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::TensorList grads, at::TensorList momentum_buffer_list, double weight_decay, double momentum, double lr, double dampening, bool nesterov, bool maximize, bool is_first_step, const ::std::optional & grad_scale={}, const ::std::optional & found_inf={}) { + return at::_ops::_fused_sgd_out::redispatch(dispatchKeySet, self, grads, momentum_buffer_list, weight_decay, momentum, lr, dampening, nesterov, maximize, is_first_step, grad_scale, found_inf, out); + } + + // aten::_fused_sgd.out(Tensor[] self, Tensor(b!)[] grads, Tensor(c!)[] momentum_buffer_list, *, float weight_decay, float momentum, float lr, float dampening, bool nesterov, bool maximize, bool is_first_step, Tensor? grad_scale=None, Tensor? found_inf=None, Tensor(a!)[] out) -> () + inline void _fused_sgd_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList grads, at::TensorList momentum_buffer_list, double weight_decay, double momentum, double lr, double dampening, bool nesterov, bool maximize, bool is_first_step, const ::std::optional & grad_scale, const ::std::optional & found_inf, at::TensorList out) { + return at::_ops::_fused_sgd_out::redispatch(dispatchKeySet, self, grads, momentum_buffer_list, weight_decay, momentum, lr, dampening, nesterov, maximize, is_first_step, grad_scale, found_inf, out); + } + + // aten::_fused_sgd(Tensor[] self, Tensor[] grads, Tensor[] momentum_buffer_list, *, float weight_decay, float momentum, float lr, float dampening, bool nesterov, bool maximize, bool is_first_step, Tensor? grad_scale=None, Tensor? found_inf=None) -> (Tensor[] self_out, Tensor[] grads_out, Tensor[] momentum_buffer_list_out) + inline ::std::tuple<::std::vector,::std::vector,::std::vector> _fused_sgd(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList grads, at::TensorList momentum_buffer_list, double weight_decay, double momentum, double lr, double dampening, bool nesterov, bool maximize, bool is_first_step, const ::std::optional & grad_scale={}, const ::std::optional & found_inf={}) { + return at::_ops::_fused_sgd::redispatch(dispatchKeySet, self, grads, momentum_buffer_list, weight_decay, momentum, lr, dampening, nesterov, maximize, is_first_step, grad_scale, found_inf); + } + + // aten::_fused_sgd.tensor_lr_out(Tensor[] self, Tensor(b!)[] grads, Tensor(c!)[] momentum_buffer_list, *, float weight_decay, float momentum, Tensor lr, float dampening, bool nesterov, bool maximize, bool is_first_step, Tensor? grad_scale=None, Tensor? found_inf=None, Tensor(a!)[] out) -> () + inline void _fused_sgd_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::TensorList grads, at::TensorList momentum_buffer_list, double weight_decay, double momentum, const at::Tensor & lr, double dampening, bool nesterov, bool maximize, bool is_first_step, const ::std::optional & grad_scale={}, const ::std::optional & found_inf={}) { + return at::_ops::_fused_sgd_tensor_lr_out::redispatch(dispatchKeySet, self, grads, momentum_buffer_list, weight_decay, momentum, lr, dampening, nesterov, maximize, is_first_step, grad_scale, found_inf, out); + } + + // aten::_fused_sgd.tensor_lr_out(Tensor[] self, Tensor(b!)[] grads, Tensor(c!)[] momentum_buffer_list, *, float weight_decay, float momentum, Tensor lr, float dampening, bool nesterov, bool maximize, bool is_first_step, Tensor? grad_scale=None, Tensor? found_inf=None, Tensor(a!)[] out) -> () + inline void _fused_sgd_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList grads, at::TensorList momentum_buffer_list, double weight_decay, double momentum, const at::Tensor & lr, double dampening, bool nesterov, bool maximize, bool is_first_step, const ::std::optional & grad_scale, const ::std::optional & found_inf, at::TensorList out) { + return at::_ops::_fused_sgd_tensor_lr_out::redispatch(dispatchKeySet, self, grads, momentum_buffer_list, weight_decay, momentum, lr, dampening, nesterov, maximize, is_first_step, grad_scale, found_inf, out); + } + + // aten::_fused_sgd.tensor_lr(Tensor[] self, Tensor[] grads, Tensor[] momentum_buffer_list, *, float weight_decay, float momentum, Tensor lr, float dampening, bool nesterov, bool maximize, bool is_first_step, Tensor? grad_scale=None, Tensor? found_inf=None) -> (Tensor[] self_out, Tensor[] grads_out, Tensor[] momentum_buffer_list_out) + inline ::std::tuple<::std::vector,::std::vector,::std::vector> _fused_sgd(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList grads, at::TensorList momentum_buffer_list, double weight_decay, double momentum, const at::Tensor & lr, double dampening, bool nesterov, bool maximize, bool is_first_step, const ::std::optional & grad_scale={}, const ::std::optional & found_inf={}) { + return at::_ops::_fused_sgd_tensor_lr::redispatch(dispatchKeySet, self, grads, momentum_buffer_list, weight_decay, momentum, lr, dampening, nesterov, maximize, is_first_step, grad_scale, found_inf); + } + + // aten::_fused_adagrad.out(Tensor[] self, Tensor(b!)[] grads, Tensor(c!)[] state_sums, Tensor(d!)[] state_steps, *, float lr, float lr_decay, float weight_decay, float eps, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None, Tensor(a!)[] out) -> () + inline void _fused_adagrad_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::TensorList grads, at::TensorList state_sums, at::TensorList state_steps, double lr, double lr_decay, double weight_decay, double eps, bool maximize, const ::std::optional & grad_scale={}, const ::std::optional & found_inf={}) { + return at::_ops::_fused_adagrad_out::redispatch(dispatchKeySet, self, grads, state_sums, state_steps, lr, lr_decay, weight_decay, eps, maximize, grad_scale, found_inf, out); + } + + // aten::_fused_adagrad.out(Tensor[] self, Tensor(b!)[] grads, Tensor(c!)[] state_sums, Tensor(d!)[] state_steps, *, float lr, float lr_decay, float weight_decay, float eps, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None, Tensor(a!)[] out) -> () + inline void _fused_adagrad_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList grads, at::TensorList state_sums, at::TensorList state_steps, double lr, double lr_decay, double weight_decay, double eps, bool maximize, const ::std::optional & grad_scale, const ::std::optional & found_inf, at::TensorList out) { + return at::_ops::_fused_adagrad_out::redispatch(dispatchKeySet, self, grads, state_sums, state_steps, lr, lr_decay, weight_decay, eps, maximize, grad_scale, found_inf, out); + } + + // aten::_fused_adagrad(Tensor[] self, Tensor[] grads, Tensor[] state_sums, Tensor[] state_steps, *, float lr, float lr_decay, float weight_decay, float eps, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> (Tensor[] self_out, Tensor[] grads_out, Tensor[] state_sums_out, Tensor[] state_steps_out) + inline ::std::tuple<::std::vector,::std::vector,::std::vector,::std::vector> _fused_adagrad(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList grads, at::TensorList state_sums, at::TensorList state_steps, double lr, double lr_decay, double weight_decay, double eps, bool maximize, const ::std::optional & grad_scale={}, const ::std::optional & found_inf={}) { + return at::_ops::_fused_adagrad::redispatch(dispatchKeySet, self, grads, state_sums, state_steps, lr, lr_decay, weight_decay, eps, maximize, grad_scale, found_inf); + } + + // aten::_fused_adagrad.tensor_lr_out(Tensor[] self, Tensor(b!)[] grads, Tensor(c!)[] state_sums, Tensor[] state_steps, *, Tensor lr, float lr_decay, float weight_decay, float eps, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None, Tensor(a!)[] out) -> () + inline void _fused_adagrad_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::TensorList grads, at::TensorList state_sums, at::TensorList state_steps, const at::Tensor & lr, double lr_decay, double weight_decay, double eps, bool maximize, const ::std::optional & grad_scale={}, const ::std::optional & found_inf={}) { + return at::_ops::_fused_adagrad_tensor_lr_out::redispatch(dispatchKeySet, self, grads, state_sums, state_steps, lr, lr_decay, weight_decay, eps, maximize, grad_scale, found_inf, out); + } + + // aten::_fused_adagrad.tensor_lr_out(Tensor[] self, Tensor(b!)[] grads, Tensor(c!)[] state_sums, Tensor[] state_steps, *, Tensor lr, float lr_decay, float weight_decay, float eps, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None, Tensor(a!)[] out) -> () + inline void _fused_adagrad_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList grads, at::TensorList state_sums, at::TensorList state_steps, const at::Tensor & lr, double lr_decay, double weight_decay, double eps, bool maximize, const ::std::optional & grad_scale, const ::std::optional & found_inf, at::TensorList out) { + return at::_ops::_fused_adagrad_tensor_lr_out::redispatch(dispatchKeySet, self, grads, state_sums, state_steps, lr, lr_decay, weight_decay, eps, maximize, grad_scale, found_inf, out); + } + + // aten::_fused_adagrad.tensor_lr(Tensor[] self, Tensor[] grads, Tensor[] state_sums, Tensor[] state_steps, *, Tensor lr, float lr_decay, float weight_decay, float eps, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> (Tensor[] self_out, Tensor[] grads_out, Tensor[] state_sums_out) + inline ::std::tuple<::std::vector,::std::vector,::std::vector> _fused_adagrad(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList grads, at::TensorList state_sums, at::TensorList state_steps, const at::Tensor & lr, double lr_decay, double weight_decay, double eps, bool maximize, const ::std::optional & grad_scale={}, const ::std::optional & found_inf={}) { + return at::_ops::_fused_adagrad_tensor_lr::redispatch(dispatchKeySet, self, grads, state_sums, state_steps, lr, lr_decay, weight_decay, eps, maximize, grad_scale, found_inf); + } +} // namespace redispatch + +} + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/RegistrationDeclarations.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/RegistrationDeclarations.h new file mode 100644 index 0000000000000000000000000000000000000000..7b28a0980762a45b777bd832511f08b9f0707b4f --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/RegistrationDeclarations.h @@ -0,0 +1,3192 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// This file contains all native_functions that can be registered to +// and the schema string that they should be registered with + +at::Tensor _cast_Byte(const at::Tensor & self, bool non_blocking); // {"schema": "aten::_cast_Byte(Tensor self, bool non_blocking=False) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _cast_Char(const at::Tensor & self, bool non_blocking); // {"schema": "aten::_cast_Char(Tensor self, bool non_blocking=False) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _cast_Double(const at::Tensor & self, bool non_blocking); // {"schema": "aten::_cast_Double(Tensor self, bool non_blocking=False) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _cast_Float(const at::Tensor & self, bool non_blocking); // {"schema": "aten::_cast_Float(Tensor self, bool non_blocking=False) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _cast_Int(const at::Tensor & self, bool non_blocking); // {"schema": "aten::_cast_Int(Tensor self, bool non_blocking=False) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _cast_Long(const at::Tensor & self, bool non_blocking); // {"schema": "aten::_cast_Long(Tensor self, bool non_blocking=False) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _cast_Short(const at::Tensor & self, bool non_blocking); // {"schema": "aten::_cast_Short(Tensor self, bool non_blocking=False) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _cast_Half(const at::Tensor & self, bool non_blocking); // {"schema": "aten::_cast_Half(Tensor self, bool non_blocking=False) -> Tensor", "dispatch": "False", "default": "True"} +void _backward(const at::Tensor & self, at::TensorList inputs, const ::std::optional & gradient, ::std::optional retain_graph, bool create_graph); // {"schema": "aten::_backward(Tensor self, Tensor[] inputs, Tensor? gradient=None, bool? retain_graph=None, bool create_graph=False) -> ()", "dispatch": "False", "default": "True"} +void set_data(at::Tensor & self, const at::Tensor & new_data); // {"schema": "aten::set_data(Tensor(a!) self, Tensor new_data) -> ()", "dispatch": "False", "default": "True"} +at::Tensor data(const at::Tensor & self); // {"schema": "aten::data(Tensor self) -> Tensor", "dispatch": "False", "default": "True"} +bool is_leaf(const at::Tensor & self); // {"schema": "aten::is_leaf(Tensor self) -> bool", "dispatch": "False", "default": "True"} +int64_t output_nr(const at::Tensor & self); // {"schema": "aten::output_nr(Tensor self) -> int", "dispatch": "False", "default": "True"} +int64_t _version(const at::Tensor & self); // {"schema": "aten::_version(Tensor self) -> int", "dispatch": "False", "default": "True"} +at::Tensor & requires_grad_(at::Tensor & self, bool requires_grad); // {"schema": "aten::requires_grad_(Tensor(a!) self, bool requires_grad=True) -> Tensor(a!)", "dispatch": "False", "default": "True"} +void retain_grad(at::Tensor & self); // {"schema": "aten::retain_grad(Tensor(a!) self) -> ()", "dispatch": "False", "default": "True"} +bool retains_grad(const at::Tensor & self); // {"schema": "aten::retains_grad(Tensor self) -> bool", "dispatch": "False", "default": "True"} +at::Tensor _fw_primal(const at::Tensor & self, int64_t level); // {"schema": "aten::_fw_primal(Tensor(a) self, int level) -> Tensor(a)", "dispatch": "True", "default": "True"} +at::Tensor _make_dual(const at::Tensor & primal, const at::Tensor & tangent, int64_t level); // {"schema": "aten::_make_dual(Tensor(a) primal, Tensor tangent, int level) -> Tensor(a)", "dispatch": "True", "default": "True"} +::std::tuple _unpack_dual(const at::Tensor & dual, int64_t level); // {"schema": "aten::_unpack_dual(Tensor(a) dual, int level) -> (Tensor(a) primal, Tensor tangent)", "dispatch": "False", "default": "True"} +at::Tensor _new_zeros_with_same_feature_meta(const at::Tensor & self, const at::Tensor & other, int64_t self_num_batch_dims); // {"schema": "aten::_new_zeros_with_same_feature_meta(Tensor self, Tensor other, *, int self_num_batch_dims=0) -> Tensor", "dispatch": "True", "default": "True"} +bool _has_same_storage_numel(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::_has_same_storage_numel(Tensor self, Tensor other) -> bool", "dispatch": "True", "default": "True"} +at::Tensor & rename_(at::Tensor & self, ::std::optional names); // {"schema": "aten::rename_(Tensor(a!) self, Dimname[]? names) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor rename(const at::Tensor & self, ::std::optional names); // {"schema": "aten::rename(Tensor(a) self, Dimname[]? names) -> Tensor(a)", "dispatch": "False", "default": "True"} +at::Tensor align_to(const at::Tensor & self, at::DimnameList names); // {"schema": "aten::align_to(Tensor(a) self, Dimname[] names) -> Tensor(a)", "dispatch": "False", "default": "True"} +at::Tensor align_to(const at::Tensor & self, at::DimnameList order, int64_t ellipsis_idx); // {"schema": "aten::align_to.ellipsis_idx(Tensor(a) self, Dimname[] order, int ellipsis_idx) -> Tensor(a)", "dispatch": "False", "default": "True"} +at::Tensor align_as(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::align_as(Tensor self, Tensor other) -> Tensor", "dispatch": "False", "default": "True"} +::std::vector align_tensors(at::TensorList tensors); // {"schema": "aten::align_tensors(Tensor[] tensors) -> Tensor[]", "dispatch": "False", "default": "True"} +void _assert_async(const at::Tensor & self); // {"schema": "aten::_assert_async(Tensor self) -> ()", "dispatch": "True", "default": "False"} +void _assert_async(const at::Tensor & self, c10::string_view assert_msg); // {"schema": "aten::_assert_async.msg(Tensor self, str assert_msg) -> ()", "dispatch": "True", "default": "False"} +void _assert_scalar(const at::Scalar & self, c10::string_view assert_msg); // {"schema": "aten::_assert_scalar(Scalar self, str assert_msg) -> ()", "dispatch": "True", "default": "True"} +at::Tensor _functional_assert_scalar(const at::Scalar & self, c10::string_view assert_msg, const at::Tensor & dep_token); // {"schema": "aten::_functional_assert_scalar(Scalar self, str assert_msg, Tensor dep_token) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor _functional_assert_async(const at::Tensor & self, c10::string_view assert_msg, const at::Tensor & dep_token); // {"schema": "aten::_functional_assert_async.msg(Tensor self, str assert_msg, Tensor dep_token) -> Tensor", "dispatch": "True", "default": "False"} +void _assert_tensor_metadata(const at::Tensor & a, at::OptionalSymIntArrayRef size, at::OptionalSymIntArrayRef stride, ::std::optional dtype, ::std::optional device, ::std::optional layout); // {"schema": "aten::_assert_tensor_metadata(Tensor a, SymInt[]? size=None, SymInt[]? stride=None, ScalarType? dtype=None, *, Device? device=None, Layout? layout=None) -> ()", "dispatch": "True", "default": "True"} +void _print(c10::string_view s); // {"schema": "aten::_print(str s) -> ()", "dispatch": "True", "default": "True"} +void sym_constrain_range(const at::Scalar & size, ::std::optional min, ::std::optional max); // {"schema": "aten::sym_constrain_range(Scalar size, *, int? min=None, int? max=None) -> ()", "dispatch": "True", "default": "True"} +void sym_constrain_range_for_size(const at::Scalar & size, ::std::optional min, ::std::optional max); // {"schema": "aten::sym_constrain_range_for_size(Scalar size, *, int? min=None, int? max=None) -> ()", "dispatch": "True", "default": "True"} +at::Tensor _functional_sym_constrain_range(const at::Scalar & size, ::std::optional min, ::std::optional max, const at::Tensor & dep_token); // {"schema": "aten::_functional_sym_constrain_range(Scalar size, int? min, int? max, Tensor dep_token) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor _functional_sym_constrain_range_for_size(const at::Scalar & size, ::std::optional min, ::std::optional max, const at::Tensor & dep_token); // {"schema": "aten::_functional_sym_constrain_range_for_size(Scalar size, int? min, int? max, Tensor dep_token) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor _make_dep_token(::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format); // {"schema": "aten::_make_dep_token(*, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor refine_names(const at::Tensor & self, at::DimnameList names); // {"schema": "aten::refine_names(Tensor(a) self, Dimname[] names) -> Tensor(a)", "dispatch": "False", "default": "True"} +bool _use_cudnn_ctc_loss(const at::Tensor & log_probs, const at::Tensor & targets, at::IntArrayRef input_lengths, at::IntArrayRef target_lengths, int64_t blank); // {"schema": "aten::_use_cudnn_ctc_loss(Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, int blank) -> bool", "dispatch": "True", "default": "False"} +bool _use_cudnn_ctc_loss(const at::Tensor & log_probs, const at::Tensor & targets, const at::Tensor & input_lengths, const at::Tensor & target_lengths, int64_t blank); // {"schema": "aten::_use_cudnn_ctc_loss.Tensor(Tensor log_probs, Tensor targets, Tensor input_lengths, Tensor target_lengths, int blank) -> bool", "dispatch": "True", "default": "False"} +::std::tuple _cudnn_ctc_loss(const at::Tensor & log_probs, const at::Tensor & targets, at::IntArrayRef input_lengths, at::IntArrayRef target_lengths, int64_t blank, bool deterministic, bool zero_infinity); // {"schema": "aten::_cudnn_ctc_loss(Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, int blank, bool deterministic, bool zero_infinity) -> (Tensor, Tensor)", "dispatch": "True", "default": "False"} +::std::tuple _cudnn_ctc_loss(const at::Tensor & log_probs, const at::Tensor & targets, const at::Tensor & input_lengths, const at::Tensor & target_lengths, int64_t blank, bool deterministic, bool zero_infinity); // {"schema": "aten::_cudnn_ctc_loss.Tensor(Tensor log_probs, Tensor targets, Tensor input_lengths, Tensor target_lengths, int blank, bool deterministic, bool zero_infinity) -> (Tensor, Tensor)", "dispatch": "True", "default": "False"} +bool _use_cudnn_rnn_flatten_weight(); // {"schema": "aten::_use_cudnn_rnn_flatten_weight() -> bool", "dispatch": "False", "default": "True"} +at::Tensor _cudnn_rnn_flatten_weight(at::TensorList weight_arr, int64_t weight_stride0, c10::SymInt input_size, int64_t mode, c10::SymInt hidden_size, c10::SymInt proj_size, int64_t num_layers, bool batch_first, bool bidirectional); // {"schema": "aten::_cudnn_rnn_flatten_weight(Tensor[] weight_arr, int weight_stride0, SymInt input_size, int mode, SymInt hidden_size, SymInt proj_size, int num_layers, bool batch_first, bool bidirectional) -> Tensor", "dispatch": "True", "default": "False"} +::std::tuple _cudnn_rnn(const at::Tensor & input, at::TensorList weight, int64_t weight_stride0, const ::std::optional & weight_buf, const at::Tensor & hx, const ::std::optional & cx, int64_t mode, c10::SymInt hidden_size, c10::SymInt proj_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, c10::SymIntArrayRef batch_sizes, const ::std::optional & dropout_state); // {"schema": "aten::_cudnn_rnn(Tensor input, Tensor[] weight, int weight_stride0, Tensor? weight_buf, Tensor hx, Tensor? cx, int mode, SymInt hidden_size, SymInt proj_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, SymInt[] batch_sizes, Tensor? dropout_state) -> (Tensor, Tensor, Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"} +::std::tuple> _cudnn_rnn_backward(const at::Tensor & input, at::TensorList weight, int64_t weight_stride0, const at::Tensor & weight_buf, const at::Tensor & hx, const ::std::optional & cx, const at::Tensor & output, const ::std::optional & grad_output, const ::std::optional & grad_hy, const ::std::optional & grad_cy, int64_t mode, c10::SymInt hidden_size, c10::SymInt proj_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, c10::SymIntArrayRef batch_sizes, const ::std::optional & dropout_state, const at::Tensor & reserve, ::std::array output_mask); // {"schema": "aten::_cudnn_rnn_backward(Tensor input, Tensor[] weight, int weight_stride0, Tensor weight_buf, Tensor hx, Tensor? cx, Tensor output, Tensor? grad_output, Tensor? grad_hy, Tensor? grad_cy, int mode, SymInt hidden_size, SymInt proj_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, SymInt[] batch_sizes, Tensor? dropout_state, Tensor reserve, bool[4] output_mask) -> (Tensor, Tensor, Tensor, Tensor[])", "dispatch": "True", "default": "False"} +at::Tensor _cudnn_init_dropout_state(double dropout, bool train, int64_t dropout_seed, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::_cudnn_init_dropout_state(float dropout, bool train, int dropout_seed, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor", "dispatch": "True", "default": "False"} +int64_t _debug_has_internal_overlap(const at::Tensor & self); // {"schema": "aten::_debug_has_internal_overlap(Tensor self) -> int", "dispatch": "False", "default": "True"} +::std::tuple _fused_dropout(const at::Tensor & self, double p, ::std::optional generator); // {"schema": "aten::_fused_dropout(Tensor self, float p, Generator? generator=None) -> (Tensor, Tensor)", "dispatch": "True", "default": "False"} +at::Tensor _masked_scale(const at::Tensor & self, const at::Tensor & mask, double scale); // {"schema": "aten::_masked_scale(Tensor self, Tensor mask, float scale) -> Tensor", "dispatch": "True", "default": "False"} +::std::tuple native_dropout(const at::Tensor & input, double p, ::std::optional train); // {"schema": "aten::native_dropout(Tensor input, float p, bool? train) -> (Tensor, Tensor)", "dispatch": "True", "default": "False"} +at::Tensor native_dropout_backward(const at::Tensor & grad_output, const at::Tensor & mask, double scale); // {"schema": "aten::native_dropout_backward(Tensor grad_output, Tensor mask, float scale) -> Tensor", "dispatch": "True", "default": "False"} +::std::tuple _sobol_engine_draw(const at::Tensor & quasi, int64_t n, const at::Tensor & sobolstate, int64_t dimension, int64_t num_generated, ::std::optional dtype); // {"schema": "aten::_sobol_engine_draw(Tensor quasi, int n, Tensor sobolstate, int dimension, int num_generated, ScalarType? dtype) -> (Tensor, Tensor)", "dispatch": "False", "default": "True"} +at::Tensor & _sobol_engine_ff_(at::Tensor & self, int64_t n, const at::Tensor & sobolstate, int64_t dimension, int64_t num_generated); // {"schema": "aten::_sobol_engine_ff_(Tensor(a!) self, int n, Tensor sobolstate, int dimension, int num_generated) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & _sobol_engine_scramble_(at::Tensor & self, const at::Tensor & ltm, int64_t dimension); // {"schema": "aten::_sobol_engine_scramble_(Tensor(a!) self, Tensor ltm, int dimension) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & _sobol_engine_initialize_state_(at::Tensor & self, int64_t dimension); // {"schema": "aten::_sobol_engine_initialize_state_(Tensor(a!) self, int dimension) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor _reshape_from_tensor(const at::Tensor & self, const at::Tensor & shape); // {"schema": "aten::_reshape_from_tensor(Tensor self, Tensor shape) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _shape_as_tensor(const at::Tensor & self); // {"schema": "aten::_shape_as_tensor(Tensor self) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor dropout(const at::Tensor & input, double p, bool train); // {"schema": "aten::dropout(Tensor input, float p, bool train) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & dropout_(at::Tensor & self, double p, bool train); // {"schema": "aten::dropout_(Tensor(a!) self, float p, bool train) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor feature_dropout(const at::Tensor & input, double p, bool train); // {"schema": "aten::feature_dropout(Tensor input, float p, bool train) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & feature_dropout_(at::Tensor & self, double p, bool train); // {"schema": "aten::feature_dropout_(Tensor(a!) self, float p, bool train) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor alpha_dropout(const at::Tensor & input, double p, bool train); // {"schema": "aten::alpha_dropout(Tensor input, float p, bool train) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & alpha_dropout_(at::Tensor & self, double p, bool train); // {"schema": "aten::alpha_dropout_(Tensor(a!) self, float p, bool train) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor feature_alpha_dropout(const at::Tensor & input, double p, bool train); // {"schema": "aten::feature_alpha_dropout(Tensor input, float p, bool train) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & feature_alpha_dropout_(at::Tensor & self, double p, bool train); // {"schema": "aten::feature_alpha_dropout_(Tensor(a!) self, float p, bool train) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor abs(const at::Tensor & self); // {"schema": "aten::abs(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & abs_(at::Tensor & self); // {"schema": "aten::abs_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & abs_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::abs.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor absolute(const at::Tensor & self); // {"schema": "aten::absolute(Tensor self) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & absolute_(at::Tensor & self); // {"schema": "aten::absolute_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & absolute_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::absolute.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor angle(const at::Tensor & self); // {"schema": "aten::angle(Tensor self) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & angle_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::angle.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor view_as_real(const at::Tensor & self); // {"schema": "aten::view_as_real(Tensor(a) self) -> Tensor(a)", "dispatch": "True", "default": "False"} +at::Tensor view_as_complex(const at::Tensor & self); // {"schema": "aten::view_as_complex(Tensor(a) self) -> Tensor(a)", "dispatch": "True", "default": "False"} +at::Tensor sgn(const at::Tensor & self); // {"schema": "aten::sgn(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & sgn_(at::Tensor & self); // {"schema": "aten::sgn_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & sgn_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::sgn.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor chalf(const at::Tensor & self, ::std::optional memory_format); // {"schema": "aten::chalf(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor real(const at::Tensor & self); // {"schema": "aten::real(Tensor(a) self) -> Tensor(a)", "dispatch": "False", "default": "True"} +at::Tensor imag(const at::Tensor & self); // {"schema": "aten::imag(Tensor(a) self) -> Tensor(a)", "dispatch": "False", "default": "True"} +at::Tensor _conj(const at::Tensor & self); // {"schema": "aten::_conj(Tensor(a) self) -> Tensor(a)", "dispatch": "True", "default": "True"} +at::Tensor conj(const at::Tensor & self); // {"schema": "aten::conj(Tensor(a) self) -> Tensor(a)", "dispatch": "False", "default": "True"} +at::Tensor _conj_physical(const at::Tensor & self); // {"schema": "aten::_conj_physical(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor conj_physical(const at::Tensor & self); // {"schema": "aten::conj_physical(Tensor self) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & conj_physical_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::conj_physical.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & conj_physical_(at::Tensor & self); // {"schema": "aten::conj_physical_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor resolve_conj(const at::Tensor & self); // {"schema": "aten::resolve_conj(Tensor(a) self) -> Tensor(a)", "dispatch": "False", "default": "True"} +at::Tensor resolve_neg(const at::Tensor & self); // {"schema": "aten::resolve_neg(Tensor(a) self) -> Tensor(a)", "dispatch": "False", "default": "True"} +at::Tensor _neg_view(const at::Tensor & self); // {"schema": "aten::_neg_view(Tensor(a) self) -> Tensor(a)", "dispatch": "True", "default": "True"} +at::Tensor acos(const at::Tensor & self); // {"schema": "aten::acos(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & acos_(at::Tensor & self); // {"schema": "aten::acos_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & acos_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::acos.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor arccos(const at::Tensor & self); // {"schema": "aten::arccos(Tensor self) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & arccos_(at::Tensor & self); // {"schema": "aten::arccos_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & arccos_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::arccos.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor avg_pool1d(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad); // {"schema": "aten::avg_pool1d(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, bool ceil_mode=False, bool count_include_pad=True) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor adaptive_avg_pool1d(const at::Tensor & self, at::IntArrayRef output_size); // {"schema": "aten::adaptive_avg_pool1d(Tensor self, int[1] output_size) -> Tensor", "dispatch": "False", "default": "True"} +::std::tuple adaptive_max_pool1d(const at::Tensor & self, at::IntArrayRef output_size); // {"schema": "aten::adaptive_max_pool1d(Tensor self, int[1] output_size) -> (Tensor, Tensor)", "dispatch": "False", "default": "True"} +at::Tensor add(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha); // {"schema": "aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & add_(at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha); // {"schema": "aten::add_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & add_out(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha, at::Tensor & out); // {"schema": "aten::add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor _add_relu(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha); // {"schema": "aten::_add_relu.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & _add_relu_(at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha); // {"schema": "aten::_add_relu_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & _add_relu_out(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha, at::Tensor & out); // {"schema": "aten::_add_relu.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor _add_relu(const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha); // {"schema": "aten::_add_relu.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & _add_relu_(at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha); // {"schema": "aten::_add_relu_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor add(const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha); // {"schema": "aten::add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & add_(at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha); // {"schema": "aten::add_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor addmv(const at::Tensor & self, const at::Tensor & mat, const at::Tensor & vec, const at::Scalar & beta, const at::Scalar & alpha); // {"schema": "aten::addmv(Tensor self, Tensor mat, Tensor vec, *, Scalar beta=1, Scalar alpha=1) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & addmv_(at::Tensor & self, const at::Tensor & mat, const at::Tensor & vec, const at::Scalar & beta, const at::Scalar & alpha); // {"schema": "aten::addmv_(Tensor(a!) self, Tensor mat, Tensor vec, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & addmv_out(const at::Tensor & self, const at::Tensor & mat, const at::Tensor & vec, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out); // {"schema": "aten::addmv.out(Tensor self, Tensor mat, Tensor vec, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor addr(const at::Tensor & self, const at::Tensor & vec1, const at::Tensor & vec2, const at::Scalar & beta, const at::Scalar & alpha); // {"schema": "aten::addr(Tensor self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & addr_(at::Tensor & self, const at::Tensor & vec1, const at::Tensor & vec2, const at::Scalar & beta, const at::Scalar & alpha); // {"schema": "aten::addr_(Tensor(a!) self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & addr_out(const at::Tensor & self, const at::Tensor & vec1, const at::Tensor & vec2, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out); // {"schema": "aten::addr.out(Tensor self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor affine_grid_generator(const at::Tensor & theta, c10::SymIntArrayRef size, bool align_corners); // {"schema": "aten::affine_grid_generator(Tensor theta, SymInt[] size, bool align_corners) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor affine_grid_generator_backward(const at::Tensor & grad, c10::SymIntArrayRef size, bool align_corners); // {"schema": "aten::affine_grid_generator_backward(Tensor grad, SymInt[] size, bool align_corners) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _is_all_true(const at::Tensor & self); // {"schema": "aten::_is_all_true(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor _is_any_true(const at::Tensor & self); // {"schema": "aten::_is_any_true(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor _test_check_tensor(const at::Tensor & self); // {"schema": "aten::_test_check_tensor(Tensor self) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _test_functorch_fallback(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::_test_functorch_fallback(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor all(const at::Tensor & self, int64_t dim, bool keepdim); // {"schema": "aten::all.dim(Tensor self, int dim, bool keepdim=False) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor all(const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim); // {"schema": "aten::all.dims(Tensor self, int[]? dim=None, bool keepdim=False) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & all_out(const at::Tensor & self, int64_t dim, bool keepdim, at::Tensor & out); // {"schema": "aten::all.out(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & all_out(const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim, at::Tensor & out); // {"schema": "aten::all.dims_out(Tensor self, int[]? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor all(const at::Tensor & self, at::Dimname dim, bool keepdim); // {"schema": "aten::all.dimname(Tensor self, Dimname dim, bool keepdim=False) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & all_out(const at::Tensor & self, at::Dimname dim, bool keepdim, at::Tensor & out); // {"schema": "aten::all.dimname_out(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +bool allclose(const at::Tensor & self, const at::Tensor & other, double rtol, double atol, bool equal_nan); // {"schema": "aten::allclose(Tensor self, Tensor other, float rtol=1e-05, float atol=1e-08, bool equal_nan=False) -> bool", "dispatch": "True", "default": "True"} +at::Tensor any(const at::Tensor & self, int64_t dim, bool keepdim); // {"schema": "aten::any.dim(Tensor self, int dim, bool keepdim=False) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor any(const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim); // {"schema": "aten::any.dims(Tensor self, int[]? dim=None, bool keepdim=False) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & any_out(const at::Tensor & self, int64_t dim, bool keepdim, at::Tensor & out); // {"schema": "aten::any.out(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & any_out(const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim, at::Tensor & out); // {"schema": "aten::any.dims_out(Tensor self, int[]? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor any(const at::Tensor & self, at::Dimname dim, bool keepdim); // {"schema": "aten::any.dimname(Tensor self, Dimname dim, bool keepdim=False) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & any_out(const at::Tensor & self, at::Dimname dim, bool keepdim, at::Tensor & out); // {"schema": "aten::any.dimname_out(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor arange(const at::Scalar & end, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::arange(Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor arange(const at::Scalar & start, const at::Scalar & end, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::arange.start(Scalar start, Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor arange(const at::Scalar & start, const at::Scalar & end, const at::Scalar & step, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::arange.start_step(Scalar start, Scalar end, Scalar step=1, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & arange_out(const at::Scalar & end, at::Tensor & out); // {"schema": "aten::arange.out(Scalar end, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & arange_out(const at::Scalar & start, const at::Scalar & end, const at::Scalar & step, at::Tensor & out); // {"schema": "aten::arange.start_out(Scalar start, Scalar end, Scalar step=1, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor _dim_arange(const at::Tensor & like, int64_t dim); // {"schema": "aten::_dim_arange(Tensor like, int dim) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor argmax(const at::Tensor & self, ::std::optional dim, bool keepdim); // {"schema": "aten::argmax(Tensor self, int? dim=None, bool keepdim=False) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & argmax_out(const at::Tensor & self, ::std::optional dim, bool keepdim, at::Tensor & out); // {"schema": "aten::argmax.out(Tensor self, int? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor argmin(const at::Tensor & self, ::std::optional dim, bool keepdim); // {"schema": "aten::argmin(Tensor self, int? dim=None, bool keepdim=False) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & argmin_out(const at::Tensor & self, ::std::optional dim, bool keepdim, at::Tensor & out); // {"schema": "aten::argmin.out(Tensor self, int? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor acosh(const at::Tensor & self); // {"schema": "aten::acosh(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & acosh_(at::Tensor & self); // {"schema": "aten::acosh_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & acosh_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::acosh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor arccosh(const at::Tensor & self); // {"schema": "aten::arccosh(Tensor self) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & arccosh_(at::Tensor & self); // {"schema": "aten::arccosh_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & arccosh_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::arccosh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor asinh(const at::Tensor & self); // {"schema": "aten::asinh(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & asinh_(at::Tensor & self); // {"schema": "aten::asinh_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & asinh_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::asinh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor arcsinh(const at::Tensor & self); // {"schema": "aten::arcsinh(Tensor self) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & arcsinh_(at::Tensor & self); // {"schema": "aten::arcsinh_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & arcsinh_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::arcsinh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor atanh(const at::Tensor & self); // {"schema": "aten::atanh(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & atanh_(at::Tensor & self); // {"schema": "aten::atanh_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & atanh_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::atanh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor arctanh(const at::Tensor & self); // {"schema": "aten::arctanh(Tensor self) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & arctanh_(at::Tensor & self); // {"schema": "aten::arctanh_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & arctanh_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::arctanh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor as_strided(const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional storage_offset); // {"schema": "aten::as_strided(Tensor(a) self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor(a)", "dispatch": "True", "default": "False"} +const at::Tensor & as_strided_(const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional storage_offset); // {"schema": "aten::as_strided_(Tensor(a!) self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor asin(const at::Tensor & self); // {"schema": "aten::asin(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & asin_(at::Tensor & self); // {"schema": "aten::asin_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & asin_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::asin.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor arcsin(const at::Tensor & self); // {"schema": "aten::arcsin(Tensor self) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & arcsin_(at::Tensor & self); // {"schema": "aten::arcsin_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & arcsin_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::arcsin.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor atan(const at::Tensor & self); // {"schema": "aten::atan(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & atan_(at::Tensor & self); // {"schema": "aten::atan_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & atan_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::atan.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor arctan(const at::Tensor & self); // {"schema": "aten::arctan(Tensor self) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & arctan_(at::Tensor & self); // {"schema": "aten::arctan_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & arctan_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::arctan.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor atleast_1d(const at::Tensor & self); // {"schema": "aten::atleast_1d(Tensor self) -> Tensor", "dispatch": "False", "default": "True"} +::std::vector atleast_1d(at::TensorList tensors); // {"schema": "aten::atleast_1d.Sequence(Tensor[] tensors) -> Tensor[]", "dispatch": "False", "default": "True"} +at::Tensor atleast_2d(const at::Tensor & self); // {"schema": "aten::atleast_2d(Tensor self) -> Tensor", "dispatch": "False", "default": "True"} +::std::vector atleast_2d(at::TensorList tensors); // {"schema": "aten::atleast_2d.Sequence(Tensor[] tensors) -> Tensor[]", "dispatch": "False", "default": "True"} +at::Tensor atleast_3d(const at::Tensor & self); // {"schema": "aten::atleast_3d(Tensor self) -> Tensor", "dispatch": "False", "default": "True"} +::std::vector atleast_3d(at::TensorList tensors); // {"schema": "aten::atleast_3d.Sequence(Tensor[] tensors) -> Tensor[]", "dispatch": "False", "default": "True"} +at::Tensor baddbmm(const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta, const at::Scalar & alpha); // {"schema": "aten::baddbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & baddbmm_(at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta, const at::Scalar & alpha); // {"schema": "aten::baddbmm_(Tensor(a!) self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & baddbmm_out(const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out); // {"schema": "aten::baddbmm.out(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor baddbmm(const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, at::ScalarType out_dtype, const at::Scalar & beta, const at::Scalar & alpha); // {"schema": "aten::baddbmm.dtype(Tensor self, Tensor batch1, Tensor batch2, ScalarType out_dtype, *, Scalar beta=1, Scalar alpha=1) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & baddbmm_out(const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, at::ScalarType out_dtype, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out); // {"schema": "aten::baddbmm.dtype_out(Tensor self, Tensor batch1, Tensor batch2, ScalarType out_dtype, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor bartlett_window(int64_t window_length, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::bartlett_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor bartlett_window(int64_t window_length, bool periodic, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::bartlett_window.periodic(int window_length, bool periodic, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor batch_norm(const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const ::std::optional & running_mean, const ::std::optional & running_var, bool training, double momentum, double eps, bool cudnn_enabled); // {"schema": "aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor quantized_batch_norm(const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const at::Tensor & mean, const at::Tensor & var, double eps, double output_scale, int64_t output_zero_point); // {"schema": "aten::quantized_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor mean, Tensor var, float eps, float output_scale, int output_zero_point) -> Tensor", "dispatch": "True", "default": "False"} +::std::tuple _batch_norm_impl_index(const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const ::std::optional & running_mean, const ::std::optional & running_var, bool training, double momentum, double eps, bool cudnn_enabled); // {"schema": "aten::_batch_norm_impl_index(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> (Tensor, Tensor, Tensor, Tensor, int)", "dispatch": "False", "default": "True"} +::std::tuple _batch_norm_impl_index_backward(int64_t impl_index, const at::Tensor & input, const at::Tensor & grad_output, const ::std::optional & weight, const ::std::optional & running_mean, const ::std::optional & running_var, const ::std::optional & save_mean, const ::std::optional & save_var_transform, bool train, double eps, ::std::array output_mask, const at::Tensor & reservedSpace); // {"schema": "aten::_batch_norm_impl_index_backward(int impl_index, Tensor input, Tensor grad_output, Tensor? weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var_transform, bool train, float eps, bool[3] output_mask, Tensor reservedSpace) -> (Tensor, Tensor, Tensor)", "dispatch": "False", "default": "True"} +at::Tensor bernoulli(const at::Tensor & self, ::std::optional generator); // {"schema": "aten::bernoulli(Tensor self, *, Generator? generator=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & bernoulli_out(const at::Tensor & self, ::std::optional generator, at::Tensor & out); // {"schema": "aten::bernoulli.out(Tensor self, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & bernoulli_(at::Tensor & self, const at::Tensor & p, ::std::optional generator); // {"schema": "aten::bernoulli_.Tensor(Tensor(a!) self, Tensor p, *, Generator? generator=None) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & bernoulli_(at::Tensor & self, double p, ::std::optional generator); // {"schema": "aten::bernoulli_.float(Tensor(a!) self, float p=0.5, *, Generator? generator=None) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor bernoulli(const at::Tensor & self, double p, ::std::optional generator); // {"schema": "aten::bernoulli.p(Tensor self, float p, *, Generator? generator=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor bilinear(const at::Tensor & input1, const at::Tensor & input2, const at::Tensor & weight, const ::std::optional & bias); // {"schema": "aten::bilinear(Tensor input1, Tensor input2, Tensor weight, Tensor? bias=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor binary_cross_entropy(const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction); // {"schema": "aten::binary_cross_entropy(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & binary_cross_entropy_out(const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, at::Tensor & out); // {"schema": "aten::binary_cross_entropy.out(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor binary_cross_entropy_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction); // {"schema": "aten::binary_cross_entropy_backward(Tensor grad_output, Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & binary_cross_entropy_backward_out(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, at::Tensor & grad_input); // {"schema": "aten::binary_cross_entropy_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor binary_cross_entropy_with_logits(const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, const ::std::optional & pos_weight, int64_t reduction); // {"schema": "aten::binary_cross_entropy_with_logits(Tensor self, Tensor target, Tensor? weight=None, Tensor? pos_weight=None, int reduction=Mean) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor bincount(const at::Tensor & self, const ::std::optional & weights, c10::SymInt minlength); // {"schema": "aten::bincount(Tensor self, Tensor? weights=None, SymInt minlength=0) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor bitwise_not(const at::Tensor & self); // {"schema": "aten::bitwise_not(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & bitwise_not_(at::Tensor & self); // {"schema": "aten::bitwise_not_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & bitwise_not_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::bitwise_not.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & copysign_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::copysign.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor copysign(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::copysign.Tensor(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & copysign_(at::Tensor & self, const at::Tensor & other); // {"schema": "aten::copysign_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor copysign(const at::Tensor & self, const at::Scalar & other); // {"schema": "aten::copysign.Scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & copysign_(at::Tensor & self, const at::Scalar & other); // {"schema": "aten::copysign_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & copysign_out(const at::Tensor & self, const at::Scalar & other, at::Tensor & out); // {"schema": "aten::copysign.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor _lazy_clone(const at::Tensor & self); // {"schema": "aten::_lazy_clone(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor logical_not(const at::Tensor & self); // {"schema": "aten::logical_not(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & logical_not_(at::Tensor & self); // {"schema": "aten::logical_not_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & logical_not_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::logical_not.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor logical_xor(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::logical_xor(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & logical_xor_(at::Tensor & self, const at::Tensor & other); // {"schema": "aten::logical_xor_(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & logical_xor_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::logical_xor.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor logical_and(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::logical_and(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & logical_and_(at::Tensor & self, const at::Tensor & other); // {"schema": "aten::logical_and_(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & logical_and_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::logical_and.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor logical_or(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::logical_or(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & logical_or_(at::Tensor & self, const at::Tensor & other); // {"schema": "aten::logical_or_(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & logical_or_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::logical_or.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor blackman_window(int64_t window_length, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::blackman_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor blackman_window(int64_t window_length, bool periodic, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::blackman_window.periodic(int window_length, bool periodic, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor bmm(const at::Tensor & self, const at::Tensor & mat2); // {"schema": "aten::bmm(Tensor self, Tensor mat2) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & bmm_out(const at::Tensor & self, const at::Tensor & mat2, at::Tensor & out); // {"schema": "aten::bmm.out(Tensor self, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor bmm(const at::Tensor & self, const at::Tensor & mat2, at::ScalarType out_dtype); // {"schema": "aten::bmm.dtype(Tensor self, Tensor mat2, ScalarType out_dtype) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & bmm_out(const at::Tensor & self, const at::Tensor & mat2, at::ScalarType out_dtype, at::Tensor & out); // {"schema": "aten::bmm.dtype_out(Tensor self, Tensor mat2, ScalarType out_dtype, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +::std::vector broadcast_tensors(at::TensorList tensors); // {"schema": "aten::broadcast_tensors(Tensor[] tensors) -> Tensor[]", "dispatch": "False", "default": "True"} +at::Tensor broadcast_to(const at::Tensor & self, c10::SymIntArrayRef size); // {"schema": "aten::broadcast_to(Tensor(a) self, SymInt[] size) -> Tensor(a)", "dispatch": "False", "default": "True"} +at::Tensor _sparse_broadcast_to(const at::Tensor & self, at::IntArrayRef size); // {"schema": "aten::_sparse_broadcast_to(Tensor(a) self, int[] size) -> Tensor(a)", "dispatch": "True", "default": "False"} +at::Tensor cat(const at::ITensorListRef & tensors, int64_t dim); // {"schema": "aten::cat(Tensor[] tensors, int dim=0) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & cat_out(const at::ITensorListRef & tensors, int64_t dim, at::Tensor & out); // {"schema": "aten::cat.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor cat(at::TensorList tensors, at::Dimname dim); // {"schema": "aten::cat.names(Tensor[] tensors, Dimname dim) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & cat_out(at::TensorList tensors, at::Dimname dim, at::Tensor & out); // {"schema": "aten::cat.names_out(Tensor[] tensors, Dimname dim, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor concat(at::TensorList tensors, int64_t dim); // {"schema": "aten::concat(Tensor[] tensors, int dim=0) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & concat_out(at::TensorList tensors, int64_t dim, at::Tensor & out); // {"schema": "aten::concat.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor concat(at::TensorList tensors, at::Dimname dim); // {"schema": "aten::concat.names(Tensor[] tensors, Dimname dim) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & concat_out(at::TensorList tensors, at::Dimname dim, at::Tensor & out); // {"schema": "aten::concat.names_out(Tensor[] tensors, Dimname dim, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor concatenate(at::TensorList tensors, int64_t dim); // {"schema": "aten::concatenate(Tensor[] tensors, int dim=0) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & concatenate_out(at::TensorList tensors, int64_t dim, at::Tensor & out); // {"schema": "aten::concatenate.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor concatenate(at::TensorList tensors, at::Dimname dim); // {"schema": "aten::concatenate.names(Tensor[] tensors, Dimname dim) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & concatenate_out(at::TensorList tensors, at::Dimname dim, at::Tensor & out); // {"schema": "aten::concatenate.names_out(Tensor[] tensors, Dimname dim, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor block_diag(at::TensorList tensors); // {"schema": "aten::block_diag(Tensor[] tensors) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor ceil(const at::Tensor & self); // {"schema": "aten::ceil(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & ceil_(at::Tensor & self); // {"schema": "aten::ceil_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & ceil_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::ceil.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor chain_matmul(at::TensorList matrices); // {"schema": "aten::chain_matmul(Tensor[] matrices) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & chain_matmul_out(at::TensorList matrices, at::Tensor & out); // {"schema": "aten::chain_matmul.out(Tensor[] matrices, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +::std::vector unsafe_chunk(const at::Tensor & self, int64_t chunks, int64_t dim); // {"schema": "aten::unsafe_chunk(Tensor self, int chunks, int dim=0) -> Tensor[]", "dispatch": "False", "default": "True"} +::std::vector chunk(const at::Tensor & self, int64_t chunks, int64_t dim); // {"schema": "aten::chunk(Tensor(a -> *) self, int chunks, int dim=0) -> Tensor(a)[]", "dispatch": "True", "default": "True"} +::std::vector tensor_split(const at::Tensor & self, c10::SymInt sections, int64_t dim); // {"schema": "aten::tensor_split.sections(Tensor(a -> *) self, SymInt sections, int dim=0) -> Tensor(a)[]", "dispatch": "False", "default": "True"} +::std::vector tensor_split(const at::Tensor & self, c10::SymIntArrayRef indices, int64_t dim); // {"schema": "aten::tensor_split.indices(Tensor(a -> *) self, SymInt[] indices, int dim=0) -> Tensor(a)[]", "dispatch": "False", "default": "True"} +::std::vector tensor_split(const at::Tensor & self, const at::Tensor & tensor_indices_or_sections, int64_t dim); // {"schema": "aten::tensor_split.tensor_indices_or_sections(Tensor(a -> *) self, Tensor tensor_indices_or_sections, int dim=0) -> Tensor(a)[]", "dispatch": "False", "default": "True"} +at::Tensor clamp(const at::Tensor & self, const ::std::optional & min, const ::std::optional & max); // {"schema": "aten::clamp(Tensor self, Scalar? min=None, Scalar? max=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor clamp(const at::Tensor & self, const ::std::optional & min, const ::std::optional & max); // {"schema": "aten::clamp.Tensor(Tensor self, Tensor? min=None, Tensor? max=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & clamp_(at::Tensor & self, const ::std::optional & min, const ::std::optional & max); // {"schema": "aten::clamp_(Tensor(a!) self, Scalar? min=None, Scalar? max=None) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & clamp_(at::Tensor & self, const ::std::optional & min, const ::std::optional & max); // {"schema": "aten::clamp_.Tensor(Tensor(a!) self, Tensor? min=None, Tensor? max=None) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & clamp_out(const at::Tensor & self, const ::std::optional & min, const ::std::optional & max, at::Tensor & out); // {"schema": "aten::clamp.out(Tensor self, Scalar? min=None, Scalar? max=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & clamp_out(const at::Tensor & self, const ::std::optional & min, const ::std::optional & max, at::Tensor & out); // {"schema": "aten::clamp.Tensor_out(Tensor self, Tensor? min=None, Tensor? max=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor clamp_max(const at::Tensor & self, const at::Scalar & max); // {"schema": "aten::clamp_max(Tensor self, Scalar max) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor clamp_max(const at::Tensor & self, const at::Tensor & max); // {"schema": "aten::clamp_max.Tensor(Tensor self, Tensor max) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & clamp_max_(at::Tensor & self, const at::Scalar & max); // {"schema": "aten::clamp_max_(Tensor(a!) self, Scalar max) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & clamp_max_(at::Tensor & self, const at::Tensor & max); // {"schema": "aten::clamp_max_.Tensor(Tensor(a!) self, Tensor max) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & clamp_max_out(const at::Tensor & self, const at::Scalar & max, at::Tensor & out); // {"schema": "aten::clamp_max.out(Tensor self, Scalar max, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & clamp_max_out(const at::Tensor & self, const at::Tensor & max, at::Tensor & out); // {"schema": "aten::clamp_max.Tensor_out(Tensor self, Tensor max, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor clamp_min(const at::Tensor & self, const at::Scalar & min); // {"schema": "aten::clamp_min(Tensor self, Scalar min) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor clamp_min(const at::Tensor & self, const at::Tensor & min); // {"schema": "aten::clamp_min.Tensor(Tensor self, Tensor min) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & clamp_min_(at::Tensor & self, const at::Scalar & min); // {"schema": "aten::clamp_min_(Tensor(a!) self, Scalar min) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & clamp_min_(at::Tensor & self, const at::Tensor & min); // {"schema": "aten::clamp_min_.Tensor(Tensor(a!) self, Tensor min) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & clamp_min_out(const at::Tensor & self, const at::Scalar & min, at::Tensor & out); // {"schema": "aten::clamp_min.out(Tensor self, Scalar min, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & clamp_min_out(const at::Tensor & self, const at::Tensor & min, at::Tensor & out); // {"schema": "aten::clamp_min.Tensor_out(Tensor self, Tensor min, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor clip(const at::Tensor & self, const ::std::optional & min, const ::std::optional & max); // {"schema": "aten::clip(Tensor self, Scalar? min=None, Scalar? max=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor clip(const at::Tensor & self, const ::std::optional & min, const ::std::optional & max); // {"schema": "aten::clip.Tensor(Tensor self, Tensor? min=None, Tensor? max=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & clip_(at::Tensor & self, const ::std::optional & min, const ::std::optional & max); // {"schema": "aten::clip_(Tensor(a!) self, Scalar? min=None, Scalar? max=None) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & clip_(at::Tensor & self, const ::std::optional & min, const ::std::optional & max); // {"schema": "aten::clip_.Tensor(Tensor(a!) self, Tensor? min=None, Tensor? max=None) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & clip_out(const at::Tensor & self, const ::std::optional & min, const ::std::optional & max, at::Tensor & out); // {"schema": "aten::clip.out(Tensor self, Scalar? min=None, Scalar? max=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & clip_out(const at::Tensor & self, const ::std::optional & min, const ::std::optional & max, at::Tensor & out); // {"schema": "aten::clip.Tensor_out(Tensor self, Tensor? min=None, Tensor? max=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +bool cudnn_is_acceptable(const at::Tensor & self); // {"schema": "aten::cudnn_is_acceptable(Tensor self) -> bool", "dispatch": "False", "default": "True"} +at::Tensor complex(const at::Tensor & real, const at::Tensor & imag); // {"schema": "aten::complex(Tensor real, Tensor imag) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & complex_out(const at::Tensor & real, const at::Tensor & imag, at::Tensor & out); // {"schema": "aten::complex.out(Tensor real, Tensor imag, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor polar(const at::Tensor & abs, const at::Tensor & angle); // {"schema": "aten::polar(Tensor abs, Tensor angle) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & polar_out(const at::Tensor & abs, const at::Tensor & angle, at::Tensor & out); // {"schema": "aten::polar.out(Tensor abs, Tensor angle, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor constant_pad_nd(const at::Tensor & self, c10::SymIntArrayRef pad, const at::Scalar & value); // {"schema": "aten::constant_pad_nd(Tensor self, SymInt[] pad, Scalar value=0) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor contiguous(const at::Tensor & self, at::MemoryFormat memory_format); // {"schema": "aten::contiguous(Tensor(a) self, *, MemoryFormat memory_format=contiguous_format) -> Tensor(a)", "dispatch": "False", "default": "True"} +at::Tensor convolution(const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups); // {"schema": "aten::convolution(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups) -> Tensor", "dispatch": "True", "default": "True"} +::std::tuple convolution_backward(const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & weight, at::OptionalSymIntArrayRef bias_sizes, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups, ::std::array output_mask); // {"schema": "aten::convolution_backward(Tensor grad_output, Tensor input, Tensor weight, SymInt[]? bias_sizes, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor)", "dispatch": "True", "default": "True"} +at::Tensor convolution_overrideable(const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups); // {"schema": "aten::convolution_overrideable(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups) -> Tensor", "dispatch": "True", "default": "True"} +::std::tuple convolution_backward_overrideable(const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & weight, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups, ::std::array output_mask); // {"schema": "aten::convolution_backward_overrideable(Tensor grad_output, Tensor input, Tensor weight, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool[3] output_mask) -> (Tensor grad_input, Tensor grad_weight, Tensor grad_bias)", "dispatch": "True", "default": "True"} +at::Tensor _convolution(const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32); // {"schema": "aten::_convolution(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor _convolution(const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, c10::SymInt groups, bool benchmark, bool deterministic, bool cudnn_enabled); // {"schema": "aten::_convolution.deprecated(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, int[] output_padding, SymInt groups, bool benchmark, bool deterministic, bool cudnn_enabled) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _convolution_mode(const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::string_view padding, c10::SymIntArrayRef dilation, c10::SymInt groups); // {"schema": "aten::_convolution_mode(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, str padding, SymInt[] dilation, SymInt groups) -> Tensor", "dispatch": "False", "default": "True"} +::std::tuple _convolution_double_backward(const ::std::optional & ggI, const ::std::optional & ggW, const ::std::optional & ggb, const at::Tensor & gO, const at::Tensor & weight, const at::Tensor & self, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups, ::std::array output_mask); // {"schema": "aten::_convolution_double_backward(Tensor? ggI, Tensor? ggW, Tensor? ggb, Tensor gO, Tensor weight, Tensor self, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor)", "dispatch": "False", "default": "True"} +at::Tensor conv1d(const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, c10::SymInt groups); // {"schema": "aten::conv1d(Tensor input, Tensor weight, Tensor? bias=None, SymInt[1] stride=1, SymInt[1] padding=0, SymInt[1] dilation=1, SymInt groups=1) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor conv2d(const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, c10::SymInt groups); // {"schema": "aten::conv2d(Tensor input, Tensor weight, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] dilation=1, SymInt groups=1) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor conv3d(const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, c10::SymInt groups); // {"schema": "aten::conv3d(Tensor input, Tensor weight, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, SymInt[3] dilation=1, SymInt groups=1) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor conv1d(const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::string_view padding, c10::SymIntArrayRef dilation, c10::SymInt groups); // {"schema": "aten::conv1d.padding(Tensor input, Tensor weight, Tensor? bias=None, SymInt[1] stride=1, str padding=\"valid\", SymInt[1] dilation=1, SymInt groups=1) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor conv2d(const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::string_view padding, c10::SymIntArrayRef dilation, c10::SymInt groups); // {"schema": "aten::conv2d.padding(Tensor input, Tensor weight, Tensor? bias=None, SymInt[2] stride=1, str padding=\"valid\", SymInt[2] dilation=1, SymInt groups=1) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor conv3d(const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::string_view padding, c10::SymIntArrayRef dilation, c10::SymInt groups); // {"schema": "aten::conv3d.padding(Tensor input, Tensor weight, Tensor? bias=None, SymInt[3] stride=1, str padding=\"valid\", SymInt[3] dilation=1, SymInt groups=1) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor conv_tbc(const at::Tensor & self, const at::Tensor & weight, const at::Tensor & bias, int64_t pad); // {"schema": "aten::conv_tbc(Tensor self, Tensor weight, Tensor bias, int pad=0) -> Tensor", "dispatch": "True", "default": "True"} +::std::tuple conv_tbc_backward(const at::Tensor & self, const at::Tensor & input, const at::Tensor & weight, const at::Tensor & bias, int64_t pad); // {"schema": "aten::conv_tbc_backward(Tensor self, Tensor input, Tensor weight, Tensor bias, int pad) -> (Tensor, Tensor, Tensor)", "dispatch": "False", "default": "True"} +at::Tensor conv_transpose1d(const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymInt groups, c10::SymIntArrayRef dilation); // {"schema": "aten::conv_transpose1d(Tensor input, Tensor weight, Tensor? bias=None, SymInt[1] stride=1, SymInt[1] padding=0, SymInt[1] output_padding=0, SymInt groups=1, SymInt[1] dilation=1) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor conv_transpose2d(const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymInt groups, c10::SymIntArrayRef dilation); // {"schema": "aten::conv_transpose2d.input(Tensor input, Tensor weight, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] output_padding=0, SymInt groups=1, SymInt[2] dilation=1) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor conv_transpose3d(const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymInt groups, c10::SymIntArrayRef dilation); // {"schema": "aten::conv_transpose3d.input(Tensor input, Tensor weight, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, SymInt[3] output_padding=0, SymInt groups=1, SymInt[3] dilation=1) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor copy(const at::Tensor & self, const at::Tensor & src, bool non_blocking); // {"schema": "aten::copy(Tensor self, Tensor src, bool non_blocking=False) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & copy_(at::Tensor & self, const at::Tensor & src, bool non_blocking); // {"schema": "aten::copy_(Tensor(a!) self, Tensor src, bool non_blocking=False) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor _copy_from(const at::Tensor & self, const at::Tensor & dst, bool non_blocking); // {"schema": "aten::_copy_from(Tensor self, Tensor dst, bool non_blocking=False) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _copy_from_and_resize(const at::Tensor & self, const at::Tensor & dst); // {"schema": "aten::_copy_from_and_resize(Tensor self, Tensor dst) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor cos(const at::Tensor & self); // {"schema": "aten::cos(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & cos_(at::Tensor & self); // {"schema": "aten::cos_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & cos_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::cos.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor cosh(const at::Tensor & self); // {"schema": "aten::cosh(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & cosh_(at::Tensor & self); // {"schema": "aten::cosh_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & cosh_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::cosh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor cosine_embedding_loss(const at::Tensor & input1, const at::Tensor & input2, const at::Tensor & target, double margin, int64_t reduction); // {"schema": "aten::cosine_embedding_loss(Tensor input1, Tensor input2, Tensor target, float margin=0.0, int reduction=Mean) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor count_nonzero(const at::Tensor & self, at::IntArrayRef dim); // {"schema": "aten::count_nonzero.dim_IntList(Tensor self, int[] dim) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor count_nonzero(const at::Tensor & self, ::std::optional dim); // {"schema": "aten::count_nonzero(Tensor self, int? dim=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor cov(const at::Tensor & self, int64_t correction, const ::std::optional & fweights, const ::std::optional & aweights); // {"schema": "aten::cov(Tensor self, *, int correction=1, Tensor? fweights=None, Tensor? aweights=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor corrcoef(const at::Tensor & self); // {"schema": "aten::corrcoef(Tensor self) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor cudnn_affine_grid_generator(const at::Tensor & theta, int64_t N, int64_t C, int64_t H, int64_t W); // {"schema": "aten::cudnn_affine_grid_generator(Tensor theta, int N, int C, int H, int W) -> Tensor grid", "dispatch": "True", "default": "False"} +at::Tensor cudnn_affine_grid_generator_backward(const at::Tensor & grad, int64_t N, int64_t C, int64_t H, int64_t W); // {"schema": "aten::cudnn_affine_grid_generator_backward(Tensor grad, int N, int C, int H, int W) -> Tensor grad_theta", "dispatch": "True", "default": "False"} +::std::tuple cudnn_batch_norm(const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, const ::std::optional & running_mean, const ::std::optional & running_var, bool training, double exponential_average_factor, double epsilon); // {"schema": "aten::cudnn_batch_norm(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon) -> (Tensor, Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"} +::std::tuple cudnn_batch_norm_out(const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, const ::std::optional & running_mean, const ::std::optional & running_var, bool training, double exponential_average_factor, double epsilon, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3); // {"schema": "aten::cudnn_batch_norm.out(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!))", "dispatch": "True", "default": "False"} +::std::tuple cudnn_batch_norm_backward(const at::Tensor & input, const at::Tensor & grad_output, const at::Tensor & weight, const ::std::optional & running_mean, const ::std::optional & running_var, const ::std::optional & save_mean, const ::std::optional & save_var, double epsilon, const at::Tensor & reserveSpace); // {"schema": "aten::cudnn_batch_norm_backward(Tensor input, Tensor grad_output, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, float epsilon, Tensor reserveSpace) -> (Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"} +at::Tensor cudnn_convolution(const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, bool benchmark, bool deterministic, bool allow_tf32); // {"schema": "aten::cudnn_convolution(Tensor self, Tensor weight, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, bool allow_tf32) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & cudnn_convolution_out(const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, bool benchmark, bool deterministic, bool allow_tf32, at::Tensor & out); // {"schema": "aten::cudnn_convolution.out(Tensor self, Tensor weight, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, bool allow_tf32, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor cudnn_convolution_transpose(const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, bool benchmark, bool deterministic, bool allow_tf32); // {"schema": "aten::cudnn_convolution_transpose(Tensor self, Tensor weight, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, bool allow_tf32) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _mps_convolution_transpose(const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups); // {"schema": "aten::_mps_convolution_transpose(Tensor self, Tensor weight, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups) -> Tensor", "dispatch": "True", "default": "False"} +::std::tuple mps_convolution_transpose_backward(const at::Tensor & self, const at::Tensor & grad_output, const at::Tensor & weight, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, ::std::array output_mask); // {"schema": "aten::mps_convolution_transpose_backward(Tensor self, Tensor grad_output, Tensor weight, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool[2] output_mask) -> (Tensor, Tensor)", "dispatch": "True", "default": "False"} +at::Tensor cudnn_convolution_relu(const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, c10::SymInt groups); // {"schema": "aten::cudnn_convolution_relu(Tensor self, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, SymInt groups) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor cudnn_convolution_add_relu(const at::Tensor & self, const at::Tensor & weight, const at::Tensor & z, const ::std::optional & alpha, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, c10::SymInt groups); // {"schema": "aten::cudnn_convolution_add_relu(Tensor self, Tensor weight, Tensor z, Scalar? alpha, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, SymInt groups) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor cudnn_grid_sampler(const at::Tensor & self, const at::Tensor & grid); // {"schema": "aten::cudnn_grid_sampler(Tensor self, Tensor grid) -> Tensor output", "dispatch": "True", "default": "False"} +::std::tuple cudnn_grid_sampler_backward(const at::Tensor & self, const at::Tensor & grid, const at::Tensor & grad_output); // {"schema": "aten::cudnn_grid_sampler_backward(Tensor self, Tensor grid, Tensor grad_output) -> (Tensor grad_self, Tensor grad_grid)", "dispatch": "True", "default": "False"} +::std::tuple cummax(const at::Tensor & self, int64_t dim); // {"schema": "aten::cummax(Tensor self, int dim) -> (Tensor values, Tensor indices)", "dispatch": "True", "default": "True"} +::std::tuple cummax_out(const at::Tensor & self, int64_t dim, at::Tensor & values, at::Tensor & indices); // {"schema": "aten::cummax.out(Tensor self, int dim, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)", "dispatch": "True", "default": "True"} +::std::tuple cummax(const at::Tensor & self, at::Dimname dim); // {"schema": "aten::cummax.dimname(Tensor self, Dimname dim) -> (Tensor values, Tensor indices)", "dispatch": "False", "default": "True"} +::std::tuple cummax_out(const at::Tensor & self, at::Dimname dim, at::Tensor & values, at::Tensor & indices); // {"schema": "aten::cummax.dimname_out(Tensor self, Dimname dim, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)", "dispatch": "False", "default": "True"} +void _cummax_helper(const at::Tensor & self, at::Tensor & values, at::Tensor & indices, int64_t dim); // {"schema": "aten::_cummax_helper(Tensor self, Tensor(a!) values, Tensor(b!) indices, int dim) -> ()", "dispatch": "True", "default": "False"} +::std::tuple cummin(const at::Tensor & self, int64_t dim); // {"schema": "aten::cummin(Tensor self, int dim) -> (Tensor values, Tensor indices)", "dispatch": "True", "default": "True"} +::std::tuple cummin_out(const at::Tensor & self, int64_t dim, at::Tensor & values, at::Tensor & indices); // {"schema": "aten::cummin.out(Tensor self, int dim, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)", "dispatch": "True", "default": "True"} +::std::tuple cummin(const at::Tensor & self, at::Dimname dim); // {"schema": "aten::cummin.dimname(Tensor self, Dimname dim) -> (Tensor values, Tensor indices)", "dispatch": "False", "default": "True"} +::std::tuple cummin_out(const at::Tensor & self, at::Dimname dim, at::Tensor & values, at::Tensor & indices); // {"schema": "aten::cummin.dimname_out(Tensor self, Dimname dim, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)", "dispatch": "False", "default": "True"} +void _cummin_helper(const at::Tensor & self, at::Tensor & values, at::Tensor & indices, int64_t dim); // {"schema": "aten::_cummin_helper(Tensor self, Tensor(a!) values, Tensor(b!) indices, int dim) -> ()", "dispatch": "True", "default": "False"} +at::Tensor cummaxmin_backward(const at::Tensor & grad, const at::Tensor & input, const at::Tensor & indices, int64_t dim); // {"schema": "aten::cummaxmin_backward(Tensor grad, Tensor input, Tensor indices, int dim) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor cumprod(const at::Tensor & self, int64_t dim, ::std::optional dtype); // {"schema": "aten::cumprod(Tensor self, int dim, *, ScalarType? dtype=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & cumprod_(at::Tensor & self, int64_t dim, ::std::optional dtype); // {"schema": "aten::cumprod_(Tensor(a!) self, int dim, *, ScalarType? dtype=None) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & cumprod_out(const at::Tensor & self, int64_t dim, ::std::optional dtype, at::Tensor & out); // {"schema": "aten::cumprod.out(Tensor self, int dim, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor cumprod(const at::Tensor & self, at::Dimname dim, ::std::optional dtype); // {"schema": "aten::cumprod.dimname(Tensor self, Dimname dim, *, ScalarType? dtype=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & cumprod_(at::Tensor & self, at::Dimname dim, ::std::optional dtype); // {"schema": "aten::cumprod_.dimname(Tensor(a!) self, Dimname dim, *, ScalarType? dtype=None) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & cumprod_out(const at::Tensor & self, at::Dimname dim, ::std::optional dtype, at::Tensor & out); // {"schema": "aten::cumprod.dimname_out(Tensor self, Dimname dim, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor cumprod_backward(const at::Tensor & grad, const at::Tensor & input, int64_t dim, const at::Tensor & output); // {"schema": "aten::cumprod_backward(Tensor grad, Tensor input, int dim, Tensor output) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor cumsum(const at::Tensor & self, int64_t dim, ::std::optional dtype); // {"schema": "aten::cumsum(Tensor self, int dim, *, ScalarType? dtype=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & cumsum_(at::Tensor & self, int64_t dim, ::std::optional dtype); // {"schema": "aten::cumsum_(Tensor(a!) self, int dim, *, ScalarType? dtype=None) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & cumsum_out(const at::Tensor & self, int64_t dim, ::std::optional dtype, at::Tensor & out); // {"schema": "aten::cumsum.out(Tensor self, int dim, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor cumsum(const at::Tensor & self, at::Dimname dim, ::std::optional dtype); // {"schema": "aten::cumsum.dimname(Tensor self, Dimname dim, *, ScalarType? dtype=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & cumsum_(at::Tensor & self, at::Dimname dim, ::std::optional dtype); // {"schema": "aten::cumsum_.dimname(Tensor(a!) self, Dimname dim, *, ScalarType? dtype=None) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & cumsum_out(const at::Tensor & self, at::Dimname dim, ::std::optional dtype, at::Tensor & out); // {"schema": "aten::cumsum.dimname_out(Tensor self, Dimname dim, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor cumulative_trapezoid(const at::Tensor & y, const at::Tensor & x, int64_t dim); // {"schema": "aten::cumulative_trapezoid.x(Tensor y, Tensor x, *, int dim=-1) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor cumulative_trapezoid(const at::Tensor & y, const at::Scalar & dx, int64_t dim); // {"schema": "aten::cumulative_trapezoid.dx(Tensor y, *, Scalar dx=1, int dim=-1) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor ctc_loss(const at::Tensor & log_probs, const at::Tensor & targets, at::IntArrayRef input_lengths, at::IntArrayRef target_lengths, int64_t blank, int64_t reduction, bool zero_infinity); // {"schema": "aten::ctc_loss.IntList(Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, int blank=0, int reduction=Mean, bool zero_infinity=False) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor ctc_loss(const at::Tensor & log_probs, const at::Tensor & targets, const at::Tensor & input_lengths, const at::Tensor & target_lengths, int64_t blank, int64_t reduction, bool zero_infinity); // {"schema": "aten::ctc_loss.Tensor(Tensor log_probs, Tensor targets, Tensor input_lengths, Tensor target_lengths, int blank=0, int reduction=Mean, bool zero_infinity=False) -> Tensor", "dispatch": "False", "default": "True"} +::std::tuple _ctc_loss(const at::Tensor & log_probs, const at::Tensor & targets, at::IntArrayRef input_lengths, at::IntArrayRef target_lengths, int64_t blank, bool zero_infinity); // {"schema": "aten::_ctc_loss(Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, int blank=0, bool zero_infinity=False) -> (Tensor, Tensor)", "dispatch": "True", "default": "False"} +::std::tuple _ctc_loss(const at::Tensor & log_probs, const at::Tensor & targets, const at::Tensor & input_lengths, const at::Tensor & target_lengths, int64_t blank, bool zero_infinity); // {"schema": "aten::_ctc_loss.Tensor(Tensor log_probs, Tensor targets, Tensor input_lengths, Tensor target_lengths, int blank=0, bool zero_infinity=False) -> (Tensor, Tensor)", "dispatch": "True", "default": "False"} +at::Tensor _ctc_loss_backward(const at::Tensor & grad, const at::Tensor & log_probs, const at::Tensor & targets, at::IntArrayRef input_lengths, at::IntArrayRef target_lengths, const at::Tensor & neg_log_likelihood, const at::Tensor & log_alpha, int64_t blank, bool zero_infinity); // {"schema": "aten::_ctc_loss_backward(Tensor grad, Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, Tensor neg_log_likelihood, Tensor log_alpha, int blank, bool zero_infinity=False) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _ctc_loss_backward(const at::Tensor & grad, const at::Tensor & log_probs, const at::Tensor & targets, const at::Tensor & input_lengths, const at::Tensor & target_lengths, const at::Tensor & neg_log_likelihood, const at::Tensor & log_alpha, int64_t blank, bool zero_infinity); // {"schema": "aten::_ctc_loss_backward.Tensor(Tensor grad, Tensor log_probs, Tensor targets, Tensor input_lengths, Tensor target_lengths, Tensor neg_log_likelihood, Tensor log_alpha, int blank, bool zero_infinity=False) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor diag_embed(const at::Tensor & self, int64_t offset, int64_t dim1, int64_t dim2); // {"schema": "aten::diag_embed(Tensor self, int offset=0, int dim1=-2, int dim2=-1) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor diagflat(const at::Tensor & self, int64_t offset); // {"schema": "aten::diagflat(Tensor self, int offset=0) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor diagonal(const at::Tensor & self, int64_t offset, int64_t dim1, int64_t dim2); // {"schema": "aten::diagonal(Tensor(a) self, int offset=0, int dim1=0, int dim2=1) -> Tensor(a)", "dispatch": "True", "default": "True"} +at::Tensor linalg_diagonal(const at::Tensor & A, int64_t offset, int64_t dim1, int64_t dim2); // {"schema": "aten::linalg_diagonal(Tensor(a) A, *, int offset=0, int dim1=-2, int dim2=-1) -> Tensor(a)", "dispatch": "False", "default": "True"} +at::Tensor diagonal(const at::Tensor & self, at::Dimname outdim, at::Dimname dim1, at::Dimname dim2, int64_t offset); // {"schema": "aten::diagonal.Dimname(Tensor(a) self, *, Dimname outdim, Dimname dim1, Dimname dim2, int offset=0) -> Tensor(a)", "dispatch": "False", "default": "True"} +at::Tensor diagonal_backward(const at::Tensor & grad_output, c10::SymIntArrayRef input_sizes, int64_t offset, int64_t dim1, int64_t dim2); // {"schema": "aten::diagonal_backward(Tensor grad_output, SymInt[] input_sizes, int offset, int dim1, int dim2) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & fill_diagonal_(at::Tensor & self, const at::Scalar & fill_value, bool wrap); // {"schema": "aten::fill_diagonal_(Tensor(a!) self, Scalar fill_value, bool wrap=False) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor diff(const at::Tensor & self, int64_t n, int64_t dim, const ::std::optional & prepend, const ::std::optional & append); // {"schema": "aten::diff(Tensor self, int n=1, int dim=-1, Tensor? prepend=None, Tensor? append=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & diff_out(const at::Tensor & self, int64_t n, int64_t dim, const ::std::optional & prepend, const ::std::optional & append, at::Tensor & out); // {"schema": "aten::diff.out(Tensor self, int n=1, int dim=-1, Tensor? prepend=None, Tensor? append=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +::std::vector gradient(const at::Tensor & self, const ::std::optional & spacing, ::std::optional dim, int64_t edge_order); // {"schema": "aten::gradient.scalarint(Tensor self, *, Scalar? spacing=None, int? dim=None, int edge_order=1) -> Tensor[]", "dispatch": "False", "default": "True"} +::std::vector gradient(const at::Tensor & self, const at::Scalar & spacing, at::IntArrayRef dim, int64_t edge_order); // {"schema": "aten::gradient.scalararray(Tensor self, *, Scalar spacing, int[] dim, int edge_order=1) -> Tensor[]", "dispatch": "False", "default": "True"} +::std::vector gradient(const at::Tensor & self, at::IntArrayRef dim, int64_t edge_order); // {"schema": "aten::gradient.array(Tensor self, *, int[] dim, int edge_order=1) -> Tensor[]", "dispatch": "False", "default": "True"} +::std::vector gradient(const at::Tensor & self, at::ArrayRef spacing, ::std::optional dim, int64_t edge_order); // {"schema": "aten::gradient.scalarrayint(Tensor self, *, Scalar[] spacing, int? dim=None, int edge_order=1) -> Tensor[]", "dispatch": "False", "default": "True"} +::std::vector gradient(const at::Tensor & self, at::ArrayRef spacing, at::IntArrayRef dim, int64_t edge_order); // {"schema": "aten::gradient.scalarrayarray(Tensor self, *, Scalar[] spacing, int[] dim, int edge_order=1) -> Tensor[]", "dispatch": "False", "default": "True"} +::std::vector gradient(const at::Tensor & self, at::TensorList spacing, ::std::optional dim, int64_t edge_order); // {"schema": "aten::gradient.tensorarrayint(Tensor self, *, Tensor[] spacing, int? dim=None, int edge_order=1) -> Tensor[]", "dispatch": "False", "default": "True"} +::std::vector gradient(const at::Tensor & self, at::TensorList spacing, at::IntArrayRef dim, int64_t edge_order); // {"schema": "aten::gradient.tensorarray(Tensor self, *, Tensor[] spacing, int[] dim, int edge_order=1) -> Tensor[]", "dispatch": "False", "default": "True"} +at::Tensor div(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::div.Tensor(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & div_(at::Tensor & self, const at::Tensor & other); // {"schema": "aten::div_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & div_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::div.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor div(const at::Tensor & self, const at::Tensor & other, ::std::optional rounding_mode); // {"schema": "aten::div.Tensor_mode(Tensor self, Tensor other, *, str? rounding_mode) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & div_(at::Tensor & self, const at::Tensor & other, ::std::optional rounding_mode); // {"schema": "aten::div_.Tensor_mode(Tensor(a!) self, Tensor other, *, str? rounding_mode) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & div_out(const at::Tensor & self, const at::Tensor & other, ::std::optional rounding_mode, at::Tensor & out); // {"schema": "aten::div.out_mode(Tensor self, Tensor other, *, str? rounding_mode, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor div(const at::Tensor & self, const at::Scalar & other); // {"schema": "aten::div.Scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & div_(at::Tensor & self, const at::Scalar & other); // {"schema": "aten::div_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor div(const at::Tensor & self, const at::Scalar & other, ::std::optional rounding_mode); // {"schema": "aten::div.Scalar_mode(Tensor self, Scalar other, *, str? rounding_mode) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & div_(at::Tensor & self, const at::Scalar & other, ::std::optional rounding_mode); // {"schema": "aten::div_.Scalar_mode(Tensor(a!) self, Scalar other, *, str? rounding_mode) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor divide(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::divide.Tensor(Tensor self, Tensor other) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & divide_(at::Tensor & self, const at::Tensor & other); // {"schema": "aten::divide_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & divide_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::divide.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor divide(const at::Tensor & self, const at::Scalar & other); // {"schema": "aten::divide.Scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & divide_(at::Tensor & self, const at::Scalar & other); // {"schema": "aten::divide_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor divide(const at::Tensor & self, const at::Tensor & other, ::std::optional rounding_mode); // {"schema": "aten::divide.Tensor_mode(Tensor self, Tensor other, *, str? rounding_mode) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & divide_(at::Tensor & self, const at::Tensor & other, ::std::optional rounding_mode); // {"schema": "aten::divide_.Tensor_mode(Tensor(a!) self, Tensor other, *, str? rounding_mode) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & divide_out(const at::Tensor & self, const at::Tensor & other, ::std::optional rounding_mode, at::Tensor & out); // {"schema": "aten::divide.out_mode(Tensor self, Tensor other, *, str? rounding_mode, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor divide(const at::Tensor & self, const at::Scalar & other, ::std::optional rounding_mode); // {"schema": "aten::divide.Scalar_mode(Tensor self, Scalar other, *, str? rounding_mode) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & divide_(at::Tensor & self, const at::Scalar & other, ::std::optional rounding_mode); // {"schema": "aten::divide_.Scalar_mode(Tensor(a!) self, Scalar other, *, str? rounding_mode) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor true_divide(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::true_divide.Tensor(Tensor self, Tensor other) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & true_divide_(at::Tensor & self, const at::Tensor & other); // {"schema": "aten::true_divide_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & true_divide_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::true_divide.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor true_divide(const at::Tensor & self, const at::Scalar & other); // {"schema": "aten::true_divide.Scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & true_divide_(at::Tensor & self, const at::Scalar & other); // {"schema": "aten::true_divide_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor dot(const at::Tensor & self, const at::Tensor & tensor); // {"schema": "aten::dot(Tensor self, Tensor tensor) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & dot_out(const at::Tensor & self, const at::Tensor & tensor, at::Tensor & out); // {"schema": "aten::dot.out(Tensor self, Tensor tensor, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor vdot(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::vdot(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & vdot_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::vdot.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor einsum(c10::string_view equation, at::TensorList tensors, at::OptionalIntArrayRef path); // {"schema": "aten::einsum(str equation, Tensor[] tensors, *, int[]? path=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor embedding(const at::Tensor & weight, const at::Tensor & indices, c10::SymInt padding_idx, bool scale_grad_by_freq, bool sparse); // {"schema": "aten::embedding(Tensor weight, Tensor indices, SymInt padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor embedding_backward(const at::Tensor & grad, const at::Tensor & indices, c10::SymInt num_weights, c10::SymInt padding_idx, bool scale_grad_by_freq, bool sparse); // {"schema": "aten::embedding_backward(Tensor grad, Tensor indices, SymInt num_weights, SymInt padding_idx, bool scale_grad_by_freq, bool sparse) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor embedding_dense_backward(const at::Tensor & grad_output, const at::Tensor & indices, c10::SymInt num_weights, c10::SymInt padding_idx, bool scale_grad_by_freq); // {"schema": "aten::embedding_dense_backward(Tensor grad_output, Tensor indices, SymInt num_weights, SymInt padding_idx, bool scale_grad_by_freq) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & embedding_renorm_(at::Tensor & self, const at::Tensor & indices, double max_norm, double norm_type); // {"schema": "aten::embedding_renorm_(Tensor(a!) self, Tensor indices, float max_norm, float norm_type) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor embedding_sparse_backward(const at::Tensor & grad, const at::Tensor & indices, int64_t num_weights, int64_t padding_idx, bool scale_grad_by_freq); // {"schema": "aten::embedding_sparse_backward(Tensor grad, Tensor indices, int num_weights, int padding_idx, bool scale_grad_by_freq) -> Tensor", "dispatch": "False", "default": "True"} +::std::tuple _embedding_bag_forward_only(const at::Tensor & weight, const at::Tensor & indices, const at::Tensor & offsets, bool scale_grad_by_freq, int64_t mode, bool sparse, const ::std::optional & per_sample_weights, bool include_last_offset, int64_t padding_idx); // {"schema": "aten::_embedding_bag_forward_only(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False, int padding_idx=-1) -> (Tensor, Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"} +::std::tuple _rowwise_prune(const at::Tensor & weight, const at::Tensor & mask, at::ScalarType compressed_indices_dtype); // {"schema": "aten::_rowwise_prune(Tensor weight, Tensor mask, ScalarType compressed_indices_dtype) -> (Tensor, Tensor)", "dispatch": "False", "default": "True"} +at::Tensor row_stack(at::TensorList tensors); // {"schema": "aten::row_stack(Tensor[] tensors) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & row_stack_out(at::TensorList tensors, at::Tensor & out); // {"schema": "aten::row_stack.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +::std::tuple embedding_bag(const at::Tensor & weight, const at::Tensor & indices, const at::Tensor & offsets, bool scale_grad_by_freq, int64_t mode, bool sparse, const ::std::optional & per_sample_weights, bool include_last_offset); // {"schema": "aten::embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False) -> (Tensor, Tensor, Tensor, Tensor)", "dispatch": "False", "default": "True"} +::std::tuple embedding_bag(const at::Tensor & weight, const at::Tensor & indices, const at::Tensor & offsets, bool scale_grad_by_freq, int64_t mode, bool sparse, const ::std::optional & per_sample_weights, bool include_last_offset, ::std::optional padding_idx); // {"schema": "aten::embedding_bag.padding_idx(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq, int mode, bool sparse, Tensor? per_sample_weights, bool include_last_offset, int? padding_idx) -> (Tensor, Tensor, Tensor, Tensor)", "dispatch": "False", "default": "True"} +::std::tuple _embedding_bag(const at::Tensor & weight, const at::Tensor & indices, const at::Tensor & offsets, bool scale_grad_by_freq, int64_t mode, bool sparse, const ::std::optional & per_sample_weights, bool include_last_offset, int64_t padding_idx); // {"schema": "aten::_embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False, int padding_idx=-1) -> (Tensor, Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"} +at::Tensor _embedding_bag_backward(const at::Tensor & grad, const at::Tensor & indices, const at::Tensor & offsets, const at::Tensor & offset2bag, const at::Tensor & bag_size, const at::Tensor & maximum_indices, c10::SymInt num_weights, bool scale_grad_by_freq, int64_t mode, bool sparse, const ::std::optional & per_sample_weights, int64_t padding_idx); // {"schema": "aten::_embedding_bag_backward(Tensor grad, Tensor indices, Tensor offsets, Tensor offset2bag, Tensor bag_size, Tensor maximum_indices, SymInt num_weights, bool scale_grad_by_freq, int mode, bool sparse, Tensor? per_sample_weights, int padding_idx=-1) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _embedding_bag_sparse_backward(const at::Tensor & grad, const at::Tensor & indices, const at::Tensor & offsets, const at::Tensor & offset2bag, const at::Tensor & bag_size, c10::SymInt num_weights, bool scale_grad_by_freq, int64_t mode, const ::std::optional & per_sample_weights, int64_t padding_idx); // {"schema": "aten::_embedding_bag_sparse_backward(Tensor grad, Tensor indices, Tensor offsets, Tensor offset2bag, Tensor bag_size, SymInt num_weights, bool scale_grad_by_freq, int mode, Tensor? per_sample_weights, int padding_idx=-1) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _embedding_bag_dense_backward(const at::Tensor & grad, const at::Tensor & indices, const at::Tensor & offset2bag, const at::Tensor & bag_size, const at::Tensor & maximum_indices, c10::SymInt num_weights, bool scale_grad_by_freq, int64_t mode, const ::std::optional & per_sample_weights, int64_t padding_idx); // {"schema": "aten::_embedding_bag_dense_backward(Tensor grad, Tensor indices, Tensor offset2bag, Tensor bag_size, Tensor maximum_indices, SymInt num_weights, bool scale_grad_by_freq, int mode, Tensor? per_sample_weights, int padding_idx=-1) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _embedding_bag_per_sample_weights_backward(const at::Tensor & grad, const at::Tensor & weight, const at::Tensor & indices, const at::Tensor & offsets, const at::Tensor & offset2bag, int64_t mode, int64_t padding_idx); // {"schema": "aten::_embedding_bag_per_sample_weights_backward(Tensor grad, Tensor weight, Tensor indices, Tensor offsets, Tensor offset2bag, int mode, int padding_idx=-1) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor empty(at::IntArrayRef size, ::std::optional names, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format); // {"schema": "aten::empty.names(int[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor empty(c10::SymIntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format); // {"schema": "aten::empty.memory_format(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor empty_permuted(c10::SymIntArrayRef size, at::IntArrayRef physical_layout, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::empty_permuted(SymInt[] size, int[] physical_layout, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor new_empty(const at::Tensor & self, c10::SymIntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::new_empty(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor new_empty_strided(const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::new_empty_strided(Tensor self, SymInt[] size, SymInt[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor new_full(const at::Tensor & self, c10::SymIntArrayRef size, const at::Scalar & fill_value, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::new_full(Tensor self, SymInt[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor new_zeros(const at::Tensor & self, c10::SymIntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::new_zeros(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor new_ones(const at::Tensor & self, c10::SymIntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::new_ones(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor _empty_affine_quantized(c10::SymIntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, double scale, int64_t zero_point, ::std::optional memory_format); // {"schema": "aten::_empty_affine_quantized(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, float scale=1, int zero_point=0, MemoryFormat? memory_format=contiguous_format) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _empty_per_channel_affine_quantized(c10::SymIntArrayRef size, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format); // {"schema": "aten::_empty_per_channel_affine_quantized(SymInt[] size, *, Tensor scales, Tensor zero_points, int axis, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=contiguous_format) -> Tensor", "dispatch": "True", "default": "False"} +const at::Tensor & resize_(const at::Tensor & self, c10::SymIntArrayRef size, ::std::optional memory_format); // {"schema": "aten::resize_(Tensor(a!) self, SymInt[] size, *, MemoryFormat? memory_format=None) -> Tensor(a!)", "dispatch": "True", "default": "False"} +const at::Tensor & _resize_output_(const at::Tensor & self, c10::SymIntArrayRef size, at::Device device); // {"schema": "aten::_resize_output_(Tensor(a!) self, SymInt[] size, Device device) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor empty_quantized(at::IntArrayRef size, const at::Tensor & qtensor, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format); // {"schema": "aten::empty_quantized(int[] size, Tensor qtensor, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & empty_out(c10::SymIntArrayRef size, ::std::optional memory_format, at::Tensor & out); // {"schema": "aten::empty.out(SymInt[] size, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor empty_like(const at::Tensor & self, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format); // {"schema": "aten::empty_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor empty_strided(c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::empty_strided(SymInt[] size, SymInt[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor erf(const at::Tensor & self); // {"schema": "aten::erf(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & erf_(at::Tensor & self); // {"schema": "aten::erf_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & erf_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::erf.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor erfc(const at::Tensor & self); // {"schema": "aten::erfc(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & erfc_(at::Tensor & self); // {"schema": "aten::erfc_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & erfc_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::erfc.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor exp(const at::Tensor & self); // {"schema": "aten::exp(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & exp_(at::Tensor & self); // {"schema": "aten::exp_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & exp_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::exp.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor exp2(const at::Tensor & self); // {"schema": "aten::exp2(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & exp2_(at::Tensor & self); // {"schema": "aten::exp2_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & exp2_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::exp2.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor expm1(const at::Tensor & self); // {"schema": "aten::expm1(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & expm1_(at::Tensor & self); // {"schema": "aten::expm1_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & expm1_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::expm1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor expand(const at::Tensor & self, c10::SymIntArrayRef size, bool implicit); // {"schema": "aten::expand(Tensor(a) self, SymInt[] size, *, bool implicit=False) -> Tensor(a)", "dispatch": "True", "default": "True"} +at::Tensor expand_as(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::expand_as(Tensor(a) self, Tensor other) -> Tensor(a)", "dispatch": "False", "default": "True"} +at::Tensor eye(c10::SymInt n, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::eye(SymInt n, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor eye(c10::SymInt n, c10::SymInt m, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::eye.m(SymInt n, SymInt m, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & eye_out(c10::SymInt n, at::Tensor & out); // {"schema": "aten::eye.out(SymInt n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & eye_out(c10::SymInt n, c10::SymInt m, at::Tensor & out); // {"schema": "aten::eye.m_out(SymInt n, SymInt m, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor flatten(const at::Tensor & self, int64_t start_dim, int64_t end_dim); // {"schema": "aten::flatten.using_ints(Tensor(a) self, int start_dim=0, int end_dim=-1) -> Tensor(a)", "dispatch": "False", "default": "True"} +at::Tensor flatten(const at::Tensor & self, int64_t start_dim, int64_t end_dim, at::Dimname out_dim); // {"schema": "aten::flatten.named_out_dim(Tensor(a) self, int start_dim, int end_dim, Dimname out_dim) -> Tensor(a)", "dispatch": "False", "default": "True"} +at::Tensor flatten(const at::Tensor & self, at::Dimname start_dim, at::Dimname end_dim, at::Dimname out_dim); // {"schema": "aten::flatten.using_names(Tensor(a) self, Dimname start_dim, Dimname end_dim, Dimname out_dim) -> Tensor(a)", "dispatch": "False", "default": "True"} +at::Tensor flatten(const at::Tensor & self, at::DimnameList dims, at::Dimname out_dim); // {"schema": "aten::flatten.DimnameList(Tensor(a) self, Dimname[] dims, Dimname out_dim) -> Tensor(a)", "dispatch": "False", "default": "True"} +at::Tensor unflatten(const at::Tensor & self, int64_t dim, c10::SymIntArrayRef sizes); // {"schema": "aten::unflatten.int(Tensor(a) self, int dim, SymInt[] sizes) -> Tensor(a)", "dispatch": "False", "default": "True"} +at::Tensor unflatten(const at::Tensor & self, at::Dimname dim, c10::SymIntArrayRef sizes, at::DimnameList names); // {"schema": "aten::unflatten.Dimname(Tensor(a) self, Dimname dim, SymInt[] sizes, Dimname[] names) -> Tensor(a)", "dispatch": "False", "default": "True"} +at::Tensor fill(const at::Tensor & self, const at::Scalar & value); // {"schema": "aten::fill.Scalar(Tensor self, Scalar value) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor fill(const at::Tensor & self, const at::Tensor & value); // {"schema": "aten::fill.Tensor(Tensor self, Tensor value) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & fill_(at::Tensor & self, const at::Scalar & value); // {"schema": "aten::fill_.Scalar(Tensor(a!) self, Scalar value) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & fill_(at::Tensor & self, const at::Tensor & value); // {"schema": "aten::fill_.Tensor(Tensor(a!) self, Tensor value) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor floor(const at::Tensor & self); // {"schema": "aten::floor(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & floor_(at::Tensor & self); // {"schema": "aten::floor_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & floor_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::floor.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor floor_divide(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::floor_divide(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & floor_divide_(at::Tensor & self, const at::Tensor & other); // {"schema": "aten::floor_divide_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & floor_divide_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::floor_divide.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor floor_divide(const at::Tensor & self, const at::Scalar & other); // {"schema": "aten::floor_divide.Scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & floor_divide_(at::Tensor & self, const at::Scalar & other); // {"schema": "aten::floor_divide_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor frac(const at::Tensor & self); // {"schema": "aten::frac(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & frac_(at::Tensor & self); // {"schema": "aten::frac_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & frac_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::frac.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor full(at::IntArrayRef size, const at::Scalar & fill_value, ::std::optional names, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::full.names(int[] size, Scalar fill_value, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor full(c10::SymIntArrayRef size, const at::Scalar & fill_value, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::full(SymInt[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & full_out(c10::SymIntArrayRef size, const at::Scalar & fill_value, at::Tensor & out); // {"schema": "aten::full.out(SymInt[] size, Scalar fill_value, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor full_like(const at::Tensor & self, const at::Scalar & fill_value, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format); // {"schema": "aten::full_like(Tensor self, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor from_file(c10::string_view filename, ::std::optional shared, ::std::optional size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::from_file(str filename, bool? shared=None, int? size=0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & gcd_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::gcd.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor gcd(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::gcd(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & gcd_(at::Tensor & self, const at::Tensor & other); // {"schema": "aten::gcd_(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & lcm_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::lcm.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor lcm(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::lcm(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & lcm_(at::Tensor & self, const at::Tensor & other); // {"schema": "aten::lcm_(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor grid_sampler(const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners); // {"schema": "aten::grid_sampler(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor grid_sampler_2d(const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners); // {"schema": "aten::grid_sampler_2d(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners) -> Tensor", "dispatch": "True", "default": "False"} +::std::tuple grid_sampler_2d_backward(const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners, ::std::array output_mask); // {"schema": "aten::grid_sampler_2d_backward(Tensor grad_output, Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners, bool[2] output_mask) -> (Tensor, Tensor)", "dispatch": "True", "default": "False"} +at::Tensor _grid_sampler_2d_cpu_fallback(const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners); // {"schema": "aten::_grid_sampler_2d_cpu_fallback(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners) -> Tensor", "dispatch": "True", "default": "True"} +::std::tuple _grid_sampler_2d_cpu_fallback_backward(const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners); // {"schema": "aten::_grid_sampler_2d_cpu_fallback_backward(Tensor grad_output, Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners) -> (Tensor, Tensor)", "dispatch": "False", "default": "True"} +at::Tensor grid_sampler_3d(const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners); // {"schema": "aten::grid_sampler_3d(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners) -> Tensor", "dispatch": "True", "default": "False"} +::std::tuple grid_sampler_3d_backward(const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners, ::std::array output_mask); // {"schema": "aten::grid_sampler_3d_backward(Tensor grad_output, Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners, bool[2] output_mask) -> (Tensor, Tensor)", "dispatch": "True", "default": "False"} +at::Tensor hann_window(int64_t window_length, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::hann_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor hann_window(int64_t window_length, bool periodic, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::hann_window.periodic(int window_length, bool periodic, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor hamming_window(int64_t window_length, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::hamming_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor hamming_window(int64_t window_length, bool periodic, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::hamming_window.periodic(int window_length, bool periodic, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor hamming_window(int64_t window_length, bool periodic, double alpha, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::hamming_window.periodic_alpha(int window_length, bool periodic, float alpha, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor hamming_window(int64_t window_length, bool periodic, double alpha, double beta, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::hamming_window.periodic_alpha_beta(int window_length, bool periodic, float alpha, float beta, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor kaiser_window(int64_t window_length, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::kaiser_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor kaiser_window(int64_t window_length, bool periodic, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::kaiser_window.periodic(int window_length, bool periodic, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor kaiser_window(int64_t window_length, bool periodic, double beta, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::kaiser_window.beta(int window_length, bool periodic, float beta, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor hinge_embedding_loss(const at::Tensor & self, const at::Tensor & target, double margin, int64_t reduction); // {"schema": "aten::hinge_embedding_loss(Tensor self, Tensor target, float margin=1.0, int reduction=Mean) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor group_norm(const at::Tensor & input, int64_t num_groups, const ::std::optional & weight, const ::std::optional & bias, double eps, bool cudnn_enabled); // {"schema": "aten::group_norm(Tensor input, int num_groups, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enabled=True) -> Tensor", "dispatch": "False", "default": "True"} +::std::tuple native_group_norm(const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, c10::SymInt N, c10::SymInt C, c10::SymInt HxW, int64_t group, double eps); // {"schema": "aten::native_group_norm(Tensor input, Tensor? weight, Tensor? bias, SymInt N, SymInt C, SymInt HxW, int group, float eps) -> (Tensor, Tensor, Tensor)", "dispatch": "True", "default": "True"} +::std::tuple native_group_norm_backward(const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & rstd, const ::std::optional & weight, c10::SymInt N, c10::SymInt C, c10::SymInt HxW, int64_t group, ::std::array output_mask); // {"schema": "aten::native_group_norm_backward(Tensor grad_out, Tensor input, Tensor mean, Tensor rstd, Tensor? weight, SymInt N, SymInt C, SymInt HxW, int group, bool[3] output_mask) -> (Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"} +at::Tensor _fft_r2c(const at::Tensor & self, at::IntArrayRef dim, int64_t normalization, bool onesided); // {"schema": "aten::_fft_r2c(Tensor self, int[] dim, int normalization, bool onesided) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & _fft_r2c_out(const at::Tensor & self, at::IntArrayRef dim, int64_t normalization, bool onesided, at::Tensor & out); // {"schema": "aten::_fft_r2c.out(Tensor self, int[] dim, int normalization, bool onesided, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor _fft_c2r(const at::Tensor & self, at::IntArrayRef dim, int64_t normalization, c10::SymInt last_dim_size); // {"schema": "aten::_fft_c2r(Tensor self, int[] dim, int normalization, SymInt last_dim_size) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & _fft_c2r_out(const at::Tensor & self, at::IntArrayRef dim, int64_t normalization, c10::SymInt last_dim_size, at::Tensor & out); // {"schema": "aten::_fft_c2r.out(Tensor self, int[] dim, int normalization, SymInt last_dim_size, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor _fft_c2c(const at::Tensor & self, c10::SymIntArrayRef dim, int64_t normalization, bool forward); // {"schema": "aten::_fft_c2c(Tensor self, SymInt[] dim, int normalization, bool forward) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & _fft_c2c_out(const at::Tensor & self, c10::SymIntArrayRef dim, int64_t normalization, bool forward, at::Tensor & out); // {"schema": "aten::_fft_c2c.out(Tensor self, SymInt[] dim, int normalization, bool forward, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +void _validate_compressed_sparse_indices(bool is_crow, const at::Tensor & compressed_idx, const at::Tensor & plain_idx, int64_t cdim, int64_t dim, int64_t nnz); // {"schema": "aten::_validate_compressed_sparse_indices(bool is_crow, Tensor compressed_idx, Tensor plain_idx, int cdim, int dim, int nnz) -> ()", "dispatch": "True", "default": "False"} +int64_t _cufft_get_plan_cache_size(at::DeviceIndex device_index); // {"schema": "aten::_cufft_get_plan_cache_size(DeviceIndex device_index) -> int", "dispatch": "False", "default": "True"} +int64_t _cufft_get_plan_cache_max_size(at::DeviceIndex device_index); // {"schema": "aten::_cufft_get_plan_cache_max_size(DeviceIndex device_index) -> int", "dispatch": "False", "default": "True"} +void _cufft_set_plan_cache_max_size(at::DeviceIndex device_index, int64_t max_size); // {"schema": "aten::_cufft_set_plan_cache_max_size(DeviceIndex device_index, int max_size) -> ()", "dispatch": "False", "default": "True"} +void _cufft_clear_plan_cache(at::DeviceIndex device_index); // {"schema": "aten::_cufft_clear_plan_cache(DeviceIndex device_index) -> ()", "dispatch": "False", "default": "True"} +at::Tensor index(const at::Tensor & self, const c10::List<::std::optional> & indices); // {"schema": "aten::index.Tensor(Tensor self, Tensor?[] indices) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & index_out(const at::Tensor & self, const c10::List<::std::optional> & indices, at::Tensor & out); // {"schema": "aten::index.Tensor_out(Tensor self, Tensor?[] indices, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor _unsafe_index(const at::Tensor & self, const c10::List<::std::optional> & indices); // {"schema": "aten::_unsafe_index.Tensor(Tensor self, Tensor?[] indices) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor _unsafe_masked_index(const at::Tensor & self, const at::Tensor & mask, const c10::List<::std::optional> & indices, const at::Scalar & fill); // {"schema": "aten::_unsafe_masked_index(Tensor self, Tensor mask, Tensor?[] indices, Scalar fill) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor _unsafe_masked_index_put_accumulate(const at::Tensor & self, const at::Tensor & mask, const c10::List<::std::optional> & indices, const at::Tensor & values); // {"schema": "aten::_unsafe_masked_index_put_accumulate(Tensor self, Tensor mask, Tensor?[] indices, Tensor values) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & index_copy_out(const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & source, at::Tensor & out); // {"schema": "aten::index_copy.out(Tensor self, int dim, Tensor index, Tensor source, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & index_copy_(at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & source); // {"schema": "aten::index_copy_(Tensor(a!) self, int dim, Tensor index, Tensor source) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor index_copy(const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & source); // {"schema": "aten::index_copy(Tensor self, int dim, Tensor index, Tensor source) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & index_copy_(at::Tensor & self, at::Dimname dim, const at::Tensor & index, const at::Tensor & source); // {"schema": "aten::index_copy_.dimname(Tensor(a!) self, Dimname dim, Tensor index, Tensor source) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor index_copy(const at::Tensor & self, at::Dimname dim, const at::Tensor & index, const at::Tensor & source); // {"schema": "aten::index_copy.dimname(Tensor self, Dimname dim, Tensor index, Tensor source) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & index_put_(at::Tensor & self, const c10::List<::std::optional> & indices, const at::Tensor & values, bool accumulate); // {"schema": "aten::index_put_(Tensor(a!) self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor index_put(const at::Tensor & self, const c10::List<::std::optional> & indices, const at::Tensor & values, bool accumulate); // {"schema": "aten::index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor _unsafe_index_put(const at::Tensor & self, const c10::List<::std::optional> & indices, const at::Tensor & values, bool accumulate); // {"schema": "aten::_unsafe_index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & _index_put_impl_(at::Tensor & self, const c10::List<::std::optional> & indices, const at::Tensor & values, bool accumulate, bool unsafe); // {"schema": "aten::_index_put_impl_(Tensor(a!) self, Tensor?[] indices, Tensor values, bool accumulate=False, bool unsafe=False) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor instance_norm(const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const ::std::optional & running_mean, const ::std::optional & running_var, bool use_input_stats, double momentum, double eps, bool cudnn_enabled); // {"schema": "aten::instance_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool use_input_stats, float momentum, float eps, bool cudnn_enabled) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor isclose(const at::Tensor & self, const at::Tensor & other, double rtol, double atol, bool equal_nan); // {"schema": "aten::isclose(Tensor self, Tensor other, float rtol=1e-05, float atol=1e-08, bool equal_nan=False) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & isin_out(const at::Tensor & elements, const at::Tensor & test_elements, bool assume_unique, bool invert, at::Tensor & out); // {"schema": "aten::isin.Tensor_Tensor_out(Tensor elements, Tensor test_elements, *, bool assume_unique=False, bool invert=False, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor isin(const at::Tensor & elements, const at::Tensor & test_elements, bool assume_unique, bool invert); // {"schema": "aten::isin.Tensor_Tensor(Tensor elements, Tensor test_elements, *, bool assume_unique=False, bool invert=False) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & isin_out(const at::Tensor & elements, const at::Scalar & test_element, bool assume_unique, bool invert, at::Tensor & out); // {"schema": "aten::isin.Tensor_Scalar_out(Tensor elements, Scalar test_element, *, bool assume_unique=False, bool invert=False, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor isin(const at::Tensor & elements, const at::Scalar & test_element, bool assume_unique, bool invert); // {"schema": "aten::isin.Tensor_Scalar(Tensor elements, Scalar test_element, *, bool assume_unique=False, bool invert=False) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & isin_out(const at::Scalar & element, const at::Tensor & test_elements, bool assume_unique, bool invert, at::Tensor & out); // {"schema": "aten::isin.Scalar_Tensor_out(Scalar element, Tensor test_elements, *, bool assume_unique=False, bool invert=False, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor isin(const at::Scalar & element, const at::Tensor & test_elements, bool assume_unique, bool invert); // {"schema": "aten::isin.Scalar_Tensor(Scalar element, Tensor test_elements, *, bool assume_unique=False, bool invert=False) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor isnan(const at::Tensor & self); // {"schema": "aten::isnan(Tensor self) -> Tensor", "dispatch": "True", "default": "False"} +bool is_distributed(const at::Tensor & self); // {"schema": "aten::is_distributed(Tensor self) -> bool", "dispatch": "False", "default": "True"} +bool is_floating_point(const at::Tensor & self); // {"schema": "aten::is_floating_point(Tensor self) -> bool", "dispatch": "False", "default": "True"} +bool is_complex(const at::Tensor & self); // {"schema": "aten::is_complex(Tensor self) -> bool", "dispatch": "False", "default": "True"} +bool is_conj(const at::Tensor & self); // {"schema": "aten::is_conj(Tensor self) -> bool", "dispatch": "False", "default": "True"} +bool _is_zerotensor(const at::Tensor & self); // {"schema": "aten::_is_zerotensor(Tensor self) -> bool", "dispatch": "False", "default": "True"} +bool is_neg(const at::Tensor & self); // {"schema": "aten::is_neg(Tensor self) -> bool", "dispatch": "False", "default": "True"} +at::Tensor isreal(const at::Tensor & self); // {"schema": "aten::isreal(Tensor self) -> Tensor", "dispatch": "False", "default": "True"} +bool is_nonzero(const at::Tensor & self); // {"schema": "aten::is_nonzero(Tensor self) -> bool", "dispatch": "False", "default": "True"} +bool is_same_size(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::is_same_size(Tensor self, Tensor other) -> bool", "dispatch": "True", "default": "True"} +bool is_signed(const at::Tensor & self); // {"schema": "aten::is_signed(Tensor self) -> bool", "dispatch": "False", "default": "True"} +bool is_inference(const at::Tensor & self); // {"schema": "aten::is_inference(Tensor self) -> bool", "dispatch": "False", "default": "True"} +at::Tensor kl_div(const at::Tensor & self, const at::Tensor & target, int64_t reduction, bool log_target); // {"schema": "aten::kl_div(Tensor self, Tensor target, int reduction=Mean, *, bool log_target=False) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor kron(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::kron(Tensor self, Tensor other) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & kron_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::kron.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +::std::tuple kthvalue(const at::Tensor & self, c10::SymInt k, int64_t dim, bool keepdim); // {"schema": "aten::kthvalue(Tensor self, SymInt k, int dim=-1, bool keepdim=False) -> (Tensor values, Tensor indices)", "dispatch": "True", "default": "True"} +::std::tuple kthvalue_out(const at::Tensor & self, c10::SymInt k, int64_t dim, bool keepdim, at::Tensor & values, at::Tensor & indices); // {"schema": "aten::kthvalue.values(Tensor self, SymInt k, int dim=-1, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)", "dispatch": "True", "default": "False"} +::std::tuple kthvalue(const at::Tensor & self, c10::SymInt k, at::Dimname dim, bool keepdim); // {"schema": "aten::kthvalue.dimname(Tensor self, SymInt k, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices)", "dispatch": "False", "default": "True"} +::std::tuple kthvalue_out(const at::Tensor & self, c10::SymInt k, at::Dimname dim, bool keepdim, at::Tensor & values, at::Tensor & indices); // {"schema": "aten::kthvalue.dimname_out(Tensor self, SymInt k, Dimname dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)", "dispatch": "False", "default": "True"} +at::Tensor layer_norm(const at::Tensor & input, c10::SymIntArrayRef normalized_shape, const ::std::optional & weight, const ::std::optional & bias, double eps, bool cudnn_enable); // {"schema": "aten::layer_norm(Tensor input, SymInt[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> Tensor", "dispatch": "False", "default": "True"} +::std::tuple native_layer_norm(const at::Tensor & input, c10::SymIntArrayRef normalized_shape, const ::std::optional & weight, const ::std::optional & bias, double eps); // {"schema": "aten::native_layer_norm(Tensor input, SymInt[] normalized_shape, Tensor? weight, Tensor? bias, float eps) -> (Tensor, Tensor, Tensor)", "dispatch": "True", "default": "True"} +::std::tuple native_layer_norm_backward(const at::Tensor & grad_out, const at::Tensor & input, c10::SymIntArrayRef normalized_shape, const at::Tensor & mean, const at::Tensor & rstd, const ::std::optional & weight, const ::std::optional & bias, ::std::array output_mask); // {"schema": "aten::native_layer_norm_backward(Tensor grad_out, Tensor input, SymInt[] normalized_shape, Tensor mean, Tensor rstd, Tensor? weight, Tensor? bias, bool[3] output_mask) -> (Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"} +at::Tensor rms_norm(const at::Tensor & input, c10::SymIntArrayRef normalized_shape, const ::std::optional & weight, ::std::optional eps); // {"schema": "aten::rms_norm(Tensor input, SymInt[] normalized_shape, Tensor? weight=None, float? eps=None) -> Tensor", "dispatch": "False", "default": "True"} +::std::tuple _fused_rms_norm(const at::Tensor & input, at::IntArrayRef normalized_shape, const ::std::optional & weight, ::std::optional eps); // {"schema": "aten::_fused_rms_norm(Tensor input, int[] normalized_shape, Tensor? weight, float? eps) -> (Tensor, Tensor)", "dispatch": "True", "default": "True"} +::std::tuple _fused_rms_norm_backward(const at::Tensor & grad_out, const at::Tensor & input, at::IntArrayRef normalized_shape, const at::Tensor & rstd, const ::std::optional & weight, ::std::array output_mask); // {"schema": "aten::_fused_rms_norm_backward(Tensor grad_out, Tensor input, int[] normalized_shape, Tensor rstd, Tensor? weight, bool[2] output_mask) -> (Tensor, Tensor)", "dispatch": "True", "default": "False"} +at::Tensor nan_to_num(const at::Tensor & self, ::std::optional nan, ::std::optional posinf, ::std::optional neginf); // {"schema": "aten::nan_to_num(Tensor self, float? nan=None, float? posinf=None, float? neginf=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & nan_to_num_(at::Tensor & self, ::std::optional nan, ::std::optional posinf, ::std::optional neginf); // {"schema": "aten::nan_to_num_(Tensor(a!) self, float? nan=None, float? posinf=None, float? neginf=None) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & nan_to_num_out(const at::Tensor & self, ::std::optional nan, ::std::optional posinf, ::std::optional neginf, at::Tensor & out); // {"schema": "aten::nan_to_num.out(Tensor self, float? nan=None, float? posinf=None, float? neginf=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor linear(const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias); // {"schema": "aten::linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor", "dispatch": "True", "default": "True"} +::std::tuple linear_backward(const at::Tensor & self, const at::Tensor & grad_output, const at::Tensor & weight, ::std::array output_mask); // {"schema": "aten::linear_backward(Tensor self, Tensor grad_output, Tensor weight, bool[3] output_mask) -> (Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"} +at::Tensor & linear_out(const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, at::Tensor & out); // {"schema": "aten::linear.out(Tensor input, Tensor weight, Tensor? bias=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor mkldnn_linear(const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias); // {"schema": "aten::mkldnn_linear(Tensor self, Tensor weight, Tensor? bias=None) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor mkldnn_linear_backward_input(at::IntArrayRef input_size, const at::Tensor & grad_output, const at::Tensor & weight); // {"schema": "aten::mkldnn_linear_backward_input(int[] input_size, Tensor grad_output, Tensor weight) -> Tensor", "dispatch": "True", "default": "False"} +::std::tuple mkldnn_linear_backward_weights(const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & weight, bool bias_defined); // {"schema": "aten::mkldnn_linear_backward_weights(Tensor grad_output, Tensor input, Tensor weight, bool bias_defined) -> (Tensor, Tensor)", "dispatch": "True", "default": "False"} +::std::tuple mkldnn_linear_backward(const at::Tensor & self, const at::Tensor & grad_output, const at::Tensor & weight, ::std::array output_mask); // {"schema": "aten::mkldnn_linear_backward(Tensor self, Tensor grad_output, Tensor weight, bool[3] output_mask) -> (Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"} +at::Tensor _cslt_compress(const at::Tensor & input); // {"schema": "aten::_cslt_compress(Tensor input) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _cslt_sparse_mm(const at::Tensor & compressed_A, const at::Tensor & dense_B, const ::std::optional & bias, const ::std::optional & alpha, ::std::optional out_dtype, bool transpose_result, int64_t alg_id, int64_t split_k, int64_t split_k_mode); // {"schema": "aten::_cslt_sparse_mm(Tensor compressed_A, Tensor dense_B, Tensor? bias=None, Tensor? alpha=None, ScalarType? out_dtype=None, bool transpose_result=False, int alg_id=0, int split_k=1, int split_k_mode=-1) -> Tensor", "dispatch": "True", "default": "False"} +int64_t _cslt_sparse_mm_search(const at::Tensor & compressed_A, const at::Tensor & dense_B, const ::std::optional & bias, const ::std::optional & alpha, ::std::optional out_dtype, bool transpose_result); // {"schema": "aten::_cslt_sparse_mm_search(Tensor compressed_A, Tensor dense_B, Tensor? bias=None, Tensor? alpha=None, ScalarType? out_dtype=None, bool transpose_result=False) -> int", "dispatch": "True", "default": "False"} +::std::tuple _sparse_semi_structured_tile(const at::Tensor & input, c10::string_view algorithm, bool use_cutlass); // {"schema": "aten::_sparse_semi_structured_tile(Tensor input, str algorithm=\"\", bool use_cutlass=True) -> (Tensor, Tensor, Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"} +::std::tuple _sparse_semi_structured_apply(const at::Tensor & input, const at::Tensor & thread_masks); // {"schema": "aten::_sparse_semi_structured_apply(Tensor input, Tensor thread_masks) -> (Tensor, Tensor)", "dispatch": "True", "default": "False"} +at::Tensor _sparse_semi_structured_apply_dense(const at::Tensor & input, const at::Tensor & thread_masks); // {"schema": "aten::_sparse_semi_structured_apply_dense(Tensor input, Tensor thread_masks) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _sparse_semi_structured_linear(const at::Tensor & input, const at::Tensor & weight, const at::Tensor & meta, const ::std::optional & bias, ::std::optional activation, ::std::optional out_dtype); // {"schema": "aten::_sparse_semi_structured_linear(Tensor input, Tensor weight, Tensor meta, *, Tensor? bias=None, str? activation=None, ScalarType? out_dtype=None) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _sparse_semi_structured_mm(const at::Tensor & mat1, const at::Tensor & mat1_meta, const at::Tensor & mat2, ::std::optional out_dtype); // {"schema": "aten::_sparse_semi_structured_mm(Tensor mat1, Tensor mat1_meta, Tensor mat2, *, ScalarType? out_dtype=None) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _sparse_semi_structured_addmm(const at::Tensor & input, const at::Tensor & mat1, const at::Tensor & mat1_meta, const at::Tensor & mat2, const at::Scalar & alpha, const at::Scalar & beta, ::std::optional out_dtype); // {"schema": "aten::_sparse_semi_structured_addmm(Tensor input, Tensor mat1, Tensor mat1_meta, Tensor mat2, *, Scalar alpha=1, Scalar beta=1, ScalarType? out_dtype=None) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _mixed_dtypes_linear(const at::Tensor & input, const at::Tensor & weight, const at::Tensor & scale, const ::std::optional & bias, ::std::optional activation); // {"schema": "aten::_mixed_dtypes_linear(Tensor input, Tensor weight, Tensor scale, *, Tensor? bias=None, str? activation=None) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor fbgemm_linear_int8_weight_fp32_activation(const at::Tensor & input, const at::Tensor & weight, const at::Tensor & packed, const at::Tensor & col_offsets, const at::Scalar & weight_scale, const at::Scalar & weight_zero_point, const at::Tensor & bias); // {"schema": "aten::fbgemm_linear_int8_weight_fp32_activation(Tensor input, Tensor weight, Tensor packed, Tensor col_offsets, Scalar weight_scale, Scalar weight_zero_point, Tensor bias) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor fbgemm_linear_int8_weight(const at::Tensor & input, const at::Tensor & weight, const at::Tensor & packed, const at::Tensor & col_offsets, const at::Scalar & weight_scale, const at::Scalar & weight_zero_point, const at::Tensor & bias); // {"schema": "aten::fbgemm_linear_int8_weight(Tensor input, Tensor weight, Tensor packed, Tensor col_offsets, Scalar weight_scale, Scalar weight_zero_point, Tensor bias) -> Tensor", "dispatch": "False", "default": "True"} +::std::tuple fbgemm_linear_quantize_weight(const at::Tensor & input); // {"schema": "aten::fbgemm_linear_quantize_weight(Tensor input) -> (Tensor, Tensor, float, int)", "dispatch": "False", "default": "True"} +at::Tensor fbgemm_pack_gemm_matrix_fp16(const at::Tensor & input); // {"schema": "aten::fbgemm_pack_gemm_matrix_fp16(Tensor input) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _wrapped_linear_prepack(const at::Tensor & weight, const at::Tensor & weight_scale, const at::Tensor & weight_zero_point, const at::Tensor & bias); // {"schema": "aten::_wrapped_linear_prepack(Tensor weight, Tensor weight_scale, Tensor weight_zero_point, Tensor bias) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _wrapped_quantized_linear_prepacked(const at::Tensor & input, const at::Tensor & input_scale, const at::Tensor & input_zero_point, const at::Tensor & packed_weight, const at::Tensor & output_scale, const at::Tensor & output_zero_point, int64_t out_channel); // {"schema": "aten::_wrapped_quantized_linear_prepacked(Tensor input, Tensor input_scale, Tensor input_zero_point, Tensor packed_weight, Tensor output_scale, Tensor output_zero_point, int out_channel) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor fbgemm_linear_fp16_weight_fp32_activation(const at::Tensor & input, const at::Tensor & packed_weight, const ::std::optional & bias); // {"schema": "aten::fbgemm_linear_fp16_weight_fp32_activation(Tensor input, Tensor packed_weight, Tensor? bias) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor fbgemm_linear_fp16_weight_fp32_activation(const at::Tensor & input, const at::Tensor & packed_weight, const ::std::optional & bias, at::Tensor & output); // {"schema": "aten::fbgemm_linear_fp16_weight_fp32_activation.out(Tensor input, Tensor packed_weight, Tensor? bias, Tensor(a!) output) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor fbgemm_linear_fp16_weight(const at::Tensor & input, const at::Tensor & packed_weight, const at::Tensor & bias); // {"schema": "aten::fbgemm_linear_fp16_weight(Tensor input, Tensor packed_weight, Tensor bias) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor fbgemm_linear_fp16_weight(const at::Tensor & input, const at::Tensor & packed_weight, const at::Tensor & bias, at::Tensor & output); // {"schema": "aten::fbgemm_linear_fp16_weight.out(Tensor input, Tensor packed_weight, Tensor bias, Tensor(a!) output) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor fbgemm_pack_quantized_matrix(const at::Tensor & input); // {"schema": "aten::fbgemm_pack_quantized_matrix(Tensor input) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor fbgemm_pack_quantized_matrix(const at::Tensor & input, int64_t K, int64_t N); // {"schema": "aten::fbgemm_pack_quantized_matrix.KN(Tensor input, int K, int N) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor ldexp(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::ldexp.Tensor(Tensor self, Tensor other) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & ldexp_(at::Tensor & self, const at::Tensor & other); // {"schema": "aten::ldexp_(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & ldexp_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::ldexp.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor linspace(const at::Scalar & start, const at::Scalar & end, int64_t steps, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::linspace(Scalar start, Scalar end, int steps, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor linspace(const at::Tensor & start, const at::Tensor & end, int64_t steps, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::linspace.Tensor_Tensor(Tensor start, Tensor end, int steps, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor linspace(const at::Tensor & start, const at::Scalar & end, int64_t steps, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::linspace.Tensor_Scalar(Tensor start, Scalar end, int steps, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor linspace(const at::Scalar & start, const at::Tensor & end, int64_t steps, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::linspace.Scalar_Tensor(Scalar start, Tensor end, int steps, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & linspace_out(const at::Scalar & start, const at::Scalar & end, int64_t steps, at::Tensor & out); // {"schema": "aten::linspace.out(Scalar start, Scalar end, int steps, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & linspace_out(const at::Tensor & start, const at::Tensor & end, int64_t steps, at::Tensor & out); // {"schema": "aten::linspace.Tensor_Tensor_out(Tensor start, Tensor end, int steps, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & linspace_out(const at::Tensor & start, const at::Scalar & end, int64_t steps, at::Tensor & out); // {"schema": "aten::linspace.Tensor_Scalar_out(Tensor start, Scalar end, int steps, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & linspace_out(const at::Scalar & start, const at::Tensor & end, int64_t steps, at::Tensor & out); // {"schema": "aten::linspace.Scalar_Tensor_out(Scalar start, Tensor end, int steps, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor log(const at::Tensor & self); // {"schema": "aten::log(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & log_(at::Tensor & self); // {"schema": "aten::log_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & log_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::log.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor log10(const at::Tensor & self); // {"schema": "aten::log10(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & log10_(at::Tensor & self); // {"schema": "aten::log10_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & log10_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::log10.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor log1p(const at::Tensor & self); // {"schema": "aten::log1p(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & log1p_(at::Tensor & self); // {"schema": "aten::log1p_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & log1p_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::log1p.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor log2(const at::Tensor & self); // {"schema": "aten::log2(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & log2_(at::Tensor & self); // {"schema": "aten::log2_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & log2_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::log2.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & logaddexp_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::logaddexp.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor logaddexp(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::logaddexp(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & logaddexp2_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::logaddexp2.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor logaddexp2(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::logaddexp2(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor xlogy(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::xlogy.Tensor(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor xlogy(const at::Scalar & self, const at::Tensor & other); // {"schema": "aten::xlogy.Scalar_Self(Scalar self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor xlogy(const at::Tensor & self, const at::Scalar & other); // {"schema": "aten::xlogy.Scalar_Other(Tensor self, Scalar other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & xlogy_(at::Tensor & self, const at::Tensor & other); // {"schema": "aten::xlogy_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & xlogy_(at::Tensor & self, const at::Scalar & other); // {"schema": "aten::xlogy_.Scalar_Other(Tensor(a!) self, Scalar other) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & xlogy_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::xlogy.OutTensor(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & xlogy_out(const at::Scalar & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::xlogy.OutScalar_Self(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & xlogy_out(const at::Tensor & self, const at::Scalar & other, at::Tensor & out); // {"schema": "aten::xlogy.OutScalar_Other(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor logspace(const at::Scalar & start, const at::Scalar & end, int64_t steps, double base, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::logspace(Scalar start, Scalar end, int steps, float base=10.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor logspace(const at::Tensor & start, const at::Tensor & end, int64_t steps, double base, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::logspace.Tensor_Tensor(Tensor start, Tensor end, int steps, float base=10.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor logspace(const at::Tensor & start, const at::Scalar & end, int64_t steps, double base, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::logspace.Tensor_Scalar(Tensor start, Scalar end, int steps, float base=10.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor logspace(const at::Scalar & start, const at::Tensor & end, int64_t steps, double base, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::logspace.Scalar_Tensor(Scalar start, Tensor end, int steps, float base=10.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & logspace_out(const at::Scalar & start, const at::Scalar & end, int64_t steps, double base, at::Tensor & out); // {"schema": "aten::logspace.out(Scalar start, Scalar end, int steps, float base=10.0, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & logspace_out(const at::Tensor & start, const at::Tensor & end, int64_t steps, double base, at::Tensor & out); // {"schema": "aten::logspace.Tensor_Tensor_out(Tensor start, Tensor end, int steps, float base=10.0, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & logspace_out(const at::Tensor & start, const at::Scalar & end, int64_t steps, double base, at::Tensor & out); // {"schema": "aten::logspace.Tensor_Scalar_out(Tensor start, Scalar end, int steps, float base=10.0, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & logspace_out(const at::Scalar & start, const at::Tensor & end, int64_t steps, double base, at::Tensor & out); // {"schema": "aten::logspace.Scalar_Tensor_out(Scalar start, Tensor end, int steps, float base=10.0, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor log_softmax(const at::Tensor & self, int64_t dim, ::std::optional dtype); // {"schema": "aten::log_softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & log_softmax_out(const at::Tensor & self, int64_t dim, ::std::optional dtype, at::Tensor & out); // {"schema": "aten::log_softmax.int_out(Tensor self, int dim, ScalarType? dtype=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor log_softmax(const at::Tensor & self, at::Dimname dim, ::std::optional dtype); // {"schema": "aten::log_softmax.Dimname(Tensor self, Dimname dim, *, ScalarType? dtype=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _log_softmax(const at::Tensor & self, int64_t dim, bool half_to_float); // {"schema": "aten::_log_softmax(Tensor self, int dim, bool half_to_float) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & _log_softmax_out(const at::Tensor & self, int64_t dim, bool half_to_float, at::Tensor & out); // {"schema": "aten::_log_softmax.out(Tensor self, int dim, bool half_to_float, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor _log_softmax_backward_data(const at::Tensor & grad_output, const at::Tensor & output, int64_t dim, at::ScalarType input_dtype); // {"schema": "aten::_log_softmax_backward_data(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & _log_softmax_backward_data_out(const at::Tensor & grad_output, const at::Tensor & output, int64_t dim, at::ScalarType input_dtype, at::Tensor & out); // {"schema": "aten::_log_softmax_backward_data.out(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor _logcumsumexp(const at::Tensor & self, int64_t dim); // {"schema": "aten::_logcumsumexp(Tensor self, int dim) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & _logcumsumexp_out(const at::Tensor & self, int64_t dim, at::Tensor & out); // {"schema": "aten::_logcumsumexp.out(Tensor self, int dim, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor logcumsumexp(const at::Tensor & self, int64_t dim); // {"schema": "aten::logcumsumexp(Tensor self, int dim) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & logcumsumexp_out(const at::Tensor & self, int64_t dim, at::Tensor & out); // {"schema": "aten::logcumsumexp.out(Tensor self, int dim, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor logcumsumexp(const at::Tensor & self, at::Dimname dim); // {"schema": "aten::logcumsumexp.dimname(Tensor self, Dimname dim) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & logcumsumexp_out(const at::Tensor & self, at::Dimname dim, at::Tensor & out); // {"schema": "aten::logcumsumexp.dimname_out(Tensor self, Dimname dim, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor logsumexp(const at::Tensor & self, at::IntArrayRef dim, bool keepdim); // {"schema": "aten::logsumexp(Tensor self, int[1] dim, bool keepdim=False) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & logsumexp_out(const at::Tensor & self, at::IntArrayRef dim, bool keepdim, at::Tensor & out); // {"schema": "aten::logsumexp.out(Tensor self, int[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor logsumexp(const at::Tensor & self, at::DimnameList dim, bool keepdim); // {"schema": "aten::logsumexp.names(Tensor self, Dimname[1] dim, bool keepdim=False) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & logsumexp_out(const at::Tensor & self, at::DimnameList dim, bool keepdim, at::Tensor & out); // {"schema": "aten::logsumexp.names_out(Tensor self, Dimname[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor margin_ranking_loss(const at::Tensor & input1, const at::Tensor & input2, const at::Tensor & target, double margin, int64_t reduction); // {"schema": "aten::margin_ranking_loss(Tensor input1, Tensor input2, Tensor target, float margin=0.0, int reduction=Mean) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor matmul(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::matmul(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"} +::std::tuple matmul_backward(const at::Tensor & grad, const at::Tensor & self, const at::Tensor & other, ::std::array mask); // {"schema": "aten::matmul_backward(Tensor grad, Tensor self, Tensor other, bool[2] mask) -> (Tensor, Tensor)", "dispatch": "True", "default": "False"} +at::Tensor & matmul_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::matmul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor matrix_power(const at::Tensor & self, int64_t n); // {"schema": "aten::matrix_power(Tensor self, int n) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & matrix_power_out(const at::Tensor & self, int64_t n, at::Tensor & out); // {"schema": "aten::matrix_power.out(Tensor self, int n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor matrix_exp(const at::Tensor & self); // {"schema": "aten::matrix_exp(Tensor self) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor matrix_exp_backward(const at::Tensor & self, const at::Tensor & grad); // {"schema": "aten::matrix_exp_backward(Tensor self, Tensor grad) -> Tensor", "dispatch": "False", "default": "True"} +::std::tuple _aminmax(const at::Tensor & self); // {"schema": "aten::_aminmax(Tensor self) -> (Tensor, Tensor)", "dispatch": "True", "default": "False"} +::std::tuple _aminmax(const at::Tensor & self, int64_t dim, bool keepdim); // {"schema": "aten::_aminmax.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor, Tensor)", "dispatch": "True", "default": "False"} +::std::tuple aminmax(const at::Tensor & self, ::std::optional dim, bool keepdim); // {"schema": "aten::aminmax(Tensor self, *, int? dim=None, bool keepdim=False) -> (Tensor min, Tensor max)", "dispatch": "True", "default": "True"} +::std::tuple aminmax_out(const at::Tensor & self, ::std::optional dim, bool keepdim, at::Tensor & min, at::Tensor & max); // {"schema": "aten::aminmax.out(Tensor self, *, int? dim=None, bool keepdim=False, Tensor(a!) min, Tensor(b!) max) -> (Tensor(a!) min, Tensor(b!) max)", "dispatch": "True", "default": "False"} +at::Tensor _compute_linear_combination(const at::Tensor & input, const at::Tensor & coefficients); // {"schema": "aten::_compute_linear_combination(Tensor input, Tensor coefficients) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & _compute_linear_combination_out(const at::Tensor & input, const at::Tensor & coefficients, at::Tensor & out); // {"schema": "aten::_compute_linear_combination.out(Tensor input, Tensor coefficients, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +::std::tuple max(const at::Tensor & self, int64_t dim, bool keepdim); // {"schema": "aten::max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)", "dispatch": "True", "default": "True"} +::std::tuple max_out(const at::Tensor & self, int64_t dim, bool keepdim, at::Tensor & max, at::Tensor & max_values); // {"schema": "aten::max.dim_max(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) max, Tensor(b!) max_values) -> (Tensor(a!) values, Tensor(b!) indices)", "dispatch": "True", "default": "False"} +::std::tuple max(const at::Tensor & self, at::Dimname dim, bool keepdim); // {"schema": "aten::max.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices)", "dispatch": "False", "default": "True"} +::std::tuple max_out(const at::Tensor & self, at::Dimname dim, bool keepdim, at::Tensor & max, at::Tensor & max_values); // {"schema": "aten::max.names_dim_max(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) max, Tensor(b!) max_values) -> (Tensor(a!) values, Tensor(b!) indices)", "dispatch": "False", "default": "True"} +at::Tensor value_selecting_reduction_backward(const at::Tensor & grad, int64_t dim, const at::Tensor & indices, c10::SymIntArrayRef sizes, bool keepdim); // {"schema": "aten::value_selecting_reduction_backward(Tensor grad, int dim, Tensor indices, SymInt[] sizes, bool keepdim) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor amax(const at::Tensor & self, at::IntArrayRef dim, bool keepdim); // {"schema": "aten::amax(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & amax_out(const at::Tensor & self, at::IntArrayRef dim, bool keepdim, at::Tensor & out); // {"schema": "aten::amax.out(Tensor self, int[1] dim=[], bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +::std::tuple max_pool1d_with_indices(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode); // {"schema": "aten::max_pool1d_with_indices(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, int[1] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor)", "dispatch": "False", "default": "True"} +at::Tensor max_pool1d(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode); // {"schema": "aten::max_pool1d(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, int[1] dilation=1, bool ceil_mode=False) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor max_pool2d(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode); // {"schema": "aten::max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor max_pool2d_backward(const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode); // {"schema": "aten::max_pool2d_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor mkldnn_max_pool2d(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode); // {"schema": "aten::mkldnn_max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor mkldnn_max_pool2d_backward(const at::Tensor & grad_output, const at::Tensor & output, const at::Tensor & input, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode); // {"schema": "aten::mkldnn_max_pool2d_backward(Tensor grad_output, Tensor output, Tensor input, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor mkldnn_max_pool3d(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode); // {"schema": "aten::mkldnn_max_pool3d(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor mkldnn_max_pool3d_backward(const at::Tensor & grad_output, const at::Tensor & output, const at::Tensor & input, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode); // {"schema": "aten::mkldnn_max_pool3d_backward(Tensor grad_output, Tensor output, Tensor input, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor quantized_max_pool1d(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode); // {"schema": "aten::quantized_max_pool1d(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, int[1] dilation=1, bool ceil_mode=False) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor quantized_max_pool2d(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode); // {"schema": "aten::quantized_max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor quantized_max_pool3d(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode); // {"schema": "aten::quantized_max_pool3d(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor max_pool3d(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode); // {"schema": "aten::max_pool3d(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor mean(const at::Tensor & self, ::std::optional dtype); // {"schema": "aten::mean(Tensor self, *, ScalarType? dtype=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & mean_out(const at::Tensor & self, ::std::optional dtype, at::Tensor & out); // {"schema": "aten::mean.dtype_out(Tensor self, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor mean(const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim, ::std::optional dtype); // {"schema": "aten::mean.dim(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & mean_out(const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim, ::std::optional dtype, at::Tensor & out); // {"schema": "aten::mean.out(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor mean(const at::Tensor & self, at::DimnameList dim, bool keepdim, ::std::optional dtype); // {"schema": "aten::mean.names_dim(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & mean_out(const at::Tensor & self, at::DimnameList dim, bool keepdim, ::std::optional dtype, at::Tensor & out); // {"schema": "aten::mean.names_out(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor nanmean(const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim, ::std::optional dtype); // {"schema": "aten::nanmean(Tensor self, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & nanmean_out(const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim, ::std::optional dtype, at::Tensor & out); // {"schema": "aten::nanmean.out(Tensor self, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor median(const at::Tensor & self); // {"schema": "aten::median(Tensor self) -> Tensor", "dispatch": "True", "default": "False"} +::std::tuple median(const at::Tensor & self, int64_t dim, bool keepdim); // {"schema": "aten::median.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)", "dispatch": "True", "default": "True"} +::std::tuple median_out(const at::Tensor & self, int64_t dim, bool keepdim, at::Tensor & values, at::Tensor & indices); // {"schema": "aten::median.dim_values(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)", "dispatch": "True", "default": "False"} +::std::tuple median(const at::Tensor & self, at::Dimname dim, bool keepdim); // {"schema": "aten::median.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices)", "dispatch": "False", "default": "True"} +::std::tuple median_out(const at::Tensor & self, at::Dimname dim, bool keepdim, at::Tensor & values, at::Tensor & indices); // {"schema": "aten::median.names_dim_values(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)", "dispatch": "False", "default": "True"} +at::Tensor nanmedian(const at::Tensor & self); // {"schema": "aten::nanmedian(Tensor self) -> Tensor", "dispatch": "True", "default": "False"} +::std::tuple nanmedian(const at::Tensor & self, int64_t dim, bool keepdim); // {"schema": "aten::nanmedian.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)", "dispatch": "True", "default": "True"} +::std::tuple nanmedian_out(const at::Tensor & self, int64_t dim, bool keepdim, at::Tensor & values, at::Tensor & indices); // {"schema": "aten::nanmedian.dim_values(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)", "dispatch": "True", "default": "False"} +::std::tuple nanmedian(const at::Tensor & self, at::Dimname dim, bool keepdim); // {"schema": "aten::nanmedian.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices)", "dispatch": "False", "default": "True"} +::std::tuple nanmedian_out(const at::Tensor & self, at::Dimname dim, bool keepdim, at::Tensor & values, at::Tensor & indices); // {"schema": "aten::nanmedian.names_dim_values(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)", "dispatch": "False", "default": "True"} +::std::tuple min(const at::Tensor & self, int64_t dim, bool keepdim); // {"schema": "aten::min.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)", "dispatch": "True", "default": "True"} +::std::tuple min_out(const at::Tensor & self, int64_t dim, bool keepdim, at::Tensor & min, at::Tensor & min_indices); // {"schema": "aten::min.dim_min(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) min, Tensor(b!) min_indices) -> (Tensor(a!) values, Tensor(b!) indices)", "dispatch": "True", "default": "False"} +::std::tuple min(const at::Tensor & self, at::Dimname dim, bool keepdim); // {"schema": "aten::min.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices)", "dispatch": "False", "default": "True"} +::std::tuple min_out(const at::Tensor & self, at::Dimname dim, bool keepdim, at::Tensor & min, at::Tensor & min_indices); // {"schema": "aten::min.names_dim_min(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) min, Tensor(b!) min_indices) -> (Tensor(a!) values, Tensor(b!) indices)", "dispatch": "False", "default": "True"} +at::Tensor amin(const at::Tensor & self, at::IntArrayRef dim, bool keepdim); // {"schema": "aten::amin(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & amin_out(const at::Tensor & self, at::IntArrayRef dim, bool keepdim, at::Tensor & out); // {"schema": "aten::amin.out(Tensor self, int[1] dim=[], bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor _mps_convolution(const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups); // {"schema": "aten::_mps_convolution(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups) -> Tensor", "dispatch": "True", "default": "False"} +::std::tuple mps_convolution_backward(const at::Tensor & self, const at::Tensor & grad_output, const at::Tensor & weight, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, ::std::array output_mask); // {"schema": "aten::mps_convolution_backward(Tensor self, Tensor grad_output, Tensor weight, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"} +at::Tensor mkldnn_convolution(const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups); // {"schema": "aten::mkldnn_convolution(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups) -> Tensor", "dispatch": "True", "default": "True"} +::std::tuple mkldnn_rnn_layer(const at::Tensor & input, const at::Tensor & weight0, const at::Tensor & weight1, const at::Tensor & weight2, const at::Tensor & weight3, const at::Tensor & hx_, const at::Tensor & cx_, bool reverse, at::IntArrayRef batch_sizes, int64_t mode, int64_t hidden_size, int64_t num_layers, bool has_biases, bool bidirectional, bool batch_first, bool train); // {"schema": "aten::mkldnn_rnn_layer(Tensor input, Tensor weight0, Tensor weight1, Tensor weight2, Tensor weight3, Tensor hx_, Tensor cx_, bool reverse, int[] batch_sizes, int mode, int hidden_size, int num_layers, bool has_biases, bool bidirectional, bool batch_first, bool train) -> (Tensor, Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"} +::std::tuple mkldnn_rnn_layer_backward(const at::Tensor & input, const at::Tensor & weight1, const at::Tensor & weight2, const at::Tensor & weight3, const at::Tensor & weight4, const at::Tensor & hx_, const at::Tensor & cx_tmp, const at::Tensor & output, const at::Tensor & hy_, const at::Tensor & cy_, const ::std::optional & grad_output, const ::std::optional & grad_hy, const ::std::optional & grad_cy, bool reverse, int64_t mode, int64_t hidden_size, int64_t num_layers, bool has_biases, bool train, bool bidirectional, at::IntArrayRef batch_sizes, bool batch_first, const at::Tensor & workspace); // {"schema": "aten::mkldnn_rnn_layer_backward(Tensor input, Tensor weight1, Tensor weight2, Tensor weight3, Tensor weight4, Tensor hx_, Tensor cx_tmp, Tensor output, Tensor hy_, Tensor cy_, Tensor? grad_output, Tensor? grad_hy, Tensor? grad_cy, bool reverse, int mode, int hidden_size, int num_layers, bool has_biases, bool train, bool bidirectional, int[] batch_sizes, bool batch_first, Tensor workspace) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"} +::std::tuple miopen_batch_norm(const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, const ::std::optional & running_mean, const ::std::optional & running_var, bool training, double exponential_average_factor, double epsilon); // {"schema": "aten::miopen_batch_norm(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon) -> (Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"} +::std::tuple miopen_batch_norm_backward(const at::Tensor & input, const at::Tensor & grad_output, const at::Tensor & weight, const ::std::optional & running_mean, const ::std::optional & running_var, const ::std::optional & save_mean, const ::std::optional & save_var, double epsilon); // {"schema": "aten::miopen_batch_norm_backward(Tensor input, Tensor grad_output, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, float epsilon) -> (Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"} +at::Tensor miopen_convolution(const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, bool benchmark, bool deterministic); // {"schema": "aten::miopen_convolution(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor miopen_convolution_transpose(const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, bool benchmark, bool deterministic); // {"schema": "aten::miopen_convolution_transpose(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor miopen_depthwise_convolution(const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, bool benchmark, bool deterministic); // {"schema": "aten::miopen_depthwise_convolution(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor miopen_convolution_relu(const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, c10::SymInt groups); // {"schema": "aten::miopen_convolution_relu(Tensor self, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, SymInt groups) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor miopen_convolution_add_relu(const at::Tensor & self, const at::Tensor & weight, const at::Tensor & z, const ::std::optional & alpha, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, c10::SymInt groups); // {"schema": "aten::miopen_convolution_add_relu(Tensor self, Tensor weight, Tensor z, Scalar? alpha, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, SymInt groups) -> Tensor", "dispatch": "True", "default": "False"} +::std::tuple miopen_rnn(const at::Tensor & input, at::TensorList weight, int64_t weight_stride0, const at::Tensor & hx, const ::std::optional & cx, int64_t mode, int64_t hidden_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, at::IntArrayRef batch_sizes, const ::std::optional & dropout_state); // {"schema": "aten::miopen_rnn(Tensor input, Tensor[] weight, int weight_stride0, Tensor hx, Tensor? cx, int mode, int hidden_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, int[] batch_sizes, Tensor? dropout_state) -> (Tensor, Tensor, Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"} +::std::tuple> miopen_rnn_backward(const at::Tensor & input, at::TensorList weight, int64_t weight_stride0, const at::Tensor & weight_buf, const at::Tensor & hx, const ::std::optional & cx, const at::Tensor & output, const ::std::optional & grad_output, const ::std::optional & grad_hy, const ::std::optional & grad_cy, int64_t mode, int64_t hidden_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, at::IntArrayRef batch_sizes, const ::std::optional & dropout_state, const at::Tensor & reserve, ::std::array output_mask); // {"schema": "aten::miopen_rnn_backward(Tensor input, Tensor[] weight, int weight_stride0, Tensor weight_buf, Tensor hx, Tensor? cx, Tensor output, Tensor? grad_output, Tensor? grad_hy, Tensor? grad_cy, int mode, int hidden_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, int[] batch_sizes, Tensor? dropout_state, Tensor reserve, bool[4] output_mask) -> (Tensor, Tensor, Tensor, Tensor[])", "dispatch": "True", "default": "False"} +at::Tensor mm(const at::Tensor & self, const at::Tensor & mat2); // {"schema": "aten::mm(Tensor self, Tensor mat2) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & mm_out(const at::Tensor & self, const at::Tensor & mat2, at::Tensor & out); // {"schema": "aten::mm.out(Tensor self, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor mm(const at::Tensor & self, const at::Tensor & mat2, at::ScalarType out_dtype); // {"schema": "aten::mm.dtype(Tensor self, Tensor mat2, ScalarType out_dtype) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & mm_out(const at::Tensor & self, const at::Tensor & mat2, at::ScalarType out_dtype, at::Tensor & out); // {"schema": "aten::mm.dtype_out(Tensor self, Tensor mat2, ScalarType out_dtype, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor _int_mm(const at::Tensor & self, const at::Tensor & mat2); // {"schema": "aten::_int_mm(Tensor self, Tensor mat2) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & _int_mm_out(const at::Tensor & self, const at::Tensor & mat2, at::Tensor & out); // {"schema": "aten::_int_mm.out(Tensor self, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor _convert_weight_to_int4pack(const at::Tensor & self, int64_t innerKTiles); // {"schema": "aten::_convert_weight_to_int4pack(Tensor self, int innerKTiles) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _weight_int4pack_mm(const at::Tensor & self, const at::Tensor & mat2, int64_t qGroupSize, const at::Tensor & qScaleAndZeros); // {"schema": "aten::_weight_int4pack_mm(Tensor self, Tensor mat2, int qGroupSize, Tensor qScaleAndZeros) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _weight_int4pack_mm_with_scales_and_zeros(const at::Tensor & self, const at::Tensor & mat2, int64_t qGroupSize, const at::Tensor & qScale, const at::Tensor & qZeros); // {"schema": "aten::_weight_int4pack_mm_with_scales_and_zeros(Tensor self, Tensor mat2, int qGroupSize, Tensor qScale, Tensor qZeros) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _convert_weight_to_int4pack_for_cpu(const at::Tensor & self, int64_t innerKTiles); // {"schema": "aten::_convert_weight_to_int4pack_for_cpu(Tensor self, int innerKTiles) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _weight_int4pack_mm_for_cpu(const at::Tensor & self, const at::Tensor & mat2, int64_t qGroupSize, const at::Tensor & qScaleAndZeros); // {"schema": "aten::_weight_int4pack_mm_for_cpu(Tensor self, Tensor mat2, int qGroupSize, Tensor qScaleAndZeros) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _dyn_quant_pack_4bit_weight(const at::Tensor & weights, const at::Tensor & scales_zeros, const ::std::optional & bias, int64_t block_size, int64_t in_features, int64_t out_features); // {"schema": "aten::_dyn_quant_pack_4bit_weight(Tensor weights, Tensor scales_zeros, Tensor? bias, int block_size, int in_features, int out_features) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _dyn_quant_matmul_4bit(const at::Tensor & inp, const at::Tensor & packed_weights, int64_t block_size, int64_t in_features, int64_t out_features); // {"schema": "aten::_dyn_quant_matmul_4bit(Tensor inp, Tensor packed_weights, int block_size, int in_features, int out_features) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _weight_int8pack_mm(const at::Tensor & self, const at::Tensor & mat2, const at::Tensor & scales); // {"schema": "aten::_weight_int8pack_mm(Tensor self, Tensor mat2, Tensor scales) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _sparse_mm(const at::Tensor & sparse, const at::Tensor & dense); // {"schema": "aten::_sparse_mm(Tensor sparse, Tensor dense) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _sparse_mm(const at::Tensor & sparse, const at::Tensor & dense, c10::string_view reduce); // {"schema": "aten::_sparse_mm.reduce(Tensor sparse, Tensor dense, str reduce) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _sparse_sparse_matmul(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::_sparse_sparse_matmul(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "False"} +::std::tuple mode(const at::Tensor & self, int64_t dim, bool keepdim); // {"schema": "aten::mode(Tensor self, int dim=-1, bool keepdim=False) -> (Tensor values, Tensor indices)", "dispatch": "True", "default": "False"} +::std::tuple mode_out(const at::Tensor & self, int64_t dim, bool keepdim, at::Tensor & values, at::Tensor & indices); // {"schema": "aten::mode.values(Tensor self, int dim=-1, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)", "dispatch": "True", "default": "True"} +::std::tuple mode(const at::Tensor & self, at::Dimname dim, bool keepdim); // {"schema": "aten::mode.dimname(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices)", "dispatch": "False", "default": "True"} +::std::tuple mode_out(const at::Tensor & self, at::Dimname dim, bool keepdim, at::Tensor & values, at::Tensor & indices); // {"schema": "aten::mode.dimname_out(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)", "dispatch": "False", "default": "True"} +at::Tensor mul(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::mul.Tensor(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & mul_(at::Tensor & self, const at::Tensor & other); // {"schema": "aten::mul_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & mul_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::mul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor mul(const at::Tensor & self, const at::Scalar & other); // {"schema": "aten::mul.Scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & mul_(at::Tensor & self, const at::Scalar & other); // {"schema": "aten::mul_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor multiply(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::multiply.Tensor(Tensor self, Tensor other) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & multiply_(at::Tensor & self, const at::Tensor & other); // {"schema": "aten::multiply_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & multiply_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::multiply.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor multiply(const at::Tensor & self, const at::Scalar & other); // {"schema": "aten::multiply.Scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & multiply_(at::Tensor & self, const at::Scalar & other); // {"schema": "aten::multiply_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor mv(const at::Tensor & self, const at::Tensor & vec); // {"schema": "aten::mv(Tensor self, Tensor vec) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & mv_out(const at::Tensor & self, const at::Tensor & vec, at::Tensor & out); // {"schema": "aten::mv.out(Tensor self, Tensor vec, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & mvlgamma_out(const at::Tensor & self, int64_t p, at::Tensor & out); // {"schema": "aten::mvlgamma.out(Tensor self, int p, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor mvlgamma(const at::Tensor & self, int64_t p); // {"schema": "aten::mvlgamma(Tensor self, int p) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & mvlgamma_(at::Tensor & self, int64_t p); // {"schema": "aten::mvlgamma_(Tensor(a!) self, int p) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor narrow_copy(const at::Tensor & self, int64_t dim, c10::SymInt start, c10::SymInt length); // {"schema": "aten::narrow_copy(Tensor self, int dim, SymInt start, SymInt length) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & narrow_copy_out(const at::Tensor & self, int64_t dim, c10::SymInt start, c10::SymInt length, at::Tensor & out); // {"schema": "aten::narrow_copy.out(Tensor self, int dim, SymInt start, SymInt length, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor narrow(const at::Tensor & self, int64_t dim, c10::SymInt start, c10::SymInt length); // {"schema": "aten::narrow(Tensor(a) self, int dim, SymInt start, SymInt length) -> Tensor(a)", "dispatch": "True", "default": "True"} +at::Tensor narrow(const at::Tensor & self, int64_t dim, const at::Tensor & start, c10::SymInt length); // {"schema": "aten::narrow.Tensor(Tensor(a) self, int dim, Tensor start, SymInt length) -> Tensor(a)", "dispatch": "False", "default": "True"} +::std::tuple native_batch_norm(const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const ::std::optional & running_mean, const ::std::optional & running_var, bool training, double momentum, double eps); // {"schema": "aten::native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"} +::std::tuple native_batch_norm_out(const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const ::std::optional & running_mean, const ::std::optional & running_var, bool training, double momentum, double eps, at::Tensor & out, at::Tensor & save_mean, at::Tensor & save_invstd); // {"schema": "aten::native_batch_norm.out(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, *, Tensor(a!) out, Tensor(b!) save_mean, Tensor(c!) save_invstd) -> (Tensor(a!), Tensor(b!), Tensor(c!))", "dispatch": "True", "default": "False"} +::std::tuple _native_batch_norm_legit(const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, at::Tensor & running_mean, at::Tensor & running_var, bool training, double momentum, double eps); // {"schema": "aten::_native_batch_norm_legit(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"} +::std::tuple _native_batch_norm_legit_no_training(const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const at::Tensor & running_mean, const at::Tensor & running_var, double momentum, double eps); // {"schema": "aten::_native_batch_norm_legit_no_training(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor)", "dispatch": "True", "default": "True"} +::std::tuple _native_batch_norm_legit_out(const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, at::Tensor & running_mean, at::Tensor & running_var, bool training, double momentum, double eps, at::Tensor & out, at::Tensor & save_mean, at::Tensor & save_invstd); // {"schema": "aten::_native_batch_norm_legit.out(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, bool training, float momentum, float eps, *, Tensor(d!) out, Tensor(e!) save_mean, Tensor(f!) save_invstd) -> (Tensor(d!), Tensor(e!), Tensor(f!))", "dispatch": "True", "default": "False"} +::std::tuple _native_batch_norm_legit(const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, bool training, double momentum, double eps); // {"schema": "aten::_native_batch_norm_legit.no_stats(Tensor input, Tensor? weight, Tensor? bias, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"} +::std::tuple _native_batch_norm_legit_out(const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, bool training, double momentum, double eps, at::Tensor & out, at::Tensor & save_mean, at::Tensor & save_invstd); // {"schema": "aten::_native_batch_norm_legit.no_stats_out(Tensor input, Tensor? weight, Tensor? bias, bool training, float momentum, float eps, *, Tensor(a!) out, Tensor(b!) save_mean, Tensor(c!) save_invstd) -> (Tensor(a!), Tensor(b!), Tensor(c!))", "dispatch": "True", "default": "False"} +::std::tuple batch_norm_stats(const at::Tensor & input, double eps); // {"schema": "aten::batch_norm_stats(Tensor input, float eps) -> (Tensor, Tensor)", "dispatch": "True", "default": "False"} +at::Tensor batch_norm_elemt(const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const at::Tensor & mean, const at::Tensor & invstd, double eps); // {"schema": "aten::batch_norm_elemt(Tensor input, Tensor? weight, Tensor? bias, Tensor mean, Tensor invstd, float eps) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & batch_norm_elemt_out(const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const at::Tensor & mean, const at::Tensor & invstd, double eps, at::Tensor & out); // {"schema": "aten::batch_norm_elemt.out(Tensor input, Tensor? weight, Tensor? bias, Tensor mean, Tensor invstd, float eps, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +::std::tuple batch_norm_gather_stats(const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const ::std::optional & running_mean, const ::std::optional & running_var, double momentum, double eps, int64_t count); // {"schema": "aten::batch_norm_gather_stats(Tensor input, Tensor mean, Tensor invstd, Tensor? running_mean, Tensor? running_var, float momentum, float eps, int count) -> (Tensor, Tensor)", "dispatch": "True", "default": "False"} +::std::tuple batch_norm_gather_stats_with_counts(const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const ::std::optional & running_mean, const ::std::optional & running_var, double momentum, double eps, const at::Tensor & counts); // {"schema": "aten::batch_norm_gather_stats_with_counts(Tensor input, Tensor mean, Tensor invstd, Tensor? running_mean, Tensor? running_var, float momentum, float eps, Tensor counts) -> (Tensor, Tensor)", "dispatch": "True", "default": "False"} +::std::tuple native_batch_norm_backward(const at::Tensor & grad_out, const at::Tensor & input, const ::std::optional & weight, const ::std::optional & running_mean, const ::std::optional & running_var, const ::std::optional & save_mean, const ::std::optional & save_invstd, bool train, double eps, ::std::array output_mask); // {"schema": "aten::native_batch_norm_backward(Tensor grad_out, Tensor input, Tensor? weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_invstd, bool train, float eps, bool[3] output_mask) -> (Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"} +::std::tuple batch_norm_backward_reduce(const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const ::std::optional & weight, bool input_g, bool weight_g, bool bias_g); // {"schema": "aten::batch_norm_backward_reduce(Tensor grad_out, Tensor input, Tensor mean, Tensor invstd, Tensor? weight, bool input_g, bool weight_g, bool bias_g) -> (Tensor, Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"} +at::Tensor batch_norm_backward_elemt(const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const ::std::optional & weight, const at::Tensor & sum_dy, const at::Tensor & sum_dy_xmu, const at::Tensor & count); // {"schema": "aten::batch_norm_backward_elemt(Tensor grad_out, Tensor input, Tensor mean, Tensor invstd, Tensor? weight, Tensor sum_dy, Tensor sum_dy_xmu, Tensor count) -> Tensor", "dispatch": "True", "default": "False"} +::std::tuple batch_norm_update_stats(const at::Tensor & input, const ::std::optional & running_mean, const ::std::optional & running_var, double momentum); // {"schema": "aten::batch_norm_update_stats(Tensor input, Tensor? running_mean, Tensor? running_var, float momentum) -> (Tensor, Tensor)", "dispatch": "True", "default": "False"} +bool is_vulkan_available(); // {"schema": "aten::is_vulkan_available() -> bool", "dispatch": "False", "default": "True"} +bool _nnpack_available(); // {"schema": "aten::_nnpack_available() -> bool", "dispatch": "False", "default": "True"} +at::Tensor _nnpack_spatial_convolution(const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride); // {"schema": "aten::_nnpack_spatial_convolution(Tensor input, Tensor weight, Tensor? bias, SymInt[2] padding, SymInt[2] stride=1) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor ones(at::IntArrayRef size, ::std::optional names, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::ones.names(int[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor ones(c10::SymIntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::ones(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & ones_out(c10::SymIntArrayRef size, at::Tensor & out); // {"schema": "aten::ones.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor ones_like(const at::Tensor & self, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format); // {"schema": "aten::ones_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor pairwise_distance(const at::Tensor & x1, const at::Tensor & x2, double p, double eps, bool keepdim); // {"schema": "aten::pairwise_distance(Tensor x1, Tensor x2, float p=2, float eps=1e-06, bool keepdim=False) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor cdist(const at::Tensor & x1, const at::Tensor & x2, double p, ::std::optional compute_mode); // {"schema": "aten::cdist(Tensor x1, Tensor x2, float p=2, int? compute_mode=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _euclidean_dist(const at::Tensor & x1, const at::Tensor & x2); // {"schema": "aten::_euclidean_dist(Tensor x1, Tensor x2) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor _cdist_forward(const at::Tensor & x1, const at::Tensor & x2, double p, ::std::optional compute_mode); // {"schema": "aten::_cdist_forward(Tensor x1, Tensor x2, float p, int? compute_mode) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _cdist_backward(const at::Tensor & grad, const at::Tensor & x1, const at::Tensor & x2, double p, const at::Tensor & cdist); // {"schema": "aten::_cdist_backward(Tensor grad, Tensor x1, Tensor x2, float p, Tensor cdist) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor pdist(const at::Tensor & self, double p); // {"schema": "aten::pdist(Tensor self, float p=2) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _pdist_forward(const at::Tensor & self, double p); // {"schema": "aten::_pdist_forward(Tensor self, float p=2) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _pdist_backward(const at::Tensor & grad, const at::Tensor & self, double p, const at::Tensor & pdist); // {"schema": "aten::_pdist_backward(Tensor grad, Tensor self, float p, Tensor pdist) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor cosine_similarity(const at::Tensor & x1, const at::Tensor & x2, int64_t dim, double eps); // {"schema": "aten::cosine_similarity(Tensor x1, Tensor x2, int dim=1, float eps=1e-08) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor permute(const at::Tensor & self, at::IntArrayRef dims); // {"schema": "aten::permute(Tensor(a) self, int[] dims) -> Tensor(a)", "dispatch": "True", "default": "True"} +at::Tensor movedim(const at::Tensor & self, at::IntArrayRef source, at::IntArrayRef destination); // {"schema": "aten::movedim.intlist(Tensor(a) self, int[] source, int[] destination) -> Tensor(a)", "dispatch": "False", "default": "True"} +at::Tensor movedim(const at::Tensor & self, int64_t source, int64_t destination); // {"schema": "aten::movedim.int(Tensor(a) self, int source, int destination) -> Tensor(a)", "dispatch": "False", "default": "True"} +at::Tensor moveaxis(const at::Tensor & self, at::IntArrayRef source, at::IntArrayRef destination); // {"schema": "aten::moveaxis.intlist(Tensor(a) self, int[] source, int[] destination) -> Tensor(a)", "dispatch": "False", "default": "True"} +at::Tensor moveaxis(const at::Tensor & self, int64_t source, int64_t destination); // {"schema": "aten::moveaxis.int(Tensor(a) self, int source, int destination) -> Tensor(a)", "dispatch": "False", "default": "True"} +at::Tensor numpy_T(const at::Tensor & self); // {"schema": "aten::numpy_T(Tensor(a) self) -> Tensor(a)", "dispatch": "False", "default": "True"} +at::Tensor matrix_H(const at::Tensor & self); // {"schema": "aten::matrix_H(Tensor(a) self) -> Tensor(a)", "dispatch": "False", "default": "True"} +at::Tensor mT(const at::Tensor & self); // {"schema": "aten::mT(Tensor(a) self) -> Tensor(a)", "dispatch": "False", "default": "True"} +at::Tensor mH(const at::Tensor & self); // {"schema": "aten::mH(Tensor(a) self) -> Tensor(a)", "dispatch": "False", "default": "True"} +at::Tensor adjoint(const at::Tensor & self); // {"schema": "aten::adjoint(Tensor(a) self) -> Tensor(a)", "dispatch": "False", "default": "True"} +at::Tensor pixel_shuffle(const at::Tensor & self, int64_t upscale_factor); // {"schema": "aten::pixel_shuffle(Tensor self, int upscale_factor) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor pixel_unshuffle(const at::Tensor & self, int64_t downscale_factor); // {"schema": "aten::pixel_unshuffle(Tensor self, int downscale_factor) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor channel_shuffle(const at::Tensor & self, c10::SymInt groups); // {"schema": "aten::channel_shuffle(Tensor self, SymInt groups) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor native_channel_shuffle(const at::Tensor & self, c10::SymInt groups); // {"schema": "aten::native_channel_shuffle(Tensor self, SymInt groups) -> Tensor", "dispatch": "True", "default": "True"} +bool is_pinned(const at::Tensor & self, ::std::optional device); // {"schema": "aten::is_pinned(Tensor self, Device? device=None) -> bool", "dispatch": "True", "default": "True"} +at::Tensor pin_memory(const at::Tensor & self, ::std::optional device); // {"schema": "aten::pin_memory(Tensor(a) self, Device? device=None) -> Tensor(a)", "dispatch": "False", "default": "True"} +at::Tensor _pin_memory(const at::Tensor & self, ::std::optional device); // {"schema": "aten::_pin_memory(Tensor self, Device? device=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor pinverse(const at::Tensor & self, double rcond); // {"schema": "aten::pinverse(Tensor self, float rcond=1e-15) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor poisson_nll_loss(const at::Tensor & input, const at::Tensor & target, bool log_input, bool full, double eps, int64_t reduction); // {"schema": "aten::poisson_nll_loss(Tensor input, Tensor target, bool log_input, bool full, float eps, int reduction) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor rad2deg(const at::Tensor & self); // {"schema": "aten::rad2deg(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & rad2deg_(at::Tensor & self); // {"schema": "aten::rad2deg_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & rad2deg_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::rad2deg.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor deg2rad(const at::Tensor & self); // {"schema": "aten::deg2rad(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & deg2rad_(at::Tensor & self); // {"schema": "aten::deg2rad_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & deg2rad_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::deg2rad.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor scalar_tensor(const at::Scalar & s, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::scalar_tensor(Scalar s, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor rand(c10::SymIntArrayRef size, ::std::optional names, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::rand.names(SymInt[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor rand(c10::SymIntArrayRef size, ::std::optional generator, ::std::optional names, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::rand.generator_with_names(SymInt[] size, *, Generator? generator, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor rand(c10::SymIntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::rand(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor rand(c10::SymIntArrayRef size, ::std::optional generator, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::rand.generator(SymInt[] size, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & rand_out(c10::SymIntArrayRef size, at::Tensor & out); // {"schema": "aten::rand.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & rand_out(c10::SymIntArrayRef size, ::std::optional generator, at::Tensor & out); // {"schema": "aten::rand.generator_out(SymInt[] size, *, Generator? generator, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor rand_like(const at::Tensor & self, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format); // {"schema": "aten::rand_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor rand_like(const at::Tensor & self, ::std::optional generator, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format); // {"schema": "aten::rand_like.generator(Tensor self, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor randint(c10::SymInt high, c10::SymIntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::randint(SymInt high, SymInt[] size, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor randint(c10::SymInt high, c10::SymIntArrayRef size, ::std::optional generator, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::randint.generator(SymInt high, SymInt[] size, *, Generator? generator, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor randint(c10::SymInt low, c10::SymInt high, c10::SymIntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::randint.low(SymInt low, SymInt high, SymInt[] size, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor randint(c10::SymInt low, c10::SymInt high, c10::SymIntArrayRef size, ::std::optional generator, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::randint.low_generator(SymInt low, SymInt high, SymInt[] size, *, Generator? generator, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & randint_out(c10::SymInt high, c10::SymIntArrayRef size, at::Tensor & out); // {"schema": "aten::randint.out(SymInt high, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & randint_out(c10::SymInt high, c10::SymIntArrayRef size, ::std::optional generator, at::Tensor & out); // {"schema": "aten::randint.generator_out(SymInt high, SymInt[] size, *, Generator? generator, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & randint_out(c10::SymInt low, c10::SymInt high, c10::SymIntArrayRef size, at::Tensor & out); // {"schema": "aten::randint.low_out(SymInt low, SymInt high, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & randint_out(c10::SymInt low, c10::SymInt high, c10::SymIntArrayRef size, ::std::optional generator, at::Tensor & out); // {"schema": "aten::randint.low_generator_out(SymInt low, SymInt high, SymInt[] size, *, Generator? generator, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor randint_like(const at::Tensor & self, c10::SymInt high, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format); // {"schema": "aten::randint_like(Tensor self, SymInt high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor randint_like(const at::Tensor & self, c10::SymInt high, ::std::optional generator, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format); // {"schema": "aten::randint_like.generator(Tensor self, SymInt high, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor randint_like(const at::Tensor & self, const at::Tensor & high, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format); // {"schema": "aten::randint_like.Tensor(Tensor self, Tensor high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor randint_like(const at::Tensor & self, const at::Tensor & high, ::std::optional generator, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format); // {"schema": "aten::randint_like.Tensor_generator(Tensor self, Tensor high, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor randint_like(const at::Tensor & self, c10::SymInt low, c10::SymInt high, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format); // {"schema": "aten::randint_like.low_dtype(Tensor self, SymInt low, SymInt high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor randint_like(const at::Tensor & self, c10::SymInt low, c10::SymInt high, ::std::optional generator, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format); // {"schema": "aten::randint_like.low_generator_dtype(Tensor self, SymInt low, SymInt high, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor randn(c10::SymIntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::randn(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor randn(c10::SymIntArrayRef size, ::std::optional generator, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::randn.generator(SymInt[] size, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor randn(c10::SymIntArrayRef size, ::std::optional names, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::randn.names(SymInt[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor randn(c10::SymIntArrayRef size, ::std::optional generator, ::std::optional names, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::randn.generator_with_names(SymInt[] size, *, Generator? generator, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & randn_out(c10::SymIntArrayRef size, at::Tensor & out); // {"schema": "aten::randn.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & randn_out(c10::SymIntArrayRef size, ::std::optional generator, at::Tensor & out); // {"schema": "aten::randn.generator_out(SymInt[] size, *, Generator? generator, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor randn_like(const at::Tensor & self, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format); // {"schema": "aten::randn_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor randn_like(const at::Tensor & self, ::std::optional generator, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format); // {"schema": "aten::randn_like.generator(Tensor self, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor randperm(c10::SymInt n, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::randperm(SymInt n, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor randperm(c10::SymInt n, ::std::optional generator, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::randperm.generator(SymInt n, *, Generator? generator, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & randperm_out(c10::SymInt n, at::Tensor & out); // {"schema": "aten::randperm.out(SymInt n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & randperm_out(c10::SymInt n, ::std::optional generator, at::Tensor & out); // {"schema": "aten::randperm.generator_out(SymInt n, *, Generator? generator, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor range(const at::Scalar & start, const at::Scalar & end, const at::Scalar & step, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::range.step(Scalar start, Scalar end, Scalar step=1, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor range(const at::Scalar & start, const at::Scalar & end, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::range(Scalar start, Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & range_out(const at::Scalar & start, const at::Scalar & end, at::Tensor & out); // {"schema": "aten::range.out_(Scalar start, Scalar end, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & range_out(const at::Scalar & start, const at::Scalar & end, const at::Scalar & step, at::Tensor & out); // {"schema": "aten::range.out(Scalar start, Scalar end, Scalar step=1, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor ravel(const at::Tensor & self); // {"schema": "aten::ravel(Tensor(a) self) -> Tensor(a)", "dispatch": "False", "default": "True"} +at::Tensor reciprocal(const at::Tensor & self); // {"schema": "aten::reciprocal(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & reciprocal_(at::Tensor & self); // {"schema": "aten::reciprocal_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & reciprocal_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::reciprocal.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor neg(const at::Tensor & self); // {"schema": "aten::neg(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & neg_(at::Tensor & self); // {"schema": "aten::neg_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & neg_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::neg.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor negative(const at::Tensor & self); // {"schema": "aten::negative(Tensor self) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & negative_(at::Tensor & self); // {"schema": "aten::negative_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & negative_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::negative.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor repeat(const at::Tensor & self, c10::SymIntArrayRef repeats); // {"schema": "aten::repeat(Tensor self, SymInt[] repeats) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor repeat_interleave(const at::Tensor & repeats, ::std::optional output_size); // {"schema": "aten::repeat_interleave.Tensor(Tensor repeats, *, SymInt? output_size=None) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor repeat_interleave(const at::Tensor & self, const at::Tensor & repeats, ::std::optional dim, ::std::optional output_size); // {"schema": "aten::repeat_interleave.self_Tensor(Tensor self, Tensor repeats, int? dim=None, *, SymInt? output_size=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor repeat_interleave(const at::Tensor & self, c10::SymInt repeats, ::std::optional dim, ::std::optional output_size); // {"schema": "aten::repeat_interleave.self_int(Tensor self, SymInt repeats, int? dim=None, *, SymInt? output_size=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor reshape(const at::Tensor & self, c10::SymIntArrayRef shape); // {"schema": "aten::reshape(Tensor(a) self, SymInt[] shape) -> Tensor(a)", "dispatch": "False", "default": "True"} +at::Tensor _reshape_copy(const at::Tensor & self, c10::SymIntArrayRef size); // {"schema": "aten::_reshape_copy(Tensor self, SymInt[] size) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor _reshape_alias(const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride); // {"schema": "aten::_reshape_alias(Tensor(a) self, SymInt[] size, SymInt[] stride) -> Tensor(a)", "dispatch": "True", "default": "False"} +at::Tensor _mkldnn_reshape(const at::Tensor & self, at::IntArrayRef shape); // {"schema": "aten::_mkldnn_reshape(Tensor self, int[] shape) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor reshape_as(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::reshape_as(Tensor(a) self, Tensor other) -> Tensor(a)", "dispatch": "False", "default": "True"} +at::Tensor round(const at::Tensor & self); // {"schema": "aten::round(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & round_(at::Tensor & self); // {"schema": "aten::round_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & round_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::round.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor round(const at::Tensor & self, int64_t decimals); // {"schema": "aten::round.decimals(Tensor self, *, int decimals) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & round_(at::Tensor & self, int64_t decimals); // {"schema": "aten::round_.decimals(Tensor(a!) self, *, int decimals) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & round_out(const at::Tensor & self, int64_t decimals, at::Tensor & out); // {"schema": "aten::round.decimals_out(Tensor self, *, int decimals, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor rrelu(const at::Tensor & self, const at::Scalar & lower, const at::Scalar & upper, bool training, ::std::optional generator); // {"schema": "aten::rrelu(Tensor self, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & rrelu_(at::Tensor & self, const at::Scalar & lower, const at::Scalar & upper, bool training, ::std::optional generator); // {"schema": "aten::rrelu_(Tensor(a!) self, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor relu(const at::Tensor & self); // {"schema": "aten::relu(Tensor self) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & relu_(at::Tensor & self); // {"schema": "aten::relu_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor relu6(const at::Tensor & self); // {"schema": "aten::relu6(Tensor self) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & relu6_(at::Tensor & self); // {"schema": "aten::relu6_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor prelu(const at::Tensor & self, const at::Tensor & weight); // {"schema": "aten::prelu(Tensor self, Tensor weight) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _prelu_kernel(const at::Tensor & self, const at::Tensor & weight); // {"schema": "aten::_prelu_kernel(Tensor self, Tensor weight) -> Tensor", "dispatch": "True", "default": "False"} +::std::tuple _prelu_kernel_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & weight); // {"schema": "aten::_prelu_kernel_backward(Tensor grad_output, Tensor self, Tensor weight) -> (Tensor, Tensor)", "dispatch": "True", "default": "False"} +at::Tensor & gelu_out(const at::Tensor & self, c10::string_view approximate, at::Tensor & out); // {"schema": "aten::gelu.out(Tensor self, *, str approximate='none', Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & gelu_(at::Tensor & self, c10::string_view approximate); // {"schema": "aten::gelu_(Tensor(a!) self, *, str approximate='none') -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor gelu(const at::Tensor & self, c10::string_view approximate); // {"schema": "aten::gelu(Tensor self, *, str approximate='none') -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & gelu_backward_out(const at::Tensor & grad_output, const at::Tensor & self, c10::string_view approximate, at::Tensor & grad_input); // {"schema": "aten::gelu_backward.grad_input(Tensor grad_output, Tensor self, *, str approximate='none', Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor gelu_backward(const at::Tensor & grad_output, const at::Tensor & self, c10::string_view approximate); // {"schema": "aten::gelu_backward(Tensor grad_output, Tensor self, *, str approximate='none') -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor infinitely_differentiable_gelu_backward(const at::Tensor & grad, const at::Tensor & self); // {"schema": "aten::infinitely_differentiable_gelu_backward(Tensor grad, Tensor self) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & hardshrink_out(const at::Tensor & self, const at::Scalar & lambd, at::Tensor & out); // {"schema": "aten::hardshrink.out(Tensor self, Scalar lambd=0.5, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor hardshrink(const at::Tensor & self, const at::Scalar & lambd); // {"schema": "aten::hardshrink(Tensor self, Scalar lambd=0.5) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & hardshrink_backward_out(const at::Tensor & grad_out, const at::Tensor & self, const at::Scalar & lambd, at::Tensor & grad_input); // {"schema": "aten::hardshrink_backward.grad_input(Tensor grad_out, Tensor self, Scalar lambd, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor hardshrink_backward(const at::Tensor & grad_out, const at::Tensor & self, const at::Scalar & lambd); // {"schema": "aten::hardshrink_backward(Tensor grad_out, Tensor self, Scalar lambd) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor rsqrt(const at::Tensor & self); // {"schema": "aten::rsqrt(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & rsqrt_(at::Tensor & self); // {"schema": "aten::rsqrt_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & rsqrt_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::rsqrt.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor select(const at::Tensor & self, at::Dimname dim, int64_t index); // {"schema": "aten::select.Dimname(Tensor(a) self, Dimname dim, int index) -> Tensor(a)", "dispatch": "False", "default": "True"} +at::Tensor select(const at::Tensor & self, int64_t dim, c10::SymInt index); // {"schema": "aten::select.int(Tensor(a) self, int dim, SymInt index) -> Tensor(a)", "dispatch": "True", "default": "True"} +at::Tensor select_backward(const at::Tensor & grad_output, c10::SymIntArrayRef input_sizes, int64_t dim, c10::SymInt index); // {"schema": "aten::select_backward(Tensor grad_output, SymInt[] input_sizes, int dim, SymInt index) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor _nested_select_backward(const at::Tensor & grad_output, const at::Tensor & self, int64_t dim, c10::SymInt index); // {"schema": "aten::_nested_select_backward(Tensor grad_output, Tensor self, int dim, SymInt index) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor selu(const at::Tensor & self); // {"schema": "aten::selu(Tensor self) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & selu_(at::Tensor & self); // {"schema": "aten::selu_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor celu(const at::Tensor & self, const at::Scalar & alpha); // {"schema": "aten::celu(Tensor self, Scalar alpha=1.0) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & celu_(at::Tensor & self, const at::Scalar & alpha); // {"schema": "aten::celu_(Tensor(a!) self, Scalar alpha=1.0) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor silu(const at::Tensor & self); // {"schema": "aten::silu(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & silu_(at::Tensor & self); // {"schema": "aten::silu_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & silu_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::silu.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & silu_backward_out(const at::Tensor & grad_output, const at::Tensor & self, at::Tensor & grad_input); // {"schema": "aten::silu_backward.grad_input(Tensor grad_output, Tensor self, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor silu_backward(const at::Tensor & grad_output, const at::Tensor & self); // {"schema": "aten::silu_backward(Tensor grad_output, Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor mish(const at::Tensor & self); // {"schema": "aten::mish(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & mish_(at::Tensor & self); // {"schema": "aten::mish_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & mish_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::mish.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor mish_backward(const at::Tensor & grad_output, const at::Tensor & self); // {"schema": "aten::mish_backward(Tensor grad_output, Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor sigmoid(const at::Tensor & self); // {"schema": "aten::sigmoid(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & sigmoid_(at::Tensor & self); // {"schema": "aten::sigmoid_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & sigmoid_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::sigmoid.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor logit(const at::Tensor & self, ::std::optional eps); // {"schema": "aten::logit(Tensor self, float? eps=None) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & logit_(at::Tensor & self, ::std::optional eps); // {"schema": "aten::logit_(Tensor(a!) self, float? eps=None) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & logit_out(const at::Tensor & self, ::std::optional eps, at::Tensor & out); // {"schema": "aten::logit.out(Tensor self, float? eps=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor sin(const at::Tensor & self); // {"schema": "aten::sin(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & sin_(at::Tensor & self); // {"schema": "aten::sin_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & sin_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::sin.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor sinc(const at::Tensor & self); // {"schema": "aten::sinc(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & sinc_(at::Tensor & self); // {"schema": "aten::sinc_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & sinc_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::sinc.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor sinh(const at::Tensor & self); // {"schema": "aten::sinh(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & sinh_(at::Tensor & self); // {"schema": "aten::sinh_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & sinh_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::sinh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor detach(const at::Tensor & self); // {"schema": "aten::detach(Tensor(a) self) -> Tensor(a)", "dispatch": "True", "default": "True"} +at::Tensor & detach_(at::Tensor & self); // {"schema": "aten::detach_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +int64_t size(const at::Tensor & self, int64_t dim); // {"schema": "aten::size.int(Tensor self, int dim) -> int", "dispatch": "False", "default": "True"} +int64_t size(const at::Tensor & self, at::Dimname dim); // {"schema": "aten::size.Dimname(Tensor self, Dimname dim) -> int", "dispatch": "False", "default": "True"} +c10::SymInt sym_size(const at::Tensor & self, int64_t dim); // {"schema": "aten::sym_size.int(Tensor self, int dim) -> SymInt", "dispatch": "False", "default": "True"} +c10::SymBool sym_is_contiguous(const at::Tensor & self, at::MemoryFormat memory_format); // {"schema": "aten::sym_is_contiguous(Tensor self, MemoryFormat memory_format=contiguous_format) -> SymBool", "dispatch": "False", "default": "True"} +c10::SymInt sym_numel(const at::Tensor & self); // {"schema": "aten::sym_numel(Tensor self) -> SymInt", "dispatch": "False", "default": "True"} +c10::SymInt sym_storage_offset(const at::Tensor & self); // {"schema": "aten::sym_storage_offset(Tensor self) -> SymInt", "dispatch": "False", "default": "True"} +at::Tensor slice(const at::Tensor & self, int64_t dim, ::std::optional start, ::std::optional end, c10::SymInt step); // {"schema": "aten::slice.Tensor(Tensor(a) self, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor(a)", "dispatch": "True", "default": "True"} +at::Tensor slice_backward(const at::Tensor & grad_output, c10::SymIntArrayRef input_sizes, int64_t dim, c10::SymInt start, c10::SymInt end, c10::SymInt step); // {"schema": "aten::slice_backward(Tensor grad_output, SymInt[] input_sizes, int dim, SymInt start, SymInt end, SymInt step) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor slice_inverse(const at::Tensor & self, const at::Tensor & src, int64_t dim, ::std::optional start, ::std::optional end, c10::SymInt step); // {"schema": "aten::slice_inverse(Tensor(a) self, Tensor src, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor(a)", "dispatch": "True", "default": "True"} +at::Tensor slice_scatter(const at::Tensor & self, const at::Tensor & src, int64_t dim, ::std::optional start, ::std::optional end, c10::SymInt step); // {"schema": "aten::slice_scatter(Tensor self, Tensor src, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor select_scatter(const at::Tensor & self, const at::Tensor & src, int64_t dim, c10::SymInt index); // {"schema": "aten::select_scatter(Tensor self, Tensor src, int dim, SymInt index) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor diagonal_scatter(const at::Tensor & self, const at::Tensor & src, int64_t offset, int64_t dim1, int64_t dim2); // {"schema": "aten::diagonal_scatter(Tensor self, Tensor src, int offset=0, int dim1=0, int dim2=1) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor as_strided_scatter(const at::Tensor & self, const at::Tensor & src, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional storage_offset); // {"schema": "aten::as_strided_scatter(Tensor self, Tensor src, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor smm(const at::Tensor & self, const at::Tensor & mat2); // {"schema": "aten::smm(Tensor self, Tensor mat2) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor softmax(const at::Tensor & self, int64_t dim, ::std::optional dtype); // {"schema": "aten::softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & softmax_out(const at::Tensor & self, int64_t dim, ::std::optional dtype, at::Tensor & out); // {"schema": "aten::softmax.int_out(Tensor self, int dim, ScalarType? dtype=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor softmax(const at::Tensor & self, at::Dimname dim, ::std::optional dtype); // {"schema": "aten::softmax.Dimname(Tensor self, Dimname dim, *, ScalarType? dtype=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _softmax(const at::Tensor & self, int64_t dim, bool half_to_float); // {"schema": "aten::_softmax(Tensor self, int dim, bool half_to_float) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & _softmax_out(const at::Tensor & self, int64_t dim, bool half_to_float, at::Tensor & out); // {"schema": "aten::_softmax.out(Tensor self, int dim, bool half_to_float, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor _softmax_backward_data(const at::Tensor & grad_output, const at::Tensor & output, int64_t dim, at::ScalarType input_dtype); // {"schema": "aten::_softmax_backward_data(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & _softmax_backward_data_out(const at::Tensor & grad_output, const at::Tensor & output, int64_t dim, at::ScalarType input_dtype, at::Tensor & grad_input); // {"schema": "aten::_softmax_backward_data.out(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +::std::vector unsafe_split(const at::Tensor & self, c10::SymInt split_size, int64_t dim); // {"schema": "aten::unsafe_split.Tensor(Tensor self, SymInt split_size, int dim=0) -> Tensor[]", "dispatch": "True", "default": "True"} +::std::vector split(const at::Tensor & self, c10::SymInt split_size, int64_t dim); // {"schema": "aten::split.Tensor(Tensor(a -> *) self, SymInt split_size, int dim=0) -> Tensor(a)[]", "dispatch": "True", "default": "True"} +::std::vector split(const at::Tensor & self, c10::SymIntArrayRef split_size, int64_t dim); // {"schema": "aten::split.sizes(Tensor(a -> *) self, SymInt[] split_size, int dim=0) -> Tensor(a)[]", "dispatch": "False", "default": "True"} +::std::vector unsafe_split_with_sizes(const at::Tensor & self, c10::SymIntArrayRef split_sizes, int64_t dim); // {"schema": "aten::unsafe_split_with_sizes(Tensor self, SymInt[] split_sizes, int dim=0) -> Tensor[]", "dispatch": "True", "default": "True"} +::std::vector split_with_sizes(const at::Tensor & self, c10::SymIntArrayRef split_sizes, int64_t dim); // {"schema": "aten::split_with_sizes(Tensor(a -> *) self, SymInt[] split_sizes, int dim=0) -> Tensor(a)[]", "dispatch": "True", "default": "True"} +::std::vector hsplit(const at::Tensor & self, int64_t sections); // {"schema": "aten::hsplit.int(Tensor(a -> *) self, int sections) -> Tensor(a)[]", "dispatch": "False", "default": "True"} +::std::vector hsplit(const at::Tensor & self, at::IntArrayRef indices); // {"schema": "aten::hsplit.array(Tensor(a -> *) self, int[] indices) -> Tensor(a)[]", "dispatch": "False", "default": "True"} +::std::vector vsplit(const at::Tensor & self, int64_t sections); // {"schema": "aten::vsplit.int(Tensor(a -> *) self, int sections) -> Tensor(a)[]", "dispatch": "False", "default": "True"} +::std::vector vsplit(const at::Tensor & self, at::IntArrayRef indices); // {"schema": "aten::vsplit.array(Tensor(a -> *) self, int[] indices) -> Tensor(a)[]", "dispatch": "False", "default": "True"} +::std::vector dsplit(const at::Tensor & self, int64_t sections); // {"schema": "aten::dsplit.int(Tensor(a -> *) self, int sections) -> Tensor(a)[]", "dispatch": "False", "default": "True"} +::std::vector dsplit(const at::Tensor & self, at::IntArrayRef indices); // {"schema": "aten::dsplit.array(Tensor(a -> *) self, int[] indices) -> Tensor(a)[]", "dispatch": "False", "default": "True"} +at::Tensor squeeze(const at::Tensor & self); // {"schema": "aten::squeeze(Tensor(a) self) -> Tensor(a)", "dispatch": "True", "default": "True"} +at::Tensor squeeze(const at::Tensor & self, int64_t dim); // {"schema": "aten::squeeze.dim(Tensor(a) self, int dim) -> Tensor(a)", "dispatch": "True", "default": "True"} +at::Tensor squeeze(const at::Tensor & self, at::Dimname dim); // {"schema": "aten::squeeze.dimname(Tensor(a) self, Dimname dim) -> Tensor(a)", "dispatch": "False", "default": "True"} +at::Tensor squeeze(const at::Tensor & self, at::IntArrayRef dim); // {"schema": "aten::squeeze.dims(Tensor(a) self, int[] dim) -> Tensor(a)", "dispatch": "True", "default": "True"} +at::Tensor & squeeze_(at::Tensor & self); // {"schema": "aten::squeeze_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & squeeze_(at::Tensor & self, int64_t dim); // {"schema": "aten::squeeze_.dim(Tensor(a!) self, int dim) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & squeeze_(at::Tensor & self, at::IntArrayRef dim); // {"schema": "aten::squeeze_.dims(Tensor(a!) self, int[] dim) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & squeeze_(at::Tensor & self, at::Dimname dim); // {"schema": "aten::squeeze_.dimname(Tensor(a!) self, Dimname dim) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor sspaddmm(const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta, const at::Scalar & alpha); // {"schema": "aten::sspaddmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & sspaddmm_out(const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out); // {"schema": "aten::sspaddmm.out(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor _chunk_cat(at::TensorList tensors, int64_t dim, int64_t num_chunks); // {"schema": "aten::_chunk_cat(Tensor[] tensors, int dim, int num_chunks) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & _chunk_cat_out(at::TensorList tensors, int64_t dim, int64_t num_chunks, at::Tensor & out); // {"schema": "aten::_chunk_cat.out(Tensor[] tensors, int dim, int num_chunks, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor stack(at::TensorList tensors, int64_t dim); // {"schema": "aten::stack(Tensor[] tensors, int dim=0) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & stack_out(at::TensorList tensors, int64_t dim, at::Tensor & out); // {"schema": "aten::stack.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor _stack(at::TensorList tensors, int64_t dim); // {"schema": "aten::_stack(Tensor[] tensors, int dim=0) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & _stack_out(at::TensorList tensors, int64_t dim, at::Tensor & out); // {"schema": "aten::_stack.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor hstack(at::TensorList tensors); // {"schema": "aten::hstack(Tensor[] tensors) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & hstack_out(at::TensorList tensors, at::Tensor & out); // {"schema": "aten::hstack.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor vstack(at::TensorList tensors); // {"schema": "aten::vstack(Tensor[] tensors) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & vstack_out(at::TensorList tensors, at::Tensor & out); // {"schema": "aten::vstack.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor dstack(at::TensorList tensors); // {"schema": "aten::dstack(Tensor[] tensors) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & dstack_out(at::TensorList tensors, at::Tensor & out); // {"schema": "aten::dstack.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor stft(const at::Tensor & self, int64_t n_fft, ::std::optional hop_length, ::std::optional win_length, const ::std::optional & window, bool normalized, ::std::optional onesided, ::std::optional return_complex, ::std::optional align_to_window); // {"schema": "aten::stft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool normalized=False, bool? onesided=None, bool? return_complex=None, bool? align_to_window=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor stft(const at::Tensor & self, int64_t n_fft, ::std::optional hop_length, ::std::optional win_length, const ::std::optional & window, bool center, c10::string_view pad_mode, bool normalized, ::std::optional onesided, ::std::optional return_complex, ::std::optional align_to_window); // {"schema": "aten::stft.center(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool center=True, str pad_mode=\"reflect\", bool normalized=False, bool? onesided=None, bool? return_complex=None, bool? align_to_window=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor istft(const at::Tensor & self, int64_t n_fft, ::std::optional hop_length, ::std::optional win_length, const ::std::optional & window, bool center, bool normalized, ::std::optional onesided, ::std::optional length, bool return_complex); // {"schema": "aten::istft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool center=True, bool normalized=False, bool? onesided=None, int? length=None, bool return_complex=False) -> Tensor", "dispatch": "False", "default": "True"} +int64_t stride(const at::Tensor & self, int64_t dim); // {"schema": "aten::stride.int(Tensor self, int dim) -> int", "dispatch": "False", "default": "True"} +int64_t stride(const at::Tensor & self, at::Dimname dim); // {"schema": "aten::stride.Dimname(Tensor self, Dimname dim) -> int", "dispatch": "False", "default": "True"} +c10::SymInt sym_stride(const at::Tensor & self, int64_t dim); // {"schema": "aten::sym_stride.int(Tensor self, int dim) -> SymInt", "dispatch": "False", "default": "True"} +at::Tensor sum(const at::Tensor & self, ::std::optional dtype); // {"schema": "aten::sum(Tensor self, *, ScalarType? dtype=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor sum(const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim, ::std::optional dtype); // {"schema": "aten::sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor sum(const at::Tensor & self, at::DimnameList dim, bool keepdim, ::std::optional dtype); // {"schema": "aten::sum.dim_DimnameList(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & sum_out(const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim, ::std::optional dtype, at::Tensor & out); // {"schema": "aten::sum.IntList_out(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & sum_out(const at::Tensor & self, at::DimnameList dim, bool keepdim, ::std::optional dtype, at::Tensor & out); // {"schema": "aten::sum.DimnameList_out(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor _nested_sum_backward(const at::Tensor & grad, const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim); // {"schema": "aten::_nested_sum_backward(Tensor grad, Tensor self, int[1]? dim, bool keepdim=False) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor nansum(const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim, ::std::optional dtype); // {"schema": "aten::nansum(Tensor self, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & nansum_out(const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim, ::std::optional dtype, at::Tensor & out); // {"schema": "aten::nansum.out(Tensor self, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor hash_tensor(const at::Tensor & self, at::IntArrayRef dim, bool keepdim, int64_t mode); // {"schema": "aten::hash_tensor(Tensor self, int[1] dim=[], *, bool keepdim=False, int mode=0) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & hash_tensor_out(const at::Tensor & self, at::IntArrayRef dim, bool keepdim, int64_t mode, at::Tensor & out); // {"schema": "aten::hash_tensor.out(Tensor self, int[1] dim=[], *, bool keepdim=False, int mode=0, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor sum_to_size(const at::Tensor & self, c10::SymIntArrayRef size); // {"schema": "aten::sum_to_size(Tensor self, SymInt[] size) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor sqrt(const at::Tensor & self); // {"schema": "aten::sqrt(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & sqrt_(at::Tensor & self); // {"schema": "aten::sqrt_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & sqrt_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::sqrt.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor square(const at::Tensor & self); // {"schema": "aten::square(Tensor self) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & square_(at::Tensor & self); // {"schema": "aten::square_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & square_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::square.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor std(const at::Tensor & self, bool unbiased); // {"schema": "aten::std(Tensor self, bool unbiased=True) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor std(const at::Tensor & self, at::OptionalIntArrayRef dim, bool unbiased, bool keepdim); // {"schema": "aten::std.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor std(const at::Tensor & self, at::OptionalIntArrayRef dim, const ::std::optional & correction, bool keepdim); // {"schema": "aten::std.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> Tensor", "dispatch": "True", "default": "False"} +::std::tuple std_mean(const at::Tensor & self, bool unbiased); // {"schema": "aten::std_mean(Tensor self, bool unbiased=True) -> (Tensor, Tensor)", "dispatch": "False", "default": "True"} +::std::tuple std_mean(const at::Tensor & self, at::OptionalIntArrayRef dim, bool unbiased, bool keepdim); // {"schema": "aten::std_mean.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> (Tensor, Tensor)", "dispatch": "False", "default": "True"} +::std::tuple std_mean(const at::Tensor & self, at::OptionalIntArrayRef dim, const ::std::optional & correction, bool keepdim); // {"schema": "aten::std_mean.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> (Tensor, Tensor)", "dispatch": "True", "default": "False"} +::std::tuple std_mean(const at::Tensor & self, at::DimnameList dim, bool unbiased, bool keepdim); // {"schema": "aten::std_mean.names_dim(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False) -> (Tensor, Tensor)", "dispatch": "False", "default": "True"} +::std::tuple std_mean(const at::Tensor & self, at::DimnameList dim, const ::std::optional & correction, bool keepdim); // {"schema": "aten::std_mean.correction_names(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False) -> (Tensor, Tensor)", "dispatch": "False", "default": "True"} +at::Tensor & std_out(const at::Tensor & self, at::OptionalIntArrayRef dim, bool unbiased, bool keepdim, at::Tensor & out); // {"schema": "aten::std.out(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & std_out(const at::Tensor & self, at::OptionalIntArrayRef dim, const ::std::optional & correction, bool keepdim, at::Tensor & out); // {"schema": "aten::std.correction_out(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor std(const at::Tensor & self, at::DimnameList dim, bool unbiased, bool keepdim); // {"schema": "aten::std.names_dim(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & std_out(const at::Tensor & self, at::DimnameList dim, bool unbiased, bool keepdim, at::Tensor & out); // {"schema": "aten::std.names_out(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor std(const at::Tensor & self, at::DimnameList dim, const ::std::optional & correction, bool keepdim); // {"schema": "aten::std.correction_names(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & std_out(const at::Tensor & self, at::DimnameList dim, const ::std::optional & correction, bool keepdim, at::Tensor & out); // {"schema": "aten::std.correction_names_out(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor prod(const at::Tensor & self, ::std::optional dtype); // {"schema": "aten::prod(Tensor self, *, ScalarType? dtype=None) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor prod(const at::Tensor & self, int64_t dim, bool keepdim, ::std::optional dtype); // {"schema": "aten::prod.dim_int(Tensor self, int dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & prod_out(const at::Tensor & self, int64_t dim, bool keepdim, ::std::optional dtype, at::Tensor & out); // {"schema": "aten::prod.int_out(Tensor self, int dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor prod(const at::Tensor & self, at::Dimname dim, bool keepdim, ::std::optional dtype); // {"schema": "aten::prod.dim_Dimname(Tensor self, Dimname dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & prod_out(const at::Tensor & self, at::Dimname dim, bool keepdim, ::std::optional dtype, at::Tensor & out); // {"schema": "aten::prod.Dimname_out(Tensor self, Dimname dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor t(const at::Tensor & self); // {"schema": "aten::t(Tensor(a) self) -> Tensor(a)", "dispatch": "True", "default": "True"} +at::Tensor & t_(at::Tensor & self); // {"schema": "aten::t_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor tan(const at::Tensor & self); // {"schema": "aten::tan(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & tan_(at::Tensor & self); // {"schema": "aten::tan_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & tan_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::tan.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor tanh(const at::Tensor & self); // {"schema": "aten::tanh(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & tanh_(at::Tensor & self); // {"schema": "aten::tanh_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & tanh_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::tanh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor tensordot(const at::Tensor & self, const at::Tensor & other, at::IntArrayRef dims_self, at::IntArrayRef dims_other); // {"schema": "aten::tensordot(Tensor self, Tensor other, int[] dims_self, int[] dims_other) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & tensordot_out(const at::Tensor & self, const at::Tensor & other, at::IntArrayRef dims_self, at::IntArrayRef dims_other, at::Tensor & out); // {"schema": "aten::tensordot.out(Tensor self, Tensor other, int[] dims_self, int[] dims_other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor threshold(const at::Tensor & self, const at::Scalar & threshold, const at::Scalar & value); // {"schema": "aten::threshold(Tensor self, Scalar threshold, Scalar value) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & threshold_(at::Tensor & self, const at::Scalar & threshold, const at::Scalar & value); // {"schema": "aten::threshold_(Tensor(a!) self, Scalar threshold, Scalar value) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & threshold_out(const at::Tensor & self, const at::Scalar & threshold, const at::Scalar & value, at::Tensor & out); // {"schema": "aten::threshold.out(Tensor self, Scalar threshold, Scalar value, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & threshold_backward_out(const at::Tensor & grad_output, const at::Tensor & self, const at::Scalar & threshold, at::Tensor & grad_input); // {"schema": "aten::threshold_backward.grad_input(Tensor grad_output, Tensor self, Scalar threshold, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor threshold_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Scalar & threshold); // {"schema": "aten::threshold_backward(Tensor grad_output, Tensor self, Scalar threshold) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor tile(const at::Tensor & self, c10::SymIntArrayRef dims); // {"schema": "aten::tile(Tensor self, SymInt[] dims) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor transpose(const at::Tensor & self, int64_t dim0, int64_t dim1); // {"schema": "aten::transpose.int(Tensor(a) self, int dim0, int dim1) -> Tensor(a)", "dispatch": "True", "default": "True"} +at::Tensor transpose(const at::Tensor & self, at::Dimname dim0, at::Dimname dim1); // {"schema": "aten::transpose.Dimname(Tensor(a) self, Dimname dim0, Dimname dim1) -> Tensor(a)", "dispatch": "False", "default": "True"} +at::Tensor _mkldnn_transpose(const at::Tensor & self, int64_t dim0, int64_t dim1); // {"schema": "aten::_mkldnn_transpose(Tensor self, int dim0, int dim1) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & transpose_(at::Tensor & self, int64_t dim0, int64_t dim1); // {"schema": "aten::transpose_(Tensor(a!) self, int dim0, int dim1) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _mkldnn_transpose_(at::Tensor & self, int64_t dim0, int64_t dim1); // {"schema": "aten::_mkldnn_transpose_(Tensor(a!) self, int dim0, int dim1) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor one_hot(const at::Tensor & self, int64_t num_classes); // {"schema": "aten::one_hot(Tensor self, int num_classes=-1) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor flip(const at::Tensor & self, at::IntArrayRef dims); // {"schema": "aten::flip(Tensor self, int[] dims) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor fliplr(const at::Tensor & self); // {"schema": "aten::fliplr(Tensor self) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor flipud(const at::Tensor & self); // {"schema": "aten::flipud(Tensor self) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor roll(const at::Tensor & self, c10::SymIntArrayRef shifts, at::IntArrayRef dims); // {"schema": "aten::roll(Tensor self, SymInt[1] shifts, int[1] dims=[]) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor rot90(const at::Tensor & self, int64_t k, at::IntArrayRef dims); // {"schema": "aten::rot90(Tensor self, int k=1, int[] dims=[0,1]) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor trapezoid(const at::Tensor & y, const at::Tensor & x, int64_t dim); // {"schema": "aten::trapezoid.x(Tensor y, Tensor x, *, int dim=-1) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor trapezoid(const at::Tensor & y, const at::Scalar & dx, int64_t dim); // {"schema": "aten::trapezoid.dx(Tensor y, *, Scalar dx=1, int dim=-1) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor trapz(const at::Tensor & y, const at::Tensor & x, int64_t dim); // {"schema": "aten::trapz.x(Tensor y, Tensor x, *, int dim=-1) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor trapz(const at::Tensor & y, double dx, int64_t dim); // {"schema": "aten::trapz.dx(Tensor y, *, float dx=1, int dim=-1) -> Tensor", "dispatch": "False", "default": "True"} +::std::tuple _transform_bias_rescale_qkv(const at::Tensor & qkv, const at::Tensor & qkv_bias, int64_t num_heads); // {"schema": "aten::_transform_bias_rescale_qkv(Tensor qkv, Tensor qkv_bias, int num_heads) -> (Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"} +at::Tensor _nested_tensor_from_mask(const at::Tensor & t, const at::Tensor & mask, bool mask_check); // {"schema": "aten::_nested_tensor_from_mask(Tensor t, Tensor mask, bool mask_check=True) -> Tensor", "dispatch": "True", "default": "False"} +bool _nested_tensor_from_mask_left_aligned(const at::Tensor & t, const at::Tensor & mask); // {"schema": "aten::_nested_tensor_from_mask_left_aligned(Tensor t, Tensor mask) -> bool", "dispatch": "True", "default": "False"} +at::Tensor _nested_from_padded(const at::Tensor & padded, const at::Tensor & cpu_nested_shape_example, bool fuse_transform_0213); // {"schema": "aten::_nested_from_padded(Tensor padded, Tensor cpu_nested_shape_example, bool fuse_transform_0213=False) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _nested_tensor_size(const at::Tensor & self); // {"schema": "aten::_nested_tensor_size(Tensor self) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _nested_tensor_strides(const at::Tensor & self); // {"schema": "aten::_nested_tensor_strides(Tensor self) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _nested_tensor_storage_offsets(const at::Tensor & self); // {"schema": "aten::_nested_tensor_storage_offsets(Tensor self) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _nested_from_padded_and_nested_example(const at::Tensor & padded, const at::Tensor & nt_example); // {"schema": "aten::_nested_from_padded_and_nested_example(Tensor padded, Tensor nt_example) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _nested_view_from_buffer(const at::Tensor & self, const at::Tensor & nested_size, const at::Tensor & nested_strides, const at::Tensor & offsets); // {"schema": "aten::_nested_view_from_buffer(Tensor(a) self, Tensor nested_size, Tensor nested_strides, Tensor offsets) -> Tensor(a)", "dispatch": "True", "default": "False"} +at::Tensor _nested_view_from_buffer_copy(const at::Tensor & self, const at::Tensor & nested_size, const at::Tensor & nested_strides, const at::Tensor & offsets); // {"schema": "aten::_nested_view_from_buffer_copy(Tensor self, Tensor nested_size, Tensor nested_strides, Tensor offsets) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor _nested_view_from_jagged(const at::Tensor & self, const at::Tensor & offsets, const at::Tensor & dummy, const ::std::optional & lengths, int64_t ragged_idx, const ::std::optional & min_seqlen, const ::std::optional & max_seqlen); // {"schema": "aten::_nested_view_from_jagged(Tensor(a) self, Tensor offsets, Tensor dummy, Tensor? lengths=None, int ragged_idx=1, Tensor? min_seqlen=None, Tensor? max_seqlen=None) -> Tensor(a)", "dispatch": "True", "default": "False"} +at::Tensor _nested_view_from_jagged_copy(const at::Tensor & self, const at::Tensor & offsets, const at::Tensor & dummy, const ::std::optional & lengths, int64_t ragged_idx, const ::std::optional & min_seqlen, const ::std::optional & max_seqlen); // {"schema": "aten::_nested_view_from_jagged_copy(Tensor self, Tensor offsets, Tensor dummy, Tensor? lengths=None, int ragged_idx=1, Tensor? min_seqlen=None, Tensor? max_seqlen=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor _nested_get_values(const at::Tensor & self); // {"schema": "aten::_nested_get_values(Tensor(a) self) -> Tensor(a)", "dispatch": "True", "default": "False"} +at::Tensor _nested_get_values_copy(const at::Tensor & self); // {"schema": "aten::_nested_get_values_copy(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor _nested_get_offsets(const at::Tensor & self); // {"schema": "aten::_nested_get_offsets(Tensor self) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _nested_get_lengths(const at::Tensor & self); // {"schema": "aten::_nested_get_lengths(Tensor self) -> Tensor", "dispatch": "True", "default": "False"} +int64_t _nested_get_ragged_idx(const at::Tensor & self); // {"schema": "aten::_nested_get_ragged_idx(Tensor self) -> int", "dispatch": "True", "default": "False"} +at::Tensor _nested_get_min_seqlen(const at::Tensor & self); // {"schema": "aten::_nested_get_min_seqlen(Tensor self) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _nested_get_max_seqlen(const at::Tensor & self); // {"schema": "aten::_nested_get_max_seqlen(Tensor self) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _nested_get_jagged_dummy(const at::Tensor & any); // {"schema": "aten::_nested_get_jagged_dummy(Tensor any) -> Tensor", "dispatch": "True", "default": "False"} +::std::tuple _nested_compute_contiguous_strides_offsets(const at::Tensor & nested_size); // {"schema": "aten::_nested_compute_contiguous_strides_offsets(Tensor nested_size) -> (Tensor, Tensor)", "dispatch": "True", "default": "False"} +at::Tensor _trilinear(const at::Tensor & i1, const at::Tensor & i2, const at::Tensor & i3, at::IntArrayRef expand1, at::IntArrayRef expand2, at::IntArrayRef expand3, at::IntArrayRef sumdim, int64_t unroll_dim); // {"schema": "aten::_trilinear(Tensor i1, Tensor i2, Tensor i3, int[] expand1, int[] expand2, int[] expand3, int[] sumdim, int unroll_dim=1) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor triplet_margin_loss(const at::Tensor & anchor, const at::Tensor & positive, const at::Tensor & negative, double margin, double p, double eps, bool swap, int64_t reduction); // {"schema": "aten::triplet_margin_loss(Tensor anchor, Tensor positive, Tensor negative, float margin=1.0, float p=2, float eps=1e-06, bool swap=False, int reduction=Mean) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor trunc(const at::Tensor & self); // {"schema": "aten::trunc(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & trunc_(at::Tensor & self); // {"schema": "aten::trunc_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & trunc_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::trunc.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor fix(const at::Tensor & self); // {"schema": "aten::fix(Tensor self) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & fix_(at::Tensor & self); // {"schema": "aten::fix_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & fix_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::fix.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor type_as(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::type_as(Tensor self, Tensor other) -> Tensor", "dispatch": "False", "default": "True"} +bool _has_compatible_shallow_copy_type(const at::Tensor & self, const at::Tensor & from); // {"schema": "aten::_has_compatible_shallow_copy_type(Tensor self, Tensor from) -> bool", "dispatch": "False", "default": "True"} +::std::tuple _unique(const at::Tensor & self, bool sorted, bool return_inverse); // {"schema": "aten::_unique(Tensor self, bool sorted=True, bool return_inverse=False) -> (Tensor, Tensor)", "dispatch": "True", "default": "False"} +::std::tuple unique_dim(const at::Tensor & self, int64_t dim, bool sorted, bool return_inverse, bool return_counts); // {"schema": "aten::unique_dim(Tensor self, int dim, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"} +::std::tuple unique_consecutive(const at::Tensor & self, bool return_inverse, bool return_counts, ::std::optional dim); // {"schema": "aten::unique_consecutive(Tensor self, bool return_inverse=False, bool return_counts=False, int? dim=None) -> (Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"} +::std::tuple unique_dim_consecutive(const at::Tensor & self, int64_t dim, bool return_inverse, bool return_counts); // {"schema": "aten::unique_dim_consecutive(Tensor self, int dim, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"} +::std::tuple _unique2(const at::Tensor & self, bool sorted, bool return_inverse, bool return_counts); // {"schema": "aten::_unique2(Tensor self, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"} +at::Tensor _unsafe_view(const at::Tensor & self, c10::SymIntArrayRef size); // {"schema": "aten::_unsafe_view(Tensor self, SymInt[] size) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor unsqueeze(const at::Tensor & self, int64_t dim); // {"schema": "aten::unsqueeze(Tensor(a) self, int dim) -> Tensor(a)", "dispatch": "True", "default": "True"} +at::Tensor & unsqueeze_(at::Tensor & self, int64_t dim); // {"schema": "aten::unsqueeze_(Tensor(a!) self, int dim) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor vander(const at::Tensor & x, ::std::optional N, bool increasing); // {"schema": "aten::vander(Tensor x, int? N=None, bool increasing=False) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor var(const at::Tensor & self, bool unbiased); // {"schema": "aten::var(Tensor self, bool unbiased=True) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor var(const at::Tensor & self, at::OptionalIntArrayRef dim, bool unbiased, bool keepdim); // {"schema": "aten::var.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor var(const at::Tensor & self, at::OptionalIntArrayRef dim, const ::std::optional & correction, bool keepdim); // {"schema": "aten::var.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & var_out(const at::Tensor & self, at::OptionalIntArrayRef dim, bool unbiased, bool keepdim, at::Tensor & out); // {"schema": "aten::var.out(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & var_out(const at::Tensor & self, at::OptionalIntArrayRef dim, const ::std::optional & correction, bool keepdim, at::Tensor & out); // {"schema": "aten::var.correction_out(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor var(const at::Tensor & self, at::DimnameList dim, bool unbiased, bool keepdim); // {"schema": "aten::var.names_dim(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & var_out(const at::Tensor & self, at::DimnameList dim, bool unbiased, bool keepdim, at::Tensor & out); // {"schema": "aten::var.names_out(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor var(const at::Tensor & self, at::DimnameList dim, const ::std::optional & correction, bool keepdim); // {"schema": "aten::var.correction_names(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & var_out(const at::Tensor & self, at::DimnameList dim, const ::std::optional & correction, bool keepdim, at::Tensor & out); // {"schema": "aten::var.correction_names_out(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +::std::tuple var_mean(const at::Tensor & self, bool unbiased); // {"schema": "aten::var_mean(Tensor self, bool unbiased=True) -> (Tensor, Tensor)", "dispatch": "False", "default": "True"} +::std::tuple var_mean(const at::Tensor & self, at::OptionalIntArrayRef dim, bool unbiased, bool keepdim); // {"schema": "aten::var_mean.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> (Tensor, Tensor)", "dispatch": "False", "default": "True"} +::std::tuple var_mean(const at::Tensor & self, at::OptionalIntArrayRef dim, const ::std::optional & correction, bool keepdim); // {"schema": "aten::var_mean.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> (Tensor, Tensor)", "dispatch": "True", "default": "False"} +::std::tuple var_mean(const at::Tensor & self, at::DimnameList dim, bool unbiased, bool keepdim); // {"schema": "aten::var_mean.names_dim(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False) -> (Tensor, Tensor)", "dispatch": "False", "default": "True"} +::std::tuple var_mean(const at::Tensor & self, at::DimnameList dim, const ::std::optional & correction, bool keepdim); // {"schema": "aten::var_mean.correction_names(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False) -> (Tensor, Tensor)", "dispatch": "False", "default": "True"} +at::Tensor view_as(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::view_as(Tensor(a) self, Tensor other) -> Tensor(a)", "dispatch": "False", "default": "True"} +at::Tensor where(const at::Tensor & condition, const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::where.self(Tensor condition, Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & where_out(const at::Tensor & condition, const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::where.self_out(Tensor condition, Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor where(const at::Tensor & condition, const at::Scalar & self, const at::Tensor & other); // {"schema": "aten::where.ScalarSelf(Tensor condition, Scalar self, Tensor other) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor where(const at::Tensor & condition, const at::Tensor & self, const at::Scalar & other); // {"schema": "aten::where.ScalarOther(Tensor condition, Tensor self, Scalar other) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor where(const at::Tensor & condition, const at::Scalar & self, const at::Scalar & other); // {"schema": "aten::where.Scalar(Tensor condition, Scalar self, Scalar other) -> Tensor", "dispatch": "False", "default": "True"} +::std::vector where(const at::Tensor & condition); // {"schema": "aten::where(Tensor condition) -> Tensor[]", "dispatch": "False", "default": "True"} +at::Tensor norm_except_dim(const at::Tensor & v, int64_t pow, int64_t dim); // {"schema": "aten::norm_except_dim(Tensor v, int pow=2, int dim=0) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _weight_norm(const at::Tensor & v, const at::Tensor & g, int64_t dim); // {"schema": "aten::_weight_norm(Tensor v, Tensor g, int dim=0) -> Tensor", "dispatch": "False", "default": "True"} +::std::tuple _weight_norm_interface(const at::Tensor & v, const at::Tensor & g, int64_t dim); // {"schema": "aten::_weight_norm_interface(Tensor v, Tensor g, int dim=0) -> (Tensor, Tensor)", "dispatch": "True", "default": "False"} +::std::tuple _weight_norm_interface_backward(const at::Tensor & grad_w, const at::Tensor & saved_v, const at::Tensor & saved_g, const at::Tensor & saved_norms, int64_t dim); // {"schema": "aten::_weight_norm_interface_backward(Tensor grad_w, Tensor saved_v, Tensor saved_g, Tensor saved_norms, int dim) -> (Tensor, Tensor)", "dispatch": "True", "default": "False"} +::std::tuple _weight_norm_differentiable_backward(const at::Tensor & grad_w, const at::Tensor & saved_v, const at::Tensor & saved_g, const at::Tensor & saved_norms, int64_t dim); // {"schema": "aten::_weight_norm_differentiable_backward(Tensor grad_w, Tensor saved_v, Tensor saved_g, Tensor saved_norms, int dim) -> (Tensor, Tensor)", "dispatch": "False", "default": "True"} +at::Tensor zeros(at::IntArrayRef size, ::std::optional names, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::zeros.names(int[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor _efficientzerotensor(c10::SymIntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::_efficientzerotensor(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor zeros(c10::SymIntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::zeros(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & zeros_out(c10::SymIntArrayRef size, at::Tensor & out); // {"schema": "aten::zeros.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor zeros_like(const at::Tensor & self, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format); // {"schema": "aten::zeros_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor _standard_gamma_grad(const at::Tensor & self, const at::Tensor & output); // {"schema": "aten::_standard_gamma_grad(Tensor self, Tensor output) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _standard_gamma(const at::Tensor & self, ::std::optional generator); // {"schema": "aten::_standard_gamma(Tensor self, Generator? generator=None) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _dirichlet_grad(const at::Tensor & x, const at::Tensor & alpha, const at::Tensor & total); // {"schema": "aten::_dirichlet_grad(Tensor x, Tensor alpha, Tensor total) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _sample_dirichlet(const at::Tensor & self, ::std::optional generator); // {"schema": "aten::_sample_dirichlet(Tensor self, Generator? generator=None) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor poisson(const at::Tensor & self, ::std::optional generator); // {"schema": "aten::poisson(Tensor self, Generator? generator=None) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor binomial(const at::Tensor & count, const at::Tensor & prob, ::std::optional generator); // {"schema": "aten::binomial(Tensor count, Tensor prob, Generator? generator=None) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor native_norm(const at::Tensor & self, const at::Scalar & p); // {"schema": "aten::native_norm(Tensor self, Scalar p=2) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor native_norm(const at::Tensor & self, const ::std::optional & p, at::IntArrayRef dim, bool keepdim, ::std::optional dtype); // {"schema": "aten::native_norm.ScalarOpt_dim_dtype(Tensor self, Scalar? p, int[1] dim, bool keepdim, ScalarType? dtype) -> Tensor", "dispatch": "True", "default": "False"} +::std::tuple _batch_norm_with_update(const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, at::Tensor & running_mean, at::Tensor & running_var, double momentum, double eps); // {"schema": "aten::_batch_norm_with_update(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"} +::std::tuple _batch_norm_with_update_out(const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, at::Tensor & running_mean, at::Tensor & running_var, double momentum, double eps, at::Tensor & out, at::Tensor & save_mean, at::Tensor & save_invstd, at::Tensor & reserve); // {"schema": "aten::_batch_norm_with_update.out(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, float momentum, float eps, *, Tensor(d!) out, Tensor(e!) save_mean, Tensor(f!) save_invstd, Tensor(g!) reserve) -> (Tensor(d!), Tensor(e!), Tensor(f!), Tensor(g!))", "dispatch": "True", "default": "False"} +::std::tuple _batch_norm_no_update(const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const ::std::optional & running_mean, const ::std::optional & running_var, double momentum, double eps); // {"schema": "aten::_batch_norm_no_update(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor, Tensor)", "dispatch": "True", "default": "True"} +::std::tuple batch_norm_backward(const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & running_mean, const ::std::optional & running_var, const ::std::optional & save_mean, const ::std::optional & save_var, bool update, double eps, ::std::array output_mask, const at::Tensor & reserve); // {"schema": "aten::batch_norm_backward(Tensor grad_out, Tensor input, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, bool update, float eps, bool[3] output_mask, Tensor reserve) -> (Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"} +at::Tensor _sparse_sum(const at::Tensor & self); // {"schema": "aten::_sparse_sum(Tensor self) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _sparse_sum(const at::Tensor & self, at::ScalarType dtype); // {"schema": "aten::_sparse_sum.dtype(Tensor self, *, ScalarType dtype) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _sparse_sum(const at::Tensor & self, at::IntArrayRef dim); // {"schema": "aten::_sparse_sum.dim(Tensor self, int[1] dim) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor _sparse_sum(const at::Tensor & self, at::IntArrayRef dim, at::ScalarType dtype); // {"schema": "aten::_sparse_sum.dim_dtype(Tensor self, int[1] dim, *, ScalarType dtype) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _sparse_sum_backward(const at::Tensor & grad, const at::Tensor & self, at::IntArrayRef dim); // {"schema": "aten::_sparse_sum_backward(Tensor grad, Tensor self, int[] dim) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _sparse_csr_sum(const at::Tensor & self, at::IntArrayRef dim, bool keepdim, ::std::optional dtype); // {"schema": "aten::_sparse_csr_sum.dim_dtype(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _sparse_csr_prod(const at::Tensor & self, at::IntArrayRef dim, bool keepdim, ::std::optional dtype); // {"schema": "aten::_sparse_csr_prod.dim_dtype(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _sparse_softmax(const at::Tensor & self, int64_t dim, ::std::optional dtype); // {"schema": "aten::_sparse_softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _sparse_softmax(const at::Tensor & self, at::Dimname dim, ::std::optional dtype); // {"schema": "aten::_sparse_softmax.Dimname(Tensor self, Dimname dim, *, ScalarType? dtype=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _sparse_softmax(const at::Tensor & self, int64_t dim, bool half_to_float); // {"schema": "aten::_sparse_softmax(Tensor self, int dim, bool half_to_float) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _sparse_softmax_backward_data(const at::Tensor & grad_output, const at::Tensor & output, int64_t dim, const at::Tensor & self); // {"schema": "aten::_sparse_softmax_backward_data(Tensor grad_output, Tensor output, int dim, Tensor self) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _sparse_log_softmax(const at::Tensor & self, int64_t dim, ::std::optional dtype); // {"schema": "aten::_sparse_log_softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _sparse_log_softmax(const at::Tensor & self, at::Dimname dim, ::std::optional dtype); // {"schema": "aten::_sparse_log_softmax.Dimname(Tensor self, Dimname dim, *, ScalarType? dtype=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _sparse_log_softmax(const at::Tensor & self, int64_t dim, bool half_to_float); // {"schema": "aten::_sparse_log_softmax(Tensor self, int dim, bool half_to_float) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _sparse_log_softmax_backward_data(const at::Tensor & grad_output, const at::Tensor & output, int64_t dim, const at::Tensor & self); // {"schema": "aten::_sparse_log_softmax_backward_data(Tensor grad_output, Tensor output, int dim, Tensor self) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _spdiags(const at::Tensor & diagonals, const at::Tensor & offsets, at::IntArrayRef shape, ::std::optional layout); // {"schema": "aten::_spdiags(Tensor diagonals, Tensor offsets, int[] shape, Layout? layout=None) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor norm(const at::Tensor & self, const ::std::optional & p, at::ScalarType dtype); // {"schema": "aten::norm.ScalarOpt_dtype(Tensor self, Scalar? p, *, ScalarType dtype) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor norm(const at::Tensor & self, const at::Scalar & p); // {"schema": "aten::norm.Scalar(Tensor self, Scalar p=2) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor norm(const at::Tensor & self, const ::std::optional & p, at::IntArrayRef dim, bool keepdim, at::ScalarType dtype); // {"schema": "aten::norm.ScalarOpt_dim_dtype(Tensor self, Scalar? p, int[1] dim, bool keepdim, *, ScalarType dtype) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor norm(const at::Tensor & self, const ::std::optional & p, at::IntArrayRef dim, bool keepdim); // {"schema": "aten::norm.ScalarOpt_dim(Tensor self, Scalar? p, int[1] dim, bool keepdim=False) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & norm_out(const at::Tensor & self, const ::std::optional & p, at::IntArrayRef dim, bool keepdim, at::ScalarType dtype, at::Tensor & out); // {"schema": "aten::norm.dtype_out(Tensor self, Scalar? p, int[1] dim, bool keepdim, *, ScalarType dtype, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & norm_out(const at::Tensor & self, const ::std::optional & p, at::IntArrayRef dim, bool keepdim, at::Tensor & out); // {"schema": "aten::norm.out(Tensor self, Scalar? p, int[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor norm(const at::Tensor & self, const ::std::optional & p, at::DimnameList dim, bool keepdim, at::ScalarType dtype); // {"schema": "aten::norm.names_ScalarOpt_dim_dtype(Tensor self, Scalar? p, Dimname[1] dim, bool keepdim, *, ScalarType dtype) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor norm(const at::Tensor & self, const ::std::optional & p, at::DimnameList dim, bool keepdim); // {"schema": "aten::norm.names_ScalarOpt_dim(Tensor self, Scalar? p, Dimname[1] dim, bool keepdim=False) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & norm_out(const at::Tensor & self, const ::std::optional & p, at::DimnameList dim, bool keepdim, at::ScalarType dtype, at::Tensor & out); // {"schema": "aten::norm.names_dtype_out(Tensor self, Scalar? p, Dimname[1] dim, bool keepdim, *, ScalarType dtype, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & norm_out(const at::Tensor & self, const ::std::optional & p, at::DimnameList dim, bool keepdim, at::Tensor & out); // {"schema": "aten::norm.names_out(Tensor self, Scalar? p, Dimname[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +::std::tuple frexp(const at::Tensor & self); // {"schema": "aten::frexp.Tensor(Tensor self) -> (Tensor mantissa, Tensor exponent)", "dispatch": "True", "default": "True"} +::std::tuple frexp_out(const at::Tensor & self, at::Tensor & mantissa, at::Tensor & exponent); // {"schema": "aten::frexp.Tensor_out(Tensor self, *, Tensor(a!) mantissa, Tensor(b!) exponent) -> (Tensor(a!) mantissa, Tensor(b!) exponent)", "dispatch": "True", "default": "False"} +at::Tensor frobenius_norm(const at::Tensor & self, at::IntArrayRef dim, bool keepdim); // {"schema": "aten::frobenius_norm.dim(Tensor self, int[1] dim, bool keepdim=False) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & frobenius_norm_out(const at::Tensor & self, at::IntArrayRef dim, bool keepdim, at::Tensor & out); // {"schema": "aten::frobenius_norm.out(Tensor self, int[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor nuclear_norm(const at::Tensor & self, bool keepdim); // {"schema": "aten::nuclear_norm(Tensor self, bool keepdim=False) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & nuclear_norm_out(const at::Tensor & self, bool keepdim, at::Tensor & out); // {"schema": "aten::nuclear_norm.out(Tensor self, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor nuclear_norm(const at::Tensor & self, at::IntArrayRef dim, bool keepdim); // {"schema": "aten::nuclear_norm.dim(Tensor self, int[2] dim, bool keepdim=False) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & nuclear_norm_out(const at::Tensor & self, at::IntArrayRef dim, bool keepdim, at::Tensor & out); // {"schema": "aten::nuclear_norm.dim_out(Tensor self, int[2] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor clone(const at::Tensor & self, ::std::optional memory_format); // {"schema": "aten::clone(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor positive(const at::Tensor & self); // {"schema": "aten::positive(Tensor(a) self) -> Tensor(a)", "dispatch": "False", "default": "True"} +const at::Tensor & resize_as_(const at::Tensor & self, const at::Tensor & the_template, ::std::optional memory_format); // {"schema": "aten::resize_as_(Tensor(a!) self, Tensor the_template, *, MemoryFormat? memory_format=None) -> Tensor(a!)", "dispatch": "True", "default": "True"} +const at::Tensor & resize_as_sparse_(const at::Tensor & self, const at::Tensor & the_template); // {"schema": "aten::resize_as_sparse_(Tensor(a!) self, Tensor the_template) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & zero_(at::Tensor & self); // {"schema": "aten::zero_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & sub_out(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha, at::Tensor & out); // {"schema": "aten::sub.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor sub(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha); // {"schema": "aten::sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & sub_(at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha); // {"schema": "aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor sub(const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha); // {"schema": "aten::sub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & sub_(at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha); // {"schema": "aten::sub_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & subtract_out(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha, at::Tensor & out); // {"schema": "aten::subtract.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor subtract(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha); // {"schema": "aten::subtract.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & subtract_(at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha); // {"schema": "aten::subtract_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor subtract(const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha); // {"schema": "aten::subtract.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & subtract_(at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha); // {"schema": "aten::subtract_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor rsub(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha); // {"schema": "aten::rsub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & heaviside_out(const at::Tensor & self, const at::Tensor & values, at::Tensor & out); // {"schema": "aten::heaviside.out(Tensor self, Tensor values, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor heaviside(const at::Tensor & self, const at::Tensor & values); // {"schema": "aten::heaviside(Tensor self, Tensor values) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & heaviside_(at::Tensor & self, const at::Tensor & values); // {"schema": "aten::heaviside_(Tensor(a!) self, Tensor values) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor rsub(const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha); // {"schema": "aten::rsub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor _sparse_addmm(const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta, const at::Scalar & alpha); // {"schema": "aten::_sparse_addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & sparse_sampled_addmm_out(const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out); // {"schema": "aten::sparse_sampled_addmm.out(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor sparse_sampled_addmm(const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta, const at::Scalar & alpha); // {"schema": "aten::sparse_sampled_addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor", "dispatch": "True", "default": "False"} +::std::tuple _sparse_mm_reduce_impl(const at::Tensor & self, const at::Tensor & other, c10::string_view reduce); // {"schema": "aten::_sparse_mm_reduce_impl(Tensor self, Tensor other, str reduce) -> (Tensor, Tensor)", "dispatch": "True", "default": "False"} +::std::tuple _sparse_mm_reduce_impl_backward(const at::Tensor & self, const at::Tensor & grad_out, const at::Tensor & weight, c10::string_view reduce, const at::Tensor & arg_out, ::std::array output_mask); // {"schema": "aten::_sparse_mm_reduce_impl_backward(Tensor self, Tensor grad_out, Tensor weight, str reduce, Tensor arg_out, bool[2] output_mask) -> (Tensor, Tensor)", "dispatch": "True", "default": "False"} +at::Tensor & addmm_out(const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out); // {"schema": "aten::addmm.out(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor addmm(const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta, const at::Scalar & alpha); // {"schema": "aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor addmm(const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, at::ScalarType out_dtype, const at::Scalar & beta, const at::Scalar & alpha); // {"schema": "aten::addmm.dtype(Tensor self, Tensor mat1, Tensor mat2, ScalarType out_dtype, *, Scalar beta=1, Scalar alpha=1) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & addmm_out(const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, at::ScalarType out_dtype, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out); // {"schema": "aten::addmm.dtype_out(Tensor self, Tensor mat1, Tensor mat2, ScalarType out_dtype, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & addmm_(at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta, const at::Scalar & alpha); // {"schema": "aten::addmm_(Tensor(a!) self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _addmm_activation_out(const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta, const at::Scalar & alpha, bool use_gelu, at::Tensor & out); // {"schema": "aten::_addmm_activation.out(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, bool use_gelu=False, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor _addmm_activation(const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta, const at::Scalar & alpha, bool use_gelu); // {"schema": "aten::_addmm_activation(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, bool use_gelu=False) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor _scaled_mm(const at::Tensor & self, const at::Tensor & mat2, const at::Tensor & scale_a, const at::Tensor & scale_b, const ::std::optional & bias, const ::std::optional & scale_result, ::std::optional out_dtype, bool use_fast_accum); // {"schema": "aten::_scaled_mm(Tensor self, Tensor mat2, Tensor scale_a, Tensor scale_b, Tensor? bias=None, Tensor? scale_result=None, ScalarType? out_dtype=None, bool use_fast_accum=False) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & _scaled_mm_out(const at::Tensor & self, const at::Tensor & mat2, const at::Tensor & scale_a, const at::Tensor & scale_b, const ::std::optional & bias, const ::std::optional & scale_result, ::std::optional out_dtype, bool use_fast_accum, at::Tensor & out); // {"schema": "aten::_scaled_mm.out(Tensor self, Tensor mat2, Tensor scale_a, Tensor scale_b, Tensor? bias=None, Tensor? scale_result=None, ScalarType? out_dtype=None, bool use_fast_accum=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor _scaled_mm_v2(const at::Tensor & self, const at::Tensor & mat2, at::TensorList scale_a, at::IntArrayRef recipe_a, at::IntArrayRef swizzle_a, at::TensorList scale_b, at::IntArrayRef recipe_b, at::IntArrayRef swizzle_b, const ::std::optional & bias, ::std::optional out_dtype, at::IntArrayRef contraction_dim, bool use_fast_accum); // {"schema": "aten::_scaled_mm_v2(Tensor self, Tensor mat2, Tensor[] scale_a, int[] recipe_a, int[] swizzle_a, Tensor[] scale_b, int[] recipe_b, int[] swizzle_b, Tensor? bias, ScalarType? out_dtype, int[] contraction_dim=[], bool use_fast_accum=False) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & _scaled_mm_v2_out(const at::Tensor & self, const at::Tensor & mat2, at::TensorList scale_a, at::IntArrayRef recipe_a, at::IntArrayRef swizzle_a, at::TensorList scale_b, at::IntArrayRef recipe_b, at::IntArrayRef swizzle_b, const ::std::optional & bias, ::std::optional out_dtype, at::IntArrayRef contraction_dim, bool use_fast_accum, at::Tensor & out); // {"schema": "aten::_scaled_mm_v2.out(Tensor self, Tensor mat2, Tensor[] scale_a, int[] recipe_a, int[] swizzle_a, Tensor[] scale_b, int[] recipe_b, int[] swizzle_b, Tensor? bias, ScalarType? out_dtype, int[] contraction_dim=[], bool use_fast_accum=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor _scaled_grouped_mm(const at::Tensor & self, const at::Tensor & mat2, const at::Tensor & scale_a, const at::Tensor & scale_b, const ::std::optional & offs, const ::std::optional & bias, const ::std::optional & scale_result, ::std::optional out_dtype, bool use_fast_accum); // {"schema": "aten::_scaled_grouped_mm(Tensor self, Tensor mat2, Tensor scale_a, Tensor scale_b, Tensor? offs=None, Tensor? bias=None, Tensor? scale_result=None, ScalarType? out_dtype=None, bool use_fast_accum=False) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _scaled_grouped_mm_v2(const at::Tensor & self, const at::Tensor & mat2, at::TensorList scale_a, at::IntArrayRef recipe_a, at::IntArrayRef swizzle_a, at::TensorList scale_b, at::IntArrayRef recipe_b, at::IntArrayRef swizzle_b, const ::std::optional & offs, const ::std::optional & bias, ::std::optional out_dtype, at::IntArrayRef contraction_dim, bool use_fast_accum); // {"schema": "aten::_scaled_grouped_mm_v2(Tensor self, Tensor mat2, Tensor[] scale_a, int[] recipe_a, int[] swizzle_a, Tensor[] scale_b, int[] recipe_b, int[] swizzle_b, Tensor? offs=None, Tensor? bias=None, ScalarType? out_dtype=None, int[] contraction_dim=[], bool use_fast_accum=False) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _grouped_mm(const at::Tensor & self, const at::Tensor & mat2, const ::std::optional & offs, const ::std::optional & bias, ::std::optional out_dtype); // {"schema": "aten::_grouped_mm(Tensor self, Tensor mat2, Tensor? offs=None, Tensor? bias=None, ScalarType? out_dtype=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor _sparse_compressed_tensor_with_dims(int64_t nnz, int64_t dense_dim, at::IntArrayRef size, at::IntArrayRef blocksize, at::ScalarType index_dtype, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::_sparse_compressed_tensor_with_dims(int nnz, int dense_dim, int[] size, int[] blocksize, ScalarType index_dtype, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor sparse_compressed_tensor(const at::Tensor & compressed_indices, const at::Tensor & plain_indices, const at::Tensor & values, c10::SymIntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::sparse_compressed_tensor.comp_plain_value_size(Tensor compressed_indices, Tensor plain_indices, Tensor values, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor sparse_csr_tensor(const at::Tensor & crow_indices, const at::Tensor & col_indices, const at::Tensor & values, at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::sparse_csr_tensor.crow_col_value_size(Tensor crow_indices, Tensor col_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor sparse_csc_tensor(const at::Tensor & ccol_indices, const at::Tensor & row_indices, const at::Tensor & values, at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::sparse_csc_tensor.ccol_row_value_size(Tensor ccol_indices, Tensor row_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor sparse_bsr_tensor(const at::Tensor & crow_indices, const at::Tensor & col_indices, const at::Tensor & values, at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::sparse_bsr_tensor.crow_col_value_size(Tensor crow_indices, Tensor col_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor sparse_bsc_tensor(const at::Tensor & ccol_indices, const at::Tensor & row_indices, const at::Tensor & values, at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::sparse_bsc_tensor.ccol_row_value_size(Tensor ccol_indices, Tensor row_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor sparse_compressed_tensor(const at::Tensor & compressed_indices, const at::Tensor & plain_indices, const at::Tensor & values, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::sparse_compressed_tensor.comp_plain_value(Tensor compressed_indices, Tensor plain_indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor sparse_csr_tensor(const at::Tensor & crow_indices, const at::Tensor & col_indices, const at::Tensor & values, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::sparse_csr_tensor.crow_col_value(Tensor crow_indices, Tensor col_indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor sparse_csc_tensor(const at::Tensor & ccol_indices, const at::Tensor & row_indices, const at::Tensor & values, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::sparse_csc_tensor.ccol_row_value(Tensor ccol_indices, Tensor row_indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor sparse_bsr_tensor(const at::Tensor & crow_indices, const at::Tensor & col_indices, const at::Tensor & values, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::sparse_bsr_tensor.crow_col_value(Tensor crow_indices, Tensor col_indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor sparse_bsc_tensor(const at::Tensor & ccol_indices, const at::Tensor & row_indices, const at::Tensor & values, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::sparse_bsc_tensor.ccol_row_value(Tensor ccol_indices, Tensor row_indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _sparse_compressed_tensor_unsafe(const at::Tensor & compressed_indices, const at::Tensor & plain_indices, const at::Tensor & values, c10::SymIntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::_sparse_compressed_tensor_unsafe(Tensor compressed_indices, Tensor plain_indices, Tensor values, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _sparse_csr_tensor_unsafe(const at::Tensor & crow_indices, const at::Tensor & col_indices, const at::Tensor & values, at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::_sparse_csr_tensor_unsafe(Tensor crow_indices, Tensor col_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _sparse_csc_tensor_unsafe(const at::Tensor & ccol_indices, const at::Tensor & row_indices, const at::Tensor & values, at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::_sparse_csc_tensor_unsafe(Tensor ccol_indices, Tensor row_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _sparse_bsr_tensor_unsafe(const at::Tensor & crow_indices, const at::Tensor & col_indices, const at::Tensor & values, at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::_sparse_bsr_tensor_unsafe(Tensor crow_indices, Tensor col_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _sparse_bsc_tensor_unsafe(const at::Tensor & ccol_indices, const at::Tensor & row_indices, const at::Tensor & values, at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::_sparse_bsc_tensor_unsafe(Tensor ccol_indices, Tensor row_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor sparse_coo_tensor(at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::sparse_coo_tensor.size(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor sparse_coo_tensor(const at::Tensor & indices, const at::Tensor & values, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional is_coalesced); // {"schema": "aten::sparse_coo_tensor.indices(Tensor indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool? is_coalesced=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor sparse_coo_tensor(const at::Tensor & indices, const at::Tensor & values, at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional is_coalesced); // {"schema": "aten::sparse_coo_tensor.indices_size(Tensor indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool? is_coalesced=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _sparse_coo_tensor_unsafe(const at::Tensor & indices, const at::Tensor & values, c10::SymIntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional is_coalesced); // {"schema": "aten::_sparse_coo_tensor_unsafe(Tensor indices, Tensor values, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool? is_coalesced=None) -> Tensor", "dispatch": "False", "default": "True"} +void _validate_sparse_coo_tensor_args(const at::Tensor & indices, const at::Tensor & values, at::IntArrayRef size, ::std::optional is_coalesced, ::std::optional check_pinning); // {"schema": "aten::_validate_sparse_coo_tensor_args(Tensor indices, Tensor values, int[] size, bool? is_coalesced=None, bool? check_pinning=None) -> ()", "dispatch": "False", "default": "True"} +void _validate_sparse_compressed_tensor_args(const at::Tensor & compressed_indices, const at::Tensor & plain_indices, const at::Tensor & values, at::IntArrayRef size, at::Layout layout, ::std::optional check_pinning); // {"schema": "aten::_validate_sparse_compressed_tensor_args(Tensor compressed_indices, Tensor plain_indices, Tensor values, int[] size, Layout layout, bool? check_pinning=None) -> ()", "dispatch": "False", "default": "True"} +void _validate_sparse_csr_tensor_args(const at::Tensor & crow_indices, const at::Tensor & col_indices, const at::Tensor & values, at::IntArrayRef size, ::std::optional check_pinning); // {"schema": "aten::_validate_sparse_csr_tensor_args(Tensor crow_indices, Tensor col_indices, Tensor values, int[] size, bool? check_pinning=None) -> ()", "dispatch": "False", "default": "True"} +void _validate_sparse_csc_tensor_args(const at::Tensor & ccol_indices, const at::Tensor & row_indices, const at::Tensor & values, at::IntArrayRef size, ::std::optional check_pinning); // {"schema": "aten::_validate_sparse_csc_tensor_args(Tensor ccol_indices, Tensor row_indices, Tensor values, int[] size, bool? check_pinning=None) -> ()", "dispatch": "False", "default": "True"} +void _validate_sparse_bsr_tensor_args(const at::Tensor & crow_indices, const at::Tensor & col_indices, const at::Tensor & values, at::IntArrayRef size, ::std::optional check_pinning); // {"schema": "aten::_validate_sparse_bsr_tensor_args(Tensor crow_indices, Tensor col_indices, Tensor values, int[] size, bool? check_pinning=None) -> ()", "dispatch": "False", "default": "True"} +void _validate_sparse_bsc_tensor_args(const at::Tensor & ccol_indices, const at::Tensor & row_indices, const at::Tensor & values, at::IntArrayRef size, ::std::optional check_pinning); // {"schema": "aten::_validate_sparse_bsc_tensor_args(Tensor ccol_indices, Tensor row_indices, Tensor values, int[] size, bool? check_pinning=None) -> ()", "dispatch": "False", "default": "True"} +at::Tensor _sparse_coo_tensor_with_dims(int64_t sparse_dim, int64_t dense_dim, at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::_sparse_coo_tensor_with_dims(int sparse_dim, int dense_dim, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _sparse_coo_tensor_with_dims_and_tensors(int64_t sparse_dim, int64_t dense_dim, c10::SymIntArrayRef size, const at::Tensor & indices, const at::Tensor & values, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional is_coalesced); // {"schema": "aten::_sparse_coo_tensor_with_dims_and_tensors(int sparse_dim, int dense_dim, SymInt[] size, Tensor indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False, bool? is_coalesced=None) -> Tensor", "dispatch": "True", "default": "False"} +const at::Tensor & sparse_resize_(const at::Tensor & self, at::IntArrayRef size, int64_t sparse_dim, int64_t dense_dim); // {"schema": "aten::sparse_resize_(Tensor(a!) self, int[] size, int sparse_dim, int dense_dim) -> Tensor(a!)", "dispatch": "True", "default": "False"} +const at::Tensor & sparse_resize_and_clear_(const at::Tensor & self, at::IntArrayRef size, int64_t sparse_dim, int64_t dense_dim); // {"schema": "aten::sparse_resize_and_clear_(Tensor(a!) self, int[] size, int sparse_dim, int dense_dim) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor sparse_mask(const at::Tensor & self, const at::Tensor & mask); // {"schema": "aten::sparse_mask(Tensor self, Tensor mask) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _sparse_mask_projection(const at::Tensor & self, const at::Tensor & mask, bool accumulate_matches); // {"schema": "aten::_sparse_mask_projection(Tensor self, Tensor mask, bool accumulate_matches=False) -> Tensor", "dispatch": "True", "default": "False"} +::std::vector _to_cpu(at::TensorList tensors); // {"schema": "aten::_to_cpu(Tensor[] tensors) -> Tensor[]", "dispatch": "False", "default": "True"} +at::Tensor to_dense(const at::Tensor & self, ::std::optional dtype, ::std::optional masked_grad); // {"schema": "aten::to_dense(Tensor self, ScalarType? dtype=None, *, bool? masked_grad=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _to_dense(const at::Tensor & self, ::std::optional dtype, ::std::optional masked_grad); // {"schema": "aten::_to_dense(Tensor self, ScalarType? dtype=None, bool? masked_grad=None) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor to_dense_backward(const at::Tensor & grad, const at::Tensor & input, ::std::optional masked_grad); // {"schema": "aten::to_dense_backward(Tensor grad, Tensor input, bool? masked_grad=None) -> Tensor", "dispatch": "False", "default": "True"} +int64_t sparse_dim(const at::Tensor & self); // {"schema": "aten::sparse_dim(Tensor self) -> int", "dispatch": "True", "default": "True"} +int64_t _dimI(const at::Tensor & self); // {"schema": "aten::_dimI(Tensor self) -> int", "dispatch": "True", "default": "False"} +int64_t dense_dim(const at::Tensor & self); // {"schema": "aten::dense_dim(Tensor self) -> int", "dispatch": "True", "default": "True"} +int64_t _dimV(const at::Tensor & self); // {"schema": "aten::_dimV(Tensor self) -> int", "dispatch": "True", "default": "False"} +int64_t _nnz(const at::Tensor & self); // {"schema": "aten::_nnz(Tensor self) -> int", "dispatch": "True", "default": "False"} +at::Tensor coalesce(const at::Tensor & self); // {"schema": "aten::coalesce(Tensor(a) self) -> Tensor(a)", "dispatch": "False", "default": "True"} +at::Tensor _coalesce(const at::Tensor & self); // {"schema": "aten::_coalesce(Tensor self) -> Tensor", "dispatch": "True", "default": "False"} +bool is_coalesced(const at::Tensor & self); // {"schema": "aten::is_coalesced(Tensor self) -> bool", "dispatch": "True", "default": "True"} +at::Tensor _indices(const at::Tensor & self); // {"schema": "aten::_indices(Tensor(a) self) -> Tensor(a)", "dispatch": "True", "default": "False"} +at::Tensor _values(const at::Tensor & self); // {"schema": "aten::_values(Tensor(a) self) -> Tensor(a)", "dispatch": "True", "default": "False"} +at::Tensor & _coalesced_(at::Tensor & self, bool coalesced); // {"schema": "aten::_coalesced_(Tensor(a!) self, bool coalesced) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor indices(const at::Tensor & self); // {"schema": "aten::indices(Tensor(a) self) -> Tensor(a)", "dispatch": "True", "default": "True"} +at::Tensor values(const at::Tensor & self); // {"schema": "aten::values(Tensor(a) self) -> Tensor(a)", "dispatch": "True", "default": "True"} +at::Tensor crow_indices(const at::Tensor & self); // {"schema": "aten::crow_indices(Tensor(a) self) -> Tensor(a)", "dispatch": "True", "default": "True"} +at::Tensor col_indices(const at::Tensor & self); // {"schema": "aten::col_indices(Tensor(a) self) -> Tensor(a)", "dispatch": "True", "default": "True"} +at::Tensor ccol_indices(const at::Tensor & self); // {"schema": "aten::ccol_indices(Tensor(a) self) -> Tensor(a)", "dispatch": "True", "default": "True"} +at::Tensor row_indices(const at::Tensor & self); // {"schema": "aten::row_indices(Tensor(a) self) -> Tensor(a)", "dispatch": "True", "default": "True"} +at::Tensor & hspmm_out(const at::Tensor & mat1, const at::Tensor & mat2, at::Tensor & out); // {"schema": "aten::hspmm.out(Tensor mat1, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor hspmm(const at::Tensor & mat1, const at::Tensor & mat2); // {"schema": "aten::hspmm(Tensor mat1, Tensor mat2) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & copy_sparse_to_sparse_(at::Tensor & self, const at::Tensor & src, bool non_blocking); // {"schema": "aten::copy_sparse_to_sparse_(Tensor(a!) self, Tensor src, bool non_blocking=False) -> Tensor(a!)", "dispatch": "True", "default": "False"} +::std::vector unbind(const at::Tensor & self, int64_t dim); // {"schema": "aten::unbind.int(Tensor(a -> *) self, int dim=0) -> Tensor(a)[]", "dispatch": "True", "default": "True"} +::std::vector unbind(const at::Tensor & self, at::Dimname dim); // {"schema": "aten::unbind.Dimname(Tensor(a -> *) self, Dimname dim) -> Tensor(a)[]", "dispatch": "False", "default": "True"} +at::Tensor to_sparse(const at::Tensor & self, int64_t sparse_dim); // {"schema": "aten::to_sparse.sparse_dim(Tensor self, int sparse_dim) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _to_sparse(const at::Tensor & self, int64_t sparse_dim); // {"schema": "aten::_to_sparse.sparse_dim(Tensor self, int sparse_dim) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor to_sparse(const at::Tensor & self, ::std::optional layout, at::OptionalIntArrayRef blocksize, ::std::optional dense_dim); // {"schema": "aten::to_sparse(Tensor self, *, Layout? layout=None, int[2]? blocksize=None, int? dense_dim=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _to_sparse(const at::Tensor & self, ::std::optional layout, at::OptionalIntArrayRef blocksize, ::std::optional dense_dim); // {"schema": "aten::_to_sparse(Tensor self, *, Layout? layout=None, int[2]? blocksize=None, int? dense_dim=None) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor to_sparse_csr(const at::Tensor & self, ::std::optional dense_dim); // {"schema": "aten::to_sparse_csr(Tensor self, int? dense_dim=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _to_sparse_csr(const at::Tensor & self, ::std::optional dense_dim); // {"schema": "aten::_to_sparse_csr(Tensor self, int? dense_dim=None) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor to_sparse_csc(const at::Tensor & self, ::std::optional dense_dim); // {"schema": "aten::to_sparse_csc(Tensor self, int? dense_dim=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _to_sparse_csc(const at::Tensor & self, ::std::optional dense_dim); // {"schema": "aten::_to_sparse_csc(Tensor self, int? dense_dim=None) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor to_sparse_bsr(const at::Tensor & self, at::IntArrayRef blocksize, ::std::optional dense_dim); // {"schema": "aten::to_sparse_bsr(Tensor self, int[2] blocksize, int? dense_dim=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _to_sparse_bsr(const at::Tensor & self, at::IntArrayRef blocksize, ::std::optional dense_dim); // {"schema": "aten::_to_sparse_bsr(Tensor self, int[2] blocksize, int? dense_dim=None) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor to_sparse_bsc(const at::Tensor & self, at::IntArrayRef blocksize, ::std::optional dense_dim); // {"schema": "aten::to_sparse_bsc(Tensor self, int[2] blocksize, int? dense_dim=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _to_sparse_bsc(const at::Tensor & self, at::IntArrayRef blocksize, ::std::optional dense_dim); // {"schema": "aten::_to_sparse_bsc(Tensor self, int[2] blocksize, int? dense_dim=None) -> Tensor", "dispatch": "True", "default": "False"} +::std::tuple _to_sparse_semi_structured(const at::Tensor & dense); // {"schema": "aten::_to_sparse_semi_structured(Tensor dense) -> (Tensor, Tensor)", "dispatch": "True", "default": "False"} +at::Tensor to_mkldnn(const at::Tensor & self, ::std::optional dtype); // {"schema": "aten::to_mkldnn(Tensor self, ScalarType? dtype=None) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor mkldnn_reorder_conv2d_weight(const at::Tensor & self, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, at::OptionalSymIntArrayRef input_size); // {"schema": "aten::mkldnn_reorder_conv2d_weight(Tensor self, SymInt[2] padding=0, SymInt[2] stride=1, SymInt[2] dilation=1, SymInt groups=1, SymInt[]? input_size=None) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor mkldnn_reorder_conv3d_weight(const at::Tensor & self, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, at::OptionalSymIntArrayRef input_size); // {"schema": "aten::mkldnn_reorder_conv3d_weight(Tensor self, SymInt[3] padding=0, SymInt[3] stride=1, SymInt[3] dilation=1, SymInt groups=1, SymInt[]? input_size=None) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor to_mkldnn_backward(const at::Tensor & grad, const at::Tensor & input); // {"schema": "aten::to_mkldnn_backward(Tensor grad, Tensor input) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor quantize_per_tensor_dynamic(const at::Tensor & self, at::ScalarType dtype, bool reduce_range); // {"schema": "aten::quantize_per_tensor_dynamic(Tensor self, ScalarType dtype, bool reduce_range) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor quantize_per_tensor(const at::Tensor & self, double scale, int64_t zero_point, at::ScalarType dtype); // {"schema": "aten::quantize_per_tensor(Tensor self, float scale, int zero_point, ScalarType dtype) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor quantize_per_tensor(const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, at::ScalarType dtype); // {"schema": "aten::quantize_per_tensor.tensor_qparams(Tensor self, Tensor scale, Tensor zero_point, ScalarType dtype) -> Tensor", "dispatch": "True", "default": "False"} +::std::vector quantize_per_tensor(at::TensorList tensors, const at::Tensor & scales, const at::Tensor & zero_points, at::ScalarType dtype); // {"schema": "aten::quantize_per_tensor.tensors(Tensor[] tensors, Tensor scales, Tensor zero_points, ScalarType dtype) -> Tensor[]", "dispatch": "True", "default": "False"} +at::Tensor quantize_per_channel(const at::Tensor & self, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, at::ScalarType dtype); // {"schema": "aten::quantize_per_channel(Tensor self, Tensor scales, Tensor zero_points, int axis, ScalarType dtype) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor dequantize(const at::Tensor & self); // {"schema": "aten::dequantize.self(Tensor self) -> Tensor", "dispatch": "True", "default": "False"} +::std::vector dequantize(at::TensorList tensors); // {"schema": "aten::dequantize.tensors(Tensor[] tensors) -> Tensor[]", "dispatch": "True", "default": "False"} +double q_scale(const at::Tensor & self); // {"schema": "aten::q_scale(Tensor self) -> float", "dispatch": "True", "default": "False"} +int64_t q_zero_point(const at::Tensor & self); // {"schema": "aten::q_zero_point(Tensor self) -> int", "dispatch": "True", "default": "False"} +at::Tensor q_per_channel_scales(const at::Tensor & self); // {"schema": "aten::q_per_channel_scales(Tensor self) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor q_per_channel_zero_points(const at::Tensor & self); // {"schema": "aten::q_per_channel_zero_points(Tensor self) -> Tensor", "dispatch": "True", "default": "False"} +int64_t q_per_channel_axis(const at::Tensor & self); // {"schema": "aten::q_per_channel_axis(Tensor self) -> int", "dispatch": "True", "default": "False"} +at::Tensor int_repr(const at::Tensor & self); // {"schema": "aten::int_repr(Tensor self) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _make_per_tensor_quantized_tensor(const at::Tensor & self, double scale, int64_t zero_point); // {"schema": "aten::_make_per_tensor_quantized_tensor(Tensor self, float scale, int zero_point) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _make_per_channel_quantized_tensor(const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, int64_t axis); // {"schema": "aten::_make_per_channel_quantized_tensor(Tensor self, Tensor scale, Tensor zero_point, int axis) -> Tensor", "dispatch": "True", "default": "False"} +at::QScheme qscheme(const at::Tensor & self); // {"schema": "aten::qscheme(Tensor self) -> QScheme", "dispatch": "True", "default": "False"} +at::Tensor fake_quantize_per_tensor_affine(const at::Tensor & self, double scale, int64_t zero_point, int64_t quant_min, int64_t quant_max); // {"schema": "aten::fake_quantize_per_tensor_affine(Tensor self, float scale, int zero_point, int quant_min, int quant_max) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor fake_quantize_per_tensor_affine(const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, int64_t quant_min, int64_t quant_max); // {"schema": "aten::fake_quantize_per_tensor_affine.tensor_qparams(Tensor self, Tensor scale, Tensor zero_point, int quant_min, int quant_max) -> Tensor", "dispatch": "False", "default": "True"} +::std::tuple fake_quantize_per_tensor_affine_cachemask(const at::Tensor & self, double scale, int64_t zero_point, int64_t quant_min, int64_t quant_max); // {"schema": "aten::fake_quantize_per_tensor_affine_cachemask(Tensor self, float scale, int zero_point, int quant_min, int quant_max) -> (Tensor output, Tensor mask)", "dispatch": "True", "default": "False"} +::std::tuple _fake_quantize_per_tensor_affine_cachemask_tensor_qparams(const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, const at::Tensor & fake_quant_enabled, int64_t quant_min, int64_t quant_max); // {"schema": "aten::_fake_quantize_per_tensor_affine_cachemask_tensor_qparams(Tensor self, Tensor scale, Tensor zero_point, Tensor fake_quant_enabled, int quant_min, int quant_max) -> (Tensor output, Tensor mask)", "dispatch": "True", "default": "False"} +at::Tensor fake_quantize_per_tensor_affine_cachemask_backward(const at::Tensor & grad, const at::Tensor & mask); // {"schema": "aten::fake_quantize_per_tensor_affine_cachemask_backward(Tensor grad, Tensor mask) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _fake_quantize_learnable_per_tensor_affine(const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, int64_t quant_min, int64_t quant_max, double grad_factor); // {"schema": "aten::_fake_quantize_learnable_per_tensor_affine(Tensor self, Tensor scale, Tensor zero_point, int quant_min, int quant_max, float grad_factor=1.0) -> Tensor", "dispatch": "True", "default": "False"} +::std::tuple _fake_quantize_learnable_per_tensor_affine_backward(const at::Tensor & grad, const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, int64_t quant_min, int64_t quant_max, double grad_factor); // {"schema": "aten::_fake_quantize_learnable_per_tensor_affine_backward(Tensor grad, Tensor self, Tensor scale, Tensor zero_point, int quant_min, int quant_max, float grad_factor=1.0) -> (Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"} +at::Tensor fake_quantize_per_channel_affine(const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, int64_t axis, int64_t quant_min, int64_t quant_max); // {"schema": "aten::fake_quantize_per_channel_affine(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max) -> Tensor", "dispatch": "False", "default": "True"} +::std::tuple fake_quantize_per_channel_affine_cachemask(const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, int64_t axis, int64_t quant_min, int64_t quant_max); // {"schema": "aten::fake_quantize_per_channel_affine_cachemask(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max) -> (Tensor output, Tensor mask)", "dispatch": "True", "default": "False"} +at::Tensor fake_quantize_per_channel_affine_cachemask_backward(const at::Tensor & grad, const at::Tensor & mask); // {"schema": "aten::fake_quantize_per_channel_affine_cachemask_backward(Tensor grad, Tensor mask) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _fake_quantize_learnable_per_channel_affine(const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, int64_t axis, int64_t quant_min, int64_t quant_max, double grad_factor); // {"schema": "aten::_fake_quantize_learnable_per_channel_affine(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max, float grad_factor=1.0) -> Tensor", "dispatch": "True", "default": "False"} +::std::tuple _fake_quantize_learnable_per_channel_affine_backward(const at::Tensor & grad, const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, int64_t axis, int64_t quant_min, int64_t quant_max, double grad_factor); // {"schema": "aten::_fake_quantize_learnable_per_channel_affine_backward(Tensor grad, Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max, float grad_factor=1.0) -> (Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"} +at::Tensor fused_moving_avg_obs_fake_quant(const at::Tensor & self, const at::Tensor & observer_on, const at::Tensor & fake_quant_on, at::Tensor & running_min, at::Tensor & running_max, at::Tensor & scale, at::Tensor & zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, bool per_row_fake_quant, bool symmetric_quant); // {"schema": "aten::fused_moving_avg_obs_fake_quant(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor(a!) running_min, Tensor(b!) running_max, Tensor(c!) scale, Tensor(d!) zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> Tensor", "dispatch": "False", "default": "True"} +::std::tuple _fused_moving_avg_obs_fq_helper(const at::Tensor & self, const at::Tensor & observer_on, const at::Tensor & fake_quant_on, at::Tensor & running_min, at::Tensor & running_max, at::Tensor & scale, at::Tensor & zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, bool per_row_fake_quant, bool symmetric_quant); // {"schema": "aten::_fused_moving_avg_obs_fq_helper(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor(a!) running_min, Tensor(b!) running_max, Tensor(c!) scale, Tensor(d!) zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> (Tensor output, Tensor mask)", "dispatch": "True", "default": "False"} +::std::tuple _choose_qparams_per_tensor(const at::Tensor & self, bool reduce_range); // {"schema": "aten::_choose_qparams_per_tensor(Tensor self, bool reduce_range=False) -> (float, int)", "dispatch": "False", "default": "True"} +at::Tensor _saturate_weight_to_fp16(const at::Tensor & weight); // {"schema": "aten::_saturate_weight_to_fp16(Tensor weight) -> Tensor", "dispatch": "False", "default": "True"} +::std::tuple choose_qparams_optimized(const at::Tensor & input, int64_t numel, int64_t n_bins, double ratio, int64_t bit_width); // {"schema": "aten::choose_qparams_optimized(Tensor input, int numel, int n_bins, float ratio, int bit_width) -> (Tensor, Tensor)", "dispatch": "False", "default": "True"} +at::Tensor _autocast_to_reduced_precision(const at::Tensor & self, bool cuda_enabled, bool cpu_enabled, at::ScalarType cuda_dtype, at::ScalarType cpu_dtype); // {"schema": "aten::_autocast_to_reduced_precision(Tensor(a) self, bool cuda_enabled, bool cpu_enabled, ScalarType cuda_dtype, ScalarType cpu_dtype) -> Tensor(a)", "dispatch": "False", "default": "True"} +at::Tensor _autocast_to_full_precision(const at::Tensor & self, bool cuda_enabled, bool cpu_enabled); // {"schema": "aten::_autocast_to_full_precision(Tensor(a) self, bool cuda_enabled, bool cpu_enabled) -> Tensor(a)", "dispatch": "False", "default": "True"} +at::Tensor _to_copy(const at::Tensor & self, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, bool non_blocking, ::std::optional memory_format); // {"schema": "aten::_to_copy(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool non_blocking=False, MemoryFormat? memory_format=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor to(const at::Tensor & self, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, bool non_blocking, bool copy, ::std::optional memory_format); // {"schema": "aten::to.dtype_layout(Tensor(a) self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a)", "dispatch": "False", "default": "True"} +at::Tensor to(const at::Tensor & self, at::Device device, at::ScalarType dtype, bool non_blocking, bool copy, ::std::optional memory_format); // {"schema": "aten::to.device(Tensor(a) self, Device device, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a)", "dispatch": "False", "default": "True"} +at::Tensor to(const at::Tensor & self, at::ScalarType dtype, bool non_blocking, bool copy, ::std::optional memory_format); // {"schema": "aten::to.dtype(Tensor(a) self, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a)", "dispatch": "False", "default": "True"} +at::Tensor to(const at::Tensor & self, const at::Tensor & other, bool non_blocking, bool copy, ::std::optional memory_format); // {"schema": "aten::to.other(Tensor(a) self, Tensor other, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a)", "dispatch": "False", "default": "True"} +::std::vector meshgrid(at::TensorList tensors); // {"schema": "aten::meshgrid(Tensor[] tensors) -> Tensor[]", "dispatch": "False", "default": "True"} +::std::vector meshgrid(at::TensorList tensors, c10::string_view indexing); // {"schema": "aten::meshgrid.indexing(Tensor[] tensors, *, str indexing) -> Tensor[]", "dispatch": "False", "default": "True"} +at::Tensor cartesian_prod(at::TensorList tensors); // {"schema": "aten::cartesian_prod(Tensor[] tensors) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor combinations(const at::Tensor & self, int64_t r, bool with_replacement); // {"schema": "aten::combinations(Tensor self, int r=2, bool with_replacement=False) -> Tensor", "dispatch": "False", "default": "True"} +at::Scalar item(const at::Tensor & self); // {"schema": "aten::item(Tensor self) -> Scalar", "dispatch": "False", "default": "True"} +at::ScalarType result_type(const at::Tensor & tensor, const at::Tensor & other); // {"schema": "aten::result_type.Tensor(Tensor tensor, Tensor other) -> ScalarType", "dispatch": "False", "default": "True"} +at::ScalarType result_type(const at::Tensor & tensor, const at::Scalar & other); // {"schema": "aten::result_type.Scalar(Tensor tensor, Scalar other) -> ScalarType", "dispatch": "False", "default": "True"} +at::ScalarType result_type(const at::Scalar & scalar, const at::Tensor & tensor); // {"schema": "aten::result_type.Scalar_Tensor(Scalar scalar, Tensor tensor) -> ScalarType", "dispatch": "False", "default": "True"} +at::ScalarType result_type(const at::Scalar & scalar1, const at::Scalar & scalar2); // {"schema": "aten::result_type.Scalar_Scalar(Scalar scalar1, Scalar scalar2) -> ScalarType", "dispatch": "False", "default": "True"} +bool can_cast(at::ScalarType from_, at::ScalarType to); // {"schema": "aten::can_cast(ScalarType from_, ScalarType to) -> bool", "dispatch": "False", "default": "True"} +at::ScalarType promote_types(at::ScalarType type1, at::ScalarType type2); // {"schema": "aten::promote_types(ScalarType type1, ScalarType type2) -> ScalarType", "dispatch": "False", "default": "True"} +at::Scalar _local_scalar_dense(const at::Tensor & self); // {"schema": "aten::_local_scalar_dense(Tensor self) -> Scalar", "dispatch": "True", "default": "False"} +::std::tuple _lstm_mps(const at::Tensor & input, at::TensorList hx, at::TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional, bool batch_first); // {"schema": "aten::_lstm_mps(Tensor input, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"} +::std::tuple,::std::vector> lstm_mps_backward(const ::std::optional & grad_y, const ::std::optional & grad_hy, const ::std::optional & grad_cy, const at::Tensor & z_state, const at::Tensor & cell_state_fwd, const at::Tensor & input, const at::Tensor & layersOutputs, at::TensorList hx, at::TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional, bool batch_first); // {"schema": "aten::lstm_mps_backward(Tensor? grad_y, Tensor? grad_hy, Tensor? grad_cy, Tensor z_state, Tensor cell_state_fwd, Tensor input, Tensor layersOutputs, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor[], Tensor[])", "dispatch": "True", "default": "False"} +::std::tuple _thnn_fused_lstm_cell(const at::Tensor & input_gates, const at::Tensor & hidden_gates, const at::Tensor & cx, const ::std::optional & input_bias, const ::std::optional & hidden_bias); // {"schema": "aten::_thnn_fused_lstm_cell(Tensor input_gates, Tensor hidden_gates, Tensor cx, Tensor? input_bias=None, Tensor? hidden_bias=None) -> (Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"} +::std::tuple _thnn_fused_lstm_cell_backward_impl(const ::std::optional & grad_hy, const ::std::optional & grad_cy, const at::Tensor & cx, const at::Tensor & cy, const at::Tensor & workspace, bool has_bias); // {"schema": "aten::_thnn_fused_lstm_cell_backward_impl(Tensor? grad_hy, Tensor? grad_cy, Tensor cx, Tensor cy, Tensor workspace, bool has_bias) -> (Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"} +::std::tuple _thnn_fused_lstm_cell_backward(const ::std::optional & grad_hy, const ::std::optional & grad_cy, const at::Tensor & cx, const at::Tensor & cy, const at::Tensor & workspace, bool has_bias); // {"schema": "aten::_thnn_fused_lstm_cell_backward(Tensor? grad_hy, Tensor? grad_cy, Tensor cx, Tensor cy, Tensor workspace, bool has_bias) -> (Tensor, Tensor, Tensor, Tensor, Tensor)", "dispatch": "False", "default": "True"} +::std::tuple _thnn_differentiable_lstm_cell_backward(const ::std::optional & grad_hy, const ::std::optional & grad_cy, const at::Tensor & input_gates, const at::Tensor & hidden_gates, const ::std::optional & input_bias, const ::std::optional & hidden_bias, const at::Tensor & cx, const at::Tensor & cy); // {"schema": "aten::_thnn_differentiable_lstm_cell_backward(Tensor? grad_hy, Tensor? grad_cy, Tensor input_gates, Tensor hidden_gates, Tensor? input_bias, Tensor? hidden_bias, Tensor cx, Tensor cy) -> (Tensor, Tensor, Tensor, Tensor, Tensor)", "dispatch": "False", "default": "True"} +::std::tuple _thnn_fused_gru_cell(const at::Tensor & input_gates, const at::Tensor & hidden_gates, const at::Tensor & hx, const ::std::optional & input_bias, const ::std::optional & hidden_bias); // {"schema": "aten::_thnn_fused_gru_cell(Tensor input_gates, Tensor hidden_gates, Tensor hx, Tensor? input_bias=None, Tensor? hidden_bias=None) -> (Tensor, Tensor)", "dispatch": "True", "default": "False"} +::std::tuple _thnn_fused_gru_cell_backward(const at::Tensor & grad_hy, const at::Tensor & workspace, bool has_bias); // {"schema": "aten::_thnn_fused_gru_cell_backward(Tensor grad_hy, Tensor workspace, bool has_bias) -> (Tensor, Tensor, Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"} +::std::tuple _thnn_differentiable_gru_cell_backward(const at::Tensor & grad_hy, const at::Tensor & input_gates, const at::Tensor & hidden_gates, const at::Tensor & hx, const ::std::optional & input_bias, const ::std::optional & hidden_bias); // {"schema": "aten::_thnn_differentiable_gru_cell_backward(Tensor grad_hy, Tensor input_gates, Tensor hidden_gates, Tensor hx, Tensor? input_bias, Tensor? hidden_bias) -> (Tensor, Tensor, Tensor, Tensor, Tensor)", "dispatch": "False", "default": "True"} +::std::tuple lstm(const at::Tensor & input, at::TensorList hx, at::TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional, bool batch_first); // {"schema": "aten::lstm.input(Tensor input, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor, Tensor)", "dispatch": "False", "default": "True"} +::std::tuple lstm(const at::Tensor & data, const at::Tensor & batch_sizes, at::TensorList hx, at::TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional); // {"schema": "aten::lstm.data(Tensor data, Tensor batch_sizes, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional) -> (Tensor, Tensor, Tensor)", "dispatch": "False", "default": "True"} +::std::tuple gru(const at::Tensor & input, const at::Tensor & hx, at::TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional, bool batch_first); // {"schema": "aten::gru.input(Tensor input, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor)", "dispatch": "False", "default": "True"} +::std::tuple gru(const at::Tensor & data, const at::Tensor & batch_sizes, const at::Tensor & hx, at::TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional); // {"schema": "aten::gru.data(Tensor data, Tensor batch_sizes, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional) -> (Tensor, Tensor)", "dispatch": "False", "default": "True"} +::std::tuple rnn_tanh(const at::Tensor & input, const at::Tensor & hx, at::TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional, bool batch_first); // {"schema": "aten::rnn_tanh.input(Tensor input, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor)", "dispatch": "False", "default": "True"} +::std::tuple rnn_tanh(const at::Tensor & data, const at::Tensor & batch_sizes, const at::Tensor & hx, at::TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional); // {"schema": "aten::rnn_tanh.data(Tensor data, Tensor batch_sizes, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional) -> (Tensor, Tensor)", "dispatch": "False", "default": "True"} +::std::tuple rnn_relu(const at::Tensor & input, const at::Tensor & hx, at::TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional, bool batch_first); // {"schema": "aten::rnn_relu.input(Tensor input, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor)", "dispatch": "False", "default": "True"} +::std::tuple rnn_relu(const at::Tensor & data, const at::Tensor & batch_sizes, const at::Tensor & hx, at::TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional); // {"schema": "aten::rnn_relu.data(Tensor data, Tensor batch_sizes, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional) -> (Tensor, Tensor)", "dispatch": "False", "default": "True"} +::std::tuple lstm_cell(const at::Tensor & input, at::TensorList hx, const at::Tensor & w_ih, const at::Tensor & w_hh, const ::std::optional & b_ih, const ::std::optional & b_hh); // {"schema": "aten::lstm_cell(Tensor input, Tensor[] hx, Tensor w_ih, Tensor w_hh, Tensor? b_ih=None, Tensor? b_hh=None) -> (Tensor, Tensor)", "dispatch": "False", "default": "True"} +at::Tensor gru_cell(const at::Tensor & input, const at::Tensor & hx, const at::Tensor & w_ih, const at::Tensor & w_hh, const ::std::optional & b_ih, const ::std::optional & b_hh); // {"schema": "aten::gru_cell(Tensor input, Tensor hx, Tensor w_ih, Tensor w_hh, Tensor? b_ih=None, Tensor? b_hh=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor rnn_tanh_cell(const at::Tensor & input, const at::Tensor & hx, const at::Tensor & w_ih, const at::Tensor & w_hh, const ::std::optional & b_ih, const ::std::optional & b_hh); // {"schema": "aten::rnn_tanh_cell(Tensor input, Tensor hx, Tensor w_ih, Tensor w_hh, Tensor? b_ih=None, Tensor? b_hh=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor rnn_relu_cell(const at::Tensor & input, const at::Tensor & hx, const at::Tensor & w_ih, const at::Tensor & w_hh, const ::std::optional & b_ih, const ::std::optional & b_hh); // {"schema": "aten::rnn_relu_cell(Tensor input, Tensor hx, Tensor w_ih, Tensor w_hh, Tensor? b_ih=None, Tensor? b_hh=None) -> Tensor", "dispatch": "False", "default": "True"} +::std::tuple quantized_lstm_cell(const at::Tensor & input, at::TensorList hx, const at::Tensor & w_ih, const at::Tensor & w_hh, const at::Tensor & b_ih, const at::Tensor & b_hh, const at::Tensor & packed_ih, const at::Tensor & packed_hh, const at::Tensor & col_offsets_ih, const at::Tensor & col_offsets_hh, const at::Scalar & scale_ih, const at::Scalar & scale_hh, const at::Scalar & zero_point_ih, const at::Scalar & zero_point_hh); // {"schema": "aten::quantized_lstm_cell(Tensor input, Tensor[] hx, Tensor w_ih, Tensor w_hh, Tensor b_ih, Tensor b_hh, Tensor packed_ih, Tensor packed_hh, Tensor col_offsets_ih, Tensor col_offsets_hh, Scalar scale_ih, Scalar scale_hh, Scalar zero_point_ih, Scalar zero_point_hh) -> (Tensor, Tensor)", "dispatch": "False", "default": "True"} +at::Tensor quantized_gru_cell(const at::Tensor & input, const at::Tensor & hx, const at::Tensor & w_ih, const at::Tensor & w_hh, const at::Tensor & b_ih, const at::Tensor & b_hh, const at::Tensor & packed_ih, const at::Tensor & packed_hh, const at::Tensor & col_offsets_ih, const at::Tensor & col_offsets_hh, const at::Scalar & scale_ih, const at::Scalar & scale_hh, const at::Scalar & zero_point_ih, const at::Scalar & zero_point_hh); // {"schema": "aten::quantized_gru_cell(Tensor input, Tensor hx, Tensor w_ih, Tensor w_hh, Tensor b_ih, Tensor b_hh, Tensor packed_ih, Tensor packed_hh, Tensor col_offsets_ih, Tensor col_offsets_hh, Scalar scale_ih, Scalar scale_hh, Scalar zero_point_ih, Scalar zero_point_hh) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor quantized_rnn_relu_cell(const at::Tensor & input, const at::Tensor & hx, const at::Tensor & w_ih, const at::Tensor & w_hh, const at::Tensor & b_ih, const at::Tensor & b_hh, const at::Tensor & packed_ih, const at::Tensor & packed_hh, const at::Tensor & col_offsets_ih, const at::Tensor & col_offsets_hh, const at::Scalar & scale_ih, const at::Scalar & scale_hh, const at::Scalar & zero_point_ih, const at::Scalar & zero_point_hh); // {"schema": "aten::quantized_rnn_relu_cell(Tensor input, Tensor hx, Tensor w_ih, Tensor w_hh, Tensor b_ih, Tensor b_hh, Tensor packed_ih, Tensor packed_hh, Tensor col_offsets_ih, Tensor col_offsets_hh, Scalar scale_ih, Scalar scale_hh, Scalar zero_point_ih, Scalar zero_point_hh) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor quantized_rnn_tanh_cell(const at::Tensor & input, const at::Tensor & hx, const at::Tensor & w_ih, const at::Tensor & w_hh, const at::Tensor & b_ih, const at::Tensor & b_hh, const at::Tensor & packed_ih, const at::Tensor & packed_hh, const at::Tensor & col_offsets_ih, const at::Tensor & col_offsets_hh, const at::Scalar & scale_ih, const at::Scalar & scale_hh, const at::Scalar & zero_point_ih, const at::Scalar & zero_point_hh); // {"schema": "aten::quantized_rnn_tanh_cell(Tensor input, Tensor hx, Tensor w_ih, Tensor w_hh, Tensor b_ih, Tensor b_hh, Tensor packed_ih, Tensor packed_hh, Tensor col_offsets_ih, Tensor col_offsets_hh, Scalar scale_ih, Scalar scale_hh, Scalar zero_point_ih, Scalar zero_point_hh) -> Tensor", "dispatch": "False", "default": "True"} +::std::tuple _pack_padded_sequence(const at::Tensor & input, const at::Tensor & lengths, bool batch_first); // {"schema": "aten::_pack_padded_sequence(Tensor input, Tensor lengths, bool batch_first) -> (Tensor, Tensor)", "dispatch": "True", "default": "True"} +at::Tensor _pack_padded_sequence_backward(const at::Tensor & grad, c10::SymIntArrayRef input_size, const at::Tensor & batch_sizes, bool batch_first); // {"schema": "aten::_pack_padded_sequence_backward(Tensor grad, SymInt[] input_size, Tensor batch_sizes, bool batch_first) -> Tensor", "dispatch": "False", "default": "True"} +::std::tuple _pad_packed_sequence(const at::Tensor & data, const at::Tensor & batch_sizes, bool batch_first, const at::Scalar & padding_value, int64_t total_length); // {"schema": "aten::_pad_packed_sequence(Tensor data, Tensor batch_sizes, bool batch_first, Scalar padding_value, int total_length) -> (Tensor, Tensor)", "dispatch": "False", "default": "True"} +at::Tensor & set_(at::Tensor & self, at::Storage source); // {"schema": "aten::set_.source_Storage(Tensor(a!) self, Storage source) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & set_(at::Tensor & self, at::Storage source, c10::SymInt storage_offset, c10::SymIntArrayRef size, c10::SymIntArrayRef stride); // {"schema": "aten::set_.source_Storage_storage_offset(Tensor(a!) self, Storage source, SymInt storage_offset, SymInt[] size, SymInt[] stride=[]) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & set_(at::Tensor & self, const at::Tensor & source, c10::SymInt storage_offset, c10::SymIntArrayRef size, c10::SymIntArrayRef stride); // {"schema": "aten::set_.source_Tensor_storage_offset(Tensor(a!) self, Tensor source, SymInt storage_offset, SymInt[] size, SymInt[] stride=[]) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & set_(at::Tensor & self, const at::Tensor & source); // {"schema": "aten::set_.source_Tensor(Tensor(a!) self, Tensor source) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & set_(at::Tensor & self); // {"schema": "aten::set_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor lift(const at::Tensor & self); // {"schema": "aten::lift(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor lift_fresh(const at::Tensor & self); // {"schema": "aten::lift_fresh(Tensor(a) self) -> Tensor(a)", "dispatch": "True", "default": "True"} +at::Tensor lift_fresh_copy(const at::Tensor & self); // {"schema": "aten::lift_fresh_copy(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +bool is_set_to(const at::Tensor & self, const at::Tensor & tensor); // {"schema": "aten::is_set_to(Tensor self, Tensor tensor) -> bool", "dispatch": "True", "default": "False"} +at::Tensor & masked_fill_(at::Tensor & self, const at::Tensor & mask, const at::Scalar & value); // {"schema": "aten::masked_fill_.Scalar(Tensor(a!) self, Tensor mask, Scalar value) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor masked_fill(const at::Tensor & self, const at::Tensor & mask, const at::Scalar & value); // {"schema": "aten::masked_fill.Scalar(Tensor self, Tensor mask, Scalar value) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & masked_fill_(at::Tensor & self, const at::Tensor & mask, const at::Tensor & value); // {"schema": "aten::masked_fill_.Tensor(Tensor(a!) self, Tensor mask, Tensor value) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor masked_fill(const at::Tensor & self, const at::Tensor & mask, const at::Tensor & value); // {"schema": "aten::masked_fill.Tensor(Tensor self, Tensor mask, Tensor value) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & masked_scatter_(at::Tensor & self, const at::Tensor & mask, const at::Tensor & source); // {"schema": "aten::masked_scatter_(Tensor(a!) self, Tensor mask, Tensor source) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor masked_scatter(const at::Tensor & self, const at::Tensor & mask, const at::Tensor & source); // {"schema": "aten::masked_scatter(Tensor self, Tensor mask, Tensor source) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor masked_scatter_backward(const at::Tensor & grad_output, const at::Tensor & mask, c10::SymIntArrayRef sizes); // {"schema": "aten::masked_scatter_backward(Tensor grad_output, Tensor mask, SymInt[] sizes) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor _masked_softmax(const at::Tensor & self, const at::Tensor & mask, ::std::optional dim, ::std::optional mask_type); // {"schema": "aten::_masked_softmax(Tensor self, Tensor mask, int? dim=None, int? mask_type=None) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _masked_softmax_backward(const at::Tensor & grad_output, const at::Tensor & output, const at::Tensor & mask, ::std::optional dim); // {"schema": "aten::_masked_softmax_backward(Tensor grad_output, Tensor output, Tensor mask, int? dim=None) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor view(const at::Tensor & self, c10::SymIntArrayRef size); // {"schema": "aten::view(Tensor(a) self, SymInt[] size) -> Tensor(a)", "dispatch": "True", "default": "False"} +at::Tensor view(const at::Tensor & self, at::ScalarType dtype); // {"schema": "aten::view.dtype(Tensor(a) self, ScalarType dtype) -> Tensor(a)", "dispatch": "True", "default": "True"} +at::Tensor & put_(at::Tensor & self, const at::Tensor & index, const at::Tensor & source, bool accumulate); // {"schema": "aten::put_(Tensor(a!) self, Tensor index, Tensor source, bool accumulate=False) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor put(const at::Tensor & self, const at::Tensor & index, const at::Tensor & source, bool accumulate); // {"schema": "aten::put(Tensor self, Tensor index, Tensor source, bool accumulate=False) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & index_add_out(const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & source, const at::Scalar & alpha, at::Tensor & out); // {"schema": "aten::index_add.out(Tensor self, int dim, Tensor index, Tensor source, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & index_add_(at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & source, const at::Scalar & alpha); // {"schema": "aten::index_add_(Tensor(a!) self, int dim, Tensor index, Tensor source, *, Scalar alpha=1) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor index_add(const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & source, const at::Scalar & alpha); // {"schema": "aten::index_add(Tensor self, int dim, Tensor index, Tensor source, *, Scalar alpha=1) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor index_add(const at::Tensor & self, at::Dimname dim, const at::Tensor & index, const at::Tensor & source, const at::Scalar & alpha); // {"schema": "aten::index_add.dimname(Tensor self, Dimname dim, Tensor index, Tensor source, *, Scalar alpha=1) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & index_reduce_out(const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & source, c10::string_view reduce, bool include_self, at::Tensor & out); // {"schema": "aten::index_reduce.out(Tensor self, int dim, Tensor index, Tensor source, str reduce, *, bool include_self=True, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & index_reduce_(at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & source, c10::string_view reduce, bool include_self); // {"schema": "aten::index_reduce_(Tensor(a!) self, int dim, Tensor index, Tensor source, str reduce, *, bool include_self=True) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor index_reduce(const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & source, c10::string_view reduce, bool include_self); // {"schema": "aten::index_reduce(Tensor self, int dim, Tensor index, Tensor source, str reduce, *, bool include_self=True) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & index_fill_(at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Scalar & value); // {"schema": "aten::index_fill_.int_Scalar(Tensor(a!) self, int dim, Tensor index, Scalar value) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor index_fill(const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Scalar & value); // {"schema": "aten::index_fill.int_Scalar(Tensor self, int dim, Tensor index, Scalar value) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & index_fill_(at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & value); // {"schema": "aten::index_fill_.int_Tensor(Tensor(a!) self, int dim, Tensor index, Tensor value) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor index_fill(const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & value); // {"schema": "aten::index_fill.int_Tensor(Tensor self, int dim, Tensor index, Tensor value) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & index_fill_(at::Tensor & self, at::Dimname dim, const at::Tensor & index, const at::Scalar & value); // {"schema": "aten::index_fill_.Dimname_Scalar(Tensor(a!) self, Dimname dim, Tensor index, Scalar value) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & index_fill_(at::Tensor & self, at::Dimname dim, const at::Tensor & index, const at::Tensor & value); // {"schema": "aten::index_fill_.Dimname_Tensor(Tensor(a!) self, Dimname dim, Tensor index, Tensor value) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor index_fill(const at::Tensor & self, at::Dimname dim, const at::Tensor & index, const at::Scalar & value); // {"schema": "aten::index_fill.Dimname_Scalar(Tensor self, Dimname dim, Tensor index, Scalar value) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor index_fill(const at::Tensor & self, at::Dimname dim, const at::Tensor & index, const at::Tensor & value); // {"schema": "aten::index_fill.Dimname_Tensor(Tensor self, Dimname dim, Tensor index, Tensor value) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor scatter(const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src); // {"schema": "aten::scatter.src(Tensor self, int dim, Tensor index, Tensor src) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & scatter_(at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src); // {"schema": "aten::scatter_.src(Tensor(a!) self, int dim, Tensor index, Tensor src) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & scatter_out(const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src, at::Tensor & out); // {"schema": "aten::scatter.src_out(Tensor self, int dim, Tensor index, Tensor src, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor scatter(const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Scalar & value); // {"schema": "aten::scatter.value(Tensor self, int dim, Tensor index, Scalar value) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & scatter_(at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Scalar & value); // {"schema": "aten::scatter_.value(Tensor(a!) self, int dim, Tensor index, Scalar value) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & scatter_out(const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Scalar & value, at::Tensor & out); // {"schema": "aten::scatter.value_out(Tensor self, int dim, Tensor index, Scalar value, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor scatter(const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src, c10::string_view reduce); // {"schema": "aten::scatter.reduce(Tensor self, int dim, Tensor index, Tensor src, *, str reduce) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & scatter_(at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src, c10::string_view reduce); // {"schema": "aten::scatter_.reduce(Tensor(a!) self, int dim, Tensor index, Tensor src, *, str reduce) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & scatter_out(const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src, c10::string_view reduce, at::Tensor & out); // {"schema": "aten::scatter.reduce_out(Tensor self, int dim, Tensor index, Tensor src, *, str reduce, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor scatter(const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Scalar & value, c10::string_view reduce); // {"schema": "aten::scatter.value_reduce(Tensor self, int dim, Tensor index, Scalar value, *, str reduce) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & scatter_(at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Scalar & value, c10::string_view reduce); // {"schema": "aten::scatter_.value_reduce(Tensor(a!) self, int dim, Tensor index, Scalar value, *, str reduce) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & scatter_out(const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Scalar & value, c10::string_view reduce, at::Tensor & out); // {"schema": "aten::scatter.value_reduce_out(Tensor self, int dim, Tensor index, Scalar value, *, str reduce, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor scatter(const at::Tensor & self, at::Dimname dim, const at::Tensor & index, const at::Tensor & src); // {"schema": "aten::scatter.dimname_src(Tensor self, Dimname dim, Tensor index, Tensor src) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor scatter(const at::Tensor & self, at::Dimname dim, const at::Tensor & index, const at::Scalar & value); // {"schema": "aten::scatter.dimname_value(Tensor self, Dimname dim, Tensor index, Scalar value) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor scatter_add(const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src); // {"schema": "aten::scatter_add(Tensor self, int dim, Tensor index, Tensor src) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & scatter_add_(at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src); // {"schema": "aten::scatter_add_(Tensor(a!) self, int dim, Tensor index, Tensor src) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & scatter_add_out(const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src, at::Tensor & out); // {"schema": "aten::scatter_add.out(Tensor self, int dim, Tensor index, Tensor src, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor scatter_add(const at::Tensor & self, at::Dimname dim, const at::Tensor & index, const at::Tensor & src); // {"schema": "aten::scatter_add.dimname(Tensor self, Dimname dim, Tensor index, Tensor src) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor scatter_reduce(const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src, c10::string_view reduce, bool include_self); // {"schema": "aten::scatter_reduce.two(Tensor self, int dim, Tensor index, Tensor src, str reduce, *, bool include_self=True) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & scatter_reduce_(at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src, c10::string_view reduce, bool include_self); // {"schema": "aten::scatter_reduce_.two(Tensor(a!) self, int dim, Tensor index, Tensor src, str reduce, *, bool include_self=True) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & scatter_reduce_out(const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src, c10::string_view reduce, bool include_self, at::Tensor & out); // {"schema": "aten::scatter_reduce.two_out(Tensor self, int dim, Tensor index, Tensor src, str reduce, *, bool include_self=True, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & eq_(at::Tensor & self, const at::Scalar & other); // {"schema": "aten::eq_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & eq_(at::Tensor & self, const at::Tensor & other); // {"schema": "aten::eq_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & bitwise_and_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::bitwise_and.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & bitwise_and_out(const at::Tensor & self, const at::Scalar & other, at::Tensor & out); // {"schema": "aten::bitwise_and.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor bitwise_and(const at::Tensor & self, const at::Scalar & other); // {"schema": "aten::bitwise_and.Scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor bitwise_and(const at::Scalar & self, const at::Tensor & other); // {"schema": "aten::bitwise_and.Scalar_Tensor(Scalar self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor bitwise_and(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::bitwise_and.Tensor(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & bitwise_and_(at::Tensor & self, const at::Scalar & other); // {"schema": "aten::bitwise_and_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & bitwise_and_(at::Tensor & self, const at::Tensor & other); // {"schema": "aten::bitwise_and_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor __and__(const at::Tensor & self, const at::Scalar & other); // {"schema": "aten::__and__.Scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor __and__(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::__and__.Tensor(Tensor self, Tensor other) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & __iand__(at::Tensor & self, const at::Scalar & other); // {"schema": "aten::__iand__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & __iand__(at::Tensor & self, const at::Tensor & other); // {"schema": "aten::__iand__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & bitwise_or_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::bitwise_or.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & bitwise_or_out(const at::Tensor & self, const at::Scalar & other, at::Tensor & out); // {"schema": "aten::bitwise_or.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor bitwise_or(const at::Tensor & self, const at::Scalar & other); // {"schema": "aten::bitwise_or.Scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor bitwise_or(const at::Scalar & self, const at::Tensor & other); // {"schema": "aten::bitwise_or.Scalar_Tensor(Scalar self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor bitwise_or(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::bitwise_or.Tensor(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & bitwise_or_(at::Tensor & self, const at::Scalar & other); // {"schema": "aten::bitwise_or_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & bitwise_or_(at::Tensor & self, const at::Tensor & other); // {"schema": "aten::bitwise_or_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor __or__(const at::Tensor & self, const at::Scalar & other); // {"schema": "aten::__or__.Scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor __or__(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::__or__.Tensor(Tensor self, Tensor other) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & __ior__(at::Tensor & self, const at::Scalar & other); // {"schema": "aten::__ior__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & __ior__(at::Tensor & self, const at::Tensor & other); // {"schema": "aten::__ior__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & bitwise_xor_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::bitwise_xor.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & bitwise_xor_out(const at::Tensor & self, const at::Scalar & other, at::Tensor & out); // {"schema": "aten::bitwise_xor.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor bitwise_xor(const at::Tensor & self, const at::Scalar & other); // {"schema": "aten::bitwise_xor.Scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor bitwise_xor(const at::Scalar & self, const at::Tensor & other); // {"schema": "aten::bitwise_xor.Scalar_Tensor(Scalar self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor bitwise_xor(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::bitwise_xor.Tensor(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & bitwise_xor_(at::Tensor & self, const at::Scalar & other); // {"schema": "aten::bitwise_xor_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & bitwise_xor_(at::Tensor & self, const at::Tensor & other); // {"schema": "aten::bitwise_xor_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor __xor__(const at::Tensor & self, const at::Scalar & other); // {"schema": "aten::__xor__.Scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor __xor__(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::__xor__.Tensor(Tensor self, Tensor other) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & __ixor__(at::Tensor & self, const at::Scalar & other); // {"schema": "aten::__ixor__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & __ixor__(at::Tensor & self, const at::Tensor & other); // {"schema": "aten::__ixor__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor __lshift__(const at::Tensor & self, const at::Scalar & other); // {"schema": "aten::__lshift__.Scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor __lshift__(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::__lshift__.Tensor(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & __ilshift__(at::Tensor & self, const at::Scalar & other); // {"schema": "aten::__ilshift__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & __ilshift__(at::Tensor & self, const at::Tensor & other); // {"schema": "aten::__ilshift__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor bitwise_left_shift(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::bitwise_left_shift.Tensor(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & bitwise_left_shift_(at::Tensor & self, const at::Tensor & other); // {"schema": "aten::bitwise_left_shift_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & bitwise_left_shift_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::bitwise_left_shift.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor bitwise_left_shift(const at::Tensor & self, const at::Scalar & other); // {"schema": "aten::bitwise_left_shift.Tensor_Scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & bitwise_left_shift_(at::Tensor & self, const at::Scalar & other); // {"schema": "aten::bitwise_left_shift_.Tensor_Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & bitwise_left_shift_out(const at::Tensor & self, const at::Scalar & other, at::Tensor & out); // {"schema": "aten::bitwise_left_shift.Tensor_Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor bitwise_left_shift(const at::Scalar & self, const at::Tensor & other); // {"schema": "aten::bitwise_left_shift.Scalar_Tensor(Scalar self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor __rshift__(const at::Tensor & self, const at::Scalar & other); // {"schema": "aten::__rshift__.Scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor __rshift__(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::__rshift__.Tensor(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & __irshift__(at::Tensor & self, const at::Scalar & other); // {"schema": "aten::__irshift__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & __irshift__(at::Tensor & self, const at::Tensor & other); // {"schema": "aten::__irshift__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor bitwise_right_shift(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::bitwise_right_shift.Tensor(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & bitwise_right_shift_(at::Tensor & self, const at::Tensor & other); // {"schema": "aten::bitwise_right_shift_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & bitwise_right_shift_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::bitwise_right_shift.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor bitwise_right_shift(const at::Tensor & self, const at::Scalar & other); // {"schema": "aten::bitwise_right_shift.Tensor_Scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & bitwise_right_shift_(at::Tensor & self, const at::Scalar & other); // {"schema": "aten::bitwise_right_shift_.Tensor_Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & bitwise_right_shift_out(const at::Tensor & self, const at::Scalar & other, at::Tensor & out); // {"schema": "aten::bitwise_right_shift.Tensor_Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor bitwise_right_shift(const at::Scalar & self, const at::Tensor & other); // {"schema": "aten::bitwise_right_shift.Scalar_Tensor(Scalar self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & tril_(at::Tensor & self, c10::SymInt diagonal); // {"schema": "aten::tril_(Tensor(a!) self, SymInt diagonal=0) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & triu_(at::Tensor & self, c10::SymInt diagonal); // {"schema": "aten::triu_(Tensor(a!) self, SymInt diagonal=0) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & digamma_(at::Tensor & self); // {"schema": "aten::digamma_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & lerp_(at::Tensor & self, const at::Tensor & end, const at::Scalar & weight); // {"schema": "aten::lerp_.Scalar(Tensor(a!) self, Tensor end, Scalar weight) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & lerp_(at::Tensor & self, const at::Tensor & end, const at::Tensor & weight); // {"schema": "aten::lerp_.Tensor(Tensor(a!) self, Tensor end, Tensor weight) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & addbmm_(at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta, const at::Scalar & alpha); // {"schema": "aten::addbmm_(Tensor(a!) self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & addbmm_out(const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out); // {"schema": "aten::addbmm.out(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor addbmm(const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta, const at::Scalar & alpha); // {"schema": "aten::addbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & random_(at::Tensor & self, int64_t from, ::std::optional to, ::std::optional generator); // {"schema": "aten::random_.from(Tensor(a!) self, int from, int? to, *, Generator? generator=None) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & random_(at::Tensor & self, int64_t to, ::std::optional generator); // {"schema": "aten::random_.to(Tensor(a!) self, int to, *, Generator? generator=None) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & random_(at::Tensor & self, ::std::optional generator); // {"schema": "aten::random_(Tensor(a!) self, *, Generator? generator=None) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & uniform_(at::Tensor & self, double from, double to, ::std::optional generator); // {"schema": "aten::uniform_(Tensor(a!) self, float from=0, float to=1, *, Generator? generator=None) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & cauchy_(at::Tensor & self, double median, double sigma, ::std::optional generator); // {"schema": "aten::cauchy_(Tensor(a!) self, float median=0, float sigma=1, *, Generator? generator=None) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & log_normal_(at::Tensor & self, double mean, double std, ::std::optional generator); // {"schema": "aten::log_normal_(Tensor(a!) self, float mean=1, float std=2, *, Generator? generator=None) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & exponential_(at::Tensor & self, double lambd, ::std::optional generator); // {"schema": "aten::exponential_(Tensor(a!) self, float lambd=1, *, Generator? generator=None) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & geometric_(at::Tensor & self, double p, ::std::optional generator); // {"schema": "aten::geometric_(Tensor(a!) self, float p, *, Generator? generator=None) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & diag_out(const at::Tensor & self, int64_t diagonal, at::Tensor & out); // {"schema": "aten::diag.out(Tensor self, int diagonal=0, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor diag(const at::Tensor & self, int64_t diagonal); // {"schema": "aten::diag(Tensor self, int diagonal=0) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & cross_out(const at::Tensor & self, const at::Tensor & other, ::std::optional dim, at::Tensor & out); // {"schema": "aten::cross.out(Tensor self, Tensor other, int? dim=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor cross(const at::Tensor & self, const at::Tensor & other, ::std::optional dim); // {"schema": "aten::cross(Tensor self, Tensor other, int? dim=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & triu_out(const at::Tensor & self, c10::SymInt diagonal, at::Tensor & out); // {"schema": "aten::triu.out(Tensor self, SymInt diagonal=0, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor triu(const at::Tensor & self, c10::SymInt diagonal); // {"schema": "aten::triu(Tensor self, SymInt diagonal=0) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & tril_out(const at::Tensor & self, c10::SymInt diagonal, at::Tensor & out); // {"schema": "aten::tril.out(Tensor self, SymInt diagonal=0, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor tril(const at::Tensor & self, c10::SymInt diagonal); // {"schema": "aten::tril(Tensor self, SymInt diagonal=0) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor tril_indices(int64_t row, int64_t col, int64_t offset, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::tril_indices(int row, int col, int offset=0, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor triu_indices(int64_t row, int64_t col, int64_t offset, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::triu_indices(int row, int col, int offset=0, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor trace(const at::Tensor & self); // {"schema": "aten::trace(Tensor self) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor trace_backward(const at::Tensor & grad, c10::SymIntArrayRef sizes); // {"schema": "aten::trace_backward(Tensor grad, SymInt[] sizes) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & ne_out(const at::Tensor & self, const at::Scalar & other, at::Tensor & out); // {"schema": "aten::ne.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor ne(const at::Tensor & self, const at::Scalar & other); // {"schema": "aten::ne.Scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & ne_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::ne.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor ne(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::ne.Tensor(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & ne_(at::Tensor & self, const at::Scalar & other); // {"schema": "aten::ne_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & ne_(at::Tensor & self, const at::Tensor & other); // {"schema": "aten::ne_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & not_equal_out(const at::Tensor & self, const at::Scalar & other, at::Tensor & out); // {"schema": "aten::not_equal.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor not_equal(const at::Tensor & self, const at::Scalar & other); // {"schema": "aten::not_equal.Scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & not_equal_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::not_equal.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor not_equal(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::not_equal.Tensor(Tensor self, Tensor other) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & not_equal_(at::Tensor & self, const at::Scalar & other); // {"schema": "aten::not_equal_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & not_equal_(at::Tensor & self, const at::Tensor & other); // {"schema": "aten::not_equal_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & eq_out(const at::Tensor & self, const at::Scalar & other, at::Tensor & out); // {"schema": "aten::eq.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor eq(const at::Tensor & self, const at::Scalar & other); // {"schema": "aten::eq.Scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & eq_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::eq.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor eq(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::eq.Tensor(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & ge_out(const at::Tensor & self, const at::Scalar & other, at::Tensor & out); // {"schema": "aten::ge.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor ge(const at::Tensor & self, const at::Scalar & other); // {"schema": "aten::ge.Scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & ge_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::ge.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor ge(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::ge.Tensor(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & ge_(at::Tensor & self, const at::Scalar & other); // {"schema": "aten::ge_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & ge_(at::Tensor & self, const at::Tensor & other); // {"schema": "aten::ge_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & greater_equal_out(const at::Tensor & self, const at::Scalar & other, at::Tensor & out); // {"schema": "aten::greater_equal.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor greater_equal(const at::Tensor & self, const at::Scalar & other); // {"schema": "aten::greater_equal.Scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & greater_equal_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::greater_equal.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor greater_equal(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::greater_equal.Tensor(Tensor self, Tensor other) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & greater_equal_(at::Tensor & self, const at::Scalar & other); // {"schema": "aten::greater_equal_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & greater_equal_(at::Tensor & self, const at::Tensor & other); // {"schema": "aten::greater_equal_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & le_out(const at::Tensor & self, const at::Scalar & other, at::Tensor & out); // {"schema": "aten::le.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor le(const at::Tensor & self, const at::Scalar & other); // {"schema": "aten::le.Scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & le_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::le.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor le(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::le.Tensor(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & le_(at::Tensor & self, const at::Scalar & other); // {"schema": "aten::le_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & le_(at::Tensor & self, const at::Tensor & other); // {"schema": "aten::le_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & less_equal_out(const at::Tensor & self, const at::Scalar & other, at::Tensor & out); // {"schema": "aten::less_equal.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor less_equal(const at::Tensor & self, const at::Scalar & other); // {"schema": "aten::less_equal.Scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & less_equal_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::less_equal.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor less_equal(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::less_equal.Tensor(Tensor self, Tensor other) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & less_equal_(at::Tensor & self, const at::Scalar & other); // {"schema": "aten::less_equal_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & less_equal_(at::Tensor & self, const at::Tensor & other); // {"schema": "aten::less_equal_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & gt_out(const at::Tensor & self, const at::Scalar & other, at::Tensor & out); // {"schema": "aten::gt.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor gt(const at::Tensor & self, const at::Scalar & other); // {"schema": "aten::gt.Scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & gt_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::gt.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor gt(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::gt.Tensor(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & gt_(at::Tensor & self, const at::Scalar & other); // {"schema": "aten::gt_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & gt_(at::Tensor & self, const at::Tensor & other); // {"schema": "aten::gt_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & greater_out(const at::Tensor & self, const at::Scalar & other, at::Tensor & out); // {"schema": "aten::greater.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor greater(const at::Tensor & self, const at::Scalar & other); // {"schema": "aten::greater.Scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & greater_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::greater.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor greater(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::greater.Tensor(Tensor self, Tensor other) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & greater_(at::Tensor & self, const at::Scalar & other); // {"schema": "aten::greater_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & greater_(at::Tensor & self, const at::Tensor & other); // {"schema": "aten::greater_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & lt_out(const at::Tensor & self, const at::Scalar & other, at::Tensor & out); // {"schema": "aten::lt.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor lt(const at::Tensor & self, const at::Scalar & other); // {"schema": "aten::lt.Scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & lt_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::lt.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor lt(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::lt.Tensor(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & lt_(at::Tensor & self, const at::Scalar & other); // {"schema": "aten::lt_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & lt_(at::Tensor & self, const at::Tensor & other); // {"schema": "aten::lt_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & less_out(const at::Tensor & self, const at::Scalar & other, at::Tensor & out); // {"schema": "aten::less.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor less(const at::Tensor & self, const at::Scalar & other); // {"schema": "aten::less.Scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & less_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::less.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor less(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::less.Tensor(Tensor self, Tensor other) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & less_(at::Tensor & self, const at::Scalar & other); // {"schema": "aten::less_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & less_(at::Tensor & self, const at::Tensor & other); // {"schema": "aten::less_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & take_out(const at::Tensor & self, const at::Tensor & index, at::Tensor & out); // {"schema": "aten::take.out(Tensor self, Tensor index, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor take(const at::Tensor & self, const at::Tensor & index); // {"schema": "aten::take(Tensor self, Tensor index) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & take_along_dim_out(const at::Tensor & self, const at::Tensor & indices, ::std::optional dim, at::Tensor & out); // {"schema": "aten::take_along_dim.out(Tensor self, Tensor indices, int? dim=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor take_along_dim(const at::Tensor & self, const at::Tensor & indices, ::std::optional dim); // {"schema": "aten::take_along_dim(Tensor self, Tensor indices, int? dim=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & index_select_out(const at::Tensor & self, int64_t dim, const at::Tensor & index, at::Tensor & out); // {"schema": "aten::index_select.out(Tensor self, int dim, Tensor index, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor index_select(const at::Tensor & self, int64_t dim, const at::Tensor & index); // {"schema": "aten::index_select(Tensor self, int dim, Tensor index) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & index_select_out(const at::Tensor & self, at::Dimname dim, const at::Tensor & index, at::Tensor & out); // {"schema": "aten::index_select.dimname_out(Tensor self, Dimname dim, Tensor index, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor index_select(const at::Tensor & self, at::Dimname dim, const at::Tensor & index); // {"schema": "aten::index_select.dimname(Tensor self, Dimname dim, Tensor index) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor index_select_backward(const at::Tensor & grad, c10::SymIntArrayRef self_sizes, int64_t dim, const at::Tensor & index); // {"schema": "aten::index_select_backward(Tensor grad, SymInt[] self_sizes, int dim, Tensor index) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & masked_select_out(const at::Tensor & self, const at::Tensor & mask, at::Tensor & out); // {"schema": "aten::masked_select.out(Tensor self, Tensor mask, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor masked_select(const at::Tensor & self, const at::Tensor & mask); // {"schema": "aten::masked_select(Tensor self, Tensor mask) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor masked_select_backward(const at::Tensor & grad, const at::Tensor & input, const at::Tensor & mask); // {"schema": "aten::masked_select_backward(Tensor grad, Tensor input, Tensor mask) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & nonzero_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::nonzero.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor nonzero(const at::Tensor & self); // {"schema": "aten::nonzero(Tensor self) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & nonzero_static_out(const at::Tensor & self, c10::SymInt size, int64_t fill_value, at::Tensor & out); // {"schema": "aten::nonzero_static.out(Tensor self, *, SymInt size, int fill_value=-1, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor nonzero_static(const at::Tensor & self, c10::SymInt size, int64_t fill_value); // {"schema": "aten::nonzero_static(Tensor self, *, SymInt size, int fill_value=-1) -> Tensor", "dispatch": "True", "default": "False"} +::std::vector nonzero_numpy(const at::Tensor & self); // {"schema": "aten::nonzero_numpy(Tensor self) -> Tensor[]", "dispatch": "False", "default": "True"} +at::Tensor argwhere(const at::Tensor & self); // {"schema": "aten::argwhere(Tensor self) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & gather_out(const at::Tensor & self, int64_t dim, const at::Tensor & index, bool sparse_grad, at::Tensor & out); // {"schema": "aten::gather.out(Tensor self, int dim, Tensor index, *, bool sparse_grad=False, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor gather(const at::Tensor & self, int64_t dim, const at::Tensor & index, bool sparse_grad); // {"schema": "aten::gather(Tensor self, int dim, Tensor index, *, bool sparse_grad=False) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor gather_backward(const at::Tensor & grad, const at::Tensor & self, int64_t dim, const at::Tensor & index, bool sparse_grad); // {"schema": "aten::gather_backward(Tensor grad, Tensor self, int dim, Tensor index, bool sparse_grad) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & gather_out(const at::Tensor & self, at::Dimname dim, const at::Tensor & index, bool sparse_grad, at::Tensor & out); // {"schema": "aten::gather.dimname_out(Tensor self, Dimname dim, Tensor index, *, bool sparse_grad=False, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor gather(const at::Tensor & self, at::Dimname dim, const at::Tensor & index, bool sparse_grad); // {"schema": "aten::gather.dimname(Tensor self, Dimname dim, Tensor index, *, bool sparse_grad=False) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _gather_sparse_backward(const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & grad); // {"schema": "aten::_gather_sparse_backward(Tensor self, int dim, Tensor index, Tensor grad) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & addcmul_out(const at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value, at::Tensor & out); // {"schema": "aten::addcmul.out(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor addcmul(const at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value); // {"schema": "aten::addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & addcmul_(at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value); // {"schema": "aten::addcmul_(Tensor(a!) self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & addcdiv_out(const at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value, at::Tensor & out); // {"schema": "aten::addcdiv.out(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor addcdiv(const at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value); // {"schema": "aten::addcdiv(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & addcdiv_(at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value); // {"schema": "aten::addcdiv_(Tensor(a!) self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor cross_entropy_loss(const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, c10::SymInt ignore_index, double label_smoothing); // {"schema": "aten::cross_entropy_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, float label_smoothing=0.0) -> Tensor", "dispatch": "False", "default": "True"} +::std::tuple triangular_solve_out(const at::Tensor & self, const at::Tensor & A, bool upper, bool transpose, bool unitriangular, at::Tensor & X, at::Tensor & M); // {"schema": "aten::triangular_solve.X(Tensor self, Tensor A, bool upper=True, bool transpose=False, bool unitriangular=False, *, Tensor(a!) X, Tensor(b!) M) -> (Tensor(a!) solution, Tensor(b!) cloned_coefficient)", "dispatch": "True", "default": "False"} +::std::tuple triangular_solve(const at::Tensor & self, const at::Tensor & A, bool upper, bool transpose, bool unitriangular); // {"schema": "aten::triangular_solve(Tensor self, Tensor A, bool upper=True, bool transpose=False, bool unitriangular=False) -> (Tensor solution, Tensor cloned_coefficient)", "dispatch": "True", "default": "True"} +void _linalg_check_errors(const at::Tensor & info, c10::string_view api_name, bool is_matrix); // {"schema": "aten::_linalg_check_errors(Tensor info, str api_name, *, bool is_matrix) -> ()", "dispatch": "True", "default": "True"} +at::Tensor & linalg_solve_triangular_out(const at::Tensor & self, const at::Tensor & B, bool upper, bool left, bool unitriangular, at::Tensor & out); // {"schema": "aten::linalg_solve_triangular.out(Tensor self, Tensor B, *, bool upper, bool left=True, bool unitriangular=False, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor linalg_solve_triangular(const at::Tensor & self, const at::Tensor & B, bool upper, bool left, bool unitriangular); // {"schema": "aten::linalg_solve_triangular(Tensor self, Tensor B, *, bool upper, bool left=True, bool unitriangular=False) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor linalg_vander(const at::Tensor & x, ::std::optional N); // {"schema": "aten::linalg_vander(Tensor x, *, SymInt? N=None) -> Tensor", "dispatch": "False", "default": "True"} +::std::tuple svd_out(const at::Tensor & self, bool some, bool compute_uv, at::Tensor & U, at::Tensor & S, at::Tensor & V); // {"schema": "aten::svd.U(Tensor self, bool some=True, bool compute_uv=True, *, Tensor(a!) U, Tensor(b!) S, Tensor(c!) V) -> (Tensor(a!) U, Tensor(b!) S, Tensor(c!) V)", "dispatch": "False", "default": "True"} +::std::tuple svd(const at::Tensor & self, bool some, bool compute_uv); // {"schema": "aten::svd(Tensor self, bool some=True, bool compute_uv=True) -> (Tensor U, Tensor S, Tensor V)", "dispatch": "False", "default": "True"} +at::Tensor swapaxes(const at::Tensor & self, int64_t axis0, int64_t axis1); // {"schema": "aten::swapaxes(Tensor(a) self, int axis0, int axis1) -> Tensor(a)", "dispatch": "False", "default": "True"} +at::Tensor & swapaxes_(at::Tensor & self, int64_t axis0, int64_t axis1); // {"schema": "aten::swapaxes_(Tensor(a!) self, int axis0, int axis1) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor swapdims(const at::Tensor & self, int64_t dim0, int64_t dim1); // {"schema": "aten::swapdims(Tensor(a) self, int dim0, int dim1) -> Tensor(a)", "dispatch": "False", "default": "True"} +at::Tensor & swapdims_(at::Tensor & self, int64_t dim0, int64_t dim1); // {"schema": "aten::swapdims_(Tensor(a!) self, int dim0, int dim1) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & cholesky_out(const at::Tensor & self, bool upper, at::Tensor & out); // {"schema": "aten::cholesky.out(Tensor self, bool upper=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor cholesky(const at::Tensor & self, bool upper); // {"schema": "aten::cholesky(Tensor self, bool upper=False) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & cholesky_solve_out(const at::Tensor & self, const at::Tensor & input2, bool upper, at::Tensor & out); // {"schema": "aten::cholesky_solve.out(Tensor self, Tensor input2, bool upper=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor cholesky_solve(const at::Tensor & self, const at::Tensor & input2, bool upper); // {"schema": "aten::cholesky_solve(Tensor self, Tensor input2, bool upper=False) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor _cholesky_solve_helper(const at::Tensor & self, const at::Tensor & A, bool upper); // {"schema": "aten::_cholesky_solve_helper(Tensor self, Tensor A, bool upper) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor cholesky_inverse(const at::Tensor & self, bool upper); // {"schema": "aten::cholesky_inverse(Tensor self, bool upper=False) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & cholesky_inverse_out(const at::Tensor & self, bool upper, at::Tensor & out); // {"schema": "aten::cholesky_inverse.out(Tensor self, bool upper=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +::std::tuple qr_out(const at::Tensor & self, bool some, at::Tensor & Q, at::Tensor & R); // {"schema": "aten::qr.Q(Tensor self, bool some=True, *, Tensor(a!) Q, Tensor(b!) R) -> (Tensor(a!) Q, Tensor(b!) R)", "dispatch": "False", "default": "True"} +::std::tuple qr(const at::Tensor & self, bool some); // {"schema": "aten::qr(Tensor self, bool some=True) -> (Tensor Q, Tensor R)", "dispatch": "False", "default": "True"} +::std::tuple geqrf_out(const at::Tensor & self, at::Tensor & a, at::Tensor & tau); // {"schema": "aten::geqrf.a(Tensor self, *, Tensor(a!) a, Tensor(b!) tau) -> (Tensor(a!) a, Tensor(b!) tau)", "dispatch": "True", "default": "False"} +::std::tuple geqrf(const at::Tensor & self); // {"schema": "aten::geqrf(Tensor self) -> (Tensor a, Tensor tau)", "dispatch": "True", "default": "False"} +at::Tensor orgqr(const at::Tensor & self, const at::Tensor & input2); // {"schema": "aten::orgqr(Tensor self, Tensor input2) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & orgqr_out(const at::Tensor & self, const at::Tensor & input2, at::Tensor & out); // {"schema": "aten::orgqr.out(Tensor self, Tensor input2, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & ormqr_out(const at::Tensor & self, const at::Tensor & input2, const at::Tensor & input3, bool left, bool transpose, at::Tensor & out); // {"schema": "aten::ormqr.out(Tensor self, Tensor input2, Tensor input3, bool left=True, bool transpose=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor ormqr(const at::Tensor & self, const at::Tensor & input2, const at::Tensor & input3, bool left, bool transpose); // {"schema": "aten::ormqr(Tensor self, Tensor input2, Tensor input3, bool left=True, bool transpose=False) -> Tensor", "dispatch": "True", "default": "False"} +::std::tuple _lu_with_info(const at::Tensor & self, bool pivot, bool check_errors); // {"schema": "aten::_lu_with_info(Tensor self, bool pivot=True, bool check_errors=True) -> (Tensor LU, Tensor pivots, Tensor info)", "dispatch": "False", "default": "True"} +at::Tensor & lu_solve_out(const at::Tensor & self, const at::Tensor & LU_data, const at::Tensor & LU_pivots, at::Tensor & out); // {"schema": "aten::lu_solve.out(Tensor self, Tensor LU_data, Tensor LU_pivots, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor lu_solve(const at::Tensor & self, const at::Tensor & LU_data, const at::Tensor & LU_pivots); // {"schema": "aten::lu_solve(Tensor self, Tensor LU_data, Tensor LU_pivots) -> Tensor", "dispatch": "False", "default": "True"} +::std::tuple lu_unpack(const at::Tensor & LU_data, const at::Tensor & LU_pivots, bool unpack_data, bool unpack_pivots); // {"schema": "aten::lu_unpack(Tensor LU_data, Tensor LU_pivots, bool unpack_data=True, bool unpack_pivots=True) -> (Tensor P, Tensor L, Tensor U)", "dispatch": "True", "default": "True"} +::std::tuple lu_unpack_out(const at::Tensor & LU_data, const at::Tensor & LU_pivots, bool unpack_data, bool unpack_pivots, at::Tensor & P, at::Tensor & L, at::Tensor & U); // {"schema": "aten::lu_unpack.out(Tensor LU_data, Tensor LU_pivots, bool unpack_data=True, bool unpack_pivots=True, *, Tensor(a!) P, Tensor(b!) L, Tensor(c!) U) -> (Tensor(a!) P, Tensor(b!) L, Tensor(c!) U)", "dispatch": "True", "default": "False"} +at::Tensor & multinomial_out(const at::Tensor & self, c10::SymInt num_samples, bool replacement, ::std::optional generator, at::Tensor & out); // {"schema": "aten::multinomial.out(Tensor self, SymInt num_samples, bool replacement=False, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor multinomial(const at::Tensor & self, c10::SymInt num_samples, bool replacement, ::std::optional generator); // {"schema": "aten::multinomial(Tensor self, SymInt num_samples, bool replacement=False, *, Generator? generator=None) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & lgamma_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::lgamma.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & lgamma_(at::Tensor & self); // {"schema": "aten::lgamma_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor lgamma(const at::Tensor & self); // {"schema": "aten::lgamma(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & digamma_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::digamma.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor digamma(const at::Tensor & self); // {"schema": "aten::digamma(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & polygamma_out(int64_t n, const at::Tensor & self, at::Tensor & out); // {"schema": "aten::polygamma.out(int n, Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor polygamma(int64_t n, const at::Tensor & self); // {"schema": "aten::polygamma(int n, Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & polygamma_(at::Tensor & self, int64_t n); // {"schema": "aten::polygamma_(Tensor(a!) self, int n) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor erfinv(const at::Tensor & self); // {"schema": "aten::erfinv(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & erfinv_(at::Tensor & self); // {"schema": "aten::erfinv_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & erfinv_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::erfinv.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor i0(const at::Tensor & self); // {"schema": "aten::i0(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & i0_(at::Tensor & self); // {"schema": "aten::i0_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & i0_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::i0.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor sign(const at::Tensor & self); // {"schema": "aten::sign(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & sign_(at::Tensor & self); // {"schema": "aten::sign_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & sign_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::sign.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor signbit(const at::Tensor & self); // {"schema": "aten::signbit(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & signbit_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::signbit.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor dist(const at::Tensor & self, const at::Tensor & other, const at::Scalar & p); // {"schema": "aten::dist(Tensor self, Tensor other, Scalar p=2) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & atan2_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::atan2.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & atan2_(at::Tensor & self, const at::Tensor & other); // {"schema": "aten::atan2_(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor atan2(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::atan2(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor arctan2(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::arctan2(Tensor self, Tensor other) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & arctan2_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::arctan2.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & arctan2_(at::Tensor & self, const at::Tensor & other); // {"schema": "aten::arctan2_(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & lerp_out(const at::Tensor & self, const at::Tensor & end, const at::Scalar & weight, at::Tensor & out); // {"schema": "aten::lerp.Scalar_out(Tensor self, Tensor end, Scalar weight, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & lerp_out(const at::Tensor & self, const at::Tensor & end, const at::Tensor & weight, at::Tensor & out); // {"schema": "aten::lerp.Tensor_out(Tensor self, Tensor end, Tensor weight, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor lerp(const at::Tensor & self, const at::Tensor & end, const at::Scalar & weight); // {"schema": "aten::lerp.Scalar(Tensor self, Tensor end, Scalar weight) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor lerp(const at::Tensor & self, const at::Tensor & end, const at::Tensor & weight); // {"schema": "aten::lerp.Tensor(Tensor self, Tensor end, Tensor weight) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & histc_out(const at::Tensor & self, int64_t bins, const at::Scalar & min, const at::Scalar & max, at::Tensor & out); // {"schema": "aten::histc.out(Tensor self, int bins=100, Scalar min=0, Scalar max=0, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor histc(const at::Tensor & self, int64_t bins, const at::Scalar & min, const at::Scalar & max); // {"schema": "aten::histc(Tensor self, int bins=100, Scalar min=0, Scalar max=0) -> Tensor", "dispatch": "True", "default": "False"} +::std::tuple histogram_out(const at::Tensor & self, const at::Tensor & bins, const ::std::optional & weight, bool density, at::Tensor & hist, at::Tensor & bin_edges); // {"schema": "aten::histogram.bins_tensor_out(Tensor self, Tensor bins, *, Tensor? weight=None, bool density=False, Tensor(a!) hist, Tensor(b!) bin_edges) -> (Tensor(a!) hist, Tensor(b!) bin_edges)", "dispatch": "True", "default": "False"} +::std::tuple histogram(const at::Tensor & self, const at::Tensor & bins, const ::std::optional & weight, bool density); // {"schema": "aten::histogram.bins_tensor(Tensor self, Tensor bins, *, Tensor? weight=None, bool density=False) -> (Tensor hist, Tensor bin_edges)", "dispatch": "True", "default": "False"} +::std::tuple histogram_out(const at::Tensor & self, int64_t bins, ::std::optional> range, const ::std::optional & weight, bool density, at::Tensor & hist, at::Tensor & bin_edges); // {"schema": "aten::histogram.bin_ct_out(Tensor self, int bins=100, *, float[]? range=None, Tensor? weight=None, bool density=False, Tensor(a!) hist, Tensor(b!) bin_edges) -> (Tensor(a!) hist, Tensor(b!) bin_edges)", "dispatch": "True", "default": "False"} +::std::tuple histogram(const at::Tensor & self, int64_t bins, ::std::optional> range, const ::std::optional & weight, bool density); // {"schema": "aten::histogram.bin_ct(Tensor self, int bins=100, *, float[]? range=None, Tensor? weight=None, bool density=False) -> (Tensor hist, Tensor bin_edges)", "dispatch": "True", "default": "False"} +::std::vector _histogramdd_bin_edges(const at::Tensor & self, at::IntArrayRef bins, ::std::optional> range, const ::std::optional & weight, bool density); // {"schema": "aten::_histogramdd_bin_edges(Tensor self, int[] bins, *, float[]? range=None, Tensor? weight=None, bool density=False) -> Tensor[]", "dispatch": "True", "default": "False"} +at::Tensor _histogramdd_from_bin_cts(const at::Tensor & self, at::IntArrayRef bins, ::std::optional> range, const ::std::optional & weight, bool density); // {"schema": "aten::_histogramdd_from_bin_cts(Tensor self, int[] bins, *, float[]? range=None, Tensor? weight=None, bool density=False) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _histogramdd_from_bin_tensors(const at::Tensor & self, at::TensorList bins, const ::std::optional & weight, bool density); // {"schema": "aten::_histogramdd_from_bin_tensors(Tensor self, Tensor[] bins, *, Tensor? weight=None, bool density=False) -> Tensor", "dispatch": "True", "default": "False"} +::std::tuple> histogramdd(const at::Tensor & self, at::IntArrayRef bins, ::std::optional> range, const ::std::optional & weight, bool density); // {"schema": "aten::histogramdd(Tensor self, int[] bins, float[]? range=None, Tensor? weight=None, bool density=False) -> (Tensor hist, Tensor[] bin_edges)", "dispatch": "False", "default": "True"} +::std::tuple> histogramdd(const at::Tensor & self, int64_t bins, ::std::optional> range, const ::std::optional & weight, bool density); // {"schema": "aten::histogramdd.int_bins(Tensor self, int bins, float[]? range=None, Tensor? weight=None, bool density=False) -> (Tensor hist, Tensor[] bin_edges)", "dispatch": "False", "default": "True"} +::std::tuple> histogramdd(const at::Tensor & self, at::TensorList bins, ::std::optional> range, const ::std::optional & weight, bool density); // {"schema": "aten::histogramdd.TensorList_bins(Tensor self, Tensor[] bins, float[]? range=None, Tensor? weight=None, bool density=False) -> (Tensor hist, Tensor[] bin_edges)", "dispatch": "False", "default": "True"} +at::Tensor & fmod_out(const at::Tensor & self, const at::Scalar & other, at::Tensor & out); // {"schema": "aten::fmod.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor fmod(const at::Tensor & self, const at::Scalar & other); // {"schema": "aten::fmod.Scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & fmod_(at::Tensor & self, const at::Scalar & other); // {"schema": "aten::fmod_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & fmod_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::fmod.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor fmod(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::fmod.Tensor(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & fmod_(at::Tensor & self, const at::Tensor & other); // {"schema": "aten::fmod_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & hypot_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::hypot.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor hypot(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::hypot(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & hypot_(at::Tensor & self, const at::Tensor & other); // {"schema": "aten::hypot_(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & igamma_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::igamma.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor igamma(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::igamma(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & igamma_(at::Tensor & self, const at::Tensor & other); // {"schema": "aten::igamma_(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & igammac_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::igammac.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor igammac(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::igammac(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & igammac_(at::Tensor & self, const at::Tensor & other); // {"schema": "aten::igammac_(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & nextafter_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::nextafter.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor nextafter(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::nextafter(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & nextafter_(at::Tensor & self, const at::Tensor & other); // {"schema": "aten::nextafter_(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & remainder_out(const at::Tensor & self, const at::Scalar & other, at::Tensor & out); // {"schema": "aten::remainder.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor remainder(const at::Tensor & self, const at::Scalar & other); // {"schema": "aten::remainder.Scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & remainder_(at::Tensor & self, const at::Scalar & other); // {"schema": "aten::remainder_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & remainder_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::remainder.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor remainder(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::remainder.Tensor(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & remainder_(at::Tensor & self, const at::Tensor & other); // {"schema": "aten::remainder_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor remainder(const at::Scalar & self, const at::Tensor & other); // {"schema": "aten::remainder.Scalar_Tensor(Scalar self, Tensor other) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor min(const at::Tensor & self); // {"schema": "aten::min(Tensor self) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & min_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::min.unary_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor fmin(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::fmin(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & fmin_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::fmin.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor max(const at::Tensor & self); // {"schema": "aten::max(Tensor self) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor fmax(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::fmax(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & fmax_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::fmax.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor maximum(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::maximum(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & maximum_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::maximum.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor max(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::max.other(Tensor self, Tensor other) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & max_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::max.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & max_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::max.unary_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor minimum(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::minimum(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & minimum_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::minimum.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & min_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::min.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor min(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::min.other(Tensor self, Tensor other) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor quantile(const at::Tensor & self, const at::Tensor & q, ::std::optional dim, bool keepdim, c10::string_view interpolation); // {"schema": "aten::quantile(Tensor self, Tensor q, int? dim=None, bool keepdim=False, *, str interpolation='linear') -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & quantile_out(const at::Tensor & self, const at::Tensor & q, ::std::optional dim, bool keepdim, c10::string_view interpolation, at::Tensor & out); // {"schema": "aten::quantile.out(Tensor self, Tensor q, int? dim=None, bool keepdim=False, *, str interpolation='linear', Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor quantile(const at::Tensor & self, double q, ::std::optional dim, bool keepdim, c10::string_view interpolation); // {"schema": "aten::quantile.scalar(Tensor self, float q, int? dim=None, bool keepdim=False, *, str interpolation='linear') -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & quantile_out(const at::Tensor & self, double q, ::std::optional dim, bool keepdim, c10::string_view interpolation, at::Tensor & out); // {"schema": "aten::quantile.scalar_out(Tensor self, float q, int? dim=None, bool keepdim=False, *, str interpolation='linear', Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor nanquantile(const at::Tensor & self, const at::Tensor & q, ::std::optional dim, bool keepdim, c10::string_view interpolation); // {"schema": "aten::nanquantile(Tensor self, Tensor q, int? dim=None, bool keepdim=False, *, str interpolation='linear') -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & nanquantile_out(const at::Tensor & self, const at::Tensor & q, ::std::optional dim, bool keepdim, c10::string_view interpolation, at::Tensor & out); // {"schema": "aten::nanquantile.out(Tensor self, Tensor q, int? dim=None, bool keepdim=False, *, str interpolation='linear', Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor nanquantile(const at::Tensor & self, double q, ::std::optional dim, bool keepdim, c10::string_view interpolation); // {"schema": "aten::nanquantile.scalar(Tensor self, float q, int? dim=None, bool keepdim=False, *, str interpolation='linear') -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & nanquantile_out(const at::Tensor & self, double q, ::std::optional dim, bool keepdim, c10::string_view interpolation, at::Tensor & out); // {"schema": "aten::nanquantile.scalar_out(Tensor self, float q, int? dim=None, bool keepdim=False, *, str interpolation='linear', Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +::std::tuple sort_out(const at::Tensor & self, int64_t dim, bool descending, at::Tensor & values, at::Tensor & indices); // {"schema": "aten::sort.values(Tensor self, int dim=-1, bool descending=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)", "dispatch": "True", "default": "True"} +::std::tuple sort_out(const at::Tensor & self, ::std::optional stable, int64_t dim, bool descending, at::Tensor & values, at::Tensor & indices); // {"schema": "aten::sort.values_stable(Tensor self, *, bool? stable, int dim=-1, bool descending=False, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)", "dispatch": "True", "default": "False"} +::std::tuple sort(const at::Tensor & self, int64_t dim, bool descending); // {"schema": "aten::sort(Tensor self, int dim=-1, bool descending=False) -> (Tensor values, Tensor indices)", "dispatch": "True", "default": "True"} +::std::tuple sort(const at::Tensor & self, ::std::optional stable, int64_t dim, bool descending); // {"schema": "aten::sort.stable(Tensor self, *, bool? stable, int dim=-1, bool descending=False) -> (Tensor values, Tensor indices)", "dispatch": "True", "default": "True"} +::std::tuple sort_out(const at::Tensor & self, at::Dimname dim, bool descending, at::Tensor & values, at::Tensor & indices); // {"schema": "aten::sort.dimname_values(Tensor self, Dimname dim, bool descending=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)", "dispatch": "False", "default": "True"} +::std::tuple sort_out(const at::Tensor & self, ::std::optional stable, at::Dimname dim, bool descending, at::Tensor & values, at::Tensor & indices); // {"schema": "aten::sort.dimname_values_stable(Tensor self, *, bool? stable, Dimname dim, bool descending=False, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)", "dispatch": "False", "default": "True"} +::std::tuple sort(const at::Tensor & self, at::Dimname dim, bool descending); // {"schema": "aten::sort.dimname(Tensor self, Dimname dim, bool descending=False) -> (Tensor values, Tensor indices)", "dispatch": "False", "default": "True"} +::std::tuple sort(const at::Tensor & self, ::std::optional stable, at::Dimname dim, bool descending); // {"schema": "aten::sort.dimname_stable(Tensor self, *, bool? stable, Dimname dim, bool descending=False) -> (Tensor values, Tensor indices)", "dispatch": "False", "default": "True"} +at::Tensor & msort_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::msort.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor msort(const at::Tensor & self); // {"schema": "aten::msort(Tensor self) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor argsort(const at::Tensor & self, int64_t dim, bool descending); // {"schema": "aten::argsort(Tensor self, int dim=-1, bool descending=False) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor argsort(const at::Tensor & self, bool stable, int64_t dim, bool descending); // {"schema": "aten::argsort.stable(Tensor self, *, bool stable, int dim=-1, bool descending=False) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & argsort_out(const at::Tensor & self, bool stable, int64_t dim, bool descending, at::Tensor & out); // {"schema": "aten::argsort.stable_out(Tensor self, *, bool stable, int dim=-1, bool descending=False, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor argsort(const at::Tensor & self, at::Dimname dim, bool descending); // {"schema": "aten::argsort.dimname(Tensor self, Dimname dim, bool descending=False) -> Tensor", "dispatch": "False", "default": "True"} +::std::tuple topk_out(const at::Tensor & self, c10::SymInt k, int64_t dim, bool largest, bool sorted, at::Tensor & values, at::Tensor & indices); // {"schema": "aten::topk.values(Tensor self, SymInt k, int dim=-1, bool largest=True, bool sorted=True, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)", "dispatch": "True", "default": "False"} +::std::tuple topk(const at::Tensor & self, c10::SymInt k, int64_t dim, bool largest, bool sorted); // {"schema": "aten::topk(Tensor self, SymInt k, int dim=-1, bool largest=True, bool sorted=True) -> (Tensor values, Tensor indices)", "dispatch": "True", "default": "True"} +at::Tensor all(const at::Tensor & self); // {"schema": "aten::all(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & all_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::all.all_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor any(const at::Tensor & self); // {"schema": "aten::any(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & any_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::any.all_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & renorm_out(const at::Tensor & self, const at::Scalar & p, int64_t dim, const at::Scalar & maxnorm, at::Tensor & out); // {"schema": "aten::renorm.out(Tensor self, Scalar p, int dim, Scalar maxnorm, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor renorm(const at::Tensor & self, const at::Scalar & p, int64_t dim, const at::Scalar & maxnorm); // {"schema": "aten::renorm(Tensor self, Scalar p, int dim, Scalar maxnorm) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & renorm_(at::Tensor & self, const at::Scalar & p, int64_t dim, const at::Scalar & maxnorm); // {"schema": "aten::renorm_(Tensor(a!) self, Scalar p, int dim, Scalar maxnorm) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor unfold(const at::Tensor & self, int64_t dimension, int64_t size, int64_t step); // {"schema": "aten::unfold(Tensor(a) self, int dimension, int size, int step) -> Tensor(a)", "dispatch": "True", "default": "False"} +at::Tensor unfold_backward(const at::Tensor & grad_in, c10::SymIntArrayRef input_sizes, int64_t dim, int64_t size, int64_t step); // {"schema": "aten::unfold_backward(Tensor grad_in, SymInt[] input_sizes, int dim, int size, int step) -> Tensor", "dispatch": "True", "default": "False"} +bool equal(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::equal(Tensor self, Tensor other) -> bool", "dispatch": "True", "default": "False"} +at::Tensor & pow_out(const at::Tensor & self, const at::Tensor & exponent, at::Tensor & out); // {"schema": "aten::pow.Tensor_Tensor_out(Tensor self, Tensor exponent, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor pow(const at::Tensor & self, const at::Tensor & exponent); // {"schema": "aten::pow.Tensor_Tensor(Tensor self, Tensor exponent) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & pow_out(const at::Scalar & self, const at::Tensor & exponent, at::Tensor & out); // {"schema": "aten::pow.Scalar_out(Scalar self, Tensor exponent, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor pow(const at::Scalar & self, const at::Tensor & exponent); // {"schema": "aten::pow.Scalar(Scalar self, Tensor exponent) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & pow_out(const at::Tensor & self, const at::Scalar & exponent, at::Tensor & out); // {"schema": "aten::pow.Tensor_Scalar_out(Tensor self, Scalar exponent, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor pow(const at::Tensor & self, const at::Scalar & exponent); // {"schema": "aten::pow.Tensor_Scalar(Tensor self, Scalar exponent) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & pow_(at::Tensor & self, const at::Scalar & exponent); // {"schema": "aten::pow_.Scalar(Tensor(a!) self, Scalar exponent) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & pow_(at::Tensor & self, const at::Tensor & exponent); // {"schema": "aten::pow_.Tensor(Tensor(a!) self, Tensor exponent) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & float_power_out(const at::Tensor & self, const at::Tensor & exponent, at::Tensor & out); // {"schema": "aten::float_power.Tensor_Tensor_out(Tensor self, Tensor exponent, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor float_power(const at::Tensor & self, const at::Tensor & exponent); // {"schema": "aten::float_power.Tensor_Tensor(Tensor self, Tensor exponent) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & float_power_out(const at::Scalar & self, const at::Tensor & exponent, at::Tensor & out); // {"schema": "aten::float_power.Scalar_out(Scalar self, Tensor exponent, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor float_power(const at::Scalar & self, const at::Tensor & exponent); // {"schema": "aten::float_power.Scalar(Scalar self, Tensor exponent) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & float_power_out(const at::Tensor & self, const at::Scalar & exponent, at::Tensor & out); // {"schema": "aten::float_power.Tensor_Scalar_out(Tensor self, Scalar exponent, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor float_power(const at::Tensor & self, const at::Scalar & exponent); // {"schema": "aten::float_power.Tensor_Scalar(Tensor self, Scalar exponent) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & float_power_(at::Tensor & self, const at::Scalar & exponent); // {"schema": "aten::float_power_.Scalar(Tensor(a!) self, Scalar exponent) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & float_power_(at::Tensor & self, const at::Tensor & exponent); // {"schema": "aten::float_power_.Tensor(Tensor(a!) self, Tensor exponent) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & normal_(at::Tensor & self, double mean, double std, ::std::optional generator); // {"schema": "aten::normal_(Tensor(a!) self, float mean=0, float std=1, *, Generator? generator=None) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor normal_functional(const at::Tensor & self, double mean, double std, ::std::optional generator); // {"schema": "aten::normal_functional(Tensor self, float mean=0, float std=1, *, Generator? generator=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & normal_out(const at::Tensor & mean, double std, ::std::optional generator, at::Tensor & out); // {"schema": "aten::normal.Tensor_float_out(Tensor mean, float std=1, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor normal(const at::Tensor & mean, double std, ::std::optional generator); // {"schema": "aten::normal.Tensor_float(Tensor mean, float std=1, *, Generator? generator=None) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & normal_out(double mean, const at::Tensor & std, ::std::optional generator, at::Tensor & out); // {"schema": "aten::normal.float_Tensor_out(float mean, Tensor std, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor normal(double mean, const at::Tensor & std, ::std::optional generator); // {"schema": "aten::normal.float_Tensor(float mean, Tensor std, *, Generator? generator=None) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & normal_out(const at::Tensor & mean, const at::Tensor & std, ::std::optional generator, at::Tensor & out); // {"schema": "aten::normal.Tensor_Tensor_out(Tensor mean, Tensor std, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor normal(const at::Tensor & mean, const at::Tensor & std, ::std::optional generator); // {"schema": "aten::normal.Tensor_Tensor(Tensor mean, Tensor std, *, Generator? generator=None) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor normal(double mean, double std, c10::SymIntArrayRef size, ::std::optional generator, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::normal.float_float(float mean, float std, SymInt[] size, *, Generator? generator=None, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & normal_out(double mean, double std, c10::SymIntArrayRef size, ::std::optional generator, at::Tensor & out); // {"schema": "aten::normal.float_float_out(float mean, float std, SymInt[] size, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor alias(const at::Tensor & self); // {"schema": "aten::alias(Tensor(a) self) -> Tensor(a)", "dispatch": "True", "default": "True"} +void _amp_foreach_non_finite_check_and_unscale_(at::TensorList self, at::Tensor & found_inf, const at::Tensor & inv_scale); // {"schema": "aten::_amp_foreach_non_finite_check_and_unscale_(Tensor(a!)[] self, Tensor(b!) found_inf, Tensor inv_scale) -> ()", "dispatch": "True", "default": "False"} +at::Tensor & _amp_update_scale_(at::Tensor & self, at::Tensor & growth_tracker, const at::Tensor & found_inf, double scale_growth_factor, double scale_backoff_factor, int64_t growth_interval); // {"schema": "aten::_amp_update_scale_(Tensor(a!) self, Tensor(b!) growth_tracker, Tensor found_inf, float scale_growth_factor, float scale_backoff_factor, int growth_interval) -> Tensor(a!)", "dispatch": "True", "default": "False"} +::std::vector _foreach_add(at::TensorList self, const at::Scalar & scalar); // {"schema": "aten::_foreach_add.Scalar(Tensor[] self, Scalar scalar) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_add_(at::TensorList self, const at::Scalar & scalar); // {"schema": "aten::_foreach_add_.Scalar(Tensor(a!)[] self, Scalar scalar) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_add(at::TensorList self, at::TensorList other, const at::Scalar & alpha); // {"schema": "aten::_foreach_add.List(Tensor[] self, Tensor[] other, *, Scalar alpha=1) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_add_(at::TensorList self, at::TensorList other, const at::Scalar & alpha); // {"schema": "aten::_foreach_add_.List(Tensor(a!)[] self, Tensor[] other, *, Scalar alpha=1) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_add(at::TensorList self, at::ArrayRef scalars); // {"schema": "aten::_foreach_add.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_add_(at::TensorList self, at::ArrayRef scalars); // {"schema": "aten::_foreach_add_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_add(at::TensorList self, const at::Tensor & other, const at::Scalar & alpha); // {"schema": "aten::_foreach_add.Tensor(Tensor[] self, Tensor other, *, Scalar alpha=1) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_add_(at::TensorList self, const at::Tensor & other, const at::Scalar & alpha); // {"schema": "aten::_foreach_add_.Tensor(Tensor(a!)[] self, Tensor other, *, Scalar alpha=1) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_sub(at::TensorList self, const at::Scalar & scalar); // {"schema": "aten::_foreach_sub.Scalar(Tensor[] self, Scalar scalar) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_sub_(at::TensorList self, const at::Scalar & scalar); // {"schema": "aten::_foreach_sub_.Scalar(Tensor(a!)[] self, Scalar scalar) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_sub(at::TensorList self, at::TensorList other, const at::Scalar & alpha); // {"schema": "aten::_foreach_sub.List(Tensor[] self, Tensor[] other, *, Scalar alpha=1) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_sub_(at::TensorList self, at::TensorList other, const at::Scalar & alpha); // {"schema": "aten::_foreach_sub_.List(Tensor(a!)[] self, Tensor[] other, *, Scalar alpha=1) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_sub(at::TensorList self, at::ArrayRef scalars); // {"schema": "aten::_foreach_sub.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_sub_(at::TensorList self, at::ArrayRef scalars); // {"schema": "aten::_foreach_sub_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_mul(at::TensorList self, const at::Scalar & scalar); // {"schema": "aten::_foreach_mul.Scalar(Tensor[] self, Scalar scalar) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_mul_(at::TensorList self, const at::Scalar & scalar); // {"schema": "aten::_foreach_mul_.Scalar(Tensor(a!)[] self, Scalar scalar) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_mul(at::TensorList self, at::TensorList other); // {"schema": "aten::_foreach_mul.List(Tensor[] self, Tensor[] other) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_mul_(at::TensorList self, at::TensorList other); // {"schema": "aten::_foreach_mul_.List(Tensor(a!)[] self, Tensor[] other) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_mul(at::TensorList self, at::ArrayRef scalars); // {"schema": "aten::_foreach_mul.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_mul_(at::TensorList self, at::ArrayRef scalars); // {"schema": "aten::_foreach_mul_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_mul(at::TensorList self, const at::Tensor & other); // {"schema": "aten::_foreach_mul.Tensor(Tensor[] self, Tensor other) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_mul_(at::TensorList self, const at::Tensor & other); // {"schema": "aten::_foreach_mul_.Tensor(Tensor(a!)[] self, Tensor other) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_div(at::TensorList self, const at::Scalar & scalar); // {"schema": "aten::_foreach_div.Scalar(Tensor[] self, Scalar scalar) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_div_(at::TensorList self, const at::Scalar & scalar); // {"schema": "aten::_foreach_div_.Scalar(Tensor(a!)[] self, Scalar scalar) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_div(at::TensorList self, at::TensorList other); // {"schema": "aten::_foreach_div.List(Tensor[] self, Tensor[] other) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_div_(at::TensorList self, at::TensorList other); // {"schema": "aten::_foreach_div_.List(Tensor(a!)[] self, Tensor[] other) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_div(at::TensorList self, at::ArrayRef scalars); // {"schema": "aten::_foreach_div.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_div_(at::TensorList self, at::ArrayRef scalars); // {"schema": "aten::_foreach_div_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_div(at::TensorList self, const at::Tensor & other); // {"schema": "aten::_foreach_div.Tensor(Tensor[] self, Tensor other) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_div_(at::TensorList self, const at::Tensor & other); // {"schema": "aten::_foreach_div_.Tensor(Tensor(a!)[] self, Tensor other) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_clamp_max(at::TensorList self, const at::Scalar & scalar); // {"schema": "aten::_foreach_clamp_max.Scalar(Tensor[] self, Scalar scalar) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_clamp_max_(at::TensorList self, const at::Scalar & scalar); // {"schema": "aten::_foreach_clamp_max_.Scalar(Tensor(a!)[] self, Scalar scalar) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_clamp_max(at::TensorList self, at::TensorList other); // {"schema": "aten::_foreach_clamp_max.List(Tensor[] self, Tensor[] other) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_clamp_max_(at::TensorList self, at::TensorList other); // {"schema": "aten::_foreach_clamp_max_.List(Tensor(a!)[] self, Tensor[] other) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_clamp_max(at::TensorList self, at::ArrayRef scalars); // {"schema": "aten::_foreach_clamp_max.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_clamp_max_(at::TensorList self, at::ArrayRef scalars); // {"schema": "aten::_foreach_clamp_max_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_clamp_min(at::TensorList self, const at::Scalar & scalar); // {"schema": "aten::_foreach_clamp_min.Scalar(Tensor[] self, Scalar scalar) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_clamp_min_(at::TensorList self, const at::Scalar & scalar); // {"schema": "aten::_foreach_clamp_min_.Scalar(Tensor(a!)[] self, Scalar scalar) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_clamp_min(at::TensorList self, at::TensorList other); // {"schema": "aten::_foreach_clamp_min.List(Tensor[] self, Tensor[] other) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_clamp_min_(at::TensorList self, at::TensorList other); // {"schema": "aten::_foreach_clamp_min_.List(Tensor(a!)[] self, Tensor[] other) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_clamp_min(at::TensorList self, at::ArrayRef scalars); // {"schema": "aten::_foreach_clamp_min.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_clamp_min_(at::TensorList self, at::ArrayRef scalars); // {"schema": "aten::_foreach_clamp_min_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_maximum(at::TensorList self, const at::Scalar & scalar); // {"schema": "aten::_foreach_maximum.Scalar(Tensor[] self, Scalar scalar) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_maximum_(at::TensorList self, const at::Scalar & scalar); // {"schema": "aten::_foreach_maximum_.Scalar(Tensor(a!)[] self, Scalar scalar) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_maximum(at::TensorList self, at::TensorList other); // {"schema": "aten::_foreach_maximum.List(Tensor[] self, Tensor[] other) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_maximum_(at::TensorList self, at::TensorList other); // {"schema": "aten::_foreach_maximum_.List(Tensor(a!)[] self, Tensor[] other) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_maximum(at::TensorList self, at::ArrayRef scalars); // {"schema": "aten::_foreach_maximum.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_maximum_(at::TensorList self, at::ArrayRef scalars); // {"schema": "aten::_foreach_maximum_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_minimum(at::TensorList self, const at::Scalar & scalar); // {"schema": "aten::_foreach_minimum.Scalar(Tensor[] self, Scalar scalar) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_minimum_(at::TensorList self, const at::Scalar & scalar); // {"schema": "aten::_foreach_minimum_.Scalar(Tensor(a!)[] self, Scalar scalar) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_minimum(at::TensorList self, at::TensorList other); // {"schema": "aten::_foreach_minimum.List(Tensor[] self, Tensor[] other) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_minimum_(at::TensorList self, at::TensorList other); // {"schema": "aten::_foreach_minimum_.List(Tensor(a!)[] self, Tensor[] other) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_minimum(at::TensorList self, at::ArrayRef scalars); // {"schema": "aten::_foreach_minimum.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_minimum_(at::TensorList self, at::ArrayRef scalars); // {"schema": "aten::_foreach_minimum_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_addcdiv(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Scalar & value); // {"schema": "aten::_foreach_addcdiv.Scalar(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1) -> Tensor[]", "dispatch": "True", "default": "True"} +::std::vector _foreach_addcdiv(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, at::ArrayRef scalars); // {"schema": "aten::_foreach_addcdiv.ScalarList(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> Tensor[]", "dispatch": "True", "default": "True"} +::std::vector _foreach_addcdiv(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Tensor & scalars); // {"schema": "aten::_foreach_addcdiv.Tensor(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Tensor scalars) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_addcdiv_(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Scalar & value); // {"schema": "aten::_foreach_addcdiv_.Scalar(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1) -> ()", "dispatch": "True", "default": "True"} +void _foreach_addcdiv_(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, at::ArrayRef scalars); // {"schema": "aten::_foreach_addcdiv_.ScalarList(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> ()", "dispatch": "True", "default": "True"} +void _foreach_addcdiv_(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Tensor & scalars); // {"schema": "aten::_foreach_addcdiv_.Tensor(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Tensor scalars) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_addcmul(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Scalar & value); // {"schema": "aten::_foreach_addcmul.Scalar(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1) -> Tensor[]", "dispatch": "True", "default": "True"} +::std::vector _foreach_addcmul(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, at::ArrayRef scalars); // {"schema": "aten::_foreach_addcmul.ScalarList(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> Tensor[]", "dispatch": "True", "default": "True"} +::std::vector _foreach_addcmul(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Tensor & scalars); // {"schema": "aten::_foreach_addcmul.Tensor(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Tensor scalars) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_addcmul_(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Scalar & value); // {"schema": "aten::_foreach_addcmul_.Scalar(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1) -> ()", "dispatch": "True", "default": "True"} +void _foreach_addcmul_(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, at::ArrayRef scalars); // {"schema": "aten::_foreach_addcmul_.ScalarList(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> ()", "dispatch": "True", "default": "True"} +void _foreach_addcmul_(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Tensor & scalars); // {"schema": "aten::_foreach_addcmul_.Tensor(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Tensor scalars) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_abs(at::TensorList self); // {"schema": "aten::_foreach_abs(Tensor[] self) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_abs_(at::TensorList self); // {"schema": "aten::_foreach_abs_(Tensor(a!)[] self) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_acos(at::TensorList self); // {"schema": "aten::_foreach_acos(Tensor[] self) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_acos_(at::TensorList self); // {"schema": "aten::_foreach_acos_(Tensor(a!)[] self) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_asin(at::TensorList self); // {"schema": "aten::_foreach_asin(Tensor[] self) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_asin_(at::TensorList self); // {"schema": "aten::_foreach_asin_(Tensor(a!)[] self) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_atan(at::TensorList self); // {"schema": "aten::_foreach_atan(Tensor[] self) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_atan_(at::TensorList self); // {"schema": "aten::_foreach_atan_(Tensor(a!)[] self) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_ceil(at::TensorList self); // {"schema": "aten::_foreach_ceil(Tensor[] self) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_ceil_(at::TensorList self); // {"schema": "aten::_foreach_ceil_(Tensor(a!)[] self) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_cos(at::TensorList self); // {"schema": "aten::_foreach_cos(Tensor[] self) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_cos_(at::TensorList self); // {"schema": "aten::_foreach_cos_(Tensor(a!)[] self) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_cosh(at::TensorList self); // {"schema": "aten::_foreach_cosh(Tensor[] self) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_cosh_(at::TensorList self); // {"schema": "aten::_foreach_cosh_(Tensor(a!)[] self) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_erf(at::TensorList self); // {"schema": "aten::_foreach_erf(Tensor[] self) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_erf_(at::TensorList self); // {"schema": "aten::_foreach_erf_(Tensor(a!)[] self) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_erfc(at::TensorList self); // {"schema": "aten::_foreach_erfc(Tensor[] self) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_erfc_(at::TensorList self); // {"schema": "aten::_foreach_erfc_(Tensor(a!)[] self) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_exp(at::TensorList self); // {"schema": "aten::_foreach_exp(Tensor[] self) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_exp_(at::TensorList self); // {"schema": "aten::_foreach_exp_(Tensor(a!)[] self) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_expm1(at::TensorList self); // {"schema": "aten::_foreach_expm1(Tensor[] self) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_expm1_(at::TensorList self); // {"schema": "aten::_foreach_expm1_(Tensor(a!)[] self) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_floor(at::TensorList self); // {"schema": "aten::_foreach_floor(Tensor[] self) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_floor_(at::TensorList self); // {"schema": "aten::_foreach_floor_(Tensor(a!)[] self) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_frac(at::TensorList self); // {"schema": "aten::_foreach_frac(Tensor[] self) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_frac_(at::TensorList self); // {"schema": "aten::_foreach_frac_(Tensor(a!)[] self) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_lerp(at::TensorList self, at::TensorList tensors1, at::TensorList weights); // {"schema": "aten::_foreach_lerp.List(Tensor[] self, Tensor[] tensors1, Tensor[] weights) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_lerp_(at::TensorList self, at::TensorList tensors1, at::TensorList weights); // {"schema": "aten::_foreach_lerp_.List(Tensor(a!)[] self, Tensor[] tensors1, Tensor[] weights) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_lerp(at::TensorList self, at::TensorList tensors1, const at::Scalar & weight); // {"schema": "aten::_foreach_lerp.Scalar(Tensor[] self, Tensor[] tensors1, Scalar weight) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_lerp_(at::TensorList self, at::TensorList tensors1, const at::Scalar & weight); // {"schema": "aten::_foreach_lerp_.Scalar(Tensor(a!)[] self, Tensor[] tensors1, Scalar weight) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_lerp(at::TensorList self, at::TensorList tensors1, at::ArrayRef weight); // {"schema": "aten::_foreach_lerp.ScalarList(Tensor[] self, Tensor[] tensors1, Scalar[] weight) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_lerp_(at::TensorList self, at::TensorList tensors1, at::ArrayRef weight); // {"schema": "aten::_foreach_lerp_.ScalarList(Tensor(a!)[] self, Tensor[] tensors1, Scalar[] weight) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_lgamma(at::TensorList self); // {"schema": "aten::_foreach_lgamma(Tensor[] self) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_lgamma_(at::TensorList self); // {"schema": "aten::_foreach_lgamma_(Tensor(a!)[] self) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_log(at::TensorList self); // {"schema": "aten::_foreach_log(Tensor[] self) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_log_(at::TensorList self); // {"schema": "aten::_foreach_log_(Tensor(a!)[] self) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_log10(at::TensorList self); // {"schema": "aten::_foreach_log10(Tensor[] self) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_log10_(at::TensorList self); // {"schema": "aten::_foreach_log10_(Tensor(a!)[] self) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_log1p(at::TensorList self); // {"schema": "aten::_foreach_log1p(Tensor[] self) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_log1p_(at::TensorList self); // {"schema": "aten::_foreach_log1p_(Tensor(a!)[] self) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_log2(at::TensorList self); // {"schema": "aten::_foreach_log2(Tensor[] self) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_log2_(at::TensorList self); // {"schema": "aten::_foreach_log2_(Tensor(a!)[] self) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_max(at::TensorList self); // {"schema": "aten::_foreach_max(Tensor[] self) -> Tensor[]", "dispatch": "True", "default": "True"} +::std::vector _foreach_neg(at::TensorList self); // {"schema": "aten::_foreach_neg(Tensor[] self) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_neg_(at::TensorList self); // {"schema": "aten::_foreach_neg_(Tensor(a!)[] self) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_norm(at::TensorList self, const at::Scalar & ord, ::std::optional dtype); // {"schema": "aten::_foreach_norm.Scalar(Tensor[] self, Scalar ord=2, ScalarType? dtype=None) -> Tensor[]", "dispatch": "True", "default": "True"} +::std::vector _foreach_pow(at::TensorList self, at::TensorList exponent); // {"schema": "aten::_foreach_pow.List(Tensor[] self, Tensor[] exponent) -> Tensor[]", "dispatch": "True", "default": "True"} +::std::vector _foreach_pow(at::TensorList self, const at::Scalar & exponent); // {"schema": "aten::_foreach_pow.Scalar(Tensor[] self, Scalar exponent) -> Tensor[]", "dispatch": "True", "default": "True"} +::std::vector _foreach_pow(at::TensorList self, at::ArrayRef exponent); // {"schema": "aten::_foreach_pow.ScalarList(Tensor[] self, Scalar[] exponent) -> Tensor[]", "dispatch": "True", "default": "True"} +::std::vector _foreach_pow(const at::Scalar & self, at::TensorList exponent); // {"schema": "aten::_foreach_pow.ScalarAndTensor(Scalar self, Tensor[] exponent) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_pow_(at::TensorList self, at::TensorList exponent); // {"schema": "aten::_foreach_pow_.List(Tensor(a!)[] self, Tensor[] exponent) -> ()", "dispatch": "True", "default": "True"} +void _foreach_pow_(at::TensorList self, const at::Scalar & exponent); // {"schema": "aten::_foreach_pow_.Scalar(Tensor(a!)[] self, Scalar exponent) -> ()", "dispatch": "True", "default": "True"} +void _foreach_pow_(at::TensorList self, at::ArrayRef exponent); // {"schema": "aten::_foreach_pow_.ScalarList(Tensor(a!)[] self, Scalar[] exponent) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_reciprocal(at::TensorList self); // {"schema": "aten::_foreach_reciprocal(Tensor[] self) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_reciprocal_(at::TensorList self); // {"schema": "aten::_foreach_reciprocal_(Tensor(a!)[] self) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_round(at::TensorList self); // {"schema": "aten::_foreach_round(Tensor[] self) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_round_(at::TensorList self); // {"schema": "aten::_foreach_round_(Tensor(a!)[] self) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_rsqrt(at::TensorList self); // {"schema": "aten::_foreach_rsqrt(Tensor[] self) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_rsqrt_(at::TensorList self); // {"schema": "aten::_foreach_rsqrt_(Tensor(a!)[] self) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_sigmoid(at::TensorList self); // {"schema": "aten::_foreach_sigmoid(Tensor[] self) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_sigmoid_(at::TensorList self); // {"schema": "aten::_foreach_sigmoid_(Tensor(a!)[] self) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_sign(at::TensorList self); // {"schema": "aten::_foreach_sign(Tensor[] self) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_sign_(at::TensorList self); // {"schema": "aten::_foreach_sign_(Tensor(a!)[] self) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_sin(at::TensorList self); // {"schema": "aten::_foreach_sin(Tensor[] self) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_sin_(at::TensorList self); // {"schema": "aten::_foreach_sin_(Tensor(a!)[] self) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_sinh(at::TensorList self); // {"schema": "aten::_foreach_sinh(Tensor[] self) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_sinh_(at::TensorList self); // {"schema": "aten::_foreach_sinh_(Tensor(a!)[] self) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_sqrt(at::TensorList self); // {"schema": "aten::_foreach_sqrt(Tensor[] self) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_sqrt_(at::TensorList self); // {"schema": "aten::_foreach_sqrt_(Tensor(a!)[] self) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_tan(at::TensorList self); // {"schema": "aten::_foreach_tan(Tensor[] self) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_tan_(at::TensorList self); // {"schema": "aten::_foreach_tan_(Tensor(a!)[] self) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_tanh(at::TensorList self); // {"schema": "aten::_foreach_tanh(Tensor[] self) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_tanh_(at::TensorList self); // {"schema": "aten::_foreach_tanh_(Tensor(a!)[] self) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_trunc(at::TensorList self); // {"schema": "aten::_foreach_trunc(Tensor[] self) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_trunc_(at::TensorList self); // {"schema": "aten::_foreach_trunc_(Tensor(a!)[] self) -> ()", "dispatch": "True", "default": "True"} +void _foreach_zero_(at::TensorList self); // {"schema": "aten::_foreach_zero_(Tensor(a!)[] self) -> ()", "dispatch": "True", "default": "True"} +void _foreach_copy_(at::TensorList self, at::TensorList src, bool non_blocking); // {"schema": "aten::_foreach_copy_(Tensor(a!)[] self, Tensor[] src, bool non_blocking=False) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_copy(at::TensorList self, at::TensorList src, bool non_blocking); // {"schema": "aten::_foreach_copy(Tensor[] self, Tensor[] src, bool non_blocking=False) -> Tensor[] self_out", "dispatch": "True", "default": "True"} +at::Tensor bucketize(const at::Tensor & self, const at::Tensor & boundaries, bool out_int32, bool right); // {"schema": "aten::bucketize.Tensor(Tensor self, Tensor boundaries, *, bool out_int32=False, bool right=False) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & bucketize_out(const at::Tensor & self, const at::Tensor & boundaries, bool out_int32, bool right, at::Tensor & out); // {"schema": "aten::bucketize.Tensor_out(Tensor self, Tensor boundaries, *, bool out_int32=False, bool right=False, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor bucketize(const at::Scalar & self, const at::Tensor & boundaries, bool out_int32, bool right); // {"schema": "aten::bucketize.Scalar(Scalar self, Tensor boundaries, *, bool out_int32=False, bool right=False) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor searchsorted(const at::Tensor & sorted_sequence, const at::Tensor & self, bool out_int32, bool right, ::std::optional side, const ::std::optional & sorter); // {"schema": "aten::searchsorted.Tensor(Tensor sorted_sequence, Tensor self, *, bool out_int32=False, bool right=False, str? side=None, Tensor? sorter=None) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & searchsorted_out(const at::Tensor & sorted_sequence, const at::Tensor & self, bool out_int32, bool right, ::std::optional side, const ::std::optional & sorter, at::Tensor & out); // {"schema": "aten::searchsorted.Tensor_out(Tensor sorted_sequence, Tensor self, *, bool out_int32=False, bool right=False, str? side=None, Tensor? sorter=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor searchsorted(const at::Tensor & sorted_sequence, const at::Scalar & self, bool out_int32, bool right, ::std::optional side, const ::std::optional & sorter); // {"schema": "aten::searchsorted.Scalar(Tensor sorted_sequence, Scalar self, *, bool out_int32=False, bool right=False, str? side=None, Tensor? sorter=None) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & searchsorted_out(const at::Tensor & sorted_sequence, const at::Scalar & self, bool out_int32, bool right, ::std::optional side, const ::std::optional & sorter, at::Tensor & out); // {"schema": "aten::searchsorted.Scalar_out(Tensor sorted_sequence, Scalar self, *, bool out_int32=False, bool right=False, str? side=None, Tensor? sorter=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor _convert_indices_from_coo_to_csr(const at::Tensor & self, int64_t size, bool out_int32); // {"schema": "aten::_convert_indices_from_coo_to_csr(Tensor self, int size, *, bool out_int32=False) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & _convert_indices_from_coo_to_csr_out(const at::Tensor & self, int64_t size, bool out_int32, at::Tensor & out); // {"schema": "aten::_convert_indices_from_coo_to_csr.out(Tensor self, int size, *, bool out_int32=False, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor _convert_indices_from_csr_to_coo(const at::Tensor & crow_indices, const at::Tensor & col_indices, bool out_int32, bool transpose); // {"schema": "aten::_convert_indices_from_csr_to_coo(Tensor crow_indices, Tensor col_indices, *, bool out_int32=False, bool transpose=False) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & _convert_indices_from_csr_to_coo_out(const at::Tensor & crow_indices, const at::Tensor & col_indices, bool out_int32, bool transpose, at::Tensor & out); // {"schema": "aten::_convert_indices_from_csr_to_coo.out(Tensor crow_indices, Tensor col_indices, *, bool out_int32=False, bool transpose=False, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & mse_loss_out(const at::Tensor & self, const at::Tensor & target, int64_t reduction, at::Tensor & out); // {"schema": "aten::mse_loss.out(Tensor self, Tensor target, int reduction=Mean, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor mse_loss(const at::Tensor & self, const at::Tensor & target, int64_t reduction); // {"schema": "aten::mse_loss(Tensor self, Tensor target, int reduction=Mean) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & mse_loss_backward_out(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, int64_t reduction, at::Tensor & grad_input); // {"schema": "aten::mse_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, int reduction, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor mse_loss_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, int64_t reduction); // {"schema": "aten::mse_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor l1_loss(const at::Tensor & self, const at::Tensor & target, int64_t reduction); // {"schema": "aten::l1_loss(Tensor self, Tensor target, int reduction=Mean) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & multi_margin_loss_out(const at::Tensor & self, const at::Tensor & target, const at::Scalar & p, const at::Scalar & margin, const ::std::optional & weight, int64_t reduction, at::Tensor & out); // {"schema": "aten::multi_margin_loss.out(Tensor self, Tensor target, Scalar p=1, Scalar margin=1, Tensor? weight=None, int reduction=Mean, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor multi_margin_loss(const at::Tensor & self, const at::Tensor & target, const at::Scalar & p, const at::Scalar & margin, const ::std::optional & weight, int64_t reduction); // {"schema": "aten::multi_margin_loss(Tensor self, Tensor target, Scalar p=1, Scalar margin=1, Tensor? weight=None, int reduction=Mean) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & multi_margin_loss_backward_out(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const at::Scalar & p, const at::Scalar & margin, const ::std::optional & weight, int64_t reduction, at::Tensor & grad_input); // {"schema": "aten::multi_margin_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, Scalar p, Scalar margin, Tensor? weight=None, int reduction=Mean, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor multi_margin_loss_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const at::Scalar & p, const at::Scalar & margin, const ::std::optional & weight, int64_t reduction); // {"schema": "aten::multi_margin_loss_backward(Tensor grad_output, Tensor self, Tensor target, Scalar p, Scalar margin, Tensor? weight=None, int reduction=Mean) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & multilabel_margin_loss_out(const at::Tensor & self, const at::Tensor & target, int64_t reduction, at::Tensor & out); // {"schema": "aten::multilabel_margin_loss.out(Tensor self, Tensor target, int reduction=Mean, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor multilabel_margin_loss(const at::Tensor & self, const at::Tensor & target, int64_t reduction); // {"schema": "aten::multilabel_margin_loss(Tensor self, Tensor target, int reduction=Mean) -> Tensor", "dispatch": "False", "default": "True"} +::std::tuple multilabel_margin_loss_forward_out(const at::Tensor & self, const at::Tensor & target, int64_t reduction, at::Tensor & output, at::Tensor & is_target); // {"schema": "aten::multilabel_margin_loss_forward.output(Tensor self, Tensor target, int reduction, *, Tensor(a!) output, Tensor(b!) is_target) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "False"} +::std::tuple multilabel_margin_loss_forward(const at::Tensor & self, const at::Tensor & target, int64_t reduction); // {"schema": "aten::multilabel_margin_loss_forward(Tensor self, Tensor target, int reduction) -> (Tensor output, Tensor is_target)", "dispatch": "True", "default": "False"} +at::Tensor & multilabel_margin_loss_backward_out(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, int64_t reduction, const at::Tensor & is_target, at::Tensor & grad_input); // {"schema": "aten::multilabel_margin_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, int reduction, Tensor is_target, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor multilabel_margin_loss_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, int64_t reduction, const at::Tensor & is_target); // {"schema": "aten::multilabel_margin_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction, Tensor is_target) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & nll_loss_out(const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, c10::SymInt ignore_index, at::Tensor & out); // {"schema": "aten::nll_loss.out(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor nll_loss_nd(const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, c10::SymInt ignore_index); // {"schema": "aten::nll_loss_nd(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor nll_loss(const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, c10::SymInt ignore_index); // {"schema": "aten::nll_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100) -> Tensor", "dispatch": "False", "default": "True"} +::std::tuple nll_loss_forward_out(const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, c10::SymInt ignore_index, at::Tensor & output, at::Tensor & total_weight); // {"schema": "aten::nll_loss_forward.output(Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, *, Tensor(a!) output, Tensor(b!) total_weight) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "False"} +::std::tuple nll_loss_forward(const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, c10::SymInt ignore_index); // {"schema": "aten::nll_loss_forward(Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index) -> (Tensor output, Tensor total_weight)", "dispatch": "True", "default": "True"} +at::Tensor & nll_loss_backward_out(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, c10::SymInt ignore_index, const at::Tensor & total_weight, at::Tensor & grad_input); // {"schema": "aten::nll_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, Tensor total_weight, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor nll_loss_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, c10::SymInt ignore_index, const at::Tensor & total_weight); // {"schema": "aten::nll_loss_backward(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, Tensor total_weight) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & nll_loss2d_out(const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, c10::SymInt ignore_index, at::Tensor & out); // {"schema": "aten::nll_loss2d.out(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor nll_loss2d(const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, c10::SymInt ignore_index); // {"schema": "aten::nll_loss2d(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100) -> Tensor", "dispatch": "False", "default": "True"} +::std::tuple nll_loss2d_forward_out(const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, c10::SymInt ignore_index, at::Tensor & output, at::Tensor & total_weight); // {"schema": "aten::nll_loss2d_forward.output(Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, *, Tensor(a!) output, Tensor(b!) total_weight) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "False"} +::std::tuple nll_loss2d_forward(const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, c10::SymInt ignore_index); // {"schema": "aten::nll_loss2d_forward(Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index) -> (Tensor output, Tensor total_weight)", "dispatch": "True", "default": "False"} +at::Tensor & nll_loss2d_backward_out(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, c10::SymInt ignore_index, const at::Tensor & total_weight, at::Tensor & grad_input); // {"schema": "aten::nll_loss2d_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, Tensor total_weight, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor nll_loss2d_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, c10::SymInt ignore_index, const at::Tensor & total_weight); // {"schema": "aten::nll_loss2d_backward(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, Tensor total_weight) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & smooth_l1_loss_out(const at::Tensor & self, const at::Tensor & target, int64_t reduction, double beta, at::Tensor & out); // {"schema": "aten::smooth_l1_loss.out(Tensor self, Tensor target, int reduction=Mean, float beta=1.0, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor smooth_l1_loss(const at::Tensor & self, const at::Tensor & target, int64_t reduction, double beta); // {"schema": "aten::smooth_l1_loss(Tensor self, Tensor target, int reduction=Mean, float beta=1.0) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & smooth_l1_loss_backward_out(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, int64_t reduction, double beta, at::Tensor & grad_input); // {"schema": "aten::smooth_l1_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, int reduction, float beta, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor smooth_l1_loss_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, int64_t reduction, double beta); // {"schema": "aten::smooth_l1_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction, float beta) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & huber_loss_out(const at::Tensor & self, const at::Tensor & target, int64_t reduction, double delta, at::Tensor & out); // {"schema": "aten::huber_loss.out(Tensor self, Tensor target, int reduction=Mean, float delta=1.0, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor huber_loss(const at::Tensor & self, const at::Tensor & target, int64_t reduction, double delta); // {"schema": "aten::huber_loss(Tensor self, Tensor target, int reduction=Mean, float delta=1.0) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & huber_loss_backward_out(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, int64_t reduction, double delta, at::Tensor & grad_input); // {"schema": "aten::huber_loss_backward.out(Tensor grad_output, Tensor self, Tensor target, int reduction, float delta, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor huber_loss_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, int64_t reduction, double delta); // {"schema": "aten::huber_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction, float delta) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & soft_margin_loss_out(const at::Tensor & self, const at::Tensor & target, int64_t reduction, at::Tensor & out); // {"schema": "aten::soft_margin_loss.out(Tensor self, Tensor target, int reduction=Mean, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor soft_margin_loss(const at::Tensor & self, const at::Tensor & target, int64_t reduction); // {"schema": "aten::soft_margin_loss(Tensor self, Tensor target, int reduction=Mean) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & soft_margin_loss_backward_out(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, int64_t reduction, at::Tensor & grad_input); // {"schema": "aten::soft_margin_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, int reduction, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor soft_margin_loss_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, int64_t reduction); // {"schema": "aten::soft_margin_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & elu_out(const at::Tensor & self, const at::Scalar & alpha, const at::Scalar & scale, const at::Scalar & input_scale, at::Tensor & out); // {"schema": "aten::elu.out(Tensor self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor elu(const at::Tensor & self, const at::Scalar & alpha, const at::Scalar & scale, const at::Scalar & input_scale); // {"schema": "aten::elu(Tensor self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & elu_backward_out(const at::Tensor & grad_output, const at::Scalar & alpha, const at::Scalar & scale, const at::Scalar & input_scale, bool is_result, const at::Tensor & self_or_result, at::Tensor & grad_input); // {"schema": "aten::elu_backward.grad_input(Tensor grad_output, Scalar alpha, Scalar scale, Scalar input_scale, bool is_result, Tensor self_or_result, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor elu_backward(const at::Tensor & grad_output, const at::Scalar & alpha, const at::Scalar & scale, const at::Scalar & input_scale, bool is_result, const at::Tensor & self_or_result); // {"schema": "aten::elu_backward(Tensor grad_output, Scalar alpha, Scalar scale, Scalar input_scale, bool is_result, Tensor self_or_result) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & elu_(at::Tensor & self, const at::Scalar & alpha, const at::Scalar & scale, const at::Scalar & input_scale); // {"schema": "aten::elu_(Tensor(a!) self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & glu_out(const at::Tensor & self, int64_t dim, at::Tensor & out); // {"schema": "aten::glu.out(Tensor self, int dim=-1, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor glu(const at::Tensor & self, int64_t dim); // {"schema": "aten::glu(Tensor self, int dim=-1) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & glu_backward_out(const at::Tensor & grad_output, const at::Tensor & self, int64_t dim, at::Tensor & grad_input); // {"schema": "aten::glu_backward.grad_input(Tensor grad_output, Tensor self, int dim, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor glu_backward(const at::Tensor & grad_output, const at::Tensor & self, int64_t dim); // {"schema": "aten::glu_backward(Tensor grad_output, Tensor self, int dim) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor glu_jvp(const at::Tensor & glu, const at::Tensor & x, const at::Tensor & dx, int64_t dim); // {"schema": "aten::glu_jvp(Tensor glu, Tensor x, Tensor dx, int dim) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor glu_backward_jvp(const at::Tensor & grad_x, const at::Tensor & grad_glu, const at::Tensor & x, const at::Tensor & dgrad_glu, const at::Tensor & dx, int64_t dim); // {"schema": "aten::glu_backward_jvp(Tensor grad_x, Tensor grad_glu, Tensor x, Tensor dgrad_glu, Tensor dx, int dim) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & hardsigmoid_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::hardsigmoid.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor hardsigmoid(const at::Tensor & self); // {"schema": "aten::hardsigmoid(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & hardsigmoid_(at::Tensor & self); // {"schema": "aten::hardsigmoid_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & hardsigmoid_backward_out(const at::Tensor & grad_output, const at::Tensor & self, at::Tensor & grad_input); // {"schema": "aten::hardsigmoid_backward.grad_input(Tensor grad_output, Tensor self, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor hardsigmoid_backward(const at::Tensor & grad_output, const at::Tensor & self); // {"schema": "aten::hardsigmoid_backward(Tensor grad_output, Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & hardtanh_out(const at::Tensor & self, const at::Scalar & min_val, const at::Scalar & max_val, at::Tensor & out); // {"schema": "aten::hardtanh.out(Tensor self, Scalar min_val=-1, Scalar max_val=1, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor hardtanh(const at::Tensor & self, const at::Scalar & min_val, const at::Scalar & max_val); // {"schema": "aten::hardtanh(Tensor self, Scalar min_val=-1, Scalar max_val=1) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & hardtanh_backward_out(const at::Tensor & grad_output, const at::Tensor & self, const at::Scalar & min_val, const at::Scalar & max_val, at::Tensor & grad_input); // {"schema": "aten::hardtanh_backward.grad_input(Tensor grad_output, Tensor self, Scalar min_val, Scalar max_val, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor hardtanh_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Scalar & min_val, const at::Scalar & max_val); // {"schema": "aten::hardtanh_backward(Tensor grad_output, Tensor self, Scalar min_val, Scalar max_val) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & hardtanh_(at::Tensor & self, const at::Scalar & min_val, const at::Scalar & max_val); // {"schema": "aten::hardtanh_(Tensor(a!) self, Scalar min_val=-1, Scalar max_val=1) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & hardswish_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::hardswish.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor hardswish(const at::Tensor & self); // {"schema": "aten::hardswish(Tensor self) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & hardswish_(at::Tensor & self); // {"schema": "aten::hardswish_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor hardswish_backward(const at::Tensor & grad_output, const at::Tensor & self); // {"schema": "aten::hardswish_backward(Tensor grad_output, Tensor self) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & leaky_relu_out(const at::Tensor & self, const at::Scalar & negative_slope, at::Tensor & out); // {"schema": "aten::leaky_relu.out(Tensor self, Scalar negative_slope=0.01, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor leaky_relu(const at::Tensor & self, const at::Scalar & negative_slope); // {"schema": "aten::leaky_relu(Tensor self, Scalar negative_slope=0.01) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & leaky_relu_backward_out(const at::Tensor & grad_output, const at::Tensor & self, const at::Scalar & negative_slope, bool self_is_result, at::Tensor & grad_input); // {"schema": "aten::leaky_relu_backward.grad_input(Tensor grad_output, Tensor self, Scalar negative_slope, bool self_is_result, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor leaky_relu_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Scalar & negative_slope, bool self_is_result); // {"schema": "aten::leaky_relu_backward(Tensor grad_output, Tensor self, Scalar negative_slope, bool self_is_result) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & leaky_relu_(at::Tensor & self, const at::Scalar & negative_slope); // {"schema": "aten::leaky_relu_(Tensor(a!) self, Scalar negative_slope=0.01) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & log_sigmoid_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::log_sigmoid.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor log_sigmoid(const at::Tensor & self); // {"schema": "aten::log_sigmoid(Tensor self) -> Tensor", "dispatch": "False", "default": "True"} +::std::tuple log_sigmoid_forward_out(const at::Tensor & self, at::Tensor & output, at::Tensor & buffer); // {"schema": "aten::log_sigmoid_forward.output(Tensor self, *, Tensor(a!) output, Tensor(b!) buffer) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "False"} +::std::tuple log_sigmoid_forward(const at::Tensor & self); // {"schema": "aten::log_sigmoid_forward(Tensor self) -> (Tensor output, Tensor buffer)", "dispatch": "True", "default": "False"} +at::Tensor & log_sigmoid_backward_out(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & buffer, at::Tensor & grad_input); // {"schema": "aten::log_sigmoid_backward.grad_input(Tensor grad_output, Tensor self, Tensor buffer, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor log_sigmoid_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & buffer); // {"schema": "aten::log_sigmoid_backward(Tensor grad_output, Tensor self, Tensor buffer) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & rrelu_with_noise_out(const at::Tensor & self, at::Tensor & noise, const at::Scalar & lower, const at::Scalar & upper, bool training, ::std::optional generator, at::Tensor & out); // {"schema": "aten::rrelu_with_noise.out(Tensor self, Tensor(b!) noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor rrelu_with_noise(const at::Tensor & self, at::Tensor & noise, const at::Scalar & lower, const at::Scalar & upper, bool training, ::std::optional generator); // {"schema": "aten::rrelu_with_noise(Tensor self, Tensor(b!) noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor rrelu_with_noise_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & noise, const at::Scalar & lower, const at::Scalar & upper, bool training, bool self_is_result); // {"schema": "aten::rrelu_with_noise_backward(Tensor grad_output, Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training, bool self_is_result) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & rrelu_with_noise_(at::Tensor & self, at::Tensor & noise, const at::Scalar & lower, const at::Scalar & upper, bool training, ::std::optional generator); // {"schema": "aten::rrelu_with_noise_(Tensor(a!) self, Tensor(b!) noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & softplus_out(const at::Tensor & self, const at::Scalar & beta, const at::Scalar & threshold, at::Tensor & out); // {"schema": "aten::softplus.out(Tensor self, Scalar beta=1, Scalar threshold=20, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor softplus(const at::Tensor & self, const at::Scalar & beta, const at::Scalar & threshold); // {"schema": "aten::softplus(Tensor self, Scalar beta=1, Scalar threshold=20) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & softplus_backward_out(const at::Tensor & grad_output, const at::Tensor & self, const at::Scalar & beta, const at::Scalar & threshold, at::Tensor & grad_input); // {"schema": "aten::softplus_backward.grad_input(Tensor grad_output, Tensor self, Scalar beta, Scalar threshold, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor softplus_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Scalar & beta, const at::Scalar & threshold); // {"schema": "aten::softplus_backward(Tensor grad_output, Tensor self, Scalar beta, Scalar threshold) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & softshrink_out(const at::Tensor & self, const at::Scalar & lambd, at::Tensor & out); // {"schema": "aten::softshrink.out(Tensor self, Scalar lambd=0.5, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor softshrink(const at::Tensor & self, const at::Scalar & lambd); // {"schema": "aten::softshrink(Tensor self, Scalar lambd=0.5) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & softshrink_backward_out(const at::Tensor & grad_output, const at::Tensor & self, const at::Scalar & lambd, at::Tensor & grad_input); // {"schema": "aten::softshrink_backward.grad_input(Tensor grad_output, Tensor self, Scalar lambd, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor softshrink_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Scalar & lambd); // {"schema": "aten::softshrink_backward(Tensor grad_output, Tensor self, Scalar lambd) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & adaptive_avg_pool2d_out(const at::Tensor & self, c10::SymIntArrayRef output_size, at::Tensor & out); // {"schema": "aten::adaptive_avg_pool2d.out(Tensor self, SymInt[2] output_size, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor adaptive_avg_pool2d(const at::Tensor & self, c10::SymIntArrayRef output_size); // {"schema": "aten::adaptive_avg_pool2d(Tensor self, SymInt[2] output_size) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor mkldnn_adaptive_avg_pool2d(const at::Tensor & self, at::IntArrayRef output_size); // {"schema": "aten::mkldnn_adaptive_avg_pool2d(Tensor self, int[2] output_size) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & mkldnn_adaptive_avg_pool2d_out(const at::Tensor & self, at::IntArrayRef output_size, at::Tensor & out); // {"schema": "aten::mkldnn_adaptive_avg_pool2d.out(Tensor self, int[2] output_size, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor mkldnn_adaptive_avg_pool2d_backward(const at::Tensor & grad_output, const at::Tensor & self); // {"schema": "aten::mkldnn_adaptive_avg_pool2d_backward(Tensor grad_output, Tensor self) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _adaptive_avg_pool2d(const at::Tensor & self, c10::SymIntArrayRef output_size); // {"schema": "aten::_adaptive_avg_pool2d(Tensor self, SymInt[2] output_size) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _adaptive_avg_pool2d_backward(const at::Tensor & grad_output, const at::Tensor & self); // {"schema": "aten::_adaptive_avg_pool2d_backward(Tensor grad_output, Tensor self) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & adaptive_avg_pool3d_out(const at::Tensor & self, c10::SymIntArrayRef output_size, at::Tensor & out); // {"schema": "aten::adaptive_avg_pool3d.out(Tensor self, SymInt[3] output_size, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor adaptive_avg_pool3d(const at::Tensor & self, c10::SymIntArrayRef output_size); // {"schema": "aten::adaptive_avg_pool3d(Tensor self, SymInt[3] output_size) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _adaptive_avg_pool3d(const at::Tensor & self, c10::SymIntArrayRef output_size); // {"schema": "aten::_adaptive_avg_pool3d(Tensor self, SymInt[3] output_size) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & adaptive_avg_pool3d_backward_out(const at::Tensor & grad_output, const at::Tensor & self, at::Tensor & grad_input); // {"schema": "aten::adaptive_avg_pool3d_backward.grad_input(Tensor grad_output, Tensor self, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor _adaptive_avg_pool3d_backward(const at::Tensor & grad_output, const at::Tensor & self); // {"schema": "aten::_adaptive_avg_pool3d_backward(Tensor grad_output, Tensor self) -> Tensor", "dispatch": "True", "default": "False"} +::std::tuple adaptive_max_pool2d_out(const at::Tensor & self, at::IntArrayRef output_size, at::Tensor & out, at::Tensor & indices); // {"schema": "aten::adaptive_max_pool2d.out(Tensor self, int[2] output_size, *, Tensor(a!) out, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "False"} +::std::tuple adaptive_max_pool2d(const at::Tensor & self, at::IntArrayRef output_size); // {"schema": "aten::adaptive_max_pool2d(Tensor self, int[2] output_size) -> (Tensor, Tensor)", "dispatch": "True", "default": "True"} +at::Tensor & adaptive_max_pool2d_backward_out(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices, at::Tensor & grad_input); // {"schema": "aten::adaptive_max_pool2d_backward.grad_input(Tensor grad_output, Tensor self, Tensor indices, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor adaptive_max_pool2d_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices); // {"schema": "aten::adaptive_max_pool2d_backward(Tensor grad_output, Tensor self, Tensor indices) -> Tensor", "dispatch": "True", "default": "True"} +::std::tuple adaptive_max_pool3d_out(const at::Tensor & self, at::IntArrayRef output_size, at::Tensor & out, at::Tensor & indices); // {"schema": "aten::adaptive_max_pool3d.out(Tensor self, int[3] output_size, *, Tensor(a!) out, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "False"} +::std::tuple adaptive_max_pool3d(const at::Tensor & self, at::IntArrayRef output_size); // {"schema": "aten::adaptive_max_pool3d(Tensor self, int[3] output_size) -> (Tensor, Tensor)", "dispatch": "True", "default": "True"} +at::Tensor & adaptive_max_pool3d_backward_out(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices, at::Tensor & grad_input); // {"schema": "aten::adaptive_max_pool3d_backward.grad_input(Tensor grad_output, Tensor self, Tensor indices, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor adaptive_max_pool3d_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices); // {"schema": "aten::adaptive_max_pool3d_backward(Tensor grad_output, Tensor self, Tensor indices) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & avg_pool2d_out(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override, at::Tensor & out); // {"schema": "aten::avg_pool2d.out(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor avg_pool2d(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override); // {"schema": "aten::avg_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & avg_pool2d_backward_out(const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override, at::Tensor & grad_input); // {"schema": "aten::avg_pool2d_backward.grad_input(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride, int[2] padding, bool ceil_mode, bool count_include_pad, int? divisor_override, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor avg_pool2d_backward(const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override); // {"schema": "aten::avg_pool2d_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride, int[2] padding, bool ceil_mode, bool count_include_pad, int? divisor_override) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & avg_pool3d_out(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override, at::Tensor & out); // {"schema": "aten::avg_pool3d.out(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor avg_pool3d(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override); // {"schema": "aten::avg_pool3d(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & avg_pool3d_backward_out(const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override, at::Tensor & grad_input); // {"schema": "aten::avg_pool3d_backward.grad_input(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] stride, int[3] padding, bool ceil_mode, bool count_include_pad, int? divisor_override, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor avg_pool3d_backward(const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override); // {"schema": "aten::avg_pool3d_backward(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] stride, int[3] padding, bool ceil_mode, bool count_include_pad, int? divisor_override) -> Tensor", "dispatch": "True", "default": "True"} +::std::tuple fractional_max_pool2d_out(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef output_size, const at::Tensor & random_samples, at::Tensor & output, at::Tensor & indices); // {"schema": "aten::fractional_max_pool2d.output(Tensor self, int[2] kernel_size, int[2] output_size, Tensor random_samples, *, Tensor(a!) output, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "False"} +::std::tuple fractional_max_pool2d(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef output_size, const at::Tensor & random_samples); // {"schema": "aten::fractional_max_pool2d(Tensor self, int[2] kernel_size, int[2] output_size, Tensor random_samples) -> (Tensor, Tensor)", "dispatch": "True", "default": "True"} +at::Tensor & fractional_max_pool2d_backward_out(const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef output_size, const at::Tensor & indices, at::Tensor & grad_input); // {"schema": "aten::fractional_max_pool2d_backward.grad_input(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] output_size, Tensor indices, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor fractional_max_pool2d_backward(const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef output_size, const at::Tensor & indices); // {"schema": "aten::fractional_max_pool2d_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] output_size, Tensor indices) -> Tensor", "dispatch": "True", "default": "True"} +::std::tuple fractional_max_pool3d_out(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef output_size, const at::Tensor & random_samples, at::Tensor & output, at::Tensor & indices); // {"schema": "aten::fractional_max_pool3d.output(Tensor self, int[3] kernel_size, int[3] output_size, Tensor random_samples, *, Tensor(a!) output, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "False"} +::std::tuple fractional_max_pool3d(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef output_size, const at::Tensor & random_samples); // {"schema": "aten::fractional_max_pool3d(Tensor self, int[3] kernel_size, int[3] output_size, Tensor random_samples) -> (Tensor, Tensor)", "dispatch": "True", "default": "True"} +at::Tensor & fractional_max_pool3d_backward_out(const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef output_size, const at::Tensor & indices, at::Tensor & grad_input); // {"schema": "aten::fractional_max_pool3d_backward.grad_input(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] output_size, Tensor indices, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor fractional_max_pool3d_backward(const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef output_size, const at::Tensor & indices); // {"schema": "aten::fractional_max_pool3d_backward(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] output_size, Tensor indices) -> Tensor", "dispatch": "True", "default": "False"} +::std::tuple max_pool2d_with_indices_out(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, at::Tensor & out, at::Tensor & indices); // {"schema": "aten::max_pool2d_with_indices.out(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False, *, Tensor(a!) out, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "False"} +::std::tuple max_pool2d_with_indices(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode); // {"schema": "aten::max_pool2d_with_indices(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor)", "dispatch": "True", "default": "True"} +at::Tensor & max_pool2d_with_indices_backward_out(const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, const at::Tensor & indices, at::Tensor & grad_input); // {"schema": "aten::max_pool2d_with_indices_backward.grad_input(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride, int[2] padding, int[2] dilation, bool ceil_mode, Tensor indices, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor max_pool2d_with_indices_backward(const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, const at::Tensor & indices); // {"schema": "aten::max_pool2d_with_indices_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride, int[2] padding, int[2] dilation, bool ceil_mode, Tensor indices) -> Tensor", "dispatch": "True", "default": "True"} +::std::tuple max_pool3d_with_indices_out(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, at::Tensor & out, at::Tensor & indices); // {"schema": "aten::max_pool3d_with_indices.out(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False, *, Tensor(a!) out, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "False"} +::std::tuple max_pool3d_with_indices(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode); // {"schema": "aten::max_pool3d_with_indices(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor)", "dispatch": "True", "default": "False"} +at::Tensor & max_pool3d_with_indices_backward_out(const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, const at::Tensor & indices, at::Tensor & grad_input); // {"schema": "aten::max_pool3d_with_indices_backward.grad_input(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] stride, int[3] padding, int[3] dilation, bool ceil_mode, Tensor indices, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor max_pool3d_with_indices_backward(const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, const at::Tensor & indices); // {"schema": "aten::max_pool3d_with_indices_backward(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] stride, int[3] padding, int[3] dilation, bool ceil_mode, Tensor indices) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & max_unpool2d_out(const at::Tensor & self, const at::Tensor & indices, c10::SymIntArrayRef output_size, at::Tensor & out); // {"schema": "aten::max_unpool2d.out(Tensor self, Tensor indices, SymInt[2] output_size, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor max_unpool2d(const at::Tensor & self, const at::Tensor & indices, c10::SymIntArrayRef output_size); // {"schema": "aten::max_unpool2d(Tensor self, Tensor indices, SymInt[2] output_size) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & max_unpool3d_out(const at::Tensor & self, const at::Tensor & indices, c10::SymIntArrayRef output_size, at::IntArrayRef stride, at::IntArrayRef padding, at::Tensor & out); // {"schema": "aten::max_unpool3d.out(Tensor self, Tensor indices, SymInt[3] output_size, int[3] stride, int[3] padding, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor max_unpool3d(const at::Tensor & self, const at::Tensor & indices, c10::SymIntArrayRef output_size, at::IntArrayRef stride, at::IntArrayRef padding); // {"schema": "aten::max_unpool3d(Tensor self, Tensor indices, SymInt[3] output_size, int[3] stride, int[3] padding) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & reflection_pad1d_out(const at::Tensor & self, c10::SymIntArrayRef padding, at::Tensor & out); // {"schema": "aten::reflection_pad1d.out(Tensor self, SymInt[2] padding, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor reflection_pad1d(const at::Tensor & self, c10::SymIntArrayRef padding); // {"schema": "aten::reflection_pad1d(Tensor self, SymInt[2] padding) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & reflection_pad1d_backward_out(const at::Tensor & grad_output, const at::Tensor & self, c10::SymIntArrayRef padding, at::Tensor & grad_input); // {"schema": "aten::reflection_pad1d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[2] padding, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor reflection_pad1d_backward(const at::Tensor & grad_output, const at::Tensor & self, c10::SymIntArrayRef padding); // {"schema": "aten::reflection_pad1d_backward(Tensor grad_output, Tensor self, SymInt[2] padding) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & reflection_pad2d_out(const at::Tensor & self, c10::SymIntArrayRef padding, at::Tensor & out); // {"schema": "aten::reflection_pad2d.out(Tensor self, SymInt[4] padding, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor reflection_pad2d(const at::Tensor & self, c10::SymIntArrayRef padding); // {"schema": "aten::reflection_pad2d(Tensor self, SymInt[4] padding) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & reflection_pad2d_backward_out(const at::Tensor & grad_output, const at::Tensor & self, c10::SymIntArrayRef padding, at::Tensor & grad_input); // {"schema": "aten::reflection_pad2d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[4] padding, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor reflection_pad2d_backward(const at::Tensor & grad_output, const at::Tensor & self, c10::SymIntArrayRef padding); // {"schema": "aten::reflection_pad2d_backward(Tensor grad_output, Tensor self, SymInt[4] padding) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & reflection_pad3d_out(const at::Tensor & self, c10::SymIntArrayRef padding, at::Tensor & out); // {"schema": "aten::reflection_pad3d.out(Tensor self, SymInt[6] padding, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor reflection_pad3d(const at::Tensor & self, c10::SymIntArrayRef padding); // {"schema": "aten::reflection_pad3d(Tensor self, SymInt[6] padding) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & reflection_pad3d_backward_out(const at::Tensor & grad_output, const at::Tensor & self, c10::SymIntArrayRef padding, at::Tensor & grad_input); // {"schema": "aten::reflection_pad3d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[6] padding, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor reflection_pad3d_backward(const at::Tensor & grad_output, const at::Tensor & self, c10::SymIntArrayRef padding); // {"schema": "aten::reflection_pad3d_backward(Tensor grad_output, Tensor self, SymInt[6] padding) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & replication_pad1d_out(const at::Tensor & self, c10::SymIntArrayRef padding, at::Tensor & out); // {"schema": "aten::replication_pad1d.out(Tensor self, SymInt[2] padding, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor replication_pad1d(const at::Tensor & self, c10::SymIntArrayRef padding); // {"schema": "aten::replication_pad1d(Tensor self, SymInt[2] padding) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & replication_pad1d_backward_out(const at::Tensor & grad_output, const at::Tensor & self, c10::SymIntArrayRef padding, at::Tensor & grad_input); // {"schema": "aten::replication_pad1d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[2] padding, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor replication_pad1d_backward(const at::Tensor & grad_output, const at::Tensor & self, c10::SymIntArrayRef padding); // {"schema": "aten::replication_pad1d_backward(Tensor grad_output, Tensor self, SymInt[2] padding) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & replication_pad2d_out(const at::Tensor & self, c10::SymIntArrayRef padding, at::Tensor & out); // {"schema": "aten::replication_pad2d.out(Tensor self, SymInt[4] padding, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor replication_pad2d(const at::Tensor & self, c10::SymIntArrayRef padding); // {"schema": "aten::replication_pad2d(Tensor self, SymInt[4] padding) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & replication_pad2d_backward_out(const at::Tensor & grad_output, const at::Tensor & self, c10::SymIntArrayRef padding, at::Tensor & grad_input); // {"schema": "aten::replication_pad2d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[4] padding, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor replication_pad2d_backward(const at::Tensor & grad_output, const at::Tensor & self, c10::SymIntArrayRef padding); // {"schema": "aten::replication_pad2d_backward(Tensor grad_output, Tensor self, SymInt[4] padding) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & replication_pad3d_out(const at::Tensor & self, c10::SymIntArrayRef padding, at::Tensor & out); // {"schema": "aten::replication_pad3d.out(Tensor self, SymInt[6] padding, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor replication_pad3d(const at::Tensor & self, c10::SymIntArrayRef padding); // {"schema": "aten::replication_pad3d(Tensor self, SymInt[6] padding) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & replication_pad3d_backward_out(const at::Tensor & grad_output, const at::Tensor & self, c10::SymIntArrayRef padding, at::Tensor & grad_input); // {"schema": "aten::replication_pad3d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[6] padding, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor replication_pad3d_backward(const at::Tensor & grad_output, const at::Tensor & self, c10::SymIntArrayRef padding); // {"schema": "aten::replication_pad3d_backward(Tensor grad_output, Tensor self, SymInt[6] padding) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _pad_circular(const at::Tensor & self, c10::SymIntArrayRef pad); // {"schema": "aten::_pad_circular(Tensor self, SymInt[] pad) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _pad_enum(const at::Tensor & self, c10::SymIntArrayRef pad, int64_t mode, ::std::optional value); // {"schema": "aten::_pad_enum(Tensor self, SymInt[] pad, int mode, float? value=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor pad(const at::Tensor & self, c10::SymIntArrayRef pad, c10::string_view mode, ::std::optional value); // {"schema": "aten::pad(Tensor self, SymInt[] pad, str mode=\"constant\", float? value=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor upsample_linear1d(const at::Tensor & input, at::OptionalSymIntArrayRef output_size, bool align_corners, ::std::optional> scale_factors); // {"schema": "aten::upsample_linear1d.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor upsample_bilinear2d(const at::Tensor & input, at::OptionalSymIntArrayRef output_size, bool align_corners, ::std::optional> scale_factors); // {"schema": "aten::upsample_bilinear2d.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _upsample_bilinear2d_aa(const at::Tensor & input, at::OptionalSymIntArrayRef output_size, bool align_corners, ::std::optional> scale_factors); // {"schema": "aten::_upsample_bilinear2d_aa.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor upsample_trilinear3d(const at::Tensor & input, at::OptionalSymIntArrayRef output_size, bool align_corners, ::std::optional> scale_factors); // {"schema": "aten::upsample_trilinear3d.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor upsample_bicubic2d(const at::Tensor & input, at::OptionalSymIntArrayRef output_size, bool align_corners, ::std::optional> scale_factors); // {"schema": "aten::upsample_bicubic2d.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _upsample_bicubic2d_aa(const at::Tensor & input, at::OptionalSymIntArrayRef output_size, bool align_corners, ::std::optional> scale_factors); // {"schema": "aten::_upsample_bicubic2d_aa.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor upsample_nearest1d(const at::Tensor & input, at::OptionalSymIntArrayRef output_size, ::std::optional> scale_factors); // {"schema": "aten::upsample_nearest1d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _upsample_nearest_exact1d(const at::Tensor & input, at::OptionalSymIntArrayRef output_size, ::std::optional> scale_factors); // {"schema": "aten::_upsample_nearest_exact1d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor upsample_nearest2d(const at::Tensor & input, at::OptionalSymIntArrayRef output_size, ::std::optional> scale_factors); // {"schema": "aten::upsample_nearest2d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _upsample_nearest_exact2d(const at::Tensor & input, at::OptionalSymIntArrayRef output_size, ::std::optional> scale_factors); // {"schema": "aten::_upsample_nearest_exact2d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor upsample_nearest3d(const at::Tensor & input, at::OptionalSymIntArrayRef output_size, ::std::optional> scale_factors); // {"schema": "aten::upsample_nearest3d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _upsample_nearest_exact3d(const at::Tensor & input, at::OptionalSymIntArrayRef output_size, ::std::optional> scale_factors); // {"schema": "aten::_upsample_nearest_exact3d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & upsample_linear1d_out(const at::Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, ::std::optional scales, at::Tensor & out); // {"schema": "aten::upsample_linear1d.out(Tensor self, SymInt[1] output_size, bool align_corners, float? scales=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor upsample_linear1d(const at::Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, ::std::optional scales); // {"schema": "aten::upsample_linear1d(Tensor self, SymInt[1] output_size, bool align_corners, float? scales=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & upsample_linear1d_backward_out(const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, ::std::optional scales, at::Tensor & grad_input); // {"schema": "aten::upsample_linear1d_backward.grad_input(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, bool align_corners, float? scales=None, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor upsample_linear1d_backward(const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, ::std::optional scales); // {"schema": "aten::upsample_linear1d_backward(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, bool align_corners, float? scales=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & upsample_bilinear2d_out(const at::Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & out); // {"schema": "aten::upsample_bilinear2d.out(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor upsample_bilinear2d(const at::Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, ::std::optional scales_h, ::std::optional scales_w); // {"schema": "aten::upsample_bilinear2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & upsample_bilinear2d_backward_out(const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & grad_input); // {"schema": "aten::upsample_bilinear2d_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor upsample_bilinear2d_backward(const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, ::std::optional scales_h, ::std::optional scales_w); // {"schema": "aten::upsample_bilinear2d_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & _upsample_bilinear2d_aa_out(const at::Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & out); // {"schema": "aten::_upsample_bilinear2d_aa.out(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor _upsample_bilinear2d_aa(const at::Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, ::std::optional scales_h, ::std::optional scales_w); // {"schema": "aten::_upsample_bilinear2d_aa(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & _upsample_bilinear2d_aa_backward_out(const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & grad_input); // {"schema": "aten::_upsample_bilinear2d_aa_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor _upsample_bilinear2d_aa_backward(const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, ::std::optional scales_h, ::std::optional scales_w); // {"schema": "aten::_upsample_bilinear2d_aa_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & upsample_bicubic2d_out(const at::Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & out); // {"schema": "aten::upsample_bicubic2d.out(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor upsample_bicubic2d(const at::Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, ::std::optional scales_h, ::std::optional scales_w); // {"schema": "aten::upsample_bicubic2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & upsample_bicubic2d_backward_out(const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & grad_input); // {"schema": "aten::upsample_bicubic2d_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor upsample_bicubic2d_backward(const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, ::std::optional scales_h, ::std::optional scales_w); // {"schema": "aten::upsample_bicubic2d_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & _upsample_bicubic2d_aa_out(const at::Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & out); // {"schema": "aten::_upsample_bicubic2d_aa.out(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor _upsample_bicubic2d_aa(const at::Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, ::std::optional scales_h, ::std::optional scales_w); // {"schema": "aten::_upsample_bicubic2d_aa(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & _upsample_bicubic2d_aa_backward_out(const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & grad_input); // {"schema": "aten::_upsample_bicubic2d_aa_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor _upsample_bicubic2d_aa_backward(const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, ::std::optional scales_h, ::std::optional scales_w); // {"schema": "aten::_upsample_bicubic2d_aa_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & upsample_trilinear3d_out(const at::Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, ::std::optional scales_d, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & out); // {"schema": "aten::upsample_trilinear3d.out(Tensor self, SymInt[3] output_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor upsample_trilinear3d(const at::Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, ::std::optional scales_d, ::std::optional scales_h, ::std::optional scales_w); // {"schema": "aten::upsample_trilinear3d(Tensor self, SymInt[3] output_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & upsample_trilinear3d_backward_out(const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, ::std::optional scales_d, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & grad_input); // {"schema": "aten::upsample_trilinear3d_backward.grad_input(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor upsample_trilinear3d_backward(const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, ::std::optional scales_d, ::std::optional scales_h, ::std::optional scales_w); // {"schema": "aten::upsample_trilinear3d_backward(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & upsample_nearest1d_out(const at::Tensor & self, c10::SymIntArrayRef output_size, ::std::optional scales, at::Tensor & out); // {"schema": "aten::upsample_nearest1d.out(Tensor self, SymInt[1] output_size, float? scales=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & _upsample_nearest_exact1d_out(const at::Tensor & self, c10::SymIntArrayRef output_size, ::std::optional scales, at::Tensor & out); // {"schema": "aten::_upsample_nearest_exact1d.out(Tensor self, SymInt[1] output_size, float? scales=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor upsample_nearest1d(const at::Tensor & self, c10::SymIntArrayRef output_size, ::std::optional scales); // {"schema": "aten::upsample_nearest1d(Tensor self, SymInt[1] output_size, float? scales=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor _upsample_nearest_exact1d(const at::Tensor & self, c10::SymIntArrayRef output_size, ::std::optional scales); // {"schema": "aten::_upsample_nearest_exact1d(Tensor self, SymInt[1] output_size, float? scales=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & upsample_nearest1d_backward_out(const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, ::std::optional scales, at::Tensor & grad_input); // {"schema": "aten::upsample_nearest1d_backward.grad_input(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, float? scales=None, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & _upsample_nearest_exact1d_backward_out(const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, ::std::optional scales, at::Tensor & grad_input); // {"schema": "aten::_upsample_nearest_exact1d_backward.grad_input(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, float? scales=None, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor upsample_nearest1d_backward(const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, ::std::optional scales); // {"schema": "aten::upsample_nearest1d_backward(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, float? scales=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor _upsample_nearest_exact1d_backward(const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, ::std::optional scales); // {"schema": "aten::_upsample_nearest_exact1d_backward(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, float? scales=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & upsample_nearest2d_out(const at::Tensor & self, c10::SymIntArrayRef output_size, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & out); // {"schema": "aten::upsample_nearest2d.out(Tensor self, SymInt[2] output_size, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & _upsample_nearest_exact2d_out(const at::Tensor & self, c10::SymIntArrayRef output_size, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & out); // {"schema": "aten::_upsample_nearest_exact2d.out(Tensor self, SymInt[2] output_size, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor upsample_nearest2d(const at::Tensor & self, c10::SymIntArrayRef output_size, ::std::optional scales_h, ::std::optional scales_w); // {"schema": "aten::upsample_nearest2d(Tensor self, SymInt[2] output_size, float? scales_h=None, float? scales_w=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor _upsample_nearest_exact2d(const at::Tensor & self, c10::SymIntArrayRef output_size, ::std::optional scales_h, ::std::optional scales_w); // {"schema": "aten::_upsample_nearest_exact2d(Tensor self, SymInt[2] output_size, float? scales_h=None, float? scales_w=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & upsample_nearest2d_backward_out(const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & grad_input); // {"schema": "aten::upsample_nearest2d_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & _upsample_nearest_exact2d_backward_out(const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & grad_input); // {"schema": "aten::_upsample_nearest_exact2d_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor upsample_nearest2d_backward(const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, ::std::optional scales_h, ::std::optional scales_w); // {"schema": "aten::upsample_nearest2d_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, float? scales_h=None, float? scales_w=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor _upsample_nearest_exact2d_backward(const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, ::std::optional scales_h, ::std::optional scales_w); // {"schema": "aten::_upsample_nearest_exact2d_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, float? scales_h=None, float? scales_w=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & upsample_nearest3d_out(const at::Tensor & self, c10::SymIntArrayRef output_size, ::std::optional scales_d, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & out); // {"schema": "aten::upsample_nearest3d.out(Tensor self, SymInt[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & _upsample_nearest_exact3d_out(const at::Tensor & self, c10::SymIntArrayRef output_size, ::std::optional scales_d, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & out); // {"schema": "aten::_upsample_nearest_exact3d.out(Tensor self, SymInt[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor upsample_nearest3d(const at::Tensor & self, c10::SymIntArrayRef output_size, ::std::optional scales_d, ::std::optional scales_h, ::std::optional scales_w); // {"schema": "aten::upsample_nearest3d(Tensor self, SymInt[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor _upsample_nearest_exact3d(const at::Tensor & self, c10::SymIntArrayRef output_size, ::std::optional scales_d, ::std::optional scales_h, ::std::optional scales_w); // {"schema": "aten::_upsample_nearest_exact3d(Tensor self, SymInt[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & upsample_nearest3d_backward_out(const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, ::std::optional scales_d, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & grad_input); // {"schema": "aten::upsample_nearest3d_backward.grad_input(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & _upsample_nearest_exact3d_backward_out(const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, ::std::optional scales_d, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & grad_input); // {"schema": "aten::_upsample_nearest_exact3d_backward.grad_input(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor upsample_nearest3d_backward(const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, ::std::optional scales_d, ::std::optional scales_h, ::std::optional scales_w); // {"schema": "aten::upsample_nearest3d_backward(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor _upsample_nearest_exact3d_backward(const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, ::std::optional scales_d, ::std::optional scales_h, ::std::optional scales_w); // {"schema": "aten::_upsample_nearest_exact3d_backward(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & sigmoid_backward_out(const at::Tensor & grad_output, const at::Tensor & output, at::Tensor & grad_input); // {"schema": "aten::sigmoid_backward.grad_input(Tensor grad_output, Tensor output, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor sigmoid_backward(const at::Tensor & grad_output, const at::Tensor & output); // {"schema": "aten::sigmoid_backward(Tensor grad_output, Tensor output) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & logit_backward_out(const at::Tensor & grad_output, const at::Tensor & self, ::std::optional eps, at::Tensor & grad_input); // {"schema": "aten::logit_backward.grad_input(Tensor grad_output, Tensor self, float? eps=None, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor logit_backward(const at::Tensor & grad_output, const at::Tensor & self, ::std::optional eps); // {"schema": "aten::logit_backward(Tensor grad_output, Tensor self, float? eps=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & tanh_backward_out(const at::Tensor & grad_output, const at::Tensor & output, at::Tensor & grad_input); // {"schema": "aten::tanh_backward.grad_input(Tensor grad_output, Tensor output, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor tanh_backward(const at::Tensor & grad_output, const at::Tensor & output); // {"schema": "aten::tanh_backward(Tensor grad_output, Tensor output) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & slow_conv_transpose2d_out(const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymIntArrayRef dilation, at::Tensor & out); // {"schema": "aten::slow_conv_transpose2d.out(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] output_padding=0, SymInt[2] dilation=1, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor slow_conv_transpose2d(const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymIntArrayRef dilation); // {"schema": "aten::slow_conv_transpose2d(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] output_padding=0, SymInt[2] dilation=1) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & slow_conv_transpose3d_out(const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymIntArrayRef dilation, at::Tensor & out); // {"schema": "aten::slow_conv_transpose3d.out(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, SymInt[3] output_padding=0, SymInt[3] dilation=1, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor slow_conv_transpose3d(const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymIntArrayRef dilation); // {"schema": "aten::slow_conv_transpose3d(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, SymInt[3] output_padding=0, SymInt[3] dilation=1) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & thnn_conv2d_out(const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, at::Tensor & out); // {"schema": "aten::thnn_conv2d.out(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor thnn_conv2d(const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding); // {"schema": "aten::thnn_conv2d(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & _slow_conv2d_forward_out(const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, at::Tensor & output); // {"schema": "aten::_slow_conv2d_forward.output(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias, SymInt[2] stride, SymInt[2] padding, *, Tensor(a!) output) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor _slow_conv2d_forward(const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding); // {"schema": "aten::_slow_conv2d_forward(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias, SymInt[2] stride, SymInt[2] padding) -> Tensor", "dispatch": "True", "default": "False"} +::std::tuple _slow_conv2d_backward_out(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, at::Tensor & grad_input, at::Tensor & grad_weight, at::Tensor & grad_bias); // {"schema": "aten::_slow_conv2d_backward.grad_input(Tensor grad_output, Tensor self, Tensor weight, SymInt[2] kernel_size, SymInt[2] stride, SymInt[2] padding, *, Tensor(a!) grad_input, Tensor(b!) grad_weight, Tensor(c!) grad_bias) -> (Tensor(a!), Tensor(b!), Tensor(c!))", "dispatch": "True", "default": "False"} +::std::tuple _slow_conv2d_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, ::std::array output_mask); // {"schema": "aten::_slow_conv2d_backward.output_mask(Tensor grad_output, Tensor self, Tensor weight, SymInt[2] kernel_size, SymInt[2] stride, SymInt[2] padding, bool[3] output_mask) -> (Tensor grad_input, Tensor grad_weight, Tensor grad_bias)", "dispatch": "True", "default": "False"} +at::Tensor & _conv_depthwise2d_out(const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, at::Tensor & out); // {"schema": "aten::_conv_depthwise2d.out(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias, SymInt[2] stride, SymInt[2] padding, SymInt[2] dilation, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor _conv_depthwise2d(const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation); // {"schema": "aten::_conv_depthwise2d(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias, SymInt[2] stride, SymInt[2] padding, SymInt[2] dilation) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor conv_depthwise3d(const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation); // {"schema": "aten::conv_depthwise3d(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias, SymInt[3] stride, SymInt[3] padding, SymInt[3] dilation) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & slow_conv3d_out(const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, at::Tensor & out); // {"schema": "aten::slow_conv3d.out(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor slow_conv3d(const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding); // {"schema": "aten::slow_conv3d(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & slow_conv3d_forward_out(const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, at::Tensor & output); // {"schema": "aten::slow_conv3d_forward.output(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias, SymInt[3] stride, SymInt[3] padding, *, Tensor(a!) output) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor slow_conv3d_forward(const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding); // {"schema": "aten::slow_conv3d_forward(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias, SymInt[3] stride, SymInt[3] padding) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor slow_conv_dilated2d(const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation); // {"schema": "aten::slow_conv_dilated2d(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] dilation=1) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor slow_conv_dilated3d(const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation); // {"schema": "aten::slow_conv_dilated3d(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, SymInt[3] dilation=1) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & col2im_out(const at::Tensor & self, c10::SymIntArrayRef output_size, at::IntArrayRef kernel_size, at::IntArrayRef dilation, at::IntArrayRef padding, at::IntArrayRef stride, at::Tensor & out); // {"schema": "aten::col2im.out(Tensor self, SymInt[2] output_size, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor col2im(const at::Tensor & self, c10::SymIntArrayRef output_size, at::IntArrayRef kernel_size, at::IntArrayRef dilation, at::IntArrayRef padding, at::IntArrayRef stride); // {"schema": "aten::col2im(Tensor self, SymInt[2] output_size, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor column_stack(at::TensorList tensors); // {"schema": "aten::column_stack(Tensor[] tensors) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & column_stack_out(at::TensorList tensors, at::Tensor & out); // {"schema": "aten::column_stack.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & im2col_out(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef dilation, at::IntArrayRef padding, at::IntArrayRef stride, at::Tensor & out); // {"schema": "aten::im2col.out(Tensor self, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor im2col(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef dilation, at::IntArrayRef padding, at::IntArrayRef stride); // {"schema": "aten::im2col(Tensor self, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor isfinite(const at::Tensor & self); // {"schema": "aten::isfinite(Tensor self) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor isinf(const at::Tensor & self); // {"schema": "aten::isinf(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +void record_stream(at::Tensor & self, at::Stream s); // {"schema": "aten::record_stream(Tensor(a!) self, Stream s) -> ()", "dispatch": "True", "default": "False"} +at::Tensor isposinf(const at::Tensor & self); // {"schema": "aten::isposinf(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & isposinf_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::isposinf.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor isneginf(const at::Tensor & self); // {"schema": "aten::isneginf(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & isneginf_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::isneginf.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor _add_batch_dim(const at::Tensor & self, int64_t batch_dim, int64_t level); // {"schema": "aten::_add_batch_dim(Tensor self, int batch_dim, int level) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _remove_batch_dim(const at::Tensor & self, int64_t level, c10::SymInt batch_size, int64_t out_dim); // {"schema": "aten::_remove_batch_dim(Tensor self, int level, SymInt batch_size, int out_dim) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor special_entr(const at::Tensor & self); // {"schema": "aten::special_entr(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & special_entr_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::special_entr.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor special_ndtri(const at::Tensor & self); // {"schema": "aten::special_ndtri(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & special_ndtri_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::special_ndtri.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor special_log_ndtr(const at::Tensor & self); // {"schema": "aten::special_log_ndtr(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & special_log_ndtr_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::special_log_ndtr.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor special_expm1(const at::Tensor & self); // {"schema": "aten::special_expm1(Tensor self) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & special_expm1_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::special_expm1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor special_exp2(const at::Tensor & self); // {"schema": "aten::special_exp2(Tensor self) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & special_exp2_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::special_exp2.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor special_psi(const at::Tensor & self); // {"schema": "aten::special_psi(Tensor self) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & special_psi_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::special_psi.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor special_digamma(const at::Tensor & self); // {"schema": "aten::special_digamma(Tensor self) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & special_digamma_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::special_digamma.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor special_gammaln(const at::Tensor & self); // {"schema": "aten::special_gammaln(Tensor self) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & special_gammaln_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::special_gammaln.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor special_erf(const at::Tensor & self); // {"schema": "aten::special_erf(Tensor self) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & special_erf_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::special_erf.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor special_erfc(const at::Tensor & self); // {"schema": "aten::special_erfc(Tensor self) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & special_erfc_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::special_erfc.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor special_erfcx(const at::Tensor & self); // {"schema": "aten::special_erfcx(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & special_erfcx_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::special_erfcx.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor special_erfinv(const at::Tensor & self); // {"schema": "aten::special_erfinv(Tensor self) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & special_erfinv_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::special_erfinv.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor special_ndtr(const at::Tensor & self); // {"schema": "aten::special_ndtr(Tensor self) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & special_ndtr_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::special_ndtr.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor special_xlog1py(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::special_xlog1py(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor special_xlog1py(const at::Scalar & self, const at::Tensor & other); // {"schema": "aten::special_xlog1py.self_scalar(Scalar self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor special_xlog1py(const at::Tensor & self, const at::Scalar & other); // {"schema": "aten::special_xlog1py.other_scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & special_xlog1py_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::special_xlog1py.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & special_xlog1py_out(const at::Scalar & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::special_xlog1py.self_scalar_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & special_xlog1py_out(const at::Tensor & self, const at::Scalar & other, at::Tensor & out); // {"schema": "aten::special_xlog1py.other_scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor special_xlogy(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::special_xlogy(Tensor self, Tensor other) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor special_xlogy(const at::Scalar & self, const at::Tensor & other); // {"schema": "aten::special_xlogy.self_scalar(Scalar self, Tensor other) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor special_xlogy(const at::Tensor & self, const at::Scalar & other); // {"schema": "aten::special_xlogy.other_scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & special_xlogy_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::special_xlogy.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & special_xlogy_out(const at::Scalar & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::special_xlogy.self_scalar_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & special_xlogy_out(const at::Tensor & self, const at::Scalar & other, at::Tensor & out); // {"schema": "aten::special_xlogy.other_scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor special_zeta(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::special_zeta(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor special_zeta(const at::Scalar & self, const at::Tensor & other); // {"schema": "aten::special_zeta.self_scalar(Scalar self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor special_zeta(const at::Tensor & self, const at::Scalar & other); // {"schema": "aten::special_zeta.other_scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & special_zeta_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::special_zeta.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & special_zeta_out(const at::Scalar & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::special_zeta.self_scalar_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & special_zeta_out(const at::Tensor & self, const at::Scalar & other, at::Tensor & out); // {"schema": "aten::special_zeta.other_scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor special_i0(const at::Tensor & self); // {"schema": "aten::special_i0(Tensor self) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & special_i0_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::special_i0.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor special_i0e(const at::Tensor & self); // {"schema": "aten::special_i0e(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & special_i0e_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::special_i0e.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor special_i1(const at::Tensor & self); // {"schema": "aten::special_i1(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & special_i1_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::special_i1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor special_i1e(const at::Tensor & self); // {"schema": "aten::special_i1e(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & special_i1e_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::special_i1e.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor special_logit(const at::Tensor & self, ::std::optional eps); // {"schema": "aten::special_logit(Tensor self, float? eps=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & special_logit_out(const at::Tensor & self, ::std::optional eps, at::Tensor & out); // {"schema": "aten::special_logit.out(Tensor self, float? eps=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor special_polygamma(int64_t n, const at::Tensor & self); // {"schema": "aten::special_polygamma(int n, Tensor self) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & special_polygamma_out(int64_t n, const at::Tensor & self, at::Tensor & out); // {"schema": "aten::special_polygamma.out(int n, Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor special_logsumexp(const at::Tensor & self, at::IntArrayRef dim, bool keepdim); // {"schema": "aten::special_logsumexp(Tensor self, int[1] dim, bool keepdim=False) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & special_logsumexp_out(const at::Tensor & self, at::IntArrayRef dim, bool keepdim, at::Tensor & out); // {"schema": "aten::special_logsumexp.out(Tensor self, int[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor special_expit(const at::Tensor & self); // {"schema": "aten::special_expit(Tensor self) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & special_expit_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::special_expit.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor special_sinc(const at::Tensor & self); // {"schema": "aten::special_sinc(Tensor self) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & special_sinc_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::special_sinc.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor special_round(const at::Tensor & self, int64_t decimals); // {"schema": "aten::special_round(Tensor self, *, int decimals=0) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & special_round_out(const at::Tensor & self, int64_t decimals, at::Tensor & out); // {"schema": "aten::special_round.out(Tensor self, *, int decimals=0, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor special_log1p(const at::Tensor & self); // {"schema": "aten::special_log1p(Tensor self) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & special_log1p_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::special_log1p.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor special_log_softmax(const at::Tensor & self, int64_t dim, ::std::optional dtype); // {"schema": "aten::special_log_softmax(Tensor self, int dim, *, ScalarType? dtype=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & special_gammainc_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::special_gammainc.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor special_gammainc(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::special_gammainc(Tensor self, Tensor other) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & special_gammaincc_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::special_gammaincc.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor special_gammaincc(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::special_gammaincc(Tensor self, Tensor other) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor special_multigammaln(const at::Tensor & self, int64_t p); // {"schema": "aten::special_multigammaln(Tensor self, int p) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & special_multigammaln_out(const at::Tensor & self, int64_t p, at::Tensor & out); // {"schema": "aten::special_multigammaln.out(Tensor self, int p, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor special_softmax(const at::Tensor & self, int64_t dim, ::std::optional dtype); // {"schema": "aten::special_softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor fft_fft(const at::Tensor & self, ::std::optional n, int64_t dim, ::std::optional norm); // {"schema": "aten::fft_fft(Tensor self, SymInt? n=None, int dim=-1, str? norm=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & fft_fft_out(const at::Tensor & self, ::std::optional n, int64_t dim, ::std::optional norm, at::Tensor & out); // {"schema": "aten::fft_fft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor fft_ifft(const at::Tensor & self, ::std::optional n, int64_t dim, ::std::optional norm); // {"schema": "aten::fft_ifft(Tensor self, SymInt? n=None, int dim=-1, str? norm=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & fft_ifft_out(const at::Tensor & self, ::std::optional n, int64_t dim, ::std::optional norm, at::Tensor & out); // {"schema": "aten::fft_ifft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor fft_rfft(const at::Tensor & self, ::std::optional n, int64_t dim, ::std::optional norm); // {"schema": "aten::fft_rfft(Tensor self, SymInt? n=None, int dim=-1, str? norm=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & fft_rfft_out(const at::Tensor & self, ::std::optional n, int64_t dim, ::std::optional norm, at::Tensor & out); // {"schema": "aten::fft_rfft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor fft_irfft(const at::Tensor & self, ::std::optional n, int64_t dim, ::std::optional norm); // {"schema": "aten::fft_irfft(Tensor self, SymInt? n=None, int dim=-1, str? norm=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & fft_irfft_out(const at::Tensor & self, ::std::optional n, int64_t dim, ::std::optional norm, at::Tensor & out); // {"schema": "aten::fft_irfft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor fft_hfft(const at::Tensor & self, ::std::optional n, int64_t dim, ::std::optional norm); // {"schema": "aten::fft_hfft(Tensor self, SymInt? n=None, int dim=-1, str? norm=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & fft_hfft_out(const at::Tensor & self, ::std::optional n, int64_t dim, ::std::optional norm, at::Tensor & out); // {"schema": "aten::fft_hfft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor fft_ihfft(const at::Tensor & self, ::std::optional n, int64_t dim, ::std::optional norm); // {"schema": "aten::fft_ihfft(Tensor self, SymInt? n=None, int dim=-1, str? norm=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & fft_ihfft_out(const at::Tensor & self, ::std::optional n, int64_t dim, ::std::optional norm, at::Tensor & out); // {"schema": "aten::fft_ihfft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor fft_fft2(const at::Tensor & self, at::OptionalSymIntArrayRef s, at::IntArrayRef dim, ::std::optional norm); // {"schema": "aten::fft_fft2(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & fft_fft2_out(const at::Tensor & self, at::OptionalSymIntArrayRef s, at::IntArrayRef dim, ::std::optional norm, at::Tensor & out); // {"schema": "aten::fft_fft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor fft_ifft2(const at::Tensor & self, at::OptionalSymIntArrayRef s, at::IntArrayRef dim, ::std::optional norm); // {"schema": "aten::fft_ifft2(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & fft_ifft2_out(const at::Tensor & self, at::OptionalSymIntArrayRef s, at::IntArrayRef dim, ::std::optional norm, at::Tensor & out); // {"schema": "aten::fft_ifft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor fft_rfft2(const at::Tensor & self, at::OptionalSymIntArrayRef s, at::IntArrayRef dim, ::std::optional norm); // {"schema": "aten::fft_rfft2(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & fft_rfft2_out(const at::Tensor & self, at::OptionalSymIntArrayRef s, at::IntArrayRef dim, ::std::optional norm, at::Tensor & out); // {"schema": "aten::fft_rfft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor fft_irfft2(const at::Tensor & self, at::OptionalSymIntArrayRef s, at::IntArrayRef dim, ::std::optional norm); // {"schema": "aten::fft_irfft2(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & fft_irfft2_out(const at::Tensor & self, at::OptionalSymIntArrayRef s, at::IntArrayRef dim, ::std::optional norm, at::Tensor & out); // {"schema": "aten::fft_irfft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor fft_hfft2(const at::Tensor & self, at::OptionalSymIntArrayRef s, at::IntArrayRef dim, ::std::optional norm); // {"schema": "aten::fft_hfft2(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & fft_hfft2_out(const at::Tensor & self, at::OptionalSymIntArrayRef s, at::IntArrayRef dim, ::std::optional norm, at::Tensor & out); // {"schema": "aten::fft_hfft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor fft_ihfft2(const at::Tensor & self, at::OptionalSymIntArrayRef s, at::IntArrayRef dim, ::std::optional norm); // {"schema": "aten::fft_ihfft2(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & fft_ihfft2_out(const at::Tensor & self, at::OptionalSymIntArrayRef s, at::IntArrayRef dim, ::std::optional norm, at::Tensor & out); // {"schema": "aten::fft_ihfft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor fft_fftn(const at::Tensor & self, at::OptionalSymIntArrayRef s, at::OptionalIntArrayRef dim, ::std::optional norm); // {"schema": "aten::fft_fftn(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & fft_fftn_out(const at::Tensor & self, at::OptionalSymIntArrayRef s, at::OptionalIntArrayRef dim, ::std::optional norm, at::Tensor & out); // {"schema": "aten::fft_fftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor fft_ifftn(const at::Tensor & self, at::OptionalSymIntArrayRef s, at::OptionalIntArrayRef dim, ::std::optional norm); // {"schema": "aten::fft_ifftn(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & fft_ifftn_out(const at::Tensor & self, at::OptionalSymIntArrayRef s, at::OptionalIntArrayRef dim, ::std::optional norm, at::Tensor & out); // {"schema": "aten::fft_ifftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor fft_rfftn(const at::Tensor & self, at::OptionalSymIntArrayRef s, at::OptionalIntArrayRef dim, ::std::optional norm); // {"schema": "aten::fft_rfftn(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & fft_rfftn_out(const at::Tensor & self, at::OptionalSymIntArrayRef s, at::OptionalIntArrayRef dim, ::std::optional norm, at::Tensor & out); // {"schema": "aten::fft_rfftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor fft_irfftn(const at::Tensor & self, at::OptionalSymIntArrayRef s, at::OptionalIntArrayRef dim, ::std::optional norm); // {"schema": "aten::fft_irfftn(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & fft_irfftn_out(const at::Tensor & self, at::OptionalSymIntArrayRef s, at::OptionalIntArrayRef dim, ::std::optional norm, at::Tensor & out); // {"schema": "aten::fft_irfftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor fft_hfftn(const at::Tensor & self, at::OptionalSymIntArrayRef s, at::OptionalIntArrayRef dim, ::std::optional norm); // {"schema": "aten::fft_hfftn(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & fft_hfftn_out(const at::Tensor & self, at::OptionalSymIntArrayRef s, at::OptionalIntArrayRef dim, ::std::optional norm, at::Tensor & out); // {"schema": "aten::fft_hfftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor fft_ihfftn(const at::Tensor & self, at::OptionalSymIntArrayRef s, at::OptionalIntArrayRef dim, ::std::optional norm); // {"schema": "aten::fft_ihfftn(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & fft_ihfftn_out(const at::Tensor & self, at::OptionalSymIntArrayRef s, at::OptionalIntArrayRef dim, ::std::optional norm, at::Tensor & out); // {"schema": "aten::fft_ihfftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor fft_fftfreq(int64_t n, double d, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::fft_fftfreq(int n, float d=1.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & fft_fftfreq_out(int64_t n, double d, at::Tensor & out); // {"schema": "aten::fft_fftfreq.out(int n, float d=1.0, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor fft_rfftfreq(int64_t n, double d, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::fft_rfftfreq(int n, float d=1.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & fft_rfftfreq_out(int64_t n, double d, at::Tensor & out); // {"schema": "aten::fft_rfftfreq.out(int n, float d=1.0, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor fft_fftshift(const at::Tensor & self, at::OptionalIntArrayRef dim); // {"schema": "aten::fft_fftshift(Tensor self, int[1]? dim=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor fft_ifftshift(const at::Tensor & self, at::OptionalIntArrayRef dim); // {"schema": "aten::fft_ifftshift(Tensor self, int[1]? dim=None) -> Tensor", "dispatch": "False", "default": "True"} +::std::tuple linalg_cholesky_ex(const at::Tensor & self, bool upper, bool check_errors); // {"schema": "aten::linalg_cholesky_ex(Tensor self, *, bool upper=False, bool check_errors=False) -> (Tensor L, Tensor info)", "dispatch": "True", "default": "True"} +::std::tuple linalg_cholesky_ex_out(const at::Tensor & self, bool upper, bool check_errors, at::Tensor & L, at::Tensor & info); // {"schema": "aten::linalg_cholesky_ex.L(Tensor self, *, bool upper=False, bool check_errors=False, Tensor(a!) L, Tensor(b!) info) -> (Tensor(a!) L, Tensor(b!) info)", "dispatch": "True", "default": "False"} +at::Tensor linalg_cholesky(const at::Tensor & self, bool upper); // {"schema": "aten::linalg_cholesky(Tensor self, *, bool upper=False) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & linalg_cholesky_out(const at::Tensor & self, bool upper, at::Tensor & out); // {"schema": "aten::linalg_cholesky.out(Tensor self, *, bool upper=False, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor linalg_cross(const at::Tensor & self, const at::Tensor & other, int64_t dim); // {"schema": "aten::linalg_cross(Tensor self, Tensor other, *, int dim=-1) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & linalg_cross_out(const at::Tensor & self, const at::Tensor & other, int64_t dim, at::Tensor & out); // {"schema": "aten::linalg_cross.out(Tensor self, Tensor other, *, int dim=-1, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +::std::tuple linalg_lu_factor(const at::Tensor & A, bool pivot); // {"schema": "aten::linalg_lu_factor(Tensor A, *, bool pivot=True) -> (Tensor LU, Tensor pivots)", "dispatch": "False", "default": "True"} +::std::tuple linalg_lu_factor_out(const at::Tensor & A, bool pivot, at::Tensor & LU, at::Tensor & pivots); // {"schema": "aten::linalg_lu_factor.out(Tensor A, *, bool pivot=True, Tensor(a!) LU, Tensor(b!) pivots) -> (Tensor(a!) LU, Tensor(b!) pivots)", "dispatch": "False", "default": "True"} +::std::tuple linalg_lu_factor_ex(const at::Tensor & A, bool pivot, bool check_errors); // {"schema": "aten::linalg_lu_factor_ex(Tensor A, *, bool pivot=True, bool check_errors=False) -> (Tensor LU, Tensor pivots, Tensor info)", "dispatch": "True", "default": "True"} +::std::tuple linalg_lu_factor_ex_out(const at::Tensor & A, bool pivot, bool check_errors, at::Tensor & LU, at::Tensor & pivots, at::Tensor & info); // {"schema": "aten::linalg_lu_factor_ex.out(Tensor A, *, bool pivot=True, bool check_errors=False, Tensor(a!) LU, Tensor(b!) pivots, Tensor(c!) info) -> (Tensor(a!) LU, Tensor(b!) pivots, Tensor(c!) info)", "dispatch": "True", "default": "False"} +::std::tuple linalg_lu(const at::Tensor & A, bool pivot); // {"schema": "aten::linalg_lu(Tensor A, *, bool pivot=True) -> (Tensor P, Tensor L, Tensor U)", "dispatch": "True", "default": "True"} +::std::tuple linalg_lu_out(const at::Tensor & A, bool pivot, at::Tensor & P, at::Tensor & L, at::Tensor & U); // {"schema": "aten::linalg_lu.out(Tensor A, *, bool pivot=True, Tensor(a!) P, Tensor(b!) L, Tensor(c!) U) -> (Tensor(a!) P, Tensor(b!) L, Tensor(c!) U)", "dispatch": "True", "default": "False"} +at::Tensor linalg_lu_solve(const at::Tensor & LU, const at::Tensor & pivots, const at::Tensor & B, bool left, bool adjoint); // {"schema": "aten::linalg_lu_solve(Tensor LU, Tensor pivots, Tensor B, *, bool left=True, bool adjoint=False) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & linalg_lu_solve_out(const at::Tensor & LU, const at::Tensor & pivots, const at::Tensor & B, bool left, bool adjoint, at::Tensor & out); // {"schema": "aten::linalg_lu_solve.out(Tensor LU, Tensor pivots, Tensor B, *, bool left=True, bool adjoint=False, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +::std::tuple _linalg_det(const at::Tensor & A); // {"schema": "aten::_linalg_det(Tensor A) -> (Tensor result, Tensor LU, Tensor pivots)", "dispatch": "True", "default": "True"} +::std::tuple _linalg_det_out(const at::Tensor & A, at::Tensor & result, at::Tensor & LU, at::Tensor & pivots); // {"schema": "aten::_linalg_det.result(Tensor A, *, Tensor(a!) result, Tensor(b!) LU, Tensor(c!) pivots) -> (Tensor(a!) result, Tensor(b!) LU, Tensor(c!) pivots)", "dispatch": "True", "default": "False"} +at::Tensor linalg_det(const at::Tensor & A); // {"schema": "aten::linalg_det(Tensor A) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & linalg_det_out(const at::Tensor & A, at::Tensor & out); // {"schema": "aten::linalg_det.out(Tensor A, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor det(const at::Tensor & self); // {"schema": "aten::det(Tensor self) -> Tensor", "dispatch": "False", "default": "True"} +::std::tuple linalg_ldl_factor_ex(const at::Tensor & self, bool hermitian, bool check_errors); // {"schema": "aten::linalg_ldl_factor_ex(Tensor self, *, bool hermitian=False, bool check_errors=False) -> (Tensor LD, Tensor pivots, Tensor info)", "dispatch": "True", "default": "True"} +::std::tuple linalg_ldl_factor_ex_out(const at::Tensor & self, bool hermitian, bool check_errors, at::Tensor & LD, at::Tensor & pivots, at::Tensor & info); // {"schema": "aten::linalg_ldl_factor_ex.out(Tensor self, *, bool hermitian=False, bool check_errors=False, Tensor(a!) LD, Tensor(b!) pivots, Tensor(c!) info) -> (Tensor(a!) LD, Tensor(b!) pivots, Tensor(c!) info)", "dispatch": "True", "default": "False"} +::std::tuple linalg_ldl_factor(const at::Tensor & self, bool hermitian); // {"schema": "aten::linalg_ldl_factor(Tensor self, *, bool hermitian=False) -> (Tensor LD, Tensor pivots)", "dispatch": "False", "default": "True"} +::std::tuple linalg_ldl_factor_out(const at::Tensor & self, bool hermitian, at::Tensor & LD, at::Tensor & pivots); // {"schema": "aten::linalg_ldl_factor.out(Tensor self, *, bool hermitian=False, Tensor(a!) LD, Tensor(b!) pivots) -> (Tensor(a!) LD, Tensor(b!) pivots)", "dispatch": "False", "default": "True"} +at::Tensor linalg_ldl_solve(const at::Tensor & LD, const at::Tensor & pivots, const at::Tensor & B, bool hermitian); // {"schema": "aten::linalg_ldl_solve(Tensor LD, Tensor pivots, Tensor B, *, bool hermitian=False) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & linalg_ldl_solve_out(const at::Tensor & LD, const at::Tensor & pivots, const at::Tensor & B, bool hermitian, at::Tensor & out); // {"schema": "aten::linalg_ldl_solve.out(Tensor LD, Tensor pivots, Tensor B, *, bool hermitian=False, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +::std::tuple linalg_lstsq(const at::Tensor & self, const at::Tensor & b, ::std::optional rcond, ::std::optional driver); // {"schema": "aten::linalg_lstsq(Tensor self, Tensor b, float? rcond=None, *, str? driver=None) -> (Tensor solution, Tensor residuals, Tensor rank, Tensor singular_values)", "dispatch": "True", "default": "True"} +::std::tuple linalg_lstsq_out(const at::Tensor & self, const at::Tensor & b, ::std::optional rcond, ::std::optional driver, at::Tensor & solution, at::Tensor & residuals, at::Tensor & rank, at::Tensor & singular_values); // {"schema": "aten::linalg_lstsq.out(Tensor self, Tensor b, float? rcond=None, *, str? driver=None, Tensor(a!) solution, Tensor(b!) residuals, Tensor(c!) rank, Tensor(d!) singular_values) -> (Tensor(a!) solution, Tensor(b!) residuals, Tensor(c!) rank, Tensor(d!) singular_values)", "dispatch": "True", "default": "False"} +at::Tensor linalg_matmul(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::linalg_matmul(Tensor self, Tensor other) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & linalg_matmul_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::linalg_matmul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor linalg_vecdot(const at::Tensor & x, const at::Tensor & y, int64_t dim); // {"schema": "aten::linalg_vecdot(Tensor x, Tensor y, *, int dim=-1) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & linalg_vecdot_out(const at::Tensor & x, const at::Tensor & y, int64_t dim, at::Tensor & out); // {"schema": "aten::linalg_vecdot.out(Tensor x, Tensor y, *, int dim=-1, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor linalg_matrix_exp(const at::Tensor & self); // {"schema": "aten::linalg_matrix_exp(Tensor self) -> Tensor", "dispatch": "True", "default": "False"} +::std::tuple _linalg_slogdet(const at::Tensor & A); // {"schema": "aten::_linalg_slogdet(Tensor A) -> (Tensor sign, Tensor logabsdet, Tensor LU, Tensor pivots)", "dispatch": "True", "default": "True"} +::std::tuple _linalg_slogdet_out(const at::Tensor & A, at::Tensor & sign, at::Tensor & logabsdet, at::Tensor & LU, at::Tensor & pivots); // {"schema": "aten::_linalg_slogdet.sign(Tensor A, *, Tensor(a!) sign, Tensor(b!) logabsdet, Tensor(c!) LU, Tensor(d!) pivots) -> (Tensor(a!) sign, Tensor(b!) logabsdet, Tensor(c!) LU, Tensor(d!) pivots)", "dispatch": "True", "default": "False"} +::std::tuple linalg_slogdet(const at::Tensor & A); // {"schema": "aten::linalg_slogdet(Tensor A) -> (Tensor sign, Tensor logabsdet)", "dispatch": "False", "default": "True"} +::std::tuple linalg_slogdet_out(const at::Tensor & A, at::Tensor & sign, at::Tensor & logabsdet); // {"schema": "aten::linalg_slogdet.out(Tensor A, *, Tensor(a!) sign, Tensor(b!) logabsdet) -> (Tensor(a!) sign, Tensor(b!) logabsdet)", "dispatch": "False", "default": "True"} +::std::tuple slogdet(const at::Tensor & self); // {"schema": "aten::slogdet(Tensor self) -> (Tensor sign, Tensor logabsdet)", "dispatch": "False", "default": "True"} +::std::tuple slogdet_out(const at::Tensor & self, at::Tensor & sign, at::Tensor & logabsdet); // {"schema": "aten::slogdet.out(Tensor self, *, Tensor(a!) sign, Tensor(b!) logabsdet) -> (Tensor(a!) sign, Tensor(b!) logabsdet)", "dispatch": "False", "default": "True"} +at::Tensor logdet(const at::Tensor & self); // {"schema": "aten::logdet(Tensor self) -> Tensor", "dispatch": "False", "default": "True"} +::std::tuple linalg_eig(const at::Tensor & self); // {"schema": "aten::linalg_eig(Tensor self) -> (Tensor eigenvalues, Tensor eigenvectors)", "dispatch": "True", "default": "False"} +::std::tuple linalg_eig_out(const at::Tensor & self, at::Tensor & eigenvalues, at::Tensor & eigenvectors); // {"schema": "aten::linalg_eig.out(Tensor self, *, Tensor(a!) eigenvalues, Tensor(b!) eigenvectors) -> (Tensor(a!) eigenvalues, Tensor(b!) eigenvectors)", "dispatch": "True", "default": "False"} +at::Tensor _linalg_eigvals(const at::Tensor & self); // {"schema": "aten::_linalg_eigvals(Tensor self) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor linalg_eigvals(const at::Tensor & self); // {"schema": "aten::linalg_eigvals(Tensor self) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & linalg_eigvals_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::linalg_eigvals.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +::std::tuple _linalg_eigh(const at::Tensor & A, c10::string_view UPLO, bool compute_v); // {"schema": "aten::_linalg_eigh(Tensor A, str UPLO=\"L\", bool compute_v=True) -> (Tensor eigenvalues, Tensor eigenvectors)", "dispatch": "True", "default": "True"} +::std::tuple _linalg_eigh_out(const at::Tensor & A, c10::string_view UPLO, bool compute_v, at::Tensor & eigenvalues, at::Tensor & eigenvectors); // {"schema": "aten::_linalg_eigh.eigenvalues(Tensor A, str UPLO=\"L\", bool compute_v=True, *, Tensor(a!) eigenvalues, Tensor(b!) eigenvectors) -> (Tensor(a!) eigenvalues, Tensor(b!) eigenvectors)", "dispatch": "True", "default": "False"} +::std::tuple linalg_eigh(const at::Tensor & self, c10::string_view UPLO); // {"schema": "aten::linalg_eigh(Tensor self, str UPLO=\"L\") -> (Tensor eigenvalues, Tensor eigenvectors)", "dispatch": "False", "default": "True"} +::std::tuple linalg_eigh_out(const at::Tensor & self, c10::string_view UPLO, at::Tensor & eigvals, at::Tensor & eigvecs); // {"schema": "aten::linalg_eigh.eigvals(Tensor self, str UPLO=\"L\", *, Tensor(a!) eigvals, Tensor(b!) eigvecs) -> (Tensor(a!) eigenvalues, Tensor(b!) eigenvectors)", "dispatch": "False", "default": "True"} +at::Tensor linalg_eigvalsh(const at::Tensor & self, c10::string_view UPLO); // {"schema": "aten::linalg_eigvalsh(Tensor self, str UPLO=\"L\") -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & linalg_eigvalsh_out(const at::Tensor & self, c10::string_view UPLO, at::Tensor & out); // {"schema": "aten::linalg_eigvalsh.out(Tensor self, str UPLO=\"L\", *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor linalg_householder_product(const at::Tensor & input, const at::Tensor & tau); // {"schema": "aten::linalg_householder_product(Tensor input, Tensor tau) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & linalg_householder_product_out(const at::Tensor & input, const at::Tensor & tau, at::Tensor & out); // {"schema": "aten::linalg_householder_product.out(Tensor input, Tensor tau, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +::std::tuple linalg_inv_ex(const at::Tensor & A, bool check_errors); // {"schema": "aten::linalg_inv_ex(Tensor A, *, bool check_errors=False) -> (Tensor inverse, Tensor info)", "dispatch": "True", "default": "True"} +::std::tuple linalg_inv_ex_out(const at::Tensor & A, bool check_errors, at::Tensor & inverse, at::Tensor & info); // {"schema": "aten::linalg_inv_ex.inverse(Tensor A, *, bool check_errors=False, Tensor(a!) inverse, Tensor(b!) info) -> (Tensor(a!) inverse, Tensor(b!) info)", "dispatch": "True", "default": "False"} +at::Tensor linalg_inv(const at::Tensor & A); // {"schema": "aten::linalg_inv(Tensor A) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & linalg_inv_out(const at::Tensor & A, at::Tensor & out); // {"schema": "aten::linalg_inv.out(Tensor A, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor inverse(const at::Tensor & self); // {"schema": "aten::inverse(Tensor self) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & inverse_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::inverse.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor inner(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::inner(Tensor self, Tensor other) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & inner_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::inner.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor outer(const at::Tensor & self, const at::Tensor & vec2); // {"schema": "aten::outer(Tensor self, Tensor vec2) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & outer_out(const at::Tensor & self, const at::Tensor & vec2, at::Tensor & out); // {"schema": "aten::outer.out(Tensor self, Tensor vec2, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor ger(const at::Tensor & self, const at::Tensor & vec2); // {"schema": "aten::ger(Tensor self, Tensor vec2) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & ger_out(const at::Tensor & self, const at::Tensor & vec2, at::Tensor & out); // {"schema": "aten::ger.out(Tensor self, Tensor vec2, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor linalg_norm(const at::Tensor & self, const ::std::optional & ord, at::OptionalIntArrayRef dim, bool keepdim, ::std::optional dtype); // {"schema": "aten::linalg_norm(Tensor self, Scalar? ord=None, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor linalg_norm(const at::Tensor & self, c10::string_view ord, at::OptionalIntArrayRef dim, bool keepdim, ::std::optional dtype); // {"schema": "aten::linalg_norm.ord_str(Tensor self, str ord, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & linalg_norm_out(const at::Tensor & self, const ::std::optional & ord, at::OptionalIntArrayRef dim, bool keepdim, ::std::optional dtype, at::Tensor & out); // {"schema": "aten::linalg_norm.out(Tensor self, Scalar? ord=None, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & linalg_norm_out(const at::Tensor & self, c10::string_view ord, at::OptionalIntArrayRef dim, bool keepdim, ::std::optional dtype, at::Tensor & out); // {"schema": "aten::linalg_norm.ord_str_out(Tensor self, str ord, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor linalg_vector_norm(const at::Tensor & self, const at::Scalar & ord, at::OptionalIntArrayRef dim, bool keepdim, ::std::optional dtype); // {"schema": "aten::linalg_vector_norm(Tensor self, Scalar ord=2, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & linalg_vector_norm_out(const at::Tensor & self, const at::Scalar & ord, at::OptionalIntArrayRef dim, bool keepdim, ::std::optional dtype, at::Tensor & out); // {"schema": "aten::linalg_vector_norm.out(Tensor self, Scalar ord=2, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor linalg_matrix_norm(const at::Tensor & self, const at::Scalar & ord, at::IntArrayRef dim, bool keepdim, ::std::optional dtype); // {"schema": "aten::linalg_matrix_norm(Tensor self, Scalar ord, int[] dim=[-2,-1], bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & linalg_matrix_norm_out(const at::Tensor & self, const at::Scalar & ord, at::IntArrayRef dim, bool keepdim, ::std::optional dtype, at::Tensor & out); // {"schema": "aten::linalg_matrix_norm.out(Tensor self, Scalar ord, int[] dim=[-2,-1], bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor linalg_matrix_norm(const at::Tensor & self, c10::string_view ord, at::IntArrayRef dim, bool keepdim, ::std::optional dtype); // {"schema": "aten::linalg_matrix_norm.str_ord(Tensor self, str ord='fro', int[] dim=[-2,-1], bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & linalg_matrix_norm_out(const at::Tensor & self, c10::string_view ord, at::IntArrayRef dim, bool keepdim, ::std::optional dtype, at::Tensor & out); // {"schema": "aten::linalg_matrix_norm.str_ord_out(Tensor self, str ord='fro', int[] dim=[-2,-1], bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +::std::tuple _linalg_svd(const at::Tensor & A, bool full_matrices, bool compute_uv, ::std::optional driver); // {"schema": "aten::_linalg_svd(Tensor A, bool full_matrices=False, bool compute_uv=True, *, str? driver=None) -> (Tensor U, Tensor S, Tensor Vh)", "dispatch": "True", "default": "True"} +::std::tuple _linalg_svd_out(const at::Tensor & A, bool full_matrices, bool compute_uv, ::std::optional driver, at::Tensor & U, at::Tensor & S, at::Tensor & Vh); // {"schema": "aten::_linalg_svd.U(Tensor A, bool full_matrices=False, bool compute_uv=True, *, str? driver=None, Tensor(a!) U, Tensor(b!) S, Tensor(c!) Vh) -> (Tensor(a!) U, Tensor(b!) S, Tensor(c!) Vh)", "dispatch": "True", "default": "False"} +::std::tuple linalg_svd(const at::Tensor & A, bool full_matrices, ::std::optional driver); // {"schema": "aten::linalg_svd(Tensor A, bool full_matrices=True, *, str? driver=None) -> (Tensor U, Tensor S, Tensor Vh)", "dispatch": "False", "default": "True"} +::std::tuple linalg_svd_out(const at::Tensor & A, bool full_matrices, ::std::optional driver, at::Tensor & U, at::Tensor & S, at::Tensor & Vh); // {"schema": "aten::linalg_svd.U(Tensor A, bool full_matrices=True, *, str? driver=None, Tensor(a!) U, Tensor(b!) S, Tensor(c!) Vh) -> (Tensor(a!) U, Tensor(b!) S, Tensor(c!) Vh)", "dispatch": "False", "default": "True"} +at::Tensor linalg_svdvals(const at::Tensor & A, ::std::optional driver); // {"schema": "aten::linalg_svdvals(Tensor A, *, str? driver=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & linalg_svdvals_out(const at::Tensor & A, ::std::optional driver, at::Tensor & out); // {"schema": "aten::linalg_svdvals.out(Tensor A, *, str? driver=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor linalg_cond(const at::Tensor & self, const ::std::optional & p); // {"schema": "aten::linalg_cond(Tensor self, Scalar? p=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & linalg_cond_out(const at::Tensor & self, const ::std::optional & p, at::Tensor & out); // {"schema": "aten::linalg_cond.out(Tensor self, Scalar? p=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor linalg_cond(const at::Tensor & self, c10::string_view p); // {"schema": "aten::linalg_cond.p_str(Tensor self, str p) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & linalg_cond_out(const at::Tensor & self, c10::string_view p, at::Tensor & out); // {"schema": "aten::linalg_cond.p_str_out(Tensor self, str p, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor linalg_pinv(const at::Tensor & self, const ::std::optional & atol, const ::std::optional & rtol, bool hermitian); // {"schema": "aten::linalg_pinv.atol_rtol_tensor(Tensor self, *, Tensor? atol=None, Tensor? rtol=None, bool hermitian=False) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & linalg_pinv_out(const at::Tensor & self, const ::std::optional & atol, const ::std::optional & rtol, bool hermitian, at::Tensor & out); // {"schema": "aten::linalg_pinv.atol_rtol_tensor_out(Tensor self, *, Tensor? atol=None, Tensor? rtol=None, bool hermitian=False, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor linalg_pinv(const at::Tensor & self, ::std::optional atol, ::std::optional rtol, bool hermitian); // {"schema": "aten::linalg_pinv.atol_rtol_float(Tensor self, *, float? atol=None, float? rtol=None, bool hermitian=False) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & linalg_pinv_out(const at::Tensor & self, ::std::optional atol, ::std::optional rtol, bool hermitian, at::Tensor & out); // {"schema": "aten::linalg_pinv.atol_rtol_float_out(Tensor self, *, float? atol=None, float? rtol=None, bool hermitian=False, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor linalg_pinv(const at::Tensor & self, double rcond, bool hermitian); // {"schema": "aten::linalg_pinv(Tensor self, float rcond, bool hermitian=False) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor linalg_pinv(const at::Tensor & self, const at::Tensor & rcond, bool hermitian); // {"schema": "aten::linalg_pinv.rcond_tensor(Tensor self, Tensor rcond, bool hermitian=False) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & linalg_pinv_out(const at::Tensor & self, double rcond, bool hermitian, at::Tensor & out); // {"schema": "aten::linalg_pinv.out(Tensor self, float rcond, bool hermitian=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & linalg_pinv_out(const at::Tensor & self, const at::Tensor & rcond, bool hermitian, at::Tensor & out); // {"schema": "aten::linalg_pinv.out_rcond_tensor(Tensor self, Tensor rcond, bool hermitian=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +::std::tuple _linalg_solve_ex(const at::Tensor & A, const at::Tensor & B, bool left, bool check_errors); // {"schema": "aten::_linalg_solve_ex(Tensor A, Tensor B, *, bool left=True, bool check_errors=False) -> (Tensor result, Tensor LU, Tensor pivots, Tensor info)", "dispatch": "True", "default": "True"} +::std::tuple _linalg_solve_ex_out(const at::Tensor & A, const at::Tensor & B, bool left, bool check_errors, at::Tensor & result, at::Tensor & LU, at::Tensor & pivots, at::Tensor & info); // {"schema": "aten::_linalg_solve_ex.result(Tensor A, Tensor B, *, bool left=True, bool check_errors=False, Tensor(a!) result, Tensor(b!) LU, Tensor(c!) pivots, Tensor(d!) info) -> (Tensor(a!) result, Tensor(b!) LU, Tensor(c!) pivots, Tensor(d!) info)", "dispatch": "True", "default": "False"} +::std::tuple linalg_solve_ex(const at::Tensor & A, const at::Tensor & B, bool left, bool check_errors); // {"schema": "aten::linalg_solve_ex(Tensor A, Tensor B, *, bool left=True, bool check_errors=False) -> (Tensor result, Tensor info)", "dispatch": "False", "default": "True"} +::std::tuple linalg_solve_ex_out(const at::Tensor & A, const at::Tensor & B, bool left, bool check_errors, at::Tensor & result, at::Tensor & info); // {"schema": "aten::linalg_solve_ex.out(Tensor A, Tensor B, *, bool left=True, bool check_errors=False, Tensor(a!) result, Tensor(b!) info) -> (Tensor(a!) result, Tensor(b!) info)", "dispatch": "False", "default": "True"} +at::Tensor linalg_solve(const at::Tensor & A, const at::Tensor & B, bool left); // {"schema": "aten::linalg_solve(Tensor A, Tensor B, *, bool left=True) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _spsolve(const at::Tensor & A, const at::Tensor & B, bool left); // {"schema": "aten::_spsolve(Tensor A, Tensor B, *, bool left=True) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & linalg_solve_out(const at::Tensor & A, const at::Tensor & B, bool left, at::Tensor & out); // {"schema": "aten::linalg_solve.out(Tensor A, Tensor B, *, bool left=True, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor linalg_tensorinv(const at::Tensor & self, int64_t ind); // {"schema": "aten::linalg_tensorinv(Tensor self, int ind=2) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & linalg_tensorinv_out(const at::Tensor & self, int64_t ind, at::Tensor & out); // {"schema": "aten::linalg_tensorinv.out(Tensor self, int ind=2, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor linalg_tensorsolve(const at::Tensor & self, const at::Tensor & other, at::OptionalIntArrayRef dims); // {"schema": "aten::linalg_tensorsolve(Tensor self, Tensor other, int[]? dims=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & linalg_tensorsolve_out(const at::Tensor & self, const at::Tensor & other, at::OptionalIntArrayRef dims, at::Tensor & out); // {"schema": "aten::linalg_tensorsolve.out(Tensor self, Tensor other, int[]? dims=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +::std::tuple linalg_qr(const at::Tensor & A, c10::string_view mode); // {"schema": "aten::linalg_qr(Tensor A, str mode='reduced') -> (Tensor Q, Tensor R)", "dispatch": "True", "default": "True"} +::std::tuple linalg_qr_out(const at::Tensor & A, c10::string_view mode, at::Tensor & Q, at::Tensor & R); // {"schema": "aten::linalg_qr.out(Tensor A, str mode='reduced', *, Tensor(a!) Q, Tensor(b!) R) -> (Tensor(a!) Q, Tensor(b!) R)", "dispatch": "True", "default": "False"} +at::Tensor linalg_matrix_power(const at::Tensor & self, int64_t n); // {"schema": "aten::linalg_matrix_power(Tensor self, int n) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & linalg_matrix_power_out(const at::Tensor & self, int64_t n, at::Tensor & out); // {"schema": "aten::linalg_matrix_power.out(Tensor self, int n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor linalg_matrix_rank(const at::Tensor & input, const ::std::optional & atol, const ::std::optional & rtol, bool hermitian); // {"schema": "aten::linalg_matrix_rank.atol_rtol_tensor(Tensor input, *, Tensor? atol=None, Tensor? rtol=None, bool hermitian=False) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & linalg_matrix_rank_out(const at::Tensor & input, const ::std::optional & atol, const ::std::optional & rtol, bool hermitian, at::Tensor & out); // {"schema": "aten::linalg_matrix_rank.atol_rtol_tensor_out(Tensor input, *, Tensor? atol=None, Tensor? rtol=None, bool hermitian=False, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor linalg_matrix_rank(const at::Tensor & self, ::std::optional atol, ::std::optional rtol, bool hermitian); // {"schema": "aten::linalg_matrix_rank.atol_rtol_float(Tensor self, *, float? atol=None, float? rtol=None, bool hermitian=False) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & linalg_matrix_rank_out(const at::Tensor & self, ::std::optional atol, ::std::optional rtol, bool hermitian, at::Tensor & out); // {"schema": "aten::linalg_matrix_rank.atol_rtol_float_out(Tensor self, *, float? atol=None, float? rtol=None, bool hermitian=False, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor linalg_matrix_rank(const at::Tensor & self, double tol, bool hermitian); // {"schema": "aten::linalg_matrix_rank(Tensor self, float tol, bool hermitian=False) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & linalg_matrix_rank_out(const at::Tensor & self, double tol, bool hermitian, at::Tensor & out); // {"schema": "aten::linalg_matrix_rank.out(Tensor self, float tol, bool hermitian=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor linalg_matrix_rank(const at::Tensor & input, const at::Tensor & tol, bool hermitian); // {"schema": "aten::linalg_matrix_rank.tol_tensor(Tensor input, Tensor tol, bool hermitian=False) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & linalg_matrix_rank_out(const at::Tensor & input, const at::Tensor & tol, bool hermitian, at::Tensor & out); // {"schema": "aten::linalg_matrix_rank.out_tol_tensor(Tensor input, Tensor tol, bool hermitian=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor linalg_multi_dot(at::TensorList tensors); // {"schema": "aten::linalg_multi_dot(Tensor[] tensors) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & linalg_multi_dot_out(at::TensorList tensors, at::Tensor & out); // {"schema": "aten::linalg_multi_dot.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor nested_to_padded_tensor(const at::Tensor & self, double padding, at::OptionalIntArrayRef output_size); // {"schema": "aten::nested_to_padded_tensor(Tensor self, float padding, int[]? output_size=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _test_serialization_subcmul(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha); // {"schema": "aten::_test_serialization_subcmul(Tensor self, Tensor other, Scalar alpha=1) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _test_parallel_materialize(const at::Tensor & self, int64_t num_parallel, bool skip_first); // {"schema": "aten::_test_parallel_materialize(Tensor self, int num_parallel, bool skip_first=False) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor _test_optional_intlist(const at::Tensor & values, at::OptionalIntArrayRef addends); // {"schema": "aten::_test_optional_intlist(Tensor values, int[]? addends) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _test_optional_filled_intlist(const at::Tensor & values, at::OptionalIntArrayRef addends); // {"schema": "aten::_test_optional_filled_intlist(Tensor values, int[2]? addends) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _test_optional_floatlist(const at::Tensor & values, ::std::optional> addends); // {"schema": "aten::_test_optional_floatlist(Tensor values, float[]? addends) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _test_string_default(const at::Tensor & dummy, c10::string_view a, c10::string_view b); // {"schema": "aten::_test_string_default(Tensor dummy, str a=\"\\\"'\\\\\", str b='\"\\'\\\\') -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _test_ambiguous_defaults(const at::Tensor & dummy, int64_t a, int64_t b); // {"schema": "aten::_test_ambiguous_defaults.a(Tensor dummy, int a=1, int b=1) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _test_ambiguous_defaults(const at::Tensor & dummy, int64_t a, c10::string_view b); // {"schema": "aten::_test_ambiguous_defaults.b(Tensor dummy, int a=2, str b=\"2\") -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _test_warn_in_autograd(const at::Tensor & self); // {"schema": "aten::_test_warn_in_autograd(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor _test_autograd_multiple_dispatch(const at::Tensor & self); // {"schema": "aten::_test_autograd_multiple_dispatch.fullcoverage(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor _test_autograd_multiple_dispatch(const at::Tensor & self, bool b); // {"schema": "aten::_test_autograd_multiple_dispatch.ntonly(Tensor self, bool b) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor _test_autograd_multiple_dispatch_view(const at::Tensor & self); // {"schema": "aten::_test_autograd_multiple_dispatch_view(Tensor(a) self) -> Tensor(a)", "dispatch": "True", "default": "True"} +at::Tensor _test_autograd_multiple_dispatch_view_copy(const at::Tensor & self); // {"schema": "aten::_test_autograd_multiple_dispatch_view_copy(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor segment_reduce(const at::Tensor & data, c10::string_view reduce, const ::std::optional & lengths, const ::std::optional & indices, const ::std::optional & offsets, int64_t axis, bool unsafe, const ::std::optional & initial); // {"schema": "aten::segment_reduce(Tensor data, str reduce, *, Tensor? lengths=None, Tensor? indices=None, Tensor? offsets=None, int axis=0, bool unsafe=False, Scalar? initial=None) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _segment_reduce_backward(const at::Tensor & grad, const at::Tensor & output, const at::Tensor & data, c10::string_view reduce, const ::std::optional & lengths, const ::std::optional & offsets, int64_t axis, const ::std::optional & initial); // {"schema": "aten::_segment_reduce_backward(Tensor grad, Tensor output, Tensor data, str reduce, *, Tensor? lengths=None, Tensor? offsets=None, int axis=0, Scalar? initial=None) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor pad_sequence(at::TensorList sequences, bool batch_first, double padding_value, c10::string_view padding_side); // {"schema": "aten::pad_sequence(Tensor[] sequences, bool batch_first=False, float padding_value=0.0, str padding_side=\"right\") -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor flatten_dense_tensors(at::TensorList tensors); // {"schema": "aten::flatten_dense_tensors(Tensor[] tensors) -> Tensor", "dispatch": "False", "default": "True"} +::std::vector unflatten_dense_tensors(const at::Tensor & flat, at::TensorList tensors); // {"schema": "aten::unflatten_dense_tensors(Tensor flat, Tensor[] tensors) -> Tensor[]", "dispatch": "False", "default": "True"} +at::Tensor _nested_tensor_from_tensor_list(at::TensorList list, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::_nested_tensor_from_tensor_list(Tensor[] list, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor _fw_primal_copy(const at::Tensor & self, int64_t level); // {"schema": "aten::_fw_primal_copy(Tensor self, int level) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor _make_dual_copy(const at::Tensor & primal, const at::Tensor & tangent, int64_t level); // {"schema": "aten::_make_dual_copy(Tensor primal, Tensor tangent, int level) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor view_as_real_copy(const at::Tensor & self); // {"schema": "aten::view_as_real_copy(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor view_as_complex_copy(const at::Tensor & self); // {"schema": "aten::view_as_complex_copy(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor _conj_copy(const at::Tensor & self); // {"schema": "aten::_conj_copy(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor _neg_view_copy(const at::Tensor & self); // {"schema": "aten::_neg_view_copy(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor as_strided_copy(const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional storage_offset); // {"schema": "aten::as_strided_copy(Tensor self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor _sparse_broadcast_to_copy(const at::Tensor & self, at::IntArrayRef size); // {"schema": "aten::_sparse_broadcast_to_copy(Tensor self, int[] size) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor diagonal_copy(const at::Tensor & self, int64_t offset, int64_t dim1, int64_t dim2); // {"schema": "aten::diagonal_copy(Tensor self, int offset=0, int dim1=0, int dim2=1) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor expand_copy(const at::Tensor & self, c10::SymIntArrayRef size, bool implicit); // {"schema": "aten::expand_copy(Tensor self, SymInt[] size, *, bool implicit=False) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor permute_copy(const at::Tensor & self, at::IntArrayRef dims); // {"schema": "aten::permute_copy(Tensor self, int[] dims) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor _reshape_alias_copy(const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride); // {"schema": "aten::_reshape_alias_copy(Tensor self, SymInt[] size, SymInt[] stride) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor select_copy(const at::Tensor & self, int64_t dim, c10::SymInt index); // {"schema": "aten::select_copy.int(Tensor self, int dim, SymInt index) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor detach_copy(const at::Tensor & self); // {"schema": "aten::detach_copy(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor slice_copy(const at::Tensor & self, int64_t dim, ::std::optional start, ::std::optional end, c10::SymInt step); // {"schema": "aten::slice_copy.Tensor(Tensor self, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor", "dispatch": "True", "default": "True"} +::std::vector split_copy(const at::Tensor & self, c10::SymInt split_size, int64_t dim); // {"schema": "aten::split_copy.Tensor(Tensor self, SymInt split_size, int dim=0) -> Tensor[]", "dispatch": "True", "default": "True"} +::std::vector split_with_sizes_copy(const at::Tensor & self, c10::SymIntArrayRef split_sizes, int64_t dim); // {"schema": "aten::split_with_sizes_copy(Tensor self, SymInt[] split_sizes, int dim=0) -> Tensor[]", "dispatch": "True", "default": "True"} +at::Tensor squeeze_copy(const at::Tensor & self); // {"schema": "aten::squeeze_copy(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor squeeze_copy(const at::Tensor & self, int64_t dim); // {"schema": "aten::squeeze_copy.dim(Tensor self, int dim) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor squeeze_copy(const at::Tensor & self, at::IntArrayRef dim); // {"schema": "aten::squeeze_copy.dims(Tensor self, int[] dim) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor t_copy(const at::Tensor & self); // {"schema": "aten::t_copy(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor transpose_copy(const at::Tensor & self, int64_t dim0, int64_t dim1); // {"schema": "aten::transpose_copy.int(Tensor self, int dim0, int dim1) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor unsqueeze_copy(const at::Tensor & self, int64_t dim); // {"schema": "aten::unsqueeze_copy(Tensor self, int dim) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor _indices_copy(const at::Tensor & self); // {"schema": "aten::_indices_copy(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor _values_copy(const at::Tensor & self); // {"schema": "aten::_values_copy(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor indices_copy(const at::Tensor & self); // {"schema": "aten::indices_copy(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor values_copy(const at::Tensor & self); // {"schema": "aten::values_copy(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor crow_indices_copy(const at::Tensor & self); // {"schema": "aten::crow_indices_copy(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor col_indices_copy(const at::Tensor & self); // {"schema": "aten::col_indices_copy(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor ccol_indices_copy(const at::Tensor & self); // {"schema": "aten::ccol_indices_copy(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor row_indices_copy(const at::Tensor & self); // {"schema": "aten::row_indices_copy(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +::std::vector unbind_copy(const at::Tensor & self, int64_t dim); // {"schema": "aten::unbind_copy.int(Tensor self, int dim=0) -> Tensor[]", "dispatch": "True", "default": "True"} +void unbind_copy_out(const at::Tensor & self, int64_t dim, at::TensorList out); // {"schema": "aten::unbind_copy.int_out(Tensor self, int dim=0, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void split_copy_out(const at::Tensor & self, c10::SymInt split_size, int64_t dim, at::TensorList out); // {"schema": "aten::split_copy.Tensor_out(Tensor self, SymInt split_size, int dim=0, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void split_with_sizes_copy_out(const at::Tensor & self, c10::SymIntArrayRef split_sizes, int64_t dim, at::TensorList out); // {"schema": "aten::split_with_sizes_copy.out(Tensor self, SymInt[] split_sizes, int dim=0, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +at::Tensor view_copy(const at::Tensor & self, c10::SymIntArrayRef size); // {"schema": "aten::view_copy(Tensor self, SymInt[] size) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor view_copy(const at::Tensor & self, at::ScalarType dtype); // {"schema": "aten::view_copy.dtype(Tensor self, ScalarType dtype) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor unfold_copy(const at::Tensor & self, int64_t dimension, int64_t size, int64_t step); // {"schema": "aten::unfold_copy(Tensor self, int dimension, int size, int step) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor alias_copy(const at::Tensor & self); // {"schema": "aten::alias_copy(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor to_padded_tensor(const at::Tensor & self, double padding, at::OptionalSymIntArrayRef output_size); // {"schema": "aten::to_padded_tensor(Tensor self, float padding, SymInt[]? output_size=None) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _jagged_to_padded_dense_forward(const at::Tensor & values, at::TensorList offsets, c10::SymIntArrayRef max_lengths, double padding_value); // {"schema": "aten::_jagged_to_padded_dense_forward(Tensor values, Tensor[] offsets, SymInt[] max_lengths, float padding_value=0.0) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _padded_dense_to_jagged_forward(const at::Tensor & dense, at::TensorList offsets, ::std::optional total_L); // {"schema": "aten::_padded_dense_to_jagged_forward(Tensor dense, Tensor[] offsets, SymInt? total_L=None) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _nested_from_padded_tensor(const at::Tensor & padded, const at::Tensor & offsets, const at::Tensor & dummy, int64_t ragged_idx, const ::std::optional & min_seqlen, const ::std::optional & max_seqlen, ::std::optional sum_S); // {"schema": "aten::_nested_from_padded_tensor(Tensor padded, Tensor offsets, Tensor dummy, int ragged_idx=1, Tensor? min_seqlen=None, Tensor? max_seqlen=None, SymInt? sum_S=None) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _nested_tensor_softmax_with_shape(const at::Tensor & self, const at::Tensor & query); // {"schema": "aten::_nested_tensor_softmax_with_shape(Tensor self, Tensor query) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _safe_softmax(const at::Tensor & self, int64_t dim, ::std::optional dtype); // {"schema": "aten::_safe_softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor _transformer_encoder_layer_fwd(const at::Tensor & src, int64_t embed_dim, int64_t num_heads, const at::Tensor & qkv_weight, const at::Tensor & qkv_bias, const at::Tensor & proj_weight, const at::Tensor & proj_bias, bool use_gelu, bool norm_first, double eps, const at::Tensor & norm_weight_1, const at::Tensor & norm_bias_1, const at::Tensor & norm_weight_2, const at::Tensor & norm_bias_2, const at::Tensor & ffn_weight_1, const at::Tensor & ffn_bias_1, const at::Tensor & ffn_weight_2, const at::Tensor & ffn_bias_2, const ::std::optional & mask, ::std::optional mask_type); // {"schema": "aten::_transformer_encoder_layer_fwd(Tensor src, int embed_dim, int num_heads, Tensor qkv_weight, Tensor qkv_bias, Tensor proj_weight, Tensor proj_bias, bool use_gelu, bool norm_first, float eps, Tensor norm_weight_1, Tensor norm_bias_1, Tensor norm_weight_2, Tensor norm_bias_2, Tensor ffn_weight_1, Tensor ffn_bias_1, Tensor ffn_weight_2, Tensor ffn_bias_2, Tensor? mask=None, int? mask_type=None) -> Tensor", "dispatch": "True", "default": "False"} +::std::tuple _native_multi_head_attention(const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, int64_t embed_dim, int64_t num_head, const at::Tensor & qkv_weight, const at::Tensor & qkv_bias, const at::Tensor & proj_weight, const at::Tensor & proj_bias, const ::std::optional & mask, bool need_weights, bool average_attn_weights, ::std::optional mask_type); // {"schema": "aten::_native_multi_head_attention(Tensor query, Tensor key, Tensor value, int embed_dim, int num_head, Tensor qkv_weight, Tensor qkv_bias, Tensor proj_weight, Tensor proj_bias, Tensor? mask=None, bool need_weights=True, bool average_attn_weights=True, int? mask_type=None) -> (Tensor, Tensor)", "dispatch": "True", "default": "False"} +at::Tensor scaled_dot_product_attention(const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const ::std::optional & attn_mask, double dropout_p, bool is_causal, ::std::optional scale, bool enable_gqa); // {"schema": "aten::scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, *, float? scale=None, bool enable_gqa=False) -> Tensor", "dispatch": "False", "default": "True"} +int64_t _fused_sdp_choice(const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const ::std::optional & attn_mask, double dropout_p, bool is_causal, ::std::optional scale, bool enable_gqa); // {"schema": "aten::_fused_sdp_choice(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, *, float? scale=None, bool enable_gqa=False) -> int", "dispatch": "True", "default": "False"} +::std::tuple _scaled_dot_product_attention_math(const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const ::std::optional & attn_mask, double dropout_p, bool is_causal, const ::std::optional & dropout_mask, ::std::optional scale, bool enable_gqa); // {"schema": "aten::_scaled_dot_product_attention_math(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, Tensor? dropout_mask=None, *, float? scale=None, bool enable_gqa=False) -> (Tensor, Tensor)", "dispatch": "False", "default": "True"} +::std::tuple _scaled_dot_product_attention_math_for_mps(const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const ::std::optional & attn_mask, double dropout_p, bool is_causal, const ::std::optional & dropout_mask, ::std::optional scale); // {"schema": "aten::_scaled_dot_product_attention_math_for_mps(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, Tensor? dropout_mask=None, *, float? scale=None) -> (Tensor, Tensor)", "dispatch": "True", "default": "False"} +::std::tuple _scaled_dot_product_flash_attention(const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, double dropout_p, bool is_causal, bool return_debug_mask, ::std::optional scale); // {"schema": "aten::_scaled_dot_product_flash_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor rng_state, Tensor unused, Tensor debug_attn_mask)", "dispatch": "True", "default": "False"} +::std::tuple _scaled_dot_product_flash_attention_for_cpu(const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, double dropout_p, bool is_causal, const ::std::optional & attn_mask, ::std::optional scale); // {"schema": "aten::_scaled_dot_product_flash_attention_for_cpu(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, *, Tensor? attn_mask=None, float? scale=None) -> (Tensor output, Tensor logsumexp)", "dispatch": "True", "default": "False"} +::std::tuple _scaled_dot_product_fused_attention_overrideable(const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const ::std::optional & attn_bias, double dropout_p, bool is_causal, bool return_debug_mask, ::std::optional scale); // {"schema": "aten::_scaled_dot_product_fused_attention_overrideable(Tensor query, Tensor key, Tensor value, Tensor? attn_bias=None, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)", "dispatch": "True", "default": "True"} +::std::tuple _scaled_dot_product_flash_attention_backward(const at::Tensor & grad_out, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const at::Tensor & out, const at::Tensor & logsumexp, const at::Tensor & cum_seq_q, const at::Tensor & cum_seq_k, c10::SymInt max_q, c10::SymInt max_k, double dropout_p, bool is_causal, const at::Tensor & philox_seed, const at::Tensor & philox_offset, ::std::optional scale); // {"schema": "aten::_scaled_dot_product_flash_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, Tensor philox_seed, Tensor philox_offset, *, float? scale=None) -> (Tensor grad_query, Tensor grad_key, Tensor grad_value)", "dispatch": "True", "default": "False"} +::std::tuple _scaled_dot_product_flash_attention_for_cpu_backward(const at::Tensor & grad_out, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const at::Tensor & out, const at::Tensor & logsumexp, double dropout_p, bool is_causal, const ::std::optional & attn_mask, ::std::optional scale); // {"schema": "aten::_scaled_dot_product_flash_attention_for_cpu_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, float dropout_p, bool is_causal, *, Tensor? attn_mask=None, float? scale=None) -> (Tensor grad_query, Tensor grad_key, Tensor grad_value)", "dispatch": "True", "default": "False"} +::std::tuple _scaled_dot_product_fused_attention_overrideable_backward(const at::Tensor & grad_out, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const at::Tensor & attn_bias, ::std::array grad_input_mask, const at::Tensor & out, const at::Tensor & logsumexp, const at::Tensor & cum_seq_q, const at::Tensor & cum_seq_k, c10::SymInt max_q, c10::SymInt max_k, double dropout_p, bool is_causal, const at::Tensor & philox_seed, const at::Tensor & philox_offset, ::std::optional scale); // {"schema": "aten::_scaled_dot_product_fused_attention_overrideable_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor attn_bias, bool[4] grad_input_mask, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, Tensor philox_seed, Tensor philox_offset, *, float? scale=None) -> (Tensor grad_query, Tensor grad_key, Tensor grad_value, Tensor grad_attn_bias)", "dispatch": "True", "default": "True"} +::std::tuple _scaled_dot_product_efficient_attention(const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const ::std::optional & attn_bias, bool compute_log_sumexp, double dropout_p, bool is_causal, ::std::optional scale); // {"schema": "aten::_scaled_dot_product_efficient_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> (Tensor output, Tensor log_sumexp, Tensor philox_seed, Tensor philox_offset)", "dispatch": "True", "default": "False"} +::std::tuple _scaled_dot_product_efficient_attention_backward(const at::Tensor & grad_out_, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const at::Tensor & attn_bias, const at::Tensor & out, const at::Tensor & logsumexp, const at::Tensor & philox_seed, const at::Tensor & philox_offset, double dropout_p, ::std::array grad_input_mask, bool is_causal, ::std::optional scale); // {"schema": "aten::_scaled_dot_product_efficient_attention_backward(Tensor grad_out_, Tensor query, Tensor key, Tensor value, Tensor attn_bias, Tensor out, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, float dropout_p, bool[4] grad_input_mask, bool is_causal=False, *, float? scale=None) -> (Tensor, Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"} +::std::tuple _scaled_dot_product_cudnn_attention(const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const ::std::optional & attn_bias, bool compute_log_sumexp, double dropout_p, bool is_causal, bool return_debug_mask, ::std::optional scale); // {"schema": "aten::_scaled_dot_product_cudnn_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)", "dispatch": "True", "default": "False"} +::std::tuple _scaled_dot_product_cudnn_attention_backward(const at::Tensor & grad_out, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const at::Tensor & out, const at::Tensor & logsumexp, const at::Tensor & philox_seed, const at::Tensor & philox_offset, const at::Tensor & attn_bias, const at::Tensor & cum_seq_q, const at::Tensor & cum_seq_k, c10::SymInt max_q, c10::SymInt max_k, double dropout_p, bool is_causal, ::std::optional scale); // {"schema": "aten::_scaled_dot_product_cudnn_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor attn_bias, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, *, float? scale=None) -> (Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"} +::std::tuple _flash_attention_forward(const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const ::std::optional & cum_seq_q, const ::std::optional & cum_seq_k, c10::SymInt max_q, c10::SymInt max_k, double dropout_p, bool is_causal, bool return_debug_mask, ::std::optional scale, ::std::optional window_size_left, ::std::optional window_size_right, const ::std::optional & seqused_k, const ::std::optional & alibi_slopes); // {"schema": "aten::_flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cum_seq_q, Tensor? cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, bool return_debug_mask, *, float? scale=None, SymInt? window_size_left=None, SymInt? window_size_right=None, Tensor? seqused_k=None, Tensor? alibi_slopes=None) -> (Tensor output, Tensor softmax_logsumexp, Tensor rng_state, Tensor unused, Tensor debug_attn_mask)", "dispatch": "True", "default": "False"} +::std::tuple _flash_attention_backward(const at::Tensor & grad_out, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const at::Tensor & out, const at::Tensor & logsumexp, const at::Tensor & cum_seq_q, const at::Tensor & cum_seq_k, c10::SymInt max_q, c10::SymInt max_k, double dropout_p, bool is_causal, const at::Tensor & rng_state, const at::Tensor & unused, ::std::optional scale, ::std::optional window_size_left, ::std::optional window_size_right); // {"schema": "aten::_flash_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, Tensor rng_state, Tensor unused, *, float? scale=None, SymInt? window_size_left=None, SymInt? window_size_right=None) -> (Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"} +::std::tuple _efficient_attention_forward(const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const ::std::optional & bias, const ::std::optional & cu_seqlens_q, const ::std::optional & cu_seqlens_k, ::std::optional max_seqlen_q, ::std::optional max_seqlen_k, double dropout_p, int64_t custom_mask_type, bool compute_log_sumexp, ::std::optional scale, const ::std::optional & seqlen_k, ::std::optional window_size); // {"schema": "aten::_efficient_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? bias, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, SymInt? max_seqlen_q, SymInt? max_seqlen_k, float dropout_p, int custom_mask_type, bool compute_log_sumexp=False, *, float? scale=None, Tensor? seqlen_k=None, int? window_size=None) -> (Tensor output, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, SymInt max_seqlen_batch_q, SymInt max_seqlen_batch_k)", "dispatch": "True", "default": "False"} +::std::tuple _efficient_attention_backward(const at::Tensor & grad_out_, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const ::std::optional & bias, const at::Tensor & out, const ::std::optional & cu_seqlens_q, const ::std::optional & cu_seqlens_k, c10::SymInt max_seqlen_q, c10::SymInt max_seqlen_k, const at::Tensor & logsumexp, double dropout_p, const at::Tensor & philox_seed, const at::Tensor & philox_offset, int64_t custom_mask_type, bool bias_requires_grad, ::std::optional scale, ::std::optional num_splits_key, ::std::optional window_size, bool shared_storage_dqdkdv); // {"schema": "aten::_efficient_attention_backward(Tensor grad_out_, Tensor query, Tensor key, Tensor value, Tensor? bias, Tensor out, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, SymInt max_seqlen_q, SymInt max_seqlen_k, Tensor logsumexp, float dropout_p, Tensor philox_seed, Tensor philox_offset, int custom_mask_type, bool bias_requires_grad, *, float? scale=None, int? num_splits_key=None, int? window_size=None, bool shared_storage_dqdkdv=False) -> (Tensor, Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"} +::std::tuple _cudnn_attention_forward(const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const ::std::optional & attn_bias, const ::std::optional & cum_seq_q, const ::std::optional & cum_seq_k, c10::SymInt max_q, c10::SymInt max_k, bool compute_log_sumexp, double dropout_p, bool is_causal, bool return_debug_mask, ::std::optional scale); // {"schema": "aten::_cudnn_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, Tensor? cum_seq_q, Tensor? cum_seq_k, SymInt max_q, SymInt max_k, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)", "dispatch": "True", "default": "False"} +::std::tuple _cudnn_attention_backward(const at::Tensor & grad_out, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const at::Tensor & out, const at::Tensor & logsumexp, const at::Tensor & philox_seed, const at::Tensor & philox_offset, const at::Tensor & attn_bias, const at::Tensor & cum_seq_q, const at::Tensor & cum_seq_k, c10::SymInt max_q, c10::SymInt max_k, double dropout_p, bool is_causal, ::std::optional scale); // {"schema": "aten::_cudnn_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor attn_bias, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, *, float? scale=None) -> (Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"} +at::Tensor _triton_scaled_dot_attention(const at::Tensor & q, const at::Tensor & k, const at::Tensor & v, double dropout_p); // {"schema": "aten::_triton_scaled_dot_attention(Tensor q, Tensor k, Tensor v, float dropout_p=0.0) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & _fill_mem_eff_dropout_mask_(at::Tensor & self, double dropout_p, int64_t seed, int64_t offset); // {"schema": "aten::_fill_mem_eff_dropout_mask_(Tensor(a!) self, float dropout_p, int seed, int offset) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor _triton_multi_head_attention(const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, int64_t embed_dim, int64_t num_head, const at::Tensor & qkv_weight, const at::Tensor & qkv_bias, const at::Tensor & proj_weight, const at::Tensor & proj_bias, const ::std::optional & mask); // {"schema": "aten::_triton_multi_head_attention(Tensor query, Tensor key, Tensor value, int embed_dim, int num_head, Tensor qkv_weight, Tensor qkv_bias, Tensor proj_weight, Tensor proj_bias, Tensor? mask=None) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor special_airy_ai(const at::Tensor & x); // {"schema": "aten::special_airy_ai(Tensor x) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & special_airy_ai_out(const at::Tensor & x, at::Tensor & out); // {"schema": "aten::special_airy_ai.out(Tensor x, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor special_bessel_j0(const at::Tensor & self); // {"schema": "aten::special_bessel_j0(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & special_bessel_j0_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::special_bessel_j0.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor special_bessel_j1(const at::Tensor & self); // {"schema": "aten::special_bessel_j1(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & special_bessel_j1_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::special_bessel_j1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor special_bessel_y0(const at::Tensor & self); // {"schema": "aten::special_bessel_y0(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & special_bessel_y0_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::special_bessel_y0.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor special_bessel_y1(const at::Tensor & self); // {"schema": "aten::special_bessel_y1(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & special_bessel_y1_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::special_bessel_y1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor special_chebyshev_polynomial_t(const at::Tensor & x, const at::Tensor & n); // {"schema": "aten::special_chebyshev_polynomial_t(Tensor x, Tensor n) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor special_chebyshev_polynomial_t(const at::Scalar & x, const at::Tensor & n); // {"schema": "aten::special_chebyshev_polynomial_t.x_scalar(Scalar x, Tensor n) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor special_chebyshev_polynomial_t(const at::Tensor & x, const at::Scalar & n); // {"schema": "aten::special_chebyshev_polynomial_t.n_scalar(Tensor x, Scalar n) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & special_chebyshev_polynomial_t_out(const at::Tensor & x, const at::Tensor & n, at::Tensor & out); // {"schema": "aten::special_chebyshev_polynomial_t.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & special_chebyshev_polynomial_t_out(const at::Scalar & x, const at::Tensor & n, at::Tensor & out); // {"schema": "aten::special_chebyshev_polynomial_t.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & special_chebyshev_polynomial_t_out(const at::Tensor & x, const at::Scalar & n, at::Tensor & out); // {"schema": "aten::special_chebyshev_polynomial_t.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor special_chebyshev_polynomial_u(const at::Tensor & x, const at::Tensor & n); // {"schema": "aten::special_chebyshev_polynomial_u(Tensor x, Tensor n) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor special_chebyshev_polynomial_u(const at::Scalar & x, const at::Tensor & n); // {"schema": "aten::special_chebyshev_polynomial_u.x_scalar(Scalar x, Tensor n) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor special_chebyshev_polynomial_u(const at::Tensor & x, const at::Scalar & n); // {"schema": "aten::special_chebyshev_polynomial_u.n_scalar(Tensor x, Scalar n) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & special_chebyshev_polynomial_u_out(const at::Tensor & x, const at::Tensor & n, at::Tensor & out); // {"schema": "aten::special_chebyshev_polynomial_u.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & special_chebyshev_polynomial_u_out(const at::Scalar & x, const at::Tensor & n, at::Tensor & out); // {"schema": "aten::special_chebyshev_polynomial_u.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & special_chebyshev_polynomial_u_out(const at::Tensor & x, const at::Scalar & n, at::Tensor & out); // {"schema": "aten::special_chebyshev_polynomial_u.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor special_chebyshev_polynomial_v(const at::Tensor & x, const at::Tensor & n); // {"schema": "aten::special_chebyshev_polynomial_v(Tensor x, Tensor n) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor special_chebyshev_polynomial_v(const at::Scalar & x, const at::Tensor & n); // {"schema": "aten::special_chebyshev_polynomial_v.x_scalar(Scalar x, Tensor n) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor special_chebyshev_polynomial_v(const at::Tensor & x, const at::Scalar & n); // {"schema": "aten::special_chebyshev_polynomial_v.n_scalar(Tensor x, Scalar n) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & special_chebyshev_polynomial_v_out(const at::Tensor & x, const at::Tensor & n, at::Tensor & out); // {"schema": "aten::special_chebyshev_polynomial_v.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & special_chebyshev_polynomial_v_out(const at::Scalar & x, const at::Tensor & n, at::Tensor & out); // {"schema": "aten::special_chebyshev_polynomial_v.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & special_chebyshev_polynomial_v_out(const at::Tensor & x, const at::Scalar & n, at::Tensor & out); // {"schema": "aten::special_chebyshev_polynomial_v.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor special_chebyshev_polynomial_w(const at::Tensor & x, const at::Tensor & n); // {"schema": "aten::special_chebyshev_polynomial_w(Tensor x, Tensor n) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor special_chebyshev_polynomial_w(const at::Scalar & x, const at::Tensor & n); // {"schema": "aten::special_chebyshev_polynomial_w.x_scalar(Scalar x, Tensor n) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor special_chebyshev_polynomial_w(const at::Tensor & x, const at::Scalar & n); // {"schema": "aten::special_chebyshev_polynomial_w.n_scalar(Tensor x, Scalar n) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & special_chebyshev_polynomial_w_out(const at::Tensor & x, const at::Tensor & n, at::Tensor & out); // {"schema": "aten::special_chebyshev_polynomial_w.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & special_chebyshev_polynomial_w_out(const at::Scalar & x, const at::Tensor & n, at::Tensor & out); // {"schema": "aten::special_chebyshev_polynomial_w.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & special_chebyshev_polynomial_w_out(const at::Tensor & x, const at::Scalar & n, at::Tensor & out); // {"schema": "aten::special_chebyshev_polynomial_w.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor special_hermite_polynomial_h(const at::Tensor & x, const at::Tensor & n); // {"schema": "aten::special_hermite_polynomial_h(Tensor x, Tensor n) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor special_hermite_polynomial_h(const at::Scalar & x, const at::Tensor & n); // {"schema": "aten::special_hermite_polynomial_h.x_scalar(Scalar x, Tensor n) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor special_hermite_polynomial_h(const at::Tensor & x, const at::Scalar & n); // {"schema": "aten::special_hermite_polynomial_h.n_scalar(Tensor x, Scalar n) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & special_hermite_polynomial_h_out(const at::Tensor & x, const at::Tensor & n, at::Tensor & out); // {"schema": "aten::special_hermite_polynomial_h.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & special_hermite_polynomial_h_out(const at::Scalar & x, const at::Tensor & n, at::Tensor & out); // {"schema": "aten::special_hermite_polynomial_h.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & special_hermite_polynomial_h_out(const at::Tensor & x, const at::Scalar & n, at::Tensor & out); // {"schema": "aten::special_hermite_polynomial_h.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor special_hermite_polynomial_he(const at::Tensor & x, const at::Tensor & n); // {"schema": "aten::special_hermite_polynomial_he(Tensor x, Tensor n) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor special_hermite_polynomial_he(const at::Scalar & x, const at::Tensor & n); // {"schema": "aten::special_hermite_polynomial_he.x_scalar(Scalar x, Tensor n) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor special_hermite_polynomial_he(const at::Tensor & x, const at::Scalar & n); // {"schema": "aten::special_hermite_polynomial_he.n_scalar(Tensor x, Scalar n) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & special_hermite_polynomial_he_out(const at::Tensor & x, const at::Tensor & n, at::Tensor & out); // {"schema": "aten::special_hermite_polynomial_he.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & special_hermite_polynomial_he_out(const at::Scalar & x, const at::Tensor & n, at::Tensor & out); // {"schema": "aten::special_hermite_polynomial_he.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & special_hermite_polynomial_he_out(const at::Tensor & x, const at::Scalar & n, at::Tensor & out); // {"schema": "aten::special_hermite_polynomial_he.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor special_laguerre_polynomial_l(const at::Tensor & x, const at::Tensor & n); // {"schema": "aten::special_laguerre_polynomial_l(Tensor x, Tensor n) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor special_laguerre_polynomial_l(const at::Scalar & x, const at::Tensor & n); // {"schema": "aten::special_laguerre_polynomial_l.x_scalar(Scalar x, Tensor n) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor special_laguerre_polynomial_l(const at::Tensor & x, const at::Scalar & n); // {"schema": "aten::special_laguerre_polynomial_l.n_scalar(Tensor x, Scalar n) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & special_laguerre_polynomial_l_out(const at::Tensor & x, const at::Tensor & n, at::Tensor & out); // {"schema": "aten::special_laguerre_polynomial_l.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & special_laguerre_polynomial_l_out(const at::Scalar & x, const at::Tensor & n, at::Tensor & out); // {"schema": "aten::special_laguerre_polynomial_l.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & special_laguerre_polynomial_l_out(const at::Tensor & x, const at::Scalar & n, at::Tensor & out); // {"schema": "aten::special_laguerre_polynomial_l.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor special_legendre_polynomial_p(const at::Tensor & x, const at::Tensor & n); // {"schema": "aten::special_legendre_polynomial_p(Tensor x, Tensor n) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor special_legendre_polynomial_p(const at::Scalar & x, const at::Tensor & n); // {"schema": "aten::special_legendre_polynomial_p.x_scalar(Scalar x, Tensor n) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor special_legendre_polynomial_p(const at::Tensor & x, const at::Scalar & n); // {"schema": "aten::special_legendre_polynomial_p.n_scalar(Tensor x, Scalar n) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & special_legendre_polynomial_p_out(const at::Tensor & x, const at::Tensor & n, at::Tensor & out); // {"schema": "aten::special_legendre_polynomial_p.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & special_legendre_polynomial_p_out(const at::Scalar & x, const at::Tensor & n, at::Tensor & out); // {"schema": "aten::special_legendre_polynomial_p.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & special_legendre_polynomial_p_out(const at::Tensor & x, const at::Scalar & n, at::Tensor & out); // {"schema": "aten::special_legendre_polynomial_p.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor special_modified_bessel_i0(const at::Tensor & self); // {"schema": "aten::special_modified_bessel_i0(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & special_modified_bessel_i0_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::special_modified_bessel_i0.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor special_modified_bessel_i1(const at::Tensor & self); // {"schema": "aten::special_modified_bessel_i1(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & special_modified_bessel_i1_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::special_modified_bessel_i1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor special_modified_bessel_k0(const at::Tensor & self); // {"schema": "aten::special_modified_bessel_k0(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & special_modified_bessel_k0_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::special_modified_bessel_k0.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor special_modified_bessel_k1(const at::Tensor & self); // {"schema": "aten::special_modified_bessel_k1(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & special_modified_bessel_k1_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::special_modified_bessel_k1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor special_scaled_modified_bessel_k0(const at::Tensor & x); // {"schema": "aten::special_scaled_modified_bessel_k0(Tensor x) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & special_scaled_modified_bessel_k0_out(const at::Tensor & x, at::Tensor & out); // {"schema": "aten::special_scaled_modified_bessel_k0.out(Tensor x, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor special_scaled_modified_bessel_k1(const at::Tensor & x); // {"schema": "aten::special_scaled_modified_bessel_k1(Tensor x) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & special_scaled_modified_bessel_k1_out(const at::Tensor & x, at::Tensor & out); // {"schema": "aten::special_scaled_modified_bessel_k1.out(Tensor x, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor special_shifted_chebyshev_polynomial_t(const at::Tensor & x, const at::Tensor & n); // {"schema": "aten::special_shifted_chebyshev_polynomial_t(Tensor x, Tensor n) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor special_shifted_chebyshev_polynomial_t(const at::Scalar & x, const at::Tensor & n); // {"schema": "aten::special_shifted_chebyshev_polynomial_t.x_scalar(Scalar x, Tensor n) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor special_shifted_chebyshev_polynomial_t(const at::Tensor & x, const at::Scalar & n); // {"schema": "aten::special_shifted_chebyshev_polynomial_t.n_scalar(Tensor x, Scalar n) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & special_shifted_chebyshev_polynomial_t_out(const at::Tensor & x, const at::Tensor & n, at::Tensor & out); // {"schema": "aten::special_shifted_chebyshev_polynomial_t.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & special_shifted_chebyshev_polynomial_t_out(const at::Scalar & x, const at::Tensor & n, at::Tensor & out); // {"schema": "aten::special_shifted_chebyshev_polynomial_t.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & special_shifted_chebyshev_polynomial_t_out(const at::Tensor & x, const at::Scalar & n, at::Tensor & out); // {"schema": "aten::special_shifted_chebyshev_polynomial_t.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor special_shifted_chebyshev_polynomial_u(const at::Tensor & x, const at::Tensor & n); // {"schema": "aten::special_shifted_chebyshev_polynomial_u(Tensor x, Tensor n) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor special_shifted_chebyshev_polynomial_u(const at::Scalar & x, const at::Tensor & n); // {"schema": "aten::special_shifted_chebyshev_polynomial_u.x_scalar(Scalar x, Tensor n) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor special_shifted_chebyshev_polynomial_u(const at::Tensor & x, const at::Scalar & n); // {"schema": "aten::special_shifted_chebyshev_polynomial_u.n_scalar(Tensor x, Scalar n) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & special_shifted_chebyshev_polynomial_u_out(const at::Tensor & x, const at::Tensor & n, at::Tensor & out); // {"schema": "aten::special_shifted_chebyshev_polynomial_u.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & special_shifted_chebyshev_polynomial_u_out(const at::Scalar & x, const at::Tensor & n, at::Tensor & out); // {"schema": "aten::special_shifted_chebyshev_polynomial_u.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & special_shifted_chebyshev_polynomial_u_out(const at::Tensor & x, const at::Scalar & n, at::Tensor & out); // {"schema": "aten::special_shifted_chebyshev_polynomial_u.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor special_shifted_chebyshev_polynomial_v(const at::Tensor & x, const at::Tensor & n); // {"schema": "aten::special_shifted_chebyshev_polynomial_v(Tensor x, Tensor n) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor special_shifted_chebyshev_polynomial_v(const at::Scalar & x, const at::Tensor & n); // {"schema": "aten::special_shifted_chebyshev_polynomial_v.x_scalar(Scalar x, Tensor n) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor special_shifted_chebyshev_polynomial_v(const at::Tensor & x, const at::Scalar & n); // {"schema": "aten::special_shifted_chebyshev_polynomial_v.n_scalar(Tensor x, Scalar n) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & special_shifted_chebyshev_polynomial_v_out(const at::Tensor & x, const at::Tensor & n, at::Tensor & out); // {"schema": "aten::special_shifted_chebyshev_polynomial_v.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & special_shifted_chebyshev_polynomial_v_out(const at::Scalar & x, const at::Tensor & n, at::Tensor & out); // {"schema": "aten::special_shifted_chebyshev_polynomial_v.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & special_shifted_chebyshev_polynomial_v_out(const at::Tensor & x, const at::Scalar & n, at::Tensor & out); // {"schema": "aten::special_shifted_chebyshev_polynomial_v.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor special_shifted_chebyshev_polynomial_w(const at::Tensor & x, const at::Tensor & n); // {"schema": "aten::special_shifted_chebyshev_polynomial_w(Tensor x, Tensor n) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor special_shifted_chebyshev_polynomial_w(const at::Scalar & x, const at::Tensor & n); // {"schema": "aten::special_shifted_chebyshev_polynomial_w.x_scalar(Scalar x, Tensor n) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor special_shifted_chebyshev_polynomial_w(const at::Tensor & x, const at::Scalar & n); // {"schema": "aten::special_shifted_chebyshev_polynomial_w.n_scalar(Tensor x, Scalar n) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & special_shifted_chebyshev_polynomial_w_out(const at::Tensor & x, const at::Tensor & n, at::Tensor & out); // {"schema": "aten::special_shifted_chebyshev_polynomial_w.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & special_shifted_chebyshev_polynomial_w_out(const at::Scalar & x, const at::Tensor & n, at::Tensor & out); // {"schema": "aten::special_shifted_chebyshev_polynomial_w.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & special_shifted_chebyshev_polynomial_w_out(const at::Tensor & x, const at::Scalar & n, at::Tensor & out); // {"schema": "aten::special_shifted_chebyshev_polynomial_w.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor special_spherical_bessel_j0(const at::Tensor & x); // {"schema": "aten::special_spherical_bessel_j0(Tensor x) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & special_spherical_bessel_j0_out(const at::Tensor & x, at::Tensor & out); // {"schema": "aten::special_spherical_bessel_j0.out(Tensor x, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor _foobar(const at::Tensor & self, bool arg1, bool arg2, bool arg3); // {"schema": "aten::_foobar(Tensor self, bool arg1=True, bool arg2=True, *, bool arg3=True) -> Tensor", "dispatch": "True", "default": "False"} +void _fused_adam_(at::TensorList self, at::TensorList grads, at::TensorList exp_avgs, at::TensorList exp_avg_sqs, at::TensorList max_exp_avg_sqs, at::TensorList state_steps, double lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const ::std::optional & grad_scale, const ::std::optional & found_inf); // {"schema": "aten::_fused_adam_(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] exp_avgs, Tensor(d!)[] exp_avg_sqs, Tensor(e!)[] max_exp_avg_sqs, Tensor[] state_steps, *, float lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> ()", "dispatch": "True", "default": "False"} +void _fused_adam_(at::TensorList self, at::TensorList grads, at::TensorList exp_avgs, at::TensorList exp_avg_sqs, at::TensorList max_exp_avg_sqs, at::TensorList state_steps, const at::Tensor & lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const ::std::optional & grad_scale, const ::std::optional & found_inf); // {"schema": "aten::_fused_adam_.tensor_lr(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] exp_avgs, Tensor(d!)[] exp_avg_sqs, Tensor(e!)[] max_exp_avg_sqs, Tensor[] state_steps, *, Tensor lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> ()", "dispatch": "True", "default": "False"} +void _fused_adamw_(at::TensorList self, at::TensorList grads, at::TensorList exp_avgs, at::TensorList exp_avg_sqs, at::TensorList max_exp_avg_sqs, at::TensorList state_steps, double lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const ::std::optional & grad_scale, const ::std::optional & found_inf); // {"schema": "aten::_fused_adamw_(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] exp_avgs, Tensor(d!)[] exp_avg_sqs, Tensor(e!)[] max_exp_avg_sqs, Tensor[] state_steps, *, float lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> ()", "dispatch": "True", "default": "False"} +void _fused_adamw_(at::TensorList self, at::TensorList grads, at::TensorList exp_avgs, at::TensorList exp_avg_sqs, at::TensorList max_exp_avg_sqs, at::TensorList state_steps, const at::Tensor & lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const ::std::optional & grad_scale, const ::std::optional & found_inf); // {"schema": "aten::_fused_adamw_.tensor_lr(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] exp_avgs, Tensor(d!)[] exp_avg_sqs, Tensor(e!)[] max_exp_avg_sqs, Tensor[] state_steps, *, Tensor lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> ()", "dispatch": "True", "default": "False"} +void _fused_sgd_(at::TensorList self, at::TensorList grads, at::TensorList momentum_buffer_list, double weight_decay, double momentum, double lr, double dampening, bool nesterov, bool maximize, bool is_first_step, const ::std::optional & grad_scale, const ::std::optional & found_inf); // {"schema": "aten::_fused_sgd_(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] momentum_buffer_list, *, float weight_decay, float momentum, float lr, float dampening, bool nesterov, bool maximize, bool is_first_step, Tensor? grad_scale=None, Tensor? found_inf=None) -> ()", "dispatch": "True", "default": "False"} +void _fused_sgd_(at::TensorList self, at::TensorList grads, at::TensorList momentum_buffer_list, double weight_decay, double momentum, const at::Tensor & lr, double dampening, bool nesterov, bool maximize, bool is_first_step, const ::std::optional & grad_scale, const ::std::optional & found_inf); // {"schema": "aten::_fused_sgd_.tensor_lr(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] momentum_buffer_list, *, float weight_decay, float momentum, Tensor lr, float dampening, bool nesterov, bool maximize, bool is_first_step, Tensor? grad_scale=None, Tensor? found_inf=None) -> ()", "dispatch": "True", "default": "False"} +void _fused_adagrad_(at::TensorList self, at::TensorList grads, at::TensorList state_sums, at::TensorList state_steps, double lr, double lr_decay, double weight_decay, double eps, bool maximize, const ::std::optional & grad_scale, const ::std::optional & found_inf); // {"schema": "aten::_fused_adagrad_(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] state_sums, Tensor(d!)[] state_steps, *, float lr, float lr_decay, float weight_decay, float eps, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> ()", "dispatch": "True", "default": "False"} +void _fused_adagrad_(at::TensorList self, at::TensorList grads, at::TensorList state_sums, at::TensorList state_steps, const at::Tensor & lr, double lr_decay, double weight_decay, double eps, bool maximize, const ::std::optional & grad_scale, const ::std::optional & found_inf); // {"schema": "aten::_fused_adagrad_.tensor_lr(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] state_sums, Tensor[] state_steps, *, Tensor lr, float lr_decay, float weight_decay, float eps, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> ()", "dispatch": "True", "default": "False"} +void _propagate_xla_data(const at::Tensor & input, const at::Tensor & output); // {"schema": "aten::_propagate_xla_data(Tensor input, Tensor output) -> ()", "dispatch": "False", "default": "True"} +at::Tensor & _new_zeros_with_same_feature_meta_out(const at::Tensor & self, const at::Tensor & other, int64_t self_num_batch_dims, at::Tensor & out); // {"schema": "aten::_new_zeros_with_same_feature_meta.out(Tensor self, Tensor other, *, int self_num_batch_dims=0, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +::std::tuple _cudnn_ctc_loss_out(const at::Tensor & log_probs, const at::Tensor & targets, at::IntArrayRef input_lengths, at::IntArrayRef target_lengths, int64_t blank, bool deterministic, bool zero_infinity, at::Tensor & out0, at::Tensor & out1); // {"schema": "aten::_cudnn_ctc_loss.out(Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, int blank, bool deterministic, bool zero_infinity, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "True"} +at::Tensor & _cudnn_rnn_flatten_weight_out(at::TensorList weight_arr, int64_t weight_stride0, c10::SymInt input_size, int64_t mode, c10::SymInt hidden_size, c10::SymInt proj_size, int64_t num_layers, bool batch_first, bool bidirectional, at::Tensor & out); // {"schema": "aten::_cudnn_rnn_flatten_weight.out(Tensor[] weight_arr, int weight_stride0, SymInt input_size, int mode, SymInt hidden_size, SymInt proj_size, int num_layers, bool batch_first, bool bidirectional, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +::std::tuple _cudnn_rnn_out(const at::Tensor & input, at::TensorList weight, int64_t weight_stride0, const ::std::optional & weight_buf, const at::Tensor & hx, const ::std::optional & cx, int64_t mode, c10::SymInt hidden_size, c10::SymInt proj_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, c10::SymIntArrayRef batch_sizes, const ::std::optional & dropout_state, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3, at::Tensor & out4); // {"schema": "aten::_cudnn_rnn.out(Tensor input, Tensor[] weight, int weight_stride0, Tensor? weight_buf, Tensor hx, Tensor? cx, int mode, SymInt hidden_size, SymInt proj_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, SymInt[] batch_sizes, Tensor? dropout_state, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3, Tensor(e!) out4) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!), Tensor(e!))", "dispatch": "True", "default": "True"} +void _cudnn_rnn_backward_out(const at::Tensor & input, at::TensorList weight, int64_t weight_stride0, const at::Tensor & weight_buf, const at::Tensor & hx, const ::std::optional & cx, const at::Tensor & output, const ::std::optional & grad_output, const ::std::optional & grad_hy, const ::std::optional & grad_cy, int64_t mode, c10::SymInt hidden_size, c10::SymInt proj_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, c10::SymIntArrayRef batch_sizes, const ::std::optional & dropout_state, const at::Tensor & reserve, ::std::array output_mask, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::TensorList out3); // {"schema": "aten::_cudnn_rnn_backward.out(Tensor input, Tensor[] weight, int weight_stride0, Tensor weight_buf, Tensor hx, Tensor? cx, Tensor output, Tensor? grad_output, Tensor? grad_hy, Tensor? grad_cy, int mode, SymInt hidden_size, SymInt proj_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, SymInt[] batch_sizes, Tensor? dropout_state, Tensor reserve, bool[4] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!)[] out3) -> ()", "dispatch": "True", "default": "True"} +at::Tensor & _cudnn_init_dropout_state_out(double dropout, bool train, int64_t dropout_seed, at::Tensor & out); // {"schema": "aten::_cudnn_init_dropout_state.out(float dropout, bool train, int dropout_seed, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +::std::tuple _fused_dropout_out(const at::Tensor & self, double p, ::std::optional generator, at::Tensor & out0, at::Tensor & out1); // {"schema": "aten::_fused_dropout.out(Tensor self, float p, Generator? generator=None, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "True"} +at::Tensor & _masked_scale_out(const at::Tensor & self, const at::Tensor & mask, double scale, at::Tensor & out); // {"schema": "aten::_masked_scale.out(Tensor self, Tensor mask, float scale, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +::std::tuple native_dropout_out(const at::Tensor & input, double p, ::std::optional train, at::Tensor & out0, at::Tensor & out1); // {"schema": "aten::native_dropout.out(Tensor input, float p, bool? train, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "True"} +at::Tensor & native_dropout_backward_out(const at::Tensor & grad_output, const at::Tensor & mask, double scale, at::Tensor & out); // {"schema": "aten::native_dropout_backward.out(Tensor grad_output, Tensor mask, float scale, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _conj_physical_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::_conj_physical.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & avg_pool1d_out(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, at::Tensor & out); // {"schema": "aten::avg_pool1d.out(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, bool ceil_mode=False, bool count_include_pad=True, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & adaptive_avg_pool1d_out(const at::Tensor & self, at::IntArrayRef output_size, at::Tensor & out); // {"schema": "aten::adaptive_avg_pool1d.out(Tensor self, int[1] output_size, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _add_relu_out(const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha, at::Tensor & out); // {"schema": "aten::_add_relu.Scalar_out(Tensor self, Scalar other, Scalar alpha=1, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & add_out(const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha, at::Tensor & out); // {"schema": "aten::add.Scalar_out(Tensor self, Scalar other, Scalar alpha=1, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & affine_grid_generator_out(const at::Tensor & theta, c10::SymIntArrayRef size, bool align_corners, at::Tensor & out); // {"schema": "aten::affine_grid_generator.out(Tensor theta, SymInt[] size, bool align_corners, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _test_functorch_fallback_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::_test_functorch_fallback.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & bartlett_window_out(int64_t window_length, at::Tensor & out); // {"schema": "aten::bartlett_window.out(int window_length, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & bartlett_window_out(int64_t window_length, bool periodic, at::Tensor & out); // {"schema": "aten::bartlett_window.periodic_out(int window_length, bool periodic, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & quantized_batch_norm_out(const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const at::Tensor & mean, const at::Tensor & var, double eps, double output_scale, int64_t output_zero_point, at::Tensor & out); // {"schema": "aten::quantized_batch_norm.out(Tensor input, Tensor? weight, Tensor? bias, Tensor mean, Tensor var, float eps, float output_scale, int output_zero_point, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & bernoulli_out(const at::Tensor & self, const at::Tensor & p, ::std::optional generator, at::Tensor & out); // {"schema": "aten::bernoulli.Tensor_out(Tensor self, Tensor p, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor bernoulli(const at::Tensor & self, const at::Tensor & p, ::std::optional generator); // {"schema": "aten::bernoulli.Tensor(Tensor self, Tensor p, *, Generator? generator=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & bernoulli_out(const at::Tensor & self, double p, ::std::optional generator, at::Tensor & out); // {"schema": "aten::bernoulli.float_out(Tensor self, float p=0.5, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & binary_cross_entropy_with_logits_out(const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, const ::std::optional & pos_weight, int64_t reduction, at::Tensor & out); // {"schema": "aten::binary_cross_entropy_with_logits.out(Tensor self, Tensor target, Tensor? weight=None, Tensor? pos_weight=None, int reduction=Mean, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & bincount_out(const at::Tensor & self, const ::std::optional & weights, c10::SymInt minlength, at::Tensor & out); // {"schema": "aten::bincount.out(Tensor self, Tensor? weights=None, SymInt minlength=0, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & blackman_window_out(int64_t window_length, at::Tensor & out); // {"schema": "aten::blackman_window.out(int window_length, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & blackman_window_out(int64_t window_length, bool periodic, at::Tensor & out); // {"schema": "aten::blackman_window.periodic_out(int window_length, bool periodic, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & block_diag_out(at::TensorList tensors, at::Tensor & out); // {"schema": "aten::block_diag.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & constant_pad_nd_out(const at::Tensor & self, c10::SymIntArrayRef pad, const at::Scalar & value, at::Tensor & out); // {"schema": "aten::constant_pad_nd.out(Tensor self, SymInt[] pad, Scalar value=0, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & convolution_out(const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups, at::Tensor & out); // {"schema": "aten::convolution.out(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +::std::tuple convolution_backward_out(const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & weight, at::OptionalSymIntArrayRef bias_sizes, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups, ::std::array output_mask, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2); // {"schema": "aten::convolution_backward.out(Tensor grad_output, Tensor input, Tensor weight, SymInt[]? bias_sizes, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))", "dispatch": "True", "default": "True"} +at::Tensor & convolution_overrideable_out(const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups, at::Tensor & out); // {"schema": "aten::convolution_overrideable.out(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +::std::tuple convolution_backward_overrideable_out(const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & weight, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups, ::std::array output_mask, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2); // {"schema": "aten::convolution_backward_overrideable.out(Tensor grad_output, Tensor input, Tensor weight, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))", "dispatch": "True", "default": "True"} +at::Tensor & _convolution_out(const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32, at::Tensor & out); // {"schema": "aten::_convolution.out(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & conv_tbc_out(const at::Tensor & self, const at::Tensor & weight, const at::Tensor & bias, int64_t pad, at::Tensor & out); // {"schema": "aten::conv_tbc.out(Tensor self, Tensor weight, Tensor bias, int pad=0, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & copy_out(const at::Tensor & self, const at::Tensor & src, bool non_blocking, at::Tensor & out); // {"schema": "aten::copy.out(Tensor self, Tensor src, bool non_blocking=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _copy_from_out(const at::Tensor & self, const at::Tensor & dst, bool non_blocking, at::Tensor & out); // {"schema": "aten::_copy_from.out(Tensor self, Tensor dst, bool non_blocking=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _copy_from_and_resize_out(const at::Tensor & self, const at::Tensor & dst, at::Tensor & out); // {"schema": "aten::_copy_from_and_resize.out(Tensor self, Tensor dst, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & count_nonzero_out(const at::Tensor & self, at::IntArrayRef dim, at::Tensor & out); // {"schema": "aten::count_nonzero.dim_IntList_out(Tensor self, int[] dim, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & count_nonzero_out(const at::Tensor & self, ::std::optional dim, at::Tensor & out); // {"schema": "aten::count_nonzero.out(Tensor self, int? dim=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & cudnn_affine_grid_generator_out(const at::Tensor & theta, int64_t N, int64_t C, int64_t H, int64_t W, at::Tensor & out); // {"schema": "aten::cudnn_affine_grid_generator.out(Tensor theta, int N, int C, int H, int W, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & cudnn_affine_grid_generator_backward_out(const at::Tensor & grad, int64_t N, int64_t C, int64_t H, int64_t W, at::Tensor & out); // {"schema": "aten::cudnn_affine_grid_generator_backward.out(Tensor grad, int N, int C, int H, int W, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +::std::tuple cudnn_batch_norm_backward_out(const at::Tensor & input, const at::Tensor & grad_output, const at::Tensor & weight, const ::std::optional & running_mean, const ::std::optional & running_var, const ::std::optional & save_mean, const ::std::optional & save_var, double epsilon, const at::Tensor & reserveSpace, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2); // {"schema": "aten::cudnn_batch_norm_backward.out(Tensor input, Tensor grad_output, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, float epsilon, Tensor reserveSpace, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))", "dispatch": "True", "default": "True"} +at::Tensor & cudnn_convolution_transpose_out(const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, bool benchmark, bool deterministic, bool allow_tf32, at::Tensor & out); // {"schema": "aten::cudnn_convolution_transpose.out(Tensor self, Tensor weight, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, bool allow_tf32, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _mps_convolution_transpose_out(const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, at::Tensor & out); // {"schema": "aten::_mps_convolution_transpose.out(Tensor self, Tensor weight, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +::std::tuple mps_convolution_transpose_backward_out(const at::Tensor & self, const at::Tensor & grad_output, const at::Tensor & weight, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, ::std::array output_mask, at::Tensor & out0, at::Tensor & out1); // {"schema": "aten::mps_convolution_transpose_backward.out(Tensor self, Tensor grad_output, Tensor weight, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool[2] output_mask, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "True"} +at::Tensor & cudnn_convolution_relu_out(const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, c10::SymInt groups, at::Tensor & out); // {"schema": "aten::cudnn_convolution_relu.out(Tensor self, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, SymInt groups, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & cudnn_convolution_add_relu_out(const at::Tensor & self, const at::Tensor & weight, const at::Tensor & z, const ::std::optional & alpha, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, c10::SymInt groups, at::Tensor & out); // {"schema": "aten::cudnn_convolution_add_relu.out(Tensor self, Tensor weight, Tensor z, Scalar? alpha, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, SymInt groups, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & cudnn_grid_sampler_out(const at::Tensor & self, const at::Tensor & grid, at::Tensor & out); // {"schema": "aten::cudnn_grid_sampler.out(Tensor self, Tensor grid, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +::std::tuple cudnn_grid_sampler_backward_out(const at::Tensor & self, const at::Tensor & grid, const at::Tensor & grad_output, at::Tensor & out0, at::Tensor & out1); // {"schema": "aten::cudnn_grid_sampler_backward.out(Tensor self, Tensor grid, Tensor grad_output, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "True"} +::std::tuple _ctc_loss_out(const at::Tensor & log_probs, const at::Tensor & targets, at::IntArrayRef input_lengths, at::IntArrayRef target_lengths, int64_t blank, bool zero_infinity, at::Tensor & out0, at::Tensor & out1); // {"schema": "aten::_ctc_loss.out(Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, int blank=0, bool zero_infinity=False, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "True"} +::std::tuple _ctc_loss_out(const at::Tensor & log_probs, const at::Tensor & targets, const at::Tensor & input_lengths, const at::Tensor & target_lengths, int64_t blank, bool zero_infinity, at::Tensor & out0, at::Tensor & out1); // {"schema": "aten::_ctc_loss.Tensor_out(Tensor log_probs, Tensor targets, Tensor input_lengths, Tensor target_lengths, int blank=0, bool zero_infinity=False, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "True"} +at::Tensor & _ctc_loss_backward_out(const at::Tensor & grad, const at::Tensor & log_probs, const at::Tensor & targets, at::IntArrayRef input_lengths, at::IntArrayRef target_lengths, const at::Tensor & neg_log_likelihood, const at::Tensor & log_alpha, int64_t blank, bool zero_infinity, at::Tensor & out); // {"schema": "aten::_ctc_loss_backward.out(Tensor grad, Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, Tensor neg_log_likelihood, Tensor log_alpha, int blank, bool zero_infinity=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & diag_embed_out(const at::Tensor & self, int64_t offset, int64_t dim1, int64_t dim2, at::Tensor & out); // {"schema": "aten::diag_embed.out(Tensor self, int offset=0, int dim1=-2, int dim2=-1, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & diagonal_backward_out(const at::Tensor & grad_output, c10::SymIntArrayRef input_sizes, int64_t offset, int64_t dim1, int64_t dim2, at::Tensor & out); // {"schema": "aten::diagonal_backward.out(Tensor grad_output, SymInt[] input_sizes, int offset, int dim1, int dim2, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & div_out(const at::Tensor & self, const at::Scalar & other, at::Tensor & out); // {"schema": "aten::div.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & div_out(const at::Tensor & self, const at::Scalar & other, ::std::optional rounding_mode, at::Tensor & out); // {"schema": "aten::div.Scalar_mode_out(Tensor self, Scalar other, *, str? rounding_mode, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & embedding_out(const at::Tensor & weight, const at::Tensor & indices, c10::SymInt padding_idx, bool scale_grad_by_freq, bool sparse, at::Tensor & out); // {"schema": "aten::embedding.out(Tensor weight, Tensor indices, SymInt padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & embedding_dense_backward_out(const at::Tensor & grad_output, const at::Tensor & indices, c10::SymInt num_weights, c10::SymInt padding_idx, bool scale_grad_by_freq, at::Tensor & out); // {"schema": "aten::embedding_dense_backward.out(Tensor grad_output, Tensor indices, SymInt num_weights, SymInt padding_idx, bool scale_grad_by_freq, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & embedding_renorm_out(const at::Tensor & self, const at::Tensor & indices, double max_norm, double norm_type, at::Tensor & out); // {"schema": "aten::embedding_renorm.out(Tensor self, Tensor indices, float max_norm, float norm_type, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor embedding_renorm(const at::Tensor & self, const at::Tensor & indices, double max_norm, double norm_type); // {"schema": "aten::embedding_renorm(Tensor self, Tensor indices, float max_norm, float norm_type) -> Tensor", "dispatch": "True", "default": "True"} +::std::tuple _embedding_bag_forward_only_out(const at::Tensor & weight, const at::Tensor & indices, const at::Tensor & offsets, bool scale_grad_by_freq, int64_t mode, bool sparse, const ::std::optional & per_sample_weights, bool include_last_offset, int64_t padding_idx, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3); // {"schema": "aten::_embedding_bag_forward_only.out(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False, int padding_idx=-1, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!))", "dispatch": "True", "default": "True"} +::std::tuple _embedding_bag_out(const at::Tensor & weight, const at::Tensor & indices, const at::Tensor & offsets, bool scale_grad_by_freq, int64_t mode, bool sparse, const ::std::optional & per_sample_weights, bool include_last_offset, int64_t padding_idx, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3); // {"schema": "aten::_embedding_bag.out(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False, int padding_idx=-1, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!))", "dispatch": "True", "default": "True"} +at::Tensor & _embedding_bag_dense_backward_out(const at::Tensor & grad, const at::Tensor & indices, const at::Tensor & offset2bag, const at::Tensor & bag_size, const at::Tensor & maximum_indices, c10::SymInt num_weights, bool scale_grad_by_freq, int64_t mode, const ::std::optional & per_sample_weights, int64_t padding_idx, at::Tensor & out); // {"schema": "aten::_embedding_bag_dense_backward.out(Tensor grad, Tensor indices, Tensor offset2bag, Tensor bag_size, Tensor maximum_indices, SymInt num_weights, bool scale_grad_by_freq, int mode, Tensor? per_sample_weights, int padding_idx=-1, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _embedding_bag_per_sample_weights_backward_out(const at::Tensor & grad, const at::Tensor & weight, const at::Tensor & indices, const at::Tensor & offsets, const at::Tensor & offset2bag, int64_t mode, int64_t padding_idx, at::Tensor & out); // {"schema": "aten::_embedding_bag_per_sample_weights_backward.out(Tensor grad, Tensor weight, Tensor indices, Tensor offsets, Tensor offset2bag, int mode, int padding_idx=-1, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & empty_out(at::IntArrayRef size, ::std::optional names, ::std::optional memory_format, at::Tensor & out); // {"schema": "aten::empty.names_out(int[] size, *, Dimname[]? names, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & empty_permuted_out(c10::SymIntArrayRef size, at::IntArrayRef physical_layout, at::Tensor & out); // {"schema": "aten::empty_permuted.out(SymInt[] size, int[] physical_layout, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & new_empty_out(const at::Tensor & self, c10::SymIntArrayRef size, at::Tensor & out); // {"schema": "aten::new_empty.out(Tensor self, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & new_empty_strided_out(const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, at::Tensor & out); // {"schema": "aten::new_empty_strided.out(Tensor self, SymInt[] size, SymInt[] stride, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & new_full_out(const at::Tensor & self, c10::SymIntArrayRef size, const at::Scalar & fill_value, at::Tensor & out); // {"schema": "aten::new_full.out(Tensor self, SymInt[] size, Scalar fill_value, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & new_zeros_out(const at::Tensor & self, c10::SymIntArrayRef size, at::Tensor & out); // {"schema": "aten::new_zeros.out(Tensor self, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & new_ones_out(const at::Tensor & self, c10::SymIntArrayRef size, at::Tensor & out); // {"schema": "aten::new_ones.out(Tensor self, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _empty_affine_quantized_out(c10::SymIntArrayRef size, double scale, int64_t zero_point, ::std::optional memory_format, at::Tensor & out); // {"schema": "aten::_empty_affine_quantized.out(SymInt[] size, *, float scale=1, int zero_point=0, MemoryFormat? memory_format=contiguous_format, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _empty_per_channel_affine_quantized_out(c10::SymIntArrayRef size, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, ::std::optional memory_format, at::Tensor & out); // {"schema": "aten::_empty_per_channel_affine_quantized.out(SymInt[] size, *, Tensor scales, Tensor zero_points, int axis, MemoryFormat? memory_format=contiguous_format, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +const at::Tensor & resize_out(const at::Tensor & self, c10::SymIntArrayRef size, ::std::optional memory_format, const at::Tensor & out); // {"schema": "aten::resize.out(Tensor self, SymInt[] size, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor resize(const at::Tensor & self, c10::SymIntArrayRef size, ::std::optional memory_format); // {"schema": "aten::resize(Tensor self, SymInt[] size, *, MemoryFormat? memory_format=None) -> Tensor", "dispatch": "True", "default": "True"} +const at::Tensor & _resize_output_out(const at::Tensor & self, c10::SymIntArrayRef size, at::Device device, const at::Tensor & out); // {"schema": "aten::_resize_output.out(Tensor self, SymInt[] size, Device device, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor _resize_output(const at::Tensor & self, c10::SymIntArrayRef size, at::Device device); // {"schema": "aten::_resize_output(Tensor self, SymInt[] size, Device device) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & empty_quantized_out(at::IntArrayRef size, const at::Tensor & qtensor, ::std::optional memory_format, at::Tensor & out); // {"schema": "aten::empty_quantized.out(int[] size, Tensor qtensor, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & empty_like_out(const at::Tensor & self, ::std::optional memory_format, at::Tensor & out); // {"schema": "aten::empty_like.out(Tensor self, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & empty_strided_out(c10::SymIntArrayRef size, c10::SymIntArrayRef stride, at::Tensor & out); // {"schema": "aten::empty_strided.out(SymInt[] size, SymInt[] stride, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & fill_out(const at::Tensor & self, const at::Scalar & value, at::Tensor & out); // {"schema": "aten::fill.Scalar_out(Tensor self, Scalar value, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & fill_out(const at::Tensor & self, const at::Tensor & value, at::Tensor & out); // {"schema": "aten::fill.Tensor_out(Tensor self, Tensor value, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & floor_divide_out(const at::Tensor & self, const at::Scalar & other, at::Tensor & out); // {"schema": "aten::floor_divide.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & full_out(at::IntArrayRef size, const at::Scalar & fill_value, ::std::optional names, at::Tensor & out); // {"schema": "aten::full.names_out(int[] size, Scalar fill_value, *, Dimname[]? names, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & full_like_out(const at::Tensor & self, const at::Scalar & fill_value, ::std::optional memory_format, at::Tensor & out); // {"schema": "aten::full_like.out(Tensor self, Scalar fill_value, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & from_file_out(c10::string_view filename, ::std::optional shared, ::std::optional size, at::Tensor & out); // {"schema": "aten::from_file.out(str filename, bool? shared=None, int? size=0, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & grid_sampler_2d_out(const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners, at::Tensor & out); // {"schema": "aten::grid_sampler_2d.out(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +::std::tuple grid_sampler_2d_backward_out(const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners, ::std::array output_mask, at::Tensor & out0, at::Tensor & out1); // {"schema": "aten::grid_sampler_2d_backward.out(Tensor grad_output, Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners, bool[2] output_mask, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "True"} +at::Tensor & _grid_sampler_2d_cpu_fallback_out(const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners, at::Tensor & out); // {"schema": "aten::_grid_sampler_2d_cpu_fallback.out(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & grid_sampler_3d_out(const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners, at::Tensor & out); // {"schema": "aten::grid_sampler_3d.out(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +::std::tuple grid_sampler_3d_backward_out(const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners, ::std::array output_mask, at::Tensor & out0, at::Tensor & out1); // {"schema": "aten::grid_sampler_3d_backward.out(Tensor grad_output, Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners, bool[2] output_mask, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "True"} +at::Tensor & hann_window_out(int64_t window_length, at::Tensor & out); // {"schema": "aten::hann_window.out(int window_length, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & hann_window_out(int64_t window_length, bool periodic, at::Tensor & out); // {"schema": "aten::hann_window.periodic_out(int window_length, bool periodic, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & hamming_window_out(int64_t window_length, at::Tensor & out); // {"schema": "aten::hamming_window.out(int window_length, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & hamming_window_out(int64_t window_length, bool periodic, at::Tensor & out); // {"schema": "aten::hamming_window.periodic_out(int window_length, bool periodic, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & hamming_window_out(int64_t window_length, bool periodic, double alpha, at::Tensor & out); // {"schema": "aten::hamming_window.periodic_alpha_out(int window_length, bool periodic, float alpha, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & hamming_window_out(int64_t window_length, bool periodic, double alpha, double beta, at::Tensor & out); // {"schema": "aten::hamming_window.periodic_alpha_beta_out(int window_length, bool periodic, float alpha, float beta, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & kaiser_window_out(int64_t window_length, at::Tensor & out); // {"schema": "aten::kaiser_window.out(int window_length, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & kaiser_window_out(int64_t window_length, bool periodic, at::Tensor & out); // {"schema": "aten::kaiser_window.periodic_out(int window_length, bool periodic, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & kaiser_window_out(int64_t window_length, bool periodic, double beta, at::Tensor & out); // {"schema": "aten::kaiser_window.beta_out(int window_length, bool periodic, float beta, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +::std::tuple native_group_norm_out(const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, c10::SymInt N, c10::SymInt C, c10::SymInt HxW, int64_t group, double eps, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2); // {"schema": "aten::native_group_norm.out(Tensor input, Tensor? weight, Tensor? bias, SymInt N, SymInt C, SymInt HxW, int group, float eps, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))", "dispatch": "True", "default": "True"} +::std::tuple native_group_norm_backward_out(const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & rstd, const ::std::optional & weight, c10::SymInt N, c10::SymInt C, c10::SymInt HxW, int64_t group, ::std::array output_mask, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2); // {"schema": "aten::native_group_norm_backward.out(Tensor grad_out, Tensor input, Tensor mean, Tensor rstd, Tensor? weight, SymInt N, SymInt C, SymInt HxW, int group, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))", "dispatch": "True", "default": "True"} +at::Tensor & index_put_out(const at::Tensor & self, const c10::List<::std::optional> & indices, const at::Tensor & values, bool accumulate, at::Tensor & out); // {"schema": "aten::index_put.out(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _index_put_impl_out(const at::Tensor & self, const c10::List<::std::optional> & indices, const at::Tensor & values, bool accumulate, bool unsafe, at::Tensor & out); // {"schema": "aten::_index_put_impl.out(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False, bool unsafe=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor _index_put_impl(const at::Tensor & self, const c10::List<::std::optional> & indices, const at::Tensor & values, bool accumulate, bool unsafe); // {"schema": "aten::_index_put_impl(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False, bool unsafe=False) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & isnan_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::isnan.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +::std::tuple native_layer_norm_out(const at::Tensor & input, c10::SymIntArrayRef normalized_shape, const ::std::optional & weight, const ::std::optional & bias, double eps, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2); // {"schema": "aten::native_layer_norm.out(Tensor input, SymInt[] normalized_shape, Tensor? weight, Tensor? bias, float eps, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))", "dispatch": "True", "default": "True"} +::std::tuple native_layer_norm_backward_out(const at::Tensor & grad_out, const at::Tensor & input, c10::SymIntArrayRef normalized_shape, const at::Tensor & mean, const at::Tensor & rstd, const ::std::optional & weight, const ::std::optional & bias, ::std::array output_mask, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2); // {"schema": "aten::native_layer_norm_backward.out(Tensor grad_out, Tensor input, SymInt[] normalized_shape, Tensor mean, Tensor rstd, Tensor? weight, Tensor? bias, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))", "dispatch": "True", "default": "True"} +::std::tuple linear_backward_out(const at::Tensor & self, const at::Tensor & grad_output, const at::Tensor & weight, ::std::array output_mask, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2); // {"schema": "aten::linear_backward.out(Tensor self, Tensor grad_output, Tensor weight, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))", "dispatch": "True", "default": "True"} +at::Tensor & mkldnn_linear_out(const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, at::Tensor & out); // {"schema": "aten::mkldnn_linear.out(Tensor self, Tensor weight, Tensor? bias=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & mkldnn_linear_backward_input_out(at::IntArrayRef input_size, const at::Tensor & grad_output, const at::Tensor & weight, at::Tensor & out); // {"schema": "aten::mkldnn_linear_backward_input.out(int[] input_size, Tensor grad_output, Tensor weight, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +::std::tuple mkldnn_linear_backward_weights_out(const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & weight, bool bias_defined, at::Tensor & out0, at::Tensor & out1); // {"schema": "aten::mkldnn_linear_backward_weights.out(Tensor grad_output, Tensor input, Tensor weight, bool bias_defined, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "True"} +::std::tuple mkldnn_linear_backward_out(const at::Tensor & self, const at::Tensor & grad_output, const at::Tensor & weight, ::std::array output_mask, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2); // {"schema": "aten::mkldnn_linear_backward.out(Tensor self, Tensor grad_output, Tensor weight, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))", "dispatch": "True", "default": "True"} +::std::tuple matmul_backward_out(const at::Tensor & grad, const at::Tensor & self, const at::Tensor & other, ::std::array mask, at::Tensor & out0, at::Tensor & out1); // {"schema": "aten::matmul_backward.out(Tensor grad, Tensor self, Tensor other, bool[2] mask, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "True"} +::std::tuple _aminmax_out(const at::Tensor & self, at::Tensor & out0, at::Tensor & out1); // {"schema": "aten::_aminmax.out(Tensor self, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "True"} +::std::tuple _aminmax_out(const at::Tensor & self, int64_t dim, bool keepdim, at::Tensor & out0, at::Tensor & out1); // {"schema": "aten::_aminmax.dim_out(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "True"} +at::Tensor & max_pool2d_backward_out(const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, at::Tensor & out); // {"schema": "aten::max_pool2d_backward.out(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & mkldnn_max_pool2d_out(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, at::Tensor & out); // {"schema": "aten::mkldnn_max_pool2d.out(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & mkldnn_max_pool2d_backward_out(const at::Tensor & grad_output, const at::Tensor & output, const at::Tensor & input, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, at::Tensor & out); // {"schema": "aten::mkldnn_max_pool2d_backward.out(Tensor grad_output, Tensor output, Tensor input, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & mkldnn_max_pool3d_out(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, at::Tensor & out); // {"schema": "aten::mkldnn_max_pool3d.out(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & mkldnn_max_pool3d_backward_out(const at::Tensor & grad_output, const at::Tensor & output, const at::Tensor & input, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, at::Tensor & out); // {"schema": "aten::mkldnn_max_pool3d_backward.out(Tensor grad_output, Tensor output, Tensor input, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & quantized_max_pool1d_out(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, at::Tensor & out); // {"schema": "aten::quantized_max_pool1d.out(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, int[1] dilation=1, bool ceil_mode=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & quantized_max_pool2d_out(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, at::Tensor & out); // {"schema": "aten::quantized_max_pool2d.out(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & quantized_max_pool3d_out(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, at::Tensor & out); // {"schema": "aten::quantized_max_pool3d.out(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & median_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::median.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & nanmedian_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::nanmedian.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _mps_convolution_out(const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, at::Tensor & out); // {"schema": "aten::_mps_convolution.out(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +::std::tuple mps_convolution_backward_out(const at::Tensor & self, const at::Tensor & grad_output, const at::Tensor & weight, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, ::std::array output_mask, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2); // {"schema": "aten::mps_convolution_backward.out(Tensor self, Tensor grad_output, Tensor weight, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))", "dispatch": "True", "default": "True"} +at::Tensor & mkldnn_convolution_out(const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, at::Tensor & out); // {"schema": "aten::mkldnn_convolution.out(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +::std::tuple mkldnn_rnn_layer_out(const at::Tensor & input, const at::Tensor & weight0, const at::Tensor & weight1, const at::Tensor & weight2, const at::Tensor & weight3, const at::Tensor & hx_, const at::Tensor & cx_, bool reverse, at::IntArrayRef batch_sizes, int64_t mode, int64_t hidden_size, int64_t num_layers, bool has_biases, bool bidirectional, bool batch_first, bool train, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3); // {"schema": "aten::mkldnn_rnn_layer.out(Tensor input, Tensor weight0, Tensor weight1, Tensor weight2, Tensor weight3, Tensor hx_, Tensor cx_, bool reverse, int[] batch_sizes, int mode, int hidden_size, int num_layers, bool has_biases, bool bidirectional, bool batch_first, bool train, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!))", "dispatch": "True", "default": "True"} +::std::tuple mkldnn_rnn_layer_backward_out(const at::Tensor & input, const at::Tensor & weight1, const at::Tensor & weight2, const at::Tensor & weight3, const at::Tensor & weight4, const at::Tensor & hx_, const at::Tensor & cx_tmp, const at::Tensor & output, const at::Tensor & hy_, const at::Tensor & cy_, const ::std::optional & grad_output, const ::std::optional & grad_hy, const ::std::optional & grad_cy, bool reverse, int64_t mode, int64_t hidden_size, int64_t num_layers, bool has_biases, bool train, bool bidirectional, at::IntArrayRef batch_sizes, bool batch_first, const at::Tensor & workspace, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3, at::Tensor & out4, at::Tensor & out5, at::Tensor & out6); // {"schema": "aten::mkldnn_rnn_layer_backward.out(Tensor input, Tensor weight1, Tensor weight2, Tensor weight3, Tensor weight4, Tensor hx_, Tensor cx_tmp, Tensor output, Tensor hy_, Tensor cy_, Tensor? grad_output, Tensor? grad_hy, Tensor? grad_cy, bool reverse, int mode, int hidden_size, int num_layers, bool has_biases, bool train, bool bidirectional, int[] batch_sizes, bool batch_first, Tensor workspace, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3, Tensor(e!) out4, Tensor(f!) out5, Tensor(g!) out6) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!), Tensor(e!), Tensor(f!), Tensor(g!))", "dispatch": "True", "default": "True"} +::std::tuple miopen_batch_norm_out(const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, const ::std::optional & running_mean, const ::std::optional & running_var, bool training, double exponential_average_factor, double epsilon, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2); // {"schema": "aten::miopen_batch_norm.out(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))", "dispatch": "True", "default": "True"} +::std::tuple miopen_batch_norm_backward_out(const at::Tensor & input, const at::Tensor & grad_output, const at::Tensor & weight, const ::std::optional & running_mean, const ::std::optional & running_var, const ::std::optional & save_mean, const ::std::optional & save_var, double epsilon, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2); // {"schema": "aten::miopen_batch_norm_backward.out(Tensor input, Tensor grad_output, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, float epsilon, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))", "dispatch": "True", "default": "True"} +at::Tensor & miopen_convolution_out(const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, bool benchmark, bool deterministic, at::Tensor & out); // {"schema": "aten::miopen_convolution.out(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & miopen_convolution_transpose_out(const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, bool benchmark, bool deterministic, at::Tensor & out); // {"schema": "aten::miopen_convolution_transpose.out(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & miopen_depthwise_convolution_out(const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, bool benchmark, bool deterministic, at::Tensor & out); // {"schema": "aten::miopen_depthwise_convolution.out(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +::std::tuple miopen_rnn_out(const at::Tensor & input, at::TensorList weight, int64_t weight_stride0, const at::Tensor & hx, const ::std::optional & cx, int64_t mode, int64_t hidden_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, at::IntArrayRef batch_sizes, const ::std::optional & dropout_state, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3, at::Tensor & out4); // {"schema": "aten::miopen_rnn.out(Tensor input, Tensor[] weight, int weight_stride0, Tensor hx, Tensor? cx, int mode, int hidden_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, int[] batch_sizes, Tensor? dropout_state, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3, Tensor(e!) out4) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!), Tensor(e!))", "dispatch": "True", "default": "True"} +void miopen_rnn_backward_out(const at::Tensor & input, at::TensorList weight, int64_t weight_stride0, const at::Tensor & weight_buf, const at::Tensor & hx, const ::std::optional & cx, const at::Tensor & output, const ::std::optional & grad_output, const ::std::optional & grad_hy, const ::std::optional & grad_cy, int64_t mode, int64_t hidden_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, at::IntArrayRef batch_sizes, const ::std::optional & dropout_state, const at::Tensor & reserve, ::std::array output_mask, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::TensorList out3); // {"schema": "aten::miopen_rnn_backward.out(Tensor input, Tensor[] weight, int weight_stride0, Tensor weight_buf, Tensor hx, Tensor? cx, Tensor output, Tensor? grad_output, Tensor? grad_hy, Tensor? grad_cy, int mode, int hidden_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, int[] batch_sizes, Tensor? dropout_state, Tensor reserve, bool[4] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!)[] out3) -> ()", "dispatch": "True", "default": "True"} +at::Tensor & _sparse_sparse_matmul_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::_sparse_sparse_matmul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & mul_out(const at::Tensor & self, const at::Scalar & other, at::Tensor & out); // {"schema": "aten::mul.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +::std::tuple _native_batch_norm_legit_functional(const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const at::Tensor & running_mean, const at::Tensor & running_var, bool training, double momentum, double eps); // {"schema": "aten::_native_batch_norm_legit_functional(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor, Tensor running_mean_out, Tensor running_var_out)", "dispatch": "True", "default": "True"} +::std::tuple _native_batch_norm_legit_no_training_out(const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const at::Tensor & running_mean, const at::Tensor & running_var, double momentum, double eps, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2); // {"schema": "aten::_native_batch_norm_legit_no_training.out(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, float momentum, float eps, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))", "dispatch": "True", "default": "True"} +::std::tuple batch_norm_stats_out(const at::Tensor & input, double eps, at::Tensor & out0, at::Tensor & out1); // {"schema": "aten::batch_norm_stats.out(Tensor input, float eps, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "True"} +::std::tuple batch_norm_gather_stats_out(const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const ::std::optional & running_mean, const ::std::optional & running_var, double momentum, double eps, int64_t count, at::Tensor & out0, at::Tensor & out1); // {"schema": "aten::batch_norm_gather_stats.out(Tensor input, Tensor mean, Tensor invstd, Tensor? running_mean, Tensor? running_var, float momentum, float eps, int count, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "True"} +::std::tuple batch_norm_gather_stats_with_counts_out(const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const ::std::optional & running_mean, const ::std::optional & running_var, double momentum, double eps, const at::Tensor & counts, at::Tensor & out0, at::Tensor & out1); // {"schema": "aten::batch_norm_gather_stats_with_counts.out(Tensor input, Tensor mean, Tensor invstd, Tensor? running_mean, Tensor? running_var, float momentum, float eps, Tensor counts, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "True"} +::std::tuple native_batch_norm_backward_out(const at::Tensor & grad_out, const at::Tensor & input, const ::std::optional & weight, const ::std::optional & running_mean, const ::std::optional & running_var, const ::std::optional & save_mean, const ::std::optional & save_invstd, bool train, double eps, ::std::array output_mask, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2); // {"schema": "aten::native_batch_norm_backward.out(Tensor grad_out, Tensor input, Tensor? weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_invstd, bool train, float eps, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))", "dispatch": "True", "default": "True"} +::std::tuple batch_norm_backward_reduce_out(const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const ::std::optional & weight, bool input_g, bool weight_g, bool bias_g, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3); // {"schema": "aten::batch_norm_backward_reduce.out(Tensor grad_out, Tensor input, Tensor mean, Tensor invstd, Tensor? weight, bool input_g, bool weight_g, bool bias_g, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!))", "dispatch": "True", "default": "True"} +at::Tensor & batch_norm_backward_elemt_out(const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const ::std::optional & weight, const at::Tensor & sum_dy, const at::Tensor & sum_dy_xmu, const at::Tensor & count, at::Tensor & out); // {"schema": "aten::batch_norm_backward_elemt.out(Tensor grad_out, Tensor input, Tensor mean, Tensor invstd, Tensor? weight, Tensor sum_dy, Tensor sum_dy_xmu, Tensor count, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +::std::tuple batch_norm_update_stats_out(const at::Tensor & input, const ::std::optional & running_mean, const ::std::optional & running_var, double momentum, at::Tensor & out0, at::Tensor & out1); // {"schema": "aten::batch_norm_update_stats.out(Tensor input, Tensor? running_mean, Tensor? running_var, float momentum, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "True"} +at::Tensor & _nnpack_spatial_convolution_out(const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, at::Tensor & out); // {"schema": "aten::_nnpack_spatial_convolution.out(Tensor input, Tensor weight, Tensor? bias, SymInt[2] padding, SymInt[2] stride=1, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & ones_out(at::IntArrayRef size, ::std::optional names, at::Tensor & out); // {"schema": "aten::ones.names_out(int[] size, *, Dimname[]? names, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & ones_like_out(const at::Tensor & self, ::std::optional memory_format, at::Tensor & out); // {"schema": "aten::ones_like.out(Tensor self, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _euclidean_dist_out(const at::Tensor & x1, const at::Tensor & x2, at::Tensor & out); // {"schema": "aten::_euclidean_dist.out(Tensor x1, Tensor x2, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _cdist_forward_out(const at::Tensor & x1, const at::Tensor & x2, double p, ::std::optional compute_mode, at::Tensor & out); // {"schema": "aten::_cdist_forward.out(Tensor x1, Tensor x2, float p, int? compute_mode, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _cdist_backward_out(const at::Tensor & grad, const at::Tensor & x1, const at::Tensor & x2, double p, const at::Tensor & cdist, at::Tensor & out); // {"schema": "aten::_cdist_backward.out(Tensor grad, Tensor x1, Tensor x2, float p, Tensor cdist, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _pdist_forward_out(const at::Tensor & self, double p, at::Tensor & out); // {"schema": "aten::_pdist_forward.out(Tensor self, float p=2, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _pdist_backward_out(const at::Tensor & grad, const at::Tensor & self, double p, const at::Tensor & pdist, at::Tensor & out); // {"schema": "aten::_pdist_backward.out(Tensor grad, Tensor self, float p, Tensor pdist, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & pixel_shuffle_out(const at::Tensor & self, int64_t upscale_factor, at::Tensor & out); // {"schema": "aten::pixel_shuffle.out(Tensor self, int upscale_factor, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & pixel_unshuffle_out(const at::Tensor & self, int64_t downscale_factor, at::Tensor & out); // {"schema": "aten::pixel_unshuffle.out(Tensor self, int downscale_factor, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & channel_shuffle_out(const at::Tensor & self, c10::SymInt groups, at::Tensor & out); // {"schema": "aten::channel_shuffle.out(Tensor self, SymInt groups, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _pin_memory_out(const at::Tensor & self, ::std::optional device, at::Tensor & out); // {"schema": "aten::_pin_memory.out(Tensor self, Device? device=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & scalar_tensor_out(const at::Scalar & s, at::Tensor & out); // {"schema": "aten::scalar_tensor.out(Scalar s, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & rand_out(c10::SymIntArrayRef size, ::std::optional names, at::Tensor & out); // {"schema": "aten::rand.names_out(SymInt[] size, *, Dimname[]? names, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & rand_out(c10::SymIntArrayRef size, ::std::optional generator, ::std::optional names, at::Tensor & out); // {"schema": "aten::rand.generator_with_names_out(SymInt[] size, *, Generator? generator, Dimname[]? names, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & rand_like_out(const at::Tensor & self, ::std::optional memory_format, at::Tensor & out); // {"schema": "aten::rand_like.out(Tensor self, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & rand_like_out(const at::Tensor & self, ::std::optional generator, ::std::optional memory_format, at::Tensor & out); // {"schema": "aten::rand_like.generator_out(Tensor self, *, Generator? generator, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & randint_like_out(const at::Tensor & self, c10::SymInt high, ::std::optional memory_format, at::Tensor & out); // {"schema": "aten::randint_like.out(Tensor self, SymInt high, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & randint_like_out(const at::Tensor & self, c10::SymInt high, ::std::optional generator, ::std::optional memory_format, at::Tensor & out); // {"schema": "aten::randint_like.generator_out(Tensor self, SymInt high, *, Generator? generator, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & randint_like_out(const at::Tensor & self, const at::Tensor & high, ::std::optional memory_format, at::Tensor & out); // {"schema": "aten::randint_like.Tensor_out(Tensor self, Tensor high, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & randint_like_out(const at::Tensor & self, const at::Tensor & high, ::std::optional generator, ::std::optional memory_format, at::Tensor & out); // {"schema": "aten::randint_like.Tensor_generator_out(Tensor self, Tensor high, *, Generator? generator, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & randint_like_out(const at::Tensor & self, c10::SymInt low, c10::SymInt high, ::std::optional memory_format, at::Tensor & out); // {"schema": "aten::randint_like.low_dtype_out(Tensor self, SymInt low, SymInt high, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & randint_like_out(const at::Tensor & self, c10::SymInt low, c10::SymInt high, ::std::optional generator, ::std::optional memory_format, at::Tensor & out); // {"schema": "aten::randint_like.low_generator_dtype_out(Tensor self, SymInt low, SymInt high, *, Generator? generator, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & randn_out(c10::SymIntArrayRef size, ::std::optional names, at::Tensor & out); // {"schema": "aten::randn.names_out(SymInt[] size, *, Dimname[]? names, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & randn_out(c10::SymIntArrayRef size, ::std::optional generator, ::std::optional names, at::Tensor & out); // {"schema": "aten::randn.generator_with_names_out(SymInt[] size, *, Generator? generator, Dimname[]? names, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & randn_like_out(const at::Tensor & self, ::std::optional memory_format, at::Tensor & out); // {"schema": "aten::randn_like.out(Tensor self, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & randn_like_out(const at::Tensor & self, ::std::optional generator, ::std::optional memory_format, at::Tensor & out); // {"schema": "aten::randn_like.generator_out(Tensor self, *, Generator? generator, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & repeat_out(const at::Tensor & self, c10::SymIntArrayRef repeats, at::Tensor & out); // {"schema": "aten::repeat.out(Tensor self, SymInt[] repeats, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & repeat_interleave_out(const at::Tensor & repeats, ::std::optional output_size, at::Tensor & out); // {"schema": "aten::repeat_interleave.Tensor_out(Tensor repeats, *, SymInt? output_size=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _mkldnn_reshape_out(const at::Tensor & self, at::IntArrayRef shape, at::Tensor & out); // {"schema": "aten::_mkldnn_reshape.out(Tensor self, int[] shape, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & relu_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::relu.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & select_backward_out(const at::Tensor & grad_output, c10::SymIntArrayRef input_sizes, int64_t dim, c10::SymInt index, at::Tensor & out); // {"schema": "aten::select_backward.out(Tensor grad_output, SymInt[] input_sizes, int dim, SymInt index, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & celu_out(const at::Tensor & self, const at::Scalar & alpha, at::Tensor & out); // {"schema": "aten::celu.out(Tensor self, Scalar alpha=1.0, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & slice_backward_out(const at::Tensor & grad_output, c10::SymIntArrayRef input_sizes, int64_t dim, c10::SymInt start, c10::SymInt end, c10::SymInt step, at::Tensor & out); // {"schema": "aten::slice_backward.out(Tensor grad_output, SymInt[] input_sizes, int dim, SymInt start, SymInt end, SymInt step, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & slice_scatter_out(const at::Tensor & self, const at::Tensor & src, int64_t dim, ::std::optional start, ::std::optional end, c10::SymInt step, at::Tensor & out); // {"schema": "aten::slice_scatter.out(Tensor self, Tensor src, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & select_scatter_out(const at::Tensor & self, const at::Tensor & src, int64_t dim, c10::SymInt index, at::Tensor & out); // {"schema": "aten::select_scatter.out(Tensor self, Tensor src, int dim, SymInt index, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & diagonal_scatter_out(const at::Tensor & self, const at::Tensor & src, int64_t offset, int64_t dim1, int64_t dim2, at::Tensor & out); // {"schema": "aten::diagonal_scatter.out(Tensor self, Tensor src, int offset=0, int dim1=0, int dim2=1, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & as_strided_scatter_out(const at::Tensor & self, const at::Tensor & src, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional storage_offset, at::Tensor & out); // {"schema": "aten::as_strided_scatter.out(Tensor self, Tensor src, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +void unsafe_split_out(const at::Tensor & self, c10::SymInt split_size, int64_t dim, at::TensorList out); // {"schema": "aten::unsafe_split.Tensor_out(Tensor self, SymInt split_size, int dim=0, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void unsafe_split_with_sizes_out(const at::Tensor & self, c10::SymIntArrayRef split_sizes, int64_t dim, at::TensorList out); // {"schema": "aten::unsafe_split_with_sizes.out(Tensor self, SymInt[] split_sizes, int dim=0, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +at::Tensor & sum_out(const at::Tensor & self, ::std::optional dtype, at::Tensor & out); // {"schema": "aten::sum.out(Tensor self, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +::std::tuple std_mean_out(const at::Tensor & self, at::OptionalIntArrayRef dim, const ::std::optional & correction, bool keepdim, at::Tensor & out0, at::Tensor & out1); // {"schema": "aten::std_mean.correction_out(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "True"} +at::Tensor & prod_out(const at::Tensor & self, ::std::optional dtype, at::Tensor & out); // {"schema": "aten::prod.out(Tensor self, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _mkldnn_transpose_out(const at::Tensor & self, int64_t dim0, int64_t dim1, at::Tensor & out); // {"schema": "aten::_mkldnn_transpose.out(Tensor self, int dim0, int dim1, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & flip_out(const at::Tensor & self, at::IntArrayRef dims, at::Tensor & out); // {"schema": "aten::flip.out(Tensor self, int[] dims, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & roll_out(const at::Tensor & self, c10::SymIntArrayRef shifts, at::IntArrayRef dims, at::Tensor & out); // {"schema": "aten::roll.out(Tensor self, SymInt[1] shifts, int[1] dims=[], *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & rot90_out(const at::Tensor & self, int64_t k, at::IntArrayRef dims, at::Tensor & out); // {"schema": "aten::rot90.out(Tensor self, int k=1, int[] dims=[0,1], *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +::std::tuple _transform_bias_rescale_qkv_out(const at::Tensor & qkv, const at::Tensor & qkv_bias, int64_t num_heads, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2); // {"schema": "aten::_transform_bias_rescale_qkv.out(Tensor qkv, Tensor qkv_bias, int num_heads, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))", "dispatch": "True", "default": "True"} +at::Tensor & _nested_tensor_from_mask_out(const at::Tensor & t, const at::Tensor & mask, bool mask_check, at::Tensor & out); // {"schema": "aten::_nested_tensor_from_mask.out(Tensor t, Tensor mask, bool mask_check=True, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _nested_from_padded_out(const at::Tensor & padded, const at::Tensor & cpu_nested_shape_example, bool fuse_transform_0213, at::Tensor & out); // {"schema": "aten::_nested_from_padded.out(Tensor padded, Tensor cpu_nested_shape_example, bool fuse_transform_0213=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _nested_tensor_size_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::_nested_tensor_size.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _nested_tensor_strides_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::_nested_tensor_strides.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _nested_tensor_storage_offsets_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::_nested_tensor_storage_offsets.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _nested_from_padded_and_nested_example_out(const at::Tensor & padded, const at::Tensor & nt_example, at::Tensor & out); // {"schema": "aten::_nested_from_padded_and_nested_example.out(Tensor padded, Tensor nt_example, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _nested_view_from_buffer_copy_out(const at::Tensor & self, const at::Tensor & nested_size, const at::Tensor & nested_strides, const at::Tensor & offsets, at::Tensor & out); // {"schema": "aten::_nested_view_from_buffer_copy.out(Tensor self, Tensor nested_size, Tensor nested_strides, Tensor offsets, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _nested_view_from_jagged_copy_out(const at::Tensor & self, const at::Tensor & offsets, const at::Tensor & dummy, const ::std::optional & lengths, int64_t ragged_idx, const ::std::optional & min_seqlen, const ::std::optional & max_seqlen, at::Tensor & out); // {"schema": "aten::_nested_view_from_jagged_copy.out(Tensor self, Tensor offsets, Tensor dummy, Tensor? lengths=None, int ragged_idx=1, Tensor? min_seqlen=None, Tensor? max_seqlen=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _nested_get_values_copy_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::_nested_get_values_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _trilinear_out(const at::Tensor & i1, const at::Tensor & i2, const at::Tensor & i3, at::IntArrayRef expand1, at::IntArrayRef expand2, at::IntArrayRef expand3, at::IntArrayRef sumdim, int64_t unroll_dim, at::Tensor & out); // {"schema": "aten::_trilinear.out(Tensor i1, Tensor i2, Tensor i3, int[] expand1, int[] expand2, int[] expand3, int[] sumdim, int unroll_dim=1, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +::std::tuple _unique_out(const at::Tensor & self, bool sorted, bool return_inverse, at::Tensor & out0, at::Tensor & out1); // {"schema": "aten::_unique.out(Tensor self, bool sorted=True, bool return_inverse=False, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "True"} +::std::tuple unique_dim_out(const at::Tensor & self, int64_t dim, bool sorted, bool return_inverse, bool return_counts, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2); // {"schema": "aten::unique_dim.out(Tensor self, int dim, bool sorted=True, bool return_inverse=False, bool return_counts=False, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))", "dispatch": "True", "default": "True"} +::std::tuple unique_consecutive_out(const at::Tensor & self, bool return_inverse, bool return_counts, ::std::optional dim, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2); // {"schema": "aten::unique_consecutive.out(Tensor self, bool return_inverse=False, bool return_counts=False, int? dim=None, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))", "dispatch": "True", "default": "True"} +::std::tuple unique_dim_consecutive_out(const at::Tensor & self, int64_t dim, bool return_inverse, bool return_counts, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2); // {"schema": "aten::unique_dim_consecutive.out(Tensor self, int dim, bool return_inverse=False, bool return_counts=False, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))", "dispatch": "True", "default": "True"} +::std::tuple _unique2_out(const at::Tensor & self, bool sorted, bool return_inverse, bool return_counts, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2); // {"schema": "aten::_unique2.out(Tensor self, bool sorted=True, bool return_inverse=False, bool return_counts=False, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))", "dispatch": "True", "default": "True"} +at::Tensor & _unsafe_view_out(const at::Tensor & self, c10::SymIntArrayRef size, at::Tensor & out); // {"schema": "aten::_unsafe_view.out(Tensor self, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +::std::tuple var_mean_out(const at::Tensor & self, at::OptionalIntArrayRef dim, const ::std::optional & correction, bool keepdim, at::Tensor & out0, at::Tensor & out1); // {"schema": "aten::var_mean.correction_out(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "True"} +::std::tuple _weight_norm_interface_out(const at::Tensor & v, const at::Tensor & g, int64_t dim, at::Tensor & out0, at::Tensor & out1); // {"schema": "aten::_weight_norm_interface.out(Tensor v, Tensor g, int dim=0, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "True"} +::std::tuple _weight_norm_interface_backward_out(const at::Tensor & grad_w, const at::Tensor & saved_v, const at::Tensor & saved_g, const at::Tensor & saved_norms, int64_t dim, at::Tensor & out0, at::Tensor & out1); // {"schema": "aten::_weight_norm_interface_backward.out(Tensor grad_w, Tensor saved_v, Tensor saved_g, Tensor saved_norms, int dim, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "True"} +at::Tensor & zeros_out(at::IntArrayRef size, ::std::optional names, at::Tensor & out); // {"schema": "aten::zeros.names_out(int[] size, *, Dimname[]? names, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _efficientzerotensor_out(c10::SymIntArrayRef size, at::Tensor & out); // {"schema": "aten::_efficientzerotensor.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & zeros_like_out(const at::Tensor & self, ::std::optional memory_format, at::Tensor & out); // {"schema": "aten::zeros_like.out(Tensor self, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _standard_gamma_grad_out(const at::Tensor & self, const at::Tensor & output, at::Tensor & out); // {"schema": "aten::_standard_gamma_grad.out(Tensor self, Tensor output, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _standard_gamma_out(const at::Tensor & self, ::std::optional generator, at::Tensor & out); // {"schema": "aten::_standard_gamma.out(Tensor self, Generator? generator=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _dirichlet_grad_out(const at::Tensor & x, const at::Tensor & alpha, const at::Tensor & total, at::Tensor & out); // {"schema": "aten::_dirichlet_grad.out(Tensor x, Tensor alpha, Tensor total, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _sample_dirichlet_out(const at::Tensor & self, ::std::optional generator, at::Tensor & out); // {"schema": "aten::_sample_dirichlet.out(Tensor self, Generator? generator=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & poisson_out(const at::Tensor & self, ::std::optional generator, at::Tensor & out); // {"schema": "aten::poisson.out(Tensor self, Generator? generator=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & binomial_out(const at::Tensor & count, const at::Tensor & prob, ::std::optional generator, at::Tensor & out); // {"schema": "aten::binomial.out(Tensor count, Tensor prob, Generator? generator=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & native_norm_out(const at::Tensor & self, const at::Scalar & p, at::Tensor & out); // {"schema": "aten::native_norm.out(Tensor self, Scalar p=2, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & native_norm_out(const at::Tensor & self, const ::std::optional & p, at::IntArrayRef dim, bool keepdim, ::std::optional dtype, at::Tensor & out); // {"schema": "aten::native_norm.ScalarOpt_dim_dtype_out(Tensor self, Scalar? p, int[1] dim, bool keepdim, ScalarType? dtype, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +::std::tuple _batch_norm_with_update_functional(const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const at::Tensor & running_mean, const at::Tensor & running_var, double momentum, double eps); // {"schema": "aten::_batch_norm_with_update_functional(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor, Tensor, Tensor running_mean_out, Tensor running_var_out)", "dispatch": "True", "default": "True"} +::std::tuple _batch_norm_no_update_out(const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const ::std::optional & running_mean, const ::std::optional & running_var, double momentum, double eps, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3); // {"schema": "aten::_batch_norm_no_update.out(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, float momentum, float eps, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!))", "dispatch": "True", "default": "True"} +at::Tensor & _sparse_sum_out(const at::Tensor & self, at::IntArrayRef dim, at::Tensor & out); // {"schema": "aten::_sparse_sum.dim_out(Tensor self, int[1] dim, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _sparse_sum_backward_out(const at::Tensor & grad, const at::Tensor & self, at::IntArrayRef dim, at::Tensor & out); // {"schema": "aten::_sparse_sum_backward.out(Tensor grad, Tensor self, int[] dim, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _sparse_csr_sum_out(const at::Tensor & self, at::IntArrayRef dim, bool keepdim, ::std::optional dtype, at::Tensor & out); // {"schema": "aten::_sparse_csr_sum.dim_dtype_out(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _sparse_csr_prod_out(const at::Tensor & self, at::IntArrayRef dim, bool keepdim, ::std::optional dtype, at::Tensor & out); // {"schema": "aten::_sparse_csr_prod.dim_dtype_out(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _sparse_softmax_out(const at::Tensor & self, int64_t dim, bool half_to_float, at::Tensor & out); // {"schema": "aten::_sparse_softmax.out(Tensor self, int dim, bool half_to_float, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _sparse_softmax_backward_data_out(const at::Tensor & grad_output, const at::Tensor & output, int64_t dim, const at::Tensor & self, at::Tensor & out); // {"schema": "aten::_sparse_softmax_backward_data.out(Tensor grad_output, Tensor output, int dim, Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _sparse_log_softmax_out(const at::Tensor & self, int64_t dim, bool half_to_float, at::Tensor & out); // {"schema": "aten::_sparse_log_softmax.out(Tensor self, int dim, bool half_to_float, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _sparse_log_softmax_backward_data_out(const at::Tensor & grad_output, const at::Tensor & output, int64_t dim, const at::Tensor & self, at::Tensor & out); // {"schema": "aten::_sparse_log_softmax_backward_data.out(Tensor grad_output, Tensor output, int dim, Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _spdiags_out(const at::Tensor & diagonals, const at::Tensor & offsets, at::IntArrayRef shape, ::std::optional layout, at::Tensor & out); // {"schema": "aten::_spdiags.out(Tensor diagonals, Tensor offsets, int[] shape, Layout? layout=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & norm_out(const at::Tensor & self, const ::std::optional & p, at::ScalarType dtype, at::Tensor & out); // {"schema": "aten::norm.ScalarOpt_dtype_out(Tensor self, Scalar? p, *, ScalarType dtype, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & norm_out(const at::Tensor & self, const at::Scalar & p, at::Tensor & out); // {"schema": "aten::norm.Scalar_out(Tensor self, Scalar p=2, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & clone_out(const at::Tensor & self, ::std::optional memory_format, at::Tensor & out); // {"schema": "aten::clone.out(Tensor self, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +const at::Tensor & resize_as_out(const at::Tensor & self, const at::Tensor & the_template, ::std::optional memory_format, const at::Tensor & out); // {"schema": "aten::resize_as.out(Tensor self, Tensor the_template, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor resize_as(const at::Tensor & self, const at::Tensor & the_template, ::std::optional memory_format); // {"schema": "aten::resize_as(Tensor self, Tensor the_template, *, MemoryFormat? memory_format=None) -> Tensor", "dispatch": "True", "default": "True"} +const at::Tensor & resize_as_sparse_out(const at::Tensor & self, const at::Tensor & the_template, const at::Tensor & out); // {"schema": "aten::resize_as_sparse.out(Tensor self, Tensor the_template, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor resize_as_sparse(const at::Tensor & self, const at::Tensor & the_template); // {"schema": "aten::resize_as_sparse(Tensor self, Tensor the_template) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & zero_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::zero.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor zero(const at::Tensor & self); // {"schema": "aten::zero(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & sub_out(const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha, at::Tensor & out); // {"schema": "aten::sub.Scalar_out(Tensor self, Scalar other, Scalar alpha=1, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & rsub_out(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha, at::Tensor & out); // {"schema": "aten::rsub.Tensor_out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & rsub_out(const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha, at::Tensor & out); // {"schema": "aten::rsub.Scalar_out(Tensor self, Scalar other, Scalar alpha=1, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _sparse_addmm_out(const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out); // {"schema": "aten::_sparse_addmm.out(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & sparse_coo_tensor_out(at::IntArrayRef size, at::Tensor & out); // {"schema": "aten::sparse_coo_tensor.size_out(int[] size, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _sparse_coo_tensor_with_dims_out(int64_t sparse_dim, int64_t dense_dim, at::IntArrayRef size, at::Tensor & out); // {"schema": "aten::_sparse_coo_tensor_with_dims.out(int sparse_dim, int dense_dim, int[] size, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _sparse_coo_tensor_with_dims_and_tensors_out(int64_t sparse_dim, int64_t dense_dim, c10::SymIntArrayRef size, const at::Tensor & indices, const at::Tensor & values, ::std::optional is_coalesced, at::Tensor & out); // {"schema": "aten::_sparse_coo_tensor_with_dims_and_tensors.out(int sparse_dim, int dense_dim, SymInt[] size, Tensor indices, Tensor values, *, bool? is_coalesced=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +const at::Tensor & sparse_resize_out(const at::Tensor & self, at::IntArrayRef size, int64_t sparse_dim, int64_t dense_dim, const at::Tensor & out); // {"schema": "aten::sparse_resize.out(Tensor self, int[] size, int sparse_dim, int dense_dim, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor sparse_resize(const at::Tensor & self, at::IntArrayRef size, int64_t sparse_dim, int64_t dense_dim); // {"schema": "aten::sparse_resize(Tensor self, int[] size, int sparse_dim, int dense_dim) -> Tensor", "dispatch": "True", "default": "True"} +const at::Tensor & sparse_resize_and_clear_out(const at::Tensor & self, at::IntArrayRef size, int64_t sparse_dim, int64_t dense_dim, const at::Tensor & out); // {"schema": "aten::sparse_resize_and_clear.out(Tensor self, int[] size, int sparse_dim, int dense_dim, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor sparse_resize_and_clear(const at::Tensor & self, at::IntArrayRef size, int64_t sparse_dim, int64_t dense_dim); // {"schema": "aten::sparse_resize_and_clear(Tensor self, int[] size, int sparse_dim, int dense_dim) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & sparse_mask_out(const at::Tensor & self, const at::Tensor & mask, at::Tensor & out); // {"schema": "aten::sparse_mask.out(Tensor self, Tensor mask, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _sparse_mask_projection_out(const at::Tensor & self, const at::Tensor & mask, bool accumulate_matches, at::Tensor & out); // {"schema": "aten::_sparse_mask_projection.out(Tensor self, Tensor mask, bool accumulate_matches=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _to_dense_out(const at::Tensor & self, ::std::optional dtype, ::std::optional masked_grad, at::Tensor & out); // {"schema": "aten::_to_dense.out(Tensor self, ScalarType? dtype=None, bool? masked_grad=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _coalesce_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::_coalesce.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _coalesced_out(const at::Tensor & self, bool coalesced, at::Tensor & out); // {"schema": "aten::_coalesced.out(Tensor self, bool coalesced, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor _coalesced(const at::Tensor & self, bool coalesced); // {"schema": "aten::_coalesced(Tensor self, bool coalesced) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & copy_sparse_to_sparse_out(const at::Tensor & self, const at::Tensor & src, bool non_blocking, at::Tensor & out); // {"schema": "aten::copy_sparse_to_sparse.out(Tensor self, Tensor src, bool non_blocking=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor copy_sparse_to_sparse(const at::Tensor & self, const at::Tensor & src, bool non_blocking); // {"schema": "aten::copy_sparse_to_sparse(Tensor self, Tensor src, bool non_blocking=False) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & _to_sparse_out(const at::Tensor & self, int64_t sparse_dim, at::Tensor & out); // {"schema": "aten::_to_sparse.sparse_dim_out(Tensor self, int sparse_dim, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _to_sparse_out(const at::Tensor & self, ::std::optional layout, at::OptionalIntArrayRef blocksize, ::std::optional dense_dim, at::Tensor & out); // {"schema": "aten::_to_sparse.out(Tensor self, *, Layout? layout=None, int[2]? blocksize=None, int? dense_dim=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _to_sparse_csr_out(const at::Tensor & self, ::std::optional dense_dim, at::Tensor & out); // {"schema": "aten::_to_sparse_csr.out(Tensor self, int? dense_dim=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _to_sparse_csc_out(const at::Tensor & self, ::std::optional dense_dim, at::Tensor & out); // {"schema": "aten::_to_sparse_csc.out(Tensor self, int? dense_dim=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _to_sparse_bsr_out(const at::Tensor & self, at::IntArrayRef blocksize, ::std::optional dense_dim, at::Tensor & out); // {"schema": "aten::_to_sparse_bsr.out(Tensor self, int[2] blocksize, int? dense_dim=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _to_sparse_bsc_out(const at::Tensor & self, at::IntArrayRef blocksize, ::std::optional dense_dim, at::Tensor & out); // {"schema": "aten::_to_sparse_bsc.out(Tensor self, int[2] blocksize, int? dense_dim=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & to_mkldnn_out(const at::Tensor & self, ::std::optional dtype, at::Tensor & out); // {"schema": "aten::to_mkldnn.out(Tensor self, ScalarType? dtype=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & mkldnn_reorder_conv2d_weight_out(const at::Tensor & self, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, at::OptionalSymIntArrayRef input_size, at::Tensor & out); // {"schema": "aten::mkldnn_reorder_conv2d_weight.out(Tensor self, SymInt[2] padding=0, SymInt[2] stride=1, SymInt[2] dilation=1, SymInt groups=1, SymInt[]? input_size=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & mkldnn_reorder_conv3d_weight_out(const at::Tensor & self, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, at::OptionalSymIntArrayRef input_size, at::Tensor & out); // {"schema": "aten::mkldnn_reorder_conv3d_weight.out(Tensor self, SymInt[3] padding=0, SymInt[3] stride=1, SymInt[3] dilation=1, SymInt groups=1, SymInt[]? input_size=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & quantize_per_tensor_dynamic_out(const at::Tensor & self, at::ScalarType dtype, bool reduce_range, at::Tensor & out); // {"schema": "aten::quantize_per_tensor_dynamic.out(Tensor self, ScalarType dtype, bool reduce_range, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & quantize_per_tensor_out(const at::Tensor & self, double scale, int64_t zero_point, at::ScalarType dtype, at::Tensor & out); // {"schema": "aten::quantize_per_tensor.out(Tensor self, float scale, int zero_point, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & quantize_per_tensor_out(const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, at::ScalarType dtype, at::Tensor & out); // {"schema": "aten::quantize_per_tensor.tensor_qparams_out(Tensor self, Tensor scale, Tensor zero_point, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +void quantize_per_tensor_out(at::TensorList tensors, const at::Tensor & scales, const at::Tensor & zero_points, at::ScalarType dtype, at::TensorList out); // {"schema": "aten::quantize_per_tensor.tensors_out(Tensor[] tensors, Tensor scales, Tensor zero_points, ScalarType dtype, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +at::Tensor & quantize_per_channel_out(const at::Tensor & self, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, at::ScalarType dtype, at::Tensor & out); // {"schema": "aten::quantize_per_channel.out(Tensor self, Tensor scales, Tensor zero_points, int axis, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & dequantize_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::dequantize.self_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +void dequantize_out(at::TensorList tensors, at::TensorList out); // {"schema": "aten::dequantize.tensors_out(Tensor[] tensors, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +at::Tensor & q_per_channel_scales_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::q_per_channel_scales.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & q_per_channel_zero_points_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::q_per_channel_zero_points.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & int_repr_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::int_repr.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _make_per_tensor_quantized_tensor_out(const at::Tensor & self, double scale, int64_t zero_point, at::Tensor & out); // {"schema": "aten::_make_per_tensor_quantized_tensor.out(Tensor self, float scale, int zero_point, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _make_per_channel_quantized_tensor_out(const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, int64_t axis, at::Tensor & out); // {"schema": "aten::_make_per_channel_quantized_tensor.out(Tensor self, Tensor scale, Tensor zero_point, int axis, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +::std::tuple fake_quantize_per_tensor_affine_cachemask_out(const at::Tensor & self, double scale, int64_t zero_point, int64_t quant_min, int64_t quant_max, at::Tensor & out0, at::Tensor & out1); // {"schema": "aten::fake_quantize_per_tensor_affine_cachemask.out(Tensor self, float scale, int zero_point, int quant_min, int quant_max, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "True"} +::std::tuple _fake_quantize_per_tensor_affine_cachemask_tensor_qparams_out(const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, const at::Tensor & fake_quant_enabled, int64_t quant_min, int64_t quant_max, at::Tensor & out0, at::Tensor & out1); // {"schema": "aten::_fake_quantize_per_tensor_affine_cachemask_tensor_qparams.out(Tensor self, Tensor scale, Tensor zero_point, Tensor fake_quant_enabled, int quant_min, int quant_max, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "True"} +at::Tensor & _fake_quantize_learnable_per_tensor_affine_out(const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, int64_t quant_min, int64_t quant_max, double grad_factor, at::Tensor & out); // {"schema": "aten::_fake_quantize_learnable_per_tensor_affine.out(Tensor self, Tensor scale, Tensor zero_point, int quant_min, int quant_max, float grad_factor=1.0, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +::std::tuple fake_quantize_per_channel_affine_cachemask_out(const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, int64_t axis, int64_t quant_min, int64_t quant_max, at::Tensor & out0, at::Tensor & out1); // {"schema": "aten::fake_quantize_per_channel_affine_cachemask.out(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "True"} +at::Tensor & _fake_quantize_learnable_per_channel_affine_out(const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, int64_t axis, int64_t quant_min, int64_t quant_max, double grad_factor, at::Tensor & out); // {"schema": "aten::_fake_quantize_learnable_per_channel_affine.out(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max, float grad_factor=1.0, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +::std::tuple _fused_moving_avg_obs_fq_helper_out(const at::Tensor & self, const at::Tensor & observer_on, const at::Tensor & fake_quant_on, at::Tensor & running_min, at::Tensor & running_max, at::Tensor & scale, at::Tensor & zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, bool per_row_fake_quant, bool symmetric_quant, at::Tensor & out0, at::Tensor & out1); // {"schema": "aten::_fused_moving_avg_obs_fq_helper.out(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor(a!) running_min, Tensor(b!) running_max, Tensor(c!) scale, Tensor(d!) zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False, *, Tensor(e!) out0, Tensor(f!) out1) -> (Tensor(e!), Tensor(f!))", "dispatch": "True", "default": "True"} +::std::tuple _fused_moving_avg_obs_fq_helper_functional(const at::Tensor & self, const at::Tensor & observer_on, const at::Tensor & fake_quant_on, const at::Tensor & running_min, const at::Tensor & running_max, const at::Tensor & scale, const at::Tensor & zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, bool per_row_fake_quant, bool symmetric_quant); // {"schema": "aten::_fused_moving_avg_obs_fq_helper_functional(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor running_min, Tensor running_max, Tensor scale, Tensor zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> (Tensor output, Tensor mask, Tensor running_min_out, Tensor running_max_out, Tensor scale_out, Tensor zero_point_out)", "dispatch": "True", "default": "True"} +at::Tensor & _to_copy_out(const at::Tensor & self, bool non_blocking, ::std::optional memory_format, at::Tensor & out); // {"schema": "aten::_to_copy.out(Tensor self, *, bool non_blocking=False, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +::std::tuple _lstm_mps_out(const at::Tensor & input, at::TensorList hx, at::TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional, bool batch_first, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3, at::Tensor & out4, at::Tensor & out5); // {"schema": "aten::_lstm_mps.out(Tensor input, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3, Tensor(e!) out4, Tensor(f!) out5) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!), Tensor(e!), Tensor(f!))", "dispatch": "True", "default": "True"} +void lstm_mps_backward_out(const ::std::optional & grad_y, const ::std::optional & grad_hy, const ::std::optional & grad_cy, const at::Tensor & z_state, const at::Tensor & cell_state_fwd, const at::Tensor & input, const at::Tensor & layersOutputs, at::TensorList hx, at::TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional, bool batch_first, at::Tensor & out0, at::TensorList out1, at::TensorList out2); // {"schema": "aten::lstm_mps_backward.out(Tensor? grad_y, Tensor? grad_hy, Tensor? grad_cy, Tensor z_state, Tensor cell_state_fwd, Tensor input, Tensor layersOutputs, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first, *, Tensor(a!) out0, Tensor(b!)[] out1, Tensor(c!)[] out2) -> ()", "dispatch": "True", "default": "True"} +::std::tuple _thnn_fused_lstm_cell_out(const at::Tensor & input_gates, const at::Tensor & hidden_gates, const at::Tensor & cx, const ::std::optional & input_bias, const ::std::optional & hidden_bias, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2); // {"schema": "aten::_thnn_fused_lstm_cell.out(Tensor input_gates, Tensor hidden_gates, Tensor cx, Tensor? input_bias=None, Tensor? hidden_bias=None, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))", "dispatch": "True", "default": "True"} +::std::tuple _thnn_fused_lstm_cell_backward_impl_out(const ::std::optional & grad_hy, const ::std::optional & grad_cy, const at::Tensor & cx, const at::Tensor & cy, const at::Tensor & workspace, bool has_bias, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2); // {"schema": "aten::_thnn_fused_lstm_cell_backward_impl.out(Tensor? grad_hy, Tensor? grad_cy, Tensor cx, Tensor cy, Tensor workspace, bool has_bias, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))", "dispatch": "True", "default": "True"} +::std::tuple _thnn_fused_gru_cell_out(const at::Tensor & input_gates, const at::Tensor & hidden_gates, const at::Tensor & hx, const ::std::optional & input_bias, const ::std::optional & hidden_bias, at::Tensor & out0, at::Tensor & out1); // {"schema": "aten::_thnn_fused_gru_cell.out(Tensor input_gates, Tensor hidden_gates, Tensor hx, Tensor? input_bias=None, Tensor? hidden_bias=None, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "True"} +::std::tuple _thnn_fused_gru_cell_backward_out(const at::Tensor & grad_hy, const at::Tensor & workspace, bool has_bias, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3, at::Tensor & out4); // {"schema": "aten::_thnn_fused_gru_cell_backward.out(Tensor grad_hy, Tensor workspace, bool has_bias, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3, Tensor(e!) out4) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!), Tensor(e!))", "dispatch": "True", "default": "True"} +::std::tuple _pack_padded_sequence_out(const at::Tensor & input, const at::Tensor & lengths, bool batch_first, at::Tensor & out0, at::Tensor & out1); // {"schema": "aten::_pack_padded_sequence.out(Tensor input, Tensor lengths, bool batch_first, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "True"} +at::Tensor & set_out(const at::Tensor & self, at::Storage source, at::Tensor & out); // {"schema": "aten::set.source_Storage_out(Tensor self, Storage source, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor set(const at::Tensor & self, at::Storage source); // {"schema": "aten::set.source_Storage(Tensor self, Storage source) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & set_out(const at::Tensor & self, at::Storage source, c10::SymInt storage_offset, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, at::Tensor & out); // {"schema": "aten::set.source_Storage_storage_offset_out(Tensor self, Storage source, SymInt storage_offset, SymInt[] size, SymInt[] stride=[], *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor set(const at::Tensor & self, at::Storage source, c10::SymInt storage_offset, c10::SymIntArrayRef size, c10::SymIntArrayRef stride); // {"schema": "aten::set.source_Storage_storage_offset(Tensor self, Storage source, SymInt storage_offset, SymInt[] size, SymInt[] stride=[]) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & set_out(const at::Tensor & self, const at::Tensor & source, at::Tensor & out); // {"schema": "aten::set.source_Tensor_out(Tensor self, Tensor source, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor set(const at::Tensor & self, const at::Tensor & source); // {"schema": "aten::set.source_Tensor(Tensor self, Tensor source) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & set_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::set.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor set(const at::Tensor & self); // {"schema": "aten::set(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & lift_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::lift.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & lift_fresh_copy_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::lift_fresh_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & masked_fill_out(const at::Tensor & self, const at::Tensor & mask, const at::Scalar & value, at::Tensor & out); // {"schema": "aten::masked_fill.Scalar_out(Tensor self, Tensor mask, Scalar value, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & masked_fill_out(const at::Tensor & self, const at::Tensor & mask, const at::Tensor & value, at::Tensor & out); // {"schema": "aten::masked_fill.Tensor_out(Tensor self, Tensor mask, Tensor value, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & masked_scatter_out(const at::Tensor & self, const at::Tensor & mask, const at::Tensor & source, at::Tensor & out); // {"schema": "aten::masked_scatter.out(Tensor self, Tensor mask, Tensor source, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _masked_softmax_out(const at::Tensor & self, const at::Tensor & mask, ::std::optional dim, ::std::optional mask_type, at::Tensor & out); // {"schema": "aten::_masked_softmax.out(Tensor self, Tensor mask, int? dim=None, int? mask_type=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _masked_softmax_backward_out(const at::Tensor & grad_output, const at::Tensor & output, const at::Tensor & mask, ::std::optional dim, at::Tensor & out); // {"schema": "aten::_masked_softmax_backward.out(Tensor grad_output, Tensor output, Tensor mask, int? dim=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & put_out(const at::Tensor & self, const at::Tensor & index, const at::Tensor & source, bool accumulate, at::Tensor & out); // {"schema": "aten::put.out(Tensor self, Tensor index, Tensor source, bool accumulate=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & index_fill_out(const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Scalar & value, at::Tensor & out); // {"schema": "aten::index_fill.int_Scalar_out(Tensor self, int dim, Tensor index, Scalar value, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & index_fill_out(const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & value, at::Tensor & out); // {"schema": "aten::index_fill.int_Tensor_out(Tensor self, int dim, Tensor index, Tensor value, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & bitwise_and_out(const at::Scalar & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::bitwise_and.Scalar_Tensor_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & bitwise_or_out(const at::Scalar & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::bitwise_or.Scalar_Tensor_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & bitwise_xor_out(const at::Scalar & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::bitwise_xor.Scalar_Tensor_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & __lshift___out(const at::Tensor & self, const at::Scalar & other, at::Tensor & out); // {"schema": "aten::__lshift__.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & __lshift___out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::__lshift__.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & bitwise_left_shift_out(const at::Scalar & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::bitwise_left_shift.Scalar_Tensor_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & __rshift___out(const at::Tensor & self, const at::Scalar & other, at::Tensor & out); // {"schema": "aten::__rshift__.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & __rshift___out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::__rshift__.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & bitwise_right_shift_out(const at::Scalar & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::bitwise_right_shift.Scalar_Tensor_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & random_out(const at::Tensor & self, int64_t from, ::std::optional to, ::std::optional generator, at::Tensor & out); // {"schema": "aten::random.from_out(Tensor self, int from, int? to, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor random(const at::Tensor & self, int64_t from, ::std::optional to, ::std::optional generator); // {"schema": "aten::random.from(Tensor self, int from, int? to, *, Generator? generator=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & random_out(const at::Tensor & self, int64_t to, ::std::optional generator, at::Tensor & out); // {"schema": "aten::random.to_out(Tensor self, int to, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor random(const at::Tensor & self, int64_t to, ::std::optional generator); // {"schema": "aten::random.to(Tensor self, int to, *, Generator? generator=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & random_out(const at::Tensor & self, ::std::optional generator, at::Tensor & out); // {"schema": "aten::random.out(Tensor self, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor random(const at::Tensor & self, ::std::optional generator); // {"schema": "aten::random(Tensor self, *, Generator? generator=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & uniform_out(const at::Tensor & self, double from, double to, ::std::optional generator, at::Tensor & out); // {"schema": "aten::uniform.out(Tensor self, float from=0, float to=1, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor uniform(const at::Tensor & self, double from, double to, ::std::optional generator); // {"schema": "aten::uniform(Tensor self, float from=0, float to=1, *, Generator? generator=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & cauchy_out(const at::Tensor & self, double median, double sigma, ::std::optional generator, at::Tensor & out); // {"schema": "aten::cauchy.out(Tensor self, float median=0, float sigma=1, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor cauchy(const at::Tensor & self, double median, double sigma, ::std::optional generator); // {"schema": "aten::cauchy(Tensor self, float median=0, float sigma=1, *, Generator? generator=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & log_normal_out(const at::Tensor & self, double mean, double std, ::std::optional generator, at::Tensor & out); // {"schema": "aten::log_normal.out(Tensor self, float mean=1, float std=2, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor log_normal(const at::Tensor & self, double mean, double std, ::std::optional generator); // {"schema": "aten::log_normal(Tensor self, float mean=1, float std=2, *, Generator? generator=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & exponential_out(const at::Tensor & self, double lambd, ::std::optional generator, at::Tensor & out); // {"schema": "aten::exponential.out(Tensor self, float lambd=1, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor exponential(const at::Tensor & self, double lambd, ::std::optional generator); // {"schema": "aten::exponential(Tensor self, float lambd=1, *, Generator? generator=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & geometric_out(const at::Tensor & self, double p, ::std::optional generator, at::Tensor & out); // {"schema": "aten::geometric.out(Tensor self, float p, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor geometric(const at::Tensor & self, double p, ::std::optional generator); // {"schema": "aten::geometric(Tensor self, float p, *, Generator? generator=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & tril_indices_out(int64_t row, int64_t col, int64_t offset, at::Tensor & out); // {"schema": "aten::tril_indices.out(int row, int col, int offset=0, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & triu_indices_out(int64_t row, int64_t col, int64_t offset, at::Tensor & out); // {"schema": "aten::triu_indices.out(int row, int col, int offset=0, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & trace_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::trace.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _cholesky_solve_helper_out(const at::Tensor & self, const at::Tensor & A, bool upper, at::Tensor & out); // {"schema": "aten::_cholesky_solve_helper.out(Tensor self, Tensor A, bool upper, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & dist_out(const at::Tensor & self, const at::Tensor & other, const at::Scalar & p, at::Tensor & out); // {"schema": "aten::dist.out(Tensor self, Tensor other, Scalar p=2, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +void _histogramdd_bin_edges_out(const at::Tensor & self, at::IntArrayRef bins, ::std::optional> range, const ::std::optional & weight, bool density, at::TensorList out); // {"schema": "aten::_histogramdd_bin_edges.out(Tensor self, int[] bins, *, float[]? range=None, Tensor? weight=None, bool density=False, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +at::Tensor & _histogramdd_from_bin_cts_out(const at::Tensor & self, at::IntArrayRef bins, ::std::optional> range, const ::std::optional & weight, bool density, at::Tensor & out); // {"schema": "aten::_histogramdd_from_bin_cts.out(Tensor self, int[] bins, *, float[]? range=None, Tensor? weight=None, bool density=False, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _histogramdd_from_bin_tensors_out(const at::Tensor & self, at::TensorList bins, const ::std::optional & weight, bool density, at::Tensor & out); // {"schema": "aten::_histogramdd_from_bin_tensors.out(Tensor self, Tensor[] bins, *, Tensor? weight=None, bool density=False, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & remainder_out(const at::Scalar & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::remainder.Scalar_Tensor_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & unfold_backward_out(const at::Tensor & grad_in, c10::SymIntArrayRef input_sizes, int64_t dim, int64_t size, int64_t step, at::Tensor & out); // {"schema": "aten::unfold_backward.out(Tensor grad_in, SymInt[] input_sizes, int dim, int size, int step, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & normal_out(const at::Tensor & self, double mean, double std, ::std::optional generator, at::Tensor & out); // {"schema": "aten::normal.out(Tensor self, float mean=0, float std=1, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +void _amp_foreach_non_finite_check_and_unscale_out(at::TensorList self, at::Tensor & found_inf, const at::Tensor & inv_scale, at::TensorList out); // {"schema": "aten::_amp_foreach_non_finite_check_and_unscale.out(Tensor[] self, Tensor(b!) found_inf, Tensor inv_scale, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +::std::tuple<::std::vector,at::Tensor> _amp_foreach_non_finite_check_and_unscale(at::TensorList self, const at::Tensor & found_inf, const at::Tensor & inv_scale); // {"schema": "aten::_amp_foreach_non_finite_check_and_unscale(Tensor[] self, Tensor found_inf, Tensor inv_scale) -> (Tensor[] self_out, Tensor found_inf_out)", "dispatch": "True", "default": "True"} +at::Tensor & _amp_update_scale_out(const at::Tensor & self, at::Tensor & growth_tracker, const at::Tensor & found_inf, double scale_growth_factor, double scale_backoff_factor, int64_t growth_interval, at::Tensor & out); // {"schema": "aten::_amp_update_scale.out(Tensor self, Tensor(b!) growth_tracker, Tensor found_inf, float scale_growth_factor, float scale_backoff_factor, int growth_interval, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +::std::tuple _amp_update_scale(const at::Tensor & self, const at::Tensor & growth_tracker, const at::Tensor & found_inf, double scale_growth_factor, double scale_backoff_factor, int64_t growth_interval); // {"schema": "aten::_amp_update_scale(Tensor self, Tensor growth_tracker, Tensor found_inf, float scale_growth_factor, float scale_backoff_factor, int growth_interval) -> (Tensor, Tensor growth_tracker_out)", "dispatch": "True", "default": "True"} +void _foreach_add_out(at::TensorList self, const at::Scalar & scalar, at::TensorList out); // {"schema": "aten::_foreach_add.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_add_out(at::TensorList self, at::TensorList other, const at::Scalar & alpha, at::TensorList out); // {"schema": "aten::_foreach_add.List_out(Tensor[] self, Tensor[] other, *, Scalar alpha=1, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_add_out(at::TensorList self, at::ArrayRef scalars, at::TensorList out); // {"schema": "aten::_foreach_add.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_add_out(at::TensorList self, const at::Tensor & other, const at::Scalar & alpha, at::TensorList out); // {"schema": "aten::_foreach_add.Tensor_out(Tensor[] self, Tensor other, *, Scalar alpha=1, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_sub_out(at::TensorList self, const at::Scalar & scalar, at::TensorList out); // {"schema": "aten::_foreach_sub.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_sub_out(at::TensorList self, at::TensorList other, const at::Scalar & alpha, at::TensorList out); // {"schema": "aten::_foreach_sub.List_out(Tensor[] self, Tensor[] other, *, Scalar alpha=1, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_sub_out(at::TensorList self, at::ArrayRef scalars, at::TensorList out); // {"schema": "aten::_foreach_sub.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_mul_out(at::TensorList self, const at::Scalar & scalar, at::TensorList out); // {"schema": "aten::_foreach_mul.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_mul_out(at::TensorList self, at::TensorList other, at::TensorList out); // {"schema": "aten::_foreach_mul.List_out(Tensor[] self, Tensor[] other, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_mul_out(at::TensorList self, at::ArrayRef scalars, at::TensorList out); // {"schema": "aten::_foreach_mul.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_mul_out(at::TensorList self, const at::Tensor & other, at::TensorList out); // {"schema": "aten::_foreach_mul.Tensor_out(Tensor[] self, Tensor other, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_div_out(at::TensorList self, const at::Scalar & scalar, at::TensorList out); // {"schema": "aten::_foreach_div.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_div_out(at::TensorList self, at::TensorList other, at::TensorList out); // {"schema": "aten::_foreach_div.List_out(Tensor[] self, Tensor[] other, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_div_out(at::TensorList self, at::ArrayRef scalars, at::TensorList out); // {"schema": "aten::_foreach_div.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_div_out(at::TensorList self, const at::Tensor & other, at::TensorList out); // {"schema": "aten::_foreach_div.Tensor_out(Tensor[] self, Tensor other, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_clamp_max_out(at::TensorList self, const at::Scalar & scalar, at::TensorList out); // {"schema": "aten::_foreach_clamp_max.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_clamp_max_out(at::TensorList self, at::TensorList other, at::TensorList out); // {"schema": "aten::_foreach_clamp_max.List_out(Tensor[] self, Tensor[] other, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_clamp_max_out(at::TensorList self, at::ArrayRef scalars, at::TensorList out); // {"schema": "aten::_foreach_clamp_max.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_clamp_min_out(at::TensorList self, const at::Scalar & scalar, at::TensorList out); // {"schema": "aten::_foreach_clamp_min.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_clamp_min_out(at::TensorList self, at::TensorList other, at::TensorList out); // {"schema": "aten::_foreach_clamp_min.List_out(Tensor[] self, Tensor[] other, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_clamp_min_out(at::TensorList self, at::ArrayRef scalars, at::TensorList out); // {"schema": "aten::_foreach_clamp_min.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_maximum_out(at::TensorList self, const at::Scalar & scalar, at::TensorList out); // {"schema": "aten::_foreach_maximum.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_maximum_out(at::TensorList self, at::TensorList other, at::TensorList out); // {"schema": "aten::_foreach_maximum.List_out(Tensor[] self, Tensor[] other, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_maximum_out(at::TensorList self, at::ArrayRef scalars, at::TensorList out); // {"schema": "aten::_foreach_maximum.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_minimum_out(at::TensorList self, const at::Scalar & scalar, at::TensorList out); // {"schema": "aten::_foreach_minimum.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_minimum_out(at::TensorList self, at::TensorList other, at::TensorList out); // {"schema": "aten::_foreach_minimum.List_out(Tensor[] self, Tensor[] other, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_minimum_out(at::TensorList self, at::ArrayRef scalars, at::TensorList out); // {"schema": "aten::_foreach_minimum.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_addcdiv_out(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Scalar & value, at::TensorList out); // {"schema": "aten::_foreach_addcdiv.Scalar_out(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_addcdiv_out(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, at::ArrayRef scalars, at::TensorList out); // {"schema": "aten::_foreach_addcdiv.ScalarList_out(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_addcdiv_out(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Tensor & scalars, at::TensorList out); // {"schema": "aten::_foreach_addcdiv.Tensor_out(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Tensor scalars, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_addcmul_out(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Scalar & value, at::TensorList out); // {"schema": "aten::_foreach_addcmul.Scalar_out(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_addcmul_out(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, at::ArrayRef scalars, at::TensorList out); // {"schema": "aten::_foreach_addcmul.ScalarList_out(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_addcmul_out(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Tensor & scalars, at::TensorList out); // {"schema": "aten::_foreach_addcmul.Tensor_out(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Tensor scalars, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_abs_out(at::TensorList self, at::TensorList out); // {"schema": "aten::_foreach_abs.out(Tensor[] self, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_acos_out(at::TensorList self, at::TensorList out); // {"schema": "aten::_foreach_acos.out(Tensor[] self, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_asin_out(at::TensorList self, at::TensorList out); // {"schema": "aten::_foreach_asin.out(Tensor[] self, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_atan_out(at::TensorList self, at::TensorList out); // {"schema": "aten::_foreach_atan.out(Tensor[] self, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_ceil_out(at::TensorList self, at::TensorList out); // {"schema": "aten::_foreach_ceil.out(Tensor[] self, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_cos_out(at::TensorList self, at::TensorList out); // {"schema": "aten::_foreach_cos.out(Tensor[] self, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_cosh_out(at::TensorList self, at::TensorList out); // {"schema": "aten::_foreach_cosh.out(Tensor[] self, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_erf_out(at::TensorList self, at::TensorList out); // {"schema": "aten::_foreach_erf.out(Tensor[] self, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_erfc_out(at::TensorList self, at::TensorList out); // {"schema": "aten::_foreach_erfc.out(Tensor[] self, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_exp_out(at::TensorList self, at::TensorList out); // {"schema": "aten::_foreach_exp.out(Tensor[] self, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_expm1_out(at::TensorList self, at::TensorList out); // {"schema": "aten::_foreach_expm1.out(Tensor[] self, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_floor_out(at::TensorList self, at::TensorList out); // {"schema": "aten::_foreach_floor.out(Tensor[] self, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_frac_out(at::TensorList self, at::TensorList out); // {"schema": "aten::_foreach_frac.out(Tensor[] self, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_lerp_out(at::TensorList self, at::TensorList tensors1, at::TensorList weights, at::TensorList out); // {"schema": "aten::_foreach_lerp.List_out(Tensor[] self, Tensor[] tensors1, Tensor[] weights, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_lerp_out(at::TensorList self, at::TensorList tensors1, const at::Scalar & weight, at::TensorList out); // {"schema": "aten::_foreach_lerp.Scalar_out(Tensor[] self, Tensor[] tensors1, Scalar weight, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_lerp_out(at::TensorList self, at::TensorList tensors1, at::ArrayRef weight, at::TensorList out); // {"schema": "aten::_foreach_lerp.ScalarList_out(Tensor[] self, Tensor[] tensors1, Scalar[] weight, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_lgamma_out(at::TensorList self, at::TensorList out); // {"schema": "aten::_foreach_lgamma.out(Tensor[] self, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_log_out(at::TensorList self, at::TensorList out); // {"schema": "aten::_foreach_log.out(Tensor[] self, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_log10_out(at::TensorList self, at::TensorList out); // {"schema": "aten::_foreach_log10.out(Tensor[] self, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_log1p_out(at::TensorList self, at::TensorList out); // {"schema": "aten::_foreach_log1p.out(Tensor[] self, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_log2_out(at::TensorList self, at::TensorList out); // {"schema": "aten::_foreach_log2.out(Tensor[] self, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_max_out(at::TensorList self, at::TensorList out); // {"schema": "aten::_foreach_max.out(Tensor[] self, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_neg_out(at::TensorList self, at::TensorList out); // {"schema": "aten::_foreach_neg.out(Tensor[] self, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_norm_out(at::TensorList self, const at::Scalar & ord, ::std::optional dtype, at::TensorList out); // {"schema": "aten::_foreach_norm.Scalar_out(Tensor[] self, Scalar ord=2, ScalarType? dtype=None, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_pow_out(at::TensorList self, at::TensorList exponent, at::TensorList out); // {"schema": "aten::_foreach_pow.List_out(Tensor[] self, Tensor[] exponent, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_pow_out(at::TensorList self, const at::Scalar & exponent, at::TensorList out); // {"schema": "aten::_foreach_pow.Scalar_out(Tensor[] self, Scalar exponent, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_pow_out(at::TensorList self, at::ArrayRef exponent, at::TensorList out); // {"schema": "aten::_foreach_pow.ScalarList_out(Tensor[] self, Scalar[] exponent, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_reciprocal_out(at::TensorList self, at::TensorList out); // {"schema": "aten::_foreach_reciprocal.out(Tensor[] self, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_round_out(at::TensorList self, at::TensorList out); // {"schema": "aten::_foreach_round.out(Tensor[] self, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_rsqrt_out(at::TensorList self, at::TensorList out); // {"schema": "aten::_foreach_rsqrt.out(Tensor[] self, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_sigmoid_out(at::TensorList self, at::TensorList out); // {"schema": "aten::_foreach_sigmoid.out(Tensor[] self, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_sign_out(at::TensorList self, at::TensorList out); // {"schema": "aten::_foreach_sign.out(Tensor[] self, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_sin_out(at::TensorList self, at::TensorList out); // {"schema": "aten::_foreach_sin.out(Tensor[] self, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_sinh_out(at::TensorList self, at::TensorList out); // {"schema": "aten::_foreach_sinh.out(Tensor[] self, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_sqrt_out(at::TensorList self, at::TensorList out); // {"schema": "aten::_foreach_sqrt.out(Tensor[] self, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_tan_out(at::TensorList self, at::TensorList out); // {"schema": "aten::_foreach_tan.out(Tensor[] self, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_tanh_out(at::TensorList self, at::TensorList out); // {"schema": "aten::_foreach_tanh.out(Tensor[] self, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_trunc_out(at::TensorList self, at::TensorList out); // {"schema": "aten::_foreach_trunc.out(Tensor[] self, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_zero_out(at::TensorList self, at::TensorList out); // {"schema": "aten::_foreach_zero.out(Tensor[] self, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_zero(at::TensorList self); // {"schema": "aten::_foreach_zero(Tensor[] self) -> Tensor[] self_out", "dispatch": "True", "default": "True"} +void _foreach_copy_out(at::TensorList self, at::TensorList src, bool non_blocking, at::TensorList out); // {"schema": "aten::_foreach_copy.out(Tensor[] self, Tensor[] src, bool non_blocking=False, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +at::Tensor & bucketize_out(const at::Scalar & self, const at::Tensor & boundaries, bool out_int32, bool right, at::Tensor & out); // {"schema": "aten::bucketize.Scalar_out(Scalar self, Tensor boundaries, *, bool out_int32=False, bool right=False, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & glu_jvp_out(const at::Tensor & glu, const at::Tensor & x, const at::Tensor & dx, int64_t dim, at::Tensor & out); // {"schema": "aten::glu_jvp.out(Tensor glu, Tensor x, Tensor dx, int dim, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & glu_backward_jvp_out(const at::Tensor & grad_x, const at::Tensor & grad_glu, const at::Tensor & x, const at::Tensor & dgrad_glu, const at::Tensor & dx, int64_t dim, at::Tensor & out); // {"schema": "aten::glu_backward_jvp.out(Tensor grad_x, Tensor grad_glu, Tensor x, Tensor dgrad_glu, Tensor dx, int dim, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & hardswish_backward_out(const at::Tensor & grad_output, const at::Tensor & self, at::Tensor & out); // {"schema": "aten::hardswish_backward.out(Tensor grad_output, Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +::std::tuple rrelu_with_noise_functional(const at::Tensor & self, const at::Tensor & noise, const at::Scalar & lower, const at::Scalar & upper, bool training, ::std::optional generator); // {"schema": "aten::rrelu_with_noise_functional(Tensor self, Tensor noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> (Tensor, Tensor noise_out)", "dispatch": "True", "default": "True"} +at::Tensor & rrelu_with_noise_backward_out(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & noise, const at::Scalar & lower, const at::Scalar & upper, bool training, bool self_is_result, at::Tensor & out); // {"schema": "aten::rrelu_with_noise_backward.out(Tensor grad_output, Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training, bool self_is_result, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & mkldnn_adaptive_avg_pool2d_backward_out(const at::Tensor & grad_output, const at::Tensor & self, at::Tensor & out); // {"schema": "aten::mkldnn_adaptive_avg_pool2d_backward.out(Tensor grad_output, Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _adaptive_avg_pool2d_out(const at::Tensor & self, c10::SymIntArrayRef output_size, at::Tensor & out); // {"schema": "aten::_adaptive_avg_pool2d.out(Tensor self, SymInt[2] output_size, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _adaptive_avg_pool2d_backward_out(const at::Tensor & grad_output, const at::Tensor & self, at::Tensor & out); // {"schema": "aten::_adaptive_avg_pool2d_backward.out(Tensor grad_output, Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _adaptive_avg_pool3d_out(const at::Tensor & self, c10::SymIntArrayRef output_size, at::Tensor & out); // {"schema": "aten::_adaptive_avg_pool3d.out(Tensor self, SymInt[3] output_size, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _adaptive_avg_pool3d_backward_out(const at::Tensor & grad_output, const at::Tensor & self, at::Tensor & out); // {"schema": "aten::_adaptive_avg_pool3d_backward.out(Tensor grad_output, Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & upsample_bilinear2d_out(const at::Tensor & input, at::OptionalSymIntArrayRef output_size, bool align_corners, ::std::optional> scale_factors, at::Tensor & out); // {"schema": "aten::upsample_bilinear2d.vec_out(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & upsample_nearest2d_out(const at::Tensor & input, at::OptionalSymIntArrayRef output_size, ::std::optional> scale_factors, at::Tensor & out); // {"schema": "aten::upsample_nearest2d.vec_out(Tensor input, SymInt[]? output_size, float[]? scale_factors, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +::std::tuple _slow_conv2d_backward_out(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, ::std::array output_mask, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2); // {"schema": "aten::_slow_conv2d_backward.output_mask_out(Tensor grad_output, Tensor self, Tensor weight, SymInt[2] kernel_size, SymInt[2] stride, SymInt[2] padding, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))", "dispatch": "True", "default": "True"} +at::Tensor & conv_depthwise3d_out(const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, at::Tensor & out); // {"schema": "aten::conv_depthwise3d.out(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias, SymInt[3] stride, SymInt[3] padding, SymInt[3] dilation, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & slow_conv_dilated2d_out(const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, at::Tensor & out); // {"schema": "aten::slow_conv_dilated2d.out(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] dilation=1, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & slow_conv_dilated3d_out(const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, at::Tensor & out); // {"schema": "aten::slow_conv_dilated3d.out(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, SymInt[3] dilation=1, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & isinf_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::isinf.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & linalg_matrix_exp_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::linalg_matrix_exp.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _test_optional_intlist_out(const at::Tensor & values, at::OptionalIntArrayRef addends, at::Tensor & out); // {"schema": "aten::_test_optional_intlist.out(Tensor values, int[]? addends, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _test_optional_filled_intlist_out(const at::Tensor & values, at::OptionalIntArrayRef addends, at::Tensor & out); // {"schema": "aten::_test_optional_filled_intlist.out(Tensor values, int[2]? addends, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _test_optional_floatlist_out(const at::Tensor & values, ::std::optional> addends, at::Tensor & out); // {"schema": "aten::_test_optional_floatlist.out(Tensor values, float[]? addends, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _test_warn_in_autograd_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::_test_warn_in_autograd.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _test_autograd_multiple_dispatch_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::_test_autograd_multiple_dispatch.fullcoverage_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _test_autograd_multiple_dispatch_view_copy_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::_test_autograd_multiple_dispatch_view_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & segment_reduce_out(const at::Tensor & data, c10::string_view reduce, const ::std::optional & lengths, const ::std::optional & indices, const ::std::optional & offsets, int64_t axis, bool unsafe, const ::std::optional & initial, at::Tensor & out); // {"schema": "aten::segment_reduce.out(Tensor data, str reduce, *, Tensor? lengths=None, Tensor? indices=None, Tensor? offsets=None, int axis=0, bool unsafe=False, Scalar? initial=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _segment_reduce_backward_out(const at::Tensor & grad, const at::Tensor & output, const at::Tensor & data, c10::string_view reduce, const ::std::optional & lengths, const ::std::optional & offsets, int64_t axis, const ::std::optional & initial, at::Tensor & out); // {"schema": "aten::_segment_reduce_backward.out(Tensor grad, Tensor output, Tensor data, str reduce, *, Tensor? lengths=None, Tensor? offsets=None, int axis=0, Scalar? initial=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _nested_tensor_from_tensor_list_out(at::TensorList list, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, at::Tensor & out); // {"schema": "aten::_nested_tensor_from_tensor_list.out(Tensor[] list, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _fw_primal_copy_out(const at::Tensor & self, int64_t level, at::Tensor & out); // {"schema": "aten::_fw_primal_copy.out(Tensor self, int level, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _make_dual_copy_out(const at::Tensor & primal, const at::Tensor & tangent, int64_t level, at::Tensor & out); // {"schema": "aten::_make_dual_copy.out(Tensor primal, Tensor tangent, int level, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & view_as_real_copy_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::view_as_real_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & view_as_complex_copy_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::view_as_complex_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _conj_copy_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::_conj_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _neg_view_copy_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::_neg_view_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & as_strided_copy_out(const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional storage_offset, at::Tensor & out); // {"schema": "aten::as_strided_copy.out(Tensor self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _sparse_broadcast_to_copy_out(const at::Tensor & self, at::IntArrayRef size, at::Tensor & out); // {"schema": "aten::_sparse_broadcast_to_copy.out(Tensor self, int[] size, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & diagonal_copy_out(const at::Tensor & self, int64_t offset, int64_t dim1, int64_t dim2, at::Tensor & out); // {"schema": "aten::diagonal_copy.out(Tensor self, int offset=0, int dim1=0, int dim2=1, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & expand_copy_out(const at::Tensor & self, c10::SymIntArrayRef size, bool implicit, at::Tensor & out); // {"schema": "aten::expand_copy.out(Tensor self, SymInt[] size, *, bool implicit=False, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & permute_copy_out(const at::Tensor & self, at::IntArrayRef dims, at::Tensor & out); // {"schema": "aten::permute_copy.out(Tensor self, int[] dims, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _reshape_alias_copy_out(const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, at::Tensor & out); // {"schema": "aten::_reshape_alias_copy.out(Tensor self, SymInt[] size, SymInt[] stride, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & select_copy_out(const at::Tensor & self, int64_t dim, c10::SymInt index, at::Tensor & out); // {"schema": "aten::select_copy.int_out(Tensor self, int dim, SymInt index, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & detach_copy_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::detach_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & slice_copy_out(const at::Tensor & self, int64_t dim, ::std::optional start, ::std::optional end, c10::SymInt step, at::Tensor & out); // {"schema": "aten::slice_copy.Tensor_out(Tensor self, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & squeeze_copy_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::squeeze_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & squeeze_copy_out(const at::Tensor & self, int64_t dim, at::Tensor & out); // {"schema": "aten::squeeze_copy.dim_out(Tensor self, int dim, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & squeeze_copy_out(const at::Tensor & self, at::IntArrayRef dim, at::Tensor & out); // {"schema": "aten::squeeze_copy.dims_out(Tensor self, int[] dim, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & t_copy_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::t_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & transpose_copy_out(const at::Tensor & self, int64_t dim0, int64_t dim1, at::Tensor & out); // {"schema": "aten::transpose_copy.int_out(Tensor self, int dim0, int dim1, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & unsqueeze_copy_out(const at::Tensor & self, int64_t dim, at::Tensor & out); // {"schema": "aten::unsqueeze_copy.out(Tensor self, int dim, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _indices_copy_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::_indices_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _values_copy_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::_values_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & indices_copy_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::indices_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & values_copy_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::values_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & crow_indices_copy_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::crow_indices_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & col_indices_copy_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::col_indices_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & ccol_indices_copy_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::ccol_indices_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & row_indices_copy_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::row_indices_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & view_copy_out(const at::Tensor & self, c10::SymIntArrayRef size, at::Tensor & out); // {"schema": "aten::view_copy.out(Tensor self, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & view_copy_out(const at::Tensor & self, at::ScalarType dtype, at::Tensor & out); // {"schema": "aten::view_copy.dtype_out(Tensor self, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & unfold_copy_out(const at::Tensor & self, int64_t dimension, int64_t size, int64_t step, at::Tensor & out); // {"schema": "aten::unfold_copy.out(Tensor self, int dimension, int size, int step, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & alias_copy_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::alias_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & to_padded_tensor_out(const at::Tensor & self, double padding, at::OptionalSymIntArrayRef output_size, at::Tensor & out); // {"schema": "aten::to_padded_tensor.out(Tensor self, float padding, SymInt[]? output_size=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _transformer_encoder_layer_fwd_out(const at::Tensor & src, int64_t embed_dim, int64_t num_heads, const at::Tensor & qkv_weight, const at::Tensor & qkv_bias, const at::Tensor & proj_weight, const at::Tensor & proj_bias, bool use_gelu, bool norm_first, double eps, const at::Tensor & norm_weight_1, const at::Tensor & norm_bias_1, const at::Tensor & norm_weight_2, const at::Tensor & norm_bias_2, const at::Tensor & ffn_weight_1, const at::Tensor & ffn_bias_1, const at::Tensor & ffn_weight_2, const at::Tensor & ffn_bias_2, const ::std::optional & mask, ::std::optional mask_type, at::Tensor & out); // {"schema": "aten::_transformer_encoder_layer_fwd.out(Tensor src, int embed_dim, int num_heads, Tensor qkv_weight, Tensor qkv_bias, Tensor proj_weight, Tensor proj_bias, bool use_gelu, bool norm_first, float eps, Tensor norm_weight_1, Tensor norm_bias_1, Tensor norm_weight_2, Tensor norm_bias_2, Tensor ffn_weight_1, Tensor ffn_bias_1, Tensor ffn_weight_2, Tensor ffn_bias_2, Tensor? mask=None, int? mask_type=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +::std::tuple _native_multi_head_attention_out(const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, int64_t embed_dim, int64_t num_head, const at::Tensor & qkv_weight, const at::Tensor & qkv_bias, const at::Tensor & proj_weight, const at::Tensor & proj_bias, const ::std::optional & mask, bool need_weights, bool average_attn_weights, ::std::optional mask_type, at::Tensor & out0, at::Tensor & out1); // {"schema": "aten::_native_multi_head_attention.out(Tensor query, Tensor key, Tensor value, int embed_dim, int num_head, Tensor qkv_weight, Tensor qkv_bias, Tensor proj_weight, Tensor proj_bias, Tensor? mask=None, bool need_weights=True, bool average_attn_weights=True, int? mask_type=None, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "True"} +at::Tensor & _triton_scaled_dot_attention_out(const at::Tensor & q, const at::Tensor & k, const at::Tensor & v, double dropout_p, at::Tensor & out); // {"schema": "aten::_triton_scaled_dot_attention.out(Tensor q, Tensor k, Tensor v, float dropout_p=0.0, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _triton_multi_head_attention_out(const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, int64_t embed_dim, int64_t num_head, const at::Tensor & qkv_weight, const at::Tensor & qkv_bias, const at::Tensor & proj_weight, const at::Tensor & proj_bias, const ::std::optional & mask, at::Tensor & out); // {"schema": "aten::_triton_multi_head_attention.out(Tensor query, Tensor key, Tensor value, int embed_dim, int num_head, Tensor qkv_weight, Tensor qkv_bias, Tensor proj_weight, Tensor proj_bias, Tensor? mask=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _foobar_out(const at::Tensor & self, bool arg1, bool arg2, bool arg3, at::Tensor & out); // {"schema": "aten::_foobar.out(Tensor self, bool arg1=True, bool arg2=True, *, bool arg3=True, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +void _fused_adam_out(at::TensorList self, at::TensorList grads, at::TensorList exp_avgs, at::TensorList exp_avg_sqs, at::TensorList max_exp_avg_sqs, at::TensorList state_steps, double lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const ::std::optional & grad_scale, const ::std::optional & found_inf, at::TensorList out); // {"schema": "aten::_fused_adam.out(Tensor[] self, Tensor(b!)[] grads, Tensor(c!)[] exp_avgs, Tensor(d!)[] exp_avg_sqs, Tensor(e!)[] max_exp_avg_sqs, Tensor[] state_steps, *, float lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +::std::tuple<::std::vector,::std::vector,::std::vector,::std::vector,::std::vector> _fused_adam(at::TensorList self, at::TensorList grads, at::TensorList exp_avgs, at::TensorList exp_avg_sqs, at::TensorList max_exp_avg_sqs, at::TensorList state_steps, double lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const ::std::optional & grad_scale, const ::std::optional & found_inf); // {"schema": "aten::_fused_adam(Tensor[] self, Tensor[] grads, Tensor[] exp_avgs, Tensor[] exp_avg_sqs, Tensor[] max_exp_avg_sqs, Tensor[] state_steps, *, float lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> (Tensor[] self_out, Tensor[] grads_out, Tensor[] exp_avgs_out, Tensor[] exp_avg_sqs_out, Tensor[] max_exp_avg_sqs_out)", "dispatch": "True", "default": "True"} +void _fused_adam_out(at::TensorList self, at::TensorList grads, at::TensorList exp_avgs, at::TensorList exp_avg_sqs, at::TensorList max_exp_avg_sqs, at::TensorList state_steps, const at::Tensor & lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const ::std::optional & grad_scale, const ::std::optional & found_inf, at::TensorList out); // {"schema": "aten::_fused_adam.tensor_lr_out(Tensor[] self, Tensor(b!)[] grads, Tensor(c!)[] exp_avgs, Tensor(d!)[] exp_avg_sqs, Tensor(e!)[] max_exp_avg_sqs, Tensor[] state_steps, *, Tensor lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +::std::tuple<::std::vector,::std::vector,::std::vector,::std::vector,::std::vector> _fused_adam(at::TensorList self, at::TensorList grads, at::TensorList exp_avgs, at::TensorList exp_avg_sqs, at::TensorList max_exp_avg_sqs, at::TensorList state_steps, const at::Tensor & lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const ::std::optional & grad_scale, const ::std::optional & found_inf); // {"schema": "aten::_fused_adam.tensor_lr(Tensor[] self, Tensor[] grads, Tensor[] exp_avgs, Tensor[] exp_avg_sqs, Tensor[] max_exp_avg_sqs, Tensor[] state_steps, *, Tensor lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> (Tensor[] self_out, Tensor[] grads_out, Tensor[] exp_avgs_out, Tensor[] exp_avg_sqs_out, Tensor[] max_exp_avg_sqs_out)", "dispatch": "True", "default": "True"} +void _fused_adamw_out(at::TensorList self, at::TensorList grads, at::TensorList exp_avgs, at::TensorList exp_avg_sqs, at::TensorList max_exp_avg_sqs, at::TensorList state_steps, double lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const ::std::optional & grad_scale, const ::std::optional & found_inf, at::TensorList out); // {"schema": "aten::_fused_adamw.out(Tensor[] self, Tensor(b!)[] grads, Tensor(c!)[] exp_avgs, Tensor(d!)[] exp_avg_sqs, Tensor(e!)[] max_exp_avg_sqs, Tensor[] state_steps, *, float lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +::std::tuple<::std::vector,::std::vector,::std::vector,::std::vector,::std::vector> _fused_adamw(at::TensorList self, at::TensorList grads, at::TensorList exp_avgs, at::TensorList exp_avg_sqs, at::TensorList max_exp_avg_sqs, at::TensorList state_steps, double lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const ::std::optional & grad_scale, const ::std::optional & found_inf); // {"schema": "aten::_fused_adamw(Tensor[] self, Tensor[] grads, Tensor[] exp_avgs, Tensor[] exp_avg_sqs, Tensor[] max_exp_avg_sqs, Tensor[] state_steps, *, float lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> (Tensor[] self_out, Tensor[] grads_out, Tensor[] exp_avgs_out, Tensor[] exp_avg_sqs_out, Tensor[] max_exp_avg_sqs_out)", "dispatch": "True", "default": "True"} +void _fused_adamw_out(at::TensorList self, at::TensorList grads, at::TensorList exp_avgs, at::TensorList exp_avg_sqs, at::TensorList max_exp_avg_sqs, at::TensorList state_steps, const at::Tensor & lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const ::std::optional & grad_scale, const ::std::optional & found_inf, at::TensorList out); // {"schema": "aten::_fused_adamw.tensor_lr_out(Tensor[] self, Tensor(b!)[] grads, Tensor(c!)[] exp_avgs, Tensor(d!)[] exp_avg_sqs, Tensor(e!)[] max_exp_avg_sqs, Tensor[] state_steps, *, Tensor lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +::std::tuple<::std::vector,::std::vector,::std::vector,::std::vector,::std::vector> _fused_adamw(at::TensorList self, at::TensorList grads, at::TensorList exp_avgs, at::TensorList exp_avg_sqs, at::TensorList max_exp_avg_sqs, at::TensorList state_steps, const at::Tensor & lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const ::std::optional & grad_scale, const ::std::optional & found_inf); // {"schema": "aten::_fused_adamw.tensor_lr(Tensor[] self, Tensor[] grads, Tensor[] exp_avgs, Tensor[] exp_avg_sqs, Tensor[] max_exp_avg_sqs, Tensor[] state_steps, *, Tensor lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> (Tensor[] self_out, Tensor[] grads_out, Tensor[] exp_avgs_out, Tensor[] exp_avg_sqs_out, Tensor[] max_exp_avg_sqs_out)", "dispatch": "True", "default": "True"} +void _fused_sgd_out(at::TensorList self, at::TensorList grads, at::TensorList momentum_buffer_list, double weight_decay, double momentum, double lr, double dampening, bool nesterov, bool maximize, bool is_first_step, const ::std::optional & grad_scale, const ::std::optional & found_inf, at::TensorList out); // {"schema": "aten::_fused_sgd.out(Tensor[] self, Tensor(b!)[] grads, Tensor(c!)[] momentum_buffer_list, *, float weight_decay, float momentum, float lr, float dampening, bool nesterov, bool maximize, bool is_first_step, Tensor? grad_scale=None, Tensor? found_inf=None, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +::std::tuple<::std::vector,::std::vector,::std::vector> _fused_sgd(at::TensorList self, at::TensorList grads, at::TensorList momentum_buffer_list, double weight_decay, double momentum, double lr, double dampening, bool nesterov, bool maximize, bool is_first_step, const ::std::optional & grad_scale, const ::std::optional & found_inf); // {"schema": "aten::_fused_sgd(Tensor[] self, Tensor[] grads, Tensor[] momentum_buffer_list, *, float weight_decay, float momentum, float lr, float dampening, bool nesterov, bool maximize, bool is_first_step, Tensor? grad_scale=None, Tensor? found_inf=None) -> (Tensor[] self_out, Tensor[] grads_out, Tensor[] momentum_buffer_list_out)", "dispatch": "True", "default": "True"} +void _fused_sgd_out(at::TensorList self, at::TensorList grads, at::TensorList momentum_buffer_list, double weight_decay, double momentum, const at::Tensor & lr, double dampening, bool nesterov, bool maximize, bool is_first_step, const ::std::optional & grad_scale, const ::std::optional & found_inf, at::TensorList out); // {"schema": "aten::_fused_sgd.tensor_lr_out(Tensor[] self, Tensor(b!)[] grads, Tensor(c!)[] momentum_buffer_list, *, float weight_decay, float momentum, Tensor lr, float dampening, bool nesterov, bool maximize, bool is_first_step, Tensor? grad_scale=None, Tensor? found_inf=None, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +::std::tuple<::std::vector,::std::vector,::std::vector> _fused_sgd(at::TensorList self, at::TensorList grads, at::TensorList momentum_buffer_list, double weight_decay, double momentum, const at::Tensor & lr, double dampening, bool nesterov, bool maximize, bool is_first_step, const ::std::optional & grad_scale, const ::std::optional & found_inf); // {"schema": "aten::_fused_sgd.tensor_lr(Tensor[] self, Tensor[] grads, Tensor[] momentum_buffer_list, *, float weight_decay, float momentum, Tensor lr, float dampening, bool nesterov, bool maximize, bool is_first_step, Tensor? grad_scale=None, Tensor? found_inf=None) -> (Tensor[] self_out, Tensor[] grads_out, Tensor[] momentum_buffer_list_out)", "dispatch": "True", "default": "True"} +void _fused_adagrad_out(at::TensorList self, at::TensorList grads, at::TensorList state_sums, at::TensorList state_steps, double lr, double lr_decay, double weight_decay, double eps, bool maximize, const ::std::optional & grad_scale, const ::std::optional & found_inf, at::TensorList out); // {"schema": "aten::_fused_adagrad.out(Tensor[] self, Tensor(b!)[] grads, Tensor(c!)[] state_sums, Tensor(d!)[] state_steps, *, float lr, float lr_decay, float weight_decay, float eps, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +::std::tuple<::std::vector,::std::vector,::std::vector,::std::vector> _fused_adagrad(at::TensorList self, at::TensorList grads, at::TensorList state_sums, at::TensorList state_steps, double lr, double lr_decay, double weight_decay, double eps, bool maximize, const ::std::optional & grad_scale, const ::std::optional & found_inf); // {"schema": "aten::_fused_adagrad(Tensor[] self, Tensor[] grads, Tensor[] state_sums, Tensor[] state_steps, *, float lr, float lr_decay, float weight_decay, float eps, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> (Tensor[] self_out, Tensor[] grads_out, Tensor[] state_sums_out, Tensor[] state_steps_out)", "dispatch": "True", "default": "True"} +void _fused_adagrad_out(at::TensorList self, at::TensorList grads, at::TensorList state_sums, at::TensorList state_steps, const at::Tensor & lr, double lr_decay, double weight_decay, double eps, bool maximize, const ::std::optional & grad_scale, const ::std::optional & found_inf, at::TensorList out); // {"schema": "aten::_fused_adagrad.tensor_lr_out(Tensor[] self, Tensor(b!)[] grads, Tensor(c!)[] state_sums, Tensor[] state_steps, *, Tensor lr, float lr_decay, float weight_decay, float eps, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +::std::tuple<::std::vector,::std::vector,::std::vector> _fused_adagrad(at::TensorList self, at::TensorList grads, at::TensorList state_sums, at::TensorList state_steps, const at::Tensor & lr, double lr_decay, double weight_decay, double eps, bool maximize, const ::std::optional & grad_scale, const ::std::optional & found_inf); // {"schema": "aten::_fused_adagrad.tensor_lr(Tensor[] self, Tensor[] grads, Tensor[] state_sums, Tensor[] state_steps, *, Tensor lr, float lr_decay, float weight_decay, float eps, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> (Tensor[] self_out, Tensor[] grads_out, Tensor[] state_sums_out)", "dispatch": "True", "default": "True"} + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/SDPBackend.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/SDPBackend.h new file mode 100644 index 0000000000000000000000000000000000000000..a89231eddce95149baa87fd831c6a568fc5e9634 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/SDPBackend.h @@ -0,0 +1,21 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +#include + +namespace at { + +constexpr int32_t num_sdp_backends = 5; +enum class SDPBackend { + error = -1, + math = 0, + flash_attention = 1, + efficient_attention = 2, + cudnn_attention = 3, + overrideable = 4 +}; + +} // namespace at + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/SavedTensorHooks.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/SavedTensorHooks.h new file mode 100644 index 0000000000000000000000000000000000000000..ed3d5a4db001d6a32aefac0b2860dc6f1a3f6501 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/SavedTensorHooks.h @@ -0,0 +1,73 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include +#include +#include + +#include + +namespace at { + +namespace impl { + +struct TORCH_API SavedTensorDefaultHooksTLS { + // PyObject is defined in c10/util/python_stub.h + std::stack> stack; + + // See NOTE: [Disabling SavedTensorDefaultHooks] for context + // NOTE: [disabled_error_message invariant] + // disabled_error_message is nullopt IFF Saved Tensor hooks is enabled + // We did this for efficiency (so we didn't have to keep a separate bool + // around) + std::optional disabled_error_message; + + // See NOTE: [Deferring tensor pack/unpack hooks until runtime] + bool is_tracing = false; +}; + +} // namespace impl + +struct TORCH_API SavedTensorDefaultHooks { + static void push_hooks( + c10::SafePyObject pack_hook, + c10::SafePyObject unpack_hook); + static std::pair pop_hooks(); + static std::optional> + get_hooks(bool ignore_is_tracing = false); + static void lazy_initialize(); + + static const impl::SavedTensorDefaultHooksTLS& get_tls_state(); + static void set_tls_state(const impl::SavedTensorDefaultHooksTLS& tls); + + // NOTE: [Disabling SavedTensorDefaultHooks] + // A developer of a PyTorch feature may choose to disable SavedTensorDefault + // hooks, especially if their feature does not work with it. If they are + // disabled, then the following will raise an error: + // - Attempting to push_hooks + // - calling disable(message) with a non-zero stack (hooks) size + static void disable( + const std::string& error_message, + const bool fail_if_non_empty = true); + static void enable(); + static bool is_enabled(); + static const std::optional& get_disabled_error_message(); + + // NOTE: [Deferring tensor pack/unpack hooks until runtime] + // To preserve eager semantics of pack/unpack hooks firing only once per saved + // variable, Dynamo/AOTAutograd need to defer hook firing until runtime. Using + // disable() would loud error at trace time, and pushing a no-op hook would + // fail when the traced code is wrapped in a disable_saved_tensors_hooks ctx. + // To do so, we disable these hooks during tracing. See + // https://github.com/pytorch/pytorch/issues/113263. + static bool set_tracing(bool is_tracing); +}; + +} // namespace at + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/Scalar.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/Scalar.h new file mode 100644 index 0000000000000000000000000000000000000000..17a5006f54516ed1ff35efe96e4e644a1623514a --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/Scalar.h @@ -0,0 +1,8 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/ScalarOps.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/ScalarOps.h new file mode 100644 index 0000000000000000000000000000000000000000..ca30dc34312219290f847f98f99f106084b55815 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/ScalarOps.h @@ -0,0 +1,58 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#endif + +namespace at::detail { +// When filling a number to 1-element CPU tensor, we want to skip +// everything but manipulate data ptr directly. +// Ideally this fast pass should be implemented in TensorIterator, +// but we also want to skip compute_types which in not avoidable +// in TensorIterator for now. +Tensor& scalar_fill(Tensor& self, const Scalar& value); +TORCH_API Tensor scalar_tensor_static( + const Scalar& s, + std::optional dtype_opt, + std::optional device_opt); +} // namespace at::detail + +// This is in the c10 namespace because we use ADL to find the functions in it. +namespace c10 { + +// FIXME: this should be (and was) Scalar::toTensor, but there is currently no +// way to implement this without going through Derived Types (which are not part +// of core). +inline at::Tensor scalar_to_tensor( + const Scalar& s, + const Device device = at::kCPU) { + // This is the fast track we have for CPU scalar tensors. + if (device == at::kCPU) { + return at::detail::scalar_tensor_static(s, s.type(), at::kCPU); + } + return at::scalar_tensor(s, at::device(device).dtype(s.type())); +} + +} // namespace c10 + +namespace at::native { + +inline Tensor wrapped_scalar_tensor( + const Scalar& scalar, + const Device device = at::kCPU) { + auto tensor = scalar_to_tensor(scalar, device); + tensor.unsafeGetTensorImpl()->set_wrapped_number(true); + return tensor; +} + +} // namespace at::native + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/ScalarType.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/ScalarType.h new file mode 100644 index 0000000000000000000000000000000000000000..ee3c08cfdeb0b07f045aa9246faeb8e2944292af --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/ScalarType.h @@ -0,0 +1,9 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +#include // for BC reasons +#include +#include + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/SequenceNumber.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/SequenceNumber.h new file mode 100644 index 0000000000000000000000000000000000000000..5f83eccf90933246eca27dd9ac411394b9cd345b --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/SequenceNumber.h @@ -0,0 +1,18 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include + +// A simple thread local enumeration, used to link forward and backward pass +// ops and is used by autograd and observers framework +namespace at::sequence_number { + +TORCH_API uint64_t peek(); +TORCH_API uint64_t get_and_increment(); + +} // namespace at::sequence_number + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/SmallVector.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/SmallVector.h new file mode 100644 index 0000000000000000000000000000000000000000..09c6929024169c6667833aa39969d70a6f7fa016 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/SmallVector.h @@ -0,0 +1,7 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +#include + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/SparseCsrTensorImpl.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/SparseCsrTensorImpl.h new file mode 100644 index 0000000000000000000000000000000000000000..fb5d633095ba5bc40cc77443f86505bede68ed8a --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/SparseCsrTensorImpl.h @@ -0,0 +1,212 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include +namespace at { + +// Struct implementing a sparse CSR tensor. It uses three 1-D tensors for +// denoting the data: `crow_indices_`, `col_indices_` and `values_`. +// The `crow_indices_` tensor is a integer tensor of shape `(size(0) + 1)` +// that represents the compressed row indices of the CSR tensor. The +// `col_indices_` tensor is an integer tensor of shape `(nnz())` +// that explicitly stores the column indices of each value of the sparse +// tensor. The `values_` tensor can be of any pytorch-supported data type +// and has shape `(nnz())`. +// +// Since the main advantage of the CSR format over the COO format is speed of +// computation, care must be taken to facilitate smooth interfacing of +// these data structures with optimized libraries such as MKL and MAGMA. +// Since the MKL interface for pytorch currently uses indexing with int32 +// type, it is important to make sure that the `crow_indices` and `col_indices` +// are of type int32 when calling MKL routines such as SPMM or SPMV. +// +// If not calling MKL, it should be alright to use 64 bit integer tensors +// for indexing. +struct TORCH_API SparseCsrTensorImpl : public TensorImpl { + Tensor crow_indices_; + Tensor col_indices_; + Tensor values_; + Layout layout_; + + public: + explicit SparseCsrTensorImpl( + at::DispatchKeySet /*key_set*/, + at::Device device, + Layout layout, + const caffe2::TypeMeta /*data_type*/); + + void resize_(int64_t nnz, IntArrayRef size); + void resize_and_clear_( + int64_t sparse_dim, + int64_t dense_dim, + IntArrayRef size); + void resize_as_sparse_compressed_tensor_(const Tensor& src); + void set_member_tensors( + const Tensor& crow_indices, + const Tensor& col_indices, + const Tensor& values, + c10::SymIntArrayRef size); + void set_member_tensors( + const Tensor& crow_indices, + const Tensor& col_indices, + const Tensor& values, + IntArrayRef size); + const Tensor& compressed_indices() const { + return crow_indices_; + } + const Tensor& plain_indices() const { + return col_indices_; + } + const Tensor& values() const { + return values_; + } + int64_t nnz() { + return col_indices_.size(-1); + } + + inline int64_t batch_dim() const noexcept { + return crow_indices_.dim() - 1; + } + + inline int64_t sparse_dim() const noexcept { + return 2; + } + + inline int64_t dense_dim() const noexcept { + return values_.dim() - batch_dim() - block_dim() - 1; + } + + private: + inline int64_t block_dim() const noexcept { + return (layout_ == kSparseBsr || layout_ == kSparseBsc ? 2 : 0); + } + + protected: + IntArrayRef strides_custom() const override; + SymIntArrayRef sym_strides_custom() const override; + SymBool sym_is_contiguous_custom( + MemoryFormat /*memory_format*/) const override; + + public: + void set_size(int64_t dim, int64_t new_size) override; + void set_stride(int64_t dim, int64_t new_stride) override; + void set_storage_offset(int64_t storage_offset) override; + Layout layout_impl() const override { + return layout_; + } + void set_layout(Layout layout) { + switch (layout) { + case kSparseCsr: + case kSparseCsc: + case kSparseBsr: + case kSparseBsc: + layout_ = layout; + break; + default: + TORCH_CHECK(false, "unsupported layout ", layout); + } + } + + template + c10::intrusive_ptr shallow_copy_and_detach_core( + VariableVersion&& version_counter, + bool allow_tensor_metadata_change) const { + const auto mode_stack_len = c10::impl::TorchDispatchModeTLS::stack_len(); + c10::impl::PyInterpreter&& interpreter = nullptr; + if (mode_stack_len > 0 && + !c10::impl::tls_is_dispatch_key_excluded(DispatchKey::Python)) { + const auto& cur_torch_dispatch_mode_state = + c10::impl::TorchDispatchModeTLS::get_stack_at(mode_stack_len - 1); + interpreter = cur_torch_dispatch_mode_state->pyinterpreter(); + } else if ( + key_set_.has(DispatchKey::Python) && + !c10::impl::tls_is_dispatch_key_excluded(DispatchKey::Python)) { + interpreter = pyobj_slot_.load_pyobj_interpreter(); + } else { + // otherwise just copy the SparseTensorImpl and not the PyObject. + auto impl = c10::make_intrusive( + key_set(), device(), layout_impl(), dtype()); + copy_tensor_metadata( + /*src_sparse_impl=*/this, + /*dest_sparse_impl=*/impl.get(), + /*version_counter=*/version_counter, + /*allow_tensor_metadata_change=*/allow_tensor_metadata_change); + impl->refresh_numel(); + return impl; + } + auto r = interpreter->detach(this); + r->set_version_counter(std::forward(version_counter)); + r->set_allow_tensor_metadata_change(allow_tensor_metadata_change); + return r; + } + + /** + * Return a TensorImpl that is a shallow-copy of this TensorImpl. + * + * For usage of `version_counter` and `allow_tensor_metadata_change`, + * see NOTE [ TensorImpl Shallow-Copying ]. + */ + c10::intrusive_ptr shallow_copy_and_detach( + const c10::VariableVersion& version_counter, + bool allow_tensor_metadata_change) const override { + return shallow_copy_and_detach_core( + version_counter, allow_tensor_metadata_change); + } + + /** + * Return a TensorImpl that is a shallow-copy of this TensorImpl. + * + * For usage of `version_counter` and `allow_tensor_metadata_change`, + * see NOTE [ TensorImpl Shallow-Copying ]. + */ + c10::intrusive_ptr shallow_copy_and_detach( + c10::VariableVersion&& version_counter, + bool allow_tensor_metadata_change) const override { + return shallow_copy_and_detach_core( + std::move(version_counter), allow_tensor_metadata_change); + } + + private: + explicit SparseCsrTensorImpl( + at::DispatchKeySet key_set, + const caffe2::TypeMeta data_type, + at::Tensor crow_indices, + at::Tensor col_indices, + at::Tensor values, + at::Layout layout); + + const char* tensorimpl_type_name() const override; + + /** + * Copy the tensor metadata fields (e.g. sizes / strides / storage pointer / + * storage_offset) from one TensorImpl to another TensorImpl. + * + * For usage of `version_counter` and `allow_tensor_metadata_change`, see NOTE + * [ TensorImpl Shallow-Copying ]. + */ + static void copy_tensor_metadata( + const SparseCsrTensorImpl* src_sparse_impl, + SparseCsrTensorImpl* dest_sparse_impl, + c10::VariableVersion version_counter, + bool allow_tensor_metadata_change) { + TensorImpl::copy_tensor_metadata( + src_sparse_impl, + dest_sparse_impl, + std::move(version_counter), + allow_tensor_metadata_change); + + // Sparse-specific fields + dest_sparse_impl->crow_indices_ = src_sparse_impl->compressed_indices(); + dest_sparse_impl->col_indices_ = src_sparse_impl->plain_indices(); + dest_sparse_impl->values_ = src_sparse_impl->values(); + dest_sparse_impl->layout_ = src_sparse_impl->layout_impl(); + } +}; +} // namespace at + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/SparseCsrTensorUtils.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/SparseCsrTensorUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..01905dedce5f818a103475247a4330586b3f0d6c --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/SparseCsrTensorUtils.h @@ -0,0 +1,459 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#include +#else +#include +#include +#endif + +#define AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(LAYOUT, NAME, ...) \ + [&] { \ + const auto& the_layout = LAYOUT; \ + switch (the_layout) { \ + case kSparseCsr: \ + case kSparseCsc: \ + case kSparseBsr: \ + case kSparseBsc: \ + return __VA_ARGS__(); \ + default: \ + TORCH_CHECK( \ + false, \ + NAME, \ + " expected sparse compressed tensor layout but got ", \ + the_layout); \ + } \ + }() + +#define AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS( \ + LAYOUT, NAME, ROW_DIM_ACTION, COLUMN_DIM_ACTION) \ + [&]() { \ + const auto& the_layout = LAYOUT; \ + switch (the_layout) { \ + case kSparseCsr: \ + case kSparseBsr: \ + return (ROW_DIM_ACTION)(); \ + case kSparseCsc: \ + case kSparseBsc: \ + return (COLUMN_DIM_ACTION)(); \ + default: \ + TORCH_CHECK( \ + false, \ + NAME, \ + " expected sparse compressed tensor layout but got ", \ + the_layout); \ + } \ + }() + +#define AT_DISPATCH_PLAIN_SPARSE_COMPRESSED_LAYOUTS( \ + LAYOUT, NAME, NO_BLOCK_ACTION, BLOCK_ACTION) \ + [&]() { \ + const auto& the_layout = LAYOUT; \ + switch (the_layout) { \ + case kSparseCsr: \ + case kSparseCsc: \ + return (NO_BLOCK_ACTION)(); \ + case kSparseBsr: \ + case kSparseBsc: \ + return (BLOCK_ACTION)(); \ + default: \ + TORCH_CHECK( \ + false, \ + NAME, \ + " expected sparse compressed tensor layout but got ", \ + the_layout); \ + } \ + }() + +#define AT_DISPATCH_SPARSE_ROW_COMPRESSED_LAYOUTS( \ + LAYOUT, NAME, ROW_DIM_ACTION) \ + [&]() { \ + const auto& the_layout = LAYOUT; \ + switch (the_layout) { \ + case kSparseCsr: \ + case kSparseBsr: \ + return (ROW_DIM_ACTION)(); \ + default: \ + TORCH_CHECK( \ + false, \ + NAME, \ + " expected sparse row compressed tensor layout but got ", \ + the_layout); \ + } \ + }() + +#define AT_DISPATCH_SPARSE_COL_COMPRESSED_LAYOUTS( \ + LAYOUT, NAME, COL_DIM_ACTION) \ + [&]() { \ + const auto& the_layout = LAYOUT; \ + switch (the_layout) { \ + case kSparseCsc: \ + case kSparseBsc: \ + return (COL_DIM_ACTION)(); \ + default: \ + TORCH_CHECK( \ + false, \ + NAME, \ + " expected sparse column compressed tensor layout but got ", \ + the_layout); \ + } \ + }() + +#define AT_DISPATCH_SPARSE_COMPRESSED_NONBLOCK_LAYOUTS(LAYOUT, NAME, ACTION) \ + [&]() { \ + const auto& the_layout = LAYOUT; \ + switch (the_layout) { \ + case kSparseCsr: \ + case kSparseCsc: \ + return (ACTION)(); \ + default: \ + TORCH_CHECK( \ + false, \ + NAME, \ + " expected sparse compressed (non-block) tensor layout but got ", \ + the_layout); \ + } \ + }() + +#define AT_DISPATCH_SPARSE_COMPRESSED_BLOCK_LAYOUTS(LAYOUT, NAME, ACTION) \ + [&]() { \ + const auto& the_layout = LAYOUT; \ + switch (the_layout) { \ + case kSparseBsr: \ + case kSparseBsc: \ + return (ACTION)(); \ + default: \ + TORCH_CHECK( \ + false, \ + NAME, \ + " expected sparse compressed block tensor layout but got ", \ + the_layout); \ + } \ + }() + +#define AT_DISPATCH_SPARSE_VALUE_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, \ + NAME, \ + AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND4( \ + kComplexHalf, kHalf, kBool, kBFloat16, __VA_ARGS__)) + +namespace at::sparse_csr { + +// Implements RAII object to manage checking sparse tensor invariants: +class CheckSparseTensorInvariants { + bool old_state; + + public: + CheckSparseTensorInvariants(bool state) + : old_state(at::globalContext().checkSparseTensorInvariants()) { + at::globalContext().setCheckSparseTensorInvariants(state); + } + CheckSparseTensorInvariants(CheckSparseTensorInvariants&& other) = delete; + CheckSparseTensorInvariants(const CheckSparseTensorInvariants&) = delete; + CheckSparseTensorInvariants& operator=(const CheckSparseTensorInvariants&) = + delete; + CheckSparseTensorInvariants& operator=(CheckSparseTensorInvariants&&) = + delete; + + ~CheckSparseTensorInvariants() { + at::globalContext().setCheckSparseTensorInvariants(old_state); + } +}; + +using SparseCsrTensor = Tensor; + +inline bool is_sparse_compressed(const Layout& layout) { + switch (layout) { + case kSparseCsr: + case kSparseCsc: + case kSparseBsr: + case kSparseBsc: + return true; + default:; + } + return false; +} + +inline bool is_sparse_compressed(const Tensor& self) { + return is_sparse_compressed(self.layout()); +} + +inline SparseCsrTensorImpl* get_sparse_csr_impl(const SparseCsrTensor& self) { + AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS( + self.layout(), "get_sparse_csr_impl", [&] {}); + return static_cast(self.unsafeGetTensorImpl()); +} + +inline std::string layoutToString( + Layout layout, + bool upper = false, + bool lower = false) { + switch (layout) { + case kSparseCsr: + return (upper ? "CSR" : (lower ? "csr" : "Csr")); + case kSparseCsc: + return (upper ? "CSC" : (lower ? "csc" : "Csc")); + case kSparseBsr: + return (upper ? "BSR" : (lower ? "bsr" : "Bsr")); + case kSparseBsc: + return (upper ? "BSC" : (lower ? "bsc" : "Bsc")); + default: + TORCH_CHECK(false, "Not a sparse compressed layout:", layout); + return ""; + } +} + +inline bool isCompressedRow(Layout layout) { + return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS( + layout, "isCompressedRow", [&] { return true; }, [&] { return false; }); +} + +inline bool isCompressedColumn(Layout layout) { + return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS( + layout, + "isCompressedColumn", + [&] { return false; }, + [&] { return true; }); +} + +inline std::string compressedIndicesName(Layout layout) { + return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS( + layout, + "compressedIndicesName", + [&] { return "crow_indices"; }, + [&] { return "ccol_indices"; }); +} + +inline std::string plainIndicesName(Layout layout) { + return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS( + layout, + "plainIndicesName", + [&] { return "col_indices"; }, + [&] { return "row_indices"; }); +} + +inline std::string compressedDimName(Layout layout) { + switch (layout) { + case kSparseCsr: + return "row"; + case kSparseCsc: + return "column"; + case kSparseBsr: + return "row block"; + case kSparseBsc: + return "column block"; + default: + TORCH_CHECK(false, "Not a sparse compressed layout:", layout); + return ""; + } +} + +inline std::string plainDimName(Layout layout) { + switch (layout) { + case kSparseCsr: + return "column"; + case kSparseCsc: + return "row"; + case kSparseBsr: + return "column block"; + case kSparseBsc: + return "row block"; + default: + TORCH_CHECK(false, "Not a sparse compressed layout:", layout); + return ""; + } +} + +inline size_t rowDimension(Layout layout, IntArrayRef size) { + return size.size() - (isCompressedRow(layout) ? 2 : 1); +} + +inline size_t columnDimension(Layout layout, IntArrayRef size) { + return size.size() - (isCompressedColumn(layout) ? 2 : 1); +} + +inline size_t compressedDimension( + Layout layout, + IntArrayRef size, + size_t dense_ndim = 0) { + return size.size() - dense_ndim - (isCompressedRow(layout) ? 2 : 1); +} + +inline size_t plainDimension( + Layout layout, + IntArrayRef size, + size_t dense_ndim = 0) { + return size.size() - dense_ndim - (isCompressedRow(layout) ? 1 : 2); +} + +inline int64_t numBatchDimensions(Tensor const& self) { + return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS( + self.layout(), + "numBatchDimensions", + [&self] { return self.crow_indices().dim() - 1; }, + [&self] { return self.ccol_indices().dim() - 1; }); +} + +inline std::pair getCompressedPlainIndices(Tensor const& self) { + return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS( + self.layout(), + "getCompressedPlainIndices", + [&self] { + return std::make_pair(self.crow_indices(), self.col_indices()); + }, + [&self] { + return std::make_pair(self.ccol_indices(), self.row_indices()); + }); +} + +inline ScalarType getIndexDtype(Tensor const& self) { + switch (self.layout()) { + case kSparseCsr: + case kSparseBsr: + return self.crow_indices().scalar_type(); + case kSparseCsc: + case kSparseBsc: + return self.ccol_indices().scalar_type(); + case kSparse: + return self._indices().scalar_type(); + default: + return ScalarType::Long; + } +} + +inline Layout flip_compressed_layout(Layout layout) { + switch (layout) { + case kSparseCsr: + return kSparseCsc; + case kSparseCsc: + return kSparseCsr; + case kSparseBsr: + return kSparseBsc; + case kSparseBsc: + return kSparseBsr; + default: + TORCH_CHECK(false, "Not a sparse compressed layout:", layout); + return kSparseCsr; + } +} + +inline DimVector getBlockSize(Tensor const& self) { + int64_t n_batch = numBatchDimensions(self); + return at::DimVector(self.values().sizes().slice(n_batch + 1, 2)); +} + +inline at::OptionalArray getSymIntBlockSize(Tensor const& self) { + if (self.layout() == at::kSparseBsr || self.layout() == at::kSparseBsc) { + int64_t n_batch = numBatchDimensions(self); + return self.values().sym_sizes().slice(n_batch + 1, 2).vec(); + } else { + return {}; + } +} + +template +inline bool only_sparse_compressed_binary_op_trivial_cases( + const Tensor& self, + const Tensor& other, + const Scalar& alpha, + Tensor& out, + const binary_op_t& binary_op, + const binary_op_out_t& binary_op_out) { + // Only sparse compressed! Just like the name says :) + TORCH_INTERNAL_ASSERT(at::sparse_csr::is_sparse_compressed(self)); + TORCH_INTERNAL_ASSERT(at::sparse_csr::is_sparse_compressed(other)); + TORCH_INTERNAL_ASSERT(at::sparse_csr::is_sparse_compressed(out)); + + // Bypass BLAS if there are matches in (self, other, out) + if (self.is_same(out) && self.is_same(other)) { + binary_op_out(self.values(), other.values(), alpha); + return true; + } + if (self.is_same(other)) { + auto [compressed_indices, plain_indices] = + at::sparse_csr::getCompressedPlainIndices(self); + static_cast(out.unsafeGetTensorImpl()) + ->set_member_tensors( + compressed_indices, + plain_indices, + binary_op(self.values(), other.values(), alpha), + self.sizes()); + return true; + } + return false; +} + +inline bool only_sparse_compressed_add_trivial_cases( + const Tensor& self, + const Tensor& other, + const Scalar& alpha, + Tensor& out) { + return only_sparse_compressed_binary_op_trivial_cases( + self, + other, + alpha, + out, + [](const Tensor& v1, const Tensor& v2, const Scalar& alpha) { + return v1.add(v2, alpha); + }, + [](const Tensor& v1, const Tensor& v2, const Scalar& alpha) { + return v1.add_(v2, alpha); + }); +} + +inline Tensor to_type(const Tensor& input, ScalarType dtype) { + auto [compressed_indices, plain_indices] = + at::sparse_csr::getCompressedPlainIndices(input); + return at::_sparse_compressed_tensor_unsafe( + compressed_indices, + plain_indices, + std::move(input.values()).to(dtype), + input.sizes(), + dtype, + input.layout(), + input.device(), + input.options().pinned_memory_opt()); +} + +template +inline std::tuple create_acc_buffer( + TensorOptions option, + ScalarType type, + int64_t nnz = -1) { + Tensor new_values, new_values_acc; + constexpr bool need_acc = !std::is_same_v; + bool is_integral = at::isIntegralType(type, /*includeBool=*/true); + if constexpr (need_acc) { + auto acc_dtype = CppTypeToScalarType::value; + new_values_acc = at::empty({}, option.dtype(acc_dtype)); + new_values = is_integral ? new_values_acc : at::empty({}, option); + } else { + new_values = new_values_acc = at::empty({}, option); + } + if (nnz != -1) { + return std::make_tuple( + new_values.resize_(nnz), new_values_acc.resize_(nnz)); + } else { + return std::make_tuple(new_values, new_values_acc); + } +} + +inline void copy_from_acc_buffer(Tensor& new_values, Tensor& new_values_acc) { + if (!new_values_acc.is_same(new_values)) { + new_values.copy_(new_values_acc); + } +} + +} // namespace at::sparse_csr + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/SparseTensorImpl.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/SparseTensorImpl.h new file mode 100644 index 0000000000000000000000000000000000000000..0243667051e43df38a137957edaa608f06210da2 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/SparseTensorImpl.h @@ -0,0 +1,428 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#include +#endif + +namespace at { +struct TORCH_API SparseTensorImpl : public TensorImpl { + // Stored in COO format, indices + values. + + // INVARIANTS: + // sparse_dim: range [0, len(shape)]; sparse_dim + dense_dim = len(shape) + // dense_dim : range [0, len(shape)]; sparse_dim + dense_dim = len(shape) + // _indices.shape: dimensionality: 2, shape: (sparse_dim, nnz) + // _values.shape: dimensionality: 1 + dense_dim. shape: (nnz, + // shape[sparse_dim:]) + + int64_t sparse_dim_ = 0; // number of sparse dimensions + int64_t dense_dim_ = 0; // number of dense dimensions + + Tensor indices_; // always a LongTensor + Tensor values_; + + // A sparse tensor is 'coalesced' if every index occurs at most once in + // the indices tensor, and the indices are in sorted order. (This means + // that it is very easy to convert a coalesced tensor to CSR format: you + // need only compute CSR format indices.) + // + // Most math operations can only be performed on coalesced sparse tensors, + // because many algorithms proceed by merging two sorted lists (of indices). + bool coalesced_ = false; + + // compute_numel with integer multiplication overflow check, see gh-57542 + void refresh_numel() { + TensorImpl::safe_refresh_numel(); + } + + public: + // Public for now... + explicit SparseTensorImpl( + at::DispatchKeySet /*key_set*/, + const caffe2::TypeMeta /*data_type*/); + + void release_resources() override; + + int64_t nnz() const { + return values_.size(0); + } + + c10::SymInt sym_nnz() const { + return values_.sym_size(0); + } + int64_t sparse_dim() const { + return sparse_dim_; + } + int64_t dense_dim() const { + return dense_dim_; + } + bool coalesced() const { + return coalesced_; + } + Tensor indices() const { + return indices_; + } + Tensor values() const { + return values_; + } + + void set_size(int64_t dim, int64_t new_size) override; + void set_stride(int64_t dim, int64_t new_stride) override; + void set_storage_offset(int64_t storage_offset) override; + +#ifdef DEBUG + bool has_storage() const override; +#endif + + // WARNING: This function does NOT preserve invariants of sparse_dim/dense_dim + // with respect to indices and values + void raw_resize_(int64_t sparse_dim, int64_t dense_dim, IntArrayRef size) { + TORCH_CHECK( + allow_tensor_metadata_change(), + "raw_resize_ ", + err_msg_tensor_metadata_change_not_allowed); + TORCH_CHECK( + !has_symbolic_sizes_strides_, + "raw_resize_ called on tensor with symbolic shape") + set_sizes_and_strides(size, std::vector(size.size())); + sparse_dim_ = sparse_dim; + dense_dim_ = dense_dim; + refresh_numel(); + } + + // NOTE: This function preserves invariants of sparse_dim/dense_dim with + // respect to indices and values. + // + // NOTE: This function supports the following cases: + // 1. When we keep the number of dense dimensions unchanged, and NOT shrinking + // the size of any of the dense dimensions. + // 2. When we keep the number of sparse dimensions unchanged, and NOT + // shrinking the size of any of the sparse dimensions. + // 3. When the sparse tensor has zero nnz, in which case we are free to change + // the shapes of both its sparse and dense dimensions. + // + // This function DOESN'T support (and will throw an error) the following + // cases: + // 1. When we attempt to change the number of sparse dimensions on a non-empty + // sparse tensor (such an operation will invalidate the indices stored). + // 2. When we attempt to change the number of dense dimensions on a non-empty + // sparse tensor (such an operation will behave differently from an equivalent + // dense tensor's resize method, and for API consistency we don't support it). + // 3. When we attempt to shrink the size of any of the dense dimensions on a + // non-empty sparse tensor (such an operation will behave differently from an + // equivalent dense tensor's resize method, and for API consistency we don't + // support it). + // 4. When we attempt to shrink the size of any of the sparse dimensions on a + // non-empty sparse tensor (this could make some of the stored indices + // out-of-bound and thus unsafe). + template + void _resize_(int64_t sparse_dim, int64_t dense_dim, ArrayRef size) { + TORCH_CHECK( + allow_tensor_metadata_change(), + "resize_ ", + err_msg_tensor_metadata_change_not_allowed); + TORCH_CHECK( + !has_symbolic_sizes_strides_, + "resize_ called on tensor with symbolic shape") + TORCH_CHECK( + sparse_dim + dense_dim == static_cast(size.size()), + "'len(size) == sparse_dim + dense_dim' is not satisfied: len(size) = ", + size.size(), + ", sparse_dim = ", + sparse_dim, + ", dense_dim = ", + dense_dim); + if (nnz() > 0) { + [[maybe_unused]] auto constexpr alt_options_msg = + "You could try the following options:\n\ +1. If you need an empty sparse tensor of this size, call `x = torch.sparse_coo_tensor(size)`.\n\ +2. If you need to resize this tensor, you have the following options:\n\ + 1. For both sparse and dense dimensions, keep the number of them constant and the size of them non-shrinking, and then try the same call again.\n\ + 2. Or, create a new sparse tensor with the correct indices and values from this sparse tensor."; + + TORCH_CHECK( + sparse_dim == sparse_dim_, + "changing the number of sparse dimensions (from ", + sparse_dim_, + " to ", + sparse_dim, + ") on a non-empty sparse tensor is not supported.\n", + alt_options_msg); + + TORCH_CHECK( + dense_dim == dense_dim_, + "changing the number of dense dimensions (from ", + dense_dim_, + " to ", + dense_dim, + ") on a non-empty sparse tensor is not supported.\n", + alt_options_msg); + + bool shrinking_sparse_dims = false; + bool shrinking_dense_dim = false; + auto sparse_size_original = generic_sizes().slice(0, sparse_dim); + auto sparse_size_new = size.slice(0, sparse_dim); + for (const auto i : c10::irange(sparse_dim)) { + if (sparse_size_new[i] < sparse_size_original[i]) { + shrinking_sparse_dims = true; + break; + } + } + auto dense_size_original = generic_sizes().slice(sparse_dim); + auto dense_size_new = size.slice(sparse_dim); + for (const auto i : c10::irange(dense_dim)) { + if (dense_size_new[i] < dense_size_original[i]) { + shrinking_dense_dim = true; + break; + } + } + + TORCH_CHECK( + !shrinking_sparse_dims, + "shrinking the size of sparse dimensions (from ", + sparse_size_original, + " to ", + sparse_size_new, + ") on a non-empty sparse tensor is not supported.\n", + alt_options_msg); + + TORCH_CHECK( + !shrinking_dense_dim, + "shrinking the size of dense dimensions (from ", + dense_size_original, + " to ", + dense_size_new, + ") on a non-empty sparse tensor is not supported.\n", + alt_options_msg); + } + + auto sizes_and_strides = generic_sizes(); + const bool size_equals_sizes = std::equal( + size.begin(), + size.end(), + sizes_and_strides.begin(), + sizes_and_strides.end()); + if ((!size_equals_sizes) || (sparse_dim != sparse_dim_) || + (dense_dim != dense_dim_)) { + auto nnz = at::symint::sizes(values())[0]; + std::vector values_size = {nnz}; + auto dense_size = size.slice(sparse_dim); + values_size.insert( + values_size.end(), dense_size.begin(), dense_size.end()); + at::symint::resize_(values_, values_size); + at::symint::resize_(indices_, {T(sparse_dim), nnz}); + } + + if (!size_equals_sizes) { + set_sizes_and_strides(size, std::vector(size.size())); + } + sparse_dim_ = sparse_dim; + dense_dim_ = dense_dim; + refresh_numel(); + } + + void resize_(int64_t sparse_dim, int64_t dense_dim, ArrayRef size) { + _resize_(sparse_dim, dense_dim, size); + } + + void resize_( + int64_t sparse_dim, + int64_t dense_dim, + ArrayRef size) { + _resize_(sparse_dim, dense_dim, size); + } + + // NOTE: this function will resize the sparse tensor and also set `indices` + // and `values` to empty. + void resize_and_clear_( + int64_t sparse_dim, + int64_t dense_dim, + IntArrayRef size) { + TORCH_CHECK( + allow_tensor_metadata_change(), + "resize_and_clear_ ", + err_msg_tensor_metadata_change_not_allowed); + TORCH_CHECK( + !has_symbolic_sizes_strides_, + "resize_and_clear_ called on tensor with symbolic shape") + TORCH_CHECK( + sparse_dim + dense_dim == static_cast(size.size()), + "'len(size) == sparse_dim + dense_dim' is not satisfied: len(size) = ", + size.size(), + ", sparse_dim = ", + sparse_dim, + ", dense_dim = ", + dense_dim); + + set_sizes_and_strides(size, std::vector(size.size())); + sparse_dim_ = sparse_dim; + dense_dim_ = dense_dim; + + auto empty_indices = at::empty({sparse_dim, 0}, indices().options()); + std::vector values_size = {0}; + auto dense_size = sizes().slice(sparse_dim); + values_size.insert(values_size.end(), dense_size.begin(), dense_size.end()); + auto empty_values = at::empty(values_size, values().options()); + set_indices_and_values_unsafe(empty_indices, empty_values); + refresh_numel(); + } + + void set_coalesced(bool coalesced) { + TORCH_CHECK( + allow_tensor_metadata_change(), + "set_coalesced ", + err_msg_tensor_metadata_change_not_allowed); + coalesced_ = coalesced; + } + + // NOTE: this function is only used internally and not exposed to Python + // frontend + void set_nnz_and_narrow(int64_t new_nnz) { + TORCH_CHECK( + allow_tensor_metadata_change(), + "set_nnz_and_narrow ", + err_msg_tensor_metadata_change_not_allowed); + AT_ASSERT(new_nnz <= nnz()); + indices_ = indices_.narrow(1, 0, new_nnz); + values_ = values_.narrow(0, 0, new_nnz); + if (new_nnz < 2) { + coalesced_ = true; + } + } + + // Takes indices and values and directly puts them into the sparse tensor, no + // copy. NOTE: this function is unsafe because it doesn't check whether any + // indices are out of boundaries of `sizes`, so it should ONLY be used where + // we know that the indices are guaranteed to be within bounds. This used to + // be called THSTensor_(_move) NB: This used to be able to avoid a refcount + // bump, but I was too lazy to make it happen + void set_indices_and_values_unsafe( + const Tensor& indices, + const Tensor& values); + + template + c10::intrusive_ptr shallow_copy_and_detach_core( + VariableVersion&& version_counter, + bool allow_tensor_metadata_change) const { + const auto mode_stack_len = c10::impl::TorchDispatchModeTLS::stack_len(); + c10::impl::PyInterpreter&& interpreter = nullptr; + if (mode_stack_len > 0 && + !c10::impl::tls_is_dispatch_key_excluded(DispatchKey::Python)) { + const auto& cur_torch_dispatch_mode_state = + c10::impl::TorchDispatchModeTLS::get_stack_at(mode_stack_len - 1); + interpreter = cur_torch_dispatch_mode_state->pyinterpreter(); + } else if ( + key_set_.has(DispatchKey::Python) && + !c10::impl::tls_is_dispatch_key_excluded(DispatchKey::Python)) { + interpreter = pyobj_slot_.load_pyobj_interpreter(); + } else { + // otherwise just copy the SparseTensorImpl and not the PyObject. + auto impl = c10::make_intrusive(key_set(), dtype()); + copy_tensor_metadata( + /*src_sparse_impl=*/this, + /*dest_sparse_impl=*/impl.get(), + /*version_counter=*/version_counter, + /*allow_tensor_metadata_change=*/allow_tensor_metadata_change); + impl->refresh_numel(); + return impl; + } + auto r = interpreter->detach(this); + r->set_version_counter(std::forward(version_counter)); + r->set_allow_tensor_metadata_change(allow_tensor_metadata_change); + return r; + } + + /** + * Return a TensorImpl that is a shallow-copy of this TensorImpl. + * + * For usage of `version_counter` and `allow_tensor_metadata_change`, + * see NOTE [ TensorImpl Shallow-Copying ]. + */ + c10::intrusive_ptr shallow_copy_and_detach( + const c10::VariableVersion& version_counter, + bool allow_tensor_metadata_change) const override { + return shallow_copy_and_detach_core( + version_counter, allow_tensor_metadata_change); + } + + /** + * Return a TensorImpl that is a shallow-copy of this TensorImpl. + * + * For usage of `version_counter` and `allow_tensor_metadata_change`, + * see NOTE [ TensorImpl Shallow-Copying ]. + */ + c10::intrusive_ptr shallow_copy_and_detach( + c10::VariableVersion&& version_counter, + bool allow_tensor_metadata_change) const override { + return shallow_copy_and_detach_core( + std::move(version_counter), allow_tensor_metadata_change); + } + + /** + * Shallow-copies data from another TensorImpl into this TensorImpl. + * + * For why this function doesn't check this TensorImpl's + * `allow_tensor_metadata_change_`, see NOTE [ TensorImpl Shallow-Copying ]. + */ + void shallow_copy_from(const c10::intrusive_ptr& impl) override { + AT_ASSERT(has_compatible_shallow_copy_type(impl->key_set())); + auto sparse_impl = static_cast(impl.get()); + copy_tensor_metadata( + /*src_sparse_impl=*/sparse_impl, + /*dest_sparse_impl=*/this, + /*version_counter=*/version_counter(), + /*allow_tensor_metadata_change=*/allow_tensor_metadata_change()); + refresh_numel(); + } + + private: + explicit SparseTensorImpl( + at::DispatchKeySet /*key_set*/, + const caffe2::TypeMeta /*data_type*/, + at::Tensor indices, + at::Tensor values); + + /** + * Copy the tensor metadata fields (e.g. sizes / strides / storage pointer / + * storage_offset) from one TensorImpl to another TensorImpl. + * + * For usage of `version_counter` and `allow_tensor_metadata_change`, see NOTE + * [ TensorImpl Shallow-Copying ]. + */ + static void copy_tensor_metadata( + const SparseTensorImpl* src_sparse_impl, + SparseTensorImpl* dest_sparse_impl, + c10::VariableVersion version_counter, + bool allow_tensor_metadata_change) { + TensorImpl::copy_tensor_metadata( + src_sparse_impl, + dest_sparse_impl, + std::move(version_counter), + allow_tensor_metadata_change); + + // Sparse-specific fields + dest_sparse_impl->sparse_dim_ = src_sparse_impl->sparse_dim(); + dest_sparse_impl->dense_dim_ = src_sparse_impl->dense_dim(); + dest_sparse_impl->indices_ = src_sparse_impl->indices(); + dest_sparse_impl->values_ = src_sparse_impl->values(); + dest_sparse_impl->coalesced_ = src_sparse_impl->coalesced(); + } + + const char* tensorimpl_type_name() const override; +}; + +} // namespace at + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/Storage.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/Storage.h new file mode 100644 index 0000000000000000000000000000000000000000..366b0c17dc5db3a8c7adbca030ee8f8ac30d6e3e --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/Storage.h @@ -0,0 +1,7 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +#include + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/StorageUtils.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/StorageUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..9647f6411043c238b34f4be3bea9532f2448f69b --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/StorageUtils.h @@ -0,0 +1,54 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include + +namespace at { + +class TensorBase; + +// Here we define a series of utils to create/manipulate ATen backed +// c10 storage implementations. + +/** + * Create a new shared memory storage impl managed by file descriptor + * + * @param size size in bytes + */ +C10_EXPORT c10::intrusive_ptr new_shm_fd_storage(size_t size); + +/** + * Copy src to dst + * Caller must guarantee the validness of the storage objects + * during the entire copy process, esp. when it's async. + * + * This can probably live in c10 namespace later if needed, + * but for now keep it in at to keep implementation simple. + * + * @param dst dst tensor + * @param src src tensor + * @param non_blocking (default false) whether this operation blocks caller + */ +C10_EXPORT void storage_copy( + c10::Storage& dst, + const c10::Storage& src, + bool non_blocking = false); + +/** + * In place change the storage to shm based. + * + * This is only applicable to CPU tensors not already shared. + * Otherwise, it's a no op to mirror the THP tensor behavior: + * https://pytorch.org/docs/stable/generated/torch.Tensor.share_memory_.html + * + * @param t a tensor + */ +C10_EXPORT void share_memory_(TensorBase& t); + +} // namespace at + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/Tensor.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/Tensor.h new file mode 100644 index 0000000000000000000000000000000000000000..5f9aa4c4648b9623282e1daf26645b147f3d02d1 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/Tensor.h @@ -0,0 +1,8 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/TensorAccessor.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/TensorAccessor.h new file mode 100644 index 0000000000000000000000000000000000000000..c4b966f6a421f60ca1a31f299e240a06d471f31c --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/TensorAccessor.h @@ -0,0 +1,7 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +#include + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/TensorGeometry.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/TensorGeometry.h new file mode 100644 index 0000000000000000000000000000000000000000..138df4286f1c1c4933fb7d774ba7a1dd49b723fa --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/TensorGeometry.h @@ -0,0 +1,159 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include + +namespace at { + +// Return if the tensor geometry represented by `sizes` and `strides` is +// contiguous Although we cache is_contiguous in tensor now, this is till useful +// because it allows checking if a particular geometry is contiguous without +// explicitly constructing a tensor, e.g., when you want to choose a kernel +// strategy based on whether a subgeometry is contiguous. +TORCH_API bool geometry_is_contiguous(IntArrayRef sizes, IntArrayRef strides); + +struct TORCH_API TensorGeometry { + TensorGeometry() = default; + + explicit TensorGeometry(c10::SymIntArrayRef sizes) + : sizes_(sizes.vec()), + strides_(sizes.size()), + has_symbolic_sizes_strides_( + !c10::asIntArrayRefSlowOpt(sizes).has_value()) { + int64_t dim = static_cast(sizes.size()); + c10::SymInt expected_stride = 1; + for (int64_t i = dim - 1; i >= 0; i--) { + strides_[i] = expected_stride; + expected_stride *= sizes_[i]; + } + numel_ = expected_stride; + } + + explicit TensorGeometry(const TensorBase& t) + : sizes_(t.sym_sizes().vec()), + strides_(t.sym_strides().vec()), + storage_offset_(t.sym_storage_offset()), + numel_(t.sym_numel()), + has_symbolic_sizes_strides_( + t.unsafeGetTensorImpl()->has_symbolic_sizes_strides()) {} + + explicit TensorGeometry( + std::vector sizes, + std::vector strides, + at::SymInt storage_offset) + : sizes_(std::move(sizes)), + strides_(std::move(strides)), + storage_offset_(std::move(storage_offset)) { + recompute(); + } + + // true if the tensor is contiguous + bool is_contiguous() const; + + int64_t dim() const { + return static_cast(sizes_.size()); + } + + int64_t size(int64_t dim) const { + TORCH_INTERNAL_ASSERT(!has_symbolic_sizes_strides_); + dim = c10::maybe_wrap_dim(dim, this->dim()); + return sizes_.at(static_cast(dim)).as_int_unchecked(); + } + c10::IntArrayRef sizes() const { + TORCH_INTERNAL_ASSERT(!has_symbolic_sizes_strides_); + return c10::asIntArrayRefUnchecked(sizes_); + } + int64_t stride(int64_t dim) const { + TORCH_INTERNAL_ASSERT(!has_symbolic_sizes_strides_); + dim = c10::maybe_wrap_dim(dim, this->dim()); + return strides_.at(static_cast(dim)).as_int_unchecked(); + } + c10::IntArrayRef strides() const { + TORCH_INTERNAL_ASSERT(!has_symbolic_sizes_strides_); + return c10::asIntArrayRefUnchecked(strides_); + } + int64_t storage_offset() const { + TORCH_INTERNAL_ASSERT(!has_symbolic_sizes_strides_); + return storage_offset_.as_int_unchecked(); + } + int64_t numel() const { + TORCH_INTERNAL_ASSERT(!has_symbolic_sizes_strides_); + return numel_.as_int_unchecked(); + } + + c10::SymInt sym_size(int64_t dim) const { + dim = c10::maybe_wrap_dim(dim, this->dim()); + return sizes_.at(static_cast(dim)); + } + c10::SymIntArrayRef sym_sizes() const { + return sizes_; + } + c10::SymInt sym_stride(int64_t dim) const { + dim = c10::maybe_wrap_dim(dim, this->dim()); + return strides_.at(static_cast(dim)); + } + c10::SymIntArrayRef sym_strides() const { + return strides_; + } + c10::SymInt sym_storage_offset() const { + return storage_offset_; + } + c10::SymInt sym_numel() const { + return numel_; + } + + TensorGeometry transpose(int64_t dim0, int64_t dim1) { + TensorGeometry r = *this; // copy + TORCH_CHECK( + dim0 < dim(), + "transpose: dim0=", + dim0, + " out of range (dim=", + dim(), + ")") + TORCH_CHECK( + dim1 < dim(), + "transpose: dim1=", + dim1, + " out of range (dim=", + dim(), + ")") + std::swap(r.sizes_[dim0], r.sizes_[dim1]); + std::swap(r.strides_[dim0], r.strides_[dim1]); + return r; + } + + std::vector& mutable_sizes() { + return sizes_; + } + std::vector& mutable_strides() { + return strides_; + } + c10::SymInt& mutable_storage_offset() { + return storage_offset_; + } + void recompute() { + // recalculate numel after a change + c10::SymInt numel = 1; + for (const auto& i : sizes_) { + numel = numel * i; + } + numel_ = std::move(numel); + has_symbolic_sizes_strides_ = + !c10::asIntArrayRefSlowOpt(sizes_).has_value(); + } + + private: + std::vector sizes_; + std::vector strides_; + c10::SymInt storage_offset_; + c10::SymInt numel_; + bool has_symbolic_sizes_strides_{false}; +}; + +} // namespace at + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/TensorIndexing.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/TensorIndexing.h new file mode 100644 index 0000000000000000000000000000000000000000..76d8f282920e9b283c24fee96c9d9f4bc987e6dd --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/TensorIndexing.h @@ -0,0 +1,772 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#include +#endif + +#include + +#include + +namespace at::indexing { + +constexpr int64_t INDEX_MIN = c10::SymInt::min_representable_int(); +constexpr int64_t INDEX_MAX = -(INDEX_MIN + 1); + +enum class TensorIndexType { None, Ellipsis, SymInt, Boolean, Slice, Tensor }; + +constexpr std::nullopt_t None = std::nullopt; + +struct TORCH_API EllipsisIndexType final { + EllipsisIndexType() = default; +}; +TORCH_API extern const EllipsisIndexType Ellipsis; + +struct TORCH_API Slice final { + public: + Slice( + std::optional start_index = std::nullopt, + std::optional stop_index = std::nullopt, + std::optional step_index = std::nullopt) { + if (!step_index.has_value()) { + step_ = c10::SymInt(1); + } else { + step_ = std::move(step_index).value(); + } + + TORCH_CHECK_VALUE( + step_.sym_ne(0).expect_true(__FILE__, __LINE__), + "slice step cannot be zero"); + + if (!start_index.has_value()) { + start_ = c10::SymInt(step_ < 0 ? INDEX_MAX : 0); + } else { + start_ = std::move(start_index).value(); + } + + if (!stop_index.has_value()) { + stop_ = c10::SymInt(step_ < 0 ? INDEX_MIN : INDEX_MAX); + } else { + stop_ = std::move(stop_index).value(); + } + } + + inline c10::SymInt start() const { + return start_; + } + + inline c10::SymInt stop() const { + return stop_; + } + + inline c10::SymInt step() const { + return step_; + } + + private: + c10::SymInt start_; + c10::SymInt stop_; + c10::SymInt step_; +}; + +TORCH_API std::ostream& operator<<(std::ostream& stream, const Slice& slice); + +// `at::indexing::TensorIndex` is used for converting C++ tensor indices such as +// `{None, "...", Ellipsis, 0, true, Slice(1, None, 2), torch::tensor({1, 2})}` +// into its equivalent `std::vector`, so that further tensor +// indexing operations can be performed using the supplied indices. +// +// There is one-to-one correspondence between Python and C++ tensor index types: +// Python | C++ +// ----------------------------------------------------- +// `None` | `at::indexing::None` +// `Ellipsis` | `at::indexing::Ellipsis` +// `...` | `"..."` +// `123` | `123` +// `True` / `False` | `true` / `false` +// `:` | `Slice()` / `Slice(None, None)` +// `::` | `Slice()` / `Slice(None, None, None)` +// `1:` | `Slice(1, None)` +// `1::` | `Slice(1, None, None)` +// `:3` | `Slice(None, 3)` +// `:3:` | `Slice(None, 3, None)` +// `::2` | `Slice(None, None, 2)` +// `1:3` | `Slice(1, 3)` +// `1::2` | `Slice(1, None, 2)` +// `:3:2` | `Slice(None, 3, 2)` +// `1:3:2` | `Slice(1, 3, 2)` +// `torch.tensor([1, 2])`) | `torch::tensor({1, 2})` +struct TORCH_API TensorIndex final { + // Case 1: `at::indexing::None` + TensorIndex(std::nullopt_t /*unused*/) : type_(TensorIndexType::None) {} + + // Case 2: "..." / `at::indexing::Ellipsis` + TensorIndex(at::indexing::EllipsisIndexType /*unused*/) + : type_(TensorIndexType::Ellipsis) {} + TensorIndex(const char* str) : TensorIndex(at::indexing::Ellipsis) { + TORCH_CHECK_VALUE( + strcmp(str, "...") == 0, + "Expected \"...\" to represent an ellipsis index, but got \"", + str, + "\""); + } + + // Case 3: (Sym) Integer value + TensorIndex(SymInt integer) + : integer_(std::move(integer)), type_(TensorIndexType::SymInt) {} + TensorIndex(int64_t integer) : TensorIndex(SymInt(integer)) {} + TensorIndex(int integer) : TensorIndex(SymInt(integer)) {} + + // Case 4: Boolean value + template >> + TensorIndex(T boolean) : boolean_(boolean), type_(TensorIndexType::Boolean) {} + + // Case 5: Slice represented in `at::indexing::Slice` form + TensorIndex(Slice slice) + : slice_(std::move(slice)), type_(TensorIndexType::Slice) {} + + // Case 6: Tensor value + TensorIndex(Tensor tensor) + : tensor_(std::move(tensor)), type_(TensorIndexType::Tensor) {} + + inline bool is_none() const { + return type_ == TensorIndexType::None; + } + + inline bool is_ellipsis() const { + return type_ == TensorIndexType::Ellipsis; + } + + inline bool is_integer() const { + return type_ == TensorIndexType::SymInt; + } + + inline SymInt integer() const { + return integer_; + } + + inline bool is_boolean() const { + return type_ == TensorIndexType::Boolean; + } + + inline bool boolean() const { + return boolean_; + } + + inline bool is_slice() const { + return type_ == TensorIndexType::Slice; + } + + inline const Slice& slice() const { + return slice_; + } + + inline bool is_tensor() const { + return type_ == TensorIndexType::Tensor; + } + + inline const Tensor& tensor() const { + return tensor_; + } + + private: + SymInt integer_ = 0; + bool boolean_ = false; + Slice slice_; + Tensor tensor_; + TensorIndexType type_; +}; + +TORCH_API std::ostream& operator<<( + std::ostream& stream, + const TensorIndex& tensor_index); +TORCH_API std::ostream& operator<<( + std::ostream& stream, + const std::vector& tensor_indices); + +namespace impl { +inline Tensor applySlice( + const Tensor& self, + int64_t dim, + c10::SymInt start, + c10::SymInt stop, + c10::SymInt step, + bool disable_slice_optimization, + const at::Device& self_device, + const std::optional& self_sizes) { + // TODO: implement negative step + TORCH_CHECK_VALUE( + step.sym_gt(0).expect_true(__FILE__, __LINE__), + "step must be greater than zero"); + + // See NOTE [nested tensor size for indexing] + if (self_sizes.has_value() && !self_sizes.value().empty()) { + // Skip this optimization if we are tracing, as the trace may be polymorphic + // over the shape of the `self` tensor, and we still want to record + // the slice. + SymInt length = (self_device == at::kCPU || self_device == at::kCUDA) + ? (*self_sizes)[dim] + : self.sym_size(dim); + if (!disable_slice_optimization && + TORCH_STATICALLY_KNOWN_TRUE(start.sym_eq(0)) && + TORCH_STATICALLY_KNOWN_TRUE(length.sym_le(stop)) && step == 1) { + return self; + } + } + return self.slice_symint( + dim, std::move(start), std::move(stop), std::move(step)); +} + +inline Tensor applySelect( + const Tensor& self, + int64_t dim, + SymInt index, + int64_t real_dim, + const at::Device& /*self_device*/, + const std::optional& self_sizes) { + // See NOTE [nested tensor size for indexing] + if (self_sizes.has_value()) { + auto maybe_index = index.maybe_as_int(); + if (maybe_index.has_value()) { + TORCH_CHECK_INDEX( + !(maybe_index.value() == 0 && dim == 0 && self_sizes->empty()), + "invalid index of a 0-dim tensor. ", + "Use `tensor.item()` in Python or `tensor.item()` in C++ to convert a 0-dim tensor to a number"); + } + + auto size = (*self_sizes)[dim]; + // Note: `size >= -index` is not equivalent to `size > -1 - index` if index + // is INT64_MIN For std::numeric_limits::min() result of unary + // minus is undefined by the standard but in practice is equal to self. On + // the other hand, indexing wrapping is valid for all negative int64_t + // values, as x[INT64_MIN] is the same as x[INT64_MAX] + TORCH_CHECK_INDEX( + size.sym_gt(-1 - index) + .sym_and(size.sym_gt(index)) + .expect_true(__FILE__, __LINE__), + "index ", + index, + " is out of bounds for dimension ", + real_dim, + " with size ", + size); + } + + // if the index is negative, do not normalize it because that would fix the + // index on the current tensor size in the tracer. aten::select also works on + // negative indices + return self.select_symint(dim, std::move(index)); +} + +inline Tensor boolToIndexingTensorCPUOrCUDA(const Tensor& self, bool value) { + // booleans add a dimension of size 1. true indexes this dimension as if 0:, + // false as empty. + if (value) { + return at::empty({1}, self.options().dtype(kLong)).fill_(0.); + } else { + return at::empty({0}, self.options().dtype(kLong)); + } +} + +inline Tensor boolToIndexingTensorNonNativeDeviceType( + const Tensor& self, + bool value) { + // booleans add a dimension of size 1. true indexes this dimension as if 0:, + // false as empty. + if (value) { + return at::zeros({1}, self.options().dtype(kLong)); + } else { + return at::empty({0}, self.options().dtype(kLong)); + } +} + +inline Tensor boolToIndexingTensor( + const Tensor& self, + bool value, + const at::Device& self_device) { + if (self_device == at::kCPU || self_device == at::kCUDA) { + return boolToIndexingTensorCPUOrCUDA(self, value); + } else { + return boolToIndexingTensorNonNativeDeviceType(self, value); + } +} + +inline Tensor scalarToTensorNonNativeDeviceType( + const Scalar& v, + const TensorOptions& options) { + return at::scalar_tensor(v, options); +} + +inline void recordTensorIndex( + const Tensor& tensor, + std::vector& outIndices, + int64_t* dim_ptr) { + if (outIndices.empty()) { + outIndices.resize(*dim_ptr + 1); + outIndices[*dim_ptr] = tensor; + } else { + outIndices.push_back(tensor); + } + if (tensor.scalar_type() == kByte || tensor.scalar_type() == kBool) { + *dim_ptr += tensor.dim(); + } else { + *dim_ptr += 1; + } +} + +inline c10::List<::std::optional> typeConvertIndices( + const Tensor& /*self*/, + std::vector&& indices) { + c10::List<::std::optional> converted_inds; + converted_inds.reserve(indices.size()); + for (auto&& i : std::move(indices)) { + converted_inds.push_back(std::move(i)); + } + return converted_inds; +} + +// NOTE: Why do we mirror instead of replace the `count_specified_dimensions` +// function in torch/csrc/autograd/python_variable_indexing.cpp? It's because +// `count_specified_dimensions` is on the hot path of Python tensor multi-dim +// indexing (i.e. it's called by `applySlicing` which is called by +// `THPVariable_getitem` / `THPVariable_setitem` when handling indexing of more +// than one dimension). If we were to merge the Python/C++ +// `count_specified_dimensions` function, on the Python side we would have to +// construct a `std::vector` container to be consumed by the C++ +// `count_specified_dimensions` function, which adds 100s of nanoseconds +// overhead and is undesirable. +inline int64_t count_specified_dimensions( + const ArrayRef& indices) { + // Count the number of indexed dimensions (everything but ellipsis and None) + int64_t count = 0; + for (auto& obj : indices) { + if (obj.is_tensor()) { + auto& tensor = obj.tensor(); + if (tensor.scalar_type() == kByte || tensor.scalar_type() == kBool) { + count += tensor.dim(); + } else { + count++; + } + } else if (!obj.is_none() && !obj.is_ellipsis() && !obj.is_boolean()) { + count++; + } + } + return count; +} +} // namespace impl + +// NOTE: Many functions below are only for consumption from Python indexing +// implementation, they include: +// +// - `Tensor scalarToTensor(...)` +// - `IntArrayRef slicePrefix1sSize(...)` +// - `void copy_to(...)` +// - `Tensor handleDimInMultiDimIndexing(...)` +// - `Tensor dispatch_index(...)` +// - `Tensor dispatch_index_put_(...)` +// - `Tensor get_item(...)` +// - `void set_item(...)` +// +// The rest of the functions are in `at::indexing::impl` namespace, signifying +// that they shouldn't be used from Python indexing implementation. +inline Tensor scalarToTensor( + const Scalar& v, + const TensorOptions& options, + const at::Device& self_device) { + if (self_device == at::kCPU && !v.isSymbolic()) { + return at::detail::scalar_tensor_static( + v, + // NOLINTNEXTLINE(bugprone-unchecked-optional-access) + options.dtype_opt()->toScalarType(), + self_device); + } else { + return impl::scalarToTensorNonNativeDeviceType(v, options); + } +} + +// To match numpy semantics: +// As a special case for backwards compatibility, +// strip away unit dimensions from the left of 'src' +inline SymIntArrayRef slicePrefix1sSize(const SymIntArrayRef& sizes) { + size_t first_non1_src = sizes.size(); + for (const auto i : c10::irange(sizes.size())) { + // Unbacked SymInt has different behavior, but this is sound because + // failing to slice will only ever cause an error, not divergent + // behavior + if (!sizes[i].has_hint() || sizes[i] != 1) { + first_non1_src = i; + break; + } + } + + return sizes.slice(first_non1_src); +} + +inline void copy_to(const Tensor& dst, const Tensor& src) { + if (dst.sym_sizes().equals(src.sym_sizes())) { + // A shortcut to avoid generating hard-coded constant sizes during tracing. + // This is not a perfect solution: when src & dst have different shapes, + // constants will still appear. Users can workaround that case by + // dst[index..] = src.reshape(..) + dst.copy_(src); + return; + } else if (src.dim() == 0 && src.device().type() == at::kCPU) { + dst.fill_(src); + return; + } + auto src_view = src.view_symint(slicePrefix1sSize(src.sym_sizes())); + c10::MaybeOwned b_src = expand_inplace(dst, src_view, "setitem"); + dst.copy_(*b_src); +} + +// See NOTE [ Setting `disable_slice_optimization` when calling C++ tensor +// indexing functions from Python ] +inline Tensor handleDimInMultiDimIndexing( + const Tensor& prev_dim_result, + const Tensor& original_tensor, + const TensorIndex& index, + int64_t* dim_ptr, + int64_t* specified_dims_ptr, + int64_t real_dim, + std::vector& outIndices, + bool disable_slice_optimization, + const at::Device& original_tensor_device, + const std::optional& prev_dim_result_sizes) { + if (index.is_integer()) { + return impl::applySelect( + prev_dim_result, + *dim_ptr, + index.integer(), + real_dim, + original_tensor_device, + prev_dim_result_sizes); + } else if (index.is_slice()) { + Tensor result = impl::applySlice( + prev_dim_result, + *dim_ptr, + index.slice().start(), + index.slice().stop(), + index.slice().step(), + /*disable_slice_optimization=*/disable_slice_optimization, + original_tensor_device, + prev_dim_result_sizes); + (*dim_ptr)++; + if (!outIndices.empty()) { + outIndices.resize(outIndices.size() + 1); + } + return result; + } else if (index.is_ellipsis()) { + auto ellipsis_ndims = original_tensor.dim() - *specified_dims_ptr; + (*dim_ptr) += ellipsis_ndims; + if (!outIndices.empty()) { + outIndices.resize(outIndices.size() + ellipsis_ndims); + } + return prev_dim_result; + } else if (index.is_none()) { + Tensor result = prev_dim_result.unsqueeze(*dim_ptr); + (*dim_ptr)++; + if (!outIndices.empty()) { + outIndices.resize(outIndices.size() + 1); + } + return result; + } else if (index.is_boolean()) { + Tensor result = prev_dim_result.unsqueeze(*dim_ptr); + impl::recordTensorIndex( + impl::boolToIndexingTensor( + result, index.boolean(), original_tensor_device), + outIndices, + dim_ptr); + return result; + } else if (index.is_tensor()) { + Tensor result = prev_dim_result; + const Tensor& tensor = index.tensor(); + auto scalar_type = tensor.scalar_type(); + if (tensor.dim() == 0 && + at::isIntegralType(scalar_type, /*includeBool=*/true)) { + if (scalar_type != at::kByte && scalar_type != at::kBool) { + result = impl::applySelect( + result, + *dim_ptr, + tensor.item(), + real_dim, + original_tensor_device, + prev_dim_result_sizes); + } else { + result = result.unsqueeze(*dim_ptr); + if (scalar_type == at::kBool) { + impl::recordTensorIndex( + impl::boolToIndexingTensor( + result, tensor.item() != 0, original_tensor_device), + outIndices, + dim_ptr); + } else { + impl::recordTensorIndex( + impl::boolToIndexingTensor( + result, tensor.item() != 0, original_tensor_device), + outIndices, + dim_ptr); + } + } + } else { + impl::recordTensorIndex(tensor, outIndices, dim_ptr); + } + return result; + } else { + TORCH_INTERNAL_ASSERT(false, "Invalid TensorIndex type"); + } +} + +namespace impl { +// This mirrors `applySlicing` in +// torch/csrc/autograd/python_variable_indexing.cpp +inline Tensor applySlicing( + const Tensor& self, + const ArrayRef& indices, + std::vector& outIndices, + bool disable_slice_optimization, + const at::Device& self_device, + const std::optional& self_sizes) { + int64_t dim = 0; + int64_t specified_dims = impl::count_specified_dimensions(indices); + + // See NOTE [nested tensor size for indexing] + if (self_sizes.has_value()) { + TORCH_CHECK_INDEX( + specified_dims <= (int64_t)self_sizes->size(), + "too many indices for tensor of dimension ", + (int)self_sizes->size()); + } + + Tensor result = self; + for (const auto i : c10::irange(indices.size())) { + auto& obj = indices[i]; + // See NOTE [nested tensor size for indexing] + std::optional result_sizes = result.is_nested() + ? std::optional(std::nullopt) + : std::optional(result.sym_sizes()); + result = handleDimInMultiDimIndexing( + /*prev_dim_result=*/result, + /*original_tensor=*/self, + /*index=*/obj, + /*dim_ptr=*/&dim, + /*specified_dims_ptr=*/&specified_dims, + /*real_dim=*/static_cast(i), + /*outIndices=*/outIndices, + /*disable_slice_optimization=*/disable_slice_optimization, + /*original_tensor_device=*/self_device, + /*prev_dim_result_sizes=*/result_sizes); + } + return result; +} +} // namespace impl + +inline Tensor dispatch_index( + const Tensor& self, + std::vector&& indices) { + // Remove trailing null elements from indices + while (!indices.empty() && !indices.back().defined()) { + indices.pop_back(); + } + return self.index(impl::typeConvertIndices(self, std::move(indices))); +} + +inline Tensor dispatch_index_put_( + Tensor& self, + std::vector&& indices, + const Tensor& value) { + // Remove trailing null elements from indices + while (!indices.empty() && !indices.back().defined()) { + indices.pop_back(); + } + return self.index_put_( + impl::typeConvertIndices(self, std::move(indices)), value); +} + +// NOTE [ Setting `disable_slice_optimization` when calling C++ tensor indexing +// functions from Python ] +// +// Question: When should we set `disable_slice_optimization` to `true` when +// calling C++ tensor indexing functions from Python indexing code? +// +// Answer: What "slice optimization" means: when we have a slicing expression +// like `x[0:5, 0]`, where the sliced tensor was of size 5 in dimension 0, we +// would skip dispatching the actual slice call as an optimization. However, +// here are the cases where we DON'T want this optimization: +// +// 1. When we are doing 1-D slicing (e.g. `tensor[:]`). +// Reason: we always return a shallow copy for expressions such as +// `tensor[:]` / `tensor[...]` / `tensor[:, :]`. (Note that for `tensor[:, +// :]`, we return an alias of `tensor` by doing the following: +// ``` +// Tensor sliced = impl::applySlicing(self, indices, tensorIndices, +// disable_slice_optimization, self_device, self_sizes); if +// (tensorIndices.empty()) { +// if (sliced.is_same(self)) { +// // ensure we return a shallow copy for things like x[...] +// sliced = at::alias(sliced); +// } +// return sliced; +// } +// ```) +// 2. When we are doing JIT tracing. +// Reason: JIT tracing needs the `self.slice(...)` call to properly trace the +// slice operation. + +// This mirrors `THPVariable_getitem` in +// torch/csrc/autograd/python_variable_indexing.cpp See NOTE [ Setting +// `disable_slice_optimization` when calling C++ tensor indexing functions from +// Python ] +inline Tensor get_item( + const Tensor& self, + const ArrayRef& indices, + bool disable_slice_optimization = false) { + at::Device self_device = self.device(); + // NOTE [nested tensor size for indexing] + // nested tensor does not have a size (yet) so for now we represent its size + // as null may need to be changed after we reach a better solution for nested + // tensor size + std::optional self_sizes = self.is_nested() + ? std::optional(std::nullopt) + : std::optional(self.sym_sizes()); + + // handle simple types: integers, slices, none, ellipsis, bool + if (indices.size() == 1) { + const TensorIndex& index = indices[0]; + if (index.is_integer()) { + return impl::applySelect( + self, 0, index.integer(), 0, self_device, self_sizes); + } else if (index.is_slice()) { + return impl::applySlice( + self, + 0, + index.slice().start(), + index.slice().stop(), + index.slice().step(), + /*disable_slice_optimization=*/true, + self_device, + self_sizes); + } else if (index.is_none()) { + return self.unsqueeze(0); + } else if (index.is_ellipsis()) { + return at::alias(self); + } else if (index.is_boolean()) { + Tensor result = self.unsqueeze(0); + return dispatch_index( + result, + std::vector{impl::boolToIndexingTensor( + result, index.boolean(), self_device)}); + } + } + + std::vector tensorIndices; + Tensor sliced = impl::applySlicing( + self, + indices, + tensorIndices, + disable_slice_optimization, + self_device, + self_sizes); + if (tensorIndices.empty()) { + if (sliced.is_same(self)) { + // ensure we return a shallow copy for things like x[...] + sliced = at::alias(sliced); + } + return sliced; + } + + // indexing by tensors ("advanced" indexing) + return dispatch_index(sliced, std::move(tensorIndices)); +} + +// This mirrors `THPVariable_setitem` in +// torch/csrc/autograd/python_variable_indexing.cpp for "the assigned value is a +// Tensor" case See NOTE [ Setting `disable_slice_optimization` when calling C++ +// tensor indexing functions from Python ] +inline void set_item( + const Tensor& self, + const ArrayRef& indices, + const Tensor& value, + bool disable_slice_optimization = false) { + at::Device self_device = self.device(); + SymIntArrayRef self_sizes = self.sym_sizes(); + + // handle simple types: integers, slices, ellipsis, bool + if (indices.size() == 1) { + const TensorIndex& index = indices[0]; + if (index.is_boolean() && !index.boolean()) { + // do nothing for false (technically we should check the size, but we + // don't have real 0-sized shapes. + return; + } else if (index.is_ellipsis()) { + copy_to(self, value); + return; + } else if (index.is_none() || (index.is_boolean() && index.boolean())) { + copy_to(self.unsqueeze(0), value); + return; + } else if (index.is_integer()) { + copy_to( + impl::applySelect( + self, 0, index.integer(), 0, self_device, self_sizes), + value); + return; + } else if (index.is_slice()) { + copy_to( + impl::applySlice( + self, + 0, + index.slice().start(), + index.slice().stop(), + index.slice().step(), + /*disable_slice_optimization=*/disable_slice_optimization, + self_device, + self_sizes), + value); + return; + } + } + + std::vector tensorIndices; + Tensor sliced = impl::applySlicing( + self, + indices, + tensorIndices, + disable_slice_optimization, + self_device, + self_sizes); + if (tensorIndices.empty()) { + copy_to(sliced, value); + return; + } + + SymIntArrayRef valueSizes = value.sym_sizes(); + SymIntArrayRef slicedValueSizes = slicePrefix1sSize(valueSizes); + Tensor valuesSliced; + if (!valueSizes.equals(slicedValueSizes)) { + valuesSliced = value.view_symint(slicedValueSizes); + } else { + valuesSliced = value; + } + dispatch_index_put_(sliced, std::move(tensorIndices), valuesSliced); + return; +} + +} // namespace at::indexing + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/TensorIterator.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/TensorIterator.h new file mode 100644 index 0000000000000000000000000000000000000000..44fe79d3dbef21b91d60545f3542184fd008da64 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/TensorIterator.h @@ -0,0 +1,1039 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace at { +class Tensor; +class OptionalTensorRef; +using NameVector = SmallVector; +} // namespace at + +// TensorIterator is a helper class for element-wise operations, such as +// arithmetic, comparisons, and trigonometric functions. It handles +// broadcasting and type conversions of operands. +// +// This is inspired by NumPy's Array Iterator API (NpyIter). +// +// The files Loops.h and Loops.cuh provide functions to build kernels that +// use TensorIterator. +// +// Example: +// +// auto iter = TensorIteratorConfig() +// .add_output(output) +// .add_input(input) +// .build() +// +// [MyKernel.cpp / MyKernel.cu] +// cpu_kernel(iter, [](float a, float b) { +// return a + b; +// }); +// +// gpu_kernel(iter, []GPU_LAMBDA(float a, float b) -> float { +// return a + b; +// }); +// +// Note [Order of Construction] +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// When setting up the tensor iterator configuration, the output Tensors +// have to be added first via +// TensorIteratorConfig::add_owned_output(at::Tensor). After adding all outputs, +// the inputs can be added via +// TensorIteratorConfig::add_owned_input(at::Tensor). +// Adding another output after inputs have been added will rise an exception. +// +// Note [Common Dtype Computation] +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// Some operations have a natural notion of a "common dtype" or +// "computation dtype" where all inputs are cast to one dtype, the +// operation is performed, and then the results are cast to all outputs. +// +// TensorIterator infers a common dtype if all inputs have the same dtype, +// and it computes one using type promotion rules on its inputs if +// promote_inputs_to_common_dtype_ is true. Attempting to query +// a common dtype otherwise will throw an exception. +// +// Note that the outputs are not considered when computing a common dtype. + +namespace at { + +namespace internal { +// This parameter is heuristically chosen to determine the minimum number of +// work that warrants parallelism. For example, when summing an array, it is +// deemed inefficient to parallelise over arrays shorter than 32768. Further, +// no parallel algorithm (such as parallel_reduce) should split work into +// smaller than GRAIN_SIZE chunks. +constexpr int64_t GRAIN_SIZE = 32768; + +// Storage for a non-owning Tensor, without needing to include Tensor.h +class TORCH_API OpaqueOptionalTensorRef { + alignas(alignof(TensorBase)) std::array data_{}; + + public: + OpaqueOptionalTensorRef(); + OpaqueOptionalTensorRef(const OpaqueOptionalTensorRef&) = default; + OpaqueOptionalTensorRef& operator=(const OpaqueOptionalTensorRef&) = default; + OpaqueOptionalTensorRef(OpaqueOptionalTensorRef&&) noexcept = default; + OpaqueOptionalTensorRef& operator=(OpaqueOptionalTensorRef&&) noexcept = + default; + ~OpaqueOptionalTensorRef(); + + OptionalTensorRef* get() { + return reinterpret_cast(data_.data()); + } + const OptionalTensorRef* get() const { + return reinterpret_cast(data_.data()); + } + + OptionalTensorRef& operator*() { + return *get(); + } + const OptionalTensorRef& operator*() const { + return *get(); + } + OptionalTensorRef* operator->() { + return get(); + } + const OptionalTensorRef* operator->() const { + return get(); + } + + const Tensor& getTensor() const; +}; +} // namespace internal + +struct TORCH_API OperandInfo { + using StrideVector = SmallVector; + OperandInfo() = default; + C10_ALWAYS_INLINE explicit OperandInfo(c10::MaybeOwned&& t) { + if (t->defined()) { + device = t->device(); + target_dtype = t->scalar_type(); + current_dtype = target_dtype; + } + tensor(std::move(t)); + validate(); + } + + C10_ALWAYS_INLINE OperandInfo(const OperandInfo&) = default; + C10_ALWAYS_INLINE OperandInfo& operator=(const OperandInfo&) = default; + C10_ALWAYS_INLINE OperandInfo(OperandInfo&&) noexcept = default; + C10_ALWAYS_INLINE OperandInfo& operator=(OperandInfo&&) noexcept = default; + C10_ALWAYS_INLINE ~OperandInfo() = default; + + /// The data pointer. This may be different from tensor->data_ptr() if the + /// iterator is split. + void* data = nullptr; + + /// Stride after broadcasting. The stride is in bytes, not number of elements. + StrideVector stride_bytes; + + /// The desired device and type for the operand. For inputs, this specifies + /// that the input should be converted to this type if necessary. For outputs, + /// this specifies which type to allocate. target_dtype and device are + /// initialized with the dtype and device of the tensor but during type + /// promotion target_dtype value can become different from tensor's dtype + /// also, during type promotion target_dtype and device can be set for an + /// undefined tensor so that tensor can be properly constructed later. + std::optional device = std::nullopt; + ScalarType target_dtype = ScalarType::Undefined; + // Caches dtype of the tensor, because scalar_type is an expensive operation + // If dtype of the tensor is changed (e.g. as a result of type promotion or in + // allocate_outputs), this + // value should be changed too. + ScalarType current_dtype = ScalarType::Undefined; + + bool is_device_defined() const { + return device.has_value(); + } + bool is_type_defined() const { + return target_dtype != ScalarType::Undefined; + } + TensorOptions options() const { + return TensorOptions(target_dtype).device(device); + } + + bool is_output = false; + + // will_resize is only for output tensor. + // 1) Functional call(like torch.add(self, other)): output tensor is + // undefined, and pytorch creates a new tensor by using common shape + // and computed stride in TensorIterator; + // 2) Inplace call(like torch.add_(self, other)): output tensor is same + // with input tensor, and can't to modify tensor's size and stride; + // 3) Op call with output(like torch.add(self, other, out = output)): + // output tensor is defined, but tensor shape maybe different with common + // shape. If tensor shape is not same with common shape, this output + // tensor will be resized by using common shape and computed stride in + // TensorIterator. Otherwise can't modify tensor's size and stride. + bool will_resize = false; + + bool is_read_write = false; + + bool is_const = false; + + void validate() { + TORCH_CHECK( + !tensor_base_->defined() || tensor_base_->layout() == kStrided, + "unsupported tensor layout: ", + tensor_base_->layout()); + } + + /// The tensor operand. Note that the strides, data pointer, and + /// other attributes may differ due to dimension reordering and + /// coalescing. + const Tensor& tensor() const { + return tensor_storage_.getTensor(); + } + const TensorBase& tensor_base() const { + return *tensor_base_; + } + void tensor(c10::MaybeOwned&& tensor); + + // Save the original tensor operand in cases when an output is modified + // (e.g. if dtype is changed) + const Tensor& original_tensor() const { + return original_tensor_storage_.getTensor(); + } + const TensorBase& original_tensor_base() const { + return *original_tensor_base_; + } + + // Set tensor to a new value, and store the old tensor value in + // original_tensor Should only ever be called once for the lifetime of an + // operand + void exchange_tensor(c10::MaybeOwned&& new_tensor); + + // Move original_tensor back into tensor, exchange_tensor must have been + // called before + void restore_original_tensor(); + + private: + c10::MaybeOwned tensor_base_; + c10::MaybeOwned original_tensor_base_ = + c10::MaybeOwned::owned(std::in_place); + + // We store TensorBase visibly in the header to allow inline access. + // However, we sometimes need a genuine `const Tensor &` for the + // TensorIterator API. So, we also store a non-owning `Tensor` + // object in these `_storage_` variables. + internal::OpaqueOptionalTensorRef tensor_storage_; + internal::OpaqueOptionalTensorRef original_tensor_storage_; +}; + +struct SplitUntil32Bit; + +enum class FastSetupType : uint8_t { + NONE, + CONTIGUOUS, + CHANNELS_LAST, + NON_OVERLAPPING_DENSE +}; + +class TensorIteratorConfig; +struct TensorIterator; + +struct TORCH_API TensorIteratorBase : public impl::MetaBase { + using DimMask = std::bitset<64>; + using PtrVector = SmallVector; + using StrideVector = SmallVector; + + void build(TensorIteratorConfig& /*config*/); + + // The inner-loop function operates on the fastest moving dimension. It + // implements element-wise operations in terms of 1-d strided tensors. + // + // Arguments: + // data: data pointers for each operand (length `ntensors`) + // strides: stride for each operand (length `ntensors`) + // size: size of inner loop + // + // The `size` often matches shape[0], but may be smaller due to + // parallelization of the inner loop. + using loop2d_t = c10::function_ref< + void(char** data, const int64_t* strides, int64_t size0, int64_t size1)>; + + using loop_subiter_t = c10::function_ref; + + void foreach_reduced_elt(loop_subiter_t loop, bool parallelize = true); + + int ndim() const { + return static_cast(shape_.size()); + } + IntArrayRef shape() const { + return shape_; + } + int64_t numel() const; + int ntensors() const { + return static_cast(operands_.size()); + } + int noutputs() const { + return num_outputs_; + } + int ninputs() const { + return ntensors() - noutputs(); + } + IntArrayRef view_offsets() const { + return view_offsets_; + } + + /// number of elements in the output operand. this is the same as numel() for + /// operations that are not reductions. + int64_t num_output_elements() const; + + /// number of reduced dimensions in a reduction operation + int num_reduce_dims() const; + + /// 1-dimensional iteration and no buffering or type conversion + bool is_trivial_1d() const; + /// Reducible to 1-dimensional and all operands are contiguous + bool is_contiguous() const; + bool is_dim_reduced(int dim) const; + + /// Accessors for each operand + IntArrayRef strides(int64_t arg) const { + return operands_[arg].stride_bytes; + } + void* data_ptr(int64_t arg) const; + ScalarType dtype(int64_t arg = 0) const { + return operands_[arg].current_dtype; + } + ScalarType common_dtype() const { + TORCH_INTERNAL_ASSERT( + common_dtype_ != ScalarType::Undefined, + "Queried for invalid common dtype!"); + return common_dtype_; + } + ScalarType input_dtype(int64_t arg = 0) const { + return operands_[num_outputs_ + arg].current_dtype; + } + Device device(int64_t arg = 0) const { + // NOLINTNEXTLINE(bugprone-unchecked-optional-access) + return operands_[arg].device.value(); + } + c10::DeviceType device_type(int64_t arg = 0) const { + return device(arg).type(); + } + int64_t element_size(int64_t arg) const { + return static_cast(elementSize(dtype(arg))); + } + bool is_scalar(int64_t arg) const; + bool is_cpu_scalar(int64_t arg) const; + + const TensorBase& tensor_base(int64_t arg) const { + return operands_[arg].tensor_base(); + } + const Tensor& tensor(int64_t arg) const { + return operands_[arg].tensor(); + } + + const TensorBase& output_base(int64_t arg = 0) const { + AT_ASSERT(arg < num_outputs_); + return tensor_base(arg); + } + + const Tensor& output(int64_t arg = 0) const { + AT_ASSERT(arg < num_outputs_); + return tensor(arg); + } + + const TensorBase& input_base(int64_t arg = 0) const { + AT_ASSERT(arg >= 0 && arg < ntensors() - num_outputs_); + return tensor_base(num_outputs_ + arg); + } + const Tensor& input(int64_t arg = 0) const { + AT_ASSERT(arg >= 0 && arg < ntensors() - num_outputs_); + return tensor(num_outputs_ + arg); + } + + // Copies from temporary outputs back to the original outputs + // NOTE: only used on CPU + void cast_outputs(); + + /// Removes an operand from this iterator + void remove_operand(int64_t arg); + /// Shrinks an iterated dimension + void narrow(int dim, int64_t start, int64_t size); + /// Narrows every dim after and including `start_dim` to size one. + void select_all_keeping_dim(int start_dim, IntArrayRef starts); + /// Replaces the data pointer for the operand at index `arg`. + /// The new pointer should have the same sizes, strides and dtype as the + /// original + void unsafe_replace_operand(int64_t arg, void* data); + + /// Splits this TensorIterator into two iterators. Together they iterate over + /// the entire operation. Used by `with_32bit_indexing()`. + std::unique_ptr split(int dim); + + /// Returns the dimension with the largest extent: (size[dim]-1) * stride[dim] + int get_dim_to_split() const; + + template + T scalar_value(int64_t arg) { + auto& op = operands_[arg]; + return c10::fetch_and_cast(op.tensor_base().scalar_type(), op.data); + } + + /// Return scalar value from original_tensor_base if it is defined. When + /// common_dtype is Half, casting scalar input to common_dtype might overflow. + /// If the scalar is already given in the type of Half, then return scalar + /// value from tensor_base. + template + T original_scalar_value(int64_t arg) { + auto& original_tensor_base = operands_[arg].original_tensor_base(); + if (original_tensor_base.defined()) { + TORCH_INTERNAL_ASSERT( + original_tensor_base.scalar_type() != common_dtype()); + return c10::fetch_and_cast( + original_tensor_base.scalar_type(), + original_tensor_base.const_data_ptr()); + } else { + return scalar_value(arg); + } + } + + private: + template + auto loop_2d_from_1d(const loop1d_t& loop) { + return + [loop, ntensor = ntensors()]( + char** base, const int64_t* strides, int64_t size0, int64_t size1) { + PtrVector data(base, base + ntensor); + const int64_t* outer_strides = &strides[ntensor]; + for (const auto i : c10::irange(size1)) { + if (i > 0) { + for (const auto arg : c10::irange(ntensor)) { + data[arg] += outer_strides[arg]; + } + } + loop(data.data(), strides, size0); + } + }; + } + + public: + template < + typename loop1d_t, + std::enable_if_t< + std::is_convertible_v< + loop1d_t, + c10::function_ref< + void(char**, const int64_t* strides, int64_t size)>>, + int> = 0> + void for_each(loop1d_t loop, int64_t grain_size = at::internal::GRAIN_SIZE) { + for_each(loop_2d_from_1d(loop), grain_size); + } + + void for_each(loop2d_t loop, int64_t grain_size = at::internal::GRAIN_SIZE); + + void parallel_reduce(loop2d_t loop); + + template < + typename loop1d_t, + std::enable_if_t< + std::is_convertible_v< + loop1d_t, + c10::function_ref< + void(char**, const int64_t* strides, int64_t size)>>, + int> = 0> + void serial_for_each(loop1d_t loop, Range range) { + serial_for_each(loop_2d_from_1d(loop), range); + } + + void serial_for_each(loop2d_t loop, Range range) const; + + /// Create a strides array for a Tensor with shape of this iterator. The + /// parameter `element_size` specifies the size of Tensor's data type in + /// bytes (e.g. `4` for `float`) + StrideVector compatible_stride(int64_t element_size) const; + + /// Inverts the re-ordering done by reorder_dimensions. This can only be + /// called *before* coalesce_dimensions() is called. + DimVector invert_perm(IntArrayRef input) const; + + /// Reapply same re-ordering as it is done by reorder_dimensions. This can + /// only be called *before* coalesce_dimensions() is called. + DimVector apply_perm_and_mul(IntArrayRef input, int mul) const; + + /// Helper functions for CPU iteration + StrideVector get_dim_strides(int dim) const; + StrideVector get_strides() const; + StrideVector get_inner_strides() const { + return get_dim_strides(0); + } + PtrVector get_base_ptrs() const; + + // Helper functions for advanced stride manipulations (e.g. torch.flip) + void _unsafe_set_arg_strides(const int64_t arg, IntArrayRef strides) { + operands_[arg].stride_bytes = strides; + } + void _unsafe_set_arg_data(const int64_t arg, void* data) { + operands_[arg].data = data; + } + + // Helper functions for custom device, custom device can get OperandInfo and + // NameVector in their side. + const OperandInfo& operand(int arg = 0) const { + return operands_[arg]; + } + OperandInfo& operand(int arg = 0) { + return operands_[arg]; + } + NameVector& get_dim_names() { + return names_; + } + const NameVector& get_dim_names() const { + return names_; + } + + /// true if the stride computation can use 32-bit arithmetic. Used by GPU + /// kernels + bool can_use_32bit_indexing() const; + + /// An "iterable" object that recursively splits this iterator into + /// sub-iterators that can use 32-bit indexing. + SplitUntil32Bit with_32bit_indexing() const; + + /// If the kernel should accumulate into the output. Only relevant for CUDA + /// reductions. + bool should_accumulate() const { + return accumulate_; + } + + /// Whether this iterator produces the actual output, + /// as opposed to something that will be accumulated further. Only relevant + /// for CUDA reductions. + bool is_final_output() const { + return final_output_; + } + + bool has_contiguous_first_dim() const { + if (ndim() == 0) { + return true; + } + + int num_tensors = ntensors(); + for (const auto i : c10::irange(num_tensors)) { + if (strides(i)[0] != element_size(i)) { + return false; + } + } + return true; + } + + void set_output_raw_strided( + int64_t output_idx, + IntArrayRef sizes, + IntArrayRef strides, + TensorOptions options, + DimnameList names) override; + +#define TORCH_DISALLOW_TEMPORARIES_IMPL(methodname, maybestatic) \ + maybestatic void methodname( \ + TensorBase&& out, const TensorBase& a, const TensorBase& b) = delete; \ + maybestatic void methodname( \ + const TensorBase& out, TensorBase&& a, const TensorBase& b) = delete; \ + maybestatic void methodname( \ + const TensorBase& out, const TensorBase& a, TensorBase&& b) = delete; \ + maybestatic void methodname( \ + TensorBase&& out, TensorBase&& a, const TensorBase& b) = delete; \ + maybestatic void methodname( \ + TensorBase&& out, const TensorBase& a, TensorBase&& b) = delete; \ + maybestatic void methodname( \ + const TensorBase& out, TensorBase&& a, TensorBase&& b) = delete; \ + maybestatic void methodname( \ + TensorBase&& out, TensorBase&& a, TensorBase&& b) = delete; + +#define TORCH_DISALLOW_TEMPORARIES(methodname) \ + TORCH_DISALLOW_TEMPORARIES_IMPL(methodname, ) + + void build_binary_float_op( + const TensorBase& out, + const TensorBase& a, + const TensorBase& b); + void build_borrowing_binary_float_op( + const TensorBase& out, + const TensorBase& a, + const TensorBase& b); + TORCH_DISALLOW_TEMPORARIES(build_borrowing_binary_float_op) + void build_binary_op( + const TensorBase& out, + const TensorBase& a, + const TensorBase& b); + void build_borrowing_binary_op( + const TensorBase& out, + const TensorBase& a, + const TensorBase& b); + TORCH_DISALLOW_TEMPORARIES(build_borrowing_binary_op) + void build_unary_float_op(const TensorBase& out, const TensorBase& a); + void build_borrowing_unary_float_op( + const TensorBase& out, + const TensorBase& a); + TORCH_DISALLOW_TEMPORARIES(build_borrowing_unary_float_op) + void build_unary_op(const TensorBase& out, const TensorBase& a); + // Odd special case needed for pow. Has to borrow the output because + // it's a structured kernel, but the argument is potentially a copy. + void build_output_borrowing_argument_owning_unary_op( + const TensorBase& out, + const TensorBase& a); + void build_borrowing_unary_op(const TensorBase& out, const TensorBase& a); + TORCH_DISALLOW_TEMPORARIES(build_borrowing_unary_op) + void build_borrowing_unary_force_boolean_op( + const TensorBase& out, + const TensorBase& a); + TORCH_DISALLOW_TEMPORARIES(build_borrowing_unary_force_boolean_op) + void build_comparison_op( + const TensorBase& out, + const TensorBase& a, + const TensorBase& b); + void build_borrowing_comparison_op( + const TensorBase& out, + const TensorBase& a, + const TensorBase& b); + TORCH_DISALLOW_TEMPORARIES(build_borrowing_comparison_op) + // Another special case: we need to own the second argument for comparison + // ops. + void build_borrowing_except_last_argument_comparison_op( + const TensorBase& out, + const TensorBase& a, + const TensorBase& b); + void build_ternary_op( + const TensorBase& out, + const TensorBase& a, + const TensorBase& b, + const TensorBase& c); + +#undef TORCH_DISALLOW_TEMPORARIES + protected: + // Mutable reference as it moves tensors out of TensorIteratorConfig + void populate_operands(TensorIteratorConfig& /*config*/); + void mark_outputs(); + void mark_resize_outputs(const TensorIteratorConfig& /*config*/); + void compute_mem_overlaps(const TensorIteratorConfig& /*config*/); + void compute_shape(const TensorIteratorConfig& /*config*/); + void compute_strides(const TensorIteratorConfig& /*config*/); + void reorder_dimensions(); + void permute_dimensions(IntArrayRef perm); + void compute_types(const TensorIteratorConfig& /*config*/); + ScalarType compute_common_dtype(); + void allocate_or_resize_outputs(); + bool fast_set_up(const TensorIteratorConfig& /*config*/); + FastSetupType compute_fast_setup_type(const TensorIteratorConfig& /*config*/); + void compute_names(const TensorIteratorConfig& /*config*/); + void propagate_names_to_outputs(); + void coalesce_dimensions(); + + protected: + /// Records the "computation" shape of the output tensor. The computation + /// shape is different from the regular shape in a few ways: + /// + /// - The shape may be permuted (via permute_dimensions) so that we + /// process the dimensions in the most computationally efficient order + /// (rather than the logical order given to us by the users.) + /// - The shape may have adjacent dimensions collapsed (via + /// coalesce_dimensions) so that we minimize the number of + /// dimensions we have to explicitly iterate over. For example, + /// a pointwise operation on a contiguous tensor "computationally" + /// consists of only a single dimension. + /// + /// In other words, the computation shape is the output shape as it + /// actually matters for implementing the kernel, but not necessarily the + /// output shape that the user will see in the end. + /// + /// The lifecycle of mutations to shape_ in TensorIterator: + /// - declare_static_shape() sets an initial shape explicitly + /// provided by user, otherwise + /// - compute_shape() computes the true (non-computational) shape + /// specified by the user. + /// - reorder_dimensions() reorders dimensions to improve coalescing. + /// - coalesce_dimensions() then coalesces adjacent dimensions when + /// possible. + /// + /// The shape may also be further modified if we create sub-TensorIterators, + /// e.g., via narrow or select_all_keeping_dim. + DimVector shape_; + + /// Temporarily records the permutation computed by reorder_dimensions. + /// This permutation maps the computation output dimension (dim) to + /// the original true output dimension (perm_[dim]). It is used by + /// invert_perm to undo the permutation. After coalesce_dimensions is + /// called, the permutation is no longer valid (as, in general, there + /// is no permutation that will make computation dimensions to + /// output dimensions); methods that manipulate perm_ are obligated + /// to test that !has_coalesced_dimensions + DimVector perm_; + + /// Has coalesce_dimensions() (or any moral equivalent, e.g., fast_build()) + /// been called? This is SOLELY used to check validity of perm_. + bool has_coalesced_dimensions_ = false; + + /// Whether iteration must be fixed. This disables dimension permuting and + /// also changes how for_each divides work among threads. + bool enforce_linear_iteration_ = false; + + /// The index offsets into the original tensors for each dimension. + /// This is only non-zero when you narrow() a TensorIterator (e.g., + /// when you make sub-TensorIterators). + DimVector view_offsets_; + + /// The computed names of the output tensor. Computed by compute_names() + NameVector names_; + + /// The operands of the TensorIterator: both the inputs and outputs. The + /// outputs MUST come first in the operands_ list. There is always an + /// operand for each output of the TensorIterator, even if TensorIterator + /// will ultimately be responsible for allocating the output; in those + /// cases, tensor is simply undefined (and will be populated later + /// during build()). + /// + /// This list is initially populated prior to build(), but build() mutates + /// OperandInfo to populate more information. + SmallVector operands_; + + /// Number of outputs in operands_ (the length of the outputs prefix + /// in operands_). + int num_outputs_ = 0; + + /// Whether or not all operands have the same shape and are 1d+. Having all + /// the same shape affects whether or not the iterator is eligible for fast + /// setup. + bool all_ops_same_shape_ = false; + /// Whether or not all operands are 0d, this affects type promotion + bool all_ops_are_scalars_ = false; + + /// The "computation" dtype of TensorIterator, specifying what the dtype + /// we will do the internal computation in TensorIterator. Typically, + /// this matches the dtype of the output tensors, but not always! + ScalarType common_dtype_ = ScalarType::Undefined; + + /// This is currently defined as kCPU, or the device of the first non-CPU + /// tensor argument. See TensorIteratorBase::compute_types for details. + Device common_device_ = kCPU; + + /// Set by split(), see should_accumulate() and is_final_output() + bool accumulate_ = false; + bool final_output_ = true; + + // From TensorIteratorConfig + bool is_reduction_ = false; + + /// Set by populate_operands(), says if we're handling meta tensors + bool is_meta_ = false; +}; + +struct TORCH_API TensorIterator final : public TensorIteratorBase { + TensorIterator() : TensorIteratorBase() {} + // Slicing is OK, TensorIterator guaranteed NOT to have any fields + TensorIterator(const TensorIteratorBase& iter) : TensorIteratorBase(iter) {} + +#define TORCH_DISALLOW_TEMPORARIES(methodname) \ + TORCH_DISALLOW_TEMPORARIES_IMPL(methodname, static) + + static TensorIterator binary_float_op( + TensorBase& out, + const TensorBase& a, + const TensorBase& b); + static TensorIterator binary_op( + TensorBase& out, + const TensorBase& a, + const TensorBase& b); + static TensorIterator borrowing_binary_op( + const TensorBase& out, + const TensorBase& a, + const TensorBase& b); + TORCH_DISALLOW_TEMPORARIES(borrowing_binary_op) + static TensorIterator comparison_op( + TensorBase& out, + const TensorBase& a, + const TensorBase& b); + static TensorIterator unary_op(TensorBase& out, const TensorBase& a); + static TensorIterator unary_float_op(TensorBase& out, const TensorBase& a); + static TensorIterator nullary_op(TensorBase& out); + static TensorIterator borrowing_nullary_op(const TensorBase& out); + static TensorIterator borrowing_nullary_op(TensorBase&& out) = delete; + static TensorIterator reduce_op(TensorBase& out, const TensorBase& a); + static TensorIterator reduce_op( + TensorBase& out1, + TensorBase& out2, + const TensorBase& a); +#undef TORCH_DISALLOW_TEMPORARIES +#undef TORCH_DISALLOW_TEMPORARIES_IMPL + + const Tensor& maybe_get_output(int64_t output_idx) override; + void set_output_raw_strided( + int64_t output_idx, + IntArrayRef sizes, + IntArrayRef strides, + TensorOptions options, + DimnameList names) override; +}; + +class TORCH_API TensorIteratorConfig final { + public: + friend struct TensorIteratorBase; + friend struct TensorIterator; + + TensorIteratorConfig() = default; + + C10_DISABLE_COPY_AND_ASSIGN(TensorIteratorConfig); + TensorIteratorConfig(TensorIteratorConfig&&) = default; + TensorIteratorConfig& operator=(TensorIteratorConfig&&) = default; + ~TensorIteratorConfig() = default; + + /// Construction + // Stores input/output Tensors without incrementing the reference count. + // Important: the outputs have to be added before the inputs. + TensorIteratorConfig& add_output(const TensorBase& output) { + return add_borrowed_output(output); + } + TensorIteratorConfig& add_input(const TensorBase& input) { + return add_borrowed_input(input); + } + TensorIteratorConfig& add_const_input(const TensorBase& input) { + return add_borrowed_const_input(input); + } + + // Borrowing from temporaries is unlikely to go well. + TensorIteratorConfig& add_output(TensorBase&& output) = delete; + TensorIteratorConfig& add_input(TensorBase&& input) = delete; + TensorIteratorConfig& add_const_input(TensorBase&& input) = delete; + + // Stores input/output Tensors while incrementing the reference count. + // Note that add_{in,out}put are nearly always what you + // want, and the exception (adding an unnamed temporary) won't + // compile. + TensorIteratorConfig& add_owned_output(const TensorBase& output); + TensorIteratorConfig& add_owned_input(const TensorBase& input); + TensorIteratorConfig& add_owned_const_input(const TensorBase& input); + + // Advanced API: stores input/output Tensors without incrementing + // the reference count. The caller must ensure that these Tensors + // live at least as long as this TensorIteratorConfig and any + // TensorIteratorBase built from this TensorIteratorConfig. + // Important: the outputs have to be added before the inputs. + TensorIteratorConfig& add_borrowed_output(const TensorBase& output); + TensorIteratorConfig& add_borrowed_input(const TensorBase& input); + TensorIteratorConfig& add_borrowed_const_input(const TensorBase& input); + + // Borrowing from temporaries is unlikely to go well. + TensorIteratorConfig& add_borrowed_output(TensorBase&& output) = delete; + TensorIteratorConfig& add_borrowed_input(TensorBase&& input) = delete; + TensorIteratorConfig& add_borrowed_const_input(TensorBase&& input) = delete; + + // Sets the check_mem_overlap_ flag, which is true by default. + // If true, inputs are checked for partial overlap with the outputs and + // outputs are checked for internal overlap (e.g. broadcasted views). An error + // is raised if unacceptable overlap is detected. + // If you're migrating an existing operator to using TensorIterator, please + // consider if the previous implementation checked memory overlap. If it did + // not, and if the operator is idempotent (for example, Tensor.fill_(0)), then + // checking memory overlap is BC-breaking. Please don't check memory overlap + // in that case. + TensorIteratorConfig& set_check_mem_overlap(bool check_mem_overlap) { + check_mem_overlap_ = check_mem_overlap; + return *this; + } + + // Sets the check_all_same_dtype_ flag, which is true by default + // If true, checks that all inputs and defined outputs have the same dtype + // Setting either of promote_inputs_to_common_dtype_ + // or cast_common_dtype_to_outputs_ to true will set + // check_all_same_dtype_ to false. + TensorIteratorConfig& check_all_same_dtype(const bool _check_all_same_dtype) { + check_all_same_dtype_ = _check_all_same_dtype; + return *this; + } + + // Sets the check_all_same_device_ flag, which is true by default + // If true, all operands must be on the same device, with the possible + // exception of CPU scalars, which can be passed to some CUDA kernels + // as kernel arguments. + TensorIteratorConfig& check_all_same_device( + const bool _check_all_same_device) { + check_all_same_device_ = _check_all_same_device; + return *this; + } + + // Sets the enforce_safe_casting_to_output_ flag, which is false by default + // If true, the iterator's "common dtype" must be computable + // (see the [Common Dtype Computation] note) and + // canCast(common dtype, output dtype) must be true for all outputs. + TensorIteratorConfig& enforce_safe_casting_to_output( + const bool _enforce_safe_casting_to_output) { + enforce_safe_casting_to_output_ = _enforce_safe_casting_to_output; + return *this; + } + + // Sets the enforce_linear_iteration_ flag, which is false by default. + // If true, iteration goes in the same order as a C-contiguous tensor + // is laid out in memory. i.e. last dimension iterates fastest. + // + // This iteration order can be less efficient and may even prevent + // vectorization. So only use if the correctness of your kernel depends on it. + TensorIteratorConfig& enforce_linear_iteration( + const bool _enforce_linear_iteration = true) { + enforce_linear_iteration_ = _enforce_linear_iteration; + return *this; + } + + // Sets the promote_inputs_to_common_dtype_ flag, which is false by default + // If true, the iterator's "common dtype" is always computed (see the + // [Common Dtype Computation] note) and, on the CPU, temporary copies of + // the inputs in the common dtype are passed as the actual inputs to + // the operation. + // Setting this flag to true sets check_all_same_dtype_ to false. + TensorIteratorConfig& promote_inputs_to_common_dtype( + const bool _promote_inputs_to_common_dtype) { + promote_inputs_to_common_dtype_ = _promote_inputs_to_common_dtype; + if (_promote_inputs_to_common_dtype) { + check_all_same_dtype_ = false; + } + return *this; + } + + // Sets the promote_integer_inputs_to_float_ flag, which is false by default + // NOTE: If set to true, the promote_inputs_to_common_dtype_ must also be + // true. If true, if the iterator's "common dtype" is an integral type + // (including bool) + // then it is changed to the default float scalar type. + TensorIteratorConfig& promote_integer_inputs_to_float( + const bool _promote_integer_inputs_to_float) { + promote_integer_inputs_to_float_ = _promote_integer_inputs_to_float; + TORCH_INTERNAL_ASSERT( + !promote_integer_inputs_to_float_ || promote_inputs_to_common_dtype_); + return *this; + } + + TensorIteratorConfig& is_reduction(const bool _is_reduction) { + is_reduction_ = _is_reduction; + return *this; + } + + TensorIteratorConfig& allow_cpu_scalars(const bool _allow_cpu_scalars) { + allow_cpu_scalars_ = _allow_cpu_scalars; + return *this; + } + + // Sets the cast_common_dtype_to_outputs_ flag, which is false by default + // If true, the iterator's "common dtype" must be computatable + // (see the [Common Dtype Computation] note) and, on the CPU, temporary + // copies of the outputs are passed as the actual output to the operation. + // These temporaries are then copied to the original outputs after + // the operation is performed (see cast_outputs()). + // Setting this flag to true sets check_all_same_dtype_ to false. + TensorIteratorConfig& cast_common_dtype_to_outputs( + const bool _cast_common_dtype_to_outputs) { + cast_common_dtype_to_outputs_ = _cast_common_dtype_to_outputs; + if (_cast_common_dtype_to_outputs) { + check_all_same_dtype_ = false; + } + return *this; + } + + TensorIteratorConfig& resize_outputs(bool resize_outputs) { + resize_outputs_ = resize_outputs; + return *this; + } + + // Bypass output dtype/device computation and fix the dtype/device as + // specified here. + TensorIteratorConfig& declare_static_dtype_and_device( + ScalarType dtype, + Device device); + TensorIteratorConfig& declare_static_dtype(ScalarType dtype); + TensorIteratorConfig& declare_static_device(Device device); + TensorIteratorConfig& declare_static_shape(IntArrayRef shape); + TensorIteratorConfig& declare_static_shape( + IntArrayRef shape, + IntArrayRef squash_dims); + + // It would be better if this was && qualified, but this would be at the cost + // of a lot of boilerplate above + TensorIterator build() { + TensorIterator iter; + iter.build(*this); + return iter; + } + + private: + bool is_tensor_const(size_t idx); + + SmallVector, 4> tensors_; + int num_outputs_ = 0; + int num_inputs_ = 0; + + std::optional static_shape_ = std::nullopt; + std::optional static_dtype_ = std::nullopt; + std::optional static_device_ = std::nullopt; + bool check_mem_overlap_ = true; + bool allow_cpu_scalars_ = false; + bool is_reduction_ = false; + bool resize_outputs_ = true; + bool check_all_same_dtype_ = true; + bool check_all_same_device_ = true; + bool enforce_safe_casting_to_output_ = false; + bool enforce_linear_iteration_ = false; + bool promote_inputs_to_common_dtype_ = false; + bool promote_integer_inputs_to_float_ = false; + bool cast_common_dtype_to_outputs_ = false; + + SmallVector const_tensor_indices_; +}; + +/// A container-like struct that acts as if it contains splits of a +/// TensorIterator that can use 32-bit indexing. Taken together the splits cover +/// the original TensorIterator. +struct TORCH_API SplitUntil32Bit { + // NOLINTNEXTLINE(cppcoreguidelines-special-member-functions) + struct TORCH_API iterator { + iterator() = default; + iterator(const TensorIteratorBase& iter); + iterator(iterator&&) = default; + iterator& operator=(iterator&&) = default; + ~iterator() = default; + + // Guaranteed to be a TensorIterator proper! + TensorIterator& operator*() const; + iterator& operator++(); + bool operator==(const iterator& other) const { + // two iterators are equal if they are the same object or they're both + // empty + return this == &other || (vec.empty() && other.vec.empty()); + } + // needed for C++11 range-based for loop + bool operator!=(const iterator& other) const { + return !(*this == other); + } + + /// stack of TensorIterators to be split + std::vector> vec; + }; + + SplitUntil32Bit(const TensorIteratorBase& iter) : iter(iter) {} + + iterator begin() const; + iterator end() const; + + private: + // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) + const TensorIteratorBase& iter; +}; + +} // namespace at + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/TensorIteratorInternal.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/TensorIteratorInternal.h new file mode 100644 index 0000000000000000000000000000000000000000..11134d2512053e9efb494dbce5e1bccf43ffaedb --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/TensorIteratorInternal.h @@ -0,0 +1,77 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +#include +#include +#include + +namespace at { + +struct DimCounter { + DimCounter(IntArrayRef shape, Range range); + + void increment(const std::array& step); + bool is_done() const; + std::array max_2d_step() const; + + IntArrayRef shape; + Range range; + c10::SmallBuffer values; + int64_t offset; +}; + +namespace internal { + +inline void get_data_ptrs( + char** ptrs, + ArrayRef base, + IntArrayRef strides, + IntArrayRef counter) { + const auto ntensors = base.size(); + const auto ndim = counter.size(); + std::copy(base.begin(), base.end(), ptrs); + for (const auto dim : c10::irange(ndim)) { + int64_t value = counter[dim]; + for (const auto arg : c10::irange(ntensors)) { + ptrs[arg] += value * strides[dim * ntensors + arg]; + } + } +} + +inline void serial_for_each( + IntArrayRef shape, + IntArrayRef strides, + char** base_ptrs, + size_t ntensors, + TensorIteratorBase::loop2d_t loop, + Range range) { + const auto ndim = shape.size(); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + strides.size() == ntensors * std::max(size_t{2}, ndim)); + + if (ndim <= 1) { + if (range.begin == 0) { + loop(base_ptrs, strides.data(), range.size(), 1); + } else { + c10::SmallBuffer ptrs(ntensors); + get_data_ptrs(ptrs.data(), {base_ptrs, ntensors}, strides, {range.begin}); + loop(ptrs.data(), strides.data(), range.size(), 1); + } + } else { + c10::SmallBuffer ptrs(ntensors); + auto counter = DimCounter(shape, range); + while (!counter.is_done()) { + get_data_ptrs( + ptrs.data(), {base_ptrs, ntensors}, strides, counter.values); + auto step = counter.max_2d_step(); + loop(ptrs.data(), strides.data(), step[0], step[1]); + counter.increment(step); + } + } +} + +} // namespace internal +} // namespace at + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/TensorMeta.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/TensorMeta.h new file mode 100644 index 0000000000000000000000000000000000000000..0d7c4b830ab3404ddf2433ecfe1d9ffcbab22108 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/TensorMeta.h @@ -0,0 +1,142 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include + +namespace at { + +class Tensor; + +namespace impl { + +// Use this to define the prototype for a meta function. There are two +// versions; one that takes one argument (just the operator name), or FUNC2 +// variant that takes two arguments (operator name and overload name). +// +// Example usage: +// +// TORCH_META_FUNC2(add, Tensor) ( +// const Tensor& self, const Tensor& other +// ) { +// ... compute sizes and options ... +// set_output(sizes, options); +// } +// +#define TORCH_META_FUNC(name) void structured_##name::meta +#define TORCH_META_FUNC2(name, overload) \ + void structured_##name##_##overload::meta + +// These are versions of TORCH_META_FUNC(2) that include a precompute_out struct +// as a return value. They should be used when the kernel in question has +// precomputed values declared in native_functions.yaml and the corresponding +// implementation should return an instance of the aforementioned struct. +#define TORCH_PRECOMPUTE_META_FUNC(name) \ + structured_##name::meta_return_ty structured_##name::meta +#define TORCH_PRECOMPUTE_META_FUNC2(name, overload) \ + structured_##name##_##overload::meta_return_ty \ + structured_##name##_##overload::meta + +// Use this to create a precompute struct in a meta function. +#define TORCH_PRECOMPUTE_STRUCT(name) structured_##name::precompute_out<> +#define TORCH_PRECOMPUTE_STRUCT2(name, overload) \ + structured_##name##_##overload::precompute_out<> + +// Use this to define the prototype for an implementation. This takes only +// one argument, which is the name of the dispatch key entry you're +// implementing. +// +// Example usage: +// +// TORCH_IMPL_FUNC(add_cpu) ( +// Tensor& result, const Tensor& self, const Tensor& other +// ) { +// ... do the actual implementation ... +// } +// +#define TORCH_IMPL_FUNC(name) void structured_##name::impl + +// Base class for all structured kernel classes. The set_output virtual +// method is varied depending whether or not the operator is +// functional/out/inplace, and could also be specialized for CPU/CUDA/etc +// (although presently it isn't). +// +// A notable subclass of this interface is TensorIteratorBase. +struct TORCH_API MetaBase { + MetaBase() = default; + MetaBase(const MetaBase&) = default; + MetaBase& operator=(const MetaBase&) = default; + MetaBase(MetaBase&&) noexcept = default; + MetaBase& operator=(MetaBase&&) noexcept = default; + virtual const Tensor& maybe_get_output(int64_t output_idx) = 0; + + // Note: [set_output_*] + // See: https://github.com/pytorch/pytorch/issues/69813 + // Whenever defining the output properties in the META function of a + // structured kernel (what was usually done with `set_output`), use one of + // these 3 variants, instead. In order to decide which variant to use, check + // the following decision tree: + // + // - Can the kernel you are going to implement support output tensors + // with arbitrary strides? + // | + // -- YES: `set_output_raw_strided` + // | + // -- NO: Should the output tensor strides be contiguous? + // | + // -- YES: `set_output_contiguous` + // | + // -- NO: `set_output_strided` + // + // Use this function whenever the kernel requires specific strides for the + // output. If `strides` does not match the given output strides, proxy outputs + // will be created and passed to the IMPL function. + virtual void set_output_strided( + int64_t output_idx [[maybe_unused]], + IntArrayRef sizes [[maybe_unused]], + IntArrayRef strides [[maybe_unused]], + TensorOptions options [[maybe_unused]], + DimnameList names [[maybe_unused]] = {}) { + TORCH_INTERNAL_ASSERT(false, "set_output_strided not implemented."); + } + + // Use this function whenever the kernel knows how to handle arbitrary strided + // outputs. This function has the same behavior as the old `set_output`: it + // will only re-stride if the given output was resized. + virtual void set_output_raw_strided( + int64_t output_idx [[maybe_unused]], + IntArrayRef sizes [[maybe_unused]], + IntArrayRef strides_hint [[maybe_unused]], + TensorOptions options [[maybe_unused]], + DimnameList names [[maybe_unused]] = {}) { + TORCH_INTERNAL_ASSERT(false, "set_output_strided not implemented."); + } + + // Use this function if the kernel requires contiguous strides. + // Alias for `set_output_strided`, but with contiguous strides. + void set_output_contiguous( + int64_t output_idx, + IntArrayRef sizes, + TensorOptions options, + DimnameList names = {}) { + auto strides = c10::contiguous_strides(sizes); + set_output_strided(output_idx, sizes, strides, options, names); + } + + // Returns a reference to an undefined tensor if there is no presupplied + // output + const Tensor& maybe_get_output() { + return maybe_get_output(0); + } + virtual ~MetaBase() = default; +}; + +} // namespace impl + +} // namespace at + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/TensorNames.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/TensorNames.h new file mode 100644 index 0000000000000000000000000000000000000000..707d016b8672057d028d313b8f51e722a413da53 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/TensorNames.h @@ -0,0 +1,80 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include + +namespace at::namedinference { + +// TensorName and TensorNames are wrappers around Dimname and DimnameList +// that contain helper functions to make writing name inference rules easier. +// +// A TensorName represents a Dimname associated with some DimnameList (from a +// Tensor). This encapsulates all the information that is needed to check if +// names *match* and to *unify* names. +// +// Definition: Two names in two tensors *match* if they are equal, or if at +// least one of them is a wildcard that can be *refined* to the other name. +// +// Definition: unify(name, other) fails if the names do not match. Otherwise, +// it returns the most refined of name and other. +// +// Here is an example of checking if two names match. +// tensor: Tensor[A, None] +// other: Tensor[A] +// +// Let's say we wish to check if tensor.names[-1] matches other.names[-1]. +// None (in tensor) cannot match A (in other) because if the None were refined +// to A, `tensor` would have duplicate names [A, A]. Therefore we need to check +// tensor.names [A, None] for the existence of A. +struct TORCH_API TensorName { + explicit TensorName(ArrayRef origin, int origin_idx) + : origin_(origin), + name_(origin[maybe_wrap_dim( + origin_idx, + static_cast(origin.size()))]), + origin_idx_(origin_idx) {} + + // op_name is only used for error reporting. + const TensorName& unify(const TensorName& other, const char* op_name) const; + Dimname toDimname() const; + + private: + ArrayRef origin_; + Dimname name_; + int origin_idx_; // A named tensor can have at most 64 dims. + + TORCH_API friend std::ostream& operator<<( + std::ostream& out, + const TensorName& tensorname); +}; + +using TensorNameVec = SmallVector; + +struct TORCH_API TensorNames { + explicit TensorNames(ArrayRef names); + + // Create TensorNames from names[start:end]. Each individual TensorName stores + // `names`, NOT names[start:end], because the original tensor's names are + // `names`. + explicit TensorNames(ArrayRef names, int64_t start, int64_t end); + + // op_name is only used for error reporting. + TensorNames& unifyFromRightInplace( + const TensorNames& other, + const char* op_name = "unify"); + void checkUnique(const char* op_name) const; + + void append(TensorName name); + std::vector toDimnameVec() const; + + private: + explicit TensorNames(TensorNameVec&& names) : names_(std::move(names)) {} + + TensorNameVec names_; +}; + +} // namespace at::namedinference + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/TensorOperators.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/TensorOperators.h new file mode 100644 index 0000000000000000000000000000000000000000..57e84eb77e0529d16d6c70b91dc28e30b1406522 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/TensorOperators.h @@ -0,0 +1,56 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#endif + +namespace at { + +#define AT_FORALL_BINARY_OPS(_) \ + _(+, x.add(y), y.add(x)) \ + _(*, x.mul(y), y.mul(x)) \ + _(-, \ + x.sub(y), \ + ::at::empty_like(y, at::MemoryFormat::Preserve).fill_(x).sub_(y)) \ + _(/, \ + x.div(y), \ + ::at::empty_like(y, at::MemoryFormat::Preserve).fill_(x).div_(y)) \ + _(%, \ + x.remainder(y), \ + ::at::empty_like(y, at::MemoryFormat::Preserve).fill_(x).remainder_(y)) \ + _(&, x.bitwise_and(y), y.bitwise_and(x)) \ + _(|, x.bitwise_or(y), y.bitwise_or(x)) \ + _(^, x.bitwise_xor(y), y.bitwise_xor(x)) \ + _(<, x.lt(y), y.gt(x)) \ + _(<=, x.le(y), y.ge(x)) \ + _(>, x.gt(y), y.lt(x)) \ + _(>=, x.ge(y), y.le(x)) \ + _(==, x.eq(y), y.eq(x)) \ + _(!=, x.ne(y), y.ne(x)) + +#define DEFINE_OPERATOR(op, body, reverse_scalar_body) \ + inline Tensor operator op(const Tensor& x, const Tensor& y) { \ + return body; \ + } \ + inline Tensor operator op(const Tensor& x, const Scalar& y) { \ + return body; \ + } \ + inline Tensor operator op(const Scalar& x, const Tensor& y) { \ + return reverse_scalar_body; \ + } + +AT_FORALL_BINARY_OPS(DEFINE_OPERATOR) +#undef DEFINE_OPERATOR +#undef AT_FORALL_BINARY_OPS + +} // namespace at + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/TensorOptions.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/TensorOptions.h new file mode 100644 index 0000000000000000000000000000000000000000..6c67b9f53d6da55d97fd5571d79c2340a4098e19 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/TensorOptions.h @@ -0,0 +1,7 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +#include + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/TensorSubclassLikeUtils.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/TensorSubclassLikeUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..694681cc7ce4756a7ad4a83ca5b0dc3f0afdc414 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/TensorSubclassLikeUtils.h @@ -0,0 +1,93 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#endif + +namespace at { + +// Note [Tensor-subclass-like Tensors] +// Tensor-subclass-like is defined as: +// - a Tensor subclass (via __torch_dispatch__ in Python or extending +// TensorImpl in C++) +// - anything else that shares the same perils as Tensor subclasses. +// For example, many Tensor subclasses do not have storage and meta Tensors +// do not have storage either, so meta Tensors belong here. +// +// We should ensure that PyTorch internals supports Tensor-subclass-like +// objects. In particular, Tensor-subclass-like objects struggle with two +// classes of operations that are problematic for Tensor subclasses: +// 1. Because some Tensor subclasses do not have storage, .item() or +// .data_ptr() calls are not good. +// 2. Certain in-place operations can eliminate the typing of the Tensor +// subclass. For example: +// >>> torch.zeros(input.sizes(), grad.options()).diag().copy_(input) +// If input is a Tensor subclass, then the above ends up either erroring out +// or returning a regular non-Tensor-subclass Tensor! + +constexpr auto kFunctorchWrappedTensors = DispatchKeySet( + {DispatchKey::FuncTorchGradWrapper, + DispatchKey::FuncTorchBatched, + DispatchKey::Functionalize}); + +constexpr auto kTensorSubclassLike = + kFunctorchWrappedTensors | + DispatchKeySet( + {// WARNING: DO NOT put combined backend component + functionality keys + // here, you will incorrectly always match on the functionality key + // no matter the backend component + DispatchKey::Batched, + DispatchKey::Sparse, + DispatchKey::SparseCsr, + DispatchKey::Python}) | + DispatchKeySet(BackendComponent::MetaBit); + +inline bool isTensorSubclassLike(const Tensor& tensor) { + if (c10::impl::dispatch_mode_enabled()) + return true; + auto key_set = tensor.unsafeGetTensorImpl()->key_set(); + return !(key_set & kTensorSubclassLike).empty(); +} + +inline bool areAnyTensorSubclassLike(TensorList tensors) { + if (c10::impl::dispatch_mode_enabled()) + return true; + return std::any_of(tensors.begin(), tensors.end(), isTensorSubclassLike); +} + +inline bool areAnyOptionalTensorSubclassLike( + const c10::List>& tensors) { + if (c10::impl::dispatch_mode_enabled()) + return true; + return std::any_of( + tensors.begin(), + tensors.end(), + [](const std::optional& opt_tensor) { + return ( + opt_tensor.has_value() && isTensorSubclassLike(opt_tensor.value())); + }); +} + +// Helper function to deal testing truthfulness of a scalar tensor +// in a Composite Compliant manner. +// NOTE: This function expects a scalar tensor of boolean dtype. +// Eg. +// Non-Composite Compliant Pattern : (t == 0).all().item() +// Composite Compliant Pattern : is_salar_tensor_true((t == 0).all()) +inline bool is_scalar_tensor_true(const Tensor& t) { + TORCH_INTERNAL_ASSERT(t.dim() == 0) + TORCH_INTERNAL_ASSERT(t.scalar_type() == kBool) + return at::equal(t, t.new_ones({}, t.options())); +} + +} // namespace at + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/TensorUtils.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/TensorUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..27c4bc38f2add9e353a9026a0c2732ee548c3d7f --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/TensorUtils.h @@ -0,0 +1,195 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include +#include + +#include + +// These functions are NOT in Utils.h, because this file has a dep on Tensor.h + +#define TORCH_CHECK_TENSOR_ALL(cond, ...) \ + TORCH_CHECK((cond)._is_all_true().item(), __VA_ARGS__); + +namespace at { + +// The following are utility functions for checking that arguments +// make sense. These are particularly useful for native functions, +// which do NO argument checking by default. + +struct TORCH_API TensorArg { + // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) + const Tensor& tensor; + const char* name; + int pos; // 1-indexed + TensorArg(const Tensor& tensor, const char* name, int pos) + : tensor(tensor), name(name), pos(pos) {} + // Try to mitigate any possibility of dangling reference to temporaries. + // NOLINTNEXTLINE(cppcoreguidelines-rvalue-reference-param-not-moved) + TensorArg(Tensor&& tensor, const char* name, int pos) = delete; + const Tensor* operator->() const { + return &tensor; + } + const Tensor& operator*() const { + return tensor; + } +}; + +struct TORCH_API TensorGeometryArg { + TensorGeometry tensor; + const char* name; + int pos; // 1-indexed + /* implicit */ TensorGeometryArg(TensorArg arg) + : tensor(TensorGeometry{arg.tensor}), name(arg.name), pos(arg.pos) {} + TensorGeometryArg(TensorGeometry tensor, const char* name, int pos) + : tensor(std::move(tensor)), name(name), pos(pos) {} + const TensorGeometry* operator->() const { + return &tensor; + } + const TensorGeometry& operator*() const { + return tensor; + } +}; + +// A string describing which function did checks on its input +// arguments. +// TODO: Consider generalizing this into a call stack. +using CheckedFrom = const char*; + +// The undefined convention: singular operators assume their arguments +// are defined, but functions which take multiple tensors will +// implicitly filter out undefined tensors (to make it easier to perform +// tests which should apply if the tensor is defined, and should not +// otherwise.) +// +// NB: This means that the n-ary operators take lists of TensorArg, +// not TensorGeometryArg, because the Tensor to TensorGeometry +// conversion will blow up if you have undefined tensors. + +TORCH_API std::ostream& operator<<( + std::ostream& out, + const TensorGeometryArg& t); +TORCH_API void checkDim( + CheckedFrom c, + const Tensor& tensor, + const char* name, + int pos, // 1-indexed + int64_t dim); +TORCH_API void checkDim(CheckedFrom c, const TensorGeometryArg& t, int64_t dim); +// NB: this is an inclusive-exclusive range +TORCH_API void checkDimRange( + CheckedFrom c, + const TensorGeometryArg& t, + int64_t dim_start, + int64_t dim_end); +TORCH_API void checkSameDim( + CheckedFrom c, + const TensorGeometryArg& t1, + const TensorGeometryArg& t2); +TORCH_API void checkContiguous(CheckedFrom c, const TensorGeometryArg& t); +TORCH_API void checkAllContiguous(CheckedFrom c, at::ArrayRef ts); +TORCH_API void checkSize( + CheckedFrom c, + const TensorGeometryArg& t, + IntArrayRef sizes); +TORCH_API void checkSize_symint( + CheckedFrom c, + const TensorGeometryArg& t, + c10::SymIntArrayRef sizes); +TORCH_API void checkSize( + CheckedFrom c, + const TensorGeometryArg& t, + int64_t dim, + int64_t size); +TORCH_API void checkSize_symint( + CheckedFrom c, + const TensorGeometryArg& t, + int64_t dim, + const c10::SymInt& size); +TORCH_API void checkNumel( + CheckedFrom c, + const TensorGeometryArg& t, + int64_t numel); +TORCH_API void checkSameNumel( + CheckedFrom c, + const TensorArg& t1, + const TensorArg& t2); +TORCH_API void checkAllSameNumel(CheckedFrom c, ArrayRef tensors); +TORCH_API void checkScalarType(CheckedFrom c, const TensorArg& t, ScalarType s); +TORCH_API void checkScalarTypes( + CheckedFrom c, + const TensorArg& t, + at::ArrayRef l); +TORCH_API void checkSameGPU( + CheckedFrom c, + const TensorArg& t1, + const TensorArg& t2); +TORCH_API void checkAllSameGPU(CheckedFrom c, ArrayRef tensors); +TORCH_API void checkSameType( + CheckedFrom c, + const TensorArg& t1, + const TensorArg& t2); +TORCH_API void checkAllSameType(CheckedFrom c, ArrayRef tensors); +TORCH_API void checkSameSize( + CheckedFrom c, + const TensorArg& t1, + const TensorArg& t2); +TORCH_API void checkAllSameSize(CheckedFrom c, ArrayRef tensors); +TORCH_API void checkDefined(CheckedFrom c, const TensorArg& t); +TORCH_API void checkAllDefined(CheckedFrom c, at::ArrayRef t); + +// FixMe: does TensorArg slow things down? +TORCH_API void checkBackend( + CheckedFrom c, + at::ArrayRef t, + at::Backend backend); + +TORCH_API void checkDeviceType( + CheckedFrom c, + at::ArrayRef tensors, + at::DeviceType device_type); + +TORCH_API void checkLayout(CheckedFrom c, const Tensor& t, Layout layout); + +TORCH_API void checkLayout( + CheckedFrom c, + at::ArrayRef tensors, + at::Layout layout); + +// Methods for getting data_ptr if tensor is defined +TORCH_API void* maybe_data_ptr(const Tensor& tensor); +TORCH_API void* maybe_data_ptr(const TensorArg& tensor); + +TORCH_API void check_dim_size( + const Tensor& tensor, + int64_t dim, + int64_t dim_size, + int64_t size); + +namespace detail { +TORCH_API std::vector defaultStrides(IntArrayRef sizes); + +TORCH_API std::optional> computeStride( + IntArrayRef oldshape, + IntArrayRef oldstride, + IntArrayRef newshape); + +TORCH_API std::optional computeStride( + c10::SymIntArrayRef oldshape, + c10::SymIntArrayRef oldstride, + c10::SymIntArrayRef newshape); + +TORCH_API std::optional computeStride( + IntArrayRef oldshape, + IntArrayRef oldstride, + const DimVector& newshape); + +} // namespace detail +} // namespace at + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/ThreadLocalPythonObjects.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/ThreadLocalPythonObjects.h new file mode 100644 index 0000000000000000000000000000000000000000..c7ec22594ed1f4612ca4131c475ae47a1166c310 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/ThreadLocalPythonObjects.h @@ -0,0 +1,26 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include + +namespace at::impl { + +struct TORCH_API ThreadLocalPythonObjects { + static void set(const std::string& key, std::shared_ptr value); + static const std::shared_ptr& get(const std::string& key); + static bool contains(const std::string& key); + + static const ThreadLocalPythonObjects& get_state(); + static void set_state(ThreadLocalPythonObjects state); + + private: + std::unordered_map> obj_dict_; +}; + +} // namespace at::impl + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/ThreadLocalState.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/ThreadLocalState.h new file mode 100644 index 0000000000000000000000000000000000000000..abd3b361ac1f62d1bff3b2e731917b899c43c32d --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/ThreadLocalState.h @@ -0,0 +1,131 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace at { + +// Thread local state contains values that are preserved across +// thread boundaries (e.g. at::launch/JIT fork, autograd). +// Note at::parallel_for doesn't preserve TLS across thread boundaries. +class TORCH_API ThreadLocalState { + public: + // Saves the thread local variables' values and + // returns them as a ThreadLocalState + ThreadLocalState(); + + // set_grad_mode - force the value of the grad mode TLS in + // the current state object. This is used for example in the + // autograd engine. + void set_grad_mode(bool enabled); + + // set_multithreading_enabled - force the value of the multithreadinmaximum + // threads TLS in + // the current state object. This is used for example in the + // autograd engine. + void set_multithreading_enabled(bool enabled); + + // Sets thread local variables in the current thread, + // according to the thread boundary specified + static void setThreadLocalState(const ThreadLocalState& state); + + private: + c10::impl::LocalDispatchKeySet dispatch_key_; + + // ThreadLocalDebugInfo does not change after being created + // with DebugInfoGuard + std::shared_ptr debug_info_; + + // RecordFunction TLS + RecordFunctionTLS rf_tls_; + + // TLS for out-of-tree functorch + // See NOTE [functorch TLS in pytorch/pytorch] for why this needs to be a + // pointer (spoiler alert: it's due to the indirection) + // This needs to be a shared_ptr instead of a unique_ptr because + // ThreadLocalState is copy-able and does indeed get copied. Maybe we can + // consider adding an explicit copy constructor for ThreadLocalState in the + // future but I didn't want to add one just for this. + std::shared_ptr functorch_tls_; + + // TLS for AutogradModes + AutogradState autograd_tls_; + + // TLS for enable_torch_dispatch_mode + c10::impl::TorchDispatchModeTLS torch_dispatch_mode_state_; + + // TLS for enable_python_dispatcher + c10::impl::PyInterpreter* python_dispatcher_state_; + + // TLS for __torch_function__ (mode and disable_torch_function) + at::impl::PythonTorchFunctionTLS python_torch_function_state_; + + // TLS for saved tensors default hooks + at::impl::SavedTensorDefaultHooksTLS saved_tensors_default_hooks_state_; + + bool functionalization_reapply_views_state_; + + bool dtensor_allow_implicit_replication_; + + // TLS for arbitrary python objects that is registered via hooks + at::impl::ThreadLocalPythonObjects saved_objects_; + +#if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) && \ + !defined(BUILD_LITE_INTERPRETER) + // TLS for autocast dtypes + std::array + autocast_dtypes_{}; +#endif + + friend class ThreadLocalStateGuard; +}; + +// Guard to set and reset the thread local state +class TORCH_API ThreadLocalStateGuard { + public: + explicit ThreadLocalStateGuard(const ThreadLocalState& state) + : prev_state_(ThreadLocalState()) { + // set the given state across the thread boundary + ThreadLocalState::setThreadLocalState(state); + } + ThreadLocalStateGuard(ThreadLocalStateGuard&& other) = delete; + ThreadLocalStateGuard(const ThreadLocalStateGuard&) = delete; + ThreadLocalStateGuard& operator=(const ThreadLocalStateGuard&) = delete; + ThreadLocalStateGuard& operator=(ThreadLocalStateGuard&&) = delete; + + ~ThreadLocalStateGuard() { + // restore previously set variables + ThreadLocalState::setThreadLocalState(prev_state_); + } + + private: + // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) + const ThreadLocalState prev_state_; +}; + +template +auto wrapPropagateTLSState(T callback) { + return [tls_state = ThreadLocalState(), + callback = std::move(callback)](auto&&... args) { + ThreadLocalStateGuard g(tls_state); + // Propagate value returned by callback(). + return callback(std::forward(args)...); + }; +} + +} // namespace at + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/TracerMode.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/TracerMode.h new file mode 100644 index 0000000000000000000000000000000000000000..bcf580c847880623db566a4b11d7396d73cf8d42 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/TracerMode.h @@ -0,0 +1,137 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include + +// NOTE [Tracing Mode Switches] +// +// Historically, tracing function was controlled by two switches: +// +// - `AutoDispatchBelowADInplaceOrView` guard +// +// Tracing function used to be script-generated inside `VariableType_*.cpp` +// kernels, sharing the same `Autograd` dispatch key with autograd function. +// Therefore, before tracing function was moved out of VariableType, +// `AutoDispatchBelowADInplaceOrView` guard can also disable tracing as a +// side effect of disabling `Autograd` dispatching. +// +// - `setTracingState()` API in `torch/csrc/jit/frontend/tracer.h` +// +// It stores tracing data in a `TracingState` object in TLS. If the +// `TracingState` object in TLS is `null`, then tracing is paused. +// +// The `TracingState` object is created in `tracer::trace()` - the main +// entrance of tracing function. It's temporarily set to `null` inside +// generated VariableType (now TraceType) to bypass tracing for intermediate +// ops (ops being called by other ops). After the intermediate op call +// finishes it's set back to the original `TracingState` object. +// +// The `TracingState` object in TLS can also be read/written via its Python +// binding in `python_tracer.cpp`, and `get/setTracingState()` C++ APIs, +// which are also exposed as `TORCH_API`. +// +// Two new switches were introduced since tracing function was moved out of +// VariableType: +// +// - `tracer::impl::set_dispatch_enabled()` API +// +// Unlike the special `Autograd` dispatch key which is included in dispatch +// key set by default, `Tracer` dispatch key is off by default. The +// dispatching switch can be toggled via this new API. +// +// - `tracer::impl::NoTracerDispatchMode` guard +// +// It's used to cover the old semantics of `AutoDispatchBelowADInplaceOrView` +// after tracing was moved out of VariableType. +// +// Before tracing function was moved out of VariableType, tracing was enabled +// when the following conditions are satisfied: +// +// 1) `TracingState` object in TLS != null; +// - Either inside the execution scope of `tracer::trace()`, or +// - Eagerly called `setTracingState()` with non-null object. +// 2) Not inside `AutoDispatchBelowADInplaceOrView` scope; +// +// After: +// +// 1) `TracingState` object in TLS != null; +// 2) Has called `tracer::impl::set_dispatch_enabled(true)`; +// 3) Not inside `tracer::impl::NonDispatchGuard` scope; +// +// [TODOs] +// +// - `setTracingState()` v.s. `tracer::impl::set_dispatch_enabled()` +// +// Currently `set_dispatch_enabled()` is set/unset inside `setTracingState()` +// to keep the semantics exactly the same as before - it's confusing to keep +// both switches, though. We should consider simplifying/limiting the exposed +// `setTracingState()` Python/C++ APIs (and other APIs calling it) so that +// these two can be unified. +// +// - `AutoDispatchBelowADInplaceOrView` v.s. +// `tracer::impl::NoTracerDispatchMode` +// +// We don't need to always set both guards together to keep semantics +// unchanged. For the follow use cases of `AutoDispatchBelowADInplaceOrView` +// we don't need set the new tracer guard: +// +// * Script-generated VariableType kernels. The guard is not necessary as +// tracing is already disabled explicitly by `setTracingState(null)` in +// generated TraceType kernels - we could keep it as is or use the new guard +// instead. +// +// * Custom ops. Will be handled by fallback kernel for `Tracer`. +// +// * Functions that are not likely to be called in tracing context (no python +// binding / not an operator), e.g.: all mobile forward() wrappers, test +// binaries, and etc. +// +// * Where new threads are spawned, e.g.: ATen/native/ConvolutionMM2d.cpp. +// It's not necessary as tracing is off by default. +// +// For the rest of cases we might need have both: +// +// * Functions that might be reachable from eager mode python (especially +// factory methods), e.g.: +// `internal_new_from_data()` in `torch/csrc/utils/tensor_new.cpp`. +// Without the new guard it will add `aten::empty` to the traced graph. +// +// * Some manually maintained functions, e.g.: +// `torch/csrc/autograd/VariableTypeManual.cpp`. +// Set the new guard if it's not obvious whether `setTracingState(null)` +// has been called before it reaches the `AutoDispatchBelowADInplaceOrView` +// guard. +// +// We might need tweak the usage of the new guard to optimize/fix things. +// It should only affect the correctness of tracing function, because the +// guard is essentially no-op when the master `setTracingState()` switch is +// off. + +// TODO: move this from `at::` to `jit::torch::` after +// `aten/src/ATen/cpp_custom_type_hack.h` is removed. + +namespace at::tracer::impl { + +inline bool is_dispatch_enabled() { + return c10::impl::tls_is_dispatch_key_included(at::DispatchKey::Tracer) && + !c10::impl::tls_is_dispatch_key_excluded(at::DispatchKey::Tracer); +} + +inline void set_dispatch_enabled(bool enabled) { + TORCH_INTERNAL_ASSERT( + !c10::impl::tls_is_dispatch_key_excluded(at::DispatchKey::Tracer), + "Cannot enable tracing within the scope of NoTracerDispatchMode!"); + c10::impl::tls_set_dispatch_key_included(at::DispatchKey::Tracer, enabled); +} + +struct NoTracerDispatchMode { + c10::impl::ExcludeDispatchKeyGuard guard_{at::DispatchKey::Tracer}; +}; + +} // namespace at::tracer::impl + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/TypeDefault.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/TypeDefault.h new file mode 100644 index 0000000000000000000000000000000000000000..53a835e224c4abe7ab2930260ff394dd0992501b --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/TypeDefault.h @@ -0,0 +1,31 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace c10 { +struct Storage; +} + +namespace at { + +class Tensor; +using TensorList = ArrayRef; + +class Context; +struct Generator; + +struct Quantizer; + +} // namespace at + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/Utils.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/Utils.h new file mode 100644 index 0000000000000000000000000000000000000000..afb2ecca3cd6395acdf5965e19d078cf7b2ffdd8 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/Utils.h @@ -0,0 +1,143 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#define AT_DISALLOW_COPY_AND_ASSIGN(TypeName) \ + TypeName(const TypeName&) = delete; \ + void operator=(const TypeName&) = delete + +namespace at { + +TORCH_API int _crash_if_asan(int /*arg*/); + +// Converts a TensorList (i.e. ArrayRef to vector of TensorImpl*) +// NB: This is ONLY used by legacy TH bindings, and ONLY used by cat. +// Once cat is ported entirely to ATen this can be deleted! +inline std::vector checked_dense_tensor_list_unwrap( + ArrayRef tensors, + const char* name, + int pos, + c10::DeviceType device_type, + ScalarType scalar_type) { + std::vector unwrapped; + unwrapped.reserve(tensors.size()); + for (const auto i : c10::irange(tensors.size())) { + const auto& expr = tensors[i]; + if (expr.layout() != Layout::Strided) { + TORCH_CHECK( + false, + "Expected dense tensor but got ", + expr.layout(), + " for sequence element ", + i, + " in sequence argument at position #", + pos, + " '", + name, + "'"); + } + if (expr.device().type() != device_type) { + TORCH_CHECK( + false, + "Expected object of device type ", + device_type, + " but got device type ", + expr.device().type(), + " for sequence element ", + i, + " in sequence argument at position #", + pos, + " '", + name, + "'"); + } + if (expr.scalar_type() != scalar_type) { + TORCH_CHECK( + false, + "Expected object of scalar type ", + scalar_type, + " but got scalar type ", + expr.scalar_type(), + " for sequence element ", + i, + " in sequence argument at position #", + pos, + " '", + name, + "'"); + } + unwrapped.emplace_back(expr.unsafeGetTensorImpl()); + } + return unwrapped; +} + +template +std::array check_intlist( + ArrayRef list, + const char* name, + int pos) { + if (list.empty()) { + // TODO: is this necessary? We used to treat nullptr-vs-not in IntList + // differently with strides as a way of faking optional. + list = {}; + } + auto res = std::array(); + if (list.size() == 1 && N > 1) { + res.fill(list[0]); + return res; + } + if (list.size() != N) { + TORCH_CHECK( + false, + "Expected a list of ", + N, + " ints but got ", + list.size(), + " for argument #", + pos, + " '", + name, + "'"); + } + std::copy_n(list.begin(), N, res.begin()); + return res; +} + +using at::detail::check_size_nonnegative; + +namespace detail { + +template +TORCH_API Tensor tensor_cpu(ArrayRef values, const TensorOptions& options); + +template +TORCH_API Tensor +tensor_backend(ArrayRef values, const TensorOptions& options); + +template +TORCH_API Tensor +tensor_complex_cpu(ArrayRef values, const TensorOptions& options); + +template +TORCH_API Tensor +tensor_complex_backend(ArrayRef values, const TensorOptions& options); +} // namespace detail + +} // namespace at + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/Version.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/Version.h new file mode 100644 index 0000000000000000000000000000000000000000..5dafcebb9147bf5a1ce0380469f02cc9deadec05 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/Version.h @@ -0,0 +1,23 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#include + +namespace at { + +/// Returns a detailed string describing the configuration PyTorch. +TORCH_API std::string show_config(); + +TORCH_API std::string get_mkl_version(); + +TORCH_API std::string get_mkldnn_version(); + +TORCH_API std::string get_openmp_version(); + +TORCH_API std::string get_cxx_flags(); + +TORCH_API std::string get_cpu_capability(); + +} // namespace at + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/VmapGeneratedPlumbing.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/VmapGeneratedPlumbing.h new file mode 100644 index 0000000000000000000000000000000000000000..2609765f05c5d373950d86251dca86a446f22d9a --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/VmapGeneratedPlumbing.h @@ -0,0 +1,28271 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) + +#pragma once +#include +#include + +namespace at { namespace functorch { + +template +at::Tensor _cast_Byte_generated_plumbing(const at::Tensor & self, bool non_blocking) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_cast_Byte::call(self, non_blocking); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, non_blocking); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _cast_Char_generated_plumbing(const at::Tensor & self, bool non_blocking) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_cast_Char::call(self, non_blocking); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, non_blocking); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _cast_Double_generated_plumbing(const at::Tensor & self, bool non_blocking) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_cast_Double::call(self, non_blocking); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, non_blocking); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _cast_Float_generated_plumbing(const at::Tensor & self, bool non_blocking) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_cast_Float::call(self, non_blocking); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, non_blocking); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _cast_Int_generated_plumbing(const at::Tensor & self, bool non_blocking) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_cast_Int::call(self, non_blocking); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, non_blocking); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _cast_Long_generated_plumbing(const at::Tensor & self, bool non_blocking) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_cast_Long::call(self, non_blocking); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, non_blocking); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _cast_Short_generated_plumbing(const at::Tensor & self, bool non_blocking) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_cast_Short::call(self, non_blocking); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, non_blocking); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _cast_Half_generated_plumbing(const at::Tensor & self, bool non_blocking) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_cast_Half::call(self, non_blocking); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, non_blocking); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _backward_generated_plumbing(const at::Tensor & self, at::TensorList inputs, const ::std::optional & gradient, ::std::optional retain_graph, bool create_graph) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(inputs, cur_level) && !isBatchedAtLevel(gradient, cur_level)) { + return at::_ops::_backward::call(self, inputs, gradient, retain_graph, create_graph); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + std::optional gradient_value; + std::optional gradient_bdim; + if (gradient) { + std::tie(gradient_value, gradient_bdim) = unwrapTensorAtLevel(gradient.value(), cur_level); + } + batch_rule(self_value, self_bdim, inputs, gradient_value, gradient_bdim, retain_graph, create_graph); +} +template +void set_data_generated_plumbing(at::Tensor & self, const at::Tensor & new_data) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(new_data, cur_level)) { + return at::_ops::set_data::call(self, new_data); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [new_data_value, new_data_bdim] = unwrapTensorAtLevel(new_data, cur_level); + batch_rule(self_value, self_bdim, new_data_value, new_data_bdim); +} +template +at::Tensor data_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::data::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & requires_grad__generated_plumbing(at::Tensor & self, bool requires_grad) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::requires_grad_::call(self, requires_grad); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, requires_grad); + return self; +} +template +void retain_grad_generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::retain_grad::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); +} +template +at::Tensor _fw_primal_generated_plumbing(const at::Tensor & self, int64_t level) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_fw_primal::call(self, level); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, level); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _make_dual_generated_plumbing(const at::Tensor & primal, const at::Tensor & tangent, int64_t level) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(primal, cur_level) && !isBatchedAtLevel(tangent, cur_level)) { + return at::_ops::_make_dual::call(primal, tangent, level); + } + auto [primal_value, primal_bdim] = unwrapTensorAtLevel(primal, cur_level); + auto [tangent_value, tangent_bdim] = unwrapTensorAtLevel(tangent, cur_level); + auto results = batch_rule(primal_value, primal_bdim, tangent_value, tangent_bdim, level); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple _unpack_dual_generated_plumbing(const at::Tensor & dual, int64_t level) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(dual, cur_level)) { + return at::_ops::_unpack_dual::call(dual, level); + } + auto [dual_value, dual_bdim] = unwrapTensorAtLevel(dual, cur_level); + auto results = batch_rule(dual_value, dual_bdim, level); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor _new_zeros_with_same_feature_meta_generated_plumbing(const at::Tensor & self, const at::Tensor & other, int64_t self_num_batch_dims) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::_new_zeros_with_same_feature_meta::call(self, other, self_num_batch_dims); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim, self_num_batch_dims); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor rename_generated_plumbing(const at::Tensor & self, ::std::optional names) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::rename::call(self, names); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, names); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor align_to_generated_plumbing(const at::Tensor & self, at::DimnameList names) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::align_to::call(self, names); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, names); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor align_to_ellipsis_idx_generated_plumbing(const at::Tensor & self, at::DimnameList order, int64_t ellipsis_idx) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::align_to_ellipsis_idx::call(self, order, ellipsis_idx); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, order, ellipsis_idx); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor align_as_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::align_as::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::vector align_tensors_generated_plumbing(at::TensorList tensors) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(tensors, cur_level)) { + return at::_ops::align_tensors::call(tensors); + } + + auto results = batch_rule(tensors); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _assert_async_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_assert_async::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); +} +template +void _assert_async_msg_generated_plumbing(const at::Tensor & self, c10::string_view assert_msg) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_assert_async_msg::call(self, assert_msg); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, assert_msg); +} +template +at::Tensor _functional_assert_scalar_generated_plumbing(const at::Scalar & self, c10::string_view assert_msg, const at::Tensor & dep_token) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(dep_token, cur_level)) { + return at::_ops::_functional_assert_scalar::call(self, assert_msg, dep_token); + } + auto [dep_token_value, dep_token_bdim] = unwrapTensorAtLevel(dep_token, cur_level); + auto results = batch_rule(self, assert_msg, dep_token_value, dep_token_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _functional_assert_async_msg_generated_plumbing(const at::Tensor & self, c10::string_view assert_msg, const at::Tensor & dep_token) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(dep_token, cur_level)) { + return at::_ops::_functional_assert_async_msg::call(self, assert_msg, dep_token); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [dep_token_value, dep_token_bdim] = unwrapTensorAtLevel(dep_token, cur_level); + auto results = batch_rule(self_value, self_bdim, assert_msg, dep_token_value, dep_token_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _assert_tensor_metadata_generated_plumbing(const at::Tensor & a, at::OptionalSymIntArrayRef size, at::OptionalSymIntArrayRef stride, ::std::optional dtype, ::std::optional device, ::std::optional layout) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(a, cur_level)) { + return at::_ops::_assert_tensor_metadata::call(a, size, stride, dtype, device, layout); + } + auto [a_value, a_bdim] = unwrapTensorAtLevel(a, cur_level); + batch_rule(a_value, a_bdim, size, stride, dtype, device, layout); +} +template +at::Tensor _functional_sym_constrain_range_generated_plumbing(const at::Scalar & size, ::std::optional min, ::std::optional max, const at::Tensor & dep_token) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(dep_token, cur_level)) { + return at::_ops::_functional_sym_constrain_range::call(size, min, max, dep_token); + } + auto [dep_token_value, dep_token_bdim] = unwrapTensorAtLevel(dep_token, cur_level); + auto results = batch_rule(size, min, max, dep_token_value, dep_token_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _functional_sym_constrain_range_for_size_generated_plumbing(const at::Scalar & size, ::std::optional min, ::std::optional max, const at::Tensor & dep_token) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(dep_token, cur_level)) { + return at::_ops::_functional_sym_constrain_range_for_size::call(size, min, max, dep_token); + } + auto [dep_token_value, dep_token_bdim] = unwrapTensorAtLevel(dep_token, cur_level); + auto results = batch_rule(size, min, max, dep_token_value, dep_token_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor refine_names_generated_plumbing(const at::Tensor & self, at::DimnameList names) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::refine_names::call(self, names); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, names); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple _cudnn_ctc_loss_generated_plumbing(const at::Tensor & log_probs, const at::Tensor & targets, at::IntArrayRef input_lengths, at::IntArrayRef target_lengths, int64_t blank, bool deterministic, bool zero_infinity) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(log_probs, cur_level) && !isBatchedAtLevel(targets, cur_level)) { + return at::_ops::_cudnn_ctc_loss::call(log_probs, targets, input_lengths, target_lengths, blank, deterministic, zero_infinity); + } + auto [log_probs_value, log_probs_bdim] = unwrapTensorAtLevel(log_probs, cur_level); + auto [targets_value, targets_bdim] = unwrapTensorAtLevel(targets, cur_level); + auto results = batch_rule(log_probs_value, log_probs_bdim, targets_value, targets_bdim, input_lengths, target_lengths, blank, deterministic, zero_infinity); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple _cudnn_ctc_loss_Tensor_generated_plumbing(const at::Tensor & log_probs, const at::Tensor & targets, const at::Tensor & input_lengths, const at::Tensor & target_lengths, int64_t blank, bool deterministic, bool zero_infinity) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(log_probs, cur_level) && !isBatchedAtLevel(targets, cur_level) && !isBatchedAtLevel(input_lengths, cur_level) && !isBatchedAtLevel(target_lengths, cur_level)) { + return at::_ops::_cudnn_ctc_loss_Tensor::call(log_probs, targets, input_lengths, target_lengths, blank, deterministic, zero_infinity); + } + auto [log_probs_value, log_probs_bdim] = unwrapTensorAtLevel(log_probs, cur_level); + auto [targets_value, targets_bdim] = unwrapTensorAtLevel(targets, cur_level); + auto [input_lengths_value, input_lengths_bdim] = unwrapTensorAtLevel(input_lengths, cur_level); + auto [target_lengths_value, target_lengths_bdim] = unwrapTensorAtLevel(target_lengths, cur_level); + auto results = batch_rule(log_probs_value, log_probs_bdim, targets_value, targets_bdim, input_lengths_value, input_lengths_bdim, target_lengths_value, target_lengths_bdim, blank, deterministic, zero_infinity); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor _cudnn_rnn_flatten_weight_generated_plumbing(at::TensorList weight_arr, int64_t weight_stride0, c10::SymInt input_size, int64_t mode, c10::SymInt hidden_size, c10::SymInt proj_size, int64_t num_layers, bool batch_first, bool bidirectional) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(weight_arr, cur_level)) { + return at::_ops::_cudnn_rnn_flatten_weight::call(weight_arr, weight_stride0, input_size, mode, hidden_size, proj_size, num_layers, batch_first, bidirectional); + } + + auto results = batch_rule(weight_arr, weight_stride0, input_size, mode, hidden_size, proj_size, num_layers, batch_first, bidirectional); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple _cudnn_rnn_generated_plumbing(const at::Tensor & input, at::TensorList weight, int64_t weight_stride0, const ::std::optional & weight_buf, const at::Tensor & hx, const ::std::optional & cx, int64_t mode, c10::SymInt hidden_size, c10::SymInt proj_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, c10::SymIntArrayRef batch_sizes, const ::std::optional & dropout_state) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(weight_buf, cur_level) && !isBatchedAtLevel(hx, cur_level) && !isBatchedAtLevel(cx, cur_level) && !isBatchedAtLevel(dropout_state, cur_level)) { + return at::_ops::_cudnn_rnn::call(input, weight, weight_stride0, weight_buf, hx, cx, mode, hidden_size, proj_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [hx_value, hx_bdim] = unwrapTensorAtLevel(hx, cur_level); + std::optional weight_buf_value; + std::optional weight_buf_bdim; + if (weight_buf) { + std::tie(weight_buf_value, weight_buf_bdim) = unwrapTensorAtLevel(weight_buf.value(), cur_level); + } + std::optional cx_value; + std::optional cx_bdim; + if (cx) { + std::tie(cx_value, cx_bdim) = unwrapTensorAtLevel(cx.value(), cur_level); + } + std::optional dropout_state_value; + std::optional dropout_state_bdim; + if (dropout_state) { + std::tie(dropout_state_value, dropout_state_bdim) = unwrapTensorAtLevel(dropout_state.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, weight, weight_stride0, weight_buf_value, weight_buf_bdim, hx_value, hx_bdim, cx_value, cx_bdim, mode, hidden_size, proj_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state_value, dropout_state_bdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level), makeBatched(std::get<6>(results), std::get<7>(results), cur_level), makeBatched(std::get<8>(results), std::get<9>(results), cur_level)); +} +template +::std::tuple> _cudnn_rnn_backward_generated_plumbing(const at::Tensor & input, at::TensorList weight, int64_t weight_stride0, const at::Tensor & weight_buf, const at::Tensor & hx, const ::std::optional & cx, const at::Tensor & output, const ::std::optional & grad_output, const ::std::optional & grad_hy, const ::std::optional & grad_cy, int64_t mode, c10::SymInt hidden_size, c10::SymInt proj_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, c10::SymIntArrayRef batch_sizes, const ::std::optional & dropout_state, const at::Tensor & reserve, ::std::array output_mask) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(weight_buf, cur_level) && !isBatchedAtLevel(hx, cur_level) && !isBatchedAtLevel(cx, cur_level) && !isBatchedAtLevel(output, cur_level) && !isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(grad_hy, cur_level) && !isBatchedAtLevel(grad_cy, cur_level) && !isBatchedAtLevel(dropout_state, cur_level) && !isBatchedAtLevel(reserve, cur_level)) { + return at::_ops::_cudnn_rnn_backward::call(input, weight, weight_stride0, weight_buf, hx, cx, output, grad_output, grad_hy, grad_cy, mode, hidden_size, proj_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state, reserve, output_mask); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [weight_buf_value, weight_buf_bdim] = unwrapTensorAtLevel(weight_buf, cur_level); + auto [hx_value, hx_bdim] = unwrapTensorAtLevel(hx, cur_level); + auto [output_value, output_bdim] = unwrapTensorAtLevel(output, cur_level); + auto [reserve_value, reserve_bdim] = unwrapTensorAtLevel(reserve, cur_level); + std::optional cx_value; + std::optional cx_bdim; + if (cx) { + std::tie(cx_value, cx_bdim) = unwrapTensorAtLevel(cx.value(), cur_level); + } + std::optional grad_output_value; + std::optional grad_output_bdim; + if (grad_output) { + std::tie(grad_output_value, grad_output_bdim) = unwrapTensorAtLevel(grad_output.value(), cur_level); + } + std::optional grad_hy_value; + std::optional grad_hy_bdim; + if (grad_hy) { + std::tie(grad_hy_value, grad_hy_bdim) = unwrapTensorAtLevel(grad_hy.value(), cur_level); + } + std::optional grad_cy_value; + std::optional grad_cy_bdim; + if (grad_cy) { + std::tie(grad_cy_value, grad_cy_bdim) = unwrapTensorAtLevel(grad_cy.value(), cur_level); + } + std::optional dropout_state_value; + std::optional dropout_state_bdim; + if (dropout_state) { + std::tie(dropout_state_value, dropout_state_bdim) = unwrapTensorAtLevel(dropout_state.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, weight, weight_stride0, weight_buf_value, weight_buf_bdim, hx_value, hx_bdim, cx_value, cx_bdim, output_value, output_bdim, grad_output_value, grad_output_bdim, grad_hy_value, grad_hy_bdim, grad_cy_value, grad_cy_bdim, mode, hidden_size, proj_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state_value, dropout_state_bdim, reserve_value, reserve_bdim, output_mask); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level), makeBatchedVector(std::get<6>(results), std::get<7>(results), cur_level)); +} +template +::std::tuple _fused_dropout_generated_plumbing(const at::Tensor & self, double p, ::std::optional generator) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_fused_dropout::call(self, p, generator); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, p, generator); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor _masked_scale_generated_plumbing(const at::Tensor & self, const at::Tensor & mask, double scale) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(mask, cur_level)) { + return at::_ops::_masked_scale::call(self, mask, scale); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [mask_value, mask_bdim] = unwrapTensorAtLevel(mask, cur_level); + auto results = batch_rule(self_value, self_bdim, mask_value, mask_bdim, scale); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple native_dropout_generated_plumbing(const at::Tensor & input, double p, ::std::optional train) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level)) { + return at::_ops::native_dropout::call(input, p, train); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto results = batch_rule(input_value, input_bdim, p, train); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor native_dropout_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & mask, double scale) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(mask, cur_level)) { + return at::_ops::native_dropout_backward::call(grad_output, mask, scale); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [mask_value, mask_bdim] = unwrapTensorAtLevel(mask, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, mask_value, mask_bdim, scale); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple _sobol_engine_draw_generated_plumbing(const at::Tensor & quasi, int64_t n, const at::Tensor & sobolstate, int64_t dimension, int64_t num_generated, ::std::optional dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(quasi, cur_level) && !isBatchedAtLevel(sobolstate, cur_level)) { + return at::_ops::_sobol_engine_draw::call(quasi, n, sobolstate, dimension, num_generated, dtype); + } + auto [quasi_value, quasi_bdim] = unwrapTensorAtLevel(quasi, cur_level); + auto [sobolstate_value, sobolstate_bdim] = unwrapTensorAtLevel(sobolstate, cur_level); + auto results = batch_rule(quasi_value, quasi_bdim, n, sobolstate_value, sobolstate_bdim, dimension, num_generated, dtype); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor & _sobol_engine_ff__generated_plumbing(at::Tensor & self, int64_t n, const at::Tensor & sobolstate, int64_t dimension, int64_t num_generated) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(sobolstate, cur_level)) { + return at::_ops::_sobol_engine_ff_::call(self, n, sobolstate, dimension, num_generated); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [sobolstate_value, sobolstate_bdim] = unwrapTensorAtLevel(sobolstate, cur_level); + batch_rule(self_value, self_bdim, n, sobolstate_value, sobolstate_bdim, dimension, num_generated); + return self; +} +template +at::Tensor & _sobol_engine_scramble__generated_plumbing(at::Tensor & self, const at::Tensor & ltm, int64_t dimension) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(ltm, cur_level)) { + return at::_ops::_sobol_engine_scramble_::call(self, ltm, dimension); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [ltm_value, ltm_bdim] = unwrapTensorAtLevel(ltm, cur_level); + batch_rule(self_value, self_bdim, ltm_value, ltm_bdim, dimension); + return self; +} +template +at::Tensor & _sobol_engine_initialize_state__generated_plumbing(at::Tensor & self, int64_t dimension) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_sobol_engine_initialize_state_::call(self, dimension); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, dimension); + return self; +} +template +at::Tensor _reshape_from_tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & shape) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(shape, cur_level)) { + return at::_ops::_reshape_from_tensor::call(self, shape); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [shape_value, shape_bdim] = unwrapTensorAtLevel(shape, cur_level); + auto results = batch_rule(self_value, self_bdim, shape_value, shape_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _shape_as_tensor_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_shape_as_tensor::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor dropout_generated_plumbing(const at::Tensor & input, double p, bool train) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level)) { + return at::_ops::dropout::call(input, p, train); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto results = batch_rule(input_value, input_bdim, p, train); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & dropout__generated_plumbing(at::Tensor & self, double p, bool train) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::dropout_::call(self, p, train); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, p, train); + return self; +} +template +at::Tensor feature_dropout_generated_plumbing(const at::Tensor & input, double p, bool train) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level)) { + return at::_ops::feature_dropout::call(input, p, train); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto results = batch_rule(input_value, input_bdim, p, train); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & feature_dropout__generated_plumbing(at::Tensor & self, double p, bool train) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::feature_dropout_::call(self, p, train); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, p, train); + return self; +} +template +at::Tensor alpha_dropout_generated_plumbing(const at::Tensor & input, double p, bool train) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level)) { + return at::_ops::alpha_dropout::call(input, p, train); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto results = batch_rule(input_value, input_bdim, p, train); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & alpha_dropout__generated_plumbing(at::Tensor & self, double p, bool train) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::alpha_dropout_::call(self, p, train); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, p, train); + return self; +} +template +at::Tensor feature_alpha_dropout_generated_plumbing(const at::Tensor & input, double p, bool train) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level)) { + return at::_ops::feature_alpha_dropout::call(input, p, train); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto results = batch_rule(input_value, input_bdim, p, train); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & feature_alpha_dropout__generated_plumbing(at::Tensor & self, double p, bool train) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::feature_alpha_dropout_::call(self, p, train); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, p, train); + return self; +} +template +at::Tensor abs_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::abs::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & abs__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::abs_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor absolute_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::absolute::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & absolute__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::absolute_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor angle_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::angle::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor view_as_real_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::view_as_real::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor view_as_complex_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::view_as_complex::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor sgn_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::sgn::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & sgn__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::sgn_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor chalf_generated_plumbing(const at::Tensor & self, ::std::optional memory_format) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::chalf::call(self, memory_format); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, memory_format); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor real_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::real::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor imag_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::imag::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _conj_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_conj::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor conj_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::conj::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _conj_physical_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_conj_physical::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor conj_physical_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::conj_physical::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & conj_physical__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::conj_physical_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor resolve_conj_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::resolve_conj::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor resolve_neg_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::resolve_neg::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _neg_view_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_neg_view::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor acos_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::acos::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & acos__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::acos_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor arccos_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::arccos::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & arccos__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::arccos_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor avg_pool1d_generated_plumbing(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::avg_pool1d::call(self, kernel_size, stride, padding, ceil_mode, count_include_pad); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, kernel_size, stride, padding, ceil_mode, count_include_pad); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor adaptive_avg_pool1d_generated_plumbing(const at::Tensor & self, at::IntArrayRef output_size) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::adaptive_avg_pool1d::call(self, output_size); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, output_size); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple adaptive_max_pool1d_generated_plumbing(const at::Tensor & self, at::IntArrayRef output_size) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::adaptive_max_pool1d::call(self, output_size); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, output_size); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor add_Tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::add_Tensor::call(self, other, alpha); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim, alpha); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & add__Tensor_generated_plumbing(at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::add__Tensor::call(self, other, alpha); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self_value, self_bdim, other_value, other_bdim, alpha); + return self; +} +template +at::Tensor _add_relu_Tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::_add_relu_Tensor::call(self, other, alpha); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim, alpha); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & _add_relu__Tensor_generated_plumbing(at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::_add_relu__Tensor::call(self, other, alpha); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self_value, self_bdim, other_value, other_bdim, alpha); + return self; +} +template +at::Tensor _add_relu_Scalar_generated_plumbing(const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_add_relu_Scalar::call(self, other, alpha); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, other, alpha); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & _add_relu__Scalar_generated_plumbing(at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_add_relu__Scalar::call(self, other, alpha); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, other, alpha); + return self; +} +template +at::Tensor add_Scalar_generated_plumbing(const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::add_Scalar::call(self, other, alpha); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, other, alpha); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & add__Scalar_generated_plumbing(at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::add__Scalar::call(self, other, alpha); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, other, alpha); + return self; +} +template +at::Tensor addmv_generated_plumbing(const at::Tensor & self, const at::Tensor & mat, const at::Tensor & vec, const at::Scalar & beta, const at::Scalar & alpha) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(mat, cur_level) && !isBatchedAtLevel(vec, cur_level)) { + return at::_ops::addmv::call(self, mat, vec, beta, alpha); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [mat_value, mat_bdim] = unwrapTensorAtLevel(mat, cur_level); + auto [vec_value, vec_bdim] = unwrapTensorAtLevel(vec, cur_level); + auto results = batch_rule(self_value, self_bdim, mat_value, mat_bdim, vec_value, vec_bdim, beta, alpha); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & addmv__generated_plumbing(at::Tensor & self, const at::Tensor & mat, const at::Tensor & vec, const at::Scalar & beta, const at::Scalar & alpha) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(mat, cur_level) && !isBatchedAtLevel(vec, cur_level)) { + return at::_ops::addmv_::call(self, mat, vec, beta, alpha); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [mat_value, mat_bdim] = unwrapTensorAtLevel(mat, cur_level); + auto [vec_value, vec_bdim] = unwrapTensorAtLevel(vec, cur_level); + batch_rule(self_value, self_bdim, mat_value, mat_bdim, vec_value, vec_bdim, beta, alpha); + return self; +} +template +at::Tensor addr_generated_plumbing(const at::Tensor & self, const at::Tensor & vec1, const at::Tensor & vec2, const at::Scalar & beta, const at::Scalar & alpha) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(vec1, cur_level) && !isBatchedAtLevel(vec2, cur_level)) { + return at::_ops::addr::call(self, vec1, vec2, beta, alpha); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [vec1_value, vec1_bdim] = unwrapTensorAtLevel(vec1, cur_level); + auto [vec2_value, vec2_bdim] = unwrapTensorAtLevel(vec2, cur_level); + auto results = batch_rule(self_value, self_bdim, vec1_value, vec1_bdim, vec2_value, vec2_bdim, beta, alpha); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & addr__generated_plumbing(at::Tensor & self, const at::Tensor & vec1, const at::Tensor & vec2, const at::Scalar & beta, const at::Scalar & alpha) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(vec1, cur_level) && !isBatchedAtLevel(vec2, cur_level)) { + return at::_ops::addr_::call(self, vec1, vec2, beta, alpha); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [vec1_value, vec1_bdim] = unwrapTensorAtLevel(vec1, cur_level); + auto [vec2_value, vec2_bdim] = unwrapTensorAtLevel(vec2, cur_level); + batch_rule(self_value, self_bdim, vec1_value, vec1_bdim, vec2_value, vec2_bdim, beta, alpha); + return self; +} +template +at::Tensor affine_grid_generator_generated_plumbing(const at::Tensor & theta, c10::SymIntArrayRef size, bool align_corners) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(theta, cur_level)) { + return at::_ops::affine_grid_generator::call(theta, size, align_corners); + } + auto [theta_value, theta_bdim] = unwrapTensorAtLevel(theta, cur_level); + auto results = batch_rule(theta_value, theta_bdim, size, align_corners); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor affine_grid_generator_backward_generated_plumbing(const at::Tensor & grad, c10::SymIntArrayRef size, bool align_corners) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad, cur_level)) { + return at::_ops::affine_grid_generator_backward::call(grad, size, align_corners); + } + auto [grad_value, grad_bdim] = unwrapTensorAtLevel(grad, cur_level); + auto results = batch_rule(grad_value, grad_bdim, size, align_corners); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _is_all_true_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_is_all_true::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _is_any_true_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_is_any_true::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _test_check_tensor_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_test_check_tensor::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _test_functorch_fallback_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::_test_functorch_fallback::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor all_dim_generated_plumbing(const at::Tensor & self, int64_t dim, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::all_dim::call(self, dim, keepdim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, keepdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor all_dims_generated_plumbing(const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::all_dims::call(self, dim, keepdim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, keepdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor all_dimname_generated_plumbing(const at::Tensor & self, at::Dimname dim, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::all_dimname::call(self, dim, keepdim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, keepdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor any_dim_generated_plumbing(const at::Tensor & self, int64_t dim, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::any_dim::call(self, dim, keepdim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, keepdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor any_dims_generated_plumbing(const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::any_dims::call(self, dim, keepdim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, keepdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor any_dimname_generated_plumbing(const at::Tensor & self, at::Dimname dim, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::any_dimname::call(self, dim, keepdim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, keepdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _dim_arange_generated_plumbing(const at::Tensor & like, int64_t dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(like, cur_level)) { + return at::_ops::_dim_arange::call(like, dim); + } + auto [like_value, like_bdim] = unwrapTensorAtLevel(like, cur_level); + auto results = batch_rule(like_value, like_bdim, dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor argmax_generated_plumbing(const at::Tensor & self, ::std::optional dim, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::argmax::call(self, dim, keepdim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, keepdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor argmin_generated_plumbing(const at::Tensor & self, ::std::optional dim, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::argmin::call(self, dim, keepdim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, keepdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor acosh_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::acosh::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & acosh__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::acosh_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor arccosh_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::arccosh::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & arccosh__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::arccosh_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor asinh_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::asinh::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & asinh__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::asinh_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor arcsinh_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::arcsinh::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & arcsinh__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::arcsinh_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor atanh_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::atanh::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & atanh__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::atanh_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor arctanh_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::arctanh::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & arctanh__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::arctanh_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor as_strided_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional storage_offset) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::as_strided::call(self, size, stride, storage_offset); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, size, stride, storage_offset); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor asin_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::asin::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & asin__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::asin_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor arcsin_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::arcsin::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & arcsin__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::arcsin_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor atan_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::atan::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & atan__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::atan_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor arctan_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::arctan::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & arctan__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::arctan_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor atleast_1d_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::atleast_1d::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::vector atleast_1d_Sequence_generated_plumbing(at::TensorList tensors) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(tensors, cur_level)) { + return at::_ops::atleast_1d_Sequence::call(tensors); + } + + auto results = batch_rule(tensors); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor atleast_2d_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::atleast_2d::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::vector atleast_2d_Sequence_generated_plumbing(at::TensorList tensors) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(tensors, cur_level)) { + return at::_ops::atleast_2d_Sequence::call(tensors); + } + + auto results = batch_rule(tensors); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor atleast_3d_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::atleast_3d::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::vector atleast_3d_Sequence_generated_plumbing(at::TensorList tensors) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(tensors, cur_level)) { + return at::_ops::atleast_3d_Sequence::call(tensors); + } + + auto results = batch_rule(tensors); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor baddbmm_generated_plumbing(const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta, const at::Scalar & alpha) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(batch1, cur_level) && !isBatchedAtLevel(batch2, cur_level)) { + return at::_ops::baddbmm::call(self, batch1, batch2, beta, alpha); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [batch1_value, batch1_bdim] = unwrapTensorAtLevel(batch1, cur_level); + auto [batch2_value, batch2_bdim] = unwrapTensorAtLevel(batch2, cur_level); + auto results = batch_rule(self_value, self_bdim, batch1_value, batch1_bdim, batch2_value, batch2_bdim, beta, alpha); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & baddbmm__generated_plumbing(at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta, const at::Scalar & alpha) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(batch1, cur_level) && !isBatchedAtLevel(batch2, cur_level)) { + return at::_ops::baddbmm_::call(self, batch1, batch2, beta, alpha); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [batch1_value, batch1_bdim] = unwrapTensorAtLevel(batch1, cur_level); + auto [batch2_value, batch2_bdim] = unwrapTensorAtLevel(batch2, cur_level); + batch_rule(self_value, self_bdim, batch1_value, batch1_bdim, batch2_value, batch2_bdim, beta, alpha); + return self; +} +template +at::Tensor baddbmm_dtype_generated_plumbing(const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, at::ScalarType out_dtype, const at::Scalar & beta, const at::Scalar & alpha) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(batch1, cur_level) && !isBatchedAtLevel(batch2, cur_level)) { + return at::_ops::baddbmm_dtype::call(self, batch1, batch2, out_dtype, beta, alpha); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [batch1_value, batch1_bdim] = unwrapTensorAtLevel(batch1, cur_level); + auto [batch2_value, batch2_bdim] = unwrapTensorAtLevel(batch2, cur_level); + auto results = batch_rule(self_value, self_bdim, batch1_value, batch1_bdim, batch2_value, batch2_bdim, out_dtype, beta, alpha); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor batch_norm_generated_plumbing(const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const ::std::optional & running_mean, const ::std::optional & running_var, bool training, double momentum, double eps, bool cudnn_enabled) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level) && !isBatchedAtLevel(running_mean, cur_level) && !isBatchedAtLevel(running_var, cur_level)) { + return at::_ops::batch_norm::call(input, weight, bias, running_mean, running_var, training, momentum, eps, cudnn_enabled); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + std::optional weight_value; + std::optional weight_bdim; + if (weight) { + std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight.value(), cur_level); + } + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + std::optional running_mean_value; + std::optional running_mean_bdim; + if (running_mean) { + std::tie(running_mean_value, running_mean_bdim) = unwrapTensorAtLevel(running_mean.value(), cur_level); + } + std::optional running_var_value; + std::optional running_var_bdim; + if (running_var) { + std::tie(running_var_value, running_var_bdim) = unwrapTensorAtLevel(running_var.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, weight_value, weight_bdim, bias_value, bias_bdim, running_mean_value, running_mean_bdim, running_var_value, running_var_bdim, training, momentum, eps, cudnn_enabled); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor quantized_batch_norm_generated_plumbing(const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const at::Tensor & mean, const at::Tensor & var, double eps, double output_scale, int64_t output_zero_point) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level) && !isBatchedAtLevel(mean, cur_level) && !isBatchedAtLevel(var, cur_level)) { + return at::_ops::quantized_batch_norm::call(input, weight, bias, mean, var, eps, output_scale, output_zero_point); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [mean_value, mean_bdim] = unwrapTensorAtLevel(mean, cur_level); + auto [var_value, var_bdim] = unwrapTensorAtLevel(var, cur_level); + std::optional weight_value; + std::optional weight_bdim; + if (weight) { + std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight.value(), cur_level); + } + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, weight_value, weight_bdim, bias_value, bias_bdim, mean_value, mean_bdim, var_value, var_bdim, eps, output_scale, output_zero_point); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple _batch_norm_impl_index_backward_generated_plumbing(int64_t impl_index, const at::Tensor & input, const at::Tensor & grad_output, const ::std::optional & weight, const ::std::optional & running_mean, const ::std::optional & running_var, const ::std::optional & save_mean, const ::std::optional & save_var_transform, bool train, double eps, ::std::array output_mask, const at::Tensor & reservedSpace) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(running_mean, cur_level) && !isBatchedAtLevel(running_var, cur_level) && !isBatchedAtLevel(save_mean, cur_level) && !isBatchedAtLevel(save_var_transform, cur_level) && !isBatchedAtLevel(reservedSpace, cur_level)) { + return at::_ops::_batch_norm_impl_index_backward::call(impl_index, input, grad_output, weight, running_mean, running_var, save_mean, save_var_transform, train, eps, output_mask, reservedSpace); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [reservedSpace_value, reservedSpace_bdim] = unwrapTensorAtLevel(reservedSpace, cur_level); + std::optional weight_value; + std::optional weight_bdim; + if (weight) { + std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight.value(), cur_level); + } + std::optional running_mean_value; + std::optional running_mean_bdim; + if (running_mean) { + std::tie(running_mean_value, running_mean_bdim) = unwrapTensorAtLevel(running_mean.value(), cur_level); + } + std::optional running_var_value; + std::optional running_var_bdim; + if (running_var) { + std::tie(running_var_value, running_var_bdim) = unwrapTensorAtLevel(running_var.value(), cur_level); + } + std::optional save_mean_value; + std::optional save_mean_bdim; + if (save_mean) { + std::tie(save_mean_value, save_mean_bdim) = unwrapTensorAtLevel(save_mean.value(), cur_level); + } + std::optional save_var_transform_value; + std::optional save_var_transform_bdim; + if (save_var_transform) { + std::tie(save_var_transform_value, save_var_transform_bdim) = unwrapTensorAtLevel(save_var_transform.value(), cur_level); + } + auto results = batch_rule(impl_index, input_value, input_bdim, grad_output_value, grad_output_bdim, weight_value, weight_bdim, running_mean_value, running_mean_bdim, running_var_value, running_var_bdim, save_mean_value, save_mean_bdim, save_var_transform_value, save_var_transform_bdim, train, eps, output_mask, reservedSpace_value, reservedSpace_bdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level)); +} +template +at::Tensor bernoulli_generated_plumbing(const at::Tensor & self, ::std::optional generator) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::bernoulli::call(self, generator); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, generator); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & bernoulli__Tensor_generated_plumbing(at::Tensor & self, const at::Tensor & p, ::std::optional generator) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(p, cur_level)) { + return at::_ops::bernoulli__Tensor::call(self, p, generator); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [p_value, p_bdim] = unwrapTensorAtLevel(p, cur_level); + batch_rule(self_value, self_bdim, p_value, p_bdim, generator); + return self; +} +template +at::Tensor & bernoulli__float_generated_plumbing(at::Tensor & self, double p, ::std::optional generator) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::bernoulli__float::call(self, p, generator); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, p, generator); + return self; +} +template +at::Tensor bernoulli_p_generated_plumbing(const at::Tensor & self, double p, ::std::optional generator) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::bernoulli_p::call(self, p, generator); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, p, generator); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor bilinear_generated_plumbing(const at::Tensor & input1, const at::Tensor & input2, const at::Tensor & weight, const ::std::optional & bias) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input1, cur_level) && !isBatchedAtLevel(input2, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::bilinear::call(input1, input2, weight, bias); + } + auto [input1_value, input1_bdim] = unwrapTensorAtLevel(input1, cur_level); + auto [input2_value, input2_bdim] = unwrapTensorAtLevel(input2, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(input1_value, input1_bdim, input2_value, input2_bdim, weight_value, weight_bdim, bias_value, bias_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor binary_cross_entropy_generated_plumbing(const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(target, cur_level) && !isBatchedAtLevel(weight, cur_level)) { + return at::_ops::binary_cross_entropy::call(self, target, weight, reduction); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [target_value, target_bdim] = unwrapTensorAtLevel(target, cur_level); + std::optional weight_value; + std::optional weight_bdim; + if (weight) { + std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, target_value, target_bdim, weight_value, weight_bdim, reduction); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor binary_cross_entropy_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(target, cur_level) && !isBatchedAtLevel(weight, cur_level)) { + return at::_ops::binary_cross_entropy_backward::call(grad_output, self, target, weight, reduction); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [target_value, target_bdim] = unwrapTensorAtLevel(target, cur_level); + std::optional weight_value; + std::optional weight_bdim; + if (weight) { + std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight.value(), cur_level); + } + auto results = batch_rule(grad_output_value, grad_output_bdim, self_value, self_bdim, target_value, target_bdim, weight_value, weight_bdim, reduction); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor binary_cross_entropy_with_logits_generated_plumbing(const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, const ::std::optional & pos_weight, int64_t reduction) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(target, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(pos_weight, cur_level)) { + return at::_ops::binary_cross_entropy_with_logits::call(self, target, weight, pos_weight, reduction); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [target_value, target_bdim] = unwrapTensorAtLevel(target, cur_level); + std::optional weight_value; + std::optional weight_bdim; + if (weight) { + std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight.value(), cur_level); + } + std::optional pos_weight_value; + std::optional pos_weight_bdim; + if (pos_weight) { + std::tie(pos_weight_value, pos_weight_bdim) = unwrapTensorAtLevel(pos_weight.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, target_value, target_bdim, weight_value, weight_bdim, pos_weight_value, pos_weight_bdim, reduction); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor bincount_generated_plumbing(const at::Tensor & self, const ::std::optional & weights, c10::SymInt minlength) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(weights, cur_level)) { + return at::_ops::bincount::call(self, weights, minlength); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + std::optional weights_value; + std::optional weights_bdim; + if (weights) { + std::tie(weights_value, weights_bdim) = unwrapTensorAtLevel(weights.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, weights_value, weights_bdim, minlength); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor bitwise_not_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::bitwise_not::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & bitwise_not__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::bitwise_not_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor copysign_Tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::copysign_Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & copysign__Tensor_generated_plumbing(at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::copysign__Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self_value, self_bdim, other_value, other_bdim); + return self; +} +template +at::Tensor copysign_Scalar_generated_plumbing(const at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::copysign_Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, other); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & copysign__Scalar_generated_plumbing(at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::copysign__Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, other); + return self; +} +template +at::Tensor _lazy_clone_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_lazy_clone::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor logical_not_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::logical_not::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & logical_not__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::logical_not_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor logical_xor_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::logical_xor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & logical_xor__generated_plumbing(at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::logical_xor_::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self_value, self_bdim, other_value, other_bdim); + return self; +} +template +at::Tensor logical_and_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::logical_and::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & logical_and__generated_plumbing(at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::logical_and_::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self_value, self_bdim, other_value, other_bdim); + return self; +} +template +at::Tensor logical_or_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::logical_or::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & logical_or__generated_plumbing(at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::logical_or_::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self_value, self_bdim, other_value, other_bdim); + return self; +} +template +at::Tensor bmm_generated_plumbing(const at::Tensor & self, const at::Tensor & mat2) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(mat2, cur_level)) { + return at::_ops::bmm::call(self, mat2); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [mat2_value, mat2_bdim] = unwrapTensorAtLevel(mat2, cur_level); + auto results = batch_rule(self_value, self_bdim, mat2_value, mat2_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor bmm_dtype_generated_plumbing(const at::Tensor & self, const at::Tensor & mat2, at::ScalarType out_dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(mat2, cur_level)) { + return at::_ops::bmm_dtype::call(self, mat2, out_dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [mat2_value, mat2_bdim] = unwrapTensorAtLevel(mat2, cur_level); + auto results = batch_rule(self_value, self_bdim, mat2_value, mat2_bdim, out_dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::vector broadcast_tensors_generated_plumbing(at::TensorList tensors) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(tensors, cur_level)) { + return at::_ops::broadcast_tensors::call(tensors); + } + + auto results = batch_rule(tensors); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor broadcast_to_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef size) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::broadcast_to::call(self, size); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, size); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _sparse_broadcast_to_generated_plumbing(const at::Tensor & self, at::IntArrayRef size) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_sparse_broadcast_to::call(self, size); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, size); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor cat_generated_plumbing(const at::ITensorListRef & tensors, int64_t dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(tensors, cur_level)) { + return at::_ops::cat::call(tensors, dim); + } + + auto results = batch_rule(tensors, dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor cat_names_generated_plumbing(at::TensorList tensors, at::Dimname dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(tensors, cur_level)) { + return at::_ops::cat_names::call(tensors, dim); + } + + auto results = batch_rule(tensors, dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor concat_generated_plumbing(at::TensorList tensors, int64_t dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(tensors, cur_level)) { + return at::_ops::concat::call(tensors, dim); + } + + auto results = batch_rule(tensors, dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor concat_names_generated_plumbing(at::TensorList tensors, at::Dimname dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(tensors, cur_level)) { + return at::_ops::concat_names::call(tensors, dim); + } + + auto results = batch_rule(tensors, dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor concatenate_generated_plumbing(at::TensorList tensors, int64_t dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(tensors, cur_level)) { + return at::_ops::concatenate::call(tensors, dim); + } + + auto results = batch_rule(tensors, dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor concatenate_names_generated_plumbing(at::TensorList tensors, at::Dimname dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(tensors, cur_level)) { + return at::_ops::concatenate_names::call(tensors, dim); + } + + auto results = batch_rule(tensors, dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor block_diag_generated_plumbing(at::TensorList tensors) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(tensors, cur_level)) { + return at::_ops::block_diag::call(tensors); + } + + auto results = batch_rule(tensors); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor ceil_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::ceil::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & ceil__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::ceil_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor chain_matmul_generated_plumbing(at::TensorList matrices) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(matrices, cur_level)) { + return at::_ops::chain_matmul::call(matrices); + } + + auto results = batch_rule(matrices); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::vector unsafe_chunk_generated_plumbing(const at::Tensor & self, int64_t chunks, int64_t dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::unsafe_chunk::call(self, chunks, dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, chunks, dim); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::vector chunk_generated_plumbing(const at::Tensor & self, int64_t chunks, int64_t dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::chunk::call(self, chunks, dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, chunks, dim); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::vector tensor_split_sections_generated_plumbing(const at::Tensor & self, c10::SymInt sections, int64_t dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::tensor_split_sections::call(self, sections, dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, sections, dim); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::vector tensor_split_indices_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef indices, int64_t dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::tensor_split_indices::call(self, indices, dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, indices, dim); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::vector tensor_split_tensor_indices_or_sections_generated_plumbing(const at::Tensor & self, const at::Tensor & tensor_indices_or_sections, int64_t dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(tensor_indices_or_sections, cur_level)) { + return at::_ops::tensor_split_tensor_indices_or_sections::call(self, tensor_indices_or_sections, dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [tensor_indices_or_sections_value, tensor_indices_or_sections_bdim] = unwrapTensorAtLevel(tensor_indices_or_sections, cur_level); + auto results = batch_rule(self_value, self_bdim, tensor_indices_or_sections_value, tensor_indices_or_sections_bdim, dim); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor clamp_generated_plumbing(const at::Tensor & self, const ::std::optional & min, const ::std::optional & max) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::clamp::call(self, min, max); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, min, max); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor clamp_Tensor_generated_plumbing(const at::Tensor & self, const ::std::optional & min, const ::std::optional & max) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(min, cur_level) && !isBatchedAtLevel(max, cur_level)) { + return at::_ops::clamp_Tensor::call(self, min, max); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + std::optional min_value; + std::optional min_bdim; + if (min) { + std::tie(min_value, min_bdim) = unwrapTensorAtLevel(min.value(), cur_level); + } + std::optional max_value; + std::optional max_bdim; + if (max) { + std::tie(max_value, max_bdim) = unwrapTensorAtLevel(max.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, min_value, min_bdim, max_value, max_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & clamp__generated_plumbing(at::Tensor & self, const ::std::optional & min, const ::std::optional & max) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::clamp_::call(self, min, max); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, min, max); + return self; +} +template +at::Tensor & clamp__Tensor_generated_plumbing(at::Tensor & self, const ::std::optional & min, const ::std::optional & max) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(min, cur_level) && !isBatchedAtLevel(max, cur_level)) { + return at::_ops::clamp__Tensor::call(self, min, max); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + std::optional min_value; + std::optional min_bdim; + if (min) { + std::tie(min_value, min_bdim) = unwrapTensorAtLevel(min.value(), cur_level); + } + std::optional max_value; + std::optional max_bdim; + if (max) { + std::tie(max_value, max_bdim) = unwrapTensorAtLevel(max.value(), cur_level); + } + batch_rule(self_value, self_bdim, min_value, min_bdim, max_value, max_bdim); + return self; +} +template +at::Tensor clamp_max_generated_plumbing(const at::Tensor & self, const at::Scalar & max) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::clamp_max::call(self, max); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, max); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor clamp_max_Tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & max) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(max, cur_level)) { + return at::_ops::clamp_max_Tensor::call(self, max); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [max_value, max_bdim] = unwrapTensorAtLevel(max, cur_level); + auto results = batch_rule(self_value, self_bdim, max_value, max_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & clamp_max__generated_plumbing(at::Tensor & self, const at::Scalar & max) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::clamp_max_::call(self, max); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, max); + return self; +} +template +at::Tensor & clamp_max__Tensor_generated_plumbing(at::Tensor & self, const at::Tensor & max) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(max, cur_level)) { + return at::_ops::clamp_max__Tensor::call(self, max); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [max_value, max_bdim] = unwrapTensorAtLevel(max, cur_level); + batch_rule(self_value, self_bdim, max_value, max_bdim); + return self; +} +template +at::Tensor clamp_min_generated_plumbing(const at::Tensor & self, const at::Scalar & min) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::clamp_min::call(self, min); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, min); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor clamp_min_Tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & min) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(min, cur_level)) { + return at::_ops::clamp_min_Tensor::call(self, min); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [min_value, min_bdim] = unwrapTensorAtLevel(min, cur_level); + auto results = batch_rule(self_value, self_bdim, min_value, min_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & clamp_min__generated_plumbing(at::Tensor & self, const at::Scalar & min) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::clamp_min_::call(self, min); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, min); + return self; +} +template +at::Tensor & clamp_min__Tensor_generated_plumbing(at::Tensor & self, const at::Tensor & min) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(min, cur_level)) { + return at::_ops::clamp_min__Tensor::call(self, min); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [min_value, min_bdim] = unwrapTensorAtLevel(min, cur_level); + batch_rule(self_value, self_bdim, min_value, min_bdim); + return self; +} +template +at::Tensor clip_generated_plumbing(const at::Tensor & self, const ::std::optional & min, const ::std::optional & max) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::clip::call(self, min, max); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, min, max); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor clip_Tensor_generated_plumbing(const at::Tensor & self, const ::std::optional & min, const ::std::optional & max) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(min, cur_level) && !isBatchedAtLevel(max, cur_level)) { + return at::_ops::clip_Tensor::call(self, min, max); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + std::optional min_value; + std::optional min_bdim; + if (min) { + std::tie(min_value, min_bdim) = unwrapTensorAtLevel(min.value(), cur_level); + } + std::optional max_value; + std::optional max_bdim; + if (max) { + std::tie(max_value, max_bdim) = unwrapTensorAtLevel(max.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, min_value, min_bdim, max_value, max_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & clip__generated_plumbing(at::Tensor & self, const ::std::optional & min, const ::std::optional & max) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::clip_::call(self, min, max); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, min, max); + return self; +} +template +at::Tensor & clip__Tensor_generated_plumbing(at::Tensor & self, const ::std::optional & min, const ::std::optional & max) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(min, cur_level) && !isBatchedAtLevel(max, cur_level)) { + return at::_ops::clip__Tensor::call(self, min, max); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + std::optional min_value; + std::optional min_bdim; + if (min) { + std::tie(min_value, min_bdim) = unwrapTensorAtLevel(min.value(), cur_level); + } + std::optional max_value; + std::optional max_bdim; + if (max) { + std::tie(max_value, max_bdim) = unwrapTensorAtLevel(max.value(), cur_level); + } + batch_rule(self_value, self_bdim, min_value, min_bdim, max_value, max_bdim); + return self; +} +template +at::Tensor complex_generated_plumbing(const at::Tensor & real, const at::Tensor & imag) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(real, cur_level) && !isBatchedAtLevel(imag, cur_level)) { + return at::_ops::complex::call(real, imag); + } + auto [real_value, real_bdim] = unwrapTensorAtLevel(real, cur_level); + auto [imag_value, imag_bdim] = unwrapTensorAtLevel(imag, cur_level); + auto results = batch_rule(real_value, real_bdim, imag_value, imag_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor polar_generated_plumbing(const at::Tensor & abs, const at::Tensor & angle) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(abs, cur_level) && !isBatchedAtLevel(angle, cur_level)) { + return at::_ops::polar::call(abs, angle); + } + auto [abs_value, abs_bdim] = unwrapTensorAtLevel(abs, cur_level); + auto [angle_value, angle_bdim] = unwrapTensorAtLevel(angle, cur_level); + auto results = batch_rule(abs_value, abs_bdim, angle_value, angle_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor constant_pad_nd_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef pad, const at::Scalar & value) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::constant_pad_nd::call(self, pad, value); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, pad, value); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor contiguous_generated_plumbing(const at::Tensor & self, at::MemoryFormat memory_format) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::contiguous::call(self, memory_format); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, memory_format); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor convolution_generated_plumbing(const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::convolution::call(input, weight, bias, stride, padding, dilation, transposed, output_padding, groups); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, weight_value, weight_bdim, bias_value, bias_bdim, stride, padding, dilation, transposed, output_padding, groups); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple convolution_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & weight, at::OptionalSymIntArrayRef bias_sizes, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups, ::std::array output_mask) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level)) { + return at::_ops::convolution_backward::call(grad_output, input, weight, bias_sizes, stride, padding, dilation, transposed, output_padding, groups, output_mask); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, input_value, input_bdim, weight_value, weight_bdim, bias_sizes, stride, padding, dilation, transposed, output_padding, groups, output_mask); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level)); +} +template +at::Tensor convolution_overrideable_generated_plumbing(const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::convolution_overrideable::call(input, weight, bias, stride, padding, dilation, transposed, output_padding, groups); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, weight_value, weight_bdim, bias_value, bias_bdim, stride, padding, dilation, transposed, output_padding, groups); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple convolution_backward_overrideable_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & weight, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups, ::std::array output_mask) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level)) { + return at::_ops::convolution_backward_overrideable::call(grad_output, input, weight, stride, padding, dilation, transposed, output_padding, groups, output_mask); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, input_value, input_bdim, weight_value, weight_bdim, stride, padding, dilation, transposed, output_padding, groups, output_mask); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level)); +} +template +at::Tensor _convolution_generated_plumbing(const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::_convolution::call(input, weight, bias, stride, padding, dilation, transposed, output_padding, groups, benchmark, deterministic, cudnn_enabled, allow_tf32); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, weight_value, weight_bdim, bias_value, bias_bdim, stride, padding, dilation, transposed, output_padding, groups, benchmark, deterministic, cudnn_enabled, allow_tf32); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _convolution_deprecated_generated_plumbing(const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, c10::SymInt groups, bool benchmark, bool deterministic, bool cudnn_enabled) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::_convolution_deprecated::call(input, weight, bias, stride, padding, dilation, transposed, output_padding, groups, benchmark, deterministic, cudnn_enabled); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, weight_value, weight_bdim, bias_value, bias_bdim, stride, padding, dilation, transposed, output_padding, groups, benchmark, deterministic, cudnn_enabled); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _convolution_mode_generated_plumbing(const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::string_view padding, c10::SymIntArrayRef dilation, c10::SymInt groups) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::_convolution_mode::call(input, weight, bias, stride, padding, dilation, groups); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, weight_value, weight_bdim, bias_value, bias_bdim, stride, padding, dilation, groups); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple _convolution_double_backward_generated_plumbing(const ::std::optional & ggI, const ::std::optional & ggW, const ::std::optional & ggb, const at::Tensor & gO, const at::Tensor & weight, const at::Tensor & self, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups, ::std::array output_mask) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(ggI, cur_level) && !isBatchedAtLevel(ggW, cur_level) && !isBatchedAtLevel(ggb, cur_level) && !isBatchedAtLevel(gO, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(self, cur_level)) { + return at::_ops::_convolution_double_backward::call(ggI, ggW, ggb, gO, weight, self, stride, padding, dilation, transposed, output_padding, groups, output_mask); + } + auto [gO_value, gO_bdim] = unwrapTensorAtLevel(gO, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + std::optional ggI_value; + std::optional ggI_bdim; + if (ggI) { + std::tie(ggI_value, ggI_bdim) = unwrapTensorAtLevel(ggI.value(), cur_level); + } + std::optional ggW_value; + std::optional ggW_bdim; + if (ggW) { + std::tie(ggW_value, ggW_bdim) = unwrapTensorAtLevel(ggW.value(), cur_level); + } + std::optional ggb_value; + std::optional ggb_bdim; + if (ggb) { + std::tie(ggb_value, ggb_bdim) = unwrapTensorAtLevel(ggb.value(), cur_level); + } + auto results = batch_rule(ggI_value, ggI_bdim, ggW_value, ggW_bdim, ggb_value, ggb_bdim, gO_value, gO_bdim, weight_value, weight_bdim, self_value, self_bdim, stride, padding, dilation, transposed, output_padding, groups, output_mask); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level)); +} +template +at::Tensor conv1d_generated_plumbing(const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, c10::SymInt groups) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::conv1d::call(input, weight, bias, stride, padding, dilation, groups); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, weight_value, weight_bdim, bias_value, bias_bdim, stride, padding, dilation, groups); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor conv2d_generated_plumbing(const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, c10::SymInt groups) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::conv2d::call(input, weight, bias, stride, padding, dilation, groups); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, weight_value, weight_bdim, bias_value, bias_bdim, stride, padding, dilation, groups); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor conv3d_generated_plumbing(const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, c10::SymInt groups) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::conv3d::call(input, weight, bias, stride, padding, dilation, groups); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, weight_value, weight_bdim, bias_value, bias_bdim, stride, padding, dilation, groups); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor conv1d_padding_generated_plumbing(const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::string_view padding, c10::SymIntArrayRef dilation, c10::SymInt groups) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::conv1d_padding::call(input, weight, bias, stride, padding, dilation, groups); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, weight_value, weight_bdim, bias_value, bias_bdim, stride, padding, dilation, groups); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor conv2d_padding_generated_plumbing(const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::string_view padding, c10::SymIntArrayRef dilation, c10::SymInt groups) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::conv2d_padding::call(input, weight, bias, stride, padding, dilation, groups); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, weight_value, weight_bdim, bias_value, bias_bdim, stride, padding, dilation, groups); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor conv3d_padding_generated_plumbing(const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::string_view padding, c10::SymIntArrayRef dilation, c10::SymInt groups) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::conv3d_padding::call(input, weight, bias, stride, padding, dilation, groups); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, weight_value, weight_bdim, bias_value, bias_bdim, stride, padding, dilation, groups); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor conv_tbc_generated_plumbing(const at::Tensor & self, const at::Tensor & weight, const at::Tensor & bias, int64_t pad) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::conv_tbc::call(self, weight, bias, pad); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + auto [bias_value, bias_bdim] = unwrapTensorAtLevel(bias, cur_level); + auto results = batch_rule(self_value, self_bdim, weight_value, weight_bdim, bias_value, bias_bdim, pad); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple conv_tbc_backward_generated_plumbing(const at::Tensor & self, const at::Tensor & input, const at::Tensor & weight, const at::Tensor & bias, int64_t pad) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::conv_tbc_backward::call(self, input, weight, bias, pad); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + auto [bias_value, bias_bdim] = unwrapTensorAtLevel(bias, cur_level); + auto results = batch_rule(self_value, self_bdim, input_value, input_bdim, weight_value, weight_bdim, bias_value, bias_bdim, pad); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level)); +} +template +at::Tensor conv_transpose1d_generated_plumbing(const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymInt groups, c10::SymIntArrayRef dilation) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::conv_transpose1d::call(input, weight, bias, stride, padding, output_padding, groups, dilation); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, weight_value, weight_bdim, bias_value, bias_bdim, stride, padding, output_padding, groups, dilation); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor conv_transpose2d_input_generated_plumbing(const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymInt groups, c10::SymIntArrayRef dilation) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::conv_transpose2d_input::call(input, weight, bias, stride, padding, output_padding, groups, dilation); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, weight_value, weight_bdim, bias_value, bias_bdim, stride, padding, output_padding, groups, dilation); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor conv_transpose3d_input_generated_plumbing(const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymInt groups, c10::SymIntArrayRef dilation) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::conv_transpose3d_input::call(input, weight, bias, stride, padding, output_padding, groups, dilation); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, weight_value, weight_bdim, bias_value, bias_bdim, stride, padding, output_padding, groups, dilation); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor copy_generated_plumbing(const at::Tensor & self, const at::Tensor & src, bool non_blocking) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(src, cur_level)) { + return at::_ops::copy::call(self, src, non_blocking); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [src_value, src_bdim] = unwrapTensorAtLevel(src, cur_level); + auto results = batch_rule(self_value, self_bdim, src_value, src_bdim, non_blocking); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & copy__generated_plumbing(at::Tensor & self, const at::Tensor & src, bool non_blocking) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(src, cur_level)) { + return at::_ops::copy_::call(self, src, non_blocking); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [src_value, src_bdim] = unwrapTensorAtLevel(src, cur_level); + batch_rule(self_value, self_bdim, src_value, src_bdim, non_blocking); + return self; +} +template +at::Tensor _copy_from_generated_plumbing(const at::Tensor & self, const at::Tensor & dst, bool non_blocking) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(dst, cur_level)) { + return at::_ops::_copy_from::call(self, dst, non_blocking); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [dst_value, dst_bdim] = unwrapTensorAtLevel(dst, cur_level); + auto results = batch_rule(self_value, self_bdim, dst_value, dst_bdim, non_blocking); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _copy_from_and_resize_generated_plumbing(const at::Tensor & self, const at::Tensor & dst) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(dst, cur_level)) { + return at::_ops::_copy_from_and_resize::call(self, dst); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [dst_value, dst_bdim] = unwrapTensorAtLevel(dst, cur_level); + auto results = batch_rule(self_value, self_bdim, dst_value, dst_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor cos_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::cos::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & cos__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::cos_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor cosh_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::cosh::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & cosh__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::cosh_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor cosine_embedding_loss_generated_plumbing(const at::Tensor & input1, const at::Tensor & input2, const at::Tensor & target, double margin, int64_t reduction) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input1, cur_level) && !isBatchedAtLevel(input2, cur_level) && !isBatchedAtLevel(target, cur_level)) { + return at::_ops::cosine_embedding_loss::call(input1, input2, target, margin, reduction); + } + auto [input1_value, input1_bdim] = unwrapTensorAtLevel(input1, cur_level); + auto [input2_value, input2_bdim] = unwrapTensorAtLevel(input2, cur_level); + auto [target_value, target_bdim] = unwrapTensorAtLevel(target, cur_level); + auto results = batch_rule(input1_value, input1_bdim, input2_value, input2_bdim, target_value, target_bdim, margin, reduction); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor count_nonzero_dim_IntList_generated_plumbing(const at::Tensor & self, at::IntArrayRef dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::count_nonzero_dim_IntList::call(self, dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor count_nonzero_generated_plumbing(const at::Tensor & self, ::std::optional dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::count_nonzero::call(self, dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor cov_generated_plumbing(const at::Tensor & self, int64_t correction, const ::std::optional & fweights, const ::std::optional & aweights) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(fweights, cur_level) && !isBatchedAtLevel(aweights, cur_level)) { + return at::_ops::cov::call(self, correction, fweights, aweights); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + std::optional fweights_value; + std::optional fweights_bdim; + if (fweights) { + std::tie(fweights_value, fweights_bdim) = unwrapTensorAtLevel(fweights.value(), cur_level); + } + std::optional aweights_value; + std::optional aweights_bdim; + if (aweights) { + std::tie(aweights_value, aweights_bdim) = unwrapTensorAtLevel(aweights.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, correction, fweights_value, fweights_bdim, aweights_value, aweights_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor corrcoef_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::corrcoef::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor cudnn_affine_grid_generator_generated_plumbing(const at::Tensor & theta, int64_t N, int64_t C, int64_t H, int64_t W) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(theta, cur_level)) { + return at::_ops::cudnn_affine_grid_generator::call(theta, N, C, H, W); + } + auto [theta_value, theta_bdim] = unwrapTensorAtLevel(theta, cur_level); + auto results = batch_rule(theta_value, theta_bdim, N, C, H, W); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor cudnn_affine_grid_generator_backward_generated_plumbing(const at::Tensor & grad, int64_t N, int64_t C, int64_t H, int64_t W) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad, cur_level)) { + return at::_ops::cudnn_affine_grid_generator_backward::call(grad, N, C, H, W); + } + auto [grad_value, grad_bdim] = unwrapTensorAtLevel(grad, cur_level); + auto results = batch_rule(grad_value, grad_bdim, N, C, H, W); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple cudnn_batch_norm_generated_plumbing(const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, const ::std::optional & running_mean, const ::std::optional & running_var, bool training, double exponential_average_factor, double epsilon) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level) && !isBatchedAtLevel(running_mean, cur_level) && !isBatchedAtLevel(running_var, cur_level)) { + return at::_ops::cudnn_batch_norm::call(input, weight, bias, running_mean, running_var, training, exponential_average_factor, epsilon); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + std::optional running_mean_value; + std::optional running_mean_bdim; + if (running_mean) { + std::tie(running_mean_value, running_mean_bdim) = unwrapTensorAtLevel(running_mean.value(), cur_level); + } + std::optional running_var_value; + std::optional running_var_bdim; + if (running_var) { + std::tie(running_var_value, running_var_bdim) = unwrapTensorAtLevel(running_var.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, weight_value, weight_bdim, bias_value, bias_bdim, running_mean_value, running_mean_bdim, running_var_value, running_var_bdim, training, exponential_average_factor, epsilon); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level), makeBatched(std::get<6>(results), std::get<7>(results), cur_level)); +} +template +::std::tuple cudnn_batch_norm_backward_generated_plumbing(const at::Tensor & input, const at::Tensor & grad_output, const at::Tensor & weight, const ::std::optional & running_mean, const ::std::optional & running_var, const ::std::optional & save_mean, const ::std::optional & save_var, double epsilon, const at::Tensor & reserveSpace) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(running_mean, cur_level) && !isBatchedAtLevel(running_var, cur_level) && !isBatchedAtLevel(save_mean, cur_level) && !isBatchedAtLevel(save_var, cur_level) && !isBatchedAtLevel(reserveSpace, cur_level)) { + return at::_ops::cudnn_batch_norm_backward::call(input, grad_output, weight, running_mean, running_var, save_mean, save_var, epsilon, reserveSpace); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + auto [reserveSpace_value, reserveSpace_bdim] = unwrapTensorAtLevel(reserveSpace, cur_level); + std::optional running_mean_value; + std::optional running_mean_bdim; + if (running_mean) { + std::tie(running_mean_value, running_mean_bdim) = unwrapTensorAtLevel(running_mean.value(), cur_level); + } + std::optional running_var_value; + std::optional running_var_bdim; + if (running_var) { + std::tie(running_var_value, running_var_bdim) = unwrapTensorAtLevel(running_var.value(), cur_level); + } + std::optional save_mean_value; + std::optional save_mean_bdim; + if (save_mean) { + std::tie(save_mean_value, save_mean_bdim) = unwrapTensorAtLevel(save_mean.value(), cur_level); + } + std::optional save_var_value; + std::optional save_var_bdim; + if (save_var) { + std::tie(save_var_value, save_var_bdim) = unwrapTensorAtLevel(save_var.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, grad_output_value, grad_output_bdim, weight_value, weight_bdim, running_mean_value, running_mean_bdim, running_var_value, running_var_bdim, save_mean_value, save_mean_bdim, save_var_value, save_var_bdim, epsilon, reserveSpace_value, reserveSpace_bdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level)); +} +template +at::Tensor cudnn_convolution_generated_plumbing(const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, bool benchmark, bool deterministic, bool allow_tf32) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(weight, cur_level)) { + return at::_ops::cudnn_convolution::call(self, weight, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + auto results = batch_rule(self_value, self_bdim, weight_value, weight_bdim, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor cudnn_convolution_transpose_generated_plumbing(const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, bool benchmark, bool deterministic, bool allow_tf32) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(weight, cur_level)) { + return at::_ops::cudnn_convolution_transpose::call(self, weight, padding, output_padding, stride, dilation, groups, benchmark, deterministic, allow_tf32); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + auto results = batch_rule(self_value, self_bdim, weight_value, weight_bdim, padding, output_padding, stride, dilation, groups, benchmark, deterministic, allow_tf32); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _mps_convolution_transpose_generated_plumbing(const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(weight, cur_level)) { + return at::_ops::_mps_convolution_transpose::call(self, weight, padding, output_padding, stride, dilation, groups); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + auto results = batch_rule(self_value, self_bdim, weight_value, weight_bdim, padding, output_padding, stride, dilation, groups); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple mps_convolution_transpose_backward_generated_plumbing(const at::Tensor & self, const at::Tensor & grad_output, const at::Tensor & weight, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, ::std::array output_mask) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(weight, cur_level)) { + return at::_ops::mps_convolution_transpose_backward::call(self, grad_output, weight, padding, output_padding, stride, dilation, groups, output_mask); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + auto results = batch_rule(self_value, self_bdim, grad_output_value, grad_output_bdim, weight_value, weight_bdim, padding, output_padding, stride, dilation, groups, output_mask); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor cudnn_convolution_relu_generated_plumbing(const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, c10::SymInt groups) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::cudnn_convolution_relu::call(self, weight, bias, stride, padding, dilation, groups); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, weight_value, weight_bdim, bias_value, bias_bdim, stride, padding, dilation, groups); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor cudnn_convolution_add_relu_generated_plumbing(const at::Tensor & self, const at::Tensor & weight, const at::Tensor & z, const ::std::optional & alpha, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, c10::SymInt groups) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(z, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::cudnn_convolution_add_relu::call(self, weight, z, alpha, bias, stride, padding, dilation, groups); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + auto [z_value, z_bdim] = unwrapTensorAtLevel(z, cur_level); + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, weight_value, weight_bdim, z_value, z_bdim, alpha, bias_value, bias_bdim, stride, padding, dilation, groups); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor cudnn_grid_sampler_generated_plumbing(const at::Tensor & self, const at::Tensor & grid) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(grid, cur_level)) { + return at::_ops::cudnn_grid_sampler::call(self, grid); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [grid_value, grid_bdim] = unwrapTensorAtLevel(grid, cur_level); + auto results = batch_rule(self_value, self_bdim, grid_value, grid_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple cudnn_grid_sampler_backward_generated_plumbing(const at::Tensor & self, const at::Tensor & grid, const at::Tensor & grad_output) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(grid, cur_level) && !isBatchedAtLevel(grad_output, cur_level)) { + return at::_ops::cudnn_grid_sampler_backward::call(self, grid, grad_output); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [grid_value, grid_bdim] = unwrapTensorAtLevel(grid, cur_level); + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto results = batch_rule(self_value, self_bdim, grid_value, grid_bdim, grad_output_value, grad_output_bdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple cummax_generated_plumbing(const at::Tensor & self, int64_t dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::cummax::call(self, dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple cummax_dimname_generated_plumbing(const at::Tensor & self, at::Dimname dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::cummax_dimname::call(self, dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +void _cummax_helper_generated_plumbing(const at::Tensor & self, at::Tensor & values, at::Tensor & indices, int64_t dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(values, cur_level) && !isBatchedAtLevel(indices, cur_level)) { + return at::_ops::_cummax_helper::call(self, values, indices, dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [values_value, values_bdim] = unwrapTensorAtLevel(values, cur_level); + auto [indices_value, indices_bdim] = unwrapTensorAtLevel(indices, cur_level); + batch_rule(self_value, self_bdim, values_value, values_bdim, indices_value, indices_bdim, dim); +} +template +::std::tuple cummin_generated_plumbing(const at::Tensor & self, int64_t dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::cummin::call(self, dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple cummin_dimname_generated_plumbing(const at::Tensor & self, at::Dimname dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::cummin_dimname::call(self, dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +void _cummin_helper_generated_plumbing(const at::Tensor & self, at::Tensor & values, at::Tensor & indices, int64_t dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(values, cur_level) && !isBatchedAtLevel(indices, cur_level)) { + return at::_ops::_cummin_helper::call(self, values, indices, dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [values_value, values_bdim] = unwrapTensorAtLevel(values, cur_level); + auto [indices_value, indices_bdim] = unwrapTensorAtLevel(indices, cur_level); + batch_rule(self_value, self_bdim, values_value, values_bdim, indices_value, indices_bdim, dim); +} +template +at::Tensor cummaxmin_backward_generated_plumbing(const at::Tensor & grad, const at::Tensor & input, const at::Tensor & indices, int64_t dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad, cur_level) && !isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(indices, cur_level)) { + return at::_ops::cummaxmin_backward::call(grad, input, indices, dim); + } + auto [grad_value, grad_bdim] = unwrapTensorAtLevel(grad, cur_level); + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [indices_value, indices_bdim] = unwrapTensorAtLevel(indices, cur_level); + auto results = batch_rule(grad_value, grad_bdim, input_value, input_bdim, indices_value, indices_bdim, dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor cumprod_generated_plumbing(const at::Tensor & self, int64_t dim, ::std::optional dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::cumprod::call(self, dim, dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & cumprod__generated_plumbing(at::Tensor & self, int64_t dim, ::std::optional dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::cumprod_::call(self, dim, dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, dim, dtype); + return self; +} +template +at::Tensor cumprod_dimname_generated_plumbing(const at::Tensor & self, at::Dimname dim, ::std::optional dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::cumprod_dimname::call(self, dim, dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & cumprod__dimname_generated_plumbing(at::Tensor & self, at::Dimname dim, ::std::optional dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::cumprod__dimname::call(self, dim, dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, dim, dtype); + return self; +} +template +at::Tensor cumprod_backward_generated_plumbing(const at::Tensor & grad, const at::Tensor & input, int64_t dim, const at::Tensor & output) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad, cur_level) && !isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(output, cur_level)) { + return at::_ops::cumprod_backward::call(grad, input, dim, output); + } + auto [grad_value, grad_bdim] = unwrapTensorAtLevel(grad, cur_level); + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [output_value, output_bdim] = unwrapTensorAtLevel(output, cur_level); + auto results = batch_rule(grad_value, grad_bdim, input_value, input_bdim, dim, output_value, output_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor cumsum_generated_plumbing(const at::Tensor & self, int64_t dim, ::std::optional dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::cumsum::call(self, dim, dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & cumsum__generated_plumbing(at::Tensor & self, int64_t dim, ::std::optional dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::cumsum_::call(self, dim, dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, dim, dtype); + return self; +} +template +at::Tensor cumsum_dimname_generated_plumbing(const at::Tensor & self, at::Dimname dim, ::std::optional dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::cumsum_dimname::call(self, dim, dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & cumsum__dimname_generated_plumbing(at::Tensor & self, at::Dimname dim, ::std::optional dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::cumsum__dimname::call(self, dim, dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, dim, dtype); + return self; +} +template +at::Tensor cumulative_trapezoid_x_generated_plumbing(const at::Tensor & y, const at::Tensor & x, int64_t dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(y, cur_level) && !isBatchedAtLevel(x, cur_level)) { + return at::_ops::cumulative_trapezoid_x::call(y, x, dim); + } + auto [y_value, y_bdim] = unwrapTensorAtLevel(y, cur_level); + auto [x_value, x_bdim] = unwrapTensorAtLevel(x, cur_level); + auto results = batch_rule(y_value, y_bdim, x_value, x_bdim, dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor cumulative_trapezoid_dx_generated_plumbing(const at::Tensor & y, const at::Scalar & dx, int64_t dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(y, cur_level)) { + return at::_ops::cumulative_trapezoid_dx::call(y, dx, dim); + } + auto [y_value, y_bdim] = unwrapTensorAtLevel(y, cur_level); + auto results = batch_rule(y_value, y_bdim, dx, dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor ctc_loss_IntList_generated_plumbing(const at::Tensor & log_probs, const at::Tensor & targets, at::IntArrayRef input_lengths, at::IntArrayRef target_lengths, int64_t blank, int64_t reduction, bool zero_infinity) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(log_probs, cur_level) && !isBatchedAtLevel(targets, cur_level)) { + return at::_ops::ctc_loss_IntList::call(log_probs, targets, input_lengths, target_lengths, blank, reduction, zero_infinity); + } + auto [log_probs_value, log_probs_bdim] = unwrapTensorAtLevel(log_probs, cur_level); + auto [targets_value, targets_bdim] = unwrapTensorAtLevel(targets, cur_level); + auto results = batch_rule(log_probs_value, log_probs_bdim, targets_value, targets_bdim, input_lengths, target_lengths, blank, reduction, zero_infinity); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor ctc_loss_Tensor_generated_plumbing(const at::Tensor & log_probs, const at::Tensor & targets, const at::Tensor & input_lengths, const at::Tensor & target_lengths, int64_t blank, int64_t reduction, bool zero_infinity) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(log_probs, cur_level) && !isBatchedAtLevel(targets, cur_level) && !isBatchedAtLevel(input_lengths, cur_level) && !isBatchedAtLevel(target_lengths, cur_level)) { + return at::_ops::ctc_loss_Tensor::call(log_probs, targets, input_lengths, target_lengths, blank, reduction, zero_infinity); + } + auto [log_probs_value, log_probs_bdim] = unwrapTensorAtLevel(log_probs, cur_level); + auto [targets_value, targets_bdim] = unwrapTensorAtLevel(targets, cur_level); + auto [input_lengths_value, input_lengths_bdim] = unwrapTensorAtLevel(input_lengths, cur_level); + auto [target_lengths_value, target_lengths_bdim] = unwrapTensorAtLevel(target_lengths, cur_level); + auto results = batch_rule(log_probs_value, log_probs_bdim, targets_value, targets_bdim, input_lengths_value, input_lengths_bdim, target_lengths_value, target_lengths_bdim, blank, reduction, zero_infinity); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple _ctc_loss_generated_plumbing(const at::Tensor & log_probs, const at::Tensor & targets, at::IntArrayRef input_lengths, at::IntArrayRef target_lengths, int64_t blank, bool zero_infinity) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(log_probs, cur_level) && !isBatchedAtLevel(targets, cur_level)) { + return at::_ops::_ctc_loss::call(log_probs, targets, input_lengths, target_lengths, blank, zero_infinity); + } + auto [log_probs_value, log_probs_bdim] = unwrapTensorAtLevel(log_probs, cur_level); + auto [targets_value, targets_bdim] = unwrapTensorAtLevel(targets, cur_level); + auto results = batch_rule(log_probs_value, log_probs_bdim, targets_value, targets_bdim, input_lengths, target_lengths, blank, zero_infinity); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple _ctc_loss_Tensor_generated_plumbing(const at::Tensor & log_probs, const at::Tensor & targets, const at::Tensor & input_lengths, const at::Tensor & target_lengths, int64_t blank, bool zero_infinity) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(log_probs, cur_level) && !isBatchedAtLevel(targets, cur_level) && !isBatchedAtLevel(input_lengths, cur_level) && !isBatchedAtLevel(target_lengths, cur_level)) { + return at::_ops::_ctc_loss_Tensor::call(log_probs, targets, input_lengths, target_lengths, blank, zero_infinity); + } + auto [log_probs_value, log_probs_bdim] = unwrapTensorAtLevel(log_probs, cur_level); + auto [targets_value, targets_bdim] = unwrapTensorAtLevel(targets, cur_level); + auto [input_lengths_value, input_lengths_bdim] = unwrapTensorAtLevel(input_lengths, cur_level); + auto [target_lengths_value, target_lengths_bdim] = unwrapTensorAtLevel(target_lengths, cur_level); + auto results = batch_rule(log_probs_value, log_probs_bdim, targets_value, targets_bdim, input_lengths_value, input_lengths_bdim, target_lengths_value, target_lengths_bdim, blank, zero_infinity); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor _ctc_loss_backward_generated_plumbing(const at::Tensor & grad, const at::Tensor & log_probs, const at::Tensor & targets, at::IntArrayRef input_lengths, at::IntArrayRef target_lengths, const at::Tensor & neg_log_likelihood, const at::Tensor & log_alpha, int64_t blank, bool zero_infinity) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad, cur_level) && !isBatchedAtLevel(log_probs, cur_level) && !isBatchedAtLevel(targets, cur_level) && !isBatchedAtLevel(neg_log_likelihood, cur_level) && !isBatchedAtLevel(log_alpha, cur_level)) { + return at::_ops::_ctc_loss_backward::call(grad, log_probs, targets, input_lengths, target_lengths, neg_log_likelihood, log_alpha, blank, zero_infinity); + } + auto [grad_value, grad_bdim] = unwrapTensorAtLevel(grad, cur_level); + auto [log_probs_value, log_probs_bdim] = unwrapTensorAtLevel(log_probs, cur_level); + auto [targets_value, targets_bdim] = unwrapTensorAtLevel(targets, cur_level); + auto [neg_log_likelihood_value, neg_log_likelihood_bdim] = unwrapTensorAtLevel(neg_log_likelihood, cur_level); + auto [log_alpha_value, log_alpha_bdim] = unwrapTensorAtLevel(log_alpha, cur_level); + auto results = batch_rule(grad_value, grad_bdim, log_probs_value, log_probs_bdim, targets_value, targets_bdim, input_lengths, target_lengths, neg_log_likelihood_value, neg_log_likelihood_bdim, log_alpha_value, log_alpha_bdim, blank, zero_infinity); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _ctc_loss_backward_Tensor_generated_plumbing(const at::Tensor & grad, const at::Tensor & log_probs, const at::Tensor & targets, const at::Tensor & input_lengths, const at::Tensor & target_lengths, const at::Tensor & neg_log_likelihood, const at::Tensor & log_alpha, int64_t blank, bool zero_infinity) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad, cur_level) && !isBatchedAtLevel(log_probs, cur_level) && !isBatchedAtLevel(targets, cur_level) && !isBatchedAtLevel(input_lengths, cur_level) && !isBatchedAtLevel(target_lengths, cur_level) && !isBatchedAtLevel(neg_log_likelihood, cur_level) && !isBatchedAtLevel(log_alpha, cur_level)) { + return at::_ops::_ctc_loss_backward_Tensor::call(grad, log_probs, targets, input_lengths, target_lengths, neg_log_likelihood, log_alpha, blank, zero_infinity); + } + auto [grad_value, grad_bdim] = unwrapTensorAtLevel(grad, cur_level); + auto [log_probs_value, log_probs_bdim] = unwrapTensorAtLevel(log_probs, cur_level); + auto [targets_value, targets_bdim] = unwrapTensorAtLevel(targets, cur_level); + auto [input_lengths_value, input_lengths_bdim] = unwrapTensorAtLevel(input_lengths, cur_level); + auto [target_lengths_value, target_lengths_bdim] = unwrapTensorAtLevel(target_lengths, cur_level); + auto [neg_log_likelihood_value, neg_log_likelihood_bdim] = unwrapTensorAtLevel(neg_log_likelihood, cur_level); + auto [log_alpha_value, log_alpha_bdim] = unwrapTensorAtLevel(log_alpha, cur_level); + auto results = batch_rule(grad_value, grad_bdim, log_probs_value, log_probs_bdim, targets_value, targets_bdim, input_lengths_value, input_lengths_bdim, target_lengths_value, target_lengths_bdim, neg_log_likelihood_value, neg_log_likelihood_bdim, log_alpha_value, log_alpha_bdim, blank, zero_infinity); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor diag_embed_generated_plumbing(const at::Tensor & self, int64_t offset, int64_t dim1, int64_t dim2) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::diag_embed::call(self, offset, dim1, dim2); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, offset, dim1, dim2); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor diagflat_generated_plumbing(const at::Tensor & self, int64_t offset) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::diagflat::call(self, offset); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, offset); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor diagonal_generated_plumbing(const at::Tensor & self, int64_t offset, int64_t dim1, int64_t dim2) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::diagonal::call(self, offset, dim1, dim2); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, offset, dim1, dim2); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor linalg_diagonal_generated_plumbing(const at::Tensor & A, int64_t offset, int64_t dim1, int64_t dim2) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(A, cur_level)) { + return at::_ops::linalg_diagonal::call(A, offset, dim1, dim2); + } + auto [A_value, A_bdim] = unwrapTensorAtLevel(A, cur_level); + auto results = batch_rule(A_value, A_bdim, offset, dim1, dim2); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor diagonal_Dimname_generated_plumbing(const at::Tensor & self, at::Dimname outdim, at::Dimname dim1, at::Dimname dim2, int64_t offset) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::diagonal_Dimname::call(self, outdim, dim1, dim2, offset); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, outdim, dim1, dim2, offset); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor diagonal_backward_generated_plumbing(const at::Tensor & grad_output, c10::SymIntArrayRef input_sizes, int64_t offset, int64_t dim1, int64_t dim2) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level)) { + return at::_ops::diagonal_backward::call(grad_output, input_sizes, offset, dim1, dim2); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, input_sizes, offset, dim1, dim2); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & fill_diagonal__generated_plumbing(at::Tensor & self, const at::Scalar & fill_value, bool wrap) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::fill_diagonal_::call(self, fill_value, wrap); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, fill_value, wrap); + return self; +} +template +at::Tensor diff_generated_plumbing(const at::Tensor & self, int64_t n, int64_t dim, const ::std::optional & prepend, const ::std::optional & append) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(prepend, cur_level) && !isBatchedAtLevel(append, cur_level)) { + return at::_ops::diff::call(self, n, dim, prepend, append); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + std::optional prepend_value; + std::optional prepend_bdim; + if (prepend) { + std::tie(prepend_value, prepend_bdim) = unwrapTensorAtLevel(prepend.value(), cur_level); + } + std::optional append_value; + std::optional append_bdim; + if (append) { + std::tie(append_value, append_bdim) = unwrapTensorAtLevel(append.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, n, dim, prepend_value, prepend_bdim, append_value, append_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::vector gradient_scalarint_generated_plumbing(const at::Tensor & self, const ::std::optional & spacing, ::std::optional dim, int64_t edge_order) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::gradient_scalarint::call(self, spacing, dim, edge_order); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, spacing, dim, edge_order); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::vector gradient_scalararray_generated_plumbing(const at::Tensor & self, const at::Scalar & spacing, at::IntArrayRef dim, int64_t edge_order) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::gradient_scalararray::call(self, spacing, dim, edge_order); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, spacing, dim, edge_order); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::vector gradient_array_generated_plumbing(const at::Tensor & self, at::IntArrayRef dim, int64_t edge_order) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::gradient_array::call(self, dim, edge_order); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, edge_order); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::vector gradient_scalarrayint_generated_plumbing(const at::Tensor & self, at::ArrayRef spacing, ::std::optional dim, int64_t edge_order) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::gradient_scalarrayint::call(self, spacing, dim, edge_order); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, spacing, dim, edge_order); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::vector gradient_scalarrayarray_generated_plumbing(const at::Tensor & self, at::ArrayRef spacing, at::IntArrayRef dim, int64_t edge_order) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::gradient_scalarrayarray::call(self, spacing, dim, edge_order); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, spacing, dim, edge_order); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::vector gradient_tensorarrayint_generated_plumbing(const at::Tensor & self, at::TensorList spacing, ::std::optional dim, int64_t edge_order) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(spacing, cur_level)) { + return at::_ops::gradient_tensorarrayint::call(self, spacing, dim, edge_order); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, spacing, dim, edge_order); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::vector gradient_tensorarray_generated_plumbing(const at::Tensor & self, at::TensorList spacing, at::IntArrayRef dim, int64_t edge_order) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(spacing, cur_level)) { + return at::_ops::gradient_tensorarray::call(self, spacing, dim, edge_order); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, spacing, dim, edge_order); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor div_Tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::div_Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & div__Tensor_generated_plumbing(at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::div__Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self_value, self_bdim, other_value, other_bdim); + return self; +} +template +at::Tensor div_Tensor_mode_generated_plumbing(const at::Tensor & self, const at::Tensor & other, ::std::optional rounding_mode) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::div_Tensor_mode::call(self, other, rounding_mode); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim, rounding_mode); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & div__Tensor_mode_generated_plumbing(at::Tensor & self, const at::Tensor & other, ::std::optional rounding_mode) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::div__Tensor_mode::call(self, other, rounding_mode); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self_value, self_bdim, other_value, other_bdim, rounding_mode); + return self; +} +template +at::Tensor div_Scalar_generated_plumbing(const at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::div_Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, other); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & div__Scalar_generated_plumbing(at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::div__Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, other); + return self; +} +template +at::Tensor div_Scalar_mode_generated_plumbing(const at::Tensor & self, const at::Scalar & other, ::std::optional rounding_mode) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::div_Scalar_mode::call(self, other, rounding_mode); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, other, rounding_mode); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & div__Scalar_mode_generated_plumbing(at::Tensor & self, const at::Scalar & other, ::std::optional rounding_mode) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::div__Scalar_mode::call(self, other, rounding_mode); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, other, rounding_mode); + return self; +} +template +at::Tensor divide_Tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::divide_Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & divide__Tensor_generated_plumbing(at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::divide__Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self_value, self_bdim, other_value, other_bdim); + return self; +} +template +at::Tensor divide_Scalar_generated_plumbing(const at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::divide_Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, other); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & divide__Scalar_generated_plumbing(at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::divide__Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, other); + return self; +} +template +at::Tensor divide_Tensor_mode_generated_plumbing(const at::Tensor & self, const at::Tensor & other, ::std::optional rounding_mode) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::divide_Tensor_mode::call(self, other, rounding_mode); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim, rounding_mode); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & divide__Tensor_mode_generated_plumbing(at::Tensor & self, const at::Tensor & other, ::std::optional rounding_mode) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::divide__Tensor_mode::call(self, other, rounding_mode); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self_value, self_bdim, other_value, other_bdim, rounding_mode); + return self; +} +template +at::Tensor divide_Scalar_mode_generated_plumbing(const at::Tensor & self, const at::Scalar & other, ::std::optional rounding_mode) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::divide_Scalar_mode::call(self, other, rounding_mode); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, other, rounding_mode); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & divide__Scalar_mode_generated_plumbing(at::Tensor & self, const at::Scalar & other, ::std::optional rounding_mode) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::divide__Scalar_mode::call(self, other, rounding_mode); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, other, rounding_mode); + return self; +} +template +at::Tensor true_divide_Tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::true_divide_Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & true_divide__Tensor_generated_plumbing(at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::true_divide__Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self_value, self_bdim, other_value, other_bdim); + return self; +} +template +at::Tensor true_divide_Scalar_generated_plumbing(const at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::true_divide_Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, other); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & true_divide__Scalar_generated_plumbing(at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::true_divide__Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, other); + return self; +} +template +at::Tensor dot_generated_plumbing(const at::Tensor & self, const at::Tensor & tensor) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(tensor, cur_level)) { + return at::_ops::dot::call(self, tensor); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [tensor_value, tensor_bdim] = unwrapTensorAtLevel(tensor, cur_level); + auto results = batch_rule(self_value, self_bdim, tensor_value, tensor_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor vdot_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::vdot::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor einsum_generated_plumbing(c10::string_view equation, at::TensorList tensors, at::OptionalIntArrayRef path) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(tensors, cur_level)) { + return at::_ops::einsum::call(equation, tensors, path); + } + + auto results = batch_rule(equation, tensors, path); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor embedding_generated_plumbing(const at::Tensor & weight, const at::Tensor & indices, c10::SymInt padding_idx, bool scale_grad_by_freq, bool sparse) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(indices, cur_level)) { + return at::_ops::embedding::call(weight, indices, padding_idx, scale_grad_by_freq, sparse); + } + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + auto [indices_value, indices_bdim] = unwrapTensorAtLevel(indices, cur_level); + auto results = batch_rule(weight_value, weight_bdim, indices_value, indices_bdim, padding_idx, scale_grad_by_freq, sparse); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor embedding_backward_generated_plumbing(const at::Tensor & grad, const at::Tensor & indices, c10::SymInt num_weights, c10::SymInt padding_idx, bool scale_grad_by_freq, bool sparse) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad, cur_level) && !isBatchedAtLevel(indices, cur_level)) { + return at::_ops::embedding_backward::call(grad, indices, num_weights, padding_idx, scale_grad_by_freq, sparse); + } + auto [grad_value, grad_bdim] = unwrapTensorAtLevel(grad, cur_level); + auto [indices_value, indices_bdim] = unwrapTensorAtLevel(indices, cur_level); + auto results = batch_rule(grad_value, grad_bdim, indices_value, indices_bdim, num_weights, padding_idx, scale_grad_by_freq, sparse); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor embedding_dense_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & indices, c10::SymInt num_weights, c10::SymInt padding_idx, bool scale_grad_by_freq) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(indices, cur_level)) { + return at::_ops::embedding_dense_backward::call(grad_output, indices, num_weights, padding_idx, scale_grad_by_freq); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [indices_value, indices_bdim] = unwrapTensorAtLevel(indices, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, indices_value, indices_bdim, num_weights, padding_idx, scale_grad_by_freq); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & embedding_renorm__generated_plumbing(at::Tensor & self, const at::Tensor & indices, double max_norm, double norm_type) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(indices, cur_level)) { + return at::_ops::embedding_renorm_::call(self, indices, max_norm, norm_type); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [indices_value, indices_bdim] = unwrapTensorAtLevel(indices, cur_level); + batch_rule(self_value, self_bdim, indices_value, indices_bdim, max_norm, norm_type); + return self; +} +template +at::Tensor embedding_sparse_backward_generated_plumbing(const at::Tensor & grad, const at::Tensor & indices, int64_t num_weights, int64_t padding_idx, bool scale_grad_by_freq) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad, cur_level) && !isBatchedAtLevel(indices, cur_level)) { + return at::_ops::embedding_sparse_backward::call(grad, indices, num_weights, padding_idx, scale_grad_by_freq); + } + auto [grad_value, grad_bdim] = unwrapTensorAtLevel(grad, cur_level); + auto [indices_value, indices_bdim] = unwrapTensorAtLevel(indices, cur_level); + auto results = batch_rule(grad_value, grad_bdim, indices_value, indices_bdim, num_weights, padding_idx, scale_grad_by_freq); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple _embedding_bag_forward_only_generated_plumbing(const at::Tensor & weight, const at::Tensor & indices, const at::Tensor & offsets, bool scale_grad_by_freq, int64_t mode, bool sparse, const ::std::optional & per_sample_weights, bool include_last_offset, int64_t padding_idx) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(indices, cur_level) && !isBatchedAtLevel(offsets, cur_level) && !isBatchedAtLevel(per_sample_weights, cur_level)) { + return at::_ops::_embedding_bag_forward_only::call(weight, indices, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights, include_last_offset, padding_idx); + } + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + auto [indices_value, indices_bdim] = unwrapTensorAtLevel(indices, cur_level); + auto [offsets_value, offsets_bdim] = unwrapTensorAtLevel(offsets, cur_level); + std::optional per_sample_weights_value; + std::optional per_sample_weights_bdim; + if (per_sample_weights) { + std::tie(per_sample_weights_value, per_sample_weights_bdim) = unwrapTensorAtLevel(per_sample_weights.value(), cur_level); + } + auto results = batch_rule(weight_value, weight_bdim, indices_value, indices_bdim, offsets_value, offsets_bdim, scale_grad_by_freq, mode, sparse, per_sample_weights_value, per_sample_weights_bdim, include_last_offset, padding_idx); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level), makeBatched(std::get<6>(results), std::get<7>(results), cur_level)); +} +template +::std::tuple _rowwise_prune_generated_plumbing(const at::Tensor & weight, const at::Tensor & mask, at::ScalarType compressed_indices_dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(mask, cur_level)) { + return at::_ops::_rowwise_prune::call(weight, mask, compressed_indices_dtype); + } + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + auto [mask_value, mask_bdim] = unwrapTensorAtLevel(mask, cur_level); + auto results = batch_rule(weight_value, weight_bdim, mask_value, mask_bdim, compressed_indices_dtype); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor row_stack_generated_plumbing(at::TensorList tensors) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(tensors, cur_level)) { + return at::_ops::row_stack::call(tensors); + } + + auto results = batch_rule(tensors); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple embedding_bag_generated_plumbing(const at::Tensor & weight, const at::Tensor & indices, const at::Tensor & offsets, bool scale_grad_by_freq, int64_t mode, bool sparse, const ::std::optional & per_sample_weights, bool include_last_offset) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(indices, cur_level) && !isBatchedAtLevel(offsets, cur_level) && !isBatchedAtLevel(per_sample_weights, cur_level)) { + return at::_ops::embedding_bag::call(weight, indices, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights, include_last_offset); + } + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + auto [indices_value, indices_bdim] = unwrapTensorAtLevel(indices, cur_level); + auto [offsets_value, offsets_bdim] = unwrapTensorAtLevel(offsets, cur_level); + std::optional per_sample_weights_value; + std::optional per_sample_weights_bdim; + if (per_sample_weights) { + std::tie(per_sample_weights_value, per_sample_weights_bdim) = unwrapTensorAtLevel(per_sample_weights.value(), cur_level); + } + auto results = batch_rule(weight_value, weight_bdim, indices_value, indices_bdim, offsets_value, offsets_bdim, scale_grad_by_freq, mode, sparse, per_sample_weights_value, per_sample_weights_bdim, include_last_offset); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level), makeBatched(std::get<6>(results), std::get<7>(results), cur_level)); +} +template +::std::tuple embedding_bag_padding_idx_generated_plumbing(const at::Tensor & weight, const at::Tensor & indices, const at::Tensor & offsets, bool scale_grad_by_freq, int64_t mode, bool sparse, const ::std::optional & per_sample_weights, bool include_last_offset, ::std::optional padding_idx) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(indices, cur_level) && !isBatchedAtLevel(offsets, cur_level) && !isBatchedAtLevel(per_sample_weights, cur_level)) { + return at::_ops::embedding_bag_padding_idx::call(weight, indices, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights, include_last_offset, padding_idx); + } + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + auto [indices_value, indices_bdim] = unwrapTensorAtLevel(indices, cur_level); + auto [offsets_value, offsets_bdim] = unwrapTensorAtLevel(offsets, cur_level); + std::optional per_sample_weights_value; + std::optional per_sample_weights_bdim; + if (per_sample_weights) { + std::tie(per_sample_weights_value, per_sample_weights_bdim) = unwrapTensorAtLevel(per_sample_weights.value(), cur_level); + } + auto results = batch_rule(weight_value, weight_bdim, indices_value, indices_bdim, offsets_value, offsets_bdim, scale_grad_by_freq, mode, sparse, per_sample_weights_value, per_sample_weights_bdim, include_last_offset, padding_idx); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level), makeBatched(std::get<6>(results), std::get<7>(results), cur_level)); +} +template +::std::tuple _embedding_bag_generated_plumbing(const at::Tensor & weight, const at::Tensor & indices, const at::Tensor & offsets, bool scale_grad_by_freq, int64_t mode, bool sparse, const ::std::optional & per_sample_weights, bool include_last_offset, int64_t padding_idx) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(indices, cur_level) && !isBatchedAtLevel(offsets, cur_level) && !isBatchedAtLevel(per_sample_weights, cur_level)) { + return at::_ops::_embedding_bag::call(weight, indices, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights, include_last_offset, padding_idx); + } + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + auto [indices_value, indices_bdim] = unwrapTensorAtLevel(indices, cur_level); + auto [offsets_value, offsets_bdim] = unwrapTensorAtLevel(offsets, cur_level); + std::optional per_sample_weights_value; + std::optional per_sample_weights_bdim; + if (per_sample_weights) { + std::tie(per_sample_weights_value, per_sample_weights_bdim) = unwrapTensorAtLevel(per_sample_weights.value(), cur_level); + } + auto results = batch_rule(weight_value, weight_bdim, indices_value, indices_bdim, offsets_value, offsets_bdim, scale_grad_by_freq, mode, sparse, per_sample_weights_value, per_sample_weights_bdim, include_last_offset, padding_idx); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level), makeBatched(std::get<6>(results), std::get<7>(results), cur_level)); +} +template +at::Tensor _embedding_bag_backward_generated_plumbing(const at::Tensor & grad, const at::Tensor & indices, const at::Tensor & offsets, const at::Tensor & offset2bag, const at::Tensor & bag_size, const at::Tensor & maximum_indices, c10::SymInt num_weights, bool scale_grad_by_freq, int64_t mode, bool sparse, const ::std::optional & per_sample_weights, int64_t padding_idx) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad, cur_level) && !isBatchedAtLevel(indices, cur_level) && !isBatchedAtLevel(offsets, cur_level) && !isBatchedAtLevel(offset2bag, cur_level) && !isBatchedAtLevel(bag_size, cur_level) && !isBatchedAtLevel(maximum_indices, cur_level) && !isBatchedAtLevel(per_sample_weights, cur_level)) { + return at::_ops::_embedding_bag_backward::call(grad, indices, offsets, offset2bag, bag_size, maximum_indices, num_weights, scale_grad_by_freq, mode, sparse, per_sample_weights, padding_idx); + } + auto [grad_value, grad_bdim] = unwrapTensorAtLevel(grad, cur_level); + auto [indices_value, indices_bdim] = unwrapTensorAtLevel(indices, cur_level); + auto [offsets_value, offsets_bdim] = unwrapTensorAtLevel(offsets, cur_level); + auto [offset2bag_value, offset2bag_bdim] = unwrapTensorAtLevel(offset2bag, cur_level); + auto [bag_size_value, bag_size_bdim] = unwrapTensorAtLevel(bag_size, cur_level); + auto [maximum_indices_value, maximum_indices_bdim] = unwrapTensorAtLevel(maximum_indices, cur_level); + std::optional per_sample_weights_value; + std::optional per_sample_weights_bdim; + if (per_sample_weights) { + std::tie(per_sample_weights_value, per_sample_weights_bdim) = unwrapTensorAtLevel(per_sample_weights.value(), cur_level); + } + auto results = batch_rule(grad_value, grad_bdim, indices_value, indices_bdim, offsets_value, offsets_bdim, offset2bag_value, offset2bag_bdim, bag_size_value, bag_size_bdim, maximum_indices_value, maximum_indices_bdim, num_weights, scale_grad_by_freq, mode, sparse, per_sample_weights_value, per_sample_weights_bdim, padding_idx); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _embedding_bag_sparse_backward_generated_plumbing(const at::Tensor & grad, const at::Tensor & indices, const at::Tensor & offsets, const at::Tensor & offset2bag, const at::Tensor & bag_size, c10::SymInt num_weights, bool scale_grad_by_freq, int64_t mode, const ::std::optional & per_sample_weights, int64_t padding_idx) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad, cur_level) && !isBatchedAtLevel(indices, cur_level) && !isBatchedAtLevel(offsets, cur_level) && !isBatchedAtLevel(offset2bag, cur_level) && !isBatchedAtLevel(bag_size, cur_level) && !isBatchedAtLevel(per_sample_weights, cur_level)) { + return at::_ops::_embedding_bag_sparse_backward::call(grad, indices, offsets, offset2bag, bag_size, num_weights, scale_grad_by_freq, mode, per_sample_weights, padding_idx); + } + auto [grad_value, grad_bdim] = unwrapTensorAtLevel(grad, cur_level); + auto [indices_value, indices_bdim] = unwrapTensorAtLevel(indices, cur_level); + auto [offsets_value, offsets_bdim] = unwrapTensorAtLevel(offsets, cur_level); + auto [offset2bag_value, offset2bag_bdim] = unwrapTensorAtLevel(offset2bag, cur_level); + auto [bag_size_value, bag_size_bdim] = unwrapTensorAtLevel(bag_size, cur_level); + std::optional per_sample_weights_value; + std::optional per_sample_weights_bdim; + if (per_sample_weights) { + std::tie(per_sample_weights_value, per_sample_weights_bdim) = unwrapTensorAtLevel(per_sample_weights.value(), cur_level); + } + auto results = batch_rule(grad_value, grad_bdim, indices_value, indices_bdim, offsets_value, offsets_bdim, offset2bag_value, offset2bag_bdim, bag_size_value, bag_size_bdim, num_weights, scale_grad_by_freq, mode, per_sample_weights_value, per_sample_weights_bdim, padding_idx); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _embedding_bag_dense_backward_generated_plumbing(const at::Tensor & grad, const at::Tensor & indices, const at::Tensor & offset2bag, const at::Tensor & bag_size, const at::Tensor & maximum_indices, c10::SymInt num_weights, bool scale_grad_by_freq, int64_t mode, const ::std::optional & per_sample_weights, int64_t padding_idx) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad, cur_level) && !isBatchedAtLevel(indices, cur_level) && !isBatchedAtLevel(offset2bag, cur_level) && !isBatchedAtLevel(bag_size, cur_level) && !isBatchedAtLevel(maximum_indices, cur_level) && !isBatchedAtLevel(per_sample_weights, cur_level)) { + return at::_ops::_embedding_bag_dense_backward::call(grad, indices, offset2bag, bag_size, maximum_indices, num_weights, scale_grad_by_freq, mode, per_sample_weights, padding_idx); + } + auto [grad_value, grad_bdim] = unwrapTensorAtLevel(grad, cur_level); + auto [indices_value, indices_bdim] = unwrapTensorAtLevel(indices, cur_level); + auto [offset2bag_value, offset2bag_bdim] = unwrapTensorAtLevel(offset2bag, cur_level); + auto [bag_size_value, bag_size_bdim] = unwrapTensorAtLevel(bag_size, cur_level); + auto [maximum_indices_value, maximum_indices_bdim] = unwrapTensorAtLevel(maximum_indices, cur_level); + std::optional per_sample_weights_value; + std::optional per_sample_weights_bdim; + if (per_sample_weights) { + std::tie(per_sample_weights_value, per_sample_weights_bdim) = unwrapTensorAtLevel(per_sample_weights.value(), cur_level); + } + auto results = batch_rule(grad_value, grad_bdim, indices_value, indices_bdim, offset2bag_value, offset2bag_bdim, bag_size_value, bag_size_bdim, maximum_indices_value, maximum_indices_bdim, num_weights, scale_grad_by_freq, mode, per_sample_weights_value, per_sample_weights_bdim, padding_idx); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _embedding_bag_per_sample_weights_backward_generated_plumbing(const at::Tensor & grad, const at::Tensor & weight, const at::Tensor & indices, const at::Tensor & offsets, const at::Tensor & offset2bag, int64_t mode, int64_t padding_idx) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(indices, cur_level) && !isBatchedAtLevel(offsets, cur_level) && !isBatchedAtLevel(offset2bag, cur_level)) { + return at::_ops::_embedding_bag_per_sample_weights_backward::call(grad, weight, indices, offsets, offset2bag, mode, padding_idx); + } + auto [grad_value, grad_bdim] = unwrapTensorAtLevel(grad, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + auto [indices_value, indices_bdim] = unwrapTensorAtLevel(indices, cur_level); + auto [offsets_value, offsets_bdim] = unwrapTensorAtLevel(offsets, cur_level); + auto [offset2bag_value, offset2bag_bdim] = unwrapTensorAtLevel(offset2bag, cur_level); + auto results = batch_rule(grad_value, grad_bdim, weight_value, weight_bdim, indices_value, indices_bdim, offsets_value, offsets_bdim, offset2bag_value, offset2bag_bdim, mode, padding_idx); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor new_empty_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::new_empty::call(self, size, dtype, layout, device, pin_memory); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, size, dtype, layout, device, pin_memory); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor new_empty_strided_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::new_empty_strided::call(self, size, stride, dtype, layout, device, pin_memory); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, size, stride, dtype, layout, device, pin_memory); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor new_full_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef size, const at::Scalar & fill_value, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::new_full::call(self, size, fill_value, dtype, layout, device, pin_memory); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, size, fill_value, dtype, layout, device, pin_memory); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor new_zeros_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::new_zeros::call(self, size, dtype, layout, device, pin_memory); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, size, dtype, layout, device, pin_memory); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor new_ones_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::new_ones::call(self, size, dtype, layout, device, pin_memory); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, size, dtype, layout, device, pin_memory); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _empty_per_channel_affine_quantized_generated_plumbing(c10::SymIntArrayRef size, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(scales, cur_level) && !isBatchedAtLevel(zero_points, cur_level)) { + return at::_ops::_empty_per_channel_affine_quantized::call(size, scales, zero_points, axis, dtype, layout, device, pin_memory, memory_format); + } + auto [scales_value, scales_bdim] = unwrapTensorAtLevel(scales, cur_level); + auto [zero_points_value, zero_points_bdim] = unwrapTensorAtLevel(zero_points, cur_level); + auto results = batch_rule(size, scales_value, scales_bdim, zero_points_value, zero_points_bdim, axis, dtype, layout, device, pin_memory, memory_format); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +const at::Tensor & _resize_output__generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef size, at::Device device) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_resize_output_::call(self, size, device); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, size, device); + return self; +} +template +at::Tensor empty_quantized_generated_plumbing(at::IntArrayRef size, const at::Tensor & qtensor, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(qtensor, cur_level)) { + return at::_ops::empty_quantized::call(size, qtensor, dtype, layout, device, pin_memory, memory_format); + } + auto [qtensor_value, qtensor_bdim] = unwrapTensorAtLevel(qtensor, cur_level); + auto results = batch_rule(size, qtensor_value, qtensor_bdim, dtype, layout, device, pin_memory, memory_format); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor empty_like_generated_plumbing(const at::Tensor & self, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::empty_like::call(self, dtype, layout, device, pin_memory, memory_format); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dtype, layout, device, pin_memory, memory_format); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor erf_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::erf::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & erf__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::erf_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor erfc_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::erfc::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & erfc__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::erfc_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor exp_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::exp::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & exp__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::exp_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor exp2_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::exp2::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & exp2__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::exp2_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor expm1_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::expm1::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & expm1__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::expm1_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor expand_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef size, bool implicit) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::expand::call(self, size, implicit); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, size, implicit); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor expand_as_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::expand_as::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor flatten_using_ints_generated_plumbing(const at::Tensor & self, int64_t start_dim, int64_t end_dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::flatten_using_ints::call(self, start_dim, end_dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, start_dim, end_dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor flatten_named_out_dim_generated_plumbing(const at::Tensor & self, int64_t start_dim, int64_t end_dim, at::Dimname out_dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::flatten_named_out_dim::call(self, start_dim, end_dim, out_dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, start_dim, end_dim, out_dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor flatten_using_names_generated_plumbing(const at::Tensor & self, at::Dimname start_dim, at::Dimname end_dim, at::Dimname out_dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::flatten_using_names::call(self, start_dim, end_dim, out_dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, start_dim, end_dim, out_dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor flatten_DimnameList_generated_plumbing(const at::Tensor & self, at::DimnameList dims, at::Dimname out_dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::flatten_DimnameList::call(self, dims, out_dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dims, out_dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor unflatten_int_generated_plumbing(const at::Tensor & self, int64_t dim, c10::SymIntArrayRef sizes) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::unflatten_int::call(self, dim, sizes); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, sizes); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor unflatten_Dimname_generated_plumbing(const at::Tensor & self, at::Dimname dim, c10::SymIntArrayRef sizes, at::DimnameList names) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::unflatten_Dimname::call(self, dim, sizes, names); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, sizes, names); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor fill_Scalar_generated_plumbing(const at::Tensor & self, const at::Scalar & value) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::fill_Scalar::call(self, value); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, value); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor fill_Tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & value) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(value, cur_level)) { + return at::_ops::fill_Tensor::call(self, value); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [value_value, value_bdim] = unwrapTensorAtLevel(value, cur_level); + auto results = batch_rule(self_value, self_bdim, value_value, value_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & fill__Scalar_generated_plumbing(at::Tensor & self, const at::Scalar & value) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::fill__Scalar::call(self, value); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, value); + return self; +} +template +at::Tensor & fill__Tensor_generated_plumbing(at::Tensor & self, const at::Tensor & value) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(value, cur_level)) { + return at::_ops::fill__Tensor::call(self, value); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [value_value, value_bdim] = unwrapTensorAtLevel(value, cur_level); + batch_rule(self_value, self_bdim, value_value, value_bdim); + return self; +} +template +at::Tensor floor_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::floor::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & floor__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::floor_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor floor_divide_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::floor_divide::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & floor_divide__Tensor_generated_plumbing(at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::floor_divide__Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self_value, self_bdim, other_value, other_bdim); + return self; +} +template +at::Tensor floor_divide_Scalar_generated_plumbing(const at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::floor_divide_Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, other); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & floor_divide__Scalar_generated_plumbing(at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::floor_divide__Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, other); + return self; +} +template +at::Tensor frac_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::frac::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & frac__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::frac_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor full_like_generated_plumbing(const at::Tensor & self, const at::Scalar & fill_value, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::full_like::call(self, fill_value, dtype, layout, device, pin_memory, memory_format); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, fill_value, dtype, layout, device, pin_memory, memory_format); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor gcd_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::gcd::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & gcd__generated_plumbing(at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::gcd_::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self_value, self_bdim, other_value, other_bdim); + return self; +} +template +at::Tensor lcm_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::lcm::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & lcm__generated_plumbing(at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::lcm_::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self_value, self_bdim, other_value, other_bdim); + return self; +} +template +at::Tensor grid_sampler_generated_plumbing(const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(grid, cur_level)) { + return at::_ops::grid_sampler::call(input, grid, interpolation_mode, padding_mode, align_corners); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [grid_value, grid_bdim] = unwrapTensorAtLevel(grid, cur_level); + auto results = batch_rule(input_value, input_bdim, grid_value, grid_bdim, interpolation_mode, padding_mode, align_corners); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor grid_sampler_2d_generated_plumbing(const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(grid, cur_level)) { + return at::_ops::grid_sampler_2d::call(input, grid, interpolation_mode, padding_mode, align_corners); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [grid_value, grid_bdim] = unwrapTensorAtLevel(grid, cur_level); + auto results = batch_rule(input_value, input_bdim, grid_value, grid_bdim, interpolation_mode, padding_mode, align_corners); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple grid_sampler_2d_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners, ::std::array output_mask) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(grid, cur_level)) { + return at::_ops::grid_sampler_2d_backward::call(grad_output, input, grid, interpolation_mode, padding_mode, align_corners, output_mask); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [grid_value, grid_bdim] = unwrapTensorAtLevel(grid, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, input_value, input_bdim, grid_value, grid_bdim, interpolation_mode, padding_mode, align_corners, output_mask); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor _grid_sampler_2d_cpu_fallback_generated_plumbing(const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(grid, cur_level)) { + return at::_ops::_grid_sampler_2d_cpu_fallback::call(input, grid, interpolation_mode, padding_mode, align_corners); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [grid_value, grid_bdim] = unwrapTensorAtLevel(grid, cur_level); + auto results = batch_rule(input_value, input_bdim, grid_value, grid_bdim, interpolation_mode, padding_mode, align_corners); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple _grid_sampler_2d_cpu_fallback_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(grid, cur_level)) { + return at::_ops::_grid_sampler_2d_cpu_fallback_backward::call(grad_output, input, grid, interpolation_mode, padding_mode, align_corners); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [grid_value, grid_bdim] = unwrapTensorAtLevel(grid, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, input_value, input_bdim, grid_value, grid_bdim, interpolation_mode, padding_mode, align_corners); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor grid_sampler_3d_generated_plumbing(const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(grid, cur_level)) { + return at::_ops::grid_sampler_3d::call(input, grid, interpolation_mode, padding_mode, align_corners); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [grid_value, grid_bdim] = unwrapTensorAtLevel(grid, cur_level); + auto results = batch_rule(input_value, input_bdim, grid_value, grid_bdim, interpolation_mode, padding_mode, align_corners); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple grid_sampler_3d_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners, ::std::array output_mask) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(grid, cur_level)) { + return at::_ops::grid_sampler_3d_backward::call(grad_output, input, grid, interpolation_mode, padding_mode, align_corners, output_mask); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [grid_value, grid_bdim] = unwrapTensorAtLevel(grid, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, input_value, input_bdim, grid_value, grid_bdim, interpolation_mode, padding_mode, align_corners, output_mask); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor hinge_embedding_loss_generated_plumbing(const at::Tensor & self, const at::Tensor & target, double margin, int64_t reduction) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(target, cur_level)) { + return at::_ops::hinge_embedding_loss::call(self, target, margin, reduction); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [target_value, target_bdim] = unwrapTensorAtLevel(target, cur_level); + auto results = batch_rule(self_value, self_bdim, target_value, target_bdim, margin, reduction); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor group_norm_generated_plumbing(const at::Tensor & input, int64_t num_groups, const ::std::optional & weight, const ::std::optional & bias, double eps, bool cudnn_enabled) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::group_norm::call(input, num_groups, weight, bias, eps, cudnn_enabled); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + std::optional weight_value; + std::optional weight_bdim; + if (weight) { + std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight.value(), cur_level); + } + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, num_groups, weight_value, weight_bdim, bias_value, bias_bdim, eps, cudnn_enabled); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple native_group_norm_generated_plumbing(const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, c10::SymInt N, c10::SymInt C, c10::SymInt HxW, int64_t group, double eps) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::native_group_norm::call(input, weight, bias, N, C, HxW, group, eps); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + std::optional weight_value; + std::optional weight_bdim; + if (weight) { + std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight.value(), cur_level); + } + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, weight_value, weight_bdim, bias_value, bias_bdim, N, C, HxW, group, eps); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level)); +} +template +::std::tuple native_group_norm_backward_generated_plumbing(const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & rstd, const ::std::optional & weight, c10::SymInt N, c10::SymInt C, c10::SymInt HxW, int64_t group, ::std::array output_mask) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_out, cur_level) && !isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(mean, cur_level) && !isBatchedAtLevel(rstd, cur_level) && !isBatchedAtLevel(weight, cur_level)) { + return at::_ops::native_group_norm_backward::call(grad_out, input, mean, rstd, weight, N, C, HxW, group, output_mask); + } + auto [grad_out_value, grad_out_bdim] = unwrapTensorAtLevel(grad_out, cur_level); + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [mean_value, mean_bdim] = unwrapTensorAtLevel(mean, cur_level); + auto [rstd_value, rstd_bdim] = unwrapTensorAtLevel(rstd, cur_level); + std::optional weight_value; + std::optional weight_bdim; + if (weight) { + std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight.value(), cur_level); + } + auto results = batch_rule(grad_out_value, grad_out_bdim, input_value, input_bdim, mean_value, mean_bdim, rstd_value, rstd_bdim, weight_value, weight_bdim, N, C, HxW, group, output_mask); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level)); +} +template +at::Tensor _fft_r2c_generated_plumbing(const at::Tensor & self, at::IntArrayRef dim, int64_t normalization, bool onesided) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_fft_r2c::call(self, dim, normalization, onesided); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, normalization, onesided); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _fft_c2r_generated_plumbing(const at::Tensor & self, at::IntArrayRef dim, int64_t normalization, c10::SymInt last_dim_size) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_fft_c2r::call(self, dim, normalization, last_dim_size); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, normalization, last_dim_size); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _fft_c2c_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef dim, int64_t normalization, bool forward) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_fft_c2c::call(self, dim, normalization, forward); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, normalization, forward); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _validate_compressed_sparse_indices_generated_plumbing(bool is_crow, const at::Tensor & compressed_idx, const at::Tensor & plain_idx, int64_t cdim, int64_t dim, int64_t nnz) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(compressed_idx, cur_level) && !isBatchedAtLevel(plain_idx, cur_level)) { + return at::_ops::_validate_compressed_sparse_indices::call(is_crow, compressed_idx, plain_idx, cdim, dim, nnz); + } + auto [compressed_idx_value, compressed_idx_bdim] = unwrapTensorAtLevel(compressed_idx, cur_level); + auto [plain_idx_value, plain_idx_bdim] = unwrapTensorAtLevel(plain_idx, cur_level); + batch_rule(is_crow, compressed_idx_value, compressed_idx_bdim, plain_idx_value, plain_idx_bdim, cdim, dim, nnz); +} +template +at::Tensor index_Tensor_generated_plumbing(const at::Tensor & self, const c10::List<::std::optional> & indices) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(indices, cur_level)) { + return at::_ops::index_Tensor::call(self, indices); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, indices); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _unsafe_index_Tensor_generated_plumbing(const at::Tensor & self, const c10::List<::std::optional> & indices) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(indices, cur_level)) { + return at::_ops::_unsafe_index_Tensor::call(self, indices); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, indices); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _unsafe_masked_index_generated_plumbing(const at::Tensor & self, const at::Tensor & mask, const c10::List<::std::optional> & indices, const at::Scalar & fill) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(mask, cur_level) && !isBatchedAtLevel(indices, cur_level)) { + return at::_ops::_unsafe_masked_index::call(self, mask, indices, fill); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [mask_value, mask_bdim] = unwrapTensorAtLevel(mask, cur_level); + auto results = batch_rule(self_value, self_bdim, mask_value, mask_bdim, indices, fill); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _unsafe_masked_index_put_accumulate_generated_plumbing(const at::Tensor & self, const at::Tensor & mask, const c10::List<::std::optional> & indices, const at::Tensor & values) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(mask, cur_level) && !isBatchedAtLevel(indices, cur_level) && !isBatchedAtLevel(values, cur_level)) { + return at::_ops::_unsafe_masked_index_put_accumulate::call(self, mask, indices, values); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [mask_value, mask_bdim] = unwrapTensorAtLevel(mask, cur_level); + auto [values_value, values_bdim] = unwrapTensorAtLevel(values, cur_level); + auto results = batch_rule(self_value, self_bdim, mask_value, mask_bdim, indices, values_value, values_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & index_copy__generated_plumbing(at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & source) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(index, cur_level) && !isBatchedAtLevel(source, cur_level)) { + return at::_ops::index_copy_::call(self, dim, index, source); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [index_value, index_bdim] = unwrapTensorAtLevel(index, cur_level); + auto [source_value, source_bdim] = unwrapTensorAtLevel(source, cur_level); + batch_rule(self_value, self_bdim, dim, index_value, index_bdim, source_value, source_bdim); + return self; +} +template +at::Tensor index_copy_generated_plumbing(const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & source) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(index, cur_level) && !isBatchedAtLevel(source, cur_level)) { + return at::_ops::index_copy::call(self, dim, index, source); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [index_value, index_bdim] = unwrapTensorAtLevel(index, cur_level); + auto [source_value, source_bdim] = unwrapTensorAtLevel(source, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, index_value, index_bdim, source_value, source_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & index_copy__dimname_generated_plumbing(at::Tensor & self, at::Dimname dim, const at::Tensor & index, const at::Tensor & source) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(index, cur_level) && !isBatchedAtLevel(source, cur_level)) { + return at::_ops::index_copy__dimname::call(self, dim, index, source); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [index_value, index_bdim] = unwrapTensorAtLevel(index, cur_level); + auto [source_value, source_bdim] = unwrapTensorAtLevel(source, cur_level); + batch_rule(self_value, self_bdim, dim, index_value, index_bdim, source_value, source_bdim); + return self; +} +template +at::Tensor index_copy_dimname_generated_plumbing(const at::Tensor & self, at::Dimname dim, const at::Tensor & index, const at::Tensor & source) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(index, cur_level) && !isBatchedAtLevel(source, cur_level)) { + return at::_ops::index_copy_dimname::call(self, dim, index, source); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [index_value, index_bdim] = unwrapTensorAtLevel(index, cur_level); + auto [source_value, source_bdim] = unwrapTensorAtLevel(source, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, index_value, index_bdim, source_value, source_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & index_put__generated_plumbing(at::Tensor & self, const c10::List<::std::optional> & indices, const at::Tensor & values, bool accumulate) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(indices, cur_level) && !isBatchedAtLevel(values, cur_level)) { + return at::_ops::index_put_::call(self, indices, values, accumulate); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [values_value, values_bdim] = unwrapTensorAtLevel(values, cur_level); + batch_rule(self_value, self_bdim, indices, values_value, values_bdim, accumulate); + return self; +} +template +at::Tensor index_put_generated_plumbing(const at::Tensor & self, const c10::List<::std::optional> & indices, const at::Tensor & values, bool accumulate) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(indices, cur_level) && !isBatchedAtLevel(values, cur_level)) { + return at::_ops::index_put::call(self, indices, values, accumulate); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [values_value, values_bdim] = unwrapTensorAtLevel(values, cur_level); + auto results = batch_rule(self_value, self_bdim, indices, values_value, values_bdim, accumulate); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _unsafe_index_put_generated_plumbing(const at::Tensor & self, const c10::List<::std::optional> & indices, const at::Tensor & values, bool accumulate) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(indices, cur_level) && !isBatchedAtLevel(values, cur_level)) { + return at::_ops::_unsafe_index_put::call(self, indices, values, accumulate); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [values_value, values_bdim] = unwrapTensorAtLevel(values, cur_level); + auto results = batch_rule(self_value, self_bdim, indices, values_value, values_bdim, accumulate); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & _index_put_impl__generated_plumbing(at::Tensor & self, const c10::List<::std::optional> & indices, const at::Tensor & values, bool accumulate, bool unsafe) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(indices, cur_level) && !isBatchedAtLevel(values, cur_level)) { + return at::_ops::_index_put_impl_::call(self, indices, values, accumulate, unsafe); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [values_value, values_bdim] = unwrapTensorAtLevel(values, cur_level); + batch_rule(self_value, self_bdim, indices, values_value, values_bdim, accumulate, unsafe); + return self; +} +template +at::Tensor instance_norm_generated_plumbing(const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const ::std::optional & running_mean, const ::std::optional & running_var, bool use_input_stats, double momentum, double eps, bool cudnn_enabled) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level) && !isBatchedAtLevel(running_mean, cur_level) && !isBatchedAtLevel(running_var, cur_level)) { + return at::_ops::instance_norm::call(input, weight, bias, running_mean, running_var, use_input_stats, momentum, eps, cudnn_enabled); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + std::optional weight_value; + std::optional weight_bdim; + if (weight) { + std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight.value(), cur_level); + } + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + std::optional running_mean_value; + std::optional running_mean_bdim; + if (running_mean) { + std::tie(running_mean_value, running_mean_bdim) = unwrapTensorAtLevel(running_mean.value(), cur_level); + } + std::optional running_var_value; + std::optional running_var_bdim; + if (running_var) { + std::tie(running_var_value, running_var_bdim) = unwrapTensorAtLevel(running_var.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, weight_value, weight_bdim, bias_value, bias_bdim, running_mean_value, running_mean_bdim, running_var_value, running_var_bdim, use_input_stats, momentum, eps, cudnn_enabled); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor isclose_generated_plumbing(const at::Tensor & self, const at::Tensor & other, double rtol, double atol, bool equal_nan) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::isclose::call(self, other, rtol, atol, equal_nan); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim, rtol, atol, equal_nan); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor isin_Tensor_Tensor_generated_plumbing(const at::Tensor & elements, const at::Tensor & test_elements, bool assume_unique, bool invert) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(elements, cur_level) && !isBatchedAtLevel(test_elements, cur_level)) { + return at::_ops::isin_Tensor_Tensor::call(elements, test_elements, assume_unique, invert); + } + auto [elements_value, elements_bdim] = unwrapTensorAtLevel(elements, cur_level); + auto [test_elements_value, test_elements_bdim] = unwrapTensorAtLevel(test_elements, cur_level); + auto results = batch_rule(elements_value, elements_bdim, test_elements_value, test_elements_bdim, assume_unique, invert); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor isin_Tensor_Scalar_generated_plumbing(const at::Tensor & elements, const at::Scalar & test_element, bool assume_unique, bool invert) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(elements, cur_level)) { + return at::_ops::isin_Tensor_Scalar::call(elements, test_element, assume_unique, invert); + } + auto [elements_value, elements_bdim] = unwrapTensorAtLevel(elements, cur_level); + auto results = batch_rule(elements_value, elements_bdim, test_element, assume_unique, invert); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor isin_Scalar_Tensor_generated_plumbing(const at::Scalar & element, const at::Tensor & test_elements, bool assume_unique, bool invert) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(test_elements, cur_level)) { + return at::_ops::isin_Scalar_Tensor::call(element, test_elements, assume_unique, invert); + } + auto [test_elements_value, test_elements_bdim] = unwrapTensorAtLevel(test_elements, cur_level); + auto results = batch_rule(element, test_elements_value, test_elements_bdim, assume_unique, invert); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor isnan_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::isnan::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor isreal_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::isreal::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor kl_div_generated_plumbing(const at::Tensor & self, const at::Tensor & target, int64_t reduction, bool log_target) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(target, cur_level)) { + return at::_ops::kl_div::call(self, target, reduction, log_target); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [target_value, target_bdim] = unwrapTensorAtLevel(target, cur_level); + auto results = batch_rule(self_value, self_bdim, target_value, target_bdim, reduction, log_target); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor kron_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::kron::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple kthvalue_generated_plumbing(const at::Tensor & self, c10::SymInt k, int64_t dim, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::kthvalue::call(self, k, dim, keepdim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, k, dim, keepdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple kthvalue_dimname_generated_plumbing(const at::Tensor & self, c10::SymInt k, at::Dimname dim, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::kthvalue_dimname::call(self, k, dim, keepdim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, k, dim, keepdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor layer_norm_generated_plumbing(const at::Tensor & input, c10::SymIntArrayRef normalized_shape, const ::std::optional & weight, const ::std::optional & bias, double eps, bool cudnn_enable) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::layer_norm::call(input, normalized_shape, weight, bias, eps, cudnn_enable); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + std::optional weight_value; + std::optional weight_bdim; + if (weight) { + std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight.value(), cur_level); + } + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, normalized_shape, weight_value, weight_bdim, bias_value, bias_bdim, eps, cudnn_enable); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple native_layer_norm_generated_plumbing(const at::Tensor & input, c10::SymIntArrayRef normalized_shape, const ::std::optional & weight, const ::std::optional & bias, double eps) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::native_layer_norm::call(input, normalized_shape, weight, bias, eps); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + std::optional weight_value; + std::optional weight_bdim; + if (weight) { + std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight.value(), cur_level); + } + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, normalized_shape, weight_value, weight_bdim, bias_value, bias_bdim, eps); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level)); +} +template +::std::tuple native_layer_norm_backward_generated_plumbing(const at::Tensor & grad_out, const at::Tensor & input, c10::SymIntArrayRef normalized_shape, const at::Tensor & mean, const at::Tensor & rstd, const ::std::optional & weight, const ::std::optional & bias, ::std::array output_mask) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_out, cur_level) && !isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(mean, cur_level) && !isBatchedAtLevel(rstd, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::native_layer_norm_backward::call(grad_out, input, normalized_shape, mean, rstd, weight, bias, output_mask); + } + auto [grad_out_value, grad_out_bdim] = unwrapTensorAtLevel(grad_out, cur_level); + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [mean_value, mean_bdim] = unwrapTensorAtLevel(mean, cur_level); + auto [rstd_value, rstd_bdim] = unwrapTensorAtLevel(rstd, cur_level); + std::optional weight_value; + std::optional weight_bdim; + if (weight) { + std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight.value(), cur_level); + } + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(grad_out_value, grad_out_bdim, input_value, input_bdim, normalized_shape, mean_value, mean_bdim, rstd_value, rstd_bdim, weight_value, weight_bdim, bias_value, bias_bdim, output_mask); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level)); +} +template +at::Tensor rms_norm_generated_plumbing(const at::Tensor & input, c10::SymIntArrayRef normalized_shape, const ::std::optional & weight, ::std::optional eps) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level)) { + return at::_ops::rms_norm::call(input, normalized_shape, weight, eps); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + std::optional weight_value; + std::optional weight_bdim; + if (weight) { + std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, normalized_shape, weight_value, weight_bdim, eps); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple _fused_rms_norm_generated_plumbing(const at::Tensor & input, at::IntArrayRef normalized_shape, const ::std::optional & weight, ::std::optional eps) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level)) { + return at::_ops::_fused_rms_norm::call(input, normalized_shape, weight, eps); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + std::optional weight_value; + std::optional weight_bdim; + if (weight) { + std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, normalized_shape, weight_value, weight_bdim, eps); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple _fused_rms_norm_backward_generated_plumbing(const at::Tensor & grad_out, const at::Tensor & input, at::IntArrayRef normalized_shape, const at::Tensor & rstd, const ::std::optional & weight, ::std::array output_mask) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_out, cur_level) && !isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(rstd, cur_level) && !isBatchedAtLevel(weight, cur_level)) { + return at::_ops::_fused_rms_norm_backward::call(grad_out, input, normalized_shape, rstd, weight, output_mask); + } + auto [grad_out_value, grad_out_bdim] = unwrapTensorAtLevel(grad_out, cur_level); + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [rstd_value, rstd_bdim] = unwrapTensorAtLevel(rstd, cur_level); + std::optional weight_value; + std::optional weight_bdim; + if (weight) { + std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight.value(), cur_level); + } + auto results = batch_rule(grad_out_value, grad_out_bdim, input_value, input_bdim, normalized_shape, rstd_value, rstd_bdim, weight_value, weight_bdim, output_mask); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor nan_to_num_generated_plumbing(const at::Tensor & self, ::std::optional nan, ::std::optional posinf, ::std::optional neginf) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::nan_to_num::call(self, nan, posinf, neginf); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, nan, posinf, neginf); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & nan_to_num__generated_plumbing(at::Tensor & self, ::std::optional nan, ::std::optional posinf, ::std::optional neginf) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::nan_to_num_::call(self, nan, posinf, neginf); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, nan, posinf, neginf); + return self; +} +template +at::Tensor linear_generated_plumbing(const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::linear::call(input, weight, bias); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, weight_value, weight_bdim, bias_value, bias_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple linear_backward_generated_plumbing(const at::Tensor & self, const at::Tensor & grad_output, const at::Tensor & weight, ::std::array output_mask) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(weight, cur_level)) { + return at::_ops::linear_backward::call(self, grad_output, weight, output_mask); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + auto results = batch_rule(self_value, self_bdim, grad_output_value, grad_output_bdim, weight_value, weight_bdim, output_mask); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level)); +} +template +at::Tensor mkldnn_linear_generated_plumbing(const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::mkldnn_linear::call(self, weight, bias); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, weight_value, weight_bdim, bias_value, bias_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor mkldnn_linear_backward_input_generated_plumbing(at::IntArrayRef input_size, const at::Tensor & grad_output, const at::Tensor & weight) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(weight, cur_level)) { + return at::_ops::mkldnn_linear_backward_input::call(input_size, grad_output, weight); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + auto results = batch_rule(input_size, grad_output_value, grad_output_bdim, weight_value, weight_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple mkldnn_linear_backward_weights_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & weight, bool bias_defined) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level)) { + return at::_ops::mkldnn_linear_backward_weights::call(grad_output, input, weight, bias_defined); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, input_value, input_bdim, weight_value, weight_bdim, bias_defined); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple mkldnn_linear_backward_generated_plumbing(const at::Tensor & self, const at::Tensor & grad_output, const at::Tensor & weight, ::std::array output_mask) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(weight, cur_level)) { + return at::_ops::mkldnn_linear_backward::call(self, grad_output, weight, output_mask); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + auto results = batch_rule(self_value, self_bdim, grad_output_value, grad_output_bdim, weight_value, weight_bdim, output_mask); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level)); +} +template +at::Tensor _cslt_compress_generated_plumbing(const at::Tensor & input) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level)) { + return at::_ops::_cslt_compress::call(input); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto results = batch_rule(input_value, input_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _cslt_sparse_mm_generated_plumbing(const at::Tensor & compressed_A, const at::Tensor & dense_B, const ::std::optional & bias, const ::std::optional & alpha, ::std::optional out_dtype, bool transpose_result, int64_t alg_id, int64_t split_k, int64_t split_k_mode) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(compressed_A, cur_level) && !isBatchedAtLevel(dense_B, cur_level) && !isBatchedAtLevel(bias, cur_level) && !isBatchedAtLevel(alpha, cur_level)) { + return at::_ops::_cslt_sparse_mm::call(compressed_A, dense_B, bias, alpha, out_dtype, transpose_result, alg_id, split_k, split_k_mode); + } + auto [compressed_A_value, compressed_A_bdim] = unwrapTensorAtLevel(compressed_A, cur_level); + auto [dense_B_value, dense_B_bdim] = unwrapTensorAtLevel(dense_B, cur_level); + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + std::optional alpha_value; + std::optional alpha_bdim; + if (alpha) { + std::tie(alpha_value, alpha_bdim) = unwrapTensorAtLevel(alpha.value(), cur_level); + } + auto results = batch_rule(compressed_A_value, compressed_A_bdim, dense_B_value, dense_B_bdim, bias_value, bias_bdim, alpha_value, alpha_bdim, out_dtype, transpose_result, alg_id, split_k, split_k_mode); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple _sparse_semi_structured_tile_generated_plumbing(const at::Tensor & input, c10::string_view algorithm, bool use_cutlass) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level)) { + return at::_ops::_sparse_semi_structured_tile::call(input, algorithm, use_cutlass); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto results = batch_rule(input_value, input_bdim, algorithm, use_cutlass); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level), makeBatched(std::get<6>(results), std::get<7>(results), cur_level), makeBatched(std::get<8>(results), std::get<9>(results), cur_level)); +} +template +::std::tuple _sparse_semi_structured_apply_generated_plumbing(const at::Tensor & input, const at::Tensor & thread_masks) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(thread_masks, cur_level)) { + return at::_ops::_sparse_semi_structured_apply::call(input, thread_masks); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [thread_masks_value, thread_masks_bdim] = unwrapTensorAtLevel(thread_masks, cur_level); + auto results = batch_rule(input_value, input_bdim, thread_masks_value, thread_masks_bdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor _sparse_semi_structured_apply_dense_generated_plumbing(const at::Tensor & input, const at::Tensor & thread_masks) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(thread_masks, cur_level)) { + return at::_ops::_sparse_semi_structured_apply_dense::call(input, thread_masks); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [thread_masks_value, thread_masks_bdim] = unwrapTensorAtLevel(thread_masks, cur_level); + auto results = batch_rule(input_value, input_bdim, thread_masks_value, thread_masks_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _sparse_semi_structured_linear_generated_plumbing(const at::Tensor & input, const at::Tensor & weight, const at::Tensor & meta, const ::std::optional & bias, ::std::optional activation, ::std::optional out_dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(meta, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::_sparse_semi_structured_linear::call(input, weight, meta, bias, activation, out_dtype); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + auto [meta_value, meta_bdim] = unwrapTensorAtLevel(meta, cur_level); + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, weight_value, weight_bdim, meta_value, meta_bdim, bias_value, bias_bdim, activation, out_dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _sparse_semi_structured_mm_generated_plumbing(const at::Tensor & mat1, const at::Tensor & mat1_meta, const at::Tensor & mat2, ::std::optional out_dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(mat1, cur_level) && !isBatchedAtLevel(mat1_meta, cur_level) && !isBatchedAtLevel(mat2, cur_level)) { + return at::_ops::_sparse_semi_structured_mm::call(mat1, mat1_meta, mat2, out_dtype); + } + auto [mat1_value, mat1_bdim] = unwrapTensorAtLevel(mat1, cur_level); + auto [mat1_meta_value, mat1_meta_bdim] = unwrapTensorAtLevel(mat1_meta, cur_level); + auto [mat2_value, mat2_bdim] = unwrapTensorAtLevel(mat2, cur_level); + auto results = batch_rule(mat1_value, mat1_bdim, mat1_meta_value, mat1_meta_bdim, mat2_value, mat2_bdim, out_dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _sparse_semi_structured_addmm_generated_plumbing(const at::Tensor & input, const at::Tensor & mat1, const at::Tensor & mat1_meta, const at::Tensor & mat2, const at::Scalar & alpha, const at::Scalar & beta, ::std::optional out_dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(mat1, cur_level) && !isBatchedAtLevel(mat1_meta, cur_level) && !isBatchedAtLevel(mat2, cur_level)) { + return at::_ops::_sparse_semi_structured_addmm::call(input, mat1, mat1_meta, mat2, alpha, beta, out_dtype); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [mat1_value, mat1_bdim] = unwrapTensorAtLevel(mat1, cur_level); + auto [mat1_meta_value, mat1_meta_bdim] = unwrapTensorAtLevel(mat1_meta, cur_level); + auto [mat2_value, mat2_bdim] = unwrapTensorAtLevel(mat2, cur_level); + auto results = batch_rule(input_value, input_bdim, mat1_value, mat1_bdim, mat1_meta_value, mat1_meta_bdim, mat2_value, mat2_bdim, alpha, beta, out_dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _mixed_dtypes_linear_generated_plumbing(const at::Tensor & input, const at::Tensor & weight, const at::Tensor & scale, const ::std::optional & bias, ::std::optional activation) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(scale, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::_mixed_dtypes_linear::call(input, weight, scale, bias, activation); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + auto [scale_value, scale_bdim] = unwrapTensorAtLevel(scale, cur_level); + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, weight_value, weight_bdim, scale_value, scale_bdim, bias_value, bias_bdim, activation); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor fbgemm_linear_int8_weight_fp32_activation_generated_plumbing(const at::Tensor & input, const at::Tensor & weight, const at::Tensor & packed, const at::Tensor & col_offsets, const at::Scalar & weight_scale, const at::Scalar & weight_zero_point, const at::Tensor & bias) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(packed, cur_level) && !isBatchedAtLevel(col_offsets, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::fbgemm_linear_int8_weight_fp32_activation::call(input, weight, packed, col_offsets, weight_scale, weight_zero_point, bias); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + auto [packed_value, packed_bdim] = unwrapTensorAtLevel(packed, cur_level); + auto [col_offsets_value, col_offsets_bdim] = unwrapTensorAtLevel(col_offsets, cur_level); + auto [bias_value, bias_bdim] = unwrapTensorAtLevel(bias, cur_level); + auto results = batch_rule(input_value, input_bdim, weight_value, weight_bdim, packed_value, packed_bdim, col_offsets_value, col_offsets_bdim, weight_scale, weight_zero_point, bias_value, bias_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor fbgemm_linear_int8_weight_generated_plumbing(const at::Tensor & input, const at::Tensor & weight, const at::Tensor & packed, const at::Tensor & col_offsets, const at::Scalar & weight_scale, const at::Scalar & weight_zero_point, const at::Tensor & bias) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(packed, cur_level) && !isBatchedAtLevel(col_offsets, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::fbgemm_linear_int8_weight::call(input, weight, packed, col_offsets, weight_scale, weight_zero_point, bias); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + auto [packed_value, packed_bdim] = unwrapTensorAtLevel(packed, cur_level); + auto [col_offsets_value, col_offsets_bdim] = unwrapTensorAtLevel(col_offsets, cur_level); + auto [bias_value, bias_bdim] = unwrapTensorAtLevel(bias, cur_level); + auto results = batch_rule(input_value, input_bdim, weight_value, weight_bdim, packed_value, packed_bdim, col_offsets_value, col_offsets_bdim, weight_scale, weight_zero_point, bias_value, bias_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor fbgemm_pack_gemm_matrix_fp16_generated_plumbing(const at::Tensor & input) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level)) { + return at::_ops::fbgemm_pack_gemm_matrix_fp16::call(input); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto results = batch_rule(input_value, input_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _wrapped_linear_prepack_generated_plumbing(const at::Tensor & weight, const at::Tensor & weight_scale, const at::Tensor & weight_zero_point, const at::Tensor & bias) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(weight_scale, cur_level) && !isBatchedAtLevel(weight_zero_point, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::_wrapped_linear_prepack::call(weight, weight_scale, weight_zero_point, bias); + } + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + auto [weight_scale_value, weight_scale_bdim] = unwrapTensorAtLevel(weight_scale, cur_level); + auto [weight_zero_point_value, weight_zero_point_bdim] = unwrapTensorAtLevel(weight_zero_point, cur_level); + auto [bias_value, bias_bdim] = unwrapTensorAtLevel(bias, cur_level); + auto results = batch_rule(weight_value, weight_bdim, weight_scale_value, weight_scale_bdim, weight_zero_point_value, weight_zero_point_bdim, bias_value, bias_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _wrapped_quantized_linear_prepacked_generated_plumbing(const at::Tensor & input, const at::Tensor & input_scale, const at::Tensor & input_zero_point, const at::Tensor & packed_weight, const at::Tensor & output_scale, const at::Tensor & output_zero_point, int64_t out_channel) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(input_scale, cur_level) && !isBatchedAtLevel(input_zero_point, cur_level) && !isBatchedAtLevel(packed_weight, cur_level) && !isBatchedAtLevel(output_scale, cur_level) && !isBatchedAtLevel(output_zero_point, cur_level)) { + return at::_ops::_wrapped_quantized_linear_prepacked::call(input, input_scale, input_zero_point, packed_weight, output_scale, output_zero_point, out_channel); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [input_scale_value, input_scale_bdim] = unwrapTensorAtLevel(input_scale, cur_level); + auto [input_zero_point_value, input_zero_point_bdim] = unwrapTensorAtLevel(input_zero_point, cur_level); + auto [packed_weight_value, packed_weight_bdim] = unwrapTensorAtLevel(packed_weight, cur_level); + auto [output_scale_value, output_scale_bdim] = unwrapTensorAtLevel(output_scale, cur_level); + auto [output_zero_point_value, output_zero_point_bdim] = unwrapTensorAtLevel(output_zero_point, cur_level); + auto results = batch_rule(input_value, input_bdim, input_scale_value, input_scale_bdim, input_zero_point_value, input_zero_point_bdim, packed_weight_value, packed_weight_bdim, output_scale_value, output_scale_bdim, output_zero_point_value, output_zero_point_bdim, out_channel); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor fbgemm_linear_fp16_weight_fp32_activation_generated_plumbing(const at::Tensor & input, const at::Tensor & packed_weight, const ::std::optional & bias) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(packed_weight, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::fbgemm_linear_fp16_weight_fp32_activation::call(input, packed_weight, bias); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [packed_weight_value, packed_weight_bdim] = unwrapTensorAtLevel(packed_weight, cur_level); + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, packed_weight_value, packed_weight_bdim, bias_value, bias_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor fbgemm_linear_fp16_weight_generated_plumbing(const at::Tensor & input, const at::Tensor & packed_weight, const at::Tensor & bias) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(packed_weight, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::fbgemm_linear_fp16_weight::call(input, packed_weight, bias); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [packed_weight_value, packed_weight_bdim] = unwrapTensorAtLevel(packed_weight, cur_level); + auto [bias_value, bias_bdim] = unwrapTensorAtLevel(bias, cur_level); + auto results = batch_rule(input_value, input_bdim, packed_weight_value, packed_weight_bdim, bias_value, bias_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor fbgemm_pack_quantized_matrix_generated_plumbing(const at::Tensor & input) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level)) { + return at::_ops::fbgemm_pack_quantized_matrix::call(input); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto results = batch_rule(input_value, input_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor fbgemm_pack_quantized_matrix_KN_generated_plumbing(const at::Tensor & input, int64_t K, int64_t N) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level)) { + return at::_ops::fbgemm_pack_quantized_matrix_KN::call(input, K, N); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto results = batch_rule(input_value, input_bdim, K, N); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor ldexp_Tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::ldexp_Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & ldexp__generated_plumbing(at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::ldexp_::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self_value, self_bdim, other_value, other_bdim); + return self; +} +template +at::Tensor linspace_Tensor_Tensor_generated_plumbing(const at::Tensor & start, const at::Tensor & end, int64_t steps, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(start, cur_level) && !isBatchedAtLevel(end, cur_level)) { + return at::_ops::linspace_Tensor_Tensor::call(start, end, steps, dtype, layout, device, pin_memory); + } + auto [start_value, start_bdim] = unwrapTensorAtLevel(start, cur_level); + auto [end_value, end_bdim] = unwrapTensorAtLevel(end, cur_level); + auto results = batch_rule(start_value, start_bdim, end_value, end_bdim, steps, dtype, layout, device, pin_memory); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor linspace_Tensor_Scalar_generated_plumbing(const at::Tensor & start, const at::Scalar & end, int64_t steps, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(start, cur_level)) { + return at::_ops::linspace_Tensor_Scalar::call(start, end, steps, dtype, layout, device, pin_memory); + } + auto [start_value, start_bdim] = unwrapTensorAtLevel(start, cur_level); + auto results = batch_rule(start_value, start_bdim, end, steps, dtype, layout, device, pin_memory); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor linspace_Scalar_Tensor_generated_plumbing(const at::Scalar & start, const at::Tensor & end, int64_t steps, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(end, cur_level)) { + return at::_ops::linspace_Scalar_Tensor::call(start, end, steps, dtype, layout, device, pin_memory); + } + auto [end_value, end_bdim] = unwrapTensorAtLevel(end, cur_level); + auto results = batch_rule(start, end_value, end_bdim, steps, dtype, layout, device, pin_memory); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor log_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::log::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & log__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::log_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor log10_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::log10::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & log10__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::log10_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor log1p_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::log1p::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & log1p__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::log1p_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor log2_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::log2::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & log2__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::log2_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor logaddexp_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::logaddexp::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor logaddexp2_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::logaddexp2::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor xlogy_Tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::xlogy_Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor xlogy_Scalar_Self_generated_plumbing(const at::Scalar & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(other, cur_level)) { + return at::_ops::xlogy_Scalar_Self::call(self, other); + } + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor xlogy_Scalar_Other_generated_plumbing(const at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::xlogy_Scalar_Other::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, other); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & xlogy__Tensor_generated_plumbing(at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::xlogy__Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self_value, self_bdim, other_value, other_bdim); + return self; +} +template +at::Tensor & xlogy__Scalar_Other_generated_plumbing(at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::xlogy__Scalar_Other::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, other); + return self; +} +template +at::Tensor logspace_Tensor_Tensor_generated_plumbing(const at::Tensor & start, const at::Tensor & end, int64_t steps, double base, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(start, cur_level) && !isBatchedAtLevel(end, cur_level)) { + return at::_ops::logspace_Tensor_Tensor::call(start, end, steps, base, dtype, layout, device, pin_memory); + } + auto [start_value, start_bdim] = unwrapTensorAtLevel(start, cur_level); + auto [end_value, end_bdim] = unwrapTensorAtLevel(end, cur_level); + auto results = batch_rule(start_value, start_bdim, end_value, end_bdim, steps, base, dtype, layout, device, pin_memory); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor logspace_Tensor_Scalar_generated_plumbing(const at::Tensor & start, const at::Scalar & end, int64_t steps, double base, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(start, cur_level)) { + return at::_ops::logspace_Tensor_Scalar::call(start, end, steps, base, dtype, layout, device, pin_memory); + } + auto [start_value, start_bdim] = unwrapTensorAtLevel(start, cur_level); + auto results = batch_rule(start_value, start_bdim, end, steps, base, dtype, layout, device, pin_memory); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor logspace_Scalar_Tensor_generated_plumbing(const at::Scalar & start, const at::Tensor & end, int64_t steps, double base, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(end, cur_level)) { + return at::_ops::logspace_Scalar_Tensor::call(start, end, steps, base, dtype, layout, device, pin_memory); + } + auto [end_value, end_bdim] = unwrapTensorAtLevel(end, cur_level); + auto results = batch_rule(start, end_value, end_bdim, steps, base, dtype, layout, device, pin_memory); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor log_softmax_int_generated_plumbing(const at::Tensor & self, int64_t dim, ::std::optional dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::log_softmax_int::call(self, dim, dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor log_softmax_Dimname_generated_plumbing(const at::Tensor & self, at::Dimname dim, ::std::optional dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::log_softmax_Dimname::call(self, dim, dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _log_softmax_generated_plumbing(const at::Tensor & self, int64_t dim, bool half_to_float) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_log_softmax::call(self, dim, half_to_float); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, half_to_float); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _log_softmax_backward_data_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & output, int64_t dim, at::ScalarType input_dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(output, cur_level)) { + return at::_ops::_log_softmax_backward_data::call(grad_output, output, dim, input_dtype); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [output_value, output_bdim] = unwrapTensorAtLevel(output, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, output_value, output_bdim, dim, input_dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _logcumsumexp_generated_plumbing(const at::Tensor & self, int64_t dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_logcumsumexp::call(self, dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor logcumsumexp_generated_plumbing(const at::Tensor & self, int64_t dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::logcumsumexp::call(self, dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor logcumsumexp_dimname_generated_plumbing(const at::Tensor & self, at::Dimname dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::logcumsumexp_dimname::call(self, dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor logsumexp_generated_plumbing(const at::Tensor & self, at::IntArrayRef dim, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::logsumexp::call(self, dim, keepdim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, keepdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor logsumexp_names_generated_plumbing(const at::Tensor & self, at::DimnameList dim, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::logsumexp_names::call(self, dim, keepdim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, keepdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor margin_ranking_loss_generated_plumbing(const at::Tensor & input1, const at::Tensor & input2, const at::Tensor & target, double margin, int64_t reduction) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input1, cur_level) && !isBatchedAtLevel(input2, cur_level) && !isBatchedAtLevel(target, cur_level)) { + return at::_ops::margin_ranking_loss::call(input1, input2, target, margin, reduction); + } + auto [input1_value, input1_bdim] = unwrapTensorAtLevel(input1, cur_level); + auto [input2_value, input2_bdim] = unwrapTensorAtLevel(input2, cur_level); + auto [target_value, target_bdim] = unwrapTensorAtLevel(target, cur_level); + auto results = batch_rule(input1_value, input1_bdim, input2_value, input2_bdim, target_value, target_bdim, margin, reduction); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor matmul_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::matmul::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple matmul_backward_generated_plumbing(const at::Tensor & grad, const at::Tensor & self, const at::Tensor & other, ::std::array mask) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad, cur_level) && !isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::matmul_backward::call(grad, self, other, mask); + } + auto [grad_value, grad_bdim] = unwrapTensorAtLevel(grad, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(grad_value, grad_bdim, self_value, self_bdim, other_value, other_bdim, mask); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor matrix_power_generated_plumbing(const at::Tensor & self, int64_t n) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::matrix_power::call(self, n); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, n); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor matrix_exp_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::matrix_exp::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor matrix_exp_backward_generated_plumbing(const at::Tensor & self, const at::Tensor & grad) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(grad, cur_level)) { + return at::_ops::matrix_exp_backward::call(self, grad); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [grad_value, grad_bdim] = unwrapTensorAtLevel(grad, cur_level); + auto results = batch_rule(self_value, self_bdim, grad_value, grad_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple _aminmax_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_aminmax::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple _aminmax_dim_generated_plumbing(const at::Tensor & self, int64_t dim, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_aminmax_dim::call(self, dim, keepdim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, keepdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple aminmax_generated_plumbing(const at::Tensor & self, ::std::optional dim, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::aminmax::call(self, dim, keepdim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, keepdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor _compute_linear_combination_generated_plumbing(const at::Tensor & input, const at::Tensor & coefficients) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(coefficients, cur_level)) { + return at::_ops::_compute_linear_combination::call(input, coefficients); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [coefficients_value, coefficients_bdim] = unwrapTensorAtLevel(coefficients, cur_level); + auto results = batch_rule(input_value, input_bdim, coefficients_value, coefficients_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple max_dim_generated_plumbing(const at::Tensor & self, int64_t dim, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::max_dim::call(self, dim, keepdim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, keepdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple max_names_dim_generated_plumbing(const at::Tensor & self, at::Dimname dim, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::max_names_dim::call(self, dim, keepdim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, keepdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor value_selecting_reduction_backward_generated_plumbing(const at::Tensor & grad, int64_t dim, const at::Tensor & indices, c10::SymIntArrayRef sizes, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad, cur_level) && !isBatchedAtLevel(indices, cur_level)) { + return at::_ops::value_selecting_reduction_backward::call(grad, dim, indices, sizes, keepdim); + } + auto [grad_value, grad_bdim] = unwrapTensorAtLevel(grad, cur_level); + auto [indices_value, indices_bdim] = unwrapTensorAtLevel(indices, cur_level); + auto results = batch_rule(grad_value, grad_bdim, dim, indices_value, indices_bdim, sizes, keepdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor amax_generated_plumbing(const at::Tensor & self, at::IntArrayRef dim, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::amax::call(self, dim, keepdim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, keepdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple max_pool1d_with_indices_generated_plumbing(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::max_pool1d_with_indices::call(self, kernel_size, stride, padding, dilation, ceil_mode); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, kernel_size, stride, padding, dilation, ceil_mode); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor max_pool1d_generated_plumbing(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::max_pool1d::call(self, kernel_size, stride, padding, dilation, ceil_mode); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, kernel_size, stride, padding, dilation, ceil_mode); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor max_pool2d_generated_plumbing(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::max_pool2d::call(self, kernel_size, stride, padding, dilation, ceil_mode); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, kernel_size, stride, padding, dilation, ceil_mode); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor max_pool2d_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(self, cur_level)) { + return at::_ops::max_pool2d_backward::call(grad_output, self, kernel_size, stride, padding, dilation, ceil_mode); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, self_value, self_bdim, kernel_size, stride, padding, dilation, ceil_mode); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor mkldnn_max_pool2d_generated_plumbing(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::mkldnn_max_pool2d::call(self, kernel_size, stride, padding, dilation, ceil_mode); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, kernel_size, stride, padding, dilation, ceil_mode); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor mkldnn_max_pool2d_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & output, const at::Tensor & input, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(output, cur_level) && !isBatchedAtLevel(input, cur_level)) { + return at::_ops::mkldnn_max_pool2d_backward::call(grad_output, output, input, kernel_size, stride, padding, dilation, ceil_mode); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [output_value, output_bdim] = unwrapTensorAtLevel(output, cur_level); + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, output_value, output_bdim, input_value, input_bdim, kernel_size, stride, padding, dilation, ceil_mode); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor mkldnn_max_pool3d_generated_plumbing(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::mkldnn_max_pool3d::call(self, kernel_size, stride, padding, dilation, ceil_mode); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, kernel_size, stride, padding, dilation, ceil_mode); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor mkldnn_max_pool3d_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & output, const at::Tensor & input, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(output, cur_level) && !isBatchedAtLevel(input, cur_level)) { + return at::_ops::mkldnn_max_pool3d_backward::call(grad_output, output, input, kernel_size, stride, padding, dilation, ceil_mode); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [output_value, output_bdim] = unwrapTensorAtLevel(output, cur_level); + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, output_value, output_bdim, input_value, input_bdim, kernel_size, stride, padding, dilation, ceil_mode); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor quantized_max_pool1d_generated_plumbing(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::quantized_max_pool1d::call(self, kernel_size, stride, padding, dilation, ceil_mode); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, kernel_size, stride, padding, dilation, ceil_mode); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor quantized_max_pool2d_generated_plumbing(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::quantized_max_pool2d::call(self, kernel_size, stride, padding, dilation, ceil_mode); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, kernel_size, stride, padding, dilation, ceil_mode); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor quantized_max_pool3d_generated_plumbing(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::quantized_max_pool3d::call(self, kernel_size, stride, padding, dilation, ceil_mode); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, kernel_size, stride, padding, dilation, ceil_mode); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor max_pool3d_generated_plumbing(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::max_pool3d::call(self, kernel_size, stride, padding, dilation, ceil_mode); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, kernel_size, stride, padding, dilation, ceil_mode); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor mean_generated_plumbing(const at::Tensor & self, ::std::optional dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::mean::call(self, dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor mean_dim_generated_plumbing(const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim, ::std::optional dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::mean_dim::call(self, dim, keepdim, dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, keepdim, dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor mean_names_dim_generated_plumbing(const at::Tensor & self, at::DimnameList dim, bool keepdim, ::std::optional dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::mean_names_dim::call(self, dim, keepdim, dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, keepdim, dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor nanmean_generated_plumbing(const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim, ::std::optional dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::nanmean::call(self, dim, keepdim, dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, keepdim, dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor median_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::median::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple median_dim_generated_plumbing(const at::Tensor & self, int64_t dim, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::median_dim::call(self, dim, keepdim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, keepdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple median_names_dim_generated_plumbing(const at::Tensor & self, at::Dimname dim, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::median_names_dim::call(self, dim, keepdim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, keepdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor nanmedian_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::nanmedian::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple nanmedian_dim_generated_plumbing(const at::Tensor & self, int64_t dim, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::nanmedian_dim::call(self, dim, keepdim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, keepdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple nanmedian_names_dim_generated_plumbing(const at::Tensor & self, at::Dimname dim, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::nanmedian_names_dim::call(self, dim, keepdim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, keepdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple min_dim_generated_plumbing(const at::Tensor & self, int64_t dim, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::min_dim::call(self, dim, keepdim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, keepdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple min_names_dim_generated_plumbing(const at::Tensor & self, at::Dimname dim, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::min_names_dim::call(self, dim, keepdim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, keepdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor amin_generated_plumbing(const at::Tensor & self, at::IntArrayRef dim, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::amin::call(self, dim, keepdim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, keepdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _mps_convolution_generated_plumbing(const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::_mps_convolution::call(self, weight, bias, padding, stride, dilation, groups); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, weight_value, weight_bdim, bias_value, bias_bdim, padding, stride, dilation, groups); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple mps_convolution_backward_generated_plumbing(const at::Tensor & self, const at::Tensor & grad_output, const at::Tensor & weight, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, ::std::array output_mask) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(weight, cur_level)) { + return at::_ops::mps_convolution_backward::call(self, grad_output, weight, padding, stride, dilation, groups, output_mask); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + auto results = batch_rule(self_value, self_bdim, grad_output_value, grad_output_bdim, weight_value, weight_bdim, padding, stride, dilation, groups, output_mask); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level)); +} +template +at::Tensor mkldnn_convolution_generated_plumbing(const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::mkldnn_convolution::call(self, weight, bias, padding, stride, dilation, groups); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, weight_value, weight_bdim, bias_value, bias_bdim, padding, stride, dilation, groups); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple mkldnn_rnn_layer_generated_plumbing(const at::Tensor & input, const at::Tensor & weight0, const at::Tensor & weight1, const at::Tensor & weight2, const at::Tensor & weight3, const at::Tensor & hx_, const at::Tensor & cx_, bool reverse, at::IntArrayRef batch_sizes, int64_t mode, int64_t hidden_size, int64_t num_layers, bool has_biases, bool bidirectional, bool batch_first, bool train) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight0, cur_level) && !isBatchedAtLevel(weight1, cur_level) && !isBatchedAtLevel(weight2, cur_level) && !isBatchedAtLevel(weight3, cur_level) && !isBatchedAtLevel(hx_, cur_level) && !isBatchedAtLevel(cx_, cur_level)) { + return at::_ops::mkldnn_rnn_layer::call(input, weight0, weight1, weight2, weight3, hx_, cx_, reverse, batch_sizes, mode, hidden_size, num_layers, has_biases, bidirectional, batch_first, train); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [weight0_value, weight0_bdim] = unwrapTensorAtLevel(weight0, cur_level); + auto [weight1_value, weight1_bdim] = unwrapTensorAtLevel(weight1, cur_level); + auto [weight2_value, weight2_bdim] = unwrapTensorAtLevel(weight2, cur_level); + auto [weight3_value, weight3_bdim] = unwrapTensorAtLevel(weight3, cur_level); + auto [hx__value, hx__bdim] = unwrapTensorAtLevel(hx_, cur_level); + auto [cx__value, cx__bdim] = unwrapTensorAtLevel(cx_, cur_level); + auto results = batch_rule(input_value, input_bdim, weight0_value, weight0_bdim, weight1_value, weight1_bdim, weight2_value, weight2_bdim, weight3_value, weight3_bdim, hx__value, hx__bdim, cx__value, cx__bdim, reverse, batch_sizes, mode, hidden_size, num_layers, has_biases, bidirectional, batch_first, train); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level), makeBatched(std::get<6>(results), std::get<7>(results), cur_level)); +} +template +::std::tuple mkldnn_rnn_layer_backward_generated_plumbing(const at::Tensor & input, const at::Tensor & weight1, const at::Tensor & weight2, const at::Tensor & weight3, const at::Tensor & weight4, const at::Tensor & hx_, const at::Tensor & cx_tmp, const at::Tensor & output, const at::Tensor & hy_, const at::Tensor & cy_, const ::std::optional & grad_output, const ::std::optional & grad_hy, const ::std::optional & grad_cy, bool reverse, int64_t mode, int64_t hidden_size, int64_t num_layers, bool has_biases, bool train, bool bidirectional, at::IntArrayRef batch_sizes, bool batch_first, const at::Tensor & workspace) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight1, cur_level) && !isBatchedAtLevel(weight2, cur_level) && !isBatchedAtLevel(weight3, cur_level) && !isBatchedAtLevel(weight4, cur_level) && !isBatchedAtLevel(hx_, cur_level) && !isBatchedAtLevel(cx_tmp, cur_level) && !isBatchedAtLevel(output, cur_level) && !isBatchedAtLevel(hy_, cur_level) && !isBatchedAtLevel(cy_, cur_level) && !isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(grad_hy, cur_level) && !isBatchedAtLevel(grad_cy, cur_level) && !isBatchedAtLevel(workspace, cur_level)) { + return at::_ops::mkldnn_rnn_layer_backward::call(input, weight1, weight2, weight3, weight4, hx_, cx_tmp, output, hy_, cy_, grad_output, grad_hy, grad_cy, reverse, mode, hidden_size, num_layers, has_biases, train, bidirectional, batch_sizes, batch_first, workspace); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [weight1_value, weight1_bdim] = unwrapTensorAtLevel(weight1, cur_level); + auto [weight2_value, weight2_bdim] = unwrapTensorAtLevel(weight2, cur_level); + auto [weight3_value, weight3_bdim] = unwrapTensorAtLevel(weight3, cur_level); + auto [weight4_value, weight4_bdim] = unwrapTensorAtLevel(weight4, cur_level); + auto [hx__value, hx__bdim] = unwrapTensorAtLevel(hx_, cur_level); + auto [cx_tmp_value, cx_tmp_bdim] = unwrapTensorAtLevel(cx_tmp, cur_level); + auto [output_value, output_bdim] = unwrapTensorAtLevel(output, cur_level); + auto [hy__value, hy__bdim] = unwrapTensorAtLevel(hy_, cur_level); + auto [cy__value, cy__bdim] = unwrapTensorAtLevel(cy_, cur_level); + auto [workspace_value, workspace_bdim] = unwrapTensorAtLevel(workspace, cur_level); + std::optional grad_output_value; + std::optional grad_output_bdim; + if (grad_output) { + std::tie(grad_output_value, grad_output_bdim) = unwrapTensorAtLevel(grad_output.value(), cur_level); + } + std::optional grad_hy_value; + std::optional grad_hy_bdim; + if (grad_hy) { + std::tie(grad_hy_value, grad_hy_bdim) = unwrapTensorAtLevel(grad_hy.value(), cur_level); + } + std::optional grad_cy_value; + std::optional grad_cy_bdim; + if (grad_cy) { + std::tie(grad_cy_value, grad_cy_bdim) = unwrapTensorAtLevel(grad_cy.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, weight1_value, weight1_bdim, weight2_value, weight2_bdim, weight3_value, weight3_bdim, weight4_value, weight4_bdim, hx__value, hx__bdim, cx_tmp_value, cx_tmp_bdim, output_value, output_bdim, hy__value, hy__bdim, cy__value, cy__bdim, grad_output_value, grad_output_bdim, grad_hy_value, grad_hy_bdim, grad_cy_value, grad_cy_bdim, reverse, mode, hidden_size, num_layers, has_biases, train, bidirectional, batch_sizes, batch_first, workspace_value, workspace_bdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level), makeBatched(std::get<6>(results), std::get<7>(results), cur_level), makeBatched(std::get<8>(results), std::get<9>(results), cur_level), makeBatched(std::get<10>(results), std::get<11>(results), cur_level), makeBatched(std::get<12>(results), std::get<13>(results), cur_level)); +} +template +::std::tuple miopen_batch_norm_generated_plumbing(const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, const ::std::optional & running_mean, const ::std::optional & running_var, bool training, double exponential_average_factor, double epsilon) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level) && !isBatchedAtLevel(running_mean, cur_level) && !isBatchedAtLevel(running_var, cur_level)) { + return at::_ops::miopen_batch_norm::call(input, weight, bias, running_mean, running_var, training, exponential_average_factor, epsilon); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + std::optional running_mean_value; + std::optional running_mean_bdim; + if (running_mean) { + std::tie(running_mean_value, running_mean_bdim) = unwrapTensorAtLevel(running_mean.value(), cur_level); + } + std::optional running_var_value; + std::optional running_var_bdim; + if (running_var) { + std::tie(running_var_value, running_var_bdim) = unwrapTensorAtLevel(running_var.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, weight_value, weight_bdim, bias_value, bias_bdim, running_mean_value, running_mean_bdim, running_var_value, running_var_bdim, training, exponential_average_factor, epsilon); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level)); +} +template +::std::tuple miopen_batch_norm_backward_generated_plumbing(const at::Tensor & input, const at::Tensor & grad_output, const at::Tensor & weight, const ::std::optional & running_mean, const ::std::optional & running_var, const ::std::optional & save_mean, const ::std::optional & save_var, double epsilon) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(running_mean, cur_level) && !isBatchedAtLevel(running_var, cur_level) && !isBatchedAtLevel(save_mean, cur_level) && !isBatchedAtLevel(save_var, cur_level)) { + return at::_ops::miopen_batch_norm_backward::call(input, grad_output, weight, running_mean, running_var, save_mean, save_var, epsilon); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + std::optional running_mean_value; + std::optional running_mean_bdim; + if (running_mean) { + std::tie(running_mean_value, running_mean_bdim) = unwrapTensorAtLevel(running_mean.value(), cur_level); + } + std::optional running_var_value; + std::optional running_var_bdim; + if (running_var) { + std::tie(running_var_value, running_var_bdim) = unwrapTensorAtLevel(running_var.value(), cur_level); + } + std::optional save_mean_value; + std::optional save_mean_bdim; + if (save_mean) { + std::tie(save_mean_value, save_mean_bdim) = unwrapTensorAtLevel(save_mean.value(), cur_level); + } + std::optional save_var_value; + std::optional save_var_bdim; + if (save_var) { + std::tie(save_var_value, save_var_bdim) = unwrapTensorAtLevel(save_var.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, grad_output_value, grad_output_bdim, weight_value, weight_bdim, running_mean_value, running_mean_bdim, running_var_value, running_var_bdim, save_mean_value, save_mean_bdim, save_var_value, save_var_bdim, epsilon); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level)); +} +template +at::Tensor miopen_convolution_generated_plumbing(const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, bool benchmark, bool deterministic) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::miopen_convolution::call(self, weight, bias, padding, stride, dilation, groups, benchmark, deterministic); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, weight_value, weight_bdim, bias_value, bias_bdim, padding, stride, dilation, groups, benchmark, deterministic); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor miopen_convolution_transpose_generated_plumbing(const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, bool benchmark, bool deterministic) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::miopen_convolution_transpose::call(self, weight, bias, padding, output_padding, stride, dilation, groups, benchmark, deterministic); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, weight_value, weight_bdim, bias_value, bias_bdim, padding, output_padding, stride, dilation, groups, benchmark, deterministic); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor miopen_depthwise_convolution_generated_plumbing(const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, bool benchmark, bool deterministic) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::miopen_depthwise_convolution::call(self, weight, bias, padding, stride, dilation, groups, benchmark, deterministic); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, weight_value, weight_bdim, bias_value, bias_bdim, padding, stride, dilation, groups, benchmark, deterministic); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor miopen_convolution_relu_generated_plumbing(const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, c10::SymInt groups) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::miopen_convolution_relu::call(self, weight, bias, stride, padding, dilation, groups); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, weight_value, weight_bdim, bias_value, bias_bdim, stride, padding, dilation, groups); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor miopen_convolution_add_relu_generated_plumbing(const at::Tensor & self, const at::Tensor & weight, const at::Tensor & z, const ::std::optional & alpha, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, c10::SymInt groups) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(z, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::miopen_convolution_add_relu::call(self, weight, z, alpha, bias, stride, padding, dilation, groups); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + auto [z_value, z_bdim] = unwrapTensorAtLevel(z, cur_level); + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, weight_value, weight_bdim, z_value, z_bdim, alpha, bias_value, bias_bdim, stride, padding, dilation, groups); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple miopen_rnn_generated_plumbing(const at::Tensor & input, at::TensorList weight, int64_t weight_stride0, const at::Tensor & hx, const ::std::optional & cx, int64_t mode, int64_t hidden_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, at::IntArrayRef batch_sizes, const ::std::optional & dropout_state) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(hx, cur_level) && !isBatchedAtLevel(cx, cur_level) && !isBatchedAtLevel(dropout_state, cur_level)) { + return at::_ops::miopen_rnn::call(input, weight, weight_stride0, hx, cx, mode, hidden_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [hx_value, hx_bdim] = unwrapTensorAtLevel(hx, cur_level); + std::optional cx_value; + std::optional cx_bdim; + if (cx) { + std::tie(cx_value, cx_bdim) = unwrapTensorAtLevel(cx.value(), cur_level); + } + std::optional dropout_state_value; + std::optional dropout_state_bdim; + if (dropout_state) { + std::tie(dropout_state_value, dropout_state_bdim) = unwrapTensorAtLevel(dropout_state.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, weight, weight_stride0, hx_value, hx_bdim, cx_value, cx_bdim, mode, hidden_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state_value, dropout_state_bdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level), makeBatched(std::get<6>(results), std::get<7>(results), cur_level), makeBatched(std::get<8>(results), std::get<9>(results), cur_level)); +} +template +::std::tuple> miopen_rnn_backward_generated_plumbing(const at::Tensor & input, at::TensorList weight, int64_t weight_stride0, const at::Tensor & weight_buf, const at::Tensor & hx, const ::std::optional & cx, const at::Tensor & output, const ::std::optional & grad_output, const ::std::optional & grad_hy, const ::std::optional & grad_cy, int64_t mode, int64_t hidden_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, at::IntArrayRef batch_sizes, const ::std::optional & dropout_state, const at::Tensor & reserve, ::std::array output_mask) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(weight_buf, cur_level) && !isBatchedAtLevel(hx, cur_level) && !isBatchedAtLevel(cx, cur_level) && !isBatchedAtLevel(output, cur_level) && !isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(grad_hy, cur_level) && !isBatchedAtLevel(grad_cy, cur_level) && !isBatchedAtLevel(dropout_state, cur_level) && !isBatchedAtLevel(reserve, cur_level)) { + return at::_ops::miopen_rnn_backward::call(input, weight, weight_stride0, weight_buf, hx, cx, output, grad_output, grad_hy, grad_cy, mode, hidden_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state, reserve, output_mask); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [weight_buf_value, weight_buf_bdim] = unwrapTensorAtLevel(weight_buf, cur_level); + auto [hx_value, hx_bdim] = unwrapTensorAtLevel(hx, cur_level); + auto [output_value, output_bdim] = unwrapTensorAtLevel(output, cur_level); + auto [reserve_value, reserve_bdim] = unwrapTensorAtLevel(reserve, cur_level); + std::optional cx_value; + std::optional cx_bdim; + if (cx) { + std::tie(cx_value, cx_bdim) = unwrapTensorAtLevel(cx.value(), cur_level); + } + std::optional grad_output_value; + std::optional grad_output_bdim; + if (grad_output) { + std::tie(grad_output_value, grad_output_bdim) = unwrapTensorAtLevel(grad_output.value(), cur_level); + } + std::optional grad_hy_value; + std::optional grad_hy_bdim; + if (grad_hy) { + std::tie(grad_hy_value, grad_hy_bdim) = unwrapTensorAtLevel(grad_hy.value(), cur_level); + } + std::optional grad_cy_value; + std::optional grad_cy_bdim; + if (grad_cy) { + std::tie(grad_cy_value, grad_cy_bdim) = unwrapTensorAtLevel(grad_cy.value(), cur_level); + } + std::optional dropout_state_value; + std::optional dropout_state_bdim; + if (dropout_state) { + std::tie(dropout_state_value, dropout_state_bdim) = unwrapTensorAtLevel(dropout_state.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, weight, weight_stride0, weight_buf_value, weight_buf_bdim, hx_value, hx_bdim, cx_value, cx_bdim, output_value, output_bdim, grad_output_value, grad_output_bdim, grad_hy_value, grad_hy_bdim, grad_cy_value, grad_cy_bdim, mode, hidden_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state_value, dropout_state_bdim, reserve_value, reserve_bdim, output_mask); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level), makeBatchedVector(std::get<6>(results), std::get<7>(results), cur_level)); +} +template +at::Tensor mm_generated_plumbing(const at::Tensor & self, const at::Tensor & mat2) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(mat2, cur_level)) { + return at::_ops::mm::call(self, mat2); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [mat2_value, mat2_bdim] = unwrapTensorAtLevel(mat2, cur_level); + auto results = batch_rule(self_value, self_bdim, mat2_value, mat2_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor mm_dtype_generated_plumbing(const at::Tensor & self, const at::Tensor & mat2, at::ScalarType out_dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(mat2, cur_level)) { + return at::_ops::mm_dtype::call(self, mat2, out_dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [mat2_value, mat2_bdim] = unwrapTensorAtLevel(mat2, cur_level); + auto results = batch_rule(self_value, self_bdim, mat2_value, mat2_bdim, out_dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _int_mm_generated_plumbing(const at::Tensor & self, const at::Tensor & mat2) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(mat2, cur_level)) { + return at::_ops::_int_mm::call(self, mat2); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [mat2_value, mat2_bdim] = unwrapTensorAtLevel(mat2, cur_level); + auto results = batch_rule(self_value, self_bdim, mat2_value, mat2_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _convert_weight_to_int4pack_generated_plumbing(const at::Tensor & self, int64_t innerKTiles) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_convert_weight_to_int4pack::call(self, innerKTiles); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, innerKTiles); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _weight_int4pack_mm_generated_plumbing(const at::Tensor & self, const at::Tensor & mat2, int64_t qGroupSize, const at::Tensor & qScaleAndZeros) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(mat2, cur_level) && !isBatchedAtLevel(qScaleAndZeros, cur_level)) { + return at::_ops::_weight_int4pack_mm::call(self, mat2, qGroupSize, qScaleAndZeros); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [mat2_value, mat2_bdim] = unwrapTensorAtLevel(mat2, cur_level); + auto [qScaleAndZeros_value, qScaleAndZeros_bdim] = unwrapTensorAtLevel(qScaleAndZeros, cur_level); + auto results = batch_rule(self_value, self_bdim, mat2_value, mat2_bdim, qGroupSize, qScaleAndZeros_value, qScaleAndZeros_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _weight_int4pack_mm_with_scales_and_zeros_generated_plumbing(const at::Tensor & self, const at::Tensor & mat2, int64_t qGroupSize, const at::Tensor & qScale, const at::Tensor & qZeros) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(mat2, cur_level) && !isBatchedAtLevel(qScale, cur_level) && !isBatchedAtLevel(qZeros, cur_level)) { + return at::_ops::_weight_int4pack_mm_with_scales_and_zeros::call(self, mat2, qGroupSize, qScale, qZeros); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [mat2_value, mat2_bdim] = unwrapTensorAtLevel(mat2, cur_level); + auto [qScale_value, qScale_bdim] = unwrapTensorAtLevel(qScale, cur_level); + auto [qZeros_value, qZeros_bdim] = unwrapTensorAtLevel(qZeros, cur_level); + auto results = batch_rule(self_value, self_bdim, mat2_value, mat2_bdim, qGroupSize, qScale_value, qScale_bdim, qZeros_value, qZeros_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _convert_weight_to_int4pack_for_cpu_generated_plumbing(const at::Tensor & self, int64_t innerKTiles) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_convert_weight_to_int4pack_for_cpu::call(self, innerKTiles); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, innerKTiles); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _weight_int4pack_mm_for_cpu_generated_plumbing(const at::Tensor & self, const at::Tensor & mat2, int64_t qGroupSize, const at::Tensor & qScaleAndZeros) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(mat2, cur_level) && !isBatchedAtLevel(qScaleAndZeros, cur_level)) { + return at::_ops::_weight_int4pack_mm_for_cpu::call(self, mat2, qGroupSize, qScaleAndZeros); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [mat2_value, mat2_bdim] = unwrapTensorAtLevel(mat2, cur_level); + auto [qScaleAndZeros_value, qScaleAndZeros_bdim] = unwrapTensorAtLevel(qScaleAndZeros, cur_level); + auto results = batch_rule(self_value, self_bdim, mat2_value, mat2_bdim, qGroupSize, qScaleAndZeros_value, qScaleAndZeros_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _dyn_quant_pack_4bit_weight_generated_plumbing(const at::Tensor & weights, const at::Tensor & scales_zeros, const ::std::optional & bias, int64_t block_size, int64_t in_features, int64_t out_features) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(weights, cur_level) && !isBatchedAtLevel(scales_zeros, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::_dyn_quant_pack_4bit_weight::call(weights, scales_zeros, bias, block_size, in_features, out_features); + } + auto [weights_value, weights_bdim] = unwrapTensorAtLevel(weights, cur_level); + auto [scales_zeros_value, scales_zeros_bdim] = unwrapTensorAtLevel(scales_zeros, cur_level); + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(weights_value, weights_bdim, scales_zeros_value, scales_zeros_bdim, bias_value, bias_bdim, block_size, in_features, out_features); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _dyn_quant_matmul_4bit_generated_plumbing(const at::Tensor & inp, const at::Tensor & packed_weights, int64_t block_size, int64_t in_features, int64_t out_features) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(inp, cur_level) && !isBatchedAtLevel(packed_weights, cur_level)) { + return at::_ops::_dyn_quant_matmul_4bit::call(inp, packed_weights, block_size, in_features, out_features); + } + auto [inp_value, inp_bdim] = unwrapTensorAtLevel(inp, cur_level); + auto [packed_weights_value, packed_weights_bdim] = unwrapTensorAtLevel(packed_weights, cur_level); + auto results = batch_rule(inp_value, inp_bdim, packed_weights_value, packed_weights_bdim, block_size, in_features, out_features); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _weight_int8pack_mm_generated_plumbing(const at::Tensor & self, const at::Tensor & mat2, const at::Tensor & scales) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(mat2, cur_level) && !isBatchedAtLevel(scales, cur_level)) { + return at::_ops::_weight_int8pack_mm::call(self, mat2, scales); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [mat2_value, mat2_bdim] = unwrapTensorAtLevel(mat2, cur_level); + auto [scales_value, scales_bdim] = unwrapTensorAtLevel(scales, cur_level); + auto results = batch_rule(self_value, self_bdim, mat2_value, mat2_bdim, scales_value, scales_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _sparse_mm_generated_plumbing(const at::Tensor & sparse, const at::Tensor & dense) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(sparse, cur_level) && !isBatchedAtLevel(dense, cur_level)) { + return at::_ops::_sparse_mm::call(sparse, dense); + } + auto [sparse_value, sparse_bdim] = unwrapTensorAtLevel(sparse, cur_level); + auto [dense_value, dense_bdim] = unwrapTensorAtLevel(dense, cur_level); + auto results = batch_rule(sparse_value, sparse_bdim, dense_value, dense_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _sparse_mm_reduce_generated_plumbing(const at::Tensor & sparse, const at::Tensor & dense, c10::string_view reduce) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(sparse, cur_level) && !isBatchedAtLevel(dense, cur_level)) { + return at::_ops::_sparse_mm_reduce::call(sparse, dense, reduce); + } + auto [sparse_value, sparse_bdim] = unwrapTensorAtLevel(sparse, cur_level); + auto [dense_value, dense_bdim] = unwrapTensorAtLevel(dense, cur_level); + auto results = batch_rule(sparse_value, sparse_bdim, dense_value, dense_bdim, reduce); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _sparse_sparse_matmul_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::_sparse_sparse_matmul::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple mode_generated_plumbing(const at::Tensor & self, int64_t dim, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::mode::call(self, dim, keepdim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, keepdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple mode_dimname_generated_plumbing(const at::Tensor & self, at::Dimname dim, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::mode_dimname::call(self, dim, keepdim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, keepdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor mul_Tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::mul_Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & mul__Tensor_generated_plumbing(at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::mul__Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self_value, self_bdim, other_value, other_bdim); + return self; +} +template +at::Tensor mul_Scalar_generated_plumbing(const at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::mul_Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, other); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & mul__Scalar_generated_plumbing(at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::mul__Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, other); + return self; +} +template +at::Tensor multiply_Tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::multiply_Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & multiply__Tensor_generated_plumbing(at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::multiply__Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self_value, self_bdim, other_value, other_bdim); + return self; +} +template +at::Tensor multiply_Scalar_generated_plumbing(const at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::multiply_Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, other); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & multiply__Scalar_generated_plumbing(at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::multiply__Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, other); + return self; +} +template +at::Tensor mv_generated_plumbing(const at::Tensor & self, const at::Tensor & vec) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(vec, cur_level)) { + return at::_ops::mv::call(self, vec); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [vec_value, vec_bdim] = unwrapTensorAtLevel(vec, cur_level); + auto results = batch_rule(self_value, self_bdim, vec_value, vec_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor mvlgamma_generated_plumbing(const at::Tensor & self, int64_t p) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::mvlgamma::call(self, p); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, p); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & mvlgamma__generated_plumbing(at::Tensor & self, int64_t p) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::mvlgamma_::call(self, p); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, p); + return self; +} +template +at::Tensor narrow_copy_generated_plumbing(const at::Tensor & self, int64_t dim, c10::SymInt start, c10::SymInt length) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::narrow_copy::call(self, dim, start, length); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, start, length); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor narrow_generated_plumbing(const at::Tensor & self, int64_t dim, c10::SymInt start, c10::SymInt length) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::narrow::call(self, dim, start, length); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, start, length); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor narrow_Tensor_generated_plumbing(const at::Tensor & self, int64_t dim, const at::Tensor & start, c10::SymInt length) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(start, cur_level)) { + return at::_ops::narrow_Tensor::call(self, dim, start, length); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [start_value, start_bdim] = unwrapTensorAtLevel(start, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, start_value, start_bdim, length); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple native_batch_norm_generated_plumbing(const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const ::std::optional & running_mean, const ::std::optional & running_var, bool training, double momentum, double eps) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level) && !isBatchedAtLevel(running_mean, cur_level) && !isBatchedAtLevel(running_var, cur_level)) { + return at::_ops::native_batch_norm::call(input, weight, bias, running_mean, running_var, training, momentum, eps); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + std::optional weight_value; + std::optional weight_bdim; + if (weight) { + std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight.value(), cur_level); + } + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + std::optional running_mean_value; + std::optional running_mean_bdim; + if (running_mean) { + std::tie(running_mean_value, running_mean_bdim) = unwrapTensorAtLevel(running_mean.value(), cur_level); + } + std::optional running_var_value; + std::optional running_var_bdim; + if (running_var) { + std::tie(running_var_value, running_var_bdim) = unwrapTensorAtLevel(running_var.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, weight_value, weight_bdim, bias_value, bias_bdim, running_mean_value, running_mean_bdim, running_var_value, running_var_bdim, training, momentum, eps); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level)); +} +template +::std::tuple _native_batch_norm_legit_no_training_generated_plumbing(const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const at::Tensor & running_mean, const at::Tensor & running_var, double momentum, double eps) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level) && !isBatchedAtLevel(running_mean, cur_level) && !isBatchedAtLevel(running_var, cur_level)) { + return at::_ops::_native_batch_norm_legit_no_training::call(input, weight, bias, running_mean, running_var, momentum, eps); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [running_mean_value, running_mean_bdim] = unwrapTensorAtLevel(running_mean, cur_level); + auto [running_var_value, running_var_bdim] = unwrapTensorAtLevel(running_var, cur_level); + std::optional weight_value; + std::optional weight_bdim; + if (weight) { + std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight.value(), cur_level); + } + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, weight_value, weight_bdim, bias_value, bias_bdim, running_mean_value, running_mean_bdim, running_var_value, running_var_bdim, momentum, eps); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level)); +} +template +::std::tuple _native_batch_norm_legit_no_stats_generated_plumbing(const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, bool training, double momentum, double eps) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::_native_batch_norm_legit_no_stats::call(input, weight, bias, training, momentum, eps); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + std::optional weight_value; + std::optional weight_bdim; + if (weight) { + std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight.value(), cur_level); + } + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, weight_value, weight_bdim, bias_value, bias_bdim, training, momentum, eps); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level)); +} +template +::std::tuple batch_norm_stats_generated_plumbing(const at::Tensor & input, double eps) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level)) { + return at::_ops::batch_norm_stats::call(input, eps); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto results = batch_rule(input_value, input_bdim, eps); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor batch_norm_elemt_generated_plumbing(const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const at::Tensor & mean, const at::Tensor & invstd, double eps) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level) && !isBatchedAtLevel(mean, cur_level) && !isBatchedAtLevel(invstd, cur_level)) { + return at::_ops::batch_norm_elemt::call(input, weight, bias, mean, invstd, eps); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [mean_value, mean_bdim] = unwrapTensorAtLevel(mean, cur_level); + auto [invstd_value, invstd_bdim] = unwrapTensorAtLevel(invstd, cur_level); + std::optional weight_value; + std::optional weight_bdim; + if (weight) { + std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight.value(), cur_level); + } + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, weight_value, weight_bdim, bias_value, bias_bdim, mean_value, mean_bdim, invstd_value, invstd_bdim, eps); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple batch_norm_gather_stats_generated_plumbing(const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const ::std::optional & running_mean, const ::std::optional & running_var, double momentum, double eps, int64_t count) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(mean, cur_level) && !isBatchedAtLevel(invstd, cur_level) && !isBatchedAtLevel(running_mean, cur_level) && !isBatchedAtLevel(running_var, cur_level)) { + return at::_ops::batch_norm_gather_stats::call(input, mean, invstd, running_mean, running_var, momentum, eps, count); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [mean_value, mean_bdim] = unwrapTensorAtLevel(mean, cur_level); + auto [invstd_value, invstd_bdim] = unwrapTensorAtLevel(invstd, cur_level); + std::optional running_mean_value; + std::optional running_mean_bdim; + if (running_mean) { + std::tie(running_mean_value, running_mean_bdim) = unwrapTensorAtLevel(running_mean.value(), cur_level); + } + std::optional running_var_value; + std::optional running_var_bdim; + if (running_var) { + std::tie(running_var_value, running_var_bdim) = unwrapTensorAtLevel(running_var.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, mean_value, mean_bdim, invstd_value, invstd_bdim, running_mean_value, running_mean_bdim, running_var_value, running_var_bdim, momentum, eps, count); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple batch_norm_gather_stats_with_counts_generated_plumbing(const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const ::std::optional & running_mean, const ::std::optional & running_var, double momentum, double eps, const at::Tensor & counts) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(mean, cur_level) && !isBatchedAtLevel(invstd, cur_level) && !isBatchedAtLevel(running_mean, cur_level) && !isBatchedAtLevel(running_var, cur_level) && !isBatchedAtLevel(counts, cur_level)) { + return at::_ops::batch_norm_gather_stats_with_counts::call(input, mean, invstd, running_mean, running_var, momentum, eps, counts); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [mean_value, mean_bdim] = unwrapTensorAtLevel(mean, cur_level); + auto [invstd_value, invstd_bdim] = unwrapTensorAtLevel(invstd, cur_level); + auto [counts_value, counts_bdim] = unwrapTensorAtLevel(counts, cur_level); + std::optional running_mean_value; + std::optional running_mean_bdim; + if (running_mean) { + std::tie(running_mean_value, running_mean_bdim) = unwrapTensorAtLevel(running_mean.value(), cur_level); + } + std::optional running_var_value; + std::optional running_var_bdim; + if (running_var) { + std::tie(running_var_value, running_var_bdim) = unwrapTensorAtLevel(running_var.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, mean_value, mean_bdim, invstd_value, invstd_bdim, running_mean_value, running_mean_bdim, running_var_value, running_var_bdim, momentum, eps, counts_value, counts_bdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple native_batch_norm_backward_generated_plumbing(const at::Tensor & grad_out, const at::Tensor & input, const ::std::optional & weight, const ::std::optional & running_mean, const ::std::optional & running_var, const ::std::optional & save_mean, const ::std::optional & save_invstd, bool train, double eps, ::std::array output_mask) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_out, cur_level) && !isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(running_mean, cur_level) && !isBatchedAtLevel(running_var, cur_level) && !isBatchedAtLevel(save_mean, cur_level) && !isBatchedAtLevel(save_invstd, cur_level)) { + return at::_ops::native_batch_norm_backward::call(grad_out, input, weight, running_mean, running_var, save_mean, save_invstd, train, eps, output_mask); + } + auto [grad_out_value, grad_out_bdim] = unwrapTensorAtLevel(grad_out, cur_level); + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + std::optional weight_value; + std::optional weight_bdim; + if (weight) { + std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight.value(), cur_level); + } + std::optional running_mean_value; + std::optional running_mean_bdim; + if (running_mean) { + std::tie(running_mean_value, running_mean_bdim) = unwrapTensorAtLevel(running_mean.value(), cur_level); + } + std::optional running_var_value; + std::optional running_var_bdim; + if (running_var) { + std::tie(running_var_value, running_var_bdim) = unwrapTensorAtLevel(running_var.value(), cur_level); + } + std::optional save_mean_value; + std::optional save_mean_bdim; + if (save_mean) { + std::tie(save_mean_value, save_mean_bdim) = unwrapTensorAtLevel(save_mean.value(), cur_level); + } + std::optional save_invstd_value; + std::optional save_invstd_bdim; + if (save_invstd) { + std::tie(save_invstd_value, save_invstd_bdim) = unwrapTensorAtLevel(save_invstd.value(), cur_level); + } + auto results = batch_rule(grad_out_value, grad_out_bdim, input_value, input_bdim, weight_value, weight_bdim, running_mean_value, running_mean_bdim, running_var_value, running_var_bdim, save_mean_value, save_mean_bdim, save_invstd_value, save_invstd_bdim, train, eps, output_mask); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level)); +} +template +::std::tuple batch_norm_backward_reduce_generated_plumbing(const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const ::std::optional & weight, bool input_g, bool weight_g, bool bias_g) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_out, cur_level) && !isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(mean, cur_level) && !isBatchedAtLevel(invstd, cur_level) && !isBatchedAtLevel(weight, cur_level)) { + return at::_ops::batch_norm_backward_reduce::call(grad_out, input, mean, invstd, weight, input_g, weight_g, bias_g); + } + auto [grad_out_value, grad_out_bdim] = unwrapTensorAtLevel(grad_out, cur_level); + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [mean_value, mean_bdim] = unwrapTensorAtLevel(mean, cur_level); + auto [invstd_value, invstd_bdim] = unwrapTensorAtLevel(invstd, cur_level); + std::optional weight_value; + std::optional weight_bdim; + if (weight) { + std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight.value(), cur_level); + } + auto results = batch_rule(grad_out_value, grad_out_bdim, input_value, input_bdim, mean_value, mean_bdim, invstd_value, invstd_bdim, weight_value, weight_bdim, input_g, weight_g, bias_g); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level), makeBatched(std::get<6>(results), std::get<7>(results), cur_level)); +} +template +at::Tensor batch_norm_backward_elemt_generated_plumbing(const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const ::std::optional & weight, const at::Tensor & sum_dy, const at::Tensor & sum_dy_xmu, const at::Tensor & count) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_out, cur_level) && !isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(mean, cur_level) && !isBatchedAtLevel(invstd, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(sum_dy, cur_level) && !isBatchedAtLevel(sum_dy_xmu, cur_level) && !isBatchedAtLevel(count, cur_level)) { + return at::_ops::batch_norm_backward_elemt::call(grad_out, input, mean, invstd, weight, sum_dy, sum_dy_xmu, count); + } + auto [grad_out_value, grad_out_bdim] = unwrapTensorAtLevel(grad_out, cur_level); + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [mean_value, mean_bdim] = unwrapTensorAtLevel(mean, cur_level); + auto [invstd_value, invstd_bdim] = unwrapTensorAtLevel(invstd, cur_level); + auto [sum_dy_value, sum_dy_bdim] = unwrapTensorAtLevel(sum_dy, cur_level); + auto [sum_dy_xmu_value, sum_dy_xmu_bdim] = unwrapTensorAtLevel(sum_dy_xmu, cur_level); + auto [count_value, count_bdim] = unwrapTensorAtLevel(count, cur_level); + std::optional weight_value; + std::optional weight_bdim; + if (weight) { + std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight.value(), cur_level); + } + auto results = batch_rule(grad_out_value, grad_out_bdim, input_value, input_bdim, mean_value, mean_bdim, invstd_value, invstd_bdim, weight_value, weight_bdim, sum_dy_value, sum_dy_bdim, sum_dy_xmu_value, sum_dy_xmu_bdim, count_value, count_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple batch_norm_update_stats_generated_plumbing(const at::Tensor & input, const ::std::optional & running_mean, const ::std::optional & running_var, double momentum) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(running_mean, cur_level) && !isBatchedAtLevel(running_var, cur_level)) { + return at::_ops::batch_norm_update_stats::call(input, running_mean, running_var, momentum); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + std::optional running_mean_value; + std::optional running_mean_bdim; + if (running_mean) { + std::tie(running_mean_value, running_mean_bdim) = unwrapTensorAtLevel(running_mean.value(), cur_level); + } + std::optional running_var_value; + std::optional running_var_bdim; + if (running_var) { + std::tie(running_var_value, running_var_bdim) = unwrapTensorAtLevel(running_var.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, running_mean_value, running_mean_bdim, running_var_value, running_var_bdim, momentum); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor _nnpack_spatial_convolution_generated_plumbing(const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::_nnpack_spatial_convolution::call(input, weight, bias, padding, stride); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, weight_value, weight_bdim, bias_value, bias_bdim, padding, stride); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor ones_like_generated_plumbing(const at::Tensor & self, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::ones_like::call(self, dtype, layout, device, pin_memory, memory_format); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dtype, layout, device, pin_memory, memory_format); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor pairwise_distance_generated_plumbing(const at::Tensor & x1, const at::Tensor & x2, double p, double eps, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(x1, cur_level) && !isBatchedAtLevel(x2, cur_level)) { + return at::_ops::pairwise_distance::call(x1, x2, p, eps, keepdim); + } + auto [x1_value, x1_bdim] = unwrapTensorAtLevel(x1, cur_level); + auto [x2_value, x2_bdim] = unwrapTensorAtLevel(x2, cur_level); + auto results = batch_rule(x1_value, x1_bdim, x2_value, x2_bdim, p, eps, keepdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor cdist_generated_plumbing(const at::Tensor & x1, const at::Tensor & x2, double p, ::std::optional compute_mode) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(x1, cur_level) && !isBatchedAtLevel(x2, cur_level)) { + return at::_ops::cdist::call(x1, x2, p, compute_mode); + } + auto [x1_value, x1_bdim] = unwrapTensorAtLevel(x1, cur_level); + auto [x2_value, x2_bdim] = unwrapTensorAtLevel(x2, cur_level); + auto results = batch_rule(x1_value, x1_bdim, x2_value, x2_bdim, p, compute_mode); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _euclidean_dist_generated_plumbing(const at::Tensor & x1, const at::Tensor & x2) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(x1, cur_level) && !isBatchedAtLevel(x2, cur_level)) { + return at::_ops::_euclidean_dist::call(x1, x2); + } + auto [x1_value, x1_bdim] = unwrapTensorAtLevel(x1, cur_level); + auto [x2_value, x2_bdim] = unwrapTensorAtLevel(x2, cur_level); + auto results = batch_rule(x1_value, x1_bdim, x2_value, x2_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _cdist_forward_generated_plumbing(const at::Tensor & x1, const at::Tensor & x2, double p, ::std::optional compute_mode) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(x1, cur_level) && !isBatchedAtLevel(x2, cur_level)) { + return at::_ops::_cdist_forward::call(x1, x2, p, compute_mode); + } + auto [x1_value, x1_bdim] = unwrapTensorAtLevel(x1, cur_level); + auto [x2_value, x2_bdim] = unwrapTensorAtLevel(x2, cur_level); + auto results = batch_rule(x1_value, x1_bdim, x2_value, x2_bdim, p, compute_mode); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _cdist_backward_generated_plumbing(const at::Tensor & grad, const at::Tensor & x1, const at::Tensor & x2, double p, const at::Tensor & cdist) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad, cur_level) && !isBatchedAtLevel(x1, cur_level) && !isBatchedAtLevel(x2, cur_level) && !isBatchedAtLevel(cdist, cur_level)) { + return at::_ops::_cdist_backward::call(grad, x1, x2, p, cdist); + } + auto [grad_value, grad_bdim] = unwrapTensorAtLevel(grad, cur_level); + auto [x1_value, x1_bdim] = unwrapTensorAtLevel(x1, cur_level); + auto [x2_value, x2_bdim] = unwrapTensorAtLevel(x2, cur_level); + auto [cdist_value, cdist_bdim] = unwrapTensorAtLevel(cdist, cur_level); + auto results = batch_rule(grad_value, grad_bdim, x1_value, x1_bdim, x2_value, x2_bdim, p, cdist_value, cdist_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor pdist_generated_plumbing(const at::Tensor & self, double p) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::pdist::call(self, p); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, p); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _pdist_forward_generated_plumbing(const at::Tensor & self, double p) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_pdist_forward::call(self, p); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, p); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _pdist_backward_generated_plumbing(const at::Tensor & grad, const at::Tensor & self, double p, const at::Tensor & pdist) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad, cur_level) && !isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(pdist, cur_level)) { + return at::_ops::_pdist_backward::call(grad, self, p, pdist); + } + auto [grad_value, grad_bdim] = unwrapTensorAtLevel(grad, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [pdist_value, pdist_bdim] = unwrapTensorAtLevel(pdist, cur_level); + auto results = batch_rule(grad_value, grad_bdim, self_value, self_bdim, p, pdist_value, pdist_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor cosine_similarity_generated_plumbing(const at::Tensor & x1, const at::Tensor & x2, int64_t dim, double eps) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(x1, cur_level) && !isBatchedAtLevel(x2, cur_level)) { + return at::_ops::cosine_similarity::call(x1, x2, dim, eps); + } + auto [x1_value, x1_bdim] = unwrapTensorAtLevel(x1, cur_level); + auto [x2_value, x2_bdim] = unwrapTensorAtLevel(x2, cur_level); + auto results = batch_rule(x1_value, x1_bdim, x2_value, x2_bdim, dim, eps); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor permute_generated_plumbing(const at::Tensor & self, at::IntArrayRef dims) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::permute::call(self, dims); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dims); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor movedim_intlist_generated_plumbing(const at::Tensor & self, at::IntArrayRef source, at::IntArrayRef destination) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::movedim_intlist::call(self, source, destination); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, source, destination); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor movedim_int_generated_plumbing(const at::Tensor & self, int64_t source, int64_t destination) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::movedim_int::call(self, source, destination); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, source, destination); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor moveaxis_intlist_generated_plumbing(const at::Tensor & self, at::IntArrayRef source, at::IntArrayRef destination) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::moveaxis_intlist::call(self, source, destination); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, source, destination); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor moveaxis_int_generated_plumbing(const at::Tensor & self, int64_t source, int64_t destination) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::moveaxis_int::call(self, source, destination); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, source, destination); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor numpy_T_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::numpy_T::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor matrix_H_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::matrix_H::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor mT_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::mT::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor mH_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::mH::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor adjoint_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::adjoint::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor pixel_shuffle_generated_plumbing(const at::Tensor & self, int64_t upscale_factor) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::pixel_shuffle::call(self, upscale_factor); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, upscale_factor); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor pixel_unshuffle_generated_plumbing(const at::Tensor & self, int64_t downscale_factor) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::pixel_unshuffle::call(self, downscale_factor); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, downscale_factor); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor channel_shuffle_generated_plumbing(const at::Tensor & self, c10::SymInt groups) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::channel_shuffle::call(self, groups); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, groups); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor native_channel_shuffle_generated_plumbing(const at::Tensor & self, c10::SymInt groups) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::native_channel_shuffle::call(self, groups); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, groups); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor pin_memory_generated_plumbing(const at::Tensor & self, ::std::optional device) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::pin_memory::call(self, device); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, device); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _pin_memory_generated_plumbing(const at::Tensor & self, ::std::optional device) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_pin_memory::call(self, device); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, device); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor pinverse_generated_plumbing(const at::Tensor & self, double rcond) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::pinverse::call(self, rcond); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, rcond); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor poisson_nll_loss_generated_plumbing(const at::Tensor & input, const at::Tensor & target, bool log_input, bool full, double eps, int64_t reduction) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(target, cur_level)) { + return at::_ops::poisson_nll_loss::call(input, target, log_input, full, eps, reduction); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [target_value, target_bdim] = unwrapTensorAtLevel(target, cur_level); + auto results = batch_rule(input_value, input_bdim, target_value, target_bdim, log_input, full, eps, reduction); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor rad2deg_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::rad2deg::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & rad2deg__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::rad2deg_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor deg2rad_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::deg2rad::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & deg2rad__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::deg2rad_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor rand_like_generated_plumbing(const at::Tensor & self, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::rand_like::call(self, dtype, layout, device, pin_memory, memory_format); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dtype, layout, device, pin_memory, memory_format); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor rand_like_generator_generated_plumbing(const at::Tensor & self, ::std::optional generator, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::rand_like_generator::call(self, generator, dtype, layout, device, pin_memory, memory_format); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, generator, dtype, layout, device, pin_memory, memory_format); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor randint_like_generated_plumbing(const at::Tensor & self, c10::SymInt high, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::randint_like::call(self, high, dtype, layout, device, pin_memory, memory_format); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, high, dtype, layout, device, pin_memory, memory_format); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor randint_like_generator_generated_plumbing(const at::Tensor & self, c10::SymInt high, ::std::optional generator, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::randint_like_generator::call(self, high, generator, dtype, layout, device, pin_memory, memory_format); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, high, generator, dtype, layout, device, pin_memory, memory_format); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor randint_like_Tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & high, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(high, cur_level)) { + return at::_ops::randint_like_Tensor::call(self, high, dtype, layout, device, pin_memory, memory_format); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [high_value, high_bdim] = unwrapTensorAtLevel(high, cur_level); + auto results = batch_rule(self_value, self_bdim, high_value, high_bdim, dtype, layout, device, pin_memory, memory_format); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor randint_like_Tensor_generator_generated_plumbing(const at::Tensor & self, const at::Tensor & high, ::std::optional generator, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(high, cur_level)) { + return at::_ops::randint_like_Tensor_generator::call(self, high, generator, dtype, layout, device, pin_memory, memory_format); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [high_value, high_bdim] = unwrapTensorAtLevel(high, cur_level); + auto results = batch_rule(self_value, self_bdim, high_value, high_bdim, generator, dtype, layout, device, pin_memory, memory_format); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor randint_like_low_dtype_generated_plumbing(const at::Tensor & self, c10::SymInt low, c10::SymInt high, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::randint_like_low_dtype::call(self, low, high, dtype, layout, device, pin_memory, memory_format); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, low, high, dtype, layout, device, pin_memory, memory_format); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor randint_like_low_generator_dtype_generated_plumbing(const at::Tensor & self, c10::SymInt low, c10::SymInt high, ::std::optional generator, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::randint_like_low_generator_dtype::call(self, low, high, generator, dtype, layout, device, pin_memory, memory_format); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, low, high, generator, dtype, layout, device, pin_memory, memory_format); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor randn_like_generated_plumbing(const at::Tensor & self, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::randn_like::call(self, dtype, layout, device, pin_memory, memory_format); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dtype, layout, device, pin_memory, memory_format); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor randn_like_generator_generated_plumbing(const at::Tensor & self, ::std::optional generator, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::randn_like_generator::call(self, generator, dtype, layout, device, pin_memory, memory_format); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, generator, dtype, layout, device, pin_memory, memory_format); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor ravel_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::ravel::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor reciprocal_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::reciprocal::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & reciprocal__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::reciprocal_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor neg_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::neg::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & neg__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::neg_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor negative_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::negative::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & negative__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::negative_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor repeat_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef repeats) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::repeat::call(self, repeats); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, repeats); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor repeat_interleave_Tensor_generated_plumbing(const at::Tensor & repeats, ::std::optional output_size) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(repeats, cur_level)) { + return at::_ops::repeat_interleave_Tensor::call(repeats, output_size); + } + auto [repeats_value, repeats_bdim] = unwrapTensorAtLevel(repeats, cur_level); + auto results = batch_rule(repeats_value, repeats_bdim, output_size); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor repeat_interleave_self_Tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & repeats, ::std::optional dim, ::std::optional output_size) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(repeats, cur_level)) { + return at::_ops::repeat_interleave_self_Tensor::call(self, repeats, dim, output_size); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [repeats_value, repeats_bdim] = unwrapTensorAtLevel(repeats, cur_level); + auto results = batch_rule(self_value, self_bdim, repeats_value, repeats_bdim, dim, output_size); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor repeat_interleave_self_int_generated_plumbing(const at::Tensor & self, c10::SymInt repeats, ::std::optional dim, ::std::optional output_size) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::repeat_interleave_self_int::call(self, repeats, dim, output_size); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, repeats, dim, output_size); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor reshape_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef shape) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::reshape::call(self, shape); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, shape); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _reshape_copy_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef size) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_reshape_copy::call(self, size); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, size); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _reshape_alias_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_reshape_alias::call(self, size, stride); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, size, stride); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _mkldnn_reshape_generated_plumbing(const at::Tensor & self, at::IntArrayRef shape) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_mkldnn_reshape::call(self, shape); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, shape); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor reshape_as_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::reshape_as::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor round_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::round::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & round__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::round_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor round_decimals_generated_plumbing(const at::Tensor & self, int64_t decimals) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::round_decimals::call(self, decimals); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, decimals); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & round__decimals_generated_plumbing(at::Tensor & self, int64_t decimals) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::round__decimals::call(self, decimals); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, decimals); + return self; +} +template +at::Tensor rrelu_generated_plumbing(const at::Tensor & self, const at::Scalar & lower, const at::Scalar & upper, bool training, ::std::optional generator) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::rrelu::call(self, lower, upper, training, generator); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, lower, upper, training, generator); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & rrelu__generated_plumbing(at::Tensor & self, const at::Scalar & lower, const at::Scalar & upper, bool training, ::std::optional generator) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::rrelu_::call(self, lower, upper, training, generator); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, lower, upper, training, generator); + return self; +} +template +at::Tensor relu_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::relu::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & relu__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::relu_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor relu6_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::relu6::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & relu6__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::relu6_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor prelu_generated_plumbing(const at::Tensor & self, const at::Tensor & weight) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(weight, cur_level)) { + return at::_ops::prelu::call(self, weight); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + auto results = batch_rule(self_value, self_bdim, weight_value, weight_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _prelu_kernel_generated_plumbing(const at::Tensor & self, const at::Tensor & weight) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(weight, cur_level)) { + return at::_ops::_prelu_kernel::call(self, weight); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + auto results = batch_rule(self_value, self_bdim, weight_value, weight_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple _prelu_kernel_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & weight) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(weight, cur_level)) { + return at::_ops::_prelu_kernel_backward::call(grad_output, self, weight); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, self_value, self_bdim, weight_value, weight_bdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor & gelu__generated_plumbing(at::Tensor & self, c10::string_view approximate) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::gelu_::call(self, approximate); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, approximate); + return self; +} +template +at::Tensor gelu_generated_plumbing(const at::Tensor & self, c10::string_view approximate) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::gelu::call(self, approximate); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, approximate); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor gelu_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & self, c10::string_view approximate) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(self, cur_level)) { + return at::_ops::gelu_backward::call(grad_output, self, approximate); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, self_value, self_bdim, approximate); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor infinitely_differentiable_gelu_backward_generated_plumbing(const at::Tensor & grad, const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad, cur_level) && !isBatchedAtLevel(self, cur_level)) { + return at::_ops::infinitely_differentiable_gelu_backward::call(grad, self); + } + auto [grad_value, grad_bdim] = unwrapTensorAtLevel(grad, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(grad_value, grad_bdim, self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor hardshrink_generated_plumbing(const at::Tensor & self, const at::Scalar & lambd) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::hardshrink::call(self, lambd); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, lambd); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor hardshrink_backward_generated_plumbing(const at::Tensor & grad_out, const at::Tensor & self, const at::Scalar & lambd) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_out, cur_level) && !isBatchedAtLevel(self, cur_level)) { + return at::_ops::hardshrink_backward::call(grad_out, self, lambd); + } + auto [grad_out_value, grad_out_bdim] = unwrapTensorAtLevel(grad_out, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(grad_out_value, grad_out_bdim, self_value, self_bdim, lambd); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor rsqrt_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::rsqrt::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & rsqrt__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::rsqrt_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor select_Dimname_generated_plumbing(const at::Tensor & self, at::Dimname dim, int64_t index) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::select_Dimname::call(self, dim, index); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, index); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor select_int_generated_plumbing(const at::Tensor & self, int64_t dim, c10::SymInt index) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::select_int::call(self, dim, index); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, index); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor select_backward_generated_plumbing(const at::Tensor & grad_output, c10::SymIntArrayRef input_sizes, int64_t dim, c10::SymInt index) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level)) { + return at::_ops::select_backward::call(grad_output, input_sizes, dim, index); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, input_sizes, dim, index); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _nested_select_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & self, int64_t dim, c10::SymInt index) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(self, cur_level)) { + return at::_ops::_nested_select_backward::call(grad_output, self, dim, index); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, self_value, self_bdim, dim, index); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor selu_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::selu::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & selu__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::selu_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor celu_generated_plumbing(const at::Tensor & self, const at::Scalar & alpha) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::celu::call(self, alpha); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, alpha); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & celu__generated_plumbing(at::Tensor & self, const at::Scalar & alpha) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::celu_::call(self, alpha); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, alpha); + return self; +} +template +at::Tensor silu_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::silu::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & silu__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::silu_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor silu_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(self, cur_level)) { + return at::_ops::silu_backward::call(grad_output, self); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor mish_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::mish::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & mish__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::mish_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor mish_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(self, cur_level)) { + return at::_ops::mish_backward::call(grad_output, self); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor sigmoid_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::sigmoid::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & sigmoid__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::sigmoid_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor logit_generated_plumbing(const at::Tensor & self, ::std::optional eps) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::logit::call(self, eps); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, eps); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & logit__generated_plumbing(at::Tensor & self, ::std::optional eps) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::logit_::call(self, eps); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, eps); + return self; +} +template +at::Tensor sin_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::sin::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & sin__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::sin_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor sinc_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::sinc::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & sinc__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::sinc_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor sinh_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::sinh::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & sinh__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::sinh_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor detach_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::detach::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor slice_Tensor_generated_plumbing(const at::Tensor & self, int64_t dim, ::std::optional start, ::std::optional end, c10::SymInt step) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::slice_Tensor::call(self, dim, start, end, step); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, start, end, step); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor slice_backward_generated_plumbing(const at::Tensor & grad_output, c10::SymIntArrayRef input_sizes, int64_t dim, c10::SymInt start, c10::SymInt end, c10::SymInt step) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level)) { + return at::_ops::slice_backward::call(grad_output, input_sizes, dim, start, end, step); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, input_sizes, dim, start, end, step); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor slice_inverse_generated_plumbing(const at::Tensor & self, const at::Tensor & src, int64_t dim, ::std::optional start, ::std::optional end, c10::SymInt step) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(src, cur_level)) { + return at::_ops::slice_inverse::call(self, src, dim, start, end, step); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [src_value, src_bdim] = unwrapTensorAtLevel(src, cur_level); + auto results = batch_rule(self_value, self_bdim, src_value, src_bdim, dim, start, end, step); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor slice_scatter_generated_plumbing(const at::Tensor & self, const at::Tensor & src, int64_t dim, ::std::optional start, ::std::optional end, c10::SymInt step) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(src, cur_level)) { + return at::_ops::slice_scatter::call(self, src, dim, start, end, step); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [src_value, src_bdim] = unwrapTensorAtLevel(src, cur_level); + auto results = batch_rule(self_value, self_bdim, src_value, src_bdim, dim, start, end, step); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor select_scatter_generated_plumbing(const at::Tensor & self, const at::Tensor & src, int64_t dim, c10::SymInt index) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(src, cur_level)) { + return at::_ops::select_scatter::call(self, src, dim, index); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [src_value, src_bdim] = unwrapTensorAtLevel(src, cur_level); + auto results = batch_rule(self_value, self_bdim, src_value, src_bdim, dim, index); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor diagonal_scatter_generated_plumbing(const at::Tensor & self, const at::Tensor & src, int64_t offset, int64_t dim1, int64_t dim2) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(src, cur_level)) { + return at::_ops::diagonal_scatter::call(self, src, offset, dim1, dim2); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [src_value, src_bdim] = unwrapTensorAtLevel(src, cur_level); + auto results = batch_rule(self_value, self_bdim, src_value, src_bdim, offset, dim1, dim2); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor as_strided_scatter_generated_plumbing(const at::Tensor & self, const at::Tensor & src, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional storage_offset) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(src, cur_level)) { + return at::_ops::as_strided_scatter::call(self, src, size, stride, storage_offset); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [src_value, src_bdim] = unwrapTensorAtLevel(src, cur_level); + auto results = batch_rule(self_value, self_bdim, src_value, src_bdim, size, stride, storage_offset); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor smm_generated_plumbing(const at::Tensor & self, const at::Tensor & mat2) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(mat2, cur_level)) { + return at::_ops::smm::call(self, mat2); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [mat2_value, mat2_bdim] = unwrapTensorAtLevel(mat2, cur_level); + auto results = batch_rule(self_value, self_bdim, mat2_value, mat2_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor softmax_int_generated_plumbing(const at::Tensor & self, int64_t dim, ::std::optional dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::softmax_int::call(self, dim, dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor softmax_Dimname_generated_plumbing(const at::Tensor & self, at::Dimname dim, ::std::optional dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::softmax_Dimname::call(self, dim, dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _softmax_generated_plumbing(const at::Tensor & self, int64_t dim, bool half_to_float) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_softmax::call(self, dim, half_to_float); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, half_to_float); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _softmax_backward_data_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & output, int64_t dim, at::ScalarType input_dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(output, cur_level)) { + return at::_ops::_softmax_backward_data::call(grad_output, output, dim, input_dtype); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [output_value, output_bdim] = unwrapTensorAtLevel(output, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, output_value, output_bdim, dim, input_dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::vector unsafe_split_Tensor_generated_plumbing(const at::Tensor & self, c10::SymInt split_size, int64_t dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::unsafe_split_Tensor::call(self, split_size, dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, split_size, dim); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::vector split_Tensor_generated_plumbing(const at::Tensor & self, c10::SymInt split_size, int64_t dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::split_Tensor::call(self, split_size, dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, split_size, dim); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::vector split_sizes_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef split_size, int64_t dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::split_sizes::call(self, split_size, dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, split_size, dim); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::vector unsafe_split_with_sizes_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef split_sizes, int64_t dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::unsafe_split_with_sizes::call(self, split_sizes, dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, split_sizes, dim); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::vector split_with_sizes_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef split_sizes, int64_t dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::split_with_sizes::call(self, split_sizes, dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, split_sizes, dim); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::vector hsplit_int_generated_plumbing(const at::Tensor & self, int64_t sections) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::hsplit_int::call(self, sections); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, sections); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::vector hsplit_array_generated_plumbing(const at::Tensor & self, at::IntArrayRef indices) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::hsplit_array::call(self, indices); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, indices); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::vector vsplit_int_generated_plumbing(const at::Tensor & self, int64_t sections) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::vsplit_int::call(self, sections); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, sections); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::vector vsplit_array_generated_plumbing(const at::Tensor & self, at::IntArrayRef indices) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::vsplit_array::call(self, indices); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, indices); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::vector dsplit_int_generated_plumbing(const at::Tensor & self, int64_t sections) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::dsplit_int::call(self, sections); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, sections); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::vector dsplit_array_generated_plumbing(const at::Tensor & self, at::IntArrayRef indices) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::dsplit_array::call(self, indices); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, indices); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor squeeze_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::squeeze::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor squeeze_dim_generated_plumbing(const at::Tensor & self, int64_t dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::squeeze_dim::call(self, dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor squeeze_dimname_generated_plumbing(const at::Tensor & self, at::Dimname dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::squeeze_dimname::call(self, dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor squeeze_dims_generated_plumbing(const at::Tensor & self, at::IntArrayRef dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::squeeze_dims::call(self, dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor sspaddmm_generated_plumbing(const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta, const at::Scalar & alpha) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(mat1, cur_level) && !isBatchedAtLevel(mat2, cur_level)) { + return at::_ops::sspaddmm::call(self, mat1, mat2, beta, alpha); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [mat1_value, mat1_bdim] = unwrapTensorAtLevel(mat1, cur_level); + auto [mat2_value, mat2_bdim] = unwrapTensorAtLevel(mat2, cur_level); + auto results = batch_rule(self_value, self_bdim, mat1_value, mat1_bdim, mat2_value, mat2_bdim, beta, alpha); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _chunk_cat_generated_plumbing(at::TensorList tensors, int64_t dim, int64_t num_chunks) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(tensors, cur_level)) { + return at::_ops::_chunk_cat::call(tensors, dim, num_chunks); + } + + auto results = batch_rule(tensors, dim, num_chunks); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor stack_generated_plumbing(at::TensorList tensors, int64_t dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(tensors, cur_level)) { + return at::_ops::stack::call(tensors, dim); + } + + auto results = batch_rule(tensors, dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _stack_generated_plumbing(at::TensorList tensors, int64_t dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(tensors, cur_level)) { + return at::_ops::_stack::call(tensors, dim); + } + + auto results = batch_rule(tensors, dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor hstack_generated_plumbing(at::TensorList tensors) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(tensors, cur_level)) { + return at::_ops::hstack::call(tensors); + } + + auto results = batch_rule(tensors); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor vstack_generated_plumbing(at::TensorList tensors) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(tensors, cur_level)) { + return at::_ops::vstack::call(tensors); + } + + auto results = batch_rule(tensors); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor dstack_generated_plumbing(at::TensorList tensors) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(tensors, cur_level)) { + return at::_ops::dstack::call(tensors); + } + + auto results = batch_rule(tensors); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor stft_generated_plumbing(const at::Tensor & self, int64_t n_fft, ::std::optional hop_length, ::std::optional win_length, const ::std::optional & window, bool normalized, ::std::optional onesided, ::std::optional return_complex, ::std::optional align_to_window) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(window, cur_level)) { + return at::_ops::stft::call(self, n_fft, hop_length, win_length, window, normalized, onesided, return_complex, align_to_window); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + std::optional window_value; + std::optional window_bdim; + if (window) { + std::tie(window_value, window_bdim) = unwrapTensorAtLevel(window.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, n_fft, hop_length, win_length, window_value, window_bdim, normalized, onesided, return_complex, align_to_window); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor stft_center_generated_plumbing(const at::Tensor & self, int64_t n_fft, ::std::optional hop_length, ::std::optional win_length, const ::std::optional & window, bool center, c10::string_view pad_mode, bool normalized, ::std::optional onesided, ::std::optional return_complex, ::std::optional align_to_window) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(window, cur_level)) { + return at::_ops::stft_center::call(self, n_fft, hop_length, win_length, window, center, pad_mode, normalized, onesided, return_complex, align_to_window); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + std::optional window_value; + std::optional window_bdim; + if (window) { + std::tie(window_value, window_bdim) = unwrapTensorAtLevel(window.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, n_fft, hop_length, win_length, window_value, window_bdim, center, pad_mode, normalized, onesided, return_complex, align_to_window); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor istft_generated_plumbing(const at::Tensor & self, int64_t n_fft, ::std::optional hop_length, ::std::optional win_length, const ::std::optional & window, bool center, bool normalized, ::std::optional onesided, ::std::optional length, bool return_complex) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(window, cur_level)) { + return at::_ops::istft::call(self, n_fft, hop_length, win_length, window, center, normalized, onesided, length, return_complex); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + std::optional window_value; + std::optional window_bdim; + if (window) { + std::tie(window_value, window_bdim) = unwrapTensorAtLevel(window.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, n_fft, hop_length, win_length, window_value, window_bdim, center, normalized, onesided, length, return_complex); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor sum_generated_plumbing(const at::Tensor & self, ::std::optional dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::sum::call(self, dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor sum_dim_IntList_generated_plumbing(const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim, ::std::optional dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::sum_dim_IntList::call(self, dim, keepdim, dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, keepdim, dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor sum_dim_DimnameList_generated_plumbing(const at::Tensor & self, at::DimnameList dim, bool keepdim, ::std::optional dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::sum_dim_DimnameList::call(self, dim, keepdim, dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, keepdim, dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _nested_sum_backward_generated_plumbing(const at::Tensor & grad, const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad, cur_level) && !isBatchedAtLevel(self, cur_level)) { + return at::_ops::_nested_sum_backward::call(grad, self, dim, keepdim); + } + auto [grad_value, grad_bdim] = unwrapTensorAtLevel(grad, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(grad_value, grad_bdim, self_value, self_bdim, dim, keepdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor nansum_generated_plumbing(const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim, ::std::optional dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::nansum::call(self, dim, keepdim, dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, keepdim, dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor hash_tensor_generated_plumbing(const at::Tensor & self, at::IntArrayRef dim, bool keepdim, int64_t mode) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::hash_tensor::call(self, dim, keepdim, mode); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, keepdim, mode); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor sum_to_size_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef size) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::sum_to_size::call(self, size); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, size); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor sqrt_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::sqrt::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & sqrt__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::sqrt_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor square_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::square::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & square__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::square_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor std_generated_plumbing(const at::Tensor & self, bool unbiased) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::std::call(self, unbiased); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, unbiased); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor std_dim_generated_plumbing(const at::Tensor & self, at::OptionalIntArrayRef dim, bool unbiased, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::std_dim::call(self, dim, unbiased, keepdim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, unbiased, keepdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor std_correction_generated_plumbing(const at::Tensor & self, at::OptionalIntArrayRef dim, const ::std::optional & correction, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::std_correction::call(self, dim, correction, keepdim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, correction, keepdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple std_mean_generated_plumbing(const at::Tensor & self, bool unbiased) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::std_mean::call(self, unbiased); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, unbiased); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple std_mean_dim_generated_plumbing(const at::Tensor & self, at::OptionalIntArrayRef dim, bool unbiased, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::std_mean_dim::call(self, dim, unbiased, keepdim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, unbiased, keepdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple std_mean_correction_generated_plumbing(const at::Tensor & self, at::OptionalIntArrayRef dim, const ::std::optional & correction, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::std_mean_correction::call(self, dim, correction, keepdim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, correction, keepdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple std_mean_names_dim_generated_plumbing(const at::Tensor & self, at::DimnameList dim, bool unbiased, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::std_mean_names_dim::call(self, dim, unbiased, keepdim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, unbiased, keepdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple std_mean_correction_names_generated_plumbing(const at::Tensor & self, at::DimnameList dim, const ::std::optional & correction, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::std_mean_correction_names::call(self, dim, correction, keepdim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, correction, keepdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor std_names_dim_generated_plumbing(const at::Tensor & self, at::DimnameList dim, bool unbiased, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::std_names_dim::call(self, dim, unbiased, keepdim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, unbiased, keepdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor std_correction_names_generated_plumbing(const at::Tensor & self, at::DimnameList dim, const ::std::optional & correction, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::std_correction_names::call(self, dim, correction, keepdim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, correction, keepdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor prod_generated_plumbing(const at::Tensor & self, ::std::optional dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::prod::call(self, dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor prod_dim_int_generated_plumbing(const at::Tensor & self, int64_t dim, bool keepdim, ::std::optional dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::prod_dim_int::call(self, dim, keepdim, dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, keepdim, dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor prod_dim_Dimname_generated_plumbing(const at::Tensor & self, at::Dimname dim, bool keepdim, ::std::optional dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::prod_dim_Dimname::call(self, dim, keepdim, dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, keepdim, dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor t_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::t::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor tan_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::tan::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & tan__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::tan_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor tanh_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::tanh::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & tanh__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::tanh_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor tensordot_generated_plumbing(const at::Tensor & self, const at::Tensor & other, at::IntArrayRef dims_self, at::IntArrayRef dims_other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::tensordot::call(self, other, dims_self, dims_other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim, dims_self, dims_other); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor threshold_generated_plumbing(const at::Tensor & self, const at::Scalar & threshold, const at::Scalar & value) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::threshold::call(self, threshold, value); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, threshold, value); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & threshold__generated_plumbing(at::Tensor & self, const at::Scalar & threshold, const at::Scalar & value) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::threshold_::call(self, threshold, value); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, threshold, value); + return self; +} +template +at::Tensor threshold_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & self, const at::Scalar & threshold) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(self, cur_level)) { + return at::_ops::threshold_backward::call(grad_output, self, threshold); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, self_value, self_bdim, threshold); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor tile_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef dims) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::tile::call(self, dims); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dims); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor transpose_int_generated_plumbing(const at::Tensor & self, int64_t dim0, int64_t dim1) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::transpose_int::call(self, dim0, dim1); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim0, dim1); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor transpose_Dimname_generated_plumbing(const at::Tensor & self, at::Dimname dim0, at::Dimname dim1) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::transpose_Dimname::call(self, dim0, dim1); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim0, dim1); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _mkldnn_transpose_generated_plumbing(const at::Tensor & self, int64_t dim0, int64_t dim1) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_mkldnn_transpose::call(self, dim0, dim1); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim0, dim1); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & _mkldnn_transpose__generated_plumbing(at::Tensor & self, int64_t dim0, int64_t dim1) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_mkldnn_transpose_::call(self, dim0, dim1); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, dim0, dim1); + return self; +} +template +at::Tensor one_hot_generated_plumbing(const at::Tensor & self, int64_t num_classes) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::one_hot::call(self, num_classes); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, num_classes); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor flip_generated_plumbing(const at::Tensor & self, at::IntArrayRef dims) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::flip::call(self, dims); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dims); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor fliplr_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::fliplr::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor flipud_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::flipud::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor roll_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef shifts, at::IntArrayRef dims) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::roll::call(self, shifts, dims); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, shifts, dims); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor rot90_generated_plumbing(const at::Tensor & self, int64_t k, at::IntArrayRef dims) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::rot90::call(self, k, dims); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, k, dims); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor trapezoid_x_generated_plumbing(const at::Tensor & y, const at::Tensor & x, int64_t dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(y, cur_level) && !isBatchedAtLevel(x, cur_level)) { + return at::_ops::trapezoid_x::call(y, x, dim); + } + auto [y_value, y_bdim] = unwrapTensorAtLevel(y, cur_level); + auto [x_value, x_bdim] = unwrapTensorAtLevel(x, cur_level); + auto results = batch_rule(y_value, y_bdim, x_value, x_bdim, dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor trapezoid_dx_generated_plumbing(const at::Tensor & y, const at::Scalar & dx, int64_t dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(y, cur_level)) { + return at::_ops::trapezoid_dx::call(y, dx, dim); + } + auto [y_value, y_bdim] = unwrapTensorAtLevel(y, cur_level); + auto results = batch_rule(y_value, y_bdim, dx, dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor trapz_x_generated_plumbing(const at::Tensor & y, const at::Tensor & x, int64_t dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(y, cur_level) && !isBatchedAtLevel(x, cur_level)) { + return at::_ops::trapz_x::call(y, x, dim); + } + auto [y_value, y_bdim] = unwrapTensorAtLevel(y, cur_level); + auto [x_value, x_bdim] = unwrapTensorAtLevel(x, cur_level); + auto results = batch_rule(y_value, y_bdim, x_value, x_bdim, dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor trapz_dx_generated_plumbing(const at::Tensor & y, double dx, int64_t dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(y, cur_level)) { + return at::_ops::trapz_dx::call(y, dx, dim); + } + auto [y_value, y_bdim] = unwrapTensorAtLevel(y, cur_level); + auto results = batch_rule(y_value, y_bdim, dx, dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple _transform_bias_rescale_qkv_generated_plumbing(const at::Tensor & qkv, const at::Tensor & qkv_bias, int64_t num_heads) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(qkv, cur_level) && !isBatchedAtLevel(qkv_bias, cur_level)) { + return at::_ops::_transform_bias_rescale_qkv::call(qkv, qkv_bias, num_heads); + } + auto [qkv_value, qkv_bdim] = unwrapTensorAtLevel(qkv, cur_level); + auto [qkv_bias_value, qkv_bias_bdim] = unwrapTensorAtLevel(qkv_bias, cur_level); + auto results = batch_rule(qkv_value, qkv_bdim, qkv_bias_value, qkv_bias_bdim, num_heads); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level)); +} +template +at::Tensor _nested_tensor_from_mask_generated_plumbing(const at::Tensor & t, const at::Tensor & mask, bool mask_check) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(t, cur_level) && !isBatchedAtLevel(mask, cur_level)) { + return at::_ops::_nested_tensor_from_mask::call(t, mask, mask_check); + } + auto [t_value, t_bdim] = unwrapTensorAtLevel(t, cur_level); + auto [mask_value, mask_bdim] = unwrapTensorAtLevel(mask, cur_level); + auto results = batch_rule(t_value, t_bdim, mask_value, mask_bdim, mask_check); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _nested_from_padded_generated_plumbing(const at::Tensor & padded, const at::Tensor & cpu_nested_shape_example, bool fuse_transform_0213) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(padded, cur_level) && !isBatchedAtLevel(cpu_nested_shape_example, cur_level)) { + return at::_ops::_nested_from_padded::call(padded, cpu_nested_shape_example, fuse_transform_0213); + } + auto [padded_value, padded_bdim] = unwrapTensorAtLevel(padded, cur_level); + auto [cpu_nested_shape_example_value, cpu_nested_shape_example_bdim] = unwrapTensorAtLevel(cpu_nested_shape_example, cur_level); + auto results = batch_rule(padded_value, padded_bdim, cpu_nested_shape_example_value, cpu_nested_shape_example_bdim, fuse_transform_0213); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _nested_tensor_size_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_nested_tensor_size::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _nested_tensor_strides_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_nested_tensor_strides::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _nested_tensor_storage_offsets_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_nested_tensor_storage_offsets::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _nested_from_padded_and_nested_example_generated_plumbing(const at::Tensor & padded, const at::Tensor & nt_example) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(padded, cur_level) && !isBatchedAtLevel(nt_example, cur_level)) { + return at::_ops::_nested_from_padded_and_nested_example::call(padded, nt_example); + } + auto [padded_value, padded_bdim] = unwrapTensorAtLevel(padded, cur_level); + auto [nt_example_value, nt_example_bdim] = unwrapTensorAtLevel(nt_example, cur_level); + auto results = batch_rule(padded_value, padded_bdim, nt_example_value, nt_example_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _nested_view_from_buffer_generated_plumbing(const at::Tensor & self, const at::Tensor & nested_size, const at::Tensor & nested_strides, const at::Tensor & offsets) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(nested_size, cur_level) && !isBatchedAtLevel(nested_strides, cur_level) && !isBatchedAtLevel(offsets, cur_level)) { + return at::_ops::_nested_view_from_buffer::call(self, nested_size, nested_strides, offsets); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [nested_size_value, nested_size_bdim] = unwrapTensorAtLevel(nested_size, cur_level); + auto [nested_strides_value, nested_strides_bdim] = unwrapTensorAtLevel(nested_strides, cur_level); + auto [offsets_value, offsets_bdim] = unwrapTensorAtLevel(offsets, cur_level); + auto results = batch_rule(self_value, self_bdim, nested_size_value, nested_size_bdim, nested_strides_value, nested_strides_bdim, offsets_value, offsets_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _nested_view_from_buffer_copy_generated_plumbing(const at::Tensor & self, const at::Tensor & nested_size, const at::Tensor & nested_strides, const at::Tensor & offsets) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(nested_size, cur_level) && !isBatchedAtLevel(nested_strides, cur_level) && !isBatchedAtLevel(offsets, cur_level)) { + return at::_ops::_nested_view_from_buffer_copy::call(self, nested_size, nested_strides, offsets); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [nested_size_value, nested_size_bdim] = unwrapTensorAtLevel(nested_size, cur_level); + auto [nested_strides_value, nested_strides_bdim] = unwrapTensorAtLevel(nested_strides, cur_level); + auto [offsets_value, offsets_bdim] = unwrapTensorAtLevel(offsets, cur_level); + auto results = batch_rule(self_value, self_bdim, nested_size_value, nested_size_bdim, nested_strides_value, nested_strides_bdim, offsets_value, offsets_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _nested_view_from_jagged_generated_plumbing(const at::Tensor & self, const at::Tensor & offsets, const at::Tensor & dummy, const ::std::optional & lengths, int64_t ragged_idx, const ::std::optional & min_seqlen, const ::std::optional & max_seqlen) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(offsets, cur_level) && !isBatchedAtLevel(dummy, cur_level) && !isBatchedAtLevel(lengths, cur_level) && !isBatchedAtLevel(min_seqlen, cur_level) && !isBatchedAtLevel(max_seqlen, cur_level)) { + return at::_ops::_nested_view_from_jagged::call(self, offsets, dummy, lengths, ragged_idx, min_seqlen, max_seqlen); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [offsets_value, offsets_bdim] = unwrapTensorAtLevel(offsets, cur_level); + auto [dummy_value, dummy_bdim] = unwrapTensorAtLevel(dummy, cur_level); + std::optional lengths_value; + std::optional lengths_bdim; + if (lengths) { + std::tie(lengths_value, lengths_bdim) = unwrapTensorAtLevel(lengths.value(), cur_level); + } + std::optional min_seqlen_value; + std::optional min_seqlen_bdim; + if (min_seqlen) { + std::tie(min_seqlen_value, min_seqlen_bdim) = unwrapTensorAtLevel(min_seqlen.value(), cur_level); + } + std::optional max_seqlen_value; + std::optional max_seqlen_bdim; + if (max_seqlen) { + std::tie(max_seqlen_value, max_seqlen_bdim) = unwrapTensorAtLevel(max_seqlen.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, offsets_value, offsets_bdim, dummy_value, dummy_bdim, lengths_value, lengths_bdim, ragged_idx, min_seqlen_value, min_seqlen_bdim, max_seqlen_value, max_seqlen_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _nested_view_from_jagged_copy_generated_plumbing(const at::Tensor & self, const at::Tensor & offsets, const at::Tensor & dummy, const ::std::optional & lengths, int64_t ragged_idx, const ::std::optional & min_seqlen, const ::std::optional & max_seqlen) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(offsets, cur_level) && !isBatchedAtLevel(dummy, cur_level) && !isBatchedAtLevel(lengths, cur_level) && !isBatchedAtLevel(min_seqlen, cur_level) && !isBatchedAtLevel(max_seqlen, cur_level)) { + return at::_ops::_nested_view_from_jagged_copy::call(self, offsets, dummy, lengths, ragged_idx, min_seqlen, max_seqlen); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [offsets_value, offsets_bdim] = unwrapTensorAtLevel(offsets, cur_level); + auto [dummy_value, dummy_bdim] = unwrapTensorAtLevel(dummy, cur_level); + std::optional lengths_value; + std::optional lengths_bdim; + if (lengths) { + std::tie(lengths_value, lengths_bdim) = unwrapTensorAtLevel(lengths.value(), cur_level); + } + std::optional min_seqlen_value; + std::optional min_seqlen_bdim; + if (min_seqlen) { + std::tie(min_seqlen_value, min_seqlen_bdim) = unwrapTensorAtLevel(min_seqlen.value(), cur_level); + } + std::optional max_seqlen_value; + std::optional max_seqlen_bdim; + if (max_seqlen) { + std::tie(max_seqlen_value, max_seqlen_bdim) = unwrapTensorAtLevel(max_seqlen.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, offsets_value, offsets_bdim, dummy_value, dummy_bdim, lengths_value, lengths_bdim, ragged_idx, min_seqlen_value, min_seqlen_bdim, max_seqlen_value, max_seqlen_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _nested_get_values_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_nested_get_values::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _nested_get_values_copy_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_nested_get_values_copy::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _nested_get_offsets_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_nested_get_offsets::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _nested_get_lengths_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_nested_get_lengths::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _nested_get_min_seqlen_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_nested_get_min_seqlen::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _nested_get_max_seqlen_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_nested_get_max_seqlen::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _nested_get_jagged_dummy_generated_plumbing(const at::Tensor & any) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(any, cur_level)) { + return at::_ops::_nested_get_jagged_dummy::call(any); + } + auto [any_value, any_bdim] = unwrapTensorAtLevel(any, cur_level); + auto results = batch_rule(any_value, any_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple _nested_compute_contiguous_strides_offsets_generated_plumbing(const at::Tensor & nested_size) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(nested_size, cur_level)) { + return at::_ops::_nested_compute_contiguous_strides_offsets::call(nested_size); + } + auto [nested_size_value, nested_size_bdim] = unwrapTensorAtLevel(nested_size, cur_level); + auto results = batch_rule(nested_size_value, nested_size_bdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor _trilinear_generated_plumbing(const at::Tensor & i1, const at::Tensor & i2, const at::Tensor & i3, at::IntArrayRef expand1, at::IntArrayRef expand2, at::IntArrayRef expand3, at::IntArrayRef sumdim, int64_t unroll_dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(i1, cur_level) && !isBatchedAtLevel(i2, cur_level) && !isBatchedAtLevel(i3, cur_level)) { + return at::_ops::_trilinear::call(i1, i2, i3, expand1, expand2, expand3, sumdim, unroll_dim); + } + auto [i1_value, i1_bdim] = unwrapTensorAtLevel(i1, cur_level); + auto [i2_value, i2_bdim] = unwrapTensorAtLevel(i2, cur_level); + auto [i3_value, i3_bdim] = unwrapTensorAtLevel(i3, cur_level); + auto results = batch_rule(i1_value, i1_bdim, i2_value, i2_bdim, i3_value, i3_bdim, expand1, expand2, expand3, sumdim, unroll_dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor triplet_margin_loss_generated_plumbing(const at::Tensor & anchor, const at::Tensor & positive, const at::Tensor & negative, double margin, double p, double eps, bool swap, int64_t reduction) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(anchor, cur_level) && !isBatchedAtLevel(positive, cur_level) && !isBatchedAtLevel(negative, cur_level)) { + return at::_ops::triplet_margin_loss::call(anchor, positive, negative, margin, p, eps, swap, reduction); + } + auto [anchor_value, anchor_bdim] = unwrapTensorAtLevel(anchor, cur_level); + auto [positive_value, positive_bdim] = unwrapTensorAtLevel(positive, cur_level); + auto [negative_value, negative_bdim] = unwrapTensorAtLevel(negative, cur_level); + auto results = batch_rule(anchor_value, anchor_bdim, positive_value, positive_bdim, negative_value, negative_bdim, margin, p, eps, swap, reduction); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor trunc_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::trunc::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & trunc__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::trunc_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor fix_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::fix::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & fix__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::fix_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor type_as_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::type_as::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple _unique_generated_plumbing(const at::Tensor & self, bool sorted, bool return_inverse) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_unique::call(self, sorted, return_inverse); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, sorted, return_inverse); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple unique_dim_generated_plumbing(const at::Tensor & self, int64_t dim, bool sorted, bool return_inverse, bool return_counts) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::unique_dim::call(self, dim, sorted, return_inverse, return_counts); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, sorted, return_inverse, return_counts); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level)); +} +template +::std::tuple unique_consecutive_generated_plumbing(const at::Tensor & self, bool return_inverse, bool return_counts, ::std::optional dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::unique_consecutive::call(self, return_inverse, return_counts, dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, return_inverse, return_counts, dim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level)); +} +template +::std::tuple unique_dim_consecutive_generated_plumbing(const at::Tensor & self, int64_t dim, bool return_inverse, bool return_counts) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::unique_dim_consecutive::call(self, dim, return_inverse, return_counts); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, return_inverse, return_counts); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level)); +} +template +::std::tuple _unique2_generated_plumbing(const at::Tensor & self, bool sorted, bool return_inverse, bool return_counts) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_unique2::call(self, sorted, return_inverse, return_counts); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, sorted, return_inverse, return_counts); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level)); +} +template +at::Tensor _unsafe_view_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef size) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_unsafe_view::call(self, size); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, size); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor unsqueeze_generated_plumbing(const at::Tensor & self, int64_t dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::unsqueeze::call(self, dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor vander_generated_plumbing(const at::Tensor & x, ::std::optional N, bool increasing) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(x, cur_level)) { + return at::_ops::vander::call(x, N, increasing); + } + auto [x_value, x_bdim] = unwrapTensorAtLevel(x, cur_level); + auto results = batch_rule(x_value, x_bdim, N, increasing); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor var_generated_plumbing(const at::Tensor & self, bool unbiased) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::var::call(self, unbiased); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, unbiased); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor var_dim_generated_plumbing(const at::Tensor & self, at::OptionalIntArrayRef dim, bool unbiased, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::var_dim::call(self, dim, unbiased, keepdim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, unbiased, keepdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor var_correction_generated_plumbing(const at::Tensor & self, at::OptionalIntArrayRef dim, const ::std::optional & correction, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::var_correction::call(self, dim, correction, keepdim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, correction, keepdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor var_names_dim_generated_plumbing(const at::Tensor & self, at::DimnameList dim, bool unbiased, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::var_names_dim::call(self, dim, unbiased, keepdim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, unbiased, keepdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor var_correction_names_generated_plumbing(const at::Tensor & self, at::DimnameList dim, const ::std::optional & correction, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::var_correction_names::call(self, dim, correction, keepdim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, correction, keepdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple var_mean_generated_plumbing(const at::Tensor & self, bool unbiased) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::var_mean::call(self, unbiased); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, unbiased); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple var_mean_dim_generated_plumbing(const at::Tensor & self, at::OptionalIntArrayRef dim, bool unbiased, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::var_mean_dim::call(self, dim, unbiased, keepdim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, unbiased, keepdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple var_mean_correction_generated_plumbing(const at::Tensor & self, at::OptionalIntArrayRef dim, const ::std::optional & correction, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::var_mean_correction::call(self, dim, correction, keepdim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, correction, keepdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple var_mean_names_dim_generated_plumbing(const at::Tensor & self, at::DimnameList dim, bool unbiased, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::var_mean_names_dim::call(self, dim, unbiased, keepdim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, unbiased, keepdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple var_mean_correction_names_generated_plumbing(const at::Tensor & self, at::DimnameList dim, const ::std::optional & correction, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::var_mean_correction_names::call(self, dim, correction, keepdim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, correction, keepdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor view_as_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::view_as::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor where_self_generated_plumbing(const at::Tensor & condition, const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(condition, cur_level) && !isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::where_self::call(condition, self, other); + } + auto [condition_value, condition_bdim] = unwrapTensorAtLevel(condition, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(condition_value, condition_bdim, self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor where_ScalarSelf_generated_plumbing(const at::Tensor & condition, const at::Scalar & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(condition, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::where_ScalarSelf::call(condition, self, other); + } + auto [condition_value, condition_bdim] = unwrapTensorAtLevel(condition, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(condition_value, condition_bdim, self, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor where_ScalarOther_generated_plumbing(const at::Tensor & condition, const at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(condition, cur_level) && !isBatchedAtLevel(self, cur_level)) { + return at::_ops::where_ScalarOther::call(condition, self, other); + } + auto [condition_value, condition_bdim] = unwrapTensorAtLevel(condition, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(condition_value, condition_bdim, self_value, self_bdim, other); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor where_Scalar_generated_plumbing(const at::Tensor & condition, const at::Scalar & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(condition, cur_level)) { + return at::_ops::where_Scalar::call(condition, self, other); + } + auto [condition_value, condition_bdim] = unwrapTensorAtLevel(condition, cur_level); + auto results = batch_rule(condition_value, condition_bdim, self, other); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::vector where_generated_plumbing(const at::Tensor & condition) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(condition, cur_level)) { + return at::_ops::where::call(condition); + } + auto [condition_value, condition_bdim] = unwrapTensorAtLevel(condition, cur_level); + auto results = batch_rule(condition_value, condition_bdim); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor norm_except_dim_generated_plumbing(const at::Tensor & v, int64_t pow, int64_t dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(v, cur_level)) { + return at::_ops::norm_except_dim::call(v, pow, dim); + } + auto [v_value, v_bdim] = unwrapTensorAtLevel(v, cur_level); + auto results = batch_rule(v_value, v_bdim, pow, dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _weight_norm_generated_plumbing(const at::Tensor & v, const at::Tensor & g, int64_t dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(v, cur_level) && !isBatchedAtLevel(g, cur_level)) { + return at::_ops::_weight_norm::call(v, g, dim); + } + auto [v_value, v_bdim] = unwrapTensorAtLevel(v, cur_level); + auto [g_value, g_bdim] = unwrapTensorAtLevel(g, cur_level); + auto results = batch_rule(v_value, v_bdim, g_value, g_bdim, dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple _weight_norm_interface_generated_plumbing(const at::Tensor & v, const at::Tensor & g, int64_t dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(v, cur_level) && !isBatchedAtLevel(g, cur_level)) { + return at::_ops::_weight_norm_interface::call(v, g, dim); + } + auto [v_value, v_bdim] = unwrapTensorAtLevel(v, cur_level); + auto [g_value, g_bdim] = unwrapTensorAtLevel(g, cur_level); + auto results = batch_rule(v_value, v_bdim, g_value, g_bdim, dim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple _weight_norm_interface_backward_generated_plumbing(const at::Tensor & grad_w, const at::Tensor & saved_v, const at::Tensor & saved_g, const at::Tensor & saved_norms, int64_t dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_w, cur_level) && !isBatchedAtLevel(saved_v, cur_level) && !isBatchedAtLevel(saved_g, cur_level) && !isBatchedAtLevel(saved_norms, cur_level)) { + return at::_ops::_weight_norm_interface_backward::call(grad_w, saved_v, saved_g, saved_norms, dim); + } + auto [grad_w_value, grad_w_bdim] = unwrapTensorAtLevel(grad_w, cur_level); + auto [saved_v_value, saved_v_bdim] = unwrapTensorAtLevel(saved_v, cur_level); + auto [saved_g_value, saved_g_bdim] = unwrapTensorAtLevel(saved_g, cur_level); + auto [saved_norms_value, saved_norms_bdim] = unwrapTensorAtLevel(saved_norms, cur_level); + auto results = batch_rule(grad_w_value, grad_w_bdim, saved_v_value, saved_v_bdim, saved_g_value, saved_g_bdim, saved_norms_value, saved_norms_bdim, dim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple _weight_norm_differentiable_backward_generated_plumbing(const at::Tensor & grad_w, const at::Tensor & saved_v, const at::Tensor & saved_g, const at::Tensor & saved_norms, int64_t dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_w, cur_level) && !isBatchedAtLevel(saved_v, cur_level) && !isBatchedAtLevel(saved_g, cur_level) && !isBatchedAtLevel(saved_norms, cur_level)) { + return at::_ops::_weight_norm_differentiable_backward::call(grad_w, saved_v, saved_g, saved_norms, dim); + } + auto [grad_w_value, grad_w_bdim] = unwrapTensorAtLevel(grad_w, cur_level); + auto [saved_v_value, saved_v_bdim] = unwrapTensorAtLevel(saved_v, cur_level); + auto [saved_g_value, saved_g_bdim] = unwrapTensorAtLevel(saved_g, cur_level); + auto [saved_norms_value, saved_norms_bdim] = unwrapTensorAtLevel(saved_norms, cur_level); + auto results = batch_rule(grad_w_value, grad_w_bdim, saved_v_value, saved_v_bdim, saved_g_value, saved_g_bdim, saved_norms_value, saved_norms_bdim, dim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor zeros_like_generated_plumbing(const at::Tensor & self, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::zeros_like::call(self, dtype, layout, device, pin_memory, memory_format); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dtype, layout, device, pin_memory, memory_format); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _standard_gamma_grad_generated_plumbing(const at::Tensor & self, const at::Tensor & output) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(output, cur_level)) { + return at::_ops::_standard_gamma_grad::call(self, output); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [output_value, output_bdim] = unwrapTensorAtLevel(output, cur_level); + auto results = batch_rule(self_value, self_bdim, output_value, output_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _standard_gamma_generated_plumbing(const at::Tensor & self, ::std::optional generator) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_standard_gamma::call(self, generator); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, generator); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _dirichlet_grad_generated_plumbing(const at::Tensor & x, const at::Tensor & alpha, const at::Tensor & total) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(x, cur_level) && !isBatchedAtLevel(alpha, cur_level) && !isBatchedAtLevel(total, cur_level)) { + return at::_ops::_dirichlet_grad::call(x, alpha, total); + } + auto [x_value, x_bdim] = unwrapTensorAtLevel(x, cur_level); + auto [alpha_value, alpha_bdim] = unwrapTensorAtLevel(alpha, cur_level); + auto [total_value, total_bdim] = unwrapTensorAtLevel(total, cur_level); + auto results = batch_rule(x_value, x_bdim, alpha_value, alpha_bdim, total_value, total_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _sample_dirichlet_generated_plumbing(const at::Tensor & self, ::std::optional generator) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_sample_dirichlet::call(self, generator); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, generator); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor poisson_generated_plumbing(const at::Tensor & self, ::std::optional generator) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::poisson::call(self, generator); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, generator); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor binomial_generated_plumbing(const at::Tensor & count, const at::Tensor & prob, ::std::optional generator) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(count, cur_level) && !isBatchedAtLevel(prob, cur_level)) { + return at::_ops::binomial::call(count, prob, generator); + } + auto [count_value, count_bdim] = unwrapTensorAtLevel(count, cur_level); + auto [prob_value, prob_bdim] = unwrapTensorAtLevel(prob, cur_level); + auto results = batch_rule(count_value, count_bdim, prob_value, prob_bdim, generator); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor native_norm_generated_plumbing(const at::Tensor & self, const at::Scalar & p) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::native_norm::call(self, p); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, p); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor native_norm_ScalarOpt_dim_dtype_generated_plumbing(const at::Tensor & self, const ::std::optional & p, at::IntArrayRef dim, bool keepdim, ::std::optional dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::native_norm_ScalarOpt_dim_dtype::call(self, p, dim, keepdim, dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, p, dim, keepdim, dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple _batch_norm_no_update_generated_plumbing(const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const ::std::optional & running_mean, const ::std::optional & running_var, double momentum, double eps) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level) && !isBatchedAtLevel(running_mean, cur_level) && !isBatchedAtLevel(running_var, cur_level)) { + return at::_ops::_batch_norm_no_update::call(input, weight, bias, running_mean, running_var, momentum, eps); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + std::optional weight_value; + std::optional weight_bdim; + if (weight) { + std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight.value(), cur_level); + } + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + std::optional running_mean_value; + std::optional running_mean_bdim; + if (running_mean) { + std::tie(running_mean_value, running_mean_bdim) = unwrapTensorAtLevel(running_mean.value(), cur_level); + } + std::optional running_var_value; + std::optional running_var_bdim; + if (running_var) { + std::tie(running_var_value, running_var_bdim) = unwrapTensorAtLevel(running_var.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, weight_value, weight_bdim, bias_value, bias_bdim, running_mean_value, running_mean_bdim, running_var_value, running_var_bdim, momentum, eps); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level), makeBatched(std::get<6>(results), std::get<7>(results), cur_level)); +} +template +::std::tuple batch_norm_backward_generated_plumbing(const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & running_mean, const ::std::optional & running_var, const ::std::optional & save_mean, const ::std::optional & save_var, bool update, double eps, ::std::array output_mask, const at::Tensor & reserve) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_out, cur_level) && !isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(running_mean, cur_level) && !isBatchedAtLevel(running_var, cur_level) && !isBatchedAtLevel(save_mean, cur_level) && !isBatchedAtLevel(save_var, cur_level) && !isBatchedAtLevel(reserve, cur_level)) { + return at::_ops::batch_norm_backward::call(grad_out, input, weight, running_mean, running_var, save_mean, save_var, update, eps, output_mask, reserve); + } + auto [grad_out_value, grad_out_bdim] = unwrapTensorAtLevel(grad_out, cur_level); + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + auto [reserve_value, reserve_bdim] = unwrapTensorAtLevel(reserve, cur_level); + std::optional running_mean_value; + std::optional running_mean_bdim; + if (running_mean) { + std::tie(running_mean_value, running_mean_bdim) = unwrapTensorAtLevel(running_mean.value(), cur_level); + } + std::optional running_var_value; + std::optional running_var_bdim; + if (running_var) { + std::tie(running_var_value, running_var_bdim) = unwrapTensorAtLevel(running_var.value(), cur_level); + } + std::optional save_mean_value; + std::optional save_mean_bdim; + if (save_mean) { + std::tie(save_mean_value, save_mean_bdim) = unwrapTensorAtLevel(save_mean.value(), cur_level); + } + std::optional save_var_value; + std::optional save_var_bdim; + if (save_var) { + std::tie(save_var_value, save_var_bdim) = unwrapTensorAtLevel(save_var.value(), cur_level); + } + auto results = batch_rule(grad_out_value, grad_out_bdim, input_value, input_bdim, weight_value, weight_bdim, running_mean_value, running_mean_bdim, running_var_value, running_var_bdim, save_mean_value, save_mean_bdim, save_var_value, save_var_bdim, update, eps, output_mask, reserve_value, reserve_bdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level)); +} +template +at::Tensor _sparse_sum_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_sparse_sum::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _sparse_sum_dtype_generated_plumbing(const at::Tensor & self, at::ScalarType dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_sparse_sum_dtype::call(self, dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _sparse_sum_dim_generated_plumbing(const at::Tensor & self, at::IntArrayRef dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_sparse_sum_dim::call(self, dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _sparse_sum_dim_dtype_generated_plumbing(const at::Tensor & self, at::IntArrayRef dim, at::ScalarType dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_sparse_sum_dim_dtype::call(self, dim, dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _sparse_sum_backward_generated_plumbing(const at::Tensor & grad, const at::Tensor & self, at::IntArrayRef dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad, cur_level) && !isBatchedAtLevel(self, cur_level)) { + return at::_ops::_sparse_sum_backward::call(grad, self, dim); + } + auto [grad_value, grad_bdim] = unwrapTensorAtLevel(grad, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(grad_value, grad_bdim, self_value, self_bdim, dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _sparse_csr_sum_dim_dtype_generated_plumbing(const at::Tensor & self, at::IntArrayRef dim, bool keepdim, ::std::optional dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_sparse_csr_sum_dim_dtype::call(self, dim, keepdim, dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, keepdim, dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _sparse_csr_prod_dim_dtype_generated_plumbing(const at::Tensor & self, at::IntArrayRef dim, bool keepdim, ::std::optional dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_sparse_csr_prod_dim_dtype::call(self, dim, keepdim, dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, keepdim, dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _sparse_softmax_int_generated_plumbing(const at::Tensor & self, int64_t dim, ::std::optional dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_sparse_softmax_int::call(self, dim, dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _sparse_softmax_Dimname_generated_plumbing(const at::Tensor & self, at::Dimname dim, ::std::optional dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_sparse_softmax_Dimname::call(self, dim, dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _sparse_softmax_generated_plumbing(const at::Tensor & self, int64_t dim, bool half_to_float) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_sparse_softmax::call(self, dim, half_to_float); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, half_to_float); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _sparse_softmax_backward_data_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & output, int64_t dim, const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(output, cur_level) && !isBatchedAtLevel(self, cur_level)) { + return at::_ops::_sparse_softmax_backward_data::call(grad_output, output, dim, self); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [output_value, output_bdim] = unwrapTensorAtLevel(output, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, output_value, output_bdim, dim, self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _sparse_log_softmax_int_generated_plumbing(const at::Tensor & self, int64_t dim, ::std::optional dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_sparse_log_softmax_int::call(self, dim, dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _sparse_log_softmax_Dimname_generated_plumbing(const at::Tensor & self, at::Dimname dim, ::std::optional dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_sparse_log_softmax_Dimname::call(self, dim, dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _sparse_log_softmax_generated_plumbing(const at::Tensor & self, int64_t dim, bool half_to_float) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_sparse_log_softmax::call(self, dim, half_to_float); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, half_to_float); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _sparse_log_softmax_backward_data_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & output, int64_t dim, const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(output, cur_level) && !isBatchedAtLevel(self, cur_level)) { + return at::_ops::_sparse_log_softmax_backward_data::call(grad_output, output, dim, self); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [output_value, output_bdim] = unwrapTensorAtLevel(output, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, output_value, output_bdim, dim, self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _spdiags_generated_plumbing(const at::Tensor & diagonals, const at::Tensor & offsets, at::IntArrayRef shape, ::std::optional layout) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(diagonals, cur_level) && !isBatchedAtLevel(offsets, cur_level)) { + return at::_ops::_spdiags::call(diagonals, offsets, shape, layout); + } + auto [diagonals_value, diagonals_bdim] = unwrapTensorAtLevel(diagonals, cur_level); + auto [offsets_value, offsets_bdim] = unwrapTensorAtLevel(offsets, cur_level); + auto results = batch_rule(diagonals_value, diagonals_bdim, offsets_value, offsets_bdim, shape, layout); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor norm_ScalarOpt_dtype_generated_plumbing(const at::Tensor & self, const ::std::optional & p, at::ScalarType dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::norm_ScalarOpt_dtype::call(self, p, dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, p, dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor norm_Scalar_generated_plumbing(const at::Tensor & self, const at::Scalar & p) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::norm_Scalar::call(self, p); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, p); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor norm_ScalarOpt_dim_dtype_generated_plumbing(const at::Tensor & self, const ::std::optional & p, at::IntArrayRef dim, bool keepdim, at::ScalarType dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::norm_ScalarOpt_dim_dtype::call(self, p, dim, keepdim, dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, p, dim, keepdim, dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor norm_ScalarOpt_dim_generated_plumbing(const at::Tensor & self, const ::std::optional & p, at::IntArrayRef dim, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::norm_ScalarOpt_dim::call(self, p, dim, keepdim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, p, dim, keepdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor norm_names_ScalarOpt_dim_dtype_generated_plumbing(const at::Tensor & self, const ::std::optional & p, at::DimnameList dim, bool keepdim, at::ScalarType dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::norm_names_ScalarOpt_dim_dtype::call(self, p, dim, keepdim, dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, p, dim, keepdim, dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor norm_names_ScalarOpt_dim_generated_plumbing(const at::Tensor & self, const ::std::optional & p, at::DimnameList dim, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::norm_names_ScalarOpt_dim::call(self, p, dim, keepdim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, p, dim, keepdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple frexp_Tensor_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::frexp_Tensor::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor frobenius_norm_dim_generated_plumbing(const at::Tensor & self, at::IntArrayRef dim, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::frobenius_norm_dim::call(self, dim, keepdim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, keepdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor nuclear_norm_generated_plumbing(const at::Tensor & self, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::nuclear_norm::call(self, keepdim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, keepdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor nuclear_norm_dim_generated_plumbing(const at::Tensor & self, at::IntArrayRef dim, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::nuclear_norm_dim::call(self, dim, keepdim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, keepdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor clone_generated_plumbing(const at::Tensor & self, ::std::optional memory_format) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::clone::call(self, memory_format); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, memory_format); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor positive_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::positive::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +const at::Tensor & resize_as_sparse__generated_plumbing(const at::Tensor & self, const at::Tensor & the_template) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(the_template, cur_level)) { + return at::_ops::resize_as_sparse_::call(self, the_template); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [the_template_value, the_template_bdim] = unwrapTensorAtLevel(the_template, cur_level); + batch_rule(self_value, self_bdim, the_template_value, the_template_bdim); + return self; +} +template +at::Tensor & zero__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::zero_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor sub_Tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::sub_Tensor::call(self, other, alpha); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim, alpha); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & sub__Tensor_generated_plumbing(at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::sub__Tensor::call(self, other, alpha); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self_value, self_bdim, other_value, other_bdim, alpha); + return self; +} +template +at::Tensor sub_Scalar_generated_plumbing(const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::sub_Scalar::call(self, other, alpha); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, other, alpha); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & sub__Scalar_generated_plumbing(at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::sub__Scalar::call(self, other, alpha); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, other, alpha); + return self; +} +template +at::Tensor subtract_Tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::subtract_Tensor::call(self, other, alpha); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim, alpha); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & subtract__Tensor_generated_plumbing(at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::subtract__Tensor::call(self, other, alpha); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self_value, self_bdim, other_value, other_bdim, alpha); + return self; +} +template +at::Tensor subtract_Scalar_generated_plumbing(const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::subtract_Scalar::call(self, other, alpha); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, other, alpha); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & subtract__Scalar_generated_plumbing(at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::subtract__Scalar::call(self, other, alpha); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, other, alpha); + return self; +} +template +at::Tensor rsub_Tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::rsub_Tensor::call(self, other, alpha); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim, alpha); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor heaviside_generated_plumbing(const at::Tensor & self, const at::Tensor & values) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(values, cur_level)) { + return at::_ops::heaviside::call(self, values); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [values_value, values_bdim] = unwrapTensorAtLevel(values, cur_level); + auto results = batch_rule(self_value, self_bdim, values_value, values_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & heaviside__generated_plumbing(at::Tensor & self, const at::Tensor & values) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(values, cur_level)) { + return at::_ops::heaviside_::call(self, values); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [values_value, values_bdim] = unwrapTensorAtLevel(values, cur_level); + batch_rule(self_value, self_bdim, values_value, values_bdim); + return self; +} +template +at::Tensor rsub_Scalar_generated_plumbing(const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::rsub_Scalar::call(self, other, alpha); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, other, alpha); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _sparse_addmm_generated_plumbing(const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta, const at::Scalar & alpha) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(mat1, cur_level) && !isBatchedAtLevel(mat2, cur_level)) { + return at::_ops::_sparse_addmm::call(self, mat1, mat2, beta, alpha); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [mat1_value, mat1_bdim] = unwrapTensorAtLevel(mat1, cur_level); + auto [mat2_value, mat2_bdim] = unwrapTensorAtLevel(mat2, cur_level); + auto results = batch_rule(self_value, self_bdim, mat1_value, mat1_bdim, mat2_value, mat2_bdim, beta, alpha); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor sparse_sampled_addmm_generated_plumbing(const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta, const at::Scalar & alpha) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(mat1, cur_level) && !isBatchedAtLevel(mat2, cur_level)) { + return at::_ops::sparse_sampled_addmm::call(self, mat1, mat2, beta, alpha); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [mat1_value, mat1_bdim] = unwrapTensorAtLevel(mat1, cur_level); + auto [mat2_value, mat2_bdim] = unwrapTensorAtLevel(mat2, cur_level); + auto results = batch_rule(self_value, self_bdim, mat1_value, mat1_bdim, mat2_value, mat2_bdim, beta, alpha); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple _sparse_mm_reduce_impl_generated_plumbing(const at::Tensor & self, const at::Tensor & other, c10::string_view reduce) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::_sparse_mm_reduce_impl::call(self, other, reduce); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim, reduce); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple _sparse_mm_reduce_impl_backward_generated_plumbing(const at::Tensor & self, const at::Tensor & grad_out, const at::Tensor & weight, c10::string_view reduce, const at::Tensor & arg_out, ::std::array output_mask) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(grad_out, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(arg_out, cur_level)) { + return at::_ops::_sparse_mm_reduce_impl_backward::call(self, grad_out, weight, reduce, arg_out, output_mask); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [grad_out_value, grad_out_bdim] = unwrapTensorAtLevel(grad_out, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + auto [arg_out_value, arg_out_bdim] = unwrapTensorAtLevel(arg_out, cur_level); + auto results = batch_rule(self_value, self_bdim, grad_out_value, grad_out_bdim, weight_value, weight_bdim, reduce, arg_out_value, arg_out_bdim, output_mask); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor addmm_generated_plumbing(const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta, const at::Scalar & alpha) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(mat1, cur_level) && !isBatchedAtLevel(mat2, cur_level)) { + return at::_ops::addmm::call(self, mat1, mat2, beta, alpha); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [mat1_value, mat1_bdim] = unwrapTensorAtLevel(mat1, cur_level); + auto [mat2_value, mat2_bdim] = unwrapTensorAtLevel(mat2, cur_level); + auto results = batch_rule(self_value, self_bdim, mat1_value, mat1_bdim, mat2_value, mat2_bdim, beta, alpha); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor addmm_dtype_generated_plumbing(const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, at::ScalarType out_dtype, const at::Scalar & beta, const at::Scalar & alpha) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(mat1, cur_level) && !isBatchedAtLevel(mat2, cur_level)) { + return at::_ops::addmm_dtype::call(self, mat1, mat2, out_dtype, beta, alpha); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [mat1_value, mat1_bdim] = unwrapTensorAtLevel(mat1, cur_level); + auto [mat2_value, mat2_bdim] = unwrapTensorAtLevel(mat2, cur_level); + auto results = batch_rule(self_value, self_bdim, mat1_value, mat1_bdim, mat2_value, mat2_bdim, out_dtype, beta, alpha); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & addmm__generated_plumbing(at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta, const at::Scalar & alpha) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(mat1, cur_level) && !isBatchedAtLevel(mat2, cur_level)) { + return at::_ops::addmm_::call(self, mat1, mat2, beta, alpha); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [mat1_value, mat1_bdim] = unwrapTensorAtLevel(mat1, cur_level); + auto [mat2_value, mat2_bdim] = unwrapTensorAtLevel(mat2, cur_level); + batch_rule(self_value, self_bdim, mat1_value, mat1_bdim, mat2_value, mat2_bdim, beta, alpha); + return self; +} +template +at::Tensor _addmm_activation_generated_plumbing(const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta, const at::Scalar & alpha, bool use_gelu) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(mat1, cur_level) && !isBatchedAtLevel(mat2, cur_level)) { + return at::_ops::_addmm_activation::call(self, mat1, mat2, beta, alpha, use_gelu); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [mat1_value, mat1_bdim] = unwrapTensorAtLevel(mat1, cur_level); + auto [mat2_value, mat2_bdim] = unwrapTensorAtLevel(mat2, cur_level); + auto results = batch_rule(self_value, self_bdim, mat1_value, mat1_bdim, mat2_value, mat2_bdim, beta, alpha, use_gelu); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _scaled_mm_generated_plumbing(const at::Tensor & self, const at::Tensor & mat2, const at::Tensor & scale_a, const at::Tensor & scale_b, const ::std::optional & bias, const ::std::optional & scale_result, ::std::optional out_dtype, bool use_fast_accum) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(mat2, cur_level) && !isBatchedAtLevel(scale_a, cur_level) && !isBatchedAtLevel(scale_b, cur_level) && !isBatchedAtLevel(bias, cur_level) && !isBatchedAtLevel(scale_result, cur_level)) { + return at::_ops::_scaled_mm::call(self, mat2, scale_a, scale_b, bias, scale_result, out_dtype, use_fast_accum); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [mat2_value, mat2_bdim] = unwrapTensorAtLevel(mat2, cur_level); + auto [scale_a_value, scale_a_bdim] = unwrapTensorAtLevel(scale_a, cur_level); + auto [scale_b_value, scale_b_bdim] = unwrapTensorAtLevel(scale_b, cur_level); + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + std::optional scale_result_value; + std::optional scale_result_bdim; + if (scale_result) { + std::tie(scale_result_value, scale_result_bdim) = unwrapTensorAtLevel(scale_result.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, mat2_value, mat2_bdim, scale_a_value, scale_a_bdim, scale_b_value, scale_b_bdim, bias_value, bias_bdim, scale_result_value, scale_result_bdim, out_dtype, use_fast_accum); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _scaled_mm_v2_generated_plumbing(const at::Tensor & self, const at::Tensor & mat2, at::TensorList scale_a, at::IntArrayRef recipe_a, at::IntArrayRef swizzle_a, at::TensorList scale_b, at::IntArrayRef recipe_b, at::IntArrayRef swizzle_b, const ::std::optional & bias, ::std::optional out_dtype, at::IntArrayRef contraction_dim, bool use_fast_accum) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(mat2, cur_level) && !isBatchedAtLevel(scale_a, cur_level) && !isBatchedAtLevel(scale_b, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::_scaled_mm_v2::call(self, mat2, scale_a, recipe_a, swizzle_a, scale_b, recipe_b, swizzle_b, bias, out_dtype, contraction_dim, use_fast_accum); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [mat2_value, mat2_bdim] = unwrapTensorAtLevel(mat2, cur_level); + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, mat2_value, mat2_bdim, scale_a, recipe_a, swizzle_a, scale_b, recipe_b, swizzle_b, bias_value, bias_bdim, out_dtype, contraction_dim, use_fast_accum); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _scaled_grouped_mm_generated_plumbing(const at::Tensor & self, const at::Tensor & mat2, const at::Tensor & scale_a, const at::Tensor & scale_b, const ::std::optional & offs, const ::std::optional & bias, const ::std::optional & scale_result, ::std::optional out_dtype, bool use_fast_accum) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(mat2, cur_level) && !isBatchedAtLevel(scale_a, cur_level) && !isBatchedAtLevel(scale_b, cur_level) && !isBatchedAtLevel(offs, cur_level) && !isBatchedAtLevel(bias, cur_level) && !isBatchedAtLevel(scale_result, cur_level)) { + return at::_ops::_scaled_grouped_mm::call(self, mat2, scale_a, scale_b, offs, bias, scale_result, out_dtype, use_fast_accum); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [mat2_value, mat2_bdim] = unwrapTensorAtLevel(mat2, cur_level); + auto [scale_a_value, scale_a_bdim] = unwrapTensorAtLevel(scale_a, cur_level); + auto [scale_b_value, scale_b_bdim] = unwrapTensorAtLevel(scale_b, cur_level); + std::optional offs_value; + std::optional offs_bdim; + if (offs) { + std::tie(offs_value, offs_bdim) = unwrapTensorAtLevel(offs.value(), cur_level); + } + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + std::optional scale_result_value; + std::optional scale_result_bdim; + if (scale_result) { + std::tie(scale_result_value, scale_result_bdim) = unwrapTensorAtLevel(scale_result.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, mat2_value, mat2_bdim, scale_a_value, scale_a_bdim, scale_b_value, scale_b_bdim, offs_value, offs_bdim, bias_value, bias_bdim, scale_result_value, scale_result_bdim, out_dtype, use_fast_accum); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _scaled_grouped_mm_v2_generated_plumbing(const at::Tensor & self, const at::Tensor & mat2, at::TensorList scale_a, at::IntArrayRef recipe_a, at::IntArrayRef swizzle_a, at::TensorList scale_b, at::IntArrayRef recipe_b, at::IntArrayRef swizzle_b, const ::std::optional & offs, const ::std::optional & bias, ::std::optional out_dtype, at::IntArrayRef contraction_dim, bool use_fast_accum) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(mat2, cur_level) && !isBatchedAtLevel(scale_a, cur_level) && !isBatchedAtLevel(scale_b, cur_level) && !isBatchedAtLevel(offs, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::_scaled_grouped_mm_v2::call(self, mat2, scale_a, recipe_a, swizzle_a, scale_b, recipe_b, swizzle_b, offs, bias, out_dtype, contraction_dim, use_fast_accum); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [mat2_value, mat2_bdim] = unwrapTensorAtLevel(mat2, cur_level); + std::optional offs_value; + std::optional offs_bdim; + if (offs) { + std::tie(offs_value, offs_bdim) = unwrapTensorAtLevel(offs.value(), cur_level); + } + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, mat2_value, mat2_bdim, scale_a, recipe_a, swizzle_a, scale_b, recipe_b, swizzle_b, offs_value, offs_bdim, bias_value, bias_bdim, out_dtype, contraction_dim, use_fast_accum); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _grouped_mm_generated_plumbing(const at::Tensor & self, const at::Tensor & mat2, const ::std::optional & offs, const ::std::optional & bias, ::std::optional out_dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(mat2, cur_level) && !isBatchedAtLevel(offs, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::_grouped_mm::call(self, mat2, offs, bias, out_dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [mat2_value, mat2_bdim] = unwrapTensorAtLevel(mat2, cur_level); + std::optional offs_value; + std::optional offs_bdim; + if (offs) { + std::tie(offs_value, offs_bdim) = unwrapTensorAtLevel(offs.value(), cur_level); + } + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, mat2_value, mat2_bdim, offs_value, offs_bdim, bias_value, bias_bdim, out_dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor sparse_compressed_tensor_comp_plain_value_size_generated_plumbing(const at::Tensor & compressed_indices, const at::Tensor & plain_indices, const at::Tensor & values, c10::SymIntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(compressed_indices, cur_level) && !isBatchedAtLevel(plain_indices, cur_level) && !isBatchedAtLevel(values, cur_level)) { + return at::_ops::sparse_compressed_tensor_comp_plain_value_size::call(compressed_indices, plain_indices, values, size, dtype, layout, device, pin_memory); + } + auto [compressed_indices_value, compressed_indices_bdim] = unwrapTensorAtLevel(compressed_indices, cur_level); + auto [plain_indices_value, plain_indices_bdim] = unwrapTensorAtLevel(plain_indices, cur_level); + auto [values_value, values_bdim] = unwrapTensorAtLevel(values, cur_level); + auto results = batch_rule(compressed_indices_value, compressed_indices_bdim, plain_indices_value, plain_indices_bdim, values_value, values_bdim, size, dtype, layout, device, pin_memory); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor sparse_csr_tensor_crow_col_value_size_generated_plumbing(const at::Tensor & crow_indices, const at::Tensor & col_indices, const at::Tensor & values, at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(crow_indices, cur_level) && !isBatchedAtLevel(col_indices, cur_level) && !isBatchedAtLevel(values, cur_level)) { + return at::_ops::sparse_csr_tensor_crow_col_value_size::call(crow_indices, col_indices, values, size, dtype, layout, device, pin_memory); + } + auto [crow_indices_value, crow_indices_bdim] = unwrapTensorAtLevel(crow_indices, cur_level); + auto [col_indices_value, col_indices_bdim] = unwrapTensorAtLevel(col_indices, cur_level); + auto [values_value, values_bdim] = unwrapTensorAtLevel(values, cur_level); + auto results = batch_rule(crow_indices_value, crow_indices_bdim, col_indices_value, col_indices_bdim, values_value, values_bdim, size, dtype, layout, device, pin_memory); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor sparse_csc_tensor_ccol_row_value_size_generated_plumbing(const at::Tensor & ccol_indices, const at::Tensor & row_indices, const at::Tensor & values, at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(ccol_indices, cur_level) && !isBatchedAtLevel(row_indices, cur_level) && !isBatchedAtLevel(values, cur_level)) { + return at::_ops::sparse_csc_tensor_ccol_row_value_size::call(ccol_indices, row_indices, values, size, dtype, layout, device, pin_memory); + } + auto [ccol_indices_value, ccol_indices_bdim] = unwrapTensorAtLevel(ccol_indices, cur_level); + auto [row_indices_value, row_indices_bdim] = unwrapTensorAtLevel(row_indices, cur_level); + auto [values_value, values_bdim] = unwrapTensorAtLevel(values, cur_level); + auto results = batch_rule(ccol_indices_value, ccol_indices_bdim, row_indices_value, row_indices_bdim, values_value, values_bdim, size, dtype, layout, device, pin_memory); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor sparse_bsr_tensor_crow_col_value_size_generated_plumbing(const at::Tensor & crow_indices, const at::Tensor & col_indices, const at::Tensor & values, at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(crow_indices, cur_level) && !isBatchedAtLevel(col_indices, cur_level) && !isBatchedAtLevel(values, cur_level)) { + return at::_ops::sparse_bsr_tensor_crow_col_value_size::call(crow_indices, col_indices, values, size, dtype, layout, device, pin_memory); + } + auto [crow_indices_value, crow_indices_bdim] = unwrapTensorAtLevel(crow_indices, cur_level); + auto [col_indices_value, col_indices_bdim] = unwrapTensorAtLevel(col_indices, cur_level); + auto [values_value, values_bdim] = unwrapTensorAtLevel(values, cur_level); + auto results = batch_rule(crow_indices_value, crow_indices_bdim, col_indices_value, col_indices_bdim, values_value, values_bdim, size, dtype, layout, device, pin_memory); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor sparse_bsc_tensor_ccol_row_value_size_generated_plumbing(const at::Tensor & ccol_indices, const at::Tensor & row_indices, const at::Tensor & values, at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(ccol_indices, cur_level) && !isBatchedAtLevel(row_indices, cur_level) && !isBatchedAtLevel(values, cur_level)) { + return at::_ops::sparse_bsc_tensor_ccol_row_value_size::call(ccol_indices, row_indices, values, size, dtype, layout, device, pin_memory); + } + auto [ccol_indices_value, ccol_indices_bdim] = unwrapTensorAtLevel(ccol_indices, cur_level); + auto [row_indices_value, row_indices_bdim] = unwrapTensorAtLevel(row_indices, cur_level); + auto [values_value, values_bdim] = unwrapTensorAtLevel(values, cur_level); + auto results = batch_rule(ccol_indices_value, ccol_indices_bdim, row_indices_value, row_indices_bdim, values_value, values_bdim, size, dtype, layout, device, pin_memory); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor sparse_compressed_tensor_comp_plain_value_generated_plumbing(const at::Tensor & compressed_indices, const at::Tensor & plain_indices, const at::Tensor & values, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(compressed_indices, cur_level) && !isBatchedAtLevel(plain_indices, cur_level) && !isBatchedAtLevel(values, cur_level)) { + return at::_ops::sparse_compressed_tensor_comp_plain_value::call(compressed_indices, plain_indices, values, dtype, layout, device, pin_memory); + } + auto [compressed_indices_value, compressed_indices_bdim] = unwrapTensorAtLevel(compressed_indices, cur_level); + auto [plain_indices_value, plain_indices_bdim] = unwrapTensorAtLevel(plain_indices, cur_level); + auto [values_value, values_bdim] = unwrapTensorAtLevel(values, cur_level); + auto results = batch_rule(compressed_indices_value, compressed_indices_bdim, plain_indices_value, plain_indices_bdim, values_value, values_bdim, dtype, layout, device, pin_memory); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor sparse_csr_tensor_crow_col_value_generated_plumbing(const at::Tensor & crow_indices, const at::Tensor & col_indices, const at::Tensor & values, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(crow_indices, cur_level) && !isBatchedAtLevel(col_indices, cur_level) && !isBatchedAtLevel(values, cur_level)) { + return at::_ops::sparse_csr_tensor_crow_col_value::call(crow_indices, col_indices, values, dtype, layout, device, pin_memory); + } + auto [crow_indices_value, crow_indices_bdim] = unwrapTensorAtLevel(crow_indices, cur_level); + auto [col_indices_value, col_indices_bdim] = unwrapTensorAtLevel(col_indices, cur_level); + auto [values_value, values_bdim] = unwrapTensorAtLevel(values, cur_level); + auto results = batch_rule(crow_indices_value, crow_indices_bdim, col_indices_value, col_indices_bdim, values_value, values_bdim, dtype, layout, device, pin_memory); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor sparse_csc_tensor_ccol_row_value_generated_plumbing(const at::Tensor & ccol_indices, const at::Tensor & row_indices, const at::Tensor & values, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(ccol_indices, cur_level) && !isBatchedAtLevel(row_indices, cur_level) && !isBatchedAtLevel(values, cur_level)) { + return at::_ops::sparse_csc_tensor_ccol_row_value::call(ccol_indices, row_indices, values, dtype, layout, device, pin_memory); + } + auto [ccol_indices_value, ccol_indices_bdim] = unwrapTensorAtLevel(ccol_indices, cur_level); + auto [row_indices_value, row_indices_bdim] = unwrapTensorAtLevel(row_indices, cur_level); + auto [values_value, values_bdim] = unwrapTensorAtLevel(values, cur_level); + auto results = batch_rule(ccol_indices_value, ccol_indices_bdim, row_indices_value, row_indices_bdim, values_value, values_bdim, dtype, layout, device, pin_memory); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor sparse_bsr_tensor_crow_col_value_generated_plumbing(const at::Tensor & crow_indices, const at::Tensor & col_indices, const at::Tensor & values, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(crow_indices, cur_level) && !isBatchedAtLevel(col_indices, cur_level) && !isBatchedAtLevel(values, cur_level)) { + return at::_ops::sparse_bsr_tensor_crow_col_value::call(crow_indices, col_indices, values, dtype, layout, device, pin_memory); + } + auto [crow_indices_value, crow_indices_bdim] = unwrapTensorAtLevel(crow_indices, cur_level); + auto [col_indices_value, col_indices_bdim] = unwrapTensorAtLevel(col_indices, cur_level); + auto [values_value, values_bdim] = unwrapTensorAtLevel(values, cur_level); + auto results = batch_rule(crow_indices_value, crow_indices_bdim, col_indices_value, col_indices_bdim, values_value, values_bdim, dtype, layout, device, pin_memory); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor sparse_bsc_tensor_ccol_row_value_generated_plumbing(const at::Tensor & ccol_indices, const at::Tensor & row_indices, const at::Tensor & values, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(ccol_indices, cur_level) && !isBatchedAtLevel(row_indices, cur_level) && !isBatchedAtLevel(values, cur_level)) { + return at::_ops::sparse_bsc_tensor_ccol_row_value::call(ccol_indices, row_indices, values, dtype, layout, device, pin_memory); + } + auto [ccol_indices_value, ccol_indices_bdim] = unwrapTensorAtLevel(ccol_indices, cur_level); + auto [row_indices_value, row_indices_bdim] = unwrapTensorAtLevel(row_indices, cur_level); + auto [values_value, values_bdim] = unwrapTensorAtLevel(values, cur_level); + auto results = batch_rule(ccol_indices_value, ccol_indices_bdim, row_indices_value, row_indices_bdim, values_value, values_bdim, dtype, layout, device, pin_memory); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _sparse_compressed_tensor_unsafe_generated_plumbing(const at::Tensor & compressed_indices, const at::Tensor & plain_indices, const at::Tensor & values, c10::SymIntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(compressed_indices, cur_level) && !isBatchedAtLevel(plain_indices, cur_level) && !isBatchedAtLevel(values, cur_level)) { + return at::_ops::_sparse_compressed_tensor_unsafe::call(compressed_indices, plain_indices, values, size, dtype, layout, device, pin_memory); + } + auto [compressed_indices_value, compressed_indices_bdim] = unwrapTensorAtLevel(compressed_indices, cur_level); + auto [plain_indices_value, plain_indices_bdim] = unwrapTensorAtLevel(plain_indices, cur_level); + auto [values_value, values_bdim] = unwrapTensorAtLevel(values, cur_level); + auto results = batch_rule(compressed_indices_value, compressed_indices_bdim, plain_indices_value, plain_indices_bdim, values_value, values_bdim, size, dtype, layout, device, pin_memory); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _sparse_csr_tensor_unsafe_generated_plumbing(const at::Tensor & crow_indices, const at::Tensor & col_indices, const at::Tensor & values, at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(crow_indices, cur_level) && !isBatchedAtLevel(col_indices, cur_level) && !isBatchedAtLevel(values, cur_level)) { + return at::_ops::_sparse_csr_tensor_unsafe::call(crow_indices, col_indices, values, size, dtype, layout, device, pin_memory); + } + auto [crow_indices_value, crow_indices_bdim] = unwrapTensorAtLevel(crow_indices, cur_level); + auto [col_indices_value, col_indices_bdim] = unwrapTensorAtLevel(col_indices, cur_level); + auto [values_value, values_bdim] = unwrapTensorAtLevel(values, cur_level); + auto results = batch_rule(crow_indices_value, crow_indices_bdim, col_indices_value, col_indices_bdim, values_value, values_bdim, size, dtype, layout, device, pin_memory); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _sparse_csc_tensor_unsafe_generated_plumbing(const at::Tensor & ccol_indices, const at::Tensor & row_indices, const at::Tensor & values, at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(ccol_indices, cur_level) && !isBatchedAtLevel(row_indices, cur_level) && !isBatchedAtLevel(values, cur_level)) { + return at::_ops::_sparse_csc_tensor_unsafe::call(ccol_indices, row_indices, values, size, dtype, layout, device, pin_memory); + } + auto [ccol_indices_value, ccol_indices_bdim] = unwrapTensorAtLevel(ccol_indices, cur_level); + auto [row_indices_value, row_indices_bdim] = unwrapTensorAtLevel(row_indices, cur_level); + auto [values_value, values_bdim] = unwrapTensorAtLevel(values, cur_level); + auto results = batch_rule(ccol_indices_value, ccol_indices_bdim, row_indices_value, row_indices_bdim, values_value, values_bdim, size, dtype, layout, device, pin_memory); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _sparse_bsr_tensor_unsafe_generated_plumbing(const at::Tensor & crow_indices, const at::Tensor & col_indices, const at::Tensor & values, at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(crow_indices, cur_level) && !isBatchedAtLevel(col_indices, cur_level) && !isBatchedAtLevel(values, cur_level)) { + return at::_ops::_sparse_bsr_tensor_unsafe::call(crow_indices, col_indices, values, size, dtype, layout, device, pin_memory); + } + auto [crow_indices_value, crow_indices_bdim] = unwrapTensorAtLevel(crow_indices, cur_level); + auto [col_indices_value, col_indices_bdim] = unwrapTensorAtLevel(col_indices, cur_level); + auto [values_value, values_bdim] = unwrapTensorAtLevel(values, cur_level); + auto results = batch_rule(crow_indices_value, crow_indices_bdim, col_indices_value, col_indices_bdim, values_value, values_bdim, size, dtype, layout, device, pin_memory); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _sparse_bsc_tensor_unsafe_generated_plumbing(const at::Tensor & ccol_indices, const at::Tensor & row_indices, const at::Tensor & values, at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(ccol_indices, cur_level) && !isBatchedAtLevel(row_indices, cur_level) && !isBatchedAtLevel(values, cur_level)) { + return at::_ops::_sparse_bsc_tensor_unsafe::call(ccol_indices, row_indices, values, size, dtype, layout, device, pin_memory); + } + auto [ccol_indices_value, ccol_indices_bdim] = unwrapTensorAtLevel(ccol_indices, cur_level); + auto [row_indices_value, row_indices_bdim] = unwrapTensorAtLevel(row_indices, cur_level); + auto [values_value, values_bdim] = unwrapTensorAtLevel(values, cur_level); + auto results = batch_rule(ccol_indices_value, ccol_indices_bdim, row_indices_value, row_indices_bdim, values_value, values_bdim, size, dtype, layout, device, pin_memory); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor sparse_coo_tensor_indices_generated_plumbing(const at::Tensor & indices, const at::Tensor & values, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional is_coalesced) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(indices, cur_level) && !isBatchedAtLevel(values, cur_level)) { + return at::_ops::sparse_coo_tensor_indices::call(indices, values, dtype, layout, device, pin_memory, is_coalesced); + } + auto [indices_value, indices_bdim] = unwrapTensorAtLevel(indices, cur_level); + auto [values_value, values_bdim] = unwrapTensorAtLevel(values, cur_level); + auto results = batch_rule(indices_value, indices_bdim, values_value, values_bdim, dtype, layout, device, pin_memory, is_coalesced); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor sparse_coo_tensor_indices_size_generated_plumbing(const at::Tensor & indices, const at::Tensor & values, at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional is_coalesced) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(indices, cur_level) && !isBatchedAtLevel(values, cur_level)) { + return at::_ops::sparse_coo_tensor_indices_size::call(indices, values, size, dtype, layout, device, pin_memory, is_coalesced); + } + auto [indices_value, indices_bdim] = unwrapTensorAtLevel(indices, cur_level); + auto [values_value, values_bdim] = unwrapTensorAtLevel(values, cur_level); + auto results = batch_rule(indices_value, indices_bdim, values_value, values_bdim, size, dtype, layout, device, pin_memory, is_coalesced); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _sparse_coo_tensor_unsafe_generated_plumbing(const at::Tensor & indices, const at::Tensor & values, c10::SymIntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional is_coalesced) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(indices, cur_level) && !isBatchedAtLevel(values, cur_level)) { + return at::_ops::_sparse_coo_tensor_unsafe::call(indices, values, size, dtype, layout, device, pin_memory, is_coalesced); + } + auto [indices_value, indices_bdim] = unwrapTensorAtLevel(indices, cur_level); + auto [values_value, values_bdim] = unwrapTensorAtLevel(values, cur_level); + auto results = batch_rule(indices_value, indices_bdim, values_value, values_bdim, size, dtype, layout, device, pin_memory, is_coalesced); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _validate_sparse_coo_tensor_args_generated_plumbing(const at::Tensor & indices, const at::Tensor & values, at::IntArrayRef size, ::std::optional is_coalesced, ::std::optional check_pinning) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(indices, cur_level) && !isBatchedAtLevel(values, cur_level)) { + return at::_ops::_validate_sparse_coo_tensor_args::call(indices, values, size, is_coalesced, check_pinning); + } + auto [indices_value, indices_bdim] = unwrapTensorAtLevel(indices, cur_level); + auto [values_value, values_bdim] = unwrapTensorAtLevel(values, cur_level); + batch_rule(indices_value, indices_bdim, values_value, values_bdim, size, is_coalesced, check_pinning); +} +template +void _validate_sparse_compressed_tensor_args_generated_plumbing(const at::Tensor & compressed_indices, const at::Tensor & plain_indices, const at::Tensor & values, at::IntArrayRef size, at::Layout layout, ::std::optional check_pinning) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(compressed_indices, cur_level) && !isBatchedAtLevel(plain_indices, cur_level) && !isBatchedAtLevel(values, cur_level)) { + return at::_ops::_validate_sparse_compressed_tensor_args::call(compressed_indices, plain_indices, values, size, layout, check_pinning); + } + auto [compressed_indices_value, compressed_indices_bdim] = unwrapTensorAtLevel(compressed_indices, cur_level); + auto [plain_indices_value, plain_indices_bdim] = unwrapTensorAtLevel(plain_indices, cur_level); + auto [values_value, values_bdim] = unwrapTensorAtLevel(values, cur_level); + batch_rule(compressed_indices_value, compressed_indices_bdim, plain_indices_value, plain_indices_bdim, values_value, values_bdim, size, layout, check_pinning); +} +template +void _validate_sparse_csr_tensor_args_generated_plumbing(const at::Tensor & crow_indices, const at::Tensor & col_indices, const at::Tensor & values, at::IntArrayRef size, ::std::optional check_pinning) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(crow_indices, cur_level) && !isBatchedAtLevel(col_indices, cur_level) && !isBatchedAtLevel(values, cur_level)) { + return at::_ops::_validate_sparse_csr_tensor_args::call(crow_indices, col_indices, values, size, check_pinning); + } + auto [crow_indices_value, crow_indices_bdim] = unwrapTensorAtLevel(crow_indices, cur_level); + auto [col_indices_value, col_indices_bdim] = unwrapTensorAtLevel(col_indices, cur_level); + auto [values_value, values_bdim] = unwrapTensorAtLevel(values, cur_level); + batch_rule(crow_indices_value, crow_indices_bdim, col_indices_value, col_indices_bdim, values_value, values_bdim, size, check_pinning); +} +template +void _validate_sparse_csc_tensor_args_generated_plumbing(const at::Tensor & ccol_indices, const at::Tensor & row_indices, const at::Tensor & values, at::IntArrayRef size, ::std::optional check_pinning) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(ccol_indices, cur_level) && !isBatchedAtLevel(row_indices, cur_level) && !isBatchedAtLevel(values, cur_level)) { + return at::_ops::_validate_sparse_csc_tensor_args::call(ccol_indices, row_indices, values, size, check_pinning); + } + auto [ccol_indices_value, ccol_indices_bdim] = unwrapTensorAtLevel(ccol_indices, cur_level); + auto [row_indices_value, row_indices_bdim] = unwrapTensorAtLevel(row_indices, cur_level); + auto [values_value, values_bdim] = unwrapTensorAtLevel(values, cur_level); + batch_rule(ccol_indices_value, ccol_indices_bdim, row_indices_value, row_indices_bdim, values_value, values_bdim, size, check_pinning); +} +template +void _validate_sparse_bsr_tensor_args_generated_plumbing(const at::Tensor & crow_indices, const at::Tensor & col_indices, const at::Tensor & values, at::IntArrayRef size, ::std::optional check_pinning) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(crow_indices, cur_level) && !isBatchedAtLevel(col_indices, cur_level) && !isBatchedAtLevel(values, cur_level)) { + return at::_ops::_validate_sparse_bsr_tensor_args::call(crow_indices, col_indices, values, size, check_pinning); + } + auto [crow_indices_value, crow_indices_bdim] = unwrapTensorAtLevel(crow_indices, cur_level); + auto [col_indices_value, col_indices_bdim] = unwrapTensorAtLevel(col_indices, cur_level); + auto [values_value, values_bdim] = unwrapTensorAtLevel(values, cur_level); + batch_rule(crow_indices_value, crow_indices_bdim, col_indices_value, col_indices_bdim, values_value, values_bdim, size, check_pinning); +} +template +void _validate_sparse_bsc_tensor_args_generated_plumbing(const at::Tensor & ccol_indices, const at::Tensor & row_indices, const at::Tensor & values, at::IntArrayRef size, ::std::optional check_pinning) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(ccol_indices, cur_level) && !isBatchedAtLevel(row_indices, cur_level) && !isBatchedAtLevel(values, cur_level)) { + return at::_ops::_validate_sparse_bsc_tensor_args::call(ccol_indices, row_indices, values, size, check_pinning); + } + auto [ccol_indices_value, ccol_indices_bdim] = unwrapTensorAtLevel(ccol_indices, cur_level); + auto [row_indices_value, row_indices_bdim] = unwrapTensorAtLevel(row_indices, cur_level); + auto [values_value, values_bdim] = unwrapTensorAtLevel(values, cur_level); + batch_rule(ccol_indices_value, ccol_indices_bdim, row_indices_value, row_indices_bdim, values_value, values_bdim, size, check_pinning); +} +template +at::Tensor _sparse_coo_tensor_with_dims_and_tensors_generated_plumbing(int64_t sparse_dim, int64_t dense_dim, c10::SymIntArrayRef size, const at::Tensor & indices, const at::Tensor & values, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional is_coalesced) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(indices, cur_level) && !isBatchedAtLevel(values, cur_level)) { + return at::_ops::_sparse_coo_tensor_with_dims_and_tensors::call(sparse_dim, dense_dim, size, indices, values, dtype, layout, device, pin_memory, is_coalesced); + } + auto [indices_value, indices_bdim] = unwrapTensorAtLevel(indices, cur_level); + auto [values_value, values_bdim] = unwrapTensorAtLevel(values, cur_level); + auto results = batch_rule(sparse_dim, dense_dim, size, indices_value, indices_bdim, values_value, values_bdim, dtype, layout, device, pin_memory, is_coalesced); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +const at::Tensor & sparse_resize__generated_plumbing(const at::Tensor & self, at::IntArrayRef size, int64_t sparse_dim, int64_t dense_dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::sparse_resize_::call(self, size, sparse_dim, dense_dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, size, sparse_dim, dense_dim); + return self; +} +template +const at::Tensor & sparse_resize_and_clear__generated_plumbing(const at::Tensor & self, at::IntArrayRef size, int64_t sparse_dim, int64_t dense_dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::sparse_resize_and_clear_::call(self, size, sparse_dim, dense_dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, size, sparse_dim, dense_dim); + return self; +} +template +at::Tensor sparse_mask_generated_plumbing(const at::Tensor & self, const at::Tensor & mask) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(mask, cur_level)) { + return at::_ops::sparse_mask::call(self, mask); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [mask_value, mask_bdim] = unwrapTensorAtLevel(mask, cur_level); + auto results = batch_rule(self_value, self_bdim, mask_value, mask_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _sparse_mask_projection_generated_plumbing(const at::Tensor & self, const at::Tensor & mask, bool accumulate_matches) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(mask, cur_level)) { + return at::_ops::_sparse_mask_projection::call(self, mask, accumulate_matches); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [mask_value, mask_bdim] = unwrapTensorAtLevel(mask, cur_level); + auto results = batch_rule(self_value, self_bdim, mask_value, mask_bdim, accumulate_matches); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::vector _to_cpu_generated_plumbing(at::TensorList tensors) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(tensors, cur_level)) { + return at::_ops::_to_cpu::call(tensors); + } + + auto results = batch_rule(tensors); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor to_dense_generated_plumbing(const at::Tensor & self, ::std::optional dtype, ::std::optional masked_grad) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::to_dense::call(self, dtype, masked_grad); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dtype, masked_grad); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _to_dense_generated_plumbing(const at::Tensor & self, ::std::optional dtype, ::std::optional masked_grad) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_to_dense::call(self, dtype, masked_grad); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dtype, masked_grad); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor to_dense_backward_generated_plumbing(const at::Tensor & grad, const at::Tensor & input, ::std::optional masked_grad) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad, cur_level) && !isBatchedAtLevel(input, cur_level)) { + return at::_ops::to_dense_backward::call(grad, input, masked_grad); + } + auto [grad_value, grad_bdim] = unwrapTensorAtLevel(grad, cur_level); + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto results = batch_rule(grad_value, grad_bdim, input_value, input_bdim, masked_grad); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor coalesce_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::coalesce::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _coalesce_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_coalesce::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _indices_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_indices::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _values_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_values::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & _coalesced__generated_plumbing(at::Tensor & self, bool coalesced) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_coalesced_::call(self, coalesced); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, coalesced); + return self; +} +template +at::Tensor indices_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::indices::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor values_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::values::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor crow_indices_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::crow_indices::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor col_indices_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::col_indices::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor ccol_indices_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::ccol_indices::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor row_indices_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::row_indices::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor hspmm_generated_plumbing(const at::Tensor & mat1, const at::Tensor & mat2) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(mat1, cur_level) && !isBatchedAtLevel(mat2, cur_level)) { + return at::_ops::hspmm::call(mat1, mat2); + } + auto [mat1_value, mat1_bdim] = unwrapTensorAtLevel(mat1, cur_level); + auto [mat2_value, mat2_bdim] = unwrapTensorAtLevel(mat2, cur_level); + auto results = batch_rule(mat1_value, mat1_bdim, mat2_value, mat2_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & copy_sparse_to_sparse__generated_plumbing(at::Tensor & self, const at::Tensor & src, bool non_blocking) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(src, cur_level)) { + return at::_ops::copy_sparse_to_sparse_::call(self, src, non_blocking); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [src_value, src_bdim] = unwrapTensorAtLevel(src, cur_level); + batch_rule(self_value, self_bdim, src_value, src_bdim, non_blocking); + return self; +} +template +::std::vector unbind_int_generated_plumbing(const at::Tensor & self, int64_t dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::unbind_int::call(self, dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::vector unbind_Dimname_generated_plumbing(const at::Tensor & self, at::Dimname dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::unbind_Dimname::call(self, dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor to_sparse_sparse_dim_generated_plumbing(const at::Tensor & self, int64_t sparse_dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::to_sparse_sparse_dim::call(self, sparse_dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, sparse_dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _to_sparse_sparse_dim_generated_plumbing(const at::Tensor & self, int64_t sparse_dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_to_sparse_sparse_dim::call(self, sparse_dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, sparse_dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor to_sparse_generated_plumbing(const at::Tensor & self, ::std::optional layout, at::OptionalIntArrayRef blocksize, ::std::optional dense_dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::to_sparse::call(self, layout, blocksize, dense_dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, layout, blocksize, dense_dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _to_sparse_generated_plumbing(const at::Tensor & self, ::std::optional layout, at::OptionalIntArrayRef blocksize, ::std::optional dense_dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_to_sparse::call(self, layout, blocksize, dense_dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, layout, blocksize, dense_dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor to_sparse_csr_generated_plumbing(const at::Tensor & self, ::std::optional dense_dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::to_sparse_csr::call(self, dense_dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dense_dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _to_sparse_csr_generated_plumbing(const at::Tensor & self, ::std::optional dense_dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_to_sparse_csr::call(self, dense_dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dense_dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor to_sparse_csc_generated_plumbing(const at::Tensor & self, ::std::optional dense_dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::to_sparse_csc::call(self, dense_dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dense_dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _to_sparse_csc_generated_plumbing(const at::Tensor & self, ::std::optional dense_dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_to_sparse_csc::call(self, dense_dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dense_dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor to_sparse_bsr_generated_plumbing(const at::Tensor & self, at::IntArrayRef blocksize, ::std::optional dense_dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::to_sparse_bsr::call(self, blocksize, dense_dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, blocksize, dense_dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _to_sparse_bsr_generated_plumbing(const at::Tensor & self, at::IntArrayRef blocksize, ::std::optional dense_dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_to_sparse_bsr::call(self, blocksize, dense_dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, blocksize, dense_dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor to_sparse_bsc_generated_plumbing(const at::Tensor & self, at::IntArrayRef blocksize, ::std::optional dense_dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::to_sparse_bsc::call(self, blocksize, dense_dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, blocksize, dense_dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _to_sparse_bsc_generated_plumbing(const at::Tensor & self, at::IntArrayRef blocksize, ::std::optional dense_dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_to_sparse_bsc::call(self, blocksize, dense_dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, blocksize, dense_dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple _to_sparse_semi_structured_generated_plumbing(const at::Tensor & dense) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(dense, cur_level)) { + return at::_ops::_to_sparse_semi_structured::call(dense); + } + auto [dense_value, dense_bdim] = unwrapTensorAtLevel(dense, cur_level); + auto results = batch_rule(dense_value, dense_bdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor to_mkldnn_generated_plumbing(const at::Tensor & self, ::std::optional dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::to_mkldnn::call(self, dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor mkldnn_reorder_conv2d_weight_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, at::OptionalSymIntArrayRef input_size) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::mkldnn_reorder_conv2d_weight::call(self, padding, stride, dilation, groups, input_size); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, padding, stride, dilation, groups, input_size); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor mkldnn_reorder_conv3d_weight_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, at::OptionalSymIntArrayRef input_size) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::mkldnn_reorder_conv3d_weight::call(self, padding, stride, dilation, groups, input_size); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, padding, stride, dilation, groups, input_size); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor to_mkldnn_backward_generated_plumbing(const at::Tensor & grad, const at::Tensor & input) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad, cur_level) && !isBatchedAtLevel(input, cur_level)) { + return at::_ops::to_mkldnn_backward::call(grad, input); + } + auto [grad_value, grad_bdim] = unwrapTensorAtLevel(grad, cur_level); + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto results = batch_rule(grad_value, grad_bdim, input_value, input_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor quantize_per_tensor_dynamic_generated_plumbing(const at::Tensor & self, at::ScalarType dtype, bool reduce_range) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::quantize_per_tensor_dynamic::call(self, dtype, reduce_range); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dtype, reduce_range); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor quantize_per_tensor_generated_plumbing(const at::Tensor & self, double scale, int64_t zero_point, at::ScalarType dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::quantize_per_tensor::call(self, scale, zero_point, dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, scale, zero_point, dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor quantize_per_tensor_tensor_qparams_generated_plumbing(const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, at::ScalarType dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(scale, cur_level) && !isBatchedAtLevel(zero_point, cur_level)) { + return at::_ops::quantize_per_tensor_tensor_qparams::call(self, scale, zero_point, dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [scale_value, scale_bdim] = unwrapTensorAtLevel(scale, cur_level); + auto [zero_point_value, zero_point_bdim] = unwrapTensorAtLevel(zero_point, cur_level); + auto results = batch_rule(self_value, self_bdim, scale_value, scale_bdim, zero_point_value, zero_point_bdim, dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::vector quantize_per_tensor_tensors_generated_plumbing(at::TensorList tensors, const at::Tensor & scales, const at::Tensor & zero_points, at::ScalarType dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(tensors, cur_level) && !isBatchedAtLevel(scales, cur_level) && !isBatchedAtLevel(zero_points, cur_level)) { + return at::_ops::quantize_per_tensor_tensors::call(tensors, scales, zero_points, dtype); + } + auto [scales_value, scales_bdim] = unwrapTensorAtLevel(scales, cur_level); + auto [zero_points_value, zero_points_bdim] = unwrapTensorAtLevel(zero_points, cur_level); + auto results = batch_rule(tensors, scales_value, scales_bdim, zero_points_value, zero_points_bdim, dtype); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor quantize_per_channel_generated_plumbing(const at::Tensor & self, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, at::ScalarType dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(scales, cur_level) && !isBatchedAtLevel(zero_points, cur_level)) { + return at::_ops::quantize_per_channel::call(self, scales, zero_points, axis, dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [scales_value, scales_bdim] = unwrapTensorAtLevel(scales, cur_level); + auto [zero_points_value, zero_points_bdim] = unwrapTensorAtLevel(zero_points, cur_level); + auto results = batch_rule(self_value, self_bdim, scales_value, scales_bdim, zero_points_value, zero_points_bdim, axis, dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor dequantize_self_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::dequantize_self::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::vector dequantize_tensors_generated_plumbing(at::TensorList tensors) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(tensors, cur_level)) { + return at::_ops::dequantize_tensors::call(tensors); + } + + auto results = batch_rule(tensors); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor q_per_channel_scales_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::q_per_channel_scales::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor q_per_channel_zero_points_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::q_per_channel_zero_points::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor int_repr_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::int_repr::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _make_per_tensor_quantized_tensor_generated_plumbing(const at::Tensor & self, double scale, int64_t zero_point) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_make_per_tensor_quantized_tensor::call(self, scale, zero_point); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, scale, zero_point); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _make_per_channel_quantized_tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, int64_t axis) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(scale, cur_level) && !isBatchedAtLevel(zero_point, cur_level)) { + return at::_ops::_make_per_channel_quantized_tensor::call(self, scale, zero_point, axis); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [scale_value, scale_bdim] = unwrapTensorAtLevel(scale, cur_level); + auto [zero_point_value, zero_point_bdim] = unwrapTensorAtLevel(zero_point, cur_level); + auto results = batch_rule(self_value, self_bdim, scale_value, scale_bdim, zero_point_value, zero_point_bdim, axis); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor fake_quantize_per_tensor_affine_generated_plumbing(const at::Tensor & self, double scale, int64_t zero_point, int64_t quant_min, int64_t quant_max) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::fake_quantize_per_tensor_affine::call(self, scale, zero_point, quant_min, quant_max); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, scale, zero_point, quant_min, quant_max); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor fake_quantize_per_tensor_affine_tensor_qparams_generated_plumbing(const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, int64_t quant_min, int64_t quant_max) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(scale, cur_level) && !isBatchedAtLevel(zero_point, cur_level)) { + return at::_ops::fake_quantize_per_tensor_affine_tensor_qparams::call(self, scale, zero_point, quant_min, quant_max); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [scale_value, scale_bdim] = unwrapTensorAtLevel(scale, cur_level); + auto [zero_point_value, zero_point_bdim] = unwrapTensorAtLevel(zero_point, cur_level); + auto results = batch_rule(self_value, self_bdim, scale_value, scale_bdim, zero_point_value, zero_point_bdim, quant_min, quant_max); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple fake_quantize_per_tensor_affine_cachemask_generated_plumbing(const at::Tensor & self, double scale, int64_t zero_point, int64_t quant_min, int64_t quant_max) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::fake_quantize_per_tensor_affine_cachemask::call(self, scale, zero_point, quant_min, quant_max); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, scale, zero_point, quant_min, quant_max); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple _fake_quantize_per_tensor_affine_cachemask_tensor_qparams_generated_plumbing(const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, const at::Tensor & fake_quant_enabled, int64_t quant_min, int64_t quant_max) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(scale, cur_level) && !isBatchedAtLevel(zero_point, cur_level) && !isBatchedAtLevel(fake_quant_enabled, cur_level)) { + return at::_ops::_fake_quantize_per_tensor_affine_cachemask_tensor_qparams::call(self, scale, zero_point, fake_quant_enabled, quant_min, quant_max); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [scale_value, scale_bdim] = unwrapTensorAtLevel(scale, cur_level); + auto [zero_point_value, zero_point_bdim] = unwrapTensorAtLevel(zero_point, cur_level); + auto [fake_quant_enabled_value, fake_quant_enabled_bdim] = unwrapTensorAtLevel(fake_quant_enabled, cur_level); + auto results = batch_rule(self_value, self_bdim, scale_value, scale_bdim, zero_point_value, zero_point_bdim, fake_quant_enabled_value, fake_quant_enabled_bdim, quant_min, quant_max); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor fake_quantize_per_tensor_affine_cachemask_backward_generated_plumbing(const at::Tensor & grad, const at::Tensor & mask) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad, cur_level) && !isBatchedAtLevel(mask, cur_level)) { + return at::_ops::fake_quantize_per_tensor_affine_cachemask_backward::call(grad, mask); + } + auto [grad_value, grad_bdim] = unwrapTensorAtLevel(grad, cur_level); + auto [mask_value, mask_bdim] = unwrapTensorAtLevel(mask, cur_level); + auto results = batch_rule(grad_value, grad_bdim, mask_value, mask_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _fake_quantize_learnable_per_tensor_affine_generated_plumbing(const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, int64_t quant_min, int64_t quant_max, double grad_factor) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(scale, cur_level) && !isBatchedAtLevel(zero_point, cur_level)) { + return at::_ops::_fake_quantize_learnable_per_tensor_affine::call(self, scale, zero_point, quant_min, quant_max, grad_factor); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [scale_value, scale_bdim] = unwrapTensorAtLevel(scale, cur_level); + auto [zero_point_value, zero_point_bdim] = unwrapTensorAtLevel(zero_point, cur_level); + auto results = batch_rule(self_value, self_bdim, scale_value, scale_bdim, zero_point_value, zero_point_bdim, quant_min, quant_max, grad_factor); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple _fake_quantize_learnable_per_tensor_affine_backward_generated_plumbing(const at::Tensor & grad, const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, int64_t quant_min, int64_t quant_max, double grad_factor) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad, cur_level) && !isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(scale, cur_level) && !isBatchedAtLevel(zero_point, cur_level)) { + return at::_ops::_fake_quantize_learnable_per_tensor_affine_backward::call(grad, self, scale, zero_point, quant_min, quant_max, grad_factor); + } + auto [grad_value, grad_bdim] = unwrapTensorAtLevel(grad, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [scale_value, scale_bdim] = unwrapTensorAtLevel(scale, cur_level); + auto [zero_point_value, zero_point_bdim] = unwrapTensorAtLevel(zero_point, cur_level); + auto results = batch_rule(grad_value, grad_bdim, self_value, self_bdim, scale_value, scale_bdim, zero_point_value, zero_point_bdim, quant_min, quant_max, grad_factor); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level)); +} +template +at::Tensor fake_quantize_per_channel_affine_generated_plumbing(const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, int64_t axis, int64_t quant_min, int64_t quant_max) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(scale, cur_level) && !isBatchedAtLevel(zero_point, cur_level)) { + return at::_ops::fake_quantize_per_channel_affine::call(self, scale, zero_point, axis, quant_min, quant_max); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [scale_value, scale_bdim] = unwrapTensorAtLevel(scale, cur_level); + auto [zero_point_value, zero_point_bdim] = unwrapTensorAtLevel(zero_point, cur_level); + auto results = batch_rule(self_value, self_bdim, scale_value, scale_bdim, zero_point_value, zero_point_bdim, axis, quant_min, quant_max); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple fake_quantize_per_channel_affine_cachemask_generated_plumbing(const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, int64_t axis, int64_t quant_min, int64_t quant_max) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(scale, cur_level) && !isBatchedAtLevel(zero_point, cur_level)) { + return at::_ops::fake_quantize_per_channel_affine_cachemask::call(self, scale, zero_point, axis, quant_min, quant_max); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [scale_value, scale_bdim] = unwrapTensorAtLevel(scale, cur_level); + auto [zero_point_value, zero_point_bdim] = unwrapTensorAtLevel(zero_point, cur_level); + auto results = batch_rule(self_value, self_bdim, scale_value, scale_bdim, zero_point_value, zero_point_bdim, axis, quant_min, quant_max); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor fake_quantize_per_channel_affine_cachemask_backward_generated_plumbing(const at::Tensor & grad, const at::Tensor & mask) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad, cur_level) && !isBatchedAtLevel(mask, cur_level)) { + return at::_ops::fake_quantize_per_channel_affine_cachemask_backward::call(grad, mask); + } + auto [grad_value, grad_bdim] = unwrapTensorAtLevel(grad, cur_level); + auto [mask_value, mask_bdim] = unwrapTensorAtLevel(mask, cur_level); + auto results = batch_rule(grad_value, grad_bdim, mask_value, mask_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _fake_quantize_learnable_per_channel_affine_generated_plumbing(const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, int64_t axis, int64_t quant_min, int64_t quant_max, double grad_factor) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(scale, cur_level) && !isBatchedAtLevel(zero_point, cur_level)) { + return at::_ops::_fake_quantize_learnable_per_channel_affine::call(self, scale, zero_point, axis, quant_min, quant_max, grad_factor); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [scale_value, scale_bdim] = unwrapTensorAtLevel(scale, cur_level); + auto [zero_point_value, zero_point_bdim] = unwrapTensorAtLevel(zero_point, cur_level); + auto results = batch_rule(self_value, self_bdim, scale_value, scale_bdim, zero_point_value, zero_point_bdim, axis, quant_min, quant_max, grad_factor); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple _fake_quantize_learnable_per_channel_affine_backward_generated_plumbing(const at::Tensor & grad, const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, int64_t axis, int64_t quant_min, int64_t quant_max, double grad_factor) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad, cur_level) && !isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(scale, cur_level) && !isBatchedAtLevel(zero_point, cur_level)) { + return at::_ops::_fake_quantize_learnable_per_channel_affine_backward::call(grad, self, scale, zero_point, axis, quant_min, quant_max, grad_factor); + } + auto [grad_value, grad_bdim] = unwrapTensorAtLevel(grad, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [scale_value, scale_bdim] = unwrapTensorAtLevel(scale, cur_level); + auto [zero_point_value, zero_point_bdim] = unwrapTensorAtLevel(zero_point, cur_level); + auto results = batch_rule(grad_value, grad_bdim, self_value, self_bdim, scale_value, scale_bdim, zero_point_value, zero_point_bdim, axis, quant_min, quant_max, grad_factor); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level)); +} +template +at::Tensor _saturate_weight_to_fp16_generated_plumbing(const at::Tensor & weight) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(weight, cur_level)) { + return at::_ops::_saturate_weight_to_fp16::call(weight); + } + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + auto results = batch_rule(weight_value, weight_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple choose_qparams_optimized_generated_plumbing(const at::Tensor & input, int64_t numel, int64_t n_bins, double ratio, int64_t bit_width) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level)) { + return at::_ops::choose_qparams_optimized::call(input, numel, n_bins, ratio, bit_width); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto results = batch_rule(input_value, input_bdim, numel, n_bins, ratio, bit_width); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor _autocast_to_reduced_precision_generated_plumbing(const at::Tensor & self, bool cuda_enabled, bool cpu_enabled, at::ScalarType cuda_dtype, at::ScalarType cpu_dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_autocast_to_reduced_precision::call(self, cuda_enabled, cpu_enabled, cuda_dtype, cpu_dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, cuda_enabled, cpu_enabled, cuda_dtype, cpu_dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _autocast_to_full_precision_generated_plumbing(const at::Tensor & self, bool cuda_enabled, bool cpu_enabled) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_autocast_to_full_precision::call(self, cuda_enabled, cpu_enabled); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, cuda_enabled, cpu_enabled); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _to_copy_generated_plumbing(const at::Tensor & self, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, bool non_blocking, ::std::optional memory_format) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_to_copy::call(self, dtype, layout, device, pin_memory, non_blocking, memory_format); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dtype, layout, device, pin_memory, non_blocking, memory_format); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor to_dtype_layout_generated_plumbing(const at::Tensor & self, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, bool non_blocking, bool copy, ::std::optional memory_format) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::to_dtype_layout::call(self, dtype, layout, device, pin_memory, non_blocking, copy, memory_format); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dtype, layout, device, pin_memory, non_blocking, copy, memory_format); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor to_device_generated_plumbing(const at::Tensor & self, at::Device device, at::ScalarType dtype, bool non_blocking, bool copy, ::std::optional memory_format) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::to_device::call(self, device, dtype, non_blocking, copy, memory_format); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, device, dtype, non_blocking, copy, memory_format); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor to_dtype_generated_plumbing(const at::Tensor & self, at::ScalarType dtype, bool non_blocking, bool copy, ::std::optional memory_format) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::to_dtype::call(self, dtype, non_blocking, copy, memory_format); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dtype, non_blocking, copy, memory_format); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor to_other_generated_plumbing(const at::Tensor & self, const at::Tensor & other, bool non_blocking, bool copy, ::std::optional memory_format) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::to_other::call(self, other, non_blocking, copy, memory_format); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim, non_blocking, copy, memory_format); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::vector meshgrid_generated_plumbing(at::TensorList tensors) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(tensors, cur_level)) { + return at::_ops::meshgrid::call(tensors); + } + + auto results = batch_rule(tensors); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::vector meshgrid_indexing_generated_plumbing(at::TensorList tensors, c10::string_view indexing) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(tensors, cur_level)) { + return at::_ops::meshgrid_indexing::call(tensors, indexing); + } + + auto results = batch_rule(tensors, indexing); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor cartesian_prod_generated_plumbing(at::TensorList tensors) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(tensors, cur_level)) { + return at::_ops::cartesian_prod::call(tensors); + } + + auto results = batch_rule(tensors); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor combinations_generated_plumbing(const at::Tensor & self, int64_t r, bool with_replacement) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::combinations::call(self, r, with_replacement); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, r, with_replacement); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple _lstm_mps_generated_plumbing(const at::Tensor & input, at::TensorList hx, at::TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional, bool batch_first) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(hx, cur_level) && !isBatchedAtLevel(params, cur_level)) { + return at::_ops::_lstm_mps::call(input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto results = batch_rule(input_value, input_bdim, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level), makeBatched(std::get<6>(results), std::get<7>(results), cur_level), makeBatched(std::get<8>(results), std::get<9>(results), cur_level), makeBatched(std::get<10>(results), std::get<11>(results), cur_level)); +} +template +::std::tuple,::std::vector> lstm_mps_backward_generated_plumbing(const ::std::optional & grad_y, const ::std::optional & grad_hy, const ::std::optional & grad_cy, const at::Tensor & z_state, const at::Tensor & cell_state_fwd, const at::Tensor & input, const at::Tensor & layersOutputs, at::TensorList hx, at::TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional, bool batch_first) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_y, cur_level) && !isBatchedAtLevel(grad_hy, cur_level) && !isBatchedAtLevel(grad_cy, cur_level) && !isBatchedAtLevel(z_state, cur_level) && !isBatchedAtLevel(cell_state_fwd, cur_level) && !isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(layersOutputs, cur_level) && !isBatchedAtLevel(hx, cur_level) && !isBatchedAtLevel(params, cur_level)) { + return at::_ops::lstm_mps_backward::call(grad_y, grad_hy, grad_cy, z_state, cell_state_fwd, input, layersOutputs, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first); + } + auto [z_state_value, z_state_bdim] = unwrapTensorAtLevel(z_state, cur_level); + auto [cell_state_fwd_value, cell_state_fwd_bdim] = unwrapTensorAtLevel(cell_state_fwd, cur_level); + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [layersOutputs_value, layersOutputs_bdim] = unwrapTensorAtLevel(layersOutputs, cur_level); + std::optional grad_y_value; + std::optional grad_y_bdim; + if (grad_y) { + std::tie(grad_y_value, grad_y_bdim) = unwrapTensorAtLevel(grad_y.value(), cur_level); + } + std::optional grad_hy_value; + std::optional grad_hy_bdim; + if (grad_hy) { + std::tie(grad_hy_value, grad_hy_bdim) = unwrapTensorAtLevel(grad_hy.value(), cur_level); + } + std::optional grad_cy_value; + std::optional grad_cy_bdim; + if (grad_cy) { + std::tie(grad_cy_value, grad_cy_bdim) = unwrapTensorAtLevel(grad_cy.value(), cur_level); + } + auto results = batch_rule(grad_y_value, grad_y_bdim, grad_hy_value, grad_hy_bdim, grad_cy_value, grad_cy_bdim, z_state_value, z_state_bdim, cell_state_fwd_value, cell_state_fwd_bdim, input_value, input_bdim, layersOutputs_value, layersOutputs_bdim, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatchedVector(std::get<2>(results), std::get<3>(results), cur_level), makeBatchedVector(std::get<4>(results), std::get<5>(results), cur_level)); +} +template +::std::tuple _thnn_fused_lstm_cell_generated_plumbing(const at::Tensor & input_gates, const at::Tensor & hidden_gates, const at::Tensor & cx, const ::std::optional & input_bias, const ::std::optional & hidden_bias) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input_gates, cur_level) && !isBatchedAtLevel(hidden_gates, cur_level) && !isBatchedAtLevel(cx, cur_level) && !isBatchedAtLevel(input_bias, cur_level) && !isBatchedAtLevel(hidden_bias, cur_level)) { + return at::_ops::_thnn_fused_lstm_cell::call(input_gates, hidden_gates, cx, input_bias, hidden_bias); + } + auto [input_gates_value, input_gates_bdim] = unwrapTensorAtLevel(input_gates, cur_level); + auto [hidden_gates_value, hidden_gates_bdim] = unwrapTensorAtLevel(hidden_gates, cur_level); + auto [cx_value, cx_bdim] = unwrapTensorAtLevel(cx, cur_level); + std::optional input_bias_value; + std::optional input_bias_bdim; + if (input_bias) { + std::tie(input_bias_value, input_bias_bdim) = unwrapTensorAtLevel(input_bias.value(), cur_level); + } + std::optional hidden_bias_value; + std::optional hidden_bias_bdim; + if (hidden_bias) { + std::tie(hidden_bias_value, hidden_bias_bdim) = unwrapTensorAtLevel(hidden_bias.value(), cur_level); + } + auto results = batch_rule(input_gates_value, input_gates_bdim, hidden_gates_value, hidden_gates_bdim, cx_value, cx_bdim, input_bias_value, input_bias_bdim, hidden_bias_value, hidden_bias_bdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level)); +} +template +::std::tuple _thnn_fused_lstm_cell_backward_impl_generated_plumbing(const ::std::optional & grad_hy, const ::std::optional & grad_cy, const at::Tensor & cx, const at::Tensor & cy, const at::Tensor & workspace, bool has_bias) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_hy, cur_level) && !isBatchedAtLevel(grad_cy, cur_level) && !isBatchedAtLevel(cx, cur_level) && !isBatchedAtLevel(cy, cur_level) && !isBatchedAtLevel(workspace, cur_level)) { + return at::_ops::_thnn_fused_lstm_cell_backward_impl::call(grad_hy, grad_cy, cx, cy, workspace, has_bias); + } + auto [cx_value, cx_bdim] = unwrapTensorAtLevel(cx, cur_level); + auto [cy_value, cy_bdim] = unwrapTensorAtLevel(cy, cur_level); + auto [workspace_value, workspace_bdim] = unwrapTensorAtLevel(workspace, cur_level); + std::optional grad_hy_value; + std::optional grad_hy_bdim; + if (grad_hy) { + std::tie(grad_hy_value, grad_hy_bdim) = unwrapTensorAtLevel(grad_hy.value(), cur_level); + } + std::optional grad_cy_value; + std::optional grad_cy_bdim; + if (grad_cy) { + std::tie(grad_cy_value, grad_cy_bdim) = unwrapTensorAtLevel(grad_cy.value(), cur_level); + } + auto results = batch_rule(grad_hy_value, grad_hy_bdim, grad_cy_value, grad_cy_bdim, cx_value, cx_bdim, cy_value, cy_bdim, workspace_value, workspace_bdim, has_bias); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level)); +} +template +::std::tuple _thnn_fused_lstm_cell_backward_generated_plumbing(const ::std::optional & grad_hy, const ::std::optional & grad_cy, const at::Tensor & cx, const at::Tensor & cy, const at::Tensor & workspace, bool has_bias) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_hy, cur_level) && !isBatchedAtLevel(grad_cy, cur_level) && !isBatchedAtLevel(cx, cur_level) && !isBatchedAtLevel(cy, cur_level) && !isBatchedAtLevel(workspace, cur_level)) { + return at::_ops::_thnn_fused_lstm_cell_backward::call(grad_hy, grad_cy, cx, cy, workspace, has_bias); + } + auto [cx_value, cx_bdim] = unwrapTensorAtLevel(cx, cur_level); + auto [cy_value, cy_bdim] = unwrapTensorAtLevel(cy, cur_level); + auto [workspace_value, workspace_bdim] = unwrapTensorAtLevel(workspace, cur_level); + std::optional grad_hy_value; + std::optional grad_hy_bdim; + if (grad_hy) { + std::tie(grad_hy_value, grad_hy_bdim) = unwrapTensorAtLevel(grad_hy.value(), cur_level); + } + std::optional grad_cy_value; + std::optional grad_cy_bdim; + if (grad_cy) { + std::tie(grad_cy_value, grad_cy_bdim) = unwrapTensorAtLevel(grad_cy.value(), cur_level); + } + auto results = batch_rule(grad_hy_value, grad_hy_bdim, grad_cy_value, grad_cy_bdim, cx_value, cx_bdim, cy_value, cy_bdim, workspace_value, workspace_bdim, has_bias); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level), makeBatched(std::get<6>(results), std::get<7>(results), cur_level), makeBatched(std::get<8>(results), std::get<9>(results), cur_level)); +} +template +::std::tuple _thnn_differentiable_lstm_cell_backward_generated_plumbing(const ::std::optional & grad_hy, const ::std::optional & grad_cy, const at::Tensor & input_gates, const at::Tensor & hidden_gates, const ::std::optional & input_bias, const ::std::optional & hidden_bias, const at::Tensor & cx, const at::Tensor & cy) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_hy, cur_level) && !isBatchedAtLevel(grad_cy, cur_level) && !isBatchedAtLevel(input_gates, cur_level) && !isBatchedAtLevel(hidden_gates, cur_level) && !isBatchedAtLevel(input_bias, cur_level) && !isBatchedAtLevel(hidden_bias, cur_level) && !isBatchedAtLevel(cx, cur_level) && !isBatchedAtLevel(cy, cur_level)) { + return at::_ops::_thnn_differentiable_lstm_cell_backward::call(grad_hy, grad_cy, input_gates, hidden_gates, input_bias, hidden_bias, cx, cy); + } + auto [input_gates_value, input_gates_bdim] = unwrapTensorAtLevel(input_gates, cur_level); + auto [hidden_gates_value, hidden_gates_bdim] = unwrapTensorAtLevel(hidden_gates, cur_level); + auto [cx_value, cx_bdim] = unwrapTensorAtLevel(cx, cur_level); + auto [cy_value, cy_bdim] = unwrapTensorAtLevel(cy, cur_level); + std::optional grad_hy_value; + std::optional grad_hy_bdim; + if (grad_hy) { + std::tie(grad_hy_value, grad_hy_bdim) = unwrapTensorAtLevel(grad_hy.value(), cur_level); + } + std::optional grad_cy_value; + std::optional grad_cy_bdim; + if (grad_cy) { + std::tie(grad_cy_value, grad_cy_bdim) = unwrapTensorAtLevel(grad_cy.value(), cur_level); + } + std::optional input_bias_value; + std::optional input_bias_bdim; + if (input_bias) { + std::tie(input_bias_value, input_bias_bdim) = unwrapTensorAtLevel(input_bias.value(), cur_level); + } + std::optional hidden_bias_value; + std::optional hidden_bias_bdim; + if (hidden_bias) { + std::tie(hidden_bias_value, hidden_bias_bdim) = unwrapTensorAtLevel(hidden_bias.value(), cur_level); + } + auto results = batch_rule(grad_hy_value, grad_hy_bdim, grad_cy_value, grad_cy_bdim, input_gates_value, input_gates_bdim, hidden_gates_value, hidden_gates_bdim, input_bias_value, input_bias_bdim, hidden_bias_value, hidden_bias_bdim, cx_value, cx_bdim, cy_value, cy_bdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level), makeBatched(std::get<6>(results), std::get<7>(results), cur_level), makeBatched(std::get<8>(results), std::get<9>(results), cur_level)); +} +template +::std::tuple _thnn_fused_gru_cell_generated_plumbing(const at::Tensor & input_gates, const at::Tensor & hidden_gates, const at::Tensor & hx, const ::std::optional & input_bias, const ::std::optional & hidden_bias) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input_gates, cur_level) && !isBatchedAtLevel(hidden_gates, cur_level) && !isBatchedAtLevel(hx, cur_level) && !isBatchedAtLevel(input_bias, cur_level) && !isBatchedAtLevel(hidden_bias, cur_level)) { + return at::_ops::_thnn_fused_gru_cell::call(input_gates, hidden_gates, hx, input_bias, hidden_bias); + } + auto [input_gates_value, input_gates_bdim] = unwrapTensorAtLevel(input_gates, cur_level); + auto [hidden_gates_value, hidden_gates_bdim] = unwrapTensorAtLevel(hidden_gates, cur_level); + auto [hx_value, hx_bdim] = unwrapTensorAtLevel(hx, cur_level); + std::optional input_bias_value; + std::optional input_bias_bdim; + if (input_bias) { + std::tie(input_bias_value, input_bias_bdim) = unwrapTensorAtLevel(input_bias.value(), cur_level); + } + std::optional hidden_bias_value; + std::optional hidden_bias_bdim; + if (hidden_bias) { + std::tie(hidden_bias_value, hidden_bias_bdim) = unwrapTensorAtLevel(hidden_bias.value(), cur_level); + } + auto results = batch_rule(input_gates_value, input_gates_bdim, hidden_gates_value, hidden_gates_bdim, hx_value, hx_bdim, input_bias_value, input_bias_bdim, hidden_bias_value, hidden_bias_bdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple _thnn_fused_gru_cell_backward_generated_plumbing(const at::Tensor & grad_hy, const at::Tensor & workspace, bool has_bias) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_hy, cur_level) && !isBatchedAtLevel(workspace, cur_level)) { + return at::_ops::_thnn_fused_gru_cell_backward::call(grad_hy, workspace, has_bias); + } + auto [grad_hy_value, grad_hy_bdim] = unwrapTensorAtLevel(grad_hy, cur_level); + auto [workspace_value, workspace_bdim] = unwrapTensorAtLevel(workspace, cur_level); + auto results = batch_rule(grad_hy_value, grad_hy_bdim, workspace_value, workspace_bdim, has_bias); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level), makeBatched(std::get<6>(results), std::get<7>(results), cur_level), makeBatched(std::get<8>(results), std::get<9>(results), cur_level)); +} +template +::std::tuple _thnn_differentiable_gru_cell_backward_generated_plumbing(const at::Tensor & grad_hy, const at::Tensor & input_gates, const at::Tensor & hidden_gates, const at::Tensor & hx, const ::std::optional & input_bias, const ::std::optional & hidden_bias) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_hy, cur_level) && !isBatchedAtLevel(input_gates, cur_level) && !isBatchedAtLevel(hidden_gates, cur_level) && !isBatchedAtLevel(hx, cur_level) && !isBatchedAtLevel(input_bias, cur_level) && !isBatchedAtLevel(hidden_bias, cur_level)) { + return at::_ops::_thnn_differentiable_gru_cell_backward::call(grad_hy, input_gates, hidden_gates, hx, input_bias, hidden_bias); + } + auto [grad_hy_value, grad_hy_bdim] = unwrapTensorAtLevel(grad_hy, cur_level); + auto [input_gates_value, input_gates_bdim] = unwrapTensorAtLevel(input_gates, cur_level); + auto [hidden_gates_value, hidden_gates_bdim] = unwrapTensorAtLevel(hidden_gates, cur_level); + auto [hx_value, hx_bdim] = unwrapTensorAtLevel(hx, cur_level); + std::optional input_bias_value; + std::optional input_bias_bdim; + if (input_bias) { + std::tie(input_bias_value, input_bias_bdim) = unwrapTensorAtLevel(input_bias.value(), cur_level); + } + std::optional hidden_bias_value; + std::optional hidden_bias_bdim; + if (hidden_bias) { + std::tie(hidden_bias_value, hidden_bias_bdim) = unwrapTensorAtLevel(hidden_bias.value(), cur_level); + } + auto results = batch_rule(grad_hy_value, grad_hy_bdim, input_gates_value, input_gates_bdim, hidden_gates_value, hidden_gates_bdim, hx_value, hx_bdim, input_bias_value, input_bias_bdim, hidden_bias_value, hidden_bias_bdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level), makeBatched(std::get<6>(results), std::get<7>(results), cur_level), makeBatched(std::get<8>(results), std::get<9>(results), cur_level)); +} +template +::std::tuple lstm_input_generated_plumbing(const at::Tensor & input, at::TensorList hx, at::TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional, bool batch_first) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(hx, cur_level) && !isBatchedAtLevel(params, cur_level)) { + return at::_ops::lstm_input::call(input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto results = batch_rule(input_value, input_bdim, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level)); +} +template +::std::tuple lstm_data_generated_plumbing(const at::Tensor & data, const at::Tensor & batch_sizes, at::TensorList hx, at::TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(data, cur_level) && !isBatchedAtLevel(batch_sizes, cur_level) && !isBatchedAtLevel(hx, cur_level) && !isBatchedAtLevel(params, cur_level)) { + return at::_ops::lstm_data::call(data, batch_sizes, hx, params, has_biases, num_layers, dropout, train, bidirectional); + } + auto [data_value, data_bdim] = unwrapTensorAtLevel(data, cur_level); + auto [batch_sizes_value, batch_sizes_bdim] = unwrapTensorAtLevel(batch_sizes, cur_level); + auto results = batch_rule(data_value, data_bdim, batch_sizes_value, batch_sizes_bdim, hx, params, has_biases, num_layers, dropout, train, bidirectional); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level)); +} +template +::std::tuple gru_input_generated_plumbing(const at::Tensor & input, const at::Tensor & hx, at::TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional, bool batch_first) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(hx, cur_level) && !isBatchedAtLevel(params, cur_level)) { + return at::_ops::gru_input::call(input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [hx_value, hx_bdim] = unwrapTensorAtLevel(hx, cur_level); + auto results = batch_rule(input_value, input_bdim, hx_value, hx_bdim, params, has_biases, num_layers, dropout, train, bidirectional, batch_first); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple gru_data_generated_plumbing(const at::Tensor & data, const at::Tensor & batch_sizes, const at::Tensor & hx, at::TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(data, cur_level) && !isBatchedAtLevel(batch_sizes, cur_level) && !isBatchedAtLevel(hx, cur_level) && !isBatchedAtLevel(params, cur_level)) { + return at::_ops::gru_data::call(data, batch_sizes, hx, params, has_biases, num_layers, dropout, train, bidirectional); + } + auto [data_value, data_bdim] = unwrapTensorAtLevel(data, cur_level); + auto [batch_sizes_value, batch_sizes_bdim] = unwrapTensorAtLevel(batch_sizes, cur_level); + auto [hx_value, hx_bdim] = unwrapTensorAtLevel(hx, cur_level); + auto results = batch_rule(data_value, data_bdim, batch_sizes_value, batch_sizes_bdim, hx_value, hx_bdim, params, has_biases, num_layers, dropout, train, bidirectional); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple rnn_tanh_input_generated_plumbing(const at::Tensor & input, const at::Tensor & hx, at::TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional, bool batch_first) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(hx, cur_level) && !isBatchedAtLevel(params, cur_level)) { + return at::_ops::rnn_tanh_input::call(input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [hx_value, hx_bdim] = unwrapTensorAtLevel(hx, cur_level); + auto results = batch_rule(input_value, input_bdim, hx_value, hx_bdim, params, has_biases, num_layers, dropout, train, bidirectional, batch_first); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple rnn_tanh_data_generated_plumbing(const at::Tensor & data, const at::Tensor & batch_sizes, const at::Tensor & hx, at::TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(data, cur_level) && !isBatchedAtLevel(batch_sizes, cur_level) && !isBatchedAtLevel(hx, cur_level) && !isBatchedAtLevel(params, cur_level)) { + return at::_ops::rnn_tanh_data::call(data, batch_sizes, hx, params, has_biases, num_layers, dropout, train, bidirectional); + } + auto [data_value, data_bdim] = unwrapTensorAtLevel(data, cur_level); + auto [batch_sizes_value, batch_sizes_bdim] = unwrapTensorAtLevel(batch_sizes, cur_level); + auto [hx_value, hx_bdim] = unwrapTensorAtLevel(hx, cur_level); + auto results = batch_rule(data_value, data_bdim, batch_sizes_value, batch_sizes_bdim, hx_value, hx_bdim, params, has_biases, num_layers, dropout, train, bidirectional); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple rnn_relu_input_generated_plumbing(const at::Tensor & input, const at::Tensor & hx, at::TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional, bool batch_first) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(hx, cur_level) && !isBatchedAtLevel(params, cur_level)) { + return at::_ops::rnn_relu_input::call(input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [hx_value, hx_bdim] = unwrapTensorAtLevel(hx, cur_level); + auto results = batch_rule(input_value, input_bdim, hx_value, hx_bdim, params, has_biases, num_layers, dropout, train, bidirectional, batch_first); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple rnn_relu_data_generated_plumbing(const at::Tensor & data, const at::Tensor & batch_sizes, const at::Tensor & hx, at::TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(data, cur_level) && !isBatchedAtLevel(batch_sizes, cur_level) && !isBatchedAtLevel(hx, cur_level) && !isBatchedAtLevel(params, cur_level)) { + return at::_ops::rnn_relu_data::call(data, batch_sizes, hx, params, has_biases, num_layers, dropout, train, bidirectional); + } + auto [data_value, data_bdim] = unwrapTensorAtLevel(data, cur_level); + auto [batch_sizes_value, batch_sizes_bdim] = unwrapTensorAtLevel(batch_sizes, cur_level); + auto [hx_value, hx_bdim] = unwrapTensorAtLevel(hx, cur_level); + auto results = batch_rule(data_value, data_bdim, batch_sizes_value, batch_sizes_bdim, hx_value, hx_bdim, params, has_biases, num_layers, dropout, train, bidirectional); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple lstm_cell_generated_plumbing(const at::Tensor & input, at::TensorList hx, const at::Tensor & w_ih, const at::Tensor & w_hh, const ::std::optional & b_ih, const ::std::optional & b_hh) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(hx, cur_level) && !isBatchedAtLevel(w_ih, cur_level) && !isBatchedAtLevel(w_hh, cur_level) && !isBatchedAtLevel(b_ih, cur_level) && !isBatchedAtLevel(b_hh, cur_level)) { + return at::_ops::lstm_cell::call(input, hx, w_ih, w_hh, b_ih, b_hh); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [w_ih_value, w_ih_bdim] = unwrapTensorAtLevel(w_ih, cur_level); + auto [w_hh_value, w_hh_bdim] = unwrapTensorAtLevel(w_hh, cur_level); + std::optional b_ih_value; + std::optional b_ih_bdim; + if (b_ih) { + std::tie(b_ih_value, b_ih_bdim) = unwrapTensorAtLevel(b_ih.value(), cur_level); + } + std::optional b_hh_value; + std::optional b_hh_bdim; + if (b_hh) { + std::tie(b_hh_value, b_hh_bdim) = unwrapTensorAtLevel(b_hh.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, hx, w_ih_value, w_ih_bdim, w_hh_value, w_hh_bdim, b_ih_value, b_ih_bdim, b_hh_value, b_hh_bdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor gru_cell_generated_plumbing(const at::Tensor & input, const at::Tensor & hx, const at::Tensor & w_ih, const at::Tensor & w_hh, const ::std::optional & b_ih, const ::std::optional & b_hh) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(hx, cur_level) && !isBatchedAtLevel(w_ih, cur_level) && !isBatchedAtLevel(w_hh, cur_level) && !isBatchedAtLevel(b_ih, cur_level) && !isBatchedAtLevel(b_hh, cur_level)) { + return at::_ops::gru_cell::call(input, hx, w_ih, w_hh, b_ih, b_hh); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [hx_value, hx_bdim] = unwrapTensorAtLevel(hx, cur_level); + auto [w_ih_value, w_ih_bdim] = unwrapTensorAtLevel(w_ih, cur_level); + auto [w_hh_value, w_hh_bdim] = unwrapTensorAtLevel(w_hh, cur_level); + std::optional b_ih_value; + std::optional b_ih_bdim; + if (b_ih) { + std::tie(b_ih_value, b_ih_bdim) = unwrapTensorAtLevel(b_ih.value(), cur_level); + } + std::optional b_hh_value; + std::optional b_hh_bdim; + if (b_hh) { + std::tie(b_hh_value, b_hh_bdim) = unwrapTensorAtLevel(b_hh.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, hx_value, hx_bdim, w_ih_value, w_ih_bdim, w_hh_value, w_hh_bdim, b_ih_value, b_ih_bdim, b_hh_value, b_hh_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor rnn_tanh_cell_generated_plumbing(const at::Tensor & input, const at::Tensor & hx, const at::Tensor & w_ih, const at::Tensor & w_hh, const ::std::optional & b_ih, const ::std::optional & b_hh) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(hx, cur_level) && !isBatchedAtLevel(w_ih, cur_level) && !isBatchedAtLevel(w_hh, cur_level) && !isBatchedAtLevel(b_ih, cur_level) && !isBatchedAtLevel(b_hh, cur_level)) { + return at::_ops::rnn_tanh_cell::call(input, hx, w_ih, w_hh, b_ih, b_hh); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [hx_value, hx_bdim] = unwrapTensorAtLevel(hx, cur_level); + auto [w_ih_value, w_ih_bdim] = unwrapTensorAtLevel(w_ih, cur_level); + auto [w_hh_value, w_hh_bdim] = unwrapTensorAtLevel(w_hh, cur_level); + std::optional b_ih_value; + std::optional b_ih_bdim; + if (b_ih) { + std::tie(b_ih_value, b_ih_bdim) = unwrapTensorAtLevel(b_ih.value(), cur_level); + } + std::optional b_hh_value; + std::optional b_hh_bdim; + if (b_hh) { + std::tie(b_hh_value, b_hh_bdim) = unwrapTensorAtLevel(b_hh.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, hx_value, hx_bdim, w_ih_value, w_ih_bdim, w_hh_value, w_hh_bdim, b_ih_value, b_ih_bdim, b_hh_value, b_hh_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor rnn_relu_cell_generated_plumbing(const at::Tensor & input, const at::Tensor & hx, const at::Tensor & w_ih, const at::Tensor & w_hh, const ::std::optional & b_ih, const ::std::optional & b_hh) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(hx, cur_level) && !isBatchedAtLevel(w_ih, cur_level) && !isBatchedAtLevel(w_hh, cur_level) && !isBatchedAtLevel(b_ih, cur_level) && !isBatchedAtLevel(b_hh, cur_level)) { + return at::_ops::rnn_relu_cell::call(input, hx, w_ih, w_hh, b_ih, b_hh); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [hx_value, hx_bdim] = unwrapTensorAtLevel(hx, cur_level); + auto [w_ih_value, w_ih_bdim] = unwrapTensorAtLevel(w_ih, cur_level); + auto [w_hh_value, w_hh_bdim] = unwrapTensorAtLevel(w_hh, cur_level); + std::optional b_ih_value; + std::optional b_ih_bdim; + if (b_ih) { + std::tie(b_ih_value, b_ih_bdim) = unwrapTensorAtLevel(b_ih.value(), cur_level); + } + std::optional b_hh_value; + std::optional b_hh_bdim; + if (b_hh) { + std::tie(b_hh_value, b_hh_bdim) = unwrapTensorAtLevel(b_hh.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, hx_value, hx_bdim, w_ih_value, w_ih_bdim, w_hh_value, w_hh_bdim, b_ih_value, b_ih_bdim, b_hh_value, b_hh_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple quantized_lstm_cell_generated_plumbing(const at::Tensor & input, at::TensorList hx, const at::Tensor & w_ih, const at::Tensor & w_hh, const at::Tensor & b_ih, const at::Tensor & b_hh, const at::Tensor & packed_ih, const at::Tensor & packed_hh, const at::Tensor & col_offsets_ih, const at::Tensor & col_offsets_hh, const at::Scalar & scale_ih, const at::Scalar & scale_hh, const at::Scalar & zero_point_ih, const at::Scalar & zero_point_hh) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(hx, cur_level) && !isBatchedAtLevel(w_ih, cur_level) && !isBatchedAtLevel(w_hh, cur_level) && !isBatchedAtLevel(b_ih, cur_level) && !isBatchedAtLevel(b_hh, cur_level) && !isBatchedAtLevel(packed_ih, cur_level) && !isBatchedAtLevel(packed_hh, cur_level) && !isBatchedAtLevel(col_offsets_ih, cur_level) && !isBatchedAtLevel(col_offsets_hh, cur_level)) { + return at::_ops::quantized_lstm_cell::call(input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih, col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [w_ih_value, w_ih_bdim] = unwrapTensorAtLevel(w_ih, cur_level); + auto [w_hh_value, w_hh_bdim] = unwrapTensorAtLevel(w_hh, cur_level); + auto [b_ih_value, b_ih_bdim] = unwrapTensorAtLevel(b_ih, cur_level); + auto [b_hh_value, b_hh_bdim] = unwrapTensorAtLevel(b_hh, cur_level); + auto [packed_ih_value, packed_ih_bdim] = unwrapTensorAtLevel(packed_ih, cur_level); + auto [packed_hh_value, packed_hh_bdim] = unwrapTensorAtLevel(packed_hh, cur_level); + auto [col_offsets_ih_value, col_offsets_ih_bdim] = unwrapTensorAtLevel(col_offsets_ih, cur_level); + auto [col_offsets_hh_value, col_offsets_hh_bdim] = unwrapTensorAtLevel(col_offsets_hh, cur_level); + auto results = batch_rule(input_value, input_bdim, hx, w_ih_value, w_ih_bdim, w_hh_value, w_hh_bdim, b_ih_value, b_ih_bdim, b_hh_value, b_hh_bdim, packed_ih_value, packed_ih_bdim, packed_hh_value, packed_hh_bdim, col_offsets_ih_value, col_offsets_ih_bdim, col_offsets_hh_value, col_offsets_hh_bdim, scale_ih, scale_hh, zero_point_ih, zero_point_hh); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor quantized_gru_cell_generated_plumbing(const at::Tensor & input, const at::Tensor & hx, const at::Tensor & w_ih, const at::Tensor & w_hh, const at::Tensor & b_ih, const at::Tensor & b_hh, const at::Tensor & packed_ih, const at::Tensor & packed_hh, const at::Tensor & col_offsets_ih, const at::Tensor & col_offsets_hh, const at::Scalar & scale_ih, const at::Scalar & scale_hh, const at::Scalar & zero_point_ih, const at::Scalar & zero_point_hh) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(hx, cur_level) && !isBatchedAtLevel(w_ih, cur_level) && !isBatchedAtLevel(w_hh, cur_level) && !isBatchedAtLevel(b_ih, cur_level) && !isBatchedAtLevel(b_hh, cur_level) && !isBatchedAtLevel(packed_ih, cur_level) && !isBatchedAtLevel(packed_hh, cur_level) && !isBatchedAtLevel(col_offsets_ih, cur_level) && !isBatchedAtLevel(col_offsets_hh, cur_level)) { + return at::_ops::quantized_gru_cell::call(input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih, col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [hx_value, hx_bdim] = unwrapTensorAtLevel(hx, cur_level); + auto [w_ih_value, w_ih_bdim] = unwrapTensorAtLevel(w_ih, cur_level); + auto [w_hh_value, w_hh_bdim] = unwrapTensorAtLevel(w_hh, cur_level); + auto [b_ih_value, b_ih_bdim] = unwrapTensorAtLevel(b_ih, cur_level); + auto [b_hh_value, b_hh_bdim] = unwrapTensorAtLevel(b_hh, cur_level); + auto [packed_ih_value, packed_ih_bdim] = unwrapTensorAtLevel(packed_ih, cur_level); + auto [packed_hh_value, packed_hh_bdim] = unwrapTensorAtLevel(packed_hh, cur_level); + auto [col_offsets_ih_value, col_offsets_ih_bdim] = unwrapTensorAtLevel(col_offsets_ih, cur_level); + auto [col_offsets_hh_value, col_offsets_hh_bdim] = unwrapTensorAtLevel(col_offsets_hh, cur_level); + auto results = batch_rule(input_value, input_bdim, hx_value, hx_bdim, w_ih_value, w_ih_bdim, w_hh_value, w_hh_bdim, b_ih_value, b_ih_bdim, b_hh_value, b_hh_bdim, packed_ih_value, packed_ih_bdim, packed_hh_value, packed_hh_bdim, col_offsets_ih_value, col_offsets_ih_bdim, col_offsets_hh_value, col_offsets_hh_bdim, scale_ih, scale_hh, zero_point_ih, zero_point_hh); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor quantized_rnn_relu_cell_generated_plumbing(const at::Tensor & input, const at::Tensor & hx, const at::Tensor & w_ih, const at::Tensor & w_hh, const at::Tensor & b_ih, const at::Tensor & b_hh, const at::Tensor & packed_ih, const at::Tensor & packed_hh, const at::Tensor & col_offsets_ih, const at::Tensor & col_offsets_hh, const at::Scalar & scale_ih, const at::Scalar & scale_hh, const at::Scalar & zero_point_ih, const at::Scalar & zero_point_hh) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(hx, cur_level) && !isBatchedAtLevel(w_ih, cur_level) && !isBatchedAtLevel(w_hh, cur_level) && !isBatchedAtLevel(b_ih, cur_level) && !isBatchedAtLevel(b_hh, cur_level) && !isBatchedAtLevel(packed_ih, cur_level) && !isBatchedAtLevel(packed_hh, cur_level) && !isBatchedAtLevel(col_offsets_ih, cur_level) && !isBatchedAtLevel(col_offsets_hh, cur_level)) { + return at::_ops::quantized_rnn_relu_cell::call(input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih, col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [hx_value, hx_bdim] = unwrapTensorAtLevel(hx, cur_level); + auto [w_ih_value, w_ih_bdim] = unwrapTensorAtLevel(w_ih, cur_level); + auto [w_hh_value, w_hh_bdim] = unwrapTensorAtLevel(w_hh, cur_level); + auto [b_ih_value, b_ih_bdim] = unwrapTensorAtLevel(b_ih, cur_level); + auto [b_hh_value, b_hh_bdim] = unwrapTensorAtLevel(b_hh, cur_level); + auto [packed_ih_value, packed_ih_bdim] = unwrapTensorAtLevel(packed_ih, cur_level); + auto [packed_hh_value, packed_hh_bdim] = unwrapTensorAtLevel(packed_hh, cur_level); + auto [col_offsets_ih_value, col_offsets_ih_bdim] = unwrapTensorAtLevel(col_offsets_ih, cur_level); + auto [col_offsets_hh_value, col_offsets_hh_bdim] = unwrapTensorAtLevel(col_offsets_hh, cur_level); + auto results = batch_rule(input_value, input_bdim, hx_value, hx_bdim, w_ih_value, w_ih_bdim, w_hh_value, w_hh_bdim, b_ih_value, b_ih_bdim, b_hh_value, b_hh_bdim, packed_ih_value, packed_ih_bdim, packed_hh_value, packed_hh_bdim, col_offsets_ih_value, col_offsets_ih_bdim, col_offsets_hh_value, col_offsets_hh_bdim, scale_ih, scale_hh, zero_point_ih, zero_point_hh); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor quantized_rnn_tanh_cell_generated_plumbing(const at::Tensor & input, const at::Tensor & hx, const at::Tensor & w_ih, const at::Tensor & w_hh, const at::Tensor & b_ih, const at::Tensor & b_hh, const at::Tensor & packed_ih, const at::Tensor & packed_hh, const at::Tensor & col_offsets_ih, const at::Tensor & col_offsets_hh, const at::Scalar & scale_ih, const at::Scalar & scale_hh, const at::Scalar & zero_point_ih, const at::Scalar & zero_point_hh) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(hx, cur_level) && !isBatchedAtLevel(w_ih, cur_level) && !isBatchedAtLevel(w_hh, cur_level) && !isBatchedAtLevel(b_ih, cur_level) && !isBatchedAtLevel(b_hh, cur_level) && !isBatchedAtLevel(packed_ih, cur_level) && !isBatchedAtLevel(packed_hh, cur_level) && !isBatchedAtLevel(col_offsets_ih, cur_level) && !isBatchedAtLevel(col_offsets_hh, cur_level)) { + return at::_ops::quantized_rnn_tanh_cell::call(input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih, col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [hx_value, hx_bdim] = unwrapTensorAtLevel(hx, cur_level); + auto [w_ih_value, w_ih_bdim] = unwrapTensorAtLevel(w_ih, cur_level); + auto [w_hh_value, w_hh_bdim] = unwrapTensorAtLevel(w_hh, cur_level); + auto [b_ih_value, b_ih_bdim] = unwrapTensorAtLevel(b_ih, cur_level); + auto [b_hh_value, b_hh_bdim] = unwrapTensorAtLevel(b_hh, cur_level); + auto [packed_ih_value, packed_ih_bdim] = unwrapTensorAtLevel(packed_ih, cur_level); + auto [packed_hh_value, packed_hh_bdim] = unwrapTensorAtLevel(packed_hh, cur_level); + auto [col_offsets_ih_value, col_offsets_ih_bdim] = unwrapTensorAtLevel(col_offsets_ih, cur_level); + auto [col_offsets_hh_value, col_offsets_hh_bdim] = unwrapTensorAtLevel(col_offsets_hh, cur_level); + auto results = batch_rule(input_value, input_bdim, hx_value, hx_bdim, w_ih_value, w_ih_bdim, w_hh_value, w_hh_bdim, b_ih_value, b_ih_bdim, b_hh_value, b_hh_bdim, packed_ih_value, packed_ih_bdim, packed_hh_value, packed_hh_bdim, col_offsets_ih_value, col_offsets_ih_bdim, col_offsets_hh_value, col_offsets_hh_bdim, scale_ih, scale_hh, zero_point_ih, zero_point_hh); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple _pack_padded_sequence_generated_plumbing(const at::Tensor & input, const at::Tensor & lengths, bool batch_first) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(lengths, cur_level)) { + return at::_ops::_pack_padded_sequence::call(input, lengths, batch_first); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [lengths_value, lengths_bdim] = unwrapTensorAtLevel(lengths, cur_level); + auto results = batch_rule(input_value, input_bdim, lengths_value, lengths_bdim, batch_first); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor _pack_padded_sequence_backward_generated_plumbing(const at::Tensor & grad, c10::SymIntArrayRef input_size, const at::Tensor & batch_sizes, bool batch_first) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad, cur_level) && !isBatchedAtLevel(batch_sizes, cur_level)) { + return at::_ops::_pack_padded_sequence_backward::call(grad, input_size, batch_sizes, batch_first); + } + auto [grad_value, grad_bdim] = unwrapTensorAtLevel(grad, cur_level); + auto [batch_sizes_value, batch_sizes_bdim] = unwrapTensorAtLevel(batch_sizes, cur_level); + auto results = batch_rule(grad_value, grad_bdim, input_size, batch_sizes_value, batch_sizes_bdim, batch_first); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple _pad_packed_sequence_generated_plumbing(const at::Tensor & data, const at::Tensor & batch_sizes, bool batch_first, const at::Scalar & padding_value, int64_t total_length) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(data, cur_level) && !isBatchedAtLevel(batch_sizes, cur_level)) { + return at::_ops::_pad_packed_sequence::call(data, batch_sizes, batch_first, padding_value, total_length); + } + auto [data_value, data_bdim] = unwrapTensorAtLevel(data, cur_level); + auto [batch_sizes_value, batch_sizes_bdim] = unwrapTensorAtLevel(batch_sizes, cur_level); + auto results = batch_rule(data_value, data_bdim, batch_sizes_value, batch_sizes_bdim, batch_first, padding_value, total_length); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor lift_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::lift::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor lift_fresh_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::lift_fresh::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor lift_fresh_copy_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::lift_fresh_copy::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & masked_fill__Scalar_generated_plumbing(at::Tensor & self, const at::Tensor & mask, const at::Scalar & value) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(mask, cur_level)) { + return at::_ops::masked_fill__Scalar::call(self, mask, value); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [mask_value, mask_bdim] = unwrapTensorAtLevel(mask, cur_level); + batch_rule(self_value, self_bdim, mask_value, mask_bdim, value); + return self; +} +template +at::Tensor masked_fill_Scalar_generated_plumbing(const at::Tensor & self, const at::Tensor & mask, const at::Scalar & value) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(mask, cur_level)) { + return at::_ops::masked_fill_Scalar::call(self, mask, value); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [mask_value, mask_bdim] = unwrapTensorAtLevel(mask, cur_level); + auto results = batch_rule(self_value, self_bdim, mask_value, mask_bdim, value); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & masked_fill__Tensor_generated_plumbing(at::Tensor & self, const at::Tensor & mask, const at::Tensor & value) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(mask, cur_level) && !isBatchedAtLevel(value, cur_level)) { + return at::_ops::masked_fill__Tensor::call(self, mask, value); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [mask_value, mask_bdim] = unwrapTensorAtLevel(mask, cur_level); + auto [value_value, value_bdim] = unwrapTensorAtLevel(value, cur_level); + batch_rule(self_value, self_bdim, mask_value, mask_bdim, value_value, value_bdim); + return self; +} +template +at::Tensor masked_fill_Tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & mask, const at::Tensor & value) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(mask, cur_level) && !isBatchedAtLevel(value, cur_level)) { + return at::_ops::masked_fill_Tensor::call(self, mask, value); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [mask_value, mask_bdim] = unwrapTensorAtLevel(mask, cur_level); + auto [value_value, value_bdim] = unwrapTensorAtLevel(value, cur_level); + auto results = batch_rule(self_value, self_bdim, mask_value, mask_bdim, value_value, value_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & masked_scatter__generated_plumbing(at::Tensor & self, const at::Tensor & mask, const at::Tensor & source) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(mask, cur_level) && !isBatchedAtLevel(source, cur_level)) { + return at::_ops::masked_scatter_::call(self, mask, source); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [mask_value, mask_bdim] = unwrapTensorAtLevel(mask, cur_level); + auto [source_value, source_bdim] = unwrapTensorAtLevel(source, cur_level); + batch_rule(self_value, self_bdim, mask_value, mask_bdim, source_value, source_bdim); + return self; +} +template +at::Tensor masked_scatter_generated_plumbing(const at::Tensor & self, const at::Tensor & mask, const at::Tensor & source) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(mask, cur_level) && !isBatchedAtLevel(source, cur_level)) { + return at::_ops::masked_scatter::call(self, mask, source); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [mask_value, mask_bdim] = unwrapTensorAtLevel(mask, cur_level); + auto [source_value, source_bdim] = unwrapTensorAtLevel(source, cur_level); + auto results = batch_rule(self_value, self_bdim, mask_value, mask_bdim, source_value, source_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor masked_scatter_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & mask, c10::SymIntArrayRef sizes) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(mask, cur_level)) { + return at::_ops::masked_scatter_backward::call(grad_output, mask, sizes); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [mask_value, mask_bdim] = unwrapTensorAtLevel(mask, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, mask_value, mask_bdim, sizes); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _masked_softmax_generated_plumbing(const at::Tensor & self, const at::Tensor & mask, ::std::optional dim, ::std::optional mask_type) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(mask, cur_level)) { + return at::_ops::_masked_softmax::call(self, mask, dim, mask_type); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [mask_value, mask_bdim] = unwrapTensorAtLevel(mask, cur_level); + auto results = batch_rule(self_value, self_bdim, mask_value, mask_bdim, dim, mask_type); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _masked_softmax_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & output, const at::Tensor & mask, ::std::optional dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(output, cur_level) && !isBatchedAtLevel(mask, cur_level)) { + return at::_ops::_masked_softmax_backward::call(grad_output, output, mask, dim); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [output_value, output_bdim] = unwrapTensorAtLevel(output, cur_level); + auto [mask_value, mask_bdim] = unwrapTensorAtLevel(mask, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, output_value, output_bdim, mask_value, mask_bdim, dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor view_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef size) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::view::call(self, size); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, size); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor view_dtype_generated_plumbing(const at::Tensor & self, at::ScalarType dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::view_dtype::call(self, dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & put__generated_plumbing(at::Tensor & self, const at::Tensor & index, const at::Tensor & source, bool accumulate) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(index, cur_level) && !isBatchedAtLevel(source, cur_level)) { + return at::_ops::put_::call(self, index, source, accumulate); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [index_value, index_bdim] = unwrapTensorAtLevel(index, cur_level); + auto [source_value, source_bdim] = unwrapTensorAtLevel(source, cur_level); + batch_rule(self_value, self_bdim, index_value, index_bdim, source_value, source_bdim, accumulate); + return self; +} +template +at::Tensor put_generated_plumbing(const at::Tensor & self, const at::Tensor & index, const at::Tensor & source, bool accumulate) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(index, cur_level) && !isBatchedAtLevel(source, cur_level)) { + return at::_ops::put::call(self, index, source, accumulate); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [index_value, index_bdim] = unwrapTensorAtLevel(index, cur_level); + auto [source_value, source_bdim] = unwrapTensorAtLevel(source, cur_level); + auto results = batch_rule(self_value, self_bdim, index_value, index_bdim, source_value, source_bdim, accumulate); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & index_add__generated_plumbing(at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & source, const at::Scalar & alpha) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(index, cur_level) && !isBatchedAtLevel(source, cur_level)) { + return at::_ops::index_add_::call(self, dim, index, source, alpha); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [index_value, index_bdim] = unwrapTensorAtLevel(index, cur_level); + auto [source_value, source_bdim] = unwrapTensorAtLevel(source, cur_level); + batch_rule(self_value, self_bdim, dim, index_value, index_bdim, source_value, source_bdim, alpha); + return self; +} +template +at::Tensor index_add_generated_plumbing(const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & source, const at::Scalar & alpha) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(index, cur_level) && !isBatchedAtLevel(source, cur_level)) { + return at::_ops::index_add::call(self, dim, index, source, alpha); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [index_value, index_bdim] = unwrapTensorAtLevel(index, cur_level); + auto [source_value, source_bdim] = unwrapTensorAtLevel(source, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, index_value, index_bdim, source_value, source_bdim, alpha); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor index_add_dimname_generated_plumbing(const at::Tensor & self, at::Dimname dim, const at::Tensor & index, const at::Tensor & source, const at::Scalar & alpha) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(index, cur_level) && !isBatchedAtLevel(source, cur_level)) { + return at::_ops::index_add_dimname::call(self, dim, index, source, alpha); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [index_value, index_bdim] = unwrapTensorAtLevel(index, cur_level); + auto [source_value, source_bdim] = unwrapTensorAtLevel(source, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, index_value, index_bdim, source_value, source_bdim, alpha); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & index_reduce__generated_plumbing(at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & source, c10::string_view reduce, bool include_self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(index, cur_level) && !isBatchedAtLevel(source, cur_level)) { + return at::_ops::index_reduce_::call(self, dim, index, source, reduce, include_self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [index_value, index_bdim] = unwrapTensorAtLevel(index, cur_level); + auto [source_value, source_bdim] = unwrapTensorAtLevel(source, cur_level); + batch_rule(self_value, self_bdim, dim, index_value, index_bdim, source_value, source_bdim, reduce, include_self); + return self; +} +template +at::Tensor index_reduce_generated_plumbing(const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & source, c10::string_view reduce, bool include_self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(index, cur_level) && !isBatchedAtLevel(source, cur_level)) { + return at::_ops::index_reduce::call(self, dim, index, source, reduce, include_self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [index_value, index_bdim] = unwrapTensorAtLevel(index, cur_level); + auto [source_value, source_bdim] = unwrapTensorAtLevel(source, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, index_value, index_bdim, source_value, source_bdim, reduce, include_self); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & index_fill__int_Scalar_generated_plumbing(at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Scalar & value) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(index, cur_level)) { + return at::_ops::index_fill__int_Scalar::call(self, dim, index, value); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [index_value, index_bdim] = unwrapTensorAtLevel(index, cur_level); + batch_rule(self_value, self_bdim, dim, index_value, index_bdim, value); + return self; +} +template +at::Tensor index_fill_int_Scalar_generated_plumbing(const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Scalar & value) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(index, cur_level)) { + return at::_ops::index_fill_int_Scalar::call(self, dim, index, value); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [index_value, index_bdim] = unwrapTensorAtLevel(index, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, index_value, index_bdim, value); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & index_fill__int_Tensor_generated_plumbing(at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & value) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(index, cur_level) && !isBatchedAtLevel(value, cur_level)) { + return at::_ops::index_fill__int_Tensor::call(self, dim, index, value); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [index_value, index_bdim] = unwrapTensorAtLevel(index, cur_level); + auto [value_value, value_bdim] = unwrapTensorAtLevel(value, cur_level); + batch_rule(self_value, self_bdim, dim, index_value, index_bdim, value_value, value_bdim); + return self; +} +template +at::Tensor index_fill_int_Tensor_generated_plumbing(const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & value) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(index, cur_level) && !isBatchedAtLevel(value, cur_level)) { + return at::_ops::index_fill_int_Tensor::call(self, dim, index, value); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [index_value, index_bdim] = unwrapTensorAtLevel(index, cur_level); + auto [value_value, value_bdim] = unwrapTensorAtLevel(value, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, index_value, index_bdim, value_value, value_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & index_fill__Dimname_Scalar_generated_plumbing(at::Tensor & self, at::Dimname dim, const at::Tensor & index, const at::Scalar & value) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(index, cur_level)) { + return at::_ops::index_fill__Dimname_Scalar::call(self, dim, index, value); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [index_value, index_bdim] = unwrapTensorAtLevel(index, cur_level); + batch_rule(self_value, self_bdim, dim, index_value, index_bdim, value); + return self; +} +template +at::Tensor & index_fill__Dimname_Tensor_generated_plumbing(at::Tensor & self, at::Dimname dim, const at::Tensor & index, const at::Tensor & value) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(index, cur_level) && !isBatchedAtLevel(value, cur_level)) { + return at::_ops::index_fill__Dimname_Tensor::call(self, dim, index, value); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [index_value, index_bdim] = unwrapTensorAtLevel(index, cur_level); + auto [value_value, value_bdim] = unwrapTensorAtLevel(value, cur_level); + batch_rule(self_value, self_bdim, dim, index_value, index_bdim, value_value, value_bdim); + return self; +} +template +at::Tensor index_fill_Dimname_Scalar_generated_plumbing(const at::Tensor & self, at::Dimname dim, const at::Tensor & index, const at::Scalar & value) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(index, cur_level)) { + return at::_ops::index_fill_Dimname_Scalar::call(self, dim, index, value); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [index_value, index_bdim] = unwrapTensorAtLevel(index, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, index_value, index_bdim, value); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor index_fill_Dimname_Tensor_generated_plumbing(const at::Tensor & self, at::Dimname dim, const at::Tensor & index, const at::Tensor & value) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(index, cur_level) && !isBatchedAtLevel(value, cur_level)) { + return at::_ops::index_fill_Dimname_Tensor::call(self, dim, index, value); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [index_value, index_bdim] = unwrapTensorAtLevel(index, cur_level); + auto [value_value, value_bdim] = unwrapTensorAtLevel(value, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, index_value, index_bdim, value_value, value_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor scatter_src_generated_plumbing(const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(index, cur_level) && !isBatchedAtLevel(src, cur_level)) { + return at::_ops::scatter_src::call(self, dim, index, src); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [index_value, index_bdim] = unwrapTensorAtLevel(index, cur_level); + auto [src_value, src_bdim] = unwrapTensorAtLevel(src, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, index_value, index_bdim, src_value, src_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & scatter__src_generated_plumbing(at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(index, cur_level) && !isBatchedAtLevel(src, cur_level)) { + return at::_ops::scatter__src::call(self, dim, index, src); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [index_value, index_bdim] = unwrapTensorAtLevel(index, cur_level); + auto [src_value, src_bdim] = unwrapTensorAtLevel(src, cur_level); + batch_rule(self_value, self_bdim, dim, index_value, index_bdim, src_value, src_bdim); + return self; +} +template +at::Tensor scatter_value_generated_plumbing(const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Scalar & value) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(index, cur_level)) { + return at::_ops::scatter_value::call(self, dim, index, value); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [index_value, index_bdim] = unwrapTensorAtLevel(index, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, index_value, index_bdim, value); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & scatter__value_generated_plumbing(at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Scalar & value) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(index, cur_level)) { + return at::_ops::scatter__value::call(self, dim, index, value); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [index_value, index_bdim] = unwrapTensorAtLevel(index, cur_level); + batch_rule(self_value, self_bdim, dim, index_value, index_bdim, value); + return self; +} +template +at::Tensor scatter_reduce_generated_plumbing(const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src, c10::string_view reduce) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(index, cur_level) && !isBatchedAtLevel(src, cur_level)) { + return at::_ops::scatter_reduce::call(self, dim, index, src, reduce); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [index_value, index_bdim] = unwrapTensorAtLevel(index, cur_level); + auto [src_value, src_bdim] = unwrapTensorAtLevel(src, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, index_value, index_bdim, src_value, src_bdim, reduce); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & scatter__reduce_generated_plumbing(at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src, c10::string_view reduce) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(index, cur_level) && !isBatchedAtLevel(src, cur_level)) { + return at::_ops::scatter__reduce::call(self, dim, index, src, reduce); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [index_value, index_bdim] = unwrapTensorAtLevel(index, cur_level); + auto [src_value, src_bdim] = unwrapTensorAtLevel(src, cur_level); + batch_rule(self_value, self_bdim, dim, index_value, index_bdim, src_value, src_bdim, reduce); + return self; +} +template +at::Tensor scatter_value_reduce_generated_plumbing(const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Scalar & value, c10::string_view reduce) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(index, cur_level)) { + return at::_ops::scatter_value_reduce::call(self, dim, index, value, reduce); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [index_value, index_bdim] = unwrapTensorAtLevel(index, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, index_value, index_bdim, value, reduce); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & scatter__value_reduce_generated_plumbing(at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Scalar & value, c10::string_view reduce) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(index, cur_level)) { + return at::_ops::scatter__value_reduce::call(self, dim, index, value, reduce); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [index_value, index_bdim] = unwrapTensorAtLevel(index, cur_level); + batch_rule(self_value, self_bdim, dim, index_value, index_bdim, value, reduce); + return self; +} +template +at::Tensor scatter_dimname_src_generated_plumbing(const at::Tensor & self, at::Dimname dim, const at::Tensor & index, const at::Tensor & src) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(index, cur_level) && !isBatchedAtLevel(src, cur_level)) { + return at::_ops::scatter_dimname_src::call(self, dim, index, src); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [index_value, index_bdim] = unwrapTensorAtLevel(index, cur_level); + auto [src_value, src_bdim] = unwrapTensorAtLevel(src, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, index_value, index_bdim, src_value, src_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor scatter_dimname_value_generated_plumbing(const at::Tensor & self, at::Dimname dim, const at::Tensor & index, const at::Scalar & value) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(index, cur_level)) { + return at::_ops::scatter_dimname_value::call(self, dim, index, value); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [index_value, index_bdim] = unwrapTensorAtLevel(index, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, index_value, index_bdim, value); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor scatter_add_generated_plumbing(const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(index, cur_level) && !isBatchedAtLevel(src, cur_level)) { + return at::_ops::scatter_add::call(self, dim, index, src); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [index_value, index_bdim] = unwrapTensorAtLevel(index, cur_level); + auto [src_value, src_bdim] = unwrapTensorAtLevel(src, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, index_value, index_bdim, src_value, src_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & scatter_add__generated_plumbing(at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(index, cur_level) && !isBatchedAtLevel(src, cur_level)) { + return at::_ops::scatter_add_::call(self, dim, index, src); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [index_value, index_bdim] = unwrapTensorAtLevel(index, cur_level); + auto [src_value, src_bdim] = unwrapTensorAtLevel(src, cur_level); + batch_rule(self_value, self_bdim, dim, index_value, index_bdim, src_value, src_bdim); + return self; +} +template +at::Tensor scatter_add_dimname_generated_plumbing(const at::Tensor & self, at::Dimname dim, const at::Tensor & index, const at::Tensor & src) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(index, cur_level) && !isBatchedAtLevel(src, cur_level)) { + return at::_ops::scatter_add_dimname::call(self, dim, index, src); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [index_value, index_bdim] = unwrapTensorAtLevel(index, cur_level); + auto [src_value, src_bdim] = unwrapTensorAtLevel(src, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, index_value, index_bdim, src_value, src_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor scatter_reduce_two_generated_plumbing(const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src, c10::string_view reduce, bool include_self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(index, cur_level) && !isBatchedAtLevel(src, cur_level)) { + return at::_ops::scatter_reduce_two::call(self, dim, index, src, reduce, include_self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [index_value, index_bdim] = unwrapTensorAtLevel(index, cur_level); + auto [src_value, src_bdim] = unwrapTensorAtLevel(src, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, index_value, index_bdim, src_value, src_bdim, reduce, include_self); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & scatter_reduce__two_generated_plumbing(at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src, c10::string_view reduce, bool include_self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(index, cur_level) && !isBatchedAtLevel(src, cur_level)) { + return at::_ops::scatter_reduce__two::call(self, dim, index, src, reduce, include_self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [index_value, index_bdim] = unwrapTensorAtLevel(index, cur_level); + auto [src_value, src_bdim] = unwrapTensorAtLevel(src, cur_level); + batch_rule(self_value, self_bdim, dim, index_value, index_bdim, src_value, src_bdim, reduce, include_self); + return self; +} +template +at::Tensor & eq__Scalar_generated_plumbing(at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::eq__Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, other); + return self; +} +template +at::Tensor & eq__Tensor_generated_plumbing(at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::eq__Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self_value, self_bdim, other_value, other_bdim); + return self; +} +template +at::Tensor bitwise_and_Scalar_generated_plumbing(const at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::bitwise_and_Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, other); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor bitwise_and_Scalar_Tensor_generated_plumbing(const at::Scalar & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(other, cur_level)) { + return at::_ops::bitwise_and_Scalar_Tensor::call(self, other); + } + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor bitwise_and_Tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::bitwise_and_Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & bitwise_and__Scalar_generated_plumbing(at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::bitwise_and__Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, other); + return self; +} +template +at::Tensor & bitwise_and__Tensor_generated_plumbing(at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::bitwise_and__Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self_value, self_bdim, other_value, other_bdim); + return self; +} +template +at::Tensor __and___Scalar_generated_plumbing(const at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::__and___Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, other); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor __and___Tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::__and___Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & __iand___Scalar_generated_plumbing(at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::__iand___Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, other); + return self; +} +template +at::Tensor & __iand___Tensor_generated_plumbing(at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::__iand___Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self_value, self_bdim, other_value, other_bdim); + return self; +} +template +at::Tensor bitwise_or_Scalar_generated_plumbing(const at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::bitwise_or_Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, other); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor bitwise_or_Scalar_Tensor_generated_plumbing(const at::Scalar & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(other, cur_level)) { + return at::_ops::bitwise_or_Scalar_Tensor::call(self, other); + } + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor bitwise_or_Tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::bitwise_or_Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & bitwise_or__Scalar_generated_plumbing(at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::bitwise_or__Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, other); + return self; +} +template +at::Tensor & bitwise_or__Tensor_generated_plumbing(at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::bitwise_or__Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self_value, self_bdim, other_value, other_bdim); + return self; +} +template +at::Tensor __or___Scalar_generated_plumbing(const at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::__or___Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, other); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor __or___Tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::__or___Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & __ior___Scalar_generated_plumbing(at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::__ior___Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, other); + return self; +} +template +at::Tensor & __ior___Tensor_generated_plumbing(at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::__ior___Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self_value, self_bdim, other_value, other_bdim); + return self; +} +template +at::Tensor bitwise_xor_Scalar_generated_plumbing(const at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::bitwise_xor_Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, other); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor bitwise_xor_Scalar_Tensor_generated_plumbing(const at::Scalar & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(other, cur_level)) { + return at::_ops::bitwise_xor_Scalar_Tensor::call(self, other); + } + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor bitwise_xor_Tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::bitwise_xor_Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & bitwise_xor__Scalar_generated_plumbing(at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::bitwise_xor__Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, other); + return self; +} +template +at::Tensor & bitwise_xor__Tensor_generated_plumbing(at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::bitwise_xor__Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self_value, self_bdim, other_value, other_bdim); + return self; +} +template +at::Tensor __xor___Scalar_generated_plumbing(const at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::__xor___Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, other); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor __xor___Tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::__xor___Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & __ixor___Scalar_generated_plumbing(at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::__ixor___Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, other); + return self; +} +template +at::Tensor & __ixor___Tensor_generated_plumbing(at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::__ixor___Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self_value, self_bdim, other_value, other_bdim); + return self; +} +template +at::Tensor __lshift___Scalar_generated_plumbing(const at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::__lshift___Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, other); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor __lshift___Tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::__lshift___Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & __ilshift___Scalar_generated_plumbing(at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::__ilshift___Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, other); + return self; +} +template +at::Tensor & __ilshift___Tensor_generated_plumbing(at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::__ilshift___Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self_value, self_bdim, other_value, other_bdim); + return self; +} +template +at::Tensor bitwise_left_shift_Tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::bitwise_left_shift_Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & bitwise_left_shift__Tensor_generated_plumbing(at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::bitwise_left_shift__Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self_value, self_bdim, other_value, other_bdim); + return self; +} +template +at::Tensor bitwise_left_shift_Tensor_Scalar_generated_plumbing(const at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::bitwise_left_shift_Tensor_Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, other); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & bitwise_left_shift__Tensor_Scalar_generated_plumbing(at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::bitwise_left_shift__Tensor_Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, other); + return self; +} +template +at::Tensor bitwise_left_shift_Scalar_Tensor_generated_plumbing(const at::Scalar & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(other, cur_level)) { + return at::_ops::bitwise_left_shift_Scalar_Tensor::call(self, other); + } + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor __rshift___Scalar_generated_plumbing(const at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::__rshift___Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, other); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor __rshift___Tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::__rshift___Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & __irshift___Scalar_generated_plumbing(at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::__irshift___Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, other); + return self; +} +template +at::Tensor & __irshift___Tensor_generated_plumbing(at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::__irshift___Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self_value, self_bdim, other_value, other_bdim); + return self; +} +template +at::Tensor bitwise_right_shift_Tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::bitwise_right_shift_Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & bitwise_right_shift__Tensor_generated_plumbing(at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::bitwise_right_shift__Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self_value, self_bdim, other_value, other_bdim); + return self; +} +template +at::Tensor bitwise_right_shift_Tensor_Scalar_generated_plumbing(const at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::bitwise_right_shift_Tensor_Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, other); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & bitwise_right_shift__Tensor_Scalar_generated_plumbing(at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::bitwise_right_shift__Tensor_Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, other); + return self; +} +template +at::Tensor bitwise_right_shift_Scalar_Tensor_generated_plumbing(const at::Scalar & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(other, cur_level)) { + return at::_ops::bitwise_right_shift_Scalar_Tensor::call(self, other); + } + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & tril__generated_plumbing(at::Tensor & self, c10::SymInt diagonal) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::tril_::call(self, diagonal); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, diagonal); + return self; +} +template +at::Tensor & triu__generated_plumbing(at::Tensor & self, c10::SymInt diagonal) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::triu_::call(self, diagonal); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, diagonal); + return self; +} +template +at::Tensor & digamma__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::digamma_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor & lerp__Scalar_generated_plumbing(at::Tensor & self, const at::Tensor & end, const at::Scalar & weight) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(end, cur_level)) { + return at::_ops::lerp__Scalar::call(self, end, weight); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [end_value, end_bdim] = unwrapTensorAtLevel(end, cur_level); + batch_rule(self_value, self_bdim, end_value, end_bdim, weight); + return self; +} +template +at::Tensor & lerp__Tensor_generated_plumbing(at::Tensor & self, const at::Tensor & end, const at::Tensor & weight) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(end, cur_level) && !isBatchedAtLevel(weight, cur_level)) { + return at::_ops::lerp__Tensor::call(self, end, weight); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [end_value, end_bdim] = unwrapTensorAtLevel(end, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + batch_rule(self_value, self_bdim, end_value, end_bdim, weight_value, weight_bdim); + return self; +} +template +at::Tensor & addbmm__generated_plumbing(at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta, const at::Scalar & alpha) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(batch1, cur_level) && !isBatchedAtLevel(batch2, cur_level)) { + return at::_ops::addbmm_::call(self, batch1, batch2, beta, alpha); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [batch1_value, batch1_bdim] = unwrapTensorAtLevel(batch1, cur_level); + auto [batch2_value, batch2_bdim] = unwrapTensorAtLevel(batch2, cur_level); + batch_rule(self_value, self_bdim, batch1_value, batch1_bdim, batch2_value, batch2_bdim, beta, alpha); + return self; +} +template +at::Tensor addbmm_generated_plumbing(const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta, const at::Scalar & alpha) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(batch1, cur_level) && !isBatchedAtLevel(batch2, cur_level)) { + return at::_ops::addbmm::call(self, batch1, batch2, beta, alpha); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [batch1_value, batch1_bdim] = unwrapTensorAtLevel(batch1, cur_level); + auto [batch2_value, batch2_bdim] = unwrapTensorAtLevel(batch2, cur_level); + auto results = batch_rule(self_value, self_bdim, batch1_value, batch1_bdim, batch2_value, batch2_bdim, beta, alpha); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & random__from_generated_plumbing(at::Tensor & self, int64_t from, ::std::optional to, ::std::optional generator) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::random__from::call(self, from, to, generator); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, from, to, generator); + return self; +} +template +at::Tensor & random__to_generated_plumbing(at::Tensor & self, int64_t to, ::std::optional generator) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::random__to::call(self, to, generator); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, to, generator); + return self; +} +template +at::Tensor & random__generated_plumbing(at::Tensor & self, ::std::optional generator) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::random_::call(self, generator); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, generator); + return self; +} +template +at::Tensor & uniform__generated_plumbing(at::Tensor & self, double from, double to, ::std::optional generator) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::uniform_::call(self, from, to, generator); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, from, to, generator); + return self; +} +template +at::Tensor & cauchy__generated_plumbing(at::Tensor & self, double median, double sigma, ::std::optional generator) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::cauchy_::call(self, median, sigma, generator); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, median, sigma, generator); + return self; +} +template +at::Tensor & log_normal__generated_plumbing(at::Tensor & self, double mean, double std, ::std::optional generator) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::log_normal_::call(self, mean, std, generator); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, mean, std, generator); + return self; +} +template +at::Tensor & exponential__generated_plumbing(at::Tensor & self, double lambd, ::std::optional generator) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::exponential_::call(self, lambd, generator); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, lambd, generator); + return self; +} +template +at::Tensor & geometric__generated_plumbing(at::Tensor & self, double p, ::std::optional generator) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::geometric_::call(self, p, generator); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, p, generator); + return self; +} +template +at::Tensor diag_generated_plumbing(const at::Tensor & self, int64_t diagonal) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::diag::call(self, diagonal); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, diagonal); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor cross_generated_plumbing(const at::Tensor & self, const at::Tensor & other, ::std::optional dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::cross::call(self, other, dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim, dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor triu_generated_plumbing(const at::Tensor & self, c10::SymInt diagonal) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::triu::call(self, diagonal); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, diagonal); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor tril_generated_plumbing(const at::Tensor & self, c10::SymInt diagonal) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::tril::call(self, diagonal); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, diagonal); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor trace_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::trace::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor trace_backward_generated_plumbing(const at::Tensor & grad, c10::SymIntArrayRef sizes) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad, cur_level)) { + return at::_ops::trace_backward::call(grad, sizes); + } + auto [grad_value, grad_bdim] = unwrapTensorAtLevel(grad, cur_level); + auto results = batch_rule(grad_value, grad_bdim, sizes); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor ne_Scalar_generated_plumbing(const at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::ne_Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, other); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor ne_Tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::ne_Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & ne__Scalar_generated_plumbing(at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::ne__Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, other); + return self; +} +template +at::Tensor & ne__Tensor_generated_plumbing(at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::ne__Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self_value, self_bdim, other_value, other_bdim); + return self; +} +template +at::Tensor not_equal_Scalar_generated_plumbing(const at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::not_equal_Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, other); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor not_equal_Tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::not_equal_Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & not_equal__Scalar_generated_plumbing(at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::not_equal__Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, other); + return self; +} +template +at::Tensor & not_equal__Tensor_generated_plumbing(at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::not_equal__Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self_value, self_bdim, other_value, other_bdim); + return self; +} +template +at::Tensor eq_Scalar_generated_plumbing(const at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::eq_Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, other); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor eq_Tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::eq_Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor ge_Scalar_generated_plumbing(const at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::ge_Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, other); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor ge_Tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::ge_Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & ge__Scalar_generated_plumbing(at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::ge__Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, other); + return self; +} +template +at::Tensor & ge__Tensor_generated_plumbing(at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::ge__Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self_value, self_bdim, other_value, other_bdim); + return self; +} +template +at::Tensor greater_equal_Scalar_generated_plumbing(const at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::greater_equal_Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, other); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor greater_equal_Tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::greater_equal_Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & greater_equal__Scalar_generated_plumbing(at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::greater_equal__Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, other); + return self; +} +template +at::Tensor & greater_equal__Tensor_generated_plumbing(at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::greater_equal__Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self_value, self_bdim, other_value, other_bdim); + return self; +} +template +at::Tensor le_Scalar_generated_plumbing(const at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::le_Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, other); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor le_Tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::le_Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & le__Scalar_generated_plumbing(at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::le__Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, other); + return self; +} +template +at::Tensor & le__Tensor_generated_plumbing(at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::le__Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self_value, self_bdim, other_value, other_bdim); + return self; +} +template +at::Tensor less_equal_Scalar_generated_plumbing(const at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::less_equal_Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, other); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor less_equal_Tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::less_equal_Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & less_equal__Scalar_generated_plumbing(at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::less_equal__Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, other); + return self; +} +template +at::Tensor & less_equal__Tensor_generated_plumbing(at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::less_equal__Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self_value, self_bdim, other_value, other_bdim); + return self; +} +template +at::Tensor gt_Scalar_generated_plumbing(const at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::gt_Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, other); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor gt_Tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::gt_Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & gt__Scalar_generated_plumbing(at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::gt__Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, other); + return self; +} +template +at::Tensor & gt__Tensor_generated_plumbing(at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::gt__Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self_value, self_bdim, other_value, other_bdim); + return self; +} +template +at::Tensor greater_Scalar_generated_plumbing(const at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::greater_Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, other); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor greater_Tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::greater_Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & greater__Scalar_generated_plumbing(at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::greater__Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, other); + return self; +} +template +at::Tensor & greater__Tensor_generated_plumbing(at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::greater__Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self_value, self_bdim, other_value, other_bdim); + return self; +} +template +at::Tensor lt_Scalar_generated_plumbing(const at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::lt_Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, other); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor lt_Tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::lt_Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & lt__Scalar_generated_plumbing(at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::lt__Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, other); + return self; +} +template +at::Tensor & lt__Tensor_generated_plumbing(at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::lt__Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self_value, self_bdim, other_value, other_bdim); + return self; +} +template +at::Tensor less_Scalar_generated_plumbing(const at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::less_Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, other); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor less_Tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::less_Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & less__Scalar_generated_plumbing(at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::less__Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, other); + return self; +} +template +at::Tensor & less__Tensor_generated_plumbing(at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::less__Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self_value, self_bdim, other_value, other_bdim); + return self; +} +template +at::Tensor take_generated_plumbing(const at::Tensor & self, const at::Tensor & index) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(index, cur_level)) { + return at::_ops::take::call(self, index); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [index_value, index_bdim] = unwrapTensorAtLevel(index, cur_level); + auto results = batch_rule(self_value, self_bdim, index_value, index_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor take_along_dim_generated_plumbing(const at::Tensor & self, const at::Tensor & indices, ::std::optional dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(indices, cur_level)) { + return at::_ops::take_along_dim::call(self, indices, dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [indices_value, indices_bdim] = unwrapTensorAtLevel(indices, cur_level); + auto results = batch_rule(self_value, self_bdim, indices_value, indices_bdim, dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor index_select_generated_plumbing(const at::Tensor & self, int64_t dim, const at::Tensor & index) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(index, cur_level)) { + return at::_ops::index_select::call(self, dim, index); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [index_value, index_bdim] = unwrapTensorAtLevel(index, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, index_value, index_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor index_select_dimname_generated_plumbing(const at::Tensor & self, at::Dimname dim, const at::Tensor & index) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(index, cur_level)) { + return at::_ops::index_select_dimname::call(self, dim, index); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [index_value, index_bdim] = unwrapTensorAtLevel(index, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, index_value, index_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor index_select_backward_generated_plumbing(const at::Tensor & grad, c10::SymIntArrayRef self_sizes, int64_t dim, const at::Tensor & index) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad, cur_level) && !isBatchedAtLevel(index, cur_level)) { + return at::_ops::index_select_backward::call(grad, self_sizes, dim, index); + } + auto [grad_value, grad_bdim] = unwrapTensorAtLevel(grad, cur_level); + auto [index_value, index_bdim] = unwrapTensorAtLevel(index, cur_level); + auto results = batch_rule(grad_value, grad_bdim, self_sizes, dim, index_value, index_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor masked_select_generated_plumbing(const at::Tensor & self, const at::Tensor & mask) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(mask, cur_level)) { + return at::_ops::masked_select::call(self, mask); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [mask_value, mask_bdim] = unwrapTensorAtLevel(mask, cur_level); + auto results = batch_rule(self_value, self_bdim, mask_value, mask_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor masked_select_backward_generated_plumbing(const at::Tensor & grad, const at::Tensor & input, const at::Tensor & mask) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad, cur_level) && !isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(mask, cur_level)) { + return at::_ops::masked_select_backward::call(grad, input, mask); + } + auto [grad_value, grad_bdim] = unwrapTensorAtLevel(grad, cur_level); + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [mask_value, mask_bdim] = unwrapTensorAtLevel(mask, cur_level); + auto results = batch_rule(grad_value, grad_bdim, input_value, input_bdim, mask_value, mask_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor nonzero_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::nonzero::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor nonzero_static_generated_plumbing(const at::Tensor & self, c10::SymInt size, int64_t fill_value) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::nonzero_static::call(self, size, fill_value); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, size, fill_value); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::vector nonzero_numpy_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::nonzero_numpy::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor argwhere_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::argwhere::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor gather_generated_plumbing(const at::Tensor & self, int64_t dim, const at::Tensor & index, bool sparse_grad) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(index, cur_level)) { + return at::_ops::gather::call(self, dim, index, sparse_grad); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [index_value, index_bdim] = unwrapTensorAtLevel(index, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, index_value, index_bdim, sparse_grad); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor gather_backward_generated_plumbing(const at::Tensor & grad, const at::Tensor & self, int64_t dim, const at::Tensor & index, bool sparse_grad) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad, cur_level) && !isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(index, cur_level)) { + return at::_ops::gather_backward::call(grad, self, dim, index, sparse_grad); + } + auto [grad_value, grad_bdim] = unwrapTensorAtLevel(grad, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [index_value, index_bdim] = unwrapTensorAtLevel(index, cur_level); + auto results = batch_rule(grad_value, grad_bdim, self_value, self_bdim, dim, index_value, index_bdim, sparse_grad); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor gather_dimname_generated_plumbing(const at::Tensor & self, at::Dimname dim, const at::Tensor & index, bool sparse_grad) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(index, cur_level)) { + return at::_ops::gather_dimname::call(self, dim, index, sparse_grad); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [index_value, index_bdim] = unwrapTensorAtLevel(index, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, index_value, index_bdim, sparse_grad); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _gather_sparse_backward_generated_plumbing(const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & grad) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(index, cur_level) && !isBatchedAtLevel(grad, cur_level)) { + return at::_ops::_gather_sparse_backward::call(self, dim, index, grad); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [index_value, index_bdim] = unwrapTensorAtLevel(index, cur_level); + auto [grad_value, grad_bdim] = unwrapTensorAtLevel(grad, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, index_value, index_bdim, grad_value, grad_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor addcmul_generated_plumbing(const at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(tensor1, cur_level) && !isBatchedAtLevel(tensor2, cur_level)) { + return at::_ops::addcmul::call(self, tensor1, tensor2, value); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [tensor1_value, tensor1_bdim] = unwrapTensorAtLevel(tensor1, cur_level); + auto [tensor2_value, tensor2_bdim] = unwrapTensorAtLevel(tensor2, cur_level); + auto results = batch_rule(self_value, self_bdim, tensor1_value, tensor1_bdim, tensor2_value, tensor2_bdim, value); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & addcmul__generated_plumbing(at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(tensor1, cur_level) && !isBatchedAtLevel(tensor2, cur_level)) { + return at::_ops::addcmul_::call(self, tensor1, tensor2, value); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [tensor1_value, tensor1_bdim] = unwrapTensorAtLevel(tensor1, cur_level); + auto [tensor2_value, tensor2_bdim] = unwrapTensorAtLevel(tensor2, cur_level); + batch_rule(self_value, self_bdim, tensor1_value, tensor1_bdim, tensor2_value, tensor2_bdim, value); + return self; +} +template +at::Tensor addcdiv_generated_plumbing(const at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(tensor1, cur_level) && !isBatchedAtLevel(tensor2, cur_level)) { + return at::_ops::addcdiv::call(self, tensor1, tensor2, value); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [tensor1_value, tensor1_bdim] = unwrapTensorAtLevel(tensor1, cur_level); + auto [tensor2_value, tensor2_bdim] = unwrapTensorAtLevel(tensor2, cur_level); + auto results = batch_rule(self_value, self_bdim, tensor1_value, tensor1_bdim, tensor2_value, tensor2_bdim, value); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & addcdiv__generated_plumbing(at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(tensor1, cur_level) && !isBatchedAtLevel(tensor2, cur_level)) { + return at::_ops::addcdiv_::call(self, tensor1, tensor2, value); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [tensor1_value, tensor1_bdim] = unwrapTensorAtLevel(tensor1, cur_level); + auto [tensor2_value, tensor2_bdim] = unwrapTensorAtLevel(tensor2, cur_level); + batch_rule(self_value, self_bdim, tensor1_value, tensor1_bdim, tensor2_value, tensor2_bdim, value); + return self; +} +template +at::Tensor cross_entropy_loss_generated_plumbing(const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, c10::SymInt ignore_index, double label_smoothing) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(target, cur_level) && !isBatchedAtLevel(weight, cur_level)) { + return at::_ops::cross_entropy_loss::call(self, target, weight, reduction, ignore_index, label_smoothing); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [target_value, target_bdim] = unwrapTensorAtLevel(target, cur_level); + std::optional weight_value; + std::optional weight_bdim; + if (weight) { + std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, target_value, target_bdim, weight_value, weight_bdim, reduction, ignore_index, label_smoothing); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple triangular_solve_generated_plumbing(const at::Tensor & self, const at::Tensor & A, bool upper, bool transpose, bool unitriangular) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(A, cur_level)) { + return at::_ops::triangular_solve::call(self, A, upper, transpose, unitriangular); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [A_value, A_bdim] = unwrapTensorAtLevel(A, cur_level); + auto results = batch_rule(self_value, self_bdim, A_value, A_bdim, upper, transpose, unitriangular); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +void _linalg_check_errors_generated_plumbing(const at::Tensor & info, c10::string_view api_name, bool is_matrix) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(info, cur_level)) { + return at::_ops::_linalg_check_errors::call(info, api_name, is_matrix); + } + auto [info_value, info_bdim] = unwrapTensorAtLevel(info, cur_level); + batch_rule(info_value, info_bdim, api_name, is_matrix); +} +template +at::Tensor linalg_solve_triangular_generated_plumbing(const at::Tensor & self, const at::Tensor & B, bool upper, bool left, bool unitriangular) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(B, cur_level)) { + return at::_ops::linalg_solve_triangular::call(self, B, upper, left, unitriangular); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [B_value, B_bdim] = unwrapTensorAtLevel(B, cur_level); + auto results = batch_rule(self_value, self_bdim, B_value, B_bdim, upper, left, unitriangular); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor linalg_vander_generated_plumbing(const at::Tensor & x, ::std::optional N) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(x, cur_level)) { + return at::_ops::linalg_vander::call(x, N); + } + auto [x_value, x_bdim] = unwrapTensorAtLevel(x, cur_level); + auto results = batch_rule(x_value, x_bdim, N); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple svd_generated_plumbing(const at::Tensor & self, bool some, bool compute_uv) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::svd::call(self, some, compute_uv); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, some, compute_uv); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level)); +} +template +at::Tensor swapaxes_generated_plumbing(const at::Tensor & self, int64_t axis0, int64_t axis1) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::swapaxes::call(self, axis0, axis1); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, axis0, axis1); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor swapdims_generated_plumbing(const at::Tensor & self, int64_t dim0, int64_t dim1) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::swapdims::call(self, dim0, dim1); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim0, dim1); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor cholesky_generated_plumbing(const at::Tensor & self, bool upper) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::cholesky::call(self, upper); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, upper); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor cholesky_solve_generated_plumbing(const at::Tensor & self, const at::Tensor & input2, bool upper) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(input2, cur_level)) { + return at::_ops::cholesky_solve::call(self, input2, upper); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [input2_value, input2_bdim] = unwrapTensorAtLevel(input2, cur_level); + auto results = batch_rule(self_value, self_bdim, input2_value, input2_bdim, upper); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _cholesky_solve_helper_generated_plumbing(const at::Tensor & self, const at::Tensor & A, bool upper) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(A, cur_level)) { + return at::_ops::_cholesky_solve_helper::call(self, A, upper); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [A_value, A_bdim] = unwrapTensorAtLevel(A, cur_level); + auto results = batch_rule(self_value, self_bdim, A_value, A_bdim, upper); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor cholesky_inverse_generated_plumbing(const at::Tensor & self, bool upper) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::cholesky_inverse::call(self, upper); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, upper); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple qr_generated_plumbing(const at::Tensor & self, bool some) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::qr::call(self, some); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, some); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple geqrf_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::geqrf::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor orgqr_generated_plumbing(const at::Tensor & self, const at::Tensor & input2) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(input2, cur_level)) { + return at::_ops::orgqr::call(self, input2); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [input2_value, input2_bdim] = unwrapTensorAtLevel(input2, cur_level); + auto results = batch_rule(self_value, self_bdim, input2_value, input2_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor ormqr_generated_plumbing(const at::Tensor & self, const at::Tensor & input2, const at::Tensor & input3, bool left, bool transpose) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(input2, cur_level) && !isBatchedAtLevel(input3, cur_level)) { + return at::_ops::ormqr::call(self, input2, input3, left, transpose); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [input2_value, input2_bdim] = unwrapTensorAtLevel(input2, cur_level); + auto [input3_value, input3_bdim] = unwrapTensorAtLevel(input3, cur_level); + auto results = batch_rule(self_value, self_bdim, input2_value, input2_bdim, input3_value, input3_bdim, left, transpose); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple _lu_with_info_generated_plumbing(const at::Tensor & self, bool pivot, bool check_errors) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_lu_with_info::call(self, pivot, check_errors); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, pivot, check_errors); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level)); +} +template +at::Tensor lu_solve_generated_plumbing(const at::Tensor & self, const at::Tensor & LU_data, const at::Tensor & LU_pivots) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(LU_data, cur_level) && !isBatchedAtLevel(LU_pivots, cur_level)) { + return at::_ops::lu_solve::call(self, LU_data, LU_pivots); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [LU_data_value, LU_data_bdim] = unwrapTensorAtLevel(LU_data, cur_level); + auto [LU_pivots_value, LU_pivots_bdim] = unwrapTensorAtLevel(LU_pivots, cur_level); + auto results = batch_rule(self_value, self_bdim, LU_data_value, LU_data_bdim, LU_pivots_value, LU_pivots_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple lu_unpack_generated_plumbing(const at::Tensor & LU_data, const at::Tensor & LU_pivots, bool unpack_data, bool unpack_pivots) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(LU_data, cur_level) && !isBatchedAtLevel(LU_pivots, cur_level)) { + return at::_ops::lu_unpack::call(LU_data, LU_pivots, unpack_data, unpack_pivots); + } + auto [LU_data_value, LU_data_bdim] = unwrapTensorAtLevel(LU_data, cur_level); + auto [LU_pivots_value, LU_pivots_bdim] = unwrapTensorAtLevel(LU_pivots, cur_level); + auto results = batch_rule(LU_data_value, LU_data_bdim, LU_pivots_value, LU_pivots_bdim, unpack_data, unpack_pivots); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level)); +} +template +at::Tensor multinomial_generated_plumbing(const at::Tensor & self, c10::SymInt num_samples, bool replacement, ::std::optional generator) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::multinomial::call(self, num_samples, replacement, generator); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, num_samples, replacement, generator); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & lgamma__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::lgamma_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor lgamma_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::lgamma::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor digamma_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::digamma::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor polygamma_generated_plumbing(int64_t n, const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::polygamma::call(n, self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(n, self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & polygamma__generated_plumbing(at::Tensor & self, int64_t n) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::polygamma_::call(self, n); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, n); + return self; +} +template +at::Tensor erfinv_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::erfinv::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & erfinv__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::erfinv_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor i0_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::i0::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & i0__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::i0_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor sign_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::sign::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & sign__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::sign_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor signbit_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::signbit::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor dist_generated_plumbing(const at::Tensor & self, const at::Tensor & other, const at::Scalar & p) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::dist::call(self, other, p); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim, p); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & atan2__generated_plumbing(at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::atan2_::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self_value, self_bdim, other_value, other_bdim); + return self; +} +template +at::Tensor atan2_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::atan2::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor arctan2_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::arctan2::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & arctan2__generated_plumbing(at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::arctan2_::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self_value, self_bdim, other_value, other_bdim); + return self; +} +template +at::Tensor lerp_Scalar_generated_plumbing(const at::Tensor & self, const at::Tensor & end, const at::Scalar & weight) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(end, cur_level)) { + return at::_ops::lerp_Scalar::call(self, end, weight); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [end_value, end_bdim] = unwrapTensorAtLevel(end, cur_level); + auto results = batch_rule(self_value, self_bdim, end_value, end_bdim, weight); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor lerp_Tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & end, const at::Tensor & weight) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(end, cur_level) && !isBatchedAtLevel(weight, cur_level)) { + return at::_ops::lerp_Tensor::call(self, end, weight); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [end_value, end_bdim] = unwrapTensorAtLevel(end, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + auto results = batch_rule(self_value, self_bdim, end_value, end_bdim, weight_value, weight_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor histc_generated_plumbing(const at::Tensor & self, int64_t bins, const at::Scalar & min, const at::Scalar & max) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::histc::call(self, bins, min, max); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, bins, min, max); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple histogram_bins_tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & bins, const ::std::optional & weight, bool density) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(bins, cur_level) && !isBatchedAtLevel(weight, cur_level)) { + return at::_ops::histogram_bins_tensor::call(self, bins, weight, density); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [bins_value, bins_bdim] = unwrapTensorAtLevel(bins, cur_level); + std::optional weight_value; + std::optional weight_bdim; + if (weight) { + std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, bins_value, bins_bdim, weight_value, weight_bdim, density); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple histogram_bin_ct_generated_plumbing(const at::Tensor & self, int64_t bins, ::std::optional> range, const ::std::optional & weight, bool density) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(weight, cur_level)) { + return at::_ops::histogram_bin_ct::call(self, bins, range, weight, density); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + std::optional weight_value; + std::optional weight_bdim; + if (weight) { + std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, bins, range, weight_value, weight_bdim, density); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::vector _histogramdd_bin_edges_generated_plumbing(const at::Tensor & self, at::IntArrayRef bins, ::std::optional> range, const ::std::optional & weight, bool density) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(weight, cur_level)) { + return at::_ops::_histogramdd_bin_edges::call(self, bins, range, weight, density); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + std::optional weight_value; + std::optional weight_bdim; + if (weight) { + std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, bins, range, weight_value, weight_bdim, density); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _histogramdd_from_bin_cts_generated_plumbing(const at::Tensor & self, at::IntArrayRef bins, ::std::optional> range, const ::std::optional & weight, bool density) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(weight, cur_level)) { + return at::_ops::_histogramdd_from_bin_cts::call(self, bins, range, weight, density); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + std::optional weight_value; + std::optional weight_bdim; + if (weight) { + std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, bins, range, weight_value, weight_bdim, density); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _histogramdd_from_bin_tensors_generated_plumbing(const at::Tensor & self, at::TensorList bins, const ::std::optional & weight, bool density) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(bins, cur_level) && !isBatchedAtLevel(weight, cur_level)) { + return at::_ops::_histogramdd_from_bin_tensors::call(self, bins, weight, density); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + std::optional weight_value; + std::optional weight_bdim; + if (weight) { + std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, bins, weight_value, weight_bdim, density); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple> histogramdd_generated_plumbing(const at::Tensor & self, at::IntArrayRef bins, ::std::optional> range, const ::std::optional & weight, bool density) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(weight, cur_level)) { + return at::_ops::histogramdd::call(self, bins, range, weight, density); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + std::optional weight_value; + std::optional weight_bdim; + if (weight) { + std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, bins, range, weight_value, weight_bdim, density); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatchedVector(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple> histogramdd_int_bins_generated_plumbing(const at::Tensor & self, int64_t bins, ::std::optional> range, const ::std::optional & weight, bool density) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(weight, cur_level)) { + return at::_ops::histogramdd_int_bins::call(self, bins, range, weight, density); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + std::optional weight_value; + std::optional weight_bdim; + if (weight) { + std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, bins, range, weight_value, weight_bdim, density); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatchedVector(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple> histogramdd_TensorList_bins_generated_plumbing(const at::Tensor & self, at::TensorList bins, ::std::optional> range, const ::std::optional & weight, bool density) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(bins, cur_level) && !isBatchedAtLevel(weight, cur_level)) { + return at::_ops::histogramdd_TensorList_bins::call(self, bins, range, weight, density); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + std::optional weight_value; + std::optional weight_bdim; + if (weight) { + std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, bins, range, weight_value, weight_bdim, density); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatchedVector(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor fmod_Scalar_generated_plumbing(const at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::fmod_Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, other); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & fmod__Scalar_generated_plumbing(at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::fmod__Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, other); + return self; +} +template +at::Tensor fmod_Tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::fmod_Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & fmod__Tensor_generated_plumbing(at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::fmod__Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self_value, self_bdim, other_value, other_bdim); + return self; +} +template +at::Tensor hypot_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::hypot::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & hypot__generated_plumbing(at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::hypot_::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self_value, self_bdim, other_value, other_bdim); + return self; +} +template +at::Tensor igamma_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::igamma::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & igamma__generated_plumbing(at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::igamma_::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self_value, self_bdim, other_value, other_bdim); + return self; +} +template +at::Tensor igammac_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::igammac::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & igammac__generated_plumbing(at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::igammac_::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self_value, self_bdim, other_value, other_bdim); + return self; +} +template +at::Tensor nextafter_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::nextafter::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & nextafter__generated_plumbing(at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::nextafter_::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self_value, self_bdim, other_value, other_bdim); + return self; +} +template +at::Tensor remainder_Scalar_generated_plumbing(const at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::remainder_Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, other); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & remainder__Scalar_generated_plumbing(at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::remainder__Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, other); + return self; +} +template +at::Tensor remainder_Tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::remainder_Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & remainder__Tensor_generated_plumbing(at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::remainder__Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self_value, self_bdim, other_value, other_bdim); + return self; +} +template +at::Tensor remainder_Scalar_Tensor_generated_plumbing(const at::Scalar & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(other, cur_level)) { + return at::_ops::remainder_Scalar_Tensor::call(self, other); + } + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor min_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::min::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor fmin_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::fmin::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor max_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::max::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor fmax_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::fmax::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor maximum_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::maximum::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor max_other_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::max_other::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor minimum_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::minimum::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor min_other_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::min_other::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor quantile_generated_plumbing(const at::Tensor & self, const at::Tensor & q, ::std::optional dim, bool keepdim, c10::string_view interpolation) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(q, cur_level)) { + return at::_ops::quantile::call(self, q, dim, keepdim, interpolation); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [q_value, q_bdim] = unwrapTensorAtLevel(q, cur_level); + auto results = batch_rule(self_value, self_bdim, q_value, q_bdim, dim, keepdim, interpolation); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor quantile_scalar_generated_plumbing(const at::Tensor & self, double q, ::std::optional dim, bool keepdim, c10::string_view interpolation) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::quantile_scalar::call(self, q, dim, keepdim, interpolation); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, q, dim, keepdim, interpolation); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor nanquantile_generated_plumbing(const at::Tensor & self, const at::Tensor & q, ::std::optional dim, bool keepdim, c10::string_view interpolation) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(q, cur_level)) { + return at::_ops::nanquantile::call(self, q, dim, keepdim, interpolation); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [q_value, q_bdim] = unwrapTensorAtLevel(q, cur_level); + auto results = batch_rule(self_value, self_bdim, q_value, q_bdim, dim, keepdim, interpolation); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor nanquantile_scalar_generated_plumbing(const at::Tensor & self, double q, ::std::optional dim, bool keepdim, c10::string_view interpolation) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::nanquantile_scalar::call(self, q, dim, keepdim, interpolation); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, q, dim, keepdim, interpolation); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple sort_generated_plumbing(const at::Tensor & self, int64_t dim, bool descending) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::sort::call(self, dim, descending); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, descending); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple sort_stable_generated_plumbing(const at::Tensor & self, ::std::optional stable, int64_t dim, bool descending) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::sort_stable::call(self, stable, dim, descending); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, stable, dim, descending); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple sort_dimname_generated_plumbing(const at::Tensor & self, at::Dimname dim, bool descending) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::sort_dimname::call(self, dim, descending); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, descending); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple sort_dimname_stable_generated_plumbing(const at::Tensor & self, ::std::optional stable, at::Dimname dim, bool descending) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::sort_dimname_stable::call(self, stable, dim, descending); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, stable, dim, descending); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor msort_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::msort::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor argsort_generated_plumbing(const at::Tensor & self, int64_t dim, bool descending) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::argsort::call(self, dim, descending); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, descending); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor argsort_stable_generated_plumbing(const at::Tensor & self, bool stable, int64_t dim, bool descending) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::argsort_stable::call(self, stable, dim, descending); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, stable, dim, descending); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor argsort_dimname_generated_plumbing(const at::Tensor & self, at::Dimname dim, bool descending) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::argsort_dimname::call(self, dim, descending); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, descending); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple topk_generated_plumbing(const at::Tensor & self, c10::SymInt k, int64_t dim, bool largest, bool sorted) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::topk::call(self, k, dim, largest, sorted); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, k, dim, largest, sorted); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor all_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::all::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor any_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::any::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor renorm_generated_plumbing(const at::Tensor & self, const at::Scalar & p, int64_t dim, const at::Scalar & maxnorm) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::renorm::call(self, p, dim, maxnorm); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, p, dim, maxnorm); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & renorm__generated_plumbing(at::Tensor & self, const at::Scalar & p, int64_t dim, const at::Scalar & maxnorm) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::renorm_::call(self, p, dim, maxnorm); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, p, dim, maxnorm); + return self; +} +template +at::Tensor unfold_generated_plumbing(const at::Tensor & self, int64_t dimension, int64_t size, int64_t step) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::unfold::call(self, dimension, size, step); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dimension, size, step); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor unfold_backward_generated_plumbing(const at::Tensor & grad_in, c10::SymIntArrayRef input_sizes, int64_t dim, int64_t size, int64_t step) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_in, cur_level)) { + return at::_ops::unfold_backward::call(grad_in, input_sizes, dim, size, step); + } + auto [grad_in_value, grad_in_bdim] = unwrapTensorAtLevel(grad_in, cur_level); + auto results = batch_rule(grad_in_value, grad_in_bdim, input_sizes, dim, size, step); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor pow_Tensor_Tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & exponent) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(exponent, cur_level)) { + return at::_ops::pow_Tensor_Tensor::call(self, exponent); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [exponent_value, exponent_bdim] = unwrapTensorAtLevel(exponent, cur_level); + auto results = batch_rule(self_value, self_bdim, exponent_value, exponent_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor pow_Scalar_generated_plumbing(const at::Scalar & self, const at::Tensor & exponent) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(exponent, cur_level)) { + return at::_ops::pow_Scalar::call(self, exponent); + } + auto [exponent_value, exponent_bdim] = unwrapTensorAtLevel(exponent, cur_level); + auto results = batch_rule(self, exponent_value, exponent_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor pow_Tensor_Scalar_generated_plumbing(const at::Tensor & self, const at::Scalar & exponent) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::pow_Tensor_Scalar::call(self, exponent); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, exponent); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & pow__Scalar_generated_plumbing(at::Tensor & self, const at::Scalar & exponent) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::pow__Scalar::call(self, exponent); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, exponent); + return self; +} +template +at::Tensor & pow__Tensor_generated_plumbing(at::Tensor & self, const at::Tensor & exponent) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(exponent, cur_level)) { + return at::_ops::pow__Tensor::call(self, exponent); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [exponent_value, exponent_bdim] = unwrapTensorAtLevel(exponent, cur_level); + batch_rule(self_value, self_bdim, exponent_value, exponent_bdim); + return self; +} +template +at::Tensor float_power_Tensor_Tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & exponent) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(exponent, cur_level)) { + return at::_ops::float_power_Tensor_Tensor::call(self, exponent); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [exponent_value, exponent_bdim] = unwrapTensorAtLevel(exponent, cur_level); + auto results = batch_rule(self_value, self_bdim, exponent_value, exponent_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor float_power_Scalar_generated_plumbing(const at::Scalar & self, const at::Tensor & exponent) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(exponent, cur_level)) { + return at::_ops::float_power_Scalar::call(self, exponent); + } + auto [exponent_value, exponent_bdim] = unwrapTensorAtLevel(exponent, cur_level); + auto results = batch_rule(self, exponent_value, exponent_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor float_power_Tensor_Scalar_generated_plumbing(const at::Tensor & self, const at::Scalar & exponent) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::float_power_Tensor_Scalar::call(self, exponent); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, exponent); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & float_power__Scalar_generated_plumbing(at::Tensor & self, const at::Scalar & exponent) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::float_power__Scalar::call(self, exponent); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, exponent); + return self; +} +template +at::Tensor & float_power__Tensor_generated_plumbing(at::Tensor & self, const at::Tensor & exponent) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(exponent, cur_level)) { + return at::_ops::float_power__Tensor::call(self, exponent); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [exponent_value, exponent_bdim] = unwrapTensorAtLevel(exponent, cur_level); + batch_rule(self_value, self_bdim, exponent_value, exponent_bdim); + return self; +} +template +at::Tensor & normal__generated_plumbing(at::Tensor & self, double mean, double std, ::std::optional generator) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::normal_::call(self, mean, std, generator); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, mean, std, generator); + return self; +} +template +at::Tensor normal_functional_generated_plumbing(const at::Tensor & self, double mean, double std, ::std::optional generator) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::normal_functional::call(self, mean, std, generator); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, mean, std, generator); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor normal_Tensor_float_generated_plumbing(const at::Tensor & mean, double std, ::std::optional generator) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(mean, cur_level)) { + return at::_ops::normal_Tensor_float::call(mean, std, generator); + } + auto [mean_value, mean_bdim] = unwrapTensorAtLevel(mean, cur_level); + auto results = batch_rule(mean_value, mean_bdim, std, generator); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor normal_float_Tensor_generated_plumbing(double mean, const at::Tensor & std, ::std::optional generator) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(std, cur_level)) { + return at::_ops::normal_float_Tensor::call(mean, std, generator); + } + auto [std_value, std_bdim] = unwrapTensorAtLevel(std, cur_level); + auto results = batch_rule(mean, std_value, std_bdim, generator); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor normal_Tensor_Tensor_generated_plumbing(const at::Tensor & mean, const at::Tensor & std, ::std::optional generator) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(mean, cur_level) && !isBatchedAtLevel(std, cur_level)) { + return at::_ops::normal_Tensor_Tensor::call(mean, std, generator); + } + auto [mean_value, mean_bdim] = unwrapTensorAtLevel(mean, cur_level); + auto [std_value, std_bdim] = unwrapTensorAtLevel(std, cur_level); + auto results = batch_rule(mean_value, mean_bdim, std_value, std_bdim, generator); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor alias_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::alias::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _amp_foreach_non_finite_check_and_unscale__generated_plumbing(at::TensorList self, at::Tensor & found_inf, const at::Tensor & inv_scale) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(found_inf, cur_level) && !isBatchedAtLevel(inv_scale, cur_level)) { + return at::_ops::_amp_foreach_non_finite_check_and_unscale_::call(self, found_inf, inv_scale); + } + auto [found_inf_value, found_inf_bdim] = unwrapTensorAtLevel(found_inf, cur_level); + auto [inv_scale_value, inv_scale_bdim] = unwrapTensorAtLevel(inv_scale, cur_level); + batch_rule(self, found_inf_value, found_inf_bdim, inv_scale_value, inv_scale_bdim); +} +template +::std::vector _foreach_add_Scalar_generated_plumbing(at::TensorList self, const at::Scalar & scalar) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_add_Scalar::call(self, scalar); + } + + auto results = batch_rule(self, scalar); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_add__Scalar_generated_plumbing(at::TensorList self, const at::Scalar & scalar) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_add__Scalar::call(self, scalar); + } + + batch_rule(self, scalar); +} +template +::std::vector _foreach_add_List_generated_plumbing(at::TensorList self, at::TensorList other, const at::Scalar & alpha) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::_foreach_add_List::call(self, other, alpha); + } + + auto results = batch_rule(self, other, alpha); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_add__List_generated_plumbing(at::TensorList self, at::TensorList other, const at::Scalar & alpha) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::_foreach_add__List::call(self, other, alpha); + } + + batch_rule(self, other, alpha); +} +template +::std::vector _foreach_add_ScalarList_generated_plumbing(at::TensorList self, at::ArrayRef scalars) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_add_ScalarList::call(self, scalars); + } + + auto results = batch_rule(self, scalars); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_add__ScalarList_generated_plumbing(at::TensorList self, at::ArrayRef scalars) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_add__ScalarList::call(self, scalars); + } + + batch_rule(self, scalars); +} +template +::std::vector _foreach_add_Tensor_generated_plumbing(at::TensorList self, const at::Tensor & other, const at::Scalar & alpha) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::_foreach_add_Tensor::call(self, other, alpha); + } + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self, other_value, other_bdim, alpha); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_add__Tensor_generated_plumbing(at::TensorList self, const at::Tensor & other, const at::Scalar & alpha) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::_foreach_add__Tensor::call(self, other, alpha); + } + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self, other_value, other_bdim, alpha); +} +template +::std::vector _foreach_sub_Scalar_generated_plumbing(at::TensorList self, const at::Scalar & scalar) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_sub_Scalar::call(self, scalar); + } + + auto results = batch_rule(self, scalar); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_sub__Scalar_generated_plumbing(at::TensorList self, const at::Scalar & scalar) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_sub__Scalar::call(self, scalar); + } + + batch_rule(self, scalar); +} +template +::std::vector _foreach_sub_List_generated_plumbing(at::TensorList self, at::TensorList other, const at::Scalar & alpha) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::_foreach_sub_List::call(self, other, alpha); + } + + auto results = batch_rule(self, other, alpha); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_sub__List_generated_plumbing(at::TensorList self, at::TensorList other, const at::Scalar & alpha) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::_foreach_sub__List::call(self, other, alpha); + } + + batch_rule(self, other, alpha); +} +template +::std::vector _foreach_sub_ScalarList_generated_plumbing(at::TensorList self, at::ArrayRef scalars) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_sub_ScalarList::call(self, scalars); + } + + auto results = batch_rule(self, scalars); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_sub__ScalarList_generated_plumbing(at::TensorList self, at::ArrayRef scalars) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_sub__ScalarList::call(self, scalars); + } + + batch_rule(self, scalars); +} +template +::std::vector _foreach_mul_Scalar_generated_plumbing(at::TensorList self, const at::Scalar & scalar) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_mul_Scalar::call(self, scalar); + } + + auto results = batch_rule(self, scalar); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_mul__Scalar_generated_plumbing(at::TensorList self, const at::Scalar & scalar) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_mul__Scalar::call(self, scalar); + } + + batch_rule(self, scalar); +} +template +::std::vector _foreach_mul_List_generated_plumbing(at::TensorList self, at::TensorList other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::_foreach_mul_List::call(self, other); + } + + auto results = batch_rule(self, other); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_mul__List_generated_plumbing(at::TensorList self, at::TensorList other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::_foreach_mul__List::call(self, other); + } + + batch_rule(self, other); +} +template +::std::vector _foreach_mul_ScalarList_generated_plumbing(at::TensorList self, at::ArrayRef scalars) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_mul_ScalarList::call(self, scalars); + } + + auto results = batch_rule(self, scalars); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_mul__ScalarList_generated_plumbing(at::TensorList self, at::ArrayRef scalars) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_mul__ScalarList::call(self, scalars); + } + + batch_rule(self, scalars); +} +template +::std::vector _foreach_mul_Tensor_generated_plumbing(at::TensorList self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::_foreach_mul_Tensor::call(self, other); + } + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self, other_value, other_bdim); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_mul__Tensor_generated_plumbing(at::TensorList self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::_foreach_mul__Tensor::call(self, other); + } + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self, other_value, other_bdim); +} +template +::std::vector _foreach_div_Scalar_generated_plumbing(at::TensorList self, const at::Scalar & scalar) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_div_Scalar::call(self, scalar); + } + + auto results = batch_rule(self, scalar); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_div__Scalar_generated_plumbing(at::TensorList self, const at::Scalar & scalar) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_div__Scalar::call(self, scalar); + } + + batch_rule(self, scalar); +} +template +::std::vector _foreach_div_List_generated_plumbing(at::TensorList self, at::TensorList other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::_foreach_div_List::call(self, other); + } + + auto results = batch_rule(self, other); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_div__List_generated_plumbing(at::TensorList self, at::TensorList other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::_foreach_div__List::call(self, other); + } + + batch_rule(self, other); +} +template +::std::vector _foreach_div_ScalarList_generated_plumbing(at::TensorList self, at::ArrayRef scalars) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_div_ScalarList::call(self, scalars); + } + + auto results = batch_rule(self, scalars); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_div__ScalarList_generated_plumbing(at::TensorList self, at::ArrayRef scalars) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_div__ScalarList::call(self, scalars); + } + + batch_rule(self, scalars); +} +template +::std::vector _foreach_div_Tensor_generated_plumbing(at::TensorList self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::_foreach_div_Tensor::call(self, other); + } + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self, other_value, other_bdim); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_div__Tensor_generated_plumbing(at::TensorList self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::_foreach_div__Tensor::call(self, other); + } + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self, other_value, other_bdim); +} +template +::std::vector _foreach_clamp_max_Scalar_generated_plumbing(at::TensorList self, const at::Scalar & scalar) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_clamp_max_Scalar::call(self, scalar); + } + + auto results = batch_rule(self, scalar); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_clamp_max__Scalar_generated_plumbing(at::TensorList self, const at::Scalar & scalar) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_clamp_max__Scalar::call(self, scalar); + } + + batch_rule(self, scalar); +} +template +::std::vector _foreach_clamp_max_List_generated_plumbing(at::TensorList self, at::TensorList other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::_foreach_clamp_max_List::call(self, other); + } + + auto results = batch_rule(self, other); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_clamp_max__List_generated_plumbing(at::TensorList self, at::TensorList other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::_foreach_clamp_max__List::call(self, other); + } + + batch_rule(self, other); +} +template +::std::vector _foreach_clamp_max_ScalarList_generated_plumbing(at::TensorList self, at::ArrayRef scalars) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_clamp_max_ScalarList::call(self, scalars); + } + + auto results = batch_rule(self, scalars); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_clamp_max__ScalarList_generated_plumbing(at::TensorList self, at::ArrayRef scalars) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_clamp_max__ScalarList::call(self, scalars); + } + + batch_rule(self, scalars); +} +template +::std::vector _foreach_clamp_min_Scalar_generated_plumbing(at::TensorList self, const at::Scalar & scalar) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_clamp_min_Scalar::call(self, scalar); + } + + auto results = batch_rule(self, scalar); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_clamp_min__Scalar_generated_plumbing(at::TensorList self, const at::Scalar & scalar) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_clamp_min__Scalar::call(self, scalar); + } + + batch_rule(self, scalar); +} +template +::std::vector _foreach_clamp_min_List_generated_plumbing(at::TensorList self, at::TensorList other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::_foreach_clamp_min_List::call(self, other); + } + + auto results = batch_rule(self, other); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_clamp_min__List_generated_plumbing(at::TensorList self, at::TensorList other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::_foreach_clamp_min__List::call(self, other); + } + + batch_rule(self, other); +} +template +::std::vector _foreach_clamp_min_ScalarList_generated_plumbing(at::TensorList self, at::ArrayRef scalars) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_clamp_min_ScalarList::call(self, scalars); + } + + auto results = batch_rule(self, scalars); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_clamp_min__ScalarList_generated_plumbing(at::TensorList self, at::ArrayRef scalars) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_clamp_min__ScalarList::call(self, scalars); + } + + batch_rule(self, scalars); +} +template +::std::vector _foreach_maximum_Scalar_generated_plumbing(at::TensorList self, const at::Scalar & scalar) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_maximum_Scalar::call(self, scalar); + } + + auto results = batch_rule(self, scalar); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_maximum__Scalar_generated_plumbing(at::TensorList self, const at::Scalar & scalar) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_maximum__Scalar::call(self, scalar); + } + + batch_rule(self, scalar); +} +template +::std::vector _foreach_maximum_List_generated_plumbing(at::TensorList self, at::TensorList other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::_foreach_maximum_List::call(self, other); + } + + auto results = batch_rule(self, other); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_maximum__List_generated_plumbing(at::TensorList self, at::TensorList other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::_foreach_maximum__List::call(self, other); + } + + batch_rule(self, other); +} +template +::std::vector _foreach_maximum_ScalarList_generated_plumbing(at::TensorList self, at::ArrayRef scalars) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_maximum_ScalarList::call(self, scalars); + } + + auto results = batch_rule(self, scalars); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_maximum__ScalarList_generated_plumbing(at::TensorList self, at::ArrayRef scalars) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_maximum__ScalarList::call(self, scalars); + } + + batch_rule(self, scalars); +} +template +::std::vector _foreach_minimum_Scalar_generated_plumbing(at::TensorList self, const at::Scalar & scalar) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_minimum_Scalar::call(self, scalar); + } + + auto results = batch_rule(self, scalar); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_minimum__Scalar_generated_plumbing(at::TensorList self, const at::Scalar & scalar) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_minimum__Scalar::call(self, scalar); + } + + batch_rule(self, scalar); +} +template +::std::vector _foreach_minimum_List_generated_plumbing(at::TensorList self, at::TensorList other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::_foreach_minimum_List::call(self, other); + } + + auto results = batch_rule(self, other); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_minimum__List_generated_plumbing(at::TensorList self, at::TensorList other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::_foreach_minimum__List::call(self, other); + } + + batch_rule(self, other); +} +template +::std::vector _foreach_minimum_ScalarList_generated_plumbing(at::TensorList self, at::ArrayRef scalars) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_minimum_ScalarList::call(self, scalars); + } + + auto results = batch_rule(self, scalars); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_minimum__ScalarList_generated_plumbing(at::TensorList self, at::ArrayRef scalars) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_minimum__ScalarList::call(self, scalars); + } + + batch_rule(self, scalars); +} +template +::std::vector _foreach_addcdiv_Scalar_generated_plumbing(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Scalar & value) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(tensor1, cur_level) && !isBatchedAtLevel(tensor2, cur_level)) { + return at::_ops::_foreach_addcdiv_Scalar::call(self, tensor1, tensor2, value); + } + + auto results = batch_rule(self, tensor1, tensor2, value); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::vector _foreach_addcdiv_ScalarList_generated_plumbing(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, at::ArrayRef scalars) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(tensor1, cur_level) && !isBatchedAtLevel(tensor2, cur_level)) { + return at::_ops::_foreach_addcdiv_ScalarList::call(self, tensor1, tensor2, scalars); + } + + auto results = batch_rule(self, tensor1, tensor2, scalars); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::vector _foreach_addcdiv_Tensor_generated_plumbing(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Tensor & scalars) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(tensor1, cur_level) && !isBatchedAtLevel(tensor2, cur_level) && !isBatchedAtLevel(scalars, cur_level)) { + return at::_ops::_foreach_addcdiv_Tensor::call(self, tensor1, tensor2, scalars); + } + auto [scalars_value, scalars_bdim] = unwrapTensorAtLevel(scalars, cur_level); + auto results = batch_rule(self, tensor1, tensor2, scalars_value, scalars_bdim); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_addcdiv__Scalar_generated_plumbing(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Scalar & value) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(tensor1, cur_level) && !isBatchedAtLevel(tensor2, cur_level)) { + return at::_ops::_foreach_addcdiv__Scalar::call(self, tensor1, tensor2, value); + } + + batch_rule(self, tensor1, tensor2, value); +} +template +void _foreach_addcdiv__ScalarList_generated_plumbing(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, at::ArrayRef scalars) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(tensor1, cur_level) && !isBatchedAtLevel(tensor2, cur_level)) { + return at::_ops::_foreach_addcdiv__ScalarList::call(self, tensor1, tensor2, scalars); + } + + batch_rule(self, tensor1, tensor2, scalars); +} +template +void _foreach_addcdiv__Tensor_generated_plumbing(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Tensor & scalars) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(tensor1, cur_level) && !isBatchedAtLevel(tensor2, cur_level) && !isBatchedAtLevel(scalars, cur_level)) { + return at::_ops::_foreach_addcdiv__Tensor::call(self, tensor1, tensor2, scalars); + } + auto [scalars_value, scalars_bdim] = unwrapTensorAtLevel(scalars, cur_level); + batch_rule(self, tensor1, tensor2, scalars_value, scalars_bdim); +} +template +::std::vector _foreach_addcmul_Scalar_generated_plumbing(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Scalar & value) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(tensor1, cur_level) && !isBatchedAtLevel(tensor2, cur_level)) { + return at::_ops::_foreach_addcmul_Scalar::call(self, tensor1, tensor2, value); + } + + auto results = batch_rule(self, tensor1, tensor2, value); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::vector _foreach_addcmul_ScalarList_generated_plumbing(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, at::ArrayRef scalars) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(tensor1, cur_level) && !isBatchedAtLevel(tensor2, cur_level)) { + return at::_ops::_foreach_addcmul_ScalarList::call(self, tensor1, tensor2, scalars); + } + + auto results = batch_rule(self, tensor1, tensor2, scalars); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::vector _foreach_addcmul_Tensor_generated_plumbing(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Tensor & scalars) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(tensor1, cur_level) && !isBatchedAtLevel(tensor2, cur_level) && !isBatchedAtLevel(scalars, cur_level)) { + return at::_ops::_foreach_addcmul_Tensor::call(self, tensor1, tensor2, scalars); + } + auto [scalars_value, scalars_bdim] = unwrapTensorAtLevel(scalars, cur_level); + auto results = batch_rule(self, tensor1, tensor2, scalars_value, scalars_bdim); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_addcmul__Scalar_generated_plumbing(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Scalar & value) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(tensor1, cur_level) && !isBatchedAtLevel(tensor2, cur_level)) { + return at::_ops::_foreach_addcmul__Scalar::call(self, tensor1, tensor2, value); + } + + batch_rule(self, tensor1, tensor2, value); +} +template +void _foreach_addcmul__ScalarList_generated_plumbing(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, at::ArrayRef scalars) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(tensor1, cur_level) && !isBatchedAtLevel(tensor2, cur_level)) { + return at::_ops::_foreach_addcmul__ScalarList::call(self, tensor1, tensor2, scalars); + } + + batch_rule(self, tensor1, tensor2, scalars); +} +template +void _foreach_addcmul__Tensor_generated_plumbing(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Tensor & scalars) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(tensor1, cur_level) && !isBatchedAtLevel(tensor2, cur_level) && !isBatchedAtLevel(scalars, cur_level)) { + return at::_ops::_foreach_addcmul__Tensor::call(self, tensor1, tensor2, scalars); + } + auto [scalars_value, scalars_bdim] = unwrapTensorAtLevel(scalars, cur_level); + batch_rule(self, tensor1, tensor2, scalars_value, scalars_bdim); +} +template +::std::vector _foreach_abs_generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_abs::call(self); + } + + auto results = batch_rule(self); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_abs__generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_abs_::call(self); + } + + batch_rule(self); +} +template +::std::vector _foreach_acos_generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_acos::call(self); + } + + auto results = batch_rule(self); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_acos__generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_acos_::call(self); + } + + batch_rule(self); +} +template +::std::vector _foreach_asin_generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_asin::call(self); + } + + auto results = batch_rule(self); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_asin__generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_asin_::call(self); + } + + batch_rule(self); +} +template +::std::vector _foreach_atan_generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_atan::call(self); + } + + auto results = batch_rule(self); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_atan__generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_atan_::call(self); + } + + batch_rule(self); +} +template +::std::vector _foreach_ceil_generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_ceil::call(self); + } + + auto results = batch_rule(self); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_ceil__generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_ceil_::call(self); + } + + batch_rule(self); +} +template +::std::vector _foreach_cos_generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_cos::call(self); + } + + auto results = batch_rule(self); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_cos__generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_cos_::call(self); + } + + batch_rule(self); +} +template +::std::vector _foreach_cosh_generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_cosh::call(self); + } + + auto results = batch_rule(self); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_cosh__generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_cosh_::call(self); + } + + batch_rule(self); +} +template +::std::vector _foreach_erf_generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_erf::call(self); + } + + auto results = batch_rule(self); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_erf__generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_erf_::call(self); + } + + batch_rule(self); +} +template +::std::vector _foreach_erfc_generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_erfc::call(self); + } + + auto results = batch_rule(self); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_erfc__generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_erfc_::call(self); + } + + batch_rule(self); +} +template +::std::vector _foreach_exp_generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_exp::call(self); + } + + auto results = batch_rule(self); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_exp__generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_exp_::call(self); + } + + batch_rule(self); +} +template +::std::vector _foreach_expm1_generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_expm1::call(self); + } + + auto results = batch_rule(self); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_expm1__generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_expm1_::call(self); + } + + batch_rule(self); +} +template +::std::vector _foreach_floor_generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_floor::call(self); + } + + auto results = batch_rule(self); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_floor__generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_floor_::call(self); + } + + batch_rule(self); +} +template +::std::vector _foreach_frac_generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_frac::call(self); + } + + auto results = batch_rule(self); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_frac__generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_frac_::call(self); + } + + batch_rule(self); +} +template +::std::vector _foreach_lerp_List_generated_plumbing(at::TensorList self, at::TensorList tensors1, at::TensorList weights) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(tensors1, cur_level) && !isBatchedAtLevel(weights, cur_level)) { + return at::_ops::_foreach_lerp_List::call(self, tensors1, weights); + } + + auto results = batch_rule(self, tensors1, weights); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_lerp__List_generated_plumbing(at::TensorList self, at::TensorList tensors1, at::TensorList weights) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(tensors1, cur_level) && !isBatchedAtLevel(weights, cur_level)) { + return at::_ops::_foreach_lerp__List::call(self, tensors1, weights); + } + + batch_rule(self, tensors1, weights); +} +template +::std::vector _foreach_lerp_Scalar_generated_plumbing(at::TensorList self, at::TensorList tensors1, const at::Scalar & weight) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(tensors1, cur_level)) { + return at::_ops::_foreach_lerp_Scalar::call(self, tensors1, weight); + } + + auto results = batch_rule(self, tensors1, weight); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_lerp__Scalar_generated_plumbing(at::TensorList self, at::TensorList tensors1, const at::Scalar & weight) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(tensors1, cur_level)) { + return at::_ops::_foreach_lerp__Scalar::call(self, tensors1, weight); + } + + batch_rule(self, tensors1, weight); +} +template +::std::vector _foreach_lerp_ScalarList_generated_plumbing(at::TensorList self, at::TensorList tensors1, at::ArrayRef weight) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(tensors1, cur_level)) { + return at::_ops::_foreach_lerp_ScalarList::call(self, tensors1, weight); + } + + auto results = batch_rule(self, tensors1, weight); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_lerp__ScalarList_generated_plumbing(at::TensorList self, at::TensorList tensors1, at::ArrayRef weight) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(tensors1, cur_level)) { + return at::_ops::_foreach_lerp__ScalarList::call(self, tensors1, weight); + } + + batch_rule(self, tensors1, weight); +} +template +::std::vector _foreach_lgamma_generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_lgamma::call(self); + } + + auto results = batch_rule(self); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_lgamma__generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_lgamma_::call(self); + } + + batch_rule(self); +} +template +::std::vector _foreach_log_generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_log::call(self); + } + + auto results = batch_rule(self); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_log__generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_log_::call(self); + } + + batch_rule(self); +} +template +::std::vector _foreach_log10_generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_log10::call(self); + } + + auto results = batch_rule(self); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_log10__generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_log10_::call(self); + } + + batch_rule(self); +} +template +::std::vector _foreach_log1p_generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_log1p::call(self); + } + + auto results = batch_rule(self); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_log1p__generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_log1p_::call(self); + } + + batch_rule(self); +} +template +::std::vector _foreach_log2_generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_log2::call(self); + } + + auto results = batch_rule(self); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_log2__generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_log2_::call(self); + } + + batch_rule(self); +} +template +::std::vector _foreach_max_generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_max::call(self); + } + + auto results = batch_rule(self); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::vector _foreach_neg_generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_neg::call(self); + } + + auto results = batch_rule(self); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_neg__generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_neg_::call(self); + } + + batch_rule(self); +} +template +::std::vector _foreach_norm_Scalar_generated_plumbing(at::TensorList self, const at::Scalar & ord, ::std::optional dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_norm_Scalar::call(self, ord, dtype); + } + + auto results = batch_rule(self, ord, dtype); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::vector _foreach_pow_List_generated_plumbing(at::TensorList self, at::TensorList exponent) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(exponent, cur_level)) { + return at::_ops::_foreach_pow_List::call(self, exponent); + } + + auto results = batch_rule(self, exponent); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::vector _foreach_pow_Scalar_generated_plumbing(at::TensorList self, const at::Scalar & exponent) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_pow_Scalar::call(self, exponent); + } + + auto results = batch_rule(self, exponent); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::vector _foreach_pow_ScalarList_generated_plumbing(at::TensorList self, at::ArrayRef exponent) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_pow_ScalarList::call(self, exponent); + } + + auto results = batch_rule(self, exponent); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::vector _foreach_pow_ScalarAndTensor_generated_plumbing(const at::Scalar & self, at::TensorList exponent) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(exponent, cur_level)) { + return at::_ops::_foreach_pow_ScalarAndTensor::call(self, exponent); + } + + auto results = batch_rule(self, exponent); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_pow__List_generated_plumbing(at::TensorList self, at::TensorList exponent) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(exponent, cur_level)) { + return at::_ops::_foreach_pow__List::call(self, exponent); + } + + batch_rule(self, exponent); +} +template +void _foreach_pow__Scalar_generated_plumbing(at::TensorList self, const at::Scalar & exponent) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_pow__Scalar::call(self, exponent); + } + + batch_rule(self, exponent); +} +template +void _foreach_pow__ScalarList_generated_plumbing(at::TensorList self, at::ArrayRef exponent) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_pow__ScalarList::call(self, exponent); + } + + batch_rule(self, exponent); +} +template +::std::vector _foreach_reciprocal_generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_reciprocal::call(self); + } + + auto results = batch_rule(self); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_reciprocal__generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_reciprocal_::call(self); + } + + batch_rule(self); +} +template +::std::vector _foreach_round_generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_round::call(self); + } + + auto results = batch_rule(self); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_round__generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_round_::call(self); + } + + batch_rule(self); +} +template +::std::vector _foreach_rsqrt_generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_rsqrt::call(self); + } + + auto results = batch_rule(self); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_rsqrt__generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_rsqrt_::call(self); + } + + batch_rule(self); +} +template +::std::vector _foreach_sigmoid_generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_sigmoid::call(self); + } + + auto results = batch_rule(self); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_sigmoid__generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_sigmoid_::call(self); + } + + batch_rule(self); +} +template +::std::vector _foreach_sign_generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_sign::call(self); + } + + auto results = batch_rule(self); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_sign__generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_sign_::call(self); + } + + batch_rule(self); +} +template +::std::vector _foreach_sin_generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_sin::call(self); + } + + auto results = batch_rule(self); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_sin__generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_sin_::call(self); + } + + batch_rule(self); +} +template +::std::vector _foreach_sinh_generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_sinh::call(self); + } + + auto results = batch_rule(self); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_sinh__generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_sinh_::call(self); + } + + batch_rule(self); +} +template +::std::vector _foreach_sqrt_generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_sqrt::call(self); + } + + auto results = batch_rule(self); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_sqrt__generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_sqrt_::call(self); + } + + batch_rule(self); +} +template +::std::vector _foreach_tan_generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_tan::call(self); + } + + auto results = batch_rule(self); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_tan__generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_tan_::call(self); + } + + batch_rule(self); +} +template +::std::vector _foreach_tanh_generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_tanh::call(self); + } + + auto results = batch_rule(self); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_tanh__generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_tanh_::call(self); + } + + batch_rule(self); +} +template +::std::vector _foreach_trunc_generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_trunc::call(self); + } + + auto results = batch_rule(self); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_trunc__generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_trunc_::call(self); + } + + batch_rule(self); +} +template +void _foreach_zero__generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_zero_::call(self); + } + + batch_rule(self); +} +template +void _foreach_copy__generated_plumbing(at::TensorList self, at::TensorList src, bool non_blocking) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(src, cur_level)) { + return at::_ops::_foreach_copy_::call(self, src, non_blocking); + } + + batch_rule(self, src, non_blocking); +} +template +::std::vector _foreach_copy_generated_plumbing(at::TensorList self, at::TensorList src, bool non_blocking) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(src, cur_level)) { + return at::_ops::_foreach_copy::call(self, src, non_blocking); + } + + auto results = batch_rule(self, src, non_blocking); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor bucketize_Tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & boundaries, bool out_int32, bool right) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(boundaries, cur_level)) { + return at::_ops::bucketize_Tensor::call(self, boundaries, out_int32, right); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [boundaries_value, boundaries_bdim] = unwrapTensorAtLevel(boundaries, cur_level); + auto results = batch_rule(self_value, self_bdim, boundaries_value, boundaries_bdim, out_int32, right); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor bucketize_Scalar_generated_plumbing(const at::Scalar & self, const at::Tensor & boundaries, bool out_int32, bool right) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(boundaries, cur_level)) { + return at::_ops::bucketize_Scalar::call(self, boundaries, out_int32, right); + } + auto [boundaries_value, boundaries_bdim] = unwrapTensorAtLevel(boundaries, cur_level); + auto results = batch_rule(self, boundaries_value, boundaries_bdim, out_int32, right); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor searchsorted_Tensor_generated_plumbing(const at::Tensor & sorted_sequence, const at::Tensor & self, bool out_int32, bool right, ::std::optional side, const ::std::optional & sorter) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(sorted_sequence, cur_level) && !isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(sorter, cur_level)) { + return at::_ops::searchsorted_Tensor::call(sorted_sequence, self, out_int32, right, side, sorter); + } + auto [sorted_sequence_value, sorted_sequence_bdim] = unwrapTensorAtLevel(sorted_sequence, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + std::optional sorter_value; + std::optional sorter_bdim; + if (sorter) { + std::tie(sorter_value, sorter_bdim) = unwrapTensorAtLevel(sorter.value(), cur_level); + } + auto results = batch_rule(sorted_sequence_value, sorted_sequence_bdim, self_value, self_bdim, out_int32, right, side, sorter_value, sorter_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor searchsorted_Scalar_generated_plumbing(const at::Tensor & sorted_sequence, const at::Scalar & self, bool out_int32, bool right, ::std::optional side, const ::std::optional & sorter) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(sorted_sequence, cur_level) && !isBatchedAtLevel(sorter, cur_level)) { + return at::_ops::searchsorted_Scalar::call(sorted_sequence, self, out_int32, right, side, sorter); + } + auto [sorted_sequence_value, sorted_sequence_bdim] = unwrapTensorAtLevel(sorted_sequence, cur_level); + std::optional sorter_value; + std::optional sorter_bdim; + if (sorter) { + std::tie(sorter_value, sorter_bdim) = unwrapTensorAtLevel(sorter.value(), cur_level); + } + auto results = batch_rule(sorted_sequence_value, sorted_sequence_bdim, self, out_int32, right, side, sorter_value, sorter_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _convert_indices_from_coo_to_csr_generated_plumbing(const at::Tensor & self, int64_t size, bool out_int32) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_convert_indices_from_coo_to_csr::call(self, size, out_int32); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, size, out_int32); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _convert_indices_from_csr_to_coo_generated_plumbing(const at::Tensor & crow_indices, const at::Tensor & col_indices, bool out_int32, bool transpose) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(crow_indices, cur_level) && !isBatchedAtLevel(col_indices, cur_level)) { + return at::_ops::_convert_indices_from_csr_to_coo::call(crow_indices, col_indices, out_int32, transpose); + } + auto [crow_indices_value, crow_indices_bdim] = unwrapTensorAtLevel(crow_indices, cur_level); + auto [col_indices_value, col_indices_bdim] = unwrapTensorAtLevel(col_indices, cur_level); + auto results = batch_rule(crow_indices_value, crow_indices_bdim, col_indices_value, col_indices_bdim, out_int32, transpose); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor mse_loss_generated_plumbing(const at::Tensor & self, const at::Tensor & target, int64_t reduction) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(target, cur_level)) { + return at::_ops::mse_loss::call(self, target, reduction); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [target_value, target_bdim] = unwrapTensorAtLevel(target, cur_level); + auto results = batch_rule(self_value, self_bdim, target_value, target_bdim, reduction); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor mse_loss_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, int64_t reduction) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(target, cur_level)) { + return at::_ops::mse_loss_backward::call(grad_output, self, target, reduction); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [target_value, target_bdim] = unwrapTensorAtLevel(target, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, self_value, self_bdim, target_value, target_bdim, reduction); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor l1_loss_generated_plumbing(const at::Tensor & self, const at::Tensor & target, int64_t reduction) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(target, cur_level)) { + return at::_ops::l1_loss::call(self, target, reduction); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [target_value, target_bdim] = unwrapTensorAtLevel(target, cur_level); + auto results = batch_rule(self_value, self_bdim, target_value, target_bdim, reduction); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor multi_margin_loss_generated_plumbing(const at::Tensor & self, const at::Tensor & target, const at::Scalar & p, const at::Scalar & margin, const ::std::optional & weight, int64_t reduction) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(target, cur_level) && !isBatchedAtLevel(weight, cur_level)) { + return at::_ops::multi_margin_loss::call(self, target, p, margin, weight, reduction); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [target_value, target_bdim] = unwrapTensorAtLevel(target, cur_level); + std::optional weight_value; + std::optional weight_bdim; + if (weight) { + std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, target_value, target_bdim, p, margin, weight_value, weight_bdim, reduction); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor multi_margin_loss_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const at::Scalar & p, const at::Scalar & margin, const ::std::optional & weight, int64_t reduction) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(target, cur_level) && !isBatchedAtLevel(weight, cur_level)) { + return at::_ops::multi_margin_loss_backward::call(grad_output, self, target, p, margin, weight, reduction); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [target_value, target_bdim] = unwrapTensorAtLevel(target, cur_level); + std::optional weight_value; + std::optional weight_bdim; + if (weight) { + std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight.value(), cur_level); + } + auto results = batch_rule(grad_output_value, grad_output_bdim, self_value, self_bdim, target_value, target_bdim, p, margin, weight_value, weight_bdim, reduction); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor multilabel_margin_loss_generated_plumbing(const at::Tensor & self, const at::Tensor & target, int64_t reduction) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(target, cur_level)) { + return at::_ops::multilabel_margin_loss::call(self, target, reduction); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [target_value, target_bdim] = unwrapTensorAtLevel(target, cur_level); + auto results = batch_rule(self_value, self_bdim, target_value, target_bdim, reduction); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple multilabel_margin_loss_forward_generated_plumbing(const at::Tensor & self, const at::Tensor & target, int64_t reduction) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(target, cur_level)) { + return at::_ops::multilabel_margin_loss_forward::call(self, target, reduction); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [target_value, target_bdim] = unwrapTensorAtLevel(target, cur_level); + auto results = batch_rule(self_value, self_bdim, target_value, target_bdim, reduction); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor multilabel_margin_loss_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, int64_t reduction, const at::Tensor & is_target) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(target, cur_level) && !isBatchedAtLevel(is_target, cur_level)) { + return at::_ops::multilabel_margin_loss_backward::call(grad_output, self, target, reduction, is_target); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [target_value, target_bdim] = unwrapTensorAtLevel(target, cur_level); + auto [is_target_value, is_target_bdim] = unwrapTensorAtLevel(is_target, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, self_value, self_bdim, target_value, target_bdim, reduction, is_target_value, is_target_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor nll_loss_nd_generated_plumbing(const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, c10::SymInt ignore_index) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(target, cur_level) && !isBatchedAtLevel(weight, cur_level)) { + return at::_ops::nll_loss_nd::call(self, target, weight, reduction, ignore_index); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [target_value, target_bdim] = unwrapTensorAtLevel(target, cur_level); + std::optional weight_value; + std::optional weight_bdim; + if (weight) { + std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, target_value, target_bdim, weight_value, weight_bdim, reduction, ignore_index); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor nll_loss_generated_plumbing(const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, c10::SymInt ignore_index) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(target, cur_level) && !isBatchedAtLevel(weight, cur_level)) { + return at::_ops::nll_loss::call(self, target, weight, reduction, ignore_index); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [target_value, target_bdim] = unwrapTensorAtLevel(target, cur_level); + std::optional weight_value; + std::optional weight_bdim; + if (weight) { + std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, target_value, target_bdim, weight_value, weight_bdim, reduction, ignore_index); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple nll_loss_forward_generated_plumbing(const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, c10::SymInt ignore_index) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(target, cur_level) && !isBatchedAtLevel(weight, cur_level)) { + return at::_ops::nll_loss_forward::call(self, target, weight, reduction, ignore_index); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [target_value, target_bdim] = unwrapTensorAtLevel(target, cur_level); + std::optional weight_value; + std::optional weight_bdim; + if (weight) { + std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, target_value, target_bdim, weight_value, weight_bdim, reduction, ignore_index); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor nll_loss_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, c10::SymInt ignore_index, const at::Tensor & total_weight) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(target, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(total_weight, cur_level)) { + return at::_ops::nll_loss_backward::call(grad_output, self, target, weight, reduction, ignore_index, total_weight); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [target_value, target_bdim] = unwrapTensorAtLevel(target, cur_level); + auto [total_weight_value, total_weight_bdim] = unwrapTensorAtLevel(total_weight, cur_level); + std::optional weight_value; + std::optional weight_bdim; + if (weight) { + std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight.value(), cur_level); + } + auto results = batch_rule(grad_output_value, grad_output_bdim, self_value, self_bdim, target_value, target_bdim, weight_value, weight_bdim, reduction, ignore_index, total_weight_value, total_weight_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor nll_loss2d_generated_plumbing(const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, c10::SymInt ignore_index) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(target, cur_level) && !isBatchedAtLevel(weight, cur_level)) { + return at::_ops::nll_loss2d::call(self, target, weight, reduction, ignore_index); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [target_value, target_bdim] = unwrapTensorAtLevel(target, cur_level); + std::optional weight_value; + std::optional weight_bdim; + if (weight) { + std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, target_value, target_bdim, weight_value, weight_bdim, reduction, ignore_index); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple nll_loss2d_forward_generated_plumbing(const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, c10::SymInt ignore_index) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(target, cur_level) && !isBatchedAtLevel(weight, cur_level)) { + return at::_ops::nll_loss2d_forward::call(self, target, weight, reduction, ignore_index); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [target_value, target_bdim] = unwrapTensorAtLevel(target, cur_level); + std::optional weight_value; + std::optional weight_bdim; + if (weight) { + std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, target_value, target_bdim, weight_value, weight_bdim, reduction, ignore_index); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor nll_loss2d_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, c10::SymInt ignore_index, const at::Tensor & total_weight) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(target, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(total_weight, cur_level)) { + return at::_ops::nll_loss2d_backward::call(grad_output, self, target, weight, reduction, ignore_index, total_weight); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [target_value, target_bdim] = unwrapTensorAtLevel(target, cur_level); + auto [total_weight_value, total_weight_bdim] = unwrapTensorAtLevel(total_weight, cur_level); + std::optional weight_value; + std::optional weight_bdim; + if (weight) { + std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight.value(), cur_level); + } + auto results = batch_rule(grad_output_value, grad_output_bdim, self_value, self_bdim, target_value, target_bdim, weight_value, weight_bdim, reduction, ignore_index, total_weight_value, total_weight_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor smooth_l1_loss_generated_plumbing(const at::Tensor & self, const at::Tensor & target, int64_t reduction, double beta) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(target, cur_level)) { + return at::_ops::smooth_l1_loss::call(self, target, reduction, beta); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [target_value, target_bdim] = unwrapTensorAtLevel(target, cur_level); + auto results = batch_rule(self_value, self_bdim, target_value, target_bdim, reduction, beta); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor smooth_l1_loss_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, int64_t reduction, double beta) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(target, cur_level)) { + return at::_ops::smooth_l1_loss_backward::call(grad_output, self, target, reduction, beta); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [target_value, target_bdim] = unwrapTensorAtLevel(target, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, self_value, self_bdim, target_value, target_bdim, reduction, beta); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor huber_loss_generated_plumbing(const at::Tensor & self, const at::Tensor & target, int64_t reduction, double delta) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(target, cur_level)) { + return at::_ops::huber_loss::call(self, target, reduction, delta); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [target_value, target_bdim] = unwrapTensorAtLevel(target, cur_level); + auto results = batch_rule(self_value, self_bdim, target_value, target_bdim, reduction, delta); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor huber_loss_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, int64_t reduction, double delta) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(target, cur_level)) { + return at::_ops::huber_loss_backward::call(grad_output, self, target, reduction, delta); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [target_value, target_bdim] = unwrapTensorAtLevel(target, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, self_value, self_bdim, target_value, target_bdim, reduction, delta); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor soft_margin_loss_generated_plumbing(const at::Tensor & self, const at::Tensor & target, int64_t reduction) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(target, cur_level)) { + return at::_ops::soft_margin_loss::call(self, target, reduction); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [target_value, target_bdim] = unwrapTensorAtLevel(target, cur_level); + auto results = batch_rule(self_value, self_bdim, target_value, target_bdim, reduction); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor soft_margin_loss_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, int64_t reduction) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(target, cur_level)) { + return at::_ops::soft_margin_loss_backward::call(grad_output, self, target, reduction); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [target_value, target_bdim] = unwrapTensorAtLevel(target, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, self_value, self_bdim, target_value, target_bdim, reduction); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor elu_generated_plumbing(const at::Tensor & self, const at::Scalar & alpha, const at::Scalar & scale, const at::Scalar & input_scale) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::elu::call(self, alpha, scale, input_scale); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, alpha, scale, input_scale); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor elu_backward_generated_plumbing(const at::Tensor & grad_output, const at::Scalar & alpha, const at::Scalar & scale, const at::Scalar & input_scale, bool is_result, const at::Tensor & self_or_result) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(self_or_result, cur_level)) { + return at::_ops::elu_backward::call(grad_output, alpha, scale, input_scale, is_result, self_or_result); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [self_or_result_value, self_or_result_bdim] = unwrapTensorAtLevel(self_or_result, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, alpha, scale, input_scale, is_result, self_or_result_value, self_or_result_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & elu__generated_plumbing(at::Tensor & self, const at::Scalar & alpha, const at::Scalar & scale, const at::Scalar & input_scale) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::elu_::call(self, alpha, scale, input_scale); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, alpha, scale, input_scale); + return self; +} +template +at::Tensor glu_generated_plumbing(const at::Tensor & self, int64_t dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::glu::call(self, dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor glu_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & self, int64_t dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(self, cur_level)) { + return at::_ops::glu_backward::call(grad_output, self, dim); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, self_value, self_bdim, dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor glu_jvp_generated_plumbing(const at::Tensor & glu, const at::Tensor & x, const at::Tensor & dx, int64_t dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(glu, cur_level) && !isBatchedAtLevel(x, cur_level) && !isBatchedAtLevel(dx, cur_level)) { + return at::_ops::glu_jvp::call(glu, x, dx, dim); + } + auto [glu_value, glu_bdim] = unwrapTensorAtLevel(glu, cur_level); + auto [x_value, x_bdim] = unwrapTensorAtLevel(x, cur_level); + auto [dx_value, dx_bdim] = unwrapTensorAtLevel(dx, cur_level); + auto results = batch_rule(glu_value, glu_bdim, x_value, x_bdim, dx_value, dx_bdim, dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor glu_backward_jvp_generated_plumbing(const at::Tensor & grad_x, const at::Tensor & grad_glu, const at::Tensor & x, const at::Tensor & dgrad_glu, const at::Tensor & dx, int64_t dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_x, cur_level) && !isBatchedAtLevel(grad_glu, cur_level) && !isBatchedAtLevel(x, cur_level) && !isBatchedAtLevel(dgrad_glu, cur_level) && !isBatchedAtLevel(dx, cur_level)) { + return at::_ops::glu_backward_jvp::call(grad_x, grad_glu, x, dgrad_glu, dx, dim); + } + auto [grad_x_value, grad_x_bdim] = unwrapTensorAtLevel(grad_x, cur_level); + auto [grad_glu_value, grad_glu_bdim] = unwrapTensorAtLevel(grad_glu, cur_level); + auto [x_value, x_bdim] = unwrapTensorAtLevel(x, cur_level); + auto [dgrad_glu_value, dgrad_glu_bdim] = unwrapTensorAtLevel(dgrad_glu, cur_level); + auto [dx_value, dx_bdim] = unwrapTensorAtLevel(dx, cur_level); + auto results = batch_rule(grad_x_value, grad_x_bdim, grad_glu_value, grad_glu_bdim, x_value, x_bdim, dgrad_glu_value, dgrad_glu_bdim, dx_value, dx_bdim, dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor hardsigmoid_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::hardsigmoid::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & hardsigmoid__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::hardsigmoid_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor hardsigmoid_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(self, cur_level)) { + return at::_ops::hardsigmoid_backward::call(grad_output, self); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor hardtanh_generated_plumbing(const at::Tensor & self, const at::Scalar & min_val, const at::Scalar & max_val) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::hardtanh::call(self, min_val, max_val); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, min_val, max_val); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor hardtanh_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & self, const at::Scalar & min_val, const at::Scalar & max_val) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(self, cur_level)) { + return at::_ops::hardtanh_backward::call(grad_output, self, min_val, max_val); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, self_value, self_bdim, min_val, max_val); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & hardtanh__generated_plumbing(at::Tensor & self, const at::Scalar & min_val, const at::Scalar & max_val) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::hardtanh_::call(self, min_val, max_val); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, min_val, max_val); + return self; +} +template +at::Tensor hardswish_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::hardswish::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & hardswish__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::hardswish_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor hardswish_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(self, cur_level)) { + return at::_ops::hardswish_backward::call(grad_output, self); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor leaky_relu_generated_plumbing(const at::Tensor & self, const at::Scalar & negative_slope) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::leaky_relu::call(self, negative_slope); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, negative_slope); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor leaky_relu_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & self, const at::Scalar & negative_slope, bool self_is_result) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(self, cur_level)) { + return at::_ops::leaky_relu_backward::call(grad_output, self, negative_slope, self_is_result); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, self_value, self_bdim, negative_slope, self_is_result); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & leaky_relu__generated_plumbing(at::Tensor & self, const at::Scalar & negative_slope) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::leaky_relu_::call(self, negative_slope); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, negative_slope); + return self; +} +template +at::Tensor log_sigmoid_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::log_sigmoid::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple log_sigmoid_forward_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::log_sigmoid_forward::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor log_sigmoid_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & buffer) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(buffer, cur_level)) { + return at::_ops::log_sigmoid_backward::call(grad_output, self, buffer); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [buffer_value, buffer_bdim] = unwrapTensorAtLevel(buffer, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, self_value, self_bdim, buffer_value, buffer_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor rrelu_with_noise_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & noise, const at::Scalar & lower, const at::Scalar & upper, bool training, bool self_is_result) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(noise, cur_level)) { + return at::_ops::rrelu_with_noise_backward::call(grad_output, self, noise, lower, upper, training, self_is_result); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [noise_value, noise_bdim] = unwrapTensorAtLevel(noise, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, self_value, self_bdim, noise_value, noise_bdim, lower, upper, training, self_is_result); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor softplus_generated_plumbing(const at::Tensor & self, const at::Scalar & beta, const at::Scalar & threshold) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::softplus::call(self, beta, threshold); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, beta, threshold); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor softplus_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & self, const at::Scalar & beta, const at::Scalar & threshold) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(self, cur_level)) { + return at::_ops::softplus_backward::call(grad_output, self, beta, threshold); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, self_value, self_bdim, beta, threshold); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor softshrink_generated_plumbing(const at::Tensor & self, const at::Scalar & lambd) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::softshrink::call(self, lambd); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, lambd); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor softshrink_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & self, const at::Scalar & lambd) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(self, cur_level)) { + return at::_ops::softshrink_backward::call(grad_output, self, lambd); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, self_value, self_bdim, lambd); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor adaptive_avg_pool2d_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef output_size) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::adaptive_avg_pool2d::call(self, output_size); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, output_size); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor mkldnn_adaptive_avg_pool2d_generated_plumbing(const at::Tensor & self, at::IntArrayRef output_size) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::mkldnn_adaptive_avg_pool2d::call(self, output_size); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, output_size); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor mkldnn_adaptive_avg_pool2d_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(self, cur_level)) { + return at::_ops::mkldnn_adaptive_avg_pool2d_backward::call(grad_output, self); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _adaptive_avg_pool2d_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef output_size) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_adaptive_avg_pool2d::call(self, output_size); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, output_size); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _adaptive_avg_pool2d_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(self, cur_level)) { + return at::_ops::_adaptive_avg_pool2d_backward::call(grad_output, self); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor adaptive_avg_pool3d_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef output_size) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::adaptive_avg_pool3d::call(self, output_size); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, output_size); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _adaptive_avg_pool3d_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef output_size) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_adaptive_avg_pool3d::call(self, output_size); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, output_size); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _adaptive_avg_pool3d_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(self, cur_level)) { + return at::_ops::_adaptive_avg_pool3d_backward::call(grad_output, self); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple adaptive_max_pool2d_generated_plumbing(const at::Tensor & self, at::IntArrayRef output_size) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::adaptive_max_pool2d::call(self, output_size); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, output_size); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor adaptive_max_pool2d_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(indices, cur_level)) { + return at::_ops::adaptive_max_pool2d_backward::call(grad_output, self, indices); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [indices_value, indices_bdim] = unwrapTensorAtLevel(indices, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, self_value, self_bdim, indices_value, indices_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple adaptive_max_pool3d_generated_plumbing(const at::Tensor & self, at::IntArrayRef output_size) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::adaptive_max_pool3d::call(self, output_size); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, output_size); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor adaptive_max_pool3d_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(indices, cur_level)) { + return at::_ops::adaptive_max_pool3d_backward::call(grad_output, self, indices); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [indices_value, indices_bdim] = unwrapTensorAtLevel(indices, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, self_value, self_bdim, indices_value, indices_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor avg_pool2d_generated_plumbing(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::avg_pool2d::call(self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor avg_pool2d_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(self, cur_level)) { + return at::_ops::avg_pool2d_backward::call(grad_output, self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, self_value, self_bdim, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor avg_pool3d_generated_plumbing(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::avg_pool3d::call(self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor avg_pool3d_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(self, cur_level)) { + return at::_ops::avg_pool3d_backward::call(grad_output, self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, self_value, self_bdim, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple fractional_max_pool2d_generated_plumbing(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef output_size, const at::Tensor & random_samples) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(random_samples, cur_level)) { + return at::_ops::fractional_max_pool2d::call(self, kernel_size, output_size, random_samples); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [random_samples_value, random_samples_bdim] = unwrapTensorAtLevel(random_samples, cur_level); + auto results = batch_rule(self_value, self_bdim, kernel_size, output_size, random_samples_value, random_samples_bdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor fractional_max_pool2d_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef output_size, const at::Tensor & indices) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(indices, cur_level)) { + return at::_ops::fractional_max_pool2d_backward::call(grad_output, self, kernel_size, output_size, indices); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [indices_value, indices_bdim] = unwrapTensorAtLevel(indices, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, self_value, self_bdim, kernel_size, output_size, indices_value, indices_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple fractional_max_pool3d_generated_plumbing(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef output_size, const at::Tensor & random_samples) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(random_samples, cur_level)) { + return at::_ops::fractional_max_pool3d::call(self, kernel_size, output_size, random_samples); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [random_samples_value, random_samples_bdim] = unwrapTensorAtLevel(random_samples, cur_level); + auto results = batch_rule(self_value, self_bdim, kernel_size, output_size, random_samples_value, random_samples_bdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor fractional_max_pool3d_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef output_size, const at::Tensor & indices) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(indices, cur_level)) { + return at::_ops::fractional_max_pool3d_backward::call(grad_output, self, kernel_size, output_size, indices); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [indices_value, indices_bdim] = unwrapTensorAtLevel(indices, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, self_value, self_bdim, kernel_size, output_size, indices_value, indices_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple max_pool2d_with_indices_generated_plumbing(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::max_pool2d_with_indices::call(self, kernel_size, stride, padding, dilation, ceil_mode); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, kernel_size, stride, padding, dilation, ceil_mode); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor max_pool2d_with_indices_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, const at::Tensor & indices) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(indices, cur_level)) { + return at::_ops::max_pool2d_with_indices_backward::call(grad_output, self, kernel_size, stride, padding, dilation, ceil_mode, indices); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [indices_value, indices_bdim] = unwrapTensorAtLevel(indices, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, self_value, self_bdim, kernel_size, stride, padding, dilation, ceil_mode, indices_value, indices_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple max_pool3d_with_indices_generated_plumbing(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::max_pool3d_with_indices::call(self, kernel_size, stride, padding, dilation, ceil_mode); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, kernel_size, stride, padding, dilation, ceil_mode); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor max_pool3d_with_indices_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, const at::Tensor & indices) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(indices, cur_level)) { + return at::_ops::max_pool3d_with_indices_backward::call(grad_output, self, kernel_size, stride, padding, dilation, ceil_mode, indices); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [indices_value, indices_bdim] = unwrapTensorAtLevel(indices, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, self_value, self_bdim, kernel_size, stride, padding, dilation, ceil_mode, indices_value, indices_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor max_unpool2d_generated_plumbing(const at::Tensor & self, const at::Tensor & indices, c10::SymIntArrayRef output_size) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(indices, cur_level)) { + return at::_ops::max_unpool2d::call(self, indices, output_size); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [indices_value, indices_bdim] = unwrapTensorAtLevel(indices, cur_level); + auto results = batch_rule(self_value, self_bdim, indices_value, indices_bdim, output_size); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor max_unpool3d_generated_plumbing(const at::Tensor & self, const at::Tensor & indices, c10::SymIntArrayRef output_size, at::IntArrayRef stride, at::IntArrayRef padding) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(indices, cur_level)) { + return at::_ops::max_unpool3d::call(self, indices, output_size, stride, padding); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [indices_value, indices_bdim] = unwrapTensorAtLevel(indices, cur_level); + auto results = batch_rule(self_value, self_bdim, indices_value, indices_bdim, output_size, stride, padding); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor reflection_pad1d_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef padding) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::reflection_pad1d::call(self, padding); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, padding); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor reflection_pad1d_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & self, c10::SymIntArrayRef padding) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(self, cur_level)) { + return at::_ops::reflection_pad1d_backward::call(grad_output, self, padding); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, self_value, self_bdim, padding); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor reflection_pad2d_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef padding) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::reflection_pad2d::call(self, padding); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, padding); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor reflection_pad2d_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & self, c10::SymIntArrayRef padding) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(self, cur_level)) { + return at::_ops::reflection_pad2d_backward::call(grad_output, self, padding); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, self_value, self_bdim, padding); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor reflection_pad3d_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef padding) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::reflection_pad3d::call(self, padding); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, padding); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor reflection_pad3d_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & self, c10::SymIntArrayRef padding) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(self, cur_level)) { + return at::_ops::reflection_pad3d_backward::call(grad_output, self, padding); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, self_value, self_bdim, padding); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor replication_pad1d_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef padding) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::replication_pad1d::call(self, padding); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, padding); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor replication_pad1d_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & self, c10::SymIntArrayRef padding) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(self, cur_level)) { + return at::_ops::replication_pad1d_backward::call(grad_output, self, padding); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, self_value, self_bdim, padding); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor replication_pad2d_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef padding) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::replication_pad2d::call(self, padding); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, padding); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor replication_pad2d_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & self, c10::SymIntArrayRef padding) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(self, cur_level)) { + return at::_ops::replication_pad2d_backward::call(grad_output, self, padding); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, self_value, self_bdim, padding); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor replication_pad3d_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef padding) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::replication_pad3d::call(self, padding); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, padding); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor replication_pad3d_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & self, c10::SymIntArrayRef padding) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(self, cur_level)) { + return at::_ops::replication_pad3d_backward::call(grad_output, self, padding); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, self_value, self_bdim, padding); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _pad_circular_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef pad) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_pad_circular::call(self, pad); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, pad); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _pad_enum_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef pad, int64_t mode, ::std::optional value) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_pad_enum::call(self, pad, mode, value); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, pad, mode, value); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor pad_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef pad, c10::string_view mode, ::std::optional value) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::pad::call(self, pad, mode, value); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, pad, mode, value); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor upsample_linear1d_vec_generated_plumbing(const at::Tensor & input, at::OptionalSymIntArrayRef output_size, bool align_corners, ::std::optional> scale_factors) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level)) { + return at::_ops::upsample_linear1d_vec::call(input, output_size, align_corners, scale_factors); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto results = batch_rule(input_value, input_bdim, output_size, align_corners, scale_factors); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor upsample_bilinear2d_vec_generated_plumbing(const at::Tensor & input, at::OptionalSymIntArrayRef output_size, bool align_corners, ::std::optional> scale_factors) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level)) { + return at::_ops::upsample_bilinear2d_vec::call(input, output_size, align_corners, scale_factors); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto results = batch_rule(input_value, input_bdim, output_size, align_corners, scale_factors); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _upsample_bilinear2d_aa_vec_generated_plumbing(const at::Tensor & input, at::OptionalSymIntArrayRef output_size, bool align_corners, ::std::optional> scale_factors) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level)) { + return at::_ops::_upsample_bilinear2d_aa_vec::call(input, output_size, align_corners, scale_factors); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto results = batch_rule(input_value, input_bdim, output_size, align_corners, scale_factors); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor upsample_trilinear3d_vec_generated_plumbing(const at::Tensor & input, at::OptionalSymIntArrayRef output_size, bool align_corners, ::std::optional> scale_factors) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level)) { + return at::_ops::upsample_trilinear3d_vec::call(input, output_size, align_corners, scale_factors); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto results = batch_rule(input_value, input_bdim, output_size, align_corners, scale_factors); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor upsample_bicubic2d_vec_generated_plumbing(const at::Tensor & input, at::OptionalSymIntArrayRef output_size, bool align_corners, ::std::optional> scale_factors) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level)) { + return at::_ops::upsample_bicubic2d_vec::call(input, output_size, align_corners, scale_factors); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto results = batch_rule(input_value, input_bdim, output_size, align_corners, scale_factors); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _upsample_bicubic2d_aa_vec_generated_plumbing(const at::Tensor & input, at::OptionalSymIntArrayRef output_size, bool align_corners, ::std::optional> scale_factors) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level)) { + return at::_ops::_upsample_bicubic2d_aa_vec::call(input, output_size, align_corners, scale_factors); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto results = batch_rule(input_value, input_bdim, output_size, align_corners, scale_factors); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor upsample_nearest1d_vec_generated_plumbing(const at::Tensor & input, at::OptionalSymIntArrayRef output_size, ::std::optional> scale_factors) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level)) { + return at::_ops::upsample_nearest1d_vec::call(input, output_size, scale_factors); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto results = batch_rule(input_value, input_bdim, output_size, scale_factors); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _upsample_nearest_exact1d_vec_generated_plumbing(const at::Tensor & input, at::OptionalSymIntArrayRef output_size, ::std::optional> scale_factors) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level)) { + return at::_ops::_upsample_nearest_exact1d_vec::call(input, output_size, scale_factors); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto results = batch_rule(input_value, input_bdim, output_size, scale_factors); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor upsample_nearest2d_vec_generated_plumbing(const at::Tensor & input, at::OptionalSymIntArrayRef output_size, ::std::optional> scale_factors) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level)) { + return at::_ops::upsample_nearest2d_vec::call(input, output_size, scale_factors); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto results = batch_rule(input_value, input_bdim, output_size, scale_factors); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _upsample_nearest_exact2d_vec_generated_plumbing(const at::Tensor & input, at::OptionalSymIntArrayRef output_size, ::std::optional> scale_factors) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level)) { + return at::_ops::_upsample_nearest_exact2d_vec::call(input, output_size, scale_factors); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto results = batch_rule(input_value, input_bdim, output_size, scale_factors); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor upsample_nearest3d_vec_generated_plumbing(const at::Tensor & input, at::OptionalSymIntArrayRef output_size, ::std::optional> scale_factors) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level)) { + return at::_ops::upsample_nearest3d_vec::call(input, output_size, scale_factors); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto results = batch_rule(input_value, input_bdim, output_size, scale_factors); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _upsample_nearest_exact3d_vec_generated_plumbing(const at::Tensor & input, at::OptionalSymIntArrayRef output_size, ::std::optional> scale_factors) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level)) { + return at::_ops::_upsample_nearest_exact3d_vec::call(input, output_size, scale_factors); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto results = batch_rule(input_value, input_bdim, output_size, scale_factors); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor upsample_linear1d_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, ::std::optional scales) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::upsample_linear1d::call(self, output_size, align_corners, scales); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, output_size, align_corners, scales); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor upsample_linear1d_backward_generated_plumbing(const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, ::std::optional scales) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level)) { + return at::_ops::upsample_linear1d_backward::call(grad_output, output_size, input_size, align_corners, scales); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, output_size, input_size, align_corners, scales); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor upsample_bilinear2d_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, ::std::optional scales_h, ::std::optional scales_w) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::upsample_bilinear2d::call(self, output_size, align_corners, scales_h, scales_w); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, output_size, align_corners, scales_h, scales_w); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor upsample_bilinear2d_backward_generated_plumbing(const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, ::std::optional scales_h, ::std::optional scales_w) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level)) { + return at::_ops::upsample_bilinear2d_backward::call(grad_output, output_size, input_size, align_corners, scales_h, scales_w); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, output_size, input_size, align_corners, scales_h, scales_w); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _upsample_bilinear2d_aa_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, ::std::optional scales_h, ::std::optional scales_w) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_upsample_bilinear2d_aa::call(self, output_size, align_corners, scales_h, scales_w); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, output_size, align_corners, scales_h, scales_w); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _upsample_bilinear2d_aa_backward_generated_plumbing(const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, ::std::optional scales_h, ::std::optional scales_w) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level)) { + return at::_ops::_upsample_bilinear2d_aa_backward::call(grad_output, output_size, input_size, align_corners, scales_h, scales_w); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, output_size, input_size, align_corners, scales_h, scales_w); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor upsample_bicubic2d_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, ::std::optional scales_h, ::std::optional scales_w) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::upsample_bicubic2d::call(self, output_size, align_corners, scales_h, scales_w); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, output_size, align_corners, scales_h, scales_w); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor upsample_bicubic2d_backward_generated_plumbing(const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, ::std::optional scales_h, ::std::optional scales_w) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level)) { + return at::_ops::upsample_bicubic2d_backward::call(grad_output, output_size, input_size, align_corners, scales_h, scales_w); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, output_size, input_size, align_corners, scales_h, scales_w); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _upsample_bicubic2d_aa_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, ::std::optional scales_h, ::std::optional scales_w) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_upsample_bicubic2d_aa::call(self, output_size, align_corners, scales_h, scales_w); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, output_size, align_corners, scales_h, scales_w); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _upsample_bicubic2d_aa_backward_generated_plumbing(const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, ::std::optional scales_h, ::std::optional scales_w) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level)) { + return at::_ops::_upsample_bicubic2d_aa_backward::call(grad_output, output_size, input_size, align_corners, scales_h, scales_w); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, output_size, input_size, align_corners, scales_h, scales_w); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor upsample_trilinear3d_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, ::std::optional scales_d, ::std::optional scales_h, ::std::optional scales_w) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::upsample_trilinear3d::call(self, output_size, align_corners, scales_d, scales_h, scales_w); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, output_size, align_corners, scales_d, scales_h, scales_w); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor upsample_trilinear3d_backward_generated_plumbing(const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, ::std::optional scales_d, ::std::optional scales_h, ::std::optional scales_w) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level)) { + return at::_ops::upsample_trilinear3d_backward::call(grad_output, output_size, input_size, align_corners, scales_d, scales_h, scales_w); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, output_size, input_size, align_corners, scales_d, scales_h, scales_w); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor upsample_nearest1d_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef output_size, ::std::optional scales) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::upsample_nearest1d::call(self, output_size, scales); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, output_size, scales); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _upsample_nearest_exact1d_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef output_size, ::std::optional scales) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_upsample_nearest_exact1d::call(self, output_size, scales); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, output_size, scales); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor upsample_nearest1d_backward_generated_plumbing(const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, ::std::optional scales) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level)) { + return at::_ops::upsample_nearest1d_backward::call(grad_output, output_size, input_size, scales); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, output_size, input_size, scales); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _upsample_nearest_exact1d_backward_generated_plumbing(const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, ::std::optional scales) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level)) { + return at::_ops::_upsample_nearest_exact1d_backward::call(grad_output, output_size, input_size, scales); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, output_size, input_size, scales); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor upsample_nearest2d_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef output_size, ::std::optional scales_h, ::std::optional scales_w) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::upsample_nearest2d::call(self, output_size, scales_h, scales_w); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, output_size, scales_h, scales_w); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _upsample_nearest_exact2d_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef output_size, ::std::optional scales_h, ::std::optional scales_w) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_upsample_nearest_exact2d::call(self, output_size, scales_h, scales_w); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, output_size, scales_h, scales_w); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor upsample_nearest2d_backward_generated_plumbing(const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, ::std::optional scales_h, ::std::optional scales_w) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level)) { + return at::_ops::upsample_nearest2d_backward::call(grad_output, output_size, input_size, scales_h, scales_w); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, output_size, input_size, scales_h, scales_w); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _upsample_nearest_exact2d_backward_generated_plumbing(const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, ::std::optional scales_h, ::std::optional scales_w) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level)) { + return at::_ops::_upsample_nearest_exact2d_backward::call(grad_output, output_size, input_size, scales_h, scales_w); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, output_size, input_size, scales_h, scales_w); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor upsample_nearest3d_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef output_size, ::std::optional scales_d, ::std::optional scales_h, ::std::optional scales_w) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::upsample_nearest3d::call(self, output_size, scales_d, scales_h, scales_w); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, output_size, scales_d, scales_h, scales_w); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _upsample_nearest_exact3d_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef output_size, ::std::optional scales_d, ::std::optional scales_h, ::std::optional scales_w) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_upsample_nearest_exact3d::call(self, output_size, scales_d, scales_h, scales_w); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, output_size, scales_d, scales_h, scales_w); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor upsample_nearest3d_backward_generated_plumbing(const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, ::std::optional scales_d, ::std::optional scales_h, ::std::optional scales_w) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level)) { + return at::_ops::upsample_nearest3d_backward::call(grad_output, output_size, input_size, scales_d, scales_h, scales_w); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, output_size, input_size, scales_d, scales_h, scales_w); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _upsample_nearest_exact3d_backward_generated_plumbing(const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, ::std::optional scales_d, ::std::optional scales_h, ::std::optional scales_w) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level)) { + return at::_ops::_upsample_nearest_exact3d_backward::call(grad_output, output_size, input_size, scales_d, scales_h, scales_w); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, output_size, input_size, scales_d, scales_h, scales_w); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor sigmoid_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & output) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(output, cur_level)) { + return at::_ops::sigmoid_backward::call(grad_output, output); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [output_value, output_bdim] = unwrapTensorAtLevel(output, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, output_value, output_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor logit_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & self, ::std::optional eps) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(self, cur_level)) { + return at::_ops::logit_backward::call(grad_output, self, eps); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, self_value, self_bdim, eps); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor tanh_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & output) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(output, cur_level)) { + return at::_ops::tanh_backward::call(grad_output, output); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [output_value, output_bdim] = unwrapTensorAtLevel(output, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, output_value, output_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor slow_conv_transpose2d_generated_plumbing(const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymIntArrayRef dilation) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::slow_conv_transpose2d::call(self, weight, kernel_size, bias, stride, padding, output_padding, dilation); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, weight_value, weight_bdim, kernel_size, bias_value, bias_bdim, stride, padding, output_padding, dilation); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor slow_conv_transpose3d_generated_plumbing(const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymIntArrayRef dilation) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::slow_conv_transpose3d::call(self, weight, kernel_size, bias, stride, padding, output_padding, dilation); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, weight_value, weight_bdim, kernel_size, bias_value, bias_bdim, stride, padding, output_padding, dilation); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor thnn_conv2d_generated_plumbing(const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::thnn_conv2d::call(self, weight, kernel_size, bias, stride, padding); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, weight_value, weight_bdim, kernel_size, bias_value, bias_bdim, stride, padding); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _slow_conv2d_forward_generated_plumbing(const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::_slow_conv2d_forward::call(self, weight, kernel_size, bias, stride, padding); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, weight_value, weight_bdim, kernel_size, bias_value, bias_bdim, stride, padding); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple _slow_conv2d_backward_output_mask_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, ::std::array output_mask) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(weight, cur_level)) { + return at::_ops::_slow_conv2d_backward_output_mask::call(grad_output, self, weight, kernel_size, stride, padding, output_mask); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, self_value, self_bdim, weight_value, weight_bdim, kernel_size, stride, padding, output_mask); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level)); +} +template +at::Tensor _conv_depthwise2d_generated_plumbing(const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::_conv_depthwise2d::call(self, weight, kernel_size, bias, stride, padding, dilation); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, weight_value, weight_bdim, kernel_size, bias_value, bias_bdim, stride, padding, dilation); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor conv_depthwise3d_generated_plumbing(const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::conv_depthwise3d::call(self, weight, kernel_size, bias, stride, padding, dilation); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, weight_value, weight_bdim, kernel_size, bias_value, bias_bdim, stride, padding, dilation); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor slow_conv3d_generated_plumbing(const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::slow_conv3d::call(self, weight, kernel_size, bias, stride, padding); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, weight_value, weight_bdim, kernel_size, bias_value, bias_bdim, stride, padding); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor slow_conv3d_forward_generated_plumbing(const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::slow_conv3d_forward::call(self, weight, kernel_size, bias, stride, padding); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, weight_value, weight_bdim, kernel_size, bias_value, bias_bdim, stride, padding); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor slow_conv_dilated2d_generated_plumbing(const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::slow_conv_dilated2d::call(self, weight, kernel_size, bias, stride, padding, dilation); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, weight_value, weight_bdim, kernel_size, bias_value, bias_bdim, stride, padding, dilation); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor slow_conv_dilated3d_generated_plumbing(const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::slow_conv_dilated3d::call(self, weight, kernel_size, bias, stride, padding, dilation); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, weight_value, weight_bdim, kernel_size, bias_value, bias_bdim, stride, padding, dilation); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor col2im_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef output_size, at::IntArrayRef kernel_size, at::IntArrayRef dilation, at::IntArrayRef padding, at::IntArrayRef stride) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::col2im::call(self, output_size, kernel_size, dilation, padding, stride); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, output_size, kernel_size, dilation, padding, stride); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor column_stack_generated_plumbing(at::TensorList tensors) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(tensors, cur_level)) { + return at::_ops::column_stack::call(tensors); + } + + auto results = batch_rule(tensors); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor im2col_generated_plumbing(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef dilation, at::IntArrayRef padding, at::IntArrayRef stride) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::im2col::call(self, kernel_size, dilation, padding, stride); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, kernel_size, dilation, padding, stride); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor isfinite_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::isfinite::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor isinf_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::isinf::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void record_stream_generated_plumbing(at::Tensor & self, at::Stream s) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::record_stream::call(self, s); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, s); +} +template +at::Tensor isposinf_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::isposinf::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor isneginf_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::isneginf::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _add_batch_dim_generated_plumbing(const at::Tensor & self, int64_t batch_dim, int64_t level) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_add_batch_dim::call(self, batch_dim, level); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, batch_dim, level); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _remove_batch_dim_generated_plumbing(const at::Tensor & self, int64_t level, c10::SymInt batch_size, int64_t out_dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_remove_batch_dim::call(self, level, batch_size, out_dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, level, batch_size, out_dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_entr_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::special_entr::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_ndtri_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::special_ndtri::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_log_ndtr_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::special_log_ndtr::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_expm1_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::special_expm1::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_exp2_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::special_exp2::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_psi_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::special_psi::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_digamma_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::special_digamma::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_gammaln_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::special_gammaln::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_erf_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::special_erf::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_erfc_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::special_erfc::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_erfcx_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::special_erfcx::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_erfinv_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::special_erfinv::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_ndtr_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::special_ndtr::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_xlog1py_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::special_xlog1py::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_xlog1py_self_scalar_generated_plumbing(const at::Scalar & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(other, cur_level)) { + return at::_ops::special_xlog1py_self_scalar::call(self, other); + } + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_xlog1py_other_scalar_generated_plumbing(const at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::special_xlog1py_other_scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, other); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_xlogy_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::special_xlogy::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_xlogy_self_scalar_generated_plumbing(const at::Scalar & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(other, cur_level)) { + return at::_ops::special_xlogy_self_scalar::call(self, other); + } + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_xlogy_other_scalar_generated_plumbing(const at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::special_xlogy_other_scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, other); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_zeta_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::special_zeta::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_zeta_self_scalar_generated_plumbing(const at::Scalar & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(other, cur_level)) { + return at::_ops::special_zeta_self_scalar::call(self, other); + } + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_zeta_other_scalar_generated_plumbing(const at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::special_zeta_other_scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, other); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_i0_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::special_i0::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_i0e_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::special_i0e::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_i1_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::special_i1::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_i1e_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::special_i1e::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_logit_generated_plumbing(const at::Tensor & self, ::std::optional eps) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::special_logit::call(self, eps); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, eps); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_polygamma_generated_plumbing(int64_t n, const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::special_polygamma::call(n, self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(n, self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_logsumexp_generated_plumbing(const at::Tensor & self, at::IntArrayRef dim, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::special_logsumexp::call(self, dim, keepdim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, keepdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_expit_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::special_expit::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_sinc_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::special_sinc::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_round_generated_plumbing(const at::Tensor & self, int64_t decimals) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::special_round::call(self, decimals); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, decimals); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_log1p_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::special_log1p::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_log_softmax_generated_plumbing(const at::Tensor & self, int64_t dim, ::std::optional dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::special_log_softmax::call(self, dim, dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_gammainc_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::special_gammainc::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_gammaincc_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::special_gammaincc::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_multigammaln_generated_plumbing(const at::Tensor & self, int64_t p) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::special_multigammaln::call(self, p); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, p); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_softmax_generated_plumbing(const at::Tensor & self, int64_t dim, ::std::optional dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::special_softmax::call(self, dim, dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor fft_fft_generated_plumbing(const at::Tensor & self, ::std::optional n, int64_t dim, ::std::optional norm) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::fft_fft::call(self, n, dim, norm); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, n, dim, norm); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor fft_ifft_generated_plumbing(const at::Tensor & self, ::std::optional n, int64_t dim, ::std::optional norm) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::fft_ifft::call(self, n, dim, norm); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, n, dim, norm); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor fft_rfft_generated_plumbing(const at::Tensor & self, ::std::optional n, int64_t dim, ::std::optional norm) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::fft_rfft::call(self, n, dim, norm); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, n, dim, norm); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor fft_irfft_generated_plumbing(const at::Tensor & self, ::std::optional n, int64_t dim, ::std::optional norm) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::fft_irfft::call(self, n, dim, norm); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, n, dim, norm); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor fft_hfft_generated_plumbing(const at::Tensor & self, ::std::optional n, int64_t dim, ::std::optional norm) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::fft_hfft::call(self, n, dim, norm); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, n, dim, norm); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor fft_ihfft_generated_plumbing(const at::Tensor & self, ::std::optional n, int64_t dim, ::std::optional norm) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::fft_ihfft::call(self, n, dim, norm); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, n, dim, norm); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor fft_fft2_generated_plumbing(const at::Tensor & self, at::OptionalSymIntArrayRef s, at::IntArrayRef dim, ::std::optional norm) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::fft_fft2::call(self, s, dim, norm); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, s, dim, norm); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor fft_ifft2_generated_plumbing(const at::Tensor & self, at::OptionalSymIntArrayRef s, at::IntArrayRef dim, ::std::optional norm) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::fft_ifft2::call(self, s, dim, norm); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, s, dim, norm); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor fft_rfft2_generated_plumbing(const at::Tensor & self, at::OptionalSymIntArrayRef s, at::IntArrayRef dim, ::std::optional norm) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::fft_rfft2::call(self, s, dim, norm); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, s, dim, norm); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor fft_irfft2_generated_plumbing(const at::Tensor & self, at::OptionalSymIntArrayRef s, at::IntArrayRef dim, ::std::optional norm) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::fft_irfft2::call(self, s, dim, norm); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, s, dim, norm); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor fft_hfft2_generated_plumbing(const at::Tensor & self, at::OptionalSymIntArrayRef s, at::IntArrayRef dim, ::std::optional norm) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::fft_hfft2::call(self, s, dim, norm); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, s, dim, norm); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor fft_ihfft2_generated_plumbing(const at::Tensor & self, at::OptionalSymIntArrayRef s, at::IntArrayRef dim, ::std::optional norm) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::fft_ihfft2::call(self, s, dim, norm); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, s, dim, norm); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor fft_fftn_generated_plumbing(const at::Tensor & self, at::OptionalSymIntArrayRef s, at::OptionalIntArrayRef dim, ::std::optional norm) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::fft_fftn::call(self, s, dim, norm); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, s, dim, norm); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor fft_ifftn_generated_plumbing(const at::Tensor & self, at::OptionalSymIntArrayRef s, at::OptionalIntArrayRef dim, ::std::optional norm) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::fft_ifftn::call(self, s, dim, norm); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, s, dim, norm); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor fft_rfftn_generated_plumbing(const at::Tensor & self, at::OptionalSymIntArrayRef s, at::OptionalIntArrayRef dim, ::std::optional norm) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::fft_rfftn::call(self, s, dim, norm); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, s, dim, norm); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor fft_irfftn_generated_plumbing(const at::Tensor & self, at::OptionalSymIntArrayRef s, at::OptionalIntArrayRef dim, ::std::optional norm) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::fft_irfftn::call(self, s, dim, norm); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, s, dim, norm); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor fft_hfftn_generated_plumbing(const at::Tensor & self, at::OptionalSymIntArrayRef s, at::OptionalIntArrayRef dim, ::std::optional norm) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::fft_hfftn::call(self, s, dim, norm); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, s, dim, norm); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor fft_ihfftn_generated_plumbing(const at::Tensor & self, at::OptionalSymIntArrayRef s, at::OptionalIntArrayRef dim, ::std::optional norm) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::fft_ihfftn::call(self, s, dim, norm); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, s, dim, norm); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor fft_fftshift_generated_plumbing(const at::Tensor & self, at::OptionalIntArrayRef dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::fft_fftshift::call(self, dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor fft_ifftshift_generated_plumbing(const at::Tensor & self, at::OptionalIntArrayRef dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::fft_ifftshift::call(self, dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple linalg_cholesky_ex_generated_plumbing(const at::Tensor & self, bool upper, bool check_errors) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::linalg_cholesky_ex::call(self, upper, check_errors); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, upper, check_errors); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor linalg_cholesky_generated_plumbing(const at::Tensor & self, bool upper) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::linalg_cholesky::call(self, upper); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, upper); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor linalg_cross_generated_plumbing(const at::Tensor & self, const at::Tensor & other, int64_t dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::linalg_cross::call(self, other, dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim, dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple linalg_lu_factor_generated_plumbing(const at::Tensor & A, bool pivot) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(A, cur_level)) { + return at::_ops::linalg_lu_factor::call(A, pivot); + } + auto [A_value, A_bdim] = unwrapTensorAtLevel(A, cur_level); + auto results = batch_rule(A_value, A_bdim, pivot); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple linalg_lu_factor_ex_generated_plumbing(const at::Tensor & A, bool pivot, bool check_errors) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(A, cur_level)) { + return at::_ops::linalg_lu_factor_ex::call(A, pivot, check_errors); + } + auto [A_value, A_bdim] = unwrapTensorAtLevel(A, cur_level); + auto results = batch_rule(A_value, A_bdim, pivot, check_errors); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level)); +} +template +::std::tuple linalg_lu_generated_plumbing(const at::Tensor & A, bool pivot) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(A, cur_level)) { + return at::_ops::linalg_lu::call(A, pivot); + } + auto [A_value, A_bdim] = unwrapTensorAtLevel(A, cur_level); + auto results = batch_rule(A_value, A_bdim, pivot); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level)); +} +template +at::Tensor linalg_lu_solve_generated_plumbing(const at::Tensor & LU, const at::Tensor & pivots, const at::Tensor & B, bool left, bool adjoint) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(LU, cur_level) && !isBatchedAtLevel(pivots, cur_level) && !isBatchedAtLevel(B, cur_level)) { + return at::_ops::linalg_lu_solve::call(LU, pivots, B, left, adjoint); + } + auto [LU_value, LU_bdim] = unwrapTensorAtLevel(LU, cur_level); + auto [pivots_value, pivots_bdim] = unwrapTensorAtLevel(pivots, cur_level); + auto [B_value, B_bdim] = unwrapTensorAtLevel(B, cur_level); + auto results = batch_rule(LU_value, LU_bdim, pivots_value, pivots_bdim, B_value, B_bdim, left, adjoint); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple _linalg_det_generated_plumbing(const at::Tensor & A) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(A, cur_level)) { + return at::_ops::_linalg_det::call(A); + } + auto [A_value, A_bdim] = unwrapTensorAtLevel(A, cur_level); + auto results = batch_rule(A_value, A_bdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level)); +} +template +at::Tensor linalg_det_generated_plumbing(const at::Tensor & A) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(A, cur_level)) { + return at::_ops::linalg_det::call(A); + } + auto [A_value, A_bdim] = unwrapTensorAtLevel(A, cur_level); + auto results = batch_rule(A_value, A_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor det_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::det::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple linalg_ldl_factor_ex_generated_plumbing(const at::Tensor & self, bool hermitian, bool check_errors) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::linalg_ldl_factor_ex::call(self, hermitian, check_errors); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, hermitian, check_errors); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level)); +} +template +::std::tuple linalg_ldl_factor_generated_plumbing(const at::Tensor & self, bool hermitian) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::linalg_ldl_factor::call(self, hermitian); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, hermitian); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor linalg_ldl_solve_generated_plumbing(const at::Tensor & LD, const at::Tensor & pivots, const at::Tensor & B, bool hermitian) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(LD, cur_level) && !isBatchedAtLevel(pivots, cur_level) && !isBatchedAtLevel(B, cur_level)) { + return at::_ops::linalg_ldl_solve::call(LD, pivots, B, hermitian); + } + auto [LD_value, LD_bdim] = unwrapTensorAtLevel(LD, cur_level); + auto [pivots_value, pivots_bdim] = unwrapTensorAtLevel(pivots, cur_level); + auto [B_value, B_bdim] = unwrapTensorAtLevel(B, cur_level); + auto results = batch_rule(LD_value, LD_bdim, pivots_value, pivots_bdim, B_value, B_bdim, hermitian); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple linalg_lstsq_generated_plumbing(const at::Tensor & self, const at::Tensor & b, ::std::optional rcond, ::std::optional driver) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(b, cur_level)) { + return at::_ops::linalg_lstsq::call(self, b, rcond, driver); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [b_value, b_bdim] = unwrapTensorAtLevel(b, cur_level); + auto results = batch_rule(self_value, self_bdim, b_value, b_bdim, rcond, driver); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level), makeBatched(std::get<6>(results), std::get<7>(results), cur_level)); +} +template +at::Tensor linalg_matmul_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::linalg_matmul::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor linalg_vecdot_generated_plumbing(const at::Tensor & x, const at::Tensor & y, int64_t dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(x, cur_level) && !isBatchedAtLevel(y, cur_level)) { + return at::_ops::linalg_vecdot::call(x, y, dim); + } + auto [x_value, x_bdim] = unwrapTensorAtLevel(x, cur_level); + auto [y_value, y_bdim] = unwrapTensorAtLevel(y, cur_level); + auto results = batch_rule(x_value, x_bdim, y_value, y_bdim, dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor linalg_matrix_exp_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::linalg_matrix_exp::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple _linalg_slogdet_generated_plumbing(const at::Tensor & A) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(A, cur_level)) { + return at::_ops::_linalg_slogdet::call(A); + } + auto [A_value, A_bdim] = unwrapTensorAtLevel(A, cur_level); + auto results = batch_rule(A_value, A_bdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level), makeBatched(std::get<6>(results), std::get<7>(results), cur_level)); +} +template +::std::tuple linalg_slogdet_generated_plumbing(const at::Tensor & A) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(A, cur_level)) { + return at::_ops::linalg_slogdet::call(A); + } + auto [A_value, A_bdim] = unwrapTensorAtLevel(A, cur_level); + auto results = batch_rule(A_value, A_bdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple slogdet_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::slogdet::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor logdet_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::logdet::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple linalg_eig_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::linalg_eig::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor _linalg_eigvals_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_linalg_eigvals::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor linalg_eigvals_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::linalg_eigvals::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple _linalg_eigh_generated_plumbing(const at::Tensor & A, c10::string_view UPLO, bool compute_v) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(A, cur_level)) { + return at::_ops::_linalg_eigh::call(A, UPLO, compute_v); + } + auto [A_value, A_bdim] = unwrapTensorAtLevel(A, cur_level); + auto results = batch_rule(A_value, A_bdim, UPLO, compute_v); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple linalg_eigh_generated_plumbing(const at::Tensor & self, c10::string_view UPLO) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::linalg_eigh::call(self, UPLO); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, UPLO); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor linalg_eigvalsh_generated_plumbing(const at::Tensor & self, c10::string_view UPLO) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::linalg_eigvalsh::call(self, UPLO); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, UPLO); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor linalg_householder_product_generated_plumbing(const at::Tensor & input, const at::Tensor & tau) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(tau, cur_level)) { + return at::_ops::linalg_householder_product::call(input, tau); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [tau_value, tau_bdim] = unwrapTensorAtLevel(tau, cur_level); + auto results = batch_rule(input_value, input_bdim, tau_value, tau_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple linalg_inv_ex_generated_plumbing(const at::Tensor & A, bool check_errors) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(A, cur_level)) { + return at::_ops::linalg_inv_ex::call(A, check_errors); + } + auto [A_value, A_bdim] = unwrapTensorAtLevel(A, cur_level); + auto results = batch_rule(A_value, A_bdim, check_errors); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor linalg_inv_generated_plumbing(const at::Tensor & A) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(A, cur_level)) { + return at::_ops::linalg_inv::call(A); + } + auto [A_value, A_bdim] = unwrapTensorAtLevel(A, cur_level); + auto results = batch_rule(A_value, A_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor inverse_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::inverse::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor inner_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::inner::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor outer_generated_plumbing(const at::Tensor & self, const at::Tensor & vec2) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(vec2, cur_level)) { + return at::_ops::outer::call(self, vec2); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [vec2_value, vec2_bdim] = unwrapTensorAtLevel(vec2, cur_level); + auto results = batch_rule(self_value, self_bdim, vec2_value, vec2_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor ger_generated_plumbing(const at::Tensor & self, const at::Tensor & vec2) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(vec2, cur_level)) { + return at::_ops::ger::call(self, vec2); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [vec2_value, vec2_bdim] = unwrapTensorAtLevel(vec2, cur_level); + auto results = batch_rule(self_value, self_bdim, vec2_value, vec2_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor linalg_norm_generated_plumbing(const at::Tensor & self, const ::std::optional & ord, at::OptionalIntArrayRef dim, bool keepdim, ::std::optional dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::linalg_norm::call(self, ord, dim, keepdim, dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, ord, dim, keepdim, dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor linalg_norm_ord_str_generated_plumbing(const at::Tensor & self, c10::string_view ord, at::OptionalIntArrayRef dim, bool keepdim, ::std::optional dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::linalg_norm_ord_str::call(self, ord, dim, keepdim, dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, ord, dim, keepdim, dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor linalg_vector_norm_generated_plumbing(const at::Tensor & self, const at::Scalar & ord, at::OptionalIntArrayRef dim, bool keepdim, ::std::optional dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::linalg_vector_norm::call(self, ord, dim, keepdim, dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, ord, dim, keepdim, dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor linalg_matrix_norm_generated_plumbing(const at::Tensor & self, const at::Scalar & ord, at::IntArrayRef dim, bool keepdim, ::std::optional dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::linalg_matrix_norm::call(self, ord, dim, keepdim, dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, ord, dim, keepdim, dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor linalg_matrix_norm_str_ord_generated_plumbing(const at::Tensor & self, c10::string_view ord, at::IntArrayRef dim, bool keepdim, ::std::optional dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::linalg_matrix_norm_str_ord::call(self, ord, dim, keepdim, dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, ord, dim, keepdim, dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple _linalg_svd_generated_plumbing(const at::Tensor & A, bool full_matrices, bool compute_uv, ::std::optional driver) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(A, cur_level)) { + return at::_ops::_linalg_svd::call(A, full_matrices, compute_uv, driver); + } + auto [A_value, A_bdim] = unwrapTensorAtLevel(A, cur_level); + auto results = batch_rule(A_value, A_bdim, full_matrices, compute_uv, driver); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level)); +} +template +::std::tuple linalg_svd_generated_plumbing(const at::Tensor & A, bool full_matrices, ::std::optional driver) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(A, cur_level)) { + return at::_ops::linalg_svd::call(A, full_matrices, driver); + } + auto [A_value, A_bdim] = unwrapTensorAtLevel(A, cur_level); + auto results = batch_rule(A_value, A_bdim, full_matrices, driver); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level)); +} +template +at::Tensor linalg_svdvals_generated_plumbing(const at::Tensor & A, ::std::optional driver) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(A, cur_level)) { + return at::_ops::linalg_svdvals::call(A, driver); + } + auto [A_value, A_bdim] = unwrapTensorAtLevel(A, cur_level); + auto results = batch_rule(A_value, A_bdim, driver); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor linalg_cond_generated_plumbing(const at::Tensor & self, const ::std::optional & p) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::linalg_cond::call(self, p); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, p); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor linalg_cond_p_str_generated_plumbing(const at::Tensor & self, c10::string_view p) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::linalg_cond_p_str::call(self, p); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, p); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor linalg_pinv_atol_rtol_tensor_generated_plumbing(const at::Tensor & self, const ::std::optional & atol, const ::std::optional & rtol, bool hermitian) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(atol, cur_level) && !isBatchedAtLevel(rtol, cur_level)) { + return at::_ops::linalg_pinv_atol_rtol_tensor::call(self, atol, rtol, hermitian); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + std::optional atol_value; + std::optional atol_bdim; + if (atol) { + std::tie(atol_value, atol_bdim) = unwrapTensorAtLevel(atol.value(), cur_level); + } + std::optional rtol_value; + std::optional rtol_bdim; + if (rtol) { + std::tie(rtol_value, rtol_bdim) = unwrapTensorAtLevel(rtol.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, atol_value, atol_bdim, rtol_value, rtol_bdim, hermitian); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor linalg_pinv_atol_rtol_float_generated_plumbing(const at::Tensor & self, ::std::optional atol, ::std::optional rtol, bool hermitian) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::linalg_pinv_atol_rtol_float::call(self, atol, rtol, hermitian); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, atol, rtol, hermitian); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor linalg_pinv_generated_plumbing(const at::Tensor & self, double rcond, bool hermitian) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::linalg_pinv::call(self, rcond, hermitian); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, rcond, hermitian); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor linalg_pinv_rcond_tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & rcond, bool hermitian) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(rcond, cur_level)) { + return at::_ops::linalg_pinv_rcond_tensor::call(self, rcond, hermitian); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [rcond_value, rcond_bdim] = unwrapTensorAtLevel(rcond, cur_level); + auto results = batch_rule(self_value, self_bdim, rcond_value, rcond_bdim, hermitian); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple _linalg_solve_ex_generated_plumbing(const at::Tensor & A, const at::Tensor & B, bool left, bool check_errors) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(A, cur_level) && !isBatchedAtLevel(B, cur_level)) { + return at::_ops::_linalg_solve_ex::call(A, B, left, check_errors); + } + auto [A_value, A_bdim] = unwrapTensorAtLevel(A, cur_level); + auto [B_value, B_bdim] = unwrapTensorAtLevel(B, cur_level); + auto results = batch_rule(A_value, A_bdim, B_value, B_bdim, left, check_errors); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level), makeBatched(std::get<6>(results), std::get<7>(results), cur_level)); +} +template +::std::tuple linalg_solve_ex_generated_plumbing(const at::Tensor & A, const at::Tensor & B, bool left, bool check_errors) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(A, cur_level) && !isBatchedAtLevel(B, cur_level)) { + return at::_ops::linalg_solve_ex::call(A, B, left, check_errors); + } + auto [A_value, A_bdim] = unwrapTensorAtLevel(A, cur_level); + auto [B_value, B_bdim] = unwrapTensorAtLevel(B, cur_level); + auto results = batch_rule(A_value, A_bdim, B_value, B_bdim, left, check_errors); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor linalg_solve_generated_plumbing(const at::Tensor & A, const at::Tensor & B, bool left) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(A, cur_level) && !isBatchedAtLevel(B, cur_level)) { + return at::_ops::linalg_solve::call(A, B, left); + } + auto [A_value, A_bdim] = unwrapTensorAtLevel(A, cur_level); + auto [B_value, B_bdim] = unwrapTensorAtLevel(B, cur_level); + auto results = batch_rule(A_value, A_bdim, B_value, B_bdim, left); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _spsolve_generated_plumbing(const at::Tensor & A, const at::Tensor & B, bool left) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(A, cur_level) && !isBatchedAtLevel(B, cur_level)) { + return at::_ops::_spsolve::call(A, B, left); + } + auto [A_value, A_bdim] = unwrapTensorAtLevel(A, cur_level); + auto [B_value, B_bdim] = unwrapTensorAtLevel(B, cur_level); + auto results = batch_rule(A_value, A_bdim, B_value, B_bdim, left); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor linalg_tensorinv_generated_plumbing(const at::Tensor & self, int64_t ind) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::linalg_tensorinv::call(self, ind); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, ind); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor linalg_tensorsolve_generated_plumbing(const at::Tensor & self, const at::Tensor & other, at::OptionalIntArrayRef dims) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::linalg_tensorsolve::call(self, other, dims); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim, dims); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple linalg_qr_generated_plumbing(const at::Tensor & A, c10::string_view mode) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(A, cur_level)) { + return at::_ops::linalg_qr::call(A, mode); + } + auto [A_value, A_bdim] = unwrapTensorAtLevel(A, cur_level); + auto results = batch_rule(A_value, A_bdim, mode); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor linalg_matrix_power_generated_plumbing(const at::Tensor & self, int64_t n) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::linalg_matrix_power::call(self, n); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, n); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor linalg_matrix_rank_atol_rtol_tensor_generated_plumbing(const at::Tensor & input, const ::std::optional & atol, const ::std::optional & rtol, bool hermitian) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(atol, cur_level) && !isBatchedAtLevel(rtol, cur_level)) { + return at::_ops::linalg_matrix_rank_atol_rtol_tensor::call(input, atol, rtol, hermitian); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + std::optional atol_value; + std::optional atol_bdim; + if (atol) { + std::tie(atol_value, atol_bdim) = unwrapTensorAtLevel(atol.value(), cur_level); + } + std::optional rtol_value; + std::optional rtol_bdim; + if (rtol) { + std::tie(rtol_value, rtol_bdim) = unwrapTensorAtLevel(rtol.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, atol_value, atol_bdim, rtol_value, rtol_bdim, hermitian); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor linalg_matrix_rank_atol_rtol_float_generated_plumbing(const at::Tensor & self, ::std::optional atol, ::std::optional rtol, bool hermitian) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::linalg_matrix_rank_atol_rtol_float::call(self, atol, rtol, hermitian); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, atol, rtol, hermitian); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor linalg_matrix_rank_generated_plumbing(const at::Tensor & self, double tol, bool hermitian) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::linalg_matrix_rank::call(self, tol, hermitian); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, tol, hermitian); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor linalg_matrix_rank_tol_tensor_generated_plumbing(const at::Tensor & input, const at::Tensor & tol, bool hermitian) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(tol, cur_level)) { + return at::_ops::linalg_matrix_rank_tol_tensor::call(input, tol, hermitian); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [tol_value, tol_bdim] = unwrapTensorAtLevel(tol, cur_level); + auto results = batch_rule(input_value, input_bdim, tol_value, tol_bdim, hermitian); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor linalg_multi_dot_generated_plumbing(at::TensorList tensors) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(tensors, cur_level)) { + return at::_ops::linalg_multi_dot::call(tensors); + } + + auto results = batch_rule(tensors); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor nested_to_padded_tensor_generated_plumbing(const at::Tensor & self, double padding, at::OptionalIntArrayRef output_size) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::nested_to_padded_tensor::call(self, padding, output_size); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, padding, output_size); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _test_serialization_subcmul_generated_plumbing(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::_test_serialization_subcmul::call(self, other, alpha); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim, alpha); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _test_parallel_materialize_generated_plumbing(const at::Tensor & self, int64_t num_parallel, bool skip_first) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_test_parallel_materialize::call(self, num_parallel, skip_first); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, num_parallel, skip_first); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _test_optional_intlist_generated_plumbing(const at::Tensor & values, at::OptionalIntArrayRef addends) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(values, cur_level)) { + return at::_ops::_test_optional_intlist::call(values, addends); + } + auto [values_value, values_bdim] = unwrapTensorAtLevel(values, cur_level); + auto results = batch_rule(values_value, values_bdim, addends); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _test_optional_filled_intlist_generated_plumbing(const at::Tensor & values, at::OptionalIntArrayRef addends) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(values, cur_level)) { + return at::_ops::_test_optional_filled_intlist::call(values, addends); + } + auto [values_value, values_bdim] = unwrapTensorAtLevel(values, cur_level); + auto results = batch_rule(values_value, values_bdim, addends); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _test_optional_floatlist_generated_plumbing(const at::Tensor & values, ::std::optional> addends) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(values, cur_level)) { + return at::_ops::_test_optional_floatlist::call(values, addends); + } + auto [values_value, values_bdim] = unwrapTensorAtLevel(values, cur_level); + auto results = batch_rule(values_value, values_bdim, addends); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _test_string_default_generated_plumbing(const at::Tensor & dummy, c10::string_view a, c10::string_view b) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(dummy, cur_level)) { + return at::_ops::_test_string_default::call(dummy, a, b); + } + auto [dummy_value, dummy_bdim] = unwrapTensorAtLevel(dummy, cur_level); + auto results = batch_rule(dummy_value, dummy_bdim, a, b); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _test_ambiguous_defaults_a_generated_plumbing(const at::Tensor & dummy, int64_t a, int64_t b) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(dummy, cur_level)) { + return at::_ops::_test_ambiguous_defaults_a::call(dummy, a, b); + } + auto [dummy_value, dummy_bdim] = unwrapTensorAtLevel(dummy, cur_level); + auto results = batch_rule(dummy_value, dummy_bdim, a, b); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _test_ambiguous_defaults_b_generated_plumbing(const at::Tensor & dummy, int64_t a, c10::string_view b) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(dummy, cur_level)) { + return at::_ops::_test_ambiguous_defaults_b::call(dummy, a, b); + } + auto [dummy_value, dummy_bdim] = unwrapTensorAtLevel(dummy, cur_level); + auto results = batch_rule(dummy_value, dummy_bdim, a, b); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _test_warn_in_autograd_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_test_warn_in_autograd::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _test_autograd_multiple_dispatch_fullcoverage_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_test_autograd_multiple_dispatch_fullcoverage::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _test_autograd_multiple_dispatch_ntonly_generated_plumbing(const at::Tensor & self, bool b) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_test_autograd_multiple_dispatch_ntonly::call(self, b); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, b); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _test_autograd_multiple_dispatch_view_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_test_autograd_multiple_dispatch_view::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _test_autograd_multiple_dispatch_view_copy_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_test_autograd_multiple_dispatch_view_copy::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor segment_reduce_generated_plumbing(const at::Tensor & data, c10::string_view reduce, const ::std::optional & lengths, const ::std::optional & indices, const ::std::optional & offsets, int64_t axis, bool unsafe, const ::std::optional & initial) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(data, cur_level) && !isBatchedAtLevel(lengths, cur_level) && !isBatchedAtLevel(indices, cur_level) && !isBatchedAtLevel(offsets, cur_level)) { + return at::_ops::segment_reduce::call(data, reduce, lengths, indices, offsets, axis, unsafe, initial); + } + auto [data_value, data_bdim] = unwrapTensorAtLevel(data, cur_level); + std::optional lengths_value; + std::optional lengths_bdim; + if (lengths) { + std::tie(lengths_value, lengths_bdim) = unwrapTensorAtLevel(lengths.value(), cur_level); + } + std::optional indices_value; + std::optional indices_bdim; + if (indices) { + std::tie(indices_value, indices_bdim) = unwrapTensorAtLevel(indices.value(), cur_level); + } + std::optional offsets_value; + std::optional offsets_bdim; + if (offsets) { + std::tie(offsets_value, offsets_bdim) = unwrapTensorAtLevel(offsets.value(), cur_level); + } + auto results = batch_rule(data_value, data_bdim, reduce, lengths_value, lengths_bdim, indices_value, indices_bdim, offsets_value, offsets_bdim, axis, unsafe, initial); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _segment_reduce_backward_generated_plumbing(const at::Tensor & grad, const at::Tensor & output, const at::Tensor & data, c10::string_view reduce, const ::std::optional & lengths, const ::std::optional & offsets, int64_t axis, const ::std::optional & initial) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad, cur_level) && !isBatchedAtLevel(output, cur_level) && !isBatchedAtLevel(data, cur_level) && !isBatchedAtLevel(lengths, cur_level) && !isBatchedAtLevel(offsets, cur_level)) { + return at::_ops::_segment_reduce_backward::call(grad, output, data, reduce, lengths, offsets, axis, initial); + } + auto [grad_value, grad_bdim] = unwrapTensorAtLevel(grad, cur_level); + auto [output_value, output_bdim] = unwrapTensorAtLevel(output, cur_level); + auto [data_value, data_bdim] = unwrapTensorAtLevel(data, cur_level); + std::optional lengths_value; + std::optional lengths_bdim; + if (lengths) { + std::tie(lengths_value, lengths_bdim) = unwrapTensorAtLevel(lengths.value(), cur_level); + } + std::optional offsets_value; + std::optional offsets_bdim; + if (offsets) { + std::tie(offsets_value, offsets_bdim) = unwrapTensorAtLevel(offsets.value(), cur_level); + } + auto results = batch_rule(grad_value, grad_bdim, output_value, output_bdim, data_value, data_bdim, reduce, lengths_value, lengths_bdim, offsets_value, offsets_bdim, axis, initial); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor pad_sequence_generated_plumbing(at::TensorList sequences, bool batch_first, double padding_value, c10::string_view padding_side) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(sequences, cur_level)) { + return at::_ops::pad_sequence::call(sequences, batch_first, padding_value, padding_side); + } + + auto results = batch_rule(sequences, batch_first, padding_value, padding_side); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor flatten_dense_tensors_generated_plumbing(at::TensorList tensors) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(tensors, cur_level)) { + return at::_ops::flatten_dense_tensors::call(tensors); + } + + auto results = batch_rule(tensors); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::vector unflatten_dense_tensors_generated_plumbing(const at::Tensor & flat, at::TensorList tensors) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(flat, cur_level) && !isBatchedAtLevel(tensors, cur_level)) { + return at::_ops::unflatten_dense_tensors::call(flat, tensors); + } + auto [flat_value, flat_bdim] = unwrapTensorAtLevel(flat, cur_level); + auto results = batch_rule(flat_value, flat_bdim, tensors); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _nested_tensor_from_tensor_list_generated_plumbing(at::TensorList list, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(list, cur_level)) { + return at::_ops::_nested_tensor_from_tensor_list::call(list, dtype, layout, device, pin_memory); + } + + auto results = batch_rule(list, dtype, layout, device, pin_memory); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _fw_primal_copy_generated_plumbing(const at::Tensor & self, int64_t level) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_fw_primal_copy::call(self, level); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, level); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _make_dual_copy_generated_plumbing(const at::Tensor & primal, const at::Tensor & tangent, int64_t level) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(primal, cur_level) && !isBatchedAtLevel(tangent, cur_level)) { + return at::_ops::_make_dual_copy::call(primal, tangent, level); + } + auto [primal_value, primal_bdim] = unwrapTensorAtLevel(primal, cur_level); + auto [tangent_value, tangent_bdim] = unwrapTensorAtLevel(tangent, cur_level); + auto results = batch_rule(primal_value, primal_bdim, tangent_value, tangent_bdim, level); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor view_as_real_copy_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::view_as_real_copy::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor view_as_complex_copy_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::view_as_complex_copy::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _conj_copy_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_conj_copy::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _neg_view_copy_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_neg_view_copy::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor as_strided_copy_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional storage_offset) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::as_strided_copy::call(self, size, stride, storage_offset); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, size, stride, storage_offset); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _sparse_broadcast_to_copy_generated_plumbing(const at::Tensor & self, at::IntArrayRef size) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_sparse_broadcast_to_copy::call(self, size); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, size); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor diagonal_copy_generated_plumbing(const at::Tensor & self, int64_t offset, int64_t dim1, int64_t dim2) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::diagonal_copy::call(self, offset, dim1, dim2); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, offset, dim1, dim2); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor expand_copy_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef size, bool implicit) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::expand_copy::call(self, size, implicit); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, size, implicit); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor permute_copy_generated_plumbing(const at::Tensor & self, at::IntArrayRef dims) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::permute_copy::call(self, dims); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dims); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _reshape_alias_copy_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_reshape_alias_copy::call(self, size, stride); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, size, stride); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor select_copy_int_generated_plumbing(const at::Tensor & self, int64_t dim, c10::SymInt index) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::select_copy_int::call(self, dim, index); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, index); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor detach_copy_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::detach_copy::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor slice_copy_Tensor_generated_plumbing(const at::Tensor & self, int64_t dim, ::std::optional start, ::std::optional end, c10::SymInt step) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::slice_copy_Tensor::call(self, dim, start, end, step); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, start, end, step); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::vector split_copy_Tensor_generated_plumbing(const at::Tensor & self, c10::SymInt split_size, int64_t dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::split_copy_Tensor::call(self, split_size, dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, split_size, dim); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::vector split_with_sizes_copy_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef split_sizes, int64_t dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::split_with_sizes_copy::call(self, split_sizes, dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, split_sizes, dim); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor squeeze_copy_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::squeeze_copy::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor squeeze_copy_dim_generated_plumbing(const at::Tensor & self, int64_t dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::squeeze_copy_dim::call(self, dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor squeeze_copy_dims_generated_plumbing(const at::Tensor & self, at::IntArrayRef dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::squeeze_copy_dims::call(self, dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor t_copy_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::t_copy::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor transpose_copy_int_generated_plumbing(const at::Tensor & self, int64_t dim0, int64_t dim1) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::transpose_copy_int::call(self, dim0, dim1); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim0, dim1); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor unsqueeze_copy_generated_plumbing(const at::Tensor & self, int64_t dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::unsqueeze_copy::call(self, dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _indices_copy_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_indices_copy::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _values_copy_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_values_copy::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor indices_copy_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::indices_copy::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor values_copy_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::values_copy::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor crow_indices_copy_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::crow_indices_copy::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor col_indices_copy_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::col_indices_copy::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor ccol_indices_copy_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::ccol_indices_copy::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor row_indices_copy_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::row_indices_copy::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::vector unbind_copy_int_generated_plumbing(const at::Tensor & self, int64_t dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::unbind_copy_int::call(self, dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void unbind_copy_int_out_generated_plumbing(const at::Tensor & self, int64_t dim, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::unbind_copy_int_out::call(self, dim, out); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, dim, out); +} +template +void split_copy_Tensor_out_generated_plumbing(const at::Tensor & self, c10::SymInt split_size, int64_t dim, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::split_copy_Tensor_out::call(self, split_size, dim, out); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, split_size, dim, out); +} +template +void split_with_sizes_copy_out_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef split_sizes, int64_t dim, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::split_with_sizes_copy_out::call(self, split_sizes, dim, out); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, split_sizes, dim, out); +} +template +at::Tensor view_copy_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef size) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::view_copy::call(self, size); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, size); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor view_copy_dtype_generated_plumbing(const at::Tensor & self, at::ScalarType dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::view_copy_dtype::call(self, dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor unfold_copy_generated_plumbing(const at::Tensor & self, int64_t dimension, int64_t size, int64_t step) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::unfold_copy::call(self, dimension, size, step); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dimension, size, step); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor alias_copy_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::alias_copy::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor to_padded_tensor_generated_plumbing(const at::Tensor & self, double padding, at::OptionalSymIntArrayRef output_size) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::to_padded_tensor::call(self, padding, output_size); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, padding, output_size); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _jagged_to_padded_dense_forward_generated_plumbing(const at::Tensor & values, at::TensorList offsets, c10::SymIntArrayRef max_lengths, double padding_value) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(values, cur_level) && !isBatchedAtLevel(offsets, cur_level)) { + return at::_ops::_jagged_to_padded_dense_forward::call(values, offsets, max_lengths, padding_value); + } + auto [values_value, values_bdim] = unwrapTensorAtLevel(values, cur_level); + auto results = batch_rule(values_value, values_bdim, offsets, max_lengths, padding_value); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _padded_dense_to_jagged_forward_generated_plumbing(const at::Tensor & dense, at::TensorList offsets, ::std::optional total_L) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(dense, cur_level) && !isBatchedAtLevel(offsets, cur_level)) { + return at::_ops::_padded_dense_to_jagged_forward::call(dense, offsets, total_L); + } + auto [dense_value, dense_bdim] = unwrapTensorAtLevel(dense, cur_level); + auto results = batch_rule(dense_value, dense_bdim, offsets, total_L); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _nested_from_padded_tensor_generated_plumbing(const at::Tensor & padded, const at::Tensor & offsets, const at::Tensor & dummy, int64_t ragged_idx, const ::std::optional & min_seqlen, const ::std::optional & max_seqlen, ::std::optional sum_S) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(padded, cur_level) && !isBatchedAtLevel(offsets, cur_level) && !isBatchedAtLevel(dummy, cur_level) && !isBatchedAtLevel(min_seqlen, cur_level) && !isBatchedAtLevel(max_seqlen, cur_level)) { + return at::_ops::_nested_from_padded_tensor::call(padded, offsets, dummy, ragged_idx, min_seqlen, max_seqlen, sum_S); + } + auto [padded_value, padded_bdim] = unwrapTensorAtLevel(padded, cur_level); + auto [offsets_value, offsets_bdim] = unwrapTensorAtLevel(offsets, cur_level); + auto [dummy_value, dummy_bdim] = unwrapTensorAtLevel(dummy, cur_level); + std::optional min_seqlen_value; + std::optional min_seqlen_bdim; + if (min_seqlen) { + std::tie(min_seqlen_value, min_seqlen_bdim) = unwrapTensorAtLevel(min_seqlen.value(), cur_level); + } + std::optional max_seqlen_value; + std::optional max_seqlen_bdim; + if (max_seqlen) { + std::tie(max_seqlen_value, max_seqlen_bdim) = unwrapTensorAtLevel(max_seqlen.value(), cur_level); + } + auto results = batch_rule(padded_value, padded_bdim, offsets_value, offsets_bdim, dummy_value, dummy_bdim, ragged_idx, min_seqlen_value, min_seqlen_bdim, max_seqlen_value, max_seqlen_bdim, sum_S); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _nested_tensor_softmax_with_shape_generated_plumbing(const at::Tensor & self, const at::Tensor & query) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(query, cur_level)) { + return at::_ops::_nested_tensor_softmax_with_shape::call(self, query); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [query_value, query_bdim] = unwrapTensorAtLevel(query, cur_level); + auto results = batch_rule(self_value, self_bdim, query_value, query_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _safe_softmax_generated_plumbing(const at::Tensor & self, int64_t dim, ::std::optional dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_safe_softmax::call(self, dim, dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _transformer_encoder_layer_fwd_generated_plumbing(const at::Tensor & src, int64_t embed_dim, int64_t num_heads, const at::Tensor & qkv_weight, const at::Tensor & qkv_bias, const at::Tensor & proj_weight, const at::Tensor & proj_bias, bool use_gelu, bool norm_first, double eps, const at::Tensor & norm_weight_1, const at::Tensor & norm_bias_1, const at::Tensor & norm_weight_2, const at::Tensor & norm_bias_2, const at::Tensor & ffn_weight_1, const at::Tensor & ffn_bias_1, const at::Tensor & ffn_weight_2, const at::Tensor & ffn_bias_2, const ::std::optional & mask, ::std::optional mask_type) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(src, cur_level) && !isBatchedAtLevel(qkv_weight, cur_level) && !isBatchedAtLevel(qkv_bias, cur_level) && !isBatchedAtLevel(proj_weight, cur_level) && !isBatchedAtLevel(proj_bias, cur_level) && !isBatchedAtLevel(norm_weight_1, cur_level) && !isBatchedAtLevel(norm_bias_1, cur_level) && !isBatchedAtLevel(norm_weight_2, cur_level) && !isBatchedAtLevel(norm_bias_2, cur_level) && !isBatchedAtLevel(ffn_weight_1, cur_level) && !isBatchedAtLevel(ffn_bias_1, cur_level) && !isBatchedAtLevel(ffn_weight_2, cur_level) && !isBatchedAtLevel(ffn_bias_2, cur_level) && !isBatchedAtLevel(mask, cur_level)) { + return at::_ops::_transformer_encoder_layer_fwd::call(src, embed_dim, num_heads, qkv_weight, qkv_bias, proj_weight, proj_bias, use_gelu, norm_first, eps, norm_weight_1, norm_bias_1, norm_weight_2, norm_bias_2, ffn_weight_1, ffn_bias_1, ffn_weight_2, ffn_bias_2, mask, mask_type); + } + auto [src_value, src_bdim] = unwrapTensorAtLevel(src, cur_level); + auto [qkv_weight_value, qkv_weight_bdim] = unwrapTensorAtLevel(qkv_weight, cur_level); + auto [qkv_bias_value, qkv_bias_bdim] = unwrapTensorAtLevel(qkv_bias, cur_level); + auto [proj_weight_value, proj_weight_bdim] = unwrapTensorAtLevel(proj_weight, cur_level); + auto [proj_bias_value, proj_bias_bdim] = unwrapTensorAtLevel(proj_bias, cur_level); + auto [norm_weight_1_value, norm_weight_1_bdim] = unwrapTensorAtLevel(norm_weight_1, cur_level); + auto [norm_bias_1_value, norm_bias_1_bdim] = unwrapTensorAtLevel(norm_bias_1, cur_level); + auto [norm_weight_2_value, norm_weight_2_bdim] = unwrapTensorAtLevel(norm_weight_2, cur_level); + auto [norm_bias_2_value, norm_bias_2_bdim] = unwrapTensorAtLevel(norm_bias_2, cur_level); + auto [ffn_weight_1_value, ffn_weight_1_bdim] = unwrapTensorAtLevel(ffn_weight_1, cur_level); + auto [ffn_bias_1_value, ffn_bias_1_bdim] = unwrapTensorAtLevel(ffn_bias_1, cur_level); + auto [ffn_weight_2_value, ffn_weight_2_bdim] = unwrapTensorAtLevel(ffn_weight_2, cur_level); + auto [ffn_bias_2_value, ffn_bias_2_bdim] = unwrapTensorAtLevel(ffn_bias_2, cur_level); + std::optional mask_value; + std::optional mask_bdim; + if (mask) { + std::tie(mask_value, mask_bdim) = unwrapTensorAtLevel(mask.value(), cur_level); + } + auto results = batch_rule(src_value, src_bdim, embed_dim, num_heads, qkv_weight_value, qkv_weight_bdim, qkv_bias_value, qkv_bias_bdim, proj_weight_value, proj_weight_bdim, proj_bias_value, proj_bias_bdim, use_gelu, norm_first, eps, norm_weight_1_value, norm_weight_1_bdim, norm_bias_1_value, norm_bias_1_bdim, norm_weight_2_value, norm_weight_2_bdim, norm_bias_2_value, norm_bias_2_bdim, ffn_weight_1_value, ffn_weight_1_bdim, ffn_bias_1_value, ffn_bias_1_bdim, ffn_weight_2_value, ffn_weight_2_bdim, ffn_bias_2_value, ffn_bias_2_bdim, mask_value, mask_bdim, mask_type); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple _native_multi_head_attention_generated_plumbing(const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, int64_t embed_dim, int64_t num_head, const at::Tensor & qkv_weight, const at::Tensor & qkv_bias, const at::Tensor & proj_weight, const at::Tensor & proj_bias, const ::std::optional & mask, bool need_weights, bool average_attn_weights, ::std::optional mask_type) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(query, cur_level) && !isBatchedAtLevel(key, cur_level) && !isBatchedAtLevel(value, cur_level) && !isBatchedAtLevel(qkv_weight, cur_level) && !isBatchedAtLevel(qkv_bias, cur_level) && !isBatchedAtLevel(proj_weight, cur_level) && !isBatchedAtLevel(proj_bias, cur_level) && !isBatchedAtLevel(mask, cur_level)) { + return at::_ops::_native_multi_head_attention::call(query, key, value, embed_dim, num_head, qkv_weight, qkv_bias, proj_weight, proj_bias, mask, need_weights, average_attn_weights, mask_type); + } + auto [query_value, query_bdim] = unwrapTensorAtLevel(query, cur_level); + auto [key_value, key_bdim] = unwrapTensorAtLevel(key, cur_level); + auto [value_value, value_bdim] = unwrapTensorAtLevel(value, cur_level); + auto [qkv_weight_value, qkv_weight_bdim] = unwrapTensorAtLevel(qkv_weight, cur_level); + auto [qkv_bias_value, qkv_bias_bdim] = unwrapTensorAtLevel(qkv_bias, cur_level); + auto [proj_weight_value, proj_weight_bdim] = unwrapTensorAtLevel(proj_weight, cur_level); + auto [proj_bias_value, proj_bias_bdim] = unwrapTensorAtLevel(proj_bias, cur_level); + std::optional mask_value; + std::optional mask_bdim; + if (mask) { + std::tie(mask_value, mask_bdim) = unwrapTensorAtLevel(mask.value(), cur_level); + } + auto results = batch_rule(query_value, query_bdim, key_value, key_bdim, value_value, value_bdim, embed_dim, num_head, qkv_weight_value, qkv_weight_bdim, qkv_bias_value, qkv_bias_bdim, proj_weight_value, proj_weight_bdim, proj_bias_value, proj_bias_bdim, mask_value, mask_bdim, need_weights, average_attn_weights, mask_type); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor scaled_dot_product_attention_generated_plumbing(const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const ::std::optional & attn_mask, double dropout_p, bool is_causal, ::std::optional scale, bool enable_gqa) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(query, cur_level) && !isBatchedAtLevel(key, cur_level) && !isBatchedAtLevel(value, cur_level) && !isBatchedAtLevel(attn_mask, cur_level)) { + return at::_ops::scaled_dot_product_attention::call(query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa); + } + auto [query_value, query_bdim] = unwrapTensorAtLevel(query, cur_level); + auto [key_value, key_bdim] = unwrapTensorAtLevel(key, cur_level); + auto [value_value, value_bdim] = unwrapTensorAtLevel(value, cur_level); + std::optional attn_mask_value; + std::optional attn_mask_bdim; + if (attn_mask) { + std::tie(attn_mask_value, attn_mask_bdim) = unwrapTensorAtLevel(attn_mask.value(), cur_level); + } + auto results = batch_rule(query_value, query_bdim, key_value, key_bdim, value_value, value_bdim, attn_mask_value, attn_mask_bdim, dropout_p, is_causal, scale, enable_gqa); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple _scaled_dot_product_attention_math_generated_plumbing(const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const ::std::optional & attn_mask, double dropout_p, bool is_causal, const ::std::optional & dropout_mask, ::std::optional scale, bool enable_gqa) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(query, cur_level) && !isBatchedAtLevel(key, cur_level) && !isBatchedAtLevel(value, cur_level) && !isBatchedAtLevel(attn_mask, cur_level) && !isBatchedAtLevel(dropout_mask, cur_level)) { + return at::_ops::_scaled_dot_product_attention_math::call(query, key, value, attn_mask, dropout_p, is_causal, dropout_mask, scale, enable_gqa); + } + auto [query_value, query_bdim] = unwrapTensorAtLevel(query, cur_level); + auto [key_value, key_bdim] = unwrapTensorAtLevel(key, cur_level); + auto [value_value, value_bdim] = unwrapTensorAtLevel(value, cur_level); + std::optional attn_mask_value; + std::optional attn_mask_bdim; + if (attn_mask) { + std::tie(attn_mask_value, attn_mask_bdim) = unwrapTensorAtLevel(attn_mask.value(), cur_level); + } + std::optional dropout_mask_value; + std::optional dropout_mask_bdim; + if (dropout_mask) { + std::tie(dropout_mask_value, dropout_mask_bdim) = unwrapTensorAtLevel(dropout_mask.value(), cur_level); + } + auto results = batch_rule(query_value, query_bdim, key_value, key_bdim, value_value, value_bdim, attn_mask_value, attn_mask_bdim, dropout_p, is_causal, dropout_mask_value, dropout_mask_bdim, scale, enable_gqa); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple _scaled_dot_product_attention_math_for_mps_generated_plumbing(const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const ::std::optional & attn_mask, double dropout_p, bool is_causal, const ::std::optional & dropout_mask, ::std::optional scale) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(query, cur_level) && !isBatchedAtLevel(key, cur_level) && !isBatchedAtLevel(value, cur_level) && !isBatchedAtLevel(attn_mask, cur_level) && !isBatchedAtLevel(dropout_mask, cur_level)) { + return at::_ops::_scaled_dot_product_attention_math_for_mps::call(query, key, value, attn_mask, dropout_p, is_causal, dropout_mask, scale); + } + auto [query_value, query_bdim] = unwrapTensorAtLevel(query, cur_level); + auto [key_value, key_bdim] = unwrapTensorAtLevel(key, cur_level); + auto [value_value, value_bdim] = unwrapTensorAtLevel(value, cur_level); + std::optional attn_mask_value; + std::optional attn_mask_bdim; + if (attn_mask) { + std::tie(attn_mask_value, attn_mask_bdim) = unwrapTensorAtLevel(attn_mask.value(), cur_level); + } + std::optional dropout_mask_value; + std::optional dropout_mask_bdim; + if (dropout_mask) { + std::tie(dropout_mask_value, dropout_mask_bdim) = unwrapTensorAtLevel(dropout_mask.value(), cur_level); + } + auto results = batch_rule(query_value, query_bdim, key_value, key_bdim, value_value, value_bdim, attn_mask_value, attn_mask_bdim, dropout_p, is_causal, dropout_mask_value, dropout_mask_bdim, scale); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple _scaled_dot_product_flash_attention_generated_plumbing(const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, double dropout_p, bool is_causal, bool return_debug_mask, ::std::optional scale) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(query, cur_level) && !isBatchedAtLevel(key, cur_level) && !isBatchedAtLevel(value, cur_level)) { + return at::_ops::_scaled_dot_product_flash_attention::call(query, key, value, dropout_p, is_causal, return_debug_mask, scale); + } + auto [query_value, query_bdim] = unwrapTensorAtLevel(query, cur_level); + auto [key_value, key_bdim] = unwrapTensorAtLevel(key, cur_level); + auto [value_value, value_bdim] = unwrapTensorAtLevel(value, cur_level); + auto results = batch_rule(query_value, query_bdim, key_value, key_bdim, value_value, value_bdim, dropout_p, is_causal, return_debug_mask, scale); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level), makeBatched(std::get<6>(results), std::get<7>(results), cur_level), std::get<8>(results), std::get<9>(results), makeBatched(std::get<10>(results), std::get<11>(results), cur_level), makeBatched(std::get<12>(results), std::get<13>(results), cur_level), makeBatched(std::get<14>(results), std::get<15>(results), cur_level)); +} +template +::std::tuple _scaled_dot_product_flash_attention_for_cpu_generated_plumbing(const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, double dropout_p, bool is_causal, const ::std::optional & attn_mask, ::std::optional scale) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(query, cur_level) && !isBatchedAtLevel(key, cur_level) && !isBatchedAtLevel(value, cur_level) && !isBatchedAtLevel(attn_mask, cur_level)) { + return at::_ops::_scaled_dot_product_flash_attention_for_cpu::call(query, key, value, dropout_p, is_causal, attn_mask, scale); + } + auto [query_value, query_bdim] = unwrapTensorAtLevel(query, cur_level); + auto [key_value, key_bdim] = unwrapTensorAtLevel(key, cur_level); + auto [value_value, value_bdim] = unwrapTensorAtLevel(value, cur_level); + std::optional attn_mask_value; + std::optional attn_mask_bdim; + if (attn_mask) { + std::tie(attn_mask_value, attn_mask_bdim) = unwrapTensorAtLevel(attn_mask.value(), cur_level); + } + auto results = batch_rule(query_value, query_bdim, key_value, key_bdim, value_value, value_bdim, dropout_p, is_causal, attn_mask_value, attn_mask_bdim, scale); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple _scaled_dot_product_flash_attention_backward_generated_plumbing(const at::Tensor & grad_out, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const at::Tensor & out, const at::Tensor & logsumexp, const at::Tensor & cum_seq_q, const at::Tensor & cum_seq_k, c10::SymInt max_q, c10::SymInt max_k, double dropout_p, bool is_causal, const at::Tensor & philox_seed, const at::Tensor & philox_offset, ::std::optional scale) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_out, cur_level) && !isBatchedAtLevel(query, cur_level) && !isBatchedAtLevel(key, cur_level) && !isBatchedAtLevel(value, cur_level) && !isBatchedAtLevel(out, cur_level) && !isBatchedAtLevel(logsumexp, cur_level) && !isBatchedAtLevel(cum_seq_q, cur_level) && !isBatchedAtLevel(cum_seq_k, cur_level) && !isBatchedAtLevel(philox_seed, cur_level) && !isBatchedAtLevel(philox_offset, cur_level)) { + return at::_ops::_scaled_dot_product_flash_attention_backward::call(grad_out, query, key, value, out, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, philox_seed, philox_offset, scale); + } + auto [grad_out_value, grad_out_bdim] = unwrapTensorAtLevel(grad_out, cur_level); + auto [query_value, query_bdim] = unwrapTensorAtLevel(query, cur_level); + auto [key_value, key_bdim] = unwrapTensorAtLevel(key, cur_level); + auto [value_value, value_bdim] = unwrapTensorAtLevel(value, cur_level); + auto [out_value, out_bdim] = unwrapTensorAtLevel(out, cur_level); + auto [logsumexp_value, logsumexp_bdim] = unwrapTensorAtLevel(logsumexp, cur_level); + auto [cum_seq_q_value, cum_seq_q_bdim] = unwrapTensorAtLevel(cum_seq_q, cur_level); + auto [cum_seq_k_value, cum_seq_k_bdim] = unwrapTensorAtLevel(cum_seq_k, cur_level); + auto [philox_seed_value, philox_seed_bdim] = unwrapTensorAtLevel(philox_seed, cur_level); + auto [philox_offset_value, philox_offset_bdim] = unwrapTensorAtLevel(philox_offset, cur_level); + auto results = batch_rule(grad_out_value, grad_out_bdim, query_value, query_bdim, key_value, key_bdim, value_value, value_bdim, out_value, out_bdim, logsumexp_value, logsumexp_bdim, cum_seq_q_value, cum_seq_q_bdim, cum_seq_k_value, cum_seq_k_bdim, max_q, max_k, dropout_p, is_causal, philox_seed_value, philox_seed_bdim, philox_offset_value, philox_offset_bdim, scale); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level)); +} +template +::std::tuple _scaled_dot_product_flash_attention_for_cpu_backward_generated_plumbing(const at::Tensor & grad_out, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const at::Tensor & out, const at::Tensor & logsumexp, double dropout_p, bool is_causal, const ::std::optional & attn_mask, ::std::optional scale) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_out, cur_level) && !isBatchedAtLevel(query, cur_level) && !isBatchedAtLevel(key, cur_level) && !isBatchedAtLevel(value, cur_level) && !isBatchedAtLevel(out, cur_level) && !isBatchedAtLevel(logsumexp, cur_level) && !isBatchedAtLevel(attn_mask, cur_level)) { + return at::_ops::_scaled_dot_product_flash_attention_for_cpu_backward::call(grad_out, query, key, value, out, logsumexp, dropout_p, is_causal, attn_mask, scale); + } + auto [grad_out_value, grad_out_bdim] = unwrapTensorAtLevel(grad_out, cur_level); + auto [query_value, query_bdim] = unwrapTensorAtLevel(query, cur_level); + auto [key_value, key_bdim] = unwrapTensorAtLevel(key, cur_level); + auto [value_value, value_bdim] = unwrapTensorAtLevel(value, cur_level); + auto [out_value, out_bdim] = unwrapTensorAtLevel(out, cur_level); + auto [logsumexp_value, logsumexp_bdim] = unwrapTensorAtLevel(logsumexp, cur_level); + std::optional attn_mask_value; + std::optional attn_mask_bdim; + if (attn_mask) { + std::tie(attn_mask_value, attn_mask_bdim) = unwrapTensorAtLevel(attn_mask.value(), cur_level); + } + auto results = batch_rule(grad_out_value, grad_out_bdim, query_value, query_bdim, key_value, key_bdim, value_value, value_bdim, out_value, out_bdim, logsumexp_value, logsumexp_bdim, dropout_p, is_causal, attn_mask_value, attn_mask_bdim, scale); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level)); +} +template +::std::tuple _scaled_dot_product_fused_attention_overrideable_backward_generated_plumbing(const at::Tensor & grad_out, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const at::Tensor & attn_bias, ::std::array grad_input_mask, const at::Tensor & out, const at::Tensor & logsumexp, const at::Tensor & cum_seq_q, const at::Tensor & cum_seq_k, c10::SymInt max_q, c10::SymInt max_k, double dropout_p, bool is_causal, const at::Tensor & philox_seed, const at::Tensor & philox_offset, ::std::optional scale) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_out, cur_level) && !isBatchedAtLevel(query, cur_level) && !isBatchedAtLevel(key, cur_level) && !isBatchedAtLevel(value, cur_level) && !isBatchedAtLevel(attn_bias, cur_level) && !isBatchedAtLevel(out, cur_level) && !isBatchedAtLevel(logsumexp, cur_level) && !isBatchedAtLevel(cum_seq_q, cur_level) && !isBatchedAtLevel(cum_seq_k, cur_level) && !isBatchedAtLevel(philox_seed, cur_level) && !isBatchedAtLevel(philox_offset, cur_level)) { + return at::_ops::_scaled_dot_product_fused_attention_overrideable_backward::call(grad_out, query, key, value, attn_bias, grad_input_mask, out, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, philox_seed, philox_offset, scale); + } + auto [grad_out_value, grad_out_bdim] = unwrapTensorAtLevel(grad_out, cur_level); + auto [query_value, query_bdim] = unwrapTensorAtLevel(query, cur_level); + auto [key_value, key_bdim] = unwrapTensorAtLevel(key, cur_level); + auto [value_value, value_bdim] = unwrapTensorAtLevel(value, cur_level); + auto [attn_bias_value, attn_bias_bdim] = unwrapTensorAtLevel(attn_bias, cur_level); + auto [out_value, out_bdim] = unwrapTensorAtLevel(out, cur_level); + auto [logsumexp_value, logsumexp_bdim] = unwrapTensorAtLevel(logsumexp, cur_level); + auto [cum_seq_q_value, cum_seq_q_bdim] = unwrapTensorAtLevel(cum_seq_q, cur_level); + auto [cum_seq_k_value, cum_seq_k_bdim] = unwrapTensorAtLevel(cum_seq_k, cur_level); + auto [philox_seed_value, philox_seed_bdim] = unwrapTensorAtLevel(philox_seed, cur_level); + auto [philox_offset_value, philox_offset_bdim] = unwrapTensorAtLevel(philox_offset, cur_level); + auto results = batch_rule(grad_out_value, grad_out_bdim, query_value, query_bdim, key_value, key_bdim, value_value, value_bdim, attn_bias_value, attn_bias_bdim, grad_input_mask, out_value, out_bdim, logsumexp_value, logsumexp_bdim, cum_seq_q_value, cum_seq_q_bdim, cum_seq_k_value, cum_seq_k_bdim, max_q, max_k, dropout_p, is_causal, philox_seed_value, philox_seed_bdim, philox_offset_value, philox_offset_bdim, scale); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level), makeBatched(std::get<6>(results), std::get<7>(results), cur_level)); +} +template +::std::tuple _scaled_dot_product_efficient_attention_generated_plumbing(const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const ::std::optional & attn_bias, bool compute_log_sumexp, double dropout_p, bool is_causal, ::std::optional scale) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(query, cur_level) && !isBatchedAtLevel(key, cur_level) && !isBatchedAtLevel(value, cur_level) && !isBatchedAtLevel(attn_bias, cur_level)) { + return at::_ops::_scaled_dot_product_efficient_attention::call(query, key, value, attn_bias, compute_log_sumexp, dropout_p, is_causal, scale); + } + auto [query_value, query_bdim] = unwrapTensorAtLevel(query, cur_level); + auto [key_value, key_bdim] = unwrapTensorAtLevel(key, cur_level); + auto [value_value, value_bdim] = unwrapTensorAtLevel(value, cur_level); + std::optional attn_bias_value; + std::optional attn_bias_bdim; + if (attn_bias) { + std::tie(attn_bias_value, attn_bias_bdim) = unwrapTensorAtLevel(attn_bias.value(), cur_level); + } + auto results = batch_rule(query_value, query_bdim, key_value, key_bdim, value_value, value_bdim, attn_bias_value, attn_bias_bdim, compute_log_sumexp, dropout_p, is_causal, scale); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level), makeBatched(std::get<6>(results), std::get<7>(results), cur_level)); +} +template +::std::tuple _scaled_dot_product_efficient_attention_backward_generated_plumbing(const at::Tensor & grad_out_, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const at::Tensor & attn_bias, const at::Tensor & out, const at::Tensor & logsumexp, const at::Tensor & philox_seed, const at::Tensor & philox_offset, double dropout_p, ::std::array grad_input_mask, bool is_causal, ::std::optional scale) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_out_, cur_level) && !isBatchedAtLevel(query, cur_level) && !isBatchedAtLevel(key, cur_level) && !isBatchedAtLevel(value, cur_level) && !isBatchedAtLevel(attn_bias, cur_level) && !isBatchedAtLevel(out, cur_level) && !isBatchedAtLevel(logsumexp, cur_level) && !isBatchedAtLevel(philox_seed, cur_level) && !isBatchedAtLevel(philox_offset, cur_level)) { + return at::_ops::_scaled_dot_product_efficient_attention_backward::call(grad_out_, query, key, value, attn_bias, out, logsumexp, philox_seed, philox_offset, dropout_p, grad_input_mask, is_causal, scale); + } + auto [grad_out__value, grad_out__bdim] = unwrapTensorAtLevel(grad_out_, cur_level); + auto [query_value, query_bdim] = unwrapTensorAtLevel(query, cur_level); + auto [key_value, key_bdim] = unwrapTensorAtLevel(key, cur_level); + auto [value_value, value_bdim] = unwrapTensorAtLevel(value, cur_level); + auto [attn_bias_value, attn_bias_bdim] = unwrapTensorAtLevel(attn_bias, cur_level); + auto [out_value, out_bdim] = unwrapTensorAtLevel(out, cur_level); + auto [logsumexp_value, logsumexp_bdim] = unwrapTensorAtLevel(logsumexp, cur_level); + auto [philox_seed_value, philox_seed_bdim] = unwrapTensorAtLevel(philox_seed, cur_level); + auto [philox_offset_value, philox_offset_bdim] = unwrapTensorAtLevel(philox_offset, cur_level); + auto results = batch_rule(grad_out__value, grad_out__bdim, query_value, query_bdim, key_value, key_bdim, value_value, value_bdim, attn_bias_value, attn_bias_bdim, out_value, out_bdim, logsumexp_value, logsumexp_bdim, philox_seed_value, philox_seed_bdim, philox_offset_value, philox_offset_bdim, dropout_p, grad_input_mask, is_causal, scale); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level), makeBatched(std::get<6>(results), std::get<7>(results), cur_level)); +} +template +::std::tuple _scaled_dot_product_cudnn_attention_generated_plumbing(const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const ::std::optional & attn_bias, bool compute_log_sumexp, double dropout_p, bool is_causal, bool return_debug_mask, ::std::optional scale) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(query, cur_level) && !isBatchedAtLevel(key, cur_level) && !isBatchedAtLevel(value, cur_level) && !isBatchedAtLevel(attn_bias, cur_level)) { + return at::_ops::_scaled_dot_product_cudnn_attention::call(query, key, value, attn_bias, compute_log_sumexp, dropout_p, is_causal, return_debug_mask, scale); + } + auto [query_value, query_bdim] = unwrapTensorAtLevel(query, cur_level); + auto [key_value, key_bdim] = unwrapTensorAtLevel(key, cur_level); + auto [value_value, value_bdim] = unwrapTensorAtLevel(value, cur_level); + std::optional attn_bias_value; + std::optional attn_bias_bdim; + if (attn_bias) { + std::tie(attn_bias_value, attn_bias_bdim) = unwrapTensorAtLevel(attn_bias.value(), cur_level); + } + auto results = batch_rule(query_value, query_bdim, key_value, key_bdim, value_value, value_bdim, attn_bias_value, attn_bias_bdim, compute_log_sumexp, dropout_p, is_causal, return_debug_mask, scale); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level), makeBatched(std::get<6>(results), std::get<7>(results), cur_level), std::get<8>(results), std::get<9>(results), makeBatched(std::get<10>(results), std::get<11>(results), cur_level), makeBatched(std::get<12>(results), std::get<13>(results), cur_level), makeBatched(std::get<14>(results), std::get<15>(results), cur_level)); +} +template +::std::tuple _scaled_dot_product_cudnn_attention_backward_generated_plumbing(const at::Tensor & grad_out, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const at::Tensor & out, const at::Tensor & logsumexp, const at::Tensor & philox_seed, const at::Tensor & philox_offset, const at::Tensor & attn_bias, const at::Tensor & cum_seq_q, const at::Tensor & cum_seq_k, c10::SymInt max_q, c10::SymInt max_k, double dropout_p, bool is_causal, ::std::optional scale) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_out, cur_level) && !isBatchedAtLevel(query, cur_level) && !isBatchedAtLevel(key, cur_level) && !isBatchedAtLevel(value, cur_level) && !isBatchedAtLevel(out, cur_level) && !isBatchedAtLevel(logsumexp, cur_level) && !isBatchedAtLevel(philox_seed, cur_level) && !isBatchedAtLevel(philox_offset, cur_level) && !isBatchedAtLevel(attn_bias, cur_level) && !isBatchedAtLevel(cum_seq_q, cur_level) && !isBatchedAtLevel(cum_seq_k, cur_level)) { + return at::_ops::_scaled_dot_product_cudnn_attention_backward::call(grad_out, query, key, value, out, logsumexp, philox_seed, philox_offset, attn_bias, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, scale); + } + auto [grad_out_value, grad_out_bdim] = unwrapTensorAtLevel(grad_out, cur_level); + auto [query_value, query_bdim] = unwrapTensorAtLevel(query, cur_level); + auto [key_value, key_bdim] = unwrapTensorAtLevel(key, cur_level); + auto [value_value, value_bdim] = unwrapTensorAtLevel(value, cur_level); + auto [out_value, out_bdim] = unwrapTensorAtLevel(out, cur_level); + auto [logsumexp_value, logsumexp_bdim] = unwrapTensorAtLevel(logsumexp, cur_level); + auto [philox_seed_value, philox_seed_bdim] = unwrapTensorAtLevel(philox_seed, cur_level); + auto [philox_offset_value, philox_offset_bdim] = unwrapTensorAtLevel(philox_offset, cur_level); + auto [attn_bias_value, attn_bias_bdim] = unwrapTensorAtLevel(attn_bias, cur_level); + auto [cum_seq_q_value, cum_seq_q_bdim] = unwrapTensorAtLevel(cum_seq_q, cur_level); + auto [cum_seq_k_value, cum_seq_k_bdim] = unwrapTensorAtLevel(cum_seq_k, cur_level); + auto results = batch_rule(grad_out_value, grad_out_bdim, query_value, query_bdim, key_value, key_bdim, value_value, value_bdim, out_value, out_bdim, logsumexp_value, logsumexp_bdim, philox_seed_value, philox_seed_bdim, philox_offset_value, philox_offset_bdim, attn_bias_value, attn_bias_bdim, cum_seq_q_value, cum_seq_q_bdim, cum_seq_k_value, cum_seq_k_bdim, max_q, max_k, dropout_p, is_causal, scale); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level)); +} +template +::std::tuple _flash_attention_forward_generated_plumbing(const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const ::std::optional & cum_seq_q, const ::std::optional & cum_seq_k, c10::SymInt max_q, c10::SymInt max_k, double dropout_p, bool is_causal, bool return_debug_mask, ::std::optional scale, ::std::optional window_size_left, ::std::optional window_size_right, const ::std::optional & seqused_k, const ::std::optional & alibi_slopes) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(query, cur_level) && !isBatchedAtLevel(key, cur_level) && !isBatchedAtLevel(value, cur_level) && !isBatchedAtLevel(cum_seq_q, cur_level) && !isBatchedAtLevel(cum_seq_k, cur_level) && !isBatchedAtLevel(seqused_k, cur_level) && !isBatchedAtLevel(alibi_slopes, cur_level)) { + return at::_ops::_flash_attention_forward::call(query, key, value, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, return_debug_mask, scale, window_size_left, window_size_right, seqused_k, alibi_slopes); + } + auto [query_value, query_bdim] = unwrapTensorAtLevel(query, cur_level); + auto [key_value, key_bdim] = unwrapTensorAtLevel(key, cur_level); + auto [value_value, value_bdim] = unwrapTensorAtLevel(value, cur_level); + std::optional cum_seq_q_value; + std::optional cum_seq_q_bdim; + if (cum_seq_q) { + std::tie(cum_seq_q_value, cum_seq_q_bdim) = unwrapTensorAtLevel(cum_seq_q.value(), cur_level); + } + std::optional cum_seq_k_value; + std::optional cum_seq_k_bdim; + if (cum_seq_k) { + std::tie(cum_seq_k_value, cum_seq_k_bdim) = unwrapTensorAtLevel(cum_seq_k.value(), cur_level); + } + std::optional seqused_k_value; + std::optional seqused_k_bdim; + if (seqused_k) { + std::tie(seqused_k_value, seqused_k_bdim) = unwrapTensorAtLevel(seqused_k.value(), cur_level); + } + std::optional alibi_slopes_value; + std::optional alibi_slopes_bdim; + if (alibi_slopes) { + std::tie(alibi_slopes_value, alibi_slopes_bdim) = unwrapTensorAtLevel(alibi_slopes.value(), cur_level); + } + auto results = batch_rule(query_value, query_bdim, key_value, key_bdim, value_value, value_bdim, cum_seq_q_value, cum_seq_q_bdim, cum_seq_k_value, cum_seq_k_bdim, max_q, max_k, dropout_p, is_causal, return_debug_mask, scale, window_size_left, window_size_right, seqused_k_value, seqused_k_bdim, alibi_slopes_value, alibi_slopes_bdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level), makeBatched(std::get<6>(results), std::get<7>(results), cur_level), makeBatched(std::get<8>(results), std::get<9>(results), cur_level)); +} +template +::std::tuple _flash_attention_backward_generated_plumbing(const at::Tensor & grad_out, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const at::Tensor & out, const at::Tensor & logsumexp, const at::Tensor & cum_seq_q, const at::Tensor & cum_seq_k, c10::SymInt max_q, c10::SymInt max_k, double dropout_p, bool is_causal, const at::Tensor & rng_state, const at::Tensor & unused, ::std::optional scale, ::std::optional window_size_left, ::std::optional window_size_right) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_out, cur_level) && !isBatchedAtLevel(query, cur_level) && !isBatchedAtLevel(key, cur_level) && !isBatchedAtLevel(value, cur_level) && !isBatchedAtLevel(out, cur_level) && !isBatchedAtLevel(logsumexp, cur_level) && !isBatchedAtLevel(cum_seq_q, cur_level) && !isBatchedAtLevel(cum_seq_k, cur_level) && !isBatchedAtLevel(rng_state, cur_level) && !isBatchedAtLevel(unused, cur_level)) { + return at::_ops::_flash_attention_backward::call(grad_out, query, key, value, out, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, rng_state, unused, scale, window_size_left, window_size_right); + } + auto [grad_out_value, grad_out_bdim] = unwrapTensorAtLevel(grad_out, cur_level); + auto [query_value, query_bdim] = unwrapTensorAtLevel(query, cur_level); + auto [key_value, key_bdim] = unwrapTensorAtLevel(key, cur_level); + auto [value_value, value_bdim] = unwrapTensorAtLevel(value, cur_level); + auto [out_value, out_bdim] = unwrapTensorAtLevel(out, cur_level); + auto [logsumexp_value, logsumexp_bdim] = unwrapTensorAtLevel(logsumexp, cur_level); + auto [cum_seq_q_value, cum_seq_q_bdim] = unwrapTensorAtLevel(cum_seq_q, cur_level); + auto [cum_seq_k_value, cum_seq_k_bdim] = unwrapTensorAtLevel(cum_seq_k, cur_level); + auto [rng_state_value, rng_state_bdim] = unwrapTensorAtLevel(rng_state, cur_level); + auto [unused_value, unused_bdim] = unwrapTensorAtLevel(unused, cur_level); + auto results = batch_rule(grad_out_value, grad_out_bdim, query_value, query_bdim, key_value, key_bdim, value_value, value_bdim, out_value, out_bdim, logsumexp_value, logsumexp_bdim, cum_seq_q_value, cum_seq_q_bdim, cum_seq_k_value, cum_seq_k_bdim, max_q, max_k, dropout_p, is_causal, rng_state_value, rng_state_bdim, unused_value, unused_bdim, scale, window_size_left, window_size_right); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level)); +} +template +::std::tuple _efficient_attention_backward_generated_plumbing(const at::Tensor & grad_out_, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const ::std::optional & bias, const at::Tensor & out, const ::std::optional & cu_seqlens_q, const ::std::optional & cu_seqlens_k, c10::SymInt max_seqlen_q, c10::SymInt max_seqlen_k, const at::Tensor & logsumexp, double dropout_p, const at::Tensor & philox_seed, const at::Tensor & philox_offset, int64_t custom_mask_type, bool bias_requires_grad, ::std::optional scale, ::std::optional num_splits_key, ::std::optional window_size, bool shared_storage_dqdkdv) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_out_, cur_level) && !isBatchedAtLevel(query, cur_level) && !isBatchedAtLevel(key, cur_level) && !isBatchedAtLevel(value, cur_level) && !isBatchedAtLevel(bias, cur_level) && !isBatchedAtLevel(out, cur_level) && !isBatchedAtLevel(cu_seqlens_q, cur_level) && !isBatchedAtLevel(cu_seqlens_k, cur_level) && !isBatchedAtLevel(logsumexp, cur_level) && !isBatchedAtLevel(philox_seed, cur_level) && !isBatchedAtLevel(philox_offset, cur_level)) { + return at::_ops::_efficient_attention_backward::call(grad_out_, query, key, value, bias, out, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, logsumexp, dropout_p, philox_seed, philox_offset, custom_mask_type, bias_requires_grad, scale, num_splits_key, window_size, shared_storage_dqdkdv); + } + auto [grad_out__value, grad_out__bdim] = unwrapTensorAtLevel(grad_out_, cur_level); + auto [query_value, query_bdim] = unwrapTensorAtLevel(query, cur_level); + auto [key_value, key_bdim] = unwrapTensorAtLevel(key, cur_level); + auto [value_value, value_bdim] = unwrapTensorAtLevel(value, cur_level); + auto [out_value, out_bdim] = unwrapTensorAtLevel(out, cur_level); + auto [logsumexp_value, logsumexp_bdim] = unwrapTensorAtLevel(logsumexp, cur_level); + auto [philox_seed_value, philox_seed_bdim] = unwrapTensorAtLevel(philox_seed, cur_level); + auto [philox_offset_value, philox_offset_bdim] = unwrapTensorAtLevel(philox_offset, cur_level); + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + std::optional cu_seqlens_q_value; + std::optional cu_seqlens_q_bdim; + if (cu_seqlens_q) { + std::tie(cu_seqlens_q_value, cu_seqlens_q_bdim) = unwrapTensorAtLevel(cu_seqlens_q.value(), cur_level); + } + std::optional cu_seqlens_k_value; + std::optional cu_seqlens_k_bdim; + if (cu_seqlens_k) { + std::tie(cu_seqlens_k_value, cu_seqlens_k_bdim) = unwrapTensorAtLevel(cu_seqlens_k.value(), cur_level); + } + auto results = batch_rule(grad_out__value, grad_out__bdim, query_value, query_bdim, key_value, key_bdim, value_value, value_bdim, bias_value, bias_bdim, out_value, out_bdim, cu_seqlens_q_value, cu_seqlens_q_bdim, cu_seqlens_k_value, cu_seqlens_k_bdim, max_seqlen_q, max_seqlen_k, logsumexp_value, logsumexp_bdim, dropout_p, philox_seed_value, philox_seed_bdim, philox_offset_value, philox_offset_bdim, custom_mask_type, bias_requires_grad, scale, num_splits_key, window_size, shared_storage_dqdkdv); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level), makeBatched(std::get<6>(results), std::get<7>(results), cur_level)); +} +template +::std::tuple _cudnn_attention_backward_generated_plumbing(const at::Tensor & grad_out, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const at::Tensor & out, const at::Tensor & logsumexp, const at::Tensor & philox_seed, const at::Tensor & philox_offset, const at::Tensor & attn_bias, const at::Tensor & cum_seq_q, const at::Tensor & cum_seq_k, c10::SymInt max_q, c10::SymInt max_k, double dropout_p, bool is_causal, ::std::optional scale) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_out, cur_level) && !isBatchedAtLevel(query, cur_level) && !isBatchedAtLevel(key, cur_level) && !isBatchedAtLevel(value, cur_level) && !isBatchedAtLevel(out, cur_level) && !isBatchedAtLevel(logsumexp, cur_level) && !isBatchedAtLevel(philox_seed, cur_level) && !isBatchedAtLevel(philox_offset, cur_level) && !isBatchedAtLevel(attn_bias, cur_level) && !isBatchedAtLevel(cum_seq_q, cur_level) && !isBatchedAtLevel(cum_seq_k, cur_level)) { + return at::_ops::_cudnn_attention_backward::call(grad_out, query, key, value, out, logsumexp, philox_seed, philox_offset, attn_bias, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, scale); + } + auto [grad_out_value, grad_out_bdim] = unwrapTensorAtLevel(grad_out, cur_level); + auto [query_value, query_bdim] = unwrapTensorAtLevel(query, cur_level); + auto [key_value, key_bdim] = unwrapTensorAtLevel(key, cur_level); + auto [value_value, value_bdim] = unwrapTensorAtLevel(value, cur_level); + auto [out_value, out_bdim] = unwrapTensorAtLevel(out, cur_level); + auto [logsumexp_value, logsumexp_bdim] = unwrapTensorAtLevel(logsumexp, cur_level); + auto [philox_seed_value, philox_seed_bdim] = unwrapTensorAtLevel(philox_seed, cur_level); + auto [philox_offset_value, philox_offset_bdim] = unwrapTensorAtLevel(philox_offset, cur_level); + auto [attn_bias_value, attn_bias_bdim] = unwrapTensorAtLevel(attn_bias, cur_level); + auto [cum_seq_q_value, cum_seq_q_bdim] = unwrapTensorAtLevel(cum_seq_q, cur_level); + auto [cum_seq_k_value, cum_seq_k_bdim] = unwrapTensorAtLevel(cum_seq_k, cur_level); + auto results = batch_rule(grad_out_value, grad_out_bdim, query_value, query_bdim, key_value, key_bdim, value_value, value_bdim, out_value, out_bdim, logsumexp_value, logsumexp_bdim, philox_seed_value, philox_seed_bdim, philox_offset_value, philox_offset_bdim, attn_bias_value, attn_bias_bdim, cum_seq_q_value, cum_seq_q_bdim, cum_seq_k_value, cum_seq_k_bdim, max_q, max_k, dropout_p, is_causal, scale); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level)); +} +template +at::Tensor _triton_scaled_dot_attention_generated_plumbing(const at::Tensor & q, const at::Tensor & k, const at::Tensor & v, double dropout_p) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(q, cur_level) && !isBatchedAtLevel(k, cur_level) && !isBatchedAtLevel(v, cur_level)) { + return at::_ops::_triton_scaled_dot_attention::call(q, k, v, dropout_p); + } + auto [q_value, q_bdim] = unwrapTensorAtLevel(q, cur_level); + auto [k_value, k_bdim] = unwrapTensorAtLevel(k, cur_level); + auto [v_value, v_bdim] = unwrapTensorAtLevel(v, cur_level); + auto results = batch_rule(q_value, q_bdim, k_value, k_bdim, v_value, v_bdim, dropout_p); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & _fill_mem_eff_dropout_mask__generated_plumbing(at::Tensor & self, double dropout_p, int64_t seed, int64_t offset) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_fill_mem_eff_dropout_mask_::call(self, dropout_p, seed, offset); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, dropout_p, seed, offset); + return self; +} +template +at::Tensor _triton_multi_head_attention_generated_plumbing(const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, int64_t embed_dim, int64_t num_head, const at::Tensor & qkv_weight, const at::Tensor & qkv_bias, const at::Tensor & proj_weight, const at::Tensor & proj_bias, const ::std::optional & mask) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(query, cur_level) && !isBatchedAtLevel(key, cur_level) && !isBatchedAtLevel(value, cur_level) && !isBatchedAtLevel(qkv_weight, cur_level) && !isBatchedAtLevel(qkv_bias, cur_level) && !isBatchedAtLevel(proj_weight, cur_level) && !isBatchedAtLevel(proj_bias, cur_level) && !isBatchedAtLevel(mask, cur_level)) { + return at::_ops::_triton_multi_head_attention::call(query, key, value, embed_dim, num_head, qkv_weight, qkv_bias, proj_weight, proj_bias, mask); + } + auto [query_value, query_bdim] = unwrapTensorAtLevel(query, cur_level); + auto [key_value, key_bdim] = unwrapTensorAtLevel(key, cur_level); + auto [value_value, value_bdim] = unwrapTensorAtLevel(value, cur_level); + auto [qkv_weight_value, qkv_weight_bdim] = unwrapTensorAtLevel(qkv_weight, cur_level); + auto [qkv_bias_value, qkv_bias_bdim] = unwrapTensorAtLevel(qkv_bias, cur_level); + auto [proj_weight_value, proj_weight_bdim] = unwrapTensorAtLevel(proj_weight, cur_level); + auto [proj_bias_value, proj_bias_bdim] = unwrapTensorAtLevel(proj_bias, cur_level); + std::optional mask_value; + std::optional mask_bdim; + if (mask) { + std::tie(mask_value, mask_bdim) = unwrapTensorAtLevel(mask.value(), cur_level); + } + auto results = batch_rule(query_value, query_bdim, key_value, key_bdim, value_value, value_bdim, embed_dim, num_head, qkv_weight_value, qkv_weight_bdim, qkv_bias_value, qkv_bias_bdim, proj_weight_value, proj_weight_bdim, proj_bias_value, proj_bias_bdim, mask_value, mask_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_airy_ai_generated_plumbing(const at::Tensor & x) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(x, cur_level)) { + return at::_ops::special_airy_ai::call(x); + } + auto [x_value, x_bdim] = unwrapTensorAtLevel(x, cur_level); + auto results = batch_rule(x_value, x_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_bessel_j0_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::special_bessel_j0::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_bessel_j1_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::special_bessel_j1::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_bessel_y0_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::special_bessel_y0::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_bessel_y1_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::special_bessel_y1::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_chebyshev_polynomial_t_generated_plumbing(const at::Tensor & x, const at::Tensor & n) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(x, cur_level) && !isBatchedAtLevel(n, cur_level)) { + return at::_ops::special_chebyshev_polynomial_t::call(x, n); + } + auto [x_value, x_bdim] = unwrapTensorAtLevel(x, cur_level); + auto [n_value, n_bdim] = unwrapTensorAtLevel(n, cur_level); + auto results = batch_rule(x_value, x_bdim, n_value, n_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_chebyshev_polynomial_t_x_scalar_generated_plumbing(const at::Scalar & x, const at::Tensor & n) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(n, cur_level)) { + return at::_ops::special_chebyshev_polynomial_t_x_scalar::call(x, n); + } + auto [n_value, n_bdim] = unwrapTensorAtLevel(n, cur_level); + auto results = batch_rule(x, n_value, n_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_chebyshev_polynomial_t_n_scalar_generated_plumbing(const at::Tensor & x, const at::Scalar & n) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(x, cur_level)) { + return at::_ops::special_chebyshev_polynomial_t_n_scalar::call(x, n); + } + auto [x_value, x_bdim] = unwrapTensorAtLevel(x, cur_level); + auto results = batch_rule(x_value, x_bdim, n); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_chebyshev_polynomial_u_generated_plumbing(const at::Tensor & x, const at::Tensor & n) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(x, cur_level) && !isBatchedAtLevel(n, cur_level)) { + return at::_ops::special_chebyshev_polynomial_u::call(x, n); + } + auto [x_value, x_bdim] = unwrapTensorAtLevel(x, cur_level); + auto [n_value, n_bdim] = unwrapTensorAtLevel(n, cur_level); + auto results = batch_rule(x_value, x_bdim, n_value, n_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_chebyshev_polynomial_u_x_scalar_generated_plumbing(const at::Scalar & x, const at::Tensor & n) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(n, cur_level)) { + return at::_ops::special_chebyshev_polynomial_u_x_scalar::call(x, n); + } + auto [n_value, n_bdim] = unwrapTensorAtLevel(n, cur_level); + auto results = batch_rule(x, n_value, n_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_chebyshev_polynomial_u_n_scalar_generated_plumbing(const at::Tensor & x, const at::Scalar & n) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(x, cur_level)) { + return at::_ops::special_chebyshev_polynomial_u_n_scalar::call(x, n); + } + auto [x_value, x_bdim] = unwrapTensorAtLevel(x, cur_level); + auto results = batch_rule(x_value, x_bdim, n); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_chebyshev_polynomial_v_generated_plumbing(const at::Tensor & x, const at::Tensor & n) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(x, cur_level) && !isBatchedAtLevel(n, cur_level)) { + return at::_ops::special_chebyshev_polynomial_v::call(x, n); + } + auto [x_value, x_bdim] = unwrapTensorAtLevel(x, cur_level); + auto [n_value, n_bdim] = unwrapTensorAtLevel(n, cur_level); + auto results = batch_rule(x_value, x_bdim, n_value, n_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_chebyshev_polynomial_v_x_scalar_generated_plumbing(const at::Scalar & x, const at::Tensor & n) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(n, cur_level)) { + return at::_ops::special_chebyshev_polynomial_v_x_scalar::call(x, n); + } + auto [n_value, n_bdim] = unwrapTensorAtLevel(n, cur_level); + auto results = batch_rule(x, n_value, n_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_chebyshev_polynomial_v_n_scalar_generated_plumbing(const at::Tensor & x, const at::Scalar & n) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(x, cur_level)) { + return at::_ops::special_chebyshev_polynomial_v_n_scalar::call(x, n); + } + auto [x_value, x_bdim] = unwrapTensorAtLevel(x, cur_level); + auto results = batch_rule(x_value, x_bdim, n); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_chebyshev_polynomial_w_generated_plumbing(const at::Tensor & x, const at::Tensor & n) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(x, cur_level) && !isBatchedAtLevel(n, cur_level)) { + return at::_ops::special_chebyshev_polynomial_w::call(x, n); + } + auto [x_value, x_bdim] = unwrapTensorAtLevel(x, cur_level); + auto [n_value, n_bdim] = unwrapTensorAtLevel(n, cur_level); + auto results = batch_rule(x_value, x_bdim, n_value, n_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_chebyshev_polynomial_w_x_scalar_generated_plumbing(const at::Scalar & x, const at::Tensor & n) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(n, cur_level)) { + return at::_ops::special_chebyshev_polynomial_w_x_scalar::call(x, n); + } + auto [n_value, n_bdim] = unwrapTensorAtLevel(n, cur_level); + auto results = batch_rule(x, n_value, n_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_chebyshev_polynomial_w_n_scalar_generated_plumbing(const at::Tensor & x, const at::Scalar & n) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(x, cur_level)) { + return at::_ops::special_chebyshev_polynomial_w_n_scalar::call(x, n); + } + auto [x_value, x_bdim] = unwrapTensorAtLevel(x, cur_level); + auto results = batch_rule(x_value, x_bdim, n); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_hermite_polynomial_h_generated_plumbing(const at::Tensor & x, const at::Tensor & n) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(x, cur_level) && !isBatchedAtLevel(n, cur_level)) { + return at::_ops::special_hermite_polynomial_h::call(x, n); + } + auto [x_value, x_bdim] = unwrapTensorAtLevel(x, cur_level); + auto [n_value, n_bdim] = unwrapTensorAtLevel(n, cur_level); + auto results = batch_rule(x_value, x_bdim, n_value, n_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_hermite_polynomial_h_x_scalar_generated_plumbing(const at::Scalar & x, const at::Tensor & n) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(n, cur_level)) { + return at::_ops::special_hermite_polynomial_h_x_scalar::call(x, n); + } + auto [n_value, n_bdim] = unwrapTensorAtLevel(n, cur_level); + auto results = batch_rule(x, n_value, n_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_hermite_polynomial_h_n_scalar_generated_plumbing(const at::Tensor & x, const at::Scalar & n) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(x, cur_level)) { + return at::_ops::special_hermite_polynomial_h_n_scalar::call(x, n); + } + auto [x_value, x_bdim] = unwrapTensorAtLevel(x, cur_level); + auto results = batch_rule(x_value, x_bdim, n); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_hermite_polynomial_he_generated_plumbing(const at::Tensor & x, const at::Tensor & n) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(x, cur_level) && !isBatchedAtLevel(n, cur_level)) { + return at::_ops::special_hermite_polynomial_he::call(x, n); + } + auto [x_value, x_bdim] = unwrapTensorAtLevel(x, cur_level); + auto [n_value, n_bdim] = unwrapTensorAtLevel(n, cur_level); + auto results = batch_rule(x_value, x_bdim, n_value, n_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_hermite_polynomial_he_x_scalar_generated_plumbing(const at::Scalar & x, const at::Tensor & n) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(n, cur_level)) { + return at::_ops::special_hermite_polynomial_he_x_scalar::call(x, n); + } + auto [n_value, n_bdim] = unwrapTensorAtLevel(n, cur_level); + auto results = batch_rule(x, n_value, n_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_hermite_polynomial_he_n_scalar_generated_plumbing(const at::Tensor & x, const at::Scalar & n) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(x, cur_level)) { + return at::_ops::special_hermite_polynomial_he_n_scalar::call(x, n); + } + auto [x_value, x_bdim] = unwrapTensorAtLevel(x, cur_level); + auto results = batch_rule(x_value, x_bdim, n); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_laguerre_polynomial_l_generated_plumbing(const at::Tensor & x, const at::Tensor & n) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(x, cur_level) && !isBatchedAtLevel(n, cur_level)) { + return at::_ops::special_laguerre_polynomial_l::call(x, n); + } + auto [x_value, x_bdim] = unwrapTensorAtLevel(x, cur_level); + auto [n_value, n_bdim] = unwrapTensorAtLevel(n, cur_level); + auto results = batch_rule(x_value, x_bdim, n_value, n_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_laguerre_polynomial_l_x_scalar_generated_plumbing(const at::Scalar & x, const at::Tensor & n) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(n, cur_level)) { + return at::_ops::special_laguerre_polynomial_l_x_scalar::call(x, n); + } + auto [n_value, n_bdim] = unwrapTensorAtLevel(n, cur_level); + auto results = batch_rule(x, n_value, n_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_laguerre_polynomial_l_n_scalar_generated_plumbing(const at::Tensor & x, const at::Scalar & n) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(x, cur_level)) { + return at::_ops::special_laguerre_polynomial_l_n_scalar::call(x, n); + } + auto [x_value, x_bdim] = unwrapTensorAtLevel(x, cur_level); + auto results = batch_rule(x_value, x_bdim, n); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_legendre_polynomial_p_generated_plumbing(const at::Tensor & x, const at::Tensor & n) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(x, cur_level) && !isBatchedAtLevel(n, cur_level)) { + return at::_ops::special_legendre_polynomial_p::call(x, n); + } + auto [x_value, x_bdim] = unwrapTensorAtLevel(x, cur_level); + auto [n_value, n_bdim] = unwrapTensorAtLevel(n, cur_level); + auto results = batch_rule(x_value, x_bdim, n_value, n_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_legendre_polynomial_p_x_scalar_generated_plumbing(const at::Scalar & x, const at::Tensor & n) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(n, cur_level)) { + return at::_ops::special_legendre_polynomial_p_x_scalar::call(x, n); + } + auto [n_value, n_bdim] = unwrapTensorAtLevel(n, cur_level); + auto results = batch_rule(x, n_value, n_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_legendre_polynomial_p_n_scalar_generated_plumbing(const at::Tensor & x, const at::Scalar & n) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(x, cur_level)) { + return at::_ops::special_legendre_polynomial_p_n_scalar::call(x, n); + } + auto [x_value, x_bdim] = unwrapTensorAtLevel(x, cur_level); + auto results = batch_rule(x_value, x_bdim, n); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_modified_bessel_i0_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::special_modified_bessel_i0::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_modified_bessel_i1_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::special_modified_bessel_i1::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_modified_bessel_k0_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::special_modified_bessel_k0::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_modified_bessel_k1_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::special_modified_bessel_k1::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_scaled_modified_bessel_k0_generated_plumbing(const at::Tensor & x) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(x, cur_level)) { + return at::_ops::special_scaled_modified_bessel_k0::call(x); + } + auto [x_value, x_bdim] = unwrapTensorAtLevel(x, cur_level); + auto results = batch_rule(x_value, x_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_scaled_modified_bessel_k1_generated_plumbing(const at::Tensor & x) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(x, cur_level)) { + return at::_ops::special_scaled_modified_bessel_k1::call(x); + } + auto [x_value, x_bdim] = unwrapTensorAtLevel(x, cur_level); + auto results = batch_rule(x_value, x_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_shifted_chebyshev_polynomial_t_generated_plumbing(const at::Tensor & x, const at::Tensor & n) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(x, cur_level) && !isBatchedAtLevel(n, cur_level)) { + return at::_ops::special_shifted_chebyshev_polynomial_t::call(x, n); + } + auto [x_value, x_bdim] = unwrapTensorAtLevel(x, cur_level); + auto [n_value, n_bdim] = unwrapTensorAtLevel(n, cur_level); + auto results = batch_rule(x_value, x_bdim, n_value, n_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_shifted_chebyshev_polynomial_t_x_scalar_generated_plumbing(const at::Scalar & x, const at::Tensor & n) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(n, cur_level)) { + return at::_ops::special_shifted_chebyshev_polynomial_t_x_scalar::call(x, n); + } + auto [n_value, n_bdim] = unwrapTensorAtLevel(n, cur_level); + auto results = batch_rule(x, n_value, n_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_shifted_chebyshev_polynomial_t_n_scalar_generated_plumbing(const at::Tensor & x, const at::Scalar & n) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(x, cur_level)) { + return at::_ops::special_shifted_chebyshev_polynomial_t_n_scalar::call(x, n); + } + auto [x_value, x_bdim] = unwrapTensorAtLevel(x, cur_level); + auto results = batch_rule(x_value, x_bdim, n); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_shifted_chebyshev_polynomial_u_generated_plumbing(const at::Tensor & x, const at::Tensor & n) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(x, cur_level) && !isBatchedAtLevel(n, cur_level)) { + return at::_ops::special_shifted_chebyshev_polynomial_u::call(x, n); + } + auto [x_value, x_bdim] = unwrapTensorAtLevel(x, cur_level); + auto [n_value, n_bdim] = unwrapTensorAtLevel(n, cur_level); + auto results = batch_rule(x_value, x_bdim, n_value, n_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_shifted_chebyshev_polynomial_u_x_scalar_generated_plumbing(const at::Scalar & x, const at::Tensor & n) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(n, cur_level)) { + return at::_ops::special_shifted_chebyshev_polynomial_u_x_scalar::call(x, n); + } + auto [n_value, n_bdim] = unwrapTensorAtLevel(n, cur_level); + auto results = batch_rule(x, n_value, n_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_shifted_chebyshev_polynomial_u_n_scalar_generated_plumbing(const at::Tensor & x, const at::Scalar & n) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(x, cur_level)) { + return at::_ops::special_shifted_chebyshev_polynomial_u_n_scalar::call(x, n); + } + auto [x_value, x_bdim] = unwrapTensorAtLevel(x, cur_level); + auto results = batch_rule(x_value, x_bdim, n); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_shifted_chebyshev_polynomial_v_generated_plumbing(const at::Tensor & x, const at::Tensor & n) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(x, cur_level) && !isBatchedAtLevel(n, cur_level)) { + return at::_ops::special_shifted_chebyshev_polynomial_v::call(x, n); + } + auto [x_value, x_bdim] = unwrapTensorAtLevel(x, cur_level); + auto [n_value, n_bdim] = unwrapTensorAtLevel(n, cur_level); + auto results = batch_rule(x_value, x_bdim, n_value, n_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_shifted_chebyshev_polynomial_v_x_scalar_generated_plumbing(const at::Scalar & x, const at::Tensor & n) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(n, cur_level)) { + return at::_ops::special_shifted_chebyshev_polynomial_v_x_scalar::call(x, n); + } + auto [n_value, n_bdim] = unwrapTensorAtLevel(n, cur_level); + auto results = batch_rule(x, n_value, n_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_shifted_chebyshev_polynomial_v_n_scalar_generated_plumbing(const at::Tensor & x, const at::Scalar & n) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(x, cur_level)) { + return at::_ops::special_shifted_chebyshev_polynomial_v_n_scalar::call(x, n); + } + auto [x_value, x_bdim] = unwrapTensorAtLevel(x, cur_level); + auto results = batch_rule(x_value, x_bdim, n); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_shifted_chebyshev_polynomial_w_generated_plumbing(const at::Tensor & x, const at::Tensor & n) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(x, cur_level) && !isBatchedAtLevel(n, cur_level)) { + return at::_ops::special_shifted_chebyshev_polynomial_w::call(x, n); + } + auto [x_value, x_bdim] = unwrapTensorAtLevel(x, cur_level); + auto [n_value, n_bdim] = unwrapTensorAtLevel(n, cur_level); + auto results = batch_rule(x_value, x_bdim, n_value, n_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_shifted_chebyshev_polynomial_w_x_scalar_generated_plumbing(const at::Scalar & x, const at::Tensor & n) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(n, cur_level)) { + return at::_ops::special_shifted_chebyshev_polynomial_w_x_scalar::call(x, n); + } + auto [n_value, n_bdim] = unwrapTensorAtLevel(n, cur_level); + auto results = batch_rule(x, n_value, n_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_shifted_chebyshev_polynomial_w_n_scalar_generated_plumbing(const at::Tensor & x, const at::Scalar & n) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(x, cur_level)) { + return at::_ops::special_shifted_chebyshev_polynomial_w_n_scalar::call(x, n); + } + auto [x_value, x_bdim] = unwrapTensorAtLevel(x, cur_level); + auto results = batch_rule(x_value, x_bdim, n); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_spherical_bessel_j0_generated_plumbing(const at::Tensor & x) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(x, cur_level)) { + return at::_ops::special_spherical_bessel_j0::call(x); + } + auto [x_value, x_bdim] = unwrapTensorAtLevel(x, cur_level); + auto results = batch_rule(x_value, x_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _foobar_generated_plumbing(const at::Tensor & self, bool arg1, bool arg2, bool arg3) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foobar::call(self, arg1, arg2, arg3); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, arg1, arg2, arg3); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _fused_adam__generated_plumbing(at::TensorList self, at::TensorList grads, at::TensorList exp_avgs, at::TensorList exp_avg_sqs, at::TensorList max_exp_avg_sqs, at::TensorList state_steps, double lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const ::std::optional & grad_scale, const ::std::optional & found_inf) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(grads, cur_level) && !isBatchedAtLevel(exp_avgs, cur_level) && !isBatchedAtLevel(exp_avg_sqs, cur_level) && !isBatchedAtLevel(max_exp_avg_sqs, cur_level) && !isBatchedAtLevel(state_steps, cur_level) && !isBatchedAtLevel(grad_scale, cur_level) && !isBatchedAtLevel(found_inf, cur_level)) { + return at::_ops::_fused_adam_::call(self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale, found_inf); + } + std::optional grad_scale_value; + std::optional grad_scale_bdim; + if (grad_scale) { + std::tie(grad_scale_value, grad_scale_bdim) = unwrapTensorAtLevel(grad_scale.value(), cur_level); + } + std::optional found_inf_value; + std::optional found_inf_bdim; + if (found_inf) { + std::tie(found_inf_value, found_inf_bdim) = unwrapTensorAtLevel(found_inf.value(), cur_level); + } + batch_rule(self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale_value, grad_scale_bdim, found_inf_value, found_inf_bdim); +} +template +void _fused_adam__tensor_lr_generated_plumbing(at::TensorList self, at::TensorList grads, at::TensorList exp_avgs, at::TensorList exp_avg_sqs, at::TensorList max_exp_avg_sqs, at::TensorList state_steps, const at::Tensor & lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const ::std::optional & grad_scale, const ::std::optional & found_inf) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(grads, cur_level) && !isBatchedAtLevel(exp_avgs, cur_level) && !isBatchedAtLevel(exp_avg_sqs, cur_level) && !isBatchedAtLevel(max_exp_avg_sqs, cur_level) && !isBatchedAtLevel(state_steps, cur_level) && !isBatchedAtLevel(lr, cur_level) && !isBatchedAtLevel(grad_scale, cur_level) && !isBatchedAtLevel(found_inf, cur_level)) { + return at::_ops::_fused_adam__tensor_lr::call(self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale, found_inf); + } + auto [lr_value, lr_bdim] = unwrapTensorAtLevel(lr, cur_level); + std::optional grad_scale_value; + std::optional grad_scale_bdim; + if (grad_scale) { + std::tie(grad_scale_value, grad_scale_bdim) = unwrapTensorAtLevel(grad_scale.value(), cur_level); + } + std::optional found_inf_value; + std::optional found_inf_bdim; + if (found_inf) { + std::tie(found_inf_value, found_inf_bdim) = unwrapTensorAtLevel(found_inf.value(), cur_level); + } + batch_rule(self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr_value, lr_bdim, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale_value, grad_scale_bdim, found_inf_value, found_inf_bdim); +} +template +void _fused_adamw__generated_plumbing(at::TensorList self, at::TensorList grads, at::TensorList exp_avgs, at::TensorList exp_avg_sqs, at::TensorList max_exp_avg_sqs, at::TensorList state_steps, double lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const ::std::optional & grad_scale, const ::std::optional & found_inf) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(grads, cur_level) && !isBatchedAtLevel(exp_avgs, cur_level) && !isBatchedAtLevel(exp_avg_sqs, cur_level) && !isBatchedAtLevel(max_exp_avg_sqs, cur_level) && !isBatchedAtLevel(state_steps, cur_level) && !isBatchedAtLevel(grad_scale, cur_level) && !isBatchedAtLevel(found_inf, cur_level)) { + return at::_ops::_fused_adamw_::call(self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale, found_inf); + } + std::optional grad_scale_value; + std::optional grad_scale_bdim; + if (grad_scale) { + std::tie(grad_scale_value, grad_scale_bdim) = unwrapTensorAtLevel(grad_scale.value(), cur_level); + } + std::optional found_inf_value; + std::optional found_inf_bdim; + if (found_inf) { + std::tie(found_inf_value, found_inf_bdim) = unwrapTensorAtLevel(found_inf.value(), cur_level); + } + batch_rule(self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale_value, grad_scale_bdim, found_inf_value, found_inf_bdim); +} +template +void _fused_adamw__tensor_lr_generated_plumbing(at::TensorList self, at::TensorList grads, at::TensorList exp_avgs, at::TensorList exp_avg_sqs, at::TensorList max_exp_avg_sqs, at::TensorList state_steps, const at::Tensor & lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const ::std::optional & grad_scale, const ::std::optional & found_inf) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(grads, cur_level) && !isBatchedAtLevel(exp_avgs, cur_level) && !isBatchedAtLevel(exp_avg_sqs, cur_level) && !isBatchedAtLevel(max_exp_avg_sqs, cur_level) && !isBatchedAtLevel(state_steps, cur_level) && !isBatchedAtLevel(lr, cur_level) && !isBatchedAtLevel(grad_scale, cur_level) && !isBatchedAtLevel(found_inf, cur_level)) { + return at::_ops::_fused_adamw__tensor_lr::call(self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale, found_inf); + } + auto [lr_value, lr_bdim] = unwrapTensorAtLevel(lr, cur_level); + std::optional grad_scale_value; + std::optional grad_scale_bdim; + if (grad_scale) { + std::tie(grad_scale_value, grad_scale_bdim) = unwrapTensorAtLevel(grad_scale.value(), cur_level); + } + std::optional found_inf_value; + std::optional found_inf_bdim; + if (found_inf) { + std::tie(found_inf_value, found_inf_bdim) = unwrapTensorAtLevel(found_inf.value(), cur_level); + } + batch_rule(self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr_value, lr_bdim, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale_value, grad_scale_bdim, found_inf_value, found_inf_bdim); +} +template +void _fused_sgd__generated_plumbing(at::TensorList self, at::TensorList grads, at::TensorList momentum_buffer_list, double weight_decay, double momentum, double lr, double dampening, bool nesterov, bool maximize, bool is_first_step, const ::std::optional & grad_scale, const ::std::optional & found_inf) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(grads, cur_level) && !isBatchedAtLevel(momentum_buffer_list, cur_level) && !isBatchedAtLevel(grad_scale, cur_level) && !isBatchedAtLevel(found_inf, cur_level)) { + return at::_ops::_fused_sgd_::call(self, grads, momentum_buffer_list, weight_decay, momentum, lr, dampening, nesterov, maximize, is_first_step, grad_scale, found_inf); + } + std::optional grad_scale_value; + std::optional grad_scale_bdim; + if (grad_scale) { + std::tie(grad_scale_value, grad_scale_bdim) = unwrapTensorAtLevel(grad_scale.value(), cur_level); + } + std::optional found_inf_value; + std::optional found_inf_bdim; + if (found_inf) { + std::tie(found_inf_value, found_inf_bdim) = unwrapTensorAtLevel(found_inf.value(), cur_level); + } + batch_rule(self, grads, momentum_buffer_list, weight_decay, momentum, lr, dampening, nesterov, maximize, is_first_step, grad_scale_value, grad_scale_bdim, found_inf_value, found_inf_bdim); +} +template +void _fused_sgd__tensor_lr_generated_plumbing(at::TensorList self, at::TensorList grads, at::TensorList momentum_buffer_list, double weight_decay, double momentum, const at::Tensor & lr, double dampening, bool nesterov, bool maximize, bool is_first_step, const ::std::optional & grad_scale, const ::std::optional & found_inf) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(grads, cur_level) && !isBatchedAtLevel(momentum_buffer_list, cur_level) && !isBatchedAtLevel(lr, cur_level) && !isBatchedAtLevel(grad_scale, cur_level) && !isBatchedAtLevel(found_inf, cur_level)) { + return at::_ops::_fused_sgd__tensor_lr::call(self, grads, momentum_buffer_list, weight_decay, momentum, lr, dampening, nesterov, maximize, is_first_step, grad_scale, found_inf); + } + auto [lr_value, lr_bdim] = unwrapTensorAtLevel(lr, cur_level); + std::optional grad_scale_value; + std::optional grad_scale_bdim; + if (grad_scale) { + std::tie(grad_scale_value, grad_scale_bdim) = unwrapTensorAtLevel(grad_scale.value(), cur_level); + } + std::optional found_inf_value; + std::optional found_inf_bdim; + if (found_inf) { + std::tie(found_inf_value, found_inf_bdim) = unwrapTensorAtLevel(found_inf.value(), cur_level); + } + batch_rule(self, grads, momentum_buffer_list, weight_decay, momentum, lr_value, lr_bdim, dampening, nesterov, maximize, is_first_step, grad_scale_value, grad_scale_bdim, found_inf_value, found_inf_bdim); +} +template +void _fused_adagrad__generated_plumbing(at::TensorList self, at::TensorList grads, at::TensorList state_sums, at::TensorList state_steps, double lr, double lr_decay, double weight_decay, double eps, bool maximize, const ::std::optional & grad_scale, const ::std::optional & found_inf) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(grads, cur_level) && !isBatchedAtLevel(state_sums, cur_level) && !isBatchedAtLevel(state_steps, cur_level) && !isBatchedAtLevel(grad_scale, cur_level) && !isBatchedAtLevel(found_inf, cur_level)) { + return at::_ops::_fused_adagrad_::call(self, grads, state_sums, state_steps, lr, lr_decay, weight_decay, eps, maximize, grad_scale, found_inf); + } + std::optional grad_scale_value; + std::optional grad_scale_bdim; + if (grad_scale) { + std::tie(grad_scale_value, grad_scale_bdim) = unwrapTensorAtLevel(grad_scale.value(), cur_level); + } + std::optional found_inf_value; + std::optional found_inf_bdim; + if (found_inf) { + std::tie(found_inf_value, found_inf_bdim) = unwrapTensorAtLevel(found_inf.value(), cur_level); + } + batch_rule(self, grads, state_sums, state_steps, lr, lr_decay, weight_decay, eps, maximize, grad_scale_value, grad_scale_bdim, found_inf_value, found_inf_bdim); +} +template +void _fused_adagrad__tensor_lr_generated_plumbing(at::TensorList self, at::TensorList grads, at::TensorList state_sums, at::TensorList state_steps, const at::Tensor & lr, double lr_decay, double weight_decay, double eps, bool maximize, const ::std::optional & grad_scale, const ::std::optional & found_inf) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(grads, cur_level) && !isBatchedAtLevel(state_sums, cur_level) && !isBatchedAtLevel(state_steps, cur_level) && !isBatchedAtLevel(lr, cur_level) && !isBatchedAtLevel(grad_scale, cur_level) && !isBatchedAtLevel(found_inf, cur_level)) { + return at::_ops::_fused_adagrad__tensor_lr::call(self, grads, state_sums, state_steps, lr, lr_decay, weight_decay, eps, maximize, grad_scale, found_inf); + } + auto [lr_value, lr_bdim] = unwrapTensorAtLevel(lr, cur_level); + std::optional grad_scale_value; + std::optional grad_scale_bdim; + if (grad_scale) { + std::tie(grad_scale_value, grad_scale_bdim) = unwrapTensorAtLevel(grad_scale.value(), cur_level); + } + std::optional found_inf_value; + std::optional found_inf_bdim; + if (found_inf) { + std::tie(found_inf_value, found_inf_bdim) = unwrapTensorAtLevel(found_inf.value(), cur_level); + } + batch_rule(self, grads, state_sums, state_steps, lr_value, lr_bdim, lr_decay, weight_decay, eps, maximize, grad_scale_value, grad_scale_bdim, found_inf_value, found_inf_bdim); +} +template +void _propagate_xla_data_generated_plumbing(const at::Tensor & input, const at::Tensor & output) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(output, cur_level)) { + return at::_ops::_propagate_xla_data::call(input, output); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [output_value, output_bdim] = unwrapTensorAtLevel(output, cur_level); + batch_rule(input_value, input_bdim, output_value, output_bdim); +} +template +void _cudnn_rnn_backward_out_generated_plumbing(const at::Tensor & input, at::TensorList weight, int64_t weight_stride0, const at::Tensor & weight_buf, const at::Tensor & hx, const ::std::optional & cx, const at::Tensor & output, const ::std::optional & grad_output, const ::std::optional & grad_hy, const ::std::optional & grad_cy, int64_t mode, c10::SymInt hidden_size, c10::SymInt proj_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, c10::SymIntArrayRef batch_sizes, const ::std::optional & dropout_state, const at::Tensor & reserve, ::std::array output_mask, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::TensorList out3) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(weight_buf, cur_level) && !isBatchedAtLevel(hx, cur_level) && !isBatchedAtLevel(cx, cur_level) && !isBatchedAtLevel(output, cur_level) && !isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(grad_hy, cur_level) && !isBatchedAtLevel(grad_cy, cur_level) && !isBatchedAtLevel(dropout_state, cur_level) && !isBatchedAtLevel(reserve, cur_level) && !isBatchedAtLevel(out0, cur_level) && !isBatchedAtLevel(out1, cur_level) && !isBatchedAtLevel(out2, cur_level) && !isBatchedAtLevel(out3, cur_level)) { + return at::_ops::_cudnn_rnn_backward_out::call(input, weight, weight_stride0, weight_buf, hx, cx, output, grad_output, grad_hy, grad_cy, mode, hidden_size, proj_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state, reserve, output_mask, out0, out1, out2, out3); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [weight_buf_value, weight_buf_bdim] = unwrapTensorAtLevel(weight_buf, cur_level); + auto [hx_value, hx_bdim] = unwrapTensorAtLevel(hx, cur_level); + auto [output_value, output_bdim] = unwrapTensorAtLevel(output, cur_level); + auto [reserve_value, reserve_bdim] = unwrapTensorAtLevel(reserve, cur_level); + auto [out0_value, out0_bdim] = unwrapTensorAtLevel(out0, cur_level); + auto [out1_value, out1_bdim] = unwrapTensorAtLevel(out1, cur_level); + auto [out2_value, out2_bdim] = unwrapTensorAtLevel(out2, cur_level); + std::optional cx_value; + std::optional cx_bdim; + if (cx) { + std::tie(cx_value, cx_bdim) = unwrapTensorAtLevel(cx.value(), cur_level); + } + std::optional grad_output_value; + std::optional grad_output_bdim; + if (grad_output) { + std::tie(grad_output_value, grad_output_bdim) = unwrapTensorAtLevel(grad_output.value(), cur_level); + } + std::optional grad_hy_value; + std::optional grad_hy_bdim; + if (grad_hy) { + std::tie(grad_hy_value, grad_hy_bdim) = unwrapTensorAtLevel(grad_hy.value(), cur_level); + } + std::optional grad_cy_value; + std::optional grad_cy_bdim; + if (grad_cy) { + std::tie(grad_cy_value, grad_cy_bdim) = unwrapTensorAtLevel(grad_cy.value(), cur_level); + } + std::optional dropout_state_value; + std::optional dropout_state_bdim; + if (dropout_state) { + std::tie(dropout_state_value, dropout_state_bdim) = unwrapTensorAtLevel(dropout_state.value(), cur_level); + } + batch_rule(input_value, input_bdim, weight, weight_stride0, weight_buf_value, weight_buf_bdim, hx_value, hx_bdim, cx_value, cx_bdim, output_value, output_bdim, grad_output_value, grad_output_bdim, grad_hy_value, grad_hy_bdim, grad_cy_value, grad_cy_bdim, mode, hidden_size, proj_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state_value, dropout_state_bdim, reserve_value, reserve_bdim, output_mask, out0_value, out0_bdim, out1_value, out1_bdim, out2_value, out2_bdim, out3); +} +template +at::Tensor bernoulli_Tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & p, ::std::optional generator) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(p, cur_level)) { + return at::_ops::bernoulli_Tensor::call(self, p, generator); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [p_value, p_bdim] = unwrapTensorAtLevel(p, cur_level); + auto results = batch_rule(self_value, self_bdim, p_value, p_bdim, generator); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor embedding_renorm_generated_plumbing(const at::Tensor & self, const at::Tensor & indices, double max_norm, double norm_type) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(indices, cur_level)) { + return at::_ops::embedding_renorm::call(self, indices, max_norm, norm_type); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [indices_value, indices_bdim] = unwrapTensorAtLevel(indices, cur_level); + auto results = batch_rule(self_value, self_bdim, indices_value, indices_bdim, max_norm, norm_type); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor resize_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef size, ::std::optional memory_format) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::resize::call(self, size, memory_format); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, size, memory_format); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _resize_output_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef size, at::Device device) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_resize_output::call(self, size, device); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, size, device); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _index_put_impl_generated_plumbing(const at::Tensor & self, const c10::List<::std::optional> & indices, const at::Tensor & values, bool accumulate, bool unsafe) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(indices, cur_level) && !isBatchedAtLevel(values, cur_level)) { + return at::_ops::_index_put_impl::call(self, indices, values, accumulate, unsafe); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [values_value, values_bdim] = unwrapTensorAtLevel(values, cur_level); + auto results = batch_rule(self_value, self_bdim, indices, values_value, values_bdim, accumulate, unsafe); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void miopen_rnn_backward_out_generated_plumbing(const at::Tensor & input, at::TensorList weight, int64_t weight_stride0, const at::Tensor & weight_buf, const at::Tensor & hx, const ::std::optional & cx, const at::Tensor & output, const ::std::optional & grad_output, const ::std::optional & grad_hy, const ::std::optional & grad_cy, int64_t mode, int64_t hidden_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, at::IntArrayRef batch_sizes, const ::std::optional & dropout_state, const at::Tensor & reserve, ::std::array output_mask, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::TensorList out3) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(weight_buf, cur_level) && !isBatchedAtLevel(hx, cur_level) && !isBatchedAtLevel(cx, cur_level) && !isBatchedAtLevel(output, cur_level) && !isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(grad_hy, cur_level) && !isBatchedAtLevel(grad_cy, cur_level) && !isBatchedAtLevel(dropout_state, cur_level) && !isBatchedAtLevel(reserve, cur_level) && !isBatchedAtLevel(out0, cur_level) && !isBatchedAtLevel(out1, cur_level) && !isBatchedAtLevel(out2, cur_level) && !isBatchedAtLevel(out3, cur_level)) { + return at::_ops::miopen_rnn_backward_out::call(input, weight, weight_stride0, weight_buf, hx, cx, output, grad_output, grad_hy, grad_cy, mode, hidden_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state, reserve, output_mask, out0, out1, out2, out3); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [weight_buf_value, weight_buf_bdim] = unwrapTensorAtLevel(weight_buf, cur_level); + auto [hx_value, hx_bdim] = unwrapTensorAtLevel(hx, cur_level); + auto [output_value, output_bdim] = unwrapTensorAtLevel(output, cur_level); + auto [reserve_value, reserve_bdim] = unwrapTensorAtLevel(reserve, cur_level); + auto [out0_value, out0_bdim] = unwrapTensorAtLevel(out0, cur_level); + auto [out1_value, out1_bdim] = unwrapTensorAtLevel(out1, cur_level); + auto [out2_value, out2_bdim] = unwrapTensorAtLevel(out2, cur_level); + std::optional cx_value; + std::optional cx_bdim; + if (cx) { + std::tie(cx_value, cx_bdim) = unwrapTensorAtLevel(cx.value(), cur_level); + } + std::optional grad_output_value; + std::optional grad_output_bdim; + if (grad_output) { + std::tie(grad_output_value, grad_output_bdim) = unwrapTensorAtLevel(grad_output.value(), cur_level); + } + std::optional grad_hy_value; + std::optional grad_hy_bdim; + if (grad_hy) { + std::tie(grad_hy_value, grad_hy_bdim) = unwrapTensorAtLevel(grad_hy.value(), cur_level); + } + std::optional grad_cy_value; + std::optional grad_cy_bdim; + if (grad_cy) { + std::tie(grad_cy_value, grad_cy_bdim) = unwrapTensorAtLevel(grad_cy.value(), cur_level); + } + std::optional dropout_state_value; + std::optional dropout_state_bdim; + if (dropout_state) { + std::tie(dropout_state_value, dropout_state_bdim) = unwrapTensorAtLevel(dropout_state.value(), cur_level); + } + batch_rule(input_value, input_bdim, weight, weight_stride0, weight_buf_value, weight_buf_bdim, hx_value, hx_bdim, cx_value, cx_bdim, output_value, output_bdim, grad_output_value, grad_output_bdim, grad_hy_value, grad_hy_bdim, grad_cy_value, grad_cy_bdim, mode, hidden_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state_value, dropout_state_bdim, reserve_value, reserve_bdim, output_mask, out0_value, out0_bdim, out1_value, out1_bdim, out2_value, out2_bdim, out3); +} +template +::std::tuple _native_batch_norm_legit_functional_generated_plumbing(const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const at::Tensor & running_mean, const at::Tensor & running_var, bool training, double momentum, double eps) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level) && !isBatchedAtLevel(running_mean, cur_level) && !isBatchedAtLevel(running_var, cur_level)) { + return at::_ops::_native_batch_norm_legit_functional::call(input, weight, bias, running_mean, running_var, training, momentum, eps); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [running_mean_value, running_mean_bdim] = unwrapTensorAtLevel(running_mean, cur_level); + auto [running_var_value, running_var_bdim] = unwrapTensorAtLevel(running_var, cur_level); + std::optional weight_value; + std::optional weight_bdim; + if (weight) { + std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight.value(), cur_level); + } + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, weight_value, weight_bdim, bias_value, bias_bdim, running_mean_value, running_mean_bdim, running_var_value, running_var_bdim, training, momentum, eps); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level), makeBatched(std::get<6>(results), std::get<7>(results), cur_level), makeBatched(std::get<8>(results), std::get<9>(results), cur_level)); +} +template +void unsafe_split_Tensor_out_generated_plumbing(const at::Tensor & self, c10::SymInt split_size, int64_t dim, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::unsafe_split_Tensor_out::call(self, split_size, dim, out); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, split_size, dim, out); +} +template +void unsafe_split_with_sizes_out_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef split_sizes, int64_t dim, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::unsafe_split_with_sizes_out::call(self, split_sizes, dim, out); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, split_sizes, dim, out); +} +template +::std::tuple _batch_norm_with_update_functional_generated_plumbing(const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const at::Tensor & running_mean, const at::Tensor & running_var, double momentum, double eps) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level) && !isBatchedAtLevel(running_mean, cur_level) && !isBatchedAtLevel(running_var, cur_level)) { + return at::_ops::_batch_norm_with_update_functional::call(input, weight, bias, running_mean, running_var, momentum, eps); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [running_mean_value, running_mean_bdim] = unwrapTensorAtLevel(running_mean, cur_level); + auto [running_var_value, running_var_bdim] = unwrapTensorAtLevel(running_var, cur_level); + std::optional weight_value; + std::optional weight_bdim; + if (weight) { + std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight.value(), cur_level); + } + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, weight_value, weight_bdim, bias_value, bias_bdim, running_mean_value, running_mean_bdim, running_var_value, running_var_bdim, momentum, eps); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level), makeBatched(std::get<6>(results), std::get<7>(results), cur_level), makeBatched(std::get<8>(results), std::get<9>(results), cur_level), makeBatched(std::get<10>(results), std::get<11>(results), cur_level)); +} +template +at::Tensor resize_as_generated_plumbing(const at::Tensor & self, const at::Tensor & the_template, ::std::optional memory_format) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(the_template, cur_level)) { + return at::_ops::resize_as::call(self, the_template, memory_format); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [the_template_value, the_template_bdim] = unwrapTensorAtLevel(the_template, cur_level); + auto results = batch_rule(self_value, self_bdim, the_template_value, the_template_bdim, memory_format); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor resize_as_sparse_generated_plumbing(const at::Tensor & self, const at::Tensor & the_template) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(the_template, cur_level)) { + return at::_ops::resize_as_sparse::call(self, the_template); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [the_template_value, the_template_bdim] = unwrapTensorAtLevel(the_template, cur_level); + auto results = batch_rule(self_value, self_bdim, the_template_value, the_template_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor zero_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::zero::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor sparse_resize_generated_plumbing(const at::Tensor & self, at::IntArrayRef size, int64_t sparse_dim, int64_t dense_dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::sparse_resize::call(self, size, sparse_dim, dense_dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, size, sparse_dim, dense_dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor sparse_resize_and_clear_generated_plumbing(const at::Tensor & self, at::IntArrayRef size, int64_t sparse_dim, int64_t dense_dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::sparse_resize_and_clear::call(self, size, sparse_dim, dense_dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, size, sparse_dim, dense_dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _coalesced_generated_plumbing(const at::Tensor & self, bool coalesced) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_coalesced::call(self, coalesced); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, coalesced); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor copy_sparse_to_sparse_generated_plumbing(const at::Tensor & self, const at::Tensor & src, bool non_blocking) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(src, cur_level)) { + return at::_ops::copy_sparse_to_sparse::call(self, src, non_blocking); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [src_value, src_bdim] = unwrapTensorAtLevel(src, cur_level); + auto results = batch_rule(self_value, self_bdim, src_value, src_bdim, non_blocking); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void quantize_per_tensor_tensors_out_generated_plumbing(at::TensorList tensors, const at::Tensor & scales, const at::Tensor & zero_points, at::ScalarType dtype, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(tensors, cur_level) && !isBatchedAtLevel(scales, cur_level) && !isBatchedAtLevel(zero_points, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::quantize_per_tensor_tensors_out::call(tensors, scales, zero_points, dtype, out); + } + auto [scales_value, scales_bdim] = unwrapTensorAtLevel(scales, cur_level); + auto [zero_points_value, zero_points_bdim] = unwrapTensorAtLevel(zero_points, cur_level); + batch_rule(tensors, scales_value, scales_bdim, zero_points_value, zero_points_bdim, dtype, out); +} +template +void dequantize_tensors_out_generated_plumbing(at::TensorList tensors, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(tensors, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::dequantize_tensors_out::call(tensors, out); + } + + batch_rule(tensors, out); +} +template +::std::tuple _fused_moving_avg_obs_fq_helper_functional_generated_plumbing(const at::Tensor & self, const at::Tensor & observer_on, const at::Tensor & fake_quant_on, const at::Tensor & running_min, const at::Tensor & running_max, const at::Tensor & scale, const at::Tensor & zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, bool per_row_fake_quant, bool symmetric_quant) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(observer_on, cur_level) && !isBatchedAtLevel(fake_quant_on, cur_level) && !isBatchedAtLevel(running_min, cur_level) && !isBatchedAtLevel(running_max, cur_level) && !isBatchedAtLevel(scale, cur_level) && !isBatchedAtLevel(zero_point, cur_level)) { + return at::_ops::_fused_moving_avg_obs_fq_helper_functional::call(self, observer_on, fake_quant_on, running_min, running_max, scale, zero_point, averaging_const, quant_min, quant_max, ch_axis, per_row_fake_quant, symmetric_quant); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [observer_on_value, observer_on_bdim] = unwrapTensorAtLevel(observer_on, cur_level); + auto [fake_quant_on_value, fake_quant_on_bdim] = unwrapTensorAtLevel(fake_quant_on, cur_level); + auto [running_min_value, running_min_bdim] = unwrapTensorAtLevel(running_min, cur_level); + auto [running_max_value, running_max_bdim] = unwrapTensorAtLevel(running_max, cur_level); + auto [scale_value, scale_bdim] = unwrapTensorAtLevel(scale, cur_level); + auto [zero_point_value, zero_point_bdim] = unwrapTensorAtLevel(zero_point, cur_level); + auto results = batch_rule(self_value, self_bdim, observer_on_value, observer_on_bdim, fake_quant_on_value, fake_quant_on_bdim, running_min_value, running_min_bdim, running_max_value, running_max_bdim, scale_value, scale_bdim, zero_point_value, zero_point_bdim, averaging_const, quant_min, quant_max, ch_axis, per_row_fake_quant, symmetric_quant); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level), makeBatched(std::get<6>(results), std::get<7>(results), cur_level), makeBatched(std::get<8>(results), std::get<9>(results), cur_level), makeBatched(std::get<10>(results), std::get<11>(results), cur_level)); +} +template +void lstm_mps_backward_out_generated_plumbing(const ::std::optional & grad_y, const ::std::optional & grad_hy, const ::std::optional & grad_cy, const at::Tensor & z_state, const at::Tensor & cell_state_fwd, const at::Tensor & input, const at::Tensor & layersOutputs, at::TensorList hx, at::TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional, bool batch_first, at::Tensor & out0, at::TensorList out1, at::TensorList out2) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_y, cur_level) && !isBatchedAtLevel(grad_hy, cur_level) && !isBatchedAtLevel(grad_cy, cur_level) && !isBatchedAtLevel(z_state, cur_level) && !isBatchedAtLevel(cell_state_fwd, cur_level) && !isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(layersOutputs, cur_level) && !isBatchedAtLevel(hx, cur_level) && !isBatchedAtLevel(params, cur_level) && !isBatchedAtLevel(out0, cur_level) && !isBatchedAtLevel(out1, cur_level) && !isBatchedAtLevel(out2, cur_level)) { + return at::_ops::lstm_mps_backward_out::call(grad_y, grad_hy, grad_cy, z_state, cell_state_fwd, input, layersOutputs, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first, out0, out1, out2); + } + auto [z_state_value, z_state_bdim] = unwrapTensorAtLevel(z_state, cur_level); + auto [cell_state_fwd_value, cell_state_fwd_bdim] = unwrapTensorAtLevel(cell_state_fwd, cur_level); + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [layersOutputs_value, layersOutputs_bdim] = unwrapTensorAtLevel(layersOutputs, cur_level); + auto [out0_value, out0_bdim] = unwrapTensorAtLevel(out0, cur_level); + std::optional grad_y_value; + std::optional grad_y_bdim; + if (grad_y) { + std::tie(grad_y_value, grad_y_bdim) = unwrapTensorAtLevel(grad_y.value(), cur_level); + } + std::optional grad_hy_value; + std::optional grad_hy_bdim; + if (grad_hy) { + std::tie(grad_hy_value, grad_hy_bdim) = unwrapTensorAtLevel(grad_hy.value(), cur_level); + } + std::optional grad_cy_value; + std::optional grad_cy_bdim; + if (grad_cy) { + std::tie(grad_cy_value, grad_cy_bdim) = unwrapTensorAtLevel(grad_cy.value(), cur_level); + } + batch_rule(grad_y_value, grad_y_bdim, grad_hy_value, grad_hy_bdim, grad_cy_value, grad_cy_bdim, z_state_value, z_state_bdim, cell_state_fwd_value, cell_state_fwd_bdim, input_value, input_bdim, layersOutputs_value, layersOutputs_bdim, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first, out0_value, out0_bdim, out1, out2); +} +template +at::Tensor set_source_Storage_generated_plumbing(const at::Tensor & self, at::Storage source) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::set_source_Storage::call(self, source); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, source); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor set_source_Storage_storage_offset_generated_plumbing(const at::Tensor & self, at::Storage source, c10::SymInt storage_offset, c10::SymIntArrayRef size, c10::SymIntArrayRef stride) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::set_source_Storage_storage_offset::call(self, source, storage_offset, size, stride); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, source, storage_offset, size, stride); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor set_source_Tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & source) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(source, cur_level)) { + return at::_ops::set_source_Tensor::call(self, source); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [source_value, source_bdim] = unwrapTensorAtLevel(source, cur_level); + auto results = batch_rule(self_value, self_bdim, source_value, source_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor set_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::set::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor random_from_generated_plumbing(const at::Tensor & self, int64_t from, ::std::optional to, ::std::optional generator) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::random_from::call(self, from, to, generator); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, from, to, generator); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor random_to_generated_plumbing(const at::Tensor & self, int64_t to, ::std::optional generator) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::random_to::call(self, to, generator); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, to, generator); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor random_generated_plumbing(const at::Tensor & self, ::std::optional generator) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::random::call(self, generator); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, generator); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor uniform_generated_plumbing(const at::Tensor & self, double from, double to, ::std::optional generator) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::uniform::call(self, from, to, generator); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, from, to, generator); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor cauchy_generated_plumbing(const at::Tensor & self, double median, double sigma, ::std::optional generator) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::cauchy::call(self, median, sigma, generator); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, median, sigma, generator); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor log_normal_generated_plumbing(const at::Tensor & self, double mean, double std, ::std::optional generator) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::log_normal::call(self, mean, std, generator); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, mean, std, generator); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor exponential_generated_plumbing(const at::Tensor & self, double lambd, ::std::optional generator) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::exponential::call(self, lambd, generator); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, lambd, generator); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor geometric_generated_plumbing(const at::Tensor & self, double p, ::std::optional generator) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::geometric::call(self, p, generator); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, p, generator); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _histogramdd_bin_edges_out_generated_plumbing(const at::Tensor & self, at::IntArrayRef bins, ::std::optional> range, const ::std::optional & weight, bool density, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_histogramdd_bin_edges_out::call(self, bins, range, weight, density, out); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + std::optional weight_value; + std::optional weight_bdim; + if (weight) { + std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight.value(), cur_level); + } + batch_rule(self_value, self_bdim, bins, range, weight_value, weight_bdim, density, out); +} +template +void _amp_foreach_non_finite_check_and_unscale_out_generated_plumbing(at::TensorList self, at::Tensor & found_inf, const at::Tensor & inv_scale, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(found_inf, cur_level) && !isBatchedAtLevel(inv_scale, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_amp_foreach_non_finite_check_and_unscale_out::call(self, found_inf, inv_scale, out); + } + auto [found_inf_value, found_inf_bdim] = unwrapTensorAtLevel(found_inf, cur_level); + auto [inv_scale_value, inv_scale_bdim] = unwrapTensorAtLevel(inv_scale, cur_level); + batch_rule(self, found_inf_value, found_inf_bdim, inv_scale_value, inv_scale_bdim, out); +} +template +::std::tuple<::std::vector,at::Tensor> _amp_foreach_non_finite_check_and_unscale_generated_plumbing(at::TensorList self, const at::Tensor & found_inf, const at::Tensor & inv_scale) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(found_inf, cur_level) && !isBatchedAtLevel(inv_scale, cur_level)) { + return at::_ops::_amp_foreach_non_finite_check_and_unscale::call(self, found_inf, inv_scale); + } + auto [found_inf_value, found_inf_bdim] = unwrapTensorAtLevel(found_inf, cur_level); + auto [inv_scale_value, inv_scale_bdim] = unwrapTensorAtLevel(inv_scale, cur_level); + auto results = batch_rule(self, found_inf_value, found_inf_bdim, inv_scale_value, inv_scale_bdim); + return std::make_tuple(makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple _amp_update_scale_generated_plumbing(const at::Tensor & self, const at::Tensor & growth_tracker, const at::Tensor & found_inf, double scale_growth_factor, double scale_backoff_factor, int64_t growth_interval) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(growth_tracker, cur_level) && !isBatchedAtLevel(found_inf, cur_level)) { + return at::_ops::_amp_update_scale::call(self, growth_tracker, found_inf, scale_growth_factor, scale_backoff_factor, growth_interval); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [growth_tracker_value, growth_tracker_bdim] = unwrapTensorAtLevel(growth_tracker, cur_level); + auto [found_inf_value, found_inf_bdim] = unwrapTensorAtLevel(found_inf, cur_level); + auto results = batch_rule(self_value, self_bdim, growth_tracker_value, growth_tracker_bdim, found_inf_value, found_inf_bdim, scale_growth_factor, scale_backoff_factor, growth_interval); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +void _foreach_add_Scalar_out_generated_plumbing(at::TensorList self, const at::Scalar & scalar, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_add_Scalar_out::call(self, scalar, out); + } + + batch_rule(self, scalar, out); +} +template +void _foreach_add_List_out_generated_plumbing(at::TensorList self, at::TensorList other, const at::Scalar & alpha, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_add_List_out::call(self, other, alpha, out); + } + + batch_rule(self, other, alpha, out); +} +template +void _foreach_add_ScalarList_out_generated_plumbing(at::TensorList self, at::ArrayRef scalars, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_add_ScalarList_out::call(self, scalars, out); + } + + batch_rule(self, scalars, out); +} +template +void _foreach_add_Tensor_out_generated_plumbing(at::TensorList self, const at::Tensor & other, const at::Scalar & alpha, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_add_Tensor_out::call(self, other, alpha, out); + } + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self, other_value, other_bdim, alpha, out); +} +template +void _foreach_sub_Scalar_out_generated_plumbing(at::TensorList self, const at::Scalar & scalar, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_sub_Scalar_out::call(self, scalar, out); + } + + batch_rule(self, scalar, out); +} +template +void _foreach_sub_List_out_generated_plumbing(at::TensorList self, at::TensorList other, const at::Scalar & alpha, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_sub_List_out::call(self, other, alpha, out); + } + + batch_rule(self, other, alpha, out); +} +template +void _foreach_sub_ScalarList_out_generated_plumbing(at::TensorList self, at::ArrayRef scalars, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_sub_ScalarList_out::call(self, scalars, out); + } + + batch_rule(self, scalars, out); +} +template +void _foreach_mul_Scalar_out_generated_plumbing(at::TensorList self, const at::Scalar & scalar, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_mul_Scalar_out::call(self, scalar, out); + } + + batch_rule(self, scalar, out); +} +template +void _foreach_mul_List_out_generated_plumbing(at::TensorList self, at::TensorList other, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_mul_List_out::call(self, other, out); + } + + batch_rule(self, other, out); +} +template +void _foreach_mul_ScalarList_out_generated_plumbing(at::TensorList self, at::ArrayRef scalars, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_mul_ScalarList_out::call(self, scalars, out); + } + + batch_rule(self, scalars, out); +} +template +void _foreach_mul_Tensor_out_generated_plumbing(at::TensorList self, const at::Tensor & other, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_mul_Tensor_out::call(self, other, out); + } + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self, other_value, other_bdim, out); +} +template +void _foreach_div_Scalar_out_generated_plumbing(at::TensorList self, const at::Scalar & scalar, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_div_Scalar_out::call(self, scalar, out); + } + + batch_rule(self, scalar, out); +} +template +void _foreach_div_List_out_generated_plumbing(at::TensorList self, at::TensorList other, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_div_List_out::call(self, other, out); + } + + batch_rule(self, other, out); +} +template +void _foreach_div_ScalarList_out_generated_plumbing(at::TensorList self, at::ArrayRef scalars, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_div_ScalarList_out::call(self, scalars, out); + } + + batch_rule(self, scalars, out); +} +template +void _foreach_div_Tensor_out_generated_plumbing(at::TensorList self, const at::Tensor & other, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_div_Tensor_out::call(self, other, out); + } + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self, other_value, other_bdim, out); +} +template +void _foreach_clamp_max_Scalar_out_generated_plumbing(at::TensorList self, const at::Scalar & scalar, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_clamp_max_Scalar_out::call(self, scalar, out); + } + + batch_rule(self, scalar, out); +} +template +void _foreach_clamp_max_List_out_generated_plumbing(at::TensorList self, at::TensorList other, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_clamp_max_List_out::call(self, other, out); + } + + batch_rule(self, other, out); +} +template +void _foreach_clamp_max_ScalarList_out_generated_plumbing(at::TensorList self, at::ArrayRef scalars, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_clamp_max_ScalarList_out::call(self, scalars, out); + } + + batch_rule(self, scalars, out); +} +template +void _foreach_clamp_min_Scalar_out_generated_plumbing(at::TensorList self, const at::Scalar & scalar, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_clamp_min_Scalar_out::call(self, scalar, out); + } + + batch_rule(self, scalar, out); +} +template +void _foreach_clamp_min_List_out_generated_plumbing(at::TensorList self, at::TensorList other, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_clamp_min_List_out::call(self, other, out); + } + + batch_rule(self, other, out); +} +template +void _foreach_clamp_min_ScalarList_out_generated_plumbing(at::TensorList self, at::ArrayRef scalars, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_clamp_min_ScalarList_out::call(self, scalars, out); + } + + batch_rule(self, scalars, out); +} +template +void _foreach_maximum_Scalar_out_generated_plumbing(at::TensorList self, const at::Scalar & scalar, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_maximum_Scalar_out::call(self, scalar, out); + } + + batch_rule(self, scalar, out); +} +template +void _foreach_maximum_List_out_generated_plumbing(at::TensorList self, at::TensorList other, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_maximum_List_out::call(self, other, out); + } + + batch_rule(self, other, out); +} +template +void _foreach_maximum_ScalarList_out_generated_plumbing(at::TensorList self, at::ArrayRef scalars, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_maximum_ScalarList_out::call(self, scalars, out); + } + + batch_rule(self, scalars, out); +} +template +void _foreach_minimum_Scalar_out_generated_plumbing(at::TensorList self, const at::Scalar & scalar, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_minimum_Scalar_out::call(self, scalar, out); + } + + batch_rule(self, scalar, out); +} +template +void _foreach_minimum_List_out_generated_plumbing(at::TensorList self, at::TensorList other, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_minimum_List_out::call(self, other, out); + } + + batch_rule(self, other, out); +} +template +void _foreach_minimum_ScalarList_out_generated_plumbing(at::TensorList self, at::ArrayRef scalars, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_minimum_ScalarList_out::call(self, scalars, out); + } + + batch_rule(self, scalars, out); +} +template +void _foreach_addcdiv_Scalar_out_generated_plumbing(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Scalar & value, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(tensor1, cur_level) && !isBatchedAtLevel(tensor2, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_addcdiv_Scalar_out::call(self, tensor1, tensor2, value, out); + } + + batch_rule(self, tensor1, tensor2, value, out); +} +template +void _foreach_addcdiv_ScalarList_out_generated_plumbing(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, at::ArrayRef scalars, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(tensor1, cur_level) && !isBatchedAtLevel(tensor2, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_addcdiv_ScalarList_out::call(self, tensor1, tensor2, scalars, out); + } + + batch_rule(self, tensor1, tensor2, scalars, out); +} +template +void _foreach_addcdiv_Tensor_out_generated_plumbing(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Tensor & scalars, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(tensor1, cur_level) && !isBatchedAtLevel(tensor2, cur_level) && !isBatchedAtLevel(scalars, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_addcdiv_Tensor_out::call(self, tensor1, tensor2, scalars, out); + } + auto [scalars_value, scalars_bdim] = unwrapTensorAtLevel(scalars, cur_level); + batch_rule(self, tensor1, tensor2, scalars_value, scalars_bdim, out); +} +template +void _foreach_addcmul_Scalar_out_generated_plumbing(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Scalar & value, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(tensor1, cur_level) && !isBatchedAtLevel(tensor2, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_addcmul_Scalar_out::call(self, tensor1, tensor2, value, out); + } + + batch_rule(self, tensor1, tensor2, value, out); +} +template +void _foreach_addcmul_ScalarList_out_generated_plumbing(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, at::ArrayRef scalars, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(tensor1, cur_level) && !isBatchedAtLevel(tensor2, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_addcmul_ScalarList_out::call(self, tensor1, tensor2, scalars, out); + } + + batch_rule(self, tensor1, tensor2, scalars, out); +} +template +void _foreach_addcmul_Tensor_out_generated_plumbing(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Tensor & scalars, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(tensor1, cur_level) && !isBatchedAtLevel(tensor2, cur_level) && !isBatchedAtLevel(scalars, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_addcmul_Tensor_out::call(self, tensor1, tensor2, scalars, out); + } + auto [scalars_value, scalars_bdim] = unwrapTensorAtLevel(scalars, cur_level); + batch_rule(self, tensor1, tensor2, scalars_value, scalars_bdim, out); +} +template +void _foreach_abs_out_generated_plumbing(at::TensorList self, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_abs_out::call(self, out); + } + + batch_rule(self, out); +} +template +void _foreach_acos_out_generated_plumbing(at::TensorList self, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_acos_out::call(self, out); + } + + batch_rule(self, out); +} +template +void _foreach_asin_out_generated_plumbing(at::TensorList self, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_asin_out::call(self, out); + } + + batch_rule(self, out); +} +template +void _foreach_atan_out_generated_plumbing(at::TensorList self, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_atan_out::call(self, out); + } + + batch_rule(self, out); +} +template +void _foreach_ceil_out_generated_plumbing(at::TensorList self, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_ceil_out::call(self, out); + } + + batch_rule(self, out); +} +template +void _foreach_cos_out_generated_plumbing(at::TensorList self, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_cos_out::call(self, out); + } + + batch_rule(self, out); +} +template +void _foreach_cosh_out_generated_plumbing(at::TensorList self, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_cosh_out::call(self, out); + } + + batch_rule(self, out); +} +template +void _foreach_erf_out_generated_plumbing(at::TensorList self, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_erf_out::call(self, out); + } + + batch_rule(self, out); +} +template +void _foreach_erfc_out_generated_plumbing(at::TensorList self, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_erfc_out::call(self, out); + } + + batch_rule(self, out); +} +template +void _foreach_exp_out_generated_plumbing(at::TensorList self, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_exp_out::call(self, out); + } + + batch_rule(self, out); +} +template +void _foreach_expm1_out_generated_plumbing(at::TensorList self, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_expm1_out::call(self, out); + } + + batch_rule(self, out); +} +template +void _foreach_floor_out_generated_plumbing(at::TensorList self, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_floor_out::call(self, out); + } + + batch_rule(self, out); +} +template +void _foreach_frac_out_generated_plumbing(at::TensorList self, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_frac_out::call(self, out); + } + + batch_rule(self, out); +} +template +void _foreach_lerp_List_out_generated_plumbing(at::TensorList self, at::TensorList tensors1, at::TensorList weights, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(tensors1, cur_level) && !isBatchedAtLevel(weights, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_lerp_List_out::call(self, tensors1, weights, out); + } + + batch_rule(self, tensors1, weights, out); +} +template +void _foreach_lerp_Scalar_out_generated_plumbing(at::TensorList self, at::TensorList tensors1, const at::Scalar & weight, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(tensors1, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_lerp_Scalar_out::call(self, tensors1, weight, out); + } + + batch_rule(self, tensors1, weight, out); +} +template +void _foreach_lerp_ScalarList_out_generated_plumbing(at::TensorList self, at::TensorList tensors1, at::ArrayRef weight, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(tensors1, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_lerp_ScalarList_out::call(self, tensors1, weight, out); + } + + batch_rule(self, tensors1, weight, out); +} +template +void _foreach_lgamma_out_generated_plumbing(at::TensorList self, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_lgamma_out::call(self, out); + } + + batch_rule(self, out); +} +template +void _foreach_log_out_generated_plumbing(at::TensorList self, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_log_out::call(self, out); + } + + batch_rule(self, out); +} +template +void _foreach_log10_out_generated_plumbing(at::TensorList self, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_log10_out::call(self, out); + } + + batch_rule(self, out); +} +template +void _foreach_log1p_out_generated_plumbing(at::TensorList self, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_log1p_out::call(self, out); + } + + batch_rule(self, out); +} +template +void _foreach_log2_out_generated_plumbing(at::TensorList self, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_log2_out::call(self, out); + } + + batch_rule(self, out); +} +template +void _foreach_max_out_generated_plumbing(at::TensorList self, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_max_out::call(self, out); + } + + batch_rule(self, out); +} +template +void _foreach_neg_out_generated_plumbing(at::TensorList self, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_neg_out::call(self, out); + } + + batch_rule(self, out); +} +template +void _foreach_norm_Scalar_out_generated_plumbing(at::TensorList self, const at::Scalar & ord, ::std::optional dtype, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_norm_Scalar_out::call(self, ord, dtype, out); + } + + batch_rule(self, ord, dtype, out); +} +template +void _foreach_pow_List_out_generated_plumbing(at::TensorList self, at::TensorList exponent, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(exponent, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_pow_List_out::call(self, exponent, out); + } + + batch_rule(self, exponent, out); +} +template +void _foreach_pow_Scalar_out_generated_plumbing(at::TensorList self, const at::Scalar & exponent, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_pow_Scalar_out::call(self, exponent, out); + } + + batch_rule(self, exponent, out); +} +template +void _foreach_pow_ScalarList_out_generated_plumbing(at::TensorList self, at::ArrayRef exponent, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_pow_ScalarList_out::call(self, exponent, out); + } + + batch_rule(self, exponent, out); +} +template +void _foreach_reciprocal_out_generated_plumbing(at::TensorList self, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_reciprocal_out::call(self, out); + } + + batch_rule(self, out); +} +template +void _foreach_round_out_generated_plumbing(at::TensorList self, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_round_out::call(self, out); + } + + batch_rule(self, out); +} +template +void _foreach_rsqrt_out_generated_plumbing(at::TensorList self, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_rsqrt_out::call(self, out); + } + + batch_rule(self, out); +} +template +void _foreach_sigmoid_out_generated_plumbing(at::TensorList self, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_sigmoid_out::call(self, out); + } + + batch_rule(self, out); +} +template +void _foreach_sign_out_generated_plumbing(at::TensorList self, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_sign_out::call(self, out); + } + + batch_rule(self, out); +} +template +void _foreach_sin_out_generated_plumbing(at::TensorList self, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_sin_out::call(self, out); + } + + batch_rule(self, out); +} +template +void _foreach_sinh_out_generated_plumbing(at::TensorList self, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_sinh_out::call(self, out); + } + + batch_rule(self, out); +} +template +void _foreach_sqrt_out_generated_plumbing(at::TensorList self, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_sqrt_out::call(self, out); + } + + batch_rule(self, out); +} +template +void _foreach_tan_out_generated_plumbing(at::TensorList self, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_tan_out::call(self, out); + } + + batch_rule(self, out); +} +template +void _foreach_tanh_out_generated_plumbing(at::TensorList self, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_tanh_out::call(self, out); + } + + batch_rule(self, out); +} +template +void _foreach_trunc_out_generated_plumbing(at::TensorList self, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_trunc_out::call(self, out); + } + + batch_rule(self, out); +} +template +void _foreach_zero_out_generated_plumbing(at::TensorList self, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_zero_out::call(self, out); + } + + batch_rule(self, out); +} +template +::std::vector _foreach_zero_generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_zero::call(self); + } + + auto results = batch_rule(self); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_copy_out_generated_plumbing(at::TensorList self, at::TensorList src, bool non_blocking, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(src, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_copy_out::call(self, src, non_blocking, out); + } + + batch_rule(self, src, non_blocking, out); +} +template +::std::tuple rrelu_with_noise_functional_generated_plumbing(const at::Tensor & self, const at::Tensor & noise, const at::Scalar & lower, const at::Scalar & upper, bool training, ::std::optional generator) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(noise, cur_level)) { + return at::_ops::rrelu_with_noise_functional::call(self, noise, lower, upper, training, generator); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [noise_value, noise_bdim] = unwrapTensorAtLevel(noise, cur_level); + auto results = batch_rule(self_value, self_bdim, noise_value, noise_bdim, lower, upper, training, generator); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +void _fused_adam_out_generated_plumbing(at::TensorList self, at::TensorList grads, at::TensorList exp_avgs, at::TensorList exp_avg_sqs, at::TensorList max_exp_avg_sqs, at::TensorList state_steps, double lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const ::std::optional & grad_scale, const ::std::optional & found_inf, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(grads, cur_level) && !isBatchedAtLevel(exp_avgs, cur_level) && !isBatchedAtLevel(exp_avg_sqs, cur_level) && !isBatchedAtLevel(max_exp_avg_sqs, cur_level) && !isBatchedAtLevel(state_steps, cur_level) && !isBatchedAtLevel(grad_scale, cur_level) && !isBatchedAtLevel(found_inf, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_fused_adam_out::call(self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale, found_inf, out); + } + std::optional grad_scale_value; + std::optional grad_scale_bdim; + if (grad_scale) { + std::tie(grad_scale_value, grad_scale_bdim) = unwrapTensorAtLevel(grad_scale.value(), cur_level); + } + std::optional found_inf_value; + std::optional found_inf_bdim; + if (found_inf) { + std::tie(found_inf_value, found_inf_bdim) = unwrapTensorAtLevel(found_inf.value(), cur_level); + } + batch_rule(self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale_value, grad_scale_bdim, found_inf_value, found_inf_bdim, out); +} +template +::std::tuple<::std::vector,::std::vector,::std::vector,::std::vector,::std::vector> _fused_adam_generated_plumbing(at::TensorList self, at::TensorList grads, at::TensorList exp_avgs, at::TensorList exp_avg_sqs, at::TensorList max_exp_avg_sqs, at::TensorList state_steps, double lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const ::std::optional & grad_scale, const ::std::optional & found_inf) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(grads, cur_level) && !isBatchedAtLevel(exp_avgs, cur_level) && !isBatchedAtLevel(exp_avg_sqs, cur_level) && !isBatchedAtLevel(max_exp_avg_sqs, cur_level) && !isBatchedAtLevel(state_steps, cur_level) && !isBatchedAtLevel(grad_scale, cur_level) && !isBatchedAtLevel(found_inf, cur_level)) { + return at::_ops::_fused_adam::call(self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale, found_inf); + } + std::optional grad_scale_value; + std::optional grad_scale_bdim; + if (grad_scale) { + std::tie(grad_scale_value, grad_scale_bdim) = unwrapTensorAtLevel(grad_scale.value(), cur_level); + } + std::optional found_inf_value; + std::optional found_inf_bdim; + if (found_inf) { + std::tie(found_inf_value, found_inf_bdim) = unwrapTensorAtLevel(found_inf.value(), cur_level); + } + auto results = batch_rule(self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale_value, grad_scale_bdim, found_inf_value, found_inf_bdim); + return std::make_tuple(makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level), makeBatchedVector(std::get<2>(results), std::get<3>(results), cur_level), makeBatchedVector(std::get<4>(results), std::get<5>(results), cur_level), makeBatchedVector(std::get<6>(results), std::get<7>(results), cur_level), makeBatchedVector(std::get<8>(results), std::get<9>(results), cur_level)); +} +template +void _fused_adam_tensor_lr_out_generated_plumbing(at::TensorList self, at::TensorList grads, at::TensorList exp_avgs, at::TensorList exp_avg_sqs, at::TensorList max_exp_avg_sqs, at::TensorList state_steps, const at::Tensor & lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const ::std::optional & grad_scale, const ::std::optional & found_inf, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(grads, cur_level) && !isBatchedAtLevel(exp_avgs, cur_level) && !isBatchedAtLevel(exp_avg_sqs, cur_level) && !isBatchedAtLevel(max_exp_avg_sqs, cur_level) && !isBatchedAtLevel(state_steps, cur_level) && !isBatchedAtLevel(lr, cur_level) && !isBatchedAtLevel(grad_scale, cur_level) && !isBatchedAtLevel(found_inf, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_fused_adam_tensor_lr_out::call(self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale, found_inf, out); + } + auto [lr_value, lr_bdim] = unwrapTensorAtLevel(lr, cur_level); + std::optional grad_scale_value; + std::optional grad_scale_bdim; + if (grad_scale) { + std::tie(grad_scale_value, grad_scale_bdim) = unwrapTensorAtLevel(grad_scale.value(), cur_level); + } + std::optional found_inf_value; + std::optional found_inf_bdim; + if (found_inf) { + std::tie(found_inf_value, found_inf_bdim) = unwrapTensorAtLevel(found_inf.value(), cur_level); + } + batch_rule(self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr_value, lr_bdim, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale_value, grad_scale_bdim, found_inf_value, found_inf_bdim, out); +} +template +::std::tuple<::std::vector,::std::vector,::std::vector,::std::vector,::std::vector> _fused_adam_tensor_lr_generated_plumbing(at::TensorList self, at::TensorList grads, at::TensorList exp_avgs, at::TensorList exp_avg_sqs, at::TensorList max_exp_avg_sqs, at::TensorList state_steps, const at::Tensor & lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const ::std::optional & grad_scale, const ::std::optional & found_inf) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(grads, cur_level) && !isBatchedAtLevel(exp_avgs, cur_level) && !isBatchedAtLevel(exp_avg_sqs, cur_level) && !isBatchedAtLevel(max_exp_avg_sqs, cur_level) && !isBatchedAtLevel(state_steps, cur_level) && !isBatchedAtLevel(lr, cur_level) && !isBatchedAtLevel(grad_scale, cur_level) && !isBatchedAtLevel(found_inf, cur_level)) { + return at::_ops::_fused_adam_tensor_lr::call(self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale, found_inf); + } + auto [lr_value, lr_bdim] = unwrapTensorAtLevel(lr, cur_level); + std::optional grad_scale_value; + std::optional grad_scale_bdim; + if (grad_scale) { + std::tie(grad_scale_value, grad_scale_bdim) = unwrapTensorAtLevel(grad_scale.value(), cur_level); + } + std::optional found_inf_value; + std::optional found_inf_bdim; + if (found_inf) { + std::tie(found_inf_value, found_inf_bdim) = unwrapTensorAtLevel(found_inf.value(), cur_level); + } + auto results = batch_rule(self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr_value, lr_bdim, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale_value, grad_scale_bdim, found_inf_value, found_inf_bdim); + return std::make_tuple(makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level), makeBatchedVector(std::get<2>(results), std::get<3>(results), cur_level), makeBatchedVector(std::get<4>(results), std::get<5>(results), cur_level), makeBatchedVector(std::get<6>(results), std::get<7>(results), cur_level), makeBatchedVector(std::get<8>(results), std::get<9>(results), cur_level)); +} +template +void _fused_adamw_out_generated_plumbing(at::TensorList self, at::TensorList grads, at::TensorList exp_avgs, at::TensorList exp_avg_sqs, at::TensorList max_exp_avg_sqs, at::TensorList state_steps, double lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const ::std::optional & grad_scale, const ::std::optional & found_inf, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(grads, cur_level) && !isBatchedAtLevel(exp_avgs, cur_level) && !isBatchedAtLevel(exp_avg_sqs, cur_level) && !isBatchedAtLevel(max_exp_avg_sqs, cur_level) && !isBatchedAtLevel(state_steps, cur_level) && !isBatchedAtLevel(grad_scale, cur_level) && !isBatchedAtLevel(found_inf, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_fused_adamw_out::call(self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale, found_inf, out); + } + std::optional grad_scale_value; + std::optional grad_scale_bdim; + if (grad_scale) { + std::tie(grad_scale_value, grad_scale_bdim) = unwrapTensorAtLevel(grad_scale.value(), cur_level); + } + std::optional found_inf_value; + std::optional found_inf_bdim; + if (found_inf) { + std::tie(found_inf_value, found_inf_bdim) = unwrapTensorAtLevel(found_inf.value(), cur_level); + } + batch_rule(self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale_value, grad_scale_bdim, found_inf_value, found_inf_bdim, out); +} +template +::std::tuple<::std::vector,::std::vector,::std::vector,::std::vector,::std::vector> _fused_adamw_generated_plumbing(at::TensorList self, at::TensorList grads, at::TensorList exp_avgs, at::TensorList exp_avg_sqs, at::TensorList max_exp_avg_sqs, at::TensorList state_steps, double lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const ::std::optional & grad_scale, const ::std::optional & found_inf) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(grads, cur_level) && !isBatchedAtLevel(exp_avgs, cur_level) && !isBatchedAtLevel(exp_avg_sqs, cur_level) && !isBatchedAtLevel(max_exp_avg_sqs, cur_level) && !isBatchedAtLevel(state_steps, cur_level) && !isBatchedAtLevel(grad_scale, cur_level) && !isBatchedAtLevel(found_inf, cur_level)) { + return at::_ops::_fused_adamw::call(self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale, found_inf); + } + std::optional grad_scale_value; + std::optional grad_scale_bdim; + if (grad_scale) { + std::tie(grad_scale_value, grad_scale_bdim) = unwrapTensorAtLevel(grad_scale.value(), cur_level); + } + std::optional found_inf_value; + std::optional found_inf_bdim; + if (found_inf) { + std::tie(found_inf_value, found_inf_bdim) = unwrapTensorAtLevel(found_inf.value(), cur_level); + } + auto results = batch_rule(self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale_value, grad_scale_bdim, found_inf_value, found_inf_bdim); + return std::make_tuple(makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level), makeBatchedVector(std::get<2>(results), std::get<3>(results), cur_level), makeBatchedVector(std::get<4>(results), std::get<5>(results), cur_level), makeBatchedVector(std::get<6>(results), std::get<7>(results), cur_level), makeBatchedVector(std::get<8>(results), std::get<9>(results), cur_level)); +} +template +void _fused_adamw_tensor_lr_out_generated_plumbing(at::TensorList self, at::TensorList grads, at::TensorList exp_avgs, at::TensorList exp_avg_sqs, at::TensorList max_exp_avg_sqs, at::TensorList state_steps, const at::Tensor & lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const ::std::optional & grad_scale, const ::std::optional & found_inf, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(grads, cur_level) && !isBatchedAtLevel(exp_avgs, cur_level) && !isBatchedAtLevel(exp_avg_sqs, cur_level) && !isBatchedAtLevel(max_exp_avg_sqs, cur_level) && !isBatchedAtLevel(state_steps, cur_level) && !isBatchedAtLevel(lr, cur_level) && !isBatchedAtLevel(grad_scale, cur_level) && !isBatchedAtLevel(found_inf, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_fused_adamw_tensor_lr_out::call(self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale, found_inf, out); + } + auto [lr_value, lr_bdim] = unwrapTensorAtLevel(lr, cur_level); + std::optional grad_scale_value; + std::optional grad_scale_bdim; + if (grad_scale) { + std::tie(grad_scale_value, grad_scale_bdim) = unwrapTensorAtLevel(grad_scale.value(), cur_level); + } + std::optional found_inf_value; + std::optional found_inf_bdim; + if (found_inf) { + std::tie(found_inf_value, found_inf_bdim) = unwrapTensorAtLevel(found_inf.value(), cur_level); + } + batch_rule(self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr_value, lr_bdim, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale_value, grad_scale_bdim, found_inf_value, found_inf_bdim, out); +} +template +::std::tuple<::std::vector,::std::vector,::std::vector,::std::vector,::std::vector> _fused_adamw_tensor_lr_generated_plumbing(at::TensorList self, at::TensorList grads, at::TensorList exp_avgs, at::TensorList exp_avg_sqs, at::TensorList max_exp_avg_sqs, at::TensorList state_steps, const at::Tensor & lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const ::std::optional & grad_scale, const ::std::optional & found_inf) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(grads, cur_level) && !isBatchedAtLevel(exp_avgs, cur_level) && !isBatchedAtLevel(exp_avg_sqs, cur_level) && !isBatchedAtLevel(max_exp_avg_sqs, cur_level) && !isBatchedAtLevel(state_steps, cur_level) && !isBatchedAtLevel(lr, cur_level) && !isBatchedAtLevel(grad_scale, cur_level) && !isBatchedAtLevel(found_inf, cur_level)) { + return at::_ops::_fused_adamw_tensor_lr::call(self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale, found_inf); + } + auto [lr_value, lr_bdim] = unwrapTensorAtLevel(lr, cur_level); + std::optional grad_scale_value; + std::optional grad_scale_bdim; + if (grad_scale) { + std::tie(grad_scale_value, grad_scale_bdim) = unwrapTensorAtLevel(grad_scale.value(), cur_level); + } + std::optional found_inf_value; + std::optional found_inf_bdim; + if (found_inf) { + std::tie(found_inf_value, found_inf_bdim) = unwrapTensorAtLevel(found_inf.value(), cur_level); + } + auto results = batch_rule(self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr_value, lr_bdim, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale_value, grad_scale_bdim, found_inf_value, found_inf_bdim); + return std::make_tuple(makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level), makeBatchedVector(std::get<2>(results), std::get<3>(results), cur_level), makeBatchedVector(std::get<4>(results), std::get<5>(results), cur_level), makeBatchedVector(std::get<6>(results), std::get<7>(results), cur_level), makeBatchedVector(std::get<8>(results), std::get<9>(results), cur_level)); +} +template +void _fused_sgd_out_generated_plumbing(at::TensorList self, at::TensorList grads, at::TensorList momentum_buffer_list, double weight_decay, double momentum, double lr, double dampening, bool nesterov, bool maximize, bool is_first_step, const ::std::optional & grad_scale, const ::std::optional & found_inf, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(grads, cur_level) && !isBatchedAtLevel(momentum_buffer_list, cur_level) && !isBatchedAtLevel(grad_scale, cur_level) && !isBatchedAtLevel(found_inf, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_fused_sgd_out::call(self, grads, momentum_buffer_list, weight_decay, momentum, lr, dampening, nesterov, maximize, is_first_step, grad_scale, found_inf, out); + } + std::optional grad_scale_value; + std::optional grad_scale_bdim; + if (grad_scale) { + std::tie(grad_scale_value, grad_scale_bdim) = unwrapTensorAtLevel(grad_scale.value(), cur_level); + } + std::optional found_inf_value; + std::optional found_inf_bdim; + if (found_inf) { + std::tie(found_inf_value, found_inf_bdim) = unwrapTensorAtLevel(found_inf.value(), cur_level); + } + batch_rule(self, grads, momentum_buffer_list, weight_decay, momentum, lr, dampening, nesterov, maximize, is_first_step, grad_scale_value, grad_scale_bdim, found_inf_value, found_inf_bdim, out); +} +template +::std::tuple<::std::vector,::std::vector,::std::vector> _fused_sgd_generated_plumbing(at::TensorList self, at::TensorList grads, at::TensorList momentum_buffer_list, double weight_decay, double momentum, double lr, double dampening, bool nesterov, bool maximize, bool is_first_step, const ::std::optional & grad_scale, const ::std::optional & found_inf) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(grads, cur_level) && !isBatchedAtLevel(momentum_buffer_list, cur_level) && !isBatchedAtLevel(grad_scale, cur_level) && !isBatchedAtLevel(found_inf, cur_level)) { + return at::_ops::_fused_sgd::call(self, grads, momentum_buffer_list, weight_decay, momentum, lr, dampening, nesterov, maximize, is_first_step, grad_scale, found_inf); + } + std::optional grad_scale_value; + std::optional grad_scale_bdim; + if (grad_scale) { + std::tie(grad_scale_value, grad_scale_bdim) = unwrapTensorAtLevel(grad_scale.value(), cur_level); + } + std::optional found_inf_value; + std::optional found_inf_bdim; + if (found_inf) { + std::tie(found_inf_value, found_inf_bdim) = unwrapTensorAtLevel(found_inf.value(), cur_level); + } + auto results = batch_rule(self, grads, momentum_buffer_list, weight_decay, momentum, lr, dampening, nesterov, maximize, is_first_step, grad_scale_value, grad_scale_bdim, found_inf_value, found_inf_bdim); + return std::make_tuple(makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level), makeBatchedVector(std::get<2>(results), std::get<3>(results), cur_level), makeBatchedVector(std::get<4>(results), std::get<5>(results), cur_level)); +} +template +void _fused_sgd_tensor_lr_out_generated_plumbing(at::TensorList self, at::TensorList grads, at::TensorList momentum_buffer_list, double weight_decay, double momentum, const at::Tensor & lr, double dampening, bool nesterov, bool maximize, bool is_first_step, const ::std::optional & grad_scale, const ::std::optional & found_inf, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(grads, cur_level) && !isBatchedAtLevel(momentum_buffer_list, cur_level) && !isBatchedAtLevel(lr, cur_level) && !isBatchedAtLevel(grad_scale, cur_level) && !isBatchedAtLevel(found_inf, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_fused_sgd_tensor_lr_out::call(self, grads, momentum_buffer_list, weight_decay, momentum, lr, dampening, nesterov, maximize, is_first_step, grad_scale, found_inf, out); + } + auto [lr_value, lr_bdim] = unwrapTensorAtLevel(lr, cur_level); + std::optional grad_scale_value; + std::optional grad_scale_bdim; + if (grad_scale) { + std::tie(grad_scale_value, grad_scale_bdim) = unwrapTensorAtLevel(grad_scale.value(), cur_level); + } + std::optional found_inf_value; + std::optional found_inf_bdim; + if (found_inf) { + std::tie(found_inf_value, found_inf_bdim) = unwrapTensorAtLevel(found_inf.value(), cur_level); + } + batch_rule(self, grads, momentum_buffer_list, weight_decay, momentum, lr_value, lr_bdim, dampening, nesterov, maximize, is_first_step, grad_scale_value, grad_scale_bdim, found_inf_value, found_inf_bdim, out); +} +template +::std::tuple<::std::vector,::std::vector,::std::vector> _fused_sgd_tensor_lr_generated_plumbing(at::TensorList self, at::TensorList grads, at::TensorList momentum_buffer_list, double weight_decay, double momentum, const at::Tensor & lr, double dampening, bool nesterov, bool maximize, bool is_first_step, const ::std::optional & grad_scale, const ::std::optional & found_inf) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(grads, cur_level) && !isBatchedAtLevel(momentum_buffer_list, cur_level) && !isBatchedAtLevel(lr, cur_level) && !isBatchedAtLevel(grad_scale, cur_level) && !isBatchedAtLevel(found_inf, cur_level)) { + return at::_ops::_fused_sgd_tensor_lr::call(self, grads, momentum_buffer_list, weight_decay, momentum, lr, dampening, nesterov, maximize, is_first_step, grad_scale, found_inf); + } + auto [lr_value, lr_bdim] = unwrapTensorAtLevel(lr, cur_level); + std::optional grad_scale_value; + std::optional grad_scale_bdim; + if (grad_scale) { + std::tie(grad_scale_value, grad_scale_bdim) = unwrapTensorAtLevel(grad_scale.value(), cur_level); + } + std::optional found_inf_value; + std::optional found_inf_bdim; + if (found_inf) { + std::tie(found_inf_value, found_inf_bdim) = unwrapTensorAtLevel(found_inf.value(), cur_level); + } + auto results = batch_rule(self, grads, momentum_buffer_list, weight_decay, momentum, lr_value, lr_bdim, dampening, nesterov, maximize, is_first_step, grad_scale_value, grad_scale_bdim, found_inf_value, found_inf_bdim); + return std::make_tuple(makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level), makeBatchedVector(std::get<2>(results), std::get<3>(results), cur_level), makeBatchedVector(std::get<4>(results), std::get<5>(results), cur_level)); +} +template +void _fused_adagrad_out_generated_plumbing(at::TensorList self, at::TensorList grads, at::TensorList state_sums, at::TensorList state_steps, double lr, double lr_decay, double weight_decay, double eps, bool maximize, const ::std::optional & grad_scale, const ::std::optional & found_inf, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(grads, cur_level) && !isBatchedAtLevel(state_sums, cur_level) && !isBatchedAtLevel(state_steps, cur_level) && !isBatchedAtLevel(grad_scale, cur_level) && !isBatchedAtLevel(found_inf, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_fused_adagrad_out::call(self, grads, state_sums, state_steps, lr, lr_decay, weight_decay, eps, maximize, grad_scale, found_inf, out); + } + std::optional grad_scale_value; + std::optional grad_scale_bdim; + if (grad_scale) { + std::tie(grad_scale_value, grad_scale_bdim) = unwrapTensorAtLevel(grad_scale.value(), cur_level); + } + std::optional found_inf_value; + std::optional found_inf_bdim; + if (found_inf) { + std::tie(found_inf_value, found_inf_bdim) = unwrapTensorAtLevel(found_inf.value(), cur_level); + } + batch_rule(self, grads, state_sums, state_steps, lr, lr_decay, weight_decay, eps, maximize, grad_scale_value, grad_scale_bdim, found_inf_value, found_inf_bdim, out); +} +template +::std::tuple<::std::vector,::std::vector,::std::vector,::std::vector> _fused_adagrad_generated_plumbing(at::TensorList self, at::TensorList grads, at::TensorList state_sums, at::TensorList state_steps, double lr, double lr_decay, double weight_decay, double eps, bool maximize, const ::std::optional & grad_scale, const ::std::optional & found_inf) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(grads, cur_level) && !isBatchedAtLevel(state_sums, cur_level) && !isBatchedAtLevel(state_steps, cur_level) && !isBatchedAtLevel(grad_scale, cur_level) && !isBatchedAtLevel(found_inf, cur_level)) { + return at::_ops::_fused_adagrad::call(self, grads, state_sums, state_steps, lr, lr_decay, weight_decay, eps, maximize, grad_scale, found_inf); + } + std::optional grad_scale_value; + std::optional grad_scale_bdim; + if (grad_scale) { + std::tie(grad_scale_value, grad_scale_bdim) = unwrapTensorAtLevel(grad_scale.value(), cur_level); + } + std::optional found_inf_value; + std::optional found_inf_bdim; + if (found_inf) { + std::tie(found_inf_value, found_inf_bdim) = unwrapTensorAtLevel(found_inf.value(), cur_level); + } + auto results = batch_rule(self, grads, state_sums, state_steps, lr, lr_decay, weight_decay, eps, maximize, grad_scale_value, grad_scale_bdim, found_inf_value, found_inf_bdim); + return std::make_tuple(makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level), makeBatchedVector(std::get<2>(results), std::get<3>(results), cur_level), makeBatchedVector(std::get<4>(results), std::get<5>(results), cur_level), makeBatchedVector(std::get<6>(results), std::get<7>(results), cur_level)); +} +template +void _fused_adagrad_tensor_lr_out_generated_plumbing(at::TensorList self, at::TensorList grads, at::TensorList state_sums, at::TensorList state_steps, const at::Tensor & lr, double lr_decay, double weight_decay, double eps, bool maximize, const ::std::optional & grad_scale, const ::std::optional & found_inf, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(grads, cur_level) && !isBatchedAtLevel(state_sums, cur_level) && !isBatchedAtLevel(state_steps, cur_level) && !isBatchedAtLevel(lr, cur_level) && !isBatchedAtLevel(grad_scale, cur_level) && !isBatchedAtLevel(found_inf, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_fused_adagrad_tensor_lr_out::call(self, grads, state_sums, state_steps, lr, lr_decay, weight_decay, eps, maximize, grad_scale, found_inf, out); + } + auto [lr_value, lr_bdim] = unwrapTensorAtLevel(lr, cur_level); + std::optional grad_scale_value; + std::optional grad_scale_bdim; + if (grad_scale) { + std::tie(grad_scale_value, grad_scale_bdim) = unwrapTensorAtLevel(grad_scale.value(), cur_level); + } + std::optional found_inf_value; + std::optional found_inf_bdim; + if (found_inf) { + std::tie(found_inf_value, found_inf_bdim) = unwrapTensorAtLevel(found_inf.value(), cur_level); + } + batch_rule(self, grads, state_sums, state_steps, lr_value, lr_bdim, lr_decay, weight_decay, eps, maximize, grad_scale_value, grad_scale_bdim, found_inf_value, found_inf_bdim, out); +} +template +::std::tuple<::std::vector,::std::vector,::std::vector> _fused_adagrad_tensor_lr_generated_plumbing(at::TensorList self, at::TensorList grads, at::TensorList state_sums, at::TensorList state_steps, const at::Tensor & lr, double lr_decay, double weight_decay, double eps, bool maximize, const ::std::optional & grad_scale, const ::std::optional & found_inf) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(grads, cur_level) && !isBatchedAtLevel(state_sums, cur_level) && !isBatchedAtLevel(state_steps, cur_level) && !isBatchedAtLevel(lr, cur_level) && !isBatchedAtLevel(grad_scale, cur_level) && !isBatchedAtLevel(found_inf, cur_level)) { + return at::_ops::_fused_adagrad_tensor_lr::call(self, grads, state_sums, state_steps, lr, lr_decay, weight_decay, eps, maximize, grad_scale, found_inf); + } + auto [lr_value, lr_bdim] = unwrapTensorAtLevel(lr, cur_level); + std::optional grad_scale_value; + std::optional grad_scale_bdim; + if (grad_scale) { + std::tie(grad_scale_value, grad_scale_bdim) = unwrapTensorAtLevel(grad_scale.value(), cur_level); + } + std::optional found_inf_value; + std::optional found_inf_bdim; + if (found_inf) { + std::tie(found_inf_value, found_inf_bdim) = unwrapTensorAtLevel(found_inf.value(), cur_level); + } + auto results = batch_rule(self, grads, state_sums, state_steps, lr_value, lr_bdim, lr_decay, weight_decay, eps, maximize, grad_scale_value, grad_scale_bdim, found_inf_value, found_inf_bdim); + return std::make_tuple(makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level), makeBatchedVector(std::get<2>(results), std::get<3>(results), cur_level), makeBatchedVector(std::get<4>(results), std::get<5>(results), cur_level)); +} + +}} // namespace at::functorch + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/WrapDimUtils.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/WrapDimUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..0cce3541b090df1c1758a111c879131dda187de6 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/WrapDimUtils.h @@ -0,0 +1,161 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include +#include + +namespace at { + +// if dim_post_expr is 0 and wrap_scalar is true, then dim must be in the +// range [-1, 0]. This is a special case for scalar tensors and manifests in +// e.g. torch.sum(scalar_tensor, 0) Otherwise, dim should be in the range +// [-dim_post_expr, dim_post_expr-1]. +using c10::maybe_wrap_dim; + +inline int64_t maybe_wrap_dim(int64_t dim, TensorImpl* tensor) { + return maybe_wrap_dim(dim, tensor->dim()); +} + +inline int64_t maybe_wrap_dim(int64_t dim, TensorList tensors) { + if (tensors.empty()) { + // can't wrap empty TensorList; rely on underlying implementation to throw + // error if necessary. + return dim; + } + return maybe_wrap_dim(dim, tensors[0].dim()); +} + +inline int64_t maybe_wrap_dim( + int64_t dim, + const std::vector>& tensor_sizes) { + if (tensor_sizes.empty()) { + // can't wrap empty list; rely on underlying implementation to throw error + // if necessary + return dim; + } + return maybe_wrap_dim(dim, static_cast(tensor_sizes[0].size())); +} + +// Given an array of dimensions `dims` of length `ndims`, this function "Wraps" +// each dim in-place for a tensor of rank `dim_post_expr`, allowing dims to be +// specified using negative indices. +// +// Additionally, if `wrap_scalar` is true then scalar tensors with rank 0, will +// allow dimensions in the range [-1, 0]. Otherwise, an IndexError is raised for +// dimensions not in the range [-dim_post_expr, dim_post_expr). +inline void maybe_wrap_dims_n( + int64_t* dims, + int64_t ndims, + int64_t dim_post_expr, + bool wrap_scalars = true) { + if (dim_post_expr <= 0) { + if (wrap_scalars) { + dim_post_expr = 1; // this will make range [-1, 0] + } else { + TORCH_CHECK_INDEX( + ndims == 0, + "Dimension specified as ", + dims[0], + " but tensor has no dimensions"); + return; + } + } + int64_t min = -dim_post_expr; + int64_t max = dim_post_expr - 1; + for (const auto i : c10::irange(ndims)) { + auto& dim = dims[i]; + if (dim < min || dim > max) { + TORCH_CHECK_INDEX( + false, + "Dimension out of range (expected to be in range of [", + min, + ", ", + max, + "], but got ", + dim, + ")"); + } + if (dim < 0) + dim += dim_post_expr; + } +} + +// Given a contiguous container of dimensions `dims`, this function "Wraps" +// each dim in-place for a tensor of rank `dim_post_expr`, allowing dims to be +// specified using negative indices. +// +// Additionally, if `wrap_scalar` is true then scalar tensors with rank 0, will +// allow dimensions in the range [-1, 0]. Otherwise, an IndexError is raised for +// dimensions not in the range [-dim_post_expr, dim_post_expr). +template +inline void maybe_wrap_dims( + Container& dims, + int64_t dim_post_expr, + bool wrap_scalars = true) { + return maybe_wrap_dims_n( + dims.data(), dims.size(), dim_post_expr, wrap_scalars); +} + +// previously, size [0] tensors were the only possible empty tensors; thus, it +// wasn't possible to cat empty tensors unless all the other tensors were +// 1-dimensional, so we allowed these tensors to be "skipped" (both for wrap +// dimension behavior and dimension size checking). We maintain this behavior +// for backwards compatibility, but only for this specific size (i.e. other +// empty sizes are not skipped). +inline int64_t legacy_cat_wrap_dim( + int64_t dim, + const std::vector>& tensor_sizes) { + for (auto& sizes : tensor_sizes) { + if (sizes.size() == 1 && sizes[0] == 0) { + continue; + } + return maybe_wrap_dim(dim, static_cast(sizes.size())); + } + return dim; +} + +inline int64_t legacy_cat_wrap_dim_symint( + int64_t dim, + const std::vector>& tensor_sizes) { + for (auto& sizes : tensor_sizes) { + if (sizes.size() == 1) { + if (TORCH_GUARD_OR_FALSE(sizes[0].sym_eq(0))) { + continue; + } + } + return maybe_wrap_dim(dim, static_cast(sizes.size())); + } + return dim; +} + +inline int64_t legacy_cat_wrap_dim( + int64_t dim, + const MaterializedITensorListRef& tensors) { + for (const Tensor& tensor : tensors) { + if (tensor.dim() == 1) { + if (TORCH_GUARD_OR_FALSE(tensor.sym_sizes()[0].sym_eq(0))) { + continue; + } + } + return maybe_wrap_dim(dim, tensor.dim()); + } + return dim; +} + +// wrap negative dims in a vector +inline void wrap_all_dims( + std::vector& dims_to_wrap, + int64_t tensor_total_dims) { + for (const auto i : c10::irange(dims_to_wrap.size())) { + dims_to_wrap[i] = maybe_wrap_dim(dims_to_wrap[i], tensor_total_dims); + } +} + +} // namespace at + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/WrapDimUtilsMulti.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/WrapDimUtilsMulti.h new file mode 100644 index 0000000000000000000000000000000000000000..42ea7643ec3b8be712dab89e8941fa244cbc75d8 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/WrapDimUtilsMulti.h @@ -0,0 +1,49 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include +#include + +namespace at { + +// This is in an extra file to work around strange interaction of +// bitset on Windows with operator overloading + +constexpr size_t dim_bitset_size = 64; + +inline std::bitset dim_list_to_bitset( + OptionalIntArrayRef opt_dims, + size_t ndims) { + TORCH_CHECK( + ndims <= dim_bitset_size, + "only tensors with up to ", + dim_bitset_size, + " dims are supported"); + std::bitset seen; + if (opt_dims.has_value()) { + auto dims = opt_dims.value(); + for (const auto i : c10::irange(dims.size())) { + size_t dim = maybe_wrap_dim(dims[i], static_cast(ndims)); + TORCH_CHECK( + !seen[dim], + "dim ", + dim, + " appears multiple times in the list of dims"); + seen[dim] = true; + } + } else { + for (size_t dim = 0; dim < ndims; dim++) { + seen[dim] = true; + } + } + return seen; +} + +} // namespace at + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/autocast_mode.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/autocast_mode.h new file mode 100644 index 0000000000000000000000000000000000000000..ccae04da35f508c9de895ee377dcbc37f0ad724f --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/autocast_mode.h @@ -0,0 +1,976 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include + +#include +#include + +namespace at::autocast { + +TORCH_API bool is_autocast_enabled(at::DeviceType device_type); +TORCH_API void set_autocast_enabled(at::DeviceType device_type, bool enabled); +TORCH_API at::ScalarType get_autocast_dtype(at::DeviceType device_type); +TORCH_API void set_autocast_dtype( + at::DeviceType device_type, + at::ScalarType dtype); +TORCH_API void clear_cache(); +TORCH_API int increment_nesting(); +TORCH_API int decrement_nesting(); +TORCH_API bool is_autocast_cache_enabled(); +TORCH_API void set_autocast_cache_enabled(bool enabled); + +// deprecated CUDA-specific autocast APIs +C10_DEPRECATED_MESSAGE( + "at::autocast::is_enabled() is deprecated. Please use at::autocast::is_autocast_enabled(at::kCUDA) instead.") +inline bool is_enabled() { + TORCH_WARN_DEPRECATION( + "at::autocast::", + __func__, + "() is deprecated. Please use at::autocast::is_autocast_enabled(at::kCUDA) instead.") + return is_autocast_enabled(at::kCUDA); +} +C10_DEPRECATED_MESSAGE( + "at::autocast::set_enabled(enabled) is deprecated. Please use at::autocast::set_autocast_enabled(at::kCUDA, enabled) instead.") +inline void set_enabled(bool enabled) { + TORCH_WARN_DEPRECATION( + "at::autocast::", + __func__, + "(enabled) is deprecated. Please use at::autocast::set_autocast_enabled(at::kCUDA, enabled) instead.") + set_autocast_enabled(at::kCUDA, enabled); +} +C10_DEPRECATED_MESSAGE( + "at::autocast::get_autocast_gpu_dtype() is deprecated. Please use at::autocast::get_autocast_dtype(at::kCUDA) instead.") +inline at::ScalarType get_autocast_gpu_dtype() { + TORCH_WARN_DEPRECATION( + "at::autocast::", + __func__, + "() is deprecated. Please use at::autocast::get_autocast_dtype(at::kCUDA) instead.") + return get_autocast_dtype(at::kCUDA); +} +C10_DEPRECATED_MESSAGE( + "at::autocast::set_autocast_gpu_dtype(dtype) is deprecated. Please use at::autocast::set_autocast_dtype(at::kCUDA, dtype) instead.") +inline void set_autocast_gpu_dtype(at::ScalarType dtype) { + TORCH_WARN_DEPRECATION( + "at::autocast::", + __func__, + "(dtype) is deprecated. Please use at::autocast::set_autocast_dtype(at::kCUDA, dtype) instead.") + set_autocast_dtype(at::kCUDA, dtype); +} + +#define DECLARE_DEPRECATED_AUTOCAST_APIS(name, device_type) \ + C10_DEPRECATED_MESSAGE( \ + "at::autocast::is_" #name \ + "_enabled() is deprecated. Please use at::autocast::is_autocast_enabled(" #device_type \ + ") instead.") \ + inline bool is_##name##_enabled() { \ + TORCH_WARN_DEPRECATION( \ + "at::autocast::", \ + __func__, \ + "() is deprecated. Please use at::autocast::is_autocast_enabled(" #device_type \ + ") instead.") \ + return is_autocast_enabled(device_type); \ + } \ + \ + C10_DEPRECATED_MESSAGE( \ + "at::autocast::set_" #name \ + "_enabled(enabled) is deprecated. Please use at::autocast::set_autocast_enabled(" #device_type \ + ", enabled) instead.") \ + inline void set_##name##_enabled(bool enabled) { \ + TORCH_WARN_DEPRECATION( \ + "at::autocast::", \ + __func__, \ + "(enabled) is deprecated. Please use at::autocast::set_autocast_enabled(" #device_type \ + ", enabled) instead.") \ + set_autocast_enabled(device_type, enabled); \ + } \ + \ + C10_DEPRECATED_MESSAGE( \ + "at::autocast::get_autocast_" #name \ + "_dtype() is deprecated. Please use at::autocast::get_autocast_dtype(" #device_type \ + ") instead.") \ + inline at::ScalarType get_autocast_##name##_dtype() { \ + TORCH_WARN_DEPRECATION( \ + "at::autocast::", \ + __func__, \ + "() is deprecated. Please at::autocast::get_autocast_dtype(" #device_type \ + ") instead.") \ + return get_autocast_dtype(device_type); \ + } \ + \ + C10_DEPRECATED_MESSAGE( \ + "at::autocast::set_autocast_" #name \ + "_dtype(dtype) is deprecated. Please use at::autocast::set_autocast_dtype(" #device_type \ + ", dtype) instead.") \ + inline void set_autocast_##name##_dtype(at::ScalarType dtype) { \ + TORCH_WARN_DEPRECATION( \ + "at::autocast::", \ + __func__, \ + "(dtype) is deprecated. Please use at::autocast::set_autocast_dtype(" #device_type \ + ", dtype) instead.") \ + set_autocast_dtype(device_type, dtype); \ + } + +#define AT_FORALL_DEPRECATED_AUTOCAST_BACKENDS(_) \ + _(cpu, at::kCPU) \ + _(mtia, at::kMTIA) \ + _(xpu, at::kXPU) \ + _(xla, at::kXLA) \ + _(hpu, at::kHPU) \ + _(ipu, at::kIPU) \ + _(privateuseone, at::kPrivateUse1) + +// deprecated other backend specific autocast APIs +// NOLINTNEXTLINE(misc-use-internal-linkage) +AT_FORALL_DEPRECATED_AUTOCAST_BACKENDS(DECLARE_DEPRECATED_AUTOCAST_APIS) + +const std::array _AUTOCAST_SUPPORTED_DEVICES{ + at::kCPU, + at::kCUDA, + at::kMTIA, + at::kMAIA, + at::kXPU, + at::kIPU, + at::kHPU, + at::kXLA, + at::kPrivateUse1, + at::kMPS}; + +namespace { +inline bool is_autocast_eligible( + const Tensor& tensor, + c10::DeviceType device_type) { + switch (device_type) { + case c10::DeviceType::CUDA: + return (tensor.is_cuda() || tensor.is_xla()) && + tensor.is_floating_point(); + case c10::DeviceType::CPU: + return (tensor.is_cpu() || tensor.is_mkldnn()) && + tensor.is_floating_point(); + case c10::DeviceType::MTIA: + return tensor.is_mtia() && tensor.is_floating_point(); + case c10::DeviceType::MAIA: + return tensor.is_maia() && tensor.is_floating_point(); + case c10::DeviceType::XPU: + return tensor.is_xpu() && tensor.is_floating_point(); + case c10::DeviceType::IPU: + return tensor.is_ipu() && tensor.is_floating_point(); + case c10::DeviceType::HPU: + return tensor.is_hpu() && tensor.is_floating_point(); + case c10::DeviceType::XLA: + return tensor.is_xla() && tensor.is_floating_point(); + case c10::DeviceType::PrivateUse1: + return tensor.is_privateuseone() && tensor.is_floating_point(); + case c10::DeviceType::MPS: + return tensor.is_mps() && tensor.is_floating_point(); + default: + return false; + } +} +} // namespace + +inline DispatchKey get_autocast_dispatch_key_from_device_type( + c10::DeviceType device_type) { + switch (device_type) { + case c10::DeviceType::CUDA: + return DispatchKey::Autocast; + case c10::DeviceType::CPU: + return DispatchKey::AutocastCPU; + case c10::DeviceType::MTIA: + return DispatchKey::AutocastMTIA; + case c10::DeviceType::MAIA: + return DispatchKey::AutocastMAIA; + case c10::DeviceType::XPU: + return DispatchKey::AutocastXPU; + case c10::DeviceType::IPU: + return DispatchKey::AutocastIPU; + case c10::DeviceType::HPU: + return DispatchKey::AutocastHPU; + case c10::DeviceType::XLA: + return DispatchKey::AutocastXLA; + case c10::DeviceType::PrivateUse1: + return DispatchKey::AutocastPrivateUse1; + case c10::DeviceType::MPS: + return DispatchKey::AutocastMPS; + default: + TORCH_CHECK( + false, + "unknown device type for autocast in get_autocast_dispatch_key_from_device_type"); + } +} + +inline bool is_autocast_available(c10::DeviceType device_type) { + if (std::find( + _AUTOCAST_SUPPORTED_DEVICES.begin(), + _AUTOCAST_SUPPORTED_DEVICES.end(), + device_type) != _AUTOCAST_SUPPORTED_DEVICES.end()) { + return true; + } else { + return false; + } +} + +inline at::ScalarType get_lower_precision_fp_from_device_type( + c10::DeviceType device_type) { + if (is_autocast_available(device_type)) { + return get_autocast_dtype(device_type); + } else { + TORCH_CHECK( + false, + "unknown device type for autocast in get_lower_precision_fp_from_device_type"); + } +} + +/******************************************************************** +Logic to extract the promote type from any Tensor or TensorList args. +********************************************************************/ + +// Overload to catch Tensor args. +// If nextArg is floating-point, compare its scalar_type with our +// current best guess for the promote type, and update if necessary. +inline at::ScalarType prioritize( + at::ScalarType current, + const Tensor& nextArg, + c10::DeviceType device_type = c10::DeviceType::CUDA) { + if (current == at::kDouble) { + TORCH_CHECK(false, "promote type is double in at::autocast::prioritize"); + return current; + } + at::ScalarType lower_precision_fp = + get_lower_precision_fp_from_device_type(device_type); + if (is_autocast_eligible(nextArg, device_type)) { + auto next = nextArg.scalar_type(); + if (next == at::kDouble) { + return current; // ignores double tensors + } else if (current == at::kFloat || next == at::kFloat) { + return at::kFloat; // prioritizes float over lower_precision_fp + } else if (current == lower_precision_fp && next == lower_precision_fp) { + return lower_precision_fp; + } else { + TORCH_CHECK( + false, "Unexpected floating ScalarType in at::autocast::prioritize"); + return current; + } + } else { + return current; + } +} + +// Overload to catch TensorList args (for e.g. cat, stack). +// Reuses the overload above to process each Tensor in the list. +inline at::ScalarType prioritize( + at::ScalarType current, + const TensorList& list, + c10::DeviceType device_type = c10::DeviceType::CUDA) { + for (const auto& tensor : list) { + current = prioritize(current, tensor, device_type); + } + return current; +} + +inline at::ScalarType prioritize( + at::ScalarType current, + const ITensorListRef& list, + c10::DeviceType device_type = c10::DeviceType::CUDA) { + for (const auto& tensor : list) { + current = prioritize(current, tensor, device_type); + } + return current; +} + +// Template to catch non-Tensor args (no-op that returns current best guess) +template +inline at::ScalarType prioritize( + at::ScalarType current, + T nextArg, + c10::DeviceType device_type = c10::DeviceType::CUDA) { + return current; +} + +// Overload for the tail case. +inline at::ScalarType promote_type( + at::ScalarType current, + c10::DeviceType device_type) { + return current; +} + +// Unpack args and determine if incoming lower_precision_fp tensors need to be +// promoted to float32. Non-Tensor arguments are ignored. +template +inline at::ScalarType promote_type( + at::ScalarType current, + c10::DeviceType device_type, + Arg0 arg0, + Args... args) { + auto new_current = prioritize(current, arg0, device_type); + return promote_type(new_current, device_type, args...); +} + +/**************************************************** +Logic to apply cached casting to any Tensor argument. +****************************************************/ +inline bool is_eligible( + const Tensor& arg, + c10::DeviceType device_type = c10::DeviceType::CUDA) { + return ( + arg.defined() && is_autocast_eligible(arg, device_type) && + (arg.scalar_type() != at::kDouble)); +} + +// Overload to catch Tensor args +TORCH_API Tensor cached_cast( + at::ScalarType to_type, + const Tensor& arg, + c10::DeviceType device_type = c10::DeviceType::CUDA); + +// Overload to process std::optional +inline std::optional cached_cast( + at::ScalarType to_type, + const std::optional& arg, + c10::DeviceType device_type = c10::DeviceType::CUDA) { + if (arg.has_value()) { + return cached_cast(to_type, *arg, device_type); + } else { + return std::nullopt; + } +} + +// Overload to process TensorLists +inline std::vector cached_cast( + at::ScalarType to_type, + const TensorList& arg, + c10::DeviceType device_type = c10::DeviceType::CUDA) { + std::vector vec; + vec.reserve(arg.size()); + for (const auto& t : arg) { + vec.emplace_back(cached_cast(to_type, t, device_type)); + } + return vec; +} + +inline std::vector cached_cast( + at::ScalarType to_type, + const ITensorListRef& arg, + c10::DeviceType device_type = c10::DeviceType::CUDA) { + std::vector vec; + vec.reserve(arg.size()); + for (const auto& t : arg) { + vec.emplace_back(cached_cast(to_type, t, device_type)); + } + return vec; +} + +// Template to catch non-Tensor args. +template +inline T cached_cast( + at::ScalarType to_type, + T arg, + c10::DeviceType device_type = c10::DeviceType::CUDA) { + return arg; +} + +/******************************************************* +Logic to flip an output dtype flag. +Keep it simple for now by assuming only one such flag is +present in the argument list. If I ever need a function +with more than flag I'll figure out something else. +The policy is: +If the user has explicitly specified a dtype, respect it. +Otherwise, set it to the autocast type. +********************************************************/ + +// Overload to catch dtype flags +std::optional inline set_opt_dtype( + at::ScalarType to_type, + const std::optional& dtype) { + return dtype.has_value() ? dtype : to_type; +} + +// Template to catch other args +template +inline T set_opt_dtype(at::ScalarType to_type, T arg) { + return arg; +} + +template +inline bool firstarg_is_eligible( + c10::DeviceType device_type, + const Tensor& arg, + Args... args) { + return is_eligible(arg, device_type); +} + +template +inline at::ScalarType type_from_firstarg( + c10::DeviceType device_type, + at::ScalarType to_type, + const Tensor& arg, + Args... args) { + return (is_eligible(arg, device_type) ? to_type : arg.scalar_type()); +} + +// Policies correspond to op categories that need code-divergent handling. +// Wrapper templates below are specialized based on a policy template parameter. +enum class CastPolicy : uint8_t { + lower_precision_fp = 0, // Cast all inputs to lower_precision_fp before + // running the op. Currently, lower_precision_fp is + // fp16 for AutocastCUDA, and is defined by user + // (default bf16) for AutocastCPU or other device. + fp32, // Cast all inputs to at::kFloat before running the op. + fp32_set_opt_dtype, // Treats functions (like softmax) that + // 1. we'd like to run in fp32 and + // 2. have a std::optional arg that controls + // the output type. + // fp32_set_opt_dtype wrappers' policy is: if the output + // type is already set, don't touch it, otherwise, set + // it to at::kFloat. + fp32_append_dtype, // Treats functions (like norm) that + // 1. we'd like to run in fp32 and + // 2. have some overloads that accept an output type and + // other overloads that don't. + // fp32_append_dtype wrappers wrap the overloads that don't + // have an output dtype. + // The wrapper policy is: append at::kFloat to the args, + // and redispatch to the type-aware overload. + promote, // Run in the widest dtype among several args. +}; + +/******************************************************************************************************** +Templates to provide wrapper functions + +I'm copying the pattern used in core/boxing/impl/WrapFunctionIntoFunctor.h to +extract args and return type. (see also +https://stackoverflow.com/questions/46533698/how-to-deduce-argument-list-from-function-pointer) + +This strategy uses an exterior "WrapFunction" that extracts arguments on behalf +of (in my case several specializations of) an interior "WrapFunction_". +Interior WrapFunction_ specializations are defined for each CastPolicy. +********************************************************************************************************/ + +// Base template for WrapFunction_, which is specialized to contain a "call" +// method each CastPolicy +template < + CastPolicy policy, + c10::DeviceType device_type, + class Redispatch, + Redispatch* F, + class Ret, + class ArgList> +struct WrapFunction_ {}; + +// CastPolicy::lower_precision_fp General_DeviceType +template < + c10::DeviceType device_type, + class Redispatch, + Redispatch* F, + class Ret, + class... Args> +struct WrapFunction_< + CastPolicy::lower_precision_fp, + device_type, + Redispatch, + F, + Ret, + guts::typelist::typelist> { + static Ret call(Args... args) { + c10::impl::ExcludeDispatchKeyGuard no_autocast( + get_autocast_dispatch_key_from_device_type(device_type)); + return (*F)(cached_cast( + get_lower_precision_fp_from_device_type(device_type), + args, + device_type)...); + } +}; + +// CastPolicy::fp32 General_DeviceType +template < + c10::DeviceType device_type, + class Redispatch, + Redispatch* F, + class Ret, + class... Args> +struct WrapFunction_< + CastPolicy::fp32, + device_type, + Redispatch, + F, + Ret, + guts::typelist::typelist> { + static Ret call(Args... args) { + c10::impl::ExcludeDispatchKeyGuard no_autocast( + get_autocast_dispatch_key_from_device_type(device_type)); + return (*F)(cached_cast(at::kFloat, args, device_type)...); + } +}; + +// CastPolicy::fp32_set_opt_dtype General_DeviceType +template < + c10::DeviceType device_type, + class Redispatch, + Redispatch* F, + class Ret, + class... Args> +struct WrapFunction_< + CastPolicy::fp32_set_opt_dtype, + device_type, + Redispatch, + F, + Ret, + guts::typelist::typelist> { + static Ret call(Args... args) { + c10::impl::ExcludeDispatchKeyGuard no_autocast( + get_autocast_dispatch_key_from_device_type(device_type)); + if (firstarg_is_eligible(device_type, args...)) { + return (*F)(set_opt_dtype(at::kFloat, args)...); + } else { + // If ineligible, calls F with unaltered args. Does not set opt dtype, + // because setting opt dtype explicitly may interfere with internal + // implicit promotion decisions. + return (*F)(args...); + } + } +}; + +// CastPolicy::fp32_append_dtype General_DeviceType +template < + c10::DeviceType device_type, + class Redispatch, + Redispatch* F, + class Ret, + class... Args> +struct WrapFunction_< + CastPolicy::fp32_append_dtype, + device_type, + Redispatch, + F, + Ret, + guts::typelist::typelist> { + static Ret call(Args... args) { + c10::impl::ExcludeDispatchKeyGuard no_autocast( + get_autocast_dispatch_key_from_device_type(device_type)); + at::ScalarType out_type = + type_from_firstarg(device_type, at::kFloat, args...); + return (*F)(args..., out_type); + } +}; + +// CastPolicy::promote General_DeviceType +template < + c10::DeviceType device_type, + class Redispatch, + Redispatch* F, + class Ret, + class... Args> +struct WrapFunction_< + CastPolicy::promote, + device_type, + Redispatch, + F, + Ret, + guts::typelist::typelist> { + static Ret call(Args... args) { + c10::impl::ExcludeDispatchKeyGuard no_autocast( + get_autocast_dispatch_key_from_device_type(device_type)); + auto to_type = promote_type( + get_lower_precision_fp_from_device_type(device_type), + device_type, + args...); + return (*F)(cached_cast(to_type, args, device_type)...); + } +}; + +// Wrapper to infer return_type and parameter_types for WrapFunction_ (imitating +// core/boxing/impl/WrapFunctionIntoFunctor.h) +template < + CastPolicy policy, + c10::DeviceType device_type, + class Registered, // The signature for which we're registering. The + // dispatcher's calling code invokes our registered + // functions with arguments matching Registered, so we + // register WrapFunction_::call methods with a matching + // signature to properly field those arguments. + // guts::function_traits below extracts return_type and + // parameter_types from Registered, which WrapFunction_ + // templates above use to declare their call methods. + class Redispatch, // The signature for the function we're redispatching to. + // In most cases this is the same as Registered, but for + // some ops (for example, ops where we append a dtype) + // it's useful to redispatch to a function with a + // different signature. + Redispatch* F> // The actual function we're redispatching to. +struct WrapFunction final { + using type = WrapFunction_< + policy, + device_type, + Redispatch, + F, + typename guts::function_traits::return_type, + typename guts::function_traits::parameter_types>; +}; + +/***************************************************************************************************************** +This section performs load-time registration for autocast wrappers. + +It's debatable at what level operations should be patched. We'd like casts to +be autograd-exposed and precede autograd history recording, so that for +lower_precision_fp ops, input tensors are saved for backward in +lower_precision_fp rather than fp32. Saving inputs in lower_precision_fp +can significantly reduce a model's memory footprint. + +Option 1 (strawman): Patch only at the level of explicit calls into +cudnn/cublas (cudnn_convolution, etc), because those are the code paths that are +guaranteed to use Tensor Cores, therefore they're the ones that will benefit +most from lower_precision_fp. Potential pitfall: convolutions (and other ops) +are wrapped in several layers of at::* calls. If one of those happens to record +autograd history, then we've lost the opportunity to save inputs in +lower_precision_fp. + +Option 2: Patch the Python-exposed surface of calls, to make 100% sure autograd +history recording can't sneak in ahead of autocast. This mirrors Apex most +closely. + +I think Option 2 is the right answer for all ops, not just convolutions. Option +2 is what I implement here. +*****************************************************************************************************************/ + +/******************************************************************************************************************** +Explicit registration for out-of-place ops + +The stuff below could be codegenned. Ed said +> you are going to have to write the function definition at some point, I +wouldn't try to get clever about it Therefore, for the moment, this is all +copy pasted in from VariableTypeEverything.cpp with appropriate substitutions. +********************************************************************************************************************/ + +} // namespace at::autocast + +#define ADD_NS(RAW_OP) at::RAW_OP + +#define _KERNEL_OVERLOAD_NARG_IMPL(_0, _1, _2, N, ...) N +#define _KERNEL_OVERLOAD_NARG(...) \ + C10_EXPAND_MSVC_WORKAROUND(_KERNEL_OVERLOAD_NARG_IMPL(__VA_ARGS__, 2, 1)) + +// Common cases where registration signature matches redispatch signature +// (that's why SIGNATURE is repeated in the WrapFunction instantiation) +#define KERNEL1(DISPATCHKEY, OP, POLICY) \ + m.impl( \ + TORCH_SELECTIVE_NAME("aten::" #OP), \ + &::at::autocast::WrapFunction< \ + ::at::autocast::CastPolicy::POLICY, \ + DISPATCHKEY, \ + decltype(ATEN_FN(OP)), \ + decltype(ATEN_FN(OP)), \ + &ATEN_FN(OP)>::type::call); + +#define KERNEL2(DISPATCHKEY, OP, OVERLOAD, POLICY) \ + m.impl( \ + TORCH_SELECTIVE_NAME("aten::" #OP "." #OVERLOAD), \ + &::at::autocast::WrapFunction< \ + ::at::autocast::CastPolicy::POLICY, \ + DISPATCHKEY, \ + decltype(ATEN_FN2(OP, OVERLOAD)), \ + decltype(ATEN_FN2(OP, OVERLOAD)), \ + &ATEN_FN2(OP, OVERLOAD)>::type::call); + +#define _KERNEL_DISPATCH(DISPATCHKEY, NARG, ...) \ + C10_CONCATENATE(KERNEL, NARG)(DISPATCHKEY, __VA_ARGS__) + +#define _KERNEL_IMPL(DISPATCHKEY, ...) \ + _KERNEL_DISPATCH(DISPATCHKEY, _KERNEL_OVERLOAD_NARG(__VA_ARGS__), __VA_ARGS__) + +// It will dispatch to KERNEL1 or KERNEL2 based on its inputs. +#define KERNEL(DISPATCHKEY, ...) _KERNEL_IMPL(DISPATCHKEY, __VA_ARGS__) + +// Less-common but still useful case: redispatching to a function +// with a new signature (e.g. appending a dtype) +#define KERNEL_DIFFERENT_REDISPATCH_SIGNATURE( \ + DISPATCHKEY, \ + REDISPATCH_FUNC, \ + REGISTER_NAME, \ + REGISTER_SIGNATURE, \ + REDISPATCH_SIGNATURE, \ + POLICY) \ + m.impl( \ + TORCH_SELECTIVE_NAME("aten::" REGISTER_NAME), \ + &::at::autocast::WrapFunction< \ + ::at::autocast::CastPolicy::POLICY, \ + DISPATCHKEY, \ + REGISTER_SIGNATURE, \ + REDISPATCH_SIGNATURE, \ + &REDISPATCH_FUNC>::type::call); + +// KERNEL_CPU/KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_CPU +// registration (OP, POLICY) or (OP, OVERLOAD, POLICY) for AutocastCPU +#define KERNEL_CPU(...) KERNEL(c10::DeviceType::CPU, __VA_ARGS__) + +#define KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_CPU( \ + REDISPATCH_FUNC, \ + REGISTER_NAME, \ + REGISTER_SIGNATURE, \ + REDISPATCH_SIGNATURE, \ + POLICY) \ + KERNEL_DIFFERENT_REDISPATCH_SIGNATURE( \ + c10::DeviceType::CPU, \ + REDISPATCH_FUNC, \ + REGISTER_NAME, \ + REGISTER_SIGNATURE, \ + REDISPATCH_SIGNATURE, \ + POLICY) + +// KERNEL_CUDA/KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_CUDA +// registration (OP, POLICY) or (OP, OVERLOAD, POLICY) for AutocastCUDA +#define KERNEL_CUDA(...) KERNEL(c10::DeviceType::CUDA, __VA_ARGS__) + +#define KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_CUDA( \ + REDISPATCH_FUNC, \ + REGISTER_NAME, \ + REGISTER_SIGNATURE, \ + REDISPATCH_SIGNATURE, \ + POLICY) \ + KERNEL_DIFFERENT_REDISPATCH_SIGNATURE( \ + c10::DeviceType::CUDA, \ + REDISPATCH_FUNC, \ + REGISTER_NAME, \ + REGISTER_SIGNATURE, \ + REDISPATCH_SIGNATURE, \ + POLICY) + +// KERNEL_MTIA/KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_MTIA +// registration (OP, POLICY) or (OP, OVERLOAD, POLICY) for AutocastMTIA +#define KERNEL_MTIA(...) KERNEL(c10::DeviceType::MTIA, __VA_ARGS__) + +#define KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_MTIA( \ + REDISPATCH_FUNC, \ + REGISTER_NAME, \ + REGISTER_SIGNATURE, \ + REDISPATCH_SIGNATURE, \ + POLICY) \ + KERNEL_DIFFERENT_REDISPATCH_SIGNATURE( \ + c10::DeviceType::MTIA, \ + REDISPATCH_FUNC, \ + REGISTER_NAME, \ + REGISTER_SIGNATURE, \ + REDISPATCH_SIGNATURE, \ + POLICY) + +// KERNEL_MAIA/KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_MAIA +// registration (OP, POLICY) or (OP, OVERLOAD, POLICY) for AutocastMAIA +#define KERNEL_MAIA(...) KERNEL(c10::DeviceType::MAIA, __VA_ARGS__) + +#define KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_MAIA( \ + REDISPATCH_FUNC, \ + REGISTER_NAME, \ + REGISTER_SIGNATURE, \ + REDISPATCH_SIGNATURE, \ + POLICY) \ + KERNEL_DIFFERENT_REDISPATCH_SIGNATURE( \ + c10::DeviceType::MAIA, \ + REDISPATCH_FUNC, \ + REGISTER_NAME, \ + REGISTER_SIGNATURE, \ + REDISPATCH_SIGNATURE, \ + POLICY) + +// KERNEL_XPU/KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_XPU +// registration (OP, POLICY) or (OP, OVERLOAD, POLICY) for AutocastXPU +#define KERNEL_XPU(...) KERNEL(c10::DeviceType::XPU, __VA_ARGS__) + +#define KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_XPU( \ + REDISPATCH_FUNC, \ + REGISTER_NAME, \ + REGISTER_SIGNATURE, \ + REDISPATCH_SIGNATURE, \ + POLICY) \ + KERNEL_DIFFERENT_REDISPATCH_SIGNATURE( \ + c10::DeviceType::XPU, \ + REDISPATCH_FUNC, \ + REGISTER_NAME, \ + REGISTER_SIGNATURE, \ + REDISPATCH_SIGNATURE, \ + POLICY) + +// KERNEL_PRIVATEUSEONE/KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_PRIVATEUSEONE +// registration (OP, POLICY) or (OP, OVERLOAD, POLICY) for AutocastPrivateUse1 +#define KERNEL_PRIVATEUSEONE(...) \ + KERNEL(c10::DeviceType::PrivateUse1, __VA_ARGS__) + +#define KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_PRIVATEUSEONE( \ + REDISPATCH_FUNC, \ + REGISTER_NAME, \ + REGISTER_SIGNATURE, \ + REDISPATCH_SIGNATURE, \ + POLICY) \ + KERNEL_DIFFERENT_REDISPATCH_SIGNATURE( \ + c10::DeviceType::PrivateUse1, \ + REDISPATCH_FUNC, \ + REGISTER_NAME, \ + REGISTER_SIGNATURE, \ + REDISPATCH_SIGNATURE, \ + POLICY) + +// KERNEL_MPS +// registration (OP, POLICY) or (OP, OVERLOAD, POLICY) for AutocastMPS +#define KERNEL_MPS(...) KERNEL(c10::DeviceType::MPS, __VA_ARGS__) + +// Op lists for different policies. +// To make sure other backends can reuse the policy op list. +#define AT_FORALL_LOWER_PRECISION_FP(_) \ + _(_convolution, deprecated) \ + _(_convolution) \ + _(conv1d) \ + _(conv2d) \ + _(conv3d) \ + _(conv_tbc) \ + _(conv_transpose1d) \ + _(conv_transpose2d, input) \ + _(conv_transpose3d, input) \ + _(convolution) \ + _(prelu) \ + _(addmm) \ + _(addmv) \ + _(addr) \ + _(matmul) \ + _(einsum) \ + _(mm) \ + _(mv) \ + _(linalg_vecdot) \ + _(linear) \ + _(addbmm) \ + _(baddbmm) \ + _(bmm) \ + _(chain_matmul) \ + _(linalg_multi_dot) \ + _(_thnn_fused_lstm_cell) \ + _(_thnn_fused_gru_cell) \ + _(lstm_cell) \ + _(gru_cell) \ + _(rnn_tanh_cell) \ + _(rnn_relu_cell) \ + _(_scaled_dot_product_flash_attention) \ + _(scaled_dot_product_attention) + +#define AT_FORALL_FP32(_) \ + _(acos) \ + _(asin) \ + _(cosh) \ + _(erfinv) \ + _(exp) \ + _(expm1) \ + _(log) \ + _(log10) \ + _(log2) \ + _(log1p) \ + _(reciprocal) \ + _(rsqrt) \ + _(sinh) \ + _(tan) \ + _(pow, Tensor_Scalar) \ + _(pow, Tensor_Tensor) \ + _(pow, Scalar) \ + _(softplus) \ + _(layer_norm) \ + _(native_layer_norm) \ + _(group_norm) \ + _(frobenius_norm, dim) \ + _(nuclear_norm) \ + _(nuclear_norm, dim) \ + _(cosine_similarity) \ + _(poisson_nll_loss) \ + _(cosine_embedding_loss) \ + _(nll_loss) \ + _(nll_loss2d) \ + _(hinge_embedding_loss) \ + _(kl_div) \ + _(l1_loss) \ + _(smooth_l1_loss) \ + _(huber_loss) \ + _(mse_loss) \ + _(margin_ranking_loss) \ + _(multilabel_margin_loss) \ + _(soft_margin_loss) \ + _(triplet_margin_loss) \ + _(multi_margin_loss) \ + _(binary_cross_entropy_with_logits) \ + _(dist) \ + _(pdist) \ + _(cdist) \ + _(renorm) \ + _(logsumexp) \ + _(upsample_nearest1d) \ + _(_upsample_nearest_exact1d) \ + _(upsample_nearest2d) \ + _(_upsample_nearest_exact2d) \ + _(upsample_nearest3d) \ + _(_upsample_nearest_exact3d) \ + _(upsample_linear1d) \ + _(upsample_bilinear2d) \ + _(_upsample_bilinear2d_aa) \ + _(upsample_trilinear3d) \ + _(upsample_bicubic2d) \ + _(_upsample_bicubic2d_aa) + +#define AT_FORALL_FP32_SET_OPT_DTYPE(_) \ + _(prod) \ + _(prod, dim_int) \ + _(prod, dim_Dimname) \ + _(softmax, int) \ + _(softmax, Dimname) \ + _(log_softmax, int) \ + _(log_softmax, Dimname) \ + _(cumprod) \ + _(cumprod, dimname) \ + _(cumsum) \ + _(cumsum, dimname) \ + _(linalg_vector_norm) \ + _(linalg_matrix_norm) \ + _(linalg_matrix_norm, str_ord) \ + _(sum) \ + _(sum, dim_IntList) \ + _(sum, dim_DimnameList) + +#define AT_FORALL_DIFFERENT_REDISPATCH_SIGNATURE(_) \ + _(ADD_NS(norm), \ + "norm.Scalar", \ + Tensor(const Tensor&, const Scalar&), \ + Tensor(const Tensor&, const std::optional&, ScalarType), \ + fp32_append_dtype) \ + _(ADD_NS(norm), \ + "norm.ScalarOpt_dim", \ + Tensor(const Tensor&, const std::optional&, IntArrayRef, bool), \ + Tensor( \ + const Tensor&, \ + const std::optional&, \ + IntArrayRef, \ + bool, \ + ScalarType), \ + fp32_append_dtype) \ + _(ADD_NS(norm), \ + "norm.names_ScalarOpt_dim", \ + Tensor(const Tensor&, const std::optional&, DimnameList, bool), \ + Tensor( \ + const Tensor&, \ + const std::optional&, \ + DimnameList, \ + bool, \ + ScalarType), \ + fp32_append_dtype) + +#define AT_FORALL_PROMOTE(_) \ + _(addcdiv) \ + _(addcmul) \ + _(atan2) \ + _(bilinear) \ + _(cross) \ + _(dot) \ + _(vdot) \ + _(grid_sampler) \ + _(index_put) \ + _(tensordot) \ + _(scatter_add) + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/ceil_div.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/ceil_div.h new file mode 100644 index 0000000000000000000000000000000000000000..416bd91640662022385ed83ff64b98467b882898 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/ceil_div.h @@ -0,0 +1,29 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +#include +#include + +namespace at { + +/** + Computes ceil(a / b) +*/ +template >> +C10_ALWAYS_INLINE C10_HOST_DEVICE T ceil_div(T a, T b) { + return (a + b - 1) / b; +} + +/** + Computes ceil(a / b) * b; i.e., rounds up `a` to the next highest + multiple of b +*/ +template +C10_ALWAYS_INLINE C10_HOST_DEVICE T round_up(T a, T b) { + return ceil_div(a, b) * b; +} + +} // namespace at + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/code_template.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/code_template.h new file mode 100644 index 0000000000000000000000000000000000000000..593a55aa70c72f1357729410ae3c3cf8a395c966 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/code_template.h @@ -0,0 +1,250 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include + +#include +#include +#include +#include + +namespace at::jit { + +// A template environment is a mapping from template variable names, e.g., +// identifier (corresponding to $identifier) to their expansions. +// +// This template environment supports storing strings, numbers and lists +// of strings, and can be chained together (so that lookup proceeds in +// in the top level environment, and then recurses into a parent +// environment if the key is not found.) +struct TemplateEnv { + TemplateEnv() = default; + TemplateEnv(TemplateEnv& parent) : parent(&parent) {} + TemplateEnv(TemplateEnv&&) = delete; + TemplateEnv& operator=(const TemplateEnv& parent) = delete; + TemplateEnv& operator=(TemplateEnv&& parent) = delete; + ~TemplateEnv() = default; + + using string_list = std::vector; + + // Add a string 'v' to the map at key 'k'. + void s(const std::string& k, const std::string& v) { + strings_[k] = v; + lists_.erase(k); + } + + // Add a number 'v' to the map at key 'k' + template + void d(const std::string& k, const T& v) { + strings_[k] = std::to_string(v); + lists_.erase(k); + } + + // Retrieve the string representation of the value stored at 'k' from the map. + // Raises an exception if the key is not found. + const std::string& s(const std::string& k) const { + if (strings_.count(k) == 0) { + if (parent) { + return parent->s(k); + } + notFound(k); + } + return strings_.at(k); + } + + // Store a list of strings 'v' in the map at 'k'. + void v(const std::string& k, const string_list& v) { + lists_[k] = v; + strings_.erase(k); + } + + // Retrieve a list of strings stored at 'k' from the map. + // Raises an exception if the key is not found. + const string_list& v(const std::string& k) const { + if (lists_.count(k) == 0) { + if (parent) { + return parent->v(k); + } + notFound(k); + } + return lists_.at(k); + } + + // Test if a string 'k' is a string (as opposed to a list.) + bool keyIsString(const std::string& k) const { + if (strings_.count(k) > 0) + return true; + if (lists_.count(k) > 0) + return false; + if (parent) + return parent->keyIsString(k); + notFound(k); + } + + private: + [[noreturn]] void notFound(const std::string& k) const { + std::stringstream ss; + ss << "key not found: " << k; + throw std::logic_error(ss.str()); + } + + std::unordered_map strings_; + std::unordered_map lists_; + TemplateEnv* parent{nullptr}; +}; + +/* +# Match $identifier or ${identifier} and replace with the value in env. +# If this identifier is at the beginning of whitespace on a line +# and its value is a list then it is treated as +# block substitution by indenting all lines of all elements. +# If the identifier is on a line starting with non-whitespace and a list +# then it is comma separated. ${,foo} will insert a comma before the list +# if this list is not empty and ${foo,} will insert one after. +*/ +struct CodeTemplate { + /* implicit */ CodeTemplate(std::string t) : template_text(std::move(t)) {} + + std::string format(const TemplateEnv& env) const { + std::stringstream out; + size_t pos = 0; + size_t indent = 0; + bool all_whitespace = true; + while (pos < template_text.size()) { + char c = template_text[pos]; + if (c == '$') { + std::stringstream kss; + bool comma_before = false; + bool comma_after = false; + size_t new_pos = parseKey(pos, kss, comma_before, comma_after); + std::string k = kss.str(); + bool is_string = env.keyIsString(k); + if (all_whitespace) { + if (is_string) + emitStringWithIndents(out, indent, env.s(k)); + else + emitLinesIndented(out, indent, env.v(k)); + } else { + if (is_string) + out << env.s(k); + else + emitCommaSeparatedList(out, env.v(k), comma_before, comma_after); + } + all_whitespace = false; + pos = new_pos; + } else { + out << c; + if (!isspace(c)) + all_whitespace = false; + indent++; + if (c == '\n') { + indent = 0; + all_whitespace = true; + } + pos++; + } + } + return out.str(); + } + + private: + using string_list = std::vector; + char charAt(size_t p) const { + if (p >= template_text.size()) + throw std::logic_error("EOS found in key"); + return template_text[p]; + } + size_t parseKey( + size_t pos, + std::ostream& k, + bool& comma_before, + bool& comma_after) const { + comma_before = false; + comma_after = false; + pos++; + if (charAt(pos) == '{') { + pos++; + if (charAt(pos) == ',') { + comma_before = true; + pos++; + } + pos = parseIdent(pos, k); + if (charAt(pos) == ',') { + comma_after = true; + pos++; + } + if (charAt(pos) != '}') + throw std::logic_error("missing terminating '}'"); + pos++; + return pos; + } else { + return parseIdent(pos, k); + } + } + size_t parseIdent(size_t pos, std::ostream& k) const { + while (pos < template_text.size() && + (isalnum(template_text[pos]) || template_text[pos] == '_')) { + k << template_text[pos]; + pos++; + } + return pos; + } + void emitCommaSeparatedList( + std::ostream& out, + const string_list& strings, + bool comma_before, + bool comma_after) const { + if (comma_before && !strings.empty()) + out << ", "; + for (const auto i : c10::irange(strings.size())) { + if (i > 0) + out << ", "; + out << strings[i]; + } + if (comma_after && !strings.empty()) + out << ", "; + } + // These indentation functions follow the convention that they never emit + // leading or trailing newlines when the input string does not have leading + // or trailing newlines. It's the responsibility of the calling function + // to indent correctly in the context. + void emitIndent(std::ostream& out, size_t indent) const { + for ([[maybe_unused]] const auto i : c10::irange(indent)) { + out << ' '; + } + } + void emitStringWithIndents( + std::ostream& out, + size_t indent, + const std::string& str) const { + for (auto c : str) { + out << c; + if (c == '\n') { + emitIndent(out, indent); + } + } + } + void emitLinesIndented( + std::stringstream& out, + size_t indent, + const string_list& strings) const { + for (const auto i : c10::irange(strings.size())) { + if (i > 0) + emitIndent(out, indent); + emitStringWithIndents(out, indent, strings[i]); + if (i + 1 != strings.size()) + out << '\n'; + } + } + std::string template_text; +}; + +static inline std::string format(const std::string& fmt, TemplateEnv& env) { + return CodeTemplate(fmt).format(env); +} + +} // namespace at::jit + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/cpp_custom_type_hack.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/cpp_custom_type_hack.h new file mode 100644 index 0000000000000000000000000000000000000000..73137aef68809ce395dba1245e58195996f1f607 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/cpp_custom_type_hack.h @@ -0,0 +1,115 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP + +// YOU ARE IN THE WRONG PLACE! TURN BACK NOW! + +// This code was a temporary hack to enable embedding arbitrary C++ structures +// into Tensors. THIS IS UNSAFE AND IS NOT SUPPORTED. IF YOU USE THIS CODE, +// IT __WILL__ BREAK. + +// This code has been superseded by custom classes: +// https://pytorch.org/tutorials/advanced/torch_script_custom_classes.html + +// Please use custom classes and **DO NOT ADD MORE CALLSITES TO THINGS DEFINED +// IN THIS FILE**. + +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP + +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#endif + +namespace at::cpp_custom_type_hack { + +template +[[deprecated( + "Use custom classes instead: " + "https://pytorch.org/tutorials/advanced/torch_script_custom_classes.html")]] bool +isa(const Tensor& packed) { + return (packed.scalar_type() == kByte) && + (packed.storage().data_ptr().get_deleter() == + caffe2::TypeMeta::Make().deleteFn()); +} + +template +[[deprecated( + "Use custom classes instead: " + "https://pytorch.org/tutorials/advanced/torch_script_custom_classes.html")]] T& +cast(const Tensor& packed) { + TORCH_CHECK( + packed.scalar_type() == kByte, "Expected temporary cpp type wrapper"); + TORCH_CHECK( + packed.storage().data_ptr().get_deleter() == + caffe2::TypeMeta::Make().deleteFn(), + "Expected temporary cpp type wrapper of type ", + caffe2::TypeMeta::TypeName()); + return *reinterpret_cast(packed.storage().data_ptr().get()); +} + +template +[[deprecated( + "Use custom classes instead: " + "https://pytorch.org/tutorials/advanced/torch_script_custom_classes.html")]] Tensor +create(std::unique_ptr ptr, TensorOptions options) { + // None of this should trace, so turn off Tracer dispatching + at::AutoDispatchBelowADInplaceOrView guard; // TODO: remove + at::tracer::impl::NoTracerDispatchMode tracer_guard; + + // We store this instance away in a Tensor and register a deleter function + // so that we do not leak memory. On the other side, we pull out the storage's + // data_ptr and get the right typed pointer. + void* raw_ptr = ptr.release(); + at::DataPtr at_ptr( + raw_ptr, raw_ptr, caffe2::TypeMeta::Make().deleteFn(), at::kCPU); + + // size doesn't really matter, but we can align it to the actual size + // returning variables because one likely want to use this hack from python + auto retval = at::empty({sizeof(T)}, options.device(kCPU).dtype(at::kByte)); + retval.storage().set_data_ptr_noswap(std::move(at_ptr)); + return retval; +} + +} // namespace at::cpp_custom_type_hack + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/div_rtn.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/div_rtn.h new file mode 100644 index 0000000000000000000000000000000000000000..888e67703821286fa0e3150e7463aeaf4518e175 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/div_rtn.h @@ -0,0 +1,16 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +// Integer division rounding to -Infinity +template +static inline T div_rtn(T x, T y) { + int q = x / y; + int r = x % y; + if ((r != 0) && ((r < 0) != (y < 0))) + --q; + return q; +} + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/dlpack.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/dlpack.h new file mode 100644 index 0000000000000000000000000000000000000000..c159b677e79e02f8610f72532f1ee4e5487de322 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/dlpack.h @@ -0,0 +1,646 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +/*! + * Copyright (c) 2017 - by Contributors + * \file dlpack.h + * \brief The common header of DLPack. + */ +#ifndef DLPACK_DLPACK_H_ +#define DLPACK_DLPACK_H_ + +/** + * \brief Compatibility with C++ + */ +#ifdef __cplusplus +#define DLPACK_EXTERN_C extern "C" +#else +#define DLPACK_EXTERN_C +#endif + +/*! \brief The current major version of dlpack */ +#define DLPACK_MAJOR_VERSION 1 + +/*! \brief The current minor version of dlpack */ +#define DLPACK_MINOR_VERSION 3 + +/*! \brief DLPACK_DLL prefix for windows */ +#ifdef _WIN32 +#ifdef DLPACK_EXPORTS +#define DLPACK_DLL __declspec(dllexport) +#else +#define DLPACK_DLL __declspec(dllimport) +#endif +#else +#define DLPACK_DLL +#endif + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/*! + * \brief The DLPack version. + * + * A change in major version indicates that we have changed the + * data layout of the ABI - DLManagedTensorVersioned. + * + * A change in minor version indicates that we have added new + * code, such as a new device type, but the ABI is kept the same. + * + * If an obtained DLPack tensor has a major version that disagrees + * with the version number specified in this header file + * (i.e. major != DLPACK_MAJOR_VERSION), the consumer must call the deleter + * (and it is safe to do so). It is not safe to access any other fields + * as the memory layout will have changed. + * + * In the case of a minor version mismatch, the tensor can be safely used as + * long as the consumer knows how to interpret all fields. Minor version + * updates indicate the addition of enumeration values. + */ +typedef struct { + /*! \brief DLPack major version. */ + uint32_t major; + /*! \brief DLPack minor version. */ + uint32_t minor; +} DLPackVersion; + +/*! + * \brief The device type in DLDevice. + */ +#ifdef __cplusplus +typedef enum : int32_t { +#else +typedef enum { +#endif + /*! \brief CPU device */ + kDLCPU = 1, + /*! \brief CUDA GPU device */ + kDLCUDA = 2, + /*! + * \brief Pinned CUDA CPU memory by cudaMallocHost + */ + kDLCUDAHost = 3, + /*! \brief OpenCL devices. */ + kDLOpenCL = 4, + /*! \brief Vulkan buffer for next generation graphics. */ + kDLVulkan = 7, + /*! \brief Metal for Apple GPU. */ + kDLMetal = 8, + /*! \brief Verilog simulator buffer */ + kDLVPI = 9, + /*! \brief ROCm GPUs for AMD GPUs */ + kDLROCM = 10, + /*! + * \brief Pinned ROCm CPU memory allocated by hipMallocHost + */ + kDLROCMHost = 11, + /*! + * \brief Reserved extension device type, + * used for quickly test extension device + * The semantics can differ depending on the implementation. + */ + kDLExtDev = 12, + /*! + * \brief CUDA managed/unified memory allocated by cudaMallocManaged + */ + kDLCUDAManaged = 13, + /*! + * \brief Unified shared memory allocated on a oneAPI non-partititioned + * device. Call to oneAPI runtime is required to determine the device + * type, the USM allocation type and the sycl context it is bound to. + * + */ + kDLOneAPI = 14, + /*! \brief GPU support for next generation WebGPU standard. */ + kDLWebGPU = 15, + /*! \brief Qualcomm Hexagon DSP */ + kDLHexagon = 16, + /*! \brief Microsoft MAIA devices */ + kDLMAIA = 17, + /*! \brief AWS Trainium */ + kDLTrn = 18, +} DLDeviceType; + +/*! + * \brief A Device for Tensor and operator. + */ +typedef struct { + /*! \brief The device type used in the device. */ + DLDeviceType device_type; + /*! + * \brief The device index. + * For vanilla CPU memory, pinned memory, or managed memory, this is set to 0. + */ + int32_t device_id; +} DLDevice; + +/*! + * \brief The type code options DLDataType. + */ +typedef enum { + /*! \brief signed integer */ + kDLInt = 0U, + /*! \brief unsigned integer */ + kDLUInt = 1U, + /*! \brief IEEE floating point */ + kDLFloat = 2U, + /*! + * \brief Opaque handle type, reserved for testing purposes. + * Frameworks need to agree on the handle data type for the exchange to be well-defined. + */ + kDLOpaqueHandle = 3U, + /*! \brief bfloat16 */ + kDLBfloat = 4U, + /*! + * \brief complex number + * (C/C++/Python layout: compact struct per complex number) + */ + kDLComplex = 5U, + /*! \brief boolean */ + kDLBool = 6U, + /*! \brief FP8 data types */ + kDLFloat8_e3m4 = 7U, + kDLFloat8_e4m3 = 8U, + kDLFloat8_e4m3b11fnuz = 9U, + kDLFloat8_e4m3fn = 10U, + kDLFloat8_e4m3fnuz = 11U, + kDLFloat8_e5m2 = 12U, + kDLFloat8_e5m2fnuz = 13U, + kDLFloat8_e8m0fnu = 14U, + /*! \brief FP6 data types + * Setting bits != 6 is currently unspecified, and the producer must ensure it is set + * while the consumer must stop importing if the value is unexpected. + */ + kDLFloat6_e2m3fn = 15U, + kDLFloat6_e3m2fn = 16U, + /*! \brief FP4 data types + * Setting bits != 4 is currently unspecified, and the producer must ensure it is set + * while the consumer must stop importing if the value is unexpected. + */ + kDLFloat4_e2m1fn = 17U, +} DLDataTypeCode; + +/*! + * \brief The data type the tensor can hold. The data type is assumed to follow the + * native endian-ness. An explicit error message should be raised when attempting to + * export an array with non-native endianness + * + * Examples + * - float: type_code = 2, bits = 32, lanes = 1 + * - float4(vectorized 4 float): type_code = 2, bits = 32, lanes = 4 + * - int8: type_code = 0, bits = 8, lanes = 1 + * - std::complex: type_code = 5, bits = 64, lanes = 1 + * - bool: type_code = 6, bits = 8, lanes = 1 (as per common array library convention, the underlying storage size of bool is 8 bits) + * - float8_e4m3: type_code = 8, bits = 8, lanes = 1 (packed in memory) + * - float6_e3m2fn: type_code = 16, bits = 6, lanes = 1 (packed in memory) + * - float4_e2m1fn: type_code = 17, bits = 4, lanes = 1 (packed in memory) + * + * When a sub-byte type is packed, DLPack requires the data to be in little bit-endian, i.e., + * for a packed data set D ((D >> (i * bits)) && bit_mask) stores the i-th element. + */ +typedef struct { + /*! + * \brief Type code of base types. + * We keep it uint8_t instead of DLDataTypeCode for minimal memory + * footprint, but the value should be one of DLDataTypeCode enum values. + * */ + uint8_t code; + /*! + * \brief Number of bits, common choices are 8, 16, 32. + */ + uint8_t bits; + /*! \brief Number of lanes in the type, used for vector types. */ + uint16_t lanes; +} DLDataType; + +/*! + * \brief Plain C Tensor object, does not manage memory. + */ +typedef struct { + /*! + * \brief The data pointer points to the allocated data. This will be CUDA + * device pointer or cl_mem handle in OpenCL. It may be opaque on some device + * types. This pointer is always aligned to 256 bytes as in CUDA. The + * `byte_offset` field should be used to point to the beginning of the data. + * + * Note that as of Nov 2021, multiple libraries (CuPy, PyTorch, TensorFlow, + * TVM, perhaps others) do not adhere to this 256 byte alignment requirement + * on CPU/CUDA/ROCm, and always use `byte_offset=0`. This must be fixed + * (after which this note will be updated); at the moment it is recommended + * to not rely on the data pointer being correctly aligned. + * + * For given DLTensor, the size of memory required to store the contents of + * data is calculated as follows: + * + * \code{.c} + * static inline size_t GetDataSize(const DLTensor* t) { + * size_t size = 1; + * for (tvm_index_t i = 0; i < t->ndim; ++i) { + * size *= t->shape[i]; + * } + * size *= (t->dtype.bits * t->dtype.lanes + 7) / 8; + * return size; + * } + * \endcode + * + * Note that if the tensor is of size zero, then the data pointer should be + * set to `NULL`. + */ + void* data; + /*! \brief The device of the tensor */ + DLDevice device; + /*! \brief Number of dimensions */ + int32_t ndim; + /*! \brief The data type of the pointer*/ + DLDataType dtype; + /*! + * \brief The shape of the tensor + * + * When ndim == 0, shape can be set to NULL. + */ + int64_t* shape; + /*! + * \brief strides of the tensor (in number of elements, not bytes), + * can not be NULL if ndim != 0, must points to + * an array of ndim elements that specifies the strides, + * so consumer can always rely on strides[dim] being valid for 0 <= dim < ndim. + * + * When ndim == 0, strides can be set to NULL. + * + * \note Before DLPack v1.2, strides can be NULL to indicate contiguous data. + * This is not allowed in DLPack v1.2 and later. The rationale + * is to simplify the consumer handling. + */ + int64_t* strides; + /*! \brief The offset in bytes to the beginning pointer to data */ + uint64_t byte_offset; +} DLTensor; + +/*! + * \brief C Tensor object, manage memory of DLTensor. This data structure is + * intended to facilitate the borrowing of DLTensor by another framework. It is + * not meant to transfer the tensor. When the borrowing framework doesn't need + * the tensor, it should call the deleter to notify the host that the resource + * is no longer needed. + * + * \note This data structure is used as Legacy DLManagedTensor + * in DLPack exchange and is deprecated after DLPack v0.8 + * Use DLManagedTensorVersioned instead. + * This data structure may get renamed or deleted in future versions. + * + * \sa DLManagedTensorVersioned + */ +typedef struct DLManagedTensor { + /*! \brief DLTensor which is being memory managed */ + DLTensor dl_tensor; + /*! \brief the context of the original host framework of DLManagedTensor in + * which DLManagedTensor is used in the framework. It can also be NULL. + */ + void * manager_ctx; + /*! + * \brief Destructor - this should be called + * to destruct the manager_ctx which backs the DLManagedTensor. It can be + * NULL if there is no way for the caller to provide a reasonable destructor. + * The destructor deletes the argument self as well. + */ + void (*deleter)(struct DLManagedTensor * self); +} DLManagedTensor; + +// bit masks used in the DLManagedTensorVersioned + +/*! \brief bit mask to indicate that the tensor is read only. */ +#define DLPACK_FLAG_BITMASK_READ_ONLY (1UL << 0UL) + +/*! + * \brief bit mask to indicate that the tensor is a copy made by the producer. + * + * If set, the tensor is considered solely owned throughout its lifetime by the + * consumer, until the producer-provided deleter is invoked. + */ +#define DLPACK_FLAG_BITMASK_IS_COPIED (1UL << 1UL) + +/*! + * \brief bit mask to indicate that whether a sub-byte type is packed or padded. + * + * The default for sub-byte types (ex: fp4/fp6) is assumed packed. This flag can + * be set by the producer to signal that a tensor of sub-byte type is padded. + */ +#define DLPACK_FLAG_BITMASK_IS_SUBBYTE_TYPE_PADDED (1UL << 2UL) + +/*! + * \brief A versioned and managed C Tensor object, manage memory of DLTensor. + * + * This data structure is intended to facilitate the borrowing of DLTensor by + * another framework. It is not meant to transfer the tensor. When the borrowing + * framework doesn't need the tensor, it should call the deleter to notify the + * host that the resource is no longer needed. + * + * \note This is the current standard DLPack exchange data structure. + */ +typedef struct DLManagedTensorVersioned { + /*! + * \brief The API and ABI version of the current managed Tensor + */ + DLPackVersion version; + /*! + * \brief the context of the original host framework. + * + * Stores DLManagedTensorVersioned is used in the + * framework. It can also be NULL. + */ + void *manager_ctx; + /*! + * \brief Destructor. + * + * This should be called to destruct manager_ctx which holds the DLManagedTensorVersioned. + * It can be NULL if there is no way for the caller to provide a reasonable + * destructor. The destructor deletes the argument self as well. + */ + void (*deleter)(struct DLManagedTensorVersioned *self); + /*! + * \brief Additional bitmask flags information about the tensor. + * + * By default the flags should be set to 0. + * + * \note Future ABI changes should keep everything until this field + * stable, to ensure that deleter can be correctly called. + * + * \sa DLPACK_FLAG_BITMASK_READ_ONLY + * \sa DLPACK_FLAG_BITMASK_IS_COPIED + */ + uint64_t flags; + /*! \brief DLTensor which is being memory managed */ + DLTensor dl_tensor; +} DLManagedTensorVersioned; + +//---------------------------------------------------------------------- +// DLPack `__dlpack_c_exchange_api__` fast exchange protocol definitions +//---------------------------------------------------------------------- +/*! + * \brief Request a producer library to create a new tensor. + * + * Create a new `DLManagedTensorVersioned` within the context of the producer + * library. The allocation is defined via the prototype DLTensor. + * + * This function is exposed by the framework through the DLPackExchangeAPI. + * + * \param prototype The prototype DLTensor. Only the dtype, ndim, shape, + * and device fields are used. + * \param out The output DLManagedTensorVersioned. + * \param error_ctx Context for `SetError`. + * \param SetError The function to set the error. + * \return The owning DLManagedTensorVersioned* or NULL on failure. + * SetError is called exactly when NULL is returned (the implementer + * must ensure this). + * \note - As a C function, must not thrown C++ exceptions. + * - Error propagation via SetError to avoid any direct need + * of Python API. Due to this `SetError` may have to ensure the GIL is + * held since it will presumably set a Python error. + * + * \sa DLPackExchangeAPI + */ +typedef int (*DLPackManagedTensorAllocator)( // + DLTensor* prototype, DLManagedTensorVersioned** out, void* error_ctx, // + void (*SetError)(void* error_ctx, const char* kind, const char* message) // +); + +/*! + * \brief Exports a PyObject* Tensor/NDArray to a DLManagedTensorVersioned. + * + * This function does not perform any stream synchronization. The consumer should query + * DLPackCurrentWorkStream to get the current work stream and launch kernels on it. + * + * This function is exposed by the framework through the DLPackExchangeAPI. + * + * \param py_object The Python object to convert. Must have the same type + * as the one the `DLPackExchangeAPI` was discovered from. + * \return The owning DLManagedTensorVersioned* or NULL on failure with a + * Python exception set. If the data cannot be described using DLPack + * this should be a BufferError if possible. + * \note - As a C function, must not thrown C++ exceptions. + * + * \sa DLPackExchangeAPI, DLPackCurrentWorkStream + */ +typedef int (*DLPackManagedTensorFromPyObjectNoSync)( // + void* py_object, // + DLManagedTensorVersioned** out // +); + +/*! + * \brief Exports a PyObject* Tensor/NDArray to a provided DLTensor. + * + * This function provides a faster interface for temporary, non-owning, + * exchange. The producer (implementer) still owns the memory of data, strides, + * shape. The liveness of the DLTensor and the data it views is only guaranteed + * until control is returned. + * + * This function currently assumes that the producer (implementer) can fill + * in the DLTensor shape and strides without the need for temporary allocations. + * + * This function does not perform any stream synchronization. The consumer + * should query DLPackCurrentWorkStream to get the current work stream and + * launch kernels on it. + * + * This function is exposed by the framework through the DLPackExchangeAPI. + * + * \param py_object The Python object to convert. Must have the same type + * as the one the `DLPackExchangeAPI` was discovered from. + * \param out The output DLTensor, whose space is pre-allocated on stack. + * \return 0 on success, -1 on failure with a Python exception set. + * \note - As a C function, must not thrown C++ exceptions. + * + * \sa DLPackExchangeAPI, DLPackCurrentWorkStream + */ +typedef int (*DLPackDLTensorFromPyObjectNoSync)( // + void* py_object, // + DLTensor* out // +); + +/*! + * \brief Obtain the current work stream of a device. + * + * Obtain the current work stream of a device from the producer framework. + * For example, it should map to torch.cuda.current_stream in PyTorch. + * + * When device_type is kDLCPU, the consumer do not have to query the stream + * and the producer can simply return NULL when queried. + * The consumer do not have to do anything on stream sync or setting. + * So CPU only framework can just provide a dummy implementation that + * always set out_current_stream[0] to NULL. + * + * \param device_type The device type. + * \param device_id The device id. + * \param out_current_stream The output current work stream. + * + * \return 0 on success, -1 on failure with a Python exception set. + * \note - As a C function, must not thrown C++ exceptions. + * + * \sa DLPackExchangeAPI + */ +typedef int (*DLPackCurrentWorkStream)( // + DLDeviceType device_type, // + int32_t device_id, // + void** out_current_stream // +); + +/*! + * \brief Imports a DLManagedTensorVersioned to a PyObject* Tensor/NDArray. + * + * Convert an owning DLManagedTensorVersioned* to the Python tensor of the + * producer (implementer) library with the correct type. + * + * This function does not perform any stream synchronization. + * + * This function is exposed by the framework through the DLPackExchangeAPI. + * + * \param tensor The DLManagedTensorVersioned to convert the ownership of the + * tensor is stolen. + * \param out_py_object The output Python object. + * \return 0 on success, -1 on failure with a Python exception set. + * + * \sa DLPackExchangeAPI + */ +typedef int (*DLPackManagedTensorToPyObjectNoSync)( // + DLManagedTensorVersioned* tensor, // + void** out_py_object // +); + +/*! + * \brief DLPackExchangeAPI stable header. + * \sa DLPackExchangeAPI + */ +typedef struct DLPackExchangeAPIHeader { + /*! + * \brief The provided DLPack version the consumer must check major version + * compatibility before using this struct. + */ + DLPackVersion version; + /*! + * \brief Optional pointer to an older DLPackExchangeAPI in the chain. + * + * It must be NULL if the framework does not support older versions. + * If the current major version is larger than the one supported by the + * consumer, the consumer may walk this to find an earlier supported version. + * + * \sa DLPackExchangeAPI + */ + struct DLPackExchangeAPIHeader* prev_api; +} DLPackExchangeAPIHeader; + +/*! + * \brief Framework-specific function pointers table for DLPack exchange. + * + * Additionally to `__dlpack__()` we define a C function table sharable by + * + * Python implementations via `__dlpack_c_exchange_api__`. + * This attribute must be set on the type as a Python PyCapsule + * with name "dlpack_exchange_api". + * + * A consumer library may use a pattern such as: + * + * \code + * + * PyObject *api_obj = type(tensor_obj).__dlpack_c_exchange_api__; // as C-code + * MyDLPackExchangeAPI *api = PyCapsule_GetPointer(api_obj, "dlpack_exchange_api"); + * if (api == NULL && PyErr_Occurred()) { goto handle_error; } + * + * \endcode + * + * Note that this must be defined on the type. The consumer should look up the + * attribute on the type and may cache the result for each unique type. + * + * The precise API table is given by: + * \code + * struct MyDLPackExchangeAPI : public DLPackExchangeAPI { + * MyDLPackExchangeAPI() { + * header.version.major = DLPACK_MAJOR_VERSION; + * header.version.minor = DLPACK_MINOR_VERSION; + * header.prev_version_api = nullptr; + * + * managed_tensor_allocator = MyDLPackManagedTensorAllocator; + * managed_tensor_from_py_object_no_sync = MyDLPackManagedTensorFromPyObjectNoSync; + * managed_tensor_to_py_object_no_sync = MyDLPackManagedTensorToPyObjectNoSync; + * dltensor_from_py_object_no_sync = MyDLPackDLTensorFromPyObjectNoSync; + * current_work_stream = MyDLPackCurrentWorkStream; + * } + * + * static const DLPackExchangeAPI* Global() { + * static MyDLPackExchangeAPI inst; + * return &inst; + * } + * }; + * \endcode + * + * Guidelines for leveraging DLPackExchangeAPI: + * + * There are generally two kinds of consumer needs for DLPack exchange: + * - N0: library support, where consumer.kernel(x, y, z) would like to run a kernel + * with the data from x, y, z. The consumer is also expected to run the kernel with the same + * stream context as the producer. For example, when x, y, z is torch.Tensor, + * consumer should query exchange_api->current_work_stream to get the + * current stream and launch the kernel with the same stream. + * This setup is necessary for no synchronization in kernel launch and maximum compatibility + * with CUDA graph capture in the producer. + * This is the desirable behavior for library extension support for frameworks like PyTorch. + * - N1: data ingestion and retention + * + * Note that obj.__dlpack__() API should provide useful ways for N1. + * The primary focus of the current DLPackExchangeAPI is to enable faster exchange N0 + * with the support of the function pointer current_work_stream. + * + * Array/Tensor libraries should statically create and initialize this structure + * then return a pointer to DLPackExchangeAPI as an int value in Tensor/Array. + * The DLPackExchangeAPI* must stay alive throughout the lifetime of the process. + * + * One simple way to do so is to create a static instance of DLPackExchangeAPI + * within the framework and return a pointer to it. The following code + * shows an example to do so in C++. It should also be reasonably easy + * to do so in other languages. + */ +typedef struct DLPackExchangeAPI { + /*! + * \brief The header that remains stable across versions. + */ + DLPackExchangeAPIHeader header; + /*! + * \brief Producer function pointer for DLPackManagedTensorAllocator + * This function must not be NULL. + * \sa DLPackManagedTensorAllocator + */ + DLPackManagedTensorAllocator managed_tensor_allocator; + /*! + * \brief Producer function pointer for DLPackManagedTensorFromPyObject + * This function must be not NULL. + * \sa DLPackManagedTensorFromPyObject + */ + DLPackManagedTensorFromPyObjectNoSync managed_tensor_from_py_object_no_sync; + /*! + * \brief Producer function pointer for DLPackManagedTensorToPyObject + * This function must be not NULL. + * \sa DLPackManagedTensorToPyObject + */ + DLPackManagedTensorToPyObjectNoSync managed_tensor_to_py_object_no_sync; + /*! + * \brief Producer function pointer for DLPackDLTensorFromPyObject + * This function can be NULL when the producer does not support this function. + * \sa DLPackDLTensorFromPyObjectNoSync + */ + DLPackDLTensorFromPyObjectNoSync dltensor_from_py_object_no_sync; + /*! + * \brief Producer function pointer for DLPackCurrentWorkStream + * This function must be not NULL. + * \sa DLPackCurrentWorkStream + */ + DLPackCurrentWorkStream current_work_stream; +} DLPackExchangeAPI; + +#ifdef __cplusplus +} // DLPACK_EXTERN_C +#endif +#endif // DLPACK_DLPACK_H_ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/jit_macros.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/jit_macros.h new file mode 100644 index 0000000000000000000000000000000000000000..7ca584767df4958f4c97a829da9708843cea30f8 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/jit_macros.h @@ -0,0 +1,12 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +#include +#include + +// AT_USE_JITERATOR(), controls whether we jit some elementwise kernels +#define AT_USE_JITERATOR() true +#define jiterator_stringify(...) std::string(#__VA_ARGS__); + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/jiterator_macros.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/jiterator_macros.h new file mode 100644 index 0000000000000000000000000000000000000000..298e030a2c42f71a795c8534752bbf9df99ebba9 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/jiterator_macros.h @@ -0,0 +1,43 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +#include +#include + +#define JITERATOR_HOST_DEVICE C10_HOST_DEVICE +#if defined(_MSC_VER) && defined(__CUDACC__) +// NVRTC on Windows errors if __host__ __device__ attribute is +// present on kernel. +// error: attribute "__host__" does not apply here +// error: attribute "__device__" does not apply here +#define JITERATOR_HOST_DEVICE +#endif + +// jiterator_also_stringify_as macro is used to define code (for CPU/ROCm) +// and generate code string for `jiterator` (only when compiling for CUDA). +// Usage : +// jiterator_also_stringify_as( +// jiterator_code(template T identity(T x) { return x; }), +// identity_string); +// This will define the template `identity` as present in code and +// also define `std::string identity_string` with the code as the string +// if this is being compiled for CUDA. + +// `jiterator_code` macro is to deal with `,` in the kernel code. +// These `,`s confuse the preprocessor into thinking we are passing +// multiple arguments to the macro. +#define jiterator_code(...) __VA_ARGS__ +#if defined(__CUDACC__) || defined(__HIPCC__) +// CPU and CUDA and ROCm case +#define stringify_code(...) #__VA_ARGS__ +#define jiterator_also_stringify_as(code, str_name) \ + code /* define the function */ \ + const std::string str_name = std::string(stringify_code(code)); +#else +// CPU only or CPU and ROCm case +// Only needs the function +#define jiterator_also_stringify_as(code, str_name) code +#endif + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/record_function.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/record_function.h new file mode 100644 index 0000000000000000000000000000000000000000..c0ca1a69ee445bd52776a669da4f012113b9deed --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/record_function.h @@ -0,0 +1,799 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace c10 { +class TORCH_API OperatorHandle; +} + +namespace at { + +// Function name to record NCCL metadata +extern TORCH_API const std::string kParamCommsCallName; + +// Kind of record function scope; +enum class C10_API_ENUM RecordScope : uint8_t { + // c10/ATen ops, autograd nodes + FUNCTION = 0, + // Functions/nodes called from the autograd + BACKWARD_FUNCTION, + // TorchScript functions, methods + TORCHSCRIPT_FUNCTION, + // Kernel Function dtype Tag + KERNEL_FUNCTION_DTYPE, + // Torchbind custom class, + CUSTOM_CLASS, + // Generic Build Feature + BUILD_FEATURE, + // Kernel Function dtype Tag + LITE_INTERPRETER, + // User defined scope (e.g. with record_function()) + USER_SCOPE, + // Scopes for static runtime, a specialized TorchScript interpreter + STATIC_RUNTIME_OP, + STATIC_RUNTIME_MODEL, + NUM_SCOPES, // must be the last in the list +}; + +} // namespace at + +namespace std { +template <> +struct hash { + size_t operator()(const at::RecordScope& sc) const { + return static_cast(sc); + } +}; +} // namespace std + +namespace at { + +struct TORCH_API StringView { + StringView() : StringView(nullptr) {} + explicit StringView(const char* str_ptr) + : owned_str_ptr_(nullptr), str_ptr_(str_ptr) {} + explicit StringView(std::string str) + : owned_str_ptr_(std::make_shared(std::move(str))), + str_ptr_(owned_str_ptr_->c_str()) {} + + const char* str() const { + return str_ptr_; + } + + friend std::ostream& operator<<(std::ostream& os, const StringView& dt) { + os << dt.str(); + return os; + } + + friend bool operator==(const StringView& lhs, const StringView& rhs) { + return strcmp(lhs.str(), rhs.str()) == 0; + } + + friend bool operator!=(const StringView& lhs, const StringView& rhs) { + return !(lhs == rhs); + } + + private: + std::shared_ptr owned_str_ptr_; + const char* str_ptr_; +}; + +// Soft limit on the number of callbacks to use; +constexpr std::size_t kSoftLimitCallbacks = 4; + +// An abstract base class for various observer contexts that can be attached to +// the RecordFunction. +struct ObserverContext { + virtual ~ObserverContext() = default; + + protected: + ObserverContext() = default; +}; + +typedef c10::SmallVector CallbackHandles; +typedef c10::SmallVector, kSoftLimitCallbacks> + ObserverContextList; +typedef uint64_t RecordFunctionHandle; +struct RecordFunction; + +// +// PyTorch callbacks/observers API: +// + +/** + * RecordFunctionCallback represents a pair of callbacks to be used with + * RecordFunction, members: + * start, end - the callbacks to run when entering and exiting the scope; + * optionally, the start callback may return an ObserverContext which will + * be passed to the end callback, use appropriate constructor accordingly. + * needs_inputs - whether the callbacks need the inputs passed from the + * observed function/range; NOTE: passing the inputs incurs an additional + * overhead; sampling_probability - if not 1.0, then the callback is + * probabilistically sampled to run; NOTE: start and end callbacks always run as + * a pair and are sampled together; scopes - types of scopes to execute the + * callbacks on (see RecordScope); passing empty set means the callbacks will be + * executed for all possible scope types should_run - optional function that + * returns whether this callback should run; overwrites the effect of setting + * sampling_probability + */ +class TORCH_API RecordFunctionCallback { + public: + using StartCallback = + std::unique_ptr (*)(const RecordFunction&); + using EndCallback = void (*)(const RecordFunction&, ObserverContext*); + + // This interface supports observers that require passing an ObserverContext + // between start and end callbacks. + explicit RecordFunctionCallback( + StartCallback start, + EndCallback end = nullptr) + : start_(start), end_(end) { + scopes_.fill(true); + } + + RecordFunctionCallback& needsInputs(bool needs_inputs) { + needs_inputs_ = needs_inputs; + return *this; + } + + RecordFunctionCallback& needsOutputs(bool needs_outputs) { + needs_outputs_ = needs_outputs; + return *this; + } + + RecordFunctionCallback& needsIds(bool needs_ids) { + needs_ids_ = needs_ids; + return *this; + } + + RecordFunctionCallback& samplingProb(double sampling_prob) { + TORCH_CHECK( + sampling_prob >= 0.0 && sampling_prob <= 1.0, + "Invalid sampling probability"); + sampling_prob_ = sampling_prob; + return *this; + } + + RecordFunctionCallback& scopes( + const std::unordered_set>& scopes) { + if (!scopes.empty()) { + scopes_.fill(false); + for (auto sc : scopes) { + scopes_[static_cast(sc)] = true; + } + } else { + scopes_.fill(true); + } + return *this; + } + + bool needsInputs() const { + return needs_inputs_; + } + + bool needsOutputs() const { + return needs_outputs_; + } + + bool needsIds() const { + return needs_ids_; + } + + double samplingProb() const { + return sampling_prob_; + } + + bool checkScope(RecordScope sc) const { + return scopes_[(size_t)sc]; + } + + StartCallback start() const { + return start_; + } + + EndCallback end() const { + return end_; + } + + private: + StartCallback start_; + EndCallback end_; + double sampling_prob_ = 1.0; + std::array(RecordScope::NUM_SCOPES)> scopes_ = {}; + bool needs_inputs_ = false; + bool needs_outputs_ = false; + bool needs_ids_ = false; +}; + +// Notes: +// - two types of callbacks are provided: thread local and global +// - thread local callbacks are added/removed only for the given thread +// and are stored locally for each thread and separately from the list +// of the global callbacks +// - global callbacks are stored in a single per process list and are +// invoked by every RecordFunction, in addition to the thread local +// callbacks specific to the given thread +// - we allow the added callbacks to be sampled, by specifying a sampling +// probability for each callback pair, if the start callback is +// not picked to run, the corresponding end callback won't be called +// - a typical use case for the global callbacks is passive monitoring +// in the background (e.g. fleet-wide monitoring), without focusing on +// the specific piece of code +// - in contrast, thread local callbacks are enabled locally, on demand, +// for the specific piece of code (range) and are not sampled +// - a typical use case for thread local callbacks is profiler and code +// execution tracer +// - note, thread local callbacks are automatically propagated with +// ThreadLocalState across JIT continuations and async tasks (at::launch) + +typedef uint64_t CallbackHandle; + +constexpr CallbackHandle INVALID_CALLBACK_HANDLE{0}; + +// It is unnecessary to use atomic operations for enabling +// thread-local function callbacks. Moreover, it prevents saving to +// ThreadLocalState because std::atomic is non-copyable. +struct RecordFunctionCallbacksEntry { + RecordFunctionCallbacksEntry(RecordFunctionCallback cb, CallbackHandle h) + : callback_(cb), handle_(h) {} + + RecordFunctionCallback callback_; + bool enabled_{true}; + CallbackHandle handle_; +}; + +// Holds pairs (callbacks, unique_id) +using RecordFunctionCallbacks = std::vector; + +// Generated by the callback managers to determine which functions to run. +struct StepCallbacks { + StepCallbacks() = default; + StepCallbacks(uint64_t thread_id, RecordScope scope) + : thread_id_{thread_id}, scope_{scope} {} + + bool empty() const { + return callbacks_.empty(); + } + + struct StartEndPair { + RecordFunctionCallback::StartCallback start_; + RecordFunctionCallback::EndCallback end_; + }; + + using StartEndPairs = c10::SmallVector; + + StartEndPairs callbacks_; + uint64_t thread_id_{0}; + RecordScope scope_{RecordScope::FUNCTION}; + bool needs_inputs_{false}; + bool needs_outputs_{false}; + bool needs_ids_{false}; +}; + +struct TORCH_API RecordFunction { + // Default constructor is used with before function called afterwards: + // scope - record scope that this function tracks + // pre_sampled - whether this RecordFunction was already pre-sampled with + // kLowProb probability + explicit RecordFunction(RecordScope scope = RecordScope::FUNCTION); + explicit RecordFunction(StepCallbacks&& step_callbacks); + + using schema_ref_t = std::reference_wrapper; + using FunctionDescriptor = std::variant; + + void before( + FunctionDescriptor fn, + c10::ArrayRef args, + int64_t current_sequence_nr = -1) { + if (!isActive()) { + return; + } + inputs_ = args; + before(fn, current_sequence_nr); + } + + void before( + FunctionDescriptor fn, + c10::ArrayRef args, + const std::unordered_map* kwargs, + int64_t current_sequence_nr = -1) { + if (!isActive()) { + return; + } + kwinputs_ = *kwargs; + before(fn, args, current_sequence_nr); + } + + void before( + FunctionDescriptor fn, + const std::unordered_map* kwargs, + int64_t current_sequence_nr = -1) { + if (!isActive()) { + return; + } + kwinputs_ = *kwargs; + before(fn, current_sequence_nr); + } + + void before( + FunctionDescriptor fn, + const std::vector* args, + int64_t current_sequence_nr = -1) { + before( + fn, + c10::ArrayRef(args->data(), args->size()), + current_sequence_nr); + } + + void before( + FunctionDescriptor fn, + const std::vector* args, + const std::unordered_map* kwargs, + int64_t current_sequence_nr = -1) { + if (!isActive()) { + return; + } + kwinputs_ = *kwargs; + before(std::move(fn), args, current_sequence_nr); + } + + // Destructor calls end callbacks + virtual ~RecordFunction(); + + RecordFunction(const RecordFunction&) = delete; + RecordFunction& operator=(const RecordFunction&) = delete; + RecordFunction(RecordFunction&&) = delete; + RecordFunction& operator=(RecordFunction&&) = delete; + + const char* name() const; + const char* overload_name() const; + + int64_t seqNr() const { + return sequence_nr_; + } + + c10::ArrayRef inputs() const { +#ifndef NDEBUG + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + inputs_valid_, "Called inputs() outside RecordFunction start callback"); +#endif + return inputs_; + } + + std::unordered_map kwinputs() const { +#ifndef NDEBUG + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + inputs_valid_, + "Called kwinputs() outside RecordFunction start callback"); +#endif + return kwinputs_; + } + + const std::vector& outputs() const { + return outputs_; + } + + void setOutputs(std::vector&& outputs) { + outputs_ = std::move(outputs); + } + + void setOutputs(c10::ArrayRef outputs) { + outputs_ = outputs.vec(); + } + + size_t num_inputs() const; + size_t num_outputs() const; + + // Retrieves the thread_id that this RecordFunction ran start callbacks with. + // Useful for writing thread safe end callbacks that may be potentially + // executed in a different thread (async ops) + uint64_t threadId() const { + return step_callbacks_.thread_id_; + } + + // For backward functions - thread id of the corresponding forward function, + // or zero otherwise; + // used alongside with sequence number to correlate backward functions with + // the forward ones + uint64_t forwardThreadId() const { + return fwd_thread_id_; + } + + void setForwardThreadId(uint64_t thread_id) { + fwd_thread_id_ = thread_id; + } + + RecordScope scope() const { + return step_callbacks_.scope_; + } + + // Returns logical thread_id for the current thread + static uint64_t currentThreadId(); + + // Internal functions, do not use directly; + // used in python's context manager + + // before functions initialize RecordFunction members and call + // start callbacks + void before(FunctionDescriptor schema, int64_t sequence_nr = -1); + + // Sets node ID for distributed profiling + static void setDefaultNodeId(int64_t defaultNodeId); + // Gets node ID for distributed profiling + static int64_t getDefaultNodeId(); + + // Calls end callbacks. After end(), accessors will no longer provide useful + // results. + void end(); + + // Internal-only, used only force async event for distributed events + // profiling. + void _setAsync(); + + // Returns whether this RecordFunction corresponds to an async event or not. + bool isAsync() const; + + // Returns whether this RecordFunction corresponds to NCCL metadata collection + // or not. + bool isNcclMeta() const { + return is_nccl_meta_; + } + + // Internal-only, used to denote out variant used for Static Runtime execution + void _setStaticRuntimeOutVariant(); + bool isStaticRuntimeOutVariant() const; + + RecordFunctionHandle handle() const { + return handle_; + } + + std::optional operator_name() const; + + // This method returns a copy of the FunctionSchema and can be expensive. + std::optional operator_schema() const; + + void setHandle(RecordFunctionHandle handle) { + handle_ = handle; + } + + // Whether this RecordFunction runs any callbacks. + bool isActive() const { + return !step_callbacks_.empty(); + } + + bool needsInputs() const { + return step_callbacks_.needs_inputs_; + } + + bool needsOutputs() const { + return step_callbacks_.needs_outputs_; + } + + int64_t debugHandle() const { + return debug_handle_; + } + + void setDebugHandle(int64_t debug_handle) { + debug_handle_ = debug_handle; + } + + void invalidateInputs() { +#ifndef NDEBUG + inputs_valid_ = false; +#endif + } + + private: + void runStartCallbacks(); + + StepCallbacks step_callbacks_; + + // In cases when RecordFunction might be active but we chose not to + // use the observers (e.g. operator is not observed), this boolean + // flag is used to check whether the start callbacks were called + bool called_start_callbacks_ = false; + +#ifndef NDEBUG + bool inputs_valid_ = false; +#endif + + // Stores various ObserverContext objects with event metadata for callbacks. + ObserverContextList ctx_; + + std::variant fn_; + + int64_t sequence_nr_ = -1; + c10::ArrayRef inputs_; + std::unordered_map kwinputs_; + std::vector outputs_; + + // For backward functions - thread id of the forward function + uint64_t fwd_thread_id_ = 0; + + // Unique id for this RecordFunction, used in callbacks to track start + // and end of ranges + RecordFunctionHandle handle_{0}; + + // Whether this record_function corresponds to an async event or not. Async + // events can complete in different threads or follow a future-like pattern + // of use. + bool is_async_{false}; + + // Debug handles are used for lazy annotation of module hierarchy + // and callstack. + // This is specifically is useful for mobile runtime, where generated + // debug handles can be lazily symbolicated using debug information + int64_t debug_handle_{-1}; + + // Whether this RecordFunction is used for an out variant run with + // Static Runtime + bool is_static_runtime_out_variant_{false}; + + // Whether this RecordFunction is used for NCCL metadata collection + bool is_nccl_meta_{false}; +}; + +TORCH_API StepCallbacks getStepCallbacks(RecordScope scope); + +TORCH_API std::optional getStepCallbacksUnlessEmpty( + RecordScope scope); + +namespace detail { +template +void record_function_with_scope( + RecordFunction& guard, + RecordFunction::FunctionDescriptor fn, + const Inputs& inputs, + Args&&... args) { + if (guard.needsInputs()) { + guard.before( + fn, + c10::ArrayRef(inputs.data(), inputs.size()), + std::forward(args)...); + } else { + guard.before(fn, std::forward(args)...); + } +} + +template +void record_function_with_scope_and_debug_handle( + RecordFunction& guard, + RecordFunction::FunctionDescriptor fn, + int64_t debug_handle, + const Inputs& inputs, + Args&&... args) { + guard.setDebugHandle(debug_handle); + if (guard.needsInputs()) { + guard.before( + fn, + c10::ArrayRef(inputs.data(), inputs.size()), + std::forward(args)...); + } else { + guard.before(fn, std::forward(args)...); + } +} + +template +void record_function_with_scope( + RecordFunction& guard, + RecordFunction::FunctionDescriptor fn, + c10::ArrayRef inputs, + Args&&... args) { + return record_function_with_scope, Args...>( + guard, fn, inputs, std::forward(args)...); +} + +template +void record_function_with_scope_and_debug_handle( + RecordFunction& guard, + RecordFunction::FunctionDescriptor fn, + int64_t debug_handle, + c10::ArrayRef inputs, + Args&&... args) { + return record_function_with_scope_and_debug_handle< + c10::ArrayRef, + Args...>(guard, fn, debug_handle, inputs, std::forward(args)...); +} + +} // namespace detail + +// optional argument - function's seq_no +#define RECORD_FUNCTION_WITH_SCOPE(scope, fn, inputs, ...) \ + at::RecordFunction guard(scope); \ + if (guard.isActive()) { \ + ::at::detail::record_function_with_scope( \ + guard, fn, inputs, ##__VA_ARGS__); \ + } + +#define RECORD_FUNCTION_WITH_SCOPE_INPUTS_OUTPUTS( \ + scope, fn, inputs, outputs, ...) \ + at::RecordFunction guard(scope); \ + if (guard.isActive()) { \ + if (guard.needsInputs()) { \ + guard.before(fn, inputs, ##__VA_ARGS__); \ + } else { \ + guard.before(fn, ##__VA_ARGS__); \ + } \ + if (guard.needsOutputs()) { \ + guard.setOutputs(outputs); \ + } \ + } + +#define RECORD_FUNCTION(fn, inputs, ...) \ + RECORD_FUNCTION_WITH_SCOPE( \ + at::RecordScope::FUNCTION, fn, inputs, ##__VA_ARGS__) + +#define RECORD_TORCHSCRIPT_FUNCTION(mn, inputs) \ + RECORD_FUNCTION_WITH_SCOPE(at::RecordScope::TORCHSCRIPT_FUNCTION, mn, inputs) + +#define RECORD_FUNCTION_WITH_INPUTS_OUTPUTS(fn, inputs, outputs, ...) \ + RECORD_FUNCTION_WITH_SCOPE_INPUTS_OUTPUTS( \ + at::RecordScope::FUNCTION, fn, inputs, outputs, ##__VA_ARGS__) + +// Custom user scopes in C++; similar to Python's 'with record_function("..."):' +#define RECORD_USER_SCOPE(fn) \ + RECORD_FUNCTION_WITH_SCOPE( \ + at::RecordScope::USER_SCOPE, fn, c10::ArrayRef{}) + +// RECORD_USER_SCOPE with inputs +#define RECORD_USER_SCOPE_WITH_INPUTS(fn, inputs) \ + RECORD_FUNCTION_WITH_SCOPE(at::RecordScope::USER_SCOPE, fn, inputs) + +#define RECORD_USER_SCOPE_WITH_KWARGS_ONLY(fn, kwargs) \ + RECORD_FUNCTION_WITH_SCOPE( \ + at::RecordScope::USER_SCOPE, \ + fn, \ + c10::ArrayRef{}, \ + kwargs) + +// Helper macro to pass in debug handle that is used to +// post process events +#define RECORD_WITH_SCOPE_DEBUG_HANDLE_AND_INPUTS( \ + scope, fn, debug_handle, inputs, ...) \ + at::RecordFunction guard(scope); \ + if (guard.isActive()) { \ + ::at::detail::record_function_with_scope_and_debug_handle( \ + guard, fn, debug_handle, inputs, ##__VA_ARGS__); \ + } + +// Helper macros to record LITE INTERPRETER scope events with debug handles +#define RECORD_EDGE_SCOPE_WITH_DEBUG_HANDLE_AND_INPUTS( \ + fn, debug_handle, inputs) \ + RECORD_WITH_SCOPE_DEBUG_HANDLE_AND_INPUTS( \ + at::RecordScope::LITE_INTERPRETER, fn, debug_handle, inputs) + +// Bookend to the RECORD_FUNCTION macros. Use this after the kernel +// launch to let the profiler bind the outputs to the op that produced +// them. Note that guard is declared by RECORD_FUNCTION so this macro +// needs to be called from the same scope as RECORD_FUNCTION +#define RECORD_OUTPUTS(outputs) \ + if (guard.needsOutputs()) { \ + guard.setOutputs( \ + std::vector(outputs.begin(), outputs.end())); \ + } + +/** + * addThreadLocalCallback adds a thread local callback to run with + * RecordFunction, returns handle to use with removeThreadLocalCallback + */ +TORCH_API CallbackHandle addThreadLocalCallback(RecordFunctionCallback cb); + +/** + * hasThreadLocalCallbacks returns whether there're callbacks registered + * with addThreadLocalCallback + */ +TORCH_API bool hasThreadLocalCallbacks(); + +/** + * clearThreadLocalCallbacks removes all thread local callbacks + */ +TORCH_API void clearThreadLocalCallbacks(); + +/** + * addGlobalCallback adds a global callback to run with RecordFunction: + * + * only during the program initialization + */ +TORCH_API CallbackHandle addGlobalCallback(RecordFunctionCallback cb); + +/** + * removeCallback removes a callback given the handle returned by + * addThreadLocalCallback or addGlobalCallback; + * + * no other code can run simultaneously + */ +TORCH_API void removeCallback(CallbackHandle handle); + +/** + * Prevent the given callback from executing. If handle is invalid, + * does nothing. + */ +TORCH_API void disableCallback(CallbackHandle handle); + +/** + * Allow the given callback, previously disabled with disableCallback, to + * execute again. If handle is invalid, does nothing. + */ +TORCH_API void reenableCallback(CallbackHandle handle); + +/** + * hasGlobalCallbacks returns whether there're global callbacks + * registered with pushGlobalCallback + */ +TORCH_API bool hasGlobalCallbacks(); + +/** + * clearGlobalCallbacks removes all global callbacks + */ +TORCH_API void clearGlobalCallbacks(); + +// for both thread local and global callbacks +TORCH_API bool hasCallbacks(); +TORCH_API void clearCallbacks(); + +/** + * enableRecordFunction enables RecordFunction thread locally + */ +TORCH_API void enableRecordFunction(bool enable = true); + +/** + * isRecordFunctionEnabled returns whether RecordFunction + * is enabled thread locally + */ +TORCH_API bool isRecordFunctionEnabled(); + +class TORCH_API RecordFunctionGuard { + public: + explicit RecordFunctionGuard(bool is_enabled = true) + : prev_value_(isRecordFunctionEnabled()) { + enableRecordFunction(is_enabled); + } + + RecordFunctionGuard(RecordFunctionGuard&& other) = delete; + RecordFunctionGuard(const RecordFunctionGuard&) = delete; + RecordFunctionGuard& operator=(const RecordFunctionGuard&) = delete; + RecordFunctionGuard& operator=(RecordFunctionGuard&&) = delete; + virtual ~RecordFunctionGuard() { + enableRecordFunction(prev_value_); + } + + private: + bool prev_value_ = false; +}; + +class TORCH_API DisableRecordFunctionGuard : public RecordFunctionGuard { + public: + DisableRecordFunctionGuard() : RecordFunctionGuard(false) {} + ~DisableRecordFunctionGuard() override = default; +}; + +struct TORCH_API RecordFunctionTLS { + // Thread local vector of callbacks, holds pairs (callbacks, unique_id); + // must be sorted in increasing handles order + RecordFunctionCallbacks sorted_tls_callbacks_; + + bool tls_record_function_enabled_ = true; +}; + +TORCH_API const RecordFunctionTLS& get_record_function_tls_(); + +TORCH_API void set_record_function_tls_(const RecordFunctionTLS& tls); + +TORCH_API void set_record_function_seed_for_testing(uint32_t seed); + +} // namespace at + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/THC/THCAtomics.cuh b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/THC/THCAtomics.cuh new file mode 100644 index 0000000000000000000000000000000000000000..cb269d477f2c8d906d0c5c3e101189cb85c971d3 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/THC/THCAtomics.cuh @@ -0,0 +1,8 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +// TODO: Remove once torchvision has been updated to use the ATen header +#include + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/THC/THCDeviceUtils.cuh b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/THC/THCDeviceUtils.cuh new file mode 100644 index 0000000000000000000000000000000000000000..251098b34aec9c7a11706c03213c2a5035bf5050 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/THC/THCDeviceUtils.cuh @@ -0,0 +1,8 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +// TODO: Remove this header +#include + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fbgemm/ConvUtils.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fbgemm/ConvUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..e02f7219333029289853acf1883b29b840da4198 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fbgemm/ConvUtils.h @@ -0,0 +1,195 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include + +namespace fbgemm { + +template +constexpr std::enable_if_t> +array_of_ones() { + return std::array{{Vals...}}; +} + +template +constexpr std::enable_if_t> +array_of_ones() { + return array_of_ones(); +} + +template +constexpr std::enable_if_t> +array_of_zeroes() { + return std::array{{Vals...}}; +} + +template +constexpr std::enable_if_t> +array_of_zeroes() { + return array_of_zeroes(); +} + +/** + * @brief A struct to conveniently store all convolution parameters. + */ +template +struct conv_param_t { + int MB; ///< Mini Batch size + int IC; ///< Number of Input Channels + int OC; ///< Number of Output Channels + std::array IN_DIM; ///< Input Image Dimension + int G; ///< Number of Groups + std::array K; ///< Filter (Kernel) dimensions + std::array stride; //< Strides + std::array + pad; //< Padding (first SPATIAL_DIM is for prev/top/left padding, second + // SPATIAL_DIM is for next/bottom/right padding) + std::array dilation; //< Kernel dilation + + // The following are derived parameters + std::array OUT_DIM; //< Output Image Dimension + std::array IN_DIMP; //< Input Image Dimension Padded + + // The following is for tranposed convolution + std::array + output_pad; //< Padding (next/bottom/right padding in output buffer) + bool transposed; + + /** + * @brief Constructor for initializing the convolution parameters. + */ + conv_param_t( + int mb, + int ic, + int oc, + std::array in_dim, + int g, + std::array k, + std::array strd, + std::array pd, + std::array dilations = array_of_ones(), + std::array otpt_pd = array_of_zeroes(), + bool transposed = false) + : MB(mb), + IC(ic), + OC(oc), + IN_DIM(in_dim), + G(g), + K(k), + stride(strd), + pad(pd), + dilation(dilations), + output_pad(otpt_pd), + transposed(transposed) { + if (ic % g != 0) { + throw std::runtime_error( + "groups = " + std::to_string(g) + + " does not divide number of input channels = " + std::to_string(ic)); + } + if (oc % g != 0) { + throw std::runtime_error( + "groups = " + std::to_string(g) + + " does not divide number of output channels = " + std::to_string(oc)); + } + + for (int d = 0; d < SPATIAL_DIM; ++d) { + if (transposed) { + this->IN_DIMP[d] = this->IN_DIM[d] + + (this->dilation[d] * (this->K[d] - 1) - this->pad[d]) + + (this->dilation[d] * (this->K[d] - 1) - this->pad[SPATIAL_DIM + d]); + this->OUT_DIM[d] = (this->IN_DIM[d] - 1) * this->stride[d] - + this->pad[d] - this->pad[SPATIAL_DIM + d] + + this->dilation[d] * (this->K[d] - 1) + output_pad[d] + 1; + } else { + IN_DIMP[d] = IN_DIM[d] + pad[d] + pad[SPATIAL_DIM + d]; + OUT_DIM[d] = + (IN_DIMP[d] - dilation[d] * (K[d] - 1) - 1) / stride[d] + 1; + } + } + } + + /** + * @brief Helper function to get convolution parameters as string. + */ + std::string toString() const { + std::string dim_string[3] = {"T", "H", "W"}; + + std::string out; + out += "MB:" + std::to_string(MB) + ", "; + out += "IC:" + std::to_string(IC) + ", "; + out += "OC:" + std::to_string(OC) + ", "; + if constexpr (SPATIAL_DIM <= 3) { + for (int d = 0; d < SPATIAL_DIM; ++d) { + out += "I" + dim_string[3 - SPATIAL_DIM + d] + ":" + + std::to_string(IN_DIM[d]) + ", "; + } + } else { + for (int d = 0; d < SPATIAL_DIM; ++d) { + out += "I" + std::to_string(d) + ":" + std::to_string(IN_DIM[d]) + ", "; + } + } + out += "G:" + std::to_string(G) + ", "; + if constexpr (SPATIAL_DIM <= 3) { + for (int d = 0; d < SPATIAL_DIM; ++d) { + out += "K" + dim_string[3 - SPATIAL_DIM + d] + ":" + + std::to_string(K[d]) + ", "; + } + for (int d = 0; d < SPATIAL_DIM; ++d) { + out += "stride_" + dim_string[3 - SPATIAL_DIM + d] + ":" + + std::to_string(stride[d]) + ", "; + } + for (int d = 0; d < SPATIAL_DIM * 2; ++d) { + out += "pad_" + dim_string[3 - SPATIAL_DIM + (d % SPATIAL_DIM)] + ":" + + std::to_string(pad[d]) + ", "; + } + for (int d = 0; d < SPATIAL_DIM; ++d) { + out += "dilation_" + dim_string[3 - SPATIAL_DIM + d] + ":" + + std::to_string(dilation[d]); + if (d < SPATIAL_DIM - 1) { + out += ", "; + } + } + } else { + for (int d = 0; d < SPATIAL_DIM; ++d) { + out += "K" + std::to_string(d) + ":" + std::to_string(K[d]) + ", "; + } + for (int d = 0; d < SPATIAL_DIM; ++d) { + out += "stride_" + std::to_string(d) + ":" + std::to_string(stride[d]) + + ", "; + } + for (int d = 0; d < SPATIAL_DIM; ++d) { + out += "pad_" + std::to_string(d) + ":" + std::to_string(pad[d]); + if (d < SPATIAL_DIM * 2 - 1) { + out += ", "; + } + } + for (int d = 0; d < SPATIAL_DIM; ++d) { + out += "dilation_" + std::to_string(d) + ":" + + std::to_string(dilation[d]) + ", "; + } + } + if (transposed) { + for (int d = 0; d < SPATIAL_DIM; ++d) { + out += "output_padding_" + std::to_string(d) + ":" + + std::to_string(output_pad[d]) + ", "; + } + } + return out; + } +}; +} // namespace fbgemm + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fbgemm/Fbgemm.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fbgemm/Fbgemm.h new file mode 100644 index 0000000000000000000000000000000000000000..518d014dd8bdf904cb561071de323cce8d2517eb --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fbgemm/Fbgemm.h @@ -0,0 +1,1515 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +/** + * Top level include file for FBGEMM. + */ +#include +#include +#include "./ConvUtils.h" // @manual +#include "./FbgemmBuild.h" // @manual +#include "./FbgemmEmbedding.h" // @manual +#include "./FbgemmI8DepthwiseAvx2.h" // @manual +#include "./FbgemmI8DirectconvAvx2.h" // @manual +#include "./FbgemmI8Spmdm.h" // @manual +#include "./FloatConversion.h" // @manual +#include "./QuantUtilsAvx2.h" // @manual +#include "./Types.h" // @manual +#include "./Utils.h" // @manual + +// Turning on this option will print out time breakdown of each stage (e.g., +// input packing, the main GEMM kernel, each output processing pipeline). +// Please note that currently this option won't report accurate timing if +// multiple threads are used. +// #define FBGEMM_MEASURE_TIME_BREAKDOWN + +#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN +#include +#include +extern double packing_time; +extern double computing_time; +extern double kernel_time; +extern double postprocessing_time; +extern double run_time; +#endif + +namespace fbgemm { + +/** + * @brief Templatized struct for packing parameters for A and B matrices. + * + * @tparam T input type + * @tparam accT the type used for accumulation + * @tparam instSet anyarch/avx2/avx512 + * @tparam int8Type an auxiliary template parameter to specialize for 8-bit + * input types. + */ +template < + typename T, + typename accT, + inst_set_t instSet, + typename int8Type = void> +struct PackingTraits; + +// type specialized implementation in an include file +#include "./PackingTraits-inl.h" // @manual + +/** + * @brief Base class for packing matrices for higher GEMM performance. + * + * Matrix is tiled into blockRows() * blockCols() blocks. + * Each block is with size blockRowSize() * blockColSize(). + * This class is designed using CRTP + * (https://en.wikipedia.org/wiki/Curiously_recurring_template_pattern) + * + * @tparam PT actual packing type, e.g., PackAWithRowOffset + */ +template +class PackMatrix { + public: + PackMatrix() = delete; // no default constructor + PackMatrix(const PackMatrix&) = delete; // no copy + PackMatrix& operator=(const PackMatrix&) = delete; // no copy + PackMatrix(PackMatrix&&) = delete; // no move + PackMatrix& operator=(PackMatrix&& rhs) noexcept = delete; // no move + + /** + * @param rows total number of rows in the matrix + * (packed rows can be less than rows). + * @param cols total number of columns in the matrix + * @param pmat A buffer to contain the packed matrix. + * If nullptr, a buffer owned by PackMatrix will be allocated + * internally to contain the packed matrix. + * For non-constant matrices like activation matrices, the client + * code may want to pass a pre-allocated pmat to avoid the + * overhead of internal memory allocation everytime a PackMatrix + * is constructed. The client code can query how big patm should + * be with packedBufferSize function. + * @param groups when groups > 1, we compute groups number of GEMMs each + * multiplies A.rows by A.cols/A.groups matrix with + * B.rows/B.groups by B.cols matrix (in conventional BLAS + * terminology, this is a batched GEMM but we use the name group + * to follow deep learning terminology). The result matrix has + * dimension A.rows by B.cols*B.groups . + * A.groups must be same as B.groups, A.groups must divide + * A.cols, and B.groups must divide B.rows and C.cols. + */ + PackMatrix( + std::int32_t rows, + std::int32_t cols, + inpType* pmat, + int groups = 1, + const BlockingFactors* params = nullptr); + + /** + * @return true usually when the matrix is constant matrix (e.g., weight + * matrices) that can be prepacked + */ + bool isPrePacked() const { + return static_cast(this)->isPrePacked(); + } + + /** + * @return true if this is the first input matrix in GEMM (i.e., A in C = A * + * B) + */ + static bool isA() { + return PT::isA(); + } + + /** + * @brief The size of the buffer used for packing (The size is in number of + * elements). + * + * rows and cols are only used for fully packing, i.e., for B matrix. The + * client code can use this function to query how big the buffer used for + * packing should be. + */ + static int packedBufferSize( + int rows = 0, + int cols = 0, + const BlockingFactors* params = nullptr); + + FBGEMM_PUSH_WARNING_AND_DISABLE("-Wpragmas") + FBGEMM_PUSH_WARNING_AND_DISABLE("-Winfinite-recursion") + /** + * @return Pointer to a buffer containing row offset results. Some packing + * objects fuse row offset computation for later requantization step. + */ + std::int32_t* getRowOffsetBuffer() const { + return static_cast(this)->getRowOffsetBuffer(); + } + /** + * @brief When k loop is also tiled/blocked, this function is used to check if + * have executed computations for the last k block so that we can perform + * post-GEMM operations. + */ + bool isThisLastKBlock(int block_id) const { + return static_cast(this)->isThisLastKBlock(block_id); + } + FBGEMM_POP_WARNING + FBGEMM_POP_WARNING + + /** + * @brief Actual packing of a block of the source matrix in pmat buffer. + */ + void pack(const block_type_t& block) { +#if defined(FBGEMM_FBCODE) || !defined(__aarch64__) + static_cast(this)->pack(block); +#else + throw std::runtime_error("PackMatrix::pack() not implemented for aarch64"); +#endif // __aarch64__ + } + + std::int32_t numRows() const { + return nrows_; + } + + std::int32_t numCols() const { + return ncols_; + } + + /** + * @return The number of rows in each block + */ + std::int32_t blockRowSize() const { + return brow_; + } + + /** + * @return The number of columns in each block + */ + std::int32_t blockColSize() const { + return bcol_; + } + + /** + * @return The number of blocks along rows + */ + std::int32_t blockRows() const { + return nbrow_; + } + + /** + * @return The number of blocks along columns + */ + std::int32_t blockCols() const { + return nbcol_; + } + + /** + * @return The number of the rows in the currently packed block of a matrix. + * For pre-packed (i.e., fully-packed), it's equal to the total number + * of rows. + */ + std::int32_t numPackedRows() const { + return packedBlock_.row_size; + } + + /** + * @return The number of columns in the currently packed block of a matrix. + * For pre-packed (i.e., fully-packed), it's equal to the number of + * columns. + */ + std::int32_t numPackedCols() const { + return packedBlock_.col_size; + } + + /** + * @return The first row of the block we're working on. + */ + std::int32_t packedRowStart() const { + return packedBlock_.row_start; + } + + /** + * @return The first column of the block we're working on. + */ + std::int32_t packedColStart() const { + return packedBlock_.col_start; + } + + /** + * @return The beginning of (rowBlockNum, colBlockNum)th block + */ + inpType* getBuf(std::int32_t rowBlockNum = 0, std::int32_t colBlockNum = 0) { + return buf_ + blockRowSize() * blockColSize() * rowBlockNum + + blockRowSize() * blockColSize() * blockCols() * colBlockNum; + } + + /** + * @brief Print the packed block. + */ + void printPackedMatrix(const std::string& name) { + static_cast(this)->printPackedMatrix(name); + } + + /** + * @return The number of rows in the last row block. + */ + std::int32_t lastBrow() const { + return last_brow_; + } + + /** + * @return The number of columns in the last column block. + */ + std::int32_t lastBcol() const { + return last_bcol_; + } + + int numGroups() const { + return G_; + } + + /** + * @return True if the last column block has fewer columns than the block + * size. + */ + bool isThereColRemainder() const { + return last_bcol_ != blockColSize(); + } + + virtual ~PackMatrix() { + if (bufAllocatedHere_) { + fbgemmAlignedFree(buf_); + } + } + + protected: + /** + * Set which block we're packing + */ + void packedBlock(const block_type_t& block) { + packedBlock_ = block; + nbrow_ = (numPackedRows() + blockRowSize() - 1) / blockRowSize(); + nbcol_ = (numPackedCols() + blockColSize() - 1) / blockColSize(); + + last_brow_ = ((numPackedRows() % blockRowSize()) == 0) + ? blockRowSize() + : (numPackedRows() % blockRowSize()); + last_bcol_ = ((numPackedCols() % blockColSize()) == 0) + ? blockColSize() + : (numPackedCols() % blockColSize()); + } + + inpType* buf_; + std::int32_t brow_; ///< the number of rows in each block + std::int32_t bcol_; ///< the number of columns in each block + std::int32_t nbrow_; ///< the number of blocks along rows + std::int32_t nbcol_; ///< the number of blocks along columns + bool bufAllocatedHere_{false}; + const BlockingFactors* + blocking_params; ///< MCB, KCB, NCB, MR, NR, NR_MIN, ROW_INTERLEAVE; + + private: + std::int32_t nrows_, ncols_; + int G_; + block_type_t packedBlock_; ///< The block in the source matrix just packed + std::int32_t last_brow_, last_bcol_; +}; + +/** + * @brief Matrix packed for the first input matrix in GEMM (usually + * activation). The source matrix is already quantized. Default + * accumulation type is int32. + */ +template +class FBGEMM_API PackAMatrix final + : public PackMatrix, T, accT> { + public: + using This = PackAMatrix; + using BaseType = PackMatrix; + using inpType = T; + using accType = accT; + + PackAMatrix() = delete; // no default constructor + + PackAMatrix( + matrix_op_t trans, + std::int32_t nRow, + std::int32_t nCol, + const inpType* smat, + std::int32_t ld, + inpType* pmat = nullptr, + int groups = 1, + const BlockingFactors* params = nullptr); + + /** + * Activation matrices are not constant so cannot amortize the cost of + * pre-packing. + */ + bool isPrePacked() const { + return false; + } + + /** + * @return True if this is used as A matrix. + */ + static constexpr bool isA() { + return true; + } + + /** + * @return A pointer to the row offset buffer. There is no row offset buffer + * calculations with this packing class, hence, it returns nullptr. + */ + std::int32_t* getRowOffsetBuffer() const { + return nullptr; + } + + /** + * @return Offset of the element in the packed matrix that was at (i, j) in + * the source matrix. + */ + std::int32_t addr(std::int32_t i, std::int32_t j) const; + + /** + * @brief Packs a block of source matrix into pmat buffer. + */ + void pack(const block_type_t& block); + + /** + * @brief Print the packed block. + */ + void printPackedMatrix(const std::string& name); + + private: + matrix_op_t trans_; + const T* smat_; + std::int32_t ld_; + std::int32_t row_interleave_B_; +}; + +/** + * @brief Matrix packed for the second input matrix in GEMM (usually weight). + * The source matrix is already quantized. Default accumulation + * type is int32. + */ +template +class FBGEMM_API PackBMatrix final + : public PackMatrix, T, accT> { + public: + using This = PackBMatrix; + using BaseType = PackMatrix; + using inpType = T; + using accType = accT; + + PackBMatrix() = delete; // no default constructor + + /** + * @param groups if > 1 and trans == NoTranspose, smat is nRow x nCol with + * groups are vertically concatenated: each group is + * (nRow / groups) x nCol . + * if > 1 and trans == Transpose, smat is (nCol * groups) x + * (nRow / groups) with groups are horizontally concatenated: + * each group is nCol x (nRow / groups) . Each group is + * transposed and vertically concatenated to match with the + * NoTranspose case. + */ + PackBMatrix( + matrix_op_t trans, + std::int32_t nRow, + std::int32_t nCol, + const inpType* smat, + std::int32_t ld, + inpType* pmat = nullptr, + int groups = 1, + const BlockingFactors* params = nullptr); + + /** + * Weight matrices are usually constant so worth pre-packing. + */ + bool isPrePacked() const { + return true; + } + + /** + * @return True if to be used as A matrix, False otherwise. + */ + static constexpr bool isA() { + return false; + } + + /** + * @brief When k loop is also tiled/blocked, this function is used to check if + * have executed computations for the last k block so that we can perform + * post-GEMM operations. + */ + bool isThisLastKBlock(int block_id) const { + return (BaseType::blockRows() - 1) == block_id; + } + + /** + * @return Offset of the element in the packed matrix that was at (i, j) in + * the source matrix. + */ + std::int32_t addr(std::int32_t i, std::int32_t j) const; + + /** + * @brief Packs a block of source matrix into pmat buffer. The blocking + * parameters are needed to compute the buffer size of each group. + * It will use default blocking parameters if params is not provided. + */ + void pack(const block_type_t& block, const BlockingFactors* params = nullptr); + + /** + * @brief Print the packed block. + */ + void printPackedMatrix( + const std::string& name, + const BlockingFactors* params = nullptr); + + /** + * @return true if meta information like matrix shape is the same. + */ + bool metaEquals(const PackBMatrix& that) const; + /** + * @return true if matrices are the same. + */ + bool equals(const PackBMatrix& that) const; + + /** + * @brief Unpack pmat buffer to the origin_buf (Used for the serialization to + * recover weight matrix). + */ + void unpack(T* origin_buf, const BlockingFactors* params = nullptr); + + ~PackBMatrix() override = default; + + private: + matrix_op_t trans_; + const T* smat_; + std::int32_t ld_; + std::int32_t row_interleave_; + + /** + * @brief Internal function performing both pack & unpack + */ + void pack_unpack_( + const block_type_t& block, + T* unpack_buf, + T* pack_buf, + bool ispack, + const BlockingFactors* params = nullptr); +}; + +/** + * @brief Matrix packed for direct group convolution. + * The source matrix is already quantized. Default accumulation + * type is int32. + */ +template +class FBGEMM_API PackWeightMatrixForGConv { + public: + using This = PackWeightMatrixForGConv; + using inpType = T; + using accType = accT; + + PackWeightMatrixForGConv() = delete; // no default constructor + PackWeightMatrixForGConv(const PackWeightMatrixForGConv&) = delete; // no copy + PackWeightMatrixForGConv& operator=(const PackWeightMatrixForGConv&) = + delete; // no copy + + PackWeightMatrixForGConv(PackWeightMatrixForGConv&&) = delete; // no move + PackWeightMatrixForGConv& operator=(PackWeightMatrixForGConv&&) = + delete; // no move + + /** + * @param pmat if nullptr, a buffer is allocated and owned by this class. + */ + PackWeightMatrixForGConv( + matrix_op_t trans, + const conv_param_t& conv_param, + const inpType* sdata, + inpType* pdata = nullptr); + + /** + * Number of groups we work at a time to fill the full simd width + * e.g., IC_PER_G = 4 and OC_PER_G = 4, we work on two groups at a time + * to fill the avx2 width of 256 bits. + */ + static int numOfGroupsTogether(const conv_param_t& conv_param); + + /** + * @brief Packs a block of source matrix into pmat buffer. + */ + void pack(); + + /** + * @brief Unpacks a pmat buffer into source matrix. + */ + void unpack(T* origin_buf); + + /** + * @brief Return packed data + */ + inpType* getBuf() { + return pdata_; + } + + ~PackWeightMatrixForGConv() { + if (bufAllocatedHere_) { + fbgemmAlignedFree(pdata_); + } + } + + private: + matrix_op_t trans_; + const conv_param_t conv_param_; + const T* sdata_; + T* pdata_; + bool bufAllocatedHere_{false}; + // Number of groups we work at a time to fill the full simd width + int GTogether_; + + /** + * @brief Internal function performing both pack & unpack + */ + void pack_unpack_(const T* src, T* dst, bool ispack); + + /** + * @brief Get the index of the unpacked data + */ + int unpacked_index_(int t, int r, int s, int k, int g, int c, bool tr); + + /** + * @brief Get the index of the packed data + */ + int packed_index_(int t, int r, int s, int k, int g, int c); +}; + +/** + * @brief A container class to keep packed weight tensor for convolution. + * The source tensor should already be quantized. + * + * @tparam SPATIAL_DIM is equal to 2 for 2D convolutions and 3 for 3D + * convolutions. Default value is 2. + * @tparam T is the datatype for source tensor. Default value is int8. + * @tparam accT is the datatype to accumulate into. Default value is int32. + */ +template < + int SPATIAL_DIM = 2, + typename T = std::int8_t, + typename accT = std::int32_t> +class FBGEMM_API PackWeightsForConv { + public: + using This = PackWeightsForConv; + using inpType = T; + using accType = accT; + + PackWeightsForConv() = delete; // no default constructor + + PackWeightsForConv( + const conv_param_t& conv_param, + const inpType* sdata, + const BlockingFactors* blocking_params = nullptr); + + std::shared_ptr> getPackedWForIm2col() { + return W_im2col_packed_; + } + +#if defined(FBGEMM_FBCODE) || !defined(__aarch64__) + std::shared_ptr getPackedWForDepthwise() { + return W_dw_packed_; + } +#endif // __aarch64__ + + std::shared_ptr getPackedWForDirectconv() { + return W_dc_packed_; + } + + std::shared_ptr> + getPackedWForGroupwise() { + return W_gconv_packed_; + } + + std::shared_ptr> getPackedWForPointwise() { + return W_pointwise_packed_; + } + + int inputChannels() { + return conv_param_.IC; + } + + int outputChannels() { + return conv_param_.OC; + } + + std::array kernelDims() { + return conv_param_.K; + } + + int groups() { + return conv_param_.G; + } + + /** + * @brief Returns true if the packed weights would work for the given + * convolution parameters, and false otherwise + */ + bool isPackingCompliant(const conv_param_t& conv_p); + + /** + * @brief Returns a string of mismatching parameters + */ + std::string mismatchingParams(const conv_param_t& conv_p); + + /** + * @brief Unpack packed matric into origin_buf (Used for the serialization to + * recover weight matrix). + */ + void unpack(T* origin_buf); + + private: + const conv_param_t conv_param_; + // Packed weights if we use im2col based convolution implementation + std::shared_ptr> W_im2col_packed_; +#if defined(FBGEMM_FBCODE) || !defined(__aarch64__) + // Packed weights if we use depthwise convolution implementation + std::shared_ptr W_dw_packed_; +#endif // __aarch64__ + // Packed weights if we use direct convolution implementation + std::shared_ptr W_dc_packed_; + // Packed weights if we use groupwise (small channels per group) convolution + // implementation + std::shared_ptr> + W_gconv_packed_; + // Packed weights if we use direct gemm for pointwise convolution + std::shared_ptr> W_pointwise_packed_; +}; + +/** + * @brief Matrix packed for the first input matrix in GEMM (usually activation), + * and row offsets used for requantization is computed during packing. + * Im2col is fused with packing here. The source matrix is already + * quantized. + */ +template +class FBGEMM_API PackAWithIm2Col + : public PackMatrix, T, accT> { + public: + using This = PackAWithIm2Col; + using BaseType = PackMatrix; + using inpType = T; + using accType = accT; + + PackAWithIm2Col() = delete; // no default constructor + /** + * @param zero_pt the quantized value that maps to 0.0f floating-point number. + * @param row_offset If nullptr, this constructor internally allocates a + * buffer and owns it. Otherwise, this class doesn't own + * the buffer. The buffer will be populated when pack + * function is called. + * @param b_symmetric if true we skip row offset computation + */ + PackAWithIm2Col( + const conv_param_t& conv_param, + const T* sdata, + inpType* pmat = nullptr, + std::int32_t a_zero_pt = 0, + std::int32_t* row_offset = nullptr, + bool b_symmetric = false, + const BlockingFactors* params = nullptr); + + PackAWithIm2Col(const PackAWithIm2Col&) = delete; + PackAWithIm2Col(PackAWithIm2Col&&) = delete; + PackAWithIm2Col& operator=(const PackAWithIm2Col&) = delete; + PackAWithIm2Col& operator=(PackAWithIm2Col&&) = delete; + + /** + * Activation matrices are not constant so cannot amortize the cost of + * pre-packing. + */ + bool isPrePacked() const { + return false; + } + + /** + * @return True if this is used as A matrix. + */ + static constexpr bool isA() { + return true; + } + + /** + * @brief Packs a block of source matrix into pmat buffer. + */ + void pack(const block_type_t& block); + + /** + * @return A pointer to the row offset buffer. + */ + std::int32_t* getRowOffsetBuffer() const { + return row_offset_; + } + + /** + * @brief Print the packed block. + */ + void printPackedMatrix(const std::string& name); + + /** + * @return Size of row offset buffer in number of elements + */ + static int rowOffsetBufferSize(const BlockingFactors* params = nullptr); + + ~PackAWithIm2Col() override { + if (rowOffsetAllocatedHere) { + fbgemmAlignedFree(row_offset_); + } + } + + private: + const conv_param_t conv_p_; + const T* sdata_; + std::int32_t a_zero_pt_; + std::int32_t* row_offset_{nullptr}; + bool rowOffsetAllocatedHere{false}; + std::int32_t row_interleave_B_; +}; + +/** + * @brief Matrix packed for the first input matrix in GEMM (usually activation), + * and row offsets used for requantization is computed during packing. + * The source matrix is already quantized. + */ +template +class FBGEMM_API PackAWithRowOffset final + : public PackMatrix, T, accT> { + public: + using This = PackAWithRowOffset; + using BaseType = PackMatrix; + using inpType = T; + using accType = accT; + + PackAWithRowOffset() = delete; // no default constructor + /** + * @param row_offset If nullptr, this constructor internally allocates a + * buffer and owns it. Otherwise, this class doesn't own + * the buffer. The buffer will be populated when pack + * function is called. + */ + PackAWithRowOffset( + matrix_op_t trans, + std::uint32_t nRow, + std::uint32_t nCol, + const T* smat, + std::uint32_t ld, + inpType* pmat = nullptr, + int groups = 1, + std::int32_t* row_offset = nullptr, + const BlockingFactors* params = nullptr); + + PackAWithRowOffset(const PackAWithRowOffset&) = delete; + PackAWithRowOffset(PackAWithRowOffset&&) = delete; + PackAWithRowOffset& operator=(const PackAWithRowOffset&) = delete; + PackAWithRowOffset& operator=(PackAWithRowOffset&&) = delete; + + /** + * Activation matrices are not constant so cannot amortize the cost of + * pre-packing. + */ + bool isPrePacked() const { + return false; + } + + /** + * @return True if this is used as A matrix. + */ + static constexpr bool isA() { + return true; + } + + /** + * @return Offset of the element in the packed matrix that was at (i, j) in + * the source matrix + */ + std::int32_t addr(std::int32_t i, std::int32_t j) const; + + /** + * @brief Packs a block of source matrix into pmat buffer. + */ + void pack(const block_type_t& block); + + /** + * @return A pointer to the row offset buffer. + */ + std::int32_t* getRowOffsetBuffer() const { + return row_offset_; + } + + /** + * @brief Print the packed block. + */ + void printPackedMatrix(const std::string& name); + + /** + * @return size of row offset buffer in number of elements + */ + static int rowOffsetBufferSize(const BlockingFactors* params = nullptr); + + ~PackAWithRowOffset() override { + if (rowOffsetAllocatedHere) { + fbgemmAlignedFree(row_offset_); + } + } + + private: + matrix_op_t trans_; + const T* smat_; + std::uint32_t ld_; + std::int32_t* row_offset_{nullptr}; + bool rowOffsetAllocatedHere{false}; + std::int32_t row_interleave_B_; +}; + +/** + * @brief Matrix packed for the first input matrix in GEMM (usually activation), + * and row offsets used for requantization is computed during packing. + * The source matrix is in fp32 and quantized during packing. + */ +template +class FBGEMM_API PackAWithQuantRowOffset final + : public PackMatrix, T, accT> { + public: + using This = PackAWithQuantRowOffset; + using BaseType = PackMatrix; + using inpType = T; + using accType = accT; + + PackAWithQuantRowOffset() = delete; // no default constructor + /** + * @param row_offset If nullptr, this constructor internally allocates a + * buffer and owns it. Otherwise, this class doesn't own + * the buffer. The buffer will be populated when pack + * function is called. + */ + PackAWithQuantRowOffset( + matrix_op_t trans, + std::int32_t nRow, + std::int32_t nCol, + const float* smat, + std::int32_t ld, + inpType* pmat = nullptr, + float scale = 1.0f, + std::int32_t zero_pt = 0, + int groups = 1, + std::int32_t* row_offset = nullptr, + const BlockingFactors* params = nullptr); + PackAWithQuantRowOffset(const PackAWithQuantRowOffset&) = delete; + PackAWithQuantRowOffset(PackAWithQuantRowOffset&&) = delete; + PackAWithQuantRowOffset& operator=(const PackAWithQuantRowOffset&) = delete; + PackAWithQuantRowOffset& operator=(PackAWithQuantRowOffset&&) = delete; + + /** + * Activation matrices are not constant so cannot amortize the cost of + * pre-packing. + */ + bool isPrePacked() const { + return false; + } + + /** + * @return True if this is used as A matrix. + */ + static constexpr bool isA() { + return true; + } + + /** + * @return offset of the element in the packed matrix that was at (i, j) in + * the source matrix + */ + std::int32_t addr(std::int32_t i, std::int32_t j) const; + + /** + * @brief Packs a block of source matrix into pmat buffer. + */ + void pack(const block_type_t& block); + + /** + * @return A pointer to the row offset buffer. + */ + std::int32_t* getRowOffsetBuffer() const { + return row_offset_; + } + + /** + * @brief Print the packed block. + */ + void printPackedMatrix(const std::string& name); + + /** + * @return Size of row offset buffer in number of elements + */ + static int rowOffsetBufferSize(const BlockingFactors* params = nullptr); + + ~PackAWithQuantRowOffset() override { + if (rowOffsetAllocatedHere) { + fbgemmAlignedFree(row_offset_); + } + } + + private: + matrix_op_t trans_; + const float* smat_; + std::int32_t ld_; + float scale_; + std::int32_t zero_pt_; + std::int32_t* row_offset_{nullptr}; + bool rowOffsetAllocatedHere{false}; + std::int32_t row_interleave_B_; +}; + +/* + * + * Post Processing of outputs + * + */ + +/** + * @brief Does nothing. NoOp. Used as the last operation in the output + * processing pipeline. + * + */ +template +class FBGEMM_API DoNothing { + public: + using outType = outT; + using inpType = inT; + DoNothing() = default; + template + int f( + outType* /* unused */, + inpType* /* unused */, + const block_type_t& /* unused */, + int /* unused */, + int /* unused */) const { + return 0; + } +}; + +/** + * @brief Copy data pointed by inp ptr to out ptr when + * inp ptr and out ptr are not the same. + * inp buffer: row and column start points: (0, 0) + * output buffer: row and column start points: + * (block.row_start, block.col_start) + * + * This is the output processing stage that should passed when there is no + * requantization and output is required in the same format as internal buffer + * used for accumulation. + */ +template < + typename outT = std::int32_t, + typename inT = std::int32_t, + typename nextOPType = DoNothing> +class FBGEMM_API memCopy { + public: + using outType = outT; + using inpType = inT; + explicit memCopy(nextOPType& nextop) : nextop_(nextop) {} + template + inline int f( + outType* out, + inpType* inp, + const block_type_t& block, + int ld_out, + int ld_in) const; + + private: + nextOPType& nextop_; +}; + +/** + * @brief Perform scaling on accumulated data. + */ +template < + typename outT = std::int32_t, + typename inT = std::int32_t, + typename nextOPType = DoNothing> +class ScaleOP { + public: + using outType = outT; + using inpType = inT; + explicit ScaleOP(inpType scalingFactor) : scalingFactor_(scalingFactor) {} + + template + inline int f( + outType* out, + inpType* inp, + const block_type_t& block, + int ld_out, + int ld_in) const; + + private: + inpType scalingFactor_; +}; + +/** + * @brief Perform Relu on accumulated data. + */ +template < + typename outT = std::int32_t, + typename inT = std::int32_t, + typename nextOPType = DoNothing> +class ReluOutput { + public: + using outType = outT; + using inpType = inT; + explicit ReluOutput(inpType zero_pt) : zero_pt_(zero_pt) {} + + template + inline int f( + outType* out, + inpType* inp, + const block_type_t& block, + int ld_out, + int ld_in) const; + + private: + inpType zero_pt_; +}; + +/** + * @brief Perform Dense-Matrix * Sparse-Matrix as a part the of output + * processing pipeline. + * + * SPMDM (SParse Matrix times Dense Matrix) inplace on the 32-bit input buffer + * (inp). After modifying the input buffer, pass it to the next op. + * When groups > 1, each group is numRows() x (numCols()/groups) matrix. + */ +template < + typename outT = std::int32_t, + typename inT = std::int32_t, + typename nextOPType = DoNothing> +class FBGEMM_API DoSpmdmOnInpBuffer { + public: + using outType = outT; + using inpType = inT; + DoSpmdmOnInpBuffer( + nextOPType& nextop, + const std::uint8_t* A, + int lda, + const CompressedSparseColumn& B_csc, + int groups = 1) + : nextop_(nextop), A_(A), lda_(lda), B_csc_(B_csc), groups_(groups) {} + + template + inline int f( + outT* out, + inT* inp, + const block_type_t& block, + int ld_out, + int ld_in) const; + + private: + nextOPType& nextop_; + const std::uint8_t* A_; + const int lda_; + const CompressedSparseColumn& B_csc_; + const int groups_; +}; + +/** + * @brief Perform Dense-Matrix * Sparse-Matrix as a part the of output + * processing pipeline. + * + * SPMDM (SParse Matrix times Dense Matrix) inplace on the 32-bit input buffer + * (inp). After modifying the input buffer, pass it to the next op. + * When groups > 1, each group is numRows() x (numCols()/groups) matrix. + */ +template < + typename outT = std::int32_t, + typename inT = std::int32_t, + typename nextOPType = DoNothing> +class FBGEMM_API DoSConvOnInpBuffer { + public: + using outType = outT; + using inpType = inT; + DoSConvOnInpBuffer( + nextOPType& nextop, + const std::uint8_t* A, + const conv_param_t<>& conv_p, + std::int32_t A_zero_point, + const CompressedSparseColumn& B_csc) + : nextop_(nextop), + A_(A), + conv_p_(conv_p), + A_zero_point_(A_zero_point), + B_csc_(B_csc) {} + + template + inline int f( + outT* out, + inT* inp, + const block_type_t& block, + int ld_out, + int ld_in) const; + + private: + nextOPType& nextop_; + const std::uint8_t* A_; + const conv_param_t<> conv_p_; + const std::int32_t A_zero_point_; + const CompressedSparseColumn& B_csc_; +}; + +/** + * @brief Requantize values in inp buffer and write to out buffer. + * pass the out buffer to next op for further processing. + */ +template < + bool FUSE_RELU, + QuantizationGranularity Q_GRAN = QuantizationGranularity::TENSOR, + typename BIAS_TYPE = std::int32_t, + typename outT = std::uint8_t, + typename inT = std::int32_t, + typename nextOPType = DoNothing> +class FBGEMM_API ReQuantizeOutput { + public: + static constexpr int RELU_FUSED = FUSE_RELU; + static constexpr QuantizationGranularity QGRANType = Q_GRAN; + using BIAS_T = BIAS_TYPE; + using outType = outT; + using inpType = inT; + /** + * @param C_multiplier The length of this array is + * 1 when Q_GRAN == QuantizationGranularity::TENSOR, + * groups when Q_GRAN == QuantizationGranularity::GROUP, + * nCol if Q_GRAN == QuantizationGranularity::OUT_CHANNEL + * @param Bq_zero_point The length of this array should be the same as + * C_multiplier. + * @param row_offsets Typically, this should've been computed by a + * PackAMatrix and should be obtained by + * PackMatrix::getRowOffsetBuffer(). + * If Bq_zero_point == 0 (symmetric quantization of B + * matrix), we can pass nullptr. + * @param col_offsets This should be pre-computed for example using + * col_offsets_with_zero_pt_s8acc32_ref. + * The length should be nCol. + * See PackedRequantizeTest.cc for an example. + * TODO: if Aq_zero_point == 0, allow passing nullptr. + * @param bias can be nullptr otherwise the length should be nCol + * @param act_times_w_scale activation_scale * weight_scale. This is only + * used if bias is unquantized (i.e., float). + */ + ReQuantizeOutput( + nextOPType& nextop, + const float* C_multiplier, + std::int32_t C_zero_point, + std::int32_t Aq_zero_point, + const std::int32_t* Bq_zero_point, + const std::int32_t* row_offsets, + const std::int32_t* col_offsets, + const BIAS_T* bias, + std::uint32_t nCol, + int groups = 1, + const float* act_times_w_scale = nullptr) + : nextop_(nextop), + C_multiplier_(C_multiplier), + C_zero_point_(C_zero_point), + Aq_zero_point_(Aq_zero_point), + Bq_zero_point_(Bq_zero_point), + q_row_offsets_(row_offsets), + q_col_offsets_(col_offsets), + bias_(bias), + ncols_(nCol), + groups_(groups), + act_times_w_scale_(act_times_w_scale) {} + + template + inline int f( + outT* out, + const inT* inp, + const block_type_t& block, + int ld_out, + int ld_in) const; + + const float* getCMultiplier() const { + return C_multiplier_; + } + std::int32_t getAZeroPoint() const { + return Aq_zero_point_; + } + std::int32_t getCZeroPoint() const { + return C_zero_point_; + } + const std::int32_t* getBZeroPoint() const { + return Bq_zero_point_; + } + const std::int32_t* getRowOffsets() const { + return q_row_offsets_; + } + const std::int32_t* getColOffsets() const { + return q_col_offsets_; + } + const BIAS_T* getBias() const { + return bias_; + } + std::uint32_t getNCols() const { + return ncols_; + } + const float* getActWScale() const { + return act_times_w_scale_; + } + + void setRowOffsets(const std::int32_t* row_offsets) { + q_row_offsets_ = row_offsets; + } + + private: + nextOPType& nextop_; + const float* C_multiplier_; + std::int32_t C_zero_point_; + std::int32_t Aq_zero_point_; + const std::int32_t* Bq_zero_point_; + const std::int32_t* q_row_offsets_; + const std::int32_t* q_col_offsets_; + const BIAS_T* bias_; + std::uint32_t ncols_; + int groups_; + const float* act_times_w_scale_; +}; + +/** + * @brief Requantize to convert accumulated data to be used as float, i.e., the + * output would be used as float. + */ +template < + bool FUSE_RELU, + QuantizationGranularity Q_GRAN = QuantizationGranularity::TENSOR, + typename outT = float, + typename inT = std::int32_t, + typename nextOPType = DoNothing> +class FBGEMM_API ReQuantizeForFloat { + public: + using outType = outT; + using inpType = inT; + /** + * @param Bq_scale The length of this array is + * 1 when Q_GRAN == QuantizationGranularity::TENSOR, + * groups when Q_GRAN == QuantizationGranularity::GROUP, + * nCol if Q_GRAN == QuantizationGranularity::OUT_CHANNEL + * @param Bq_zero_point The length of this array should be the same as + * Bq_scale. + * @param row_offsets Typically, this should've been computed by a + * PackAMatrix and should be obtained by + * PackMatrix::getRowOffsetBuffer(). + * If Bq_zero_point == 0 (symmetric quantization of B + * matrix), we can pass nullptr. + * @param col_offsets This should be pre-computed for example using + * col_offsets_with_zero_pt_s8acc32_ref. + * The length should be nCol. + * See PackedRequantizeTest.cc for an example. + * TODO: if Aq_zero_point == 0, allow passing nullptr. + * @param bias can be nullptr otherwise the length should be nCol + */ + ReQuantizeForFloat( + nextOPType& nextop, + float Aq_scale, + const float* Bq_scale, + std::int32_t Aq_zero_point, + const std::int32_t* Bq_zero_point, + const std::int32_t* row_offsets, + const std::int32_t* col_offsets, + const float* bias, + std::uint32_t nCol, + int groups = 1) + : nextop_(nextop), + Aq_scale_(Aq_scale), + Bq_scale_(Bq_scale), + Aq_zero_point_(Aq_zero_point), + Bq_zero_point_(Bq_zero_point), + q_row_offsets_(row_offsets), + q_col_offsets_(col_offsets), + bias_(bias), + ncols_(nCol), + groups_(groups) {} + + template + inline int f( + outT* out, + inT* inp, + const block_type_t& block, + int ld_out, + int ld_in) const; + + private: + nextOPType& nextop_; + float Aq_scale_; + const float* Bq_scale_; + std::int32_t Aq_zero_point_; + const std::int32_t* Bq_zero_point_; + const std::int32_t* q_row_offsets_; + const std::int32_t* q_col_offsets_; + const float* bias_; + std::uint32_t ncols_; + int groups_; +}; + +// type specialized implementation in an include file +#include "./OutputProcessing-inl.h" // @manual + +/* + * + * ####### GEMM related functions ####### + * + */ + +/** + * Matrix B must be prepacked. For matrix A, packA.pack function is called to + * pack it. + * + * @tparam packingAMatrix processing of A matrix while packing, + * e.g., PackAWithQuantRowOffset + * + * @tparam packingBMatrix processing of B matrix while packing, + * e.g., pre-multiply by alpha + * @tparam cT data type of C matrix + * @tparam processOutputType further processing of outputs, e.g., Relu + */ +template < + typename packingAMatrix, + typename packingBMatrix, + typename cT, + typename processOutputType> +FBGEMM_API void fbgemmPacked( + PackMatrix< + packingAMatrix, + typename packingAMatrix::inpType, + typename packingAMatrix::accType>& packA, + PackMatrix< + packingBMatrix, + typename packingBMatrix::inpType, + typename packingBMatrix::accType>& packB, + cT* C, + std::int32_t* C_buffer, + std::uint32_t ldc, + const processOutputType& outProcess, + int thread_id, + int num_threads, + const BlockingFactors* blocking_params = nullptr); + +/** + * @brief Perform small-channels-per-group groupwise convolution + * Note: Currently threading is not supported. This function does + * nothing for thread_ids > 0, i.e., returns early. + * + * @param rowOffsetBuf nullptr if B uses symmetric quantization + * Note: Currently threading is not supported. This function does + * nothing for thread_ids > 0, i.e., returns early. + */ +template < + typename packed_W, + typename outType, + bool FUSE_RELU, + QuantizationGranularity Q_GRAN, + int SPATIAL_DIM = 2, + typename BIAS_TYPE = std::int32_t> +FBGEMM_API void fbgemmGroupwiseConv( + const conv_param_t& conv_param, + const std::uint8_t* activations, + std::int32_t a_zero_point, + std::int32_t* rowOffsetBuf, + packed_W& packed_weights, + outType* out, + std::int32_t* outBuffer, + const ReQuantizeOutput& outProcess, + int thread_id, + int num_threads); + +template < + int SPATIAL_DIM, + QuantizationGranularity Q_GRAN, + bool FUSE_RELU, + typename BIAS_TYPE = std::int32_t> +FBGEMM_API void fbgemmDirectConv( + const conv_param_t& conv_p, + const uint8_t* Aint8, + PackedDirectConvMatrix& Bint8_tr, + uint8_t* C, + int32_t* C_buffer, + const ReQuantizeOutput& outProcess, + const BIAS_TYPE* bias, + int thread_id, + int num_threads); + +/** + * @return Size of row offset buffer in number of elements needed for + * fbgemmGroupwiseConv + */ +template +FBGEMM_API int rowOffsetBufferSizeGConv( + const conv_param_t& conv_param); + +/** + * @brief Is this depthwise convolution optimized? + */ +template +bool takeDepthWiseFastPath(const conv_param_t& conv_p); + +/** + * @brief Is this groupwise convolution supported? + */ +template +FBGEMM_API bool fbgemmOptimizedGConv(const conv_param_t& conv_p); + +/** + * @brief Is this convolution a direct matrix-matrix multiplication, i.e., 1x1 + * (aka pointwise) with right paddings etc.? + */ +template +FBGEMM_API bool takePointWiseFastPath(const conv_param_t& conv_p); + +/** + * @brief Are we running on a fbgemm supported cpu? + */ +FBGEMM_API bool fbgemmSupportedCPU(); + +/** + * @brief Performs convolution using fastest path available. + * + * @tparam SPATIAL_DIM It's 2 for 2D convolutions and 3 for 3D convolutions. + */ +template < + typename processOutputType, + int SPATIAL_DIM = 2, + typename ACC_T = std::int32_t> +FBGEMM_API int fbgemmConv( + const conv_param_t& conv_p, + const std::uint8_t* activations, + PackWeightsForConv& packed_weights, + typename processOutputType::outType* out, + std::int32_t* outBuffer, + processOutputType& outProcess, + int thread_id, + int num_threads, + const BlockingFactors* blocking_params = nullptr); + +/** + * @brief Returns which fast path to take + * + * @tparam SPATIAL_DIM It's 2 for 2D convolutions and 3 for 3D convolutions. + * + * @return optimized_conv_t::depthwise, optimized_conv_t::groupwise or + * optimized_conv_t::im2col + * + */ +template +FBGEMM_API optimized_conv_t +ConvFastPath(const conv_param_t& conv_p); +} // namespace fbgemm + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fbgemm/FbgemmBuild.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fbgemm/FbgemmBuild.h new file mode 100644 index 0000000000000000000000000000000000000000..f413d17980a5b8fcc3ea3e4f98b6ec81a98512c5 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fbgemm/FbgemmBuild.h @@ -0,0 +1,116 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +// For details about dllexport/dllimport, checkout the following SO question +// https://stackoverflow.com/questions/57999/what-is-the-difference-between-dllexport-and-dllimport +#if !defined(FBGEMM_API) +#if defined(FBGEMM_STATIC) +#define FBGEMM_API +#define FBGEMM_ENUM_CLASS_API +#elif defined _WIN32 || defined __CYGWIN__ +#if (__GNUC__ || __clang__) && !(__MINGW64__ || __MINGW32__) +#if defined(FBGEMM_EXPORTS) +#define FBGEMM_API __attribute__((__dllexport__)) +#else +#define FBGEMM_API __attribute__((__dllimport__)) +#endif +#else +#if defined(FBGEMM_EXPORTS) +#define FBGEMM_API __declspec(dllexport) +#else +#define FBGEMM_API __declspec(dllimport) +#endif +#endif +#define FBGEMM_ENUM_CLASS_API +#else +#if __clang__ || __GNUC__ || __INTEL_COMPILER +#define FBGEMM_API __attribute__((__visibility__("default"))) +#else +#define FBGEMM_API +#endif +// Currently, enum classes need to be declaredly explicitly for shared build on +// macos +#if __clang__ +#define FBGEMM_ENUM_CLASS_API __attribute__((__visibility__("default"))) +#else +#define FBGEMM_ENUM_CLASS_API +#endif +#endif +#endif + +// Use this to indicate to not inline functions +#if __clang__ || __GNUC__ || __INTEL_COMPILER +#define NOINLINE __attribute__((noinline)) +#elif _MSC_VER +#define NOINLINE __declspec(noinline) +#else +#define NOINLINE +#endif + +// Use this to indicate always inline functions +#if __clang__ || __GNUC__ || __INTEL_COMPILER +#define ALWAYS_INLINE inline __attribute__((__always_inline__)) +#elif _MSC_VER +// commenting out because __forceinline takes too long time in MSVC +#define ALWAYS_INLINE // __forceinline +#else +#define ALWAYS_INLINE inline +#endif + +// Use the C++11 keyword "alignas" if you can +#if _MSC_VER +#define ALIGNAS(byte_alignment) __declspec(align(byte_alignment)) +#else +#define ALIGNAS(byte_alignment) __attribute__((aligned(byte_alignment))) +#endif + +// Sanitizers annotations +#if defined(__has_attribute) +#if __has_attribute(no_sanitize) +#define NO_SANITIZE(what) __attribute__((no_sanitize(what))) +#endif +#endif +#if !defined(NO_SANITIZE) +#define NO_SANITIZE(what) +#endif + +// Ignore __builtin_assume() when not supported by compiler. +#ifndef __has_builtin +#define __has_builtin(x) 0 +#endif +#if !__has_builtin(__builtin_assume) +#define __builtin_assume(x) (static_cast(0)) +#endif + +// Macro for silencing warnings +#if __clang__ || __GNUC__ +// clang-format off +#define FBGEMM_PUSH_WARNING _Pragma("GCC diagnostic push") +#define FBGEMM_DISABLE_WARNING_INTERNAL2(warningName) #warningName +#define FBGEMM_DISABLE_WARNING(warningName) \ + _Pragma( \ + FBGEMM_DISABLE_WARNING_INTERNAL2(GCC diagnostic ignored warningName)) +#define FBGEMM_PUSH_WARNING_AND_DISABLE(warningName) \ + _Pragma("GCC diagnostic push") \ + _Pragma( \ + FBGEMM_DISABLE_WARNING_INTERNAL2(GCC diagnostic ignored warningName)) +#define FBGEMM_POP_WARNING _Pragma("GCC diagnostic pop") +// clang-format on +#else +#define FBGEMM_PUSH_WARNING +#define FBGEMM_DISABLE_WARNING(NAME) +#define FBGEMM_PUSH_WARNING_AND_DISABLE(NAME) +#define FBGEMM_POP_WARNING +#endif + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fbgemm/FbgemmConvert.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fbgemm/FbgemmConvert.h new file mode 100644 index 0000000000000000000000000000000000000000..6574dfc305700356f3ea6d16be69512c45208212 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fbgemm/FbgemmConvert.h @@ -0,0 +1,205 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include "fbgemm/FbgemmBuild.h" +#include "fbgemm/Types.h" + +namespace fbgemm { + +/** + * @ Transform all entries in a matrix from fp32 to bfloat16: reference + * implementation. + * + */ +FBGEMM_API void +FloatToBfloat16_ref(const float* src, bfloat16* dst, size_t size); + +/** + * @ Transform all entries in a matrix from bfloat16 to fp32: reference + * implementation. + * + */ +FBGEMM_API void +Bfloat16ToFloat_ref(const bfloat16* src, float* dst, size_t size); + +/** + * @ Transform all entries in a matrix from fp32 to bfloat16: simd + * implementation. + * + */ +FBGEMM_API void +FloatToBfloat16_simd(const float* src, bfloat16* dst, size_t size); + +/** + * @ Transform all entries in a matrix from bfloat16 to fp32: simd + * implementation. + * + */ +FBGEMM_API void +Bfloat16ToFloat_simd(const bfloat16* src, float* dst, size_t size); + +#if !defined(__aarch64__) +/** + * @brief AVX2 implementation to convert fp32 numbers to bf16 numbers. + * + */ +FBGEMM_API void +FloatToBfloat16_avx2(const float* src, bfloat16* dst, size_t size); + +/** + * @brief AVX512 implementation to convert fp32 numbers to bf16 numbers. + * + */ +FBGEMM_API void +FloatToBfloat16_avx512(const float* src, bfloat16* dst, size_t size); + +/** + * @brief AVX2 implementation to convert bf16 numbers to fp32 numbers. + * + */ +FBGEMM_API void +Bfloat16ToFloat_avx2(const bfloat16* src, float* dst, size_t size); + +/** + * @brief AVX512 implementation to convert bf16 numbers to fp32 numbers. + * + */ +FBGEMM_API void +Bfloat16ToFloat_avx512(const bfloat16* src, float* dst, size_t size); +#endif + +/** + * @ Transform all entries in a matrix from fp32 to float16: reference + * implementation. + * + * @param do_clip if true we saturate to fp16 min and max instead of generating + * infinities. + */ +FBGEMM_API void FloatToFloat16_ref( + const float* src, + float16* dst, + size_t size, + bool do_clip = false); + +/** + * @ Transform all entries in a matrix from float16 to fp32: reference + * implementation. + * + */ +FBGEMM_API void Float16ToFloat_ref(const float16* src, float* dst, size_t size); + +/** + * @ Transform all entries in a matrix from fp32 to float16: simd + * implementation. + * + * @param do_clip if true we saturate to fp16 min and max instead of generating + * infinities. + */ +FBGEMM_API void FloatToFloat16_simd( + const float* src, + float16* dst, + size_t size, + bool do_clip = false); + +/** + * @ Transform all entries in a matrix from float16 to fp32: simd + * implementation. + * + */ +FBGEMM_API void +Float16ToFloat_simd(const float16* src, float* dst, size_t size); + +/** + * @brief AVX2 implementation to convert fp32 numbers to fp16 numbers. + * + */ +#if !defined(__aarch64__) +FBGEMM_API void FloatToFloat16_avx2( + const float* src, + float16* dst, + size_t size, + bool do_clip = false); + +/** + * @brief AVX512 implementation to convert fp32 numbers to fp16 numbers. + * + */ +FBGEMM_API void FloatToFloat16_avx512( + const float* src, + float16* dst, + size_t size, + bool do_clip = false); +#endif + +/** + * @brief SVE2 implementation to convert fp32 numbers to fp16 numbers. + * + */ +FBGEMM_API void FloatToFloat16_sve2( + const float* src, + float16* dst, + size_t size, + bool do_clip = false); + +#if !defined(__aarch64__) +/** + * @brief AVX2 implementation to convert fp16 numbers to fp32 numbers. + * + */ +FBGEMM_API void +Float16ToFloat_avx2(const float16* src, float* dst, size_t size); + +/** + * @brief AVX512 implementation to convert fp16 numbers to fp32 numbers. + * + */ +FBGEMM_API void +Float16ToFloat_avx512(const float16* src, float* dst, size_t size); +#endif + +/** + * @brief Transform all entries in a matrix from fp32 to float16 and back to + * fp32. + */ +FBGEMM_API void RoundToFloat16( + const float* input, + float* output, + size_t size, + bool clamp = false, + bool clamp_denorms = false); + +/** + * @brief Quantize float32 to float8. The code is a copy of float_to_hfp8() in + * fbgemm_gpu/quantize_ops_utils.h + */ +FBGEMM_API void FloatToFloat8_ref( + float input, + uint8_t* output, + int exponent_bits, + int exponent_bias); + +/** + * @brief Dequantize float8 to float32. The code is a copy of hf8_to_float() in + * fbgemm_gpu/quantize_ops_utils.h + */ +FBGEMM_API void Float8ToFloat_ref( + uint8_t input, + float* output, + int exponent_bits, + int exponent_bias); + +} // namespace fbgemm + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fbgemm/FbgemmEmbedding.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fbgemm/FbgemmEmbedding.h new file mode 100644 index 0000000000000000000000000000000000000000..0da43d39fbdfd5d0df2958cc1f6d0b5fe5ff4cdb --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fbgemm/FbgemmEmbedding.h @@ -0,0 +1,383 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once +#include +#include + +#include "fbgemm/FbgemmBuild.h" + +namespace fbgemm { + +template < + typename InType, + typename IndexType, + typename OffsetType = std::int32_t, + typename OutType = float> +class EmbeddingSpMDMKernelSignature { + public: + /** + * Behavior is as the follow pseudocode + * (when use_offsets == true, lengths[i] == offsets[i + 1] - offsets[i]) + * (when is_weight_positional == true, use weights[j - offsets[i]] instead of + * weights[j]) + * + * for i in range(output_size): + * out[i * block_size : (i + 1) * block_size] = 0 + * for j in range(offsets[i], offsets[i + 1]): + * for k in range(block_size): + * out[i * block_size + k] += input[indices[j] * block_size + k] * + * weights ? weights[j] : 1; + * if normalize_weights and lengths[i] > 0: + * out[i * block_size : (i + 1) * block_size] /= lengths[i] + * + * @param data_size the number of rows in embedding table + */ + using Type = std::function; +}; + +/** + * @tparam InType can be float, float16, or uint8_t + * @tparam IndexType can be int32_t or int64_t + * @tparam IndexType can be int32_t or int64_t + * + * @param use_offsets If true, the generated code assumes we will pass offsets + * instead of lengths that confirms PyTorch EmbeddingBag + * interface. In this case, the length of offsets array + * should be output_size + 1 and offsets[output_size] should + * be index_size. + * If false, the generate code assumes we will pass lengths + * that confirms Caffe2 SparseLengthsSum interface. + */ +template < + typename InType, + typename IndexType, + typename OffsetType = std::int32_t, + typename OutType = float, + bool THREAD_LOCAL = false> +FBGEMM_API typename EmbeddingSpMDMKernelSignature< + InType, + IndexType, + OffsetType, + OutType>::Type +GenerateEmbeddingSpMDM( + const std::int64_t block_size, + bool has_weight, + bool normalize_by_lengths, + int prefetch = 16, + bool is_weight_positional = false, + bool use_offsets = true, + bool is_bf16_out = false, + bool is_bf16_in = false); + +/** + * @param output_stride If -1, output_stride is same as block_size + * @param input_stride If -1, input_stride is same as block_size + * @param scale_bias_last if false, scale and bias appear at the beginning + * of each row and are in fp16 for table batched embedding (TBE) + * in FBGEMM_GPU. If false, it can also take -1 indices (output from + * pruned embedding id mapping) + */ +template < + typename InType, + typename IndexType, + typename OffsetType = std::int32_t, + typename OutType = float, + bool THREAD_LOCAL = false> +FBGEMM_API typename EmbeddingSpMDMKernelSignature< + InType, + IndexType, + OffsetType, + OutType>::Type +GenerateEmbeddingSpMDMWithStrides( + const std::int64_t block_size, + bool has_weight, + bool normalize_by_lengths, + int prefetch = 16, + bool is_weight_positional = false, + bool use_offsets = true, + std::int64_t output_stride = -1, + std::int64_t input_stride = -1, + bool scale_bias_last = true, + bool no_bag = false, + bool is_bf16_out = false, + bool is_bf16_in = false); + +/** + * @tparam IndexType can be int32_t or int64_t + * @tparam OffsetType can be int32_t or int64_t + * @param bit_rate can be 2 or 4 + */ +template < + typename IndexType, + typename OffsetType = std::int32_t, + typename OutType = float> +FBGEMM_API typename EmbeddingSpMDMKernelSignature< + std::uint8_t, + IndexType, + OffsetType, + OutType>::Type +GenerateEmbeddingSpMDMNBit( + int bit_rate, + const std::int64_t block_size, + bool has_weight, + bool normalize_by_lengths, + int prefetch = 16, + bool is_weight_positional = false, + bool use_offsets = true); + +/** + * @param output_stride If -1, output_stride is same as block_size + * @param input_stride in Bytes. If -1, input_stride is same as + * block_size / num_elem_per_byte + 2 * sizeof(float16) + * @param scale_bias_last if false, scale and bias appear at the beginning + * of each row and are in fp16 for table batched embedding (TBE) + * in FBGEMM_GPU. If false, it can also take -1 indices (output from + * pruned embedding id mapping) + */ +template < + typename IndexType, + typename OffsetType = std::int32_t, + typename OutType = float, + bool THREAD_LOCAL = false> +FBGEMM_API typename EmbeddingSpMDMKernelSignature< + std::uint8_t, + IndexType, + OffsetType, + OutType>::Type +GenerateEmbeddingSpMDMNBitWithStrides( + const int input_bit_rate, + const std::int64_t block_size, + bool has_weight, + bool normalize_by_lengths, + int prefetch = 16, + bool is_weight_positional = false, + bool use_offsets = true, + std::int64_t output_stride = -1, + std::int64_t input_stride = -1, + bool scale_bias_last = true, + const bool is_bf16_out = false, + const bool no_bag = false, + int output_bit_rate = -1); + +/** + * @param output_stride If -1, output_stride is same as block_size + * @param input_stride in Bytes. If -1, input_stride is same as + * block_size / num_elem_per_byte + 2 * sizeof(float16) + * @param exponent_bits is the number of exponent bits in the FP8 encode + * (normally 4 or 5) + * @param exponent_bias is subtracted from the exponent to obtain the actual + * exponent for the floating-point number + */ +template < + typename IndexType, + typename OffsetType = std::int32_t, + typename OutType = float> +FBGEMM_API typename EmbeddingSpMDMKernelSignature< + std::uint8_t, + IndexType, + OffsetType, + OutType>::Type +GenerateEmbeddingSpMDMFP8WithStrides( + const std::int64_t block_size, + bool normalize_by_lengths, + bool is_weight_positional = false, + bool use_offsets = true, + std::int64_t output_stride = -1, + std::int64_t input_stride = -1, + int exponent_bits = 4, + int exponent_bias = 7, + bool is_bf16_out = false); + +template < + typename InType, + typename IndexType, + typename OffsetType = std::int32_t> +class EmbeddingSpMDMRowWiseSparseKernelSignature { + public: + using Type = std::function; +}; + +/** + * @tparam InType can be float, float16, or uint8_t + * @tparam IndexType can be int32_t or int64_t + * @tparam OffsetType can be int32_t or int64_t + */ +template < + typename InType, + typename IndexType, + typename OffsetType = std::int32_t> +FBGEMM_API typename EmbeddingSpMDMRowWiseSparseKernelSignature< + InType, + IndexType, + OffsetType>::Type +GenerateEmbeddingSpMDMRowWiseSparse( + const std::int64_t block_size, + bool has_weight, + bool normalize_by_lengths, + int prefetch = 16, + bool is_weight_positional = false, + bool use_offsets = true); + +/** + * @tparam IndexType can be int32_t or int64_t + * @tparam OffsetType can be int32_t or int64_t + * @param bit_rate can be 2 or 4 + */ +template +FBGEMM_API typename EmbeddingSpMDMRowWiseSparseKernelSignature< + std::uint8_t, + IndexType, + OffsetType>::Type +GenerateEmbeddingSpMDMNBitRowWiseSparse( + int bit_rate, + const std::int64_t block_size, + bool has_weight, + bool normalize_by_lengths, + int prefetch = 16, + bool is_weight_positional = false, + bool use_offsets = true); + +/** + * @return The number of rows processed. If smaller than num_rows, an error + * must have happened at the last row processed. + */ +template +class SparseAdaGradSignature { + public: + using Type = std::function; // frequency adjust happens only after +}; + +template +FBGEMM_API typename SparseAdaGradSignature::Type +GenerateSparseAdaGrad( + int block_size, // number of parameters per row + bool rowwise = false, + int prefetch = 16, + bool use_weight_decay = false); + +// RowWiseSparseAdaGrad fused with SLS gradient +// Weights can be either float or float16 +template < + typename IndexType, + typename OffsetType = std::int32_t, + typename DataType = float> +class RowWiseSparseAdaGradFusedSignature { + public: + using Type = std::function; +}; + +/** + * @param grad_stride If -1, grad_stride is same as block size + */ +template < + typename IndexType, + typename OffsetType = std::int32_t, + typename DataType = float> +FBGEMM_API typename RowWiseSparseAdaGradFusedSignature< + IndexType, + OffsetType, + DataType>::Type +GenerateRowWiseSparseAdaGradFused( + int block_size, // number of parameters per row + int prefetch = 16, + bool use_offsets = true, + bool use_stochastic_rounding = true, + int grad_stride = -1); + +namespace internal { +// Specialization for block size 1 internally called by GenerateEmbeddingSpMDM +template +FBGEMM_API bool EmbeddingSpMDMBlockSize1_( + const std::int64_t output_size, + const std::int64_t index_size, + const std::int64_t data_size, // the number of rows in input + const InType* input, + const IndexType* indices, + const OffsetType* offsets_or_lengths, + const float* weights, // optional, can be null for non-weighted sum + bool normalize_by_lengths, + float* out, + bool is_weight_positional = false, + bool use_offsets = true, + bool is_bf16 = false); + +#if !defined(__aarch64__) +template +void compressed_indices_remap_avx512( + std::int32_t offsets_numel, + const IndexType* indices, + const int32_t* compressed_indices_mapping, + const IndexType* offsets, + const float* weights, // optional, can be null, + IndexType* out_indices, + IndexType* out_offsets, + float* out_weights); +#endif + +} // namespace internal + +template +FBGEMM_API void compressed_indices_remap( + std::int32_t offsets_numel, + const IndexType* indices, + const int32_t* compressed_indices_mapping, + const IndexType* offsets, + const float* weights, // optional, can be null, + IndexType* out_indices, + IndexType* out_offsets, + float* out_weights); + +} // namespace fbgemm + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fbgemm/FbgemmFP16.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fbgemm/FbgemmFP16.h new file mode 100644 index 0000000000000000000000000000000000000000..254142290ef03e8aeb5ee40c28443af6f737a6bb --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fbgemm/FbgemmFP16.h @@ -0,0 +1,60 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +// WARNING: this is a legacy fp16 fbgemm implementation and will soon be +// upgraded to match with new fbgemm interface. + +#include + +#include "./FbgemmPackMatrixB.h" // @manual +#include "./FloatConversion.h" // @manual +#include "./Types.h" // @manual +#include "./Utils.h" // @manual + +namespace fbgemm { + +template <> +struct TypeConverter { + float16 operator()(float src) const { + constexpr float FP16_MAX = 65504.f; + const float fp16 = std::max(-FP16_MAX, std::min(src, FP16_MAX)); + return cpu_float2half(fp16); + } +}; + +using PackedGemmMatrixFP16 = PackedGemmMatrixB; + +template +FBGEMM_API void cblas_gemm_compute( + const matrix_op_t transa, + const int m, + const float* A, + const PackedGemmMatrixB& Bp, + const float beta, + float* C, + int thread_id = 0, + int num_threads = 1); + +extern template void cblas_gemm_compute( + const matrix_op_t transa, + const int m, + const float* A, + const PackedGemmMatrixFP16& Bp, + const float beta, + float* C, + int thread_id, + int num_threads); + +}; // namespace fbgemm + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fbgemm/FbgemmFP32.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fbgemm/FbgemmFP32.h new file mode 100644 index 0000000000000000000000000000000000000000..91c0c4c7ce6bb8e062391b1b612dfc86720d2ef0 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fbgemm/FbgemmFP32.h @@ -0,0 +1,54 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#pragma once + +// WARNING: this is a legacy fp16 fbgemm implementation and will soon be +// upgraded to match with new fbgemm interface. + +#include + +#include "fbgemm/FbgemmFPCommon.h" +#include "fbgemm/FbgemmPackMatrixB.h" +#include "fbgemm/Utils.h" + +namespace fbgemm { +template <> +struct TypeConverter { + float operator()(float src) const { + return src; + } +}; + +using GemmParamsFP32 = GemmParams; +using PackedGemmMatrixFP32 = PackedGemmMatrixB; + +template +void cblas_gemm_compute( + const matrix_op_t transa, + const int m, + const float* A, + const PackedGemmMatrixB& Bp, + const float beta, + float* C, + int thread_id = 0, + int num_threads = 1); + +extern template void cblas_gemm_compute( + const matrix_op_t transa, + const int m, + const float* A, + const PackedGemmMatrixFP32& Bp, + const float beta, + float* C, + int thread_id, + int num_threads); + +template <> +const isa_descriptor& getIsaHandlers(inst_set_t isa); + +} // namespace fbgemm + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fbgemm/FbgemmFPCommon.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fbgemm/FbgemmFPCommon.h new file mode 100644 index 0000000000000000000000000000000000000000..e4fd09a18f70ab177229aad256043b955dea884d --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fbgemm/FbgemmFPCommon.h @@ -0,0 +1,319 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * Copyright 2024-2025 Arm Limited and/or its affiliates + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include +#include + +#if defined(FBGEMM_FP16_FALLBACK_TO_REF_KERNEL) || \ + defined(FBGEMM_FP32_FALLBACK_TO_REF_KERNEL) +#if defined(__APPLE__) && defined(__aarch64__) +#define FBGEMM_USE_REF_KERNEL +#endif +#endif + +namespace fbgemm { + +using partition_array_t = std::array, 2>, 121>; +extern partition_array_t partition_avx2; +extern partition_array_t partition_avx512; +extern partition_array_t partition_sve128; +#ifdef FBGEMM_ENABLE_KLEIDIAI +extern partition_array_t partition_neon; +#endif + +template +struct GemmParams { + uint64_t k; + float* A; + const T* B; + float beta; + float* C; + uint64_t ldc; + uint64_t b_block_cols; + uint64_t b_block_size; +}; + +template <> +struct GemmParams { + uint64_t k; + float* A; + const float16* B; + float beta; + float* C; + uint64_t ldc; + uint64_t b_block_cols; +#ifdef FBGEMM_ENABLE_KLEIDIAI + uint64_t lda; +#else + uint64_t b_block_size; +#endif +}; + +template <> +struct GemmParams { + uint64_t k; + float* A; + const float* B; + float beta; + float* C; + uint64_t ldc; + uint64_t b_block_cols; +#ifdef FBGEMM_ENABLE_KLEIDIAI + uint64_t lda; +#else + uint64_t b_block_size; +#endif +}; + +template +using funcptr_t = void (*)(GemmParams*); +template +using kernel_array_t = std::array, 15>; +template +using isa_descriptor = std::tuple, partition_array_t>; + +template +extern const isa_descriptor& getIsaHandlers(inst_set_t isa); + +void PackA(int nrow, int ncol, const float* from, int ldim, float* to); + +// define fp16/fp32 kernels using a reference C implementation +#if defined(FBGEMM_FP16_FALLBACK_TO_REF_KERNEL) || \ + defined(FBGEMM_FP32_FALLBACK_TO_REF_KERNEL) +template +FBGEMM_API void ref_kernel( + int kernel_nrows, + GemmParams* gp, + const float* C_base, + int m_total, + int n_total, + int vlen); +#endif + +template +FBGEMM_API void cblas_gemm_compute( + const matrix_op_t transa, + const int m, + const float* A, + const PackedGemmMatrixB& Bp, + const float beta, + float* C, + int thread_id = 0, + int num_threads = 1); + +#if defined(FBGEMM_EXPORTS) +// autotuned kernel splits for various cases m = 1:mb_max +template +void cblas_gemm_compute( + const matrix_op_t transa [[maybe_unused]], + const int m, + const float* A, + const PackedGemmMatrixB& Bp, + const float beta, + float* C, + int thread_id, + int num_threads) { + // ground truth + assert(cpuinfo_initialize()); +#ifndef __aarch64__ + assert(cpuinfo_has_x86_fma3()); + assert(cpuinfo_has_x86_f16c()); +#endif + assert(transa == matrix_op_t::NoTranspose); + + // private scratchpad storage + static thread_local std::unique_ptr> scratchpad( + new std::array()); + + // constants + const int n = Bp.numCols(), k = Bp.numRows(), ldc = n; + const int mb_max = 120; + +#if defined(FBGEMM_USE_REF_KERNEL) && defined(__APPLE__) + const auto& [_, partition] = getIsaHandlers(inst_set_t::sve); +#else + const auto iset = fbgemmInstructionSet(); + const auto& [kernels, partition] = getIsaHandlers(iset); +#endif + +#ifdef FBGEMM_USE_REF_KERNEL + // By some reason, if packed B is using packing layout for avx2, we just use + // avx2 even if avx512 is available. + const int simd_width = +#ifndef __aarch64__ + (iset == inst_set_t::avx512 || iset == inst_set_t::avx512_vnni) && + (Bp.blockColSize() == 16 * Bp.kernelNumColBlocks()) + ? simd_info::WIDTH_32BIT_ELEMS + : simd_info::WIDTH_32BIT_ELEMS; +#else + simd_info::WIDTH_32BIT_ELEMS; +#endif +#endif + + GemmParams gp; + int i_begin = 0, i_end = 0; + i_begin = 0; + i_end = m; + for (auto m0 = i_begin; m0 < i_end; m0 += mb_max) { + int mb = std::min(mb_max, i_end - m0); + assert(mb < static_cast(partition.size())); + for (auto k_ind = 0; k_ind < k; k_ind += Bp.blockRowSize()) { + // set up proper accumulation to avoid "Nan" problem + // accumulate of beta != 0.0 + // do not!!! accumulate otherwise + float beta_ = beta; + if (k_ind != 0) { + // always accumulate with beta_ = 1.0f + beta_ = 1.0f; + } + + const int kb = std::min(Bp.blockRowSize(), Bp.numRows() - k_ind); + + auto m1 = m0; + auto const num_cycles = partition[mb].size(); + for (size_t c = 0; c < num_cycles; ++c) { + auto kernel_nrows = partition[mb][c][0]; + auto nkernel_nrows = partition[mb][c][1]; + auto m_start = m1; + auto m_end = m1 + kernel_nrows * nkernel_nrows; + for (auto m2 = m_start; m2 < m_end; m2 += kernel_nrows) { + assert(kernel_nrows * kb < static_cast(scratchpad->size())); + if (m != 1) { +#ifdef FBGEMM_ENABLE_KLEIDIAI + if constexpr ( + std::is_same::value || + std::is_same::value) { + gp.A = const_cast(&A[m2 * k + k_ind]); + } else { +#endif + PackA( + kernel_nrows, kb, &A[m2 * k + k_ind], k, scratchpad->data()); + gp.A = scratchpad->data(); +#ifdef FBGEMM_ENABLE_KLEIDIAI + } +#endif + } else { + // When m == 1, it is actually vector matrix multiplication. We + // don't need to do the transposition for packA here. Instead, we + // can just pass the pointer of the original A matrix buffer to the + // packed A buffer. + gp.A = const_cast(&A[k_ind]); + } + + int nbcol = n / Bp.blockColSize(); + gp.k = kb; + gp.B = &(Bp(k_ind, 0)); + gp.beta = beta_; + gp.C = &C[m2 * ldc]; + gp.ldc = ldc * sizeof(C[0]); + gp.b_block_cols = nbcol; +#ifdef FBGEMM_ENABLE_KLEIDIAI + if constexpr ( + std::is_same::value || + std::is_same::value) { + gp.lda = k * sizeof(A[0]); + } else { +#endif + gp.b_block_size = gp.k * Bp.blockColSize() * sizeof(gp.B[0]); +#ifdef FBGEMM_ENABLE_KLEIDIAI + } +#endif + if ((n % Bp.blockColSize()) == 0) { + int64_t jb_begin = 0, jb_end = 0; + fbgemmPartition1D( + thread_id, num_threads, gp.b_block_cols, jb_begin, jb_end); + gp.B += gp.k * Bp.blockColSize() * jb_begin; + gp.C += Bp.blockColSize() * jb_begin; + gp.b_block_cols = jb_end - jb_begin; + if (gp.b_block_cols) { +#ifdef FBGEMM_USE_REF_KERNEL + ref_kernel(kernel_nrows, &gp, C, m, n, simd_width); +#else + kernels[kernel_nrows](&gp); +#endif + } + } else { + int last_blk_col = nbcol * Bp.blockColSize(); + if (nbcol) { + int64_t jb_begin = 0, jb_end = 0; + fbgemmPartition1D( + thread_id, num_threads, gp.b_block_cols, jb_begin, jb_end); + gp.B += gp.k * Bp.blockColSize() * jb_begin; + gp.C += Bp.blockColSize() * jb_begin; + gp.b_block_cols = jb_end - jb_begin; + if (gp.b_block_cols) { +#ifdef FBGEMM_USE_REF_KERNEL + ref_kernel(kernel_nrows, &gp, C, m, n, simd_width); +#else + kernels[kernel_nrows](&gp); +#endif + } + } + + // use one thread to handle the fringe cases + if (thread_id == num_threads - 1) { + // leftover + const int rem [[maybe_unused]] = n - last_blk_col; + assert(rem < Bp.blockColSize()); + + // small temporary buffer: the size should be larger than the + // required kernel_nrow x kernel_ncols elements computed in the + // registers. + std::array c_tmp{0.f}; + assert( + static_cast(c_tmp.size()) >= + kernel_nrows * Bp.blockColSize()); + + gp.B = &(Bp(k_ind, last_blk_col)); + gp.C = c_tmp.data(); + gp.ldc = Bp.blockColSize() * sizeof(C[0]); + gp.b_block_cols = 1; +#ifdef FBGEMM_USE_REF_KERNEL + ref_kernel( + kernel_nrows, &gp, c_tmp.data(), 14, 32, simd_width); +#else + kernels[kernel_nrows](&gp); +#endif + for (int i = 0; i < kernel_nrows; i++) { + // Todo: use assembly + for (int j = last_blk_col; j < n; j++) { + assert( + i * Bp.blockColSize() + (j - last_blk_col) < + static_cast(sizeof(c_tmp) / sizeof(c_tmp[0]))); + if (beta_ == 0.f) { + C[(m2 + i) * ldc + j] = + c_tmp[i * Bp.blockColSize() + (j - last_blk_col)]; + } else { + C[(m2 + i) * ldc + j] = beta_ * C[(m2 + i) * ldc + j] + + c_tmp[i * Bp.blockColSize() + (j - last_blk_col)]; + } + } + } + } + } + } + m1 += kernel_nrows * nkernel_nrows; + } + } + } +} +#endif + +#undef FBGEMM_USE_REF_KERNEL +} // namespace fbgemm + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fbgemm/FbgemmI64.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fbgemm/FbgemmI64.h new file mode 100644 index 0000000000000000000000000000000000000000..8d95013257c8419d468815ce055bfb56898a432c --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fbgemm/FbgemmI64.h @@ -0,0 +1,36 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +#include "fbgemm/Utils.h" + +namespace fbgemm { + +FBGEMM_API void cblas_gemm_i64_i64acc( + matrix_op_t transa, + matrix_op_t transb, + int M, + int N, + int K, + const std::int64_t* A, + int lda, + const std::int64_t* B, + int ldb, + bool accumulate, + std::int64_t* C, + int ldc); + +} // namespace fbgemm + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fbgemm/FbgemmI8DepthwiseAvx2.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fbgemm/FbgemmI8DepthwiseAvx2.h new file mode 100644 index 0000000000000000000000000000000000000000..70571cded3d6f574afd2f471cb418baf2ab9022a --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fbgemm/FbgemmI8DepthwiseAvx2.h @@ -0,0 +1,117 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include "fbgemm/ConvUtils.h" +#include "fbgemm/FbgemmBuild.h" +#include "fbgemm/UtilsAvx2.h" + +namespace fbgemm { + +class FBGEMM_API PackedDepthWiseConvMatrix { + public: + /** + * @param IC the number of input channels (same as the number of groups + * because depth-wise convolution has one input channel per group) + * @param OC the number of output channels + * @param kernel_prod the product of all kernels. For example, kernel_prod = + * 9 for 3x3 conv, and 27 for 3x3x3 conv. + * @param smat the source unpacked weight in GRS layout + */ + PackedDepthWiseConvMatrix(int OC, int kernel_prod, const std::int8_t* smat); + PackedDepthWiseConvMatrix(const PackedDepthWiseConvMatrix&) = delete; + PackedDepthWiseConvMatrix(PackedDepthWiseConvMatrix&&) = delete; + PackedDepthWiseConvMatrix& operator=(const PackedDepthWiseConvMatrix&) = + delete; + PackedDepthWiseConvMatrix& operator=(PackedDepthWiseConvMatrix&&) = delete; + virtual ~PackedDepthWiseConvMatrix(); + + const std::int8_t* PackedMat() const { + return pmat_; + } + + int GetKernelProduct() const { + return kernel_prod_; + } + + /** + * @brief Unpacks pmat_ into unpack_data. + * Used for recovering the weight matrix into the original format + */ + void unpack(std::int8_t* unpacked_data); + + /** + * @brief returns the index into pmat_ given the row and column for smat + */ + int addr(int r, int c); + + private: + const int OC_; /**< the number of output channels */ + const int kernel_prod_; /** the product of all kernel dims */ + std::int8_t* pmat_; /** packed weight */ +}; // PackedDepthWiseConvMatrix + +/** + * Depth-wise convolution that results in the same output feature size as the + * input feature. That is PAD_T = PAD_B = (R - 1) / 2 and PAD_L = PAD_R = + * (S - 1) / 2. This function also does requantization. + * @param col_offsets nullptr if col_offsets are folded into bias + * @param act_times_w_scale Only used if BIAS_TYPE is float, i.e., bias is + * unquantized. + */ +template +FBGEMM_API void depthwise_2d_same_pad( + int N, + int H, + int W, + int IC, + int OC, + int stride_h, + int stride_w, + std::int32_t A_zero_point, + const std::uint8_t* A, + const std::int32_t* B_zero_point, + const PackedDepthWiseConvMatrix& Bp, + const float* C_multiplier, + std::int32_t C_zero_point, + std::uint8_t* C, + const std::int32_t* col_offsets, + const BIAS_TYPE* bias, + bool fuse_relu = false, + const float* act_times_w_scale = nullptr, + int thread_id = 0, + int num_threads = 1); + +/** + * @param col_offsets nullptr if col_offsets are folded into bias + */ +template +FBGEMM_API void depthwise_3d_same_pad( + const conv_param_t<3>& conv_p, + std::int32_t A_zero_point, + const std::uint8_t* A, + const std::int32_t* B_zero_point, + const PackedDepthWiseConvMatrix& Bp, + const float* C_multiplier, + std::int32_t C_zero_point, + std::uint8_t* C, + const std::int32_t* col_offsets, + const BIAS_TYPE* bias, + bool fuse_relu = false, + const float* act_times_w_scale = nullptr, + int thread_id = 0, + int num_threads = 1); + +} // namespace fbgemm + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fbgemm/FbgemmI8DirectconvAvx2.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fbgemm/FbgemmI8DirectconvAvx2.h new file mode 100644 index 0000000000000000000000000000000000000000..e0cd02f1eea7b442fa31bbe03ca5074ca30b2f1c --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fbgemm/FbgemmI8DirectconvAvx2.h @@ -0,0 +1,69 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include "fbgemm/ConvUtils.h" +#include "fbgemm/FbgemmBuild.h" + +namespace fbgemm { + +class FBGEMM_API PackedDirectConvMatrix { + public: + /** + * @param IC the number of input channels + * @param OC the number of output channels + * @param kernel_prod the product of all kernels. For example, kernel_prod = + * 9 for 3x3 conv, and 27 for 3x3x3 conv. + * @param smat the source unpacked weight in GRS layout + */ + PackedDirectConvMatrix( + int IC_per_G, + int OC_per_G, + int filter_prod, + const std::int8_t* smat); + PackedDirectConvMatrix(const PackedDirectConvMatrix&) = delete; + PackedDirectConvMatrix(PackedDirectConvMatrix&&) = delete; + PackedDirectConvMatrix& operator=(const PackedDirectConvMatrix&) = delete; + PackedDirectConvMatrix& operator=(PackedDirectConvMatrix&&) = delete; + + virtual ~PackedDirectConvMatrix(); + + const std::int8_t* PackedMat() const { + return pmat_; + } + + const bool& is_first_call() const { + return first_call; + } + + /** + compute the column offsets of the weight matrix. + output of this function is the col_offsets vector + col_offses dimension is the same as conv_p.OUT_DIM + */ + template + FBGEMM_API void col_offsets_with_zero_pt_s8acc32_DirectConvT( + const fbgemm::conv_param_t& conv_p, + std::int32_t* B_zero_point, + std::vector& col_offsets, + int ncols_per_quant_group); + + private: + std::int8_t* pmat_; /** packed weight */ + bool first_call{true}; +}; + +} // namespace fbgemm + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fbgemm/FbgemmI8Spmdm.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fbgemm/FbgemmI8Spmdm.h new file mode 100644 index 0000000000000000000000000000000000000000..650f5d6bf6ab5c67548f79997f170ef26cdaf614 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fbgemm/FbgemmI8Spmdm.h @@ -0,0 +1,140 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include "./ConvUtils.h" // @manual +#include "./FbgemmBuild.h" // @manual +#include "./Utils.h" // @manual + +// #define FBGEMM_MEASURE_TIME_BREAKDOWN + +#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN +#include +#include +extern double spmdm_initial_time; +extern double spmdm_transpose_uint8_time; +extern double spmdm_transpose_32xN_time; +extern double spmdm_compute_time; +extern double spmdm_transpose_Nx32_time; +extern double spmdm_run_time; +extern double sconv_run_time; +#endif + +namespace fbgemm { + +/** + * @brief A class to represent a matrix in Compressed Sparse Column (CSC) + * format. + * + * The second input matrix of matrix multiplication is usually weight and can + * be sparse, and it's usually more efficient to use CSC format to represent + * the second input matrix. + */ +class FBGEMM_API CompressedSparseColumn { + public: + CompressedSparseColumn(int num_of_rows, int num_of_cols); + + std::vector& ColPtr() { + return colptr_; + } + std::vector& RowIdx() { + return rowidx_; + } + std::vector& Values() { + return values_; + } + std::vector& KHs() { + return kh_; + } + std::vector& KWs() { + return kw_; + } + /** + * ICs include group: i.e. for ith input channels withint group g, ICs contain + * g*(groups_per_input_channels) + i + */ + std::vector& ICs() { + return ic_; + } + + std::size_t NumOfRows() const { + return num_rows_; + } + std::size_t NumOfCols() const { + return colptr_.size() - 1; + } + std::int32_t NumOfNonZeros() const { + return colptr_.back(); + } + + /** + * @return Total number of non-zero elements as a fraction of total + * elements. + */ + double Density() const; + + /** + * @return True if the number of non-zeros per row is smaller than a small + * threshold. + */ + bool IsHyperSparse() const; + + /** + * @brief Perform dense-matrix * sparse matrix. + * + * C += A (dense matrix) * B (this CSC matrix) if accumulation = true \n + * C = A (dense matrix) * B (this CSC matrix) if accumulation = false + */ + void SpMDM( + const block_type_t& block, + const std::uint8_t* A, + int lda, + bool accumulation, + std::int32_t* C, + int ldc) const; + + void SparseConv( + const conv_param_t<>& conv_p, + const block_type_t& block, + const std::uint8_t* A, + std::int32_t A_zero_point, + bool accumulation, + std::int32_t* C, + int ldc) const; + + private: + const std::size_t num_rows_; + std::vector colptr_; // corresponds to out channels + std::vector values_; + + // For SpMDM + std::vector rowidx_; // kh kw ic are flattened with im2col + + // For direct sparse convolution + std::vector kh_; + std::vector kw_; + std::vector ic_; // in channels + + // Cache IsHyperSparse to minimize its overhead. + mutable bool hyper_sparse_{false}; + + // Whether we can reuse the cached hyper_sparse_ is determined by checking + // if NumOfNonZeros() is same as old_nnz_ saved in previous invocation of + // IsHyperSparse call. + mutable std::int32_t old_nnz_{-1}; +}; + +} // namespace fbgemm + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fbgemm/FbgemmPackMatrixB.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fbgemm/FbgemmPackMatrixB.h new file mode 100644 index 0000000000000000000000000000000000000000..407c372c434d7091455f7664baccdc902805c4e2 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fbgemm/FbgemmPackMatrixB.h @@ -0,0 +1,339 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * Copyright 2024-2025 Arm Limited and/or its affiliates + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include + +#include "SimdUtils.h" // @manual +#include "Types.h" // @manual +#include "Utils.h" // @manual + +namespace fbgemm { + +template +struct TypeConverter { + template + T operator()(F) const; +}; + +#define PMAT_ALIGNMENT 64 + +/// class that performs packing of matrix in +/// row-major format into +/// internal packed blocked-row major format +template > +class PackedGemmMatrixB { + public: + using value_type = T; + using size_type = uint64_t; + + // takes smat input mamtrix in row-major format; + // packs it into gemm-friendly blocked format; + // allocate space and sets up all the internal variables; + // also premultiplies by alpha during packing. + // brow_ contains tile size along k dimension + // and also is # of fmas updates into int16 container + // before flushing into fp32. + // the smaller the brow_, the higher overhead + // of flushing is. + // kernel_ncol_blocks is the number of column blocks (in the size of 8 fp16, + // or 128 bit, or 1 xmm register size) in the kernel. Because the batch size + // can be dynamic and we need to prepack the weight matrix B, the internal + // packing layout of the weight matrix and kernel_ncol_blocks have to be + // fixed. We can choose kernel_ncol_blocks = 1 (with kernels of 1x1~14x1 + // register layouts), 2 (with kernels of 1x2~6x2 register layout), or 3 (with + // kernels of 1x3~4x3 register layout). + PackedGemmMatrixB( + const matrix_op_t trans, + const int nrow, + const int ncol, + const float alpha, + const float* smat, + const int brow = 512) + : nrow_(nrow), ncol_(ncol), brow_(brow), kernel_ncol_blocks_(2) { +#ifdef FBGEMM_ENABLE_KLEIDIAI + if constexpr (std::is_same::value) { + kernel_ncol_blocks_ = 1; + } +#endif + initializeParam(); + initializeMemory(); + // copy source matrix into packed matrix + this->packFromSrc(trans, alpha, smat); + } + + PackedGemmMatrixB( + const int nrow, + const int ncol, + const int brow, + const int last_brow, + const int bcol, + const int nbrow, + const int nbcol, + const uint64_t size) + : nrow_(nrow), + ncol_(ncol), + brow_(brow), + last_brow_(last_brow), + bcol_(bcol), + nbrow_(nbrow), + nbcol_(nbcol), + size_(size), + kernel_ncol_blocks_(2) { +#ifdef FBGEMM_ENABLE_KLEIDIAI + if constexpr (std::is_same::value) { + kernel_ncol_blocks_ = 1; + } +#endif + initializeMemory(); + } + + PackedGemmMatrixB( + const int nrow, + const int ncol, + const int brow, + const int last_brow, + const int bcol, + const int nbrow, + const int nbcol, + const uint64_t size, + const int kernel_ncol_blocks, + void* pmat) + : nrow_(nrow), + ncol_(ncol), + brow_(brow), + last_brow_(last_brow), + bcol_(bcol), + nbrow_(nbrow), + nbcol_(nbcol), + size_(size), + kernel_ncol_blocks_(kernel_ncol_blocks) { +#ifdef FBGEMM_ENABLE_KLEIDIAI + if constexpr (std::is_same::value) { + kernel_ncol_blocks_ = 1; + } +#endif + pmat_ = static_cast(pmat); + packed_ = true; + pmat_passed_in = true; + } + PackedGemmMatrixB(const PackedGemmMatrixB&) = delete; + PackedGemmMatrixB(PackedGemmMatrixB&&) = delete; + PackedGemmMatrixB& operator=(const PackedGemmMatrixB&) = delete; + PackedGemmMatrixB& operator=(PackedGemmMatrixB&&) = delete; + + void initializeParam() { + if (!cpuinfo_initialize()) { + throw std::runtime_error("Failed to initialize cpuinfo!"); + } + bcol_ = (isZmm(fbgemmInstructionSet()) + ? simd_info::WIDTH_32BIT_ELEMS + : simd_info::WIDTH_32BIT_ELEMS) * + kernelNumColBlocks(); + + // set up internal packing parameters + nbrow_ = (numRows() + blockRowSize() - 1) / blockRowSize(); + last_brow_ = ((nrow_ % blockRowSize()) == 0) ? blockRowSize() + : (nrow_ % blockRowSize()); + nbcol_ = (numCols() + blockColSize() - 1) / blockColSize(); + + if (numCols() != blockColSize() * nbcol_) { +#ifdef VLOG + VLOG(0) << "Packer warning: ncol(" << numCols() + << ") is not a multiple of internal block size (" + << blockColSize() << ")"; + VLOG(0) << "lefover is not super optimized hence overhead will inccur"; +#endif + } + } + + void setPacked(bool p) { + packed_ = p; + } + + bool packed() const { + return packed_; + } + + void initializeMemory() { + // allocate and initialize packed memory + size_ = (blockRowSize() * nbrow_) * (blockColSize() * nbcol_); + pmat_ = static_cast( + fbgemmAlignedAlloc(PMAT_ALIGNMENT, matSize() * sizeof(T))); + memset(pmat_, 0, matSize() * sizeof(T)); + } + + ~PackedGemmMatrixB() { + if (pmat_passed_in == false) { + fbgemmAlignedFree(pmat_); + } + } + + void unpackFromSrc(const matrix_op_t trans, T* src_mat) { + bool tr = (trans == matrix_op_t::Transpose); + for (int i = 0; i < numRows(); i++) { + for (int j = 0; j < numCols(); j++) { + pmat_[tr ? i + numRows() * j : i * numCols() + j] = src_mat[addr(i, j)]; + } + } + packed_ = false; + } + + void unpack(T* origin_buf, const matrix_op_t trans) { + assert(packed_); + bool tr = (trans == matrix_op_t::Transpose); + for (int i = 0; i < numRows(); i++) { + for (int j = 0; j < numCols(); j++) { + origin_buf[tr ? i + numRows() * j : i * numCols() + j] = + pmat_[addr(i, j)]; + } + } + } + + // protected: + // blocked row-major format address arithmetic + uint64_t addr(const int r_, const int c_) const { + uint64_t r = (uint64_t)r_; + uint64_t c = (uint64_t)c_; + + uint64_t block_row_id = r / blockRowSize(); + uint64_t brow_offset = + (block_row_id * nbcol_) * (blockRowSize() * blockColSize()); + uint64_t block_col_id = c / blockColSize(); + uint64_t bcol_offset = block_col_id * + ((static_cast(block_row_id) != nbrow_ - 1) + ? (blockRowSize() * blockColSize()) + : (last_brow_ * blockColSize())); + uint64_t block_offset = brow_offset + bcol_offset; + uint64_t inblock_offset = + r % blockRowSize() * blockColSize() + c % blockColSize(); + + uint64_t index = block_offset + inblock_offset; + assert(static_cast(index) < matSize()); + return index; + } + + void + packFromSrc(const matrix_op_t trans, const float alpha, const float* smat) { + bool tr = (trans == matrix_op_t::Transpose); + // pack + for (int i = 0; i < numRows(); i++) { + for (int j = 0; j < numCols(); j++) { + float src = alpha * + ((tr == false) ? smat[i * numCols() + j] : smat[i + numRows() * j]); + pmat_[addr(i, j)] = C()(src); + } + } + packed_ = true; + } + + // This function takes in an unpacked T matrix of the same size and + // packs it. There is no floating type conversion. + void packFromSrc(const matrix_op_t trans, const T* smat) { + bool tr = (trans == matrix_op_t::Transpose); + for (int i = 0; i < numRows(); ++i) { + for (int j = 0; j < numCols(); ++j) { + pmat_[addr(i, j)] = smat[tr ? i + numRows() * j : i * numCols() + j]; + } + } + packed_ = true; + } + + const T& operator()(const int r, const int c) const { + const auto a = addr(r, c); + assert(r < numRows()); + assert(c < numCols()); + assert(static_cast(a) < this->matSize()); + return pmat_[a]; + } + + int matSize() const { + return size_; + } + int numRows() const { + return nrow_; + } + int numCols() const { + return ncol_; + } + int lastBrow() const { + return last_brow_; + } + int numBrow() const { + return nbrow_; + } + int numBcol() const { + return nbcol_; + } + T* pmat() const { + return pmat_; + } + int blockRowSize() const { + return brow_; + } + int blockColSize() const { + return bcol_; + } + int kernelNumColBlocks() const { + return kernel_ncol_blocks_; + } + + const value_type* data() const { + return pmat_; + } + + uint64_t size() const { + return size_ / sizeof(value_type); + } + + int nrow_, ncol_; + int brow_, last_brow_, bcol_; + int nbrow_, nbcol_; + uint64_t size_; + int kernel_ncol_blocks_; + T* pmat_; + bool packed_{false}; + bool pmat_passed_in{false}; +}; + +#ifndef _M_X64 + +template <> +FBGEMM_API +PackedGemmMatrixB>::PackedGemmMatrixB( + const matrix_op_t trans, + const int nrow, + const int ncol, + const float alpha, + const float* smat, + const int brow); + +template <> +FBGEMM_API +PackedGemmMatrixB>::PackedGemmMatrixB( + const int nrow, + const int ncol, + const int brow, + const int last_brow, + const int bcol, + const int nbrow, + const int nbcol, + const uint64_t size); + +#endif + +} // namespace fbgemm + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fbgemm/FbgemmSparse.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fbgemm/FbgemmSparse.h new file mode 100644 index 0000000000000000000000000000000000000000..21a3a111271bc10b949cc9912d75fc775ffe3488 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fbgemm/FbgemmSparse.h @@ -0,0 +1,230 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include + +#include "fbgemm/FbgemmBuild.h" +#include "fbgemm/UtilsAvx2.h" +#include "fbgemm/spmmUtilsAvx2.h" + +namespace fbgemm { + +template +struct FBGEMM_API CSRMatrix { + std::vector rowPtr; + std::vector colIdx; + std::vector values; +}; + +/** + * Tiled block CSR format + * Partial blocks are zero-filled + * + */ +template +struct FBGEMM_API BCSRMatrix { + using DTYPE = T; + static constexpr int RB = ROW_BLOCK; // Block size for rows + static constexpr int CB = COL_BLOCK; // Block size for cols + // We only tile in column dimension currently + // COLTILE must be a multiple of COL_BLOCK + static constexpr int COLTILE = 4000; + std::vector rowBPtr; // rowPtr for blocks + std::vector colBIdx; // colIdx for blocks + std::vector values; + // Sum of all elements in a row + std::vector row_offsets; + int R; + int C; + + BCSRMatrix(int Rows, int Cols) { + R = Rows; + C = Cols; + row_offsets.resize(R, 0); + } + + /** + * @brief pack from dense to tiled block CSR format + * @param R number of rows in the matrix + * @param C number of columns in the matrix + * @param src is the source matrix with data type DTYPE + * @param ld is the leading dimension + */ + void pack(const DTYPE* src, size_t ld); + + /** + * @brief pack from dense to tiled block CSR format + * @param R number of rows in the matrix + * @param C number of columns in the matrix + * @param src is the source matrix with data type DTYPE + * + * leading dim of the matrix is assumed to be equal to C + */ + void pack(const DTYPE* src); + + /** + * @brief unpack from tiled block CSR to dense + * @param dst should be able to hold R*C elements of type DTYPE + * @param ld is the leading dimension + */ + void unpack(DTYPE* dst, size_t ld); + + /* + * @brief unpack from tiled block CSR to dense + * @param dst should be able to hold R*C elements of type DTYPE + * + * leading dimension of the matrix is assumed to be equal to C + */ + void unpack(DTYPE* dst); +}; + +template +FBGEMM_API std::unique_ptr> +fbgemmDenseToCSR(int R, int C, const T* inp, int ld); + +template +FBGEMM_API std::unique_ptr> +fbgemmDenseToCSR(int R, int C, const T* inp); + +template +FBGEMM_API std::unique_ptr> +fbgemmDenseToBCSR(int R, int C, const T* inp, int ld); + +template +FBGEMM_API std::unique_ptr> +fbgemmDenseToBCSR(int R, int C, const T* inp); + +/** + * @param accum Controls accumulation. + * 1 means we're accumulating to the C Matrix. + * + * Note on matrix order and layout: + * Unlike other fbgemm functions that follow PyTorch convention where A + * matrix is activation (so in uint8_t for quantized FC/Conv or fp32) and B + * matrix is weight (so in int8_t for quantized FC/Conv or fp32), here A is + * weight matrix. This is because we mostly target sparsity in weights and for + * row-major layout it's more efficient to have A as a sparse matrix: for each + * non-zero of A at ith row and kth column, we can access kth row of B, whose + * elements are contiguous in memory. If B matrix was sparse, for each non-zero + * of B at kth row and jth column, we would've needed to access kth column of A, + * whose elements are not contiguous in memory with C/C++'s row-major layout. + * Alternatively, we can call this function as if we're computing + * C^T = B^T * A^T while maintaining PyTorch's convention that the lefthand + * side matrix B is activation. If B matrix is in column-major layout, we don't + * need to do an extra transposition. The C matrix will be output in + * column-major layout, so if we have a back-to-back Sparse-Dense matrix-matrix + * multiplications, B matrices of subsequent matrices will be already in + * column-major layout. Refer to SparseDenseMMFP32Benchmark.cc for an example. + * + */ +FBGEMM_API void SparseDenseMM( + int M, + int N, + const int* row_ptr, + const int* col_idx, + const float* values, + const float* B, + int ldb, + float* C, + int ldc, + bool accum = false); + +template +FBGEMM_API void fbgemmSparseDenseInt8MM( + int N, + const std::unique_ptr>& bcsr, + const uint8_t* B, + int ldb, + int32_t* C_i32, + uint8_t* C_u8, + int ldc, + trRequantizationParams_t& rParams, + bool accum = false, + int thread_id = 0, + int num_threads = 1); + +namespace internal { + +void SparseDenseMMAvx2( + int M, + int N, + const int* row_ptr, + const int* col_idx, + const float* values, + const float* B, + int ldb, + float* C, + int ldc, + bool accum = false); + +#if !defined(__aarch64__) +void SparseDenseMMAvx512( + int M, + int N, + const int* row_ptr, + const int* col_idx, + const float* values, + const float* B, + int ldb, + float* C, + int ldc, + bool accum = false); + +template +void SparseDenseInt8MMAvx2( + int N, + const std::unique_ptr>& bcsr, + const uint8_t* B, + int ldb, + int32_t* C_i32, + uint8_t* C_u8, + int ldc, + trRequantizationParams_t& rParams, + bool accum = false, + int thread_id = 0, + int num_threads = 1); + +template +void SparseDenseInt8MMAvx512( + int N, + const std::unique_ptr>& bcsr, + const uint8_t* B, + int ldb, + int32_t* C_i32, + uint8_t* C_u8, + int ldc, + trRequantizationParams_t& rParams, + bool accum = false, + int thread_id = 0, + int num_threads = 1); + +template +void SparseDenseInt8MVAvx512( + const std::unique_ptr>& bcsr, + const uint8_t* B, + int ldb, + int32_t* C_i32, + uint8_t* C_u8, + trRequantizationParams_t& rParams, + bool accum = false, + int thread_id = 0, + int num_threads = 1); +#endif + +} // namespace internal + +} // namespace fbgemm + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fbgemm/FloatConversion.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fbgemm/FloatConversion.h new file mode 100644 index 0000000000000000000000000000000000000000..949e9ebefde3ad93ecdd934f7703a6be2b84ea79 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fbgemm/FloatConversion.h @@ -0,0 +1,331 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +#include +#include +#include +#include +#include + +#include "./Types.h" // @manual + +#ifndef __is_identifier +#define __is_identifier(x) 1 +#endif + +#define __has_keyword(__x) !(__is_identifier(__x)) + +// TODO: we're disabling native fp16 on Windows to workaround test failures +// due to "undefined symbol __gnu_h2f_ieee" error. We should follup on this +// later. +#if __has_keyword(__fp16) && !defined(_WIN32) +#define HAS_NATIVE_FP16_TYPE +using native_fp16_t = __fp16; +#elif __has_keyword(_Float16) && !defined(_WIN32) +#define HAS_NATIVE_FP16_TYPE +using native_fp16_t = _Float16; +#else +using native_fp16_t = void; +#endif + +namespace fbgemm { + +namespace detail { + +template +struct FloatFormat { + using value_type = T; + static constexpr int bits = sizeof(T) * CHAR_BIT; + static constexpr int exponent_bits = ExponentBits; + static constexpr int mantissa_bits = bits - exponent_bits - 1; + static constexpr int sign_bit_pos = bits - 1; + static constexpr int exponent_bias = (1 << (exponent_bits - 1)) - 1; + static constexpr int unbiased_exponent_min = -exponent_bias + 1; + static constexpr int unbiased_exponent_max = + HasInfinity ? exponent_bias : (exponent_bias + 1); + static constexpr T sign_bit = T{1} << sign_bit_pos; + static constexpr T exponent_mask = ((T{1} << exponent_bits) - 1) + << mantissa_bits; + static constexpr T mantissa_mask = (T{1} << mantissa_bits) - 1; + // signaling/quiet encoding is unspecified by IEEE754. This mirrors x86/ARM. + static constexpr T quiet_nan_bit = T{1} << (mantissa_bits - 1); + + static constexpr T nan = exponent_mask | mantissa_mask; + static constexpr T overflow_value = HasInfinity ? exponent_mask : nan; + static constexpr bool has_infinity = HasInfinity; + static constexpr bool has_nan_payload = HasInfinity; +}; + +using IEEE754Single = FloatFormat; +using IEEE754Half = FloatFormat; +// See https://arxiv.org/abs/1905.12322v3 +using BFloat16 = FloatFormat; +// See https://doi.org/10.48550/arXiv.2209.05433 +using FP8_E5M2 = FloatFormat; +// See https://doi.org/10.48550/arXiv.2209.05433 +using FP8_E4M3FN = FloatFormat< + /*T=*/uint8_t, + /*ExponentBits=*/4, + /*HasInfinity=*/false>; + +enum class RoundingMode { + ToNearestTiesToEven, + ToZero, +}; + +// Generic IEEE754 truncation algorithm. +template +[[gnu::always_inline]] inline typename Tgt::value_type ieee754_trunc( + typename Src::value_type value) { + static_assert(Src::exponent_bits >= Tgt::exponent_bits); + static_assert(Src::mantissa_bits > Tgt::mantissa_bits); + using ST = typename Src::value_type; + using TT = typename Tgt::value_type; + + ST src_exponent = value & Src::exponent_mask; + ST src_mantissa = value & Src::mantissa_mask; + // Fast-path: If there is no difference in exponent sizes (e.g. fp32 -> bf16) + // and we round toward zero, then we can just drop the least significant bits. + if constexpr ( + Src::exponent_bits == Tgt::exponent_bits && Src::has_infinity && + Tgt::has_infinity && RoundingMode == RoundingMode::ToZero) { + TT result = value >> (Src::bits - Tgt::bits); + // Turn signaling NaN into quiet NaN. This also avoids that the mantissa + // is completely zero after truncation (which would be misinterpreted as + // INF). + if (src_exponent == Src::exponent_mask && src_mantissa != 0) { + result |= Tgt::quiet_nan_bit; + } + return result; + } + + ST tgt_sign = + (value & Src::sign_bit) >> (Src::sign_bit_pos - Tgt::sign_bit_pos); + constexpr bool denormal_becomes_zero = + Tgt::unbiased_exponent_min - Src::unbiased_exponent_min > + Src::mantissa_bits - Tgt::mantissa_bits; + if constexpr (denormal_becomes_zero) { + // Fast-path for zero exponentbits: This means the number was zero or a + // denormal number that will turn into zero in the Tgt format. + if (src_exponent == 0) { + return tgt_sign; // tgt_exponent == 0, tgt_mantissa == 0 + } + } + + int unbiased_exponent = + (src_exponent >> Src::mantissa_bits) - Src::exponent_bias; + if (unbiased_exponent < Tgt::unbiased_exponent_min) { + int shift = Tgt::unbiased_exponent_min - unbiased_exponent; + if (shift <= Tgt::mantissa_bits + 1) { + // Result is denormal. + ST src_mantissa_one = src_mantissa; + // Add explicit one if the source was not denormal. + if (denormal_becomes_zero || src_exponent != 0) { + src_mantissa_one |= TT{1} << Src::mantissa_bits; + } else { + shift--; + } + TT tgt_mantissa = + src_mantissa_one >> (Src::mantissa_bits - Tgt::mantissa_bits + shift); + + if constexpr (RoundingMode == RoundingMode::ToNearestTiesToEven) { + int half_pos = Src::mantissa_bits - Tgt::mantissa_bits + shift - 1; + ST half = 1 << half_pos; + ST remainder = src_mantissa_one & ((half << 1) - 1); + if (remainder > half || + (remainder == half && (tgt_mantissa & 1) != 0)) { + tgt_mantissa += 1; + } + } else { + static_assert(RoundingMode == RoundingMode::ToZero); + } + return tgt_sign | tgt_mantissa; // tgt_exponent == 0 + } else { + // Result is +/- zero + return tgt_sign; // tgt_exponent == 0, tgt_mantissa == 0 + } + } + + if (unbiased_exponent > Tgt::unbiased_exponent_max) { + if (unbiased_exponent == Src::exponent_bias + 1 && src_mantissa != 0) { + TT tgt_mantissa; + if constexpr (Tgt::has_nan_payload) { + // NaN; not a number + tgt_mantissa = + src_mantissa >> (Src::mantissa_bits - Tgt::mantissa_bits); + tgt_mantissa |= Tgt::quiet_nan_bit; + } else { + tgt_mantissa = Tgt::mantissa_mask; + } + return tgt_sign | Tgt::exponent_mask | tgt_mantissa; + } else { + if (RoundingMode == RoundingMode::ToZero && + (!Src::has_infinity || src_exponent != Src::exponent_mask)) { + // Return largest finite number. + return tgt_sign | (Tgt::exponent_mask - Tgt::has_infinity) | + Tgt::mantissa_mask; + } + // Infinity or NaN for formats without infinity. + return tgt_sign | Tgt::overflow_value; + } + } + + // Normal number. + TT tgt_mantissa = src_mantissa >> (Src::mantissa_bits - Tgt::mantissa_bits); + TT tgt_exponent = (unbiased_exponent + Tgt::exponent_bias) + << Tgt::mantissa_bits; + if constexpr (RoundingMode == RoundingMode::ToNearestTiesToEven) { + ST half = 1 << (Src::mantissa_bits - Tgt::mantissa_bits - 1); + ST remainder = src_mantissa & ((half << 1) - 1); + if (remainder > half || (remainder == half && (tgt_mantissa & 1) != 0)) { + if (tgt_mantissa < Tgt::mantissa_mask) { + tgt_mantissa += 1; + } else { + // Mantissa overflowed, increment exponent. + + // Normally we can just add to the exponent and will naturally end up + // on infinity on overflow. But we need special treatments for formats + // without infinity. + if (Tgt::has_infinity || tgt_exponent != Tgt::exponent_mask) { + tgt_mantissa = 0; + tgt_exponent += TT{1} << Tgt::mantissa_bits; + } else { + // Return NaN. + tgt_mantissa = Tgt::mantissa_mask; + } + } + } + } else { + static_assert(RoundingMode == RoundingMode::ToZero); + } + return tgt_sign | tgt_exponent | tgt_mantissa; +} + +} // namespace detail + +inline float16 cpu_float2half_rn(float f) { + uint32_t f_u32 = 0; + std::memcpy(&f_u32, &f, sizeof(f_u32)); + return detail::ieee754_trunc< + /*Src=*/detail::IEEE754Single, + /*Tgt=*/detail::IEEE754Half, + detail::RoundingMode::ToNearestTiesToEven>(f_u32); +} + +inline float16 cpu_float2half_rz(float f) { + uint32_t f_u32 = 0; + std::memcpy(&f_u32, &f, sizeof(f_u32)); + return detail::ieee754_trunc< + /*Src=*/detail::IEEE754Single, + /*Tgt=*/detail::IEEE754Half, + detail::RoundingMode::ToZero>(f_u32); +} + +// Converts a 16-bit unsigned integer representation of a IEEE754 half-precision +// float into an IEEE754 32-bit single-precision float +inline float cpu_half2float_ref(const float16 h) { + constexpr uint32_t f16_num_exponent_bits = 5; + constexpr uint32_t f16_num_mantissa_bits = 10; + constexpr uint32_t f16_num_non_sign_bits = + f16_num_exponent_bits + f16_num_mantissa_bits; + constexpr uint32_t f16_exponent_bias = 15; + constexpr uint32_t f16_exponent_mask = 0b1'1111; + constexpr uint32_t f16_mantissa_mask = 0b11'1111'1111; + + constexpr uint32_t f32_num_exponent_bits = 8; + constexpr uint32_t f32_num_mantissa_bits = 23; + constexpr uint32_t f32_num_non_sign_bits = + f32_num_exponent_bits + f32_num_mantissa_bits; + constexpr uint32_t f32_exponent_bias = 127; + constexpr uint32_t f32_exponent_mask = 0b1111'1111; + constexpr uint32_t f32_mantissa_mask = 0x7F'FF'FF; + constexpr uint32_t f32_most_significant_bit = 1u << 22; + + // Get sign and exponent alone by themselves + uint32_t sign_bit = (h >> f16_num_non_sign_bits) & 1; + uint32_t exponent = (h >> f16_num_mantissa_bits) & f16_exponent_mask; + // Shift mantissa so that it fills the most significant bits of a float32 + uint32_t mantissa = (h & f16_mantissa_mask) + << (f32_num_mantissa_bits - f16_num_mantissa_bits); + + if (exponent == f16_exponent_mask) { // NaN or Inf + if (mantissa) { + mantissa = f32_mantissa_mask; + sign_bit = 0; + } + exponent = f32_exponent_mask; + } else if (!exponent) { // Denorm or Zero + if (mantissa) { + uint32_t msb = 0; + exponent = f32_exponent_bias - f16_exponent_bias + 1; + do { + msb = mantissa & f32_most_significant_bit; + mantissa <<= 1; // normalize + --exponent; + } while (!msb); + mantissa &= f32_mantissa_mask; // 1.mantissa is implicit + } + } else { + exponent += f32_exponent_bias - f16_exponent_bias; + } + + const uint32_t i = (sign_bit << f32_num_non_sign_bits) | + (exponent << f32_num_mantissa_bits) | mantissa; + + float ret = NAN; + std::memcpy(&ret, &i, sizeof(float)); + return ret; +} + +// Same as the previous function, but use the built-in fp16 to fp32 +// conversion provided by the compiler +inline float cpu_half2float(const float16 h) { +#if defined(HAS_NATIVE_FP16_TYPE) && !defined(MISSING_GNU_F2H_IEEE) + __fp16 h_fp16 = NAN; + std::memcpy(&h_fp16, &h, sizeof(__fp16)); + return h_fp16; +#else + return cpu_half2float_ref(h); +#endif +} + +inline float16 cpu_float2half(const float f) { +#if defined(HAS_NATIVE_FP16_TYPE) && !defined(MISSING_GNU_F2H_IEEE) + __fp16 h = f; + float16 res = 0; + std::memcpy(&res, &h, sizeof(__fp16)); + return res; +#else + return cpu_float2half_rn(f); +#endif +} + +inline float cpu_bf162float(bfloat16 src) { + float ret = NAN; + uint32_t val_fp32 = + static_cast(reinterpret_cast(&src)[0]) << 16; + std::memcpy(&ret, &val_fp32, sizeof(float)); + return ret; +} + +inline bfloat16 cpu_float2bfloat16(float src) { + uint32_t temp = 0; + std::memcpy(&temp, &src, sizeof(uint32_t)); + return (temp + (1u << 15)) >> 16; +} + +} // namespace fbgemm + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fbgemm/OutputProcessing-inl.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fbgemm/OutputProcessing-inl.h new file mode 100644 index 0000000000000000000000000000000000000000..4457d5ee0ba204e875009e9632bdda2480fc0f4d --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fbgemm/OutputProcessing-inl.h @@ -0,0 +1,320 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#include + +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +template +template +inline int memCopy::f( + outT* out, + inT* inp, + const block_type_t& block, + int ld_out, + int ld_in) const { + static_assert( + std::is_same_v, + "input and output data type must be of same type"); + // only copy if destination is not the same as source + if (out + block.row_start * ld_out + block.col_start != inp) { + for (int i = block.row_start; i < block.row_start + block.row_size; ++i) { + memcpy( + out + block.col_start + i * ld_out, + inp + (i - block.row_start) * ld_in, + block.col_size * sizeof(inT)); + } + } + return nextop_.template f(out, out, block, ld_out, ld_out); +} + +template +template +inline int DoSpmdmOnInpBuffer::f( + outT* out, + inT* inp, + const block_type_t& block, + int ld_out, + int ld_in) const { + assert(B_csc_.NumOfCols() % groups_ == 0); + int n_per_group = B_csc_.NumOfCols() / groups_; + int g = block.col_start / n_per_group; + B_csc_.SpMDM(block, A_ + g * B_csc_.NumOfRows(), lda_, true, inp, ld_in); + return nextop_.template f(out, inp, block, ld_out, ld_in); +} + +template +template +inline int DoSConvOnInpBuffer::f( + outT* out, + inT* inp, + const block_type_t& block, + int ld_out, + int ld_in) const { + B_csc_.SparseConv(conv_p_, block, A_, A_zero_point_, true, inp, ld_in); + return nextop_.template f(out, inp, block, ld_out, ld_in); +} + +template < + bool FUSE_RELU, + QuantizationGranularity Q_GRAN, + typename BIAS_TYPE, + typename outT, + typename inT, + typename nextOPType> +template +inline int +ReQuantizeOutput::f( + outT* out, + const inT* inp, + const block_type_t& block, + int ld_out, + int ld_in) const { + static_assert( + std::is_same_v, "input data type must be of int32_t type"); + int ncol_per_group = ncols_ / groups_; + assert( + block.col_size <= ncol_per_group && + "ReQuantizeOutput should be called at most 1 group at a time."); + if constexpr ( + instSet == inst_set_t::anyarch || !std::is_same_v) { + for (int i = block.row_start; i < block.row_start + block.row_size; ++i) { + for (int j = block.col_start; j < block.col_start + block.col_size; ++j) { + inT raw = inp[(i - block.row_start) * ld_in + (j - block.col_start)]; + if (Aq_zero_point_) { + raw -= Aq_zero_point_ * q_col_offsets_[j]; + } + int Bq_zero_point_idx = 0; + if constexpr (Q_GRAN == QuantizationGranularity::TENSOR) { + Bq_zero_point_idx = 0; + } else if constexpr (Q_GRAN == QuantizationGranularity::GROUP) { + int g = block.col_start / ncol_per_group; + Bq_zero_point_idx = g; + } else { + static_assert(Q_GRAN == QuantizationGranularity::OUT_CHANNEL); + Bq_zero_point_idx = j; + } + if (q_row_offsets_) { + raw -= q_row_offsets_[i - block.row_start] * + Bq_zero_point_[Bq_zero_point_idx]; + } + float raw_f = NAN; + if (bias_) { + if constexpr (std::is_same_v) { + raw_f = raw; + raw_f += bias_[j] / act_times_w_scale_[Bq_zero_point_idx]; + } else { + raw += bias_[j]; + raw_f = raw; + } + } else { + raw_f = raw; + } + + float ab = raw_f * C_multiplier_[Bq_zero_point_idx]; + long rounded = std::lrintf(ab) + C_zero_point_; + + out[i * ld_out + j] = std::max( + FUSE_RELU ? static_cast(C_zero_point_) : 0l, + std::min(255l, rounded)); + } + } + +#if !defined(__aarch64__) + + } else if constexpr ( + instSet == inst_set_t::avx2 || instSet == inst_set_t::avx512) { + bool b_symmetric = + (Q_GRAN == QuantizationGranularity::TENSOR && Bq_zero_point_[0] == 0) || + q_row_offsets_ == nullptr; + + requantizationParams_t r = { + Aq_zero_point_, + Bq_zero_point_, + C_zero_point_, + C_multiplier_, + q_row_offsets_, + q_col_offsets_, + bias_, + ncols_, + groups_, + act_times_w_scale_}; + + if (Aq_zero_point_ == 0) { + if (b_symmetric) { + if (bias_ == nullptr) { + requantizeOutputProcessingAvx2( + out, inp, block, ld_out, ld_in, r); + } else { + requantizeOutputProcessingAvx2( + out, inp, block, ld_out, ld_in, r); + } + } else { + if (bias_ == nullptr) { + requantizeOutputProcessingAvx2( + out, inp, block, ld_out, ld_in, r); + } else { + requantizeOutputProcessingAvx2( + out, inp, block, ld_out, ld_in, r); + } + } + } else { + if (b_symmetric) { + if (bias_ == nullptr) { + requantizeOutputProcessingAvx2( + out, inp, block, ld_out, ld_in, r); + } else { + requantizeOutputProcessingAvx2( + out, inp, block, ld_out, ld_in, r); + } + } else { + if (bias_ == nullptr) { + requantizeOutputProcessingAvx2< + false, + false, + Q_GRAN, + false, + FUSE_RELU>(out, inp, block, ld_out, ld_in, r); + } else { + requantizeOutputProcessingAvx2( + out, inp, block, ld_out, ld_in, r); + } + } + } + +#endif // __aarch64__ + + } else { + assert(0 && "Not supported yet"); + } + return nextop_.template f(out, out, block, ld_out, ld_out); +} + +template < + bool FUSE_RELU, + QuantizationGranularity Q_GRAN, + typename outT, + typename inT, + typename nextOPType> +template +inline int ReQuantizeForFloat::f( + outT* out, + inT* inp, + const block_type_t& block, + int ld_out, + int ld_in) const { + static_assert( + std::is_same_v, "input data type is of not expected type"); + static_assert( + std::is_same_v, "output data type is of not expected type"); + int ncol_per_group = ncols_ / groups_; + assert( + block.col_size <= ncol_per_group && + "ReQuantizeOutput should be called at most 1 group at a time."); + if constexpr ( + instSet == inst_set_t::anyarch || !std::is_same_v) { + for (int i = block.row_start; i < block.row_start + block.row_size; ++i) { + for (int j = block.col_start; j < block.col_start + block.col_size; ++j) { + inT raw = inp[(i - block.row_start) * ld_in + j - block.col_start]; + if (Aq_zero_point_) { + raw -= Aq_zero_point_ * q_col_offsets_[j]; + } + int Bq_zero_point_idx = 0; + if constexpr (Q_GRAN == QuantizationGranularity::TENSOR) { + Bq_zero_point_idx = 0; + } else if constexpr (Q_GRAN == QuantizationGranularity::GROUP) { + int g = block.col_start / ncol_per_group; + Bq_zero_point_idx = g; + } else { + static_assert(Q_GRAN == QuantizationGranularity::OUT_CHANNEL); + Bq_zero_point_idx = j; + } + if (q_row_offsets_) { + raw -= q_row_offsets_[i - block.row_start] * + Bq_zero_point_[Bq_zero_point_idx]; + } + float res = raw * Aq_scale_ * Bq_scale_[Bq_zero_point_idx]; + if (bias_) { + res += bias_[j]; + } + out[i * ld_out + j] = res; + if constexpr (FUSE_RELU) { + out[i * ld_out + j] = std::max(0.0f, out[i * ld_out + j]); + } + } + } + +#if !defined(__aarch64__) + } else if constexpr ( + instSet == inst_set_t::avx2 || instSet == inst_set_t::avx512) { + bool b_symmetric = + (Q_GRAN == QuantizationGranularity::TENSOR && Bq_zero_point_[0] == 0) || + q_row_offsets_ == nullptr; + + requantizationForFloatParams_t r = { + Aq_zero_point_, + Bq_zero_point_, + Aq_scale_, + Bq_scale_, + q_row_offsets_, + q_col_offsets_, + bias_, + ncols_, + groups_}; + + if (Aq_zero_point_ == 0) { + if (b_symmetric) { + if (bias_ == nullptr) { + requantizeForFloatAvx2( + out, inp, block, ld_out, ld_in, r); + } else { + requantizeForFloatAvx2( + out, inp, block, ld_out, ld_in, r); + } + } else { + if (bias_ == nullptr) { + requantizeForFloatAvx2( + out, inp, block, ld_out, ld_in, r); + } else { + requantizeForFloatAvx2( + out, inp, block, ld_out, ld_in, r); + } + } + } else { + if (b_symmetric) { + if (bias_ == nullptr) { + requantizeForFloatAvx2( + out, inp, block, ld_out, ld_in, r); + } else { + requantizeForFloatAvx2( + out, inp, block, ld_out, ld_in, r); + } + } else { + if (bias_ == nullptr) { + requantizeForFloatAvx2( + out, inp, block, ld_out, ld_in, r); + } else { + requantizeForFloatAvx2( + out, inp, block, ld_out, ld_in, r); + } + } + } + +#endif // __aarch64__ + + } else { + assert(0 && "Not supported yet"); + } + + return nextop_.template f(out, out, block, ld_out, ld_out); +} + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fbgemm/PackingTraits-inl.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fbgemm/PackingTraits-inl.h new file mode 100644 index 0000000000000000000000000000000000000000..3d9f26ab2b30b805f9399a61bc5063ee58764d37 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fbgemm/PackingTraits-inl.h @@ -0,0 +1,541 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +/* + * This file configures the important cache blocking parameters and registers + * blocking parameters for the matrix multiplication loops inside FBGEMM. + * + * ROW_INTERLEAVE: the number of interleaved rows to use vpmaddubsw instructions + * for packing B matrix. For 32-bit accumulation, ROW_INTERLEAVE = 4; For 16-bit + * accumulation, ROW_INTERLEAVE = 2. + * + * VLEN: the vector length of one SIMD register. For avx2, VLEN = 256; For + * avx512, VLEN = 512. + * + * NR: the register blocking parameters for N dimension. The total registers + * used in N dimension for C accumulations are NR * ROW_INTERLEAVE * 8 (int8) / + * VLEN. + * + * MR: the register blocking parameters for M dimension. The total number of + * registers used in M dimension for C accumulations is MR. This indicates the + * number of vpbroadcastw instructions for A. + * + * (MR) * (NR * ROW_INTERLEAVE * 8 (int8) / VLEN): the number of registers used + * for C accumulations. This number should be less than the maximum registers we + * can use for C accumulations (A max of 12 out of 16 ymm registers for avx2; a + * max of 28 out of 32 zmm registers for avx512 ). The remaining are used for A + * matrix loading, B matrix loading and as temp registers. C accumulation + * registers should be as large as possible to increase the register + * utilization. + * + * MCB: the cache blocking parameters for M dimension. MCB needs to be a + * multiple of MR. + * + * NCB: the cache blocking parameters for N dimension. NCB needs to be a + * multiple of NR. + * + * KCB: the cache blocking parameters for K dimension. KCB needs to be a + * multiple of ROW_INTERLEAVE. + */ + +/** + * @brief Packing parameter specialization for accumulation into 32-bit + * integers. + * + * This is picked when T is of int8 type (signed or unsigned) and instruction + * set is avx2 + */ +template +struct PackingTraits< + T, + std::int32_t, + inst_set_t::avx2, + std::enable_if_t::value>> { + static constexpr int MR{12}; ///< Register block for M dimension. + static constexpr int NR_MIN{8}; ///< Minimum register block for N dimension. + ///< 8 because 8*ROW_INTERLEAVE int8 elements + ///< completely fill a 256-bit wide vector. + static constexpr int NR{8}; ///< Register block for N dimension. + ///< NR = VLEN/8/ROW_INTERLEAVE = 256 / 8 / 4 = 8. + ///< Total registers used for N dimension: NCB/NR. + ///< Here we use 12 x 1 ymm register blocking for + ///< the registers used for accumulation C. + + static constexpr int ROW_INTERLEAVE{ + 4}; ///< 4 rows are interleaved to use vpmaddubsw instruction for packing + ///< B matrix. + + static constexpr int MCB{ + 120}; ///< Cache block for M dimension (multiple of MR). + static constexpr int NCB{ + 8}; ///< Cache block for N dimension (multiple of NR). + static constexpr int KCB{512}; ///< Cache block for K dimension. + + static std::tuple getCacheBlockParams() { + return std::tuple(int(MCB), int(KCB), int(MR)); + } + static std::tuple getKernelParams() { + return std::tuple( + int(MCB), int(NCB), int(NR_MIN), int(NR)); + } + static std::tuple getMatrixPackAParams() { + return std::tuple(int(MCB), int(KCB), int(ROW_INTERLEAVE)); + } + static std::tuple getMatrixPackBParams() { + return std::tuple(int(KCB), int(NCB), int(ROW_INTERLEAVE)); + } +}; + +/** + * @brief Packing parameter specialization for accumulation into 16-bit + * integers. + * + * This is picked when T is of int8 type (signed or unsigned) and instruction + * set is avx2. + */ +template +struct PackingTraits< + T, + std::int16_t, + inst_set_t::avx2, + std::enable_if_t::value>> { + static constexpr int MR{3}; ///< Register block for M dimension. + static constexpr int NR_MIN{ + 16}; ///< Minimum register block for N dimension. + ///< 16 because 16*ROW_INTERLEAVE int8 elements + ///< completely fill a 256-bit wide vector. + + static constexpr int NR{ + 16}; ///< Register block for N dimension; + ///< NR = VLEN/8/ROW_INTERLEAVE = 256 / 8 / 2 = 16. + ///< Total registers used for N dimension: NCB/NR. + ///< Here we use 3 x 4 ymm register blocking for the + ///< registers used for accumulation C. + + static constexpr int ROW_INTERLEAVE{ + 2}; ///< 2 rows are interleaved to use vpmaddubsw instruction for packing + ///< B matrix. + + static constexpr int MCB{ + 60}; ///< Cache block for M dimension (multiple of MR). + static constexpr int NCB{ + 64}; ///< Cache block for N dimension (multiple of NR). + static constexpr int KCB{256}; ///< Cache block for K dimension. + + static std::tuple getCacheBlockParams() { + return std::tuple(int(MCB), int(KCB), int(MR)); + } + static std::tuple getKernelParams() { + return std::tuple( + int(MCB), int(NCB), int(NR_MIN), int(NR)); + } + static std::tuple getMatrixPackAParams() { + return std::tuple(int(MCB), int(KCB), int(ROW_INTERLEAVE)); + } + static std::tuple getMatrixPackBParams() { + return std::tuple(int(KCB), int(NCB), int(ROW_INTERLEAVE)); + } +}; + +/** + * @brief Packing parameter specialization for float input and float + * accumulation. + * + * This is picked when template paramtere T is of float type and instruction + * set is avx2. + */ +template <> +struct PackingTraits { + static constexpr int MR{3}; ///< Register block for M dimension + static constexpr int NR{32}; ///< Register block for N dimension + + static constexpr int ROW_INTERLEAVE{1}; ///< No Row interleave. + + static constexpr int MCB{ + 24}; ///< Cache block for M dimension (multiple of MR) + static constexpr int NCB{ + 64}; ///< Cache block for N dimension (multiple of NR) + static constexpr int KCB{256}; ///< Cache block for K dimension + + static std::tuple getCacheBlockParams() { + return std::tuple(int(MCB), int(KCB), int(MR)); + } + static std::tuple getMatrixPackAParams() { + return std::tuple(int(MCB), int(KCB), int(ROW_INTERLEAVE)); + } + static std::tuple getMatrixPackBParams() { + return std::tuple(int(KCB), int(NCB), int(ROW_INTERLEAVE)); + } +}; + +/** + * @brief Packing parameter specialization for fp16 input and float + * accumulation. + * + * This is picked when template parameter T is of float16 type and instruction + * set is avx2 + */ +template <> +struct PackingTraits { + static constexpr int BCOL{8}; + static constexpr int ROW_INTERLEAVE{1}; +}; + +/** + * @brief Packing parameter specialization for accumulation into 32-bit + * integers. + * + * This is picked when T is of int8 type (signed or unsigned) and instruction + * set is avx512. + */ +template +struct PackingTraits< + T, + std::int32_t, + inst_set_t::avx512, + std::enable_if_t::value>> { + static constexpr int MR{14}; ///< Register block for M dimension. + static constexpr int NR_MIN{ + 16}; ///< Minimum register block for N dimension. + ///< 16 because 16*ROW_INTERLEAVE int8 elements + ///< completely fill a 512-bit wide vector. + static constexpr int NR{ + 32}; ///< Register block for N dimension. + ///< Must be a multiple of 16 because 16*ROW_INTERLEAVE int8 elements + ///< completely fill a 512-bit wide vector. Total registers used for + ///< N dimension: NR*ROW_INTERLEAVE*8/VLEN. We use MR x + ///< NR*ROW_INTERLEAVE*8/VLEN zmm registers + ///< for C accumulations. + + static constexpr int ROW_INTERLEAVE{ + 4}; ///< 4 rows are interleaved to use vpmaddubsw instruction for packing + ///< B matrix. + + static constexpr int MCB{ + 56}; ///< Cache block for M dimension (multiple of MR). + static constexpr int NCB{ + 32}; ///< Cache block for N dimension (multiple of NR). + static constexpr int KCB{256}; ///< Cache block for K dimension. + + static std::tuple getCacheBlockParams() { + return std::tuple(int(MCB), int(KCB), int(MR)); + } + static std::tuple getKernelParams() { + return std::tuple( + int(MCB), int(NCB), int(NR_MIN), int(NR)); + } + static std::tuple getMatrixPackAParams() { + return std::tuple(int(MCB), int(KCB), int(ROW_INTERLEAVE)); + } + static std::tuple getMatrixPackBParams() { + return std::tuple(int(KCB), int(NCB), int(ROW_INTERLEAVE)); + } +}; + +/** + * @brief Packing parameter specialization for accumulation into 32-bit + * integers. + * + * This is picked when T is of int8 type (signed or unsigned) and instruction + * set is avx512_ymm. + */ +template +struct PackingTraits< + T, + std::int32_t, + inst_set_t::avx512_ymm, + std::enable_if_t::value>> { + static constexpr int MR{7}; ///< Register block for M dimension. + static constexpr int NR_MIN{16}; ///< Minimum register block for N dimension. + ///< 8 because 8*ROW_INTERLEAVE int8 elements + ///< completely fill a 256-bit wide vector. + static constexpr int NR{ + 32}; ///< Register block for N dimension. + ///< NR = VLEN/8/ROW_INTERLEAVE = 256 / 8 / 4 = 8. + ///< Total registers used for N dimension: NCB/NR. + ///< Here we use 12 x 1 ymm register blocking for + ///< the registers used for accumulation C. + + static constexpr int ROW_INTERLEAVE{ + 4}; ///< 4 rows are interleaved to use vpmaddubsw instruction for packing + ///< B matrix. + + static constexpr int MCB{ + 56}; ///< Cache block for M dimension (multiple of MR). + static constexpr int NCB{ + 32}; ///< Cache block for N dimension (multiple of NR). + static constexpr int KCB{256}; ///< Cache block for K dimension. + + static std::tuple getCacheBlockParams() { + return std::tuple(int(MCB), int(KCB), int(MR)); + } + static std::tuple getKernelParams() { + return std::tuple( + int(MCB), int(NCB), int(NR_MIN), int(NR)); + } + static std::tuple getMatrixPackAParams() { + return std::tuple(int(MCB), int(KCB), int(ROW_INTERLEAVE)); + } + static std::tuple getMatrixPackBParams() { + return std::tuple(int(KCB), int(NCB), int(ROW_INTERLEAVE)); + } +}; + +/** + * @brief Packing parameter specialization for accumulation into 16-bit + * integers. + * + * This is picked when T is of int8 type (signed or unsigned) and instruction + * set is avx512. + */ +template +struct PackingTraits< + T, + std::int16_t, + inst_set_t::avx512, + std::enable_if_t::value>> { + static constexpr int MR{6}; ///< Register block for M dimension + static constexpr int NR_MIN{ + 32}; ///< Minimum register block for N dimension; + ///< 32 because 32*ROW_INTERLEAVE int8 elements + ///< completely fill a 512-bit wide vector. + static constexpr int NR{ + 128}; ///< Register block for N dimension; + ///< Must be a multiple of 32 because 32*ROW_INTERLEAVE int8 + ///< elements completely fill a 512-bit wide vector. Total registers + ///< used for N dimension: NR*ROW_INTERLEAVE*8/VLEN. We use MR x + ///< NR*ROW_INTERLEAVE*8/VLEN zmm registers + ///< for C accumulations. + + static constexpr int ROW_INTERLEAVE{ + 2}; ///< 2 rows are interleaved to use vpmaddubsw instruction for packing + ///< B matrix. + + static constexpr int MCB{ + 60}; ///< Cache block for M dimension (multiple of MR). + static constexpr int NCB{ + 128}; ///< Cache block for N dimension (multiple of NR). + static constexpr int KCB{256}; ///< Cache block for K dimension. + + static std::tuple getCacheBlockParams() { + return std::tuple(int(MCB), int(KCB), int(MR)); + } + static std::tuple getKernelParams() { + return std::tuple( + int(MCB), int(NCB), int(NR_MIN), int(NR)); + } + static std::tuple getMatrixPackAParams() { + return std::tuple(int(MCB), int(KCB), int(ROW_INTERLEAVE)); + } + static std::tuple getMatrixPackBParams() { + return std::tuple(int(KCB), int(NCB), int(ROW_INTERLEAVE)); + } +}; + +/** + * @brief Packing parameter specialization for accumulation into 16-bit + * integers. + * + * This is picked when T is of int8 type (signed or unsigned) and instruction + * set is avx512_ymm. + */ +template +struct PackingTraits< + T, + std::int16_t, + inst_set_t::avx512_ymm, + std::enable_if_t::value>> { + static constexpr int MR{6}; ///< Register block for M dimension. + static constexpr int NR_MIN{ + 16}; ///< Minimum register block for N dimension. + ///< 16 because 16*ROW_INTERLEAVE int8 elements + ///< completely fill a 256-bit wide vector. + + static constexpr int NR{ + 16}; ///< Register block for N dimension; + ///< NR = VLEN/8/ROW_INTERLEAVE = 256 / 8 / 2 = 16. + ///< Total registers used for N dimension: NCB/NR. + ///< Here we use 3 x 4 ymm register blocking for the + ///< registers used for accumulation C. + + static constexpr int ROW_INTERLEAVE{ + 2}; ///< 2 rows are interleaved to use vpmaddubsw instruction for packing + ///< B matrix. + + static constexpr int MCB{ + 60}; ///< Cache block for M dimension (multiple of MR). + static constexpr int NCB{ + 64}; ///< Cache block for N dimension (multiple of NR). + static constexpr int KCB{256}; ///< Cache block for K dimension. + + static std::tuple getCacheBlockParams() { + return std::tuple(int(MCB), int(KCB), int(MR)); + } + static std::tuple getKernelParams() { + return std::tuple( + int(MCB), int(NCB), int(NR_MIN), int(NR)); + } + static std::tuple getMatrixPackAParams() { + return std::tuple(int(MCB), int(KCB), int(ROW_INTERLEAVE)); + } + static std::tuple getMatrixPackBParams() { + return std::tuple(int(KCB), int(NCB), int(ROW_INTERLEAVE)); + } +}; + +/** + * @brief Helper struct to type specialize for int16_t and int32_t together. + */ +template +struct is_16or32bit { + static constexpr bool value = + std::is_same_v || std::is_same_v; +}; + +/** + * @brief Packing parameter specialization for accumulation into 32-bit/16-bit + * integers. + * + * Since there is no int16_t accumulation for AVX512 VNNI, we redirect int16_t + * to int32_t accumulation and use the same blocking parameters as int32_t. + * + * This is picked when T is of int8 type (signed or unsigned) and instruction + * set is avx512_vnni. + */ +template +struct PackingTraits< + T, + accT, + inst_set_t::avx512_vnni, + std::enable_if_t::value && is_16or32bit::value>> { + static constexpr int MR{8}; ///< Register block for M dimension. + static constexpr int NR_MIN{ + 16}; ///< Minimum register block for N dimension. + ///< 16 because 16*ROW_INTERLEAVE int8 elements + ///< completely fill a 512-bit wide vector. + static constexpr int NR{ + 48}; ///< Register block for N dimension. + ///< Must be a multiple of 16 because 16*ROW_INTERLEAVE int8 elements + ///< completely fill a 512-bit wide vector. Total registers used for + ///< N dimension: NR*ROW_INTERLEAVE*8/VLEN. We use MR x + ///< NR*ROW_INTERLEAVE*8/VLEN zmm registers + ///< for C accumulations. + + static constexpr int ROW_INTERLEAVE{ + 4}; ///< 4 rows are interleaved to use vpmaddubsw instruction for packing + ///< B matrix. + + static constexpr int MCB{ + 384}; ///< Cache block for M dimension (multiple of MR). + static constexpr int NCB{ + 48}; ///< Cache block for N dimension (multiple of NR). + static constexpr int KCB{512}; ///< Cache block for K dimension. + + static std::tuple getCacheBlockParams() { + return std::tuple(int(MCB), int(KCB), int(MR)); + } + static std::tuple getKernelParams() { + return std::tuple( + int(MCB), int(NCB), int(NR_MIN), int(NR)); + } + static std::tuple getMatrixPackAParams() { + return std::tuple(int(MCB), int(KCB), int(ROW_INTERLEAVE)); + } + static std::tuple getMatrixPackBParams() { + return std::tuple(int(KCB), int(NCB), int(ROW_INTERLEAVE)); + } +}; + +/** + * @brief Packing parameter specialization for accumulation into 32-bit/16-bit + * integers. + * + * Since there is no int16_t accumulation for AVX512 VNNI, we redirect int16_t + * to int32_t accumulation and use the same blocking parameters as int32_t. + * + * This is picked when T is of int8 type (signed or unsigned) and instruction + * set is avx512_vnni_ymm. + */ +template +struct PackingTraits< + T, + accT, + inst_set_t::avx512_vnni_ymm, + std::enable_if_t::value && is_16or32bit::value>> { + static constexpr int MR{4}; ///< Register block for M dimension. + static constexpr int NR_MIN{ + 16}; ///< Minimum register block for N dimension. + ///< 16 because 16*ROW_INTERLEAVE int8 elements + ///< completely fill a 512-bit wide vector. + static constexpr int NR{ + 48}; ///< Register block for N dimension. + ///< Must be a multiple of 16 because 16*ROW_INTERLEAVE int8 elements + ///< completely fill a 512-bit wide vector. Total registers used for + ///< N dimension: NR*ROW_INTERLEAVE*8/VLEN. We use MR x + ///< NR*ROW_INTERLEAVE*8/VLEN zmm registers + ///< for C accumulations. + + static constexpr int ROW_INTERLEAVE{ + 4}; ///< 4 rows are interleaved to use vpmaddubsw instruction for packing + ///< B matrix. + + static constexpr int MCB{ + 384}; ///< Cache block for M dimension (multiple of MR). + static constexpr int NCB{ + 48}; ///< Cache block for N dimension (multiple of NR). + static constexpr int KCB{512}; ///< Cache block for K dimension. + + static std::tuple getCacheBlockParams() { + return std::tuple(int(MCB), int(KCB), int(MR)); + } + static std::tuple getKernelParams() { + return std::tuple( + int(MCB), int(NCB), int(NR_MIN), int(NR)); + } + static std::tuple getMatrixPackAParams() { + return std::tuple(int(MCB), int(KCB), int(ROW_INTERLEAVE)); + } + static std::tuple getMatrixPackBParams() { + return std::tuple(int(KCB), int(NCB), int(ROW_INTERLEAVE)); + } +}; + +/** + * @brief Packing parameter specialization for I64 GEMM + * integers. + * + * This is picked when T is of int64 type and instruction + * set is avx512. + */ +template <> +struct PackingTraits { + static constexpr int MR{2}; ///< Register block for M dimension. + static constexpr int NR_MIN{8}; ///< Minimum register block for N dimension. + ///< 8 because 8 int64 elements + ///< completely fill a 512-bit wide vector. + static constexpr int NR{ + 32}; ///< Register block for N dimension. + ///< Must be a multiple of 16 because 16*ROW_INTERLEAVE int8 elements + ///< completely fill a 512-bit wide vector. Total registers used for + ///< N dimension: NR*8/VLEN. We use MR x + ///< NR*8/VLEN zmm registers + ///< for C accumulations. + + static constexpr int MCB{ + 16}; ///< Cache block for M dimension (multiple of MR). + static constexpr int NCB{ + 64}; ///< Cache block for N dimension (multiple of NR). + static constexpr int KCB{8}; ///< Cache block for K dimension. +}; + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fbgemm/QuantUtils.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fbgemm/QuantUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..90e95beb1ed4d8b32ff0ebe84ac37d6686ca1b60 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fbgemm/QuantUtils.h @@ -0,0 +1,397 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include "./FbgemmBuild.h" // @manual +#include "./QuantUtilsAvx2.h" // @manual +#include "./QuantUtilsAvx512.h" // @manual +#include "./QuantUtilsNeon.h" // @manual +#include "./Types.h" // @manual +#include "./Utils.h" // @manual + +#include +#include +#include +#include +#include + +/// @defgroup fbgemm-quant-utils-generic Quantization Utilities (Generic) +/// + +namespace fbgemm { + +FBGEMM_API TensorQuantizationParams ChooseQuantizationParams( + float min, + float max, + std::int32_t qmin, + std::int32_t qmax, + bool preserve_sparsity = false, + bool force_scale_power_of_two = false); + +FBGEMM_API void ChooseRequantizationMultiplier( + float real_multiplier, + std::int32_t* quantized_multiplier, + int* right_shift, + int requantization_multiplier_precision = 32); + +//////////////////////////////////////////////////////////////////////////////// +// Utility functions + +// Clamp src in T1 to the desired precision and convert it to T2 +// TODO: T26263653 fix signed-integer-overflow undefined behavior +template +NO_SANITIZE("signed-integer-overflow") +T2 clamp(T1 src, int precision, bool is_signed = false) { + std::int32_t min = is_signed ? -(1LL << (precision - 1)) : 0; + std::int32_t max = + is_signed ? ((1LL << (precision - 1)) - 1) : (1LL << precision) - 1; + + // Make sure T1 and T2 can represent the precision + assert(min >= std::numeric_limits::lowest()); + assert(min >= std::numeric_limits::lowest()); + assert(max <= std::numeric_limits::max()); + assert(max <= std::numeric_limits::max()); + + return std::min(std::max(src, min), max); +} + +/// Quantize src using zero_point and scale, clamp to the specified precision, +/// and convert it to type T +template +T Quantize( + float src, + std::int32_t zero_point, + float scale, + int result_precision, + bool result_is_signed = std::is_signed_v) { + // Note: We want to multiply with src with inv_scale instead of + // dividing src by scale. The same is done in vector code and + // at other places. + // + // Example: + // With scale = 0.00214854861f, zero_point = 0 and src = 0.273939937f + // transformed_val is 127.5 for src * inv_scale while + // transformed_val is 127.499992 for src / scale. + // Eventually 127.5 gets rounded to 128 while 127.499992 gets rounded to 127. + float inv_scale = 1.0f / scale; + + float transformed_val = src * inv_scale; + // nearbyint here performs round-to-nearest-ties-to-even with + // default rounding mode. + // For example, nearbyint(1.4) is 1.0, nearbyint(1.5) is 2.0 + // and nearbyint(2.5) is 2.0 + // Adding zero_point before or after rounding can make a difference + // in exactly halfway cases. + if constexpr (LEGACY) { + transformed_val = std::nearbyint(zero_point + transformed_val); + } else { + transformed_val = zero_point + std::nearbyint(transformed_val); + } + // Please note the use of double. Unlike float, a double can represent + // all int32 values exactly. Using a float results in a float value > + // INT32_MAX conversion to int32 in clamp function and hence an UBSAN error. + return clamp(transformed_val, result_precision, result_is_signed); +} + +template +T Quantize(float src, const TensorQuantizationParams& qparams) { + return Quantize( + src, qparams.zero_point, qparams.scale, qparams.precision); +} + +template +FBGEMM_API void Quantize( + const float* src, + T* dst, + std::int64_t len, + const TensorQuantizationParams& qparams, + int thread_id = 0, + int num_threads = 1); + +/// @ingroup fbgemm-quant-utils-generic +/// +/// Quantize floating point data in `src` to type `T`. +/// +/// @tparam T output quantized data type (`int8_t`, `uint8_t`, and `int32_t` are +/// supported) +/// +/// @tparam LAYOUT layout of input tensor in `src`. (`KCX` and `KXC` are +/// supported) +/// `KCX` corresponds to `KCRS` or `KCTRS` (for weight tensors with time +/// dimension) +/// `KXC` corresponds to `KRSC` or `KTRSC` (for weight tensors with time +/// dimension) +/// +/// @param K Output channels for weight tensors +/// @param C Number of channels +/// @param X `R*S` or `T*R*S` +/// @param G Groups (if `G == C` the function performs channelwise +/// quantization; +/// if `1 < G < C` the function performs groupwise +/// quantization; if `G == 1` the function performs per tensor +/// quantization;) +/// @param scales floating point scales. Size should be equal `G` +/// @param zero_points zero points (should be reprsentable in type `T`). +/// Size should be equal `G` +template +FBGEMM_API void QuantizeGroupwise( + const float* src, + int K, + int C, + int X, + int G, + const float* scales, + const std::int32_t* zero_points, + T* dst); + +template +float Dequantize(T src, const TensorQuantizationParams& qparams) { + return qparams.scale * (src - qparams.zero_point); +} + +template +void Dequantize( + const T* src, + float* dst, + std::int64_t len, + const TensorQuantizationParams& qparams, + int thread_id = 0, + int num_threads = 1) { + int64_t i_begin = 0, i_end = 0; + fbgemmPartition1D(thread_id, num_threads, len, i_begin, i_end); + for (int64_t i = i_begin; i < i_end; i++) { + dst[i] = Dequantize(src[i], qparams); + } +} + +template +float FusedQuantizeDequantize( + float src, + const TensorQuantizationParams& qparams) { + T q = Quantize( + src, qparams.zero_point, qparams.scale, qparams.precision); + return Dequantize(q, qparams); +} + +/// @ingroup fbgemm-quant-utils-generic +/// +/// Fused integer quantization dequantization kernel to accelerate +/// quantization-aware training. Quantize `fp32` values in src to `(u)int8` +/// using the provided qparams, and dequantize quantized integer values back +/// into `fp32`. +template +FBGEMM_API void FusedQuantizeDequantize( + const float* src, + float* dst, + std::int64_t len, + const TensorQuantizationParams& qparams, + int thread_id = 0, + int num_threads = 1, + float noise_ratio = 0.0f); + +//////////////////////////////////////////////////////////////////////////////// +// Requantization (pure fixed-point) + +FBGEMM_API std::int64_t +SaturatingRoundingMulWithShift(std::int32_t a, std::int32_t b, int right_shift); + +template +T Requantize( + std::int32_t src, // int32 input before requantization + std::int32_t zero_point, + std::int32_t multiplier, + int right_shift, + int result_precision, + bool result_is_signed = false) { + std::int64_t quantized_down = + zero_point + SaturatingRoundingMulWithShift(src, multiplier, right_shift); + return clamp( + quantized_down, result_precision, result_is_signed); +} + +template +T RequantizeFixedPoint( + std::int32_t src, // int32 input before requantization + const RequantizationParams& params) { + return Requantize( + src, + params.target_qparams.zero_point, + params.multiplier, + params.right_shift, + params.target_qparams.precision); +} + +template +FBGEMM_API void RequantizeFixedPoint( + const std::int32_t* src, + T* dst, + std::int64_t len, + const RequantizationParams& params, + int thread_id = 0, + int num_threads = 1); + +//////////////////////////////////////////////////////////////////////////////// +// Requantization (with floats) + +template +T Requantize( + std::int32_t src, // int32 input before requantization + std::int32_t zero_point, + float multiplier, + int result_precision, + bool result_is_signed = false) { + long quantized_down = zero_point + std::lrintf(src * multiplier); + return clamp(quantized_down, result_precision, result_is_signed); +} + +template +T Requantize( + std::int32_t src, // int32 input before requantization + const RequantizationParams& params) { + return Requantize( + src, + params.target_qparams.zero_point, + params.real_multiplier, + params.target_qparams.precision); +} + +template +FBGEMM_API void Requantize( + const std::int32_t* src, + T* dst, + std::int64_t len, + const RequantizationParams& params, + int thread_id = 0, + int num_threads = 1); + +/** + * @ingroup fbgemm-quant-utils-generic + * + * Convert float (fp32 or fp16) inputs to rowwise quantized outputs. + * bitrate specifies the number of bits in quantized output. + * Scale and Bias are in fp16. Each row's Scale and Bias are stored in + * the row itself (fused) at the end. + * + * @param bit_rate can be 2, 4, or 8 + */ +template +FBGEMM_API void FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf( + int bit_rate, + const InputType* input, + size_t input_rows, + int input_columns, + std::uint8_t* output, + const InputType* rowwise_min_max = nullptr); + +/** + * Convert fused rowwise quantized inputs to float (fp32 or fp16). + * bitrate specifies the number of bits in quantized input. + * Scale and Bias are in fp16. Each row's Scale and Bias are stored in + * the row itself (fused) at the end. + * + * @param bit_rate can be 2, 4, or 8 + */ +template +FBGEMM_API void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalf( + int bit_rate, + const uint8_t* input, + size_t input_rows, + int input_columns, + OutputType* output, + bool scale_bias_last = true); + +/** + * Convert float or half inputs to rowwise quantized (8-bit) outputs. + * Scale and Bias are in float. Each row's Scale and Bias are stored in + * the row itself (fused) at the end. + * + * This version intentionally supports only 8-bit because we want to discourage + * the usage of float scale and bias with 2 and 4 bit cases as that diminishes + * the overall memory savings. + */ +template +FBGEMM_API void FloatOrHalfToFused8BitRowwiseQuantizedSBFloat( + const InputType* input, + size_t input_rows, + int input_columns, + std::uint8_t* output, + const InputType* rowwise_min_max = nullptr); + +/** + * Convert fused rowwise quantized (8-bit) inputs to float or half outputs. + * Scale and Bias are in float. Each row's Scale and Bias are stored in + * the row itself (fused) at the end. + * + * This version intentionally supports only 8-bit because + * the corresponding quantize version only supports 8-bit. + */ +template +FBGEMM_API void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalf( + const uint8_t* input, + size_t input_rows, + int input_columns, + OutputType* output, + const bool scale_bias_last = true, + const bool quant_padding_float_type = true); + +/** + * Same as ToFusedNBitRowwiseQuantizedSBHalf but unoptimized. + * This should not be called directly except in testing. + */ +template +FBGEMM_API void FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfRef( + int bit_rate, + const InputType* input, + size_t input_rows, + int input_columns, + std::uint8_t* output); + +/** + * Same as FloatOrHalfToFused8BitRowwiseQuantizedSBFloat but unoptimized. + * This should not be called directly except in testing. + */ +template +FBGEMM_API void FloatOrHalfToFused8BitRowwiseQuantizedSBFloatRef( + const InputType* input, + size_t input_rows, + int input_columns, + std::uint8_t* output); + +/** + * Same as FusedNBitRowwiseQuantizedSBHalfToFloat but unoptimized. + * This should not be called directly except in testing. + */ +template +FBGEMM_API void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfRef( + int bit_rate, + const uint8_t* input, + size_t input_rows, + int input_columns, + OutputType* output, + bool scale_bias_last = true); + +/** + * Same as Fused8BitRowwiseQuantizedSBFloatToFloatOrHalf but unoptimized. + * This should not be called directly except in testing. + */ +template +FBGEMM_API void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfRef( + const uint8_t* input, + size_t input_rows, + int input_columns, + OutputType* output, + const bool scale_bias_last = true, + const bool quant_padding_float_type = true); + +} // namespace fbgemm + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fbgemm/QuantUtilsAvx2.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fbgemm/QuantUtilsAvx2.h new file mode 100644 index 0000000000000000000000000000000000000000..ec985aeba579fdb4bec0cffedc9115fc194ee861 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fbgemm/QuantUtilsAvx2.h @@ -0,0 +1,192 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include "./FbgemmBuild.h" // @manual +#include "./UtilsAvx2.h" // @manual + +/// @defgroup fbgemm-quant-utils-avx2 Quantization Utilities (AVX2) +/// + +namespace fbgemm { + +/// Number of columns in the rowwise min/max buffer passed to the quantization +/// function(s) +constexpr int kRowwiseMinMaxNumCols = 2; + +/// Struct from `gemmlowp` +/// +/// A structure to hold quantization parameters `scale` and `zero_point`. +/// The meaning of these values is as the constants in the quantization equation +/// +/// `real_value = scale * (quantized_value - zero_point)` +/// +/// In other words, 'zero_point' is the quantized value that corresponds +/// to the real value 0, and 'scale' is the difference of real values +/// corresponding to consecutive quantized values. +struct FBGEMM_API TensorQuantizationParams { + float scale; + std::int32_t zero_point; + int precision; + float Min() const; + float Max() const; +}; + +/// Parameters when we scale from int32 intermediate matrix multiplication +/// results to 8-bit integers +struct FBGEMM_API RequantizationParams { + /// For floating-point requantization + float real_multiplier; + + /// For fixed-point requantization + std::int32_t multiplier; + int right_shift; + + TensorQuantizationParams target_qparams; +}; + +/// @ingroup fbgemm-quant-utils-avx2 +/// +/// @brief Find the min and max value in a float matrix. +void FBGEMM_API FindMinMax(const float* m, float* min, float* max, int64_t len); + +#if !defined(__aarch64__) + +//////////////////////////////////////////////////////////////////////////////// +// Utility functions +//////////////////////////////////////////////////////////////////////////////// + +template +void QuantizeAvx2( + const float* src, + T* dst, + int64_t len, + const TensorQuantizationParams& qparams); + +template +void FusedQuantizeDequantizeAvx2( + const float* src, + float* dst, + int len, + const TensorQuantizationParams& qparams, + float noise_ratio = 0.0f); + +/// @ingroup fbgemm-quant-utils-avx2 +/// +/// Random number generator in [0, 9] based on +/// this paper. +uint32_t FBGEMM_API Xor128(); + +void RequantizeFixedPointAvx2( + const std::int32_t* src, + std::uint8_t* dst, + int len, + const RequantizationParams& params); + +void RequantizeAvx2( + const std::int32_t* src, + std::uint8_t* dst, + int len, + const RequantizationParams& params); + +#endif // !defined(__aarch64__) + +/// @ingroup fbgemm-quant-utils-avx2 +/// +/// Requantize with avx2 and bias is fused. +template < + bool A_SYMMETRIC, + bool B_SYMMETRIC, + QuantizationGranularity Q_GRAN, + bool HAS_BIAS, + bool FUSE_RELU, + typename BIAS_TYPE = std::int32_t, + bool DIRECT = false> +FBGEMM_API void requantizeOutputProcessingAvx2( + std::uint8_t* out, + const std::int32_t* inp, + const block_type_t& block, + int ld_out, + int ld_in, + const requantizationParams_t& r); + +template < + bool A_SYMMETRIC, + bool B_SYMMETRIC, + QuantizationGranularity Q_GRAN, + bool HAS_BIAS, + bool FUSE_RELU, + int C_PER_G, + typename BIAS_TYPE = std::int32_t> +FBGEMM_API void requantizeOutputProcessingGConvAvx2( + std::uint8_t* out, + const std::int32_t* inp, + const block_type_t& block, + int ld_out, + int ld_in, + const requantizationParams_t& r); + +template < + bool A_SYMMETRIC, + bool B_SYMMETRIC, + QuantizationGranularity Q_GRAN, + bool HAS_BIAS, + bool FUSE_RELU> +FBGEMM_API void requantizeForFloatAvx2( + float* out, + const std::int32_t* inp, + const block_type_t& block, + int ld_out, + int ld_in, + const requantizationForFloatParams_t& r); + +#if !defined(__aarch64__) + +template +void FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfAvx2( + const InputType* input, + size_t input_rows, + int input_columns, + std::uint8_t* output, + const InputType* rowwise_min_max = nullptr); + +template +void FloatOrHalfToFused8BitRowwiseQuantizedSBFloatAvx2( + const InputType* input, + size_t input_rows, + int input_columns, + std::uint8_t* output, + const InputType* rowwise_min_max = nullptr); + +template +void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfAvx2( + const std::uint8_t* input, + size_t input_rows, + int input_columns, + OutputType* output); + +template < + typename OutputType, + bool scale_bias_last = true, + bool quant_padding_float_type = true> +void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfAvx2( + const std::uint8_t* input, + size_t input_rows, + int input_columns, + OutputType* output); + +#endif // !defined(__aarch64__) + +} // namespace fbgemm + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fbgemm/QuantUtilsAvx512.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fbgemm/QuantUtilsAvx512.h new file mode 100644 index 0000000000000000000000000000000000000000..d8330f4808c1b258e94d58c68cd14b8060b31b1c --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fbgemm/QuantUtilsAvx512.h @@ -0,0 +1,55 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include "Types.h" +#if !defined(__aarch64__) + +#include +#include "./FbgemmBuild.h" // @manual +#include "./UtilsAvx2.h" // @manual + +/// @defgroup fbgemm-quant-utils-avx512 Quantization Utilities (AVX512) +/// + +namespace fbgemm { + +/// @ingroup fbgemm-quant-utils-avx512 +/// +/// Requantize with AVX512. +template < + bool A_SYMMETRIC, + bool B_SYMMETRIC, + QuantizationGranularity Q_GRAN, + bool HAS_BIAS, + bool FUSE_RELU, + int C_PER_G, + typename BIAS_TYPE = std::int32_t> +FBGEMM_API void requantizeOutputProcessingGConvAvx512( + std::uint8_t* out, + const std::int32_t* inp, + const block_type_t& block, + int ld_out, + int ld_in, + const requantizationParams_t& r); + +template +void Fused8BitRowwiseQuantizedSBFloatToBfloat16Avx512( + const std::uint8_t* input, + size_t input_rows, + int input_columns, + bfloat16* output); +} // namespace fbgemm + +#endif + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fbgemm/QuantUtilsNeon.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fbgemm/QuantUtilsNeon.h new file mode 100644 index 0000000000000000000000000000000000000000..32e571213b6c83d9dd7dd65b4b5930b9a3974224 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fbgemm/QuantUtilsNeon.h @@ -0,0 +1,53 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#ifdef __aarch64__ + +#include +#include "./FbgemmBuild.h" // @manual + +/// @defgroup fbgemm-quant-utils-avx2 Quantization Utilities (AVX2) +/// + +namespace fbgemm { + +//////////////////////////////////////////////////////////////////////////////// +// Utility functions +//////////////////////////////////////////////////////////////////////////////// + +template +void FloatOrHalfToFused8BitRowwiseQuantizedSBFloatNeon( + const InputType* input, + size_t input_rows, + int input_columns, + uint8_t* output); + +template +void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfNeon( + const std::uint8_t* input, + size_t input_rows, + int input_columns, + OutputType* output); + +template +void FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfNeon( + const InputType* input, + size_t input_rows, + int input_columns, + std::uint8_t* output); + +} // namespace fbgemm + +#endif // __aarch64__ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fbgemm/SimdUtils.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fbgemm/SimdUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..6a1f4aca84462476cec8e1aee731c6f691765881 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fbgemm/SimdUtils.h @@ -0,0 +1,118 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include "./Utils.h" // @manual + +#include // @manual +#include // @manual + +namespace fbgemm { + +#if ASMJIT_LIBRARY_VERSION >= ASMJIT_LIBRARY_MAKE_VERSION(1, 17, 0) +//! 128-bit XMM register (SSE+). +class Xmm : public asmjit::x86::Vec { + public: + using Vec::Vec; + using Vec::operator=; + Xmm(uint32_t regId) : Vec(asmjit::x86::Vec::make_xmm(regId)) {} + //! Casts this register to a register that has half the size (XMM). + ASMJIT_INLINE_NODEBUG Xmm half() const noexcept { + return Xmm(id()); + } +}; + +//! 256-bit YMM register (AVX+). +class Ymm : public asmjit::x86::Vec { + public: + using Vec::Vec; + using Vec::operator=; + Ymm(uint32_t regId) : Vec(asmjit::x86::Vec::make_ymm(regId)) {} + //! Casts this register to a register that has half the size (XMM). + ASMJIT_INLINE_NODEBUG Xmm half() const noexcept { + return Xmm(id()); + } +}; + +//! 512-bit ZMM register (AVX512+). +class Zmm : public asmjit::x86::Vec { + public: + using Vec::Vec; + using Vec::operator=; + Zmm(uint32_t regId) : Vec(asmjit::x86::Vec::make_zmm(regId)) {} + //! Casts this register to a register that has half the size (YMM). + ASMJIT_INLINE_NODEBUG Ymm half() const noexcept { + return Ymm(id()); + } +}; +#else +using Xmm = asmjit::x86::Xmm; +using Ymm = asmjit::x86::Ymm; +using Zmm = asmjit::x86::Zmm; +#endif + +/** + * @brief Some commonly used variables for different instruction sets + */ +template +struct simd_info; + +template <> +struct simd_info { + static constexpr int WIDTH_BITS = 256; + static constexpr int WIDTH_BYTES = 32; + static constexpr int WIDTH_32BIT_ELEMS = 8; + static constexpr int NUM_VEC_REGS = 16; + + using vec_reg_t = Ymm; +}; + +template <> +struct simd_info { + // Implementation is unrolled to match params used on avx2 + static constexpr int WIDTH_BITS = 256; + static constexpr int WIDTH_BYTES = 32; + static constexpr int WIDTH_32BIT_ELEMS = 8; + static constexpr int NUM_VEC_REGS = 32; +}; + +template <> +struct simd_info { + static constexpr int WIDTH_BITS = 512; + static constexpr int WIDTH_BYTES = 64; + static constexpr int WIDTH_32BIT_ELEMS = 16; + static constexpr int NUM_VEC_REGS = 32; + + using vec_reg_t = Zmm; +}; + +template <> +struct simd_info + : public simd_info {}; + +template <> +struct simd_info { + static constexpr int WIDTH_BITS = 256; + static constexpr int WIDTH_BYTES = 32; + static constexpr int WIDTH_32BIT_ELEMS = 8; + static constexpr int NUM_VEC_REGS = 32; + + using vec_reg_t = Ymm; +}; + +template <> +struct simd_info + : public simd_info {}; + +} // namespace fbgemm + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fbgemm/Types.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fbgemm/Types.h new file mode 100644 index 0000000000000000000000000000000000000000..615ea0d87471ee752b9bc76873736ad1a36b0ef4 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fbgemm/Types.h @@ -0,0 +1,31 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace fbgemm { + +using float16 = std::uint16_t; +using bfloat16 = std::uint16_t; + +inline int64_t round_up(int64_t val, int64_t unit) { + return (val + unit - 1) / unit * unit; +} + +inline int64_t div_up(int64_t val, int64_t unit) { + return (val + unit - 1) / unit; +} + +} // namespace fbgemm + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fbgemm/Utils.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fbgemm/Utils.h new file mode 100644 index 0000000000000000000000000000000000000000..dc0ef013d11fb4a65fbff6337a5eeac36fbadbc3 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fbgemm/Utils.h @@ -0,0 +1,505 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include "./FbgemmBuild.h" // @manual +#include "./UtilsAvx2.h" // @manual + +#include +#include +#include +#include +#include +#include +#include +#include + +#ifndef HAVE_SVE +#if defined(__aarch64__) && __ARM_FEATURE_SVE && \ + __has_include() +#define HAVE_SVE 1 +#else +#define HAVE_SVE 0 +#endif +#endif + +namespace fbgemm { + +/** + * @brief Helper struct to type specialize for uint8 and int8 together. + */ +template +struct is_8bit { + static constexpr bool value = + std::is_same_v || std::is_same_v; +}; + +/** + * @brief Typed enum to specify matrix operations. + */ +enum class matrix_op_t { NoTranspose, Transpose }; + +/** + * @brief Typed enum for supported instruction sets. + */ +enum class inst_set_t { + anyarch, + avx2, + avx512, + avx512_ymm, + avx512_vnni, + avx512_vnni_ymm, + sve +}; + +/** + * @brief Typed enum for optimized paths for convolutions + */ +enum class optimized_conv_t { + depthwise, + groupwise, + pointwise, + fastpath1d, + im2col, + directconv +}; + +/** + * @brief Typed enum for implementation type. + * + * ref is reference and opt is optimized. + */ +enum class impl_type_t { ref, opt }; + +/** + * @brief Typed enum to specify data layout. + * KCX can be KCRS format or KCTRS format (e.g., for 3-D convolutions) + * KXC can be KRSC format or KTRSC format (e.g., for 3-D convolutions) + */ +enum class FBGEMM_ENUM_CLASS_API layout_t { KCX, KXC }; + +/** + * @brief A function to compare data in two buffers for closeness/equality. + */ +template +FBGEMM_API int compare_buffers( + const T* ref, + const T* test, + int m, + int n, + int ld, + size_t max_mismatches_to_report, + float atol = 1e-3); + +/** + * @brief Print the matrix. + * @param op Transpose type of the matrix. + * @param R The height of the matrix. + * @param C The width of the matrix. + * @param ld The leading dimension of the matrix. + * @param name The prefix string before printing the matrix. + */ +template +void printMatrix( + matrix_op_t op, + const T* inp, + size_t R, + size_t C, + size_t ld, + const std::string& name) { + // R: number of rows in op(inp) + // C: number of cols in op(inp) + // ld: leading dimension in inp + std::cout << name << ":" << "[" << R << ", " << C << "]" << '\n'; + bool tr = (op == matrix_op_t::Transpose); + for (size_t r = 0; r < R; ++r) { + for (size_t c = 0; c < C; ++c) { + T res = tr ? inp[c * ld + r] : inp[r * ld + c]; + if constexpr (std::is_integral_v) { + std::cout << std::setw(5) << static_cast(res) << " "; + } else { + std::cout << std::setw(5) << res << " "; + } + } + std::cout << '\n'; + } +} + +/** + * @brief Transpose a matrix. + * + * @param M the number of rows of input matrix + * @param N the number of columns of input matrix + */ +template +FBGEMM_API void transpose_simd( + int64_t M, + int64_t N, + const T* src, + int64_t ld_src, + T* dst, + int64_t ld_dst); + +/** + * @brief Explicitly set instruction set to be used + */ +FBGEMM_API void fbgemmForceIsa(inst_set_t /*isa*/); + +/** + * @brief Enable AVX512-256 path for Intel(r) Xeon(r) D servers + */ +FBGEMM_API void fbgemmEnableAvx512Ymm(bool /*flag*/); + +/** + * @brief Are we running on a Xeon-D cpu? + */ +FBGEMM_API bool fbgemmIsIntelXeonD(); + +/** + * @brief Are we running on a AVX512 supported cpu? + */ +FBGEMM_API bool fbgemmHasAvx512Support(); + +/** + * @brief Are we running on a AVX2 supported cpu? + */ +FBGEMM_API bool fbgemmHasAvx2Support(); + +/** + * @brief Are we running on a AVX512_VNNI supported cpu? + */ +FBGEMM_API bool fbgemmHasAvx512VnniSupport(); + +/** + * @brief Are we running on a AVX512_BF16 supported cpu? + */ +FBGEMM_API bool fbgemmHasAvx512Bf16Support(); + +/** + * @brief Are we running on a ARM Neon supported cpu? + */ +FBGEMM_API bool fbgemmHasArmNeonSupport(); + +/** + * @brief Are we running on a ARM SVE supported cpu? + */ +FBGEMM_API bool fbgemmHasArmSveSupport(); + +/** + * @brief Are we running on a ARM SVE2 supported cpu? + */ +FBGEMM_API bool fbgemmHasArmSve2Support(); + +/** + * @brief Retrieve current CPU instruction set + */ +FBGEMM_API inst_set_t fbgemmInstructionSet(); + +/** + * @brief Is ISA is wide vector ZMM + */ +FBGEMM_API bool isZmm(inst_set_t /*isa*/); + +/** + * @brief Is ISA is wide vector ZMM + */ +FBGEMM_API bool isYmm(inst_set_t /*isa*/); + +/** + * @brief Helper struct to enable autotuning of FBGEMM packing and kernels. + * + * This structure is optional. If not used, the default values for these + * parameters are picked up from PackingTraits-inl.h. Please see this + * file for details on these parameters. + */ +struct FBGEMM_API BlockingFactors { + int MR; + int NR; + int NR_MIN; + int ROW_INTERLEAVE; + int MCB; + int KCB; + int NCB; +}; + +/** + * @brief A struct to represent the partition information for the threads on the + * m and n dimensions. + */ +struct FBGEMM_API thread_type_t { + int g_num_threads; + int m_num_threads; + int n_num_threads; + int g_thread_id; + int m_thread_id; + int n_thread_id; + + std::string toString() const { + std::string out; + out += "g num threads: " + std::to_string(g_num_threads) + ", "; + out += "m num threads: " + std::to_string(m_num_threads) + ", "; + out += "n num threads: " + std::to_string(n_num_threads) + ", "; + out += "g thread id: " + std::to_string(g_thread_id) + ", "; + out += "m thread id: " + std::to_string(m_thread_id) + ", "; + out += "n thread id: " + std::to_string(n_thread_id); + return out; + } +}; + +/** + * @brief A heuristic algorithm to partition the threads across m and n + * dimensions for parallelization, ensuring the ratio between the number of rows + * allocated to each thread in the m dimension and the number of columns + * allocated to each thread in the n dimension is approximately aspect_ratio. + * + * The less aspect_ratio is, the more favorable it is to parallelize the m + * dimension over the n dimension. + */ +FBGEMM_API int fbgemmGet2DPartition( + int m, + int n, + int nthreads, + int n_align, + double aspect_ratio); + +/** + * @brief A heuristic way to partition the threads across g, m and n dimensions + * for parallelization. + */ +FBGEMM_API thread_type_t fbgemmGetThreadPartition( + int g, + int m, + int n, + int thread_id, + int num_threads, + int n_align = 64); + +template +std::string arrayToString(const std::array& inp) { + std::string out = "["; + for (int i = 0; i < SIZE; ++i) { + out += std::to_string(inp[i]); + out += (i != SIZE - 1) ? std::string(", ") : std::string("]"); + } + return out; +} + +template +bool isValidBlockingFactor(const BlockingFactors* const param) { + constexpr bool is_32bit = std::is_same_v; + constexpr bool is_16bit = std::is_same_v; + static const auto iset = fbgemmInstructionSet(); + + if constexpr (is_32bit) { + if (param->ROW_INTERLEAVE != 4) + return false; + + if (isZmm(iset)) { + if (param->NR_MIN != 16 || param->NR % param->NR_MIN) + return false; + } else if (isYmm(iset)) { + if (param->NR_MIN != 8 || param->NR % param->NR_MIN) + return false; + } + } else if constexpr (is_16bit) { + if (param->ROW_INTERLEAVE != 2) + return false; + + if (isZmm(iset)) { + if (param->NR_MIN != 32 || param->NR % param->NR_MIN) + return false; + } else if (isYmm(iset)) { + if (param->NR_MIN != 16 || param->NR % param->NR_MIN) + return false; + } + } + + if (param->MCB % param->MR) + return false; + if (param->NCB % param->NR) + return false; + if (isZmm(iset)) { + if constexpr (is_32bit) { + // Zmm register usage for C + if (param->MR * (param->NR / param->NR_MIN) > 28) + return false; + } else if constexpr (is_16bit) { + // Zmm register usage for C + one row for loading B + if ((param->MR * (param->NR / param->NR_MIN) + + (param->NR / param->NR_MIN)) > 28) + return false; + } + + } else if (isYmm(iset)) { + if (param->MR * (param->NR / param->NR_MIN) > 12) + return false; + } + return true; +} + +/** + * @brief Partition work across given number of threads + * + * @param start Given thread_id should execute starting from the index + * start + * @param stop Given thread_id should stop executing at the index stop + * + * i.e., the loop should be equivalent to for(int i = start; i < end; ++i) + */ +FBGEMM_API void fbgemmPartition1D( + int thread_id, + int num_threads, + std::int64_t total_work, + std::int64_t& start, + std::int64_t& end); + +/** + * @brief Partition work across given number of threads in blocks + * of size block_size. Each thread gets a multiple of block_size + * work or nothing, except the last one. The last one might + * receive the fringe case. + * + * @param start Given thread_id should execute starting from the index + * start + * @param stop Given thread_id should stop executing at the index stop + * + * The loop can be equivalent to for(int i = start; i < end; i+=block_size) + * except for the last thread. (i.e., thread_id = num_threads - 1) + * + * Example 1: block_size = 2, num_threads = 2 + * total_work start(th 0) end(th 0) start(th 1) end(th 1) + * 4 0 2 2 4 + * 5 0 2 2 5 + * + * Example 2: block_size = 2, num_threads = 3 + * total_work start(th 0) end(th 0) start(th 1) end(th 1) + * 4 0 2 2 4 + * 5 0 2 2 4 + * + * total_work start(th 2) end(th 2) + * 4 4 4 + * 5 4 5 + * + * Example 3: block_size = 2, num_threads = 4 + * total_work start(th 0) end(th 0) start(th 1) end(th 1) + * 4 0 2 2 4 + * 5 0 2 2 4 + * + * total_work start(th 2) end(th 2) start(th 3) end(th 3) + * 4 4 4 4 4 + * 5 4 4 4 5 + */ +FBGEMM_API void fbgemmPartition1DBlocked( + int thread_id, + int num_threads, + std::int64_t total_work, + int block_size, + std::int64_t& start, + std::int64_t& end); + +/** + * @brief A stable sorting algorithm. It sorts 8 bits at a time, hence in a + * worst-case performing sizeof(K) / 8 passes. Providing meaningful max_value + * may help reduce the number of passes performed by radix_sort. If + * maybe_with_neg_vals is set to true, we are performing all possible passes, + * up to a sign bit. If OpenMP is available in a build system, radix_sort works + * in parallel. + */ +template +FBGEMM_API std::pair radix_sort_parallel( + K* const inp_key_buf, + V* const inp_value_buf, + K* const tmp_key_buf, + V* const tmp_value_buf, + const int64_t elements_count, + const int64_t max_value, + const bool maybe_with_neg_vals = false); + +/** + * @brief Helper function that allows us to check whether radix_sort is + * accelerated with OpenMP or not. + */ +FBGEMM_API bool is_radix_sort_accelerated_with_openmp(); + +/** + * Choosing which kernel (autovec/asmjit/ref) to use for nbit-CPU-TBE + * Available kernels: + * * ref: non-optimized, reference implementation that focuses on + * correctness, not performance + * * asmjit: hand-optimized kernel by having asmjit emit SIMD + * instructions during runtime. Only supports x86_64 CPUs with + * AVX2/AVX512 instruction sets + * * autovec: the kernel written in regular C++ code but in a + * way that makes compilers easier to generate vectorized SIMD + * instructions out of it. Supports both x86_64 and aarch64 CPUs. + * Currently only available on Linux. + * How to set environment variables: + * * No environment variables: on x86_64 we will default to asmjit + * kernel, and on aarch64 and linux we will default to autovec. + * On non-linux aarch64 we will fall back to ref. + * * Set FBGEMM_NO_AUTOVEC: on aarch64 linux we will use ref. On other + * platforms this will have no effect. + * * Set FBGEMM_NO_ASMJIT: on x86_64 we will use ref. On other + * platforms this will have no effect. + * * Set FBGEMM_NO_ASMJIT AND FBGEMM_FORCE_AUTOVEC: on x86_64 we will + * use autovec if these two variables are set at the same time. + * No effect on other platforms. + * * FBGEMM_FORCE_AUTOVEC will override FBGEMM_NO_AUTOVEC if they + * are set at the same time. + * * These variables are considered set as long as they exist regardless + * of content. That means assigning values like "1", "true", "y", "0", + * "false" or "no" has the same effect. The easiest way of setting a + * variable is to prepend `=1` before the benchmarking command. + */ +FBGEMM_API bool is_autovec_disabled(); +FBGEMM_API bool is_autovec_forced(); +FBGEMM_API bool is_asmjit_disabled(); +FBGEMM_API bool is_stats_enabled(); + +/** + * @brief A function to check if the input parameter in the nbit CPU TBE kernel + * is valid. + */ +template +void nbit_embedding_sanity_check( + // assertions are ignored in release mode, in which case these parameters + // will be unused + [[maybe_unused]] const int input_bit_rate, + [[maybe_unused]] const int output_bit_rate, + [[maybe_unused]] const bool no_bag) { + assert( + (input_bit_rate == 2 || input_bit_rate == 4) && + "input_bit_rate must be 2 or 4"); + // NOLINTNEXTLINE(bugprone-branch-clone) + if constexpr (std::is_same_v) { + assert( + (no_bag && input_bit_rate == 4 && output_bit_rate == 4) && + "we currently only support int4 to int4 for sequential TBE"); + } else { + assert( + (output_bit_rate == 8 * sizeof(OutType)) && + "output_bit_rate should be equal to 8 * sizeof(OutType)"); + } +} + +#define WARN_ONCE(...) \ + do { \ + static bool _warned = false; \ + if (!_warned) { \ + _warned = true; \ + fprintf(stderr, __VA_ARGS__); \ + } \ + } while (0) + +} // namespace fbgemm + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fbgemm/UtilsAvx2.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fbgemm/UtilsAvx2.h new file mode 100644 index 0000000000000000000000000000000000000000..6a774a3fb71b17dda3bbf67bf8e5819d6fb95654 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fbgemm/UtilsAvx2.h @@ -0,0 +1,97 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once +// This file defines common utilities used in code compiled with avx2/avx512 +// flags. + +#include +#include + +namespace fbgemm { + +enum class FBGEMM_ENUM_CLASS_API QuantizationGranularity { + TENSOR, + GROUP, + OUT_CHANNEL, +}; + +/** + * @brief A struct to represent a block of a matrix. + */ +struct FBGEMM_API block_type_t { + int row_start; + int row_size; + int col_start; + int col_size; + + std::string toString() const { + std::string out; + out += "row start:" + std::to_string(row_start) + ", "; + out += "row size:" + std::to_string(row_size) + ", "; + out += "col start:" + std::to_string(col_start) + ", "; + out += "col size:" + std::to_string(col_size); + return out; + } +}; + +/** + * @brief A struct to represent all the requantization parameters. + * + * Please note that this is different from RequantizationParams in + * QuantUtilsAvx2.h as it combines all the parameters needed for various + * quantization granularities + */ +template +struct requantizationParams_t { + using BIAS_T = BIAS_TYPE; + std::int32_t A_zero_point; + const std::int32_t* B_zero_point; + std::int32_t C_zero_point; + const float* C_multiplier; + const std::int32_t* row_offsets; + const std::int32_t* col_offsets; + const BIAS_T* bias; + std::uint32_t ncols; + int groups; + const float* act_times_w_scale; +}; + +/** + * @brief A struct to represent all the parameters for requantizing for floats. + */ +struct requantizationForFloatParams_t { + std::int32_t A_zero_point; + const std::int32_t* B_zero_point; + float A_scale; + const float* B_scale; + const std::int32_t* row_offsets; + const std::int32_t* col_offsets; + const float* bias; + std::uint32_t ncols; + int groups; +}; + +/** + * @brief Allocate size bytes of uninitialized storage whose alignment is + * specified by align. + */ +FBGEMM_API void* +fbgemmAlignedAlloc(size_t align, size_t size, bool raiseException = false); + +/** + * @brief Free memory allocated by fbgemmAlignedAlloc + */ +FBGEMM_API void fbgemmAlignedFree(void* p); + +} // namespace fbgemm + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fbgemm/spmmUtils.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fbgemm/spmmUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..ba473fb3d8e16a65f743b5a048eacca64c360558 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fbgemm/spmmUtils.h @@ -0,0 +1,62 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once +#include + +#include "fbgemm/FbgemmBuild.h" +#include "fbgemm/FbgemmSparse.h" +#include "fbgemm/UtilsAvx2.h" +#include "fbgemm/spmmUtilsAvx2.h" + +namespace fbgemm { + +FBGEMM_API void sparseDenseMMRef( + int M, + int N, + const int* row_ptr, + const int* col_idx, + const float* values, + const float* B, + int ldb, + float* C, + int ldc, + bool accum = false); + +template +FBGEMM_API void sparseDenseInt8MMRef( + int N, + const std::unique_ptr>& bcsr, + const uint8_t* B, + int ldb, + int32_t* C_i32, + uint8_t* C_u8, + int ldc, + trRequantizationParams_t& rParams, + bool accum = false, + int thread_id = 0, + int num_threads = 1); + +template +FBGEMM_API void trRequantizeRef( + uint8_t* out, + const int32_t* inp, + const block_type_t& block, + int ld_out, + int ld_in, + const trRequantizationParams_t& r); + +// Get matrix shapes of interest +FBGEMM_API std::vector> getSparseMatrixShapes(); + +} // namespace fbgemm + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fbgemm/spmmUtilsAvx2.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fbgemm/spmmUtilsAvx2.h new file mode 100644 index 0000000000000000000000000000000000000000..7543b56dfa0410fd678e2ec24fa39a9054dee7d4 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fbgemm/spmmUtilsAvx2.h @@ -0,0 +1,44 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once +#include +#include "./FbgemmBuild.h" // @manual +#include "fbgemm/UtilsAvx2.h" + +namespace fbgemm { +struct FBGEMM_API trRequantizationParams_t { + std::int32_t act_zero_point; // activation zero point + const std::int32_t* weight_zero_points; // weight zero point(s) + std::int32_t C_zero_point; + const float C_scale; + const std::int32_t* weight_row_offsets; + const std::int32_t* act_col_offsets; + const float* bias; + const float* act_times_w_scale; +}; + +template < + bool FUSE_RELU, + bool ACT_SYMMETRIC, // whether activation matrix is symmetric + bool WEIGHT_SYMMETRIC, // whether weight matrix is symmetric + bool HAS_BIAS, + QuantizationGranularity Q_GRAN> +FBGEMM_API void trRequantizeOpt( + uint8_t* out, + const int32_t* inp, + const block_type_t& block, + int ld_out, + int ld_in, + const trRequantizationParams_t& rParams); +} // namespace fbgemm + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fmt/args.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fmt/args.h new file mode 100644 index 0000000000000000000000000000000000000000..33309f51f704da42ec73074839969d2ef578da17 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fmt/args.h @@ -0,0 +1,225 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Formatting library for C++ - dynamic argument lists +// +// Copyright (c) 2012 - present, Victor Zverovich +// All rights reserved. +// +// For the license information refer to format.h. + +#ifndef FMT_ARGS_H_ +#define FMT_ARGS_H_ + +#ifndef FMT_MODULE +# include // std::reference_wrapper +# include // std::unique_ptr +# include +#endif + +#include "format.h" // std_string_view + +FMT_BEGIN_NAMESPACE +namespace detail { + +template struct is_reference_wrapper : std::false_type {}; +template +struct is_reference_wrapper> : std::true_type {}; + +template auto unwrap(const T& v) -> const T& { return v; } +template +auto unwrap(const std::reference_wrapper& v) -> const T& { + return static_cast(v); +} + +// node is defined outside dynamic_arg_list to workaround a C2504 bug in MSVC +// 2022 (v17.10.0). +// +// Workaround for clang's -Wweak-vtables. Unlike for regular classes, for +// templates it doesn't complain about inability to deduce single translation +// unit for placing vtable. So node is made a fake template. +template struct node { + virtual ~node() = default; + std::unique_ptr> next; +}; + +class dynamic_arg_list { + template struct typed_node : node<> { + T value; + + template + FMT_CONSTEXPR typed_node(const Arg& arg) : value(arg) {} + + template + FMT_CONSTEXPR typed_node(const basic_string_view& arg) + : value(arg.data(), arg.size()) {} + }; + + std::unique_ptr> head_; + + public: + template auto push(const Arg& arg) -> const T& { + auto new_node = std::unique_ptr>(new typed_node(arg)); + auto& value = new_node->value; + new_node->next = std::move(head_); + head_ = std::move(new_node); + return value; + } +}; +} // namespace detail + +/** + * A dynamic list of formatting arguments with storage. + * + * It can be implicitly converted into `fmt::basic_format_args` for passing + * into type-erased formatting functions such as `fmt::vformat`. + */ +FMT_EXPORT template class dynamic_format_arg_store { + private: + using char_type = typename Context::char_type; + + template struct need_copy { + static constexpr detail::type mapped_type = + detail::mapped_type_constant::value; + + enum { + value = !(detail::is_reference_wrapper::value || + std::is_same>::value || + std::is_same>::value || + (mapped_type != detail::type::cstring_type && + mapped_type != detail::type::string_type && + mapped_type != detail::type::custom_type)) + }; + }; + + template + using stored_t = conditional_t< + std::is_convertible>::value && + !detail::is_reference_wrapper::value, + std::basic_string, T>; + + // Storage of basic_format_arg must be contiguous. + std::vector> data_; + std::vector> named_info_; + + // Storage of arguments not fitting into basic_format_arg must grow + // without relocation because items in data_ refer to it. + detail::dynamic_arg_list dynamic_args_; + + friend class basic_format_args; + + auto data() const -> const basic_format_arg* { + return named_info_.empty() ? data_.data() : data_.data() + 1; + } + + template void emplace_arg(const T& arg) { + data_.emplace_back(arg); + } + + template + void emplace_arg(const detail::named_arg& arg) { + if (named_info_.empty()) + data_.insert(data_.begin(), basic_format_arg(nullptr, 0)); + data_.emplace_back(detail::unwrap(arg.value)); + auto pop_one = [](std::vector>* data) { + data->pop_back(); + }; + std::unique_ptr>, decltype(pop_one)> + guard{&data_, pop_one}; + named_info_.push_back({arg.name, static_cast(data_.size() - 2u)}); + data_[0] = {named_info_.data(), named_info_.size()}; + guard.release(); + } + + public: + constexpr dynamic_format_arg_store() = default; + + operator basic_format_args() const { + return basic_format_args(data(), static_cast(data_.size()), + !named_info_.empty()); + } + + /** + * Adds an argument into the dynamic store for later passing to a formatting + * function. + * + * Note that custom types and string types (but not string views) are copied + * into the store dynamically allocating memory if necessary. + * + * **Example**: + * + * fmt::dynamic_format_arg_store store; + * store.push_back(42); + * store.push_back("abc"); + * store.push_back(1.5f); + * std::string result = fmt::vformat("{} and {} and {}", store); + */ + template void push_back(const T& arg) { + if (detail::const_check(need_copy::value)) + emplace_arg(dynamic_args_.push>(arg)); + else + emplace_arg(detail::unwrap(arg)); + } + + /** + * Adds a reference to the argument into the dynamic store for later passing + * to a formatting function. + * + * **Example**: + * + * fmt::dynamic_format_arg_store store; + * char band[] = "Rolling Stones"; + * store.push_back(std::cref(band)); + * band[9] = 'c'; // Changing str affects the output. + * std::string result = fmt::vformat("{}", store); + * // result == "Rolling Scones" + */ + template void push_back(std::reference_wrapper arg) { + static_assert( + need_copy::value, + "objects of built-in types and string views are always copied"); + emplace_arg(arg.get()); + } + + /** + * Adds named argument into the dynamic store for later passing to a + * formatting function. `std::reference_wrapper` is supported to avoid + * copying of the argument. The name is always copied into the store. + */ + template + void push_back(const detail::named_arg& arg) { + const char_type* arg_name = + dynamic_args_.push>(arg.name).c_str(); + if (detail::const_check(need_copy::value)) { + emplace_arg( + fmt::arg(arg_name, dynamic_args_.push>(arg.value))); + } else { + emplace_arg(fmt::arg(arg_name, arg.value)); + } + } + + /// Erase all elements from the store. + void clear() { + data_.clear(); + named_info_.clear(); + dynamic_args_ = {}; + } + + /// Reserves space to store at least `new_cap` arguments including + /// `new_cap_named` named arguments. + void reserve(size_t new_cap, size_t new_cap_named) { + FMT_ASSERT(new_cap >= new_cap_named, + "set of arguments includes set of named arguments"); + data_.reserve(new_cap); + named_info_.reserve(new_cap_named); + } + + /// Returns the number of elements in the store. + auto size() const noexcept -> size_t { return data_.size(); } +}; + +FMT_END_NAMESPACE + +#endif // FMT_ARGS_H_ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fmt/base.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fmt/base.h new file mode 100644 index 0000000000000000000000000000000000000000..c72f2fbe80572fc8bb73b04f262ebdbd866c278b --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fmt/base.h @@ -0,0 +1,3015 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Formatting library for C++ - the base API for char/UTF-8 +// +// Copyright (c) 2012 - present, Victor Zverovich +// All rights reserved. +// +// For the license information refer to format.h. + +#ifndef FMT_BASE_H_ +#define FMT_BASE_H_ + +#if defined(FMT_IMPORT_STD) && !defined(FMT_MODULE) +# define FMT_MODULE +#endif + +#ifndef FMT_MODULE +# include // CHAR_BIT +# include // FILE +# include // memcmp + +# include // std::enable_if +#endif + +// The fmt library version in the form major * 10000 + minor * 100 + patch. +#define FMT_VERSION 120100 + +// Detect compiler versions. +#if defined(__clang__) && !defined(__ibmxl__) +# define FMT_CLANG_VERSION (__clang_major__ * 100 + __clang_minor__) +#else +# define FMT_CLANG_VERSION 0 +#endif +#if defined(__GNUC__) && !defined(__clang__) && !defined(__INTEL_COMPILER) +# define FMT_GCC_VERSION (__GNUC__ * 100 + __GNUC_MINOR__) +#else +# define FMT_GCC_VERSION 0 +#endif +#if defined(__ICL) +# define FMT_ICC_VERSION __ICL +#elif defined(__INTEL_COMPILER) +# define FMT_ICC_VERSION __INTEL_COMPILER +#else +# define FMT_ICC_VERSION 0 +#endif +#if defined(_MSC_VER) +# define FMT_MSC_VERSION _MSC_VER +#else +# define FMT_MSC_VERSION 0 +#endif + +// Detect standard library versions. +#ifdef _GLIBCXX_RELEASE +# define FMT_GLIBCXX_RELEASE _GLIBCXX_RELEASE +#else +# define FMT_GLIBCXX_RELEASE 0 +#endif +#ifdef _LIBCPP_VERSION +# define FMT_LIBCPP_VERSION _LIBCPP_VERSION +#else +# define FMT_LIBCPP_VERSION 0 +#endif + +#ifdef _MSVC_LANG +# define FMT_CPLUSPLUS _MSVC_LANG +#else +# define FMT_CPLUSPLUS __cplusplus +#endif + +// Detect __has_*. +#ifdef __has_feature +# define FMT_HAS_FEATURE(x) __has_feature(x) +#else +# define FMT_HAS_FEATURE(x) 0 +#endif +#ifdef __has_include +# define FMT_HAS_INCLUDE(x) __has_include(x) +#else +# define FMT_HAS_INCLUDE(x) 0 +#endif +#ifdef __has_builtin +# define FMT_HAS_BUILTIN(x) __has_builtin(x) +#else +# define FMT_HAS_BUILTIN(x) 0 +#endif +#ifdef __has_cpp_attribute +# define FMT_HAS_CPP_ATTRIBUTE(x) __has_cpp_attribute(x) +#else +# define FMT_HAS_CPP_ATTRIBUTE(x) 0 +#endif + +#define FMT_HAS_CPP14_ATTRIBUTE(attribute) \ + (FMT_CPLUSPLUS >= 201402L && FMT_HAS_CPP_ATTRIBUTE(attribute)) + +#define FMT_HAS_CPP17_ATTRIBUTE(attribute) \ + (FMT_CPLUSPLUS >= 201703L && FMT_HAS_CPP_ATTRIBUTE(attribute)) + +// Detect C++14 relaxed constexpr. +#ifdef FMT_USE_CONSTEXPR +// Use the provided definition. +#elif FMT_GCC_VERSION >= 702 && FMT_CPLUSPLUS >= 201402L +// GCC only allows constexpr member functions in non-literal types since 7.2: +// https://gcc.gnu.org/bugzilla/show_bug.cgi?id=66297. +# define FMT_USE_CONSTEXPR 1 +#elif FMT_ICC_VERSION +# define FMT_USE_CONSTEXPR 0 // https://github.com/fmtlib/fmt/issues/1628 +#elif FMT_HAS_FEATURE(cxx_relaxed_constexpr) || FMT_MSC_VERSION >= 1912 +# define FMT_USE_CONSTEXPR 1 +#else +# define FMT_USE_CONSTEXPR 0 +#endif +#if FMT_USE_CONSTEXPR +# define FMT_CONSTEXPR constexpr +#else +# define FMT_CONSTEXPR +#endif + +// Detect consteval, C++20 constexpr extensions and std::is_constant_evaluated. +#ifdef FMT_USE_CONSTEVAL +// Use the provided definition. +#elif !defined(__cpp_lib_is_constant_evaluated) +# define FMT_USE_CONSTEVAL 0 +#elif FMT_CPLUSPLUS < 201709L +# define FMT_USE_CONSTEVAL 0 +#elif FMT_GLIBCXX_RELEASE && FMT_GLIBCXX_RELEASE < 10 +# define FMT_USE_CONSTEVAL 0 +#elif FMT_LIBCPP_VERSION && FMT_LIBCPP_VERSION < 10000 +# define FMT_USE_CONSTEVAL 0 +#elif defined(__apple_build_version__) && __apple_build_version__ < 14000029L +# define FMT_USE_CONSTEVAL 0 // consteval is broken in Apple clang < 14. +#elif FMT_MSC_VERSION && FMT_MSC_VERSION < 1929 +# define FMT_USE_CONSTEVAL 0 // consteval is broken in MSVC VS2019 < 16.10. +#elif defined(__cpp_consteval) +# define FMT_USE_CONSTEVAL 1 +#elif FMT_GCC_VERSION >= 1002 || FMT_CLANG_VERSION >= 1101 +# define FMT_USE_CONSTEVAL 1 +#else +# define FMT_USE_CONSTEVAL 0 +#endif +#if FMT_USE_CONSTEVAL +# define FMT_CONSTEVAL consteval +# define FMT_CONSTEXPR20 constexpr +#else +# define FMT_CONSTEVAL +# define FMT_CONSTEXPR20 +#endif + +// Check if exceptions are disabled. +#ifdef FMT_USE_EXCEPTIONS +// Use the provided definition. +#elif defined(__GNUC__) && !defined(__EXCEPTIONS) +# define FMT_USE_EXCEPTIONS 0 +#elif defined(__clang__) && !defined(__cpp_exceptions) +# define FMT_USE_EXCEPTIONS 0 +#elif FMT_MSC_VERSION && !_HAS_EXCEPTIONS +# define FMT_USE_EXCEPTIONS 0 +#else +# define FMT_USE_EXCEPTIONS 1 +#endif +#if FMT_USE_EXCEPTIONS +# define FMT_TRY try +# define FMT_CATCH(x) catch (x) +#else +# define FMT_TRY if (true) +# define FMT_CATCH(x) if (false) +#endif + +#ifdef FMT_NO_UNIQUE_ADDRESS +// Use the provided definition. +#elif FMT_CPLUSPLUS < 202002L +// Not supported. +#elif FMT_HAS_CPP_ATTRIBUTE(no_unique_address) +# define FMT_NO_UNIQUE_ADDRESS [[no_unique_address]] +// VS2019 v16.10 and later except clang-cl (https://reviews.llvm.org/D110485). +#elif FMT_MSC_VERSION >= 1929 && !FMT_CLANG_VERSION +# define FMT_NO_UNIQUE_ADDRESS [[msvc::no_unique_address]] +#endif +#ifndef FMT_NO_UNIQUE_ADDRESS +# define FMT_NO_UNIQUE_ADDRESS +#endif + +#if FMT_HAS_CPP17_ATTRIBUTE(fallthrough) +# define FMT_FALLTHROUGH [[fallthrough]] +#elif defined(__clang__) +# define FMT_FALLTHROUGH [[clang::fallthrough]] +#elif FMT_GCC_VERSION >= 700 && \ + (!defined(__EDG_VERSION__) || __EDG_VERSION__ >= 520) +# define FMT_FALLTHROUGH [[gnu::fallthrough]] +#else +# define FMT_FALLTHROUGH +#endif + +// Disable [[noreturn]] on MSVC/NVCC because of bogus unreachable code warnings. +#if FMT_HAS_CPP_ATTRIBUTE(noreturn) && !FMT_MSC_VERSION && !defined(__NVCC__) +# define FMT_NORETURN [[noreturn]] +#else +# define FMT_NORETURN +#endif + +#ifdef FMT_NODISCARD +// Use the provided definition. +#elif FMT_HAS_CPP17_ATTRIBUTE(nodiscard) +# define FMT_NODISCARD [[nodiscard]] +#else +# define FMT_NODISCARD +#endif + +#if FMT_GCC_VERSION || FMT_CLANG_VERSION +# define FMT_VISIBILITY(value) __attribute__((visibility(value))) +#else +# define FMT_VISIBILITY(value) +#endif + +// Detect pragmas. +#define FMT_PRAGMA_IMPL(x) _Pragma(#x) +#if FMT_GCC_VERSION >= 504 && !defined(__NVCOMPILER) +// Workaround a _Pragma bug https://gcc.gnu.org/bugzilla/show_bug.cgi?id=59884 +// and an nvhpc warning: https://github.com/fmtlib/fmt/pull/2582. +# define FMT_PRAGMA_GCC(x) FMT_PRAGMA_IMPL(GCC x) +#else +# define FMT_PRAGMA_GCC(x) +#endif +#if FMT_CLANG_VERSION +# define FMT_PRAGMA_CLANG(x) FMT_PRAGMA_IMPL(clang x) +#else +# define FMT_PRAGMA_CLANG(x) +#endif +#if FMT_MSC_VERSION +# define FMT_MSC_WARNING(...) __pragma(warning(__VA_ARGS__)) +#else +# define FMT_MSC_WARNING(...) +#endif + +// Enable minimal optimizations for more compact code in debug mode. +FMT_PRAGMA_GCC(push_options) +#if !defined(__OPTIMIZE__) && !defined(__CUDACC__) && !defined(FMT_MODULE) +FMT_PRAGMA_GCC(optimize("Og")) +# define FMT_GCC_OPTIMIZED +#endif +FMT_PRAGMA_CLANG(diagnostic push) +FMT_PRAGMA_GCC(diagnostic push) + +#ifdef FMT_ALWAYS_INLINE +// Use the provided definition. +#elif FMT_GCC_VERSION || FMT_CLANG_VERSION +# define FMT_ALWAYS_INLINE inline __attribute__((always_inline)) +#else +# define FMT_ALWAYS_INLINE inline +#endif +// A version of FMT_ALWAYS_INLINE to prevent code bloat in debug mode. +#if defined(NDEBUG) || defined(FMT_GCC_OPTIMIZED) +# define FMT_INLINE FMT_ALWAYS_INLINE +#else +# define FMT_INLINE inline +#endif + +#ifndef FMT_BEGIN_NAMESPACE +# define FMT_BEGIN_NAMESPACE \ + namespace fmt { \ + inline namespace v12 { +# define FMT_END_NAMESPACE \ + } \ + } +#endif + +#ifndef FMT_EXPORT +# define FMT_EXPORT +# define FMT_BEGIN_EXPORT +# define FMT_END_EXPORT +#endif + +#ifdef _WIN32 +# define FMT_WIN32 1 +#else +# define FMT_WIN32 0 +#endif + +#if !defined(FMT_HEADER_ONLY) && FMT_WIN32 +# if defined(FMT_LIB_EXPORT) +# define FMT_API __declspec(dllexport) +# elif defined(FMT_SHARED) +# define FMT_API __declspec(dllimport) +# endif +#elif defined(FMT_LIB_EXPORT) || defined(FMT_SHARED) +# define FMT_API FMT_VISIBILITY("default") +#endif +#ifndef FMT_API +# define FMT_API +#endif + +#ifndef FMT_OPTIMIZE_SIZE +# define FMT_OPTIMIZE_SIZE 0 +#endif + +// FMT_BUILTIN_TYPE=0 may result in smaller library size at the cost of higher +// per-call binary size by passing built-in types through the extension API. +#ifndef FMT_BUILTIN_TYPES +# define FMT_BUILTIN_TYPES 1 +#endif + +#define FMT_APPLY_VARIADIC(expr) \ + using unused = int[]; \ + (void)unused { 0, (expr, 0)... } + +FMT_BEGIN_NAMESPACE + +// Implementations of enable_if_t and other metafunctions for older systems. +template +using enable_if_t = typename std::enable_if::type; +template +using conditional_t = typename std::conditional::type; +template using bool_constant = std::integral_constant; +template +using remove_reference_t = typename std::remove_reference::type; +template +using remove_const_t = typename std::remove_const::type; +template +using remove_cvref_t = typename std::remove_cv>::type; +template +using make_unsigned_t = typename std::make_unsigned::type; +template +using underlying_t = typename std::underlying_type::type; +template using decay_t = typename std::decay::type; +using nullptr_t = decltype(nullptr); + +#if (FMT_GCC_VERSION && FMT_GCC_VERSION < 500) || FMT_MSC_VERSION +// A workaround for gcc 4.9 & MSVC v141 to make void_t work in a SFINAE context. +template struct void_t_impl { + using type = void; +}; +template using void_t = typename void_t_impl::type; +#else +template using void_t = void; +#endif + +struct monostate { + constexpr monostate() {} +}; + +// An enable_if helper to be used in template parameters which results in much +// shorter symbols: https://godbolt.org/z/sWw4vP. Extra parentheses are needed +// to workaround a bug in MSVC 2019 (see #1140 and #1186). +#ifdef FMT_DOC +# define FMT_ENABLE_IF(...) +#else +# define FMT_ENABLE_IF(...) fmt::enable_if_t<(__VA_ARGS__), int> = 0 +#endif + +template constexpr auto min_of(T a, T b) -> T { + return a < b ? a : b; +} +template constexpr auto max_of(T a, T b) -> T { + return a > b ? a : b; +} + +FMT_NORETURN FMT_API void assert_fail(const char* file, int line, + const char* message); + +namespace detail { +// Suppresses "unused variable" warnings with the method described in +// https://herbsutter.com/2009/10/18/mailbag-shutting-up-compiler-warnings/. +// (void)var does not work on many Intel compilers. +template FMT_CONSTEXPR void ignore_unused(const T&...) {} + +constexpr auto is_constant_evaluated(bool default_value = false) noexcept + -> bool { +// Workaround for incompatibility between clang 14 and libstdc++ consteval-based +// std::is_constant_evaluated: https://github.com/fmtlib/fmt/issues/3247. +#if FMT_CPLUSPLUS >= 202002L && FMT_GLIBCXX_RELEASE >= 12 && \ + (FMT_CLANG_VERSION >= 1400 && FMT_CLANG_VERSION < 1500) + ignore_unused(default_value); + return __builtin_is_constant_evaluated(); +#elif defined(__cpp_lib_is_constant_evaluated) + ignore_unused(default_value); + return std::is_constant_evaluated(); +#else + return default_value; +#endif +} + +// Suppresses "conditional expression is constant" warnings. +template FMT_ALWAYS_INLINE constexpr auto const_check(T val) -> T { + return val; +} + +FMT_NORETURN FMT_API void assert_fail(const char* file, int line, + const char* message); + +#if defined(FMT_ASSERT) +// Use the provided definition. +#elif defined(NDEBUG) +// FMT_ASSERT is not empty to avoid -Wempty-body. +# define FMT_ASSERT(condition, message) \ + fmt::detail::ignore_unused((condition), (message)) +#else +# define FMT_ASSERT(condition, message) \ + ((condition) /* void() fails with -Winvalid-constexpr on clang 4.0.1 */ \ + ? (void)0 \ + : ::fmt::assert_fail(__FILE__, __LINE__, (message))) +#endif + +#ifdef FMT_USE_INT128 +// Use the provided definition. +#elif defined(__SIZEOF_INT128__) && !defined(__NVCC__) && \ + !(FMT_CLANG_VERSION && FMT_MSC_VERSION) +# define FMT_USE_INT128 1 +using int128_opt = __int128_t; // An optional native 128-bit integer. +using uint128_opt = __uint128_t; +inline auto map(int128_opt x) -> int128_opt { return x; } +inline auto map(uint128_opt x) -> uint128_opt { return x; } +#else +# define FMT_USE_INT128 0 +#endif +#if !FMT_USE_INT128 +enum class int128_opt {}; +enum class uint128_opt {}; +// Reduce template instantiations. +inline auto map(int128_opt) -> monostate { return {}; } +inline auto map(uint128_opt) -> monostate { return {}; } +#endif + +#ifdef FMT_USE_BITINT +// Use the provided definition. +#elif FMT_CLANG_VERSION >= 1500 && !defined(__CUDACC__) +# define FMT_USE_BITINT 1 +#else +# define FMT_USE_BITINT 0 +#endif + +#if FMT_USE_BITINT +FMT_PRAGMA_CLANG(diagnostic ignored "-Wbit-int-extension") +template using bitint = _BitInt(N); +template using ubitint = unsigned _BitInt(N); +#else +template struct bitint {}; +template struct ubitint {}; +#endif // FMT_USE_BITINT + +// Casts a nonnegative integer to unsigned. +template +FMT_CONSTEXPR auto to_unsigned(Int value) -> make_unsigned_t { + FMT_ASSERT(std::is_unsigned::value || value >= 0, "negative value"); + return static_cast>(value); +} + +template +using unsigned_char = conditional_t; + +// A heuristic to detect std::string and std::[experimental::]string_view. +// It is mainly used to avoid dependency on <[experimental/]string_view>. +template +struct is_std_string_like : std::false_type {}; +template +struct is_std_string_like().find_first_of( + typename T::value_type(), 0))>> + : std::is_convertible().data()), + const typename T::value_type*> {}; + +// Check if the literal encoding is UTF-8. +enum { is_utf8_enabled = "\u00A7"[1] == '\xA7' }; +enum { use_utf8 = !FMT_WIN32 || is_utf8_enabled }; + +#ifndef FMT_UNICODE +# define FMT_UNICODE 1 +#endif + +static_assert(!FMT_UNICODE || use_utf8, + "Unicode support requires compiling with /utf-8"); + +template constexpr auto narrow(T*) -> char* { return nullptr; } +constexpr FMT_ALWAYS_INLINE auto narrow(const char* s) -> const char* { + return s; +} + +template +FMT_CONSTEXPR auto compare(const Char* s1, const Char* s2, size_t n) -> int { + if (!is_constant_evaluated() && sizeof(Char) == 1) return memcmp(s1, s2, n); + for (; n != 0; ++s1, ++s2, --n) { + if (*s1 < *s2) return -1; + if (*s1 > *s2) return 1; + } + return 0; +} + +namespace adl { +using namespace std; + +template +auto invoke_back_inserter() + -> decltype(back_inserter(std::declval())); +} // namespace adl + +template +struct is_back_insert_iterator : std::false_type {}; + +template +struct is_back_insert_iterator< + It, bool_constant()), + It>::value>> : std::true_type {}; + +// Extracts a reference to the container from *insert_iterator. +template +inline FMT_CONSTEXPR20 auto get_container(OutputIt it) -> + typename OutputIt::container_type& { + struct accessor : OutputIt { + FMT_CONSTEXPR20 accessor(OutputIt base) : OutputIt(base) {} + using OutputIt::container; + }; + return *accessor(it).container; +} +} // namespace detail + +// Parsing-related public API and forward declarations. +FMT_BEGIN_EXPORT + +/** + * An implementation of `std::basic_string_view` for pre-C++17. It provides a + * subset of the API. `fmt::basic_string_view` is used for format strings even + * if `std::basic_string_view` is available to prevent issues when a library is + * compiled with a different `-std` option than the client code (which is not + * recommended). + */ +template class basic_string_view { + private: + const Char* data_; + size_t size_; + + public: + using value_type = Char; + using iterator = const Char*; + + constexpr basic_string_view() noexcept : data_(nullptr), size_(0) {} + + /// Constructs a string view object from a C string and a size. + constexpr basic_string_view(const Char* s, size_t count) noexcept + : data_(s), size_(count) {} + + constexpr basic_string_view(nullptr_t) = delete; + + /// Constructs a string view object from a C string. +#if FMT_GCC_VERSION + FMT_ALWAYS_INLINE +#endif + FMT_CONSTEXPR20 basic_string_view(const Char* s) : data_(s) { +#if FMT_HAS_BUILTIN(__builtin_strlen) || FMT_GCC_VERSION || FMT_CLANG_VERSION + if (std::is_same::value && !detail::is_constant_evaluated()) { + size_ = __builtin_strlen(detail::narrow(s)); // strlen is not constexpr. + return; + } +#endif + size_t len = 0; + while (*s++) ++len; + size_ = len; + } + + /// Constructs a string view from a `std::basic_string` or a + /// `std::basic_string_view` object. + template ::value&& std::is_same< + typename S::value_type, Char>::value)> + FMT_CONSTEXPR basic_string_view(const S& s) noexcept + : data_(s.data()), size_(s.size()) {} + + /// Returns a pointer to the string data. + constexpr auto data() const noexcept -> const Char* { return data_; } + + /// Returns the string size. + constexpr auto size() const noexcept -> size_t { return size_; } + + constexpr auto begin() const noexcept -> iterator { return data_; } + constexpr auto end() const noexcept -> iterator { return data_ + size_; } + + constexpr auto operator[](size_t pos) const noexcept -> const Char& { + return data_[pos]; + } + + FMT_CONSTEXPR void remove_prefix(size_t n) noexcept { + data_ += n; + size_ -= n; + } + + FMT_CONSTEXPR auto starts_with(basic_string_view sv) const noexcept + -> bool { + return size_ >= sv.size_ && detail::compare(data_, sv.data_, sv.size_) == 0; + } + FMT_CONSTEXPR auto starts_with(Char c) const noexcept -> bool { + return size_ >= 1 && *data_ == c; + } + FMT_CONSTEXPR auto starts_with(const Char* s) const -> bool { + return starts_with(basic_string_view(s)); + } + + FMT_CONSTEXPR auto compare(basic_string_view other) const -> int { + int result = + detail::compare(data_, other.data_, min_of(size_, other.size_)); + if (result != 0) return result; + return size_ == other.size_ ? 0 : (size_ < other.size_ ? -1 : 1); + } + + FMT_CONSTEXPR friend auto operator==(basic_string_view lhs, + basic_string_view rhs) -> bool { + return lhs.compare(rhs) == 0; + } + friend auto operator!=(basic_string_view lhs, basic_string_view rhs) -> bool { + return lhs.compare(rhs) != 0; + } + friend auto operator<(basic_string_view lhs, basic_string_view rhs) -> bool { + return lhs.compare(rhs) < 0; + } + friend auto operator<=(basic_string_view lhs, basic_string_view rhs) -> bool { + return lhs.compare(rhs) <= 0; + } + friend auto operator>(basic_string_view lhs, basic_string_view rhs) -> bool { + return lhs.compare(rhs) > 0; + } + friend auto operator>=(basic_string_view lhs, basic_string_view rhs) -> bool { + return lhs.compare(rhs) >= 0; + } +}; + +using string_view = basic_string_view; + +template class basic_appender; +using appender = basic_appender; + +// Checks whether T is a container with contiguous storage. +template struct is_contiguous : std::false_type {}; + +class context; +template class generic_context; +template class parse_context; + +// Longer aliases for C++20 compatibility. +template using basic_format_parse_context = parse_context; +using format_parse_context = parse_context; +template +using basic_format_context = + conditional_t::value, context, + generic_context>; +using format_context = context; + +template +using buffered_context = + conditional_t::value, context, + generic_context, Char>>; + +template class basic_format_arg; +template class basic_format_args; + +// A separate type would result in shorter symbols but break ABI compatibility +// between clang and gcc on ARM (#1919). +using format_args = basic_format_args; + +// A formatter for objects of type T. +template +struct formatter { + // A deleted default constructor indicates a disabled formatter. + formatter() = delete; +}; + +/// Reports a format error at compile time or, via a `format_error` exception, +/// at runtime. +// This function is intentionally not constexpr to give a compile-time error. +FMT_NORETURN FMT_API void report_error(const char* message); + +enum class presentation_type : unsigned char { + // Common specifiers: + none = 0, + debug = 1, // '?' + string = 2, // 's' (string, bool) + + // Integral, bool and character specifiers: + dec = 3, // 'd' + hex, // 'x' or 'X' + oct, // 'o' + bin, // 'b' or 'B' + chr, // 'c' + + // String and pointer specifiers: + pointer = 3, // 'p' + + // Floating-point specifiers: + exp = 1, // 'e' or 'E' (1 since there is no FP debug presentation) + fixed, // 'f' or 'F' + general, // 'g' or 'G' + hexfloat // 'a' or 'A' +}; + +enum class align { none, left, right, center, numeric }; +enum class sign { none, minus, plus, space }; +enum class arg_id_kind { none, index, name }; + +// Basic format specifiers for built-in and string types. +class basic_specs { + private: + // Data is arranged as follows: + // + // 0 1 2 3 + // 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + // |type |align| w | p | s |u|#|L| f | unused | + // +-----+-----+---+---+---+-+-+-+-----+---------------------------+ + // + // w - dynamic width info + // p - dynamic precision info + // s - sign + // u - uppercase (e.g. 'X' for 'x') + // # - alternate form ('#') + // L - localized + // f - fill size + // + // Bitfields are not used because of compiler bugs such as gcc bug 61414. + enum : unsigned { + type_mask = 0x00007, + align_mask = 0x00038, + width_mask = 0x000C0, + precision_mask = 0x00300, + sign_mask = 0x00C00, + uppercase_mask = 0x01000, + alternate_mask = 0x02000, + localized_mask = 0x04000, + fill_size_mask = 0x38000, + + align_shift = 3, + width_shift = 6, + precision_shift = 8, + sign_shift = 10, + fill_size_shift = 15, + + max_fill_size = 4 + }; + + unsigned data_ = 1 << fill_size_shift; + static_assert(sizeof(basic_specs::data_) * CHAR_BIT >= 18, ""); + + // Character (code unit) type is erased to prevent template bloat. + char fill_data_[max_fill_size] = {' '}; + + FMT_CONSTEXPR void set_fill_size(size_t size) { + data_ = (data_ & ~fill_size_mask) | + (static_cast(size) << fill_size_shift); + } + + public: + constexpr auto type() const -> presentation_type { + return static_cast(data_ & type_mask); + } + FMT_CONSTEXPR void set_type(presentation_type t) { + data_ = (data_ & ~type_mask) | static_cast(t); + } + + constexpr auto align() const -> align { + return static_cast((data_ & align_mask) >> align_shift); + } + FMT_CONSTEXPR void set_align(fmt::align a) { + data_ = (data_ & ~align_mask) | (static_cast(a) << align_shift); + } + + constexpr auto dynamic_width() const -> arg_id_kind { + return static_cast((data_ & width_mask) >> width_shift); + } + FMT_CONSTEXPR void set_dynamic_width(arg_id_kind w) { + data_ = (data_ & ~width_mask) | (static_cast(w) << width_shift); + } + + FMT_CONSTEXPR auto dynamic_precision() const -> arg_id_kind { + return static_cast((data_ & precision_mask) >> + precision_shift); + } + FMT_CONSTEXPR void set_dynamic_precision(arg_id_kind p) { + data_ = (data_ & ~precision_mask) | + (static_cast(p) << precision_shift); + } + + constexpr auto dynamic() const -> bool { + return (data_ & (width_mask | precision_mask)) != 0; + } + + constexpr auto sign() const -> sign { + return static_cast((data_ & sign_mask) >> sign_shift); + } + FMT_CONSTEXPR void set_sign(fmt::sign s) { + data_ = (data_ & ~sign_mask) | (static_cast(s) << sign_shift); + } + + constexpr auto upper() const -> bool { return (data_ & uppercase_mask) != 0; } + FMT_CONSTEXPR void set_upper() { data_ |= uppercase_mask; } + + constexpr auto alt() const -> bool { return (data_ & alternate_mask) != 0; } + FMT_CONSTEXPR void set_alt() { data_ |= alternate_mask; } + FMT_CONSTEXPR void clear_alt() { data_ &= ~alternate_mask; } + + constexpr auto localized() const -> bool { + return (data_ & localized_mask) != 0; + } + FMT_CONSTEXPR void set_localized() { data_ |= localized_mask; } + + constexpr auto fill_size() const -> size_t { + return (data_ & fill_size_mask) >> fill_size_shift; + } + + template ::value)> + constexpr auto fill() const -> const Char* { + return fill_data_; + } + template ::value)> + constexpr auto fill() const -> const Char* { + return nullptr; + } + + template constexpr auto fill_unit() const -> Char { + using uchar = unsigned char; + return static_cast(static_cast(fill_data_[0]) | + (static_cast(fill_data_[1]) << 8) | + (static_cast(fill_data_[2]) << 16)); + } + + FMT_CONSTEXPR void set_fill(char c) { + fill_data_[0] = c; + set_fill_size(1); + } + + template + FMT_CONSTEXPR void set_fill(basic_string_view s) { + auto size = s.size(); + set_fill_size(size); + if (size == 1) { + unsigned uchar = static_cast>(s[0]); + fill_data_[0] = static_cast(uchar); + fill_data_[1] = static_cast(uchar >> 8); + fill_data_[2] = static_cast(uchar >> 16); + return; + } + FMT_ASSERT(size <= max_fill_size, "invalid fill"); + for (size_t i = 0; i < size; ++i) + fill_data_[i & 3] = static_cast(s[i]); + } + + FMT_CONSTEXPR void copy_fill_from(const basic_specs& specs) { + set_fill_size(specs.fill_size()); + for (size_t i = 0; i < max_fill_size; ++i) + fill_data_[i] = specs.fill_data_[i]; + } +}; + +// Format specifiers for built-in and string types. +struct format_specs : basic_specs { + int width; + int precision; + + constexpr format_specs() : width(0), precision(-1) {} +}; + +/** + * Parsing context consisting of a format string range being parsed and an + * argument counter for automatic indexing. + */ +template class parse_context { + private: + basic_string_view fmt_; + int next_arg_id_; + + enum { use_constexpr_cast = !FMT_GCC_VERSION || FMT_GCC_VERSION >= 1200 }; + + FMT_CONSTEXPR void do_check_arg_id(int arg_id); + + public: + using char_type = Char; + using iterator = const Char*; + + constexpr explicit parse_context(basic_string_view fmt, + int next_arg_id = 0) + : fmt_(fmt), next_arg_id_(next_arg_id) {} + + /// Returns an iterator to the beginning of the format string range being + /// parsed. + constexpr auto begin() const noexcept -> iterator { return fmt_.begin(); } + + /// Returns an iterator past the end of the format string range being parsed. + constexpr auto end() const noexcept -> iterator { return fmt_.end(); } + + /// Advances the begin iterator to `it`. + FMT_CONSTEXPR void advance_to(iterator it) { + fmt_.remove_prefix(detail::to_unsigned(it - begin())); + } + + /// Reports an error if using the manual argument indexing; otherwise returns + /// the next argument index and switches to the automatic indexing. + FMT_CONSTEXPR auto next_arg_id() -> int { + if (next_arg_id_ < 0) { + report_error("cannot switch from manual to automatic argument indexing"); + return 0; + } + int id = next_arg_id_++; + do_check_arg_id(id); + return id; + } + + /// Reports an error if using the automatic argument indexing; otherwise + /// switches to the manual indexing. + FMT_CONSTEXPR void check_arg_id(int id) { + if (next_arg_id_ > 0) { + report_error("cannot switch from automatic to manual argument indexing"); + return; + } + next_arg_id_ = -1; + do_check_arg_id(id); + } + FMT_CONSTEXPR void check_arg_id(basic_string_view) { + next_arg_id_ = -1; + } + FMT_CONSTEXPR void check_dynamic_spec(int arg_id); +}; + +#ifndef FMT_USE_LOCALE +# define FMT_USE_LOCALE (FMT_OPTIMIZE_SIZE <= 1) +#endif + +// A type-erased reference to std::locale to avoid the heavy include. +class locale_ref { +#if FMT_USE_LOCALE + private: + const void* locale_; // A type-erased pointer to std::locale. + + public: + constexpr locale_ref() : locale_(nullptr) {} + + template + locale_ref(const Locale& loc) : locale_(&loc) { + // Check if std::isalpha is found via ADL to reduce the chance of misuse. + isalpha('x', loc); + } + + inline explicit operator bool() const noexcept { return locale_ != nullptr; } +#endif // FMT_USE_LOCALE + + public: + template auto get() const -> Locale; +}; + +FMT_END_EXPORT + +namespace detail { + +// Specifies if `T` is a code unit type. +template struct is_code_unit : std::false_type {}; +template <> struct is_code_unit : std::true_type {}; +template <> struct is_code_unit : std::true_type {}; +template <> struct is_code_unit : std::true_type {}; +template <> struct is_code_unit : std::true_type {}; +#ifdef __cpp_char8_t +template <> struct is_code_unit : bool_constant {}; +#endif + +// Constructs fmt::basic_string_view from types implicitly convertible +// to it, deducing Char. Explicitly convertible types such as the ones returned +// from FMT_STRING are intentionally excluded. +template ::value)> +constexpr auto to_string_view(const Char* s) -> basic_string_view { + return s; +} +template ::value)> +constexpr auto to_string_view(const T& s) + -> basic_string_view { + return s; +} +template +constexpr auto to_string_view(basic_string_view s) + -> basic_string_view { + return s; +} + +template +struct has_to_string_view : std::false_type {}; +// detail:: is intentional since to_string_view is not an extension point. +template +struct has_to_string_view< + T, void_t()))>> + : std::true_type {}; + +/// String's character (code unit) type. detail:: is intentional to prevent ADL. +template ()))> +using char_t = typename V::value_type; + +enum class type { + none_type, + // Integer types should go first, + int_type, + uint_type, + long_long_type, + ulong_long_type, + int128_type, + uint128_type, + bool_type, + char_type, + last_integer_type = char_type, + // followed by floating-point types. + float_type, + double_type, + long_double_type, + last_numeric_type = long_double_type, + cstring_type, + string_type, + pointer_type, + custom_type +}; + +// Maps core type T to the corresponding type enum constant. +template +struct type_constant : std::integral_constant {}; + +#define FMT_TYPE_CONSTANT(Type, constant) \ + template \ + struct type_constant \ + : std::integral_constant {} + +FMT_TYPE_CONSTANT(int, int_type); +FMT_TYPE_CONSTANT(unsigned, uint_type); +FMT_TYPE_CONSTANT(long long, long_long_type); +FMT_TYPE_CONSTANT(unsigned long long, ulong_long_type); +FMT_TYPE_CONSTANT(int128_opt, int128_type); +FMT_TYPE_CONSTANT(uint128_opt, uint128_type); +FMT_TYPE_CONSTANT(bool, bool_type); +FMT_TYPE_CONSTANT(Char, char_type); +FMT_TYPE_CONSTANT(float, float_type); +FMT_TYPE_CONSTANT(double, double_type); +FMT_TYPE_CONSTANT(long double, long_double_type); +FMT_TYPE_CONSTANT(const Char*, cstring_type); +FMT_TYPE_CONSTANT(basic_string_view, string_type); +FMT_TYPE_CONSTANT(const void*, pointer_type); + +constexpr auto is_integral_type(type t) -> bool { + return t > type::none_type && t <= type::last_integer_type; +} +constexpr auto is_arithmetic_type(type t) -> bool { + return t > type::none_type && t <= type::last_numeric_type; +} + +constexpr auto set(type rhs) -> int { return 1 << static_cast(rhs); } +constexpr auto in(type t, int set) -> bool { + return ((set >> static_cast(t)) & 1) != 0; +} + +// Bitsets of types. +enum { + sint_set = + set(type::int_type) | set(type::long_long_type) | set(type::int128_type), + uint_set = set(type::uint_type) | set(type::ulong_long_type) | + set(type::uint128_type), + bool_set = set(type::bool_type), + char_set = set(type::char_type), + float_set = set(type::float_type) | set(type::double_type) | + set(type::long_double_type), + string_set = set(type::string_type), + cstring_set = set(type::cstring_type), + pointer_set = set(type::pointer_type) +}; + +struct view {}; + +template +struct is_view : std::false_type {}; +template +struct is_view> : std::is_base_of {}; + +template struct named_arg; +template struct is_named_arg : std::false_type {}; +template struct is_static_named_arg : std::false_type {}; + +template +struct is_named_arg> : std::true_type {}; + +template struct named_arg : view { + const Char* name; + const T& value; + + named_arg(const Char* n, const T& v) : name(n), value(v) {} + static_assert(!is_named_arg::value, "nested named arguments"); +}; + +template constexpr auto count() -> int { return B ? 1 : 0; } +template constexpr auto count() -> int { + return (B1 ? 1 : 0) + count(); +} + +template constexpr auto count_named_args() -> int { + return count::value...>(); +} +template constexpr auto count_static_named_args() -> int { + return count::value...>(); +} + +template struct named_arg_info { + const Char* name; + int id; +}; + +// named_args is non-const to suppress a bogus -Wmaybe-uninitialized in gcc 13. +template +FMT_CONSTEXPR void check_for_duplicate(named_arg_info* named_args, + int named_arg_index, + basic_string_view arg_name) { + for (int i = 0; i < named_arg_index; ++i) { + if (named_args[i].name == arg_name) report_error("duplicate named arg"); + } +} + +template ::value)> +void init_named_arg(named_arg_info*, int& arg_index, int&, const T&) { + ++arg_index; +} +template ::value)> +void init_named_arg(named_arg_info* named_args, int& arg_index, + int& named_arg_index, const T& arg) { + check_for_duplicate(named_args, named_arg_index, arg.name); + named_args[named_arg_index++] = {arg.name, arg_index++}; +} + +template ::value)> +FMT_CONSTEXPR void init_static_named_arg(named_arg_info*, int& arg_index, + int&) { + ++arg_index; +} +template ::value)> +FMT_CONSTEXPR void init_static_named_arg(named_arg_info* named_args, + int& arg_index, int& named_arg_index) { + check_for_duplicate(named_args, named_arg_index, T::name); + named_args[named_arg_index++] = {T::name, arg_index++}; +} + +// To minimize the number of types we need to deal with, long is translated +// either to int or to long long depending on its size. +enum { long_short = sizeof(long) == sizeof(int) && FMT_BUILTIN_TYPES }; +using long_type = conditional_t; +using ulong_type = conditional_t; + +template +using format_as_result = + remove_cvref_t()))>; +template +using format_as_member_result = + remove_cvref_t::format_as(std::declval()))>; + +template +struct use_format_as : std::false_type {}; +// format_as member is only used to avoid injection into the std namespace. +template +struct use_format_as_member : std::false_type {}; + +// Only map owning types because mapping views can be unsafe. +template +struct use_format_as< + T, bool_constant>::value>> + : std::true_type {}; +template +struct use_format_as_member< + T, bool_constant>::value>> + : std::true_type {}; + +template > +using use_formatter = + bool_constant<(std::is_class::value || std::is_enum::value || + std::is_union::value || std::is_array::value) && + !has_to_string_view::value && !is_named_arg::value && + !use_format_as::value && !use_format_as_member::value>; + +template > +auto has_formatter_impl(T* p, buffered_context* ctx = nullptr) + -> decltype(formatter().format(*p, *ctx), std::true_type()); +template auto has_formatter_impl(...) -> std::false_type; + +// T can be const-qualified to check if it is const-formattable. +template constexpr auto has_formatter() -> bool { + return decltype(has_formatter_impl(static_cast(nullptr)))::value; +} + +// Maps formatting argument types to natively supported types or user-defined +// types with formatters. Returns void on errors to be SFINAE-friendly. +template struct type_mapper { + static auto map(signed char) -> int; + static auto map(unsigned char) -> unsigned; + static auto map(short) -> int; + static auto map(unsigned short) -> unsigned; + static auto map(int) -> int; + static auto map(unsigned) -> unsigned; + static auto map(long) -> long_type; + static auto map(unsigned long) -> ulong_type; + static auto map(long long) -> long long; + static auto map(unsigned long long) -> unsigned long long; + static auto map(int128_opt) -> int128_opt; + static auto map(uint128_opt) -> uint128_opt; + static auto map(bool) -> bool; + + template + static auto map(bitint) -> conditional_t; + template + static auto map(ubitint) + -> conditional_t; + + template ::value)> + static auto map(T) -> conditional_t< + std::is_same::value || std::is_same::value, Char, void>; + + static auto map(float) -> float; + static auto map(double) -> double; + static auto map(long double) -> long double; + + static auto map(Char*) -> const Char*; + static auto map(const Char*) -> const Char*; + template , + FMT_ENABLE_IF(!std::is_pointer::value)> + static auto map(const T&) -> conditional_t::value, + basic_string_view, void>; + + static auto map(void*) -> const void*; + static auto map(const void*) -> const void*; + static auto map(volatile void*) -> const void*; + static auto map(const volatile void*) -> const void*; + static auto map(nullptr_t) -> const void*; + template ::value || + std::is_member_pointer::value)> + static auto map(const T&) -> void; + + template ::value)> + static auto map(const T& x) -> decltype(map(format_as(x))); + template ::value)> + static auto map(const T& x) -> decltype(map(formatter::format_as(x))); + + template ::value)> + static auto map(T&) -> conditional_t(), T&, void>; + + template ::value)> + static auto map(const T& named_arg) -> decltype(map(named_arg.value)); +}; + +// detail:: is used to workaround a bug in MSVC 2017. +template +using mapped_t = decltype(detail::type_mapper::map(std::declval())); + +// A type constant after applying type_mapper. +template +using mapped_type_constant = type_constant, Char>; + +template ::value> +using stored_type_constant = std::integral_constant< + type, Context::builtin_types || TYPE == type::int_type ? TYPE + : type::custom_type>; +// A parse context with extra data used only in compile-time checks. +template +class compile_parse_context : public parse_context { + private: + int num_args_; + const type* types_; + using base = parse_context; + + public: + FMT_CONSTEXPR explicit compile_parse_context(basic_string_view fmt, + int num_args, const type* types, + int next_arg_id = 0) + : base(fmt, next_arg_id), num_args_(num_args), types_(types) {} + + constexpr auto num_args() const -> int { return num_args_; } + constexpr auto arg_type(int id) const -> type { return types_[id]; } + + FMT_CONSTEXPR auto next_arg_id() -> int { + int id = base::next_arg_id(); + if (id >= num_args_) report_error("argument not found"); + return id; + } + + FMT_CONSTEXPR void check_arg_id(int id) { + base::check_arg_id(id); + if (id >= num_args_) report_error("argument not found"); + } + using base::check_arg_id; + + FMT_CONSTEXPR void check_dynamic_spec(int arg_id) { + ignore_unused(arg_id); + if (arg_id < num_args_ && types_ && !is_integral_type(types_[arg_id])) + report_error("width/precision is not integer"); + } +}; + +// An argument reference. +template union arg_ref { + FMT_CONSTEXPR arg_ref(int idx = 0) : index(idx) {} + FMT_CONSTEXPR arg_ref(basic_string_view n) : name(n) {} + + int index; + basic_string_view name; +}; + +// Format specifiers with width and precision resolved at formatting rather +// than parsing time to allow reusing the same parsed specifiers with +// different sets of arguments (precompilation of format strings). +template struct dynamic_format_specs : format_specs { + arg_ref width_ref; + arg_ref precision_ref; +}; + +// Converts a character to ASCII. Returns '\0' on conversion failure. +template ::value)> +constexpr auto to_ascii(Char c) -> char { + return c <= 0xff ? static_cast(c) : '\0'; +} + +// Returns the number of code units in a code point or 1 on error. +template +FMT_CONSTEXPR auto code_point_length(const Char* begin) -> int { + if (const_check(sizeof(Char) != 1)) return 1; + auto c = static_cast(*begin); + return static_cast((0x3a55000000000000ull >> (2 * (c >> 3))) & 3) + 1; +} + +// Parses the range [begin, end) as an unsigned integer. This function assumes +// that the range is non-empty and the first character is a digit. +template +FMT_CONSTEXPR auto parse_nonnegative_int(const Char*& begin, const Char* end, + int error_value) noexcept -> int { + FMT_ASSERT(begin != end && '0' <= *begin && *begin <= '9', ""); + unsigned value = 0, prev = 0; + auto p = begin; + do { + prev = value; + value = value * 10 + unsigned(*p - '0'); + ++p; + } while (p != end && '0' <= *p && *p <= '9'); + auto num_digits = p - begin; + begin = p; + int digits10 = static_cast(sizeof(int) * CHAR_BIT * 3 / 10); + if (num_digits <= digits10) return static_cast(value); + // Check for overflow. + unsigned max = INT_MAX; + return num_digits == digits10 + 1 && + prev * 10ull + unsigned(p[-1] - '0') <= max + ? static_cast(value) + : error_value; +} + +FMT_CONSTEXPR inline auto parse_align(char c) -> align { + switch (c) { + case '<': return align::left; + case '>': return align::right; + case '^': return align::center; + } + return align::none; +} + +template constexpr auto is_name_start(Char c) -> bool { + return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '_'; +} + +template +FMT_CONSTEXPR auto parse_arg_id(const Char* begin, const Char* end, + Handler&& handler) -> const Char* { + Char c = *begin; + if (c >= '0' && c <= '9') { + int index = 0; + if (c != '0') + index = parse_nonnegative_int(begin, end, INT_MAX); + else + ++begin; + if (begin == end || (*begin != '}' && *begin != ':')) + report_error("invalid format string"); + else + handler.on_index(index); + return begin; + } + if (FMT_OPTIMIZE_SIZE > 1 || !is_name_start(c)) { + report_error("invalid format string"); + return begin; + } + auto it = begin; + do { + ++it; + } while (it != end && (is_name_start(*it) || ('0' <= *it && *it <= '9'))); + handler.on_name({begin, to_unsigned(it - begin)}); + return it; +} + +template struct dynamic_spec_handler { + parse_context& ctx; + arg_ref& ref; + arg_id_kind& kind; + + FMT_CONSTEXPR void on_index(int id) { + ref = id; + kind = arg_id_kind::index; + ctx.check_arg_id(id); + ctx.check_dynamic_spec(id); + } + FMT_CONSTEXPR void on_name(basic_string_view id) { + ref = id; + kind = arg_id_kind::name; + ctx.check_arg_id(id); + } +}; + +template struct parse_dynamic_spec_result { + const Char* end; + arg_id_kind kind; +}; + +// Parses integer | "{" [arg_id] "}". +template +FMT_CONSTEXPR auto parse_dynamic_spec(const Char* begin, const Char* end, + int& value, arg_ref& ref, + parse_context& ctx) + -> parse_dynamic_spec_result { + FMT_ASSERT(begin != end, ""); + auto kind = arg_id_kind::none; + if ('0' <= *begin && *begin <= '9') { + int val = parse_nonnegative_int(begin, end, -1); + if (val == -1) report_error("number is too big"); + value = val; + } else { + if (*begin == '{') { + ++begin; + if (begin != end) { + Char c = *begin; + if (c == '}' || c == ':') { + int id = ctx.next_arg_id(); + ref = id; + kind = arg_id_kind::index; + ctx.check_dynamic_spec(id); + } else { + begin = parse_arg_id(begin, end, + dynamic_spec_handler{ctx, ref, kind}); + } + } + if (begin != end && *begin == '}') return {++begin, kind}; + } + report_error("invalid format string"); + } + return {begin, kind}; +} + +template +FMT_CONSTEXPR auto parse_width(const Char* begin, const Char* end, + format_specs& specs, arg_ref& width_ref, + parse_context& ctx) -> const Char* { + auto result = parse_dynamic_spec(begin, end, specs.width, width_ref, ctx); + specs.set_dynamic_width(result.kind); + return result.end; +} + +template +FMT_CONSTEXPR auto parse_precision(const Char* begin, const Char* end, + format_specs& specs, + arg_ref& precision_ref, + parse_context& ctx) -> const Char* { + ++begin; + if (begin == end) { + report_error("invalid precision"); + return begin; + } + auto result = + parse_dynamic_spec(begin, end, specs.precision, precision_ref, ctx); + specs.set_dynamic_precision(result.kind); + return result.end; +} + +enum class state { start, align, sign, hash, zero, width, precision, locale }; + +// Parses standard format specifiers. +template +FMT_CONSTEXPR auto parse_format_specs(const Char* begin, const Char* end, + dynamic_format_specs& specs, + parse_context& ctx, type arg_type) + -> const Char* { + auto c = '\0'; + if (end - begin > 1) { + auto next = to_ascii(begin[1]); + c = parse_align(next) == align::none ? to_ascii(*begin) : '\0'; + } else { + if (begin == end) return begin; + c = to_ascii(*begin); + } + + struct { + state current_state = state::start; + FMT_CONSTEXPR void operator()(state s, bool valid = true) { + if (current_state >= s || !valid) + report_error("invalid format specifier"); + current_state = s; + } + } enter_state; + + using pres = presentation_type; + constexpr auto integral_set = sint_set | uint_set | bool_set | char_set; + struct { + const Char*& begin; + format_specs& specs; + type arg_type; + + FMT_CONSTEXPR auto operator()(pres pres_type, int set) -> const Char* { + if (!in(arg_type, set)) report_error("invalid format specifier"); + specs.set_type(pres_type); + return begin + 1; + } + } parse_presentation_type{begin, specs, arg_type}; + + for (;;) { + switch (c) { + case '<': + case '>': + case '^': + enter_state(state::align); + specs.set_align(parse_align(c)); + ++begin; + break; + case '+': + case ' ': + specs.set_sign(c == ' ' ? sign::space : sign::plus); + FMT_FALLTHROUGH; + case '-': + enter_state(state::sign, in(arg_type, sint_set | float_set)); + ++begin; + break; + case '#': + enter_state(state::hash, is_arithmetic_type(arg_type)); + specs.set_alt(); + ++begin; + break; + case '0': + enter_state(state::zero); + if (!is_arithmetic_type(arg_type)) + report_error("format specifier requires numeric argument"); + if (specs.align() == align::none) { + // Ignore 0 if align is specified for compatibility with std::format. + specs.set_align(align::numeric); + specs.set_fill('0'); + } + ++begin; + break; + // clang-format off + case '1': case '2': case '3': case '4': case '5': + case '6': case '7': case '8': case '9': case '{': + // clang-format on + enter_state(state::width); + begin = parse_width(begin, end, specs, specs.width_ref, ctx); + break; + case '.': + enter_state(state::precision, + in(arg_type, float_set | string_set | cstring_set)); + begin = parse_precision(begin, end, specs, specs.precision_ref, ctx); + break; + case 'L': + enter_state(state::locale, is_arithmetic_type(arg_type)); + specs.set_localized(); + ++begin; + break; + case 'd': return parse_presentation_type(pres::dec, integral_set); + case 'X': specs.set_upper(); FMT_FALLTHROUGH; + case 'x': return parse_presentation_type(pres::hex, integral_set); + case 'o': return parse_presentation_type(pres::oct, integral_set); + case 'B': specs.set_upper(); FMT_FALLTHROUGH; + case 'b': return parse_presentation_type(pres::bin, integral_set); + case 'E': specs.set_upper(); FMT_FALLTHROUGH; + case 'e': return parse_presentation_type(pres::exp, float_set); + case 'F': specs.set_upper(); FMT_FALLTHROUGH; + case 'f': return parse_presentation_type(pres::fixed, float_set); + case 'G': specs.set_upper(); FMT_FALLTHROUGH; + case 'g': return parse_presentation_type(pres::general, float_set); + case 'A': specs.set_upper(); FMT_FALLTHROUGH; + case 'a': return parse_presentation_type(pres::hexfloat, float_set); + case 'c': + if (arg_type == type::bool_type) report_error("invalid format specifier"); + return parse_presentation_type(pres::chr, integral_set); + case 's': + return parse_presentation_type(pres::string, + bool_set | string_set | cstring_set); + case 'p': + return parse_presentation_type(pres::pointer, pointer_set | cstring_set); + case '?': + return parse_presentation_type(pres::debug, + char_set | string_set | cstring_set); + case '}': return begin; + default: { + if (*begin == '}') return begin; + // Parse fill and alignment. + auto fill_end = begin + code_point_length(begin); + if (end - fill_end <= 0) { + report_error("invalid format specifier"); + return begin; + } + if (*begin == '{') { + report_error("invalid fill character '{'"); + return begin; + } + auto alignment = parse_align(to_ascii(*fill_end)); + enter_state(state::align, alignment != align::none); + specs.set_fill( + basic_string_view(begin, to_unsigned(fill_end - begin))); + specs.set_align(alignment); + begin = fill_end + 1; + } + } + if (begin == end) return begin; + c = to_ascii(*begin); + } +} + +template +FMT_CONSTEXPR FMT_INLINE auto parse_replacement_field(const Char* begin, + const Char* end, + Handler&& handler) + -> const Char* { + ++begin; + if (begin == end) { + handler.on_error("invalid format string"); + return end; + } + int arg_id = 0; + switch (*begin) { + case '}': + handler.on_replacement_field(handler.on_arg_id(), begin); + return begin + 1; + case '{': handler.on_text(begin, begin + 1); return begin + 1; + case ':': arg_id = handler.on_arg_id(); break; + default: { + struct id_adapter { + Handler& handler; + int arg_id; + + FMT_CONSTEXPR void on_index(int id) { arg_id = handler.on_arg_id(id); } + FMT_CONSTEXPR void on_name(basic_string_view id) { + arg_id = handler.on_arg_id(id); + } + } adapter = {handler, 0}; + begin = parse_arg_id(begin, end, adapter); + arg_id = adapter.arg_id; + Char c = begin != end ? *begin : Char(); + if (c == '}') { + handler.on_replacement_field(arg_id, begin); + return begin + 1; + } + if (c != ':') { + handler.on_error("missing '}' in format string"); + return end; + } + break; + } + } + begin = handler.on_format_specs(arg_id, begin + 1, end); + if (begin == end || *begin != '}') + return handler.on_error("unknown format specifier"), end; + return begin + 1; +} + +template +FMT_CONSTEXPR void parse_format_string(basic_string_view fmt, + Handler&& handler) { + auto begin = fmt.data(), end = begin + fmt.size(); + auto p = begin; + while (p != end) { + auto c = *p++; + if (c == '{') { + handler.on_text(begin, p - 1); + begin = p = parse_replacement_field(p - 1, end, handler); + } else if (c == '}') { + if (p == end || *p != '}') + return handler.on_error("unmatched '}' in format string"); + handler.on_text(begin, p); + begin = ++p; + } + } + handler.on_text(begin, end); +} + +// Checks char specs and returns true iff the presentation type is char-like. +FMT_CONSTEXPR inline auto check_char_specs(const format_specs& specs) -> bool { + auto type = specs.type(); + if (type != presentation_type::none && type != presentation_type::chr && + type != presentation_type::debug) { + return false; + } + if (specs.align() == align::numeric || specs.sign() != sign::none || + specs.alt()) { + report_error("invalid format specifier for char"); + } + return true; +} + +// A base class for compile-time strings. +struct compile_string {}; + +template +FMT_VISIBILITY("hidden") // Suppress an ld warning on macOS (#3769). +FMT_CONSTEXPR auto invoke_parse(parse_context& ctx) -> const Char* { + using mapped_type = remove_cvref_t>; + constexpr bool formattable = + std::is_constructible>::value; + if (!formattable) return ctx.begin(); // Error is reported in the value ctor. + using formatted_type = conditional_t; + return formatter().parse(ctx); +} + +template struct arg_pack {}; + +template +class format_string_checker { + private: + type types_[max_of(1, NUM_ARGS)]; + named_arg_info named_args_[max_of(1, NUM_NAMED_ARGS)]; + compile_parse_context context_; + + using parse_func = auto (*)(parse_context&) -> const Char*; + parse_func parse_funcs_[max_of(1, NUM_ARGS)]; + + public: + template + FMT_CONSTEXPR explicit format_string_checker(basic_string_view fmt, + arg_pack) + : types_{mapped_type_constant::value...}, + named_args_{}, + context_(fmt, NUM_ARGS, types_), + parse_funcs_{&invoke_parse...} { + int arg_index = 0, named_arg_index = 0; + FMT_APPLY_VARIADIC( + init_static_named_arg(named_args_, arg_index, named_arg_index)); + ignore_unused(arg_index, named_arg_index); + } + + FMT_CONSTEXPR void on_text(const Char*, const Char*) {} + + FMT_CONSTEXPR auto on_arg_id() -> int { return context_.next_arg_id(); } + FMT_CONSTEXPR auto on_arg_id(int id) -> int { + context_.check_arg_id(id); + return id; + } + FMT_CONSTEXPR auto on_arg_id(basic_string_view id) -> int { + for (int i = 0; i < NUM_NAMED_ARGS; ++i) { + if (named_args_[i].name == id) return named_args_[i].id; + } + if (!DYNAMIC_NAMES) on_error("argument not found"); + return -1; + } + + FMT_CONSTEXPR void on_replacement_field(int id, const Char* begin) { + on_format_specs(id, begin, begin); // Call parse() on empty specs. + } + + FMT_CONSTEXPR auto on_format_specs(int id, const Char* begin, const Char* end) + -> const Char* { + context_.advance_to(begin); + if (id >= 0 && id < NUM_ARGS) return parse_funcs_[id](context_); + + // If id is out of range, it means we do not know the type and cannot parse + // the format at compile time. Instead, skip over content until we finish + // the format spec, accounting for any nested replacements. + for (int bracket_count = 0; + begin != end && (bracket_count > 0 || *begin != '}'); ++begin) { + if (*begin == '{') + ++bracket_count; + else if (*begin == '}') + --bracket_count; + } + return begin; + } + + FMT_NORETURN FMT_CONSTEXPR void on_error(const char* message) { + report_error(message); + } +}; + +/// A contiguous memory buffer with an optional growing ability. It is an +/// internal class and shouldn't be used directly, only via `memory_buffer`. +template class buffer { + private: + T* ptr_; + size_t size_; + size_t capacity_; + + using grow_fun = void (*)(buffer& buf, size_t capacity); + grow_fun grow_; + + protected: + // Don't initialize ptr_ since it is not accessed to save a few cycles. + FMT_MSC_WARNING(suppress : 26495) + FMT_CONSTEXPR buffer(grow_fun grow, size_t sz) noexcept + : size_(sz), capacity_(sz), grow_(grow) {} + + constexpr buffer(grow_fun grow, T* p = nullptr, size_t sz = 0, + size_t cap = 0) noexcept + : ptr_(p), size_(sz), capacity_(cap), grow_(grow) {} + + FMT_CONSTEXPR20 ~buffer() = default; + buffer(buffer&&) = default; + + /// Sets the buffer data and capacity. + FMT_CONSTEXPR void set(T* buf_data, size_t buf_capacity) noexcept { + ptr_ = buf_data; + capacity_ = buf_capacity; + } + + public: + using value_type = T; + using const_reference = const T&; + + buffer(const buffer&) = delete; + void operator=(const buffer&) = delete; + + auto begin() noexcept -> T* { return ptr_; } + auto end() noexcept -> T* { return ptr_ + size_; } + + auto begin() const noexcept -> const T* { return ptr_; } + auto end() const noexcept -> const T* { return ptr_ + size_; } + + /// Returns the size of this buffer. + constexpr auto size() const noexcept -> size_t { return size_; } + + /// Returns the capacity of this buffer. + constexpr auto capacity() const noexcept -> size_t { return capacity_; } + + /// Returns a pointer to the buffer data (not null-terminated). + FMT_CONSTEXPR auto data() noexcept -> T* { return ptr_; } + FMT_CONSTEXPR auto data() const noexcept -> const T* { return ptr_; } + + /// Clears this buffer. + FMT_CONSTEXPR void clear() { size_ = 0; } + + // Tries resizing the buffer to contain `count` elements. If T is a POD type + // the new elements may not be initialized. + FMT_CONSTEXPR void try_resize(size_t count) { + try_reserve(count); + size_ = min_of(count, capacity_); + } + + // Tries increasing the buffer capacity to `new_capacity`. It can increase the + // capacity by a smaller amount than requested but guarantees there is space + // for at least one additional element either by increasing the capacity or by + // flushing the buffer if it is full. + FMT_CONSTEXPR void try_reserve(size_t new_capacity) { + if (new_capacity > capacity_) grow_(*this, new_capacity); + } + + FMT_CONSTEXPR void push_back(const T& value) { + try_reserve(size_ + 1); + ptr_[size_++] = value; + } + + /// Appends data to the end of the buffer. + template +// Workaround for MSVC2019 to fix error C2893: Failed to specialize function +// template 'void fmt::v11::detail::buffer::append(const U *,const U *)'. +#if !FMT_MSC_VERSION || FMT_MSC_VERSION >= 1940 + FMT_CONSTEXPR20 +#endif + void + append(const U* begin, const U* end) { + while (begin != end) { + auto size = size_; + auto free_cap = capacity_ - size; + auto count = to_unsigned(end - begin); + if (free_cap < count) { + grow_(*this, size + count); + size = size_; + free_cap = capacity_ - size; + count = count < free_cap ? count : free_cap; + } + // A loop is faster than memcpy on small sizes. + T* out = ptr_ + size; + for (size_t i = 0; i < count; ++i) out[i] = begin[i]; + size_ += count; + begin += count; + } + } + + template FMT_CONSTEXPR auto operator[](Idx index) -> T& { + return ptr_[index]; + } + template + FMT_CONSTEXPR auto operator[](Idx index) const -> const T& { + return ptr_[index]; + } +}; + +struct buffer_traits { + constexpr explicit buffer_traits(size_t) {} + constexpr auto count() const -> size_t { return 0; } + constexpr auto limit(size_t size) const -> size_t { return size; } +}; + +class fixed_buffer_traits { + private: + size_t count_ = 0; + size_t limit_; + + public: + constexpr explicit fixed_buffer_traits(size_t limit) : limit_(limit) {} + constexpr auto count() const -> size_t { return count_; } + FMT_CONSTEXPR auto limit(size_t size) -> size_t { + size_t n = limit_ > count_ ? limit_ - count_ : 0; + count_ += size; + return min_of(size, n); + } +}; + +// A buffer that writes to an output iterator when flushed. +template +class iterator_buffer : public Traits, public buffer { + private: + OutputIt out_; + enum { buffer_size = 256 }; + T data_[buffer_size]; + + static FMT_CONSTEXPR void grow(buffer& buf, size_t) { + if (buf.size() == buffer_size) static_cast(buf).flush(); + } + + void flush() { + auto size = this->size(); + this->clear(); + const T* begin = data_; + const T* end = begin + this->limit(size); + while (begin != end) *out_++ = *begin++; + } + + public: + explicit iterator_buffer(OutputIt out, size_t n = buffer_size) + : Traits(n), buffer(grow, data_, 0, buffer_size), out_(out) {} + iterator_buffer(iterator_buffer&& other) noexcept + : Traits(other), + buffer(grow, data_, 0, buffer_size), + out_(other.out_) {} + ~iterator_buffer() { + // Don't crash if flush fails during unwinding. + FMT_TRY { flush(); } + FMT_CATCH(...) {} + } + + auto out() -> OutputIt { + flush(); + return out_; + } + auto count() const -> size_t { return Traits::count() + this->size(); } +}; + +template +class iterator_buffer : public fixed_buffer_traits, + public buffer { + private: + T* out_; + enum { buffer_size = 256 }; + T data_[buffer_size]; + + static FMT_CONSTEXPR void grow(buffer& buf, size_t) { + if (buf.size() == buf.capacity()) + static_cast(buf).flush(); + } + + void flush() { + size_t n = this->limit(this->size()); + if (this->data() == out_) { + out_ += n; + this->set(data_, buffer_size); + } + this->clear(); + } + + public: + explicit iterator_buffer(T* out, size_t n = buffer_size) + : fixed_buffer_traits(n), buffer(grow, out, 0, n), out_(out) {} + iterator_buffer(iterator_buffer&& other) noexcept + : fixed_buffer_traits(other), + buffer(static_cast(other)), + out_(other.out_) { + if (this->data() != out_) { + this->set(data_, buffer_size); + this->clear(); + } + } + ~iterator_buffer() { flush(); } + + auto out() -> T* { + flush(); + return out_; + } + auto count() const -> size_t { + return fixed_buffer_traits::count() + this->size(); + } +}; + +template class iterator_buffer : public buffer { + public: + explicit iterator_buffer(T* out, size_t = 0) + : buffer([](buffer&, size_t) {}, out, 0, ~size_t()) {} + + auto out() -> T* { return &*this->end(); } +}; + +template +class container_buffer : public buffer { + private: + using value_type = typename Container::value_type; + + static FMT_CONSTEXPR void grow(buffer& buf, size_t capacity) { + auto& self = static_cast(buf); + self.container.resize(capacity); + self.set(&self.container[0], capacity); + } + + public: + Container& container; + + explicit container_buffer(Container& c) + : buffer(grow, c.size()), container(c) {} +}; + +// A buffer that writes to a container with the contiguous storage. +template +class iterator_buffer< + OutputIt, + enable_if_t::value && + is_contiguous::value, + typename OutputIt::container_type::value_type>> + : public container_buffer { + private: + using base = container_buffer; + + public: + explicit iterator_buffer(typename OutputIt::container_type& c) : base(c) {} + explicit iterator_buffer(OutputIt out, size_t = 0) + : base(get_container(out)) {} + + auto out() -> OutputIt { return OutputIt(this->container); } +}; + +// A buffer that counts the number of code units written discarding the output. +template class counting_buffer : public buffer { + private: + enum { buffer_size = 256 }; + T data_[buffer_size]; + size_t count_ = 0; + + static FMT_CONSTEXPR void grow(buffer& buf, size_t) { + if (buf.size() != buffer_size) return; + static_cast(buf).count_ += buf.size(); + buf.clear(); + } + + public: + FMT_CONSTEXPR counting_buffer() : buffer(grow, data_, 0, buffer_size) {} + + constexpr auto count() const noexcept -> size_t { + return count_ + this->size(); + } +}; + +template +struct is_back_insert_iterator> : std::true_type {}; + +template +struct has_back_insert_iterator_container_append : std::false_type {}; +template +struct has_back_insert_iterator_container_append< + OutputIt, InputIt, + void_t()) + .append(std::declval(), + std::declval()))>> : std::true_type {}; + +template +struct has_back_insert_iterator_container_insert_at_end : std::false_type {}; + +template +struct has_back_insert_iterator_container_insert_at_end< + OutputIt, InputIt, + void_t()) + .insert(get_container(std::declval()).end(), + std::declval(), + std::declval()))>> : std::true_type {}; + +// An optimized version of std::copy with the output value type (T). +template ::value&& + has_back_insert_iterator_container_append< + OutputIt, InputIt>::value)> +FMT_CONSTEXPR20 auto copy(InputIt begin, InputIt end, OutputIt out) + -> OutputIt { + get_container(out).append(begin, end); + return out; +} + +template ::value && + !has_back_insert_iterator_container_append< + OutputIt, InputIt>::value && + has_back_insert_iterator_container_insert_at_end< + OutputIt, InputIt>::value)> +FMT_CONSTEXPR20 auto copy(InputIt begin, InputIt end, OutputIt out) + -> OutputIt { + auto& c = get_container(out); + c.insert(c.end(), begin, end); + return out; +} + +template ::value && + (has_back_insert_iterator_container_append< + OutputIt, InputIt>::value || + has_back_insert_iterator_container_insert_at_end< + OutputIt, InputIt>::value)))> +FMT_CONSTEXPR auto copy(InputIt begin, InputIt end, OutputIt out) -> OutputIt { + while (begin != end) *out++ = static_cast(*begin++); + return out; +} + +template +FMT_CONSTEXPR auto copy(basic_string_view s, OutputIt out) -> OutputIt { + return copy(s.begin(), s.end(), out); +} + +template +struct is_buffer_appender : std::false_type {}; +template +struct is_buffer_appender< + It, bool_constant< + is_back_insert_iterator::value && + std::is_base_of, + typename It::container_type>::value>> + : std::true_type {}; + +// Maps an output iterator to a buffer. +template ::value)> +auto get_buffer(OutputIt out) -> iterator_buffer { + return iterator_buffer(out); +} +template ::value)> +auto get_buffer(OutputIt out) -> buffer& { + return get_container(out); +} + +template +auto get_iterator(Buf& buf, OutputIt) -> decltype(buf.out()) { + return buf.out(); +} +template +auto get_iterator(buffer&, OutputIt out) -> OutputIt { + return out; +} + +// This type is intentionally undefined, only used for errors. +template struct type_is_unformattable_for; + +template struct string_value { + const Char* data; + size_t size; + auto str() const -> basic_string_view { return {data, size}; } +}; + +template struct custom_value { + using char_type = typename Context::char_type; + void* value; + void (*format)(void* arg, parse_context& parse_ctx, Context& ctx); +}; + +template struct named_arg_value { + const named_arg_info* data; + size_t size; +}; + +struct custom_tag {}; + +#if !FMT_BUILTIN_TYPES +# define FMT_BUILTIN , monostate +#else +# define FMT_BUILTIN +#endif + +// A formatting argument value. +template class value { + public: + using char_type = typename Context::char_type; + + union { + monostate no_value; + int int_value; + unsigned uint_value; + long long long_long_value; + unsigned long long ulong_long_value; + int128_opt int128_value; + uint128_opt uint128_value; + bool bool_value; + char_type char_value; + float float_value; + double double_value; + long double long_double_value; + const void* pointer; + string_value string; + custom_value custom; + named_arg_value named_args; + }; + + constexpr FMT_INLINE value() : no_value() {} + constexpr FMT_INLINE value(signed char x) : int_value(x) {} + constexpr FMT_INLINE value(unsigned char x FMT_BUILTIN) : uint_value(x) {} + constexpr FMT_INLINE value(signed short x) : int_value(x) {} + constexpr FMT_INLINE value(unsigned short x FMT_BUILTIN) : uint_value(x) {} + constexpr FMT_INLINE value(int x) : int_value(x) {} + constexpr FMT_INLINE value(unsigned x FMT_BUILTIN) : uint_value(x) {} + FMT_CONSTEXPR FMT_INLINE value(long x FMT_BUILTIN) : value(long_type(x)) {} + FMT_CONSTEXPR FMT_INLINE value(unsigned long x FMT_BUILTIN) + : value(ulong_type(x)) {} + constexpr FMT_INLINE value(long long x FMT_BUILTIN) : long_long_value(x) {} + constexpr FMT_INLINE value(unsigned long long x FMT_BUILTIN) + : ulong_long_value(x) {} + FMT_INLINE value(int128_opt x FMT_BUILTIN) : int128_value(x) {} + FMT_INLINE value(uint128_opt x FMT_BUILTIN) : uint128_value(x) {} + constexpr FMT_INLINE value(bool x FMT_BUILTIN) : bool_value(x) {} + + template + constexpr FMT_INLINE value(bitint x FMT_BUILTIN) : long_long_value(x) { + static_assert(N <= 64, "unsupported _BitInt"); + } + template + constexpr FMT_INLINE value(ubitint x FMT_BUILTIN) : ulong_long_value(x) { + static_assert(N <= 64, "unsupported _BitInt"); + } + + template ::value)> + constexpr FMT_INLINE value(T x FMT_BUILTIN) : char_value(x) { + static_assert( + std::is_same::value || std::is_same::value, + "mixing character types is disallowed"); + } + + constexpr FMT_INLINE value(float x FMT_BUILTIN) : float_value(x) {} + constexpr FMT_INLINE value(double x FMT_BUILTIN) : double_value(x) {} + FMT_INLINE value(long double x FMT_BUILTIN) : long_double_value(x) {} + + FMT_CONSTEXPR FMT_INLINE value(char_type* x FMT_BUILTIN) { + string.data = x; + if (is_constant_evaluated()) string.size = 0; + } + FMT_CONSTEXPR FMT_INLINE value(const char_type* x FMT_BUILTIN) { + string.data = x; + if (is_constant_evaluated()) string.size = 0; + } + template , + FMT_ENABLE_IF(!std::is_pointer::value)> + FMT_CONSTEXPR value(const T& x FMT_BUILTIN) { + static_assert(std::is_same::value, + "mixing character types is disallowed"); + auto sv = to_string_view(x); + string.data = sv.data(); + string.size = sv.size(); + } + FMT_INLINE value(void* x FMT_BUILTIN) : pointer(x) {} + FMT_INLINE value(const void* x FMT_BUILTIN) : pointer(x) {} + FMT_INLINE value(volatile void* x FMT_BUILTIN) + : pointer(const_cast(x)) {} + FMT_INLINE value(const volatile void* x FMT_BUILTIN) + : pointer(const_cast(x)) {} + FMT_INLINE value(nullptr_t) : pointer(nullptr) {} + + template ::value || + std::is_member_pointer::value)> + value(const T&) { + // Formatting of arbitrary pointers is disallowed. If you want to format a + // pointer cast it to `void*` or `const void*`. In particular, this forbids + // formatting of `[const] volatile char*` printed as bool by iostreams. + static_assert(sizeof(T) == 0, + "formatting of non-void pointers is disallowed"); + } + + template ::value)> + value(const T& x) : value(format_as(x)) {} + template ::value)> + value(const T& x) : value(formatter::format_as(x)) {} + + template ::value)> + value(const T& named_arg) : value(named_arg.value) {} + + template ::value || !FMT_BUILTIN_TYPES)> + FMT_CONSTEXPR20 FMT_INLINE value(T& x) : value(x, custom_tag()) {} + + FMT_ALWAYS_INLINE value(const named_arg_info* args, size_t size) + : named_args{args, size} {} + + private: + template ())> + FMT_CONSTEXPR value(T& x, custom_tag) { + using value_type = remove_const_t; + // T may overload operator& e.g. std::vector::reference in libc++. + if (!is_constant_evaluated()) { + custom.value = + const_cast(&reinterpret_cast(x)); + } else { + custom.value = nullptr; +#if defined(__cpp_if_constexpr) + if constexpr (std::is_same*>::value) + custom.value = const_cast(&x); +#endif + } + custom.format = format_custom; + } + + template ())> + FMT_CONSTEXPR value(const T&, custom_tag) { + // Cannot format an argument; to make type T formattable provide a + // formatter specialization: https://fmt.dev/latest/api.html#udt. + type_is_unformattable_for _; + } + + // Formats an argument of a custom type, such as a user-defined class. + template + static void format_custom(void* arg, parse_context& parse_ctx, + Context& ctx) { + auto f = formatter(); + parse_ctx.advance_to(f.parse(parse_ctx)); + using qualified_type = + conditional_t(), const T, T>; + // format must be const for compatibility with std::format and compilation. + const auto& cf = f; + ctx.advance_to(cf.format(*static_cast(arg), ctx)); + } +}; + +enum { packed_arg_bits = 4 }; +// Maximum number of arguments with packed types. +enum { max_packed_args = 62 / packed_arg_bits }; +enum : unsigned long long { is_unpacked_bit = 1ULL << 63 }; +enum : unsigned long long { has_named_args_bit = 1ULL << 62 }; + +template +struct is_output_iterator : std::false_type {}; + +template <> struct is_output_iterator : std::true_type {}; + +template +struct is_output_iterator< + It, T, + enable_if_t&>()++), + T>::value>> : std::true_type {}; + +template constexpr auto encode_types() -> unsigned long long { + return 0; +} + +template +constexpr auto encode_types() -> unsigned long long { + return static_cast(stored_type_constant::value) | + (encode_types() << packed_arg_bits); +} + +template +constexpr auto make_descriptor() -> unsigned long long { + return NUM_ARGS <= max_packed_args ? encode_types() + : is_unpacked_bit | NUM_ARGS; +} + +template +using arg_t = conditional_t, + basic_format_arg>; + +template +struct named_arg_store { + // args_[0].named_args points to named_args to avoid bloating format_args. + arg_t args[1u + NUM_ARGS]; + named_arg_info + named_args[static_cast(NUM_NAMED_ARGS)]; + + template + FMT_CONSTEXPR FMT_ALWAYS_INLINE named_arg_store(T&... values) + : args{{named_args, NUM_NAMED_ARGS}, values...} { + int arg_index = 0, named_arg_index = 0; + FMT_APPLY_VARIADIC( + init_named_arg(named_args, arg_index, named_arg_index, values)); + } + + named_arg_store(named_arg_store&& rhs) { + args[0] = {named_args, NUM_NAMED_ARGS}; + for (size_t i = 1; i < sizeof(args) / sizeof(*args); ++i) + args[i] = rhs.args[i]; + for (size_t i = 0; i < NUM_NAMED_ARGS; ++i) + named_args[i] = rhs.named_args[i]; + } + + named_arg_store(const named_arg_store& rhs) = delete; + auto operator=(const named_arg_store& rhs) -> named_arg_store& = delete; + auto operator=(named_arg_store&& rhs) -> named_arg_store& = delete; + operator const arg_t*() const { return args + 1; } +}; + +// An array of references to arguments. It can be implicitly converted to +// `basic_format_args` for passing into type-erased formatting functions +// such as `vformat`. It is a plain struct to reduce binary size in debug mode. +template +struct format_arg_store { + // +1 to workaround a bug in gcc 7.5 that causes duplicated-branches warning. + using type = + conditional_t[max_of(1, NUM_ARGS)], + named_arg_store>; + type args; +}; + +// TYPE can be different from type_constant, e.g. for __float128. +template struct native_formatter { + private: + dynamic_format_specs specs_; + + public: + using nonlocking = void; + + FMT_CONSTEXPR auto parse(parse_context& ctx) -> const Char* { + if (ctx.begin() == ctx.end() || *ctx.begin() == '}') return ctx.begin(); + auto end = parse_format_specs(ctx.begin(), ctx.end(), specs_, ctx, TYPE); + if (const_check(TYPE == type::char_type)) check_char_specs(specs_); + return end; + } + + template + FMT_CONSTEXPR void set_debug_format(bool set = true) { + specs_.set_type(set ? presentation_type::debug : presentation_type::none); + } + + FMT_PRAGMA_CLANG(diagnostic ignored "-Wundefined-inline") + template + FMT_CONSTEXPR auto format(const T& val, FormatContext& ctx) const + -> decltype(ctx.out()); +}; + +template +struct locking + : bool_constant::value == type::custom_type> {}; +template +struct locking>::nonlocking>> + : std::false_type {}; + +template FMT_CONSTEXPR inline auto is_locking() -> bool { + return locking::value; +} +template +FMT_CONSTEXPR inline auto is_locking() -> bool { + return locking::value || is_locking(); +} + +FMT_API void vformat_to(buffer& buf, string_view fmt, format_args args, + locale_ref loc = {}); + +#if FMT_WIN32 +FMT_API void vprint_mojibake(FILE*, string_view, format_args, bool); +#else // format_args is passed by reference since it is defined later. +inline void vprint_mojibake(FILE*, string_view, const format_args&, bool) {} +#endif +} // namespace detail + +// The main public API. + +template +FMT_CONSTEXPR void parse_context::do_check_arg_id(int arg_id) { + // Argument id is only checked at compile time during parsing because + // formatting has its own validation. + if (detail::is_constant_evaluated() && use_constexpr_cast) { + auto ctx = static_cast*>(this); + if (arg_id >= ctx->num_args()) report_error("argument not found"); + } +} + +template +FMT_CONSTEXPR void parse_context::check_dynamic_spec(int arg_id) { + using detail::compile_parse_context; + if (detail::is_constant_evaluated() && use_constexpr_cast) + static_cast*>(this)->check_dynamic_spec(arg_id); +} + +FMT_BEGIN_EXPORT + +// An output iterator that appends to a buffer. It is used instead of +// back_insert_iterator to reduce symbol sizes and avoid dependency. +template class basic_appender { + protected: + detail::buffer* container; + + public: + using container_type = detail::buffer; + + FMT_CONSTEXPR basic_appender(detail::buffer& buf) : container(&buf) {} + + FMT_CONSTEXPR20 auto operator=(T c) -> basic_appender& { + container->push_back(c); + return *this; + } + FMT_CONSTEXPR20 auto operator*() -> basic_appender& { return *this; } + FMT_CONSTEXPR20 auto operator++() -> basic_appender& { return *this; } + FMT_CONSTEXPR20 auto operator++(int) -> basic_appender { return *this; } +}; + +// A formatting argument. Context is a template parameter for the compiled API +// where output can be unbuffered. +template class basic_format_arg { + private: + detail::value value_; + detail::type type_; + + friend class basic_format_args; + + using char_type = typename Context::char_type; + + public: + class handle { + private: + detail::custom_value custom_; + + public: + explicit handle(detail::custom_value custom) : custom_(custom) {} + + void format(parse_context& parse_ctx, Context& ctx) const { + custom_.format(custom_.value, parse_ctx, ctx); + } + }; + + constexpr basic_format_arg() : type_(detail::type::none_type) {} + basic_format_arg(const detail::named_arg_info* args, size_t size) + : value_(args, size) {} + template + basic_format_arg(T&& val) + : value_(val), type_(detail::stored_type_constant::value) {} + + constexpr explicit operator bool() const noexcept { + return type_ != detail::type::none_type; + } + auto type() const -> detail::type { return type_; } + + /** + * Visits an argument dispatching to the appropriate visit method based on + * the argument type. For example, if the argument type is `double` then + * `vis(value)` will be called with the value of type `double`. + */ + template + FMT_CONSTEXPR FMT_INLINE auto visit(Visitor&& vis) const -> decltype(vis(0)) { + using detail::map; + switch (type_) { + case detail::type::none_type: break; + case detail::type::int_type: return vis(value_.int_value); + case detail::type::uint_type: return vis(value_.uint_value); + case detail::type::long_long_type: return vis(value_.long_long_value); + case detail::type::ulong_long_type: return vis(value_.ulong_long_value); + case detail::type::int128_type: return vis(map(value_.int128_value)); + case detail::type::uint128_type: return vis(map(value_.uint128_value)); + case detail::type::bool_type: return vis(value_.bool_value); + case detail::type::char_type: return vis(value_.char_value); + case detail::type::float_type: return vis(value_.float_value); + case detail::type::double_type: return vis(value_.double_value); + case detail::type::long_double_type: return vis(value_.long_double_value); + case detail::type::cstring_type: return vis(value_.string.data); + case detail::type::string_type: return vis(value_.string.str()); + case detail::type::pointer_type: return vis(value_.pointer); + case detail::type::custom_type: return vis(handle(value_.custom)); + } + return vis(monostate()); + } + + auto format_custom(const char_type* parse_begin, + parse_context& parse_ctx, Context& ctx) + -> bool { + if (type_ != detail::type::custom_type) return false; + parse_ctx.advance_to(parse_begin); + value_.custom.format(value_.custom.value, parse_ctx, ctx); + return true; + } +}; + +/** + * A view of a collection of formatting arguments. To avoid lifetime issues it + * should only be used as a parameter type in type-erased functions such as + * `vformat`: + * + * void vlog(fmt::string_view fmt, fmt::format_args args); // OK + * fmt::format_args args = fmt::make_format_args(); // Dangling reference + */ +template class basic_format_args { + private: + // A descriptor that contains information about formatting arguments. + // If the number of arguments is less or equal to max_packed_args then + // argument types are passed in the descriptor. This reduces binary code size + // per formatting function call. + unsigned long long desc_; + union { + // If is_packed() returns true then argument values are stored in values_; + // otherwise they are stored in args_. This is done to improve cache + // locality and reduce compiled code size since storing larger objects + // may require more code (at least on x86-64) even if the same amount of + // data is actually copied to stack. It saves ~10% on the bloat test. + const detail::value* values_; + const basic_format_arg* args_; + }; + + constexpr auto is_packed() const -> bool { + return (desc_ & detail::is_unpacked_bit) == 0; + } + constexpr auto has_named_args() const -> bool { + return (desc_ & detail::has_named_args_bit) != 0; + } + + FMT_CONSTEXPR auto type(int index) const -> detail::type { + int shift = index * detail::packed_arg_bits; + unsigned mask = (1 << detail::packed_arg_bits) - 1; + return static_cast((desc_ >> shift) & mask); + } + + template + using store = + detail::format_arg_store; + + public: + using format_arg = basic_format_arg; + + constexpr basic_format_args() : desc_(0), args_(nullptr) {} + + /// Constructs a `basic_format_args` object from `format_arg_store`. + template + constexpr FMT_ALWAYS_INLINE basic_format_args( + const store& s) + : desc_(DESC | (NUM_NAMED_ARGS != 0 ? +detail::has_named_args_bit : 0)), + values_(s.args) {} + + template detail::max_packed_args)> + constexpr basic_format_args(const store& s) + : desc_(DESC | (NUM_NAMED_ARGS != 0 ? +detail::has_named_args_bit : 0)), + args_(s.args) {} + + /// Constructs a `basic_format_args` object from a dynamic list of arguments. + constexpr basic_format_args(const format_arg* args, int count, + bool has_named = false) + : desc_(detail::is_unpacked_bit | detail::to_unsigned(count) | + (has_named ? +detail::has_named_args_bit : 0)), + args_(args) {} + + /// Returns the argument with the specified id. + FMT_CONSTEXPR auto get(int id) const -> format_arg { + auto arg = format_arg(); + if (!is_packed()) { + if (id < max_size()) arg = args_[id]; + return arg; + } + if (static_cast(id) >= detail::max_packed_args) return arg; + arg.type_ = type(id); + if (arg.type_ != detail::type::none_type) arg.value_ = values_[id]; + return arg; + } + + template + auto get(basic_string_view name) const -> format_arg { + int id = get_id(name); + return id >= 0 ? get(id) : format_arg(); + } + + template + FMT_CONSTEXPR auto get_id(basic_string_view name) const -> int { + if (!has_named_args()) return -1; + const auto& named_args = + (is_packed() ? values_[-1] : args_[-1].value_).named_args; + for (size_t i = 0; i < named_args.size; ++i) { + if (named_args.data[i].name == name) return named_args.data[i].id; + } + return -1; + } + + auto max_size() const -> int { + unsigned long long max_packed = detail::max_packed_args; + return static_cast(is_packed() ? max_packed + : desc_ & ~detail::is_unpacked_bit); + } +}; + +// A formatting context. +class context { + private: + appender out_; + format_args args_; + FMT_NO_UNIQUE_ADDRESS locale_ref loc_; + + public: + using char_type = char; ///< The character type for the output. + using iterator = appender; + using format_arg = basic_format_arg; + enum { builtin_types = FMT_BUILTIN_TYPES }; + + /// Constructs a `context` object. References to the arguments are stored + /// in the object so make sure they have appropriate lifetimes. + FMT_CONSTEXPR context(iterator out, format_args args, locale_ref loc = {}) + : out_(out), args_(args), loc_(loc) {} + context(context&&) = default; + context(const context&) = delete; + void operator=(const context&) = delete; + + FMT_CONSTEXPR auto arg(int id) const -> format_arg { return args_.get(id); } + inline auto arg(string_view name) const -> format_arg { + return args_.get(name); + } + FMT_CONSTEXPR auto arg_id(string_view name) const -> int { + return args_.get_id(name); + } + auto args() const -> const format_args& { return args_; } + + // Returns an iterator to the beginning of the output range. + FMT_CONSTEXPR auto out() const -> iterator { return out_; } + + // Advances the begin iterator to `it`. + FMT_CONSTEXPR void advance_to(iterator) {} + + FMT_CONSTEXPR auto locale() const -> locale_ref { return loc_; } +}; + +template struct runtime_format_string { + basic_string_view str; +}; + +/** + * Creates a runtime format string. + * + * **Example**: + * + * // Check format string at runtime instead of compile-time. + * fmt::print(fmt::runtime("{:d}"), "I am not a number"); + */ +inline auto runtime(string_view s) -> runtime_format_string<> { return {{s}}; } + +/// A compile-time format string. Use `format_string` in the public API to +/// prevent type deduction. +template struct fstring { + private: + static constexpr int num_static_named_args = + detail::count_static_named_args(); + + using checker = detail::format_string_checker< + char, static_cast(sizeof...(T)), num_static_named_args, + num_static_named_args != detail::count_named_args()>; + + using arg_pack = detail::arg_pack; + + public: + string_view str; + using t = fstring; + + // Reports a compile-time error if S is not a valid format string for T. + template + FMT_CONSTEVAL FMT_ALWAYS_INLINE fstring(const char (&s)[N]) : str(s, N - 1) { + using namespace detail; + static_assert(count<(is_view>::value && + std::is_reference::value)...>() == 0, + "passing views as lvalues is disallowed"); + if (FMT_USE_CONSTEVAL) parse_format_string(s, checker(s, arg_pack())); +#ifdef FMT_ENFORCE_COMPILE_STRING + static_assert( + FMT_USE_CONSTEVAL && sizeof(s) != 0, + "FMT_ENFORCE_COMPILE_STRING requires format strings to use FMT_STRING"); +#endif + } + template ::value)> + FMT_CONSTEVAL FMT_ALWAYS_INLINE fstring(const S& s) : str(s) { + auto sv = string_view(str); + if (FMT_USE_CONSTEVAL) + detail::parse_format_string(sv, checker(sv, arg_pack())); +#ifdef FMT_ENFORCE_COMPILE_STRING + static_assert( + FMT_USE_CONSTEVAL && sizeof(s) != 0, + "FMT_ENFORCE_COMPILE_STRING requires format strings to use FMT_STRING"); +#endif + } + template ::value&& + std::is_same::value)> + FMT_ALWAYS_INLINE fstring(const S&) : str(S()) { + FMT_CONSTEXPR auto sv = string_view(S()); + FMT_CONSTEXPR int unused = + (parse_format_string(sv, checker(sv, arg_pack())), 0); + detail::ignore_unused(unused); + } + fstring(runtime_format_string<> fmt) : str(fmt.str) {} + + // Returning by reference generates better code in debug mode. + FMT_ALWAYS_INLINE operator const string_view&() const { return str; } + auto get() const -> string_view { return str; } +}; + +template using format_string = typename fstring::t; + +template +using is_formattable = bool_constant::value, int*, T>, Char>, + void>::value>; +#ifdef __cpp_concepts +template +concept formattable = is_formattable, Char>::value; +#endif + +// A formatter specialization for natively supported types. +template +struct formatter::value != + detail::type::custom_type>> + : detail::native_formatter::value> { +}; + +/** + * Constructs an object that stores references to arguments and can be + * implicitly converted to `format_args`. `Context` can be omitted in which case + * it defaults to `context`. See `arg` for lifetime considerations. + */ +// Take arguments by lvalue references to avoid some lifetime issues, e.g. +// auto args = make_format_args(std::string()); +template (), + unsigned long long DESC = detail::make_descriptor()> +constexpr FMT_ALWAYS_INLINE auto make_format_args(T&... args) + -> detail::format_arg_store { + // Suppress warnings for pathological types convertible to detail::value. + FMT_PRAGMA_GCC(diagnostic ignored "-Wconversion") + return {{args...}}; +} + +template +using vargs = + detail::format_arg_store(), + detail::make_descriptor()>; + +/** + * Returns a named argument to be used in a formatting function. + * It should only be used in a call to a formatting function. + * + * **Example**: + * + * fmt::print("The answer is {answer}.", fmt::arg("answer", 42)); + */ +template +inline auto arg(const Char* name, const T& arg) -> detail::named_arg { + return {name, arg}; +} + +/// Formats a string and writes the output to `out`. +template , + char>::value)> +auto vformat_to(OutputIt&& out, string_view fmt, format_args args) + -> remove_cvref_t { + auto&& buf = detail::get_buffer(out); + detail::vformat_to(buf, fmt, args, {}); + return detail::get_iterator(buf, out); +} + +/** + * Formats `args` according to specifications in `fmt`, writes the result to + * the output iterator `out` and returns the iterator past the end of the output + * range. `format_to` does not append a terminating null character. + * + * **Example**: + * + * auto out = std::vector(); + * fmt::format_to(std::back_inserter(out), "{}", 42); + */ +template , + char>::value)> +FMT_INLINE auto format_to(OutputIt&& out, format_string fmt, T&&... args) + -> remove_cvref_t { + return vformat_to(out, fmt.str, vargs{{args...}}); +} + +template struct format_to_n_result { + /// Iterator past the end of the output range. + OutputIt out; + /// Total (not truncated) output size. + size_t size; +}; + +template ::value)> +auto vformat_to_n(OutputIt out, size_t n, string_view fmt, format_args args) + -> format_to_n_result { + using traits = detail::fixed_buffer_traits; + auto buf = detail::iterator_buffer(out, n); + detail::vformat_to(buf, fmt, args, {}); + return {buf.out(), buf.count()}; +} + +/** + * Formats `args` according to specifications in `fmt`, writes up to `n` + * characters of the result to the output iterator `out` and returns the total + * (not truncated) output size and the iterator past the end of the output + * range. `format_to_n` does not append a terminating null character. + */ +template ::value)> +FMT_INLINE auto format_to_n(OutputIt out, size_t n, format_string fmt, + T&&... args) -> format_to_n_result { + return vformat_to_n(out, n, fmt.str, vargs{{args...}}); +} + +struct format_to_result { + /// Pointer to just after the last successful write in the array. + char* out; + /// Specifies if the output was truncated. + bool truncated; + + FMT_CONSTEXPR operator char*() const { + // Report truncation to prevent silent data loss. + if (truncated) report_error("output is truncated"); + return out; + } +}; + +template +auto vformat_to(char (&out)[N], string_view fmt, format_args args) + -> format_to_result { + auto result = vformat_to_n(out, N, fmt, args); + return {result.out, result.size > N}; +} + +template +FMT_INLINE auto format_to(char (&out)[N], format_string fmt, T&&... args) + -> format_to_result { + auto result = vformat_to_n(out, N, fmt.str, vargs{{args...}}); + return {result.out, result.size > N}; +} + +/// Returns the number of chars in the output of `format(fmt, args...)`. +template +FMT_NODISCARD FMT_INLINE auto formatted_size(format_string fmt, + T&&... args) -> size_t { + auto buf = detail::counting_buffer<>(); + detail::vformat_to(buf, fmt.str, vargs{{args...}}, {}); + return buf.count(); +} + +FMT_API void vprint(string_view fmt, format_args args); +FMT_API void vprint(FILE* f, string_view fmt, format_args args); +FMT_API void vprintln(FILE* f, string_view fmt, format_args args); +FMT_API void vprint_buffered(FILE* f, string_view fmt, format_args args); + +/** + * Formats `args` according to specifications in `fmt` and writes the output + * to `stdout`. + * + * **Example**: + * + * fmt::print("The answer is {}.", 42); + */ +template +FMT_INLINE void print(format_string fmt, T&&... args) { + vargs va = {{args...}}; + if (detail::const_check(!detail::use_utf8)) + return detail::vprint_mojibake(stdout, fmt.str, va, false); + return detail::is_locking() ? vprint_buffered(stdout, fmt.str, va) + : vprint(fmt.str, va); +} + +/** + * Formats `args` according to specifications in `fmt` and writes the + * output to the file `f`. + * + * **Example**: + * + * fmt::print(stderr, "Don't {}!", "panic"); + */ +template +FMT_INLINE void print(FILE* f, format_string fmt, T&&... args) { + vargs va = {{args...}}; + if (detail::const_check(!detail::use_utf8)) + return detail::vprint_mojibake(f, fmt.str, va, false); + return detail::is_locking() ? vprint_buffered(f, fmt.str, va) + : vprint(f, fmt.str, va); +} + +/// Formats `args` according to specifications in `fmt` and writes the output +/// to the file `f` followed by a newline. +template +FMT_INLINE void println(FILE* f, format_string fmt, T&&... args) { + vargs va = {{args...}}; + return detail::const_check(detail::use_utf8) + ? vprintln(f, fmt.str, va) + : detail::vprint_mojibake(f, fmt.str, va, true); +} + +/// Formats `args` according to specifications in `fmt` and writes the output +/// to `stdout` followed by a newline. +template +FMT_INLINE void println(format_string fmt, T&&... args) { + return fmt::println(stdout, fmt, static_cast(args)...); +} + +FMT_PRAGMA_GCC(diagnostic pop) +FMT_PRAGMA_CLANG(diagnostic pop) +FMT_PRAGMA_GCC(pop_options) +FMT_END_EXPORT +FMT_END_NAMESPACE + +#ifdef FMT_HEADER_ONLY +# include "format.h" +#endif +#endif // FMT_BASE_H_ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fmt/chrono.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fmt/chrono.h new file mode 100644 index 0000000000000000000000000000000000000000..d4519b16353e00d73bdcc493926625937f8c1808 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fmt/chrono.h @@ -0,0 +1,2251 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Formatting library for C++ - chrono support +// +// Copyright (c) 2012 - present, Victor Zverovich +// All rights reserved. +// +// For the license information refer to format.h. + +#ifndef FMT_CHRONO_H_ +#define FMT_CHRONO_H_ + +#ifndef FMT_MODULE +# include +# include +# include // std::isfinite +# include // std::memcpy +# include +# include +# include +# include +# include +#endif + +#include "format.h" + +FMT_BEGIN_NAMESPACE + +// Enable safe chrono durations, unless explicitly disabled. +#ifndef FMT_SAFE_DURATION_CAST +# define FMT_SAFE_DURATION_CAST 1 +#endif +#if FMT_SAFE_DURATION_CAST + +// For conversion between std::chrono::durations without undefined +// behaviour or erroneous results. +// This is a stripped down version of duration_cast, for inclusion in fmt. +// See https://github.com/pauldreik/safe_duration_cast +// +// Copyright Paul Dreik 2019 +namespace safe_duration_cast { + +// DEPRECATED! +template ::value && + std::numeric_limits::is_signed == + std::numeric_limits::is_signed)> +FMT_CONSTEXPR auto lossless_integral_conversion(const From from, int& ec) + -> To { + ec = 0; + using F = std::numeric_limits; + using T = std::numeric_limits; + static_assert(F::is_integer, "From must be integral"); + static_assert(T::is_integer, "To must be integral"); + + // A and B are both signed, or both unsigned. + if (detail::const_check(F::digits <= T::digits)) { + // From fits in To without any problem. + } else { + // From does not always fit in To, resort to a dynamic check. + if (from < (T::min)() || from > (T::max)()) { + // outside range. + ec = 1; + return {}; + } + } + return static_cast(from); +} + +/// Converts From to To, without loss. If the dynamic value of from +/// can't be converted to To without loss, ec is set. +template ::value && + std::numeric_limits::is_signed != + std::numeric_limits::is_signed)> +FMT_CONSTEXPR auto lossless_integral_conversion(const From from, int& ec) + -> To { + ec = 0; + using F = std::numeric_limits; + using T = std::numeric_limits; + static_assert(F::is_integer, "From must be integral"); + static_assert(T::is_integer, "To must be integral"); + + if (detail::const_check(F::is_signed && !T::is_signed)) { + // From may be negative, not allowed! + if (fmt::detail::is_negative(from)) { + ec = 1; + return {}; + } + // From is positive. Can it always fit in To? + if (detail::const_check(F::digits > T::digits) && + from > static_cast(detail::max_value())) { + ec = 1; + return {}; + } + } + + if (detail::const_check(!F::is_signed && T::is_signed && + F::digits >= T::digits) && + from > static_cast(detail::max_value())) { + ec = 1; + return {}; + } + return static_cast(from); // Lossless conversion. +} + +template ::value)> +FMT_CONSTEXPR auto lossless_integral_conversion(const From from, int& ec) + -> To { + ec = 0; + return from; +} // function + +// clang-format off +/** + * converts From to To if possible, otherwise ec is set. + * + * input | output + * ---------------------------------|--------------- + * NaN | NaN + * Inf | Inf + * normal, fits in output | converted (possibly lossy) + * normal, does not fit in output | ec is set + * subnormal | best effort + * -Inf | -Inf + */ +// clang-format on +template ::value)> +FMT_CONSTEXPR auto safe_float_conversion(const From from, int& ec) -> To { + ec = 0; + using T = std::numeric_limits; + static_assert(std::is_floating_point::value, "From must be floating"); + static_assert(std::is_floating_point::value, "To must be floating"); + + // catch the only happy case + if (std::isfinite(from)) { + if (from >= T::lowest() && from <= (T::max)()) { + return static_cast(from); + } + // not within range. + ec = 1; + return {}; + } + + // nan and inf will be preserved + return static_cast(from); +} // function + +template ::value)> +FMT_CONSTEXPR auto safe_float_conversion(const From from, int& ec) -> To { + ec = 0; + static_assert(std::is_floating_point::value, "From must be floating"); + return from; +} + +/// Safe duration_cast between floating point durations +template ::value), + FMT_ENABLE_IF(std::is_floating_point::value)> +auto safe_duration_cast(std::chrono::duration from, + int& ec) -> To { + using From = std::chrono::duration; + ec = 0; + + // the basic idea is that we need to convert from count() in the from type + // to count() in the To type, by multiplying it with this: + struct Factor + : std::ratio_divide {}; + + static_assert(Factor::num > 0, "num must be positive"); + static_assert(Factor::den > 0, "den must be positive"); + + // the conversion is like this: multiply from.count() with Factor::num + // /Factor::den and convert it to To::rep, all this without + // overflow/underflow. let's start by finding a suitable type that can hold + // both To, From and Factor::num + using IntermediateRep = + typename std::common_type::type; + + // force conversion of From::rep -> IntermediateRep to be safe, + // even if it will never happen be narrowing in this context. + IntermediateRep count = + safe_float_conversion(from.count(), ec); + if (ec) { + return {}; + } + + // multiply with Factor::num without overflow or underflow + if (detail::const_check(Factor::num != 1)) { + constexpr auto max1 = detail::max_value() / + static_cast(Factor::num); + if (count > max1) { + ec = 1; + return {}; + } + constexpr auto min1 = std::numeric_limits::lowest() / + static_cast(Factor::num); + if (count < min1) { + ec = 1; + return {}; + } + count *= static_cast(Factor::num); + } + + // this can't go wrong, right? den>0 is checked earlier. + if (detail::const_check(Factor::den != 1)) { + using common_t = typename std::common_type::type; + count /= static_cast(Factor::den); + } + + // convert to the to type, safely + using ToRep = typename To::rep; + + const ToRep tocount = safe_float_conversion(count, ec); + if (ec) { + return {}; + } + return To{tocount}; +} +} // namespace safe_duration_cast +#endif + +namespace detail { + +// Check if std::chrono::utc_time is available. +#ifdef FMT_USE_UTC_TIME +// Use the provided definition. +#elif defined(__cpp_lib_chrono) +# define FMT_USE_UTC_TIME (__cpp_lib_chrono >= 201907L) +#else +# define FMT_USE_UTC_TIME 0 +#endif +#if FMT_USE_UTC_TIME +using utc_clock = std::chrono::utc_clock; +#else +struct utc_clock { + template void to_sys(T); +}; +#endif + +// Check if std::chrono::local_time is available. +#ifdef FMT_USE_LOCAL_TIME +// Use the provided definition. +#elif defined(__cpp_lib_chrono) +# define FMT_USE_LOCAL_TIME (__cpp_lib_chrono >= 201907L) +#else +# define FMT_USE_LOCAL_TIME 0 +#endif +#if FMT_USE_LOCAL_TIME +using local_t = std::chrono::local_t; +#else +struct local_t {}; +#endif + +} // namespace detail + +template +using sys_time = std::chrono::time_point; + +template +using utc_time = std::chrono::time_point; + +template +using local_time = std::chrono::time_point; + +namespace detail { + +// Prevents expansion of a preceding token as a function-style macro. +// Usage: f FMT_NOMACRO() +#define FMT_NOMACRO + +template struct null {}; +inline auto gmtime_r(...) -> null<> { return null<>(); } +inline auto gmtime_s(...) -> null<> { return null<>(); } + +// It is defined here and not in ostream.h because the latter has expensive +// includes. +template class formatbuf : public StreamBuf { + private: + using char_type = typename StreamBuf::char_type; + using streamsize = decltype(std::declval().sputn(nullptr, 0)); + using int_type = typename StreamBuf::int_type; + using traits_type = typename StreamBuf::traits_type; + + buffer& buffer_; + + public: + explicit formatbuf(buffer& buf) : buffer_(buf) {} + + protected: + // The put area is always empty. This makes the implementation simpler and has + // the advantage that the streambuf and the buffer are always in sync and + // sputc never writes into uninitialized memory. A disadvantage is that each + // call to sputc always results in a (virtual) call to overflow. There is no + // disadvantage here for sputn since this always results in a call to xsputn. + + auto overflow(int_type ch) -> int_type override { + if (!traits_type::eq_int_type(ch, traits_type::eof())) + buffer_.push_back(static_cast(ch)); + return ch; + } + + auto xsputn(const char_type* s, streamsize count) -> streamsize override { + buffer_.append(s, s + count); + return count; + } +}; + +inline auto get_classic_locale() -> const std::locale& { + static const auto& locale = std::locale::classic(); + return locale; +} + +template struct codecvt_result { + static constexpr size_t max_size = 32; + CodeUnit buf[max_size]; + CodeUnit* end; +}; + +template +void write_codecvt(codecvt_result& out, string_view in, + const std::locale& loc) { + FMT_PRAGMA_CLANG(diagnostic push) + FMT_PRAGMA_CLANG(diagnostic ignored "-Wdeprecated") + auto& f = std::use_facet>(loc); + FMT_PRAGMA_CLANG(diagnostic pop) + auto mb = std::mbstate_t(); + const char* from_next = nullptr; + auto result = f.in(mb, in.begin(), in.end(), from_next, std::begin(out.buf), + std::end(out.buf), out.end); + if (result != std::codecvt_base::ok) + FMT_THROW(format_error("failed to format time")); +} + +template +auto write_encoded_tm_str(OutputIt out, string_view in, const std::locale& loc) + -> OutputIt { + if (const_check(detail::use_utf8) && loc != get_classic_locale()) { + // char16_t and char32_t codecvts are broken in MSVC (linkage errors) and + // gcc-4. +#if FMT_MSC_VERSION != 0 || \ + (defined(__GLIBCXX__) && \ + (!defined(_GLIBCXX_USE_DUAL_ABI) || _GLIBCXX_USE_DUAL_ABI == 0)) + // The _GLIBCXX_USE_DUAL_ABI macro is always defined in libstdc++ from gcc-5 + // and newer. + using code_unit = wchar_t; +#else + using code_unit = char32_t; +#endif + + using unit_t = codecvt_result; + unit_t unit; + write_codecvt(unit, in, loc); + // In UTF-8 is used one to four one-byte code units. + auto u = + to_utf8>(); + if (!u.convert({unit.buf, to_unsigned(unit.end - unit.buf)})) + FMT_THROW(format_error("failed to format time")); + return copy(u.c_str(), u.c_str() + u.size(), out); + } + return copy(in.data(), in.data() + in.size(), out); +} + +template ::value)> +auto write_tm_str(OutputIt out, string_view sv, const std::locale& loc) + -> OutputIt { + codecvt_result unit; + write_codecvt(unit, sv, loc); + return copy(unit.buf, unit.end, out); +} + +template ::value)> +auto write_tm_str(OutputIt out, string_view sv, const std::locale& loc) + -> OutputIt { + return write_encoded_tm_str(out, sv, loc); +} + +template +inline void do_write(buffer& buf, const std::tm& time, + const std::locale& loc, char format, char modifier) { + auto&& format_buf = formatbuf>(buf); + auto&& os = std::basic_ostream(&format_buf); + os.imbue(loc); + const auto& facet = std::use_facet>(loc); + auto end = facet.put(os, os, Char(' '), &time, format, modifier); + if (end.failed()) FMT_THROW(format_error("failed to format time")); +} + +template ::value)> +auto write(OutputIt out, const std::tm& time, const std::locale& loc, + char format, char modifier = 0) -> OutputIt { + auto&& buf = get_buffer(out); + do_write(buf, time, loc, format, modifier); + return get_iterator(buf, out); +} + +template ::value)> +auto write(OutputIt out, const std::tm& time, const std::locale& loc, + char format, char modifier = 0) -> OutputIt { + auto&& buf = basic_memory_buffer(); + do_write(buf, time, loc, format, modifier); + return write_encoded_tm_str(out, string_view(buf.data(), buf.size()), loc); +} + +template +using is_similar_arithmetic_type = + bool_constant<(std::is_integral::value && std::is_integral::value) || + (std::is_floating_point::value && + std::is_floating_point::value)>; + +FMT_NORETURN inline void throw_duration_error() { + FMT_THROW(format_error("cannot format duration")); +} + +// Cast one integral duration to another with an overflow check. +template ::value&& + std::is_integral::value)> +auto duration_cast(std::chrono::duration from) -> To { +#if !FMT_SAFE_DURATION_CAST + return std::chrono::duration_cast(from); +#else + // The conversion factor: to.count() == factor * from.count(). + using factor = std::ratio_divide; + + using common_rep = typename std::common_type::type; + common_rep count = from.count(); // This conversion is lossless. + + // Multiply from.count() by factor and check for overflow. + if (const_check(factor::num != 1)) { + if (count > max_value() / factor::num) throw_duration_error(); + const auto min = (std::numeric_limits::min)() / factor::num; + if (const_check(!std::is_unsigned::value) && count < min) + throw_duration_error(); + count *= factor::num; + } + if (const_check(factor::den != 1)) count /= factor::den; + int ec = 0; + auto to = + To(safe_duration_cast::lossless_integral_conversion( + count, ec)); + if (ec) throw_duration_error(); + return to; +#endif +} + +template ::value&& + std::is_floating_point::value)> +auto duration_cast(std::chrono::duration from) -> To { +#if FMT_SAFE_DURATION_CAST + // Preserve infinity and NaN. + if (!isfinite(from.count())) return static_cast(from.count()); + // Throwing version of safe_duration_cast is only available for + // integer to integer or float to float casts. + int ec; + To to = safe_duration_cast::safe_duration_cast(from, ec); + if (ec) throw_duration_error(); + return to; +#else + // Standard duration cast, may overflow. + return std::chrono::duration_cast(from); +#endif +} + +template ::value)> +auto duration_cast(std::chrono::duration from) -> To { + // Mixed integer <-> float cast is not supported by safe duration_cast. + return std::chrono::duration_cast(from); +} + +template +auto to_time_t(sys_time time_point) -> std::time_t { + // Cannot use std::chrono::system_clock::to_time_t since this would first + // require a cast to std::chrono::system_clock::time_point, which could + // overflow. + return detail::duration_cast>( + time_point.time_since_epoch()) + .count(); +} + +} // namespace detail + +FMT_BEGIN_EXPORT + +/** + * Converts given time since epoch as `std::time_t` value into calendar time, + * expressed in Coordinated Universal Time (UTC). Unlike `std::gmtime`, this + * function is thread-safe on most platforms. + */ +inline auto gmtime(std::time_t time) -> std::tm { + struct dispatcher { + std::time_t time_; + std::tm tm_; + + inline dispatcher(std::time_t t) : time_(t) {} + + inline auto run() -> bool { + using namespace fmt::detail; + return handle(gmtime_r(&time_, &tm_)); + } + + inline auto handle(std::tm* tm) -> bool { return tm != nullptr; } + + inline auto handle(detail::null<>) -> bool { + using namespace fmt::detail; + return fallback(gmtime_s(&tm_, &time_)); + } + + inline auto fallback(int res) -> bool { return res == 0; } + +#if !FMT_MSC_VERSION + inline auto fallback(detail::null<>) -> bool { + std::tm* tm = std::gmtime(&time_); + if (tm) tm_ = *tm; + return tm != nullptr; + } +#endif + }; + auto gt = dispatcher(time); + // Too big time values may be unsupported. + if (!gt.run()) FMT_THROW(format_error("time_t value out of range")); + return gt.tm_; +} + +template +inline auto gmtime(sys_time time_point) -> std::tm { + return gmtime(detail::to_time_t(time_point)); +} + +namespace detail { + +// Writes two-digit numbers a, b and c separated by sep to buf. +// The method by Pavel Novikov based on +// https://johnnylee-sde.github.io/Fast-unsigned-integer-to-time-string/. +inline void write_digit2_separated(char* buf, unsigned a, unsigned b, + unsigned c, char sep) { + unsigned long long digits = + a | (b << 24) | (static_cast(c) << 48); + // Convert each value to BCD. + // We have x = a * 10 + b and we want to convert it to BCD y = a * 16 + b. + // The difference is + // y - x = a * 6 + // a can be found from x: + // a = floor(x / 10) + // then + // y = x + a * 6 = x + floor(x / 10) * 6 + // floor(x / 10) is (x * 205) >> 11 (needs 16 bits). + digits += (((digits * 205) >> 11) & 0x000f00000f00000f) * 6; + // Put low nibbles to high bytes and high nibbles to low bytes. + digits = ((digits & 0x00f00000f00000f0) >> 4) | + ((digits & 0x000f00000f00000f) << 8); + auto usep = static_cast(sep); + // Add ASCII '0' to each digit byte and insert separators. + digits |= 0x3030003030003030 | (usep << 16) | (usep << 40); + + constexpr size_t len = 8; + if (const_check(is_big_endian())) { + char tmp[len]; + std::memcpy(tmp, &digits, len); + std::reverse_copy(tmp, tmp + len, buf); + } else { + std::memcpy(buf, &digits, len); + } +} + +template +FMT_CONSTEXPR inline auto get_units() -> const char* { + if (std::is_same::value) return "as"; + if (std::is_same::value) return "fs"; + if (std::is_same::value) return "ps"; + if (std::is_same::value) return "ns"; + if (std::is_same::value) + return detail::use_utf8 ? "µs" : "us"; + if (std::is_same::value) return "ms"; + if (std::is_same::value) return "cs"; + if (std::is_same::value) return "ds"; + if (std::is_same>::value) return "s"; + if (std::is_same::value) return "das"; + if (std::is_same::value) return "hs"; + if (std::is_same::value) return "ks"; + if (std::is_same::value) return "Ms"; + if (std::is_same::value) return "Gs"; + if (std::is_same::value) return "Ts"; + if (std::is_same::value) return "Ps"; + if (std::is_same::value) return "Es"; + if (std::is_same>::value) return "min"; + if (std::is_same>::value) return "h"; + if (std::is_same>::value) return "d"; + return nullptr; +} + +enum class numeric_system { + standard, + // Alternative numeric system, e.g. 十二 instead of 12 in ja_JP locale. + alternative +}; + +// Glibc extensions for formatting numeric values. +enum class pad_type { + // Pad a numeric result string with zeros (the default). + zero, + // Do not pad a numeric result string. + none, + // Pad a numeric result string with spaces. + space, +}; + +template +auto write_padding(OutputIt out, pad_type pad, int width) -> OutputIt { + if (pad == pad_type::none) return out; + return detail::fill_n(out, width, pad == pad_type::space ? ' ' : '0'); +} + +template +auto write_padding(OutputIt out, pad_type pad) -> OutputIt { + if (pad != pad_type::none) *out++ = pad == pad_type::space ? ' ' : '0'; + return out; +} + +// Parses a put_time-like format string and invokes handler actions. +template +FMT_CONSTEXPR auto parse_chrono_format(const Char* begin, const Char* end, + Handler&& handler) -> const Char* { + if (begin == end || *begin == '}') return begin; + if (*begin != '%') FMT_THROW(format_error("invalid format")); + auto ptr = begin; + while (ptr != end) { + pad_type pad = pad_type::zero; + auto c = *ptr; + if (c == '}') break; + if (c != '%') { + ++ptr; + continue; + } + if (begin != ptr) handler.on_text(begin, ptr); + ++ptr; // consume '%' + if (ptr == end) FMT_THROW(format_error("invalid format")); + c = *ptr; + switch (c) { + case '_': + pad = pad_type::space; + ++ptr; + break; + case '-': + pad = pad_type::none; + ++ptr; + break; + } + if (ptr == end) FMT_THROW(format_error("invalid format")); + c = *ptr++; + switch (c) { + case '%': handler.on_text(ptr - 1, ptr); break; + case 'n': { + const Char newline[] = {'\n'}; + handler.on_text(newline, newline + 1); + break; + } + case 't': { + const Char tab[] = {'\t'}; + handler.on_text(tab, tab + 1); + break; + } + // Year: + case 'Y': handler.on_year(numeric_system::standard, pad); break; + case 'y': handler.on_short_year(numeric_system::standard); break; + case 'C': handler.on_century(numeric_system::standard); break; + case 'G': handler.on_iso_week_based_year(); break; + case 'g': handler.on_iso_week_based_short_year(); break; + // Day of the week: + case 'a': handler.on_abbr_weekday(); break; + case 'A': handler.on_full_weekday(); break; + case 'w': handler.on_dec0_weekday(numeric_system::standard); break; + case 'u': handler.on_dec1_weekday(numeric_system::standard); break; + // Month: + case 'b': + case 'h': handler.on_abbr_month(); break; + case 'B': handler.on_full_month(); break; + case 'm': handler.on_dec_month(numeric_system::standard, pad); break; + // Day of the year/month: + case 'U': + handler.on_dec0_week_of_year(numeric_system::standard, pad); + break; + case 'W': + handler.on_dec1_week_of_year(numeric_system::standard, pad); + break; + case 'V': handler.on_iso_week_of_year(numeric_system::standard, pad); break; + case 'j': handler.on_day_of_year(pad); break; + case 'd': handler.on_day_of_month(numeric_system::standard, pad); break; + case 'e': + handler.on_day_of_month(numeric_system::standard, pad_type::space); + break; + // Hour, minute, second: + case 'H': handler.on_24_hour(numeric_system::standard, pad); break; + case 'I': handler.on_12_hour(numeric_system::standard, pad); break; + case 'M': handler.on_minute(numeric_system::standard, pad); break; + case 'S': handler.on_second(numeric_system::standard, pad); break; + // Other: + case 'c': handler.on_datetime(numeric_system::standard); break; + case 'x': handler.on_loc_date(numeric_system::standard); break; + case 'X': handler.on_loc_time(numeric_system::standard); break; + case 'D': handler.on_us_date(); break; + case 'F': handler.on_iso_date(); break; + case 'r': handler.on_12_hour_time(); break; + case 'R': handler.on_24_hour_time(); break; + case 'T': handler.on_iso_time(); break; + case 'p': handler.on_am_pm(); break; + case 'Q': handler.on_duration_value(); break; + case 'q': handler.on_duration_unit(); break; + case 'z': handler.on_utc_offset(numeric_system::standard); break; + case 'Z': handler.on_tz_name(); break; + // Alternative representation: + case 'E': { + if (ptr == end) FMT_THROW(format_error("invalid format")); + c = *ptr++; + switch (c) { + case 'Y': handler.on_year(numeric_system::alternative, pad); break; + case 'y': handler.on_offset_year(); break; + case 'C': handler.on_century(numeric_system::alternative); break; + case 'c': handler.on_datetime(numeric_system::alternative); break; + case 'x': handler.on_loc_date(numeric_system::alternative); break; + case 'X': handler.on_loc_time(numeric_system::alternative); break; + case 'z': handler.on_utc_offset(numeric_system::alternative); break; + default: FMT_THROW(format_error("invalid format")); + } + break; + } + case 'O': + if (ptr == end) FMT_THROW(format_error("invalid format")); + c = *ptr++; + switch (c) { + case 'y': handler.on_short_year(numeric_system::alternative); break; + case 'm': handler.on_dec_month(numeric_system::alternative, pad); break; + case 'U': + handler.on_dec0_week_of_year(numeric_system::alternative, pad); + break; + case 'W': + handler.on_dec1_week_of_year(numeric_system::alternative, pad); + break; + case 'V': + handler.on_iso_week_of_year(numeric_system::alternative, pad); + break; + case 'd': + handler.on_day_of_month(numeric_system::alternative, pad); + break; + case 'e': + handler.on_day_of_month(numeric_system::alternative, pad_type::space); + break; + case 'w': handler.on_dec0_weekday(numeric_system::alternative); break; + case 'u': handler.on_dec1_weekday(numeric_system::alternative); break; + case 'H': handler.on_24_hour(numeric_system::alternative, pad); break; + case 'I': handler.on_12_hour(numeric_system::alternative, pad); break; + case 'M': handler.on_minute(numeric_system::alternative, pad); break; + case 'S': handler.on_second(numeric_system::alternative, pad); break; + case 'z': handler.on_utc_offset(numeric_system::alternative); break; + default: FMT_THROW(format_error("invalid format")); + } + break; + default: FMT_THROW(format_error("invalid format")); + } + begin = ptr; + } + if (begin != ptr) handler.on_text(begin, ptr); + return ptr; +} + +template struct null_chrono_spec_handler { + FMT_CONSTEXPR void unsupported() { + static_cast(this)->unsupported(); + } + FMT_CONSTEXPR void on_year(numeric_system, pad_type) { unsupported(); } + FMT_CONSTEXPR void on_short_year(numeric_system) { unsupported(); } + FMT_CONSTEXPR void on_offset_year() { unsupported(); } + FMT_CONSTEXPR void on_century(numeric_system) { unsupported(); } + FMT_CONSTEXPR void on_iso_week_based_year() { unsupported(); } + FMT_CONSTEXPR void on_iso_week_based_short_year() { unsupported(); } + FMT_CONSTEXPR void on_abbr_weekday() { unsupported(); } + FMT_CONSTEXPR void on_full_weekday() { unsupported(); } + FMT_CONSTEXPR void on_dec0_weekday(numeric_system) { unsupported(); } + FMT_CONSTEXPR void on_dec1_weekday(numeric_system) { unsupported(); } + FMT_CONSTEXPR void on_abbr_month() { unsupported(); } + FMT_CONSTEXPR void on_full_month() { unsupported(); } + FMT_CONSTEXPR void on_dec_month(numeric_system, pad_type) { unsupported(); } + FMT_CONSTEXPR void on_dec0_week_of_year(numeric_system, pad_type) { + unsupported(); + } + FMT_CONSTEXPR void on_dec1_week_of_year(numeric_system, pad_type) { + unsupported(); + } + FMT_CONSTEXPR void on_iso_week_of_year(numeric_system, pad_type) { + unsupported(); + } + FMT_CONSTEXPR void on_day_of_year(pad_type) { unsupported(); } + FMT_CONSTEXPR void on_day_of_month(numeric_system, pad_type) { + unsupported(); + } + FMT_CONSTEXPR void on_24_hour(numeric_system) { unsupported(); } + FMT_CONSTEXPR void on_12_hour(numeric_system) { unsupported(); } + FMT_CONSTEXPR void on_minute(numeric_system) { unsupported(); } + FMT_CONSTEXPR void on_second(numeric_system) { unsupported(); } + FMT_CONSTEXPR void on_datetime(numeric_system) { unsupported(); } + FMT_CONSTEXPR void on_loc_date(numeric_system) { unsupported(); } + FMT_CONSTEXPR void on_loc_time(numeric_system) { unsupported(); } + FMT_CONSTEXPR void on_us_date() { unsupported(); } + FMT_CONSTEXPR void on_iso_date() { unsupported(); } + FMT_CONSTEXPR void on_12_hour_time() { unsupported(); } + FMT_CONSTEXPR void on_24_hour_time() { unsupported(); } + FMT_CONSTEXPR void on_iso_time() { unsupported(); } + FMT_CONSTEXPR void on_am_pm() { unsupported(); } + FMT_CONSTEXPR void on_duration_value() { unsupported(); } + FMT_CONSTEXPR void on_duration_unit() { unsupported(); } + FMT_CONSTEXPR void on_utc_offset(numeric_system) { unsupported(); } + FMT_CONSTEXPR void on_tz_name() { unsupported(); } +}; + +class tm_format_checker : public null_chrono_spec_handler { + private: + bool has_timezone_ = false; + + public: + constexpr explicit tm_format_checker(bool has_timezone) + : has_timezone_(has_timezone) {} + + FMT_NORETURN inline void unsupported() { + FMT_THROW(format_error("no format")); + } + + template + FMT_CONSTEXPR void on_text(const Char*, const Char*) {} + FMT_CONSTEXPR void on_year(numeric_system, pad_type) {} + FMT_CONSTEXPR void on_short_year(numeric_system) {} + FMT_CONSTEXPR void on_offset_year() {} + FMT_CONSTEXPR void on_century(numeric_system) {} + FMT_CONSTEXPR void on_iso_week_based_year() {} + FMT_CONSTEXPR void on_iso_week_based_short_year() {} + FMT_CONSTEXPR void on_abbr_weekday() {} + FMT_CONSTEXPR void on_full_weekday() {} + FMT_CONSTEXPR void on_dec0_weekday(numeric_system) {} + FMT_CONSTEXPR void on_dec1_weekday(numeric_system) {} + FMT_CONSTEXPR void on_abbr_month() {} + FMT_CONSTEXPR void on_full_month() {} + FMT_CONSTEXPR void on_dec_month(numeric_system, pad_type) {} + FMT_CONSTEXPR void on_dec0_week_of_year(numeric_system, pad_type) {} + FMT_CONSTEXPR void on_dec1_week_of_year(numeric_system, pad_type) {} + FMT_CONSTEXPR void on_iso_week_of_year(numeric_system, pad_type) {} + FMT_CONSTEXPR void on_day_of_year(pad_type) {} + FMT_CONSTEXPR void on_day_of_month(numeric_system, pad_type) {} + FMT_CONSTEXPR void on_24_hour(numeric_system, pad_type) {} + FMT_CONSTEXPR void on_12_hour(numeric_system, pad_type) {} + FMT_CONSTEXPR void on_minute(numeric_system, pad_type) {} + FMT_CONSTEXPR void on_second(numeric_system, pad_type) {} + FMT_CONSTEXPR void on_datetime(numeric_system) {} + FMT_CONSTEXPR void on_loc_date(numeric_system) {} + FMT_CONSTEXPR void on_loc_time(numeric_system) {} + FMT_CONSTEXPR void on_us_date() {} + FMT_CONSTEXPR void on_iso_date() {} + FMT_CONSTEXPR void on_12_hour_time() {} + FMT_CONSTEXPR void on_24_hour_time() {} + FMT_CONSTEXPR void on_iso_time() {} + FMT_CONSTEXPR void on_am_pm() {} + FMT_CONSTEXPR void on_utc_offset(numeric_system) { + if (!has_timezone_) FMT_THROW(format_error("no timezone")); + } + FMT_CONSTEXPR void on_tz_name() { + if (!has_timezone_) FMT_THROW(format_error("no timezone")); + } +}; + +inline auto tm_wday_full_name(int wday) -> const char* { + static constexpr const char* full_name_list[] = { + "Sunday", "Monday", "Tuesday", "Wednesday", + "Thursday", "Friday", "Saturday"}; + return wday >= 0 && wday <= 6 ? full_name_list[wday] : "?"; +} +inline auto tm_wday_short_name(int wday) -> const char* { + static constexpr const char* short_name_list[] = {"Sun", "Mon", "Tue", "Wed", + "Thu", "Fri", "Sat"}; + return wday >= 0 && wday <= 6 ? short_name_list[wday] : "???"; +} + +inline auto tm_mon_full_name(int mon) -> const char* { + static constexpr const char* full_name_list[] = { + "January", "February", "March", "April", "May", "June", + "July", "August", "September", "October", "November", "December"}; + return mon >= 0 && mon <= 11 ? full_name_list[mon] : "?"; +} +inline auto tm_mon_short_name(int mon) -> const char* { + static constexpr const char* short_name_list[] = { + "Jan", "Feb", "Mar", "Apr", "May", "Jun", + "Jul", "Aug", "Sep", "Oct", "Nov", "Dec", + }; + return mon >= 0 && mon <= 11 ? short_name_list[mon] : "???"; +} + +template +struct has_tm_gmtoff : std::false_type {}; +template +struct has_tm_gmtoff> : std::true_type {}; + +template struct has_tm_zone : std::false_type {}; +template +struct has_tm_zone> : std::true_type {}; + +template ::value)> +auto set_tm_zone(T& time, char* tz) -> bool { + time.tm_zone = tz; + return true; +} +template ::value)> +auto set_tm_zone(T&, char*) -> bool { + return false; +} + +inline auto utc() -> char* { + static char tz[] = "UTC"; + return tz; +} + +// Converts value to Int and checks that it's in the range [0, upper). +template ::value)> +inline auto to_nonnegative_int(T value, Int upper) -> Int { + if (!std::is_unsigned::value && + (value < 0 || to_unsigned(value) > to_unsigned(upper))) { + FMT_THROW(format_error("chrono value is out of range")); + } + return static_cast(value); +} +template ::value)> +inline auto to_nonnegative_int(T value, Int upper) -> Int { + auto int_value = static_cast(value); + if (int_value < 0 || value > static_cast(upper)) + FMT_THROW(format_error("invalid value")); + return int_value; +} + +constexpr auto pow10(std::uint32_t n) -> long long { + return n == 0 ? 1 : 10 * pow10(n - 1); +} + +// Counts the number of fractional digits in the range [0, 18] according to the +// C++20 spec. If more than 18 fractional digits are required then returns 6 for +// microseconds precision. +template () / 10)> +struct count_fractional_digits { + static constexpr int value = + Num % Den == 0 ? N : count_fractional_digits::value; +}; + +// Base case that doesn't instantiate any more templates +// in order to avoid overflow. +template +struct count_fractional_digits { + static constexpr int value = (Num % Den == 0) ? N : 6; +}; + +// Format subseconds which are given as an integer type with an appropriate +// number of digits. +template +void write_fractional_seconds(OutputIt& out, Duration d, int precision = -1) { + constexpr auto num_fractional_digits = + count_fractional_digits::value; + + using subsecond_precision = std::chrono::duration< + typename std::common_type::type, + std::ratio<1, pow10(num_fractional_digits)>>; + + const auto fractional = d - detail::duration_cast(d); + const auto subseconds = + std::chrono::treat_as_floating_point< + typename subsecond_precision::rep>::value + ? fractional.count() + : detail::duration_cast(fractional).count(); + auto n = static_cast>(subseconds); + const int num_digits = count_digits(n); + + int leading_zeroes = (std::max)(0, num_fractional_digits - num_digits); + if (precision < 0) { + FMT_ASSERT(!std::is_floating_point::value, ""); + if (std::ratio_less::value) { + *out++ = '.'; + out = detail::fill_n(out, leading_zeroes, '0'); + out = format_decimal(out, n, num_digits); + } + } else if (precision > 0) { + *out++ = '.'; + leading_zeroes = min_of(leading_zeroes, precision); + int remaining = precision - leading_zeroes; + out = detail::fill_n(out, leading_zeroes, '0'); + if (remaining < num_digits) { + int num_truncated_digits = num_digits - remaining; + n /= to_unsigned(pow10(to_unsigned(num_truncated_digits))); + if (n != 0) out = format_decimal(out, n, remaining); + return; + } + if (n != 0) { + out = format_decimal(out, n, num_digits); + remaining -= num_digits; + } + out = detail::fill_n(out, remaining, '0'); + } +} + +// Format subseconds which are given as a floating point type with an +// appropriate number of digits. We cannot pass the Duration here, as we +// explicitly need to pass the Rep value in the duration_formatter. +template +void write_floating_seconds(memory_buffer& buf, Duration duration, + int num_fractional_digits = -1) { + using rep = typename Duration::rep; + FMT_ASSERT(std::is_floating_point::value, ""); + + auto val = duration.count(); + + if (num_fractional_digits < 0) { + // For `std::round` with fallback to `round`: + // On some toolchains `std::round` is not available (e.g. GCC 6). + using namespace std; + num_fractional_digits = + count_fractional_digits::value; + if (num_fractional_digits < 6 && static_cast(round(val)) != val) + num_fractional_digits = 6; + } + + fmt::format_to(std::back_inserter(buf), FMT_STRING("{:.{}f}"), + std::fmod(val * static_cast(Duration::period::num) / + static_cast(Duration::period::den), + static_cast(60)), + num_fractional_digits); +} + +template +class tm_writer { + private: + static constexpr int days_per_week = 7; + + const std::locale& loc_; + bool is_classic_; + OutputIt out_; + const Duration* subsecs_; + const std::tm& tm_; + + auto tm_sec() const noexcept -> int { + FMT_ASSERT(tm_.tm_sec >= 0 && tm_.tm_sec <= 61, ""); + return tm_.tm_sec; + } + auto tm_min() const noexcept -> int { + FMT_ASSERT(tm_.tm_min >= 0 && tm_.tm_min <= 59, ""); + return tm_.tm_min; + } + auto tm_hour() const noexcept -> int { + FMT_ASSERT(tm_.tm_hour >= 0 && tm_.tm_hour <= 23, ""); + return tm_.tm_hour; + } + auto tm_mday() const noexcept -> int { + FMT_ASSERT(tm_.tm_mday >= 1 && tm_.tm_mday <= 31, ""); + return tm_.tm_mday; + } + auto tm_mon() const noexcept -> int { + FMT_ASSERT(tm_.tm_mon >= 0 && tm_.tm_mon <= 11, ""); + return tm_.tm_mon; + } + auto tm_year() const noexcept -> long long { return 1900ll + tm_.tm_year; } + auto tm_wday() const noexcept -> int { + FMT_ASSERT(tm_.tm_wday >= 0 && tm_.tm_wday <= 6, ""); + return tm_.tm_wday; + } + auto tm_yday() const noexcept -> int { + FMT_ASSERT(tm_.tm_yday >= 0 && tm_.tm_yday <= 365, ""); + return tm_.tm_yday; + } + + auto tm_hour12() const noexcept -> int { + auto h = tm_hour(); + auto z = h < 12 ? h : h - 12; + return z == 0 ? 12 : z; + } + + // POSIX and the C Standard are unclear or inconsistent about what %C and %y + // do if the year is negative or exceeds 9999. Use the convention that %C + // concatenated with %y yields the same output as %Y, and that %Y contains at + // least 4 characters, with more only if necessary. + auto split_year_lower(long long year) const noexcept -> int { + auto l = year % 100; + if (l < 0) l = -l; // l in [0, 99] + return static_cast(l); + } + + // Algorithm: https://en.wikipedia.org/wiki/ISO_week_date. + auto iso_year_weeks(long long curr_year) const noexcept -> int { + auto prev_year = curr_year - 1; + auto curr_p = + (curr_year + curr_year / 4 - curr_year / 100 + curr_year / 400) % + days_per_week; + auto prev_p = + (prev_year + prev_year / 4 - prev_year / 100 + prev_year / 400) % + days_per_week; + return 52 + ((curr_p == 4 || prev_p == 3) ? 1 : 0); + } + auto iso_week_num(int tm_yday, int tm_wday) const noexcept -> int { + return (tm_yday + 11 - (tm_wday == 0 ? days_per_week : tm_wday)) / + days_per_week; + } + auto tm_iso_week_year() const noexcept -> long long { + auto year = tm_year(); + auto w = iso_week_num(tm_yday(), tm_wday()); + if (w < 1) return year - 1; + if (w > iso_year_weeks(year)) return year + 1; + return year; + } + auto tm_iso_week_of_year() const noexcept -> int { + auto year = tm_year(); + auto w = iso_week_num(tm_yday(), tm_wday()); + if (w < 1) return iso_year_weeks(year - 1); + if (w > iso_year_weeks(year)) return 1; + return w; + } + + void write1(int value) { + *out_++ = static_cast('0' + to_unsigned(value) % 10); + } + void write2(int value) { + const char* d = digits2(to_unsigned(value) % 100); + *out_++ = *d++; + *out_++ = *d; + } + void write2(int value, pad_type pad) { + unsigned int v = to_unsigned(value) % 100; + if (v >= 10) { + const char* d = digits2(v); + *out_++ = *d++; + *out_++ = *d; + } else { + out_ = detail::write_padding(out_, pad); + *out_++ = static_cast('0' + v); + } + } + + void write_year_extended(long long year, pad_type pad) { + // At least 4 characters. + int width = 4; + bool negative = year < 0; + if (negative) { + year = 0 - year; + --width; + } + uint32_or_64_or_128_t n = to_unsigned(year); + const int num_digits = count_digits(n); + if (negative && pad == pad_type::zero) *out_++ = '-'; + if (width > num_digits) + out_ = detail::write_padding(out_, pad, width - num_digits); + if (negative && pad != pad_type::zero) *out_++ = '-'; + out_ = format_decimal(out_, n, num_digits); + } + void write_year(long long year, pad_type pad) { + write_year_extended(year, pad); + } + + void write_utc_offset(long long offset, numeric_system ns) { + if (offset < 0) { + *out_++ = '-'; + offset = -offset; + } else { + *out_++ = '+'; + } + offset /= 60; + write2(static_cast(offset / 60)); + if (ns != numeric_system::standard) *out_++ = ':'; + write2(static_cast(offset % 60)); + } + + template ::value)> + void format_utc_offset(const T& tm, numeric_system ns) { + write_utc_offset(tm.tm_gmtoff, ns); + } + template ::value)> + void format_utc_offset(const T&, numeric_system ns) { + write_utc_offset(0, ns); + } + + template ::value)> + void format_tz_name(const T& tm) { + out_ = write_tm_str(out_, tm.tm_zone, loc_); + } + template ::value)> + void format_tz_name(const T&) { + out_ = std::copy_n(utc(), 3, out_); + } + + void format_localized(char format, char modifier = 0) { + out_ = write(out_, tm_, loc_, format, modifier); + } + + public: + tm_writer(const std::locale& loc, OutputIt out, const std::tm& tm, + const Duration* subsecs = nullptr) + : loc_(loc), + is_classic_(loc_ == get_classic_locale()), + out_(out), + subsecs_(subsecs), + tm_(tm) {} + + auto out() const -> OutputIt { return out_; } + + FMT_CONSTEXPR void on_text(const Char* begin, const Char* end) { + out_ = copy(begin, end, out_); + } + + void on_abbr_weekday() { + if (is_classic_) + out_ = write(out_, tm_wday_short_name(tm_wday())); + else + format_localized('a'); + } + void on_full_weekday() { + if (is_classic_) + out_ = write(out_, tm_wday_full_name(tm_wday())); + else + format_localized('A'); + } + void on_dec0_weekday(numeric_system ns) { + if (is_classic_ || ns == numeric_system::standard) return write1(tm_wday()); + format_localized('w', 'O'); + } + void on_dec1_weekday(numeric_system ns) { + if (is_classic_ || ns == numeric_system::standard) { + auto wday = tm_wday(); + write1(wday == 0 ? days_per_week : wday); + } else { + format_localized('u', 'O'); + } + } + + void on_abbr_month() { + if (is_classic_) + out_ = write(out_, tm_mon_short_name(tm_mon())); + else + format_localized('b'); + } + void on_full_month() { + if (is_classic_) + out_ = write(out_, tm_mon_full_name(tm_mon())); + else + format_localized('B'); + } + + void on_datetime(numeric_system ns) { + if (is_classic_) { + on_abbr_weekday(); + *out_++ = ' '; + on_abbr_month(); + *out_++ = ' '; + on_day_of_month(numeric_system::standard, pad_type::space); + *out_++ = ' '; + on_iso_time(); + *out_++ = ' '; + on_year(numeric_system::standard, pad_type::space); + } else { + format_localized('c', ns == numeric_system::standard ? '\0' : 'E'); + } + } + void on_loc_date(numeric_system ns) { + if (is_classic_) + on_us_date(); + else + format_localized('x', ns == numeric_system::standard ? '\0' : 'E'); + } + void on_loc_time(numeric_system ns) { + if (is_classic_) + on_iso_time(); + else + format_localized('X', ns == numeric_system::standard ? '\0' : 'E'); + } + void on_us_date() { + char buf[8]; + write_digit2_separated(buf, to_unsigned(tm_mon() + 1), + to_unsigned(tm_mday()), + to_unsigned(split_year_lower(tm_year())), '/'); + out_ = copy(std::begin(buf), std::end(buf), out_); + } + void on_iso_date() { + auto year = tm_year(); + char buf[10]; + size_t offset = 0; + if (year >= 0 && year < 10000) { + write2digits(buf, static_cast(year / 100)); + } else { + offset = 4; + write_year_extended(year, pad_type::zero); + year = 0; + } + write_digit2_separated(buf + 2, static_cast(year % 100), + to_unsigned(tm_mon() + 1), to_unsigned(tm_mday()), + '-'); + out_ = copy(std::begin(buf) + offset, std::end(buf), out_); + } + + void on_utc_offset(numeric_system ns) { format_utc_offset(tm_, ns); } + void on_tz_name() { format_tz_name(tm_); } + + void on_year(numeric_system ns, pad_type pad) { + if (is_classic_ || ns == numeric_system::standard) + return write_year(tm_year(), pad); + format_localized('Y', 'E'); + } + void on_short_year(numeric_system ns) { + if (is_classic_ || ns == numeric_system::standard) + return write2(split_year_lower(tm_year())); + format_localized('y', 'O'); + } + void on_offset_year() { + if (is_classic_) return write2(split_year_lower(tm_year())); + format_localized('y', 'E'); + } + + void on_century(numeric_system ns) { + if (is_classic_ || ns == numeric_system::standard) { + auto year = tm_year(); + auto upper = year / 100; + if (year >= -99 && year < 0) { + // Zero upper on negative year. + *out_++ = '-'; + *out_++ = '0'; + } else if (upper >= 0 && upper < 100) { + write2(static_cast(upper)); + } else { + out_ = write(out_, upper); + } + } else { + format_localized('C', 'E'); + } + } + + void on_dec_month(numeric_system ns, pad_type pad) { + if (is_classic_ || ns == numeric_system::standard) + return write2(tm_mon() + 1, pad); + format_localized('m', 'O'); + } + + void on_dec0_week_of_year(numeric_system ns, pad_type pad) { + if (is_classic_ || ns == numeric_system::standard) + return write2((tm_yday() + days_per_week - tm_wday()) / days_per_week, + pad); + format_localized('U', 'O'); + } + void on_dec1_week_of_year(numeric_system ns, pad_type pad) { + if (is_classic_ || ns == numeric_system::standard) { + auto wday = tm_wday(); + write2((tm_yday() + days_per_week - + (wday == 0 ? (days_per_week - 1) : (wday - 1))) / + days_per_week, + pad); + } else { + format_localized('W', 'O'); + } + } + void on_iso_week_of_year(numeric_system ns, pad_type pad) { + if (is_classic_ || ns == numeric_system::standard) + return write2(tm_iso_week_of_year(), pad); + format_localized('V', 'O'); + } + + void on_iso_week_based_year() { + write_year(tm_iso_week_year(), pad_type::zero); + } + void on_iso_week_based_short_year() { + write2(split_year_lower(tm_iso_week_year())); + } + + void on_day_of_year(pad_type pad) { + auto yday = tm_yday() + 1; + auto digit1 = yday / 100; + if (digit1 != 0) + write1(digit1); + else + out_ = detail::write_padding(out_, pad); + write2(yday % 100, pad); + } + + void on_day_of_month(numeric_system ns, pad_type pad) { + if (is_classic_ || ns == numeric_system::standard) + return write2(tm_mday(), pad); + format_localized('d', 'O'); + } + + void on_24_hour(numeric_system ns, pad_type pad) { + if (is_classic_ || ns == numeric_system::standard) + return write2(tm_hour(), pad); + format_localized('H', 'O'); + } + void on_12_hour(numeric_system ns, pad_type pad) { + if (is_classic_ || ns == numeric_system::standard) + return write2(tm_hour12(), pad); + format_localized('I', 'O'); + } + void on_minute(numeric_system ns, pad_type pad) { + if (is_classic_ || ns == numeric_system::standard) + return write2(tm_min(), pad); + format_localized('M', 'O'); + } + + void on_second(numeric_system ns, pad_type pad) { + if (is_classic_ || ns == numeric_system::standard) { + write2(tm_sec(), pad); + if (subsecs_) { + if (std::is_floating_point::value) { + auto buf = memory_buffer(); + write_floating_seconds(buf, *subsecs_); + if (buf.size() > 1) { + // Remove the leading "0", write something like ".123". + out_ = copy(buf.begin() + 1, buf.end(), out_); + } + } else { + write_fractional_seconds(out_, *subsecs_); + } + } + } else { + // Currently no formatting of subseconds when a locale is set. + format_localized('S', 'O'); + } + } + + void on_12_hour_time() { + if (is_classic_) { + char buf[8]; + write_digit2_separated(buf, to_unsigned(tm_hour12()), + to_unsigned(tm_min()), to_unsigned(tm_sec()), ':'); + out_ = copy(std::begin(buf), std::end(buf), out_); + *out_++ = ' '; + on_am_pm(); + } else { + format_localized('r'); + } + } + void on_24_hour_time() { + write2(tm_hour()); + *out_++ = ':'; + write2(tm_min()); + } + void on_iso_time() { + on_24_hour_time(); + *out_++ = ':'; + on_second(numeric_system::standard, pad_type::zero); + } + + void on_am_pm() { + if (is_classic_) { + *out_++ = tm_hour() < 12 ? 'A' : 'P'; + *out_++ = 'M'; + } else { + format_localized('p'); + } + } + + // These apply to chrono durations but not tm. + void on_duration_value() {} + void on_duration_unit() {} +}; + +struct chrono_format_checker : null_chrono_spec_handler { + bool has_precision_integral = false; + + FMT_NORETURN inline void unsupported() { FMT_THROW(format_error("no date")); } + + template + FMT_CONSTEXPR void on_text(const Char*, const Char*) {} + FMT_CONSTEXPR void on_day_of_year(pad_type) {} + FMT_CONSTEXPR void on_24_hour(numeric_system, pad_type) {} + FMT_CONSTEXPR void on_12_hour(numeric_system, pad_type) {} + FMT_CONSTEXPR void on_minute(numeric_system, pad_type) {} + FMT_CONSTEXPR void on_second(numeric_system, pad_type) {} + FMT_CONSTEXPR void on_12_hour_time() {} + FMT_CONSTEXPR void on_24_hour_time() {} + FMT_CONSTEXPR void on_iso_time() {} + FMT_CONSTEXPR void on_am_pm() {} + FMT_CONSTEXPR void on_duration_value() const { + if (has_precision_integral) + FMT_THROW(format_error("precision not allowed for this argument type")); + } + FMT_CONSTEXPR void on_duration_unit() {} +}; + +template ::value&& has_isfinite::value)> +inline auto isfinite(T) -> bool { + return true; +} + +template ::value)> +inline auto mod(T x, int y) -> T { + return x % static_cast(y); +} +template ::value)> +inline auto mod(T x, int y) -> T { + return std::fmod(x, static_cast(y)); +} + +// If T is an integral type, maps T to its unsigned counterpart, otherwise +// leaves it unchanged (unlike std::make_unsigned). +template ::value> +struct make_unsigned_or_unchanged { + using type = T; +}; + +template struct make_unsigned_or_unchanged { + using type = typename std::make_unsigned::type; +}; + +template ::value)> +inline auto get_milliseconds(std::chrono::duration d) + -> std::chrono::duration { + // This may overflow and/or the result may not fit in the target type. +#if FMT_SAFE_DURATION_CAST + using common_seconds_type = + typename std::common_type::type; + auto d_as_common = detail::duration_cast(d); + auto d_as_whole_seconds = + detail::duration_cast(d_as_common); + // This conversion should be nonproblematic. + auto diff = d_as_common - d_as_whole_seconds; + auto ms = detail::duration_cast>(diff); + return ms; +#else + auto s = detail::duration_cast(d); + return detail::duration_cast(d - s); +#endif +} + +template ::value)> +auto format_duration_value(OutputIt out, Rep val, int) -> OutputIt { + return write(out, val); +} + +template ::value)> +auto format_duration_value(OutputIt out, Rep val, int precision) -> OutputIt { + auto specs = format_specs(); + specs.precision = precision; + specs.set_type(precision >= 0 ? presentation_type::fixed + : presentation_type::general); + return write(out, val, specs); +} + +template +auto copy_unit(string_view unit, OutputIt out, Char) -> OutputIt { + return copy(unit.begin(), unit.end(), out); +} + +template +auto copy_unit(string_view unit, OutputIt out, wchar_t) -> OutputIt { + // This works when wchar_t is UTF-32 because units only contain characters + // that have the same representation in UTF-16 and UTF-32. + utf8_to_utf16 u(unit); + return copy(u.c_str(), u.c_str() + u.size(), out); +} + +template +auto format_duration_unit(OutputIt out) -> OutputIt { + if (const char* unit = get_units()) + return copy_unit(string_view(unit), out, Char()); + *out++ = '['; + out = write(out, Period::num); + if (const_check(Period::den != 1)) { + *out++ = '/'; + out = write(out, Period::den); + } + *out++ = ']'; + *out++ = 's'; + return out; +} + +class get_locale { + private: + union { + std::locale locale_; + }; + bool has_locale_ = false; + + public: + inline get_locale(bool localized, locale_ref loc) : has_locale_(localized) { + if (!localized) return; + ignore_unused(loc); + ::new (&locale_) std::locale( +#if FMT_USE_LOCALE + loc.template get() +#endif + ); + } + inline ~get_locale() { + if (has_locale_) locale_.~locale(); + } + inline operator const std::locale&() const { + return has_locale_ ? locale_ : get_classic_locale(); + } +}; + +template +struct duration_formatter { + using iterator = basic_appender; + iterator out; + // rep is unsigned to avoid overflow. + using rep = + conditional_t::value && sizeof(Rep) < sizeof(int), + unsigned, typename make_unsigned_or_unchanged::type>; + rep val; + int precision; + locale_ref locale; + bool localized = false; + using seconds = std::chrono::duration; + seconds s; + using milliseconds = std::chrono::duration; + bool negative; + + using tm_writer_type = tm_writer; + + duration_formatter(iterator o, std::chrono::duration d, + locale_ref loc) + : out(o), val(static_cast(d.count())), locale(loc), negative(false) { + if (d.count() < 0) { + val = 0 - val; + negative = true; + } + + // this may overflow and/or the result may not fit in the + // target type. + // might need checked conversion (rep!=Rep) + s = detail::duration_cast(std::chrono::duration(val)); + } + + // returns true if nan or inf, writes to out. + auto handle_nan_inf() -> bool { + if (isfinite(val)) return false; + if (isnan(val)) { + write_nan(); + return true; + } + // must be +-inf + if (val > 0) + std::copy_n("inf", 3, out); + else + std::copy_n("-inf", 4, out); + return true; + } + + auto days() const -> Rep { return static_cast(s.count() / 86400); } + auto hour() const -> Rep { + return static_cast(mod((s.count() / 3600), 24)); + } + + auto hour12() const -> Rep { + Rep hour = static_cast(mod((s.count() / 3600), 12)); + return hour <= 0 ? 12 : hour; + } + + auto minute() const -> Rep { + return static_cast(mod((s.count() / 60), 60)); + } + auto second() const -> Rep { return static_cast(mod(s.count(), 60)); } + + auto time() const -> std::tm { + auto time = std::tm(); + time.tm_hour = to_nonnegative_int(hour(), 24); + time.tm_min = to_nonnegative_int(minute(), 60); + time.tm_sec = to_nonnegative_int(second(), 60); + return time; + } + + void write_sign() { + if (!negative) return; + *out++ = '-'; + negative = false; + } + + void write(Rep value, int width, pad_type pad = pad_type::zero) { + write_sign(); + if (isnan(value)) return write_nan(); + uint32_or_64_or_128_t n = + to_unsigned(to_nonnegative_int(value, max_value())); + int num_digits = detail::count_digits(n); + if (width > num_digits) { + out = detail::write_padding(out, pad, width - num_digits); + } + out = format_decimal(out, n, num_digits); + } + + void write_nan() { std::copy_n("nan", 3, out); } + + template + void format_tm(const tm& time, Callback cb, Args... args) { + if (isnan(val)) return write_nan(); + get_locale loc(localized, locale); + auto w = tm_writer_type(loc, out, time); + (w.*cb)(args...); + out = w.out(); + } + + void on_text(const Char* begin, const Char* end) { + copy(begin, end, out); + } + + // These are not implemented because durations don't have date information. + void on_abbr_weekday() {} + void on_full_weekday() {} + void on_dec0_weekday(numeric_system) {} + void on_dec1_weekday(numeric_system) {} + void on_abbr_month() {} + void on_full_month() {} + void on_datetime(numeric_system) {} + void on_loc_date(numeric_system) {} + void on_loc_time(numeric_system) {} + void on_us_date() {} + void on_iso_date() {} + void on_utc_offset(numeric_system) {} + void on_tz_name() {} + void on_year(numeric_system, pad_type) {} + void on_short_year(numeric_system) {} + void on_offset_year() {} + void on_century(numeric_system) {} + void on_iso_week_based_year() {} + void on_iso_week_based_short_year() {} + void on_dec_month(numeric_system, pad_type) {} + void on_dec0_week_of_year(numeric_system, pad_type) {} + void on_dec1_week_of_year(numeric_system, pad_type) {} + void on_iso_week_of_year(numeric_system, pad_type) {} + void on_day_of_month(numeric_system, pad_type) {} + + void on_day_of_year(pad_type) { + if (handle_nan_inf()) return; + write(days(), 0); + } + + void on_24_hour(numeric_system ns, pad_type pad) { + if (handle_nan_inf()) return; + + if (ns == numeric_system::standard) return write(hour(), 2, pad); + auto time = tm(); + time.tm_hour = to_nonnegative_int(hour(), 24); + format_tm(time, &tm_writer_type::on_24_hour, ns, pad); + } + + void on_12_hour(numeric_system ns, pad_type pad) { + if (handle_nan_inf()) return; + + if (ns == numeric_system::standard) return write(hour12(), 2, pad); + auto time = tm(); + time.tm_hour = to_nonnegative_int(hour12(), 12); + format_tm(time, &tm_writer_type::on_12_hour, ns, pad); + } + + void on_minute(numeric_system ns, pad_type pad) { + if (handle_nan_inf()) return; + + if (ns == numeric_system::standard) return write(minute(), 2, pad); + auto time = tm(); + time.tm_min = to_nonnegative_int(minute(), 60); + format_tm(time, &tm_writer_type::on_minute, ns, pad); + } + + void on_second(numeric_system ns, pad_type pad) { + if (handle_nan_inf()) return; + + if (ns == numeric_system::standard) { + if (std::is_floating_point::value) { + auto buf = memory_buffer(); + write_floating_seconds(buf, std::chrono::duration(val), + precision); + if (negative) *out++ = '-'; + if (buf.size() < 2 || buf[1] == '.') + out = detail::write_padding(out, pad); + out = copy(buf.begin(), buf.end(), out); + } else { + write(second(), 2, pad); + write_fractional_seconds( + out, std::chrono::duration(val), precision); + } + return; + } + auto time = tm(); + time.tm_sec = to_nonnegative_int(second(), 60); + format_tm(time, &tm_writer_type::on_second, ns, pad); + } + + void on_12_hour_time() { + if (handle_nan_inf()) return; + format_tm(time(), &tm_writer_type::on_12_hour_time); + } + + void on_24_hour_time() { + if (handle_nan_inf()) { + *out++ = ':'; + handle_nan_inf(); + return; + } + + write(hour(), 2); + *out++ = ':'; + write(minute(), 2); + } + + void on_iso_time() { + on_24_hour_time(); + *out++ = ':'; + if (handle_nan_inf()) return; + on_second(numeric_system::standard, pad_type::zero); + } + + void on_am_pm() { + if (handle_nan_inf()) return; + format_tm(time(), &tm_writer_type::on_am_pm); + } + + void on_duration_value() { + if (handle_nan_inf()) return; + write_sign(); + out = format_duration_value(out, val, precision); + } + + void on_duration_unit() { out = format_duration_unit(out); } +}; + +} // namespace detail + +#if defined(__cpp_lib_chrono) && __cpp_lib_chrono >= 201907 +using weekday = std::chrono::weekday; +using day = std::chrono::day; +using month = std::chrono::month; +using year = std::chrono::year; +using year_month_day = std::chrono::year_month_day; +#else +// A fallback version of weekday. +class weekday { + private: + unsigned char value_; + + public: + weekday() = default; + constexpr explicit weekday(unsigned wd) noexcept + : value_(static_cast(wd != 7 ? wd : 0)) {} + constexpr auto c_encoding() const noexcept -> unsigned { return value_; } +}; + +class day { + private: + unsigned char value_; + + public: + day() = default; + constexpr explicit day(unsigned d) noexcept + : value_(static_cast(d)) {} + constexpr explicit operator unsigned() const noexcept { return value_; } +}; + +class month { + private: + unsigned char value_; + + public: + month() = default; + constexpr explicit month(unsigned m) noexcept + : value_(static_cast(m)) {} + constexpr explicit operator unsigned() const noexcept { return value_; } +}; + +class year { + private: + int value_; + + public: + year() = default; + constexpr explicit year(int y) noexcept : value_(y) {} + constexpr explicit operator int() const noexcept { return value_; } +}; + +class year_month_day { + private: + fmt::year year_; + fmt::month month_; + fmt::day day_; + + public: + year_month_day() = default; + constexpr year_month_day(const year& y, const month& m, const day& d) noexcept + : year_(y), month_(m), day_(d) {} + constexpr auto year() const noexcept -> fmt::year { return year_; } + constexpr auto month() const noexcept -> fmt::month { return month_; } + constexpr auto day() const noexcept -> fmt::day { return day_; } +}; +#endif // __cpp_lib_chrono >= 201907 + +template +struct formatter : private formatter { + private: + bool use_tm_formatter_ = false; + + public: + FMT_CONSTEXPR auto parse(parse_context& ctx) -> const Char* { + auto it = ctx.begin(), end = ctx.end(); + if (it != end && *it == 'L') { + ++it; + this->set_localized(); + } + use_tm_formatter_ = it != end && *it != '}'; + return use_tm_formatter_ ? formatter::parse(ctx) : it; + } + + template + auto format(weekday wd, FormatContext& ctx) const -> decltype(ctx.out()) { + auto time = std::tm(); + time.tm_wday = static_cast(wd.c_encoding()); + if (use_tm_formatter_) return formatter::format(time, ctx); + detail::get_locale loc(this->localized(), ctx.locale()); + auto w = detail::tm_writer(loc, ctx.out(), time); + w.on_abbr_weekday(); + return w.out(); + } +}; + +template +struct formatter : private formatter { + private: + bool use_tm_formatter_ = false; + + public: + FMT_CONSTEXPR auto parse(parse_context& ctx) -> const Char* { + auto it = ctx.begin(), end = ctx.end(); + use_tm_formatter_ = it != end && *it != '}'; + return use_tm_formatter_ ? formatter::parse(ctx) : it; + } + + template + auto format(day d, FormatContext& ctx) const -> decltype(ctx.out()) { + auto time = std::tm(); + time.tm_mday = static_cast(static_cast(d)); + if (use_tm_formatter_) return formatter::format(time, ctx); + detail::get_locale loc(false, ctx.locale()); + auto w = detail::tm_writer(loc, ctx.out(), time); + w.on_day_of_month(detail::numeric_system::standard, detail::pad_type::zero); + return w.out(); + } +}; + +template +struct formatter : private formatter { + private: + bool use_tm_formatter_ = false; + + public: + FMT_CONSTEXPR auto parse(parse_context& ctx) -> const Char* { + auto it = ctx.begin(), end = ctx.end(); + if (it != end && *it == 'L') { + ++it; + this->set_localized(); + } + use_tm_formatter_ = it != end && *it != '}'; + return use_tm_formatter_ ? formatter::parse(ctx) : it; + } + + template + auto format(month m, FormatContext& ctx) const -> decltype(ctx.out()) { + auto time = std::tm(); + time.tm_mon = static_cast(static_cast(m)) - 1; + if (use_tm_formatter_) return formatter::format(time, ctx); + detail::get_locale loc(this->localized(), ctx.locale()); + auto w = detail::tm_writer(loc, ctx.out(), time); + w.on_abbr_month(); + return w.out(); + } +}; + +template +struct formatter : private formatter { + private: + bool use_tm_formatter_ = false; + + public: + FMT_CONSTEXPR auto parse(parse_context& ctx) -> const Char* { + auto it = ctx.begin(), end = ctx.end(); + use_tm_formatter_ = it != end && *it != '}'; + return use_tm_formatter_ ? formatter::parse(ctx) : it; + } + + template + auto format(year y, FormatContext& ctx) const -> decltype(ctx.out()) { + auto time = std::tm(); + time.tm_year = static_cast(y) - 1900; + if (use_tm_formatter_) return formatter::format(time, ctx); + detail::get_locale loc(false, ctx.locale()); + auto w = detail::tm_writer(loc, ctx.out(), time); + w.on_year(detail::numeric_system::standard, detail::pad_type::zero); + return w.out(); + } +}; + +template +struct formatter : private formatter { + private: + bool use_tm_formatter_ = false; + + public: + FMT_CONSTEXPR auto parse(parse_context& ctx) -> const Char* { + auto it = ctx.begin(), end = ctx.end(); + use_tm_formatter_ = it != end && *it != '}'; + return use_tm_formatter_ ? formatter::parse(ctx) : it; + } + + template + auto format(year_month_day val, FormatContext& ctx) const + -> decltype(ctx.out()) { + auto time = std::tm(); + time.tm_year = static_cast(val.year()) - 1900; + time.tm_mon = static_cast(static_cast(val.month())) - 1; + time.tm_mday = static_cast(static_cast(val.day())); + if (use_tm_formatter_) return formatter::format(time, ctx); + detail::get_locale loc(true, ctx.locale()); + auto w = detail::tm_writer(loc, ctx.out(), time); + w.on_iso_date(); + return w.out(); + } +}; + +template +struct formatter, Char> { + private: + format_specs specs_; + detail::arg_ref width_ref_; + detail::arg_ref precision_ref_; + basic_string_view fmt_; + + public: + FMT_CONSTEXPR auto parse(parse_context& ctx) -> const Char* { + auto it = ctx.begin(), end = ctx.end(); + if (it == end || *it == '}') return it; + + it = detail::parse_align(it, end, specs_); + if (it == end) return it; + + Char c = *it; + if ((c >= '0' && c <= '9') || c == '{') { + it = detail::parse_width(it, end, specs_, width_ref_, ctx); + if (it == end) return it; + } + + auto checker = detail::chrono_format_checker(); + if (*it == '.') { + checker.has_precision_integral = !std::is_floating_point::value; + it = detail::parse_precision(it, end, specs_, precision_ref_, ctx); + } + if (it != end && *it == 'L') { + specs_.set_localized(); + ++it; + } + end = detail::parse_chrono_format(it, end, checker); + fmt_ = {it, detail::to_unsigned(end - it)}; + return end; + } + + template + auto format(std::chrono::duration d, FormatContext& ctx) const + -> decltype(ctx.out()) { + auto specs = specs_; + auto precision = specs.precision; + specs.precision = -1; + auto begin = fmt_.begin(), end = fmt_.end(); + // As a possible future optimization, we could avoid extra copying if width + // is not specified. + auto buf = basic_memory_buffer(); + auto out = basic_appender(buf); + detail::handle_dynamic_spec(specs.dynamic_width(), specs.width, width_ref_, + ctx); + detail::handle_dynamic_spec(specs.dynamic_precision(), precision, + precision_ref_, ctx); + if (begin == end || *begin == '}') { + out = detail::format_duration_value(out, d.count(), precision); + detail::format_duration_unit(out); + } else { + auto f = + detail::duration_formatter(out, d, ctx.locale()); + f.precision = precision; + f.localized = specs_.localized(); + detail::parse_chrono_format(begin, end, f); + } + return detail::write( + ctx.out(), basic_string_view(buf.data(), buf.size()), specs); + } +}; + +template struct formatter { + private: + format_specs specs_; + detail::arg_ref width_ref_; + basic_string_view fmt_ = + detail::string_literal(); + + protected: + auto localized() const -> bool { return specs_.localized(); } + FMT_CONSTEXPR void set_localized() { specs_.set_localized(); } + + FMT_CONSTEXPR auto do_parse(parse_context& ctx, bool has_timezone) + -> const Char* { + auto it = ctx.begin(), end = ctx.end(); + if (it == end || *it == '}') return it; + + it = detail::parse_align(it, end, specs_); + if (it == end) return it; + + Char c = *it; + if ((c >= '0' && c <= '9') || c == '{') { + it = detail::parse_width(it, end, specs_, width_ref_, ctx); + if (it == end) return it; + } + + if (*it == 'L') { + specs_.set_localized(); + ++it; + } + + end = detail::parse_chrono_format(it, end, + detail::tm_format_checker(has_timezone)); + // Replace the default format string only if the new spec is not empty. + if (end != it) fmt_ = {it, detail::to_unsigned(end - it)}; + return end; + } + + template + auto do_format(const std::tm& tm, FormatContext& ctx, + const Duration* subsecs) const -> decltype(ctx.out()) { + auto specs = specs_; + auto buf = basic_memory_buffer(); + auto out = basic_appender(buf); + detail::handle_dynamic_spec(specs.dynamic_width(), specs.width, width_ref_, + ctx); + + auto loc_ref = specs.localized() ? ctx.locale() : locale_ref(); + detail::get_locale loc(static_cast(loc_ref), loc_ref); + auto w = detail::tm_writer, Char, Duration>( + loc, out, tm, subsecs); + detail::parse_chrono_format(fmt_.begin(), fmt_.end(), w); + return detail::write( + ctx.out(), basic_string_view(buf.data(), buf.size()), specs); + } + + public: + FMT_CONSTEXPR auto parse(parse_context& ctx) -> const Char* { + return do_parse(ctx, detail::has_tm_gmtoff::value); + } + + template + auto format(const std::tm& tm, FormatContext& ctx) const + -> decltype(ctx.out()) { + return do_format(tm, ctx, nullptr); + } +}; + +// DEPRECATED! Reversed order of template parameters. +template +struct formatter, Char> : private formatter { + FMT_CONSTEXPR auto parse(parse_context& ctx) -> const Char* { + return this->do_parse(ctx, true); + } + + template + auto format(sys_time val, FormatContext& ctx) const + -> decltype(ctx.out()) { + std::tm tm = gmtime(val); + using period = typename Duration::period; + if (detail::const_check( + period::num == 1 && period::den == 1 && + !std::is_floating_point::value)) { + detail::set_tm_zone(tm, detail::utc()); + return formatter::format(tm, ctx); + } + Duration epoch = val.time_since_epoch(); + Duration subsecs = detail::duration_cast( + epoch - detail::duration_cast(epoch)); + if (subsecs.count() < 0) { + auto second = detail::duration_cast(std::chrono::seconds(1)); + if (tm.tm_sec != 0) { + --tm.tm_sec; + } else { + tm = gmtime(val - second); + detail::set_tm_zone(tm, detail::utc()); + } + subsecs += second; + } + return formatter::do_format(tm, ctx, &subsecs); + } +}; + +template +struct formatter, Char> + : formatter, Char> { + template + auto format(utc_time val, FormatContext& ctx) const + -> decltype(ctx.out()) { + return formatter, Char>::format( + detail::utc_clock::to_sys(val), ctx); + } +}; + +template +struct formatter, Char> + : private formatter { + FMT_CONSTEXPR auto parse(parse_context& ctx) -> const Char* { + return this->do_parse(ctx, false); + } + + template + auto format(local_time val, FormatContext& ctx) const + -> decltype(ctx.out()) { + auto time_since_epoch = val.time_since_epoch(); + auto seconds_since_epoch = + detail::duration_cast(time_since_epoch); + // Use gmtime to prevent time zone conversion since local_time has an + // unspecified time zone. + std::tm t = gmtime(seconds_since_epoch.count()); + using period = typename Duration::period; + if (period::num == 1 && period::den == 1 && + !std::is_floating_point::value) { + return formatter::format(t, ctx); + } + auto subsecs = + detail::duration_cast(time_since_epoch - seconds_since_epoch); + return formatter::do_format(t, ctx, &subsecs); + } +}; + +FMT_END_EXPORT +FMT_END_NAMESPACE + +#endif // FMT_CHRONO_H_ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fmt/color.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fmt/color.h new file mode 100644 index 0000000000000000000000000000000000000000..3246ddceea28f0df7c347a5661d597e21a149edc --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fmt/color.h @@ -0,0 +1,642 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Formatting library for C++ - color support +// +// Copyright (c) 2018 - present, Victor Zverovich and fmt contributors +// All rights reserved. +// +// For the license information refer to format.h. + +#ifndef FMT_COLOR_H_ +#define FMT_COLOR_H_ + +#include "format.h" + +FMT_BEGIN_NAMESPACE +FMT_BEGIN_EXPORT + +enum class color : uint32_t { + alice_blue = 0xF0F8FF, // rgb(240,248,255) + antique_white = 0xFAEBD7, // rgb(250,235,215) + aqua = 0x00FFFF, // rgb(0,255,255) + aquamarine = 0x7FFFD4, // rgb(127,255,212) + azure = 0xF0FFFF, // rgb(240,255,255) + beige = 0xF5F5DC, // rgb(245,245,220) + bisque = 0xFFE4C4, // rgb(255,228,196) + black = 0x000000, // rgb(0,0,0) + blanched_almond = 0xFFEBCD, // rgb(255,235,205) + blue = 0x0000FF, // rgb(0,0,255) + blue_violet = 0x8A2BE2, // rgb(138,43,226) + brown = 0xA52A2A, // rgb(165,42,42) + burly_wood = 0xDEB887, // rgb(222,184,135) + cadet_blue = 0x5F9EA0, // rgb(95,158,160) + chartreuse = 0x7FFF00, // rgb(127,255,0) + chocolate = 0xD2691E, // rgb(210,105,30) + coral = 0xFF7F50, // rgb(255,127,80) + cornflower_blue = 0x6495ED, // rgb(100,149,237) + cornsilk = 0xFFF8DC, // rgb(255,248,220) + crimson = 0xDC143C, // rgb(220,20,60) + cyan = 0x00FFFF, // rgb(0,255,255) + dark_blue = 0x00008B, // rgb(0,0,139) + dark_cyan = 0x008B8B, // rgb(0,139,139) + dark_golden_rod = 0xB8860B, // rgb(184,134,11) + dark_gray = 0xA9A9A9, // rgb(169,169,169) + dark_green = 0x006400, // rgb(0,100,0) + dark_khaki = 0xBDB76B, // rgb(189,183,107) + dark_magenta = 0x8B008B, // rgb(139,0,139) + dark_olive_green = 0x556B2F, // rgb(85,107,47) + dark_orange = 0xFF8C00, // rgb(255,140,0) + dark_orchid = 0x9932CC, // rgb(153,50,204) + dark_red = 0x8B0000, // rgb(139,0,0) + dark_salmon = 0xE9967A, // rgb(233,150,122) + dark_sea_green = 0x8FBC8F, // rgb(143,188,143) + dark_slate_blue = 0x483D8B, // rgb(72,61,139) + dark_slate_gray = 0x2F4F4F, // rgb(47,79,79) + dark_turquoise = 0x00CED1, // rgb(0,206,209) + dark_violet = 0x9400D3, // rgb(148,0,211) + deep_pink = 0xFF1493, // rgb(255,20,147) + deep_sky_blue = 0x00BFFF, // rgb(0,191,255) + dim_gray = 0x696969, // rgb(105,105,105) + dodger_blue = 0x1E90FF, // rgb(30,144,255) + fire_brick = 0xB22222, // rgb(178,34,34) + floral_white = 0xFFFAF0, // rgb(255,250,240) + forest_green = 0x228B22, // rgb(34,139,34) + fuchsia = 0xFF00FF, // rgb(255,0,255) + gainsboro = 0xDCDCDC, // rgb(220,220,220) + ghost_white = 0xF8F8FF, // rgb(248,248,255) + gold = 0xFFD700, // rgb(255,215,0) + golden_rod = 0xDAA520, // rgb(218,165,32) + gray = 0x808080, // rgb(128,128,128) + green = 0x008000, // rgb(0,128,0) + green_yellow = 0xADFF2F, // rgb(173,255,47) + honey_dew = 0xF0FFF0, // rgb(240,255,240) + hot_pink = 0xFF69B4, // rgb(255,105,180) + indian_red = 0xCD5C5C, // rgb(205,92,92) + indigo = 0x4B0082, // rgb(75,0,130) + ivory = 0xFFFFF0, // rgb(255,255,240) + khaki = 0xF0E68C, // rgb(240,230,140) + lavender = 0xE6E6FA, // rgb(230,230,250) + lavender_blush = 0xFFF0F5, // rgb(255,240,245) + lawn_green = 0x7CFC00, // rgb(124,252,0) + lemon_chiffon = 0xFFFACD, // rgb(255,250,205) + light_blue = 0xADD8E6, // rgb(173,216,230) + light_coral = 0xF08080, // rgb(240,128,128) + light_cyan = 0xE0FFFF, // rgb(224,255,255) + light_golden_rod_yellow = 0xFAFAD2, // rgb(250,250,210) + light_gray = 0xD3D3D3, // rgb(211,211,211) + light_green = 0x90EE90, // rgb(144,238,144) + light_pink = 0xFFB6C1, // rgb(255,182,193) + light_salmon = 0xFFA07A, // rgb(255,160,122) + light_sea_green = 0x20B2AA, // rgb(32,178,170) + light_sky_blue = 0x87CEFA, // rgb(135,206,250) + light_slate_gray = 0x778899, // rgb(119,136,153) + light_steel_blue = 0xB0C4DE, // rgb(176,196,222) + light_yellow = 0xFFFFE0, // rgb(255,255,224) + lime = 0x00FF00, // rgb(0,255,0) + lime_green = 0x32CD32, // rgb(50,205,50) + linen = 0xFAF0E6, // rgb(250,240,230) + magenta = 0xFF00FF, // rgb(255,0,255) + maroon = 0x800000, // rgb(128,0,0) + medium_aquamarine = 0x66CDAA, // rgb(102,205,170) + medium_blue = 0x0000CD, // rgb(0,0,205) + medium_orchid = 0xBA55D3, // rgb(186,85,211) + medium_purple = 0x9370DB, // rgb(147,112,219) + medium_sea_green = 0x3CB371, // rgb(60,179,113) + medium_slate_blue = 0x7B68EE, // rgb(123,104,238) + medium_spring_green = 0x00FA9A, // rgb(0,250,154) + medium_turquoise = 0x48D1CC, // rgb(72,209,204) + medium_violet_red = 0xC71585, // rgb(199,21,133) + midnight_blue = 0x191970, // rgb(25,25,112) + mint_cream = 0xF5FFFA, // rgb(245,255,250) + misty_rose = 0xFFE4E1, // rgb(255,228,225) + moccasin = 0xFFE4B5, // rgb(255,228,181) + navajo_white = 0xFFDEAD, // rgb(255,222,173) + navy = 0x000080, // rgb(0,0,128) + old_lace = 0xFDF5E6, // rgb(253,245,230) + olive = 0x808000, // rgb(128,128,0) + olive_drab = 0x6B8E23, // rgb(107,142,35) + orange = 0xFFA500, // rgb(255,165,0) + orange_red = 0xFF4500, // rgb(255,69,0) + orchid = 0xDA70D6, // rgb(218,112,214) + pale_golden_rod = 0xEEE8AA, // rgb(238,232,170) + pale_green = 0x98FB98, // rgb(152,251,152) + pale_turquoise = 0xAFEEEE, // rgb(175,238,238) + pale_violet_red = 0xDB7093, // rgb(219,112,147) + papaya_whip = 0xFFEFD5, // rgb(255,239,213) + peach_puff = 0xFFDAB9, // rgb(255,218,185) + peru = 0xCD853F, // rgb(205,133,63) + pink = 0xFFC0CB, // rgb(255,192,203) + plum = 0xDDA0DD, // rgb(221,160,221) + powder_blue = 0xB0E0E6, // rgb(176,224,230) + purple = 0x800080, // rgb(128,0,128) + rebecca_purple = 0x663399, // rgb(102,51,153) + red = 0xFF0000, // rgb(255,0,0) + rosy_brown = 0xBC8F8F, // rgb(188,143,143) + royal_blue = 0x4169E1, // rgb(65,105,225) + saddle_brown = 0x8B4513, // rgb(139,69,19) + salmon = 0xFA8072, // rgb(250,128,114) + sandy_brown = 0xF4A460, // rgb(244,164,96) + sea_green = 0x2E8B57, // rgb(46,139,87) + sea_shell = 0xFFF5EE, // rgb(255,245,238) + sienna = 0xA0522D, // rgb(160,82,45) + silver = 0xC0C0C0, // rgb(192,192,192) + sky_blue = 0x87CEEB, // rgb(135,206,235) + slate_blue = 0x6A5ACD, // rgb(106,90,205) + slate_gray = 0x708090, // rgb(112,128,144) + snow = 0xFFFAFA, // rgb(255,250,250) + spring_green = 0x00FF7F, // rgb(0,255,127) + steel_blue = 0x4682B4, // rgb(70,130,180) + tan = 0xD2B48C, // rgb(210,180,140) + teal = 0x008080, // rgb(0,128,128) + thistle = 0xD8BFD8, // rgb(216,191,216) + tomato = 0xFF6347, // rgb(255,99,71) + turquoise = 0x40E0D0, // rgb(64,224,208) + violet = 0xEE82EE, // rgb(238,130,238) + wheat = 0xF5DEB3, // rgb(245,222,179) + white = 0xFFFFFF, // rgb(255,255,255) + white_smoke = 0xF5F5F5, // rgb(245,245,245) + yellow = 0xFFFF00, // rgb(255,255,0) + yellow_green = 0x9ACD32 // rgb(154,205,50) +}; // enum class color + +enum class terminal_color : uint8_t { + black = 30, + red, + green, + yellow, + blue, + magenta, + cyan, + white, + bright_black = 90, + bright_red, + bright_green, + bright_yellow, + bright_blue, + bright_magenta, + bright_cyan, + bright_white +}; + +enum class emphasis : uint8_t { + bold = 1, + faint = 1 << 1, + italic = 1 << 2, + underline = 1 << 3, + blink = 1 << 4, + reverse = 1 << 5, + conceal = 1 << 6, + strikethrough = 1 << 7, +}; + +// rgb is a struct for red, green and blue colors. +// Using the name "rgb" makes some editors show the color in a tooltip. +struct rgb { + constexpr rgb() : r(0), g(0), b(0) {} + constexpr rgb(uint8_t r_, uint8_t g_, uint8_t b_) : r(r_), g(g_), b(b_) {} + constexpr rgb(uint32_t hex) + : r((hex >> 16) & 0xFF), g((hex >> 8) & 0xFF), b(hex & 0xFF) {} + constexpr rgb(color hex) + : r((uint32_t(hex) >> 16) & 0xFF), + g((uint32_t(hex) >> 8) & 0xFF), + b(uint32_t(hex) & 0xFF) {} + uint8_t r; + uint8_t g; + uint8_t b; +}; + +namespace detail { + +// A bit-packed variant of an RGB color, a terminal color, or unset color. +// see text_style for the bit-packing scheme. +struct color_type { + constexpr color_type() noexcept = default; + constexpr color_type(color rgb_color) noexcept + : value_(static_cast(rgb_color) | (1 << 24)) {} + constexpr color_type(rgb rgb_color) noexcept + : color_type(static_cast( + (static_cast(rgb_color.r) << 16) | + (static_cast(rgb_color.g) << 8) | rgb_color.b)) {} + constexpr color_type(terminal_color term_color) noexcept + : value_(static_cast(term_color) | (3 << 24)) {} + + constexpr auto is_terminal_color() const noexcept -> bool { + return (value_ & (1 << 25)) != 0; + } + + constexpr auto value() const noexcept -> uint32_t { + return value_ & 0xFFFFFF; + } + + constexpr color_type(uint32_t value) noexcept : value_(value) {} + + uint32_t value_ = 0; +}; +} // namespace detail + +/// A text style consisting of foreground and background colors and emphasis. +class text_style { + // The information is packed as follows: + // ┌──┐ + // │ 0│─┐ + // │..│ ├── foreground color value + // │23│─┘ + // ├──┤ + // │24│─┬── discriminator for the above value. 00 if unset, 01 if it's + // │25│─┘ an RGB color, or 11 if it's a terminal color (10 is unused) + // ├──┤ + // │26│──── overflow bit, always zero (see below) + // ├──┤ + // │27│─┐ + // │..│ │ + // │50│ │ + // ├──┤ │ + // │51│ ├── background color (same format as the foreground color) + // │52│ │ + // ├──┤ │ + // │53│─┘ + // ├──┤ + // │54│─┐ + // │..│ ├── emphases + // │61│─┘ + // ├──┤ + // │62│─┬── unused + // │63│─┘ + // └──┘ + // The overflow bits are there to make operator|= efficient. + // When ORing, we must throw if, for either the foreground or background, + // one style specifies a terminal color and the other specifies any color + // (terminal or RGB); in other words, if one discriminator is 11 and the + // other is 11 or 01. + // + // We do that check by adding the styles. Consider what adding does to each + // possible pair of discriminators: + // 00 + 00 = 000 + // 01 + 00 = 001 + // 11 + 00 = 011 + // 01 + 01 = 010 + // 11 + 01 = 100 (!!) + // 11 + 11 = 110 (!!) + // In the last two cases, the ones we want to catch, the third bit——the + // overflow bit——is set. Bingo. + // + // We must take into account the possible carry bit from the bits + // before the discriminator. The only potentially problematic case is + // 11 + 00 = 011 (a carry bit would make it 100, not good!), but a carry + // bit is impossible in that case, because 00 (unset color) means the + // 24 bits that precede the discriminator are all zero. + // + // This test can be applied to both colors simultaneously. + + public: + FMT_CONSTEXPR text_style(emphasis em = emphasis()) noexcept + : style_(static_cast(em) << 54) {} + + FMT_CONSTEXPR auto operator|=(text_style rhs) -> text_style& { + if (((style_ + rhs.style_) & ((1ULL << 26) | (1ULL << 53))) != 0) + report_error("can't OR a terminal color"); + style_ |= rhs.style_; + return *this; + } + + friend FMT_CONSTEXPR auto operator|(text_style lhs, text_style rhs) + -> text_style { + return lhs |= rhs; + } + + FMT_CONSTEXPR auto operator==(text_style rhs) const noexcept -> bool { + return style_ == rhs.style_; + } + + FMT_CONSTEXPR auto operator!=(text_style rhs) const noexcept -> bool { + return !(*this == rhs); + } + + FMT_CONSTEXPR auto has_foreground() const noexcept -> bool { + return (style_ & (1 << 24)) != 0; + } + FMT_CONSTEXPR auto has_background() const noexcept -> bool { + return (style_ & (1ULL << 51)) != 0; + } + FMT_CONSTEXPR auto has_emphasis() const noexcept -> bool { + return (style_ >> 54) != 0; + } + FMT_CONSTEXPR auto get_foreground() const noexcept -> detail::color_type { + FMT_ASSERT(has_foreground(), "no foreground specified for this style"); + return style_ & 0x3FFFFFF; + } + FMT_CONSTEXPR auto get_background() const noexcept -> detail::color_type { + FMT_ASSERT(has_background(), "no background specified for this style"); + return (style_ >> 27) & 0x3FFFFFF; + } + FMT_CONSTEXPR auto get_emphasis() const noexcept -> emphasis { + FMT_ASSERT(has_emphasis(), "no emphasis specified for this style"); + return static_cast(style_ >> 54); + } + + private: + FMT_CONSTEXPR text_style(uint64_t style) noexcept : style_(style) {} + + friend FMT_CONSTEXPR auto fg(detail::color_type foreground) noexcept + -> text_style; + + friend FMT_CONSTEXPR auto bg(detail::color_type background) noexcept + -> text_style; + + uint64_t style_ = 0; +}; + +/// Creates a text style from the foreground (text) color. +FMT_CONSTEXPR inline auto fg(detail::color_type foreground) noexcept + -> text_style { + return foreground.value_; +} + +/// Creates a text style from the background color. +FMT_CONSTEXPR inline auto bg(detail::color_type background) noexcept + -> text_style { + return static_cast(background.value_) << 27; +} + +FMT_CONSTEXPR inline auto operator|(emphasis lhs, emphasis rhs) noexcept + -> text_style { + return text_style(lhs) | rhs; +} + +namespace detail { + +template struct ansi_color_escape { + FMT_CONSTEXPR ansi_color_escape(color_type text_color, + const char* esc) noexcept { + // If we have a terminal color, we need to output another escape code + // sequence. + if (text_color.is_terminal_color()) { + bool is_background = esc == string_view("\x1b[48;2;"); + uint32_t value = text_color.value(); + // Background ASCII codes are the same as the foreground ones but with + // 10 more. + if (is_background) value += 10u; + + buffer[size++] = static_cast('\x1b'); + buffer[size++] = static_cast('['); + + if (value >= 100u) { + buffer[size++] = static_cast('1'); + value %= 100u; + } + buffer[size++] = static_cast('0' + value / 10u); + buffer[size++] = static_cast('0' + value % 10u); + + buffer[size++] = static_cast('m'); + return; + } + + for (int i = 0; i < 7; i++) { + buffer[i] = static_cast(esc[i]); + } + rgb color(text_color.value()); + to_esc(color.r, buffer + 7, ';'); + to_esc(color.g, buffer + 11, ';'); + to_esc(color.b, buffer + 15, 'm'); + size = 19; + } + FMT_CONSTEXPR ansi_color_escape(emphasis em) noexcept { + uint8_t em_codes[num_emphases] = {}; + if (has_emphasis(em, emphasis::bold)) em_codes[0] = 1; + if (has_emphasis(em, emphasis::faint)) em_codes[1] = 2; + if (has_emphasis(em, emphasis::italic)) em_codes[2] = 3; + if (has_emphasis(em, emphasis::underline)) em_codes[3] = 4; + if (has_emphasis(em, emphasis::blink)) em_codes[4] = 5; + if (has_emphasis(em, emphasis::reverse)) em_codes[5] = 7; + if (has_emphasis(em, emphasis::conceal)) em_codes[6] = 8; + if (has_emphasis(em, emphasis::strikethrough)) em_codes[7] = 9; + + buffer[size++] = static_cast('\x1b'); + buffer[size++] = static_cast('['); + + for (size_t i = 0; i < num_emphases; ++i) { + if (!em_codes[i]) continue; + buffer[size++] = static_cast('0' + em_codes[i]); + buffer[size++] = static_cast(';'); + } + + buffer[size - 1] = static_cast('m'); + } + FMT_CONSTEXPR operator const Char*() const noexcept { return buffer; } + + FMT_CONSTEXPR auto begin() const noexcept -> const Char* { return buffer; } + FMT_CONSTEXPR auto end() const noexcept -> const Char* { + return buffer + size; + } + + private: + static constexpr size_t num_emphases = 8; + Char buffer[7u + 4u * num_emphases] = {}; + size_t size = 0; + + static FMT_CONSTEXPR void to_esc(uint8_t c, Char* out, + char delimiter) noexcept { + out[0] = static_cast('0' + c / 100); + out[1] = static_cast('0' + c / 10 % 10); + out[2] = static_cast('0' + c % 10); + out[3] = static_cast(delimiter); + } + static FMT_CONSTEXPR auto has_emphasis(emphasis em, emphasis mask) noexcept + -> bool { + return static_cast(em) & static_cast(mask); + } +}; + +template +FMT_CONSTEXPR auto make_foreground_color(color_type foreground) noexcept + -> ansi_color_escape { + return ansi_color_escape(foreground, "\x1b[38;2;"); +} + +template +FMT_CONSTEXPR auto make_background_color(color_type background) noexcept + -> ansi_color_escape { + return ansi_color_escape(background, "\x1b[48;2;"); +} + +template +FMT_CONSTEXPR auto make_emphasis(emphasis em) noexcept + -> ansi_color_escape { + return ansi_color_escape(em); +} + +template inline void reset_color(buffer& buffer) { + auto reset_color = string_view("\x1b[0m"); + buffer.append(reset_color.begin(), reset_color.end()); +} + +template struct styled_arg : view { + const T& value; + text_style style; + styled_arg(const T& v, text_style s) : value(v), style(s) {} +}; + +template +void vformat_to(buffer& buf, text_style ts, basic_string_view fmt, + basic_format_args> args) { + if (ts.has_emphasis()) { + auto emphasis = make_emphasis(ts.get_emphasis()); + buf.append(emphasis.begin(), emphasis.end()); + } + if (ts.has_foreground()) { + auto foreground = make_foreground_color(ts.get_foreground()); + buf.append(foreground.begin(), foreground.end()); + } + if (ts.has_background()) { + auto background = make_background_color(ts.get_background()); + buf.append(background.begin(), background.end()); + } + vformat_to(buf, fmt, args); + if (ts != text_style()) reset_color(buf); +} +} // namespace detail + +inline void vprint(FILE* f, text_style ts, string_view fmt, format_args args) { + auto buf = memory_buffer(); + detail::vformat_to(buf, ts, fmt, args); + print(f, FMT_STRING("{}"), string_view(buf.begin(), buf.size())); +} + +/** + * Formats a string and prints it to the specified file stream using ANSI + * escape sequences to specify text formatting. + * + * **Example**: + * + * fmt::print(fmt::emphasis::bold | fg(fmt::color::red), + * "Elapsed time: {0:.2f} seconds", 1.23); + */ +template +void print(FILE* f, text_style ts, format_string fmt, T&&... args) { + vprint(f, ts, fmt.str, vargs{{args...}}); +} + +/** + * Formats a string and prints it to stdout using ANSI escape sequences to + * specify text formatting. + * + * **Example**: + * + * fmt::print(fmt::emphasis::bold | fg(fmt::color::red), + * "Elapsed time: {0:.2f} seconds", 1.23); + */ +template +void print(text_style ts, format_string fmt, T&&... args) { + return print(stdout, ts, fmt, std::forward(args)...); +} + +inline auto vformat(text_style ts, string_view fmt, format_args args) + -> std::string { + auto buf = memory_buffer(); + detail::vformat_to(buf, ts, fmt, args); + return fmt::to_string(buf); +} + +/** + * Formats arguments and returns the result as a string using ANSI escape + * sequences to specify text formatting. + * + * **Example**: + * + * ``` + * #include + * std::string message = fmt::format(fmt::emphasis::bold | fg(fmt::color::red), + * "The answer is {}", 42); + * ``` + */ +template +inline auto format(text_style ts, format_string fmt, T&&... args) + -> std::string { + return fmt::vformat(ts, fmt.str, vargs{{args...}}); +} + +/// Formats a string with the given text_style and writes the output to `out`. +template ::value)> +auto vformat_to(OutputIt out, text_style ts, string_view fmt, format_args args) + -> OutputIt { + auto&& buf = detail::get_buffer(out); + detail::vformat_to(buf, ts, fmt, args); + return detail::get_iterator(buf, out); +} + +/** + * Formats arguments with the given text style, writes the result to the output + * iterator `out` and returns the iterator past the end of the output range. + * + * **Example**: + * + * std::vector out; + * fmt::format_to(std::back_inserter(out), + * fmt::emphasis::bold | fg(fmt::color::red), "{}", 42); + */ +template ::value)> +inline auto format_to(OutputIt out, text_style ts, format_string fmt, + T&&... args) -> OutputIt { + return vformat_to(out, ts, fmt.str, vargs{{args...}}); +} + +template +struct formatter, Char> : formatter { + template + auto format(const detail::styled_arg& arg, FormatContext& ctx) const + -> decltype(ctx.out()) { + const auto& ts = arg.style; + auto out = ctx.out(); + + bool has_style = false; + if (ts.has_emphasis()) { + has_style = true; + auto emphasis = detail::make_emphasis(ts.get_emphasis()); + out = detail::copy(emphasis.begin(), emphasis.end(), out); + } + if (ts.has_foreground()) { + has_style = true; + auto foreground = + detail::make_foreground_color(ts.get_foreground()); + out = detail::copy(foreground.begin(), foreground.end(), out); + } + if (ts.has_background()) { + has_style = true; + auto background = + detail::make_background_color(ts.get_background()); + out = detail::copy(background.begin(), background.end(), out); + } + out = formatter::format(arg.value, ctx); + if (has_style) { + auto reset_color = string_view("\x1b[0m"); + out = detail::copy(reset_color.begin(), reset_color.end(), out); + } + return out; + } +}; + +/** + * Returns an argument that will be formatted using ANSI escape sequences, + * to be used in a formatting function. + * + * **Example**: + * + * fmt::print("Elapsed time: {0:.2f} seconds", + * fmt::styled(1.23, fmt::fg(fmt::color::green) | + * fmt::bg(fmt::color::blue))); + */ +template +FMT_CONSTEXPR auto styled(const T& value, text_style ts) + -> detail::styled_arg> { + return detail::styled_arg>{value, ts}; +} + +FMT_END_EXPORT +FMT_END_NAMESPACE + +#endif // FMT_COLOR_H_ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fmt/compile.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fmt/compile.h new file mode 100644 index 0000000000000000000000000000000000000000..955135a39f4eab50b903bd3bfee0cb9217f0238d --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fmt/compile.h @@ -0,0 +1,593 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Formatting library for C++ - experimental format string compilation +// +// Copyright (c) 2012 - present, Victor Zverovich and fmt contributors +// All rights reserved. +// +// For the license information refer to format.h. + +#ifndef FMT_COMPILE_H_ +#define FMT_COMPILE_H_ + +#ifndef FMT_MODULE +# include // std::back_inserter +#endif + +#include "format.h" + +FMT_BEGIN_NAMESPACE +FMT_BEGIN_EXPORT + +// A compile-time string which is compiled into fast formatting code. +class compiled_string {}; + +template +struct is_compiled_string : std::is_base_of {}; + +/** + * Converts a string literal `s` into a format string that will be parsed at + * compile time and converted into efficient formatting code. Requires C++17 + * `constexpr if` compiler support. + * + * **Example**: + * + * // Converts 42 into std::string using the most efficient method and no + * // runtime format string processing. + * std::string s = fmt::format(FMT_COMPILE("{}"), 42); + */ +#if defined(__cpp_if_constexpr) && defined(__cpp_return_type_deduction) +# define FMT_COMPILE(s) FMT_STRING_IMPL(s, fmt::compiled_string) +#else +# define FMT_COMPILE(s) FMT_STRING(s) +#endif + +/** + * Converts a string literal into a format string that will be parsed at + * compile time and converted into efficient formatting code. Requires support + * for class types in constant template parameters (a C++20 feature). + * + * **Example**: + * + * // Converts 42 into std::string using the most efficient method and no + * // runtime format string processing. + * using namespace fmt::literals; + * std::string s = fmt::format("{}"_cf, 42); + */ +#if FMT_USE_NONTYPE_TEMPLATE_ARGS +inline namespace literals { +template constexpr auto operator""_cf() { + return FMT_COMPILE(Str.data); +} +} // namespace literals +#endif + +FMT_END_EXPORT + +namespace detail { + +template +constexpr auto first(const T& value, const Tail&...) -> const T& { + return value; +} + +#if defined(__cpp_if_constexpr) && defined(__cpp_return_type_deduction) +template struct type_list {}; + +// Returns a reference to the argument at index N from [first, rest...]. +template +constexpr auto get([[maybe_unused]] const T& first, + [[maybe_unused]] const Args&... rest) -> const auto& { + static_assert(N < 1 + sizeof...(Args), "index is out of bounds"); + if constexpr (N == 0) + return first; + else + return detail::get(rest...); +} + +# if FMT_USE_NONTYPE_TEMPLATE_ARGS +template +constexpr auto get_arg_index_by_name(basic_string_view name) -> int { + if constexpr (is_static_named_arg()) { + if (name == T::name) return N; + } + if constexpr (sizeof...(Args) > 0) + return get_arg_index_by_name(name); + (void)name; // Workaround an MSVC bug about "unused" parameter. + return -1; +} +# endif + +template +FMT_CONSTEXPR auto get_arg_index_by_name(basic_string_view name) -> int { +# if FMT_USE_NONTYPE_TEMPLATE_ARGS + if constexpr (sizeof...(Args) > 0) + return get_arg_index_by_name<0, Args...>(name); +# endif + (void)name; + return -1; +} + +template +constexpr auto get_arg_index_by_name(basic_string_view name, + type_list) -> int { + return get_arg_index_by_name(name); +} + +template struct get_type_impl; + +template struct get_type_impl> { + using type = + remove_cvref_t(std::declval()...))>; +}; + +template +using get_type = typename get_type_impl::type; + +template struct is_compiled_format : std::false_type {}; + +template struct text { + basic_string_view data; + using char_type = Char; + + template + constexpr auto format(OutputIt out, const T&...) const -> OutputIt { + return write(out, data); + } +}; + +template +struct is_compiled_format> : std::true_type {}; + +template +constexpr auto make_text(basic_string_view s, size_t pos, size_t size) + -> text { + return {{&s[pos], size}}; +} + +template struct code_unit { + Char value; + using char_type = Char; + + template + constexpr auto format(OutputIt out, const T&...) const -> OutputIt { + *out++ = value; + return out; + } +}; + +// This ensures that the argument type is convertible to `const T&`. +template +constexpr auto get_arg_checked(const Args&... args) -> const T& { + const auto& arg = detail::get(args...); + if constexpr (detail::is_named_arg>()) { + return arg.value; + } else { + return arg; + } +} + +template +struct is_compiled_format> : std::true_type {}; + +// A replacement field that refers to argument N. +template struct field { + using char_type = Char; + + template + constexpr auto format(OutputIt out, const T&... args) const -> OutputIt { + const V& arg = get_arg_checked(args...); + if constexpr (std::is_convertible>::value) { + auto s = basic_string_view(arg); + return copy(s.begin(), s.end(), out); + } else { + return write(out, arg); + } + } +}; + +template +struct is_compiled_format> : std::true_type {}; + +// A replacement field that refers to argument with name. +template struct runtime_named_field { + using char_type = Char; + basic_string_view name; + + template + constexpr static auto try_format_argument( + OutputIt& out, + // [[maybe_unused]] due to unused-but-set-parameter warning in GCC 7,8,9 + [[maybe_unused]] basic_string_view arg_name, const T& arg) -> bool { + if constexpr (is_named_arg::type>::value) { + if (arg_name == arg.name) { + out = write(out, arg.value); + return true; + } + } + return false; + } + + template + constexpr auto format(OutputIt out, const T&... args) const -> OutputIt { + bool found = (try_format_argument(out, name, args) || ...); + if (!found) { + FMT_THROW(format_error("argument with specified name is not found")); + } + return out; + } +}; + +template +struct is_compiled_format> : std::true_type {}; + +// A replacement field that refers to argument N and has format specifiers. +template struct spec_field { + using char_type = Char; + formatter fmt; + + template + constexpr FMT_INLINE auto format(OutputIt out, const T&... args) const + -> OutputIt { + const auto& vargs = + fmt::make_format_args>(args...); + basic_format_context ctx(out, vargs); + return fmt.format(get_arg_checked(args...), ctx); + } +}; + +template +struct is_compiled_format> : std::true_type {}; + +template struct concat { + L lhs; + R rhs; + using char_type = typename L::char_type; + + template + constexpr auto format(OutputIt out, const T&... args) const -> OutputIt { + out = lhs.format(out, args...); + return rhs.format(out, args...); + } +}; + +template +struct is_compiled_format> : std::true_type {}; + +template +constexpr auto make_concat(L lhs, R rhs) -> concat { + return {lhs, rhs}; +} + +struct unknown_format {}; + +template +constexpr auto parse_text(basic_string_view str, size_t pos) -> size_t { + for (size_t size = str.size(); pos != size; ++pos) { + if (str[pos] == '{' || str[pos] == '}') break; + } + return pos; +} + +template +constexpr auto compile_format_string(S fmt); + +template +constexpr auto parse_tail(T head, S fmt) { + if constexpr (POS != basic_string_view(fmt).size()) { + constexpr auto tail = compile_format_string(fmt); + if constexpr (std::is_same, + unknown_format>()) + return tail; + else + return make_concat(head, tail); + } else { + return head; + } +} + +template struct parse_specs_result { + formatter fmt; + size_t end; + int next_arg_id; +}; + +enum { manual_indexing_id = -1 }; + +template +constexpr auto parse_specs(basic_string_view str, size_t pos, + int next_arg_id) -> parse_specs_result { + str.remove_prefix(pos); + auto ctx = + compile_parse_context(str, max_value(), nullptr, next_arg_id); + auto f = formatter(); + auto end = f.parse(ctx); + return {f, pos + fmt::detail::to_unsigned(end - str.data()), + next_arg_id == 0 ? manual_indexing_id : ctx.next_arg_id()}; +} + +template struct arg_id_handler { + arg_id_kind kind; + arg_ref arg_id; + + constexpr auto on_auto() -> int { + FMT_ASSERT(false, "handler cannot be used with automatic indexing"); + return 0; + } + constexpr auto on_index(int id) -> int { + kind = arg_id_kind::index; + arg_id = arg_ref(id); + return 0; + } + constexpr auto on_name(basic_string_view id) -> int { + kind = arg_id_kind::name; + arg_id = arg_ref(id); + return 0; + } +}; + +template struct parse_arg_id_result { + arg_id_kind kind; + arg_ref arg_id; + const Char* arg_id_end; +}; + +template +constexpr auto parse_arg_id(const Char* begin, const Char* end) { + auto handler = arg_id_handler{arg_id_kind::none, arg_ref{}}; + auto arg_id_end = parse_arg_id(begin, end, handler); + return parse_arg_id_result{handler.kind, handler.arg_id, arg_id_end}; +} + +template struct field_type { + using type = remove_cvref_t; +}; + +template +struct field_type::value>> { + using type = remove_cvref_t; +}; + +template +constexpr auto parse_replacement_field_then_tail(S fmt) { + using char_type = typename S::char_type; + constexpr auto str = basic_string_view(fmt); + constexpr char_type c = END_POS != str.size() ? str[END_POS] : char_type(); + if constexpr (c == '}') { + return parse_tail( + field::type, ARG_INDEX>(), fmt); + } else if constexpr (c != ':') { + FMT_THROW(format_error("expected ':'")); + } else { + constexpr auto result = parse_specs::type>( + str, END_POS + 1, NEXT_ID == manual_indexing_id ? 0 : NEXT_ID); + if constexpr (result.end >= str.size() || str[result.end] != '}') { + FMT_THROW(format_error("expected '}'")); + return 0; + } else { + return parse_tail( + spec_field::type, ARG_INDEX>{ + result.fmt}, + fmt); + } + } +} + +// Compiles a non-empty format string and returns the compiled representation +// or unknown_format() on unrecognized input. +template +constexpr auto compile_format_string(S fmt) { + using char_type = typename S::char_type; + constexpr auto str = basic_string_view(fmt); + if constexpr (str[POS] == '{') { + if constexpr (POS + 1 == str.size()) + FMT_THROW(format_error("unmatched '{' in format string")); + if constexpr (str[POS + 1] == '{') { + return parse_tail(make_text(str, POS, 1), fmt); + } else if constexpr (str[POS + 1] == '}' || str[POS + 1] == ':') { + static_assert(ID != manual_indexing_id, + "cannot switch from manual to automatic argument indexing"); + constexpr auto next_id = + ID != manual_indexing_id ? ID + 1 : manual_indexing_id; + return parse_replacement_field_then_tail, Args, + POS + 1, ID, next_id>(fmt); + } else { + constexpr auto arg_id_result = + parse_arg_id(str.data() + POS + 1, str.data() + str.size()); + constexpr auto arg_id_end_pos = arg_id_result.arg_id_end - str.data(); + constexpr char_type c = + arg_id_end_pos != str.size() ? str[arg_id_end_pos] : char_type(); + static_assert(c == '}' || c == ':', "missing '}' in format string"); + if constexpr (arg_id_result.kind == arg_id_kind::index) { + static_assert( + ID == manual_indexing_id || ID == 0, + "cannot switch from automatic to manual argument indexing"); + constexpr auto arg_index = arg_id_result.arg_id.index; + return parse_replacement_field_then_tail, + Args, arg_id_end_pos, + arg_index, manual_indexing_id>( + fmt); + } else if constexpr (arg_id_result.kind == arg_id_kind::name) { + constexpr auto arg_index = + get_arg_index_by_name(arg_id_result.arg_id.name, Args{}); + if constexpr (arg_index >= 0) { + constexpr auto next_id = + ID != manual_indexing_id ? ID + 1 : manual_indexing_id; + return parse_replacement_field_then_tail< + decltype(get_type::value), Args, arg_id_end_pos, + arg_index, next_id>(fmt); + } else if constexpr (c == '}') { + return parse_tail( + runtime_named_field{arg_id_result.arg_id.name}, fmt); + } else if constexpr (c == ':') { + return unknown_format(); // no type info for specs parsing + } + } + } + } else if constexpr (str[POS] == '}') { + if constexpr (POS + 1 == str.size()) + FMT_THROW(format_error("unmatched '}' in format string")); + return parse_tail(make_text(str, POS, 1), fmt); + } else { + constexpr auto end = parse_text(str, POS + 1); + if constexpr (end - POS > 1) { + return parse_tail(make_text(str, POS, end - POS), fmt); + } else { + return parse_tail(code_unit{str[POS]}, fmt); + } + } +} + +template ::value)> +constexpr auto compile(S fmt) { + constexpr auto str = basic_string_view(fmt); + if constexpr (str.size() == 0) { + return detail::make_text(str, 0, 0); + } else { + constexpr auto result = + detail::compile_format_string, 0, 0>(fmt); + return result; + } +} +#endif // defined(__cpp_if_constexpr) && defined(__cpp_return_type_deduction) +} // namespace detail + +FMT_BEGIN_EXPORT + +#if defined(__cpp_if_constexpr) && defined(__cpp_return_type_deduction) + +template ::value)> +FMT_INLINE FMT_CONSTEXPR_STRING auto format(const CompiledFormat& cf, + const T&... args) + -> std::basic_string { + auto s = std::basic_string(); + cf.format(std::back_inserter(s), args...); + return s; +} + +template ::value)> +constexpr FMT_INLINE auto format_to(OutputIt out, const CompiledFormat& cf, + const T&... args) -> OutputIt { + return cf.format(out, args...); +} + +template ::value)> +FMT_INLINE FMT_CONSTEXPR_STRING auto format(const S&, T&&... args) + -> std::basic_string { + if constexpr (std::is_same::value) { + constexpr auto str = basic_string_view(S()); + if constexpr (str.size() == 2 && str[0] == '{' && str[1] == '}') { + const auto& first = detail::first(args...); + if constexpr (detail::is_named_arg< + remove_cvref_t>::value) { + return fmt::to_string(first.value); + } else { + return fmt::to_string(first); + } + } + } + constexpr auto compiled = detail::compile(S()); + if constexpr (std::is_same, + detail::unknown_format>()) { + return fmt::format( + static_cast>(S()), + std::forward(args)...); + } else { + return fmt::format(compiled, std::forward(args)...); + } +} + +template ::value)> +FMT_CONSTEXPR auto format_to(OutputIt out, const S&, T&&... args) -> OutputIt { + constexpr auto compiled = detail::compile(S()); + if constexpr (std::is_same, + detail::unknown_format>()) { + return fmt::format_to( + out, static_cast>(S()), + std::forward(args)...); + } else { + return fmt::format_to(out, compiled, std::forward(args)...); + } +} +#endif + +template ::value)> +auto format_to_n(OutputIt out, size_t n, const S& fmt, T&&... args) + -> format_to_n_result { + using traits = detail::fixed_buffer_traits; + auto buf = detail::iterator_buffer(out, n); + fmt::format_to(std::back_inserter(buf), fmt, std::forward(args)...); + return {buf.out(), buf.count()}; +} + +template ::value)> +FMT_CONSTEXPR20 auto formatted_size(const S& fmt, T&&... args) -> size_t { + auto buf = detail::counting_buffer<>(); + fmt::format_to(appender(buf), fmt, std::forward(args)...); + return buf.count(); +} + +template ::value)> +void print(std::FILE* f, const S& fmt, T&&... args) { + auto buf = memory_buffer(); + fmt::format_to(appender(buf), fmt, std::forward(args)...); + detail::print(f, {buf.data(), buf.size()}); +} + +template ::value)> +void print(const S& fmt, T&&... args) { + print(stdout, fmt, std::forward(args)...); +} + +template class static_format_result { + private: + char data[N]; + + public: + template ::value)> + explicit FMT_CONSTEXPR static_format_result(const S& fmt, T&&... args) { + *fmt::format_to(data, fmt, std::forward(args)...) = '\0'; + } + + auto str() const -> fmt::string_view { return {data, N - 1}; } + auto c_str() const -> const char* { return data; } +}; + +/** + * Formats arguments according to the format string `fmt_str` and produces + * a string of the exact required size at compile time. Both the format string + * and the arguments must be compile-time expressions. + * + * The resulting string can be accessed as a C string via `c_str()` or as + * a `fmt::string_view` via `str()`. + * + * **Example**: + * + * // Produces the static string "42" at compile time. + * static constexpr auto result = FMT_STATIC_FORMAT("{}", 42); + * const char* s = result.c_str(); + */ +#define FMT_STATIC_FORMAT(fmt_str, ...) \ + fmt::static_format_result< \ + fmt::formatted_size(FMT_COMPILE(fmt_str), __VA_ARGS__) + 1>( \ + FMT_COMPILE(fmt_str), __VA_ARGS__) + +FMT_END_EXPORT +FMT_END_NAMESPACE + +#endif // FMT_COMPILE_H_ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fmt/core.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fmt/core.h new file mode 100644 index 0000000000000000000000000000000000000000..6d2b271f05c9ff22b8c67f8cfa7bc2c8dbdef5b1 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fmt/core.h @@ -0,0 +1,10 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// This file is only provided for compatibility and may be removed in future +// versions. Use fmt/base.h if you don't need fmt::format and fmt/format.h +// otherwise. + +#include "format.h" + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fmt/format-inl.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fmt/format-inl.h new file mode 100644 index 0000000000000000000000000000000000000000..ccf962a394ea93bc040d5e9b94bfb7e0bcf30bb7 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fmt/format-inl.h @@ -0,0 +1,1953 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Formatting library for C++ - implementation +// +// Copyright (c) 2012 - 2016, Victor Zverovich +// All rights reserved. +// +// For the license information refer to format.h. + +#ifndef FMT_FORMAT_INL_H_ +#define FMT_FORMAT_INL_H_ + +#ifndef FMT_MODULE +# include +# include // errno +# include +# include +# include +#endif + +#if defined(_WIN32) && !defined(FMT_USE_WRITE_CONSOLE) +# include // _isatty +#endif + +#include "format.h" + +#if FMT_USE_LOCALE && !defined(FMT_MODULE) +# include +#endif + +#ifndef FMT_FUNC +# define FMT_FUNC +#endif + +FMT_BEGIN_NAMESPACE + +#ifndef FMT_CUSTOM_ASSERT_FAIL +FMT_FUNC void assert_fail(const char* file, int line, const char* message) { + // Use unchecked std::fprintf to avoid triggering another assertion when + // writing to stderr fails. + std::fprintf(stderr, "%s:%d: assertion failed: %s", file, line, message); + abort(); +} +#endif + +#if FMT_USE_LOCALE +namespace detail { +using std::locale; +using std::numpunct; +using std::use_facet; +} // namespace detail +#else +namespace detail { +struct locale {}; +template struct numpunct { + auto grouping() const -> std::string { return "\03"; } + auto thousands_sep() const -> Char { return ','; } + auto decimal_point() const -> Char { return '.'; } +}; +template Facet use_facet(locale) { return {}; } +} // namespace detail +#endif // FMT_USE_LOCALE + +template auto locale_ref::get() const -> Locale { + using namespace detail; + static_assert(std::is_same::value, ""); +#if FMT_USE_LOCALE + if (locale_) return *static_cast(locale_); +#endif + return locale(); +} + +namespace detail { + +FMT_FUNC void format_error_code(detail::buffer& out, int error_code, + string_view message) noexcept { + // Report error code making sure that the output fits into + // inline_buffer_size to avoid dynamic memory allocation and potential + // bad_alloc. + out.try_resize(0); + static const char SEP[] = ": "; + static const char ERROR_STR[] = "error "; + // Subtract 2 to account for terminating null characters in SEP and ERROR_STR. + size_t error_code_size = sizeof(SEP) + sizeof(ERROR_STR) - 2; + auto abs_value = static_cast>(error_code); + if (detail::is_negative(error_code)) { + abs_value = 0 - abs_value; + ++error_code_size; + } + error_code_size += detail::to_unsigned(detail::count_digits(abs_value)); + auto it = appender(out); + if (message.size() <= inline_buffer_size - error_code_size) + fmt::format_to(it, FMT_STRING("{}{}"), message, SEP); + fmt::format_to(it, FMT_STRING("{}{}"), ERROR_STR, error_code); + FMT_ASSERT(out.size() <= inline_buffer_size, ""); +} + +FMT_FUNC void do_report_error(format_func func, int error_code, + const char* message) noexcept { + memory_buffer full_message; + func(full_message, error_code, message); + // Don't use fwrite_all because the latter may throw. + if (std::fwrite(full_message.data(), full_message.size(), 1, stderr) > 0) + std::fputc('\n', stderr); +} + +// A wrapper around fwrite that throws on error. +inline void fwrite_all(const void* ptr, size_t count, FILE* stream) { + size_t written = std::fwrite(ptr, 1, count, stream); + if (written < count) + FMT_THROW(system_error(errno, FMT_STRING("cannot write to file"))); +} + +template +FMT_FUNC auto thousands_sep_impl(locale_ref loc) -> thousands_sep_result { + auto&& facet = use_facet>(loc.get()); + auto grouping = facet.grouping(); + auto thousands_sep = grouping.empty() ? Char() : facet.thousands_sep(); + return {std::move(grouping), thousands_sep}; +} +template +FMT_FUNC auto decimal_point_impl(locale_ref loc) -> Char { + return use_facet>(loc.get()).decimal_point(); +} + +#if FMT_USE_LOCALE +FMT_FUNC auto write_loc(appender out, loc_value value, + const format_specs& specs, locale_ref loc) -> bool { + auto locale = loc.get(); + // We cannot use the num_put facet because it may produce output in + // a wrong encoding. + using facet = format_facet; + if (std::has_facet(locale)) + return use_facet(locale).put(out, value, specs); + return facet(locale).put(out, value, specs); +} +#endif +} // namespace detail + +FMT_FUNC void report_error(const char* message) { +#if FMT_MSC_VERSION || defined(__NVCC__) + // Silence unreachable code warnings in MSVC and NVCC because these + // are nearly impossible to fix in a generic code. + volatile bool b = true; + if (!b) return; +#endif + FMT_THROW(format_error(message)); +} + +template typename Locale::id format_facet::id; + +template format_facet::format_facet(Locale& loc) { + auto& np = detail::use_facet>(loc); + grouping_ = np.grouping(); + if (!grouping_.empty()) separator_ = std::string(1, np.thousands_sep()); +} + +#if FMT_USE_LOCALE +template <> +FMT_API FMT_FUNC auto format_facet::do_put( + appender out, loc_value val, const format_specs& specs) const -> bool { + return val.visit( + detail::loc_writer<>{out, specs, separator_, grouping_, decimal_point_}); +} +#endif + +FMT_FUNC auto vsystem_error(int error_code, string_view fmt, format_args args) + -> std::system_error { + auto ec = std::error_code(error_code, std::generic_category()); + return std::system_error(ec, vformat(fmt, args)); +} + +namespace detail { + +template +inline auto operator==(basic_fp x, basic_fp y) -> bool { + return x.f == y.f && x.e == y.e; +} + +// Compilers should be able to optimize this into the ror instruction. +FMT_INLINE auto rotr(uint32_t n, uint32_t r) noexcept -> uint32_t { + r &= 31; + return (n >> r) | (n << (32 - r)); +} +FMT_INLINE auto rotr(uint64_t n, uint32_t r) noexcept -> uint64_t { + r &= 63; + return (n >> r) | (n << (64 - r)); +} + +// Implementation of Dragonbox algorithm: https://github.com/jk-jeon/dragonbox. +namespace dragonbox { +// Computes upper 64 bits of multiplication of a 32-bit unsigned integer and a +// 64-bit unsigned integer. +inline auto umul96_upper64(uint32_t x, uint64_t y) noexcept -> uint64_t { + return umul128_upper64(static_cast(x) << 32, y); +} + +// Computes lower 128 bits of multiplication of a 64-bit unsigned integer and a +// 128-bit unsigned integer. +inline auto umul192_lower128(uint64_t x, uint128_fallback y) noexcept + -> uint128_fallback { + uint64_t high = x * y.high(); + uint128_fallback high_low = umul128(x, y.low()); + return {high + high_low.high(), high_low.low()}; +} + +// Computes lower 64 bits of multiplication of a 32-bit unsigned integer and a +// 64-bit unsigned integer. +inline auto umul96_lower64(uint32_t x, uint64_t y) noexcept -> uint64_t { + return x * y; +} + +// Various fast log computations. +inline auto floor_log10_pow2_minus_log10_4_over_3(int e) noexcept -> int { + FMT_ASSERT(e <= 2936 && e >= -2985, "too large exponent"); + return (e * 631305 - 261663) >> 21; +} + +FMT_INLINE_VARIABLE constexpr struct div_small_pow10_infos_struct { + uint32_t divisor; + int shift_amount; +} div_small_pow10_infos[] = {{10, 16}, {100, 16}}; + +// Replaces n by floor(n / pow(10, N)) returning true if and only if n is +// divisible by pow(10, N). +// Precondition: n <= pow(10, N + 1). +template +auto check_divisibility_and_divide_by_pow10(uint32_t& n) noexcept -> bool { + // The numbers below are chosen such that: + // 1. floor(n/d) = floor(nm / 2^k) where d=10 or d=100, + // 2. nm mod 2^k < m if and only if n is divisible by d, + // where m is magic_number, k is shift_amount + // and d is divisor. + // + // Item 1 is a common technique of replacing division by a constant with + // multiplication, see e.g. "Division by Invariant Integers Using + // Multiplication" by Granlund and Montgomery (1994). magic_number (m) is set + // to ceil(2^k/d) for large enough k. + // The idea for item 2 originates from Schubfach. + constexpr auto info = div_small_pow10_infos[N - 1]; + FMT_ASSERT(n <= info.divisor * 10, "n is too large"); + constexpr uint32_t magic_number = + (1u << info.shift_amount) / info.divisor + 1; + n *= magic_number; + const uint32_t comparison_mask = (1u << info.shift_amount) - 1; + bool result = (n & comparison_mask) < magic_number; + n >>= info.shift_amount; + return result; +} + +// Computes floor(n / pow(10, N)) for small n and N. +// Precondition: n <= pow(10, N + 1). +template auto small_division_by_pow10(uint32_t n) noexcept -> uint32_t { + constexpr auto info = div_small_pow10_infos[N - 1]; + FMT_ASSERT(n <= info.divisor * 10, "n is too large"); + constexpr uint32_t magic_number = + (1u << info.shift_amount) / info.divisor + 1; + return (n * magic_number) >> info.shift_amount; +} + +// Computes floor(n / 10^(kappa + 1)) (float) +inline auto divide_by_10_to_kappa_plus_1(uint32_t n) noexcept -> uint32_t { + // 1374389535 = ceil(2^37/100) + return static_cast((static_cast(n) * 1374389535) >> 37); +} +// Computes floor(n / 10^(kappa + 1)) (double) +inline auto divide_by_10_to_kappa_plus_1(uint64_t n) noexcept -> uint64_t { + // 2361183241434822607 = ceil(2^(64+7)/1000) + return umul128_upper64(n, 2361183241434822607ull) >> 7; +} + +// Various subroutines using pow10 cache +template struct cache_accessor; + +template <> struct cache_accessor { + using carrier_uint = float_info::carrier_uint; + using cache_entry_type = uint64_t; + + static auto get_cached_power(int k) noexcept -> uint64_t { + FMT_ASSERT(k >= float_info::min_k && k <= float_info::max_k, + "k is out of range"); + static constexpr uint64_t pow10_significands[] = { + 0x81ceb32c4b43fcf5, 0xa2425ff75e14fc32, 0xcad2f7f5359a3b3f, + 0xfd87b5f28300ca0e, 0x9e74d1b791e07e49, 0xc612062576589ddb, + 0xf79687aed3eec552, 0x9abe14cd44753b53, 0xc16d9a0095928a28, + 0xf1c90080baf72cb2, 0x971da05074da7bef, 0xbce5086492111aeb, + 0xec1e4a7db69561a6, 0x9392ee8e921d5d08, 0xb877aa3236a4b44a, + 0xe69594bec44de15c, 0x901d7cf73ab0acda, 0xb424dc35095cd810, + 0xe12e13424bb40e14, 0x8cbccc096f5088cc, 0xafebff0bcb24aaff, + 0xdbe6fecebdedd5bf, 0x89705f4136b4a598, 0xabcc77118461cefd, + 0xd6bf94d5e57a42bd, 0x8637bd05af6c69b6, 0xa7c5ac471b478424, + 0xd1b71758e219652c, 0x83126e978d4fdf3c, 0xa3d70a3d70a3d70b, + 0xcccccccccccccccd, 0x8000000000000000, 0xa000000000000000, + 0xc800000000000000, 0xfa00000000000000, 0x9c40000000000000, + 0xc350000000000000, 0xf424000000000000, 0x9896800000000000, + 0xbebc200000000000, 0xee6b280000000000, 0x9502f90000000000, + 0xba43b74000000000, 0xe8d4a51000000000, 0x9184e72a00000000, + 0xb5e620f480000000, 0xe35fa931a0000000, 0x8e1bc9bf04000000, + 0xb1a2bc2ec5000000, 0xde0b6b3a76400000, 0x8ac7230489e80000, + 0xad78ebc5ac620000, 0xd8d726b7177a8000, 0x878678326eac9000, + 0xa968163f0a57b400, 0xd3c21bcecceda100, 0x84595161401484a0, + 0xa56fa5b99019a5c8, 0xcecb8f27f4200f3a, 0x813f3978f8940985, + 0xa18f07d736b90be6, 0xc9f2c9cd04674edf, 0xfc6f7c4045812297, + 0x9dc5ada82b70b59e, 0xc5371912364ce306, 0xf684df56c3e01bc7, + 0x9a130b963a6c115d, 0xc097ce7bc90715b4, 0xf0bdc21abb48db21, + 0x96769950b50d88f5, 0xbc143fa4e250eb32, 0xeb194f8e1ae525fe, + 0x92efd1b8d0cf37bf, 0xb7abc627050305ae, 0xe596b7b0c643c71a, + 0x8f7e32ce7bea5c70, 0xb35dbf821ae4f38c, 0xe0352f62a19e306f}; + return pow10_significands[k - float_info::min_k]; + } + + struct compute_mul_result { + carrier_uint result; + bool is_integer; + }; + struct compute_mul_parity_result { + bool parity; + bool is_integer; + }; + + static auto compute_mul(carrier_uint u, + const cache_entry_type& cache) noexcept + -> compute_mul_result { + auto r = umul96_upper64(u, cache); + return {static_cast(r >> 32), + static_cast(r) == 0}; + } + + static auto compute_delta(const cache_entry_type& cache, int beta) noexcept + -> uint32_t { + return static_cast(cache >> (64 - 1 - beta)); + } + + static auto compute_mul_parity(carrier_uint two_f, + const cache_entry_type& cache, + int beta) noexcept + -> compute_mul_parity_result { + FMT_ASSERT(beta >= 1, ""); + FMT_ASSERT(beta < 64, ""); + + auto r = umul96_lower64(two_f, cache); + return {((r >> (64 - beta)) & 1) != 0, + static_cast(r >> (32 - beta)) == 0}; + } + + static auto compute_left_endpoint_for_shorter_interval_case( + const cache_entry_type& cache, int beta) noexcept -> carrier_uint { + return static_cast( + (cache - (cache >> (num_significand_bits() + 2))) >> + (64 - num_significand_bits() - 1 - beta)); + } + + static auto compute_right_endpoint_for_shorter_interval_case( + const cache_entry_type& cache, int beta) noexcept -> carrier_uint { + return static_cast( + (cache + (cache >> (num_significand_bits() + 1))) >> + (64 - num_significand_bits() - 1 - beta)); + } + + static auto compute_round_up_for_shorter_interval_case( + const cache_entry_type& cache, int beta) noexcept -> carrier_uint { + return (static_cast( + cache >> (64 - num_significand_bits() - 2 - beta)) + + 1) / + 2; + } +}; + +template <> struct cache_accessor { + using carrier_uint = float_info::carrier_uint; + using cache_entry_type = uint128_fallback; + + static auto get_cached_power(int k) noexcept -> uint128_fallback { + FMT_ASSERT(k >= float_info::min_k && k <= float_info::max_k, + "k is out of range"); + + static constexpr uint128_fallback pow10_significands[] = { +#if FMT_USE_FULL_CACHE_DRAGONBOX + {0xff77b1fcbebcdc4f, 0x25e8e89c13bb0f7b}, + {0x9faacf3df73609b1, 0x77b191618c54e9ad}, + {0xc795830d75038c1d, 0xd59df5b9ef6a2418}, + {0xf97ae3d0d2446f25, 0x4b0573286b44ad1e}, + {0x9becce62836ac577, 0x4ee367f9430aec33}, + {0xc2e801fb244576d5, 0x229c41f793cda740}, + {0xf3a20279ed56d48a, 0x6b43527578c11110}, + {0x9845418c345644d6, 0x830a13896b78aaaa}, + {0xbe5691ef416bd60c, 0x23cc986bc656d554}, + {0xedec366b11c6cb8f, 0x2cbfbe86b7ec8aa9}, + {0x94b3a202eb1c3f39, 0x7bf7d71432f3d6aa}, + {0xb9e08a83a5e34f07, 0xdaf5ccd93fb0cc54}, + {0xe858ad248f5c22c9, 0xd1b3400f8f9cff69}, + {0x91376c36d99995be, 0x23100809b9c21fa2}, + {0xb58547448ffffb2d, 0xabd40a0c2832a78b}, + {0xe2e69915b3fff9f9, 0x16c90c8f323f516d}, + {0x8dd01fad907ffc3b, 0xae3da7d97f6792e4}, + {0xb1442798f49ffb4a, 0x99cd11cfdf41779d}, + {0xdd95317f31c7fa1d, 0x40405643d711d584}, + {0x8a7d3eef7f1cfc52, 0x482835ea666b2573}, + {0xad1c8eab5ee43b66, 0xda3243650005eed0}, + {0xd863b256369d4a40, 0x90bed43e40076a83}, + {0x873e4f75e2224e68, 0x5a7744a6e804a292}, + {0xa90de3535aaae202, 0x711515d0a205cb37}, + {0xd3515c2831559a83, 0x0d5a5b44ca873e04}, + {0x8412d9991ed58091, 0xe858790afe9486c3}, + {0xa5178fff668ae0b6, 0x626e974dbe39a873}, + {0xce5d73ff402d98e3, 0xfb0a3d212dc81290}, + {0x80fa687f881c7f8e, 0x7ce66634bc9d0b9a}, + {0xa139029f6a239f72, 0x1c1fffc1ebc44e81}, + {0xc987434744ac874e, 0xa327ffb266b56221}, + {0xfbe9141915d7a922, 0x4bf1ff9f0062baa9}, + {0x9d71ac8fada6c9b5, 0x6f773fc3603db4aa}, + {0xc4ce17b399107c22, 0xcb550fb4384d21d4}, + {0xf6019da07f549b2b, 0x7e2a53a146606a49}, + {0x99c102844f94e0fb, 0x2eda7444cbfc426e}, + {0xc0314325637a1939, 0xfa911155fefb5309}, + {0xf03d93eebc589f88, 0x793555ab7eba27cb}, + {0x96267c7535b763b5, 0x4bc1558b2f3458df}, + {0xbbb01b9283253ca2, 0x9eb1aaedfb016f17}, + {0xea9c227723ee8bcb, 0x465e15a979c1cadd}, + {0x92a1958a7675175f, 0x0bfacd89ec191eca}, + {0xb749faed14125d36, 0xcef980ec671f667c}, + {0xe51c79a85916f484, 0x82b7e12780e7401b}, + {0x8f31cc0937ae58d2, 0xd1b2ecb8b0908811}, + {0xb2fe3f0b8599ef07, 0x861fa7e6dcb4aa16}, + {0xdfbdcece67006ac9, 0x67a791e093e1d49b}, + {0x8bd6a141006042bd, 0xe0c8bb2c5c6d24e1}, + {0xaecc49914078536d, 0x58fae9f773886e19}, + {0xda7f5bf590966848, 0xaf39a475506a899f}, + {0x888f99797a5e012d, 0x6d8406c952429604}, + {0xaab37fd7d8f58178, 0xc8e5087ba6d33b84}, + {0xd5605fcdcf32e1d6, 0xfb1e4a9a90880a65}, + {0x855c3be0a17fcd26, 0x5cf2eea09a550680}, + {0xa6b34ad8c9dfc06f, 0xf42faa48c0ea481f}, + {0xd0601d8efc57b08b, 0xf13b94daf124da27}, + {0x823c12795db6ce57, 0x76c53d08d6b70859}, + {0xa2cb1717b52481ed, 0x54768c4b0c64ca6f}, + {0xcb7ddcdda26da268, 0xa9942f5dcf7dfd0a}, + {0xfe5d54150b090b02, 0xd3f93b35435d7c4d}, + {0x9efa548d26e5a6e1, 0xc47bc5014a1a6db0}, + {0xc6b8e9b0709f109a, 0x359ab6419ca1091c}, + {0xf867241c8cc6d4c0, 0xc30163d203c94b63}, + {0x9b407691d7fc44f8, 0x79e0de63425dcf1e}, + {0xc21094364dfb5636, 0x985915fc12f542e5}, + {0xf294b943e17a2bc4, 0x3e6f5b7b17b2939e}, + {0x979cf3ca6cec5b5a, 0xa705992ceecf9c43}, + {0xbd8430bd08277231, 0x50c6ff782a838354}, + {0xece53cec4a314ebd, 0xa4f8bf5635246429}, + {0x940f4613ae5ed136, 0x871b7795e136be9a}, + {0xb913179899f68584, 0x28e2557b59846e40}, + {0xe757dd7ec07426e5, 0x331aeada2fe589d0}, + {0x9096ea6f3848984f, 0x3ff0d2c85def7622}, + {0xb4bca50b065abe63, 0x0fed077a756b53aa}, + {0xe1ebce4dc7f16dfb, 0xd3e8495912c62895}, + {0x8d3360f09cf6e4bd, 0x64712dd7abbbd95d}, + {0xb080392cc4349dec, 0xbd8d794d96aacfb4}, + {0xdca04777f541c567, 0xecf0d7a0fc5583a1}, + {0x89e42caaf9491b60, 0xf41686c49db57245}, + {0xac5d37d5b79b6239, 0x311c2875c522ced6}, + {0xd77485cb25823ac7, 0x7d633293366b828c}, + {0x86a8d39ef77164bc, 0xae5dff9c02033198}, + {0xa8530886b54dbdeb, 0xd9f57f830283fdfd}, + {0xd267caa862a12d66, 0xd072df63c324fd7c}, + {0x8380dea93da4bc60, 0x4247cb9e59f71e6e}, + {0xa46116538d0deb78, 0x52d9be85f074e609}, + {0xcd795be870516656, 0x67902e276c921f8c}, + {0x806bd9714632dff6, 0x00ba1cd8a3db53b7}, + {0xa086cfcd97bf97f3, 0x80e8a40eccd228a5}, + {0xc8a883c0fdaf7df0, 0x6122cd128006b2ce}, + {0xfad2a4b13d1b5d6c, 0x796b805720085f82}, + {0x9cc3a6eec6311a63, 0xcbe3303674053bb1}, + {0xc3f490aa77bd60fc, 0xbedbfc4411068a9d}, + {0xf4f1b4d515acb93b, 0xee92fb5515482d45}, + {0x991711052d8bf3c5, 0x751bdd152d4d1c4b}, + {0xbf5cd54678eef0b6, 0xd262d45a78a0635e}, + {0xef340a98172aace4, 0x86fb897116c87c35}, + {0x9580869f0e7aac0e, 0xd45d35e6ae3d4da1}, + {0xbae0a846d2195712, 0x8974836059cca10a}, + {0xe998d258869facd7, 0x2bd1a438703fc94c}, + {0x91ff83775423cc06, 0x7b6306a34627ddd0}, + {0xb67f6455292cbf08, 0x1a3bc84c17b1d543}, + {0xe41f3d6a7377eeca, 0x20caba5f1d9e4a94}, + {0x8e938662882af53e, 0x547eb47b7282ee9d}, + {0xb23867fb2a35b28d, 0xe99e619a4f23aa44}, + {0xdec681f9f4c31f31, 0x6405fa00e2ec94d5}, + {0x8b3c113c38f9f37e, 0xde83bc408dd3dd05}, + {0xae0b158b4738705e, 0x9624ab50b148d446}, + {0xd98ddaee19068c76, 0x3badd624dd9b0958}, + {0x87f8a8d4cfa417c9, 0xe54ca5d70a80e5d7}, + {0xa9f6d30a038d1dbc, 0x5e9fcf4ccd211f4d}, + {0xd47487cc8470652b, 0x7647c32000696720}, + {0x84c8d4dfd2c63f3b, 0x29ecd9f40041e074}, + {0xa5fb0a17c777cf09, 0xf468107100525891}, + {0xcf79cc9db955c2cc, 0x7182148d4066eeb5}, + {0x81ac1fe293d599bf, 0xc6f14cd848405531}, + {0xa21727db38cb002f, 0xb8ada00e5a506a7d}, + {0xca9cf1d206fdc03b, 0xa6d90811f0e4851d}, + {0xfd442e4688bd304a, 0x908f4a166d1da664}, + {0x9e4a9cec15763e2e, 0x9a598e4e043287ff}, + {0xc5dd44271ad3cdba, 0x40eff1e1853f29fe}, + {0xf7549530e188c128, 0xd12bee59e68ef47d}, + {0x9a94dd3e8cf578b9, 0x82bb74f8301958cf}, + {0xc13a148e3032d6e7, 0xe36a52363c1faf02}, + {0xf18899b1bc3f8ca1, 0xdc44e6c3cb279ac2}, + {0x96f5600f15a7b7e5, 0x29ab103a5ef8c0ba}, + {0xbcb2b812db11a5de, 0x7415d448f6b6f0e8}, + {0xebdf661791d60f56, 0x111b495b3464ad22}, + {0x936b9fcebb25c995, 0xcab10dd900beec35}, + {0xb84687c269ef3bfb, 0x3d5d514f40eea743}, + {0xe65829b3046b0afa, 0x0cb4a5a3112a5113}, + {0x8ff71a0fe2c2e6dc, 0x47f0e785eaba72ac}, + {0xb3f4e093db73a093, 0x59ed216765690f57}, + {0xe0f218b8d25088b8, 0x306869c13ec3532d}, + {0x8c974f7383725573, 0x1e414218c73a13fc}, + {0xafbd2350644eeacf, 0xe5d1929ef90898fb}, + {0xdbac6c247d62a583, 0xdf45f746b74abf3a}, + {0x894bc396ce5da772, 0x6b8bba8c328eb784}, + {0xab9eb47c81f5114f, 0x066ea92f3f326565}, + {0xd686619ba27255a2, 0xc80a537b0efefebe}, + {0x8613fd0145877585, 0xbd06742ce95f5f37}, + {0xa798fc4196e952e7, 0x2c48113823b73705}, + {0xd17f3b51fca3a7a0, 0xf75a15862ca504c6}, + {0x82ef85133de648c4, 0x9a984d73dbe722fc}, + {0xa3ab66580d5fdaf5, 0xc13e60d0d2e0ebbb}, + {0xcc963fee10b7d1b3, 0x318df905079926a9}, + {0xffbbcfe994e5c61f, 0xfdf17746497f7053}, + {0x9fd561f1fd0f9bd3, 0xfeb6ea8bedefa634}, + {0xc7caba6e7c5382c8, 0xfe64a52ee96b8fc1}, + {0xf9bd690a1b68637b, 0x3dfdce7aa3c673b1}, + {0x9c1661a651213e2d, 0x06bea10ca65c084f}, + {0xc31bfa0fe5698db8, 0x486e494fcff30a63}, + {0xf3e2f893dec3f126, 0x5a89dba3c3efccfb}, + {0x986ddb5c6b3a76b7, 0xf89629465a75e01d}, + {0xbe89523386091465, 0xf6bbb397f1135824}, + {0xee2ba6c0678b597f, 0x746aa07ded582e2d}, + {0x94db483840b717ef, 0xa8c2a44eb4571cdd}, + {0xba121a4650e4ddeb, 0x92f34d62616ce414}, + {0xe896a0d7e51e1566, 0x77b020baf9c81d18}, + {0x915e2486ef32cd60, 0x0ace1474dc1d122f}, + {0xb5b5ada8aaff80b8, 0x0d819992132456bb}, + {0xe3231912d5bf60e6, 0x10e1fff697ed6c6a}, + {0x8df5efabc5979c8f, 0xca8d3ffa1ef463c2}, + {0xb1736b96b6fd83b3, 0xbd308ff8a6b17cb3}, + {0xddd0467c64bce4a0, 0xac7cb3f6d05ddbdf}, + {0x8aa22c0dbef60ee4, 0x6bcdf07a423aa96c}, + {0xad4ab7112eb3929d, 0x86c16c98d2c953c7}, + {0xd89d64d57a607744, 0xe871c7bf077ba8b8}, + {0x87625f056c7c4a8b, 0x11471cd764ad4973}, + {0xa93af6c6c79b5d2d, 0xd598e40d3dd89bd0}, + {0xd389b47879823479, 0x4aff1d108d4ec2c4}, + {0x843610cb4bf160cb, 0xcedf722a585139bb}, + {0xa54394fe1eedb8fe, 0xc2974eb4ee658829}, + {0xce947a3da6a9273e, 0x733d226229feea33}, + {0x811ccc668829b887, 0x0806357d5a3f5260}, + {0xa163ff802a3426a8, 0xca07c2dcb0cf26f8}, + {0xc9bcff6034c13052, 0xfc89b393dd02f0b6}, + {0xfc2c3f3841f17c67, 0xbbac2078d443ace3}, + {0x9d9ba7832936edc0, 0xd54b944b84aa4c0e}, + {0xc5029163f384a931, 0x0a9e795e65d4df12}, + {0xf64335bcf065d37d, 0x4d4617b5ff4a16d6}, + {0x99ea0196163fa42e, 0x504bced1bf8e4e46}, + {0xc06481fb9bcf8d39, 0xe45ec2862f71e1d7}, + {0xf07da27a82c37088, 0x5d767327bb4e5a4d}, + {0x964e858c91ba2655, 0x3a6a07f8d510f870}, + {0xbbe226efb628afea, 0x890489f70a55368c}, + {0xeadab0aba3b2dbe5, 0x2b45ac74ccea842f}, + {0x92c8ae6b464fc96f, 0x3b0b8bc90012929e}, + {0xb77ada0617e3bbcb, 0x09ce6ebb40173745}, + {0xe55990879ddcaabd, 0xcc420a6a101d0516}, + {0x8f57fa54c2a9eab6, 0x9fa946824a12232e}, + {0xb32df8e9f3546564, 0x47939822dc96abfa}, + {0xdff9772470297ebd, 0x59787e2b93bc56f8}, + {0x8bfbea76c619ef36, 0x57eb4edb3c55b65b}, + {0xaefae51477a06b03, 0xede622920b6b23f2}, + {0xdab99e59958885c4, 0xe95fab368e45ecee}, + {0x88b402f7fd75539b, 0x11dbcb0218ebb415}, + {0xaae103b5fcd2a881, 0xd652bdc29f26a11a}, + {0xd59944a37c0752a2, 0x4be76d3346f04960}, + {0x857fcae62d8493a5, 0x6f70a4400c562ddc}, + {0xa6dfbd9fb8e5b88e, 0xcb4ccd500f6bb953}, + {0xd097ad07a71f26b2, 0x7e2000a41346a7a8}, + {0x825ecc24c873782f, 0x8ed400668c0c28c9}, + {0xa2f67f2dfa90563b, 0x728900802f0f32fb}, + {0xcbb41ef979346bca, 0x4f2b40a03ad2ffba}, + {0xfea126b7d78186bc, 0xe2f610c84987bfa9}, + {0x9f24b832e6b0f436, 0x0dd9ca7d2df4d7ca}, + {0xc6ede63fa05d3143, 0x91503d1c79720dbc}, + {0xf8a95fcf88747d94, 0x75a44c6397ce912b}, + {0x9b69dbe1b548ce7c, 0xc986afbe3ee11abb}, + {0xc24452da229b021b, 0xfbe85badce996169}, + {0xf2d56790ab41c2a2, 0xfae27299423fb9c4}, + {0x97c560ba6b0919a5, 0xdccd879fc967d41b}, + {0xbdb6b8e905cb600f, 0x5400e987bbc1c921}, + {0xed246723473e3813, 0x290123e9aab23b69}, + {0x9436c0760c86e30b, 0xf9a0b6720aaf6522}, + {0xb94470938fa89bce, 0xf808e40e8d5b3e6a}, + {0xe7958cb87392c2c2, 0xb60b1d1230b20e05}, + {0x90bd77f3483bb9b9, 0xb1c6f22b5e6f48c3}, + {0xb4ecd5f01a4aa828, 0x1e38aeb6360b1af4}, + {0xe2280b6c20dd5232, 0x25c6da63c38de1b1}, + {0x8d590723948a535f, 0x579c487e5a38ad0f}, + {0xb0af48ec79ace837, 0x2d835a9df0c6d852}, + {0xdcdb1b2798182244, 0xf8e431456cf88e66}, + {0x8a08f0f8bf0f156b, 0x1b8e9ecb641b5900}, + {0xac8b2d36eed2dac5, 0xe272467e3d222f40}, + {0xd7adf884aa879177, 0x5b0ed81dcc6abb10}, + {0x86ccbb52ea94baea, 0x98e947129fc2b4ea}, + {0xa87fea27a539e9a5, 0x3f2398d747b36225}, + {0xd29fe4b18e88640e, 0x8eec7f0d19a03aae}, + {0x83a3eeeef9153e89, 0x1953cf68300424ad}, + {0xa48ceaaab75a8e2b, 0x5fa8c3423c052dd8}, + {0xcdb02555653131b6, 0x3792f412cb06794e}, + {0x808e17555f3ebf11, 0xe2bbd88bbee40bd1}, + {0xa0b19d2ab70e6ed6, 0x5b6aceaeae9d0ec5}, + {0xc8de047564d20a8b, 0xf245825a5a445276}, + {0xfb158592be068d2e, 0xeed6e2f0f0d56713}, + {0x9ced737bb6c4183d, 0x55464dd69685606c}, + {0xc428d05aa4751e4c, 0xaa97e14c3c26b887}, + {0xf53304714d9265df, 0xd53dd99f4b3066a9}, + {0x993fe2c6d07b7fab, 0xe546a8038efe402a}, + {0xbf8fdb78849a5f96, 0xde98520472bdd034}, + {0xef73d256a5c0f77c, 0x963e66858f6d4441}, + {0x95a8637627989aad, 0xdde7001379a44aa9}, + {0xbb127c53b17ec159, 0x5560c018580d5d53}, + {0xe9d71b689dde71af, 0xaab8f01e6e10b4a7}, + {0x9226712162ab070d, 0xcab3961304ca70e9}, + {0xb6b00d69bb55c8d1, 0x3d607b97c5fd0d23}, + {0xe45c10c42a2b3b05, 0x8cb89a7db77c506b}, + {0x8eb98a7a9a5b04e3, 0x77f3608e92adb243}, + {0xb267ed1940f1c61c, 0x55f038b237591ed4}, + {0xdf01e85f912e37a3, 0x6b6c46dec52f6689}, + {0x8b61313bbabce2c6, 0x2323ac4b3b3da016}, + {0xae397d8aa96c1b77, 0xabec975e0a0d081b}, + {0xd9c7dced53c72255, 0x96e7bd358c904a22}, + {0x881cea14545c7575, 0x7e50d64177da2e55}, + {0xaa242499697392d2, 0xdde50bd1d5d0b9ea}, + {0xd4ad2dbfc3d07787, 0x955e4ec64b44e865}, + {0x84ec3c97da624ab4, 0xbd5af13bef0b113f}, + {0xa6274bbdd0fadd61, 0xecb1ad8aeacdd58f}, + {0xcfb11ead453994ba, 0x67de18eda5814af3}, + {0x81ceb32c4b43fcf4, 0x80eacf948770ced8}, + {0xa2425ff75e14fc31, 0xa1258379a94d028e}, + {0xcad2f7f5359a3b3e, 0x096ee45813a04331}, + {0xfd87b5f28300ca0d, 0x8bca9d6e188853fd}, + {0x9e74d1b791e07e48, 0x775ea264cf55347e}, + {0xc612062576589dda, 0x95364afe032a819e}, + {0xf79687aed3eec551, 0x3a83ddbd83f52205}, + {0x9abe14cd44753b52, 0xc4926a9672793543}, + {0xc16d9a0095928a27, 0x75b7053c0f178294}, + {0xf1c90080baf72cb1, 0x5324c68b12dd6339}, + {0x971da05074da7bee, 0xd3f6fc16ebca5e04}, + {0xbce5086492111aea, 0x88f4bb1ca6bcf585}, + {0xec1e4a7db69561a5, 0x2b31e9e3d06c32e6}, + {0x9392ee8e921d5d07, 0x3aff322e62439fd0}, + {0xb877aa3236a4b449, 0x09befeb9fad487c3}, + {0xe69594bec44de15b, 0x4c2ebe687989a9b4}, + {0x901d7cf73ab0acd9, 0x0f9d37014bf60a11}, + {0xb424dc35095cd80f, 0x538484c19ef38c95}, + {0xe12e13424bb40e13, 0x2865a5f206b06fba}, + {0x8cbccc096f5088cb, 0xf93f87b7442e45d4}, + {0xafebff0bcb24aafe, 0xf78f69a51539d749}, + {0xdbe6fecebdedd5be, 0xb573440e5a884d1c}, + {0x89705f4136b4a597, 0x31680a88f8953031}, + {0xabcc77118461cefc, 0xfdc20d2b36ba7c3e}, + {0xd6bf94d5e57a42bc, 0x3d32907604691b4d}, + {0x8637bd05af6c69b5, 0xa63f9a49c2c1b110}, + {0xa7c5ac471b478423, 0x0fcf80dc33721d54}, + {0xd1b71758e219652b, 0xd3c36113404ea4a9}, + {0x83126e978d4fdf3b, 0x645a1cac083126ea}, + {0xa3d70a3d70a3d70a, 0x3d70a3d70a3d70a4}, + {0xcccccccccccccccc, 0xcccccccccccccccd}, + {0x8000000000000000, 0x0000000000000000}, + {0xa000000000000000, 0x0000000000000000}, + {0xc800000000000000, 0x0000000000000000}, + {0xfa00000000000000, 0x0000000000000000}, + {0x9c40000000000000, 0x0000000000000000}, + {0xc350000000000000, 0x0000000000000000}, + {0xf424000000000000, 0x0000000000000000}, + {0x9896800000000000, 0x0000000000000000}, + {0xbebc200000000000, 0x0000000000000000}, + {0xee6b280000000000, 0x0000000000000000}, + {0x9502f90000000000, 0x0000000000000000}, + {0xba43b74000000000, 0x0000000000000000}, + {0xe8d4a51000000000, 0x0000000000000000}, + {0x9184e72a00000000, 0x0000000000000000}, + {0xb5e620f480000000, 0x0000000000000000}, + {0xe35fa931a0000000, 0x0000000000000000}, + {0x8e1bc9bf04000000, 0x0000000000000000}, + {0xb1a2bc2ec5000000, 0x0000000000000000}, + {0xde0b6b3a76400000, 0x0000000000000000}, + {0x8ac7230489e80000, 0x0000000000000000}, + {0xad78ebc5ac620000, 0x0000000000000000}, + {0xd8d726b7177a8000, 0x0000000000000000}, + {0x878678326eac9000, 0x0000000000000000}, + {0xa968163f0a57b400, 0x0000000000000000}, + {0xd3c21bcecceda100, 0x0000000000000000}, + {0x84595161401484a0, 0x0000000000000000}, + {0xa56fa5b99019a5c8, 0x0000000000000000}, + {0xcecb8f27f4200f3a, 0x0000000000000000}, + {0x813f3978f8940984, 0x4000000000000000}, + {0xa18f07d736b90be5, 0x5000000000000000}, + {0xc9f2c9cd04674ede, 0xa400000000000000}, + {0xfc6f7c4045812296, 0x4d00000000000000}, + {0x9dc5ada82b70b59d, 0xf020000000000000}, + {0xc5371912364ce305, 0x6c28000000000000}, + {0xf684df56c3e01bc6, 0xc732000000000000}, + {0x9a130b963a6c115c, 0x3c7f400000000000}, + {0xc097ce7bc90715b3, 0x4b9f100000000000}, + {0xf0bdc21abb48db20, 0x1e86d40000000000}, + {0x96769950b50d88f4, 0x1314448000000000}, + {0xbc143fa4e250eb31, 0x17d955a000000000}, + {0xeb194f8e1ae525fd, 0x5dcfab0800000000}, + {0x92efd1b8d0cf37be, 0x5aa1cae500000000}, + {0xb7abc627050305ad, 0xf14a3d9e40000000}, + {0xe596b7b0c643c719, 0x6d9ccd05d0000000}, + {0x8f7e32ce7bea5c6f, 0xe4820023a2000000}, + {0xb35dbf821ae4f38b, 0xdda2802c8a800000}, + {0xe0352f62a19e306e, 0xd50b2037ad200000}, + {0x8c213d9da502de45, 0x4526f422cc340000}, + {0xaf298d050e4395d6, 0x9670b12b7f410000}, + {0xdaf3f04651d47b4c, 0x3c0cdd765f114000}, + {0x88d8762bf324cd0f, 0xa5880a69fb6ac800}, + {0xab0e93b6efee0053, 0x8eea0d047a457a00}, + {0xd5d238a4abe98068, 0x72a4904598d6d880}, + {0x85a36366eb71f041, 0x47a6da2b7f864750}, + {0xa70c3c40a64e6c51, 0x999090b65f67d924}, + {0xd0cf4b50cfe20765, 0xfff4b4e3f741cf6d}, + {0x82818f1281ed449f, 0xbff8f10e7a8921a5}, + {0xa321f2d7226895c7, 0xaff72d52192b6a0e}, + {0xcbea6f8ceb02bb39, 0x9bf4f8a69f764491}, + {0xfee50b7025c36a08, 0x02f236d04753d5b5}, + {0x9f4f2726179a2245, 0x01d762422c946591}, + {0xc722f0ef9d80aad6, 0x424d3ad2b7b97ef6}, + {0xf8ebad2b84e0d58b, 0xd2e0898765a7deb3}, + {0x9b934c3b330c8577, 0x63cc55f49f88eb30}, + {0xc2781f49ffcfa6d5, 0x3cbf6b71c76b25fc}, + {0xf316271c7fc3908a, 0x8bef464e3945ef7b}, + {0x97edd871cfda3a56, 0x97758bf0e3cbb5ad}, + {0xbde94e8e43d0c8ec, 0x3d52eeed1cbea318}, + {0xed63a231d4c4fb27, 0x4ca7aaa863ee4bde}, + {0x945e455f24fb1cf8, 0x8fe8caa93e74ef6b}, + {0xb975d6b6ee39e436, 0xb3e2fd538e122b45}, + {0xe7d34c64a9c85d44, 0x60dbbca87196b617}, + {0x90e40fbeea1d3a4a, 0xbc8955e946fe31ce}, + {0xb51d13aea4a488dd, 0x6babab6398bdbe42}, + {0xe264589a4dcdab14, 0xc696963c7eed2dd2}, + {0x8d7eb76070a08aec, 0xfc1e1de5cf543ca3}, + {0xb0de65388cc8ada8, 0x3b25a55f43294bcc}, + {0xdd15fe86affad912, 0x49ef0eb713f39ebf}, + {0x8a2dbf142dfcc7ab, 0x6e3569326c784338}, + {0xacb92ed9397bf996, 0x49c2c37f07965405}, + {0xd7e77a8f87daf7fb, 0xdc33745ec97be907}, + {0x86f0ac99b4e8dafd, 0x69a028bb3ded71a4}, + {0xa8acd7c0222311bc, 0xc40832ea0d68ce0d}, + {0xd2d80db02aabd62b, 0xf50a3fa490c30191}, + {0x83c7088e1aab65db, 0x792667c6da79e0fb}, + {0xa4b8cab1a1563f52, 0x577001b891185939}, + {0xcde6fd5e09abcf26, 0xed4c0226b55e6f87}, + {0x80b05e5ac60b6178, 0x544f8158315b05b5}, + {0xa0dc75f1778e39d6, 0x696361ae3db1c722}, + {0xc913936dd571c84c, 0x03bc3a19cd1e38ea}, + {0xfb5878494ace3a5f, 0x04ab48a04065c724}, + {0x9d174b2dcec0e47b, 0x62eb0d64283f9c77}, + {0xc45d1df942711d9a, 0x3ba5d0bd324f8395}, + {0xf5746577930d6500, 0xca8f44ec7ee3647a}, + {0x9968bf6abbe85f20, 0x7e998b13cf4e1ecc}, + {0xbfc2ef456ae276e8, 0x9e3fedd8c321a67f}, + {0xefb3ab16c59b14a2, 0xc5cfe94ef3ea101f}, + {0x95d04aee3b80ece5, 0xbba1f1d158724a13}, + {0xbb445da9ca61281f, 0x2a8a6e45ae8edc98}, + {0xea1575143cf97226, 0xf52d09d71a3293be}, + {0x924d692ca61be758, 0x593c2626705f9c57}, + {0xb6e0c377cfa2e12e, 0x6f8b2fb00c77836d}, + {0xe498f455c38b997a, 0x0b6dfb9c0f956448}, + {0x8edf98b59a373fec, 0x4724bd4189bd5ead}, + {0xb2977ee300c50fe7, 0x58edec91ec2cb658}, + {0xdf3d5e9bc0f653e1, 0x2f2967b66737e3ee}, + {0x8b865b215899f46c, 0xbd79e0d20082ee75}, + {0xae67f1e9aec07187, 0xecd8590680a3aa12}, + {0xda01ee641a708de9, 0xe80e6f4820cc9496}, + {0x884134fe908658b2, 0x3109058d147fdcde}, + {0xaa51823e34a7eede, 0xbd4b46f0599fd416}, + {0xd4e5e2cdc1d1ea96, 0x6c9e18ac7007c91b}, + {0x850fadc09923329e, 0x03e2cf6bc604ddb1}, + {0xa6539930bf6bff45, 0x84db8346b786151d}, + {0xcfe87f7cef46ff16, 0xe612641865679a64}, + {0x81f14fae158c5f6e, 0x4fcb7e8f3f60c07f}, + {0xa26da3999aef7749, 0xe3be5e330f38f09e}, + {0xcb090c8001ab551c, 0x5cadf5bfd3072cc6}, + {0xfdcb4fa002162a63, 0x73d9732fc7c8f7f7}, + {0x9e9f11c4014dda7e, 0x2867e7fddcdd9afb}, + {0xc646d63501a1511d, 0xb281e1fd541501b9}, + {0xf7d88bc24209a565, 0x1f225a7ca91a4227}, + {0x9ae757596946075f, 0x3375788de9b06959}, + {0xc1a12d2fc3978937, 0x0052d6b1641c83af}, + {0xf209787bb47d6b84, 0xc0678c5dbd23a49b}, + {0x9745eb4d50ce6332, 0xf840b7ba963646e1}, + {0xbd176620a501fbff, 0xb650e5a93bc3d899}, + {0xec5d3fa8ce427aff, 0xa3e51f138ab4cebf}, + {0x93ba47c980e98cdf, 0xc66f336c36b10138}, + {0xb8a8d9bbe123f017, 0xb80b0047445d4185}, + {0xe6d3102ad96cec1d, 0xa60dc059157491e6}, + {0x9043ea1ac7e41392, 0x87c89837ad68db30}, + {0xb454e4a179dd1877, 0x29babe4598c311fc}, + {0xe16a1dc9d8545e94, 0xf4296dd6fef3d67b}, + {0x8ce2529e2734bb1d, 0x1899e4a65f58660d}, + {0xb01ae745b101e9e4, 0x5ec05dcff72e7f90}, + {0xdc21a1171d42645d, 0x76707543f4fa1f74}, + {0x899504ae72497eba, 0x6a06494a791c53a9}, + {0xabfa45da0edbde69, 0x0487db9d17636893}, + {0xd6f8d7509292d603, 0x45a9d2845d3c42b7}, + {0x865b86925b9bc5c2, 0x0b8a2392ba45a9b3}, + {0xa7f26836f282b732, 0x8e6cac7768d7141f}, + {0xd1ef0244af2364ff, 0x3207d795430cd927}, + {0x8335616aed761f1f, 0x7f44e6bd49e807b9}, + {0xa402b9c5a8d3a6e7, 0x5f16206c9c6209a7}, + {0xcd036837130890a1, 0x36dba887c37a8c10}, + {0x802221226be55a64, 0xc2494954da2c978a}, + {0xa02aa96b06deb0fd, 0xf2db9baa10b7bd6d}, + {0xc83553c5c8965d3d, 0x6f92829494e5acc8}, + {0xfa42a8b73abbf48c, 0xcb772339ba1f17fa}, + {0x9c69a97284b578d7, 0xff2a760414536efc}, + {0xc38413cf25e2d70d, 0xfef5138519684abb}, + {0xf46518c2ef5b8cd1, 0x7eb258665fc25d6a}, + {0x98bf2f79d5993802, 0xef2f773ffbd97a62}, + {0xbeeefb584aff8603, 0xaafb550ffacfd8fb}, + {0xeeaaba2e5dbf6784, 0x95ba2a53f983cf39}, + {0x952ab45cfa97a0b2, 0xdd945a747bf26184}, + {0xba756174393d88df, 0x94f971119aeef9e5}, + {0xe912b9d1478ceb17, 0x7a37cd5601aab85e}, + {0x91abb422ccb812ee, 0xac62e055c10ab33b}, + {0xb616a12b7fe617aa, 0x577b986b314d600a}, + {0xe39c49765fdf9d94, 0xed5a7e85fda0b80c}, + {0x8e41ade9fbebc27d, 0x14588f13be847308}, + {0xb1d219647ae6b31c, 0x596eb2d8ae258fc9}, + {0xde469fbd99a05fe3, 0x6fca5f8ed9aef3bc}, + {0x8aec23d680043bee, 0x25de7bb9480d5855}, + {0xada72ccc20054ae9, 0xaf561aa79a10ae6b}, + {0xd910f7ff28069da4, 0x1b2ba1518094da05}, + {0x87aa9aff79042286, 0x90fb44d2f05d0843}, + {0xa99541bf57452b28, 0x353a1607ac744a54}, + {0xd3fa922f2d1675f2, 0x42889b8997915ce9}, + {0x847c9b5d7c2e09b7, 0x69956135febada12}, + {0xa59bc234db398c25, 0x43fab9837e699096}, + {0xcf02b2c21207ef2e, 0x94f967e45e03f4bc}, + {0x8161afb94b44f57d, 0x1d1be0eebac278f6}, + {0xa1ba1ba79e1632dc, 0x6462d92a69731733}, + {0xca28a291859bbf93, 0x7d7b8f7503cfdcff}, + {0xfcb2cb35e702af78, 0x5cda735244c3d43f}, + {0x9defbf01b061adab, 0x3a0888136afa64a8}, + {0xc56baec21c7a1916, 0x088aaa1845b8fdd1}, + {0xf6c69a72a3989f5b, 0x8aad549e57273d46}, + {0x9a3c2087a63f6399, 0x36ac54e2f678864c}, + {0xc0cb28a98fcf3c7f, 0x84576a1bb416a7de}, + {0xf0fdf2d3f3c30b9f, 0x656d44a2a11c51d6}, + {0x969eb7c47859e743, 0x9f644ae5a4b1b326}, + {0xbc4665b596706114, 0x873d5d9f0dde1fef}, + {0xeb57ff22fc0c7959, 0xa90cb506d155a7eb}, + {0x9316ff75dd87cbd8, 0x09a7f12442d588f3}, + {0xb7dcbf5354e9bece, 0x0c11ed6d538aeb30}, + {0xe5d3ef282a242e81, 0x8f1668c8a86da5fb}, + {0x8fa475791a569d10, 0xf96e017d694487bd}, + {0xb38d92d760ec4455, 0x37c981dcc395a9ad}, + {0xe070f78d3927556a, 0x85bbe253f47b1418}, + {0x8c469ab843b89562, 0x93956d7478ccec8f}, + {0xaf58416654a6babb, 0x387ac8d1970027b3}, + {0xdb2e51bfe9d0696a, 0x06997b05fcc0319f}, + {0x88fcf317f22241e2, 0x441fece3bdf81f04}, + {0xab3c2fddeeaad25a, 0xd527e81cad7626c4}, + {0xd60b3bd56a5586f1, 0x8a71e223d8d3b075}, + {0x85c7056562757456, 0xf6872d5667844e4a}, + {0xa738c6bebb12d16c, 0xb428f8ac016561dc}, + {0xd106f86e69d785c7, 0xe13336d701beba53}, + {0x82a45b450226b39c, 0xecc0024661173474}, + {0xa34d721642b06084, 0x27f002d7f95d0191}, + {0xcc20ce9bd35c78a5, 0x31ec038df7b441f5}, + {0xff290242c83396ce, 0x7e67047175a15272}, + {0x9f79a169bd203e41, 0x0f0062c6e984d387}, + {0xc75809c42c684dd1, 0x52c07b78a3e60869}, + {0xf92e0c3537826145, 0xa7709a56ccdf8a83}, + {0x9bbcc7a142b17ccb, 0x88a66076400bb692}, + {0xc2abf989935ddbfe, 0x6acff893d00ea436}, + {0xf356f7ebf83552fe, 0x0583f6b8c4124d44}, + {0x98165af37b2153de, 0xc3727a337a8b704b}, + {0xbe1bf1b059e9a8d6, 0x744f18c0592e4c5d}, + {0xeda2ee1c7064130c, 0x1162def06f79df74}, + {0x9485d4d1c63e8be7, 0x8addcb5645ac2ba9}, + {0xb9a74a0637ce2ee1, 0x6d953e2bd7173693}, + {0xe8111c87c5c1ba99, 0xc8fa8db6ccdd0438}, + {0x910ab1d4db9914a0, 0x1d9c9892400a22a3}, + {0xb54d5e4a127f59c8, 0x2503beb6d00cab4c}, + {0xe2a0b5dc971f303a, 0x2e44ae64840fd61e}, + {0x8da471a9de737e24, 0x5ceaecfed289e5d3}, + {0xb10d8e1456105dad, 0x7425a83e872c5f48}, + {0xdd50f1996b947518, 0xd12f124e28f7771a}, + {0x8a5296ffe33cc92f, 0x82bd6b70d99aaa70}, + {0xace73cbfdc0bfb7b, 0x636cc64d1001550c}, + {0xd8210befd30efa5a, 0x3c47f7e05401aa4f}, + {0x8714a775e3e95c78, 0x65acfaec34810a72}, + {0xa8d9d1535ce3b396, 0x7f1839a741a14d0e}, + {0xd31045a8341ca07c, 0x1ede48111209a051}, + {0x83ea2b892091e44d, 0x934aed0aab460433}, + {0xa4e4b66b68b65d60, 0xf81da84d56178540}, + {0xce1de40642e3f4b9, 0x36251260ab9d668f}, + {0x80d2ae83e9ce78f3, 0xc1d72b7c6b42601a}, + {0xa1075a24e4421730, 0xb24cf65b8612f820}, + {0xc94930ae1d529cfc, 0xdee033f26797b628}, + {0xfb9b7cd9a4a7443c, 0x169840ef017da3b2}, + {0x9d412e0806e88aa5, 0x8e1f289560ee864f}, + {0xc491798a08a2ad4e, 0xf1a6f2bab92a27e3}, + {0xf5b5d7ec8acb58a2, 0xae10af696774b1dc}, + {0x9991a6f3d6bf1765, 0xacca6da1e0a8ef2a}, + {0xbff610b0cc6edd3f, 0x17fd090a58d32af4}, + {0xeff394dcff8a948e, 0xddfc4b4cef07f5b1}, + {0x95f83d0a1fb69cd9, 0x4abdaf101564f98f}, + {0xbb764c4ca7a4440f, 0x9d6d1ad41abe37f2}, + {0xea53df5fd18d5513, 0x84c86189216dc5ee}, + {0x92746b9be2f8552c, 0x32fd3cf5b4e49bb5}, + {0xb7118682dbb66a77, 0x3fbc8c33221dc2a2}, + {0xe4d5e82392a40515, 0x0fabaf3feaa5334b}, + {0x8f05b1163ba6832d, 0x29cb4d87f2a7400f}, + {0xb2c71d5bca9023f8, 0x743e20e9ef511013}, + {0xdf78e4b2bd342cf6, 0x914da9246b255417}, + {0x8bab8eefb6409c1a, 0x1ad089b6c2f7548f}, + {0xae9672aba3d0c320, 0xa184ac2473b529b2}, + {0xda3c0f568cc4f3e8, 0xc9e5d72d90a2741f}, + {0x8865899617fb1871, 0x7e2fa67c7a658893}, + {0xaa7eebfb9df9de8d, 0xddbb901b98feeab8}, + {0xd51ea6fa85785631, 0x552a74227f3ea566}, + {0x8533285c936b35de, 0xd53a88958f872760}, + {0xa67ff273b8460356, 0x8a892abaf368f138}, + {0xd01fef10a657842c, 0x2d2b7569b0432d86}, + {0x8213f56a67f6b29b, 0x9c3b29620e29fc74}, + {0xa298f2c501f45f42, 0x8349f3ba91b47b90}, + {0xcb3f2f7642717713, 0x241c70a936219a74}, + {0xfe0efb53d30dd4d7, 0xed238cd383aa0111}, + {0x9ec95d1463e8a506, 0xf4363804324a40ab}, + {0xc67bb4597ce2ce48, 0xb143c6053edcd0d6}, + {0xf81aa16fdc1b81da, 0xdd94b7868e94050b}, + {0x9b10a4e5e9913128, 0xca7cf2b4191c8327}, + {0xc1d4ce1f63f57d72, 0xfd1c2f611f63a3f1}, + {0xf24a01a73cf2dccf, 0xbc633b39673c8ced}, + {0x976e41088617ca01, 0xd5be0503e085d814}, + {0xbd49d14aa79dbc82, 0x4b2d8644d8a74e19}, + {0xec9c459d51852ba2, 0xddf8e7d60ed1219f}, + {0x93e1ab8252f33b45, 0xcabb90e5c942b504}, + {0xb8da1662e7b00a17, 0x3d6a751f3b936244}, + {0xe7109bfba19c0c9d, 0x0cc512670a783ad5}, + {0x906a617d450187e2, 0x27fb2b80668b24c6}, + {0xb484f9dc9641e9da, 0xb1f9f660802dedf7}, + {0xe1a63853bbd26451, 0x5e7873f8a0396974}, + {0x8d07e33455637eb2, 0xdb0b487b6423e1e9}, + {0xb049dc016abc5e5f, 0x91ce1a9a3d2cda63}, + {0xdc5c5301c56b75f7, 0x7641a140cc7810fc}, + {0x89b9b3e11b6329ba, 0xa9e904c87fcb0a9e}, + {0xac2820d9623bf429, 0x546345fa9fbdcd45}, + {0xd732290fbacaf133, 0xa97c177947ad4096}, + {0x867f59a9d4bed6c0, 0x49ed8eabcccc485e}, + {0xa81f301449ee8c70, 0x5c68f256bfff5a75}, + {0xd226fc195c6a2f8c, 0x73832eec6fff3112}, + {0x83585d8fd9c25db7, 0xc831fd53c5ff7eac}, + {0xa42e74f3d032f525, 0xba3e7ca8b77f5e56}, + {0xcd3a1230c43fb26f, 0x28ce1bd2e55f35ec}, + {0x80444b5e7aa7cf85, 0x7980d163cf5b81b4}, + {0xa0555e361951c366, 0xd7e105bcc3326220}, + {0xc86ab5c39fa63440, 0x8dd9472bf3fefaa8}, + {0xfa856334878fc150, 0xb14f98f6f0feb952}, + {0x9c935e00d4b9d8d2, 0x6ed1bf9a569f33d4}, + {0xc3b8358109e84f07, 0x0a862f80ec4700c9}, + {0xf4a642e14c6262c8, 0xcd27bb612758c0fb}, + {0x98e7e9cccfbd7dbd, 0x8038d51cb897789d}, + {0xbf21e44003acdd2c, 0xe0470a63e6bd56c4}, + {0xeeea5d5004981478, 0x1858ccfce06cac75}, + {0x95527a5202df0ccb, 0x0f37801e0c43ebc9}, + {0xbaa718e68396cffd, 0xd30560258f54e6bb}, + {0xe950df20247c83fd, 0x47c6b82ef32a206a}, + {0x91d28b7416cdd27e, 0x4cdc331d57fa5442}, + {0xb6472e511c81471d, 0xe0133fe4adf8e953}, + {0xe3d8f9e563a198e5, 0x58180fddd97723a7}, + {0x8e679c2f5e44ff8f, 0x570f09eaa7ea7649}, + {0xb201833b35d63f73, 0x2cd2cc6551e513db}, + {0xde81e40a034bcf4f, 0xf8077f7ea65e58d2}, + {0x8b112e86420f6191, 0xfb04afaf27faf783}, + {0xadd57a27d29339f6, 0x79c5db9af1f9b564}, + {0xd94ad8b1c7380874, 0x18375281ae7822bd}, + {0x87cec76f1c830548, 0x8f2293910d0b15b6}, + {0xa9c2794ae3a3c69a, 0xb2eb3875504ddb23}, + {0xd433179d9c8cb841, 0x5fa60692a46151ec}, + {0x849feec281d7f328, 0xdbc7c41ba6bcd334}, + {0xa5c7ea73224deff3, 0x12b9b522906c0801}, + {0xcf39e50feae16bef, 0xd768226b34870a01}, + {0x81842f29f2cce375, 0xe6a1158300d46641}, + {0xa1e53af46f801c53, 0x60495ae3c1097fd1}, + {0xca5e89b18b602368, 0x385bb19cb14bdfc5}, + {0xfcf62c1dee382c42, 0x46729e03dd9ed7b6}, + {0x9e19db92b4e31ba9, 0x6c07a2c26a8346d2}, + {0xc5a05277621be293, 0xc7098b7305241886}, + {0xf70867153aa2db38, 0xb8cbee4fc66d1ea8}, + {0x9a65406d44a5c903, 0x737f74f1dc043329}, + {0xc0fe908895cf3b44, 0x505f522e53053ff3}, + {0xf13e34aabb430a15, 0x647726b9e7c68ff0}, + {0x96c6e0eab509e64d, 0x5eca783430dc19f6}, + {0xbc789925624c5fe0, 0xb67d16413d132073}, + {0xeb96bf6ebadf77d8, 0xe41c5bd18c57e890}, + {0x933e37a534cbaae7, 0x8e91b962f7b6f15a}, + {0xb80dc58e81fe95a1, 0x723627bbb5a4adb1}, + {0xe61136f2227e3b09, 0xcec3b1aaa30dd91d}, + {0x8fcac257558ee4e6, 0x213a4f0aa5e8a7b2}, + {0xb3bd72ed2af29e1f, 0xa988e2cd4f62d19e}, + {0xe0accfa875af45a7, 0x93eb1b80a33b8606}, + {0x8c6c01c9498d8b88, 0xbc72f130660533c4}, + {0xaf87023b9bf0ee6a, 0xeb8fad7c7f8680b5}, + {0xdb68c2ca82ed2a05, 0xa67398db9f6820e2}, +#else + {0xff77b1fcbebcdc4f, 0x25e8e89c13bb0f7b}, + {0xce5d73ff402d98e3, 0xfb0a3d212dc81290}, + {0xa6b34ad8c9dfc06f, 0xf42faa48c0ea481f}, + {0x86a8d39ef77164bc, 0xae5dff9c02033198}, + {0xd98ddaee19068c76, 0x3badd624dd9b0958}, + {0xafbd2350644eeacf, 0xe5d1929ef90898fb}, + {0x8df5efabc5979c8f, 0xca8d3ffa1ef463c2}, + {0xe55990879ddcaabd, 0xcc420a6a101d0516}, + {0xb94470938fa89bce, 0xf808e40e8d5b3e6a}, + {0x95a8637627989aad, 0xdde7001379a44aa9}, + {0xf1c90080baf72cb1, 0x5324c68b12dd6339}, + {0xc350000000000000, 0x0000000000000000}, + {0x9dc5ada82b70b59d, 0xf020000000000000}, + {0xfee50b7025c36a08, 0x02f236d04753d5b5}, + {0xcde6fd5e09abcf26, 0xed4c0226b55e6f87}, + {0xa6539930bf6bff45, 0x84db8346b786151d}, + {0x865b86925b9bc5c2, 0x0b8a2392ba45a9b3}, + {0xd910f7ff28069da4, 0x1b2ba1518094da05}, + {0xaf58416654a6babb, 0x387ac8d1970027b3}, + {0x8da471a9de737e24, 0x5ceaecfed289e5d3}, + {0xe4d5e82392a40515, 0x0fabaf3feaa5334b}, + {0xb8da1662e7b00a17, 0x3d6a751f3b936244}, + {0x95527a5202df0ccb, 0x0f37801e0c43ebc9}, + {0xf13e34aabb430a15, 0x647726b9e7c68ff0} +#endif + }; + +#if FMT_USE_FULL_CACHE_DRAGONBOX + return pow10_significands[k - float_info::min_k]; +#else + static constexpr uint64_t powers_of_5_64[] = { + 0x0000000000000001, 0x0000000000000005, 0x0000000000000019, + 0x000000000000007d, 0x0000000000000271, 0x0000000000000c35, + 0x0000000000003d09, 0x000000000001312d, 0x000000000005f5e1, + 0x00000000001dcd65, 0x00000000009502f9, 0x0000000002e90edd, + 0x000000000e8d4a51, 0x0000000048c27395, 0x000000016bcc41e9, + 0x000000071afd498d, 0x0000002386f26fc1, 0x000000b1a2bc2ec5, + 0x000003782dace9d9, 0x00001158e460913d, 0x000056bc75e2d631, + 0x0001b1ae4d6e2ef5, 0x000878678326eac9, 0x002a5a058fc295ed, + 0x00d3c21bcecceda1, 0x0422ca8b0a00a425, 0x14adf4b7320334b9}; + + static const int compression_ratio = 27; + + // Compute base index. + int cache_index = (k - float_info::min_k) / compression_ratio; + int kb = cache_index * compression_ratio + float_info::min_k; + int offset = k - kb; + + // Get base cache. + uint128_fallback base_cache = pow10_significands[cache_index]; + if (offset == 0) return base_cache; + + // Compute the required amount of bit-shift. + int alpha = floor_log2_pow10(kb + offset) - floor_log2_pow10(kb) - offset; + FMT_ASSERT(alpha > 0 && alpha < 64, "shifting error detected"); + + // Try to recover the real cache. + uint64_t pow5 = powers_of_5_64[offset]; + uint128_fallback recovered_cache = umul128(base_cache.high(), pow5); + uint128_fallback middle_low = umul128(base_cache.low(), pow5); + + recovered_cache += middle_low.high(); + + uint64_t high_to_middle = recovered_cache.high() << (64 - alpha); + uint64_t middle_to_low = recovered_cache.low() << (64 - alpha); + + recovered_cache = + uint128_fallback{(recovered_cache.low() >> alpha) | high_to_middle, + ((middle_low.low() >> alpha) | middle_to_low)}; + FMT_ASSERT(recovered_cache.low() + 1 != 0, ""); + return {recovered_cache.high(), recovered_cache.low() + 1}; +#endif + } + + struct compute_mul_result { + carrier_uint result; + bool is_integer; + }; + struct compute_mul_parity_result { + bool parity; + bool is_integer; + }; + + static auto compute_mul(carrier_uint u, + const cache_entry_type& cache) noexcept + -> compute_mul_result { + auto r = umul192_upper128(u, cache); + return {r.high(), r.low() == 0}; + } + + static auto compute_delta(const cache_entry_type& cache, int beta) noexcept + -> uint32_t { + return static_cast(cache.high() >> (64 - 1 - beta)); + } + + static auto compute_mul_parity(carrier_uint two_f, + const cache_entry_type& cache, + int beta) noexcept + -> compute_mul_parity_result { + FMT_ASSERT(beta >= 1, ""); + FMT_ASSERT(beta < 64, ""); + + auto r = umul192_lower128(two_f, cache); + return {((r.high() >> (64 - beta)) & 1) != 0, + ((r.high() << beta) | (r.low() >> (64 - beta))) == 0}; + } + + static auto compute_left_endpoint_for_shorter_interval_case( + const cache_entry_type& cache, int beta) noexcept -> carrier_uint { + return (cache.high() - + (cache.high() >> (num_significand_bits() + 2))) >> + (64 - num_significand_bits() - 1 - beta); + } + + static auto compute_right_endpoint_for_shorter_interval_case( + const cache_entry_type& cache, int beta) noexcept -> carrier_uint { + return (cache.high() + + (cache.high() >> (num_significand_bits() + 1))) >> + (64 - num_significand_bits() - 1 - beta); + } + + static auto compute_round_up_for_shorter_interval_case( + const cache_entry_type& cache, int beta) noexcept -> carrier_uint { + return ((cache.high() >> (64 - num_significand_bits() - 2 - beta)) + + 1) / + 2; + } +}; + +FMT_FUNC auto get_cached_power(int k) noexcept -> uint128_fallback { + return cache_accessor::get_cached_power(k); +} + +// Various integer checks +template +auto is_left_endpoint_integer_shorter_interval(int exponent) noexcept -> bool { + const int case_shorter_interval_left_endpoint_lower_threshold = 2; + const int case_shorter_interval_left_endpoint_upper_threshold = 3; + return exponent >= case_shorter_interval_left_endpoint_lower_threshold && + exponent <= case_shorter_interval_left_endpoint_upper_threshold; +} + +// Remove trailing zeros from n and return the number of zeros removed (float). +FMT_INLINE auto remove_trailing_zeros(uint32_t& n, int s = 0) noexcept -> int { + FMT_ASSERT(n != 0, ""); + // Modular inverse of 5 (mod 2^32): (mod_inv_5 * 5) mod 2^32 = 1. + constexpr uint32_t mod_inv_5 = 0xcccccccd; + constexpr uint32_t mod_inv_25 = 0xc28f5c29; // = mod_inv_5 * mod_inv_5 + + while (true) { + auto q = rotr(n * mod_inv_25, 2); + if (q > max_value() / 100) break; + n = q; + s += 2; + } + auto q = rotr(n * mod_inv_5, 1); + if (q <= max_value() / 10) { + n = q; + s |= 1; + } + return s; +} + +// Removes trailing zeros and returns the number of zeros removed (double). +FMT_INLINE auto remove_trailing_zeros(uint64_t& n) noexcept -> int { + FMT_ASSERT(n != 0, ""); + + // Is n is divisible by 10^8? + constexpr uint32_t ten_pow_8 = 100000000u; + if ((n % ten_pow_8) == 0) { + // If yes, work with the quotient... + auto n32 = static_cast(n / ten_pow_8); + // ... and use the 32 bit variant of the function + int num_zeros = remove_trailing_zeros(n32, 8); + n = n32; + return num_zeros; + } + + // If n is not divisible by 10^8, work with n itself. + constexpr uint64_t mod_inv_5 = 0xcccccccccccccccd; + constexpr uint64_t mod_inv_25 = 0x8f5c28f5c28f5c29; // mod_inv_5 * mod_inv_5 + + int s = 0; + while (true) { + auto q = rotr(n * mod_inv_25, 2); + if (q > max_value() / 100) break; + n = q; + s += 2; + } + auto q = rotr(n * mod_inv_5, 1); + if (q <= max_value() / 10) { + n = q; + s |= 1; + } + + return s; +} + +// The main algorithm for shorter interval case +template +FMT_INLINE auto shorter_interval_case(int exponent) noexcept -> decimal_fp { + decimal_fp ret_value; + // Compute k and beta + const int minus_k = floor_log10_pow2_minus_log10_4_over_3(exponent); + const int beta = exponent + floor_log2_pow10(-minus_k); + + // Compute xi and zi + using cache_entry_type = typename cache_accessor::cache_entry_type; + const cache_entry_type cache = cache_accessor::get_cached_power(-minus_k); + + auto xi = cache_accessor::compute_left_endpoint_for_shorter_interval_case( + cache, beta); + auto zi = cache_accessor::compute_right_endpoint_for_shorter_interval_case( + cache, beta); + + // If the left endpoint is not an integer, increase it + if (!is_left_endpoint_integer_shorter_interval(exponent)) ++xi; + + // Try bigger divisor + ret_value.significand = zi / 10; + + // If succeed, remove trailing zeros if necessary and return + if (ret_value.significand * 10 >= xi) { + ret_value.exponent = minus_k + 1; + ret_value.exponent += remove_trailing_zeros(ret_value.significand); + return ret_value; + } + + // Otherwise, compute the round-up of y + ret_value.significand = + cache_accessor::compute_round_up_for_shorter_interval_case(cache, + beta); + ret_value.exponent = minus_k; + + // When tie occurs, choose one of them according to the rule + if (exponent >= float_info::shorter_interval_tie_lower_threshold && + exponent <= float_info::shorter_interval_tie_upper_threshold) { + ret_value.significand = ret_value.significand % 2 == 0 + ? ret_value.significand + : ret_value.significand - 1; + } else if (ret_value.significand < xi) { + ++ret_value.significand; + } + return ret_value; +} + +template auto to_decimal(T x) noexcept -> decimal_fp { + // Step 1: integer promotion & Schubfach multiplier calculation. + + using carrier_uint = typename float_info::carrier_uint; + using cache_entry_type = typename cache_accessor::cache_entry_type; + auto br = bit_cast(x); + + // Extract significand bits and exponent bits. + const carrier_uint significand_mask = + (static_cast(1) << num_significand_bits()) - 1; + carrier_uint significand = (br & significand_mask); + int exponent = + static_cast((br & exponent_mask()) >> num_significand_bits()); + + if (exponent != 0) { // Check if normal. + exponent -= exponent_bias() + num_significand_bits(); + + // Shorter interval case; proceed like Schubfach. + // In fact, when exponent == 1 and significand == 0, the interval is + // regular. However, it can be shown that the end-results are anyway same. + if (significand == 0) return shorter_interval_case(exponent); + + significand |= (static_cast(1) << num_significand_bits()); + } else { + // Subnormal case; the interval is always regular. + if (significand == 0) return {0, 0}; + exponent = + std::numeric_limits::min_exponent - num_significand_bits() - 1; + } + + const bool include_left_endpoint = (significand % 2 == 0); + const bool include_right_endpoint = include_left_endpoint; + + // Compute k and beta. + const int minus_k = floor_log10_pow2(exponent) - float_info::kappa; + const cache_entry_type cache = cache_accessor::get_cached_power(-minus_k); + const int beta = exponent + floor_log2_pow10(-minus_k); + + // Compute zi and deltai. + // 10^kappa <= deltai < 10^(kappa + 1) + const uint32_t deltai = cache_accessor::compute_delta(cache, beta); + const carrier_uint two_fc = significand << 1; + + // For the case of binary32, the result of integer check is not correct for + // 29711844 * 2^-82 + // = 6.1442653300000000008655037797566933477355632930994033813476... * 10^-18 + // and 29711844 * 2^-81 + // = 1.2288530660000000001731007559513386695471126586198806762695... * 10^-17, + // and they are the unique counterexamples. However, since 29711844 is even, + // this does not cause any problem for the endpoints calculations; it can only + // cause a problem when we need to perform integer check for the center. + // Fortunately, with these inputs, that branch is never executed, so we are + // fine. + const typename cache_accessor::compute_mul_result z_mul = + cache_accessor::compute_mul((two_fc | 1) << beta, cache); + + // Step 2: Try larger divisor; remove trailing zeros if necessary. + + // Using an upper bound on zi, we might be able to optimize the division + // better than the compiler; we are computing zi / big_divisor here. + decimal_fp ret_value; + ret_value.significand = divide_by_10_to_kappa_plus_1(z_mul.result); + uint32_t r = static_cast(z_mul.result - float_info::big_divisor * + ret_value.significand); + + if (r < deltai) { + // Exclude the right endpoint if necessary. + if (r == 0 && (z_mul.is_integer & !include_right_endpoint)) { + --ret_value.significand; + r = float_info::big_divisor; + goto small_divisor_case_label; + } + } else if (r > deltai) { + goto small_divisor_case_label; + } else { + // r == deltai; compare fractional parts. + const typename cache_accessor::compute_mul_parity_result x_mul = + cache_accessor::compute_mul_parity(two_fc - 1, cache, beta); + + if (!(x_mul.parity | (x_mul.is_integer & include_left_endpoint))) + goto small_divisor_case_label; + } + ret_value.exponent = minus_k + float_info::kappa + 1; + + // We may need to remove trailing zeros. + ret_value.exponent += remove_trailing_zeros(ret_value.significand); + return ret_value; + + // Step 3: Find the significand with the smaller divisor. + +small_divisor_case_label: + ret_value.significand *= 10; + ret_value.exponent = minus_k + float_info::kappa; + + uint32_t dist = r - (deltai / 2) + (float_info::small_divisor / 2); + const bool approx_y_parity = + ((dist ^ (float_info::small_divisor / 2)) & 1) != 0; + + // Is dist divisible by 10^kappa? + const bool divisible_by_small_divisor = + check_divisibility_and_divide_by_pow10::kappa>(dist); + + // Add dist / 10^kappa to the significand. + ret_value.significand += dist; + + if (!divisible_by_small_divisor) return ret_value; + + // Check z^(f) >= epsilon^(f). + // We have either yi == zi - epsiloni or yi == (zi - epsiloni) - 1, + // where yi == zi - epsiloni if and only if z^(f) >= epsilon^(f). + // Since there are only 2 possibilities, we only need to care about the + // parity. Also, zi and r should have the same parity since the divisor + // is an even number. + const auto y_mul = cache_accessor::compute_mul_parity(two_fc, cache, beta); + + // If z^(f) >= epsilon^(f), we might have a tie when z^(f) == epsilon^(f), + // or equivalently, when y is an integer. + if (y_mul.parity != approx_y_parity) + --ret_value.significand; + else if (y_mul.is_integer & (ret_value.significand % 2 != 0)) + --ret_value.significand; + return ret_value; +} +} // namespace dragonbox +} // namespace detail + +template <> struct formatter { + FMT_CONSTEXPR auto parse(format_parse_context& ctx) + -> format_parse_context::iterator { + return ctx.begin(); + } + + auto format(const detail::bigint& n, format_context& ctx) const + -> format_context::iterator { + auto out = ctx.out(); + bool first = true; + for (auto i = n.bigits_.size(); i > 0; --i) { + auto value = n.bigits_[i - 1u]; + if (first) { + out = fmt::format_to(out, FMT_STRING("{:x}"), value); + first = false; + continue; + } + out = fmt::format_to(out, FMT_STRING("{:08x}"), value); + } + if (n.exp_ > 0) + out = fmt::format_to(out, FMT_STRING("p{}"), + n.exp_ * detail::bigint::bigit_bits); + return out; + } +}; + +FMT_FUNC detail::utf8_to_utf16::utf8_to_utf16(string_view s) { + for_each_codepoint(s, [this](uint32_t cp, string_view) { + if (cp == invalid_code_point) FMT_THROW(std::runtime_error("invalid utf8")); + if (cp <= 0xFFFF) { + buffer_.push_back(static_cast(cp)); + } else { + cp -= 0x10000; + buffer_.push_back(static_cast(0xD800 + (cp >> 10))); + buffer_.push_back(static_cast(0xDC00 + (cp & 0x3FF))); + } + return true; + }); + buffer_.push_back(0); +} + +FMT_FUNC void format_system_error(detail::buffer& out, int error_code, + const char* message) noexcept { + FMT_TRY { + auto ec = std::error_code(error_code, std::generic_category()); + detail::write(appender(out), std::system_error(ec, message).what()); + return; + } + FMT_CATCH(...) {} + format_error_code(out, error_code, message); +} + +FMT_FUNC void report_system_error(int error_code, + const char* message) noexcept { + do_report_error(format_system_error, error_code, message); +} + +FMT_FUNC auto vformat(string_view fmt, format_args args) -> std::string { + // Don't optimize the "{}" case to keep the binary size small and because it + // can be better optimized in fmt::format anyway. + auto buffer = memory_buffer(); + detail::vformat_to(buffer, fmt, args); + return to_string(buffer); +} + +namespace detail { + +FMT_FUNC void vformat_to(buffer& buf, string_view fmt, format_args args, + locale_ref loc) { + auto out = appender(buf); + if (fmt.size() == 2 && equal2(fmt.data(), "{}")) + return args.get(0).visit(default_arg_formatter{out}); + parse_format_string(fmt, + format_handler<>{parse_context<>(fmt), {out, args, loc}}); +} + +template struct span { + T* data; + size_t size; +}; + +template auto flockfile(F* f) -> decltype(_lock_file(f)) { + _lock_file(f); +} +template auto funlockfile(F* f) -> decltype(_unlock_file(f)) { + _unlock_file(f); +} + +#ifndef getc_unlocked +template auto getc_unlocked(F* f) -> decltype(_fgetc_nolock(f)) { + return _fgetc_nolock(f); +} +#endif + +template +struct has_flockfile : std::false_type {}; + +template +struct has_flockfile()))>> + : std::true_type {}; + +// A FILE wrapper. F is FILE defined as a template parameter to make system API +// detection work. +template class file_base { + public: + F* file_; + + public: + file_base(F* file) : file_(file) {} + operator F*() const { return file_; } + + // Reads a code unit from the stream. + auto get() -> int { + int result = getc_unlocked(file_); + if (result == EOF && ferror(file_) != 0) + FMT_THROW(system_error(errno, FMT_STRING("getc failed"))); + return result; + } + + // Puts the code unit back into the stream buffer. + void unget(char c) { + if (ungetc(c, file_) == EOF) + FMT_THROW(system_error(errno, FMT_STRING("ungetc failed"))); + } + + void flush() { fflush(this->file_); } +}; + +// A FILE wrapper for glibc. +template class glibc_file : public file_base { + private: + enum { + line_buffered = 0x200, // _IO_LINE_BUF + unbuffered = 2 // _IO_UNBUFFERED + }; + + public: + using file_base::file_base; + + auto is_buffered() const -> bool { + return (this->file_->_flags & unbuffered) == 0; + } + + void init_buffer() { + if (this->file_->_IO_write_ptr < this->file_->_IO_write_end) return; + // Force buffer initialization by placing and removing a char in a buffer. + putc_unlocked(0, this->file_); + --this->file_->_IO_write_ptr; + } + + // Returns the file's read buffer. + auto get_read_buffer() const -> span { + auto ptr = this->file_->_IO_read_ptr; + return {ptr, to_unsigned(this->file_->_IO_read_end - ptr)}; + } + + // Returns the file's write buffer. + auto get_write_buffer() const -> span { + auto ptr = this->file_->_IO_write_ptr; + return {ptr, to_unsigned(this->file_->_IO_buf_end - ptr)}; + } + + void advance_write_buffer(size_t size) { this->file_->_IO_write_ptr += size; } + + auto needs_flush() const -> bool { + if ((this->file_->_flags & line_buffered) == 0) return false; + char* end = this->file_->_IO_write_end; + auto size = max_of(this->file_->_IO_write_ptr - end, 0); + return memchr(end, '\n', static_cast(size)); + } + + void flush() { fflush_unlocked(this->file_); } +}; + +// A FILE wrapper for Apple's libc. +template class apple_file : public file_base { + private: + enum { + line_buffered = 1, // __SNBF + unbuffered = 2 // __SLBF + }; + + public: + using file_base::file_base; + + auto is_buffered() const -> bool { + return (this->file_->_flags & unbuffered) == 0; + } + + void init_buffer() { + if (this->file_->_p) return; + // Force buffer initialization by placing and removing a char in a buffer. + if (!FMT_CLANG_ANALYZER) putc_unlocked(0, this->file_); + --this->file_->_p; + ++this->file_->_w; + } + + auto get_read_buffer() const -> span { + return {reinterpret_cast(this->file_->_p), + to_unsigned(this->file_->_r)}; + } + + auto get_write_buffer() const -> span { + return {reinterpret_cast(this->file_->_p), + to_unsigned(this->file_->_bf._base + this->file_->_bf._size - + this->file_->_p)}; + } + + void advance_write_buffer(size_t size) { + this->file_->_p += size; + this->file_->_w -= size; + } + + auto needs_flush() const -> bool { + if ((this->file_->_flags & line_buffered) == 0) return false; + return memchr(this->file_->_p + this->file_->_w, '\n', + to_unsigned(-this->file_->_w)); + } +}; + +// A fallback FILE wrapper. +template class fallback_file : public file_base { + private: + char next_; // The next unconsumed character in the buffer. + bool has_next_ = false; + + public: + using file_base::file_base; + + auto is_buffered() const -> bool { return false; } + auto needs_flush() const -> bool { return false; } + void init_buffer() {} + + auto get_read_buffer() const -> span { + return {&next_, has_next_ ? 1u : 0u}; + } + + auto get_write_buffer() const -> span { return {nullptr, 0}; } + + void advance_write_buffer(size_t) {} + + auto get() -> int { + has_next_ = false; + return file_base::get(); + } + + void unget(char c) { + file_base::unget(c); + next_ = c; + has_next_ = true; + } +}; + +#ifndef FMT_USE_FALLBACK_FILE +# define FMT_USE_FALLBACK_FILE 0 +#endif + +template +auto get_file(F* f, int) -> apple_file { + return f; +} +template +inline auto get_file(F* f, int) -> glibc_file { + return f; +} + +inline auto get_file(FILE* f, ...) -> fallback_file { return f; } + +using file_ref = decltype(get_file(static_cast(nullptr), 0)); + +template +class file_print_buffer : public buffer { + public: + explicit file_print_buffer(F*) : buffer(nullptr, size_t()) {} +}; + +template +class file_print_buffer::value>> + : public buffer { + private: + file_ref file_; + + static void grow(buffer& base, size_t) { + auto& self = static_cast(base); + self.file_.advance_write_buffer(self.size()); + if (self.file_.get_write_buffer().size == 0) self.file_.flush(); + auto buf = self.file_.get_write_buffer(); + FMT_ASSERT(buf.size > 0, ""); + self.set(buf.data, buf.size); + self.clear(); + } + + public: + explicit file_print_buffer(F* f) : buffer(grow, size_t()), file_(f) { + flockfile(f); + file_.init_buffer(); + auto buf = file_.get_write_buffer(); + set(buf.data, buf.size); + } + ~file_print_buffer() { + file_.advance_write_buffer(size()); + bool flush = file_.needs_flush(); + F* f = file_; // Make funlockfile depend on the template parameter F + funlockfile(f); // for the system API detection to work. + if (flush) fflush(file_); + } +}; + +#if !defined(_WIN32) || defined(FMT_USE_WRITE_CONSOLE) +FMT_FUNC auto write_console(int, string_view) -> bool { return false; } +#else +using dword = conditional_t; +extern "C" __declspec(dllimport) int __stdcall WriteConsoleW( // + void*, const void*, dword, dword*, void*); + +FMT_FUNC bool write_console(int fd, string_view text) { + auto u16 = utf8_to_utf16(text); + return WriteConsoleW(reinterpret_cast(_get_osfhandle(fd)), u16.c_str(), + static_cast(u16.size()), nullptr, nullptr) != 0; +} +#endif + +#ifdef _WIN32 +// Print assuming legacy (non-Unicode) encoding. +FMT_FUNC void vprint_mojibake(std::FILE* f, string_view fmt, format_args args, + bool newline) { + auto buffer = memory_buffer(); + detail::vformat_to(buffer, fmt, args); + if (newline) buffer.push_back('\n'); + fwrite_all(buffer.data(), buffer.size(), f); +} +#endif + +FMT_FUNC void print(std::FILE* f, string_view text) { +#if defined(_WIN32) && !defined(FMT_USE_WRITE_CONSOLE) + int fd = _fileno(f); + if (_isatty(fd)) { + std::fflush(f); + if (write_console(fd, text)) return; + } +#endif + fwrite_all(text.data(), text.size(), f); +} +} // namespace detail + +FMT_FUNC void vprint_buffered(std::FILE* f, string_view fmt, format_args args) { + auto buffer = memory_buffer(); + detail::vformat_to(buffer, fmt, args); + detail::print(f, {buffer.data(), buffer.size()}); +} + +FMT_FUNC void vprint(std::FILE* f, string_view fmt, format_args args) { + if (!detail::file_ref(f).is_buffered() || !detail::has_flockfile<>()) + return vprint_buffered(f, fmt, args); + auto&& buffer = detail::file_print_buffer<>(f); + return detail::vformat_to(buffer, fmt, args); +} + +FMT_FUNC void vprintln(std::FILE* f, string_view fmt, format_args args) { + auto buffer = memory_buffer(); + detail::vformat_to(buffer, fmt, args); + buffer.push_back('\n'); + detail::print(f, {buffer.data(), buffer.size()}); +} + +FMT_FUNC void vprint(string_view fmt, format_args args) { + vprint(stdout, fmt, args); +} + +namespace detail { + +struct singleton { + unsigned char upper; + unsigned char lower_count; +}; + +inline auto is_printable(uint16_t x, const singleton* singletons, + size_t singletons_size, + const unsigned char* singleton_lowers, + const unsigned char* normal, size_t normal_size) + -> bool { + auto upper = x >> 8; + auto lower_start = 0; + for (size_t i = 0; i < singletons_size; ++i) { + auto s = singletons[i]; + auto lower_end = lower_start + s.lower_count; + if (upper < s.upper) break; + if (upper == s.upper) { + for (auto j = lower_start; j < lower_end; ++j) { + if (singleton_lowers[j] == (x & 0xff)) return false; + } + } + lower_start = lower_end; + } + + auto xsigned = static_cast(x); + auto current = true; + for (size_t i = 0; i < normal_size; ++i) { + auto v = static_cast(normal[i]); + auto len = (v & 0x80) != 0 ? (v & 0x7f) << 8 | normal[++i] : v; + xsigned -= len; + if (xsigned < 0) break; + current = !current; + } + return current; +} + +// This code is generated by support/printable.py. +FMT_FUNC auto is_printable(uint32_t cp) -> bool { + static constexpr singleton singletons0[] = { + {0x00, 1}, {0x03, 5}, {0x05, 6}, {0x06, 3}, {0x07, 6}, {0x08, 8}, + {0x09, 17}, {0x0a, 28}, {0x0b, 25}, {0x0c, 20}, {0x0d, 16}, {0x0e, 13}, + {0x0f, 4}, {0x10, 3}, {0x12, 18}, {0x13, 9}, {0x16, 1}, {0x17, 5}, + {0x18, 2}, {0x19, 3}, {0x1a, 7}, {0x1c, 2}, {0x1d, 1}, {0x1f, 22}, + {0x20, 3}, {0x2b, 3}, {0x2c, 2}, {0x2d, 11}, {0x2e, 1}, {0x30, 3}, + {0x31, 2}, {0x32, 1}, {0xa7, 2}, {0xa9, 2}, {0xaa, 4}, {0xab, 8}, + {0xfa, 2}, {0xfb, 5}, {0xfd, 4}, {0xfe, 3}, {0xff, 9}, + }; + static constexpr unsigned char singletons0_lower[] = { + 0xad, 0x78, 0x79, 0x8b, 0x8d, 0xa2, 0x30, 0x57, 0x58, 0x8b, 0x8c, 0x90, + 0x1c, 0x1d, 0xdd, 0x0e, 0x0f, 0x4b, 0x4c, 0xfb, 0xfc, 0x2e, 0x2f, 0x3f, + 0x5c, 0x5d, 0x5f, 0xb5, 0xe2, 0x84, 0x8d, 0x8e, 0x91, 0x92, 0xa9, 0xb1, + 0xba, 0xbb, 0xc5, 0xc6, 0xc9, 0xca, 0xde, 0xe4, 0xe5, 0xff, 0x00, 0x04, + 0x11, 0x12, 0x29, 0x31, 0x34, 0x37, 0x3a, 0x3b, 0x3d, 0x49, 0x4a, 0x5d, + 0x84, 0x8e, 0x92, 0xa9, 0xb1, 0xb4, 0xba, 0xbb, 0xc6, 0xca, 0xce, 0xcf, + 0xe4, 0xe5, 0x00, 0x04, 0x0d, 0x0e, 0x11, 0x12, 0x29, 0x31, 0x34, 0x3a, + 0x3b, 0x45, 0x46, 0x49, 0x4a, 0x5e, 0x64, 0x65, 0x84, 0x91, 0x9b, 0x9d, + 0xc9, 0xce, 0xcf, 0x0d, 0x11, 0x29, 0x45, 0x49, 0x57, 0x64, 0x65, 0x8d, + 0x91, 0xa9, 0xb4, 0xba, 0xbb, 0xc5, 0xc9, 0xdf, 0xe4, 0xe5, 0xf0, 0x0d, + 0x11, 0x45, 0x49, 0x64, 0x65, 0x80, 0x84, 0xb2, 0xbc, 0xbe, 0xbf, 0xd5, + 0xd7, 0xf0, 0xf1, 0x83, 0x85, 0x8b, 0xa4, 0xa6, 0xbe, 0xbf, 0xc5, 0xc7, + 0xce, 0xcf, 0xda, 0xdb, 0x48, 0x98, 0xbd, 0xcd, 0xc6, 0xce, 0xcf, 0x49, + 0x4e, 0x4f, 0x57, 0x59, 0x5e, 0x5f, 0x89, 0x8e, 0x8f, 0xb1, 0xb6, 0xb7, + 0xbf, 0xc1, 0xc6, 0xc7, 0xd7, 0x11, 0x16, 0x17, 0x5b, 0x5c, 0xf6, 0xf7, + 0xfe, 0xff, 0x80, 0x0d, 0x6d, 0x71, 0xde, 0xdf, 0x0e, 0x0f, 0x1f, 0x6e, + 0x6f, 0x1c, 0x1d, 0x5f, 0x7d, 0x7e, 0xae, 0xaf, 0xbb, 0xbc, 0xfa, 0x16, + 0x17, 0x1e, 0x1f, 0x46, 0x47, 0x4e, 0x4f, 0x58, 0x5a, 0x5c, 0x5e, 0x7e, + 0x7f, 0xb5, 0xc5, 0xd4, 0xd5, 0xdc, 0xf0, 0xf1, 0xf5, 0x72, 0x73, 0x8f, + 0x74, 0x75, 0x96, 0x2f, 0x5f, 0x26, 0x2e, 0x2f, 0xa7, 0xaf, 0xb7, 0xbf, + 0xc7, 0xcf, 0xd7, 0xdf, 0x9a, 0x40, 0x97, 0x98, 0x30, 0x8f, 0x1f, 0xc0, + 0xc1, 0xce, 0xff, 0x4e, 0x4f, 0x5a, 0x5b, 0x07, 0x08, 0x0f, 0x10, 0x27, + 0x2f, 0xee, 0xef, 0x6e, 0x6f, 0x37, 0x3d, 0x3f, 0x42, 0x45, 0x90, 0x91, + 0xfe, 0xff, 0x53, 0x67, 0x75, 0xc8, 0xc9, 0xd0, 0xd1, 0xd8, 0xd9, 0xe7, + 0xfe, 0xff, + }; + static constexpr singleton singletons1[] = { + {0x00, 6}, {0x01, 1}, {0x03, 1}, {0x04, 2}, {0x08, 8}, {0x09, 2}, + {0x0a, 5}, {0x0b, 2}, {0x0e, 4}, {0x10, 1}, {0x11, 2}, {0x12, 5}, + {0x13, 17}, {0x14, 1}, {0x15, 2}, {0x17, 2}, {0x19, 13}, {0x1c, 5}, + {0x1d, 8}, {0x24, 1}, {0x6a, 3}, {0x6b, 2}, {0xbc, 2}, {0xd1, 2}, + {0xd4, 12}, {0xd5, 9}, {0xd6, 2}, {0xd7, 2}, {0xda, 1}, {0xe0, 5}, + {0xe1, 2}, {0xe8, 2}, {0xee, 32}, {0xf0, 4}, {0xf8, 2}, {0xf9, 2}, + {0xfa, 2}, {0xfb, 1}, + }; + static constexpr unsigned char singletons1_lower[] = { + 0x0c, 0x27, 0x3b, 0x3e, 0x4e, 0x4f, 0x8f, 0x9e, 0x9e, 0x9f, 0x06, 0x07, + 0x09, 0x36, 0x3d, 0x3e, 0x56, 0xf3, 0xd0, 0xd1, 0x04, 0x14, 0x18, 0x36, + 0x37, 0x56, 0x57, 0x7f, 0xaa, 0xae, 0xaf, 0xbd, 0x35, 0xe0, 0x12, 0x87, + 0x89, 0x8e, 0x9e, 0x04, 0x0d, 0x0e, 0x11, 0x12, 0x29, 0x31, 0x34, 0x3a, + 0x45, 0x46, 0x49, 0x4a, 0x4e, 0x4f, 0x64, 0x65, 0x5c, 0xb6, 0xb7, 0x1b, + 0x1c, 0x07, 0x08, 0x0a, 0x0b, 0x14, 0x17, 0x36, 0x39, 0x3a, 0xa8, 0xa9, + 0xd8, 0xd9, 0x09, 0x37, 0x90, 0x91, 0xa8, 0x07, 0x0a, 0x3b, 0x3e, 0x66, + 0x69, 0x8f, 0x92, 0x6f, 0x5f, 0xee, 0xef, 0x5a, 0x62, 0x9a, 0x9b, 0x27, + 0x28, 0x55, 0x9d, 0xa0, 0xa1, 0xa3, 0xa4, 0xa7, 0xa8, 0xad, 0xba, 0xbc, + 0xc4, 0x06, 0x0b, 0x0c, 0x15, 0x1d, 0x3a, 0x3f, 0x45, 0x51, 0xa6, 0xa7, + 0xcc, 0xcd, 0xa0, 0x07, 0x19, 0x1a, 0x22, 0x25, 0x3e, 0x3f, 0xc5, 0xc6, + 0x04, 0x20, 0x23, 0x25, 0x26, 0x28, 0x33, 0x38, 0x3a, 0x48, 0x4a, 0x4c, + 0x50, 0x53, 0x55, 0x56, 0x58, 0x5a, 0x5c, 0x5e, 0x60, 0x63, 0x65, 0x66, + 0x6b, 0x73, 0x78, 0x7d, 0x7f, 0x8a, 0xa4, 0xaa, 0xaf, 0xb0, 0xc0, 0xd0, + 0xae, 0xaf, 0x79, 0xcc, 0x6e, 0x6f, 0x93, + }; + static constexpr unsigned char normal0[] = { + 0x00, 0x20, 0x5f, 0x22, 0x82, 0xdf, 0x04, 0x82, 0x44, 0x08, 0x1b, 0x04, + 0x06, 0x11, 0x81, 0xac, 0x0e, 0x80, 0xab, 0x35, 0x28, 0x0b, 0x80, 0xe0, + 0x03, 0x19, 0x08, 0x01, 0x04, 0x2f, 0x04, 0x34, 0x04, 0x07, 0x03, 0x01, + 0x07, 0x06, 0x07, 0x11, 0x0a, 0x50, 0x0f, 0x12, 0x07, 0x55, 0x07, 0x03, + 0x04, 0x1c, 0x0a, 0x09, 0x03, 0x08, 0x03, 0x07, 0x03, 0x02, 0x03, 0x03, + 0x03, 0x0c, 0x04, 0x05, 0x03, 0x0b, 0x06, 0x01, 0x0e, 0x15, 0x05, 0x3a, + 0x03, 0x11, 0x07, 0x06, 0x05, 0x10, 0x07, 0x57, 0x07, 0x02, 0x07, 0x15, + 0x0d, 0x50, 0x04, 0x43, 0x03, 0x2d, 0x03, 0x01, 0x04, 0x11, 0x06, 0x0f, + 0x0c, 0x3a, 0x04, 0x1d, 0x25, 0x5f, 0x20, 0x6d, 0x04, 0x6a, 0x25, 0x80, + 0xc8, 0x05, 0x82, 0xb0, 0x03, 0x1a, 0x06, 0x82, 0xfd, 0x03, 0x59, 0x07, + 0x15, 0x0b, 0x17, 0x09, 0x14, 0x0c, 0x14, 0x0c, 0x6a, 0x06, 0x0a, 0x06, + 0x1a, 0x06, 0x59, 0x07, 0x2b, 0x05, 0x46, 0x0a, 0x2c, 0x04, 0x0c, 0x04, + 0x01, 0x03, 0x31, 0x0b, 0x2c, 0x04, 0x1a, 0x06, 0x0b, 0x03, 0x80, 0xac, + 0x06, 0x0a, 0x06, 0x21, 0x3f, 0x4c, 0x04, 0x2d, 0x03, 0x74, 0x08, 0x3c, + 0x03, 0x0f, 0x03, 0x3c, 0x07, 0x38, 0x08, 0x2b, 0x05, 0x82, 0xff, 0x11, + 0x18, 0x08, 0x2f, 0x11, 0x2d, 0x03, 0x20, 0x10, 0x21, 0x0f, 0x80, 0x8c, + 0x04, 0x82, 0x97, 0x19, 0x0b, 0x15, 0x88, 0x94, 0x05, 0x2f, 0x05, 0x3b, + 0x07, 0x02, 0x0e, 0x18, 0x09, 0x80, 0xb3, 0x2d, 0x74, 0x0c, 0x80, 0xd6, + 0x1a, 0x0c, 0x05, 0x80, 0xff, 0x05, 0x80, 0xdf, 0x0c, 0xee, 0x0d, 0x03, + 0x84, 0x8d, 0x03, 0x37, 0x09, 0x81, 0x5c, 0x14, 0x80, 0xb8, 0x08, 0x80, + 0xcb, 0x2a, 0x38, 0x03, 0x0a, 0x06, 0x38, 0x08, 0x46, 0x08, 0x0c, 0x06, + 0x74, 0x0b, 0x1e, 0x03, 0x5a, 0x04, 0x59, 0x09, 0x80, 0x83, 0x18, 0x1c, + 0x0a, 0x16, 0x09, 0x4c, 0x04, 0x80, 0x8a, 0x06, 0xab, 0xa4, 0x0c, 0x17, + 0x04, 0x31, 0xa1, 0x04, 0x81, 0xda, 0x26, 0x07, 0x0c, 0x05, 0x05, 0x80, + 0xa5, 0x11, 0x81, 0x6d, 0x10, 0x78, 0x28, 0x2a, 0x06, 0x4c, 0x04, 0x80, + 0x8d, 0x04, 0x80, 0xbe, 0x03, 0x1b, 0x03, 0x0f, 0x0d, + }; + static constexpr unsigned char normal1[] = { + 0x5e, 0x22, 0x7b, 0x05, 0x03, 0x04, 0x2d, 0x03, 0x66, 0x03, 0x01, 0x2f, + 0x2e, 0x80, 0x82, 0x1d, 0x03, 0x31, 0x0f, 0x1c, 0x04, 0x24, 0x09, 0x1e, + 0x05, 0x2b, 0x05, 0x44, 0x04, 0x0e, 0x2a, 0x80, 0xaa, 0x06, 0x24, 0x04, + 0x24, 0x04, 0x28, 0x08, 0x34, 0x0b, 0x01, 0x80, 0x90, 0x81, 0x37, 0x09, + 0x16, 0x0a, 0x08, 0x80, 0x98, 0x39, 0x03, 0x63, 0x08, 0x09, 0x30, 0x16, + 0x05, 0x21, 0x03, 0x1b, 0x05, 0x01, 0x40, 0x38, 0x04, 0x4b, 0x05, 0x2f, + 0x04, 0x0a, 0x07, 0x09, 0x07, 0x40, 0x20, 0x27, 0x04, 0x0c, 0x09, 0x36, + 0x03, 0x3a, 0x05, 0x1a, 0x07, 0x04, 0x0c, 0x07, 0x50, 0x49, 0x37, 0x33, + 0x0d, 0x33, 0x07, 0x2e, 0x08, 0x0a, 0x81, 0x26, 0x52, 0x4e, 0x28, 0x08, + 0x2a, 0x56, 0x1c, 0x14, 0x17, 0x09, 0x4e, 0x04, 0x1e, 0x0f, 0x43, 0x0e, + 0x19, 0x07, 0x0a, 0x06, 0x48, 0x08, 0x27, 0x09, 0x75, 0x0b, 0x3f, 0x41, + 0x2a, 0x06, 0x3b, 0x05, 0x0a, 0x06, 0x51, 0x06, 0x01, 0x05, 0x10, 0x03, + 0x05, 0x80, 0x8b, 0x62, 0x1e, 0x48, 0x08, 0x0a, 0x80, 0xa6, 0x5e, 0x22, + 0x45, 0x0b, 0x0a, 0x06, 0x0d, 0x13, 0x39, 0x07, 0x0a, 0x36, 0x2c, 0x04, + 0x10, 0x80, 0xc0, 0x3c, 0x64, 0x53, 0x0c, 0x48, 0x09, 0x0a, 0x46, 0x45, + 0x1b, 0x48, 0x08, 0x53, 0x1d, 0x39, 0x81, 0x07, 0x46, 0x0a, 0x1d, 0x03, + 0x47, 0x49, 0x37, 0x03, 0x0e, 0x08, 0x0a, 0x06, 0x39, 0x07, 0x0a, 0x81, + 0x36, 0x19, 0x80, 0xb7, 0x01, 0x0f, 0x32, 0x0d, 0x83, 0x9b, 0x66, 0x75, + 0x0b, 0x80, 0xc4, 0x8a, 0xbc, 0x84, 0x2f, 0x8f, 0xd1, 0x82, 0x47, 0xa1, + 0xb9, 0x82, 0x39, 0x07, 0x2a, 0x04, 0x02, 0x60, 0x26, 0x0a, 0x46, 0x0a, + 0x28, 0x05, 0x13, 0x82, 0xb0, 0x5b, 0x65, 0x4b, 0x04, 0x39, 0x07, 0x11, + 0x40, 0x05, 0x0b, 0x02, 0x0e, 0x97, 0xf8, 0x08, 0x84, 0xd6, 0x2a, 0x09, + 0xa2, 0xf7, 0x81, 0x1f, 0x31, 0x03, 0x11, 0x04, 0x08, 0x81, 0x8c, 0x89, + 0x04, 0x6b, 0x05, 0x0d, 0x03, 0x09, 0x07, 0x10, 0x93, 0x60, 0x80, 0xf6, + 0x0a, 0x73, 0x08, 0x6e, 0x17, 0x46, 0x80, 0x9a, 0x14, 0x0c, 0x57, 0x09, + 0x19, 0x80, 0x87, 0x81, 0x47, 0x03, 0x85, 0x42, 0x0f, 0x15, 0x85, 0x50, + 0x2b, 0x80, 0xd5, 0x2d, 0x03, 0x1a, 0x04, 0x02, 0x81, 0x70, 0x3a, 0x05, + 0x01, 0x85, 0x00, 0x80, 0xd7, 0x29, 0x4c, 0x04, 0x0a, 0x04, 0x02, 0x83, + 0x11, 0x44, 0x4c, 0x3d, 0x80, 0xc2, 0x3c, 0x06, 0x01, 0x04, 0x55, 0x05, + 0x1b, 0x34, 0x02, 0x81, 0x0e, 0x2c, 0x04, 0x64, 0x0c, 0x56, 0x0a, 0x80, + 0xae, 0x38, 0x1d, 0x0d, 0x2c, 0x04, 0x09, 0x07, 0x02, 0x0e, 0x06, 0x80, + 0x9a, 0x83, 0xd8, 0x08, 0x0d, 0x03, 0x0d, 0x03, 0x74, 0x0c, 0x59, 0x07, + 0x0c, 0x14, 0x0c, 0x04, 0x38, 0x08, 0x0a, 0x06, 0x28, 0x08, 0x22, 0x4e, + 0x81, 0x54, 0x0c, 0x15, 0x03, 0x03, 0x05, 0x07, 0x09, 0x19, 0x07, 0x07, + 0x09, 0x03, 0x0d, 0x07, 0x29, 0x80, 0xcb, 0x25, 0x0a, 0x84, 0x06, + }; + auto lower = static_cast(cp); + if (cp < 0x10000) { + return is_printable(lower, singletons0, + sizeof(singletons0) / sizeof(*singletons0), + singletons0_lower, normal0, sizeof(normal0)); + } + if (cp < 0x20000) { + return is_printable(lower, singletons1, + sizeof(singletons1) / sizeof(*singletons1), + singletons1_lower, normal1, sizeof(normal1)); + } + if (0x2a6de <= cp && cp < 0x2a700) return false; + if (0x2b735 <= cp && cp < 0x2b740) return false; + if (0x2b81e <= cp && cp < 0x2b820) return false; + if (0x2cea2 <= cp && cp < 0x2ceb0) return false; + if (0x2ebe1 <= cp && cp < 0x2f800) return false; + if (0x2fa1e <= cp && cp < 0x30000) return false; + if (0x3134b <= cp && cp < 0xe0100) return false; + if (0xe01f0 <= cp && cp < 0x110000) return false; + return cp < 0x110000; +} + +} // namespace detail + +FMT_END_NAMESPACE + +#endif // FMT_FORMAT_INL_H_ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fmt/format.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fmt/format.h new file mode 100644 index 0000000000000000000000000000000000000000..a16acbf64ad9b35b0b008127d7a80a794da7ac0f --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fmt/format.h @@ -0,0 +1,4400 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +/* + Formatting library for C++ + + Copyright (c) 2012 - present, Victor Zverovich + + Permission is hereby granted, free of charge, to any person obtaining + a copy of this software and associated documentation files (the + "Software"), to deal in the Software without restriction, including + without limitation the rights to use, copy, modify, merge, publish, + distribute, sublicense, and/or sell copies of the Software, and to + permit persons to whom the Software is furnished to do so, subject to + the following conditions: + + The above copyright notice and this permission notice shall be + included in all copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE + LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION + OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION + WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + + --- Optional exception to the license --- + + As an exception, if, as a result of your compiling your source code, portions + of this Software are embedded into a machine-executable object form of such + source code, you may redistribute such embedded portions in such object form + without including the above copyright and permission notices. + */ + +#ifndef FMT_FORMAT_H_ +#define FMT_FORMAT_H_ + +#ifndef _LIBCPP_REMOVE_TRANSITIVE_INCLUDES +# define _LIBCPP_REMOVE_TRANSITIVE_INCLUDES +# define FMT_REMOVE_TRANSITIVE_INCLUDES +#endif + +#include "base.h" + +// libc++ supports string_view in pre-c++17. +#if FMT_HAS_INCLUDE() && \ + (FMT_CPLUSPLUS >= 201703L || defined(_LIBCPP_VERSION)) +# define FMT_USE_STRING_VIEW +#endif + +#ifndef FMT_MODULE +# include // malloc, free + +# include // std::signbit +# include // std::byte +# include // uint32_t +# include // std::memcpy +# include // std::numeric_limits +# include // std::bad_alloc +# if defined(__GLIBCXX__) && !defined(_GLIBCXX_USE_DUAL_ABI) +// Workaround for pre gcc 5 libstdc++. +# include // std::allocator_traits +# endif +# include // std::runtime_error +# include // std::string +# include // std::system_error + +// Check FMT_CPLUSPLUS to avoid a warning in MSVC. +# if FMT_HAS_INCLUDE() && FMT_CPLUSPLUS > 201703L +# include // std::bit_cast +# endif + +# if defined(FMT_USE_STRING_VIEW) +# include +# endif + +# if FMT_MSC_VERSION +# include // _BitScanReverse[64], _umul128 +# endif +#endif // FMT_MODULE + +#if defined(FMT_USE_NONTYPE_TEMPLATE_ARGS) +// Use the provided definition. +#elif defined(__NVCOMPILER) +# define FMT_USE_NONTYPE_TEMPLATE_ARGS 0 +#elif FMT_GCC_VERSION >= 903 && FMT_CPLUSPLUS >= 201709L +# define FMT_USE_NONTYPE_TEMPLATE_ARGS 1 +#elif defined(__cpp_nontype_template_args) && \ + __cpp_nontype_template_args >= 201911L +# define FMT_USE_NONTYPE_TEMPLATE_ARGS 1 +#elif FMT_CLANG_VERSION >= 1200 && FMT_CPLUSPLUS >= 202002L +# define FMT_USE_NONTYPE_TEMPLATE_ARGS 1 +#else +# define FMT_USE_NONTYPE_TEMPLATE_ARGS 0 +#endif + +#if defined __cpp_inline_variables && __cpp_inline_variables >= 201606L +# define FMT_INLINE_VARIABLE inline +#else +# define FMT_INLINE_VARIABLE +#endif + +// Check if RTTI is disabled. +#ifdef FMT_USE_RTTI +// Use the provided definition. +#elif defined(__GXX_RTTI) || FMT_HAS_FEATURE(cxx_rtti) || defined(_CPPRTTI) || \ + defined(__INTEL_RTTI__) || defined(__RTTI) +// __RTTI is for EDG compilers. _CPPRTTI is for MSVC. +# define FMT_USE_RTTI 1 +#else +# define FMT_USE_RTTI 0 +#endif + +// Visibility when compiled as a shared library/object. +#if defined(FMT_LIB_EXPORT) || defined(FMT_SHARED) +# define FMT_SO_VISIBILITY(value) FMT_VISIBILITY(value) +#else +# define FMT_SO_VISIBILITY(value) +#endif + +#if FMT_GCC_VERSION || FMT_CLANG_VERSION +# define FMT_NOINLINE __attribute__((noinline)) +#else +# define FMT_NOINLINE +#endif + +#ifdef FMT_DEPRECATED +// Use the provided definition. +#elif FMT_HAS_CPP14_ATTRIBUTE(deprecated) +# define FMT_DEPRECATED [[deprecated]] +#else +# define FMT_DEPRECATED /* deprecated */ +#endif + +// Detect constexpr std::string. +#if !FMT_USE_CONSTEVAL +# define FMT_USE_CONSTEXPR_STRING 0 +#elif defined(__cpp_lib_constexpr_string) && \ + __cpp_lib_constexpr_string >= 201907L +# if FMT_CLANG_VERSION && FMT_GLIBCXX_RELEASE +// clang + libstdc++ are able to work only starting with gcc13.3 +// https://gcc.gnu.org/bugzilla/show_bug.cgi?id=113294 +# if FMT_GLIBCXX_RELEASE < 13 +# define FMT_USE_CONSTEXPR_STRING 0 +# elif FMT_GLIBCXX_RELEASE == 13 && __GLIBCXX__ < 20240521 +# define FMT_USE_CONSTEXPR_STRING 0 +# else +# define FMT_USE_CONSTEXPR_STRING 1 +# endif +# else +# define FMT_USE_CONSTEXPR_STRING 1 +# endif +#else +# define FMT_USE_CONSTEXPR_STRING 0 +#endif +#if FMT_USE_CONSTEXPR_STRING +# define FMT_CONSTEXPR_STRING constexpr +#else +# define FMT_CONSTEXPR_STRING +#endif + +// GCC 4.9 doesn't support qualified names in specializations. +namespace std { +template struct iterator_traits> { + using iterator_category = output_iterator_tag; + using value_type = T; + using difference_type = + decltype(static_cast(nullptr) - static_cast(nullptr)); + using pointer = void; + using reference = void; +}; +} // namespace std + +#ifdef FMT_THROW +// Use the provided definition. +#elif FMT_USE_EXCEPTIONS +# define FMT_THROW(x) throw x +#else +# define FMT_THROW(x) ::fmt::assert_fail(__FILE__, __LINE__, (x).what()) +#endif + +#ifdef __clang_analyzer__ +# define FMT_CLANG_ANALYZER 1 +#else +# define FMT_CLANG_ANALYZER 0 +#endif + +// Defining FMT_REDUCE_INT_INSTANTIATIONS to 1, will reduce the number of +// integer formatter template instantiations to just one by only using the +// largest integer type. This results in a reduction in binary size but will +// cause a decrease in integer formatting performance. +#if !defined(FMT_REDUCE_INT_INSTANTIATIONS) +# define FMT_REDUCE_INT_INSTANTIATIONS 0 +#endif + +FMT_BEGIN_NAMESPACE + +template +struct is_contiguous> + : std::true_type {}; + +namespace detail { + +// __builtin_clz is broken in clang with Microsoft codegen: +// https://github.com/fmtlib/fmt/issues/519. +#if !FMT_MSC_VERSION +# if FMT_HAS_BUILTIN(__builtin_clz) || FMT_GCC_VERSION || FMT_ICC_VERSION +# define FMT_BUILTIN_CLZ(n) __builtin_clz(n) +# endif +# if FMT_HAS_BUILTIN(__builtin_clzll) || FMT_GCC_VERSION || FMT_ICC_VERSION +# define FMT_BUILTIN_CLZLL(n) __builtin_clzll(n) +# endif +#endif + +// Some compilers masquerade as both MSVC and GCC but otherwise support +// __builtin_clz and __builtin_clzll, so only define FMT_BUILTIN_CLZ using the +// MSVC intrinsics if the clz and clzll builtins are not available. +#if FMT_MSC_VERSION && !defined(FMT_BUILTIN_CLZLL) +// Avoid Clang with Microsoft CodeGen's -Wunknown-pragmas warning. +# ifndef __clang__ +# pragma intrinsic(_BitScanReverse) +# ifdef _WIN64 +# pragma intrinsic(_BitScanReverse64) +# endif +# endif + +inline auto clz(uint32_t x) -> int { + FMT_ASSERT(x != 0, ""); + FMT_MSC_WARNING(suppress : 6102) // Suppress a bogus static analysis warning. + unsigned long r = 0; + _BitScanReverse(&r, x); + return 31 ^ static_cast(r); +} +# define FMT_BUILTIN_CLZ(n) detail::clz(n) + +inline auto clzll(uint64_t x) -> int { + FMT_ASSERT(x != 0, ""); + FMT_MSC_WARNING(suppress : 6102) // Suppress a bogus static analysis warning. + unsigned long r = 0; +# ifdef _WIN64 + _BitScanReverse64(&r, x); +# else + // Scan the high 32 bits. + if (_BitScanReverse(&r, static_cast(x >> 32))) + return 63 ^ static_cast(r + 32); + // Scan the low 32 bits. + _BitScanReverse(&r, static_cast(x)); +# endif + return 63 ^ static_cast(r); +} +# define FMT_BUILTIN_CLZLL(n) detail::clzll(n) +#endif // FMT_MSC_VERSION && !defined(FMT_BUILTIN_CLZLL) + +FMT_CONSTEXPR inline void abort_fuzzing_if(bool condition) { + ignore_unused(condition); +#ifdef FMT_FUZZ + if (condition) throw std::runtime_error("fuzzing limit reached"); +#endif +} + +#if defined(FMT_USE_STRING_VIEW) +template using std_string_view = std::basic_string_view; +#else +template struct std_string_view { + operator basic_string_view() const; +}; +#endif + +template struct string_literal { + static constexpr Char value[sizeof...(C)] = {C...}; + constexpr operator basic_string_view() const { + return {value, sizeof...(C)}; + } +}; +#if FMT_CPLUSPLUS < 201703L +template +constexpr Char string_literal::value[sizeof...(C)]; +#endif + +// Implementation of std::bit_cast for pre-C++20. +template +FMT_CONSTEXPR20 auto bit_cast(const From& from) -> To { +#ifdef __cpp_lib_bit_cast + if (is_constant_evaluated()) return std::bit_cast(from); +#endif + auto to = To(); + // The cast suppresses a bogus -Wclass-memaccess on GCC. + std::memcpy(static_cast(&to), &from, sizeof(to)); + return to; +} + +inline auto is_big_endian() -> bool { +#ifdef _WIN32 + return false; +#elif defined(__BIG_ENDIAN__) + return true; +#elif defined(__BYTE_ORDER__) && defined(__ORDER_BIG_ENDIAN__) + return __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__; +#else + struct bytes { + char data[sizeof(int)]; + }; + return bit_cast(1).data[0] == 0; +#endif +} + +class uint128_fallback { + private: + uint64_t lo_, hi_; + + public: + constexpr uint128_fallback(uint64_t hi, uint64_t lo) : lo_(lo), hi_(hi) {} + constexpr uint128_fallback(uint64_t value = 0) : lo_(value), hi_(0) {} + + constexpr auto high() const noexcept -> uint64_t { return hi_; } + constexpr auto low() const noexcept -> uint64_t { return lo_; } + + template ::value)> + constexpr explicit operator T() const { + return static_cast(lo_); + } + + friend constexpr auto operator==(const uint128_fallback& lhs, + const uint128_fallback& rhs) -> bool { + return lhs.hi_ == rhs.hi_ && lhs.lo_ == rhs.lo_; + } + friend constexpr auto operator!=(const uint128_fallback& lhs, + const uint128_fallback& rhs) -> bool { + return !(lhs == rhs); + } + friend constexpr auto operator>(const uint128_fallback& lhs, + const uint128_fallback& rhs) -> bool { + return lhs.hi_ != rhs.hi_ ? lhs.hi_ > rhs.hi_ : lhs.lo_ > rhs.lo_; + } + friend constexpr auto operator|(const uint128_fallback& lhs, + const uint128_fallback& rhs) + -> uint128_fallback { + return {lhs.hi_ | rhs.hi_, lhs.lo_ | rhs.lo_}; + } + friend constexpr auto operator&(const uint128_fallback& lhs, + const uint128_fallback& rhs) + -> uint128_fallback { + return {lhs.hi_ & rhs.hi_, lhs.lo_ & rhs.lo_}; + } + friend constexpr auto operator~(const uint128_fallback& n) + -> uint128_fallback { + return {~n.hi_, ~n.lo_}; + } + friend FMT_CONSTEXPR auto operator+(const uint128_fallback& lhs, + const uint128_fallback& rhs) + -> uint128_fallback { + auto result = uint128_fallback(lhs); + result += rhs; + return result; + } + friend FMT_CONSTEXPR auto operator*(const uint128_fallback& lhs, uint32_t rhs) + -> uint128_fallback { + FMT_ASSERT(lhs.hi_ == 0, ""); + uint64_t hi = (lhs.lo_ >> 32) * rhs; + uint64_t lo = (lhs.lo_ & ~uint32_t()) * rhs; + uint64_t new_lo = (hi << 32) + lo; + return {(hi >> 32) + (new_lo < lo ? 1 : 0), new_lo}; + } + friend constexpr auto operator-(const uint128_fallback& lhs, uint64_t rhs) + -> uint128_fallback { + return {lhs.hi_ - (lhs.lo_ < rhs ? 1 : 0), lhs.lo_ - rhs}; + } + FMT_CONSTEXPR auto operator>>(int shift) const -> uint128_fallback { + if (shift == 64) return {0, hi_}; + if (shift > 64) return uint128_fallback(0, hi_) >> (shift - 64); + return {hi_ >> shift, (hi_ << (64 - shift)) | (lo_ >> shift)}; + } + FMT_CONSTEXPR auto operator<<(int shift) const -> uint128_fallback { + if (shift == 64) return {lo_, 0}; + if (shift > 64) return uint128_fallback(lo_, 0) << (shift - 64); + return {hi_ << shift | (lo_ >> (64 - shift)), (lo_ << shift)}; + } + FMT_CONSTEXPR auto operator>>=(int shift) -> uint128_fallback& { + return *this = *this >> shift; + } + FMT_CONSTEXPR void operator+=(uint128_fallback n) { + uint64_t new_lo = lo_ + n.lo_; + uint64_t new_hi = hi_ + n.hi_ + (new_lo < lo_ ? 1 : 0); + FMT_ASSERT(new_hi >= hi_, ""); + lo_ = new_lo; + hi_ = new_hi; + } + FMT_CONSTEXPR void operator&=(uint128_fallback n) { + lo_ &= n.lo_; + hi_ &= n.hi_; + } + + FMT_CONSTEXPR20 auto operator+=(uint64_t n) noexcept -> uint128_fallback& { + if (is_constant_evaluated()) { + lo_ += n; + hi_ += (lo_ < n ? 1 : 0); + return *this; + } +#if FMT_HAS_BUILTIN(__builtin_addcll) && !defined(__ibmxl__) + unsigned long long carry; + lo_ = __builtin_addcll(lo_, n, 0, &carry); + hi_ += carry; +#elif FMT_HAS_BUILTIN(__builtin_ia32_addcarryx_u64) && !defined(__ibmxl__) + unsigned long long result; + auto carry = __builtin_ia32_addcarryx_u64(0, lo_, n, &result); + lo_ = result; + hi_ += carry; +#elif defined(_MSC_VER) && defined(_M_X64) + auto carry = _addcarry_u64(0, lo_, n, &lo_); + _addcarry_u64(carry, hi_, 0, &hi_); +#else + lo_ += n; + hi_ += (lo_ < n ? 1 : 0); +#endif + return *this; + } +}; + +using uint128_t = conditional_t; + +#ifdef UINTPTR_MAX +using uintptr_t = ::uintptr_t; +#else +using uintptr_t = uint128_t; +#endif + +// Returns the largest possible value for type T. Same as +// std::numeric_limits::max() but shorter and not affected by the max macro. +template constexpr auto max_value() -> T { + return (std::numeric_limits::max)(); +} +template constexpr auto num_bits() -> int { + return std::numeric_limits::digits; +} +// std::numeric_limits::digits may return 0 for 128-bit ints. +template <> constexpr auto num_bits() -> int { return 128; } +template <> constexpr auto num_bits() -> int { return 128; } +template <> constexpr auto num_bits() -> int { return 128; } + +// A heterogeneous bit_cast used for converting 96-bit long double to uint128_t +// and 128-bit pointers to uint128_fallback. +template sizeof(From))> +inline auto bit_cast(const From& from) -> To { + constexpr auto size = static_cast(sizeof(From) / sizeof(unsigned short)); + struct data_t { + unsigned short value[static_cast(size)]; + } data = bit_cast(from); + auto result = To(); + if (const_check(is_big_endian())) { + for (int i = 0; i < size; ++i) + result = (result << num_bits()) | data.value[i]; + } else { + for (int i = size - 1; i >= 0; --i) + result = (result << num_bits()) | data.value[i]; + } + return result; +} + +template +FMT_CONSTEXPR20 inline auto countl_zero_fallback(UInt n) -> int { + int lz = 0; + constexpr UInt msb_mask = static_cast(1) << (num_bits() - 1); + for (; (n & msb_mask) == 0; n <<= 1) lz++; + return lz; +} + +FMT_CONSTEXPR20 inline auto countl_zero(uint32_t n) -> int { +#ifdef FMT_BUILTIN_CLZ + if (!is_constant_evaluated()) return FMT_BUILTIN_CLZ(n); +#endif + return countl_zero_fallback(n); +} + +FMT_CONSTEXPR20 inline auto countl_zero(uint64_t n) -> int { +#ifdef FMT_BUILTIN_CLZLL + if (!is_constant_evaluated()) return FMT_BUILTIN_CLZLL(n); +#endif + return countl_zero_fallback(n); +} + +FMT_INLINE void assume(bool condition) { + (void)condition; +#if FMT_HAS_BUILTIN(__builtin_assume) && !FMT_ICC_VERSION + __builtin_assume(condition); +#elif FMT_GCC_VERSION + if (!condition) __builtin_unreachable(); +#endif +} + +// Attempts to reserve space for n extra characters in the output range. +// Returns a pointer to the reserved range or a reference to it. +template ::value&& + is_contiguous::value)> +#if FMT_CLANG_VERSION >= 307 && !FMT_ICC_VERSION +__attribute__((no_sanitize("undefined"))) +#endif +FMT_CONSTEXPR20 inline auto +reserve(OutputIt it, size_t n) -> typename OutputIt::value_type* { + auto& c = get_container(it); + size_t size = c.size(); + c.resize(size + n); + return &c[size]; +} + +template +FMT_CONSTEXPR20 inline auto reserve(basic_appender it, size_t n) + -> basic_appender { + buffer& buf = get_container(it); + buf.try_reserve(buf.size() + n); + return it; +} + +template +constexpr auto reserve(Iterator& it, size_t) -> Iterator& { + return it; +} + +template +using reserve_iterator = + remove_reference_t(), 0))>; + +template +constexpr auto to_pointer(OutputIt, size_t) -> T* { + return nullptr; +} +template FMT_CONSTEXPR auto to_pointer(T*& ptr, size_t n) -> T* { + T* begin = ptr; + ptr += n; + return begin; +} +template +FMT_CONSTEXPR20 auto to_pointer(basic_appender it, size_t n) -> T* { + buffer& buf = get_container(it); + buf.try_reserve(buf.size() + n); + auto size = buf.size(); + if (buf.capacity() < size + n) return nullptr; + buf.try_resize(size + n); + return buf.data() + size; +} + +template ::value&& + is_contiguous::value)> +inline auto base_iterator(OutputIt it, + typename OutputIt::container_type::value_type*) + -> OutputIt { + return it; +} + +template +constexpr auto base_iterator(Iterator, Iterator it) -> Iterator { + return it; +} + +// is spectacularly slow to compile in C++20 so use a simple fill_n +// instead (#1998). +template +FMT_CONSTEXPR auto fill_n(OutputIt out, Size count, const T& value) + -> OutputIt { + for (Size i = 0; i < count; ++i) *out++ = value; + return out; +} +template +FMT_CONSTEXPR20 auto fill_n(T* out, Size count, char value) -> T* { + if (is_constant_evaluated()) return fill_n(out, count, value); + static_assert(sizeof(T) == 1, + "sizeof(T) must be 1 to use char for initialization"); + std::memset(out, value, to_unsigned(count)); + return out + count; +} + +template +FMT_CONSTEXPR FMT_NOINLINE auto copy_noinline(InputIt begin, InputIt end, + OutputIt out) -> OutputIt { + return copy(begin, end, out); +} + +// A public domain branchless UTF-8 decoder by Christopher Wellons: +// https://github.com/skeeto/branchless-utf8 +/* Decode the next character, c, from s, reporting errors in e. + * + * Since this is a branchless decoder, four bytes will be read from the + * buffer regardless of the actual length of the next character. This + * means the buffer _must_ have at least three bytes of zero padding + * following the end of the data stream. + * + * Errors are reported in e, which will be non-zero if the parsed + * character was somehow invalid: invalid byte sequence, non-canonical + * encoding, or a surrogate half. + * + * The function returns a pointer to the next character. When an error + * occurs, this pointer will be a guess that depends on the particular + * error, but it will always advance at least one byte. + */ +FMT_CONSTEXPR inline auto utf8_decode(const char* s, uint32_t* c, int* e) + -> const char* { + constexpr int masks[] = {0x00, 0x7f, 0x1f, 0x0f, 0x07}; + constexpr uint32_t mins[] = {4194304, 0, 128, 2048, 65536}; + constexpr int shiftc[] = {0, 18, 12, 6, 0}; + constexpr int shifte[] = {0, 6, 4, 2, 0}; + + int len = "\1\1\1\1\1\1\1\1\1\1\1\1\1\1\1\1\0\0\0\0\0\0\0\0\2\2\2\2\3\3\4" + [static_cast(*s) >> 3]; + // Compute the pointer to the next character early so that the next + // iteration can start working on the next character. Neither Clang + // nor GCC figure out this reordering on their own. + const char* next = s + len + !len; + + using uchar = unsigned char; + + // Assume a four-byte character and load four bytes. Unused bits are + // shifted out. + *c = uint32_t(uchar(s[0]) & masks[len]) << 18; + *c |= uint32_t(uchar(s[1]) & 0x3f) << 12; + *c |= uint32_t(uchar(s[2]) & 0x3f) << 6; + *c |= uint32_t(uchar(s[3]) & 0x3f) << 0; + *c >>= shiftc[len]; + + // Accumulate the various error conditions. + *e = (*c < mins[len]) << 6; // non-canonical encoding + *e |= ((*c >> 11) == 0x1b) << 7; // surrogate half? + *e |= (*c > 0x10FFFF) << 8; // out of range? + *e |= (uchar(s[1]) & 0xc0) >> 2; + *e |= (uchar(s[2]) & 0xc0) >> 4; + *e |= uchar(s[3]) >> 6; + *e ^= 0x2a; // top two bits of each tail byte correct? + *e >>= shifte[len]; + + return next; +} + +constexpr FMT_INLINE_VARIABLE uint32_t invalid_code_point = ~uint32_t(); + +// Invokes f(cp, sv) for every code point cp in s with sv being the string view +// corresponding to the code point. cp is invalid_code_point on error. +template +FMT_CONSTEXPR void for_each_codepoint(string_view s, F f) { + auto decode = [f](const char* buf_ptr, const char* ptr) { + auto cp = uint32_t(); + auto error = 0; + auto end = utf8_decode(buf_ptr, &cp, &error); + bool result = f(error ? invalid_code_point : cp, + string_view(ptr, error ? 1 : to_unsigned(end - buf_ptr))); + return result ? (error ? buf_ptr + 1 : end) : nullptr; + }; + + auto p = s.data(); + const size_t block_size = 4; // utf8_decode always reads blocks of 4 chars. + if (s.size() >= block_size) { + for (auto end = p + s.size() - block_size + 1; p < end;) { + p = decode(p, p); + if (!p) return; + } + } + auto num_chars_left = to_unsigned(s.data() + s.size() - p); + if (num_chars_left == 0) return; + + // Suppress bogus -Wstringop-overflow. + if (FMT_GCC_VERSION) num_chars_left &= 3; + char buf[2 * block_size - 1] = {}; + copy(p, p + num_chars_left, buf); + const char* buf_ptr = buf; + do { + auto end = decode(buf_ptr, p); + if (!end) return; + p += end - buf_ptr; + buf_ptr = end; + } while (buf_ptr < buf + num_chars_left); +} + +FMT_CONSTEXPR inline auto display_width_of(uint32_t cp) noexcept -> size_t { + return to_unsigned( + 1 + (cp >= 0x1100 && + (cp <= 0x115f || // Hangul Jamo init. consonants + cp == 0x2329 || // LEFT-POINTING ANGLE BRACKET + cp == 0x232a || // RIGHT-POINTING ANGLE BRACKET + // CJK ... Yi except IDEOGRAPHIC HALF FILL SPACE: + (cp >= 0x2e80 && cp <= 0xa4cf && cp != 0x303f) || + (cp >= 0xac00 && cp <= 0xd7a3) || // Hangul Syllables + (cp >= 0xf900 && cp <= 0xfaff) || // CJK Compatibility Ideographs + (cp >= 0xfe10 && cp <= 0xfe19) || // Vertical Forms + (cp >= 0xfe30 && cp <= 0xfe6f) || // CJK Compatibility Forms + (cp >= 0xff00 && cp <= 0xff60) || // Fullwidth Forms + (cp >= 0xffe0 && cp <= 0xffe6) || // Fullwidth Forms + (cp >= 0x20000 && cp <= 0x2fffd) || // CJK + (cp >= 0x30000 && cp <= 0x3fffd) || + // Miscellaneous Symbols and Pictographs + Emoticons: + (cp >= 0x1f300 && cp <= 0x1f64f) || + // Supplemental Symbols and Pictographs: + (cp >= 0x1f900 && cp <= 0x1f9ff)))); +} + +template struct is_integral : std::is_integral {}; +template <> struct is_integral : std::true_type {}; +template <> struct is_integral : std::true_type {}; + +template +using is_signed = + std::integral_constant::is_signed || + std::is_same::value>; + +template +using is_integer = + bool_constant::value && !std::is_same::value && + !std::is_same::value && + !std::is_same::value>; + +#if defined(FMT_USE_FLOAT128) +// Use the provided definition. +#elif FMT_CLANG_VERSION >= 309 && FMT_HAS_INCLUDE() +# define FMT_USE_FLOAT128 1 +#elif FMT_GCC_VERSION && defined(_GLIBCXX_USE_FLOAT128) && \ + !defined(__STRICT_ANSI__) +# define FMT_USE_FLOAT128 1 +#else +# define FMT_USE_FLOAT128 0 +#endif +#if FMT_USE_FLOAT128 +using float128 = __float128; +#else +struct float128 {}; +#endif + +template using is_float128 = std::is_same; + +template struct is_floating_point : std::is_floating_point {}; +template <> struct is_floating_point : std::true_type {}; + +template ::value> +struct is_fast_float : bool_constant::is_iec559 && + sizeof(T) <= sizeof(double)> {}; +template struct is_fast_float : std::false_type {}; + +template +using fast_float_t = conditional_t; + +template +using is_double_double = bool_constant::digits == 106>; + +#ifndef FMT_USE_FULL_CACHE_DRAGONBOX +# define FMT_USE_FULL_CACHE_DRAGONBOX 0 +#endif + +// An allocator that uses malloc/free to allow removing dependency on the C++ +// standard libary runtime. std::decay is used for back_inserter to be found by +// ADL when applied to memory_buffer. +template struct allocator : private std::decay { + using value_type = T; + + auto allocate(size_t n) -> T* { + FMT_ASSERT(n <= max_value() / sizeof(T), ""); + T* p = static_cast(malloc(n * sizeof(T))); + if (!p) FMT_THROW(std::bad_alloc()); + return p; + } + + void deallocate(T* p, size_t) { free(p); } + + constexpr friend auto operator==(allocator, allocator) noexcept -> bool { + return true; // All instances of this allocator are equivalent. + } + constexpr friend auto operator!=(allocator, allocator) noexcept -> bool { + return false; + } +}; + +template +FMT_CONSTEXPR auto maybe_set_debug_format(Formatter& f, bool set) + -> decltype(f.set_debug_format(set)) { + f.set_debug_format(set); +} +template +FMT_CONSTEXPR void maybe_set_debug_format(Formatter&, ...) {} + +} // namespace detail + +FMT_BEGIN_EXPORT + +// The number of characters to store in the basic_memory_buffer object itself +// to avoid dynamic memory allocation. +enum { inline_buffer_size = 500 }; + +/** + * A dynamically growing memory buffer for trivially copyable/constructible + * types with the first `SIZE` elements stored in the object itself. Most + * commonly used via the `memory_buffer` alias for `char`. + * + * **Example**: + * + * auto out = fmt::memory_buffer(); + * fmt::format_to(std::back_inserter(out), "The answer is {}.", 42); + * + * This will append "The answer is 42." to `out`. The buffer content can be + * converted to `std::string` with `to_string(out)`. + */ +template > +class basic_memory_buffer : public detail::buffer { + private: + T store_[SIZE]; + + // Don't inherit from Allocator to avoid generating type_info for it. + FMT_NO_UNIQUE_ADDRESS Allocator alloc_; + + // Deallocate memory allocated by the buffer. + FMT_CONSTEXPR20 void deallocate() { + T* data = this->data(); + if (data != store_) alloc_.deallocate(data, this->capacity()); + } + + static FMT_CONSTEXPR20 void grow(detail::buffer& buf, size_t size) { + detail::abort_fuzzing_if(size > 5000); + auto& self = static_cast(buf); + const size_t max_size = + std::allocator_traits::max_size(self.alloc_); + size_t old_capacity = buf.capacity(); + size_t new_capacity = old_capacity + old_capacity / 2; + if (size > new_capacity) + new_capacity = size; + else if (new_capacity > max_size) + new_capacity = max_of(size, max_size); + T* old_data = buf.data(); + T* new_data = self.alloc_.allocate(new_capacity); + // Suppress a bogus -Wstringop-overflow in gcc 13.1 (#3481). + detail::assume(buf.size() <= new_capacity); + // The following code doesn't throw, so the raw pointer above doesn't leak. + memcpy(new_data, old_data, buf.size() * sizeof(T)); + self.set(new_data, new_capacity); + // deallocate must not throw according to the standard, but even if it does, + // the buffer already uses the new storage and will deallocate it in + // destructor. + if (old_data != self.store_) self.alloc_.deallocate(old_data, old_capacity); + } + + public: + using value_type = T; + using const_reference = const T&; + + FMT_CONSTEXPR explicit basic_memory_buffer( + const Allocator& alloc = Allocator()) + : detail::buffer(grow), alloc_(alloc) { + this->set(store_, SIZE); + if (detail::is_constant_evaluated()) detail::fill_n(store_, SIZE, T()); + } + FMT_CONSTEXPR20 ~basic_memory_buffer() { deallocate(); } + + private: + template :: + propagate_on_container_move_assignment::value)> + FMT_CONSTEXPR20 auto move_alloc(basic_memory_buffer& other) -> bool { + alloc_ = std::move(other.alloc_); + return true; + } + // If the allocator does not propagate then copy the data from other. + template :: + propagate_on_container_move_assignment::value)> + FMT_CONSTEXPR20 auto move_alloc(basic_memory_buffer& other) -> bool { + T* data = other.data(); + if (alloc_ == other.alloc_ || data == other.store_) return true; + size_t size = other.size(); + // Perform copy operation, allocators are different. + this->resize(size); + detail::copy(data, data + size, this->data()); + return false; + } + + // Move data from other to this buffer. + FMT_CONSTEXPR20 void move(basic_memory_buffer& other) { + T* data = other.data(); + size_t size = other.size(), capacity = other.capacity(); + if (!move_alloc(other)) return; + if (data == other.store_) { + this->set(store_, capacity); + detail::copy(other.store_, other.store_ + size, store_); + } else { + this->set(data, capacity); + // Set pointer to the inline array so that delete is not called + // when deallocating. + other.set(other.store_, 0); + other.clear(); + } + this->resize(size); + } + + public: + /// Constructs a `basic_memory_buffer` object moving the content of the other + /// object to it. + FMT_CONSTEXPR20 basic_memory_buffer(basic_memory_buffer&& other) noexcept + : detail::buffer(grow) { + move(other); + } + + /// Moves the content of the other `basic_memory_buffer` object to this one. + auto operator=(basic_memory_buffer&& other) noexcept -> basic_memory_buffer& { + FMT_ASSERT(this != &other, ""); + deallocate(); + move(other); + return *this; + } + + // Returns a copy of the allocator associated with this buffer. + auto get_allocator() const -> Allocator { return alloc_; } + + /// Resizes the buffer to contain `count` elements. If T is a POD type new + /// elements may not be initialized. + FMT_CONSTEXPR void resize(size_t count) { this->try_resize(count); } + + /// Increases the buffer capacity to `new_capacity`. + void reserve(size_t new_capacity) { this->try_reserve(new_capacity); } + + using detail::buffer::append; + template + FMT_CONSTEXPR20 void append(const ContiguousRange& range) { + append(range.data(), range.data() + range.size()); + } +}; + +using memory_buffer = basic_memory_buffer; + +template +FMT_NODISCARD auto to_string(const basic_memory_buffer& buf) + -> std::string { + auto size = buf.size(); + detail::assume(size < std::string().max_size()); + return {buf.data(), size}; +} + +// A writer to a buffered stream. It doesn't own the underlying stream. +class writer { + private: + detail::buffer* buf_; + + // We cannot create a file buffer in advance because any write to a FILE may + // invalidate it. + FILE* file_; + + public: + inline writer(FILE* f) : buf_(nullptr), file_(f) {} + inline writer(detail::buffer& buf) : buf_(&buf) {} + + /// Formats `args` according to specifications in `fmt` and writes the + /// output to the file. + template void print(format_string fmt, T&&... args) { + if (buf_) + fmt::format_to(appender(*buf_), fmt, std::forward(args)...); + else + fmt::print(file_, fmt, std::forward(args)...); + } +}; + +class string_buffer { + private: + std::string str_; + detail::container_buffer buf_; + + public: + inline string_buffer() : buf_(str_) {} + + inline operator writer() { return buf_; } + inline auto str() -> std::string& { return str_; } +}; + +template +struct is_contiguous> : std::true_type { +}; + +// Suppress a misleading warning in older versions of clang. +FMT_PRAGMA_CLANG(diagnostic ignored "-Wweak-vtables") + +/// An error reported from a formatting function. +class FMT_SO_VISIBILITY("default") format_error : public std::runtime_error { + public: + using std::runtime_error::runtime_error; +}; + +class loc_value; + +FMT_END_EXPORT +namespace detail { +FMT_API auto write_console(int fd, string_view text) -> bool; +FMT_API void print(FILE*, string_view); +} // namespace detail + +namespace detail { +template struct fixed_string { + FMT_CONSTEXPR20 fixed_string(const Char (&s)[N]) { + detail::copy(static_cast(s), s + N, + data); + } + Char data[N] = {}; +}; + +// Converts a compile-time string to basic_string_view. +FMT_EXPORT template +constexpr auto compile_string_to_view(const Char (&s)[N]) + -> basic_string_view { + // Remove trailing NUL character if needed. Won't be present if this is used + // with a raw character array (i.e. not defined as a string). + return {s, N - (std::char_traits::to_int_type(s[N - 1]) == 0 ? 1 : 0)}; +} +FMT_EXPORT template +constexpr auto compile_string_to_view(basic_string_view s) + -> basic_string_view { + return s; +} + +// Returns true if value is negative, false otherwise. +// Same as `value < 0` but doesn't produce warnings if T is an unsigned type. +template ::value)> +constexpr auto is_negative(T value) -> bool { + return value < 0; +} +template ::value)> +constexpr auto is_negative(T) -> bool { + return false; +} + +// Smallest of uint32_t, uint64_t, uint128_t that is large enough to +// represent all values of an integral type T. +template +using uint32_or_64_or_128_t = + conditional_t() <= 32 && !FMT_REDUCE_INT_INSTANTIATIONS, + uint32_t, + conditional_t() <= 64, uint64_t, uint128_t>>; +template +using uint64_or_128_t = conditional_t() <= 64, uint64_t, uint128_t>; + +#define FMT_POWERS_OF_10(factor) \ + factor * 10, (factor) * 100, (factor) * 1000, (factor) * 10000, \ + (factor) * 100000, (factor) * 1000000, (factor) * 10000000, \ + (factor) * 100000000, (factor) * 1000000000 + +// Converts value in the range [0, 100) to a string. +// GCC generates slightly better code when value is pointer-size. +inline auto digits2(size_t value) -> const char* { + // Align data since unaligned access may be slower when crossing a + // hardware-specific boundary. + alignas(2) static const char data[] = + "0001020304050607080910111213141516171819" + "2021222324252627282930313233343536373839" + "4041424344454647484950515253545556575859" + "6061626364656667686970717273747576777879" + "8081828384858687888990919293949596979899"; + return &data[value * 2]; +} + +template constexpr auto getsign(sign s) -> Char { + return static_cast(((' ' << 24) | ('+' << 16) | ('-' << 8)) >> + (static_cast(s) * 8)); +} + +template FMT_CONSTEXPR auto count_digits_fallback(T n) -> int { + int count = 1; + for (;;) { + // Integer division is slow so do it for a group of four digits instead + // of for every digit. The idea comes from the talk by Alexandrescu + // "Three Optimization Tips for C++". See speed-test for a comparison. + if (n < 10) return count; + if (n < 100) return count + 1; + if (n < 1000) return count + 2; + if (n < 10000) return count + 3; + n /= 10000u; + count += 4; + } +} +#if FMT_USE_INT128 +FMT_CONSTEXPR inline auto count_digits(uint128_opt n) -> int { + return count_digits_fallback(n); +} +#endif + +#ifdef FMT_BUILTIN_CLZLL +// It is a separate function rather than a part of count_digits to workaround +// the lack of static constexpr in constexpr functions. +inline auto do_count_digits(uint64_t n) -> int { + // This has comparable performance to the version by Kendall Willets + // (https://github.com/fmtlib/format-benchmark/blob/master/digits10) + // but uses smaller tables. + // Maps bsr(n) to ceil(log10(pow(2, bsr(n) + 1) - 1)). + static constexpr uint8_t bsr2log10[] = { + 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, + 6, 6, 6, 7, 7, 7, 7, 8, 8, 8, 9, 9, 9, 10, 10, 10, + 10, 11, 11, 11, 12, 12, 12, 13, 13, 13, 13, 14, 14, 14, 15, 15, + 15, 16, 16, 16, 16, 17, 17, 17, 18, 18, 18, 19, 19, 19, 19, 20}; + auto t = bsr2log10[FMT_BUILTIN_CLZLL(n | 1) ^ 63]; + static constexpr uint64_t zero_or_powers_of_10[] = { + 0, 0, FMT_POWERS_OF_10(1U), FMT_POWERS_OF_10(1000000000ULL), + 10000000000000000000ULL}; + return t - (n < zero_or_powers_of_10[t]); +} +#endif + +// Returns the number of decimal digits in n. Leading zeros are not counted +// except for n == 0 in which case count_digits returns 1. +FMT_CONSTEXPR20 inline auto count_digits(uint64_t n) -> int { +#ifdef FMT_BUILTIN_CLZLL + if (!is_constant_evaluated() && !FMT_OPTIMIZE_SIZE) return do_count_digits(n); +#endif + return count_digits_fallback(n); +} + +// Counts the number of digits in n. BITS = log2(radix). +template +FMT_CONSTEXPR auto count_digits(UInt n) -> int { +#ifdef FMT_BUILTIN_CLZ + if (!is_constant_evaluated() && num_bits() == 32) + return (FMT_BUILTIN_CLZ(static_cast(n) | 1) ^ 31) / BITS + 1; +#endif + // Lambda avoids unreachable code warnings from NVHPC. + return [](UInt m) { + int num_digits = 0; + do { + ++num_digits; + } while ((m >>= BITS) != 0); + return num_digits; + }(n); +} + +#ifdef FMT_BUILTIN_CLZ +// It is a separate function rather than a part of count_digits to workaround +// the lack of static constexpr in constexpr functions. +FMT_INLINE auto do_count_digits(uint32_t n) -> int { +// An optimization by Kendall Willets from https://bit.ly/3uOIQrB. +// This increments the upper 32 bits (log10(T) - 1) when >= T is added. +# define FMT_INC(T) (((sizeof(#T) - 1ull) << 32) - T) + static constexpr uint64_t table[] = { + FMT_INC(0), FMT_INC(0), FMT_INC(0), // 8 + FMT_INC(10), FMT_INC(10), FMT_INC(10), // 64 + FMT_INC(100), FMT_INC(100), FMT_INC(100), // 512 + FMT_INC(1000), FMT_INC(1000), FMT_INC(1000), // 4096 + FMT_INC(10000), FMT_INC(10000), FMT_INC(10000), // 32k + FMT_INC(100000), FMT_INC(100000), FMT_INC(100000), // 256k + FMT_INC(1000000), FMT_INC(1000000), FMT_INC(1000000), // 2048k + FMT_INC(10000000), FMT_INC(10000000), FMT_INC(10000000), // 16M + FMT_INC(100000000), FMT_INC(100000000), FMT_INC(100000000), // 128M + FMT_INC(1000000000), FMT_INC(1000000000), FMT_INC(1000000000), // 1024M + FMT_INC(1000000000), FMT_INC(1000000000) // 4B + }; + auto inc = table[FMT_BUILTIN_CLZ(n | 1) ^ 31]; + return static_cast((n + inc) >> 32); +} +#endif + +// Optional version of count_digits for better performance on 32-bit platforms. +FMT_CONSTEXPR20 inline auto count_digits(uint32_t n) -> int { +#ifdef FMT_BUILTIN_CLZ + if (!is_constant_evaluated() && !FMT_OPTIMIZE_SIZE) return do_count_digits(n); +#endif + return count_digits_fallback(n); +} + +template constexpr auto digits10() noexcept -> int { + return std::numeric_limits::digits10; +} +template <> constexpr auto digits10() noexcept -> int { return 38; } +template <> constexpr auto digits10() noexcept -> int { return 38; } + +template struct thousands_sep_result { + std::string grouping; + Char thousands_sep; +}; + +template +FMT_API auto thousands_sep_impl(locale_ref loc) -> thousands_sep_result; +template +inline auto thousands_sep(locale_ref loc) -> thousands_sep_result { + auto result = thousands_sep_impl(loc); + return {result.grouping, Char(result.thousands_sep)}; +} +template <> +inline auto thousands_sep(locale_ref loc) -> thousands_sep_result { + return thousands_sep_impl(loc); +} + +template +FMT_API auto decimal_point_impl(locale_ref loc) -> Char; +template inline auto decimal_point(locale_ref loc) -> Char { + return Char(decimal_point_impl(loc)); +} +template <> inline auto decimal_point(locale_ref loc) -> wchar_t { + return decimal_point_impl(loc); +} + +#ifndef FMT_HEADER_ONLY +FMT_BEGIN_EXPORT +extern template FMT_API auto thousands_sep_impl(locale_ref) + -> thousands_sep_result; +extern template FMT_API auto thousands_sep_impl(locale_ref) + -> thousands_sep_result; +extern template FMT_API auto decimal_point_impl(locale_ref) -> char; +extern template FMT_API auto decimal_point_impl(locale_ref) -> wchar_t; +FMT_END_EXPORT +#endif // FMT_HEADER_ONLY + +// Compares two characters for equality. +template auto equal2(const Char* lhs, const char* rhs) -> bool { + return lhs[0] == Char(rhs[0]) && lhs[1] == Char(rhs[1]); +} +inline auto equal2(const char* lhs, const char* rhs) -> bool { + return memcmp(lhs, rhs, 2) == 0; +} + +// Writes a two-digit value to out. +template +FMT_CONSTEXPR20 FMT_INLINE void write2digits(Char* out, size_t value) { + if (!is_constant_evaluated() && std::is_same::value && + !FMT_OPTIMIZE_SIZE) { + memcpy(out, digits2(value), 2); + return; + } + *out++ = static_cast('0' + value / 10); + *out = static_cast('0' + value % 10); +} + +// Formats a decimal unsigned integer value writing to out pointing to a buffer +// of specified size. The caller must ensure that the buffer is large enough. +template +FMT_CONSTEXPR20 auto do_format_decimal(Char* out, UInt value, int size) + -> Char* { + FMT_ASSERT(size >= count_digits(value), "invalid digit count"); + unsigned n = to_unsigned(size); + while (value >= 100) { + // Integer division is slow so do it for a group of two digits instead + // of for every digit. The idea comes from the talk by Alexandrescu + // "Three Optimization Tips for C++". See speed-test for a comparison. + n -= 2; + write2digits(out + n, static_cast(value % 100)); + value /= 100; + } + if (value >= 10) { + n -= 2; + write2digits(out + n, static_cast(value)); + } else { + out[--n] = static_cast('0' + value); + } + return out + n; +} + +template +FMT_CONSTEXPR FMT_INLINE auto format_decimal(Char* out, UInt value, + int num_digits) -> Char* { + do_format_decimal(out, value, num_digits); + return out + num_digits; +} + +template >::value)> +FMT_CONSTEXPR auto format_decimal(OutputIt out, UInt value, int num_digits) + -> OutputIt { + if (auto ptr = to_pointer(out, to_unsigned(num_digits))) { + do_format_decimal(ptr, value, num_digits); + return out; + } + // Buffer is large enough to hold all digits (digits10 + 1). + char buffer[digits10() + 1]; + if (is_constant_evaluated()) fill_n(buffer, sizeof(buffer), '\0'); + do_format_decimal(buffer, value, num_digits); + return copy_noinline(buffer, buffer + num_digits, out); +} + +template +FMT_CONSTEXPR auto do_format_base2e(int base_bits, Char* out, UInt value, + int size, bool upper = false) -> Char* { + out += size; + do { + const char* digits = upper ? "0123456789ABCDEF" : "0123456789abcdef"; + unsigned digit = static_cast(value & ((1u << base_bits) - 1)); + *--out = static_cast(base_bits < 4 ? static_cast('0' + digit) + : digits[digit]); + } while ((value >>= base_bits) != 0); + return out; +} + +// Formats an unsigned integer in the power of two base (binary, octal, hex). +template +FMT_CONSTEXPR auto format_base2e(int base_bits, Char* out, UInt value, + int num_digits, bool upper = false) -> Char* { + do_format_base2e(base_bits, out, value, num_digits, upper); + return out + num_digits; +} + +template ::value)> +FMT_CONSTEXPR inline auto format_base2e(int base_bits, OutputIt out, UInt value, + int num_digits, bool upper = false) + -> OutputIt { + if (auto ptr = to_pointer(out, to_unsigned(num_digits))) { + format_base2e(base_bits, ptr, value, num_digits, upper); + return out; + } + // Make buffer large enough for any base. + char buffer[num_bits()]; + if (is_constant_evaluated()) fill_n(buffer, sizeof(buffer), '\0'); + format_base2e(base_bits, buffer, value, num_digits, upper); + return detail::copy_noinline(buffer, buffer + num_digits, out); +} + +// A converter from UTF-8 to UTF-16. +class utf8_to_utf16 { + private: + basic_memory_buffer buffer_; + + public: + FMT_API explicit utf8_to_utf16(string_view s); + inline operator basic_string_view() const { + return {&buffer_[0], size()}; + } + inline auto size() const -> size_t { return buffer_.size() - 1; } + inline auto c_str() const -> const wchar_t* { return &buffer_[0]; } + inline auto str() const -> std::wstring { return {&buffer_[0], size()}; } +}; + +enum class to_utf8_error_policy { abort, replace }; + +// A converter from UTF-16/UTF-32 (host endian) to UTF-8. +template class to_utf8 { + private: + Buffer buffer_; + + public: + to_utf8() {} + explicit to_utf8(basic_string_view s, + to_utf8_error_policy policy = to_utf8_error_policy::abort) { + static_assert(sizeof(WChar) == 2 || sizeof(WChar) == 4, + "expected utf16 or utf32"); + if (!convert(s, policy)) { + FMT_THROW(std::runtime_error(sizeof(WChar) == 2 ? "invalid utf16" + : "invalid utf32")); + } + } + operator string_view() const { return string_view(&buffer_[0], size()); } + auto size() const -> size_t { return buffer_.size() - 1; } + auto c_str() const -> const char* { return &buffer_[0]; } + auto str() const -> std::string { return std::string(&buffer_[0], size()); } + + // Performs conversion returning a bool instead of throwing exception on + // conversion error. This method may still throw in case of memory allocation + // error. + auto convert(basic_string_view s, + to_utf8_error_policy policy = to_utf8_error_policy::abort) + -> bool { + if (!convert(buffer_, s, policy)) return false; + buffer_.push_back(0); + return true; + } + static auto convert(Buffer& buf, basic_string_view s, + to_utf8_error_policy policy = to_utf8_error_policy::abort) + -> bool { + for (auto p = s.begin(); p != s.end(); ++p) { + uint32_t c = static_cast(*p); + if (sizeof(WChar) == 2 && c >= 0xd800 && c <= 0xdfff) { + // Handle a surrogate pair. + ++p; + if (p == s.end() || (c & 0xfc00) != 0xd800 || (*p & 0xfc00) != 0xdc00) { + if (policy == to_utf8_error_policy::abort) return false; + buf.append(string_view("\xEF\xBF\xBD")); + --p; + continue; + } + c = (c << 10) + static_cast(*p) - 0x35fdc00; + } + if (c < 0x80) { + buf.push_back(static_cast(c)); + } else if (c < 0x800) { + buf.push_back(static_cast(0xc0 | (c >> 6))); + buf.push_back(static_cast(0x80 | (c & 0x3f))); + } else if ((c >= 0x800 && c <= 0xd7ff) || (c >= 0xe000 && c <= 0xffff)) { + buf.push_back(static_cast(0xe0 | (c >> 12))); + buf.push_back(static_cast(0x80 | ((c & 0xfff) >> 6))); + buf.push_back(static_cast(0x80 | (c & 0x3f))); + } else if (c >= 0x10000 && c <= 0x10ffff) { + buf.push_back(static_cast(0xf0 | (c >> 18))); + buf.push_back(static_cast(0x80 | ((c & 0x3ffff) >> 12))); + buf.push_back(static_cast(0x80 | ((c & 0xfff) >> 6))); + buf.push_back(static_cast(0x80 | (c & 0x3f))); + } else { + return false; + } + } + return true; + } +}; + +// Computes 128-bit result of multiplication of two 64-bit unsigned integers. +FMT_INLINE auto umul128(uint64_t x, uint64_t y) noexcept -> uint128_fallback { +#if FMT_USE_INT128 + auto p = static_cast(x) * static_cast(y); + return {static_cast(p >> 64), static_cast(p)}; +#elif defined(_MSC_VER) && defined(_M_X64) + auto hi = uint64_t(); + auto lo = _umul128(x, y, &hi); + return {hi, lo}; +#else + const uint64_t mask = static_cast(max_value()); + + uint64_t a = x >> 32; + uint64_t b = x & mask; + uint64_t c = y >> 32; + uint64_t d = y & mask; + + uint64_t ac = a * c; + uint64_t bc = b * c; + uint64_t ad = a * d; + uint64_t bd = b * d; + + uint64_t intermediate = (bd >> 32) + (ad & mask) + (bc & mask); + + return {ac + (intermediate >> 32) + (ad >> 32) + (bc >> 32), + (intermediate << 32) + (bd & mask)}; +#endif +} + +namespace dragonbox { +// Computes floor(log10(pow(2, e))) for e in [-2620, 2620] using the method from +// https://fmt.dev/papers/Dragonbox.pdf#page=28, section 6.1. +inline auto floor_log10_pow2(int e) noexcept -> int { + FMT_ASSERT(e <= 2620 && e >= -2620, "too large exponent"); + static_assert((-1 >> 1) == -1, "right shift is not arithmetic"); + return (e * 315653) >> 20; +} + +inline auto floor_log2_pow10(int e) noexcept -> int { + FMT_ASSERT(e <= 1233 && e >= -1233, "too large exponent"); + return (e * 1741647) >> 19; +} + +// Computes upper 64 bits of multiplication of two 64-bit unsigned integers. +inline auto umul128_upper64(uint64_t x, uint64_t y) noexcept -> uint64_t { +#if FMT_USE_INT128 + auto p = static_cast(x) * static_cast(y); + return static_cast(p >> 64); +#elif defined(_MSC_VER) && defined(_M_X64) + return __umulh(x, y); +#else + return umul128(x, y).high(); +#endif +} + +// Computes upper 128 bits of multiplication of a 64-bit unsigned integer and a +// 128-bit unsigned integer. +inline auto umul192_upper128(uint64_t x, uint128_fallback y) noexcept + -> uint128_fallback { + uint128_fallback r = umul128(x, y.high()); + r += umul128_upper64(x, y.low()); + return r; +} + +FMT_API auto get_cached_power(int k) noexcept -> uint128_fallback; + +// Type-specific information that Dragonbox uses. +template struct float_info; + +template <> struct float_info { + using carrier_uint = uint32_t; + static const int exponent_bits = 8; + static const int kappa = 1; + static const int big_divisor = 100; + static const int small_divisor = 10; + static const int min_k = -31; + static const int max_k = 46; + static const int shorter_interval_tie_lower_threshold = -35; + static const int shorter_interval_tie_upper_threshold = -35; +}; + +template <> struct float_info { + using carrier_uint = uint64_t; + static const int exponent_bits = 11; + static const int kappa = 2; + static const int big_divisor = 1000; + static const int small_divisor = 100; + static const int min_k = -292; + static const int max_k = 341; + static const int shorter_interval_tie_lower_threshold = -77; + static const int shorter_interval_tie_upper_threshold = -77; +}; + +// An 80- or 128-bit floating point number. +template +struct float_info::digits == 64 || + std::numeric_limits::digits == 113 || + is_float128::value>> { + using carrier_uint = detail::uint128_t; + static const int exponent_bits = 15; +}; + +// A double-double floating point number. +template +struct float_info::value>> { + using carrier_uint = detail::uint128_t; +}; + +template struct decimal_fp { + using significand_type = typename float_info::carrier_uint; + significand_type significand; + int exponent; +}; + +template FMT_API auto to_decimal(T x) noexcept -> decimal_fp; +} // namespace dragonbox + +// Returns true iff Float has the implicit bit which is not stored. +template constexpr auto has_implicit_bit() -> bool { + // An 80-bit FP number has a 64-bit significand an no implicit bit. + return std::numeric_limits::digits != 64; +} + +// Returns the number of significand bits stored in Float. The implicit bit is +// not counted since it is not stored. +template constexpr auto num_significand_bits() -> int { + // std::numeric_limits may not support __float128. + return is_float128() ? 112 + : (std::numeric_limits::digits - + (has_implicit_bit() ? 1 : 0)); +} + +template +constexpr auto exponent_mask() -> + typename dragonbox::float_info::carrier_uint { + using float_uint = typename dragonbox::float_info::carrier_uint; + return ((float_uint(1) << dragonbox::float_info::exponent_bits) - 1) + << num_significand_bits(); +} +template constexpr auto exponent_bias() -> int { + // std::numeric_limits may not support __float128. + return is_float128() ? 16383 + : std::numeric_limits::max_exponent - 1; +} + +FMT_CONSTEXPR inline auto compute_exp_size(int exp) -> int { + auto prefix_size = 2; // sign + 'e' + auto abs_exp = exp >= 0 ? exp : -exp; + if (abs_exp < 100) return prefix_size + 2; + return prefix_size + (abs_exp >= 1000 ? 4 : 3); +} + +// Writes the exponent exp in the form "[+-]d{2,3}" to buffer. +template +FMT_CONSTEXPR auto write_exponent(int exp, OutputIt out) -> OutputIt { + FMT_ASSERT(-10000 < exp && exp < 10000, "exponent out of range"); + if (exp < 0) { + *out++ = static_cast('-'); + exp = -exp; + } else { + *out++ = static_cast('+'); + } + auto uexp = static_cast(exp); + if (is_constant_evaluated()) { + if (uexp < 10) *out++ = '0'; + return format_decimal(out, uexp, count_digits(uexp)); + } + if (uexp >= 100u) { + const char* top = digits2(uexp / 100); + if (uexp >= 1000u) *out++ = static_cast(top[0]); + *out++ = static_cast(top[1]); + uexp %= 100; + } + const char* d = digits2(uexp); + *out++ = static_cast(d[0]); + *out++ = static_cast(d[1]); + return out; +} + +// A floating-point number f * pow(2, e) where F is an unsigned type. +template struct basic_fp { + F f; + int e; + + static constexpr int num_significand_bits = + static_cast(sizeof(F) * num_bits()); + + constexpr basic_fp() : f(0), e(0) {} + constexpr basic_fp(uint64_t f_val, int e_val) : f(f_val), e(e_val) {} + + // Constructs fp from an IEEE754 floating-point number. + template FMT_CONSTEXPR basic_fp(Float n) { assign(n); } + + // Assigns n to this and return true iff predecessor is closer than successor. + template ::value)> + FMT_CONSTEXPR auto assign(Float n) -> bool { + static_assert(std::numeric_limits::digits <= 113, "unsupported FP"); + // Assume Float is in the format [sign][exponent][significand]. + using carrier_uint = typename dragonbox::float_info::carrier_uint; + const auto num_float_significand_bits = + detail::num_significand_bits(); + const auto implicit_bit = carrier_uint(1) << num_float_significand_bits; + const auto significand_mask = implicit_bit - 1; + auto u = bit_cast(n); + f = static_cast(u & significand_mask); + auto biased_e = static_cast((u & exponent_mask()) >> + num_float_significand_bits); + // The predecessor is closer if n is a normalized power of 2 (f == 0) + // other than the smallest normalized number (biased_e > 1). + auto is_predecessor_closer = f == 0 && biased_e > 1; + if (biased_e == 0) + biased_e = 1; // Subnormals use biased exponent 1 (min exponent). + else if (has_implicit_bit()) + f += static_cast(implicit_bit); + e = biased_e - exponent_bias() - num_float_significand_bits; + if (!has_implicit_bit()) ++e; + return is_predecessor_closer; + } + + template ::value)> + FMT_CONSTEXPR auto assign(Float n) -> bool { + static_assert(std::numeric_limits::is_iec559, "unsupported FP"); + return assign(static_cast(n)); + } +}; + +using fp = basic_fp; + +// Normalizes the value converted from double and multiplied by (1 << SHIFT). +template +FMT_CONSTEXPR auto normalize(basic_fp value) -> basic_fp { + // Handle subnormals. + const auto implicit_bit = F(1) << num_significand_bits(); + const auto shifted_implicit_bit = implicit_bit << SHIFT; + while ((value.f & shifted_implicit_bit) == 0) { + value.f <<= 1; + --value.e; + } + // Subtract 1 to account for hidden bit. + const auto offset = basic_fp::num_significand_bits - + num_significand_bits() - SHIFT - 1; + value.f <<= offset; + value.e -= offset; + return value; +} + +// Computes lhs * rhs / pow(2, 64) rounded to nearest with half-up tie breaking. +FMT_CONSTEXPR inline auto multiply(uint64_t lhs, uint64_t rhs) -> uint64_t { +#if FMT_USE_INT128 + auto product = static_cast<__uint128_t>(lhs) * rhs; + auto f = static_cast(product >> 64); + return (static_cast(product) & (1ULL << 63)) != 0 ? f + 1 : f; +#else + // Multiply 32-bit parts of significands. + uint64_t mask = (1ULL << 32) - 1; + uint64_t a = lhs >> 32, b = lhs & mask; + uint64_t c = rhs >> 32, d = rhs & mask; + uint64_t ac = a * c, bc = b * c, ad = a * d, bd = b * d; + // Compute mid 64-bit of result and round. + uint64_t mid = (bd >> 32) + (ad & mask) + (bc & mask) + (1U << 31); + return ac + (ad >> 32) + (bc >> 32) + (mid >> 32); +#endif +} + +FMT_CONSTEXPR inline auto operator*(fp x, fp y) -> fp { + return {multiply(x.f, y.f), x.e + y.e + 64}; +} + +template () == num_bits()> +using convert_float_result = + conditional_t::value || doublish, double, T>; + +template +constexpr auto convert_float(T value) -> convert_float_result { + return static_cast>(value); +} + +template +auto select(T true_value, F) -> T { + return true_value; +} +template +auto select(T, F false_value) -> F { + return false_value; +} + +template +FMT_CONSTEXPR FMT_NOINLINE auto fill(OutputIt it, size_t n, + const basic_specs& specs) -> OutputIt { + auto fill_size = specs.fill_size(); + if (fill_size == 1) return detail::fill_n(it, n, specs.fill_unit()); + if (const Char* data = specs.fill()) { + for (size_t i = 0; i < n; ++i) it = copy(data, data + fill_size, it); + } + return it; +} + +// Writes the output of f, padded according to format specifications in specs. +// size: output size in code units. +// width: output display width in (terminal) column positions. +template +FMT_CONSTEXPR auto write_padded(OutputIt out, const format_specs& specs, + size_t size, size_t width, F&& f) -> OutputIt { + static_assert(default_align == align::left || default_align == align::right, + ""); + unsigned spec_width = to_unsigned(specs.width); + size_t padding = spec_width > width ? spec_width - width : 0; + // Shifts are encoded as string literals because static constexpr is not + // supported in constexpr functions. + auto* shifts = + default_align == align::left ? "\x1f\x1f\x00\x01" : "\x00\x1f\x00\x01"; + size_t left_padding = padding >> shifts[static_cast(specs.align())]; + size_t right_padding = padding - left_padding; + auto it = reserve(out, size + padding * specs.fill_size()); + if (left_padding != 0) it = fill(it, left_padding, specs); + it = f(it); + if (right_padding != 0) it = fill(it, right_padding, specs); + return base_iterator(out, it); +} + +template +constexpr auto write_padded(OutputIt out, const format_specs& specs, + size_t size, F&& f) -> OutputIt { + return write_padded(out, specs, size, size, f); +} + +template +FMT_CONSTEXPR auto write_bytes(OutputIt out, string_view bytes, + const format_specs& specs = {}) -> OutputIt { + return write_padded( + out, specs, bytes.size(), [bytes](reserve_iterator it) { + const char* data = bytes.data(); + return copy(data, data + bytes.size(), it); + }); +} + +template +auto write_ptr(OutputIt out, UIntPtr value, const format_specs* specs) + -> OutputIt { + int num_digits = count_digits<4>(value); + auto size = to_unsigned(num_digits) + size_t(2); + auto write = [=](reserve_iterator it) { + *it++ = static_cast('0'); + *it++ = static_cast('x'); + return format_base2e(4, it, value, num_digits); + }; + return specs ? write_padded(out, *specs, size, write) + : base_iterator(out, write(reserve(out, size))); +} + +// Returns true iff the code point cp is printable. +FMT_API auto is_printable(uint32_t cp) -> bool; + +inline auto needs_escape(uint32_t cp) -> bool { + if (cp < 0x20 || cp == 0x7f || cp == '"' || cp == '\\') return true; + if (const_check(FMT_OPTIMIZE_SIZE > 1)) return false; + return !is_printable(cp); +} + +template struct find_escape_result { + const Char* begin; + const Char* end; + uint32_t cp; +}; + +template +auto find_escape(const Char* begin, const Char* end) + -> find_escape_result { + for (; begin != end; ++begin) { + uint32_t cp = static_cast>(*begin); + if (const_check(sizeof(Char) == 1) && cp >= 0x80) continue; + if (needs_escape(cp)) return {begin, begin + 1, cp}; + } + return {begin, nullptr, 0}; +} + +inline auto find_escape(const char* begin, const char* end) + -> find_escape_result { + if (const_check(!use_utf8)) return find_escape(begin, end); + auto result = find_escape_result{end, nullptr, 0}; + for_each_codepoint(string_view(begin, to_unsigned(end - begin)), + [&](uint32_t cp, string_view sv) { + if (needs_escape(cp)) { + result = {sv.begin(), sv.end(), cp}; + return false; + } + return true; + }); + return result; +} + +template +auto write_codepoint(OutputIt out, char prefix, uint32_t cp) -> OutputIt { + *out++ = static_cast('\\'); + *out++ = static_cast(prefix); + Char buf[width]; + fill_n(buf, width, static_cast('0')); + format_base2e(4, buf, cp, width); + return copy(buf, buf + width, out); +} + +template +auto write_escaped_cp(OutputIt out, const find_escape_result& escape) + -> OutputIt { + auto c = static_cast(escape.cp); + switch (escape.cp) { + case '\n': + *out++ = static_cast('\\'); + c = static_cast('n'); + break; + case '\r': + *out++ = static_cast('\\'); + c = static_cast('r'); + break; + case '\t': + *out++ = static_cast('\\'); + c = static_cast('t'); + break; + case '"': FMT_FALLTHROUGH; + case '\'': FMT_FALLTHROUGH; + case '\\': *out++ = static_cast('\\'); break; + default: + if (escape.cp < 0x100) return write_codepoint<2, Char>(out, 'x', escape.cp); + if (escape.cp < 0x10000) + return write_codepoint<4, Char>(out, 'u', escape.cp); + if (escape.cp < 0x110000) + return write_codepoint<8, Char>(out, 'U', escape.cp); + for (Char escape_char : basic_string_view( + escape.begin, to_unsigned(escape.end - escape.begin))) { + out = write_codepoint<2, Char>(out, 'x', + static_cast(escape_char) & 0xFF); + } + return out; + } + *out++ = c; + return out; +} + +template +auto write_escaped_string(OutputIt out, basic_string_view str) + -> OutputIt { + *out++ = static_cast('"'); + auto begin = str.begin(), end = str.end(); + do { + auto escape = find_escape(begin, end); + out = copy(begin, escape.begin, out); + begin = escape.end; + if (!begin) break; + out = write_escaped_cp(out, escape); + } while (begin != end); + *out++ = static_cast('"'); + return out; +} + +template +auto write_escaped_char(OutputIt out, Char v) -> OutputIt { + Char v_array[1] = {v}; + *out++ = static_cast('\''); + if ((needs_escape(static_cast(v)) && v != static_cast('"')) || + v == static_cast('\'')) { + out = write_escaped_cp(out, + find_escape_result{v_array, v_array + 1, + static_cast(v)}); + } else { + *out++ = v; + } + *out++ = static_cast('\''); + return out; +} + +template +FMT_CONSTEXPR auto write_char(OutputIt out, Char value, + const format_specs& specs) -> OutputIt { + bool is_debug = specs.type() == presentation_type::debug; + return write_padded(out, specs, 1, [=](reserve_iterator it) { + if (is_debug) return write_escaped_char(it, value); + *it++ = value; + return it; + }); +} + +template class digit_grouping { + private: + std::string grouping_; + std::basic_string thousands_sep_; + + struct next_state { + std::string::const_iterator group; + int pos; + }; + auto initial_state() const -> next_state { return {grouping_.begin(), 0}; } + + // Returns the next digit group separator position. + auto next(next_state& state) const -> int { + if (thousands_sep_.empty()) return max_value(); + if (state.group == grouping_.end()) return state.pos += grouping_.back(); + if (*state.group <= 0 || *state.group == max_value()) + return max_value(); + state.pos += *state.group++; + return state.pos; + } + + public: + explicit digit_grouping(locale_ref loc, bool localized = true) { + if (!localized) return; + auto sep = thousands_sep(loc); + grouping_ = sep.grouping; + if (sep.thousands_sep) thousands_sep_.assign(1, sep.thousands_sep); + } + digit_grouping(std::string grouping, std::basic_string sep) + : grouping_(std::move(grouping)), thousands_sep_(std::move(sep)) {} + + auto has_separator() const -> bool { return !thousands_sep_.empty(); } + + auto count_separators(int num_digits) const -> int { + int count = 0; + auto state = initial_state(); + while (num_digits > next(state)) ++count; + return count; + } + + // Applies grouping to digits and writes the output to out. + template + auto apply(Out out, basic_string_view digits) const -> Out { + auto num_digits = static_cast(digits.size()); + auto separators = basic_memory_buffer(); + separators.push_back(0); + auto state = initial_state(); + while (int i = next(state)) { + if (i >= num_digits) break; + separators.push_back(i); + } + for (int i = 0, sep_index = static_cast(separators.size() - 1); + i < num_digits; ++i) { + if (num_digits - i == separators[sep_index]) { + out = copy(thousands_sep_.data(), + thousands_sep_.data() + thousands_sep_.size(), out); + --sep_index; + } + *out++ = static_cast(digits[to_unsigned(i)]); + } + return out; + } +}; + +FMT_CONSTEXPR inline void prefix_append(unsigned& prefix, unsigned value) { + prefix |= prefix != 0 ? value << 8 : value; + prefix += (1u + (value > 0xff ? 1 : 0)) << 24; +} + +// Writes a decimal integer with digit grouping. +template +auto write_int(OutputIt out, UInt value, unsigned prefix, + const format_specs& specs, const digit_grouping& grouping) + -> OutputIt { + static_assert(std::is_same, UInt>::value, ""); + int num_digits = 0; + auto buffer = memory_buffer(); + switch (specs.type()) { + default: FMT_ASSERT(false, ""); FMT_FALLTHROUGH; + case presentation_type::none: + case presentation_type::dec: + num_digits = count_digits(value); + format_decimal(appender(buffer), value, num_digits); + break; + case presentation_type::hex: + if (specs.alt()) + prefix_append(prefix, unsigned(specs.upper() ? 'X' : 'x') << 8 | '0'); + num_digits = count_digits<4>(value); + format_base2e(4, appender(buffer), value, num_digits, specs.upper()); + break; + case presentation_type::oct: + num_digits = count_digits<3>(value); + // Octal prefix '0' is counted as a digit, so only add it if precision + // is not greater than the number of digits. + if (specs.alt() && specs.precision <= num_digits && value != 0) + prefix_append(prefix, '0'); + format_base2e(3, appender(buffer), value, num_digits); + break; + case presentation_type::bin: + if (specs.alt()) + prefix_append(prefix, unsigned(specs.upper() ? 'B' : 'b') << 8 | '0'); + num_digits = count_digits<1>(value); + format_base2e(1, appender(buffer), value, num_digits); + break; + case presentation_type::chr: + return write_char(out, static_cast(value), specs); + } + + unsigned size = (prefix != 0 ? prefix >> 24 : 0) + to_unsigned(num_digits) + + to_unsigned(grouping.count_separators(num_digits)); + return write_padded( + out, specs, size, size, [&](reserve_iterator it) { + for (unsigned p = prefix & 0xffffff; p != 0; p >>= 8) + *it++ = static_cast(p & 0xff); + return grouping.apply(it, string_view(buffer.data(), buffer.size())); + }); +} + +#if FMT_USE_LOCALE +// Writes a localized value. +FMT_API auto write_loc(appender out, loc_value value, const format_specs& specs, + locale_ref loc) -> bool; +auto write_loc(basic_appender out, loc_value value, + const format_specs& specs, locale_ref loc) -> bool; +#endif +template +inline auto write_loc(OutputIt, const loc_value&, const format_specs&, + locale_ref) -> bool { + return false; +} + +template struct write_int_arg { + UInt abs_value; + unsigned prefix; +}; + +template +FMT_CONSTEXPR auto make_write_int_arg(T value, sign s) + -> write_int_arg> { + auto prefix = 0u; + auto abs_value = static_cast>(value); + if (is_negative(value)) { + prefix = 0x01000000 | '-'; + abs_value = 0 - abs_value; + } else { + constexpr unsigned prefixes[4] = {0, 0, 0x1000000u | '+', 0x1000000u | ' '}; + prefix = prefixes[static_cast(s)]; + } + return {abs_value, prefix}; +} + +template struct loc_writer { + basic_appender out; + const format_specs& specs; + std::basic_string sep; + std::string grouping; + std::basic_string decimal_point; + + template ::value)> + auto operator()(T value) -> bool { + auto arg = make_write_int_arg(value, specs.sign()); + write_int(out, static_cast>(arg.abs_value), arg.prefix, + specs, digit_grouping(grouping, sep)); + return true; + } + + template ::value)> + auto operator()(T) -> bool { + return false; + } +}; + +// Size and padding computation separate from write_int to avoid template bloat. +struct size_padding { + unsigned size; + unsigned padding; + + FMT_CONSTEXPR size_padding(int num_digits, unsigned prefix, + const format_specs& specs) + : size((prefix >> 24) + to_unsigned(num_digits)), padding(0) { + if (specs.align() == align::numeric) { + auto width = to_unsigned(specs.width); + if (width > size) { + padding = width - size; + size = width; + } + } else if (specs.precision > num_digits) { + size = (prefix >> 24) + to_unsigned(specs.precision); + padding = to_unsigned(specs.precision - num_digits); + } + } +}; + +template +FMT_CONSTEXPR FMT_INLINE auto write_int(OutputIt out, write_int_arg arg, + const format_specs& specs) -> OutputIt { + static_assert(std::is_same>::value, ""); + + constexpr size_t buffer_size = num_bits(); + char buffer[buffer_size]; + if (is_constant_evaluated()) fill_n(buffer, buffer_size, '\0'); + const char* begin = nullptr; + const char* end = buffer + buffer_size; + + auto abs_value = arg.abs_value; + auto prefix = arg.prefix; + switch (specs.type()) { + default: FMT_ASSERT(false, ""); FMT_FALLTHROUGH; + case presentation_type::none: + case presentation_type::dec: + begin = do_format_decimal(buffer, abs_value, buffer_size); + break; + case presentation_type::hex: + begin = do_format_base2e(4, buffer, abs_value, buffer_size, specs.upper()); + if (specs.alt()) + prefix_append(prefix, unsigned(specs.upper() ? 'X' : 'x') << 8 | '0'); + break; + case presentation_type::oct: { + begin = do_format_base2e(3, buffer, abs_value, buffer_size); + // Octal prefix '0' is counted as a digit, so only add it if precision + // is not greater than the number of digits. + auto num_digits = end - begin; + if (specs.alt() && specs.precision <= num_digits && abs_value != 0) + prefix_append(prefix, '0'); + break; + } + case presentation_type::bin: + begin = do_format_base2e(1, buffer, abs_value, buffer_size); + if (specs.alt()) + prefix_append(prefix, unsigned(specs.upper() ? 'B' : 'b') << 8 | '0'); + break; + case presentation_type::chr: + return write_char(out, static_cast(abs_value), specs); + } + + // Write an integer in the format + // + // prefix contains chars in three lower bytes and the size in the fourth byte. + int num_digits = static_cast(end - begin); + // Slightly faster check for specs.width == 0 && specs.precision == -1. + if ((specs.width | (specs.precision + 1)) == 0) { + auto it = reserve(out, to_unsigned(num_digits) + (prefix >> 24)); + for (unsigned p = prefix & 0xffffff; p != 0; p >>= 8) + *it++ = static_cast(p & 0xff); + return base_iterator(out, copy(begin, end, it)); + } + auto sp = size_padding(num_digits, prefix, specs); + unsigned padding = sp.padding; + return write_padded( + out, specs, sp.size, [=](reserve_iterator it) { + for (unsigned p = prefix & 0xffffff; p != 0; p >>= 8) + *it++ = static_cast(p & 0xff); + it = detail::fill_n(it, padding, static_cast('0')); + return copy(begin, end, it); + }); +} + +template +FMT_CONSTEXPR FMT_NOINLINE auto write_int_noinline(OutputIt out, + write_int_arg arg, + const format_specs& specs) + -> OutputIt { + return write_int(out, arg, specs); +} + +template ::value && + !std::is_same::value && + !std::is_same::value)> +FMT_CONSTEXPR FMT_INLINE auto write(basic_appender out, T value, + const format_specs& specs, locale_ref loc) + -> basic_appender { + if (specs.localized() && write_loc(out, value, specs, loc)) return out; + return write_int_noinline(out, make_write_int_arg(value, specs.sign()), + specs); +} + +// An inlined version of write used in format string compilation. +template ::value && + !std::is_same::value && + !std::is_same::value && + !std::is_same>::value)> +FMT_CONSTEXPR FMT_INLINE auto write(OutputIt out, T value, + const format_specs& specs, locale_ref loc) + -> OutputIt { + if (specs.localized() && write_loc(out, value, specs, loc)) return out; + return write_int(out, make_write_int_arg(value, specs.sign()), specs); +} + +template +FMT_CONSTEXPR auto write(OutputIt out, Char value, const format_specs& specs, + locale_ref loc = {}) -> OutputIt { + // char is formatted as unsigned char for consistency across platforms. + using unsigned_type = + conditional_t::value, unsigned char, unsigned>; + return check_char_specs(specs) + ? write_char(out, value, specs) + : write(out, static_cast(value), specs, loc); +} + +template ::value)> +FMT_CONSTEXPR auto write(OutputIt out, basic_string_view s, + const format_specs& specs) -> OutputIt { + bool is_debug = specs.type() == presentation_type::debug; + if (specs.precision < 0 && specs.width == 0) { + auto&& it = reserve(out, s.size()); + return is_debug ? write_escaped_string(it, s) : copy(s, it); + } + + size_t display_width_limit = + specs.precision < 0 ? SIZE_MAX : to_unsigned(specs.precision); + size_t display_width = + !is_debug || specs.precision == 0 ? 0 : 1; // Account for opening '"'. + size_t size = !is_debug || specs.precision == 0 ? 0 : 1; + for_each_codepoint(s, [&](uint32_t cp, string_view sv) { + if (is_debug && needs_escape(cp)) { + counting_buffer buf; + write_escaped_cp(basic_appender(buf), + find_escape_result{sv.begin(), sv.end(), cp}); + // We're reinterpreting bytes as display width. That's okay + // because write_escaped_cp() only writes ASCII characters. + size_t cp_width = buf.count(); + if (display_width + cp_width <= display_width_limit) { + display_width += cp_width; + size += cp_width; + // If this is the end of the string, account for closing '"'. + if (display_width < display_width_limit && sv.end() == s.end()) { + ++display_width; + ++size; + } + return true; + } + + size += display_width_limit - display_width; + display_width = display_width_limit; + return false; + } + + size_t cp_width = display_width_of(cp); + if (cp_width + display_width <= display_width_limit) { + display_width += cp_width; + size += sv.size(); + // If this is the end of the string, account for closing '"'. + if (is_debug && display_width < display_width_limit && + sv.end() == s.end()) { + ++display_width; + ++size; + } + return true; + } + + return false; + }); + + struct bounded_output_iterator { + reserve_iterator underlying_iterator; + size_t bound; + + FMT_CONSTEXPR auto operator*() -> bounded_output_iterator& { return *this; } + FMT_CONSTEXPR auto operator++() -> bounded_output_iterator& { + return *this; + } + FMT_CONSTEXPR auto operator++(int) -> bounded_output_iterator& { + return *this; + } + FMT_CONSTEXPR auto operator=(char c) -> bounded_output_iterator& { + if (bound > 0) { + *underlying_iterator++ = c; + --bound; + } + return *this; + } + }; + + return write_padded( + out, specs, size, display_width, [=](reserve_iterator it) { + return is_debug + ? write_escaped_string(bounded_output_iterator{it, size}, s) + .underlying_iterator + : copy(s.data(), s.data() + size, it); + }); +} + +template ::value)> +FMT_CONSTEXPR auto write(OutputIt out, basic_string_view s, + const format_specs& specs) -> OutputIt { + auto data = s.data(); + auto size = s.size(); + if (specs.precision >= 0 && to_unsigned(specs.precision) < size) + size = to_unsigned(specs.precision); + + bool is_debug = specs.type() == presentation_type::debug; + if (is_debug) { + auto buf = counting_buffer(); + write_escaped_string(basic_appender(buf), s); + size = buf.count(); + } + + return write_padded( + out, specs, size, [=](reserve_iterator it) { + return is_debug ? write_escaped_string(it, s) + : copy(data, data + size, it); + }); +} + +template +FMT_CONSTEXPR auto write(OutputIt out, basic_string_view s, + const format_specs& specs, locale_ref) -> OutputIt { + return write(out, s, specs); +} + +template +FMT_CONSTEXPR auto write(OutputIt out, const Char* s, const format_specs& specs, + locale_ref) -> OutputIt { + if (specs.type() == presentation_type::pointer) + return write_ptr(out, bit_cast(s), &specs); + if (!s) report_error("string pointer is null"); + return write(out, basic_string_view(s), specs, {}); +} + +template ::value && + !std::is_same::value && + !std::is_same::value)> +FMT_CONSTEXPR auto write(OutputIt out, T value) -> OutputIt { + auto abs_value = static_cast>(value); + bool negative = is_negative(value); + // Don't do -abs_value since it trips unsigned-integer-overflow sanitizer. + if (negative) abs_value = ~abs_value + 1; + int num_digits = count_digits(abs_value); + auto size = (negative ? 1 : 0) + static_cast(num_digits); + if (auto ptr = to_pointer(out, size)) { + if (negative) *ptr++ = static_cast('-'); + format_decimal(ptr, abs_value, num_digits); + return out; + } + if (negative) *out++ = static_cast('-'); + return format_decimal(out, abs_value, num_digits); +} + +template +FMT_CONSTEXPR auto parse_align(const Char* begin, const Char* end, + format_specs& specs) -> const Char* { + FMT_ASSERT(begin != end, ""); + auto alignment = align::none; + auto p = begin + code_point_length(begin); + if (end - p <= 0) p = begin; + for (;;) { + switch (to_ascii(*p)) { + case '<': alignment = align::left; break; + case '>': alignment = align::right; break; + case '^': alignment = align::center; break; + } + if (alignment != align::none) { + if (p != begin) { + auto c = *begin; + if (c == '}') return begin; + if (c == '{') { + report_error("invalid fill character '{'"); + return begin; + } + specs.set_fill(basic_string_view(begin, to_unsigned(p - begin))); + begin = p + 1; + } else { + ++begin; + } + break; + } else if (p == begin) { + break; + } + p = begin; + } + specs.set_align(alignment); + return begin; +} + +template +FMT_CONSTEXPR20 auto write_nonfinite(OutputIt out, bool isnan, + format_specs specs, sign s) -> OutputIt { + auto str = + isnan ? (specs.upper() ? "NAN" : "nan") : (specs.upper() ? "INF" : "inf"); + constexpr size_t str_size = 3; + auto size = str_size + (s != sign::none ? 1 : 0); + // Replace '0'-padding with space for non-finite values. + const bool is_zero_fill = + specs.fill_size() == 1 && specs.fill_unit() == '0'; + if (is_zero_fill) specs.set_fill(' '); + return write_padded(out, specs, size, + [=](reserve_iterator it) { + if (s != sign::none) + *it++ = detail::getsign(s); + return copy(str, str + str_size, it); + }); +} + +// A decimal floating-point number significand * pow(10, exp). +struct big_decimal_fp { + const char* significand; + int significand_size; + int exponent; +}; + +constexpr auto get_significand_size(const big_decimal_fp& f) -> int { + return f.significand_size; +} +template +inline auto get_significand_size(const dragonbox::decimal_fp& f) -> int { + return count_digits(f.significand); +} + +template +constexpr auto write_significand(OutputIt out, const char* significand, + int significand_size) -> OutputIt { + return copy(significand, significand + significand_size, out); +} +template +inline auto write_significand(OutputIt out, UInt significand, + int significand_size) -> OutputIt { + return format_decimal(out, significand, significand_size); +} +template +FMT_CONSTEXPR20 auto write_significand(OutputIt out, T significand, + int significand_size, int exponent, + const Grouping& grouping) -> OutputIt { + if (!grouping.has_separator()) { + out = write_significand(out, significand, significand_size); + return detail::fill_n(out, exponent, static_cast('0')); + } + auto buffer = memory_buffer(); + write_significand(appender(buffer), significand, significand_size); + detail::fill_n(appender(buffer), exponent, '0'); + return grouping.apply(out, string_view(buffer.data(), buffer.size())); +} + +template ::value)> +inline auto write_significand(Char* out, UInt significand, int significand_size, + int integral_size, Char decimal_point) -> Char* { + if (!decimal_point) return format_decimal(out, significand, significand_size); + out += significand_size + 1; + Char* end = out; + int floating_size = significand_size - integral_size; + for (int i = floating_size / 2; i > 0; --i) { + out -= 2; + write2digits(out, static_cast(significand % 100)); + significand /= 100; + } + if (floating_size % 2 != 0) { + *--out = static_cast('0' + significand % 10); + significand /= 10; + } + *--out = decimal_point; + format_decimal(out - integral_size, significand, integral_size); + return end; +} + +template >::value)> +inline auto write_significand(OutputIt out, UInt significand, + int significand_size, int integral_size, + Char decimal_point) -> OutputIt { + // Buffer is large enough to hold digits (digits10 + 1) and a decimal point. + Char buffer[digits10() + 2]; + auto end = write_significand(buffer, significand, significand_size, + integral_size, decimal_point); + return detail::copy_noinline(buffer, end, out); +} + +template +FMT_CONSTEXPR auto write_significand(OutputIt out, const char* significand, + int significand_size, int integral_size, + Char decimal_point) -> OutputIt { + out = detail::copy_noinline(significand, significand + integral_size, + out); + if (!decimal_point) return out; + *out++ = decimal_point; + return detail::copy_noinline(significand + integral_size, + significand + significand_size, out); +} + +template +FMT_CONSTEXPR20 auto write_significand(OutputIt out, T significand, + int significand_size, int integral_size, + Char decimal_point, + const Grouping& grouping) -> OutputIt { + if (!grouping.has_separator()) { + return write_significand(out, significand, significand_size, integral_size, + decimal_point); + } + auto buffer = basic_memory_buffer(); + write_significand(basic_appender(buffer), significand, significand_size, + integral_size, decimal_point); + grouping.apply( + out, basic_string_view(buffer.data(), to_unsigned(integral_size))); + return detail::copy_noinline(buffer.data() + integral_size, + buffer.end(), out); +} + +// Numbers with exponents greater or equal to the returned value will use +// the exponential notation. +template FMT_CONSTEVAL auto exp_upper() -> int { + return std::numeric_limits::digits10 != 0 + ? min_of(16, std::numeric_limits::digits10 + 1) + : 16; +} + +// Use the fixed notation if the exponent is in [-4, exp_upper), +// e.g. 0.0001 instead of 1e-04. Otherwise use the exponent notation. +constexpr auto use_fixed(int exp, int exp_upper) -> bool { + return exp >= -4 && exp < exp_upper; +} + +template class fallback_digit_grouping { + public: + constexpr fallback_digit_grouping(locale_ref, bool) {} + + constexpr auto has_separator() const -> bool { return false; } + + constexpr auto count_separators(int) const -> int { return 0; } + + template + constexpr auto apply(Out out, basic_string_view) const -> Out { + return out; + } +}; + +template +FMT_CONSTEXPR20 auto write_fixed(OutputIt out, const DecimalFP& f, + int significand_size, Char decimal_point, + const format_specs& specs, sign s, + locale_ref loc = {}) -> OutputIt { + using iterator = reserve_iterator; + + int exp = f.exponent + significand_size; + long long size = significand_size + (s != sign::none ? 1 : 0); + if (f.exponent >= 0) { + // 1234e5 -> 123400000[.0+] + size += f.exponent; + int num_zeros = specs.precision - exp; + abort_fuzzing_if(num_zeros > 5000); + if (specs.alt()) { + ++size; + if (num_zeros <= 0 && specs.type() != presentation_type::fixed) + num_zeros = 0; + if (num_zeros > 0) size += num_zeros; + } + auto grouping = Grouping(loc, specs.localized()); + size += grouping.count_separators(exp); + return write_padded( + out, specs, static_cast(size), [&](iterator it) { + if (s != sign::none) *it++ = detail::getsign(s); + it = write_significand(it, f.significand, significand_size, + f.exponent, grouping); + if (!specs.alt()) return it; + *it++ = decimal_point; + return num_zeros > 0 ? detail::fill_n(it, num_zeros, Char('0')) : it; + }); + } + if (exp > 0) { + // 1234e-2 -> 12.34[0+] + int num_zeros = specs.alt() ? specs.precision - significand_size : 0; + size += 1 + max_of(num_zeros, 0); + auto grouping = Grouping(loc, specs.localized()); + size += grouping.count_separators(exp); + return write_padded( + out, specs, to_unsigned(size), [&](iterator it) { + if (s != sign::none) *it++ = detail::getsign(s); + it = write_significand(it, f.significand, significand_size, exp, + decimal_point, grouping); + return num_zeros > 0 ? detail::fill_n(it, num_zeros, Char('0')) : it; + }); + } + // 1234e-6 -> 0.001234 + int num_zeros = -exp; + if (significand_size == 0 && specs.precision >= 0 && + specs.precision < num_zeros) { + num_zeros = specs.precision; + } + bool pointy = num_zeros != 0 || significand_size != 0 || specs.alt(); + size += 1 + (pointy ? 1 : 0) + num_zeros; + return write_padded( + out, specs, to_unsigned(size), [&](iterator it) { + if (s != sign::none) *it++ = detail::getsign(s); + *it++ = Char('0'); + if (!pointy) return it; + *it++ = decimal_point; + it = detail::fill_n(it, num_zeros, Char('0')); + return write_significand(it, f.significand, significand_size); + }); +} + +template +FMT_CONSTEXPR20 auto do_write_float(OutputIt out, const DecimalFP& f, + const format_specs& specs, sign s, + int exp_upper, locale_ref loc) -> OutputIt { + Char point = specs.localized() ? detail::decimal_point(loc) : Char('.'); + int significand_size = get_significand_size(f); + int exp = f.exponent + significand_size - 1; + if (specs.type() == presentation_type::fixed || + (specs.type() != presentation_type::exp && + use_fixed(exp, specs.precision > 0 ? specs.precision : exp_upper))) { + return write_fixed(out, f, significand_size, point, specs, + s, loc); + } + + // Write value in the exponential format. + int num_zeros = 0; + long long size = significand_size + (s != sign::none ? 1 : 0); + if (specs.alt()) { + num_zeros = max_of(specs.precision - significand_size, 0); + size += num_zeros; + } else if (significand_size == 1) { + point = Char(); + } + size += (point ? 1 : 0) + compute_exp_size(exp); + char exp_char = specs.upper() ? 'E' : 'e'; + auto write = [=](reserve_iterator it) { + if (s != sign::none) *it++ = detail::getsign(s); + // Insert a decimal point after the first digit and add an exponent. + it = write_significand(it, f.significand, significand_size, 1, point); + if (num_zeros > 0) it = detail::fill_n(it, num_zeros, Char('0')); + *it++ = Char(exp_char); + return write_exponent(exp, it); + }; + auto usize = to_unsigned(size); + return specs.width > 0 + ? write_padded(out, specs, usize, write) + : base_iterator(out, write(reserve(out, usize))); +} + +template +FMT_CONSTEXPR20 auto write_float(OutputIt out, const DecimalFP& f, + const format_specs& specs, sign s, + int exp_upper, locale_ref loc) -> OutputIt { + if (is_constant_evaluated()) { + return do_write_float>(out, f, specs, s, + exp_upper, loc); + } else { + return do_write_float>(out, f, specs, s, + exp_upper, loc); + } +} + +template constexpr auto isnan(T value) -> bool { + return value != value; // std::isnan doesn't support __float128. +} + +template +struct has_isfinite : std::false_type {}; + +template +struct has_isfinite> + : std::true_type {}; + +template ::value&& has_isfinite::value)> +FMT_CONSTEXPR20 auto isfinite(T value) -> bool { + constexpr T inf = T(std::numeric_limits::infinity()); + if (is_constant_evaluated()) + return !detail::isnan(value) && value < inf && value > -inf; + return std::isfinite(value); +} +template ::value)> +FMT_CONSTEXPR auto isfinite(T value) -> bool { + T inf = T(std::numeric_limits::infinity()); + // std::isfinite doesn't support __float128. + return !detail::isnan(value) && value < inf && value > -inf; +} + +template ::value)> +FMT_INLINE FMT_CONSTEXPR auto signbit(T value) -> bool { + if (is_constant_evaluated()) { +#ifdef __cpp_if_constexpr + if constexpr (std::numeric_limits::is_iec559) { + auto bits = detail::bit_cast(static_cast(value)); + return (bits >> (num_bits() - 1)) != 0; + } +#endif + } + return std::signbit(static_cast(value)); +} + +inline FMT_CONSTEXPR20 void adjust_precision(int& precision, int exp10) { + // Adjust fixed precision by exponent because it is relative to decimal + // point. + if (exp10 > 0 && precision > max_value() - exp10) + FMT_THROW(format_error("number is too big")); + precision += exp10; +} + +class bigint { + private: + // A bigint is a number in the form bigit_[N - 1] ... bigit_[0] * 32^exp_. + using bigit = uint32_t; // A big digit. + using double_bigit = uint64_t; + enum { bigit_bits = num_bits() }; + enum { bigits_capacity = 32 }; + basic_memory_buffer bigits_; + int exp_; + + friend struct formatter; + + FMT_CONSTEXPR auto get_bigit(int i) const -> bigit { + return i >= exp_ && i < num_bigits() ? bigits_[i - exp_] : 0; + } + + FMT_CONSTEXPR void subtract_bigits(int index, bigit other, bigit& borrow) { + auto result = double_bigit(bigits_[index]) - other - borrow; + bigits_[index] = static_cast(result); + borrow = static_cast(result >> (bigit_bits * 2 - 1)); + } + + FMT_CONSTEXPR void remove_leading_zeros() { + int num_bigits = static_cast(bigits_.size()) - 1; + while (num_bigits > 0 && bigits_[num_bigits] == 0) --num_bigits; + bigits_.resize(to_unsigned(num_bigits + 1)); + } + + // Computes *this -= other assuming aligned bigints and *this >= other. + FMT_CONSTEXPR void subtract_aligned(const bigint& other) { + FMT_ASSERT(other.exp_ >= exp_, "unaligned bigints"); + FMT_ASSERT(compare(*this, other) >= 0, ""); + bigit borrow = 0; + int i = other.exp_ - exp_; + for (size_t j = 0, n = other.bigits_.size(); j != n; ++i, ++j) + subtract_bigits(i, other.bigits_[j], borrow); + if (borrow != 0) subtract_bigits(i, 0, borrow); + FMT_ASSERT(borrow == 0, ""); + remove_leading_zeros(); + } + + FMT_CONSTEXPR void multiply(uint32_t value) { + bigit carry = 0; + const double_bigit wide_value = value; + for (size_t i = 0, n = bigits_.size(); i < n; ++i) { + double_bigit result = bigits_[i] * wide_value + carry; + bigits_[i] = static_cast(result); + carry = static_cast(result >> bigit_bits); + } + if (carry != 0) bigits_.push_back(carry); + } + + template ::value || + std::is_same::value)> + FMT_CONSTEXPR void multiply(UInt value) { + using half_uint = + conditional_t::value, uint64_t, uint32_t>; + const int shift = num_bits() - bigit_bits; + const UInt lower = static_cast(value); + const UInt upper = value >> num_bits(); + UInt carry = 0; + for (size_t i = 0, n = bigits_.size(); i < n; ++i) { + UInt result = lower * bigits_[i] + static_cast(carry); + carry = (upper * bigits_[i] << shift) + (result >> bigit_bits) + + (carry >> bigit_bits); + bigits_[i] = static_cast(result); + } + while (carry != 0) { + bigits_.push_back(static_cast(carry)); + carry >>= bigit_bits; + } + } + + template ::value || + std::is_same::value)> + FMT_CONSTEXPR void assign(UInt n) { + size_t num_bigits = 0; + do { + bigits_[num_bigits++] = static_cast(n); + n >>= bigit_bits; + } while (n != 0); + bigits_.resize(num_bigits); + exp_ = 0; + } + + public: + FMT_CONSTEXPR bigint() : exp_(0) {} + explicit bigint(uint64_t n) { assign(n); } + + bigint(const bigint&) = delete; + void operator=(const bigint&) = delete; + + FMT_CONSTEXPR void assign(const bigint& other) { + auto size = other.bigits_.size(); + bigits_.resize(size); + auto data = other.bigits_.data(); + copy(data, data + size, bigits_.data()); + exp_ = other.exp_; + } + + template FMT_CONSTEXPR void operator=(Int n) { + FMT_ASSERT(n > 0, ""); + assign(uint64_or_128_t(n)); + } + + FMT_CONSTEXPR auto num_bigits() const -> int { + return static_cast(bigits_.size()) + exp_; + } + + FMT_CONSTEXPR auto operator<<=(int shift) -> bigint& { + FMT_ASSERT(shift >= 0, ""); + exp_ += shift / bigit_bits; + shift %= bigit_bits; + if (shift == 0) return *this; + bigit carry = 0; + for (size_t i = 0, n = bigits_.size(); i < n; ++i) { + bigit c = bigits_[i] >> (bigit_bits - shift); + bigits_[i] = (bigits_[i] << shift) + carry; + carry = c; + } + if (carry != 0) bigits_.push_back(carry); + return *this; + } + + template FMT_CONSTEXPR auto operator*=(Int value) -> bigint& { + FMT_ASSERT(value > 0, ""); + multiply(uint32_or_64_or_128_t(value)); + return *this; + } + + friend FMT_CONSTEXPR auto compare(const bigint& b1, const bigint& b2) -> int { + int num_bigits1 = b1.num_bigits(), num_bigits2 = b2.num_bigits(); + if (num_bigits1 != num_bigits2) return num_bigits1 > num_bigits2 ? 1 : -1; + int i = static_cast(b1.bigits_.size()) - 1; + int j = static_cast(b2.bigits_.size()) - 1; + int end = i - j; + if (end < 0) end = 0; + for (; i >= end; --i, --j) { + bigit b1_bigit = b1.bigits_[i], b2_bigit = b2.bigits_[j]; + if (b1_bigit != b2_bigit) return b1_bigit > b2_bigit ? 1 : -1; + } + if (i != j) return i > j ? 1 : -1; + return 0; + } + + // Returns compare(lhs1 + lhs2, rhs). + friend FMT_CONSTEXPR auto add_compare(const bigint& lhs1, const bigint& lhs2, + const bigint& rhs) -> int { + int max_lhs_bigits = max_of(lhs1.num_bigits(), lhs2.num_bigits()); + int num_rhs_bigits = rhs.num_bigits(); + if (max_lhs_bigits + 1 < num_rhs_bigits) return -1; + if (max_lhs_bigits > num_rhs_bigits) return 1; + double_bigit borrow = 0; + int min_exp = min_of(min_of(lhs1.exp_, lhs2.exp_), rhs.exp_); + for (int i = num_rhs_bigits - 1; i >= min_exp; --i) { + double_bigit sum = double_bigit(lhs1.get_bigit(i)) + lhs2.get_bigit(i); + bigit rhs_bigit = rhs.get_bigit(i); + if (sum > rhs_bigit + borrow) return 1; + borrow = rhs_bigit + borrow - sum; + if (borrow > 1) return -1; + borrow <<= bigit_bits; + } + return borrow != 0 ? -1 : 0; + } + + // Assigns pow(10, exp) to this bigint. + FMT_CONSTEXPR20 void assign_pow10(int exp) { + FMT_ASSERT(exp >= 0, ""); + if (exp == 0) return *this = 1; + int bitmask = 1 << (num_bits() - + countl_zero(static_cast(exp)) - 1); + // pow(10, exp) = pow(5, exp) * pow(2, exp). First compute pow(5, exp) by + // repeated squaring and multiplication. + *this = 5; + bitmask >>= 1; + while (bitmask != 0) { + square(); + if ((exp & bitmask) != 0) *this *= 5; + bitmask >>= 1; + } + *this <<= exp; // Multiply by pow(2, exp) by shifting. + } + + FMT_CONSTEXPR20 void square() { + int num_bigits = static_cast(bigits_.size()); + int num_result_bigits = 2 * num_bigits; + basic_memory_buffer n(std::move(bigits_)); + bigits_.resize(to_unsigned(num_result_bigits)); + auto sum = uint128_t(); + for (int bigit_index = 0; bigit_index < num_bigits; ++bigit_index) { + // Compute bigit at position bigit_index of the result by adding + // cross-product terms n[i] * n[j] such that i + j == bigit_index. + for (int i = 0, j = bigit_index; j >= 0; ++i, --j) { + // Most terms are multiplied twice which can be optimized in the future. + sum += double_bigit(n[i]) * n[j]; + } + bigits_[bigit_index] = static_cast(sum); + sum >>= num_bits(); // Compute the carry. + } + // Do the same for the top half. + for (int bigit_index = num_bigits; bigit_index < num_result_bigits; + ++bigit_index) { + for (int j = num_bigits - 1, i = bigit_index - j; i < num_bigits;) + sum += double_bigit(n[i++]) * n[j--]; + bigits_[bigit_index] = static_cast(sum); + sum >>= num_bits(); + } + remove_leading_zeros(); + exp_ *= 2; + } + + // If this bigint has a bigger exponent than other, adds trailing zero to make + // exponents equal. This simplifies some operations such as subtraction. + FMT_CONSTEXPR void align(const bigint& other) { + int exp_difference = exp_ - other.exp_; + if (exp_difference <= 0) return; + int num_bigits = static_cast(bigits_.size()); + bigits_.resize(to_unsigned(num_bigits + exp_difference)); + for (int i = num_bigits - 1, j = i + exp_difference; i >= 0; --i, --j) + bigits_[j] = bigits_[i]; + fill_n(bigits_.data(), to_unsigned(exp_difference), 0U); + exp_ -= exp_difference; + } + + // Divides this bignum by divisor, assigning the remainder to this and + // returning the quotient. + FMT_CONSTEXPR auto divmod_assign(const bigint& divisor) -> int { + FMT_ASSERT(this != &divisor, ""); + if (compare(*this, divisor) < 0) return 0; + FMT_ASSERT(divisor.bigits_[divisor.bigits_.size() - 1u] != 0, ""); + align(divisor); + int quotient = 0; + do { + subtract_aligned(divisor); + ++quotient; + } while (compare(*this, divisor) >= 0); + return quotient; + } +}; + +// format_dragon flags. +enum dragon { + predecessor_closer = 1, + fixup = 2, // Run fixup to correct exp10 which can be off by one. + fixed = 4, +}; + +// Formats a floating-point number using a variation of the Fixed-Precision +// Positive Floating-Point Printout ((FPP)^2) algorithm by Steele & White: +// https://fmt.dev/papers/p372-steele.pdf. +FMT_CONSTEXPR20 inline void format_dragon(basic_fp value, + unsigned flags, int num_digits, + buffer& buf, int& exp10) { + bigint numerator; // 2 * R in (FPP)^2. + bigint denominator; // 2 * S in (FPP)^2. + // lower and upper are differences between value and corresponding boundaries. + bigint lower; // (M^- in (FPP)^2). + bigint upper_store; // upper's value if different from lower. + bigint* upper = nullptr; // (M^+ in (FPP)^2). + // Shift numerator and denominator by an extra bit or two (if lower boundary + // is closer) to make lower and upper integers. This eliminates multiplication + // by 2 during later computations. + bool is_predecessor_closer = (flags & dragon::predecessor_closer) != 0; + int shift = is_predecessor_closer ? 2 : 1; + if (value.e >= 0) { + numerator = value.f; + numerator <<= value.e + shift; + lower = 1; + lower <<= value.e; + if (is_predecessor_closer) { + upper_store = 1; + upper_store <<= value.e + 1; + upper = &upper_store; + } + denominator.assign_pow10(exp10); + denominator <<= shift; + } else if (exp10 < 0) { + numerator.assign_pow10(-exp10); + lower.assign(numerator); + if (is_predecessor_closer) { + upper_store.assign(numerator); + upper_store <<= 1; + upper = &upper_store; + } + numerator *= value.f; + numerator <<= shift; + denominator = 1; + denominator <<= shift - value.e; + } else { + numerator = value.f; + numerator <<= shift; + denominator.assign_pow10(exp10); + denominator <<= shift - value.e; + lower = 1; + if (is_predecessor_closer) { + upper_store = 1ULL << 1; + upper = &upper_store; + } + } + int even = static_cast((value.f & 1) == 0); + if (!upper) upper = &lower; + bool shortest = num_digits < 0; + if ((flags & dragon::fixup) != 0) { + if (add_compare(numerator, *upper, denominator) + even <= 0) { + --exp10; + numerator *= 10; + if (num_digits < 0) { + lower *= 10; + if (upper != &lower) *upper *= 10; + } + } + if ((flags & dragon::fixed) != 0) adjust_precision(num_digits, exp10 + 1); + } + // Invariant: value == (numerator / denominator) * pow(10, exp10). + if (shortest) { + // Generate the shortest representation. + num_digits = 0; + char* data = buf.data(); + for (;;) { + int digit = numerator.divmod_assign(denominator); + bool low = compare(numerator, lower) - even < 0; // numerator <[=] lower. + // numerator + upper >[=] pow10: + bool high = add_compare(numerator, *upper, denominator) + even > 0; + data[num_digits++] = static_cast('0' + digit); + if (low || high) { + if (!low) { + ++data[num_digits - 1]; + } else if (high) { + int result = add_compare(numerator, numerator, denominator); + // Round half to even. + if (result > 0 || (result == 0 && (digit % 2) != 0)) + ++data[num_digits - 1]; + } + buf.try_resize(to_unsigned(num_digits)); + exp10 -= num_digits - 1; + return; + } + numerator *= 10; + lower *= 10; + if (upper != &lower) *upper *= 10; + } + } + // Generate the given number of digits. + exp10 -= num_digits - 1; + if (num_digits <= 0) { + auto digit = '0'; + if (num_digits == 0) { + denominator *= 10; + digit = add_compare(numerator, numerator, denominator) > 0 ? '1' : '0'; + } + buf.push_back(digit); + return; + } + buf.try_resize(to_unsigned(num_digits)); + for (int i = 0; i < num_digits - 1; ++i) { + int digit = numerator.divmod_assign(denominator); + buf[i] = static_cast('0' + digit); + numerator *= 10; + } + int digit = numerator.divmod_assign(denominator); + auto result = add_compare(numerator, numerator, denominator); + if (result > 0 || (result == 0 && (digit % 2) != 0)) { + if (digit == 9) { + const auto overflow = '0' + 10; + buf[num_digits - 1] = overflow; + // Propagate the carry. + for (int i = num_digits - 1; i > 0 && buf[i] == overflow; --i) { + buf[i] = '0'; + ++buf[i - 1]; + } + if (buf[0] == overflow) { + buf[0] = '1'; + if ((flags & dragon::fixed) != 0) + buf.push_back('0'); + else + ++exp10; + } + return; + } + ++digit; + } + buf[num_digits - 1] = static_cast('0' + digit); +} + +// Formats a floating-point number using the hexfloat format. +template ::value)> +FMT_CONSTEXPR20 void format_hexfloat(Float value, format_specs specs, + buffer& buf) { + // float is passed as double to reduce the number of instantiations and to + // simplify implementation. + static_assert(!std::is_same::value, ""); + + using info = dragonbox::float_info; + + // Assume Float is in the format [sign][exponent][significand]. + using carrier_uint = typename info::carrier_uint; + + const auto num_float_significand_bits = detail::num_significand_bits(); + + basic_fp f(value); + f.e += num_float_significand_bits; + if (!has_implicit_bit()) --f.e; + + const auto num_fraction_bits = + num_float_significand_bits + (has_implicit_bit() ? 1 : 0); + const auto num_xdigits = (num_fraction_bits + 3) / 4; + + const auto leading_shift = ((num_xdigits - 1) * 4); + const auto leading_mask = carrier_uint(0xF) << leading_shift; + const auto leading_xdigit = + static_cast((f.f & leading_mask) >> leading_shift); + if (leading_xdigit > 1) f.e -= (32 - countl_zero(leading_xdigit) - 1); + + int print_xdigits = num_xdigits - 1; + if (specs.precision >= 0 && print_xdigits > specs.precision) { + const int shift = ((print_xdigits - specs.precision - 1) * 4); + const auto mask = carrier_uint(0xF) << shift; + const auto v = static_cast((f.f & mask) >> shift); + + if (v >= 8) { + const auto inc = carrier_uint(1) << (shift + 4); + f.f += inc; + f.f &= ~(inc - 1); + } + + // Check long double overflow + if (!has_implicit_bit()) { + const auto implicit_bit = carrier_uint(1) << num_float_significand_bits; + if ((f.f & implicit_bit) == implicit_bit) { + f.f >>= 4; + f.e += 4; + } + } + + print_xdigits = specs.precision; + } + + char xdigits[num_bits() / 4]; + detail::fill_n(xdigits, sizeof(xdigits), '0'); + format_base2e(4, xdigits, f.f, num_xdigits, specs.upper()); + + // Remove zero tail + while (print_xdigits > 0 && xdigits[print_xdigits] == '0') --print_xdigits; + + buf.push_back('0'); + buf.push_back(specs.upper() ? 'X' : 'x'); + buf.push_back(xdigits[0]); + if (specs.alt() || print_xdigits > 0 || print_xdigits < specs.precision) + buf.push_back('.'); + buf.append(xdigits + 1, xdigits + 1 + print_xdigits); + for (; print_xdigits < specs.precision; ++print_xdigits) buf.push_back('0'); + + buf.push_back(specs.upper() ? 'P' : 'p'); + + uint32_t abs_e; + if (f.e < 0) { + buf.push_back('-'); + abs_e = static_cast(-f.e); + } else { + buf.push_back('+'); + abs_e = static_cast(f.e); + } + format_decimal(appender(buf), abs_e, detail::count_digits(abs_e)); +} + +template ::value)> +FMT_CONSTEXPR20 void format_hexfloat(Float value, format_specs specs, + buffer& buf) { + format_hexfloat(static_cast(value), specs, buf); +} + +constexpr auto fractional_part_rounding_thresholds(int index) -> uint32_t { + // For checking rounding thresholds. + // The kth entry is chosen to be the smallest integer such that the + // upper 32-bits of 10^(k+1) times it is strictly bigger than 5 * 10^k. + // It is equal to ceil(2^31 + 2^32/10^(k + 1)). + // These are stored in a string literal because we cannot have static arrays + // in constexpr functions and non-static ones are poorly optimized. + return U"\x9999999a\x828f5c29\x80418938\x80068db9\x8000a7c6\x800010c7" + U"\x800001ae\x8000002b"[index]; +} + +template +FMT_CONSTEXPR20 auto format_float(Float value, int precision, + const format_specs& specs, bool binary32, + buffer& buf) -> int { + // float is passed as double to reduce the number of instantiations. + static_assert(!std::is_same::value, ""); + auto converted_value = convert_float(value); + + const bool fixed = specs.type() == presentation_type::fixed; + if (value == 0) { + if (precision <= 0 || !fixed) { + buf.push_back('0'); + return 0; + } + buf.try_resize(to_unsigned(precision)); + fill_n(buf.data(), precision, '0'); + return -precision; + } + + int exp = 0; + bool use_dragon = true; + unsigned dragon_flags = 0; + if (!is_fast_float() || is_constant_evaluated()) { + const auto inv_log2_10 = 0.3010299956639812; // 1 / log2(10) + using info = dragonbox::float_info; + const auto f = basic_fp(converted_value); + // Compute exp, an approximate power of 10, such that + // 10^(exp - 1) <= value < 10^exp or 10^exp <= value < 10^(exp + 1). + // This is based on log10(value) == log2(value) / log2(10) and approximation + // of log2(value) by e + num_fraction_bits idea from double-conversion. + auto e = (f.e + count_digits<1>(f.f) - 1) * inv_log2_10 - 1e-10; + exp = static_cast(e); + if (e > exp) ++exp; // Compute ceil. + dragon_flags = dragon::fixup; + } else { + // Extract significand bits and exponent bits. + using info = dragonbox::float_info; + auto br = bit_cast(static_cast(value)); + + const uint64_t significand_mask = + (static_cast(1) << num_significand_bits()) - 1; + uint64_t significand = (br & significand_mask); + int exponent = static_cast((br & exponent_mask()) >> + num_significand_bits()); + + if (exponent != 0) { // Check if normal. + exponent -= exponent_bias() + num_significand_bits(); + significand |= + (static_cast(1) << num_significand_bits()); + significand <<= 1; + } else { + // Normalize subnormal inputs. + FMT_ASSERT(significand != 0, "zeros should not appear here"); + int shift = countl_zero(significand); + FMT_ASSERT(shift >= num_bits() - num_significand_bits(), + ""); + shift -= (num_bits() - num_significand_bits() - 2); + exponent = (std::numeric_limits::min_exponent - + num_significand_bits()) - + shift; + significand <<= shift; + } + + // Compute the first several nonzero decimal significand digits. + // We call the number we get the first segment. + const int k = info::kappa - dragonbox::floor_log10_pow2(exponent); + exp = -k; + const int beta = exponent + dragonbox::floor_log2_pow10(k); + uint64_t first_segment; + bool has_more_segments; + int digits_in_the_first_segment; + { + const auto r = dragonbox::umul192_upper128( + significand << beta, dragonbox::get_cached_power(k)); + first_segment = r.high(); + has_more_segments = r.low() != 0; + + // The first segment can have 18 ~ 19 digits. + if (first_segment >= 1000000000000000000ULL) { + digits_in_the_first_segment = 19; + } else { + // When it is of 18-digits, we align it to 19-digits by adding a bogus + // zero at the end. + digits_in_the_first_segment = 18; + first_segment *= 10; + } + } + + // Compute the actual number of decimal digits to print. + if (fixed) adjust_precision(precision, exp + digits_in_the_first_segment); + + // Use Dragon4 only when there might be not enough digits in the first + // segment. + if (digits_in_the_first_segment > precision) { + use_dragon = false; + + if (precision <= 0) { + exp += digits_in_the_first_segment; + + if (precision < 0) { + // Nothing to do, since all we have are just leading zeros. + buf.try_resize(0); + } else { + // We may need to round-up. + buf.try_resize(1); + if ((first_segment | static_cast(has_more_segments)) > + 5000000000000000000ULL) { + buf[0] = '1'; + } else { + buf[0] = '0'; + } + } + } // precision <= 0 + else { + exp += digits_in_the_first_segment - precision; + + // When precision > 0, we divide the first segment into three + // subsegments, each with 9, 9, and 0 ~ 1 digits so that each fits + // in 32-bits which usually allows faster calculation than in + // 64-bits. Since some compiler (e.g. MSVC) doesn't know how to optimize + // division-by-constant for large 64-bit divisors, we do it here + // manually. The magic number 7922816251426433760 below is equal to + // ceil(2^(64+32) / 10^10). + const uint32_t first_subsegment = static_cast( + dragonbox::umul128_upper64(first_segment, 7922816251426433760ULL) >> + 32); + const uint64_t second_third_subsegments = + first_segment - first_subsegment * 10000000000ULL; + + uint64_t prod; + uint32_t digits; + bool should_round_up; + int number_of_digits_to_print = min_of(precision, 9); + + // Print a 9-digits subsegment, either the first or the second. + auto print_subsegment = [&](uint32_t subsegment, char* buffer) { + int number_of_digits_printed = 0; + + // If we want to print an odd number of digits from the subsegment, + if ((number_of_digits_to_print & 1) != 0) { + // Convert to 64-bit fixed-point fractional form with 1-digit + // integer part. The magic number 720575941 is a good enough + // approximation of 2^(32 + 24) / 10^8; see + // https://jk-jeon.github.io/posts/2022/12/fixed-precision-formatting/#fixed-length-case + // for details. + prod = ((subsegment * static_cast(720575941)) >> 24) + 1; + digits = static_cast(prod >> 32); + *buffer = static_cast('0' + digits); + number_of_digits_printed++; + } + // If we want to print an even number of digits from the + // first_subsegment, + else { + // Convert to 64-bit fixed-point fractional form with 2-digits + // integer part. The magic number 450359963 is a good enough + // approximation of 2^(32 + 20) / 10^7; see + // https://jk-jeon.github.io/posts/2022/12/fixed-precision-formatting/#fixed-length-case + // for details. + prod = ((subsegment * static_cast(450359963)) >> 20) + 1; + digits = static_cast(prod >> 32); + write2digits(buffer, digits); + number_of_digits_printed += 2; + } + + // Print all digit pairs. + while (number_of_digits_printed < number_of_digits_to_print) { + prod = static_cast(prod) * static_cast(100); + digits = static_cast(prod >> 32); + write2digits(buffer + number_of_digits_printed, digits); + number_of_digits_printed += 2; + } + }; + + // Print first subsegment. + print_subsegment(first_subsegment, buf.data()); + + // Perform rounding if the first subsegment is the last subsegment to + // print. + if (precision <= 9) { + // Rounding inside the subsegment. + // We round-up if: + // - either the fractional part is strictly larger than 1/2, or + // - the fractional part is exactly 1/2 and the last digit is odd. + // We rely on the following observations: + // - If fractional_part >= threshold, then the fractional part is + // strictly larger than 1/2. + // - If the MSB of fractional_part is set, then the fractional part + // must be at least 1/2. + // - When the MSB of fractional_part is set, either + // second_third_subsegments being nonzero or has_more_segments + // being true means there are further digits not printed, so the + // fractional part is strictly larger than 1/2. + if (precision < 9) { + uint32_t fractional_part = static_cast(prod); + should_round_up = + fractional_part >= fractional_part_rounding_thresholds( + 8 - number_of_digits_to_print) || + ((fractional_part >> 31) & + ((digits & 1) | (second_third_subsegments != 0) | + has_more_segments)) != 0; + } + // Rounding at the subsegment boundary. + // In this case, the fractional part is at least 1/2 if and only if + // second_third_subsegments >= 5000000000ULL, and is strictly larger + // than 1/2 if we further have either second_third_subsegments > + // 5000000000ULL or has_more_segments == true. + else { + should_round_up = second_third_subsegments > 5000000000ULL || + (second_third_subsegments == 5000000000ULL && + ((digits & 1) != 0 || has_more_segments)); + } + } + // Otherwise, print the second subsegment. + else { + // Compilers are not aware of how to leverage the maximum value of + // second_third_subsegments to find out a better magic number which + // allows us to eliminate an additional shift. 1844674407370955162 = + // ceil(2^64/10) < ceil(2^64*(10^9/(10^10 - 1))). + const uint32_t second_subsegment = + static_cast(dragonbox::umul128_upper64( + second_third_subsegments, 1844674407370955162ULL)); + const uint32_t third_subsegment = + static_cast(second_third_subsegments) - + second_subsegment * 10; + + number_of_digits_to_print = precision - 9; + print_subsegment(second_subsegment, buf.data() + 9); + + // Rounding inside the subsegment. + if (precision < 18) { + // The condition third_subsegment != 0 implies that the segment was + // of 19 digits, so in this case the third segment should be + // consisting of a genuine digit from the input. + uint32_t fractional_part = static_cast(prod); + should_round_up = + fractional_part >= fractional_part_rounding_thresholds( + 8 - number_of_digits_to_print) || + ((fractional_part >> 31) & + ((digits & 1) | (third_subsegment != 0) | + has_more_segments)) != 0; + } + // Rounding at the subsegment boundary. + else { + // In this case, the segment must be of 19 digits, thus + // the third subsegment should be consisting of a genuine digit from + // the input. + should_round_up = third_subsegment > 5 || + (third_subsegment == 5 && + ((digits & 1) != 0 || has_more_segments)); + } + } + + // Round-up if necessary. + if (should_round_up) { + ++buf[precision - 1]; + for (int i = precision - 1; i > 0 && buf[i] > '9'; --i) { + buf[i] = '0'; + ++buf[i - 1]; + } + if (buf[0] > '9') { + buf[0] = '1'; + if (fixed) + buf[precision++] = '0'; + else + ++exp; + } + } + buf.try_resize(to_unsigned(precision)); + } + } // if (digits_in_the_first_segment > precision) + else { + // Adjust the exponent for its use in Dragon4. + exp += digits_in_the_first_segment - 1; + } + } + if (use_dragon) { + auto f = basic_fp(); + bool is_predecessor_closer = binary32 ? f.assign(static_cast(value)) + : f.assign(converted_value); + if (is_predecessor_closer) dragon_flags |= dragon::predecessor_closer; + if (fixed) dragon_flags |= dragon::fixed; + // Limit precision to the maximum possible number of significant digits in + // an IEEE754 double because we don't need to generate zeros. + const int max_double_digits = 767; + if (precision > max_double_digits) precision = max_double_digits; + format_dragon(f, dragon_flags, precision, buf, exp); + } + if (!fixed && !specs.alt()) { + // Remove trailing zeros. + auto num_digits = buf.size(); + while (num_digits > 0 && buf[num_digits - 1] == '0') { + --num_digits; + ++exp; + } + buf.try_resize(num_digits); + } + return exp; +} + +template ::value)> +FMT_CONSTEXPR20 auto write(OutputIt out, T value, format_specs specs, + locale_ref loc = {}) -> OutputIt { + if (specs.localized() && write_loc(out, value, specs, loc)) return out; + + // Use signbit because value < 0 is false for NaN. + sign s = detail::signbit(value) ? sign::minus : specs.sign(); + + if (!detail::isfinite(value)) + return write_nonfinite(out, detail::isnan(value), specs, s); + + if (specs.align() == align::numeric && s != sign::none) { + *out++ = detail::getsign(s); + s = sign::none; + if (specs.width != 0) --specs.width; + } + + const int exp_upper = detail::exp_upper(); + int precision = specs.precision; + if (precision < 0) { + if (specs.type() != presentation_type::none) { + precision = 6; + } else if (is_fast_float::value && !is_constant_evaluated()) { + // Use Dragonbox for the shortest format. + auto dec = dragonbox::to_decimal(static_cast>(value)); + return write_float(out, dec, specs, s, exp_upper, loc); + } + } + + memory_buffer buffer; + if (specs.type() == presentation_type::hexfloat) { + if (s != sign::none) buffer.push_back(detail::getsign(s)); + format_hexfloat(convert_float(value), specs, buffer); + return write_bytes(out, {buffer.data(), buffer.size()}, + specs); + } + + if (specs.type() == presentation_type::exp) { + if (precision == max_value()) + report_error("number is too big"); + else + ++precision; + if (specs.precision != 0) specs.set_alt(); + } else if (specs.type() == presentation_type::fixed) { + if (specs.precision != 0) specs.set_alt(); + } else if (precision == 0) { + precision = 1; + } + int exp = format_float(convert_float(value), precision, specs, + std::is_same(), buffer); + + specs.precision = precision; + auto f = big_decimal_fp{buffer.data(), static_cast(buffer.size()), exp}; + return write_float(out, f, specs, s, exp_upper, loc); +} + +template ::value)> +FMT_CONSTEXPR20 auto write(OutputIt out, T value) -> OutputIt { + if (is_constant_evaluated()) return write(out, value, format_specs()); + + auto s = detail::signbit(value) ? sign::minus : sign::none; + auto mask = exponent_mask>(); + if ((bit_cast(value) & mask) == mask) + return write_nonfinite(out, std::isnan(value), {}, s); + + auto dec = dragonbox::to_decimal(static_cast>(value)); + auto significand = dec.significand; + int significand_size = count_digits(significand); + int exponent = dec.exponent + significand_size - 1; + if (use_fixed(exponent, detail::exp_upper())) { + return write_fixed>( + out, dec, significand_size, Char('.'), {}, s); + } + + // Write value in the exponential format. + const char* prefix = "e+"; + int abs_exponent = exponent; + if (exponent < 0) { + abs_exponent = -exponent; + prefix = "e-"; + } + auto has_decimal_point = significand_size != 1; + size_t size = std::is_pointer::value + ? 0u + : to_unsigned((s != sign::none ? 1 : 0) + significand_size + + (has_decimal_point ? 1 : 0) + + (abs_exponent >= 100 ? 5 : 4)); + if (auto ptr = to_pointer(out, size)) { + if (s != sign::none) *ptr++ = Char('-'); + if (has_decimal_point) { + auto begin = ptr; + ptr = format_decimal(ptr, significand, significand_size + 1); + *begin = begin[1]; + begin[1] = '.'; + } else { + *ptr++ = static_cast('0' + significand); + } + if (std::is_same::value) { + memcpy(ptr, prefix, 2); + ptr += 2; + } else { + *ptr++ = prefix[0]; + *ptr++ = prefix[1]; + } + if (abs_exponent >= 100) { + *ptr++ = static_cast('0' + abs_exponent / 100); + abs_exponent %= 100; + } + write2digits(ptr, static_cast(abs_exponent)); + return select::value>(ptr + 2, out); + } + auto it = reserve(out, size); + if (s != sign::none) *it++ = Char('-'); + // Insert a decimal point after the first digit and add an exponent. + it = write_significand(it, significand, significand_size, 1, + has_decimal_point ? Char('.') : Char()); + *it++ = Char('e'); + it = write_exponent(exponent, it); + return base_iterator(out, it); +} + +template ::value && + !is_fast_float::value)> +inline auto write(OutputIt out, T value) -> OutputIt { + return write(out, value, {}); +} + +template +auto write(OutputIt out, monostate, format_specs = {}, locale_ref = {}) + -> OutputIt { + FMT_ASSERT(false, ""); + return out; +} + +template +FMT_CONSTEXPR auto write(OutputIt out, basic_string_view value) + -> OutputIt { + return copy_noinline(value.begin(), value.end(), out); +} + +template ::value)> +constexpr auto write(OutputIt out, const T& value) -> OutputIt { + return write(out, to_string_view(value)); +} + +// FMT_ENABLE_IF() condition separated to workaround an MSVC bug. +template < + typename Char, typename OutputIt, typename T, + bool check = std::is_enum::value && !std::is_same::value && + mapped_type_constant::value != type::custom_type, + FMT_ENABLE_IF(check)> +FMT_CONSTEXPR auto write(OutputIt out, T value) -> OutputIt { + return write(out, static_cast>(value)); +} + +template ::value)> +FMT_CONSTEXPR auto write(OutputIt out, T value, const format_specs& specs = {}, + locale_ref = {}) -> OutputIt { + return specs.type() != presentation_type::none && + specs.type() != presentation_type::string + ? write(out, value ? 1 : 0, specs, {}) + : write_bytes(out, value ? "true" : "false", specs); +} + +template +FMT_CONSTEXPR auto write(OutputIt out, Char value) -> OutputIt { + auto it = reserve(out, 1); + *it++ = value; + return base_iterator(out, it); +} + +template +FMT_CONSTEXPR20 auto write(OutputIt out, const Char* value) -> OutputIt { + if (value) return write(out, basic_string_view(value)); + report_error("string pointer is null"); + return out; +} + +template ::value)> +auto write(OutputIt out, const T* value, const format_specs& specs = {}, + locale_ref = {}) -> OutputIt { + return write_ptr(out, bit_cast(value), &specs); +} + +template ::value == + type::custom_type && + !std::is_fundamental::value)> +FMT_CONSTEXPR auto write(OutputIt out, const T& value) -> OutputIt { + auto f = formatter(); + auto parse_ctx = parse_context({}); + f.parse(parse_ctx); + auto ctx = basic_format_context(out, {}, {}); + return f.format(value, ctx); +} + +template +using is_builtin = + bool_constant::value || FMT_BUILTIN_TYPES>; + +// An argument visitor that formats the argument and writes it via the output +// iterator. It's a class and not a generic lambda for compatibility with C++11. +template struct default_arg_formatter { + using context = buffered_context; + + basic_appender out; + + void operator()(monostate) { report_error("argument not found"); } + + template ::value)> + void operator()(T value) { + write(out, value); + } + + template ::value)> + void operator()(T) { + FMT_ASSERT(false, ""); + } + + void operator()(typename basic_format_arg::handle h) { + // Use a null locale since the default format must be unlocalized. + auto parse_ctx = parse_context({}); + auto format_ctx = context(out, {}, {}); + h.format(parse_ctx, format_ctx); + } +}; + +template struct arg_formatter { + basic_appender out; + const format_specs& specs; + FMT_NO_UNIQUE_ADDRESS locale_ref locale; + + template ::value)> + FMT_CONSTEXPR FMT_INLINE void operator()(T value) { + detail::write(out, value, specs, locale); + } + + template ::value)> + void operator()(T) { + FMT_ASSERT(false, ""); + } + + void operator()(typename basic_format_arg>::handle) { + // User-defined types are handled separately because they require access + // to the parse context. + } +}; + +struct dynamic_spec_getter { + template ::value)> + FMT_CONSTEXPR auto operator()(T value) -> unsigned long long { + return is_negative(value) ? ~0ull : static_cast(value); + } + + template ::value)> + FMT_CONSTEXPR auto operator()(T) -> unsigned long long { + report_error("width/precision is not integer"); + return 0; + } +}; + +template +FMT_CONSTEXPR void handle_dynamic_spec( + arg_id_kind kind, int& value, + const arg_ref& ref, Context& ctx) { + if (kind == arg_id_kind::none) return; + auto arg = + kind == arg_id_kind::index ? ctx.arg(ref.index) : ctx.arg(ref.name); + if (!arg) report_error("argument not found"); + unsigned long long result = arg.visit(dynamic_spec_getter()); + if (result > to_unsigned(max_value())) + report_error("width/precision is out of range"); + value = static_cast(result); +} + +#if FMT_USE_NONTYPE_TEMPLATE_ARGS +template Str> +struct static_named_arg : view { + static constexpr auto name = Str.data; + + const T& value; + static_named_arg(const T& v) : value(v) {} +}; + +template Str> +struct is_named_arg> : std::true_type {}; + +template Str> +struct is_static_named_arg> : std::true_type { +}; + +template Str> +struct udl_arg { + template auto operator=(T&& value) const { + return static_named_arg(std::forward(value)); + } +}; +#else +template struct udl_arg { + const Char* str; + + template auto operator=(T&& value) const -> named_arg { + return {str, std::forward(value)}; + } +}; +#endif // FMT_USE_NONTYPE_TEMPLATE_ARGS + +template struct format_handler { + parse_context parse_ctx; + buffered_context ctx; + + void on_text(const Char* begin, const Char* end) { + copy_noinline(begin, end, ctx.out()); + } + + FMT_CONSTEXPR auto on_arg_id() -> int { return parse_ctx.next_arg_id(); } + FMT_CONSTEXPR auto on_arg_id(int id) -> int { + parse_ctx.check_arg_id(id); + return id; + } + FMT_CONSTEXPR auto on_arg_id(basic_string_view id) -> int { + parse_ctx.check_arg_id(id); + int arg_id = ctx.arg_id(id); + if (arg_id < 0) report_error("argument not found"); + return arg_id; + } + + FMT_INLINE void on_replacement_field(int id, const Char*) { + ctx.arg(id).visit(default_arg_formatter{ctx.out()}); + } + + auto on_format_specs(int id, const Char* begin, const Char* end) + -> const Char* { + auto arg = ctx.arg(id); + if (!arg) report_error("argument not found"); + // Not using a visitor for custom types gives better codegen. + if (arg.format_custom(begin, parse_ctx, ctx)) return parse_ctx.begin(); + + auto specs = dynamic_format_specs(); + begin = parse_format_specs(begin, end, specs, parse_ctx, arg.type()); + if (specs.dynamic()) { + handle_dynamic_spec(specs.dynamic_width(), specs.width, specs.width_ref, + ctx); + handle_dynamic_spec(specs.dynamic_precision(), specs.precision, + specs.precision_ref, ctx); + } + + arg.visit(arg_formatter{ctx.out(), specs, ctx.locale()}); + return begin; + } + + FMT_NORETURN void on_error(const char* message) { report_error(message); } +}; + +// It is used in format-inl.h and os.cc. +using format_func = void (*)(detail::buffer&, int, const char*); +FMT_API void do_report_error(format_func func, int error_code, + const char* message) noexcept; + +FMT_API void format_error_code(buffer& out, int error_code, + string_view message) noexcept; + +template +template +FMT_CONSTEXPR auto native_formatter::format( + const T& val, FormatContext& ctx) const -> decltype(ctx.out()) { + if (!specs_.dynamic()) + return write(ctx.out(), val, specs_, ctx.locale()); + auto specs = format_specs(specs_); + handle_dynamic_spec(specs.dynamic_width(), specs.width, specs_.width_ref, + ctx); + handle_dynamic_spec(specs.dynamic_precision(), specs.precision, + specs_.precision_ref, ctx); + return write(ctx.out(), val, specs, ctx.locale()); +} +} // namespace detail + +FMT_BEGIN_EXPORT + +// A generic formatting context with custom output iterator and character +// (code unit) support. Char is the format string code unit type which can be +// different from OutputIt::value_type. +template class generic_context { + private: + OutputIt out_; + basic_format_args args_; + locale_ref loc_; + + public: + using char_type = Char; + using iterator = OutputIt; + enum { builtin_types = FMT_BUILTIN_TYPES }; + + constexpr generic_context(OutputIt out, + basic_format_args args, + locale_ref loc = {}) + : out_(out), args_(args), loc_(loc) {} + generic_context(generic_context&&) = default; + generic_context(const generic_context&) = delete; + void operator=(const generic_context&) = delete; + + constexpr auto arg(int id) const -> basic_format_arg { + return args_.get(id); + } + auto arg(basic_string_view name) const + -> basic_format_arg { + return args_.get(name); + } + constexpr auto arg_id(basic_string_view name) const -> int { + return args_.get_id(name); + } + + constexpr auto out() const -> iterator { return out_; } + + void advance_to(iterator it) { + if (!detail::is_back_insert_iterator()) out_ = it; + } + + constexpr auto locale() const -> locale_ref { return loc_; } +}; + +class loc_value { + private: + basic_format_arg value_; + + public: + template ::value)> + loc_value(T value) : value_(value) {} + + template ::value)> + loc_value(T) {} + + template auto visit(Visitor&& vis) -> decltype(vis(0)) { + return value_.visit(vis); + } +}; + +// A locale facet that formats values in UTF-8. +// It is parameterized on the locale to avoid the heavy include. +template class format_facet : public Locale::facet { + private: + std::string separator_; + std::string grouping_; + std::string decimal_point_; + + protected: + virtual auto do_put(appender out, loc_value val, + const format_specs& specs) const -> bool; + + public: + static FMT_API typename Locale::id id; + + explicit format_facet(Locale& loc); + explicit format_facet(string_view sep = "", std::string grouping = "\3", + std::string decimal_point = ".") + : separator_(sep.data(), sep.size()), + grouping_(grouping), + decimal_point_(decimal_point) {} + + auto put(appender out, loc_value val, const format_specs& specs) const + -> bool { + return do_put(out, val, specs); + } +}; + +#define FMT_FORMAT_AS(Type, Base) \ + template \ + struct formatter : formatter { \ + template \ + FMT_CONSTEXPR auto format(Type value, FormatContext& ctx) const \ + -> decltype(ctx.out()) { \ + return formatter::format(value, ctx); \ + } \ + } + +FMT_FORMAT_AS(signed char, int); +FMT_FORMAT_AS(unsigned char, unsigned); +FMT_FORMAT_AS(short, int); +FMT_FORMAT_AS(unsigned short, unsigned); +FMT_FORMAT_AS(long, detail::long_type); +FMT_FORMAT_AS(unsigned long, detail::ulong_type); +FMT_FORMAT_AS(Char*, const Char*); +FMT_FORMAT_AS(detail::std_string_view, basic_string_view); +FMT_FORMAT_AS(std::nullptr_t, const void*); +FMT_FORMAT_AS(void*, const void*); + +template +struct formatter : formatter, Char> {}; + +template +class formatter, Char> + : public formatter, Char> {}; + +template +struct formatter, Char> : formatter {}; +template +struct formatter, Char> + : formatter {}; + +template +struct formatter + : detail::native_formatter {}; + +template +struct formatter>> + : formatter, Char> { + template + FMT_CONSTEXPR auto format(const T& value, FormatContext& ctx) const + -> decltype(ctx.out()) { + auto&& val = format_as(value); // Make an lvalue reference for format. + return formatter, Char>::format(val, ctx); + } +}; + +/** + * Converts `p` to `const void*` for pointer formatting. + * + * **Example**: + * + * auto s = fmt::format("{}", fmt::ptr(p)); + */ +template auto ptr(T p) -> const void* { + static_assert(std::is_pointer::value, "fmt::ptr used with non-pointer"); + return detail::bit_cast(p); +} + +/** + * Converts `e` to the underlying type. + * + * **Example**: + * + * enum class color { red, green, blue }; + * auto s = fmt::format("{}", fmt::underlying(color::red)); // s == "0" + */ +template +constexpr auto underlying(Enum e) noexcept -> underlying_t { + return static_cast>(e); +} + +namespace enums { +template ::value)> +constexpr auto format_as(Enum e) noexcept -> underlying_t { + return static_cast>(e); +} +} // namespace enums + +#ifdef __cpp_lib_byte +template +struct formatter : formatter { + static auto format_as(std::byte b) -> unsigned char { + return static_cast(b); + } + template + auto format(std::byte b, Context& ctx) const -> decltype(ctx.out()) { + return formatter::format(format_as(b), ctx); + } +}; +#endif + +struct bytes { + string_view data; + + inline explicit bytes(string_view s) : data(s) {} +}; + +template <> struct formatter { + private: + detail::dynamic_format_specs<> specs_; + + public: + FMT_CONSTEXPR auto parse(parse_context<>& ctx) -> const char* { + return parse_format_specs(ctx.begin(), ctx.end(), specs_, ctx, + detail::type::string_type); + } + + template + auto format(bytes b, FormatContext& ctx) const -> decltype(ctx.out()) { + auto specs = specs_; + detail::handle_dynamic_spec(specs.dynamic_width(), specs.width, + specs.width_ref, ctx); + detail::handle_dynamic_spec(specs.dynamic_precision(), specs.precision, + specs.precision_ref, ctx); + return detail::write_bytes(ctx.out(), b.data, specs); + } +}; + +// group_digits_view is not derived from view because it copies the argument. +template struct group_digits_view { + T value; +}; + +/** + * Returns a view that formats an integer value using ',' as a + * locale-independent thousands separator. + * + * **Example**: + * + * fmt::print("{}", fmt::group_digits(12345)); + * // Output: "12,345" + */ +template auto group_digits(T value) -> group_digits_view { + return {value}; +} + +template struct formatter> : formatter { + private: + detail::dynamic_format_specs<> specs_; + + public: + FMT_CONSTEXPR auto parse(parse_context<>& ctx) -> const char* { + return parse_format_specs(ctx.begin(), ctx.end(), specs_, ctx, + detail::type::int_type); + } + + template + auto format(group_digits_view view, FormatContext& ctx) const + -> decltype(ctx.out()) { + auto specs = specs_; + detail::handle_dynamic_spec(specs.dynamic_width(), specs.width, + specs.width_ref, ctx); + detail::handle_dynamic_spec(specs.dynamic_precision(), specs.precision, + specs.precision_ref, ctx); + auto arg = detail::make_write_int_arg(view.value, specs.sign()); + return detail::write_int( + ctx.out(), static_cast>(arg.abs_value), + arg.prefix, specs, detail::digit_grouping("\3", ",")); + } +}; + +template struct nested_view { + const formatter* fmt; + const T* value; +}; + +template +struct formatter, Char> { + FMT_CONSTEXPR auto parse(parse_context& ctx) -> const Char* { + return ctx.begin(); + } + template + auto format(nested_view view, FormatContext& ctx) const + -> decltype(ctx.out()) { + return view.fmt->format(*view.value, ctx); + } +}; + +template struct nested_formatter { + private: + basic_specs specs_; + int width_; + formatter formatter_; + + public: + constexpr nested_formatter() : width_(0) {} + + FMT_CONSTEXPR auto parse(parse_context& ctx) -> const Char* { + auto it = ctx.begin(), end = ctx.end(); + if (it == end) return it; + auto specs = format_specs(); + it = detail::parse_align(it, end, specs); + specs_ = specs; + Char c = *it; + auto width_ref = detail::arg_ref(); + if ((c >= '0' && c <= '9') || c == '{') { + it = detail::parse_width(it, end, specs, width_ref, ctx); + width_ = specs.width; + } + ctx.advance_to(it); + return formatter_.parse(ctx); + } + + template + auto write_padded(FormatContext& ctx, F write) const -> decltype(ctx.out()) { + if (width_ == 0) return write(ctx.out()); + auto buf = basic_memory_buffer(); + write(basic_appender(buf)); + auto specs = format_specs(); + specs.width = width_; + specs.copy_fill_from(specs_); + specs.set_align(specs_.align()); + return detail::write( + ctx.out(), basic_string_view(buf.data(), buf.size()), specs); + } + + auto nested(const T& value) const -> nested_view { + return nested_view{&formatter_, &value}; + } +}; + +inline namespace literals { +#if FMT_USE_NONTYPE_TEMPLATE_ARGS +template constexpr auto operator""_a() { + using char_t = remove_cvref_t; + return detail::udl_arg(); +} +#else +/** + * User-defined literal equivalent of `fmt::arg`. + * + * **Example**: + * + * using namespace fmt::literals; + * fmt::print("The answer is {answer}.", "answer"_a=42); + */ +constexpr auto operator""_a(const char* s, size_t) -> detail::udl_arg { + return {s}; +} +#endif // FMT_USE_NONTYPE_TEMPLATE_ARGS +} // namespace literals + +/// A fast integer formatter. +class format_int { + private: + // Buffer should be large enough to hold all digits (digits10 + 1), + // a sign and a null character. + enum { buffer_size = std::numeric_limits::digits10 + 3 }; + mutable char buffer_[buffer_size]; + char* str_; + + template + FMT_CONSTEXPR20 auto format_unsigned(UInt value) -> char* { + auto n = static_cast>(value); + return detail::do_format_decimal(buffer_, n, buffer_size - 1); + } + + template + FMT_CONSTEXPR20 auto format_signed(Int value) -> char* { + auto abs_value = static_cast>(value); + bool negative = value < 0; + if (negative) abs_value = 0 - abs_value; + auto begin = format_unsigned(abs_value); + if (negative) *--begin = '-'; + return begin; + } + + public: + FMT_CONSTEXPR20 explicit format_int(int value) : str_(format_signed(value)) {} + FMT_CONSTEXPR20 explicit format_int(long value) + : str_(format_signed(value)) {} + FMT_CONSTEXPR20 explicit format_int(long long value) + : str_(format_signed(value)) {} + FMT_CONSTEXPR20 explicit format_int(unsigned value) + : str_(format_unsigned(value)) {} + FMT_CONSTEXPR20 explicit format_int(unsigned long value) + : str_(format_unsigned(value)) {} + FMT_CONSTEXPR20 explicit format_int(unsigned long long value) + : str_(format_unsigned(value)) {} + + /// Returns the number of characters written to the output buffer. + FMT_CONSTEXPR20 auto size() const -> size_t { + return detail::to_unsigned(buffer_ - str_ + buffer_size - 1); + } + + /// Returns a pointer to the output buffer content. No terminating null + /// character is appended. + FMT_CONSTEXPR20 auto data() const -> const char* { return str_; } + + /// Returns a pointer to the output buffer content with terminating null + /// character appended. + FMT_CONSTEXPR20 auto c_str() const -> const char* { + buffer_[buffer_size - 1] = '\0'; + return str_; + } + + /// Returns the content of the output buffer as an `std::string`. + inline auto str() const -> std::string { return {str_, size()}; } +}; + +#if FMT_CLANG_ANALYZER +# define FMT_STRING_IMPL(s, base) s +#else +# define FMT_STRING_IMPL(s, base) \ + [] { \ + /* Use the hidden visibility as a workaround for a GCC bug (#1973). */ \ + /* Use a macro-like name to avoid shadowing warnings. */ \ + struct FMT_VISIBILITY("hidden") FMT_COMPILE_STRING : base { \ + using char_type = fmt::remove_cvref_t; \ + constexpr explicit operator fmt::basic_string_view() \ + const { \ + return fmt::detail::compile_string_to_view(s); \ + } \ + }; \ + using FMT_STRING_VIEW = \ + fmt::basic_string_view; \ + fmt::detail::ignore_unused(FMT_STRING_VIEW(FMT_COMPILE_STRING())); \ + return FMT_COMPILE_STRING(); \ + }() +#endif // FMT_CLANG_ANALYZER + +/** + * Constructs a legacy compile-time format string from a string literal `s`. + * + * **Example**: + * + * // A compile-time error because 'd' is an invalid specifier for strings. + * std::string s = fmt::format(FMT_STRING("{:d}"), "foo"); + */ +#define FMT_STRING(s) FMT_STRING_IMPL(s, fmt::detail::compile_string) + +FMT_API auto vsystem_error(int error_code, string_view fmt, format_args args) + -> std::system_error; + +/** + * Constructs `std::system_error` with a message formatted with + * `fmt::format(fmt, args...)`. + * `error_code` is a system error code as given by `errno`. + * + * **Example**: + * + * // This throws std::system_error with the description + * // cannot open file 'madeup': No such file or directory + * // or similar (system message may vary). + * const char* filename = "madeup"; + * FILE* file = fopen(filename, "r"); + * if (!file) + * throw fmt::system_error(errno, "cannot open file '{}'", filename); + */ +template +auto system_error(int error_code, format_string fmt, T&&... args) + -> std::system_error { + return vsystem_error(error_code, fmt.str, vargs{{args...}}); +} + +/** + * Formats an error message for an error returned by an operating system or a + * language runtime, for example a file opening error, and writes it to `out`. + * The format is the same as the one used by `std::system_error(ec, message)` + * where `ec` is `std::error_code(error_code, std::generic_category())`. + * It is implementation-defined but normally looks like: + * + * : + * + * where `` is the passed message and `` is the system + * message corresponding to the error code. + * `error_code` is a system error code as given by `errno`. + */ +FMT_API void format_system_error(detail::buffer& out, int error_code, + const char* message) noexcept; + +// Reports a system error without throwing an exception. +// Can be used to report errors from destructors. +FMT_API void report_system_error(int error_code, const char* message) noexcept; + +inline auto vformat(locale_ref loc, string_view fmt, format_args args) + -> std::string { + auto buf = memory_buffer(); + detail::vformat_to(buf, fmt, args, loc); + return {buf.data(), buf.size()}; +} + +template +FMT_INLINE auto format(locale_ref loc, format_string fmt, T&&... args) + -> std::string { + return vformat(loc, fmt.str, vargs{{args...}}); +} + +template ::value)> +auto vformat_to(OutputIt out, locale_ref loc, string_view fmt, format_args args) + -> OutputIt { + auto&& buf = detail::get_buffer(out); + detail::vformat_to(buf, fmt, args, loc); + return detail::get_iterator(buf, out); +} + +template ::value)> +FMT_INLINE auto format_to(OutputIt out, locale_ref loc, format_string fmt, + T&&... args) -> OutputIt { + return fmt::vformat_to(out, loc, fmt.str, vargs{{args...}}); +} + +template +FMT_NODISCARD FMT_INLINE auto formatted_size(locale_ref loc, + format_string fmt, + T&&... args) -> size_t { + auto buf = detail::counting_buffer<>(); + detail::vformat_to(buf, fmt.str, vargs{{args...}}, loc); + return buf.count(); +} + +FMT_API auto vformat(string_view fmt, format_args args) -> std::string; + +/** + * Formats `args` according to specifications in `fmt` and returns the result + * as a string. + * + * **Example**: + * + * #include + * std::string message = fmt::format("The answer is {}.", 42); + */ +template +FMT_NODISCARD FMT_INLINE auto format(format_string fmt, T&&... args) + -> std::string { + return vformat(fmt.str, vargs{{args...}}); +} + +/** + * Converts `value` to `std::string` using the default format for type `T`. + * + * **Example**: + * + * std::string answer = fmt::to_string(42); + */ +template ::value)> +FMT_NODISCARD FMT_CONSTEXPR_STRING auto to_string(T value) -> std::string { + // The buffer should be large enough to store the number including the sign + // or "false" for bool. + char buffer[max_of(detail::digits10() + 2, 5)]; + return {buffer, detail::write(buffer, value)}; +} + +template ::value)> +FMT_NODISCARD FMT_CONSTEXPR_STRING auto to_string(const T& value) + -> std::string { + return to_string(format_as(value)); +} + +template ::value && + !detail::use_format_as::value)> +FMT_NODISCARD FMT_CONSTEXPR_STRING auto to_string(const T& value) + -> std::string { + auto buffer = memory_buffer(); + detail::write(appender(buffer), value); + return {buffer.data(), buffer.size()}; +} + +FMT_END_EXPORT +FMT_END_NAMESPACE + +#ifdef FMT_HEADER_ONLY +# define FMT_FUNC inline +# include "format-inl.h" +#endif + +// Restore _LIBCPP_REMOVE_TRANSITIVE_INCLUDES. +#ifdef FMT_REMOVE_TRANSITIVE_INCLUDES +# undef _LIBCPP_REMOVE_TRANSITIVE_INCLUDES +#endif + +#endif // FMT_FORMAT_H_ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fmt/os.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fmt/os.h new file mode 100644 index 0000000000000000000000000000000000000000..a412fd64a3045f2128b2746dc5bbd0cb90b81654 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fmt/os.h @@ -0,0 +1,432 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Formatting library for C++ - optional OS-specific functionality +// +// Copyright (c) 2012 - present, Victor Zverovich +// All rights reserved. +// +// For the license information refer to format.h. + +#ifndef FMT_OS_H_ +#define FMT_OS_H_ + +#include "format.h" + +#ifndef FMT_MODULE +# include +# include +# include +# include // std::system_error + +# if FMT_HAS_INCLUDE() +# include // LC_NUMERIC_MASK on macOS +# endif +#endif // FMT_MODULE + +#ifndef FMT_USE_FCNTL +// UWP doesn't provide _pipe. +# if FMT_HAS_INCLUDE("winapifamily.h") +# include +# endif +# if (FMT_HAS_INCLUDE() || defined(__APPLE__) || \ + defined(__linux__)) && \ + (!defined(WINAPI_FAMILY) || \ + (WINAPI_FAMILY == WINAPI_FAMILY_DESKTOP_APP)) && \ + !defined(__wasm__) +# include // for O_RDONLY +# define FMT_USE_FCNTL 1 +# else +# define FMT_USE_FCNTL 0 +# endif +#endif + +#ifndef FMT_POSIX +# if defined(_WIN32) && !defined(__MINGW32__) +// Fix warnings about deprecated symbols. +# define FMT_POSIX(call) _##call +# else +# define FMT_POSIX(call) call +# endif +#endif + +// Calls to system functions are wrapped in FMT_SYSTEM for testability. +#ifdef FMT_SYSTEM +# define FMT_HAS_SYSTEM +# define FMT_POSIX_CALL(call) FMT_SYSTEM(call) +#else +# define FMT_SYSTEM(call) ::call +# ifdef _WIN32 +// Fix warnings about deprecated symbols. +# define FMT_POSIX_CALL(call) ::_##call +# else +# define FMT_POSIX_CALL(call) ::call +# endif +#endif + +// Retries the expression while it evaluates to error_result and errno +// equals to EINTR. +#ifndef _WIN32 +# define FMT_RETRY_VAL(result, expression, error_result) \ + do { \ + (result) = (expression); \ + } while ((result) == (error_result) && errno == EINTR) +#else +# define FMT_RETRY_VAL(result, expression, error_result) result = (expression) +#endif + +#define FMT_RETRY(result, expression) FMT_RETRY_VAL(result, expression, -1) + +FMT_BEGIN_NAMESPACE +FMT_BEGIN_EXPORT + +/** + * A reference to a null-terminated string. It can be constructed from a C + * string or `std::string`. + * + * You can use one of the following type aliases for common character types: + * + * +---------------+-----------------------------+ + * | Type | Definition | + * +===============+=============================+ + * | cstring_view | basic_cstring_view | + * +---------------+-----------------------------+ + * | wcstring_view | basic_cstring_view | + * +---------------+-----------------------------+ + * + * This class is most useful as a parameter type for functions that wrap C APIs. + */ +template class basic_cstring_view { + private: + const Char* data_; + + public: + /// Constructs a string reference object from a C string. + basic_cstring_view(const Char* s) : data_(s) {} + + /// Constructs a string reference from an `std::string` object. + basic_cstring_view(const std::basic_string& s) : data_(s.c_str()) {} + + /// Returns the pointer to a C string. + auto c_str() const -> const Char* { return data_; } +}; + +using cstring_view = basic_cstring_view; +using wcstring_view = basic_cstring_view; + +#ifdef _WIN32 +FMT_API const std::error_category& system_category() noexcept; + +namespace detail { +FMT_API void format_windows_error(buffer& out, int error_code, + const char* message) noexcept; +} + +FMT_API std::system_error vwindows_error(int error_code, string_view fmt, + format_args args); + +/** + * Constructs a `std::system_error` object with the description of the form + * + * : + * + * where `` is the formatted message and `` is the + * system message corresponding to the error code. + * `error_code` is a Windows error code as given by `GetLastError`. + * If `error_code` is not a valid error code such as -1, the system message + * will look like "error -1". + * + * **Example**: + * + * // This throws a system_error with the description + * // cannot open file 'foo': The system cannot find the file specified. + * // or similar (system message may vary) if the file doesn't exist. + * const char *filename = "foo"; + * LPOFSTRUCT of = LPOFSTRUCT(); + * HFILE file = OpenFile(filename, &of, OF_READ); + * if (file == HFILE_ERROR) { + * throw fmt::windows_error(GetLastError(), + * "cannot open file '{}'", filename); + * } + */ +template +auto windows_error(int error_code, string_view message, const T&... args) + -> std::system_error { + return vwindows_error(error_code, message, vargs{{args...}}); +} + +// Reports a Windows error without throwing an exception. +// Can be used to report errors from destructors. +FMT_API void report_windows_error(int error_code, const char* message) noexcept; +#else +inline auto system_category() noexcept -> const std::error_category& { + return std::system_category(); +} +#endif // _WIN32 + +// std::system is not available on some platforms such as iOS (#2248). +#ifdef __OSX__ +template > +void say(const S& fmt, Args&&... args) { + std::system(format("say \"{}\"", format(fmt, args...)).c_str()); +} +#endif + +// A buffered file. +class buffered_file { + private: + FILE* file_; + + friend class file; + + inline explicit buffered_file(FILE* f) : file_(f) {} + + public: + buffered_file(const buffered_file&) = delete; + void operator=(const buffered_file&) = delete; + + // Constructs a buffered_file object which doesn't represent any file. + inline buffered_file() noexcept : file_(nullptr) {} + + // Destroys the object closing the file it represents if any. + FMT_API ~buffered_file() noexcept; + + public: + inline buffered_file(buffered_file&& other) noexcept : file_(other.file_) { + other.file_ = nullptr; + } + + inline auto operator=(buffered_file&& other) -> buffered_file& { + close(); + file_ = other.file_; + other.file_ = nullptr; + return *this; + } + + // Opens a file. + FMT_API buffered_file(cstring_view filename, cstring_view mode); + + // Closes the file. + FMT_API void close(); + + // Returns the pointer to a FILE object representing this file. + inline auto get() const noexcept -> FILE* { return file_; } + + FMT_API auto descriptor() const -> int; + + template + inline void print(string_view fmt, const T&... args) { + fmt::vargs vargs = {{args...}}; + detail::is_locking() ? fmt::vprint_buffered(file_, fmt, vargs) + : fmt::vprint(file_, fmt, vargs); + } +}; + +#if FMT_USE_FCNTL + +// A file. Closed file is represented by a file object with descriptor -1. +// Methods that are not declared with noexcept may throw +// fmt::system_error in case of failure. Note that some errors such as +// closing the file multiple times will cause a crash on Windows rather +// than an exception. You can get standard behavior by overriding the +// invalid parameter handler with _set_invalid_parameter_handler. +class FMT_API file { + private: + int fd_; // File descriptor. + + // Constructs a file object with a given descriptor. + explicit file(int fd) : fd_(fd) {} + + friend struct pipe; + + public: + // Possible values for the oflag argument to the constructor. + enum { + RDONLY = FMT_POSIX(O_RDONLY), // Open for reading only. + WRONLY = FMT_POSIX(O_WRONLY), // Open for writing only. + RDWR = FMT_POSIX(O_RDWR), // Open for reading and writing. + CREATE = FMT_POSIX(O_CREAT), // Create if the file doesn't exist. + APPEND = FMT_POSIX(O_APPEND), // Open in append mode. + TRUNC = FMT_POSIX(O_TRUNC) // Truncate the content of the file. + }; + + // Constructs a file object which doesn't represent any file. + inline file() noexcept : fd_(-1) {} + + // Opens a file and constructs a file object representing this file. + file(cstring_view path, int oflag); + + public: + file(const file&) = delete; + void operator=(const file&) = delete; + + inline file(file&& other) noexcept : fd_(other.fd_) { other.fd_ = -1; } + + // Move assignment is not noexcept because close may throw. + inline auto operator=(file&& other) -> file& { + close(); + fd_ = other.fd_; + other.fd_ = -1; + return *this; + } + + // Destroys the object closing the file it represents if any. + ~file() noexcept; + + // Returns the file descriptor. + inline auto descriptor() const noexcept -> int { return fd_; } + + // Closes the file. + void close(); + + // Returns the file size. The size has signed type for consistency with + // stat::st_size. + auto size() const -> long long; + + // Attempts to read count bytes from the file into the specified buffer. + auto read(void* buffer, size_t count) -> size_t; + + // Attempts to write count bytes from the specified buffer to the file. + auto write(const void* buffer, size_t count) -> size_t; + + // Duplicates a file descriptor with the dup function and returns + // the duplicate as a file object. + static auto dup(int fd) -> file; + + // Makes fd be the copy of this file descriptor, closing fd first if + // necessary. + void dup2(int fd); + + // Makes fd be the copy of this file descriptor, closing fd first if + // necessary. + void dup2(int fd, std::error_code& ec) noexcept; + + // Creates a buffered_file object associated with this file and detaches + // this file object from the file. + auto fdopen(const char* mode) -> buffered_file; + +# if defined(_WIN32) && !defined(__MINGW32__) + // Opens a file and constructs a file object representing this file by + // wcstring_view filename. Windows only. + static file open_windows_file(wcstring_view path, int oflag); +# endif +}; + +struct FMT_API pipe { + file read_end; + file write_end; + + // Creates a pipe setting up read_end and write_end file objects for reading + // and writing respectively. + pipe(); +}; + +// Returns the memory page size. +auto getpagesize() -> long; + +namespace detail { + +struct buffer_size { + constexpr buffer_size() = default; + size_t value = 0; + FMT_CONSTEXPR auto operator=(size_t val) const -> buffer_size { + auto bs = buffer_size(); + bs.value = val; + return bs; + } +}; + +struct ostream_params { + int oflag = file::WRONLY | file::CREATE | file::TRUNC; + size_t buffer_size = BUFSIZ > 32768 ? BUFSIZ : 32768; + + constexpr ostream_params() {} + + template + ostream_params(T... params, int new_oflag) : ostream_params(params...) { + oflag = new_oflag; + } + + template + ostream_params(T... params, detail::buffer_size bs) + : ostream_params(params...) { + this->buffer_size = bs.value; + } + +// Intel has a bug that results in failure to deduce a constructor +// for empty parameter packs. +# if defined(__INTEL_COMPILER) && __INTEL_COMPILER < 2000 + ostream_params(int new_oflag) : oflag(new_oflag) {} + ostream_params(detail::buffer_size bs) : buffer_size(bs.value) {} +# endif +}; + +} // namespace detail + +FMT_INLINE_VARIABLE constexpr auto buffer_size = detail::buffer_size(); + +/// A fast buffered output stream for writing from a single thread. Writing from +/// multiple threads without external synchronization may result in a data race. +class ostream : private detail::buffer { + private: + file file_; + + FMT_API ostream(cstring_view path, const detail::ostream_params& params); + + FMT_API static void grow(buffer& buf, size_t); + + public: + FMT_API ostream(ostream&& other) noexcept; + FMT_API ~ostream(); + + operator writer() { + detail::buffer& buf = *this; + return buf; + } + + inline void flush() { + if (size() == 0) return; + file_.write(data(), size() * sizeof(data()[0])); + clear(); + } + + template + friend auto output_file(cstring_view path, T... params) -> ostream; + + inline void close() { + flush(); + file_.close(); + } + + /// Formats `args` according to specifications in `fmt` and writes the + /// output to the file. + template void print(format_string fmt, T&&... args) { + vformat_to(appender(*this), fmt.str, vargs{{args...}}); + } +}; + +/** + * Opens a file for writing. Supported parameters passed in `params`: + * + * - ``: Flags passed to [open]( + * https://pubs.opengroup.org/onlinepubs/007904875/functions/open.html) + * (`file::WRONLY | file::CREATE | file::TRUNC` by default) + * - `buffer_size=`: Output buffer size + * + * **Example**: + * + * auto out = fmt::output_file("guide.txt"); + * out.print("Don't {}", "Panic"); + */ +template +inline auto output_file(cstring_view path, T... params) -> ostream { + return {path, detail::ostream_params(params...)}; +} +#endif // FMT_USE_FCNTL + +FMT_END_EXPORT +FMT_END_NAMESPACE + +#endif // FMT_OS_H_ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fmt/ostream.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fmt/ostream.h new file mode 100644 index 0000000000000000000000000000000000000000..a3c4887750d4bc5227e3e79f34f8869f706d82f5 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fmt/ostream.h @@ -0,0 +1,172 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Formatting library for C++ - std::ostream support +// +// Copyright (c) 2012 - present, Victor Zverovich +// All rights reserved. +// +// For the license information refer to format.h. + +#ifndef FMT_OSTREAM_H_ +#define FMT_OSTREAM_H_ + +#ifndef FMT_MODULE +# include // std::filebuf +#endif + +#ifdef _WIN32 +# ifdef __GLIBCXX__ +# include +# include +# endif +# include +#endif + +#include "chrono.h" // formatbuf + +#ifdef _MSVC_STL_UPDATE +# define FMT_MSVC_STL_UPDATE _MSVC_STL_UPDATE +#elif defined(_MSC_VER) && _MSC_VER < 1912 // VS 15.5 +# define FMT_MSVC_STL_UPDATE _MSVC_LANG +#else +# define FMT_MSVC_STL_UPDATE 0 +#endif + +FMT_BEGIN_NAMESPACE +namespace detail { + +// Generate a unique explicit instantiation in every translation unit using a +// tag type in an anonymous namespace. +namespace { +struct file_access_tag {}; +} // namespace +template +class file_access { + friend auto get_file(BufType& obj) -> FILE* { return obj.*FileMemberPtr; } +}; + +#if FMT_MSVC_STL_UPDATE +template class file_access; +auto get_file(std::filebuf&) -> FILE*; +#endif + +// Write the content of buf to os. +// It is a separate function rather than a part of vprint to simplify testing. +template +void write_buffer(std::basic_ostream& os, buffer& buf) { + const Char* buf_data = buf.data(); + using unsigned_streamsize = make_unsigned_t; + unsigned_streamsize size = buf.size(); + unsigned_streamsize max_size = to_unsigned(max_value()); + do { + unsigned_streamsize n = size <= max_size ? size : max_size; + os.write(buf_data, static_cast(n)); + buf_data += n; + size -= n; + } while (size != 0); +} + +template struct streamed_view { + const T& value; +}; +} // namespace detail + +// Formats an object of type T that has an overloaded ostream operator<<. +template +struct basic_ostream_formatter : formatter, Char> { + void set_debug_format() = delete; + + template + auto format(const T& value, Context& ctx) const -> decltype(ctx.out()) { + auto buffer = basic_memory_buffer(); + auto&& formatbuf = detail::formatbuf>(buffer); + auto&& output = std::basic_ostream(&formatbuf); + output.imbue(std::locale::classic()); // The default is always unlocalized. + output << value; + output.exceptions(std::ios_base::failbit | std::ios_base::badbit); + return formatter, Char>::format( + {buffer.data(), buffer.size()}, ctx); + } +}; + +using ostream_formatter = basic_ostream_formatter; + +template +struct formatter, Char> + : basic_ostream_formatter { + template + auto format(detail::streamed_view view, Context& ctx) const + -> decltype(ctx.out()) { + return basic_ostream_formatter::format(view.value, ctx); + } +}; + +/** + * Returns a view that formats `value` via an ostream `operator<<`. + * + * **Example**: + * + * fmt::print("Current thread id: {}\n", + * fmt::streamed(std::this_thread::get_id())); + */ +template +constexpr auto streamed(const T& value) -> detail::streamed_view { + return {value}; +} + +inline void vprint(std::ostream& os, string_view fmt, format_args args) { + auto buffer = memory_buffer(); + detail::vformat_to(buffer, fmt, args); + FILE* f = nullptr; +#if FMT_MSVC_STL_UPDATE && FMT_USE_RTTI + if (auto* buf = dynamic_cast(os.rdbuf())) + f = detail::get_file(*buf); +#elif defined(_WIN32) && defined(__GLIBCXX__) && FMT_USE_RTTI + auto* rdbuf = os.rdbuf(); + if (auto* sfbuf = dynamic_cast<__gnu_cxx::stdio_sync_filebuf*>(rdbuf)) + f = sfbuf->file(); + else if (auto* fbuf = dynamic_cast<__gnu_cxx::stdio_filebuf*>(rdbuf)) + f = fbuf->file(); +#endif +#ifdef _WIN32 + if (f) { + int fd = _fileno(f); + if (_isatty(fd)) { + os.flush(); + if (detail::write_console(fd, {buffer.data(), buffer.size()})) return; + } + } +#endif + detail::ignore_unused(f); + detail::write_buffer(os, buffer); +} + +/** + * Prints formatted data to the stream `os`. + * + * **Example**: + * + * fmt::print(cerr, "Don't {}!", "panic"); + */ +FMT_EXPORT template +void print(std::ostream& os, format_string fmt, T&&... args) { + fmt::vargs vargs = {{args...}}; + if (detail::const_check(detail::use_utf8)) return vprint(os, fmt.str, vargs); + auto buffer = memory_buffer(); + detail::vformat_to(buffer, fmt.str, vargs); + detail::write_buffer(os, buffer); +} + +FMT_EXPORT template +void println(std::ostream& os, format_string fmt, T&&... args) { + fmt::print(os, FMT_STRING("{}\n"), + fmt::format(fmt, std::forward(args)...)); +} + +FMT_END_NAMESPACE + +#endif // FMT_OSTREAM_H_ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fmt/printf.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fmt/printf.h new file mode 100644 index 0000000000000000000000000000000000000000..087cbae23c6d72af09fb88bdc6c5db721f4f6832 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fmt/printf.h @@ -0,0 +1,629 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Formatting library for C++ - legacy printf implementation +// +// Copyright (c) 2012 - 2016, Victor Zverovich +// All rights reserved. +// +// For the license information refer to format.h. + +#ifndef FMT_PRINTF_H_ +#define FMT_PRINTF_H_ + +#ifndef FMT_MODULE +# include // std::find +# include // std::numeric_limits +#endif + +#include "format.h" + +FMT_BEGIN_NAMESPACE +FMT_BEGIN_EXPORT + +template class basic_printf_context { + private: + basic_appender out_; + basic_format_args args_; + + static_assert(std::is_same::value || + std::is_same::value, + "Unsupported code unit type."); + + public: + using char_type = Char; + enum { builtin_types = 1 }; + + /// Constructs a `printf_context` object. References to the arguments are + /// stored in the context object so make sure they have appropriate lifetimes. + basic_printf_context(basic_appender out, + basic_format_args args) + : out_(out), args_(args) {} + + auto out() -> basic_appender { return out_; } + void advance_to(basic_appender) {} + + auto locale() -> locale_ref { return {}; } + + auto arg(int id) const -> basic_format_arg { + return args_.get(id); + } +}; + +namespace detail { + +// Return the result via the out param to workaround gcc bug 77539. +template +FMT_CONSTEXPR auto find(Ptr first, Ptr last, T value, Ptr& out) -> bool { + for (out = first; out != last; ++out) { + if (*out == value) return true; + } + return false; +} + +template <> +inline auto find(const char* first, const char* last, char value, + const char*& out) -> bool { + out = + static_cast(memchr(first, value, to_unsigned(last - first))); + return out != nullptr; +} + +// Checks if a value fits in int - used to avoid warnings about comparing +// signed and unsigned integers. +template struct int_checker { + template static auto fits_in_int(T value) -> bool { + return value <= to_unsigned(max_value()); + } + inline static auto fits_in_int(bool) -> bool { return true; } +}; + +template <> struct int_checker { + template static auto fits_in_int(T value) -> bool { + return value >= (std::numeric_limits::min)() && + value <= max_value(); + } + inline static auto fits_in_int(int) -> bool { return true; } +}; + +struct printf_precision_handler { + template ::value)> + auto operator()(T value) -> int { + if (!int_checker::is_signed>::fits_in_int(value)) + report_error("number is too big"); + return max_of(static_cast(value), 0); + } + + template ::value)> + auto operator()(T) -> int { + report_error("precision is not integer"); + return 0; + } +}; + +// An argument visitor that returns true iff arg is a zero integer. +struct is_zero_int { + template ::value)> + auto operator()(T value) -> bool { + return value == 0; + } + + template ::value)> + auto operator()(T) -> bool { + return false; + } +}; + +template struct make_unsigned_or_bool : std::make_unsigned {}; + +template <> struct make_unsigned_or_bool { + using type = bool; +}; + +template class arg_converter { + private: + using char_type = typename Context::char_type; + + basic_format_arg& arg_; + char_type type_; + + public: + arg_converter(basic_format_arg& arg, char_type type) + : arg_(arg), type_(type) {} + + void operator()(bool value) { + if (type_ != 's') operator()(value); + } + + template ::value)> + void operator()(U value) { + bool is_signed = type_ == 'd' || type_ == 'i'; + using target_type = conditional_t::value, U, T>; + if (const_check(sizeof(target_type) <= sizeof(int))) { + // Extra casts are used to silence warnings. + using unsigned_type = typename make_unsigned_or_bool::type; + if (is_signed) + arg_ = static_cast(static_cast(value)); + else + arg_ = static_cast(static_cast(value)); + } else { + // glibc's printf doesn't sign extend arguments of smaller types: + // std::printf("%lld", -42); // prints "4294967254" + // but we don't have to do the same because it's a UB. + if (is_signed) + arg_ = static_cast(value); + else + arg_ = static_cast::type>(value); + } + } + + template ::value)> + void operator()(U) {} // No conversion needed for non-integral types. +}; + +// Converts an integer argument to T for printf, if T is an integral type. +// If T is void, the argument is converted to corresponding signed or unsigned +// type depending on the type specifier: 'd' and 'i' - signed, other - +// unsigned). +template +void convert_arg(basic_format_arg& arg, Char type) { + arg.visit(arg_converter(arg, type)); +} + +// Converts an integer argument to char for printf. +template class char_converter { + private: + basic_format_arg& arg_; + + public: + explicit char_converter(basic_format_arg& arg) : arg_(arg) {} + + template ::value)> + void operator()(T value) { + arg_ = static_cast(value); + } + + template ::value)> + void operator()(T) {} // No conversion needed for non-integral types. +}; + +// An argument visitor that return a pointer to a C string if argument is a +// string or null otherwise. +template struct get_cstring { + template auto operator()(T) -> const Char* { return nullptr; } + auto operator()(const Char* s) -> const Char* { return s; } +}; + +// Checks if an argument is a valid printf width specifier and sets +// left alignment if it is negative. +class printf_width_handler { + private: + format_specs& specs_; + + public: + inline explicit printf_width_handler(format_specs& specs) : specs_(specs) {} + + template ::value)> + auto operator()(T value) -> unsigned { + auto width = static_cast>(value); + if (detail::is_negative(value)) { + specs_.set_align(align::left); + width = 0 - width; + } + unsigned int_max = to_unsigned(max_value()); + if (width > int_max) report_error("number is too big"); + return static_cast(width); + } + + template ::value)> + auto operator()(T) -> unsigned { + report_error("width is not integer"); + return 0; + } +}; + +// Workaround for a bug with the XL compiler when initializing +// printf_arg_formatter's base class. +template +auto make_arg_formatter(basic_appender iter, format_specs& s) + -> arg_formatter { + return {iter, s, locale_ref()}; +} + +// The `printf` argument formatter. +template +class printf_arg_formatter : public arg_formatter { + private: + using base = arg_formatter; + using context_type = basic_printf_context; + + context_type& context_; + + void write_null_pointer(bool is_string = false) { + auto s = this->specs; + s.set_type(presentation_type::none); + write_bytes(this->out, is_string ? "(null)" : "(nil)", s); + } + + template void write(T value) { + detail::write(this->out, value, this->specs, this->locale); + } + + public: + printf_arg_formatter(basic_appender iter, format_specs& s, + context_type& ctx) + : base(make_arg_formatter(iter, s)), context_(ctx) {} + + void operator()(monostate value) { write(value); } + + template ::value)> + void operator()(T value) { + // MSVC2013 fails to compile separate overloads for bool and Char so use + // std::is_same instead. + if (!std::is_same::value) { + write(value); + return; + } + format_specs s = this->specs; + if (s.type() != presentation_type::none && + s.type() != presentation_type::chr) { + return (*this)(static_cast(value)); + } + s.set_sign(sign::none); + s.clear_alt(); + s.set_fill(' '); // Ignore '0' flag for char types. + // align::numeric needs to be overwritten here since the '0' flag is + // ignored for non-numeric types + if (s.align() == align::none || s.align() == align::numeric) + s.set_align(align::right); + detail::write(this->out, static_cast(value), s); + } + + template ::value)> + void operator()(T value) { + write(value); + } + + void operator()(const char* value) { + if (value) + write(value); + else + write_null_pointer(this->specs.type() != presentation_type::pointer); + } + + void operator()(const wchar_t* value) { + if (value) + write(value); + else + write_null_pointer(this->specs.type() != presentation_type::pointer); + } + + void operator()(basic_string_view value) { write(value); } + + void operator()(const void* value) { + if (value) + write(value); + else + write_null_pointer(); + } + + void operator()(typename basic_format_arg::handle handle) { + auto parse_ctx = parse_context({}); + handle.format(parse_ctx, context_); + } +}; + +template +void parse_flags(format_specs& specs, const Char*& it, const Char* end) { + for (; it != end; ++it) { + switch (*it) { + case '-': specs.set_align(align::left); break; + case '+': specs.set_sign(sign::plus); break; + case '0': specs.set_fill('0'); break; + case ' ': + if (specs.sign() != sign::plus) specs.set_sign(sign::space); + break; + case '#': specs.set_alt(); break; + default: return; + } + } +} + +template +auto parse_header(const Char*& it, const Char* end, format_specs& specs, + GetArg get_arg) -> int { + int arg_index = -1; + Char c = *it; + if (c >= '0' && c <= '9') { + // Parse an argument index (if followed by '$') or a width possibly + // preceded with '0' flag(s). + int value = parse_nonnegative_int(it, end, -1); + if (it != end && *it == '$') { // value is an argument index + ++it; + arg_index = value != -1 ? value : max_value(); + } else { + if (c == '0') specs.set_fill('0'); + if (value != 0) { + // Nonzero value means that we parsed width and don't need to + // parse it or flags again, so return now. + if (value == -1) report_error("number is too big"); + specs.width = value; + return arg_index; + } + } + } + parse_flags(specs, it, end); + // Parse width. + if (it != end) { + if (*it >= '0' && *it <= '9') { + specs.width = parse_nonnegative_int(it, end, -1); + if (specs.width == -1) report_error("number is too big"); + } else if (*it == '*') { + ++it; + specs.width = static_cast( + get_arg(-1).visit(detail::printf_width_handler(specs))); + } + } + return arg_index; +} + +inline auto parse_printf_presentation_type(char c, type t, bool& upper) + -> presentation_type { + using pt = presentation_type; + constexpr auto integral_set = sint_set | uint_set | bool_set | char_set; + switch (c) { + case 'd': return in(t, integral_set) ? pt::dec : pt::none; + case 'o': return in(t, integral_set) ? pt::oct : pt::none; + case 'X': upper = true; FMT_FALLTHROUGH; + case 'x': return in(t, integral_set) ? pt::hex : pt::none; + case 'E': upper = true; FMT_FALLTHROUGH; + case 'e': return in(t, float_set) ? pt::exp : pt::none; + case 'F': upper = true; FMT_FALLTHROUGH; + case 'f': return in(t, float_set) ? pt::fixed : pt::none; + case 'G': upper = true; FMT_FALLTHROUGH; + case 'g': return in(t, float_set) ? pt::general : pt::none; + case 'A': upper = true; FMT_FALLTHROUGH; + case 'a': return in(t, float_set) ? pt::hexfloat : pt::none; + case 'c': return in(t, integral_set) ? pt::chr : pt::none; + case 's': return in(t, string_set | cstring_set) ? pt::string : pt::none; + case 'p': return in(t, pointer_set | cstring_set) ? pt::pointer : pt::none; + default: return pt::none; + } +} + +template +void vprintf(buffer& buf, basic_string_view format, + basic_format_args args) { + using iterator = basic_appender; + auto out = iterator(buf); + auto context = basic_printf_context(out, args); + auto parse_ctx = parse_context(format); + + // Returns the argument with specified index or, if arg_index is -1, the next + // argument. + auto get_arg = [&](int arg_index) { + if (arg_index < 0) + arg_index = parse_ctx.next_arg_id(); + else + parse_ctx.check_arg_id(--arg_index); + auto arg = context.arg(arg_index); + if (!arg) report_error("argument not found"); + return arg; + }; + + const Char* start = parse_ctx.begin(); + const Char* end = parse_ctx.end(); + auto it = start; + while (it != end) { + if (!find(it, end, '%', it)) { + it = end; // find leaves it == nullptr if it doesn't find '%'. + break; + } + Char c = *it++; + if (it != end && *it == c) { + write(out, basic_string_view(start, to_unsigned(it - start))); + start = ++it; + continue; + } + write(out, basic_string_view(start, to_unsigned(it - 1 - start))); + + auto specs = format_specs(); + specs.set_align(align::right); + + // Parse argument index, flags and width. + int arg_index = parse_header(it, end, specs, get_arg); + if (arg_index == 0) report_error("argument not found"); + + // Parse precision. + if (it != end && *it == '.') { + ++it; + c = it != end ? *it : 0; + if ('0' <= c && c <= '9') { + specs.precision = parse_nonnegative_int(it, end, 0); + } else if (c == '*') { + ++it; + specs.precision = + static_cast(get_arg(-1).visit(printf_precision_handler())); + } else { + specs.precision = 0; + } + } + + auto arg = get_arg(arg_index); + // For d, i, o, u, x, and X conversion specifiers, if a precision is + // specified, the '0' flag is ignored + if (specs.precision >= 0 && is_integral_type(arg.type())) { + // Ignore '0' for non-numeric types or if '-' present. + specs.set_fill(' '); + } + if (specs.precision >= 0 && arg.type() == type::cstring_type) { + auto str = arg.visit(get_cstring()); + auto str_end = str + specs.precision; + auto nul = std::find(str, str_end, Char()); + auto sv = basic_string_view( + str, to_unsigned(nul != str_end ? nul - str : specs.precision)); + arg = sv; + } + if (specs.alt() && arg.visit(is_zero_int())) specs.clear_alt(); + if (specs.fill_unit() == '0') { + if (is_arithmetic_type(arg.type()) && specs.align() != align::left) { + specs.set_align(align::numeric); + } else { + // Ignore '0' flag for non-numeric types or if '-' flag is also present. + specs.set_fill(' '); + } + } + + // Parse length and convert the argument to the required type. + c = it != end ? *it++ : 0; + Char t = it != end ? *it : 0; + switch (c) { + case 'h': + if (t == 'h') { + ++it; + t = it != end ? *it : 0; + convert_arg(arg, t); + } else { + convert_arg(arg, t); + } + break; + case 'l': + if (t == 'l') { + ++it; + t = it != end ? *it : 0; + convert_arg(arg, t); + } else { + convert_arg(arg, t); + } + break; + case 'j': convert_arg(arg, t); break; + case 'z': convert_arg(arg, t); break; + case 't': convert_arg(arg, t); break; + case 'L': + // printf produces garbage when 'L' is omitted for long double, no + // need to do the same. + break; + default: --it; convert_arg(arg, c); + } + + // Parse type. + if (it == end) report_error("invalid format string"); + char type = static_cast(*it++); + if (is_integral_type(arg.type())) { + // Normalize type. + switch (type) { + case 'i': + case 'u': type = 'd'; break; + case 'c': + arg.visit(char_converter>(arg)); + break; + } + } + bool upper = false; + specs.set_type(parse_printf_presentation_type(type, arg.type(), upper)); + if (specs.type() == presentation_type::none) + report_error("invalid format specifier"); + if (upper) specs.set_upper(); + + start = it; + + // Format argument. + arg.visit(printf_arg_formatter(out, specs, context)); + } + write(out, basic_string_view(start, to_unsigned(it - start))); +} +} // namespace detail + +using printf_context = basic_printf_context; +using wprintf_context = basic_printf_context; + +using printf_args = basic_format_args; +using wprintf_args = basic_format_args; + +/// Constructs an `format_arg_store` object that contains references to +/// arguments and can be implicitly converted to `printf_args`. +template +inline auto make_printf_args(T&... args) + -> decltype(fmt::make_format_args>(args...)) { + return fmt::make_format_args>(args...); +} + +template struct vprintf_args { + using type = basic_format_args>; +}; + +template +inline auto vsprintf(basic_string_view fmt, + typename vprintf_args::type args) + -> std::basic_string { + auto buf = basic_memory_buffer(); + detail::vprintf(buf, fmt, args); + return {buf.data(), buf.size()}; +} + +/** + * Formats `args` according to specifications in `fmt` and returns the result + * as as string. + * + * **Example**: + * + * std::string message = fmt::sprintf("The answer is %d", 42); + */ +template +inline auto sprintf(string_view fmt, const T&... args) -> std::string { + return vsprintf(fmt, make_printf_args(args...)); +} +template +FMT_DEPRECATED auto sprintf(basic_string_view fmt, const T&... args) + -> std::wstring { + return vsprintf(fmt, make_printf_args(args...)); +} + +template +auto vfprintf(std::FILE* f, basic_string_view fmt, + typename vprintf_args::type args) -> int { + auto buf = basic_memory_buffer(); + detail::vprintf(buf, fmt, args); + size_t size = buf.size(); + return std::fwrite(buf.data(), sizeof(Char), size, f) < size + ? -1 + : static_cast(size); +} + +/** + * Formats `args` according to specifications in `fmt` and writes the output + * to `f`. + * + * **Example**: + * + * fmt::fprintf(stderr, "Don't %s!", "panic"); + */ +template +inline auto fprintf(std::FILE* f, string_view fmt, const T&... args) -> int { + return vfprintf(f, fmt, make_printf_args(args...)); +} +template +FMT_DEPRECATED auto fprintf(std::FILE* f, basic_string_view fmt, + const T&... args) -> int { + return vfprintf(f, fmt, make_printf_args(args...)); +} + +/** + * Formats `args` according to specifications in `fmt` and writes the output + * to `stdout`. + * + * **Example**: + * + * fmt::printf("Elapsed time: %.2f seconds", 1.23); + */ +template +inline auto printf(string_view fmt, const T&... args) -> int { + return vfprintf(stdout, fmt, make_printf_args(args...)); +} + +FMT_END_EXPORT +FMT_END_NAMESPACE + +#endif // FMT_PRINTF_H_ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fmt/ranges.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fmt/ranges.h new file mode 100644 index 0000000000000000000000000000000000000000..f8df05f9a4517bb9add87f261a6c29b9c7ec19f6 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fmt/ranges.h @@ -0,0 +1,856 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Formatting library for C++ - range and tuple support +// +// Copyright (c) 2012 - present, Victor Zverovich and {fmt} contributors +// All rights reserved. +// +// For the license information refer to format.h. + +#ifndef FMT_RANGES_H_ +#define FMT_RANGES_H_ + +#ifndef FMT_MODULE +# include +# include +# include +# include +# include +#endif + +#include "format.h" + +#if FMT_HAS_CPP_ATTRIBUTE(clang::lifetimebound) +# define FMT_LIFETIMEBOUND [[clang::lifetimebound]] +#else +# define FMT_LIFETIMEBOUND +#endif +FMT_PRAGMA_CLANG(diagnostic error "-Wreturn-stack-address") + +FMT_BEGIN_NAMESPACE + +FMT_EXPORT +enum class range_format { disabled, map, set, sequence, string, debug_string }; + +namespace detail { + +template class is_map { + template static auto check(U*) -> typename U::mapped_type; + template static void check(...); + + public: + static constexpr bool value = + !std::is_void(nullptr))>::value; +}; + +template class is_set { + template static auto check(U*) -> typename U::key_type; + template static void check(...); + + public: + static constexpr bool value = + !std::is_void(nullptr))>::value && !is_map::value; +}; + +// C array overload +template +auto range_begin(const T (&arr)[N]) -> const T* { + return arr; +} +template auto range_end(const T (&arr)[N]) -> const T* { + return arr + N; +} + +template +struct has_member_fn_begin_end_t : std::false_type {}; + +template +struct has_member_fn_begin_end_t().begin()), + decltype(std::declval().end())>> + : std::true_type {}; + +// Member function overloads. +template +auto range_begin(T&& rng) -> decltype(static_cast(rng).begin()) { + return static_cast(rng).begin(); +} +template +auto range_end(T&& rng) -> decltype(static_cast(rng).end()) { + return static_cast(rng).end(); +} + +// ADL overloads. Only participate in overload resolution if member functions +// are not found. +template +auto range_begin(T&& rng) + -> enable_if_t::value, + decltype(begin(static_cast(rng)))> { + return begin(static_cast(rng)); +} +template +auto range_end(T&& rng) -> enable_if_t::value, + decltype(end(static_cast(rng)))> { + return end(static_cast(rng)); +} + +template +struct has_const_begin_end : std::false_type {}; +template +struct has_mutable_begin_end : std::false_type {}; + +template +struct has_const_begin_end< + T, void_t&>())), + decltype(detail::range_end( + std::declval&>()))>> + : std::true_type {}; + +template +struct has_mutable_begin_end< + T, void_t())), + decltype(detail::range_end(std::declval())), + // the extra int here is because older versions of MSVC don't + // SFINAE properly unless there are distinct types + int>> : std::true_type {}; + +template struct is_range_ : std::false_type {}; +template +struct is_range_ + : std::integral_constant::value || + has_mutable_begin_end::value)> {}; + +// tuple_size and tuple_element check. +template class is_tuple_like_ { + template ::type> + static auto check(U* p) -> decltype(std::tuple_size::value, 0); + template static void check(...); + + public: + static constexpr bool value = + !std::is_void(nullptr))>::value; +}; + +// Check for integer_sequence +#if defined(__cpp_lib_integer_sequence) || FMT_MSC_VERSION >= 1900 +template +using integer_sequence = std::integer_sequence; +template using index_sequence = std::index_sequence; +template using make_index_sequence = std::make_index_sequence; +#else +template struct integer_sequence { + using value_type = T; + + static FMT_CONSTEXPR auto size() -> size_t { return sizeof...(N); } +}; + +template using index_sequence = integer_sequence; + +template +struct make_integer_sequence : make_integer_sequence {}; +template +struct make_integer_sequence : integer_sequence {}; + +template +using make_index_sequence = make_integer_sequence; +#endif + +template +using tuple_index_sequence = make_index_sequence::value>; + +template ::value> +class is_tuple_formattable_ { + public: + static constexpr bool value = false; +}; +template class is_tuple_formattable_ { + template + static auto all_true(index_sequence, + integer_sequence= 0)...>) -> std::true_type; + static auto all_true(...) -> std::false_type; + + template + static auto check(index_sequence) -> decltype(all_true( + index_sequence{}, + integer_sequence::type, + C>::value)...>{})); + + public: + static constexpr bool value = + decltype(check(tuple_index_sequence{}))::value; +}; + +template +FMT_CONSTEXPR void for_each(index_sequence, Tuple&& t, F&& f) { + using std::get; + // Using a free function get(Tuple) now. + const int unused[] = {0, ((void)f(get(t)), 0)...}; + ignore_unused(unused); +} + +template +FMT_CONSTEXPR void for_each(Tuple&& t, F&& f) { + for_each(tuple_index_sequence>(), + std::forward(t), std::forward(f)); +} + +template +void for_each2(index_sequence, Tuple1&& t1, Tuple2&& t2, F&& f) { + using std::get; + const int unused[] = {0, ((void)f(get(t1), get(t2)), 0)...}; + ignore_unused(unused); +} + +template +void for_each2(Tuple1&& t1, Tuple2&& t2, F&& f) { + for_each2(tuple_index_sequence>(), + std::forward(t1), std::forward(t2), + std::forward(f)); +} + +namespace tuple { +// Workaround a bug in MSVC 2019 (v140). +template +using result_t = std::tuple, Char>...>; + +using std::get; +template +auto get_formatters(index_sequence) + -> result_t(std::declval()))...>; +} // namespace tuple + +#if FMT_MSC_VERSION && FMT_MSC_VERSION < 1920 +// Older MSVC doesn't get the reference type correctly for arrays. +template struct range_reference_type_impl { + using type = decltype(*detail::range_begin(std::declval())); +}; + +template struct range_reference_type_impl { + using type = T&; +}; + +template +using range_reference_type = typename range_reference_type_impl::type; +#else +template +using range_reference_type = + decltype(*detail::range_begin(std::declval())); +#endif + +// We don't use the Range's value_type for anything, but we do need the Range's +// reference type, with cv-ref stripped. +template +using uncvref_type = remove_cvref_t>; + +template +struct range_format_kind_ + : std::integral_constant, T>::value + ? range_format::disabled + : is_map::value ? range_format::map + : is_set::value ? range_format::set + : range_format::sequence> {}; + +template +using range_format_constant = std::integral_constant; + +// These are not generic lambdas for compatibility with C++11. +template struct parse_empty_specs { + template FMT_CONSTEXPR void operator()(Formatter& f) { + f.parse(ctx); + detail::maybe_set_debug_format(f, true); + } + parse_context& ctx; +}; +template struct format_tuple_element { + using char_type = typename FormatContext::char_type; + + template + void operator()(const formatter& f, const T& v) { + if (i > 0) ctx.advance_to(detail::copy(separator, ctx.out())); + ctx.advance_to(f.format(v, ctx)); + ++i; + } + + int i; + FormatContext& ctx; + basic_string_view separator; +}; + +} // namespace detail + +FMT_EXPORT +template struct is_tuple_like { + static constexpr bool value = + detail::is_tuple_like_::value && !detail::is_range_::value; +}; + +FMT_EXPORT +template struct is_tuple_formattable { + static constexpr bool value = detail::is_tuple_formattable_::value; +}; + +template +struct formatter::value && + fmt::is_tuple_formattable::value>> { + private: + decltype(detail::tuple::get_formatters( + detail::tuple_index_sequence())) formatters_; + + basic_string_view separator_ = detail::string_literal{}; + basic_string_view opening_bracket_ = + detail::string_literal{}; + basic_string_view closing_bracket_ = + detail::string_literal{}; + + public: + FMT_CONSTEXPR formatter() {} + + FMT_CONSTEXPR void set_separator(basic_string_view sep) { + separator_ = sep; + } + + FMT_CONSTEXPR void set_brackets(basic_string_view open, + basic_string_view close) { + opening_bracket_ = open; + closing_bracket_ = close; + } + + FMT_CONSTEXPR auto parse(parse_context& ctx) -> const Char* { + auto it = ctx.begin(); + auto end = ctx.end(); + if (it != end && detail::to_ascii(*it) == 'n') { + ++it; + set_brackets({}, {}); + set_separator({}); + } + if (it != end && *it != '}') report_error("invalid format specifier"); + ctx.advance_to(it); + detail::for_each(formatters_, detail::parse_empty_specs{ctx}); + return it; + } + + template + auto format(const Tuple& value, FormatContext& ctx) const + -> decltype(ctx.out()) { + ctx.advance_to(detail::copy(opening_bracket_, ctx.out())); + detail::for_each2( + formatters_, value, + detail::format_tuple_element{0, ctx, separator_}); + return detail::copy(closing_bracket_, ctx.out()); + } +}; + +FMT_EXPORT +template struct is_range { + static constexpr bool value = + detail::is_range_::value && !detail::has_to_string_view::value; +}; + +namespace detail { + +template +using range_formatter_type = formatter, Char>; + +template +using maybe_const_range = + conditional_t::value, const R, R>; + +template +struct is_formattable_delayed + : is_formattable>, Char> {}; +} // namespace detail + +template struct conjunction : std::true_type {}; +template struct conjunction

: P {}; +template +struct conjunction + : conditional_t, P1> {}; + +FMT_EXPORT +template +struct range_formatter; + +template +struct range_formatter< + T, Char, + enable_if_t>, + is_formattable>::value>> { + private: + detail::range_formatter_type underlying_; + basic_string_view separator_ = detail::string_literal{}; + basic_string_view opening_bracket_ = + detail::string_literal{}; + basic_string_view closing_bracket_ = + detail::string_literal{}; + bool is_debug = false; + + template ::value)> + auto write_debug_string(Output& out, It it, Sentinel end) const -> Output { + auto buf = basic_memory_buffer(); + for (; it != end; ++it) buf.push_back(*it); + auto specs = format_specs(); + specs.set_type(presentation_type::debug); + return detail::write( + out, basic_string_view(buf.data(), buf.size()), specs); + } + + template ::value)> + auto write_debug_string(Output& out, It, Sentinel) const -> Output { + return out; + } + + public: + FMT_CONSTEXPR range_formatter() {} + + FMT_CONSTEXPR auto underlying() -> detail::range_formatter_type& { + return underlying_; + } + + FMT_CONSTEXPR void set_separator(basic_string_view sep) { + separator_ = sep; + } + + FMT_CONSTEXPR void set_brackets(basic_string_view open, + basic_string_view close) { + opening_bracket_ = open; + closing_bracket_ = close; + } + + FMT_CONSTEXPR auto parse(parse_context& ctx) -> const Char* { + auto it = ctx.begin(); + auto end = ctx.end(); + detail::maybe_set_debug_format(underlying_, true); + if (it == end) return underlying_.parse(ctx); + + switch (detail::to_ascii(*it)) { + case 'n': + set_brackets({}, {}); + ++it; + break; + case '?': + is_debug = true; + set_brackets({}, {}); + ++it; + if (it == end || *it != 's') report_error("invalid format specifier"); + FMT_FALLTHROUGH; + case 's': + if (!std::is_same::value) + report_error("invalid format specifier"); + if (!is_debug) { + set_brackets(detail::string_literal{}, + detail::string_literal{}); + set_separator({}); + detail::maybe_set_debug_format(underlying_, false); + } + ++it; + return it; + } + + if (it != end && *it != '}') { + if (*it != ':') report_error("invalid format specifier"); + detail::maybe_set_debug_format(underlying_, false); + ++it; + } + + ctx.advance_to(it); + return underlying_.parse(ctx); + } + + template + auto format(R&& range, FormatContext& ctx) const -> decltype(ctx.out()) { + auto out = ctx.out(); + auto it = detail::range_begin(range); + auto end = detail::range_end(range); + if (is_debug) return write_debug_string(out, std::move(it), end); + + out = detail::copy(opening_bracket_, out); + int i = 0; + for (; it != end; ++it) { + if (i > 0) out = detail::copy(separator_, out); + ctx.advance_to(out); + auto&& item = *it; // Need an lvalue + out = underlying_.format(item, ctx); + ++i; + } + out = detail::copy(closing_bracket_, out); + return out; + } +}; + +FMT_EXPORT +template +struct range_format_kind + : conditional_t< + is_range::value, detail::range_format_kind_, + std::integral_constant> {}; + +template +struct formatter< + R, Char, + enable_if_t::value != range_format::disabled && + range_format_kind::value != range_format::map && + range_format_kind::value != range_format::string && + range_format_kind::value != range_format::debug_string>, + detail::is_formattable_delayed>::value>> { + private: + using range_type = detail::maybe_const_range; + range_formatter, Char> range_formatter_; + + public: + using nonlocking = void; + + FMT_CONSTEXPR formatter() { + if (detail::const_check(range_format_kind::value != + range_format::set)) + return; + range_formatter_.set_brackets(detail::string_literal{}, + detail::string_literal{}); + } + + FMT_CONSTEXPR auto parse(parse_context& ctx) -> const Char* { + return range_formatter_.parse(ctx); + } + + template + auto format(range_type& range, FormatContext& ctx) const + -> decltype(ctx.out()) { + return range_formatter_.format(range, ctx); + } +}; + +// A map formatter. +template +struct formatter< + R, Char, + enable_if_t::value == range_format::map>, + detail::is_formattable_delayed>::value>> { + private: + using map_type = detail::maybe_const_range; + using element_type = detail::uncvref_type; + + decltype(detail::tuple::get_formatters( + detail::tuple_index_sequence())) formatters_; + bool no_delimiters_ = false; + + public: + FMT_CONSTEXPR formatter() {} + + FMT_CONSTEXPR auto parse(parse_context& ctx) -> const Char* { + auto it = ctx.begin(); + auto end = ctx.end(); + if (it != end) { + if (detail::to_ascii(*it) == 'n') { + no_delimiters_ = true; + ++it; + } + if (it != end && *it != '}') { + if (*it != ':') report_error("invalid format specifier"); + ++it; + } + ctx.advance_to(it); + } + detail::for_each(formatters_, detail::parse_empty_specs{ctx}); + return it; + } + + template + auto format(map_type& map, FormatContext& ctx) const -> decltype(ctx.out()) { + auto out = ctx.out(); + basic_string_view open = detail::string_literal{}; + if (!no_delimiters_) out = detail::copy(open, out); + int i = 0; + basic_string_view sep = detail::string_literal{}; + for (auto&& value : map) { + if (i > 0) out = detail::copy(sep, out); + ctx.advance_to(out); + detail::for_each2(formatters_, value, + detail::format_tuple_element{ + 0, ctx, detail::string_literal{}}); + ++i; + } + basic_string_view close = detail::string_literal{}; + if (!no_delimiters_) out = detail::copy(close, out); + return out; + } +}; + +// A (debug_)string formatter. +template +struct formatter< + R, Char, + enable_if_t::value == range_format::string || + range_format_kind::value == + range_format::debug_string>> { + private: + using range_type = detail::maybe_const_range; + using string_type = + conditional_t, + decltype(detail::range_begin(std::declval())), + decltype(detail::range_end(std::declval()))>::value, + detail::std_string_view, std::basic_string>; + + formatter underlying_; + + public: + FMT_CONSTEXPR auto parse(parse_context& ctx) -> const Char* { + return underlying_.parse(ctx); + } + + template + auto format(range_type& range, FormatContext& ctx) const + -> decltype(ctx.out()) { + auto out = ctx.out(); + if (detail::const_check(range_format_kind::value == + range_format::debug_string)) + *out++ = '"'; + out = underlying_.format( + string_type{detail::range_begin(range), detail::range_end(range)}, ctx); + if (detail::const_check(range_format_kind::value == + range_format::debug_string)) + *out++ = '"'; + return out; + } +}; + +template +struct join_view : detail::view { + It begin; + Sentinel end; + basic_string_view sep; + + join_view(It b, Sentinel e, basic_string_view s) + : begin(std::move(b)), end(e), sep(s) {} +}; + +template +struct formatter, Char> { + private: + using value_type = +#ifdef __cpp_lib_ranges + std::iter_value_t; +#else + typename std::iterator_traits::value_type; +#endif + formatter, Char> value_formatter_; + + using view = conditional_t::value, + const join_view, + join_view>; + + public: + using nonlocking = void; + + FMT_CONSTEXPR auto parse(parse_context& ctx) -> const Char* { + return value_formatter_.parse(ctx); + } + + template + auto format(view& value, FormatContext& ctx) const -> decltype(ctx.out()) { + using iter = + conditional_t::value, It, It&>; + iter it = value.begin; + auto out = ctx.out(); + if (it == value.end) return out; + out = value_formatter_.format(*it, ctx); + ++it; + while (it != value.end) { + out = detail::copy(value.sep.begin(), value.sep.end(), out); + ctx.advance_to(out); + out = value_formatter_.format(*it, ctx); + ++it; + } + return out; + } +}; + +FMT_EXPORT +template struct tuple_join_view : detail::view { + const Tuple& tuple; + basic_string_view sep; + + tuple_join_view(const Tuple& t, basic_string_view s) + : tuple(t), sep{s} {} +}; + +// Define FMT_TUPLE_JOIN_SPECIFIERS to enable experimental format specifiers +// support in tuple_join. It is disabled by default because of issues with +// the dynamic width and precision. +#ifndef FMT_TUPLE_JOIN_SPECIFIERS +# define FMT_TUPLE_JOIN_SPECIFIERS 0 +#endif + +template +struct formatter, Char, + enable_if_t::value>> { + FMT_CONSTEXPR auto parse(parse_context& ctx) -> const Char* { + return do_parse(ctx, std::tuple_size()); + } + + template + auto format(const tuple_join_view& value, + FormatContext& ctx) const -> typename FormatContext::iterator { + return do_format(value, ctx, std::tuple_size()); + } + + private: + decltype(detail::tuple::get_formatters( + detail::tuple_index_sequence())) formatters_; + + FMT_CONSTEXPR auto do_parse(parse_context& ctx, + std::integral_constant) + -> const Char* { + return ctx.begin(); + } + + template + FMT_CONSTEXPR auto do_parse(parse_context& ctx, + std::integral_constant) + -> const Char* { + auto end = ctx.begin(); +#if FMT_TUPLE_JOIN_SPECIFIERS + end = std::get::value - N>(formatters_).parse(ctx); + if (N > 1) { + auto end1 = do_parse(ctx, std::integral_constant()); + if (end != end1) + report_error("incompatible format specs for tuple elements"); + } +#endif + return end; + } + + template + auto do_format(const tuple_join_view&, FormatContext& ctx, + std::integral_constant) const -> + typename FormatContext::iterator { + return ctx.out(); + } + + template + auto do_format(const tuple_join_view& value, FormatContext& ctx, + std::integral_constant) const -> + typename FormatContext::iterator { + using std::get; + auto out = + std::get::value - N>(formatters_) + .format(get::value - N>(value.tuple), ctx); + if (N <= 1) return out; + out = detail::copy(value.sep, out); + ctx.advance_to(out); + return do_format(value, ctx, std::integral_constant()); + } +}; + +namespace detail { +// Check if T has an interface like a container adaptor (e.g. std::stack, +// std::queue, std::priority_queue). +template class is_container_adaptor_like { + template static auto check(U* p) -> typename U::container_type; + template static void check(...); + + public: + static constexpr bool value = + !std::is_void(nullptr))>::value; +}; + +template struct all { + const Container& c; + auto begin() const -> typename Container::const_iterator { return c.begin(); } + auto end() const -> typename Container::const_iterator { return c.end(); } +}; +} // namespace detail + +template +struct formatter< + T, Char, + enable_if_t, + bool_constant::value == + range_format::disabled>>::value>> + : formatter, Char> { + using all = detail::all; + template + auto format(const T& value, FormatContext& ctx) const -> decltype(ctx.out()) { + struct getter : T { + static auto get(const T& v) -> all { + return {v.*(&getter::c)}; // Access c through the derived class. + } + }; + return formatter::format(getter::get(value), ctx); + } +}; + +FMT_BEGIN_EXPORT + +/// Returns a view that formats the iterator range `[begin, end)` with elements +/// separated by `sep`. +template +auto join(It begin, Sentinel end, string_view sep) -> join_view { + return {std::move(begin), end, sep}; +} + +/** + * Returns a view that formats `range` with elements separated by `sep`. + * + * **Example**: + * + * auto v = std::vector{1, 2, 3}; + * fmt::print("{}", fmt::join(v, ", ")); + * // Output: 1, 2, 3 + * + * `fmt::join` applies passed format specifiers to the range elements: + * + * fmt::print("{:02}", fmt::join(v, ", ")); + * // Output: 01, 02, 03 + */ +template ::value)> +auto join(Range&& r, string_view sep) + -> join_view { + return {detail::range_begin(r), detail::range_end(r), sep}; +} + +/** + * Returns an object that formats `std::tuple` with elements separated by `sep`. + * + * **Example**: + * + * auto t = std::tuple(1, 'a'); + * fmt::print("{}", fmt::join(t, ", ")); + * // Output: 1, a + */ +template ::value)> +FMT_CONSTEXPR auto join(const Tuple& tuple FMT_LIFETIMEBOUND, string_view sep) + -> tuple_join_view { + return {tuple, sep}; +} + +/** + * Returns an object that formats `std::initializer_list` with elements + * separated by `sep`. + * + * **Example**: + * + * fmt::print("{}", fmt::join({1, 2, 3}, ", ")); + * // Output: "1, 2, 3" + */ +template +auto join(std::initializer_list list, string_view sep) + -> join_view { + return join(std::begin(list), std::end(list), sep); +} + +FMT_END_EXPORT +FMT_END_NAMESPACE + +#endif // FMT_RANGES_H_ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fmt/std.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fmt/std.h new file mode 100644 index 0000000000000000000000000000000000000000..1c166432291fad18366482ec158f73c137e84806 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fmt/std.h @@ -0,0 +1,732 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Formatting library for C++ - formatters for standard library types +// +// Copyright (c) 2012 - present, Victor Zverovich +// All rights reserved. +// +// For the license information refer to format.h. + +#ifndef FMT_STD_H_ +#define FMT_STD_H_ + +#include "format.h" +#include "ostream.h" + +#ifndef FMT_MODULE +# include +# include +# include +# include +# include // std::reference_wrapper +# include +# include +# include +# include // std::type_info +# include // std::make_index_sequence + +// Check FMT_CPLUSPLUS to suppress a bogus warning in MSVC. +# if FMT_CPLUSPLUS >= 201703L +# if FMT_HAS_INCLUDE() && \ + (!defined(FMT_CPP_LIB_FILESYSTEM) || FMT_CPP_LIB_FILESYSTEM != 0) +# include +# endif +# if FMT_HAS_INCLUDE() +# include +# endif +# if FMT_HAS_INCLUDE() +# include +# endif +# endif +// Use > instead of >= in the version check because may be +// available after C++17 but before C++20 is marked as implemented. +# if FMT_CPLUSPLUS > 201703L && FMT_HAS_INCLUDE() +# include +# endif +# if FMT_CPLUSPLUS > 202002L && FMT_HAS_INCLUDE() +# include +# endif +#endif // FMT_MODULE + +#if FMT_HAS_INCLUDE() +# include +#endif + +// GCC 4 does not support FMT_HAS_INCLUDE. +#if FMT_HAS_INCLUDE() || defined(__GLIBCXX__) +# include +// Android NDK with gabi++ library on some architectures does not implement +// abi::__cxa_demangle(). +# ifndef __GABIXX_CXXABI_H__ +# define FMT_HAS_ABI_CXA_DEMANGLE +# endif +#endif + +#ifdef FMT_CPP_LIB_FILESYSTEM +// Use the provided definition. +#elif defined(__cpp_lib_filesystem) +# define FMT_CPP_LIB_FILESYSTEM __cpp_lib_filesystem +#else +# define FMT_CPP_LIB_FILESYSTEM 0 +#endif + +#ifdef FMT_CPP_LIB_VARIANT +// Use the provided definition. +#elif defined(__cpp_lib_variant) +# define FMT_CPP_LIB_VARIANT __cpp_lib_variant +#else +# define FMT_CPP_LIB_VARIANT 0 +#endif + +FMT_BEGIN_NAMESPACE +namespace detail { + +#if FMT_CPP_LIB_FILESYSTEM + +template +auto get_path_string(const std::filesystem::path& p, + const std::basic_string& native) { + if constexpr (std::is_same_v && std::is_same_v) + return to_utf8(native, to_utf8_error_policy::replace); + else + return p.string(); +} + +template +void write_escaped_path(basic_memory_buffer& quoted, + const std::filesystem::path& p, + const std::basic_string& native) { + if constexpr (std::is_same_v && + std::is_same_v) { + auto buf = basic_memory_buffer(); + write_escaped_string(std::back_inserter(buf), native); + bool valid = to_utf8::convert(quoted, {buf.data(), buf.size()}); + FMT_ASSERT(valid, "invalid utf16"); + } else if constexpr (std::is_same_v) { + write_escaped_string( + std::back_inserter(quoted), native); + } else { + write_escaped_string(std::back_inserter(quoted), p.string()); + } +} + +#endif // FMT_CPP_LIB_FILESYSTEM + +#if defined(__cpp_lib_expected) || FMT_CPP_LIB_VARIANT + +template +auto write_escaped_alternative(OutputIt out, const T& v, FormatContext& ctx) + -> OutputIt { + if constexpr (has_to_string_view::value) + return write_escaped_string(out, detail::to_string_view(v)); + if constexpr (std::is_same_v) return write_escaped_char(out, v); + + formatter, Char> underlying; + maybe_set_debug_format(underlying, true); + return underlying.format(v, ctx); +} +#endif + +#if FMT_CPP_LIB_VARIANT + +template struct is_variant_like_ : std::false_type {}; +template +struct is_variant_like_> : std::true_type {}; + +template class is_variant_formattable { + template + static auto check(std::index_sequence) -> std::conjunction< + is_formattable, Char>...>; + + public: + static constexpr bool value = decltype(check( + std::make_index_sequence::value>()))::value; +}; + +#endif // FMT_CPP_LIB_VARIANT + +#if FMT_USE_RTTI +inline auto normalize_libcxx_inline_namespaces(string_view demangled_name_view, + char* begin) -> string_view { + // Normalization of stdlib inline namespace names. + // libc++ inline namespaces. + // std::__1::* -> std::* + // std::__1::__fs::* -> std::* + // libstdc++ inline namespaces. + // std::__cxx11::* -> std::* + // std::filesystem::__cxx11::* -> std::filesystem::* + if (demangled_name_view.starts_with("std::")) { + char* to = begin + 5; // std:: + for (const char *from = to, *end = begin + demangled_name_view.size(); + from < end;) { + // This is safe, because demangled_name is NUL-terminated. + if (from[0] == '_' && from[1] == '_') { + const char* next = from + 1; + while (next < end && *next != ':') next++; + if (next[0] == ':' && next[1] == ':') { + from = next + 2; + continue; + } + } + *to++ = *from++; + } + demangled_name_view = {begin, detail::to_unsigned(to - begin)}; + } + return demangled_name_view; +} + +template +auto normalize_msvc_abi_name(string_view abi_name_view, OutputIt out) + -> OutputIt { + const string_view demangled_name(abi_name_view); + for (size_t i = 0; i < demangled_name.size(); ++i) { + auto sub = demangled_name; + sub.remove_prefix(i); + if (sub.starts_with("enum ")) { + i += 4; + continue; + } + if (sub.starts_with("class ") || sub.starts_with("union ")) { + i += 5; + continue; + } + if (sub.starts_with("struct ")) { + i += 6; + continue; + } + if (*sub.begin() != ' ') *out++ = *sub.begin(); + } + return out; +} + +template +auto write_demangled_name(OutputIt out, const std::type_info& ti) -> OutputIt { +# ifdef FMT_HAS_ABI_CXA_DEMANGLE + int status = 0; + size_t size = 0; + std::unique_ptr demangled_name_ptr( + abi::__cxa_demangle(ti.name(), nullptr, &size, &status), &free); + + string_view demangled_name_view; + if (demangled_name_ptr) { + demangled_name_view = normalize_libcxx_inline_namespaces( + demangled_name_ptr.get(), demangled_name_ptr.get()); + } else { + demangled_name_view = string_view(ti.name()); + } + return detail::write_bytes(out, demangled_name_view); +# elif FMT_MSC_VERSION && defined(_MSVC_STL_UPDATE) + return normalize_msvc_abi_name(ti.name(), out); +# elif FMT_MSC_VERSION && defined(_LIBCPP_VERSION) + const string_view demangled_name = ti.name(); + std::string name_copy(demangled_name.size(), '\0'); + // normalize_msvc_abi_name removes class, struct, union etc that MSVC has in + // front of types + name_copy.erase(normalize_msvc_abi_name(demangled_name, name_copy.begin()), + name_copy.end()); + // normalize_libcxx_inline_namespaces removes the inline __1, __2, etc + // namespaces libc++ uses for ABI versioning On MSVC ABI + libc++ + // environments, we need to eliminate both of them. + const string_view normalized_name = + normalize_libcxx_inline_namespaces(name_copy, name_copy.data()); + return detail::write_bytes(out, normalized_name); +# else + return detail::write_bytes(out, string_view(ti.name())); +# endif +} + +#endif // FMT_USE_RTTI + +template +struct has_flip : std::false_type {}; + +template +struct has_flip().flip())>> + : std::true_type {}; + +template struct is_bit_reference_like { + static constexpr bool value = std::is_convertible::value && + std::is_nothrow_assignable::value && + has_flip::value; +}; + +// Workaround for libc++ incompatibility with C++ standard. +// According to the Standard, `bitset::operator[] const` returns bool. +#if defined(_LIBCPP_VERSION) && !defined(FMT_IMPORT_STD) +template +struct is_bit_reference_like> { + static constexpr bool value = true; +}; +#endif + +template +struct has_format_as : std::false_type {}; +template +struct has_format_as()))>> + : std::true_type {}; + +template +struct has_format_as_member : std::false_type {}; +template +struct has_format_as_member< + T, void_t::format_as(std::declval()))>> + : std::true_type {}; + +} // namespace detail + +template +auto ptr(const std::unique_ptr& p) -> const void* { + return p.get(); +} +template auto ptr(const std::shared_ptr& p) -> const void* { + return p.get(); +} + +#if FMT_CPP_LIB_FILESYSTEM + +template struct formatter { + private: + format_specs specs_; + detail::arg_ref width_ref_; + bool debug_ = false; + char path_type_ = 0; + + public: + FMT_CONSTEXPR void set_debug_format(bool set = true) { debug_ = set; } + + FMT_CONSTEXPR auto parse(parse_context& ctx) { + auto it = ctx.begin(), end = ctx.end(); + if (it == end) return it; + + it = detail::parse_align(it, end, specs_); + if (it == end) return it; + + Char c = *it; + if ((c >= '0' && c <= '9') || c == '{') + it = detail::parse_width(it, end, specs_, width_ref_, ctx); + if (it != end && *it == '?') { + debug_ = true; + ++it; + } + if (it != end && (*it == 'g')) path_type_ = detail::to_ascii(*it++); + return it; + } + + template + auto format(const std::filesystem::path& p, FormatContext& ctx) const { + auto specs = specs_; + auto path_string = + !path_type_ ? p.native() + : p.generic_string(); + + detail::handle_dynamic_spec(specs.dynamic_width(), specs.width, width_ref_, + ctx); + if (!debug_) { + auto s = detail::get_path_string(p, path_string); + return detail::write(ctx.out(), basic_string_view(s), specs); + } + auto quoted = basic_memory_buffer(); + detail::write_escaped_path(quoted, p, path_string); + return detail::write(ctx.out(), + basic_string_view(quoted.data(), quoted.size()), + specs); + } +}; + +class path : public std::filesystem::path { + public: + auto display_string() const -> std::string { + const std::filesystem::path& base = *this; + return fmt::format(FMT_STRING("{}"), base); + } + auto system_string() const -> std::string { return string(); } + + auto generic_display_string() const -> std::string { + const std::filesystem::path& base = *this; + return fmt::format(FMT_STRING("{:g}"), base); + } + auto generic_system_string() const -> std::string { return generic_string(); } +}; + +#endif // FMT_CPP_LIB_FILESYSTEM + +template +struct formatter, Char> + : nested_formatter, Char> { + private: + // This is a functor because C++11 doesn't support generic lambdas. + struct writer { + const std::bitset& bs; + + template + FMT_CONSTEXPR auto operator()(OutputIt out) -> OutputIt { + for (auto pos = N; pos > 0; --pos) + out = detail::write(out, bs[pos - 1] ? Char('1') : Char('0')); + return out; + } + }; + + public: + template + auto format(const std::bitset& bs, FormatContext& ctx) const + -> decltype(ctx.out()) { + return this->write_padded(ctx, writer{bs}); + } +}; + +template +struct formatter : basic_ostream_formatter {}; + +#ifdef __cpp_lib_optional +template +struct formatter, Char, + std::enable_if_t::value>> { + private: + formatter, Char> underlying_; + static constexpr basic_string_view optional = + detail::string_literal{}; + static constexpr basic_string_view none = + detail::string_literal{}; + + public: + FMT_CONSTEXPR auto parse(parse_context& ctx) { + detail::maybe_set_debug_format(underlying_, true); + return underlying_.parse(ctx); + } + + template + auto format(const std::optional& opt, FormatContext& ctx) const + -> decltype(ctx.out()) { + if (!opt) return detail::write(ctx.out(), none); + + auto out = ctx.out(); + out = detail::write(out, optional); + ctx.advance_to(out); + out = underlying_.format(*opt, ctx); + return detail::write(out, ')'); + } +}; +#endif // __cpp_lib_optional + +#ifdef __cpp_lib_expected +template +struct formatter, Char, + std::enable_if_t<(std::is_void::value || + is_formattable::value) && + is_formattable::value>> { + FMT_CONSTEXPR auto parse(parse_context& ctx) -> const Char* { + return ctx.begin(); + } + + template + auto format(const std::expected& value, FormatContext& ctx) const + -> decltype(ctx.out()) { + auto out = ctx.out(); + + if (value.has_value()) { + out = detail::write(out, "expected("); + if constexpr (!std::is_void::value) + out = detail::write_escaped_alternative(out, *value, ctx); + } else { + out = detail::write(out, "unexpected("); + out = detail::write_escaped_alternative(out, value.error(), ctx); + } + *out++ = ')'; + return out; + } +}; +#endif // __cpp_lib_expected + +#ifdef __cpp_lib_source_location +template <> struct formatter { + FMT_CONSTEXPR auto parse(parse_context<>& ctx) { return ctx.begin(); } + + template + auto format(const std::source_location& loc, FormatContext& ctx) const + -> decltype(ctx.out()) { + auto out = ctx.out(); + out = detail::write(out, loc.file_name()); + out = detail::write(out, ':'); + out = detail::write(out, loc.line()); + out = detail::write(out, ':'); + out = detail::write(out, loc.column()); + out = detail::write(out, ": "); + out = detail::write(out, loc.function_name()); + return out; + } +}; +#endif + +#if FMT_CPP_LIB_VARIANT + +template struct is_variant_like { + static constexpr bool value = detail::is_variant_like_::value; +}; + +template struct formatter { + FMT_CONSTEXPR auto parse(parse_context& ctx) -> const Char* { + return ctx.begin(); + } + + template + auto format(const std::monostate&, FormatContext& ctx) const + -> decltype(ctx.out()) { + return detail::write(ctx.out(), "monostate"); + } +}; + +template +struct formatter, + detail::is_variant_formattable>>> { + FMT_CONSTEXPR auto parse(parse_context& ctx) -> const Char* { + return ctx.begin(); + } + + template + auto format(const Variant& value, FormatContext& ctx) const + -> decltype(ctx.out()) { + auto out = ctx.out(); + + out = detail::write(out, "variant("); + FMT_TRY { + std::visit( + [&](const auto& v) { + out = detail::write_escaped_alternative(out, v, ctx); + }, + value); + } + FMT_CATCH(const std::bad_variant_access&) { + detail::write(out, "valueless by exception"); + } + *out++ = ')'; + return out; + } +}; + +#endif // FMT_CPP_LIB_VARIANT + +template <> struct formatter { + private: + format_specs specs_; + detail::arg_ref width_ref_; + bool debug_ = false; + + public: + FMT_CONSTEXPR void set_debug_format(bool set = true) { debug_ = set; } + + FMT_CONSTEXPR auto parse(parse_context<>& ctx) -> const char* { + auto it = ctx.begin(), end = ctx.end(); + if (it == end) return it; + + it = detail::parse_align(it, end, specs_); + + char c = *it; + if (it != end && ((c >= '0' && c <= '9') || c == '{')) + it = detail::parse_width(it, end, specs_, width_ref_, ctx); + + if (it != end && *it == '?') { + debug_ = true; + ++it; + } + if (it != end && *it == 's') { + specs_.set_type(presentation_type::string); + ++it; + } + return it; + } + + template + FMT_CONSTEXPR20 auto format(const std::error_code& ec, + FormatContext& ctx) const -> decltype(ctx.out()) { + auto specs = specs_; + detail::handle_dynamic_spec(specs.dynamic_width(), specs.width, width_ref_, + ctx); + auto buf = memory_buffer(); + if (specs_.type() == presentation_type::string) { + buf.append(ec.message()); + } else { + buf.append(string_view(ec.category().name())); + buf.push_back(':'); + detail::write(appender(buf), ec.value()); + } + auto quoted = memory_buffer(); + auto str = string_view(buf.data(), buf.size()); + if (debug_) { + detail::write_escaped_string(std::back_inserter(quoted), str); + str = string_view(quoted.data(), quoted.size()); + } + return detail::write(ctx.out(), str, specs); + } +}; + +#if FMT_USE_RTTI +template <> struct formatter { + public: + FMT_CONSTEXPR auto parse(parse_context<>& ctx) -> const char* { + return ctx.begin(); + } + + template + auto format(const std::type_info& ti, Context& ctx) const + -> decltype(ctx.out()) { + return detail::write_demangled_name(ctx.out(), ti); + } +}; +#endif // FMT_USE_RTTI + +template +struct formatter< + T, char, + typename std::enable_if::value>::type> { + private: + bool with_typename_ = false; + + public: + FMT_CONSTEXPR auto parse(parse_context<>& ctx) -> const char* { + auto it = ctx.begin(); + auto end = ctx.end(); + if (it == end || *it == '}') return it; + if (*it == 't') { + ++it; + with_typename_ = FMT_USE_RTTI != 0; + } + return it; + } + + template + auto format(const std::exception& ex, Context& ctx) const + -> decltype(ctx.out()) { + auto out = ctx.out(); +#if FMT_USE_RTTI + if (with_typename_) { + out = detail::write_demangled_name(out, typeid(ex)); + *out++ = ':'; + *out++ = ' '; + } +#endif + return detail::write_bytes(out, string_view(ex.what())); + } +}; + +// We can't use std::vector::reference and +// std::bitset::reference because the compiler can't deduce Allocator and N +// in partial specialization. +template +struct formatter::value>> + : formatter { + template + FMT_CONSTEXPR auto format(const BitRef& v, FormatContext& ctx) const + -> decltype(ctx.out()) { + return formatter::format(v, ctx); + } +}; + +template +struct formatter, Char, + enable_if_t::value>> + : formatter { + template + auto format(const std::atomic& v, FormatContext& ctx) const + -> decltype(ctx.out()) { + return formatter::format(v.load(), ctx); + } +}; + +#ifdef __cpp_lib_atomic_flag_test +template +struct formatter : formatter { + template + auto format(const std::atomic_flag& v, FormatContext& ctx) const + -> decltype(ctx.out()) { + return formatter::format(v.test(), ctx); + } +}; +#endif // __cpp_lib_atomic_flag_test + +template struct formatter, Char> { + private: + detail::dynamic_format_specs specs_; + + template + FMT_CONSTEXPR auto do_format(const std::complex& c, + detail::dynamic_format_specs& specs, + FormatContext& ctx, OutputIt out) const + -> OutputIt { + if (c.real() != 0) { + *out++ = Char('('); + out = detail::write(out, c.real(), specs, ctx.locale()); + specs.set_sign(sign::plus); + out = detail::write(out, c.imag(), specs, ctx.locale()); + if (!detail::isfinite(c.imag())) *out++ = Char(' '); + *out++ = Char('i'); + *out++ = Char(')'); + return out; + } + out = detail::write(out, c.imag(), specs, ctx.locale()); + if (!detail::isfinite(c.imag())) *out++ = Char(' '); + *out++ = Char('i'); + return out; + } + + public: + FMT_CONSTEXPR auto parse(parse_context& ctx) -> const Char* { + if (ctx.begin() == ctx.end() || *ctx.begin() == '}') return ctx.begin(); + return parse_format_specs(ctx.begin(), ctx.end(), specs_, ctx, + detail::type_constant::value); + } + + template + auto format(const std::complex& c, FormatContext& ctx) const + -> decltype(ctx.out()) { + auto specs = specs_; + if (specs.dynamic()) { + detail::handle_dynamic_spec(specs.dynamic_width(), specs.width, + specs.width_ref, ctx); + detail::handle_dynamic_spec(specs.dynamic_precision(), specs.precision, + specs.precision_ref, ctx); + } + + if (specs.width == 0) return do_format(c, specs, ctx, ctx.out()); + auto buf = basic_memory_buffer(); + + auto outer_specs = format_specs(); + outer_specs.width = specs.width; + outer_specs.copy_fill_from(specs); + outer_specs.set_align(specs.align()); + + specs.width = 0; + specs.set_fill({}); + specs.set_align(align::none); + + do_format(c, specs, ctx, basic_appender(buf)); + return detail::write(ctx.out(), + basic_string_view(buf.data(), buf.size()), + outer_specs); + } +}; + +template +struct formatter, Char, + // Guard against format_as because reference_wrapper is + // implicitly convertible to T&. + enable_if_t, Char>::value && + !detail::has_format_as::value && + !detail::has_format_as_member::value>> + : formatter, Char> { + template + auto format(std::reference_wrapper ref, FormatContext& ctx) const + -> decltype(ctx.out()) { + return formatter, Char>::format(ref.get(), ctx); + } +}; + +FMT_END_NAMESPACE + +#endif // FMT_STD_H_ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fmt/xchar.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fmt/xchar.h new file mode 100644 index 0000000000000000000000000000000000000000..1cf7170e8bc421671f0e575a60e8b5c1f53d0ce6 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fmt/xchar.h @@ -0,0 +1,361 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +// Formatting library for C++ - optional wchar_t and exotic character support +// +// Copyright (c) 2012 - present, Victor Zverovich +// All rights reserved. +// +// For the license information refer to format.h. + +#ifndef FMT_XCHAR_H_ +#define FMT_XCHAR_H_ + +#include "color.h" +#include "format.h" +#include "ostream.h" +#include "ranges.h" + +#ifndef FMT_MODULE +# include +# if FMT_USE_LOCALE +# include +# endif +#endif + +FMT_BEGIN_NAMESPACE +namespace detail { + +template +using is_exotic_char = bool_constant::value>; + +template struct format_string_char {}; + +template +struct format_string_char< + S, void_t())))>> { + using type = char_t; +}; + +template +struct format_string_char< + S, enable_if_t::value>> { + using type = typename S::char_type; +}; + +template +using format_string_char_t = typename format_string_char::type; + +inline auto write_loc(basic_appender out, loc_value value, + const format_specs& specs, locale_ref loc) -> bool { +#if FMT_USE_LOCALE + auto& numpunct = + std::use_facet>(loc.get()); + auto separator = std::wstring(); + auto grouping = numpunct.grouping(); + if (!grouping.empty()) separator = std::wstring(1, numpunct.thousands_sep()); + return value.visit(loc_writer{out, specs, separator, grouping, {}}); +#endif + return false; +} + +template +void vformat_to(buffer& buf, basic_string_view fmt, + basic_format_args> args, + locale_ref loc = {}) { + static_assert(!std::is_same::value, ""); + auto out = basic_appender(buf); + parse_format_string( + fmt, format_handler{parse_context(fmt), {out, args, loc}}); +} +} // namespace detail + +FMT_BEGIN_EXPORT + +using wstring_view = basic_string_view; +using wformat_parse_context = parse_context; +using wformat_context = buffered_context; +using wformat_args = basic_format_args; +using wmemory_buffer = basic_memory_buffer; + +template struct basic_fstring { + private: + basic_string_view str_; + + static constexpr int num_static_named_args = + detail::count_static_named_args(); + + using checker = detail::format_string_checker< + Char, static_cast(sizeof...(T)), num_static_named_args, + num_static_named_args != detail::count_named_args()>; + + using arg_pack = detail::arg_pack; + + public: + using t = basic_fstring; + + template >::value)> + FMT_CONSTEVAL FMT_ALWAYS_INLINE basic_fstring(const S& s) : str_(s) { + if (FMT_USE_CONSTEVAL) + detail::parse_format_string(s, checker(s, arg_pack())); + } + template ::value&& + std::is_same::value)> + FMT_ALWAYS_INLINE basic_fstring(const S&) : str_(S()) { + FMT_CONSTEXPR auto sv = basic_string_view(S()); + FMT_CONSTEXPR int ignore = + (parse_format_string(sv, checker(sv, arg_pack())), 0); + detail::ignore_unused(ignore); + } + basic_fstring(runtime_format_string fmt) : str_(fmt.str) {} + + operator basic_string_view() const { return str_; } + auto get() const -> basic_string_view { return str_; } +}; + +template +using basic_format_string = basic_fstring; + +template +using wformat_string = typename basic_format_string::t; +inline auto runtime(wstring_view s) -> runtime_format_string { + return {{s}}; +} + +template +constexpr auto make_wformat_args(T&... args) + -> decltype(fmt::make_format_args(args...)) { + return fmt::make_format_args(args...); +} + +#if !FMT_USE_NONTYPE_TEMPLATE_ARGS +inline namespace literals { +inline auto operator""_a(const wchar_t* s, size_t) -> detail::udl_arg { + return {s}; +} +} // namespace literals +#endif + +template +auto join(It begin, Sentinel end, wstring_view sep) + -> join_view { + return {begin, end, sep}; +} + +template ::value)> +auto join(Range&& range, wstring_view sep) + -> join_view { + return join(std::begin(range), std::end(range), sep); +} + +template +auto join(std::initializer_list list, wstring_view sep) + -> join_view { + return join(std::begin(list), std::end(list), sep); +} + +template ::value)> +auto join(const Tuple& tuple, basic_string_view sep) + -> tuple_join_view { + return {tuple, sep}; +} + +template ::value)> +auto vformat(basic_string_view fmt, + basic_format_args> args) + -> std::basic_string { + auto buf = basic_memory_buffer(); + detail::vformat_to(buf, fmt, args); + return {buf.data(), buf.size()}; +} + +template +auto format(wformat_string fmt, T&&... args) -> std::wstring { + return vformat(fmt::wstring_view(fmt), fmt::make_wformat_args(args...)); +} + +template +auto format_to(OutputIt out, wformat_string fmt, T&&... args) + -> OutputIt { + return vformat_to(out, fmt::wstring_view(fmt), + fmt::make_wformat_args(args...)); +} + +// Pass char_t as a default template parameter instead of using +// std::basic_string> to reduce the symbol size. +template , + FMT_ENABLE_IF(!std::is_same::value && + !std::is_same::value)> +auto format(const S& fmt, T&&... args) -> std::basic_string { + return vformat(detail::to_string_view(fmt), + fmt::make_format_args>(args...)); +} + +template , + FMT_ENABLE_IF(detail::is_exotic_char::value)> +inline auto vformat(locale_ref loc, const S& fmt, + basic_format_args> args) + -> std::basic_string { + auto buf = basic_memory_buffer(); + detail::vformat_to(buf, detail::to_string_view(fmt), args, loc); + return {buf.data(), buf.size()}; +} + +template , + FMT_ENABLE_IF(detail::is_exotic_char::value)> +inline auto format(locale_ref loc, const S& fmt, T&&... args) + -> std::basic_string { + return vformat(loc, detail::to_string_view(fmt), + fmt::make_format_args>(args...)); +} + +template , + FMT_ENABLE_IF(detail::is_output_iterator::value&& + detail::is_exotic_char::value)> +auto vformat_to(OutputIt out, const S& fmt, + basic_format_args> args) -> OutputIt { + auto&& buf = detail::get_buffer(out); + detail::vformat_to(buf, detail::to_string_view(fmt), args); + return detail::get_iterator(buf, out); +} + +template , + FMT_ENABLE_IF(detail::is_output_iterator::value && + !std::is_same::value && + !std::is_same::value)> +inline auto format_to(OutputIt out, const S& fmt, T&&... args) -> OutputIt { + return vformat_to(out, detail::to_string_view(fmt), + fmt::make_format_args>(args...)); +} + +template , + FMT_ENABLE_IF(detail::is_output_iterator::value&& + detail::is_exotic_char::value)> +inline auto vformat_to(OutputIt out, locale_ref loc, const S& fmt, + basic_format_args> args) + -> OutputIt { + auto&& buf = detail::get_buffer(out); + vformat_to(buf, detail::to_string_view(fmt), args, loc); + return detail::get_iterator(buf, out); +} + +template , + bool enable = detail::is_output_iterator::value && + detail::is_exotic_char::value> +inline auto format_to(OutputIt out, locale_ref loc, const S& fmt, T&&... args) + -> typename std::enable_if::type { + return vformat_to(out, loc, detail::to_string_view(fmt), + fmt::make_format_args>(args...)); +} + +template ::value&& + detail::is_exotic_char::value)> +inline auto vformat_to_n(OutputIt out, size_t n, basic_string_view fmt, + basic_format_args> args) + -> format_to_n_result { + using traits = detail::fixed_buffer_traits; + auto buf = detail::iterator_buffer(out, n); + detail::vformat_to(buf, fmt, args); + return {buf.out(), buf.count()}; +} + +template , + FMT_ENABLE_IF(detail::is_output_iterator::value&& + detail::is_exotic_char::value)> +inline auto format_to_n(OutputIt out, size_t n, const S& fmt, T&&... args) + -> format_to_n_result { + return vformat_to_n(out, n, fmt::basic_string_view(fmt), + fmt::make_format_args>(args...)); +} + +template , + FMT_ENABLE_IF(detail::is_exotic_char::value)> +inline auto formatted_size(const S& fmt, T&&... args) -> size_t { + auto buf = detail::counting_buffer(); + detail::vformat_to(buf, detail::to_string_view(fmt), + fmt::make_format_args>(args...)); + return buf.count(); +} + +inline void vprint(std::FILE* f, wstring_view fmt, wformat_args args) { + auto buf = wmemory_buffer(); + detail::vformat_to(buf, fmt, args); + buf.push_back(L'\0'); + if (std::fputws(buf.data(), f) == -1) + FMT_THROW(system_error(errno, FMT_STRING("cannot write to file"))); +} + +inline void vprint(wstring_view fmt, wformat_args args) { + vprint(stdout, fmt, args); +} + +template +void print(std::FILE* f, wformat_string fmt, T&&... args) { + return vprint(f, wstring_view(fmt), fmt::make_wformat_args(args...)); +} + +template void print(wformat_string fmt, T&&... args) { + return vprint(wstring_view(fmt), fmt::make_wformat_args(args...)); +} + +template +void println(std::FILE* f, wformat_string fmt, T&&... args) { + return print(f, L"{}\n", fmt::format(fmt, std::forward(args)...)); +} + +template void println(wformat_string fmt, T&&... args) { + return print(L"{}\n", fmt::format(fmt, std::forward(args)...)); +} + +inline auto vformat(text_style ts, wstring_view fmt, wformat_args args) + -> std::wstring { + auto buf = wmemory_buffer(); + detail::vformat_to(buf, ts, fmt, args); + return {buf.data(), buf.size()}; +} + +template +inline auto format(text_style ts, wformat_string fmt, T&&... args) + -> std::wstring { + return fmt::vformat(ts, fmt, fmt::make_wformat_args(args...)); +} + +inline void vprint(std::wostream& os, wstring_view fmt, wformat_args args) { + auto buffer = basic_memory_buffer(); + detail::vformat_to(buffer, fmt, args); + detail::write_buffer(os, buffer); +} + +template +void print(std::wostream& os, wformat_string fmt, T&&... args) { + vprint(os, fmt, fmt::make_format_args>(args...)); +} + +template +void println(std::wostream& os, wformat_string fmt, T&&... args) { + print(os, L"{}\n", fmt::format(fmt, std::forward(args)...)); +} + +/// Converts `value` to `std::wstring` using the default format for type `T`. +template inline auto to_wstring(const T& value) -> std::wstring { + return format(FMT_STRING(L"{}"), value); +} +FMT_END_EXPORT +FMT_END_NAMESPACE + +#endif // FMT_XCHAR_H_ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fp16/bitcasts.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fp16/bitcasts.h new file mode 100644 index 0000000000000000000000000000000000000000..55461b797fb041b8f5d5b6a313cf2ec6fedcce6d --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fp16/bitcasts.h @@ -0,0 +1,97 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +#ifndef FP16_BITCASTS_H +#define FP16_BITCASTS_H + +#if defined(__cplusplus) && (__cplusplus >= 201103L) + #include +#elif !defined(__OPENCL_VERSION__) + #include +#endif + +#if defined(__INTEL_COMPILER) + #include +#endif + +#if defined(_MSC_VER) && (defined(_M_ARM) || defined(_M_ARM64)) + #include +#endif + + +static inline float fp32_from_bits(uint32_t w) { +#if defined(__OPENCL_VERSION__) + return as_float(w); +#elif defined(__CUDA_ARCH__) + return __uint_as_float((unsigned int) w); +#elif defined(__INTEL_COMPILER) + return _castu32_f32(w); +#elif defined(_MSC_VER) && (defined(_M_ARM) || defined(_M_ARM64)) + return _CopyFloatFromInt32((__int32) w); +#else + union { + uint32_t as_bits; + float as_value; + } fp32 = { w }; + return fp32.as_value; +#endif +} + +static inline uint32_t fp32_to_bits(float f) { +#if defined(__OPENCL_VERSION__) + return as_uint(f); +#elif defined(__CUDA_ARCH__) + return (uint32_t) __float_as_uint(f); +#elif defined(__INTEL_COMPILER) + return _castf32_u32(f); +#elif defined(_MSC_VER) && (defined(_M_ARM) || defined(_M_ARM64)) + return (uint32_t) _CopyInt32FromFloat(f); +#else + union { + float as_value; + uint32_t as_bits; + } fp32 = { f }; + return fp32.as_bits; +#endif +} + +static inline double fp64_from_bits(uint64_t w) { +#if defined(__OPENCL_VERSION__) + return as_double(w); +#elif defined(__CUDA_ARCH__) + return __longlong_as_double((long long) w); +#elif defined(__INTEL_COMPILER) + return _castu64_f64(w); +#elif defined(_MSC_VER) && (defined(_M_ARM) || defined(_M_ARM64)) + return _CopyDoubleFromInt64((__int64) w); +#else + union { + uint64_t as_bits; + double as_value; + } fp64 = { w }; + return fp64.as_value; +#endif +} + +static inline uint64_t fp64_to_bits(double f) { +#if defined(__OPENCL_VERSION__) + return as_ulong(f); +#elif defined(__CUDA_ARCH__) + return (uint64_t) __double_as_longlong(f); +#elif defined(__INTEL_COMPILER) + return _castf64_u64(f); +#elif defined(_MSC_VER) && (defined(_M_ARM) || defined(_M_ARM64)) + return (uint64_t) _CopyInt64FromDouble(f); +#else + union { + double as_value; + uint64_t as_bits; + } fp64 = { f }; + return fp64.as_bits; +#endif +} + +#endif /* FP16_BITCASTS_H */ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fp16/fp16.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fp16/fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..4e29f28ccea774a363d6c624d53a4acc9783b4a3 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fp16/fp16.h @@ -0,0 +1,456 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +#ifndef FP16_FP16_H +#define FP16_FP16_H + +#if defined(__cplusplus) && (__cplusplus >= 201103L) + #include + #include +#elif !defined(__OPENCL_VERSION__) + #include + #include +#endif + +#ifdef _MSC_VER + #include +#endif + +#include + + +/* + * Convert a 16-bit floating-point number in IEEE half-precision format, in bit representation, to + * a 32-bit floating-point number in IEEE single-precision format, in bit representation. + * + * @note The implementation doesn't use any floating-point operations. + */ +static inline uint32_t fp16_ieee_to_fp32_bits(uint16_t h) { + /* + * Extend the half-precision floating-point number to 32 bits and shift to the upper part of the 32-bit word: + * +---+-----+------------+-------------------+ + * | S |EEEEE|MM MMMM MMMM|0000 0000 0000 0000| + * +---+-----+------------+-------------------+ + * Bits 31 26-30 16-25 0-15 + * + * S - sign bit, E - bits of the biased exponent, M - bits of the mantissa, 0 - zero bits. + */ + const uint32_t w = (uint32_t) h << 16; + /* + * Extract the sign of the input number into the high bit of the 32-bit word: + * + * +---+----------------------------------+ + * | S |0000000 00000000 00000000 00000000| + * +---+----------------------------------+ + * Bits 31 0-31 + */ + const uint32_t sign = w & UINT32_C(0x80000000); + /* + * Extract mantissa and biased exponent of the input number into the bits 0-30 of the 32-bit word: + * + * +---+-----+------------+-------------------+ + * | 0 |EEEEE|MM MMMM MMMM|0000 0000 0000 0000| + * +---+-----+------------+-------------------+ + * Bits 30 27-31 17-26 0-16 + */ + const uint32_t nonsign = w & UINT32_C(0x7FFFFFFF); + /* + * Renorm shift is the number of bits to shift mantissa left to make the half-precision number normalized. + * If the initial number is normalized, some of its high 6 bits (sign == 0 and 5-bit exponent) equals one. + * In this case renorm_shift == 0. If the number is denormalize, renorm_shift > 0. Note that if we shift + * denormalized nonsign by renorm_shift, the unit bit of mantissa will shift into exponent, turning the + * biased exponent into 1, and making mantissa normalized (i.e. without leading 1). + */ +#ifdef _MSC_VER + unsigned long nonsign_bsr; + _BitScanReverse(&nonsign_bsr, (unsigned long) nonsign); + uint32_t renorm_shift = (uint32_t) nonsign_bsr ^ 31; +#else + uint32_t renorm_shift = __builtin_clz(nonsign); +#endif + renorm_shift = renorm_shift > 5 ? renorm_shift - 5 : 0; + /* + * Iff half-precision number has exponent of 15, the addition overflows it into bit 31, + * and the subsequent shift turns the high 9 bits into 1. Thus + * inf_nan_mask == + * 0x7F800000 if the half-precision number had exponent of 15 (i.e. was NaN or infinity) + * 0x00000000 otherwise + */ + const int32_t inf_nan_mask = ((int32_t) (nonsign + 0x04000000) >> 8) & INT32_C(0x7F800000); + /* + * Iff nonsign is 0, it overflows into 0xFFFFFFFF, turning bit 31 into 1. Otherwise, bit 31 remains 0. + * The signed shift right by 31 broadcasts bit 31 into all bits of the zero_mask. Thus + * zero_mask == + * 0xFFFFFFFF if the half-precision number was zero (+0.0h or -0.0h) + * 0x00000000 otherwise + */ + const int32_t zero_mask = (int32_t) (nonsign - 1) >> 31; + /* + * 1. Shift nonsign left by renorm_shift to normalize it (if the input was denormal) + * 2. Shift nonsign right by 3 so the exponent (5 bits originally) becomes an 8-bit field and 10-bit mantissa + * shifts into the 10 high bits of the 23-bit mantissa of IEEE single-precision number. + * 3. Add 0x70 to the exponent (starting at bit 23) to compensate the different in exponent bias + * (0x7F for single-precision number less 0xF for half-precision number). + * 4. Subtract renorm_shift from the exponent (starting at bit 23) to account for renormalization. As renorm_shift + * is less than 0x70, this can be combined with step 3. + * 5. Binary OR with inf_nan_mask to turn the exponent into 0xFF if the input was NaN or infinity. + * 6. Binary ANDNOT with zero_mask to turn the mantissa and exponent into zero if the input was zero. + * 7. Combine with the sign of the input number. + */ + return sign | ((((nonsign << renorm_shift >> 3) + ((0x70 - renorm_shift) << 23)) | inf_nan_mask) & ~zero_mask); +} + +/* + * Convert a 16-bit floating-point number in IEEE half-precision format, in bit representation, to + * a 32-bit floating-point number in IEEE single-precision format. + * + * @note The implementation relies on IEEE-like (no assumption about rounding mode and no operations on denormals) + * floating-point operations and bitcasts between integer and floating-point variables. + */ +static inline float fp16_ieee_to_fp32_value(uint16_t h) { + /* + * Extend the half-precision floating-point number to 32 bits and shift to the upper part of the 32-bit word: + * +---+-----+------------+-------------------+ + * | S |EEEEE|MM MMMM MMMM|0000 0000 0000 0000| + * +---+-----+------------+-------------------+ + * Bits 31 26-30 16-25 0-15 + * + * S - sign bit, E - bits of the biased exponent, M - bits of the mantissa, 0 - zero bits. + */ + const uint32_t w = (uint32_t) h << 16; + /* + * Extract the sign of the input number into the high bit of the 32-bit word: + * + * +---+----------------------------------+ + * | S |0000000 00000000 00000000 00000000| + * +---+----------------------------------+ + * Bits 31 0-31 + */ + const uint32_t sign = w & UINT32_C(0x80000000); + /* + * Extract mantissa and biased exponent of the input number into the high bits of the 32-bit word: + * + * +-----+------------+---------------------+ + * |EEEEE|MM MMMM MMMM|0 0000 0000 0000 0000| + * +-----+------------+---------------------+ + * Bits 27-31 17-26 0-16 + */ + const uint32_t two_w = w + w; + + /* + * Shift mantissa and exponent into bits 23-28 and bits 13-22 so they become mantissa and exponent + * of a single-precision floating-point number: + * + * S|Exponent | Mantissa + * +-+---+-----+------------+----------------+ + * |0|000|EEEEE|MM MMMM MMMM|0 0000 0000 0000| + * +-+---+-----+------------+----------------+ + * Bits | 23-31 | 0-22 + * + * Next, there are some adjustments to the exponent: + * - The exponent needs to be corrected by the difference in exponent bias between single-precision and half-precision + * formats (0x7F - 0xF = 0x70) + * - Inf and NaN values in the inputs should become Inf and NaN values after conversion to the single-precision number. + * Therefore, if the biased exponent of the half-precision input was 0x1F (max possible value), the biased exponent + * of the single-precision output must be 0xFF (max possible value). We do this correction in two steps: + * - First, we adjust the exponent by (0xFF - 0x1F) = 0xE0 (see exp_offset below) rather than by 0x70 suggested + * by the difference in the exponent bias (see above). + * - Then we multiply the single-precision result of exponent adjustment by 2**(-112) to reverse the effect of + * exponent adjustment by 0xE0 less the necessary exponent adjustment by 0x70 due to difference in exponent bias. + * The floating-point multiplication hardware would ensure than Inf and NaN would retain their value on at least + * partially IEEE754-compliant implementations. + * + * Note that the above operations do not handle denormal inputs (where biased exponent == 0). However, they also do not + * operate on denormal inputs, and do not produce denormal results. + */ + const uint32_t exp_offset = UINT32_C(0xE0) << 23; +#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__) + const float exp_scale = 0x1.0p-112f; +#else + const float exp_scale = fp32_from_bits(UINT32_C(0x7800000)); +#endif + const float normalized_value = fp32_from_bits((two_w >> 4) + exp_offset) * exp_scale; + + /* + * Convert denormalized half-precision inputs into single-precision results (always normalized). + * Zero inputs are also handled here. + * + * In a denormalized number the biased exponent is zero, and mantissa has on-zero bits. + * First, we shift mantissa into bits 0-9 of the 32-bit word. + * + * zeros | mantissa + * +---------------------------+------------+ + * |0000 0000 0000 0000 0000 00|MM MMMM MMMM| + * +---------------------------+------------+ + * Bits 10-31 0-9 + * + * Now, remember that denormalized half-precision numbers are represented as: + * FP16 = mantissa * 2**(-24). + * The trick is to construct a normalized single-precision number with the same mantissa and thehalf-precision input + * and with an exponent which would scale the corresponding mantissa bits to 2**(-24). + * A normalized single-precision floating-point number is represented as: + * FP32 = (1 + mantissa * 2**(-23)) * 2**(exponent - 127) + * Therefore, when the biased exponent is 126, a unit change in the mantissa of the input denormalized half-precision + * number causes a change of the constructud single-precision number by 2**(-24), i.e. the same ammount. + * + * The last step is to adjust the bias of the constructed single-precision number. When the input half-precision number + * is zero, the constructed single-precision number has the value of + * FP32 = 1 * 2**(126 - 127) = 2**(-1) = 0.5 + * Therefore, we need to subtract 0.5 from the constructed single-precision number to get the numerical equivalent of + * the input half-precision number. + */ + const uint32_t magic_mask = UINT32_C(126) << 23; + const float magic_bias = 0.5f; + const float denormalized_value = fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias; + + /* + * - Choose either results of conversion of input as a normalized number, or as a denormalized number, depending on the + * input exponent. The variable two_w contains input exponent in bits 27-31, therefore if its smaller than 2**27, the + * input is either a denormal number, or zero. + * - Combine the result of conversion of exponent and mantissa with the sign of the input number. + */ + const uint32_t denormalized_cutoff = UINT32_C(1) << 27; + const uint32_t result = sign | + (two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value) : fp32_to_bits(normalized_value)); + return fp32_from_bits(result); +} + +/* + * Convert a 32-bit floating-point number in IEEE single-precision format to a 16-bit floating-point number in + * IEEE half-precision format, in bit representation. + * + * @note The implementation relies on IEEE-like (no assumption about rounding mode and no operations on denormals) + * floating-point operations and bitcasts between integer and floating-point variables. + */ +static inline uint16_t fp16_ieee_from_fp32_value(float f) { +#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__) + const float scale_to_inf = 0x1.0p+112f; + const float scale_to_zero = 0x1.0p-110f; +#else + const float scale_to_inf = fp32_from_bits(UINT32_C(0x77800000)); + const float scale_to_zero = fp32_from_bits(UINT32_C(0x08800000)); +#endif + float base = (fabsf(f) * scale_to_inf) * scale_to_zero; + + const uint32_t w = fp32_to_bits(f); + const uint32_t shl1_w = w + w; + const uint32_t sign = w & UINT32_C(0x80000000); + uint32_t bias = shl1_w & UINT32_C(0xFF000000); + if (bias < UINT32_C(0x71000000)) { + bias = UINT32_C(0x71000000); + } + + base = fp32_from_bits((bias >> 1) + UINT32_C(0x07800000)) + base; + const uint32_t bits = fp32_to_bits(base); + const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00); + const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF); + const uint32_t nonsign = exp_bits + mantissa_bits; + return (sign >> 16) | (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign); +} + +/* + * Convert a 16-bit floating-point number in ARM alternative half-precision format, in bit representation, to + * a 32-bit floating-point number in IEEE single-precision format, in bit representation. + * + * @note The implementation doesn't use any floating-point operations. + */ +static inline uint32_t fp16_alt_to_fp32_bits(uint16_t h) { + /* + * Extend the half-precision floating-point number to 32 bits and shift to the upper part of the 32-bit word: + * +---+-----+------------+-------------------+ + * | S |EEEEE|MM MMMM MMMM|0000 0000 0000 0000| + * +---+-----+------------+-------------------+ + * Bits 31 26-30 16-25 0-15 + * + * S - sign bit, E - bits of the biased exponent, M - bits of the mantissa, 0 - zero bits. + */ + const uint32_t w = (uint32_t) h << 16; + /* + * Extract the sign of the input number into the high bit of the 32-bit word: + * + * +---+----------------------------------+ + * | S |0000000 00000000 00000000 00000000| + * +---+----------------------------------+ + * Bits 31 0-31 + */ + const uint32_t sign = w & UINT32_C(0x80000000); + /* + * Extract mantissa and biased exponent of the input number into the bits 0-30 of the 32-bit word: + * + * +---+-----+------------+-------------------+ + * | 0 |EEEEE|MM MMMM MMMM|0000 0000 0000 0000| + * +---+-----+------------+-------------------+ + * Bits 30 27-31 17-26 0-16 + */ + const uint32_t nonsign = w & UINT32_C(0x7FFFFFFF); + /* + * Renorm shift is the number of bits to shift mantissa left to make the half-precision number normalized. + * If the initial number is normalized, some of its high 6 bits (sign == 0 and 5-bit exponent) equals one. + * In this case renorm_shift == 0. If the number is denormalize, renorm_shift > 0. Note that if we shift + * denormalized nonsign by renorm_shift, the unit bit of mantissa will shift into exponent, turning the + * biased exponent into 1, and making mantissa normalized (i.e. without leading 1). + */ +#ifdef _MSC_VER + unsigned long nonsign_bsr; + _BitScanReverse(&nonsign_bsr, (unsigned long) nonsign); + uint32_t renorm_shift = (uint32_t) nonsign_bsr ^ 31; +#else + uint32_t renorm_shift = __builtin_clz(nonsign); +#endif + renorm_shift = renorm_shift > 5 ? renorm_shift - 5 : 0; + /* + * Iff nonsign is 0, it overflows into 0xFFFFFFFF, turning bit 31 into 1. Otherwise, bit 31 remains 0. + * The signed shift right by 31 broadcasts bit 31 into all bits of the zero_mask. Thus + * zero_mask == + * 0xFFFFFFFF if the half-precision number was zero (+0.0h or -0.0h) + * 0x00000000 otherwise + */ + const int32_t zero_mask = (int32_t) (nonsign - 1) >> 31; + /* + * 1. Shift nonsign left by renorm_shift to normalize it (if the input was denormal) + * 2. Shift nonsign right by 3 so the exponent (5 bits originally) becomes an 8-bit field and 10-bit mantissa + * shifts into the 10 high bits of the 23-bit mantissa of IEEE single-precision number. + * 3. Add 0x70 to the exponent (starting at bit 23) to compensate the different in exponent bias + * (0x7F for single-precision number less 0xF for half-precision number). + * 4. Subtract renorm_shift from the exponent (starting at bit 23) to account for renormalization. As renorm_shift + * is less than 0x70, this can be combined with step 3. + * 5. Binary ANDNOT with zero_mask to turn the mantissa and exponent into zero if the input was zero. + * 6. Combine with the sign of the input number. + */ + return sign | (((nonsign << renorm_shift >> 3) + ((0x70 - renorm_shift) << 23)) & ~zero_mask); +} + +/* + * Convert a 16-bit floating-point number in ARM alternative half-precision format, in bit representation, to + * a 32-bit floating-point number in IEEE single-precision format. + * + * @note The implementation relies on IEEE-like (no assumption about rounding mode and no operations on denormals) + * floating-point operations and bitcasts between integer and floating-point variables. + */ +static inline float fp16_alt_to_fp32_value(uint16_t h) { + /* + * Extend the half-precision floating-point number to 32 bits and shift to the upper part of the 32-bit word: + * +---+-----+------------+-------------------+ + * | S |EEEEE|MM MMMM MMMM|0000 0000 0000 0000| + * +---+-----+------------+-------------------+ + * Bits 31 26-30 16-25 0-15 + * + * S - sign bit, E - bits of the biased exponent, M - bits of the mantissa, 0 - zero bits. + */ + const uint32_t w = (uint32_t) h << 16; + /* + * Extract the sign of the input number into the high bit of the 32-bit word: + * + * +---+----------------------------------+ + * | S |0000000 00000000 00000000 00000000| + * +---+----------------------------------+ + * Bits 31 0-31 + */ + const uint32_t sign = w & UINT32_C(0x80000000); + /* + * Extract mantissa and biased exponent of the input number into the high bits of the 32-bit word: + * + * +-----+------------+---------------------+ + * |EEEEE|MM MMMM MMMM|0 0000 0000 0000 0000| + * +-----+------------+---------------------+ + * Bits 27-31 17-26 0-16 + */ + const uint32_t two_w = w + w; + + /* + * Shift mantissa and exponent into bits 23-28 and bits 13-22 so they become mantissa and exponent + * of a single-precision floating-point number: + * + * S|Exponent | Mantissa + * +-+---+-----+------------+----------------+ + * |0|000|EEEEE|MM MMMM MMMM|0 0000 0000 0000| + * +-+---+-----+------------+----------------+ + * Bits | 23-31 | 0-22 + * + * Next, the exponent is adjusted for the difference in exponent bias between single-precision and half-precision + * formats (0x7F - 0xF = 0x70). This operation never overflows or generates non-finite values, as the largest + * half-precision exponent is 0x1F and after the adjustment is can not exceed 0x8F < 0xFE (largest single-precision + * exponent for non-finite values). + * + * Note that this operation does not handle denormal inputs (where biased exponent == 0). However, they also do not + * operate on denormal inputs, and do not produce denormal results. + */ + const uint32_t exp_offset = UINT32_C(0x70) << 23; + const float normalized_value = fp32_from_bits((two_w >> 4) + exp_offset); + + /* + * Convert denormalized half-precision inputs into single-precision results (always normalized). + * Zero inputs are also handled here. + * + * In a denormalized number the biased exponent is zero, and mantissa has on-zero bits. + * First, we shift mantissa into bits 0-9 of the 32-bit word. + * + * zeros | mantissa + * +---------------------------+------------+ + * |0000 0000 0000 0000 0000 00|MM MMMM MMMM| + * +---------------------------+------------+ + * Bits 10-31 0-9 + * + * Now, remember that denormalized half-precision numbers are represented as: + * FP16 = mantissa * 2**(-24). + * The trick is to construct a normalized single-precision number with the same mantissa and thehalf-precision input + * and with an exponent which would scale the corresponding mantissa bits to 2**(-24). + * A normalized single-precision floating-point number is represented as: + * FP32 = (1 + mantissa * 2**(-23)) * 2**(exponent - 127) + * Therefore, when the biased exponent is 126, a unit change in the mantissa of the input denormalized half-precision + * number causes a change of the constructud single-precision number by 2**(-24), i.e. the same ammount. + * + * The last step is to adjust the bias of the constructed single-precision number. When the input half-precision number + * is zero, the constructed single-precision number has the value of + * FP32 = 1 * 2**(126 - 127) = 2**(-1) = 0.5 + * Therefore, we need to subtract 0.5 from the constructed single-precision number to get the numerical equivalent of + * the input half-precision number. + */ + const uint32_t magic_mask = UINT32_C(126) << 23; + const float magic_bias = 0.5f; + const float denormalized_value = fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias; + + /* + * - Choose either results of conversion of input as a normalized number, or as a denormalized number, depending on the + * input exponent. The variable two_w contains input exponent in bits 27-31, therefore if its smaller than 2**27, the + * input is either a denormal number, or zero. + * - Combine the result of conversion of exponent and mantissa with the sign of the input number. + */ + const uint32_t denormalized_cutoff = UINT32_C(1) << 27; + const uint32_t result = sign | + (two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value) : fp32_to_bits(normalized_value)); + return fp32_from_bits(result); +} + +/* + * Convert a 32-bit floating-point number in IEEE single-precision format to a 16-bit floating-point number in + * ARM alternative half-precision format, in bit representation. + * + * @note The implementation relies on IEEE-like (no assumption about rounding mode and no operations on denormals) + * floating-point operations and bitcasts between integer and floating-point variables. + */ +static inline uint16_t fp16_alt_from_fp32_value(float f) { + const uint32_t w = fp32_to_bits(f); + const uint32_t sign = w & UINT32_C(0x80000000); + const uint32_t shl1_w = w + w; + + const uint32_t shl1_max_fp16_fp32 = UINT32_C(0x8FFFC000); + const uint32_t shl1_base = shl1_w > shl1_max_fp16_fp32 ? shl1_max_fp16_fp32 : shl1_w; + uint32_t shl1_bias = shl1_base & UINT32_C(0xFF000000); + const uint32_t exp_difference = 23 - 10; + const uint32_t shl1_bias_min = (127 - 1 - exp_difference) << 24; + if (shl1_bias < shl1_bias_min) { + shl1_bias = shl1_bias_min; + } + + const float bias = fp32_from_bits((shl1_bias >> 1) + ((exp_difference + 2) << 23)); + const float base = fp32_from_bits((shl1_base >> 1) + (2 << 23)) + bias; + + const uint32_t exp_f = fp32_to_bits(base) >> 13; + return (sign >> 16) | ((exp_f & UINT32_C(0x00007C00)) + (fp32_to_bits(base) & UINT32_C(0x00000FFF))); +} + +#endif /* FP16_FP16_H */ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fp16/psimd.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fp16/psimd.h new file mode 100644 index 0000000000000000000000000000000000000000..346df7b3de42625fb0c1b6774b614ffce0e94993 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/fp16/psimd.h @@ -0,0 +1,136 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +#pragma once +#ifndef FP16_PSIMD_H +#define FP16_PSIMD_H + +#if defined(__cplusplus) && (__cplusplus >= 201103L) + #include +#elif !defined(__OPENCL_VERSION__) + #include +#endif + +#include + + +PSIMD_INTRINSIC psimd_f32 fp16_ieee_to_fp32_psimd(psimd_u16 half) { + const psimd_u32 word = (psimd_u32) psimd_interleave_lo_u16(psimd_zero_u16(), half); + + const psimd_u32 sign = word & psimd_splat_u32(UINT32_C(0x80000000)); + const psimd_u32 shr3_nonsign = (word + word) >> psimd_splat_u32(4); + + const psimd_u32 exp_offset = psimd_splat_u32(UINT32_C(0x70000000)); +#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__) + const psimd_f32 exp_scale = psimd_splat_f32(0x1.0p-112f); +#else + const psimd_f32 exp_scale = psimd_splat_f32(fp32_from_bits(UINT32_C(0x7800000))); +#endif + const psimd_f32 norm_nonsign = psimd_mul_f32((psimd_f32) (shr3_nonsign + exp_offset), exp_scale); + + const psimd_u16 magic_mask = psimd_splat_u16(UINT16_C(0x3E80)); + const psimd_f32 magic_bias = psimd_splat_f32(0.25f); + const psimd_f32 denorm_nonsign = psimd_sub_f32((psimd_f32) psimd_interleave_lo_u16(half + half, magic_mask), magic_bias); + + const psimd_s32 denorm_cutoff = psimd_splat_s32(INT32_C(0x00800000)); + const psimd_s32 denorm_mask = (psimd_s32) shr3_nonsign < denorm_cutoff; + return (psimd_f32) (sign | (psimd_s32) psimd_blend_f32(denorm_mask, denorm_nonsign, norm_nonsign)); +} + +PSIMD_INTRINSIC psimd_f32x2 fp16_ieee_to_fp32x2_psimd(psimd_u16 half) { + const psimd_u32 word_lo = (psimd_u32) psimd_interleave_lo_u16(psimd_zero_u16(), half); + const psimd_u32 word_hi = (psimd_u32) psimd_interleave_hi_u16(psimd_zero_u16(), half); + + const psimd_u32 sign_mask = psimd_splat_u32(UINT32_C(0x80000000)); + const psimd_u32 sign_lo = word_lo & sign_mask; + const psimd_u32 sign_hi = word_hi & sign_mask; + const psimd_u32 shr3_nonsign_lo = (word_lo + word_lo) >> psimd_splat_u32(4); + const psimd_u32 shr3_nonsign_hi = (word_hi + word_hi) >> psimd_splat_u32(4); + + const psimd_u32 exp_offset = psimd_splat_u32(UINT32_C(0x70000000)); +#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__) + const psimd_f32 exp_scale = psimd_splat_f32(0x1.0p-112f); +#else + const psimd_f32 exp_scale = psimd_splat_f32(fp32_from_bits(UINT32_C(0x7800000))); +#endif + const psimd_f32 norm_nonsign_lo = psimd_mul_f32((psimd_f32) (shr3_nonsign_lo + exp_offset), exp_scale); + const psimd_f32 norm_nonsign_hi = psimd_mul_f32((psimd_f32) (shr3_nonsign_hi + exp_offset), exp_scale); + + const psimd_u16 magic_mask = psimd_splat_u16(UINT16_C(0x3E80)); + const psimd_u16 shl1_half = half + half; + const psimd_f32 magic_bias = psimd_splat_f32(0.25f); + const psimd_f32 denorm_nonsign_lo = psimd_sub_f32((psimd_f32) psimd_interleave_lo_u16(shl1_half, magic_mask), magic_bias); + const psimd_f32 denorm_nonsign_hi = psimd_sub_f32((psimd_f32) psimd_interleave_hi_u16(shl1_half, magic_mask), magic_bias); + + const psimd_s32 denorm_cutoff = psimd_splat_s32(INT32_C(0x00800000)); + const psimd_s32 denorm_mask_lo = (psimd_s32) shr3_nonsign_lo < denorm_cutoff; + const psimd_s32 denorm_mask_hi = (psimd_s32) shr3_nonsign_hi < denorm_cutoff; + + psimd_f32x2 result; + result.lo = (psimd_f32) (sign_lo | (psimd_s32) psimd_blend_f32(denorm_mask_lo, denorm_nonsign_lo, norm_nonsign_lo)); + result.hi = (psimd_f32) (sign_hi | (psimd_s32) psimd_blend_f32(denorm_mask_hi, denorm_nonsign_hi, norm_nonsign_hi)); + return result; +} + +PSIMD_INTRINSIC psimd_f32 fp16_alt_to_fp32_psimd(psimd_u16 half) { + const psimd_u32 word = (psimd_u32) psimd_interleave_lo_u16(psimd_zero_u16(), half); + + const psimd_u32 sign = word & psimd_splat_u32(INT32_C(0x80000000)); + const psimd_u32 shr3_nonsign = (word + word) >> psimd_splat_u32(4); + +#if 0 + const psimd_s32 exp112_offset = psimd_splat_s32(INT32_C(0x38000000)); + const psimd_s32 nonsign_bits = (psimd_s32) shr3_nonsign + exp112_offset; + const psimd_s32 exp1_offset = psimd_splat_s32(INT32_C(0x00800000)); + const psimd_f32 two_nonsign = (psimd_f32) (nonsign_bits + exp1_offset); + const psimd_s32 exp113_offset = exp112_offset | exp1_offset; + return (psimd_f32) (sign | (psimd_s32) psimd_sub_f32(two_nonsign, (psimd_f32) psimd_max_s32(nonsign_bits, exp113_offset))); +#else + const psimd_u32 exp_offset = psimd_splat_u32(UINT32_C(0x38000000)); + const psimd_f32 nonsign = (psimd_f32) (shr3_nonsign + exp_offset); +#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__) + const psimd_f32 denorm_bias = psimd_splat_f32(0x1.0p-14f); +#else + const psimd_f32 denorm_bias = psimd_splat_f32(fp32_from_bits(UINT32_C(0x38800000))); +#endif + return (psimd_f32) (sign | (psimd_s32) psimd_sub_f32(psimd_add_f32(nonsign, nonsign), psimd_max_f32(nonsign, denorm_bias))); +#endif +} + +PSIMD_INTRINSIC psimd_f32x2 fp16_alt_to_fp32x2_psimd(psimd_u16 half) { + const psimd_u32 word_lo = (psimd_u32) psimd_interleave_lo_u16(psimd_zero_u16(), half); + const psimd_u32 word_hi = (psimd_u32) psimd_interleave_hi_u16(psimd_zero_u16(), half); + + const psimd_u32 sign_mask = psimd_splat_u32(UINT32_C(0x80000000)); + const psimd_u32 sign_lo = word_lo & sign_mask; + const psimd_u32 sign_hi = word_hi & sign_mask; + const psimd_u32 shr3_nonsign_lo = (word_lo + word_lo) >> psimd_splat_u32(4); + const psimd_u32 shr3_nonsign_hi = (word_hi + word_hi) >> psimd_splat_u32(4); + +#if 1 + const psimd_s32 exp112_offset = psimd_splat_s32(INT32_C(0x38000000)); + const psimd_s32 nonsign_bits_lo = (psimd_s32) shr3_nonsign_lo + exp112_offset; + const psimd_s32 nonsign_bits_hi = (psimd_s32) shr3_nonsign_hi + exp112_offset; + const psimd_s32 exp1_offset = psimd_splat_s32(INT32_C(0x00800000)); + const psimd_f32 two_nonsign_lo = (psimd_f32) (nonsign_bits_lo + exp1_offset); + const psimd_f32 two_nonsign_hi = (psimd_f32) (nonsign_bits_hi + exp1_offset); + const psimd_s32 exp113_offset = exp1_offset | exp112_offset; + psimd_f32x2 result; + result.lo = (psimd_f32) (sign_lo | (psimd_s32) psimd_sub_f32(two_nonsign_lo, (psimd_f32) psimd_max_s32(nonsign_bits_lo, exp113_offset))); + result.hi = (psimd_f32) (sign_hi | (psimd_s32) psimd_sub_f32(two_nonsign_hi, (psimd_f32) psimd_max_s32(nonsign_bits_hi, exp113_offset))); + return result; +#else + const psimd_u32 exp_offset = psimd_splat_u32(UINT32_C(0x38000000)); + const psimd_f32 nonsign_lo = (psimd_f32) (shr3_nonsign_lo + exp_offset); + const psimd_f32 nonsign_hi = (psimd_f32) (shr3_nonsign_hi + exp_offset); + const psimd_f32 denorm_bias = psimd_splat_f32(0x1.0p-14f); + psimd_f32x2 result; + result.lo = (psimd_f32) (sign_lo | (psimd_s32) psimd_sub_f32(psimd_add_f32(nonsign_lo, nonsign_lo), psimd_max_f32(nonsign_lo, denorm_bias))); + result.hi = (psimd_f32) (sign_hi | (psimd_s32) psimd_sub_f32(psimd_add_f32(nonsign_hi, nonsign_hi), psimd_max_f32(nonsign_hi, denorm_bias))); + return result; +#endif +} + +#endif /* FP16_PSIMD_H */ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/kineto/AbstractConfig.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/kineto/AbstractConfig.h new file mode 100644 index 0000000000000000000000000000000000000000..ac90d4a31c29d03e72fa1d43bf7e747c75926678 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/kineto/AbstractConfig.h @@ -0,0 +1,128 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include +#include + +namespace libkineto { + +class AbstractConfig { + public: + AbstractConfig& operator=(const AbstractConfig&) = delete; + AbstractConfig(AbstractConfig&&) = delete; + AbstractConfig& operator=(AbstractConfig&&) = delete; + + virtual ~AbstractConfig() { + for (const auto& p : featureConfigs_) { + delete p.second; + } + } + + // Return a copy of the full derived class + virtual AbstractConfig* cloneDerived(AbstractConfig& parent) const = 0; + + // Returns true if successfully parsed the config string + bool parse(const std::string& conf); + + // Default setup for signal-triggered profiling + virtual void setSignalDefaults() { + for (auto& p : featureConfigs_) { + p.second->setSignalDefaults(); + } + } + + // Default setup for client-triggered profiling + virtual void setClientDefaults() { + for (auto& p : featureConfigs_) { + p.second->setClientDefaults(); + } + } + + // Time config was created / updated + std::chrono::time_point timestamp() const { + return timestamp_; + } + + // Source config string that this was parsed from + const std::string& source() const { + return source_; + } + + AbstractConfig& feature(const std::string& name) const { + const auto& pos = featureConfigs_.find(name); + return *pos->second; + } + + // Transfers ownership of cfg arg + void addFeature(const std::string& name, AbstractConfig* cfg) { + featureConfigs_[name] = cfg; + } + + protected: + AbstractConfig() {} + AbstractConfig(const AbstractConfig& other) = default; + + // Return true if the option was recognized and successfully parsed. + // Throw std::invalid_argument if val is invalid. + virtual bool handleOption(const std::string& name, std::string& val); + + // Perform post-validation checks, typically conditons involving + // multiple options. + // Throw std::invalid_argument if automatic correction can not be made. + // + // @param fallbackProfileStartTime Specify a fallback profile start timestamp + // in case it was never specified by the client + virtual void validate( + const std::chrono::time_point& + fallbackProfileStartTime) = 0; + + // TODO: Separate out each profiler type into features? + virtual void printActivityProfilerConfig(std::ostream& s) const; + virtual void setActivityDependentConfig(); + + // Helpers for use in handleOption + // Split a string by delimiter and remove external white space + std::vector splitAndTrim(const std::string& s, char delim) const; + // Lowercase for case-insensitive comparisons + std::string toLower(std::string& s) const; + // Does string end with suffix + bool endsWith(const std::string& s, const std::string& suffix) const; + // Conversions + int64_t toIntRange(const std::string& val, int64_t min, int64_t max) const; + int32_t toInt32(const std::string& val) const; + int64_t toInt64(const std::string& val) const; + bool toBool(std::string& val) const; + + void cloneFeaturesInto(AbstractConfig& cfg) const { + for (const auto& feature : featureConfigs_) { + cfg.featureConfigs_[feature.first] = feature.second->cloneDerived(cfg); + } + } + + private: + // Time config was created / updated + std::chrono::time_point timestamp_{}; + + // Original configuration string, used for comparison + std::string source_; + + // Configuration objects for optional features + std::map featureConfigs_{}; +}; + +} // namespace libkineto + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/kineto/ActivityProfilerInterface.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/kineto/ActivityProfilerInterface.h new file mode 100644 index 0000000000000000000000000000000000000000..b179c703d52849a10fa447a63d7fab697fb946d6 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/kineto/ActivityProfilerInterface.h @@ -0,0 +1,113 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include + +#include "ActivityTraceInterface.h" +#include "ActivityType.h" +#include "IActivityProfiler.h" + +namespace libkineto { + +class ActivityProfilerController; +struct CpuTraceBuffer; +class Config; + +class ActivityProfilerInterface { + public: + virtual ~ActivityProfilerInterface() {} + + virtual void init() {} + virtual bool isInitialized() { + return false; + } + virtual bool isActive() { + return false; + } + + // *** Asynchronous API *** + // Instead of starting and stopping the trace manually, provide a start time + // and duration and / or iteration stop criterion. + // Tracing terminates when either condition is met. + virtual void scheduleTrace(const std::string& configStr) {} + + // *** Synchronous API *** + // These must be called in order: + // prepareTrace -> startTrace -> stopTrace. + + // Many tracing structures are lazily initialized during trace collection, + // with potentially high overhead. + // Call prepareTrace to enable tracing, then run the region to trace + // at least once (and ideally run the same code that is to be traced) to + // allow tracing structures to be initialized. + virtual void prepareTrace( + const std::set& activityTypes, + const std::string& configStr = "") {} + + // Toggle GPU tracing as a trace is running to omit certain parts of a graph + virtual void toggleCollectionDynamic(const bool enable) {} + + // Start recording, potentially reusing any buffers allocated since + // prepareTrace was called. + virtual void startTrace() {} + + // Stop and process trace, producing an in-memory list of trace records. + // The processing will be done synchronously (using the calling thread.) + virtual std::unique_ptr stopTrace() { + return nullptr; + } + + // Re-evaluate internal state to allow for triggering operations based + // on number of iteration. each implicitly increments the iteration count + virtual void step() {} + + // *** TraceActivity API *** + // FIXME: Pass activityProfiler interface into clientInterface? + virtual void pushCorrelationId(uint64_t id) {} + virtual void popCorrelationId() {} + virtual void transferCpuTrace(std::unique_ptr traceBuffer) {} + + // Correlation ids for user defined spans + virtual void pushUserCorrelationId(uint64_t) {} + virtual void popUserCorrelationId() {} + + // Saves information for the current thread to be used in profiler output + // Client must record any new kernel thread where the activity has occured. + virtual void recordThreadInfo() {} + + // Record trace metadata, currently supporting only string key and values, + // values with the same key are overwritten + virtual void addMetadata( + const std::string& key, + const std::string& value) = 0; + + // Add a child activity profiler, this enables frameworks in the application + // to enable custom framework events. + virtual void addChildActivityProfiler( + std::unique_ptr profiler) {} + + // Log Invariant Violation to factories enabled. This helps record + // instances when the profiler behaves unexpectedly. + virtual void logInvariantViolation( + const std::string&, + const std::string&, + const std::string&, + const std::string& = "") {} +}; + +} // namespace libkineto + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/kineto/ActivityTraceInterface.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/kineto/ActivityTraceInterface.h new file mode 100644 index 0000000000000000000000000000000000000000..ccf65f0671af7a2ddc69e11f29011f7d927f46e0 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/kineto/ActivityTraceInterface.h @@ -0,0 +1,33 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include + +namespace libkineto { + +struct ITraceActivity; + +class ActivityTraceInterface { + public: + virtual ~ActivityTraceInterface() {} + virtual const std::vector* activities() { + return nullptr; + } + virtual void save(const std::string& path) {} +}; + +} // namespace libkineto + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/kineto/ActivityType.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/kineto/ActivityType.h new file mode 100644 index 0000000000000000000000000000000000000000..832271f2cad7993c41f7b89617f727d7a9035357 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/kineto/ActivityType.h @@ -0,0 +1,71 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include + +namespace libkineto { + +// Note : All activity types are not enabled by default. Please add them +// at correct position in the enum +enum class ActivityType { + // Activity types enabled by default + CPU_OP = 0, // cpu side ops + USER_ANNOTATION, + GPU_USER_ANNOTATION, + GPU_MEMCPY, + GPU_MEMSET, + CONCURRENT_KERNEL, // on-device kernels + EXTERNAL_CORRELATION, + CUDA_RUNTIME, // host side cuda runtime events + CUDA_DRIVER, // host side cuda driver events + CPU_INSTANT_EVENT, // host side point-like events + PYTHON_FUNCTION, + OVERHEAD, // CUPTI induced overhead events sampled from its overhead API. + MTIA_RUNTIME, // host side MTIA runtime events + MTIA_CCP_EVENTS, // MTIA ondevice CCP events + MTIA_INSIGHT, // MTIA Insight Events + CUDA_SYNC, // synchronization events between runtime and kernels + CUDA_EVENT, // CUDA event activities (cudaEventRecord, etc.) + + // Optional Activity types + GLOW_RUNTIME, // host side glow runtime events + CUDA_PROFILER_RANGE, // CUPTI Profiler range for performance metrics + HPU_OP, // HPU host side runtime event + XPU_RUNTIME, // host side xpu runtime events + COLLECTIVE_COMM, // collective communication + + // PRIVATEUSE1 Activity types are used for custom backends. + // The corresponding device type is `DeviceType::PrivateUse1` in PyTorch. + PRIVATEUSE1_RUNTIME, // host side privateUse1 runtime events + PRIVATEUSE1_DRIVER, // host side privateUse1 driver events + + ENUM_COUNT, // This is to add buffer and not used for any profiling logic. Add + // your new type before it. + OPTIONAL_ACTIVITY_TYPE_START = GLOW_RUNTIME, +}; + +const char* toString(ActivityType t); +ActivityType toActivityType(const std::string& str); + +// Return an array of all activity types except COUNT +constexpr int activityTypeCount = (int)ActivityType::ENUM_COUNT; +constexpr int defaultActivityTypeCount = + (int)ActivityType::OPTIONAL_ACTIVITY_TYPE_START; +const std::array activityTypes(); +const std::array defaultActivityTypes(); + +} // namespace libkineto + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/kineto/ClientInterface.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/kineto/ClientInterface.h new file mode 100644 index 0000000000000000000000000000000000000000..b017252cd43eff6ff0d6afef7daab3bcd7efabe9 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/kineto/ClientInterface.h @@ -0,0 +1,31 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +namespace libkineto { + +class ClientInterface { + public: + virtual ~ClientInterface() {} + virtual void init() = 0; + virtual void prepare(bool, bool, bool, bool, bool) = 0; + virtual void start() = 0; + virtual void stop() = 0; + virtual void start_memory_profile() = 0; + virtual void stop_memory_profile() = 0; + virtual void export_memory_profile(const std::string&) = 0; +}; + +} // namespace libkineto + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/kineto/Config.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/kineto/Config.h new file mode 100644 index 0000000000000000000000000000000000000000..81e356b596f8b613b5095b7d3732d942b0382d14 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/kineto/Config.h @@ -0,0 +1,549 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include "AbstractConfig.h" +#include "ActivityType.h" + +#include +#include +#include +#include +#include +#include + +namespace libkineto { + +class Config : public AbstractConfig { + public: + Config(); + Config& operator=(const Config&) = delete; + Config(Config&&) = delete; + Config& operator=(Config&&) = delete; + ~Config() override = default; + + // Return a full copy including feature config object + std::unique_ptr clone() const { + auto cfg = std::unique_ptr(new Config(*this)); + cloneFeaturesInto(*cfg); + return cfg; + } + + bool handleOption(const std::string& name, std::string& val) override; + + void setClientDefaults() override; + + // Log events to this file + const std::string& eventLogFile() const { + return eventLogFile_; + } + + bool activityProfilerEnabled() const { + return activityProfilerEnabled_ || + activitiesOnDemandTimestamp_.time_since_epoch().count() > 0; + } + + // Log activitiy trace to this file + const std::string& activitiesLogFile() const { + return activitiesLogFile_; + } + + // Log activitiy trace to this url + const std::string& activitiesLogUrl() const { + return activitiesLogUrl_; + } + + void setActivitiesLogUrl(const std::string& url) { + activitiesLogUrl_ = url; + } + + bool activitiesLogToMemory() const { + return activitiesLogToMemory_; + } + + bool eventProfilerEnabled() const { + return !eventNames_.empty() || !metricNames_.empty(); + } + + // Is profiling enabled for the given device? + bool eventProfilerEnabledForDevice(uint32_t dev) const { + return 0 != (eventProfilerDeviceMask_ & (1 << dev)); + } + + // Take a sample (read hardware counters) at this frequency. + // This controls how often counters are read - if all counters cannot + // be collected simultaneously then multiple samples are needed to + // collect all requested counters - see multiplex period. + std::chrono::milliseconds samplePeriod() const { + return samplePeriod_; + } + + void setSamplePeriod(std::chrono::milliseconds period) { + samplePeriod_ = period; + } + + // When all requested counters cannot be collected simultaneously, + // counters will be multiplexed at this frequency. + // Multiplexing can have a large performance impact if done frequently. + // To avoid a perf impact, keep this at 1s or above. + std::chrono::milliseconds multiplexPeriod() const { + return multiplexPeriod_; + } + + void setMultiplexPeriod(std::chrono::milliseconds period) { + multiplexPeriod_ = period; + } + + // Report counters at this frequency. Note that several samples can + // be reported each time, see samplesPerReport. + std::chrono::milliseconds reportPeriod() const { + return reportPeriod_; + } + + void setReportPeriod(std::chrono::milliseconds msecs); + + // Number of samples dispatched each report period. + // Must be in the range [1, report period / sample period]. + // In other words, aggregation is supported but not interpolation. + int samplesPerReport() const { + return samplesPerReport_; + } + + void setSamplesPerReport(int count) { + samplesPerReport_ = count; + } + + // The names of events to collect + const std::set& eventNames() const { + return eventNames_; + } + + // Add additional events to be profiled + void addEvents(const std::set& names) { + eventNames_.insert(names.begin(), names.end()); + } + + // The names of metrics to collect + const std::set& metricNames() const { + return metricNames_; + } + + // Add additional metrics to be profiled + void addMetrics(const std::set& names) { + metricNames_.insert(names.begin(), names.end()); + } + + const std::vector& percentiles() const { + return eventReportPercentiles_; + } + + // Profile for this long, then revert to base config + std::chrono::seconds eventProfilerOnDemandDuration() const { + return eventProfilerOnDemandDuration_; + } + + void setEventProfilerOnDemandDuration(std::chrono::seconds duration) { + eventProfilerOnDemandDuration_ = duration; + } + + // Too many event profilers on a single system can overload the driver. + // At some point, latencies shoot through the roof and collection of samples + // becomes impossible. To avoid this situation we have a limit of profilers + // per GPU. + // NOTE: Communication with a daemon is needed for this feature. + // Library must be built with an active DaemonConfigLoader. + int maxEventProfilersPerGpu() const { + return eventProfilerMaxInstancesPerGpu_; + } + + // On Cuda11 we've seen occasional hangs when reprogramming counters + // Monitor profiling threads and report when a thread is not responding + // for a given number of seconds. + // A period of 0 means disable. + std::chrono::seconds eventProfilerHeartbeatMonitorPeriod() const { + return eventProfilerHeartbeatMonitorPeriod_; + } + + // The types of activities selected in the configuration file + const std::set& selectedActivityTypes() const { + return selectedActivityTypes_; + } + + // Set the types of activities to be traced + bool perThreadBufferEnabled() const { + return perThreadBufferEnabled_; + } + + void setSelectedActivityTypes(const std::set& types) { + selectedActivityTypes_ = types; + } + + bool isReportInputShapesEnabled() const { + return enableReportInputShapes_; + } + + bool isProfileMemoryEnabled() const { + return enableProfileMemory_; + } + + bool isWithStackEnabled() const { + return enableWithStack_; + } + + bool isWithFlopsEnabled() const { + return enableWithFlops_; + } + + bool isWithModulesEnabled() const { + return enableWithModules_; + } + + // Trace for this long + std::chrono::milliseconds activitiesDuration() const { + return activitiesDuration_; + } + + // Trace for this many iterations, determined by external API + int activitiesRunIterations() const { + return activitiesRunIterations_; + } + + int activitiesMaxGpuBufferSize() const { + return activitiesMaxGpuBufferSize_; + } + + std::chrono::seconds activitiesWarmupDuration() const { + return activitiesWarmupDuration_; + } + + int activitiesWarmupIterations() const { + return activitiesWarmupIterations_; + } + + // Show CUDA Synchronization Stream Wait Events + bool activitiesCudaSyncWaitEvents() const { + return activitiesCudaSyncWaitEvents_; + } + + void setActivitiesCudaSyncWaitEvents(bool enable) { + activitiesCudaSyncWaitEvents_ = enable; + } + + // Timestamp at which the profiling to start, requested by the user. + const std::chrono::time_point requestTimestamp() + const { + if (profileStartTime_.time_since_epoch().count()) { + return profileStartTime_; + } + // If no one requested timestamp, return 0. + if (requestTimestamp_.time_since_epoch().count() == 0) { + return requestTimestamp_; + } + + // TODO(T94634890): Deprecate requestTimestamp + return requestTimestamp_ + maxRequestAge() + activitiesWarmupDuration(); + } + + bool hasProfileStartTime() const { + return requestTimestamp_.time_since_epoch().count() > 0 || + profileStartTime_.time_since_epoch().count() > 0; + } + + int profileStartIteration() const { + return profileStartIteration_; + } + + bool hasProfileStartIteration() const { + return profileStartIteration_ >= 0 && activitiesRunIterations_ > 0; + } + + void setProfileStartIteration(int iter) { + profileStartIteration_ = iter; + } + + int profileStartIterationRoundUp() const { + return profileStartIterationRoundUp_; + } + + // calculate the start iteration accounting for warmup + int startIterationIncludingWarmup() const { + if (!hasProfileStartIteration()) { + return -1; + } + return profileStartIteration_ - activitiesWarmupIterations_; + } + + const std::chrono::seconds maxRequestAge() const; + + // All VLOG* macros will log if the verbose log level is >= + // the verbosity specified for the verbose log message. + // Default value is -1, so messages with log level 0 will log by default. + int verboseLogLevel() const { + return verboseLogLevel_; + } + + // Modules for which verbose logging is enabled. + // If empty, logging is enabled for all modules. + const std::vector& verboseLogModules() const { + return verboseLogModules_; + } + + bool sigUsr2Enabled() const { + return enableSigUsr2_; + } + + bool ipcFabricEnabled() const { + return enableIpcFabric_; + } + + std::chrono::seconds onDemandConfigUpdateIntervalSecs() const { + return onDemandConfigUpdateIntervalSecs_; + } + + static std::chrono::milliseconds alignUp( + std::chrono::milliseconds duration, + std::chrono::milliseconds alignment) { + duration += alignment; + return duration - (duration % alignment); + } + + std::chrono::time_point + eventProfilerOnDemandStartTime() const { + return eventProfilerOnDemandTimestamp_; + } + + std::chrono::time_point + eventProfilerOnDemandEndTime() const { + return eventProfilerOnDemandTimestamp_ + eventProfilerOnDemandDuration_; + } + + std::chrono::time_point + activityProfilerRequestReceivedTime() const { + return activitiesOnDemandTimestamp_; + } + + static constexpr std::chrono::milliseconds kControllerIntervalMsecs{1000}; + + // Users may request and set trace id and group trace id. + const std::string& requestTraceID() const { + return requestTraceID_; + } + + void setRequestTraceID(const std::string& tid) { + requestTraceID_ = tid; + } + + const std::string& requestGroupTraceID() const { + return requestGroupTraceID_; + } + + void setRequestGroupTraceID(const std::string& gtid) { + requestGroupTraceID_ = gtid; + } + + size_t cuptiDeviceBufferSize() const { + return cuptiDeviceBufferSize_; + } + + size_t cuptiDeviceBufferPoolLimit() const { + return cuptiDeviceBufferPoolLimit_; + } + + bool memoryProfilerEnabled() const { + return memoryProfilerEnabled_; + } + + int profileMemoryDuration() const { + return profileMemoryDuration_; + } + void updateActivityProfilerRequestReceivedTime(); + + void printActivityProfilerConfig(std::ostream& s) const override; + void setActivityDependentConfig() override; + + void validate( + const std::chrono::time_point& + fallbackProfileStartTime) override; + + static void addConfigFactory( + std::string name, + std::function factory); + + void print(std::ostream& s) const; + + // Config relies on some state with global static lifetime. If other + // threads are using the config, it's possible that the global state + // is destroyed before the threads stop. By hanging onto this handle, + // correct destruction order can be ensured. + static std::shared_ptr getStaticObjectsLifetimeHandle(); + + bool getTSCTimestampFlag() const { + return useTSCTimestamp_; + } + + void setTSCTimestampFlag(bool flag) { + useTSCTimestamp_ = flag; + } + + const std::string& getCustomConfig() const { + return customConfig_; + } + + uint32_t maxEvents() const { + return maxEvents_; + } + + private: + explicit Config(const Config& other) = default; + + AbstractConfig* cloneDerived(AbstractConfig& parent) const override { + // Clone from AbstractConfig not supported + assert(false); + return nullptr; + } + + uint8_t createDeviceMask(const std::string& val); + + // Adds valid activity types from the user defined string list in the + // configuration file + void setActivityTypes(const std::vector& selected_activities); + + // Sets the default activity types to be traced + void selectDefaultActivityTypes() { + // If the user has not specified an activity list, add all types + for (ActivityType t : defaultActivityTypes()) { + selectedActivityTypes_.insert(t); + } + } + + int verboseLogLevel_; + std::vector verboseLogModules_; + + // Event profiler + // These settings are also supported in on-demand mode + std::chrono::milliseconds samplePeriod_; + std::chrono::milliseconds reportPeriod_; + int samplesPerReport_; + std::set eventNames_; + std::set metricNames_; + + // On-demand duration + std::chrono::seconds eventProfilerOnDemandDuration_; + // Last on-demand request + std::chrono::time_point + eventProfilerOnDemandTimestamp_; + + int eventProfilerMaxInstancesPerGpu_; + + // Monitor whether event profiler threads are stuck + // at this frequency + std::chrono::seconds eventProfilerHeartbeatMonitorPeriod_; + + // These settings can not be changed on-demand + std::string eventLogFile_; + std::vector eventReportPercentiles_ = {5, 25, 50, 75, 95}; + uint8_t eventProfilerDeviceMask_ = ~0; + std::chrono::milliseconds multiplexPeriod_; + + // Activity profiler + bool activityProfilerEnabled_; + + // Enable per-thread buffer + bool perThreadBufferEnabled_; + std::set selectedActivityTypes_; + + // The activity profiler settings are all on-demand + std::string activitiesLogFile_; + + std::string activitiesLogUrl_; + + // Log activities to memory buffer + bool activitiesLogToMemory_{false}; + + int activitiesMaxGpuBufferSize_; + std::chrono::seconds activitiesWarmupDuration_; + int activitiesWarmupIterations_; + bool activitiesCudaSyncWaitEvents_; + + // Enable Profiler Config Options + // Temporarily disable shape collection until we re-roll out the feature for + // on-demand cases + bool enableReportInputShapes_{false}; + bool enableProfileMemory_{false}; + bool enableWithStack_{false}; + bool enableWithFlops_{false}; + bool enableWithModules_{false}; + + // Profile for specified iterations and duration + std::chrono::milliseconds activitiesDuration_; + int activitiesRunIterations_; + + // Below are not used + // Use this net name for iteration count + std::string activitiesExternalAPIIterationsTarget_; + // Only profile nets that includes this in the name + std::vector activitiesExternalAPIFilter_; + // Only profile nets with at least this many operators + int activitiesExternalAPINetSizeThreshold_; + // Only profile nets with at least this many GPU operators + int activitiesExternalAPIGpuOpCountThreshold_; + // Last activity profiler request + std::chrono::time_point + activitiesOnDemandTimestamp_; + + // ActivityProfilers are triggered by either: + // Synchronized start timestamps + std::chrono::time_point profileStartTime_; + // Or start iterations. + int profileStartIteration_; + int profileStartIterationRoundUp_; + + // DEPRECATED + std::chrono::time_point requestTimestamp_; + + // Enable profiling via SIGUSR2 + bool enableSigUsr2_; + + // Enable IPC Fabric instead of thrift communication + bool enableIpcFabric_; + std::chrono::seconds onDemandConfigUpdateIntervalSecs_; + + // Logger Metadata + std::string requestTraceID_; + std::string requestGroupTraceID_; + + // CUPTI Device Buffer + size_t cuptiDeviceBufferSize_; + size_t cuptiDeviceBufferPoolLimit_; + + // CUPTI Timestamp Format + bool useTSCTimestamp_{true}; + + // Memory Profiler + bool memoryProfilerEnabled_{false}; + int profileMemoryDuration_{1000}; + + // Used to flexibly configure some custom options, especially for custom + // backends. How to parse this string is handled by the custom backend. + std::string customConfig_; + + // Roctracer settings + uint32_t maxEvents_{5000000}; +}; + +constexpr char kUseDaemonEnvVar[] = "KINETO_USE_DAEMON"; + +bool isDaemonEnvVarSet(); + +} // namespace libkineto + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/kineto/GenericTraceActivity.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/kineto/GenericTraceActivity.h new file mode 100644 index 0000000000000000000000000000000000000000..dd5187464ef429129d7035b2c8792b2df8d35ed0 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/kineto/GenericTraceActivity.h @@ -0,0 +1,164 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "ITraceActivity.h" +#include "ThreadUtil.h" +#include "TraceSpan.h" + +namespace libkineto { + +// Link type, used in GenericTraceActivity.flow.type +constexpr unsigned int kLinkFwdBwd = 1; +constexpr unsigned int kLinkAsyncCpuGpu = 2; + +// @lint-ignore-every CLANGTIDY +// cppcoreguidelines-non-private-member-variables-in-classes +// @lint-ignore-every CLANGTIDY cppcoreguidelines-pro-type-member-init +class GenericTraceActivity : public ITraceActivity { + public: + GenericTraceActivity() + : activityType(ActivityType::ENUM_COUNT), traceSpan_(nullptr) {} + + GenericTraceActivity( + const TraceSpan& trace, + ActivityType type, + const std::string& name) + : activityType(type), activityName(name), traceSpan_(&trace) {} + + int64_t deviceId() const override { + return device; + } + + int64_t resourceId() const override { + return resource; + } + + void setDevice(int32_t newDevice) { + device = newDevice; + } + + int32_t getThreadId() const override { + return threadId; + } + + int64_t timestamp() const override { + return startTime; + } + + int64_t duration() const override { + return endTime - startTime; + } + + int64_t correlationId() const override { + return id; + } + + ActivityType type() const override { + return activityType; + } + + const ITraceActivity* linkedActivity() const override { + return linked; + } + + int flowType() const override { + return flow.type; + } + + int64_t flowId() const override { + return flow.id; + } + + bool flowStart() const override { + return flow.start; + } + + const std::string name() const override { + return activityName; + } + + const TraceSpan* traceSpan() const override { + return traceSpan_; + } + + void log(ActivityLogger& logger) const override; + + // Encode client side metadata as a key/value + template + void addMetadata(const std::string& key, const ValType& value) { + metadataMap_.emplace(key, std::make_pair(fmt::format("{}", value), false)); + } + + void addMetadataQuoted(const std::string& key, const std::string& value) { + metadataMap_.emplace(key, std::make_pair(value, true)); + } + + const std::string getMetadataValue(const std::string& key) const override { + if (auto it = metadataMap_.find(key); it != metadataMap_.end()) { + return it->second.first; + } + return ""; + } + + const std::string metadataJson() const override { + std::stringstream json; + bool first = true; + for (const auto& [key, val] : metadataMap_) { + if (!first) { + json << ", "; + } + // Ok to use fmt::format here as we are not logging + val.second ? json << fmt::format("\"{}\": \"{}\"", key, val.first) + : json << fmt::format("\"{}\": {}", key, val.first); + first = false; + } + return json.str(); + } + + virtual ~GenericTraceActivity() override {} + + int64_t startTime{0}; + int64_t endTime{0}; + int32_t id{0}; + int32_t device{0}; + int32_t resource{0}; + int32_t threadId{0}; + ActivityType activityType; + std::string activityName; + struct Flow { + Flow() : id(0), type(0), start(0) {} + // Ids must be unique within each type + uint32_t id; + // Type will be used to connect flows between profilers, as + // well as look up flow information (name etc) + uint32_t type : 4; + uint32_t start : 1; + } flow; + const ITraceActivity* linked{nullptr}; + + private: + const TraceSpan* traceSpan_; + // Metadata map: { key: (value, quoted)} + std::unordered_map> metadataMap_; +}; + +} // namespace libkineto + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/kineto/IActivityProfiler.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/kineto/IActivityProfiler.h new file mode 100644 index 0000000000000000000000000000000000000000..1edf56eb2d6d6d57162ad9ade73f0973b7dd7a1c --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/kineto/IActivityProfiler.h @@ -0,0 +1,176 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include + +#include "Config.h" +#include "GenericTraceActivity.h" + +/* This file includes an abstract base class for an activity profiler + * that can be implemented by multiple tracing agents in the application. + * The high level Kineto profiler can co-ordinate start and end of tracing + * and combine together events from multiple such activity profilers. + */ + +namespace libkineto { + +struct CpuTraceBuffer; + +#ifdef _MSC_VER +// workaround for the predefined ERROR macro on Windows +#undef ERROR +#endif // _MSC_VER + +enum class TraceStatus { + READY, // Accepting trace requests + WARMUP, // Performing trace warmup + RECORDING, // Actively collecting activities + PROCESSING, // Recording is complete, preparing results + ERROR, // One or more errors (and possibly also warnings) occurred. + WARNING, // One or more warnings occurred. +}; + +/* DeviceInfo: + * Can be used to specify process name, sort order, PID and device label. + * The sort order is determined by the sortIndex field to handle ordering of + * processes and gpu rows in the trace viewer. + */ +struct DeviceInfo { + DeviceInfo( + int64_t id, + int64_t sortIndex, + const std::string& name, + const std::string& label) + : id(id), sortIndex(sortIndex), name(name), label(label) {} + int64_t id; // process id + int64_t sortIndex; // position in trace view + const std::string name; // process name + const std::string label; // device label +}; + +/* ResourceInfo: + * Can be used to specify resource inside device + */ +struct ResourceInfo { + ResourceInfo( + int64_t deviceId, + int64_t id, + int64_t sortIndex, + const std::string& name) + : id(id), sortIndex(sortIndex), deviceId(deviceId), name(name) {} + int64_t id; // resource id + int64_t sortIndex; // position in trace view + int64_t deviceId; // id of device which owns this resource (specified in + // DeviceInfo.id) + const std::string name; // resource name +}; + +using getLinkedActivityCallback = std::function; + +/* IActivityProfilerSession: + * an opaque object that can be used by a high level profiler to + * start/stop and return trace events. + */ +class IActivityProfilerSession { + public: + virtual ~IActivityProfilerSession() {} + + // start the trace collection synchronously + virtual void start() = 0; + + // stop the trace collection synchronously + virtual void stop() = 0; + + TraceStatus status() { + return status_; + } + + // returns errors with this trace + virtual std::vector errors() = 0; + + // processes trace activities using logger + virtual void processTrace(ActivityLogger& logger) = 0; + + virtual void processTrace( + ActivityLogger& logger, + getLinkedActivityCallback /*getLinkedActivity*/, + int64_t /*startTime*/, + int64_t /*endTime*/) { + processTrace(logger); + } + + // returns device info used in this trace, could be nullptr + virtual std::unique_ptr getDeviceInfo() = 0; + + // returns resource info used in this trace, could be empty + virtual std::vector getResourceInfos() = 0; + + // release ownership of the trace events and metadata + virtual std::unique_ptr getTraceBuffer() = 0; + + // XXX define trace formats + // virtual save(string name, TraceFormat format) + + virtual void pushCorrelationId(uint64_t /*id*/) {} + virtual void popCorrelationId() {} + + virtual void pushUserCorrelationId(uint64_t /*id*/) {} + virtual void popUserCorrelationId() {} + + virtual std::string getDeviceProperties() { + return ""; + } + + virtual std::unordered_map getMetadata() { + return {}; + } + + protected: + TraceStatus status_ = TraceStatus::READY; +}; + +/* Activity Profiler Plugins: + * These allow other frameworks to integrate into Kineto's primariy + * activity profiler. While the primary activity profiler handles + * timing the trace collections and correlating events the plugins + * can become source of new trace activity types. + */ +class IActivityProfiler { + public: + virtual ~IActivityProfiler() {} + + // name of profiler + virtual const std::string& name() const = 0; + + // returns activity types this profiler supports + virtual const std::set& availableActivities() const = 0; + + // Calls prepare() on registered tracer providers passing in the relevant + // activity types. Returns a profiler session handle + virtual std::unique_ptr configure( + const std::set& activity_types, + const Config& config) = 0; + + // asynchronous version of the above with future timestamp and duration. + virtual std::unique_ptr configure( + int64_t ts_ms, + int64_t duration_ms, + const std::set& activity_types, + const Config& config) = 0; +}; + +} // namespace libkineto + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/kineto/ILoggerObserver.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/kineto/ILoggerObserver.h new file mode 100644 index 0000000000000000000000000000000000000000..758f0194d066eb9c288654f1ad0f5d02ea130133 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/kineto/ILoggerObserver.h @@ -0,0 +1,77 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#define NOGDI +#include + +// Stages in libkineto used when pushing logs to UST Logger. +constexpr char kWarmUpStage[] = "Warm Up"; +constexpr char kCollectionStage[] = "Collection"; +constexpr char kPostProcessingStage[] = "Post Processing"; + +// Special string in UST for determining if traces are empty +constexpr char kEmptyTrace[] = + "No Valid Trace Events (CPU/GPU) found. Outputting empty trace."; + +#if !USE_GOOGLE_LOG + +#include +#include + +#include + +#ifdef _MSC_VER +// unset a predefined ERROR (windows) +#undef ERROR +#endif // _MSC_VER + +namespace libkineto { + +enum LoggerOutputType { + VERBOSE = 0, + INFO = 1, + WARNING = 2, + STAGE = 3, + ERROR = 4, + ENUM_COUNT = 5 +}; + +const char* toString(LoggerOutputType t); +LoggerOutputType toLoggerOutputType(const std::string& str); + +constexpr int LoggerTypeCount = (int)LoggerOutputType::ENUM_COUNT; + +class ILoggerObserver { + public: + virtual ~ILoggerObserver() = default; + virtual void write(const std::string& message, LoggerOutputType ot) = 0; + virtual const std::map> + extractCollectorMetadata() = 0; + virtual void reset() = 0; + virtual void addDevice(const int64_t device) = 0; + virtual void setTraceDurationMS(const int64_t duration) = 0; + virtual void addEventCount(const int64_t count) = 0; + virtual void setTraceID(const std::string&) {} + virtual void setGroupTraceID(const std::string&) {} + virtual void addDestination(const std::string& dest) = 0; + virtual void setTriggerOnDemand() {} + virtual void addMetadata( + const std::string& key, + const std::string& value) = 0; +}; + +} // namespace libkineto + +#endif // !USE_GOOGLE_LOG + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/kineto/ITraceActivity.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/kineto/ITraceActivity.h new file mode 100644 index 0000000000000000000000000000000000000000..bbdd708dbf102e3e766771e86bbe1d701a1b8844 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/kineto/ITraceActivity.h @@ -0,0 +1,69 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +#include "ActivityType.h" + +namespace libkineto { + +class ActivityLogger; +struct TraceSpan; + +// Generic activity interface is borrowed from tensorboard protobuf format. +struct ITraceActivity { + virtual ~ITraceActivity() {} + // Device is a physical or logical entity, e.g. CPU, GPU or process + virtual int64_t deviceId() const = 0; + // A resource is something on the device, h/w thread, + // functional units etc. + virtual int64_t resourceId() const = 0; + // s/w thread + virtual int32_t getThreadId() const = 0; + // Start timestamp in nanoseconds + virtual int64_t timestamp() const = 0; + // Duration in nanoseconds + virtual int64_t duration() const = 0; + // Used to link up async activities + virtual int64_t correlationId() const = 0; + // Part of a flow, identified by flow id and type + virtual int flowType() const = 0; + virtual int64_t flowId() const = 0; + virtual bool flowStart() const = 0; + virtual ActivityType type() const = 0; + virtual const std::string name() const = 0; + // Optional linked activity + virtual const ITraceActivity* linkedActivity() const = 0; + // Optional containing trace object + virtual const TraceSpan* traceSpan() const = 0; + // Log activity + virtual void log(ActivityLogger& logger) const = 0; + // Return json formatted metadata + // FIXME: Return iterator to dynamic type map here instead + virtual const std::string metadataJson() const = 0; + // Return the metadata value in string format with key + // @lint-ignore CLANGTIDY: clang-diagnostic-unused-parameter + virtual const std::string getMetadataValue(const std::string& key) const { + return ""; + } + + static int64_t nsToUs(int64_t ns) { + // It's important that this conversion is the same everywhere. + // No rounding! + return ns / 1000; + } +}; + +} // namespace libkineto + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/kineto/LoggingAPI.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/kineto/LoggingAPI.h new file mode 100644 index 0000000000000000000000000000000000000000..d27484403234139ff0153276b08d1953c6821d37 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/kineto/LoggingAPI.h @@ -0,0 +1,19 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +namespace libkineto { +int getLogSeverityLevel(); +void setLogSeverityLevel(int level); +} // namespace libkineto + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/kineto/ThreadUtil.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/kineto/ThreadUtil.h new file mode 100644 index 0000000000000000000000000000000000000000..efc92dc98be2eba2d151fd9b9cfdbbc2603d3e46 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/kineto/ThreadUtil.h @@ -0,0 +1,41 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include +#include + +namespace libkineto { + +int32_t systemThreadId(bool cache = true); +int32_t threadId(); +bool setThreadName(const std::string& name); +std::string getThreadName(); + +int32_t pidNamespace(ino_t& ns); +int32_t processId(bool cache = true); +std::string processName(int32_t pid); + +// Return a list of pids and process names for the current process +// and its parents. +std::vector> pidCommandPairsOfAncestors(); + +// Resets all cached Thread local state, this must be done on +// forks to prevent stale values from being retained. +void resetTLS(); + +} // namespace libkineto + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/kineto/TraceSpan.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/kineto/TraceSpan.h new file mode 100644 index 0000000000000000000000000000000000000000..395075bff5e4a104b21e4c90d09226723a722331 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/kineto/TraceSpan.h @@ -0,0 +1,43 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include + +namespace libkineto { + +struct TraceSpan { + TraceSpan() = delete; + TraceSpan(int64_t startTime, int64_t endTime, std::string name) + : startTime(startTime), endTime(endTime), name(std::move(name)) {} + TraceSpan(int opCount, int it, std::string name, std::string prefix) + : opCount(opCount), + iteration(it), + name(std::move(name)), + prefix(std::move(prefix)) {} + + // FIXME: change to duration? + int64_t startTime{0}; + int64_t endTime{0}; + int opCount{0}; + int iteration{-1}; + // Name is used to identify timeline + std::string name; + // Prefix used to distinguish trace spans on the same timeline + std::string prefix; +}; + +} // namespace libkineto + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/kineto/libkineto.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/kineto/libkineto.h new file mode 100644 index 0000000000000000000000000000000000000000..e956a06075b317ecabaf903913181330c00304e5 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/kineto/libkineto.h @@ -0,0 +1,167 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +// Mediator for initialization and profiler control + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ActivityProfilerInterface.h" +#include "ActivityTraceInterface.h" +#include "ActivityType.h" +#include "ClientInterface.h" +#include "GenericTraceActivity.h" +#include "IActivityProfiler.h" +#include "ILoggerObserver.h" +#include "LoggingAPI.h" +#include "TraceSpan.h" + +#include "ThreadUtil.h" + +extern "C" { +void suppressLibkinetoLogMessages(); +int InitializeInjection(void); +void libkineto_init(bool cpuOnly, bool logOnError); +bool hasTestEnvVar(); +} + +namespace libkineto { + +class Config; +class ConfigLoader; + +struct CpuTraceBuffer { + template + void emplace_activity(Args&&... args) { + activities.emplace_back( + std::make_unique(std::forward(args)...)); + } + + static GenericTraceActivity& toRef( + std::unique_ptr& ref) { + return *ref; + } + + static const GenericTraceActivity& toRef( + const std::unique_ptr& ref) { + return *ref; + } + + TraceSpan span{0, 0, "none"}; + int gpuOpCount; + std::deque> activities; +}; + +using ChildActivityProfilerFactory = + std::function()>; + +class LibkinetoApi { + public: + explicit LibkinetoApi(ConfigLoader& configLoader) + : configLoader_(configLoader) {} + + // Called by client that supports tracing API. + // libkineto can still function without this. + void registerClient(ClientInterface* client); + + // Called by libkineto on init + void registerProfiler(std::unique_ptr profiler) { + activityProfiler_ = std::move(profiler); + initClientIfRegistered(); + } + + ActivityProfilerInterface& activityProfiler() { + return *activityProfiler_; + } + + ClientInterface* client() { + return client_; + } + + void initProfilerIfRegistered() { + static std::once_flag once; + if (activityProfiler_) { + std::call_once(once, [this] { + if (!activityProfiler_->isInitialized()) { + activityProfiler_->init(); + initChildActivityProfilers(); + } + }); + } + } + + bool isProfilerInitialized() const { + return activityProfiler_ && activityProfiler_->isInitialized(); + } + + bool isProfilerRegistered() const { + return activityProfiler_ != nullptr; + } + + void suppressLogMessages() { + suppressLibkinetoLogMessages(); + } + + void resetKinetoTLS() { + resetTLS(); + } + + // Provides access to profier configuration manaegement + ConfigLoader& configLoader() { + return configLoader_; + } + + void registerProfilerFactory(const ChildActivityProfilerFactory& factory) { + if (isProfilerInitialized()) { + activityProfiler_->addChildActivityProfiler(factory()); + } else { + childProfilerFactories_.push_back(factory); + } + } + + private: + void initChildActivityProfilers() { + if (!isProfilerInitialized()) { + return; + } + for (const auto& factory : childProfilerFactories_) { + activityProfiler_->addChildActivityProfiler(factory()); + } + childProfilerFactories_.clear(); + } + + // Client is initialized once both it and libkineto has registered + void initClientIfRegistered(); + + ConfigLoader& configLoader_; + std::unique_ptr activityProfiler_{}; + ClientInterface* client_{}; + int32_t clientRegisterThread_{0}; + + std::vector childProfilerFactories_; +}; + +// Singleton +LibkinetoApi& api(); + +} // namespace libkineto + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/kineto/output_base.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/kineto/output_base.h new file mode 100644 index 0000000000000000000000000000000000000000..010a4145cfe3f848a185d58f5fb8f0d822684741 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/kineto/output_base.h @@ -0,0 +1,83 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include +#include + +// TODO(T90238193) +// @lint-ignore-every CLANGTIDY facebook-hte-RelativeInclude +#include "GenericTraceActivity.h" +#include "IActivityProfiler.h" +#include "ThreadUtil.h" +#include "TraceSpan.h" + +namespace KINETO_NAMESPACE { +struct ActivityBuffers; +} + +namespace libkineto { + +using namespace KINETO_NAMESPACE; + +// Used by sortIndex to put GPU tracks at the bottom +// of the trace timelines. The largest valid CPU PID is 4,194,304, +// so 5000000 is enough to guarantee that GPU tracks are sorted after CPU. +constexpr int64_t kExceedMaxPid = 5000000; + +class ActivityLogger { + public: + virtual ~ActivityLogger() = default; + + struct OverheadInfo { + explicit OverheadInfo(const std::string& name) : name(name) {} + const std::string name; + }; + + virtual void handleDeviceInfo(const DeviceInfo& info, uint64_t time) = 0; + + virtual void handleResourceInfo(const ResourceInfo& info, int64_t time) = 0; + + virtual void handleOverheadInfo(const OverheadInfo& info, int64_t time) = 0; + + virtual void handleTraceSpan(const TraceSpan& span) = 0; + + virtual void handleActivity(const libkineto::ITraceActivity& activity) = 0; + virtual void handleGenericActivity( + const libkineto::GenericTraceActivity& activity) = 0; + + virtual void handleTraceStart( + const std::unordered_map& metadata, + const std::string& device_properties) = 0; + + void handleTraceStart() { + handleTraceStart(std::unordered_map(), ""); + } + + virtual void finalizeMemoryTrace(const std::string&, const Config&) = 0; + + virtual void finalizeTrace( + const Config& config, + std::unique_ptr buffers, + int64_t endTime, + std::unordered_map>& metadata) = 0; + + protected: + ActivityLogger() = default; +}; + +} // namespace libkineto + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/kineto/time_since_epoch.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/kineto/time_since_epoch.h new file mode 100644 index 0000000000000000000000000000000000000000..341e36d7ae1381769818e3c320275346f302380e --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/kineto/time_since_epoch.h @@ -0,0 +1,26 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace libkineto { +template +inline int64_t timeSinceEpoch(const std::chrono::time_point& t) { + return std::chrono::duration_cast( + t.time_since_epoch()) + .count(); +} + +} // namespace libkineto + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/legacy/ittnotify.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/legacy/ittnotify.h new file mode 100644 index 0000000000000000000000000000000000000000..307580cd0d0faa459abf5452c8a3f273cba55942 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/legacy/ittnotify.h @@ -0,0 +1,1009 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +/* + Copyright (C) 2005-2019 Intel Corporation + + SPDX-License-Identifier: GPL-2.0-only OR BSD-3-Clause +*/ +#ifndef _LEGACY_ITTNOTIFY_H_ +#define _LEGACY_ITTNOTIFY_H_ + +/** + * @file + * @brief Legacy User API functions and types + */ + +/** @cond exclude_from_documentation */ +#ifndef ITT_OS_WIN +# define ITT_OS_WIN 1 +#endif /* ITT_OS_WIN */ + +#ifndef ITT_OS_LINUX +# define ITT_OS_LINUX 2 +#endif /* ITT_OS_LINUX */ + +#ifndef ITT_OS_MAC +# define ITT_OS_MAC 3 +#endif /* ITT_OS_MAC */ + +#ifndef ITT_OS_FREEBSD +# define ITT_OS_FREEBSD 4 +#endif /* ITT_OS_FREEBSD */ + +#ifndef ITT_OS_OPENBSD +# define ITT_OS_OPENBSD 5 +#endif /* ITT_OS_OPENBSD */ + +#ifndef ITT_OS +# if defined WIN32 || defined _WIN32 +# define ITT_OS ITT_OS_WIN +# elif defined( __APPLE__ ) && defined( __MACH__ ) +# define ITT_OS ITT_OS_MAC +# elif defined( __FreeBSD__ ) +# define ITT_OS ITT_OS_FREEBSD +# elif defined( __OpenBSD__ ) +# define ITT_OS ITT_OS_OPENBSD +# else +# define ITT_OS ITT_OS_LINUX +# endif +#endif /* ITT_OS */ + +#ifndef ITT_PLATFORM_WIN +# define ITT_PLATFORM_WIN 1 +#endif /* ITT_PLATFORM_WIN */ + +#ifndef ITT_PLATFORM_POSIX +# define ITT_PLATFORM_POSIX 2 +#endif /* ITT_PLATFORM_POSIX */ + +#ifndef ITT_PLATFORM_MAC +# define ITT_PLATFORM_MAC 3 +#endif /* ITT_PLATFORM_MAC */ + +#ifndef ITT_PLATFORM_FREEBSD +# define ITT_PLATFORM_FREEBSD 4 +#endif /* ITT_PLATFORM_FREEBSD */ + +#ifndef ITT_PLATFORM_OPENBSD +# define ITT_PLATFORM_OPENBSD 5 +#endif /* ITT_PLATFORM_OPENBSD */ + +#ifndef ITT_PLATFORM +# if ITT_OS==ITT_OS_WIN +# define ITT_PLATFORM ITT_PLATFORM_WIN +# elif ITT_OS==ITT_OS_MAC +# define ITT_PLATFORM ITT_PLATFORM_MAC +# elif ITT_OS==ITT_OS_FREEBSD +# define ITT_PLATFORM ITT_PLATFORM_FREEBSD +# elif ITT_OS==ITT_OS_OPENBSD +# define ITT_PLATFORM ITT_PLATFORM_OPENBSD +# else +# define ITT_PLATFORM ITT_PLATFORM_POSIX +# endif +#endif /* ITT_PLATFORM */ + +#if defined(_UNICODE) && !defined(UNICODE) +#define UNICODE +#endif + +#include +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#include +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#include +#if defined(UNICODE) || defined(_UNICODE) +#include +#endif /* UNICODE || _UNICODE */ +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ + +#ifndef ITTAPI_CDECL +# if ITT_PLATFORM==ITT_PLATFORM_WIN +# define ITTAPI_CDECL __cdecl +# else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +# if defined _M_IX86 || defined __i386__ +# define ITTAPI_CDECL __attribute__ ((cdecl)) +# else /* _M_IX86 || __i386__ */ +# define ITTAPI_CDECL /* actual only on x86 platform */ +# endif /* _M_IX86 || __i386__ */ +# endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#endif /* ITTAPI_CDECL */ + +#ifndef STDCALL +# if ITT_PLATFORM==ITT_PLATFORM_WIN +# define STDCALL __stdcall +# else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +# if defined _M_IX86 || defined __i386__ +# define STDCALL __attribute__ ((stdcall)) +# else /* _M_IX86 || __i386__ */ +# define STDCALL /* supported only on x86 platform */ +# endif /* _M_IX86 || __i386__ */ +# endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#endif /* STDCALL */ + +#define ITTAPI ITTAPI_CDECL +#define LIBITTAPI ITTAPI_CDECL + +/* TODO: Temporary for compatibility! */ +#define ITTAPI_CALL ITTAPI_CDECL +#define LIBITTAPI_CALL ITTAPI_CDECL + +#if ITT_PLATFORM==ITT_PLATFORM_WIN +/* use __forceinline (VC++ specific) */ +#if defined(__MINGW32__) && !defined(__cplusplus) +#define ITT_INLINE static __inline__ __attribute__((__always_inline__,__gnu_inline__)) +#else +#define ITT_INLINE static __forceinline +#endif /* __MINGW32__ */ + +#define ITT_INLINE_ATTRIBUTE /* nothing */ +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +/* + * Generally, functions are not inlined unless optimization is specified. + * For functions declared inline, this attribute inlines the function even + * if no optimization level was specified. + */ +#ifdef __STRICT_ANSI__ +#define ITT_INLINE static +#define ITT_INLINE_ATTRIBUTE __attribute__((unused)) +#else /* __STRICT_ANSI__ */ +#define ITT_INLINE static inline +#define ITT_INLINE_ATTRIBUTE __attribute__((always_inline, unused)) +#endif /* __STRICT_ANSI__ */ +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +/** @endcond */ + +/** @cond exclude_from_documentation */ +/* Helper macro for joining tokens */ +#define ITT_JOIN_AUX(p,n) p##n +#define ITT_JOIN(p,n) ITT_JOIN_AUX(p,n) + +#ifdef ITT_MAJOR +#undef ITT_MAJOR +#endif +#ifdef ITT_MINOR +#undef ITT_MINOR +#endif +#define ITT_MAJOR 3 +#define ITT_MINOR 0 + +/* Standard versioning of a token with major and minor version numbers */ +#define ITT_VERSIONIZE(x) \ + ITT_JOIN(x, \ + ITT_JOIN(_, \ + ITT_JOIN(ITT_MAJOR, \ + ITT_JOIN(_, ITT_MINOR)))) + +#ifndef INTEL_ITTNOTIFY_PREFIX +# define INTEL_ITTNOTIFY_PREFIX __itt_ +#endif /* INTEL_ITTNOTIFY_PREFIX */ +#ifndef INTEL_ITTNOTIFY_POSTFIX +# define INTEL_ITTNOTIFY_POSTFIX _ptr_ +#endif /* INTEL_ITTNOTIFY_POSTFIX */ + +#define ITTNOTIFY_NAME_AUX(n) ITT_JOIN(INTEL_ITTNOTIFY_PREFIX,n) +#define ITTNOTIFY_NAME(n) ITT_VERSIONIZE(ITTNOTIFY_NAME_AUX(ITT_JOIN(n,INTEL_ITTNOTIFY_POSTFIX))) + +#define ITTNOTIFY_VOID(n) (!ITTNOTIFY_NAME(n)) ? (void)0 : ITTNOTIFY_NAME(n) +#define ITTNOTIFY_DATA(n) (!ITTNOTIFY_NAME(n)) ? 0 : ITTNOTIFY_NAME(n) + +#define ITTNOTIFY_VOID_D0(n,d) (d == NULL) ? (void)0 : (!(d)->flags) ? (void)0 : (!ITTNOTIFY_NAME(n)) ? (void)0 : ITTNOTIFY_NAME(n)(d) +#define ITTNOTIFY_VOID_D1(n,d,x) (d == NULL) ? (void)0 : (!(d)->flags) ? (void)0 : (!ITTNOTIFY_NAME(n)) ? (void)0 : ITTNOTIFY_NAME(n)(d,x) +#define ITTNOTIFY_VOID_D2(n,d,x,y) (d == NULL) ? (void)0 : (!(d)->flags) ? (void)0 : (!ITTNOTIFY_NAME(n)) ? (void)0 : ITTNOTIFY_NAME(n)(d,x,y) +#define ITTNOTIFY_VOID_D3(n,d,x,y,z) (d == NULL) ? (void)0 : (!(d)->flags) ? (void)0 : (!ITTNOTIFY_NAME(n)) ? (void)0 : ITTNOTIFY_NAME(n)(d,x,y,z) +#define ITTNOTIFY_VOID_D4(n,d,x,y,z,a) (d == NULL) ? (void)0 : (!(d)->flags) ? (void)0 : (!ITTNOTIFY_NAME(n)) ? (void)0 : ITTNOTIFY_NAME(n)(d,x,y,z,a) +#define ITTNOTIFY_VOID_D5(n,d,x,y,z,a,b) (d == NULL) ? (void)0 : (!(d)->flags) ? (void)0 : (!ITTNOTIFY_NAME(n)) ? (void)0 : ITTNOTIFY_NAME(n)(d,x,y,z,a,b) +#define ITTNOTIFY_VOID_D6(n,d,x,y,z,a,b,c) (d == NULL) ? (void)0 : (!(d)->flags) ? (void)0 : (!ITTNOTIFY_NAME(n)) ? (void)0 : ITTNOTIFY_NAME(n)(d,x,y,z,a,b,c) +#define ITTNOTIFY_DATA_D0(n,d) (d == NULL) ? 0 : (!(d)->flags) ? 0 : (!ITTNOTIFY_NAME(n)) ? 0 : ITTNOTIFY_NAME(n)(d) +#define ITTNOTIFY_DATA_D1(n,d,x) (d == NULL) ? 0 : (!(d)->flags) ? 0 : (!ITTNOTIFY_NAME(n)) ? 0 : ITTNOTIFY_NAME(n)(d,x) +#define ITTNOTIFY_DATA_D2(n,d,x,y) (d == NULL) ? 0 : (!(d)->flags) ? 0 : (!ITTNOTIFY_NAME(n)) ? 0 : ITTNOTIFY_NAME(n)(d,x,y) +#define ITTNOTIFY_DATA_D3(n,d,x,y,z) (d == NULL) ? 0 : (!(d)->flags) ? 0 : (!ITTNOTIFY_NAME(n)) ? 0 : ITTNOTIFY_NAME(n)(d,x,y,z) +#define ITTNOTIFY_DATA_D4(n,d,x,y,z,a) (d == NULL) ? 0 : (!(d)->flags) ? 0 : (!ITTNOTIFY_NAME(n)) ? 0 : ITTNOTIFY_NAME(n)(d,x,y,z,a) +#define ITTNOTIFY_DATA_D5(n,d,x,y,z,a,b) (d == NULL) ? 0 : (!(d)->flags) ? 0 : (!ITTNOTIFY_NAME(n)) ? 0 : ITTNOTIFY_NAME(n)(d,x,y,z,a,b) +#define ITTNOTIFY_DATA_D6(n,d,x,y,z,a,b,c) (d == NULL) ? 0 : (!(d)->flags) ? 0 : (!ITTNOTIFY_NAME(n)) ? 0 : ITTNOTIFY_NAME(n)(d,x,y,z,a,b,c) + +#ifdef ITT_STUB +#undef ITT_STUB +#endif +#ifdef ITT_STUBV +#undef ITT_STUBV +#endif +#define ITT_STUBV(api,type,name,args) \ + typedef type (api* ITT_JOIN(ITTNOTIFY_NAME(name),_t)) args; \ + extern ITT_JOIN(ITTNOTIFY_NAME(name),_t) ITTNOTIFY_NAME(name); +#define ITT_STUB ITT_STUBV +/** @endcond */ + +#ifdef __cplusplus +extern "C" { +#endif /* __cplusplus */ + +/** + * @defgroup legacy Legacy API + * @{ + * @} + */ + +/** + * @defgroup legacy_control Collection Control + * @ingroup legacy + * General behavior: application continues to run, but no profiling information is being collected + * + * Pausing occurs not only for the current thread but for all process as well as spawned processes + * - Intel(R) Parallel Inspector and Intel(R) Inspector XE: + * - Does not analyze or report errors that involve memory access. + * - Other errors are reported as usual. Pausing data collection in + * Intel(R) Parallel Inspector and Intel(R) Inspector XE + * only pauses tracing and analyzing memory access. + * It does not pause tracing or analyzing threading APIs. + * . + * - Intel(R) VTune(TM) Profiler: + * - Does continue to record when new threads are started. + * . + * - Other effects: + * - Possible reduction of runtime overhead. + * . + * @{ + */ +#ifndef _ITTNOTIFY_H_ +/** @brief Pause collection */ +void ITTAPI __itt_pause(void); +/** @brief Resume collection */ +void ITTAPI __itt_resume(void); +/** @brief Detach collection */ +void ITTAPI __itt_detach(void); + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +ITT_STUBV(ITTAPI, void, pause, (void)) +ITT_STUBV(ITTAPI, void, resume, (void)) +ITT_STUBV(ITTAPI, void, detach, (void)) +#define __itt_pause ITTNOTIFY_VOID(pause) +#define __itt_pause_ptr ITTNOTIFY_NAME(pause) +#define __itt_resume ITTNOTIFY_VOID(resume) +#define __itt_resume_ptr ITTNOTIFY_NAME(resume) +#define __itt_detach ITTNOTIFY_VOID(detach) +#define __itt_detach_ptr ITTNOTIFY_NAME(detach) +#else /* INTEL_NO_ITTNOTIFY_API */ +#define __itt_pause() +#define __itt_pause_ptr 0 +#define __itt_resume() +#define __itt_resume_ptr 0 +#define __itt_detach() +#define __itt_detach_ptr 0 +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#define __itt_pause_ptr 0 +#define __itt_resume_ptr 0 +#define __itt_detach_ptr 0 +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ +#endif /* _ITTNOTIFY_H_ */ +/** @} legacy_control group */ + +/** + * @defgroup legacy_threads Threads + * @ingroup legacy + * Threads group + * @warning Legacy API + * @{ + */ +/** + * @deprecated Legacy API + * @brief Set name to be associated with thread in analysis GUI. + * @return __itt_err upon failure (name or namelen being null,name and namelen mismatched) + */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN +int LIBITTAPI __itt_thr_name_setA(const char *name, int namelen); +int LIBITTAPI __itt_thr_name_setW(const wchar_t *name, int namelen); +#if defined(UNICODE) || defined(_UNICODE) +# define __itt_thr_name_set __itt_thr_name_setW +# define __itt_thr_name_set_ptr __itt_thr_name_setW_ptr +#else +# define __itt_thr_name_set __itt_thr_name_setA +# define __itt_thr_name_set_ptr __itt_thr_name_setA_ptr +#endif /* UNICODE */ +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +int LIBITTAPI __itt_thr_name_set(const char *name, int namelen); +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +#if ITT_PLATFORM==ITT_PLATFORM_WIN +ITT_STUB(LIBITTAPI, int, thr_name_setA, (const char *name, int namelen)) +ITT_STUB(LIBITTAPI, int, thr_name_setW, (const wchar_t *name, int namelen)) +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +ITT_STUB(LIBITTAPI, int, thr_name_set, (const char *name, int namelen)) +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#define __itt_thr_name_setA ITTNOTIFY_DATA(thr_name_setA) +#define __itt_thr_name_setA_ptr ITTNOTIFY_NAME(thr_name_setA) +#define __itt_thr_name_setW ITTNOTIFY_DATA(thr_name_setW) +#define __itt_thr_name_setW_ptr ITTNOTIFY_NAME(thr_name_setW) +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#define __itt_thr_name_set ITTNOTIFY_DATA(thr_name_set) +#define __itt_thr_name_set_ptr ITTNOTIFY_NAME(thr_name_set) +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#else /* INTEL_NO_ITTNOTIFY_API */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#define __itt_thr_name_setA(name, namelen) +#define __itt_thr_name_setA_ptr 0 +#define __itt_thr_name_setW(name, namelen) +#define __itt_thr_name_setW_ptr 0 +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#define __itt_thr_name_set(name, namelen) +#define __itt_thr_name_set_ptr 0 +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#define __itt_thr_name_setA_ptr 0 +#define __itt_thr_name_setW_ptr 0 +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#define __itt_thr_name_set_ptr 0 +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ + +/** + * @deprecated Legacy API + * @brief Mark current thread as ignored from this point on, for the duration of its existence. + */ +void LIBITTAPI __itt_thr_ignore(void); + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +ITT_STUBV(LIBITTAPI, void, thr_ignore, (void)) +#define __itt_thr_ignore ITTNOTIFY_VOID(thr_ignore) +#define __itt_thr_ignore_ptr ITTNOTIFY_NAME(thr_ignore) +#else /* INTEL_NO_ITTNOTIFY_API */ +#define __itt_thr_ignore() +#define __itt_thr_ignore_ptr 0 +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#define __itt_thr_ignore_ptr 0 +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ +/** @} legacy_threads group */ + +/** + * @defgroup legacy_sync Synchronization + * @ingroup legacy + * Synchronization group + * @warning Legacy API + * @{ + */ +/** + * @hideinitializer + * @brief possible value of attribute argument for sync object type + */ +#define __itt_attr_barrier 1 + +/** + * @hideinitializer + * @brief possible value of attribute argument for sync object type + */ +#define __itt_attr_mutex 2 + +/** + * @deprecated Legacy API + * @brief Assign a name to a sync object using char or Unicode string + * @param[in] addr - pointer to the sync object. You should use a real pointer to your object + * to make sure that the values don't clash with other object addresses + * @param[in] objtype - null-terminated object type string. If NULL is passed, the object will + * be assumed to be of generic "User Synchronization" type + * @param[in] objname - null-terminated object name string. If NULL, no name will be assigned + * to the object -- you can use the __itt_sync_rename call later to assign + * the name + * @param[in] attribute - one of [#__itt_attr_barrier, #__itt_attr_mutex] values which defines the + * exact semantics of how prepare/acquired/releasing calls work. + */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN +void ITTAPI __itt_sync_set_nameA(void *addr, const char *objtype, const char *objname, int attribute); +void ITTAPI __itt_sync_set_nameW(void *addr, const wchar_t *objtype, const wchar_t *objname, int attribute); +#if defined(UNICODE) || defined(_UNICODE) +# define __itt_sync_set_name __itt_sync_set_nameW +# define __itt_sync_set_name_ptr __itt_sync_set_nameW_ptr +#else /* UNICODE */ +# define __itt_sync_set_name __itt_sync_set_nameA +# define __itt_sync_set_name_ptr __itt_sync_set_nameA_ptr +#endif /* UNICODE */ +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +void ITTAPI __itt_sync_set_name(void *addr, const char* objtype, const char* objname, int attribute); +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +#if ITT_PLATFORM==ITT_PLATFORM_WIN +ITT_STUBV(ITTAPI, void, sync_set_nameA, (void *addr, const char *objtype, const char *objname, int attribute)) +ITT_STUBV(ITTAPI, void, sync_set_nameW, (void *addr, const wchar_t *objtype, const wchar_t *objname, int attribute)) +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +ITT_STUBV(ITTAPI, void, sync_set_name, (void *addr, const char *objtype, const char *objname, int attribute)) +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#define __itt_sync_set_nameA ITTNOTIFY_VOID(sync_set_nameA) +#define __itt_sync_set_nameA_ptr ITTNOTIFY_NAME(sync_set_nameA) +#define __itt_sync_set_nameW ITTNOTIFY_VOID(sync_set_nameW) +#define __itt_sync_set_nameW_ptr ITTNOTIFY_NAME(sync_set_nameW) +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#define __itt_sync_set_name ITTNOTIFY_VOID(sync_set_name) +#define __itt_sync_set_name_ptr ITTNOTIFY_NAME(sync_set_name) +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#else /* INTEL_NO_ITTNOTIFY_API */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#define __itt_sync_set_nameA(addr, objtype, objname, attribute) +#define __itt_sync_set_nameA_ptr 0 +#define __itt_sync_set_nameW(addr, objtype, objname, attribute) +#define __itt_sync_set_nameW_ptr 0 +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#define __itt_sync_set_name(addr, objtype, objname, attribute) +#define __itt_sync_set_name_ptr 0 +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#define __itt_sync_set_nameA_ptr 0 +#define __itt_sync_set_nameW_ptr 0 +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#define __itt_sync_set_name_ptr 0 +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ + +/** + * @deprecated Legacy API + * @brief Assign a name and type to a sync object using char or Unicode string + * @param[in] addr - pointer to the sync object. You should use a real pointer to your object + * to make sure that the values don't clash with other object addresses + * @param[in] objtype - null-terminated object type string. If NULL is passed, the object will + * be assumed to be of generic "User Synchronization" type + * @param[in] objname - null-terminated object name string. If NULL, no name will be assigned + * to the object -- you can use the __itt_sync_rename call later to assign + * the name + * @param[in] typelen, namelen - a length of string for appropriate objtype and objname parameter + * @param[in] attribute - one of [#__itt_attr_barrier, #__itt_attr_mutex] values which defines the + * exact semantics of how prepare/acquired/releasing calls work. + * @return __itt_err upon failure (name or namelen being null,name and namelen mismatched) + */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN +int LIBITTAPI __itt_notify_sync_nameA(void *addr, const char *objtype, int typelen, const char *objname, int namelen, int attribute); +int LIBITTAPI __itt_notify_sync_nameW(void *addr, const wchar_t *objtype, int typelen, const wchar_t *objname, int namelen, int attribute); +#if defined(UNICODE) || defined(_UNICODE) +# define __itt_notify_sync_name __itt_notify_sync_nameW +#else +# define __itt_notify_sync_name __itt_notify_sync_nameA +#endif +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +int LIBITTAPI __itt_notify_sync_name(void *addr, const char *objtype, int typelen, const char *objname, int namelen, int attribute); +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +#if ITT_PLATFORM==ITT_PLATFORM_WIN +ITT_STUB(LIBITTAPI, int, notify_sync_nameA, (void *addr, const char *objtype, int typelen, const char *objname, int namelen, int attribute)) +ITT_STUB(LIBITTAPI, int, notify_sync_nameW, (void *addr, const wchar_t *objtype, int typelen, const wchar_t *objname, int namelen, int attribute)) +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +ITT_STUB(LIBITTAPI, int, notify_sync_name, (void *addr, const char *objtype, int typelen, const char *objname, int namelen, int attribute)) +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#define __itt_notify_sync_nameA ITTNOTIFY_DATA(notify_sync_nameA) +#define __itt_notify_sync_nameA_ptr ITTNOTIFY_NAME(notify_sync_nameA) +#define __itt_notify_sync_nameW ITTNOTIFY_DATA(notify_sync_nameW) +#define __itt_notify_sync_nameW_ptr ITTNOTIFY_NAME(notify_sync_nameW) +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#define __itt_notify_sync_name ITTNOTIFY_DATA(notify_sync_name) +#define __itt_notify_sync_name_ptr ITTNOTIFY_NAME(notify_sync_name) +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#else /* INTEL_NO_ITTNOTIFY_API */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#define __itt_notify_sync_nameA(addr, objtype, typelen, objname, namelen, attribute) +#define __itt_notify_sync_nameA_ptr 0 +#define __itt_notify_sync_nameW(addr, objtype, typelen, objname, namelen, attribute) +#define __itt_notify_sync_nameW_ptr 0 +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#define __itt_notify_sync_name(addr, objtype, typelen, objname, namelen, attribute) +#define __itt_notify_sync_name_ptr 0 +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#define __itt_notify_sync_nameA_ptr 0 +#define __itt_notify_sync_nameW_ptr 0 +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#define __itt_notify_sync_name_ptr 0 +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ + +/** + * @deprecated Legacy API + * @brief Enter spin loop on user-defined sync object + */ +void LIBITTAPI __itt_notify_sync_prepare(void* addr); + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +ITT_STUBV(LIBITTAPI, void, notify_sync_prepare, (void *addr)) +#define __itt_notify_sync_prepare ITTNOTIFY_VOID(notify_sync_prepare) +#define __itt_notify_sync_prepare_ptr ITTNOTIFY_NAME(notify_sync_prepare) +#else /* INTEL_NO_ITTNOTIFY_API */ +#define __itt_notify_sync_prepare(addr) +#define __itt_notify_sync_prepare_ptr 0 +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#define __itt_notify_sync_prepare_ptr 0 +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ + +/** + * @deprecated Legacy API + * @brief Quit spin loop without acquiring spin object + */ +void LIBITTAPI __itt_notify_sync_cancel(void *addr); + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +ITT_STUBV(LIBITTAPI, void, notify_sync_cancel, (void *addr)) +#define __itt_notify_sync_cancel ITTNOTIFY_VOID(notify_sync_cancel) +#define __itt_notify_sync_cancel_ptr ITTNOTIFY_NAME(notify_sync_cancel) +#else /* INTEL_NO_ITTNOTIFY_API */ +#define __itt_notify_sync_cancel(addr) +#define __itt_notify_sync_cancel_ptr 0 +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#define __itt_notify_sync_cancel_ptr 0 +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ + +/** + * @deprecated Legacy API + * @brief Successful spin loop completion (sync object acquired) + */ +void LIBITTAPI __itt_notify_sync_acquired(void *addr); + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +ITT_STUBV(LIBITTAPI, void, notify_sync_acquired, (void *addr)) +#define __itt_notify_sync_acquired ITTNOTIFY_VOID(notify_sync_acquired) +#define __itt_notify_sync_acquired_ptr ITTNOTIFY_NAME(notify_sync_acquired) +#else /* INTEL_NO_ITTNOTIFY_API */ +#define __itt_notify_sync_acquired(addr) +#define __itt_notify_sync_acquired_ptr 0 +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#define __itt_notify_sync_acquired_ptr 0 +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ + +/** + * @deprecated Legacy API + * @brief Start sync object releasing code. Is called before the lock release call. + */ +void LIBITTAPI __itt_notify_sync_releasing(void* addr); + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +ITT_STUBV(LIBITTAPI, void, notify_sync_releasing, (void *addr)) +#define __itt_notify_sync_releasing ITTNOTIFY_VOID(notify_sync_releasing) +#define __itt_notify_sync_releasing_ptr ITTNOTIFY_NAME(notify_sync_releasing) +#else /* INTEL_NO_ITTNOTIFY_API */ +#define __itt_notify_sync_releasing(addr) +#define __itt_notify_sync_releasing_ptr 0 +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#define __itt_notify_sync_releasing_ptr 0 +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ +/** @} legacy_sync group */ + +#ifndef _ITTNOTIFY_H_ +/** + * @defgroup legacy_events Events + * @ingroup legacy + * Events group + * @{ + */ + +/** @brief user event type */ +typedef int __itt_event; + +/** + * @brief Create an event notification + * @note name or namelen being null/name and namelen not matching, user event feature not enabled + * @return non-zero event identifier upon success and __itt_err otherwise + */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN +__itt_event LIBITTAPI __itt_event_createA(const char *name, int namelen); +__itt_event LIBITTAPI __itt_event_createW(const wchar_t *name, int namelen); +#if defined(UNICODE) || defined(_UNICODE) +# define __itt_event_create __itt_event_createW +# define __itt_event_create_ptr __itt_event_createW_ptr +#else +# define __itt_event_create __itt_event_createA +# define __itt_event_create_ptr __itt_event_createA_ptr +#endif /* UNICODE */ +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +__itt_event LIBITTAPI __itt_event_create(const char *name, int namelen); +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +#if ITT_PLATFORM==ITT_PLATFORM_WIN +ITT_STUB(LIBITTAPI, __itt_event, event_createA, (const char *name, int namelen)) +ITT_STUB(LIBITTAPI, __itt_event, event_createW, (const wchar_t *name, int namelen)) +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +ITT_STUB(LIBITTAPI, __itt_event, event_create, (const char *name, int namelen)) +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#define __itt_event_createA ITTNOTIFY_DATA(event_createA) +#define __itt_event_createA_ptr ITTNOTIFY_NAME(event_createA) +#define __itt_event_createW ITTNOTIFY_DATA(event_createW) +#define __itt_event_createW_ptr ITTNOTIFY_NAME(event_createW) +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#define __itt_event_create ITTNOTIFY_DATA(event_create) +#define __itt_event_create_ptr ITTNOTIFY_NAME(event_create) +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#else /* INTEL_NO_ITTNOTIFY_API */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#define __itt_event_createA(name, namelen) (__itt_event)0 +#define __itt_event_createA_ptr 0 +#define __itt_event_createW(name, namelen) (__itt_event)0 +#define __itt_event_createW_ptr 0 +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#define __itt_event_create(name, namelen) (__itt_event)0 +#define __itt_event_create_ptr 0 +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#define __itt_event_createA_ptr 0 +#define __itt_event_createW_ptr 0 +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#define __itt_event_create_ptr 0 +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ + +/** + * @brief Record an event occurrence. + * @return __itt_err upon failure (invalid event id/user event feature not enabled) + */ +int LIBITTAPI __itt_event_start(__itt_event event); + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +ITT_STUB(LIBITTAPI, int, event_start, (__itt_event event)) +#define __itt_event_start ITTNOTIFY_DATA(event_start) +#define __itt_event_start_ptr ITTNOTIFY_NAME(event_start) +#else /* INTEL_NO_ITTNOTIFY_API */ +#define __itt_event_start(event) (int)0 +#define __itt_event_start_ptr 0 +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#define __itt_event_start_ptr 0 +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ + +/** + * @brief Record an event end occurrence. + * @note It is optional if events do not have durations. + * @return __itt_err upon failure (invalid event id/user event feature not enabled) + */ +int LIBITTAPI __itt_event_end(__itt_event event); + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +ITT_STUB(LIBITTAPI, int, event_end, (__itt_event event)) +#define __itt_event_end ITTNOTIFY_DATA(event_end) +#define __itt_event_end_ptr ITTNOTIFY_NAME(event_end) +#else /* INTEL_NO_ITTNOTIFY_API */ +#define __itt_event_end(event) (int)0 +#define __itt_event_end_ptr 0 +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#define __itt_event_end_ptr 0 +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ +/** @} legacy_events group */ +#endif /* _ITTNOTIFY_H_ */ + +/** + * @defgroup legacy_memory Memory Accesses + * @ingroup legacy + */ + +/** + * @deprecated Legacy API + * @brief Inform the tool of memory accesses on reading + */ +void LIBITTAPI __itt_memory_read(void *addr, size_t size); + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +ITT_STUBV(LIBITTAPI, void, memory_read, (void *addr, size_t size)) +#define __itt_memory_read ITTNOTIFY_VOID(memory_read) +#define __itt_memory_read_ptr ITTNOTIFY_NAME(memory_read) +#else /* INTEL_NO_ITTNOTIFY_API */ +#define __itt_memory_read(addr, size) +#define __itt_memory_read_ptr 0 +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#define __itt_memory_read_ptr 0 +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ + +/** + * @deprecated Legacy API + * @brief Inform the tool of memory accesses on writing + */ +void LIBITTAPI __itt_memory_write(void *addr, size_t size); + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +ITT_STUBV(LIBITTAPI, void, memory_write, (void *addr, size_t size)) +#define __itt_memory_write ITTNOTIFY_VOID(memory_write) +#define __itt_memory_write_ptr ITTNOTIFY_NAME(memory_write) +#else /* INTEL_NO_ITTNOTIFY_API */ +#define __itt_memory_write(addr, size) +#define __itt_memory_write_ptr 0 +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#define __itt_memory_write_ptr 0 +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ + +/** + * @deprecated Legacy API + * @brief Inform the tool of memory accesses on updating + */ +void LIBITTAPI __itt_memory_update(void *address, size_t size); + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +ITT_STUBV(LIBITTAPI, void, memory_update, (void *addr, size_t size)) +#define __itt_memory_update ITTNOTIFY_VOID(memory_update) +#define __itt_memory_update_ptr ITTNOTIFY_NAME(memory_update) +#else /* INTEL_NO_ITTNOTIFY_API */ +#define __itt_memory_update(addr, size) +#define __itt_memory_update_ptr 0 +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#define __itt_memory_update_ptr 0 +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ +/** @} legacy_memory group */ + +/** + * @defgroup legacy_state Thread and Object States + * @ingroup legacy + */ + +/** @brief state type */ +typedef int __itt_state_t; + +/** @cond exclude_from_documentation */ +typedef enum __itt_obj_state { + __itt_obj_state_err = 0, + __itt_obj_state_clr = 1, + __itt_obj_state_set = 2, + __itt_obj_state_use = 3 +} __itt_obj_state_t; + +typedef enum __itt_thr_state { + __itt_thr_state_err = 0, + __itt_thr_state_clr = 1, + __itt_thr_state_set = 2 +} __itt_thr_state_t; + +typedef enum __itt_obj_prop { + __itt_obj_prop_watch = 1, + __itt_obj_prop_ignore = 2, + __itt_obj_prop_sharable = 3 +} __itt_obj_prop_t; + +typedef enum __itt_thr_prop { + __itt_thr_prop_quiet = 1 +} __itt_thr_prop_t; +/** @endcond */ + +/** + * @deprecated Legacy API + * @brief managing thread and object states + */ +__itt_state_t LIBITTAPI __itt_state_get(void); + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +ITT_STUB(ITTAPI, __itt_state_t, state_get, (void)) +#define __itt_state_get ITTNOTIFY_DATA(state_get) +#define __itt_state_get_ptr ITTNOTIFY_NAME(state_get) +#else /* INTEL_NO_ITTNOTIFY_API */ +#define __itt_state_get(void) (__itt_state_t)0 +#define __itt_state_get_ptr 0 +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#define __itt_state_get_ptr 0 +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ + +/** + * @deprecated Legacy API + * @brief managing thread and object states + */ +__itt_state_t LIBITTAPI __itt_state_set(__itt_state_t s); + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +ITT_STUB(ITTAPI, __itt_state_t, state_set, (__itt_state_t s)) +#define __itt_state_set ITTNOTIFY_DATA(state_set) +#define __itt_state_set_ptr ITTNOTIFY_NAME(state_set) +#else /* INTEL_NO_ITTNOTIFY_API */ +#define __itt_state_set(s) (__itt_state_t)0 +#define __itt_state_set_ptr 0 +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#define __itt_state_set_ptr 0 +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ + +/** + * @deprecated Legacy API + * @brief managing thread and object modes + */ +__itt_thr_state_t LIBITTAPI __itt_thr_mode_set(__itt_thr_prop_t p, __itt_thr_state_t s); + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +ITT_STUB(ITTAPI, __itt_thr_state_t, thr_mode_set, (__itt_thr_prop_t p, __itt_thr_state_t s)) +#define __itt_thr_mode_set ITTNOTIFY_DATA(thr_mode_set) +#define __itt_thr_mode_set_ptr ITTNOTIFY_NAME(thr_mode_set) +#else /* INTEL_NO_ITTNOTIFY_API */ +#define __itt_thr_mode_set(p, s) (__itt_thr_state_t)0 +#define __itt_thr_mode_set_ptr 0 +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#define __itt_thr_mode_set_ptr 0 +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ + +/** + * @deprecated Legacy API + * @brief managing thread and object modes + */ +__itt_obj_state_t LIBITTAPI __itt_obj_mode_set(__itt_obj_prop_t p, __itt_obj_state_t s); + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +ITT_STUB(ITTAPI, __itt_obj_state_t, obj_mode_set, (__itt_obj_prop_t p, __itt_obj_state_t s)) +#define __itt_obj_mode_set ITTNOTIFY_DATA(obj_mode_set) +#define __itt_obj_mode_set_ptr ITTNOTIFY_NAME(obj_mode_set) +#else /* INTEL_NO_ITTNOTIFY_API */ +#define __itt_obj_mode_set(p, s) (__itt_obj_state_t)0 +#define __itt_obj_mode_set_ptr 0 +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#define __itt_obj_mode_set_ptr 0 +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ +/** @} legacy_state group */ + +/** + * @defgroup frames Frames + * @ingroup legacy + * Frames group + * @{ + */ +/** + * @brief opaque structure for frame identification + */ +typedef struct __itt_frame_t *__itt_frame; + +/** + * @brief Create a global frame with given domain + */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN +__itt_frame ITTAPI __itt_frame_createA(const char *domain); +__itt_frame ITTAPI __itt_frame_createW(const wchar_t *domain); +#if defined(UNICODE) || defined(_UNICODE) +# define __itt_frame_create __itt_frame_createW +# define __itt_frame_create_ptr __itt_frame_createW_ptr +#else /* UNICODE */ +# define __itt_frame_create __itt_frame_createA +# define __itt_frame_create_ptr __itt_frame_createA_ptr +#endif /* UNICODE */ +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +__itt_frame ITTAPI __itt_frame_create(const char *domain); +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +#if ITT_PLATFORM==ITT_PLATFORM_WIN +ITT_STUB(ITTAPI, __itt_frame, frame_createA, (const char *domain)) +ITT_STUB(ITTAPI, __itt_frame, frame_createW, (const wchar_t *domain)) +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +ITT_STUB(ITTAPI, __itt_frame, frame_create, (const char *domain)) +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#define __itt_frame_createA ITTNOTIFY_DATA(frame_createA) +#define __itt_frame_createA_ptr ITTNOTIFY_NAME(frame_createA) +#define __itt_frame_createW ITTNOTIFY_DATA(frame_createW) +#define __itt_frame_createW_ptr ITTNOTIFY_NAME(frame_createW) +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#define __itt_frame_create ITTNOTIFY_DATA(frame_create) +#define __itt_frame_create_ptr ITTNOTIFY_NAME(frame_create) +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#else /* INTEL_NO_ITTNOTIFY_API */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#define __itt_frame_createA(domain) +#define __itt_frame_createA_ptr 0 +#define __itt_frame_createW(domain) +#define __itt_frame_createW_ptr 0 +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#define __itt_frame_create(domain) +#define __itt_frame_create_ptr 0 +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#define __itt_frame_createA_ptr 0 +#define __itt_frame_createW_ptr 0 +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#define __itt_frame_create_ptr 0 +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ + +/** @brief Record a frame begin occurrence. */ +void ITTAPI __itt_frame_begin(__itt_frame frame); +/** @brief Record a frame end occurrence. */ +void ITTAPI __itt_frame_end (__itt_frame frame); + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +ITT_STUBV(ITTAPI, void, frame_begin, (__itt_frame frame)) +ITT_STUBV(ITTAPI, void, frame_end, (__itt_frame frame)) +#define __itt_frame_begin ITTNOTIFY_VOID(frame_begin) +#define __itt_frame_begin_ptr ITTNOTIFY_NAME(frame_begin) +#define __itt_frame_end ITTNOTIFY_VOID(frame_end) +#define __itt_frame_end_ptr ITTNOTIFY_NAME(frame_end) +#else /* INTEL_NO_ITTNOTIFY_API */ +#define __itt_frame_begin(frame) +#define __itt_frame_begin_ptr 0 +#define __itt_frame_end(frame) +#define __itt_frame_end_ptr 0 +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#define __itt_frame_begin_ptr 0 +#define __itt_frame_end_ptr 0 +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ +/** @} frames group */ + +#ifdef __cplusplus +} +#endif /* __cplusplus */ + +#endif /* _LEGACY_ITTNOTIFY_H_ */ + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/pybind11/attr.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/pybind11/attr.h new file mode 100644 index 0000000000000000000000000000000000000000..cdf277758c1b84ed67ebd8199abfcf9198ef4a65 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/pybind11/attr.h @@ -0,0 +1,727 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +/* + pybind11/attr.h: Infrastructure for processing custom + type and function attributes + + Copyright (c) 2016 Wenzel Jakob + + All rights reserved. Use of this source code is governed by a + BSD-style license that can be found in the LICENSE file. +*/ + +#pragma once + +#include "detail/common.h" +#include "cast.h" +#include "trampoline_self_life_support.h" + +#include + +PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE) + +/// \addtogroup annotations +/// @{ + +/// Annotation for methods +struct is_method { + handle class_; + explicit is_method(const handle &c) : class_(c) {} +}; + +/// Annotation for setters +struct is_setter {}; + +/// Annotation for operators +struct is_operator {}; + +/// Annotation for classes that cannot be subclassed +struct is_final {}; + +/// Annotation for parent scope +struct scope { + handle value; + explicit scope(const handle &s) : value(s) {} +}; + +/// Annotation for documentation +struct doc { + const char *value; + explicit doc(const char *value) : value(value) {} +}; + +/// Annotation for function names +struct name { + const char *value; + explicit name(const char *value) : value(value) {} +}; + +/// Annotation indicating that a function is an overload associated with a given "sibling" +struct sibling { + handle value; + explicit sibling(const handle &value) : value(value.ptr()) {} +}; + +/// Annotation indicating that a class derives from another given type +template +struct base { + + PYBIND11_DEPRECATED( + "base() was deprecated in favor of specifying 'T' as a template argument to class_") + base() = default; +}; + +/// Keep patient alive while nurse lives +template +struct keep_alive {}; + +/// Annotation indicating that a class is involved in a multiple inheritance relationship +struct multiple_inheritance {}; + +/// Annotation which enables dynamic attributes, i.e. adds `__dict__` to a class +struct dynamic_attr {}; + +/// Annotation which enables the buffer protocol for a type +struct buffer_protocol {}; + +/// Annotation which enables releasing the GIL before calling the C++ destructor of wrapped +/// instances (pybind/pybind11#1446). +struct release_gil_before_calling_cpp_dtor {}; + +/// Annotation which requests that a special metaclass is created for a type +struct metaclass { + handle value; + + PYBIND11_DEPRECATED("py::metaclass() is no longer required. It's turned on by default now.") + metaclass() = default; + + /// Override pybind11's default metaclass + explicit metaclass(handle value) : value(value) {} +}; + +/// Specifies a custom callback with signature `void (PyHeapTypeObject*)` that +/// may be used to customize the Python type. +/// +/// The callback is invoked immediately before `PyType_Ready`. +/// +/// Note: This is an advanced interface, and uses of it may require changes to +/// work with later versions of pybind11. You may wish to consult the +/// implementation of `make_new_python_type` in `detail/classes.h` to understand +/// the context in which the callback will be run. +struct custom_type_setup { + using callback = std::function; + + explicit custom_type_setup(callback value) : value(std::move(value)) {} + + callback value; +}; + +/// Annotation that marks a class as local to the module: +struct module_local { + const bool value; + constexpr explicit module_local(bool v = true) : value(v) {} +}; + +/// Annotation to mark enums as an arithmetic type +struct arithmetic {}; + +/// Mark a function for addition at the beginning of the existing overload chain instead of the end +struct prepend {}; + +/** \rst + A call policy which places one or more guard variables (``Ts...``) around the function call. + + For example, this definition: + + .. code-block:: cpp + + m.def("foo", foo, py::call_guard()); + + is equivalent to the following pseudocode: + + .. code-block:: cpp + + m.def("foo", [](args...) { + T scope_guard; + return foo(args...); // forwarded arguments + }); + \endrst */ +template +struct call_guard; + +template <> +struct call_guard<> { + using type = detail::void_type; +}; + +template +struct call_guard { + static_assert(std::is_default_constructible::value, + "The guard type must be default constructible"); + + using type = T; +}; + +template +struct call_guard { + struct type { + T guard{}; // Compose multiple guard types with left-to-right default-constructor order + typename call_guard::type next{}; + }; +}; + +/// @} annotations + +PYBIND11_NAMESPACE_BEGIN(detail) +/* Forward declarations */ +enum op_id : int; +enum op_type : int; +struct undefined_t; +template +struct op_; +void keep_alive_impl(size_t Nurse, size_t Patient, function_call &call, handle ret); + +/// Internal data structure which holds metadata about a keyword argument +struct argument_record { + const char *name; ///< Argument name + const char *descr; ///< Human-readable version of the argument value + handle value; ///< Associated Python object + bool convert : 1; ///< True if the argument is allowed to convert when loading + bool none : 1; ///< True if None is allowed when loading + + argument_record(const char *name, const char *descr, handle value, bool convert, bool none) + : name(name), descr(descr), value(value), convert(convert), none(none) {} +}; + +/// Internal data structure which holds metadata about a bound function (signature, overloads, +/// etc.) +#define PYBIND11_DETAIL_FUNCTION_RECORD_ABI_ID "v1" // PLEASE UPDATE if the struct is changed. +struct function_record { + function_record() + : is_constructor(false), is_new_style_constructor(false), is_stateless(false), + is_operator(false), is_method(false), is_setter(false), has_args(false), + has_kwargs(false), prepend(false) {} + + /// Function name + char *name = nullptr; /* why no C++ strings? They generate heavier code.. */ + + // User-specified documentation string + char *doc = nullptr; + + /// Human-readable version of the function signature + char *signature = nullptr; + + /// List of registered keyword arguments + std::vector args; + + /// Pointer to lambda function which converts arguments and performs the actual call + handle (*impl)(function_call &) = nullptr; + + /// Storage for the wrapped function pointer and captured data, if any + void *data[3] = {}; + + /// Pointer to custom destructor for 'data' (if needed) + void (*free_data)(function_record *ptr) = nullptr; + + /// Return value policy associated with this function + return_value_policy policy = return_value_policy::automatic; + + /// True if name == '__init__' + bool is_constructor : 1; + + /// True if this is a new-style `__init__` defined in `detail/init.h` + bool is_new_style_constructor : 1; + + /// True if this is a stateless function pointer + bool is_stateless : 1; + + /// True if this is an operator (__add__), etc. + bool is_operator : 1; + + /// True if this is a method + bool is_method : 1; + + /// True if this is a setter + bool is_setter : 1; + + /// True if the function has a '*args' argument + bool has_args : 1; + + /// True if the function has a '**kwargs' argument + bool has_kwargs : 1; + + /// True if this function is to be inserted at the beginning of the overload resolution chain + bool prepend : 1; + + /// Number of arguments (including py::args and/or py::kwargs, if present) + std::uint16_t nargs; + + /// Number of leading positional arguments, which are terminated by a py::args or py::kwargs + /// argument or by a py::kw_only annotation. + std::uint16_t nargs_pos = 0; + + /// Number of leading arguments (counted in `nargs`) that are positional-only + std::uint16_t nargs_pos_only = 0; + + /// Python method object + PyMethodDef *def = nullptr; + + /// Python handle to the parent scope (a class or a module) + handle scope; + + /// Python handle to the sibling function representing an overload chain + handle sibling; + + /// Pointer to next overload + function_record *next = nullptr; +}; +// The main purpose of this macro is to make it easy to pin-point the critically related code +// sections. +#define PYBIND11_ENSURE_PRECONDITION_FOR_FUNCTIONAL_H_PERFORMANCE_OPTIMIZATIONS(...) \ + static_assert( \ + __VA_ARGS__, \ + "Violation of precondition for pybind11/functional.h performance optimizations!") + +/// Special data structure which (temporarily) holds metadata about a bound class +struct type_record { + PYBIND11_NOINLINE type_record() + : multiple_inheritance(false), dynamic_attr(false), buffer_protocol(false), + module_local(false), is_final(false), release_gil_before_calling_cpp_dtor(false) {} + + /// Handle to the parent scope + handle scope; + + /// Name of the class + const char *name = nullptr; + + // Pointer to RTTI type_info data structure + const std::type_info *type = nullptr; + + /// How large is the underlying C++ type? + size_t type_size = 0; + + /// What is the alignment of the underlying C++ type? + size_t type_align = 0; + + /// How large is the type's holder? + size_t holder_size = 0; + + /// The global operator new can be overridden with a class-specific variant + void *(*operator_new)(size_t) = nullptr; + + /// Function pointer to class_<..>::init_instance + void (*init_instance)(instance *, const void *) = nullptr; + + /// Function pointer to class_<..>::dealloc + void (*dealloc)(detail::value_and_holder &) = nullptr; + + /// Function pointer for casting alias class (aka trampoline) pointer to + /// trampoline_self_life_support pointer. Sidesteps cross-DSO RTTI issues + /// on platforms like macOS (see PR #5728 for details). + get_trampoline_self_life_support_fn get_trampoline_self_life_support + = [](void *) -> trampoline_self_life_support * { return nullptr; }; + + /// List of base classes of the newly created type + list bases; + + /// Optional docstring + const char *doc = nullptr; + + /// Custom metaclass (optional) + handle metaclass; + + /// Custom type setup. + custom_type_setup::callback custom_type_setup_callback; + + /// Multiple inheritance marker + bool multiple_inheritance : 1; + + /// Does the class manage a __dict__? + bool dynamic_attr : 1; + + /// Does the class implement the buffer protocol? + bool buffer_protocol : 1; + + /// Is the class definition local to the module shared object? + bool module_local : 1; + + /// Is the class inheritable from python classes? + bool is_final : 1; + + /// Solves pybind/pybind11#1446 + bool release_gil_before_calling_cpp_dtor : 1; + + holder_enum_t holder_enum_v = holder_enum_t::undefined; + + PYBIND11_NOINLINE void add_base(const std::type_info &base, void *(*caster)(void *) ) { + auto *base_info = detail::get_type_info(base, false); + if (!base_info) { + std::string tname(base.name()); + detail::clean_type_id(tname); + pybind11_fail("generic_type: type \"" + std::string(name) + + "\" referenced unknown base type \"" + tname + "\""); + } + + // SMART_HOLDER_BAKEIN_FOLLOW_ON: Refine holder compatibility checks. + bool this_has_unique_ptr_holder = (holder_enum_v == holder_enum_t::std_unique_ptr); + bool base_has_unique_ptr_holder + = (base_info->holder_enum_v == holder_enum_t::std_unique_ptr); + if (this_has_unique_ptr_holder != base_has_unique_ptr_holder) { + std::string tname(base.name()); + detail::clean_type_id(tname); + pybind11_fail("generic_type: type \"" + std::string(name) + "\" " + + (this_has_unique_ptr_holder ? "does not have" : "has") + + " a non-default holder type while its base \"" + tname + "\" " + + (base_has_unique_ptr_holder ? "does not" : "does")); + } + + bases.append((PyObject *) base_info->type); + +#ifdef PYBIND11_BACKWARD_COMPATIBILITY_TP_DICTOFFSET + dynamic_attr |= base_info->type->tp_dictoffset != 0; +#else + dynamic_attr |= (base_info->type->tp_flags & Py_TPFLAGS_MANAGED_DICT) != 0; +#endif + + if (caster) { + base_info->implicit_casts.emplace_back(type, caster); + } + } +}; + +inline function_call::function_call(const function_record &f, handle p) : func(f), parent(p) { + args.reserve(f.nargs); + args_convert.reserve(f.nargs); +} + +/// Tag for a new-style `__init__` defined in `detail/init.h` +struct is_new_style_constructor {}; + +/** + * Partial template specializations to process custom attributes provided to + * cpp_function_ and class_. These are either used to initialize the respective + * fields in the type_record and function_record data structures or executed at + * runtime to deal with custom call policies (e.g. keep_alive). + */ +template +struct process_attribute; + +template +struct process_attribute_default { + /// Default implementation: do nothing + static void init(const T &, function_record *) {} + static void init(const T &, type_record *) {} + static void precall(function_call &) {} + static void postcall(function_call &, handle) {} +}; + +/// Process an attribute specifying the function's name +template <> +struct process_attribute : process_attribute_default { + static void init(const name &n, function_record *r) { r->name = const_cast(n.value); } +}; + +/// Process an attribute specifying the function's docstring +template <> +struct process_attribute : process_attribute_default { + static void init(const doc &n, function_record *r) { r->doc = const_cast(n.value); } +}; + +/// Process an attribute specifying the function's docstring (provided as a C-style string) +template <> +struct process_attribute : process_attribute_default { + static void init(const char *d, function_record *r) { r->doc = const_cast(d); } + static void init(const char *d, type_record *r) { r->doc = d; } +}; +template <> +struct process_attribute : process_attribute {}; + +/// Process an attribute indicating the function's return value policy +template <> +struct process_attribute : process_attribute_default { + static void init(const return_value_policy &p, function_record *r) { r->policy = p; } +}; + +/// Process an attribute which indicates that this is an overloaded function associated with a +/// given sibling +template <> +struct process_attribute : process_attribute_default { + static void init(const sibling &s, function_record *r) { r->sibling = s.value; } +}; + +/// Process an attribute which indicates that this function is a method +template <> +struct process_attribute : process_attribute_default { + static void init(const is_method &s, function_record *r) { + r->is_method = true; + r->scope = s.class_; + } +}; + +/// Process an attribute which indicates that this function is a setter +template <> +struct process_attribute : process_attribute_default { + static void init(const is_setter &, function_record *r) { r->is_setter = true; } +}; + +/// Process an attribute which indicates the parent scope of a method +template <> +struct process_attribute : process_attribute_default { + static void init(const scope &s, function_record *r) { r->scope = s.value; } +}; + +/// Process an attribute which indicates that this function is an operator +template <> +struct process_attribute : process_attribute_default { + static void init(const is_operator &, function_record *r) { r->is_operator = true; } +}; + +template <> +struct process_attribute + : process_attribute_default { + static void init(const is_new_style_constructor &, function_record *r) { + r->is_new_style_constructor = true; + } +}; + +inline void check_kw_only_arg(const arg &a, function_record *r) { + if (r->args.size() > r->nargs_pos && (!a.name || a.name[0] == '\0')) { + pybind11_fail("arg(): cannot specify an unnamed argument after a kw_only() annotation or " + "args() argument"); + } +} + +inline void append_self_arg_if_needed(function_record *r) { + if (r->is_method && r->args.empty()) { + r->args.emplace_back("self", nullptr, handle(), /*convert=*/true, /*none=*/false); + } +} + +/// Process a keyword argument attribute (*without* a default value) +template <> +struct process_attribute : process_attribute_default { + static void init(const arg &a, function_record *r) { + append_self_arg_if_needed(r); + r->args.emplace_back(a.name, nullptr, handle(), !a.flag_noconvert, a.flag_none); + + check_kw_only_arg(a, r); + } +}; + +/// Process a keyword argument attribute (*with* a default value) +template <> +struct process_attribute : process_attribute_default { + static void init(const arg_v &a, function_record *r) { + if (r->is_method && r->args.empty()) { + r->args.emplace_back( + "self", /*descr=*/nullptr, /*parent=*/handle(), /*convert=*/true, /*none=*/false); + } + + if (!a.value) { +#if defined(PYBIND11_DETAILED_ERROR_MESSAGES) + std::string descr("'"); + if (a.name) { + descr += std::string(a.name) + ": "; + } + descr += a.type + "'"; + if (r->is_method) { + if (r->name) { + descr += " in method '" + (std::string) str(r->scope) + "." + + (std::string) r->name + "'"; + } else { + descr += " in method of '" + (std::string) str(r->scope) + "'"; + } + } else if (r->name) { + descr += " in function '" + (std::string) r->name + "'"; + } + pybind11_fail("arg(): could not convert default argument " + descr + + " into a Python object (type not registered yet?)"); +#else + pybind11_fail("arg(): could not convert default argument " + "into a Python object (type not registered yet?). " + "#define PYBIND11_DETAILED_ERROR_MESSAGES or compile in debug mode for " + "more information."); +#endif + } + r->args.emplace_back(a.name, a.descr, a.value.inc_ref(), !a.flag_noconvert, a.flag_none); + + check_kw_only_arg(a, r); + } +}; + +/// Process a keyword-only-arguments-follow pseudo argument +template <> +struct process_attribute : process_attribute_default { + static void init(const kw_only &, function_record *r) { + append_self_arg_if_needed(r); + if (r->has_args && r->nargs_pos != static_cast(r->args.size())) { + pybind11_fail("Mismatched args() and kw_only(): they must occur at the same relative " + "argument location (or omit kw_only() entirely)"); + } + r->nargs_pos = static_cast(r->args.size()); + } +}; + +/// Process a positional-only-argument maker +template <> +struct process_attribute : process_attribute_default { + static void init(const pos_only &, function_record *r) { + append_self_arg_if_needed(r); + r->nargs_pos_only = static_cast(r->args.size()); + if (r->nargs_pos_only > r->nargs_pos) { + pybind11_fail("pos_only(): cannot follow a py::args() argument"); + } + // It also can't follow a kw_only, but a static_assert in pybind11.h checks that + } +}; + +/// Process a parent class attribute. Single inheritance only (class_ itself already guarantees +/// that) +template +struct process_attribute::value>> + : process_attribute_default { + static void init(const handle &h, type_record *r) { r->bases.append(h); } +}; + +/// Process a parent class attribute (deprecated, does not support multiple inheritance) +template +struct process_attribute> : process_attribute_default> { + static void init(const base &, type_record *r) { r->add_base(typeid(T), nullptr); } +}; + +/// Process a multiple inheritance attribute +template <> +struct process_attribute : process_attribute_default { + static void init(const multiple_inheritance &, type_record *r) { + r->multiple_inheritance = true; + } +}; + +template <> +struct process_attribute : process_attribute_default { + static void init(const dynamic_attr &, type_record *r) { r->dynamic_attr = true; } +}; + +template <> +struct process_attribute { + static void init(const custom_type_setup &value, type_record *r) { + r->custom_type_setup_callback = value.value; + } +}; + +template <> +struct process_attribute : process_attribute_default { + static void init(const is_final &, type_record *r) { r->is_final = true; } +}; + +template <> +struct process_attribute : process_attribute_default { + static void init(const buffer_protocol &, type_record *r) { r->buffer_protocol = true; } +}; + +template <> +struct process_attribute : process_attribute_default { + static void init(const metaclass &m, type_record *r) { r->metaclass = m.value; } +}; + +template <> +struct process_attribute : process_attribute_default { + static void init(const module_local &l, type_record *r) { r->module_local = l.value; } +}; + +template <> +struct process_attribute + : process_attribute_default { + static void init(const release_gil_before_calling_cpp_dtor &, type_record *r) { + r->release_gil_before_calling_cpp_dtor = true; + } +}; + +/// Process a 'prepend' attribute, putting this at the beginning of the overload chain +template <> +struct process_attribute : process_attribute_default { + static void init(const prepend &, function_record *r) { r->prepend = true; } +}; + +/// Process an 'arithmetic' attribute for enums (does nothing here) +template <> +struct process_attribute : process_attribute_default {}; + +template +struct process_attribute> : process_attribute_default> {}; + +/** + * Process a keep_alive call policy -- invokes keep_alive_impl during the + * pre-call handler if both Nurse, Patient != 0 and use the post-call handler + * otherwise + */ +template +struct process_attribute> + : public process_attribute_default> { + template = 0> + static void precall(function_call &call) { + keep_alive_impl(Nurse, Patient, call, handle()); + } + template = 0> + static void postcall(function_call &, handle) {} + template = 0> + static void precall(function_call &) {} + template = 0> + static void postcall(function_call &call, handle ret) { + keep_alive_impl(Nurse, Patient, call, ret); + } +}; + +/// Recursively iterate over variadic template arguments +template +struct process_attributes { + static void init(const Args &...args, function_record *r) { + PYBIND11_WORKAROUND_INCORRECT_MSVC_C4100(r); + PYBIND11_WORKAROUND_INCORRECT_GCC_UNUSED_BUT_SET_PARAMETER(r); + using expander = int[]; + (void) expander{ + 0, ((void) process_attribute::type>::init(args, r), 0)...}; + } + static void init(const Args &...args, type_record *r) { + PYBIND11_WORKAROUND_INCORRECT_MSVC_C4100(r); + PYBIND11_WORKAROUND_INCORRECT_GCC_UNUSED_BUT_SET_PARAMETER(r); + using expander = int[]; + (void) expander{0, + (process_attribute::type>::init(args, r), 0)...}; + } + static void precall(function_call &call) { + PYBIND11_WORKAROUND_INCORRECT_MSVC_C4100(call); + using expander = int[]; + (void) expander{0, + (process_attribute::type>::precall(call), 0)...}; + } + static void postcall(function_call &call, handle fn_ret) { + PYBIND11_WORKAROUND_INCORRECT_MSVC_C4100(call, fn_ret); + PYBIND11_WORKAROUND_INCORRECT_GCC_UNUSED_BUT_SET_PARAMETER(fn_ret); + using expander = int[]; + (void) expander{ + 0, (process_attribute::type>::postcall(call, fn_ret), 0)...}; + } +}; + +template +using is_call_guard = is_instantiation; + +/// Extract the ``type`` from the first `call_guard` in `Extras...` (or `void_type` if none found) +template +using extract_guard_t = typename exactly_one_t, Extra...>::type; + +/// Check the number of named arguments at compile time +template ::value...), + size_t self = constexpr_sum(std::is_same::value...)> +constexpr bool expected_num_args(size_t nargs, bool has_args, bool has_kwargs) { + PYBIND11_WORKAROUND_INCORRECT_MSVC_C4100(nargs, has_args, has_kwargs); + return named == 0 || (self + named + size_t(has_args) + size_t(has_kwargs)) == nargs; +} + +PYBIND11_NAMESPACE_END(detail) +PYBIND11_NAMESPACE_END(PYBIND11_NAMESPACE) + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/pybind11/buffer_info.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/pybind11/buffer_info.h new file mode 100644 index 0000000000000000000000000000000000000000..78a026db838f59376549ee53908fc892cecc8817 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/pybind11/buffer_info.h @@ -0,0 +1,213 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +/* + pybind11/buffer_info.h: Python buffer object interface + + Copyright (c) 2016 Wenzel Jakob + + All rights reserved. Use of this source code is governed by a + BSD-style license that can be found in the LICENSE file. +*/ + +#pragma once + +#include "detail/common.h" + +PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE) + +PYBIND11_NAMESPACE_BEGIN(detail) + +// Default, C-style strides +inline std::vector c_strides(const std::vector &shape, ssize_t itemsize) { + auto ndim = shape.size(); + std::vector strides(ndim, itemsize); + if (ndim > 0) { + for (size_t i = ndim - 1; i > 0; --i) { + strides[i - 1] = strides[i] * shape[i]; + } + } + return strides; +} + +// F-style strides; default when constructing an array_t with `ExtraFlags & f_style` +inline std::vector f_strides(const std::vector &shape, ssize_t itemsize) { + auto ndim = shape.size(); + std::vector strides(ndim, itemsize); + for (size_t i = 1; i < ndim; ++i) { + strides[i] = strides[i - 1] * shape[i - 1]; + } + return strides; +} + +template +struct compare_buffer_info; + +PYBIND11_NAMESPACE_END(detail) + +/// Information record describing a Python buffer object +struct buffer_info { + void *ptr = nullptr; // Pointer to the underlying storage + ssize_t itemsize = 0; // Size of individual items in bytes + ssize_t size = 0; // Total number of entries + std::string format; // For homogeneous buffers, this should be set to + // format_descriptor::format() + ssize_t ndim = 0; // Number of dimensions + std::vector shape; // Shape of the tensor (1 entry per dimension) + std::vector strides; // Number of bytes between adjacent entries + // (for each per dimension) + bool readonly = false; // flag to indicate if the underlying storage may be written to + + buffer_info() = default; + + buffer_info(void *ptr, + ssize_t itemsize, + const std::string &format, + ssize_t ndim, + detail::any_container shape_in, + detail::any_container strides_in, + bool readonly = false) + : ptr(ptr), itemsize(itemsize), size(1), format(format), ndim(ndim), + shape(std::move(shape_in)), strides(std::move(strides_in)), readonly(readonly) { + if (ndim != (ssize_t) shape.size() || ndim != (ssize_t) strides.size()) { + pybind11_fail("buffer_info: ndim doesn't match shape and/or strides length"); + } + for (size_t i = 0; i < (size_t) ndim; ++i) { + size *= shape[i]; + } + } + + template + buffer_info(T *ptr, + detail::any_container shape_in, + detail::any_container strides_in, + bool readonly = false) + : buffer_info(private_ctr_tag(), + ptr, + sizeof(T), + format_descriptor::format(), + static_cast(shape_in->size()), + std::move(shape_in), + std::move(strides_in), + readonly) {} + + buffer_info(void *ptr, + ssize_t itemsize, + const std::string &format, + ssize_t size, + bool readonly = false) + : buffer_info(ptr, itemsize, format, 1, {size}, {itemsize}, readonly) {} + + template + buffer_info(T *ptr, ssize_t size, bool readonly = false) + : buffer_info(ptr, sizeof(T), format_descriptor::format(), size, readonly) {} + + template + buffer_info(const T *ptr, ssize_t size, bool readonly = true) + : buffer_info( + const_cast(ptr), sizeof(T), format_descriptor::format(), size, readonly) {} + + explicit buffer_info(Py_buffer *view, bool ownview = true) + : buffer_info( + view->buf, + view->itemsize, + view->format, + view->ndim, + {view->shape, view->shape + view->ndim}, + /* Though buffer::request() requests PyBUF_STRIDES, ctypes objects + * ignore this flag and return a view with NULL strides. + * When strides are NULL, build them manually. */ + view->strides + ? std::vector(view->strides, view->strides + view->ndim) + : detail::c_strides({view->shape, view->shape + view->ndim}, view->itemsize), + (view->readonly != 0)) { + // NOLINTNEXTLINE(cppcoreguidelines-prefer-member-initializer) + this->m_view = view; + // NOLINTNEXTLINE(cppcoreguidelines-prefer-member-initializer) + this->ownview = ownview; + } + + buffer_info(const buffer_info &) = delete; + buffer_info &operator=(const buffer_info &) = delete; + + buffer_info(buffer_info &&other) noexcept { (*this) = std::move(other); } + + buffer_info &operator=(buffer_info &&rhs) noexcept { + ptr = rhs.ptr; + itemsize = rhs.itemsize; + size = rhs.size; + format = std::move(rhs.format); + ndim = rhs.ndim; + shape = std::move(rhs.shape); + strides = std::move(rhs.strides); + std::swap(m_view, rhs.m_view); + std::swap(ownview, rhs.ownview); + readonly = rhs.readonly; + return *this; + } + + ~buffer_info() { + if (m_view && ownview) { + PyBuffer_Release(m_view); + delete m_view; + } + } + + Py_buffer *view() const { return m_view; } + Py_buffer *&view() { return m_view; } + + /* True if the buffer item type is equivalent to `T`. */ + // To define "equivalent" by example: + // `buffer_info::item_type_is_equivalent_to(b)` and + // `buffer_info::item_type_is_equivalent_to(b)` may both be true + // on some platforms, but `int` and `unsigned` will never be equivalent. + // For the ground truth, please inspect `detail::compare_buffer_info<>`. + template + bool item_type_is_equivalent_to() const { + return detail::compare_buffer_info::compare(*this); + } + +private: + struct private_ctr_tag {}; + + buffer_info(private_ctr_tag, + void *ptr, + ssize_t itemsize, + const std::string &format, + ssize_t ndim, + detail::any_container &&shape_in, + detail::any_container &&strides_in, + bool readonly) + : buffer_info( + ptr, itemsize, format, ndim, std::move(shape_in), std::move(strides_in), readonly) {} + + Py_buffer *m_view = nullptr; + bool ownview = false; +}; + +PYBIND11_NAMESPACE_BEGIN(detail) + +template +struct compare_buffer_info { + static bool compare(const buffer_info &b) { + // NOLINTNEXTLINE(bugprone-sizeof-expression) Needed for `PyObject *` + return b.format == format_descriptor::format() && b.itemsize == (ssize_t) sizeof(T); + } +}; + +template +struct compare_buffer_info::value>> { + static bool compare(const buffer_info &b) { + return (size_t) b.itemsize == sizeof(T) + && (b.format == format_descriptor::value + || ((sizeof(T) == sizeof(long)) + && b.format == (std::is_unsigned::value ? "L" : "l")) + || ((sizeof(T) == sizeof(size_t)) + && b.format == (std::is_unsigned::value ? "N" : "n"))); + } +}; + +PYBIND11_NAMESPACE_END(detail) +PYBIND11_NAMESPACE_END(PYBIND11_NAMESPACE) + +#else +#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined." +#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) diff --git a/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/pybind11/cast.h b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/pybind11/cast.h new file mode 100644 index 0000000000000000000000000000000000000000..6949956d6ac4fa47efa9168daa1a6b82735d9f42 --- /dev/null +++ b/Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/pybind11/cast.h @@ -0,0 +1,2366 @@ +#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION) +/* + pybind11/cast.h: Partial template specializations to cast between + C++ and Python types + + Copyright (c) 2016 Wenzel Jakob + + All rights reserved. Use of this source code is governed by a + BSD-style license that can be found in the LICENSE file. +*/ + +#pragma once + +#include "detail/common.h" +#include "detail/descr.h" +#include "detail/native_enum_data.h" +#include "detail/type_caster_base.h" +#include "detail/typeid.h" +#include "pytypes.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE) + +PYBIND11_WARNING_DISABLE_MSVC(4127) + +PYBIND11_NAMESPACE_BEGIN(detail) + +template +class type_caster : public type_caster_base {}; +template +using make_caster = type_caster>; + +// Shortcut for calling a caster's `cast_op_type` cast operator for casting a type_caster to a T +template +typename make_caster::template cast_op_type cast_op(make_caster &caster) { + using result_t = typename make_caster::template cast_op_type; // See PR #4893 + return caster.operator result_t(); +} +template +typename make_caster::template cast_op_type::type> +cast_op(make_caster &&caster) { + using result_t = typename make_caster::template cast_op_type< + typename std::add_rvalue_reference::type>; // See PR #4893 + return std::move(caster).operator result_t(); +} + +template +class type_caster_enum_type { +private: + using Underlying = typename std::underlying_type::type; + +public: + static constexpr auto name = const_name(); + + template + static handle cast(SrcType &&src, return_value_policy, handle parent) { + handle native_enum + = global_internals_native_enum_type_map_get_item(std::type_index(typeid(EnumType))); + if (native_enum) { + return native_enum(static_cast(src)).release(); + } + return type_caster_base::cast( + std::forward(src), + // Fixes https://github.com/pybind/pybind11/pull/3643#issuecomment-1022987818: + return_value_policy::copy, + parent); + } + + template + static handle cast(SrcType *src, return_value_policy policy, handle parent) { + return cast(*src, policy, parent); + } + + bool load(handle src, bool convert) { + handle native_enum + = global_internals_native_enum_type_map_get_item(std::type_index(typeid(EnumType))); + if (native_enum) { + if (!isinstance(src, native_enum)) { + return false; + } + type_caster underlying_caster; + if (!underlying_caster.load(src.attr("value"), convert)) { + pybind11_fail("native_enum internal consistency failure."); + } + value = static_cast(static_cast(underlying_caster)); + return true; + } + if (!pybind11_enum_) { + pybind11_enum_.reset(new type_caster_base()); + } + return pybind11_enum_->load(src, convert); + } + + template + using cast_op_type = detail::cast_op_type; + + // NOLINTNEXTLINE(google-explicit-constructor) + operator EnumType *() { + if (!pybind11_enum_) { + return &value; + } + return pybind11_enum_->operator EnumType *(); + } + + // NOLINTNEXTLINE(google-explicit-constructor) + operator EnumType &() { + if (!pybind11_enum_) { + return value; + } + return pybind11_enum_->operator EnumType &(); + } + +private: + std::unique_ptr> pybind11_enum_; + EnumType value; +}; + +template +struct type_caster_enum_type_enabled : std::true_type {}; + +template +struct type_uses_type_caster_enum_type { + static constexpr bool value + = std::is_enum::value && type_caster_enum_type_enabled::value; +}; + +template +class type_caster::value>> + : public type_caster_enum_type {}; + +template ::value, int> = 0> +bool isinstance_native_enum_impl(handle obj, const std::type_info &tp) { + handle native_enum = global_internals_native_enum_type_map_get_item(tp); + if (!native_enum) { + return false; + } + return isinstance(obj, native_enum); +} + +template ::value, int> = 0> +bool isinstance_native_enum_impl(handle, const std::type_info &) { + return false; +} + +template +bool isinstance_native_enum(handle obj, const std::type_info &tp) { + return isinstance_native_enum_impl>(obj, tp); +} + +template +class type_caster> { +private: + using caster_t = make_caster; + caster_t subcaster; + using reference_t = type &; + using subcaster_cast_op_type = typename caster_t::template cast_op_type; + + static_assert( + std::is_same::type &, subcaster_cast_op_type>::value + || std::is_same::value, + "std::reference_wrapper caster requires T to have a caster with an " + "`operator T &()` or `operator const T &()`"); + +public: + bool load(handle src, bool convert) { return subcaster.load(src, convert); } + static constexpr auto name = caster_t::name; + static handle + cast(const std::reference_wrapper &src, return_value_policy policy, handle parent) { + // It is definitely wrong to take ownership of this pointer, so mask that rvp + if (policy == return_value_policy::take_ownership + || policy == return_value_policy::automatic) { + policy = return_value_policy::automatic_reference; + } + return caster_t::cast(&src.get(), policy, parent); + } + template + using cast_op_type = std::reference_wrapper; + explicit operator std::reference_wrapper() { return cast_op(subcaster); } +}; + +#define PYBIND11_TYPE_CASTER(type, py_name) \ +protected: \ + type value; \ + \ +public: \ + static constexpr auto name = py_name; \ + template >::value, \ + int> \ + = 0> \ + static ::pybind11::handle cast( \ + T_ *src, ::pybind11::return_value_policy policy, ::pybind11::handle parent) { \ + if (!src) \ + return ::pybind11::none().release(); \ + if (policy == ::pybind11::return_value_policy::take_ownership) { \ + auto h = cast(std::move(*src), policy, parent); \ + delete src; \ + return h; \ + } \ + return cast(*src, policy, parent); \ + } \ + operator type *() { return &value; } /* NOLINT(bugprone-macro-parentheses) */ \ + operator type &() { return value; } /* NOLINT(bugprone-macro-parentheses) */ \ + operator type &&() && { return std::move(value); } /* NOLINT(bugprone-macro-parentheses) */ \ + template \ + using cast_op_type = ::pybind11::detail::movable_cast_op_type + +template +using is_std_char_type = any_of, /* std::string */ +#if defined(PYBIND11_HAS_U8STRING) + std::is_same, /* std::u8string */ +#endif + std::is_same, /* std::u16string */ + std::is_same, /* std::u32string */ + std::is_same /* std::wstring */ + >; + +template +struct type_caster::value && !is_std_char_type::value>> { + using _py_type_0 = conditional_t; + using _py_type_1 = conditional_t::value, + _py_type_0, + typename std::make_unsigned<_py_type_0>::type>; + using py_type = conditional_t::value, double, _py_type_1>; + +public: + bool load(handle src, bool convert) { + py_type py_value; + + if (!src) { + return false; + } + +#if !defined(PYPY_VERSION) + auto index_check = [](PyObject *o) { return PyIndex_Check(o); }; +#else + // In PyPy 7.3.3, `PyIndex_Check` is implemented by calling `__index__`, + // while CPython only considers the existence of `nb_index`/`__index__`. + auto index_check = [](PyObject *o) { return hasattr(o, "__index__"); }; +#endif + + if (std::is_floating_point::value) { + if (convert || PyFloat_Check(src.ptr())) { + py_value = (py_type) PyFloat_AsDouble(src.ptr()); + } else { + return false; + } + } else if (PyFloat_Check(src.ptr()) + || (!convert && !PYBIND11_LONG_CHECK(src.ptr()) && !index_check(src.ptr()))) { + return false; + } else { + handle src_or_index = src; + // PyPy: 7.3.7's 3.8 does not implement PyLong_*'s __index__ calls. +#if defined(PYPY_VERSION) + object index; + if (!PYBIND11_LONG_CHECK(src.ptr())) { // So: index_check(src.ptr()) + index = reinterpret_steal(PyNumber_Index(src.ptr())); + if (!index) { + PyErr_Clear(); + if (!convert) + return false; + } else { + src_or_index = index; + } + } +#endif + if (std::is_unsigned::value) { + py_value = as_unsigned(src_or_index.ptr()); + } else { // signed integer: + py_value = sizeof(T) <= sizeof(long) + ? (py_type) PyLong_AsLong(src_or_index.ptr()) + : (py_type) PYBIND11_LONG_AS_LONGLONG(src_or_index.ptr()); + } + } + + // Python API reported an error + bool py_err = py_value == (py_type) -1 && PyErr_Occurred(); + + // Check to see if the conversion is valid (integers should match exactly) + // Signed/unsigned checks happen elsewhere + if (py_err + || (std::is_integral::value && sizeof(py_type) != sizeof(T) + && py_value != (py_type) (T) py_value)) { + PyErr_Clear(); + if (py_err && convert && (PyNumber_Check(src.ptr()) != 0)) { + auto tmp = reinterpret_steal(std::is_floating_point::value + ? PyNumber_Float(src.ptr()) + : PyNumber_Long(src.ptr())); + PyErr_Clear(); + return load(tmp, false); + } + return false; + } + + value = (T) py_value; + return true; + } + + template + static typename std::enable_if::value, handle>::type + cast(U src, return_value_policy /* policy */, handle /* parent */) { + return PyFloat_FromDouble((double) src); + } + + template + static typename std::enable_if::value && std::is_signed::value + && (sizeof(U) <= sizeof(long)), + handle>::type + cast(U src, return_value_policy /* policy */, handle /* parent */) { + return PYBIND11_LONG_FROM_SIGNED((long) src); + } + + template + static typename std::enable_if::value && std::is_unsigned::value + && (sizeof(U) <= sizeof(unsigned long)), + handle>::type + cast(U src, return_value_policy /* policy */, handle /* parent */) { + return PYBIND11_LONG_FROM_UNSIGNED((unsigned long) src); + } + + template + static typename std::enable_if::value && std::is_signed::value + && (sizeof(U) > sizeof(long)), + handle>::type + cast(U src, return_value_policy /* policy */, handle /* parent */) { + return PyLong_FromLongLong((long long) src); + } + + template + static typename std::enable_if::value && std::is_unsigned::value + && (sizeof(U) > sizeof(unsigned long)), + handle>::type + cast(U src, return_value_policy /* policy */, handle /* parent */) { + return PyLong_FromUnsignedLongLong((unsigned long long) src); + } + + PYBIND11_TYPE_CASTER(T, + io_name::value>( + "typing.SupportsInt", "int", "typing.SupportsFloat", "float")); +}; + +template +struct void_caster { +public: + bool load(handle src, bool) { + if (src && src.is_none()) { + return true; + } + return false; + } + static handle cast(T, return_value_policy /* policy */, handle /* parent */) { + return none().release(); + } + PYBIND11_TYPE_CASTER(T, const_name("None")); +}; + +template <> +class type_caster : public void_caster {}; + +template <> +class type_caster : public type_caster { +public: + using type_caster::cast; + + bool load(handle h, bool) { + if (!h) { + return false; + } + if (h.is_none()) { + value = nullptr; + return true; + } + + /* Check if this is a capsule */ + if (isinstance(h)) { + value = reinterpret_borrow(h); + return true; + } + + /* Check if this is a C++ type */ + const auto &bases = all_type_info((PyTypeObject *) type::handle_of(h).ptr()); + if (bases.size() == 1) { // Only allowing loading from a single-value type + value = values_and_holders(reinterpret_cast(h.ptr())).begin()->value_ptr(); + return true; + } + + /* Fail */ + return false; + } + + static handle cast(const void *ptr, return_value_policy /* policy */, handle /* parent */) { + if (ptr) { + return capsule(ptr).release(); + } + return none().release(); + } + + template + using cast_op_type = void *&; + explicit operator void *&() { return value; } + static constexpr auto name = const_name(PYBIND11_CAPSULE_TYPE_TYPE_HINT); + +private: + void *value = nullptr; +}; + +template <> +class type_caster : public void_caster {}; + +template <> +class type_caster { +public: + bool load(handle src, bool convert) { + if (!src) { + return false; + } + if (src.ptr() == Py_True) { + value = true; + return true; + } + if (src.ptr() == Py_False) { + value = false; + return true; + } + if (convert || is_numpy_bool(src)) { + // (allow non-implicit conversion for numpy booleans), use strncmp + // since NumPy 1.x had an additional trailing underscore. + + Py_ssize_t res = -1; + if (src.is_none()) { + res = 0; // None is implicitly converted to False + } +#if defined(PYPY_VERSION) + // On PyPy, check that "__bool__" attr exists + else if (hasattr(src, PYBIND11_BOOL_ATTR)) { + res = PyObject_IsTrue(src.ptr()); + } +#else + // Alternate approach for CPython: this does the same as the above, but optimized + // using the CPython API so as to avoid an unneeded attribute lookup. + else if (auto *tp_as_number = Py_TYPE(src.ptr())->tp_as_number) { + if (PYBIND11_NB_BOOL(tp_as_number)) { + res = (*PYBIND11_NB_BOOL(tp_as_number))(src.ptr()); + } + } +#endif + if (res == 0 || res == 1) { + value = (res != 0); + return true; + } + PyErr_Clear(); + } + return false; + } + static handle cast(bool src, return_value_policy /* policy */, handle /* parent */) { + return handle(src ? Py_True : Py_False).inc_ref(); + } + PYBIND11_TYPE_CASTER(bool, const_name("bool")); + +private: + // Test if an object is a NumPy boolean (without fetching the type). + static inline bool is_numpy_bool(handle object) { + const char *type_name = Py_TYPE(object.ptr())->tp_name; + // Name changed to `numpy.bool` in NumPy 2, `numpy.bool_` is needed for 1.x support + return std::strcmp("numpy.bool", type_name) == 0 + || std::strcmp("numpy.bool_", type_name) == 0; + } +}; + +// Helper class for UTF-{8,16,32} C++ stl strings: +template +struct string_caster { + using CharT = typename StringType::value_type; + + // Simplify life by being able to assume standard char sizes (the standard only guarantees + // minimums, but Python requires exact sizes) + static_assert(!std::is_same::value || sizeof(CharT) == 1, + "Unsupported char size != 1"); +#if defined(PYBIND11_HAS_U8STRING) + static_assert(!std::is_same::value || sizeof(CharT) == 1, + "Unsupported char8_t size != 1"); +#endif + static_assert(!std::is_same::value || sizeof(CharT) == 2, + "Unsupported char16_t size != 2"); + static_assert(!std::is_same::value || sizeof(CharT) == 4, + "Unsupported char32_t size != 4"); + // wchar_t can be either 16 bits (Windows) or 32 (everywhere else) + static_assert(!std::is_same::value || sizeof(CharT) == 2 || sizeof(CharT) == 4, + "Unsupported wchar_t size != 2/4"); + static constexpr size_t UTF_N = 8 * sizeof(CharT); + + bool load(handle src, bool) { + handle load_src = src; + if (!src) { + return false; + } + if (!PyUnicode_Check(load_src.ptr())) { + return load_raw(load_src); + } + + // For UTF-8 we avoid the need for a temporary `bytes` object by using + // `PyUnicode_AsUTF8AndSize`. + if (UTF_N == 8) { + Py_ssize_t size = -1; + const auto *buffer + = reinterpret_cast(PyUnicode_AsUTF8AndSize(load_src.ptr(), &size)); + if (!buffer) { + PyErr_Clear(); + return false; + } + value = StringType(buffer, static_cast(size)); + return true; + } + + auto utfNbytes + = reinterpret_steal(PyUnicode_AsEncodedString(load_src.ptr(), + UTF_N == 8 ? "utf-8" + : UTF_N == 16 ? "utf-16" + : "utf-32", + nullptr)); + if (!utfNbytes) { + PyErr_Clear(); + return false; + } + + const auto *buffer + = reinterpret_cast(PYBIND11_BYTES_AS_STRING(utfNbytes.ptr())); + size_t length = (size_t) PYBIND11_BYTES_SIZE(utfNbytes.ptr()) / sizeof(CharT); + // Skip BOM for UTF-16/32 + if (UTF_N > 8) { + buffer++; + length--; + } + value = StringType(buffer, length); + + // If we're loading a string_view we need to keep the encoded Python object alive: + if (IsView) { + loader_life_support::add_patient(utfNbytes); + } + + return true; + } + + static handle + cast(const StringType &src, return_value_policy /* policy */, handle /* parent */) { + const char *buffer = reinterpret_cast(src.data()); + auto nbytes = ssize_t(src.size() * sizeof(CharT)); + handle s = decode_utfN(buffer, nbytes); + if (!s) { + throw error_already_set(); + } + return s; + } + + PYBIND11_TYPE_CASTER(StringType, const_name(PYBIND11_STRING_NAME)); + +private: + static handle decode_utfN(const char *buffer, ssize_t nbytes) { +#if !defined(PYPY_VERSION) + return UTF_N == 8 ? PyUnicode_DecodeUTF8(buffer, nbytes, nullptr) + : UTF_N == 16 ? PyUnicode_DecodeUTF16(buffer, nbytes, nullptr, nullptr) + : PyUnicode_DecodeUTF32(buffer, nbytes, nullptr, nullptr); +#else + // PyPy segfaults when on PyUnicode_DecodeUTF16 (and possibly on PyUnicode_DecodeUTF32 as + // well), so bypass the whole thing by just passing the encoding as a string value, which + // works properly: + return PyUnicode_Decode(buffer, + nbytes, + UTF_N == 8 ? "utf-8" + : UTF_N == 16 ? "utf-16" + : "utf-32", + nullptr); +#endif + } + + // When loading into a std::string or char*, accept a bytes/bytearray object as-is (i.e. + // without any encoding/decoding attempt). For other C++ char sizes this is a no-op. + // which supports loading a unicode from a str, doesn't take this path. + template + bool load_raw(enable_if_t::value, handle> src) { + if (PYBIND11_BYTES_CHECK(src.ptr())) { + // We were passed raw bytes; accept it into a std::string or char* + // without any encoding attempt. + const char *bytes = PYBIND11_BYTES_AS_STRING(src.ptr()); + if (!bytes) { + pybind11_fail("Unexpected PYBIND11_BYTES_AS_STRING() failure."); + } + value = StringType(bytes, (size_t) PYBIND11_BYTES_SIZE(src.ptr())); + return true; + } + if (PyByteArray_Check(src.ptr())) { + // We were passed a bytearray; accept it into a std::string or char* + // without any encoding attempt. + const char *bytearray = PyByteArray_AsString(src.ptr()); + if (!bytearray) { + pybind11_fail("Unexpected PyByteArray_AsString() failure."); + } + value = StringType(bytearray, (size_t) PyByteArray_Size(src.ptr())); + return true; + } + + return false; + } + + template + bool load_raw(enable_if_t::value, handle>) { + return false; + } +}; + +template +struct type_caster, + enable_if_t::value>> + : string_caster> {}; + +#ifdef PYBIND11_HAS_STRING_VIEW +template +struct type_caster, + enable_if_t::value>> + : string_caster, true> {}; +#endif + +// Type caster for C-style strings. We basically use a std::string type caster, but also add the +// ability to use None as a nullptr char* (which the string caster doesn't allow). +template +struct type_caster::value>> { + using StringType = std::basic_string; + using StringCaster = make_caster; + StringCaster str_caster; + bool none = false; + CharT one_char = 0; + +public: + bool load(handle src, bool convert) { + if (!src) { + return false; + } + if (src.is_none()) { + // Defer accepting None to other overloads (if we aren't in convert mode): + if (!convert) { + return false; + } + none = true; + return true; + } + return str_caster.load(src, convert); + } + + static handle cast(const CharT *src, return_value_policy policy, handle parent) { + if (src == nullptr) { + return pybind11::none().release(); + } + return StringCaster::cast(StringType(src), policy, parent); + } + + static handle cast(CharT src, return_value_policy policy, handle parent) { + if (std::is_same::value) { + handle s = PyUnicode_DecodeLatin1((const char *) &src, 1, nullptr); + if (!s) { + throw error_already_set(); + } + return s; + } + return StringCaster::cast(StringType(1, src), policy, parent); + } + + explicit operator CharT *() { + return none ? nullptr : const_cast(static_cast(str_caster).c_str()); + } + explicit operator CharT &() { + if (none) { + throw value_error("Cannot convert None to a character"); + } + + auto &value = static_cast(str_caster); + size_t str_len = value.size(); + if (str_len == 0) { + throw value_error("Cannot convert empty string to a character"); + } + + // If we're in UTF-8 mode, we have two possible failures: one for a unicode character that + // is too high, and one for multiple unicode characters (caught later), so we need to + // figure out how long the first encoded character is in bytes to distinguish between these + // two errors. We also allow want to allow unicode characters U+0080 through U+00FF, as + // those can fit into a single char value. + if (StringCaster::UTF_N == 8 && str_len > 1 && str_len <= 4) { + auto v0 = static_cast(value[0]); + // low bits only: 0-127 + // 0b110xxxxx - start of 2-byte sequence + // 0b1110xxxx - start of 3-byte sequence + // 0b11110xxx - start of 4-byte sequence + size_t char0_bytes = (v0 & 0x80) == 0 ? 1 + : (v0 & 0xE0) == 0xC0 ? 2 + : (v0 & 0xF0) == 0xE0 ? 3 + : 4; + + if (char0_bytes == str_len) { + // If we have a 128-255 value, we can decode it into a single char: + if (char0_bytes == 2 && (v0 & 0xFC) == 0xC0) { // 0x110000xx 0x10xxxxxx + one_char = static_cast(((v0 & 3) << 6) + + (static_cast(value[1]) & 0x3F)); + return one_char; + } + // Otherwise we have a single character, but it's > U+00FF + throw value_error("Character code point not in range(0x100)"); + } + } + + // UTF-16 is much easier: we can only have a surrogate pair for values above U+FFFF, thus a + // surrogate pair with total length 2 instantly indicates a range error (but not a "your + // string was too long" error). + else if (StringCaster::UTF_N == 16 && str_len == 2) { + one_char = static_cast(value[0]); + if (one_char >= 0xD800 && one_char < 0xE000) { + throw value_error("Character code point not in range(0x10000)"); + } + } + + if (str_len != 1) { + throw value_error("Expected a character, but multi-character string found"); + } + + one_char = value[0]; + return one_char; + } + + static constexpr auto name = const_name(PYBIND11_STRING_NAME); + template + using cast_op_type = pybind11::detail::cast_op_type<_T>; +}; + +// Base implementation for std::tuple and std::pair +template